이전 포스팅에서 out.backward() 로 텐서의 기울기를 저장하는 과정을 보았다. 하지만
기울기를 저장할 때 어떤 텐서의 기울기를 저장할지 지정할 수 있을까?
import torch
# 초기 가중치 설정
w = torch.tensor(1.0, requires_grad=True)
# 입력값 x, 정답 y
x = torch.tensor(3.0)
y = torch.tensor(6.0) # y = 2 * x
# 예측값 (y_pred)
y_pred = w * x # 모델이 예측한 값
# 손실(loss) 계산
loss = (y_pred - y) ** 2 # MSE (Mean Squared Error)
# 역전파 실행 (미분 계산)
loss.backward()
# w의 기울기 출력
print(w.grad)
이 전과 같은 코드로 w, x, y를 지정하고 loss.backward() 로 loss를 w로 미분한 값을 w.grad에 저장하였다.
이 때 w를 지정하여 기울기를 저장할 수 있는 것은 w = torch.tensor(1.0, requires_grad=True) 이 코드 때문이다.
requires_grad=True 로 만들었기 때문에 w에 대한 기울기를 추적하게 된다. 그렇다면 텐서끼리 관계가 있을 때는 어떻게 될까
a,b,c 텐서가 있고 출력 out 이 있다고 가정해 보자
a = [[1,1],[1,1]] 로 주고, b = a + 2, c = b**2 의 관계로 되어 있다. 출력 out 은 c 텐서 원소의 총합으로 정의한다.
이 때 out의 a에 대한 기울기를 알아보자
a = torch.ones(2,2,requires_grad = True)
b = a + 2
c = b**2
out = c.sum()
out.backward()
a 에 대한 기울기를 추적하기 위해서 a 텐서는 requires_grad = True 로 설정하였다. 그리고 out.backward()를 하여 a.grad에 기울기를 저장한다. 이 때 각 텐서에 대한 기울기를 살펴 보자.
1) d(out)/dc : out은 c 텐서 원소의 합이다 즉, out = c1+c2+c3+c4 라는 것이다. 이 때 c1,c2,c3,c4 에 대해 각각 미분하면
1이 나오므로 d(out)/dc = [[1,1],[1,1]] 이 된다. 그리고 c의 값은 (1+2)**2 = 9가 된다.
c.grad가 None 인 이유는 requires_grad = True로 설정하지 않아 기울기를 추적하지 않기 때문이다. 만약 기울기를 추적한다면 [[1,1],[1,1]] 이 될 것이다
2) d(out)/db : d(out)/db은 체인룰을 적용하면 d(out)/dc * dc/db 가 된다. d(out)/dc 은 [[1,1],[1,1]] 이었으므로 dc/db만 구하면 d(out)/db 값을 구할 수 있다. c = b**2 관계에 있으므로 dc/db = 2b 가 된다. 이 때 b = a+2 = 3 이므로 [[3,3],[3,3]] 이 된다
또한 dc/db 는 [[6,6],[6,6]] 이 될 것이다. 따라서 d(out)/db 은 [[6,6],[6,6]] 이 된다.
c 의 경우와 같이 requires_grad = True로 설정하지 않아 기울기가 None 으로 표시된다. 만약 기울기를 추적한다면, 다음과 같은 결과가 나온다.
3) d(out)/da : 체인룰을 적용하면 d(out)/da = d(out)/db * db/da 이다. 이 때 d(out)/db 는 [[6,6],[6,6]] 이므로 db/da 를 구하면 d(out)/da 를 알 수 있다. a 텐서는 [[1,1],[1,1]] 이고, b = a + 2 이므로 db/da = 1 이다. 즉 [[1,1],[1,1]] 이 된다.
따라서 d(out)/da = [[6,6],[6,6]] * [[1,1],[1,1]] = [[6,6],[6,6]] 이다.
'머신 러닝 이론' 카테고리의 다른 글
Cross Entropy 손실함수 (0) | 2025.04.05 |
---|---|
컨볼루션 레이어 (0) | 2025.03.28 |
Mnist 데이터 셋 불러오기 (0) | 2025.03.27 |
loss.backward() (0) | 2025.03.26 |
torch.empty() 와 torch.rand() 의 차이 (0) | 2025.03.25 |