Class GradientBoostedTrees

Object
org.apache.spark.mllib.tree.GradientBoostedTrees
All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, scala.Serializable

public class GradientBoostedTrees extends Object implements scala.Serializable, org.apache.spark.internal.Logging
A class that implements Stochastic Gradient Boosting for regression and binary classification.

The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.

Notes on Gradient Boosting vs. TreeBoost: - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - Both algorithms learn tree ensembles by minimizing loss functions. - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function, whereas the original gradient boosting method does not. - When the loss is SquaredError, these methods give the same result, but they could differ for other loss functions.

param: boostingStrategy Parameters for the gradient boosting algorithm. param: seed Random seed.

See Also:
  • Constructor Details

    • GradientBoostedTrees

      public GradientBoostedTrees(BoostingStrategy boostingStrategy)
      Parameters:
      boostingStrategy - Parameters for the gradient boosting algorithm.
  • Method Details

    • train

      public static GradientBoostedTreesModel train(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy)
      Method to train a gradient boosting model.

      Parameters:
      input - Training dataset: RDD of LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
      boostingStrategy - Configuration options for the boosting algorithm.
      Returns:
      GradientBoostedTreesModel that can be used for prediction.
    • train

      public static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy)
      Java-friendly API for org.apache.spark.mllib.tree.GradientBoostedTrees.train
      Parameters:
      input - (undocumented)
      boostingStrategy - (undocumented)
      Returns:
      (undocumented)
    • org$apache$spark$internal$Logging$$log_

      public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_()
    • org$apache$spark$internal$Logging$$log__$eq

      public static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1)
    • run

      Method to train a gradient boosting model

      Parameters:
      input - Training dataset: RDD of LabeledPoint.
      Returns:
      GradientBoostedTreesModel that can be used for prediction.
    • run

      Java-friendly API for org.apache.spark.mllib.tree.GradientBoostedTrees.run.
      Parameters:
      input - (undocumented)
      Returns:
      (undocumented)
    • runWithValidation

      public GradientBoostedTreesModel runWithValidation(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput)
      Method to validate a gradient boosting model

      Parameters:
      input - Training dataset: RDD of LabeledPoint.
      validationInput - Validation dataset. This dataset should be different from the training dataset, but it should follow the same distribution. E.g., these two datasets could be created from an original dataset by using org.apache.spark.rdd.RDD.randomSplit()
      Returns:
      GradientBoostedTreesModel that can be used for prediction.
    • runWithValidation

      public GradientBoostedTreesModel runWithValidation(JavaRDD<LabeledPoint> input, JavaRDD<LabeledPoint> validationInput)
      Java-friendly API for org.apache.spark.mllib.tree.GradientBoostedTrees.runWithValidation.
      Parameters:
      input - (undocumented)
      validationInput - (undocumented)
      Returns:
      (undocumented)