Skip to content

Commit bd61bc2

Browse files
committed
feat: add ArrowJSONtype to extend pyarrow for JSONDtype
1 parent b6c1428 commit bd61bc2

File tree

3 files changed

+89
-4
lines changed

3 files changed

+89
-4
lines changed

db_dtypes/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030

3131
from db_dtypes import core
3232
from db_dtypes.version import __version__
33-
from . import _versions_helpers
3433

34+
from . import _versions_helpers
3535

3636
date_dtype_name = "dbdate"
3737
time_dtype_name = "dbtime"
@@ -50,7 +50,7 @@
5050
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
5151
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
5252
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
53-
from db_dtypes.json import JSONArray, JSONDtype
53+
from db_dtypes.json import ArrowJSONType, JSONArray, JSONDtype
5454
else:
5555
JSONArray = None
5656
JSONDtype = None
@@ -359,7 +359,7 @@ def __sub__(self, other):
359359
)
360360

361361

362-
if not JSONArray or not JSONDtype:
362+
if not JSONArray or not JSONDtype or not ArrowJSONType:
363363
__all__ = [
364364
"__version__",
365365
"DateArray",
@@ -370,6 +370,7 @@ def __sub__(self, other):
370370
else:
371371
__all__ = [
372372
"__version__",
373+
"ArrowJSONType",
373374
"DateArray",
374375
"DateDtype",
375376
"JSONDtype",

db_dtypes/json.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def construct_array_type(cls):
6464
"""Return the array type associated with this dtype."""
6565
return JSONArray
6666

67+
def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray:
68+
"""Convert the pyarrow array to the extension array."""
69+
return JSONArray(array)
70+
6771

6872
class JSONArray(arrays.ArrowExtensionArray):
6973
"""Extension array that handles BigQuery JSON data, leveraging a string-based
@@ -92,6 +96,10 @@ def __init__(self, values) -> None:
9296
else:
9397
raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}")
9498

99+
def __arrow_array__(self):
100+
"""Convert to an arrow array. This is required for pyarrow extension."""
101+
return self.pa_data
102+
95103
@classmethod
96104
def _box_pa(
97105
cls, value, pa_type: pa.DataType | None = None
@@ -151,7 +159,12 @@ def _serialize_json(value):
151159
def _deserialize_json(value):
152160
"""A static method that converts a JSON string back into its original value."""
153161
if not pd.isna(value):
154-
return json.loads(value)
162+
# Attempt to interpret the value as a JSON object.
163+
# If it's not valid JSON, treat it as a regular string.
164+
try:
165+
return json.loads(value)
166+
except json.JSONDecodeError:
167+
return value
155168
else:
156169
return value
157170

@@ -244,3 +257,39 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
244257
result[mask] = self._dtype.na_value
245258
result[~mask] = data[~mask].pa_data.to_numpy()
246259
return result
260+
261+
262+
class ArrowJSONType(pa.ExtensionType):
263+
"""Arrow extension type for the `dbjson` Pandas extension type."""
264+
265+
def __init__(self) -> None:
266+
super().__init__(pa.string(), "dbjson")
267+
268+
def __arrow_ext_serialize__(self) -> bytes:
269+
# No parameters are necessary
270+
return b""
271+
272+
def __eq__(self, other):
273+
if isinstance(other, pyarrow.BaseExtensionType):
274+
return type(self) == type(other)
275+
else:
276+
return NotImplemented
277+
278+
def __ne__(self, other) -> bool:
279+
return not self == other
280+
281+
@classmethod
282+
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType:
283+
# return an instance of this subclass
284+
return ArrowJSONType()
285+
286+
def __hash__(self) -> int:
287+
return hash(str(self))
288+
289+
def to_pandas_dtype(self):
290+
return JSONDtype()
291+
292+
293+
# Register the type to be included in RecordBatches, sent over IPC and received in
294+
# another Python process.
295+
pa.register_extension_type(ArrowJSONType())

tests/unit/test_json.py

+35
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import pandas as pd
19+
import pyarrow as pa
1920
import pytest
2021

2122
import db_dtypes
@@ -114,3 +115,37 @@ def test_as_numpy_array():
114115
]
115116
)
116117
pd._testing.assert_equal(result, expected)
118+
119+
120+
def test_arrow_json_storage_type():
121+
arrow_json_type = db_dtypes.ArrowJSONType()
122+
assert arrow_json_type.extension_name == "dbjson"
123+
assert pa.types.is_string(arrow_json_type.storage_type)
124+
125+
126+
def test_arrow_json_constructors():
127+
storage_array = pa.array(
128+
["0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
129+
)
130+
arr_1 = db_dtypes.ArrowJSONType().wrap_array(storage_array)
131+
assert isinstance(arr_1, pa.ExtensionArray)
132+
133+
arr_2 = pa.ExtensionArray.from_storage(db_dtypes.ArrowJSONType(), storage_array)
134+
assert isinstance(arr_2, pa.ExtensionArray)
135+
136+
assert arr_1 == arr_2
137+
138+
139+
def test_arrow_json_to_pandas():
140+
storage_array = pa.array(
141+
[None, "0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
142+
)
143+
arr = db_dtypes.ArrowJSONType().wrap_array(storage_array)
144+
145+
s = arr.to_pandas()
146+
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
147+
assert pd.isna(s[0])
148+
assert s[1] == 0
149+
assert s[2] == "str"
150+
assert s[3]["b"] == 2
151+
assert s[4]["a"] == [1, 2, 3]

0 commit comments

Comments
 (0)