dotviz visulation for cells
This commit is contained in:
parent
a1aa24c257
commit
d19056f071
81
visualiser.py
Normal file
81
visualiser.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import re
|
||||||
|
from graphviz import Digraph
|
||||||
|
import pandas as pd
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Fast cell visualisation')
|
||||||
|
parser.add_argument('--arch', default=1, type=int)
|
||||||
|
parser.add_argument('--save', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def set_none(bit):
|
||||||
|
print(bit)
|
||||||
|
tmp = bit.split('~')
|
||||||
|
tmp[0] = 'none'
|
||||||
|
print('~'.join(tmp))
|
||||||
|
return '~'.join(tmp)
|
||||||
|
|
||||||
|
def remove_pointless_ops(archstr):
|
||||||
|
old = None
|
||||||
|
new = archstr
|
||||||
|
while old != new:
|
||||||
|
old = new
|
||||||
|
bits = old.strip('|').split('|')
|
||||||
|
if 'none~' in bits[0]: # node 1 has no connections to it
|
||||||
|
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
|
||||||
|
bits[6] = set_none(bits[6]) # node 1 -> 3 now none
|
||||||
|
if 'none~' in bits[2] and 'none~' in bits[3]: # node 2 has no connections to it
|
||||||
|
bits[7] = set_none(bits[7]) # node 2 -> 3 now none
|
||||||
|
if 'none~' in bits[7]: # doesn't matter what comes through node 2
|
||||||
|
bits[2] = set_none(bits[2]) # node 0 -> 2 now none
|
||||||
|
bits[3] = set_none(bits[3]) # node 1 -> 2 now none
|
||||||
|
if 'none~' in bits[6] and 'none~' in bits[7]: # doesn't matter what comes through node 1
|
||||||
|
bits[0] = set_none(bits[0]) # node 0 -> 1 now none
|
||||||
|
new = '|'.join(bits)
|
||||||
|
print(new)
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
df = pd.read_pickle('results/arch_score_acc.pd')
|
||||||
|
|
||||||
|
nodestr = df.iloc[args.arch]['cellstr']
|
||||||
|
nodestr = nodestr[1:-1] # remove leading and trailing bars |
|
||||||
|
|
||||||
|
nodestr = remove_pointless_ops(nodestr)
|
||||||
|
nodes = nodestr.split("|+|")
|
||||||
|
|
||||||
|
dot = Digraph(
|
||||||
|
format='pdf',
|
||||||
|
edge_attr=dict(fontsize='12'),
|
||||||
|
node_attr=dict(fixedsize='true',shape="circle", height='0.5', width='0.5'),
|
||||||
|
engine='dot')
|
||||||
|
|
||||||
|
dot.body.extend(['rankdir=LR'])
|
||||||
|
|
||||||
|
OPS = ['conv_3x3','avg_pool_3x3','skip_connect','conv_1x1','none']
|
||||||
|
|
||||||
|
dot.node('0', 'in')
|
||||||
|
|
||||||
|
## ops are separated by bars (|) so
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
|
||||||
|
# if node 3 then label as output
|
||||||
|
if (i+1) == 3:
|
||||||
|
dot.node(str(i+1), 'out')
|
||||||
|
else:
|
||||||
|
dot.node(str(i+1))
|
||||||
|
|
||||||
|
for op_str in node.split('|'):
|
||||||
|
op_name = [o for o in OPS if o in op_str][0]
|
||||||
|
if op_name == 'none':
|
||||||
|
break
|
||||||
|
connect = re.findall('~[0-9]', op_str)[0]
|
||||||
|
connect = connect[1:]
|
||||||
|
dot.edge(connect,str(i+1), label=op_name)
|
||||||
|
|
||||||
|
dot.render( view=True)
|
||||||
|
|
||||||
|
|
||||||
|
if args.save:
|
||||||
|
dot.render(f'outputs/{args.arch}.gv')
|
Loading…
Reference in New Issue
Block a user