일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- 객체 지향형 DB
- 백트래킹
- sync_with_stdio(0)
- Backtracking
- BOJ
- 객체관계형DB
- 네트워크형DB
- 프로그래밍 언어 기술 동향
- aws winscp
- 클라우드기반 IDE
- 소프트웨어아키텍처 기술 동향
- 소프트웨어 개발도구
- 시간초과
- 온라인처리
- 계층형DB
- 공간DB
- 멀티미디어DB
- compare
- 정렬
- 개발프레임워크
- 메인 메모리 DB
- 개발프레임워크의 기술 동향
- vector unique erase
- TOPCIT
- Flutter
- 프로그래밍 언어
- ANSI-SPARC
- boj 11659
- compare구조체
- c++
옐그's 코딩라이프
[tensor] tflite를 이용하여 과일 classification 어플 만들기 (1) (모델 만들기) 본문
Image Classification App | Deploy TensorFlow model on Android | #2 - YouTube
위의 동영상을 참고하여 작성하였습니다.
진행 순서
0. 필요한 사진 다운받기
1. 필요한 라이브러리 불러오기
2. 데이터셋 불러오기
3. 클래스 이름 지정하고 데이터 시각화하기
4. 모델 만들기
5. 컴파일하기
6. 훈련시키기
7. 테스트 데이터셋 예측하기
8. tflite 로 추출하기
0. 필요한 사진 다운받기
kaggle(로그인 필요/데이터셋), 픽사베이(사진 하나씩 다운) 등에서 필요한 사진을 다운받아주세요.
혹시 돌리는 와중에 InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.라는 에러가 발생하면
https://yell0wgreen.tistory.com/27 이 글을 참고해주세요
1. 필요한 라이브러리 불러오기
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import log_loss
2. 데이터셋 불러오기
필요에 따라 validation 데이터 셋은 안넣으셔도 됩니다.
코랩에서 하다보니 런타임이 끝날 때마다 불러왔던 파일이 리셋되기 때문에 구글 드라이브에 넣어놓고 구글 드라이브를 마운트하는 식으로 불러왔습니다.
img_height, img_width = 32, 32
batch_size = 20
train_ds = tf.keras.utils.image_dataset_from_directory(
"drive/MyDrive/archive/train",
image_size = (img_height, img_width),
batch_size = batch_size
)
valid_ds = tf.keras.utils.image_dataset_from_directory(
"drive/MyDrive/archive/validation",
image_size = (img_height, img_width),
batch_size = batch_size
)
test_ds = tf.keras.utils.image_dataset_from_directory(
"drive/MyDrive/archive/test",
image_size = (img_height, img_width),
batch_size = batch_size
)
눈모양 옆에 있는 파일+드라이브 모양을 누르면 구글 드라이브를 사용할 수 있습니다.
3. 클래스 이름 지정하고 데이터 시각화하기
class_names = ["apple", "banana", "mandarine", "orange", "tangelo"]
plt.figure(figsize=(10,10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i+1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_nam
class 이름을 하나하나 지정해주었지만, 아래와 같이 할 경우 알파벳 순서로 정렬되어 있는 디렉토리 이름이 class 이름이 됩니다.
class_names = train_ds.class_names
4. 모델 만들기
model = tf.keras.Sequential(
[
tf.keras.layers.Rescaling(1./255),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(5, activation = "softmax") # class 갯수
]
)
처음 tf.keras.Rescvaling(1./255)는 rgb채널값 255을 표준화해주는 과정이다.
확률을 output으로 내는 모델이 목적이어서 마지막줄은 activation fuction을 softmax로 넣어주었다.
맨 마지막에 5 대신 자신의 클래스 갯수에 맞는 숫자를 넣어주면 된다.
5. 컴파일하기
model.compile(
optimizer = "adam",
loss = tf.losses.SparseCategoricalCrossentropy(from_logits = True),
metrics = ['accuracy']
)
6. 훈련시키기
model.fit(
train_ds,
validation_data = valid_ds,
epochs = 10
)
위에서 validation 데이터셋을 따로 만들지 않으신 분들은 validation_data = valid_ds, 을 주석처리하거나 지우셔서 하시면 됩니다.
7. 테스트 데이터셋 예측하기
y_pred = model.predict(test_ds)
y_pred
8. tflite 로 추출하기
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open("model.tflite", 'wb') as f:
f.write(tflite_model)
이 과정은 어플에서 텐서로 만든 인공지능 모델을 사용하고 싶으신 분만 진행해주시면 됩니다.
'AI' 카테고리의 다른 글
[tensor] InvalidArgumentError : Unknown image file format. One of JPEG, PNG, GIF, BMP required (0) | 2022.11.28 |
---|---|
[conda] activate CommandNotFoundError 해결하기 (0) | 2022.09.07 |