Fix stat tracking bug
This commit is contained in:
		| @@ -9,6 +9,8 @@ class PerPromptStatTracker: | ||||
|         self.stats = {} | ||||
|  | ||||
|     def update(self, prompts, rewards): | ||||
|         prompts = np.array(prompts) | ||||
|         rewards = np.array(rewards) | ||||
|         unique = np.unique(prompts) | ||||
|         advantages = np.empty_like(rewards) | ||||
|         for prompt in unique: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user