ddpo-pytorch/config/dgx.py

96 lines
2.4 KiB
Python
Raw Normal View History

2023-06-24 09:07:55 +02:00
import ml_collections
2023-06-26 06:02:27 +02:00
import imp
import os
base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
2023-06-24 09:07:55 +02:00
2023-07-04 09:25:37 +02:00
def compressibility():
2023-06-24 09:07:55 +02:00
config = base.get_config()
2023-07-04 09:25:37 +02:00
config.pretrained.model = "CompVis/stable-diffusion-v1-4"
2023-07-04 09:25:37 +02:00
config.num_epochs = 100
config.use_lora = True
config.save_freq = 1
config.num_checkpoint_limit = 100000000
2023-06-24 09:07:55 +02:00
2023-07-04 09:25:37 +02:00
# the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch.
config.sample.batch_size = 8
config.sample.num_batches_per_epoch = 4
# this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch.
config.train.batch_size = 4
config.train.gradient_accumulation_steps = 2
2023-06-24 09:07:55 +02:00
2023-07-04 09:25:37 +02:00
# prompting
config.prompt_fn = "imagenet_animals"
config.prompt_fn_kwargs = {}
# rewards
config.reward_fn = "jpeg_compressibility"
2023-06-24 09:07:55 +02:00
config.per_prompt_stat_tracking = {
"buffer_size": 16,
"min_count": 16,
}
2023-06-24 09:07:55 +02:00
2023-06-26 06:02:27 +02:00
return config
2023-07-04 09:25:37 +02:00
def incompressibility():
config = compressibility()
config.reward_fn = "jpeg_incompressibility"
return config
def aesthetic():
config = compressibility()
config.num_epochs = 200
config.reward_fn = "aesthetic_score"
# this reward is a bit harder to optimize, so I used 2 gradient updates per epoch.
config.train.gradient_accumulation_steps = 4
config.prompt_fn = "simple_animals"
config.per_prompt_stat_tracking = {
"buffer_size": 32,
"min_count": 16,
}
return config
def prompt_image_alignment():
config = compressibility()
config.num_epochs = 200
# for this experiment, I reserved 2 GPUs for LLaVA inference so only 6 could be used for DDPO. the total number of
# samples per epoch is 8 * 6 * 6 = 288.
config.sample.batch_size = 8
config.sample.num_batches_per_epoch = 6
# again, this one is harder to optimize, so I used (8 * 6) / (4 * 6) = 2 gradient updates per epoch.
config.train.batch_size = 4
config.train.gradient_accumulation_steps = 6
# prompting
config.prompt_fn = "nouns_activities"
config.prompt_fn_kwargs = {
"nouns_file": "simple_animals.txt",
"activities_file": "activities.txt",
}
# rewards
config.reward_fn = "llava_bertscore"
config.per_prompt_stat_tracking = {
"buffer_size": 32,
"min_count": 16,
}
return config
def get_config(name):
return globals()[name]()