Compare commits
2 Commits
82183d3df7
...
63ca6c716e
Author | SHA1 | Date | |
---|---|---|---|
63ca6c716e | |||
d36e1d1077 |
15626
graph_dit/swap_results_aircraft.csv
Normal file
15626
graph_dit/swap_results_aircraft.csv
Normal file
File diff suppressed because it is too large
Load Diff
@ -10,6 +10,7 @@ api = API('./NAS-Bench-201-v1_1-096897.pth')
|
||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||
|
||||
parser.add_argument('--file_path', type=str, default='211035.txt',)
|
||||
parser.add_argument('--datasets', type=str, default='cifar10',)
|
||||
args = parser.parse_args()
|
||||
|
||||
def process_graph_data(text):
|
||||
@ -89,6 +90,7 @@ def nodes_to_arch_str(nodes):
|
||||
return arch_str
|
||||
|
||||
filename = args.file_path
|
||||
datasets_name = args.datasets
|
||||
|
||||
with open('./output_graphs/' + filename, 'r') as f:
|
||||
texts = f.read()
|
||||
@ -96,7 +98,15 @@ with open('./output_graphs/' + filename, 'r') as f:
|
||||
valid = 0
|
||||
not_valid = 0
|
||||
scores = []
|
||||
dist = {'<90':0, '<91':0, '<92':0, '<93':0, '<94':0, '>94':0}
|
||||
|
||||
# 定义分类标准和分布字典的映射
|
||||
thresholds = {
|
||||
'cifar10': [90, 91, 92, 93, 94],
|
||||
'cifar100': [68,69,70, 71, 72, 73]
|
||||
}
|
||||
dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]}
|
||||
dist[f'>{thresholds[datasets_name][-1]}'] = 0
|
||||
|
||||
for i in range(len(df)):
|
||||
nodes = df['nodes'][i]
|
||||
edges = df['edges'][i]
|
||||
@ -105,32 +115,30 @@ with open('./output_graphs/' + filename, 'r') as f:
|
||||
valid += 1
|
||||
arch_str = nodes_to_arch_str(nodes)
|
||||
index = api.query_index_by_arch(arch_str)
|
||||
# results = api.query_by_index(index, 'cifar10', hp='200')
|
||||
# print(results)
|
||||
# result = results[888].get_eval('ori-test')
|
||||
res = api.get_more_info(index, 'cifar10', None, hp=200, is_random=False)
|
||||
res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False)
|
||||
acc = res['test-accuracy']
|
||||
scores.append((index, acc))
|
||||
if acc < 90:
|
||||
dist['<90'] += 1
|
||||
elif acc < 91 and acc >= 90:
|
||||
dist['<91'] += 1
|
||||
elif acc < 92 and acc >= 91:
|
||||
dist['<92'] += 1
|
||||
elif acc < 93 and acc >= 92:
|
||||
dist['<93'] += 1
|
||||
elif acc < 94 and acc >= 93:
|
||||
dist['<94'] += 1
|
||||
else:
|
||||
dist['>94'] += 1
|
||||
|
||||
# 根据阈值更新分布
|
||||
updated = False
|
||||
for threshold in thresholds[datasets_name]:
|
||||
if acc < threshold:
|
||||
dist[f'<{threshold}'] += 1
|
||||
updated = True
|
||||
break
|
||||
if not updated:
|
||||
dist[f'>{thresholds[datasets_name][-1]}'] += 1
|
||||
else:
|
||||
not_valid += 1
|
||||
with open('./output_graphs/' + filename + '.json', 'w') as f:
|
||||
|
||||
with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f:
|
||||
json.dump(scores, f)
|
||||
|
||||
print(scores)
|
||||
print(valid, not_valid)
|
||||
print(dist)
|
||||
print("mean: ", np.mean([x[1] for x in scores]))
|
||||
print("max: ", np.max([x[1] for x in scores]))
|
||||
print("min: ", np.min([x[1] for x in scores]))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user