×
Community Blog FlashQLA: CP-/Bwd-Friendly Fused Linear Attention Kernels for GDN

FlashQLA: CP-/Bwd-Friendly Fused Linear Attention Kernels for GDN

We officially open-source FlashQLA: a high-performance linear attention kernel library built on TileLang.

1

Introduction

Following the release of Qwen3-Next, Gated Delta Network (GDN) has become the workhorse attention layer across the Qwen family — from Qwen3-Next-80B-A3B all the way to the subsequent Qwen3.5 / Qwen3.6 series. As models scale to 397A17B / 122A10B / 35B / 27B and context windows stretch beyond 256K, the overhead of the GDN block in end-to-end training and inference has become non-negligible.

Today we officially open-source FlashQLA: a high-performance linear attention kernel library built on TileLang. FlashQLA applies reasonable operator fusion and performance optimization to the forward and backward passes of GDN Chunked Prefill, achieving 2-3× forward speedup and 2× backward speedup over the FLA Triton kernel across multiple scenarios on NVIDIA Hopper. The efficiency gains are particularly pronounced in pretraining scenarios and edge-side agentic inference.

Key highlights of this release:

  1. Gate-driven automatic intra-card context parallelism. By exploiting the exponential decay property of the GDN gate, FlashQLA automatically enables intra-card CP under TP, long-sequence, and small-head-count settings, improving GPU SM utilization.
  2. Hardware-friendly algebraic reformulation. We reformulate the forward and backward flows of GDN Chunked Prefill to a certain extent, effectively reducing Tensor Core, CUDA Core, and SFU overhead without sacrificing numerical precision.
  3. TileLang fused warp-specialized kernels. Rather than following the step-by-step decomposition into independent kernels, nor fusing the entire computation flow into a single kernel, we take CP and backward requirements into account, use TileLang to build several key fused kernels, and manually implement warpgroup specialization to overlap data movement, Tensor Core computation, and CUDA Core computation.

FlashQLA code and benchmarks are open-sourced at github.com/QwenLM/FlashQLA.

Key Problems in FLA GDN Chunked Prefill

Let us first review the forward computation flow of GDN Chunked Prefill, taking chunk index i as an example:

2

Ignoring gate preprocessing and CP, each step of this flow corresponds to one kernel in FLA. This flow has two main efficiency problems:

  1. Most of the above are memory-bound kernels. The flow repeatedly reads K, V and other data, while W, U, S as intermediate variables must be written to HBM and then read by the next kernel, incurring significant memory access overhead.
  2. The recurrent nature of the SSM state means that the corresponding third step chunk_gated_delta_rule_fwd_kernel can only launch batch_size * num_heads thread blocks simultaneously, resulting in low GPU utilization in small-model, small-batch, or TP scenarios.

The solutions to these two problems are contradictory. For the first problem, the most intuitive solution is to write a fully-fused kernel, where all data is accessed only once and all intermediate variables are kept on-chip. When batch_size * num_heads is large enough, this is certainly optimal. However, such a solution obviously runs into the second problem: for edge-side inference with small models and batch_size=1, or for large-model online deployments with TP where long-sequence inputs from coding agents etc. cannot launch a large enough batch for chunked prefill, the speedup of a fully-fused kernel over the original FLA implementation is limited.

The earliest solution to the second problem comes from how DeltaNet does context parallelism, which splits a long sequence into multiple sub-sequences, uses S0 = 0 to parallelize the recurrence, and then computes an additional M matrix to correct the recurrent results. This scheme was later optimized to insert a step before the recurrence kernel to compute the S0 of each sub-sequence, and has now been merged into the FLA repository. For CP rank j, the specific preprocessing flow is:

3

However, this CP scheme also has its drawbacks: first, it introduces significant extra computation, with the time complexity of recurrently computing the M matrix even exceeding that of the S matrix; second, it does not work well with fully-fused kernels, because matrix inversion and other steps must be performed before the S0 of each sub-sequence can be computed.

A Balanced Solution: Fusing Kernels While Enabling Intra-Card CP

Based on the two problems above, a compromise solution can be derived: split the GDN Chunked Prefill forward computation into two fused kernels, inserting CP-related preprocessing steps between them. After some transformations and simplifications, the following computation flow is obtained:

4

We also designed a simple mathematical model to automatically determine the degree of parallelism. Let N be the number of chunks in a sequence and L be the number of chunks per CP rank. It is easy to see that the runtime of steps 2.1 and 3 is proportional to L, while the runtime of step 2.2 is proportional to NL; therefore we can choose LN to minimize total time, where λ is a coefficient composed of batch_size, num_heads, and other hyperparameters.

In production, intra-card CP is not always needed. Following the original FLA implementation, step 3 can also increase parallelism by 2-4× via splitting v_head_dim, at the cost of redundant memory access to Q and K. Based on measured data, we enable CP only when batch_size * num_heads <= 40 or batch_size * num_heads <= 56 && seq_len >= 8192.

Further Optimization via Gate Decay

Revisiting the GDN recurrence:

5

For αi ∈(0,1), the influence of each Si on subsequent states decays exponentially, giving it a sliding-window property. For a sufficiently long window size of W, starting computation from Si−W = 0 can obtain the accurate Si, without the need to start from S0. We refer to this process as warmup. On real data, we find that αi is not constantly 1 on 60–80% of linear attention heads, and 6–8 chunks of warmup are sufficient to drive the Si error below the noise floor.

Therefore, for linear attention heads with the sliding-window property, we can design a lighter CP preprocessing flow that discards the computation of the correction term M and directly obtains an equally accurate sub-sequence S0 through warmup:

 

C0

C1

C2

C3

C4

C5

C6

C7

C8

C9

C10

C11

C12

R1

O

O

O

O

O

               

R2

     

X

X

O

O

O

O

       

R3

             

X

X

O

O

O

O

X denotes warming up with a zero initial state until the gate has decayed sufficiently, then writing out the S0 of that CP rank; O denotes subsequent normal recurrent computation. The warmup length for each rank is determined by an independent kernel that collects gate statistics, and the cost of this step is negligible.

TileLang Warp-Specialized Kernel

We implement FlashQLA in TileLang using a warpgroup-specialization pattern: one producer warpgroup and three consumer warpgroups reside in the same SM, exchange data through shared memory, and synchronize via mbarriers.

Forward

In the forward pass, the three consumer warp groups compute V', S, and O respectively, overlapping computation and memory traffic through a ping-pong structure.

6

Notes:

  • S output per chunk is for debugging only; normally only O and the last chunk's S are output.

CP Preprocessing

As mentioned earlier, the CP preprocessing splits into two cases: the original approach (computing both M and S) and the sliding-window approach (computing only S). We designed a single fused kernel that handles both:

7

Notes:

  • The last two steps of WG1 and WG2 correspond to the M matrix computation and are triggered only when required.
  • S is output per chunk only during backward recomputation.

Backward

In the backward pass, we reuse the CP preprocessing kernel from the previous section to recompute the S matrix, then fuse bwd_dv, bwd_dhu, bwd_dqkwg, bwd_wy into a single kernel with corresponding algebraic optimizations. Because of on-chip resource constraints, the backward kernel does not use multi-stage pipelining; instead it relies on the long compute chain to hide memory traffic. The full schedule is available in the FlashQLA repo.

8

Benchmark

We benchmarked FlashQLA against the FLA Triton and FlashInfer baseline (FLA 0.5.0, Triton 3.5.1, FlashInfer 0.6.9, TileLang 0.1.8) on the head configurations used by the Qwen3.5 / Qwen3.6 family — hv ∈64,48,32,24,16,8, corresponding to TP1 through TP8.

9

Specifically, the forward (FWD) benchmarks measure single-kernel latency for different models and TP settings under varying batch lengths, while the backward (BWD) benchmarks examine the relationship between total token count within a batch and latency during a single update step.

Selected H200 single-layer forward results:

10

The speedup grows with TP degree because FlashQLA improves SM utilization via intra-card AutoCP in the exact regimes — TP sharding and small head number — where the baseline leaves SMs idle.

Usage

FlashQLA exposes both a high-level API matching FLA’s signature and low-level fwd/bwd entry points:

import torch
from qla import chunk_gated_delta_rule

o, final_state = chunk_gated_delta_rule(
    q=q,                            # [B, T, H_q, K]
    k=k,                            # [B, T, H_q, K]
    v=v,                            # [B, T, H_v, V]
    g=g,                            # [B, T, H_v]
    beta=beta,                      # [B, T, H_v]
    scale=scale,
    initial_state=initial_state,    # optional, [B, H_v, K, V]
    output_final_state=True,
    cu_seqlens=cu_seqlens,          # optional, varlen support
)

Requirements: SM90, CUDA 12.8+, PyTorch 2.8+. Install:

git clone https://github.com/QwenLM/FlashQLA.git
cd FlashQLA && pip install -v .

Acknowledgments

FlashQLA is inspired by Flash Linear Attention, FlashInfer and TileLang. We thank these communities for the reference implementations.

Citation

If FlashQLA is useful for your research, please cite:

@misc{flashqla2026,
    title  = {FlashQLA: Flash Qwen Linear Attention},
    author = {Zhang, Chengruidong and Lin, Xi and Jiang, Huiqiang and Wang, Zekun and
              Li, Xiao and Cao, Yizhong and Zhuang, Bohan and Men, Rui and Zhang, Jianwei and
              Zheng, Bo and Lin, Junyang and Liu, Dayiheng and Zhou, Jingren},
    year   = {2026},
    publisher = {GitHub},
    howpublished = {\url{https://github.com/QwenLM/FlashQLA}}
}

Source

0 0 0
Share on

Alibaba Cloud Community

1,395 posts | 492 followers

You may also like

Comments