머신 러닝 이론

backward() 와 requires_grad = True

skawlsgus2 2025. 3. 27. 14:28

이전 포스팅에서 out.backward() 로 텐서의 기울기를 저장하는 과정을 보았다. 하지만 

기울기를 저장할 때 어떤 텐서의 기울기를 저장할지 지정할 수 있을까? 

import torch

# 초기 가중치 설정
w = torch.tensor(1.0, requires_grad=True)

# 입력값 x, 정답 y
x = torch.tensor(3.0)
y = torch.tensor(6.0)  # y = 2 * x

# 예측값 (y_pred)
y_pred = w * x  # 모델이 예측한 값

# 손실(loss) 계산
loss = (y_pred - y) ** 2  # MSE (Mean Squared Error)

# 역전파 실행 (미분 계산)
loss.backward()

# w의 기울기 출력
print(w.grad)

이 전과 같은 코드로 w, x, y를 지정하고 loss.backward() 로 loss를 w로 미분한 값을 w.grad에 저장하였다.

이 때 w를 지정하여 기울기를 저장할 수 있는 것은 w = torch.tensor(1.0, requires_grad=True) 이 코드 때문이다.

 requires_grad=True 로 만들었기 때문에 w에 대한 기울기를 추적하게 된다. 그렇다면 텐서끼리 관계가 있을 때는 어떻게 될까

 

a,b,c 텐서가 있고 출력 out 이 있다고 가정해 보자 

a = [[1,1],[1,1]] 로 주고, b = a + 2, c = b**2 의 관계로 되어 있다. 출력 out 은 c 텐서 원소의 총합으로 정의한다.

이 때 out의 a에 대한 기울기를 알아보자 

a = torch.ones(2,2,requires_grad = True)
b = a + 2
c = b**2
out = c.sum()
out.backward()

a 에 대한 기울기를 추적하기 위해서 a 텐서는 requires_grad = True 로 설정하였다. 그리고 out.backward()를 하여 a.grad에 기울기를 저장한다. 이 때 각 텐서에 대한 기울기를 살펴 보자. 

1) d(out)/dc : out은 c 텐서 원소의 합이다 즉, out = c1+c2+c3+c4 라는 것이다. 이 때 c1,c2,c3,c4 에 대해 각각 미분하면

1이 나오므로 d(out)/dc = [[1,1],[1,1]] 이 된다. 그리고 c의 값은 (1+2)**2 = 9가 된다. 

c 에 대한 기울기 추적 X 일 때

c.grad가 None 인 이유는 requires_grad = True로 설정하지 않아 기울기를 추적하지 않기 때문이다. 만약 기울기를 추적한다면 [[1,1],[1,1]] 이 될 것이다

c.grad가 [[1,1],[1,1]] 이 되었다.

2) d(out)/db : d(out)/db은 체인룰을 적용하면 d(out)/dc * dc/db 가 된다. d(out)/dc 은 [[1,1],[1,1]] 이었으므로 dc/db만 구하면 d(out)/db 값을 구할 수 있다. c = b**2 관계에 있으므로 dc/db = 2b 가 된다. 이 때 b = a+2 = 3 이므로 [[3,3],[3,3]] 이 된다

또한 dc/db 는 [[6,6],[6,6]] 이 될 것이다. 따라서 d(out)/db 은 [[6,6],[6,6]] 이 된다.

b 에 대한 기울기 추적 X 일 때

c 의 경우와 같이 requires_grad = True로 설정하지 않아 기울기가 None 으로 표시된다. 만약 기울기를 추적한다면, 다음과 같은 결과가 나온다. 

d(out)/db 는 [[6,6],[6,6]]

3) d(out)/da : 체인룰을 적용하면 d(out)/da = d(out)/db * db/da 이다. 이 때 d(out)/db 는 [[6,6],[6,6]] 이므로 db/da 를 구하면 d(out)/da  를 알 수 있다.  a 텐서는 [[1,1],[1,1]] 이고, b = a + 2 이므로 db/da = 1 이다. 즉 [[1,1],[1,1]] 이 된다. 

따라서 d(out)/da = [[6,6],[6,6]] * [[1,1],[1,1]] = [[6,6],[6,6]] 이다.

d(out)/da 는 [[6,6],[6,6]]

'머신 러닝 이론' 카테고리의 다른 글

Cross Entropy 손실함수  (0) 2025.04.05
컨볼루션 레이어  (0) 2025.03.28
Mnist 데이터 셋 불러오기  (0) 2025.03.27
loss.backward()  (0) 2025.03.26
torch.empty() 와 torch.rand() 의 차이  (0) 2025.03.25