Minor changes, add assets
This commit is contained in:
parent
4c5322ca85
commit
8cab96dea4
3
ddpo_pytorch/assets/activities_v0.txt
Normal file
3
ddpo_pytorch/assets/activities_v0.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
washing the dishes
|
||||||
|
riding a bike
|
||||||
|
playing chess
|
45
ddpo_pytorch/assets/common_animals.txt
Normal file
45
ddpo_pytorch/assets/common_animals.txt
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
cat
|
||||||
|
dog
|
||||||
|
horse
|
||||||
|
monkey
|
||||||
|
rabbit
|
||||||
|
zebra
|
||||||
|
spider
|
||||||
|
bird
|
||||||
|
sheep
|
||||||
|
deer
|
||||||
|
cow
|
||||||
|
goat
|
||||||
|
lion
|
||||||
|
tiger
|
||||||
|
bear
|
||||||
|
raccoon
|
||||||
|
fox
|
||||||
|
wolf
|
||||||
|
lizard
|
||||||
|
beetle
|
||||||
|
ant
|
||||||
|
butterfly
|
||||||
|
fish
|
||||||
|
shark
|
||||||
|
whale
|
||||||
|
dolphin
|
||||||
|
squirrel
|
||||||
|
mouse
|
||||||
|
rat
|
||||||
|
snake
|
||||||
|
turtle
|
||||||
|
frog
|
||||||
|
chicken
|
||||||
|
duck
|
||||||
|
goose
|
||||||
|
bee
|
||||||
|
pig
|
||||||
|
turkey
|
||||||
|
fly
|
||||||
|
llama
|
||||||
|
camel
|
||||||
|
bat
|
||||||
|
gorilla
|
||||||
|
hedgehog
|
||||||
|
kangaroo
|
@ -1,6 +1,8 @@
|
|||||||
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
||||||
# with the following modifications:
|
# with the following modifications:
|
||||||
# -
|
# - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the
|
||||||
|
# `ddim` scheduler.
|
||||||
|
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
@ -10,7 +12,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
|||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
rescale_noise_cfg,
|
rescale_noise_cfg,
|
||||||
)
|
)
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
|
||||||
from .ddim_with_logprob import ddim_step_with_logprob
|
from .ddim_with_logprob import ddim_step_with_logprob
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from importlib import resources
|
from importlib import resources
|
||||||
|
import os
|
||||||
import functools
|
import functools
|
||||||
import random
|
import random
|
||||||
import inflect
|
import inflect
|
||||||
@ -8,35 +9,45 @@ ASSETS_PATH = resources.files("ddpo_pytorch.assets")
|
|||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def load_lines(name):
|
def _load_lines(path):
|
||||||
with ASSETS_PATH.joinpath(name).open() as f:
|
"""
|
||||||
|
Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the
|
||||||
|
`ddpo_pytorch/assets` directory for a file named `path`.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(path):
|
||||||
|
newpath = ASSETS_PATH.joinpath(path)
|
||||||
|
if not os.path.exists(newpath):
|
||||||
|
raise FileNotFoundError(f"Could not find {path} or ddpo_pytorch.assets/{path}")
|
||||||
|
path = newpath
|
||||||
|
with open(path, "r") as f:
|
||||||
return [line.strip() for line in f.readlines()]
|
return [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
|
|
||||||
def imagenet(low, high):
|
def from_file(path, low=None, high=None):
|
||||||
return random.choice(load_lines("imagenet_classes.txt")[low:high]), {}
|
prompts = _load_lines(path)[low:high]
|
||||||
|
return random.choice(prompts), {}
|
||||||
|
|
||||||
|
|
||||||
def imagenet_all():
|
def imagenet_all():
|
||||||
return imagenet(0, 1000)
|
return from_file("imagenet_classes.txt")
|
||||||
|
|
||||||
|
|
||||||
def imagenet_animals():
|
def imagenet_animals():
|
||||||
return imagenet(0, 398)
|
return from_file("imagenet_classes.txt", 0, 398)
|
||||||
|
|
||||||
|
|
||||||
def imagenet_dogs():
|
def imagenet_dogs():
|
||||||
return imagenet(151, 269)
|
return from_file("imagenet_classes.txt", 151, 269)
|
||||||
|
|
||||||
|
|
||||||
def nouns_activities(nouns_file, activities_file):
|
def nouns_activities(nouns_file, activities_file):
|
||||||
nouns = load_lines(nouns_file)
|
nouns = _load_lines(nouns_file)
|
||||||
activities = load_lines(activities_file)
|
activities = _load_lines(activities_file)
|
||||||
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
|
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
|
||||||
|
|
||||||
|
|
||||||
def counting(nouns_file, low, high):
|
def counting(nouns_file, low, high):
|
||||||
nouns = load_lines(nouns_file)
|
nouns = _load_lines(nouns_file)
|
||||||
number = IE.number_to_words(random.randint(low, high))
|
number = IE.number_to_words(random.randint(low, high))
|
||||||
noun = random.choice(nouns)
|
noun = random.choice(nouns)
|
||||||
plural_noun = IE.plural(noun)
|
plural_noun = IE.plural(noun)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
from absl import app, flags, logging
|
from absl import app, flags
|
||||||
from ml_collections import config_flags
|
from ml_collections import config_flags
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
Loading…
Reference in New Issue
Block a user