Fix stat tracking bug

This commit is contained in:
Kevin Black 2023-06-26 22:25:43 -07:00
parent 5c16a90ceb
commit 1ce0994c8a

View File

@ -9,6 +9,8 @@ class PerPromptStatTracker:
self.stats = {}
def update(self, prompts, rewards):
prompts = np.array(prompts)
rewards = np.array(rewards)
unique = np.unique(prompts)
advantages = np.empty_like(rewards)
for prompt in unique: