update the naswot code

This commit is contained in:
mhz 2024-08-14 16:24:13 +02:00
parent 4552f55b4d
commit b36ecd3ad0
5 changed files with 104 additions and 98 deletions

File diff suppressed because one or more lines are too long

View File

@ -49,7 +49,7 @@ def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin
val_acc_type = 'x-valid' val_acc_type = 'x-valid'
if trainval and 'cifar10' in dataset: if trainval and 'cifar10' in dataset:
cifar_split = load_config('config_utils/cifar-split.txt', None, None) cifar_split = load_config('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/naswot/naswot/config_utils/cifardata/cifar-split.txt', None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid train_split, valid_split = cifar_split.train, cifar_split.valid
if repeat > 0: if repeat > 0:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,

@ -0,0 +1 @@
Subproject commit b94247037ee470418a3e56dcb83814e9be83f3a8

View File

@ -12,41 +12,41 @@ import time
from naswot.models import get_cell_based_tiny_net from naswot.models import get_cell_based_tiny_net
from naswot.utils import add_dropout, init_network from naswot.utils import add_dropout, init_network
parser = argparse.ArgumentParser(description='NAS Without Training') # parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') # parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth', # parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
type=str, help='path to API') # type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') # parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file') # parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate') # parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use') # parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=128, type=int) # parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch') # parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use') # parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"') # parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str) # parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int) # parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--init', default='', type=str) # parser.add_argument('--init', default='', type=str)
parser.add_argument('--trainval', action='store_true') # parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dropout', action='store_true') # parser.add_argument('--dropout', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str) # parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network') # parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int) # parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int) # parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)') # parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)') # parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)') # parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)') # parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
args = parser.parse_args() # args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU # os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
# Reproducibility # # Reproducibility
torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # torch.backends.cudnn.benchmark = False
random.seed(args.seed) # random.seed(args.seed)
np.random.seed(args.seed) # np.random.seed(args.seed)
torch.manual_seed(args.seed) # torch.manual_seed(args.seed)
def get_batch_jacobian(net, x, target, device, args=None): def get_batch_jacobian(net, x, target, device, args=None):
@ -58,10 +58,16 @@ def get_batch_jacobian(net, x, target, device, args=None):
return jacob, target.detach(), y.detach(), out.detach() return jacob, target.detach(), y.detach(), out.detach()
def get_config_by_nodes(nodes): def get_config_by_nodes(nodes):
# check if the nodes[0] is a number
if isinstance(nodes[0], int):
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \ arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \
num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \ num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \
num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|' num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|'
else:
arch_str = '|' + nodes[1] + '~0|+|' + \
nodes[2] + '~0|' + nodes[3] + '~1|+|' + \
nodes[4] + '~0|' + nodes[5] + '~1|' + nodes[6] + '~2|'
config = { config = {
'name': 'infer.tiny', 'name': 'infer.tiny',
'C': 16, 'C': 16,
@ -234,64 +240,64 @@ def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device):
print('final result') print('final result')
return np.nan return np.nan
class Args: # class Args:
pass # pass
args = Args() # args = Args()
args.trainval = True # args.trainval = True
args.augtype = 'none' # args.augtype = 'none'
args.repeat = 1 # args.repeat = 1
args.score = 'hook_logdet' # args.score = 'hook_logdet'
args.sigma = 0.05 # args.sigma = 0.05
args.nasspace = 'nasbench201' # args.nasspace = 'nasbench201'
args.batch_size = 128 # args.batch_size = 128
args.GPU = '0' # args.GPU = '0'
args.dataset = 'cifar10-valid' # args.dataset = 'cifar10-valid'
args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' # args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
args.data_loc = '../cifardata/' # args.data_loc = '../cifardata/'
args.seed = 777 # args.seed = 777
args.init = '' # args.init = ''
args.save_loc = 'results' # args.save_loc = 'results'
args.save_string = 'naswot' # args.save_string = 'naswot'
args.dropout = False # args.dropout = False
args.maxofn = 1 # args.maxofn = 1
args.n_samples = 100 # args.n_samples = 100
args.n_runs = 500 # args.n_runs = 500
args.stem_out_channels = 16 # args.stem_out_channels = 16
args.num_stacks = 3 # args.num_stacks = 3
args.num_modules_per_stack = 3 # args.num_modules_per_stack = 3
args.num_labels = 1 # args.num_labels = 1
if 'valid' in args.dataset: # if 'valid' in args.dataset:
args.dataset = args.dataset.replace('-valid', '') # args.dataset = args.dataset.replace('-valid', '')
print('start to get search space') # print('start to get search space')
start_time = time.time() # start_time = time.time()
print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6])) # print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6]))
end_time = time.time() # end_time = time.time()
start_time = time.time() # start_time = time.time()
searchspace = nasspace.get_search_space(args) # searchspace = nasspace.get_search_space(args)
end_time = time.time() # end_time = time.time()
print(f'search space time: {end_time - start_time}') # print(f'search space time: {end_time - start_time}')
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
print('start to get score') # print('start to get score')
print('5374') # print('5374')
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] # num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
start_time = time.time() # start_time = time.time()
print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) # print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time() # end_time = time.time()
start_time = time.time() # start_time = time.time()
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) # print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time() # end_time = time.time()
print(f'5374 time: {end_time - start_time}') # print(f'5374 time: {end_time - start_time}')
print('5375') # print('5375')
start_time = time.time() # start_time = time.time()
print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) # print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time() # end_time = time.time()
print(f'5375 time: {end_time - start_time}') # print(f'5375 time: {end_time - start_time}')
print('5376') # print('5376')
start_time = time.time() # start_time = time.time()
print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) # print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time() # end_time = time.time()
print(f'5376 time: {end_time - start_time}') # print(f'5376 time: {end_time - start_time}')
# device = "cuda:0" # device = "cuda:0"
# dataset = dataset # dataset = dataset

View File

@ -3,5 +3,8 @@ from setuptools import setup, find_packages
setup( setup(
name='naswot', name='naswot',
version='0.1', version='0.1',
packages=find_packages() packages=find_packages(),
package_data={
'naswot': ['config_utils/cifardata/*']
}
) )