How to Visualize a Random Forest in Python?

While training the random forest in Python, we might want to visualize a random forest so that we could see the training and create decision trees in our model. To remain, random forest is a supervised machine learning that takes the training data in order to create a forest of decision trees and then make predictions based on the majority voting. In this article, we will go through various methods to visualize a random forest in Python using different tools and modules.

Check how fully guide article on Random forests

Implementation of Random forests in Machine learning using Python

What is a Random Forest in Machine Learning?

Random forest is a type of supervised machine learning which means it takes the training dataset with input and output values. The model then creates a bunch of decision trees on randomly selected data from the training dataset.

Given diagram shows how the random forest in Python actually works.

visualize-a-random-forest-in-python

As shown above, the random forest model contains a forest of decision trees which is why it is known as a random forest.

How many decision trees are there in a random forest?

There is no limitation on the number of decision trees in a random forest. By default, the number of decision trees in a random forest is fixed at 100. You can change this number by using the n_estimators parameter.

The n_estimator parameter in the random forest is used to specify the number of decision trees in the forest. Usually, we specify this number while initializing the random forest model as shown below:

# instantiate Random forest using python 
classifier = RandomForestClassifier(n_estimators=50)

As shown, the number of decision trees in the random forest has been specified at 50.

Training the random forest model

In order to visualize a random forest model using Python, we need to first train the model on a training dataset. Let us first import the dataset which we will be using to train the model.

# importing pandas
import pandas as pd

# improting dataset
data = pd.read_csv('dataset.csv')

# heading of dataset
data.head()

Output:

	Age	Salary	Purchased
0	19	19000	0
1	35	20000	0
2	26	43000	0
3	27	57000	0
4	19	76000	0

As you can see, we have a classification dataset. The output class has two possible outcomes.

Now we will split the dataset into input and output with training and testing parts.

# dividing the dataset
X = data.drop('Purchased', axis=1)
y = data['Purchased']

# importing the train_test_split method from sklearn
from sklearn.model_selection import train_test_split

# splitting the data 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=0)

We have assigned 20% of the data to the testing part and the remaining to the training part.

Now, it is time to train the model using the training dataset. We will fix the number of decision trees to 50 while initializing the model.

# import Random forest using python
from sklearn.ensemble import RandomForestClassifier

# instantiate Random forest using python 
classifier = RandomForestClassifier(n_estimators=50)

# fit Random forest using python
classifier.fit(X_train, y_train)

Once the training of the model is complete, we are good to visualize a random forest.

Method-1: Visualize a random forest classifier using a tree

We will now use our first method to visualize the random forest classifier. We will be using the tree submodule from the sklearn module to visualize a random forest. The random forest contains a forest of decision trees, we cannot visualize all decision trees at once. We can visualize the decision trees one by one. Let us first visualize the first decision tree in our random forest.

from sklearn import tree
import matplotlib.pyplot as plt

# using the fig
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4))

# tree 
tree.plot_tree(classifier.estimators_[0],
               filled = True);

# saving the decision tree
fig.savefig('rf_individualtree.png')

Output:

rf_individualtree

The estimators_[0] in the code represent the first decision tree in the random forest. We can increase this number to 1 to get the second decision tree.

# using the fig
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4))

# tree 
tree.plot_tree(classifier.estimators_[1],
               filled = True);

# saving the decision tree
fig.savefig('rf_individualtree.png')

Output:

rf_individualtree

If we will increase the value in the estimators_[] above 50, we will get an error as shown below:

# using the fig
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4))

# tree 
tree.plot_tree(classifier.estimators_[51],
               filled = True);

# saving the decision tree
fig.savefig('rf_individualtree.png')

Output:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_68742/1231554001.py in <module>
      6 
      7 # tree
----> 8 tree.plot_tree(classifier.estimators_[51],
      9                filled = True);
     10 

IndexError: list index out of range

The reason for getting this error is that in our training part, we have specified the number of decision trees in our model to be 50 so anything outside this range will give the error.

Method-2: Visualize a Random Forest Classifier in text format

When you have a large dataset, then visualizing the random forest in pictorial format might take a lot of time. Another easiest way to show the random forest is using the text tree which will represent the decision trees in the form of text.

# visualize a random forest
text_representation = tree.export_text(classifier.estimators_[0])

# printing the text
print(text_representation)

This will visualize the random forest in the following format:

|--- feature_0 <= 42.50
|   |--- feature_1 <= 90500.00
|   |   |--- feature_0 <= 36.50
|   |   |   |--- class: 0.0
|   |   |--- feature_0 >  36.50
|   |   |   |--- feature_0 <= 41.50
|   |   |   |   |--- feature_0 <= 40.50
|   |   |   |   |   |--- feature_0 <= 38.50
|   |   |   |   |   |   |--- feature_0 <= 37.50
|   |   |   |   |   |   |   |--- feature_1 <= 76500.00
|   |   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |   |   |--- feature_1 >  76500.00
|   |   |   |   |   |   |   |   |--- feature_1 <= 79000.00
|   |   |   |   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |   |   |   |--- feature_1 >  79000.00
|   |   |   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |   |--- feature_0 >  37.50
|   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |--- feature_0 >  38.50
|   |   |   |   |   |   |--- feature_0 <= 39.50
|   |   |   |   |   |   |   |--- feature_1 <= 77000.00
|   |   |   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |   |   |--- feature_1 >  77000.00
|   |   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |   |--- feature_0 >  39.50
|   |   |   |   |   |   |   |--- feature_1 <= 66000.00
|   |   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |   |   |--- feature_1 >  66000.00
|   |   |   |   |   |   |   |   |--- feature_1 <= 74500.00
|   |   |   |   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |   |   |   |--- feature_1 >  74500.00
|   |   |   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |--- feature_0 >  40.50
|   |   |   |   |   |--- class: 0.0
|   |   |   |--- feature_0 >  41.50
|   |   |   |   |--- feature_1 <= 71500.00
|   |   |   |   |   |--- class: 0.0
|   |   |   |   |--- feature_1 >  71500.00
|   |   |   |   |   |--- class: 1.0
|   |--- feature_1 >  90500.00
|   |   |--- feature_1 <= 116500.00
|   |   |   |--- feature_0 <= 34.50
|   |   |   |   |--- feature_0 <= 31.50
|   |   |   |   |   |--- feature_1 <= 111500.00
|   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |--- feature_1 >  111500.00
|   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |--- feature_0 >  31.50
|   |   |   |   |   |--- feature_0 <= 33.50
|   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |--- feature_0 >  33.50
|   |   |   |   |   |   |--- feature_1 <= 113500.00
|   |   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |   |--- feature_1 >  113500.00
|   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |--- feature_0 >  34.50
|   |   |   |   |--- class: 1.0
|   |   |--- feature_1 >  116500.00
|   |   |   |--- class: 1.0
|--- feature_0 >  42.50
|   |--- feature_0 <= 46.50
|   |   |--- feature_0 <= 43.50
|   |   |   |--- feature_1 <= 122500.00
|   |   |   |   |--- class: 1.0
|   |   |   |--- feature_1 >  122500.00
|   |   |   |   |--- class: 0.0
|   |   |--- feature_0 >  43.50
|   |   |   |--- feature_1 <= 25000.00
|   |   |   |   |--- class: 0.0
|   |   |   |--- feature_1 >  25000.00
|   |   |   |   |--- feature_1 <= 62000.00
|   |   |   |   |   |--- class: 1.0
|   |   |   |   |--- feature_1 >  62000.00
|   |   |   |   |   |--- feature_1 <= 83500.00
|   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |--- feature_1 >  83500.00
|   |   |   |   |   |   |--- feature_1 <= 92000.00
|   |   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |   |--- feature_1 >  92000.00
|   |   |   |   |   |   |   |--- feature_1 <= 117500.00
|   |   |   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |   |   |--- feature_1 >  117500.00
|   |   |   |   |   |   |   |   |--- class: 1.0
|   |--- feature_0 >  46.50
|   |   |--- class: 1.0

This tree represents the first decision tree in our random forest.

Method-3: Visualize the random forest using Graphviz

Graphviz is an open-source visualizing tool that can be used in Python programming language. Here, we will use the graphviz module in order to visualize a random forest and save the decision in an image format.

# importing the modules
import graphviz
from sklearn.tree import export_graphviz

# selecting the first graph
dot_data = export_graphviz(classifier.estimators_[0])

# saving the image
graph = graphviz.Source(dot_data, format='png')
graph

Output:

visualize-the-random-forest

Because the decision tree was large, we couldn’t take a full screenshot of it.

Visualizing a random forest regressor

Visualizing random forest regressor is very much similar to the classifier as we did use various methods. The only difference will be the structure of the decision tree. In the case of the regression model, we will be having a very messy decision tree because of having many possible output values.

Here we will take an example of how to visualize a random forest regressor on a simple dataset.

# importing the module
from sklearn import tree

# setting the size
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4))

# plotting the first decision tre
tree.plot_tree(regressor.estimators_[0],
               filled = True);
fig.savefig('rf_individualtree.png')

Output:

rf_individualtree

You can see, the random forest regressor is very much messy because of having many possible outputs.

Summary

In this short article, we learned how we can visualize a random forest in Python using various methods. We discussed three possible methods to visualize the random forest. We also discussed whey it take too much time to visualize the random forest regressor as well.

Leave a Comment