Star Attention ⭐⭐⭐⭐⭐
Star Attention is like giving LLMs a cheat sheet (the anchor block) to quickly understand the big picture (the context)
Paper: Star Attention: Efficient LLM Inference over Long Sequences
Researchers from NVIDIA are interested in long sequence inference due to the quadratic complexity of the self-attention mechanism. While techniques like Flash Attention and Ring Attention improve efficiency, challenges remain in reducing memory usage and increasing inference speed, especially for very long sequences
Hmm..What’s the background?
Star Attention is built on the observation that LLM inference often involves a two-stage process: prompt encoding (processing input and storing key-value vectors) and token generation (attending to the cached key-value vectors and generating new tokens). Recognizing that query tokens often need to attend to all prior tokens while context tokens may only need to attend locally, Star Attention proposes a two-phase approach to efficiently handle long sequences.
So what is proposed in the research paper?
The research paper incorporates several key insights:
Phase 1 (Context Encoding): The context is divided into blocks and distributed across multiple “context” hosts. Each host computes self-attention only for its assigned blocks, achieving linear complexity with respect to context length. An "anchor block" (the first block) is copied to each host to help approximate global attention patterns and address potential attention sink issues.
Phase 2 (Query Encoding and Token Generation): The query is replicated across hosts, attending to the local key-value cache on each. A designated "query host" then computes global attention by efficiently aggregating results from all hosts, minimizing communication overhead. Only the query host updates its key-value cache during this stage
Experiments on RULER and BABILong benchmarks with various Llama models (including 8B and 70B parameter versions) show Star Attention achieves up to 11x faster inference than Ring Attention while maintaining 95-100% accuracy.
What’s next?
The authors identify several areas for future research:
Refining the anchor block mechanism: A deeper understanding of the anchor block's function and how to determine its optimal size in relation to context blocks is necessary
Improving performance on complex tasks: Addressing the challenges Star Attention faces in tasks like multi-hop tracing that require effective communication between context blocks is key
Scaling to longer sequences: Exploring the limits of Star Attention’s scalability to even longer sequences and larger models
Star Attention is like giving LLMs a cheat sheet (the anchor block) to quickly understand the big picture (the context)
Learned something new? Consider sharing it!