swap-nas/analyze.py

64 lines
1.8 KiB
Python
Raw Normal View History

import csv
import matplotlib.pyplot as plt
from scipy import stats
import pandas as pd
2024-08-29 09:20:29 +02:00
import argparse
2024-08-29 09:20:29 +02:00
def plot(l,filename):
2024-08-29 09:36:33 +02:00
lenth = len(l)
2024-08-29 09:20:29 +02:00
threshold = [0, 10000, 20000, 30000, 40000, 50000, 60000, 70000]
labels = ['0-10k', '10k-20k,', '20k-30k', '30k-40k', '40k-50k', '50k-60k', '60k-70k']
l = [i/15625 for i in l]
l = l[:7]
2024-08-29 09:36:33 +02:00
datasets = filename.split('_')[-1].split('.')[0]
plt.figure(figsize=(8, 6))
plt.subplots_adjust(top=0.85)
plt.ylim(0,0.3)
plt.title('Distribution of Swap Scores in ' + datasets)
plt.bar(labels, l)
2024-08-29 09:36:33 +02:00
for i, v in enumerate(l):
plt.text(i, v + 0.01, str(round(v, 2)), ha='center', va='bottom')
2024-08-29 09:20:29 +02:00
plt.savefig(filename)
def analyse(filename):
l = [0 for i in range(10)]
scores = []
count = 0
best_value = -1
with open(filename) as file:
reader = csv.reader(file)
header = next(reader)
data = [row for row in reader]
for row in data:
score = row[0]
best_value = max(best_value, float(score))
# print(score)
ind = float(score) // 10000
ind = int(ind)
l[ind] += 1
acc = row[1]
index = row[2]
datas = list(zip(score, acc, index))
scores.append(score)
print(max(scores))
results = pd.DataFrame(datas, columns=['swap_score', 'valid_acc', 'index'])
print(results['swap_score'].max())
print(best_value)
2024-08-29 09:20:29 +02:00
plot(l, filename + '.png')
return stats.spearmanr(results.swap_score, results.valid_acc)[0]
if __name__ == '__main__':
2024-08-29 09:20:29 +02:00
parser = argparse.ArgumentParser()
parser.add_argument('--filename', type=str, help='Filename to analyze', default='swap_results.csv')
args = parser.parse_args()
print(analyse('output' + '/' + args.filename))