Minor changes, add assets
This commit is contained in:
parent
4c5322ca85
commit
8cab96dea4
3
ddpo_pytorch/assets/activities_v0.txt
Normal file
3
ddpo_pytorch/assets/activities_v0.txt
Normal file
@ -0,0 +1,3 @@
|
||||
washing the dishes
|
||||
riding a bike
|
||||
playing chess
|
45
ddpo_pytorch/assets/common_animals.txt
Normal file
45
ddpo_pytorch/assets/common_animals.txt
Normal file
@ -0,0 +1,45 @@
|
||||
cat
|
||||
dog
|
||||
horse
|
||||
monkey
|
||||
rabbit
|
||||
zebra
|
||||
spider
|
||||
bird
|
||||
sheep
|
||||
deer
|
||||
cow
|
||||
goat
|
||||
lion
|
||||
tiger
|
||||
bear
|
||||
raccoon
|
||||
fox
|
||||
wolf
|
||||
lizard
|
||||
beetle
|
||||
ant
|
||||
butterfly
|
||||
fish
|
||||
shark
|
||||
whale
|
||||
dolphin
|
||||
squirrel
|
||||
mouse
|
||||
rat
|
||||
snake
|
||||
turtle
|
||||
frog
|
||||
chicken
|
||||
duck
|
||||
goose
|
||||
bee
|
||||
pig
|
||||
turkey
|
||||
fly
|
||||
llama
|
||||
camel
|
||||
bat
|
||||
gorilla
|
||||
hedgehog
|
||||
kangaroo
|
@ -1,6 +1,8 @@
|
||||
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
||||
# with the following modifications:
|
||||
# -
|
||||
# - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the
|
||||
# `ddim` scheduler.
|
||||
# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
@ -10,7 +12,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
rescale_noise_cfg,
|
||||
)
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from .ddim_with_logprob import ddim_step_with_logprob
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from importlib import resources
|
||||
import os
|
||||
import functools
|
||||
import random
|
||||
import inflect
|
||||
@ -8,35 +9,45 @@ ASSETS_PATH = resources.files("ddpo_pytorch.assets")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_lines(name):
|
||||
with ASSETS_PATH.joinpath(name).open() as f:
|
||||
def _load_lines(path):
|
||||
"""
|
||||
Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the
|
||||
`ddpo_pytorch/assets` directory for a file named `path`.
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
newpath = ASSETS_PATH.joinpath(path)
|
||||
if not os.path.exists(newpath):
|
||||
raise FileNotFoundError(f"Could not find {path} or ddpo_pytorch.assets/{path}")
|
||||
path = newpath
|
||||
with open(path, "r") as f:
|
||||
return [line.strip() for line in f.readlines()]
|
||||
|
||||
|
||||
def imagenet(low, high):
|
||||
return random.choice(load_lines("imagenet_classes.txt")[low:high]), {}
|
||||
def from_file(path, low=None, high=None):
|
||||
prompts = _load_lines(path)[low:high]
|
||||
return random.choice(prompts), {}
|
||||
|
||||
|
||||
def imagenet_all():
|
||||
return imagenet(0, 1000)
|
||||
return from_file("imagenet_classes.txt")
|
||||
|
||||
|
||||
def imagenet_animals():
|
||||
return imagenet(0, 398)
|
||||
return from_file("imagenet_classes.txt", 0, 398)
|
||||
|
||||
|
||||
def imagenet_dogs():
|
||||
return imagenet(151, 269)
|
||||
return from_file("imagenet_classes.txt", 151, 269)
|
||||
|
||||
|
||||
def nouns_activities(nouns_file, activities_file):
|
||||
nouns = load_lines(nouns_file)
|
||||
activities = load_lines(activities_file)
|
||||
nouns = _load_lines(nouns_file)
|
||||
activities = _load_lines(activities_file)
|
||||
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
|
||||
|
||||
|
||||
def counting(nouns_file, low, high):
|
||||
nouns = load_lines(nouns_file)
|
||||
nouns = _load_lines(nouns_file)
|
||||
number = IE.number_to_words(random.randint(low, high))
|
||||
noun = random.choice(nouns)
|
||||
plural_noun = IE.plural(noun)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from collections import defaultdict
|
||||
import contextlib
|
||||
import os
|
||||
from absl import app, flags, logging
|
||||
from absl import app, flags
|
||||
from ml_collections import config_flags
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
|
Loading…
Reference in New Issue
Block a user