Code Yarns ‍👨‍💻
Tech BlogPersonal Blog

How to create a confusion matrix plot using Matplotlib

📅 2014-Oct-24 ⬩ ✍️ Ashwin Nanjappa ⬩ 🏷️ confusion matrix, matplotlib ⬩ 📚 Archive

Confusion matrix plot generated using Matplotlib

Confusion matrix is an excellent method to illustrate the results of multi-class classification. It takes a single function call in Matplotlib to generate a colorful confusion matrix plot. However, you have to first have your results in the form of a confusion matrix.

Let me illustrate with an example. Assume, you have 4 classes: A, B, C and D. Your classifier does great on A, C and D with fully accurate results. However, for results that should be class B, it classifies them as A 10% of the time and as C 20% of the time. You should be able to extract such classification results from your classifier easily.

You just need to put these results in a 2D float Numpy array in the form of a confusion matrix. In this type of matrix, typically the true classes are listed on the Y axis, top to bottom. The predicted classes are listed on X axis, from left to right. For our example, the confusion matrix would look like this:

[[100 0  0   0  ]
 [10  70 20  0  ]
 [0   0  100 0  ]
 [0   0  0   100]]

Optionally, you can also normalize the results to 1.0:

[[1.0 0   0   0  ]
 [0.1 0.7 0.2 0  ]
 [0   0   1.0 0  ]
 [0   0   0   1.0]]

Once you have this as a 2D float Numpy array, just pass it to the matshow method of Matplotlib to generate the confusion matrix plot. To get a temperature scale of the colors used in the plot, call the colorbar method:

import matplotlib.pyplot as plt

# Assume m is 2D Numpy array with these values
# [[1.0 0   0   0  ]
#  [0.1 0.7 0.2 0  ]
#  [0   0   1.0 0  ]
#  [0   0   0   1.0]]

plt.matshow(m)
plt.colorbar()

To add X-axis, Y-axis labels, and other modifications, use the typical calls you use in Matplotlib for other types of plots.

Tried with: Python 2.7.6 and Ubuntu 14.04


© 2022 Ashwin Nanjappa • All writing under CC BY-SA license • 🐘 @codeyarns@hachyderm.io📧