Matplotlib: Subplots and Layout Management
When you need to display multiple plots within a single figure, Matplotlib's subplot functionality comes into play. This allows for organized visualization of different aspects of your data or comparisons between various models. Matplotlib offers several ways to create and manage subplot layouts.
1. Using plt.subplots()
This is the recommended method for creating figures and subplots. It returns a Figure object and an Axes object (or an array of Axes objects if you specify more than one subplot).
Basic subplots() usage (single plot)
import matplotlib.pyplot as plt
import numpy as np
# Create a figure and a single subplot
fig, ax = plt.subplots()
# Plot on the Axes object
ax.plot([0, 1, 2, 3], [0, 1, 4, 9])
ax.set_title('Single Subplot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
plt.show()
Multiple Subplots in a Grid
You can specify the number of rows and columns for your grid.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2 * np.pi, 400)
y_sin = np.sin(x)
y_cos = np.cos(x)
y_tan = np.tan(x)
y_square = x**2
# Create a figure with 2 rows and 2 columns of subplots
# fig: the entire figure
# axes: a 2D NumPy array of Axes objects (subplots)
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))
# Plot on the first subplot (top-left)
axes[0, 0].plot(x, y_sin, color='blue')
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_ylabel('sin(x)')
# Plot on the second subplot (top-right)
axes[0, 1].plot(x, y_cos, color='red')
axes[0, 1].set_title('Cosine Wave')
axes[0, 1].set_ylabel('cos(x)')
# Plot on the third subplot (bottom-left)
axes[1, 0].plot(x, y_tan, color='green')
axes[1, 0].set_title('Tangent Wave')
axes[1, 0].set_xlabel('X-axis')
axes[1, 0].set_ylabel('tan(x)')
axes[1, 0].set_ylim(-5, 5) # Limit y-axis for tan
# Plot on the fourth subplot (bottom-right)
axes[1, 1].plot(x, y_square, color='purple')
axes[1, 1].set_title('X Squared')
axes[1, 1].set_xlabel('X-axis')
axes[1, 1].set_ylabel('x^2')
# Adjust layout to prevent titles/labels from overlapping
plt.tight_layout()
plt.suptitle('Various Mathematical Functions', y=1.02, fontsize=16) # Super title for the entire figure
plt.show()
2. Sharing Axes
Sometimes, plots in a grid might share the same x-axis or y-axis. plt.subplots() has sharex and sharey parameters.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.tan(x)
# Share the x-axis across all subplots, share y-axis only for columns
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(8, 10), sharex=True)
axes[0].plot(x, y1, color='blue')
axes[0].set_title('Sine Wave')
axes[0].set_ylabel('sin(x)')
axes[1].plot(x, y2, color='red')
axes[1].set_title('Cosine Wave')
axes[1].set_ylabel('cos(x)')
axes[2].plot(x, y3, color='green')
axes[2].set_title('Tangent Wave')
axes[2].set_xlabel('X-axis')
axes[2].set_ylabel('tan(x)')
axes[2].set_ylim(-5, 5)
plt.tight_layout()
plt.suptitle('Shared X-axis Example', y=1.02, fontsize=16)
plt.show()
3. Using plt.subplot() (Legacy Method)
plt.subplot() creates a grid of subplots within a figure, returning the current Axes object. It uses a nrows, ncols, index notation. While functional, plt.subplots() is generally preferred for its more explicit return values (Figure and Axes objects).
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
plt.figure(figsize=(10, 4))
# Create the first subplot (1 row, 2 columns, 1st plot)
plt.subplot(1, 2, 1)
plt.plot(x, y1)
plt.title('Plot 1 (Sine)')
plt.xlabel('X')
plt.ylabel('Y')
# Create the second subplot (1 row, 2 columns, 2nd plot)
plt.subplot(1, 2, 2)
plt.plot(x, y2, color='orange')
plt.title('Plot 2 (Cosine)')
plt.xlabel('X')
plt.ylabel('Y')
plt.tight_layout()
plt.show()
4. GridSpec for Complex Layouts
For more complex or uneven subplot layouts, GridSpec offers fine-grained control over the placement and span of subplots.
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
# Create some dummy data
x = np.arange(0, 10, 0.1)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = x
y4 = x**2
fig = plt.figure(figsize=(12, 8))
gs = GridSpec(nrows=3, ncols=3, figure=fig) # 3 rows, 3 columns grid
# A wide plot at the top
ax1 = fig.add_subplot(gs[0, :]) # First row, spans all 3 columns
ax1.plot(x, y1, color='blue')
ax1.set_title('Wide Plot (sin(x))')
# Two plots in the middle row
ax2 = fig.add_subplot(gs[1, 0]) # Second row, first column
ax2.plot(x, y2, color='red')
ax2.set_title('Plot 2 (cos(x))')
ax3 = fig.add_subplot(gs[1, 1:]) # Second row, spans columns 1 and 2
ax3.plot(x, y3, color='green')
ax3.set_title('Plot 3 (x)')
# A tall plot on the right, spanning two rows
ax4 = fig.add_subplot(gs[2, 0]) # Third row, first column
ax4.plot(x, y4, color='purple')
ax4.set_title('Plot 4 (x^2)')
ax5 = fig.add_subplot(gs[2, 1:]) # Third row, spans columns 1 and 2
ax5.hist(np.random.randn(1000), bins=30, color='gray', alpha=0.7)
ax5.set_title('Histogram')
plt.tight_layout()
plt.suptitle('Complex Layout with GridSpec', y=1.02, fontsize=16)
plt.show()
5. Adjusting Spacing
plt.tight_layout() is usually sufficient, but for manual control, you can use fig.subplots_adjust().
# fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.4, hspace=0.4)
# wspace: width reserved for space between subplots
# hspace: height reserved for space between subplots
Further Topics:
add_axes()for arbitrary placement of axes.- Integrating with
plt.figure(). - Customizing subplot titles and labels.
- Dealing with overlapping elements in complex layouts.
Mastering subplots and layout management is crucial for creating informative and professional-looking multi-panel visualizations in Matplotlib.