Let’s Write a Decision Tree Classifier from Scratch - Machine Learning Recipes #8

Google for Developers
13 Sept 201709:52

Summary

TLDRIn this educational video, Josh Gordon teaches viewers how to build a decision tree classifier from scratch using Python. He introduces a toy dataset to predict fruit types based on attributes like color and size. The tutorial covers decision tree learning, Gini impurity, and information gain. The code is available in Jupyter notebook and Python file formats. The video encourages swapping the dataset for personal projects, promoting hands-on learning.

Takeaways

  • 🌳 The tutorial focuses on building a decision tree classifier from scratch using pure Python.
  • 📊 The dataset used is a toy dataset with both numeric and categorical attributes, aiming to predict fruit types based on features like color and size.
  • 📝 The data set is intentionally not perfectly separable to demonstrate how the tree handles overlapping examples.
  • 🔍 The CART algorithm is introduced for decision tree learning, standing for Classification and Regression Trees.
  • 📉 Gini impurity is explained as a metric for quantifying the uncertainty or impurity at a node, with lower values indicating less mixing of labels.
  • 🌐 Information gain is discussed as a concept for selecting the best question to ask at each node, aiming to reduce uncertainty.
  • 🔑 The process of partitioning data into subsets based on true or false responses to a question is detailed.
  • 🛠️ Utility functions are provided to assist with data manipulation, and demos are included to illustrate their usage.
  • 🔄 Recursion is used in the 'Build Tree' function to iteratively split the data and build the tree structure.
  • 📚 The video concludes with suggestions for further learning and encourages viewers to apply the concepts to their own datasets.

Q & A

  • What is the main topic of the video?

    -The main topic of the video is building a decision tree classifier from scratch in pure Python.

  • What dataset is used in the video to demonstrate the decision tree classifier?

    -A toy dataset with both numeric and categorical attributes is used, where the goal is to predict the type of fruit based on features like color and size.

  • What is the purpose of the dataset not being perfectly separable?

    -The dataset is not perfectly separable to demonstrate how the decision tree handles cases where examples have the same features but different labels.

  • What utility functions are mentioned in the script to work with the data?

    -The script mentions utility functions that make it easier to work with the data, with demos provided to show how they work.

  • What does CART stand for and what is its role in building the decision tree?

    -CART stands for Classification and Regression Trees, which is an algorithm used to build trees from data by deciding which questions to ask and when.

  • How does the decision tree algorithm decide which question to ask at each node?

    -The decision tree algorithm decides which question to ask at each node by calculating the information gain and choosing the question that produces the most gain.

  • What is Gini impurity and how is it used in the decision tree?

    -Gini impurity is a metric that quantifies the uncertainty or mixing at a node, and it is used to determine the best question to ask at each point in the decision tree.

  • How is information gain calculated in the context of the decision tree?

    -Information gain is calculated by starting with the uncertainty of the initial set, partitioning the data based on a question, calculating the weighted average uncertainty of the child nodes, and subtracting this from the starting uncertainty.

  • What is the role of recursion in building the decision tree?

    -Recursion plays a role in building the decision tree by allowing the build tree function to call itself to add nodes for both the true and false branches of the tree.

  • How does the video conclude and what is the recommendation for the viewers?

    -The video concludes by encouraging viewers to modify the tree to work with their own datasets as a way to build a simple and interpretable classifier for their projects.

Outlines

00:00

🌳 Introduction to Building a Decision Tree Classifier

Josh Gordon introduces a tutorial on constructing a decision tree classifier from scratch using Python. The episode's agenda includes an introduction to the dataset, a preview of the completed tree, and a step-by-step guide to building the tree. Key concepts such as decision tree learning, Gini impurity, and information gain are discussed. The dataset, designed to predict fruit types based on attributes like color and size, is intentionally not perfectly separable to demonstrate the tree's handling of such cases. The code for the episode is available in both Jupyter notebook and Python file formats. The tutorial encourages viewers to experiment with their own datasets and features.

05:01

📊 Understanding Gini Impurity and Information Gain

The script delves into the metrics used to build an effective decision tree: Gini impurity and information gain. Gini impurity measures the uncertainty at a node, with lower values indicating less mixing of labels. It's illustrated with examples, including a scenario with no uncertainty (impurity of zero) and one with high uncertainty (impurity of 0.8). Information gain quantifies how much a question reduces uncertainty, calculated by comparing the initial uncertainty with the weighted average uncertainty of the partitions created by the question. The process involves iterating over feature values to generate questions, partitioning data, and calculating information gain. The script includes code examples and demos to clarify these concepts and concludes with a walkthrough of the algorithm used to build the tree, emphasizing the recursive nature of the process.

Mindmap

Keywords

💡Decision Tree Classifier

A decision tree classifier is a type of machine learning model that uses a tree-like graph of decisions to predict the value of a target variable based on its input features. In the video, the host demonstrates how to build a decision tree classifier from scratch using Python. The classifier is used to predict the type of fruit based on attributes like color and size, illustrating the practical application of this machine learning technique.

💡Gini Impurity

Gini impurity is a measure used in decision trees to quantify the uncertainty or 'impurity' of a node during the training process. It is calculated as the probability of a randomly chosen element from the node being incorrectly classified. In the script, Gini impurity is used to determine which feature split will yield the 'purest' subset, thus guiding the construction of the decision tree to maximize information gain.

💡Information Gain

Information gain refers to the reduction in uncertainty or impurity that results from partitioning a dataset based on an attribute. It is a key concept in the construction of decision trees, as it helps in selecting the best feature to split the data at each node. In the video, information gain is used to decide which questions to ask at each node in order to create a tree that accurately classifies the input data.

💡CART

CART stands for Classification and Regression Trees, which is an algorithm used to construct binary trees from data. It is a widely used method for building decision trees. In the script, CART is introduced as the algorithm that will be used to build the decision tree classifier, emphasizing its role in handling both classification and regression tasks.

💡Numeric and Categorical Attributes

Numeric and categorical attributes are types of features that describe data. Numeric attributes are quantitative, such as size or weight, while categorical attributes are qualitative, such as color or type. The script mentions that the toy dataset includes both types of attributes, which are used as inputs for the decision tree classifier to predict the type of fruit.

💡Toy Dataset

A toy dataset is a small, simplified dataset used for demonstration or testing purposes. In the video, the host has created a toy dataset with both numeric and categorical attributes to illustrate the process of building a decision tree classifier. This dataset is used to predict the type of fruit, and it is intentionally not perfectly separable to show how the tree handles such cases.

💡Feature

In the context of machine learning, a feature is an individual measurable property or characteristic of a phenomenon being observed. Features are used as inputs for the model. The script discusses features such as color and size, which are used to describe the data and help the decision tree classifier make predictions.

💡Label

In machine learning, a label is the target variable that the model is trying to predict. It represents the correct answer or value associated with a given set of features. In the video, the label is the type of fruit that the decision tree classifier aims to predict based on the input features.

💡Recursion

Recursion is a method in programming where a function calls itself in order to solve a problem. In the context of the video, recursion is used in the 'build tree' function to construct the decision tree. The function calls itself with subsets of the data until it reaches a base case where no further questions can be asked, and a leaf node is added.

💡Leaf Node

A leaf node is a terminal node in a decision tree that does not have any child nodes. It represents a decision or prediction made by the tree. In the script, leaf nodes are created when there are no further questions to ask, and the node makes a prediction based on the majority label in the subset of data it represents.

Highlights

Introduction to building a decision tree classifier from scratch in Python.

Outline of topics covered including dataset introduction, preview of the completed tree, and building process.

Availability of code in two formats: Jupyter notebook and regular Python file.

Description of a toy dataset with numeric and categorical attributes for predicting fruit type.

Encouragement for viewers to modify the dataset and apply the tree to their own problems.

Explanation of dataset format with examples and attributes.

Discussion on the dataset's non-separability to demonstrate tree handling of conflicting examples.

Introduction of utility functions and their demonstration for easier data manipulation.

Overview of the CART algorithm for decision tree learning.

Description of the tree-building process starting with a root node and partitioning data.

Explanation of how to quantify uncertainty at a node using Gini impurity.

Introduction to the concept of information gain for selecting the best questions to ask.

Detailed walk-through of the algorithm's recursive nature in building the tree.

Demonstration of how to handle data partitioning based on questions asked.

Calculation of Gini impurity to measure node uncertainty.

Explanation of information gain and its role in choosing the best question to ask.

Practical implementation of the build tree function using recursion.

Final tree construction with decision nodes and leaf nodes based on information gain.

Additional functions for classifying data and printing the tree for visualization.

Recommendation for next steps, including modifying the tree with personal datasets.

Conclusion and a prompt for viewers to explore more about decision trees through suggested books.

Transcripts

play00:05

JOSH GORDON: Hey, everyone.

play00:06

Welcome back.

play00:07

In this episode, we'll write a decision tree classifier

play00:10

from scratch in pure Python.

play00:12

Here's an outline of what we'll cover.

play00:14

I'll start by introducing the data set we'll work with.

play00:17

Next, we'll preview the completed tree.

play00:19

And then, we'll build it.

play00:20

On the way, we'll cover concepts like decision tree learning,

play00:23

Gini impurity, and information gain.

play00:24

And you can find the code for this episode

play00:26

in the description.

play00:27

And it's available in two formats,

play00:29

both as a Jupiter notebook and as a regular Python file.

play00:32

OK, let's get started.

play00:34

For this episode, I've written a toy data

play00:36

set that includes both numeric and categorical attributes.

play00:39

And here, our goal will be to predict the type of fruit,

play00:42

like an apple or a grape, based on features

play00:44

like color and size.

play00:46

At the end of the episode, I encourage

play00:47

you to swap out this data set for one of your own

play00:50

and build a tree for a problem you care about.

play00:53

Let's look at the format.

play00:54

I've re-drawn it here for clarity.

play00:56

Each row is an example.

play00:57

And the first two columns provide features or attributes

play01:00

that describe the data.

play01:02

The last column gives the label, or the class,

play01:03

we want to predict.

play01:05

And if you like, you can modify this data set

play01:07

by adding additional features or more examples,

play01:09

and our program will work in exactly the same way.

play01:12

Now, this data set is pretty straightforward,

play01:13

except for one thing.

play01:15

I've written it so it's not perfectly separable.

play01:18

And by that I mean there's no way to tell apart

play01:20

the second and fifth examples.

play01:22

They have the same features, but different labels.

play01:24

And this is so we can see how our tree handles this case.

play01:27

Towards the end of the notebook, you'll

play01:29

find testing data in the same format.

play01:33

Now I've written a few utility functions that make it easier

play01:35

to work with this data.

play01:36

And below each function, I've written a small demo

play01:39

to show how it works.

play01:40

And I've repeated this pattern for every block of code

play01:42

in the notebook.

play01:45

Now to build the tree, we use the decision tree learning

play01:47

algorithm called CART.

play01:49

And as it happens, there's a whole family of algorithms

play01:51

used to build trees from data.

play01:53

At their core, they give you a procedure

play01:55

to decide which questions to ask and when.

play01:58

CART stands for Classification and Regression Trees.

play02:01

And here's a preview of how it works.

play02:04

To begin, we'll add a root node for the tree.

play02:06

And all nodes receive a list of rows as input.

play02:09

And the root will receive the entire training set.

play02:12

Now each node will ask a true false question

play02:14

about one of the features.

play02:16

And in response to this question,

play02:17

we split, or partition, the data into two subsets.

play02:21

These subsets then become the input to two child nodes

play02:24

we add to the tree.

play02:25

And the goal of the question is to unmix the labels

play02:28

as we proceed down.

play02:30

Or in other words, to produce the purest possible

play02:32

distribution of the labels at each node.

play02:35

For example, the input to this node

play02:37

contains only a single type of label,

play02:39

so we'd say it's perfectly unmixed.

play02:41

There's no uncertainty about the type of label.

play02:44

On the other hand, the labels in this node are still mixed up,

play02:47

so we'd ask another question to further narrow it down.

play02:50

And the trick to building an effective tree

play02:52

is to understand which questions to ask and when.

play02:55

And to do that, we need to quantify how much a question

play02:58

helps to unmix the labels.

play03:00

And we can quantify the amount of uncertainty

play03:02

at a single node using a metric called Gini impurity.

play03:05

And we can quantify how much a question

play03:06

reduces that uncertainty using a concept

play03:09

called information gain.

play03:11

We'll use these to select the best

play03:12

question to ask at each point.

play03:14

And given that question, we'll recursively build the tree

play03:17

on each of the new nodes.

play03:18

We'll continue dividing the data until there

play03:20

are no further questions to ask, at which point

play03:23

we'll add a leaf.

play03:24

To implement this, first we need to understand

play03:26

what type of questions can we ask about the data.

play03:29

And second, we need to understand

play03:30

how to decide which question to ask when.

play03:35

Now each node takes a list of rows as input.

play03:37

And to generate a list of questions

play03:39

we'll iterate over every value for every feature that

play03:42

appears in those rows.

play03:44

Each of these becomes a candidate

play03:45

for a threshold we can use to partition the data.

play03:48

And there will often be many possibilities.

play03:51

In code we represent a question by storing

play03:52

a column number and a column value,

play03:55

or the threshold we'll use to partition the data.

play03:58

For example, here's how we'd write a question

play03:59

to test if the color is green.

play04:01

And here's an example for a numeric attribute

play04:03

to test if the diameter is greater than or equal to 3.

play04:07

In response to a question, we divide, or partition, the data

play04:10

into two subsets.

play04:12

The first contains all the rows for which the question is true.

play04:14

And the second contains everything else.

play04:17

In code, our partition function takes a question

play04:19

and a list of rows as input.

play04:21

For example, here's how we partition the rows based

play04:24

on whether the color is red.

play04:26

Here, true rows contains all the red examples.

play04:29

And false rows contains everything else.

play04:33

The best question is the one that reduces our uncertainty

play04:36

the most.

play04:37

And Gini impurity let's us quantify how much uncertainty

play04:39

there is at a node.

play04:41

Information gain will let us quantify how much

play04:43

a question reduces that.

play04:44

Let's work on impurity first.

play04:46

Now this is a metric that ranges between 0 and 1

play04:49

where lower values indicate less uncertainty, or mixing,

play04:52

at a node.

play04:53

It quantifies our chance of being incorrect if we randomly

play04:56

assign a label from a set to an example in that set.

play05:00

Here's an example to make that clear.

play05:02

Imagine we have two bowls and one contains the examples

play05:06

and the other contains labels.

play05:08

First, we'll randomly draw an example from the first bowl.

play05:11

Then we'll randomly draw a label from the second.

play05:14

And now, we'll classify the example as having that label.

play05:17

And Gini impurity gives us our chance of being incorrect.

play05:21

In this example, we have only apples in each bowl.

play05:23

There's no way to make a mistake.

play05:25

So we say the impurity is zero.

play05:28

On the other hand, given a bowl with five different types

play05:30

of fruit in equal proportion, we'd

play05:32

say it has an impurity of 0.8.

play05:35

That's because we have a one out of five chance of being right

play05:37

if we randomly assign a label to an example.

play05:41

In code, this method calculates the impurity of a data set.

play05:44

And I've written a couple examples

play05:45

below that demonstrate how it works.

play05:48

You can see the impurity for the first set is zero

play05:50

because there's no mixing.

play05:51

And here, you can see the impurity is 0.8.

play05:57

Now information gain will let us find the question that reduces

play05:59

our uncertainty the most.

play06:01

And it's just a number that describes

play06:02

how much a question helps to unmix the labels at a node.

play06:06

Here's the idea.

play06:07

We begin by calculating the uncertainty

play06:09

of our starting set.

play06:11

Then, for each question we can ask,

play06:12

we'll try partitioning the data and calculating

play06:15

the uncertainty of the child nodes that result.

play06:18

We'll take a weighted average of their uncertainty

play06:20

because we care more about a large set with low uncertainty

play06:22

than a small set with high.

play06:25

Then, we'll subtract this from our starting uncertainty.

play06:28

And that's our information gain.

play06:29

As we go, we'll keep track of the question that

play06:31

produces the most gain.

play06:33

And that will be the best one to ask at this node.

play06:36

Let's see how this looks in code.

play06:38

Here, we'll iterate over every value for the features.

play06:41

We'll generate a question for that feature,

play06:43

then partition the data on it.

play06:45

Notice we discard any questions that fail to produce a split.

play06:49

Then, we'll calculate our information gain.

play06:51

And inside this function, you can

play06:52

see we take a weighted average and the impurity of each set.

play06:56

We see how much this reduces the uncertainty

play06:57

from our starting set.

play06:59

And we keep track of the best value.

play07:01

I've written a couple of demos below as well.

play07:04

OK, with these concepts in hand, we're ready to build the tree.

play07:07

And to put this all together I think the most useful thing I

play07:10

can do is walk you through the algorithm

play07:11

as it builds a tree for our training data.

play07:14

This uses recursion, so seeing it in action can be helpful.

play07:17

You can find the code for this inside the Build Tree function.

play07:21

When we call build tree for the first time,

play07:23

it receives the entire training set as input.

play07:26

And as output it will return a reference

play07:27

to the root node of our tree.

play07:30

I'll draw a placeholder for the root here in gray.

play07:33

And here are the rows we're considering at this node.

play07:35

And to start, that's the entire training set.

play07:38

Now we find the best question to ask at this node.

play07:40

And we do that by iterating over each of these values.

play07:44

We'll split the data and calculate the information

play07:47

gained for each one.

play07:48

And as we go, we'll keep track of the question that

play07:50

produces the most gain.

play07:53

Now in this case, there's a useful question to ask,

play07:55

so the gain will be greater than zero.

play07:57

And we'll split the data using that question.

play08:00

And now, we'll use recursion by calling build tree again

play08:03

to add a node for the true branch.

play08:06

The rows we're considering now are the first half

play08:08

of the split.

play08:09

And again, we'll find the best question to ask for this data.

play08:13

Once more we split and call the build tree function

play08:15

to add the child node.

play08:17

Now for this data there are no further questions to ask.

play08:20

So the information gain will be zero.

play08:22

And this node becomes a leaf.

play08:24

It will predict that an example is either

play08:26

an apple or a lemon with 50% confidence

play08:28

because that's the ratio of the labels in the data.

play08:32

Now we'll continue by building the false branch.

play08:34

And here, this will also become a leaf.

play08:36

We'll predict apple with 100% confidence.

play08:40

Now the previous call returns, and this node

play08:42

becomes a decision node.

play08:44

In code, that just means it holds a reference

play08:46

to the question we asked and the two child nodes that result.

play08:49

And we're nearly done.

play08:52

Now we return to the root node and build the false branch.

play08:55

There are no further questions to ask, so this becomes a leaf.

play08:58

And that predicts grape with 100% confidence.

play09:00

And finally, the root node also becomes a decision node.

play09:04

And our call to build tree returns a reference to it.

play09:06

If you scroll down in the code, you'll

play09:08

see that I've added functions to classify data and print

play09:10

the tree.

play09:11

And these start with a reference to the root node,

play09:13

so you can see how it works.

play09:17

OK, hope that was helpful.

play09:18

And you can check out the code for more details.

play09:21

There's a lot more I have to say about decision trees,

play09:23

but there's only so much we can fit into a short time.

play09:25

Here are a couple of topics that are good to be aware of.

play09:28

And you can check out the books in the description

play09:30

to learn more.

play09:31

As a next step, I'd recommend modifying the tree

play09:33

to work with your own data set.

play09:35

And this can be a fun way to build

play09:36

a simple and interpretable classifier for use

play09:38

in your projects.

play09:40

Thanks for watching, everyone.

play09:41

And I'll see you next time.

Rate This

5.0 / 5 (0 votes)

Etiquetas Relacionadas
Machine LearningDecision TreesPython CodingData ScienceCART AlgorithmGini ImpurityInformation GainClassifier BuildingFruit ClassificationData Partitioning
¿Necesitas un resumen en inglés?