Fix aesthetic scorer
This commit is contained in:
		| @@ -5,6 +5,7 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import numpy as np | import numpy as np | ||||||
| from transformers import CLIPModel, CLIPProcessor | from transformers import CLIPModel, CLIPProcessor | ||||||
|  | from PIL import Image | ||||||
|  |  | ||||||
| ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||||
|  |  | ||||||
| @@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module): | |||||||
|  |  | ||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
|     def __call__(self, images): |     def __call__(self, images): | ||||||
|  |         assert isinstance(images, list) | ||||||
|  |         assert isinstance(images[0], Image.Image) | ||||||
|         inputs = self.processor(images=images, return_tensors="pt") |         inputs = self.processor(images=images, return_tensors="pt") | ||||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} |         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||||
|         embed = self.clip.get_image_features(**inputs) |         embed = self.clip.get_image_features(**inputs) | ||||||
|   | |||||||
| @@ -35,6 +35,10 @@ def aesthetic_score(): | |||||||
|     scorer = AestheticScorer().cuda() |     scorer = AestheticScorer().cuda() | ||||||
|  |  | ||||||
|     def _fn(images, prompts, metadata): |     def _fn(images, prompts, metadata): | ||||||
|  |         if isinstance(images, torch.Tensor): | ||||||
|  |             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||||
|  |             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||||
|  |         images = [Image.fromarray(image) for image in images] | ||||||
|         scores = scorer(images) |         scores = scorer(images) | ||||||
|         return scores, {} |         return scores, {} | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user