이것저것 기록

[DL, PyTorch] 신경망 모델 정의하기 -- Class, nn.Module 본문

Data Science/ML & DL

[DL, PyTorch] 신경망 모델 정의하기 -- Class, nn.Module

anweh 2020. 10. 13. 23:48

 

 

PyTorch로 신경망 모델을 설계할 때, 크게 다음과 같은 세 가지 스텝을 따르면 된다.

  1. Design your model using class with Variables 
  2. Construct loss and optim
  3. Train cycle (forward, backward, update)

 


 

 

이 포스팅에선 첫번째 단계인 클래스와 변수를 정의하는 방법을 다루려고 한다.

PyTorch로 신경망을 설계할 때크게 두 가지 방법이 있다.

  • 사용자 정의 nn 모듈
  • nn.Module을 상속한 클래스 이용 

어느 방식으로 신경망을 설계해도 상관 없지만, 복잡한 신경망을 직접 설계할 때는 내 마음대로 정의하는 방식의 nn모듈을 더 많이 사용한다고 한다. (참고: tutorials.pytorch.kr/beginner/pytorch_with_examples.html

여기서 모듈이란 한 개 이상의 레이어가 모여서 구성된 것을 말한다. 모듈+모듈 = 새모듈 < 이런 공식도 성립함. 

신경망(모델)은 한 개 이상의 모듈로 이루어진 내가 최종적으로 원하는 것을 뜻함! 

 

 

1. PyTorch 모델의 기본 구조

PyTorch로 설계하는 신경망은 기본적으로 다음과 같은 구조를 갖는다. 

PyTorch 내장 모델 뿐만 아니라 사용자 정의 모델도 반드시 이 구조를 따라야 한다.

import torch.nn as nn
import torch.nn.functional as F

class Model_Name(nn.Module):
    def __init__(self):
    
        super(Model_Name, self).__init__()
        self.module1 = ...
        self.module2 = ...
        
        """
        ex)
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
        """

    def forward(self, x):
    
        x = some_function1(x)
        x = some_function2(x)
        
        """
        ex)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        """
        return x
        
model = Model_Name() # 여기에 변수를 넣어주면 됨.

PyTorch 모델로 쓰기 위해선 다음 두 가지 조건을 따라야한다. 내장된 모델(nn.Linear등)도 이를 만족한다. 

  1. torch.nn.Module을 상속해야한다. 
    • interitance: 상속; 어떤 클래스를 만들 때 다른 클래스의 기능을 그대로 가지고오는 것.
  2. __init()__과 forward()를 override 해야한다.
    • override: 재정의; torch.nn.Module(부모클래스)에서 정의한 메소드를 자식클래스에서 변경하는 것.
    • __init()__에서는 모델에서 사용될 module(nn.Linear, nn.Conv2d), activation function(nn.functional.relu, nn.functional.sigmoid)등을 정의한다. 
    • forward()에서는 모델에서 실행되어야하는 계산을 정의한다. backward 계산은 backward()를 이용하면 PyTorch가 알아서 해주니까 forward()만 정의해주면 된다. input을 넣어서 어떤 계산을 진행하하여 output이 나올지를 정의해준다고 이해하면 됨. 

 

2. nn.Module

PyTorch의 nn 라이브러리는 Neural Network의 모든 것을 포괄하는 모든 신경망 모델의 Base Class이다. 

다른 말로, 모든 신경망 모델은 nn.Module의 subclass라고 할 수 있다. 

nn.Module을 상속한 subclass가 신경망 모델로 사용되기 위해선 앞서 소개한 두 메소드를 override 해야한다. 

  • __init__(self): initialize; 내가 사용하고 싶은, 내 신경망 모델에 사용될 구성품들을 정의 및 초기화 하는 메소드이다. 
  • forward(self, x): specify the connections;  이닛에서 정의된 구성품들을 연결하는 메소드이다. 

 

 

3. PyTorch Layer의 종류

참조: pytorch.org/docs/stable/nn.html#module

  1. Linear layers
    • nn.Linear
    • nn.Bilinear
  2. Convolution layers
    • nn.Conv1d, nn.Conv2d, nn.Conv3d
    • nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d
    • nn.Unfold, nn.Fold
  3. Pooling layers
    • nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d
    • nn.MaxUnpool1d, nn.MaxUnpool2d, nn.MaxUnpool3d
    • nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d
    • nn.FractionalMaxPool2d
    • nn.LPPool1d, nn.LPPool2d
    • nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d
    • nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d
  4. Padding layers
    • nn.ReflectionPad1d, nn.ReflectionPad2d
    • nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d
    • nn.ZeroPad2d
    • nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d
  5. Normalization layers
    • nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
    • nn.GroupNorm
    • nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d
    • nn.LayerNorm
    • nn.LocalResponseNorm
  6. Recurrent layers
    • nn.RNN, nn.RNNCell
    • nn.LSTM, nn.LSTMCell
    • nn.GRU, nn.GRUCell
  7. Dropout layers
    • nn.Dropout, nn.Dropout2d, nn.Dropout3d
    • nn.AlphaDropout
  8. Sparse layers
    • nn.Embedding
    • nn.EmbeddingBag

 

 

4. PyTorch 활성화 함수의 종류

참조: pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

  1. Non-linear activations
    • nn.ELU, nn.SELU
    • nn.Hardshrink, nn.Hardtanh
    • nn.LeakyReLU, nn.PReLU, nn.ReLU, nn.ReLU6, nn.RReLU
    • nn.Sigmoid, nn.LogSigmoid
    • nn.Softplus, nn.Softshrink, nn.Softsign
    • nn.Tanh, nn.Tanhshrink
    • nn.Threshold
  2. Non-linear activations (other)
    • nn.Softmin
    • nn.Softmax, nn.Softmax2d, nn.LogSoftmax
    • nn.AdaptiveLogSoftmaxWithLoss
Comments