Compare commits
No commits in common. "a0473008a1b59f7beb7e6698e95873fca398bed8" and "6d9db64a48095fb226b73c1fb69571928897975c" have entirely different histories.
a0473008a1
...
6d9db64a48
Binary file not shown.
Before Width: | Height: | Size: 30 KiB |
@ -2,45 +2,44 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from nas_201_api import NASBench201API as API
|
from nas_201_api import NASBench201API as API
|
||||||
# from naswot.score_networks import get_nasbench201_idx_score
|
from naswot.score_networks import get_nasbench201_idx_score
|
||||||
# from naswot import datasets as dt
|
from naswot import datasets as dt
|
||||||
# from naswot import nasspace
|
from naswot import nasspace
|
||||||
|
|
||||||
# class Args():
|
class Args():
|
||||||
# pass
|
pass
|
||||||
# args = Args()
|
args = Args()
|
||||||
# args.trainval = True
|
args.trainval = True
|
||||||
# args.augtype = 'none'
|
args.augtype = 'none'
|
||||||
# args.repeat = 1
|
args.repeat = 1
|
||||||
# args.score = 'hook_logdet'
|
args.score = 'hook_logdet'
|
||||||
# args.sigma = 0.05
|
args.sigma = 0.05
|
||||||
# args.nasspace = 'nasbench201'
|
args.nasspace = 'nasbench201'
|
||||||
# args.batch_size = 128
|
args.batch_size = 128
|
||||||
# args.GPU = '0'
|
args.GPU = '0'
|
||||||
# args.dataset = 'cifar10'
|
args.dataset = 'cifar10'
|
||||||
# args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
# args.data_loc = '../cifardata/'
|
args.data_loc = '../cifardata/'
|
||||||
# args.seed = 777
|
args.seed = 777
|
||||||
# args.init = ''
|
args.init = ''
|
||||||
# args.save_loc = 'results'
|
args.save_loc = 'results'
|
||||||
# args.save_string = 'naswot'
|
args.save_string = 'naswot'
|
||||||
# args.dropout = False
|
args.dropout = False
|
||||||
# args.maxofn = 1
|
args.maxofn = 1
|
||||||
# args.n_samples = 100
|
args.n_samples = 100
|
||||||
# args.n_runs = 500
|
args.n_runs = 500
|
||||||
# args.stem_out_channels = 16
|
args.stem_out_channels = 16
|
||||||
# args.num_stacks = 3
|
args.num_stacks = 3
|
||||||
# args.num_modules_per_stack = 3
|
args.num_modules_per_stack = 3
|
||||||
# args.num_labels = 1
|
args.num_labels = 1
|
||||||
# searchspace = nasspace.get_search_space(args)
|
searchspace = nasspace.get_search_space(args)
|
||||||
# train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||||
# device = torch.device('cuda:2')
|
device = torch.device('cuda:2')
|
||||||
|
|
||||||
|
|
||||||
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
|
||||||
api = API(source)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
|
# api = API(source)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -51,10 +50,8 @@ percentages = []
|
|||||||
len_201 = 15625
|
len_201 = 15625
|
||||||
|
|
||||||
for i in range(len_201):
|
for i in range(len_201):
|
||||||
# percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device)
|
percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device)
|
||||||
results = api.query_by_index(i, 'cifar10')
|
percentages.append(percentage)
|
||||||
result = results[111].get_eval('ori-test')
|
|
||||||
percentages.append(result)
|
|
||||||
|
|
||||||
# 定义10%区间
|
# 定义10%区间
|
||||||
bins = [i for i in range(0, 101, 10)]
|
bins = [i for i in range(0, 101, 10)]
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user