import os, sys, time, glob, random, argparse import numpy as np from copy import deepcopy import torch from pathlib import Path lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from graphviz import Digraph parser = argparse.ArgumentParser("Visualize the Networks") parser.add_argument('--checkpoint', type=str, help='The path to the checkpoint.') parser.add_argument('--save_dir', type=str, help='The directory to save the network plot.') args = parser.parse_args() def plot(genotype, filename): g = Digraph( format='pdf', edge_attr=dict(fontsize='20', fontname="times"), node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), engine='dot') g.body.extend(['rankdir=LR']) g.node("c_{k-2}", fillcolor='darkseagreen2') g.node("c_{k-1}", fillcolor='darkseagreen2') assert len(genotype) % 2 == 0 steps = len(genotype) // 2 for i in range(steps): g.node(str(i), fillcolor='lightblue') for i in range(steps): for k in [2*i, 2*i + 1]: op, j, weight = genotype[k] if j == 0: u = "c_{k-2}" elif j == 1: u = "c_{k-1}" else: u = str(j-2) v = str(i) g.edge(u, v, label=op, fillcolor="gray") g.node("c_{k}", fillcolor='palegoldenrod') for i in range(steps): g.edge(str(i), "c_{k}", fillcolor="gray") g.render(filename, view=False) if __name__ == '__main__': checkpoint = args.checkpoint assert os.path.isfile(checkpoint), 'Invalid path for checkpoint : {:}'.format(checkpoint) checkpoint = torch.load( checkpoint, map_location='cpu' ) genotypes = checkpoint['genotypes'] save_dir = Path(args.save_dir) subs = ['normal', 'reduce'] for sub in subs: if not (save_dir / sub).exists(): (save_dir / sub).mkdir(parents=True, exist_ok=True) for key, network in genotypes.items(): save_path = str(save_dir / 'normal' / 'epoch-{:03d}'.format( int(key) )) print('save into {:}'.format(save_path)) plot(network.normal, save_path) save_path = str(save_dir / 'reduce' / 'epoch-{:03d}'.format( int(key) )) print('save into {:}'.format(save_path)) plot(network.reduce, save_path)