Update NATS-Bench to cope with lib->xautodl

This commit is contained in:
D-X-Y 2021-05-21 13:55:38 +08:00
parent b4e8eae63a
commit 97717d826e
11 changed files with 75 additions and 80 deletions

View File

@ -9,4 +9,4 @@
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0 - [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1 - [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl` - [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
- [2021.05.19] [b50ad2a](https://github.com/D-X-Y/AutoDL-Projects/tree/b50ad2a) `xautodl` is close to ready - [2021.05.21] [b4e8eae](https://github.com/D-X-Y/AutoDL-Projects/tree/b4e8eae) `xautodl` is close to ready

View File

@ -8,6 +8,7 @@
import os, sys, time, torch, random, argparse import os, sys, time, torch, random, argparse
from typing import List, Text, Dict, Any from typing import List, Text, Dict, Any
from PIL import ImageFile from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy from copy import deepcopy

View File

@ -12,14 +12,11 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.procedures import bench_evaluate_for_seed
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import get_machine_info
from config_utils import dict2config, load_config from xautodl.datasets import get_datasets
from procedures import bench_evaluate_for_seed from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
def obtain_valid_ckp(save_dir: Text, total: int): def obtain_valid_ckp(save_dir: Text, total: int):

View File

@ -12,14 +12,11 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.procedures import bench_evaluate_for_seed
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import get_machine_info
from config_utils import dict2config, load_config from xautodl.datasets import get_datasets
from procedures import bench_evaluate_for_seed from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]): def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):

View File

@ -10,18 +10,14 @@
################################################################### ###################################################################
import os, sys, time, random, argparse, collections import os, sys, time, random, argparse, collections
from copy import deepcopy from copy import deepcopy
from pathlib import Path
import torch import torch
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import prepare_seed, prepare_logger
from config_utils import load_config from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from datasets import get_datasets, SearchDataset from xautodl.models import CellStructure, get_search_spaces
from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time
from nats_bench import create from nats_bench import create
from models import CellStructure, get_search_spaces
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace import ConfigSpace

View File

@ -12,25 +12,47 @@ import numpy as np, collections
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces from xautodl.models import CellStructure, get_search_spaces
from nats_bench import create from nats_bench import create
from regularized_ea import random_topology_func, random_size_func
def random_topology_func(op_names, max_nodes=4):
# Return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = random.choice(op_names)
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
return random_architecture
def random_size_func(info):
# Return a random architecture
def random_architecture():
channels = []
for i in range(info["numbers"]):
channels.append(str(random.choice(info["candidates"])))
return ":".join(channels)
return random_architecture
def main(xargs, api): def main(xargs, api):

View File

@ -16,23 +16,19 @@ import numpy as np, collections
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces from xautodl.models import CellStructure, get_search_spaces
from nats_bench import create from nats_bench import create

View File

@ -13,26 +13,22 @@
import os, sys, time, glob, random, argparse import os, sys, time, glob, random, argparse
import numpy as np, collections import numpy as np, collections
from copy import deepcopy from copy import deepcopy
from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions import Categorical from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces from xautodl.models import CellStructure, get_search_spaces
from nats_bench import create from nats_bench import create
@ -206,7 +202,6 @@ def main(xargs, api):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("The REINFORCE Algorithm") parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,

View File

@ -30,23 +30,19 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, get_nas_search_loaders
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import count_parameters_in_MB, obtain_accuracy from xautodl.utils import count_parameters_in_MB, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create from nats_bench import create

View File

@ -31,23 +31,19 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, get_nas_search_loaders
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import count_parameters_in_MB, obtain_accuracy from xautodl.utils import count_parameters_in_MB, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create from nats_bench import create

View File

@ -3,7 +3,6 @@
##################################################### #####################################################
import os, time, copy, torch, pathlib import os, time, copy, torch, pathlib
# modules in AutoDL
from xautodl import datasets from xautodl import datasets
from xautodl.config_utils import load_config from xautodl.config_utils import load_config
from xautodl.procedures import prepare_seed, get_optim_scheduler from xautodl.procedures import prepare_seed, get_optim_scheduler
@ -83,7 +82,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
def evaluate_for_seed( def evaluate_for_seed(
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
): ):
"""A modular function to train and evaluate a single network, using the given random seed and optimization config with the provided loaders."""
prepare_seed(seed) # random seed prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(arch_config) net = get_cell_based_tiny_net(arch_config)
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)