Add aesthetic scorer reward function
This commit is contained in:
		
							
								
								
									
										48
									
								
								ddpo_pytorch/aesthetic_scorer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								ddpo_pytorch/aesthetic_scorer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | |||||||
|  | # Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py | ||||||
|  |  | ||||||
|  | from importlib import resources | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import numpy as np | ||||||
|  | from transformers import CLIPModel, CLIPProcessor | ||||||
|  |  | ||||||
|  | ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MLP(nn.Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |         self.layers = nn.Sequential( | ||||||
|  |             nn.Linear(768, 1024), | ||||||
|  |             nn.Identity(), | ||||||
|  |             nn.Linear(1024, 128), | ||||||
|  |             nn.Identity(), | ||||||
|  |             nn.Linear(128, 64), | ||||||
|  |             nn.Identity(), | ||||||
|  |             nn.Linear(64, 16), | ||||||
|  |             nn.Linear(16, 1), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||||
|  |         self.load_state_dict(state_dict) | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def forward(self, embed): | ||||||
|  |         return self.layers(embed) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AestheticScorer(torch.nn.Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||||||
|  |         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | ||||||
|  |         self.mlp = MLP() | ||||||
|  |  | ||||||
|  |     @torch.no_grad() | ||||||
|  |     def __call__(self, images): | ||||||
|  |         inputs = self.processor(images=images, return_tensors="pt") | ||||||
|  |         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||||
|  |         embed = self.clip.get_image_features(**inputs) | ||||||
|  |         # normalize embedding | ||||||
|  |         embed = embed / embed.norm(dim=-1, keepdim=True) | ||||||
|  |         return self.mlp(embed) | ||||||
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/assets/sac+logos+ava1-l14-linearMSE.pth
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/assets/sac+logos+ava1-l14-linearMSE.pth
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -40,6 +40,10 @@ def imagenet_dogs(): | |||||||
|     return from_file("imagenet_classes.txt", 151, 269) |     return from_file("imagenet_classes.txt", 151, 269) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def simple_animals(): | ||||||
|  |     return from_file("simple_animals.txt") | ||||||
|  |  | ||||||
|  |  | ||||||
| def nouns_activities(nouns_file, activities_file): | def nouns_activities(nouns_file, activities_file): | ||||||
|     nouns = _load_lines(nouns_file) |     nouns = _load_lines(nouns_file) | ||||||
|     activities = _load_lines(activities_file) |     activities = _load_lines(activities_file) | ||||||
|   | |||||||
| @@ -8,7 +8,7 @@ def jpeg_incompressibility(): | |||||||
|     def _fn(images, prompts, metadata): |     def _fn(images, prompts, metadata): | ||||||
|         if isinstance(images, torch.Tensor): |         if isinstance(images, torch.Tensor): | ||||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() |             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||||
|             images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC |             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||||
|         images = [Image.fromarray(image) for image in images] |         images = [Image.fromarray(image) for image in images] | ||||||
|         buffers = [io.BytesIO() for _ in images] |         buffers = [io.BytesIO() for _ in images] | ||||||
|         for image, buffer in zip(images, buffers): |         for image, buffer in zip(images, buffers): | ||||||
| @@ -27,3 +27,17 @@ def jpeg_compressibility(): | |||||||
|         return -rew, meta |         return -rew, meta | ||||||
|  |  | ||||||
|     return _fn |     return _fn | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def aesthetic_score(): | ||||||
|  |     from ddpo_pytorch.aesthetic_scorer import AestheticScorer | ||||||
|  |  | ||||||
|  |     scorer = AestheticScorer().cuda() | ||||||
|  |  | ||||||
|  |     def _fn(images, prompts, metadata): | ||||||
|  |         if not isinstance(images, torch.Tensor): | ||||||
|  |             images = torch.as_tensor(images) | ||||||
|  |         scores = scorer(images) | ||||||
|  |         return scores, {} | ||||||
|  |  | ||||||
|  |     return _fn | ||||||
|   | |||||||
| @@ -42,7 +42,7 @@ def main(_): | |||||||
|     ) |     ) | ||||||
|     if accelerator.is_main_process: |     if accelerator.is_main_process: | ||||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) |         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) | ||||||
|     logger.info(config) |     logger.info(f"\n{config}") | ||||||
|  |  | ||||||
|     # set seed |     # set seed | ||||||
|     set_seed(config.seed, device_specific=True) |     set_seed(config.seed, device_specific=True) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user