You already know ML: a decision tree primer

A decision tree that looks like an actual tree, in cartoon style (Google Gemini)

Decision trees are one of the easiest machine learning algorithms to understand. In fact, you already use them. For example, when you were a child, you might have played the game 20 Questions.

The game works like this: first, think of something–a person, place or thing. I will choose Austin Powers. Then, it’s the other person’s turn.

A good strategy is to start asking general questions, then proceed to be more specific as the set of possible answers gets smaller. This is because general questions give you more information than specific ones.

An example of how the “guesser” might approach the problem.

You might also think of decision trees like a series of if-then statements.

What’s missing from this model is learning. You can already code a bunch of if-else statements. This might work well for small problems, but the technique doesn’t scale.

Decision trees use supervised learning to improve on this. Supervised learning means that humans give the algorithm input/output pairs to learn from. Information gain (a formal term) is used to compute which attributes (or features) of the data maximize learning.

Going back to our 20 Questions example, a question like, “Is it alive?” tells us much more than “Is it located in the state of California?” So we would say that the former question gives us more information than the latter.

Let’s use a more concrete example. We will train a decision tree on the diabetes dataset from scikit-learn. It consists of 442 patients. Patients have 10 features, including age, sex, body mass index, and cholesterol levels. Each patient also has a label Y (or target), which is a quantitative measure of their disease progression. Our goal is to predict Y given the features.

The following image demonstrates how to load the dataset in a notebook, then display the first few rows.

You might be surprised looking at this data–no one’s age can be negative. Scikit-learn has done us a favor by cleaning up the data to increase our algorithm’s performance.

We then do a typical 80/20 train/test split. This means that the algorithm has access to 80% of the answers, but it has to guess the remaining 20%. This is how we measure its performance.

The tree has 691 nodes (decision points), so it is difficult to visualize it in entirety. Here is one grouping of nodes.

This doesn’t look too different from the 20 Questions example. At the top node, you can see that it also asks a question: “Is s5 (log of serum triglycerides) less than or equal to 0.006?” The other three values (squared_error, samples, and value) provide a bit more detail on how the tree is navigated.

Finally, we can measure how well we did using a metric called Mean Absolute Error (MAE). This tells us how far away the guesses were from the actual values.

Here is a link to the Colab notebook I used for this blog post. I hope you found this information useful. You already knew how this algorithm works, I just gave you a brief reminder 😉

Sources & Further Reading:


Leave a Reply

Your email address will not be published. Required fields are marked *