| """ |
| Adapted from ImageReward (https://github.com/THUDM/ImageReward) |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
|
|
| |
| from blip.blip_pretrain import blip_pretrain |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| try: |
| from torchvision.transforms import InterpolationMode |
| BICUBIC = InterpolationMode.BICUBIC |
| except ImportError: |
| BICUBIC = Image.BICUBIC |
|
|
| cyclereward_args = { |
| 'blip_path': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth', |
| 'vit': 'large', |
| 'image_size': 224, |
| 'mlp_dim': 768 |
| } |
|
|
| def _convert_image_to_rgb(image): |
| return image.convert("RGB") |
|
|
| def _transform(n_px): |
| return Compose([ |
| Resize(n_px, interpolation=BICUBIC), |
| CenterCrop(n_px), |
| _convert_image_to_rgb, |
| ToTensor(), |
| Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
| ]) |
|
|
| class CycleReward(nn.Module, PyTorchModelHubMixin): |
| def __init__(self, device='cpu', |
| model_type='CycleReward-Combo', |
| max_length=128, |
| fix_rate=0.7, |
| med_config=None, |
| ): |
| super().__init__() |
| self.device = device |
| self.model_type = model_type |
| self.max_length = max_length |
| |
| self.blip = blip_pretrain( |
| pretrained=cyclereward_args['blip_path'], |
| med_config=med_config, |
| image_size=cyclereward_args['image_size'], |
| vit=cyclereward_args['vit'] |
| ) |
| self.preprocess = _transform(cyclereward_args['image_size']) |
| self.mlp = MLP(cyclereward_args['mlp_dim']) |
| |
| for name, parms in self.blip.named_parameters(): |
| if '_proj' in name: |
| parms.requires_grad_(False) |
| |
| |
| self.image_layer_num = 24 if cyclereward_args['vit'] == 'large' else 12 |
| if fix_rate > 0: |
| text_fix_num = "layer.{}".format(int(12 * fix_rate)) |
| image_fix_num = "blocks.{}".format(int(self.image_layer_num * fix_rate)) |
| for name, parms in self.blip.text_encoder.named_parameters(): |
| parms.requires_grad_(False) |
| if text_fix_num in name: |
| break |
| for name, parms in self.blip.visual_encoder.named_parameters(): |
| parms.requires_grad_(False) |
| if image_fix_num in name: |
| break |
|
|
| def forward(self, batch): |
| if 'Combo' in self.model_type: |
| text_reward = self.text_reward(batch) |
| image_reward = self.image_reward(batch) |
|
|
| elif 'I2T' in self.model_type: |
| text_reward = self.text_reward(batch) |
| image_reward = None |
|
|
| elif 'T2I' in self.model_type: |
| text_reward = None |
| image_reward = self.image_reward(batch) |
| |
| return text_reward, image_reward |
| |
| def text_reward(self, batch): |
| images, preferred_ids, preferred_mask, rejected_ids, rejected_mask = batch["images"], batch["preferred_ids"], batch["preferred_mask"], batch["rejected_ids"], batch["rejected_mask"] |
| images = images.to(self.device) |
| preferred_ids = preferred_ids.to(self.device) |
| preferred_mask = preferred_mask.to(self.device) |
| rejected_ids = rejected_ids.to(self.device) |
| rejected_mask = rejected_mask.to(self.device) |
|
|
| |
| image_embeds = self.blip.visual_encoder(images) |
| image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) |
| |
| |
| preferred_embeds = self.blip.text_encoder( |
| preferred_ids, |
| attention_mask=preferred_mask, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=True, |
| ).last_hidden_state |
| preferred_embeds = preferred_embeds[:,0,:].float() |
| |
| |
| rejected_embeds = self.blip.text_encoder( |
| rejected_ids, |
| attention_mask=rejected_mask, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=True, |
| ).last_hidden_state |
| rejected_embeds = rejected_embeds[:,0,:].float() |
|
|
| preferred_reward = self.mlp(preferred_embeds) |
| rejected_reward = self.mlp(rejected_embeds) |
| reward = torch.concat((preferred_reward, rejected_reward), dim=1) |
|
|
| return reward |
|
|
| def image_reward(self, batch): |
| prompt_ids, prompt_mask, image_preferred, image_rejected = batch["prompt_ids"], batch["prompt_mask"], batch["image_preferred"], batch["image_rejected"] |
| image_preferred = image_preferred.to(self.device) |
| image_rejected = image_rejected.to(self.device) |
| prompt_ids = prompt_ids.view(prompt_ids.shape[0], -1).to(self.device) |
| prompt_mask = prompt_mask.view(prompt_mask.shape[0], -1).to(self.device) |
|
|
| |
| image_embeds_preferred = self.blip.visual_encoder(image_preferred) |
| image_atts_preferred = torch.ones(image_embeds_preferred.size()[:-1],dtype=torch.long).to(self.device) |
|
|
| image_embeds_rejected = self.blip.visual_encoder(image_rejected) |
| image_atts_rejected = torch.ones(image_embeds_rejected.size()[:-1],dtype=torch.long).to(self.device) |
| |
| |
| preferred_embeds = self.blip.text_encoder( |
| prompt_ids, |
| attention_mask=prompt_mask, |
| encoder_hidden_states=image_embeds_preferred, |
| encoder_attention_mask=image_atts_preferred, |
| return_dict=True, |
| ).last_hidden_state |
| preferred_embeds = preferred_embeds[:,0,:].float() |
| |
| |
| rejected_embeds = self.blip.text_encoder( |
| prompt_ids, |
| attention_mask=prompt_mask, |
| encoder_hidden_states=image_embeds_rejected, |
| encoder_attention_mask=image_atts_rejected, |
| return_dict=True, |
| ).last_hidden_state |
| rejected_embeds = rejected_embeds[:,0,:].float() |
|
|
| preferred_reward = self.mlp(preferred_embeds) |
| rejected_reward = self.mlp(rejected_embeds) |
| reward = torch.concat((preferred_reward, rejected_reward), dim=1) |
|
|
| return reward |
| |
| @torch.no_grad() |
| def score(self, image, prompt): |
| text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").to(self.device) |
| |
| image_embeds = self.blip.visual_encoder(image) |
| image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) |
| |
| text_embeds = self.blip.text_encoder( |
| text_input.input_ids, |
| attention_mask=text_input.attention_mask, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=True, |
| ).last_hidden_state |
| text_embeds = text_embeds[:,0,:].float() |
| |
| rewards = self.mlp(text_embeds) |
| return rewards |
| |
| class MLP(nn.Module): |
| def __init__(self, input_size): |
| super().__init__() |
| self.layers = nn.Sequential( |
| nn.Linear(input_size, 1024), |
| nn.GELU(), |
|
|
| nn.Linear(1024, 128), |
| nn.GELU(), |
|
|
| nn.Linear(128, 64), |
| nn.GELU(), |
|
|
| nn.Linear(64, 16), |
| nn.GELU(), |
| |
| nn.Linear(16, 1) |
| ) |
| |
| def init_weights(m): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| |
| self.layers.apply(init_weights) |
| |
| def forward(self, input): |
| return self.layers(input) |
|
|
|
|
| device='cuda' |
|
|
| model = CycleReward.from_pretrained("carolineec/CycleReward-Combo") |
| model.to(device) |
| model.eval() |
|
|
| preprocess = model.preprocess |
| image_path = "cat.jpg" |
| caption = "a photo of a cat" |
| image = preprocess(Image.open(image_path)).unsqueeze(0).to(device) |
| print('prepared data') |
|
|
| score = model.score(image, caption) |
| print('my score:', score.item()) |
|
|
|
|
|
|
|
|