diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 8e01f5d170ff5..2fde3edc2486c 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -81,7 +81,10 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set() + expected_missing_in_py = set( + # TODO(SPARK-53108): Implement the time_diff function in Python + ["time_diff"] + ) self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 49fa45ed02cbe..95ca5b44bf023 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -6292,6 +6292,28 @@ object functions { def timestamp_add(unit: String, quantity: Column, ts: Column): Column = Column.internalFn("timestampadd", lit(unit), quantity, ts) + /** + * Returns the difference between two times, measured in specified units. Throws a + * SparkIllegalArgumentException, in case the specified unit is not supported. + * + * @param unit + * A STRING representing the unit of the time difference. Supported units are: "HOUR", + * "MINUTE", "SECOND", "MILLISECOND", and "MICROSECOND". The unit is case-insensitive. + * @param start + * A starting TIME. + * @param end + * An ending TIME. + * @return + * The difference between `end` and `start` times, measured in specified units. + * @note + * If any of the inputs is `NULL`, the result is `NULL`. + * @group datetime_funcs + * @since 4.1.0 + */ + def time_diff(unit: Column, start: Column, end: Column): Column = { + Column.fn("time_diff", unit, start, end) + } + /** * Returns `time` truncated to the `unit`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala index 005bfcb13d2e8..556755a73ff7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala @@ -241,6 +241,75 @@ abstract class TimeFunctionsSuiteBase extends QueryTest with SharedSparkSession checkAnswer(result2, expected) } + test("SPARK-53108: time_diff function") { + // Input data for the function. + val schema = StructType(Seq( + StructField("unit", StringType, nullable = false), + StructField("start", TimeType(), nullable = false), + StructField("end", TimeType(), nullable = false) + )) + val data = Seq( + Row("HOUR", LocalTime.parse("20:30:29"), LocalTime.parse("21:30:28")), + Row("second", LocalTime.parse("09:32:05.359123"), LocalTime.parse("17:23:49.906152")), + Row("MicroSecond", LocalTime.parse("09:32:05.359123"), LocalTime.parse("17:23:49.906152")) + ) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + + // Test the function using both `selectExpr` and `select`. + val result1 = df.selectExpr( + "time_diff(unit, start, end)" + ) + val result2 = df.select( + time_diff(col("unit"), col("start"), col("end")) + ) + // Check that both methods produce the same result. + checkAnswer(result1, result2) + + // Expected output of the function. + val expected = Seq( + 0, + 28304, + 28304547029L + ).toDF("diff").select(col("diff")) + // Check that the results match the expected output. + checkAnswer(result1, expected) + checkAnswer(result2, expected) + + // NULL result is returned for any NULL input. + val nullInputDF = Seq( + (null, LocalTime.parse("01:02:03"), LocalTime.parse("01:02:03")), + ("HOUR", null, LocalTime.parse("01:02:03")), + ("HOUR", LocalTime.parse("01:02:03"), null), + ("HOUR", null, null), + (null, LocalTime.parse("01:02:03"), null), + (null, null, LocalTime.parse("01:02:03")), + (null, null, null) + ).toDF("unit", "start", "end") + val nullResult = Seq[Integer]( + null, null, null, null, null, null, null + ).toDF("diff").select(col("diff")) + checkAnswer( + nullInputDF.select(time_diff(col("unit"), col("start"), col("end"))), + nullResult + ) + + // Error is thrown for malformed input. + val invalidUnitDF = Seq( + ("invalid_unit", LocalTime.parse("01:02:03"), LocalTime.parse("01:02:03")) + ).toDF("unit", "start", "end") + checkError( + exception = intercept[SparkIllegalArgumentException] { + invalidUnitDF.select(time_diff(col("unit"), col("start"), col("end"))).collect() + }, + condition = "INVALID_PARAMETER_VALUE.TIME_UNIT", + parameters = Map( + "functionName" -> "`time_diff`", + "parameter" -> "`unit`", + "invalidValue" -> "'invalid_unit'" + ) + ) + } + test("SPARK-53107: time_trunc function") { // Input data for the function (including null values). val schema = StructType(Seq(