Learn ML Algorithms by coding: Decision Trees

Rahul
Lethal Brains
Published in
9 min readAug 8, 2018

--

Introduction

The intention of this series, Learn ML Algorithms by coding, is to understand the intricacies of popular machine learning algorithms by coding them from scratch. You may wonder why am I doing this, when I can write the most popular 3 lines from Scikit-learn - instantiate, fit and predict for any ML algorithm.

The idea here is not to create a competing implementation, rather to understand the building blocks behind each machine learning algorithm.

Since the focus is on learning the algorithms, efficiency is not given priority. At places, I have even traded off efficiency for explainability

In a galaxy far far away….

Groot: I am Groot

Me: Hello Groot. Yes, been a long time. How are you?
Groot: I am Groot.
Me: I had been very busy, been working on multiple projects.
Groot: I am Groot.
Me: How do you know that I was working on Machine Learning projects? It was supposed to be top secret.
Groot: I am Groot.
Me: Damn you Star Lord! Such a blabbermouth.
Groot: I am Groot.
Me: Oh! We implemented multiple algorithms. Logistic Regression, Light GBM, XGBoost, Decision Trees, Random Fore…
Groot: I AM GROOT. I AM GROOT.
Me: Come on! Just because it says decision trees, doesn’t mean it has any relation with you. And you want me to teach you Decision trees? No way!!

Groot: I am Groot

Me: Oh no! Not that face. Ok. I will do it but lets learn by building our own decision tree.
Me: I hope you know the basic data pre processing and feature engineering, if not, read through this kernel. Its very extensive.

Dataset

Let us use the famous titanic dataset from kaggle. As I mentioned, we are not going to discuss on pre processing and feature engineering. So I have uploaded a preprocessed dataset here.

We will be implementing the decision tree for a binary classification problem(i.e. the target attribute is a category of two classes). Titanic dataset is a classic example, the Survived column is 1 for people who survived, 0 otherwise. Now let us build us a Decision Tree.

The Crude way

Let us build a crude decision tree which predicts the outcome in probabilities (In Scikit learn, predict method returns the predicted classes while the predict_proba method returns the predicted probabilities. Lets implement the later). What do you think would be most simple and easy way to predict the probabilities?

Groot: I am Groot
Me: Exactly, the probabilities based on the frequencies of the classes (i.e. in a dataset with 100 rows, if 0 occurs 70 times and 1 occurs 30 times, the respective probability would be 0.7 and 0.3 respectively)

Crude implementation version 1

I have touched it up a little bit. The fit method accepts a dataframe(data) and a string for the target attribute(target). Both of the them are then assigned to the object. The independent attribute names are derived and assigned to the object.

The predict method iterates though the test data that is passed and returns the probability based on frequency, Hence at this point, the output of all the rows will have the same values. __flow_data_thru_tree achieves the above on a row level. Lets execute this code.

As expected, we get the probabilities as 0.62 and 0.38. But this prediction is not of much use. In order to evolve our implementation, we need to understand what a decision tree is.

What is a Decision Tree?

Decision Tree is a tree based algorithm providing a multi-level conditional architecture for predictive analysis by evaluating the incremental information gain.

Ok, relax. Let me dial it down. Think of the decision tree as a tree filled with if-else conditions. Most efficient condition on the top and it gradually downgrades as we descend(in most cases). Let us see an example

In this tree, the top node has a condition to check whether the sex of the passenger is male. So the condition on the root node of the tree is sex = male, if yes, the tree traversal is to the left else we slide to the right.

In the next level, a different criteria on a different feature is established to distinguish the data. Lets look at the tree from an implementation aspect. List down the components of a tree.

Groot: I am Groot
(Translation below)

  • A tree has a root node.
  • A node can have left branch and right branch.
  • Lowest layer nodes do not have any branches

Good observation! The nodes with no branches are called as leaf nodes and the leaf nodes are decisive. Programmatically, we will not be implementing branches as a separate data structure, since the branches can be represented as trees themselves, we can recursively build Decision trees. The image below represents a branch taken from the larger tree shown above but, when isolated, it is a tree of its own.

To implement branches, we begin by adding a constructor to the class and initialize the left and right branches to None. We are not intentionally initializing them with a Decision tree object because the decision of having branches for a node can be determined only by a split which brings us to the most important question in a decision tree

How is an optimal split made in the Decision Tree?

Finding the optimal split

Now would be the right time to introduce two important concepts in Decision Trees - Impurity and Information Gain. In a binary classification problem, an ideal split is the condition which can divide the data such that the branches are homogeneous.

In the example on the left, the ideal split on the root node can be done with the criteria Independent<=5. With this condition, the left branch would have all the 0s and the right branch would have all the 1s. Our leaf nodes would be pure. But in a realistic use case, it is not as straight forward as this. The dataset is filled with features and each feature can generate n number of criteria.

Hence a split should be capable of decreasing the impurity in the child node with respect to the parent node and the quantitative decrease of impurity in the child node is called the Information Gain.

We need to have a unit of measure to quantify the impurity and in the information gain at each level. Common measures of impurity are

  • Gini
  • Cross Entropy
  • Misclassification

For our implementation, lets us pick up Gini. Following are the formula for impurity and information gain when using Gini index

where k is number of classes in the target attribute and pis the probability of the class at the node.

where k is number of classes in the target attribute, r is the number of rows in the node and n is the number of rows in the dataset.

Lets dive in and implement Impurity and Information Gain

Given a feature(target) or a section of it, this method finds the probability of multiple classes. Since we are working on a Binary classification problem, we would get only two probabilities and since their summation is 1. So p_i * (1 — p_i) would be the same for both classes, hence the * 2.

In order to find the best split, we need to get the best split in each features and use the one with the most information gain. First let us see the implementation for a single feature

Here, given a feature, all the unique values are separated and for each of those values, a split is made in such a way the data is either less than or equal to the value or greater than the value.

In this scenario, if I make a split on independent <= 5, all the rows with value less than 5 in the independent feature would become “left” branch and others would go to the “right” branch.

Then we can calculate the impurity for the left branch and the right branch. Once we have both, we can calculate the information gain by using __calculate_information_gain method which is a straight forward implementation of the formula above.

This is done for all the unique values in the feature and split with highest information is returned by the method.

We must then repeat this for all the features, to find the best split across all of them.

Now that we have identified the split, let us create the branches.

In this method, we have instantiated a decision tree for both the left and the right branch. We pass in the rows with values less than the criteria to the left branch and greater than the criteria to the right branch. Invoking this method from fit would recursively create the branches and with that we have built the tree.

Groot: I am Groot

Not so fast, Groot. We are not done yet. Now we come to the point where we must ask the second most important question in a decision tree.

When do we stop splitting?

Before understanding “when”, let us discuss on “why”. The logical reason to stop splitting would be when the data at the node is homogeneous. Another point to note is that Decision Trees have a major disadvantage

Decision trees are likely to overfit noisy data. The probability of overfitting on noise increases as a tree gets deeper.

One way to avoid this is by predefining the maximum depth of the tree when initializing the Decision tree. There are also other ways of preventing overfitting but for this implementation we can work on just the max depth. We can add max_depth attribute to our class and we can increment the depth as we create branches as shown below.

We have finally created our decision tree, now let us try and fit our training data. Hope its not a bumpy ride.

Groot: I am Groot

We are able to successfully fit our training data. Now the next step is to predict the probabilities for the test data.

I am adding two more methods is_leaf_node and probability which can be accessed as properties using the property decorator. You can read more about it here.

In order to predict the probabilities of the test data, let us create two function, one to iterate through each row while the other to traverse the tree for a row and find the prediction. The latter method invokes iteratively till it reaches a leaf node.

Now for the judgement time, lets try to predict the probabilities of survival in our test data by traversing through the tree.

After converting the probabilities into actual classes, following is the prediction outcome.

Lets plug them all together

Me: Not a bad start to Machine learning, what do you say, Groot.
Groot: I am Groot.
Me: I am not exactly sure what you mean.
Groot: I am GROOT.
Me: Still nothing. Hey Rocket, could you help me out with the translation.
Groot: I am GROOT.
Rocket: He says ‘That was frickin Good’

There are lot of improvements that are required to this implementation to make it real time usable but I am not planning on doing that.

I have added few validations and have uploaded the code in github. Feel free to send in pull requests.

--

--

AWS certified Solution Architect | Node.js - Machine learning Engineer | Amateur photographer