2023-06-24 04:25:54 +02:00
from PIL import Image
import io
import numpy as np
import torch
def jpeg_incompressibility():
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
2023-06-27 19:40:36 +02:00
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
2023-06-24 04:25:54 +02:00
images = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in images]
for image, buffer in zip(images, buffers):
image.save(buffer, format="JPEG", quality=95)
sizes = [buffer.tell() / 1000 for buffer in buffers]
return np.array(sizes), {}
return _fn
def jpeg_compressibility():
jpeg_fn = jpeg_incompressibility()
def _fn(images, prompts, metadata):
rew, meta = jpeg_fn(images, prompts, metadata)
return -rew, meta
return _fn
2023-06-27 19:40:36 +02:00
def aesthetic_score():
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
2023-07-04 09:23:33 +02:00
scorer = AestheticScorer(dtype=torch.float32).cuda()
2023-06-27 19:40:36 +02:00
def _fn(images, prompts, metadata):
2023-11-16 23:36:46 +01:00
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
images = torch.tensor(images, dtype=torch.uint8)
2023-07-04 09:23:33 +02:00
scores = scorer(images)
return scores, {}
return _fn
def llava_strict_satisfaction():
"""Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without
using BERTScore. Prompt metadata must have "questions" and "answers" keys. See
https://github.com/kvablack/LLaVA-server for server-side code.
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 4
url = ""
sess = requests.Session()
2023-11-16 23:36:46 +01:00
retries = Retry(
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
2023-07-04 09:23:33 +02:00
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
del prompts
2023-06-28 19:42:30 +02:00
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
2023-07-04 09:23:33 +02:00
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size))
all_scores = []
all_info = {
"answers": [],
for image_batch, metadata_batch in zip(images_batched, metadata_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=80)
# format for LLaVA server
data = {
"images": jpeg_images,
"queries": [m["questions"] for m in metadata_batch],
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
correct = np.array(
[ans in resp for ans, resp in zip(m["answers"], responses)]
for m, responses in zip(metadata_batch, response_data["outputs"])
scores = correct.mean(axis=-1)
all_scores += scores.tolist()
all_info["answers"] += response_data["outputs"]
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
return _fn
def llava_bertscore():
"""Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
https://github.com/kvablack/LLaVA-server for server-side code.
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 16
url = ""
sess = requests.Session()
2023-11-16 23:36:46 +01:00
retries = Retry(
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
2023-07-04 09:23:33 +02:00
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
del metadata
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
all_scores = []
all_info = {
"precision": [],
"f1": [],
"outputs": [],
for image_batch, prompt_batch in zip(images_batched, prompts_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=80)
# format for LLaVA server
data = {
"images": jpeg_images,
2023-11-16 23:36:46 +01:00
"queries": [["Answer concisely: what is going on in this image?"]]
* len(image_batch),
"answers": [
[f"The image contains {prompt}"] for prompt in prompt_batch
2023-07-04 09:23:33 +02:00
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
# use the recall score as the reward
scores = np.array(response_data["recall"]).squeeze()
all_scores += scores.tolist()
# save the precision and f1 scores for analysis
2023-11-16 23:36:46 +01:00
all_info["precision"] += (
2023-07-04 09:23:33 +02:00
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
2023-06-27 19:40:36 +02:00
return _fn