diff --git a/graph_dit/naswot/score_networks.py b/graph_dit/naswot/score_networks.py index 76f3a21..e5a9c76 100644 --- a/graph_dit/naswot/score_networks.py +++ b/graph_dit/naswot/score_networks.py @@ -9,6 +9,7 @@ from scores import get_score_func from scipy import stats import time # from pycls.models.nas.nas import Cell +from models import get_cell_based_tiny_net from utils import add_dropout, init_network parser = argparse.ArgumentParser(description='NAS Without Training') @@ -56,11 +57,22 @@ def get_batch_jacobian(net, x, target, device, args=None): jacob = x.grad.detach() return jacob, target.detach(), y.detach(), out.detach() -def get_nasbench201_idx_score(idx, train_loader, searchspace, args): - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): + op_type = { + 'input': 0, + 'nor_conv_1x1': 1, + 'nor_conv_3x3': 2, + 'avg_pool_3x3': 3, + 'skip_connect': 4, + 'none': 5, + 'output': 6, + } + +def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): + # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # searchspace = nasspace.get_search_space(args) - if 'valid' in args.dataset: - args.dataset = args.dataset.replace('-valid', '') + # if 'valid' in args.dataset: + # args.dataset = args.dataset.replace('-valid', '') # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) # os.makedirs(args.save_loc, exist_ok=True) @@ -182,17 +194,17 @@ train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, arg print('start to get score') print('5374') start_time = time.time() -print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args)) +print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) end_time = time.time() print(f'5374 time: {end_time - start_time}') print('5375') start_time = time.time() -print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args)) +print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) end_time = time.time() print(f'5375 time: {end_time - start_time}') print('5376') start_time = time.time() -print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args)) +print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) end_time = time.time() print(f'5376 time: {end_time - start_time}')