Fix aesthetic scorer
This commit is contained in:
parent
28d2d8c40e
commit
fe9ed8a25f
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
from PIL import Image
|
||||
|
||||
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
|
||||
|
||||
@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, images):
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], Image.Image)
|
||||
inputs = self.processor(images=images, return_tensors="pt")
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
embed = self.clip.get_image_features(**inputs)
|
||||
|
@ -35,6 +35,10 @@ def aesthetic_score():
|
||||
scorer = AestheticScorer().cuda()
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
images = [Image.fromarray(image) for image in images]
|
||||
scores = scorer(images)
|
||||
return scores, {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user