본문 바로가기

카테고리 없음

PyTorch를 이용한 ImageNet 데이터 세트 학습 코드 및 결과 총 정리 (ResNet 활용하여 80% 이상 정확도 얻기)

 본 게시글은 필자의 개인 경험을 토대로 작성된 것으로, 잘못된 정보를 포함하고 있을 수 있습니다. 또한 2023년 4월을 기준으로 작성되어, 현재 기준으로는 잘못된 정보가 포함되어 있을 수 있습니다. ※

 

  NVIDIA Tesla V100 혹은 NVIDIA TITAN RTX과 같은 GPU 하나만 있어도 3~4일 정도면 충분히 ImageNet을 학습하여 Top-1 정확도(accuracy)로 70% 가까이 뽑아낼 수 있다. 이번 포스팅에서는 간단히 Hugging Face에서 제공하는, 흔히 알려진 ImageNet 데이터 세트를 활용하여 CNN 모델(ResNet-18)을 학습을 진행하는 방법에 대해서 알아보겠다. 구체적으로 ResNet-18 모델을 사용하여 간단히 학습을 진행해 볼 수 있다.

 

  가장 먼저, Jupyter Notebook 상에서 다음과 같이 사용할 GPU의 번호를 설정할 수 있다.

 

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

 

  이후에 다음과 같이 Hugging Face의 ImageNet 1k 데이터 세트에 대하여 사용할 수 있는 커스텀 데이터 세트(custom dataset)를 정의한 것을 확인할 수 있다.

 

  Hugging Face의 ImageNet 데이터 세트 중에는 회색(grayscale) 이미지가 포함되어 있어서, 필자는 다음과 같이 항상 RGB 색상의 이미지(3 channels) 형태로 이미지를 처리할 수 있도록 코드를 작성한 것을 확인할 수 있다.

 

import torch


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, train_mode=True, transforms=None):
        if train_mode:
            self.dataset = imagenet_dataset["train"]
        else:
            self.dataset = imagenet_dataset["validation"]
        if transforms:
            self.transforms = transforms
        
    def __getitem__(self, index):
        image = self.dataset[index]["image"]
        label = self.dataset[index]["label"]
    
        current = image.convert("RGB")
        if self.transforms:
            current = self.transforms(current)
    
        return current, label
    
    def __len__(self):
        return len(self.dataset)

 

  이후에 다음과 같이 load_dataset 라이브러리를 사용하여 데이터 세트를 불러 올 수 있다. 이를 위해 기본적으로 Hugging Face에 로그인하는 과정이 필요할 수 있다. (Hugging Face에서 데이터 세트가 다운로드 된다.)

 

from datasets import load_dataset

imagenet_dataset = load_dataset("imagenet-1k")

 

  이후에 학습을 위하여 다음과 같이 일반적으로 사용되는 학습 목적의 데이터 증진(augmentation) 기법을 사용할 수 있다. 테스트 단계에서는 사실상 이미지 크기를 224 X 224로 변경하여 모델에 넣는 것으로 이해할 수 있다.

 

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

augment_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

batch_size = 256

train_dataset = CustomDataset(train_mode=True, transforms=augment_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=40)

test_dataset = CustomDataset(train_mode=False, transforms=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=40)

print(len(train_dataset))
print(len(test_dataset))

 

  위 코드를 실행하면 다음과 같이 학습 데이터와 검증 데이터의 개수가 출력되는 것을 확인할 수 있다. Hugging Face에서 제공하는 ImageNet 1K 데이터 세트ImageNet 2012 1,000 classes 데이터 세트와 정확히 개수가 일치하는 것을 확인할 수 있다.

 

  ▶ 학습 데이터 개수: 1,281,167개

  ▶ 검증 데이터 개수: 50,000개

 

  이것은 일반적인 ImageNet과 정확히 일치한다. 이후에 실질적인 학습(training) 및 테스트(test) 수행 코드는 다음과 같다.

 

def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        augmented_images = images.cuda()
        targets = labels.cuda()

        optimizer.zero_grad()

        benign_outputs = net(augmented_images)
        loss = criterion(benign_outputs, targets)
        loss.backward()

        optimizer.step()
        train_loss += loss.item()
        _, predicted = benign_outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print(f'[{epoch}] Total benign train accuarcy:', 100. * correct / total)
    return 100. * correct / total


def test(epoch):
    net.eval()
    benign_loss = 0
    benign_correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            total += targets.size(0)

            outputs = net(inputs)
            loss = criterion(outputs, targets)
            benign_loss += loss.item()

            _, predicted = outputs.max(1)
            benign_correct += predicted.eq(targets).sum().item()

    print(f'[{epoch}] Total benign test accuarcy:', 100. * benign_correct / total)
    return 100. * benign_correct / total

 

  또한, 학습(training) 코드와 테스트(test) 코드가 각각 함수로 나누어져 작성되어 있기 때문에, 체계적으로 로깅(logging)을 진행하면서 학습이 되는 과정을 기록할 수 있을 것이다. 필자는 다음과 같이 logger 객체를 사용했다.

 

from torchvision.transforms.functional import to_pil_image
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

log_path = "training.log"
handler = logging.FileHandler(log_path, 'a')
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

logger.addHandler(handler)

 

  다음과 같이 학습을 해볼 수 있다. 필자는 0.1부터 학습률(learning rate)이 시작되어, 코사인 학습률(cosine annealing learning rate) 기법을 적용하여 서서히 학습률이 감소하도록 하였다. 이는 일반적으로 많이 사용되는 하이퍼 파라미터 세팅이기도 하다. 결과적으로, 다음과 같은 코드를 작성할 수 있다.

 

import torch.optim as optim
import time

criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
net = ResNet18(channel=3, num_classes=1000, im_size=(224, 224, 3))
net = net.cuda()

optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0005)
num_epochs = 200
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.0001)
best_accuracy = 0

start_time = time.time()
for epoch in range(0, num_epochs):
    train_accuracy = train(epoch)
    test_accuracy = test(epoch)
    best_accuracy = max(best_accuracy, test_accuracy)
    scheduler.step()

    message = f"[{epoch}] train_accuracy: {train_accuracy:.2f}%, test_accuracy: {test_accuracy:.2f}%, elapsed time: {time.time() - start_time:.4f} seconds."
    print(message)
    logger.info(message)

 

  로그 파일 "training.log" 파일의 내용으로는 바로 다음과 같이 기록되는 것을 확인할 수 있었다. (일부 생략된 결과) 그렇다면 얼마나 많은 시간이 소요되는가? 1시간에 2번의 epoch(반복)이 수행되므로, 최종적으로 200번의 epoch을 위해 약 4일(48시간) 정도가 소요되는 것으로 이해할 수 있다. 즉, 최종적인 학습이 모두 수행되기까지 약 4일 정도가 소요된다.

 

INFO - [0] train_accuracy: 0.39%, test_accuracy: 0.08%, elapsed time: 50.1816 seconds.
INFO - [1] train_accuracy: 21.52%, test_accuracy: 19.09%, elapsed time: 3417.9323 seconds.
INFO - [2] train_accuracy: 25.13%, test_accuracy: 20.68%, elapsed time: 5141.1221 seconds.
INFO - [3] train_accuracy: 26.67%, test_accuracy: 25.18%, elapsed time: 6865.0347 seconds.
INFO - [4] train_accuracy: 27.41%, test_accuracy: 22.51%, elapsed time: 8592.6948 seconds.
INFO - [5] train_accuracy: 27.89%, test_accuracy: 22.04%, elapsed time: 10312.3661 seconds.
INFO - [6] train_accuracy: 28.21%, test_accuracy: 21.88%, elapsed time: 12036.4238 seconds.
INFO - [7] train_accuracy: 28.42%, test_accuracy: 25.27%, elapsed time: 13760.0541 seconds.
INFO - [8] train_accuracy: 28.60%, test_accuracy: 26.38%, elapsed time: 15484.2011 seconds.
INFO - [9] train_accuracy: 28.79%, test_accuracy: 23.05%, elapsed time: 17207.5558 seconds.
INFO - [10] train_accuracy: 28.87%, test_accuracy: 24.49%, elapsed time: 18927.5624 seconds.
INFO - [11] train_accuracy: 29.07%, test_accuracy: 21.52%, elapsed time: 20643.2920 seconds.
INFO - [12] train_accuracy: 29.14%, test_accuracy: 27.60%, elapsed time: 22366.7608 seconds.
INFO - [13] train_accuracy: 29.21%, test_accuracy: 27.60%, elapsed time: 24085.4733 seconds.
INFO - [14] train_accuracy: 29.36%, test_accuracy: 20.93%, elapsed time: 25808.5252 seconds.
INFO - [15] train_accuracy: 29.36%, test_accuracy: 24.45%, elapsed time: 27529.0205 seconds.
INFO - [16] train_accuracy: 29.51%, test_accuracy: 24.26%, elapsed time: 29252.7996 seconds.
INFO - [17] train_accuracy: 29.56%, test_accuracy: 22.86%, elapsed time: 30971.7989 seconds.
INFO - [18] train_accuracy: 29.62%, test_accuracy: 26.58%, elapsed time: 32694.5667 seconds.
INFO - [19] train_accuracy: 29.61%, test_accuracy: 27.42%, elapsed time: 34418.3205 seconds.
INFO - [20] train_accuracy: 29.79%, test_accuracy: 27.31%, elapsed time: 36142.2254 seconds.
INFO - [21] train_accuracy: 29.81%, test_accuracy: 26.69%, elapsed time: 37861.3492 seconds.
INFO - [22] train_accuracy: 29.85%, test_accuracy: 28.85%, elapsed time: 39581.4022 seconds.
INFO - [23] train_accuracy: 29.93%, test_accuracy: 23.09%, elapsed time: 41306.3136 seconds.
INFO - [24] train_accuracy: 29.97%, test_accuracy: 25.78%, elapsed time: 42988.2724 seconds.
INFO - [25] train_accuracy: 30.06%, test_accuracy: 24.52%, elapsed time: 44682.1765 seconds.
INFO - [26] train_accuracy: 30.08%, test_accuracy: 25.99%, elapsed time: 46396.7877 seconds.
INFO - [27] train_accuracy: 30.14%, test_accuracy: 24.89%, elapsed time: 48107.5016 seconds.
INFO - [28] train_accuracy: 30.19%, test_accuracy: 26.90%, elapsed time: 49821.4210 seconds.
INFO - [29] train_accuracy: 30.25%, test_accuracy: 19.38%, elapsed time: 51540.0017 seconds.
INFO - [30] train_accuracy: 30.34%, test_accuracy: 26.31%, elapsed time: 53255.1550 seconds.
INFO - [31] train_accuracy: 30.42%, test_accuracy: 29.78%, elapsed time: 54973.4562 seconds.
INFO - [32] train_accuracy: 30.44%, test_accuracy: 21.07%, elapsed time: 56691.9628 seconds.
INFO - [33] train_accuracy: 30.55%, test_accuracy: 29.46%, elapsed time: 58373.2567 seconds.
INFO - [34] train_accuracy: 30.59%, test_accuracy: 25.63%, elapsed time: 60078.9948 seconds.
...
INFO - [40] train_accuracy: 30.93%, test_accuracy: 27.60%, elapsed time: 70349.2042 seconds.
INFO - [50] train_accuracy: 31.70%, test_accuracy: 31.77%, elapsed time: 87564.8081 seconds.
INFO - [60] train_accuracy: 32.58%, test_accuracy: 29.86%, elapsed time: 105548.2592 seconds.
INFO - [70] train_accuracy: 33.64%, test_accuracy: 25.45%, elapsed time: 123313.4451 seconds.
INFO - [80] train_accuracy: 34.82%, test_accuracy: 33.87%, elapsed time: 140434.3901 seconds.
INFO - [90] train_accuracy: 36.24%, test_accuracy: 32.73%, elapsed time: 157645.7088 seconds.
INFO - [100] train_accuracy: 37.80%, test_accuracy: 39.62%, elapsed time: 174914.3022 seconds.
INFO - [110] train_accuracy: 39.61%, test_accuracy: 37.91%, elapsed time: 192132.7988 seconds.
INFO - [120] train_accuracy: 41.52%, test_accuracy: 41.18%, elapsed time: 209347.4967 seconds.
INFO - [130] train_accuracy: 43.71%, test_accuracy: 42.41%, elapsed time: 226553.6149 seconds.
INFO - [140] train_accuracy: 46.19%, test_accuracy: 47.86%, elapsed time: 243681.9147 seconds.
INFO - [150] train_accuracy: 48.86%, test_accuracy: 52.52%, elapsed time: 260886.8874 seconds.
INFO - [160] train_accuracy: 51.79%, test_accuracy: 54.01%, elapsed time: 278025.4070 seconds.
INFO - [170] train_accuracy: 55.15%, test_accuracy: 57.40%, elapsed time: 295231.0110 seconds.
INFO - [180] train_accuracy: 59.03%, test_accuracy: 61.31%, elapsed time: 312378.6068 seconds.
INFO - [190] train_accuracy: 63.48%, test_accuracy: 66.29%, elapsed time: 329543.2688 seconds.
...
INFO - [199] train_accuracy: 66.13%, test_accuracy: 68.56%, elapsed time: 344980.0817 seconds.

 

  또한, 5번 정도의 epoch만을 반복했음에도 테스트 정확도(test accuracy)가 순식간에 25%에 가까이 되는 것을 확인할 수 있다.