#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# I write this package to make AutoDL-Projects to be compatible with the old GDAS projects.
# Ideally, this package will be merged into lib/models/cell_infers in future.
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
##################################################

import os, torch


def obtain_nas_infer_model(config, extra_model_path=None):

    if config.arch == "dxys":
        from .DXYs import CifarNet, ImageNet, Networks
        from .DXYs import build_genotype_from_dict

        if config.genotype is None:
            if extra_model_path is not None and not os.path.isfile(extra_model_path):
                raise ValueError(
                    "When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format(
                        extra_model_path
                    )
                )
            xdata = torch.load(extra_model_path)
            current_epoch = xdata["epoch"]
            genotype_dict = xdata["genotypes"][current_epoch - 1]
            genotype = build_genotype_from_dict(genotype_dict)
        else:
            genotype = Networks[config.genotype]
        if config.dataset == "cifar":
            return CifarNet(
                config.ichannel,
                config.layers,
                config.stem_multi,
                config.auxiliary,
                genotype,
                config.class_num,
            )
        elif config.dataset == "imagenet":
            return ImageNet(
                config.ichannel,
                config.layers,
                config.auxiliary,
                genotype,
                config.class_num,
            )
        else:
            raise ValueError("invalid dataset : {:}".format(config.dataset))
    else:
        raise ValueError("invalid nas arch type : {:}".format(config.arch))