| | """ |
| | PyTorch Autoencoder model for Hugging Face Transformers. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Tuple, Union, Dict, Any, List |
| | from dataclasses import dataclass |
| | import random |
| | import re |
| |
|
| | |
| | try: |
| | from transformers.modeling_utils import PreTrainedModel |
| | except Exception: |
| | |
| | from transformers import PreTrainedModel |
| |
|
| | from transformers.modeling_outputs import BaseModelOutput |
| | from transformers.utils import ModelOutput |
| |
|
| | try: |
| | from .configuration_autoencoder import AutoencoderConfig |
| | except Exception: |
| | from configuration_autoencoder import AutoencoderConfig |
| |
|
| | |
| | try: |
| | from .blocks import ( |
| | BlockFactory, |
| | BlockSequence, |
| | LinearBlockConfig, |
| | AttentionBlockConfig, |
| | RecurrentBlockConfig, |
| | ConvolutionalBlockConfig, |
| | VariationalBlockConfig, |
| | VariationalBlock, |
| | ) |
| | except Exception: |
| | from blocks import ( |
| | BlockFactory, |
| | BlockSequence, |
| | LinearBlockConfig, |
| | AttentionBlockConfig, |
| | RecurrentBlockConfig, |
| | ConvolutionalBlockConfig, |
| | VariationalBlockConfig, |
| | VariationalBlock, |
| | ) |
| |
|
| | |
| | try: |
| | from .utils import _get_activation |
| | except Exception: |
| | from utils import _get_activation |
| |
|
| | |
| | try: |
| | from .preprocessing import PreprocessingBlock |
| | except Exception: |
| | from preprocessing import PreprocessingBlock |
| |
|
| |
|
| | @dataclass |
| | class AutoencoderOutput(ModelOutput): |
| | """ |
| | Output type of AutoencoderModel. |
| | |
| | Args: |
| | last_hidden_state (torch.FloatTensor): The latent representation of the input. |
| | reconstructed (torch.FloatTensor, optional): The reconstructed input. |
| | hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
| | attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder. |
| | preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
| | """ |
| |
|
| | last_hidden_state: torch.FloatTensor = None |
| | reconstructed: Optional[torch.FloatTensor] = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| | preprocessing_loss: Optional[torch.FloatTensor] = None |
| |
|
| |
|
| | @dataclass |
| | class AutoencoderForReconstructionOutput(ModelOutput): |
| | """ |
| | Output type of AutoencoderForReconstruction. |
| | |
| | Args: |
| | loss (torch.FloatTensor, optional): The reconstruction loss. |
| | reconstructed (torch.FloatTensor): The reconstructed input. |
| | last_hidden_state (torch.FloatTensor): The latent representation. |
| | hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
| | preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
| | """ |
| |
|
| | loss: Optional[torch.FloatTensor] = None |
| | reconstructed: torch.FloatTensor = None |
| | last_hidden_state: torch.FloatTensor = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | preprocessing_loss: Optional[torch.FloatTensor] = None |
| |
|
| |
|
| | class AutoencoderEncoder(nn.Module): |
| | """Encoder part of the autoencoder.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | layers = [] |
| | input_dim = config.input_dim |
| |
|
| | for hidden_dim in config.hidden_dims: |
| | layers.append(nn.Linear(input_dim, hidden_dim)) |
| |
|
| | if config.use_batch_norm: |
| | layers.append(nn.BatchNorm1d(hidden_dim)) |
| |
|
| | layers.append(self._get_activation(config.activation)) |
| |
|
| | if config.dropout_rate > 0: |
| | layers.append(nn.Dropout(config.dropout_rate)) |
| |
|
| | input_dim = hidden_dim |
| |
|
| | self.encoder = nn.Sequential(*layers) |
| |
|
| | |
| | if config.is_variational: |
| | self.fc_mu = nn.Linear(input_dim, config.latent_dim) |
| | self.fc_logvar = nn.Linear(input_dim, config.latent_dim) |
| | else: |
| | |
| | self.fc_out = nn.Linear(input_dim, config.latent_dim) |
| |
|
| |
|
| | def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| | """Forward pass through encoder.""" |
| | |
| | if self.config.is_denoising and self.training: |
| | noise = torch.randn_like(x) * self.config.noise_factor |
| | x = x + noise |
| |
|
| | encoded = self.encoder(x) |
| |
|
| | if self.config.is_variational: |
| | |
| | mu = self.fc_mu(encoded) |
| | logvar = self.fc_logvar(encoded) |
| |
|
| | |
| | if self.training: |
| | std = torch.exp(0.5 * logvar) |
| | eps = torch.randn_like(std) |
| | z = mu + eps * std |
| | else: |
| | z = mu |
| |
|
| | return z, mu, logvar |
| | else: |
| | |
| | latent = self.fc_out(encoded) |
| |
|
| | |
| | if self.config.is_sparse and self.training: |
| | |
| | latent = F.relu(latent) |
| |
|
| | return latent |
| |
|
| |
|
| | class AutoencoderDecoder(nn.Module): |
| | """Decoder part of the autoencoder.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | layers = [] |
| | input_dim = config.latent_dim |
| | decoder_dims = config.decoder_dims + [config.input_dim] |
| |
|
| | for i, hidden_dim in enumerate(decoder_dims): |
| | layers.append(nn.Linear(input_dim, hidden_dim)) |
| |
|
| | |
| | if i < len(decoder_dims) - 1: |
| | if config.use_batch_norm: |
| | layers.append(nn.BatchNorm1d(hidden_dim)) |
| |
|
| | layers.append(_get_activation(config.activation)) |
| |
|
| | if config.dropout_rate > 0: |
| | layers.append(nn.Dropout(config.dropout_rate)) |
| | else: |
| | |
| | if config.reconstruction_loss == "bce": |
| | layers.append(nn.Sigmoid()) |
| |
|
| | input_dim = hidden_dim |
| |
|
| | self.decoder = nn.Sequential(*layers) |
| |
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Forward pass through decoder.""" |
| | return self.decoder(x) |
| |
|
| |
|
| | class RecurrentEncoder(nn.Module): |
| | """Recurrent encoder for sequence data.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | if config.rnn_type == "lstm": |
| | rnn_class = nn.LSTM |
| | elif config.rnn_type == "gru": |
| | rnn_class = nn.GRU |
| | elif config.rnn_type == "rnn": |
| | rnn_class = nn.RNN |
| | else: |
| | raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
| |
|
| | |
| | self.rnn = rnn_class( |
| | input_size=config.input_dim, |
| | hidden_size=config.latent_dim, |
| | num_layers=config.num_layers, |
| | batch_first=True, |
| | dropout=config.dropout_rate if config.num_layers > 1 else 0, |
| | bidirectional=config.bidirectional |
| | ) |
| |
|
| | |
| | if config.bidirectional: |
| | self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim) |
| | else: |
| | self.projection = None |
| |
|
| | |
| | if config.use_batch_norm: |
| | self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
| | else: |
| | self.batch_norm = None |
| |
|
| | |
| | if config.dropout_rate > 0: |
| | self.dropout = nn.Dropout(config.dropout_rate) |
| | else: |
| | self.dropout = None |
| |
|
| | def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| | """ |
| | Forward pass through recurrent encoder. |
| | |
| | Args: |
| | x: Input tensor of shape (batch_size, seq_len, input_dim) |
| | lengths: Sequence lengths for packed sequences (optional) |
| | |
| | Returns: |
| | Encoded representation or tuple for VAE |
| | """ |
| | batch_size, seq_len, _ = x.shape |
| |
|
| | |
| | if self.config.is_denoising and self.training: |
| | noise = torch.randn_like(x) * self.config.noise_factor |
| | x = x + noise |
| |
|
| | |
| | if lengths is not None: |
| | x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) |
| |
|
| | |
| | if self.config.rnn_type == "lstm": |
| | output, (hidden, cell) = self.rnn(x) |
| | else: |
| | output, hidden = self.rnn(x) |
| | cell = None |
| |
|
| | |
| | if lengths is not None: |
| | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) |
| |
|
| | |
| | if self.config.bidirectional: |
| | |
| | hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim) |
| | hidden = hidden[-1] |
| | hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) |
| |
|
| | |
| | if self.projection: |
| | hidden = self.projection(hidden) |
| | else: |
| | hidden = hidden[-1] |
| |
|
| | |
| | if self.batch_norm: |
| | hidden = self.batch_norm(hidden) |
| |
|
| | |
| | if self.dropout and self.training: |
| | hidden = self.dropout(hidden) |
| |
|
| | |
| | if self.config.is_variational: |
| | |
| | mu = hidden[:, :self.config.latent_dim // 2] |
| | logvar = hidden[:, self.config.latent_dim // 2:] |
| |
|
| | |
| | if self.training: |
| | std = torch.exp(0.5 * logvar) |
| | eps = torch.randn_like(std) |
| | z = mu + eps * std |
| | else: |
| | z = mu |
| |
|
| | return z, mu, logvar |
| | else: |
| | return hidden |
| |
|
| |
|
| | class RecurrentDecoder(nn.Module): |
| | """Recurrent decoder for sequence data.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | if config.rnn_type == "lstm": |
| | rnn_class = nn.LSTM |
| | elif config.rnn_type == "gru": |
| | rnn_class = nn.GRU |
| | elif config.rnn_type == "rnn": |
| | rnn_class = nn.RNN |
| | else: |
| | raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
| |
|
| | |
| | self.rnn = rnn_class( |
| | input_size=config.latent_dim, |
| | hidden_size=config.latent_dim, |
| | num_layers=config.num_layers, |
| | batch_first=True, |
| | dropout=config.dropout_rate if config.num_layers > 1 else 0, |
| | bidirectional=False |
| | ) |
| |
|
| | |
| | self.output_projection = nn.Linear(config.latent_dim, config.input_dim) |
| |
|
| | |
| | if config.use_batch_norm: |
| | self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
| | else: |
| | self.batch_norm = None |
| |
|
| | |
| | if config.dropout_rate > 0: |
| | self.dropout = nn.Dropout(config.dropout_rate) |
| | else: |
| | self.dropout = None |
| |
|
| | def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor: |
| | """ |
| | Forward pass through recurrent decoder. |
| | |
| | Args: |
| | z: Latent representation of shape (batch_size, latent_dim) |
| | target_length: Length of sequence to generate |
| | target_sequence: Target sequence for teacher forcing (optional) |
| | |
| | Returns: |
| | Decoded sequence of shape (batch_size, seq_len, input_dim) |
| | """ |
| | batch_size = z.size(0) |
| | device = z.device |
| |
|
| | |
| | if self.config.rnn_type == "lstm": |
| | h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
| | c_0 = torch.zeros_like(h_0) |
| | hidden = (h_0, c_0) |
| | else: |
| | hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
| |
|
| | outputs = [] |
| |
|
| | |
| | current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
| |
|
| | for t in range(target_length): |
| | |
| | use_teacher_forcing = (target_sequence is not None and |
| | self.training and |
| | random.random() < self.config.teacher_forcing_ratio) |
| |
|
| | if use_teacher_forcing and t > 0: |
| | |
| | current_input = target_sequence[:, t-1:t, :] |
| | |
| | if current_input.size(-1) != self.config.latent_dim: |
| | current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
| |
|
| | |
| | if self.config.rnn_type == "lstm": |
| | output, hidden = self.rnn(current_input, hidden) |
| | else: |
| | output, hidden = self.rnn(current_input, hidden) |
| |
|
| | |
| | output_flat = output.squeeze(1) |
| |
|
| | if self.batch_norm: |
| | output_flat = self.batch_norm(output_flat) |
| |
|
| | if self.dropout and self.training: |
| | output_flat = self.dropout(output_flat) |
| |
|
| | |
| | step_output = self.output_projection(output_flat) |
| | outputs.append(step_output.unsqueeze(1)) |
| |
|
| | |
| | if not use_teacher_forcing: |
| | |
| | current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
| |
|
| | |
| | return torch.cat(outputs, dim=1) |
| |
|
| |
|
| | class AutoencoderModel(PreTrainedModel): |
| | """ |
| | The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top. |
| | |
| | This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the |
| | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| | etc.) |
| | |
| | This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the |
| | PyTorch documentation for all matter related to general usage and behavior. |
| | """ |
| |
|
| | config_class = AutoencoderConfig |
| | base_model_prefix = "autoencoder" |
| | supports_gradient_checkpointing = False |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | if config.has_preprocessing: |
| | self.pre_block = PreprocessingBlock(config, inverse=False) |
| | else: |
| | self.pre_block = None |
| |
|
| | |
| | norm = "batch" if config.use_batch_norm else "none" |
| |
|
| | def default_linear_sequence(in_dim: int, dims: List[int], activation: str, normalization: str, dropout: float) -> List[LinearBlockConfig]: |
| | cfgs: List[LinearBlockConfig] = [] |
| | prev = in_dim |
| | for h in dims: |
| | cfgs.append( |
| | LinearBlockConfig( |
| | input_dim=prev, |
| | output_dim=h, |
| | activation=activation, |
| | normalization=normalization, |
| | dropout_rate=dropout, |
| | use_residual=False, |
| | ) |
| | ) |
| | prev = h |
| | return cfgs |
| |
|
| | |
| | if getattr(config, "encoder_blocks", None): |
| | enc_cfgs = config.encoder_blocks |
| | |
| | last_out = None |
| | for b in enc_cfgs: |
| | if isinstance(b, dict): |
| | last_out = b.get("output_dim", last_out) |
| | else: |
| | last_out = getattr(b, "output_dim", last_out) |
| | enc_out_dim = last_out or (config.hidden_dims[-1] if config.hidden_dims else config.input_dim) |
| | else: |
| | enc_cfgs = default_linear_sequence(config.input_dim, config.hidden_dims, config.activation, norm, config.dropout_rate) |
| | enc_out_dim = config.hidden_dims[-1] if config.hidden_dims else config.input_dim |
| | base_encoder_seq: BlockSequence = BlockFactory.build_sequence(enc_cfgs) if len(enc_cfgs) > 0 else BlockSequence([]) |
| | |
| | self.encoder_seq = base_encoder_seq |
| |
|
| | |
| | if config.is_variational: |
| | self.fc_mu = nn.Linear(enc_out_dim, config.latent_dim) |
| | self.fc_logvar = nn.Linear(enc_out_dim, config.latent_dim) |
| | self.to_latent = None |
| | else: |
| | self.fc_mu = None |
| | self.fc_logvar = None |
| | self.to_latent = nn.Linear(enc_out_dim, config.latent_dim) |
| |
|
| | |
| | if getattr(config, "decoder_blocks", None): |
| | dec_cfgs = config.decoder_blocks |
| | else: |
| | dec_dims = config.decoder_dims + [config.input_dim] |
| | dec_cfgs = default_linear_sequence(config.latent_dim, dec_dims, config.activation, norm, config.dropout_rate) |
| | |
| | if len(dec_cfgs) > 0: |
| | last = dec_cfgs[-1] |
| | last.activation = "identity" |
| | last.normalization = "none" |
| | last.dropout_rate = 0.0 |
| | self.decoder_seq: BlockSequence = BlockFactory.build_sequence(dec_cfgs) if len(dec_cfgs) > 0 else BlockSequence([]) |
| |
|
| | |
| | if config.tie_weights: |
| | self._tie_weights() |
| |
|
| | |
| | self.post_init() |
| |
|
| | def _tie_weights(self): |
| | """Tie encoder and decoder weights (transpose relationship).""" |
| | |
| | pass |
| |
|
| | def get_input_embeddings(self): |
| | """Get input embeddings (not applicable for basic autoencoder).""" |
| | return None |
| |
|
| | def set_input_embeddings(self, value): |
| | """Set input embeddings (not applicable for basic autoencoder).""" |
| | pass |
| |
|
| | def forward( |
| | self, |
| | input_values: torch.Tensor, |
| | sequence_lengths: Optional[torch.Tensor] = None, |
| | target_length: Optional[int] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], AutoencoderOutput]: |
| | """ |
| | Forward pass through the autoencoder. |
| | |
| | Args: |
| | input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
| | - Standard: (batch_size, input_dim) |
| | - Recurrent: (batch_size, seq_len, input_dim) |
| | sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
| | target_length (int, optional): Target sequence length for recurrent decoder. |
| | output_hidden_states (bool, optional): Whether to return hidden states. |
| | return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
| | |
| | Returns: |
| | AutoencoderOutput or tuple: The model outputs. |
| | """ |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | if self.pre_block is not None: |
| | input_values = self.pre_block(input_values) |
| | preprocessing_loss = torch.tensor(0.0, device=input_values.device) |
| |
|
| | |
| | |
| | enc_out = self.encoder_seq(input_values) |
| |
|
| | |
| | if self.config.is_variational: |
| | |
| | self._variational = getattr(self, '_variational', None) |
| | if self._variational is None: |
| | self._variational = VariationalBlock(VariationalBlockConfig(input_dim=enc_out.shape[-1], latent_dim=self.config.latent_dim)).to(enc_out.device) |
| | latent = self._variational(enc_out, training=self.training) |
| | self._mu = self._variational._mu |
| | self._logvar = self._variational._logvar |
| | else: |
| | latent = self.to_latent(enc_out) if self.to_latent is not None else enc_out |
| | self._mu, self._logvar = None, None |
| |
|
| | |
| | reconstructed = self.decoder_seq(latent) |
| |
|
| |
|
| |
|
| | hidden_states = None |
| | if output_hidden_states: |
| | if self.config.is_variational: |
| | hidden_states = (latent, getattr(self, '_mu', None), getattr(self, '_logvar', None)) |
| | else: |
| | hidden_states = (latent,) |
| |
|
| | if not return_dict: |
| | return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None) |
| |
|
| | return AutoencoderOutput( |
| | last_hidden_state=latent, |
| | reconstructed=reconstructed, |
| | hidden_states=hidden_states, |
| | preprocessing_loss=preprocessing_loss, |
| | ) |
| |
|
| |
|
| | class AutoencoderForReconstruction(PreTrainedModel): |
| | """ |
| | Autoencoder Model with a reconstruction head on top for reconstruction tasks. |
| | |
| | This model inherits from PreTrainedModel and adds a reconstruction loss calculation. |
| | """ |
| |
|
| | config_class = AutoencoderConfig |
| | base_model_prefix = "autoencoder" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | self.autoencoder = AutoencoderModel(config) |
| |
|
| | |
| | self.post_init() |
| |
|
| |
|
| |
|
| | def get_input_embeddings(self): |
| | """Get input embeddings.""" |
| | return self.autoencoder.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value): |
| | """Set input embeddings.""" |
| | self.autoencoder.set_input_embeddings(value) |
| |
|
| | def _compute_reconstruction_loss( |
| | self, |
| | reconstructed: torch.Tensor, |
| | target: torch.Tensor |
| | ) -> torch.Tensor: |
| | """Compute reconstruction loss based on the configured loss type.""" |
| | if self.config.reconstruction_loss == "mse": |
| | return F.mse_loss(reconstructed, target, reduction="mean") |
| | elif self.config.reconstruction_loss == "bce": |
| | return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean") |
| | elif self.config.reconstruction_loss == "l1": |
| | return F.l1_loss(reconstructed, target, reduction="mean") |
| | elif self.config.reconstruction_loss == "huber": |
| | return F.huber_loss(reconstructed, target, reduction="mean") |
| | elif self.config.reconstruction_loss == "smooth_l1": |
| | return F.smooth_l1_loss(reconstructed, target, reduction="mean") |
| | elif self.config.reconstruction_loss == "kl_div": |
| | return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean") |
| | elif self.config.reconstruction_loss == "cosine": |
| | return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean() |
| | elif self.config.reconstruction_loss == "focal": |
| | return self._focal_loss(reconstructed, target) |
| | elif self.config.reconstruction_loss == "dice": |
| | return self._dice_loss(reconstructed, target) |
| | elif self.config.reconstruction_loss == "tversky": |
| | return self._tversky_loss(reconstructed, target) |
| | elif self.config.reconstruction_loss == "ssim": |
| | return self._ssim_loss(reconstructed, target) |
| | elif self.config.reconstruction_loss == "perceptual": |
| | return self._perceptual_loss(reconstructed, target) |
| | else: |
| | raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}") |
| |
|
| | def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor: |
| | """Compute focal loss for handling class imbalance.""" |
| | ce_loss = F.mse_loss(pred, target, reduction="none") |
| | pt = torch.exp(-ce_loss) |
| | focal_loss = alpha * (1 - pt) ** gamma * ce_loss |
| | return focal_loss.mean() |
| |
|
| | def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor: |
| | """Compute Dice loss for segmentation-like tasks.""" |
| | pred_flat = pred.view(-1) |
| | target_flat = target.view(-1) |
| | intersection = (pred_flat * target_flat).sum() |
| | dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) |
| | return 1 - dice |
| |
|
| | def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor: |
| | """Compute Tversky loss, a generalization of Dice loss.""" |
| | pred_flat = pred.view(-1) |
| | target_flat = target.view(-1) |
| | true_pos = (pred_flat * target_flat).sum() |
| | false_neg = (target_flat * (1 - pred_flat)).sum() |
| | false_pos = ((1 - target_flat) * pred_flat).sum() |
| | tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) |
| | return 1 - tversky |
| |
|
| | def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| | """Compute SSIM-based loss (simplified version).""" |
| | |
| | mu1 = pred.mean(dim=-1, keepdim=True) |
| | mu2 = target.mean(dim=-1, keepdim=True) |
| | sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True) |
| | sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True) |
| | sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True) |
| |
|
| | c1, c2 = 0.01, 0.03 |
| | ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)) |
| | return 1 - ssim.mean() |
| |
|
| | def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| | """Compute perceptual loss (simplified version using feature differences).""" |
| | |
| | pred_norm = F.normalize(pred, p=2, dim=-1) |
| | target_norm = F.normalize(target, p=2, dim=-1) |
| | return F.mse_loss(pred_norm, target_norm) |
| |
|
| | def forward( |
| | self, |
| | input_values: torch.Tensor, |
| | labels: Optional[torch.Tensor] = None, |
| | sequence_lengths: Optional[torch.Tensor] = None, |
| | target_length: Optional[int] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]: |
| | """ |
| | Forward pass with reconstruction loss calculation. |
| | |
| | Args: |
| | input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
| | - Standard: (batch_size, input_dim) |
| | - Recurrent: (batch_size, seq_len, input_dim) |
| | labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values. |
| | sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
| | target_length (int, optional): Target sequence length for recurrent decoder. |
| | output_hidden_states (bool, optional): Whether to return hidden states. |
| | return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
| | |
| | Returns: |
| | AutoencoderForReconstructionOutput or tuple: The model outputs including loss. |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | if labels is None: |
| | labels = input_values |
| |
|
| | |
| | outputs = self.autoencoder( |
| | input_values=input_values, |
| | sequence_lengths=sequence_lengths, |
| | target_length=target_length, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=True, |
| | ) |
| |
|
| | reconstructed = outputs.reconstructed |
| | latent = outputs.last_hidden_state |
| | hidden_states = outputs.hidden_states |
| |
|
| | |
| | recon_loss = self._compute_reconstruction_loss(reconstructed, labels) |
| |
|
| | |
| | total_loss = recon_loss |
| |
|
| | |
| | if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None: |
| | total_loss += outputs.preprocessing_loss |
| |
|
| | if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None: |
| | |
| | kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp()) |
| | kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) |
| | total_loss = recon_loss + self.config.beta * kl_loss |
| |
|
| | elif self.config.is_sparse: |
| | |
| | latent = outputs.last_hidden_state |
| | sparsity_loss = torch.mean(torch.abs(latent)) |
| | total_loss = recon_loss + 0.1 * sparsity_loss |
| |
|
| | elif self.config.is_contractive: |
| | |
| | latent = outputs.last_hidden_state |
| | latent.retain_grad() |
| | if latent.grad is not None: |
| | contractive_loss = torch.sum(latent.grad ** 2) |
| | total_loss = recon_loss + 0.1 * contractive_loss |
| |
|
| | loss = total_loss |
| |
|
| | if not return_dict: |
| | output = (reconstructed, latent) |
| | if hidden_states is not None: |
| | output = output + (hidden_states,) |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return AutoencoderForReconstructionOutput( |
| | loss=loss, |
| | reconstructed=reconstructed, |
| | last_hidden_state=latent, |
| | hidden_states=hidden_states, |
| | preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None, |
| | ) |
| |
|