|
| 1 | +# RFC: Multihead Attention and EinsumDense on Keras |
| 2 | + |
| 3 | +| Status | Accepted | |
| 4 | +| :------------ | :------------------------------------------------------ | |
| 5 | +| **RFC #** | [260](https://github.com/tensorflow/community/pull/260) | |
| 6 | +| **Author(s) ** | Hongkun Yu ( [email protected]), Mark Omernick ( [email protected]) | |
| 7 | +| **Sponsor ** | Francois Chollet ( [email protected]) | |
| 8 | +| **Updated** | 2020-06-16 | |
| 9 | + |
| 10 | +## Objective |
| 11 | + |
| 12 | +Introduce the MultiHeadAttention layer and EinsumDense layer to tf.keras. |
| 13 | + |
| 14 | +## Motivation |
| 15 | + |
| 16 | +MultiHeadAttention is very popular and has become standard for deep learning |
| 17 | +libraries. We propose to contribute a flexible well-defined implementation |
| 18 | +inside Keras absorbing common best practices from reference libraries. |
| 19 | + |
| 20 | +## User Benefit |
| 21 | + |
| 22 | +We can standardize the implementation of Transformer layers and use the best |
| 23 | +practice. We offer a rich set of functionalities to different use cases, e.g. |
| 24 | +different project spaces, outputing multi-head attention scores for analysis, |
| 25 | +etc. We also modularize computations to make the MultiHeadAttention layer |
| 26 | +extensible to variants. |
| 27 | + |
| 28 | +## Design Proposal |
| 29 | + |
| 30 | +### Key Features |
| 31 | + |
| 32 | +* Returns multi-headed attention scores, which is commonly useful for |
| 33 | + attention visualization and analysis. |
| 34 | +* Supports query (Q), key (K), value (V) tensors as individual inputs and |
| 35 | + supports projecting Q, K, V to different dimensions. |
| 36 | +* Final outputs projects to user specified dimensions. |
| 37 | +* Using tf.einsum to express high-dimensional computation and adopts |
| 38 | + [tf.keras.layers.experimental.EinsumDense](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/EinsumDense) |
| 39 | + layer. |
| 40 | +* Supports high-dimension attention when target and source are 2D, 3D, etc. |
| 41 | + |
| 42 | +### Code Examples |
| 43 | + |
| 44 | +* How to write a TransformerBlock for an encoder. |
| 45 | + |
| 46 | +```python |
| 47 | +class TransformerBlock(tf.keras.layers.Layer): |
| 48 | + def __init__(self, embed_dim, num_heads, ff_dim): |
| 49 | + super(TransformerBlock, self).__init__() |
| 50 | + self.att = attention.MultiHeadAttention(embed_dim, num_heads) |
| 51 | + self.ffn = tf.keras.Sequential( |
| 52 | + [tf.keras.layers.Dense(ff_dim, activation="relu"), |
| 53 | + tf.keras.layers.Dense(embed_dim),] |
| 54 | + ) |
| 55 | + self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) |
| 56 | + self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) |
| 57 | + |
| 58 | + def call(self, inputs, attention_mask=None): |
| 59 | + attn_output = self.att([inputs, inputs], attention_mask=attention_mask) |
| 60 | + out1 = self.layernorm1(inputs + attn_output) |
| 61 | + ffn_output = self.ffn(out1) |
| 62 | + return self.layernorm2(out1 + ffn_output) |
| 63 | +``` |
| 64 | + |
| 65 | +* Use attention mask to avoid performing attention on padding token indices. |
| 66 | + |
| 67 | +```python |
| 68 | +test_layer = TransformerBlock( |
| 69 | + embed_dim=2, |
| 70 | + num_heads=2, |
| 71 | + ff_dim=4) |
| 72 | +query = np.array([[[0.1, 0.2], [0.0, 0.0]]]) |
| 73 | +mask = np.array([[[1, 0], [1, 0]]], dtype='bool') |
| 74 | +output = test_layer(query, mask) |
| 75 | +``` |
| 76 | + |
| 77 | +* Inside a Transformer decoder, we often want to output the cross-attention |
| 78 | + scores to analyze how the target sequence attend to the source sequence. We |
| 79 | + are able to visualize the alignment according to attention scores. |
| 80 | + |
| 81 | +```python |
| 82 | +test_layer = MultiHeadAttention( |
| 83 | + num_heads=2, key_size=2, return_attention_scores=True) |
| 84 | +target = np.array([[[0.1, 0.2], [0.0, 0.0]]]) |
| 85 | +source = np.array([[[0.1, 0.2], [3.0, 1.0]]]) |
| 86 | +output, scores = test_layer(query=target, value=source) |
| 87 | +scores = tf.math.reduce_sum(scores, axis=1) # shape = (1, 2, 2) |
| 88 | +``` |
| 89 | + |
| 90 | +* Attention beyound sequences. Taking 2D, 3D target and source. |
| 91 | + |
| 92 | +```python |
| 93 | +query_shape = [2, 3, 4, 4] # batch, target, target, embedding. |
| 94 | +value_shape = [2, 3, 2, 4] # batch, source, source, embedding. |
| 95 | +mask_shape = [2, 3, 4, 3, 2] |
| 96 | +query = 10 * np.random.random_sample(query_shape) |
| 97 | +value = 10 * np.random.random_sample(value_shape) |
| 98 | +mask_data = np.random.randint(2, size=mask_shape).astype("bool") |
| 99 | +output = test_layer(query=query, value=value, attention_mask=mask_data) |
| 100 | +``` |
| 101 | + |
| 102 | +### Interface |
| 103 | + |
| 104 | +```python |
| 105 | +class MultiHeadAttention(tf.keras.layers.Layer): |
| 106 | + """MultiHeadAttention layer. |
| 107 | +
|
| 108 | + This is an implementation of multi-headed attention based on "Attention |
| 109 | + is all you Need". If `query`, `key,` `value` are the same, then |
| 110 | + this is self-attention. Each timestep in `query` attends to the |
| 111 | + corresponding sequence in `key`, and returns a fixed-width vector. |
| 112 | +
|
| 113 | + This layer first projects `query`, `key` and `value`. These are |
| 114 | + (effectively) a list of tensors of length `num_attention_heads`, where the |
| 115 | + corresponding shapes are [batch_size, <query dimensions>, key_size], |
| 116 | + [batch_size, <key/value dimensions>, key_size], |
| 117 | + [batch_size, <key/value dimensions>, value_size]. |
| 118 | +
|
| 119 | + Then, the query and key tensors are dot-producted and scaled. These are |
| 120 | + softmaxed to obtain attention probabilities. The value tensors are then |
| 121 | + interpolated by these probabilities, then concatenated back to a single |
| 122 | + tensor. |
| 123 | +
|
| 124 | + Finally, the result tensor with the last dimension as value_size can take an |
| 125 | + linear projection and return. |
| 126 | +
|
| 127 | + Examples: |
| 128 | +
|
| 129 | + Performs 1D cross-attention over two sequence inputs with an attention mask. |
| 130 | + Returns the additional attention weights over heads. |
| 131 | +
|
| 132 | + >>> layer = MultiHeadAttention(num_heads=2, key_size=2, |
| 133 | + ... return_attention_scores=True) |
| 134 | + >>> target = tf.keras.Input(shape=[8, 16]) |
| 135 | + >>> source = tf.keras.Input(shape=[4, 16]) |
| 136 | + >>> mask_tensor = tf.keras.Input(shape=[8, 4]) |
| 137 | + >>> output_tensor, weights = layer(query=target, value=source |
| 138 | + ... attention_mask=mask_tensor) |
| 139 | + >>> print(output_tensor.shape), print(weights.shape) |
| 140 | + (None, 8, 16) (None, 2, 8, 4) |
| 141 | +
|
| 142 | + Performs 2D self-attention over a 5D input tensor on axes 2 and 3. |
| 143 | +
|
| 144 | + >>> layer = MultiHeadAttention(num_heads=2, key_size=2, attention_axes=(2, 3)) |
| 145 | + >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16]) |
| 146 | + >>> output_tensor = layer(query=input_tensor, value=input_tensor) |
| 147 | + >>> print(output_tensor.shape) |
| 148 | + (None, 5, 3, 4, 16) |
| 149 | +
|
| 150 | + Arguments: |
| 151 | + num_heads: Number of attention heads. |
| 152 | + key_size: Size of each attention head for query and key. |
| 153 | + value_size: Size of each attention head for value. |
| 154 | + dropout: Dropout probability for a Dropout layer on attention_scores. |
| 155 | + use_bias: Boolean, whether the dense layers use bias vectors/matrices. |
| 156 | + output_shape: The expected shape of an output tensor, besides the batch and |
| 157 | + sequence dims. If not specified, projects back to the key feature dim. |
| 158 | + attention_axes: axes over which the attention is applied. `None` means |
| 159 | + attention over all axes, but batch, heads, and features. |
| 160 | + return_attention_scores: bool, if `True`, returns the multi-head |
| 161 | + attention scores as an additional output argument. |
| 162 | + kernel_initializer: Initializer for dense layer kernels. |
| 163 | + bias_initializer: Initializer for dense layer biases. |
| 164 | + kernel_regularizer: Regularizer for dense layer kernels. |
| 165 | + bias_regularizer: Regularizer for dense layer biases. |
| 166 | + activity_regularizer: Regularizer for dense layer activity. |
| 167 | + kernel_constraint: Constraint for dense layer kernels. |
| 168 | + bias_constraint: Constraint for dense layer kernels. |
| 169 | + """ |
| 170 | + |
| 171 | + def call(self, query, value, key=None, attention_mask=None): |
| 172 | + """Implements the forward pass. |
| 173 | +
|
| 174 | + Size glossary: |
| 175 | + * Number of heads (H): the number of attention heads. |
| 176 | + * Value size (V): the size of each value embedding per head. |
| 177 | + * Key size (K): the size of each key embedding per head. Equally, the size |
| 178 | + of each query embedding per head. Typically K <= V. |
| 179 | + * Batch dimensions (B). |
| 180 | + * Query (target) attention axes shape (T). |
| 181 | + * Value (source) attention axes shape (S), the rank must match the target. |
| 182 | +
|
| 183 | + Args: |
| 184 | + query: Query `Tensor` of shape `[B, T, dim]`. |
| 185 | + value: Value `Tensor` of shape `[B, S, dim]`. |
| 186 | + key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will |
| 187 | + use `value` for both `key` and `value`, which is the most common case. |
| 188 | + attention_mask: a boolean mask of shape `[B, T, S]`, that prevents |
| 189 | + attention to certain positions. |
| 190 | +
|
| 191 | + Returns: |
| 192 | + attention_output: The result of the computation, of shape [B, T, E], |
| 193 | + where `T` is for target sequence shapes and `E` is the query input last |
| 194 | + dimension if `output_shape` is `None`. Otherwise, the multi-head outputs |
| 195 | + are project to the shape specified by `output_shape`. |
| 196 | + attention_scores: [Optional] multi-head attention coeffients over |
| 197 | + attention axes. |
| 198 | + """ |
| 199 | +``` |
| 200 | + |
| 201 | +### Auxiliary Layers and Changes |
| 202 | + |
| 203 | +* EinsumDense layer |
| 204 | + |
| 205 | +We use `tf.einsum` to implement a dense layer can perform einsum calculations of |
| 206 | +arbitrary dimensionality. This example shows how to instantiate a layer that |
| 207 | +applies the same dense operation to every element in a sequence. Here, the |
| 208 | +'output_shape' has two values (since there are two non-batch dimensions in the |
| 209 | +output); the first dimension in the output_shape is `None`, because the sequence |
| 210 | +dimension `b` has an unknown shape. |
| 211 | + |
| 212 | +```python |
| 213 | +layer = EinsumDense("abc,cd->abd", output_shape=(None, 64), bias_axes="d") |
| 214 | +input_tensor = tf.keras.Input(shape=[32, 128]) |
| 215 | +output_tensor = layer(input_tensor) # output shape is (None, 32, 64) |
| 216 | +``` |
| 217 | + |
| 218 | +* Masked Softmax |
| 219 | + |
| 220 | +Inside the attention computation, we need to mask logits before softmax and it |
| 221 | +has become a common treatment in many applications. We propose to add an |
| 222 | +optional `mask` argument to `tf.nn.softmax`. The downstream keras `Softmax` |
| 223 | +layer will also take an optional `mask` tensor. This `mask` tensor should have |
| 224 | +the same rank as the input tensor and mask elements on the axis which will |
| 225 | +perform softmax. |
| 226 | + |
| 227 | +Inside `MultiHeadAttention` keras layer, we will use the keras `Softmax` layer |
| 228 | +with mask and adjust attention mask shape to match the inputs. The dimension |
| 229 | +expension logic and multi-axes softmax will be handled locally in |
| 230 | +`MultiHeadAttention` layer. |
| 231 | + |
| 232 | +* Keras Dense Attention |
| 233 | + |
| 234 | +We have two changes proposed to |
| 235 | +[tf.keras.layers.Attention](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention). |
| 236 | +(1) The layer call method takes an optional argument, `mask`, which requires two |
| 237 | +tensors, `q_mask` and `v_mask`. They are following keras framework requirements |
| 238 | +with (batch_size, target_length) and (batch_size, source_length) as shapes. This |
| 239 | +limits the flexibility of masking and `MultiHeadAttention` layer generalize the |
| 240 | +attention mask to be (batch dims, target dims, source dims). To be consistent, |
| 241 | +we would like to introduce an optional argument `attention_mask` for |
| 242 | +`tf.keras.layers.Attention`. In the reduced case of `tf.keras.layers.Attention`, |
| 243 | +the shape is (batch_size, target_length, source_length). Whenever |
| 244 | +`attention_mask` is specified, the `mask` argument is OK to be skipped. |
| 245 | +(2) The layer does not return attention scores. We will add the bool argument, |
| 246 | +`return_attention_scores` to the __init__ and return the attention score tensor if |
| 247 | +it is true. |
| 248 | + |
| 249 | +* TFA `MultiHeadAttention` Deprecation and Re-mapping |
| 250 | + |
| 251 | +[MultiHeadAttention](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/multihead_attention.py) |
| 252 | +has been released. The proposed `MultiHeadAttention` has similar `__init__` |
| 253 | +arguments and `call` interface, where the minor differences are argument names |
| 254 | +and the attention `mask` shape. We expect the new `MultiHeadAttention` keras |
| 255 | +layer will cover the functionalities. Once the implementation are merged as |
| 256 | +experimental layers, we will work with TF Addons team to design the deprecation |
| 257 | +and re-mapping procedure. |
| 258 | + |
| 259 | +### Alternatives Considered |
| 260 | + |
| 261 | +We examined multi-head attention layer implemented in various libraries. There |
| 262 | +are a few features that we do not include inside this keras layer and we feel it |
| 263 | +is better to subclass the `MultiHeadAttention` layer to fulfill the needs. |
| 264 | + |
| 265 | +* Attention caching for decoding. Implemented in |
| 266 | + [Flax](https://github.com/google/flax/blob/master/flax/nn/attention.py#L301). |
| 267 | + The caching is a special treatment for inference and we noticied that |
| 268 | + different treatments are required for dynamic or static shape programs. |
| 269 | + Thus, subclassing as a |
| 270 | + [CachedAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention.py) |
| 271 | + layer is the solution inside the model garden. |
| 272 | +* [MultiHeadAttention](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/layers/multihead_attention.py) |
| 273 | + keras layer is also implemented in TF-Addons. The design in this doc covers |
| 274 | + the features in TF-addons implementation but generalizes to more use cases. |
| 275 | + |
| 276 | +### Performance Implications |
| 277 | + |
| 278 | +* We will add microbenchmarks following the common practices of keras layers. |
| 279 | +* We have end-to-end integration/regression tests for models using this layer, |
| 280 | + e.g. BERT. |
| 281 | + |
| 282 | +### Dependencies |
| 283 | + |
| 284 | +No dependencies. |
| 285 | + |
| 286 | +### Engineering Impact |
| 287 | + |
| 288 | +* The keras layer can be tested inside the package. |
| 289 | +* TensorFlow team will maintain the code. |
| 290 | + |
| 291 | +### Platforms and Environments |
| 292 | + |
| 293 | +* Work for all platforms and environments |
| 294 | + |
| 295 | +### Best Practices |
| 296 | + |
| 297 | +* No change for Tensorflow best practices. |
| 298 | + |
| 299 | +### Tutorials and Examples |
| 300 | + |
| 301 | +* Code examples can be found inside Tensorflow Model Garden. For example, an |
| 302 | + encoder |
| 303 | + [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer.py). |
| 304 | + |
| 305 | +* 2D attention example in the |
| 306 | + [unit test](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention_test.py#L135). |
| 307 | + |
| 308 | +### Compatibility |
| 309 | + |
| 310 | +* This is a new layer without compatibility concerns. |
| 311 | +* The proposal works with TFLite, distribution strategy, tf.function, GPU/TPU |
| 312 | + and serializable to SavedModel. These are tested inside TensorFlow Model |
| 313 | + Garden applications. |
| 314 | + |
| 315 | +### User Impacteisum |
| 316 | + |
| 317 | +* We will first introduce the layer as |
| 318 | + `tf.keras.layers.experimental.MultiHeadAttention` and |
| 319 | + `tf.keras.layers.experimental.EinsumDense`. When the APIs are stable and |
| 320 | + functionalities are fully verified, the next step is to graduate as core |
| 321 | + keras layers by removing `experimental` scope. |
| 322 | + |
| 323 | +## Detailed Design |
| 324 | + |
| 325 | +The layer has been implemented as the |
| 326 | +[MultiHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/attention.py#L116) |
| 327 | +inside TensorFlow Model Garden. |
| 328 | + |
| 329 | +First, as we rely on `tf.einsum` to define projections and attention |
| 330 | +computation, we need to figure out the einsum notation of each computation. |
| 331 | +Furthermore, to make the layer generalize to high-dimension cases, i.e. there |
| 332 | +are more than one batch dimensions and attention softmax can be performed on |
| 333 | +multiple axes, we need to track the batch axes and attention axes inside einsum |
| 334 | +notations. We use a vector of chars and use two local methods to generate einsum |
| 335 | +notations for projections and attentions. |
| 336 | + |
| 337 | +Second, the layer by default implements the most common dot-product attention. |
| 338 | +There are various ways to implement the attention computation, so we modulize it |
| 339 | +as two methods `build_attention` and `compute_attention`. Thus, users will be |
| 340 | +able to just override them to get a new keras layer with a novel attention |
| 341 | +method. For example, we implemented |
| 342 | +[TalkingHeadAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) |
| 343 | +introduced by ["Talking-Heads Attention "](https://arxiv.org/abs/2003.02436) |
| 344 | +paper. Using the keras Attention layer as another example, since it supports the |
| 345 | +basic single-head case 1-D attention, we can use it inside `build_attention` |
| 346 | +and `compute_attention`. |
| 347 | + |
| 348 | +## Questions and Discussion Topics |
| 349 | + |
| 350 | +- cuDNN has the |
| 351 | + [multi-head attention](https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnMultiHeadAttnForward) |
| 352 | + function. How do we incorporate it? A: we modularize the attention |
| 353 | + computation components in order to support new low-level functions without |
| 354 | + changing this layer interface. The cuDNN function supports the classic |
| 355 | + dot-product attention with classic input dimensions. We will be able to use |
| 356 | + it once TensorFlow add an op to use it. |
0 commit comments