네이버 부스트캠프 학습 정리/2주차

[PyTorch] nn.Module

AI쟁이J 2023. 3. 17. 18:06

torch.nn.Module

딥러닝을 구성하는 Layer의 base class로 input, output, forward, backward를 정의하고 학습의 대상이 되는 parameter도 정의하는 모듈

ex) Linear 함수를 구현

import torch
from torch import nn
from torch import Tensor

class MyLiner(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.weights = nn.Parameter(
                torch.randn(in_features, out_features))
        
        self.bias = nn.Parameter(torch.randn(out_features))

    def forward(self, x : Tensor):
        return x @ self.weights + self.bias
        
x = torch.randn(5, 7)

x
>>> tensor([[-0.7218,  0.2091, -1.1726, -0.8195,  0.7817,  0.5067, -0.1196],
        [-0.4944, -0.7200,  0.4529,  1.3559,  2.0116, -0.6601, -0.8189],
        [ 1.1851, -1.2569, -1.0573,  0.2247,  0.1006,  0.8877, -0.7179],
        [ 2.0409, -0.8997, -0.0208,  0.4647, -0.6975, -0.4925, -0.0798],
        [ 1.7346,  1.3235, -0.3311, -0.9184, -0.0104, -1.0035,  1.6446]])

nn.Parameter를 통한 파라미터 설정과 forward 함수를 지정해 x @ W + b의 선형화 함수를 구현했다.

layer = MyLiner(7, 1)
print(layer(x).shape)
layer(x)

>>> torch.Size([5, 1])
    tensor([[-0.4557],
            [ 2.8882],
            [ 4.6069],
            [ 3.5640],
            [-0.1540]], grad_fn=<AddBackward0>)

Liner함수 MyLiner(7,1)을 통해 5행 7열의 데이터인 x를 넣자

5x7 행렬과 7x1 행렬의 가중치의 곱으로 결과의 사이즈는 5x1이 되었다.

 

super().__init__?

파이토치에서 구현하는 클래스에서는 다른 일반적인 클래스에 들어가는 초기화 함수 __init__과 다르게 앞에 super()가 붙은 특이한 형태를 쓴다.

super()를 쓰는 이유에 대한 글

https://stackoverflow.com/questions/63058355/why-is-the-super-constructor-necessary-in-pytorch-custom-modules

 

Why is the super constructor necessary in PyTorch custom modules?

Why does super(LR, self).__init__() need to be called in the code below? I get the error "AttributeError: cannot assign module before Module.init() call" otherwise. That error is caused b...

stackoverflow.com

 

Torch.nn.Sequential

여러 모듈들을 하나로 묶어 순차적으로 실행시킬 때 torch.nn.Sequential을 사용함.

import torch
from torch import nn

class Add(nn.Module):
    def __init__(self, value):
        super().__init__()
        self.value = value

    def forward(self, x):
        return x + self.value

calculator = nn.Sequential(Add(3), Add(2), Add(5))

x = torch.tensor([1])

output = calculator(x)

output
>>> 11

두 값을 더하는 Add 모듈을 만든 후 nn.Sequential을 통해 대입한 x의 값에

Add(3), Add(2), Add(5) 모듈을 설정하면 순차적으로 x의 값은

1 > 4 > 6 > 11 이 되어 11이 최종 값이 된다.

 

nn.ModuleList

import torch
from torch import nn

class Add(nn.Module):
    def __init__(self, value):
        super().__init__()
        self.value = value

    def forward(self, x):
        return x + self.value


class Calculator(nn.Module):
    def __init__(self):
        super().__init__()
        self.add_list = nn.ModuleList([Add(2), Add(3), Add(5)])

    def forward(self, x):
        x = self.add_list[1](x)
        x = self.add_list[0](x)
        x = self.add_list[2](x)
        
        return x

x = torch.tensor([1])

calculator = Calculator()
output = calculator(x)
output

>>> 11

모듈을 쌓아 놓은 List를 생성 후 적용하고 싶은 모듈을 인덱스를 통해 호출할 수 있는 기능.

add_list[1] = Add(3) 등의 방식으로 모듈을 순차적으로 적용해 출력하면 11의 값이 나온다.

 

Python List vs ModuleList

Q. 굳이 nn.ModuleList를 호출하지 않고 Python List를 사용하면 안될까?

 

A. 기능은 동일하다. 하지만 Python List에 넣은 모듈들은 해당 리스트를 호출했을 경우 나타나지 않는다.

class PythonList(nn.Module):
    """Python List"""
    def __init__(self):
        super().__init__()

        # Python List
        self.add_list = [Add(2), Add(3), Add(5)]

    def forward(self, x):
        x = self.add_list[1](x)
        x = self.add_list[0](x)
        x = self.add_list[2](x)
        
        return x
        
PythonList()

>>> PythonList()

모듈이 나오지 않는다.

class PyTorchList(nn.Module):
    """PyTorch List"""
    def __init__(self):
        super().__init__()

        # Pytorch ModuleList
        self.add_list = nn.ModuleList([Add(2), Add(3), Add(5)])

    def forward(self, x):
        x = self.add_list[1](x)
        x = self.add_list[0](x)
        x = self.add_list[2](x)
        
        return x

PyTorchList()

>>> PyTorchList(
      (add_list): ModuleList(
        (0): Add()
        (1): Add()
        (2): Add()
      )
    )

모듈 리스트의 원소들이 나온다.

 

조건문

nn.Module에서는 조건문을 사용해서 만들 수 있다.

class Calculator(nn.Module):
    def __init__(self, cal_type):
        super().__init__()
        self.cal_type = cal_type
        self.add = Add(3)
        self.sub = Sub(3)

    def forward(self, x):
        if self.cal_type == "add":
            x = self.add(x)
        elif self.cal_type == "sub":
            x = self.sub(x)
        else:
          raise ValueError

        return x

조건문을 사용한 계산기 클래스를 생성했다. 이 경우 self.cal_type를 호출해서 이 타입이 'add' 'sub'인 경우를 조건으로 두어 해당 함수를 사용할 수 있고, 아닐 경우 ValueError를 raise할 수 있다.