Fix NATS-Bench and NAS-Bench-201 for removing lib

This commit is contained in:
D-X-Y 2021-05-24 11:04:18 +08:00
parent c5788ba19c
commit da2575cc6c
31 changed files with 136 additions and 216 deletions

View File

@ -8,10 +8,7 @@ import torch
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
def check_files(save_dir, meta_file, basestr): def check_files(save_dir, meta_file, basestr):

View File

@ -10,16 +10,13 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config
if str(lib_dir) not in sys.path: from xautodl.procedures import save_checkpoint, copy_checkpoint
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import get_machine_info
from config_utils import load_config from xautodl.datasets import get_datasets
from procedures import save_checkpoint, copy_checkpoint from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
from procedures import get_machine_info from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
from datasets import get_datasets from xautodl.functions import evaluate_for_seed
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from models import CellStructure, CellArchitectures, get_search_spaces
from functions import evaluate_for_seed
def evaluate_all_datasets( def evaluate_all_datasets(

View File

@ -3,12 +3,9 @@
################################################################################################ ################################################################################################
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
################################################################################################ ################################################################################################
import sys, argparse import argparse
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -8,16 +8,13 @@ from pathlib import Path
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
if str(lib_dir) not in sys.path: from xautodl.config_utils import dict2config
sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import dict2config
# NAS-Bench-201 related module or function # NAS-Bench-201 related module or function
from models import CellStructure, get_cell_based_tiny_net from xautodl.models import CellStructure, get_cell_based_tiny_net
from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from nas_201_api import NASBench201API, ArchResults, ResultsCount from nas_201_api import NASBench201API, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
api = NASBench201API( api = NASBench201API(
"{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]) "{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])

View File

@ -7,17 +7,14 @@ import torch
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
if str(lib_dir) not in sys.path: from xautodl.config_utils import load_config, dict2config
sys.path.insert(0, str(lib_dir)) from xautodl.datasets import get_datasets
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import load_config, dict2config
from datasets import get_datasets
# NAS-Bench-201 related module or function # NAS-Bench-201 related module or function
from models import CellStructure, get_cell_based_tiny_net from xautodl.models import CellStructure, get_cell_based_tiny_net
from xautodl.procedures import bench_pure_evaluate as pure_evaluate
from nas_201_api import ArchResults, ResultsCount from nas_201_api import ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate
def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict): def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict):

View File

@ -10,11 +10,8 @@ from tqdm import tqdm
import torch import torch
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import time_string
if str(lib_dir) not in sys.path: from xautodl.models import CellStructure
sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from models import CellStructure
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API

View File

@ -17,10 +17,7 @@ from mpl_toolkits.mplot3d import Axes3D
matplotlib.use("agg") matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import time_string
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API

View File

@ -8,11 +8,8 @@
import os, sys, time, tqdm, argparse import os, sys, time, tqdm, argparse
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from datasets import get_datasets
from nats_bench import create from nats_bench import create

View File

@ -19,12 +19,9 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.log_utils import time_string
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create from nats_bench import create
from log_utils import time_string
def get_valid_test_acc(api, arch, dataset): def get_valid_test_acc(api, arch, dataset):

View File

@ -20,13 +20,9 @@ import seaborn as sns
matplotlib.use("agg") matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import time_string
if str(lib_dir) not in sys.path: from xautodl.models import get_cell_based_tiny_net
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from log_utils import time_string
from models import get_cell_based_tiny_net
from nats_bench import create from nats_bench import create

View File

@ -21,12 +21,9 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.log_utils import time_string
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create from nats_bench import create
from log_utils import time_string
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):

View File

@ -20,12 +20,9 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.log_utils import time_string
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create from nats_bench import create
from log_utils import time_string
def get_valid_test_acc(api, arch, dataset): def get_valid_test_acc(api, arch, dataset):

View File

@ -20,12 +20,9 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.log_utils import time_string
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create from nats_bench import create
from log_utils import time_string
plt.rcParams.update( plt.rcParams.update(

View File

@ -21,12 +21,9 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.log_utils import time_string
sys.path.insert(0, str(lib_dir)) from xautodl.models import get_cell_based_tiny_net
from config_utils import dict2config, load_config
from log_utils import time_string
from models import get_cell_based_tiny_net
from nats_bench import create from nats_bench import create

View File

@ -20,12 +20,9 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.log_utils import time_string
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create from nats_bench import create
from log_utils import time_string
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):

View File

@ -17,15 +17,16 @@ from pathlib import Path
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
if str(lib_dir) not in sys.path: from xautodl.config_utils import dict2config
sys.path.insert(0, str(lib_dir)) from xautodl.models import CellStructure, get_cell_based_tiny_net
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.procedures import (
from config_utils import dict2config bench_pure_evaluate as pure_evaluate,
from models import CellStructure, get_cell_based_tiny_net get_nas_bench_loaders,
)
from xautodl.utils import get_md5_file
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
NATS_SSS_BASE_NAME = "NATS-sss-v1_0" # 2020.08.28 NATS_SSS_BASE_NAME = "NATS-sss-v1_0" # 2020.08.28

View File

@ -19,13 +19,10 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import dict2config, load_config
if str(lib_dir) not in sys.path: from xautodl.log_utils import time_string
sys.path.insert(0, str(lib_dir)) from xautodl.models import get_cell_based_tiny_net, CellStructure
from config_utils import dict2config, load_config
from nats_bench import create from nats_bench import create
from log_utils import time_string
from models import get_cell_based_tiny_net, CellStructure
def test_api(api, sss_or_tss=True): def test_api(api, sss_or_tss=True):

View File

@ -19,16 +19,17 @@ from pathlib import Path
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
if str(lib_dir) not in sys.path: from xautodl.config_utils import load_config, dict2config
sys.path.insert(0, str(lib_dir)) from xautodl.datasets import get_datasets
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from config_utils import load_config, dict2config from xautodl.procedures import (
from datasets import get_datasets bench_pure_evaluate as pure_evaluate,
from models import CellStructure, get_cell_based_tiny_net, get_search_spaces get_nas_bench_loaders,
)
from xautodl.utils import get_md5_file
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
from nas_201_api import NASBench201API from nas_201_api import NASBench201API

View File

@ -19,16 +19,16 @@ from pathlib import Path
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
if str(lib_dir) not in sys.path: from xautodl.config_utils import load_config, dict2config
sys.path.insert(0, str(lib_dir)) from xautodl.datasets import get_datasets
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from config_utils import load_config, dict2config from xautodl.procedures import (
from datasets import get_datasets bench_pure_evaluate as pure_evaluate,
from models import CellStructure, get_cell_based_tiny_net, get_search_spaces get_nas_bench_loaders,
)
from xautodl.utils import get_md5_file
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
from nas_201_api import NASBench201API from nas_201_api import NASBench201API

View File

@ -12,15 +12,12 @@ from copy import deepcopy
from pathlib import Path from pathlib import Path
import torch import torch
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import prepare_seed, prepare_logger
from config_utils import load_config from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from datasets import get_datasets, SearchDataset from xautodl.models import CellStructure, get_search_spaces
from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace import ConfigSpace

View File

@ -8,21 +8,18 @@ from copy import deepcopy
import torch import torch
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, get_nas_search_loaders
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API

View File

@ -8,23 +8,19 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, get_nas_search_loaders
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
@ -443,7 +439,7 @@ def main(xargs):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("DARTS Second Order") parser = argparse.ArgumentParser("DARTS Second Order")
parser.add_argument("--data_path", type=str, help="Path to dataset") parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,

View File

@ -8,23 +8,19 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, get_nas_search_loaders
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
@ -527,7 +523,7 @@ def main(xargs):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("ENAS") parser = argparse.ArgumentParser("ENAS")
parser.add_argument("--data_path", type=str, help="Path to dataset") parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,

View File

@ -6,23 +6,18 @@
import sys, time, random, argparse import sys, time, random, argparse
from copy import deepcopy from copy import deepcopy
import torch import torch
from pathlib import Path from xautodl.config_utils import load_config, dict2config
from xautodl.datasets import get_datasets, get_nas_search_loaders
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.procedures import (
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
@ -343,7 +338,7 @@ def main(xargs):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("GDAS") parser = argparse.ArgumentParser("GDAS")
parser.add_argument("--data_path", type=str, help="Path to dataset") parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,

View File

@ -8,23 +8,19 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, get_nas_search_loaders
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
@ -335,7 +331,7 @@ def main(xargs):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("Random search for NAS.") parser = argparse.ArgumentParser("Random search for NAS.")
parser.add_argument("--data_path", type=str, help="Path to dataset") parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,

View File

@ -8,21 +8,19 @@ import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces from xautodl.models import get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
from R_EA import train_and_eval, random_architecture_func from R_EA import train_and_eval, random_architecture_func

View File

@ -3,3 +3,5 @@
The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper. The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper.
We have upgraded the codes to be more general and extendable at [NATS-algos](https://github.com/D-X-Y/AutoDL-Projects/tree/main/exps/NATS-algos). We have upgraded the codes to be more general and extendable at [NATS-algos](https://github.com/D-X-Y/AutoDL-Projects/tree/main/exps/NATS-algos).
**Notice** On 24 May 2021, the codes in `AutoDL` repo have been re-organized. If you find `module not found` error, please let me know. I will fix them ASAP.

View File

@ -10,22 +10,19 @@ import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import CellStructure, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
class Model(object): class Model(object):

View File

@ -10,21 +10,18 @@ import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, get_nas_search_loaders
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API

View File

@ -11,22 +11,19 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions import Categorical from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.config_utils import load_config, dict2config, configure2str
if str(lib_dir) not in sys.path: from xautodl.datasets import get_datasets, SearchDataset
sys.path.insert(0, str(lib_dir)) from xautodl.procedures import (
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import (
prepare_seed, prepare_seed,
prepare_logger, prepare_logger,
save_checkpoint, save_checkpoint,
copy_checkpoint, copy_checkpoint,
get_optim_scheduler, get_optim_scheduler,
) )
from utils import get_model_infos, obtain_accuracy from xautodl.utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import CellStructure, get_search_spaces
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
from R_EA import train_and_eval from R_EA import train_and_eval

1
lib
View File

@ -1 +0,0 @@
xautodl