Add aesthetic scorer reward function
This commit is contained in:
		
							
								
								
									
										48
									
								
								ddpo_pytorch/aesthetic_scorer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								ddpo_pytorch/aesthetic_scorer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| # Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py | ||||
|  | ||||
| from importlib import resources | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
| from transformers import CLIPModel, CLIPProcessor | ||||
|  | ||||
| ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||
|  | ||||
|  | ||||
| class MLP(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.layers = nn.Sequential( | ||||
|             nn.Linear(768, 1024), | ||||
|             nn.Identity(), | ||||
|             nn.Linear(1024, 128), | ||||
|             nn.Identity(), | ||||
|             nn.Linear(128, 64), | ||||
|             nn.Identity(), | ||||
|             nn.Linear(64, 16), | ||||
|             nn.Linear(16, 1), | ||||
|         ) | ||||
|  | ||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||
|         self.load_state_dict(state_dict) | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def forward(self, embed): | ||||
|         return self.layers(embed) | ||||
|  | ||||
|  | ||||
| class AestheticScorer(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.mlp = MLP() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def __call__(self, images): | ||||
|         inputs = self.processor(images=images, return_tensors="pt") | ||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||
|         embed = self.clip.get_image_features(**inputs) | ||||
|         # normalize embedding | ||||
|         embed = embed / embed.norm(dim=-1, keepdim=True) | ||||
|         return self.mlp(embed) | ||||
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/assets/sac+logos+ava1-l14-linearMSE.pth
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/assets/sac+logos+ava1-l14-linearMSE.pth
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -40,6 +40,10 @@ def imagenet_dogs(): | ||||
|     return from_file("imagenet_classes.txt", 151, 269) | ||||
|  | ||||
|  | ||||
| def simple_animals(): | ||||
|     return from_file("simple_animals.txt") | ||||
|  | ||||
|  | ||||
| def nouns_activities(nouns_file, activities_file): | ||||
|     nouns = _load_lines(nouns_file) | ||||
|     activities = _load_lines(activities_file) | ||||
|   | ||||
| @@ -27,3 +27,17 @@ def jpeg_compressibility(): | ||||
|         return -rew, meta | ||||
|  | ||||
|     return _fn | ||||
|  | ||||
|  | ||||
| def aesthetic_score(): | ||||
|     from ddpo_pytorch.aesthetic_scorer import AestheticScorer | ||||
|  | ||||
|     scorer = AestheticScorer().cuda() | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         if not isinstance(images, torch.Tensor): | ||||
|             images = torch.as_tensor(images) | ||||
|         scores = scorer(images) | ||||
|         return scores, {} | ||||
|  | ||||
|     return _fn | ||||
|   | ||||
| @@ -42,7 +42,7 @@ def main(_): | ||||
|     ) | ||||
|     if accelerator.is_main_process: | ||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) | ||||
|     logger.info(config) | ||||
|     logger.info(f"\n{config}") | ||||
|  | ||||
|     # set seed | ||||
|     set_seed(config.seed, device_specific=True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user