Fix aesthetic score (again), add llava reward
This commit is contained in:
		| @@ -2,6 +2,8 @@ from collections import defaultdict | ||||
| import contextlib | ||||
| import os | ||||
| import datetime | ||||
| from concurrent import futures | ||||
| import time | ||||
| from absl import app, flags | ||||
| from ml_collections import config_flags | ||||
| from accelerate import Accelerator | ||||
| @@ -227,6 +229,10 @@ def main(_): | ||||
|     # Prepare everything with our `accelerator`. | ||||
|     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) | ||||
|  | ||||
|     # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a | ||||
|     # remote server running llava inference. | ||||
|     executor = futures.ThreadPoolExecutor(max_workers=2) | ||||
|  | ||||
|     # Train! | ||||
|     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||
|     total_train_batch_size = ( | ||||
| @@ -298,8 +304,10 @@ def main(_): | ||||
|             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||
|             timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1)  # (batch_size, num_steps) | ||||
|  | ||||
|             # compute rewards | ||||
|             rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) | ||||
|             # compute rewards asynchronously | ||||
|             rewards = executor.submit(reward_fn, images, prompts, prompt_metadata) | ||||
|             # yield to to make sure reward computation starts | ||||
|             time.sleep(0) | ||||
|  | ||||
|             samples.append( | ||||
|                 { | ||||
| @@ -309,10 +317,21 @@ def main(_): | ||||
|                     "latents": latents[:, :-1],  # each entry is the latent before timestep t | ||||
|                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t | ||||
|                     "log_probs": log_probs, | ||||
|                     "rewards": torch.as_tensor(rewards, device=accelerator.device), | ||||
|                     "rewards": rewards, | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         # wait for all rewards to be computed | ||||
|         for sample in tqdm( | ||||
|             samples, | ||||
|             desc="Waiting for rewards", | ||||
|             disable=not accelerator.is_local_main_process, | ||||
|             position=0, | ||||
|         ): | ||||
|             rewards, reward_metadata = sample["rewards"].result() | ||||
|             # accelerator.print(reward_metadata) | ||||
|             sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device) | ||||
|  | ||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||
|  | ||||
| @@ -472,7 +491,7 @@ def main(_): | ||||
|             # make sure we did an optimization step at the end of the inner epoch | ||||
|             assert accelerator.sync_gradients | ||||
|  | ||||
|         if epoch % config.save_freq == 0 and accelerator.is_main_process: | ||||
|         if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process: | ||||
|             accelerator.save_state() | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user