TransWikia.com

How to save model architecture in PyTorch?

Stack Overflow Asked on November 22, 2021

I know I can save a model by torch.save(model.state_dict(), FILE) or torch.save(model, FILE). But both of them don’t save the architecture of model.

So how can we save the architecture of a model in PyTorch like creating a .pb file in Tensorflow ? I want to apply different tweaks to my model. Do I have any better way than copying the whole class definition every time and creating a new class if I can’t save the architecture of a model?

3 Answers

Regarding the actual question:

So how can we save the architecture of a model in PyTorch like creating a .pb file in Tensorflow ?

The answer is: You cannot

Is there any way to load a trained model without declaring the class definition before ? I want the model architecture as well as parameters to be loaded.

no, you have to load the class definition before, this is a python pickling limitation.

https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11

Though, there are other options (probably you have already seen most of those) that are listed at this PyTorch post:

https://pytorch.org/tutorials/beginner/saving_loading_models.html

Answered by Xxxo on November 22, 2021

Saving all the parameters (state_dict) and all the Modules is not enough, since there are operations that manipulates the tensors, but are only reflected in the actual code of the specific implementation (e.g., reshapeing in ResNet).

Furthermore, the network might not have a fixed and pre-determined compute graph: You can think of a network that has branching or a loop (recurrence).

Therefore, you must save the actual code.

Alternatively, if there are no branches/loops in the net, you may save the computation graph, see, e.g., this post.

You should also consider exporting your model using onnx and have a representation that captures both the trained weights as well as the computation graph.

Answered by Shai on November 22, 2021

You can refer to this article to understand how to save the classifier. To make a tweaks to a model, what you can do is create a new model which is a child of the existing model.


class newModel( oldModelClass):
    def __init__(self):
        super(newModel, self).__init__()

With this setup, newModel has all the layers as well as the forward function of oldModelClass. If you need to make tweaks, you can define new layers in the __init__ function and then write a new forward function to define it.

Answered by Roshan Santhosh on November 22, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP