| from constants import * |
| from utils import image_to_tensor, tokenizer, tensor_to_image, vocab_size, tokenizer |
| import torch |
| import torch.nn.functional as F |
| from PIL import ImageDraw, Image |
| from dataset import create_test_dataloader |
| from vision_language_model import VisionLanguageModel |
|
|
|
|
| model = VisionLanguageModel( |
| n_embd=HIDDEN_DIM, |
| vocab_size=vocab_size, |
| img_size=IMAGE_SIZE, |
| patch_size=PATCH_SIZE, |
| num_heads=NUM_HEADS, |
| num_blks_vit=NUM_LAYERS, |
| num_blks_dec=NUM_LAYERS, |
| emb_dropout=DROPOUT, |
| blk_dropout=DROPOUT, |
| max_context=CONTEXT_LENGTH, |
| shared_embed_dim=SHARED_EMBED_DIM, |
| lambda_contrastive=LAMBDA_CONTRASTIVE, |
| lambda_regression=LAMBDA_REGRESSION |
| ).to(DEVICE) |
|
|
| MODEL_PATH = "model_regression_multi_first_100.pth" |
|
|
| if DEVICE == "cuda": |
| model.load_state_dict(torch.load(MODEL_PATH, weights_only=True)) |
| else: |
| model.load_state_dict(torch.load(MODEL_PATH, weights_only=True, map_location=torch.device('cpu'))) |
| model.eval() |
|
|
| def generate_sample_from_image_text( |
| model, |
| image_path, |
| prompt_label, |
| tokenizer, |
| device, |
| max_new_tokens=70, |
| temperature=0.8, |
| top_k=10, |
| output_path="generated_output.png" |
| ): |
| """ |
| Generates a prediction for an image and prompt text and saves it to a file. |
| Generation loop is implemented *within* this function. |
| |
| Args: |
| model: The trained VisionLanguageModel. |
| image_path: Path to the input image. |
| prompt_label: Text prompt/label to use. |
| tokenizer: The tokenizer used for training. |
| device: The computation device ('cuda' or 'cpu'). |
| max_new_tokens (int): Max tokens to generate after the prompt. |
| temperature (float): Softmax temperature for sampling. |
| top_k (int): K for top-k sampling (0 or None to disable). |
| output_path (str): Path where to save the output image. |
| |
| Returns: |
| None. Saves the image with prompt and generated output to a file. |
| """ |
| model.eval() |
|
|
| try: |
| with torch.no_grad(): |
| |
| |
| image = Image.open(image_path) |
| image_tensor = image_to_tensor(image).unsqueeze(0).to(device) |
|
|
| |
| prompt_text = f"<point_start>{prompt_label}<point_end>" |
| prompt_tokens = tokenizer(prompt_text, return_tensors="pt", truncation=True, padding=False) |
| prompt_ids = prompt_tokens.input_ids.to(device) |
| prompt_attention_mask = prompt_tokens.attention_mask.to(device) |
| B = 1 |
|
|
| print(f"--- Generating Sample (Manual Loop) ---") |
| print(f"Original Label/Prompt Hint: {prompt_label}") |
| print(f"Input Prompt Tokens Decoded: {prompt_text}") |
|
|
| |
| image_embeds_raw = model.vision_encoder(image_tensor) |
| image_embeds_decoder = model.multimodal_projector(image_embeds_raw) |
| prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) |
|
|
| result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0] |
| result_start_embed = model.decoder.token_embedding_table( |
| torch.tensor([[result_start_token_id]], device=device) |
| ) |
|
|
| |
| current_embeds = torch.cat([ |
| image_embeds_decoder, |
| prompt_embeds_decoder, |
| result_start_embed |
| ], dim=1) |
| generated_ids = [] |
|
|
| |
| for _ in range(max_new_tokens): |
| T_current = current_embeds.shape[1] |
|
|
| |
| if T_current > model.decoder.max_context: |
| print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}") |
| current_embeds = current_embeds[:, -model.decoder.max_context:, :] |
| T_current = model.decoder.max_context |
|
|
| |
| pos = torch.arange(0, T_current, dtype=torch.long, device=device) |
| pos = pos.clamp(max=model.decoder.max_context - 1) |
| pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) |
| x = current_embeds + pos_emb |
|
|
| |
| |
| attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long) |
|
|
| |
| for block in model.decoder.blocks: |
| |
| x = block(x, attention_mask=attention_mask) |
|
|
| |
| x = model.decoder.ln_f(x[:, -1:, :]) |
| logits = model.decoder.lm_head(x) |
| logits = logits.squeeze(1) |
|
|
| |
| logits = logits / temperature |
| if top_k is not None and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| |
| idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| |
| generated_ids.append(idx_next) |
|
|
| |
| if idx_next.item() == tokenizer.eos_token_id: |
| print("EOS token generated.") |
| break |
|
|
| |
| next_token_embed = model.decoder.token_embedding_table(idx_next) |
| current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) |
|
|
| |
| if generated_ids: |
| generated_ids_tensor = torch.cat(generated_ids, dim=1) |
| initial_target_ids = torch.tensor([[result_start_token_id]], device=device) |
| full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1) |
| else: |
| full_generated_sequence_ids = prompt_ids |
|
|
| full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False) |
| print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}") |
|
|
| |
| save_coords_visualization( |
| image_tensor=image_tensor[0], |
| full_decoded_text=full_decoded_text, |
| tokenizer=tokenizer, |
| image_size=IMAGE_SIZE, |
| num_bins=NUM_BINS, |
| output_path=output_path |
| ) |
| print(f"Visualization saved to: {output_path}") |
|
|
| except Exception as e: |
| print(f"An error occurred during sample generation: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| def generate_sample_from_test_loader( |
| model, |
| test_loader, |
| tokenizer, |
| device, |
| max_new_tokens=70, |
| temperature=0.8, |
| top_k=10, |
| output_path="generated_output.png", |
| TEST_BATCH=8, |
| TEST_IDX=1 |
| ): |
| """ |
| Generates a prediction for one sample from the test loader and saves it to a file. |
| Generation loop is implemented *within* this function. |
| |
| Args: |
| model: The trained VisionLanguageModel. |
| test_loader: DataLoader for the test set. |
| tokenizer: The tokenizer used for training. |
| device: The computation device ('cuda' or 'cpu'). |
| max_new_tokens (int): Max tokens to generate after the prompt. |
| temperature (float): Softmax temperature for sampling. |
| top_k (int): K for top-k sampling (0 or None to disable). |
| output_path (str): Path where to save the output image. |
| |
| Returns: |
| None. Saves the image with prompt and generated output to a file. |
| """ |
|
|
| if not test_loader or len(test_loader.dataset) == 0: |
| print("Test loader is empty or not available.") |
| return |
|
|
| model.eval() |
|
|
| try: |
| |
| with torch.no_grad(): |
| my_iter = iter(test_loader) |
| for i in range(TEST_BATCH): |
| _ = next(my_iter) |
| batch = next(my_iter) |
|
|
| if batch is None: |
| print("Test loader yielded an empty batch.") |
| return |
| if batch['image'].shape[0] == 0: |
| print("Test loader yielded a batch with 0 items.") |
| return |
|
|
| |
| image_tensor = batch['image'][TEST_IDX:TEST_IDX+1].to(device) |
| prompt_ids = batch['prompt_ids'][TEST_IDX:TEST_IDX+1].to(device) |
| prompt_attention_mask = batch['prompt_attention_mask'][TEST_IDX:TEST_IDX+1].to(device) |
| label = batch['label'][TEST_IDX] |
| B = 1 |
|
|
| print(f"--- Generating Sample (Manual Loop) ---") |
| print(f"Original Label/Prompt Hint: {label}") |
| prompt_text = tokenizer.decode(prompt_ids[0], skip_special_tokens=False) |
| print(f"Input Prompt Tokens Decoded: {prompt_text}") |
|
|
| |
| image_embeds_raw = model.vision_encoder(image_tensor) |
| image_embeds_decoder = model.multimodal_projector(image_embeds_raw) |
| prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) |
|
|
| result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0] |
| result_start_embed = model.decoder.token_embedding_table( |
| torch.tensor([[result_start_token_id]], device=device) |
| ) |
|
|
| |
| current_embeds = torch.cat([ |
| image_embeds_decoder, |
| prompt_embeds_decoder, |
| result_start_embed |
| ], dim=1) |
| |
| generated_ids = [] |
|
|
| |
| for _ in range(max_new_tokens): |
| T_current = current_embeds.shape[1] |
|
|
| |
| if T_current > model.decoder.max_context: |
| print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}") |
| current_embeds = current_embeds[:, -model.decoder.max_context:, :] |
| T_current = model.decoder.max_context |
|
|
| |
| pos = torch.arange(0, T_current, dtype=torch.long, device=device) |
| pos = pos.clamp(max=model.decoder.max_context - 1) |
| pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) |
| x = current_embeds + pos_emb |
|
|
| |
| |
| attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long) |
|
|
| |
| for block in model.decoder.blocks: |
| |
| x = block(x, attention_mask=attention_mask) |
|
|
| |
| x = model.decoder.ln_f(x[:, -1:, :]) |
| logits = model.decoder.lm_head(x) |
| logits = logits.squeeze(1) |
|
|
| |
| logits = logits / temperature |
| if top_k is not None and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| |
| idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| |
| generated_ids.append(idx_next) |
|
|
| |
| if idx_next.item() == tokenizer.eos_token_id: |
| print("EOS token generated.") |
| break |
|
|
| |
| next_token_embed = model.decoder.token_embedding_table(idx_next) |
| current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) |
|
|
| |
| if generated_ids: |
| generated_ids_tensor = torch.cat(generated_ids, dim=1) |
| initial_target_ids = torch.tensor([[result_start_token_id]], device=device) |
| full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1) |
| else: |
| full_generated_sequence_ids = prompt_ids |
|
|
| full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False) |
| print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}") |
|
|
| |
| save_coords_visualization( |
| image_tensor=image_tensor[0], |
| full_decoded_text=full_decoded_text, |
| tokenizer=tokenizer, |
| image_size=IMAGE_SIZE, |
| num_bins=NUM_BINS, |
| output_path=output_path |
| ) |
| print(f"Visualization saved to: {output_path}") |
|
|
| except StopIteration: |
| print("Test loader is exhausted.") |
| except Exception as e: |
| print(f"An error occurred during sample generation: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| def parse_coordinate_tokens(text, tokenizer, num_bins): |
| """ |
| Parses generated text to extract coordinate bin tokens. |
| |
| Args: |
| text (str): The decoded output text from the model. |
| tokenizer: The tokenizer. |
| num_bins (int): The number of coordinate bins used. |
| |
| Returns: |
| list[tuple(int, int)]: A list of (x_bin, y_bin) tuples, or None if parsing fails. |
| """ |
| coords = [] |
| try: |
| |
| x_start_token = "<pointx_start>" |
| x_end_token = "<pointx_end>" |
| y_start_token = "<pointy_start>" |
| y_end_token = "<pointy_end>" |
| result_end_token = "<result_end>" |
|
|
| |
| try: |
| start_index = text.index("<result_start>") + len("<result_start>") |
| except ValueError: |
| print("Warning: <result_start> not found in generated text.") |
| return None |
|
|
| |
| try: |
| end_index = text.index(result_end_token, start_index) |
| except ValueError: |
| end_index = len(text) |
| print(f"Warning: {result_end_token} not found. Parsing until end of string.") |
|
|
|
|
| current_pos = start_index |
| while current_pos < end_index: |
| |
| x_start_idx = text.find(x_start_token, current_pos) |
| if x_start_idx == -1 or x_start_idx >= end_index: break |
| x_start_idx += len(x_start_token) |
|
|
| x_end_idx = text.find(x_end_token, x_start_idx) |
| if x_end_idx == -1 or x_end_idx >= end_index: break |
|
|
| x_token_str = text[x_start_idx:x_end_idx].strip() |
|
|
| |
| y_start_idx = text.find(y_start_token, x_end_idx) |
| if y_start_idx == -1 or y_start_idx >= end_index: break |
| y_start_idx += len(y_start_token) |
|
|
| y_end_idx = text.find(y_end_token, y_start_idx) |
| if y_end_idx == -1 or y_end_idx >= end_index: break |
|
|
| y_token_str = text[y_start_idx:y_end_idx].strip() |
| |
| x_token_str = x_token_str[:-1] |
| y_token_str = y_token_str[:-1] |
|
|
| |
| try: |
| x_bin = int(x_token_str.split("_")[-1]) |
| y_bin = int(y_token_str.split("_")[-1]) |
| if 0 <= x_bin < num_bins and 0 <= y_bin < num_bins: |
| coords.append((x_bin, y_bin)) |
| else: |
| print(f"Warning: Parsed bin indices out of range ({x_bin}, {y_bin}). Skipping.") |
| except (ValueError, IndexError): |
| print(f"Warning: Could not parse bins from tokens '{x_token_str}', '{y_token_str}'. Skipping.") |
|
|
| |
| current_pos = y_end_idx + len(y_end_token) |
|
|
| return coords if coords else None |
|
|
| except Exception as e: |
| print(f"Error during coordinate parsing: {e}") |
| return None |
|
|
|
|
| def save_coords_visualization(image_tensor, full_decoded_text, tokenizer, image_size, num_bins, output_path): |
| """Parses coords, draws them on the image, and saves to a file.""" |
| parsed_bins = parse_coordinate_tokens(full_decoded_text, tokenizer, num_bins) |
|
|
| |
| try: |
| pil_image = tensor_to_image(image_tensor.cpu()) |
| except Exception as e: |
| print(f"Error converting tensor to image: {e}") |
| |
| pil_image = Image.new('RGB', (image_size, image_size), color='white') |
| draw = ImageDraw.Draw(pil_image) |
| draw.text((10, 10), "Image conversion failed", fill="black") |
| pil_image.save(output_path) |
| return |
|
|
| draw = ImageDraw.Draw(pil_image) |
| radius = 5 |
|
|
| if parsed_bins: |
| print(f"\nParsed Coordinate Bins: {parsed_bins}") |
| bin_size_pixels = image_size / num_bins |
| for x_bin, y_bin in parsed_bins: |
| |
| center_x = (x_bin + 0.5) * bin_size_pixels |
| center_y = (y_bin + 0.5) * bin_size_pixels |
|
|
| |
| bbox = [center_x - radius, center_y - radius, center_x + radius, center_y + radius] |
| draw.ellipse(bbox, outline="red", width=3) |
| |
| |
|
|
| |
| coord_text = f"Generated Point(s): {parsed_bins}" |
| draw.text((10, 10), coord_text, fill="red") |
| else: |
| print("\nCould not parse valid coordinates from the generated text.") |
| |
| draw.text((10, 10), "No Coordinates Parsed", fill="red") |
|
|
| |
| pil_image.save(output_path) |
|
|
|
|
| import argparse |
|
|
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--image', type=str, help='Path to input image') |
| parser.add_argument('--prompt', type=str, help='Prompt label for generation') |
| args = parser.parse_args() |
| if args.image and args.prompt: |
| |
| if 'model' in locals() and 'tokenizer' in locals(): |
| generate_sample_from_image_text( |
| model=model, |
| image_path=args.image, |
| prompt_label=args.prompt, |
| tokenizer=tokenizer, |
| device=DEVICE, |
| output_path="model_prediction.png" |
| ) |
| else: |
| print("Please ensure 'model' and 'tokenizer' are loaded before running generation.") |
| else: |
| |
| if 'model' in locals() and 'test_loader' in locals() and 'tokenizer' in locals(): |
| test_loader = create_test_dataloader(batch_size=2, num_workers=0) |
| generate_sample_from_test_loader( |
| model=model, |
| test_loader=test_loader, |
| tokenizer=tokenizer, |
| device=DEVICE, |
| output_path="model_prediction.png" |
| ) |
| else: |
| print("Please ensure 'model', 'test_loader', and 'tokenizer' are loaded before running generation.") |