parameter
torch.nn.parameter 클래스를 통해 자동미분이 가능한 torch.Tensor 객체들을 만들 수 있다.
이 클래스는 torch.Tensor 클래스를 상속받은 자식 클래스이며 torch.nn.Module 클래스의 attribute로 할당하는 경우
model.parameters()에 자동으로 추가된다.
예를 들어
linear transformation \(y = xw + b\)의 식에서 x는 torch.Tensor로 구현할 수 있지만 가중값 w와 편향 b는 만들 수 없다.
이 경우 nn.Module 내의 미리 만들어진 tensor를 보관할 수 있는 nn.Parameter 모듈을 사용한다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Linear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
# W,b parameter를 1로 초기화 해 생성
self.W = Parameter(torch.ones((out_features, in_features)))
self.b = Parameter(torch.ones((out_features)))
def forward(self, x):
output = torch.addmm(self.b, x, self.W.T)
return output
x = torch.Tensor([[1, 2],
[3, 4]])
linear = Linear(2, 3)
output = linear(x)
output
>>> torch.Tensor([[4, 4, 4],
[8, 8, 8]])
Q. 그렇다면, 왜 W, b는 tensor로 만들지 않고 굳이 Parameter 클래스를 사용할까?
A. gradient를 자동 계산하는 기능을 제공하는 PyTorch의 기능을 사용하기 위해 Parameter 클래스를 사용한다.
Tensor 클래스를 통해 만든 값은 그 정보가 저장되지 않기 때문에 기존의 정보를 사용한 gradient를 계산할 수 없기 때문에 자동 미분 기능을 사용할 수 없다.
Buffer
Parameter가 아니더라도 그 값을 저장하고 싶은 Tensor는 buffer에 등록해서 저장할 수 있다.
요약하면
Tensor
gradinet 계산 X
값 업데이트 X
모델 저장시 값 저장 X
Parameter
gradinet 계산 O
값 업데이트 O
모델 저장시 값 저장 O
Buffer
gradinet 계산 X
값 업데이트 X
모델 저장시 값 저장 O
라고 할 수 있다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.parameter = Parameter(torch.Tensor([7]))
self.tensor = torch.Tensor([7])
self.register_buffer('buffer', self.tensor) # torch.Tensor([7])를 'buffer' 이름으로 등록
model = Model()
try:
buffer = model.get_buffer('buffer')
if buffer == 7:
print(model.state_dict())
except:
print("해당 buffer가 존재하지 않음")
>>> OrderedDict([('parameter', tensor([7.])), ('buffer', tensor([7.]))])
이 Buffer 기능은 대표적인 예시로 PyTorch 모듈의 BatchNorm에서 사용된다.
'네이버 부스트캠프 학습 정리 > 2주차' 카테고리의 다른 글
[PyTorch] PyTorch 알쓸신잡 (0) | 2023.03.19 |
---|---|
[PyTorch] PyTorch의 데이터 (0) | 2023.03.19 |
[PyTorch] nn.Module (0) | 2023.03.17 |
[PyTorch] 파이토치의 기본 (0) | 2023.03.17 |