Python implementation of mean shift clustering
Step 1: Import Libraries
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
- What this does:
- Imports NumPy for numerical computations.
- Imports MeanShift and estimate_bandwidth from sklearn to perform clustering and estimate bandwidth.
- Imports matplotlib.pyplot for visualizing the clusters.
Step 2: Create a Dataset
from sklearn.datasets import make_blobs
# Create a synthetic dataset
data, labels_true = make_blobs(n_samples=300, centers=4, cluster_std=0.6, random_state=42)
- What this does:
- Generates a synthetic dataset with 300 points around 4 centers.
- cluster_std controls the spread of clusters; random_state ensures reproducibility.
Output: A dataset with features (data) and true labels (labels_true).
Step 3: Estimate Bandwidth
# Estimate the optimal bandwidth
bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=150)
print("Estimated Bandwidth:", bandwidth)
- What this does:
- Uses estimate_bandwidth to calculate the ideal bandwidth based on the quantile parameter (e.g., 0.2 means considering 20% of distances).
- n_samples specifies the subset size used for bandwidth estimation.
Output Example:
Estimated Bandwidth: 1.3517988972452414
Step 4: Apply Mean-Shift Clustering
# Perform Mean-Shift clustering
mean_shift = MeanShift(bandwidth=bandwidth, bin_seeding=True)
mean_shift.fit(data)
# Extract cluster centers and labels
cluster_centers = mean_shift.cluster_centers_
labels = mean_shift.labels_
n_clusters = len(np.unique(labels))
print("Number of clusters:", n_clusters)
print("Cluster centers:\n", cluster_centers)
- What this does:
- Initializes the MeanShift model with the estimated bandwidth and bin_seeding for faster convergence.
- Fits the model to the data.
- Extracts the cluster centers, labels, and number of clusters.
Output Example:
Number of clusters: 4
Cluster centers:
[[-6.51836811 -6.96447756]
[ 4.45514655 1.97914855]
[-2.68551874 8.84006166]
[ 2.00598273 4.27105033]]
Step 5: Visualize the Results
# Assign colors to clusters
plt.figure(figsize=(8, 6))
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan']
# Plot each cluster
for i in range(n_clusters):
cluster_data = data[labels == i]
plt.scatter(cluster_data[:, 0], cluster_data[:, 1], s=50, color=colors[i], label=f"Cluster {i}")
# Plot cluster centers
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], s=200, color='black', marker='X', label="Centers")
plt.title("Mean-Shift Clustering")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.show()
- What this does:
- Visualizes the data points grouped by cluster, with different colors representing different clusters.
- Highlights cluster centers with a black "X" marker.
Output:
A scatter plot with clusters in distinct colors and their centers marked as black "X."
Step 6: Analyze the Results
Cluster Labels: Each data point is assigned a label based on its cluster.
print("Labels for first 10 points:", labels[:10])
Output Example:
Labels for first 10 points: [0 1 3 2 1 1 0 3 3 2]
Cluster Centers: Centers represent the densest regions (local maxima of the density function).
print("Cluster Centers:\n", cluster_centers)
Full Code
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
# Create dataset
data, labels_true = make_blobs(n_samples=300, centers=4, cluster_std=0.6, random_state=42)
# Estimate bandwidth
bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=150)
print("Estimated Bandwidth:", bandwidth)
# Perform Mean-Shift Clustering
mean_shift = MeanShift(bandwidth=bandwidth, bin_seeding=True)
mean_shift.fit(data)
# Extract cluster info
cluster_centers = mean_shift.cluster_centers_
labels = mean_shift.labels_
n_clusters = len(np.unique(labels))
print("Number of clusters:", n_clusters)
print("Cluster centers:\n", cluster_centers)
# Visualize clusters
plt.figure(figsize=(8, 6))
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan']
for i in range(n_clusters):
cluster_data = data[labels == i]
plt.scatter(cluster_data[:, 0], cluster_data[:, 1], s=50, color=colors[i], label=f"Cluster {i}")
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], s=200, color='black', marker='X', label="Centers")
plt.title("Mean-Shift Clustering")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.show()
Summary of Output Details
- Bandwidth: Displays the estimated optimal bandwidth value.
- Number of Clusters: Prints the number of clusters found by the algorithm.
- Cluster Centers: Outputs the coordinates of cluster centers.
- Scatter Plot: Visualizes clusters with distinct colors and highlights their centers.
This step-by-step process helps in understanding Mean-Shift Clustering and provides a clear workflow for implementation in Python.
Next Blog- Spectral Clustering