How do you see how accurate TensorFlow an image classification model is for each class?



I’m following through the image classification tutorial on the tensor flow website:

The model classifies flowers into one of 5 classes: daisy, dandelion, roses, sunflower and tulips.

I can see what the overall accuracy is, but is there any way I can know how accurate it is for each class?

For example, my model could be very good at predicting daisies, dandelions, roses, and sunflowers (near 100% accuracy), and poor at tulips (near 0%) and I think I’d still see 80% overall accuracy (assuming the classes are balanced). I’d need to know the accuracy for the individual classes to differentiate that performance from a model that predicts all classes at an approximately equal 80% accuracy.


When I asked this question I didn’t have enough python (or scikit-learn) knowledge to answer. The classification report (as suggested by prashant0598) is close to what I need, although it doesn’t actually have the accuracy in it. Here’s how to use the classification report:

from sklearn.metrics import classification_report
import pandas as pd

y_pred = model.predict(val_ds)
y_pred = np.argmax(y_pred, axis=1)

y_true = np.concatenate([y for x, y in val_ds], axis=0)

cr = classification_report(y_true, y_pred, output_dict=True, target_names=class_names)

The classification report outputs (among other things) precision and recall, which help.

To get the class accuracy out, we have to do this a bit more manually. Here’s one way:

from sklearn.metrics import accuracy_score

def class_accuracy(class_no):
  pred_filter = y_true==class_no
  acc = accuracy_score(y_true[pred_filter], y_pred[pred_filter])
  return acc

{class_name: class_accuracy(i) for i, class_name in enumerate(class_names)}

{‘daisy’: 0.6589147286821705,
‘dandelion’: 0.75,
‘roses’: 0.6,
‘sunflowers’: 0.868421052631579,
‘tulips’: 0.6942675159235668}

So now I know, sunflowers are the easiest to predict, and roses are particularly tricky!

Answered By – s_pike

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave A Reply

Your email address will not be published.

This website uses cookies to improve your experience. We'll assume you're ok with this, but you can opt-out if you wish. Accept Read More