update the script to use nasbench-201 api
This commit is contained in:
parent
6d9db64a48
commit
05ee34e355
BIN
graph_dit/exp_201/barplog.png
Normal file
BIN
graph_dit/exp_201/barplog.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
@ -2,44 +2,45 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from nas_201_api import NASBench201API as API
|
from nas_201_api import NASBench201API as API
|
||||||
from naswot.score_networks import get_nasbench201_idx_score
|
# from naswot.score_networks import get_nasbench201_idx_score
|
||||||
from naswot import datasets as dt
|
# from naswot import datasets as dt
|
||||||
from naswot import nasspace
|
# from naswot import nasspace
|
||||||
|
|
||||||
class Args():
|
# class Args():
|
||||||
pass
|
# pass
|
||||||
args = Args()
|
# args = Args()
|
||||||
args.trainval = True
|
# args.trainval = True
|
||||||
args.augtype = 'none'
|
# args.augtype = 'none'
|
||||||
args.repeat = 1
|
# args.repeat = 1
|
||||||
args.score = 'hook_logdet'
|
# args.score = 'hook_logdet'
|
||||||
args.sigma = 0.05
|
# args.sigma = 0.05
|
||||||
args.nasspace = 'nasbench201'
|
# args.nasspace = 'nasbench201'
|
||||||
args.batch_size = 128
|
# args.batch_size = 128
|
||||||
args.GPU = '0'
|
# args.GPU = '0'
|
||||||
args.dataset = 'cifar10'
|
# args.dataset = 'cifar10'
|
||||||
args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
# args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
args.data_loc = '../cifardata/'
|
# args.data_loc = '../cifardata/'
|
||||||
args.seed = 777
|
# args.seed = 777
|
||||||
args.init = ''
|
# args.init = ''
|
||||||
args.save_loc = 'results'
|
# args.save_loc = 'results'
|
||||||
args.save_string = 'naswot'
|
# args.save_string = 'naswot'
|
||||||
args.dropout = False
|
# args.dropout = False
|
||||||
args.maxofn = 1
|
# args.maxofn = 1
|
||||||
args.n_samples = 100
|
# args.n_samples = 100
|
||||||
args.n_runs = 500
|
# args.n_runs = 500
|
||||||
args.stem_out_channels = 16
|
# args.stem_out_channels = 16
|
||||||
args.num_stacks = 3
|
# args.num_stacks = 3
|
||||||
args.num_modules_per_stack = 3
|
# args.num_modules_per_stack = 3
|
||||||
args.num_labels = 1
|
# args.num_labels = 1
|
||||||
searchspace = nasspace.get_search_space(args)
|
# searchspace = nasspace.get_search_space(args)
|
||||||
train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
# train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||||
device = torch.device('cuda:2')
|
# device = torch.device('cuda:2')
|
||||||
|
|
||||||
|
|
||||||
|
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
api = API(source)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
|
||||||
# api = API(source)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -50,8 +51,10 @@ percentages = []
|
|||||||
len_201 = 15625
|
len_201 = 15625
|
||||||
|
|
||||||
for i in range(len_201):
|
for i in range(len_201):
|
||||||
percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device)
|
# percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device)
|
||||||
percentages.append(percentage)
|
results = api.query_by_index(i, 'cifar10')
|
||||||
|
result = results[111].get_eval('ori-test')
|
||||||
|
percentages.append(result)
|
||||||
|
|
||||||
# 定义10%区间
|
# 定义10%区间
|
||||||
bins = [i for i in range(0, 101, 10)]
|
bins = [i for i in range(0, 101, 10)]
|
||||||
|
Loading…
Reference in New Issue
Block a user