Minor changes; add train_timestep_fraction
This commit is contained in:
parent
bae3f43f5f
commit
28d2d8c40e
@ -32,14 +32,15 @@ def get_config():
|
||||
train.cfg = True
|
||||
train.adv_clip_max = 10
|
||||
train.clip_range = 1e-4
|
||||
train.timestep_fraction = 1.0
|
||||
|
||||
# sampling
|
||||
config.sample = sample = ml_collections.ConfigDict()
|
||||
sample.num_steps = 5
|
||||
sample.num_steps = 10
|
||||
sample.eta = 1.0
|
||||
sample.guidance_scale = 5.0
|
||||
sample.batch_size = 1
|
||||
sample.num_batches_per_epoch = 1
|
||||
sample.num_batches_per_epoch = 2
|
||||
|
||||
# prompting
|
||||
config.prompt_fn = "imagenet_animals"
|
||||
@ -49,7 +50,7 @@ def get_config():
|
||||
config.reward_fn = "jpeg_compressibility"
|
||||
|
||||
config.per_prompt_stat_tracking = ml_collections.ConfigDict()
|
||||
config.per_prompt_stat_tracking.buffer_size = 64
|
||||
config.per_prompt_stat_tracking.buffer_size = 16
|
||||
config.per_prompt_stat_tracking.min_count = 16
|
||||
|
||||
return config
|
@ -8,20 +8,25 @@ base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"
|
||||
def get_config():
|
||||
config = base.get_config()
|
||||
|
||||
config.mixed_precision = "no"
|
||||
config.pretrained.model = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
config.mixed_precision = "fp16"
|
||||
config.allow_tf32 = True
|
||||
config.use_lora = False
|
||||
|
||||
config.train.batch_size = 4
|
||||
config.train.gradient_accumulation_steps = 8
|
||||
config.train.learning_rate = 1e-5
|
||||
config.train.clip_range = 1.0
|
||||
config.train.gradient_accumulation_steps = 2
|
||||
config.train.learning_rate = 3e-5
|
||||
config.train.clip_range = 1e-4
|
||||
|
||||
# sampling
|
||||
config.sample.num_steps = 50
|
||||
config.sample.batch_size = 16
|
||||
config.sample.num_batches_per_epoch = 2
|
||||
config.sample.batch_size = 8
|
||||
config.sample.num_batches_per_epoch = 4
|
||||
|
||||
config.per_prompt_stat_tracking = None
|
||||
config.per_prompt_stat_tracking = {
|
||||
"buffer_size": 16,
|
||||
"min_count": 16,
|
||||
}
|
||||
|
||||
return config
|
||||
|
@ -14,18 +14,15 @@ class MLP(nn.Module):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(768, 1024),
|
||||
nn.Identity(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(1024, 128),
|
||||
nn.Identity(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(128, 64),
|
||||
nn.Identity(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(64, 16),
|
||||
nn.Linear(16, 1),
|
||||
)
|
||||
|
||||
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, embed):
|
||||
return self.layers(embed)
|
||||
@ -37,6 +34,9 @@ class AestheticScorer(torch.nn.Module):
|
||||
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.mlp = MLP()
|
||||
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
|
||||
self.mlp.load_state_dict(state_dict)
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, images):
|
||||
@ -44,5 +44,5 @@ class AestheticScorer(torch.nn.Module):
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
embed = self.clip.get_image_features(**inputs)
|
||||
# normalize embedding
|
||||
embed = embed / embed.norm(dim=-1, keepdim=True)
|
||||
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
||||
return self.mlp(embed)
|
||||
|
@ -35,8 +35,6 @@ def aesthetic_score():
|
||||
scorer = AestheticScorer().cuda()
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
if not isinstance(images, torch.Tensor):
|
||||
images = torch.as_tensor(images)
|
||||
scores = scorer(images)
|
||||
return scores, {}
|
||||
|
||||
|
@ -34,17 +34,23 @@ logger = get_logger(__name__)
|
||||
def main(_):
|
||||
# basic Accelerate and logging setup
|
||||
config = FLAGS.config
|
||||
|
||||
# number of timesteps within each trajectory to train on
|
||||
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
|
||||
|
||||
accelerator = Accelerator(
|
||||
log_with="wandb",
|
||||
mixed_precision=config.mixed_precision,
|
||||
project_dir=config.logdir,
|
||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps,
|
||||
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
|
||||
# _samples_ to accumulate across
|
||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
# set seed
|
||||
# set seed (device_specific is very important to get different prompts on different devices)
|
||||
set_seed(config.seed, device_specific=True)
|
||||
|
||||
# load scheduler, tokenizer and models.
|
||||
@ -152,7 +158,8 @@ def main(_):
|
||||
config.per_prompt_stat_tracking.min_count,
|
||||
)
|
||||
|
||||
# for some reason, autocast is necessary for non-lora training but for lora training it uses more memory
|
||||
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
@ -289,8 +296,15 @@ def main(_):
|
||||
#################### TRAINING ####################
|
||||
for inner_epoch in range(config.train.num_inner_epochs):
|
||||
# shuffle samples along batch dimension
|
||||
indices = torch.randperm(total_batch_size, device=accelerator.device)
|
||||
samples = {k: v[indices] for k, v in samples.items()}
|
||||
perm = torch.randperm(total_batch_size, device=accelerator.device)
|
||||
samples = {k: v[perm] for k, v in samples.items()}
|
||||
|
||||
# shuffle along time dimension independently for each sample
|
||||
perms = torch.stack(
|
||||
[torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)]
|
||||
)
|
||||
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
||||
samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms]
|
||||
|
||||
# rebatch for training
|
||||
samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()}
|
||||
@ -300,6 +314,7 @@ def main(_):
|
||||
|
||||
# train
|
||||
pipeline.unet.train()
|
||||
info = defaultdict(list)
|
||||
for i, sample in tqdm(
|
||||
list(enumerate(samples_batched)),
|
||||
desc=f"Epoch {epoch}.{inner_epoch}: training",
|
||||
@ -312,9 +327,8 @@ def main(_):
|
||||
else:
|
||||
embeds = sample["prompt_embeds"]
|
||||
|
||||
info = defaultdict(list)
|
||||
for j in tqdm(
|
||||
range(num_timesteps),
|
||||
range(num_train_timesteps),
|
||||
desc="Timestep",
|
||||
position=1,
|
||||
leave=False,
|
||||
@ -371,14 +385,20 @@ def main(_):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
||||
info = accelerator.reduce(info, reduction="mean")
|
||||
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
||||
accelerator.log(info, step=global_step)
|
||||
global_step += 1
|
||||
info = defaultdict(list)
|
||||
|
||||
# make sure we did an optimization step at the end of the inner epoch
|
||||
assert accelerator.sync_gradients
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
Loading…
Reference in New Issue
Block a user