try to pack the naswot

This commit is contained in:
mhz 2024-07-28 23:45:02 +02:00
parent 13f77963d0
commit fd4f0452f9
102 changed files with 112 additions and 13 deletions

View File

@ -11,7 +11,7 @@ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_ci
] ]
# useful modules # useful modules
from config_utils import dict2config from naswot.config_utils import dict2config
from .SharedUtils import change_key from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures from .cell_searchs import CellStructure, CellArchitectures

View File

@ -1,16 +1,16 @@
from models import get_cell_based_tiny_net, get_search_spaces from naswot.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
from nasbench import api as nasbench101api from nasbench import api as nasbench101api
from nas_101_api.model import Network from naswot.nas_101_api.model import Network
from nas_101_api.model_spec import ModelSpec from naswot.nas_101_api.model_spec import ModelSpec
import itertools import itertools
import random import random
import numpy as np import numpy as np
from models.cell_searchs.genotypes import Structure from naswot.models.cell_searchs.genotypes import Structure
from copy import deepcopy from copy import deepcopy
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR from naswot.pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
from pycls.models.anynet import AnyNet from naswot.pycls.models.anynet import AnyNet
from pycls.models.nas.genotypes import GENOTYPES, Genotype from naswot.pycls.models.nas.genotypes import GENOTYPES, Genotype
import json import json
import torch import torch
@ -26,6 +26,7 @@ class Nasbench201:
print(config) print(config)
config['num_classes'] = 1 config['num_classes'] = 1
network = get_cell_based_tiny_net(config) network = get_cell_based_tiny_net(config)
print(network)
return network return network
def __iter__(self): def __iter__(self):
for uid in range(len(self)): for uid in range(len(self)):

View File

@ -1,16 +1,16 @@
import argparse import argparse
import nasspace from naswot import nasspace
import datasets import datasets
import random import random
import numpy as np import numpy as np
import torch import torch
import os import os
from scores import get_score_func from naswot.scores import get_score_func
from scipy import stats from scipy import stats
import time import time
# from pycls.models.nas.nas import Cell # from pycls.models.nas.nas import Cell
from models import get_cell_based_tiny_net from naswot.models import get_cell_based_tiny_net
from utils import add_dropout, init_network from naswot.utils import add_dropout, init_network
parser = argparse.ArgumentParser(description='NAS Without Training') parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
@ -57,11 +57,95 @@ def get_batch_jacobian(net, x, target, device, args=None):
jacob = x.grad.detach() jacob = x.grad.detach()
return jacob, target.detach(), y.detach(), out.detach() return jacob, target.detach(), y.detach(), out.detach()
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): def get_config_by_nodes(nodes):
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \
num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \
num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|'
config = {
'name': 'infer.tiny',
'C': 16,
'N': 5,
'arch_str': arch_str,
'num_classes': 10,
}
return config
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device):
assert len(nodes) == 8
network = get_cell_based_tiny_net(get_config_by_nodes(nodes))
try:
if args.dropout:
add_dropout(network, args.sigma)
if args.init != '':
init_network(network, args.init)
if 'hook_' in args.score:
network.K = np.zeros((args.batch_size, args.batch_size))
def counting_forward_hook(module, inp, out):
try:
if not module.visited_backwards:
return
if isinstance(inp, tuple):
# print(len(inp))
inp = inp[0]
inp = inp.view(inp.size(0), -1)
x = (inp > 0).float()
K = x @ x.t()
K2 = (1.-x) @ (1.-x.t())
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
except:
pass
def counting_backward_hook(module, inp, out):
module.visited_backwards = True
for name, module in network.named_modules():
if 'ReLU' in str(type(module)):
#hooks[name] = module.register_forward_hook(counting_hook)
module.register_forward_hook(counting_forward_hook)
module.register_backward_hook(counting_backward_hook)
network = network.to(device)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
s = []
for j in range(args.maxofn):
data_iterator = iter(train_loader)
x, target = next(data_iterator)
x2 = torch.clone(x)
x2 = x2.to(device)
x, target = x.to(device), target.to(device)
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
if 'hook_' in args.score:
network(x2.to(device))
s.append(get_score_func(args.score)(network.K, target))
else:
s.append(get_score_func(args.score)(jacobs, labels))
return np.mean(s)
scores[i] = np.mean(s)
accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval)
accs_ = accs[~np.isnan(scores)]
scores_ = scores[~np.isnan(scores)]
numnan = np.isnan(scores).sum()
tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)])
print(f'{tau}')
if i % 1000 == 0:
np.save(filename, scores)
np.save(accfilename, accs)
except Exception as e:
print(e)
print('final result')
return np.nan
def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device):
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# searchspace = nasspace.get_search_space(args) # searchspace = nasspace.get_search_space(args)
@ -181,12 +265,19 @@ if 'valid' in args.dataset:
args.dataset = args.dataset.replace('-valid', '') args.dataset = args.dataset.replace('-valid', '')
print('start to get search space') print('start to get search space')
start_time = time.time() start_time = time.time()
print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6]))
end_time = time.time()
start_time = time.time()
searchspace = nasspace.get_search_space(args) searchspace = nasspace.get_search_space(args)
end_time = time.time() end_time = time.time()
print(f'search space time: {end_time - start_time}') print(f'search space time: {end_time - start_time}')
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
print('start to get score') print('start to get score')
print('5374') print('5374')
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
start_time = time.time()
print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
start_time = time.time() start_time = time.time()
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time() end_time = time.time()

Some files were not shown because too many files have changed in this diff Show More