How to create a confusion matrix plot using Matplotlib

Confusion matrix is an excellent method to illustrate the results of multi-class classification. It takes a single function call in Matplotlib to generate a colorful confusion matrix plot. However, you have to first have your results in the form of a confusion matrix.

Let me illustrate with an example. Assume, you have 4 classes: A, B, C and D. Your classifier does great on A, C and D with fully accurate results. However, for results that should be class B, it classifies them as A 10% of the time and as C 20% of the time. You should be able to extract such classification results from your classifier easily.

You just need to put these results in a 2D float Numpy array in the form of a confusion matrix. In this type of matrix, typically the true classes are listed on the Y axis, top to bottom. The predicted classes are listed on X axis, from left to right. For our example, the confusion matrix would look like this:

Once you have this as a 2D float Numpy array, just pass it to the matshow method of Matplotlib to generate the confusion matrix plot. To get a temperature scale of the colors used in the plot, call the colorbar method:

import matplotlib.pyplot as plt
# Assume m is 2D Numpy array with these values
# [[1.0 0 0 0 ]
# [0.1 0.7 0.2 0 ]
# [0 0 1.0 0 ]
# [0 0 0 1.0]]
plt.matshow(m)
plt.colorbar()

To add X-axis, Y-axis labels, and other modifications, use the typical calls you use in Matplotlib for other types of plots.

Hello,

You might be interested by

https://github.com/scls19fr/pandas_confusion/

It’s a Python Pandas implementation of Confusion Matrix.

Some features:

– plot confusion matrix

– plot normalized confusion matrix

– class statistics

– overall statistics

Kind regards

LikeLike