/** * Self-Play Dataset Generation for Trigo AI * * This script generates a dataset of self-play games for training the Trigo AI model. * The AI plays against itself using either tree attention or MCTS-based agents. * * Features: * - Configurable number of games to generate * - Random or fixed board shapes (2D: 2x1x1 to 5x5x1, 3D: 2x2x2 to 3x3x3) * - Temperature-based sampling for move diversity * - MCTS (Monte Carlo Tree Search) mode with AlphaGo Zero algorithm * - Automatic game termination detection (50% coverage threshold) * - TGN format output with score notation for each game * - Optional visit count statistics for MCTS training data * - Per-board-shape statistics * - Progress tracking * * Usage: * npx tsx tools/selfPlayGames.ts [options] * * Options: * --games Number of games to generate (default: 10) * --output Output directory (default: ./tools/output/selfplay) * --board Board shape "X*Y*Z" or "random" (default: "random") * --temperature Sampling temperature (default: 1.0) * --max-moves Maximum moves per game (default: 300) * --model Path to tree model ONNX file * --eval-model Path to evaluation model ONNX file (for MCTS) * --verbose Enable verbose logging * * MCTS Options: * --use-mcts Enable MCTS for move selection * --mcts-simulations MCTS simulations per move (default: 600) * --mcts-cpuct PUCT exploration constant (default: 1.0) * --mcts-dirichlet-alpha Dirichlet noise alpha (default: 0.03) * --mcts-dirichlet-epsilon Dirichlet noise epsilon (default: 0.25) * --save-visit-counts Save visit count statistics * * Output: * - Each game saved as game_.tgn (hash based on content) * - Dataset statistics in _dataset_stats.json * - Visit counts saved as game__visit_counts.json (if --save-visit-counts) */ import * as ort from "onnxruntime-node"; import * as path from "path"; import * as fs from "fs"; import * as crypto from "crypto"; import { fileURLToPath } from "url"; import { TrigoGame, StoneType } from "../inc/trigo/game"; import { ModelInferencer } from "../inc/modelInferencer"; import { loadEnvConfig, getOnnxModelPaths, getAbsoluteModelPath, getOnnxSessionOptions } from "../inc/config"; import { TrigoTreeAgent } from "../inc/trigoTreeAgent"; import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent"; import { MCTSAgent, type MCTSConfig } from "../inc/mctsAgent"; import type { Move, BoardShape } from "../inc/trigo/types"; import {encodeAb0yz} from "../inc/trigo/ab0yz"; // ES module equivalent of __dirname const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); // Load environment variables await loadEnvConfig(); // Default model paths from environment const defaultModelPaths = getOnnxModelPaths(); // Board shape types type BoardShapeTuple = [number, number, number]; /** * Generate all board shapes in a range (inclusive) */ const arangeShape = (min: BoardShapeTuple, max: BoardShapeTuple): BoardShapeTuple[] => { const result: BoardShapeTuple[] = []; for (let x = min[0]; x <= max[0]; x++) { for (let y = min[1]; y <= max[1]; y++) { for (let z = min[2]; z <= max[2]; z++) { result.push([x, y, z]); } } } return result; }; /** * Candidate board shapes for random selection * - 2D boards: 2x1x1 to 19x19x1 * - 3D boards: 2x2x2 to 9x9x9 */ const CANDIDATE_BOARD_SHAPES = [ ...arangeShape([2, 1, 1], [13, 13, 1]), ...arangeShape([2, 2, 2], [5, 5, 5]), ]; // Configuration interface GenerationConfig { numGames: number; outputDir: string; temperature: number; maxMoves: number; verbose: boolean; modelPath: string; evaluationModelPath: string; vocabSize: number; seqLen: number; boardShape: BoardShape | "random"; // MCTS configuration useMCTS: boolean; mctsSimulations: number; mctsCPuct: number; mctsDirichletAlpha: number; mctsDirichletEpsilon: number; saveVisitCounts: boolean; } // Game statistics interface GameStats { gameId: number; boardShape: string; moveCount: number; maxMovesReached: boolean; duration: number; // milliseconds averageMoveTime: number; // milliseconds scoreDiff: number; // white - black (positive: white wins, negative: black wins) } // Board shape statistics interface BoardShapeStats { boardShape: string; gameCount: number; averageScoreDiff: number; averageMoveCount: number; } // Dataset statistics interface DatasetStats { totalGames: number; totalMoves: number; blackWins: number; whiteWins: number; resignations: number; maxMovesReached: number; averageGameLength: number; averageMoveTime: number; generationTime: number; // milliseconds averageScoreDiff: number; // average white - black boardShapeStats: BoardShapeStats[]; // stats per board shape games: GameStats[]; } /** * Parse board shape string (e.g., "5*5*5" or "9*9*1") * Special value "random" selects randomly from CANDIDATE_BOARD_SHAPES */ function parseBoardShape(shapeStr: string): BoardShape | "random" { // Handle random selection if (shapeStr.toLowerCase() === "random") { return "random"; } // Parse explicit board shape const parts = shapeStr.split(/[^0-9]+/).filter(Boolean).map(Number); if (parts.length !== 3) { throw new Error(`Invalid board shape: ${shapeStr}. Expected format: "X*Y*Z" or "random"`); } return { x: parts[0], y: parts[1], z: parts[2] }; } /** * Select a random board shape from candidates */ function selectRandomBoardShape(): BoardShape { const randomIndex = Math.floor(Math.random() * CANDIDATE_BOARD_SHAPES.length); const [x, y, z] = CANDIDATE_BOARD_SHAPES[randomIndex]; return { x, y, z }; } /** * Parse command line arguments */ function parseArgs(): GenerationConfig { const args = process.argv.slice(2); const config: GenerationConfig = { numGames: 10, outputDir: path.join(__dirname, "output/selfplay"), temperature: 1.0, maxMoves: 300, verbose: false, modelPath: getAbsoluteModelPath(defaultModelPaths.treeModel), evaluationModelPath: getAbsoluteModelPath(defaultModelPaths.evaluationModel), vocabSize: 128, seqLen: 256, boardShape: "random", // MCTS defaults useMCTS: false, mctsSimulations: 600, mctsCPuct: 1.0, mctsDirichletAlpha: 0.03, mctsDirichletEpsilon: 0.25, saveVisitCounts: false }; for (let i = 0; i < args.length; i++) { switch (args[i]) { case "--games": config.numGames = parseInt(args[++i], 10); break; case "--output": config.outputDir = args[++i]; break; case "--temperature": config.temperature = parseFloat(args[++i]); break; case "--max-moves": config.maxMoves = parseInt(args[++i], 10); break; case "--board": config.boardShape = parseBoardShape(args[++i]); break; case "--model": config.modelPath = args[++i]; break; case "--eval-model": config.evaluationModelPath = args[++i]; break; case "--use-mcts": config.useMCTS = true; break; case "--mcts-simulations": config.mctsSimulations = parseInt(args[++i], 10); break; case "--mcts-cpuct": config.mctsCPuct = parseFloat(args[++i]); break; case "--mcts-dirichlet-alpha": config.mctsDirichletAlpha = parseFloat(args[++i]); break; case "--mcts-dirichlet-epsilon": config.mctsDirichletEpsilon = parseFloat(args[++i]); break; case "--save-visit-counts": config.saveVisitCounts = true; break; case "--verbose": config.verbose = true; break; case "--help": printHelp(); process.exit(0); default: if (args[i].startsWith("--")) { console.error(`Unknown option: ${args[i]}`); printHelp(); process.exit(1); } } } return config; } /** * Print help message */ function printHelp(): void { console.log(` Usage: npx tsx tools/selfPlayGames.ts [options] Options: --games Number of games to generate (default: 10) --output Output directory (default: ./tools/output/selfplay) --board Board shape "X*Y*Z" or "random" (default: "random") --temperature Sampling temperature (default: 1.0) --max-moves Maximum moves per game (default: 300) --model Path to tree model ONNX file --eval-model Path to evaluation model ONNX file (for MCTS) --verbose Enable verbose logging MCTS Options: --use-mcts Enable MCTS for move selection --mcts-simulations MCTS simulations per move (default: 600) --mcts-cpuct PUCT exploration constant (default: 1.0) --mcts-dirichlet-alpha Dirichlet noise alpha (default: 0.03) --mcts-dirichlet-epsilon Dirichlet noise epsilon (default: 0.25) --save-visit-counts Save visit count statistics --help Show this help message Board Shape Examples: --board "5*5*5" Fixed 5x5x5 board for all games --board "9*9*1" Fixed 9x9x1 (2D) board for all games --board random Random board shape for each game (default) Examples: # Generate 100 games with random board shapes (no MCTS) npx tsx tools/selfPlayGames.ts --games 100 # Generate 50 games on 5x5x5 board with MCTS npx tsx tools/selfPlayGames.ts --games 50 --board "5*5*5" --use-mcts --mcts-simulations 600 # Generate games with custom models npx tsx tools/selfPlayGames.ts --games 20 --model ./models/tree.onnx --eval-model ./models/eval.onnx # Generate games with custom output directory and save visit counts npx tsx tools/selfPlayGames.ts --games 20 --output ./my_dataset --use-mcts --save-visit-counts `); } /** * Initialize the AI agent (tree agent or MCTS agent) */ async function initializeAgent(config: GenerationConfig): Promise { console.log("Initializing AI Agent..."); console.log(` Mode: ${config.useMCTS ? "MCTS Search" : "Tree Attention"}`); console.log(` Tree Model: ${config.modelPath}`); if (config.useMCTS) { console.log(` Evaluation Model: ${config.evaluationModelPath}`); console.log(` MCTS Simulations: ${config.mctsSimulations}`); console.log(` C-PUCT: ${config.mctsCPuct}`); } console.log(` Vocab Size: ${config.vocabSize}`); console.log(` Sequence Length: ${config.seqLen}`); // Load tree model const sessionOptions = getOnnxSessionOptions(); const treeSession = await ort.InferenceSession.create(config.modelPath, sessionOptions); const treeInferencer = new ModelInferencer(ort.Tensor as any, { vocabSize: config.vocabSize, seqLen: config.seqLen, modelPath: config.modelPath }); treeInferencer.setSession(treeSession as any); const treeAgent = new TrigoTreeAgent(treeInferencer); // Return tree agent if MCTS is disabled if (!config.useMCTS) { console.log("✓ Tree Agent initialized\n"); return treeAgent; } // Load evaluation model for MCTS const evalSession = await ort.InferenceSession.create(config.evaluationModelPath, sessionOptions); const evalInferencer = new ModelInferencer(ort.Tensor as any, { vocabSize: config.vocabSize, seqLen: config.seqLen, modelPath: config.evaluationModelPath }); evalInferencer.setSession(evalSession as any); const evaluationAgent = new TrigoEvaluationAgent(evalInferencer); // Create MCTS agent const mctsConfig: MCTSConfig = { numSimulations: config.mctsSimulations, cPuct: config.mctsCPuct, temperature: config.temperature, dirichletAlpha: config.mctsDirichletAlpha, dirichletEpsilon: config.mctsDirichletEpsilon }; const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, mctsConfig); console.log("✓ MCTS Agent initialized\n"); return mctsAgent; } /** * Sample a move from probability distribution with temperature */ function sampleMove(scoredMoves: Array<{ move: Move; score: number; notation: string }>, temperature: number): Move { // Apply temperature to scores (log probabilities) const adjustedScores = scoredMoves.map((m) => m.score / temperature); // Convert to probabilities using softmax const maxScore = Math.max(...adjustedScores); const expScores = adjustedScores.map((score) => Math.exp(score - maxScore)); const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); const probabilities = expScores.map((exp) => exp / sumExp); // Sample from distribution const random = Math.random(); let cumulative = 0; for (let i = 0; i < scoredMoves.length; i++) { cumulative += probabilities[i]; if (random <= cumulative) { return scoredMoves[i].move; } } // Fallback to last move (should never happen) return scoredMoves[scoredMoves.length - 1].move; } /** * Play a single self-play game */ async function playSelfPlayGame( agent: TrigoTreeAgent | MCTSAgent, gameId: number, config: GenerationConfig ): Promise<{ game: TrigoGame; stats: GameStats; visitCounts?: number[][] }> { // Select board shape (random or fixed) const boardShape: BoardShape = config.boardShape === "random" ? selectRandomBoardShape() : config.boardShape; const game = new TrigoGame(boardShape, {}); const startTime = Date.now(); let moveCount = 0; let totalMoveTime = 0; let consecutivePasses = 0; const visitCountsHistory: number[][] = []; // Calculate territory check threshold (50% coverage, same as MCTS) const totalPositions = boardShape.x * boardShape.y * boardShape.z; const coverageThreshold = Math.floor(totalPositions * 0.5); let territoryCheckStarted = false; if (config.verbose) { console.log(`\nGame ${gameId} started [Board: ${boardShape.x}×${boardShape.y}×${boardShape.z}]`); } while (moveCount < config.maxMoves) { // Check if we should start territory checking (after 50% coverage) if (!territoryCheckStarted && moveCount >= coverageThreshold) { territoryCheckStarted = true; if (config.verbose) { console.log(` Reached 50% coverage (${moveCount} moves), starting territory check`); } } // Get current player const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white"; const moveStartTime = Date.now(); let selectedMove: Move; let visitCounts: Map | undefined; // Use MCTS or tree agent depending on agent type if (agent instanceof MCTSAgent) { // MCTS move selection const result = await agent.selectMove(game, moveCount); selectedMove = result.move; visitCounts = result.visitCounts; // Store visit counts if requested if (config.saveVisitCounts && visitCounts) { visitCountsHistory.push(Array.from(visitCounts.values())); } } else { // Tree agent move selection (original method) // Get all valid moves const validPositions = game.validMovePositions(); const moves: Move[] = validPositions.map((pos) => ({ x: pos.x, y: pos.y, z: pos.z, player: currentPlayer })); moves.push({ player: currentPlayer, isPass: true }); // If no valid moves (only pass), game is over if (validPositions.length === 0) { game.pass(); break; } // Score all moves const scoredMoves = await agent.scoreMoves(game, moves); if (scoredMoves.length === 0) { break; } // Sort by score scoredMoves.sort((a, b) => b.score - a.score); // Sample move with temperature selectedMove = sampleMove(scoredMoves, config.temperature); } const moveEndTime = Date.now(); totalMoveTime += moveEndTime - moveStartTime; // Apply move let success = false; let moveNotation = ""; if (selectedMove.isPass) { success = game.pass(); moveNotation = "Pass"; consecutivePasses++; } else if ( selectedMove.x !== undefined && selectedMove.y !== undefined && selectedMove.z !== undefined ) { success = game.drop({ x: selectedMove.x, y: selectedMove.y, z: selectedMove.z }); moveNotation = encodeAb0yz([selectedMove.x, selectedMove.y, selectedMove.z], [boardShape.x, boardShape.y, boardShape.z]); consecutivePasses = 0; } process.stdout.write(moveNotation + " "); if (!success) { console.error(`Failed to apply move: ${moveNotation}`); break; } moveCount++; if (config.verbose) { const player = currentPlayer === "black" ? "Black" : "White"; console.log(` Move ${moveCount}: ${player} plays ${moveNotation}`); } // Check for game end (two consecutive passes) if (consecutivePasses >= 2) { if (config.verbose) { console.log(" Game ended: Two consecutive passes"); } break; } // Check territory after 50% coverage (same as MCTS) if (territoryCheckStarted && !selectedMove.isPass) { // Check for natural termination (all territory claimed, no capturing moves) if (game.isNaturallyTerminal()) { if (config.verbose) { const territory = game.getTerritory(); console.log(` Game ended: No neutral territory and no captures possible (settled)`); console.log(` Black: ${territory.black}, White: ${territory.white}`); } break; } else if (config.verbose) { const territory = game.getTerritory(); if (territory.neutral === 0) { console.log(` Territory settled but captures still possible (continuing...)`); } } } } const endTime = Date.now(); const duration = endTime - startTime; const averageMoveTime = moveCount > 0 ? totalMoveTime / moveCount : 0; // Determine game result const maxMovesReached = moveCount >= config.maxMoves; // Get final territory and calculate score difference const territory = game.getTerritory(); const scoreDiff = territory.white - territory.black; const stats: GameStats = { gameId, boardShape: `${boardShape.x}×${boardShape.y}×${boardShape.z}`, moveCount, maxMovesReached, duration, averageMoveTime, scoreDiff }; return { game, stats, visitCounts: config.saveVisitCounts ? visitCountsHistory : undefined }; } /** * Save game to TGN file using hash-based filename */ function saveGame(game: TrigoGame, outputDir: string): string { const tgn = game.toTGN({}, { markResult: true }); // Generate filename based on content hash (same as generateRandomGames.ts) const hash = crypto.createHash('sha256').update(tgn).digest('hex'); const filename = `game_${hash.substring(0, 16)}.tgn`; const filepath = path.join(outputDir, filename); fs.writeFileSync(filepath, tgn, "utf-8"); return filename; } /** * Generate dataset of self-play games */ async function generateDataset(config: GenerationConfig): Promise { console.log("=".repeat(80)); console.log("Trigo Self-Play Dataset Generation"); console.log("=".repeat(80)); console.log(`Configuration:`); console.log(` Number of games: ${config.numGames}`); console.log(` Output directory: ${config.outputDir}`); console.log(` Temperature: ${config.temperature}`); console.log(` Max moves per game: ${config.maxMoves}`); console.log(` Verbose: ${config.verbose}`); console.log(); // Create output directory if (!fs.existsSync(config.outputDir)) { fs.mkdirSync(config.outputDir, { recursive: true }); console.log(`✓ Created output directory: ${config.outputDir}\n`); } // Initialize agent const agent = await initializeAgent(config); // Generate games const startTime = Date.now(); const datasetStats: DatasetStats = { totalGames: 0, totalMoves: 0, blackWins: 0, whiteWins: 0, resignations: 0, maxMovesReached: 0, averageGameLength: 0, averageMoveTime: 0, generationTime: 0, averageScoreDiff: 0, boardShapeStats: [], games: [] }; console.log("Generating games..."); console.log("=".repeat(80)); for (let i = 1; i <= config.numGames; i++) { const gameStartTime = Date.now(); // Play game const { game, stats, visitCounts } = await playSelfPlayGame(agent, i, config); // Save game with hash-based filename saveGame(game, config.outputDir); // Save visit counts if available if (visitCounts && config.saveVisitCounts) { const visitCountsPath = path.join(config.outputDir, `game_${i}_visit_counts.json`); fs.writeFileSync(visitCountsPath, JSON.stringify(visitCounts, null, 2), "utf-8"); } const gameEndTime = Date.now(); const gameDuration = gameEndTime - gameStartTime; // Update statistics datasetStats.totalGames++; datasetStats.totalMoves += stats.moveCount; if (stats.maxMovesReached) datasetStats.maxMovesReached++; datasetStats.games.push(stats); // Progress update const progress = ((i / config.numGames) * 100).toFixed(1); const result = stats.scoreDiff > 0 ? `White +${stats.scoreDiff}` : stats.scoreDiff < 0 ? `Black +${Math.abs(stats.scoreDiff)}` : "Draw"; console.log( `[${progress}%] Game ${i}/${config.numGames} [${stats.boardShape}]: ` + `${stats.moveCount} moves, ${result}, ${(gameDuration / 1000).toFixed(1)}s` ); } const endTime = Date.now(); datasetStats.generationTime = endTime - startTime; datasetStats.averageGameLength = datasetStats.totalMoves / datasetStats.totalGames; datasetStats.averageMoveTime = datasetStats.games.reduce((sum, g) => sum + g.averageMoveTime, 0) / datasetStats.totalGames; // Calculate average score difference datasetStats.averageScoreDiff = datasetStats.games.reduce((sum, g) => sum + g.scoreDiff, 0) / datasetStats.totalGames; // Calculate per-board-shape statistics const shapeMap = new Map(); for (const game of datasetStats.games) { if (!shapeMap.has(game.boardShape)) { shapeMap.set(game.boardShape, []); } shapeMap.get(game.boardShape)!.push(game); } datasetStats.boardShapeStats = Array.from(shapeMap.entries()).map(([boardShape, games]) => ({ boardShape, gameCount: games.length, averageScoreDiff: games.reduce((sum, g) => sum + g.scoreDiff, 0) / games.length, averageMoveCount: games.reduce((sum, g) => sum + g.moveCount, 0) / games.length })); // Sort by board shape for consistent output datasetStats.boardShapeStats.sort((a, b) => a.boardShape.localeCompare(b.boardShape)); // Save dataset statistics with timestamp const timestamp = new Date().toISOString().replace(/[:.]/g, "-").split("T")[0]; const statsFilepath = path.join(config.outputDir, `${timestamp}_dataset_stats.json`); fs.writeFileSync(statsFilepath, JSON.stringify(datasetStats, null, 2), "utf-8"); // Print summary console.log("=".repeat(80)); console.log("Dataset Generation Complete!"); console.log("=".repeat(80)); console.log(`Total games: ${datasetStats.totalGames}`); console.log(`Total moves: ${datasetStats.totalMoves}`); console.log(`Average game length: ${datasetStats.averageGameLength.toFixed(1)} moves`); console.log(`Average move time: ${datasetStats.averageMoveTime.toFixed(1)}ms`); console.log(`Average score diff (W-B): ${datasetStats.averageScoreDiff.toFixed(2)}`); // Print per-board-shape statistics if (datasetStats.boardShapeStats.length > 0) { console.log(`\nBoard Shape Statistics:`); for (const shapeStats of datasetStats.boardShapeStats) { console.log(` [${shapeStats.boardShape}] ${shapeStats.gameCount} games, avg score: ${shapeStats.averageScoreDiff.toFixed(2)}, avg moves: ${shapeStats.averageMoveCount.toFixed(1)}`); } } console.log(`\nBlack wins: ${datasetStats.blackWins} (${((datasetStats.blackWins / datasetStats.totalGames) * 100).toFixed(1)}%)`); console.log(`White wins: ${datasetStats.whiteWins} (${((datasetStats.whiteWins / datasetStats.totalGames) * 100).toFixed(1)}%)`); console.log(`Resignations: ${datasetStats.resignations}`); console.log(`Max moves reached: ${datasetStats.maxMovesReached}`); console.log(`Total time: ${(datasetStats.generationTime / 1000 / 60).toFixed(1)} minutes`); console.log(`\nOutput directory: ${config.outputDir}`); console.log(`Statistics file: ${statsFilepath}`); console.log("=".repeat(80)); } /** * Main function */ async function main() { try { const config = parseArgs(); await generateDataset(config); } catch (error) { console.error("Error:", error); process.exit(1); } } // Run main function main();