update the naswot code

This commit is contained in:
mhz 2024-08-14 16:24:13 +02:00
parent 4552f55b4d
commit b36ecd3ad0
5 changed files with 104 additions and 98 deletions

File diff suppressed because one or more lines are too long

View File

@ -49,7 +49,7 @@ def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin
val_acc_type = 'x-valid'
if trainval and 'cifar10' in dataset:
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
cifar_split = load_config('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/naswot/naswot/config_utils/cifardata/cifar-split.txt', None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
if repeat > 0:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,

@ -0,0 +1 @@
Subproject commit b94247037ee470418a3e56dcb83814e9be83f3a8

View File

@ -12,41 +12,41 @@ import time
from naswot.models import get_cell_based_tiny_net
from naswot.utils import add_dropout, init_network
parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--init', default='', type=str)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dropout', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
# parser = argparse.ArgumentParser(description='NAS Without Training')
# parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
# parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
# type=str, help='path to API')
# parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
# parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
# parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
# parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
# parser.add_argument('--batch_size', default=128, type=int)
# parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
# parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
# parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
# parser.add_argument('--GPU', default='0', type=str)
# parser.add_argument('--seed', default=1, type=int)
# parser.add_argument('--init', default='', type=str)
# parser.add_argument('--trainval', action='store_true')
# parser.add_argument('--dropout', action='store_true')
# parser.add_argument('--dataset', default='cifar10', type=str)
# parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network')
# parser.add_argument('--n_samples', default=100, type=int)
# parser.add_argument('--n_runs', default=500, type=int)
# parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
# parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
# parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
# parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
# args = parser.parse_args()
# os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# # Reproducibility
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
def get_batch_jacobian(net, x, target, device, args=None):
@ -58,10 +58,16 @@ def get_batch_jacobian(net, x, target, device, args=None):
return jacob, target.detach(), y.detach(), out.detach()
def get_config_by_nodes(nodes):
# check if the nodes[0] is a number
if isinstance(nodes[0], int):
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \
num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \
num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|'
else:
arch_str = '|' + nodes[1] + '~0|+|' + \
nodes[2] + '~0|' + nodes[3] + '~1|+|' + \
nodes[4] + '~0|' + nodes[5] + '~1|' + nodes[6] + '~2|'
config = {
'name': 'infer.tiny',
'C': 16,
@ -234,64 +240,64 @@ def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device):
print('final result')
return np.nan
class Args:
pass
args = Args()
args.trainval = True
args.augtype = 'none'
args.repeat = 1
args.score = 'hook_logdet'
args.sigma = 0.05
args.nasspace = 'nasbench201'
args.batch_size = 128
args.GPU = '0'
args.dataset = 'cifar10-valid'
args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
args.data_loc = '../cifardata/'
args.seed = 777
args.init = ''
args.save_loc = 'results'
args.save_string = 'naswot'
args.dropout = False
args.maxofn = 1
args.n_samples = 100
args.n_runs = 500
args.stem_out_channels = 16
args.num_stacks = 3
args.num_modules_per_stack = 3
args.num_labels = 1
# class Args:
# pass
# args = Args()
# args.trainval = True
# args.augtype = 'none'
# args.repeat = 1
# args.score = 'hook_logdet'
# args.sigma = 0.05
# args.nasspace = 'nasbench201'
# args.batch_size = 128
# args.GPU = '0'
# args.dataset = 'cifar10-valid'
# args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
# args.data_loc = '../cifardata/'
# args.seed = 777
# args.init = ''
# args.save_loc = 'results'
# args.save_string = 'naswot'
# args.dropout = False
# args.maxofn = 1
# args.n_samples = 100
# args.n_runs = 500
# args.stem_out_channels = 16
# args.num_stacks = 3
# args.num_modules_per_stack = 3
# args.num_labels = 1
if 'valid' in args.dataset:
args.dataset = args.dataset.replace('-valid', '')
print('start to get search space')
start_time = time.time()
print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6]))
end_time = time.time()
start_time = time.time()
searchspace = nasspace.get_search_space(args)
end_time = time.time()
print(f'search space time: {end_time - start_time}')
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
print('start to get score')
print('5374')
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
start_time = time.time()
print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
start_time = time.time()
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
print(f'5374 time: {end_time - start_time}')
print('5375')
start_time = time.time()
print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
print(f'5375 time: {end_time - start_time}')
print('5376')
start_time = time.time()
print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
print(f'5376 time: {end_time - start_time}')
# if 'valid' in args.dataset:
# args.dataset = args.dataset.replace('-valid', '')
# print('start to get search space')
# start_time = time.time()
# print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6]))
# end_time = time.time()
# start_time = time.time()
# searchspace = nasspace.get_search_space(args)
# end_time = time.time()
# print(f'search space time: {end_time - start_time}')
# train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
# print('start to get score')
# print('5374')
# num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
# start_time = time.time()
# print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
# end_time = time.time()
# start_time = time.time()
# print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
# end_time = time.time()
# print(f'5374 time: {end_time - start_time}')
# print('5375')
# start_time = time.time()
# print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
# end_time = time.time()
# print(f'5375 time: {end_time - start_time}')
# print('5376')
# start_time = time.time()
# print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
# end_time = time.time()
# print(f'5376 time: {end_time - start_time}')
# device = "cuda:0"
# dataset = dataset

View File

@ -3,5 +3,8 @@ from setuptools import setup, find_packages
setup(
name='naswot',
version='0.1',
packages=find_packages()
packages=find_packages(),
package_data={
'naswot': ['config_utils/cifardata/*']
}
)