Compare commits

..

No commits in common. "aa4b38a0cc14359bbf6b892d0b023c00937862bb" and "aead4df707ee0bfb5d5c95c0d945e1d97ef0e8e1" have entirely different histories.

2 changed files with 9 additions and 13 deletions

3
.gitignore vendored
View File

@ -1,3 +1,2 @@
__pycache__/ __pycache__/
datasets/ datasets/
swap_results.csv

View File

@ -55,22 +55,19 @@ if __name__ == "__main__":
results = [] results = []
# nasbench_len = 15625
nasbench_len = 15625 nasbench_len = 15625
# for index, i in arch_info.iterrows(): # for index, i in arch_info.iterrows():
for ind in range(nasbench_len): for i in range(nasbench_len):
# print(f'Evaluating network: {index}') # print(f'Evaluating network: {index}')
print(f'Evaluating network: {ind}') print(f'Evaluating network: {i}')
config = api.get_net_config(ind, 'cifar10') config = api.get_net_config(i, 'cifar10')
network = get_cell_based_tiny_net(config) network = get_cell_based_tiny_net(config)
# nas_results = api.query_by_index(i, 'cifar10') nas_results = api.query_by_index(i, 'cifar10')
# acc = nas_results[111].get_eval('ori-test') acc = nas_results[111].get_eval('ori-test')
nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False)
acc = nas_results['test-accuracy']
# print(type(network)) print(type(network))
start_time = time.time() start_time = time.time()
# network = Network(3, 10, 1, eval(i.genotype)) # network = Network(3, 10, 1, eval(i.genotype))
@ -96,13 +93,13 @@ if __name__ == "__main__":
print(f'Average SWAP score: {np.mean(swap_score)}') print(f'Average SWAP score: {np.mean(swap_score)}')
print(f'Elapsed time: {end_time - start_time:.2f} seconds') print(f'Elapsed time: {end_time - start_time:.2f} seconds')
results.append([np.mean(swap_score), acc, ind]) results.append([np.mean(swap_score), acc, i])
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
results.to_csv('output/swap_results.csv', float_format='%.4f', index=False)
print() print()
print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}') print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}')
results.to_csv('swap_results.csv', float_format='%.4f', index=False)