Minor changes, add assets
This commit is contained in:
		
							
								
								
									
										3
									
								
								ddpo_pytorch/assets/activities_v0.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								ddpo_pytorch/assets/activities_v0.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| washing the dishes | ||||
| riding a bike | ||||
| playing chess | ||||
							
								
								
									
										45
									
								
								ddpo_pytorch/assets/common_animals.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								ddpo_pytorch/assets/common_animals.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| cat | ||||
| dog | ||||
| horse | ||||
| monkey | ||||
| rabbit | ||||
| zebra | ||||
| spider | ||||
| bird | ||||
| sheep | ||||
| deer | ||||
| cow | ||||
| goat | ||||
| lion | ||||
| tiger | ||||
| bear | ||||
| raccoon | ||||
| fox | ||||
| wolf | ||||
| lizard | ||||
| beetle | ||||
| ant | ||||
| butterfly | ||||
| fish | ||||
| shark | ||||
| whale | ||||
| dolphin | ||||
| squirrel | ||||
| mouse | ||||
| rat | ||||
| snake | ||||
| turtle | ||||
| frog | ||||
| chicken | ||||
| duck | ||||
| goose | ||||
| bee | ||||
| pig | ||||
| turkey | ||||
| fly | ||||
| llama | ||||
| camel | ||||
| bat | ||||
| gorilla | ||||
| hedgehog | ||||
| kangaroo | ||||
| @@ -1,6 +1,8 @@ | ||||
| # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py | ||||
| # with the following modifications: | ||||
| # - | ||||
| # - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the | ||||
| #   `ddim` scheduler. | ||||
| # - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. | ||||
|  | ||||
| from typing import Any, Callable, Dict, List, Optional, Union | ||||
|  | ||||
| @@ -10,7 +12,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( | ||||
|     StableDiffusionPipeline, | ||||
|     rescale_noise_cfg, | ||||
| ) | ||||
| from diffusers.schedulers.scheduling_ddim import DDIMScheduler | ||||
| from .ddim_with_logprob import ddim_step_with_logprob | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| from importlib import resources | ||||
| import os | ||||
| import functools | ||||
| import random | ||||
| import inflect | ||||
| @@ -8,35 +9,45 @@ ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||
|  | ||||
|  | ||||
| @functools.cache | ||||
| def load_lines(name): | ||||
|     with ASSETS_PATH.joinpath(name).open() as f: | ||||
| def _load_lines(path): | ||||
|     """ | ||||
|     Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the | ||||
|     `ddpo_pytorch/assets` directory for a file named `path`. | ||||
|     """ | ||||
|     if not os.path.exists(path): | ||||
|         newpath = ASSETS_PATH.joinpath(path) | ||||
|     if not os.path.exists(newpath): | ||||
|         raise FileNotFoundError(f"Could not find {path} or ddpo_pytorch.assets/{path}") | ||||
|     path = newpath | ||||
|     with open(path, "r") as f: | ||||
|         return [line.strip() for line in f.readlines()] | ||||
|  | ||||
|  | ||||
| def imagenet(low, high): | ||||
|     return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} | ||||
| def from_file(path, low=None, high=None): | ||||
|     prompts = _load_lines(path)[low:high] | ||||
|     return random.choice(prompts), {} | ||||
|  | ||||
|  | ||||
| def imagenet_all(): | ||||
|     return imagenet(0, 1000) | ||||
|     return from_file("imagenet_classes.txt") | ||||
|  | ||||
|  | ||||
| def imagenet_animals(): | ||||
|     return imagenet(0, 398) | ||||
|     return from_file("imagenet_classes.txt", 0, 398) | ||||
|  | ||||
|  | ||||
| def imagenet_dogs(): | ||||
|     return imagenet(151, 269) | ||||
|     return from_file("imagenet_classes.txt", 151, 269) | ||||
|  | ||||
|  | ||||
| def nouns_activities(nouns_file, activities_file): | ||||
|     nouns = load_lines(nouns_file) | ||||
|     activities = load_lines(activities_file) | ||||
|     nouns = _load_lines(nouns_file) | ||||
|     activities = _load_lines(activities_file) | ||||
|     return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} | ||||
|  | ||||
|  | ||||
| def counting(nouns_file, low, high): | ||||
|     nouns = load_lines(nouns_file) | ||||
|     nouns = _load_lines(nouns_file) | ||||
|     number = IE.number_to_words(random.randint(low, high)) | ||||
|     noun = random.choice(nouns) | ||||
|     plural_noun = IE.plural(noun) | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| from collections import defaultdict | ||||
| import contextlib | ||||
| import os | ||||
| from absl import app, flags, logging | ||||
| from absl import app, flags | ||||
| from ml_collections import config_flags | ||||
| from accelerate import Accelerator | ||||
| from accelerate.utils import set_seed | ||||
|   | ||||
		Reference in New Issue
	
	Block a user