diff --git a/scripts/train.py b/scripts/train.py index e5e832e..ae2b71e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -335,14 +335,7 @@ def main(_): # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} - # gather rewards across processes - rewards = accelerator.gather(samples["rewards"]).cpu().numpy() - - # log rewards and images - accelerator.log( - {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, - step=global_step, - ) + # this is a hack to force wandb to log the images as JPEGs instead of PNGs with tempfile.TemporaryDirectory() as tmpdir: for i, image in enumerate(images): @@ -353,11 +346,20 @@ def main(_): { "images": [ wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}") - for i, (prompt, reward) in enumerate(zip(prompts, rewards)) + for i, (prompt, reward) in enumerate(zip(prompts, rewards)) # only log rewards from process 0 ], }, step=global_step, ) + + # gather rewards across processes + rewards = accelerator.gather(samples["rewards"]).cpu().numpy() + + # log rewards and images + accelerator.log( + {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, + step=global_step, + ) # per-prompt mean/std tracking if config.per_prompt_stat_tracking: