diff --git a/graph_dit/test_perf.py b/graph_dit/test_perf.py index 43c0160..f4790c6 100644 --- a/graph_dit/test_perf.py +++ b/graph_dit/test_perf.py @@ -10,6 +10,7 @@ api = API('./NAS-Bench-201-v1_1-096897.pth') parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--file_path', type=str, default='211035.txt',) +parser.add_argument('--datasets', type=str, default='cifar10',) args = parser.parse_args() def process_graph_data(text): @@ -89,6 +90,7 @@ def nodes_to_arch_str(nodes): return arch_str filename = args.file_path +datasets_name = args.datasets with open('./output_graphs/' + filename, 'r') as f: texts = f.read() @@ -96,7 +98,15 @@ with open('./output_graphs/' + filename, 'r') as f: valid = 0 not_valid = 0 scores = [] - dist = {'<90':0, '<91':0, '<92':0, '<93':0, '<94':0, '>94':0} + + # 定义分类标准和分布字典的映射 + thresholds = { + 'cifar10': [90, 91, 92, 93, 94], + 'cifar100': [68,69,70, 71, 72, 73] + } + dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]} + dist[f'>{thresholds[datasets_name][-1]}'] = 0 + for i in range(len(df)): nodes = df['nodes'][i] edges = df['edges'][i] @@ -105,32 +115,30 @@ with open('./output_graphs/' + filename, 'r') as f: valid += 1 arch_str = nodes_to_arch_str(nodes) index = api.query_index_by_arch(arch_str) - # results = api.query_by_index(index, 'cifar10', hp='200') - # print(results) - # result = results[888].get_eval('ori-test') - res = api.get_more_info(index, 'cifar10', None, hp=200, is_random=False) + res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False) acc = res['test-accuracy'] scores.append((index, acc)) - if acc < 90: - dist['<90'] += 1 - elif acc < 91 and acc >= 90: - dist['<91'] += 1 - elif acc < 92 and acc >= 91: - dist['<92'] += 1 - elif acc < 93 and acc >= 92: - dist['<93'] += 1 - elif acc < 94 and acc >= 93: - dist['<94'] += 1 - else: - dist['>94'] += 1 + + # 根据阈值更新分布 + updated = False + for threshold in thresholds[datasets_name]: + if acc < threshold: + dist[f'<{threshold}'] += 1 + updated = True + break + if not updated: + dist[f'>{thresholds[datasets_name][-1]}'] += 1 else: not_valid += 1 - with open('./output_graphs/' + filename + '.json', 'w') as f: + + with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f: json.dump(scores, f) + print(scores) print(valid, not_valid) print(dist) print("mean: ", np.mean([x[1] for x in scores])) print("max: ", np.max([x[1] for x in scores])) print("min: ", np.min([x[1] for x in scores])) +