import numpy as np from collections import deque class PerPromptStatTracker: def __init__(self, buffer_size, min_count): self.buffer_size = buffer_size self.min_count = min_count self.stats = {} def update(self, prompts, rewards): prompts = np.array(prompts) rewards = np.array(rewards) unique = np.unique(prompts) advantages = np.empty_like(rewards) for prompt in unique: prompt_rewards = rewards[prompts == prompt] if prompt not in self.stats: self.stats[prompt] = deque(maxlen=self.buffer_size) self.stats[prompt].extend(prompt_rewards) if len(self.stats[prompt]) < self.min_count: mean = np.mean(rewards) std = np.std(rewards) + 1e-6 else: mean = np.mean(self.stats[prompt]) std = np.std(self.stats[prompt]) + 1e-6 advantages[prompts == prompt] = (prompt_rewards - mean) / std return advantages def get_stats(self): return { k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items() }