Decision tree for classification and regression using Python

Decision tree

Decision tree classification is a popular supervised machine learning algorithm and frequently used to classify categorical data as well as regressing continuous data. In this article, we will learn how can we implement decision tree classification using Scikit-learn package of Python

Decision tree classification helps to take vital decisions in banking and finance sectors like whether a credit/loan should be given to a customer or not depending on his risk bearing credentials; in medical test conditions like if a new medicine should be tried on a patient depending on his/her medical history and many more fields.

The above two cases are where the target variable is a bivariate one i.e. with only two categories of response. There can be cases where the target variable has more than two categories, the decision tree can be applied in such multinomial cases too. The decision tree can also handle both numerical and categorical data. So, no doubt a decision tree gives a lot of liberty to its users.

NB: Being a non-native English speaker, I always take extra care to proofread my articles with Grammarly. It is the best grammar and spellchecker available online. Read here my review of using Grammarly for more than two years. 

Introduction to decision tree

Decision tree problems generally consist of some existing conditions which determine its categorical response. If we arrange the conditions and the decisions depending on those conditions and again one of those decisions resulting in further decisions; the whole structure of decision making resembles a tree structure. Hence the name decision tree.

The first and topmost condition which initiates the decision-making process is called the root condition. The nodes from the root node are called either a leaf node or decision node depending on which one takes part in further decision making. In this way, a recursive process of continues unless and until all the elements are grouped into particular categories and final nodes are all leaf nodes.

An example of decision tree

Here we can take an example of recent COVID-19 epidemic problem related to the testing of positive cases. We all know that the main problem with this disease is that it is very infectious. So, to identify COVID positive patients and isolating them is very essential to stop its further spread. This needs rigorous testing. But COVID testing is a time consuming and resource-intensive process. It becomes more of a challenge in the case of countries like India with a strong 1.3 billion population.

So, if we can categorize which persons actually need testing it can save a lot of time and resources. We can straightway downsize the testing population significantly. So, it is a kind of divide and conquer policy. See the below decision tree for classifying persons who need to be tested.

An example of decision tree
An example of decision tree

The whole classification process is much similar to how a human being judges a situation and makes a decision. That’s why this machine learning technique is simple to understand and easier to implement. Further being a non-parametric approach this algorithm is applicable to any kind of data even when the distribution is not known.

The distinct character of a decision tree which makes it special among all other machine learning algorithms is that unlike them it is a white box technique. That means the logic used in the classification process is visible to us. Due to simple logic, the training time for this algorithm is far less even when the data size is huge with high dimensionality. Moreover, it is the decision tree which makes the foundation of advanced machine learning computing technique like the random forest, bagging, gradient boosting etc.

Advantages of decision tree

  • The decision tree has a great advantage of being capable of handling both numerical and categorical variables. Many other modelling techniques can handle only one kind of variable.
  • No data preprocessing is required. Except for missing values no other data processing steps like data standardization, use of dummy variables for categorical data are required for decision tree which saves a lot of user’s time.
  • The assumptions are not too rigid and model can slightly deviate from them.
  • The decision tree model validation can be done through statistical tests and the reliability can be established easily.
  • As it is a white box model, so the logic behind it is visible to us and we can easily interpret the result unlike the black-box model like an artificial neural network.

Now no technique can be without any flaws, there are always some flipside and decision tree is no exception.

Disadvantages of Decision tree

  • A very serious problem with a decision tree is that it is very much prone to overfitting. That means the prediction given by decision tree is often too accurate for a too specific situation with a too complex model. 
  • The classification by decision tree generally uses an algorithm which tends to find a local optimum result for each node. As this process follows recursively for each node, ultimately the whole process ends up finding a locally optimal instead of a globally optimal decision tree.
  • The result obtained from a decision tree is very unstable. A little variation in the data can lead to a completely different classification/regression result. That’s why the concept of random forest/ensemble technique came, this technique brings together the best result obtained from a number of models instead of relying on a single one.

Classification and Regression Tree (CART)

The decision tree has two main categories classification tree and regression tree. These two terms at a time called as CART. This term was first coined in 1984 by Leo Breiman, Jerome Friedman, Richard Olshen and Charles Stone. 


When the response is categorical in nature, the decision tree performs classification. Like the examples, I gave before, whether a person is sick or not or a product is pass or fail in a quality test. In all these cases the problem in hand is to include the target variable into a group. 

The target variable can be a binomial that is with only two categories like yes-no, male-female, sick-not sick etc. or the target variable can be multinomial that is with more than two categories. An example of a multinomial variable can be the economic status of people. It can have categories like very rich, rich, middle class, lower-middle class, poor, very poor etc. Now the benefit of the decision tree is a decision tree is capable of handling both binomial and multinomial variables.


On the other hand, the decision tree has its application in regression problem when the target variable is of continuous nature. For example, predicting the rainfall of a future date depending on other weather parameters. Here the target variable is a continuous one. So, it is a problem of regression. 

Application of Decision tree with Python

Here we will use the sci-kit learn package to implement the decision tree. The package has a function called DecisionTreeClasifier() which is capable of classifying both binomial (target variable with only two classes) and multinomial (target variable having more than two classes) variables.

Performing classification using decision tree

Importing required libraries

The first step to start coding is to import all the libraries we are going to use. The basic libraries for any kind of data science projects are like pandas, numpy, matplotlib etc. The purpose of these libraries has an elaborate discussion in the article simple linear regression with python.

# importing libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

About the data

The example dataset I have used here for demonstration purpose is from The data collected by “National Institute of Diabetes and Digestive and Kidney Diseases”  contains vital parameters of diabetes patients belong to Pima Indian heritage.

Here is a glimpse of the first ten rows of the data set:

Diabetes data set for logistic regression
Diabetes data set for ANN

The data set has independent variables as several physiological parameters of a diabetes patient. The dependent variable is if the patient is suffering from diabetes or not. Here the dependent column contains binary variable 1 indicating the person is suffering from diabetes and 0 he is not a patient of diabetes.

# Printing data details
print( # for a quick view of the data
print(dataset.head) # printing first few rows of the data
dataset.tail        # to show last few rows of the data
dataset.sample(10)  # display a sample of 10 rows from the data
dataset.describe    # printing summary statistics of the data
pd.isnull(dataset)  # check for any null values in the data
Checking if the dataset has any null value

Creating variables

As we can see that the data frame contains nine variables in nine columns. The first eight columns contain the independent variables. These are some physiological variables having a correlation with diabetes symptoms. The ninth column shows if the patient is diabetic or not. So, here the x stores the independent variables and y stores the dependent variable diabetes count.


Performing the classification

To do the classification we need to import the DecisionTreeClassifier() from sklearn. This special classifier is capable of classifying binary variable i.e. variable with only two classes as well as multiclass variables.

# Use of the classifier
from sklearn import tree
clf = tree.DecisionTreeClassifier()
clf =, y)

Plotting the tree

Now as the model is ready we can create the tree. The below line will create the tree.


Generally the plot thus created, is of very low resolution and gets distorted while using as image. One solution of this problem is to print it in pdf format, thus the resolution gets maintained.

# The dicision tree creation

Another way to print a high resolution and quality image of the tree is to use Graphviz format importing export_graphviz() from tree.

# Creating better graph
import graphviz 
dot_data = tree.export_graphviz(clf, out_file=None) 
graph = graphviz.Source(dot_data) 
Decision tree to classify the data
Decision tree created using Graphviz

The tree represents the logic of classification in a very simple way. We can easily understand how the data has been classified and the steps to achieve that.

Performing regression using decision tree

About the data set

The dataset I have used here for demonstration purpose is from The dataset contains the height and weight of persons and a column with their genders. The original dataset has more than thousands of rows, but for this regression purpose, I have used only the first 50 rows containing data on 25 male and 25 females.

Importing libraries

Additional to the basic libraries we imported in a classification problem, here we will need to import the DecisionTreeRegressor() from sklearn.

# Import the necessary modules and libraries
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt

Reading the dataset

I have already mentioned about the dataset used here for demonstration purpose. The below code is to import the data and store in a dataframe called dataset.


Here is a glimpse of the dataset

Dataset for random forest regression

Creating variables

As we can see that the dataframe contains three variables in three columns. The last two columns are only of our interest. We want to regress the weight of a person using the height of him/her. So, here the independent variable height is x and the dependent variable weight is y.


Splitting the dataset

This is a common practice of splitting the whole data set for creating training and testing data set. Here we have set the test_size as 20% that means the training data set will consist 80% of the total data. The test data set works as an independent data set when need to test the classifier after it gets trained with training data.

# Splitting the data for training and testing
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test=train_test_split(x,y, test_size=0.20, random_state=0)

Fitting the decision tree regression

We have here fitted decision tree regression with two different depth values two draw a comparison between them.

# Creating regression models with two different depths
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5), y_train), y_train)


The below line of codes will give predictions from both the regression models with two different depth values using a new independent variable set X_test.

# Making prediction
X_test = np.arange(50,75, 0.5)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)

Visualizing prediction performance

The below line of codes will generate a height vs weight scattered plot alongwith two prediction lines created from two different regression models.

# Plot the results
plt.scatter(x, y, s=20, edgecolor="black",
            c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",
         label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.title("Decision Tree Regression")


In this post, you have learned about the decision tree and how it can be applied for classification as well as regression problem using scikit-learn of python.

The decision tree is a popular supervised machine learning algorithm and frequently used by data scientists. Its simple logic and easy algorithm are the main reason behind its popularity. Being a white box type algorithm, we can clearly understand how it is doing its work.

The DecisionTreeClassifier() and DecisionTreeRegressor() of scikit-learn are two very useful functions for applying decision tree and I hope you are confident about their use after reading this article.

If you have any question regarding this article or any confusion about its application in python post them in the comment below and I will try my best to answer them.


Supervised Machine Learning: a beginner’s guide

Supervised Machine Learning

The most common type of Machine Learning is Supervised Machine Learning. The nomenclature is due to the fact that the learning process being supervised by the result which is already known. The learning process goes through several iterations. The process continues until the difference between the actual and estimated result comes under an acceptable level. 

“Computers are able to see, hear and learn. Welcome to the future.”

~Dave Waters. Department of Earth Sciences, University of Oxford Associate Professor of Metamorphic Petrology (retired)

The data used in supervised machine learning are called “labelled data” because these data are already tagged with the right answer. Once the training part is complete and a robust model is achieved, some new inputs are provided. The task of the model now is to predict the label of this unforeseen inputs based on the labelled data used before.

In mathematical notation, it can be represented as the output variable Y which is a function of input variable X


During the training phase of supervised machine learning both X and Y remains unknown. The algorithm tries to find out the mapping function which can predict the Y most precisely.

Example of Supervised machine learning

You must have come across the term pattern recognition from any online or offline source. This is a kind off buzz word today and is in use to make our life more sophisticated and comfortable. Starting from a very simple application like your smartphone’s face recognition or handwriting recognition to advance use of cancer cell detection, this supervised learning is the essence of pattern recognition.

Its simple applications are already making our lives easier be it your smartphone’s face lock feature, handwriting recognition or your voice recognition. The auto-driving car concept also heavily depends on supervised learning concept. In every sector of the industry, you can find presence of this theory nowadays.

An application in agriculture

Now to understand how this system works we will take an example of its application in the agriculture field. 

Application of supervised machine learning
Application of supervised machine learning
Photo by Roman Synkevych on Unsplash

Prediction for the crop yield well before its harvesting is very essential for proper policy planning. It helps the government to fix its price, to provide better storage of the produce and farmers also able to plan its marketing channels if there is a precise prediction about how much production is expected.

Now crop yield is determined by several factors, some of them are physical parameters of the crop itself like crop height, number of tillers etc. weather parameters like rainfall, humidity, sunshine hours etc. other than these soil health factors like carbon balance, organic matters and several others play an important role and contribute to the ultimate yield.

Now if we have a sufficient amount of labelled data that is a set of data which has all these independent variables affecting the yield along with the corresponding yield, we can train the algorithm with this training dataset. So, it will be supervised learning. As if the learning process has been supervised by any teacher.

The learning process stops only when a robust model is achieved and the prediction is of an acceptable level.

A real-world problem solved by Supervised Machine learning

Here I am going to cite an example of supervised learning in modern research and how it is being used to address complex problems of the real world.

A Project work was taken up by a group of scientists to identify the endangered species of Mojave desert of California. The main objective of the study was to locate the two threatened species Mohave Ground Squirrel and desert tortoise of the area by analyzing images captured by smartphones. 

The challenge faced by the biologists was to track and rescue these two endangered species as they were very tough to spot. Nature has given them such a capability to camouflage with the desert background and vegetation that it becomes almost impossible for the human eye to see them. 

So here the scientists used computer vision and develop a machine learning algorithm to identify the pattern, distinguish it from the desert backdrop and classify them according to the characteristics.

Types of supervised machine learning

There are two main categories of supervised machine learning.

  • Classification
  • Regression 
Supervised Machine Learning, its categories and popular algorithms
Supervised Machine Learning, its categories and popular algorithms


It is applicable when the variable in hand is a categorical variable and the objective is to classify it. If the algorithm classifies into two classes, it is called binary classification and if the number of classes is more than two, then it is called multiclass classification. 

Classification in Supervised Machine Learning

In the given figure, a binary classification has been demonstrated. Here a group of people has been classified according to their genders depending on a dataset consisting their height and weight.

The task is done in the same way as discussed before. First of all, the algorithm is trained with a dataset with an assigned category. Then based on this training the algorithm has categorized the values when provided with an input data.

Example of classification

A most common example of classification problem is identifying if a new mail is a spam or not spam, identifying loan defaulters also a problem of classification. 

The algorithm is provided with a dataset of mails and a corresponding column indicating if it is a spam or not spam. Similarly, a list is first provided with the customers labelled with if they are a loan defaulter or not to train the algorithm. Then the supervised learning model is used to identify the type of customer from an independent input dataset.

There are a number of algorithms for classification. The most popular ones are

  • Naive Baye’s theorem
  • Linear classifier
  • Support vector machine
  • Random forest
  • Decision tree
  • K-Nearest neighbour


Regression is a statistical process which tries to find out the relationship between the dependent and independent variables. The major difference with classification is that in regression we deal with continuous variables.

If a regression equation is a linear one between the independent and dependent variables then it is a simple linear regression equation. If the regression equation of Y on X is linear, then it does not necessarily suggest that the regression equation of X on Y is also linear and vice-versa. The dependent variable a function of independent variables with respective constant parameters and an error term which is again a random variable. A regression model has the expression:

Y=f 0,1,2,…, n+ϵ

Where Y is the dependent variable, X1, X2+…Xn are independent variables, 0,1,2,…, n are the regression coefficients and is the error term and normally distributed with mean 0 and variance 2.  This type of regression model is also known as a deterministic model.

Example of regression

Regression in Supervised Machine Learning

An example of simple linear regression can be regressing the weight of a group of people on the basis of their height. Here Height and weight are the independent and dependent variable respectively. As a person height determines his weight, not the vice versa.

The blue line in the above figure is the regression line fitted with a supervised machine learning technique. This represents the best-fitted line obtained through a rigorous training process until a robust model with acceptable accuracy is achieved.

To perform regression a number of algorithms are used by researchers. The most frequently used ones are:

  • Simple linear regression
  • Multiple linear regression
  • Logistic regression
  • Polynomial regression etc.