From ef2608bb42e560a7136efb32887bcb50de17472f Mon Sep 17 00:00:00 2001 From: Mhrooz Date: Mon, 14 Oct 2024 23:19:28 +0200 Subject: [PATCH] can train aircraft now --- exps/NAS-Bench-201/functions.py | 10 ++-- exps/NAS-Bench-201/main.py | 103 ++++++++++++++++++++++---------- 2 files changed, 76 insertions(+), 37 deletions(-) diff --git a/exps/NAS-Bench-201/functions.py b/exps/NAS-Bench-201/functions.py index 5ac92bb..7915e50 100644 --- a/exps/NAS-Bench-201/functions.py +++ b/exps/NAS-Bench-201/functions.py @@ -2,11 +2,11 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # ##################################################### import time, torch -from procedures import prepare_seed, get_optim_scheduler -from utils import get_model_infos, obtain_accuracy -from config_utils import dict2config -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net +from xautodl.procedures import prepare_seed, get_optim_scheduler +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.config_utils import dict2config +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_cell_based_tiny_net __all__ = ["evaluate_for_seed", "pure_evaluate"] diff --git a/exps/NAS-Bench-201/main.py b/exps/NAS-Bench-201/main.py index 5b32850..61be861 100644 --- a/exps/NAS-Bench-201/main.py +++ b/exps/NAS-Bench-201/main.py @@ -16,8 +16,9 @@ from xautodl.procedures import get_machine_info from xautodl.datasets import get_datasets from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time from xautodl.models import CellStructure, CellArchitectures, get_search_spaces -from xautodl.functions import evaluate_for_seed +from functions import evaluate_for_seed +from torchvision import datasets, transforms def evaluate_all_datasets( arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger @@ -46,47 +47,85 @@ def evaluate_all_datasets( split_info = load_config( "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None ) + elif dataset.startswith("aircraft"): + if use_less: + config_path = "configs/nas-benchmark/LESS.config" + else: + config_path = "configs/nas-benchmark/aircraft.config" + split_info = load_config( + "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None + ) + else: raise ValueError("invalid dataset : {:}".format(dataset)) config = load_config( config_path, {"class_num": class_num, "xshape": xshape}, logger ) # check whether use splited validation set + # if dataset == 'aircraft': + # split = True if bool(split): - assert dataset == "cifar10" - ValLoaders = { - "ori-test": torch.utils.data.DataLoader( - valid_data, + if dataset == "cifar10" or dataset == "cifar100": + assert dataset == "cifar10" + ValLoaders = { + "ori-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + ) + } + assert len(train_data) == len(split_info.train) + len( + split_info.valid + ), "invalid length : {:} vs {:} + {:}".format( + len(train_data), len(split_info.train), len(split_info.valid) + ) + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + # data loader + train_loader = torch.utils.data.DataLoader( + train_data, batch_size=config.batch_size, - shuffle=False, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True, ) - } - assert len(train_data) == len(split_info.train) + len( - split_info.valid - ), "invalid length : {:} vs {:} + {:}".format( - len(train_data), len(split_info.train), len(split_info.valid) - ) - train_data_v2 = deepcopy(train_data) - train_data_v2.transform = valid_data.transform - valid_data = train_data_v2 - # data loader - train_loader = torch.utils.data.DataLoader( - train_data, - batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), - num_workers=workers, - pin_memory=True, - ) - valid_loader = torch.utils.data.DataLoader( - valid_data, - batch_size=config.batch_size, - sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), - num_workers=workers, - pin_memory=True, - ) - ValLoaders["x-valid"] = valid_loader + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), + num_workers=workers, + pin_memory=True, + ) + ValLoaders["x-valid"] = valid_loader + elif dataset == "aircraft": + ValLoaders = { + "ori-test": torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + ) + } + train_data_v2 = deepcopy(train_data) + train_data_v2.transform = valid_data.transform + valid_data = train_data_v2 + # 使用 DataLoader + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), + num_workers=workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader( + valid_data, + batch_size=config.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), + num_workers=workers, + pin_memory=True) else: # data loader train_loader = torch.utils.data.DataLoader( @@ -103,7 +142,7 @@ def evaluate_all_datasets( num_workers=workers, pin_memory=True, ) - if dataset == "cifar10": + if dataset == "cifar10" or dataset == "aircraft": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": cifar100_splits = load_config(