dotviz visulation for cells
This commit is contained in:
parent
a1aa24c257
commit
d19056f071
81
visualiser.py
Normal file
81
visualiser.py
Normal file
@ -0,0 +1,81 @@
|
||||
import re
|
||||
from graphviz import Digraph
|
||||
import pandas as pd
|
||||
import time
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Fast cell visualisation')
|
||||
parser.add_argument('--arch', default=1, type=int)
|
||||
parser.add_argument('--save', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
def set_none(bit):
|
||||
print(bit)
|
||||
tmp = bit.split('~')
|
||||
tmp[0] = 'none'
|
||||
print('~'.join(tmp))
|
||||
return '~'.join(tmp)
|
||||
|
||||
def remove_pointless_ops(archstr):
|
||||
old = None
|
||||
new = archstr
|
||||
while old != new:
|
||||
old = new
|
||||
bits = old.strip('|').split('|')
|
||||
if 'none~' in bits[0]: # node 1 has no connections to it
|
||||
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
|
||||
bits[6] = set_none(bits[6]) # node 1 -> 3 now none
|
||||
if 'none~' in bits[2] and 'none~' in bits[3]: # node 2 has no connections to it
|
||||
bits[7] = set_none(bits[7]) # node 2 -> 3 now none
|
||||
if 'none~' in bits[7]: # doesn't matter what comes through node 2
|
||||
bits[2] = set_none(bits[2]) # node 0 -> 2 now none
|
||||
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
|
||||
if 'none~' in bits[6] and 'none~' in bits[7]: # doesn't matter what comes through node 1
|
||||
bits[0] = set_none(bits[0]) # node 0 -> 1 now none
|
||||
new = '|'.join(bits)
|
||||
print(new)
|
||||
return new
|
||||
|
||||
|
||||
df = pd.read_pickle('results/arch_score_acc.pd')
|
||||
|
||||
nodestr = df.iloc[args.arch]['cellstr']
|
||||
nodestr = nodestr[1:-1] # remove leading and trailing bars |
|
||||
|
||||
nodestr = remove_pointless_ops(nodestr)
|
||||
nodes = nodestr.split("|+|")
|
||||
|
||||
dot = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='12'),
|
||||
node_attr=dict(fixedsize='true',shape="circle", height='0.5', width='0.5'),
|
||||
engine='dot')
|
||||
|
||||
dot.body.extend(['rankdir=LR'])
|
||||
|
||||
OPS = ['conv_3x3','avg_pool_3x3','skip_connect','conv_1x1','none']
|
||||
|
||||
dot.node('0', 'in')
|
||||
|
||||
## ops are separated by bars (|) so
|
||||
for i, node in enumerate(nodes):
|
||||
|
||||
# if node 3 then label as output
|
||||
if (i+1) == 3:
|
||||
dot.node(str(i+1), 'out')
|
||||
else:
|
||||
dot.node(str(i+1))
|
||||
|
||||
for op_str in node.split('|'):
|
||||
op_name = [o for o in OPS if o in op_str][0]
|
||||
if op_name == 'none':
|
||||
break
|
||||
connect = re.findall('~[0-9]', op_str)[0]
|
||||
connect = connect[1:]
|
||||
dot.edge(connect,str(i+1), label=op_name)
|
||||
|
||||
dot.render( view=True)
|
||||
|
||||
|
||||
if args.save:
|
||||
dot.render(f'outputs/{args.arch}.gv')
|
Loading…
Reference in New Issue
Block a user