from PIL import Image import io import numpy as np import torch def jpeg_incompressibility(): def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images = [Image.fromarray(image) for image in images] buffers = [io.BytesIO() for _ in images] for image, buffer in zip(images, buffers): image.save(buffer, format="JPEG", quality=95) sizes = [buffer.tell() / 1000 for buffer in buffers] return np.array(sizes), {} return _fn def jpeg_compressibility(): jpeg_fn = jpeg_incompressibility() def _fn(images, prompts, metadata): rew, meta = jpeg_fn(images, prompts, metadata) return -rew, meta return _fn def aesthetic_score(): from ddpo_pytorch.aesthetic_scorer import AestheticScorer scorer = AestheticScorer().cuda() def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images = [Image.fromarray(image) for image in images] scores = scorer(images) return scores, {} return _fn