| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import BCEWithLogitsLoss, MSELoss |
|
|
| from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss |
| from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss |
| from .loss_rt_detr import RTDetrForObjectDetectionLoss |
|
|
|
|
| def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): |
| reduction = "sum" if num_items_in_batch is not None else "mean" |
| loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) |
| if reduction == "sum": |
| loss = loss / num_items_in_batch |
| return loss |
|
|
|
|
| def ForCausalLMLoss( |
| logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs |
| ): |
| |
| special_token_weight = 1 |
| special_token_ids = [i for i in range(151665, 151672)] |
| |
| import os |
| special_token_loss = os.getenv("special_token_loss", "T") |
| if special_token_loss == "F": |
| special_token_weight = 0 |
| elif special_token_loss == "T": |
| special_token_weight = 1 |
| else: |
| weight = special_token_loss.split("T")[-1] |
| special_token_weight = float(weight) |
| print(f"special_token_weight: {special_token_weight}") |
| |
| |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| shift_logits = shift_logits.view(-1, vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| if special_token_ids is None: |
| print(f"special_token_ids is None, use default loss func") |
| loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) |
| return loss |
| |
| special_token_mask = torch.isin(shift_labels, torch.tensor(special_token_ids, device=shift_labels.device)) |
| base_loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=ignore_index, reduction="none") |
| |
| if special_token_weight == 0.0: |
| |
| base_loss[special_token_mask] = 0.0 |
| else: |
| |
| base_loss[special_token_mask] *= special_token_weight |
|
|
| if num_items_in_batch is not None: |
| loss = base_loss.sum() / num_items_in_batch |
| else: |
| loss = base_loss.mean() |
|
|
| return loss |
|
|
|
|
| def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): |
| num_labels = config.num_labels |
| if config.problem_type is None: |
| if num_labels == 1: |
| config.problem_type = "regression" |
| elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| config.problem_type = "single_label_classification" |
| else: |
| config.problem_type = "multi_label_classification" |
|
|
| if config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif config.problem_type == "single_label_classification": |
| loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs) |
| elif config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| return loss |
|
|
|
|
| def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs): |
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1).to(start_logits.device) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1).to(end_logits.device) |
| |
| ignored_index = start_logits.size(1) |
| start_positions = start_positions.clamp(0, ignored_index) |
| end_positions = end_positions.clamp(0, ignored_index) |
|
|
| start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs) |
| end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs) |
| total_loss = (start_loss + end_loss) / 2 |
| return total_loss |
|
|
|
|
| def ForTokenClassification(logits, labels, config, **kwargs): |
| |
| logits = logits.view(-1, config.num_labels) |
| labels = labels.view(-1) |
| logits = logits.float() |
| |
| return fixed_cross_entropy(logits, labels, **kwargs) |
|
|
|
|
| LOSS_MAPPING = { |
| "ForCausalLM": ForCausalLMLoss, |
| "ForQuestionAnswering": ForQuestionAnsweringLoss, |
| "ForSequenceClassification": ForSequenceClassificationLoss, |
| "ForTokenClassification": ForTokenClassification, |
| "ForSegmentation": ForSegmentationLoss, |
| "ForObjectDetection": ForObjectDetectionLoss, |
| "DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss, |
| "ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss, |
| "GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss, |
| "ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss, |
| "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss, |
| } |
|
|