102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
		| @@ -8,11 +8,11 @@ from collections import OrderedDict | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from nas_102_api import NASBench102API as API | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
| def test_nas_api(): | ||||
|   from nas_102_api import ArchResults | ||||
|   xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-102-4/simplifies/architectures/000157-FULL.pth') | ||||
|   from nas_201_api import ArchResults | ||||
|   xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth') | ||||
|   for key in ['full', 'less']: | ||||
|     print ('\n------------------------- {:} -------------------------'.format(key)) | ||||
|     archRes = ArchResults.create_from_state_dict(xdata[key]) | ||||
| @@ -81,8 +81,8 @@ def test_one_shot_model(ckpath, use_train): | ||||
|   from config_utils import load_config, dict2config | ||||
|   from utils.nas_utils import evaluate_one_shot | ||||
|   use_train = int(use_train) > 0 | ||||
|   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | ||||
|   #ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | ||||
|   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' | ||||
|   #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' | ||||
|   print ('ckpath : {:}'.format(ckpath)) | ||||
|   ckp = torch.load(ckpath) | ||||
|   xargs = ckp['args'] | ||||
| @@ -103,7 +103,7 @@ def test_one_shot_model(ckpath, use_train): | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   search_model.load_state_dict( ckp['search_model'] ) | ||||
|   search_model = search_model.cuda() | ||||
|   api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth') | ||||
|   api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth') | ||||
|   archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train) | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user