Merge pull request #9 from desaixie/main

Only log rewards from process 0
This commit is contained in:
Kevin Black 2023-08-22 11:54:52 -07:00 committed by GitHub
commit d7a63516cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -335,14 +335,7 @@ def main(_):
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
# gather rewards across processes
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
# log rewards and images
accelerator.log(
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
step=global_step,
)
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
with tempfile.TemporaryDirectory() as tmpdir:
for i, image in enumerate(images):
@ -353,11 +346,20 @@ def main(_):
{
"images": [
wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}")
for i, (prompt, reward) in enumerate(zip(prompts, rewards))
for i, (prompt, reward) in enumerate(zip(prompts, rewards)) # only log rewards from process 0
],
},
step=global_step,
)
# gather rewards across processes
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
# log rewards and images
accelerator.log(
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
step=global_step,
)
# per-prompt mean/std tracking
if config.per_prompt_stat_tracking: