Unsupervised Learning January 01 ,2025

Python implementation of mean shift clustering

Step 1: Import Libraries

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
  • What this does:
    • Imports NumPy for numerical computations.
    • Imports MeanShift and estimate_bandwidth from sklearn to perform clustering and estimate bandwidth.
    • Imports matplotlib.pyplot for visualizing the clusters.

Step 2: Create a Dataset

from sklearn.datasets import make_blobs

# Create a synthetic dataset
data, labels_true = make_blobs(n_samples=300, centers=4, cluster_std=0.6, random_state=42)
  • What this does:
    • Generates a synthetic dataset with 300 points around 4 centers.
    • cluster_std controls the spread of clusters; random_state ensures reproducibility.

Output: A dataset with features (data) and true labels (labels_true).

Step 3: Estimate Bandwidth

# Estimate the optimal bandwidth
bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=150)
print("Estimated Bandwidth:", bandwidth)
  • What this does:
    • Uses estimate_bandwidth to calculate the ideal bandwidth based on the quantile parameter (e.g., 0.2 means considering 20% of distances).
    • n_samples specifies the subset size used for bandwidth estimation.

Output Example:

Estimated Bandwidth: 1.3517988972452414

Step 4: Apply Mean-Shift Clustering

# Perform Mean-Shift clustering
mean_shift = MeanShift(bandwidth=bandwidth, bin_seeding=True)
mean_shift.fit(data)

# Extract cluster centers and labels
cluster_centers = mean_shift.cluster_centers_
labels = mean_shift.labels_
n_clusters = len(np.unique(labels))

print("Number of clusters:", n_clusters)
print("Cluster centers:\n", cluster_centers)
  • What this does:
    • Initializes the MeanShift model with the estimated bandwidth and bin_seeding for faster convergence.
    • Fits the model to the data.
    • Extracts the cluster centers, labels, and number of clusters.

Output Example:

Number of clusters: 4
Cluster centers:
 [[-6.51836811 -6.96447756]
  [ 4.45514655  1.97914855]
  [-2.68551874  8.84006166]
  [ 2.00598273  4.27105033]]

Step 5: Visualize the Results

# Assign colors to clusters
plt.figure(figsize=(8, 6))
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan']

# Plot each cluster
for i in range(n_clusters):
    cluster_data = data[labels == i]
    plt.scatter(cluster_data[:, 0], cluster_data[:, 1], s=50, color=colors[i], label=f"Cluster {i}")

# Plot cluster centers
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], s=200, color='black', marker='X', label="Centers")

plt.title("Mean-Shift Clustering")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.show()
  • What this does:
    • Visualizes the data points grouped by cluster, with different colors representing different clusters.
    • Highlights cluster centers with a black "X" marker.

Output:
A scatter plot with clusters in distinct colors and their centers marked as black "X."

Step 6: Analyze the Results

  • Cluster Labels: Each data point is assigned a label based on its cluster.

    print("Labels for first 10 points:", labels[:10])
    

    Output Example:

    Labels for first 10 points: [0 1 3 2 1 1 0 3 3 2]
    
  • Cluster Centers: Centers represent the densest regions (local maxima of the density function).

    print("Cluster Centers:\n", cluster_centers)
    

Full Code

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

# Create dataset
data, labels_true = make_blobs(n_samples=300, centers=4, cluster_std=0.6, random_state=42)

# Estimate bandwidth
bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=150)
print("Estimated Bandwidth:", bandwidth)

# Perform Mean-Shift Clustering
mean_shift = MeanShift(bandwidth=bandwidth, bin_seeding=True)
mean_shift.fit(data)

# Extract cluster info
cluster_centers = mean_shift.cluster_centers_
labels = mean_shift.labels_
n_clusters = len(np.unique(labels))

print("Number of clusters:", n_clusters)
print("Cluster centers:\n", cluster_centers)

# Visualize clusters
plt.figure(figsize=(8, 6))
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan']

for i in range(n_clusters):
    cluster_data = data[labels == i]
    plt.scatter(cluster_data[:, 0], cluster_data[:, 1], s=50, color=colors[i], label=f"Cluster {i}")

plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], s=200, color='black', marker='X', label="Centers")
plt.title("Mean-Shift Clustering")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.show()

Summary of Output Details

  1. Bandwidth: Displays the estimated optimal bandwidth value.
  2. Number of Clusters: Prints the number of clusters found by the algorithm.
  3. Cluster Centers: Outputs the coordinates of cluster centers.
  4. Scatter Plot: Visualizes clusters with distinct colors and highlights their centers.

This step-by-step process helps in understanding Mean-Shift Clustering and provides a clear workflow for implementation in Python.

Next Blog- Spectral Clustering

Purnima
0

You must logged in to post comments.

Related Blogs

Get In Touch

123 Street, New York, USA

+012 345 67890

techiefreak87@gmail.com

© Design & Developed by HW Infotech