Skip to content

Commit 4fef9d9

Browse files
committed
[SPARK-32953][PYTHON] Add Arrow self_destruct support to toPandas
1 parent 2507301 commit 4fef9d9

File tree

4 files changed

+63
-6
lines changed

4 files changed

+63
-6
lines changed

python/pyspark/sql/pandas/conversion.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,30 @@ def toPandas(self):
100100
import pyarrow
101101
# Rename columns to avoid duplicated column names.
102102
tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
103-
batches = self.toDF(*tmp_column_names)._collect_as_arrow()
103+
self_destruct = self.sql_ctx._conf.arrowPySparkSelfDestructEnabled()
104+
batches = self.toDF(*tmp_column_names)._collect_as_arrow(
105+
split_batches=self_destruct
106+
)
104107
if len(batches) > 0:
105108
table = pyarrow.Table.from_batches(batches)
109+
# Ensure only the table has a reference to the batches, so that
110+
# self_destruct (if enabled) is effective
111+
del batches
106112
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
107113
# values, but we should use datetime.date to match the behavior with when
108114
# Arrow optimization is disabled.
109-
pdf = table.to_pandas(date_as_object=True)
115+
pandas_options = {'date_as_object': True}
116+
if self_destruct:
117+
# Configure PyArrow to use as little memory as possible:
118+
# self_destruct - free columns as they are converted
119+
# split_blocks - create a separate Pandas block for each column
120+
# use_threads - convert one column at a time
121+
pandas_options.update({
122+
'self_destruct': True,
123+
'split_blocks': True,
124+
'use_threads': False,
125+
})
126+
pdf = table.to_pandas(**pandas_options)
110127
# Rename back to the original column names.
111128
pdf.columns = self.columns
112129
for field in self.schema:
@@ -217,11 +234,14 @@ def _to_corrected_pandas_type(dt):
217234
else:
218235
return None
219236

220-
def _collect_as_arrow(self):
237+
def _collect_as_arrow(self, split_batches=False):
221238
"""
222239
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
223240
and available on driver and worker Python environments.
224241
242+
:param split_batches: split batches such that each column is in its own allocation, so
243+
that the selfDestruct optimization is effective; default False.
244+
225245
.. note:: Experimental.
226246
"""
227247
from pyspark.sql.dataframe import DataFrame
@@ -232,8 +252,9 @@ def _collect_as_arrow(self):
232252
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
233253

234254
# Collect list of un-ordered batches where last element is a list of correct order indices
255+
serializer = ArrowCollectSerializer(split_batches=split_batches)
235256
try:
236-
results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
257+
results = list(_load_from_socket((port, auth_secret), serializer))
237258
finally:
238259
# Join serving thread and raise any exceptions from collectAsArrowToPython
239260
jsocket_auth_server.getResult()

python/pyspark/sql/pandas/serializers.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ class ArrowCollectSerializer(Serializer):
3636
Deserialize a stream of batches followed by batch order information. Used in
3737
PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython()
3838
in the JVM.
39+
40+
:param split_batches: split batches such that each column is in its own allocation, so
41+
that the selfDestruct optimization is effective; default False.
3942
"""
4043

41-
def __init__(self):
44+
def __init__(self, split_batches=False):
4245
self.serializer = ArrowStreamSerializer()
46+
self.split_batches = split_batches
4347

4448
def dump_stream(self, iterator, stream):
4549
return self.serializer.dump_stream(iterator, stream)
@@ -51,7 +55,20 @@ def load_stream(self, stream):
5155
"""
5256
# load the batches
5357
for batch in self.serializer.load_stream(stream):
54-
yield batch
58+
if self.split_batches:
59+
import pyarrow as pa
60+
# When spark.sql.execution.arrow.pyspark.selfDestruct.enabled, ensure
61+
# each column in each record batch is contained in its own allocation.
62+
# Otherwise, selfDestruct does nothing; it frees each column as its
63+
# converted, but each column will actually be a list of slices of record
64+
# batches, and so no memory is actually freed until all columns are
65+
# converted.
66+
split_batch = pa.RecordBatch.from_arrays([
67+
pa.concat_arrays([array]) for array in batch
68+
], schema=batch.schema)
69+
yield split_batch
70+
else:
71+
yield batch
5572

5673
# load the batch order indices or propagate any error that occurred in the JVM
5774
num = read_int(stream)

python/pyspark/sql/tests/test_arrow.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ def test_pandas_round_trip(self):
190190
pdf_arrow = df.toPandas()
191191
assert_frame_equal(pdf_arrow, pdf)
192192

193+
def test_pandas_self_destruct(self):
194+
with self.sql_conf({"spark.sql.execution.arrow.pyspark.selfDestruct.enabled": True}):
195+
pdf = self.create_pandas_data_frame()
196+
df = self.spark.createDataFrame(self.data, schema=self.schema)
197+
pdf_arrow = df.toPandas()
198+
assert_frame_equal(pdf_arrow, pdf)
199+
193200
def test_filtered_frame(self):
194201
df = self.spark.range(3).toDF("i")
195202
pdf = df.filter("i < 0").toPandas()

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,6 +1843,16 @@ object SQLConf {
18431843
.version("3.0.0")
18441844
.fallbackConf(ARROW_EXECUTION_ENABLED)
18451845

1846+
val ARROW_PYSPARK_SELF_DESTRUCT_ENABLED =
1847+
buildConf("spark.sql.execution.arrow.pyspark.selfDestruct.enabled")
1848+
.doc("When true, make use of Apache Arrow's self-destruct option " +
1849+
"for columnar data transfers in PySpark. " +
1850+
"This reduces memory usage at the cost of some CPU time. " +
1851+
"This optimization applies to: pyspark.sql.DataFrame.toPandas")
1852+
.version("3.0.0")
1853+
.booleanConf
1854+
.createWithDefault(false)
1855+
18461856
val PYSPARK_JVM_STACKTRACE_ENABLED =
18471857
buildConf("spark.sql.pyspark.jvmStacktrace.enabled")
18481858
.doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " +
@@ -3302,6 +3312,8 @@ class SQLConf extends Serializable with Logging {
33023312

33033313
def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED)
33043314

3315+
def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED)
3316+
33053317
def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED)
33063318

33073319
def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)

0 commit comments

Comments
 (0)