Drill-downs and filtering with Streamlit and Altair
For the last few months, I've written several apps using Streamlit, and one of my favorite libraries for optimizing the look and feel of my app is Altair. You can display an Altair chart definition in Streamlit using the st.altair_chart
widget. Altair is a powerful library full of styling, configurations, and interactions.
In this post, I'll show you how to create interactive and dynamic visualizations using Altair in Streamlit.
Avoiding re-runs
One significant advantage of creating drill-downs and filters with Altair on Streamlit is that these interactions occur at the front-end level and don't require a re-run of your Streamlit app.
Let's create some data first
I used a simple approach for this example and created a Pandas DataFrame with sales data. I also used the st.cache_data
decorator to save the DataFrame in the cache:
@st.cache_data
def get_data():
dates = pd.date_range(start="1/1/2022", end="12/31/2022")
data = pd.DataFrame()
sellers = {
"LATAM": ["S01", "S02", "S03"],
"EMEA": ["S10", "S11", "S12", "S13"],
"NA": ["S21", "S22", "S23", "S24", "S25", "S26"],
"APAC": ["S31", "S32", "S33", "S34", "S35", "S36"],
}
rows = 25000
data["transaction_date"] = np.random.choice([str(i) for i in dates], size=rows)
data["region"] = np.random.choice(regions, size=rows, p=[0.1, 0.3, 0.4, 0.2])
data["transaction_amount"] = np.random.uniform(100, 250000, size=rows).round(2)
data["seller"] = data.apply(
lambda x: np.random.choice(sellers.get(x["region"])), axis=1
)
return data.sort_values(by="transaction_date").reset_index(drop=True)
Color consistency
When creating drill-downs, it's crucial to maintain color consistency to enhance the clarity of your charts. Altair scales can be used to specify color domains and ranges that persist during drill-down.
Use the following three list variables in the Altair chart definitions:
regions = ["LATAM", "EMEA", "NA", "APAC"]
colors = ["#aa423a","#f6b404", "#327a88","#303e55","#c7ab84","#b1dbaa",
"#feeea5","#3e9a14","#6e4e92","#c98149", "#d1b844","#8db6d8"]
months = [
"Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec",
]
Selections
To define a selection in Altair, you can use the single, interval, or point methods.
In this example, a single selection is used to drill down based on region. The "empty" attribute can be specified to determine whether all objects or no objects are visible when no selection has been made.
region_select = alt.selection_single(fields=["region"], empty="all")
Filtering chart
This chart definition includes an Altair add_selection
method for filtering other chart definitions that will be created later.
To create a dynamic experience, the opacity attribute is used to reduce the opacity of unselected objects to 25%. Note that the color method uses the scale attribute to limit the available colors to the domain and range variables defined previously:
region_pie = (
(
alt.Chart(sales_data)
.mark_arc(innerRadius=50)
.encode(
theta=alt.Theta(
"transaction_amount",
type="quantitative",
aggregate="sum",
title="Sum of Transactions",
),
color=alt.Color(
field="region",
type="nominal",
scale=alt.Scale(domain=regions, range=colors),
title="Region",
),
opacity=alt.condition(region_select, alt.value(1), alt.value(0.25)),
)
)
.add_selection(region_select)
.properties(title="Region Sales")
)
Filtered and faceted charts
To enable filtering, implement the Altair transform_filter
method. Faceting is enabled by using the facet method inside the encoding method, which uses a field attribute and a column attribute to break down the chart into multiple related charts.
If you need to facet by two different fields, use the repeat()
method:
region_summary = (
(
alt.Chart(sales_data)
.mark_bar()
.encode(
x=alt.X(
"month(transaction_date)",
type="temporal",
),
y=alt.Y(
field="transaction_amount",
type="quantitative",
aggregate="sum",
title="Total Sales",
),
color=alt.Color(
"region",
type="nominal",
title="Regions",
scale=alt.Scale(domain=regions, range=colors),
legend=alt.Legend(
direction="vertical",
symbolType="triangle-left",
tickCount=4,
),
),
)
)
.transform_filter(region_select)
.properties(width=700, title="Monthly Sales")
)
sellers_monthly_pie = (
(
alt.Chart(sales_data)
.mark_arc(innerRadius=10)
.encode(
theta=alt.Theta(
field="transaction_amount",
type="quantitative",
aggregate="sum",
title="Total Transactions",
),
color=alt.Color(
"month(transaction_date)",
type="temporal",
title="Month",
scale=alt.Scale(domain=months, range=colors),
legend=alt.Legend(
direction="vertical",
symbolType="triangle-left",
tickCount=12,
),
),
facet=alt.Facet(
field="seller",
type="nominal",
columns=8,
title="Sellers",
),
tooltip=alt.Tooltip(["sum(transaction_amount)", "month(transaction_date)"]),
)
)
.transform_filter(region_select)
.properties(width=150, height=150, title="Sellers transactions per month")
)
All of it in just one widget
To enable selections and filters, enclose them within a single Streamlit widget. To arrange charts, use the following methods:
- VConcat and HConcat methods to concatenate charts
- Pipe symbol "|" to place charts next to each other
- Ampersand symbol "&" to set charts below or above each other
- Plus sign “+” to overlay charts
To make a dashboard-like arrangement, create a variable called top_row
and use the pipe symbol to arrange your region_pie
and region_summary
charts side by side. Then, using the ampersand, place top_row
and sellers_monthly_pie
below it. This creates a variable containing all the concatenated charts in a single Streamlit altair_chart
widget.
Note that when using concatenated charts, the use_container_width
attribute won't work. Therefore, you must specify the width of the charts in their properties:
#Create first row by concatenating horizontally
top_row = region_pie | region_summary
#Create dashboard by concatenating top_row with faceted chart
full_chart = top_row & sellers_monthly_pie
#Single Streamlit Object
st.altair_chart(full_chart)
Dashboard
Full code
import streamlit as st
import altair as alt
import pandas as pd
import numpy as np
st.set_page_config(layout="wide")
regions = ["LATAM", "EMEA", "NA", "APAC"]
colors = [
"#aa423a",
"#f6b404",
"#327a88",
"#303e55",
"#c7ab84",
"#b1dbaa",
"#feeea5",
"#3e9a14",
"#6e4e92",
"#c98149",
"#d1b844",
"#8db6d8",
]
months = [
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
]
st.title("2022 Sales Dashboard")
@st.cache_data
def get_data():
dates = pd.date_range(start="1/1/2022", end="12/31/2022")
data = pd.DataFrame()
sellers = {
"LATAM": ["S01", "S02", "S03"],
"EMEA": ["S10", "S11", "S12", "S13"],
"NA": ["S21", "S22", "S23", "S24", "S25", "S26"],
"APAC": ["S31", "S32", "S33", "S34", "S35", "S36"],
}
rows = 25000
data["transaction_date"] = np.random.choice([str(i) for i in dates], size=rows)
data["region"] = np.random.choice(regions, size=rows, p=[0.1, 0.3, 0.4, 0.2])
data["transaction_amount"] = np.random.uniform(100, 250000, size=rows).round(2)
data["seller"] = data.apply(
lambda x: np.random.choice(sellers.get(x["region"])), axis=1
)
return data.sort_values(by="transaction_date").reset_index(drop=True)
sales_data = get_data()
region_select = alt.selection_single(fields=["region"], empty="all")
region_pie = (
(
alt.Chart(sales_data)
.mark_arc(innerRadius=50)
.encode(
theta=alt.Theta(
"transaction_amount",
type="quantitative",
aggregate="sum",
title="Sum of Transactions",
),
color=alt.Color(
field="region",
type="nominal",
scale=alt.Scale(domain=regions, range=colors),
title="Region",
),
opacity=alt.condition(region_select, alt.value(1), alt.value(0.25)),
)
)
.add_selection(region_select)
.properties(title="Region Sales")
)
region_summary = (
(
alt.Chart(sales_data)
.mark_bar()
.encode(
x=alt.X(
"month(transaction_date)",
type="temporal",
),
y=alt.Y(
field="transaction_amount",
type="quantitative",
aggregate="sum",
title="Total Sales",
),
color=alt.Color(
"region",
type="nominal",
title="Regions",
scale=alt.Scale(domain=regions, range=colors),
legend=alt.Legend(
direction="vertical",
symbolType="triangle-left",
tickCount=4,
),
),
)
)
.transform_filter(region_select)
.properties(width=700, title="Monthly Sales")
)
sellers_monthly_pie = (
(
alt.Chart(sales_data)
.mark_arc(innerRadius=10)
.encode(
theta=alt.Theta(
field="transaction_amount",
type="quantitative",
aggregate="sum",
title="Total Transactions",
),
color=alt.Color(
"month(transaction_date)",
type="temporal",
title="Month",
scale=alt.Scale(domain=months, range=colors),
legend=alt.Legend(
direction="vertical",
symbolType="triangle-left",
tickCount=12,
),
),
facet=alt.Facet(
field="seller",
type="nominal",
columns=8,
title="Sellers",
),
tooltip=alt.Tooltip(["sum(transaction_amount)", "month(transaction_date)"]),
)
)
.transform_filter(region_select)
.properties(width=150, height=150, title="Sellers transactions per month")
)
top_row = region_pie | region_summary
full_chart = top_row & sellers_monthly_pie
st.altair_chart(full_chart)
Wrapping up
Altair charts in Streamlit are an efficient and performant way to add interactive charts to your app. There are many styles and combinations of interactions to create using these tools.
If you loved this post, check out my other articles on client-side filtering using Altair Sliders, paginating dataframes, and the multiselect widget. And if you have any questions, please post them in the comments below or contact me on GitHub, LinkedIn, Twitter, or Medium.
Happy Streamlit-ing! 🎈