Cross Validated Asked by londumas on November 3, 2020

Thanks for having a look at my post.

I had an extensive look at the difference in weight initialization between pytorch

and Keras, and it appears that the definition of he_normal (Keras)

and kaiming_normal_ (pytorch) is different across the two platforms.

They both claim to be applying the solution presented in He et al. 2015 (https://arxiv.org/abs/1502.01852) :

https://pytorch.org/docs/stable/nn.init.html,

https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal.

However, I found no trace of truncation in that later paper.

To me truncation makes a lot of sense.

Do I have a bug in my simple code that follows, or indeed these two platforms claim

to apply a solution from a paper, but differ in their implementation.

Then how is correct? What is best?

```
import numpy as np
import matplotlib.pyplot as plt
import torch
import keras
import keras.models as Model
from keras.layers import Input
from keras.layers.core import Dense
real = 100
### pyTorch
params = np.array([])
for _ in range(real):
lin = torch.nn.Linear(in_features=16, out_features=16)
torch.nn.init.kaiming_normal_(lin.weight)
params = np.append(params,lin.weight.detach().numpy())
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'PyTorch')
### Keras
params = np.array([])
for _ in range(real):
X_input = Input([16])
X = Dense(units=16, activation='relu', kernel_initializer='he_normal')(X_input)
model = Model.Model(inputs=X_input,outputs=X)
params = np.append(params,model.get_weights()[0])
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'Keras')
###
plt.xlabel(r'Weights')
plt.ylabel(r'#')
plt.yscale('log')
plt.legend()
plt.grid()
plt.show()
```

I think you're correct that the two initializers are different; this difference is consistent with the description in the documentation.

For Keras, the documentation says

It draws samples from a

truncatednormal distribution centered on 0 with`stddev = sqrt(2 / fan_in)`

where`fan_in`

is the number of input units in the weight tensor.

Where by contrast, for Torch, the documentation says

Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from $mathcal{N}(0, text{std}^2)$ where

$$text{std} = frac{text{gain}}{sqrt{text{fan_mode}}}$$

Choosing $text{gain}=sqrt{2}$ and $text{fan_mode}=text{fan_in}$ makes the standard deviations the same, but the Keras function is using a truncated distribution while the Torch function is not, so the resulting distributions will be different. Again, this is consistent with your findings.

So the Torch function isn't truncating, while the Keras function is.

When we look to the paper, again, you're correct: the He 2015 paper does not describe a truncation in the text. Since the cited article doesn't seem to support the initialization for the Keras function, it could be reasonable to create an issue on the library's Github or another official Keras channel. It's also possible that the Keras authors meant to cite a different article, or something like that.

As for your last question,

What is best?

Best for what? Does one of the initializations suit your task well? If so, use that one. If not, use a different one. The He paper describes a network design and finds that this initialization works well, and provides some commentary and theoretical justification. But the network that you want to build may not match the models He was examining, or it may not conform to some of the assumptions that He made in the theoretical analysis. In particular, the He paper is focused on ReLU and PReLU networks; if you're using a different activation function, your results may require an alternative initialization scheme.

Answered by Sycorax on November 3, 2020

1 Asked on January 5, 2022 by aarsmith

logistic mixed model prediction regression regression coefficients

7 Asked on January 3, 2022 by user2806363

2 Asked on January 3, 2022

autoencoders gan graphical model machine learning neural networks

1 Asked on January 3, 2022

1 Asked on January 3, 2022

2 Asked on January 3, 2022 by iplexipen

0 Asked on January 3, 2022 by khemedi

artificial intelligence machine learning neural networks precision recall

0 Asked on January 3, 2022

data visualization machine learning matplotlib python variance

0 Asked on January 3, 2022 by indula

0 Asked on January 3, 2022 by e-wade

lme4 nlme mixed model multilevel analysis r random effects model

2 Asked on January 3, 2022 by fishchick

0 Asked on January 3, 2022

hypothesis testing neyman pearson lemma statistical significance

0 Asked on January 3, 2022 by gannawag

1 Asked on January 3, 2022

0 Asked on January 3, 2022 by ofow

approximation machine learning neural networks optimization polynomial

1 Asked on January 3, 2022 by p-lrc

1 Asked on January 1, 2022

2 Asked on January 1, 2022 by tomek-tarczynski

Get help from others!

Recent Questions

- How Do I Get The Ifruit App Off Of Gta 5 / Grand Theft Auto 5
- Iv’e designed a space elevator using a series of lasers. do you know anybody i could submit the designs too that could manufacture the concept and put it to use
- Need help finding a book. Female OP protagonist, magic
- Why is the WWF pending games (“Your turn”) area replaced w/ a column of “Bonus & Reward”gift boxes?
- Does Google Analytics track 404 page responses as valid page views?

Recent Answers

- haakon.io on Why fry rice before boiling?
- Joshua Engel on Why fry rice before boiling?
- Peter Machado on Why fry rice before boiling?
- Jon Church on Why fry rice before boiling?
- Lex on Does Google Analytics track 404 page responses as valid page views?

© 2023 AnswerBun.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP