Fix stat tracking bug
This commit is contained in:
parent
5c16a90ceb
commit
1ce0994c8a
@ -9,6 +9,8 @@ class PerPromptStatTracker:
|
|||||||
self.stats = {}
|
self.stats = {}
|
||||||
|
|
||||||
def update(self, prompts, rewards):
|
def update(self, prompts, rewards):
|
||||||
|
prompts = np.array(prompts)
|
||||||
|
rewards = np.array(rewards)
|
||||||
unique = np.unique(prompts)
|
unique = np.unique(prompts)
|
||||||
advantages = np.empty_like(rewards)
|
advantages = np.empty_like(rewards)
|
||||||
for prompt in unique:
|
for prompt in unique:
|
||||||
|
Loading…
Reference in New Issue
Block a user