TensorFlow
TensorFlow - EarlyStopping 라이브러리 사용법
Cong_S
2022. 6. 13. 12:35
# 오차를 나타낸 차트 = 학습을 반복할수록 오차가 적어짐
plt.plot(epoch_history.history['loss'])
plt.show()
이 그래프를 보면 수 백번 에포크를 진행한 이후에는 인공지능의 성능이 거의 향상되지 않는 것 같다.
성능이 더이상 향상되지 않는데 계속해서 학습을 진행하는 것은 여러모로 낭비이므로
이런 경우에 자동으로 학습이 멈추도록 코드를 수정해보자.
model = build_model()
ealry_stop = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience= 10)
이 때는 텐서플로우의 EarlyStopping 함수를 사용한다.
파라미터 값의 monitor에는 어떤 데이터를 모니터하면서 멈출지 말지 결정하는 것이므로
매 학습이 끝난 후 성능을 평가하는 validation 의 loss 를 설정해준 것이고
patience는 몇번의 에포크를 본 후에 멈출지를 결정하는 것이다.
위 코드를 다시 해석해보면 매 에포크마다 validation 의 loss를 보면서
10번의 에포크 동안 성능 향상이 없으면 해당 학습을 종료하겠다는 뜻이다.
model.fit(X_train, y_train, epochs = 1000000, validation_split= 0.2, callbacks=[ealry_stop])
위에서 만든 코드를 변수에 저장하고
fit의 callback 파라미터에 리스트로 넣어준다.
이 때 callback 이란? 프레임워크가 실행하는 코드, 코드 실행을 프레임워크에 맡기는 것을 말한다.
다시 말해, 텐서플로우가 모니터링을 하고 스스로 멈출수 있게 프레임워크에게 코드 실행권을 주는 것이다.
그래서 위 fit 함수를 실행하면 아래와 같은 결과가 나온다.
100만 번 중에 44번만에 학습이 종료가 된 것을 볼 수 있다.
차트의 x 값의 갯수도 44개 밖에 안되는 것을 알수 있다.