Fix stat tracking bug
This commit is contained in:
		@@ -9,6 +9,8 @@ class PerPromptStatTracker:
 | 
				
			|||||||
        self.stats = {}
 | 
					        self.stats = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, prompts, rewards):
 | 
					    def update(self, prompts, rewards):
 | 
				
			||||||
 | 
					        prompts = np.array(prompts)
 | 
				
			||||||
 | 
					        rewards = np.array(rewards)
 | 
				
			||||||
        unique = np.unique(prompts)
 | 
					        unique = np.unique(prompts)
 | 
				
			||||||
        advantages = np.empty_like(rewards)
 | 
					        advantages = np.empty_like(rewards)
 | 
				
			||||||
        for prompt in unique:
 | 
					        for prompt in unique:
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user