TransWikia.com

PyTorch - using `param.grad` affects learning process? [Solved]

Stack Overflow Asked on January 26, 2021

I’m training ResNet34 on CIFAR-10. I have a very weird behavior when I try to manipulate param.grad for model.parameters().

The following function is where all the mess happens. It currently doesn’t do anything useful as a result of trying to understand what happens.

def add_error(error):
    params = (param for param in model.parameters() if param.requires_grad)
    # [param.grad + err for param, err in zip(params, error)]                        # Line 2
    # new_error = [param.grad + err for param, err in zip(params, error)]            # Line 3
    for param in params:
      param.grad.zero_()
    # new_error = [torch.zeros(param.grad.shape, device=device) for param in params] # Line 6
    return new_error

It’s used in the gradient descent step:

def step(model, optimizer, batch, labels, error):
  optimizer.zero_grad()
  loss = compute_loss(model, batch, labels)
  loss.backward()

  new_error = add_error(error=error)      <- add_error is called here
  optimizer.step()
  return new_error

where optimizer is optim.SGD(model.parameters(), lr=0.1) and compute_loss essentially calls nn.CrossEntropyLoss() on model(batch) and labels.

What I expect: Since I set the gradient to 0, no matter what I do, nothing should change: the loss should be around the original value (2.4) all the time

What actually happens:

  • When I uncomment only line 6, everything works as expected: the loss is close to the constant 2.4.
  • When I uncomment only line 3, The learning happens as if I didn’t call add_error at all. I.e. loss decreases at the same rate as usual SGD: 2.4 -> 1.7 -> 1.3 -> ... (per epoch). In other words, somehow the gradients are propagated.
  • When I uncomment lines 2 and 6: the weirdest thing of all. The loss increases to 4.3, and then slowly decreases 4.3 -> 4.2 -> 4.14 -> 4.1 -> ... (I suspect this decrease to be a result of batch normalization).

Note that in neither case I actually use error, I actually never use error to update the gradient.
Also, adding more lines like Line 2 doesn’t affect the outcome.

Question: What’s happening?

  • Why loss decreases? The gradient should be $0$ all the time.
  • How is it possible to change the loss value itself?

If it helps, I may try to produce MCVE and post it on pastebin (it’ll be a wall of code, too large to fit here).

One Answer

The problem was incredibly stupid: params is a generator and becomes exhausted after the first iteration over it. Creating a list instead of a generator solves the issue.

Answered by Dmitry on January 26, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP