Lesson
Subplots and facet grids
Learn Subplots and facet grids in SQLPad's Data Science in Action course with practical examples and guided lessons.
In this lesson, we will explore how to create subplots and facet grids using Plotly. Subplots are a powerful way to display multiple plots in the same figure, and facet grids help to visualize data that has multiple variables. Let's dive in!
Subplots
Subplots are a way to display multiple plots in a single figure. Plotly's make_subplots function allows you to create a grid of subplots with a specified number of rows and columns. You can then add individual plots to specific cells in the grid.
Let's start by importing the necessary libraries and creating a simple subplot:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import plotly.express as px
# Load built-in dataset
df = px.data.iris()
# Create a subplot with 1 row and 2 columns
fig = make_subplots(rows=1, cols=2)
# Add scatter plots to the subplot
fig.add_trace(go.Scatter(x=df['sepal_width'], y=df['sepal_length'], mode='markers', name='Sepal'), row=1, col=1)
fig.add_trace(go.Scatter(x=df['petal_width'], y=df['petal_length'], mode='markers', name='Petal'), row=1, col=2)
# Update subplot titles
fig.update_xaxes(title_text="Sepal Width", row=1, col=1)
fig.update_yaxes(title_text="Sepal Length", row=1, col=1)
fig.update_xaxes(title_text="Petal Width", row=1, col=2)
fig.update_yaxes(title_text="Petal Length", row=1, col=2)
# Show the figure
fig.show()
In the example above, we created a subplot with one row and two columns. We then added two scatter plots to the subplot, one in each column. Lastly, we updated the x and y-axis titles for each plot.
Facet Grids
A facet grid is a collection of subplots that display the same type of plot for different subsets of the data. This is particularly useful when you want to visualize the relationship between multiple variables in your dataset.
In Plotly, you can create a facet grid using the facet_col or facet_row parameter of the plotly.express functions. Let's create a facet grid using the Iris dataset:
import plotly.express as px
# Load built-in dataset
df = px.data.iris()
# Create a scatter plot facet grid
fig = px.scatter(df, x='sepal_width', y='sepal_length', color='species', facet_col='species', title='Sepal Width vs Sepal Length for each Species')
# Show the figure
fig.show()
In the example above, we used the facet_col parameter to create a facet grid with each column representing a different species in the Iris dataset. The scatter plot shows the relationship between sepal width and sepal length for each species.
You can also create a facet grid using the facet_row parameter:
# Create a scatter plot facet grid with rows
fig = px.scatter(df, x='sepal_width', y='sepal_length', color='species', facet_row='species', title='Sepal Width vs Sepal Length for each Species (rows)')
# Show the figure
fig.show()
In this example, we used the facet_row parameter to create a facet grid with each row representing a different species in the Iris dataset.
That's it! Now you know how to create subplots and facet grids in Plotly. Experiment with different combinations of rows and columns to create more complex visualizations.
Exercises
1. Subplots and Facet Grids
Instruction
Create a subplot with 1 row and 2 columns, displaying scatter plots of sepal width vs sepal length and petal width vs petal length using the Iris dataset. Then, create a facet grid with each column representing a different species in the Iris dataset, showing the relationship between sepal width and sepal length for each species.
My Solution
# Your solution goes here
Hint
- Import the necessary libraries: plotly.graph_objects, plotly.subplots, pandas, and plotly.express.
- Load the Iris dataset using pandas.
- Create a subplot with 1 row and 2 columns using
make_subplots. - Add scatter plots to the subplot using
add_traceand specifying the row and column. - Update the x and y-axis titles for each plot using
update_xaxesandupdate_yaxes. - Show the figure using
fig.show(). - Create a scatter plot facet grid using
px.scatterand thefacet_colparameter. - Show the facet grid figure using
fig.show().
Solution
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import plotly.express as px
# Load built-in dataset
df = px.data.iris()
# Create a subplot with 1 row and 2 columns
fig = make_subplots(rows=1, cols=2)
# Add scatter plots to the subplot
fig.add_trace(go.Scatter(x=df['sepal_width'], y=df['sepal_length'], mode='markers', name='Sepal'), row=1, col=1)
fig.add_trace(go.Scatter(x=df['petal_width'], y=df['petal_length'], mode='markers', name='Petal'), row=1, col=2)
# Update subplot titles
fig.update_xaxes(title_text="Sepal Width", row=1, col=1)
fig.update_yaxes(title_text="Sepal Length", row=1, col=1)
fig.update_xaxes(title_text="Petal Width", row=1, col=2)
fig.update_yaxes(title_text="Petal Length", row=1, col=2)
# Show the figure
fig.show()
# Create a scatter plot facet grid
fig = px.scatter(df, x='sepal_width', y='sepal_length', color='species', facet_col='species', title='Sepal Width vs Sepal Length for each Species')
# Show the figure
fig.show()