Scikit-learn: Clustering Algorithms
Clustering is an unsupervised machine learning task that involves grouping a set of objects in such a way that objects in the same group (called a cluster) are more similar to each other than to those in other groups. Scikit-learn provides a variety of clustering algorithms to discover hidden patterns or groupings in data.
1. K-Means Clustering
K-Means is one of the simplest and most popular unsupervised learning algorithms. It aims to partition n observations into k clusters, where each observation belongs to the cluster with the nearest mean (centroid), serving as a prototype of the cluster.
Key Concepts:
- Centroid: The center of a cluster, typically the mean of all points assigned to that cluster.
- Inertia: The sum of squared distances of samples to their closest cluster center. K-Means tries to minimize inertia.
n_clusters: The number of clusters to form, which needs to be specified beforehand.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs # For synthetic clustering data
from sklearn.metrics import silhouette_score
# 1. Generate synthetic clustering data
# n_samples: total number of points
# centers: number of centers to generate
# cluster_std: standard deviation of the clusters
X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
# 2. Create a K-Means model
# n_clusters: number of clusters (a key hyperparameter)
# init='k-means++': smart initialization method to help convergence
# n_init='auto': number of times the k-means algorithm is run with different centroid seeds
kmeans = KMeans(n_clusters=4, init='k-means++', n_init='auto', random_state=42)
# 3. Fit the model and predict cluster labels
kmeans.fit(X)
y_kmeans = kmeans.predict(X)
# 4. Evaluate (e.g., using Silhouette Score if ground truth is unknown)
# Silhouette Score: ranges from -1 (worst) to +1 (best). 0 indicates overlapping clusters.
# Note: A ground truth (y_true) is available for make_blobs, but clustering is unsupervised.
# We evaluate based on intrinsic metrics like silhouette_score.
score = silhouette_score(X, y_kmeans)
print(f"K-Means Silhouette Score: {score:.2f}")
print(f"K-Means Inertia: {kmeans.inertia_:.2f}")
# 5. Visualize the clusters and their centroids
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='viridis', alpha=0.7)
centers = kmeans.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1], c='red', s=200, alpha=0.9, marker='X', label='Centroids')
plt.title('K-Means Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.show()
Choosing the Optimal K for K-Means
The "Elbow Method" and "Silhouette Score" are common techniques.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
from sklearn.metrics import silhouette_score
X, _ = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
# Elbow Method to find optimal K
wcss = [] # Within-Cluster Sum of Squares (Inertia)
for i in range(1, 11): # Try K from 1 to 10
kmeans = KMeans(n_clusters=i, init='k-means++', n_init='auto', random_state=42)
kmeans.fit(X)
wcss.append(kmeans.inertia_)
plt.figure(figsize=(8, 4))
plt.plot(range(1, 11), wcss, marker='o')
plt.title('Elbow Method')
plt.xlabel('Number of clusters (K)')
plt.ylabel('WCSS (Inertia)')
plt.xticks(np.arange(1, 11, 1))
plt.grid(True)
plt.show()
# The "elbow" point indicates a good balance between distortion and number of clusters.
# Silhouette Score for different K
silhouette_scores = []
for i in range(2, 11): # Silhouette score requires at least 2 clusters
kmeans = KMeans(n_clusters=i, init='k-means++', n_init='auto', random_state=42)
kmeans.fit(X)
score = silhouette_score(X, kmeans.labels_)
silhouette_scores.append(score)
plt.figure(figsize=(8, 4))
plt.plot(range(2, 11), silhouette_scores, marker='o', color='green')
plt.title('Silhouette Score Method')
plt.xlabel('Number of clusters (K)')
plt.ylabel('Silhouette Score')
plt.xticks(np.arange(2, 11, 1))
plt.grid(True)
plt.show()
# Higher silhouette score indicates better-defined clusters.
2. DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
DBSCAN is a density-based clustering algorithm that groups together points that are closely packed together, marking as outliers points that lie alone in low-density regions. It does not require specifying the number of clusters beforehand.
Key Concepts:
eps(epsilon): The maximum distance between two samples for one to be considered as in the neighborhood of the other.min_samples: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.- Core point: A point that has at least
min_samplespoints (including itself) within distanceeps. - Border point: A point that has fewer than
min_samplespoints withineps, but is in the neighborhood of a core point. - Noise point: A point that is neither a core point nor a border point.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from sklearn.datasets import make_moons # For non-globular data
# 1. Generate synthetic non-globular data
X, y_true = make_moons(n_samples=200, noise=0.05, random_state=0)
# 2. Create a DBSCAN model
# eps: radius of neighborhood
# min_samples: minimum points in eps-neighborhood to be a core point
dbscan = DBSCAN(eps=0.3, min_samples=5)
# 3. Fit the model and get cluster labels
dbscan.fit(X)
y_dbscan = dbscan.labels_ # -1 indicates noise points
# 4. Visualize the clusters
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], c=y_dbscan, s=50, cmap='plasma', alpha=0.7)
plt.title('DBSCAN Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
print(f"Number of clusters found by DBSCAN (excluding noise): {len(set(y_dbscan)) - (1 if -1 in y_dbscan else 0)}")
print(f"Number of noise points: {list(y_dbscan).count(-1)}")
3. Hierarchical Clustering (Agglomerative Clustering)
Hierarchical clustering builds a hierarchy of clusters. Agglomerative clustering is a "bottom-up" approach: each observation starts in its own cluster, and pairs of clusters are merged as one moves up the hierarchy.
Key Concepts:
- Linkage: The criterion used to decide which clusters to merge (e.g., 'ward', 'average', 'complete', 'single').
- Distance Metric: How the distance between individual data points is measured.
- Dendrogram: A tree-like diagram that records the sequences of merges or splits.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import make_blobs
from scipy.cluster.hierarchy import dendrogram, linkage # For visualization
X, y_true = make_blobs(n_samples=50, centers=3, cluster_std=0.8, random_state=0)
# 1. Create an Agglomerative Clustering model
# n_clusters: The number of clusters to form.
# linkage: Which linkage criterion to use.
agg_clustering = AgglomerativeClustering(n_clusters=3, linkage='ward')
# 2. Fit and predict
y_agg = agg_clustering.fit_predict(X)
# 3. Visualize the clusters
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], c=y_agg, s=50, cmap='rainbow', alpha=0.7)
plt.title('Agglomerative Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
# 4. Visualize the dendrogram (requires scipy)
# Perform hierarchical clustering for dendrogram
linked = linkage(X, method='ward') # 'ward' minimizes variance within clusters
plt.figure(figsize=(10, 7))
dendrogram(linked, orientation='top', distance_sort='descending', show_leaf_counts=True)
plt.title('Hierarchical Clustering Dendrogram')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()
Further Topics:
- Mean-Shift clustering
- Gaussian Mixture Models (GMM)
- Evaluation metrics for clustering (Adjusted Rand Index, Mutual Information, Homogeneity, Completeness, V-measure if ground truth is known).
- Preprocessing for clustering (scaling, dimensionality reduction).
Clustering algorithms are powerful tools for exploratory data analysis and discovering underlying structures in data. Choosing the right algorithm depends heavily on the characteristics of your dataset and the definition of a "cluster" relevant to your problem.