Gin Rummy MDP

A reinforcement learning model for Gin Rummy, trained via PPO self-play in pure JAX.

R42 is the latest checkpoint: a 4.58M parameter SimBa residual network trained for 1.4 billion environment steps against a mixed curriculum of opponents. The checkpoint is an 18 MB pickle file that runs on CPU with no GPU required.

Play it live: https://sleeping-frames-coalition-justin.trycloudflare.com

Source code: github.com/GoodStartLabs/gin-rummy-mdp (not required to use the model)


Table of Contents

  1. Quick Start (Python)
  2. Card Encoding
  3. Action Space
  4. Game Flow
  5. Observation Vector (342 dimensions)
  6. Network Architecture
  7. Checkpoint Format
  8. Inference Step by Step
  9. Legal Action Rules
  10. Scoring Rules
  11. PyTorch Reference Implementation
  12. Available Checkpoints
  13. Training Details

1. Quick Start (Python)

pip install jax flax huggingface_hub
from huggingface_hub import hf_hub_download
import pickle, jax.numpy as jnp

# Download the R42 final checkpoint (1.4B steps)
path = hf_hub_download(
    repo_id="GoodStartLabs/gin-rummy-mdp",
    filename="checkpoints/r42/stage1_final.pkl",
    repo_type="model",
)

# Load parameters
with open(path, "rb") as f:
    params = pickle.load(f)
p = params.get("params", params)

# p is a dict of weight arrays โ€” see "Checkpoint Format" below
print(sorted(p.keys()))
# ['Dense_0', 'Dense_1', ..., 'LayerNorm_0', ..., 'opp_dw_pred', 'value_discard', 'value_draw', 'value_knock']

2. Card Encoding

Cards are integers 0 through 51.

suit = card // 13    (0=Spades, 1=Hearts, 2=Diamonds, 3=Clubs)
rank = card %  13    (0=Ace, 1=Two, 2=Three, ..., 9=Ten, 10=Jack, 11=Queen, 12=King)

Full Mapping

Card Suit Rank Name
0 Spades 0 Ace of Spades
1 Spades 1 Two of Spades
2 Spades 2 Three of Spades
... ... ... ...
9 Spades 9 Ten of Spades
10 Spades 10 Jack of Spades
11 Spades 11 Queen of Spades
12 Spades 12 King of Spades
13 Hearts 0 Ace of Hearts
14 Hearts 1 Two of Hearts
... ... ... ...
25 Hearts 12 King of Hearts
26 Diamonds 0 Ace of Diamonds
... ... ... ...
38 Diamonds 12 King of Diamonds
39 Clubs 0 Ace of Clubs
... ... ... ...
51 Clubs 12 King of Clubs

Deadwood Values

Rank Card Deadwood Points
0 Ace 1
1 Two 2
2 Three 3
3 Four 4
4 Five 5
5 Six 6
6 Seven 7
7 Eight 8
8 Nine 9
9 Ten 10
10 Jack 10
11 Queen 10
12 King 10

Melds

  • Set (group): 3 or 4 cards of the same rank, any suits. Example: 5 of Spades (4), 5 of Hearts (17), 5 of Clubs (43).
  • Run (sequence): 3+ consecutive ranks in the same suit. Example: 3/4/5 of Diamonds (28, 29, 30). Aces are low only (A-2-3 is valid, Q-K-A is not).

3. Action Space

The model uses a unified 16-action space across all phases. The game phase determines which actions are legal.

Action Phase Meaning
0 Draw Draw from stock pile (face-down)
1 Draw Draw from discard pile (face-up top card)
2 Discard Discard card at hand index 0
3 Discard Discard card at hand index 1
4 Discard Discard card at hand index 2
5 Discard Discard card at hand index 3
6 Discard Discard card at hand index 4
7 Discard Discard card at hand index 5
8 Discard Discard card at hand index 6
9 Discard Discard card at hand index 7
10 Discard Discard card at hand index 8
11 Discard Discard card at hand index 9
12 Discard Discard card at hand index 10
13 Knock Decision Continue playing (don't knock)
14 Knock Decision Knock (requires deadwood <= 10)
15 Knock Decision Gin (requires deadwood = 0)

4. Game Flow

Deal: 10 cards each, 1 upcard placed on discard pile
         |
         v
  +-> [DRAW PHASE] -- Player draws 1 card (from stock or discard)
  |      |
  |      v
  |  [DISCARD PHASE] -- Player discards 1 card (now has 10 again)
  |      |
  |      v
  |  Deadwood <= 10? --No--> Switch to other player, go to DRAW
  |      |
  |     Yes
  |      v
  |  [KNOCK DECISION]
  |      |
  |   Continue? --Yes--> Switch to other player, go to DRAW
  |      |
  |     No (Knock or Gin)
  |      v
  +  [GAME OVER] -- Score the hand

Terminal Conditions

  • Knock: Player declares knock (deadwood 1-10). Score is computed with layoffs.
  • Gin: Player declares gin (deadwood 0). Bonus awarded, no layoffs for defender.
  • Stock exhausted: When 2 or fewer cards remain in the stock pile, the hand is a draw (no points awarded).

5. Observation Vector (342 dimensions)

The model receives a flat float32[342] vector. Every feature is documented below with its exact index range.

Index Range Dims Feature Value Range
0:52 52 Hand mask โ€” 1.0 if card is in the player's hand binary {0, 1}
52:104 52 Discard pile visible โ€” 1.0 if card has been discarded binary {0, 1}
104:156 52 Discard top card โ€” one-hot encoding of the top discard card one-hot
156 1 Deadwood โ€” player's current deadwood / 100 [0, 1]
157:161 4 Phase โ€” one-hot (draw / discard / knock_decision / game_over) one-hot
161 1 Hand size โ€” number of cards in hand / 11 [0, 1]
162 1 Discard pile size โ€” cards in discard / 52 [0, 1]
163 1 Stock remaining โ€” cards left in stock / 31 [0, 1]
164 1 Turn count โ€” turns elapsed / 35 [0, 1]
165 1 Can knock โ€” 1.0 if deadwood <= 10 binary {0, 1}
166:177 11 Discard deadwood โ€” deadwood after discarding each hand slot / 100 [0, 1]
177 1 Draw-from-discard deadwood โ€” best deadwood if drawing top discard / 100 [0, 1]
178:230 52 Opponent drew-from-discard โ€” 1.0 for each card opponent picked from discard binary {0, 1}
230:282 52 Opponent declined-discard โ€” 1.0 for each card opponent chose not to pick binary {0, 1}
282 1 Opponent estimated deadwood โ€” heuristic estimate from card counting [0, 1]
283:294 11 Discard safety โ€” safety score per hand slot (high = opponent unlikely to want it) [0, 1]
294 1 Undercut risk โ€” risk of being undercut if knocking now [0, 1]
295:306 11 Meld membership โ€” 1.0 if discarding that slot increases deadwood binary {0, 1}
306:317 11 Connector scores โ€” how many unseen cards could complete melds with each hand card [0, 1]
317 1 Fraction of hand in melds โ€” sum(meld_membership) / 10 [0, 1]
318 1 Cards from gin โ€” unmelded cards / 10 [0, 1]
319 1 Game urgency โ€” 1.0 - stock_remaining (increases as deck runs out) [0, 1]
320 1 Knock margin estimate โ€” (est_opp_dw - our_dw) / 50 [-1, 1]
321 1 Opponent draw activity โ€” opponent discard draws / 5 [0, 1]
322:342 20 Opponent type โ€” one-hot encoding of opponent type ID one-hot

Notes:

  • Index 163: stock is normalized by 31 (52 total - 21 dealt cards = 31 initial stock).
  • Index 166:177: invalid hand slots (index >= hand_size) are padded with 1.0.
  • Index 177: set to 1.0 if discard pile is empty.
  • Index 322:342: set to all zeros for unknown opponents or during human play.

6. Network Architecture

The R42 model uses a SimBa (Simplified Balanced) residual architecture with phase-decomposed value heads and an auxiliary opponent deadwood prediction head.

Layer-by-Layer Specification

Input: float32[342]
  |
  v
Dense_0: Linear(342 -> 1024) + bias
  |  activation: ReLU
  v
=== Residual Block 1 ===
  |-- save as `residual`
  |   LayerNorm_0: scale[1024], bias[1024], eps=1e-5
  |   Dense_1: Linear(1024 -> 1024) + bias
  |   activation: ReLU
  |   Dense_2: Linear(1024 -> 1024) + bias (NO activation)
  |-- x = residual + x
  v
=== Residual Block 2 ===
  |-- save as `residual`
  |   LayerNorm_1: scale[1024], bias[1024], eps=1e-5
  |   Dense_3: Linear(1024 -> 1024) + bias
  |   activation: ReLU
  |   Dense_4: Linear(1024 -> 1024) + bias (NO activation)
  |-- x = residual + x
  v
LayerNorm_2: scale[1024], bias[1024], eps=1e-5
  |
  v (shared features, used by all heads below)
  |
  +-- Dense_5: Linear(1024 -> 16)     --> action logits
  +-- value_draw: Linear(1024 -> 1)   --> value estimate (draw phase)
  +-- value_discard: Linear(1024 -> 1) --> value estimate (discard phase)
  +-- value_knock: Linear(1024 -> 1)  --> value estimate (knock decision phase)
  +-- opp_dw_pred: Linear(1024 -> 1)  --> auxiliary opponent deadwood prediction

Parameter Count

Layer Parameters
Dense_0 (342x1024 + 1024) 351,232
Dense_1 (1024x1024 + 1024) 1,049,600
Dense_2 (1024x1024 + 1024) 1,049,600
Dense_3 (1024x1024 + 1024) 1,049,600
Dense_4 (1024x1024 + 1024) 1,049,600
Dense_5 (1024x16 + 16) 16,400
LayerNorm_0 (1024 + 1024) 2,048
LayerNorm_1 (1024 + 1024) 2,048
LayerNorm_2 (1024 + 1024) 2,048
value_draw (1024x1 + 1) 1,025
value_discard (1024x1 + 1) 1,025
value_knock (1024x1 + 1) 1,025
opp_dw_pred (1024x1 + 1) 1,025
Total 4,576,276

7. Checkpoint Format

The .pkl file is a Python pickle containing a dict:

{
    "params": {
        "Dense_0": {"kernel": float32[342, 1024], "bias": float32[1024]},
        "Dense_1": {"kernel": float32[1024, 1024], "bias": float32[1024]},
        "Dense_2": {"kernel": float32[1024, 1024], "bias": float32[1024]},
        "Dense_3": {"kernel": float32[1024, 1024], "bias": float32[1024]},
        "Dense_4": {"kernel": float32[1024, 1024], "bias": float32[1024]},
        "Dense_5": {"kernel": float32[1024, 16], "bias": float32[16]},
        "LayerNorm_0": {"scale": float32[1024], "bias": float32[1024]},
        "LayerNorm_1": {"scale": float32[1024], "bias": float32[1024]},
        "LayerNorm_2": {"scale": float32[1024], "bias": float32[1024]},
        "value_draw": {"kernel": float32[1024, 1], "bias": float32[1]},
        "value_discard": {"kernel": float32[1024, 1], "bias": float32[1]},
        "value_knock": {"kernel": float32[1024, 1], "bias": float32[1]},
        "opp_dw_pred": {"kernel": float32[1024, 1], "bias": float32[1]},
    }
}

Important: Flax Dense layers store the kernel as [input_dim, output_dim] (NOT transposed). The forward pass computes output = input @ kernel + bias.


8. Inference Step by Step

1. Construct the 342-dimensional observation vector from game state
   (see Section 5 for the complete index map)

2. Forward pass through the network:
   x = relu(obs @ Dense_0.kernel + Dense_0.bias)
   # Residual block 1
   r = x
   x = layer_norm(x, LN0.scale, LN0.bias)
   x = relu(x @ Dense_1.kernel + Dense_1.bias)
   x = x @ Dense_2.kernel + Dense_2.bias
   x = r + x
   # Residual block 2
   r = x
   x = layer_norm(x, LN1.scale, LN1.bias)
   x = relu(x @ Dense_3.kernel + Dense_3.bias)
   x = x @ Dense_4.kernel + Dense_4.bias
   x = r + x
   # Final norm
   x = layer_norm(x, LN2.scale, LN2.bias)
   # Actor head
   logits = x @ Dense_5.kernel + Dense_5.bias   # float32[16]

3. Compute legal action mask based on current game phase (see Section 9)

4. Mask illegal actions:
   for i in range(16):
       if not legal[i]:
           logits[i] = -infinity

5. Select action:
   - Greedy: action = argmax(logits)
   - Stochastic: action = sample from softmax(logits)

6. Execute the action in your game engine

LayerNorm formula:

layer_norm(x, scale, bias, eps=1e-5):
    mean = mean(x)
    var = var(x)
    x_norm = (x - mean) / sqrt(var + eps)
    return x_norm * scale + bias

9. Legal Action Rules

Draw Phase (player has 10 cards)

Action Legal When
0 (draw stock) Stock has > 2 cards remaining
1 (draw discard) Discard pile is not empty

Both are typically legal. If stock is nearly exhausted, only discard draw may be available.

Discard Phase (player has 11 cards)

Action Legal When
2 through 12 Hand index (action - 2) < hand_size AND the card at that index is not the card just drawn from the discard pile

The "re-discard ban" prevents immediately returning a card picked up from the discard pile.

Knock Decision Phase (player has 10 cards, deadwood <= 10)

Action Legal When
13 (continue) Always legal
14 (knock) Deadwood is 1 through 10
15 (gin) Deadwood is exactly 0

This phase only occurs when deadwood <= 10 after discarding. If deadwood > 10, the game skips directly to the other player's draw phase.


10. Scoring Rules

Normal Knock (deadwood 1-10)

  1. Knocker reveals hand and forms melds.
  2. Defender reveals hand, forms their own melds, then lays off unmelded cards onto the knocker's melds (extending runs or completing sets).
  3. Compare deadwood:
    • Knocker wins: knocker_deadwood < defender_deadwood_after_layoffs. Knocker scores the difference.
    • Undercut: defender_deadwood_after_layoffs <= knocker_deadwood. Defender scores the difference + 25 bonus.

Gin (deadwood = 0)

  • Knocker scores defender's deadwood + 25 bonus.
  • Defender gets no layoffs against a gin hand.

Stock Exhausted

  • When 2 or fewer cards remain in the stock pile, the hand ends in a draw.
  • Neither player scores any points.

Deadwood Calculation

Sum the deadwood values (see Section 2) of all cards NOT part of any meld. The optimal meld arrangement is used (minimizing deadwood).


11. PyTorch Reference Implementation

A complete, self-contained PyTorch implementation for loading and running the model. No JAX or Flax dependency required.

import pickle
import numpy as np
import torch
import torch.nn as nn


class GinRummyModel(nn.Module):
    """R42 SimBa architecture with phase-decomposed value heads."""

    def __init__(self):
        super().__init__()
        # Input projection
        self.dense_0 = nn.Linear(342, 1024)
        # Residual block 1
        self.ln_0 = nn.LayerNorm(1024, eps=1e-5)
        self.dense_1 = nn.Linear(1024, 1024)
        self.dense_2 = nn.Linear(1024, 1024)
        # Residual block 2
        self.ln_1 = nn.LayerNorm(1024, eps=1e-5)
        self.dense_3 = nn.Linear(1024, 1024)
        self.dense_4 = nn.Linear(1024, 1024)
        # Final norm
        self.ln_2 = nn.LayerNorm(1024, eps=1e-5)
        # Output heads
        self.actor = nn.Linear(1024, 16)
        self.value_draw = nn.Linear(1024, 1)
        self.value_discard = nn.Linear(1024, 1)
        self.value_knock = nn.Linear(1024, 1)
        self.opp_dw_pred = nn.Linear(1024, 1)

    def forward(self, obs):
        """
        Args:
            obs: float32 tensor of shape (..., 342)
        Returns:
            logits:        float32 (..., 16)  โ€” raw action logits (mask before use)
            value_draw:    float32 (..., 1)   โ€” value estimate for draw phase
            value_discard: float32 (..., 1)   โ€” value estimate for discard phase
            value_knock:   float32 (..., 1)   โ€” value estimate for knock phase
            opp_dw_pred:   float32 (..., 1)   โ€” predicted opponent deadwood
        """
        # Input projection
        x = torch.relu(self.dense_0(obs))

        # Residual block 1
        r = x
        x = self.ln_0(x)
        x = torch.relu(self.dense_1(x))
        x = self.dense_2(x)
        x = r + x

        # Residual block 2
        r = x
        x = self.ln_1(x)
        x = torch.relu(self.dense_3(x))
        x = self.dense_4(x)
        x = r + x

        # Final norm + heads
        x = self.ln_2(x)
        logits = self.actor(x)
        return (
            logits,
            self.value_draw(x),
            self.value_discard(x),
            self.value_knock(x),
            self.opp_dw_pred(x),
        )


def load_from_pkl(pkl_path: str) -> GinRummyModel:
    """Load Flax weights from a .pkl checkpoint into PyTorch.

    Handles the Flax [in, out] -> PyTorch [out, in] kernel transposition.
    """
    with open(pkl_path, "rb") as f:
        params = pickle.load(f)
    p = params.get("params", params)

    model = GinRummyModel()

    def set_linear(module, name):
        # Flax kernel is [in_features, out_features]
        # PyTorch weight is [out_features, in_features]
        module.weight.data = torch.from_numpy(np.array(p[name]["kernel"]).T)
        module.bias.data = torch.from_numpy(np.array(p[name]["bias"]).ravel())

    def set_ln(module, name):
        module.weight.data = torch.from_numpy(np.array(p[name]["scale"]))
        module.bias.data = torch.from_numpy(np.array(p[name]["bias"]))

    set_linear(model.dense_0, "Dense_0")
    set_linear(model.dense_1, "Dense_1")
    set_linear(model.dense_2, "Dense_2")
    set_linear(model.dense_3, "Dense_3")
    set_linear(model.dense_4, "Dense_4")
    set_linear(model.actor, "Dense_5")
    set_linear(model.value_draw, "value_draw")
    set_linear(model.value_discard, "value_discard")
    set_linear(model.value_knock, "value_knock")
    set_linear(model.opp_dw_pred, "opp_dw_pred")
    set_ln(model.ln_0, "LayerNorm_0")
    set_ln(model.ln_1, "LayerNorm_1")
    set_ln(model.ln_2, "LayerNorm_2")

    model.eval()
    return model


# --- Usage example ---
if __name__ == "__main__":
    from huggingface_hub import hf_hub_download

    path = hf_hub_download(
        repo_id="GoodStartLabs/gin-rummy-mdp",
        filename="checkpoints/r42/stage1_final.pkl",
        repo_type="model",
    )
    model = load_from_pkl(path)

    # Create a dummy observation (all zeros)
    obs = torch.zeros(342)
    with torch.no_grad():
        logits, v_draw, v_discard, v_knock, opp_dw = model(obs)
    print(f"Logits shape: {logits.shape}")  # torch.Size([16])
    print(f"Top action: {logits.argmax().item()}")

12. Available Checkpoints

R42 (latest, recommended)

SimBa + auxiliary heads, trained with mixed opponent curriculum. Use stage1_final.pkl unless you need an intermediate snapshot.

File Steps Notes
checkpoints/r42/stage1_final.pkl 1.4B Final model (recommended)
checkpoints/r42/stage1_1400M.pkl 1.4B Same as final
checkpoints/r42/stage1_1300M.pkl 1.3B
checkpoints/r42/stage1_1200M.pkl 1.2B
checkpoints/r42/stage1_1100M.pkl 1.1B
checkpoints/r42/stage1_1000M.pkl 1.0B
checkpoints/r42/stage1_900M.pkl 900M
checkpoints/r42/stage1_800M.pkl 800M
checkpoints/r42/stage1_700M.pkl 700M
checkpoints/r42/stage1_600M.pkl 600M
checkpoints/r42/stage1_500M.pkl 500M
checkpoints/r42/stage1_400M.pkl 400M
checkpoints/r42/stage1_300M.pkl 300M
checkpoints/r42/stage1_200M.pkl 200M
checkpoints/r42/stage1_100M.pkl 100M
checkpoints/r42/run42_config.toml โ€” Training configuration

Older Runs

Path Architecture Notes
checkpoints/r40/ SimBa 700M steps, predecessor to R42
checkpoints/r39/ SimBa 900M steps
checkpoints/r37_*.pkl, r38_*.pkl SimBa Earlier experiments
checkpoints/r33/, r34/, r35/, r36/ SimBa Older runs with different obs dims
checkpoints/r24_*, r25_*, r26_* MLP (no residual) Early MLP architecture

Other Files

Path Description
human_games/*.json Recorded games from human play sessions
configs/run26_config.toml Example training config

Compatibility note: Checkpoints from R39 and later use the same 342-dim observation and SimBa architecture as R42. Earlier runs (R24-R38) use different observation sizes or architectures and are not interchangeable.


13. Training Details

Parameter Value
Algorithm PPO (Proximal Policy Optimization)
Total environment steps 1.4 billion
Architecture SimBa (residual + LayerNorm), 4.58M params
Learning rate 2.5e-4 (annealed to 0)
Environments (parallel) 4,096
Steps per rollout 128
Minibatches 4
Update epochs 4
Discount (gamma) 1.0
GAE lambda 0.98
Clip epsilon 0.2
Entropy coefficient 0.025
Value function coefficient 0.75
Max gradient norm 0.5
Reward Categorical terminal (+1 gin, +0.25..+0.70 knock win, -0.15..-0.85 losses)
Auxiliary loss Opponent deadwood prediction (coefficient 0.1)
P0/P1 alternation Agent plays as both first and second player

Opponent Curriculum

The agent trains against a mix of opponents, sampled per-episode:

Opponent Probability Description
Heuristic 30% Rule-based player with meld tracking
Aggressive Knock 15% Knocks as soon as legally possible
Meld Builder 10% Prioritizes forming melds over low deadwood
Early Knock 10% Targets fast knocks with moderate deadwood
Defensive 10% Conservative, safety-focused play
Superhuman Lv5 10% Strong opponent with deep card tracking
Frozen Self 5% Past checkpoint of the learning agent
Superhuman Lv4 5% Moderate-strength superhuman
Superhuman Lv7 5% Advanced opponent with aggressive timing

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading