autodl-projects/xautodl/nas_infer_model/__init__.py

52 lines
2.0 KiB
Python
Raw Permalink Normal View History

2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# I write this package to make AutoDL-Projects to be compatible with the old GDAS projects.
# Ideally, this package will be merged into lib/models/cell_infers in future.
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
##################################################
import os, torch
2019-09-28 10:24:47 +02:00
2021-05-26 10:53:44 +02:00
def obtain_nas_infer_model(config, extra_model_path=None):
2021-05-26 10:53:44 +02:00
if config.arch == "dxys":
from .DXYs import CifarNet, ImageNet, Networks
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError(
"When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format(
extra_model_path
)
)
xdata = torch.load(extra_model_path)
current_epoch = xdata["epoch"]
genotype_dict = xdata["genotypes"][current_epoch - 1]
genotype = build_genotype_from_dict(genotype_dict)
else:
genotype = Networks[config.genotype]
if config.dataset == "cifar":
return CifarNet(
config.ichannel,
config.layers,
config.stem_multi,
config.auxiliary,
genotype,
config.class_num,
)
elif config.dataset == "imagenet":
return ImageNet(
config.ichannel,
config.layers,
config.auxiliary,
genotype,
config.class_num,
)
else:
raise ValueError("invalid dataset : {:}".format(config.dataset))
else:
2021-05-26 10:53:44 +02:00
raise ValueError("invalid nas arch type : {:}".format(config.arch))