fix the index bug in the output

This commit is contained in:
Mhrooz 2024-08-26 10:58:20 +02:00
parent f72990a675
commit aa4b38a0cc

View File

@ -59,15 +59,15 @@ if __name__ == "__main__":
nasbench_len = 15625
# for index, i in arch_info.iterrows():
for i in range(nasbench_len):
for ind in range(nasbench_len):
# print(f'Evaluating network: {index}')
print(f'Evaluating network: {i}')
print(f'Evaluating network: {ind}')
config = api.get_net_config(i, 'cifar10')
config = api.get_net_config(ind, 'cifar10')
network = get_cell_based_tiny_net(config)
# nas_results = api.query_by_index(i, 'cifar10')
# acc = nas_results[111].get_eval('ori-test')
nas_results = api.get_more_info(i, 'cifar10', None, hp=200, is_random=False)
nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False)
acc = nas_results['test-accuracy']
# print(type(network))
@ -96,7 +96,7 @@ if __name__ == "__main__":
print(f'Average SWAP score: {np.mean(swap_score)}')
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
results.append([np.mean(swap_score), acc, i])
results.append([np.mean(swap_score), acc, ind])
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
results.to_csv('output/swap_results.csv', float_format='%.4f', index=False)