Compare commits
3 Commits
aead4df707
...
aa4b38a0cc
Author | SHA1 | Date | |
---|---|---|---|
aa4b38a0cc | |||
f72990a675 | |||
ff85bba9cd |
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
datasets/
|
datasets/
|
||||||
|
swap_results.csv
|
@ -55,19 +55,22 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
# nasbench_len = 15625
|
||||||
nasbench_len = 15625
|
nasbench_len = 15625
|
||||||
|
|
||||||
# for index, i in arch_info.iterrows():
|
# for index, i in arch_info.iterrows():
|
||||||
for i in range(nasbench_len):
|
for ind in range(nasbench_len):
|
||||||
# print(f'Evaluating network: {index}')
|
# print(f'Evaluating network: {index}')
|
||||||
print(f'Evaluating network: {i}')
|
print(f'Evaluating network: {ind}')
|
||||||
|
|
||||||
config = api.get_net_config(i, 'cifar10')
|
config = api.get_net_config(ind, 'cifar10')
|
||||||
network = get_cell_based_tiny_net(config)
|
network = get_cell_based_tiny_net(config)
|
||||||
nas_results = api.query_by_index(i, 'cifar10')
|
# nas_results = api.query_by_index(i, 'cifar10')
|
||||||
acc = nas_results[111].get_eval('ori-test')
|
# acc = nas_results[111].get_eval('ori-test')
|
||||||
|
nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False)
|
||||||
|
acc = nas_results['test-accuracy']
|
||||||
|
|
||||||
print(type(network))
|
# print(type(network))
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# network = Network(3, 10, 1, eval(i.genotype))
|
# network = Network(3, 10, 1, eval(i.genotype))
|
||||||
@ -93,13 +96,13 @@ if __name__ == "__main__":
|
|||||||
print(f'Average SWAP score: {np.mean(swap_score)}')
|
print(f'Average SWAP score: {np.mean(swap_score)}')
|
||||||
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
|
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
|
||||||
|
|
||||||
results.append([np.mean(swap_score), acc, i])
|
results.append([np.mean(swap_score), acc, ind])
|
||||||
|
|
||||||
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
|
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
|
||||||
|
results.to_csv('output/swap_results.csv', float_format='%.4f', index=False)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}')
|
print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}')
|
||||||
results.to_csv('swap_results.csv', float_format='%.4f', index=False)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user