The Challenge

A) Convert nf4 to Triton. [Difficulty: Hard] [Max points: 14]

  1. Goal: Convert a nf4 quantized tensor into fp16 or bf16 into a single Triton kernel The double dequant of the absmax and weight forming must be done in 1 Triton kernel. Must work on Tesla T4.
  2. Must be faster than Unsloth's fast_dequantize by 1.15x or more, and not use large intermediate memory buffers.
  3. Must not use torch.compile, but can use trace.enabled to help on writing Triton kernels.
  4. Good material: Unsloth fast_dequantize function, also bitsandbytes dequantize_blockwise
  5. Use test_dequantize_function to test your implementation.
  6. No CUDA allowed. Custom CUDA inside of the Triton is allowed.
  7. Watch Tim's videos on Youtube: 8-bit Optimizers
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

For example, we can test our implementation via:

from unsloth.kernels.utils import fast_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)
test_dequantize(unsloth_dequantize)
> 5.320246934890747

The elapsed time for our implementation over 1000 trials is 5.38 seconds or so.

PEFT also has one, which should be mostly identical to Unsloth's version, albeit slightly slower.

from peft.utils.integrations import dequantize_module_weight as peft_dequantize
test_dequantize(peft_dequantize)

> 5.588372230529785

Write your Triton kernel below, and test it:```from triton import jit import triton import triton.language as tl

@triton.jit def _your_dequantize_nf4_kernel():

### TRITON CODE GOES HERE
return

def _your_dequantize_nf4(weight, quant_state):

### SETUP TRITON LAUNCH HERE
return None

def your_dequantize_nf4(weight): return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)


TEST IT BELOW:

test_dequantize(your_dequantize_nf4)

CALCULATE SPEEDUP (hopefully 1.15x faster or more)

test_dequantize(unsloth_dequantize) / test_dequantize(your_dequantize_nf4)


Marking Criteria for A) Max points = 14

if attemped_A: A_score = 0 if single_triton_kernel: A_score += 3 speedup = old_time / new_time if speedup <= 1.00: A_score -= 3 if speedup >= 1.05: A_score += 1 if speedup >= 1.10: A_score += 2 if speedup >= 1.15: A_score += 2 if kernel_works_in_torch_compile: A_score += 1 else: A_score -= 1 if custom_asm_works: A_score += 3 if uses_cache_eviction: A_score += 1 if tested_in_f16_and_bf16: A_score += 1 else: A_score -= 1 final_score += A_score else: final_score += 0 ```

Demystifying NF4 Dequantization

Understanding NF4 Dequantization with a Single Number

Let's say you want to store the weight value -0.347 in your neural network, but you only have 4 bits to work with.

Step 1: The Codebook

NF4 uses a lookup table of 16 carefully chosen values (since 4 bits = 2^4 = 16 possibilities):

Index:  0     1      2      3      4      5      6      7
Value: -1.0  -0.70  -0.52  -0.39  -0.28  -0.18  -0.09   0.0

Index:  8     9     10     11     12     13     14     15
Value:  0.08  0.16   0.25   0.34   0.44   0.56   0.72   1.0

These values are normalized between -1 and 1. To store -0.347, we:

Find the closest normalized value: index 3 = -0.39 Store just the number 3 (4 bits!)

Step 2: The Scale Factor (absmax)

But wait, what if our actual weight was -3.47, not -0.347? We need a scale factor.

Every 64 weights share one absmax (absolute maximum) value. Let's say absmax = 10.0.

To dequantize:

dequantized_weight = codebook[3] × absmax
                   = -0.39 × 10.0
                   = -3.9

Close to our original -3.47!

Step 3: Double Quantization (the twist!)

Here's where it gets clever. Storing absmax as float32 for every 64 weights still takes a lot of memory. So NF4 quantizes the absmax values too!

The absmax = 10.0 itself gets:

  1. Normalized and quantized to uint8 (256 possible values)
  2. Stored with its own scale factor (state2.absmax)

Step 3 Extended: Double Dequantization with Real Numbers

Let's trace through dequantizing one weight using actual values from your kernel:

Given:

  • Quantized weight (4-bit): 5
  • Quantized absmax (uint8): 200
  • state2.absmax (float32): 2.0
  • state2.code (lookup table for absmax)
  • offset: 0.0218

First, dequantize the absmax:

# Look up the normalized absmax value
state2.code[200] = 0.75

# Scale it
dequantized_absmax = 0.75 × 2.0 + 0.0218
                   = 1.5 + 0.0218
                   = 1.5218

Then, dequantize the weight:

# Look up the NF4 value
code[5] = 0.2

# Scale by the dequantized absmax
final_weight = 0.2 × 1.5218
             = 0.304

The memory savings:

  • Without double quantization: 64 weights + 1 float32 absmax = 32 bytes + 4 bytes
  • With double quantization: 64 weights + 1 uint8 absmax = 32 bytes + 1 byte
  • Savings: 75% less memory for scale factors

The offset mystery: You might wonder why there's an offset added. This shifts the quantization range to better match the distribution of absmax values—it's like centering your data before quantization.

Writing My First Triton Kernel

The Challenge

Armed with an understanding of NF4 dequantization, I set out to write a Triton kernel that could:

  1. Dequantize millions of weights in parallel
  2. Beat Unsloth's implementation by 1.15x
  3. Do it all in a single kernel (no intermediate buffers)

Easy, right?

Narrator: It was not easy.

Attempt #1: The Naive Approach

My first thought: "Triton is like NumPy but on GPU. I'll just translate the logic!"

@triton.jit
def my_first_kernel(weight_ptr, output_ptr, ...):
    # Load everything, dequantize, store
    # How hard could it be?

First wall: What's a program_id?

In CPU programming, you write code that runs once. In GPU programming, you write code that runs thousands of times in parallel.

Think of it like this:

  • CPU: "Process all 1 million elements"
  • GPU: "Hey, 4000 workers! Each of you process 256 elements"

tl.program_id(0) tells each worker: "Which chunk am I responsible for?"

pid = tl.program_id(0) # Am I worker 0, 1, 2, ... ?

If pid=3 and BLOCK_SIZE=256, this worker handles elements 768-1023.

Second wall: Memory is not contiguous

I assumed I could just do:

weights = tl.load(weight_ptr + pid * BLOCK_SIZE)

But NF4 weights are packed two 4-bit values per byte! So:

  • 256 weights = 128 bytes
  • I need to load 128 bytes, then unpack them
weight_start = (pid * BLOCK_SIZE) // 2  # Divide by 2 for packing
weights = tl.load(weight_ptr + weight_start + tl.arange(0, BLOCK_SIZE // 2))

Third wall: The absmax mapping

Each 64 weights share one absmax. If I'm processing 256 weights, I need 4 absmax values.

The index math:

# Which absmax does my first weight use?
absmax_start_idx = (pid * BLOCK_SIZE) // 64

But here's the kicker: those absmax values are also quantized! Each needs to be dequantized using state2.

I realized the kernel needs to:

  1. Figure out which state2.absmax block I'm in
  2. Load that one float32 value
  3. Use it to dequantize my 4 absmax values on-the-fly
  4. Use those to dequantize my 256 weights

All without creating intermediate buffers.

Version 1: "It compiles! Ship it!"

@triton.jit
def nf4_dequant_v1(weight_ptr, absmax_ptr, code_ptr, output_ptr, ...):
    pid = tl.program_id(0)

    # Load weights
    weights = tl.load(weight_ptr + pid * 256)

    # Unpack 4-bit values
    high = (weights >> 4) & 0xF
    low = weights & 0xF

    # Lookup in codebook
    high_float = tl.load(code_ptr + high)
    low_float = tl.load(code_ptr + low)

    # Store (somehow?)
    tl.store(output_ptr + ???, ???)

Result: Compilation error. I forgot weights are packed, need to load half as many bytes!

Lesson learned: GPU memory layout matters. A lot.

Version 2: "Fixed the packing!"

weight_start = (pid * BLOCK_SIZE) // 2  # Aha! Divide by 2
weights = tl.load(weight_ptr + weight_start + tl.arange(0, BLOCK_SIZE // 2))

high = (weights >> 4) & 0xF
low = weights & 0xF

Result: Compiled! But got 100% mismatch with Unsloth.

The bug: I was storing high/low weights in the wrong order. The interleaving was backwards!

# Wrong:
tl.store(output_ptr + offsets, low_weights)
tl.store(output_ptr + offsets + 1, high_weights)

# Right:
tl.store(output_ptr + offsets, high_weights)
tl.store(output_ptr + offsets + 1, low_weights)

Lesson learned: Bit ordering is not intuitive. Test with small examples first!

Version 3: "The absmax mystery"

Now I needed to handle the absmax values. Each 64 weights share one absmax:

absmax_start_idx = (pid * BLOCK_SIZE) // 64
# Load 4 absmax values (256 weights / 64 = 4)
absmax_values = tl.load(absmax_ptr + absmax_start_idx + tl.arange(0, 4))

These absmax values are also quantized as uint8. Result: Values were 100x too small. I was using the uint8 values directly instead of dequantizing them.

Lesson learned: Read the spec carefully. Double quantization means double work.

Version 4: "The state2 revelation"

I needed to dequantize the absmax values:

# First attempt - wrong!
absmax_float = absmax_uint8 / 255.0 * state2_absmax

Result: Still wrong. Values were negative (absmax should always be positive!)

The bug: I was missing the offset and using the wrong formula. The correct approach:

absmax_float = tl.load(state2_code_ptr + absmax_uint8) * state2_absmax + offset

That state2.code lookup was the key. And the offset shifts the range.

Lesson learned: Quantization schemes have subtle details. Every parameter matters.

Version 5: "So close! Only 7 errors..."

Mismatched elements: 7 / 13,651,968 (0.0%)
Greatest absolute difference: 1.14e-05 (just over 1e-05 tolerance!)

The bug: I was casting everything to float32 for "better precision":

high_weight_float = tl.load(code_ptr + high_weight).to(tl.float32)

But Unsloth doesn't do this! Removing the casts:

high_weight_float = tl.load(code_ptr + high_weight)  # Keep native dtype

Result: Down to 1 error out of 58 million elements!

Version 6: "The optimization phase"

With correctness achieved, time to optimize. Added tl.gather for codebook lookups:

# Load codebook once into registers
codebook = tl.load(code_ptr + tl.arange(0, 16))

# Gather from it (faster than scattered loads)
high_weight_float = tl.gather(codebook, high_weight, axis=0)

Result: 1.55x speedup over Unsloth!

The Final Kernel: A Guided Tour

@triton.jit
def nf4_dequant_kernel(
    weight_ptr,           # Packed 4-bit weights (uint8)
    absmax_ptr,           # Quantized absmax values (uint8)
    code_ptr,             # NF4 codebook (16 float values)
    state2_absmax_ptr,    # Scale factors for absmax
    state2_code_ptr,      # Codebook for absmax
    offset,               # Offset for absmax dequantization
    output_ptr,           # Dequantized weights output
    N,                    # Total number of weights
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    # Load NF4 codebook into registers
    codebook = tl.load(code_ptr + tl.arange(0, 16))

    # Calculate absmax indices
    absmax_start_idx = (pid * BLOCK_SIZE) // 64
    state2_idx = absmax_start_idx // 256

    # Load state2.absmax
    state2_absmax = tl.load(state2_absmax_ptr + state2_idx)

    # Load packed weights
    weight_start = (pid * BLOCK_SIZE) // 2
    weights = tl.load(weight_ptr + weight_start + tl.arange(0, BLOCK_SIZE // 2))

    # Unpack 4-bit values
    high_weight = (weights >> 4) & 0xF
    lower_weight = weights & 0xF

    # Lookup in NF4 codebook
    high_weight_float = tl.gather(codebook, high_weight, axis=0)
    lower_weight_float = tl.gather(codebook, lower_weight, axis=0)

    # Dequantize absmax values
    byte_indices = tl.arange(0, BLOCK_SIZE // 2)
    global_absmax_idx = absmax_start_idx + (byte_indices // 32)
    absmax_uint8 = tl.load(absmax_ptr + global_absmax_idx)
    absmax_float = tl.load(state2_code_ptr + absmax_uint8) * state2_absmax + offset

    # Apply scale factors
    dequant_high = high_weight_float * absmax_float
    dequant_low = lower_weight_float * absmax_float

    # Store interleaved results
    output_offsets = (pid * BLOCK_SIZE) + 2 * tl.arange(0, BLOCK_SIZE // 2)
    tl.store(output_ptr + output_offsets, dequant_high)
    tl.store(output_ptr + output_offsets + 1, dequant_low)

Part 1: The Codebook Optimization

codebook = tl.load(code_ptr + tl.arange(0, 16))

Why this matters: The NF4 codebook is only 16 values, but we'll access it 256 times (once per weight in this block). Loading it once into registers is ~20x faster than hitting global memory repeatedly.

The alternative: tl.load(code_ptr + index) for each weight would cause 256 scattered memory accesses.

Part 2: The Index Math

absmax_start_idx = (pid * BLOCK_SIZE) // 64
state2_idx = absmax_start_idx // 256
Program 0: weights [0-255]    → absmax [0-3]    → state2[0]
Program 1: weights [256-511]  → absmax [4-7]    → state2[0]
Program 2: weights [512-767]  → absmax [8-11]   → state2[0]
Program 3: weights [768-1023] → absmax [12-15]  → state2[0]

Key insight: Multiple programs can share the same state2.absmax value. We calculate which one we need, load it once, and reuse it.

Part 3: The Unpacking

high_weight = (weights >> 4) & 0xF
lower_weight = weights & 0xF
Byte: 10110011
      ││││││││
      ││││└┴┴┴─→ low nibble  (0011 = 3)
      └┴┴┴────→ high nibble (1011 = 11)

This is why we loaded BLOCK_SIZE // 2 bytes—each byte gives us 2 weights.

Part 4: The Gather Operation

high_weight_float = tl.gather(codebook, high_weight, axis=0)

What's happening:

  • high_weight is an array: [3, 11, 5, 14, ...] (128 values)
  • codebook is an array: [-1.0, -0.70, ..., 1.0] (16 values)
  • gather looks up: [codebook[3], codebook[11], codebook[5], ...]

Why it's fast: All 128 lookups happen in parallel, reading from registers instead of memory.

Part 5: The Double Dequantization

global_absmax_idx = absmax_start_idx + (byte_indices // 32)
absmax_uint8 = tl.load(absmax_ptr + global_absmax_idx)
absmax_float = tl.load(state2_code_ptr + absmax_uint8) * state2_absmax + offset

Step 1: Figure out which absmax each byte needs

  • Bytes 0-31 use absmax[absmax_start_idx]
  • Bytes 32-63 use absmax[absmax_start_idx + 1]
  • Bytes 64-95 use absmax[absmax_start_idx + 2]
  • Bytes 96-127 use absmax[absmax_start_idx + 3]

Step 2: Load those 4 absmax values (they're uint8)

Step 3: Dequantize them:

absmax_float = state2.code[absmax_uint8] * state2.absmax + offset

This is the 'double' in double quantization—we're dequantizing the scale factors themselves!

Part 6: The Final Multiplication

dequant_high = high_weight_float * absmax_float
dequant_low = lower_weight_float * absmax_float

What's happening element-wise:

dequant_high[0] = codebook[high_weight[0]] * absmax_float[0]
dequant_high[1] = codebook[high_weight[1]] * absmax_float[1]
...

Each weight gets multiplied by its corresponding absmax scale factor.

Part 7: The Interleaved Store

output_offsets = (pid * BLOCK_SIZE) + 2 * tl.arange(0, BLOCK_SIZE // 2)
tl.store(output_ptr + output_offsets, dequant_high)
tl.store(output_ptr + output_offsets + 1, dequant_low)

Memory layout:

Byte 0 → [high[0], low[0]]
Byte 1 → [high[1], low[1]]
Byte 2 → [high[2], low[2]]

Why two stores: We're writing to even positions (0, 2, 4...) and odd positions (1, 3, 5...) separately. This matches how the weights were originally packed.

What We Learned

The key insights:

  1. Quantization is compression with structure - NF4 isn't random. Every design choice (4 bits, block size of 64, double quantization) comes from balancing memory, accuracy, and compute.

  2. GPU programming is about thinking in parallel - Not "how do I process this element?" but "how do 10,000 workers process 10,000 elements simultaneously?"

  3. Memory is the bottleneck - Our speedup didn't come from faster math. It came from loading the codebook once instead of 256 times. On GPUs, arithmetic is free; memory access is expensive.

  4. Precision is a choice, not a default - Casting to float32 made things slower and less accurate. Sometimes the "obvious" optimization is wrong.

Every abstraction: Triton, PyTorch, even Python—is hiding complexity. When you need performance, you have to peek under the hood. Not to reject the abstractions, but to use them wisely.

Resources