diff --git a/graph_dit/exp_201/barplog.png b/graph_dit/exp_201/barplog.png new file mode 100644 index 0000000..f59f79a Binary files /dev/null and b/graph_dit/exp_201/barplog.png differ diff --git a/graph_dit/exp_201/main.py b/graph_dit/exp_201/main.py index 2cad786..378f752 100644 --- a/graph_dit/exp_201/main.py +++ b/graph_dit/exp_201/main.py @@ -2,44 +2,45 @@ import matplotlib.pyplot as plt import pandas as pd from nas_201_api import NASBench201API as API -from naswot.score_networks import get_nasbench201_idx_score -from naswot import datasets as dt -from naswot import nasspace +# from naswot.score_networks import get_nasbench201_idx_score +# from naswot import datasets as dt +# from naswot import nasspace -class Args(): - pass -args = Args() -args.trainval = True -args.augtype = 'none' -args.repeat = 1 -args.score = 'hook_logdet' -args.sigma = 0.05 -args.nasspace = 'nasbench201' -args.batch_size = 128 -args.GPU = '0' -args.dataset = 'cifar10' -args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' -args.data_loc = '../cifardata/' -args.seed = 777 -args.init = '' -args.save_loc = 'results' -args.save_string = 'naswot' -args.dropout = False -args.maxofn = 1 -args.n_samples = 100 -args.n_runs = 500 -args.stem_out_channels = 16 -args.num_stacks = 3 -args.num_modules_per_stack = 3 -args.num_labels = 1 -searchspace = nasspace.get_search_space(args) -train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) -device = torch.device('cuda:2') +# class Args(): +# pass +# args = Args() +# args.trainval = True +# args.augtype = 'none' +# args.repeat = 1 +# args.score = 'hook_logdet' +# args.sigma = 0.05 +# args.nasspace = 'nasbench201' +# args.batch_size = 128 +# args.GPU = '0' +# args.dataset = 'cifar10' +# args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' +# args.data_loc = '../cifardata/' +# args.seed = 777 +# args.init = '' +# args.save_loc = 'results' +# args.save_string = 'naswot' +# args.dropout = False +# args.maxofn = 1 +# args.n_samples = 100 +# args.n_runs = 500 +# args.stem_out_channels = 16 +# args.num_stacks = 3 +# args.num_modules_per_stack = 3 +# args.num_labels = 1 +# searchspace = nasspace.get_search_space(args) +# train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) +# device = torch.device('cuda:2') + + +source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' +api = API(source) - -# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' -# api = API(source) @@ -50,8 +51,10 @@ percentages = [] len_201 = 15625 for i in range(len_201): - percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) - percentages.append(percentage) + # percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) + results = api.query_by_index(i, 'cifar10') + result = results[111].get_eval('ori-test') + percentages.append(result) # 定义10%区间 bins = [i for i in range(0, 101, 10)]