File size: 832 Bytes
d5b7ee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from safetensors.torch import load_file
from trading_cli.strategy.ai.model import create_model
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def test_inference():
    model = create_model(input_dim=9)
    try:
        model.load_state_dict(load_file("models/ai_fusion_bitnet.safetensors"))
        model.eval()
        logger.info("Model loaded successfully ✓")
        
        # Test with random input
        x = torch.randn(1, 9)
        with torch.no_grad():
            output = model(x)
            logger.info(f"Output: {output}")
            action = torch.argmax(output, dim=-1).item()
            logger.info(f"Action: {action}")
    except Exception as e:
        logger.error(f"Inference test failed: {e}")

if __name__ == "__main__":
    test_inference()