TransWikia.com

Don't understand the simple ResNet from 'Deep Implicit Layers' Tutorial

Data Science Asked by D_H on August 1, 2021

The Deep Implicit Layers Tutorial is a nice resource that dives into Neural ODEs, Deep Equilibrium Models etc., using the JAX library. In chapter 3 (out of 5), [Link: http://implicit-layers-tutorial.org/neural_odes/] there is the following example ResNet, which I am confused about:

import jax.numpy as jnp

def mlp(params, inputs):
  # A multi-layer perceptron, i.e. a fully-connected neural network.
  for w, b in params:
    outputs = jnp.dot(inputs, w) + b  # Linear transform
    inputs = jnp.tanh(outputs)        # Nonlinearity
  return outputs


def resnet(params, inputs, depth):
  for i in range(depth):
    outputs = mlp(params, inputs) + inputs
  return outputs

What is particularly confusing to me is that in the ‘resnet’ function, the for loop over the depth seems COMPLETELY redundant.

The inputs of previous layers are not being fed to later layers, and the ‘i’ index is not being used.

Am I missing some fundamental information about ResNets? Is it a mistake? Can someone explain the point of the for loop?

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP