| import gradio as gr |
| import torch |
| import numpy as np |
| import pickle |
| from PIL import Image |
| import os |
| from convnext_original import ConvNeXt as ConvNeXtOriginal |
| from convnext_finetune import ConvNeXt |
|
|
| |
| content_model = None |
| quality_model = None |
| scaler = None |
| regression_model = None |
| device = None |
|
|
| def get_activation(name, activations): |
| """Hook function to capture activations.""" |
| def hook(model, input, output): |
| activations[name] = output.detach() |
| return hook |
|
|
| def register_hooks(model): |
| """Register hooks for each layer in the model.""" |
| activations = {} |
| for name, module in model.named_modules(): |
| module.register_forward_hook(get_activation(name, activations)) |
| return activations |
|
|
| def preprocess_image(image): |
| """Preprocess image for model input.""" |
| |
| mean = np.array([0.485, 0.456, 0.406]) |
| std = np.array([0.229, 0.224, 0.225]) |
| |
| img_array = np.array(image, dtype=np.float32) / 255.0 |
| img_array = (img_array - mean) / std |
| return torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() |
|
|
| def load_models(): |
| """Load all required models.""" |
| global content_model, quality_model, scaler, regression_model, device |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| |
| required_files = [ |
| 'feature_models/convnext_tiny_22k_224.pth', |
| 'feature_models/triqa_quality_aware.pth', |
| 'Regression_Models/KonIQ_scaler.save', |
| 'Regression_Models/KonIQ_TRIQA.save' |
| ] |
| |
| missing_files = [f for f in required_files if not os.path.exists(f)] |
| if missing_files: |
| print(f"Missing model files: {missing_files}") |
| print("Please download model files from the Box link and place them in the correct directories.") |
| return None, None |
| |
| try: |
| |
| content_model = ConvNeXtOriginal(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) |
| content_state_dict = torch.load('feature_models/convnext_tiny_22k_224.pth', map_location=device)['model'] |
| content_state_dict = {k: v for k, v in content_state_dict.items() if not k.startswith('head.')} |
| content_model.load_state_dict(content_state_dict, strict=False) |
| content_model.to(device).eval() |
| |
| |
| quality_model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) |
| quality_state_dict = torch.load('feature_models/triqa_quality_aware.pth', map_location=device) |
| quality_model.load_state_dict(quality_state_dict, strict=True) |
| quality_model.to(device).eval() |
| |
| |
| content_activations = register_hooks(content_model) |
| quality_activations = register_hooks(quality_model) |
| |
| |
| with open('Regression_Models/KonIQ_scaler.save', 'rb') as f: |
| scaler = pickle.load(f) |
| with open('Regression_Models/KonIQ_TRIQA.save', 'rb') as f: |
| regression_model = pickle.load(f) |
| |
| return content_activations, quality_activations |
| except Exception as e: |
| print(f"Error loading models: {e}") |
| return None, None |
|
|
| def predict_quality(image): |
| """Predict image quality score on 1-5 scale.""" |
| global content_model, quality_model, scaler, regression_model, device |
| |
| if content_model is None or quality_model is None: |
| return "Models not loaded. Please wait..." |
| |
| |
| image_half = image.resize((image.size[0]//2, image.size[1]//2), Image.LANCZOS) |
| |
| img_full = preprocess_image(image).to(device) |
| img_half = preprocess_image(image_half).to(device) |
| |
| with torch.no_grad(): |
| |
| _ = content_model(img_full) |
| content_full = content_model.activations['norm'].cpu().numpy().flatten() |
| |
| _ = content_model(img_half) |
| content_half = content_model.activations['norm'].cpu().numpy().flatten() |
| |
| content_features = np.concatenate([content_full, content_half]) |
| |
| |
| _ = quality_model(img_full) |
| quality_full = quality_model.activations['norm'].cpu().numpy().flatten() |
| |
| _ = quality_model(img_half) |
| quality_half = quality_model.activations['norm'].cpu().numpy().flatten() |
| |
| quality_features = np.concatenate([quality_full, quality_half]) |
| |
| |
| combined_features = np.concatenate([content_features, quality_features]) |
| normalized_features = scaler.transform(combined_features.reshape(1, -1)) |
| quality_score = regression_model.predict(normalized_features)[0] |
| |
| return f"Quality Score: {quality_score:.2f}/5.0" |
|
|
| def create_demo(): |
| """Create the Gradio demo interface.""" |
| |
| |
| try: |
| content_activations, quality_activations = load_models() |
| content_model.activations = content_activations |
| quality_model.activations = quality_activations |
| print("Models loaded successfully!") |
| except Exception as e: |
| print(f"Error loading models: {e}") |
| return None |
| |
| |
| with gr.Blocks(title="TRIQA: Image Quality Assessment", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # TRIQA: Image Quality Assessment |
| |
| **TRIQA** combines content-aware and quality-aware features from ConvNeXt models to predict image quality scores on a 1-5 scale. |
| |
| ### How to use: |
| 1. Upload an image using the file uploader below |
| 2. Click "Assess Quality" to get the quality score |
| 3. The score ranges from 1-5, where 5 represents the highest quality |
| |
| ### Paper Links: |
| - **arXiv**: [https://arxiv.org/pdf/2507.12687](https://arxiv.org/pdf/2507.12687) |
| - **IEEE Xplore**: [https://ieeexplore.ieee.org/abstract/document/11084443](https://ieeexplore.ieee.org/abstract/document/11084443) |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image( |
| label="Upload Image", |
| type="pil", |
| height=400 |
| ) |
| submit_btn = gr.Button("Assess Quality", variant="primary") |
| |
| with gr.Column(): |
| output_text = gr.Textbox( |
| label="Quality Assessment Result", |
| value="Upload an image and click 'Assess Quality' to get the quality score.", |
| interactive=False |
| ) |
| |
| gr.Examples( |
| examples=[ |
| ["sample_image/233045618.jpg"], |
| ["sample_image/25239707.jpg"], |
| ["sample_image/44009500.jpg"], |
| ["sample_image/5129172.jpg"], |
| ["sample_image/85119046.jpg"] |
| ], |
| inputs=input_image, |
| label="Sample Images" |
| ) |
| |
| submit_btn.click( |
| fn=predict_quality, |
| inputs=input_image, |
| outputs=output_text |
| ) |
| |
| gr.Markdown(""" |
| ### Citation: |
| If you use this code in your research, please cite our paper: |
| |
| ```bibtex |
| @INPROCEEDINGS{11084443, |
| author={Sureddi, Rajesh and Zadtootaghaj, Saman and Barman, Nabajeet and Bovik, Alan C.}, |
| booktitle={2025 IEEE International Conference on Image Processing (ICIP)}, |
| title={Triqa: Image Quality Assessment by Contrastive Pretraining on Ordered Distortion Triplets}, |
| year={2025}, |
| volume={}, |
| number={}, |
| pages={1744-1749}, |
| keywords={Image quality;Training;Deep learning;Contrastive learning;Predictive models;Feature extraction;Distortion;Data models;Synthetic data;Image Quality Assessment;Contrastive Learning}, |
| doi={10.1109/ICIP55913.2025.11084443}} |
| ``` |
| """) |
| |
| return demo |
|
|
| if __name__ == "__main__": |
| demo = create_demo() |
| if demo: |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
| else: |
| print("Failed to create demo. Please check model files.") |
|
|