GaussianMixtureModel¶
- 
class pyspark.mllib.clustering.GaussianMixtureModel(java_model: py4j.java_gateway.JavaObject)[source]¶
- A clustering model derived from the Gaussian Mixture Model method. - New in version 1.3.0. - Examples - >>> from pyspark.mllib.linalg import Vectors, DenseMatrix >>> from numpy.testing import assert_equal >>> from shutil import rmtree >>> import os, tempfile - >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2) >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, ... maxIterations=50, seed=10) >>> labels = model.predict(clusterdata_1).collect() >>> labels[0]==labels[1] False >>> labels[1]==labels[2] False >>> labels[4]==labels[5] True >>> model.predict([-0.1,-0.05]) 0 >>> softPredicted = model.predictSoft([-0.1,-0.05]) >>> abs(softPredicted[0] - 1.0) < 0.03 True >>> abs(softPredicted[1] - 0.0) < 0.03 True >>> abs(softPredicted[2] - 0.0) < 0.03 True - >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = GaussianMixtureModel.load(sc, path) >>> assert_equal(model.weights, sameModel.weights) >>> mus, sigmas = list( ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) >>> sameMus, sameSigmas = list( ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) >>> mus == sameMus True >>> sigmas == sameSigmas True >>> from shutil import rmtree >>> try: ... rmtree(path) ... except OSError: ... pass - >>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, ... 5.7048, 4.6567, 5.5026, ... 4.5605, 5.2043, 6.2734]) >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=4) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1] True >>> labels[2]==labels[3]==labels[4] True - Methods - call(name, *a)- Call method of java_model - load(sc, path)- Load the GaussianMixtureModel from disk. - predict(x)- Find the cluster to which the point ‘x’ or each point in RDD ‘x’ has maximum membership in this model. - predictSoft(x)- Find the membership of point ‘x’ or each point in RDD ‘x’ to all mixture components. - save(sc, path)- Save this model to the given path. - Attributes - Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i. - Number of gaussians in mixture. - Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1. - Methods Documentation - 
call(name: str, *a: Any) → Any¶
- Call method of java_model 
 - 
classmethod load(sc: pyspark.context.SparkContext, path: str) → pyspark.mllib.clustering.GaussianMixtureModel[source]¶
- Load the GaussianMixtureModel from disk. - New in version 1.5.0. - Parameters
- scSparkContext
- pathstr
- Path to where the model is stored. 
 
- sc
 
 - 
predict(x: Union[VectorLike, pyspark.rdd.RDD[VectorLike]]) → Union[numpy.int64, pyspark.rdd.RDD[int]][source]¶
- Find the cluster to which the point ‘x’ or each point in RDD ‘x’ has maximum membership in this model. - New in version 1.3.0. - Parameters
- xpyspark.mllib.linalg.Vectororpyspark.RDD
- A feature vector or an RDD of vectors representing data points. 
 
- x
- Returns
- numpy.float64 or pyspark.RDDof int
- Predicted cluster label or an RDD of predicted cluster labels if the input is an RDD. 
 
- numpy.float64 or 
 
 - 
predictSoft(x: Union[VectorLike, pyspark.rdd.RDD[VectorLike]]) → Union[numpy.ndarray, pyspark.rdd.RDD[array.array]][source]¶
- Find the membership of point ‘x’ or each point in RDD ‘x’ to all mixture components. - New in version 1.3.0. - Parameters
- xpyspark.mllib.linalg.Vectororpyspark.RDD
- A feature vector or an RDD of vectors representing data points. 
 
- x
- Returns
- numpy.ndarray or pyspark.RDD
- The membership value to all mixture components for vector ‘x’ or each vector in RDD ‘x’. 
 
- numpy.ndarray or 
 
 - 
save(sc: pyspark.context.SparkContext, path: str) → None¶
- Save this model to the given path. - New in version 1.3.0. 
 - Attributes Documentation - 
gaussians¶
- Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i. - New in version 1.4.0. 
 - 
k¶
- Number of gaussians in mixture. - New in version 1.4.0. 
 - 
weights¶
- Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1. - New in version 1.4.0. 
 
-