Compare commits
4 Commits
50ff507a15
...
889bd1974c
Author | SHA1 | Date | |
---|---|---|---|
889bd1974c | |||
af0e7786b6 | |||
c6d53f08ae | |||
ef2608bb42 |
@ -2,11 +2,11 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||||
#####################################################
|
#####################################################
|
||||||
import time, torch
|
import time, torch
|
||||||
from procedures import prepare_seed, get_optim_scheduler
|
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||||
from config_utils import dict2config
|
from xautodl.config_utils import dict2config
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net
|
from xautodl.models import get_cell_based_tiny_net
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["evaluate_for_seed", "pure_evaluate"]
|
__all__ = ["evaluate_for_seed", "pure_evaluate"]
|
||||||
|
@ -16,8 +16,9 @@ from xautodl.procedures import get_machine_info
|
|||||||
from xautodl.datasets import get_datasets
|
from xautodl.datasets import get_datasets
|
||||||
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||||
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
|
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
|
||||||
from xautodl.functions import evaluate_for_seed
|
from functions import evaluate_for_seed
|
||||||
|
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
def evaluate_all_datasets(
|
def evaluate_all_datasets(
|
||||||
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
||||||
@ -46,47 +47,85 @@ def evaluate_all_datasets(
|
|||||||
split_info = load_config(
|
split_info = load_config(
|
||||||
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
)
|
)
|
||||||
|
elif dataset.startswith("aircraft"):
|
||||||
|
if use_less:
|
||||||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||||||
|
else:
|
||||||
|
config_path = "configs/nas-benchmark/aircraft.config"
|
||||||
|
split_info = load_config(
|
||||||
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid dataset : {:}".format(dataset))
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
||||||
config = load_config(
|
config = load_config(
|
||||||
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
||||||
)
|
)
|
||||||
# check whether use splited validation set
|
# check whether use splited validation set
|
||||||
|
# if dataset == 'aircraft':
|
||||||
|
# split = True
|
||||||
if bool(split):
|
if bool(split):
|
||||||
assert dataset == "cifar10"
|
if dataset == "cifar10" or dataset == "cifar100":
|
||||||
ValLoaders = {
|
assert dataset == "cifar10"
|
||||||
"ori-test": torch.utils.data.DataLoader(
|
ValLoaders = {
|
||||||
valid_data,
|
"ori-test": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
assert len(train_data) == len(split_info.train) + len(
|
||||||
|
split_info.valid
|
||||||
|
), "invalid length : {:} vs {:} + {:}".format(
|
||||||
|
len(train_data), len(split_info.train), len(split_info.valid)
|
||||||
|
)
|
||||||
|
train_data_v2 = deepcopy(train_data)
|
||||||
|
train_data_v2.transform = valid_data.transform
|
||||||
|
valid_data = train_data_v2
|
||||||
|
# data loader
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_data,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=False,
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
}
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
assert len(train_data) == len(split_info.train) + len(
|
valid_data,
|
||||||
split_info.valid
|
batch_size=config.batch_size,
|
||||||
), "invalid length : {:} vs {:} + {:}".format(
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
len(train_data), len(split_info.train), len(split_info.valid)
|
num_workers=workers,
|
||||||
)
|
pin_memory=True,
|
||||||
train_data_v2 = deepcopy(train_data)
|
)
|
||||||
train_data_v2.transform = valid_data.transform
|
ValLoaders["x-valid"] = valid_loader
|
||||||
valid_data = train_data_v2
|
elif dataset == "aircraft":
|
||||||
# data loader
|
ValLoaders = {
|
||||||
train_loader = torch.utils.data.DataLoader(
|
"ori-test": torch.utils.data.DataLoader(
|
||||||
train_data,
|
valid_data,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
shuffle=False,
|
||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
}
|
||||||
valid_data,
|
train_data_v2 = deepcopy(train_data)
|
||||||
batch_size=config.batch_size,
|
train_data_v2.transform = valid_data.transform
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
valid_data = train_data_v2
|
||||||
num_workers=workers,
|
# 使用 DataLoader
|
||||||
pin_memory=True,
|
train_loader = torch.utils.data.DataLoader(
|
||||||
)
|
train_data,
|
||||||
ValLoaders["x-valid"] = valid_loader
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
else:
|
else:
|
||||||
# data loader
|
# data loader
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
@ -103,7 +142,7 @@ def evaluate_all_datasets(
|
|||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
if dataset == "cifar10":
|
if dataset == "cifar10" or dataset == "aircraft":
|
||||||
ValLoaders = {"ori-test": valid_loader}
|
ValLoaders = {"ori-test": valid_loader}
|
||||||
elif dataset == "cifar100":
|
elif dataset == "cifar100":
|
||||||
cifar100_splits = load_config(
|
cifar100_splits = load_config(
|
||||||
|
@ -28,16 +28,30 @@ else
|
|||||||
mode=cover
|
mode=cover
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
|
# --mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
|
# --use_less ${use_less} \
|
||||||
|
# --datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||||
|
# --splits 1 0 0 0 \
|
||||||
|
# --xpaths $TORCH_HOME/cifar.python \
|
||||||
|
# $TORCH_HOME/cifar.python \
|
||||||
|
# $TORCH_HOME/cifar.python \
|
||||||
|
# $TORCH_HOME/cifar.python/ImageNet16 \
|
||||||
|
# --channel 16 --num_cells 5 \
|
||||||
|
# --workers 4 \
|
||||||
|
# --srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||||
|
# --seeds ${all_seeds}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
--use_less ${use_less} \
|
--use_less ${use_less} \
|
||||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
--datasets aircraft \
|
||||||
--splits 1 0 0 0 \
|
--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/ \
|
||||||
--xpaths $TORCH_HOME/cifar.python \
|
--channel 16 \
|
||||||
$TORCH_HOME/cifar.python \
|
--splits 1 \
|
||||||
$TORCH_HOME/cifar.python \
|
--num_cells 5 \
|
||||||
$TORCH_HOME/cifar.python/ImageNet16 \
|
|
||||||
--channel 16 --num_cells 5 \
|
|
||||||
--workers 4 \
|
--workers 4 \
|
||||||
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||||
--seeds ${all_seeds}
|
--seeds ${all_seeds}
|
||||||
|
|
||||||
|
|
||||||
|
9477
test.ipynb
9477
test.ipynb
File diff suppressed because it is too large
Load Diff
@ -24,6 +24,8 @@ Dataset2Class = {
|
|||||||
"ImageNet16-150": 150,
|
"ImageNet16-150": 150,
|
||||||
"ImageNet16-120": 120,
|
"ImageNet16-120": 120,
|
||||||
"ImageNet16-200": 200,
|
"ImageNet16-200": 200,
|
||||||
|
"aircraft": 100,
|
||||||
|
"oxford": 102
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -109,6 +111,12 @@ def get_datasets(name, root, cutout):
|
|||||||
elif name.startswith("ImageNet16"):
|
elif name.startswith("ImageNet16"):
|
||||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||||
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
||||||
|
elif name == 'aircraft':
|
||||||
|
mean = [0.4785, 0.5100, 0.5338]
|
||||||
|
std = [0.1845, 0.1830, 0.2060]
|
||||||
|
elif name == 'oxford':
|
||||||
|
mean = [0.4811, 0.4492, 0.3957]
|
||||||
|
std = [0.2260, 0.2231, 0.2249]
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
@ -127,6 +135,13 @@ def get_datasets(name, root, cutout):
|
|||||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
)
|
)
|
||||||
xshape = (1, 3, 32, 32)
|
xshape = (1, 3, 32, 32)
|
||||||
|
elif name.startswith("aircraft") or name.startswith("oxford"):
|
||||||
|
lists = [transforms.RandomCrop(16, padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
|
if cutout > 0:
|
||||||
|
lists += [CUTOUT(cutout)]
|
||||||
|
train_transform = transforms.Compose(lists)
|
||||||
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
xshape = (1, 3, 16, 16)
|
||||||
elif name.startswith("ImageNet16"):
|
elif name.startswith("ImageNet16"):
|
||||||
lists = [
|
lists = [
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip(),
|
||||||
@ -207,6 +222,10 @@ def get_datasets(name, root, cutout):
|
|||||||
root, train=False, transform=test_transform, download=True
|
root, train=False, transform=test_transform, download=True
|
||||||
)
|
)
|
||||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
|
elif name == "aircraft":
|
||||||
|
train_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=train_transform)
|
||||||
|
test_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=test_transform)
|
||||||
|
|
||||||
elif name.startswith("imagenet-1k"):
|
elif name.startswith("imagenet-1k"):
|
||||||
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
|
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
|
||||||
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
|
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
|
||||||
|
Loading…
Reference in New Issue
Block a user