89 lines
2.9 KiB
Python
89 lines
2.9 KiB
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
|
#####################################################
|
|
from typing import Union, Dict, Text, Any
|
|
import importlib
|
|
|
|
from .yaml_utils import load_yaml
|
|
|
|
CLS_FUNC_KEY = "class_or_func"
|
|
KEYS = (CLS_FUNC_KEY, "module_path", "args", "kwargs")
|
|
|
|
|
|
def has_key_words(xdict):
|
|
if not isinstance(xdict, dict):
|
|
return False
|
|
key_set = set(KEYS)
|
|
cur_set = set(xdict.keys())
|
|
return key_set.intersection(cur_set) == key_set
|
|
|
|
|
|
def get_module_by_module_path(module_path):
|
|
"""Load the module from the path."""
|
|
|
|
if module_path.endswith(".py"):
|
|
module_spec = importlib.util.spec_from_file_location("", module_path)
|
|
module = importlib.util.module_from_spec(module_spec)
|
|
module_spec.loader.exec_module(module)
|
|
else:
|
|
module = importlib.import_module(module_path)
|
|
|
|
return module
|
|
|
|
|
|
def call_by_dict(config: Dict[Text, Any], *args, **kwargs) -> object:
|
|
"""
|
|
get initialized instance with config
|
|
Parameters
|
|
----------
|
|
config : a dictionary, such as:
|
|
{
|
|
'cls_or_func': 'ClassName',
|
|
'args': list,
|
|
'kwargs': dict,
|
|
'model_path': a string indicating the path,
|
|
}
|
|
Returns
|
|
-------
|
|
object:
|
|
An initialized object based on the config info
|
|
"""
|
|
module = get_module_by_module_path(config["module_path"])
|
|
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
|
args = tuple(list(config["args"]) + list(args))
|
|
kwargs = {**config["kwargs"], **kwargs}
|
|
return cls_or_func(*args, **kwargs)
|
|
|
|
|
|
def call_by_yaml(path, *args, **kwargs) -> object:
|
|
config = load_yaml(path)
|
|
return call_by_config(config, *args, **kwargs)
|
|
|
|
|
|
def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object:
|
|
"""Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called."""
|
|
if isinstance(config, list):
|
|
return [nested_call_by_dict(x) for x in config]
|
|
elif isinstance(config, tuple):
|
|
return (nested_call_by_dict(x) for x in config)
|
|
elif not isinstance(config, dict):
|
|
return config
|
|
elif not has_key_words(config):
|
|
return {key: nested_call_by_dict(x) for x, key in config.items()}
|
|
else:
|
|
module = get_module_by_module_path(config["module_path"])
|
|
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
|
args = tuple(list(config["args"]) + list(args))
|
|
kwargs = {**config["kwargs"], **kwargs}
|
|
# check whether there are nested special dict
|
|
new_args = [nested_call_by_dict(x) for x in args]
|
|
new_kwargs = {}
|
|
for key, x in kwargs.items():
|
|
new_kwargs[key] = nested_call_by_dict(x)
|
|
return cls_or_func(*new_args, **new_kwargs)
|
|
|
|
|
|
def nested_call_by_yaml(path, *args, **kwargs) -> object:
|
|
config = load_yaml(path)
|
|
return nested_call_by_dict(config, *args, **kwargs)
|