import argparse import nasspace import datasets import random import numpy as np import torch import os from scores import get_score_func from scipy import stats import time # from pycls.models.nas.nas import Cell from models import get_cell_based_tiny_net from utils import add_dropout, init_network parser = argparse.ArgumentParser(description='NAS Without Training') parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth', type=str, help='path to API') parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file') parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate') parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use') parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch') parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use') parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"') parser.add_argument('--GPU', default='0', type=str) parser.add_argument('--seed', default=1, type=int) parser.add_argument('--init', default='', type=str) parser.add_argument('--trainval', action='store_true') parser.add_argument('--dropout', action='store_true') parser.add_argument('--dataset', default='cifar10', type=str) parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network') parser.add_argument('--n_samples', default=100, type=int) parser.add_argument('--n_runs', default=500, type=int) parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)') parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)') parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)') parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU # Reproducibility torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) def get_batch_jacobian(net, x, target, device, args=None): net.zero_grad() x.requires_grad_(True) y, out = net(x) y.backward(torch.ones_like(y)) jacob = x.grad.detach() return jacob, target.detach(), y.detach(), out.detach() def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): op_type = { 'input': 0, 'nor_conv_1x1': 1, 'nor_conv_3x3': 2, 'avg_pool_3x3': 3, 'skip_connect': 4, 'none': 5, 'output': 6, } def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # searchspace = nasspace.get_search_space(args) # if 'valid' in args.dataset: # args.dataset = args.dataset.replace('-valid', '') # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) # os.makedirs(args.save_loc, exist_ok=True) # filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}{"_" + args.init + "_" if args.init != "" else args.init}_{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}' # accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}' # scores = np.zeros(len(searchspace)) # accs = np.zeros(len(searchspace)) i = idx uid = idx print(f'uid: {uid}') print(f'get network') network = searchspace.get_network(uid) print(f'get network done') try: if args.dropout: add_dropout(network, args.sigma) if args.init != '': init_network(network, args.init) if 'hook_' in args.score: network.K = np.zeros((args.batch_size, args.batch_size)) def counting_forward_hook(module, inp, out): try: if not module.visited_backwards: return if isinstance(inp, tuple): # print(len(inp)) inp = inp[0] inp = inp.view(inp.size(0), -1) x = (inp > 0).float() K = x @ x.t() K2 = (1.-x) @ (1.-x.t()) network.K = network.K + K.cpu().numpy() + K2.cpu().numpy() except: pass def counting_backward_hook(module, inp, out): module.visited_backwards = True for name, module in network.named_modules(): if 'ReLU' in str(type(module)): #hooks[name] = module.register_forward_hook(counting_hook) module.register_forward_hook(counting_forward_hook) module.register_backward_hook(counting_backward_hook) network = network.to(device) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) s = [] for j in range(args.maxofn): data_iterator = iter(train_loader) x, target = next(data_iterator) x2 = torch.clone(x) x2 = x2.to(device) x, target = x.to(device), target.to(device) jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args) if 'hook_' in args.score: network(x2.to(device)) s.append(get_score_func(args.score)(network.K, target)) else: s.append(get_score_func(args.score)(jacobs, labels)) return np.mean(s) scores[i] = np.mean(s) accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval) accs_ = accs[~np.isnan(scores)] scores_ = scores[~np.isnan(scores)] numnan = np.isnan(scores).sum() tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)]) print(f'{tau}') if i % 1000 == 0: np.save(filename, scores) np.save(accfilename, accs) except Exception as e: print(e) print('final result') return np.nan class Args: pass args = Args() args.trainval = True args.augtype = 'none' args.repeat = 1 args.score = 'hook_logdet' args.sigma = 0.05 args.nasspace = 'nasbench201' args.batch_size = 128 args.GPU = '0' args.dataset = 'cifar10-valid' args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' args.data_loc = '../cifardata/' args.seed = 777 args.init = '' args.save_loc = 'results' args.save_string = 'naswot' args.dropout = False args.maxofn = 1 args.n_samples = 100 args.n_runs = 500 args.stem_out_channels = 16 args.num_stacks = 3 args.num_modules_per_stack = 3 args.num_labels = 1 if 'valid' in args.dataset: args.dataset = args.dataset.replace('-valid', '') print('start to get search space') start_time = time.time() searchspace = nasspace.get_search_space(args) end_time = time.time() print(f'search space time: {end_time - start_time}') train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) print('start to get score') print('5374') start_time = time.time() print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) end_time = time.time() print(f'5374 time: {end_time - start_time}') print('5375') start_time = time.time() print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) end_time = time.time() print(f'5375 time: {end_time - start_time}') print('5376') start_time = time.time() print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) end_time = time.time() print(f'5376 time: {end_time - start_time}') # device = "cuda:0" # dataset = dataset # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # savedataset = args.dataset # dataset = 'fake' if 'fake' in args.dataset else args.dataset # args.dataset = args.dataset.replace('fake', '') # if args.dataset == 'cifar10': # args.dataset = args.dataset + '-valid' # searchspace = nasspace.get_search_space(args) # if 'valid' in args.dataset: # args.dataset = args.dataset.replace('-valid', '') # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) # os.makedirs(args.save_loc, exist_ok=True) # filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{savedataset}{"_" + args.init + "_" if args.init != "" else args.init}_{"_dropout" if args.dropout else ""}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}' # accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{savedataset}_{args.trainval}' # if args.dataset == 'cifar10': # acc_type = 'ori-test' # val_acc_type = 'x-valid' # else: # acc_type = 'x-test' # val_acc_type = 'x-valid' # scores = np.zeros(len(searchspace)) # try: # accs = np.load(accfilename + '.npy') # except: # accs = np.zeros(len(searchspace)) # for i, (uid, network) in enumerate(searchspace): # # Reproducibility # try: # if args.dropout: # add_dropout(network, args.sigma) # if args.init != '': # init_network(network, args.init) # if 'hook_' in args.score: # network.K = np.zeros((args.batch_size, args.batch_size)) # def counting_forward_hook(module, inp, out): # try: # if not module.visited_backwards: # return # if isinstance(inp, tuple): # print(len(inp)) # inp = inp[0] # inp = inp.view(inp.size(0), -1) # x = (inp > 0).float() # K = x @ x.t() # K2 = (1.-x) @ (1.-x.t()) # network.K = network.K + K.cpu().numpy() + K2.cpu().numpy() # except: # pass # def counting_backward_hook(module, inp, out): # module.visited_backwards = True # for name, module in network.named_modules(): # if 'ReLU' in str(type(module)): # #hooks[name] = module.register_forward_hook(counting_hook) # module.register_forward_hook(counting_forward_hook) # module.register_backward_hook(counting_backward_hook) # network = network.to(device) # random.seed(args.seed) # np.random.seed(args.seed) # torch.manual_seed(args.seed) # s = [] # for j in range(args.maxofn): # data_iterator = iter(train_loader) # x, target = next(data_iterator) # x2 = torch.clone(x) # x2 = x2.to(device) # x, target = x.to(device), target.to(device) # jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args) # if 'hook_' in args.score: # network(x2.to(device)) # s.append(get_score_func(args.score)(network.K, target)) # else: # s.append(get_score_func(args.score)(jacobs, labels)) # scores[i] = np.mean(s) # accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval) # accs_ = accs[~np.isnan(scores)] # scores_ = scores[~np.isnan(scores)] # numnan = np.isnan(scores).sum() # tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)]) # print(f'{tau}') # if i % 1000 == 0: # np.save(filename, scores) # np.save(accfilename, accs) # except Exception as e: # print(e) # accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval) # scores[i] = np.nan # np.save(filename, scores) # np.save(accfilename, accs)