본문 바로가기
Coloring (Additional Study)/Contest

DCC Normalization

by 생각하는 이상훈 2023. 10. 29.
728x90

Normalization

처음에는 normalization을 imagenet에 적합하다고 알려져 있는 값인 평균: [0.485, 0.456, 0.406] 표준편차: [0.229, 0.224, 0.225]로 normalize이미지를 normalize하여 결과를 봤으나 오히려 성능이 떨어졌다. 따라서 우리가 갖고있는 데이터에 적절한 normalization 값을 찾아보고자 했다. 몇몇 방법을 사용해 봤지만 직접 계산한 값중에는 음식이 있는 위치인 중앙을 중심으로 crop하여 norm과 std를 계산한 것이 성능이 제일 좋았다. 아래는 해당 calculated mean과 std를 구하는 코드이다.

import os
import torch
from torchvision import datasets, transforms
from PIL import Image
import numpy as np
def main():
    # Data paths
    train_loader = os.path.join('kfood_train', 'train')
    valid_loader = os.path.join('kfood_val', 'val')

    # Set up transformation to only resize and convert the image to a tensor.
    # No normalization at this stage.
    transform_only_resize = transforms.Compose([
        transforms.Resize((244, 244)),
        transforms.ToTensor()
    ])

    # Load datasets
    dataset_train = datasets.ImageFolder(root=train_loader, transform=transform_only_resize)

    # DataLoader for the datasets
    dataset_train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=64, shuffle=True, num_workers=os.cpu_count())

    # Lists to store the values of pixels from the central part of the images
    pixels = []

    # Loop through the dataset to get the center pixels
    for images, _ in dataset_train_loader:
        # Crop out the center (122, 122) of the images
        center_crops = images[:, :, 61:183, 61:183]
        pixels.append(center_crops)

    # Concatenate all tensor lists to make one tensor
    all_pixels = torch.cat(pixels, dim=0)

    # Calculate the mean and std along the color channels
    mean = torch.mean(all_pixels, dim=[0, 2, 3])
    std = torch.std(all_pixels, dim=[0, 2, 3])

    print("Mean: ", mean.numpy())
    print("Std: ", std.numpy())

if __name__ == '__main__':
    main()

아래 코드를 통해 원본 이미지, 계산된 normalize값으로 normalization한 이미지, 0.5로 통일해서 normalization한 이미지 출력해서 비교해보았다.

import os
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

def main():
    # Data paths
    train_loader_path = os.path.join('kfood_train', 'train')

    simple_mean = [0.5, 0.5, 0.5]
    simple_std = [0.5, 0.5, 0.5]

    calculated_mean = [0.6048, 0.5008, 0.3549]
    calculated_std = [0.2398, 0.2449, 0.2553]

    # 정규화 변환 및 원본 이미지 변환 설정
    transform_normalized = transforms.Compose([
        transforms.Resize((244, 244)),
        transforms.ToTensor(),
        transforms.Normalize(calculated_mean, calculated_std)
    ])
    
    transform_normalized2 = transforms.Compose([
        transforms.Resize((244, 244)),
        transforms.ToTensor(),
        transforms.Normalize(simple_mean, simple_std)
    ])

    transform_original = transforms.Compose([
        transforms.Resize((244, 244)),
        transforms.ToTensor()
    ])

    # Load datasets
    dataset = datasets.ImageFolder(root=train_loader_path)
    

    # 이미지 표시하기
    plt.figure(figsize=(20, 10))

    seen_labels = set()  # 이미 본 레이블을 저장하기 위한 세트
    count = 0  # 몇 개의 이미지가 출력되었는지 카운트하기 위한 변수

    for img, label in dataset:
        if label not in seen_labels:
            seen_labels.add(label)
            
            # 원본 이미지
            plt.subplot(5, 3, 3*count+1)
            img_original = transform_original(img).permute(1, 2, 0).numpy()
            plt.imshow(img_original)
            plt.title("Original")
            plt.axis('off')
            
            # 0.5로 정규화된 이미지
            plt.subplot(5, 3, 3*count+2)
            img_normalized2 = transform_normalized2(img).permute(1, 2, 0).numpy()
            plt.imshow(img_normalized2)
            plt.title("Normalized by 0.5")
            plt.axis('off')
            
            # 계산 값으로 정규화된 이미지
            plt.subplot(5, 3, 3*count+3)
            img_normalized = transform_normalized(img).permute(1, 2, 0).numpy()
            plt.imshow(img_normalized)
            plt.title("Normalized by calculation")
            plt.axis('off')
            
            count += 1

        if count >= 5:  # 5개의 이미지가 출력되면 루프를 종료
            break

    plt.tight_layout()
    plt.show()


if __name__ == '__main__':
    main()

계산한 값으로 normalize한 이미지가 좀더 왜곡 되어 음식을 제외한 부분을 밝게 날려버리는 효과가 있어보였다. resnet18기반으로 가볍게 모델을 돌리니 성능이 더 잘나오는 모습을 볼 수 있었는데 resnet50을 이용하고 학습을 많이 진행할 수록 0.5를 이용한 normalization이 오히려 과적합을 잘 잡아내는 것을 볼 수 있었다. Normalization에 대한 이해도를 높일 수 있는 기회였다.


 

728x90