Class RandomForestClassifier
Object
org.apache.spark.ml.PipelineStage
org.apache.spark.ml.Estimator<M>
org.apache.spark.ml.Predictor<FeaturesType,E,M>
org.apache.spark.ml.classification.Classifier<FeaturesType,E,M>
org.apache.spark.ml.classification.ProbabilisticClassifier<Vector,RandomForestClassifier,RandomForestClassificationModel>
org.apache.spark.ml.classification.RandomForestClassifier
- All Implemented Interfaces:
Serializable,org.apache.spark.internal.Logging,ClassifierParams,ProbabilisticClassifierParams,Params,HasCheckpointInterval,HasFeaturesCol,HasLabelCol,HasPredictionCol,HasProbabilityCol,HasRawPredictionCol,HasSeed,HasThresholds,HasWeightCol,PredictorParams,DecisionTreeParams,RandomForestClassifierParams,RandomForestParams,TreeClassifierParams,TreeEnsembleClassifierParams,TreeEnsembleParams,DefaultParamsWritable,Identifiable,MLWritable
public class RandomForestClassifier
extends ProbabilisticClassifier<Vector,RandomForestClassifier,RandomForestClassificationModel>
implements RandomForestClassifierParams, DefaultParamsWritable
Random Forest learning algorithm for
classification.
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionfinal BooleanParamWhether bootstrap samples are used when building trees.final BooleanParamIf false, the algorithm will pass trees to executors to match instances with nodes.final IntParamParam for set checkpoint interval (>= 1) or disable checkpoint (-1).Creates a copy of this instance with the same UID and some extra params.The number of features to consider for splits at each tree node.impurity()Criterion used for information gain calculation (case-insensitive).leafCol()Leaf indices column name.static RandomForestClassifierfinal IntParammaxBins()Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node.final IntParammaxDepth()Maximum depth of the tree (nonnegative).final IntParamMaximum memory in MB allocated to histogram aggregation.final DoubleParamMinimum information gain for a split to be considered at a tree node.final IntParamMinimum number of instances each child must have after split.final DoubleParamMinimum fraction of the weighted sample count that each child must have after split.final IntParamnumTrees()Number of trees to train (at least 1).static MLReader<T>read()final LongParamseed()Param for random seed.setBootstrap(boolean value) setCacheNodeIds(boolean value) setCheckpointInterval(int value) Specifies how often to checkpoint the cached node IDs.setFeatureSubsetStrategy(String value) setImpurity(String value) setMaxBins(int value) setMaxDepth(int value) setMaxMemoryInMB(int value) setMinInfoGain(double value) setMinInstancesPerNode(int value) setMinWeightFractionPerNode(double value) setNumTrees(int value) setSeed(long value) setSubsamplingRate(double value) setWeightCol(String value) Sets the value of paramweightCol().final DoubleParamFraction of the training data used for learning each decision tree, in range (0, 1].static final String[]Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2static final String[]Accessor for supported impurity settings: entropy, giniuid()An immutable unique ID for the object and its derivatives.Param for weight column name.Methods inherited from class org.apache.spark.ml.classification.ProbabilisticClassifier
probabilityCol, setProbabilityCol, setThresholds, thresholdsMethods inherited from class org.apache.spark.ml.classification.Classifier
rawPredictionCol, setRawPredictionColMethods inherited from class org.apache.spark.ml.Predictor
featuresCol, fit, labelCol, predictionCol, setFeaturesCol, setLabelCol, setPredictionCol, transformSchemaMethods inherited from class org.apache.spark.ml.PipelineStage
paramsMethods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface org.apache.spark.ml.tree.DecisionTreeParams
getCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafColMethods inherited from interface org.apache.spark.ml.util.DefaultParamsWritable
writeMethods inherited from interface org.apache.spark.ml.param.shared.HasCheckpointInterval
getCheckpointIntervalMethods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesCol
featuresCol, getFeaturesColMethods inherited from interface org.apache.spark.ml.param.shared.HasLabelCol
getLabelCol, labelColMethods inherited from interface org.apache.spark.ml.param.shared.HasPredictionCol
getPredictionCol, predictionColMethods inherited from interface org.apache.spark.ml.param.shared.HasProbabilityCol
getProbabilityCol, probabilityColMethods inherited from interface org.apache.spark.ml.param.shared.HasRawPredictionCol
getRawPredictionCol, rawPredictionColMethods inherited from interface org.apache.spark.ml.param.shared.HasThresholds
getThresholds, thresholdsMethods inherited from interface org.apache.spark.ml.param.shared.HasWeightCol
getWeightColMethods inherited from interface org.apache.spark.ml.util.Identifiable
toStringMethods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logBasedOnLevel, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContextMethods inherited from interface org.apache.spark.ml.util.MLWritable
saveMethods inherited from interface org.apache.spark.ml.param.Params
clear, copyValues, defaultCopy, defaultParamMap, estimateMatadataSize, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwnMethods inherited from interface org.apache.spark.ml.tree.RandomForestParams
getBootstrap, getNumTreesMethods inherited from interface org.apache.spark.ml.tree.TreeClassifierParams
getImpurity, getOldImpurityMethods inherited from interface org.apache.spark.ml.tree.TreeEnsembleClassifierParams
validateAndTransformSchemaMethods inherited from interface org.apache.spark.ml.tree.TreeEnsembleParams
getFeatureSubsetStrategy, getOldStrategy, getSubsamplingRate
-
Constructor Details
-
RandomForestClassifier
-
RandomForestClassifier
public RandomForestClassifier()
-
-
Method Details
-
supportedImpurities
Accessor for supported impurity settings: entropy, gini -
supportedFeatureSubsetStrategies
Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 -
load
-
read
-
impurity
Description copied from interface:TreeClassifierParamsCriterion used for information gain calculation (case-insensitive). This impurity type is used in DecisionTreeClassifier and RandomForestClassifier, Supported: "entropy" and "gini". (default = gini)- Specified by:
impurityin interfaceTreeClassifierParams- Returns:
- (undocumented)
-
numTrees
Description copied from interface:RandomForestParamsNumber of trees to train (at least 1). If 1, then no bootstrapping is used. If greater than 1, then bootstrapping is done. TODO: Change to always do bootstrapping (simpler). SPARK-7130 (default = 20)Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) is the param
maxItercontrols how many trees a GBT has. The semantics in the algorithms are a bit different.- Specified by:
numTreesin interfaceRandomForestParams- Returns:
- (undocumented)
-
bootstrap
Description copied from interface:RandomForestParamsWhether bootstrap samples are used when building trees.- Specified by:
bootstrapin interfaceRandomForestParams- Returns:
- (undocumented)
-
subsamplingRate
Description copied from interface:TreeEnsembleParamsFraction of the training data used for learning each decision tree, in range (0, 1]. (default = 1.0)- Specified by:
subsamplingRatein interfaceTreeEnsembleParams- Returns:
- (undocumented)
-
featureSubsetStrategy
Description copied from interface:TreeEnsembleParamsThe number of features to consider for splits at each tree node. Supported options: - "auto": Choose automatically for task: If numTrees == 1, set to "all." If numTrees greater than 1 (forest), set to "sqrt" for classification and to "onethird" for regression. - "all": use all features - "onethird": use 1/3 of the features - "sqrt": use sqrt(number of features) - "log2": use log2(number of features) - "n": when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features. (default = "auto")These various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
- Specified by:
featureSubsetStrategyin interfaceTreeEnsembleParams- Returns:
- (undocumented)
- See Also:
-
leafCol
Description copied from interface:DecisionTreeParamsLeaf indices column name. Predicted leaf index of each instance in each tree by preorder. (default = "")- Specified by:
leafColin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxDepth
Description copied from interface:DecisionTreeParamsMaximum depth of the tree (nonnegative). E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default = 5)- Specified by:
maxDepthin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxBins
Description copied from interface:DecisionTreeParamsMaximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. Must be at least 2 and at least number of categories in any categorical feature. (default = 32)- Specified by:
maxBinsin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minInstancesPerNode
Description copied from interface:DecisionTreeParamsMinimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Must be at least 1. (default = 1)- Specified by:
minInstancesPerNodein interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minWeightFractionPerNode
Description copied from interface:DecisionTreeParamsMinimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in the interval [0.0, 0.5). (default = 0.0)- Specified by:
minWeightFractionPerNodein interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minInfoGain
Description copied from interface:DecisionTreeParamsMinimum information gain for a split to be considered at a tree node. Should be at least 0.0. (default = 0.0)- Specified by:
minInfoGainin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxMemoryInMB
Description copied from interface:DecisionTreeParamsMaximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default = 256 MB)- Specified by:
maxMemoryInMBin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
cacheNodeIds
Description copied from interface:DecisionTreeParamsIf false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default = false)- Specified by:
cacheNodeIdsin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
weightCol
Description copied from interface:HasWeightColParam for weight column name. If this is not set or empty, we treat all instance weights as 1.0.- Specified by:
weightColin interfaceHasWeightCol- Returns:
- (undocumented)
-
seed
Description copied from interface:HasSeedParam for random seed. -
checkpointInterval
Description copied from interface:HasCheckpointIntervalParam for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.- Specified by:
checkpointIntervalin interfaceHasCheckpointInterval- Returns:
- (undocumented)
-
uid
Description copied from interface:IdentifiableAn immutable unique ID for the object and its derivatives.- Specified by:
uidin interfaceIdentifiable- Returns:
- (undocumented)
-
setMaxDepth
-
setMaxBins
-
setMinInstancesPerNode
-
setMinWeightFractionPerNode
-
setMinInfoGain
-
setMaxMemoryInMB
-
setCacheNodeIds
-
setCheckpointInterval
Specifies how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the checkpoint directory is set inSparkContext. Must be at least 1. (default = 10)- Parameters:
value- (undocumented)- Returns:
- (undocumented)
-
setImpurity
-
setSubsamplingRate
-
setSeed
-
setNumTrees
-
setBootstrap
-
setFeatureSubsetStrategy
-
setWeightCol
Sets the value of paramweightCol(). If this is not set or empty, we treat all instance weights as 1.0. By default the weightCol is not set, so all instances have weight 1.0.- Parameters:
value- (undocumented)- Returns:
- (undocumented)
-
copy
Description copied from interface:ParamsCreates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. SeedefaultCopy().- Specified by:
copyin interfaceParams- Specified by:
copyin classPredictor<Vector,RandomForestClassifier, RandomForestClassificationModel> - Parameters:
extra- (undocumented)- Returns:
- (undocumented)
-