xautodl/exps/vis/test.py
2019-12-20 20:41:49 +11:00

28 lines
1.0 KiB
Python

# python ./exps/vis/test.py
import os, sys
from pathlib import Path
import torch
import numpy as np
from collections import OrderedDict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
def test_nas_api():
from nas_102_api import ArchResults
xdata = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-102-4/simplifies/architectures/000157-FULL.pth')
for key in ['full', 'less']:
print ('\n------------------------- {:} -------------------------'.format(key))
archRes = ArchResults.create_from_state_dict(xdata[key])
print(archRes)
print(archRes.arch_idx_str())
print(archRes.get_dataset_names())
print(archRes.get_comput_costs('cifar10-valid'))
# get the metrics
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False))
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True))
print(archRes.query('cifar10-valid', 777))
if __name__ == '__main__':
test_nas_api()