선형 회귀 모델(2)_코드 리뷰
https://github.com/ackrilll/Linear_Regression
GitHub - ackrilll/Linear_Regression
Contribute to ackrilll/Linear_Regression development by creating an account on GitHub.
github.com
지난 포스팅에서 선형 회귀 모델을 학습시켜 파라미터를 갱신하는 모습을 살펴봤다. 이번에는 핵심 코드들을
살펴보자
1) 데이터 생성
X = torch.randn(200,1)*10
y = X+3*torch.randn(200,1)
plt.scatter(X.numpy(),y.numpy())
plt.ylabel('y')
plt.xlabel('X')
plt.grid()
plt.show()
PyTorch 텐서 X와 y 생성하고, matplotlib 라이브러리를 사용하여 산점도(scatter plot)를 그린다.
X 는 평균 0, 표준 편차 1의 표준 정규 분포를 따르는 200x1 크기의 PyTorch 텐서이고
y 는 평균 0, 표준 편차 1의 표준 정규 분포를 따르는 200x1 크기의 PyTorch 텐서에 3을 곱해 X와 더한 텐서이다
그려지는 산점도는 다음과 같다. 선형 그래프의 패턴을 닮았다.
오른쪽은 초기의 파라미터 값을 가지고 그려본 직선이다. 아직 모델이 학습되지 않아 임의의 파라미터로 그려진 직선이라서 데이터 경향과 동떨어진 직선 그래프가 그려졌다.
2) 선형 회귀 모델 정의
class LinearRegressionModel(torch.nn.Module):
def __init__(self):
super(LinearRegressionModel,self).__init__()
self.linear = torch.nn.Linear(1,1)
def forward(self,x):
pred = self.linear(x)
return pred
생성자
torch.nn.Module 을 상속받는 LinearRegressionModel 클래스를 정의한다. torch.nn.Module을 상속받는 이유는 torch.nn.Module 의 생성자를 호출하면{super(LinearRegressionModel,self).__init__()}, 모델의 파라미터를 자동으로 추적할 수 있게 만들어주기 때문이다. 이후에 self.linear 라는 속성을 만들어 선형 레이어를 할당한다.
forward
순전파를 정의하는 함수로 x 텐서를 매개변수로 받아 선형 레이어를 거쳐 pred 를 반환한다.
3) LinearRegressionModel 클래스 인스턴스화
model = LinearRegressionModel()
4) 손실 함수 및 옵티마이저 정의
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.001)
torch.nn 에서는 여러 손실함수를 포함하는데 여기서 Mean Square Error 손실함수를 사용한다.
옵티마이저는 SGD를 사용하고 모델의 파라미터를 매개변수로 전달해 업데이트 할 파라미터를 옵티마이저가 알게 한다.
5) 모델 학습
epochs = 100
losses = []
for epoch in range(epochs):
optimizer.zero_grad()
y_pred = model(X)
loss = criterion(y_pred,y)
losses.append(loss.item())
loss.backward()
optimizer.step()
print(f'epoch: {epoch}, loss: {loss.item()}')
모델에 X텐서를 매개변수로 예측 y_pred 를 만든다. 예측값과 실제값으로 손실을 계산한다
loss.backward() -> 손실 값에 대한 기울기(gradient)를 계산한다.
optimizer.step() -> 계산된 기울기를 사용하여 모델의 파라미터를 업데이트한다.
파라미터가 업데이트 되면서 모델의 손실함수는 낮아지는 것을 기대할 수 있다.
6) 결과
w1,b1 = w[0][0].item(),b[0].item()
x1 = np.array([-30,30])
y1 = w1 * x1+b1
plt.plot(x1,y1,'r')
plt.scatter(X,y)
plt.grid()
plt.show()
업데이트된 파라미터 w1, b1을 가지고 다시 직선을 그려보면 데이터의 패턴과 유사한 그래프를 그리게 되었다.