Fix aesthetic scorer

This commit is contained in:
Kevin Black 2023-06-28 10:42:30 -07:00
parent 28d2d8c40e
commit fe9ed8a25f
2 changed files with 7 additions and 0 deletions

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import numpy as np
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module):
@torch.no_grad()
def __call__(self, images):
assert isinstance(images, list)
assert isinstance(images[0], Image.Image)
inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs)

View File

@ -35,6 +35,10 @@ def aesthetic_score():
scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
scores = scorer(images)
return scores, {}