Only log rewards from process 0
This commit is contained in:
		| @@ -335,14 +335,7 @@ def main(_): | ||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||
|  | ||||
|         # gather rewards across processes | ||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||
|          | ||||
|         # log rewards and images | ||||
|         accelerator.log( | ||||
|             {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, | ||||
|             step=global_step, | ||||
|         ) | ||||
|         # this is a hack to force wandb to log the images as JPEGs instead of PNGs | ||||
|         with tempfile.TemporaryDirectory() as tmpdir: | ||||
|             for i, image in enumerate(images): | ||||
| @@ -353,12 +346,21 @@ def main(_): | ||||
|                 { | ||||
|                     "images": [ | ||||
|                         wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}") | ||||
|                         for i, (prompt, reward) in enumerate(zip(prompts, rewards)) | ||||
|                         for i, (prompt, reward) in enumerate(zip(prompts, rewards))  # only log rewards from process 0 | ||||
|                     ], | ||||
|                 }, | ||||
|                 step=global_step, | ||||
|             ) | ||||
|              | ||||
|         # gather rewards across processes | ||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||
|  | ||||
|         # log rewards and images | ||||
|         accelerator.log( | ||||
|             {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, | ||||
|             step=global_step, | ||||
|         ) | ||||
|  | ||||
|         # per-prompt mean/std tracking | ||||
|         if config.per_prompt_stat_tracking: | ||||
|             # gather the prompts across processes | ||||
|   | ||||
		Reference in New Issue
	
	Block a user