Skip to content

Fix Arrow serializaiton for SpanArray multidoc support #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions text_extensions_for_pandas/array/arrow_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,50 @@
from text_extensions_for_pandas.array.span import SpanArray
from text_extensions_for_pandas.array.token_span import TokenSpanArray
from text_extensions_for_pandas.array.tensor import TensorArray
from text_extensions_for_pandas.array.string_table import StringTable


class ArrowSpanType(pa.PyExtensionType):
"""
PyArrow extension type definition for conversions to/from Span columns
"""
BEGINS_NAME = "span_begins"
ENDS_NAME = "span_ends"
TARGET_TEXT_DICT_NAME = "target_text"

TARGET_TEXT_KEY = b"target_text" # metadata key/value gets serialized to bytes
BEGINS_NAME = "char_begins"
ENDS_NAME = "char_ends"

def __init__(self, index_dtype, target_text):
def __init__(self, index_dtype, target_text_dict_dtype):
"""
Create an instance of a Span data type with given index type and
target text that will be stored in Field metadata.
target text dictionary type. The dictionary type will hold target text ids
that map to a dictionary of document target texts.

:param index_dtype:
:param target_text:
:param index_dtype: type for the begin, end index arrays
:param target_text_dict_dtype: type for the target text dictionary array
"""
assert pa.types.is_integer(index_dtype)

# Store target text as field metadata
metadata = {self.TARGET_TEXT_KEY: target_text}

fields = [
pa.field(self.BEGINS_NAME, index_dtype, metadata=metadata),
pa.field(self.ENDS_NAME, index_dtype)
pa.field(self.BEGINS_NAME, index_dtype),
pa.field(self.ENDS_NAME, index_dtype),
pa.field(self.TARGET_TEXT_DICT_NAME, target_text_dict_dtype)
]

pa.PyExtensionType.__init__(self, pa.struct(fields))

def __reduce__(self):
index_dtype = self.storage_type[self.BEGINS_NAME].type
metadata = self.storage_type[self.BEGINS_NAME].metadata
return ArrowSpanType, (index_dtype, metadata)
target_text_dict_dtype = self.storage_type[self.TARGET_TEXT_DICT_NAME].type
return ArrowSpanType, (index_dtype, target_text_dict_dtype)


class ArrowTokenSpanType(pa.PyExtensionType):
"""
PyArrow extension type definition for conversions to/from TokenSpan columns
"""

TARGET_TEXT_KEY = ArrowSpanType.TARGET_TEXT_KEY
BEGINS_NAME = "token_begins"
ENDS_NAME = "token_ends"
TARGET_TEXT_DICT_NAME = "token_spans"

def __init__(self, index_dtype, target_text, num_char_span_splits):
"""
Expand Down Expand Up @@ -133,10 +132,15 @@ def span_to_arrow(char_span: SpanArray) -> pa.ExtensionArray:
begins_array = pa.array(char_span.begin)
ends_array = pa.array(char_span.end)

typ = ArrowSpanType(begins_array.type, char_span.target_text)
# Create a dictionary array from StringTable used in this span
dictionary = pa.array([char_span._string_table.unbox(s)
for s in char_span._string_table.things])
target_text_dict_array = pa.DictionaryArray.from_arrays(char_span._text_ids, dictionary)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the private char_span._string_table and ._text_ids. It didn't seem like we really needed accessors for these so I left it like this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this works. You might want to consider giving the StringTable class a to_arrow_dictionary method instead though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the dictionary in pyarrow is just an ordinary array, we just need to make sure that the indices match those in the _text_ids array, which I think they do. I believe I should unbox the "thing" first though.


typ = ArrowSpanType(begins_array.type, target_text_dict_array.type)
fields = list(typ.storage_type)

storage = pa.StructArray.from_arrays([begins_array, ends_array], fields=fields)
storage = pa.StructArray.from_arrays([begins_array, ends_array, target_text_dict_array], fields=fields)

return pa.ExtensionArray.from_storage(typ, storage)

Expand All @@ -154,13 +158,21 @@ def arrow_to_span(extension_array: pa.ExtensionArray) -> SpanArray:
raise ValueError("Only pyarrow.Array with a single chunk is supported")
extension_array = extension_array.chunk(0)

# NOTE: workaround for bug in parquet reading
if pa.types.is_struct(extension_array.type):
index_dtype = extension_array.field(ArrowSpanType.BEGINS_NAME).type
target_text_dict_dtype = extension_array.field(ArrowSpanType.TARGET_TEXT_DICT_NAME).type
extension_array = pa.ExtensionArray.from_storage(
ArrowSpanType(index_dtype, target_text_dict_dtype),
extension_array)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, parquet doesn't return an extension array, but the struct array used as storage. Seems like a bug, so I'll follow up on it, but this workaround seemed ok for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, works for me.


assert pa.types.is_struct(extension_array.storage.type)

# Get target text from the begins field metadata and decode string
metadata = extension_array.storage.type[ArrowSpanType.BEGINS_NAME].metadata
target_text = metadata[ArrowSpanType.TARGET_TEXT_KEY]
if isinstance(target_text, bytes):
target_text = target_text.decode()
# Create target text StringTable and text_ids from dictionary array
target_text_dict_array = extension_array.storage.field(ArrowSpanType.TARGET_TEXT_DICT_NAME)
table_texts = target_text_dict_array.dictionary.to_pylist()
string_table = StringTable.from_things(table_texts)
text_ids = target_text_dict_array.indices.to_numpy()

# Get the begins/ends pyarrow arrays
begins_array = extension_array.storage.field(ArrowSpanType.BEGINS_NAME)
Expand All @@ -170,7 +182,7 @@ def arrow_to_span(extension_array: pa.ExtensionArray) -> SpanArray:
begins = begins_array.to_numpy()
ends = ends_array.to_numpy()

return SpanArray(target_text, begins, ends)
return SpanArray((string_table, text_ids), begins, ends)


def token_span_to_arrow(token_span: TokenSpanArray) -> pa.ExtensionArray:
Expand Down
35 changes: 33 additions & 2 deletions text_extensions_for_pandas/array/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import os
import tempfile
import unittest
from distutils.version import LooseVersion

# noinspection PyPackageRequirements
import pytest
import pyarrow as pa
from pandas.tests.extension import base

# import pytest fixtures
Expand Down Expand Up @@ -455,17 +457,46 @@ def test_addition(self):

class CharSpanArrayIOTests(ArrayTestBase):

@pytest.mark.skip("Temporarily disabled until Feather support reimplemented")
def test_feather(self):
arr = self._make_spans_of_tokens()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind modifying these tests so that they write out a SpanArray containing spans over two different document texts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a separate test for multi-doc.

df = pd.DataFrame({'Span': arr})

with tempfile.TemporaryDirectory() as dirpath:
filename = os.path.join(dirpath, 'char_span_array_test.feather')
filename = os.path.join(dirpath, 'span_array_test.feather')
df.to_feather(filename)
df_read = pd.read_feather(filename)
pd.testing.assert_frame_equal(df, df_read)

def test_feather_multi_doc(self):
arr = self._make_spans_of_tokens()
df1 = pd.DataFrame({'Span': arr})

arr = SpanArray(
"Have at it.", np.array([0, 5, 8]), np.array([4, 7, 11])
)
df2 = pd.DataFrame({'Span': arr})

df = pd.concat([df1, df2], ignore_index=True)
self.assertFalse(df["Span"].array.is_single_document)

with tempfile.TemporaryDirectory() as dirpath:
filename = os.path.join(dirpath, 'span_array_multi_doc_test.feather')
df.to_feather(filename)
df_read = pd.read_feather(filename)
pd.testing.assert_frame_equal(df, df_read)

@pytest.mark.skipif(LooseVersion(pa.__version__) < LooseVersion("2.0.0"),
reason="Nested Parquet data types only supported in Arrow >= 2.0.0")
def test_parquet(self):
arr = self._make_spans_of_tokens()
df = pd.DataFrame({'Span': arr})

with tempfile.TemporaryDirectory() as dirpath:
filename = os.path.join(dirpath, "span_array_test.parquet")
df.to_parquet(filename)
df_read = pd.read_parquet(filename)
pd.testing.assert_frame_equal(df, df_read)


@pytest.fixture
def dtype():
Expand Down
13 changes: 13 additions & 0 deletions text_extensions_for_pandas/array/thing_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def merge_things(cls, things: Union[Sequence[Any], np.ndarray]):
str_ids[i] = new_table.maybe_add_thing(things[i])
return new_table, str_ids

@classmethod
def from_things(cls, things: Union[Sequence[Any], np.ndarray]):
"""
Factory method for creating a ThingTable from a sequence of unique things.

:param things: sequence of unique things to be added to the ThingTable.
:return: A ThingTable containing the elements of `things`.
"""
new_table = cls()
for thing in things:
new_table.add_thing(thing)
return new_table

def thing_to_id(self, thing: Any) -> int:
"""
:param thing: A thing to look up in this table
Expand Down