add a datsets option to specify the datset you want, add a plot script
This commit is contained in:
parent
aa4b38a0cc
commit
551abc31f3
48
analyze.py
Normal file
48
analyze.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import csv
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from scipy import stats
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def plot(l):
|
||||||
|
labels = ['0-10k', '10k-20k,', '20k-30k', '30k-40k', '40k-50k', '50k-60k', '60k-70k']
|
||||||
|
l = [i/15625 for i in l]
|
||||||
|
l = l[:7]
|
||||||
|
plt.bar(labels, l)
|
||||||
|
plt.savefig('plot.png')
|
||||||
|
|
||||||
|
def analyse(filename):
|
||||||
|
l = [0 for i in range(10)]
|
||||||
|
scores = []
|
||||||
|
count = 0
|
||||||
|
best_value = -1
|
||||||
|
with open(filename) as file:
|
||||||
|
reader = csv.reader(file)
|
||||||
|
header = next(reader)
|
||||||
|
data = [row for row in reader]
|
||||||
|
|
||||||
|
for row in data:
|
||||||
|
score = row[0]
|
||||||
|
best_value = max(best_value, float(score))
|
||||||
|
# print(score)
|
||||||
|
ind = float(score) // 10000
|
||||||
|
ind = int(ind)
|
||||||
|
l[ind] += 1
|
||||||
|
acc = row[1]
|
||||||
|
index = row[2]
|
||||||
|
datas = list(zip(score, acc, index))
|
||||||
|
scores.append(score)
|
||||||
|
print(max(scores))
|
||||||
|
results = pd.DataFrame(datas, columns=['swap_score', 'valid_acc', 'index'])
|
||||||
|
print(results['swap_score'].max())
|
||||||
|
print(best_value)
|
||||||
|
plot(l)
|
||||||
|
return stats.spearmanr(results.swap_score, results.valid_acc)[0]
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print(analyse('output/swap_results.csv'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,6 +39,7 @@ parser.add_argument('--seed', default=0, type=int, help='random seed')
|
|||||||
parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
|
parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
|
||||||
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
|
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
|
||||||
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
|
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
|
||||||
|
parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
|
# arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
|
||||||
|
|
||||||
train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1)
|
train_data, _, _ = get_datasets(args.datasets, args.data_path, (args.input_samples, 3, 32, 32), -1)
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True)
|
||||||
loader = iter(train_loader)
|
loader = iter(train_loader)
|
||||||
inputs, _ = next(loader)
|
inputs, _ = next(loader)
|
||||||
@ -63,11 +64,11 @@ if __name__ == "__main__":
|
|||||||
# print(f'Evaluating network: {index}')
|
# print(f'Evaluating network: {index}')
|
||||||
print(f'Evaluating network: {ind}')
|
print(f'Evaluating network: {ind}')
|
||||||
|
|
||||||
config = api.get_net_config(ind, 'cifar10')
|
config = api.get_net_config(ind, args.datasets)
|
||||||
network = get_cell_based_tiny_net(config)
|
network = get_cell_based_tiny_net(config)
|
||||||
# nas_results = api.query_by_index(i, 'cifar10')
|
# nas_results = api.query_by_index(i, 'cifar10')
|
||||||
# acc = nas_results[111].get_eval('ori-test')
|
# acc = nas_results[111].get_eval('ori-test')
|
||||||
nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False)
|
nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False)
|
||||||
acc = nas_results['test-accuracy']
|
acc = nas_results['test-accuracy']
|
||||||
|
|
||||||
# print(type(network))
|
# print(type(network))
|
||||||
|
Loading…
Reference in New Issue
Block a user