Python/머신러닝, 딥러닝

Python 딥러닝 포켓몬 분류

dustKim 2024. 6. 24. 21:57
포켓몬 분류

 

- Train: https://www.kaggle.com/datasets/thedagger/pokemon-generation-one
- Validation: https://www.kaggle.com/hlrhegemony/pokemon-image-dataset

 

Complete Pokemon Image Dataset

2,500+ clean labeled images, all official art, for Generations 1 through 8.

www.kaggle.com

더보기
# 케글에서 가져오기 위해서 모듈 import
import os

# 케글 연결하기
os.environ["KAGGLE_USERNAME"] = "name"
os.environ["KAGGLE_KEY"] = "apikey"

# 데이터셋 가져오기
!kaggle datasets download -d thedagger/pokemon-generation-one
!kaggle datasets download -d hlrhegemony/pokemon-image-dataset

# zip파일 풀기
!unzip -q /content/pokemon-generation-one.zip
!unzip -q /content/pokemon-image-dataset.zip

!mv dataset train
!rm -rf train/dataset
!mv images validation
# 데이터 확인하기
train_labels = os.listdir("train")
print(train_labels)
print(len(train_labels))
결과

 

 

# 데이터 확인하기
valid_labels = os.listdir("validation")
print(valid_labels)
print(len(valid_labels))
결과

 

 

import shutil

# valid 데이터에서 train 데이터에 없는 것을 삭제 
for val_label in valid_labels:
  if val_label not in train_labels:
    shutil.rmtree(os.path.join("validation", val_label))
    
# 삭제 후 확인
valid_labels = os.listdir("validation")
len(valid_labels)
결과

 

 

# train 데이터에는 있지만 valid 데이터에 없는 것을 확인
for train_label in train_labels:
  if train_label not in valid_labels:
    print(train_label)
결과

 

 

# train 데이터에는 있지만 valid 데이터에 없는 디렉토리 파일을 만들기
# 데이터의 양을 맞춰주기 위함이다.
for train_label in train_labels:
  if train_label not in valid_labels:
    os.makedirs(os.path.join("validation", train_label),exist_ok=True)

# 만들고 확인하기
valid_labels = os.listdir("validation")
len(valid_labels)
결과

 

 

# 필요한 모듈 import
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

# device 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
결과

 

 

data_transforms = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)), # 사이즈를 바꿔줌, 사이즈 통일
        # 회전 각도 0, 이미지 기울기 최대 10도까지 변형 적용, 이미지 80% ~ 120% 사이 임의의 크기로 확대/축소
        transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        transforms.RandomHorizontalFlip(), # 이미지를 랜덤으로 수평으로 뒤집음
        transforms.ToTensor() # Tensor형으로 변환
    ]),
    "validation": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
}
# 위에 함수를 사용하여 폴더에 있는 사진들 변환
image_datasets = {
    "train": datasets.ImageFolder("train", data_transforms["train"]),
    "validation": datasets.ImageFolder("validation", data_transforms["validation"])
}
dataloaders ={
    "train": DataLoader(
        image_datasets["train"],
        batch_size=32,
        shuffle=True
    ),
    "validation": DataLoader(
        image_datasets["validation"],
        batch_size=32,
        shuffle=False
    )
}
# 이미지 뽑아보기
imgs, labels = next(iter(dataloaders["train"]))

fig, axes = plt.subplots(4, 8, figsize=(16, 8))

for ax, img, label in zip(axes.flatten(), imgs, labels):
  ax.imshow(img.permute(1, 2, 0))
  ax.set_title(label.item())
  ax.axis("off")
결과

 

 

# 101번 이미지는 ?
image_datasets["train"].classes[101]
결과

 

EfficientNet

 

- 구글의 연구팀이 개발한 이미지 분류, 객체 검출 등 컴퓨터 비전 작업에서 높은 성능을 보여주는 신경망 모델이다.
- 신경망의 깊이, 너비, 해상도를 동시에 확장하는 방법을 통해 효율성과 성능을 극대화한 것이 특징이다.
- EfficientnetB4는 EfficientNet 시리즈의 중간 크기 모델이다.

더보기
# 필요한 모듈 import
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
def get_state_dict(self, *args, **kwargs):
    kwargs.pop("check_hash")
    return load_state_dict_from_url(self.url, *args, **kwargs)

WeightsEnum.get_state_dict = get_state_dict

# 모델 생성
model = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1).to(device)
model
결과

 

 

for param in model.parameters():
  param.requires_grad = False

model.classifier = nn.Sequential(
    nn.Linear(1792, 512),
    nn.ReLU(),
    nn.Linear(512, 149)
).to(device)
print(model)
결과

 

 

optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

epochs = 10

for epoch in range(epochs):
    for phase in ['train', 'validation']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        sum_losses = 0
        sum_accs = 0

        for x_batch, y_batch in dataloaders[phase]:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            y_pred = model(x_batch)
            loss = nn.CrossEntropyLoss()(y_pred, y_batch)

            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            sum_losses = sum_losses + loss
            y_prob = nn.Softmax(1)(y_pred)
            y_pred_index = torch.argmax(y_prob, axis=1)
            acc = (y_batch == y_pred_index).float().sum() / len(y_batch) * 100
            sum_accs = sum_accs + acc

        avg_loss = sum_losses / len(dataloaders[phase])
        avg_acc = sum_accs / len(dataloaders[phase])
        print(f'{phase:10s}: Epoch {epoch+1:4d}/{epochs} Loss: {avg_loss:.4f} Accuracy: {avg_acc:.2f}%')
결과

 

 

# 학습된 모델 파일 저장
torch.save(model.state_dict(), "model.pth") # model.h5
model = models.efficientnet_b4().ti(device)

model.classifier = nn.Sequential(
    nn.Linear(1792, 512),
    nn.ReLU(),
    nn.Linear(512, 149)
).to(device)

print(model)
결과

 

 

model.load_state_dict(torch.load("model.pth"))
결과

 

 

from PIL import Image

img1 = Image.open("validation/Snorlax/4.jpg")
img2 = Image.open("validation/Diglett/0.jpg")

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img1)
axes[0].axis("off")
axes[1].imshow(img2)
axes[1].axis("off")
plt.show()
결과

 

 

img1_input = data_transforms['validation'](img1)
img2_input = data_transforms['validation'](img2)
print(img1_input.shape)
print(img2_input.shape)
결과

 

 

test_batch = torch.stack([img1_input, img2_input])
test_batch = test_batch.to(device)
test_batch.shape
결과

 

 

y_pred = model(test_batch)
y_pred
결과

 

 

y_prob = nn.Softmax(1)(y_pred)
y_prob
결과

 

 

probs, idx = torch.topk(y_prob, k=3)
print(probs)
print(idx)
결과

 

 

fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].set_title("{:.2f}% {}, {:.2f}% {}, {:.2f}% {}".format(
    probs[0, 0] * 100,
    image_datasets["validation"].classes[idx[0, 0]],
    probs[0, 1] * 100,
    image_datasets["validation"].classes[idx[0, 1]],
    probs[0, 2] * 100,
    image_datasets["validation"].classes[idx[0, 2]]
))
axes[0].imshow(img1)
axes[0].axis("off")

axes[1].set_title("{:.2f}% {}, {:.2f}% {}, {:.2f}% {}".format(
    probs[1, 0] * 100,
    image_datasets["validation"].classes[idx[1, 0]],
    probs[1, 1] * 100,
    image_datasets["validation"].classes[idx[1, 1]],
    probs[1, 2] * 100,
    image_datasets["validation"].classes[idx[1, 2]]
))
axes[1].imshow(img2)
axes[1].axis("off")
결과