""" Tests for TouchGrass Loss Functions. """ import pytest import torch import torch.nn.functional as F from TouchGrass.training.losses import TouchGrassLoss, MusicAwareLoss class TestTouchGrassLoss: """Test suite for TouchGrassLoss.""" def setup_method(self): """Set up test fixtures.""" self.batch_size = 4 self.seq_len = 10 self.vocab_size = 32000 self.loss_fn = TouchGrassLoss( lm_loss_weight=1.0, eq_loss_weight=0.1, music_module_loss_weight=0.05 ) def test_loss_initialization(self): """Test loss function initialization.""" assert self.loss_fn.lm_loss_weight == 1.0 assert self.loss_fn.eq_loss_weight == 0.1 assert self.loss_fn.music_module_loss_weight == 0.05 def test_forward_with_all_outputs(self): """Test forward pass with all outputs.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) eq_outputs = { "frustration": torch.rand(self.batch_size, self.seq_len, 1), "emotion": torch.randn(self.batch_size, self.seq_len, 4) } eq_labels = { "frustration": torch.rand(self.batch_size, self.seq_len, 1), "emotion": torch.randint(0, 4, (self.batch_size, self.seq_len)) } music_outputs = { "tab_validator": torch.rand(self.batch_size, self.seq_len, 1), "difficulty": torch.randn(self.batch_size, self.seq_len, 3), "interval_logits": torch.randn(self.batch_size, self.seq_len, 12) } music_labels = { "tab_validator": torch.rand(self.batch_size, self.seq_len, 1), "difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len)), "interval_logits": torch.randint(0, 12, (self.batch_size, self.seq_len)) } loss_dict = self.loss_fn( logits=logits, labels=labels, eq_outputs=eq_outputs, eq_labels=eq_labels, music_outputs=music_outputs, music_labels=music_labels ) assert "total_loss" in loss_dict assert "lm_loss" in loss_dict assert "eq_loss" in loss_dict assert "music_loss" in loss_dict assert isinstance(loss_dict["total_loss"], torch.Tensor) assert loss_dict["total_loss"].shape == () def test_forward_without_auxiliary_losses(self): """Test forward pass with only LM loss.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn(logits=logits, labels=labels) assert "total_loss" in loss_dict assert "lm_loss" in loss_dict assert loss_dict["eq_loss"] == 0.0 assert loss_dict["music_loss"] == 0.0 # Total should equal LM loss only assert torch.isclose(loss_dict["total_loss"], loss_dict["lm_loss"]) def test_lm_loss_calculation(self): """Test that LM loss is computed correctly.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn(logits=logits, labels=labels) lm_loss = loss_dict["lm_loss"] # Manual calculation shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() expected_lm_loss = F.cross_entropy( shift_logits.view(-1, self.vocab_size), shift_labels.view(-1) ) assert torch.isclose(lm_loss, expected_lm_loss, rtol=1e-4) def test_eq_loss_frustration_mse(self): """Test that frustration loss uses MSE.""" eq_outputs = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)} eq_labels = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)} logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn( logits=logits, labels=labels, eq_outputs=eq_outputs, eq_labels=eq_labels ) # EQ loss should be non-zero assert loss_dict["eq_loss"] > 0 def test_eq_loss_emotion_cross_entropy(self): """Test that emotion loss uses cross-entropy.""" eq_outputs = {"emotion": torch.randn(self.batch_size, self.seq_len, 4)} eq_labels = {"emotion": torch.randint(0, 4, (self.batch_size, self.seq_len))} logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn( logits=logits, labels=labels, eq_outputs=eq_outputs, eq_labels=eq_labels ) assert loss_dict["eq_loss"] > 0 def test_music_loss_components(self): """Test that music module loss aggregates multiple components.""" music_outputs = { "tab_validator": torch.rand(self.batch_size, self.seq_len, 1), "difficulty": torch.randn(self.batch_size, self.seq_len, 3), "interval_logits": torch.randn(self.batch_size, self.seq_len, 12) } music_labels = { "tab_validator": torch.rand(self.batch_size, self.seq_len, 1), "difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len)), "interval_logits": torch.randint(0, 12, (self.batch_size, self.seq_len)) } logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn( logits=logits, labels=labels, music_outputs=music_outputs, music_labels=music_labels ) assert loss_dict["music_loss"] > 0 def test_loss_weighting(self): """Test that loss weights are applied correctly.""" # Create a scenario where we can isolate weights logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) # Only LM loss loss1 = self.loss_fn(logits=logits, labels=labels, lm_loss_weight=1.0) loss2 = self.loss_fn(logits=logits, labels=labels, lm_loss_weight=2.0) # With double weight, total loss should roughly double (if LM is only component) assert torch.isclose(loss2["total_loss"], 2 * loss1["total_loss"], rtol=1e-3) def test_gradient_computation(self): """Test that gradients can be computed.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size, requires_grad=True) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn(logits=logits, labels=labels) loss_dict["total_loss"].backward() assert logits.grad is not None def test_different_batch_sizes(self): """Test loss with different batch sizes.""" for batch_size in [1, 2, 8]: seq_len = 10 logits = torch.randn(batch_size, seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (batch_size, seq_len)) loss_dict = self.loss_fn(logits=logits, labels=labels) assert loss_dict["total_loss"].shape == () def test_different_seq_lengths(self): """Test loss with different sequence lengths.""" for seq_len in [5, 20, 50, 100]: logits = torch.randn(self.batch_size, seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, seq_len)) loss_dict = self.loss_fn(logits=logits, labels=labels) assert loss_dict["total_loss"].shape == () def test_loss_dict_keys(self): """Test that loss dictionary contains expected keys.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn(logits=logits, labels=labels) expected_keys = ["total_loss", "lm_loss", "eq_loss", "music_loss"] for key in expected_keys: assert key in loss_dict def test_loss_values_are_finite(self): """Test that all loss values are finite.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) loss_dict = self.loss_fn(logits=logits, labels=labels) for key, value in loss_dict.items(): assert torch.isfinite(value), f"Loss {key} is not finite: {value}" def test_loss_weights_accumulate(self): """Test that total loss properly accumulates weighted components.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) eq_outputs = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)} eq_labels = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)} music_outputs = {"difficulty": torch.randn(self.batch_size, self.seq_len, 3)} music_labels = {"difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len))} loss_fn = TouchGrassLoss(lm_loss_weight=1.0, eq_loss_weight=0.5, music_module_loss_weight=0.25) loss_dict = loss_fn( logits=logits, labels=labels, eq_outputs=eq_outputs, eq_labels=eq_labels, music_outputs=music_outputs, music_labels=music_labels ) # Total should be weighted sum expected_total = ( 1.0 * loss_dict["lm_loss"] + 0.5 * loss_dict["eq_loss"] + 0.25 * loss_dict["music_loss"] ) assert torch.isclose(loss_dict["total_loss"], expected_total, rtol=1e-4) def test_with_custom_loss_weights(self): """Test initializing with custom loss weights.""" custom_loss_fn = TouchGrassLoss( lm_loss_weight=2.0, eq_loss_weight=0.5, music_module_loss_weight=0.2 ) assert custom_loss_fn.lm_loss_weight == 2.0 assert custom_loss_fn.eq_loss_weight == 0.5 assert custom_loss_fn.music_module_loss_weight == 0.2 def test_missing_auxiliary_outputs(self): """Test that missing auxiliary outputs are handled gracefully.""" logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size) labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len)) # Should work without eq_outputs or music_outputs loss_dict = self.loss_fn(logits=logits, labels=labels) assert loss_dict["total_loss"] > 0 class TestMusicAwareLoss: """Test suite for MusicAwareLoss (alternative implementation).""" def test_music_aware_loss_initialization(self): """Test MusicAwareLoss initialization.""" loss_fn = MusicAwareLoss() assert hasattr(loss_fn, "forward") def test_music_aware_loss_forward(self): """Test MusicAwareLoss forward pass.""" loss_fn = MusicAwareLoss() logits = torch.randn(2, 10, 1000) labels = torch.randint(0, 1000, (2, 10)) # Should work with just LM loss loss = loss_fn(logits, labels) assert isinstance(loss, torch.Tensor) assert loss.shape == () def test_music_aware_loss_with_weights(self): """Test MusicAwareLoss with custom weights.""" loss_fn = MusicAwareLoss( lm_weight=1.0, music_weight=0.1, eq_weight=0.05 ) logits = torch.randn(2, 10, 1000) labels = torch.randint(0, 1000, (2, 10)) loss = loss_fn(logits, labels) assert torch.isfinite(loss) if __name__ == "__main__": pytest.main([__file__, "-v"])