This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathsrl_bert.py
260 lines (231 loc) · 13.3 KB
/
srl_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from typing import Dict, List, Optional, Any, Union
from overrides import overrides
import torch
from torch.nn.modules import Linear, Dropout
import torch.nn.functional as F
from pytorch_pretrained_bert.modeling import BertModel
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.models.srl_util import convert_bio_tags_to_conll_format
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn.util import get_lengths_from_binary_sequence_mask, viterbi_decode
from allennlp.training.metrics.srl_eval_scorer import SrlEvalScorer, DEFAULT_SRL_EVAL_PATH
@Model.register("srl_bert")
class SrlBert(Model):
"""
Parameters
----------
vocab : ``Vocabulary``, required
A Vocabulary, required in order to compute sizes for input/output projections.
model : ``Union[str, BertModel]``, required.
A string describing the BERT model to load or an already constructed BertModel.
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
Used to initialize the model parameters.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
label_smoothing : ``float``, optional (default = 0.0)
Whether or not to use label smoothing on the labels when computing cross entropy loss.
ignore_span_metric: ``bool``, optional (default = False)
Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
srl_eval_path: ``str``, optional (default=``DEFAULT_SRL_EVAL_PATH``)
The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
which is located at allennlp/tools/srl-eval.pl . If ``None``, srl-eval.pl is not used.
"""
def __init__(self,
vocab: Vocabulary,
bert_model: Union[str, BertModel],
embedding_dropout: float = 0.0,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None,
label_smoothing: float = None,
ignore_span_metric: bool = False,
srl_eval_path: str = DEFAULT_SRL_EVAL_PATH) -> None:
super(SrlBert, self).__init__(vocab, regularizer)
if isinstance(bert_model, str):
self.bert_model = BertModel.from_pretrained(bert_model)
else:
self.bert_model = bert_model
self.num_classes = self.vocab.get_vocab_size("labels")
if srl_eval_path is not None:
# For the span based evaluation, we don't want to consider labels
# for verb, because the verb index is provided to the model.
self.span_metric = SrlEvalScorer(srl_eval_path, ignore_classes=["V"])
else:
self.span_metric = None
self.tag_projection_layer = Linear(self.bert_model.config.hidden_size, self.num_classes)
self.embedding_dropout = Dropout(p=embedding_dropout)
self._label_smoothing = label_smoothing
self.ignore_span_metric = ignore_span_metric
initializer(self)
def forward(self, # type: ignore
tokens: Dict[str, torch.Tensor],
verb_indicator: torch.Tensor,
metadata: List[Any],
tags: torch.LongTensor = None):
# pylint: disable=arguments-differ
"""
Parameters
----------
tokens : Dict[str, torch.LongTensor], required
The output of ``TextField.as_array()``, which should typically be passed directly to a
``TextFieldEmbedder``. For this model, this must be a `SingleIdTokenIndexer` which
indexes wordpieces from the BERT vocabulary.
verb_indicator: torch.LongTensor, required.
An integer ``SequenceFeatureField`` representation of the position of the verb
in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
all zeros, in the case that the sentence has no verbal predicate.
tags : torch.LongTensor, optional (default = None)
A torch tensor representing the sequence of integer gold class labels
of shape ``(batch_size, num_tokens)``
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
metadata containg the original words in the sentence, the verb to compute the
frame for, and start offsets for converting wordpieces back to a sequence of words,
under 'words', 'verb' and 'offsets' keys, respectively.
Returns
-------
An output dictionary consisting of:
logits : torch.FloatTensor
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
unnormalised log probabilities of the tag classes.
class_probabilities : torch.FloatTensor
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
a distribution of the tag classes per word.
loss : torch.FloatTensor, optional
A scalar loss to be optimised.
"""
mask = get_text_field_mask(tokens)
bert_embeddings, _ = self.bert_model(input_ids=tokens["tokens"],
token_type_ids=verb_indicator,
attention_mask=mask,
output_all_encoded_layers=False)
embedded_text_input = self.embedding_dropout(bert_embeddings)
batch_size, sequence_length, _ = embedded_text_input.size()
logits = self.tag_projection_layer(embedded_text_input)
reshaped_log_probs = logits.view(-1, self.num_classes)
class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size,
sequence_length,
self.num_classes])
output_dict = {"logits": logits, "class_probabilities": class_probabilities}
# We need to retain the mask in the output dictionary
# so that we can crop the sequences to remove padding
# when we do viterbi inference in self.decode.
output_dict["mask"] = mask
# We add in the offsets here so we can compute the un-wordpieced tags.
words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata])
output_dict["words"] = list(words)
output_dict["verb"] = list(verbs)
output_dict["wordpiece_offsets"] = list(offsets)
if tags is not None:
loss = sequence_cross_entropy_with_logits(logits,
tags,
mask,
label_smoothing=self._label_smoothing)
if not self.ignore_span_metric and self.span_metric is not None and not self.training:
batch_verb_indices = [example_metadata["verb_index"] for example_metadata in metadata]
batch_sentences = [example_metadata["words"] for example_metadata in metadata]
# Get the BIO tags from decode()
# TODO (nfliu): This is kind of a hack, consider splitting out part
# of decode() to a separate function.
batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
batch_conll_predicted_tags = [convert_bio_tags_to_conll_format(tags) for
tags in batch_bio_predicted_tags]
batch_bio_gold_tags = [example_metadata["gold_tags"] for example_metadata in metadata]
batch_conll_gold_tags = [convert_bio_tags_to_conll_format(tags) for
tags in batch_bio_gold_tags]
self.span_metric(batch_verb_indices,
batch_sentences,
batch_conll_predicted_tags,
batch_conll_gold_tags)
output_dict["loss"] = loss
return output_dict
@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Does constrained viterbi decoding on class probabilities output in :func:`forward`. The
constraint simply specifies that the output tags must be a valid BIO sequence. We add a
``"tags"`` key to the dictionary with the result.
NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
decoding produces low quality output if you decode on top of word representations directly,
because the model gets confused by the 'missing' positions (which is sensible as it is trained
to perform tagging on wordpieces, not words).
Secondly, it's important that the indices we use to recover words from the wordpieces are the
start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
when we select out the word tags from the wordpiece tags. This happens in the case that a word
is split into multiple word pieces, and then we take the last tag of the word, which might
correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
"""
all_predictions = output_dict['class_probabilities']
sequence_lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]).data.tolist()
if all_predictions.dim() == 3:
predictions_list = [all_predictions[i].detach().cpu() for i in range(all_predictions.size(0))]
else:
predictions_list = [all_predictions]
wordpiece_tags = []
word_tags = []
transition_matrix = self.get_viterbi_pairwise_potentials()
start_transitions = self.get_start_transitions()
# **************** Different ********************
# We add in the offsets here so we can compute the un-wordpieced tags.
for predictions, length, offsets in zip(predictions_list,
sequence_lengths,
output_dict["wordpiece_offsets"]):
max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix,
allowed_start_transitions=start_transitions)
tags = [self.vocab.get_token_from_index(x, namespace="labels")
for x in max_likelihood_sequence]
wordpiece_tags.append(tags)
word_tags.append([tags[i] for i in offsets])
output_dict['wordpiece_tags'] = wordpiece_tags
output_dict['tags'] = word_tags
return output_dict
def get_metrics(self, reset: bool = False):
if self.ignore_span_metric:
# Return an empty dictionary if ignoring the
# span metric
return {}
else:
metric_dict = self.span_metric.get_metric(reset=reset)
# This can be a lot of metrics, as there are 3 per class.
# we only really care about the overall metrics, so we filter for them here.
return {x: y for x, y in metric_dict.items() if "overall" in x}
def get_viterbi_pairwise_potentials(self):
"""
Generate a matrix of pairwise transition potentials for the BIO labels.
The only constraint implemented here is that I-XXX labels must be preceded
by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
constraint, pairs of labels which do not satisfy this constraint have a
pairwise potential of -inf.
Returns
-------
transition_matrix : torch.Tensor
A (num_labels, num_labels) matrix of pairwise potentials.
"""
all_labels = self.vocab.get_index_to_token_vocabulary("labels")
num_labels = len(all_labels)
transition_matrix = torch.zeros([num_labels, num_labels])
for i, previous_label in all_labels.items():
for j, label in all_labels.items():
# I labels can only be preceded by themselves or
# their corresponding B tag.
if i != j and label[0] == 'I' and not previous_label == 'B' + label[1:]:
transition_matrix[i, j] = float("-inf")
return transition_matrix
def get_start_transitions(self):
"""
In the BIO sequence, we cannot start the sequence with an I-XXX tag.
This transition sequence is passed to viterbi_decode to specify this constraint.
Returns
-------
start_transitions : torch.Tensor
The pairwise potentials between a START token and
the first token of the sequence.
"""
all_labels = self.vocab.get_index_to_token_vocabulary("labels")
num_labels = len(all_labels)
start_transitions = torch.zeros(num_labels)
for i, label in all_labels.items():
if label[0] == "I":
start_transitions[i] = float("-inf")
return start_transitions