2023-06-24 04:25:54 +02:00
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
# with the following modifications:
2023-06-27 19:20:03 +02:00
# - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the
# `ddim` scheduler.
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
2023-06-24 04:25:54 +02:00
from typing import Any , Callable , Dict , List , Optional , Union
import torch
from diffusers . pipelines . stable_diffusion . pipeline_stable_diffusion import (
StableDiffusionPipeline ,
rescale_noise_cfg ,
)
from . ddim_with_logprob import ddim_step_with_logprob
@torch.no_grad ( )
def pipeline_with_logprob (
self : StableDiffusionPipeline ,
prompt : Union [ str , List [ str ] ] = None ,
height : Optional [ int ] = None ,
width : Optional [ int ] = None ,
num_inference_steps : int = 50 ,
guidance_scale : float = 7.5 ,
negative_prompt : Optional [ Union [ str , List [ str ] ] ] = None ,
num_images_per_prompt : Optional [ int ] = 1 ,
eta : float = 0.0 ,
generator : Optional [ Union [ torch . Generator , List [ torch . Generator ] ] ] = None ,
latents : Optional [ torch . FloatTensor ] = None ,
prompt_embeds : Optional [ torch . FloatTensor ] = None ,
negative_prompt_embeds : Optional [ torch . FloatTensor ] = None ,
output_type : Optional [ str ] = " pil " ,
return_dict : bool = True ,
callback : Optional [ Callable [ [ int , int , torch . FloatTensor ] , None ] ] = None ,
callback_steps : int = 1 ,
cross_attention_kwargs : Optional [ Dict [ str , Any ] ] = None ,
guidance_rescale : float = 0.0 ,
) :
r """
Function invoked when calling the pipeline for generation .
Args :
prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
The prompt or prompts to guide the image generation . If not defined , one has to pass ` prompt_embeds ` .
instead .
height ( ` int ` , * optional * , defaults to self . unet . config . sample_size * self . vae_scale_factor ) :
The height in pixels of the generated image .
width ( ` int ` , * optional * , defaults to self . unet . config . sample_size * self . vae_scale_factor ) :
The width in pixels of the generated image .
num_inference_steps ( ` int ` , * optional * , defaults to 50 ) :
The number of denoising steps . More denoising steps usually lead to a higher quality image at the
expense of slower inference .
guidance_scale ( ` float ` , * optional * , defaults to 7.5 ) :
Guidance scale as defined in [ Classifier - Free Diffusion Guidance ] ( https : / / arxiv . org / abs / 2207.12598 ) .
` guidance_scale ` is defined as ` w ` of equation 2. of [ Imagen
Paper ] ( https : / / arxiv . org / pdf / 2205.11487 . pdf ) . Guidance scale is enabled by setting ` guidance_scale >
1 ` . Higher guidance scale encourages to generate images that are closely linked to the text ` prompt ` ,
usually at the expense of lower image quality .
negative_prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
The prompt or prompts not to guide the image generation . If not defined , one has to pass
` negative_prompt_embeds ` instead . Ignored when not using guidance ( i . e . , ignored if ` guidance_scale ` is
less than ` 1 ` ) .
num_images_per_prompt ( ` int ` , * optional * , defaults to 1 ) :
The number of images to generate per prompt .
eta ( ` float ` , * optional * , defaults to 0.0 ) :
Corresponds to parameter eta ( η ) in the DDIM paper : https : / / arxiv . org / abs / 2010.02502 . Only applies to
[ ` schedulers . DDIMScheduler ` ] , will be ignored for others .
generator ( ` torch . Generator ` or ` List [ torch . Generator ] ` , * optional * ) :
One or a list of [ torch generator ( s ) ] ( https : / / pytorch . org / docs / stable / generated / torch . Generator . html )
to make generation deterministic .
latents ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated noisy latents , sampled from a Gaussian distribution , to be used as inputs for image
generation . Can be used to tweak the same generation with different prompts . If not provided , a latents
tensor will ge generated by sampling using the supplied random ` generator ` .
prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated text embeddings . Can be used to easily tweak text inputs , * e . g . * prompt weighting . If not
provided , text embeddings will be generated from ` prompt ` input argument .
negative_prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated negative text embeddings . Can be used to easily tweak text inputs , * e . g . * prompt
weighting . If not provided , negative_prompt_embeds will be generated from ` negative_prompt ` input
argument .
output_type ( ` str ` , * optional * , defaults to ` " pil " ` ) :
The output format of the generate image . Choose between
[ PIL ] ( https : / / pillow . readthedocs . io / en / stable / ) : ` PIL . Image . Image ` or ` np . array ` .
return_dict ( ` bool ` , * optional * , defaults to ` True ` ) :
Whether or not to return a [ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] instead of a
plain tuple .
callback ( ` Callable ` , * optional * ) :
A function that will be called every ` callback_steps ` steps during inference . The function will be
called with the following arguments : ` callback ( step : int , timestep : int , latents : torch . FloatTensor ) ` .
callback_steps ( ` int ` , * optional * , defaults to 1 ) :
The frequency at which the ` callback ` function will be called . If not specified , the callback will be
called at every step .
cross_attention_kwargs ( ` dict ` , * optional * ) :
A kwargs dictionary that if specified is passed along to the ` AttentionProcessor ` as defined under
` self . processor ` in
[ diffusers . cross_attention ] ( https : / / github . com / huggingface / diffusers / blob / main / src / diffusers / models / cross_attention . py ) .
guidance_rescale ( ` float ` , * optional * , defaults to 0.7 ) :
Guidance rescale factor proposed by [ Common Diffusion Noise Schedules and Sample Steps are
Flawed ] ( https : / / arxiv . org / pdf / 2305.08891 . pdf ) ` guidance_scale ` is defined as ` φ ` in equation 16. of
[ Common Diffusion Noise Schedules and Sample Steps are Flawed ] ( https : / / arxiv . org / pdf / 2305.08891 . pdf ) .
Guidance rescale factor should fix overexposure when using zero terminal SNR .
Examples :
Returns :
[ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] or ` tuple ` :
[ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] if ` return_dict ` is True , otherwise a ` tuple .
When returning a tuple , the first element is a list with the generated images , and the second element is a
list of ` bool ` s denoting whether the corresponding generated image likely represents " not-safe-for-work "
( nsfw ) content , according to the ` safety_checker ` .
"""
# 0. Default height and width to unet
height = height or self . unet . config . sample_size * self . vae_scale_factor
width = width or self . unet . config . sample_size * self . vae_scale_factor
# 1. Check inputs. Raise error if not correct
self . check_inputs ( prompt , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds )
# 2. Define call parameters
if prompt is not None and isinstance ( prompt , str ) :
batch_size = 1
elif prompt is not None and isinstance ( prompt , list ) :
batch_size = len ( prompt )
else :
batch_size = prompt_embeds . shape [ 0 ]
device = self . _execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs . get ( " scale " , None ) if cross_attention_kwargs is not None else None
prompt_embeds = self . _encode_prompt (
prompt ,
device ,
num_images_per_prompt ,
do_classifier_free_guidance ,
negative_prompt ,
prompt_embeds = prompt_embeds ,
negative_prompt_embeds = negative_prompt_embeds ,
lora_scale = text_encoder_lora_scale ,
)
# 4. Prepare timesteps
self . scheduler . set_timesteps ( num_inference_steps , device = device )
timesteps = self . scheduler . timesteps
# 5. Prepare latent variables
num_channels_latents = self . unet . config . in_channels
latents = self . prepare_latents (
batch_size * num_images_per_prompt ,
num_channels_latents ,
height ,
width ,
prompt_embeds . dtype ,
device ,
generator ,
latents ,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self . prepare_extra_step_kwargs ( generator , eta )
# 7. Denoising loop
num_warmup_steps = len ( timesteps ) - num_inference_steps * self . scheduler . order
all_latents = [ latents ]
all_log_probs = [ ]
with self . progress_bar ( total = num_inference_steps ) as progress_bar :
for i , t in enumerate ( timesteps ) :
# expand the latents if we are doing classifier free guidance
latent_model_input = torch . cat ( [ latents ] * 2 ) if do_classifier_free_guidance else latents
latent_model_input = self . scheduler . scale_model_input ( latent_model_input , t )
# predict the noise residual
noise_pred = self . unet (
latent_model_input ,
t ,
encoder_hidden_states = prompt_embeds ,
cross_attention_kwargs = cross_attention_kwargs ,
return_dict = False ,
) [ 0 ]
# perform guidance
if do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond )
if do_classifier_free_guidance and guidance_rescale > 0.0 :
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg ( noise_pred , noise_pred_text , guidance_rescale = guidance_rescale )
# compute the previous noisy sample x_t -> x_t-1
latents , log_prob = ddim_step_with_logprob ( self . scheduler , noise_pred , t , latents , * * extra_step_kwargs )
all_latents . append ( latents )
all_log_probs . append ( log_prob )
# call the callback, if provided
if i == len ( timesteps ) - 1 or ( ( i + 1 ) > num_warmup_steps and ( i + 1 ) % self . scheduler . order == 0 ) :
progress_bar . update ( )
if callback is not None and i % callback_steps == 0 :
callback ( i , t , latents )
if not output_type == " latent " :
image = self . vae . decode ( latents / self . vae . config . scaling_factor , return_dict = False ) [ 0 ]
image , has_nsfw_concept = self . run_safety_checker ( image , device , prompt_embeds . dtype )
else :
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None :
do_denormalize = [ True ] * image . shape [ 0 ]
else :
do_denormalize = [ not has_nsfw for has_nsfw in has_nsfw_concept ]
image = self . image_processor . postprocess ( image , output_type = output_type , do_denormalize = do_denormalize )
# Offload last model to CPU
if hasattr ( self , " final_offload_hook " ) and self . final_offload_hook is not None :
self . final_offload_hook . offload ( )
return image , has_nsfw_concept , all_latents , all_log_probs