Giter VIP home page Giter VIP logo

patternrecognition's Introduction

Hi there 👋

patternrecognition's People

Contributors

nhk9680 avatar

Stargazers

 avatar

Watchers

 avatar  avatar

patternrecognition's Issues

Future work: Data Augmentation

Future work: Data Augmentation

Introduction

현재 kaggle에 있는 Caltech 101 데이터셋은 101개+Background_Google=102개 클래스로 이루어져 있습니다.
각 클래스당 30개의 이미지가 있어서, 로드하면 3060개의 이미지가 리스트에 저장되는 것을 모두들 확인하셨을 겁니다.

실제 Caltech 101의 데이터 양은 다음과 같습니다.

Pictures of objects belonging to 101 categories. About 40 to 800 images per category. Most categories have about 50 images. Collected in September 2003 by Fei-Fei Li, Marco Andreetto, and Marc 'Aurelio Ranzato. The size of each image is roughly 300 x 200 pixels.

이 양은 이론을 공부하는 측면에서는 부족하지 않지만, 추후 빅데이터 분석이나 딥러닝 등의 분야에서 작업을 하다 보면 기본 10만개가 넘어가는 데이터를 만나게 됩니다.

이는 기술의 발전과 유저 데이터 양산으로 가능하게 되었으며, 곧 학습 성능의 향상이 동반됩니다.

따라서 우리 데이터셋도 더 많은 양으로 늘려서 모델을 학습시키면, 정확도가 올라갈 수 있습니다.

다만, 아무리 데이터 양이 많다고 해서 성능이 exponential 하게 올라가는 것은 아니고, 또한 그만큼 학습에 소요되는 시간과 비용 또한 증가하기 때문에 적정한 선에서 타협을 보아야 합니다.


설명

저희 Google Colab Notebook의

  • Python 3 Google Compute Engine Backend (GPU)
  • RAM 25.51 GB
  • Drive 358.27 GB

를 기준으로 해보겠습니다.

증폭량을 5배부터 낮춰본 결과 3배에서 메모리 오버플로우가 일어나지 않고 진행되었습니다.

from keras.preprocessing.image import ImageDataGenerator

keras 라이브러리를 이용하였습니다.

datagen = ImageDataGenerator(
rotation_range = 40,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True,
vertical_flip = True,
brightness_range = (0.5, 1.5))

이미지 회전, 자르기, 확대, 수직/수평 반전, 밝기 조절 외에도 다양한 옵션이 있습니다.

keras는 이미지를 (1, w, h, channel)의 형태로 처리하기 때문에, 이에 맞게 변형시켜 넣어줘야 합니다.

RGB의 경우:
x.shape
#(w, h, channel=3)
x = x.reshape((1, ) + x.shape)
#(1, w, h, 3)
grayscale의 경우:
x.shape
#(w, h)
x = x.reshape((1, ) + x.shape + (1, ))
#(1, w, h, 1)

본인의 데이터를 확인하고 이에 맞게 변형합니다. SIFT 추출 시에 Grayscale로 작업합니다.

x_reshape = []
train_data = np.asarray(train_data)
for i, img in enumerate(train_data):
img = img.reshape((1,) + (img.shape) + (1, ))
x_reshape.append(img)

그 다음 ImageDataGenerator.flow() 메소드를 이용하여 이미지 변형을 합니다. batch size는 한 번 반복시에 변형되는 이미지의 양입니다.

3개의 변형 케이스만 만들 것이므로, 3번째 반복마다 끊어주고 다음 이미지로 넘어갑니다.

 gray_train_data = []
gray_train_label = []
for img, label in tqdm(zip(x_reshape, train_label), total=len(train_label)):
    for i, batch_x in enumerate(datagen.flow(img, batch_size = 1)):
        gray = cv2.resize(gray, (256, 256))
        gray_train_data.append(gray)
        gray_train_label.append(label)
        if i == 3-1 :
            break

변형된 이미지의 예시입니다.

image


실행

image

kmeans를 fit할 descriptor의 사이즈는 위와 같습니다.

image

추출된 히스토그램을 SVM에 학습시킵니다. 시간은 5시간 정도 소요되었습니다.

결과는 ...

image

오버피팅이 된 듯 합니다. 그마저도 성능이 좋지 않습니다.

데이터가 극단적으로 적은 환경에서의 대처 방안을 공부했다는 것으로 의의를 남깁니다.

Reference

Caltech101

Python | Data Augmentation

전처리부터 저장까지의 과정

데이터셋은 traintest1로 나누어져 있으며, jpg 포맷의 이미지입니다.
이를 클러스터링 방법론 중 하나인 SVM을 이용하여 분류하는 방법입니다.


학습 데이터 전처리

train 폴더에는 cat.0.jpg, cat.0.jpg, ... 와 같이 정답 라벨(cat/dog).index.jpg로 구성되어 있습니다.
이를 이미지 X와 라벨 데이터 Y로 나누어 저장합니다.

for i, img in enumerate(paths.list_images(dataset_train)):

imutils 라이브러리의 paths 클래스 중 list_images 메소드를 이용하여 dataset_train을 로드합니다.

학습량을 조절하기 위해 변수 i를 추가하여 일정량(1000)이 되면 중단되도록 설정하였습니다.

이상적으로는 test 데이터보다 수가 많아야 하지만, colab notebook의 runtime 문제 등으로 인하여 일부만 포함시켰습니다.

만약 본인의 환경이 여유가 된다면 모든 train 데이터를 포함시키는 것을 권장드립니다.

# dog: 1
  Y.append(1 if 'dog' in img else 0)  
  img = cv2.imread(img)
#  print(img.shape)
  img = cv2.resize(img, (32, 32), interpolation=cv2.INTER_AREA)
#  print(img.shape)
  img = img.flatten()
#  print(img.shape)
  X.append(img)

cat은 0, dog는 1로 매핑하였으므로 이미지 파일명을 확인하여 라벨을 추가합니다.

OpenCV를 이용하여 이미지 파일을 불러온 후, resize() 메소드를 이용하여 32x32 크기로 조정합니다.

마지막으로 용이한 학습을 위해 2차원 이미지 배열을 1차원으로 변환하여 이미지 배열 X에 추가합니다.

X = np.array(X)
Y = np.array(Y)

위에서 입력받은 X, Y는 type이 list 형태이므로, array 형태로 형변환 해줍니다.

이렇게 전처리된 데이터를 학습할 때, 오버피팅을 방지하고 어느 정도의 성능이 나오는지 확인하기 위해 train set과 validation set으로 분할합니다.

X_train, X_val, Y_train, Y_val = train_test_split(
    X, Y, test_size=0.25, random_state=42)

scikit learn에서 제공해주는 train_test_split() 메소드를 이용하여 간편하게 분할해줍니다. 이 때, test size는 25%를 권장합니다. 보통 8:2, 7:3 비율도 많이 이용합니다.


학습 기반 분류기 설계

앞서 말씀드린대로 SVM(Support Vector Machine)을 적용하기 위해, scikit learn에서 제공해주는 SVC 클래스를 이용합니다.

svc = SVC(kernel='poly')
svc.fit(X_train, Y_train)

적용 가능한 SVM의 kernellinear, poly, rbf 등이 있습니다. 기본값은 rbf입니다.

svc.fit()으로 모델을 학습시킵니다.

predict = svc.predict(X_val)
print(classification_report(Y_val, predict))

학습된 모델을 이용하여 validation set을 적용해 보고, classification_report를 이용하여 precision-recall 값과 mAP(mean Average Precision) 값을 구해봅니다.

              precision    recall  f1-score   support

           0       0.58      0.66      0.62       123
           1       0.62      0.54      0.58       128

    accuracy                           0.60       251
   macro avg       0.60      0.60      0.60       251
weighted avg       0.60      0.60      0.60       251

결과는 위와 같습니다.


테스트 데이터 전처리

앞서 진행했던 것과 동일하게, 이번에는 테스트 데이터를 전처리합니다.

다만 차이점은, 테스트 데이터는 정답이 없는 index.jpg의 형태입니다.
따라서 Y 배열은 필요없고 X 배열만 선언해줍니다.
또한, 테스트 데이터는 validation이 필요 없으므로 분할하지 않고, 모두 포함합니다.


학습한 모델로 테스트하기

result = svc.predict(X_test)

앞서 학습해둔 svc에 테스트 셋을 적용하여 그 결과를 추출합니다.


csv로 저장하기

python에서 데이터 셋을 조금 더 편하게 다루게 해주는 pandas 라이브러리를 이용합니다. pandas에서는 DataFrame이라는 형태로 데이터를 처리합니다. 이는 table과 유사한 형태라고 생각하시면 됩니다.

df = pd.DataFrame(result, columns=['label'])

DataFrame을 생성하고, 그 안에 추출한 결과값을 대입합니다. 열의 이름은 label로 설정합니다.

df.index += 1
df.index.name = 'id'

제출 예시인 Sample csv파일의 규격을 맞추기 위해

  • index 값의 시작을 1로 맞추고,
  • index 열의 이름을 'id'로 설정해줍니다.
df.to_csv('result_namhun_kim-2320.csv',index=True, header=True)

마지막으로 to_csv() 메소드를 이용하여 DataFrame을 csv 포맷으로 저장합니다.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.