Minor changes, add assets

This commit is contained in:
Kevin Black 2023-06-27 10:20:03 -07:00
parent 4c5322ca85
commit 8cab96dea4
5 changed files with 73 additions and 13 deletions

View File

@ -0,0 +1,3 @@
washing the dishes
riding a bike
playing chess

View File

@ -0,0 +1,45 @@
cat
dog
horse
monkey
rabbit
zebra
spider
bird
sheep
deer
cow
goat
lion
tiger
bear
raccoon
fox
wolf
lizard
beetle
ant
butterfly
fish
shark
whale
dolphin
squirrel
mouse
rat
snake
turtle
frog
chicken
duck
goose
bee
pig
turkey
fly
llama
camel
bat
gorilla
hedgehog
kangaroo

View File

@ -1,6 +1,8 @@
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
# with the following modifications:
# -
# - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the
# `ddim` scheduler.
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
from typing import Any, Callable, Dict, List, Optional, Union
@ -10,7 +12,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
rescale_noise_cfg,
)
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from .ddim_with_logprob import ddim_step_with_logprob

View File

@ -1,4 +1,5 @@
from importlib import resources
import os
import functools
import random
import inflect
@ -8,35 +9,45 @@ ASSETS_PATH = resources.files("ddpo_pytorch.assets")
@functools.cache
def load_lines(name):
with ASSETS_PATH.joinpath(name).open() as f:
def _load_lines(path):
"""
Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the
`ddpo_pytorch/assets` directory for a file named `path`.
"""
if not os.path.exists(path):
newpath = ASSETS_PATH.joinpath(path)
if not os.path.exists(newpath):
raise FileNotFoundError(f"Could not find {path} or ddpo_pytorch.assets/{path}")
path = newpath
with open(path, "r") as f:
return [line.strip() for line in f.readlines()]
def imagenet(low, high):
return random.choice(load_lines("imagenet_classes.txt")[low:high]), {}
def from_file(path, low=None, high=None):
prompts = _load_lines(path)[low:high]
return random.choice(prompts), {}
def imagenet_all():
return imagenet(0, 1000)
return from_file("imagenet_classes.txt")
def imagenet_animals():
return imagenet(0, 398)
return from_file("imagenet_classes.txt", 0, 398)
def imagenet_dogs():
return imagenet(151, 269)
return from_file("imagenet_classes.txt", 151, 269)
def nouns_activities(nouns_file, activities_file):
nouns = load_lines(nouns_file)
activities = load_lines(activities_file)
nouns = _load_lines(nouns_file)
activities = _load_lines(activities_file)
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
def counting(nouns_file, low, high):
nouns = load_lines(nouns_file)
nouns = _load_lines(nouns_file)
number = IE.number_to_words(random.randint(low, high))
noun = random.choice(nouns)
plural_noun = IE.plural(noun)

View File

@ -1,7 +1,7 @@
from collections import defaultdict
import contextlib
import os
from absl import app, flags, logging
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed