from PIL import Image import io import numpy as np import torch def jpeg_incompressibility(): def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images = [Image.fromarray(image) for image in images] buffers = [io.BytesIO() for _ in images] for image, buffer in zip(images, buffers): image.save(buffer, format="JPEG", quality=95) sizes = [buffer.tell() / 1000 for buffer in buffers] return np.array(sizes), {} return _fn def jpeg_compressibility(): jpeg_fn = jpeg_incompressibility() def _fn(images, prompts, metadata): rew, meta = jpeg_fn(images, prompts, metadata) return -rew, meta return _fn def aesthetic_score(): from ddpo_pytorch.aesthetic_scorer import AestheticScorer scorer = AestheticScorer(dtype=torch.float32).cuda() def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8) else: images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW images = torch.tensor(images, dtype=torch.uint8) scores = scorer(images) return scores, {} return _fn def llava_strict_satisfaction(): """Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without using BERTScore. Prompt metadata must have "questions" and "answers" keys. See https://github.com/kvablack/LLaVA-server for server-side code. """ import requests from requests.adapters import HTTPAdapter, Retry from io import BytesIO import pickle batch_size = 4 url = "http://127.0.0.1:8085" sess = requests.Session() retries = Retry( total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False ) sess.mount("http://", HTTPAdapter(max_retries=retries)) def _fn(images, prompts, metadata): del prompts if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size)) all_scores = [] all_info = { "answers": [], } for image_batch, metadata_batch in zip(images_batched, metadata_batched): jpeg_images = [] # Compress the images using JPEG for image in image_batch: img = Image.fromarray(image) buffer = BytesIO() img.save(buffer, format="JPEG", quality=80) jpeg_images.append(buffer.getvalue()) # format for LLaVA server data = { "images": jpeg_images, "queries": [m["questions"] for m in metadata_batch], } data_bytes = pickle.dumps(data) # send a request to the llava server response = sess.post(url, data=data_bytes, timeout=120) response_data = pickle.loads(response.content) correct = np.array( [ [ans in resp for ans, resp in zip(m["answers"], responses)] for m, responses in zip(metadata_batch, response_data["outputs"]) ] ) scores = correct.mean(axis=-1) all_scores += scores.tolist() all_info["answers"] += response_data["outputs"] return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} return _fn def llava_bertscore(): """Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See https://github.com/kvablack/LLaVA-server for server-side code. """ import requests from requests.adapters import HTTPAdapter, Retry from io import BytesIO import pickle batch_size = 16 url = "http://127.0.0.1:8085" sess = requests.Session() retries = Retry( total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False ) sess.mount("http://", HTTPAdapter(max_retries=retries)) def _fn(images, prompts, metadata): del metadata if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size)) all_scores = [] all_info = { "precision": [], "f1": [], "outputs": [], } for image_batch, prompt_batch in zip(images_batched, prompts_batched): jpeg_images = [] # Compress the images using JPEG for image in image_batch: img = Image.fromarray(image) buffer = BytesIO() img.save(buffer, format="JPEG", quality=80) jpeg_images.append(buffer.getvalue()) # format for LLaVA server data = { "images": jpeg_images, "queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch), "answers": [ [f"The image contains {prompt}"] for prompt in prompt_batch ], } data_bytes = pickle.dumps(data) # send a request to the llava server response = sess.post(url, data=data_bytes, timeout=120) response_data = pickle.loads(response.content) # use the recall score as the reward scores = np.array(response_data["recall"]).squeeze() all_scores += scores.tolist() # save the precision and f1 scores for analysis all_info["precision"] += ( np.array(response_data["precision"]).squeeze().tolist() ) all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist() all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist() return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} return _fn