wrote the get_nasbench201_idx_score
This commit is contained in:
		| @@ -9,6 +9,7 @@ from scores import get_score_func | |||||||
| from scipy import stats | from scipy import stats | ||||||
| import time | import time | ||||||
| # from pycls.models.nas.nas import Cell | # from pycls.models.nas.nas import Cell | ||||||
|  | from models import get_cell_based_tiny_net | ||||||
| from utils import add_dropout, init_network  | from utils import add_dropout, init_network  | ||||||
|  |  | ||||||
| parser = argparse.ArgumentParser(description='NAS Without Training') | parser = argparse.ArgumentParser(description='NAS Without Training') | ||||||
| @@ -56,11 +57,22 @@ def get_batch_jacobian(net, x, target, device, args=None): | |||||||
|     jacob = x.grad.detach() |     jacob = x.grad.detach() | ||||||
|     return jacob, target.detach(), y.detach(), out.detach() |     return jacob, target.detach(), y.detach(), out.detach() | ||||||
|  |  | ||||||
| def get_nasbench201_idx_score(idx, train_loader, searchspace, args): | def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | ||||||
|     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |     op_type = { | ||||||
|  |     'input': 0, | ||||||
|  |     'nor_conv_1x1': 1, | ||||||
|  |     'nor_conv_3x3': 2, | ||||||
|  |     'avg_pool_3x3': 3, | ||||||
|  |     'skip_connect': 4, | ||||||
|  |     'none': 5, | ||||||
|  |     'output': 6, | ||||||
|  |     } | ||||||
|  |  | ||||||
|  | def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): | ||||||
|  |     # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||||
|     # searchspace = nasspace.get_search_space(args) |     # searchspace = nasspace.get_search_space(args) | ||||||
|     if 'valid' in args.dataset: |     # if 'valid' in args.dataset: | ||||||
|         args.dataset = args.dataset.replace('-valid', '') |     #     args.dataset = args.dataset.replace('-valid', '') | ||||||
|  |  | ||||||
|     # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) |     # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||||
|     # os.makedirs(args.save_loc, exist_ok=True) |     # os.makedirs(args.save_loc, exist_ok=True) | ||||||
| @@ -182,17 +194,17 @@ train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, arg | |||||||
| print('start to get score') | print('start to get score') | ||||||
| print('5374') | print('5374') | ||||||
| start_time = time.time() | start_time = time.time() | ||||||
| print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args)) | print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||||
| end_time = time.time() | end_time = time.time() | ||||||
| print(f'5374 time: {end_time - start_time}') | print(f'5374 time: {end_time - start_time}') | ||||||
| print('5375') | print('5375') | ||||||
| start_time = time.time() | start_time = time.time() | ||||||
| print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args)) | print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||||
| end_time = time.time() | end_time = time.time() | ||||||
| print(f'5375 time: {end_time - start_time}') | print(f'5375 time: {end_time - start_time}') | ||||||
| print('5376') | print('5376') | ||||||
| start_time = time.time() | start_time = time.time() | ||||||
| print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args)) | print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||||
| end_time = time.time() | end_time = time.time() | ||||||
| print(f'5376 time: {end_time - start_time}') | print(f'5376 time: {end_time - start_time}') | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user