Update NATS-Bench to cope with lib->xautodl
This commit is contained in:
parent
b4e8eae63a
commit
97717d826e
@ -9,4 +9,4 @@
|
|||||||
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
|
- [2020.10.15] [446262a](https://github.com/D-X-Y/AutoDL-Projects/tree/446262a) Update NATS-BENCH to version 1.0
|
||||||
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
|
- [2020.12.20] [dae387a](https://github.com/D-X-Y/AutoDL-Projects/tree/dae387a) Update NATS-BENCH to version 1.1
|
||||||
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
|
- [2021.05.18] [98fadf8](https://github.com/D-X-Y/AutoDL-Projects/tree/98fadf8) Before moving to `xautodl`
|
||||||
- [2021.05.19] [b50ad2a](https://github.com/D-X-Y/AutoDL-Projects/tree/b50ad2a) `xautodl` is close to ready
|
- [2021.05.21] [b4e8eae](https://github.com/D-X-Y/AutoDL-Projects/tree/b4e8eae) `xautodl` is close to ready
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
import os, sys, time, torch, random, argparse
|
import os, sys, time, torch, random, argparse
|
||||||
from typing import List, Text, Dict, Any
|
from typing import List, Text, Dict, Any
|
||||||
from PIL import ImageFile
|
from PIL import ImageFile
|
||||||
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
@ -12,14 +12,11 @@ from collections import defaultdict
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import dict2config, load_config
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.procedures import bench_evaluate_for_seed
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import get_machine_info
|
||||||
from config_utils import dict2config, load_config
|
from xautodl.datasets import get_datasets
|
||||||
from procedures import bench_evaluate_for_seed
|
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||||
from procedures import get_machine_info
|
|
||||||
from datasets import get_datasets
|
|
||||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
|
||||||
|
|
||||||
|
|
||||||
def obtain_valid_ckp(save_dir: Text, total: int):
|
def obtain_valid_ckp(save_dir: Text, total: int):
|
||||||
|
@ -12,14 +12,11 @@ from collections import defaultdict
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import dict2config, load_config
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.procedures import bench_evaluate_for_seed
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import get_machine_info
|
||||||
from config_utils import dict2config, load_config
|
from xautodl.datasets import get_datasets
|
||||||
from procedures import bench_evaluate_for_seed
|
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||||
from procedures import get_machine_info
|
|
||||||
from datasets import get_datasets
|
|
||||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
|
||||||
|
|
||||||
|
|
||||||
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):
|
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):
|
||||||
|
@ -10,18 +10,14 @@
|
|||||||
###################################################################
|
###################################################################
|
||||||
import os, sys, time, random, argparse, collections
|
import os, sys, time, random, argparse, collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import load_config
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.datasets import get_datasets, SearchDataset
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import prepare_seed, prepare_logger
|
||||||
from config_utils import load_config
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from datasets import get_datasets, SearchDataset
|
from xautodl.models import CellStructure, get_search_spaces
|
||||||
from procedures import prepare_seed, prepare_logger
|
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
from models import CellStructure, get_search_spaces
|
|
||||||
|
|
||||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
|
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
|
||||||
import ConfigSpace
|
import ConfigSpace
|
||||||
|
@ -12,25 +12,47 @@ import numpy as np, collections
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.datasets import get_datasets, SearchDataset
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import (
|
||||||
from config_utils import load_config, dict2config, configure2str
|
|
||||||
from datasets import get_datasets, SearchDataset
|
|
||||||
from procedures import (
|
|
||||||
prepare_seed,
|
prepare_seed,
|
||||||
prepare_logger,
|
prepare_logger,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
copy_checkpoint,
|
copy_checkpoint,
|
||||||
get_optim_scheduler,
|
get_optim_scheduler,
|
||||||
)
|
)
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_search_spaces
|
from xautodl.models import CellStructure, get_search_spaces
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
from regularized_ea import random_topology_func, random_size_func
|
|
||||||
|
|
||||||
|
def random_topology_func(op_names, max_nodes=4):
|
||||||
|
# Return a random architecture
|
||||||
|
def random_architecture():
|
||||||
|
genotypes = []
|
||||||
|
for i in range(1, max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = "{:}<-{:}".format(i, j)
|
||||||
|
op_name = random.choice(op_names)
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append(tuple(xlist))
|
||||||
|
return CellStructure(genotypes)
|
||||||
|
|
||||||
|
return random_architecture
|
||||||
|
|
||||||
|
|
||||||
|
def random_size_func(info):
|
||||||
|
# Return a random architecture
|
||||||
|
def random_architecture():
|
||||||
|
channels = []
|
||||||
|
for i in range(info["numbers"]):
|
||||||
|
channels.append(str(random.choice(info["candidates"])))
|
||||||
|
return ":".join(channels)
|
||||||
|
|
||||||
|
return random_architecture
|
||||||
|
|
||||||
|
|
||||||
def main(xargs, api):
|
def main(xargs, api):
|
||||||
|
@ -16,23 +16,19 @@ import numpy as np, collections
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.datasets import get_datasets, SearchDataset
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import (
|
||||||
from config_utils import load_config, dict2config, configure2str
|
|
||||||
from datasets import get_datasets, SearchDataset
|
|
||||||
from procedures import (
|
|
||||||
prepare_seed,
|
prepare_seed,
|
||||||
prepare_logger,
|
prepare_logger,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
copy_checkpoint,
|
copy_checkpoint,
|
||||||
get_optim_scheduler,
|
get_optim_scheduler,
|
||||||
)
|
)
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import CellStructure, get_search_spaces
|
from xautodl.models import CellStructure, get_search_spaces
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,26 +13,22 @@
|
|||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np, collections
|
import numpy as np, collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.datasets import get_datasets, SearchDataset
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import (
|
||||||
from config_utils import load_config, dict2config, configure2str
|
|
||||||
from datasets import get_datasets, SearchDataset
|
|
||||||
from procedures import (
|
|
||||||
prepare_seed,
|
prepare_seed,
|
||||||
prepare_logger,
|
prepare_logger,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
copy_checkpoint,
|
copy_checkpoint,
|
||||||
get_optim_scheduler,
|
get_optim_scheduler,
|
||||||
)
|
)
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import CellStructure, get_search_spaces
|
from xautodl.models import CellStructure, get_search_spaces
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
|
|
||||||
|
|
||||||
@ -206,7 +202,6 @@ def main(xargs, api):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
|
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
|
||||||
parser.add_argument("--data_path", type=str, help="Path to dataset")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset",
|
"--dataset",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -30,23 +30,19 @@ import numpy as np
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.datasets import get_datasets, get_nas_search_loaders
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import (
|
||||||
from config_utils import load_config, dict2config, configure2str
|
|
||||||
from datasets import get_datasets, get_nas_search_loaders
|
|
||||||
from procedures import (
|
|
||||||
prepare_seed,
|
prepare_seed,
|
||||||
prepare_logger,
|
prepare_logger,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
copy_checkpoint,
|
copy_checkpoint,
|
||||||
get_optim_scheduler,
|
get_optim_scheduler,
|
||||||
)
|
)
|
||||||
from utils import count_parameters_in_MB, obtain_accuracy
|
from xautodl.utils import count_parameters_in_MB, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,23 +31,19 @@ import numpy as np
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||||
if str(lib_dir) not in sys.path:
|
from xautodl.datasets import get_datasets, get_nas_search_loaders
|
||||||
sys.path.insert(0, str(lib_dir))
|
from xautodl.procedures import (
|
||||||
from config_utils import load_config, dict2config, configure2str
|
|
||||||
from datasets import get_datasets, get_nas_search_loaders
|
|
||||||
from procedures import (
|
|
||||||
prepare_seed,
|
prepare_seed,
|
||||||
prepare_logger,
|
prepare_logger,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
copy_checkpoint,
|
copy_checkpoint,
|
||||||
get_optim_scheduler,
|
get_optim_scheduler,
|
||||||
)
|
)
|
||||||
from utils import count_parameters_in_MB, obtain_accuracy
|
from xautodl.utils import count_parameters_in_MB, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
import os, time, copy, torch, pathlib
|
import os, time, copy, torch, pathlib
|
||||||
|
|
||||||
# modules in AutoDL
|
|
||||||
from xautodl import datasets
|
from xautodl import datasets
|
||||||
from xautodl.config_utils import load_config
|
from xautodl.config_utils import load_config
|
||||||
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||||
@ -83,7 +82,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
|||||||
def evaluate_for_seed(
|
def evaluate_for_seed(
|
||||||
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
|
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
|
||||||
):
|
):
|
||||||
|
"""A modular function to train and evaluate a single network, using the given random seed and optimization config with the provided loaders."""
|
||||||
prepare_seed(seed) # random seed
|
prepare_seed(seed) # random seed
|
||||||
net = get_cell_based_tiny_net(arch_config)
|
net = get_cell_based_tiny_net(arch_config)
|
||||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||||
|
Loading…
Reference in New Issue
Block a user