Fix aesthetic scorer
This commit is contained in:
		| @@ -5,6 +5,7 @@ import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
| from transformers import CLIPModel, CLIPProcessor | ||||
| from PIL import Image | ||||
|  | ||||
| ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||
|  | ||||
| @@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module): | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def __call__(self, images): | ||||
|         assert isinstance(images, list) | ||||
|         assert isinstance(images[0], Image.Image) | ||||
|         inputs = self.processor(images=images, return_tensors="pt") | ||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||
|         embed = self.clip.get_image_features(**inputs) | ||||
|   | ||||
| @@ -35,6 +35,10 @@ def aesthetic_score(): | ||||
|     scorer = AestheticScorer().cuda() | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         if isinstance(images, torch.Tensor): | ||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||
|             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||
|         images = [Image.fromarray(image) for image in images] | ||||
|         scores = scorer(images) | ||||
|         return scores, {} | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user