Heatmap


  Data visualization

Table of Contents

What is a Heatmap?

A Heatmap is a graphical representation of data which represents data values using colors. Heat maps make it easy to visualize complex data and are useful for identifying patterns and areas of concentration. They are used to show the relation between two variables, plotted on x and y axis. The variables plotted on each axis can be of any type, categorical labels or numerical values. When plotting categorical variables on the axis, it is good to properly order the variables based on some criteria in order to reveal the patterns in the data.

We can use categorical color palettes to represent categorical data, while numerical data requires a colour scale that blends from one colour to another, in order to represent high and low values.

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

x = np.random.rand(5,5)
sample_data = pd.DataFrame(x)
sample_data
0	1	2	3	4
0	0.104282	0.520084	0.586270	0.421520	0.924707
1	0.264854	0.518364	0.965355	0.545920	0.975186
2	0.546176	0.509287	0.142528	0.992683	0.136078
3	0.349071	0.260509	0.267000	0.135422	0.025028
4	0.845259	0.869380	0.178070	0.917876	0.344369

Consider the DataFrame – sample_data which consists of data in a tabular format. The above table has 5 rows and 5 columns. Let us visualize this data using a heatmap, that means we have to transform the values in this DataFrame into different colors. Seaborn provides the heatmap() function for this purpose.

sns.heatmap(data=sample_data)
plt.xlabel('column names-->')
plt.ylabel('index values-->')
plt.title('Heatmap with random values')
plt.show()


The Pandas DataFrame is passed to the heatmap() function and the plot is displayed using the Matplotlib show() function.

The heatmap generated above represents the numerical values in sample_data with different colors. The DataFrame index values are used as y-tick labels and the column names are used as x-tick labels. This heatmap uses dark colors to display low values and light colors to display high values.

A color bar which represents the relationship between colors and values is displayed to the right hand side of the figure. The color bar has ticks positioned at 0.2, 0.4, 0.6, 0.8. The tick positions are calculated depending on the minimum and maximum data values in the input data. The colormap used here has dark colors at one extreme and light colors at the other extreme.

Let us look into another example, by loading the built-in ‘flights’ dataset.

flights_dataset = sns.load_dataset('flights')
flights_dataset.head(15)
year	month	passengers
0	1949	January	112
1	1949	February	118
2	1949	March	132
3	1949	April	129
4	1949	May	121
5	1949	June	135
6	1949	July	148
7	1949	August	148
8	1949	September	136
9	1949	October	119
10	1949	November	104
11	1949	December	118
12	1950	January	115
13	1950	February	126
14	1950	March	141

Data in the ‘flights_dataset’ DataFrame is in long form, we will reorganise the data and convert to wide form data using the Pandas pivot_table() function.

flights= pd.pivot_table(data=flights_dataset,index='month',columns='year',values='passengers')
flights
year	1949	1950	1951	1952	1953	1954	1955	1956	1957	1958	1959	1960
month												
January	112	115	145	171	196	204	242	284	315	340	360	417
February	118	126	150	180	196	188	233	277	301	318	342	391
March	132	141	178	193	236	235	267	317	356	362	406	419
April	129	135	163	181	235	227	269	313	348	348	396	461
May	121	125	172	183	229	234	270	318	355	363	420	472
June	135	149	178	218	243	264	315	374	422	435	472	535
July	148	170	199	230	264	302	364	413	465	491	548	622
August	148	170	199	242	272	293	347	405	467	505	559	606
September	136	158	184	209	237	259	312	355	404	404	463	508
October	119	133	162	191	211	229	274	306	347	359	407	461
November	104	114	146	172	180	203	237	271	305	310	362	390
December	118	140	166	194	201	229	278	306	336	337	405	432

The above data is in a format which is useful for our analysis, we will now plot a heatmap.

sns.heatmap(data=flights)
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()

The heatmap above has transformed the numerical values in the ‘flights’ DataFrame to different colors. The index values(‘month’) of the ‘flights’ DataFrame are used as y-tick labels and column names(‘year’) are used as x-tick labels. The heatmap uses dark colors to display low values and light colors to display high values.

In this example, the heatmap is generated using the default parameter values. Let us customize the appearance of the heatmap by changing the default settings.

Customize your heatmap

Example 1

sns.heatmap(data=flights,vmin=100,vmax=630,annot=True,fmt='d',linewidth=0.3,cbar=True)
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()


In the above heatmap, we have set the lower and upper bounds for the color bar, displayed numerical values in each cell and added borders to the cells. These parameters are defined below:

Parameters:

  • vmin,vmax : values to set a Colorbar range, vmin is the lower bound and vmax is the upper bound for color scaling.If no value is specified then the limits are inferred from minimum and maximum data values in input data.
  • annot : If True, the data values corresponding to each cell are displayed.
  • fmt : String formatting code to use when adding annotations.
  • linewidth : separate each cell of the heatmap using the specified linewidth.
  • cbar : can be used to display or remove colorbar from the plot, default value is True.

Example 2

sns.heatmap(data=flights,center=flights.loc['July',1957])
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()


The above heatmap is plotted by setting the parameter ‘center’ to a numerical value and this value is used as the center of the colormap when plotting data.

In this example, the value in the cell indicated by (July,1957) is the center of the colormap. If you notice carefully, this is actually the new midpoint of the data and the cells to the left of the midpoint are indicated with a color scheme that gradually goes from blue to light green. The cells to the right of the midpoint are indicated with a color scheme that gradually goes from black to red. So a divergent color scheme is applied to the heatmap.

Example 3

sns.heatmap(data=flights,center=flights.loc['July',1960])
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()


In the above heatmap, the center of colormap is set with the value indicated by cell – (July,1960). This heatmap uses light colors to display low values and dark colors to display high values. As you go from left to right extreme of the heatmap, the colors change from light to dark shades which clearly shows an increase in the number of passengers with each year. Also the number of passengers is more in the months of July and August compared to other months in any given year, this is evident from the cells corresponding to July and August which are darker compared to the cells above and below.

Example 4

sns.heatmap(data=sample_data,square=True,xticklabels=2, yticklabels=2,cbar_kws={'orientation':'horizontal'})
plt.show()


By passing True to the parameter ‘square’ you can set the shape of the cells to a square.

By default, the heatmap takes the xtick and ytick labels from the index and columns names of the Dataframe. This can be changed using the parameters xticklabels and yticklabels. These two parameters can take any of following values – “auto”, bool, list, int.

You can specify whether the colorbar is displayed horizontally or vertically by using the color bar keyword arguments(cbar_kws).

Example 5 – Masking

Say for example, you want to display only the cells above or below the diagonal of the heatmap . This can be achieved using masking. Let us display the cells above the diagonal in the heatmap using the input DataFrame sample_data.

np.triu() is a method in NumPy that returns the lower triangle of any array passed to it, while np.tril() returns the upper triangle of any array passed to it. Lets pass Dataframe sample_data to the method np.tril().

lower_triangle  = np.tril(sample_data)
lower_triangle
array([[0.10428215, 0.        , 0.        , 0.        , 0.        ],
       [0.26485421, 0.51836401, 0.        , 0.        , 0.        ],
       [0.54617601, 0.50928675, 0.14252839, 0.        , 0.        ],
       [0.34907137, 0.26050913, 0.26700028, 0.13542169, 0.        ],
       [0.84525918, 0.86937953, 0.17806983, 0.91787589, 0.34436948]])

lower_triangle is the new array formed by extracting the lower (tril) triangular part of sample_data, and setting all other elements to zero.

sns.heatmap(data=sample_data,mask=lower_triangle)
plt.show()


The parameter ‘mask’ accepts an array or DataFrame. To the heatmap() function above we have passed the data as sample_data and mask as lower_triangle which will create a mask on the heatmap. Values will be plotted for cells where ‘mask’ is ‘False’ – that is a value of 0.

Below is another example, which displays the cells below the diagonal.

upper_triangle = np.triu(sample_data)
upper_triangle
array([[0.10428215, 0.52008394, 0.58626989, 0.42152029, 0.92470701],
       [0.        , 0.51836401, 0.96535456, 0.54591957, 0.97518597],
       [0.        , 0.        , 0.14252839, 0.99268308, 0.13607794],
       [0.        , 0.        , 0.        , 0.13542169, 0.02502764],
       [0.        , 0.        , 0.        , 0.        , 0.34436948]])
sns.heatmap(data=sample_data,mask=upper_triangle)
plt.show()

Read and plot data from a csv file

The csv file contains data related to the total number of road accidents and the time of occurence in different Indian cities in 2017. The file used in this example is available in this url: https://data.gov.in/resources/stateut-wise-road-accidents-time-occurance-during-2017

Let us read the data in the csv file into a DataFrame using the Pandas read_csv() function.

input_file = pd.read_csv(r'C:\Users\Ajay Tech\Documents\training\visualization\Data\Road_Accidents_2017.csv')

input_file.tail()
States/UTs	06-09hrs (Day)	09-12hrs (Day)	12-15hrs (Day)	15-18hrs (Day)	18-21hrs (Night)	21-24hrs (Night)	00-03hrs (Night)	03-06hrs (Night)	Total Accidents
32	Daman & Diu	5	7	15	16	17	13	6	0	79
33	Delhi	747	858	828	807	1008	1159	714	552	6673
34	Lakshadweep	0	0	0	1	0	0	0	0	1
35	Puducherry	250	256	250	257	257	216	118	89	1693
36	Total	51551	71426	71594	82456	85686	49567	25050	27580	464910

The last row and column in the Dataframe contain the sum of all the numerical values in each row and column respectively. This is actually not required for our analysis, so let us delete the last row and column.

input_file.drop(36, axis=0,inplace=True)
input_file.drop('Total Accidents', axis=1,inplace=True)
input_file.tail()

States/UTs	06-09hrs (Day)	09-12hrs (Day)	12-15hrs (Day)	15-18hrs (Day)	18-21hrs (Night)	21-24hrs (Night)	00-03hrs (Night)	03-06hrs (Night)
31	D & N Haveli	7	6	12	14	14	5	5	4
32	Daman & Diu	5	7	15	16	17	13	6	0
33	Delhi	747	858	828	807	1008	1159	714	552
34	Lakshadweep	0	0	0	1	0	0	0	0
35	Puducherry	250	256	250	257	257	216	118	89

Next we will reorganize the data in the Dataframe into a format which is required for analysis using pivot_table() function.

input_data = input_file.pivot_table(index='States/UTs')
input_data.head()
00-03hrs (Night)	03-06hrs (Night)	06-09hrs (Day)	09-12hrs (Day)	12-15hrs (Day)	15-18hrs (Day)	18-21hrs (Night)	21-24hrs (Night)
States/UTs								
A & N Islands	3	7	20	27	33	41	39	19
Andhra Pradesh	1406	1648	2808	3581	3765	4484	5265	2770
Arunachal Pradesh	42	39	40	24	25	24	23	24
Assam	393	391	749	1206	1262	1340	1121	708
Bihar	495	755	1344	1550	1292	1462	1312	645

Let us generate a heatmap for the first 20 rows in the data.

sns.heatmap(data = input_data.head(20),cmap='YlOrRd',linewidths=0.3,xticklabels=['00-03','03-06','06-09','09-12','12-15','15-18','18-21','21-24'])
plt.xticks(rotation=20)
plt.xlabel('Time of occurence(hrs)')
plt.title('No. of road accidents/time interval of day')
plt.show()

The heatmap uses light colors for low values and dark colors for high values. Few cities in the heatmap show that the number of accidents are more during the interval 9AM – 9PM, for example for the last city (Madhya Pradesh) in the heatmap dark colors in the colormap are used for the cells in this interval. The maximum number of road accidents for most cities occur during 6-9 PM interval, this can be inferred from the color of the cells in this interval.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: