Compare commits
11 Commits
5bf036a763
...
4612cd198b
Author | SHA1 | Date | |
---|---|---|---|
|
4612cd198b | ||
889bd1974c | |||
af0e7786b6 | |||
c6d53f08ae | |||
ef2608bb42 | |||
50ff507a15 | |||
03d7d04d41 | |||
bb33ca9a68 | |||
|
f46486e21b | ||
|
5908a1edef | ||
|
ed34024a88 |
@ -61,13 +61,13 @@ At this moment, this project provides the following algorithms and scripts to ru
|
|||||||
<tr> <!-- (6-th row) -->
|
<tr> <!-- (6-th row) -->
|
||||||
<td align="center" valign="middle"> NATS-Bench </td>
|
<td align="center" valign="middle"> NATS-Bench </td>
|
||||||
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
|
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
|
||||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr> <!-- (7-th row) -->
|
<tr> <!-- (7-th row) -->
|
||||||
<td align="center" valign="middle"> ... </td>
|
<td align="center" valign="middle"> ... </td>
|
||||||
<td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td>
|
<td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td>
|
||||||
<td align="center" valign="middle"> Please check the original papers </td>
|
<td align="center" valign="middle"> Please check the original papers </td>
|
||||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr> <!-- (start second block) -->
|
<tr> <!-- (start second block) -->
|
||||||
<td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td>
|
<td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td>
|
||||||
|
@ -29,7 +29,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
|
|||||||
You can move it to anywhere you want and send its path to our API for initialization.
|
You can move it to anywhere you want and send its path to our API for initialization.
|
||||||
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||||
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
|
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
|
||||||
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights.
|
||||||
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
||||||
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
|
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
|
||||||
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
|
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
|
||||||
|
@ -27,7 +27,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
|
|||||||
You can move it to anywhere you want and send its path to our API for initialization.
|
You can move it to anywhere you want and send its path to our API for initialization.
|
||||||
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||||
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
|
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
|
||||||
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights.
|
||||||
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
||||||
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
|
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
|
||||||
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
|
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
---------
|
---------
|
||||||
[](LICENSE.md)
|
[](../LICENSE.md)
|
||||||
|
|
||||||
自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。
|
自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。
|
||||||
该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。
|
该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。
|
||||||
@ -142,8 +142,8 @@
|
|||||||
|
|
||||||
# 其他
|
# 其他
|
||||||
|
|
||||||
如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](.github/CONTRIBUTING.md)。
|
如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](../.github/CONTRIBUTING.md)。
|
||||||
此外,使用规范请参考[CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md)。
|
此外,使用规范请参考[CODE-OF-CONDUCT.md](../.github/CODE-OF-CONDUCT.md)。
|
||||||
|
|
||||||
# 许可证
|
# 许可证
|
||||||
The entire codebase is under [MIT license](LICENSE.md)
|
The entire codebase is under [MIT license](../LICENSE.md)
|
||||||
|
@ -2,11 +2,11 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||||
#####################################################
|
#####################################################
|
||||||
import time, torch
|
import time, torch
|
||||||
from procedures import prepare_seed, get_optim_scheduler
|
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||||
from config_utils import dict2config
|
from xautodl.config_utils import dict2config
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net
|
from xautodl.models import get_cell_based_tiny_net
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["evaluate_for_seed", "pure_evaluate"]
|
__all__ = ["evaluate_for_seed", "pure_evaluate"]
|
||||||
|
@ -16,10 +16,96 @@ from xautodl.procedures import get_machine_info
|
|||||||
from xautodl.datasets import get_datasets
|
from xautodl.datasets import get_datasets
|
||||||
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||||
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
|
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
|
||||||
from xautodl.functions import evaluate_for_seed
|
from functions import evaluate_for_seed
|
||||||
|
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
|
# NASBENCH201_CONFIG_PATH = os.path.join( os.getcwd(), 'main_exp', 'transfer_nag')
|
||||||
|
|
||||||
|
NASBENCH201_CONFIG_PATH = '/lustre/hpe/ws11/ws11.1/ws/xmuhanma-nbdit/autodl-projects/configs/nas-benchmark'
|
||||||
|
|
||||||
|
|
||||||
def evaluate_all_datasets(
|
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
|
||||||
|
arch_config, workers, logger):
|
||||||
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||||
|
all_infos = {'info': machine_info}
|
||||||
|
all_dataset_keys = []
|
||||||
|
# look all the datasets
|
||||||
|
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||||
|
# train valid data
|
||||||
|
task = None
|
||||||
|
train_data, valid_data, xshape, class_num = get_datasets(
|
||||||
|
dataset, xpath, -1, task)
|
||||||
|
|
||||||
|
# load the configuration
|
||||||
|
if dataset in ['mnist', 'svhn', 'aircraft', 'oxford']:
|
||||||
|
if use_less:
|
||||||
|
# config_path = os.path.join(
|
||||||
|
# NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
|
||||||
|
config_path = os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, 'LESS.config')
|
||||||
|
else:
|
||||||
|
# config_path = os.path.join(
|
||||||
|
# NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
|
||||||
|
config_path = os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{}.config'.format(dataset))
|
||||||
|
|
||||||
|
|
||||||
|
p = os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset))
|
||||||
|
if not os.path.exists(p):
|
||||||
|
import json
|
||||||
|
label_list = list(range(len(train_data)))
|
||||||
|
random.shuffle(label_list)
|
||||||
|
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||||
|
splited = {'train': ["int", strlist[:len(train_data) // 2]],
|
||||||
|
'valid': ["int", strlist[len(train_data) // 2:]]}
|
||||||
|
with open(p, 'w') as f:
|
||||||
|
f.write(json.dumps(splited))
|
||||||
|
split_info = load_config(os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)), None, None)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||||
|
|
||||||
|
config = load_config(
|
||||||
|
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
|
# data loader
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
|
||||||
|
shuffle=True, num_workers=workers, pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||||
|
shuffle=False, num_workers=workers, pin_memory=True)
|
||||||
|
splits = load_config(os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{}-test-split.txt'.format(dataset)), None, None)
|
||||||
|
ValLoaders = {'ori-test': valid_loader,
|
||||||
|
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
splits.xvalid),
|
||||||
|
num_workers=workers, pin_memory=True),
|
||||||
|
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
splits.xtest),
|
||||||
|
num_workers=workers, pin_memory=True)
|
||||||
|
}
|
||||||
|
dataset_key = '{:}'.format(dataset)
|
||||||
|
if bool(split):
|
||||||
|
dataset_key = dataset_key + '-valid'
|
||||||
|
logger.log(
|
||||||
|
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
|
||||||
|
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
||||||
|
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
|
||||||
|
dataset_key, config))
|
||||||
|
for key, value in ValLoaders.items():
|
||||||
|
logger.log(
|
||||||
|
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
|
||||||
|
|
||||||
|
results = evaluate_for_seed(
|
||||||
|
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
|
||||||
|
all_infos[dataset_key] = results
|
||||||
|
all_dataset_keys.append(dataset_key)
|
||||||
|
all_infos['all_dataset_keys'] = all_dataset_keys
|
||||||
|
return all_infos
|
||||||
|
|
||||||
|
def evaluate_all_datasets1(
|
||||||
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
||||||
):
|
):
|
||||||
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||||
@ -46,47 +132,117 @@ def evaluate_all_datasets(
|
|||||||
split_info = load_config(
|
split_info = load_config(
|
||||||
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
)
|
)
|
||||||
|
elif dataset.startswith("aircraft"):
|
||||||
|
if use_less:
|
||||||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||||||
|
else:
|
||||||
|
config_path = "configs/nas-benchmark/aircraft.config"
|
||||||
|
split_info = load_config(
|
||||||
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
|
)
|
||||||
|
elif dataset.startswith("oxford"):
|
||||||
|
if use_less:
|
||||||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||||||
|
else:
|
||||||
|
config_path = "configs/nas-benchmark/oxford.config"
|
||||||
|
split_info = load_config(
|
||||||
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid dataset : {:}".format(dataset))
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
||||||
config = load_config(
|
config = load_config(
|
||||||
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
||||||
)
|
)
|
||||||
# check whether use splited validation set
|
# check whether use splited validation set
|
||||||
|
# if dataset == 'aircraft':
|
||||||
|
# split = True
|
||||||
if bool(split):
|
if bool(split):
|
||||||
assert dataset == "cifar10"
|
if dataset == "cifar10" or dataset == "cifar100":
|
||||||
ValLoaders = {
|
assert dataset == "cifar10"
|
||||||
"ori-test": torch.utils.data.DataLoader(
|
ValLoaders = {
|
||||||
valid_data,
|
"ori-test": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
assert len(train_data) == len(split_info.train) + len(
|
||||||
|
split_info.valid
|
||||||
|
), "invalid length : {:} vs {:} + {:}".format(
|
||||||
|
len(train_data), len(split_info.train), len(split_info.valid)
|
||||||
|
)
|
||||||
|
train_data_v2 = deepcopy(train_data)
|
||||||
|
train_data_v2.transform = valid_data.transform
|
||||||
|
valid_data = train_data_v2
|
||||||
|
# data loader
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_data,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=False,
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
}
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
assert len(train_data) == len(split_info.train) + len(
|
valid_data,
|
||||||
split_info.valid
|
batch_size=config.batch_size,
|
||||||
), "invalid length : {:} vs {:} + {:}".format(
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
len(train_data), len(split_info.train), len(split_info.valid)
|
num_workers=workers,
|
||||||
)
|
pin_memory=True,
|
||||||
train_data_v2 = deepcopy(train_data)
|
)
|
||||||
train_data_v2.transform = valid_data.transform
|
ValLoaders["x-valid"] = valid_loader
|
||||||
valid_data = train_data_v2
|
elif dataset == "aircraft":
|
||||||
# data loader
|
ValLoaders = {
|
||||||
train_loader = torch.utils.data.DataLoader(
|
"ori-test": torch.utils.data.DataLoader(
|
||||||
train_data,
|
valid_data,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
shuffle=False,
|
||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
}
|
||||||
valid_data,
|
train_data_v2 = deepcopy(train_data)
|
||||||
batch_size=config.batch_size,
|
train_data_v2.transform = valid_data.transform
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
valid_data = train_data_v2
|
||||||
num_workers=workers,
|
# 使用 DataLoader
|
||||||
pin_memory=True,
|
train_loader = torch.utils.data.DataLoader(
|
||||||
)
|
train_data,
|
||||||
ValLoaders["x-valid"] = valid_loader
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
|
elif dataset == "oxford":
|
||||||
|
ValLoaders = {
|
||||||
|
"ori-test": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# train_data_v2 = deepcopy(train_data)
|
||||||
|
# train_data_v2.transform = valid_data.transform
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# data loader
|
# data loader
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
@ -103,7 +259,7 @@ def evaluate_all_datasets(
|
|||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
if dataset == "cifar10":
|
if dataset == "cifar10" or dataset == "aircraft" or dataset == "oxford":
|
||||||
ValLoaders = {"ori-test": valid_loader}
|
ValLoaders = {"ori-test": valid_loader}
|
||||||
elif dataset == "cifar100":
|
elif dataset == "cifar100":
|
||||||
cifar100_splits = load_config(
|
cifar100_splits = load_config(
|
||||||
|
@ -28,16 +28,41 @@ else
|
|||||||
mode=cover
|
mode=cover
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
|
# --mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
|
# --use_less ${use_less} \
|
||||||
|
# --datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||||
|
# --splits 1 0 0 0 \
|
||||||
|
# --xpaths $TORCH_HOME/cifar.python \
|
||||||
|
# $TORCH_HOME/cifar.python \
|
||||||
|
# $TORCH_HOME/cifar.python \
|
||||||
|
# $TORCH_HOME/cifar.python/ImageNet16 \
|
||||||
|
# --channel 16 --num_cells 5 \
|
||||||
|
# --workers 4 \
|
||||||
|
# --srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||||
|
# --seeds ${all_seeds}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
--use_less ${use_less} \
|
--use_less ${use_less} \
|
||||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
--datasets aircraft \
|
||||||
--splits 1 0 0 0 \
|
--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \
|
||||||
--xpaths $TORCH_HOME/cifar.python \
|
--channel 16 \
|
||||||
$TORCH_HOME/cifar.python \
|
--splits 1 \
|
||||||
$TORCH_HOME/cifar.python \
|
--num_cells 5 \
|
||||||
$TORCH_HOME/cifar.python/ImageNet16 \
|
|
||||||
--channel 16 --num_cells 5 \
|
|
||||||
--workers 4 \
|
--workers 4 \
|
||||||
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||||
--seeds ${all_seeds}
|
--seeds ${all_seeds}
|
||||||
|
|
||||||
|
# OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
|
# --mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
|
# --use_less ${use_less} \
|
||||||
|
# --datasets oxford\
|
||||||
|
# --xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \
|
||||||
|
# --channel 16 \
|
||||||
|
# --splits 1 \
|
||||||
|
# --num_cells 5 \
|
||||||
|
# --workers 4 \
|
||||||
|
# --srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||||
|
# --seeds ${all_seeds}
|
||||||
|
|
||||||
|
104336
test.ipynb
Normal file
104336
test.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
616
test_network.py
Normal file
616
test_network.py
Normal file
@ -0,0 +1,616 @@
|
|||||||
|
from nas_201_api import NASBench201API as API
|
||||||
|
import os
|
||||||
|
|
||||||
|
import os, sys, time, torch, random, argparse
|
||||||
|
from PIL import ImageFile
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from xautodl.config_utils import load_config
|
||||||
|
from xautodl.procedures import save_checkpoint, copy_checkpoint
|
||||||
|
from xautodl.procedures import get_machine_info
|
||||||
|
from xautodl.datasets import get_datasets
|
||||||
|
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||||
|
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
|
||||||
|
|
||||||
|
import time, torch
|
||||||
|
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||||
|
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||||
|
from xautodl.config_utils import dict2config
|
||||||
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
|
from xautodl.models import get_cell_based_tiny_net
|
||||||
|
|
||||||
|
cur_path = os.path.abspath(os.path.curdir)
|
||||||
|
data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')
|
||||||
|
print(f'loading data from {data_path}')
|
||||||
|
print(f'loading')
|
||||||
|
api = API(data_path)
|
||||||
|
print(f'loaded')
|
||||||
|
|
||||||
|
def find_best_index(dataset):
|
||||||
|
len = 15625
|
||||||
|
accs = []
|
||||||
|
for i in range(1, len):
|
||||||
|
results = api.query_by_index(i, dataset)
|
||||||
|
dict_items = list(results.items())
|
||||||
|
train_info = dict_items[0][1].get_train()
|
||||||
|
acc = train_info['accuracy']
|
||||||
|
accs.append((i, acc))
|
||||||
|
return max(accs, key=lambda x: x[1])
|
||||||
|
|
||||||
|
best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')
|
||||||
|
best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')
|
||||||
|
best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')
|
||||||
|
print(f'find best cifar10 index: {best_cifar_10_index}, acc: {best_cifar_10_acc}')
|
||||||
|
print(f'find best cifar100 index: {best_cifar_100_index}, acc: {best_cifar_100_acc}')
|
||||||
|
print(f'find best ImageNet16 index: {best_ImageNet16_index}, acc: {best_ImageNet16_acc}')
|
||||||
|
|
||||||
|
from xautodl.models import get_cell_based_tiny_net
|
||||||
|
def get_network_str_by_id(id, dataset):
|
||||||
|
config = api.get_net_config(id, dataset)
|
||||||
|
return config['arch_str']
|
||||||
|
|
||||||
|
best_cifar_10_str = get_network_str_by_id(best_cifar_10_index, 'cifar10')
|
||||||
|
best_cifar_100_str = get_network_str_by_id(best_cifar_100_index, 'cifar100')
|
||||||
|
best_ImageNet16_str = get_network_str_by_id(best_ImageNet16_index, 'ImageNet16-120')
|
||||||
|
|
||||||
|
def evaluate_all_datasets(
|
||||||
|
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
||||||
|
):
|
||||||
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||||
|
all_infos = {"info": machine_info}
|
||||||
|
all_dataset_keys = []
|
||||||
|
# look all the datasets
|
||||||
|
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||||
|
# train valid data
|
||||||
|
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||||
|
# load the configuration
|
||||||
|
if dataset == "cifar10" or dataset == "cifar100":
|
||||||
|
if use_less:
|
||||||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||||||
|
else:
|
||||||
|
config_path = "configs/nas-benchmark/CIFAR.config"
|
||||||
|
split_info = load_config(
|
||||||
|
"configs/nas-benchmark/cifar-split.txt", None, None
|
||||||
|
)
|
||||||
|
elif dataset.startswith("ImageNet16"):
|
||||||
|
if use_less:
|
||||||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||||||
|
else:
|
||||||
|
config_path = "configs/nas-benchmark/ImageNet-16.config"
|
||||||
|
split_info = load_config(
|
||||||
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
||||||
|
config = load_config(
|
||||||
|
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
||||||
|
)
|
||||||
|
# check whether use splited validation set
|
||||||
|
if bool(split):
|
||||||
|
assert dataset == "cifar10"
|
||||||
|
ValLoaders = {
|
||||||
|
"ori-test": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
assert len(train_data) == len(split_info.train) + len(
|
||||||
|
split_info.valid
|
||||||
|
), "invalid length : {:} vs {:} + {:}".format(
|
||||||
|
len(train_data), len(split_info.train), len(split_info.valid)
|
||||||
|
)
|
||||||
|
train_data_v2 = deepcopy(train_data)
|
||||||
|
train_data_v2.transform = valid_data.transform
|
||||||
|
valid_data = train_data_v2
|
||||||
|
# data loader
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
ValLoaders["x-valid"] = valid_loader
|
||||||
|
else:
|
||||||
|
# data loader
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
if dataset == "cifar10":
|
||||||
|
ValLoaders = {"ori-test": valid_loader}
|
||||||
|
elif dataset == "cifar100":
|
||||||
|
cifar100_splits = load_config(
|
||||||
|
"configs/nas-benchmark/cifar100-test-split.txt", None, None
|
||||||
|
)
|
||||||
|
ValLoaders = {
|
||||||
|
"ori-test": valid_loader,
|
||||||
|
"x-valid": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
cifar100_splits.xvalid
|
||||||
|
),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
),
|
||||||
|
"x-test": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
cifar100_splits.xtest
|
||||||
|
),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
elif dataset == "ImageNet16-120":
|
||||||
|
imagenet16_splits = load_config(
|
||||||
|
"configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
|
||||||
|
)
|
||||||
|
ValLoaders = {
|
||||||
|
"ori-test": valid_loader,
|
||||||
|
"x-valid": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
imagenet16_splits.xvalid
|
||||||
|
),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
),
|
||||||
|
"x-test": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
imagenet16_splits.xtest
|
||||||
|
),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
||||||
|
|
||||||
|
dataset_key = "{:}".format(dataset)
|
||||||
|
if bool(split):
|
||||||
|
dataset_key = dataset_key + "-valid"
|
||||||
|
logger.log(
|
||||||
|
"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
|
||||||
|
dataset_key,
|
||||||
|
len(train_data),
|
||||||
|
len(valid_data),
|
||||||
|
len(train_loader),
|
||||||
|
len(valid_loader),
|
||||||
|
config.batch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.log(
|
||||||
|
"Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
|
||||||
|
)
|
||||||
|
for key, value in ValLoaders.items():
|
||||||
|
logger.log(
|
||||||
|
"Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
|
||||||
|
)
|
||||||
|
results = evaluate_for_seed(
|
||||||
|
arch_config, config, arch, train_loader, ValLoaders, seed, logger
|
||||||
|
)
|
||||||
|
all_infos[dataset_key] = results
|
||||||
|
all_dataset_keys.append(dataset_key)
|
||||||
|
all_infos["all_dataset_keys"] = all_dataset_keys
|
||||||
|
return all_infos
|
||||||
|
|
||||||
|
def evaluate_for_seed(
|
||||||
|
arch_config, config, arch, train_loader, valid_loaders, seed, logger
|
||||||
|
):
|
||||||
|
|
||||||
|
prepare_seed(seed) # random seed
|
||||||
|
net = get_cell_based_tiny_net(
|
||||||
|
dict2config(
|
||||||
|
{
|
||||||
|
"name": "infer.tiny",
|
||||||
|
"C": arch_config["channel"],
|
||||||
|
"N": arch_config["num_cells"],
|
||||||
|
"genotype": arch,
|
||||||
|
"num_classes": config.class_num,
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||||
|
flop, param = get_model_infos(net, config.xshape)
|
||||||
|
logger.log("Network : {:}".format(net.get_message()), False)
|
||||||
|
logger.log(
|
||||||
|
"{:} Seed-------------------------- {:} --------------------------".format(
|
||||||
|
time_string(), seed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||||||
|
# train and valid
|
||||||
|
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
|
||||||
|
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
|
||||||
|
# start training
|
||||||
|
start_time, epoch_time, total_epoch = (
|
||||||
|
time.time(),
|
||||||
|
AverageMeter(),
|
||||||
|
config.epochs + config.warmup,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
train_losses,
|
||||||
|
train_acc1es,
|
||||||
|
train_acc5es,
|
||||||
|
valid_losses,
|
||||||
|
valid_acc1es,
|
||||||
|
valid_acc5es,
|
||||||
|
) = ({}, {}, {}, {}, {}, {})
|
||||||
|
train_times, valid_times = {}, {}
|
||||||
|
for epoch in range(total_epoch):
|
||||||
|
scheduler.update(epoch, 0.0)
|
||||||
|
|
||||||
|
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||||||
|
train_loader, network, criterion, scheduler, optimizer, "train"
|
||||||
|
)
|
||||||
|
train_losses[epoch] = train_loss
|
||||||
|
train_acc1es[epoch] = train_acc1
|
||||||
|
train_acc5es[epoch] = train_acc5
|
||||||
|
train_times[epoch] = train_tm
|
||||||
|
with torch.no_grad():
|
||||||
|
for key, xloder in valid_loaders.items():
|
||||||
|
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||||||
|
xloder, network, criterion, None, None, "valid"
|
||||||
|
)
|
||||||
|
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
|
||||||
|
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
|
||||||
|
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
|
||||||
|
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
epoch_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
need_time = "Time Left: {:}".format(
|
||||||
|
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
|
||||||
|
)
|
||||||
|
logger.log(
|
||||||
|
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format(
|
||||||
|
time_string(),
|
||||||
|
need_time,
|
||||||
|
epoch,
|
||||||
|
total_epoch,
|
||||||
|
train_loss,
|
||||||
|
train_acc1,
|
||||||
|
train_acc5,
|
||||||
|
valid_loss,
|
||||||
|
valid_acc1,
|
||||||
|
valid_acc5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
info_seed = {
|
||||||
|
"flop": flop,
|
||||||
|
"param": param,
|
||||||
|
"channel": arch_config["channel"],
|
||||||
|
"num_cells": arch_config["num_cells"],
|
||||||
|
"config": config._asdict(),
|
||||||
|
"total_epoch": total_epoch,
|
||||||
|
"train_losses": train_losses,
|
||||||
|
"train_acc1es": train_acc1es,
|
||||||
|
"train_acc5es": train_acc5es,
|
||||||
|
"train_times": train_times,
|
||||||
|
"valid_losses": valid_losses,
|
||||||
|
"valid_acc1es": valid_acc1es,
|
||||||
|
"valid_acc5es": valid_acc5es,
|
||||||
|
"valid_times": valid_times,
|
||||||
|
"net_state_dict": net.state_dict(),
|
||||||
|
"net_string": "{:}".format(net),
|
||||||
|
"finish-train": True,
|
||||||
|
}
|
||||||
|
return info_seed
|
||||||
|
|
||||||
|
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||||
|
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
latencies = []
|
||||||
|
network.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
end = time.time()
|
||||||
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
|
targets = targets.cuda(non_blocking=True)
|
||||||
|
inputs = inputs.cuda(non_blocking=True)
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
# forward
|
||||||
|
features, logits = network(inputs)
|
||||||
|
loss = criterion(logits, targets)
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
if batch is None or batch == inputs.size(0):
|
||||||
|
batch = inputs.size(0)
|
||||||
|
latencies.append(batch_time.val - data_time.val)
|
||||||
|
# record loss and accuracy
|
||||||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||||
|
losses.update(loss.item(), inputs.size(0))
|
||||||
|
top1.update(prec1.item(), inputs.size(0))
|
||||||
|
top5.update(prec5.item(), inputs.size(0))
|
||||||
|
end = time.time()
|
||||||
|
if len(latencies) > 2:
|
||||||
|
latencies = latencies[1:]
|
||||||
|
return losses.avg, top1.avg, top5.avg, latencies
|
||||||
|
|
||||||
|
|
||||||
|
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
if mode == "train":
|
||||||
|
network.train()
|
||||||
|
elif mode == "valid":
|
||||||
|
network.eval()
|
||||||
|
else:
|
||||||
|
raise ValueError("The mode is not right : {:}".format(mode))
|
||||||
|
|
||||||
|
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||||
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
|
if mode == "train":
|
||||||
|
scheduler.update(None, 1.0 * i / len(xloader))
|
||||||
|
|
||||||
|
targets = targets.cuda(non_blocking=True)
|
||||||
|
if mode == "train":
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# forward
|
||||||
|
features, logits = network(inputs)
|
||||||
|
loss = criterion(logits, targets)
|
||||||
|
# backward
|
||||||
|
if mode == "train":
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# record loss and accuracy
|
||||||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||||
|
losses.update(loss.item(), inputs.size(0))
|
||||||
|
top1.update(prec1.item(), inputs.size(0))
|
||||||
|
top5.update(prec5.item(), inputs.size(0))
|
||||||
|
# count time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||||
|
|
||||||
|
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||||
|
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
latencies = []
|
||||||
|
network.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
end = time.time()
|
||||||
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
|
targets = targets.cuda(non_blocking=True)
|
||||||
|
inputs = inputs.cuda(non_blocking=True)
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
# forward
|
||||||
|
features, logits = network(inputs)
|
||||||
|
loss = criterion(logits, targets)
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
if batch is None or batch == inputs.size(0):
|
||||||
|
batch = inputs.size(0)
|
||||||
|
latencies.append(batch_time.val - data_time.val)
|
||||||
|
# record loss and accuracy
|
||||||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||||
|
losses.update(loss.item(), inputs.size(0))
|
||||||
|
top1.update(prec1.item(), inputs.size(0))
|
||||||
|
top5.update(prec5.item(), inputs.size(0))
|
||||||
|
end = time.time()
|
||||||
|
if len(latencies) > 2:
|
||||||
|
latencies = latencies[1:]
|
||||||
|
return losses.avg, top1.avg, top5.avg, latencies
|
||||||
|
|
||||||
|
|
||||||
|
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
if mode == "train":
|
||||||
|
network.train()
|
||||||
|
elif mode == "valid":
|
||||||
|
network.eval()
|
||||||
|
else:
|
||||||
|
raise ValueError("The mode is not right : {:}".format(mode))
|
||||||
|
|
||||||
|
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||||
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
|
if mode == "train":
|
||||||
|
scheduler.update(None, 1.0 * i / len(xloader))
|
||||||
|
|
||||||
|
targets = targets.cuda(non_blocking=True)
|
||||||
|
if mode == "train":
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# forward
|
||||||
|
features, logits = network(inputs)
|
||||||
|
loss = criterion(logits, targets)
|
||||||
|
# backward
|
||||||
|
if mode == "train":
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# record loss and accuracy
|
||||||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||||
|
losses.update(loss.item(), inputs.size(0))
|
||||||
|
top1.update(prec1.item(), inputs.size(0))
|
||||||
|
top5.update(prec5.item(), inputs.size(0))
|
||||||
|
# count time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||||
|
|
||||||
|
def train_single_model(
|
||||||
|
save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config
|
||||||
|
):
|
||||||
|
assert torch.cuda.is_available(), "CUDA is not available."
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
# torch.backends.cudnn.benchmark = True
|
||||||
|
torch.set_num_threads(workers)
|
||||||
|
|
||||||
|
save_dir = (
|
||||||
|
Path(save_dir)
|
||||||
|
/ "specifics"
|
||||||
|
/ "{:}-{:}-{:}-{:}".format(
|
||||||
|
"LESS" if use_less else "FULL",
|
||||||
|
model_str,
|
||||||
|
arch_config["channel"],
|
||||||
|
arch_config["num_cells"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger = Logger(str(save_dir), 0, False)
|
||||||
|
print(CellArchitectures)
|
||||||
|
if model_str in CellArchitectures:
|
||||||
|
arch = CellArchitectures[model_str]
|
||||||
|
logger.log(
|
||||||
|
"The model string is found in pre-defined architecture dict : {:}".format(
|
||||||
|
model_str
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
arch = CellStructure.str2structure(model_str)
|
||||||
|
except:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid model string : {:}. It can not be found or parsed.".format(
|
||||||
|
model_str
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert arch.check_valid_op(
|
||||||
|
get_search_spaces("cell", "nas-bench-201")
|
||||||
|
), "{:} has the invalid op.".format(arch)
|
||||||
|
logger.log("Start train-evaluate {:}".format(arch.tostr()))
|
||||||
|
logger.log("arch_config : {:}".format(arch_config))
|
||||||
|
|
||||||
|
start_time, seed_time = time.time(), AverageMeter()
|
||||||
|
for _is, seed in enumerate(seeds):
|
||||||
|
logger.log(
|
||||||
|
"\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format(
|
||||||
|
_is, len(seeds), seed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
to_save_name = save_dir / "seed-{:04d}.pth".format(seed)
|
||||||
|
if to_save_name.exists():
|
||||||
|
logger.log(
|
||||||
|
"Find the existing file {:}, directly load!".format(to_save_name)
|
||||||
|
)
|
||||||
|
checkpoint = torch.load(to_save_name)
|
||||||
|
else:
|
||||||
|
logger.log(
|
||||||
|
"Does not find the existing file {:}, train and evaluate!".format(
|
||||||
|
to_save_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
checkpoint = evaluate_all_datasets(
|
||||||
|
arch,
|
||||||
|
datasets,
|
||||||
|
xpaths,
|
||||||
|
splits,
|
||||||
|
use_less,
|
||||||
|
seed,
|
||||||
|
arch_config,
|
||||||
|
workers,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
torch.save(checkpoint, to_save_name)
|
||||||
|
# log information
|
||||||
|
logger.log("{:}".format(checkpoint["info"]))
|
||||||
|
all_dataset_keys = checkpoint["all_dataset_keys"]
|
||||||
|
for dataset_key in all_dataset_keys:
|
||||||
|
logger.log(
|
||||||
|
"\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)
|
||||||
|
)
|
||||||
|
dataset_info = checkpoint[dataset_key]
|
||||||
|
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||||
|
logger.log(
|
||||||
|
"Flops = {:} MB, Params = {:} MB".format(
|
||||||
|
dataset_info["flop"], dataset_info["param"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.log("config : {:}".format(dataset_info["config"]))
|
||||||
|
logger.log(
|
||||||
|
"Training State (finish) = {:}".format(dataset_info["finish-train"])
|
||||||
|
)
|
||||||
|
last_epoch = dataset_info["total_epoch"] - 1
|
||||||
|
train_acc1es, train_acc5es = (
|
||||||
|
dataset_info["train_acc1es"],
|
||||||
|
dataset_info["train_acc5es"],
|
||||||
|
)
|
||||||
|
valid_acc1es, valid_acc5es = (
|
||||||
|
dataset_info["valid_acc1es"],
|
||||||
|
dataset_info["valid_acc5es"],
|
||||||
|
)
|
||||||
|
print(dataset_info["train_acc1es"])
|
||||||
|
print(dataset_info["train_acc5es"])
|
||||||
|
print(dataset_info["valid_acc1es"])
|
||||||
|
print(dataset_info["valid_acc5es"])
|
||||||
|
logger.log(
|
||||||
|
"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format(
|
||||||
|
train_acc1es[last_epoch],
|
||||||
|
train_acc5es[last_epoch],
|
||||||
|
100 - train_acc1es[last_epoch],
|
||||||
|
valid_acc1es['ori-test@'+str(last_epoch)],
|
||||||
|
valid_acc5es['ori-test@'+str(last_epoch)],
|
||||||
|
100 - valid_acc1es['ori-test@'+str(last_epoch)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# measure elapsed time
|
||||||
|
seed_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
need_time = "Time Left: {:}".format(
|
||||||
|
convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)
|
||||||
|
)
|
||||||
|
logger.log(
|
||||||
|
"\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format(
|
||||||
|
_is, len(seeds), seed, need_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
# |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
|
||||||
|
train_strs = [best_cifar_10_str, best_cifar_100_str, best_ImageNet16_str]
|
||||||
|
train_single_model(
|
||||||
|
save_dir="./outputs",
|
||||||
|
workers=8,
|
||||||
|
datasets=["ImageNet16-120"],
|
||||||
|
xpaths="./datasets/imagenet16-120",
|
||||||
|
splits=[0, 0, 0],
|
||||||
|
use_less=False,
|
||||||
|
seeds=[777],
|
||||||
|
model_str=best_ImageNet16_str,
|
||||||
|
arch_config={"channel": 16, "num_cells": 8},)
|
||||||
|
train_single_model(
|
||||||
|
save_dir="./outputs",
|
||||||
|
workers=8,
|
||||||
|
datasets=["cifar10"],
|
||||||
|
xpaths="./datasets/cifar10",
|
||||||
|
splits=[0, 0, 0],
|
||||||
|
use_less=False,
|
||||||
|
seeds=[777],
|
||||||
|
model_str=best_cifar_10_str,
|
||||||
|
arch_config={"channel": 16, "num_cells": 8},)
|
||||||
|
train_single_model(
|
||||||
|
save_dir="./outputs",
|
||||||
|
workers=8,
|
||||||
|
datasets=["cifar100"],
|
||||||
|
xpaths="./datasets/cifar100",
|
||||||
|
splits=[0, 0, 0],
|
||||||
|
use_less=False,
|
||||||
|
seeds=[777],
|
||||||
|
model_str=best_cifar_100_str,
|
||||||
|
arch_config={"channel": 16, "num_cells": 8},)
|
@ -1,40 +1,39 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03.
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, torch
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchvision.datasets as dset
|
import torchvision.datasets as dset
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from xautodl.config_utils import load_config
|
|
||||||
|
|
||||||
from .DownsampledImageNet import ImageNet16
|
|
||||||
from .SearchDatasetWrap import SearchDataset
|
from .SearchDatasetWrap import SearchDataset
|
||||||
|
|
||||||
|
# from PIL import Image
|
||||||
|
import random
|
||||||
|
import pdb
|
||||||
|
from .aircraft import FGVCAircraft
|
||||||
|
from .pets import PetDataset
|
||||||
|
from config_utils import load_config
|
||||||
|
|
||||||
Dataset2Class = {
|
Dataset2Class = {'cifar10': 10,
|
||||||
"cifar10": 10,
|
'cifar100': 100,
|
||||||
"cifar100": 100,
|
'mnist': 10,
|
||||||
"imagenet-1k-s": 1000,
|
'svhn': 10,
|
||||||
"imagenet-1k": 1000,
|
'aircraft': 30,
|
||||||
"ImageNet16": 1000,
|
'oxford': 37}
|
||||||
"ImageNet16-150": 150,
|
|
||||||
"ImageNet16-120": 120,
|
|
||||||
"ImageNet16-200": 200,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CUTOUT(object):
|
class CUTOUT(object):
|
||||||
|
|
||||||
def __init__(self, length):
|
def __init__(self, length):
|
||||||
self.length = length
|
self.length = length
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}(length={length})".format(
|
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
name=self.__class__.__name__, **self.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
h, w = img.size(1), img.size(2)
|
h, w = img.size(1), img.size(2)
|
||||||
@ -47,7 +46,7 @@ class CUTOUT(object):
|
|||||||
x1 = np.clip(x - self.length // 2, 0, w)
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
x2 = np.clip(x + self.length // 2, 0, w)
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
|
||||||
mask[y1:y2, x1:x2] = 0.0
|
mask[y1: y2, x1: x2] = 0.
|
||||||
mask = torch.from_numpy(mask)
|
mask = torch.from_numpy(mask)
|
||||||
mask = mask.expand_as(img)
|
mask = mask.expand_as(img)
|
||||||
img *= mask
|
img *= mask
|
||||||
@ -55,21 +54,19 @@ class CUTOUT(object):
|
|||||||
|
|
||||||
|
|
||||||
imagenet_pca = {
|
imagenet_pca = {
|
||||||
"eigval": np.asarray([0.2175, 0.0188, 0.0045]),
|
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
|
||||||
"eigvec": np.asarray(
|
'eigvec': np.asarray([
|
||||||
[
|
[-0.5675, 0.7192, 0.4009],
|
||||||
[-0.5675, 0.7192, 0.4009],
|
[-0.5808, -0.0045, -0.8140],
|
||||||
[-0.5808, -0.0045, -0.8140],
|
[-0.5836, -0.6948, 0.4203],
|
||||||
[-0.5836, -0.6948, 0.4203],
|
])
|
||||||
]
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Lighting(object):
|
class Lighting(object):
|
||||||
def __init__(
|
def __init__(self, alphastd,
|
||||||
self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"]
|
eigval=imagenet_pca['eigval'],
|
||||||
):
|
eigvec=imagenet_pca['eigvec']):
|
||||||
self.alphastd = alphastd
|
self.alphastd = alphastd
|
||||||
assert eigval.shape == (3,)
|
assert eigval.shape == (3,)
|
||||||
assert eigvec.shape == (3, 3)
|
assert eigvec.shape == (3, 3)
|
||||||
@ -77,10 +74,10 @@ class Lighting(object):
|
|||||||
self.eigvec = eigvec
|
self.eigvec = eigvec
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
if self.alphastd == 0.0:
|
if self.alphastd == 0.:
|
||||||
return img
|
return img
|
||||||
rnd = np.random.randn(3) * self.alphastd
|
rnd = np.random.randn(3) * self.alphastd
|
||||||
rnd = rnd.astype("float32")
|
rnd = rnd.astype('float32')
|
||||||
v = rnd
|
v = rnd
|
||||||
old_dtype = np.asarray(img).dtype
|
old_dtype = np.asarray(img).dtype
|
||||||
v = v * self.eigval
|
v = v * self.eigval
|
||||||
@ -89,275 +86,222 @@ class Lighting(object):
|
|||||||
img = np.add(img, inc)
|
img = np.add(img, inc)
|
||||||
if old_dtype == np.uint8:
|
if old_dtype == np.uint8:
|
||||||
img = np.clip(img, 0, 255)
|
img = np.clip(img, 0, 255)
|
||||||
img = Image.fromarray(img.astype(old_dtype), "RGB")
|
img = Image.fromarray(img.astype(old_dtype), 'RGB')
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__class__.__name__ + "()"
|
return self.__class__.__name__ + '()'
|
||||||
|
|
||||||
|
|
||||||
def get_datasets(name, root, cutout):
|
def get_datasets(name, root, cutout, use_num_cls=None):
|
||||||
|
if name == 'cifar10':
|
||||||
if name == "cifar10":
|
|
||||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||||
elif name == "cifar100":
|
elif name == 'cifar100':
|
||||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||||
elif name.startswith("imagenet-1k"):
|
elif name.startswith('mnist'):
|
||||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
|
||||||
elif name.startswith("ImageNet16"):
|
elif name.startswith('svhn'):
|
||||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
mean, std = [0.4376821, 0.4437697, 0.47280442], [ 0.19803012, 0.20101562, 0.19703614]
|
||||||
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
elif name.startswith('aircraft'):
|
||||||
|
mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
|
||||||
|
std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
|
||||||
|
elif name.startswith('oxford'):
|
||||||
|
mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
|
||||||
|
std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
# Data Argumentation
|
# Data Argumentation
|
||||||
if name == "cifar10" or name == "cifar100":
|
if name == 'cifar10' or name == 'cifar100':
|
||||||
lists = [
|
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.Normalize(mean, std)]
|
||||||
transforms.RandomCrop(32, padding=4),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean, std),
|
|
||||||
]
|
|
||||||
if cutout > 0:
|
if cutout > 0:
|
||||||
lists += [CUTOUT(cutout)]
|
lists += [CUTOUT(cutout)]
|
||||||
train_transform = transforms.Compose(lists)
|
train_transform = transforms.Compose(lists)
|
||||||
test_transform = transforms.Compose(
|
test_transform = transforms.Compose(
|
||||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
[transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
)
|
|
||||||
xshape = (1, 3, 32, 32)
|
xshape = (1, 3, 32, 32)
|
||||||
elif name.startswith("ImageNet16"):
|
elif name.startswith('cub200'):
|
||||||
lists = [
|
train_transform = transforms.Compose([
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.Resize((32, 32)),
|
||||||
transforms.RandomCrop(16, padding=2),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean, std),
|
transforms.Normalize(mean=mean, std=std)
|
||||||
]
|
])
|
||||||
if cutout > 0:
|
test_transform = transforms.Compose([
|
||||||
lists += [CUTOUT(cutout)]
|
transforms.Resize((32, 32)),
|
||||||
train_transform = transforms.Compose(lists)
|
|
||||||
test_transform = transforms.Compose(
|
|
||||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
||||||
)
|
|
||||||
xshape = (1, 3, 16, 16)
|
|
||||||
elif name == "tiered":
|
|
||||||
lists = [
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomCrop(80, padding=4),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean, std),
|
transforms.Normalize(mean=mean, std=std)
|
||||||
]
|
])
|
||||||
if cutout > 0:
|
xshape = (1, 3, 32, 32)
|
||||||
lists += [CUTOUT(cutout)]
|
elif name.startswith('mnist'):
|
||||||
train_transform = transforms.Compose(lists)
|
train_transform = transforms.Compose([
|
||||||
test_transform = transforms.Compose(
|
transforms.Resize((32, 32)),
|
||||||
[
|
transforms.ToTensor(),
|
||||||
transforms.CenterCrop(80),
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||||
transforms.ToTensor(),
|
transforms.Normalize(mean, std),
|
||||||
transforms.Normalize(mean, std),
|
])
|
||||||
]
|
test_transform = transforms.Compose([
|
||||||
)
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||||
|
transforms.Normalize(mean, std)
|
||||||
|
])
|
||||||
|
xshape = (1, 3, 32, 32)
|
||||||
|
elif name.startswith('svhn'):
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
xshape = (1, 3, 32, 32)
|
||||||
|
elif name.startswith('aircraft'):
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std),
|
||||||
|
])
|
||||||
|
xshape = (1, 3, 32, 32)
|
||||||
|
elif name.startswith('oxford'):
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std),
|
||||||
|
])
|
||||||
xshape = (1, 3, 32, 32)
|
xshape = (1, 3, 32, 32)
|
||||||
elif name.startswith("imagenet-1k"):
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
||||||
)
|
|
||||||
if name == "imagenet-1k":
|
|
||||||
xlists = [transforms.RandomResizedCrop(224)]
|
|
||||||
xlists.append(
|
|
||||||
transforms.ColorJitter(
|
|
||||||
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
xlists.append(Lighting(0.1))
|
|
||||||
elif name == "imagenet-1k-s":
|
|
||||||
xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
|
|
||||||
else:
|
|
||||||
raise ValueError("invalid name : {:}".format(name))
|
|
||||||
xlists.append(transforms.RandomHorizontalFlip(p=0.5))
|
|
||||||
xlists.append(transforms.ToTensor())
|
|
||||||
xlists.append(normalize)
|
|
||||||
train_transform = transforms.Compose(xlists)
|
|
||||||
test_transform = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.Resize(256),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
xshape = (1, 3, 224, 224)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
if name == "cifar10":
|
if name == 'cifar10':
|
||||||
train_data = dset.CIFAR10(
|
train_data = dset.CIFAR10(
|
||||||
root, train=True, transform=train_transform, download=True
|
root, train=True, transform=train_transform, download=True)
|
||||||
)
|
|
||||||
test_data = dset.CIFAR10(
|
test_data = dset.CIFAR10(
|
||||||
root, train=False, transform=test_transform, download=True
|
root, train=False, transform=test_transform, download=True)
|
||||||
)
|
|
||||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
elif name == "cifar100":
|
elif name == 'cifar100':
|
||||||
train_data = dset.CIFAR100(
|
train_data = dset.CIFAR100(
|
||||||
root, train=True, transform=train_transform, download=True
|
root, train=True, transform=train_transform, download=True)
|
||||||
)
|
|
||||||
test_data = dset.CIFAR100(
|
test_data = dset.CIFAR100(
|
||||||
root, train=False, transform=test_transform, download=True
|
root, train=False, transform=test_transform, download=True)
|
||||||
)
|
|
||||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
elif name.startswith("imagenet-1k"):
|
elif name == 'mnist':
|
||||||
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
|
train_data = dset.MNIST(
|
||||||
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
|
root, train=True, transform=train_transform, download=True)
|
||||||
assert (
|
test_data = dset.MNIST(
|
||||||
len(train_data) == 1281167 and len(test_data) == 50000
|
root, train=False, transform=test_transform, download=True)
|
||||||
), "invalid number of images : {:} & {:} vs {:} & {:}".format(
|
assert len(train_data) == 60000 and len(test_data) == 10000
|
||||||
len(train_data), len(test_data), 1281167, 50000
|
elif name == 'svhn':
|
||||||
)
|
train_data = dset.SVHN(root, split='train',
|
||||||
elif name == "ImageNet16":
|
transform=train_transform, download=True)
|
||||||
train_data = ImageNet16(root, True, train_transform)
|
test_data = dset.SVHN(root, split='test',
|
||||||
test_data = ImageNet16(root, False, test_transform)
|
transform=test_transform, download=True)
|
||||||
assert len(train_data) == 1281167 and len(test_data) == 50000
|
assert len(train_data) == 73257 and len(test_data) == 26032
|
||||||
elif name == "ImageNet16-120":
|
elif name == 'aircraft':
|
||||||
train_data = ImageNet16(root, True, train_transform, 120)
|
train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval',
|
||||||
test_data = ImageNet16(root, False, test_transform, 120)
|
transform=train_transform, download=False)
|
||||||
assert len(train_data) == 151700 and len(test_data) == 6000
|
test_data = FGVCAircraft(root, class_type='manufacturer', split='test',
|
||||||
elif name == "ImageNet16-150":
|
transform=test_transform, download=False)
|
||||||
train_data = ImageNet16(root, True, train_transform, 150)
|
assert len(train_data) == 6667 and len(test_data) == 3333
|
||||||
test_data = ImageNet16(root, False, test_transform, 150)
|
elif name == 'oxford':
|
||||||
assert len(train_data) == 190272 and len(test_data) == 7500
|
train_data = PetDataset(root, train=True, num_cl=37,
|
||||||
elif name == "ImageNet16-200":
|
val_split=0.15, transforms=train_transform)
|
||||||
train_data = ImageNet16(root, True, train_transform, 200)
|
test_data = PetDataset(root, train=False, num_cl=37,
|
||||||
test_data = ImageNet16(root, False, test_transform, 200)
|
val_split=0.15, transforms=test_transform)
|
||||||
assert len(train_data) == 254775 and len(test_data) == 10000
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
class_num = Dataset2Class[name]
|
class_num = Dataset2Class[name] if use_num_cls is None else len(
|
||||||
|
use_num_cls)
|
||||||
return train_data, test_data, xshape, class_num
|
return train_data, test_data, xshape, class_num
|
||||||
|
|
||||||
|
|
||||||
def get_nas_search_loaders(
|
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None):
|
||||||
train_data, valid_data, dataset, config_root, batch_size, workers
|
|
||||||
):
|
|
||||||
if isinstance(batch_size, (list, tuple)):
|
if isinstance(batch_size, (list, tuple)):
|
||||||
batch, test_batch = batch_size
|
batch, test_batch = batch_size
|
||||||
else:
|
else:
|
||||||
batch, test_batch = batch_size, batch_size
|
batch, test_batch = batch_size, batch_size
|
||||||
if dataset == "cifar10":
|
if dataset == 'cifar10':
|
||||||
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None)
|
cifar_split = load_config(
|
||||||
train_split, valid_split = (
|
'{:}/cifar-split.txt'.format(config_root), None, None)
|
||||||
cifar_split.train,
|
# search over the proposed training and validation set
|
||||||
cifar_split.valid,
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
) # search over the proposed training and validation set
|
|
||||||
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||||
# To split data
|
# To split data
|
||||||
xvalid_data = deepcopy(train_data)
|
xvalid_data = deepcopy(train_data)
|
||||||
if hasattr(xvalid_data, "transforms"): # to avoid a print issue
|
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
|
||||||
xvalid_data.transforms = valid_data.transform
|
xvalid_data.transforms = valid_data.transform
|
||||||
xvalid_data.transform = deepcopy(valid_data.transform)
|
xvalid_data.transform = deepcopy(valid_data.transform)
|
||||||
search_data = SearchDataset(dataset, train_data, train_split, valid_split)
|
search_data = SearchDataset(
|
||||||
|
dataset, train_data, train_split, valid_split)
|
||||||
# data loader
|
# data loader
|
||||||
search_loader = torch.utils.data.DataLoader(
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||||
search_data,
|
pin_memory=True)
|
||||||
batch_size=batch,
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
|
||||||
shuffle=True,
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
num_workers=workers,
|
train_split),
|
||||||
pin_memory=True,
|
num_workers=workers, pin_memory=True)
|
||||||
)
|
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch,
|
||||||
train_loader = torch.utils.data.DataLoader(
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
train_data,
|
valid_split),
|
||||||
batch_size=batch,
|
num_workers=workers, pin_memory=True)
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
|
elif dataset == 'cifar100':
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
|
||||||
xvalid_data,
|
|
||||||
batch_size=test_batch,
|
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
elif dataset == "cifar100":
|
|
||||||
cifar100_test_split = load_config(
|
cifar100_test_split = load_config(
|
||||||
"{:}/cifar100-test-split.txt".format(config_root), None, None
|
'{:}/cifar100-test-split.txt'.format(config_root), None, None)
|
||||||
)
|
|
||||||
search_train_data = train_data
|
search_train_data = train_data
|
||||||
search_valid_data = deepcopy(valid_data)
|
search_valid_data = deepcopy(valid_data)
|
||||||
search_valid_data.transform = train_data.transform
|
search_valid_data.transform = train_data.transform
|
||||||
search_data = SearchDataset(
|
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||||
dataset,
|
list(range(len(search_train_data))),
|
||||||
[search_train_data, search_valid_data],
|
cifar100_test_split.xvalid)
|
||||||
list(range(len(search_train_data))),
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||||
cifar100_test_split.xvalid,
|
pin_memory=True)
|
||||||
)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||||
search_loader = torch.utils.data.DataLoader(
|
pin_memory=True)
|
||||||
search_data,
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||||
batch_size=batch,
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
shuffle=True,
|
cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||||
num_workers=workers,
|
elif dataset in ['mnist', 'svhn', 'aircraft', 'oxford']:
|
||||||
pin_memory=True,
|
if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)):
|
||||||
)
|
import json
|
||||||
train_loader = torch.utils.data.DataLoader(
|
label_list = list(range(len(valid_data)))
|
||||||
train_data,
|
random.shuffle(label_list)
|
||||||
batch_size=batch,
|
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||||
shuffle=True,
|
split = {'xvalid': ["int", strlist[:len(valid_data) // 2]],
|
||||||
num_workers=workers,
|
'xtest': ["int", strlist[len(valid_data) // 2:]]}
|
||||||
pin_memory=True,
|
with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f:
|
||||||
)
|
f.write(json.dumps(split))
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
test_split = load_config(
|
||||||
valid_data,
|
'{:}/{}-test-split.txt'.format(config_root, dataset), None, None)
|
||||||
batch_size=test_batch,
|
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
|
||||||
cifar100_test_split.xvalid
|
|
||||||
),
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
elif dataset == "ImageNet16-120":
|
|
||||||
imagenet_test_split = load_config(
|
|
||||||
"{:}/imagenet-16-120-test-split.txt".format(config_root), None, None
|
|
||||||
)
|
|
||||||
search_train_data = train_data
|
search_train_data = train_data
|
||||||
search_valid_data = deepcopy(valid_data)
|
search_valid_data = deepcopy(valid_data)
|
||||||
search_valid_data.transform = train_data.transform
|
search_valid_data.transform = train_data.transform
|
||||||
search_data = SearchDataset(
|
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||||
dataset,
|
list(range(len(search_train_data))), test_split.xvalid)
|
||||||
[search_train_data, search_valid_data],
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True,
|
||||||
list(range(len(search_train_data))),
|
num_workers=workers, pin_memory=True)
|
||||||
imagenet_test_split.xvalid,
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True,
|
||||||
)
|
num_workers=workers, pin_memory=True)
|
||||||
search_loader = torch.utils.data.DataLoader(
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||||
search_data,
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
batch_size=batch,
|
test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||||
shuffle=True,
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
train_data,
|
|
||||||
batch_size=batch,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
|
||||||
valid_data,
|
|
||||||
batch_size=test_batch,
|
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
|
||||||
imagenet_test_split.xvalid
|
|
||||||
),
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid dataset : {:}".format(dataset))
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||||
return search_loader, train_loader, valid_loader
|
return search_loader, train_loader, valid_loader
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
|
||||||
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
|
@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure(
|
|||||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
||||||
] # node-3
|
] # node-3
|
||||||
)
|
)
|
||||||
|
Number_5374 = Structure(
|
||||||
|
[
|
||||||
|
(("nor_conv_3x3", 0),), # node-1
|
||||||
|
(("nor_conv_1x1", 0), ("nor_conv_3x3", 1)), # node-2
|
||||||
|
(("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)), # node-3
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
AllFull_CODE = Structure(
|
AllFull_CODE = Structure(
|
||||||
[
|
[
|
||||||
@ -271,4 +278,5 @@ architectures = {
|
|||||||
"all_c1x1": AllConv1x1_CODE,
|
"all_c1x1": AllConv1x1_CODE,
|
||||||
"all_idnt": AllIdentity_CODE,
|
"all_idnt": AllIdentity_CODE,
|
||||||
"all_full": AllFull_CODE,
|
"all_full": AllFull_CODE,
|
||||||
|
"5374": Number_5374,
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)):
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
for k in topk:
|
for k in topk:
|
||||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||||
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
return res
|
return res
|
||||||
|
Loading…
Reference in New Issue
Block a user