2023-06-24 04:25:54 +02:00
|
|
|
from PIL import Image
|
|
|
|
import io
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def jpeg_incompressibility():
|
|
|
|
def _fn(images, prompts, metadata):
|
|
|
|
if isinstance(images, torch.Tensor):
|
|
|
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
2023-06-27 19:40:36 +02:00
|
|
|
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
2023-06-24 04:25:54 +02:00
|
|
|
images = [Image.fromarray(image) for image in images]
|
|
|
|
buffers = [io.BytesIO() for _ in images]
|
|
|
|
for image, buffer in zip(images, buffers):
|
|
|
|
image.save(buffer, format="JPEG", quality=95)
|
|
|
|
sizes = [buffer.tell() / 1000 for buffer in buffers]
|
|
|
|
return np.array(sizes), {}
|
|
|
|
|
|
|
|
return _fn
|
|
|
|
|
|
|
|
|
|
|
|
def jpeg_compressibility():
|
|
|
|
jpeg_fn = jpeg_incompressibility()
|
|
|
|
|
|
|
|
def _fn(images, prompts, metadata):
|
|
|
|
rew, meta = jpeg_fn(images, prompts, metadata)
|
|
|
|
return -rew, meta
|
|
|
|
|
|
|
|
return _fn
|
2023-06-27 19:40:36 +02:00
|
|
|
|
|
|
|
|
|
|
|
def aesthetic_score():
|
|
|
|
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
|
|
|
|
|
2023-07-04 09:23:33 +02:00
|
|
|
scorer = AestheticScorer(dtype=torch.float32).cuda()
|
2023-06-27 19:40:36 +02:00
|
|
|
|
|
|
|
def _fn(images, prompts, metadata):
|
2023-11-16 23:36:46 +01:00
|
|
|
if isinstance(images, torch.Tensor):
|
|
|
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
|
|
|
|
else:
|
|
|
|
images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
|
|
|
|
images = torch.tensor(images, dtype=torch.uint8)
|
2023-07-04 09:23:33 +02:00
|
|
|
scores = scorer(images)
|
|
|
|
return scores, {}
|
|
|
|
|
|
|
|
return _fn
|
|
|
|
|
|
|
|
|
|
|
|
def llava_strict_satisfaction():
|
|
|
|
"""Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without
|
|
|
|
using BERTScore. Prompt metadata must have "questions" and "answers" keys. See
|
|
|
|
https://github.com/kvablack/LLaVA-server for server-side code.
|
|
|
|
"""
|
|
|
|
import requests
|
|
|
|
from requests.adapters import HTTPAdapter, Retry
|
|
|
|
from io import BytesIO
|
|
|
|
import pickle
|
|
|
|
|
|
|
|
batch_size = 4
|
|
|
|
url = "http://127.0.0.1:8085"
|
|
|
|
sess = requests.Session()
|
2023-11-16 23:36:46 +01:00
|
|
|
retries = Retry(
|
|
|
|
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
|
|
|
|
)
|
2023-07-04 09:23:33 +02:00
|
|
|
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
|
|
|
|
|
|
|
def _fn(images, prompts, metadata):
|
|
|
|
del prompts
|
2023-06-28 19:42:30 +02:00
|
|
|
if isinstance(images, torch.Tensor):
|
|
|
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
|
|
|
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
2023-07-04 09:23:33 +02:00
|
|
|
|
|
|
|
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
|
|
|
|
metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size))
|
|
|
|
|
|
|
|
all_scores = []
|
|
|
|
all_info = {
|
|
|
|
"answers": [],
|
|
|
|
}
|
|
|
|
for image_batch, metadata_batch in zip(images_batched, metadata_batched):
|
|
|
|
jpeg_images = []
|
|
|
|
|
|
|
|
# Compress the images using JPEG
|
|
|
|
for image in image_batch:
|
|
|
|
img = Image.fromarray(image)
|
|
|
|
buffer = BytesIO()
|
|
|
|
img.save(buffer, format="JPEG", quality=80)
|
|
|
|
jpeg_images.append(buffer.getvalue())
|
|
|
|
|
|
|
|
# format for LLaVA server
|
|
|
|
data = {
|
|
|
|
"images": jpeg_images,
|
|
|
|
"queries": [m["questions"] for m in metadata_batch],
|
|
|
|
}
|
|
|
|
data_bytes = pickle.dumps(data)
|
|
|
|
|
|
|
|
# send a request to the llava server
|
|
|
|
response = sess.post(url, data=data_bytes, timeout=120)
|
|
|
|
|
|
|
|
response_data = pickle.loads(response.content)
|
|
|
|
|
|
|
|
correct = np.array(
|
|
|
|
[
|
|
|
|
[ans in resp for ans, resp in zip(m["answers"], responses)]
|
|
|
|
for m, responses in zip(metadata_batch, response_data["outputs"])
|
|
|
|
]
|
|
|
|
)
|
|
|
|
scores = correct.mean(axis=-1)
|
|
|
|
|
|
|
|
all_scores += scores.tolist()
|
|
|
|
all_info["answers"] += response_data["outputs"]
|
|
|
|
|
|
|
|
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
|
|
|
|
|
|
|
|
return _fn
|
|
|
|
|
|
|
|
|
|
|
|
def llava_bertscore():
|
|
|
|
"""Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
|
|
|
|
https://github.com/kvablack/LLaVA-server for server-side code.
|
|
|
|
"""
|
|
|
|
import requests
|
|
|
|
from requests.adapters import HTTPAdapter, Retry
|
|
|
|
from io import BytesIO
|
|
|
|
import pickle
|
|
|
|
|
|
|
|
batch_size = 16
|
|
|
|
url = "http://127.0.0.1:8085"
|
|
|
|
sess = requests.Session()
|
2023-11-16 23:36:46 +01:00
|
|
|
retries = Retry(
|
|
|
|
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
|
|
|
|
)
|
2023-07-04 09:23:33 +02:00
|
|
|
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
|
|
|
|
|
|
|
def _fn(images, prompts, metadata):
|
|
|
|
del metadata
|
|
|
|
if isinstance(images, torch.Tensor):
|
|
|
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
|
|
|
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
|
|
|
|
|
|
|
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
|
|
|
|
prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
|
|
|
|
|
|
|
|
all_scores = []
|
|
|
|
all_info = {
|
|
|
|
"precision": [],
|
|
|
|
"f1": [],
|
|
|
|
"outputs": [],
|
|
|
|
}
|
|
|
|
for image_batch, prompt_batch in zip(images_batched, prompts_batched):
|
|
|
|
jpeg_images = []
|
|
|
|
|
|
|
|
# Compress the images using JPEG
|
|
|
|
for image in image_batch:
|
|
|
|
img = Image.fromarray(image)
|
|
|
|
buffer = BytesIO()
|
|
|
|
img.save(buffer, format="JPEG", quality=80)
|
|
|
|
jpeg_images.append(buffer.getvalue())
|
|
|
|
|
|
|
|
# format for LLaVA server
|
|
|
|
data = {
|
|
|
|
"images": jpeg_images,
|
2023-11-16 23:36:46 +01:00
|
|
|
"queries": [["Answer concisely: what is going on in this image?"]]
|
|
|
|
* len(image_batch),
|
|
|
|
"answers": [
|
|
|
|
[f"The image contains {prompt}"] for prompt in prompt_batch
|
|
|
|
],
|
2023-07-04 09:23:33 +02:00
|
|
|
}
|
|
|
|
data_bytes = pickle.dumps(data)
|
|
|
|
|
|
|
|
# send a request to the llava server
|
|
|
|
response = sess.post(url, data=data_bytes, timeout=120)
|
|
|
|
|
|
|
|
response_data = pickle.loads(response.content)
|
|
|
|
|
|
|
|
# use the recall score as the reward
|
|
|
|
scores = np.array(response_data["recall"]).squeeze()
|
|
|
|
all_scores += scores.tolist()
|
|
|
|
|
|
|
|
# save the precision and f1 scores for analysis
|
2023-11-16 23:36:46 +01:00
|
|
|
all_info["precision"] += (
|
|
|
|
np.array(response_data["precision"]).squeeze().tolist()
|
|
|
|
)
|
2023-07-04 09:23:33 +02:00
|
|
|
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
|
|
|
|
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
|
|
|
|
|
|
|
|
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
|
2023-06-27 19:40:36 +02:00
|
|
|
|
|
|
|
return _fn
|