Fix aesthetic score (again), add llava reward

This commit is contained in:
Kevin Black 2023-07-04 00:23:33 -07:00
parent c0bc708549
commit ec499edf84
3 changed files with 164 additions and 14 deletions

View File

@ -30,22 +30,22 @@ class MLP(nn.Module):
class AestheticScorer(torch.nn.Module):
def __init__(self):
def __init__(self, dtype):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP()
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()
@torch.no_grad()
def __call__(self, images):
assert isinstance(images, list)
assert isinstance(images[0], Image.Image)
device = next(self.parameters()).device
inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs)
# normalize embedding
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.mlp(embed)
return self.mlp(embed).squeeze(1)

View File

@ -32,14 +32,145 @@ def jpeg_compressibility():
def aesthetic_score():
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
scorer = AestheticScorer().cuda()
scorer = AestheticScorer(dtype=torch.float32).cuda()
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
scores = scorer(images)
return scores, {}
return _fn
def llava_strict_satisfaction():
"""Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without
using BERTScore. Prompt metadata must have "questions" and "answers" keys. See
https://github.com/kvablack/LLaVA-server for server-side code.
"""
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 4
url = "http://127.0.0.1:8085"
sess = requests.Session()
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
del prompts
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size))
all_scores = []
all_info = {
"answers": [],
}
for image_batch, metadata_batch in zip(images_batched, metadata_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=80)
jpeg_images.append(buffer.getvalue())
# format for LLaVA server
data = {
"images": jpeg_images,
"queries": [m["questions"] for m in metadata_batch],
}
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
correct = np.array(
[
[ans in resp for ans, resp in zip(m["answers"], responses)]
for m, responses in zip(metadata_batch, response_data["outputs"])
]
)
scores = correct.mean(axis=-1)
all_scores += scores.tolist()
all_info["answers"] += response_data["outputs"]
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
return _fn
def llava_bertscore():
"""Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
https://github.com/kvablack/LLaVA-server for server-side code.
"""
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 16
url = "http://127.0.0.1:8085"
sess = requests.Session()
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
del metadata
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
all_scores = []
all_info = {
"precision": [],
"f1": [],
"outputs": [],
}
for image_batch, prompt_batch in zip(images_batched, prompts_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=80)
jpeg_images.append(buffer.getvalue())
# format for LLaVA server
data = {
"images": jpeg_images,
"queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch),
"answers": [[f"The image contains {prompt}"] for prompt in prompt_batch],
}
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
# use the recall score as the reward
scores = np.array(response_data["recall"]).squeeze()
all_scores += scores.tolist()
# save the precision and f1 scores for analysis
all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist()
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
return _fn

View File

@ -2,6 +2,8 @@ from collections import defaultdict
import contextlib
import os
import datetime
from concurrent import futures
import time
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
@ -227,6 +229,10 @@ def main(_):
# Prepare everything with our `accelerator`.
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
# remote server running llava inference.
executor = futures.ThreadPoolExecutor(max_workers=2)
# Train!
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
total_train_batch_size = (
@ -298,8 +304,10 @@ def main(_):
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps)
# compute rewards
rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata)
# compute rewards asynchronously
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
# yield to to make sure reward computation starts
time.sleep(0)
samples.append(
{
@ -309,10 +317,21 @@ def main(_):
"latents": latents[:, :-1], # each entry is the latent before timestep t
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
"log_probs": log_probs,
"rewards": torch.as_tensor(rewards, device=accelerator.device),
"rewards": rewards,
}
)
# wait for all rewards to be computed
for sample in tqdm(
samples,
desc="Waiting for rewards",
disable=not accelerator.is_local_main_process,
position=0,
):
rewards, reward_metadata = sample["rewards"].result()
# accelerator.print(reward_metadata)
sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device)
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
@ -472,7 +491,7 @@ def main(_):
# make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients
if epoch % config.save_freq == 0 and accelerator.is_main_process:
if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process:
accelerator.save_state()