Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
75 views
in Technique[技术] by (71.8m points)

python - How do I use the UnaryTransformer in PySpark?

I can't figure out what's wrong with my implementation here, nor can I find any example of how to use the UnaryTransformer to calculate a custom transformation in a PySpark Pipeline.

from pyspark.ml import Pipeline, UnaryTransformer
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import DoubleType

df = spark.createDataFrame([
    (0.0, 1.0),
    (1.0, 0.0),
    (2.0, 1.0),
    (0.0, 2.0),
    (0.0, 1.0),
    (2.0, 0.0)
], ["categoryIndex1", "categoryIndex2"])

class ScaleUp(UnaryTransformer):
    def createTransformFunc(self):
        """
        Creates the transform function using the given param map. The input param map already takes
        account of the embedded param map. So the param values should be determined
        solely by the input param map.
        """
        return f.udf(lambda item: item * 10, returnType=DoubleType())

    def outputDataType(self):
        """
        Returns the data type of the output column.
        """
        return DoubleType()

    def validateInputType(self, inputType):
        """
        Validates the input type. Throw an exception if it is invalid.
        """
        assert inputType == DoubleType(), f'Expected DoubleType() and found {inputType}'
  
scale_up = ScaleUp().setInputCol('categoryIndex2')
pipeline = Pipeline(stages=[scale_up])
pipeline.fit(df).transform(df).show()
question from:https://stackoverflow.com/questions/65892971/how-do-i-use-the-unarytransformer-in-pyspark

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

The createTransformFunc function expects a Python function, not a Spark UDF:

class ScaleUp(UnaryTransformer):
    def createTransformFunc(self):
        return lambda item: item * 10

    def outputDataType(self):
        return DoubleType()

    def validateInputType(self, inputType):
        assert inputType == DoubleType(), f'Expected DoubleType() and found {inputType}'

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...