Merge pull request #9 from desaixie/main
Only log rewards from process 0
This commit is contained in:
commit
d7a63516cb
@ -335,14 +335,7 @@ def main(_):
|
|||||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||||
|
|
||||||
# gather rewards across processes
|
|
||||||
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
|
||||||
|
|
||||||
# log rewards and images
|
|
||||||
accelerator.log(
|
|
||||||
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
|
|
||||||
step=global_step,
|
|
||||||
)
|
|
||||||
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
@ -353,12 +346,21 @@ def main(_):
|
|||||||
{
|
{
|
||||||
"images": [
|
"images": [
|
||||||
wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}")
|
wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}")
|
||||||
for i, (prompt, reward) in enumerate(zip(prompts, rewards))
|
for i, (prompt, reward) in enumerate(zip(prompts, rewards)) # only log rewards from process 0
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# gather rewards across processes
|
||||||
|
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
||||||
|
|
||||||
|
# log rewards and images
|
||||||
|
accelerator.log(
|
||||||
|
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
|
||||||
|
step=global_step,
|
||||||
|
)
|
||||||
|
|
||||||
# per-prompt mean/std tracking
|
# per-prompt mean/std tracking
|
||||||
if config.per_prompt_stat_tracking:
|
if config.per_prompt_stat_tracking:
|
||||||
# gather the prompts across processes
|
# gather the prompts across processes
|
||||||
|
Loading…
Reference in New Issue
Block a user