Compare commits

...

4 Commits

Author SHA1 Message Date
889bd1974c merged 2024-10-14 23:24:24 +02:00
af0e7786b6 just play around 2024-10-14 23:20:28 +02:00
c6d53f08ae can train aircraft now 2024-10-14 23:19:49 +02:00
ef2608bb42 can train aircraft now 2024-10-14 23:19:28 +02:00
5 changed files with 9400 additions and 237 deletions

View File

@ -2,11 +2,11 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
##################################################### #####################################################
import time, torch import time, torch
from procedures import prepare_seed, get_optim_scheduler from xautodl.procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from config_utils import dict2config from xautodl.config_utils import dict2config
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net from xautodl.models import get_cell_based_tiny_net
__all__ = ["evaluate_for_seed", "pure_evaluate"] __all__ = ["evaluate_for_seed", "pure_evaluate"]

View File

@ -16,8 +16,9 @@ from xautodl.procedures import get_machine_info
from xautodl.datasets import get_datasets from xautodl.datasets import get_datasets
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
from xautodl.functions import evaluate_for_seed from functions import evaluate_for_seed
from torchvision import datasets, transforms
def evaluate_all_datasets( def evaluate_all_datasets(
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
@ -46,47 +47,85 @@ def evaluate_all_datasets(
split_info = load_config( split_info = load_config(
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
) )
elif dataset.startswith("aircraft"):
if use_less:
config_path = "configs/nas-benchmark/LESS.config"
else:
config_path = "configs/nas-benchmark/aircraft.config"
split_info = load_config(
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
)
else: else:
raise ValueError("invalid dataset : {:}".format(dataset)) raise ValueError("invalid dataset : {:}".format(dataset))
config = load_config( config = load_config(
config_path, {"class_num": class_num, "xshape": xshape}, logger config_path, {"class_num": class_num, "xshape": xshape}, logger
) )
# check whether use splited validation set # check whether use splited validation set
# if dataset == 'aircraft':
# split = True
if bool(split): if bool(split):
assert dataset == "cifar10" if dataset == "cifar10" or dataset == "cifar100":
ValLoaders = { assert dataset == "cifar10"
"ori-test": torch.utils.data.DataLoader( ValLoaders = {
valid_data, "ori-test": torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
}
assert len(train_data) == len(split_info.train) + len(
split_info.valid
), "invalid length : {:} vs {:} + {:}".format(
len(train_data), len(split_info.train), len(split_info.valid)
)
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=False, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
num_workers=workers, num_workers=workers,
pin_memory=True, pin_memory=True,
) )
} valid_loader = torch.utils.data.DataLoader(
assert len(train_data) == len(split_info.train) + len( valid_data,
split_info.valid batch_size=config.batch_size,
), "invalid length : {:} vs {:} + {:}".format( sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
len(train_data), len(split_info.train), len(split_info.valid) num_workers=workers,
) pin_memory=True,
train_data_v2 = deepcopy(train_data) )
train_data_v2.transform = valid_data.transform ValLoaders["x-valid"] = valid_loader
valid_data = train_data_v2 elif dataset == "aircraft":
# data loader ValLoaders = {
train_loader = torch.utils.data.DataLoader( "ori-test": torch.utils.data.DataLoader(
train_data, valid_data,
batch_size=config.batch_size, batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), shuffle=False,
num_workers=workers, num_workers=workers,
pin_memory=True, pin_memory=True,
) )
valid_loader = torch.utils.data.DataLoader( }
valid_data, train_data_v2 = deepcopy(train_data)
batch_size=config.batch_size, train_data_v2.transform = valid_data.transform
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), valid_data = train_data_v2
num_workers=workers, # 使用 DataLoader
pin_memory=True, train_loader = torch.utils.data.DataLoader(
) train_data,
ValLoaders["x-valid"] = valid_loader batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
num_workers=workers,
pin_memory=True)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
num_workers=workers,
pin_memory=True)
else: else:
# data loader # data loader
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
@ -103,7 +142,7 @@ def evaluate_all_datasets(
num_workers=workers, num_workers=workers,
pin_memory=True, pin_memory=True,
) )
if dataset == "cifar10": if dataset == "cifar10" or dataset == "aircraft":
ValLoaders = {"ori-test": valid_loader} ValLoaders = {"ori-test": valid_loader}
elif dataset == "cifar100": elif dataset == "cifar100":
cifar100_splits = load_config( cifar100_splits = load_config(

View File

@ -28,16 +28,30 @@ else
mode=cover mode=cover
fi fi
# OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
# --mode ${mode} --save_dir ${save_dir} --max_node 4 \
# --use_less ${use_less} \
# --datasets cifar10 cifar10 cifar100 ImageNet16-120 \
# --splits 1 0 0 0 \
# --xpaths $TORCH_HOME/cifar.python \
# $TORCH_HOME/cifar.python \
# $TORCH_HOME/cifar.python \
# $TORCH_HOME/cifar.python/ImageNet16 \
# --channel 16 --num_cells 5 \
# --workers 4 \
# --srange ${xstart} ${xend} --arch_index ${arch_index} \
# --seeds ${all_seeds}
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \ OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
--mode ${mode} --save_dir ${save_dir} --max_node 4 \ --mode ${mode} --save_dir ${save_dir} --max_node 4 \
--use_less ${use_less} \ --use_less ${use_less} \
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ --datasets aircraft \
--splits 1 0 0 0 \ --xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/ \
--xpaths $TORCH_HOME/cifar.python \ --channel 16 \
$TORCH_HOME/cifar.python \ --splits 1 \
$TORCH_HOME/cifar.python \ --num_cells 5 \
$TORCH_HOME/cifar.python/ImageNet16 \
--channel 16 --num_cells 5 \
--workers 4 \ --workers 4 \
--srange ${xstart} ${xend} --arch_index ${arch_index} \ --srange ${xstart} ${xend} --arch_index ${arch_index} \
--seeds ${all_seeds} --seeds ${all_seeds}

9477
test.ipynb

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,8 @@ Dataset2Class = {
"ImageNet16-150": 150, "ImageNet16-150": 150,
"ImageNet16-120": 120, "ImageNet16-120": 120,
"ImageNet16-200": 200, "ImageNet16-200": 200,
"aircraft": 100,
"oxford": 102
} }
@ -109,6 +111,12 @@ def get_datasets(name, root, cutout):
elif name.startswith("ImageNet16"): elif name.startswith("ImageNet16"):
mean = [x / 255 for x in [122.68, 116.66, 104.01]] mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26, 65.09]] std = [x / 255 for x in [63.22, 61.26, 65.09]]
elif name == 'aircraft':
mean = [0.4785, 0.5100, 0.5338]
std = [0.1845, 0.1830, 0.2060]
elif name == 'oxford':
mean = [0.4811, 0.4492, 0.3957]
std = [0.2260, 0.2231, 0.2249]
else: else:
raise TypeError("Unknow dataset : {:}".format(name)) raise TypeError("Unknow dataset : {:}".format(name))
@ -127,6 +135,13 @@ def get_datasets(name, root, cutout):
[transforms.ToTensor(), transforms.Normalize(mean, std)] [transforms.ToTensor(), transforms.Normalize(mean, std)]
) )
xshape = (1, 3, 32, 32) xshape = (1, 3, 32, 32)
elif name.startswith("aircraft") or name.startswith("oxford"):
lists = [transforms.RandomCrop(16, padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0:
lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 16, 16)
elif name.startswith("ImageNet16"): elif name.startswith("ImageNet16"):
lists = [ lists = [
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
@ -207,6 +222,10 @@ def get_datasets(name, root, cutout):
root, train=False, transform=test_transform, download=True root, train=False, transform=test_transform, download=True
) )
assert len(train_data) == 50000 and len(test_data) == 10000 assert len(train_data) == 50000 and len(test_data) == 10000
elif name == "aircraft":
train_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=train_transform)
test_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=test_transform)
elif name.startswith("imagenet-1k"): elif name.startswith("imagenet-1k"):
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform) train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform) test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)