Minor changes; add train_timestep_fraction
This commit is contained in:
		| @@ -32,14 +32,15 @@ def get_config(): | ||||
|     train.cfg = True | ||||
|     train.adv_clip_max = 10 | ||||
|     train.clip_range = 1e-4 | ||||
|     train.timestep_fraction = 1.0 | ||||
|  | ||||
|     # sampling | ||||
|     config.sample = sample = ml_collections.ConfigDict() | ||||
|     sample.num_steps = 5 | ||||
|     sample.num_steps = 10 | ||||
|     sample.eta = 1.0 | ||||
|     sample.guidance_scale = 5.0 | ||||
|     sample.batch_size = 1 | ||||
|     sample.num_batches_per_epoch = 1 | ||||
|     sample.num_batches_per_epoch = 2 | ||||
|  | ||||
|     # prompting | ||||
|     config.prompt_fn = "imagenet_animals" | ||||
| @@ -49,7 +50,7 @@ def get_config(): | ||||
|     config.reward_fn = "jpeg_compressibility" | ||||
|  | ||||
|     config.per_prompt_stat_tracking = ml_collections.ConfigDict() | ||||
|     config.per_prompt_stat_tracking.buffer_size = 64 | ||||
|     config.per_prompt_stat_tracking.buffer_size = 16 | ||||
|     config.per_prompt_stat_tracking.min_count = 16 | ||||
|  | ||||
|     return config | ||||
| @@ -8,20 +8,25 @@ base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py" | ||||
| def get_config(): | ||||
|     config = base.get_config() | ||||
|  | ||||
|     config.mixed_precision = "no" | ||||
|     config.pretrained.model = "runwayml/stable-diffusion-v1-5" | ||||
|  | ||||
|     config.mixed_precision = "fp16" | ||||
|     config.allow_tf32 = True | ||||
|     config.use_lora = False | ||||
|  | ||||
|     config.train.batch_size = 4 | ||||
|     config.train.gradient_accumulation_steps = 8 | ||||
|     config.train.learning_rate = 1e-5 | ||||
|     config.train.clip_range = 1.0 | ||||
|     config.train.gradient_accumulation_steps = 2 | ||||
|     config.train.learning_rate = 3e-5 | ||||
|     config.train.clip_range = 1e-4 | ||||
|  | ||||
|     # sampling | ||||
|     config.sample.num_steps = 50 | ||||
|     config.sample.batch_size = 16 | ||||
|     config.sample.num_batches_per_epoch = 2 | ||||
|     config.sample.batch_size = 8 | ||||
|     config.sample.num_batches_per_epoch = 4 | ||||
|  | ||||
|     config.per_prompt_stat_tracking = None | ||||
|     config.per_prompt_stat_tracking = { | ||||
|         "buffer_size": 16, | ||||
|         "min_count": 16, | ||||
|     } | ||||
|  | ||||
|     return config | ||||
|   | ||||
| @@ -14,18 +14,15 @@ class MLP(nn.Module): | ||||
|         super().__init__() | ||||
|         self.layers = nn.Sequential( | ||||
|             nn.Linear(768, 1024), | ||||
|             nn.Identity(), | ||||
|             nn.Dropout(0.2), | ||||
|             nn.Linear(1024, 128), | ||||
|             nn.Identity(), | ||||
|             nn.Dropout(0.2), | ||||
|             nn.Linear(128, 64), | ||||
|             nn.Identity(), | ||||
|             nn.Dropout(0.1), | ||||
|             nn.Linear(64, 16), | ||||
|             nn.Linear(16, 1), | ||||
|         ) | ||||
|  | ||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||
|         self.load_state_dict(state_dict) | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def forward(self, embed): | ||||
|         return self.layers(embed) | ||||
| @@ -37,6 +34,9 @@ class AestheticScorer(torch.nn.Module): | ||||
|         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.mlp = MLP() | ||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||
|         self.mlp.load_state_dict(state_dict) | ||||
|         self.eval() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def __call__(self, images): | ||||
| @@ -44,5 +44,5 @@ class AestheticScorer(torch.nn.Module): | ||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||
|         embed = self.clip.get_image_features(**inputs) | ||||
|         # normalize embedding | ||||
|         embed = embed / embed.norm(dim=-1, keepdim=True) | ||||
|         embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) | ||||
|         return self.mlp(embed) | ||||
|   | ||||
| @@ -35,8 +35,6 @@ def aesthetic_score(): | ||||
|     scorer = AestheticScorer().cuda() | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         if not isinstance(images, torch.Tensor): | ||||
|             images = torch.as_tensor(images) | ||||
|         scores = scorer(images) | ||||
|         return scores, {} | ||||
|  | ||||
|   | ||||
| @@ -34,17 +34,23 @@ logger = get_logger(__name__) | ||||
| def main(_): | ||||
|     # basic Accelerate and logging setup | ||||
|     config = FLAGS.config | ||||
|  | ||||
|     # number of timesteps within each trajectory to train on | ||||
|     num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) | ||||
|  | ||||
|     accelerator = Accelerator( | ||||
|         log_with="wandb", | ||||
|         mixed_precision=config.mixed_precision, | ||||
|         project_dir=config.logdir, | ||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps, | ||||
|         # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of | ||||
|         # _samples_ to accumulate across | ||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, | ||||
|     ) | ||||
|     if accelerator.is_main_process: | ||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) | ||||
|     logger.info(f"\n{config}") | ||||
|  | ||||
|     # set seed | ||||
|     # set seed (device_specific is very important to get different prompts on different devices) | ||||
|     set_seed(config.seed, device_specific=True) | ||||
|  | ||||
|     # load scheduler, tokenizer and models. | ||||
| @@ -152,7 +158,8 @@ def main(_): | ||||
|             config.per_prompt_stat_tracking.min_count, | ||||
|         ) | ||||
|  | ||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it uses more memory | ||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | ||||
|     # more memory | ||||
|     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast | ||||
|  | ||||
|     # Prepare everything with our `accelerator`. | ||||
| @@ -289,8 +296,15 @@ def main(_): | ||||
|         #################### TRAINING #################### | ||||
|         for inner_epoch in range(config.train.num_inner_epochs): | ||||
|             # shuffle samples along batch dimension | ||||
|             indices = torch.randperm(total_batch_size, device=accelerator.device) | ||||
|             samples = {k: v[indices] for k, v in samples.items()} | ||||
|             perm = torch.randperm(total_batch_size, device=accelerator.device) | ||||
|             samples = {k: v[perm] for k, v in samples.items()} | ||||
|  | ||||
|             # shuffle along time dimension independently for each sample | ||||
|             perms = torch.stack( | ||||
|                 [torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)] | ||||
|             ) | ||||
|             for key in ["timesteps", "latents", "next_latents", "log_probs"]: | ||||
|                 samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms] | ||||
|  | ||||
|             # rebatch for training | ||||
|             samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} | ||||
| @@ -300,6 +314,7 @@ def main(_): | ||||
|  | ||||
|             # train | ||||
|             pipeline.unet.train() | ||||
|             info = defaultdict(list) | ||||
|             for i, sample in tqdm( | ||||
|                 list(enumerate(samples_batched)), | ||||
|                 desc=f"Epoch {epoch}.{inner_epoch}: training", | ||||
| @@ -312,9 +327,8 @@ def main(_): | ||||
|                 else: | ||||
|                     embeds = sample["prompt_embeds"] | ||||
|  | ||||
|                 info = defaultdict(list) | ||||
|                 for j in tqdm( | ||||
|                     range(num_timesteps), | ||||
|                     range(num_train_timesteps), | ||||
|                     desc="Timestep", | ||||
|                     position=1, | ||||
|                     leave=False, | ||||
| @@ -371,14 +385,20 @@ def main(_): | ||||
|                         optimizer.step() | ||||
|                         optimizer.zero_grad() | ||||
|  | ||||
|                     # Checks if the accelerator has performed an optimization step behind the scenes | ||||
|                     if accelerator.sync_gradients: | ||||
|                         assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0 | ||||
|                         # log training-related stuff | ||||
|                         info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} | ||||
|                         info = accelerator.reduce(info, reduction="mean") | ||||
|                         info.update({"epoch": epoch, "inner_epoch": inner_epoch}) | ||||
|                         accelerator.log(info, step=global_step) | ||||
|                         global_step += 1 | ||||
|                         info = defaultdict(list) | ||||
|  | ||||
|             # make sure we did an optimization step at the end of the inner epoch | ||||
|             assert accelerator.sync_gradients | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     app.run(main) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user