Fix aesthetic scorer

This commit is contained in:
Kevin Black 2023-06-28 10:42:30 -07:00
parent 28d2d8c40e
commit fe9ed8a25f
2 changed files with 7 additions and 0 deletions

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from transformers import CLIPModel, CLIPProcessor from transformers import CLIPModel, CLIPProcessor
from PIL import Image
ASSETS_PATH = resources.files("ddpo_pytorch.assets") ASSETS_PATH = resources.files("ddpo_pytorch.assets")
@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def __call__(self, images): def __call__(self, images):
assert isinstance(images, list)
assert isinstance(images[0], Image.Image)
inputs = self.processor(images=images, return_tensors="pt") inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()} inputs = {k: v.cuda() for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs) embed = self.clip.get_image_features(**inputs)

View File

@ -35,6 +35,10 @@ def aesthetic_score():
scorer = AestheticScorer().cuda() scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata): def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
scores = scorer(images) scores = scorer(images)
return scores, {} return scores, {}