116 lines
4.3 KiB
Python
116 lines
4.3 KiB
Python
import os
|
|
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
|
import argparse
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import pandas as pd
|
|
from scipy import stats
|
|
from src.utils.utilities import *
|
|
from src.metrics.swap import SWAP
|
|
from src.datasets.utilities import get_datasets
|
|
from src.search_space.networks import *
|
|
import time
|
|
|
|
# NASBench-201
|
|
from nas_201_api import NASBench201API as API
|
|
|
|
# xautodl
|
|
from xautodl.models import get_cell_based_tiny_net
|
|
|
|
# initalize nasbench-201
|
|
nas_201_path = 'datasets/NAS-Bench-201-v1_1-096897.pth'
|
|
print(f'Loading NAS-Bench-201 from {nas_201_path}')
|
|
start_time = time.time()
|
|
api = API(nas_201_path)
|
|
end_time = time.time()
|
|
print(f'Loaded NAS-Bench-201 in {end_time - start_time:.2f} seconds')
|
|
|
|
# Settings for console outputs
|
|
import warnings
|
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
warnings.simplefilter(action='ignore', category=UserWarning)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# general setting
|
|
parser.add_argument('--data_path', default="datasets", type=str, nargs='?', help='path to the image dataset (datasets or datasets/ILSVRC/Data/CLS-LOC)')
|
|
parser.add_argument('--seed', default=0, type=int, help='random seed')
|
|
parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
|
|
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
|
|
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
|
|
parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets')
|
|
parser.add_argument('--start_index', default=0, type=int, help='start index of the networks to evaluate')
|
|
|
|
args = parser.parse_args()
|
|
if __name__ == "__main__":
|
|
|
|
device = torch.device(args.device)
|
|
|
|
# arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
|
|
|
|
train_data, _, _ = get_datasets(args.datasets, args.data_path, (args.input_samples, 3, 32, 32), -1)
|
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True)
|
|
loader = iter(train_loader)
|
|
inputs, _ = next(loader)
|
|
|
|
results = []
|
|
|
|
# nasbench_len = 15625
|
|
nasbench_len = 15625
|
|
filename = f'output/swap_results_{args.datasets}.csv'
|
|
if args.datasets == 'aircraft':
|
|
api_datasets = 'cifar10'
|
|
|
|
# for index, i in arch_info.iterrows():
|
|
for ind in range(args.start_index,nasbench_len):
|
|
# print(f'Evaluating network: {index}')
|
|
print(f'Evaluating network: {ind}')
|
|
config = api.get_net_config(ind, api_datasets)
|
|
network = get_cell_based_tiny_net(config)
|
|
# nas_results = api.query_by_index(i, 'cifar10')
|
|
# acc = nas_results[111].get_eval('ori-test')
|
|
# nas_results = api.get_more_info(ind, api_datasets, None, hp=200, is_random=False)
|
|
# acc = nas_results['test-accuracy']
|
|
acc = 99
|
|
|
|
# print(type(network))
|
|
start_time = time.time()
|
|
|
|
# network = Network(3, 10, 1, eval(i.genotype))
|
|
network = network.to(device)
|
|
|
|
end_time = time.time()
|
|
print(f'Loaded network in {end_time - start_time:.2f} seconds')
|
|
|
|
print(f'initiliazing SWAP')
|
|
swap = SWAP(model=network, inputs=inputs, device=device, seed=args.seed)
|
|
|
|
swap_score = []
|
|
|
|
print(f'Calculating SWAP score')
|
|
start_time = time.time()
|
|
for i in range(args.repeats):
|
|
print(f'Iteration: {i+1}/{args.repeats}', end='\r')
|
|
network = network.apply(network_weight_gaussian_init)
|
|
swap.reinit()
|
|
swap_score.append(swap.forward())
|
|
swap.clear()
|
|
end_time = time.time()
|
|
print(f'Average SWAP score: {np.mean(swap_score)}')
|
|
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
|
|
|
|
results.append([np.mean(swap_score), acc, ind])
|
|
with open(filename, 'a') as f:
|
|
f.write(f'{np.mean(swap_score)},{acc},{ind}\n')
|
|
|
|
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
|
|
results.to_csv('output/swap_results.csv', float_format='%.4f', index=False)
|
|
|
|
print()
|
|
print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}')
|
|
|
|
|
|
|
|
|