Add aesthetic scorer reward function

This commit is contained in:
Kevin Black 2023-06-27 10:40:36 -07:00
parent 8cab96dea4
commit bae3f43f5f
6 changed files with 68 additions and 2 deletions

View File

@ -0,0 +1,48 @@
# Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py
from importlib import resources
import torch
import torch.nn as nn
import numpy as np
from transformers import CLIPModel, CLIPProcessor
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Identity(),
nn.Linear(1024, 128),
nn.Identity(),
nn.Linear(128, 64),
nn.Identity(),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
self.load_state_dict(state_dict)
@torch.no_grad()
def forward(self, embed):
return self.layers(embed)
class AestheticScorer(torch.nn.Module):
def __init__(self):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP()
@torch.no_grad()
def __call__(self, images):
inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs)
# normalize embedding
embed = embed / embed.norm(dim=-1, keepdim=True)
return self.mlp(embed)

Binary file not shown.

View File

@ -40,6 +40,10 @@ def imagenet_dogs():
return from_file("imagenet_classes.txt", 151, 269)
def simple_animals():
return from_file("simple_animals.txt")
def nouns_activities(nouns_file, activities_file):
nouns = _load_lines(nouns_file)
activities = _load_lines(activities_file)

View File

@ -8,7 +8,7 @@ def jpeg_incompressibility():
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in images]
for image, buffer in zip(images, buffers):
@ -27,3 +27,17 @@ def jpeg_compressibility():
return -rew, meta
return _fn
def aesthetic_score():
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata):
if not isinstance(images, torch.Tensor):
images = torch.as_tensor(images)
scores = scorer(images)
return scores, {}
return _fn

View File

@ -42,7 +42,7 @@ def main(_):
)
if accelerator.is_main_process:
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
logger.info(config)
logger.info(f"\n{config}")
# set seed
set_seed(config.seed, device_specific=True)