J o e
JoE's StOrY
J o e
  • 분류 전체보기 (206)
    • workSpace (184)
      • 도메인 지식 (2)
      • ALGORITHM (39)
      • ANDROID (3)
      • JS (0)
      • JAVA (21)
      • MYSQL (6)
      • NETWORK (3)
      • PYTHON (91)
      • LINUX (9)
      • PROJECT (4)
    • Others (20)
      • Opic (1)
      • myLife (17)
      • popSong (1)
      • 정보처리기사 (1)
    • 훈빠의 특강 (0)
      • opencv (0)
      • python (0)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

  • The code with long statements is⋯
  • 매일 매일이 행복하고 밝은 날이 될거에요

인기 글

태그

  • 태블릿 연동
  • 이미지 연산
  • sort_value
  • 넘파이 문제
  • 넘파이함수
  • full loss
  • numpy
  • MySQL
  • ㅖ43
  • Python
  • linearclassification
  • sort_index
  • java
  • 파이썬
  • 단어의 개수
  • How to create a GUI in Java with JFrame?
  • dao
  • read_html
  • DTO
  • Fully Connected Network

최근 댓글

최근 글

티스토리

J o e

WHY?

[PyTorch] 데이터 불러오기
workSpace/PYTHON

[PyTorch] 데이터 불러오기

2021. 1. 19. 22:20

 

 

 

https://github.com/pytorch/examples/tree/master/mnist

 

PyTorch Data Preprocess¶

In [4]:
import torch

from torchvision import datasets, transforms
import warnings
warnings.filterwarnings('ignore')
 

Data Loader 부르기¶

파이토치는 DataLoader를 불러 model에 넣음

In [5]:
# 사이즈를 넣어줌.
batch_size = 32 
test_batch_size = 32
In [6]:
# DataLoader함수를 사용해서 MNIST의 데이터를 불러옴.
# 인자에 디렉토리에 주소값, train 여부 , 다운로드 여부, transform으로 속성을 지정한다. 사이즈 조정하고, 셔플로 섞어줌.
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize(mean=(0.5,), std=(0.5,))
                   ])),
    batch_size=batch_size,
    shuffle=True)
In [7]:
#이번엔 train이 아니고 test이니까 train을 False로 지정, transforms, 즉 변형된 값들에 대한 구성을 설정해준다.
# transforms를 
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset', train=False, 
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5))
                   ])),
    batch_size=test_batch_size,
    shuffle=True)
 

첫번재 iteration에서 나오는 데이터 확인¶

In [7]:
images, labels = next(iter(train_loader))
In [8]:
images.shape
 
Out[8]:
torch.Size([32, 1, 28, 28])
In [9]:
labels.shape
Out[9]:
torch.Size([32])
 

PyTorch는 TensorFlow와 다르게 [Batch Size, Channel, Height, Width] 임을 명시해야함

 

데이터 시각화¶

In [10]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
 

In [11]:
images[0].shape
Out[11]:
torch.Size([1, 28, 28])
In [12]:
torch_image = torch.squeeze(images[0])
torch_image.shape
Out[12]:
torch.Size([28, 28])
In [13]:
image = torch_image.numpy()
image.shape
Out[13]:
(28, 28)
In [14]:
label = labels[0].numpy()
In [15]:
label.shape
Out[15]:
()
In [16]:
label
Out[16]:
array(7, dtype=int64)
In [18]:
plt.title(label)
plt.imshow(image, 'gray')
plt.show()
 
 
 

참고 블로그

mjdeeplearning.tistory.com/81

'workSpace > PYTHON' 카테고리의 다른 글

[ML] Entropy, impurity, gini impurity, Ensemble, Random Forest 설명  (0) 2021.01.20
[ML] Decision Tree & Ensemble 설명  (0) 2021.01.20
[PyTorch] 기본 사용법  (0) 2021.01.19
[ML] 타이타닉 생존자 예측해보기  (0) 2021.01.19
[ML] 타이타닉 생존자 예측하기 및 설명  (0) 2021.01.19
    'workSpace/PYTHON' 카테고리의 다른 글
    • [ML] Entropy, impurity, gini impurity, Ensemble, Random Forest 설명
    • [ML] Decision Tree & Ensemble 설명
    • [PyTorch] 기본 사용법
    • [ML] 타이타닉 생존자 예측해보기
    J o e
    J o e
    나의 과거를 기록합니다.

    티스토리툴바