Fix aesthetic score (again), add llava reward
This commit is contained in:
parent
c0bc708549
commit
ec499edf84
@ -30,22 +30,22 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class AestheticScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, dtype):
|
||||
super().__init__()
|
||||
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.mlp = MLP()
|
||||
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
|
||||
self.mlp.load_state_dict(state_dict)
|
||||
self.dtype = dtype
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, images):
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], Image.Image)
|
||||
device = next(self.parameters()).device
|
||||
inputs = self.processor(images=images, return_tensors="pt")
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
|
||||
embed = self.clip.get_image_features(**inputs)
|
||||
# normalize embedding
|
||||
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
||||
return self.mlp(embed)
|
||||
return self.mlp(embed).squeeze(1)
|
||||
|
@ -32,14 +32,145 @@ def jpeg_compressibility():
|
||||
def aesthetic_score():
|
||||
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
|
||||
|
||||
scorer = AestheticScorer().cuda()
|
||||
scorer = AestheticScorer(dtype=torch.float32).cuda()
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
images = [Image.fromarray(image) for image in images]
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
|
||||
scores = scorer(images)
|
||||
return scores, {}
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
def llava_strict_satisfaction():
|
||||
"""Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without
|
||||
using BERTScore. Prompt metadata must have "questions" and "answers" keys. See
|
||||
https://github.com/kvablack/LLaVA-server for server-side code.
|
||||
"""
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from io import BytesIO
|
||||
import pickle
|
||||
|
||||
batch_size = 4
|
||||
url = "http://127.0.0.1:8085"
|
||||
sess = requests.Session()
|
||||
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
|
||||
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
del prompts
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
|
||||
metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size))
|
||||
|
||||
all_scores = []
|
||||
all_info = {
|
||||
"answers": [],
|
||||
}
|
||||
for image_batch, metadata_batch in zip(images_batched, metadata_batched):
|
||||
jpeg_images = []
|
||||
|
||||
# Compress the images using JPEG
|
||||
for image in image_batch:
|
||||
img = Image.fromarray(image)
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=80)
|
||||
jpeg_images.append(buffer.getvalue())
|
||||
|
||||
# format for LLaVA server
|
||||
data = {
|
||||
"images": jpeg_images,
|
||||
"queries": [m["questions"] for m in metadata_batch],
|
||||
}
|
||||
data_bytes = pickle.dumps(data)
|
||||
|
||||
# send a request to the llava server
|
||||
response = sess.post(url, data=data_bytes, timeout=120)
|
||||
|
||||
response_data = pickle.loads(response.content)
|
||||
|
||||
correct = np.array(
|
||||
[
|
||||
[ans in resp for ans, resp in zip(m["answers"], responses)]
|
||||
for m, responses in zip(metadata_batch, response_data["outputs"])
|
||||
]
|
||||
)
|
||||
scores = correct.mean(axis=-1)
|
||||
|
||||
all_scores += scores.tolist()
|
||||
all_info["answers"] += response_data["outputs"]
|
||||
|
||||
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
def llava_bertscore():
|
||||
"""Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
|
||||
https://github.com/kvablack/LLaVA-server for server-side code.
|
||||
"""
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from io import BytesIO
|
||||
import pickle
|
||||
|
||||
batch_size = 16
|
||||
url = "http://127.0.0.1:8085"
|
||||
sess = requests.Session()
|
||||
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
|
||||
sess.mount("http://", HTTPAdapter(max_retries=retries))
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
del metadata
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
|
||||
prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
|
||||
|
||||
all_scores = []
|
||||
all_info = {
|
||||
"precision": [],
|
||||
"f1": [],
|
||||
"outputs": [],
|
||||
}
|
||||
for image_batch, prompt_batch in zip(images_batched, prompts_batched):
|
||||
jpeg_images = []
|
||||
|
||||
# Compress the images using JPEG
|
||||
for image in image_batch:
|
||||
img = Image.fromarray(image)
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=80)
|
||||
jpeg_images.append(buffer.getvalue())
|
||||
|
||||
# format for LLaVA server
|
||||
data = {
|
||||
"images": jpeg_images,
|
||||
"queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch),
|
||||
"answers": [[f"The image contains {prompt}"] for prompt in prompt_batch],
|
||||
}
|
||||
data_bytes = pickle.dumps(data)
|
||||
|
||||
# send a request to the llava server
|
||||
response = sess.post(url, data=data_bytes, timeout=120)
|
||||
|
||||
response_data = pickle.loads(response.content)
|
||||
|
||||
# use the recall score as the reward
|
||||
scores = np.array(response_data["recall"]).squeeze()
|
||||
all_scores += scores.tolist()
|
||||
|
||||
# save the precision and f1 scores for analysis
|
||||
all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist()
|
||||
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
|
||||
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
|
||||
|
||||
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
|
||||
|
||||
return _fn
|
||||
|
@ -2,6 +2,8 @@ from collections import defaultdict
|
||||
import contextlib
|
||||
import os
|
||||
import datetime
|
||||
from concurrent import futures
|
||||
import time
|
||||
from absl import app, flags
|
||||
from ml_collections import config_flags
|
||||
from accelerate import Accelerator
|
||||
@ -227,6 +229,10 @@ def main(_):
|
||||
# Prepare everything with our `accelerator`.
|
||||
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
|
||||
|
||||
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
|
||||
# remote server running llava inference.
|
||||
executor = futures.ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
# Train!
|
||||
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
|
||||
total_train_batch_size = (
|
||||
@ -298,8 +304,10 @@ def main(_):
|
||||
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
||||
timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps)
|
||||
|
||||
# compute rewards
|
||||
rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata)
|
||||
# compute rewards asynchronously
|
||||
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
|
||||
# yield to to make sure reward computation starts
|
||||
time.sleep(0)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
@ -309,10 +317,21 @@ def main(_):
|
||||
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
||||
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
||||
"log_probs": log_probs,
|
||||
"rewards": torch.as_tensor(rewards, device=accelerator.device),
|
||||
"rewards": rewards,
|
||||
}
|
||||
)
|
||||
|
||||
# wait for all rewards to be computed
|
||||
for sample in tqdm(
|
||||
samples,
|
||||
desc="Waiting for rewards",
|
||||
disable=not accelerator.is_local_main_process,
|
||||
position=0,
|
||||
):
|
||||
rewards, reward_metadata = sample["rewards"].result()
|
||||
# accelerator.print(reward_metadata)
|
||||
sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device)
|
||||
|
||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||
|
||||
@ -472,7 +491,7 @@ def main(_):
|
||||
# make sure we did an optimization step at the end of the inner epoch
|
||||
assert accelerator.sync_gradients
|
||||
|
||||
if epoch % config.save_freq == 0 and accelerator.is_main_process:
|
||||
if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process:
|
||||
accelerator.save_state()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user