From 8cab96dea4bc885122d0f2edf25b64239d67e514 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Tue, 27 Jun 2023 10:20:03 -0700 Subject: [PATCH] Minor changes, add assets --- ddpo_pytorch/assets/activities_v0.txt | 3 ++ ddpo_pytorch/assets/common_animals.txt | 45 +++++++++++++++++++ .../diffusers_patch/pipeline_with_logprob.py | 5 ++- ddpo_pytorch/prompts.py | 31 ++++++++----- scripts/train.py | 2 +- 5 files changed, 73 insertions(+), 13 deletions(-) create mode 100644 ddpo_pytorch/assets/activities_v0.txt create mode 100644 ddpo_pytorch/assets/common_animals.txt diff --git a/ddpo_pytorch/assets/activities_v0.txt b/ddpo_pytorch/assets/activities_v0.txt new file mode 100644 index 0000000..abea045 --- /dev/null +++ b/ddpo_pytorch/assets/activities_v0.txt @@ -0,0 +1,3 @@ +washing the dishes +riding a bike +playing chess \ No newline at end of file diff --git a/ddpo_pytorch/assets/common_animals.txt b/ddpo_pytorch/assets/common_animals.txt new file mode 100644 index 0000000..bc9e117 --- /dev/null +++ b/ddpo_pytorch/assets/common_animals.txt @@ -0,0 +1,45 @@ +cat +dog +horse +monkey +rabbit +zebra +spider +bird +sheep +deer +cow +goat +lion +tiger +bear +raccoon +fox +wolf +lizard +beetle +ant +butterfly +fish +shark +whale +dolphin +squirrel +mouse +rat +snake +turtle +frog +chicken +duck +goose +bee +pig +turkey +fly +llama +camel +bat +gorilla +hedgehog +kangaroo diff --git a/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py b/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py index 09378c2..67726bf 100644 --- a/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py +++ b/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py @@ -1,6 +1,8 @@ # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py # with the following modifications: -# - +# - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the +# `ddim` scheduler. +# - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. from typing import Any, Callable, Dict, List, Optional, Union @@ -10,7 +12,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, rescale_noise_cfg, ) -from diffusers.schedulers.scheduling_ddim import DDIMScheduler from .ddim_with_logprob import ddim_step_with_logprob diff --git a/ddpo_pytorch/prompts.py b/ddpo_pytorch/prompts.py index 8cecf28..fdf6e34 100644 --- a/ddpo_pytorch/prompts.py +++ b/ddpo_pytorch/prompts.py @@ -1,4 +1,5 @@ from importlib import resources +import os import functools import random import inflect @@ -8,35 +9,45 @@ ASSETS_PATH = resources.files("ddpo_pytorch.assets") @functools.cache -def load_lines(name): - with ASSETS_PATH.joinpath(name).open() as f: +def _load_lines(path): + """ + Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the + `ddpo_pytorch/assets` directory for a file named `path`. + """ + if not os.path.exists(path): + newpath = ASSETS_PATH.joinpath(path) + if not os.path.exists(newpath): + raise FileNotFoundError(f"Could not find {path} or ddpo_pytorch.assets/{path}") + path = newpath + with open(path, "r") as f: return [line.strip() for line in f.readlines()] -def imagenet(low, high): - return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} +def from_file(path, low=None, high=None): + prompts = _load_lines(path)[low:high] + return random.choice(prompts), {} def imagenet_all(): - return imagenet(0, 1000) + return from_file("imagenet_classes.txt") def imagenet_animals(): - return imagenet(0, 398) + return from_file("imagenet_classes.txt", 0, 398) def imagenet_dogs(): - return imagenet(151, 269) + return from_file("imagenet_classes.txt", 151, 269) def nouns_activities(nouns_file, activities_file): - nouns = load_lines(nouns_file) - activities = load_lines(activities_file) + nouns = _load_lines(nouns_file) + activities = _load_lines(activities_file) return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} def counting(nouns_file, low, high): - nouns = load_lines(nouns_file) + nouns = _load_lines(nouns_file) number = IE.number_to_words(random.randint(low, high)) noun = random.choice(nouns) plural_noun = IE.plural(noun) diff --git a/scripts/train.py b/scripts/train.py index 4b58f74..78df021 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,7 +1,7 @@ from collections import defaultdict import contextlib import os -from absl import app, flags, logging +from absl import app, flags from ml_collections import config_flags from accelerate import Accelerator from accelerate.utils import set_seed