Skip to content

Commit c3d9d99

Browse files
author
ematejska
authored
Merge pull request #260 from saberkun/patch-1
RFC: Multihead Attention and EinsumDense on Keras
2 parents cced595 + 72c0662 commit c3d9d99

File tree

1 file changed

+356
-0
lines changed

1 file changed

+356
-0
lines changed
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
# RFC: Multihead Attention and EinsumDense on Keras
2+
3+
| Status | Accepted |
4+
| :------------ | :------------------------------------------------------ |
5+
| **RFC #** | [260](https://github.com/tensorflow/community/pull/260) |
6+
| **Author(s)** | Hongkun Yu ([email protected]), Mark Omernick ([email protected]) |
7+
| **Sponsor** | Francois Chollet ([email protected]) |
8+
| **Updated** | 2020-06-16 |
9+
10+
## Objective
11+
12+
Introduce the MultiHeadAttention layer and EinsumDense layer to tf.keras.
13+
14+
## Motivation
15+
16+
MultiHeadAttention is very popular and has become standard for deep learning
17+
libraries. We propose to contribute a flexible well-defined implementation
18+
inside Keras absorbing common best practices from reference libraries.
19+
20+
## User Benefit
21+
22+
We can standardize the implementation of Transformer layers and use the best
23+
practice. We offer a rich set of functionalities to different use cases, e.g.
24+
different project spaces, outputing multi-head attention scores for analysis,
25+
etc. We also modularize computations to make the MultiHeadAttention layer
26+
extensible to variants.
27+
28+
## Design Proposal
29+
30+
### Key Features
31+
32+
* Returns multi-headed attention scores, which is commonly useful for
33+
attention visualization and analysis.
34+
* Supports query (Q), key (K), value (V) tensors as individual inputs and
35+
supports projecting Q, K, V to different dimensions.
36+
* Final outputs projects to user specified dimensions.
37+
* Using tf.einsum to express high-dimensional computation and adopts
38+
[tf.keras.layers.experimental.EinsumDense](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/EinsumDense)
39+
layer.
40+
* Supports high-dimension attention when target and source are 2D, 3D, etc.
41+
42+
### Code Examples
43+
44+
* How to write a TransformerBlock for an encoder.
45+
46+
```python
47+
class TransformerBlock(tf.keras.layers.Layer):
48+
def __init__(self, embed_dim, num_heads, ff_dim):
49+
super(TransformerBlock, self).__init__()
50+
self.att = attention.MultiHeadAttention(embed_dim, num_heads)
51+
self.ffn = tf.keras.Sequential(
52+
[tf.keras.layers.Dense(ff_dim, activation="relu"),
53+
tf.keras.layers.Dense(embed_dim),]
54+
)
55+
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
56+
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
57+
58+
def call(self, inputs, attention_mask=None):
59+
attn_output = self.att([inputs, inputs], attention_mask=attention_mask)
60+
out1 = self.layernorm1(inputs + attn_output)
61+
ffn_output = self.ffn(out1)
62+
return self.layernorm2(out1 + ffn_output)
63+
```
64+
65+
* Use attention mask to avoid performing attention on padding token indices.
66+
67+
```python
68+
test_layer = TransformerBlock(
69+
embed_dim=2,
70+
num_heads=2,
71+
ff_dim=4)
72+
query = np.array([[[0.1, 0.2], [0.0, 0.0]]])
73+
mask = np.array([[[1, 0], [1, 0]]], dtype='bool')
74+
output = test_layer(query, mask)
75+
```
76+
77+
* Inside a Transformer decoder, we often want to output the cross-attention
78+
scores to analyze how the target sequence attend to the source sequence. We
79+
are able to visualize the alignment according to attention scores.
80+
81+
```python
82+
test_layer = MultiHeadAttention(
83+
num_heads=2, key_size=2, return_attention_scores=True)
84+
target = np.array([[[0.1, 0.2], [0.0, 0.0]]])
85+
source = np.array([[[0.1, 0.2], [3.0, 1.0]]])
86+
output, scores = test_layer(query=target, value=source)
87+
scores = tf.math.reduce_sum(scores, axis=1) # shape = (1, 2, 2)
88+
```
89+
90+
* Attention beyound sequences. Taking 2D, 3D target and source.
91+
92+
```python
93+
query_shape = [2, 3, 4, 4] # batch, target, target, embedding.
94+
value_shape = [2, 3, 2, 4] # batch, source, source, embedding.
95+
mask_shape = [2, 3, 4, 3, 2]
96+
query = 10 * np.random.random_sample(query_shape)
97+
value = 10 * np.random.random_sample(value_shape)
98+
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
99+
output = test_layer(query=query, value=value, attention_mask=mask_data)
100+
```
101+
102+
### Interface
103+
104+
```python
105+
class MultiHeadAttention(tf.keras.layers.Layer):
106+
"""MultiHeadAttention layer.
107+
108+
This is an implementation of multi-headed attention based on "Attention
109+
is all you Need". If `query`, `key,` `value` are the same, then
110+
this is self-attention. Each timestep in `query` attends to the
111+
corresponding sequence in `key`, and returns a fixed-width vector.
112+
113+
This layer first projects `query`, `key` and `value`. These are
114+
(effectively) a list of tensors of length `num_attention_heads`, where the
115+
corresponding shapes are [batch_size, <query dimensions>, key_size],
116+
[batch_size, <key/value dimensions>, key_size],
117+
[batch_size, <key/value dimensions>, value_size].
118+
119+
Then, the query and key tensors are dot-producted and scaled. These are
120+
softmaxed to obtain attention probabilities. The value tensors are then
121+
interpolated by these probabilities, then concatenated back to a single
122+
tensor.
123+
124+
Finally, the result tensor with the last dimension as value_size can take an
125+
linear projection and return.
126+
127+
Examples:
128+
129+
Performs 1D cross-attention over two sequence inputs with an attention mask.
130+
Returns the additional attention weights over heads.
131+
132+
>>> layer = MultiHeadAttention(num_heads=2, key_size=2,
133+
... return_attention_scores=True)
134+
>>> target = tf.keras.Input(shape=[8, 16])
135+
>>> source = tf.keras.Input(shape=[4, 16])
136+
>>> mask_tensor = tf.keras.Input(shape=[8, 4])
137+
>>> output_tensor, weights = layer(query=target, value=source
138+
... attention_mask=mask_tensor)
139+
>>> print(output_tensor.shape), print(weights.shape)
140+
(None, 8, 16) (None, 2, 8, 4)
141+
142+
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
143+
144+
>>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3))
145+
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
146+
>>> output_tensor = layer(query=input_tensor, value=input_tensor)
147+
>>> print(output_tensor.shape)
148+
(None, 5, 3, 4, 16)
149+
150+
Arguments:
151+
num_heads: Number of attention heads.
152+
key_size: Size of each attention head for query and key.
153+
value_size: Size of each attention head for value.
154+
dropout: Dropout probability for a Dropout layer on attention_scores.
155+
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
156+
output_shape: The expected shape of an output tensor, besides the batch and
157+
sequence dims. If not specified, projects back to the key feature dim.
158+
attention_axes: axes over which the attention is applied. `None` means
159+
attention over all axes, but batch, heads, and features.
160+
return_attention_scores: bool, if `True`, returns the multi-head
161+
attention scores as an additional output argument.
162+
kernel_initializer: Initializer for dense layer kernels.
163+
bias_initializer: Initializer for dense layer biases.
164+
kernel_regularizer: Regularizer for dense layer kernels.
165+
bias_regularizer: Regularizer for dense layer biases.
166+
activity_regularizer: Regularizer for dense layer activity.
167+
kernel_constraint: Constraint for dense layer kernels.
168+
bias_constraint: Constraint for dense layer kernels.
169+
"""
170+
171+
def call(self, query, value, key=None, attention_mask=None):
172+
"""Implements the forward pass.
173+
174+
Size glossary:
175+
* Number of heads (H): the number of attention heads.
176+
* Value size (V): the size of each value embedding per head.
177+
* Key size (K): the size of each key embedding per head. Equally, the size
178+
of each query embedding per head. Typically K <= V.
179+
* Batch dimensions (B).
180+
* Query (target) attention axes shape (T).
181+
* Value (source) attention axes shape (S), the rank must match the target.
182+
183+
Args:
184+
query: Query `Tensor` of shape `[B, T, dim]`.
185+
value: Value `Tensor` of shape `[B, S, dim]`.
186+
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
187+
use `value` for both `key` and `value`, which is the most common case.
188+
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
189+
attention to certain positions.
190+
191+
Returns:
192+
attention_output: The result of the computation, of shape [B, T, E],
193+
where `T` is for target sequence shapes and `E` is the query input last
194+
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
195+
are project to the shape specified by `output_shape`.
196+
attention_scores: [Optional] multi-head attention coeffients over
197+
attention axes.
198+
"""
199+
```
200+
201+
### Auxiliary Layers and Changes
202+
203+
* EinsumDense layer
204+
205+
We use `tf.einsum` to implement a dense layer can perform einsum calculations of
206+
arbitrary dimensionality. This example shows how to instantiate a layer that
207+
applies the same dense operation to every element in a sequence. Here, the
208+
'output_shape' has two values (since there are two non-batch dimensions in the
209+
output); the first dimension in the output_shape is `None`, because the sequence
210+
dimension `b` has an unknown shape.
211+
212+
```python
213+
layer = EinsumDense("abc,cd->abd", output_shape=(None, 64), bias_axes="d")
214+
input_tensor = tf.keras.Input(shape=[32, 128])
215+
output_tensor = layer(input_tensor) # output shape is (None, 32, 64)
216+
```
217+
218+
* Masked Softmax
219+
220+
Inside the attention computation, we need to mask logits before softmax and it
221+
has become a common treatment in many applications. We propose to add an
222+
optional `mask` argument to `tf.nn.softmax`. The downstream keras `Softmax`
223+
layer will also take an optional `mask` tensor. This `mask` tensor should have
224+
the same rank as the input tensor and mask elements on the axis which will
225+
perform softmax.
226+
227+
Inside `MultiHeadAttention` keras layer, we will use the keras `Softmax` layer
228+
with mask and adjust attention mask shape to match the inputs. The dimension
229+
expension logic and multi-axes softmax will be handled locally in
230+
`MultiHeadAttention` layer.
231+
232+
* Keras Dense Attention
233+
234+
We have two changes proposed to
235+
[tf.keras.layers.Attention](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention).
236+
(1) The layer call method takes an optional argument, `mask`, which requires two
237+
tensors, `q_mask` and `v_mask`. They are following keras framework requirements
238+
with (batch_size, target_length) and (batch_size, source_length) as shapes. This
239+
limits the flexibility of masking and `MultiHeadAttention` layer generalize the
240+
attention mask to be (batch dims, target dims, source dims). To be consistent,
241+
we would like to introduce an optional argument `attention_mask` for
242+
`tf.keras.layers.Attention`. In the reduced case of `tf.keras.layers.Attention`,
243+
the shape is (batch_size, target_length, source_length). Whenever
244+
`attention_mask` is specified, the `mask` argument is OK to be skipped.
245+
(2) The layer does not return attention scores. We will add the bool argument,
246+
`return_attention_scores` to the __init__ and return the attention score tensor if
247+
it is true.
248+
249+
* TFA `MultiHeadAttention` Deprecation and Re-mapping
250+
251+
[MultiHeadAttention](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/multihead_attention.py)
252+
has been released. The proposed `MultiHeadAttention` has similar `__init__`
253+
arguments and `call` interface, where the minor differences are argument names
254+
and the attention `mask` shape. We expect the new `MultiHeadAttention` keras
255+
layer will cover the functionalities. Once the implementation are merged as
256+
experimental layers, we will work with TF Addons team to design the deprecation
257+
and re-mapping procedure.
258+
259+
### Alternatives Considered
260+
261+
We examined multi-head attention layer implemented in various libraries. There
262+
are a few features that we do not include inside this keras layer and we feel it
263+
is better to subclass the `MultiHeadAttention` layer to fulfill the needs.
264+
265+
* Attention caching for decoding. Implemented in
266+
[Flax](https://github.com/google/flax/blob/master/flax/nn/attention.py#L301).
267+
The caching is a special treatment for inference and we noticied that
268+
different treatments are required for dynamic or static shape programs.
269+
Thus, subclassing as a
270+
[CachedAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention.py)
271+
layer is the solution inside the model garden.
272+
* [MultiHeadAttention](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/multihead_attention.py)
273+
keras layer is also implemented in TF-Addons. The design in this doc covers
274+
the features in TF-addons implementation but generalizes to more use cases.
275+
276+
### Performance Implications
277+
278+
* We will add microbenchmarks following the common practices of keras layers.
279+
* We have end-to-end integration/regression tests for models using this layer,
280+
e.g. BERT.
281+
282+
### Dependencies
283+
284+
No dependencies.
285+
286+
### Engineering Impact
287+
288+
* The keras layer can be tested inside the package.
289+
* TensorFlow team will maintain the code.
290+
291+
### Platforms and Environments
292+
293+
* Work for all platforms and environments
294+
295+
### Best Practices
296+
297+
* No change for Tensorflow best practices.
298+
299+
### Tutorials and Examples
300+
301+
* Code examples can be found inside Tensorflow Model Garden. For example, an
302+
encoder
303+
[Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer.py).
304+
305+
* 2D attention example in the
306+
[unit test](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention_test.py#L135).
307+
308+
### Compatibility
309+
310+
* This is a new layer without compatibility concerns.
311+
* The proposal works with TFLite, distribution strategy, tf.function, GPU/TPU
312+
and serializable to SavedModel. These are tested inside TensorFlow Model
313+
Garden applications.
314+
315+
### User Impacteisum
316+
317+
* We will first introduce the layer as
318+
`tf.keras.layers.experimental.MultiHeadAttention` and
319+
`tf.keras.layers.experimental.EinsumDense`. When the APIs are stable and
320+
functionalities are fully verified, the next step is to graduate as core
321+
keras layers by removing `experimental` scope.
322+
323+
## Detailed Design
324+
325+
The layer has been implemented as the
326+
[MultiHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention.py#L116)
327+
inside TensorFlow Model Garden.
328+
329+
First, as we rely on `tf.einsum` to define projections and attention
330+
computation, we need to figure out the einsum notation of each computation.
331+
Furthermore, to make the layer generalize to high-dimension cases, i.e. there
332+
are more than one batch dimensions and attention softmax can be performed on
333+
multiple axes, we need to track the batch axes and attention axes inside einsum
334+
notations. We use a vector of chars and use two local methods to generate einsum
335+
notations for projections and attentions.
336+
337+
Second, the layer by default implements the most common dot-product attention.
338+
There are various ways to implement the attention computation, so we modulize it
339+
as two methods `build_attention` and `compute_attention`. Thus, users will be
340+
able to just override them to get a new keras layer with a novel attention
341+
method. For example, we implemented
342+
[TalkingHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py)
343+
introduced by ["Talking-Heads Attention "](https://arxiv.org/abs/2003.02436)
344+
paper. Using the keras Attention layer as another example, since it supports the
345+
basic single-head case 1-D attention, we can use it inside `build_attention`
346+
and `compute_attention`.
347+
348+
## Questions and Discussion Topics
349+
350+
- cuDNN has the
351+
[multi-head attention](https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnMultiHeadAttnForward)
352+
function. How do we incorporate it? A: we modularize the attention
353+
computation components in order to support new low-level functions without
354+
changing this layer interface. The cuDNN function supports the classic
355+
dot-product attention with classic input dimensions. We will be able to use
356+
it once TensorFlow add an op to use it.

0 commit comments

Comments
 (0)