# python ./exps/vis/test.py import os, sys, random from pathlib import Path from copy import deepcopy import torch import numpy as np from collections import OrderedDict lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from nas_201_api import NASBench201API as API def test_nas_api(): from nas_201_api import ArchResults xdata = torch.load( "/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth" ) for key in ["full", "less"]: print("\n------------------------- {:} -------------------------".format(key)) archRes = ArchResults.create_from_state_dict(xdata[key]) print(archRes) print(archRes.arch_idx_str()) print(archRes.get_dataset_names()) print(archRes.get_comput_costs("cifar10-valid")) # get the metrics print(archRes.get_metrics("cifar10-valid", "x-valid", None, False)) print(archRes.get_metrics("cifar10-valid", "x-valid", None, True)) print(archRes.query("cifar10-valid", 777)) OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"] COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"] def plot(filename): from graphviz import Digraph g = Digraph( format="png", edge_attr=dict(fontsize="20", fontname="times"), node_attr=dict( style="filled", shape="rect", align="center", fontsize="20", height="0.5", width="0.5", penwidth="2", fontname="times", ), engine="dot", ) g.body.extend(["rankdir=LR"]) steps = 5 for i in range(0, steps): if i == 0: g.node(str(i), fillcolor="darkseagreen2") elif i + 1 == steps: g.node(str(i), fillcolor="palegoldenrod") else: g.node(str(i), fillcolor="lightblue") for i in range(1, steps): for xin in range(i): op_i = random.randint(0, len(OPS) - 1) # g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i]) g.edge( str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i], ) # import pdb; pdb.set_trace() g.render(filename, cleanup=True, view=False) def test_auto_grad(): class Net(torch.nn.Module): def __init__(self, iS): super(Net, self).__init__() self.layer = torch.nn.Linear(iS, 1) def forward(self, inputs): outputs = self.layer(inputs) outputs = torch.exp(outputs) return outputs.mean() net = Net(10) inputs = torch.rand(256, 10) loss = net(inputs) first_order_grads = torch.autograd.grad( loss, net.parameters(), retain_graph=True, create_graph=True ) first_order_grads = torch.cat([x.view(-1) for x in first_order_grads]) second_order_grads = [] for grads in first_order_grads: s_grads = torch.autograd.grad(grads, net.parameters()) second_order_grads.append(s_grads) def test_one_shot_model(ckpath, use_train): from models import get_cell_based_tiny_net, get_search_spaces from datasets import get_datasets, SearchDataset from config_utils import load_config, dict2config from utils.nas_utils import evaluate_one_shot use_train = int(use_train) > 0 # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' print("ckpath : {:}".format(ckpath)) ckp = torch.load(ckpath) xargs = ckp["args"] train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1 ) # config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) config = load_config( "./configs/nas-benchmark/algos/DARTS.config", {"class_num": class_num, "xshape": xshape}, None, ) if xargs.dataset == "cifar10": cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None) xvalid_data = deepcopy(train_data) xvalid_data.transform = valid_data.transform valid_loader = torch.utils.data.DataLoader( xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True, ) else: raise ValueError("invalid dataset : {:}".format(xargs.dataseet)) search_space = get_search_spaces("cell", xargs.search_space_name) model_config = dict2config( { "name": "SETN", "C": xargs.channel, "N": xargs.num_cells, "max_nodes": xargs.max_nodes, "num_classes": class_num, "space": search_space, "affine": False, "track_running_stats": True, }, None, ) search_model = get_cell_based_tiny_net(model_config) search_model.load_state_dict(ckp["search_model"]) search_model = search_model.cuda() api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth") archs, probs, accuracies = evaluate_one_shot( search_model, valid_loader, api, use_train ) if __name__ == "__main__": # test_nas_api() # for i in range(200): plot('{:04d}'.format(i)) # test_auto_grad() test_one_shot_model(sys.argv[1], sys.argv[2])