can train aircraft now
This commit is contained in:
parent
bb33ca9a68
commit
ef2608bb42
@ -2,11 +2,11 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import time, torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from config_utils import dict2config
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||
from xautodl.config_utils import dict2config
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
__all__ = ["evaluate_for_seed", "pure_evaluate"]
|
||||
|
@ -16,8 +16,9 @@ from xautodl.procedures import get_machine_info
|
||||
from xautodl.datasets import get_datasets
|
||||
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
|
||||
from xautodl.functions import evaluate_for_seed
|
||||
from functions import evaluate_for_seed
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
def evaluate_all_datasets(
|
||||
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
||||
@ -46,47 +47,85 @@ def evaluate_all_datasets(
|
||||
split_info = load_config(
|
||||
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||
)
|
||||
elif dataset.startswith("aircraft"):
|
||||
if use_less:
|
||||
config_path = "configs/nas-benchmark/LESS.config"
|
||||
else:
|
||||
config_path = "configs/nas-benchmark/aircraft.config"
|
||||
split_info = load_config(
|
||||
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("invalid dataset : {:}".format(dataset))
|
||||
config = load_config(
|
||||
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
||||
)
|
||||
# check whether use splited validation set
|
||||
# if dataset == 'aircraft':
|
||||
# split = True
|
||||
if bool(split):
|
||||
assert dataset == "cifar10"
|
||||
ValLoaders = {
|
||||
"ori-test": torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
if dataset == "cifar10" or dataset == "cifar100":
|
||||
assert dataset == "cifar10"
|
||||
ValLoaders = {
|
||||
"ori-test": torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
}
|
||||
assert len(train_data) == len(split_info.train) + len(
|
||||
split_info.valid
|
||||
), "invalid length : {:} vs {:} + {:}".format(
|
||||
len(train_data), len(split_info.train), len(split_info.valid)
|
||||
)
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
}
|
||||
assert len(train_data) == len(split_info.train) + len(
|
||||
split_info.valid
|
||||
), "invalid length : {:} vs {:} + {:}".format(
|
||||
len(train_data), len(split_info.train), len(split_info.valid)
|
||||
)
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
ValLoaders["x-valid"] = valid_loader
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
ValLoaders["x-valid"] = valid_loader
|
||||
elif dataset == "aircraft":
|
||||
ValLoaders = {
|
||||
"ori-test": torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
}
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
valid_data = train_data_v2
|
||||
# 使用 DataLoader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||
num_workers=workers,
|
||||
pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||
num_workers=workers,
|
||||
pin_memory=True)
|
||||
else:
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
@ -103,7 +142,7 @@ def evaluate_all_datasets(
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
if dataset == "cifar10":
|
||||
if dataset == "cifar10" or dataset == "aircraft":
|
||||
ValLoaders = {"ori-test": valid_loader}
|
||||
elif dataset == "cifar100":
|
||||
cifar100_splits = load_config(
|
||||
|
Loading…
Reference in New Issue
Block a user