Compare commits
10 Commits
c67c2adfee
...
1958463f02
Author | SHA1 | Date | |
---|---|---|---|
|
1958463f02 | ||
|
378dd18298 | ||
|
bfcba5e28e | ||
|
b590ec0a7c | ||
|
500edd2b53 | ||
|
e17ecd265d | ||
|
5955244f37 | ||
|
d7a63516cb | ||
|
3130ddfaff | ||
|
173b2bb6e0 |
12
README.md
12
README.md
@ -46,3 +46,15 @@ accelerate launch scripts/train.py --config config/dgx.py:aesthetic
|
|||||||
```
|
```
|
||||||
|
|
||||||
If you want to run the LLaVA prompt-image alignment experiments, you need to dedicate a few GPUs to running LLaVA inference using [this repo](https://github.com/kvablack/LLaVA-server/).
|
If you want to run the LLaVA prompt-image alignment experiments, you need to dedicate a few GPUs to running LLaVA inference using [this repo](https://github.com/kvablack/LLaVA-server/).
|
||||||
|
|
||||||
|
## Reward Curves
|
||||||
|
<img src="https://github.com/kvablack/ddpo-pytorch/assets/12429600/593c9be3-e2a7-45d8-b1ae-ca4f77197c18" width="49%">
|
||||||
|
<img src="https://github.com/kvablack/ddpo-pytorch/assets/12429600/d12fef0a-68b8-4cef-a9b8-cb1b6878fcec" width="49%">
|
||||||
|
<img src="https://github.com/kvablack/ddpo-pytorch/assets/12429600/669076d5-2826-4b77-835b-d82e0c18a2a6" width="49%">
|
||||||
|
<img src="https://github.com/kvablack/ddpo-pytorch/assets/12429600/393a929e-36af-46f2-8022-33384bdae1c8" width="49%">
|
||||||
|
|
||||||
|
## Training using 🤗 `trl`
|
||||||
|
|
||||||
|
🤗 `trl` provides a [`DDPOTrainer` class](https://huggingface.co/docs/trl/ddpo_trainer) which lets you fine-tune Stable Diffusion on different reward functions using DDPO. The integration supports LoRA, too. You can check out the [supplementary blog post](https://huggingface.co/blog/trl-ddpo) for additional guidance. The DDPO integration was contributed by @metric-space to `trl`.
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +35,9 @@ class AestheticScorer(torch.nn.Module):
|
|||||||
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
self.mlp = MLP()
|
self.mlp = MLP()
|
||||||
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
|
state_dict = torch.load(
|
||||||
|
ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")
|
||||||
|
)
|
||||||
self.mlp.load_state_dict(state_dict)
|
self.mlp.load_state_dict(state_dict)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.eval()
|
self.eval()
|
||||||
|
@ -20,9 +20,13 @@ def _left_broadcast(t, shape):
|
|||||||
|
|
||||||
|
|
||||||
def _get_variance(self, timestep, prev_timestep):
|
def _get_variance(self, timestep, prev_timestep):
|
||||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
|
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(
|
||||||
|
timestep.device
|
||||||
|
)
|
||||||
alpha_prod_t_prev = torch.where(
|
alpha_prod_t_prev = torch.where(
|
||||||
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
|
prev_timestep.cpu() >= 0,
|
||||||
|
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
|
||||||
|
self.final_alpha_cumprod,
|
||||||
).to(timestep.device)
|
).to(timestep.device)
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||||
@ -86,31 +90,45 @@ def ddim_step_with_logprob(
|
|||||||
# - pred_prev_sample -> "x_t-1"
|
# - pred_prev_sample -> "x_t-1"
|
||||||
|
|
||||||
# 1. get previous step value (=t-1)
|
# 1. get previous step value (=t-1)
|
||||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
prev_timestep = (
|
||||||
|
timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||||
|
)
|
||||||
# to prevent OOB on gather
|
# to prevent OOB on gather
|
||||||
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
|
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
|
||||||
|
|
||||||
# 2. compute alphas, betas
|
# 2. compute alphas, betas
|
||||||
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
|
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
|
||||||
alpha_prod_t_prev = torch.where(
|
alpha_prod_t_prev = torch.where(
|
||||||
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
|
prev_timestep.cpu() >= 0,
|
||||||
|
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
|
||||||
|
self.final_alpha_cumprod,
|
||||||
)
|
)
|
||||||
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
|
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
|
||||||
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
|
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
|
||||||
|
sample.device
|
||||||
|
)
|
||||||
|
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
|
||||||
# 3. compute predicted original sample from predicted noise also called
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
if self.config.prediction_type == "epsilon":
|
if self.config.prediction_type == "epsilon":
|
||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
pred_original_sample = (
|
||||||
|
sample - beta_prod_t ** (0.5) * model_output
|
||||||
|
) / alpha_prod_t ** (0.5)
|
||||||
pred_epsilon = model_output
|
pred_epsilon = model_output
|
||||||
elif self.config.prediction_type == "sample":
|
elif self.config.prediction_type == "sample":
|
||||||
pred_original_sample = model_output
|
pred_original_sample = model_output
|
||||||
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
pred_epsilon = (
|
||||||
|
sample - alpha_prod_t ** (0.5) * pred_original_sample
|
||||||
|
) / beta_prod_t ** (0.5)
|
||||||
elif self.config.prediction_type == "v_prediction":
|
elif self.config.prediction_type == "v_prediction":
|
||||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (
|
||||||
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
beta_prod_t**0.5
|
||||||
|
) * model_output
|
||||||
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (
|
||||||
|
beta_prod_t**0.5
|
||||||
|
) * sample
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||||
@ -133,13 +151,19 @@ def ddim_step_with_logprob(
|
|||||||
|
|
||||||
if use_clipped_model_output:
|
if use_clipped_model_output:
|
||||||
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
||||||
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
pred_epsilon = (
|
||||||
|
sample - alpha_prod_t ** (0.5) * pred_original_sample
|
||||||
|
) / beta_prod_t ** (0.5)
|
||||||
|
|
||||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
|
||||||
|
0.5
|
||||||
|
) * pred_epsilon
|
||||||
|
|
||||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
prev_sample_mean = (
|
||||||
|
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||||
|
)
|
||||||
|
|
||||||
if prev_sample is not None and generator is not None:
|
if prev_sample is not None and generator is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -149,7 +173,10 @@ def ddim_step_with_logprob(
|
|||||||
|
|
||||||
if prev_sample is None:
|
if prev_sample is None:
|
||||||
variance_noise = randn_tensor(
|
variance_noise = randn_tensor(
|
||||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
model_output.shape,
|
||||||
|
generator=generator,
|
||||||
|
device=model_output.device,
|
||||||
|
dtype=model_output.dtype,
|
||||||
)
|
)
|
||||||
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
||||||
|
|
||||||
|
@ -116,7 +116,15 @@ def pipeline_with_logprob(
|
|||||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
# 1. Check inputs. Raise error if not correct
|
# 1. Check inputs. Raise error if not correct
|
||||||
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
callback_steps,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Define call parameters
|
# 2. Define call parameters
|
||||||
if prompt is not None and isinstance(prompt, str):
|
if prompt is not None and isinstance(prompt, str):
|
||||||
@ -133,7 +141,11 @@ def pipeline_with_logprob(
|
|||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
# 3. Encode input prompt
|
# 3. Encode input prompt
|
||||||
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
text_encoder_lora_scale = (
|
||||||
|
cross_attention_kwargs.get("scale", None)
|
||||||
|
if cross_attention_kwargs is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
prompt_embeds = self._encode_prompt(
|
prompt_embeds = self._encode_prompt(
|
||||||
prompt,
|
prompt,
|
||||||
device,
|
device,
|
||||||
@ -172,7 +184,9 @@ def pipeline_with_logprob(
|
|||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = (
|
||||||
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
)
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
@ -187,27 +201,39 @@ def pipeline_with_logprob(
|
|||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||||
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs)
|
latents, log_prob = ddim_step_with_logprob(
|
||||||
|
self.scheduler, noise_pred, t, latents, **extra_step_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
all_latents.append(latents)
|
all_latents.append(latents)
|
||||||
all_log_probs.append(log_prob)
|
all_log_probs.append(log_prob)
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
if not output_type == "latent":
|
if not output_type == "latent":
|
||||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
image = self.vae.decode(
|
||||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
latents / self.vae.config.scaling_factor, return_dict=False
|
||||||
|
)[0]
|
||||||
|
image, has_nsfw_concept = self.run_safety_checker(
|
||||||
|
image, device, prompt_embeds.dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image = latents
|
image = latents
|
||||||
has_nsfw_concept = None
|
has_nsfw_concept = None
|
||||||
@ -217,7 +243,9 @@ def pipeline_with_logprob(
|
|||||||
else:
|
else:
|
||||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||||
|
|
||||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
image = self.image_processor.postprocess(
|
||||||
|
image, output_type=output_type, do_denormalize=do_denormalize
|
||||||
|
)
|
||||||
|
|
||||||
# Offload last model to CPU
|
# Offload last model to CPU
|
||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||||
|
@ -35,7 +35,11 @@ def aesthetic_score():
|
|||||||
scorer = AestheticScorer(dtype=torch.float32).cuda()
|
scorer = AestheticScorer(dtype=torch.float32).cuda()
|
||||||
|
|
||||||
def _fn(images, prompts, metadata):
|
def _fn(images, prompts, metadata):
|
||||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
|
if isinstance(images, torch.Tensor):
|
||||||
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
|
||||||
|
else:
|
||||||
|
images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
|
||||||
|
images = torch.tensor(images, dtype=torch.uint8)
|
||||||
scores = scorer(images)
|
scores = scorer(images)
|
||||||
return scores, {}
|
return scores, {}
|
||||||
|
|
||||||
@ -55,7 +59,9 @@ def llava_strict_satisfaction():
|
|||||||
batch_size = 4
|
batch_size = 4
|
||||||
url = "http://127.0.0.1:8085"
|
url = "http://127.0.0.1:8085"
|
||||||
sess = requests.Session()
|
sess = requests.Session()
|
||||||
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
|
retries = Retry(
|
||||||
|
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
|
||||||
|
)
|
||||||
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
||||||
|
|
||||||
def _fn(images, prompts, metadata):
|
def _fn(images, prompts, metadata):
|
||||||
@ -121,7 +127,9 @@ def llava_bertscore():
|
|||||||
batch_size = 16
|
batch_size = 16
|
||||||
url = "http://127.0.0.1:8085"
|
url = "http://127.0.0.1:8085"
|
||||||
sess = requests.Session()
|
sess = requests.Session()
|
||||||
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
|
retries = Retry(
|
||||||
|
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
|
||||||
|
)
|
||||||
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
||||||
|
|
||||||
def _fn(images, prompts, metadata):
|
def _fn(images, prompts, metadata):
|
||||||
@ -152,8 +160,11 @@ def llava_bertscore():
|
|||||||
# format for LLaVA server
|
# format for LLaVA server
|
||||||
data = {
|
data = {
|
||||||
"images": jpeg_images,
|
"images": jpeg_images,
|
||||||
"queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch),
|
"queries": [["Answer concisely: what is going on in this image?"]]
|
||||||
"answers": [[f"The image contains {prompt}"] for prompt in prompt_batch],
|
* len(image_batch),
|
||||||
|
"answers": [
|
||||||
|
[f"The image contains {prompt}"] for prompt in prompt_batch
|
||||||
|
],
|
||||||
}
|
}
|
||||||
data_bytes = pickle.dumps(data)
|
data_bytes = pickle.dumps(data)
|
||||||
|
|
||||||
@ -167,7 +178,9 @@ def llava_bertscore():
|
|||||||
all_scores += scores.tolist()
|
all_scores += scores.tolist()
|
||||||
|
|
||||||
# save the precision and f1 scores for analysis
|
# save the precision and f1 scores for analysis
|
||||||
all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist()
|
all_info["precision"] += (
|
||||||
|
np.array(response_data["precision"]).squeeze().tolist()
|
||||||
|
)
|
||||||
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
|
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
|
||||||
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
|
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
|
||||||
|
|
||||||
|
207
scripts/train.py
207
scripts/train.py
@ -48,7 +48,9 @@ def main(_):
|
|||||||
config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
|
config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
|
||||||
if "checkpoint_" not in os.path.basename(config.resume_from):
|
if "checkpoint_" not in os.path.basename(config.resume_from):
|
||||||
# get the most recent checkpoint in this directory
|
# get the most recent checkpoint in this directory
|
||||||
checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from)))
|
checkpoints = list(
|
||||||
|
filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from))
|
||||||
|
)
|
||||||
if len(checkpoints) == 0:
|
if len(checkpoints) == 0:
|
||||||
raise ValueError(f"No checkpoints found in {config.resume_from}")
|
raise ValueError(f"No checkpoints found in {config.resume_from}")
|
||||||
config.resume_from = os.path.join(
|
config.resume_from = os.path.join(
|
||||||
@ -72,11 +74,14 @@ def main(_):
|
|||||||
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
||||||
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
||||||
# the total number of optimizer steps to accumulate across.
|
# the total number of optimizer steps to accumulate across.
|
||||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
gradient_accumulation_steps=config.train.gradient_accumulation_steps
|
||||||
|
* num_train_timesteps,
|
||||||
)
|
)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers(
|
accelerator.init_trackers(
|
||||||
project_name="ddpo-pytorch", config=config.to_dict(), init_kwargs={"wandb": {"name": config.run_name}}
|
project_name="ddpo-pytorch",
|
||||||
|
config=config.to_dict(),
|
||||||
|
init_kwargs={"wandb": {"name": config.run_name}},
|
||||||
)
|
)
|
||||||
logger.info(f"\n{config}")
|
logger.info(f"\n{config}")
|
||||||
|
|
||||||
@ -84,7 +89,9 @@ def main(_):
|
|||||||
set_seed(config.seed, device_specific=True)
|
set_seed(config.seed, device_specific=True)
|
||||||
|
|
||||||
# load scheduler, tokenizer and models.
|
# load scheduler, tokenizer and models.
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
config.pretrained.model, revision=config.pretrained.revision
|
||||||
|
)
|
||||||
# freeze parameters of models to save more memory
|
# freeze parameters of models to save more memory
|
||||||
pipeline.vae.requires_grad_(False)
|
pipeline.vae.requires_grad_(False)
|
||||||
pipeline.text_encoder.requires_grad_(False)
|
pipeline.text_encoder.requires_grad_(False)
|
||||||
@ -121,22 +128,36 @@ def main(_):
|
|||||||
lora_attn_procs = {}
|
lora_attn_procs = {}
|
||||||
for name in pipeline.unet.attn_processors.keys():
|
for name in pipeline.unet.attn_processors.keys():
|
||||||
cross_attention_dim = (
|
cross_attention_dim = (
|
||||||
None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim
|
None
|
||||||
|
if name.endswith("attn1.processor")
|
||||||
|
else pipeline.unet.config.cross_attention_dim
|
||||||
)
|
)
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
hidden_size = pipeline.unet.config.block_out_channels[-1]
|
hidden_size = pipeline.unet.config.block_out_channels[-1]
|
||||||
elif name.startswith("up_blocks"):
|
elif name.startswith("up_blocks"):
|
||||||
block_id = int(name[len("up_blocks.")])
|
block_id = int(name[len("up_blocks.")])
|
||||||
hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id]
|
hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[
|
||||||
|
block_id
|
||||||
|
]
|
||||||
elif name.startswith("down_blocks"):
|
elif name.startswith("down_blocks"):
|
||||||
block_id = int(name[len("down_blocks.")])
|
block_id = int(name[len("down_blocks.")])
|
||||||
hidden_size = pipeline.unet.config.block_out_channels[block_id]
|
hidden_size = pipeline.unet.config.block_out_channels[block_id]
|
||||||
|
|
||||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
lora_attn_procs[name] = LoRAAttnProcessor(
|
||||||
|
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||||
|
)
|
||||||
pipeline.unet.set_attn_processor(lora_attn_procs)
|
pipeline.unet.set_attn_processor(lora_attn_procs)
|
||||||
trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors)
|
|
||||||
|
# this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
|
||||||
|
# this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
|
||||||
|
# `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
|
||||||
|
class _Wrapper(AttnProcsLayers):
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return pipeline.unet(*args, **kwargs)
|
||||||
|
|
||||||
|
unet = _Wrapper(pipeline.unet.attn_processors)
|
||||||
else:
|
else:
|
||||||
trainable_layers = pipeline.unet
|
unet = pipeline.unet
|
||||||
|
|
||||||
# set up diffusers-friendly checkpoint saving with Accelerate
|
# set up diffusers-friendly checkpoint saving with Accelerate
|
||||||
|
|
||||||
@ -155,13 +176,19 @@ def main(_):
|
|||||||
if config.use_lora and isinstance(models[0], AttnProcsLayers):
|
if config.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||||
# pipeline.unet.load_attn_procs(input_dir)
|
# pipeline.unet.load_attn_procs(input_dir)
|
||||||
tmp_unet = UNet2DConditionModel.from_pretrained(
|
tmp_unet = UNet2DConditionModel.from_pretrained(
|
||||||
config.pretrained.model, revision=config.pretrained.revision, subfolder="unet"
|
config.pretrained.model,
|
||||||
|
revision=config.pretrained.revision,
|
||||||
|
subfolder="unet",
|
||||||
)
|
)
|
||||||
tmp_unet.load_attn_procs(input_dir)
|
tmp_unet.load_attn_procs(input_dir)
|
||||||
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
|
models[0].load_state_dict(
|
||||||
|
AttnProcsLayers(tmp_unet.attn_processors).state_dict()
|
||||||
|
)
|
||||||
del tmp_unet
|
del tmp_unet
|
||||||
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
|
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
load_model = UNet2DConditionModel.from_pretrained(
|
||||||
|
input_dir, subfolder="unet"
|
||||||
|
)
|
||||||
models[0].register_to_config(**load_model.config)
|
models[0].register_to_config(**load_model.config)
|
||||||
models[0].load_state_dict(load_model.state_dict())
|
models[0].load_state_dict(load_model.state_dict())
|
||||||
del load_model
|
del load_model
|
||||||
@ -191,7 +218,7 @@ def main(_):
|
|||||||
optimizer_cls = torch.optim.AdamW
|
optimizer_cls = torch.optim.AdamW
|
||||||
|
|
||||||
optimizer = optimizer_cls(
|
optimizer = optimizer_cls(
|
||||||
trainable_layers.parameters(),
|
unet.parameters(),
|
||||||
lr=config.train.learning_rate,
|
lr=config.train.learning_rate,
|
||||||
betas=(config.train.adam_beta1, config.train.adam_beta2),
|
betas=(config.train.adam_beta1, config.train.adam_beta2),
|
||||||
weight_decay=config.train.adam_weight_decay,
|
weight_decay=config.train.adam_weight_decay,
|
||||||
@ -225,29 +252,42 @@ def main(_):
|
|||||||
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||||
# more memory
|
# more memory
|
||||||
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
||||||
|
# autocast = accelerator.autocast
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
# Prepare everything with our `accelerator`.
|
||||||
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
|
unet, optimizer = accelerator.prepare(unet, optimizer)
|
||||||
|
|
||||||
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
|
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
|
||||||
# remote server running llava inference.
|
# remote server running llava inference.
|
||||||
executor = futures.ThreadPoolExecutor(max_workers=2)
|
executor = futures.ThreadPoolExecutor(max_workers=2)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
|
samples_per_epoch = (
|
||||||
|
config.sample.batch_size
|
||||||
|
* accelerator.num_processes
|
||||||
|
* config.sample.num_batches_per_epoch
|
||||||
|
)
|
||||||
total_train_batch_size = (
|
total_train_batch_size = (
|
||||||
config.train.batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps
|
config.train.batch_size
|
||||||
|
* accelerator.num_processes
|
||||||
|
* config.train.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(f" Num Epochs = {config.num_epochs}")
|
logger.info(f" Num Epochs = {config.num_epochs}")
|
||||||
logger.info(f" Sample batch size per device = {config.sample.batch_size}")
|
logger.info(f" Sample batch size per device = {config.sample.batch_size}")
|
||||||
logger.info(f" Train batch size per device = {config.train.batch_size}")
|
logger.info(f" Train batch size per device = {config.train.batch_size}")
|
||||||
logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}")
|
logger.info(
|
||||||
|
f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}"
|
||||||
|
)
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info(f" Total number of samples per epoch = {samples_per_epoch}")
|
logger.info(f" Total number of samples per epoch = {samples_per_epoch}")
|
||||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
logger.info(
|
||||||
logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}")
|
f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}"
|
||||||
|
)
|
||||||
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
|
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
|
||||||
|
|
||||||
assert config.sample.batch_size >= config.train.batch_size
|
assert config.sample.batch_size >= config.train.batch_size
|
||||||
@ -275,7 +315,10 @@ def main(_):
|
|||||||
):
|
):
|
||||||
# generate prompts
|
# generate prompts
|
||||||
prompts, prompt_metadata = zip(
|
prompts, prompt_metadata = zip(
|
||||||
*[prompt_fn(**config.prompt_fn_kwargs) for _ in range(config.sample.batch_size)]
|
*[
|
||||||
|
prompt_fn(**config.prompt_fn_kwargs)
|
||||||
|
for _ in range(config.sample.batch_size)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# encode prompts
|
# encode prompts
|
||||||
@ -300,9 +343,13 @@ def main(_):
|
|||||||
output_type="pt",
|
output_type="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64)
|
latents = torch.stack(
|
||||||
|
latents, dim=1
|
||||||
|
) # (batch_size, num_steps + 1, 4, 64, 64)
|
||||||
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
||||||
timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps)
|
timesteps = pipeline.scheduler.timesteps.repeat(
|
||||||
|
config.sample.batch_size, 1
|
||||||
|
) # (batch_size, num_steps)
|
||||||
|
|
||||||
# compute rewards asynchronously
|
# compute rewards asynchronously
|
||||||
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
|
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
|
||||||
@ -314,8 +361,12 @@ def main(_):
|
|||||||
"prompt_ids": prompt_ids,
|
"prompt_ids": prompt_ids,
|
||||||
"prompt_embeds": prompt_embeds,
|
"prompt_embeds": prompt_embeds,
|
||||||
"timesteps": timesteps,
|
"timesteps": timesteps,
|
||||||
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
"latents": latents[
|
||||||
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
:, :-1
|
||||||
|
], # each entry is the latent before timestep t
|
||||||
|
"next_latents": latents[
|
||||||
|
:, 1:
|
||||||
|
], # each entry is the latent after timestep t
|
||||||
"log_probs": log_probs,
|
"log_probs": log_probs,
|
||||||
"rewards": rewards,
|
"rewards": rewards,
|
||||||
}
|
}
|
||||||
@ -335,35 +386,50 @@ def main(_):
|
|||||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||||
|
|
||||||
# gather rewards across processes
|
|
||||||
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
|
||||||
|
|
||||||
# log rewards and images
|
|
||||||
accelerator.log(
|
|
||||||
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
|
|
||||||
step=global_step,
|
|
||||||
)
|
|
||||||
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
pil = Image.fromarray((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
|
pil = Image.fromarray(
|
||||||
|
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||||
|
)
|
||||||
pil = pil.resize((256, 256))
|
pil = pil.resize((256, 256))
|
||||||
pil.save(os.path.join(tmpdir, f"{i}.jpg"))
|
pil.save(os.path.join(tmpdir, f"{i}.jpg"))
|
||||||
accelerator.log(
|
accelerator.log(
|
||||||
{
|
{
|
||||||
"images": [
|
"images": [
|
||||||
wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}")
|
wandb.Image(
|
||||||
for i, (prompt, reward) in enumerate(zip(prompts, rewards))
|
os.path.join(tmpdir, f"{i}.jpg"),
|
||||||
|
caption=f"{prompt:.25} | {reward:.2f}",
|
||||||
|
)
|
||||||
|
for i, (prompt, reward) in enumerate(
|
||||||
|
zip(prompts, rewards)
|
||||||
|
) # only log rewards from process 0
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# gather rewards across processes
|
||||||
|
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
||||||
|
|
||||||
|
# log rewards and images
|
||||||
|
accelerator.log(
|
||||||
|
{
|
||||||
|
"reward": rewards,
|
||||||
|
"epoch": epoch,
|
||||||
|
"reward_mean": rewards.mean(),
|
||||||
|
"reward_std": rewards.std(),
|
||||||
|
},
|
||||||
|
step=global_step,
|
||||||
|
)
|
||||||
|
|
||||||
# per-prompt mean/std tracking
|
# per-prompt mean/std tracking
|
||||||
if config.per_prompt_stat_tracking:
|
if config.per_prompt_stat_tracking:
|
||||||
# gather the prompts across processes
|
# gather the prompts across processes
|
||||||
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
||||||
prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
prompts = pipeline.tokenizer.batch_decode(
|
||||||
|
prompt_ids, skip_special_tokens=True
|
||||||
|
)
|
||||||
advantages = stat_tracker.update(prompts, rewards)
|
advantages = stat_tracker.update(prompts, rewards)
|
||||||
else:
|
else:
|
||||||
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||||
@ -379,7 +445,10 @@ def main(_):
|
|||||||
del samples["prompt_ids"]
|
del samples["prompt_ids"]
|
||||||
|
|
||||||
total_batch_size, num_timesteps = samples["timesteps"].shape
|
total_batch_size, num_timesteps = samples["timesteps"].shape
|
||||||
assert total_batch_size == config.sample.batch_size * config.sample.num_batches_per_epoch
|
assert (
|
||||||
|
total_batch_size
|
||||||
|
== config.sample.batch_size * config.sample.num_batches_per_epoch
|
||||||
|
)
|
||||||
assert num_timesteps == config.sample.num_steps
|
assert num_timesteps == config.sample.num_steps
|
||||||
|
|
||||||
#################### TRAINING ####################
|
#################### TRAINING ####################
|
||||||
@ -390,16 +459,27 @@ def main(_):
|
|||||||
|
|
||||||
# shuffle along time dimension independently for each sample
|
# shuffle along time dimension independently for each sample
|
||||||
perms = torch.stack(
|
perms = torch.stack(
|
||||||
[torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)]
|
[
|
||||||
|
torch.randperm(num_timesteps, device=accelerator.device)
|
||||||
|
for _ in range(total_batch_size)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
||||||
samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms]
|
samples[key] = samples[key][
|
||||||
|
torch.arange(total_batch_size, device=accelerator.device)[:, None],
|
||||||
|
perms,
|
||||||
|
]
|
||||||
|
|
||||||
# rebatch for training
|
# rebatch for training
|
||||||
samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()}
|
samples_batched = {
|
||||||
|
k: v.reshape(-1, config.train.batch_size, *v.shape[1:])
|
||||||
|
for k, v in samples.items()
|
||||||
|
}
|
||||||
|
|
||||||
# dict of lists -> list of dicts for easier iteration
|
# dict of lists -> list of dicts for easier iteration
|
||||||
samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())]
|
samples_batched = [
|
||||||
|
dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
|
||||||
|
]
|
||||||
|
|
||||||
# train
|
# train
|
||||||
pipeline.unet.train()
|
pipeline.unet.train()
|
||||||
@ -412,7 +492,9 @@ def main(_):
|
|||||||
):
|
):
|
||||||
if config.train.cfg:
|
if config.train.cfg:
|
||||||
# concat negative prompts to sample prompts to avoid two forward passes
|
# concat negative prompts to sample prompts to avoid two forward passes
|
||||||
embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]])
|
embeds = torch.cat(
|
||||||
|
[train_neg_prompt_embeds, sample["prompt_embeds"]]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embeds = sample["prompt_embeds"]
|
embeds = sample["prompt_embeds"]
|
||||||
|
|
||||||
@ -423,21 +505,25 @@ def main(_):
|
|||||||
leave=False,
|
leave=False,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
):
|
):
|
||||||
with accelerator.accumulate(pipeline.unet):
|
with accelerator.accumulate(unet):
|
||||||
with autocast():
|
with autocast():
|
||||||
if config.train.cfg:
|
if config.train.cfg:
|
||||||
noise_pred = pipeline.unet(
|
noise_pred = unet(
|
||||||
torch.cat([sample["latents"][:, j]] * 2),
|
torch.cat([sample["latents"][:, j]] * 2),
|
||||||
torch.cat([sample["timesteps"][:, j]] * 2),
|
torch.cat([sample["timesteps"][:, j]] * 2),
|
||||||
embeds,
|
embeds,
|
||||||
).sample
|
).sample
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + config.sample.guidance_scale * (
|
noise_pred = (
|
||||||
noise_pred_text - noise_pred_uncond
|
noise_pred_uncond
|
||||||
|
+ config.sample.guidance_scale
|
||||||
|
* (noise_pred_text - noise_pred_uncond)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
noise_pred = pipeline.unet(
|
noise_pred = unet(
|
||||||
sample["latents"][:, j], sample["timesteps"][:, j], embeds
|
sample["latents"][:, j],
|
||||||
|
sample["timesteps"][:, j],
|
||||||
|
embeds,
|
||||||
).sample
|
).sample
|
||||||
# compute the log prob of next_latents given latents under the current model
|
# compute the log prob of next_latents given latents under the current model
|
||||||
_, log_prob = ddim_step_with_logprob(
|
_, log_prob = ddim_step_with_logprob(
|
||||||
@ -451,12 +537,16 @@ def main(_):
|
|||||||
|
|
||||||
# ppo logic
|
# ppo logic
|
||||||
advantages = torch.clamp(
|
advantages = torch.clamp(
|
||||||
sample["advantages"], -config.train.adv_clip_max, config.train.adv_clip_max
|
sample["advantages"],
|
||||||
|
-config.train.adv_clip_max,
|
||||||
|
config.train.adv_clip_max,
|
||||||
)
|
)
|
||||||
ratio = torch.exp(log_prob - sample["log_probs"][:, j])
|
ratio = torch.exp(log_prob - sample["log_probs"][:, j])
|
||||||
unclipped_loss = -advantages * ratio
|
unclipped_loss = -advantages * ratio
|
||||||
clipped_loss = -advantages * torch.clamp(
|
clipped_loss = -advantages * torch.clamp(
|
||||||
ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range
|
ratio,
|
||||||
|
1.0 - config.train.clip_range,
|
||||||
|
1.0 + config.train.clip_range,
|
||||||
)
|
)
|
||||||
loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
||||||
|
|
||||||
@ -464,14 +554,25 @@ def main(_):
|
|||||||
# John Schulman says that (ratio - 1) - log(ratio) is a better
|
# John Schulman says that (ratio - 1) - log(ratio) is a better
|
||||||
# estimator, but most existing code uses this so...
|
# estimator, but most existing code uses this so...
|
||||||
# http://joschu.net/blog/kl-approx.html
|
# http://joschu.net/blog/kl-approx.html
|
||||||
info["approx_kl"].append(0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2))
|
info["approx_kl"].append(
|
||||||
info["clipfrac"].append(torch.mean((torch.abs(ratio - 1.0) > config.train.clip_range).float()))
|
0.5
|
||||||
|
* torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
|
||||||
|
)
|
||||||
|
info["clipfrac"].append(
|
||||||
|
torch.mean(
|
||||||
|
(
|
||||||
|
torch.abs(ratio - 1.0) > config.train.clip_range
|
||||||
|
).float()
|
||||||
|
)
|
||||||
|
)
|
||||||
info["loss"].append(loss)
|
info["loss"].append(loss)
|
||||||
|
|
||||||
# backward pass
|
# backward pass
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm)
|
accelerator.clip_grad_norm_(
|
||||||
|
unet.parameters(), config.train.max_grad_norm
|
||||||
|
)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user