Compare commits

...

2 Commits

Author SHA1 Message Date
mhz
63ca6c716e add the aircraft result 2024-09-01 23:09:56 +02:00
mhz
d36e1d1077 adjust threshhold for cifar100 2024-08-29 10:37:42 +02:00
2 changed files with 15652 additions and 18 deletions

File diff suppressed because it is too large Load Diff

View File

@ -10,6 +10,7 @@ api = API('./NAS-Bench-201-v1_1-096897.pth')
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--file_path', type=str, default='211035.txt',)
parser.add_argument('--datasets', type=str, default='cifar10',)
args = parser.parse_args()
def process_graph_data(text):
@ -89,6 +90,7 @@ def nodes_to_arch_str(nodes):
return arch_str
filename = args.file_path
datasets_name = args.datasets
with open('./output_graphs/' + filename, 'r') as f:
texts = f.read()
@ -96,7 +98,15 @@ with open('./output_graphs/' + filename, 'r') as f:
valid = 0
not_valid = 0
scores = []
dist = {'<90':0, '<91':0, '<92':0, '<93':0, '<94':0, '>94':0}
# 定义分类标准和分布字典的映射
thresholds = {
'cifar10': [90, 91, 92, 93, 94],
'cifar100': [68,69,70, 71, 72, 73]
}
dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]}
dist[f'>{thresholds[datasets_name][-1]}'] = 0
for i in range(len(df)):
nodes = df['nodes'][i]
edges = df['edges'][i]
@ -105,32 +115,30 @@ with open('./output_graphs/' + filename, 'r') as f:
valid += 1
arch_str = nodes_to_arch_str(nodes)
index = api.query_index_by_arch(arch_str)
# results = api.query_by_index(index, 'cifar10', hp='200')
# print(results)
# result = results[888].get_eval('ori-test')
res = api.get_more_info(index, 'cifar10', None, hp=200, is_random=False)
res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False)
acc = res['test-accuracy']
scores.append((index, acc))
if acc < 90:
dist['<90'] += 1
elif acc < 91 and acc >= 90:
dist['<91'] += 1
elif acc < 92 and acc >= 91:
dist['<92'] += 1
elif acc < 93 and acc >= 92:
dist['<93'] += 1
elif acc < 94 and acc >= 93:
dist['<94'] += 1
else:
dist['>94'] += 1
# 根据阈值更新分布
updated = False
for threshold in thresholds[datasets_name]:
if acc < threshold:
dist[f'<{threshold}'] += 1
updated = True
break
if not updated:
dist[f'>{thresholds[datasets_name][-1]}'] += 1
else:
not_valid += 1
with open('./output_graphs/' + filename + '.json', 'w') as f:
with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f:
json.dump(scores, f)
print(scores)
print(valid, not_valid)
print(dist)
print("mean: ", np.mean([x[1] for x in scores]))
print("max: ", np.max([x[1] for x in scores]))
print("min: ", np.min([x[1] for x in scores]))