Skip to content

Commit a6677c5

Browse files
committed
Parent Column class for Spark Connect and Spark Classic
1 parent 458f70b commit a6677c5

36 files changed

+1750
-915
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def __hash__(self):
476476
"pyspark.sql.session",
477477
"pyspark.sql.conf",
478478
"pyspark.sql.catalog",
479-
"pyspark.sql.column",
479+
"pyspark.sql.classic.column",
480480
"pyspark.sql.classic.dataframe",
481481
"pyspark.sql.datasource",
482482
"pyspark.sql.group",

python/pyspark/ml/connect/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717
from pyspark.ml import functions as PyMLFunctions
18-
from pyspark.sql.connect.column import Column
18+
from pyspark.sql.column import Column
1919
from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col, lit
2020

2121

python/pyspark/ml/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
pass # Let it throw a better error message later when the API is invoked.
2929

3030
from pyspark.sql.functions import pandas_udf
31-
from pyspark.sql.column import Column, _to_java_column
31+
from pyspark.sql.classic.column import Column, _to_java_column
3232
from pyspark.sql.types import (
3333
ArrayType,
3434
ByteType,

python/pyspark/ml/stat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from pyspark.ml.common import _java2py, _py2java
2323
from pyspark.ml.linalg import Matrix, Vector
2424
from pyspark.ml.wrapper import JavaWrapper, _jvm
25-
from pyspark.sql.column import Column, _to_seq
25+
from pyspark.sql.column import Column
26+
from pyspark.sql.classic.column import _to_seq
2627
from pyspark.sql.dataframe import DataFrame
2728
from pyspark.sql.functions import lit
2829

python/pyspark/pandas/internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ def attach_distributed_sequence_column(
959959

960960
return sdf.select(
961961
ConnectColumn(DistributedSequenceID()).alias(column_name),
962-
"*", # type: ignore[call-overload]
962+
"*",
963963
)
964964
else:
965965
return PySparkDataFrame(

python/pyspark/pandas/spark/functions.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def product(col: Column, dropna: bool) -> Column:
2525
if is_remote():
2626
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit
2727

28-
return _invoke_function_over_columns( # type: ignore[return-value]
28+
return _invoke_function_over_columns(
2929
"pandas_product",
30-
col, # type: ignore[arg-type]
30+
col,
3131
lit(dropna),
3232
)
3333

@@ -42,9 +42,9 @@ def stddev(col: Column, ddof: int) -> Column:
4242
if is_remote():
4343
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit
4444

45-
return _invoke_function_over_columns( # type: ignore[return-value]
45+
return _invoke_function_over_columns(
4646
"pandas_stddev",
47-
col, # type: ignore[arg-type]
47+
col,
4848
lit(ddof),
4949
)
5050

@@ -59,9 +59,9 @@ def var(col: Column, ddof: int) -> Column:
5959
if is_remote():
6060
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit
6161

62-
return _invoke_function_over_columns( # type: ignore[return-value]
62+
return _invoke_function_over_columns(
6363
"pandas_var",
64-
col, # type: ignore[arg-type]
64+
col,
6565
lit(ddof),
6666
)
6767

@@ -76,9 +76,9 @@ def skew(col: Column) -> Column:
7676
if is_remote():
7777
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
7878

79-
return _invoke_function_over_columns( # type: ignore[return-value]
79+
return _invoke_function_over_columns(
8080
"pandas_skew",
81-
col, # type: ignore[arg-type]
81+
col,
8282
)
8383

8484
else:
@@ -92,9 +92,9 @@ def kurt(col: Column) -> Column:
9292
if is_remote():
9393
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
9494

95-
return _invoke_function_over_columns( # type: ignore[return-value]
95+
return _invoke_function_over_columns(
9696
"pandas_kurt",
97-
col, # type: ignore[arg-type]
97+
col,
9898
)
9999

100100
else:
@@ -108,9 +108,9 @@ def mode(col: Column, dropna: bool) -> Column:
108108
if is_remote():
109109
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit
110110

111-
return _invoke_function_over_columns( # type: ignore[return-value]
111+
return _invoke_function_over_columns(
112112
"pandas_mode",
113-
col, # type: ignore[arg-type]
113+
col,
114114
lit(dropna),
115115
)
116116

@@ -125,10 +125,10 @@ def covar(col1: Column, col2: Column, ddof: int) -> Column:
125125
if is_remote():
126126
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit
127127

128-
return _invoke_function_over_columns( # type: ignore[return-value]
128+
return _invoke_function_over_columns(
129129
"pandas_covar",
130-
col1, # type: ignore[arg-type]
131-
col2, # type: ignore[arg-type]
130+
col1,
131+
col2,
132132
lit(ddof),
133133
)
134134

@@ -143,9 +143,9 @@ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
143143
if is_remote():
144144
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit
145145

146-
return _invoke_function_over_columns( # type: ignore[return-value]
146+
return _invoke_function_over_columns(
147147
"ewm",
148-
col, # type: ignore[arg-type]
148+
col,
149149
lit(alpha),
150150
lit(ignore_na),
151151
)
@@ -161,9 +161,9 @@ def null_index(col: Column) -> Column:
161161
if is_remote():
162162
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
163163

164-
return _invoke_function_over_columns( # type: ignore[return-value]
164+
return _invoke_function_over_columns(
165165
"null_index",
166-
col, # type: ignore[arg-type]
166+
col,
167167
)
168168

169169
else:
@@ -177,11 +177,11 @@ def timestampdiff(unit: str, start: Column, end: Column) -> Column:
177177
if is_remote():
178178
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit
179179

180-
return _invoke_function_over_columns( # type: ignore[return-value]
180+
return _invoke_function_over_columns(
181181
"timestampdiff",
182182
lit(unit),
183-
start, # type: ignore[arg-type]
184-
end, # type: ignore[arg-type]
183+
start,
184+
end,
185185
)
186186

187187
else:

python/pyspark/sql/avro/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from typing import Dict, Optional, TYPE_CHECKING, cast
2424

25-
from pyspark.sql.column import Column, _to_java_column
25+
from pyspark.sql.column import Column
2626
from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions
2727
from pyspark.util import _print_missing_jar
2828

@@ -78,6 +78,7 @@ def from_avro(
7878
[Row(value=Row(avro=Row(age=2, name='Alice')))]
7979
"""
8080
from py4j.java_gateway import JVMView
81+
from pyspark.sql.classic.column import _to_java_column
8182

8283
sc = get_active_spark_context()
8384
try:
@@ -128,6 +129,7 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
128129
[Row(suite=bytearray(b'\\x02\\x00'))]
129130
"""
130131
from py4j.java_gateway import JVMView
132+
from pyspark.sql.classic.column import _to_java_column
131133

132134
sc = get_active_spark_context()
133135
try:

0 commit comments

Comments
 (0)