Merge pull request #9 from desaixie/main
Only log rewards from process 0
This commit is contained in:
commit
d7a63516cb
@ -335,14 +335,7 @@ def main(_):
|
||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||
|
||||
# gather rewards across processes
|
||||
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
||||
|
||||
# log rewards and images
|
||||
accelerator.log(
|
||||
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
for i, image in enumerate(images):
|
||||
@ -353,11 +346,20 @@ def main(_):
|
||||
{
|
||||
"images": [
|
||||
wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}")
|
||||
for i, (prompt, reward) in enumerate(zip(prompts, rewards))
|
||||
for i, (prompt, reward) in enumerate(zip(prompts, rewards)) # only log rewards from process 0
|
||||
],
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
# gather rewards across processes
|
||||
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
||||
|
||||
# log rewards and images
|
||||
accelerator.log(
|
||||
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
# per-prompt mean/std tracking
|
||||
if config.per_prompt_stat_tracking:
|
||||
|
Loading…
Reference in New Issue
Block a user