From 5c16a90cebf8974f72a54073c4b0a39a719f9506 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Sun, 25 Jun 2023 21:02:27 -0700 Subject: [PATCH] Move config out of module --- {ddpo_pytorch/config => config}/base.py | 0 {ddpo_pytorch/config => config}/dgx.py | 8 ++++++-- scripts/train.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) rename {ddpo_pytorch/config => config}/base.py (100%) rename {ddpo_pytorch/config => config}/dgx.py (79%) diff --git a/ddpo_pytorch/config/base.py b/config/base.py similarity index 100% rename from ddpo_pytorch/config/base.py rename to config/base.py diff --git a/ddpo_pytorch/config/dgx.py b/config/dgx.py similarity index 79% rename from ddpo_pytorch/config/dgx.py rename to config/dgx.py index cd387ee..16b902b 100644 --- a/ddpo_pytorch/config/dgx.py +++ b/config/dgx.py @@ -1,5 +1,9 @@ import ml_collections -from ddpo_pytorch.config import base +import imp +import os + +base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) + def get_config(): config = base.get_config() @@ -20,4 +24,4 @@ def get_config(): config.per_prompt_stat_tracking = None - return config \ No newline at end of file + return config diff --git a/scripts/train.py b/scripts/train.py index 970153e..80da2e6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -26,7 +26,7 @@ tqdm = partial(tqdm.tqdm, dynamic_ncols=True) FLAGS = flags.FLAGS -config_flags.DEFINE_config_file("config", "ddpo_pytorch/config/base.py", "Training configuration.") +config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") logger = get_logger(__name__)