adjust threshhold for cifar100
This commit is contained in:
parent
82183d3df7
commit
d36e1d1077
@ -10,6 +10,7 @@ api = API('./NAS-Bench-201-v1_1-096897.pth')
|
|||||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
|
||||||
parser.add_argument('--file_path', type=str, default='211035.txt',)
|
parser.add_argument('--file_path', type=str, default='211035.txt',)
|
||||||
|
parser.add_argument('--datasets', type=str, default='cifar10',)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def process_graph_data(text):
|
def process_graph_data(text):
|
||||||
@ -89,6 +90,7 @@ def nodes_to_arch_str(nodes):
|
|||||||
return arch_str
|
return arch_str
|
||||||
|
|
||||||
filename = args.file_path
|
filename = args.file_path
|
||||||
|
datasets_name = args.datasets
|
||||||
|
|
||||||
with open('./output_graphs/' + filename, 'r') as f:
|
with open('./output_graphs/' + filename, 'r') as f:
|
||||||
texts = f.read()
|
texts = f.read()
|
||||||
@ -96,7 +98,15 @@ with open('./output_graphs/' + filename, 'r') as f:
|
|||||||
valid = 0
|
valid = 0
|
||||||
not_valid = 0
|
not_valid = 0
|
||||||
scores = []
|
scores = []
|
||||||
dist = {'<90':0, '<91':0, '<92':0, '<93':0, '<94':0, '>94':0}
|
|
||||||
|
# 定义分类标准和分布字典的映射
|
||||||
|
thresholds = {
|
||||||
|
'cifar10': [90, 91, 92, 93, 94],
|
||||||
|
'cifar100': [68,69,70, 71, 72, 73]
|
||||||
|
}
|
||||||
|
dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]}
|
||||||
|
dist[f'>{thresholds[datasets_name][-1]}'] = 0
|
||||||
|
|
||||||
for i in range(len(df)):
|
for i in range(len(df)):
|
||||||
nodes = df['nodes'][i]
|
nodes = df['nodes'][i]
|
||||||
edges = df['edges'][i]
|
edges = df['edges'][i]
|
||||||
@ -105,32 +115,30 @@ with open('./output_graphs/' + filename, 'r') as f:
|
|||||||
valid += 1
|
valid += 1
|
||||||
arch_str = nodes_to_arch_str(nodes)
|
arch_str = nodes_to_arch_str(nodes)
|
||||||
index = api.query_index_by_arch(arch_str)
|
index = api.query_index_by_arch(arch_str)
|
||||||
# results = api.query_by_index(index, 'cifar10', hp='200')
|
res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False)
|
||||||
# print(results)
|
|
||||||
# result = results[888].get_eval('ori-test')
|
|
||||||
res = api.get_more_info(index, 'cifar10', None, hp=200, is_random=False)
|
|
||||||
acc = res['test-accuracy']
|
acc = res['test-accuracy']
|
||||||
scores.append((index, acc))
|
scores.append((index, acc))
|
||||||
if acc < 90:
|
|
||||||
dist['<90'] += 1
|
# 根据阈值更新分布
|
||||||
elif acc < 91 and acc >= 90:
|
updated = False
|
||||||
dist['<91'] += 1
|
for threshold in thresholds[datasets_name]:
|
||||||
elif acc < 92 and acc >= 91:
|
if acc < threshold:
|
||||||
dist['<92'] += 1
|
dist[f'<{threshold}'] += 1
|
||||||
elif acc < 93 and acc >= 92:
|
updated = True
|
||||||
dist['<93'] += 1
|
break
|
||||||
elif acc < 94 and acc >= 93:
|
if not updated:
|
||||||
dist['<94'] += 1
|
dist[f'>{thresholds[datasets_name][-1]}'] += 1
|
||||||
else:
|
|
||||||
dist['>94'] += 1
|
|
||||||
else:
|
else:
|
||||||
not_valid += 1
|
not_valid += 1
|
||||||
with open('./output_graphs/' + filename + '.json', 'w') as f:
|
|
||||||
|
with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f:
|
||||||
json.dump(scores, f)
|
json.dump(scores, f)
|
||||||
|
|
||||||
print(scores)
|
print(scores)
|
||||||
print(valid, not_valid)
|
print(valid, not_valid)
|
||||||
print(dist)
|
print(dist)
|
||||||
print("mean: ", np.mean([x[1] for x in scores]))
|
print("mean: ", np.mean([x[1] for x in scores]))
|
||||||
print("max: ", np.max([x[1] for x in scores]))
|
print("max: ", np.max([x[1] for x in scores]))
|
||||||
print("min: ", np.min([x[1] for x in scores]))
|
print("min: ", np.min([x[1] for x in scores]))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user