import torch import torch.nn as nn from torch.nn import init import functools from torch.autograd import Variable import numpy as np import logging from logging_utils import setup_logger # Configure logging logger = setup_logger(__name__) # Try to import the necessary modules, use fallback if not available try: from models.fpn_inception import FPNInception INCEPTION_AVAILABLE = True logger.info("Successfully imported FPNInception model") except ImportError as e: logger.error(f"Error importing FPNInception: {str(e)}") INCEPTION_AVAILABLE = False # Simple fallback model for testing purposes class FallbackDeblurModel(nn.Module): def __init__(self): super().__init__() logger.info("Initializing fallback model for testing") # Simple autoencoder-like structure self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2), nn.ReLU(inplace=True), nn.Conv2d(64, 3, kernel_size=3, padding=1), nn.Tanh() ) def forward(self, x): # Simple pass-through for testing encoded = self.encoder(x) decoded = self.decoder(encoded) return torch.clamp(decoded + x, min=-1, max=1) def get_norm_layer(norm_type='instance'): if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def get_generator(model_config): if isinstance(model_config, str): generator_name = model_config else: generator_name = model_config['g_name'] # Try to use FPNInception if available if generator_name == 'fpn_inception': if INCEPTION_AVAILABLE: try: logger.info("Creating FPNInception model") model_g = FPNInception(norm_layer=get_norm_layer(norm_type='instance')) return nn.DataParallel(model_g) except Exception as e: logger.error(f"Error creating FPNInception model: {str(e)}") logger.warning("Falling back to simple model for testing") return FallbackDeblurModel() else: logger.warning("FPNInception not available, using fallback model") return FallbackDeblurModel() else: raise ValueError("Generator Network [%s] not recognized." % generator_name)