Package org.apache.spark.ml.ann
Interface TopologyModel
- All Superinterfaces:
Serializable
Trait for ANN topology model
-
Method Summary
Modifier and TypeMethodDescriptiondoublecomputeGradient(breeze.linalg.DenseMatrix<Object> data, breeze.linalg.DenseMatrix<Object> target, Vector cumGradient, int blockSize) Computes gradient for the networkbreeze.linalg.DenseMatrix<Object>[]Forward propagationArray of layer modelsLayer[]layers()Array of layersPrediction of the model.predictRaw(Vector features) Raw prediction of the model.raw2ProbabilityInPlace(Vector rawPrediction) Probability of the model.weights()
-
Method Details
-
computeGradient
double computeGradient(breeze.linalg.DenseMatrix<Object> data, breeze.linalg.DenseMatrix<Object> target, Vector cumGradient, int blockSize) Computes gradient for the network- Parameters:
data- input datatarget- target outputcumGradient- cumulative gradientblockSize- block size- Returns:
- error
-
forward
breeze.linalg.DenseMatrix<Object>[] forward(breeze.linalg.DenseMatrix<Object> data, boolean includeLastLayer) Forward propagation- Parameters:
data- input dataincludeLastLayer- Include the last layer in the output. In MultilayerPerceptronClassifier, the last layer is always softmax; the last layer of outputs is needed for class predictions, but not for rawPrediction.- Returns:
- array of outputs for each of the layers
-
layerModels
LayerModel[] layerModels()Array of layer models- Returns:
- (undocumented)
-
layers
Layer[] layers()Array of layers- Returns:
- (undocumented)
-
predict
Prediction of the model. SeeProbabilisticClassificationModel- Parameters:
features- input features- Returns:
- prediction
-
predictRaw
Raw prediction of the model. SeeProbabilisticClassificationModel- Parameters:
features- input features- Returns:
- raw prediction
Note: This interface is only used for classification Model.
-
raw2ProbabilityInPlace
Probability of the model. SeeProbabilisticClassificationModel- Parameters:
rawPrediction- raw prediction vector- Returns:
- probability
Note: This interface is only used for classification Model.
-
weights
Vector weights()
-