Lesson
Heatmaps
Learn Heatmaps in SQLPad's Data Science in Action: Interactive Visualization with Plotly and Pandas course with practical examples and guided lessons.
Introduction
In this lesson, we will learn about Heatmaps, a powerful data visualization tool that helps us analyze complex datasets and discover patterns or trends. We will be using the popular Plotly library in combination with Pandas to create interactive and customizable heatmaps. Heatmaps represent data using a color scale, where each cell's color intensity corresponds to the value of the underlying data point.
By the end of this lesson, you'll have a solid understanding of how to create and customize heatmaps using Plotly and Pandas, and how to interpret the information they convey.
Importing Required Libraries
In this lesson, we will learn how to create a heatmap using Plotly and Pandas. We will use the built-in dataset px.data.gapminder() from Plotly for this example.
Code Block 1: Importing Required Libraries and Loading the Dataset
import pandas as pd
import plotly.express as px
# Load the dataset
df = px.data.gapminder()
# Display the first few rows of the dataset
print(df.head())
Code Block 2: Creating a Heatmap
# Create a heatmap to visualize the life expectancy of countries over time
fig = px.imshow(df.pivot_table(index="year", columns="country", values="lifeExp"),
labels=dict(x="Country", y="Year", color="Life Expectancy"),
title="Life Expectancy by Country and Year")
# Show the heatmap
fig.show()
Loading and Preparing Data with Pandas
In this code example, we will load and prepare data using Pandas in order to create a heatmap.
Code Block 1: Loading and Preparing Data with Pandas
We will use the flights built-in dataset from Plotly. This dataset contains flight data such as the number of passengers per month from 1949 to 1960.
import pandas as pd
import plotly.express as px
import pyodide.http
# Load the flights dataset
df = pd.read_csv(pyodide.http.open_url('https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv'))
# If you are running the code outside of this website, e.g., a notbook on your local computer, you can use pd.read_csv directly
# df = pd.read_csv('https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv')
df.rename(columns={'Passengers':'passengers', 'Month':'month'}, inplace=True)
df['first_day_of_mon'] = pd.to_datetime(df['month'])
df['year'] = df['first_day_of_mon'].dt.year
df['month'] = df['first_day_of_mon'].dt.month
# Pivot the dataset to create a matrix with years as rows and months as columns
df = df.pivot_table(values='passengers', index='year', columns='month')
# Display the first few rows of the dataframe
print(df.head())
Code Block 2: Constructing a Heatmap using Plotly
After preparing the data, we will use Plotly to create a heatmap that visualizes the number of passengers per month over the years.
import plotly.graph_objects as go
# Create a heatmap
fig = go.Figure(data=go.Heatmap(
z=df.values,
x=df.columns,
y=df.index,
colorscale='Viridis'))
# Customize the layout
fig.update_layout(
title='Number of Passengers per Month (1949-1960)',
xaxis_title='Month',
yaxis_title='Year'
)
# Display the heatmap
fig.show()
Customizing the Color Scale
In this code example, we will customize the color scale of a heatmap using Plotly and Pandas. We will use the built-in dataset px.data.gapminder() from Plotly.
First, let's import the necessary libraries and create a dataframe using the gapminder dataset. We will then display the first few rows of the dataframe.
import plotly.express as px
import pandas as pd
df = px.data.gapminder()
print(df.head())
Now that we have our dataframe, let's create a heatmap and customize its color scale. We will use the px.density_heatmap() function to create the heatmap and set the color_continuous_scale parameter to customize the color scale.
fig = px.density_heatmap(df,
x="year",
y="continent",
z="pop",
color_continuous_scale="Viridis",
title="Customizing the Color Scale")
fig.show()
In this example, we used the "Viridis" color scale. You can experiment with different color scales by changing the color_continuous_scale parameter value. Some other color scales you can try are "Inferno", "Magma", and "Plasma".
Adding Annotations to the Heatmap
In this code example, we will be adding annotations to a heatmap using Plotly and Pandas. We will use the built-in dataset iris from Plotly for this exercise.
First, we need to load the iris dataset, group it by species, and calculate the mean values for each feature.
import plotly.express as px
import pandas as pd
# Load the iris dataset
data = px.data.iris()
# Group by species and calculate the mean values
df = data.groupby('species').mean()
print(df)
Now that we have the mean values for each feature, we can create the heatmap and add annotations to display the values on the heatmap.
import plotly.figure_factory as ff
# Create the heatmap
fig = ff.create_annotated_heatmap(
z=df.values,
x=list(df.columns),
y=list(df.index.astype(str)),
annotation_text=df.round(2).values,
colorscale='Viridis'
)
# Set the title and axis labels
fig.update_layout(
title='Mean Feature Values by Iris Species',
xaxis_title='Features',
yaxis_title='Species'
)
# Display the heatmap
fig.show()
Exercises
1. Heatmaps with Plotly
Instruction
Create a heatmap to visualize the correlation matrix of the Iris dataset using Plotly.
My Solution
# Your solution goes here
Hint
- Load the Iris dataset using
px.data.iris(). - Calculate the correlation matrix using the
corr()function. - Create the heatmap using
ff.create_annotated_heatmap()with the correlation matrix values and a custom color scale. - Set the axis labels using
fig.update_layout()and thexaxis_title,yaxis_title,xaxis, andyaxisattributes. - Display the heatmap using
fig.show().
Solution
import plotly.figure_factory as ff
import plotly.express as px
import pandas as pd
# Loading the Iris dataset
iris = px.data.iris()
# Calculating the correlation matrix
corr_matrix = iris.corr()
# Generating the heatmap
fig = ff.create_annotated_heatmap(corr_matrix.values, colorscale='Viridis')
# Setting axis labels
fig.update_layout(
xaxis_title='Features',
yaxis_title='Features',
xaxis=dict(tickvals=list(range(len(iris.columns[:-1]))), ticktext=iris.columns[:-1]),
yaxis=dict(tickvals=list(range(len(iris.columns[:-1]))), ticktext=iris.columns[:-1])
)
# Displaying the heatmap
fig.show()