Skip to content

Commit 64d0301

Browse files
committed
[SPARK-32953][PYTHON] Add Arrow self_destruct support to toPandas
1 parent 497f599 commit 64d0301

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

python/pyspark/sql/pandas/conversion.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,29 @@ def toPandas(self):
105105
import pyarrow
106106
# Rename columns to avoid duplicated column names.
107107
tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
108-
batches = self.toDF(*tmp_column_names)._collect_as_arrow()
108+
self_destruct = self.sql_ctx._conf.arrowPySparkSelfDestructEnabled()
109+
batches = self.toDF(*tmp_column_names)._collect_as_arrow(
110+
split_batches=self_destruct)
109111
if len(batches) > 0:
110112
table = pyarrow.Table.from_batches(batches)
113+
# Ensure only the table has a reference to the batches, so that
114+
# self_destruct (if enabled) is effective
115+
del batches
111116
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
112117
# values, but we should use datetime.date to match the behavior with when
113118
# Arrow optimization is disabled.
114-
pdf = table.to_pandas(date_as_object=True)
119+
pandas_options = {'date_as_object': True}
120+
if self_destruct:
121+
# Configure PyArrow to use as little memory as possible:
122+
# self_destruct - free columns as they are converted
123+
# split_blocks - create a separate Pandas block for each column
124+
# use_threads - convert one column at a time
125+
pandas_options.update({
126+
'self_destruct': True,
127+
'split_blocks': True,
128+
'use_threads': False,
129+
})
130+
pdf = table.to_pandas(**pandas_options)
115131
# Rename back to the original column names.
116132
pdf.columns = self.columns
117133
for field in self.schema:
@@ -225,11 +241,16 @@ def _to_corrected_pandas_type(dt):
225241
else:
226242
return None
227243

228-
def _collect_as_arrow(self):
244+
def _collect_as_arrow(self, split_batches=False):
229245
"""
230246
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
231247
and available on driver and worker Python environments.
232248
This is an experimental feature.
249+
250+
:param split_batches: split batches such that each column is in its own allocation, so
251+
that the selfDestruct optimization is effective; default False.
252+
253+
.. note:: Experimental.
233254
"""
234255
from pyspark.sql.dataframe import DataFrame
235256

@@ -240,7 +261,26 @@ def _collect_as_arrow(self):
240261

241262
# Collect list of un-ordered batches where last element is a list of correct order indices
242263
try:
243-
results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
264+
batch_stream = _load_from_socket((port, auth_secret), ArrowCollectSerializer())
265+
if split_batches:
266+
# When spark.sql.execution.arrow.pyspark.selfDestruct.enabled, ensure
267+
# each column in each record batch is contained in its own allocation.
268+
# Otherwise, selfDestruct does nothing; it frees each column as its
269+
# converted, but each column will actually be a list of slices of record
270+
# batches, and so no memory is actually freed until all columns are
271+
# converted.
272+
import pyarrow as pa
273+
results = []
274+
for batch_or_indices in batch_stream:
275+
if isinstance(batch_or_indices, pa.RecordBatch):
276+
batch_or_indices = pa.RecordBatch.from_arrays([
277+
# This call actually reallocates the array
278+
pa.concat_arrays([array])
279+
for array in batch_or_indices
280+
], schema=batch_or_indices.schema)
281+
results.append(batch_or_indices)
282+
else:
283+
results = list(batch_stream)
244284
finally:
245285
# Join serving thread and raise any exceptions from collectAsArrowToPython
246286
jsocket_auth_server.getResult()

python/pyspark/sql/tests/test_arrow.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from pyspark import SparkContext, SparkConf
2727
from pyspark.sql import Row, SparkSession
28-
from pyspark.sql.functions import udf
28+
from pyspark.sql.functions import rand, udf
2929
from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \
3030
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, \
3131
ArrayType, NullType
@@ -196,6 +196,37 @@ def test_pandas_round_trip(self):
196196
pdf_arrow = df.toPandas()
197197
assert_frame_equal(pdf_arrow, pdf)
198198

199+
def test_pandas_self_destruct(self):
200+
import pyarrow as pa
201+
rows = 2 ** 10
202+
cols = 4
203+
expected_bytes = rows * cols * 8
204+
df = self.spark.range(0, rows).select(*[rand() for _ in range(cols)])
205+
# Test the self_destruct behavior by testing _collect_as_arrow directly
206+
allocation_before = pa.total_allocated_bytes()
207+
batches = df._collect_as_arrow(split_batches=True)
208+
table = pa.Table.from_batches(batches)
209+
del batches
210+
pdf_split = table.to_pandas(self_destruct=True, split_blocks=True, use_threads=False)
211+
allocation_after = pa.total_allocated_bytes()
212+
difference = allocation_after - allocation_before
213+
# Should be around 1x the data size (table should not hold on to any memory)
214+
self.assertGreaterEqual(difference, 0.9 * expected_bytes)
215+
self.assertLessEqual(difference, 1.1 * expected_bytes)
216+
217+
with self.sql_conf({"spark.sql.execution.arrow.pyspark.selfDestruct.enabled": False}):
218+
no_self_destruct_pdf = df.toPandas()
219+
# Note while memory usage is 2x data size here (both table and pdf hold on to
220+
# memory), in this case Arrow still only tracks 1x worth of memory (since the
221+
# batches are not allocated by Arrow in this case), so we can't make any
222+
# assertions here
223+
224+
with self.sql_conf({"spark.sql.execution.arrow.pyspark.selfDestruct.enabled": True}):
225+
self_destruct_pdf = df.toPandas()
226+
227+
assert_frame_equal(pdf_split, no_self_destruct_pdf)
228+
assert_frame_equal(pdf_split, self_destruct_pdf)
229+
199230
def test_filtered_frame(self):
200231
df = self.spark.range(3).toDF("i")
201232
pdf = df.filter("i < 0").toPandas()

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,17 @@ object SQLConf {
20232023
.version("3.0.0")
20242024
.fallbackConf(ARROW_EXECUTION_ENABLED)
20252025

2026+
val ARROW_PYSPARK_SELF_DESTRUCT_ENABLED =
2027+
buildConf("spark.sql.execution.arrow.pyspark.selfDestruct.enabled")
2028+
.doc("When true, make use of Apache Arrow's self-destruct and split-blocks options " +
2029+
"for columnar data transfers in PySpark, when converting from Arrow to Pandas. " +
2030+
"This reduces memory usage at the cost of some CPU time. " +
2031+
"This optimization applies to: pyspark.sql.DataFrame.toPandas " +
2032+
"when 'spark.sql.execution.arrow.pyspark.enabled' is set.")
2033+
.version("3.2.0")
2034+
.booleanConf
2035+
.createWithDefault(false)
2036+
20262037
val PYSPARK_JVM_STACKTRACE_ENABLED =
20272038
buildConf("spark.sql.pyspark.jvmStacktrace.enabled")
20282039
.doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " +
@@ -3577,6 +3588,8 @@ class SQLConf extends Serializable with Logging {
35773588

35783589
def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED)
35793590

3591+
def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED)
3592+
35803593
def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED)
35813594

35823595
def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)

0 commit comments

Comments
 (0)