37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import numpy as np
|
|
from collections import deque
|
|
|
|
|
|
class PerPromptStatTracker:
|
|
def __init__(self, buffer_size, min_count):
|
|
self.buffer_size = buffer_size
|
|
self.min_count = min_count
|
|
self.stats = {}
|
|
|
|
def update(self, prompts, rewards):
|
|
prompts = np.array(prompts)
|
|
rewards = np.array(rewards)
|
|
unique = np.unique(prompts)
|
|
advantages = np.empty_like(rewards)
|
|
for prompt in unique:
|
|
prompt_rewards = rewards[prompts == prompt]
|
|
if prompt not in self.stats:
|
|
self.stats[prompt] = deque(maxlen=self.buffer_size)
|
|
self.stats[prompt].extend(prompt_rewards)
|
|
|
|
if len(self.stats[prompt]) < self.min_count:
|
|
mean = np.mean(rewards)
|
|
std = np.std(rewards) + 1e-6
|
|
else:
|
|
mean = np.mean(self.stats[prompt])
|
|
std = np.std(self.stats[prompt]) + 1e-6
|
|
advantages[prompts == prompt] = (prompt_rewards - mean) / std
|
|
|
|
return advantages
|
|
|
|
def get_stats(self):
|
|
return {
|
|
k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)}
|
|
for k, v in self.stats.items()
|
|
}
|