dotviz visulation for cells

This commit is contained in:
Jack Turner 2020-09-07 11:31:11 +01:00 committed by GitHub
parent a1aa24c257
commit d19056f071
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

81
visualiser.py Normal file
View File

@ -0,0 +1,81 @@
import re
from graphviz import Digraph
import pandas as pd
import time
import argparse
parser = argparse.ArgumentParser(description='Fast cell visualisation')
parser.add_argument('--arch', default=1, type=int)
parser.add_argument('--save', action='store_true')
args = parser.parse_args()
def set_none(bit):
print(bit)
tmp = bit.split('~')
tmp[0] = 'none'
print('~'.join(tmp))
return '~'.join(tmp)
def remove_pointless_ops(archstr):
old = None
new = archstr
while old != new:
old = new
bits = old.strip('|').split('|')
if 'none~' in bits[0]: # node 1 has no connections to it
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
bits[6] = set_none(bits[6]) # node 1 -> 3 now none
if 'none~' in bits[2] and 'none~' in bits[3]: # node 2 has no connections to it
bits[7] = set_none(bits[7]) # node 2 -> 3 now none
if 'none~' in bits[7]: # doesn't matter what comes through node 2
bits[2] = set_none(bits[2]) # node 0 -> 2 now none
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
if 'none~' in bits[6] and 'none~' in bits[7]: # doesn't matter what comes through node 1
bits[0] = set_none(bits[0]) # node 0 -> 1 now none
new = '|'.join(bits)
print(new)
return new
df = pd.read_pickle('results/arch_score_acc.pd')
nodestr = df.iloc[args.arch]['cellstr']
nodestr = nodestr[1:-1] # remove leading and trailing bars |
nodestr = remove_pointless_ops(nodestr)
nodes = nodestr.split("|+|")
dot = Digraph(
format='pdf',
edge_attr=dict(fontsize='12'),
node_attr=dict(fixedsize='true',shape="circle", height='0.5', width='0.5'),
engine='dot')
dot.body.extend(['rankdir=LR'])
OPS = ['conv_3x3','avg_pool_3x3','skip_connect','conv_1x1','none']
dot.node('0', 'in')
## ops are separated by bars (|) so
for i, node in enumerate(nodes):
# if node 3 then label as output
if (i+1) == 3:
dot.node(str(i+1), 'out')
else:
dot.node(str(i+1))
for op_str in node.split('|'):
op_name = [o for o in OPS if o in op_str][0]
if op_name == 'none':
break
connect = re.findall('~[0-9]', op_str)[0]
connect = connect[1:]
dot.edge(connect,str(i+1), label=op_name)
dot.render( view=True)
if args.save:
dot.render(f'outputs/{args.arch}.gv')