Understanding the Collision

What does a GPU see when it looks at a model parameter?

A parameter is just a tensor - a contiguous block of numbers in memory. A 4096×4096 weight matrix in float16 takes 32MB. Simple.

[float16, float16, float16, ... ] → 32MB on GPU

What does FSDP2 do?

FSDP2's job is elegant: take that 32MB tensor and split it across 2 GPUs. Each GPU holds 16MB. When you need the full tensor for computation, GPUs talk to each other (all-gather), compute, then throw away the extra copy.

GPU 0: [first half...]  → 16MB
GPU 1: [second half...] → 16MB

To do this, FSDP2 must convert your tensor into a DTensor (Distributed Tensor) - a tensor that knows it's sharded and knows how to reassemble itself.

What does QLoRA do?

QLoRA compresses that 32MB tensor down to ~4MB using 4-bit quantization. But here's the catch: you can't just store 4-bit numbers directly. You need metadata to decode them back to usable values.

Params4bit:
  - data: [4-bit packed values...] → 4MB
  - quant_state: {scales, zeros, quant_type="nf4", blocksize=64, ...}

The quant_state is a Python dictionary hanging off the tensor. It's not a tensor itself - it's metadata that tells bitsandbytes how to dequantize.

The Collision

Now watch what happens when FSDP2 meets Params4bit:

  1. FSDP2 sees a parameter, tries to convert it to DTensor
  2. It inspects the tensor... but this isn't a normal tensor
  3. It has custom new and init methods
  4. It has this quant_state thing attached to it
  5. FSDP2 doesn't know how to shard a Python dictionary
RuntimeError: Cannot flatten integer dtype tensors

FSDP2 is asking: "How do I split this in half?"

And Params4bit is saying: "I'm not just a tensor. I'm a tensor plus a decoder ring. You can't split the decoder ring."

The Key Insight

Let's think about what actually needs to be sharded.

In QLoRA training, you have two kinds of parameters:

Type Size Trainable Needs gradients Needs optimizer state
Base weights (4-bit) ~4GB No No No
LoRA adapters (fp16) ~200MB Yes Yes Yes

The frozen 4-bit base weights never change. They're just used for forward passes. The LoRA adapters are tiny but they accumulate gradients and optimizer states.

So here's the insight: we don't need to shard the problematic Params4bit at all. We only need to shard the LoRA adapters which are regular fp16 tensors that FSDP2 handles perfectly.

Base weights:  Replicate on both GPUs (4GB each - acceptable)
LoRA adapters: Shard with FSDP2 (100MB each - sharded)

Here's what that looks like in practice:

from bitsandbytes.nn import Params4bit
from torch.distributed.fsdp import fully_shard

# Collect all the quantized parameters we want FSDP2 to ignore
ignored_params = {p for p in model.parameters() if isinstance(p, Params4bit)}

# Apply FSDP2 to each transformer layer, skipping quantized weights
for layer in model.base_model.model.model.layers:
    fully_shard(layer, ignored_params=ignored_params)

That's it. Two lines that unlock FSDP2 + QLoRA.

What's happening here?

  1. We identify every Params4bit in the model
  2. We pass them to ignored_params - telling FSDP2 "don't touch these"
  3. FSDP2 shards everything else (the LoRA adapters)

The quantized weights stay intact with their quant_state metadata. FSDP2 never tries to convert them to DTensors. No collision.

But there's a catch. Run this code and you'll hit:

ValueError: Your model contains `DTensor` parameters, 
which is incompatible with DDP.

The Trainer sees your distributed setup and tries to wrap your already-sharded model with DDP. We need to tell it to back off:

class FSDP2Trainer(SFTTrainer):
    def _wrap_model(self, model, training=True, dataloader=None):
        return model  # Already wrapped, skip

trainer = FSDP2Trainer(...)
trainer.accelerator.prepare_model = lambda model, **kwargs: model

Run it. Watch both GPUs light up. Check the loss going down.

torchrun --nproc_per_node=2 train_manual_fsdp2.py

Putting It All Together

Here's the complete script that trains Llama 3.1 8B with FSDP2 + QLoRA on 2 GPUs:

import os
import torch

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,"\
    "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from bitsandbytes.nn import Params4bit
from torch.distributed.fsdp import fully_shard

import torch.distributed as dist
dist.init_process_group(backend="nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
print(f"Rank {dist.get_rank()}: Loading model on GPU {local_rank}")
torch.cuda.set_device(local_rank)


max_seq_length = 2048
torch.set_default_dtype(torch.float16)
model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
dtype = torch.float16

bnb_config = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_compute_dtype    = dtype,
    bnb_4bit_quant_storage    = dtype,  # Key for FSDP compatibility!
)

# Load model - accelerator handles device placement
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation = "sdpa",
    quantization_config = bnb_config,
    # device_map = {"": int(os.environ.get("LOCAL_RANK", 0))},
    # device_map = {"": torch.cuda.current_device()},
    device_map = {"": local_rank},
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

lora_config = LoraConfig(
    r = 64,
    lora_alpha = 128,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout = 0,
    bias = "none",
    task_type = TaskType.CAUSAL_LM,
)

# Apply LoRA
model = get_peft_model(model, lora_config)
ignored_params = {p for p in model.parameters() if isinstance(p, Params4bit)}

# Set requires_grad
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Load dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train[:10%]")

for layer in model.base_model.model.model.layers:
    fully_shard(layer, ignored_params=ignored_params)

class FSDP2Trainer(SFTTrainer):
    def _wrap_model(self, model, training=True, dataloader=None):
        # Skip wrapping - we've already applied FSDP2 manually
        return model

# Use SFTTrainer - it integrates with accelerate
trainer = FSDP2Trainer(
    model = model,
    train_dataset = dataset,
    processing_class = tokenizer,
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 1,
        max_steps = 10,
        logging_steps = 1,
        output_dir = "outputs",
        seed = 3407,
        max_length = max_seq_length,
        fp16 = True,
        report_to = "none",
        dataset_num_proc = 4,
        fsdp="",
        ddp_backend=None,
        save_strategy="no",
        ddp_find_unused_parameters = False,
    ),
)

trainer.accelerator.prepare_model = lambda model, **kwargs: model
trainer.train()

print(f"Rank {int(os.environ.get('RANK', 0))}: Using GPU {torch.cuda.current_device()}")
print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

Let's understand the key decisions:

  1. bnb_4bit_quant_storage=torch.float16

This is subtle but critical. By default, bitsandbytes stores quantized weights as uint8. But FSDP2 expects floating point dtypes. This flag tells bitsandbytes to store the packed 4-bit values in float16 containers.

  1. device_map={"": local_rank}

Each process must load the model on its assigned GPU. LOCAL_RANK is set by torchrun - it's 0 for the first process, 1 for the second.

  1. ignored_params collected before fully_shard()

We collect all Params4bit parameters upfront, then pass the same set to every layer. FSDP2 will replicate these (not shard them).

  1. The monkey-patch: trainer.accelerator.prepare_model = lambda model, **kwargs: model

The Trainer's accelerator tries to wrap the model for distributed training. But we've already applied FSDP2. This lambda tells it "the model is already prepared, return it unchanged."

Launch it:

torchrun --nproc_per_node=2 train.py

You'll see both GPUs working:

{'loss': 2.2167, 'learning_rate': 2e-05, 'epoch': 0.0}
{'loss': 2.0854, 'learning_rate': 1.78e-05, 'epoch': 0.0}
{'loss': 1.8448, 'learning_rate': 1.56e-05, 'epoch': 0.0}
...
{'loss': 1.4309, 'learning_rate': 2.22e-06, 'epoch': 0.01}

Memory breakdown per GPU:

  1. Base model (4-bit, replicated): ~4GB
  2. LoRA adapters (sharded): ~100MB per GPU
  3. Gradients (sharded): ~100MB per GPU
  4. Optimizer states (sharded): ~400MB per GPU
  5. Activations (gradient checkpointing): ~2-3GB

Total: ~7-8GB per GPU - well within 16GB T4 limits.

What features did we enable?

  • Mixed precision - fp16 training with 4-bit base
  • Gradient checkpointing - recompute activations to save memory
  • Parameter sharding - LoRA adapters distributed across GPUs
  • Selective replication - quantized weights replicated (frozen anyway)

CPU offloading isn't enabled here. Why? The quantized base weights require CUDA kernels, they can't run on CPU.

The Graveyard of Failed Attempts

The solution above looks clean. It wasn't. Here's the path that got us there - every dead end taught us something about how these systems actually work.

Attempt #1: Just use Accelerate's FSDP2 plugin

The obvious first try:

from accelerate import Accelerator, FullyShardedDataParallelPlugin

fsdp_plugin = FullyShardedDataParallelPlugin(fsdp_version=2)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

Result: ValueError: Cannot flatten integer dtype tensors

Accelerate's plugin doesn't know about Params4bit. It passes every parameter to FSDP2's sharding machinery. When FSDP2 tries to flatten the quantized weight tensor into a DTensor, it chokes on the non-standard type.

Attempt #2: Apply fully_shard() to every layer

Maybe if we wrap each layer manually?

for layer in model.base_model.model.model.layers:
    fully_shard(layer)

Result: Same error. Params4bit still gets sharded.

fully_shard(layer)
    ↓
Inspects all parameters in layer
    ↓
Finds: q_proj.weight (Params4bit), lora_A.weight (Tensor), lora_B.weight (Tensor)
    ↓
Tries to shard ALL of them
    ↓
Params4bit can't become DTensor

Attempt #3: Override _wrap_model() in Trainer

We discovered FSDP2's ignored_params and applied it. Training started! Then:

ValueError: Your model contains `DTensor` parameters, 
which is incompatible with DDP.

The traceback pointed to accelerator.prepare():

# Inside transformers/trainer.py line 2480
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)

We overrode _wrap_model(), but the Trainer has another place where it wraps the model. Accelerate's prepare() sees DTensors (from our FSDP2 sharding) and complains.

Result: Monkey-patch prepare_model too:

trainer.accelerator.prepare_model = lambda model, **kwargs: model

Attempt #4: Enable CPU offloading

The challenge requires showcasing offloading. Let's add it:

from torch.distributed.fsdp import CPUOffloadPolicy

fully_shard(layer, 
    ignored_params=ignored_params,
    offload_policy=CPUOffloadPolicy(pin_memory=True))

Here's what happens during forward pass with offloading:

  1. LoRA adapter A (on CPU) → move to GPU
  2. Base weight (on GPU, quantized) → dequantize
  3. Compute: output = Base @ input + B @ (A @ input)
  4. LoRA adapter A → move back to CPU

Step 2 is the problem. The quantized base weight can't move to CPU - bitsandbytes dequantization requires CUDA kernels. So you have LoRA on CPU trying to interact with a base that must stay on GPU.

We tried pin_memory=False. Same error. The fundamental issue: CPU offloading assumes homogeneous parameters. Mixed precision with quantized weights breaks that assumption.

What We Learned

  1. FSDP2 is parameter-centric, FSDP1 is module-centric. This seemingly small API difference has huge implications for non-standard parameter types.

  2. Params4bit isn't just a tensor. It's a tensor + metadata. Any system that manipulates tensors directly (FSDP2's DTensor conversion) will break.

  3. ignored_params is the escape hatch. When you have parameters that can't be sharded, tell FSDP2 to leave them alone.

  4. Trainers have opinions. HuggingFace Trainer assumes it controls distribution. When you do manual FSDP, you must intercept multiple wrapping points.

  5. Not all features compose. CPU offloading + 4-bit quantization = conflict. Know when to stop stacking optimizations.

The best debugging happens when you understand what each layer of abstraction is doing. We had to read:

  1. FSDP2's fully_shard() source
  2. FSDPParam._init_sharded_param() where DTensor conversion happens
  3. Accelerate's prepare_model() where the second wrapping occurs

Read the source. Print intermediate values. Build mental models. That's how you debug systems that weren't designed to work together.

What We Built

We trained Llama 3.1 8B on 2 consumer GPUs using FSDP2 + QLoRA. Things that "shouldn't work together" now do:

What Status
FSDP2 parameter sharding Done
4-bit QLoRA quantization Done
HuggingFace Trainer compatibility Done
Mixed precision (fp16) Done
Gradient checkpointing Done
Multi-GPU training Done

Memory per GPU: ~7-8GB (down from 32GB+ for full precision)

The key insight was small: don't shard what can't be sharded. Use ignored_params to let quantized weights replicate while LoRA adapters distribute.

The Bigger Picture

This exercise reveals something important about the current state of ML infrastructure.

We have extraordinary tools:

  1. FSDP2 for elegant distributed training
  2. QLoRA for memory-efficient fine-tuning
  3. HuggingFace for accessible model loading
  4. bitsandbytes for 4-bit quantization

But these tools were built by different teams, at different times, with different assumptions. They don't always compose cleanly.

PyTorch team: "Parameters are tensors"
bitsandbytes: "Parameters are tensors + quantization metadata"
FSDP2: "I'll shard any tensor"
bitsandbytes: "Not mine you won't"

The gap between "works in isolation" and "works together" is where real engineering happens. You read source code. You understand assumptions. You find the escape hatches.

Next Steps

If you're continuing this work, here are paths worth exploring:

  1. torch.ao quantization instead of bitsandbytes

PyTorch's native quantization (torchao) was designed with composability in mind. It might integrate cleaner with FSDP2:

from torchao.quantization import quantize_, int4_weight_only
quantize_(model, int4_weight_only())

Trade-off: Potentially slower kernels than bitsandbytes' hand-tuned CUDA.

  1. Proper checkpoint saving

Our solution skipped saving because DTensors can't serialize through PEFT's save_pretrained(). A proper solution would:

Gather sharded parameters before saving Or use FSDP2's native state dict helpers

  1. Pipeline parallelism with zero-bubble scheduling

The original challenge mentioned this as an alternative path. Instead of sharding parameters, you shard the model vertically - different layers on different GPUs:

from torch.distributed.pipelining import ScheduleInterleavedZeroBubble
This avoids the Params4bit sharding problem entirely by keeping whole layers intact.

Final Thoughts

The most valuable skill in ML engineering isn't knowing every API. It's knowing how to figure things out when APIs don't work as documented.

This project required:

  1. Reading PyTorch FSDP2 source code
  2. Reading Accelerate's model preparation logic
  3. Building mental models of parameter sharding
  4. Testing hypotheses systematically

None of that was in any tutorial. It came from treating errors as clues, not roadblocks.

You have everything you need to solve problems like this. Start with working code. Break it intentionally. Understand why it broke. Fix it. Repeat.

The next "impossible" integration is waiting for someone to find the ignored_params equivalent. Maybe that's you.

Acknowlegements

This work was inspired by Unsloth's Challenge B make QLoRA work with FSDP2. What started as a competition challenge became a deep learning journey through distributed systems, quantization internals, and the messy reality of composing ML infrastructure.

The Unsloth team's challenge wasn't just about getting code to run. It forced us to understand why things break. These aren't questions you encounter in tutorials. They're the questions that build real understanding.