diff --git a/ddpo_pytorch/config/base.py b/config/base.py similarity index 100% rename from ddpo_pytorch/config/base.py rename to config/base.py diff --git a/ddpo_pytorch/config/dgx.py b/config/dgx.py similarity index 79% rename from ddpo_pytorch/config/dgx.py rename to config/dgx.py index cd387ee..16b902b 100644 --- a/ddpo_pytorch/config/dgx.py +++ b/config/dgx.py @@ -1,5 +1,9 @@ import ml_collections -from ddpo_pytorch.config import base +import imp +import os + +base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) + def get_config(): config = base.get_config() @@ -20,4 +24,4 @@ def get_config(): config.per_prompt_stat_tracking = None - return config \ No newline at end of file + return config diff --git a/scripts/train.py b/scripts/train.py index 970153e..80da2e6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -26,7 +26,7 @@ tqdm = partial(tqdm.tqdm, dynamic_ncols=True) FLAGS = flags.FLAGS -config_flags.DEFINE_config_file("config", "ddpo_pytorch/config/base.py", "Training configuration.") +config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") logger = get_logger(__name__)