ddpo-pytorch/ddpo_pytorch/rewards.py

46 lines
1.3 KiB
Python
Raw Normal View History

2023-06-24 04:25:54 +02:00
from PIL import Image
import io
import numpy as np
import torch
def jpeg_incompressibility():
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
2023-06-27 19:40:36 +02:00
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
2023-06-24 04:25:54 +02:00
images = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in images]
for image, buffer in zip(images, buffers):
image.save(buffer, format="JPEG", quality=95)
sizes = [buffer.tell() / 1000 for buffer in buffers]
return np.array(sizes), {}
return _fn
def jpeg_compressibility():
jpeg_fn = jpeg_incompressibility()
def _fn(images, prompts, metadata):
rew, meta = jpeg_fn(images, prompts, metadata)
return -rew, meta
return _fn
2023-06-27 19:40:36 +02:00
def aesthetic_score():
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata):
2023-06-28 19:42:30 +02:00
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
2023-06-27 19:40:36 +02:00
scores = scorer(images)
return scores, {}
return _fn