luohoa97's picture
Deploy BitNet-Transformer Trainer
d5b7ee9 verified
"""SQLite database layer β€” schema, queries, and connection management."""
from __future__ import annotations
import hashlib
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any
def get_connection(db_path: Path) -> sqlite3.Connection:
conn = sqlite3.connect(str(db_path), check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
return conn
def init_db(db_path: Path) -> sqlite3.Connection:
"""Create all tables and return an open connection."""
conn = get_connection(db_path)
conn.executescript("""
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
symbol TEXT NOT NULL,
action TEXT NOT NULL,
price REAL NOT NULL,
quantity INTEGER NOT NULL,
order_id TEXT,
reason TEXT,
pnl REAL,
portfolio_value REAL
);
CREATE TABLE IF NOT EXISTS signals (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
symbol TEXT NOT NULL,
action TEXT NOT NULL,
confidence REAL,
technical_score REAL,
sentiment_score REAL,
reason TEXT,
executed INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS watchlist (
symbol TEXT PRIMARY KEY,
added_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS sentiment_cache (
headline_hash TEXT PRIMARY KEY,
headline TEXT NOT NULL,
label TEXT NOT NULL,
score REAL NOT NULL,
cached_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS price_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
timestamp TEXT NOT NULL,
open REAL,
high REAL,
low REAL,
close REAL,
volume INTEGER,
UNIQUE(symbol, timestamp)
);
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
""")
conn.commit()
return conn
# ── Trades ─────────────────────────────────────────────────────────────────────
def save_trade(
conn: sqlite3.Connection,
symbol: str,
action: str,
price: float,
quantity: int,
order_id: str | None = None,
reason: str | None = None,
pnl: float | None = None,
portfolio_value: float | None = None,
) -> int:
ts = datetime.utcnow().isoformat()
cur = conn.execute(
"""INSERT INTO trades
(timestamp, symbol, action, price, quantity, order_id, reason, pnl, portfolio_value)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(ts, symbol, action, price, quantity, order_id, reason, pnl, portfolio_value),
)
conn.commit()
return cur.lastrowid
def get_trade_history(
conn: sqlite3.Connection,
symbol: str | None = None,
action: str | None = None,
limit: int = 100,
) -> list[dict]:
q = "SELECT * FROM trades"
params: list[Any] = []
clauses = []
if symbol:
clauses.append("symbol = ?")
params.append(symbol.upper())
if action:
clauses.append("action = ?")
params.append(action.upper())
if clauses:
q += " WHERE " + " AND ".join(clauses)
q += " ORDER BY timestamp DESC LIMIT ?"
params.append(limit)
return [dict(r) for r in conn.execute(q, params).fetchall()]
# ── Signals ────────────────────────────────────────────────────────────────────
def save_signal(
conn: sqlite3.Connection,
symbol: str,
action: str,
confidence: float,
technical_score: float,
sentiment_score: float,
reason: str,
executed: bool = False,
) -> int:
ts = datetime.utcnow().isoformat()
cur = conn.execute(
"""INSERT INTO signals
(timestamp, symbol, action, confidence, technical_score, sentiment_score, reason, executed)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(ts, symbol, action, confidence, technical_score, sentiment_score, reason, int(executed)),
)
conn.commit()
return cur.lastrowid
def get_recent_signals(conn: sqlite3.Connection, limit: int = 20) -> list[dict]:
return [
dict(r)
for r in conn.execute(
"SELECT * FROM signals ORDER BY timestamp DESC LIMIT ?", (limit,)
).fetchall()
]
# ── Watchlist ──────────────────────────────────────────────────────────────────
def get_watchlist(conn: sqlite3.Connection) -> list[str]:
return [r["symbol"] for r in conn.execute("SELECT symbol FROM watchlist ORDER BY symbol").fetchall()]
def add_to_watchlist(conn: sqlite3.Connection, symbol: str) -> None:
conn.execute(
"INSERT OR IGNORE INTO watchlist (symbol, added_at) VALUES (?, ?)",
(symbol.upper(), datetime.utcnow().isoformat()),
)
conn.commit()
def remove_from_watchlist(conn: sqlite3.Connection, symbol: str) -> None:
conn.execute("DELETE FROM watchlist WHERE symbol = ?", (symbol.upper(),))
conn.commit()
# ── Sentiment cache ────────────────────────────────────────────────────────────
def headline_hash(text: str) -> str:
return hashlib.md5(text.encode()).hexdigest()
def get_cached_sentiment(conn: sqlite3.Connection, text: str) -> dict | None:
h = headline_hash(text)
row = conn.execute(
"SELECT label, score FROM sentiment_cache WHERE headline_hash = ?", (h,)
).fetchone()
return dict(row) if row else None
def cache_sentiment(conn: sqlite3.Connection, text: str, label: str, score: float) -> None:
h = headline_hash(text)
conn.execute(
"""INSERT OR REPLACE INTO sentiment_cache
(headline_hash, headline, label, score, cached_at)
VALUES (?, ?, ?, ?, ?)""",
(h, text[:500], label, score, datetime.utcnow().isoformat()),
)
conn.commit()
# ── Price history ──────────────────────────────────────────────────────────────
def upsert_price_bar(
conn: sqlite3.Connection,
symbol: str,
timestamp: str,
open_: float,
high: float,
low: float,
close: float,
volume: int,
) -> None:
conn.execute(
"""INSERT OR REPLACE INTO price_history
(symbol, timestamp, open, high, low, close, volume)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(symbol, timestamp, open_, high, low, close, volume),
)
conn.commit()
def get_price_history(
conn: sqlite3.Connection, symbol: str, limit: int = 200
) -> list[dict]:
return [
dict(r)
for r in conn.execute(
"SELECT * FROM price_history WHERE symbol = ? ORDER BY timestamp DESC LIMIT ?",
(symbol.upper(), limit),
).fetchall()
]