Simple Decision Tree Python Program
In this article I will show you how to create your own Machine Learning program to classify a car as ‘unacceptable’, ‘accepted’, ‘good’, or ‘very good’, using a Machine Learning (ML) algorithm called a Decision Tree and the Python programming language !
Decision Trees are a type of Supervised Learning Algorithms(meaning that they were given labeled data to train on). The training data is continuously split into two more sub-nodes according to a certain parameter. The tree can be explained by two things, leaves and decision nodes. The decision nodes are where the data is split. The leaves are the decisions or the final outcomes. You can think of a decision tree in programming terms as a tree that has a bunch of “if statements” for each node until you get to a leaf node (the final outcome).
Decision Tree Pros:
- Simple to understand and to interpret
- List Requires little data preparation
Decision Tree Cons:
- Prone to over-fitting
- Decision trees can be unstable (a small variation in the data may result in a completely different tree being generated)
If you prefer not to read this article and would like a video representation of it, you can check out the video below. It goes through everything in this article with a little more detail, and will help make it easy for you to start programming your own Decision Tree Machine Learning model. Or you can use both as supplementary materials for learning about Decision Trees !
The original data set is the car evaluation data set from http://archive.ics.uci.edu/ml/datasets/Car+Evaluation.
More specifically it is a .data file originally and you can download it from http://archive.ics.uci.edu/ml/machine-learning-databases/car/car.data.
We will classify the quality or values column of the car, after switching all of the values from the original data set to integers except for our dependent variable ‘values’ column which I’ve already done and have a file for, the new data set. The new data set with the integer values is called ‘car_integer_exceptY.csv’ Get the ‘car_integer_exceptY.csv’ data set here:
Each attribute/feature described below:
#buying (buying price): vhigh (4), high (3), med (2), low (1)
# main (maintenance price): vhigh (4), high (3), med (2), low (1)
# doors (number of doors): 2, 3, 4, 5-more (5)
# persons (number of passengers fit in a car): 2, 4, more (6)
# lug_boot (size of luggage capacity): small (1), med (2), big (3)
# safety: low (1), med (2), high (3)
# values: unacc = unaccepted, acc = accepted, good = good, vgood = very good
First install the packages or dependencies that will make it easier to write this program.
# Import the dependencies / libraries
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
Load the data by storing the car data set into a variable called ‘df’ as a dataframe. You can get the data set here.
#Create a dataframe from the cars dataset / csv file
df = pd.read_csv('DataSets/Cars/car_integer_exceptY.csv')
Print the first 5 rows of the data.
#print the first 5 rows of the data set
Split your data into the independent variable(s) and dependent variable.
# Split your data into the independent variable(s) and dependent variableX_train = df.loc[:,'buying':'safety'] #Gets all the rows in the dataset from column 'buying' to column 'safety'Y_train = df.loc[:,'values'] #Gets all of the rows in the dataset from column 'values'
Create the Decision Tree model with 3 leaves !
# The actual decision tree classifier
tree = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
Train the model by using the fit method.
# Train the model
Make your prediction on input: buying=v-high, main=high, doors=2, persons=2, lug_boot=med, safety=3 which translate to integer value input: 4,3,2,2,2,3
# Make your prediction
# input:buying=v-high, main=high, doors=2, persons=2, lug_boot=med, safety=3
# integer conversion of input: 4,3,2,2,2,3
prediction = tree.predict([[4,3,2,2,2,3]])
Print the prediction and notice that we get back a result of [‘unacc’] meaning, unaccepted and the program is done !
#Print the prediction
print('Printing the prediction: ')
You can see the video above for how I coded this program and code along with me with a few more detailed explanations, or you can just click the YouTube link here.
If you are also interested in reading more on machine learning to immediately get started with problems and examples then I strongly recommend you check out Hands-On Machine Learning with Scikit-Learn and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. It is a great book for helping beginners learn how to write machine learning programs, and understanding machine learning concepts.
Thanks for reading this article I hope its helpful to you all ! If you enjoyed this article and found it helpful please leave some claps to show your appreciation. Keep up the learning, and if you like machine learning, mathematics, computer science, programming or algorithm analysis, please visit and subscribe to my YouTube channels (randerson112358 & compsci112358 ).