Member-only story

Unleashing the Power of Multi-Query Attention: A Turbocharged Alternative to Multi-Head Attention

Evergreen Technologies
4 min readJun 17, 2023

--

Author: Ravindra Sadaphule

Introduction

Attention mechanisms have revolutionized the field of natural language processing, giving birth to the Transformer architecture which powers state-of-the-art models like BERT, GPT, and many more. Among these attention mechanisms, Multi-Head Attention has been the star player. But wait, there’s a new player in town — Multi-Query Attention! In this blog post, we will dive into the depths of Multi-Query Attention, understand how it works, and see how it stacks up against traditional Multi-Head Attention.

MQA aka multi-query attention, serves as a more memory-efficient alternative to multi-head attention in progressive scenarios. Here Key and Value vectors are shared across all tokens thus reducing storage and decoder latency by 10X which is critical during inference. This innovation will promote the extensive application of attention-based sequence models in situations where the speed of drawing conclusions is vital.

What is Multi-Head Attention?

Before we dive into Multi-Query Attention, let’s quickly recap Multi-Head Attention. In Multi-Head Attention, the model is able to focus on different parts of the input sequence for each word in the output sequence. It does this by creating multiple sets of Query, Key, and Value vectors (hence the name “Multi-Head”). Each set is used to compute a different weighted sum of the input, and the results are concatenated and linearly transformed into the final output.

import tensorflow as tf
def multi_head_attention(Q, K, V, num_heads):
# Split Q, K, V into multiple heads
Qs = split_into_heads(Q, num_heads)
Ks = split_into_heads(K, num_heads)
Vs = split_into_heads(V, num_heads)

# Perform scaled dot-product attention for each head
outputs = []
for i in range(num_heads):
attn_scores = tf.matmul(Qs[i], Ks[i], transpose_b=True)
attn_scores /= tf.math.sqrt(float(Ks[i].shape[-1]))
attn_weights = tf.nn.softmax(attn_scores, axis=-1)
output = tf.matmul(attn_weights, Vs[i])
outputs.append(output)

# Concatenate and linearly transform the output
concatenated =…

--

--

Evergreen Technologies
Evergreen Technologies

Written by Evergreen Technologies

Decades of experience in collaborative Blog writing, Technical Advisory and Online Training. Read more about me @ https://evergreenllc2020.github.io/about.html

Responses (1)

Write a response