AnswerBun.com

How to use pre-trained weight for training convolutional NN in tensorflow?

Stack Overflow Asked by kim on August 28, 2020

In my experiment, I want to train convolutional NN (CNN) with cifar10 on imagenet, and I used ResNet50. Since cifar10 is 32x32x3 set of images while ResNet50 uses 224x224x3. To do so, I need to resize input image in order to train CNN on imagenet. However, I came up following up attempt to train simple CNN on imagenet:

my current attempt:

Please see my whole implementation in this gist:

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = Conv2D(32, (3, 3))(base_model.output)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Flatten()(x)
x = Dense(256)(x)
x = Dense(10)(x)
x = Activation('softmax')(x)
outputs = x
model = models.Model(base_model.input, outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=50, epochs=3, verbose=1, validation_data=(X_test, y_test))

but this attempt gave me ResourceExhaustedError; I occurred this error before and changing batch_size removed the error. But now even I changed batch_size as small as possible, and still end up with error. I am wondering the way of training CNN on imagenet on above may not be correct or something wrong in my attempt.

update:

I want to understand how about using pre-trained weights (i.e, ResNet50 on imagenet) to train convolutional NN; I am not sure how to get this done in tensorflow. Can anyone provide possible feasible approach to get this right? Thanks

Can anyone point me out what went wrong with my attempt? What would be correct way of training state-of-art CNN model with cifar10 on imagenet? Can anyone share possible thoughts or efficient way of doing this in tensorflow? Any idea? Thanks!

2 Answers

You could use a filter to help downsample the matrix. Mathematically speaking, a 212x212 kernel generates a 32x32 output from a 244x244 input. Something like:

Conv2D(32, (212,212), strides=(1,1), input_shape=(244, 244, 3))

Link to Documentation on Kernel Size and Stride for Keras conv2D

How Filter in CNNs work

Here's a link to an article that simplifies it.

Answered by stackz on August 28, 2020

You might be getting this error because you are trying to allocate the memory (RAM) to the whole data at once. For starters, you might be using numpy arrat to store the images, then those images are being converted to tensors. So you have 2X the memory already even before creating anything. On top of that, resnet is very heavy model so you are trying to pass the whole data at once. That is why the models work with batches. Try to create a generator by using tf.data.Dataset documentation or use the very easy keras.preprocessing.Image.ImageDataGenerator class. It is very easy to use. You can save address of your image files in the Datarame column with another column representing the class and use .flow_from_directory. Or you can use flow_from_directory if you have your images saved in the directory.

Checkout the documentation

Answered by Deshwal on August 28, 2020

Add your own answers!

Related Questions

Are these Threads synchronized?

3  Asked on November 12, 2020 by haoshoku

     

How to pause and resume a while loop in Python?

5  Asked on November 11, 2020 by mentalcombination

   

How does thrust determine arguments to pass to functor

1  Asked on November 11, 2020 by a_man

     

How can I style specific symbols in an element?

6  Asked on November 10, 2020 by ankit-aggarwal

 

Flutter crash after open apps

4  Asked on November 10, 2020 by zukijuki

       

Java alternative of product function of python form itertools

1  Asked on November 8, 2020 by vipul-tyagi

     

Kubernetes – How to run local image of jenkins

1  Asked on November 8, 2020 by jerome12

   

How to avoid ambiguous template instantiation?

2  Asked on November 8, 2020 by wintergreen_plaza

     

leetcode algorithm edgecase issue

3  Asked on November 7, 2020 by stephen1993

     

Arrows in API strings?

1  Asked on November 6, 2020 by vichofs

   

How to use SQL PARTITION BY GROUPS?

2  Asked on November 5, 2020 by radagast

     

react start cannot find files in public folder

2  Asked on November 5, 2020 by minh-triet

     

tf.keras.utils.to_categorical raises TypeError in graph mode

1  Asked on November 5, 2020 by borun-chowdhury

   

Azure IoT Hub MQTT failure(Without SDK)

1  Asked on November 5, 2020 by govtham

   

Ask a Question

Get help from others!

© 2022 AnswerBun.com. All rights reserved. Sites we Love: PCI Database, MenuIva, UKBizDB, Menu Kuliner, Sharing RPP, SolveDir