pyspark.sql.functions.arrow_udf#
- pyspark.sql.functions.arrow_udf(f=None, returnType=None, functionType=None)[source]#
Creates an arrow user defined function.
Arrow UDFs are user defined functions that are executed by Spark using Arrow to transfer and work with the data, which allows pyarrow.Array operations. An Arrow UDF is defined using the arrow_udf as a decorator or to wrap the function, and no additional configuration is required. An Arrow UDF behaves as a regular PySpark function API in general.
New in version 4.1.0.
- Parameters
- ffunction, optional
user-defined function. A python function if used as a standalone function
- returnType
pyspark.sql.types.DataType
or str, optional the return type of the user-defined function. The value can be either a
pyspark.sql.types.DataType
object or a DDL-formatted type string.- functionTypeint, optional
an enum value in
pyspark.sql.functions.ArrowUDFType
. Default: SCALAR. This parameter exists for compatibility. Using Python type hints is encouraged.
See also
Notes
The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions.
The user-defined functions do not take keyword arguments on the calling side.
The data type of returned pyarrow.Array from the user-defined functions should be matched with defined returnType (see
types.to_arrow_type()
andtypes.from_arrow_type()
). When there is mismatch between them, Spark might do conversion on returned data. The conversion is not guaranteed to be correct and results should be checked for accuracy by users.Examples
In order to use this API, customarily the below are imported:
>>> import pyarrow as pa >>> from pyspark.sql.functions import arrow_udf
Python type hints detect the function types as below:
>>> from pyspark.sql.functions import ArrowUDFType >>> from pyspark.sql.types import IntegerType >>> @arrow_udf(IntegerType(), ArrowUDFType.SCALAR) ... def slen(v: pa.Array) -> pa.Array: ... return pa.compute.utf8_length(v)
Note that the type hint should use pyarrow.Array in all cases.
- Arrays to Arrays
pyarrow.Array, … -> pyarrow.Array
The function takes one or more pyarrow.Array and outputs one pyarrow.Array. The output of the function should always be of the same length as the input.
>>> @arrow_udf("string") ... def to_upper(s: pa.Array) -> pa.Array: ... return pa.compute.ascii_upper(s) ... >>> df = spark.createDataFrame([("John Doe",)], ("name",)) >>> df.select(to_upper("name")).show() +--------------+ |to_upper(name)| +--------------+ | JOHN DOE| +--------------+
>>> @arrow_udf("first string, last string") ... def split_expand(v: pa.Array) -> pa.Array: ... b = pa.compute.ascii_split_whitespace(v) ... s0 = pa.array([t[0] for t in b]) ... s1 = pa.array([t[1] for t in b]) ... return pa.StructArray.from_arrays([s0, s1], names=["first", "last"]) ... >>> df = spark.createDataFrame([("John Doe",)], ("name",)) >>> df.select(split_expand("name")).show() +------------------+ |split_expand(name)| +------------------+ | {John, Doe}| +------------------+
This type of Pandas UDF can use keyword arguments:
>>> from pyspark.sql import functions as sf >>> @arrow_udf(returnType=IntegerType()) ... def calc(a: pa.Array, b: pa.Array) -> pa.Array: ... return pa.compute.add(a, pa.compute.multiply(b, 10)) ... >>> spark.range(2).select(calc(b=sf.col("id") * 10, a=sf.col("id"))).show() +-----------------------------+ |calc(b => (id * 10), a => id)| +-----------------------------+ | 0| | 101| +-----------------------------+
Note
The length of the input is not that of the whole input column, but is the length of an internal batch used for each call to the function.
- Iterator of Arrays to Iterator of Arrays
Iterator[pyarrow.Array] -> Iterator[pyarrow.Array]
The function takes an iterator of pyarrow.Array and outputs an iterator of pyarrow.Array. In this case, the created arrow UDF instance requires one input column when this is called as a PySpark column. The length of the entire output from the function should be the same length of the entire input; therefore, it can prefetch the data from the input iterator as long as the lengths are the same.
It is also useful when the UDF execution requires initializing some states, although internally it works identically as Arrays to Arrays case. The pseudocode below illustrates the example.
@arrow_udf("long") def calculate(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]: # Do some expensive initialization with a state state = very_expensive_initialization() for x in iterator: # Use that state for whole iterator. yield calculate_with_state(x, state) df.select(calculate("value")).show()
>>> import pandas as pd >>> from typing import Iterator >>> @arrow_udf("long") ... def plus_one(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]: ... for v in iterator: ... yield pa.compute.add(v, 1) ... >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"])) >>> df.select(plus_one(df.v)).show() +-----------+ |plus_one(v)| +-----------+ | 2| | 3| | 4| +-----------+
Note
The length of each series is the length of a batch internally used.
- Iterator of Multiple Arrays to Iterator of Arrays
Iterator[Tuple[pyarrow.Array, …]] -> Iterator[pyarrow.Array]
The function takes an iterator of a tuple of multiple pyarrow.Array and outputs an iterator of pyarrow.Array. In this case, the created arrow UDF instance requires input columns as many as the series when this is called as a PySpark column. Otherwise, it has the same characteristics and restrictions as Iterator of Arrays to Iterator of Arrays case.
>>> from typing import Iterator, Tuple >>> from pyspark.sql import functions as sf >>> @arrow_udf("long") ... def multiply(iterator: Iterator[Tuple[pa.Array, pa.Array]]) -> Iterator[pa.Array]: ... for v1, v2 in iterator: ... yield pa.compute.multiply(v1, v2.field("v")) ... >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"])) >>> df.withColumn('output', multiply(sf.col("v"), sf.struct(sf.col("v")))).show() +---+------+ | v|output| +---+------+ | 1| 1| | 2| 4| | 3| 9| +---+------+
Note
The length of each series is the length of a batch internally used.
- Arrays to Scalar
pyarrow.Array, … -> Any
The function takes pyarrow.Array and returns a scalar value. The returned scalar can be a python primitive type, (e.g., int or float), a numpy data type (e.g., numpy.int64 or numpy.float64), or a pyarrow.Scalar instance which supports complex return types. Any should ideally be a specific scalar type accordingly.
>>> @arrow_udf("double") ... def mean_udf(v: pa.Array) -> float: ... return pa.compute.mean(v).as_py() ... >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) >>> df.groupby("id").agg(mean_udf(df['v'])).show() +---+-----------+ | id|mean_udf(v)| +---+-----------+ | 1| 1.5| | 2| 6.0| +---+-----------+
The retun type can also be a complex type such as struct, list, or map.
>>> @arrow_udf("struct<m1: double, m2: double>") ... def min_max_udf(v: pa.Array) -> pa.Scalar: ... m1 = pa.compute.min(v) ... m2 = pa.compute.max(v) ... t = pa.struct([pa.field("m1", pa.float64()), pa.field("m2", pa.float64())]) ... return pa.scalar(value={"m1": m1.as_py(), "m2": m2.as_py()}, type=t) ... >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) >>> df.groupby("id").agg(min_max_udf(df['v'])).show() +---+--------------+ | id|min_max_udf(v)| +---+--------------+ | 1| {1.0, 2.0}| | 2| {3.0, 10.0}| +---+--------------+
This type of Pandas UDF can use keyword arguments:
>>> @arrow_udf("double") ... def weighted_mean_udf(v: pa.Array, w: pa.Array) -> float: ... import numpy as np ... return np.average(v.to_numpy(), weights=w) ... >>> df = spark.createDataFrame( ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)], ... ("id", "v", "w")) >>> df.groupby("id").agg(weighted_mean_udf(w=df["w"], v=df["v"])).show() +---+---------------------------------+ | id|weighted_mean_udf(w => w, v => v)| +---+---------------------------------+ | 1| 1.6666666666666667| | 2| 7.166666666666667| +---+---------------------------------+
This UDF can also be used as window functions as below:
>>> from pyspark.sql import Window >>> @arrow_udf("double") ... def mean_udf(v: pa.Array) -> float: ... return pa.compute.mean(v).as_py() ... >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) >>> w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0) >>> df.withColumn('mean_v', mean_udf("v").over(w)).show() +---+----+------+ | id| v|mean_v| +---+----+------+ | 1| 1.0| 1.0| | 1| 2.0| 1.5| | 2| 3.0| 3.0| | 2| 5.0| 4.0| | 2|10.0| 7.5| +---+----+------+
Note
For performance reasons, the input series to window functions are not copied. Therefore, mutating the input arrays is not allowed and will cause incorrect results. For the same reason, users should also not rely on the index of the input arrays.