Seaborn splits Matplotlib parameters into two independent groups: The first group sets the aesthetic style of the plot and second group scales various elements of the figure.
Let us first see how to customize the look and style of the plots. Seaborn has five built-in themes to style its plots – darkgrid, whitegrid, dark, white, and ticks. The darkgrid is the default theme, but we can change this style to suit our requirements.
We can customize the styles such as background color, color of tick marks, text color, font type etc., using the functions axes_style() and set_style(). Both these functions take same set of arguments.
The axes_style() function defines and returns a dictionary of rc parameters related to the styling of the plots. This function returns an object that can be used in a with statement to temporarily change the style parameters.
The set_style() function is used to set the aesthetic style of the plots, the rc parameters can be customized using this function.
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
Below we see the rc parameters returned by the axes_style function. The default settings of the parameters can be changed to fine tune the look of the plots.
A scatterplot is used to display the correlation between two numerical variables. The values of both the variables are displayed with dots. Each dot on the scatterplot represents one observation from the data set.
We will plot a scatter plot which uses the ‘tips’ dataset to display the tips received on the total_bill, both of which are quantitative variables. The scatter plot is rendered in default theme which is darkgrid style.
As you can see, the default theme has a light grey background with white gridlines.
Let’s look at another example, in the scatter plot function below, we have passed the ‘ticks’ theme to the style parameter. The ‘ticks’ theme allows the colors of the dataset to show more visibly. Apart from this, we have also changed the axes edgecolor, text color, ticks color in the plot by changing the default rc parameter values.
We can temporarily change the style parameters of a plot by using the axes_style function in a with statement as shown in the example below.
with sns.axes_style(style='whitegrid',rc={'font.family': 'serif','font.serif':'Times'}):
sns.scatterplot('total_bill','tip',data=tips)
plt.title('Whitegrid theme')
plt.show()
Scaling of plot elements
Next we will see how to scale the various elements in the plot. Seaborn has four preset contexts which set the size of the plot and allow us to customize the plot depending on how it will be presented. The four preset contexts, in order of relative size are – paper, notebook, talk and poster. The notebook style is the default context, which can be changed depending on our requirement.
We can customize the size of the plot elements such as labels,ticks,markers,linewidth etc., using the functions plotting_context() and set_context(). Both these functions take same set of arguments.
The plotting_context() function defines and returns a dictionary of rc parameters related to plot elements such as label size,tick size,marker size. This function returns an object that can be used in a with statement to temporarily change the context parameters.
The set_context() function is used to set the plotting context parameters.
Below we see the rc parameters returned by the plotting_context function. The default settings of the parameters can be changed to scale plot elements.
The scatter plot below uses the ‘tips’ dataset to display the tips received on the total_bill. The scatter plot is rendered in the notebook context which is the default context.
In the example below, we have passed ‘talk’ to the context parameter. Apart from this, we have also changed the label size, title size, grid linewidth of the plot by changing the default rc parameter values.
If you want to switch to Seaborn default settings, then call the set() function without passing any arguments.
Default settings:
context=’notebook’
style=’darkgrid’
palette=’deep’
font=’sans-serif’
font_scale=1
color_codes=True
sns.set()
sns.scatterplot('total_bill','tip',data=tips)
plt.title('Plot with default settings')
plt.show()
Despine() Fundespine()ction
Spines are the borders on the sides of a graph or plot. By default, a plot has four spines. The despine() function can be used to remove the spines in the plot, by default the top and right spines are removed using this function.
sns.set(style='ticks')
sns.scatterplot('total_bill','tip',data=tips)
plt.title('Plot with four spines/borders')
plt.show()
sns.set(style='ticks')
sns.scatterplot('total_bill','tip',data=tips)
sns.despine()
plt.title('Plot with top and right spines removed')
plt.show()
You can choose to remove all the spines if you think they are unnecessary and distracting, see example below.
sns.set(style='white')
sns.scatterplot('total_bill','tip',data=tips)
sns.despine(left=True,bottom=True)
plt.title('Plot with all spines removed')
plt.show()
Seaborn is a data visualization library which provides a high-level interface to draw statistical graphs. It is built on top of Python’s core visualization library, Matplotlib. Seaborn extends the Matplotlib library for creating aesthetically pleasing graphs. Internally Seaborn uses Matplotlib to draw plots, so it complements the Matplotlib library but is not a replacement to it. Matplotlib is highly customizable, but it is hard to know what settings to tweak to render nice plots. Seaborn comes with a number of customized themes and a high-level interface for controlling the look of Matplotlib figures.
Seaborn comes with preset styles and color palettes which can be used to create aesthetically pleasing charts with few lines of code. It is closely integrated with the Pandas and Numpy library.
Below are the dependencies of the Seaborn library:
Python 3.6+
numpy (>= 1.13.3)
scipy (>= 1.0.1)
pandas (>= 0.22.0)
matplotlib (>= 2.1.2)
Once the required dependencies are installed, you are ready to install and use Seaborn.
The latest version of Seaborn can be installed using pip with the command — pip install seaborn
You can also install Seaborn using Anaconda prompt with the command — conda install seaborn
Seaborn is closely integrated with Pandas data structures. The Pandas library has two primary containers of data – DataFrame and Series.
DataFrames – A DataFrame is a collection of data arranged in rows and columns. DataFrames are similar to excel spreadsheets. They are two-dimensional structures, with two axes, the “index” axis and the “columns” axis.
Series – Series is a single column of the DataFrame. So a Pandas DataFrame is a collection of Series objects.
Basic Terms
Quantitative and Qualitative variables
In statistics two types of variables are used: Quantitative and Qualitative variables.
Quantitative: Quantitative variables are numerical values representing counts or measures. Examples: Temperature counts, percents, weight. Quantitative variables are of two types – discrete and continuous:
Discrete variables are numeric variables that have a finite number of values between any two values. A discrete variable is always numeric.
Continuous variables are numeric variables that have an infinite number of values between any two values.
Qualitative: Qualitative variables are variables that can be placed into distinct categories according to some attributes or characteristics. They contain a finite number of categories or distinct groups. Examples: Gender, eye color.
Univariate and Bivariate Data
Statistical data are classified according to the number of variables being studied.
Univariate data: This type of data consists of only one variable. The variable is studied individually and we don’t look at more than one variable at a time.
Bivariate data: This type of data involves two different variables, where the two variables are studied to explore the relationship or association between them.
Loading Datasets
Seaborn comes with a few important datasets that can be used to practice. When Seaborn is installed, the datasets are downloaded automatically. To start working with a built-in Seaborn data set, you can make use of the load_dataset() function. By default, the built-in datasets are loaded as Pandas DataFrame. Let us load the ‘tips’ dataset which consists of the tips received by a waiter in a restaurant over a period of time.
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
tips = sns.load_dataset('tips')
tips.head(10)
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
5 25.29 4.71 Male No Sun Dinner 4
6 8.77 2.00 Male No Sun Dinner 2
7 26.88 3.12 Male No Sun Dinner 4
8 15.04 1.96 Male No Sun Dinner 2
9 14.78 3.23 Male No Sun Dinner 2
Bar Plot
A Bar Plot is a visual tool that uses bars to compare data among different groups or categories. The graph represents categories on one axis and discrete values on the other. Let us draw a bar plot using the barplot function defined in Matplotlib library with input from tips dataset.
The bar plot above is plotted in Matplotlib, the bars compare the tips received from two groups – Smokers and Non-Smokers. Smokers is a qualitative variable. The Series data is passed to the axes arguments.
Let us again plot a barplot using the Seaborn library as shown below.
sns.barplot(x='smoker',y='tip',data=tips)
plt.title('Tips vs Smoker')
plt.show()
The bar plot above is plotted in Seaborn. To the barplot function we have passed the column names to the x and y parameters, the Dataframe is passed to data parameter. The bars like in the previous example, compare the tips received from two groups – Smokers and Non-Smokers. Notice how the bars are displayed in different colors, also the axes labels are taken from the input data. We can add custom labels to the plot by calling set_xlabel() and set_ylabel functions on the axes object. You can also set the labels using the xlabel() and ylabel() functions defined in the pyplot module of the Matplotlib library.
fig1,axes1 = plt.subplots()
sns.barplot(x='smoker',y='total_bill',data=tips,hue='sex',estimator=np.sum,errcolor='r',errwidth=0.75,capsize=0.2,ax=axes1,)
axes1.set_xlabel('Smoker - Yes/No')
axes1.set_ylabel('Bill amount')
axes1.set_title('Total bill vs Smoker')
plt.show()
We can also create a figure object with multiple axes and render the plots onto a specific axes by using the ‘ax’ argument. If we do not specify any value for the argument, plot is rendered to the current axes.
The ‘hue’ parameter can be used to show information about the different sub-groups in a category. In the above example, the ‘hue’ parameter is assigned to the column ‘sex’ which further categorizes the data and has created two side by side bars. A separate colored bar represents each sub-group and a legend is added to let us know what each sub-category is.
The ‘estimator’ argument can be used to change how the data is aggregated. By default, each bar of a barplot displays the mean(np.mean) value of a variable. Using the estimator argument this behaviour can be changed. The estimator argument can receive a function such as np.sum, len, np.median or any other statistical function.
The red colored cap-tipped lines that extend from the edge of the bars are called Error Bars and they provide an additional layer of detail in the plotted data. Error Bars help to indicate estimated error or uncertainity of a data point. A short Error Bar shows that values are concentrated, indicating that the plotted average value is more likely, while a long Error Bar would indicate that the values are more spread out and less reliable.
Let me take you back to the number guessing game that we have played on day 1 of the course. It is a simple game where the computer chooses a random number between 1 and 100 and you have to guess the number. After each guess, the program helps you by telling if your guess is higher or lower than the chosen number. Say the number chosen is 60. Let’s visualize this.
Basically, it is a series of decisions based on the clue you get from the program. For lack of a better intelligence, we just predict the middle number on either side ( higher or lower ). We can think of the same process using a decision tree.
A decision tree is essentially a series of decisions that are based on the data you are working with. For example, if you are guessing a number between 1 and 1000, the decision tree would have been much bigger. In this case, the guesses (cuts in the number line) are exactly in the middle – for lack of a better guessing method. However, a real decision tree makes a much more informed decision. Once again, let me show this with a simple example.
Take an apple that is rotten somewhere at the side.
Our goal is to find a series of cuts that maximises the fresh apple portion (and minimizes the rotten portion) with the least possible cuts. How would you do it ?
Something like this – The criteria you would be using to make the cuts is based on the maximum area(volume) that you can carve off that is not rotten.
Decision trees also work the same way. For example, let’s take the iris dataset. To make things simple, let’s just focus on
setosa and versicolor
sepal length and sepal width.
If you are asked to carve out one species from another using just horizontal and vertical lines, how would you do it ? It’s not an easy job to do it efficiently. Probably, we would do it something like this.
What were we basing our decisions (cut-off points) on ? Visually, we were essentially eye-balling to minimize the mix of species(or maximize the grouping of a single species) in such a way that more of a specific species fell on one side than the other.
Decision tree algorithms do just this – except that they use a bit of math to do the same. Scikit learn provides two cost functions for this
Gini Index ( default )
Entropy
We will start with the basic implementation and then we will focus on understand Gini Index in a bit more detail.
model <- rpart(Species ~ Sepal.Length + Sepal.Width, method="class", data=iris[1:100,])
plot(model)
text(model, use.n = TRUE)
Visualization
One of the biggest advantages of Decision Trees is that the whole process is very intuitive to humans. It is more or less like a white-box ( as opposed to other methods like Neural Nets that are like blackboxes – We just can’t make sense of the weights and layers ). A useful method to understand Decision Trees is to visualize them. To do that, we have to install the graphviz package. Let’s do that first.
With this function, some of the text gets split off. So, a better representation would be to create a ps or postscript file, which can be either viewed directly or converted to a pdf to be viewed.
post(model, file = "tree.ps",
title = "Iris (Setosa/Versicolor) simple decision tree ")
You can either use a ps file viewer or convert it to pdf and view it.
The key parameters used by Decision Tree are either of the following
gini index
entropy
By default DecisionTreeClassifier uses the gini index to calculate the cut-offs. Let’s focus on the gini index cost function.
Let’s look at the first cut ,
sepal length (cm) < = 5.45
Let’s do some calculations by hand. It will give us a better understanding of what is going on under the hood.
Formula to calculate Gini index is
where pi is the probability of occurance of the i’th class. In our case, we have just 2 classes.
setosa
versicolor
The above visual demonstrates how the calculations have been done.
Initial gini index.
Gini indices after the first cut has been made.
so, gini index after the split is
0.208 + 0.183 = 0.391
which is less than the original gini index that we started with – 0.5
Now, the question arises, why did the first cut happen at sepal length <= 5.45 ? Why not at 6.0 ? To understand this, let’s actually make a cut at sepal length <= 6.0 and re-calculate the gini indices.
The gini index at the new cut-off sepal length <= 6.0 is 0.468. It is not much different from where we initially started (0.5). By now, you should be able to understand the reasons behind the classifier’s decision points.
Challenge
Try to calculate the gini index by hand(like above) when the sepal width <=2.75
Here is a visual of how decision tree algorithm has eventually solved the problem.
Now that we understand how decision trees work, let’s try and predict some data. Let’s first split our data into train/test datasets.
data = iris_new
index = sample(1:nrow(data),nrow(data)*.8)
train = data[index,]
test = data[-index,]
library(caret)
cm = confusionMatrix(y_pred,test[,5])
cm
Confusion Matrix and Statistics
Reference
Prediction setosa versicolor virginica
setosa 7 3 0
versicolor 0 10 0
virginica 0 0 0
Overall Statistics
Accuracy : 0.85
95% CI : (0.6211, 0.9679)
No Information Rate : 0.65
P-Value [Acc > NIR] : 0.04438
Kappa : 0.7
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: setosa Class: versicolor Class: virginica
Sensitivity 1.0000 0.7692 NA
Specificity 0.7692 1.0000 1
Pos Pred Value 0.7000 1.0000 NA
Neg Pred Value 1.0000 0.7000 NA
Prevalence 0.3500 0.6500 0
Detection Rate 0.3500 0.5000 0
Detection Prevalence 0.5000 0.5000 0
Balanced Accuracy 0.8846 0.8846 NA
Thats an accuracy score of 88%. Pretty decent.
Decision Trees for Regression
Although decision trees are mostly used for classification problems, you can use them for regression as well.
Let’s try to fit the Boston Housing dataset with decision trees. Just to make things simple, let’s just use the LSTAT predictor to predict the target.
library(mlbench)
data(BostonHousing)
boston = BostonHousing
index = sample(1:nrow(boston),nrow(boston)*.8)
train = boston[index,]
test = boston[-index,]
model = train(medv ~ lstat,
data = train,
method = "rpart",
trControl = trainControl("cv", number = 15),
tuneLength = 15)
Warning message in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
"There were missing values in resampled performance measures."
y_pred = predict ( model, newdata = test)
# Use the in-built "score" function of the regressor to calculate the R-squared
RMSE(y_pred, test$medv)
5.38848765509317
Let’s plot the predictions on top of the data to visually see how well the prediction does in comparision to the actual data.
As you can see from this plot, Decision Tree’s prediction for regression is step wise (as opposed to being smooth). This is because, decision trees work by averaging the data to be predicted to the nearest neighbors. You can try this by changing the parameters in the model, like trainControl and tuneLength. Also, instead of parameters like Gini Index or Entropy, Decision trees use RMSE to calculate the splits in case of regression.
Overfitting
One of the main drawbacks of Decision Trees is overfitting. We can very well observe that the model fits the training data 100%, but when it comes to test data, there would be a huge variation. Such a large variation is something that you do NOT observe in other models – say linear regression. Let’s try and fit the iris data again (this time with all the 3 species).
index = sample(1:nrow(iris),nrow(iris)*.8)
train = iris[index,]
test = iris[-index,]
plot ( 1:100, train_score, pch = 19, col = "blue",
xlab = "Number of times the model has predicted",
ylab = "Score")
points ( 1:100, test_score, pch = 19, col = "red")
legend ( x= "topleft", legend = c("Train", "Test"), col=c("red", "blue"), pch=19, cex=0.8)
As you can see, the training accuracy is fixed, but the test results are all over (although the range is quite limited in this case). What is essentially happening is that the model is trying to learn the noise as well. One solution to this is to limit the tree size. This is called pruning. One solution is to find out a parameter called Complexity Parameter or cp. It is used as a cut-off parameter to identify the minimum value needed at a decision tree node to identify if it should go forward with another split or not. It is based on the cost of the entire tree so far ( with all its splits ). So, a simple cutoff can be used to identify at what level (of tree size) should the tree be pruned.
Luckily, rpart can give us a graph of cp at different tree sizes. plotcp can also plot this for us.
Tree Depth
Look at the depth of the decision tree for iris data.
In this decision tree, after a tree depth of 3, there is no real value addition. It’s basically nitpicking at the small numbers – and that’s exactly what is leading to overfitting. What is we can restrict the tree to just 3 levels of depth ?
model <- rpart(Species ~ ., method="class", data=iris)
plotcp(model)
As you can see from the plot above, beyond a tree size of tree (at cp value of 0.66), the tree starts to overfit the data. The relative accuracy (on the y-axis) is computed using cross validation. Let’s prune the tree at a cp of 0.66.
Pruning – Solution to Overfitting
Overfitting is basically a problem where the model tries to fit all of the training data. Since there are many borderline cases, it is not practical to fit all the data points for any ML model. In order to avoid this, we have to prune (cut off some of it’s branches) the tree to make it an a better fit for the training data – rather than a 100% fit. There are 2 ways to prune a tree.
Pre-Pruning – Prune the decision tree while it is being created.
Post-Pruning – Prune the decision tree after the entire tree has been created.
model_new = prune(model, cp = 0.066)
Let’s plot the decision tree of both the old and new (pruned) model to see how they perform.
If you run the same plot and see how different the training and the test data looks, you will get an understanding of how we were able to prevent overfitting.
From : njlotterries1234@gmail.com Subject : You won Lottery Body : Congratulations !!! You won a lottery of 5 Million dollars. Click here to claim..
What do you think of this ? Is this a spam e-mail or not ? In all probability this is spam. How do you know it ? Well, you look at the index words – words like “lottery” , “viagra” , “free”, “money back”. When you see these words, generally you tend to classify that message as spam. This is exactly how Naive Bayes works. Let’s formalize our understanding a bit by going a bit deeper.
Bayes Theorem & Conditional Probability
Before we get into “Naive” Bayes, we have to first understand Bayes theorem. To understand Bayes theorem, we have to first understand something called Conditional Probability. What exactly is it ?
Say there is a standard deck of cards and you draw a card at random.
What is the probability that it is a red card ?
What is the probability that it is a face card, given that it is a red card ?
This is called conditional probability. Bayes theorem is an alternate way to compute the same thing.
Now, let’s calculate each one of these probabilities.
Probability of face card P(A)
Probability of a red card
Probability of a red card , given it is a face card.
And finally, we calculate the probability of a face card, given its a red card P ( face | red )
What did we achieve here ? Looks like we have made things more complicated, right ? I agree with you. In fact, this formula is not all that useful in machine learning. But there is an assumption that makes this formula extraordinarily useful in ML. Let’s go back to the email example.
Again, not very useful. To calculate the probability of “You won lottery” is very arbitrary. You cannot calculate the probability of occurrence of all different phrases or combination of words. The next time around / the subject line might say “Congratulations!! You won lottery” -which is slightly different from ‘ ‘You won lottery” . Point being, you cannot possibly Calculate all different combination of words that could result from the use of all different words in the English dictionary.
Naive Bayes
This is where the Bayes theorem becomes Naive . Let’s revisit the formula again.
The probability of the word “You” occurring in the email is independent of the Lord ‘ “Won” occurring. eg.,
Do you have the paper with you ?
we have won the contract
These Sentences are completely independent. When we break down the event into the respective independent events, probability can be Simplified as follows.
This is actually a “Naive” assumption – because in reality, there is some level of overlap. Meaning, when you mention the word “lottery”, you almost always use the word “win” or some variant-like ”won'” or “winning” . However, this is where ML is lucky. Even with the naive assumption, results are pretty good with text classification in real life. Let’s apply the simplification to the Bayes theorem once again.
With a bit of naivety, this formula became so much more useful. In fact, it makes it so useful that Naive Bayes is almost exclusively used for most text classification tasks. Let’s explore this example with some rough data – just believable, made-up data.
Probability of “You won lottery” being spam.
Probability of “You won spam” as NOT spam.
So, the probability of this phrase not being spam is 1.51.
Pretty effective, right? Especially given the simplification. Calculating the probability of the individual words is easy. The heart of this algorithm is, given any sentence, this algorithm can break it down into it’s components (words) and based on the “spamminess” of each of the words, the entire sentence can be classified as spam or not.
All we are trying to do in Naive Bayes, is to break down a complicated problem into its components. Once the component is classified, essentially the bigger piece is classified as well.
It is like solving a jigsaw puzzle. How do you solve one typically ? You look for smaller puzzles to solve. Say this is a picture of a car – you start to look for smaller components of the car, like a tire, a windshield and solve for each of these separately. Once you got the pieces figured out, all you have to do is to put them in order. Naive Bayes works more or less like this.
Classify fruits based on Characteristics
Now that we understand the basics of Naive Bayes, let’s create a simple dataset and solve it in excel. The purpose behind this exercise is to get familiar with Naive Bayes calculation using a smaller dataset. This is going to solidify our understanding a bit further, before we dive into more complicated examples.
Solve the fruits dataset in excel
The probability of each of the characteristics – round, large, small etc, can be calculated as below.
Now, let’s move on to the individual conditional probabilities. For example, what is the probability that a fruit is round, given that it is an apple ? In all the cases of Apple, the fruit is always round.
However, what is the probability that a fruit is red, given that its an apple ? one out of three apples are red.
Like that, we keep calculating the conditional probabilities of all the individual characteristics. Think of this like calculating the probability of each individual word being spam or not.
Time to test our data. Let’s say, we want to calculate the probability of a fruit being an Apple, if it is round and large. All we have to do is plug the numbers.
What is the probability that a fruit is an apple, if it is round, large and smooth ?
Based on our little dataset, we are not doing too bad. let’s do the opposite now. What is the probability of a fruit being a grape, given that it is round, large and smooth ?
Makes sense, right ? grape is never “large”. Hence the probability of a fruit being a grape if it is “large” is relatively small – 16 %.
fruit round large small red green black golden yellow smooth rough
<fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct>
apple yes yes no yes no no no no yes no
apple yes yes no no yes no no no yes no
apple yes yes no no no no yes no yes no
grape yes no yes yes no no no no yes no
grape yes no yes no yes no no no yes no
grape yes no yes no no yes no no yes no
melon yes yes no no yes no no no yes no
melon yes yes no no no no yes no no yes
melon yes yes no no no no no yes no yes
model = naiveBayes(fruit ~ . , data = fruits)
pred = predict ( model , fruits[,2:11 ])
pred
apple
apple
apple
grape
grape
grape
apple
melon
melon
table(pred, fruits[,1])
pred apple grape melon
apple 3 0 1
grape 0 3 0
melon 0 0 2
That’s not bad, given such a small set of characteristics. Let’s actually get the confusion matrix to get the accuracy percentage.
library(caret)
cm = confusionMatrix(pred,as.factor(fruits[,1]))
cm
Confusion Matrix and Statistics
Reference
Prediction apple grape melon
apple 3 0 1
grape 0 3 0
melon 0 0 2
Overall Statistics
Accuracy : 0.8889
95% CI : (0.5175, 0.9972)
No Information Rate : 0.3333
P-Value [Acc > NIR] : 0.0009653
Kappa : 0.8333
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: apple Class: grape Class: melon
Sensitivity 1.0000 1.0000 0.6667
Specificity 0.8333 1.0000 1.0000
Pos Pred Value 0.7500 1.0000 1.0000
Neg Pred Value 1.0000 1.0000 0.8571
Prevalence 0.3333 0.3333 0.3333
Detection Rate 0.3333 0.3333 0.2222
Detection Prevalence 0.4444 0.3333 0.2222
Balanced Accuracy 0.9167 1.0000 0.8333
That’s an accuracy of almost 90%. We are not very far off, given our dataset is pretty small. The one place where we went wrong is in classify a melon wrongly as an apple. If we compared the predictions vs the actuals, we can see that we went wrong with the 7th entry ( a melon being mis-classified as an apple ).
predict = pred
actual = fruits[,1]
data.frame(predict,actual)
predict actual
<fct> <fct>
apple apple
apple apple
apple apple
grape grape
grape grape
grape grape
apple melon
melon melon
melon melon
Let’s check out the actual entry.
As you can see, the entry for melon ( watermelon ) coincides in its data points to the green apple. How could this happen ? This is because of an oversimplification with regards to size. We only have 2 sizes – small and large. However, both the apple and water melon are large ( and round and smooth ). And that’s why the NB algorithm got it wrong. If we had an extra size characteristic ( say XL ), that would have solved this problem.
Classify messages as Spam
Now that we understood the basics of Naive Bayes along with a simple example in excel and R, we can proceed to solve the problem that we started with – To classify a message as spam or not.
Step 1 – Get the dataset
There is a simple SMS ( text message ) dataset available at kaggle or at the UCI ML datesets. You can also download the file from Ajay Tech’s github page. Download the zip file and open it in excel as a tab delimited format. Each of these messages have been classified as either spam or ham ( ham is just a technical word for “non-spam” ). Open the dataset in excel as a tab-delimited format and give column names ( if not available already ).
Step 2 – Read the dataset into R
data = read.csv("./data/spam.csv", encoding='ISO-8859-1')
head(data)
class message X X.1 X.2
<fct> <fct> <fct> <fct> <fct>
1 ham Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
2 ham Ok lar... Joking wif u oni...
3 spam Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
4 ham U dun say so early hor... U c already then say...
5 ham Nah I don't think he goes to usf, he lives around here though
6 spam FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
data = data[,c(1,2)]
head(data)
class message
<fct> <fct>
1 ham Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
2 ham Ok lar... Joking wif u oni...
3 spam Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
4 ham U dun say so early hor... U c already then say...
5 ham Nah I don't think he goes to usf, he lives around here though
6 spam FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
Step 3 – Simple EDA
How many messages are there in the dataset ?
nrow(data)
5572
summary(data$class)
ham
4825
spam
747
Out of them, count the occurances of spam vs ham(non-spam)
Just like we converted the fruits dataset’s feature values from “yes” or “no” to a 1 or 0 , Naive Bayes (or for that matter most ML algorithms) need the feature data to be numeric in nature. In order to do it, we have to use some techniques from Natural language processing.
Tokenize the message (into words) and create a sparse matrix
This process basically splits the sentence (message) to it’s individual words. Let’s see a sample before we tokenize the entire dataset.
Now, let’s do the same on our real messages dataset.
Before we use the DTM as-is, we have to convert the 0,1’s to Factors – like a Yes and No. This is becuase Naive Bayes works well with Factors. Let’s write a small functiont that converts all values greater than 0 to a Yes and otherwise to No.
library(e1071)
model = naiveBayes(msg_train_dtm, msg_class_train)
Step 7 – Evaluate the model.
pred = predict(model, msg_test_dtm)
table(msg_class_test, pred)
pred
msg_class_test ham spam
ham 950 13
spam 18 134
Measure the accuracy using the confusion matrix from the caret library.
library(caret)
cm = confusionMatrix(pred,msg_class_test)
cm
Confusion Matrix and Statistics
Reference
Prediction ham spam
ham 2385 46
spam 12 343
Accuracy : 0.9792
95% CI : (0.9732, 0.9842)
No Information Rate : 0.8604
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.9101
Mcnemar's Test P-Value : 1.47e-05
Sensitivity : 0.9950
Specificity : 0.8817
Pos Pred Value : 0.9811
Neg Pred Value : 0.9662
Prevalence : 0.8604
Detection Rate : 0.8561
Detection Prevalence : 0.8726
Balanced Accuracy : 0.9384
'Positive' Class : ham
There is scope for a ton of optimization here like
convert all characters to lower case
remove punctuation
remove stop words etc
But that is a subject for another day. Here we will just focus on learning the Naive Bayes algorithm.
Challenge
Let’s solve another problem in Naive Bayes. Load up a dataset called house-votes-84.csv from the data folder. The data set should look like this.
These are the results from Congressmen in the US, voting a Yes ( for ) or No (Against ) on 16 different issues. Instead of putting names, the class column identifies the congressmen as either a Republican or a Democrat.
Task – Identify the congressmen as either a Democrat or Republican based on his voting pattern.
solution – This problem is almost exactly similar to the fruits data we started with at the beginning of leaning Naive Bayes.
# 1. Import the dataset
library(mlbench)
data(HouseVotes84, package = "mlbench")
data = HouseVotes84
head(data)
# 2. train/test split
index = sample(1:nrow(data),nrow(data)*.8)
train = data[index,]
test = data[-index,]
# 3. model the data
model = naiveBayes(Class ~ ., data = train)
# 4. predict the data
pred = predict(model, test)
# 5. Accuracy
table(pred, test$Class)
library(caret)
cm = confusionMatrix(pred,test$Class)
print (cm)
Class V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11 V12 V13 V14 V15 V16
<fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct>
1 republican n y n y y y n n n y NA y y y n y
2 republican n y n y y y n n n n n y y y n NA
3 democrat NA y y NA y y n n n n y n y y n n
4 democrat n y y n NA y n n n n y n y n n y
5 democrat y y y n y y n n n n y NA y y y y
6 democrat n y y n y y n n n n n n y y y y
pred democrat republican
democrat 121 5
republican 21 71
Confusion Matrix and Statistics
Reference
Prediction democrat republican
democrat 121 5
republican 21 71
Accuracy : 0.8807
95% CI : (0.8301, 0.9206)
No Information Rate : 0.6514
P-Value [Acc > NIR] : 1.002e-14
Kappa : 0.7496
Mcnemar's Test P-Value : 0.003264
Sensitivity : 0.8521
Specificity : 0.9342
Pos Pred Value : 0.9603
Neg Pred Value : 0.7717
Prevalence : 0.6514
Detection Rate : 0.5550
Detection Prevalence : 0.5780
Balanced Accuracy : 0.8932
'Positive' Class : democrat
Challenge – IMDB review Sentiment Analysis
Similar to the SPAM/HAM problem, we can also predict if an IMDB review is positive or negative based on the words in it.
# step 1 - Read the data file
library("xlsx")
data = read.xlsx("./data/imdb-reviews-sentiment.xlsx", sheetIndex = 1, header=TRUE)
# step 2 - Create a DTM based on the text data
library(tm)
message_corpus = Corpus(VectorSource(data$review))
message_dtm <- DocumentTermMatrix(message_corpus)
# step 3 - function to convert the integers to "Yes" or "No" factors in the DTM
counts_to_factor = function(x){
x = ifelse(x > 0, 1, 0)
x = factor(x, levels = c(0,1), labels = c("No", "Yes"))
return (x)
}
# step 4 - Split the DTMs to Train and test data and convert the integers to factors for "Yes" and "No"
index = sample(1:nrow(data),nrow(data)*.8)
train = data[index,2]
test = data[-index,2]
msg_cor_train = Corpus(VectorSource(data[train,]$review))
msg_train_dtm = DocumentTermMatrix(msg_cor_train)
msg_train_dtm = apply(msg_train_dtm, MARGIN = 2, counts_to_factor)
msg_class_train = data$sentiment[train]
msg_cor_test = Corpus(VectorSource(data[test,]$review))
msg_test_dtm = DocumentTermMatrix(msg_cor_test)
msg_test_dtm = apply(msg_test_dtm, MARGIN = 2, counts_to_factor)
msg_class_test = data$sentiment[test]
# step 4 - model the data using Naive Bayes
library(e1071)
model = naiveBayes(msg_train_dtm, msg_class_train)
#step 4- predict the results from the model using the test data
pred = predict(model, msg_test_dtm)
# step 6 - get the accuracy from confusion matrix.
library(caret)
cm = confusionMatrix(pred,data$sentiment[test])
print (cm)
Confusion Matrix and Statistics
Reference
Prediction negative positive
negative 0 0
positive 0 2000
Accuracy : 1
95% CI : (0.9982, 1)
No Information Rate : 1
P-Value [Acc > NIR] : 1
Kappa : NaN
Mcnemar's Test P-Value : NA
Sensitivity : NA
Specificity : 1
Pos Pred Value : NA
Neg Pred Value : NA
Prevalence : 0
Detection Rate : 0
Detection Prevalence : 0
Balanced Accuracy : NA
'Positive' Class : negative
Naive Bayes on continuous variables
So far, we have seen Naive Bayes work on factor variables. Does NB ever work on continous variables ? Yes, it does – ofcourse with discretized version of those variables ( Think of binning a normal distribution ). The key assumption there would be that the variable has a normal distribution. For example, think of the iris dataset – is the “Sepal length” of setosa species normally distributed ? Let’s find out.
# matplotlib does not have the ability to plot the kernel density function
import matplotlib.pyplot as plt
# So, we are using seaborn instead
import seaborn as sns
%matplotlib inline
# You can check from these curves that Sepal data is normally distributed, but
# the petal data is not. Try them on one by one.
sns.distplot(iris_data[:,0], hist=True, kde=True)
sns.distplot(iris_data[:,1], hist=True, kde=True)
sns.distplot(iris_data[:,2], hist=True, kde=True)
sns.distplot(iris_data[:,3], hist=True, kde=True)
Only the Sepal data is normally distributed. Ideally, we should just be using the sepal data ( Sepal Length and Sepal Width ). However, let’s just use all of these and see what happens. As an exercise, try using just the sepal data and check for the accuracy.
# 1. train/test split
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris_data , iris_target, test_size=0.2)
# 2. Naive Bayes modeling
from sklearn.naive_bayes import MultinomialNB
model = MultinomialNB().fit(X_train, y_train)
# 3. Predict data
y_predict = model.predict(X_test)
# 4. Create a confusion matrix to check accuracy
print ( pd.crosstab(y_test, y_predict,rownames=['Actual'], colnames=['Predicted'], margins=True) )
# 5. Print the accuracy score
from sklearn.metrics import confusion_matrix, accuracy_score
print ( confusion_matrix(y_test, y_predict) )
print ( accuracy_score(y_test,y_predict))
That’s pretty accurate as well , isn’t it ? In fact even though one of the assumptions (all the variables should be independent of each other ) is wrong, Naive Bayes still outperforms some other classification algorithms.
The priors ( Probability of a “Setosa” occuring or a “Virginica” occuring .. ) is 0.33 ( a third ) – which we know.
How about the conditional probabilities ? This is where it gets tricky for continuous variables. You cannot have conditional probabilities for each of the values ( as the number can get infinite ). So, in case of a normal distribution, an approximation is applied based on the following formula.
Logistic regression is a type of linear regression. However, it is used for classification only. Huh.. that’s confusing, right ? Let’s dive in.
Let’s take the simple iris dataset. The target variable as you know by now ( from day 9 – Introduction to Classification in Python, where we discussed classification using K Nearest neighbours ) is categorical in nature. Let’s load the data first.
Let’s simplify this further – say, we wanted to predict the species based on a single parameter – Sepal Length. Let’s first plot it.
plot ( iris_data[,1], iris_data[,5],
pch = 19, col = "blue",
xlab = "Sepal Length",
ylab = "Setosa or Versicolor")
We know that regression is used to predict a continous variable. What about a categorical variable like this ? (species). If we can draw a curve like this,
and for all target values predicted with value > 0.5 put it in one category, and for all target values less than 0.5, put it in the other category – like this.
A linear regression (multilinear in this case) equation looks like this.
Logistic regression is almost similar to linear regression. The difference lies in how the predictor is calculated. Let’s see it in the next section.
Math
The name logistic regression is derived from the logit function. This function is based on odds.
logit function
Let’s take an example. A standard dice roll has 6 outcomes. So, what is the probability of landing a 4 ?
Now, what about odds ? The odds of landing a 4 is
So, when we substitute p into the odds equation, it becomes
OK. Now that we understand Probability and Odds, let’s get to the log of odds.
How exactly is the logistic regression similar to linear regression ? Like so.
Where the predictor ( log of odds ) varies between ( -∞ to +∞ ).
To understand this better, let’s plot the log of odds between a probabilty value of 0 and 1.
x = runif ( 100, min = 0.1, max = 0.9)
x = sort ( x )
y = log ( x / (1-x))
This is the logistic regression curve. It maps a probability value ( 0 to 1 ) to a number ( -∞ to +∞ ). However, we are not looking for a continous variable, right ? The predictor we are looking for is a categorical variable – in our case, we said we would be able to predict this based on probability.
p >= 0.5 – Category 1
p < 0.5 – Category 2
In order to calculate those probabilities, we would have to calculate the inverse function of the logit function.
sigmoid function
The inverse of the logit curve is the inverse-logit or sigmoid function ( or expit function). The sigmoid function transforms the numbers ( -∞ to +∞ ) back to values between 0 and 1. Here is the formula for the sigmoid function.
Essentially, if we flip the logit function 900, you get the sigmoid function.
Here is the trick – As long as we are able to find a curve like the one below, although the target (predictor) is a value between 0 and 1 ( probabilities), we can say that all values below 0.5 ( half way mark ) belongs to one category and the remaining ( values above 0.5 ) belong to the next category. This is the essence of logistic regression.
Implementation
Let’s try to implement the logistic regression function in R step by step.
Data & Modeling
Just to keep the same example going, let’s try to fit the sepal length data to try and predict the species as either Setosa or Versicolor.
model = glm(Species ~ Sepal.Length + Sepal.Width,
data=iris_data,
family=binomial(link="logit"))
Warning message:
"glm.fit: algorithm did not converge"
Warning message:
"glm.fit: fitted probabilities numerically 0 or 1 occurred"
These are probability values. We need to convert them to the actual factors (setosa & versicolor), because, we are dealing with just 2 classes. We can just use a simple ifelse ( ) syntax to convert all values > 0.5 to a versicolor and < 0.5 to a setosa.
library(caret)
cm = confusionMatrix(y_pred_levels,iris_data[,5])
cm
Warning message in levels(reference) != levels(data):
"longer object length is not a multiple of shorter object length"
Warning message in confusionMatrix.default(y_pred_levels, iris_data[, 5]):
"Levels are not in the same order for reference and data. Refactoring data to match."
Confusion Matrix and Statistics
Reference
Prediction setosa versicolor virginica
setosa 50 0 0
versicolor 0 50 0
virginica 0 0 0
Overall Statistics
Accuracy : 1
95% CI : (0.9638, 1)
No Information Rate : 0.5
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 1
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: setosa Class: versicolor Class: virginica
Sensitivity 1.0 1.0 NA
Specificity 1.0 1.0 1
Pos Pred Value 1.0 1.0 NA
Neg Pred Value 1.0 1.0 NA
Prevalence 0.5 0.5 0
Detection Rate 0.5 0.5 0
Detection Prevalence 0.5 0.5 0
Balanced Accuracy 1.0 1.0 NA
Basic Evaluation
Let’s split up the data into training and test data and model it. This time let’s do the full iris dataset. Since there are 3 Species to be predicted, we cannot use glm with a “binomial” family algorithm. Let’s use another library called nnet. As usual, to evaluate categorical target data, we use a confusion matrix.
library(nnet)
index = sample(1:nrow(iris),nrow(iris)*.8)
train = iris[index,]
test = iris[-index,]
model = multinom(Species~.,data = train)
# weights: 18 (10 variable)
initial value 131.833475
iter 10 value 11.516467
iter 20 value 4.881298
iter 30 value 4.469920
iter 40 value 4.263054
iter 50 value 3.911756
iter 60 value 3.823284
iter 70 value 3.598069
iter 80 value 3.591202
iter 90 value 3.570975
iter 100 value 3.570835
final value 3.570835
stopped after 100 iterations
pred = predict(model,test)
As usual, to evaluate categorical target data, we use a confusion matrix.
library(caret)
cm = confusionMatrix(pred, as.factor(test$Species))
cm
Confusion Matrix and Statistics
Reference
Prediction setosa versicolor virginica
setosa 16 0 0
versicolor 0 8 1
virginica 0 2 3
Overall Statistics
Accuracy : 0.9
95% CI : (0.7347, 0.9789)
No Information Rate : 0.5333
P-Value [Acc > NIR] : 1.989e-05
Kappa : 0.8315
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: setosa Class: versicolor Class: virginica
Sensitivity 1.0000 0.8000 0.7500
Specificity 1.0000 0.9500 0.9231
Pos Pred Value 1.0000 0.8889 0.6000
Neg Pred Value 1.0000 0.9048 0.9600
Prevalence 0.5333 0.3333 0.1333
Detection Rate 0.5333 0.2667 0.1000
Detection Prevalence 0.5333 0.3000 0.1667
Balanced Accuracy 1.0000 0.8750 0.8365
That’s a 84% score.
Optimization
Let’s plot the logistic regression curve for the test data set.
As you can see, still there are quite a bit of mis-classifications. All the false negatives and false positives in the plot above are examples of mis-classification. Irrespective of the algorithm used to calculate the fit, there is only so much that can be done in increasing the classification accuracy given the data as-is. Other terms for True Positive and True Negative are
Sensitivity ( True Positive )
Specificity ( True Negative )
There is a specific optimization that can be done – and that is to specifically increase accuracy of one segment of the confusion matrix at the expense of the other segments. For example, if you look at a visual of the confusion matrix for our dataset.
For this dataset, classifying the species as “setosa” is positive and “versi-color” as negative.
setosa – positive
versi-color – negative
Let’s actuall calculate the accuracy values. Say the confusion matrix looks like this.[[11 1] [ 1 12]]
What if we want to predict 100% of setosa ( or a much more accurate classification than 0.9 ). Of course, like we discussed earlier, it will come at a cost. However, there is a usecase for this scenario. For example, if getting a particular classification right is extremely important, then we focus more on that particular classification than the others. Have you seen the Brad Pitt’s movie “World War Z” ? A plague emerges all around the world and an asylum is set up in Israel with a high wall. However, when you enter the wall, they make absolutely sure that you do not have the plague. Say, if you have the plague and if you call that as positive, then essentially you maximize the green box in the picture above.
Or another example would be, if you were to diagonize cancer patients, you would rather want to increase the odds of predicting a cancer patient if he/she really has it (True positive). Even it it comes at a cost of wrongly classifying a non-cancer patient as positive ( false positive ). The former can save a life while the later will just cost the company a patient.
Evaluation
ROC Curve
Receiver Operating Characteristics – also called ROC Curve is a measure of how good the classification is. Scikit Learn has an easy way to create ROC curve and calculate the area under the ROC curve. First off, let’s start with a classifier like Logistic Regression
Step 1 – Get the data
iris_data = iris[51:150,]
iris_data = iris_data[order(iris_data$Sepal.Length),]
model = glm( Species ~ Sepal.Length, data = iris_data , family = binomial)
library(pROC)
# iris$Species has 3 classes and hence 3 factors. So, we are converting them to
# 0/1 factor using factor (c(iris_data$Species) - 2).
# -2 so that 3,2 become 1,0
roc = roc ( factor (c(iris_data$Species) - 2), model$fitted.values, plot = TRUE, legacy.axes = TRUE)
Setting levels: control = 0, case = 1
Setting direction: controls < cases
Area under the curve is an indicator of how accurate our classifier is. You can get it as follows.
Summary : Regression is a basic Machine Learning
Technique. We will be learning Simple Linear Regression, Multiple Linear
Regression and Polynomial Regression in this section. On top of the
basic technique itself, we will also be covering many statistical
concepts (like Null Hypothesis, p-value, r-2, RMSE ) and other key concepts in machine learning like Feature Selection, Training and Test data splits that will be useful in evaluating models going forward as well.
In Machine Learning, most problems are classified as supervised vs
unsupervised learning. We will first focus on supervised learning
algorithms and later work on unsupervised learning algorithms.
Supervised learning is once again split into the following 2 groups
Classification
Regression
Given a particular height and weight, classify the person as either
male or female. This is an example of classification. You are
essentially trying to classify the person – in this case – as male or female based on certain characteristics.
In contrast, say you are trying to predict the body fat percentage
based on height and weight – this is an example of a regression problem.
What is the difference ? Body Fat % is a continuous variable – say it
starts at a minimum of 2% (really lean) and can go all the way up to 50 %
say (extremely obese) – as opposed to a categorical variable in the
example above ( Male or Female ).
Why Regression
If you are learning how to solve a regression problem for the first
time, you probably need to understand why you need regression to start
with. This is probably the simplest of the regression problems. Let’s
start with a simple data set – Swedish auto insurance claims. You can
google it or get it from kaggle. It is a very small data set with just 2 variables –
Number of Claims
Total payment for these claims ( in thousands )
Claims come first and the settlement happens much later. Assuming
these are claims this company receives per week, is there a way we can
predict how much the company will end up paying, just based on the
number of claims ?
What value does this bring to the company ?
Being able to predict the payment based on the number of claims gives
a very good understanding of the companies expenses very much in
advance.
Why do you need machine learning for this problem ?
Each claims is different – A fender bender claims costs a thousand
dollars and a total could cost 20 grand. The type of claim does make for
a good predictor, but let’s just assume we don’t have that at this
point. Even if we had the type of claim, a fender bender can cost
anywhere from 300 to 2000 dollars based on the severity of damage, the
arbitration and several environmental factors. Essentially, there is no
easy way to correlate the claims to the payment. If we tried to do this
using some kind of IF, THEN logic, we would be going around in hoops.
Solve Regression in Python
data = read.csv("./data/insurance.csv", skip= 5, header=TRUE)
head(data)
Looking OK – but since we are reading data from file, we have to
ensure that Python is not reading integers as strings or other object
types. Let’s quickly verify if the data types are correct.
class(data$claims)
'integer'
class(data$payment)
'numeric'
Looking good. Now, onto LinearRegression. We don’t have to install
any specific packages to do linear regression in R. As part of base R,
we have a function called lm ( ).
model = lm( payment ~ claims, data = data)
Our model is ready. Let’s start predicting claims based on the count of claims. We will be using the predict method. But before we do that, let’s plot this out to understand what we have done so far.
# pch stands for plotting characteristics.
# A value of 19 means that it is a solid dot (as opposed to a default hollow dot)
plot( data$claim,data$payment,pch = 19, col = "blue")
predict = predict(model, data = data[,1])
Linear Regression has already solve this problem for us – we just
didn’t realize it yet. The parameters ( also called co-efficients )
plot( data$claim,data$payment,pch = 19, col = "blue")
lines(data$claim, fitted(model), col="red")
In case you are wondering how this line was draw, you can very well
use the slope and intercept parameters to draw the same line.
How did we get the straight line ?
A straight line can be defined mathematically using
These are also called coefficients. The fit function of LinearRegression has already arrived at these numbers ( slope and intercept ). It has done so based on the data
plot( data$claim,data$payment,pch = 19, col = "blue")
# abline defines the equation as "y = bx + a" as opposed to our definition as "y = ax + b"
# that is why, we are setting a = intercept and b = slope in the equation.
abline(a=intercept, b=slope, col = "red")
What did we achieve ?
What we have essentially done is predicted a relationship between the
number of claims and the total amount paid. For example, what is the
total amount expected to be paid when the number of claims is 80 ?
Easy, right ?
Prediction
We don’t have to draw lines like this every time to predict the value of Y for a value of X. You can use the predict ( ) function. Let’s first predict all the original claim values.
pred = predict(model, newdata = data.frame(claims=data$claims))
head(pred)
You don’t have to predict all the original values. Normally, we do a
train/test split (which, we will see later) and try to predict the
accuracy based on the test dataset. However, we can throw any predictor
(claims in this case) value to see how the model predicts the payments
for. For example, say we pick 5 different claim values (10,20,30,40,50)
and want to find out what our model predicts.
pred = predict(model, newdata = data.frame(claims=c(10,20,30,40,60)))
pred
Let’s also plot these to compare how well we predicted.
plot(c(10,20,30,40,60), pred ,pch=19, col="blue")
points(c(10,20,30,40,60),original_claim_values, pch=19, col="red" )
# abline defines the equation as "y = bx + a" as opposed to our definition as "y = ax + b"
# that is why, we are setting a = intercept and b = slope in the equation.
abline(a=intercept, b=slope, col = "black")
Some values are pretty close, some are a bit off – but nevertheless
its a good prediction for the amount of time we spent doing it.
Simple Linear Regression
What we have seen so far is an example of implementing Simple Linear Regression
in R. What we will see in this section is the math behind it – Give it a
try – if it gets tedious, please do a couple of re-reads. It is
critical that you understand this section.
How did LinearRegression fit the straight line
The fun starts now. How did the LinearRegression ( ) function fit the straight line ? How did it it arrive at this equation
y=3.4x+19.99
Obviously, there is no one straight line that will pass through all the points in this case.
If you take these 4 data points, we can eyeball a straight line that
goes through the middle of it. The ones marked with question marks are
visually not a good fit at all. But the question that linear model tries
to solve is,
What is the “Optimum” straight line that best describes the relationship between X and Y
This is where statistics comes in.
Let’s zoom in
Let’s simplify and make up some numbers ( for easy calculation) of
claims vs payments. Say we have a set of 5 data points for claims vs
payments and we wish to fit a linear model that can predict further
data. This is a very small data set to do anything practical, but there
is a reason why we are doing such a small data set as we will see in the
coming sections.
If we were asked to eyeball a straight line that best fits these data points, this is how we would do it.
How did we do it ? Our eye is essentially trying to minimize the distances
from each of these points to the straight line. The best fitting
straight line is one which minimizes the distances for all these points.
Linear regression in machine learning does exactly that – Instead of a
human eye, machine learning takes the help of statistics to do this
approximation. There are a couple of methods to do this in statistics.
Ordinary Least Squares
Gradient Descent
Let’s explore the first method here.
Residuals
When we tried to minimize the distance of each of these points from
the line we are trying to fit, the distances between the points and the
straight line ( on the y axis ) are called residuals.
Sum of Squares
Warning – Geeky Statistics stuff
To arrive at the best fit values for the straight line, statisticians
have arrived at the following formula based on the method of least
squares. How they arrived at it is pretty geeky and let’s leave that to
the statisticians.
This equation sounds scary, but it is not. I am going to prove it to
you in a minute down below. There are 2 things in this equation that
require an explanation.
The weird symbol that looks like a knocked up W . It is used for summation.
y with a bar on it ( or x with a bar ). The bar just represents the average. So y with a bar on it represents the average.
Let’s take the same numbers that we have above and try to calculate
the formula by hand. Excel would make things easy, but let’s just do it
manually, since the numbers are not all that bad.
That was huge – Now we can easily calculate a and b from the Y = a + b
X equation. The convention for slope and intercept here is different
from what we referred to previously. You can stick to one convention and
use it. However, you might see multiple variations of this (like Y = mx
+ b for example).
These are the slope and intercept values. We can now use these to plot the fit line.
plot(sample$x,sample$y,pch=19,col="blue")
abline(a=intercept, b=slope, col = "red")
The model object has a whole host of other information that you can
use to predict how good of a fit the line is to the data. But first,
let’s predict the value in a table.
The differences ( residuals ) are highlighted below.
There are so many other parameters ( like the p-value, r-squared, r-squared adjusted ) and graphs ( Residuals vs fitted , Q-Q Plot
etc ) that are used to analyze the performance of a model fit. We are
going to get to that in the next section. To set the stage for these
parameters, let’s scale up our data set.
Multilinear Regression
So far we have seen one predictor and one response variable. This is also called simple
linear regression. Ofcoure, real world is not that simple, right ?
There can be multiple predictors. We are going to look at one such
example in a classic example – Boston Housing dataset.
It has 13 predictors and 1 response variable. So, the equation for that mathematically would be
the value of n would be 13 in this case. Let’s look at the dataset below.
Boston Housing dataset
Predicting the price of a house is based on many parameters like the
size, neighborhood, crime rate, pollution etc. Essentially, there is no
mathematical equation that can predict the price of a house based on
these parameters – that’s essentially where ML comes into the picture.
In the earlier example, we just had to predict the value of Y given a value of X.
There is just 1 predictor ( X ). However, let’s have a look at the
Boston Housing data set – tons of variables. Load the data first.
What is the problem we are trying to solve ? We are actually trying
to predict the Median house price based on a variety of parameters –
like Crime rate, Nitric Oxide content in the air, age of the house etc.
Each of these parameters have a “say” in the price of the house. The
target variable is valled the Response Variable and the parameters that have a “say” in determining the target variable are called the Predictors.
In our case, there are 12 predictors ( CRIM, ZN, INDUS…LSTAT) and 1
response variable ( MEDV ). Just for simplicity sake, let’s just pick
one parameter – rm – the number of rooms in the house. Let’s see how well we predict the price of the house based on just this one parameter.
Let’s plot our predictor vs response variable as a scatter plot and draw the straight line that our model has predicted.
plot(boston_housing$rm, boston_housing$medv, pch=19, col = "blue", xlim=c(4,10)) # try this without setting the range of x values (using xlim)
abline(a=intercept, b=slope, col = "red")
Looks like a decent enough fit. Let’s do another – lstat – lower status population percentage.
Step 2 – Determine the slope and intercept from the model
slope = -1.010507
intercept = 35.31254
Step 3 – Plot the data and fit line
plot(boston_housing$lstat, boston_housing$medv, pch=19, col = "blue")
abline(a=intercept, b=slope, col = "red")
This seems accurate enough. However, if you look at both these
graphs, the relationship between the predictors and response variable
seems a bit different between the predictor LSTAT vs RM. Let’s put these
graphs side by side to understand better.
The first picture is lstat vs medv and the second is rm vs medv.
Also, not all variables might be relevant ( irrespective of the
direction, decrease or increase ). Let’s take the parameter dis –
distance to employment. Once again, if we try to fit this using our
steps
Visually, there is not much of a linear relationship between distance
and median value ( although we tried to fit it ). How exactly do we
measure the relationship ?
Correlation
This is the simplest measure of relation between 2 numeric variables.
Luckily, pandas provides a built-in method for calculating correlation –
corr ( ). For example,
# use = "pairwise.complete.obs" ensures that NAs are handled
cor(boston_housing$medv, boston_housing$dis, use = "pairwise.complete.obs")
0.138798442415068
Correlation values are calculated to values between 0 and 1.
Technically, the values can vary between -1 and +1. 0 being no
correlation and 1 being highly correlated ( -1 also signifies a high
correlation, just that it is a negative correlation , meaning if the
predictor values increases, the response value decreases). In the
example above, the relationship between distance and median value is
just 13 %. How about others predictors ?
cor(boston_housing$medv, boston_housing$lstat, use = "pairwise.complete.obs")
-0.706255058922179
cor(boston_housing$rm, boston_housing$medv, use = "pairwise.complete.obs")
0.740180804891272
This seems in-line with our plots above right ?
Correlation is NOT causation
However strong the correlation is, it does NOT imply causation. This
is a bit tricky to understand. For example, look at the picture below.
There seems to be a significant correlation between the number of rooms ( rm ) and the median price value (medv).
Now, imagine another column – Power Consumption . As the number of rooms increase, the power consumption also tends to increase.
Does it mean that Power Consumption has a strong correlation with the house price ? Very unlikely – isn’t it ? Mathematically, correlation
is just a tool that signifies the level of correlation between 2
variables. However, it is upto us (or the domain expert) to determine if
the correlation is real or ficticious.
Quiz
:The higher the correlation, lower the stronger is the relationship between the variables.
:True
:False
: Correlation is a measure of how strong the relationship is between two variables.
:True
:False
: -0.95 represents a strong positive correlation
:True
:False
: A correlation value of 0 shows that there is a perfect correlation between the variables
:True
:False
: The picture below shows a strong negative correlation
:True
:False
x = c(10,20,30,40,50,60,70,80,90,100)
y = c(100,90,80,70,60,50,40,30,20,10)
plot(x,y, pch=19, col="blue")
p value
While the correlation quantifies the relationship between two variables, it doesn’t tell us if the relationship is statistically significant. The word statistically significant deserves an explanation. If a relationship (between two variables) is NOT caused by chance, then it is statistically significant . That is exactly what p-value answers. Also called probability value, p-value answers the following question –
If a predictor is relevant to predict the response and if yes, how relevant is it ?
We have to understand p-value in the context of Null Hypothesis.
Null Hypothesis
Null hypothesis (denoted as H0 assumes
that there is no relationship between the predictor and the response
variable. For example, if you look at the relationship between the
number of rooms ( rm ) and the median price value of the home ( medv )
in the Boston Housing dataset, Null hypothesis says that there is NO
relationship between them.
Well, although we can see a linear relationship visually ( almost )
between those 2 variables, we start off with the Null Hypothesis. It is
indicated in statistics as H0
Alternate hypothesis indicates that they are related. It is indicated as H1
. P-value indicates how much the observed data is inconsistent with the
Null Hypothesis. This is a bit cryptic, right ? Let’s explore this
further.
Let’s just consider 2 random variables that are normally distributed.
Since they are random variables, there would be no relationship between
them, right ? Let’s check it out.
Dataset 1
A normal distribution of 100 values with a mean of 100 and sd of 20
x = rnorm(n = 100, mean = 100, sd = 20)
Dataset 2
Another normal distribution of 100 values with a mean of 100 and sd of 20
y = rnorm(n = 100, mean = 100, sd = 20)
Let’s plot x vs y and see how it looks.
plot(x,y)
This almost looks like the night sky, isn’t it ? Point being, there is no relationship between x and y
as they are random variables – That is pretty understandable right ?
Let’s try to calculate a relationship between these two variables (
although there is none ) and see how it compares against another
situation where there IS actually a relationship.
What is the p-value in this case ? Look at the summary of the linear model. The last column against the predictor (x in this case) is the p-value.
summary(model)
Call:
lm(formula = y ~ x)
Residuals:
Min 1Q Median 3Q Max
-44.483 -14.434 -0.342 11.768 50.386
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 98.19433 10.12278 9.700 5.41e-16 ***
x 0.01513 0.09677 0.156 0.876
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 19.69 on 98 degrees of freedom
Multiple R-squared: 0.0002494, Adjusted R-squared: -0.009952
F-statistic: 0.02445 on 1 and 98 DF, p-value: 0.8761
And the p-value in this case is 0.89.
p-value is 0.89 – that’s 89 %. p-value is always between 0 and 1.
0.89 is a big number right ? Does it indicate that there is a strong
relationship between x and y in this case ? On the contrary, a high
p-value indicates that the there is NO relationship between x and y –
which is what the Null Hypothesis states.
On the other hand, if we calculate the p-value of the relationship
between the number of rooms (“RM”) and median price (“MEDV”), then you
get a very small value.
model = lm (boston_housing$medv ~ boston_housing$rm)
summary(model)
Call:
lm(formula = boston_housing$medv ~ boston_housing$rm)
Residuals:
Min 1Q Median 3Q Max
-25.674 -2.488 -0.168 2.411 39.680
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -38.2758 2.6708 -14.33 <2e-16 ***
boston_housing$rm 9.7779 0.4187 23.35 <2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 5.93 on 450 degrees of freedom
(54 observations deleted due to missingness)
Multiple R-squared: 0.5479, Adjusted R-squared: 0.5469
F-statistic: 545.3 on 1 and 450 DF, p-value: < 2.2e-16
p-value in this case is 2.0 x 10-16. That is an extremely low value ( 0.0000..00248 ) .
What it means is that there is a 0.0000..00248 % chance that the correlation is random.
or in other words, there is a 99.99999 % chance that the data does
represent a relationship between the two variables. Whereas in the case
of the random variales, there is a 89% chance that the correlation is
random
Optimum p-value
A value below which we can safely conclude that the relationship
between 2 variables did not happen by chance, but because there is a
true causal relationship, is called an optimum p-value.
Is there a fixed p-value below which, we can safely conclude a strong relationship (Alternate Hypothesis) ?
Typically p <= 0.05 is accepted in most cases. However, the actual
value at which the business is comfortable is based on a variety of
parameters like
type of data
level of confidence the business requires etc
Key Parameters
r-squared – r2
R-squared is a measure of the explanatory power of the model – How
well are the variables able to explain the response variables using the
model. A model’s r2 varies between 0 and 1. Applying the summary function on the lm ( ) model should give us the r2 value.
model = lm (boston_housing$medv ~ boston_housing$rm)
summary(model)$r.squared
0.547867623929491
In the example above, we tried to model the response variable medv (
median house value ) as a measure of the number of rooms ( rm ) – the r2 value is 0.547 . It is a measure of how well the model explains the relationship. A low value of r2 ( r2 = 0 ) means that the explanatory variables are not able to predict the response variable well
A high value or r2 ( r2= 1 ) means that the
explanatory variables are fully able to predict the response – In this
case the number of rooms ( rm ) is able to explain the variance in the
median house value around 48 %. The remaining 52% variance is
unexplained by the explanatory variable.
How is r2 calculated
The formula to calculate r2 is
The denominator Sum of Squares total is the worst possible score. The numerator Sum of Squares residuals is how far the model is performing. So r 2 is essentially a normalized score of how well the model is predicting the data – in comparision to the worst possible score.
What do we mean by worst possible score ? Well, if
we know the weights of each of the individuals in the class, without any
indicator of whose that weight is ( just a list of weights ), what is
our best possible prediction of weight for anybody in the class ? We
just do an arithmetic average – right ? And that’s what we use to
calculate the Sum of Squares total .
Let’s calculate r2 by hand for our simple dataset.
Let’s verify this in R.
x = c(20,40,60,80,100)
y = c(40,60,80,80,90 )
model = lm (y ~ x)
summary(model)$r.squared
0.9
There you go – our manual calculation verifies with R’s calculation.
Exercise – Calculate the r2 value of of Boston housing dataset for the predictor – NOX ( Level of Nitric Oxide in the air ).
r-squared adjusted
Mathematically, r2 has a peculiar property. Adding more predictors increases the value of r2 . This is not very intuitive to begin with. Let’s try it on our Boston Housing dataset.
In the example above, we have tried to model medv from rm . So, the
only explanatory variable is rm ( number of rooms ). Based on that
relationship we have an r2 value of 0.54 . What would happen
if we add another explanatory variable ? – say lstat ( percentage of
lower status of the population ).
# Fit model based on
# 1. RM - Number of rooms
# 2. LSTAT - Percentage of people with lower status
model = lm (boston_housing$medv ~ boston_housing$rm + boston_housing$lstat)
summary(model)$r.squared
0.652305668932009
See, the r 2 has increased from 0.48 to 0.63 by including a second predictor. How about a third predictor ?
# Fit model based on
# 1. RM - Number of rooms
# 2. LSTAT - Percentage of people with lower status
# 3. NOX - Nitric oxide in the air
model = lm (boston_housing$medv ~ boston_housing$rm + boston_housing$lstat + boston_housing$nox)
summary(model)$r.squared
0.652671556550756
There is a slight increase in r 2 – from 0.6385 to 0.6389.
r2 value – Predictors
0.483525 – number of rooms
0.638561 – number of rooms + lower stata population
0.638910 – number of rooms + lower stata population + Nitric Oxide in the air
You can keep adding as many predictors as you want and you can observe that r 2
value always increases with every predictor. That doesn’t seem right.
Isn’t it ? Let’s try something. To the boston housing dataset, let’s add
a column with just random numbers and see if r 2 still increases. If it does, then we have a problem.
# Generate a column of 506 random numbers
x = rnorm(n = 506, mean = 100, sd = 20)
# and add it to the boston dataset
boston_housing_new = cbind(boston_housing, x)
# what is the new shape ? It should be 13 columns ( as opposed to the original 12 )
dim ( boston_housing_new)
506
15
Now, let’s try the regression with the predictors RM, LSTAT, NOX and the new random variable.
r2 value – Predictors
0.483525 – number of rooms
0.638561 – number of rooms + lower stata population
0.638910 – number of rooms + lower stata population + Nitric Oxide in the air
0.639007 – number of rooms + lower stata population + Nitric Oxide in the air + some random variable
This is crazy, right ? Just add any random variable ( which is not supposed to have any predictive power) and still the r 2 increases ? You are probably beginning to doubt the predictive power of r 2 in the first place. Well, it’s not all bad with r 2. Just that every random variable has some predictive power. In order to counter this, there is a new variable called r2 adjusted.
where
n = sample size
p = number of predictors
Essentially, the new parameter r2 adjusted works by penalizing more parameters. That is why p – the number of predictors is in the denominator.
What it goes to show is that, adding more predictors does not necessarily increase the explanatory power of the model. r2 adjusted accommodates for this by incorporating a penalty for the number of predictors ( more the predictors, lesser the r2 adjusted ).
When you add a predictor it should add significant value. If it is
not ( For example, adding NOX – Nitric Oxide as a predictor ) r2-adjusted tells you that it is a worthless predictor and you better get rid of it.
summary(model)$adj.r.squared
0.649563812528321
summary(model)$r.squared
0.652671894013658
Exercise
As an exercise, please model a multi-linear regression of Boston Housing dataset using
Scenario 1 – number of rooms + lower stata population + Nitric Oxide in the air
Scenario 2 – number of rooms + lower stata population + Nitric Oxide in the air + random variable
Calculate r2 and r2 adjusted for both the scenarios. Did r2 adjusted decrease or increase in scenario 2 ?
RMSE – Root Mean Square error
While r2 is a relative measure of fit, RMSE
is an absolute measure of fit. Lower values indicate a better fit and
higher values not. Calculating RMSE is quite simple – it is quite
similar to the standard deviation (the square root of variance). While
Standard Deviation ( sd ) is a measure of how far the distribution is
from the mean, RMSE is a measure of how far the fit is from the actuals.
So, to calculate the RMSE, you just have to borrow
the residuals from the model and do a couple of mathematical operations (
quite similar to how you do them to the differences from the mean in
the case of mean )
sqrt( mean ( summary(model)$residuals ^2 ) )
5.18556387861621
Now, you might have a question here. When you already have the RMSE to measure the fit, why do you need another metric – r2
? Well, you can’t measure the growth rate of an elephant and a horse in
absolute terms. RMSE is in the magnitute of the actual response
variable, while r2 is on a uniform scale ( 0 to 1 ).
Feature Selection
In the Boston Housing dataset example, there are 13 predictors and 1
response variable. And the response variable is MEDV – Median house
value.
Previously, we have seen some examples of predicting the house value based on a random set of predictors –
RM – Number of rooms
LSTAT – Percentage of people with lower status
NOX – Nitric oxide content
etc.
However, we know that not all these variables have an equal say. So,
how do we know which of the predictors have the most impact in
predicting the house price ?
The process of identifying the features ( predictors ) that have the most predictive power is called Feature Selection
This process is called Stepwise Regression. Let’s explore it in the next section.
Stepwise Regression
In stepwise regression, we select a parameter to determine the fit.
Based on how well the model is being fit, we add or keep removing
predictors ( features ) until we come to a point where we can longer
improve the model based on the selected parameter.
What kind of parameters can be used to evaluate the model ? There are many choices like
p-value
r-squared
r-squared (adjusted )
F-tests or T-tests
etc
For now, since we are already aware of p-value let’s choose it as our selection criteria parameter.
Call:
lm(formula = medv ~ ., data = boston_housing)
Residuals:
Min 1Q Median 3Q Max
-19.3991 -2.5377 -0.5848 1.7101 27.9916
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 20.866365 5.477190 3.810 0.000159 ***
crim -0.238025 0.223853 -1.063 0.288229
zn 0.038221 0.013372 2.858 0.004462 **
indus 0.051356 0.059643 0.861 0.389683
chas 2.435048 0.829794 2.935 0.003516 **
nox -11.657986 3.928009 -2.968 0.003163 **
rm 5.110158 0.454838 11.235 < 2e-16 ***
age -0.006094 0.013101 -0.465 0.642031
dis -1.271514 0.195766 -6.495 2.26e-10 ***
rad 0.294444 0.085387 3.448 0.000619 ***
tax -0.011360 0.003612 -3.145 0.001776 **
ptratio -0.831030 0.127000 -6.544 1.68e-10 ***
b 0.012314 0.003513 3.505 0.000503 ***
lstat -0.520753 0.057428 -9.068 < 2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 4.553 on 438 degrees of freedom
(54 observations deleted due to missingness)
Multiple R-squared: 0.7405, Adjusted R-squared: 0.7328
F-statistic: 96.16 on 13 and 438 DF, p-value: < 2.2e-16
Since we have to chosen p-value as our key assesment criteria, we are interested in this part of the table.
Irrespective of the selection criteria, there are 2 basic methods in stepwise regression.
Backward Elimination
Forward Selection
Backward Elimination
As we see from the table above, not all predictors have the same
p-value. So, in backward elimination, we start by eliminating the
predictor with the worst parameter value (p-value in this case) and
re-evaluate the model again.
# Eliminate INDUS predictor.
X = boston_housing[,c(1,2,4,5,6,7,8,9,10,11,12,13,14)]
model = lm(medv ~ . , data = X)
summary(model)
Call:
lm(formula = medv ~ ., data = X)
Residuals:
Min 1Q Median 3Q Max
-19.3513 -2.5184 -0.5814 1.6283 28.0395
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 20.669661 5.470812 3.778 0.000180 ***
crim -0.233426 0.223724 -1.043 0.297353
zn 0.037131 0.013308 2.790 0.005498 **
chas 2.508402 0.825166 3.040 0.002508 **
nox -10.830319 3.807459 -2.845 0.004656 **
rm 5.075414 0.452912 11.206 < 2e-16 ***
age -0.006238 0.013096 -0.476 0.634071
dis -1.307708 0.191144 -6.841 2.64e-11 ***
rad 0.278620 0.083362 3.342 0.000902 ***
tax -0.010031 0.003265 -3.072 0.002258 **
ptratio -0.817289 0.125956 -6.489 2.34e-10 ***
b 0.012233 0.003511 3.485 0.000542 ***
lstat -0.515724 0.057113 -9.030 < 2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 4.552 on 439 degrees of freedom
(54 observations deleted due to missingness)
Multiple R-squared: 0.7401, Adjusted R-squared: 0.733
F-statistic: 104.2 on 12 and 439 DF, p-value: < 2.2e-16
Eliminate Crime ( crim) next.
Zone (zn) is the next parameter that has a relatively higher value.
# Eliminate CRIM predictor.
X = boston_housing[,c(2,4,5,6,7,8,9,10,11,12,13,14)]
model = lm(medv ~ . , data = X)
summary(model)
Call:
lm(formula = medv ~ ., data = X)
Residuals:
Min 1Q Median 3Q Max
-18.7560 -2.4916 -0.5937 1.6188 28.1229
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 21.327935 5.434858 3.924 0.000101 ***
zn 0.035475 0.013214 2.685 0.007535 **
chas 2.544654 0.824517 3.086 0.002155 **
nox -11.589297 3.737700 -3.101 0.002055 **
rm 5.043166 0.451901 11.160 < 2e-16 ***
age -0.006310 0.013097 -0.482 0.630189
dis -1.297433 0.190909 -6.796 3.51e-11 ***
rad 0.220041 0.061626 3.571 0.000395 ***
tax -0.010079 0.003265 -3.087 0.002152 **
ptratio -0.815266 0.125954 -6.473 2.57e-10 ***
b 0.012720 0.003480 3.655 0.000288 ***
lstat -0.527443 0.056004 -9.418 < 2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 4.552 on 440 degrees of freedom
(54 observations deleted due to missingness)
Multiple R-squared: 0.7394, Adjusted R-squared: 0.7329
F-statistic: 113.5 on 11 and 440 DF, p-value: < 2.2e-16
Next, we eliminate age.
# Eliminate CRIM predictor.
X = boston_housing[,c(2,4,5,6,8,9,10,11,12,13,14)]
model = lm(medv ~ . , data = X)
summary(model)
Call:
lm(formula = medv ~ ., data = X)
Residuals:
Min 1Q Median 3Q Max
-18.6405 -2.5670 -0.5505 1.6077 27.8556
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 21.679282 5.381020 4.029 6.60e-05 ***
zn 0.036260 0.013102 2.768 0.005885 **
chas 2.525351 0.822826 3.069 0.002279 **
nox -12.097475 3.582668 -3.377 0.000799 ***
rm 4.988713 0.437159 11.412 < 2e-16 ***
dis -1.271884 0.183237 -6.941 1.40e-11 ***
rad 0.222320 0.061391 3.621 0.000327 ***
tax -0.010122 0.003261 -3.104 0.002034 **
ptratio -0.820426 0.125389 -6.543 1.68e-10 ***
b 0.012598 0.003468 3.633 0.000313 ***
lstat -0.537831 0.051641 -10.415 < 2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 4.548 on 441 degrees of freedom
(54 observations deleted due to missingness)
Multiple R-squared: 0.7393, Adjusted R-squared: 0.7334
F-statistic: 125.1 on 10 and 441 DF, p-value: < 2.2e-16
At this point, almost all of the predictors have a significantly
lower enough ( < 0.05 ) p-value. If you want to go further, probably
you will end up with a set of predictors like this.
Stepwise regression with backward elimination (based on r2) is just one of the methods. There are many other methods that we will be learning as we progress through the course.
Accuracy of the model
We have seen enough statistics for now. But we haven’t actually
predicted anything yet, right ? Using the 5 parameters that we have
narrowed down to, let’s start predicting median house prices and see how
accurate our model is.
In order to predict data, we first need to train the model ( which we
have done already ). However, instead of training the model on the
entire dataset, we split the data into training and test data sets. This
process is also called Train/Test split. And as you might have guessed
already, doing this R is pretty easy.
Training and Test datasets
How many rows of data do we have in Boston Housing dataset ?
dim(boston_housing)
506
14
506 rows. That’s all the data we have. Now, let’s split this data
into training and test datasets. What is the ratio of the split ?
Typically a 80-20 % split should do. You can also do a 75-25 % split.
The exact percentage splits would probably be based mostly on the
accuracy of the data and a bit of trail and error.
Essentially, this what we are trying to do.
Use R’s sample ( ) function to sample the indices of 80% of the dataset size.
# Split the data to train and test
# 1. Get the number of rows
rows = nrow(boston_housing)
# 2. sample out 80% of the numbers between 1 and the number of rows
index = sample(1:rows, rows*0.8)
# 3. get the training data using the indices that correspond to 80% of the dataset
train = boston_housing[index,]
# 4. The remaining part of the data set becomes the test dataset.
test = boston_housing[-index,]
Now, that we know how to do the split, let’s do it on just the predictors that we need.
# Fit model based on
# 1. RM - Number of rooms
# 2. LSTAT - Percentage of people with lower status
# 3. DIS - weighted distances to five Boston employment centres
# 4. PRATIO - pupil-teacher ratio by town
# 5. B - 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town
model = lm ( medv ~ rm + lstat + dis + ptratio + b, data = boston_housing)
Now that we have trained the model, let’s test it.
pred = predict ( model, newdata = test )
Now, let’s find out how accurate our predictions are. A quick and
dirty way to visually see this would be to plot the predicted vs actuals
on a scatterplot.
plot ( pred, test$medv,
xlab = "Predictions",
ylab = "Actuals",
xlim=c(0,50)) # use the xlimit to avoid some wild predictions
If the prediction were 100% correct, we would get a perfect 450 line. Let’s draw that as a baseline to compare the prediction’s performance.
plot ( pred, test$medv,
xlab = "Predictions",
ylab = "Actuals",
xlim=c(0,50),
ylim=c(0,50)) # use the xlimit to avoid some wild predictions
abline(0,1)
That’s not a bad fit. Let’s calculate the r2 to have a numeric estimate of how good of a fit we have.
summary(model)$r.squared
0.714425296126006
That’s an r2 of 0.7. So, we can say that we are 70% accurate in our prediction.
Polynomial Regression
So far, we have seen examples of Linear Regression (
both simple and multi). However, not all types of data can be fit into a
Linear Regression. For example, population growth is a good example of a
non-linear data. Probably a better word for it is exponential
growth. Let’s take a simple example – The population of India and how
it is projected to grow in this century. ( This is real data taken from
United Nations at https://population.un.org/wpp/Download/Standard/Population/ )
india = read.csv("./data/india_population.csv")
head(india)
Let’s plot this 150 years of data ( since 1950 until 2015 and projected data until 2100)
plot(india$year, india$population,
xlab = "year",
ylab = "Population in Billions",
type = "n") # does not produce any points or lines
lines(india$year, india$population, type="o")
Can you imagine trying to fit this data using Linear Regression ?
Well, it is going to be a huge approximation after – say 2040 where the
linearity ends. Let’s try it to see how well it fits.
# build the model
model = lm ( india$population ~ india$year, data = india)
# scatter plot
plot(india$year,india$population)
#get slope and intercept
intercept = coef(model)[1]
slope = coef(model)[2]
# fit line.
abline(a=intercept, b=slope, col = "red")
Well, it does an OK job given that the model is linear, but can we account for the curve somehow ?
This is exactly where polynomials come in. sklearn has built-in methods for this – PolynomialFeature . Let’s use it to fit the data.
# build the model
model = lm ( india$population ~ india$year + I(india$year ^ 2), data = india)
# fit line.
pred = predict(model, newdata = india)
# scatter plot
plot(india$year,india$population, pch=19, col= "blue")
points (india$year, pred, pch=19, col= "red")
summary(model)
Call:
lm(formula = india$population ~ india$year + I(india$year^2),
data = india)
Residuals:
Min 1Q Median 3Q Max
-101455469 -56743742 -2325066 54151912 214380534
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -4.149e+11 1.331e+10 -31.18 <2e-16 ***
india$year 4.017e+08 1.315e+07 30.55 <2e-16 ***
I(india$year^2) -9.685e+04 3.246e+03 -29.84 <2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 67780000 on 148 degrees of freedom
Multiple R-squared: 0.9777, Adjusted R-squared: 0.9774
F-statistic: 3242 on 2 and 148 DF, p-value: < 2.2e-16
What is really happening here ? A single column ( year ) has now become 3 columns.
We can put both these models on the same plot and you can see how well the prediction hugs the data.
Visually, you can see the green line (polynomial fit) seems to follow
the blue line(data) much better than the red line(linear fit), right ?
How about we go one step further and see if increasing the polynomial
degree would make this any better.
The orange line (polynomial fit, degree = 3) is hugging the actual
data curve (blue) much closer than the green line ( polynomial fit ,
degree = 2 ), right ? Let’s do the numbers ( r2 score ) as well, to really know the degree of fit.
cat ( summary(model_lin)$r.squared, " --> R squared - Linear fit", "\n")
cat ( summary(model_poly_2)$r.squared, " --> R squared - polynomial fit ( degree = 2)", "\n")
cat ( summary(model_poly_3)$r.squared, " --> R squared - polynomial fit ( degree = 3)", "\n")
0.8434444 --> R squared - Linear fit
0.9776831 --> R squared - polynomial fit ( degree = 2)
0.9953978 --> R squared - polynomial fit ( degree = 3)
Challenge
Use polynomial regression to fit the following parameters on the boston Housing dataset.
LSTAT ( Lower Status Population percentage )
MEDV ( Median house value )
After modeling ,
Plot both linear and polynomial model ( degree = 2 ) to visually show how they perform
Compare the r2 values between linear and polynomial model
Verify if increasing the degree of the polynomial regression ( say degree = 3 or 4 ) increases the performance.
0.544 – r2 score with linear modeling 0.640 – r2 score with polynomial modeling ( degree = 2 ) 0.657 – r2 score with polynomial modeling ( degree = 3 )
That’s just a marginal improvement over second degree polynomial regression. So, you don’t need to go beyond second degree.
How about a 10th degree fit ?
# build the models
model_lin = lm ( medv ~ lstat , data = boston_housing)
model_poly_2 = lm ( medv ~ lstat + I(lstat ^ 2), data = boston_housing)
model_poly_3 = lm ( medv ~ lstat + I(lstat ^ 2) + I(lstat ^ 3), data = boston_housing)
# notice how we are specifying the polynomial fit
model_poly_10 = lm ( medv ~ poly(lstat,10), data = boston_housing)
# fit line.
pred_lin = predict(model_lin, newdata = boston_housing)
pred_poly_2 = predict(model_poly_2, newdata = boston_housing)
pred_poly_3 = predict(model_poly_3, newdata = boston_housing)
pred_poly_10 = predict(model_poly_3, newdata = boston_housing)
# scatter plot
lstat = boston_housing$lstat
medv = boston_housing$medv
plot (lstat,medv, pch=19, col= "blue")
points (lstat, pred_lin, , pch=19, col= "red")
points (lstat, pred_poly_2, pch = 19, col = "green")
points (lstat, pred_poly_3, pch = 19, col = "orange")
points (lstat, pred_poly_10, pch = 19, col = "cyan")
#put a legend as well.
legend("topleft", legend=c("data", "Linear Model", "Polynomial model - 2", "Polynomial model - 3"),
col=c("blue", "red", "green", "orange"), lty=1)
0.544 – r2 score with linear modeling 0.640 – r2 score with polynomial modeling ( degree = 2 ) 0.657 – r2 score with polynomial modeling ( degree = 3 ) 0.684 – r2 score with polynomial modeling ( degree = 10 )
See, even if you go to a 10th degree polynomial fitting, the improvement in r2 is just about 0.04 from a second degree polynomial fitting.
Overfitting
Sometimes, where there is not enough data, the model might tend to
overfit. Look at the example data below. We are simulating a sin wave.
x = seq( from = 10, to = 100, by=5)
y = sin(x*pi/180)
Let’s plot it to see how it it looks
plot(x,y, pch = 19, col = "blue")
Let’s introduce a bit of variance to make the data a bit realistic.
variance = runif(length(x), min=0.1, max=0.3)
y = y + variance
plot(x,y, pch = 19, col = "blue")
Now, let’s see how a 2nd degree polynomial regression fits.
# build the models
model_poly_2 = lm ( y ~ x + I(x ^ 2))
# fit line.
pred_poly_2 = predict(model_poly_2, newdata = data.frame(x,y))
# scatter plot
plot (x,y, pch=19, col= "blue")
lines (x, pred_poly_2, pch = 19, col = "green", lty=1, lwd = 3)
#put a legend as well.
legend("topleft", legend=c("data", "Polynomial model - 2"),
col=c("blue", "green"), lty=1)
This looks to be a good fit, right ? Now, let’s try a higher order polynomial fit ( say degree = 10 )
# build the models
model_poly_2 = lm ( y ~ x + I(x ^ 2))
model_poly_10 = lm ( y ~ poly(x,10))
# fit line.
pred_poly_2 = predict(model_poly_2, newdata = data.frame(x,y))
pred_poly_10 = predict(model_poly_10, newdata = data.frame(x,y))
# scatter plot
plot (x,y, pch=19, col= "blue")
lines (x, pred_poly_2, pch = 19, col = "green", lty=1, lwd = 3)
lines (x, pred_poly_10, pch = 19, col = "red", lty=1, lwd = 3)
#put a legend as well.
legend("topleft", legend=c("data", "Polynomial model - 2", "Polynomial model - 10"),
col=c("blue", "green", "red"), lty=1)
Overfitting typically happens when the model is trying to work too
hard for the data. And why is it a problem ? Overfitting tries to fit
the data too much and hence will not work well for new datasets. Think
of overfitting as localizing the solution for the test datasets – it is
more or less memorizing the data, not generalizing
a solution for the dataset. Obviously, it will not work as well when
model is used on real data set. We will see more examples of these when
we see other machine learning algorithms down the line.
If you are wondering why the simple linear regression is able to
learn the model just enough, but the higher degree polynomial regression
is over learning it, that is because the higher order polynomial
regression has the flexibility to learn more ( as compared to a linear
or second order polynomial regression ). This is actually, good, except
that it is not able to discern noise from data. Let’s increase the
dataset size to see if the same 15 degree polynomial regression peforms
better than a second order.
Increase sample size to 1000
x = seq( from = 10, to = 100, by=0.1)
y = sin(x*pi/180)
variance = runif(length(x), min=0.1, max=0.3)
y = y + variance
plot(x,y, pch = 19, col = "blue")
# build the models
model_poly_2 = lm ( y ~ x + I(x ^ 2))
model_poly_15 = lm ( y ~ poly(x,15))
# fit line.
pred_poly_2 = predict(model_poly_2, newdata = data.frame(x,y))
pred_poly_15 = predict(model_poly_10, newdata = data.frame(x,y))
# scatter plot
plot (x,y, pch=19, col= "blue")
lines (x, pred_poly_2, pch = 19, col = "green", lty=1, lwd = 3)
lines (x, pred_poly_15, pch = 19, col = "red", lty=1, lwd = 3)
#put a legend as well.
legend("topleft", legend=c("data", "Polynomial model - 2", "Polynomial model - 15"),
col=c("blue", "green", "red"), lty=1)
You see, data size did not matter, still the 15th order polynomial
regression still overfits the data. The reason is that for the amount of
data and noise, a second or 3rd degree polynomial has enough power to
capture the complexity of this data. Probably for more complicated data
sets, an increase in degree might capture the complexity better.
Detect Overfitting
If you look at the pictures above, we were able to clearly see an
overfit. This is because it is a 2 dimensional dataset. Most data is
multi-dimensional in real life – in which case we cannot be able to
identify an overfit, but just looking at a plot. There are basically 2
methods to identify an overfit.
1. Chose a simpler model.
Look at a measure of the score ( r2) – if there is not a significant difference, then better go for a simpler model.
x = seq( from = 10, to = 100, by=0.1)
y = sin(x*pi/180)
variance = runif(length(x), min=0.1, max=0.3)
y = y + variance
# build the models
model_poly_2 = lm ( y ~ x + I(x ^ 2))
model_poly_15 = lm ( y ~ poly(x,15))
# fit line.
pred_poly_2 = predict(model_poly_2, newdata = data.frame(x,y))
pred_poly_15 = predict(model_poly_10, newdata = data.frame(x,y))
# print out r-squared
print ( summary(model_poly_2)$r.squared)
print ( summary(model_poly_15)$r.squared)
The r2 score of the 2nd degree fit is almost as better as the 15th degree polynomial model. So, the simpler model ( 2nd degree model) wins.
2. Check the model performance across Training and Test datasets.
Another method to identify an overfit is by validating how well the
model is performing against the training dataset vs the test dataset.
# Split the data to train and test
# 1. Get the number of rows
rows = nrow(boston_housing)
# 2. sample out 80% of the numbers between 1 and the number of rows
index = sample(1:rows, rows*0.8)
# 3. get the training data using the indices that correspond to 80% of the dataset
train = boston_housing[index,]
# 4. The remaining part of the data set becomes the test dataset.
test = boston_housing[-index,]
Fit a 2nd degree and 15th degree polynomial regression
# build the models
model_lin = lm ( medv ~ lstat , data = train)
model_poly_2 = lm ( medv ~ lstat + I(lstat ^ 2), data = train)
model_poly_15 = lm ( medv ~ poly(lstat,15), data = train)
Compare the accuracy scores across both the scenarios (2nd degree and
15th degree polynomial regression) by predicting the test data set
based on modeling the training dataset.
As you can see the r-square score has decreased (slightly though )
for the 15th degree polynomial. That goes to show that the 15th degree
polynomial is definitely overfitting.
Linear Regression Assumptions
Homoskedasticity
In plain english, what this means is that, the variance between the
predicted and the actual values should be uniform across the spectrum.
Homoskedasticity = Uniform Variance
Let’s take the relationship between the number of rooms (“RM”) and
the median price value (“MEDV”) and try to fit a linear regression
model.
Now that we have predicted the housing prices, let’s see how they
compare to the actual house price. The difference between the predicted
and actual values is called residuals.
residuals_rm = y_pred - boston_housing$medv
Let’s plot the residuals vs the predicted values.
plot ( y_pred, residuals_rm,
xlim = c(1,50),
xlab = "Predicted Median House Price",
ylab = "Residuals",
main = "Fitted vs Residuals")
What this plot goes to show is that the variance ( residuals ) is not
uniform across the spectrum of house prices. For example, the variance
is high between 10K and 35K. But that’s probably because the bulk of the
data lies there. You can strip away outliers and restrict our dataset
to between those values and try again.
plot ( y_pred, residuals_rm,
xlim = c(1,50),
ylim = c(-10,10),
xlab = "Predicted Median House Price",
ylab = "Residuals",
main = "Fitted vs Residuals")
This actually looks ok. However, the ideal scenario is what is called the night sky – like a simulated picture below.
# Generate 10K random numbers between 0 and 100.
x = runif(100, min = 0, max = 100)
y = runif(100, min = 0, max = 100)
plot(x,y)
What homoskedasticity does is ensure that the prediciteve capacity is uniform across all the entire range of the predictors.
This is the ideal plot that confirms homeskedasticity – Also called the Night sky.
If you look at the variance(residuals), they are uniformly distributed
across the entire spectrum of house prices. What this means is that we
can confidently predict the house price across the entire spectrum with
the same level of accuracy.
Now, compare it to the previous Fitted vs Residuals plot to understand it better.
Normal Distribution of Residuals
Another assumption is that the residuals should be normally distributed. You can quickly verify that using a histogram,
hist(residuals_rm,breaks=20)
This looks almost normal, but a quick quantitative way to compare how
much a distribution adheres to a normal distribution is by using a Q-Q
plot. Or using a Q-Q plot. Q-Q plot is an acronym for Quantile-Quantile plot. The way it is calculated is by comparing the quantiles of the sample against the quantiles of a normal distribution.
# First sort the residuals ( sample )
residuals_rm_s = sort(residuals_rm)
# How big is the dataset ?
length(residuals_rm_s)
# create a normal distribution as big as the residuals dataset
norm=rnorm(length(residuals_rm_s), mean = 10, sd= 3)
# and sort it
norm_s = sort(norm)
# Scatter plot of normal distribution vs residuals
plot(norm_s,residuals_rm_s,
main = "Quantile - Quantile plot ( Q-Q )",
xlab = "Normal Distribution",
ylab = "Residuals")
452
What this plot goes to show is that the residuals are almost normally
distributed except at the fringes. So, if you can identify the outliers
and trim them, this dataset satisfies the Normal Distribution of Residuals assumption of Linear Regression.
No Multicollinearity
Another big word – and the explanation is a bit tricky as well.
Multicollinearity – If we are trying to predict z
from predictors x and y (z ~ x + y), then x and y should have orthogonal
predictive power. Meaning, each of them should have independent
predictive power. However, if x and y are somehow related, then we have a
situation called multicollinearity. The predictors should not be able to predict each other.
Let’s create a dummy dataset to understand this better.
x – A normal distribution
y – A slight variation of (with a bit of random variation added to x )
z – Another variation of x + y (with a bit of random variation added to x) so that we simulate a relationship between x, y and z
var = rnorm ( 100, mean = 5, sd= 2)
x = rnorm(100, mean = 100, sd = 10)
y = x +var
var = rnorm ( 100, mean = 5, sd= 2)
z = x + var +y
X = data.frame(x,y,z)
model = lm ( z ~ ., data = X )
summary(model)
Call:
lm(formula = z ~ ., data = X)
Residuals:
Min 1Q Median 3Q Max
-4.9027 -1.2531 -0.0404 1.2897 4.6848
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 9.05493 1.99755 4.533 1.66e-05 ***
x 0.94709 0.09944 9.524 1.43e-15 ***
y 1.01445 0.09687 10.472 < 2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 2.043 on 97 degrees of freedom
Multiple R-squared: 0.9907, Adjusted R-squared: 0.9905
F-statistic: 5176 on 2 and 97 DF, p-value: < 2.2e-16
As we see, we have a pretty good r2 because we have modelled the data almost as z = x + y + a bit of variance. Now, let’s calculate the VIF of the model.
car::vif(model)
x26.1425422526061y26.1425422526061
These are basically very high values. Any value above 5 is pretty
high in terms of collinearity. What this shows is that the predictors x
and y are basically highly dependent on each other (as opposed to their
capacity to predict the response variable).
Collinearity is predicted using a parameter called VIF, the formula for which is
Point being, as long as there is a relationship between the variables, it is displayed with a VIF or Variance Inflation Factor > 2. Let’s tae another example.
Collinearity and correlation are different though. Look at the picture below. There is some correlation between
Total Power Consumption <–> Number of rooms
Total Power Consumption <–> Power consumption per room
But there is a perfect collinearity between
Total Power Consumption <–> Number of rooms & Power consumption/room
Correlation is between 2 variables, while collinearity is between a single variable and a combination of variables.
Now that we understand collinearity, let’s find out how many
variables have a collinear relationship in the Boston Housing dataset.
model = lm ( medv ~ ., data = boston_housing )
summary(model)
Call:
lm(formula = medv ~ ., data = boston_housing)
Residuals:
Min 1Q Median 3Q Max
-19.3991 -2.5377 -0.5848 1.7101 27.9916
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 20.866365 5.477190 3.810 0.000159 ***
crim -0.238025 0.223853 -1.063 0.288229
zn 0.038221 0.013372 2.858 0.004462 **
indus 0.051356 0.059643 0.861 0.389683
chas 2.435048 0.829794 2.935 0.003516 **
nox -11.657986 3.928009 -2.968 0.003163 **
rm 5.110158 0.454838 11.235 < 2e-16 ***
age -0.006094 0.013101 -0.465 0.642031
dis -1.271514 0.195766 -6.495 2.26e-10 ***
rad 0.294444 0.085387 3.448 0.000619 ***
tax -0.011360 0.003612 -3.145 0.001776 **
ptratio -0.831030 0.127000 -6.544 1.68e-10 ***
b 0.012314 0.003513 3.505 0.000503 ***
lstat -0.520753 0.057428 -9.068 < 2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 4.553 on 438 degrees of freedom
(54 observations deleted due to missingness)
Multiple R-squared: 0.7405, Adjusted R-squared: 0.7328
F-statistic: 96.16 on 13 and 438 DF, p-value: < 2.2e-16
for ( name in names(car::vif(model)) ) {
cat ( name ,car::vif(model)[name], "\n")
}
crim 6.791148
zn 2.30185
indus 3.575504
chas 1.07249
nox 4.348241
rm 2.001147
age 2.95385
dis 3.643646
rad 9.025998
tax 6.501212
ptratio 1.698408
b 1.261811
lstat 2.719365
So, all of the following variables have a high collinearity.
9 – VIF for RAD
6.5 – VIF for TAX
How to fix Collinearity?
Iteratively remove the predictor with the highest VIF. So, in this case, remove either TAX or PTRATIO and do a re-run.
Use dimentionality reduction ( say PCA ) to reduce the number of predictors ( More on PCA & dimensionality reduction later)
model = lm ( medv ~ crim + zn + indus + chas + nox + rm + age + dis + tax + ptratio + b + lstat, data = boston_housing )
for ( name in names(car::vif(model)) ) {
cat ( name ,car::vif(model)[name], "\n")
}
Regular expressions are a very important tool for a data scientist or
a machine learning engineer. Regular expressions is dry and boring
topic to learn. But the problems Regular Expressions solve are very real
and interesting. So, we will learn Regular Expressions with a problem
solving approach. We will define a series of small problems, solve them
step by step and with each problem, we will learn some of the aspects of
Regular Expressions.
If you are not comfortable with this kind of non-linear approach, this course might not be for you.
Extract phone numbers v1
Problem – Extract all the phone numbers from this text.
numbers = '''There are 3 phone numbers that you need to call in case of medical emergency.
For casualty, call 408-202-2222. For elderly emergencies, call 408-203-2222 and
for everything else call 408-202-4444
'''
Let’s take a much simpler case – just a list of 3 phone numbers and no text. If the patter is at the beginning of the string, you can use the match ( ) function. Also, match ( ) function only returns the first occurance.
match ( ) function returns a match object. It contains the span (the start and end of the match) and the actual match itself. Use the group ( ), span( ), start( ) and end( ) functions to get the specifics of the match.
print ( "matching text = ", match.group())
print ( "start position = ", match.start())
print ( "end position = ",match.end())
matching text = 408-202-2222
start position = 0
end position = 12
Let’s try something slightly different with match ( ) function. Will it be able to pick the pattern from this text ?
No. Why is that ? match ( ) function can only find out the pattern at the beginning of the string. In this case, the first line is a blank line. So, the match ( ) function fails. In cases like this, use the search ( ) function. In contrast to the match ( ) function, the search ( ) function can extract patterns anywhere in the text.
match = re.search("408-\d\d\d-\d\d\d\d",numbers)
print ( match )
But, we are still getting the first match only. We wanted all the matches, right ? Botht he search ( ) and match ( ) functions return the first match only. To get all the matches, we will have to use other functions like findall ( ) or finditer ( ).
That’s much better, right ? The findall ( ) function returns all the matches it finds in the text. The other function finditer ( ) just returns the same results in an iterable.
matches = re.finditer("408-\d\d\d-\d\d\d\d",numbers)
for match in matches :
print ( match )
If you wanted just the match, use the group () function to extract the matching text.
matches = re.finditer("408-\d\d\d-\d\d\d\d",numbers)
for match in matches :
print ( match.group() )
408-202-2222
408-203-2222
408-202-4444
Now, we can solve the problem we started out with.
numbers = '''There are 3 phone numbers that you need to call in case of medical emergency.
For casualty, call 408-202-2222. For elderly emergencies, call 408-203-2222 and
for everything else call 408-202-4444
'''
matches = re.finditer("408-\d\d\d-\d\d\d\d",numbers)
for match in matches :
print ( match.group() )
408-202-2222
408-203-2222
408-202-4444
In fact, even if the starting phone number is not always constant, like a 408 in this case, still we should be able to extract the matches.
numbers = '''There are 3 phone numbers that you need to call in case of medical emergency.
For casualty, call 408-202-2222. For elderly emergencies, call 408-203-2222 and
for everything else call 800-202-4444
'''
matches = re.finditer("\d\d\d-\d\d\d-\d\d\d\d",numbers)
for match in matches :
print ( match.group() )
408-202-2222
408-203-2222
800-202-4444
See, all the numbers have been extracted.
Points to Remember
\d represents a single digit
match ( ) function returns the first match only, but only start at the beginning of the line.
search ( ) function returns the first match only.
findall ( ) and finditer ( ) functions return all the matches.
Extract phone numbers v2
Problem – Extract all the phone numbers from this text message.
match = re.findall("\d\d\d-\d\d\d-\d\d\d\d",numbers)
print ( match )
['408-222-2222']
But this only matches the phone numbers without brackets. What about the ones with paranthesis ? We can try something like this.
match = re.findall("(\d\d\d)-\d\d\d-\d\d\d\d",numbers)
print ( match )
['408']
oops.. it is not working. Why ? Because, paranthesis represents a special character – It is used to make groups out of regular expressions (which, we will see later). To represent an actual paranthesis, escape it with a backslash.
match = re.findall("\(\d\d\d\)-\d\d\d-\d\d\d\d",numbers)
print ( match )
['(408)-333-3333', '(800)-444-4444']
OK. Now, we got the phone numbers with paranthesis, but we missed the ones without paranthesis. We want to capture either of these combinations. That’s when we use the OR operator. In regular expressions, we use the pipe operator (|) to represent either/or type of patterns.
match = re.findall("\(\d\d\d\)-\d\d\d-\d\d\d\d|\d\d\d-\d\d\d-\d\d\d\d",numbers)
print ( match )
There we go – we were able to capture both the patterns. However, the \d in the pattern repeats a lot making the string too long. Instead, we can use quantifiers to specify how long a particular sub-pattern can be. For example, the following pattern is exactly equivalent to the pattern above.
match = re.findall("\(\d{3}\)-\d{3}-\d{4}|\d{3}-\d{3}-\d{3}",numbers)
print ( match )
As you can see, quantifiers make the pattern much more compact in case there is a lot of repetition.
Points to Remember
If paranthesis (or ) needs to be used in the pattern, escape them with a backslash ( \ ). This is done because, paranthesis is used to represent groups, which we will look into later.
| or pipe character is used to represent a logical OR operator in regular expressions.
{ } Flower brackets are used to quantify the number of occurrances of a particular part of a regular expression. For example, a{3} is used to indicate that exactly 3 a‘s should be looked for.
Extract phone numbers v3
Problem – Extract all the phone numbers from this text message.
match = re.findall("\d{3}-\d{3}-\d{4}|\d{3}.\d{3}.\d{3}|\d{3}\d{3}\d{3}",numbers)
print ( match )
['408-222-2222', '408.333.333', '800 444 444']
This works. But, can we make it any more concise ? There seems to be a lot of repetition. This is where character sets come in. In this case, the separator between the phone numbers is either a dash or a dot or a blank space. Can we somehow represent all of these characters to be searched for as separators, as opposed to specifying each pattern separately ?
match = re.findall("\d{3}[-.]\d{3}[-.]\d{4}",numbers)
print ( match )
['408-222-2222', '408.333.3333']
But what about phone numbers with spaces ? How do we represent a space in regular expressions ? We use the special character \s.
match = re.findall("\d{3}[-.\s]\d{3}[-.\s]\d{4}",numbers)
print ( match )
['408-222-2222', '408.333.3333', '800 444 4444']
There we go – we are able to capture all of the phone numbers.
Points to Remember
Characters enclosed in [] (square brackets) are called character sets. Regular expressions search for any character inside the charater set for matches.
\s is used to represent a space or blank character.
Extract phone numbers v4
Problem – Extract all the phone numbers from this text.
match = re.findall("\d{3}[-.\s]\d{3}[-.\s]\d{4}",numbers)
print ( match )
['408-222-2222', '408.333.3333', '408-444-4444']
But, how about the 1 before the numbers ? How do we capture them ? Some phone numbers have it and some don’t. That’s where the ? quantifier comes in. If a pattern needs to be checked for occurance zero or 1 time, use the ? quantifier.
match = re.findall("1?\s\d{3}[-.\s]\d{3}[-.\s]\d{4}",numbers)
print ( match )
Much better. Now, what about the 800 number with paranthesis ? How do we look for paranthesis ? We have seen previously that paranthesis is a special character and to extract that we need to escape it. Let’s try that.
match = re.findall("1?\s\(?\d{3}\)?[\s]\d{3}[\s]\d{3}",numbers)
print ( match )
['1 (800) 444 444']
Alright, we got that as well. Now, to combine all of these, we can use the OR operator.
match = re.findall("1?\s\(?\d{3}\)?[\s]\d{3}[\s]\d{3}|1?\s\d{3}[-.\s]\d{3}[-.\s]\d{4}",numbers)
print ( match )
The first one is a US phone number, the second one is India and the third one is Chinese number. How to extract these. Let’s start with the plus (+) at the beginning of the string. How to extract that ?
match = re.findall("+",numbers)
---------------------------------------------------------------------------
error Traceback (most recent call last)
<ipython-input-164-2fe5ce5c5168> in <module>
----> 1 match = re.findall("+",numbers)
c:\program files\python37\lib\re.py in findall(pattern, string, flags)
221
222 Empty matches are included in the result."""
--> 223 return _compile(pattern, flags).findall(string)
224
225 def finditer(pattern, string, flags=0):
c:\program files\python37\lib\re.py in _compile(pattern, flags)
284 if not sre_compile.isstring(pattern):
285 raise TypeError("first argument must be string or compiled pattern")
--> 286 p = sre_compile.compile(pattern, flags)
287 if not (flags & DEBUG):
288 if len(_cache) >= _MAXCACHE:
c:\program files\python37\lib\sre_compile.py in compile(p, flags)
762 if isstring(p):
763 pattern = p
--> 764 p = sre_parse.parse(p, flags)
765 else:
766 pattern = None
c:\program files\python37\lib\sre_parse.py in parse(str, flags, pattern)
928
929 try:
--> 930 p = _parse_sub(source, pattern, flags & SRE_FLAG_VERBOSE, 0)
931 except Verbose:
932 # the VERBOSE flag was switched on inside the pattern. to be
c:\program files\python37\lib\sre_parse.py in _parse_sub(source, state, verbose, nested)
424 while True:
425 itemsappend(_parse(source, state, verbose, nested + 1,
--> 426 not nested and not items))
427 if not sourcematch("|"):
428 break
c:\program files\python37\lib\sre_parse.py in _parse(source, state, verbose, nested, first)
649 if not item or item[0][0] is AT:
650 raise source.error("nothing to repeat",
--> 651 source.tell() - here + len(this))
652 if item[0][0] in _REPEATCODES:
653 raise source.error("multiple repeat",
error: nothing to repeat at position 0
oops.. doesn’t work. That is because, + is a special character. It is used to represent a quantifer. + means that a patter repeats one or more time. So, to find + as a pattern, you would have to escape it.
match = re.findall("\+", numbers)
print ( match )
OK, Now, we are able to get the + in the string. Let’s extract the country code next. It is the set of numbers right next to the +
symbol. It could be 1 (like US ) or 2 (like India, China ), or may be 3
(Zimbabwe is +263 ). We can use the flower brackets to specify a
pattern length of 1 to 3 like so –
{1,3}
match = re.findall("\+\d{1,3}", numbers)
print ( match )
Next, we have a set of numbers separated by dashes. However, the count of numbers between the dashes is arbitrary. So, we need some kind of a quantifier again to find out repetitive pattern of count between 1 and n. We could just assume a higher nuber say 5 for n and proceed like so.
match = re.findall("\+\d{1,3}\s{1,3}\d{1,6}-\d{1,6}-\d{1,6}", numbers)
print ( match )
Instead of using the {m,n} quantifier to identify digits that repeat atleast once, you can use the quantifier +.
match = re.findall("\+\d+\s+\d+-\d+-\d+", numbers)
print ( match )
We are still missing another number , +91 98989-99898. This is because, the number is divided into 2 parts (and not 3 parts separated by dashes). So, a simple solution would be to create another pattern and do an OR operation. That should capture all of the possible phone numbers in this case.
match = re.findall("\+\d+\s+\d+-\d+-\d+|\+\d+\s+\d+-\d+", numbers)
print ( match )
Learning
{m,n} is used to represent a pattern that repeats m to n number of times. It is a type of quantifier.
Since + is a special character (used to identify patterns that repeat 1 or more times), to identify + itself, escape it with a backslash (\)
Extract emails
Problem – Extract all the emails from this text.
text = ''' accounts@boa.com,
sales@boa.com,
cancellations@tesla.com,
accounts@delloitte.com,
cancellations@farmers.com,
accounts@dell@com'''
To solve text based patterns, one of the fundamental character set is \w. It represents any character that can be found in a word – it could be alphabetic or numeric or underscore. These are the only 3 types of characters that \w can find. For example, a single \w on this text, basically captures all the word characters (a to z characters, 0-9 digits and underscore ). You can see that in the output below.
matches = re.findall("\w",text)
print ( matches )
We need to step up from letters, to identify words. A word is just a repetition of a set of letters, numbers and underscores. So, we use a quantifier + to identify a word.
We are almost there, except the last email – accounts@dell@com. This is not a valid email. So, why is our pattern capturing it ? When we mentioned dot (.) in our pattern (\w+@\w+.\w+), it basically captures any character. So, in order to capture a dot, all we have to do is to escape it – prepend it with a backslash (\)
html = '''<font size=2>
<font size=2>
<font size = 2>
< font size=2 >
<font size = 2 >'''
Quiz
Which of the following regular expression captures all of the above combinations. Observe the spaces precisely.
“<\s+font
Exercise : Say we have a text with Canadian zip codes. The format for canadian zip codes is
A1A A1A
where A represents an alphabet and 1 represents any digit. There is a space at the 4th character.
text = '''M1R 0E9
M3C 0C1
M3C 0C2
M3C 0C3
M3C 0E3
M3C 0E4
M3C 0H9
M3C 0J1
1M1 A1A
11M 1A1
M11 A1A
M3C0J1
M3C JJ1'''
# Test - The last five elements should NOT match
Problem – Say there is a web server log file, find out how many times the login file was succesfully hit and how many times it failed. For now, we will work with a sample snippet from the file. We will work with the real file in the next challenge.
log = '''
10.128.2.1 [29/Nov/2017:06:58:55 GET /login.php HTTP/1.1 Status Code - 302
10.128.2.1 [29/Nov/2017:06:59:02 POST /process.php HTTP/1.1 Status Code - 302
10.128.2.1 [29/Nov/2017:06:59:03 GET /home.php HTTP/1.1 Status Code - 200
10.131.2.1 [29/Nov/2017:06:59:04 GET /js/vendor/moment.min.js HTTP/1.1 Status Code - 200
10.130.2.1 [29/Nov/2017:06:59:06 GET /bootstrap-3.3.7/js/bootstrap.js HTTP/1.1 Status Code - 200
10.130.2.1 [29/Nov/2017:06:59:19 GET /profile.php?user=bala HTTP/1.1 Status Code - 200
10.128.2.1 [29/Nov/2017:06:59:19 GET /js/jquery.min.js HTTP/1.1 Status Code - 200
10.131.2.1 [29/Nov/2017:06:59:19 GET /js/chart.min.js HTTP/1.1 Status Code - 200
10.131.2.1 [29/Nov/2017:06:59:30 GET /edit.php?name=bala HTTP/1.1 Status Code - 200
10.131.2.1 [29/Nov/2017:06:59:37 GET /logout.php HTTP/1.1 Status Code - 302
10.131.2.1 [29/Nov/2017:06:59:37 GET /login.php HTTP/1.1 Status Code - 200
10.130.2.1 [29/Nov/2017:07:00:19 GET /login.php HTTP/1.1 Status Code - 200
10.130.2.1 [29/Nov/2017:07:00:21 GET /login.php HTTP/1.1 Status Code - 200
10.130.2.1 [29/Nov/2017:13:31:27 GET / HTTP/1.1 Status Code - 302
10.130.2.1 [29/Nov/2017:13:31:28 GET /login.php HTTP/1.1 Status Code - 200
10.129.2.1 [29/Nov/2017:13:38:03 POST /process.php HTTP/1.1 Status Code - 302
10.131.0.1 [29/Nov/2017:13:38:04 GET /home.php HTTP/1.1 Status Code - 200'''
count_200 = 0
count_not_200 = 0
for match in matches :
if match[3] == "200" :
count_200 += 1
else :
count_not_200 += 1
success_perc = ( count_200 / (count_200 + count_not_200) ) * 100
print ( " login was succesfully hit ", success_perc , "% of time")
Learning
(…) is used to represent groups in a regular expression
There can be multiple groups in a single regular expression
Each of the groups can be extracted out per each match of the regular expression
. (dot) represents ANY character
Challenge
Say there is a web server log file, find out how many times the login file was succesfully hit and how many times it failed. The file is available in the data directory. If the HTTP code ( at the end of each line in the log file ) is 200 the page is succesfully rendered. Otherwise, it is a failure.
Solution
# read file
data = [] # will contain the log data as a list
with open ( "./data/log_file.txt", "r") as f :
for line in f :
data.append(line)
# print the read data
for line in data [0:5]:
print ( line, end="")
# parse the data using regular expression and find matches for login.php
import re
login_data = []
pattern = "(\d+\.\d+\.\d+\.\d+).*(login.php)\s(HTTP).*-\s(\d{3})"
for line in data :
matches = re.findall (pattern, line)
if len(matches) > 0 :
login = []
login.append(matches[0][0])
login.append(matches[0][1])
login.append(matches[0][2])
login.append(matches[0][3])
login_data.append(login)
# print a sample
for line in login_data[0:5]:
print ( line)
# calculate the success ratio
count_200 = 0 # succesful
count_not_200 = 0 # unsuccesful
for element in login_data :
if element[3] == "200":
count_200 += 1
else :
count_not_200 += 1
percentage_success = ( count_200 / (count_200 + count_not_200) ) * 100
print ( "Login page was succesfully hit ", percentage_success, "% of the time")
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-192-2b3d022c0101> in <module>
3 count_not_200 = 0 # unsuccesful
4
----> 5 for element in login_data :
6 if element[3] == "200":
7 count_200 += 1
NameError: name 'login_data' is not defined
text = '''Data science is an inter-disciplinary field that uses scientific methods, processes, algorithms and systems to extract knowledge and insights from structured and unstructured data. Data science is related to data mining and big data.
Data science is a "concept to unify statistics, data analysis, machine learning and their related methods" in order to "understand and analyze actual phenomena" with data. It employs techniques and theories drawn from many fields within the context of mathematics, statistics, computer science, and information science. Turing award winner Jim Gray imagined data science as a "fourth paradigm" of science (empirical, theoretical, computational and now data-driven) and asserted that "everything about science is changing because of the impact of information technology" and the data deluge. In 2015, the American Statistical Association identified database management, statistics and machine learning, and distributed and parallel systems as the three emerging foundational professional communities.'''
sentences that have 2 or more space in between them
Credit cards extractor
Problem – Find credit card numbers identified by
their category (Visa, Master Card ). These credit card numbers follow a
certain pattern. Use the following pattern to identify the category.
Patterns –
All visa card numbers start with a 4 and are either 13 or 16 numbers
All Master card numbers start with 51 through 55 or 2221 through 2720 and have exactly 16 digits
All Amex cards start with 34 or 37 and have exactly 15 digits
or, we can let the last element in the pattern to not be a non-capturing group – meaning, it will still be a group from a syntax perspective, but will not be captured as a group. To do that, we use ?:.
Whenever we use ?: at the beginning of the group, it
will be used to capture the pattern, but will not be captured into the
group. Now, let’s work on master card.
Master card represents a different pattern. It has a pretty broad range of numbers – The beginning numbers start with
51 through 55 OR
2221 through 2720
The first one is easy enough. Let’s work on that first.
In all these examples, the first four digits are the ones that are different. The pattern for the rest of the 12 numbers remain the same. So, let’s compress all of these into an OR based patter for the first 4 digits and let the remaining 12 digits remain constant.
(?:…) is used to represent non-capturing
groups – meaning, they will be used to identify patterns, but the
specific pattern within that paranthesis will not be captured as a
group. We have seen how this will be useful in case of Visa pattern.
Cycling through a range of numbers. We have seen how to cycle
through a large range of numbers when we discussed the pattern for
Master card.
Word occurrences
Problem – Find all the occurances of a word in a text and segregate them into 2 categories
1 – standalone occurrence of the word
2 – The word is part of another word.
For example, the word “bat” can occur in isolation , like in the
sentence (“His cricket bat is awesome”), or as part of a different word ,
like in (“Aeriel combat vs land-based combat”).
Solution – Finding the pattern is quite easy. However, the trick is to find out if the word occurred individually or is part of another word. word boundaries can help in this case. \b is used to specify a word boundary.
text = '''Python is a general purpose programming language. Python's design philosopy ...
Let's pythonify some of the code...
while python is a high level language...'''
But it is not working. Why ? That is because \b is a special escape character for backspace. When you specify that in the pattern string, it is not treated literally, but interpreted as a backspace. To avoid confusion, always use raw strings to define patterns. Raw strings can be specified in Python by prepending the string with a r. Let’s try this again.
That’s better. However, there are 3 occurrences of “Python”. Why is python in the last line not being picked up ? That is because, regular expressions are case sensitive. The “p” in the word “python” in the last line is lower case. If you wanted to do a case sensitive search, use global flags. These can be specified as a third parameter in the findall ( ) function.
matches = re.findall(r"\bPython\b", text, re.IGNORECASE) # you can also use re.I as a shortcut
print ( matches )
['Python', 'Python', 'python']
Learning
\b is used to specify a word boundary.
Always prepend the pattern string with r to make it a raw string. This way, the characters in the pattern are taken literally.
global flags can be used to alter the way regular expressions work. One of the global flags we have seen is re.IGNORECASE. It can be used to do a case insensitive search.
HTML tags
Problem – Find all the tags in a HTML or XML.
For example, here is a small snippet of HTML. There are many tags like , , etc. We have to identify all the tags used in the following HTML.
import re
text = '''<html>
<head>
<title> What's in a title</title
</head>
<body>
<tr>
<td>text one </td>
<td>text two </td>
</tr>
</body>
</html>
'''
Another way to do it is to use non-greedy quantification. When you start a pattern with < and consume any character with ., it consumes it all the way to the end. That is why and + are greedy quantifiers. To negate the effect of it, use the ? operator. That way it allows the * to match the least amount of text before the regular expression is satisfied.
import re
matches = re.findall("<.*?>", text)
print ( matches )
Has atleast one special character ( let’s limit special characters to @ , # , $ , % )
text = '''Aw@som$passw0rd
Awesomepassw0rd
Awesomepassword
Aw!som!passw0rd
aw!som!passw0rd'''
# All combinations except for the first one is valid
Solution
This can be solved easily using regular python lists. However, we wanted a more concise solution using regular expressions. In these kinds of situations, we are looking for some kind of validation. Regular expression’s lookaround function is very useful in these cases. The syntax for that is (?=…) where … represents any regular expression. Let’s start with the first condition
That failed. why ? Because regular expressions consume text and move forward. So, the expession [^A-Z]*[A-Z] consumed all the text including the 1 at the beginning. And it is now looking for a number at \d, which it cannot find after the capital letter. This is where lookarounds help.
This time it works. The reason is that we have converted the digit search \D*\d into a lookahead(?=\D\d).
An important aspect of lookarounds ( look ahead or look behind ) is that it does not consume any characters. For example, look at the example below. We want to find out all the words that are preceded by a comma, but we don’t want to show the comma.
text = "Hi there, how are you doing ?"
# \b for word boundary
# \w+ for a word
#(?=,) will ensure that the word is followed by a comma
matches = re.findall(r"\b\w+(?=,)", text)
print ( matches )
['there']
See, the comman is not shown in the output. Granted it is not a big deal. We can do that using groups. However, there are many situations (like the password example above) that cannot be achieved using groups. That’s where lookarounds come in. Let’s continue the same example as above and find out all the words, preceded by a comma.
# (?=,\s) => verify (assert) that before the word, there is a comma followed by a space
# \w+ is a word
matches = re.findall(r"(?<=,\s)\w+", text)
print ( matches )
['how']
Learning
There are 2 type of Lookarounds – look ahead and look behind.
(?=…) is used to do look ahead search.
Lookarounds are also called assertions
Regular expressions cheatsheet
Special Character
Description
.
Matches any character – except new line
[XYZ]
Character set
[^XYZ]
Negation of the Character set
[A-Z]
Matches any character – except new line
pipe
Logical OR
.
Matches any character – except new line
\w
Matches any word character. Equivalent to [A-Za-z0-9_]
\W
Negation of any word character. Equivalent to [^A-Za-z0-9_]
\d
Matches any digit. Equivalent to [0-9]
\D
Matches any non-digit. Equivalent to [^0-9]
\s
Matches any whitespace character ( spaces, tabs or linebreaks )
\S
Matches any non-whitespace character
^
Matches beginning of line
$
Matches end of line
\b
Word boundary
\B
not a word boundary
*
Zero or more
+
One or more
?
Zero or one
(XYZ)
Capturing group
(?:XYZ)
non-capturing group
(?=XYZ)
Positive lookahead
(?!XYZ)
Negative lookahead
(?<=XYZ)
Positive lookbehind
(?<!XYZ)
Negative lookbehind
Challenge
Extract data to JSON
Say, we gave a bunch of cities along with their nick names in the following format in a text file. Extract the city and it nick name in a JSON format with the structure as follows.
cities = '''
1. Paris – The City of Love, The City of Light, La Ville-Lumiere
2. Prague – The City of Hundred Spires, The Golden City, The Mother of Cities
3. New York – The Big Apple
4. Las Vegas – Sin City'''
# required output format
{
"city_1" : ["nick name 1", "nick name 2"],
"city_2" : ["nick name 1", "nick name 2.."]
}
#import the file
with open("./data/cities.txt","r") as f :
data = f.read()
import re
matches = re.findall("\d+\.\s+(\w+\s?\w+)\s+–\s+(.*)", data)
print ( matches[0:5])
import json
city_dict = {}
for city in matches :
city_dict[city[0]] = city[1].split(",")
city_json = json.dumps(city_dict)
Lookup is one of the most commonly used functions in excel. Let’s understand what it is with an example.
Say, we have orders data from an e-commerce website like so.
What we have are the list of products and its prices. Say, somebody gives you a subset of these products and tells you to pick up the corresponding prices. How would you do it ? Of course you can do it manually if you are just talking about a couple of items. However, once you have a handful of items, you need a method for it. Let’s see how we can do it using VLOOKUP formula in excel.
For example, we want to find out the price of 3 items that you see on the right. If we want to get them using VLOOKUP, this is how we do it.
The VLOOKUP formula looks like a big deal, but it is actually not. We will look at the actual parameters in the next chapter, but first I want you to understand what VLOOKUP can do. In the example above, we were able to pull the prices of 3 items. What if there were a million items in the lookup data table and you are trying to find out the prices of 2000 items from it ? As you very well know, this cannot be done manually – This is where VLOOKUP works its magic.
How VLOOKUP works
Now that we understand what VLOOKUP can do for us, let’s understand how VLOOKUP works at a high level. There are 4 parts to the VLOOKUP formula.
Lookup Data ( or lookup table )
Item or key to be looked up
Column number in the lookup data that needs to be extracted
Exact vs approximate match
The following slideshow highlights each of these with the corresponding visuals. What we are trying to do here is to look up the price of item code 22697.
What we have actually done is to just get the prices of 3 item codes. It might look trivial, but imagine the following data – A half a million rows of online retail data from a real store.
Now, say you are given 1000 item codes and asked to look up the price. Imagine what a pain that would be to do manually. VLOOKUP can do this in a split second.
VLOOKUP across tabs
VLOOKUP not just works within a tab. It can work across tabs. Let’s take the same Online Retail product price list with half a million records. Let’s take a subset of these items, say 500 of them and now, we want to look up the prices of these items.
We could do the look up in the same tab, but that makes it clumsy. Instead, let’s do it in a separate tab. To reference the lookup data in a separate tab, all you have to do is to prefix the tab name – like so.
Approximate vs Exact Match
So far, we have been doing Exact matches off the key with the lookup data. Sometimes, we might have to do an approximate match. Approximate match is NOT a wildcard match as you might expect. We will deal with wildcard matches in a seperate section down below. Let’s understand what approximate match does with an example.
Imagine a sales team with different commision percentages.
Now, say we have 4 sales persons with different sales volumes and we want to find out the percentage commision each of them should get.
You see, in this case, the sales volume does not EXACTly match the slabs. Instead, it is a bracket. In cases like this, you can use the approximate match feature of VLOOKUP.
The moment you set the match option to TRUE ( which also happens to be the default option ), an approximate match happens (as opposed to an exact match). You can see the results for yourself. Although the sales numbers on the left don’t exactly match with the slabs on the right, VLOOKUP is smart enough to identify which slab each of the sales volumes fall under and matches the percentage accordingly.
HLOOKUP
While VLOOKUP is the most used function (among VLOOKUP & HLOOKUP), Excel gives you the option to do HLOOKUP or Horizontal lookup as well. Look at the same Retail example flipped horizontally below.
The data is exactly the same, except it is set up horizontally – hence HLOOKUP.
SQL Equivalent of VLOOKUP
If you have some exposure to SQL, VLOOKUP works like a SELECT statement with a LIKE clause. The equivalent of the above VLOOKUP in SQL would be something like this
SELECT item_desc, price FROM <lookup table>
WHERE item_number IN (<list of item numbers>)
#N/A error
What if the item that is looked up doesn’t exist ? Excel throws a #N/A error.
Number as Text
Sometimes, when the key in the lookup data is formatted as text, the lookup does not work and would result in an error.
Select the data, right click on it and click on “Convert to Number”. That should fix the problem.
VLOOKUP works Left to Right only
VLOOKUP always works Left to Right – meaning, the key should always be the first element in the lookup data.
Wildcard or Partial Matches
Sometimes, you might do text based matches. For example, in the example below, you might want to extract the prices using item description, which is a text field ( as opposed to item code, which is numeric). In cases like this, you might not get the text exactly right.
That’s when you can use wild card search. Just append the search key with a “*” (asterisk) and you should be good to go.
LOOKUP is case insensitive
Text based lookup is case insensitive. You can try it by changing cases in the example sheet above.
Merge data using VLOOKUP
Merging data from multiple sources ( typically separate tabs or separate sheets) is one of the use cases for VLOOKUP. The only constraint being that the data should be bound by a common key.
The simple example above shows data being merged from 2 look up tables. You can extend the same idea across multiple tabs or sheets. All you have to remember is that there should be a common key across the data being merged.
Named Ranges
The way number ranges are specified in excel when VLOOKUP is being used is a bit ugly looking. You can use named ranges when doing lookups. Instead of specifying lookup data using rows and columns, you can very well used named ranges. For example, say we specify a name for the lookup data like this..
Now that the range has been given a name, you can use this to do lookups.
VLOOKUP in Google Sheets
VLOOKUP works perfectly well in Google sheets as well. The procedure is exactly the same.
Before getting into the subplots topic, let us first discuss about the two main elements in a plot – Figure and Axes.
Figure
A Figure object can be thought of as a window on which plots are rendered and it contains all the plotting elements.
Axes
The Axes is the actual plotting area contained within the figure
object. There can be multiple axes objects in a figure object. The Axes
contains the x-axis,y-axis,data points,lines,ticks,text etc.,
Interfaces
The Matplotlib library provides two important interfaces for rendering plots:
A finite state machine is an abstract concept wherein a system can
exist in a finite number of states. But at any given point of time, it
can be in exactly one of these states. Depending on the operations
performed on the state machine, it can change from one state to another.
State-machine interface
The figure on which plots are rendered behaves like a state-machine
system. The state-machine interface keeps track of the current figure,
axes and the current state of the figure and hence is called stateful
interface. The state-machine interface makes calls to various methods
using the pyplot module and it is also known as Pyplot interface. When a
method is called using the pyplot module, a figure and axes objects are
created implicitly and changes are made to the figure, hence changing
the state of the figure. We do not define any objects or variables when
using this interface. We simply call methods defined in the pyplot
module and the changes appear in the figure.
The pyplot interface is useful to render simple plots with minimum
code or for quickly testing our code, because we can issue a command and
immediately see the result. But if we are working on multiple plots it
becomes difficult to keep track of the active figure and the plots.
Object-oriented interface
Matplotlib also supports the object oriented interface or the
stateless interface. This interface is called stateless because instead
of creating a global figure instance, a figure object is created and
referenced with a variable. We can then directly call methods on this
variable to change the different elements in a plot. Because we are
storing a reference to the plotting objects in variables, it is easy to
track the objects that we are working on. The object-oriented interface
is a more pythonic and recommended way to use Matplotlib.
The pyplot interface is infact built on top of the object-oriented
interface. The top layer is the front-end interface that is used by
users without much programming experience to test simple plots. Calling
methods such as line,scatter,bar,hist etc., from pyplot module, creates a
figure instance implicitly and generates the respective plots. The
bottom layer can be used to develop complicated visualizations as it
provides more control and allows for customization of the plots.
A figure object can be created using pyplot. In the object oriented
interface as well, pyplot is used to create the figure and the axes
objects. In order to render the plots, we then call the various methods
directly on these objects.
Subplots
Subplot method
Up until now, we have seen various plots such as line plots, bar
charts, pie charts, scatter plots etc., In many examples we have
generated one plot on one figure, we have also seen examples where we
generated multiple plots on one figure. What if you are required to have
plots side by side or one plot on top of the other in a figure? This
can be achieved with subplots. A subplot is a plot that takes up only a
part of the figure object. In Matplotlib,a subplot is also referred to
as an axes. We will now see how to generate subplots using the object
oriented approach.
The subplot method is used to divide a figure into multiple subplots: subplot (m, n, index, kwargs)**
The subplot method returns an axes object and is used to divide a figure into a matrix of m rows and n columns, creating m*n subplots in one figure. The index number provides the location of the current axes for plotting. The index number starts from 1 and the subplots are numbered from left to right beginning from the first row upto the mth row.
from matplotlib import pyplot as plt
import numpy as np
fig1 = plt.figure(num=1,figsize=(10,6))
ax1 = plt.subplot(111)
x = np.arange(1,10)
y = np.arange(1,10)
ax1.plot(x,y,color='c',marker='p')
plt.title('Subplot_1')
plt.show()
Subplots method
In order to plot using the subplot function, the given subplot has to
be first set as the current/active axes for plotting by specifying the
index number and this becomes tedious if there are too many subplots.
Matplotlib also provides the subplots method which returns a figure
object and a numpy array of axes objects.
A frequency distribution is a table that shows the number of times
distinct values in a dataset occur. Histograms are used to evaluate the
frequency distribution of a given variable by visually displaying the
number of data points occurring in a certain range of values, histograms
are useful when there are large datasets to analyse. Similar to a bar
graph, data in a histogram are represented using vertical bars or
rectangles. So a histogram appears similar to bar graphs, but the bars
in a bar graph are usually separated whereas in histograms the bars are
adjacent to each other.
Say for example, you are conducting an experiment and you want to
visually represent the outcome of the experiment. In this experiment,
you are rolling two dice 1000 times, the outcome of each event is
recorded by appending the outcome to a list.
If you want to see the pattern of the outcomes, it is difficult to
analyse the list. We can visualize the pattern by generating a histogram
showing the frequency of occurrence for the sum of two dice rolls.
Histograms are useful for displaying the pattern of your data and
getting an idea of the frequency distribution of the variable. To plot a
histogram, the entire range of input dataset is split into equal sized
groups or bins. A bar is drawn for each bin with the height proportional
to the number of values in the input data that fall under the specified
bin.
Plot a Histogram with random numbers
The histogram below is plotted with random numbers using the ‘hist’ function defined in the pyplot module. The rand function defined in the numpy library creates an array of specified shape and fills it with random numbers from 0 (included) to 1 (excluded).
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
input_data = (np.random.rand(10**3))
plt.hist(input_data,bins=50,color='r',alpha=0.4)
plt.title('Histogram')
plt.xlabel('bins')
plt.ylabel('frequency')
plt.show()
In the above example, the rand function has generated 1000 random
numbers and using the ‘hist’ function these random numbers are
distributed in 50 different bins. It can be observed from the above
histogram that the distribution of random numbers is more in some bins
than the other bins. You can generate random numbers to the order of
10^4,10^5,10^6 and see how the values are distributed.
Plot a Histogram to analyze Airline On-time performance
The U.S. Department of Transportation’s (DOT) – Bureau of
Transportation Statistics (BTS) releases a summary of statistics and
basic analysis on airline performance each month. This dataset is a
summary of different air carriers showing their departure delays,
arrival delays, scheduled departure, etc. Let us analyse the flight data
released by BTS. For this example, I have downloaded data from the
following website – (https://transtats.bts.gov/ONTIME/Departures.aspx) into a csv file. This data is collected at JFK International Airport for American Airlines carrier during Jan’19.
Let us plot a histogram which shows the distribution of departure delays(in minutes) of all flights. The delay in departure is calculated as a difference in minutes between scheduled and actual departure time. In the input dataset, early departures are represented as negative numbers and on-time departures are represented with a zero.
import csv
from matplotlib import pyplot as plt
with open (r'C:\Users\Ajay Tech\Documents\training\visualization\Data\flight_delay_american.csv') as input_file1:
csv_file = csv.reader(input_file1)
header = next (csv_file)
delay_min = []
for row in csv_file:
delay_min.append(int(row[5]))
bins = [-50,0,50,100,150,200,250,300,350,400,450,500,550,600,650,700,750]
plt.hist(delay_min,bins=bins,log=True,color='c')
plt.axvline(np.mean(delay_min), color='r', linestyle='dashed', linewidth=1)
plt.title('Histogram of Departure Delays(AA)')
plt.xlabel('Delay(min)')
plt.ylabel('No of flights')
plt.xticks(bins,rotation=30)
plt.show()
In the above script the yscale is set to log scale instead of normal
scale because log scale allows us to visualize variations that would
otherwise be barely visible. We have marked the average departure delay
time on the histogram with a vertical reference line drawn using the
axvline function. The axvline function plots a line across the x-axis
which can be used to highlight specific points on the histogram. The
dotted vertical line on the histogram indicates that on an average, the
American Airlines flights departing from JFK airport took off 7 minutes
late in Jan’19.
Let us also see the performance of another carrier at JFK airport for the same period.
with open (r'C:\Users\Ajay Tech\Documents\training\visualization\Data\flight_delay_jetblue.csv') as input_file2:
csv_file = csv.reader(input_file2)
header = next (csv_file)
delay_min = []
for row in csv_file:
delay_min.append(int(row[5]))
bins = [-50,0,50,100,150,200,250,300,350,400,450,500,550,600,650,700,750]
plt.hist(delay_min,bins=25,log=True,color='b',alpha=0.3)
plt.axvline(np.mean(delay_min), color='r', linestyle='dashed', linewidth=1)
plt.title('Histogram of Departure Delays(JB)')
plt.xlabel('Delay(min)')
plt.ylabel('No of flights')
plt.xticks(rotation=30)
plt.show()
The vertical line drawn using the axvline function indicates that the
average departure delay time for JetBlue Airways flights flying out of
JFK is 14 minutes. In fact, JetBlue Airways was named as the
most-delayed airline at JFK airport.