Fix stat tracking bug
This commit is contained in:
		| @@ -9,6 +9,8 @@ class PerPromptStatTracker: | |||||||
|         self.stats = {} |         self.stats = {} | ||||||
|  |  | ||||||
|     def update(self, prompts, rewards): |     def update(self, prompts, rewards): | ||||||
|  |         prompts = np.array(prompts) | ||||||
|  |         rewards = np.array(rewards) | ||||||
|         unique = np.unique(prompts) |         unique = np.unique(prompts) | ||||||
|         advantages = np.empty_like(rewards) |         advantages = np.empty_like(rewards) | ||||||
|         for prompt in unique: |         for prompt in unique: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user