Only log rewards from process 0
This commit is contained in:
		| @@ -335,14 +335,7 @@ def main(_): | |||||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) |         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} |         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||||
|  |  | ||||||
|         # gather rewards across processes |          | ||||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() |  | ||||||
|  |  | ||||||
|         # log rewards and images |  | ||||||
|         accelerator.log( |  | ||||||
|             {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, |  | ||||||
|             step=global_step, |  | ||||||
|         ) |  | ||||||
|         # this is a hack to force wandb to log the images as JPEGs instead of PNGs |         # this is a hack to force wandb to log the images as JPEGs instead of PNGs | ||||||
|         with tempfile.TemporaryDirectory() as tmpdir: |         with tempfile.TemporaryDirectory() as tmpdir: | ||||||
|             for i, image in enumerate(images): |             for i, image in enumerate(images): | ||||||
| @@ -353,11 +346,20 @@ def main(_): | |||||||
|                 { |                 { | ||||||
|                     "images": [ |                     "images": [ | ||||||
|                         wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}") |                         wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}") | ||||||
|                         for i, (prompt, reward) in enumerate(zip(prompts, rewards)) |                         for i, (prompt, reward) in enumerate(zip(prompts, rewards))  # only log rewards from process 0 | ||||||
|                     ], |                     ], | ||||||
|                 }, |                 }, | ||||||
|                 step=global_step, |                 step=global_step, | ||||||
|             ) |             ) | ||||||
|  |              | ||||||
|  |         # gather rewards across processes | ||||||
|  |         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||||
|  |  | ||||||
|  |         # log rewards and images | ||||||
|  |         accelerator.log( | ||||||
|  |             {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, | ||||||
|  |             step=global_step, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # per-prompt mean/std tracking |         # per-prompt mean/std tracking | ||||||
|         if config.per_prompt_stat_tracking: |         if config.per_prompt_stat_tracking: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user