# python ./exps/vis/test.py import os, sys from pathlib import Path import torch import numpy as np from collections import OrderedDict lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) def test_nas_api(): from nas_102_api import ArchResults xdata = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-102-4/simplifies/architectures/000157-FULL.pth') for key in ['full', 'less']: print ('\n------------------------- {:} -------------------------'.format(key)) archRes = ArchResults.create_from_state_dict(xdata[key]) print(archRes) print(archRes.arch_idx_str()) print(archRes.get_dataset_names()) print(archRes.get_comput_costs('cifar10-valid')) # get the metrics print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) print(archRes.query('cifar10-valid', 777)) if __name__ == '__main__': test_nas_api()