normalize training data with channel means and standard deviation in CNN model

Stack Overflow Asked on November 18, 2021

I am using CNN for multi-class image classification, but accuracy is not very good. I assume I need to normalize training data with channel means and standard deviation so it might contribute to better accuracy. I came out one way for doing this, but it is not very efficient because I just put random value for means and standard deviation for normalization. I am not sure how to find channel means and its standard deviation. I was wondering is there any way of doing this. can anyone point me out how to achieve this? Any possible thoughts?

my current attempt:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten, Input
from keras.datasets import cifar10
from keras.utils import to_categorical

(X_train, y_train), (X_test, y_test)= cifar10.load_data()
output_class = np.unique(y_train)
n_class = len(output_class)

input_shape = (32, 32, 3)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
y_train_one_hot = to_categorical(y_train)
y_test_one_hot = to_categorical(y_test)

x = tf.keras.Input(shape=(32, 32, 3))
conv = Conv2D(128, (3, 3), activation='relu',input_shape=(32, 32, 3))(x)
conv = MaxPooling2D(pool_size=(2,2))(conv)
conv = Conv2D(64, (2,2))(conv)
conv = MaxPooling2D(pool_size=(2,2))(conv)
conv = Flatten()(conv)
conv = Dense(64, activation='relu')(conv)
conv = Dense(10, activation='softmax')(conv)
model = Model(inputs = x, outputs = conv)

my attempt for normalization:

here is my way of normalization, where I just assigned random value to means and standard deviation:

mean = [125.307, 122.95, 113.865]  ## random value
std = [62.9932, 62.0887, 66.7048]  ## random value

for i in range(3):
  X_train[:,:,:,i] = (X_train[:,:,:,i] - mean[i]) / std[i]
  X_test[:,:,:,i] = (X_test[:,:,:,i] - mean[i]) / std[i]

I am wondering is there any way programmatically find channel means and its standard deviation so we could do normalization. Any better idea of doing this? what else possibly do for increasing accuracy of my sample model? How can I find channel means and its standard deviation? any possible strategy or coding attempt?

2 Answers

I believe you could do data normalization on this way which is much promising:

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
nb_classes = 10
Y_train = to_categorical(y_train, nb_classes)
Y_test = to_categorical(y_test, nb_classes)

## find channel mean, std and do data normalization
train_mean = np.mean(X_train, axis=0)
train_std = np.std(X_train, axis=0)
X_train = (X_train - train_mean) / train_std
X_test = (X_test - train_mean) / train_std

## then do training ....

hope this is what you want to do for normalization. Let me know if you have any question :)

Answered by jyson on November 18, 2021

To normalize training image values from 0-255 to 0-1, you can simply divide them by 255.

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# subtract mean
x_train_mean = np.mean(x_train, axis=0)
x_train -= x_train_mean
x_test -= x_train_mean

Please note that the main reason your network does not have good accuracy is having a shallow network. Try to increase the number of Conv2D layers and number of filters in them. You have not provided your optimizer setting as well, but Adam, with learning rate 0.01 is a good start.

Answered by Nima Aghli on November 18, 2021

Add your own answers!

Related Questions

How to Deserialize a list of objects from json in flutter

8  Asked on January 28, 2021 by gainz


How to pass input variable in structure directive?

1  Asked on January 27, 2021 by alice-messis


How to iterate over a list of floats in python

4  Asked on January 27, 2021 by joeyhoward988


Can I run docker system prune -a without downtime

1  Asked on January 27, 2021 by aswin-george


Calculating total Price of groceries with execptions in Python

4  Asked on January 27, 2021 by shemtheultimate


handling multiple useState inside conditional ternary

2  Asked on January 27, 2021 by kachi-cheong


react-native existing app crash on Android 4.1

0  Asked on January 27, 2021 by s-leg3ndz


HTML table widths set proportionally

1  Asked on January 27, 2021 by mayur-arora


Alsa issues on Raspberry pi

1  Asked on January 27, 2021 by suraj-hebbar-shankar


How to merge dictionaries with the same key and value in Python

1  Asked on January 26, 2021 by arbin-bulaybulay


AJAX switching HTML

1  Asked on January 26, 2021 by mimi


Using request to return a json value

1  Asked on January 26, 2021 by gta-sprx


Ask a Question

Get help from others!

© 2023 All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP