nathanael-fijalkow commited on
Commit
cb13986
·
1 Parent(s): 551785b

add src folder and webhook

Browse files
Files changed (6) hide show
  1. app.py +114 -7
  2. requirements.txt +2 -0
  3. src/__init__.py +13 -0
  4. src/evaluate.py +619 -0
  5. src/model.py +437 -0
  6. src/tokenizer.py +278 -0
app.py CHANGED
@@ -9,20 +9,28 @@ This Gradio app provides:
9
  Leaderboard data is stored in a private HuggingFace dataset for persistence.
10
  """
11
 
 
 
12
  import io
13
  import os
 
14
  from datetime import datetime
15
  from pathlib import Path
16
  from typing import Optional
17
 
18
  import gradio as gr
19
  import pandas as pd
 
 
 
 
20
 
21
  # Configuration
22
  ORGANIZATION = os.environ.get("HF_ORGANIZATION", "LLM-course")
23
  LEADERBOARD_DATASET = os.environ.get("LEADERBOARD_DATASET", f"{ORGANIZATION}/chess-challenge-leaderboard")
24
  LEADERBOARD_FILENAME = "leaderboard.csv"
25
  HF_TOKEN = os.environ.get("HF_TOKEN") # Required for private dataset access
 
26
 
27
  STOCKFISH_LEVELS = {
28
  "Beginner (Level 0)": 0,
@@ -293,9 +301,9 @@ def evaluate_legal_moves(
293
  """Evaluate a model's legal move generation."""
294
  try:
295
  import sys
296
- sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
297
 
298
- from chess_challenge.evaluate import ChessEvaluator, load_model_from_hub
299
 
300
  progress(0, desc="Loading model...")
301
  model, tokenizer = load_model_from_hub(model_id)
@@ -355,9 +363,9 @@ def evaluate_winrate(
355
  """Evaluate a model's win rate against Stockfish."""
356
  try:
357
  import sys
358
- sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
359
 
360
- from chess_challenge.evaluate import ChessEvaluator, load_model_from_hub
361
 
362
  progress(0, desc="Loading model...")
363
  model, tokenizer = load_model_from_hub(model_id)
@@ -419,9 +427,9 @@ def evaluate_model(
419
  try:
420
  # Import evaluation code
421
  import sys
422
- sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
423
 
424
- from chess_challenge.evaluate import ChessEvaluator, load_model_from_hub
425
 
426
  progress(0, desc="Loading model...")
427
  model, tokenizer = load_model_from_hub(model_id)
@@ -660,5 +668,104 @@ with gr.Blocks(
660
  )
661
 
662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  if __name__ == "__main__":
664
- demo.launch()
 
 
9
  Leaderboard data is stored in a private HuggingFace dataset for persistence.
10
  """
11
 
12
+ import hashlib
13
+ import hmac
14
  import io
15
  import os
16
+ import sys
17
  from datetime import datetime
18
  from pathlib import Path
19
  from typing import Optional
20
 
21
  import gradio as gr
22
  import pandas as pd
23
+ from fastapi import FastAPI, Request, BackgroundTasks
24
+
25
+ # Create FastAPI app for webhook
26
+ fastapi_app = FastAPI()
27
 
28
  # Configuration
29
  ORGANIZATION = os.environ.get("HF_ORGANIZATION", "LLM-course")
30
  LEADERBOARD_DATASET = os.environ.get("LEADERBOARD_DATASET", f"{ORGANIZATION}/chess-challenge-leaderboard")
31
  LEADERBOARD_FILENAME = "leaderboard.csv"
32
  HF_TOKEN = os.environ.get("HF_TOKEN") # Required for private dataset access
33
+ WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET", "") # For webhook verification
34
 
35
  STOCKFISH_LEVELS = {
36
  "Beginner (Level 0)": 0,
 
301
  """Evaluate a model's legal move generation."""
302
  try:
303
  import sys
304
+ sys.path.insert(0, str(Path(__file__).parent))
305
 
306
+ from src.evaluate import ChessEvaluator, load_model_from_hub
307
 
308
  progress(0, desc="Loading model...")
309
  model, tokenizer = load_model_from_hub(model_id)
 
363
  """Evaluate a model's win rate against Stockfish."""
364
  try:
365
  import sys
366
+ sys.path.insert(0, str(Path(__file__).parent))
367
 
368
+ from src.evaluate import ChessEvaluator, load_model_from_hub
369
 
370
  progress(0, desc="Loading model...")
371
  model, tokenizer = load_model_from_hub(model_id)
 
427
  try:
428
  # Import evaluation code
429
  import sys
430
+ sys.path.insert(0, str(Path(__file__).parent))
431
 
432
+ from src.evaluate import ChessEvaluator, load_model_from_hub
433
 
434
  progress(0, desc="Loading model...")
435
  model, tokenizer = load_model_from_hub(model_id)
 
668
  )
669
 
670
 
671
+ # =============================================================================
672
+ # WEBHOOK HANDLERS FOR AUTOMATIC EVALUATION
673
+ # =============================================================================
674
+
675
+ def verify_webhook_signature(payload: bytes, signature: str) -> bool:
676
+ """Verify the webhook signature from Hugging Face."""
677
+ if not WEBHOOK_SECRET:
678
+ print("⚠️ WEBHOOK_SECRET not set - skipping signature verification")
679
+ return True
680
+ expected = hmac.new(WEBHOOK_SECRET.encode(), payload, hashlib.sha256).hexdigest()
681
+ return hmac.compare_digest(f"sha256={expected}", signature)
682
+
683
+
684
+ def run_auto_evaluation(model_id: str):
685
+ """Run model evaluation in background after webhook trigger."""
686
+ try:
687
+ print(f"🚀 Auto-evaluating new model: {model_id}")
688
+
689
+ # Import evaluation functions
690
+ sys.path.insert(0, str(Path(__file__).parent))
691
+ from src.evaluate import ChessEvaluator, load_model_from_hub
692
+
693
+ # Load model
694
+ model, tokenizer = load_model_from_hub(model_id)
695
+
696
+ # Run legal moves evaluation (quick first pass)
697
+ evaluator = ChessEvaluator(
698
+ model=model,
699
+ tokenizer=tokenizer,
700
+ stockfish_level=1,
701
+ )
702
+ results = evaluator.evaluate_legal_moves(n_positions=100, verbose=True)
703
+
704
+ # Update leaderboard
705
+ leaderboard = load_leaderboard()
706
+ entry = next((e for e in leaderboard if e["model_id"] == model_id), None)
707
+ if entry is None:
708
+ entry = {"model_id": model_id}
709
+ leaderboard.append(entry)
710
+
711
+ entry.update({
712
+ "legal_rate": results.get("legal_rate_with_retry", 0),
713
+ "legal_rate_first_try": results.get("legal_rate_first_try", 0),
714
+ "last_updated": datetime.now().strftime("%Y-%m-%d %H:%M"),
715
+ })
716
+
717
+ save_leaderboard(leaderboard)
718
+ print(f"✅ Auto-evaluation complete for {model_id}: legal_rate={results.get('legal_rate_with_retry', 0):.1%}")
719
+
720
+ except Exception as e:
721
+ print(f"❌ Auto-evaluation failed for {model_id}: {e}")
722
+ import traceback
723
+ traceback.print_exc()
724
+
725
+
726
+ @fastapi_app.post("/webhook")
727
+ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
728
+ """Handle incoming webhooks from Hugging Face."""
729
+ payload = await request.body()
730
+ signature = request.headers.get("X-Webhook-Signature", "")
731
+
732
+ # Verify signature
733
+ if not verify_webhook_signature(payload, signature):
734
+ print("��� Webhook signature verification failed")
735
+ return {"error": "Invalid signature"}, 403
736
+
737
+ data = await request.json()
738
+ event = data.get("event", {})
739
+ event_type = event.get("action")
740
+ repo = data.get("repo", {})
741
+ repo_type = repo.get("type")
742
+ repo_name = repo.get("name")
743
+
744
+ print(f"📥 Webhook received: {event_type} for {repo_type}/{repo_name}")
745
+
746
+ # Only process model creation/updates in our organization
747
+ if repo_type == "model" and repo_name and repo_name.startswith(f"{ORGANIZATION}/"):
748
+ if event_type in ["create", "update"]:
749
+ # Check if it's a chess model
750
+ if "chess" in repo_name.lower():
751
+ print(f"🎯 Queuing evaluation for chess model: {repo_name}")
752
+ background_tasks.add_task(run_auto_evaluation, repo_name)
753
+ return {"status": "evaluation_queued", "model": repo_name}
754
+ else:
755
+ print(f"⏭️ Skipping non-chess model: {repo_name}")
756
+
757
+ return {"status": "ignored"}
758
+
759
+
760
+ @fastapi_app.get("/health")
761
+ async def health_check():
762
+ """Health check endpoint."""
763
+ return {"status": "healthy", "organization": ORGANIZATION}
764
+
765
+
766
+ # Mount Gradio app to FastAPI
767
+ fastapi_app = gr.mount_gradio_app(fastapi_app, demo, path="/")
768
+
769
  if __name__ == "__main__":
770
+ import uvicorn
771
+ uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -5,3 +5,5 @@ python-chess>=1.999
5
  huggingface-hub>=0.20.0
6
  datasets>=2.14.0
7
  pandas>=2.0.0
 
 
 
5
  huggingface-hub>=0.20.0
6
  datasets>=2.14.0
7
  pandas>=2.0.0
8
+ fastapi
9
+ uvicorn
src/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chess Challenge source module."""
2
+
3
+ from .model import ChessConfig, ChessForCausalLM
4
+ from .tokenizer import ChessTokenizer
5
+ from .evaluate import ChessEvaluator, load_model_from_hub
6
+
7
+ __all__ = [
8
+ "ChessConfig",
9
+ "ChessForCausalLM",
10
+ "ChessTokenizer",
11
+ "ChessEvaluator",
12
+ "load_model_from_hub",
13
+ ]
src/evaluate.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for the Chess Challenge.
3
+
4
+ This script evaluates a trained chess model by playing games against
5
+ Stockfish and computing ELO ratings.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import random
12
+ from dataclasses import dataclass
13
+ from typing import List, Optional, Tuple
14
+
15
+ import torch
16
+
17
+
18
+ @dataclass
19
+ class GameResult:
20
+ """Result of a single game."""
21
+ moves: List[str]
22
+ result: str # "1-0", "0-1", or "1/2-1/2"
23
+ model_color: str # "white" or "black"
24
+ termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
25
+ illegal_move_count: int
26
+
27
+
28
+ class ChessEvaluator:
29
+ """
30
+ Evaluator for chess models.
31
+
32
+ This class handles playing games between a trained model and Stockfish,
33
+ tracking results, and computing ELO ratings.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model,
39
+ tokenizer,
40
+ stockfish_path: Optional[str] = None,
41
+ stockfish_level: int = 1,
42
+ max_retries: int = 3,
43
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
44
+ ):
45
+ """
46
+ Initialize the evaluator.
47
+
48
+ Args:
49
+ model: The trained chess model.
50
+ tokenizer: The chess tokenizer.
51
+ stockfish_path: Path to Stockfish executable.
52
+ stockfish_level: Stockfish skill level (0-20).
53
+ max_retries: Maximum retries for illegal moves.
54
+ device: Device to run the model on.
55
+ """
56
+ self.model = model.to(device)
57
+ self.tokenizer = tokenizer
58
+ self.max_retries = max_retries
59
+ self.device = device
60
+
61
+ # Initialize Stockfish
62
+ try:
63
+ import chess
64
+ import chess.engine
65
+
66
+ self.chess = chess
67
+
68
+ if stockfish_path is None:
69
+ # Try common paths
70
+ import shutil
71
+ stockfish_path = shutil.which("stockfish")
72
+
73
+ if stockfish_path:
74
+ self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
75
+ self.engine.configure({"Skill Level": stockfish_level})
76
+ else:
77
+ print("WARNING: Stockfish not found. Install it for full evaluation.")
78
+ self.engine = None
79
+
80
+ except ImportError:
81
+ raise ImportError(
82
+ "python-chess is required for evaluation. "
83
+ "Install it with: pip install python-chess"
84
+ )
85
+
86
+ def __del__(self):
87
+ """Clean up Stockfish engine."""
88
+ if hasattr(self, 'engine') and self.engine:
89
+ self.engine.quit()
90
+
91
+ def _convert_board_to_moves(self, board) -> str:
92
+ """Convert board move history to model input format."""
93
+ moves = []
94
+ temp_board = self.chess.Board()
95
+
96
+ for move in board.move_stack:
97
+ # Get piece and color
98
+ color = "W" if temp_board.turn == self.chess.WHITE else "B"
99
+ piece = temp_board.piece_at(move.from_square)
100
+ piece_letter = piece.symbol().upper() if piece else "P"
101
+
102
+ # Get squares
103
+ from_sq = self.chess.square_name(move.from_square)
104
+ to_sq = self.chess.square_name(move.to_square)
105
+
106
+ move_str = f"{color}{piece_letter}{from_sq}{to_sq}"
107
+
108
+ # Add promotion
109
+ if move.promotion:
110
+ move_str += f"={self.chess.piece_symbol(move.promotion).upper()}"
111
+
112
+ # Add capture suffix
113
+ if temp_board.is_capture(move):
114
+ move_str += "(x)"
115
+
116
+ # Add check/checkmate suffix
117
+ temp_board.push(move)
118
+ if temp_board.is_checkmate():
119
+ move_str = move_str.replace("(x)", "(x+*)") if "(x)" in move_str else move_str + "(+*)"
120
+ elif temp_board.is_check():
121
+ move_str = move_str.replace("(x)", "(x+)") if "(x)" in move_str else move_str + "(+)"
122
+
123
+ # Handle castling
124
+ if piece_letter == "K" and abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
125
+ if to_sq[0] == 'g': # Kingside
126
+ move_str = move_str.split("(")[0] + "(o)"
127
+ else: # Queenside
128
+ move_str = move_str.split("(")[0] + "(O)"
129
+
130
+ moves.append(move_str)
131
+
132
+ return " ".join(moves)
133
+
134
+ def _get_model_move(
135
+ self,
136
+ board,
137
+ temperature: float = 0.7,
138
+ top_k: int = 10,
139
+ ) -> Tuple[Optional[str], int]:
140
+ """
141
+ Get the model's next move prediction.
142
+
143
+ Returns:
144
+ Tuple of (UCI move string, number of retries used).
145
+ """
146
+ self.model.eval()
147
+
148
+ # Convert board to input format
149
+ moves_str = self._convert_board_to_moves(board)
150
+
151
+ # Add BOS token if no moves yet
152
+ if not moves_str:
153
+ input_text = self.tokenizer.bos_token
154
+ else:
155
+ input_text = self.tokenizer.bos_token + " " + moves_str
156
+
157
+ # Tokenize
158
+ inputs = self.tokenizer(
159
+ input_text,
160
+ return_tensors="pt",
161
+ truncation=True,
162
+ max_length=self.model.config.n_ctx - 1,
163
+ ).to(self.device)
164
+
165
+ # Try to generate a legal move
166
+ for retry in range(self.max_retries):
167
+ with torch.no_grad():
168
+ outputs = self.model(**inputs)
169
+ logits = outputs.logits[:, -1, :] / temperature
170
+
171
+ # Apply top-k filtering
172
+ if top_k > 0:
173
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
174
+ logits[indices_to_remove] = float("-inf")
175
+
176
+ # Sample
177
+ probs = torch.softmax(logits, dim=-1)
178
+ next_token = torch.multinomial(probs, num_samples=1)
179
+
180
+ # Decode the move
181
+ move_token = self.tokenizer.decode(next_token[0])
182
+
183
+ # Convert to UCI
184
+ if len(move_token) >= 6:
185
+ uci_move = move_token[2:4] + move_token[4:6]
186
+
187
+ # Handle promotion
188
+ if "=" in move_token:
189
+ promo_idx = move_token.index("=")
190
+ uci_move += move_token[promo_idx + 1].lower()
191
+
192
+ try:
193
+ move = self.chess.Move.from_uci(uci_move)
194
+ if move in board.legal_moves:
195
+ return uci_move, retry
196
+ except (ValueError, self.chess.InvalidMoveError):
197
+ pass
198
+
199
+ # Mask out the tried token for next retry
200
+ logits[0, next_token[0]] = float("-inf")
201
+
202
+ return None, self.max_retries
203
+
204
+ def _get_stockfish_move(self, board, time_limit: float = 0.1) -> str:
205
+ """Get Stockfish's move."""
206
+ if self.engine is None:
207
+ raise RuntimeError("Stockfish engine not initialized")
208
+
209
+ result = self.engine.play(board, self.chess.engine.Limit(time=time_limit))
210
+ return result.move.uci()
211
+
212
+ def play_game(
213
+ self,
214
+ model_color: str = "white",
215
+ max_moves: int = 200,
216
+ temperature: float = 0.7,
217
+ ) -> GameResult:
218
+ """
219
+ Play a single game between the model and Stockfish.
220
+
221
+ Args:
222
+ model_color: "white" or "black".
223
+ max_moves: Maximum number of moves before draw.
224
+ temperature: Sampling temperature for model.
225
+
226
+ Returns:
227
+ GameResult with the game details.
228
+ """
229
+ board = self.chess.Board()
230
+ moves = []
231
+ illegal_move_count = 0
232
+
233
+ model_is_white = model_color == "white"
234
+
235
+ while not board.is_game_over() and len(moves) < max_moves:
236
+ is_model_turn = (board.turn == self.chess.WHITE) == model_is_white
237
+
238
+ if is_model_turn:
239
+ # Model's turn
240
+ uci_move, retries = self._get_model_move(board, temperature)
241
+ illegal_move_count += retries
242
+
243
+ if uci_move is None:
244
+ # Model couldn't find a legal move
245
+ return GameResult(
246
+ moves=moves,
247
+ result="0-1" if model_is_white else "1-0",
248
+ model_color=model_color,
249
+ termination="illegal_move",
250
+ illegal_move_count=illegal_move_count + 1,
251
+ )
252
+
253
+ move = self.chess.Move.from_uci(uci_move)
254
+ else:
255
+ # Stockfish's turn
256
+ if self.engine:
257
+ uci_move = self._get_stockfish_move(board)
258
+ move = self.chess.Move.from_uci(uci_move)
259
+ else:
260
+ # Random move if no engine
261
+ move = random.choice(list(board.legal_moves))
262
+
263
+ board.push(move)
264
+ moves.append(move.uci())
265
+
266
+ # Determine result
267
+ if board.is_checkmate():
268
+ if board.turn == self.chess.WHITE:
269
+ result = "0-1" # Black wins
270
+ else:
271
+ result = "1-0" # White wins
272
+ termination = "checkmate"
273
+ elif board.is_stalemate():
274
+ result = "1/2-1/2"
275
+ termination = "stalemate"
276
+ elif board.is_insufficient_material():
277
+ result = "1/2-1/2"
278
+ termination = "insufficient_material"
279
+ elif board.can_claim_draw():
280
+ result = "1/2-1/2"
281
+ termination = "draw_claim"
282
+ elif len(moves) >= max_moves:
283
+ result = "1/2-1/2"
284
+ termination = "max_moves"
285
+ else:
286
+ result = "1/2-1/2"
287
+ termination = "unknown"
288
+
289
+ return GameResult(
290
+ moves=moves,
291
+ result=result,
292
+ model_color=model_color,
293
+ termination=termination,
294
+ illegal_move_count=illegal_move_count,
295
+ )
296
+
297
+ def evaluate_legal_moves(
298
+ self,
299
+ n_positions: int = 1000,
300
+ temperature: float = 0.7,
301
+ verbose: bool = True,
302
+ ) -> dict:
303
+ """
304
+ Evaluate the model's ability to generate legal moves.
305
+
306
+ This evaluation only checks if the model generates legal moves,
307
+ without playing full games. Useful as a first-pass evaluation.
308
+
309
+ Args:
310
+ n_positions: Number of positions to test.
311
+ temperature: Sampling temperature.
312
+ verbose: Whether to print progress.
313
+
314
+ Returns:
315
+ Dictionary with legal move statistics.
316
+ """
317
+ results = {
318
+ "total_positions": 0,
319
+ "legal_first_try": 0,
320
+ "legal_with_retry": 0,
321
+ "illegal_all_retries": 0,
322
+ "positions": [],
323
+ }
324
+
325
+ # Generate random positions by playing random moves
326
+ for i in range(n_positions):
327
+ board = self.chess.Board()
328
+
329
+ # Play random number of moves (5-40) to get varied positions
330
+ n_random_moves = random.randint(5, 40)
331
+ for _ in range(n_random_moves):
332
+ if board.is_game_over():
333
+ break
334
+ move = random.choice(list(board.legal_moves))
335
+ board.push(move)
336
+
337
+ if board.is_game_over():
338
+ continue # Skip terminal positions
339
+
340
+ results["total_positions"] += 1
341
+
342
+ # Test model's move generation
343
+ uci_move, retries = self._get_model_move(board, temperature)
344
+
345
+ position_result = {
346
+ "fen": board.fen(),
347
+ "move_number": len(board.move_stack),
348
+ "legal": uci_move is not None,
349
+ "retries": retries,
350
+ }
351
+ results["positions"].append(position_result)
352
+
353
+ if uci_move is not None:
354
+ if retries == 0:
355
+ results["legal_first_try"] += 1
356
+ else:
357
+ results["legal_with_retry"] += 1
358
+ else:
359
+ results["illegal_all_retries"] += 1
360
+
361
+ if verbose and (i + 1) % 100 == 0:
362
+ legal_rate = (results["legal_first_try"] + results["legal_with_retry"]) / results["total_positions"]
363
+ print(f" Positions: {i + 1}/{n_positions} | Legal rate: {legal_rate:.1%}")
364
+
365
+ # Calculate statistics
366
+ total = results["total_positions"]
367
+ if total > 0:
368
+ results["legal_rate_first_try"] = results["legal_first_try"] / total
369
+ results["legal_rate_with_retry"] = (results["legal_first_try"] + results["legal_with_retry"]) / total
370
+ results["illegal_rate"] = results["illegal_all_retries"] / total
371
+ else:
372
+ results["legal_rate_first_try"] = 0
373
+ results["legal_rate_with_retry"] = 0
374
+ results["illegal_rate"] = 1
375
+
376
+ return results
377
+
378
+ def evaluate(
379
+ self,
380
+ n_games: int = 100,
381
+ temperature: float = 0.7,
382
+ verbose: bool = True,
383
+ ) -> dict:
384
+ """
385
+ Run a full win-rate evaluation of the model against Stockfish.
386
+
387
+ Args:
388
+ n_games: Number of games to play.
389
+ temperature: Sampling temperature.
390
+ verbose: Whether to print progress.
391
+
392
+ Returns:
393
+ Dictionary with evaluation metrics.
394
+ """
395
+ results = {
396
+ "wins": 0,
397
+ "losses": 0,
398
+ "draws": 0,
399
+ "illegal_moves": 0,
400
+ "total_moves": 0,
401
+ "games": [],
402
+ }
403
+
404
+ for i in range(n_games):
405
+ # Alternate colors
406
+ model_color = "white" if i % 2 == 0 else "black"
407
+
408
+ game = self.play_game(
409
+ model_color=model_color,
410
+ temperature=temperature,
411
+ )
412
+
413
+ results["games"].append(game)
414
+ results["total_moves"] += len(game.moves)
415
+ results["illegal_moves"] += game.illegal_move_count
416
+
417
+ # Count result
418
+ if game.result == "1/2-1/2":
419
+ results["draws"] += 1
420
+ elif (game.result == "1-0" and model_color == "white") or \
421
+ (game.result == "0-1" and model_color == "black"):
422
+ results["wins"] += 1
423
+ else:
424
+ results["losses"] += 1
425
+
426
+ if verbose and (i + 1) % 10 == 0:
427
+ print(f" Games: {i + 1}/{n_games} | "
428
+ f"W: {results['wins']} L: {results['losses']} D: {results['draws']}")
429
+
430
+ # Calculate statistics
431
+ total = results["wins"] + results["losses"] + results["draws"]
432
+ results["win_rate"] = results["wins"] / total if total > 0 else 0
433
+ results["draw_rate"] = results["draws"] / total if total > 0 else 0
434
+ results["loss_rate"] = results["losses"] / total if total > 0 else 0
435
+
436
+ total_attempts = results["total_moves"] + results["illegal_moves"]
437
+
438
+ # Average length counts both legal moves and illegal attempts so early illegal terminations
439
+ # don't show as near-zero length games.
440
+ results["avg_game_length"] = total_attempts / total if total > 0 else 0
441
+
442
+ # Illegal move rate: illegal attempts over total attempts
443
+ results["illegal_move_rate"] = results["illegal_moves"] / total_attempts if total_attempts > 0 else 0
444
+
445
+ # Estimate ELO (simplified)
446
+ # Stockfish Level 1 is approximately 1350 ELO
447
+ stockfish_elo = 1350
448
+ if results["win_rate"] > 0 or results["loss_rate"] > 0:
449
+ score = results["wins"] + 0.5 * results["draws"]
450
+ expected = total * 0.5 # Expected score against equal opponent
451
+
452
+ # Simple ELO estimation
453
+ if score > 0:
454
+ win_ratio = score / total
455
+ if win_ratio > 0 and win_ratio < 1:
456
+ elo_diff = -400 * (1 - 2 * win_ratio) / (1 if win_ratio > 0.5 else -1)
457
+ results["estimated_elo"] = stockfish_elo + elo_diff
458
+ else:
459
+ results["estimated_elo"] = stockfish_elo + (400 if win_ratio >= 1 else -400)
460
+ else:
461
+ results["estimated_elo"] = stockfish_elo - 400
462
+ else:
463
+ results["estimated_elo"] = None
464
+
465
+ return results
466
+
467
+
468
+ def load_model_from_hub(model_id: str, device: str = "auto"):
469
+ """
470
+ Load a model from the Hugging Face Hub.
471
+
472
+ Args:
473
+ model_id: Model ID on Hugging Face Hub.
474
+ device: Device to load the model on.
475
+
476
+ Returns:
477
+ Tuple of (model, tokenizer).
478
+ """
479
+ from transformers import AutoModelForCausalLM, AutoTokenizer
480
+
481
+ # Import to register custom classes (use relative import or handle both cases)
482
+ try:
483
+ from src.model import ChessConfig, ChessForCausalLM
484
+ except ImportError:
485
+ from .model import ChessConfig, ChessForCausalLM
486
+
487
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
488
+ model = AutoModelForCausalLM.from_pretrained(
489
+ model_id,
490
+ trust_remote_code=True,
491
+ device_map=device,
492
+ )
493
+
494
+ return model, tokenizer
495
+
496
+
497
+ def main():
498
+ """Main evaluation function."""
499
+ parser = argparse.ArgumentParser(description="Evaluate a chess model")
500
+
501
+ parser.add_argument(
502
+ "--model_path", type=str, required=True,
503
+ help="Path to the model or Hugging Face model ID"
504
+ )
505
+ parser.add_argument(
506
+ "--mode", type=str, default="both", choices=["legal", "winrate", "both"],
507
+ help="Evaluation mode: 'legal' for legal move rate, 'winrate' for games, 'both' for both"
508
+ )
509
+ parser.add_argument(
510
+ "--stockfish_path", type=str, default=None,
511
+ help="Path to Stockfish executable"
512
+ )
513
+ parser.add_argument(
514
+ "--stockfish_level", type=int, default=1,
515
+ help="Stockfish skill level (0-20)"
516
+ )
517
+ parser.add_argument(
518
+ "--n_positions", type=int, default=500,
519
+ help="Number of positions for legal move evaluation"
520
+ )
521
+ parser.add_argument(
522
+ "--n_games", type=int, default=100,
523
+ help="Number of games to play for win rate evaluation"
524
+ )
525
+ parser.add_argument(
526
+ "--temperature", type=float, default=0.7,
527
+ help="Sampling temperature"
528
+ )
529
+
530
+ args = parser.parse_args()
531
+
532
+ print("=" * 60)
533
+ print("CHESS CHALLENGE - EVALUATION")
534
+ print("=" * 60)
535
+
536
+ # Load model
537
+ print(f"\nLoading model from: {args.model_path}")
538
+
539
+ if "/" in args.model_path and not args.model_path.startswith("."):
540
+ # Assume Hugging Face model ID
541
+ model, tokenizer = load_model_from_hub(args.model_path)
542
+ else:
543
+ # Local path
544
+ from transformers import AutoModelForCausalLM
545
+ try:
546
+ from src.tokenizer import ChessTokenizer
547
+ from src.model import ChessConfig, ChessForCausalLM
548
+ except ImportError:
549
+ from .tokenizer import ChessTokenizer
550
+ from .model import ChessConfig, ChessForCausalLM
551
+
552
+ tokenizer = ChessTokenizer.from_pretrained(args.model_path)
553
+ model = AutoModelForCausalLM.from_pretrained(args.model_path)
554
+
555
+ # Create evaluator
556
+ print(f"\nSetting up evaluator...")
557
+ evaluator = ChessEvaluator(
558
+ model=model,
559
+ tokenizer=tokenizer,
560
+ stockfish_path=args.stockfish_path,
561
+ stockfish_level=args.stockfish_level,
562
+ )
563
+
564
+ # Run legal move evaluation
565
+ if args.mode in ["legal", "both"]:
566
+ print(f"\n" + "=" * 60)
567
+ print("PHASE 1: LEGAL MOVE EVALUATION")
568
+ print("=" * 60)
569
+ print(f"Testing {args.n_positions} random positions...")
570
+
571
+ legal_results = evaluator.evaluate_legal_moves(
572
+ n_positions=args.n_positions,
573
+ temperature=args.temperature,
574
+ verbose=True,
575
+ )
576
+
577
+ print("\n" + "-" * 40)
578
+ print("LEGAL MOVE RESULTS")
579
+ print("-" * 40)
580
+ print(f" Positions tested: {legal_results['total_positions']}")
581
+ print(f" Legal (1st try): {legal_results['legal_first_try']} ({legal_results['legal_rate_first_try']:.1%})")
582
+ print(f" Legal (with retry): {legal_results['legal_first_try'] + legal_results['legal_with_retry']} ({legal_results['legal_rate_with_retry']:.1%})")
583
+ print(f" Always illegal: {legal_results['illegal_all_retries']} ({legal_results['illegal_rate']:.1%})")
584
+
585
+ # Run win rate evaluation
586
+ if args.mode in ["winrate", "both"]:
587
+ print(f"\n" + "=" * 60)
588
+ print("PHASE 2: WIN RATE EVALUATION")
589
+ print("=" * 60)
590
+ print(f"Playing {args.n_games} games against Stockfish (Level {args.stockfish_level})...")
591
+
592
+ winrate_results = evaluator.evaluate(
593
+ n_games=args.n_games,
594
+ temperature=args.temperature,
595
+ verbose=True,
596
+ )
597
+
598
+ print("\n" + "-" * 40)
599
+ print("WIN RATE RESULTS")
600
+ print("-" * 40)
601
+ print(f" Wins: {winrate_results['wins']}")
602
+ print(f" Losses: {winrate_results['losses']}")
603
+ print(f" Draws: {winrate_results['draws']}")
604
+ print(f"\n Win Rate: {winrate_results['win_rate']:.1%}")
605
+ print(f" Draw Rate: {winrate_results['draw_rate']:.1%}")
606
+ print(f" Loss Rate: {winrate_results['loss_rate']:.1%}")
607
+ print(f"\n Avg Game Length: {winrate_results['avg_game_length']:.1f} moves")
608
+ print(f" Illegal Move Rate: {winrate_results['illegal_move_rate']:.2%}")
609
+
610
+ if winrate_results["estimated_elo"]:
611
+ print(f"\n Estimated ELO: {winrate_results['estimated_elo']:.0f}")
612
+
613
+ print("\n" + "=" * 60)
614
+ print("EVALUATION COMPLETE")
615
+ print("=" * 60)
616
+
617
+
618
+ if __name__ == "__main__":
619
+ main()
src/model.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chess Transformer Model for the Chess Challenge.
3
+
4
+ This module provides a simple GPT-style transformer architecture
5
+ designed to fit within the 1M parameter constraint.
6
+
7
+ Key components:
8
+ - ChessConfig: Configuration class for model hyperparameters
9
+ - ChessForCausalLM: The main model class for next-move prediction
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import CausalLMOutputWithPast
23
+
24
+
25
+ class ChessConfig(PretrainedConfig):
26
+ """
27
+ Configuration class for the Chess Transformer model.
28
+
29
+ This configuration is designed for a ~1M parameter model.
30
+ Students can adjust these values to explore different architectures.
31
+
32
+ Parameter budget breakdown (with default values):
33
+ - Embeddings (vocab): 1200 x 128 = 153,600
34
+ - Position Embeddings: 256 x 128 = 32,768
35
+ - Transformer Layers: 6 x ~120,000 = ~720,000
36
+ - LM Head (with weight tying): 0 (shared with embeddings)
37
+ - Total: ~906,000 parameters
38
+
39
+ Attributes:
40
+ vocab_size: Size of the vocabulary (number of unique moves).
41
+ n_embd: Embedding dimension (d_model).
42
+ n_layer: Number of transformer layers.
43
+ n_head: Number of attention heads.
44
+ n_ctx: Maximum sequence length (context window).
45
+ n_inner: Feed-forward inner dimension (default: 3 * n_embd).
46
+ dropout: Dropout probability.
47
+ layer_norm_epsilon: Epsilon for layer normalization.
48
+ tie_weights: Whether to tie embedding and output weights.
49
+ """
50
+
51
+ model_type = "chess_transformer"
52
+
53
+ def __init__(
54
+ self,
55
+ vocab_size: int = 1200,
56
+ n_embd: int = 128,
57
+ n_layer: int = 6,
58
+ n_head: int = 4,
59
+ n_ctx: int = 256,
60
+ n_inner: Optional[int] = None,
61
+ dropout: float = 0.1,
62
+ layer_norm_epsilon: float = 1e-5,
63
+ tie_weights: bool = True,
64
+ pad_token_id: int = 0,
65
+ bos_token_id: int = 1,
66
+ eos_token_id: int = 2,
67
+ **kwargs,
68
+ ):
69
+ super().__init__(
70
+ pad_token_id=pad_token_id,
71
+ bos_token_id=bos_token_id,
72
+ eos_token_id=eos_token_id,
73
+ **kwargs,
74
+ )
75
+
76
+ self.vocab_size = vocab_size
77
+ self.n_embd = n_embd
78
+ self.n_layer = n_layer
79
+ self.n_head = n_head
80
+ self.n_ctx = n_ctx
81
+ self.n_inner = n_inner if n_inner is not None else 3 * n_embd # Reduced from 4x to 3x
82
+ self.dropout = dropout
83
+ self.layer_norm_epsilon = layer_norm_epsilon
84
+ self.tie_weights = tie_weights
85
+ # Inform HF base class about tying behavior
86
+ self.tie_word_embeddings = bool(tie_weights)
87
+
88
+
89
+ class MultiHeadAttention(nn.Module):
90
+ """
91
+ Multi-head self-attention module.
92
+
93
+ This is a standard scaled dot-product attention implementation
94
+ with causal masking for autoregressive generation.
95
+ """
96
+
97
+ def __init__(self, config: ChessConfig):
98
+ super().__init__()
99
+
100
+ assert config.n_embd % config.n_head == 0, \
101
+ f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
102
+
103
+ self.n_head = config.n_head
104
+ self.n_embd = config.n_embd
105
+ self.head_dim = config.n_embd // config.n_head
106
+
107
+ # Combined QKV projection for efficiency
108
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
109
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
110
+
111
+ self.dropout = nn.Dropout(config.dropout)
112
+
113
+ # Causal mask (will be created on first forward pass)
114
+ self.register_buffer(
115
+ "bias",
116
+ torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
117
+ 1, 1, config.n_ctx, config.n_ctx
118
+ ),
119
+ persistent=False,
120
+ )
121
+
122
+ def forward(
123
+ self,
124
+ x: torch.Tensor,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ ) -> torch.Tensor:
127
+ batch_size, seq_len, _ = x.size()
128
+
129
+ # Compute Q, K, V
130
+ qkv = self.c_attn(x)
131
+ q, k, v = qkv.split(self.n_embd, dim=2)
132
+
133
+ # Reshape for multi-head attention
134
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
135
+ k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
136
+ v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
137
+
138
+ # Scaled dot-product attention
139
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
140
+
141
+ # Apply causal mask
142
+ causal_mask = self.bias[:, :, :seq_len, :seq_len]
143
+ attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
144
+
145
+ # Apply attention mask (for padding)
146
+ if attention_mask is not None:
147
+ # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
148
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
149
+ attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
150
+
151
+ attn_weights = F.softmax(attn_weights, dim=-1)
152
+ attn_weights = self.dropout(attn_weights)
153
+
154
+ # Apply attention to values
155
+ attn_output = torch.matmul(attn_weights, v)
156
+
157
+ # Reshape back
158
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
159
+ batch_size, seq_len, self.n_embd
160
+ )
161
+
162
+ # Output projection
163
+ attn_output = self.c_proj(attn_output)
164
+
165
+ return attn_output
166
+
167
+
168
+ class FeedForward(nn.Module):
169
+ """
170
+ Feed-forward network (MLP) module.
171
+
172
+ Standard two-layer MLP with GELU activation.
173
+ """
174
+
175
+ def __init__(self, config: ChessConfig):
176
+ super().__init__()
177
+
178
+ self.c_fc = nn.Linear(config.n_embd, config.n_inner)
179
+ self.c_proj = nn.Linear(config.n_inner, config.n_embd)
180
+ self.dropout = nn.Dropout(config.dropout)
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ x = self.c_fc(x)
184
+ x = F.gelu(x)
185
+ x = self.c_proj(x)
186
+ x = self.dropout(x)
187
+ return x
188
+
189
+
190
+ class TransformerBlock(nn.Module):
191
+ """
192
+ A single transformer block with attention and feed-forward layers.
193
+
194
+ Uses pre-normalization (LayerNorm before attention/FFN) for better
195
+ training stability.
196
+ """
197
+
198
+ def __init__(self, config: ChessConfig):
199
+ super().__init__()
200
+
201
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
202
+ self.attn = MultiHeadAttention(config)
203
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
204
+ self.mlp = FeedForward(config)
205
+
206
+ def forward(
207
+ self,
208
+ x: torch.Tensor,
209
+ attention_mask: Optional[torch.Tensor] = None,
210
+ ) -> torch.Tensor:
211
+ # Pre-norm attention
212
+ x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
213
+ # Pre-norm FFN
214
+ x = x + self.mlp(self.ln_2(x))
215
+ return x
216
+
217
+
218
+ class ChessForCausalLM(PreTrainedModel):
219
+ """
220
+ Chess Transformer for Causal Language Modeling (next-move prediction).
221
+
222
+ This model is designed to predict the next chess move given a sequence
223
+ of previous moves. It uses a GPT-style architecture with:
224
+ - Token embeddings for chess moves
225
+ - Learned positional embeddings
226
+ - Stacked transformer blocks
227
+ - Linear head for next-token prediction
228
+
229
+ The model supports weight tying between the embedding layer and the
230
+ output projection to save parameters.
231
+
232
+ Example:
233
+ >>> config = ChessConfig(vocab_size=1200, n_embd=128, n_layer=6)
234
+ >>> model = ChessForCausalLM(config)
235
+ >>> inputs = {"input_ids": torch.tensor([[1, 42, 87]])}
236
+ >>> outputs = model(**inputs)
237
+ >>> next_move_logits = outputs.logits[:, -1, :]
238
+ """
239
+
240
+ config_class = ChessConfig
241
+ base_model_prefix = "transformer"
242
+ supports_gradient_checkpointing = True
243
+ # Suppress missing-key warning for tied lm_head when loading
244
+ keys_to_ignore_on_load_missing = ["lm_head.weight"]
245
+
246
+ def __init__(self, config: ChessConfig):
247
+ super().__init__(config)
248
+
249
+ # Token and position embeddings
250
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
251
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
252
+
253
+ self.drop = nn.Dropout(config.dropout)
254
+
255
+ # Transformer blocks
256
+ self.h = nn.ModuleList([
257
+ TransformerBlock(config) for _ in range(config.n_layer)
258
+ ])
259
+
260
+ # Final layer norm
261
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
262
+
263
+ # Output head
264
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
265
+
266
+ # Declare tied weights for proper serialization
267
+ if config.tie_weights:
268
+ self._tied_weights_keys = ["lm_head.weight"]
269
+
270
+ # Initialize weights
271
+ self.post_init()
272
+
273
+ # Tie weights if configured
274
+ if config.tie_weights:
275
+ self.tie_weights()
276
+
277
+ def get_input_embeddings(self) -> nn.Module:
278
+ return self.wte
279
+
280
+ def set_input_embeddings(self, new_embeddings: nn.Module):
281
+ self.wte = new_embeddings
282
+ if getattr(self.config, "tie_weights", False):
283
+ self.tie_weights()
284
+
285
+ def get_output_embeddings(self) -> nn.Module:
286
+ return self.lm_head
287
+
288
+ def set_output_embeddings(self, new_embeddings: nn.Module):
289
+ self.lm_head = new_embeddings
290
+
291
+ def tie_weights(self):
292
+ # Use HF helper to tie or clone depending on config
293
+ if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
294
+ self._tie_or_clone_weights(self.lm_head, self.wte)
295
+
296
+ def _init_weights(self, module: nn.Module):
297
+ """Initialize weights following GPT-2 style."""
298
+ if isinstance(module, nn.Linear):
299
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
300
+ if module.bias is not None:
301
+ torch.nn.init.zeros_(module.bias)
302
+ elif isinstance(module, nn.Embedding):
303
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
304
+ elif isinstance(module, nn.LayerNorm):
305
+ torch.nn.init.ones_(module.weight)
306
+ torch.nn.init.zeros_(module.bias)
307
+
308
+ def forward(
309
+ self,
310
+ input_ids: torch.LongTensor,
311
+ attention_mask: Optional[torch.Tensor] = None,
312
+ position_ids: Optional[torch.LongTensor] = None,
313
+ labels: Optional[torch.LongTensor] = None,
314
+ return_dict: Optional[bool] = None,
315
+ **kwargs,
316
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
317
+ """
318
+ Forward pass of the model.
319
+
320
+ Args:
321
+ input_ids: Token IDs of shape (batch_size, seq_len).
322
+ attention_mask: Attention mask of shape (batch_size, seq_len).
323
+ position_ids: Position IDs of shape (batch_size, seq_len).
324
+ labels: Labels for language modeling loss.
325
+ return_dict: Whether to return a ModelOutput object.
326
+
327
+ Returns:
328
+ CausalLMOutputWithPast containing loss (if labels provided) and logits.
329
+ """
330
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
331
+
332
+ batch_size, seq_len = input_ids.size()
333
+ device = input_ids.device
334
+
335
+ # Create position IDs if not provided
336
+ if position_ids is None:
337
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
338
+
339
+ # Get embeddings
340
+ token_embeds = self.wte(input_ids)
341
+ position_embeds = self.wpe(position_ids)
342
+ hidden_states = self.drop(token_embeds + position_embeds)
343
+
344
+ # Pass through transformer blocks
345
+ for block in self.h:
346
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
347
+
348
+ # Final layer norm
349
+ hidden_states = self.ln_f(hidden_states)
350
+
351
+ # Get logits
352
+ logits = self.lm_head(hidden_states)
353
+
354
+ # Compute loss if labels are provided
355
+ loss = None
356
+ if labels is not None:
357
+ # Shift logits and labels for next-token prediction
358
+ shift_logits = logits[..., :-1, :].contiguous()
359
+ shift_labels = labels[..., 1:].contiguous()
360
+
361
+ # Flatten for cross-entropy
362
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
363
+ loss = loss_fct(
364
+ shift_logits.view(-1, shift_logits.size(-1)),
365
+ shift_labels.view(-1),
366
+ )
367
+
368
+ if not return_dict:
369
+ output = (logits,)
370
+ return ((loss,) + output) if loss is not None else output
371
+
372
+ return CausalLMOutputWithPast(
373
+ loss=loss,
374
+ logits=logits,
375
+ past_key_values=None,
376
+ hidden_states=None,
377
+ attentions=None,
378
+ )
379
+
380
+ @torch.no_grad()
381
+ def generate_move(
382
+ self,
383
+ input_ids: torch.LongTensor,
384
+ temperature: float = 1.0,
385
+ top_k: Optional[int] = None,
386
+ top_p: Optional[float] = None,
387
+ ) -> int:
388
+ """
389
+ Generate the next move given a sequence of moves.
390
+
391
+ Args:
392
+ input_ids: Token IDs of shape (1, seq_len).
393
+ temperature: Sampling temperature (1.0 = no change).
394
+ top_k: If set, only sample from top k tokens.
395
+ top_p: If set, use nucleus sampling with this threshold.
396
+
397
+ Returns:
398
+ The token ID of the predicted next move.
399
+ """
400
+ self.eval()
401
+
402
+ # Get logits for the last position
403
+ outputs = self(input_ids)
404
+ logits = outputs.logits[:, -1, :] / temperature
405
+
406
+ # Apply top-k filtering
407
+ if top_k is not None:
408
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
409
+ logits[indices_to_remove] = float("-inf")
410
+
411
+ # Apply top-p (nucleus) filtering
412
+ if top_p is not None:
413
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
414
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
415
+
416
+ # Remove tokens with cumulative probability above the threshold
417
+ sorted_indices_to_remove = cumulative_probs > top_p
418
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
419
+ sorted_indices_to_remove[..., 0] = 0
420
+
421
+ indices_to_remove = sorted_indices_to_remove.scatter(
422
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
423
+ )
424
+ logits[indices_to_remove] = float("-inf")
425
+
426
+ # Sample from the distribution
427
+ probs = F.softmax(logits, dim=-1)
428
+ next_token = torch.multinomial(probs, num_samples=1)
429
+
430
+ return next_token.item()
431
+
432
+
433
+ # Register the model with Auto classes for easy loading
434
+ from transformers import AutoConfig, AutoModelForCausalLM
435
+
436
+ AutoConfig.register("chess_transformer", ChessConfig)
437
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
src/tokenizer.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer treats each move as a single token using the extended UCI notation
5
+ from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
+
7
+ The dataset format uses:
8
+ - W/B prefix for White/Black
9
+ - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
+ - Source and destination squares (e.g., e2e4)
11
+ - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+
23
+
24
+ class ChessTokenizer(PreTrainedTokenizer):
25
+ """
26
+ A custom tokenizer for chess moves using extended UCI notation.
27
+
28
+ This tokenizer maps each possible chess move to a unique token ID.
29
+ The vocabulary is built from the training dataset to ensure all moves
30
+ encountered during training have a corresponding token.
31
+
32
+ Example:
33
+ >>> tokenizer = ChessTokenizer()
34
+ >>> tokenizer.encode("WPe2e4 BPe7e5")
35
+ [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
36
+ """
37
+
38
+ model_input_names = ["input_ids", "attention_mask"]
39
+ vocab_files_names = {"vocab_file": "vocab.json"}
40
+
41
+ # Special tokens
42
+ PAD_TOKEN = "[PAD]"
43
+ BOS_TOKEN = "[BOS]"
44
+ EOS_TOKEN = "[EOS]"
45
+ UNK_TOKEN = "[UNK]"
46
+
47
+ def __init__(
48
+ self,
49
+ vocab_file: Optional[str] = None,
50
+ vocab: Optional[Dict[str, int]] = None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Initialize the chess tokenizer.
55
+
56
+ Args:
57
+ vocab_file: Path to a JSON file containing the vocabulary mapping.
58
+ vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
59
+ **kwargs: Additional arguments passed to PreTrainedTokenizer.
60
+ """
61
+ # Initialize special tokens
62
+ self._pad_token = self.PAD_TOKEN
63
+ self._bos_token = self.BOS_TOKEN
64
+ self._eos_token = self.EOS_TOKEN
65
+ self._unk_token = self.UNK_TOKEN
66
+
67
+ # Remove any duplicate special-token entries passed through kwargs
68
+ # to avoid "multiple values for keyword" errors when loading from disk.
69
+ kwargs.pop("pad_token", None)
70
+ kwargs.pop("bos_token", None)
71
+ kwargs.pop("eos_token", None)
72
+ kwargs.pop("unk_token", None)
73
+
74
+ # Load or create vocabulary
75
+ if vocab is not None:
76
+ self._vocab = vocab
77
+ elif vocab_file is not None and os.path.exists(vocab_file):
78
+ with open(vocab_file, "r", encoding="utf-8") as f:
79
+ self._vocab = json.load(f)
80
+ else:
81
+ # Create a minimal vocabulary with just special tokens
82
+ # The full vocabulary should be built from the dataset
83
+ self._vocab = self._create_default_vocab()
84
+
85
+ # Create reverse mapping
86
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
87
+
88
+ # Call parent init AFTER setting up vocab
89
+ super().__init__(
90
+ pad_token=self._pad_token,
91
+ bos_token=self._bos_token,
92
+ eos_token=self._eos_token,
93
+ unk_token=self._unk_token,
94
+ **kwargs,
95
+ )
96
+
97
+ def _create_default_vocab(self) -> Dict[str, int]:
98
+ """
99
+ Create a minimal default vocabulary with just special tokens.
100
+
101
+ For the full vocabulary, use `build_vocab_from_dataset()`.
102
+ This minimal vocab is just a placeholder - you should build from data.
103
+ """
104
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
105
+ vocab = {token: idx for idx, token in enumerate(special_tokens)}
106
+ return vocab
107
+
108
+ @classmethod
109
+ def build_vocab_from_iterator(
110
+ cls,
111
+ iterator,
112
+ min_frequency: int = 1,
113
+ ) -> "ChessTokenizer":
114
+ """
115
+ Build a tokenizer vocabulary from an iterator of game strings.
116
+
117
+ Args:
118
+ iterator: An iterator yielding game strings (space-separated moves).
119
+ min_frequency: Minimum frequency for a token to be included.
120
+
121
+ Returns:
122
+ A ChessTokenizer with the built vocabulary.
123
+ """
124
+ from collections import Counter
125
+
126
+ token_counts = Counter()
127
+
128
+ for game in iterator:
129
+ moves = game.strip().split()
130
+ token_counts.update(moves)
131
+
132
+ # Filter by frequency
133
+ tokens = [
134
+ token for token, count in token_counts.items()
135
+ if count >= min_frequency
136
+ ]
137
+
138
+ # Sort for reproducibility
139
+ tokens = sorted(tokens)
140
+
141
+ # Build vocabulary
142
+ special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
143
+ vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
144
+
145
+ return cls(vocab=vocab)
146
+
147
+ @classmethod
148
+ def build_vocab_from_dataset(
149
+ cls,
150
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
151
+ split: str = "train",
152
+ column: str = "text",
153
+ min_frequency: int = 500,
154
+ max_samples: Optional[int] = 100000,
155
+ ) -> "ChessTokenizer":
156
+ """
157
+ Build a tokenizer vocabulary from a Hugging Face dataset.
158
+
159
+ Args:
160
+ dataset_name: Name of the dataset on Hugging Face Hub.
161
+ split: Dataset split to use.
162
+ column: Column containing the game strings.
163
+ min_frequency: Minimum frequency for a token to be included (default: 500).
164
+ max_samples: Maximum number of samples to process (default: 100k).
165
+
166
+ Returns:
167
+ A ChessTokenizer with the built vocabulary.
168
+ """
169
+ from datasets import load_dataset
170
+
171
+ dataset = load_dataset(dataset_name, split=split)
172
+
173
+ if max_samples is not None:
174
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
175
+
176
+ def game_iterator():
177
+ for example in dataset:
178
+ yield example[column]
179
+
180
+ return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
181
+
182
+ @property
183
+ def vocab_size(self) -> int:
184
+ """Return the size of the vocabulary."""
185
+ return len(self._vocab)
186
+
187
+ def get_vocab(self) -> Dict[str, int]:
188
+ """Return the vocabulary as a dictionary."""
189
+ return dict(self._vocab)
190
+
191
+ def _tokenize(self, text: str) -> List[str]:
192
+ """
193
+ Tokenize a string of moves into a list of tokens.
194
+
195
+ Args:
196
+ text: A string of space-separated moves.
197
+
198
+ Returns:
199
+ List of move tokens.
200
+ """
201
+ return text.strip().split()
202
+
203
+ def _convert_token_to_id(self, token: str) -> int:
204
+ """Convert a token to its ID."""
205
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
206
+
207
+ def _convert_id_to_token(self, index: int) -> str:
208
+ """Convert an ID to its token."""
209
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
210
+
211
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
212
+ """Convert a list of tokens back to a string."""
213
+ # Filter out special tokens for cleaner output
214
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
215
+ return " ".join(t for t in tokens if t not in special)
216
+
217
+ def save_vocabulary(
218
+ self,
219
+ save_directory: str,
220
+ filename_prefix: Optional[str] = None,
221
+ ) -> tuple:
222
+ """
223
+ Save the vocabulary to a JSON file.
224
+
225
+ Args:
226
+ save_directory: Directory to save the vocabulary.
227
+ filename_prefix: Optional prefix for the filename.
228
+
229
+ Returns:
230
+ Tuple containing the path to the saved vocabulary file.
231
+ """
232
+ if not os.path.isdir(save_directory):
233
+ os.makedirs(save_directory, exist_ok=True)
234
+
235
+ vocab_file = os.path.join(
236
+ save_directory,
237
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
238
+ )
239
+
240
+ with open(vocab_file, "w", encoding="utf-8") as f:
241
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
242
+
243
+ return (vocab_file,)
244
+
245
+
246
+ def count_vocab_from_dataset(
247
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
248
+ split: str = "train",
249
+ column: str = "text",
250
+ max_samples: Optional[int] = 10000,
251
+ ) -> Dict[str, int]:
252
+ """
253
+ Count token frequencies in a dataset (useful for vocabulary analysis).
254
+
255
+ Args:
256
+ dataset_name: Name of the dataset on Hugging Face Hub.
257
+ split: Dataset split to use.
258
+ column: Column containing the game strings.
259
+ max_samples: Maximum number of samples to process.
260
+
261
+ Returns:
262
+ Dictionary mapping tokens to their frequencies.
263
+ """
264
+ from collections import Counter
265
+ from datasets import load_dataset
266
+
267
+ dataset = load_dataset(dataset_name, split=split)
268
+
269
+ if max_samples is not None:
270
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
271
+
272
+ token_counts = Counter()
273
+
274
+ for example in dataset:
275
+ moves = example[column].strip().split()
276
+ token_counts.update(moves)
277
+
278
+ return dict(token_counts)