네이버 부스트캠프 학습 정리/2주차

[PyTorch] parameter & buffer

AI쟁이J 2023. 3. 19. 15:02

parameter

torch.nn.parameter 클래스를 통해 자동미분이 가능한 torch.Tensor 객체들을 만들 수 있다.

이 클래스는 torch.Tensor 클래스를 상속받은 자식 클래스이며 torch.nn.Module 클래스의 attribute로 할당하는 경우

model.parameters()에 자동으로 추가된다.

예를 들어

linear transformation \(y = xw + b\)의 식에서 x는 torch.Tensor로 구현할 수 있지만 가중값 w와 편향 b는 만들 수 없다.

이 경우 nn.Module 내의 미리 만들어진 tensor를 보관할 수 있는 nn.Parameter 모듈을 사용한다.

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()

        # W,b parameter를 1로 초기화 해 생성
        self.W = Parameter(torch.ones((out_features, in_features)))
        self.b = Parameter(torch.ones((out_features)))

    def forward(self, x):
        output = torch.addmm(self.b, x, self.W.T)

        return output
        
        
x = torch.Tensor([[1, 2],
                  [3, 4]])

linear = Linear(2, 3)
output = linear(x)
output
>>> torch.Tensor([[4, 4, 4],
                  [8, 8, 8]])

Q. 그렇다면, 왜 W, b는 tensor로 만들지 않고 굳이 Parameter 클래스를 사용할까?

 

A. gradient를 자동 계산하는 기능을 제공하는 PyTorch의 기능을 사용하기 위해 Parameter 클래스를 사용한다.

Tensor 클래스를 통해 만든 값은 그 정보가 저장되지 않기 때문에 기존의 정보를 사용한 gradient를 계산할 수 없기 때문에 자동 미분 기능을 사용할 수 없다.

 

Buffer

Parameter가 아니더라도 그 값을 저장하고 싶은 Tensor는 buffer에 등록해서 저장할 수 있다.

요약하면

Tensor

gradinet 계산 X

값 업데이트 X

모델 저장시 값 저장 X

 

Parameter

gradinet 계산 O

값 업데이트 O

모델 저장시 값 저장 O

 

Buffer

gradinet 계산 X

값 업데이트 X

모델 저장시 값 저장 O

 

라고 할 수 있다.

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.parameter = Parameter(torch.Tensor([7]))
        self.tensor = torch.Tensor([7])
        self.register_buffer('buffer', self.tensor) # torch.Tensor([7])를 'buffer' 이름으로 등록
        
model = Model()

try:
    buffer = model.get_buffer('buffer')
    if buffer == 7:
        print(model.state_dict())
except:
    print("해당 buffer가 존재하지 않음")
        
>>> OrderedDict([('parameter', tensor([7.])), ('buffer', tensor([7.]))])

이 Buffer 기능은 대표적인 예시로 PyTorch 모듈의 BatchNorm에서 사용된다. 

BatchNorm의 공식 소스코드 中

 

'네이버 부스트캠프 학습 정리 > 2주차' 카테고리의 다른 글

[PyTorch] PyTorch 알쓸신잡  (0) 2023.03.19
[PyTorch] PyTorch의 데이터  (0) 2023.03.19
[PyTorch] nn.Module  (0) 2023.03.17
[PyTorch] 파이토치의 기본  (0) 2023.03.17