Decision tree is one of most basic machine learning algorithm which has wide array of use cases which is easy to interpret & implement. We can use decision tree for both regression & classification tasks. In this article we will try to understand the basics of Decision Tree algorithm. Then how Decision tree gets generated from the training data set using CART algorithm.
About Decision Tree:
Decision tree is a non-parametric supervised learning technique, it is a tree of multiple decision rules, all these rules will be derived from the data features. It is one of most easy to understand & explainable machine learning algorithm. This ML algorithm is the most fundamental components of Random Forest, which are most popular & powerful ML algorithm.
- Structure of Decision Tree:
- In the below image I tried to show how a decision tree would look like. Each internal node represents a segment or region. With respect to tree analogy, segments or regions are nodes or leaves of the tree.
- Root Node: This is the first node which is our training data set.
- Internal Node: This is the point where subgroup is split to a new sub-group or leaf node. We can call this as a decision node as well because this is where node splits further based on the best attribute of your sub-group.
- Leaf Node: Final node from any internal node, this holds the decision.
2. About nodes in Decision Tree?
- As mentioned before Decision tree is a tree like structure which will have nested nodes, the splitting of one node to another happens based on a threshold value of an attribute which we will discuss shortly in detail.
- Decision tree algorithm splits the training set (root node) to sub-groups – internal nodes & any internal node with final sub-group will be the leaf node. which we can term it as a Recursive partitioning.
Now we will go bit deeper to understand the statement ?How the splitting of nodes happens based on a threshold value of an attribute??
The splitting of node (root node) to sub-nodes happens based on purity, the Decision Tree algorithm split the node where it will find best homogeneity for the sub-nodes. If a Sub-node has all it?s class members then homogeneity will be higher. If your Sub-node has 5/5 class member distribution then homogeneity will be lowest and highest in case it is 8/2 or 9/1.
To split a node Decision Tree algorithm needs best attribute & threshold value. Selection of best attribute & threshold value pair (f,t) happens based on below algorithms which will give you the purest nodes. The below algorithms helps to find the measurements of the best attributes:
- CART algorithm : Gini Index
- ID3 algorithm : Information Gain
- C4.5 algorithm : Gain Ratio
In this article I will use CART algorithm to create Decision tree.
This algorithm can be used for both classification & regression. CART algorithm uses Gini Index criterion to split a node to a sub-node. It start with the training set as a root node, after successfully splitting the root node in two, it splits the subsets using the same logic & again split the sub-subsets, recursively until it finds further splitting will not give any pure sub-nodes or maximum number of leaves in a growing tree or termed it as a Tree pruning.
How to calculate Gini Index?
In Gini Index, P is the probability of class i & there is total c classes.
Considering you have only two predictor/attributes: Humidity & Wind
Class: Rainy & Sunny
GI = 1 ? ((num of observations from Feature_1/total observation) + (num of observations from Feature_2/total observation))
GI = 1-((6/10) + (4/10)) => 1-(0.36+0.16) => 1?0.52 => 0.48
So, the Gini index for the first/initial set is 0.48
Basic idea on how the Node split happens:
Based on attribute ?wind? (f) & threshold value ?3.55? (t) the CART algorithm created nodes/subsets which would give a pure subsets to right side of the above flow (ref: image 4).
Let us understand how we selected the best pair (f,t) to split the root node:
Step 1: Find the best Gini Index/score from initial set
I wrote a small code snippet to understand it better:
Now after loading data, we will find the Gini score for the initial set or root node which will be the best_gini score:
From the above code we got inputs for next set of instructions to execute:
- Line num 2 : Number of features = 2
- Line num 9 : placeholders for best attribute (best_attribute) & threshold (best_thr)
- Line num 11: best_gini score on initial set = 0.48 (refer image 4)
Step 2: Find the best split from initial/training set
Now I will explain the most important part, how we will search for the best attribute (f) & threshold value (t) from the initial set?
- Line num 1:8 – The algorithm will loop ?attribute? number of times i.e. 2 & create two buckets left & right. The left one will have nothing assigned while right one have all the sorted row values named as thresholds
- Line num 14:26 – The two initialized buckets (left & right) will enter into next loop, which will iterate number of row times -10, on each iteration algorithm will assign each class observation from right to left & calculate the Gini weighted average, new_gini every time
- Line num 35:38 – If the new_gini is lower than best_gini then next we will find the best attribute & threshold (f,t)
Code flow: Algorithm picks each attribute, sort it?s list of values (which we termed it as a Thresholds list), sequentially move each observation to left bucket, calculates the new_gini (Gini weighted average) compare it with initial Best_Gini, if new_gini score is lesser then keep it as a best attribute (f) & threshold value (t) & again repeat the process for all attributes. Then use the best attribute (f) & threshold value (t) to split the node to sub-nodes/sub-sets.
The below picture tries to summarize the flow, here Green circles represent class member: Sunny and Red triangles represents class member: Rainy.
Example on how best split happens
To make it more explainable how best split happens, consider the attribute ?Humidity? in 7th iteration where the distribution is like below:
The equation just after the distribution step is CART cost function or Gini weighted average, which will give you new Gini score (new_gini):
The new_gini score will be based on class distribution on left & right bucket:
- The Left Bucket: 2 Sunny & 5 Rainy
G_left = 1 ? ((num of observations from Sunny/total observation) + (num of observations from Rainy/total observation))
G_left = 1-((2/7)+(5/7))) => 0.40
- The right bucket: 2 Sunny & 1 Rainy:
G_right = 1 ? ((num of observations from Sunny/total observation) + (num of observations from Rainy/total observation))
G_right = 1- ((2/3)+(1/3))) => 0.44
m_left = 7, m_right = 3
New_Gini = (7*0.40)+(3*0.44)/10 => 0.41
Which is lesser than best_gini (0.48) so, we will consider it as a best threshold value & attribute:
The Threshold list = [1.5, 1.5, 1.5, 1.6, 3.4, 3.9, 4.6, 4.7, 5.0, 5.1]
If we remember we are on 7th iteration, best_threshold value would be = (4.7+4.6)/2=>4.66
For now, the best feature & threshold (f,t) is (0, 4.66) and if this is the least Gini score then algorithm will split the initial node based on attribute ?Humidity? & threshold value ?4.66?, this would look like below image:
But it is not the best pair (f,t). The above mentioned process will continue for all available attributes & will keep on searching for the new lowest Gini score, if it finds it will keep the threshold value & it?s attribute, later it will split the node based on best attribute & threshold value. According to our data set best Gini score is ?0.40? for ?Wind? attribute (f) & ?3.55? as best threshold value (t). The below tree generated by DecisionTreeClassifier using scikit-learn which shows node split happened based on same threshold value & attribute:
Let?s have a recap:
- We learnt what is a decision tree
- Different criterion metrics to create new nodes/subset
- Gini Index
- How to calculate Gini Index
- Example on how nodes gets created based on an attribute (f) & threshold value (t)
- How we search the best attribute, threshold value pair (f,t)
- How (f,t) pair value help to create a pure nodes
- Working code of a decision tree
Thank you & happy reading?
- Image 2 – An introduction to Statistical Learning with Application in R
- Image 5- Hands-On Machine Learning with Scikit-Learn & Tensorflow
- Some Code ref