Matplotlib
Summary:Basic 2-d plotting library for Data Science and Machine Learning.
Contents
Just enough Matplotlib
Matplotlib Reference
Introduction
Matplotlib is the most popular and credible plotting package in the Python Data Science world. Let’s uncover the basics of Matplotlib
What is Matplotlib
Say we have wage data ( age vs number of hours worked ) and we want to analyze if there is a relationship between them. This is just an extract from a larger dataset. We will see a more realistic dataset later.

import numpy as np
wage_data = np.genfromtxt("../data/age_vs_hours_small.csv",delimiter=",",skip_header=1,dtype=int)
import matplotlib.pyplot as plt
plt.plot(wage_data[:,0],wage_data[:,1])

Granted this is not all that informative, but this is a very small dataset. Let’s increase the dataset size.
wage_data = np.genfromtxt("../data/age_vs_hours_large.csv",delimiter=",",skip_header=1,dtype=int)
plt.plot(wage_data[:,0],wage_data[:,1])

plt.scatter(wage_data[:,0],wage_data[:,1])

Plot is too dense. Let’s see if we can reduce the opacity a bit.
plt.scatter(wage_data[:,0],wage_data[:,1],alpha=0.1)

Also, for the viewer to understand the data being represented in this plot, let’s put some labels.
plt.scatter(wage_data[:,0],wage_data[:,1],alpha=0.1)
plt.xlabel("age")
plt.ylabel("Number of hours")
plt.title("Age vs Number of hours worked")
Text(0.5, 1.0, 'Age vs Number of hours worked')

Don’t worry if you don’t understand the syntax of what you saw. The point we are trying to make is, matplotlib is a full-fledged 2-d plotting toolkit that let’s you plot most types of data with good control on each aspect of the plotting element – like, shape,size,color,opacity, labels etc. There is also a wide variety of plots to choose from – scatter plot, line plot, histogram, time series plots and many more. Later, we will learn about the most used controls in matplotlib with examples.
Why learn Matplotlib
The conclusions you can draw from the picture below, is something you can never draw from the raw numbers themselves.
plt.scatter(wage_data[:,0],wage_data[:,1],alpha=0.1)
plt.xlabel("age")
plt.ylabel("Number of hours")
plt.title("Age vs Number of hours worked")
Text(0.5, 1.0, 'Age vs Number of hours worked')

That is the power of visualization. Every Data Scientist and ML engineer needs visualization to Understand their data. By Understand we mean,
- understanding data trends
- identify outliers
- identify patters etc
Once we Understand the data, we have to let our audience understand the results of our analysis. For example, unless you should a picture like the one shown above, how do you convince your audience about your analysis. Visualization is a keypart of communicating your test results.
Although there are dedicated software packages ( Tableau, Power BI ) and some other open source tools (, d3.js, plot.ly ) , if you are invested into the Python environment for Data Science of ML, Matplotlib is a wonderful visualization tool. The strength of Matplotlib lies in it’s customization. Also, Matplotlib is the basis for another high level visualization library called Seaborn.
Matplotlib Philosophy
Matplotlib is modeled functionally after MATLAB. Most of the design philisophy seems to revolve around the same way MATLAB does plotting.
Plot in 2d
Although Matplotlib can do a bit of 3d, most visualization still in the data science world is 2d. However, real world data is rarely 2d. For example, if you look at the Canada Wage data above, there were 2 parameters
- Age
- Number of hours.
This is 2-dimensional data. What if you added one more parameter – sex. Now you want to analyze the number of hours worked not just by age, but by sex as well.
wage_data = pd.read_csv("../data/age_vs_hours_vs_sex.csv",delimiter=",",
header=0)
wage_data.head()
age sex hours-per-week
0 39 Male 40
1 50 Male 13
2 38 Male 40
3 53 Male 40
4 28 Female 40
How do we plot 3 variables ( dimensions ) on 2d ? There are a variety of ways to do it. One of them being sub-plots.
import matplotlib.cm as cm
# wage_data["sex"] == "Male"
wage_data["sex"] = wage_data["sex"].astype("str")
wage_data["sex"] = wage_data["sex"].astype("category")
# wage_data["sex"] == "Male"
wage_data_male = wage_data[wage_data["sex"] == "Male"]
wage_data_female = wage_data[wage_data["sex"] == "Female"]
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
ax1.scatter(wage_data_female["age"],wage_data_female["hours-per-week"],alpha=0.1,c= "g")
ax2.scatter(wage_data_male["age"],wage_data_male["hours-per-week"],alpha=0.1,c= "r")

State Machine
This is another feature that is straight out of the MATLAB plotting technique. pyplot object acts like a static state machine. It has a hierarchical structure whereby you keep adding elements to the parent “pyplot” static object. Let’s see some examples to understand this better.In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
x = np.arange(10)
y = x
plt.plot(x,y)

Now that we have the basic plot, let’s add more on top of it – like juxtaposing a curve in different color on top of the same straight line plot.
x = np.arange(10)
y = x
z = x**2
plt.plot(x,y)
plt.plot(x,z,c="r")

The first call to pyplot ( plt.plot(x,y) ) creates the plot object and all subsequent calls adds more elements to the same plot object. Since all the subsequent functions on plt are acting on the same pyplot object, they share the same axis, axis labels, legend, plot labels and so on. Let’s see some more examples below.
Let’s add a legend to make things clear to the viewer.In [11]:
plt.plot(x,y, label = "Straight line")
plt.plot(x,z,c="r", label = "Exponential")
plt.legend(loc='upper left')

Let’s show some points on the screen – may be to draw attention to some of the important data points on the plot.
x = np.arange(10)
y = x
z = x**2
plt.plot(x,y, label = "Straight line")
plt.annotate ( "Large Difference", xy=(8,8),
xytext=(+10, +30),arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.plot(x,z,c="r", label = "Exponential")
plt.annotate ( "Large Difference", xy=(8,8**2),
xytext=(+10, +30),arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
plt.legend(loc='upper left')

We will learn more about this as we go forward.
Layers
To make plotting multiple elements on a 2-d surface, we tend to use different element like
- Color
- Shape
- Size etc.
To make this process easy, Matplotlib allows us to add the different elements in a layered fashion – one layer on top of another. Let’s see an example.
Say, we fill the area under the curve with a color and then pile another element on top of the same plot. The order of piling matters. We ofcourse have fine control over this using zorder.
x = np.arange(10)
y = x
plt.plot(x,y)
plt.fill_between(x,y)

plt.plot(x,z,c="r")
plt.fill_between(x,z)
plt.plot(x,y)
plt.fill_between(x,y)

plt.plot(x,y)
plt.fill_between(x,y)
plt.plot(x,z,c="r")
plt.fill_between(x,z)

Our approach to Matplotlib
Unless we understand the basics of probability and statistics, we can’t understand much about what these plots and graphs mean. So, we will postpone most of the advanced topics in Matplotlib to later sections. In this chapter, we will just focus on the basics of Matplotlib.
Plot
The generic object for plotting is matplotlib.pyplot or pyplot if you exclude the package name. There are some basic elements in every plot.
Axis
Typically, every 2-d plot has 2 axes (although there can be more dimensions some times).
- x axis
- y axis
Each axis can a label. You can do that using xlabel and ylabel.
Axis Labels
x = np.arange(1,10)
y = x**2 * 1000
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
Text(0, 0.5, 'Salary')

Axis Range
You might just want to show a certain range of the axis. You can zoom in on a certain section by specifying the axis ( ) function. The syntax for it is
# axis ( xMin, xMax, yMin, yMax )
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
plt.axis([1,5,0,30000])
[1, 5, 0, 30000]

Axis Ticks
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
locs,labels = plt.xticks()
print(locs) ; print(labels)
[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10.]

plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
locs,labels = plt.xticks()
labels = ["Infant","Infant","Infant","Young","Young",'Young',"Young","Young","Young",'Young',"Young"]
plt.xticks(locs,labels)

Plot Title
This is where you describe the plot in words.
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
plt.title("Age vs Salary")
Text(0.5, 1.0, 'Age vs Salary')

Granted, this is a trivial description – however, you can customize it anyway you want depending on the requirement.
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
plt.title("Age vs Salary", loc="left")
Text(0.0, 1.0, 'Age vs Salary')

Or you can have multiple titles
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
plt.title("Young", loc="left")
plt.title("Older", loc="right")
plt.title("Salary vs Age")
Text(0.5, 1.0, 'Salary vs Age')

With different fonts and sizes
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
plt.title("Young", loc="left")
plt.title("Older", loc="right")
plt.title("Salary vs Age", fontsize=16)
Text(0.5, 1.0, 'Salary vs Age')

plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
plt.title("Young", loc="left")
plt.title("Older", loc="right")
plt.title("Salary vs Age", fontsize=16,color="red")
Text(0.5, 1.0, 'Salary vs Age')

Or using a font dictionary object
font = {'family': 'arial',
'color': 'blue',
'weight': 'normal',
'size': 18,
}
plt.plot(x,y)
plt.xlabel("Age")
plt.ylabel("Salary")
plt.title("Young", loc="left")
plt.title("Older", loc="right")
plt.title("Salary vs Age", fontdict = font)
Text(0.5, 1.0, 'Salary vs Age')

Grid
To correlate x and y values accurately, it is some times better to have a grid.
plt.plot(x,y)
plt.grid()

The default grid color and transparency should be good enough for most situations. As usual though, you can customize it to your liking.In [88]:
plt.plot(x,y)
plt.grid(color="red", alpha=0.1,linestyle="--")

As you can see this is a more subdued version of the one above (although it is in red). You can pretty much use any property that a line can have,
- style
- color
- alpha ( transparency )
- width etc
Line Type
The typical line is blue, but you have fine grained control over the plotting of the line, like
- color
- type
- point type
- size
- alpha ( transparency ) etc
plt.plot(x,y)

Format String Notation
Instead of specifying all of the properties separately, matplotlib has a short string-based notation to specify most of the line properties.
x = np.arange(1,10)
y1 = x
y2 = x*10
y3 = x*20
#---------------------------------------------
plt.plot(x,y1,"g") # Make the line green
plt.plot(x,y2,"r.") # Make the line red, with dots for the points
plt.plot(x,y3,"c*--") # Make the line cyan with a dashed line

The format string goes something like this
plot ( x, y , fmt = " [color] [marker] [line type] " )
Just google matplotlib format string for more details on the different values that are possible for these parameters.
Format Keywords
Instead of using the format string, you can use specific keywords to specify the line format.
plt.plot(x,y1,color="green") # Make the line green
plt.plot(x,y2,color="red",linestyle="none",marker=".") # Make the line red, with dots for the points
plt.plot(x,y3,color="cyan", linestyle="dashed", marker="*") # Make the line cyan with a dashed line

Legend
In case there are multiple plotting elements on the same plot, Legend helps us understand what each of the stand for.
plt.plot(x,y1,color="green", label="linear")
plt.plot(x,y2,color="red",linestyle="none",marker=".", label="x10 Multiplier")
plt.plot(x,y3,color="cyan", linestyle="dashed", marker="*",label="x20 multiplier")
plt.legend()

You can locate the Legend at a particular location or let Matplotlib choose a best fit.
p1 = plt.plot(x,y1,color="green", label="linear")
p2 = plt.plot(x,y2,color="red",linestyle="none",marker=".", label="x10 Multiplier")
p3 = plt.plot(x,y3,color="cyan", linestyle="dashed", marker="*",label="x20 multiplier")
legend = plt.legend(loc="upper center",shadow=True)

You can even have a drop shadow and change the color of the legend to make it more or less prominent.
p1 = plt.plot(x,y1,color="green", label="linear")
p2 = plt.plot(x,y2,color="red",linestyle="none",marker=".", label="x10 Multiplier")
p3 = plt.plot(x,y3,color="cyan", linestyle="dashed", marker="*",label="x20 multiplier")
plt.legend(loc="best",shadow=True)

Barplot
In simple terms, a bar plot visually compares multiple categories. For example, compare the weight of 5 students in a class.
import matplotlib.pyplot as plt
names = ["Ajay","Stacy","Brad"]
weight = [34,25,45]
plt.bar(names,weight)

You can tilt it horizontally if you want. You might have seen this in comparing Processor speeds in tech magazines.
plt.barh(names,weight)

Say you want to specifically call out a particular bar – say in this case as Brad being overweight, you can do that using the color attribute
plt.bar(names,weight,color=["blue","blue","red"])

Or using a horizontal line to speicify a cut-off.
plt.bar(names,weight)
plt.axhline(35,color="red")

Sometimes, you can compare multiple attributes for the same category. For example, in the same data above, say, we are comparing the weight against the corresponding average weight categorized by sex.
Or using
names = ["Ajay","Stacy","Brad"]
weight = [34,25,45]
#---------------------------
weight_avg = [30,28,30]
plt.bar(names,weight,label="weight")
plt.bar(names,weight_avg,label="average")
plt.axhline(30, color="gray")
plt.legend()

This is not all that useful, right ? The trick is to move the bars’ x position a bit to the right so that they appear side-by-side.
import numpy as np
names = ["Ajay","Stacy","Brad"]
weight = [34,25,45]
weight_avg = [30,28,30]
#---------------------------
x = np.arange(len(names)) # Creates x-axis numbers
w = 0.5
plt.bar(x,weight,label="weight")
plt.bar(x+w,weight_avg,label="average")
plt.legend()

Still not very useful, right ? What if we adjust the width of the bar ? Use the width attribute. You might have to tweak the width a bit based on the outcome.
plt.bar(x,weight,label="weight",width=0.45)
plt.bar(x+w,weight_avg,label="average",width=0.45)
plt.legend()

How about the labels then ? Create custom labels using xticks.
plt.bar(x,weight,label="weight",width=0.45)
plt.bar(x+w,weight_avg,label="average",width=0.45)
plt.legend()
#----------------------------------------------------
x_labels = np.append(x,x+w) # Create label values
x_labels = np.sort(x_labels)
labels = ["Ajay","Average","Stacy","Average","Brad","Average"] # Create actual labels
plt.xticks(x_labels,labels) # Set the labels as x ticks

Lifecycle of a Plot
Matplotlib has 2 interfaces.
- MATLAB based interface ( State-based )
- Object Oriented interface
So far, we have seen the first one – completely based on pyplot module. How about the second ?
To understand the object oriented interface of Matplotlib, we have to start with a couple of fundamental concepts related to a plot.
- Figure
- Axes
Think of the entire plot as an object hierarchy with Figure at the top of it. Here is a small subset of the hierarchy with Figure being at the top, followed by Axes ( not the same as axis ), followed by different texts on the plot, different kinds of plots it can handle, the different axis and so on. The hierarchy doesn’t stop there. Further to x-axis for example, there are ticks and further to ticks there are its subsequent properties and so on.

Here is another picture from Matplotlib documentation that shows the same pictorially. Essentially, the same things as above, put visually on the plot to make understanding it even better.

This is a lot to understand. So, for starters, let’s just look at the picture below.

Let’s explore this in code.
import matplotlib.pyplot as plt
import numpy as np
x = np.arange(10)
y = x**2
fig,ax = plt.subplots(1,1) # 1,1 indicates, 1 row and 1 column. We will see more examples

This is a blank plot. Let’s add it’s components step by step.
# Use the plot function
fig,ax = plt.subplots(1,1)
ax.plot(x,y)

fig,ax = plt.subplots(1,1)
ax.plot(x,y)
#------ Set title -------------------#
ax.set_title("Exponential Plot")
Text(0.5, 1.0, 'Exponential Plot')

fig,ax = plt.subplots(1,1)
ax.plot(x,y)
ax.set_title("Exponential Plot")
#--------------Set the x and y axis labels ----------------
ax.set_xlabel("age")
ax.set_ylabel("Cell growth")
Text(0, 0.5, 'Cell growth')

Say for example, you have 2 sets of data and you want to compare them side by side. One way, is to overlap both on the same plot.
x = np.arange(10)
y1 = x**2
fig,ax = plt.subplots(1,1)
#----------- new y-axis data ---------------
y2 = x**3
ax.plot(x,y1)
ax.plot(x,y2)

However, if you have a different data set altogether, you can use 2 subplots instead of one.
x1 = np.arange(10)
y1 = x1**2
fig,ax = plt.subplots(1,2)
#----------- new y-axis data ---------------
x2 = np.arange(20,30)
y2 = x2**2
ax[0].plot(x1,y1)
ax[1].plot(x2,y2)

We will learn more about subplots later. We just wanted to illustrate the object oriented way of plotting in Matplotlib. Let’s compare the object oriented code vs State-based code to understand the differences.

In fact, all of pyplot.py is a bunch of wrapper functions (State-based, MATLAB style) around Object Oriented approach. For example,
pyplot.xlabel is equivalent to axes.set_xlabel
pyplot.title is equivalent to axes.set_title
..

More Plots
There are so many other plots that Matplotlib supports. We are going to look at just a couple of them here and many more as we encounter the respective topics.
Scatter Plot
A scatter plot is almost similar to a normal plot, except you don’t join the points. We are going to learn more about scatter plot when we talk about Correlation in Statistics. Essentially, a scatterplot is used to find out how far two variables are associated with each other. For example, is there a relationship between age and cell growth
import matplotlib.pyplot as plt
import numpy as np
age_1 = [28,26,26,27,29,20,28,24,29, 30, 29, 22, 22, 26, 26, 29, 28, 22, 22, 24, 27, 26, 25, 23, 25, 26, 22, 30, 30, 24, 30, 20, 24, 28, 24, 23, 24, 27, 24, 20, 24, 21, 21, 27, 28, 21, 26, 30, 26, 20]
cell_growth_1 = [76, 11, 70, 100, 48, 77, 57, 73, 99, 58, 7, 88, 19, 67, 16, 4, 36, 73, 20, 56, 10, 38, 13, 42, 85, 29, 49, 63, 78, 64, 13, 80, 34, 76, 99, 72, 77, 21, 31, 3, 61, 90, 75, 70, 87, 45, 50, 23, 90,12]
plt.scatter(age_1,cell_growth_1)

This tells us that there is no relationship between age and cell growth. How about the following ?
age_2 = [28,26,26,27,29,20,28,24,29, 30, 29, 22, 22, 26, 26, 29, 28, 22, 22, 24, 27, 26, 25, 23, 25, 26, 22, 30, 30, 24, 30, 20, 24, 28, 24, 23, 24, 27, 24, 20, 24, 21, 21, 27, 28, 21, 26, 30, 26, 20]
cell_growth_2 = [32, 28, 34, 33, 39, 22, 37, 25, 30, 37, 35, 28, 32, 27, 33, 35, 29, 29, 31, 32, 30, 30, 28, 32, 31, 28, 25, 35, 31, 25, 36, 30, 25, 31, 33, 32, 30, 34, 26, 26, 27, 23, 26, 35, 33, 29, 28, 37, 33, 27]
plt.scatter(age_2,cell_growth_2)

Now, we have a relationship – seems to be somewhat linear. Let’s add one more data point – The type of medication used.
medication_2 = [2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1]
# Say 1 = Medicine 1 , 2 = Medicine 2
plt.scatter(age_2,cell_growth_2,c=medication_2)
# This plot can analyze if medication has any effect on cell growth

Similarly, you can vary the size, shape, marker style, transparency and so on.
Pie Chart
This is probably one of the simplest of charts. It is used to represent parts of a whole. Matplotlib automatically takes care of coloring it. For example, if you had a basket of stocks and wanted to represent each of their share, you can use a pie chart.
import matplotlib.pyplot as plt
stock_value = [5000, 12000, 25000, 3000, 20000]
stock_labels = ["Apple", "Amazon", "Facebook" , "Google", "HP"]
plt.pie(stock_value, labels=stock_labels)
[Text(1.068035996798755, 0.26324724033138536, 'Apple'),
Text(0.5345817254868832, 0.9613648520595428, 'Amazon'),
Text(-1.0540704321815815, 0.3145401786744169, 'Facebook'),
Text(-0.5345816692309615, -0.961364883341512, 'Google'),
Text(0.6248713371324109, -0.9052821726016441, 'HP')])

The colors seem to be a problem. Let’s change the colors first.
plt.pie(stock_value, labels=stock_labels, textprops ={'color':"w"})
[Text(1.068035996798755, 0.26324724033138536, 'Apple'),
Text(0.5345817254868832, 0.9613648520595428, 'Amazon'),
Text(-1.0540704321815815, 0.3145401786744169, 'Facebook'),
Text(-0.5345816692309615, -0.961364883341512, 'Google'),
Text(0.6248713371324109, -0.9052821726016441, 'HP')])

If you wanted a bigger pie, choose a different radius.
plt.pie(stock_value, labels=stock_labels, textprops ={'color':"w"},
radius = 2)
[Text(2.13607199359751, 0.5264944806627707, 'Apple'),
Text(1.0691634509737664, 1.9227297041190856, 'Amazon'),
Text(-2.108140864363163, 0.6290803573488338, 'Facebook'),
Text(-1.069163338461923, -1.922729766683024, 'Google'),
Text(1.2497426742648219, -1.8105643452032882, 'HP')])

What if you wanted to get percentages ?
plt.pie(stock_value, labels=stock_labels, textprops ={'color':"w"},radius = 2,
autopct='%1.1f%%')
[Text(2.13607199359751, 0.5264944806627707, 'Apple'),
Text(1.0691634509737664, 1.9227297041190856, 'Amazon'),
Text(-2.108140864363163, 0.6290803573488338, 'Facebook'),
Text(-1.069163338461923, -1.922729766683024, 'Google'),
Text(1.2497426742648219, -1.8105643452032882, 'HP')],
[Text(1.1651301783259143, 0.28717880763423853, '7.7%'),
Text(0.5831800641675089, 1.0487616567922284, '18.5%'),
Text(-1.1498950169253614, 0.3431347403720912, '38.5%'),
Text(-0.5831800027974124, -1.048761690918013, '4.6%'),
Text(0.6816778223262664, -0.9875805519290662, '30.8%')])

How about making a particular slice stand out ? Say we want Facebook to stand out as it is our biggest share.
plt.pie(stock_value, labels=stock_labels, textprops ={'color':"w"},radius = 2,autopct='%1.1f%%',
explode = (0, 0, 0.2, 0,0))
[Text(2.13607199359751, 0.5264944806627707, 'Apple'),
Text(1.0691634509737664, 1.9227297041190856, 'Amazon'),
Text(-2.2997900338507233, 0.6862694807441824, 'Facebook'),
Text(-1.069163338461923, -1.922729766683024, 'Google'),
Text(1.2497426742648219, -1.8105643452032882, 'HP')],
[Text(1.1651301783259143, 0.28717880763423853, '7.7%'),
Text(0.5831800641675089, 1.0487616567922284, '18.5%'),
Text(-1.3415441864129218, 0.4003238637674397, '38.5%'),
Text(-0.5831800027974124, -1.048761690918013, '4.6%'),
Text(0.6816778223262664, -0.9875805519290662, '30.8%')])

This is pretty flat, right ? How about putting an angle ?
plt.pie(stock_value, labels=stock_labels, textprops ={'color':"w"},radius = 2,autopct='%1.1f%%',explode = (0, 0, 0.2, 0,0),
shadow=True, startangle=90)
[Text(-0.5264944806627704, 2.13607199359751, 'Apple'),
Text(-1.9227297041190854, 1.069163450973767, 'Amazon'),
Text(-0.6862694807441815, -2.299790033850724, 'Facebook'),
Text(1.922729766683024, -1.069163338461923, 'Google'),
Text(1.8105643452032882, 1.2497426742648217, 'HP')],
[Text(-0.28717880763423836, 1.1651301783259143, '7.7%'),
Text(-1.0487616567922282, 0.5831800641675092, '18.5%'),
Text(-0.4003238637674391, -1.3415441864129218, '38.5%'),
Text(1.048761690918013, -0.5831800027974126, '4.6%'),
Text(0.9875805519290662, 0.6816778223262663, '30.8%')])

How about a border for the pie chart ?
plt.pie(stock_value, labels=stock_labels, textprops ={'color':"w"},radius = 2,autopct='%1.1f%%',explode = (0, 0, 0.2, 0,0),shadow=True, startangle=90,
wedgeprops={"edgecolor":"w",'linewidth': 1, 'linestyle': 'solid', 'antialiased': True} )
[Text(-0.5264944806627704, 2.13607199359751, 'Apple'),
Text(-1.9227297041190854, 1.069163450973767, 'Amazon'),
Text(-0.6862694807441815, -2.299790033850724, 'Facebook'),
Text(1.922729766683024, -1.069163338461923, 'Google'),
Text(1.8105643452032882, 1.2497426742648217, 'HP')],
[Text(-0.28717880763423836, 1.1651301783259143, '7.7%'),
Text(-1.0487616567922282, 0.5831800641675092, '18.5%'),
Text(-0.4003238637674391, -1.3415441864129218, '38.5%'),
Text(1.048761690918013, -0.5831800027974126, '4.6%'),
Text(0.9875805519290662, 0.6816778223262663, '30.8%')])

The list of enhancements can go on and on.
Other Plots
There are so many other plots that matplotlib can do and we will be seeing most of them at different points as we cruise through the Statistics chapter of the course.
- Histogram
- Tables
- Box plots
- Box and whisker plots
- Violin plots
Matplotlib Reference
Contour Plots
Contour plots are exactly what you know from maps. What does a contour map tell you ? For example, look at a contour map of a hill. The inner most circle represents the highest point and as the hill tapers down at each level ( certain height ) all the points at that height are marked with a circle.

In Machine Learning, you can use contour maps to map the value of a function over a grid. To get the grid though, NumPY’s meshgrid ( ) function is extremely useful. For example, to map a circle, you could just do this.
Get a grid of data using meshgrid ( )
import numpy as np
x = np.linspace(1,10,10)
y = np.linspace(1,10,10)
xx,yy = np.meshgrid(x,y)
Formulate the circle values from the grid
The formula for a circle is

z = xx**2 + yy*2
Plot the function ( z ) over the grid
import matplotlib.pyplot as plt
%matplotlib inline
plt.scatter(xx,yy)
set = plt.contour(xx,yy,z)

In case you want to see the levels ( along which the contours are drawn), use
set.levels
array([ 0., 15., 30., 45., 60., 75., 90., 105., 120.])
The levels that are chosen are automatically selected by matplotlib. As usual, we do have control over the levels and can specify them.
import matplotlib.pyplot as plt
%matplotlib inline
plt.scatter(xx,yy)
set = plt.contour(xx,yy,z, levels = [0,5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,100])

Sometimes, you would want the contours to be colored solid. In cases like that, use the contour fill function – contourf function.
import matplotlib.pyplot as plt
%matplotlib inline
plt.scatter(xx,yy)
set = plt.contourf(xx,yy,z)

The reason why the contours are jagged is because we have just a few points per axis ( just 10 ). Let’s try to increase it to 1000 and you should be able to see a much smoother contour.
import numpy as np
x = np.linspace(1,10,1000)
y = np.linspace(1,10,1000)
xx,yy = np.meshgrid(x,y)
z = xx**2 + yy*2
z
array([[ 3. , 3.01809918, 3.03636069, ..., 101.63996429,
101.81990098, 102. ],
[ 3.01801802, 3.0361172 , 3.0543787 , ..., 101.65798231,
101.837919 , 102.01801802],
[ 3.03603604, 3.05413522, 3.07239672, ..., 101.67600032,
101.85593702, 102.03603604],
...,
[ 20.96396396, 20.98206314, 21.00032465, ..., 119.60392825,
119.78386495, 119.96396396],
[ 20.98198198, 21.00008116, 21.01834267, ..., 119.62194627,
119.80188296, 119.98198198],
[ 21. , 21.01809918, 21.03636069, ..., 119.63996429,
119.81990098, 120. ]])
import matplotlib.pyplot as plt
%matplotlib inline
plt.scatter(xx,yy)
plt.contourf(xx,yy,z)

Another use for countour plots is to visualize classification predictions in machine learning.