Tensor Product Attention Is All You Need

IIIS, Tsinghua University,   Shanghai Qi Zhi Institute,
University of California, Los Angeles,   TapTap
*Equal contribution Tech lead Corresponding author

Abstract

Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, significantly shrinking KV cache size at inference time. By factorizing these representations into contextual low-rank components (contextual factorization) and seamlessly integrating with RoPE, TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation of language modeling tasks, we demonstrate that T6 exceeds the performance of standard Transformer baselines including MHA, MQA, GQA, and MLA across various metrics, including perplexity and a range of renowned evaluation benchmarks. Notably, TPA's memory efficiency enables the processing of significantly longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern language models.

Tensor Product Attention

  • We propose Tensor Product Attention (TPA), a mechanism that factorizes Q, K, and V activations using contextual tensor-decompositions to achieve 10 times or more reduction in inference-time KV cache size relative to standard attention mechanism [Vaswani et al., 2017] with improved performance compared to previous methods such as MHA, MQA, GQA, and MLA.
  • In addition, we unify existing attention mechanisms by revealing that MHA, MQA, and GQA all arise naturally as non-contextual variants of TPA.
  • We introduce Tensor ProducT ATTenTion Transformer (T6), a new TPA-based model architecture for sequence modeling. On language modeling experiments, T6 consistently improves validation perplexity and downstream evaluation performance with reduced KV cache size.
  • We show TPA integrates seamlessly with RoPE [Su et al., 2024], facilitating easy adoption in popular foundation model architectures such as LLaMA and Gemma.

Experimental Results

Training and validation loss plot (medium)

Training loss and validation loss of pretraining medium-size (353M) models with different attention mechanisms on the FineWeb-Edu-100B dataset.

Training and validation loss plot (large)

Training loss and validation loss of pretraining large-size (773M) models with different attention mechanisms on the FineWeb-Edu-100B dataset.

The evaluation results of pretrained medium-size models (353M)

The evaluation results of pretrained medium-size models (353M)

The evaluation results of pretrained large-size models (773M)

Downstream Evaluation. We evaluate zero-shot and two-shot performance on standard benchmarks, including ARC (Yadav et al., 2019), BoolQ (Clark et al., 2019), HellaSwag (Zellers et al., 2019), OBQA (Mihaylov et al., 2018), PIQA (Bisk et al., 2020), WinoGrande (Sakaguchi et al., 2020) and MMLU (Hendrycks et al., 2021), using the lm-evaluation-harness codebase (Gao et al., 2024).

Tensor Factorization of Queries, Keys, and Values

Tensor Factorization of Queries, Keys, and Values

KV Caching and Memory Reduction

KV Caching and Memory Reduction

Citation

Please cite the paper and star this repo if you use Tensor Product Attention (TPA) or the Tensor ProducT ATTenTion Transformer (T6) and find it interesting/useful, thanks!

@article{zhang2025tensor,
    title={Tensor Product Attention Is All You Need},
    author={Zhang, Yifan and Liu, Yifeng and Yuan, Huizhuo and Qin, Zhen and Yuan, Yang and Gu, Quanquan and Yao, Andrew Chi-Chih},
    journal={arXiv preprint arXiv:2501.06425},
    year={2025},
}