# he_normal (Keras) is truncated when kaiming_normal_ (pytorch) is not

Cross Validated Asked by londumas on November 3, 2020

Thanks for having a look at my post.

I had an extensive look at the difference in weight initialization between pytorch
and Keras, and it appears that the definition of he_normal (Keras)
and kaiming_normal_ (pytorch) is different across the two platforms.

They both claim to be applying the solution presented in He et al. 2015 (https://arxiv.org/abs/1502.01852) :
https://pytorch.org/docs/stable/nn.init.html,
https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal.
However, I found no trace of truncation in that later paper.
To me truncation makes a lot of sense.

Do I have a bug in my simple code that follows, or indeed these two platforms claim
to apply a solution from a paper, but differ in their implementation.
Then how is correct? What is best?

import numpy as np
import matplotlib.pyplot as plt

import torch

import keras
import keras.models as Model
from keras.layers import Input
from keras.layers.core import Dense

real = 100

### pyTorch
params = np.array([])
for _ in range(real):
lin = torch.nn.Linear(in_features=16, out_features=16)
torch.nn.init.kaiming_normal_(lin.weight)
params = np.append(params,lin.weight.detach().numpy())
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'PyTorch')

### Keras
params = np.array([])
for _ in range(real):
X_input = Input([16])
X = Dense(units=16, activation='relu', kernel_initializer='he_normal')(X_input)
model = Model.Model(inputs=X_input,outputs=X)
params = np.append(params,model.get_weights()[0])
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'Keras')

###
plt.xlabel(r'Weights')
plt.ylabel(r'#')
plt.yscale('log')
plt.legend()
plt.grid()
plt.show()


I think you're correct that the two initializers are different; this difference is consistent with the description in the documentation.

For Keras, the documentation says

It draws samples from a truncated normal distribution centered on 0 with stddev = sqrt(2 / fan_in) where fan_in is the number of input units in the weight tensor.

Where by contrast, for Torch, the documentation says

Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from $$mathcal{N}(0, text{std}^2)$$ where

$$text{std} = frac{text{gain}}{sqrt{text{fan_mode}}}$$

Choosing $$text{gain}=sqrt{2}$$ and $$text{fan_mode}=text{fan_in}$$ makes the standard deviations the same, but the Keras function is using a truncated distribution while the Torch function is not, so the resulting distributions will be different. Again, this is consistent with your findings.

So the Torch function isn't truncating, while the Keras function is.

When we look to the paper, again, you're correct: the He 2015 paper does not describe a truncation in the text. Since the cited article doesn't seem to support the initialization for the Keras function, it could be reasonable to create an issue on the library's Github or another official Keras channel. It's also possible that the Keras authors meant to cite a different article, or something like that.

What is best?

Best for what? Does one of the initializations suit your task well? If so, use that one. If not, use a different one. The He paper describes a network design and finds that this initialization works well, and provides some commentary and theoretical justification. But the network that you want to build may not match the models He was examining, or it may not conform to some of the assumptions that He made in the theoretical analysis. In particular, the He paper is focused on ReLU and PReLU networks; if you're using a different activation function, your results may require an alternative initialization scheme.

Answered by Sycorax on November 3, 2020

## Related Questions

### Generating values from Normal Mixture distributions via copulas

1  Asked on November 26, 2021 by ravonrip

### Split-plot analysis with four factors in SAS

1  Asked on November 26, 2021

### How to learn ‘end of sequence’ for continuous sequence?

1  Asked on November 26, 2021 by ken-geonmin-kim

### Parameter Tuning for Random Forest Text Classifier

2  Asked on November 26, 2021

### Understanding distributional Temporal Difference Learning

1  Asked on November 26, 2021

### Intuition behind Box-Cox transform

2  Asked on November 26, 2021

### Why is logistic regression particularly prone to overfitting in high dimensions?

8  Asked on November 26, 2021

### Covariance matrix of integral of multivariate normal distribution

0  Asked on November 26, 2021 by crash-overflow

### Generate bounded numbers with approximately skewed normal distribution

0  Asked on November 26, 2021 by wrahool

### How would you test for a difference in mean group differences between scenarios?

2  Asked on November 26, 2021

### How to explain a multilevel model bias towards one of the levels?

0  Asked on November 26, 2021 by ronny-efronny

### Select a sample size form the infinite population

1  Asked on November 26, 2021 by effective-learning

### Interpreting coefficients of ordinal independent variables in logistic regression in R

1  Asked on November 26, 2021

### Modelling the probability of class membership using k-NN and associated distances

0  Asked on November 26, 2021 by alexis-drakopoulos

### Cox regression for non-repeated events with panel data

0  Asked on November 26, 2021 by philipp-kn_98

### Differentiable PCA?

0  Asked on November 26, 2021 by yaoshiang

### How to forecast with certain conditions

1  Asked on November 24, 2021

### Physical/pictoral interpretation of higher-order moments

1  Asked on November 24, 2021 by james-koppel

### Dimensionality reduction of small vectors (image processing)

1  Asked on November 24, 2021

### Linear Mixed Models and ANOVA

1  Asked on November 24, 2021 by user39531