How to use pre-trained weight for training convolutional NN in tensorflow?

Stack Overflow Asked by kim on August 28, 2020

In my experiment, I want to train convolutional NN (CNN) with cifar10 on imagenet, and I used ResNet50. Since cifar10 is 32x32x3 set of images while ResNet50 uses 224x224x3. To do so, I need to resize input image in order to train CNN on imagenet. However, I came up following up attempt to train simple CNN on imagenet:

my current attempt:

Please see my whole implementation in this gist:

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = Conv2D(32, (3, 3))(base_model.output)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Flatten()(x)
x = Dense(256)(x)
x = Dense(10)(x)
x = Activation('softmax')(x)
outputs = x
model = models.Model(base_model.input, outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']), y_train, batch_size=50, epochs=3, verbose=1, validation_data=(X_test, y_test))

but this attempt gave me ResourceExhaustedError; I occurred this error before and changing batch_size removed the error. But now even I changed batch_size as small as possible, and still end up with error. I am wondering the way of training CNN on imagenet on above may not be correct or something wrong in my attempt.


I want to understand how about using pre-trained weights (i.e, ResNet50 on imagenet) to train convolutional NN; I am not sure how to get this done in tensorflow. Can anyone provide possible feasible approach to get this right? Thanks

Can anyone point me out what went wrong with my attempt? What would be correct way of training state-of-art CNN model with cifar10 on imagenet? Can anyone share possible thoughts or efficient way of doing this in tensorflow? Any idea? Thanks!

2 Answers

You could use a filter to help downsample the matrix. Mathematically speaking, a 212x212 kernel generates a 32x32 output from a 244x244 input. Something like:

Conv2D(32, (212,212), strides=(1,1), input_shape=(244, 244, 3))

Link to Documentation on Kernel Size and Stride for Keras conv2D

How Filter in CNNs work

Here's a link to an article that simplifies it.

Answered by stackz on August 28, 2020

You might be getting this error because you are trying to allocate the memory (RAM) to the whole data at once. For starters, you might be using numpy arrat to store the images, then those images are being converted to tensors. So you have 2X the memory already even before creating anything. On top of that, resnet is very heavy model so you are trying to pass the whole data at once. That is why the models work with batches. Try to create a generator by using documentation or use the very easy keras.preprocessing.Image.ImageDataGenerator class. It is very easy to use. You can save address of your image files in the Datarame column with another column representing the class and use .flow_from_directory. Or you can use flow_from_directory if you have your images saved in the directory.

Checkout the documentation

Answered by Deshwal on August 28, 2020

Add your own answers!

Related Questions

Are these Threads synchronized?

3  Asked on November 12, 2020 by haoshoku


How to pause and resume a while loop in Python?

5  Asked on November 11, 2020 by mentalcombination


How does thrust determine arguments to pass to functor

1  Asked on November 11, 2020 by a_man


How can I style specific symbols in an element?

6  Asked on November 10, 2020 by ankit-aggarwal


Flutter crash after open apps

4  Asked on November 10, 2020 by zukijuki


Java alternative of product function of python form itertools

1  Asked on November 8, 2020 by vipul-tyagi


Kubernetes – How to run local image of jenkins

1  Asked on November 8, 2020 by jerome12


How to avoid ambiguous template instantiation?

2  Asked on November 8, 2020 by wintergreen_plaza


leetcode algorithm edgecase issue

3  Asked on November 7, 2020 by stephen1993


Arrows in API strings?

1  Asked on November 6, 2020 by vichofs



2  Asked on November 5, 2020 by radagast


react start cannot find files in public folder

2  Asked on November 5, 2020 by minh-triet


tf.keras.utils.to_categorical raises TypeError in graph mode

1  Asked on November 5, 2020 by borun-chowdhury


Azure IoT Hub MQTT failure(Without SDK)

1  Asked on November 5, 2020 by govtham


Ask a Question

Get help from others!

© 2022 All rights reserved. Sites we Love: PCI Database, MenuIva, UKBizDB, Menu Kuliner, Sharing RPP, SolveDir