update code styles

This commit is contained in:
D-X-Y 2020-01-09 22:26:23 +11:00
parent 5ac5060a33
commit ad34af9913
26 changed files with 192 additions and 81 deletions

View File

@ -41,7 +41,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_
## Performance on ImageNet ## Performance on ImageNet
| Model | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error | Optimizer | | Model | FLOPs (GB) | Params (M) | Top-1 Error | Top-5 Error | Optimizer |
|:--------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:| |:-----------------:|:----------:|:----------:|:-----------:|:-----------:|:----------:|
| ResNet-18 | 1.814 | 11.69 | 30.24 | 10.92 | Official | | ResNet-18 | 1.814 | 11.69 | 30.24 | 10.92 | Official |
| ResNet-18 | 1.814 | 11.69 | 29.97 | 10.43 | Step-120 | | ResNet-18 | 1.814 | 11.69 | 29.97 | 10.43 | Step-120 |
| ResNet-18 | 1.814 | 11.69 | 29.35 | 10.13 | Cosine-120 | | ResNet-18 | 1.814 | 11.69 | 29.35 | 10.13 | Cosine-120 |

View File

@ -4,7 +4,7 @@
The following is a set of guidelines for contributing to NAS-Projects. The following is a set of guidelines for contributing to NAS-Projects.
#### Table Of Contents ## Table Of Contents
[How Can I Contribute?](#how-can-i-contribute) [How Can I Contribute?](#how-can-i-contribute)
* [Reporting Bugs](#reporting-bugs) * [Reporting Bugs](#reporting-bugs)

View File

@ -140,6 +140,8 @@ This command will train 390 architectures (id from 0 to 389) using the following
| CIFAR-100 | train | valid / test | | CIFAR-100 | train | valid / test |
| ImageNet-16-120 | train | valid / test | | ImageNet-16-120 | train | valid / test |
Note that the above `train`, `valid`, and `test` indicate the proposed splits in our NAS-Bench-102, and they might be different with the original splits.
3. calculate the latency, merge the results of all architectures, and simplify the results. 3. calculate the latency, merge the results of all architectures, and simplify the results.
(see commands in `output/NAS-BENCH-102-4/meta-node-4.cal-script.txt` which is automatically generated by step-1). (see commands in `output/NAS-BENCH-102-4/meta-node-4.cal-script.txt` which is automatically generated by step-1).
``` ```
@ -167,7 +169,7 @@ If researchers can provide better results with different hyper-parameters, we ar
**Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download) **Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download)
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1` - [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`, where `cifar10` can be replaced with `cifar100` or `ImageNet16-120`.
- [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` - [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
- [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1` - [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`
- [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1` - [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1`

View File

@ -8,7 +8,6 @@ from tqdm import tqdm
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
import matplotlib import matplotlib
@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
def get_accs(xdata): def get_accs(xdata):
epochs, xresults = xdata['epoch'], [] epochs, xresults = xdata['epoch'], []
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
xresults.append( metrics['accuracy'] )
for iepoch in range(epochs): for iepoch in range(epochs):
genotype = xdata['genotypes'][iepoch] genotype = xdata['genotypes'][iepoch]
index = api.query_index_by_arch(genotype) index = api.query_index_by_arch(genotype)
@ -547,7 +548,6 @@ if __name__ == '__main__':
#visualize_relative_ranking(vis_save_dir) #visualize_relative_ranking(vis_save_dir)
api = API(args.api_path) api = API(args.api_path)
"""
for x_maxs in [50, 250]: for x_maxs in [50, 250]:
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
@ -555,11 +555,12 @@ if __name__ == '__main__':
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
just_show(api)
""" """
just_show(api)
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
"""

View File

@ -10,7 +10,6 @@ from copy import deepcopy
from pathlib import Path from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str from config_utils import load_config, dict2config, configure2str

View File

@ -121,9 +121,19 @@ def main(xargs):
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
elif xargs.dataset == 'cifar100': elif xargs.dataset == 'cifar100':
raise ValueError('not support yet : {:}'.format(xargs.dataset)) cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
elif xargs.dataset.startswith('ImageNet16'): search_train_data = train_data
raise ValueError('not support yet : {:}'.format(xargs.dataset)) search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
elif xargs.dataset == 'ImageNet16-120':
imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
else: else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
@ -168,7 +178,7 @@ def main(xargs):
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else: else:
logger.log("=> do not find the last-info file : {:}".format(last_info)) logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
# start training # start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
@ -230,7 +240,7 @@ if __name__ == '__main__':
parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells # channels and number-of-cells
parser.add_argument('--config_path', type=str, help='The config paths.') parser.add_argument('--config_path', type=str, help='The config path.')
parser.add_argument('--search_space_name', type=str, help='The search space name.') parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--channel', type=int, help='The number of channels.')

View File

@ -181,8 +181,8 @@ def main(xargs):
logger.log('Load split file from {:}'.format(split_Fpath)) logger.log('Load split file from {:}'.format(split_Fpath))
else: else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
config_path = 'configs/nas-benchmark/algos/DARTS.config' #config_path = 'configs/nas-benchmark/algos/DARTS.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data # To split data
train_data_v2 = deepcopy(train_data) train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform train_data_v2.transform = valid_data.transform
@ -233,7 +233,7 @@ def main(xargs):
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else: else:
logger.log("=> do not find the last-info file : {:}".format(last_info)) logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
# start training # start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
@ -297,6 +297,7 @@ if __name__ == '__main__':
parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells # channels and number-of-cells
parser.add_argument('--config_path', type=str, help='The config path.')
parser.add_argument('--search_space_name', type=str, help='The search space name.') parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--channel', type=int, help='The number of channels.')

View File

@ -3,7 +3,7 @@
########################################################################### ###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
########################################################################### ###########################################################################
import os, sys, time, glob, random, argparse import os, sys, time, random, argparse
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
@ -11,7 +11,7 @@ import torch.nn as nn
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str from config_utils import load_config, dict2config
from datasets import get_datasets, SearchDataset from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy

View File

@ -1,12 +1,14 @@
# python ./exps/vis/test.py # python ./exps/vis/test.py
import os, sys, random import os, sys, random
from pathlib import Path from pathlib import Path
from copy import deepcopy
import torch import torch
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_102_api import NASBench102API as API
def test_nas_api(): def test_nas_api():
from nas_102_api import ArchResults from nas_102_api import ArchResults
@ -72,7 +74,40 @@ def test_auto_grad():
s_grads = torch.autograd.grad(grads, net.parameters()) s_grads = torch.autograd.grad(grads, net.parameters())
second_order_grads.append( s_grads ) second_order_grads.append( s_grads )
def test_one_shot_model(ckpath, use_train):
from models import get_cell_based_tiny_net, get_search_spaces
from datasets import get_datasets, SearchDataset
from config_utils import load_config, dict2config
from utils.nas_utils import evaluate_one_shot
use_train = int(use_train) > 0
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
print ('ckpath : {:}'.format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp['args']
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
if xargs.dataset == 'cifar10':
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
xvalid_data = deepcopy(train_data)
xvalid_data.transform = valid_data.transform
valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': True}, None)
search_model = get_cell_based_tiny_net(model_config)
search_model.load_state_dict( ckp['search_model'] )
search_model = search_model.cuda()
api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth')
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
if __name__ == '__main__': if __name__ == '__main__':
#test_nas_api() #test_nas_api()
#for i in range(200): plot('{:04d}'.format(i)) #for i in range(200): plot('{:04d}'.format(i))
test_auto_grad() #test_auto_grad()
test_one_shot_model(sys.argv[1], sys.argv[2])

View File

@ -9,6 +9,15 @@ class SearchDataset(data.Dataset):
def __init__(self, name, data, train_split, valid_split, check=True): def __init__(self, name, data, train_split, valid_split, check=True):
self.datasetname = name self.datasetname = name
if isinstance(data, (list, tuple)): # new type of SearchDataset
assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
self.train_data = data[0]
self.valid_data = data[1]
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
self.mode_str = 'V2' # new mode
else:
self.mode_str = 'V1' # old mode
self.data = data self.data = data
self.train_split = train_split.copy() self.train_split = train_split.copy()
self.valid_split = valid_split.copy() self.valid_split = valid_split.copy()
@ -18,7 +27,7 @@ class SearchDataset(data.Dataset):
self.length = len(self.train_split) self.length = len(self.train_split)
def __repr__(self): def __repr__(self):
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split))) return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
def __len__(self): def __len__(self):
return self.length return self.length
@ -27,6 +36,11 @@ class SearchDataset(data.Dataset):
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
train_index = self.train_split[index] train_index = self.train_split[index]
valid_index = random.choice( self.valid_split ) valid_index = random.choice( self.valid_split )
if self.mode_str == 'V1':
train_image, train_label = self.data[train_index] train_image, train_label = self.data[train_index]
valid_image, valid_label = self.data[valid_index] valid_image, valid_label = self.data[valid_index]
elif self.mode_str == 'V2':
train_image, train_label = self.train_data[train_index]
valid_image, valid_label = self.valid_data[valid_index]
else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
return train_image, train_label, valid_image, valid_label return train_image, train_label, valid_image, valid_label

View File

@ -34,7 +34,7 @@ class PointMeta():
def get_box(self, return_diagonal=False): def get_box(self, return_diagonal=False):
if self.box is None: return None if self.box is None: return None
if return_diagonal == False: if not return_diagonal:
return self.box.clone() return self.box.clone()
else: else:
W = (self.box[2]-self.box[0]).item() W = (self.box[2]-self.box[0]).item()

View File

@ -1,4 +1,3 @@
import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from ..cell_operations import OPS from ..cell_operations import OPS

View File

@ -68,7 +68,7 @@ class Structure:
for i, node_info in enumerate(self.nodes): for i, node_info in enumerate(self.nodes):
sums = [] sums = []
for op, xin in node_info: for op, xin in node_info:
if op == 'none' or nodes[xin] == False: x = False if op == 'none' or nodes[xin] is False: x = False
else: x = True else: x = True
sums.append( x ) sums.append( x )
nodes[i+1] = sum(sums) > 0 nodes[i+1] = sum(sums) > 0

View File

@ -85,7 +85,7 @@ class SearchCell(nn.Module):
candidates = self.edges[node_str] candidates = self.edges[node_str]
select_op = random.choice(candidates) select_op = random.choice(candidates)
sops.append( select_op ) sops.append( select_op )
if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
if has_non_zero: break if has_non_zero: break
inter_nodes = [] inter_nodes = []
for j, select_op in enumerate(sops): for j, select_op in enumerate(sops):

View File

@ -1,4 +1,4 @@
import math, torch import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..initialization import initialize_resnet from ..initialization import initialize_resnet

View File

@ -70,6 +70,9 @@ class NASBench102API(object):
def __repr__(self): def __repr__(self):
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
def random(self):
return random.randint(0, len(self.meta_archs)-1)
def query_index_by_arch(self, arch): def query_index_by_arch(self, arch):
if isinstance(arch, str): if isinstance(arch, str):
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]

View File

@ -1,7 +1,7 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
import os, sys, time, torch, random, PIL, copy, numpy as np import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp from os import path as osp
from shutil import copyfile from shutil import copyfile

View File

@ -1,3 +1,5 @@
from .evaluation_utils import obtain_accuracy from .evaluation_utils import obtain_accuracy
from .gpu_manager import GPUManager from .gpu_manager import GPUManager
from .flop_benchmark import get_model_infos from .flop_benchmark import get_model_infos
from .affine_utils import normalize_points, denormalize_points
from .affine_utils import identity2affine, solve2theta, affine2image

View File

@ -1,10 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#
# functions for affine transformation # functions for affine transformation
import math, torch import math, torch
import numpy as np import numpy as np

View File

@ -1,4 +1,4 @@
import copy, torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np

View File

@ -27,7 +27,7 @@ class GPUManager():
find = False find = False
for gpu in all_gpus: for gpu in all_gpus:
if gpu['index'] == CUDA_VISIBLE_DEVICE: if gpu['index'] == CUDA_VISIBLE_DEVICE:
assert find==False, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE) assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
find = True find = True
selected_gpus.append( gpu.copy() ) selected_gpus.append( gpu.copy() )
selected_gpus[-1]['index'] = '{}'.format(idx) selected_gpus[-1]['index'] = '{}'.format(idx)

52
lib/utils/nas_utils.py Normal file
View File

@ -0,0 +1,52 @@
# This file is for experimental usage
import os, sys, torch, random
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch.nn as nn
from utils import obtain_accuracy
from models import CellStructure
from log_utils import time_string
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
weights = deepcopy(model.state_dict())
model.train(cal_mode)
with torch.no_grad():
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
probs, accuracies, gt_accs = [], [], []
loader_iter = iter(xloader)
random.seed(seed)
random.shuffle(archs)
for idx, arch in enumerate(archs):
arch_index = api.query_index_by_arch( arch )
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
gt_accs.append( metrics['valid-accuracy'] )
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = '{:}<-{:}'.format(i+1, xin)
op_index = model.op_names.index(op)
select_logits.append( logits[model.edge2index[node_str], op_index] )
cur_prob = sum(select_logits).item()
probs.append( cur_prob )
cor_prob = np.corrcoef(probs, gt_accs)[0,1]
print ('correlation for probabilities : {:}'.format(cor_prob))
for idx, arch in enumerate(archs):
model.set_cal_mode('dynamic', arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = model(inputs.cuda())
_, preds = torch.max(logits, dim=-1)
correct = (preds == targets.cuda() ).float()
accuracies.append( correct.mean().item() )
if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10):
cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1]
print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch))
model.load_state_dict(weights)
return archs, probs, accuracies

View File

@ -1 +0,0 @@
from .affine_utils import normalize_points, denormalize_points

View File

@ -33,6 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--config_path configs/nas-benchmark/algos/DARTS.config \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
# bash ./scripts/prepare.sh # bash ./scripts/prepare.sh
datasets='cifar10 cifar100 imagenet-1k' #datasets='cifar10 cifar100 imagenet-1k'
#ratios='0.5 0.8 0.9' #ratios='0.5 0.8 0.9'
ratios='0.5' ratios='0.5'
save_dir=./.latent-data/splits save_dir=./.latent-data/splits

View File

@ -33,7 +33,7 @@ OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
--procedure basic \ --procedure basic \
--save_dir ${xsave_dir} \ --save_dir ${xsave_dir} \
--cutout_length -1 \ --cutout_length -1 \
--batch_size 256 --rand_seed ${rseed} --workers 6 \ --batch_size ${batch} --rand_seed ${rseed} --workers 6 \
--eval_frequency 1 --print_freq 100 --print_freq_eval 200 --eval_frequency 1 --print_freq 100 --print_freq_eval 200
# KD training # KD training
@ -47,5 +47,5 @@ OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
--save_dir ${xsave_dir} \ --save_dir ${xsave_dir} \
--KD_alpha 0.9 --KD_temperature 4 \ --KD_alpha 0.9 --KD_temperature 4 \
--cutout_length -1 \ --cutout_length -1 \
--batch_size 256 --rand_seed ${rseed} --workers 6 \ --batch_size ${batch} --rand_seed ${rseed} --workers 6 \
--eval_frequency 1 --print_freq 100 --print_freq_eval 200 --eval_frequency 1 --print_freq 100 --print_freq_eval 200