티스토리 뷰
Machine Learning/Practice
[keras, TF2.0] 케라스를 사용한 분산 훈련 (MirroredStrategy)
JG Ahn 2020. 2. 24. 18:38케라스를 사용한 분산 훈련
- Licensed under the Apache License, Version 2.0 (the "License")
- MIT License
- https://www.tensorflow.org/tutorials/distribute/keras
1. 개요
tf.distribute.Strategy
API는 훈련을 여러 처리 장치들로 분산시키는 것을 추상화한 것이다기존의 모델이나 훈련 코드를 조금만 바꾸어 분산 훈련을 할 수 있게 하는 것이 목표이다
이 튜토리얼에서는
tf.distribute.MirroredStrategy
를 사용합니다.- 이 전략은 모델의 모든 변수를 각 프로세서에 복사합니다.
- 그리고 각 프로세서의 그래디언트를 All-Reduce를 사용하여 모읍니다.
- 그다음 모아서 계산한 값을 각 프로세서의 모델 복사본에 적용합니다
MirroredStrategy
는 텐서플로에서 기본으로 제공하는 몇 가지 분산 전략 중 하나입니다. 다른 전략들에 대해서는 분산 전략 가이드를 참고하세요.
2. 필요한 패키지
from __future__ import absolute_import, division, print_function, unicode_literals
# 텐서플로와 텐서플로 데이터셋 패키지 가져오기
!pip install tensorflow-gpu==2.0.0-rc1
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
import os
3. 데이터셋 다운로드
- MNIST 데이터셋을 사용합니다.
- with_info를 True로 설정하면 전체 데이터에 대한 메타 정보도 함께 불러옵니다.
- 이 정보는
info
변수에 저장됩니다. 훈련과 테스트 샘플 수를 비롯한 여러 가지 정보들이 들어있습니다.
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
4. 분산 전략 정의하기
- 분산과 관련된 처리를 하는
MirroredStrategy
객체를 만듭니다 - 이 객체가 컨텍스트 관리자(
tf.distribute.MirroredStrategy.scope
)도 제공하는데, 이 안에서 모델을 만들어야 합니다.
strategy = tf.distribute.MirroredStrategy()
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))
5. 입력 파이프라인 구성하기
- 다중 GPU로 모델을 훈련할 때는 배치 크기를 늘려야 컴퓨팅 자원을 효과적으로 사용할 수 있다.
- 기본적으로 GPU 메모리에 맞추어 가능한 가장 큰 배치를 사용하고 이에 맞게 학습률도 조정해야 한다
# 데이터셋 내 샘플의 수는 info.splits.total_num_examples 로도
# 얻을 수 있습니다.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
- 픽셀 값은 0
255 사이이므로 01 범위로 정규화해야 합니다.
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
- 함수를 데이터에 적용합니다. 데이터를 셔플 하고 배치로 묶습니다
train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
6. 모델 만들기
strategy.scope
컨텍스트 안에서 케라스 모델을 만들고 컴파일합니다.
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
7. 콜백 정의하기
- 여기서 사용하는 콜백은 다음과 같습니다
- 텐서 보드(TensorBoard): 이 콜백은 텐서 보드용 로그를 남겨서, 텐서 보드에서 그래프를 그릴 수 있게 해 줍니다.
- 모델 체크포인트(Checkpoint): 이 콜백은 매 에포크(epoch)가 끝난 후 모델을 저장합니다.
- 학습률 스케줄러: 이 콜백을 사용하면 매 에포크 혹은 배치가 끝난 후 학습률을 바꿀 수 있습니다.
- 콜백을 추가하는 방법을 보여드리기 위하여 노트북에 _학습률_을 표시하는 콜백도 추가합니다
# 체크포인트를 저장할 체크포인트 디렉터리를 지정합니다.
checkpoint_dir = './training_checkpoints'
# 체크포인트 파일의 이름
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# 학습률을 점점 줄이기 위한 함수
# 필요한 함수를 직접 정의하여 사용할 수 있습니다.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# 에포크가 끝날 때마다 학습률(lr)을 출력하는 콜백.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\n에포크 {}의 학습률은 {}입니다.'.format(epoch + 1,
model.optimizer.lr.numpy()))
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
8. 훈련과 평가
- 모델의
fit
함수를 호출하고 튜토리얼의 시작 부분에서 만든 데이터셋을 넘깁니다. - 이 단계는 분산 훈련 여부와 상관없이 동일합니다.
model.fit(train_dataset, epochs=12, callbacks=callbacks)
에포크 1의 학습률은 0.0010000000474974513입니다.
938/938 [==============================] - 27s 29ms/step - loss: 0.2018 - accuracy: 0.9404
Epoch 2/12
926/938 [============================>.] - ETA: 0s - loss: 0.0676 - accuracy: 0.9799
에포크 2의 학습률은 0.0010000000474974513입니다.
- 체크포인트 디렉터리를 확인하면 체크포인트가 잘 저장되고 있는 것을 볼 수 있습니다
- 모델의 성능이 어떤지 확인하기 위하여, 가장 최근 체크포인트를 불러온 후 테스트 데이터에 대하여
evaluate
를 호출합니다.
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
텐서 보드 로그를 다운로드한 후 터미널에서 다음과 같이 텐서 보드를 실행하여 훈련 결과를 확인할 수 있습니다.
$ tensorboard --logdir=path/to/log-directory
9. SavedModel로 내보내기
- 플랫폼에 무관한 SavedModel 형식으로 그래프와 변수들을 내보냅니다.
- 모델을 내보낸 후에는, 전략 범위(scope) 없이 불러올 수도 있고, 전략 범위와 함께 불러올 수도 있습니다.
path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
strategy.scope
없이 모델 불러오기.
unreplicated_model = tf.keras.experimental.load_from_saved_model(path)
unreplicated_model.compile(
loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
strategy.scope
와 함께 모델 불러오기.
with strategy.scope():
replicated_model = tf.keras.experimental.load_from_saved_model(path)
replicated_model.compile(loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
10. 예제와 튜토리얼
- 분산 전략을 쓰는 예제들이 더 있습니다.
tf.distribute.MirroredStrategy
를 사용하여 학습한 Transformer 예제.tf.distribute.MirroredStrategy
를 사용하여 학습한 NCF 예제.
'Machine Learning > Practice' 카테고리의 다른 글
[tf, keras] 모델 시각화 (netron) (0) | 2020.07.03 |
---|---|
Keras 소개 및 가이드 (0) | 2020.02.24 |
[keras, TF2.0] 신경망 훈련하기:기초적인 분류 문제(Fashion MNIST) (0) | 2020.02.24 |
[keras, TF2.0] 온도 데이터, 시계열 예측하기 (Time Series Forecasting) (12) | 2020.02.03 |
[Keras, TF2.0] 딥러닝 모델에서 효율적인 입력 파이프라인 만들기 (1) | 2020.02.02 |
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
TAG
- 인공지능 스피커 호출
- 핵심어 검출
- stft
- aws cli
- Tensorflow2.0
- 시계열
- LSTM
- lambda
- netron
- nlp 트렌드
- keras
- librosa
- 모델 시각화
- AWS
- TF2.0
- nlg
- Introduction to Algorithm
- 알고리즘 강의
- boto3
- tensorflow
- BOJ
- nlp
- MIT
- 6.006
- wavenet
- 오디오 전처리
- MFCC
- S3
- 알고리즘
- RNN
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
글 보관함