-
Notifications
You must be signed in to change notification settings - Fork 35
Fix Arrow serializaiton for SpanArray multidoc support #181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ef5d3b1
311764f
d34810a
7595cdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,51 +27,50 @@ | |
from text_extensions_for_pandas.array.span import SpanArray | ||
from text_extensions_for_pandas.array.token_span import TokenSpanArray | ||
from text_extensions_for_pandas.array.tensor import TensorArray | ||
from text_extensions_for_pandas.array.string_table import StringTable | ||
|
||
|
||
class ArrowSpanType(pa.PyExtensionType): | ||
""" | ||
PyArrow extension type definition for conversions to/from Span columns | ||
""" | ||
BEGINS_NAME = "span_begins" | ||
ENDS_NAME = "span_ends" | ||
TARGET_TEXT_DICT_NAME = "target_text" | ||
|
||
TARGET_TEXT_KEY = b"target_text" # metadata key/value gets serialized to bytes | ||
BEGINS_NAME = "char_begins" | ||
ENDS_NAME = "char_ends" | ||
|
||
def __init__(self, index_dtype, target_text): | ||
def __init__(self, index_dtype, target_text_dict_dtype): | ||
""" | ||
Create an instance of a Span data type with given index type and | ||
target text that will be stored in Field metadata. | ||
target text dictionary type. The dictionary type will hold target text ids | ||
that map to a dictionary of document target texts. | ||
|
||
:param index_dtype: | ||
:param target_text: | ||
:param index_dtype: type for the begin, end index arrays | ||
:param target_text_dict_dtype: type for the target text dictionary array | ||
""" | ||
assert pa.types.is_integer(index_dtype) | ||
|
||
# Store target text as field metadata | ||
metadata = {self.TARGET_TEXT_KEY: target_text} | ||
|
||
fields = [ | ||
pa.field(self.BEGINS_NAME, index_dtype, metadata=metadata), | ||
pa.field(self.ENDS_NAME, index_dtype) | ||
pa.field(self.BEGINS_NAME, index_dtype), | ||
pa.field(self.ENDS_NAME, index_dtype), | ||
pa.field(self.TARGET_TEXT_DICT_NAME, target_text_dict_dtype) | ||
] | ||
|
||
pa.PyExtensionType.__init__(self, pa.struct(fields)) | ||
|
||
def __reduce__(self): | ||
index_dtype = self.storage_type[self.BEGINS_NAME].type | ||
metadata = self.storage_type[self.BEGINS_NAME].metadata | ||
return ArrowSpanType, (index_dtype, metadata) | ||
target_text_dict_dtype = self.storage_type[self.TARGET_TEXT_DICT_NAME].type | ||
return ArrowSpanType, (index_dtype, target_text_dict_dtype) | ||
|
||
|
||
class ArrowTokenSpanType(pa.PyExtensionType): | ||
""" | ||
PyArrow extension type definition for conversions to/from TokenSpan columns | ||
""" | ||
|
||
TARGET_TEXT_KEY = ArrowSpanType.TARGET_TEXT_KEY | ||
BEGINS_NAME = "token_begins" | ||
ENDS_NAME = "token_ends" | ||
TARGET_TEXT_DICT_NAME = "token_spans" | ||
|
||
def __init__(self, index_dtype, target_text, num_char_span_splits): | ||
""" | ||
|
@@ -133,10 +132,15 @@ def span_to_arrow(char_span: SpanArray) -> pa.ExtensionArray: | |
begins_array = pa.array(char_span.begin) | ||
ends_array = pa.array(char_span.end) | ||
|
||
typ = ArrowSpanType(begins_array.type, char_span.target_text) | ||
# Create a dictionary array from StringTable used in this span | ||
dictionary = pa.array([char_span._string_table.unbox(s) | ||
for s in char_span._string_table.things]) | ||
target_text_dict_array = pa.DictionaryArray.from_arrays(char_span._text_ids, dictionary) | ||
|
||
typ = ArrowSpanType(begins_array.type, target_text_dict_array.type) | ||
fields = list(typ.storage_type) | ||
|
||
storage = pa.StructArray.from_arrays([begins_array, ends_array], fields=fields) | ||
storage = pa.StructArray.from_arrays([begins_array, ends_array, target_text_dict_array], fields=fields) | ||
|
||
return pa.ExtensionArray.from_storage(typ, storage) | ||
|
||
|
@@ -154,13 +158,21 @@ def arrow_to_span(extension_array: pa.ExtensionArray) -> SpanArray: | |
raise ValueError("Only pyarrow.Array with a single chunk is supported") | ||
extension_array = extension_array.chunk(0) | ||
|
||
# NOTE: workaround for bug in parquet reading | ||
if pa.types.is_struct(extension_array.type): | ||
index_dtype = extension_array.field(ArrowSpanType.BEGINS_NAME).type | ||
target_text_dict_dtype = extension_array.field(ArrowSpanType.TARGET_TEXT_DICT_NAME).type | ||
extension_array = pa.ExtensionArray.from_storage( | ||
ArrowSpanType(index_dtype, target_text_dict_dtype), | ||
extension_array) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason, parquet doesn't return an extension array, but the struct array used as storage. Seems like a bug, so I'll follow up on it, but this workaround seemed ok for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, works for me. |
||
|
||
assert pa.types.is_struct(extension_array.storage.type) | ||
|
||
# Get target text from the begins field metadata and decode string | ||
metadata = extension_array.storage.type[ArrowSpanType.BEGINS_NAME].metadata | ||
target_text = metadata[ArrowSpanType.TARGET_TEXT_KEY] | ||
if isinstance(target_text, bytes): | ||
target_text = target_text.decode() | ||
# Create target text StringTable and text_ids from dictionary array | ||
target_text_dict_array = extension_array.storage.field(ArrowSpanType.TARGET_TEXT_DICT_NAME) | ||
table_texts = target_text_dict_array.dictionary.to_pylist() | ||
string_table = StringTable.from_things(table_texts) | ||
text_ids = target_text_dict_array.indices.to_numpy() | ||
|
||
# Get the begins/ends pyarrow arrays | ||
begins_array = extension_array.storage.field(ArrowSpanType.BEGINS_NAME) | ||
|
@@ -170,7 +182,7 @@ def arrow_to_span(extension_array: pa.ExtensionArray) -> SpanArray: | |
begins = begins_array.to_numpy() | ||
ends = ends_array.to_numpy() | ||
|
||
return SpanArray(target_text, begins, ends) | ||
return SpanArray((string_table, text_ids), begins, ends) | ||
|
||
|
||
def token_span_to_arrow(token_span: TokenSpanArray) -> pa.ExtensionArray: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,11 @@ | |
import os | ||
import tempfile | ||
import unittest | ||
from distutils.version import LooseVersion | ||
|
||
# noinspection PyPackageRequirements | ||
import pytest | ||
import pyarrow as pa | ||
from pandas.tests.extension import base | ||
|
||
# import pytest fixtures | ||
|
@@ -455,17 +457,46 @@ def test_addition(self): | |
|
||
class CharSpanArrayIOTests(ArrayTestBase): | ||
|
||
@pytest.mark.skip("Temporarily disabled until Feather support reimplemented") | ||
def test_feather(self): | ||
arr = self._make_spans_of_tokens() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you mind modifying these tests so that they write out a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a separate test for multi-doc. |
||
df = pd.DataFrame({'Span': arr}) | ||
|
||
with tempfile.TemporaryDirectory() as dirpath: | ||
filename = os.path.join(dirpath, 'char_span_array_test.feather') | ||
filename = os.path.join(dirpath, 'span_array_test.feather') | ||
df.to_feather(filename) | ||
df_read = pd.read_feather(filename) | ||
pd.testing.assert_frame_equal(df, df_read) | ||
|
||
def test_feather_multi_doc(self): | ||
arr = self._make_spans_of_tokens() | ||
df1 = pd.DataFrame({'Span': arr}) | ||
|
||
arr = SpanArray( | ||
"Have at it.", np.array([0, 5, 8]), np.array([4, 7, 11]) | ||
) | ||
df2 = pd.DataFrame({'Span': arr}) | ||
|
||
df = pd.concat([df1, df2], ignore_index=True) | ||
self.assertFalse(df["Span"].array.is_single_document) | ||
|
||
with tempfile.TemporaryDirectory() as dirpath: | ||
filename = os.path.join(dirpath, 'span_array_multi_doc_test.feather') | ||
df.to_feather(filename) | ||
df_read = pd.read_feather(filename) | ||
pd.testing.assert_frame_equal(df, df_read) | ||
|
||
@pytest.mark.skipif(LooseVersion(pa.__version__) < LooseVersion("2.0.0"), | ||
reason="Nested Parquet data types only supported in Arrow >= 2.0.0") | ||
def test_parquet(self): | ||
arr = self._make_spans_of_tokens() | ||
df = pd.DataFrame({'Span': arr}) | ||
|
||
with tempfile.TemporaryDirectory() as dirpath: | ||
filename = os.path.join(dirpath, "span_array_test.parquet") | ||
df.to_parquet(filename) | ||
df_read = pd.read_parquet(filename) | ||
pd.testing.assert_frame_equal(df, df_read) | ||
|
||
|
||
@pytest.fixture | ||
def dtype(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This uses the private
char_span._string_table
and._text_ids
. It didn't seem like we really needed accessors for these so I left it like this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this works. You might want to consider giving the
StringTable
class ato_arrow_dictionary
method instead though.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well the dictionary in pyarrow is just an ordinary array, we just need to make sure that the indices match those in the
_text_ids
array, which I think they do. I believe I should unbox the "thing" first though.