머신 러닝 이론

Mnist 데이터 셋 불러오기

skawlsgus2 2025. 3. 27. 17:53

torchvision 패키지에 존재하는 Mnist 데이터 셋을 불러와 확인해보자 

 

우선 전체 과정은 다음과 같다

  1. 이미지 전처리 정의
  2. 데이터를 전처리하여 훈련용과 테스트용으로 나누기
  3. 데이터로더 정의
  4. 데이터로더에 따라 데이터 불러오기
  5. 불러온 데이터 그려서 확인하기 

0) 사용할 패키지

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

 

 

DataLoader는 배치 사이즈가 몇 인지(한 번에 가져올 이미지 수), 셔플을 할 것인지, 어떤 데이터 셋을 대상으로 하는지

정의하여 어떻게 데이터를 불러올 것인가에 대한 설명서 같은 역할을 한다.  

 

transforms 는 전처리 단계를 구성할 수 있도록 한다.

 

datasets 에는 다양한 데이터 셋이 들어있다. 여기서는 MNIST 데이터 셋을 사용한다.

 

matplotlib 은 불러온 데이터를 그려볼 수 있게 해준다.

 

1) 이미지 전처리 정의 

mnist_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(mean = (0.5,),std = (1.0,))])

Compose 클래스를 사용하여 여러 개의 이미지 변환(transform) 함수를 순차적으로 적용할 수 있도록 묶어준다.

여기서는 텐서로 변환, 이미지 정규화의 두 가지 변환 함수를 순차적으로 적용할 수 있도록 정의해 주었다.

텐서로 변환하는 이유는 PIL Image 또는 NumPy 배열 형태로 존재하는 이미지를 PyTorch 모델이 입력 받을 수 있는 텐서 형태로 변환하기 위해서 이고, 

이미지 정규화를 하는 이유는 픽셀 값을 -1~1 사이의 작은 값으로 변환하여 모델이 최적의 가중치를 찾는 속도를 빠르게 만들기 위함이다.

 

2) 데이터를 전처리하여 훈련용과 테스트용으로 나누기

trainset = datasets.MNIST(root = '/content/',
                          train = True,
                          download = True,
                          transform = mnist_transform)
testset = datasets.MNIST(root = '/content/',
                          train = False,
                          download = True,
                          transform = mnist_transform)

과정 1) 에서 정의한 전처리를 적용하여 /content/ 경로에 mnist_train_small.csv, mnist_test.csv 파일을 생성한다.

다만 코랩 환경에서는 /content/sample_data/ 경로에 생성된다.

MNIST  데이터 셋을 훈련용과 테스트용으로 나누어 전처리한다.

 

3) 데이터로더 정의

train_loader = DataLoader(trainset, batch_size = 8, shuffle = True, num_workers = 2)
test_loader = DataLoader(testset, batch_size = 8, shuffle = False, num_workers = 2)

훈련용 로더는 trainset 을 8장 씩 셔플하여 가져오고, 테스트용 로더는 testset을 8장 씩 셔플하여 가져도록 정의한다. 

 

4) 데이터로더에 따라 데이터 불러오기

dataiter = iter(train_loader)
images,labels = dataiter.__next__()
print(images.shape)
print(labels.shape)

어떤 데이터를 몇 장씩 가져올지 데이터로더로 정의를 했으니 이제 데이터를 배치 사이즈 만큼 가져와 보자. 

훈련용 데이터로더를 매개변수로 iter() 함수를 실행하면 iterator 객체를 반환한다(dataiter)

iterator 객체의 __next__() 메서드를 실행하면 배치 데이터를 가지고 온다. 구체적으로는 이미지와 라벨을 가지고 온다.

 

5) 데이터 그려보기 

figure = plt.figure(figsize = (12,6))
cols, rows = 4,2
for i in range (1,cols*rows +1):
  sample_idx = torch.randint(len(trainset),size = (1,)).item()
  img, label = trainset[sample_idx]
  figure.add_subplot(rows,cols,i)
  plt.title(label)
  plt.axis('off')
  plt.imshow(img.squeeze(),cmap = 'gray')
plt.show()

8 개의 서브플롯을 만들어 이미지를 확인해 본다. trainset 의 크기 보다 작은 임의의 정수 1개를 담는 텐서를 만들어 해당 정수를 sample_idx 에 담는다. trainset 을 인덱싱하여 이미지와 라벨정보를 각각 img와 label에 담는다.

3차원 텐서의 img를 imshow로 확인하기 위해 2차원으로 축소하여 확인한다.

결과는 다음과 같다

서브플롯을 만들어 그려본 이미지 데이터

#전체 코드
# torchvidion 의 datasets 를 사용해 데이터셋 가져오기
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

# torchvision.transforms의 Compose 함수를 사용해서 DataLoader의 인자로 들어갈 transform 정의
mnist_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(mean = (0.5,),std = (1.0,))])

#trainset, testset 나누기
trainset = datasets.MNIST(root = '/content/',
                          train = True,
                          download = True,
                          transform = mnist_transform)
testset = datasets.MNIST(root = '/content/',
                          train = False,
                          download = True,
                          transform = mnist_transform)

#DataLoader정의 
train_loader = DataLoader(trainset, batch_size = 8, shuffle = True, num_workers = 2)
test_loader = DataLoader(testset, batch_size = 8, shuffle = False, num_workers = 2)

#DataLoader로 데이터 셋을 배치 단위로 가져오기
dataiter = iter(train_loader)
images,labels = dataiter.__next__()
print(images.shape)
print(labels.shape)

#trainset에서 하나씩 꺼내어 그리기
figure = plt.figure(figsize = (12,6))
cols, rows = 4,2
for i in range (1,cols*rows +1):
  sample_idx = torch.randint(len(trainset),size = (1,)).item()
  img, label = trainset[sample_idx]
  figure.add_subplot(rows,cols,i)
  plt.title(label)
  plt.axis('off')
  plt.imshow(img.squeeze(),cmap = 'gray')
  print(img.squeeze().shape)
plt.show()