TransWikia.com

why my pytorch liner regression failed?

Data Science Asked on June 27, 2021

I am new to pytorch, i want start from a simple example-linear regression:

I created some random training and test sample.

here is my code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt 
from numpy import random
from util import *

np.random.seed(100)

class MLP(nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.l1 = nn.Linear(100, 100)
    self.l2 = nn.Linear(100, 1)

  def forward(self, x): 
    x = (self.l1(x))
    x = (self.l2(x))
    return x

mlp = MLP().float()

target = nn.MSELoss()
o = opt.SGD(mlp.parameters(), lr=0.02, momentum=0.9)

w = torch.tensor(np.random.rand(100, 1) * 3).float()
x = torch.tensor(np.random.rand(100, 100) * 100).float()
y = torch.mm(x, w) + 2 

test_x = torch.tensor(np.random.rand(100, 100) * 10).float()
test_y = torch.mm(test_x, w) + 2 

for epoch in range(100):
  op = mlp(x)
  if epoch == 10: 
    print(op)
    sys.exit(1)
  o.zero_grad()
  loss = target(op, y)
  loss.backward()
  o.step()
  test_pred = mlp(test_x)
  print(test_pred.shape)
  print(test_y.shape)
  print('%dth: loss=%.4f os_loss=%.4f'%(epoch, loss.item(), target(test_pred, test_y).item()))

I found it loss become nan, when several rounds passed.

i cant find out why, i think my netual network framework is correct, can you help on this?

One Answer

The problem seems to come from your learning rate and the non-normalization of your data. Here your network is clearly unstable and thus gets to sky high values (10^20) which lead to NaN values. A typical learning rate for SGD is 0.001, but this is for normalized datas (inputs-outpus between 0 and 1). Here your inputs and ouputs have high values, that are amplified even more by MSE (which squares the error). So this is why there is an issue with the learning rate here, the resulting gradient is way too strong.

This is how I understand the network behaviour, that is clearly unstable with lr=0.02.

A way to solve it is greatly diminish your learning rate (lr=0.00000001 worked for me). Another way around is to Normalize your data (inputs and ouputs)

Here is your modified code, you can try changing the learning rate and see how it changes the network behaviour (from stable to unstable).

import torch
import torch.nn as nn
import torch.optim as opt
import numpy as np
import matplotlib.pyplot as plt
import math
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # This line may be useless


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.l1 = nn.Linear(100, 100)
        self.l2 = nn.Linear(100, 1)

    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        return x


mlp = MLP().float()

target = nn.MSELoss()
o = opt.SGD(mlp.parameters(), lr=0.00000001)  #Try 0.0000001 and 0.000001

w = torch.tensor(np.random.rand(100, 1) * 3).float()
x_train = torch.tensor(np.random.rand(100, 100)*100).float()
y_train = torch.mm(x_train, w) + 2

test_x = torch.tensor(np.random.rand(100, 100)*10).float()
test_y = torch.mm(test_x, w) + 2

losslist = []
losstestlist = []

for epoch in range(100):
    op = mlp(x_train)
    o.zero_grad()
    loss = target(op, y_train)
    losslist.append(math.log10(loss))
    loss.backward()
    o.step()
    test_pred = mlp(test_x)
    loss = target(test_pred, test_y)
    losstestlist.append(math.log10(loss))


plt.plot(losslist, 'r', label='train loss')
plt.plot(losstestlist, 'b', label='test loss')
plt.legend()
plt.show()

Hope this helps.

Correct answer by Ubikuity on June 27, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP