TransWikia.com

How to use decode_predictions() for non-Imagenet models..?

Data Science Asked on October 1, 2021

I know that decode_predictions() works for only imagenet dataset(1000 classes) for the models like VGG16 etc.
But condiser my scenario.

My Scenario:

I used vgg16 pretrained model, and added my own weights.

enter image description here

So this turns out to be a non-Imagenet model. I have mentioned classes=9 as i trained my previous model with 9 classes only.

enter image description here

So now to find the predictions, i could use predict() method and then my print(answer) would give the corresponding class label.
But actually i need the class name to be printed.
Is that possible to get class name ?? If it is, can anyone explain me how.?

One Answer

In Deep learning when you are performing prediction you will get prediction in your case it is an array with 9 probabilities in them. So first perform following operation on them.

import numpy as np

prediction = model.prediction(test_data) 
# prediction will contain [[0.6, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]]

prediction = np.argmax(prediction[0])
# Now predition hold the index of the (0.6) i.e max probability value

Now after that you should have a dictionary which contains key as 0, 1, 2, .. 8 and values as classname1, classname2, ...classname9

This is the way you will get the class name as output

Answered by Swapnil Pote on October 1, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP