adjust threshhold for cifar100

This commit is contained in:
mhz 2024-08-29 10:37:42 +02:00
parent 82183d3df7
commit d36e1d1077

View File

@ -10,6 +10,7 @@ api = API('./NAS-Bench-201-v1_1-096897.pth')
parser = argparse.ArgumentParser(description='Process some integers.') parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--file_path', type=str, default='211035.txt',) parser.add_argument('--file_path', type=str, default='211035.txt',)
parser.add_argument('--datasets', type=str, default='cifar10',)
args = parser.parse_args() args = parser.parse_args()
def process_graph_data(text): def process_graph_data(text):
@ -89,6 +90,7 @@ def nodes_to_arch_str(nodes):
return arch_str return arch_str
filename = args.file_path filename = args.file_path
datasets_name = args.datasets
with open('./output_graphs/' + filename, 'r') as f: with open('./output_graphs/' + filename, 'r') as f:
texts = f.read() texts = f.read()
@ -96,7 +98,15 @@ with open('./output_graphs/' + filename, 'r') as f:
valid = 0 valid = 0
not_valid = 0 not_valid = 0
scores = [] scores = []
dist = {'<90':0, '<91':0, '<92':0, '<93':0, '<94':0, '>94':0}
# 定义分类标准和分布字典的映射
thresholds = {
'cifar10': [90, 91, 92, 93, 94],
'cifar100': [68,69,70, 71, 72, 73]
}
dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]}
dist[f'>{thresholds[datasets_name][-1]}'] = 0
for i in range(len(df)): for i in range(len(df)):
nodes = df['nodes'][i] nodes = df['nodes'][i]
edges = df['edges'][i] edges = df['edges'][i]
@ -105,32 +115,30 @@ with open('./output_graphs/' + filename, 'r') as f:
valid += 1 valid += 1
arch_str = nodes_to_arch_str(nodes) arch_str = nodes_to_arch_str(nodes)
index = api.query_index_by_arch(arch_str) index = api.query_index_by_arch(arch_str)
# results = api.query_by_index(index, 'cifar10', hp='200') res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False)
# print(results)
# result = results[888].get_eval('ori-test')
res = api.get_more_info(index, 'cifar10', None, hp=200, is_random=False)
acc = res['test-accuracy'] acc = res['test-accuracy']
scores.append((index, acc)) scores.append((index, acc))
if acc < 90:
dist['<90'] += 1 # 根据阈值更新分布
elif acc < 91 and acc >= 90: updated = False
dist['<91'] += 1 for threshold in thresholds[datasets_name]:
elif acc < 92 and acc >= 91: if acc < threshold:
dist['<92'] += 1 dist[f'<{threshold}'] += 1
elif acc < 93 and acc >= 92: updated = True
dist['<93'] += 1 break
elif acc < 94 and acc >= 93: if not updated:
dist['<94'] += 1 dist[f'>{thresholds[datasets_name][-1]}'] += 1
else:
dist['>94'] += 1
else: else:
not_valid += 1 not_valid += 1
with open('./output_graphs/' + filename + '.json', 'w') as f:
with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f:
json.dump(scores, f) json.dump(scores, f)
print(scores) print(scores)
print(valid, not_valid) print(valid, not_valid)
print(dist) print(dist)
print("mean: ", np.mean([x[1] for x in scores])) print("mean: ", np.mean([x[1] for x in scores]))
print("max: ", np.max([x[1] for x in scores])) print("max: ", np.max([x[1] for x in scores]))
print("min: ", np.min([x[1] for x in scores])) print("min: ", np.min([x[1] for x in scores]))