Minor changes, add assets

This commit is contained in:
Kevin Black 2023-06-27 10:20:03 -07:00
parent 4c5322ca85
commit 8cab96dea4
5 changed files with 73 additions and 13 deletions

View File

@ -0,0 +1,3 @@
washing the dishes
riding a bike
playing chess

View File

@ -0,0 +1,45 @@
cat
dog
horse
monkey
rabbit
zebra
spider
bird
sheep
deer
cow
goat
lion
tiger
bear
raccoon
fox
wolf
lizard
beetle
ant
butterfly
fish
shark
whale
dolphin
squirrel
mouse
rat
snake
turtle
frog
chicken
duck
goose
bee
pig
turkey
fly
llama
camel
bat
gorilla
hedgehog
kangaroo

View File

@ -1,6 +1,8 @@
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
# with the following modifications: # with the following modifications:
# - # - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the
# `ddim` scheduler.
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
@ -10,7 +12,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
rescale_noise_cfg, rescale_noise_cfg,
) )
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from .ddim_with_logprob import ddim_step_with_logprob from .ddim_with_logprob import ddim_step_with_logprob

View File

@ -1,4 +1,5 @@
from importlib import resources from importlib import resources
import os
import functools import functools
import random import random
import inflect import inflect
@ -8,35 +9,45 @@ ASSETS_PATH = resources.files("ddpo_pytorch.assets")
@functools.cache @functools.cache
def load_lines(name): def _load_lines(path):
with ASSETS_PATH.joinpath(name).open() as f: """
Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the
`ddpo_pytorch/assets` directory for a file named `path`.
"""
if not os.path.exists(path):
newpath = ASSETS_PATH.joinpath(path)
if not os.path.exists(newpath):
raise FileNotFoundError(f"Could not find {path} or ddpo_pytorch.assets/{path}")
path = newpath
with open(path, "r") as f:
return [line.strip() for line in f.readlines()] return [line.strip() for line in f.readlines()]
def imagenet(low, high): def from_file(path, low=None, high=None):
return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} prompts = _load_lines(path)[low:high]
return random.choice(prompts), {}
def imagenet_all(): def imagenet_all():
return imagenet(0, 1000) return from_file("imagenet_classes.txt")
def imagenet_animals(): def imagenet_animals():
return imagenet(0, 398) return from_file("imagenet_classes.txt", 0, 398)
def imagenet_dogs(): def imagenet_dogs():
return imagenet(151, 269) return from_file("imagenet_classes.txt", 151, 269)
def nouns_activities(nouns_file, activities_file): def nouns_activities(nouns_file, activities_file):
nouns = load_lines(nouns_file) nouns = _load_lines(nouns_file)
activities = load_lines(activities_file) activities = _load_lines(activities_file)
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
def counting(nouns_file, low, high): def counting(nouns_file, low, high):
nouns = load_lines(nouns_file) nouns = _load_lines(nouns_file)
number = IE.number_to_words(random.randint(low, high)) number = IE.number_to_words(random.randint(low, high))
noun = random.choice(nouns) noun = random.choice(nouns)
plural_noun = IE.plural(noun) plural_noun = IE.plural(noun)

View File

@ -1,7 +1,7 @@
from collections import defaultdict from collections import defaultdict
import contextlib import contextlib
import os import os
from absl import app, flags, logging from absl import app, flags
from ml_collections import config_flags from ml_collections import config_flags
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed from accelerate.utils import set_seed