TransWikia.com

he_normal (Keras) is truncated when kaiming_normal_ (pytorch) is not

Cross Validated Asked by londumas on November 3, 2020

Thanks for having a look at my post.

I had an extensive look at the difference in weight initialization between pytorch
and Keras, and it appears that the definition of he_normal (Keras)
and kaiming_normal_ (pytorch) is different across the two platforms.

They both claim to be applying the solution presented in He et al. 2015 (https://arxiv.org/abs/1502.01852) :
https://pytorch.org/docs/stable/nn.init.html,
https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal.
However, I found no trace of truncation in that later paper.
To me truncation makes a lot of sense.

Do I have a bug in my simple code that follows, or indeed these two platforms claim
to apply a solution from a paper, but differ in their implementation.
Then how is correct? What is best?

enter image description here

import numpy as np
import matplotlib.pyplot as plt

import torch

import keras
import keras.models as Model
from keras.layers import Input
from keras.layers.core import Dense

real = 100

### pyTorch
params = np.array([])
for _ in range(real):
    lin = torch.nn.Linear(in_features=16, out_features=16)
    torch.nn.init.kaiming_normal_(lin.weight)
    params = np.append(params,lin.weight.detach().numpy())
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'PyTorch')

### Keras
params = np.array([])
for _ in range(real):
    X_input = Input([16])
    X = Dense(units=16, activation='relu', kernel_initializer='he_normal')(X_input)
    model = Model.Model(inputs=X_input,outputs=X)
    params = np.append(params,model.get_weights()[0])
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'Keras')

###
plt.xlabel(r'Weights')
plt.ylabel(r'#')
plt.yscale('log')
plt.legend()
plt.grid()
plt.show()

One Answer

I think you're correct that the two initializers are different; this difference is consistent with the description in the documentation.

For Keras, the documentation says

It draws samples from a truncated normal distribution centered on 0 with stddev = sqrt(2 / fan_in) where fan_in is the number of input units in the weight tensor.

Where by contrast, for Torch, the documentation says

Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from $mathcal{N}(0, text{std}^2)$ where

$$text{std} = frac{text{gain}}{sqrt{text{fan_mode}}}$$

Choosing $text{gain}=sqrt{2}$ and $text{fan_mode}=text{fan_in}$ makes the standard deviations the same, but the Keras function is using a truncated distribution while the Torch function is not, so the resulting distributions will be different. Again, this is consistent with your findings.

So the Torch function isn't truncating, while the Keras function is.

When we look to the paper, again, you're correct: the He 2015 paper does not describe a truncation in the text. Since the cited article doesn't seem to support the initialization for the Keras function, it could be reasonable to create an issue on the library's Github or another official Keras channel. It's also possible that the Keras authors meant to cite a different article, or something like that.

As for your last question,

What is best?

Best for what? Does one of the initializations suit your task well? If so, use that one. If not, use a different one. The He paper describes a network design and finds that this initialization works well, and provides some commentary and theoretical justification. But the network that you want to build may not match the models He was examining, or it may not conform to some of the assumptions that He made in the theoretical analysis. In particular, the He paper is focused on ReLU and PReLU networks; if you're using a different activation function, your results may require an alternative initialization scheme.

Answered by Sycorax on November 3, 2020

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP