Heatmap
Table of Contents
- What is a Heatmap?
- Customize your heatmap
- Example 1
- Example 2
- Example 3
- Example 4
- Example 5 – Masking
- Read and plot data from a csv file
What is a Heatmap?
A Heatmap is a graphical representation of data which represents data values using colors. Heat maps make it easy to visualize complex data and are useful for identifying patterns and areas of concentration. They are used to show the relation between two variables, plotted on x and y axis. The variables plotted on each axis can be of any type, categorical labels or numerical values. When plotting categorical variables on the axis, it is good to properly order the variables based on some criteria in order to reveal the patterns in the data.
We can use categorical color palettes to represent categorical data, while numerical data requires a colour scale that blends from one colour to another, in order to represent high and low values.
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
x = np.random.rand(5,5)
sample_data = pd.DataFrame(x)
sample_data
0 1 2 3 4
0 0.104282 0.520084 0.586270 0.421520 0.924707
1 0.264854 0.518364 0.965355 0.545920 0.975186
2 0.546176 0.509287 0.142528 0.992683 0.136078
3 0.349071 0.260509 0.267000 0.135422 0.025028
4 0.845259 0.869380 0.178070 0.917876 0.344369
Consider the DataFrame – sample_data which consists of data in a tabular format. The above table has 5 rows and 5 columns. Let us visualize this data using a heatmap, that means we have to transform the values in this DataFrame into different colors. Seaborn provides the heatmap() function for this purpose.
sns.heatmap(data=sample_data)
plt.xlabel('column names-->')
plt.ylabel('index values-->')
plt.title('Heatmap with random values')
plt.show()

The Pandas DataFrame is passed to the heatmap() function and the plot is displayed using the Matplotlib show() function.
The heatmap generated above represents the numerical values in sample_data with different colors. The DataFrame index values are used as y-tick labels and the column names are used as x-tick labels. This heatmap uses dark colors to display low values and light colors to display high values.
A color bar which represents the relationship between colors and values is displayed to the right hand side of the figure. The color bar has ticks positioned at 0.2, 0.4, 0.6, 0.8. The tick positions are calculated depending on the minimum and maximum data values in the input data. The colormap used here has dark colors at one extreme and light colors at the other extreme.
Let us look into another example, by loading the built-in ‘flights’ dataset.
flights_dataset = sns.load_dataset('flights')
flights_dataset.head(15)
year month passengers
0 1949 January 112
1 1949 February 118
2 1949 March 132
3 1949 April 129
4 1949 May 121
5 1949 June 135
6 1949 July 148
7 1949 August 148
8 1949 September 136
9 1949 October 119
10 1949 November 104
11 1949 December 118
12 1950 January 115
13 1950 February 126
14 1950 March 141
Data in the ‘flights_dataset’ DataFrame is in long form, we will reorganise the data and convert to wide form data using the Pandas pivot_table() function.
flights= pd.pivot_table(data=flights_dataset,index='month',columns='year',values='passengers')
flights
year 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960
month
January 112 115 145 171 196 204 242 284 315 340 360 417
February 118 126 150 180 196 188 233 277 301 318 342 391
March 132 141 178 193 236 235 267 317 356 362 406 419
April 129 135 163 181 235 227 269 313 348 348 396 461
May 121 125 172 183 229 234 270 318 355 363 420 472
June 135 149 178 218 243 264 315 374 422 435 472 535
July 148 170 199 230 264 302 364 413 465 491 548 622
August 148 170 199 242 272 293 347 405 467 505 559 606
September 136 158 184 209 237 259 312 355 404 404 463 508
October 119 133 162 191 211 229 274 306 347 359 407 461
November 104 114 146 172 180 203 237 271 305 310 362 390
December 118 140 166 194 201 229 278 306 336 337 405 432
The above data is in a format which is useful for our analysis, we will now plot a heatmap.
sns.heatmap(data=flights)
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()

The heatmap above has transformed the numerical values in the ‘flights’ DataFrame to different colors. The index values(‘month’) of the ‘flights’ DataFrame are used as y-tick labels and column names(‘year’) are used as x-tick labels. The heatmap uses dark colors to display low values and light colors to display high values.
In this example, the heatmap is generated using the default parameter values. Let us customize the appearance of the heatmap by changing the default settings.
Customize your heatmap
Example 1
sns.heatmap(data=flights,vmin=100,vmax=630,annot=True,fmt='d',linewidth=0.3,cbar=True)
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()

In the above heatmap, we have set the lower and upper bounds for the color bar, displayed numerical values in each cell and added borders to the cells. These parameters are defined below:
Parameters:
- vmin,vmax : values to set a Colorbar range, vmin is the lower bound and vmax is the upper bound for color scaling.If no value is specified then the limits are inferred from minimum and maximum data values in input data.
- annot : If True, the data values corresponding to each cell are displayed.
- fmt : String formatting code to use when adding annotations.
- linewidth : separate each cell of the heatmap using the specified linewidth.
- cbar : can be used to display or remove colorbar from the plot, default value is True.
Example 2
sns.heatmap(data=flights,center=flights.loc['July',1957])
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()

The above heatmap is plotted by setting the parameter ‘center’ to a numerical value and this value is used as the center of the colormap when plotting data.
In this example, the value in the cell indicated by (July,1957) is the center of the colormap. If you notice carefully, this is actually the new midpoint of the data and the cells to the left of the midpoint are indicated with a color scheme that gradually goes from blue to light green. The cells to the right of the midpoint are indicated with a color scheme that gradually goes from black to red. So a divergent color scheme is applied to the heatmap.
Example 3
sns.heatmap(data=flights,center=flights.loc['July',1960])
plt.title('flights data')
plt.xlabel('year')
plt.ylabel('month')
plt.show()

In the above heatmap, the center of colormap is set with the value indicated by cell – (July,1960). This heatmap uses light colors to display low values and dark colors to display high values. As you go from left to right extreme of the heatmap, the colors change from light to dark shades which clearly shows an increase in the number of passengers with each year. Also the number of passengers is more in the months of July and August compared to other months in any given year, this is evident from the cells corresponding to July and August which are darker compared to the cells above and below.
Example 4
sns.heatmap(data=sample_data,square=True,xticklabels=2, yticklabels=2,cbar_kws={'orientation':'horizontal'})
plt.show()

By passing True to the parameter ‘square’ you can set the shape of the cells to a square.
By default, the heatmap takes the xtick and ytick labels from the index and columns names of the Dataframe. This can be changed using the parameters xticklabels and yticklabels. These two parameters can take any of following values – “auto”, bool, list, int.
You can specify whether the colorbar is displayed horizontally or vertically by using the color bar keyword arguments(cbar_kws).
Example 5 – Masking
Say for example, you want to display only the cells above or below the diagonal of the heatmap . This can be achieved using masking. Let us display the cells above the diagonal in the heatmap using the input DataFrame sample_data.
np.triu() is a method in NumPy that returns the lower triangle of any array passed to it, while np.tril() returns the upper triangle of any array passed to it. Lets pass Dataframe sample_data to the method np.tril().
lower_triangle = np.tril(sample_data)
lower_triangle
array([[0.10428215, 0. , 0. , 0. , 0. ],
[0.26485421, 0.51836401, 0. , 0. , 0. ],
[0.54617601, 0.50928675, 0.14252839, 0. , 0. ],
[0.34907137, 0.26050913, 0.26700028, 0.13542169, 0. ],
[0.84525918, 0.86937953, 0.17806983, 0.91787589, 0.34436948]])
lower_triangle is the new array formed by extracting the lower (tril) triangular part of sample_data, and setting all other elements to zero.
sns.heatmap(data=sample_data,mask=lower_triangle)
plt.show()

The parameter ‘mask’ accepts an array or DataFrame. To the heatmap() function above we have passed the data as sample_data and mask as lower_triangle which will create a mask on the heatmap. Values will be plotted for cells where ‘mask’ is ‘False’ – that is a value of 0.
Below is another example, which displays the cells below the diagonal.
upper_triangle = np.triu(sample_data)
upper_triangle
array([[0.10428215, 0.52008394, 0.58626989, 0.42152029, 0.92470701],
[0. , 0.51836401, 0.96535456, 0.54591957, 0.97518597],
[0. , 0. , 0.14252839, 0.99268308, 0.13607794],
[0. , 0. , 0. , 0.13542169, 0.02502764],
[0. , 0. , 0. , 0. , 0.34436948]])
sns.heatmap(data=sample_data,mask=upper_triangle)
plt.show()

Read and plot data from a csv file
The csv file contains data related to the total number of road accidents and the time of occurence in different Indian cities in 2017. The file used in this example is available in this url: https://data.gov.in/resources/stateut-wise-road-accidents-time-occurance-during-2017
Let us read the data in the csv file into a DataFrame using the Pandas read_csv() function.
input_file = pd.read_csv(r'C:\Users\Ajay Tech\Documents\training\visualization\Data\Road_Accidents_2017.csv')
input_file.tail()
States/UTs 06-09hrs (Day) 09-12hrs (Day) 12-15hrs (Day) 15-18hrs (Day) 18-21hrs (Night) 21-24hrs (Night) 00-03hrs (Night) 03-06hrs (Night) Total Accidents
32 Daman & Diu 5 7 15 16 17 13 6 0 79
33 Delhi 747 858 828 807 1008 1159 714 552 6673
34 Lakshadweep 0 0 0 1 0 0 0 0 1
35 Puducherry 250 256 250 257 257 216 118 89 1693
36 Total 51551 71426 71594 82456 85686 49567 25050 27580 464910
The last row and column in the Dataframe contain the sum of all the numerical values in each row and column respectively. This is actually not required for our analysis, so let us delete the last row and column.
input_file.drop(36, axis=0,inplace=True)
input_file.drop('Total Accidents', axis=1,inplace=True)
input_file.tail()
States/UTs 06-09hrs (Day) 09-12hrs (Day) 12-15hrs (Day) 15-18hrs (Day) 18-21hrs (Night) 21-24hrs (Night) 00-03hrs (Night) 03-06hrs (Night)
31 D & N Haveli 7 6 12 14 14 5 5 4
32 Daman & Diu 5 7 15 16 17 13 6 0
33 Delhi 747 858 828 807 1008 1159 714 552
34 Lakshadweep 0 0 0 1 0 0 0 0
35 Puducherry 250 256 250 257 257 216 118 89
Next we will reorganize the data in the Dataframe into a format which is required for analysis using pivot_table() function.
input_data = input_file.pivot_table(index='States/UTs')
input_data.head()
00-03hrs (Night) 03-06hrs (Night) 06-09hrs (Day) 09-12hrs (Day) 12-15hrs (Day) 15-18hrs (Day) 18-21hrs (Night) 21-24hrs (Night)
States/UTs
A & N Islands 3 7 20 27 33 41 39 19
Andhra Pradesh 1406 1648 2808 3581 3765 4484 5265 2770
Arunachal Pradesh 42 39 40 24 25 24 23 24
Assam 393 391 749 1206 1262 1340 1121 708
Bihar 495 755 1344 1550 1292 1462 1312 645
Let us generate a heatmap for the first 20 rows in the data.
sns.heatmap(data = input_data.head(20),cmap='YlOrRd',linewidths=0.3,xticklabels=['00-03','03-06','06-09','09-12','12-15','15-18','18-21','21-24'])
plt.xticks(rotation=20)
plt.xlabel('Time of occurence(hrs)')
plt.title('No. of road accidents/time interval of day')
plt.show()

The heatmap uses light colors for low values and dark colors for high values. Few cities in the heatmap show that the number of accidents are more during the interval 9AM – 9PM, for example for the last city (Madhya Pradesh) in the heatmap dark colors in the colormap are used for the cells in this interval. The maximum number of road accidents for most cities occur during 6-9 PM interval, this can be inferred from the color of the cells in this interval.