From ec499edf841fd7fa831e1579ae70c13171954c22 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Tue, 4 Jul 2023 00:23:33 -0700 Subject: [PATCH] Fix aesthetic score (again), add llava reward --- ddpo_pytorch/aesthetic_scorer.py | 10 +-- ddpo_pytorch/rewards.py | 141 +++++++++++++++++++++++++++++-- scripts/train.py | 27 +++++- 3 files changed, 164 insertions(+), 14 deletions(-) diff --git a/ddpo_pytorch/aesthetic_scorer.py b/ddpo_pytorch/aesthetic_scorer.py index 461fc2c..46e3202 100644 --- a/ddpo_pytorch/aesthetic_scorer.py +++ b/ddpo_pytorch/aesthetic_scorer.py @@ -30,22 +30,22 @@ class MLP(nn.Module): class AestheticScorer(torch.nn.Module): - def __init__(self): + def __init__(self, dtype): super().__init__() self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") self.mlp = MLP() state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) self.mlp.load_state_dict(state_dict) + self.dtype = dtype self.eval() @torch.no_grad() def __call__(self, images): - assert isinstance(images, list) - assert isinstance(images[0], Image.Image) + device = next(self.parameters()).device inputs = self.processor(images=images, return_tensors="pt") - inputs = {k: v.cuda() for k, v in inputs.items()} + inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} embed = self.clip.get_image_features(**inputs) # normalize embedding embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) - return self.mlp(embed) + return self.mlp(embed).squeeze(1) diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py index 0c58533..669d8e0 100644 --- a/ddpo_pytorch/rewards.py +++ b/ddpo_pytorch/rewards.py @@ -32,14 +32,145 @@ def jpeg_compressibility(): def aesthetic_score(): from ddpo_pytorch.aesthetic_scorer import AestheticScorer - scorer = AestheticScorer().cuda() + scorer = AestheticScorer(dtype=torch.float32).cuda() def _fn(images, prompts, metadata): - if isinstance(images, torch.Tensor): - images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() - images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC - images = [Image.fromarray(image) for image in images] + images = (images * 255).round().clamp(0, 255).to(torch.uint8) scores = scorer(images) return scores, {} return _fn + + +def llava_strict_satisfaction(): + """Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without + using BERTScore. Prompt metadata must have "questions" and "answers" keys. See + https://github.com/kvablack/LLaVA-server for server-side code. + """ + import requests + from requests.adapters import HTTPAdapter, Retry + from io import BytesIO + import pickle + + batch_size = 4 + url = "http://127.0.0.1:8085" + sess = requests.Session() + retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False) + sess.mount("http://", HTTPAdapter(max_retries=retries)) + + def _fn(images, prompts, metadata): + del prompts + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + + images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) + metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size)) + + all_scores = [] + all_info = { + "answers": [], + } + for image_batch, metadata_batch in zip(images_batched, metadata_batched): + jpeg_images = [] + + # Compress the images using JPEG + for image in image_batch: + img = Image.fromarray(image) + buffer = BytesIO() + img.save(buffer, format="JPEG", quality=80) + jpeg_images.append(buffer.getvalue()) + + # format for LLaVA server + data = { + "images": jpeg_images, + "queries": [m["questions"] for m in metadata_batch], + } + data_bytes = pickle.dumps(data) + + # send a request to the llava server + response = sess.post(url, data=data_bytes, timeout=120) + + response_data = pickle.loads(response.content) + + correct = np.array( + [ + [ans in resp for ans, resp in zip(m["answers"], responses)] + for m, responses in zip(metadata_batch, response_data["outputs"]) + ] + ) + scores = correct.mean(axis=-1) + + all_scores += scores.tolist() + all_info["answers"] += response_data["outputs"] + + return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} + + return _fn + + +def llava_bertscore(): + """Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See + https://github.com/kvablack/LLaVA-server for server-side code. + """ + import requests + from requests.adapters import HTTPAdapter, Retry + from io import BytesIO + import pickle + + batch_size = 16 + url = "http://127.0.0.1:8085" + sess = requests.Session() + retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False) + sess.mount("http://", HTTPAdapter(max_retries=retries)) + + def _fn(images, prompts, metadata): + del metadata + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + + images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) + prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size)) + + all_scores = [] + all_info = { + "precision": [], + "f1": [], + "outputs": [], + } + for image_batch, prompt_batch in zip(images_batched, prompts_batched): + jpeg_images = [] + + # Compress the images using JPEG + for image in image_batch: + img = Image.fromarray(image) + buffer = BytesIO() + img.save(buffer, format="JPEG", quality=80) + jpeg_images.append(buffer.getvalue()) + + # format for LLaVA server + data = { + "images": jpeg_images, + "queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch), + "answers": [[f"The image contains {prompt}"] for prompt in prompt_batch], + } + data_bytes = pickle.dumps(data) + + # send a request to the llava server + response = sess.post(url, data=data_bytes, timeout=120) + + response_data = pickle.loads(response.content) + + # use the recall score as the reward + scores = np.array(response_data["recall"]).squeeze() + all_scores += scores.tolist() + + # save the precision and f1 scores for analysis + all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist() + all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist() + all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist() + + return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} + + return _fn diff --git a/scripts/train.py b/scripts/train.py index f5928d5..e5e832e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -2,6 +2,8 @@ from collections import defaultdict import contextlib import os import datetime +from concurrent import futures +import time from absl import app, flags from ml_collections import config_flags from accelerate import Accelerator @@ -227,6 +229,10 @@ def main(_): # Prepare everything with our `accelerator`. trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) + # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a + # remote server running llava inference. + executor = futures.ThreadPoolExecutor(max_workers=2) + # Train! samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch total_train_batch_size = ( @@ -298,8 +304,10 @@ def main(_): log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps) - # compute rewards - rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) + # compute rewards asynchronously + rewards = executor.submit(reward_fn, images, prompts, prompt_metadata) + # yield to to make sure reward computation starts + time.sleep(0) samples.append( { @@ -309,10 +317,21 @@ def main(_): "latents": latents[:, :-1], # each entry is the latent before timestep t "next_latents": latents[:, 1:], # each entry is the latent after timestep t "log_probs": log_probs, - "rewards": torch.as_tensor(rewards, device=accelerator.device), + "rewards": rewards, } ) + # wait for all rewards to be computed + for sample in tqdm( + samples, + desc="Waiting for rewards", + disable=not accelerator.is_local_main_process, + position=0, + ): + rewards, reward_metadata = sample["rewards"].result() + # accelerator.print(reward_metadata) + sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device) + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} @@ -472,7 +491,7 @@ def main(_): # make sure we did an optimization step at the end of the inner epoch assert accelerator.sync_gradients - if epoch % config.save_freq == 0 and accelerator.is_main_process: + if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process: accelerator.save_state()