Skip to content

An interactive toolkit for visualizing GMM convergence in 3D/2D, featuring PCA for dimensionality reduction, K-means++ initialization, and covariance regularization for stability.


Notifications You must be signed in to change notification settings


Folders and files

Last commit message
Last commit date

Latest commit



15 Commits

Repository files navigation


The project is a visualization toolkit designed to illustrate the convergence of a Gaussian Mixture Model (GMM) in 3D space, providing an interactive experience. While primarily focused on 3D visualization, it also supports 2D plotting. The toolkit features a custom GMM implementation utilizing K-means++ for optimal centroid initialization and covariance matrix regularization to ensure positive definiteness, enhancing numerical stability. The Expectation-Maximization (EM) algorithm estimates GMM parameters. Additionally, the toolkit applies Principal Component Analysis (PCA) to reduce data dimensions to three, enabling effective 3D visualization of high-dimensional data.


git clone path/you/want/to/clone
git clone path/you/want/to/clone


Quick example:

1. Change directory to the path you cloned the repository to.

import numpy as np
import os 

from src.main.GMMViz.GaussianMixtureModel import GMM
from src.main.GMMViz.GmmPlot import GmmViz
from src.main.GMMViz.DataGenerater import DataGenerater
import as pio

2. Generating test case dataset. (Or load your own dataset)

2.1 Working with data under 3 dimensions.

3D Gaussian Mixture Model

pio.renderers.default = "notebook"

# Generate dataset with k = 3 groups within a dim = 3 dimensional space. 
X3 = DataGenerater.genData(k = 3,  # used to generate data with clearly k clusters.
                           dim = 3, # dimension of the data
                           points_per_cluster = 200, 
                           lim = [-10, 10], # range of mean values for each clusters
                           plot = True, # only data with dimension lower than 3 can be plotted.
                           random_state = 129)


# instantiate the object
gmm3 = GMM(n_clusters=3, random_state=129)
# fit the GMM to the data

X3 can be replaced by a pandas dataframe or a numpy array.

You can use gmm3.getEstimands(parm = ) with arguments options: ['mean', 'Sigma', 'log_likelihood'], to get the corrsponding parameter information in the covergence of the GMM. If no argument passed, then it will return the dictionary of the parameters estimation in a dictionary.

2.2 Working with data over 3 dimensions.

When a dataset exceeds 3 dimensions, visualizing it directly in 3D space is impractical. Principal Component Analysis (PCA) addresses this by reducing the dataset's dimensionality. It projects the data onto the top three directions of maximum variance, identified through eigenvectors of the covariance matrix. Setting the number of principal components (n_component) to 3 allows the transformed dataset to be visualized effectively in three-dimensional space.

Over 3 dimensions : Using PCA 
X7 = DataGenerater.genData(k=6, dim=7, points_per_cluster=100, lim=[-20, 20], plot = False, random_state = 129) # plot = False, since the data with greater than 3 dimensions is not able to visualized.
PCAGMM = GMM(n_clusters=6)

PCAGMM.PCA_fit(X = X7, n_components=3) # n_components' default value is 3, which is to form a 3 dimensional data.

3. Plot the Gaussian distribution.

There are two options for plotting the GMM in 3 dimensional space.

The plot() method draw the multivariate Gaussian distribution as a ellipsoid for each cluster.

3.1 Using matplotlib.pyplot (set utiPlotly = False):

# instantiate the GmmViz object
V3F = GmmViz(gmm3, utiPlotly=False) # plot via matplotlib

# use plot method to plot
         path_prefix="doc/image/dim3/parms/", # image will be stored in the `path_prefix` directory.
         show_plot = False, #  tells whether to show the figure through the editor or not. Default is `False`.
         save_plot = True, # export the figures. Default is True
         max_iter = 15) # number of iteration to plot. Default is 15.

In plot() method, the show_plot parameter tells whether to show the figure through the editor or not. Default is False.

We can generate gif file from the images we exported by the plot() method.

GmmViz.generateGIF(image_path = "doc/image/dim3/parms", # directory of the images showing each iteraction
                   output_path_filename = "doc/image/dim3/parms/gif/GMM-3D-Parms.gif", 
                   fps = 2) # Adjust the timing of each frame in the GIF file


3.2 Using Plotly (set utiPlotly = True):

Interactive 3D plot
# plot
pio.renderers.default = "browser" # it will open the browser to show the plots.

# GMM for 3 dim dataset
V3T = GmmViz(gmm3, utiPlotly=True)
V3T.plot(fig_title = "GMM-3D", path_prefix="doc/image/dim3-plotly/parms/", show_plot = False)
GmmViz.generateGIF(image_path = "doc/image/dim3-plotly/parms", output_path_filename = "doc/image/dim3-plotly/parms/gif/GMM-3D-Parms-plotly.gif", fps = 2)

# PCA_fit GMM
V7T = GmmViz(PCAGMM, utiPlotly=True)
V7T.plot(fig_title = "GMM-3D", 
         path_prefix = "", # Directory to export the images, keep default value = "" when save_plot = False
         save_plot = False, # no need to save the figures.


Visualize 3-dimensional data via Plotly

4. Visualize in 2D spaces with a 2 dimensional dataset.


X2 = DataGenerater.genData(k=3, dim=2, points_per_cluster=200, lim=[-10, 10], plot = True)


gmm2 = GMM(n_clusters=3, random_state=129)

V2 = GmmViz(gmm2)

# plot convergence
V2.plot(fig_title="GMM-2D", path_prefix="doc/image/dim2/parms/")

# generate gif ( need to plot the convergence first)
GmmViz.generateGIF(image_path = "doc/image/dim2/parms", output_path_filename = "doc/image/dim2/parms/gif/GMM-2D-Parms.gif", fps = 2)

# Likelihood
GmmViz.generateGIF(image_path = "doc/image/dim2/ll", output_path_filename = "doc/image/dim2/ll/gif/GMM-2D-LL.gif", fps = 2)



OS : macOS Sonoma 14.5
IDE: Visual Studio Code 
Language : Python       3.9.7 

Package list:
backports.shutil-get-terminal-size 1.0.0
imageio                            2.9.0
matplotlib                         3.7.2
matplotlib-inline                  0.1.6
numpy                              1.20.3
numpydoc                           1.1.0
pandas                             1.5.3
plotly                             5.21.0
scipy                              1.10.1


Denny Chen


An interactive toolkit for visualizing GMM convergence in 3D/2D, featuring PCA for dimensionality reduction, K-means++ initialization, and covariance regularization for stability.








No releases published


No packages published
