ddpo-pytorch/ddpo_pytorch/stat_tracking.py
2023-06-26 22:25:43 -07:00

37 lines
1.2 KiB
Python

import numpy as np
from collections import deque
class PerPromptStatTracker:
def __init__(self, buffer_size, min_count):
self.buffer_size = buffer_size
self.min_count = min_count
self.stats = {}
def update(self, prompts, rewards):
prompts = np.array(prompts)
rewards = np.array(rewards)
unique = np.unique(prompts)
advantages = np.empty_like(rewards)
for prompt in unique:
prompt_rewards = rewards[prompts == prompt]
if prompt not in self.stats:
self.stats[prompt] = deque(maxlen=self.buffer_size)
self.stats[prompt].extend(prompt_rewards)
if len(self.stats[prompt]) < self.min_count:
mean = np.mean(rewards)
std = np.std(rewards) + 1e-6
else:
mean = np.mean(self.stats[prompt])
std = np.std(self.stats[prompt]) + 1e-6
advantages[prompts == prompt] = (prompt_rewards - mean) / std
return advantages
def get_stats(self):
return {
k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)}
for k, v in self.stats.items()
}