Minor changes; add train_timestep_fraction

This commit is contained in:
Kevin Black 2023-06-27 22:17:32 -07:00
parent bae3f43f5f
commit 28d2d8c40e
5 changed files with 50 additions and 26 deletions

View File

@ -32,14 +32,15 @@ def get_config():
train.cfg = True
train.adv_clip_max = 10
train.clip_range = 1e-4
train.timestep_fraction = 1.0
# sampling
config.sample = sample = ml_collections.ConfigDict()
sample.num_steps = 5
sample.num_steps = 10
sample.eta = 1.0
sample.guidance_scale = 5.0
sample.batch_size = 1
sample.num_batches_per_epoch = 1
sample.num_batches_per_epoch = 2
# prompting
config.prompt_fn = "imagenet_animals"
@ -49,7 +50,7 @@ def get_config():
config.reward_fn = "jpeg_compressibility"
config.per_prompt_stat_tracking = ml_collections.ConfigDict()
config.per_prompt_stat_tracking.buffer_size = 64
config.per_prompt_stat_tracking.buffer_size = 16
config.per_prompt_stat_tracking.min_count = 16
return config

View File

@ -8,20 +8,25 @@ base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"
def get_config():
config = base.get_config()
config.mixed_precision = "no"
config.pretrained.model = "runwayml/stable-diffusion-v1-5"
config.mixed_precision = "fp16"
config.allow_tf32 = True
config.use_lora = False
config.train.batch_size = 4
config.train.gradient_accumulation_steps = 8
config.train.learning_rate = 1e-5
config.train.clip_range = 1.0
config.train.gradient_accumulation_steps = 2
config.train.learning_rate = 3e-5
config.train.clip_range = 1e-4
# sampling
config.sample.num_steps = 50
config.sample.batch_size = 16
config.sample.num_batches_per_epoch = 2
config.sample.batch_size = 8
config.sample.num_batches_per_epoch = 4
config.per_prompt_stat_tracking = None
config.per_prompt_stat_tracking = {
"buffer_size": 16,
"min_count": 16,
}
return config

View File

@ -14,18 +14,15 @@ class MLP(nn.Module):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Identity(),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Identity(),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Identity(),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
self.load_state_dict(state_dict)
@torch.no_grad()
def forward(self, embed):
return self.layers(embed)
@ -37,6 +34,9 @@ class AestheticScorer(torch.nn.Module):
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP()
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
self.mlp.load_state_dict(state_dict)
self.eval()
@torch.no_grad()
def __call__(self, images):
@ -44,5 +44,5 @@ class AestheticScorer(torch.nn.Module):
inputs = {k: v.cuda() for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs)
# normalize embedding
embed = embed / embed.norm(dim=-1, keepdim=True)
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.mlp(embed)

View File

@ -35,8 +35,6 @@ def aesthetic_score():
scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata):
if not isinstance(images, torch.Tensor):
images = torch.as_tensor(images)
scores = scorer(images)
return scores, {}

View File

@ -34,17 +34,23 @@ logger = get_logger(__name__)
def main(_):
# basic Accelerate and logging setup
config = FLAGS.config
# number of timesteps within each trajectory to train on
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
accelerator = Accelerator(
log_with="wandb",
mixed_precision=config.mixed_precision,
project_dir=config.logdir,
gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps,
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
# _samples_ to accumulate across
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
)
if accelerator.is_main_process:
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
logger.info(f"\n{config}")
# set seed
# set seed (device_specific is very important to get different prompts on different devices)
set_seed(config.seed, device_specific=True)
# load scheduler, tokenizer and models.
@ -152,7 +158,8 @@ def main(_):
config.per_prompt_stat_tracking.min_count,
)
# for some reason, autocast is necessary for non-lora training but for lora training it uses more memory
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
# Prepare everything with our `accelerator`.
@ -289,8 +296,15 @@ def main(_):
#################### TRAINING ####################
for inner_epoch in range(config.train.num_inner_epochs):
# shuffle samples along batch dimension
indices = torch.randperm(total_batch_size, device=accelerator.device)
samples = {k: v[indices] for k, v in samples.items()}
perm = torch.randperm(total_batch_size, device=accelerator.device)
samples = {k: v[perm] for k, v in samples.items()}
# shuffle along time dimension independently for each sample
perms = torch.stack(
[torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)]
)
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms]
# rebatch for training
samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()}
@ -300,6 +314,7 @@ def main(_):
# train
pipeline.unet.train()
info = defaultdict(list)
for i, sample in tqdm(
list(enumerate(samples_batched)),
desc=f"Epoch {epoch}.{inner_epoch}: training",
@ -312,9 +327,8 @@ def main(_):
else:
embeds = sample["prompt_embeds"]
info = defaultdict(list)
for j in tqdm(
range(num_timesteps),
range(num_train_timesteps),
desc="Timestep",
position=1,
leave=False,
@ -371,14 +385,20 @@ def main(_):
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
# log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = accelerator.reduce(info, reduction="mean")
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
accelerator.log(info, step=global_step)
global_step += 1
info = defaultdict(list)
# make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients
if __name__ == "__main__":
app.run(main)