Commenting pass
This commit is contained in:
		| @@ -36,7 +36,7 @@ def main(_): | ||||
|     # basic Accelerate and logging setup | ||||
|     config = FLAGS.config | ||||
|  | ||||
|     unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | ||||
|     unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S") | ||||
|     if not config.run_name: | ||||
|         config.run_name = unique_id | ||||
|     else: | ||||
| @@ -67,8 +67,9 @@ def main(_): | ||||
|         log_with="wandb", | ||||
|         mixed_precision=config.mixed_precision, | ||||
|         project_config=accelerator_config, | ||||
|         # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of | ||||
|         # _samples_ to accumulate across | ||||
|         # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the | ||||
|         # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get | ||||
|         # the total number of optimizer steps to accumulate across. | ||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, | ||||
|     ) | ||||
|     if accelerator.is_main_process: | ||||
| @@ -243,6 +244,7 @@ def main(_): | ||||
|     logger.info(f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}") | ||||
|     logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}") | ||||
|  | ||||
|     assert config.sample.batch_size >= config.train.batch_size | ||||
|     assert config.sample.batch_size % config.train.batch_size == 0 | ||||
|     assert samples_per_epoch % total_train_batch_size == 0 | ||||
|  | ||||
| @@ -418,6 +420,7 @@ def main(_): | ||||
|                                 noise_pred = pipeline.unet( | ||||
|                                     sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||
|                                 ).sample | ||||
|                             # compute the log prob of next_latents given latents under the current model | ||||
|                             _, log_prob = ddim_step_with_logprob( | ||||
|                                 pipeline.scheduler, | ||||
|                                 noise_pred, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user