| import gradio as gr |
| from PIL import Image |
| import torch |
| from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification |
| from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
| import nltk |
| import warnings |
|
|
| try: |
| nltk.data.find("tokenizers/punkt") |
| except LookupError: |
| nltk.download("punkt") |
|
|
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device) |
| tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO") |
| feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO") |
|
|
| with open("style.css") as f: |
| custom_css = f.read() |
|
|
| def load_classifier(model_id): |
| processor = AutoImageProcessor.from_pretrained(model_id) |
| model = AutoModelForImageClassification.from_pretrained(model_id).to(device) |
| return processor, model |
|
|
| classifiers = { |
| "plane": load_classifier("bombshelll/swin-brain-plane-classification"), |
| "modality": load_classifier("bombshelll/swin-brain-modality-classification"), |
| "abnormality": load_classifier("bombshelll/swin-brain-abnormalities-classification"), |
| "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification") |
| } |
|
|
| def classify_image(image): |
| results = {} |
| for name, (processor, model) in classifiers.items(): |
| inputs = processor(image, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| label = model.config.id2label[logits.argmax(-1).item()] |
| if name != "tumor_type" or results.get("abnormality") == "tumor": |
| results[name] = label |
| return results |
|
|
| def preprocess_caption(text): |
| text = str(text).lower() |
| for term in ["magnetic resonance imaging", "magnetic resonance image"]: |
| text = text.replace(term, "mri") |
| for term in ["computed tomography"]: |
| text = text.replace(term, "ct") |
| text = text.replace("t1-weighted", "t1").replace("t1w1", "t1").replace("t1ce", "t1") |
| text = text.replace("t2-weighted", "t2").replace("t2w", "t2").replace("t2/flair", "flair") |
| text = text.replace("tumour", "tumor").replace("lesions", "lesion").replace("-", " ") |
| return text.split() |
|
|
| def generate_captions(image, keywords): |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) |
|
|
| caption_model.eval() |
| with torch.no_grad(): |
| output_ids = caption_model.generate(pixel_values, max_length=80) |
| caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
| prompt = " ".join(keywords) |
| prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
| with torch.no_grad(): |
| output_ids = caption_model.generate( |
| pixel_values, |
| decoder_input_ids=prompt_ids[:, :-1], |
| max_length=80, |
| num_beams=4, |
| no_repeat_ngram_size=3, |
| length_penalty=2.0 |
| ) |
| caption2 = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
| return caption1, caption2 |
|
|
| def run_pipeline(image, actual_caption): |
| classification = classify_image(image) |
| keywords = list(classification.values()) |
| caption1, caption2 = generate_captions(image, keywords) |
|
|
| classification_text = ( |
| f"Plane: {classification.get('plane')}\n" |
| f"Modality: {classification.get('modality')}\n" |
| f"Abnormality: {classification.get('abnormality')}\n" |
| + (f"Tumor Type: {classification.get('tumor_type')}" if "tumor_type" in classification else "") |
| ) |
|
|
| if actual_caption.strip(): |
| ref = [preprocess_caption(actual_caption)] |
| hyp1 = preprocess_caption(caption1) |
| hyp2 = preprocess_caption(caption2) |
| smooth = SmoothingFunction().method1 |
| bleu1 = f"{sentence_bleu(ref, hyp1, smoothing_function=smooth):.2f}" |
| bleu2 = f"{sentence_bleu(ref, hyp2, smoothing_function=smooth):.2f}" |
| else: |
| bleu1 = "-" |
| bleu2 = "-" |
|
|
| return classification_text, caption1, caption2, bleu1, bleu2 |
|
|
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css=custom_css) as demo: |
| gr.Markdown( |
| """ |
| <link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet"> |
| <h1 style='text-align: center;'>π§ Brain Hierarchical Classification + Captioning</h1> |
| <p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image and generate captions. Optionally, provide ground truth to see BLEU scores.</p> |
| """, |
| elem_id="title" |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(type="pil", label="πΌοΈ Upload Brain MRI/CT") |
| actual_caption = gr.Textbox(label="π¬ Ground Truth Caption (optional)") |
| btn = gr.Button("π Submit") |
| with gr.Column(): |
| cls_box = gr.Textbox(label="π Classification Result", lines=4) |
| cap1_box = gr.Textbox(label="π Caption without Keyword Integration", lines=4) |
| cap2_box = gr.Textbox(label="π§ Caption with Keyword Integration", lines=4) |
| bleu1_box = gr.Textbox(label="π BLEU (No Keyword)", lines=1) |
| bleu2_box = gr.Textbox(label="π BLEU (With Keyword)", lines=1) |
|
|
| btn.click( |
| fn=run_pipeline, |
| inputs=[image_input, actual_caption], |
| outputs=[cls_box, cap1_box, cap2_box, bleu1_box, bleu2_box] |
| ) |
|
|
| demo.launch() |
|
|