|
|
| """
|
| Comprehensive unit tests for Vortex model components.
|
| Run with: python -m pytest test_model.py -v
|
| """
|
|
|
| import pytest
|
| import torch
|
| import sys
|
| from pathlib import Path
|
|
|
|
|
| sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
|
| def test_tokenizer():
|
| """Test VortexScienceTokenizer."""
|
| from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
|
|
| tokenizer = VortexScienceTokenizer(VORTEX_7B_CONFIG)
|
|
|
|
|
| text = "The equation is $E = mc^2$ and H2O is water."
|
| encoded = tokenizer.encode(text, return_tensors="pt")
|
| assert "input_ids" in encoded
|
| assert encoded["input_ids"].shape[0] == 1
|
|
|
| decoded = tokenizer.decode(encoded["input_ids"][0].tolist())
|
| assert isinstance(decoded, str)
|
| print("✓ Tokenizer test passed")
|
|
|
|
|
| def test_ssm_layer():
|
| """Test VortexSSM."""
|
| from models.ssm_layer import VortexSSM
|
|
|
| batch_size = 2
|
| seq_len = 64
|
| d_model = 512
|
| d_state = 16
|
|
|
| ssm = VortexSSM(d_model, d_state=d_state)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| output = ssm(x)
|
| assert output.shape == x.shape
|
|
|
|
|
| state = torch.zeros(batch_size, ssm.d_inner, d_state)
|
| output2, new_state = ssm(x, state=state, return_state=True)
|
| assert output2.shape == x.shape
|
| assert new_state.shape == (batch_size, ssm.d_inner, d_state)
|
|
|
|
|
| x_step = torch.randn(batch_size, d_model)
|
| output_step, state_step = ssm.step(x_step, state)
|
| assert output_step.shape == (batch_size, d_model)
|
| assert state_step.shape == (batch_size, ssm.d_inner, d_state)
|
|
|
| print("✓ SSM layer test passed")
|
|
|
|
|
| def test_attention_layer():
|
| """Test VortexLocalAttention."""
|
| from models.attention_layer import VortexLocalAttention
|
|
|
| batch_size = 2
|
| seq_len = 128
|
| d_model = 512
|
| num_heads = 8
|
|
|
| attn = VortexLocalAttention(d_model, num_heads, window_size=64, use_flash_attention=False)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| output = attn(x)
|
| assert output.shape == x.shape
|
|
|
|
|
| global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
|
| global_mask[0, 0] = True
|
| output2 = attn(x, global_mask=global_mask)
|
| assert output2.shape == x.shape
|
|
|
| print("✓ Local attention test passed")
|
|
|
|
|
| def test_scigate_ffn():
|
| """Test SciGateFFN."""
|
| from models.scigate_ffn import SciGateFFN
|
|
|
| batch_size = 2
|
| seq_len = 64
|
| d_model = 512
|
| num_domains = 7
|
|
|
| ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| output = ffn(x)
|
| assert output.shape == x.shape
|
|
|
|
|
| domain_ids = torch.randint(0, num_domains, (batch_size,))
|
| output2 = ffn(x, domain_ids=domain_ids)
|
| assert output2.shape == x.shape
|
|
|
|
|
| domain_tags = torch.zeros(batch_size, seq_len, num_domains)
|
| domain_tags[:, :, 0] = 1.0
|
| output3 = ffn(x, domain_tags=domain_tags)
|
| assert output3.shape == x.shape
|
|
|
| print("✓ SciGate FFN test passed")
|
|
|
|
|
| def test_equation_module():
|
| """Test EquationModule."""
|
| from models.science_modules.equation_module import EquationModule
|
|
|
| d_model = 512
|
| batch_size = 2
|
| seq_len = 64
|
|
|
| module = EquationModule(d_model)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| text = ["E = mc^2 is famous.", "The integral $\\int x dx = x^2/2$."]
|
|
|
| output = module(x, text=text)
|
| assert output.shape == x.shape
|
|
|
|
|
| equation_mask = torch.zeros(batch_size, seq_len)
|
| equation_mask[0, 5:10] = 1.0
|
| loss = module.compute_equation_loss(x, equation_mask)
|
| assert loss.item() >= 0
|
|
|
| print("✓ Equation module test passed")
|
|
|
|
|
| def test_numerical_module():
|
| """Test NumericalReasoningModule."""
|
| from models.science_modules.numerical_module import NumericalReasoningModule
|
|
|
| d_model = 512
|
| batch_size = 2
|
| seq_len = 64
|
|
|
| module = NumericalReasoningModule(d_model)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| text = ["Speed of light: 2.998e8 m/s", "6.022e23 is Avogadro's number."]
|
|
|
| output = module(x, text=text)
|
| assert output.shape == x.shape
|
|
|
| print("✓ Numerical reasoning module test passed")
|
|
|
|
|
| def test_citation_module():
|
| """Test CitationModule."""
|
| from models.science_modules.citation_module import CitationModule
|
|
|
| d_model = 512
|
| batch_size = 2
|
| seq_len = 64
|
|
|
| module = CitationModule(d_model)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| text = ["(Einstein, 1905) changed physics.", "See also [1, 2] for details."]
|
|
|
| output, confidence = module(x, text=text)
|
| assert output.shape == x.shape
|
| assert confidence.shape == (batch_size, seq_len, 1)
|
|
|
|
|
| citation_mask = torch.zeros(batch_size, seq_len)
|
| citation_mask[0, 0:5] = 1.0
|
| loss = module.compute_citation_loss(x, citation_mask, confidence)
|
| assert loss.item() >= 0
|
|
|
| print("✓ Citation module test passed")
|
|
|
|
|
| def test_molecular_module():
|
| """Test MolecularModule."""
|
| from models.science_modules.molecular_module import MolecularModule
|
|
|
| d_model = 512
|
| batch_size = 2
|
| seq_len = 64
|
|
|
| module = MolecularModule(d_model)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| text = ["H2O is water.", "DNA sequence: ACGTACGT"]
|
|
|
| output = module(x, text=text)
|
| assert output.shape == x.shape
|
|
|
| print("✓ Molecular module test passed")
|
|
|
|
|
| def test_vortex_model():
|
| """Test full VortexModel."""
|
| from models.vortex_model import VortexModel
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
|
|
|
|
| config = VORTEX_7B_CONFIG.copy()
|
| config["d_model"] = 256
|
| config["num_layers"] = 4
|
| config["num_heads"] = 4
|
| config["vocab_size"] = 1000
|
|
|
| model = VortexModel(config)
|
|
|
| batch_size = 2
|
| seq_len = 32
|
| input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
|
|
|
|
| output = model(input_ids)
|
| logits = output["logits"]
|
| assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
|
|
|
|
| num_params = model.get_num_params()
|
| assert num_params > 0
|
|
|
| print(f"✓ VortexModel test passed (params: {num_params:,})")
|
|
|
|
|
| def test_quality_filter():
|
| """Test ScienceQualityFilter."""
|
| from data.quality_filter import ScienceQualityFilter
|
|
|
| filter = ScienceQualityFilter()
|
|
|
|
|
| good_text = """
|
| The experiment collected data from 100 participants. Results show a
|
| significant effect (p < 0.05). The equation E = mc^2 is fundamental.
|
| According to Smith et al., this confirms the hypothesis.
|
| """
|
| assert filter.filter(good_text)
|
|
|
|
|
| assert not filter.filter("Too short.")
|
|
|
|
|
| bad_eq = "Equation $E = mc^2 and another $F = ma."
|
| assert not filter.filter(bad_eq)
|
|
|
| print("✓ Quality filter test passed")
|
|
|
|
|
| def test_domain_classifier():
|
| """Test DomainClassifier."""
|
| from data.domain_classifier import DomainClassifier
|
|
|
| d_model = 256
|
| classifier = DomainClassifier(d_model)
|
|
|
|
|
| batch_size = 4
|
| seq_len = 32
|
| hidden = torch.randn(batch_size, seq_len, d_model)
|
| logits = classifier(hidden)
|
| assert logits.shape == (batch_size, 7)
|
|
|
|
|
| text = "Quantum mechanics describes particle behavior."
|
| domain, conf = classifier.classify_text(text)
|
| assert domain in range(7)
|
| assert 0 <= conf <= 1
|
|
|
| print("✓ Domain classifier test passed")
|
|
|
|
|
| def test_deduplication():
|
| """Test MinHashLSH."""
|
| from data.deduplication import MinHashLSH
|
|
|
| lsh = MinHashLSH(num_permutations=32, threshold=0.7, bands=4, rows_per_band=8)
|
|
|
| docs = [
|
| ("doc1", "The quick brown fox jumps over the lazy dog."),
|
| ("doc2", "The quick brown fox jumps over the lazy dog!!!"),
|
| ("doc3", "Completely different text about science."),
|
| ]
|
|
|
| for doc_id, text in docs:
|
| lsh.add_document(doc_id, text)
|
|
|
|
|
| results = lsh.query(docs[0][1])
|
|
|
| assert len(results) >= 1
|
| assert any(r[0] == "doc2" for r in results)
|
|
|
| print("✓ Deduplication test passed")
|
|
|
|
|
| def test_losses():
|
| """Test VortexLoss."""
|
| from training.losses import VortexLoss
|
|
|
| config = {"loss_weights": {
|
| "lm_loss": 1.0,
|
| "equation_loss": 0.3,
|
| "domain_loss": 0.1,
|
| "citation_loss": 0.1,
|
| "numerical_loss": 0.2,
|
| }}
|
|
|
| loss_fn = VortexLoss(config)
|
|
|
| batch_size = 2
|
| seq_len = 32
|
| vocab_size = 1000
|
|
|
| logits = torch.randn(batch_size, seq_len, vocab_size)
|
| labels = torch.randint(0, vocab_size, (batch_size, seq_len))
|
|
|
| losses = loss_fn(logits, labels)
|
| assert "total_loss" in losses
|
| assert "lm_loss" in losses
|
| assert losses["total_loss"].item() > 0
|
|
|
| print("✓ Losses test passed")
|
|
|
|
|
| def test_curriculum():
|
| """Test CurriculumScheduler."""
|
| from training.curriculum import CurriculumScheduler
|
|
|
| config = {
|
| "curriculum_stages": [
|
| {"name": "foundation", "start": 0.0, "end": 0.2},
|
| {"name": "domain", "start": 0.2, "end": 0.5},
|
| {"name": "reasoning", "start": 0.5, "end": 0.8},
|
| {"name": "integration", "start": 0.8, "end": 1.0},
|
| ]
|
| }
|
|
|
| total_steps = 1000
|
| scheduler = CurriculumScheduler(config, total_steps)
|
|
|
|
|
| assert scheduler.get_stage_name(0) == "foundation"
|
| assert scheduler.get_stage_name(250) == "domain"
|
| assert scheduler.get_stage_name(500) == "reasoning"
|
| assert scheduler.get_stage_name(800) == "integration"
|
|
|
|
|
| weights = scheduler.get_dataset_sampler(100)
|
| assert isinstance(weights, dict)
|
| assert sum(weights.values()) == 1.0
|
|
|
| print("✓ Curriculum test passed")
|
|
|
|
|
| def test_hf_integration():
|
| """Test HuggingFace integration."""
|
| from configuration_vortex import VortexConfig
|
| from modeling_vortex import VortexForCausalLM
|
| from tokenization_vortex import VortexTokenizer
|
|
|
|
|
| config = VortexConfig(
|
| d_model=128,
|
| num_layers=2,
|
| num_heads=4,
|
| vocab_size=100,
|
| )
|
|
|
|
|
| model = VortexForCausalLM(config)
|
| batch_size = 2
|
| seq_len = 16
|
| input_ids = torch.randint(0, 100, (batch_size, seq_len))
|
|
|
| outputs = model(input_ids)
|
| assert outputs.logits.shape == (batch_size, seq_len, 100)
|
|
|
|
|
| model.save_pretrained("./test_hf_model")
|
| config.save_pretrained("./test_hf_model")
|
|
|
| from transformers import AutoConfig, AutoModelForCausalLM
|
| loaded_config = AutoConfig.from_pretrained("./test_hf_model")
|
| loaded_model = AutoModelForCausalLM.from_pretrained("./test_hf_model")
|
|
|
| assert loaded_config.model_type == "vortex"
|
| assert isinstance(loaded_model, VortexForCausalLM)
|
|
|
|
|
| import shutil
|
| shutil.rmtree("./test_hf_model")
|
|
|
| print("✓ HuggingFace integration test passed")
|
|
|
|
|
| def run_all_tests():
|
| """Run all tests."""
|
| tests = [
|
| test_tokenizer,
|
| test_ssm_layer,
|
| test_attention_layer,
|
| test_scigate_ffn,
|
| test_equation_module,
|
| test_numerical_module,
|
| test_citation_module,
|
| test_molecular_module,
|
| test_vortex_model,
|
| test_quality_filter,
|
| test_domain_classifier,
|
| test_deduplication,
|
| test_losses,
|
| test_curriculum,
|
| test_hf_integration,
|
| ]
|
|
|
| print("Running Vortex unit tests...\n")
|
| passed = 0
|
| failed = 0
|
|
|
| for test in tests:
|
| try:
|
| test()
|
| passed += 1
|
| except Exception as e:
|
| print(f"✗ {test.__name__} failed: {e}")
|
| failed += 1
|
| import traceback
|
| traceback.print_exc()
|
|
|
| print(f"\n{'='*50}")
|
| print(f"Tests: {passed + failed} total, {passed} passed, {failed} failed")
|
| print(f"{'='*50}")
|
|
|
| return failed == 0
|
|
|
|
|
| if __name__ == "__main__":
|
| success = run_all_tests()
|
| sys.exit(0 if success else 1)
|
|
|