46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from PIL import Image
|
|
import io
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def jpeg_incompressibility():
|
|
def _fn(images, prompts, metadata):
|
|
if isinstance(images, torch.Tensor):
|
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
|
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
|
images = [Image.fromarray(image) for image in images]
|
|
buffers = [io.BytesIO() for _ in images]
|
|
for image, buffer in zip(images, buffers):
|
|
image.save(buffer, format="JPEG", quality=95)
|
|
sizes = [buffer.tell() / 1000 for buffer in buffers]
|
|
return np.array(sizes), {}
|
|
|
|
return _fn
|
|
|
|
|
|
def jpeg_compressibility():
|
|
jpeg_fn = jpeg_incompressibility()
|
|
|
|
def _fn(images, prompts, metadata):
|
|
rew, meta = jpeg_fn(images, prompts, metadata)
|
|
return -rew, meta
|
|
|
|
return _fn
|
|
|
|
|
|
def aesthetic_score():
|
|
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
|
|
|
|
scorer = AestheticScorer().cuda()
|
|
|
|
def _fn(images, prompts, metadata):
|
|
if isinstance(images, torch.Tensor):
|
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
|
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
|
images = [Image.fromarray(image) for image in images]
|
|
scores = scorer(images)
|
|
return scores, {}
|
|
|
|
return _fn
|