Fix bugs
This commit is contained in:
		| @@ -1,14 +1,18 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1 | # python exps/GeMOSA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1 | ||||||
| # python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | # python exps/GeMOSA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from xautodl.procedures import ( | from xautodl.procedures import ( | ||||||
|     prepare_seed, |     prepare_seed, | ||||||
|     prepare_logger, |     prepare_logger, | ||||||
| @@ -38,9 +42,9 @@ def subsample(historical_x, historical_y, maxn=10000): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     logger, env_info, model_kwargs = lfna_setup(args) |     logger, model_kwargs = lfna_setup(args) | ||||||
|  |  | ||||||
|     w_container_per_epoch = dict() |     w_containers = dict() | ||||||
|  |  | ||||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|     for idx in range(args.prev_time, env_info["total"]): |     for idx in range(args.prev_time, env_info["total"]): | ||||||
| @@ -111,7 +115,7 @@ def main(args): | |||||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|             idx, env_info["total"] |             idx, env_info["total"] | ||||||
|         ) |         ) | ||||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() |         w_containers[idx] = model.get_w_container().no_grad_clone() | ||||||
|         save_checkpoint( |         save_checkpoint( | ||||||
|             { |             { | ||||||
|                 "model_state_dict": model.state_dict(), |                 "model_state_dict": model.state_dict(), | ||||||
| @@ -127,7 +131,7 @@ def main(args): | |||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         {"w_container_per_epoch": w_container_per_epoch}, |         {"w_containers": w_containers}, | ||||||
|         logger.path(None) / "final-ckp.pth", |         logger.path(None) / "final-ckp.pth", | ||||||
|         logger, |         logger, | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -68,6 +68,8 @@ def main(args): | |||||||
|         # build model |         # build model | ||||||
|         model = get_model(**model_kwargs) |         model = get_model(**model_kwargs) | ||||||
|         model = model.to(args.device) |         model = model.to(args.device) | ||||||
|  |         if idx == 0: | ||||||
|  |             print(model) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|         criterion = torch.nn.MSELoss() |         criterion = torch.nn.MSELoss() | ||||||
|   | |||||||
| @@ -16,7 +16,7 @@ def lfna_setup(args): | |||||||
|         input_dim=1, |         input_dim=1, | ||||||
|         output_dim=1, |         output_dim=1, | ||||||
|         hidden_dims=[args.hidden_dim] * 2, |         hidden_dims=[args.hidden_dim] * 2, | ||||||
|         act_cls="gelu", |         act_cls="relu", | ||||||
|         norm_cls="layer_norm_1d", |         norm_cls="layer_norm_1d", | ||||||
|     ) |     ) | ||||||
|     return logger, model_kwargs |     return logger, model_kwargs | ||||||
|   | |||||||
| @@ -23,10 +23,12 @@ if str(lib_dir) not in sys.path: | |||||||
| import qlib | import qlib | ||||||
| from qlib import config as qconfig | from qlib import config as qconfig | ||||||
| from qlib.workflow import R | from qlib.workflow import R | ||||||
| qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) |  | ||||||
|  | qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||||
|  |  | ||||||
| from utils.qlib_utils import QResult | from utils.qlib_utils import QResult | ||||||
|  |  | ||||||
|  |  | ||||||
| def filter_finished(recorders): | def filter_finished(recorders): | ||||||
|     returned_recorders = dict() |     returned_recorders = dict() | ||||||
|     not_finished = 0 |     not_finished = 0 | ||||||
| @@ -44,6 +46,7 @@ def add_to_dict(xdict, timestamp, value): | |||||||
|         raise ValueError("This date [{:}] is already in the dict".format(date)) |         raise ValueError("This date [{:}] is already in the dict".format(date)) | ||||||
|     xdict[date] = value |     xdict[date] = value | ||||||
|  |  | ||||||
|  |  | ||||||
| def query_info(save_dir, verbose, name_filter, key_map): | def query_info(save_dir, verbose, name_filter, key_map): | ||||||
|     if isinstance(save_dir, list): |     if isinstance(save_dir, list): | ||||||
|         results = [] |         results = [] | ||||||
| @@ -61,7 +64,10 @@ def query_info(save_dir, verbose, name_filter, key_map): | |||||||
|     for idx, (key, experiment) in enumerate(experiments.items()): |     for idx, (key, experiment) in enumerate(experiments.items()): | ||||||
|         if experiment.id == "0": |         if experiment.id == "0": | ||||||
|             continue |             continue | ||||||
|         if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None: |         if ( | ||||||
|  |             name_filter is not None | ||||||
|  |             and re.fullmatch(name_filter, experiment.name) is None | ||||||
|  |         ): | ||||||
|             continue |             continue | ||||||
|         recorders = experiment.list_recorders() |         recorders = experiment.list_recorders() | ||||||
|         recorders, not_finished = filter_finished(recorders) |         recorders, not_finished = filter_finished(recorders) | ||||||
| @@ -77,10 +83,10 @@ def query_info(save_dir, verbose, name_filter, key_map): | |||||||
|             ) |             ) | ||||||
|         result = QResult(experiment.name) |         result = QResult(experiment.name) | ||||||
|         for recorder_id, recorder in recorders.items(): |         for recorder_id, recorder in recorders.items(): | ||||||
|             file_names = ['results-train.pkl', 'results-valid.pkl', 'results-test.pkl'] |             file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"] | ||||||
|             date2IC = OrderedDict() |             date2IC = OrderedDict() | ||||||
|             for file_name in file_names: |             for file_name in file_names: | ||||||
|                 xtemp = recorder.load_object(file_name)['all-IC'] |                 xtemp = recorder.load_object(file_name)["all-IC"] | ||||||
|                 timestamps, values = xtemp.index.tolist(), xtemp.tolist() |                 timestamps, values = xtemp.index.tolist(), xtemp.tolist() | ||||||
|                 for timestamp, value in zip(timestamps, values): |                 for timestamp, value in zip(timestamps, values): | ||||||
|                     add_to_dict(date2IC, timestamp, value) |                     add_to_dict(date2IC, timestamp, value) | ||||||
| @@ -104,7 +110,7 @@ def query_info(save_dir, verbose, name_filter, key_map): | |||||||
|  |  | ||||||
|  |  | ||||||
| ## | ## | ||||||
| paths = [root_dir / 'outputs' / 'qlib-baselines-csi300'] | paths = [root_dir / "outputs" / "qlib-baselines-csi300"] | ||||||
| paths = [path.resolve() for path in paths] | paths = [path.resolve() for path in paths] | ||||||
| print(paths) | print(paths) | ||||||
|  |  | ||||||
| @@ -112,12 +118,12 @@ key_map = dict() | |||||||
| for xset in ("train", "valid", "test"): | for xset in ("train", "valid", "test"): | ||||||
|     key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset) |     key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset) | ||||||
|     key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset) |     key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset) | ||||||
| qresults = query_info(paths, False, 'TSF-2x24-drop0_0s.*-.*-01', key_map) | qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map) | ||||||
| print('Find {:} results'.format(len(qresults))) | print("Find {:} results".format(len(qresults))) | ||||||
| times = [] | times = [] | ||||||
| for qresult in qresults: | for qresult in qresults: | ||||||
|     times.append(qresult.name.split('0_0s')[-1]) |     times.append(qresult.name.split("0_0s")[-1]) | ||||||
| print(times) | print(times) | ||||||
| save_path = os.path.join(note_dir, 'temp-time-x.pth') | save_path = os.path.join(note_dir, "temp-time-x.pth") | ||||||
| torch.save(qresults, save_path) | torch.save(qresults, save_path) | ||||||
| print(save_path) | print(save_path) | ||||||
|   | |||||||
| @@ -24,7 +24,7 @@ from qlib.model.base import Model | |||||||
| from qlib.data.dataset import DatasetH | from qlib.data.dataset import DatasetH | ||||||
| from qlib.data.dataset.handler import DataHandlerLP | from qlib.data.dataset.handler import DataHandlerLP | ||||||
|  |  | ||||||
| qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) | qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||||
|  |  | ||||||
| dataset_config = { | dataset_config = { | ||||||
|     "class": "DatasetH", |     "class": "DatasetH", | ||||||
| @@ -47,7 +47,7 @@ dataset_config = { | |||||||
|             "test": ("2017-01-01", "2020-08-01"), |             "test": ("2017-01-01", "2020-08-01"), | ||||||
|         }, |         }, | ||||||
|     }, |     }, | ||||||
|         } | } | ||||||
| pprint.pprint(dataset_config) | pprint.pprint(dataset_config) | ||||||
| dataset = init_instance_by_config(dataset_config) | dataset = init_instance_by_config(dataset_config) | ||||||
|  |  | ||||||
| @@ -55,7 +55,7 @@ df_train, df_valid, df_test = dataset.prepare( | |||||||
|     ["train", "valid", "test"], |     ["train", "valid", "test"], | ||||||
|     col_set=["feature", "label"], |     col_set=["feature", "label"], | ||||||
|     data_key=DataHandlerLP.DK_L, |     data_key=DataHandlerLP.DK_L, | ||||||
|         ) | ) | ||||||
| model = get_transformer(None) | model = get_transformer(None) | ||||||
| print(model) | print(model) | ||||||
|  |  | ||||||
| @@ -72,4 +72,5 @@ label = labels[batch][mask] | |||||||
| loss = torch.nn.functional.mse_loss(pred, label) | loss = torch.nn.functional.mse_loss(pred, label) | ||||||
|  |  | ||||||
| from sklearn.metrics import mean_squared_error | from sklearn.metrics import mean_squared_error | ||||||
|  |  | ||||||
| mse_loss = mean_squared_error(pred.numpy(), label.numpy()) | mse_loss = mean_squared_error(pred.numpy(), label.numpy()) | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								setup.py
									
									
									
									
									
								
							| @@ -37,7 +37,9 @@ def read(fname="README.md"): | |||||||
| # What packages are required for this module to be executed? | # What packages are required for this module to be executed? | ||||||
| REQUIRED = ["numpy>=1.16.5,<=1.19.5"] | REQUIRED = ["numpy>=1.16.5,<=1.19.5"] | ||||||
|  |  | ||||||
| packages = find_packages(exclude=("tests", "scripts", "scripts-search", "lib*", "exps*")) | packages = find_packages( | ||||||
|  |     exclude=("tests", "scripts", "scripts-search", "lib*", "exps*") | ||||||
|  | ) | ||||||
| print("packages: {:}".format(packages)) | print("packages: {:}".format(packages)) | ||||||
|  |  | ||||||
| setup( | setup( | ||||||
|   | |||||||
| @@ -64,65 +64,29 @@ class ComposedSinFunc(FitFunc): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ComposedSinFuncV2(FitFunc): | class ComposedCosFunc(FitFunc): | ||||||
|     """The composed sin function that outputs: |     """The composed sin function that outputs: | ||||||
|       f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) |     f(x) = a * cos( b*x ) + c | ||||||
|     - the amplitude scale is a quadratic function of x |  | ||||||
|     - the period-phase-shift is another quadratic function of x |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, params, xstr="x"): | ||||||
|         super(ComposedSinFuncV2, self).__init__(0, None) |         super(ComposedCosFunc, self).__init__(3, None, params, xstr) | ||||||
|         self.fit(**kwargs) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         scale = self._params["amplitude_scale"](x) |         a = self._params[0] | ||||||
|         period_phase = self._params["period_phase_shift"](x) |         b = self._params[1] | ||||||
|         return scale * math.sin(period_phase) |         c = self._params[2] | ||||||
|  |         return a * math.cos(b * x) + c | ||||||
|     def fit(self, **kwargs): |  | ||||||
|         num_sin_phase = kwargs.get("num_sin_phase", 7) |  | ||||||
|         sin_speed_use_power = kwargs.get("sin_speed_use_power", True) |  | ||||||
|         min_amplitude = kwargs.get("min_amplitude", 1) |  | ||||||
|         max_amplitude = kwargs.get("max_amplitude", 4) |  | ||||||
|         phase_shift = kwargs.get("phase_shift", 0.0) |  | ||||||
|         # create parameters |  | ||||||
|         if kwargs.get("amplitude_scale", None) is None: |  | ||||||
|             amplitude_scale = QuadraticFunc( |  | ||||||
|                 [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             amplitude_scale = kwargs.get("amplitude_scale") |  | ||||||
|         if kwargs.get("period_phase_shift", None) is None: |  | ||||||
|             fitting_data = [] |  | ||||||
|             if sin_speed_use_power: |  | ||||||
|                 temp_max_scalar = 2 ** (num_sin_phase - 1) |  | ||||||
|             else: |  | ||||||
|                 temp_max_scalar = num_sin_phase - 1 |  | ||||||
|             for i in range(num_sin_phase): |  | ||||||
|                 if sin_speed_use_power: |  | ||||||
|                     value = (2 ** i) / temp_max_scalar |  | ||||||
|                     next_value = (2 ** (i + 1)) / temp_max_scalar |  | ||||||
|                 else: |  | ||||||
|                     value = i / temp_max_scalar |  | ||||||
|                     next_value = (i + 1) / temp_max_scalar |  | ||||||
|                 for _phase in (0, 0.25, 0.5, 0.75): |  | ||||||
|                     inter_value = value + (next_value - value) * _phase |  | ||||||
|                     fitting_data.append((inter_value, math.pi * (2 * i + _phase))) |  | ||||||
|             period_phase_shift = QuarticFunc(fitting_data) |  | ||||||
|         else: |  | ||||||
|             period_phase_shift = kwargs.get("period_phase_shift") |  | ||||||
|         self.set( |  | ||||||
|             dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |     def _getitem(self, x, weights): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({amplitude_scale} * sin({period_phase_shift}))".format( |         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             amplitude_scale=self._params["amplitude_scale"], |             a=self._params[0], | ||||||
|             period_phase_shift=self._params["period_phase_shift"], |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|         ) |         ) | ||||||
|   | |||||||
| @@ -5,5 +5,5 @@ from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | |||||||
| from .math_dynamic_funcs import DynamicLinearFunc | from .math_dynamic_funcs import DynamicLinearFunc | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc | from .math_adv_funcs import ConstantFunc | ||||||
| from .math_adv_funcs import ComposedSinFunc | from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc | ||||||
| from .math_dynamic_generator import GaussianDGenerator | from .math_dynamic_generator import GaussianDGenerator | ||||||
|   | |||||||
| @@ -4,7 +4,11 @@ from .synthetic_env import SyntheticDEnv | |||||||
| from .math_core import LinearFunc | from .math_core import LinearFunc | ||||||
| from .math_core import DynamicLinearFunc | from .math_core import DynamicLinearFunc | ||||||
| from .math_core import DynamicQuadraticFunc | from .math_core import DynamicQuadraticFunc | ||||||
| from .math_core import ConstantFunc, ComposedSinFunc as SinFunc | from .math_core import ( | ||||||
|  |     ConstantFunc, | ||||||
|  |     ComposedSinFunc as SinFunc, | ||||||
|  |     ComposedCosFunc as CosFunc, | ||||||
|  | ) | ||||||
| from .math_core import GaussianDGenerator | from .math_core import GaussianDGenerator | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -50,6 +54,25 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|             data_generator, oracle_map, time_generator, num_per_task |             data_generator, oracle_map, time_generator, num_per_task | ||||||
|         ) |         ) | ||||||
|  |     elif version.lower() == "v3": | ||||||
|  |         mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0})  # sin(t) | ||||||
|  |         std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1})  # 0.5 cos(t) + 1 | ||||||
|  |         data_generator = GaussianDGenerator( | ||||||
|  |             [mean_generator], [[std_generator]], (-2, 2) | ||||||
|  |         ) | ||||||
|  |         time_generator = TimeStamp( | ||||||
|  |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|  |         ) | ||||||
|  |         oracle_map = DynamicQuadraticFunc( | ||||||
|  |             params={ | ||||||
|  |                 0: LinearFunc(params={0: 0.1, 1: 0}),  # 0.1 * t | ||||||
|  |                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) | ||||||
|  |                 2: ConstantFunc(0), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |         dynamic_env = SyntheticDEnv( | ||||||
|  |             data_generator, oracle_map, time_generator, num_per_task | ||||||
|  |         ) | ||||||
|     else: |     else: | ||||||
|         raise ValueError("Unknown version: {:}".format(version)) |         raise ValueError("Unknown version: {:}".format(version)) | ||||||
|     return dynamic_env |     return dynamic_env | ||||||
|   | |||||||
| @@ -39,9 +39,9 @@ def get_model(config: Dict[Text, Any], **kwargs): | |||||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] |         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||||
|         sub_layers, last_dim = [], kwargs["input_dim"] |         sub_layers, last_dim = [], kwargs["input_dim"] | ||||||
|         for i, hidden_dim in enumerate(kwargs["hidden_dims"]): |         for i, hidden_dim in enumerate(kwargs["hidden_dims"]): | ||||||
|  |             sub_layers.append(SuperLinear(last_dim, hidden_dim)) | ||||||
|             if hidden_dim > 1: |             if hidden_dim > 1: | ||||||
|                 sub_layers.append(norm_cls(hidden_dim, elementwise_affine=False)) |                 sub_layers.append(norm_cls(hidden_dim, elementwise_affine=False)) | ||||||
|             sub_layers.append(SuperLinear(last_dim, hidden_dim)) |  | ||||||
|             sub_layers.append(act_cls()) |             sub_layers.append(act_cls()) | ||||||
|             last_dim = hidden_dim |             last_dim = hidden_dim | ||||||
|         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) |         sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"])) | ||||||
|   | |||||||
| @@ -8,24 +8,44 @@ | |||||||
|  |  | ||||||
| import os, torch | import os, torch | ||||||
|  |  | ||||||
|  |  | ||||||
| def obtain_nas_infer_model(config, extra_model_path=None): | def obtain_nas_infer_model(config, extra_model_path=None): | ||||||
|  |  | ||||||
|   if config.arch == 'dxys': |     if config.arch == "dxys": | ||||||
|         from .DXYs import CifarNet, ImageNet, Networks |         from .DXYs import CifarNet, ImageNet, Networks | ||||||
|         from .DXYs import build_genotype_from_dict |         from .DXYs import build_genotype_from_dict | ||||||
|  |  | ||||||
|         if config.genotype is None: |         if config.genotype is None: | ||||||
|             if extra_model_path is not None and not os.path.isfile(extra_model_path): |             if extra_model_path is not None and not os.path.isfile(extra_model_path): | ||||||
|         raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path)) |                 raise ValueError( | ||||||
|  |                     "When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format( | ||||||
|  |                         extra_model_path | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|             xdata = torch.load(extra_model_path) |             xdata = torch.load(extra_model_path) | ||||||
|       current_epoch = xdata['epoch'] |             current_epoch = xdata["epoch"] | ||||||
|       genotype_dict = xdata['genotypes'][current_epoch-1] |             genotype_dict = xdata["genotypes"][current_epoch - 1] | ||||||
|             genotype = build_genotype_from_dict(genotype_dict) |             genotype = build_genotype_from_dict(genotype_dict) | ||||||
|         else: |         else: | ||||||
|             genotype = Networks[config.genotype] |             genotype = Networks[config.genotype] | ||||||
|     if config.dataset == 'cifar': |         if config.dataset == "cifar": | ||||||
|       return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num) |             return CifarNet( | ||||||
|     elif config.dataset == 'imagenet': |                 config.ichannel, | ||||||
|       return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num) |                 config.layers, | ||||||
|     else: raise ValueError('invalid dataset : {:}'.format(config.dataset)) |                 config.stem_multi, | ||||||
|  |                 config.auxiliary, | ||||||
|  |                 genotype, | ||||||
|  |                 config.class_num, | ||||||
|  |             ) | ||||||
|  |         elif config.dataset == "imagenet": | ||||||
|  |             return ImageNet( | ||||||
|  |                 config.ichannel, | ||||||
|  |                 config.layers, | ||||||
|  |                 config.auxiliary, | ||||||
|  |                 genotype, | ||||||
|  |                 config.class_num, | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|     raise ValueError('invalid nas arch type : {:}'.format(config.arch)) |             raise ValueError("invalid dataset : {:}".format(config.dataset)) | ||||||
|  |     else: | ||||||
|  |         raise ValueError("invalid nas arch type : {:}".format(config.arch)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user