The project is a visualization toolkit designed to illustrate the convergence of a Gaussian Mixture Model (GMM) in 3D space, providing an interactive experience. While primarily focused on 3D visualization, it also supports 2D plotting. The toolkit features a custom GMM implementation utilizing K-means++ for optimal centroid initialization and covariance matrix regularization to ensure positive definiteness, enhancing numerical stability. The Expectation-Maximization (EM) algorithm estimates GMM parameters. Additionally, the toolkit applies Principal Component Analysis (PCA) to reduce data dimensions to three, enabling effective 3D visualization of high-dimensional data.
git clone https://github.com/ChenTaHung/GMM-Visualization path/you/want/to/clone
git clone git@github.com:ChenTaHung/GMM-Visualization.git path/you/want/to/clone
import numpy as np
import os
os.chdir("/folder/that/you/cloned/")
from src.main.GMMViz.GaussianMixtureModel import GMM
from src.main.GMMViz.GmmPlot import GmmViz
from src.main.GMMViz.DataGenerater import DataGenerater
import plotly.io as pio
"""
3D Gaussian Mixture Model
"""
pio.renderers.default = "notebook"
# Generate dataset with k = 3 groups within a dim = 3 dimensional space.
X3 = DataGenerater.genData(k = 3, # used to generate data with clearly k clusters.
dim = 3, # dimension of the data
points_per_cluster = 200,
lim = [-10, 10], # range of mean values for each clusters
plot = True, # only data with dimension lower than 3 can be plotted.
random_state = 129)
# instantiate the object
gmm3 = GMM(n_clusters=3, random_state=129)
# fit the GMM to the data
gmm3.fit(X3)
X3
can be replaced by a pandas dataframe or a numpy array.
You can use gmm3.getEstimands(parm = )
with arguments options: ['mean', 'Sigma', 'log_likelihood']
, to get the corrsponding parameter information in the covergence of the GMM. If no argument passed, then it will return the dictionary of the parameters estimation in a dictionary.
When a dataset exceeds 3 dimensions, visualizing it directly in 3D space is impractical. Principal Component Analysis (PCA)
addresses this by reducing the dataset's dimensionality. It projects the data onto the top three directions of maximum variance, identified through eigenvectors of the covariance matrix. Setting the number of principal components (n_component
) to 3 allows the transformed dataset to be visualized effectively in three-dimensional space.
"""
Over 3 dimensions : Using PCA
"""
X7 = DataGenerater.genData(k=6, dim=7, points_per_cluster=100, lim=[-20, 20], plot = False, random_state = 129) # plot = False, since the data with greater than 3 dimensions is not able to visualized.
PCAGMM = GMM(n_clusters=6)
PCAGMM.PCA_fit(X = X7, n_components=3) # n_components' default value is 3, which is to form a 3 dimensional data.
There are two options for plotting the GMM in 3 dimensional space.
The plot()
method draw the multivariate Gaussian distribution as a ellipsoid for each cluster.
# instantiate the GmmViz object
V3F = GmmViz(gmm3, utiPlotly=False) # plot via matplotlib
# use plot method to plot
V3F.plot(fig_title="GMM-3D",
path_prefix="doc/image/dim3/parms/", # image will be stored in the `path_prefix` directory.
show_plot = False, # tells whether to show the figure through the editor or not. Default is `False`.
save_plot = True, # export the figures. Default is True
max_iter = 15) # number of iteration to plot. Default is 15.
In plot()
method, the show_plot
parameter tells whether to show the figure through the editor or not. Default is False
.
We can generate gif file from the images we exported by the plot()
method.
GmmViz.generateGIF(image_path = "doc/image/dim3/parms", # directory of the images showing each iteraction
output_path_filename = "doc/image/dim3/parms/gif/GMM-3D-Parms.gif",
fps = 2) # Adjust the timing of each frame in the GIF file
"""
Interactive 3D plot
"""
# plot
pio.renderers.default = "browser" # it will open the browser to show the plots.
# GMM for 3 dim dataset
V3T = GmmViz(gmm3, utiPlotly=True)
V3T.plot(fig_title = "GMM-3D", path_prefix="doc/image/dim3-plotly/parms/", show_plot = False)
GmmViz.generateGIF(image_path = "doc/image/dim3-plotly/parms", output_path_filename = "doc/image/dim3-plotly/parms/gif/GMM-3D-Parms-plotly.gif", fps = 2)
# PCA_fit GMM
V7T = GmmViz(PCAGMM, utiPlotly=True)
V7T.plot(fig_title = "GMM-3D",
path_prefix = "", # Directory to export the images, keep default value = "" when save_plot = False
save_plot = False, # no need to save the figures.
show_plot=True)
Visualize 3-dimensional data via Plotly
"""
2 DIMENSIONAL GAUSSIAN MIXTURE MODEL
"""
np.random.seed(128)
X2 = DataGenerater.genData(k=3, dim=2, points_per_cluster=200, lim=[-10, 10], plot = True)
gmm2 = GMM(n_clusters=3, random_state=129)
gmm2.fit(X2)
V2 = GmmViz(gmm2)
# plot convergence
V2.plot(fig_title="GMM-2D", path_prefix="doc/image/dim2/parms/")
# generate gif ( need to plot the convergence first)
GmmViz.generateGIF(image_path = "doc/image/dim2/parms", output_path_filename = "doc/image/dim2/parms/gif/GMM-2D-Parms.gif", fps = 2)
# Likelihood
V2.plot_likelihood(output_path_filename="doc/image/dim2/ll/")
GmmViz.generateGIF(image_path = "doc/image/dim2/ll", output_path_filename = "doc/image/dim2/ll/gif/GMM-2D-LL.gif", fps = 2)
OS : macOS Sonoma 14.5
IDE: Visual Studio Code
Language : Python 3.9.7
Package list:
backports.shutil-get-terminal-size 1.0.0
imageio 2.9.0
matplotlib 3.7.2
matplotlib-inline 0.1.6
numpy 1.20.3
numpydoc 1.1.0
pandas 1.5.3
plotly 5.21.0
scipy 1.10.1
Denny Chen
- LinkedIn Profile : https://www.linkedin.com/in/dennychen-tahung/
- E-Mail : denny20700@gmail.com