# Fully Sharded Data Parallel (FSDP) v2

## Overview

When fine-tuning Large Language Models (LLMs) on TPUs, model sharding across devices becomes essential for memory efficiency and improved training performance. The `optimum.tpu.fsdp_v2` module provides utilities for implementing Fully Sharded Data Parallel training using SPMD (Single Program Multiple Data) specifically optimized for TPU devices.

## FSDP_v2 Features

- Model weight sharding across TPU devices
- Gradient checkpointing support
- Automatic configuration for common model architectures
- Integration with PyTorch/XLA's SPMD implementation

## Basic Usage

Here's how to enable and configure FSDP_v2 for your training:

```python
from optimum.tpu import fsdp_v2
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Enable FSDP_v2
fsdp_v2.use_fsdp_v2()

# Load model and tokenizer
model_id = "meta-llama/Llama-2-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16
)

# Get FSDP training configuration
fsdp_args = fsdp_v2.get_fsdp_training_args(model)
```

## Configuration Options

The `get_fsdp_training_args()` function returns a dictionary with a model-specific configuration such as:

```python
{
    'fsdp': 'full_shard',
    'fsdp_config': {
        'transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'],  # Model-specific
        'xla': True,
        'xla_fsdp_v2': True,
        'xla_fsdp_grad_ckpt': True
    }
}
```

### Key Parameters

- `transformer_layer_cls_to_wrap`: Specifies which model layers to wrap with FSDP
- `xla`: Enables XLA optimization
- `xla_fsdp_v2`: Activates FSDP_v2 implementation
- `xla_fsdp_grad_ckpt`: Enables gradient checkpointing for memory efficiency

## Advanced Usage

### Custom Layer Wrapping

You can customize which layers get wrapped with FSDP:

```python
custom_fsdp_args = fsdp_v2.get_fsdp_training_args(
    model,
    layer_cls_to_wrap=['CustomTransformerLayer']
)
```

### Integration with Transformers Trainer

FSDP_v2 configuration can be directly used with the Transformers Trainer:

```python
from transformers import Trainer, TrainingArguments
# Or for instruction fine-tuning:
# from trl import SFTTrainer

trainer = Trainer(  # or SFTTrainer
    model=model,
    args=TrainingArguments(**fsdp_args),  # Unpack FSDP configuration
    train_dataset=dataset,
    ...
)
```

## Next steps
- You can look our [example notebooks](../howto/more_examples) for best practice on training with optimum-tpu
- For more details on PyTorch/XLA's FSDP implementation, refer to the [official documentation](https://pytorch.org/xla/master/#fully-sharded-data-parallel-via-spmd).