##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # ##################################################### from typing import Union, Dict, Text, Any import importlib from .yaml_utils import load_yaml CLS_FUNC_KEY = "class_or_func" KEYS = (CLS_FUNC_KEY, "module_path", "args", "kwargs") def has_key_words(xdict): if not isinstance(xdict, dict): return False key_set = set(KEYS) cur_set = set(xdict.keys()) return key_set.intersection(cur_set) == key_set def get_module_by_module_path(module_path): """Load the module from the path.""" if module_path.endswith(".py"): module_spec = importlib.util.spec_from_file_location("", module_path) module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) else: module = importlib.import_module(module_path) return module def call_by_dict(config: Dict[Text, Any], *args, **kwargs) -> object: """ get initialized instance with config Parameters ---------- config : a dictionary, such as: { 'cls_or_func': 'ClassName', 'args': list, 'kwargs': dict, 'model_path': a string indicating the path, } Returns ------- object: An initialized object based on the config info """ module = get_module_by_module_path(config["module_path"]) cls_or_func = getattr(module, config[CLS_FUNC_KEY]) args = tuple(list(config["args"]) + list(args)) kwargs = {**config["kwargs"], **kwargs} return cls_or_func(*args, **kwargs) def call_by_yaml(path, *args, **kwargs) -> object: config = load_yaml(path) return call_by_config(config, *args, **kwargs) def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object: """Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called.""" if isinstance(config, list): return [nested_call_by_dict(x) for x in config] elif isinstance(config, tuple): return (nested_call_by_dict(x) for x in config) elif not isinstance(config, dict): return config elif not has_key_words(config): return {key: nested_call_by_dict(x) for x, key in config.items()} else: module = get_module_by_module_path(config["module_path"]) cls_or_func = getattr(module, config[CLS_FUNC_KEY]) args = tuple(list(config["args"]) + list(args)) kwargs = {**config["kwargs"], **kwargs} # check whether there are nested special dict new_args = [nested_call_by_dict(x) for x in args] new_kwargs = {} for key, x in kwargs.items(): new_kwargs[key] = nested_call_by_dict(x) return cls_or_func(*new_args, **new_kwargs) def nested_call_by_yaml(path, *args, **kwargs) -> object: config = load_yaml(path) return nested_call_by_dict(config, *args, **kwargs)