From 1ce0994c8a3c5ef4c79e330e461cfa7a45cf23d7 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Mon, 26 Jun 2023 22:25:43 -0700 Subject: [PATCH] Fix stat tracking bug --- ddpo_pytorch/stat_tracking.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ddpo_pytorch/stat_tracking.py b/ddpo_pytorch/stat_tracking.py index 4199ab9..ee50034 100644 --- a/ddpo_pytorch/stat_tracking.py +++ b/ddpo_pytorch/stat_tracking.py @@ -9,6 +9,8 @@ class PerPromptStatTracker: self.stats = {} def update(self, prompts, rewards): + prompts = np.array(prompts) + rewards = np.array(rewards) unique = np.unique(prompts) advantages = np.empty_like(rewards) for prompt in unique: