From da2575cc6c0805a694928ccc5b05aa4c0fa7a022 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 24 May 2021 11:04:18 +0800 Subject: [PATCH] Fix NATS-Bench and NAS-Bench-201 for removing lib --- exps/NAS-Bench-201/check.py | 5 +---- exps/NAS-Bench-201/main.py | 17 +++++++---------- exps/NAS-Bench-201/show-best.py | 5 +---- exps/NAS-Bench-201/statistics-v2.py | 11 ++++------- exps/NAS-Bench-201/statistics.py | 13 +++++-------- exps/NAS-Bench-201/test-correlation.py | 7 ++----- exps/NAS-Bench-201/visualize.py | 5 +---- exps/NATS-Bench/Analyze-time.py | 7 ++----- exps/NATS-Bench/draw-correlations.py | 7 ++----- exps/NATS-Bench/draw-fig2_5.py | 10 +++------- exps/NATS-Bench/draw-fig6.py | 7 ++----- exps/NATS-Bench/draw-fig7.py | 7 ++----- exps/NATS-Bench/draw-fig8.py | 7 ++----- exps/NATS-Bench/draw-ranks.py | 9 +++------ exps/NATS-Bench/draw-table.py | 7 ++----- exps/NATS-Bench/sss-collect.py | 17 +++++++++-------- exps/NATS-Bench/test-nats-api.py | 9 +++------ exps/NATS-Bench/tss-collect-patcher.py | 19 ++++++++++--------- exps/NATS-Bench/tss-collect.py | 18 +++++++++--------- exps/algos/BOHB.py | 13 +++++-------- exps/algos/DARTS-V1.py | 15 ++++++--------- exps/algos/DARTS-V2.py | 18 +++++++----------- exps/algos/ENAS.py | 18 +++++++----------- exps/algos/GDAS.py | 19 +++++++------------ exps/algos/RANDOM-NAS.py | 18 +++++++----------- exps/algos/RANDOM.py | 16 +++++++--------- exps/algos/README.md | 2 ++ exps/algos/R_EA.py | 15 ++++++--------- exps/algos/SETN.py | 15 ++++++--------- exps/algos/reinforce.py | 15 ++++++--------- lib | 1 - 31 files changed, 136 insertions(+), 216 deletions(-) delete mode 120000 lib diff --git a/exps/NAS-Bench-201/check.py b/exps/NAS-Bench-201/check.py index f6929db..b81ed34 100644 --- a/exps/NAS-Bench-201/check.py +++ b/exps/NAS-Bench-201/check.py @@ -8,10 +8,7 @@ import torch from pathlib import Path from collections import defaultdict -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time def check_files(save_dir, meta_file, basestr): diff --git a/exps/NAS-Bench-201/main.py b/exps/NAS-Bench-201/main.py index 19a68e6..5b32850 100644 --- a/exps/NAS-Bench-201/main.py +++ b/exps/NAS-Bench-201/main.py @@ -10,16 +10,13 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config -from procedures import save_checkpoint, copy_checkpoint -from procedures import get_machine_info -from datasets import get_datasets -from log_utils import Logger, AverageMeter, time_string, convert_secs2time -from models import CellStructure, CellArchitectures, get_search_spaces -from functions import evaluate_for_seed +from xautodl.config_utils import load_config +from xautodl.procedures import save_checkpoint, copy_checkpoint +from xautodl.procedures import get_machine_info +from xautodl.datasets import get_datasets +from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time +from xautodl.models import CellStructure, CellArchitectures, get_search_spaces +from xautodl.functions import evaluate_for_seed def evaluate_all_datasets( diff --git a/exps/NAS-Bench-201/show-best.py b/exps/NAS-Bench-201/show-best.py index eb9e929..814bdc4 100644 --- a/exps/NAS-Bench-201/show-best.py +++ b/exps/NAS-Bench-201/show-best.py @@ -3,12 +3,9 @@ ################################################################################################ # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # ################################################################################################ -import sys, argparse +import argparse from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) from nas_201_api import NASBench201API as API if __name__ == "__main__": diff --git a/exps/NAS-Bench-201/statistics-v2.py b/exps/NAS-Bench-201/statistics-v2.py index d20a56d..037af0f 100644 --- a/exps/NAS-Bench-201/statistics-v2.py +++ b/exps/NAS-Bench-201/statistics-v2.py @@ -8,16 +8,13 @@ from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time -from config_utils import dict2config +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.config_utils import dict2config # NAS-Bench-201 related module or function -from models import CellStructure, get_cell_based_tiny_net +from xautodl.models import CellStructure, get_cell_based_tiny_net +from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders from nas_201_api import NASBench201API, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders api = NASBench201API( "{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]) diff --git a/exps/NAS-Bench-201/statistics.py b/exps/NAS-Bench-201/statistics.py index 14fbc8a..7587b7a 100644 --- a/exps/NAS-Bench-201/statistics.py +++ b/exps/NAS-Bench-201/statistics.py @@ -7,17 +7,14 @@ import torch from pathlib import Path from collections import defaultdict -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time -from config_utils import load_config, dict2config -from datasets import get_datasets +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.config_utils import load_config, dict2config +from xautodl.datasets import get_datasets # NAS-Bench-201 related module or function -from models import CellStructure, get_cell_based_tiny_net +from xautodl.models import CellStructure, get_cell_based_tiny_net +from xautodl.procedures import bench_pure_evaluate as pure_evaluate from nas_201_api import ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict): diff --git a/exps/NAS-Bench-201/test-correlation.py b/exps/NAS-Bench-201/test-correlation.py index 31225ec..ccc35a4 100644 --- a/exps/NAS-Bench-201/test-correlation.py +++ b/exps/NAS-Bench-201/test-correlation.py @@ -10,11 +10,8 @@ from tqdm import tqdm import torch from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import time_string -from models import CellStructure +from xautodl.log_utils import time_string +from xautodl.models import CellStructure from nas_201_api import NASBench201API as API diff --git a/exps/NAS-Bench-201/visualize.py b/exps/NAS-Bench-201/visualize.py index 588c17c..9571251 100644 --- a/exps/NAS-Bench-201/visualize.py +++ b/exps/NAS-Bench-201/visualize.py @@ -17,10 +17,7 @@ from mpl_toolkits.mplot3d import Axes3D matplotlib.use("agg") import matplotlib.pyplot as plt -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import time_string +from xautodl.log_utils import time_string from nas_201_api import NASBench201API as API diff --git a/exps/NATS-Bench/Analyze-time.py b/exps/NATS-Bench/Analyze-time.py index 46c0e73..e9b0c6e 100644 --- a/exps/NATS-Bench/Analyze-time.py +++ b/exps/NATS-Bench/Analyze-time.py @@ -8,11 +8,8 @@ import os, sys, time, tqdm, argparse from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config -from datasets import get_datasets +from xautodl.config_utils import dict2config, load_config +from xautodl.datasets import get_datasets from nats_bench import create diff --git a/exps/NATS-Bench/draw-correlations.py b/exps/NATS-Bench/draw-correlations.py index 25b578e..db34a2d 100644 --- a/exps/NATS-Bench/draw-correlations.py +++ b/exps/NATS-Bench/draw-correlations.py @@ -19,12 +19,9 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string from nats_bench import create -from log_utils import time_string def get_valid_test_acc(api, arch, dataset): diff --git a/exps/NATS-Bench/draw-fig2_5.py b/exps/NATS-Bench/draw-fig2_5.py index f9eb8fd..779d18d 100644 --- a/exps/NATS-Bench/draw-fig2_5.py +++ b/exps/NATS-Bench/draw-fig2_5.py @@ -20,13 +20,9 @@ import seaborn as sns matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config -from log_utils import time_string -from models import get_cell_based_tiny_net +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string +from xautodl.models import get_cell_based_tiny_net from nats_bench import create diff --git a/exps/NATS-Bench/draw-fig6.py b/exps/NATS-Bench/draw-fig6.py index 8b00ad8..10ef260 100644 --- a/exps/NATS-Bench/draw-fig6.py +++ b/exps/NATS-Bench/draw-fig6.py @@ -21,12 +21,9 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string from nats_bench import create -from log_utils import time_string def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): diff --git a/exps/NATS-Bench/draw-fig7.py b/exps/NATS-Bench/draw-fig7.py index 0a98537..e156af9 100644 --- a/exps/NATS-Bench/draw-fig7.py +++ b/exps/NATS-Bench/draw-fig7.py @@ -20,12 +20,9 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string from nats_bench import create -from log_utils import time_string def get_valid_test_acc(api, arch, dataset): diff --git a/exps/NATS-Bench/draw-fig8.py b/exps/NATS-Bench/draw-fig8.py index 8579aa5..f8b4011 100644 --- a/exps/NATS-Bench/draw-fig8.py +++ b/exps/NATS-Bench/draw-fig8.py @@ -20,12 +20,9 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string from nats_bench import create -from log_utils import time_string plt.rcParams.update( diff --git a/exps/NATS-Bench/draw-ranks.py b/exps/NATS-Bench/draw-ranks.py index 76ee5fb..1ce9f75 100644 --- a/exps/NATS-Bench/draw-ranks.py +++ b/exps/NATS-Bench/draw-ranks.py @@ -21,12 +21,9 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config -from log_utils import time_string -from models import get_cell_based_tiny_net +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string +from xautodl.models import get_cell_based_tiny_net from nats_bench import create diff --git a/exps/NATS-Bench/draw-table.py b/exps/NATS-Bench/draw-table.py index d90ec06..34c7467 100644 --- a/exps/NATS-Bench/draw-table.py +++ b/exps/NATS-Bench/draw-table.py @@ -20,12 +20,9 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string from nats_bench import create -from log_utils import time_string def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): diff --git a/exps/NATS-Bench/sss-collect.py b/exps/NATS-Bench/sss-collect.py index b5ab2d4..be1805b 100644 --- a/exps/NATS-Bench/sss-collect.py +++ b/exps/NATS-Bench/sss-collect.py @@ -17,15 +17,16 @@ from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time -from config_utils import dict2config -from models import CellStructure, get_cell_based_tiny_net +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.config_utils import dict2config +from xautodl.models import CellStructure, get_cell_based_tiny_net +from xautodl.procedures import ( + bench_pure_evaluate as pure_evaluate, + get_nas_bench_loaders, +) +from xautodl.utils import get_md5_file + from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -from utils import get_md5_file NATS_SSS_BASE_NAME = "NATS-sss-v1_0" # 2020.08.28 diff --git a/exps/NATS-Bench/test-nats-api.py b/exps/NATS-Bench/test-nats-api.py index 3547ebb..dae4597 100644 --- a/exps/NATS-Bench/test-nats-api.py +++ b/exps/NATS-Bench/test-nats-api.py @@ -19,13 +19,10 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import dict2config, load_config +from xautodl.config_utils import dict2config, load_config +from xautodl.log_utils import time_string +from xautodl.models import get_cell_based_tiny_net, CellStructure from nats_bench import create -from log_utils import time_string -from models import get_cell_based_tiny_net, CellStructure def test_api(api, sss_or_tss=True): diff --git a/exps/NATS-Bench/tss-collect-patcher.py b/exps/NATS-Bench/tss-collect-patcher.py index 6895aa9..5d493df 100644 --- a/exps/NATS-Bench/tss-collect-patcher.py +++ b/exps/NATS-Bench/tss-collect-patcher.py @@ -19,16 +19,17 @@ from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time -from config_utils import load_config, dict2config -from datasets import get_datasets -from models import CellStructure, get_cell_based_tiny_net, get_search_spaces +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.config_utils import load_config, dict2config +from xautodl.datasets import get_datasets +from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces +from xautodl.procedures import ( + bench_pure_evaluate as pure_evaluate, + get_nas_bench_loaders, +) +from xautodl.utils import get_md5_file + from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -from utils import get_md5_file from nas_201_api import NASBench201API diff --git a/exps/NATS-Bench/tss-collect.py b/exps/NATS-Bench/tss-collect.py index aee2a6b..af06d43 100644 --- a/exps/NATS-Bench/tss-collect.py +++ b/exps/NATS-Bench/tss-collect.py @@ -19,16 +19,16 @@ from pathlib import Path from collections import defaultdict, OrderedDict from typing import Dict, Any, Text, List -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from log_utils import AverageMeter, time_string, convert_secs2time -from config_utils import load_config, dict2config -from datasets import get_datasets -from models import CellStructure, get_cell_based_tiny_net, get_search_spaces +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.config_utils import load_config, dict2config +from xautodl.datasets import get_datasets +from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces +from xautodl.procedures import ( + bench_pure_evaluate as pure_evaluate, + get_nas_bench_loaders, +) +from xautodl.utils import get_md5_file from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount -from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders -from utils import get_md5_file from nas_201_api import NASBench201API diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index f6e3f49..86f740c 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -12,15 +12,12 @@ from copy import deepcopy from pathlib import Path import torch -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config -from datasets import get_datasets, SearchDataset -from procedures import prepare_seed, prepare_logger -from log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.config_utils import load_config +from xautodl.datasets import get_datasets, SearchDataset +from xautodl.procedures import prepare_seed, prepare_logger +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import CellStructure, get_search_spaces from nas_201_api import NASBench201API as API -from models import CellStructure, get_search_spaces # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 import ConfigSpace diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index b1b409f..67441af 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -8,21 +8,18 @@ from copy import deepcopy import torch from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, get_nas_search_loaders +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API diff --git a/exps/algos/DARTS-V2.py b/exps/algos/DARTS-V2.py index 6739ffc..d6bd146 100644 --- a/exps/algos/DARTS-V2.py +++ b/exps/algos/DARTS-V2.py @@ -8,23 +8,19 @@ import numpy as np from copy import deepcopy import torch import torch.nn as nn -from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, get_nas_search_loaders +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API @@ -443,7 +439,7 @@ def main(xargs): if __name__ == "__main__": parser = argparse.ArgumentParser("DARTS Second Order") - parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument("--data_path", type=str, help="The path to dataset") parser.add_argument( "--dataset", type=str, diff --git a/exps/algos/ENAS.py b/exps/algos/ENAS.py index 7b8dd1d..43e4c1b 100644 --- a/exps/algos/ENAS.py +++ b/exps/algos/ENAS.py @@ -8,23 +8,19 @@ import numpy as np from copy import deepcopy import torch import torch.nn as nn -from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, get_nas_search_loaders +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API @@ -527,7 +523,7 @@ def main(xargs): if __name__ == "__main__": parser = argparse.ArgumentParser("ENAS") - parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument("--data_path", type=str, help="The path to dataset") parser.add_argument( "--dataset", type=str, diff --git a/exps/algos/GDAS.py b/exps/algos/GDAS.py index bc760f3..d030cb8 100644 --- a/exps/algos/GDAS.py +++ b/exps/algos/GDAS.py @@ -6,23 +6,18 @@ import sys, time, random, argparse from copy import deepcopy import torch -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config -from datasets import get_datasets, get_nas_search_loaders -from procedures import ( +from xautodl.config_utils import load_config, dict2config +from xautodl.datasets import get_datasets, get_nas_search_loaders +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API @@ -343,7 +338,7 @@ def main(xargs): if __name__ == "__main__": parser = argparse.ArgumentParser("GDAS") - parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument("--data_path", type=str, help="The path to dataset") parser.add_argument( "--dataset", type=str, diff --git a/exps/algos/RANDOM-NAS.py b/exps/algos/RANDOM-NAS.py index 90ffd44..51192b4 100644 --- a/exps/algos/RANDOM-NAS.py +++ b/exps/algos/RANDOM-NAS.py @@ -8,23 +8,19 @@ import numpy as np from copy import deepcopy import torch import torch.nn as nn -from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, get_nas_search_loaders +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API @@ -335,7 +331,7 @@ def main(xargs): if __name__ == "__main__": parser = argparse.ArgumentParser("Random search for NAS.") - parser.add_argument("--data_path", type=str, help="Path to dataset") + parser.add_argument("--data_path", type=str, help="The path to dataset") parser.add_argument( "--dataset", type=str, diff --git a/exps/algos/RANDOM.py b/exps/algos/RANDOM.py index 8bf9b92..88349f5 100644 --- a/exps/algos/RANDOM.py +++ b/exps/algos/RANDOM.py @@ -8,21 +8,19 @@ import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, SearchDataset +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_search_spaces +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_search_spaces + from nas_201_api import NASBench201API as API from R_EA import train_and_eval, random_architecture_func diff --git a/exps/algos/README.md b/exps/algos/README.md index 91121ac..edd4e7e 100644 --- a/exps/algos/README.md +++ b/exps/algos/README.md @@ -3,3 +3,5 @@ The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper. We have upgraded the codes to be more general and extendable at [NATS-algos](https://github.com/D-X-Y/AutoDL-Projects/tree/main/exps/NATS-algos). + +**Notice** On 24 May 2021, the codes in `AutoDL` repo have been re-organized. If you find `module not found` error, please let me know. I will fix them ASAP. diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py index fdf79a1..d18a098 100644 --- a/exps/algos/R_EA.py +++ b/exps/algos/R_EA.py @@ -10,22 +10,19 @@ import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, SearchDataset +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import CellStructure, get_search_spaces from nas_201_api import NASBench201API as API -from models import CellStructure, get_search_spaces class Model(object): diff --git a/exps/algos/SETN.py b/exps/algos/SETN.py index 3e048b9..327df31 100644 --- a/exps/algos/SETN.py +++ b/exps/algos/SETN.py @@ -10,21 +10,18 @@ import torch import torch.nn as nn from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, get_nas_search_loaders -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, get_nas_search_loaders +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time -from models import get_cell_based_tiny_net, get_search_spaces +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import get_cell_based_tiny_net, get_search_spaces from nas_201_api import NASBench201API as API diff --git a/exps/algos/reinforce.py b/exps/algos/reinforce.py index 0122aaa..5ef02da 100644 --- a/exps/algos/reinforce.py +++ b/exps/algos/reinforce.py @@ -11,22 +11,19 @@ import torch import torch.nn as nn from torch.distributions import Categorical -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from config_utils import load_config, dict2config, configure2str -from datasets import get_datasets, SearchDataset -from procedures import ( +from xautodl.config_utils import load_config, dict2config, configure2str +from xautodl.datasets import get_datasets, SearchDataset +from xautodl.procedures import ( prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler, ) -from utils import get_model_infos, obtain_accuracy -from log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.utils import get_model_infos, obtain_accuracy +from xautodl.log_utils import AverageMeter, time_string, convert_secs2time +from xautodl.models import CellStructure, get_search_spaces from nas_201_api import NASBench201API as API -from models import CellStructure, get_search_spaces from R_EA import train_and_eval diff --git a/lib b/lib deleted file mode 120000 index fccce11..0000000 --- a/lib +++ /dev/null @@ -1 +0,0 @@ -xautodl \ No newline at end of file