Class FMClassifier
Object
org.apache.spark.ml.PipelineStage
org.apache.spark.ml.Estimator<M>
org.apache.spark.ml.Predictor<FeaturesType,E,M>
org.apache.spark.ml.classification.Classifier<FeaturesType,E,M>
org.apache.spark.ml.classification.ProbabilisticClassifier<Vector,FMClassifier,FMClassificationModel>
org.apache.spark.ml.classification.FMClassifier
- All Implemented Interfaces:
Serializable
,org.apache.spark.internal.Logging
,ClassifierParams
,FMClassifierParams
,ProbabilisticClassifierParams
,Params
,HasFeaturesCol
,HasFitIntercept
,HasLabelCol
,HasMaxIter
,HasPredictionCol
,HasProbabilityCol
,HasRawPredictionCol
,HasRegParam
,HasSeed
,HasSolver
,HasStepSize
,HasThresholds
,HasTol
,HasWeightCol
,PredictorParams
,FactorizationMachines
,FactorizationMachinesParams
,DefaultParamsWritable
,Identifiable
,MLWritable
,scala.Serializable
public class FMClassifier
extends ProbabilisticClassifier<Vector,FMClassifier,FMClassificationModel>
implements FactorizationMachines, FMClassifierParams, DefaultParamsWritable, org.apache.spark.internal.Logging
Factorization Machines learning algorithm for classification.
It supports normal gradient descent and AdamW solver.
The implementation is based upon: S. Rendle. "Factorization machines" 2010.
FM is able to estimate interactions even in problems with huge sparsity (like advertising and recommendation system). FM formula is:
$$ \begin{align} y = \sigma\left( w_0 + \sum\limits^n_{i-1} w_i x_i + \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \right) \end{align} $$First two terms denote global bias and linear term (as same as linear regression), and last term denotes pairwise interactions term. v_i describes the i-th variable with k factors.
FM classification model uses logistic loss which can be solved by gradient descent method, and regularization terms like L2 are usually added to the loss function to prevent overfitting.
- See Also:
- Note:
- Multiclass labels are not currently supported.
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.SparkShellLoggingFilter
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionCreates a copy of this instance with the same UID and some extra params.final IntParam
Param for dimensionality of the factors (>= 0)final BooleanParam
Param for whether to fit an intercept term.final BooleanParam
Param for whether to fit linear term (aka 1-way term)final DoubleParam
initStd()
Param for standard deviation of initial coefficientsstatic FMClassifier
final IntParam
maxIter()
Param for maximum number of iterations (>= 0).final DoubleParam
Param for mini-batch fraction, must be in range (0, 1]static MLReader<T>
read()
final DoubleParam
regParam()
Param for regularization parameter (>= 0).final LongParam
seed()
Param for random seed.setFactorSize
(int value) Set the dimensionality of the factors.setFitIntercept
(boolean value) Set whether to fit intercept term.setFitLinear
(boolean value) Set whether to fit linear term.setInitStd
(double value) Set the standard deviation of initial coefficients.setMaxIter
(int value) Set the maximum number of iterations.setMiniBatchFraction
(double value) Set the mini-batch fraction parameter.setRegParam
(double value) Set the L2 regularization parameter.setSeed
(long value) Set the random seed for weight initialization.Set the solver algorithm used for optimization.setStepSize
(double value) Set the initial step size for the first step (like learning rate).setTol
(double value) Set the convergence tolerance of iterations.solver()
The solver algorithm for optimization.stepSize()
Param for Step size to be used for each iteration of optimization (> 0).final DoubleParam
tol()
Param for the convergence tolerance for iterative algorithms (>= 0).uid()
An immutable unique ID for the object and its derivatives.Param for weight column name.Methods inherited from class org.apache.spark.ml.classification.ProbabilisticClassifier
probabilityCol, setProbabilityCol, setThresholds, thresholds
Methods inherited from class org.apache.spark.ml.classification.Classifier
rawPredictionCol, setRawPredictionCol
Methods inherited from class org.apache.spark.ml.Predictor
featuresCol, fit, labelCol, predictionCol, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
Methods inherited from class org.apache.spark.ml.PipelineStage
params
Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface org.apache.spark.ml.util.DefaultParamsWritable
write
Methods inherited from interface org.apache.spark.ml.regression.FactorizationMachines
initCoefficients, trainImpl
Methods inherited from interface org.apache.spark.ml.regression.FactorizationMachinesParams
getFactorSize, getFitLinear, getInitStd, getMiniBatchFraction
Methods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesCol
featuresCol, getFeaturesCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasFitIntercept
getFitIntercept
Methods inherited from interface org.apache.spark.ml.param.shared.HasLabelCol
getLabelCol, labelCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasMaxIter
getMaxIter
Methods inherited from interface org.apache.spark.ml.param.shared.HasPredictionCol
getPredictionCol, predictionCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasProbabilityCol
getProbabilityCol, probabilityCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasRawPredictionCol
getRawPredictionCol, rawPredictionCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasRegParam
getRegParam
Methods inherited from interface org.apache.spark.ml.param.shared.HasStepSize
getStepSize
Methods inherited from interface org.apache.spark.ml.param.shared.HasThresholds
getThresholds, thresholds
Methods inherited from interface org.apache.spark.ml.param.shared.HasWeightCol
getWeightCol
Methods inherited from interface org.apache.spark.ml.util.Identifiable
toString
Methods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq
Methods inherited from interface org.apache.spark.ml.util.MLWritable
save
Methods inherited from interface org.apache.spark.ml.param.Params
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
Methods inherited from interface org.apache.spark.ml.classification.ProbabilisticClassifierParams
validateAndTransformSchema
-
Constructor Details
-
FMClassifier
-
FMClassifier
public FMClassifier()
-
-
Method Details
-
load
-
read
-
factorSize
Description copied from interface:FactorizationMachinesParams
Param for dimensionality of the factors (>= 0)- Specified by:
factorSize
in interfaceFactorizationMachinesParams
- Returns:
- (undocumented)
-
fitLinear
Description copied from interface:FactorizationMachinesParams
Param for whether to fit linear term (aka 1-way term)- Specified by:
fitLinear
in interfaceFactorizationMachinesParams
- Returns:
- (undocumented)
-
miniBatchFraction
Description copied from interface:FactorizationMachinesParams
Param for mini-batch fraction, must be in range (0, 1]- Specified by:
miniBatchFraction
in interfaceFactorizationMachinesParams
- Returns:
- (undocumented)
-
initStd
Description copied from interface:FactorizationMachinesParams
Param for standard deviation of initial coefficients- Specified by:
initStd
in interfaceFactorizationMachinesParams
- Returns:
- (undocumented)
-
solver
Description copied from interface:FactorizationMachinesParams
The solver algorithm for optimization. Supported options: "gd", "adamW". Default: "adamW"- Specified by:
solver
in interfaceFactorizationMachinesParams
- Specified by:
solver
in interfaceHasSolver
- Returns:
- (undocumented)
-
weightCol
Description copied from interface:HasWeightCol
Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.- Specified by:
weightCol
in interfaceHasWeightCol
- Returns:
- (undocumented)
-
regParam
Description copied from interface:HasRegParam
Param for regularization parameter (>= 0).- Specified by:
regParam
in interfaceHasRegParam
- Returns:
- (undocumented)
-
fitIntercept
Description copied from interface:HasFitIntercept
Param for whether to fit an intercept term.- Specified by:
fitIntercept
in interfaceHasFitIntercept
- Returns:
- (undocumented)
-
seed
Description copied from interface:HasSeed
Param for random seed. -
tol
Description copied from interface:HasTol
Param for the convergence tolerance for iterative algorithms (>= 0). -
stepSize
Description copied from interface:HasStepSize
Param for Step size to be used for each iteration of optimization (> 0).- Specified by:
stepSize
in interfaceHasStepSize
- Returns:
- (undocumented)
-
maxIter
Description copied from interface:HasMaxIter
Param for maximum number of iterations (>= 0).- Specified by:
maxIter
in interfaceHasMaxIter
- Returns:
- (undocumented)
-
uid
Description copied from interface:Identifiable
An immutable unique ID for the object and its derivatives.- Specified by:
uid
in interfaceIdentifiable
- Returns:
- (undocumented)
-
setFactorSize
Set the dimensionality of the factors. Default is 8.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setFitIntercept
Set whether to fit intercept term. Default is true.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setFitLinear
Set whether to fit linear term. Default is true.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setRegParam
Set the L2 regularization parameter. Default is 0.0.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setMiniBatchFraction
Set the mini-batch fraction parameter. Default is 1.0.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setInitStd
Set the standard deviation of initial coefficients. Default is 0.01.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setMaxIter
Set the maximum number of iterations. Default is 100.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setStepSize
Set the initial step size for the first step (like learning rate). Default is 1.0.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setTol
Set the convergence tolerance of iterations. Default is 1E-6.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setSolver
Set the solver algorithm used for optimization. Supported options: "gd", "adamW". Default: "adamW"- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setSeed
Set the random seed for weight initialization.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
copy
Description copied from interface:Params
Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. SeedefaultCopy()
.- Specified by:
copy
in interfaceParams
- Specified by:
copy
in classPredictor<Vector,
FMClassifier, FMClassificationModel> - Parameters:
extra
- (undocumented)- Returns:
- (undocumented)
-