📅 2014-Oct-24 ⬩ ✍️ Ashwin Nanjappa ⬩ 🏷️ confusion matrix, matplotlib ⬩ 📚 Archive
Confusion matrix is an excellent method to illustrate the results of multi-class classification. It takes a single function call in Matplotlib to generate a colorful confusion matrix plot. However, you have to first have your results in the form of a confusion matrix.
Let me illustrate with an example. Assume, you have 4 classes: A, B, C and D. Your classifier does great on A, C and D with fully accurate results. However, for results that should be class B, it classifies them as A 10% of the time and as C 20% of the time. You should be able to extract such classification results from your classifier easily.
You just need to put these results in a 2D float Numpy array in the form of a confusion matrix. In this type of matrix, typically the true classes are listed on the Y axis, top to bottom. The predicted classes are listed on X axis, from left to right. For our example, the confusion matrix would look like this:
[[100 0 0 0 ]
[10 70 20 0 ]
[0 0 100 0 ]
[0 0 0 100]]
Optionally, you can also normalize the results to 1.0:
[[1.0 0 0 0 ]
[0.1 0.7 0.2 0 ]
[0 0 1.0 0 ]
[0 0 0 1.0]]
Once you have this as a 2D float Numpy array, just pass it to the matshow method of Matplotlib to generate the confusion matrix plot. To get a temperature scale of the colors used in the plot, call the colorbar method:
import matplotlib.pyplot as plt
# Assume m is 2D Numpy array with these values
# [[1.0 0 0 0 ]
# [0.1 0.7 0.2 0 ]
# [0 0 1.0 0 ]
# [0 0 0 1.0]]
plt.matshow(m)
plt.colorbar()
To add X-axis, Y-axis labels, and other modifications, use the typical calls you use in Matplotlib for other types of plots.
Tried with: Python 2.7.6 and Ubuntu 14.04