Skip to content

Sync token-classification pipeline with Hub spec #34064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 28 additions & 36 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from ..models.bert.tokenization_bert import BasicTokenizer
from ..utils import (
ExplicitEnum,
add_end_docstrings,
is_tf_available,
is_torch_available,
)
from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args
from .base import ArgumentHandler, ChunkPipeline, Dataset


if is_tf_available():
Expand Down Expand Up @@ -60,40 +59,6 @@ class AggregationStrategy(ExplicitEnum):
MAX = "max"


@add_end_docstrings(
build_pipeline_init_args(has_tokenizer=True),
r"""
ignore_labels (`List[str]`, defaults to `["O"]`):
A list of labels to ignore.
grouped_entities (`bool`, *optional*, defaults to `False`):
DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
same entity together in the predictions or not.
stride (`int`, *optional*):
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
The strategy to fuse (or not) tokens based on the model prediction.

- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Words will simply use the tag of the first token of the word when there
is ambiguity.
- "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Word entity will simply be the token with the maximum score.""",
)
class TokenClassificationPipeline(ChunkPipeline):
"""
Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
Expand Down Expand Up @@ -224,6 +189,33 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
Args:
inputs (`str` or `List[str]`):
One or several texts (or one list of texts) for token classification.
ignore_labels (`List[str]`, defaults to `["O"]`):
A list of labels to ignore.
stride (`int`, *optional*):
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
The strategy to fuse (or not) tokens based on the model prediction.

- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Words will simply use the tag of the first token of the word when there
is ambiguity.
- "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
end up with different tags. Word entity will simply be the token with the maximum score.

Return:
A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
Expand Down
8 changes: 8 additions & 0 deletions tests/pipelines/test_pipelines_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from huggingface_hub import TokenClassificationOutputElement

from transformers import (
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
Expand All @@ -26,6 +27,7 @@
)
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
is_torch_available,
nested_simplify,
Expand Down Expand Up @@ -103,6 +105,9 @@ def run_pipeline_test(self, token_classifier, _):
for i in range(n)
],
)
for output_element in nested_simplify(outputs):
compare_pipeline_output_to_hub_spec(output_element, TokenClassificationOutputElement)

outputs = token_classifier(["list of strings", "A simple string that is quite a bit longer"])
self.assertIsInstance(outputs, list)
self.assertEqual(len(outputs), 2)
Expand Down Expand Up @@ -137,6 +142,9 @@ def run_pipeline_test(self, token_classifier, _):
],
)

for output_element in nested_simplify(outputs):
compare_pipeline_output_to_hub_spec(output_element, TokenClassificationOutputElement)

self.run_aggregation_strategy(model, tokenizer)

def run_aggregation_strategy(self, model, tokenizer):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_pipeline_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ImageToTextInput,
ObjectDetectionInput,
QuestionAnsweringInput,
TokenClassificationInput,
ZeroShotImageClassificationInput,
)

Expand All @@ -47,6 +48,7 @@
ImageToTextPipeline,
ObjectDetectionPipeline,
QuestionAnsweringPipeline,
TokenClassificationPipeline,
ZeroShotImageClassificationPipeline,
)
from transformers.testing_utils import (
Expand Down Expand Up @@ -132,6 +134,7 @@
"image-to-text": (ImageToTextPipeline, ImageToTextInput),
"object-detection": (ObjectDetectionPipeline, ObjectDetectionInput),
"question-answering": (QuestionAnsweringPipeline, QuestionAnsweringInput),
"token-classification": (TokenClassificationPipeline, TokenClassificationInput),
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
}

Expand Down
Loading