Class FMClassifier

All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, ClassifierParams, FMClassifierParams, ProbabilisticClassifierParams, Params, HasFeaturesCol, HasFitIntercept, HasLabelCol, HasMaxIter, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasRegParam, HasSeed, HasSolver, HasStepSize, HasThresholds, HasTol, HasWeightCol, PredictorParams, FactorizationMachines, FactorizationMachinesParams, DefaultParamsWritable, Identifiable, MLWritable, scala.Serializable

public class FMClassifier extends ProbabilisticClassifier<Vector,FMClassifier,FMClassificationModel> implements FactorizationMachines, FMClassifierParams, DefaultParamsWritable, org.apache.spark.internal.Logging
Factorization Machines learning algorithm for classification. It supports normal gradient descent and AdamW solver.

The implementation is based upon: S. Rendle. "Factorization machines" 2010.

FM is able to estimate interactions even in problems with huge sparsity (like advertising and recommendation system). FM formula is:

$$ \begin{align} y = \sigma\left( w_0 + \sum\limits^n_{i-1} w_i x_i + \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \right) \end{align} $$
First two terms denote global bias and linear term (as same as linear regression), and last term denotes pairwise interactions term. v_i describes the i-th variable with k factors.

FM classification model uses logistic loss which can be solved by gradient descent method, and regularization terms like L2 are usually added to the loss function to prevent overfitting.

See Also:
Note:
Multiclass labels are not currently supported.
  • Constructor Details

    • FMClassifier

      public FMClassifier(String uid)
    • FMClassifier

      public FMClassifier()
  • Method Details

    • load

      public static FMClassifier load(String path)
    • read

      public static MLReader<T> read()
    • factorSize

      public final IntParam factorSize()
      Description copied from interface: FactorizationMachinesParams
      Param for dimensionality of the factors (&gt;= 0)
      Specified by:
      factorSize in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • fitLinear

      public final BooleanParam fitLinear()
      Description copied from interface: FactorizationMachinesParams
      Param for whether to fit linear term (aka 1-way term)
      Specified by:
      fitLinear in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • miniBatchFraction

      public final DoubleParam miniBatchFraction()
      Description copied from interface: FactorizationMachinesParams
      Param for mini-batch fraction, must be in range (0, 1]
      Specified by:
      miniBatchFraction in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • initStd

      public final DoubleParam initStd()
      Description copied from interface: FactorizationMachinesParams
      Param for standard deviation of initial coefficients
      Specified by:
      initStd in interface FactorizationMachinesParams
      Returns:
      (undocumented)
    • solver

      public final Param<String> solver()
      Description copied from interface: FactorizationMachinesParams
      The solver algorithm for optimization. Supported options: "gd", "adamW". Default: "adamW"

      Specified by:
      solver in interface FactorizationMachinesParams
      Specified by:
      solver in interface HasSolver
      Returns:
      (undocumented)
    • weightCol

      public final Param<String> weightCol()
      Description copied from interface: HasWeightCol
      Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.
      Specified by:
      weightCol in interface HasWeightCol
      Returns:
      (undocumented)
    • regParam

      public final DoubleParam regParam()
      Description copied from interface: HasRegParam
      Param for regularization parameter (&gt;= 0).
      Specified by:
      regParam in interface HasRegParam
      Returns:
      (undocumented)
    • fitIntercept

      public final BooleanParam fitIntercept()
      Description copied from interface: HasFitIntercept
      Param for whether to fit an intercept term.
      Specified by:
      fitIntercept in interface HasFitIntercept
      Returns:
      (undocumented)
    • seed

      public final LongParam seed()
      Description copied from interface: HasSeed
      Param for random seed.
      Specified by:
      seed in interface HasSeed
      Returns:
      (undocumented)
    • tol

      public final DoubleParam tol()
      Description copied from interface: HasTol
      Param for the convergence tolerance for iterative algorithms (&gt;= 0).
      Specified by:
      tol in interface HasTol
      Returns:
      (undocumented)
    • stepSize

      public DoubleParam stepSize()
      Description copied from interface: HasStepSize
      Param for Step size to be used for each iteration of optimization (&gt; 0).
      Specified by:
      stepSize in interface HasStepSize
      Returns:
      (undocumented)
    • maxIter

      public final IntParam maxIter()
      Description copied from interface: HasMaxIter
      Param for maximum number of iterations (&gt;= 0).
      Specified by:
      maxIter in interface HasMaxIter
      Returns:
      (undocumented)
    • uid

      public String uid()
      Description copied from interface: Identifiable
      An immutable unique ID for the object and its derivatives.
      Specified by:
      uid in interface Identifiable
      Returns:
      (undocumented)
    • setFactorSize

      public FMClassifier setFactorSize(int value)
      Set the dimensionality of the factors. Default is 8.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setFitIntercept

      public FMClassifier setFitIntercept(boolean value)
      Set whether to fit intercept term. Default is true.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setFitLinear

      public FMClassifier setFitLinear(boolean value)
      Set whether to fit linear term. Default is true.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setRegParam

      public FMClassifier setRegParam(double value)
      Set the L2 regularization parameter. Default is 0.0.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setMiniBatchFraction

      public FMClassifier setMiniBatchFraction(double value)
      Set the mini-batch fraction parameter. Default is 1.0.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setInitStd

      public FMClassifier setInitStd(double value)
      Set the standard deviation of initial coefficients. Default is 0.01.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setMaxIter

      public FMClassifier setMaxIter(int value)
      Set the maximum number of iterations. Default is 100.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setStepSize

      public FMClassifier setStepSize(double value)
      Set the initial step size for the first step (like learning rate). Default is 1.0.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setTol

      public FMClassifier setTol(double value)
      Set the convergence tolerance of iterations. Default is 1E-6.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setSolver

      public FMClassifier setSolver(String value)
      Set the solver algorithm used for optimization. Supported options: "gd", "adamW". Default: "adamW"

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setSeed

      public FMClassifier setSeed(long value)
      Set the random seed for weight initialization.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • copy

      public FMClassifier copy(ParamMap extra)
      Description copied from interface: Params
      Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().
      Specified by:
      copy in interface Params
      Specified by:
      copy in class Predictor<Vector,FMClassifier,FMClassificationModel>
      Parameters:
      extra - (undocumented)
      Returns:
      (undocumented)