이것저것 기록

[DL, PyTorch] 차원 재구성 - view(), reshape() 본문

Data Science/ML & DL

[DL, PyTorch] 차원 재구성 - view(), reshape()

anweh 2020. 10. 8. 20:41

파이토치의 reshape()과 view()는 둘 다 텐서의 모양을 변경하는 데에 사용될 수 있다.

그러나 둘 사이에 약간의 차이가 존재한다. 

 

 

- reshape(): reshape은 가능하면 input의 view를 반환하고, 안되면 contiguous한 tensor로 copy하고 view를 반환한다. 

- view(): view는 기존의 데이터와 같은 메모리 공간을 공유하며 stride 크기만 변경하여 보여주기만 다르게 한다. 그래서 contigious해야만 동작하며, 아닌 경우 에러가 발생함. 

출처: subinium.github.io/pytorch-Tensor-Variable/ 

 

 

바로 연습!

import numpy as np
import torch

t = np.zeros((4,4,3)) #0으로 채워진 4x4x3 numpy array 생성 
ft = torch.FloatTensor(t) #텐서로 변환
print(ft.shape) #torch.Size([4, 4, 3])

우선 필요한 라이브러리를 불러와준다. 

t라는 변수에 ( 4 x 4 x 3 ) 모양의 numpy array를 할당해주었다. 

이전의 포스팅에서 ( batch size(세로) x width(가로) x height(길이) ) 라는 것을 배웠으니 t라는 변수의 모양을 시각화 하면 다음과 같다. 

 shape of t

이제 이 ft (t) 의 모양을 바꾸어줄 것이다. 

## Case 1
print(ft.view([-1, 3])) # ft라는 텐서를 (?, 3)의 크기로 변경
print(ft.view([-1, 3]).shape) 
#원소의 개수(4x4x3 = 48 개는) 유치한 채 3차원으로 맞추다보니까 결과적으론 16x3 이 됨. 

ft는 원래 3차원인데 view()를 통해서 [-1, 3]으로 바꾸라고 했다.

이건 3차원인 ft를 2차원으로 바꿀건데, (?? x 3)의 모양이 되게 바꾸란 말이다.

ft는 총 48개의 원소를 가지고 있는데, 2차원이면서 (?? x 3)의 모양을 가지려면 ??는 16이 되어야한다. 

## Case 2
print(ft.view([-1, 2, 3])) # ft라는 텐서를 (?, 2, 3)의 크기로 변경
print(ft.view([-1, 2, 3]).shape) 

Case2에선 동일한 ft 텐서를 3차원이면서 (?? x 2 x 3)의 모양이 되게 바꾸려고 한다. 

48 / 6 = 8! 

원소의 개수(4x4x3 = 48 개는) 유치한 채 3차원이면서 (? x 2 x 3)의 모양에 맞추다보니까 결과적으론 (8 x 2 x 3)이 된다. 

 

 

reshape()의 사용법도 view()와 비슷하다. 

깃헙의 유명 네트워크의 코드를 보면 거의 다 reshape()을 사용하긴 한다. 

## Case 3
r = np.zeros((5, 5, 10))
fr = torch.FloatTensor(r)
print(fr.shape)

print(fr.reshape(10, 5, 5).shape) #torch.Size([10, 5, 5])
print(fr.reshape(1, -1).shape) #torch.Size([1, 250])

 

 

 

Comments