TransWikia.com

Should I train the "Unknown" class separately from the other classes

Data Science Asked on July 25, 2021

I have a CNN model that classifies 10 classes of audio spectrograms. However, since I work with the open set of data, I need to classify the unknown audio data as an "Unknown" class. The problem is my training samples of unknown data are larger than the other known class. I’m afraid that there would be a problem when the model performs stochastic optimization.

Should I separate the "Unknown" training data and train the model separately. Or I can just simply mix the unknown data to the other classes and train the model right away?

2 Answers

There are several ways of doing this. Examples are:

Binary classifier

Train a separate binary classifier for Known vs Unknown, using supervised learning. The Known data would come from your dataset, and the Unknown dataset be a large set of samples from a diverse dataset like AudioSet et.c

Anomaly detector

Train an anomaly / out-of-distribution model, using only your dataset (Known) and unsupervised learning. This should be done on the learned representation in your CNN. You can use a Gaussian Mixture Model (e.g. from scikit-learn) as the anomaly model. To verify that it works, and to set hyperparameters such as number of gaussians, anomaly threshold you should use a few samples from another dataset (AudioSet et.c.).

Correct answer by Jon Nordby on July 25, 2021

There is indeed a risk to overlearn the "unknown" class (due to the way larger amount) at the expense of the other classes, which could lead to false "unknown" results.

It mainly depends on how close the "unknown" data to the "known" one is.

Here are 3 potential solutions:

  • A simple one is to use random "unknown" data in a larger amount than the average amount of "known" date (but not too large: twice as large for instance).

  • A logic one is to have a representative sample from a multivariate normal distribution of unknown data. https://juanitorduz.github.io/multivariate_normal/

  • An advanced one is to get the "unknown" data that is closer to the "known" one thanks to an unsupervised classification model. https://scikit-learn.org/stable/unsupervised_learning.html Then add some more unknown data to generalize them better.

Answered by Nicolas M on July 25, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP