update vis
This commit is contained in:
		| @@ -1,11 +1,12 @@ | ||||
| # python ./exps/vis/test.py | ||||
| import os, sys | ||||
| import os, sys, random | ||||
| from pathlib import Path | ||||
| import torch | ||||
| import numpy as np | ||||
| from collections import OrderedDict | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from graphviz import Digraph | ||||
|  | ||||
|  | ||||
| def test_nas_api(): | ||||
| @@ -23,5 +24,35 @@ def test_nas_api(): | ||||
|     print(archRes.get_metrics('cifar10-valid', 'x-valid', None,  True)) | ||||
|     print(archRes.query('cifar10-valid', 777)) | ||||
|  | ||||
|  | ||||
| OPS    = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3'] | ||||
| COLORS = ['chartreuse'  , 'cyan'    , 'navyblue', 'chocolate1'] | ||||
|  | ||||
| def plot(filename): | ||||
|   g = Digraph( | ||||
|       format='png', | ||||
|       edge_attr=dict(fontsize='20', fontname="times"), | ||||
|       node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), | ||||
|       engine='dot') | ||||
|   g.body.extend(['rankdir=LR']) | ||||
|  | ||||
|   steps = 5 | ||||
|   for i in range(0, steps): | ||||
|     if i == 0: | ||||
|       g.node(str(i), fillcolor='darkseagreen2') | ||||
|     elif i+1 == steps: | ||||
|       g.node(str(i), fillcolor='palegoldenrod') | ||||
|     else: g.node(str(i), fillcolor='lightblue') | ||||
|  | ||||
|   for i in range(1, steps): | ||||
|     for xin in range(i): | ||||
|       op_i = random.randint(0, len(OPS)-1) | ||||
|       #g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i]) | ||||
|       g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i]) | ||||
|       #import pdb; pdb.set_trace() | ||||
|   g.render(filename, cleanup=True, view=False) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   test_nas_api() | ||||
|   for i in range(200): plot('{:04d}'.format(i)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user