Gin Rummy MDP
A reinforcement learning model for Gin Rummy, trained via PPO self-play in pure JAX.
R42 is the latest checkpoint: a 4.58M parameter SimBa residual network trained for 1.4 billion environment steps against a mixed curriculum of opponents. The checkpoint is an 18 MB pickle file that runs on CPU with no GPU required.
Play it live: https://sleeping-frames-coalition-justin.trycloudflare.com
Source code: github.com/GoodStartLabs/gin-rummy-mdp (not required to use the model)
Table of Contents
- Quick Start (Python)
- Card Encoding
- Action Space
- Game Flow
- Observation Vector (342 dimensions)
- Network Architecture
- Checkpoint Format
- Inference Step by Step
- Legal Action Rules
- Scoring Rules
- PyTorch Reference Implementation
- Available Checkpoints
- Training Details
1. Quick Start (Python)
pip install jax flax huggingface_hub
from huggingface_hub import hf_hub_download
import pickle, jax.numpy as jnp
# Download the R42 final checkpoint (1.4B steps)
path = hf_hub_download(
repo_id="GoodStartLabs/gin-rummy-mdp",
filename="checkpoints/r42/stage1_final.pkl",
repo_type="model",
)
# Load parameters
with open(path, "rb") as f:
params = pickle.load(f)
p = params.get("params", params)
# p is a dict of weight arrays โ see "Checkpoint Format" below
print(sorted(p.keys()))
# ['Dense_0', 'Dense_1', ..., 'LayerNorm_0', ..., 'opp_dw_pred', 'value_discard', 'value_draw', 'value_knock']
2. Card Encoding
Cards are integers 0 through 51.
suit = card // 13 (0=Spades, 1=Hearts, 2=Diamonds, 3=Clubs)
rank = card % 13 (0=Ace, 1=Two, 2=Three, ..., 9=Ten, 10=Jack, 11=Queen, 12=King)
Full Mapping
| Card | Suit | Rank | Name |
|---|---|---|---|
| 0 | Spades | 0 | Ace of Spades |
| 1 | Spades | 1 | Two of Spades |
| 2 | Spades | 2 | Three of Spades |
| ... | ... | ... | ... |
| 9 | Spades | 9 | Ten of Spades |
| 10 | Spades | 10 | Jack of Spades |
| 11 | Spades | 11 | Queen of Spades |
| 12 | Spades | 12 | King of Spades |
| 13 | Hearts | 0 | Ace of Hearts |
| 14 | Hearts | 1 | Two of Hearts |
| ... | ... | ... | ... |
| 25 | Hearts | 12 | King of Hearts |
| 26 | Diamonds | 0 | Ace of Diamonds |
| ... | ... | ... | ... |
| 38 | Diamonds | 12 | King of Diamonds |
| 39 | Clubs | 0 | Ace of Clubs |
| ... | ... | ... | ... |
| 51 | Clubs | 12 | King of Clubs |
Deadwood Values
| Rank | Card | Deadwood Points |
|---|---|---|
| 0 | Ace | 1 |
| 1 | Two | 2 |
| 2 | Three | 3 |
| 3 | Four | 4 |
| 4 | Five | 5 |
| 5 | Six | 6 |
| 6 | Seven | 7 |
| 7 | Eight | 8 |
| 8 | Nine | 9 |
| 9 | Ten | 10 |
| 10 | Jack | 10 |
| 11 | Queen | 10 |
| 12 | King | 10 |
Melds
- Set (group): 3 or 4 cards of the same rank, any suits. Example: 5 of Spades (4), 5 of Hearts (17), 5 of Clubs (43).
- Run (sequence): 3+ consecutive ranks in the same suit. Example: 3/4/5 of Diamonds (28, 29, 30). Aces are low only (A-2-3 is valid, Q-K-A is not).
3. Action Space
The model uses a unified 16-action space across all phases. The game phase determines which actions are legal.
| Action | Phase | Meaning |
|---|---|---|
| 0 | Draw | Draw from stock pile (face-down) |
| 1 | Draw | Draw from discard pile (face-up top card) |
| 2 | Discard | Discard card at hand index 0 |
| 3 | Discard | Discard card at hand index 1 |
| 4 | Discard | Discard card at hand index 2 |
| 5 | Discard | Discard card at hand index 3 |
| 6 | Discard | Discard card at hand index 4 |
| 7 | Discard | Discard card at hand index 5 |
| 8 | Discard | Discard card at hand index 6 |
| 9 | Discard | Discard card at hand index 7 |
| 10 | Discard | Discard card at hand index 8 |
| 11 | Discard | Discard card at hand index 9 |
| 12 | Discard | Discard card at hand index 10 |
| 13 | Knock Decision | Continue playing (don't knock) |
| 14 | Knock Decision | Knock (requires deadwood <= 10) |
| 15 | Knock Decision | Gin (requires deadwood = 0) |
4. Game Flow
Deal: 10 cards each, 1 upcard placed on discard pile
|
v
+-> [DRAW PHASE] -- Player draws 1 card (from stock or discard)
| |
| v
| [DISCARD PHASE] -- Player discards 1 card (now has 10 again)
| |
| v
| Deadwood <= 10? --No--> Switch to other player, go to DRAW
| |
| Yes
| v
| [KNOCK DECISION]
| |
| Continue? --Yes--> Switch to other player, go to DRAW
| |
| No (Knock or Gin)
| v
+ [GAME OVER] -- Score the hand
Terminal Conditions
- Knock: Player declares knock (deadwood 1-10). Score is computed with layoffs.
- Gin: Player declares gin (deadwood 0). Bonus awarded, no layoffs for defender.
- Stock exhausted: When 2 or fewer cards remain in the stock pile, the hand is a draw (no points awarded).
5. Observation Vector (342 dimensions)
The model receives a flat float32[342] vector. Every feature is documented below with its exact index range.
| Index Range | Dims | Feature | Value Range |
|---|---|---|---|
0:52 |
52 | Hand mask โ 1.0 if card is in the player's hand | binary {0, 1} |
52:104 |
52 | Discard pile visible โ 1.0 if card has been discarded | binary {0, 1} |
104:156 |
52 | Discard top card โ one-hot encoding of the top discard card | one-hot |
156 |
1 | Deadwood โ player's current deadwood / 100 | [0, 1] |
157:161 |
4 | Phase โ one-hot (draw / discard / knock_decision / game_over) | one-hot |
161 |
1 | Hand size โ number of cards in hand / 11 | [0, 1] |
162 |
1 | Discard pile size โ cards in discard / 52 | [0, 1] |
163 |
1 | Stock remaining โ cards left in stock / 31 | [0, 1] |
164 |
1 | Turn count โ turns elapsed / 35 | [0, 1] |
165 |
1 | Can knock โ 1.0 if deadwood <= 10 | binary {0, 1} |
166:177 |
11 | Discard deadwood โ deadwood after discarding each hand slot / 100 | [0, 1] |
177 |
1 | Draw-from-discard deadwood โ best deadwood if drawing top discard / 100 | [0, 1] |
178:230 |
52 | Opponent drew-from-discard โ 1.0 for each card opponent picked from discard | binary {0, 1} |
230:282 |
52 | Opponent declined-discard โ 1.0 for each card opponent chose not to pick | binary {0, 1} |
282 |
1 | Opponent estimated deadwood โ heuristic estimate from card counting | [0, 1] |
283:294 |
11 | Discard safety โ safety score per hand slot (high = opponent unlikely to want it) | [0, 1] |
294 |
1 | Undercut risk โ risk of being undercut if knocking now | [0, 1] |
295:306 |
11 | Meld membership โ 1.0 if discarding that slot increases deadwood | binary {0, 1} |
306:317 |
11 | Connector scores โ how many unseen cards could complete melds with each hand card | [0, 1] |
317 |
1 | Fraction of hand in melds โ sum(meld_membership) / 10 | [0, 1] |
318 |
1 | Cards from gin โ unmelded cards / 10 | [0, 1] |
319 |
1 | Game urgency โ 1.0 - stock_remaining (increases as deck runs out) | [0, 1] |
320 |
1 | Knock margin estimate โ (est_opp_dw - our_dw) / 50 | [-1, 1] |
321 |
1 | Opponent draw activity โ opponent discard draws / 5 | [0, 1] |
322:342 |
20 | Opponent type โ one-hot encoding of opponent type ID | one-hot |
Notes:
- Index
163: stock is normalized by 31 (52 total - 21 dealt cards = 31 initial stock). - Index
166:177: invalid hand slots (index >= hand_size) are padded with 1.0. - Index
177: set to 1.0 if discard pile is empty. - Index
322:342: set to all zeros for unknown opponents or during human play.
6. Network Architecture
The R42 model uses a SimBa (Simplified Balanced) residual architecture with phase-decomposed value heads and an auxiliary opponent deadwood prediction head.
Layer-by-Layer Specification
Input: float32[342]
|
v
Dense_0: Linear(342 -> 1024) + bias
| activation: ReLU
v
=== Residual Block 1 ===
|-- save as `residual`
| LayerNorm_0: scale[1024], bias[1024], eps=1e-5
| Dense_1: Linear(1024 -> 1024) + bias
| activation: ReLU
| Dense_2: Linear(1024 -> 1024) + bias (NO activation)
|-- x = residual + x
v
=== Residual Block 2 ===
|-- save as `residual`
| LayerNorm_1: scale[1024], bias[1024], eps=1e-5
| Dense_3: Linear(1024 -> 1024) + bias
| activation: ReLU
| Dense_4: Linear(1024 -> 1024) + bias (NO activation)
|-- x = residual + x
v
LayerNorm_2: scale[1024], bias[1024], eps=1e-5
|
v (shared features, used by all heads below)
|
+-- Dense_5: Linear(1024 -> 16) --> action logits
+-- value_draw: Linear(1024 -> 1) --> value estimate (draw phase)
+-- value_discard: Linear(1024 -> 1) --> value estimate (discard phase)
+-- value_knock: Linear(1024 -> 1) --> value estimate (knock decision phase)
+-- opp_dw_pred: Linear(1024 -> 1) --> auxiliary opponent deadwood prediction
Parameter Count
| Layer | Parameters |
|---|---|
| Dense_0 (342x1024 + 1024) | 351,232 |
| Dense_1 (1024x1024 + 1024) | 1,049,600 |
| Dense_2 (1024x1024 + 1024) | 1,049,600 |
| Dense_3 (1024x1024 + 1024) | 1,049,600 |
| Dense_4 (1024x1024 + 1024) | 1,049,600 |
| Dense_5 (1024x16 + 16) | 16,400 |
| LayerNorm_0 (1024 + 1024) | 2,048 |
| LayerNorm_1 (1024 + 1024) | 2,048 |
| LayerNorm_2 (1024 + 1024) | 2,048 |
| value_draw (1024x1 + 1) | 1,025 |
| value_discard (1024x1 + 1) | 1,025 |
| value_knock (1024x1 + 1) | 1,025 |
| opp_dw_pred (1024x1 + 1) | 1,025 |
| Total | 4,576,276 |
7. Checkpoint Format
The .pkl file is a Python pickle containing a dict:
{
"params": {
"Dense_0": {"kernel": float32[342, 1024], "bias": float32[1024]},
"Dense_1": {"kernel": float32[1024, 1024], "bias": float32[1024]},
"Dense_2": {"kernel": float32[1024, 1024], "bias": float32[1024]},
"Dense_3": {"kernel": float32[1024, 1024], "bias": float32[1024]},
"Dense_4": {"kernel": float32[1024, 1024], "bias": float32[1024]},
"Dense_5": {"kernel": float32[1024, 16], "bias": float32[16]},
"LayerNorm_0": {"scale": float32[1024], "bias": float32[1024]},
"LayerNorm_1": {"scale": float32[1024], "bias": float32[1024]},
"LayerNorm_2": {"scale": float32[1024], "bias": float32[1024]},
"value_draw": {"kernel": float32[1024, 1], "bias": float32[1]},
"value_discard": {"kernel": float32[1024, 1], "bias": float32[1]},
"value_knock": {"kernel": float32[1024, 1], "bias": float32[1]},
"opp_dw_pred": {"kernel": float32[1024, 1], "bias": float32[1]},
}
}
Important: Flax Dense layers store the kernel as [input_dim, output_dim] (NOT transposed). The forward pass computes output = input @ kernel + bias.
8. Inference Step by Step
1. Construct the 342-dimensional observation vector from game state
(see Section 5 for the complete index map)
2. Forward pass through the network:
x = relu(obs @ Dense_0.kernel + Dense_0.bias)
# Residual block 1
r = x
x = layer_norm(x, LN0.scale, LN0.bias)
x = relu(x @ Dense_1.kernel + Dense_1.bias)
x = x @ Dense_2.kernel + Dense_2.bias
x = r + x
# Residual block 2
r = x
x = layer_norm(x, LN1.scale, LN1.bias)
x = relu(x @ Dense_3.kernel + Dense_3.bias)
x = x @ Dense_4.kernel + Dense_4.bias
x = r + x
# Final norm
x = layer_norm(x, LN2.scale, LN2.bias)
# Actor head
logits = x @ Dense_5.kernel + Dense_5.bias # float32[16]
3. Compute legal action mask based on current game phase (see Section 9)
4. Mask illegal actions:
for i in range(16):
if not legal[i]:
logits[i] = -infinity
5. Select action:
- Greedy: action = argmax(logits)
- Stochastic: action = sample from softmax(logits)
6. Execute the action in your game engine
LayerNorm formula:
layer_norm(x, scale, bias, eps=1e-5):
mean = mean(x)
var = var(x)
x_norm = (x - mean) / sqrt(var + eps)
return x_norm * scale + bias
9. Legal Action Rules
Draw Phase (player has 10 cards)
| Action | Legal When |
|---|---|
| 0 (draw stock) | Stock has > 2 cards remaining |
| 1 (draw discard) | Discard pile is not empty |
Both are typically legal. If stock is nearly exhausted, only discard draw may be available.
Discard Phase (player has 11 cards)
| Action | Legal When |
|---|---|
| 2 through 12 | Hand index (action - 2) < hand_size AND the card at that index is not the card just drawn from the discard pile |
The "re-discard ban" prevents immediately returning a card picked up from the discard pile.
Knock Decision Phase (player has 10 cards, deadwood <= 10)
| Action | Legal When |
|---|---|
| 13 (continue) | Always legal |
| 14 (knock) | Deadwood is 1 through 10 |
| 15 (gin) | Deadwood is exactly 0 |
This phase only occurs when deadwood <= 10 after discarding. If deadwood > 10, the game skips directly to the other player's draw phase.
10. Scoring Rules
Normal Knock (deadwood 1-10)
- Knocker reveals hand and forms melds.
- Defender reveals hand, forms their own melds, then lays off unmelded cards onto the knocker's melds (extending runs or completing sets).
- Compare deadwood:
- Knocker wins: knocker_deadwood < defender_deadwood_after_layoffs. Knocker scores the difference.
- Undercut: defender_deadwood_after_layoffs <= knocker_deadwood. Defender scores the difference + 25 bonus.
Gin (deadwood = 0)
- Knocker scores defender's deadwood + 25 bonus.
- Defender gets no layoffs against a gin hand.
Stock Exhausted
- When 2 or fewer cards remain in the stock pile, the hand ends in a draw.
- Neither player scores any points.
Deadwood Calculation
Sum the deadwood values (see Section 2) of all cards NOT part of any meld. The optimal meld arrangement is used (minimizing deadwood).
11. PyTorch Reference Implementation
A complete, self-contained PyTorch implementation for loading and running the model. No JAX or Flax dependency required.
import pickle
import numpy as np
import torch
import torch.nn as nn
class GinRummyModel(nn.Module):
"""R42 SimBa architecture with phase-decomposed value heads."""
def __init__(self):
super().__init__()
# Input projection
self.dense_0 = nn.Linear(342, 1024)
# Residual block 1
self.ln_0 = nn.LayerNorm(1024, eps=1e-5)
self.dense_1 = nn.Linear(1024, 1024)
self.dense_2 = nn.Linear(1024, 1024)
# Residual block 2
self.ln_1 = nn.LayerNorm(1024, eps=1e-5)
self.dense_3 = nn.Linear(1024, 1024)
self.dense_4 = nn.Linear(1024, 1024)
# Final norm
self.ln_2 = nn.LayerNorm(1024, eps=1e-5)
# Output heads
self.actor = nn.Linear(1024, 16)
self.value_draw = nn.Linear(1024, 1)
self.value_discard = nn.Linear(1024, 1)
self.value_knock = nn.Linear(1024, 1)
self.opp_dw_pred = nn.Linear(1024, 1)
def forward(self, obs):
"""
Args:
obs: float32 tensor of shape (..., 342)
Returns:
logits: float32 (..., 16) โ raw action logits (mask before use)
value_draw: float32 (..., 1) โ value estimate for draw phase
value_discard: float32 (..., 1) โ value estimate for discard phase
value_knock: float32 (..., 1) โ value estimate for knock phase
opp_dw_pred: float32 (..., 1) โ predicted opponent deadwood
"""
# Input projection
x = torch.relu(self.dense_0(obs))
# Residual block 1
r = x
x = self.ln_0(x)
x = torch.relu(self.dense_1(x))
x = self.dense_2(x)
x = r + x
# Residual block 2
r = x
x = self.ln_1(x)
x = torch.relu(self.dense_3(x))
x = self.dense_4(x)
x = r + x
# Final norm + heads
x = self.ln_2(x)
logits = self.actor(x)
return (
logits,
self.value_draw(x),
self.value_discard(x),
self.value_knock(x),
self.opp_dw_pred(x),
)
def load_from_pkl(pkl_path: str) -> GinRummyModel:
"""Load Flax weights from a .pkl checkpoint into PyTorch.
Handles the Flax [in, out] -> PyTorch [out, in] kernel transposition.
"""
with open(pkl_path, "rb") as f:
params = pickle.load(f)
p = params.get("params", params)
model = GinRummyModel()
def set_linear(module, name):
# Flax kernel is [in_features, out_features]
# PyTorch weight is [out_features, in_features]
module.weight.data = torch.from_numpy(np.array(p[name]["kernel"]).T)
module.bias.data = torch.from_numpy(np.array(p[name]["bias"]).ravel())
def set_ln(module, name):
module.weight.data = torch.from_numpy(np.array(p[name]["scale"]))
module.bias.data = torch.from_numpy(np.array(p[name]["bias"]))
set_linear(model.dense_0, "Dense_0")
set_linear(model.dense_1, "Dense_1")
set_linear(model.dense_2, "Dense_2")
set_linear(model.dense_3, "Dense_3")
set_linear(model.dense_4, "Dense_4")
set_linear(model.actor, "Dense_5")
set_linear(model.value_draw, "value_draw")
set_linear(model.value_discard, "value_discard")
set_linear(model.value_knock, "value_knock")
set_linear(model.opp_dw_pred, "opp_dw_pred")
set_ln(model.ln_0, "LayerNorm_0")
set_ln(model.ln_1, "LayerNorm_1")
set_ln(model.ln_2, "LayerNorm_2")
model.eval()
return model
# --- Usage example ---
if __name__ == "__main__":
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id="GoodStartLabs/gin-rummy-mdp",
filename="checkpoints/r42/stage1_final.pkl",
repo_type="model",
)
model = load_from_pkl(path)
# Create a dummy observation (all zeros)
obs = torch.zeros(342)
with torch.no_grad():
logits, v_draw, v_discard, v_knock, opp_dw = model(obs)
print(f"Logits shape: {logits.shape}") # torch.Size([16])
print(f"Top action: {logits.argmax().item()}")
12. Available Checkpoints
R42 (latest, recommended)
SimBa + auxiliary heads, trained with mixed opponent curriculum. Use stage1_final.pkl unless you need an intermediate snapshot.
| File | Steps | Notes |
|---|---|---|
checkpoints/r42/stage1_final.pkl |
1.4B | Final model (recommended) |
checkpoints/r42/stage1_1400M.pkl |
1.4B | Same as final |
checkpoints/r42/stage1_1300M.pkl |
1.3B | |
checkpoints/r42/stage1_1200M.pkl |
1.2B | |
checkpoints/r42/stage1_1100M.pkl |
1.1B | |
checkpoints/r42/stage1_1000M.pkl |
1.0B | |
checkpoints/r42/stage1_900M.pkl |
900M | |
checkpoints/r42/stage1_800M.pkl |
800M | |
checkpoints/r42/stage1_700M.pkl |
700M | |
checkpoints/r42/stage1_600M.pkl |
600M | |
checkpoints/r42/stage1_500M.pkl |
500M | |
checkpoints/r42/stage1_400M.pkl |
400M | |
checkpoints/r42/stage1_300M.pkl |
300M | |
checkpoints/r42/stage1_200M.pkl |
200M | |
checkpoints/r42/stage1_100M.pkl |
100M | |
checkpoints/r42/run42_config.toml |
โ | Training configuration |
Older Runs
| Path | Architecture | Notes |
|---|---|---|
checkpoints/r40/ |
SimBa | 700M steps, predecessor to R42 |
checkpoints/r39/ |
SimBa | 900M steps |
checkpoints/r37_*.pkl, r38_*.pkl |
SimBa | Earlier experiments |
checkpoints/r33/, r34/, r35/, r36/ |
SimBa | Older runs with different obs dims |
checkpoints/r24_*, r25_*, r26_* |
MLP (no residual) | Early MLP architecture |
Other Files
| Path | Description |
|---|---|
human_games/*.json |
Recorded games from human play sessions |
configs/run26_config.toml |
Example training config |
Compatibility note: Checkpoints from R39 and later use the same 342-dim observation and SimBa architecture as R42. Earlier runs (R24-R38) use different observation sizes or architectures and are not interchangeable.
13. Training Details
| Parameter | Value |
|---|---|
| Algorithm | PPO (Proximal Policy Optimization) |
| Total environment steps | 1.4 billion |
| Architecture | SimBa (residual + LayerNorm), 4.58M params |
| Learning rate | 2.5e-4 (annealed to 0) |
| Environments (parallel) | 4,096 |
| Steps per rollout | 128 |
| Minibatches | 4 |
| Update epochs | 4 |
| Discount (gamma) | 1.0 |
| GAE lambda | 0.98 |
| Clip epsilon | 0.2 |
| Entropy coefficient | 0.025 |
| Value function coefficient | 0.75 |
| Max gradient norm | 0.5 |
| Reward | Categorical terminal (+1 gin, +0.25..+0.70 knock win, -0.15..-0.85 losses) |
| Auxiliary loss | Opponent deadwood prediction (coefficient 0.1) |
| P0/P1 alternation | Agent plays as both first and second player |
Opponent Curriculum
The agent trains against a mix of opponents, sampled per-episode:
| Opponent | Probability | Description |
|---|---|---|
| Heuristic | 30% | Rule-based player with meld tracking |
| Aggressive Knock | 15% | Knocks as soon as legally possible |
| Meld Builder | 10% | Prioritizes forming melds over low deadwood |
| Early Knock | 10% | Targets fast knocks with moderate deadwood |
| Defensive | 10% | Conservative, safety-focused play |
| Superhuman Lv5 | 10% | Strong opponent with deep card tracking |
| Frozen Self | 5% | Past checkpoint of the learning agent |
| Superhuman Lv4 | 5% | Moderate-strength superhuman |
| Superhuman Lv7 | 5% | Advanced opponent with aggressive timing |
License
Apache 2.0