Fix stat tracking bug
This commit is contained in:
parent
5c16a90ceb
commit
1ce0994c8a
@ -9,6 +9,8 @@ class PerPromptStatTracker:
|
||||
self.stats = {}
|
||||
|
||||
def update(self, prompts, rewards):
|
||||
prompts = np.array(prompts)
|
||||
rewards = np.array(rewards)
|
||||
unique = np.unique(prompts)
|
||||
advantages = np.empty_like(rewards)
|
||||
for prompt in unique:
|
||||
|
Loading…
Reference in New Issue
Block a user