Class DecisionTreeClassificationModel
Object
org.apache.spark.ml.PipelineStage
org.apache.spark.ml.Transformer
org.apache.spark.ml.Model<M>
org.apache.spark.ml.PredictionModel<FeaturesType,M>
org.apache.spark.ml.classification.ClassificationModel<FeaturesType,M>
org.apache.spark.ml.classification.ProbabilisticClassificationModel<Vector,DecisionTreeClassificationModel>
org.apache.spark.ml.classification.DecisionTreeClassificationModel
- All Implemented Interfaces:
Serializable,org.apache.spark.internal.Logging,ClassifierParams,ProbabilisticClassifierParams,Params,HasCheckpointInterval,HasFeaturesCol,HasLabelCol,HasPredictionCol,HasProbabilityCol,HasRawPredictionCol,HasSeed,HasThresholds,HasWeightCol,PredictorParams,DecisionTreeClassifierParams,DecisionTreeModel,DecisionTreeParams,TreeClassifierParams,Identifiable,MLWritable
public class DecisionTreeClassificationModel
extends ProbabilisticClassificationModel<Vector,DecisionTreeClassificationModel>
implements DecisionTreeModel, DecisionTreeClassifierParams, MLWritable, Serializable
Decision tree model (http://en.wikipedia.org/wiki/Decision_tree_learning) for classification.
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter -
Method Summary
Modifier and TypeMethodDescriptionfinal BooleanParamIf false, the algorithm will pass trees to executors to match instances with nodes.final IntParamParam for set checkpoint interval (>= 1) or disable checkpoint (-1).Creates a copy of this instance with the same UID and some extra params.intdepth()Depth of the tree.longimpurity()Criterion used for information gain calculation (case-insensitive).leafCol()Leaf indices column name.final IntParammaxBins()Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node.final IntParammaxDepth()Maximum depth of the tree (nonnegative).final IntParamMaximum memory in MB allocated to histogram aggregation.final DoubleParamMinimum information gain for a split to be considered at a tree node.final IntParamMinimum number of instances each child must have after split.final DoubleParamMinimum fraction of the weighted sample count that each child must have after split.intNumber of classes (values which the label can take).intReturns the number of features the model was trained on.doublePredict label for the given features.predictRaw(Vector features) Raw prediction for each possible label.read()rootNode()Root of the decision treefinal LongParamseed()Param for random seed.toString()Summary of the modelTransforms dataset by reading fromPredictionModel.featuresCol(), and appending new columns as specified by parameters: - predicted labels asPredictionModel.predictionCol()of typeDouble- raw predictions (confidences) asClassificationModel.rawPredictionCol()of typeVector- probability of each class asProbabilisticClassificationModel.probabilityCol()of typeVector.transformSchema(StructType schema) Check transform validity and derive the output schema from the input schema.uid()An immutable unique ID for the object and its derivatives.Param for weight column name.write()Returns anMLWriterinstance for this ML instance.Methods inherited from class org.apache.spark.ml.classification.ProbabilisticClassificationModel
normalizeToProbabilitiesInPlace, predictProbability, probabilityCol, setProbabilityCol, setThresholds, thresholdsMethods inherited from class org.apache.spark.ml.classification.ClassificationModel
rawPredictionCol, setRawPredictionCol, transformImplMethods inherited from class org.apache.spark.ml.PredictionModel
featuresCol, labelCol, predictionCol, setFeaturesCol, setPredictionColMethods inherited from class org.apache.spark.ml.Transformer
transform, transform, transformMethods inherited from class org.apache.spark.ml.PipelineStage
paramsMethods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface org.apache.spark.ml.tree.DecisionTreeClassifierParams
validateAndTransformSchemaMethods inherited from interface org.apache.spark.ml.tree.DecisionTreeModel
getEstimatedSize, getLeafField, leafIterator, maxSplitFeatureIndex, numNodes, predictLeaf, toDebugStringMethods inherited from interface org.apache.spark.ml.tree.DecisionTreeParams
getCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafColMethods inherited from interface org.apache.spark.ml.param.shared.HasCheckpointInterval
getCheckpointIntervalMethods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesCol
featuresCol, getFeaturesColMethods inherited from interface org.apache.spark.ml.param.shared.HasLabelCol
getLabelCol, labelColMethods inherited from interface org.apache.spark.ml.param.shared.HasPredictionCol
getPredictionCol, predictionColMethods inherited from interface org.apache.spark.ml.param.shared.HasProbabilityCol
getProbabilityCol, probabilityColMethods inherited from interface org.apache.spark.ml.param.shared.HasRawPredictionCol
getRawPredictionCol, rawPredictionColMethods inherited from interface org.apache.spark.ml.param.shared.HasThresholds
getThresholds, thresholdsMethods inherited from interface org.apache.spark.ml.param.shared.HasWeightCol
getWeightColMethods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logBasedOnLevel, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContextMethods inherited from interface org.apache.spark.ml.util.MLWritable
saveMethods inherited from interface org.apache.spark.ml.param.Params
clear, copyValues, defaultCopy, defaultParamMap, estimateMatadataSize, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwnMethods inherited from interface org.apache.spark.ml.tree.TreeClassifierParams
getImpurity, getOldImpurity
-
Method Details
-
read
-
load
-
impurity
Description copied from interface:TreeClassifierParamsCriterion used for information gain calculation (case-insensitive). This impurity type is used in DecisionTreeClassifier and RandomForestClassifier, Supported: "entropy" and "gini". (default = gini)- Specified by:
impurityin interfaceTreeClassifierParams- Returns:
- (undocumented)
-
leafCol
Description copied from interface:DecisionTreeParamsLeaf indices column name. Predicted leaf index of each instance in each tree by preorder. (default = "")- Specified by:
leafColin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxDepth
Description copied from interface:DecisionTreeParamsMaximum depth of the tree (nonnegative). E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default = 5)- Specified by:
maxDepthin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxBins
Description copied from interface:DecisionTreeParamsMaximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. Must be at least 2 and at least number of categories in any categorical feature. (default = 32)- Specified by:
maxBinsin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minInstancesPerNode
Description copied from interface:DecisionTreeParamsMinimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Must be at least 1. (default = 1)- Specified by:
minInstancesPerNodein interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minWeightFractionPerNode
Description copied from interface:DecisionTreeParamsMinimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in the interval [0.0, 0.5). (default = 0.0)- Specified by:
minWeightFractionPerNodein interfaceDecisionTreeParams- Returns:
- (undocumented)
-
minInfoGain
Description copied from interface:DecisionTreeParamsMinimum information gain for a split to be considered at a tree node. Should be at least 0.0. (default = 0.0)- Specified by:
minInfoGainin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
maxMemoryInMB
Description copied from interface:DecisionTreeParamsMaximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default = 256 MB)- Specified by:
maxMemoryInMBin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
cacheNodeIds
Description copied from interface:DecisionTreeParamsIf false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default = false)- Specified by:
cacheNodeIdsin interfaceDecisionTreeParams- Returns:
- (undocumented)
-
weightCol
Description copied from interface:HasWeightColParam for weight column name. If this is not set or empty, we treat all instance weights as 1.0.- Specified by:
weightColin interfaceHasWeightCol- Returns:
- (undocumented)
-
seed
Description copied from interface:HasSeedParam for random seed. -
checkpointInterval
Description copied from interface:HasCheckpointIntervalParam for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.- Specified by:
checkpointIntervalin interfaceHasCheckpointInterval- Returns:
- (undocumented)
-
depth
public int depth()Description copied from interface:DecisionTreeModelDepth of the tree. E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.- Specified by:
depthin interfaceDecisionTreeModel- Returns:
- (undocumented)
-
uid
Description copied from interface:IdentifiableAn immutable unique ID for the object and its derivatives.- Specified by:
uidin interfaceIdentifiable- Returns:
- (undocumented)
-
rootNode
Description copied from interface:DecisionTreeModelRoot of the decision tree- Specified by:
rootNodein interfaceDecisionTreeModel
-
numFeatures
public int numFeatures()Description copied from class:PredictionModelReturns the number of features the model was trained on. If unknown, returns -1- Overrides:
numFeaturesin classPredictionModel<Vector,DecisionTreeClassificationModel>
-
numClasses
public int numClasses()Description copied from class:ClassificationModelNumber of classes (values which the label can take).- Specified by:
numClassesin classClassificationModel<Vector,DecisionTreeClassificationModel>
-
estimatedSize
public long estimatedSize() -
predict
Description copied from class:ClassificationModelPredict label for the given features. This method is used to implementtransform()and outputPredictionModel.predictionCol().This default implementation for classification predicts the index of the maximum value from
predictRaw().- Overrides:
predictin classClassificationModel<Vector,DecisionTreeClassificationModel> - Parameters:
features- (undocumented)- Returns:
- (undocumented)
-
transformSchema
Description copied from class:PipelineStageCheck transform validity and derive the output schema from the input schema.We check validity for interactions between parameters during
transformSchemaand raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled byParam.validate().Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.
- Overrides:
transformSchemain classProbabilisticClassificationModel<Vector,DecisionTreeClassificationModel> - Parameters:
schema- (undocumented)- Returns:
- (undocumented)
-
transform
Description copied from class:ProbabilisticClassificationModelTransforms dataset by reading fromPredictionModel.featuresCol(), and appending new columns as specified by parameters: - predicted labels asPredictionModel.predictionCol()of typeDouble- raw predictions (confidences) asClassificationModel.rawPredictionCol()of typeVector- probability of each class asProbabilisticClassificationModel.probabilityCol()of typeVector.- Overrides:
transformin classProbabilisticClassificationModel<Vector,DecisionTreeClassificationModel> - Parameters:
dataset- input dataset- Returns:
- transformed dataset
-
predictRaw
Description copied from class:ClassificationModelRaw prediction for each possible label. The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives a measure of confidence in each possible label (where larger = more confident). This internal method is used to implementtransform()and outputClassificationModel.rawPredictionCol().- Specified by:
predictRawin classClassificationModel<Vector,DecisionTreeClassificationModel> - Parameters:
features- (undocumented)- Returns:
- vector where element i is the raw prediction for label i. This raw prediction may be any real number, where a larger value indicates greater confidence for that label.
-
copy
Description copied from interface:ParamsCreates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. SeedefaultCopy().- Specified by:
copyin interfaceParams- Specified by:
copyin classModel<DecisionTreeClassificationModel>- Parameters:
extra- (undocumented)- Returns:
- (undocumented)
-
toString
Description copied from interface:DecisionTreeModelSummary of the model- Specified by:
toStringin interfaceDecisionTreeModel- Specified by:
toStringin interfaceIdentifiable- Overrides:
toStringin classObject
-
featureImportances
-
write
Description copied from interface:MLWritableReturns anMLWriterinstance for this ML instance.- Specified by:
writein interfaceMLWritable- Returns:
- (undocumented)
-