Ver fuente en GitHub

Deje de entrenar cuando una métrica monitoreada haya dejado de mejorar.

Hereda de: Callback

Ver alias

Alias ​​de compatibilidad para la migración

Ver Guía de migración para más detalles.

tf.compat.v1.keras.callbacks.EarlyStopping

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0,
    mode='auto', baseline=None, restore_best_weights=False
)

Suponiendo que el objetivo de un entrenamiento es minimizar la pérdida. Con esto, la métrica a monitorear sería 'loss'y el modo sería 'min'. A model.fit() el ciclo de entrenamiento verificará al final de cada época si la pérdida ya no está disminuyendo, considerando el min_delta y patience si es aplicable. Una vez que se encuentra que ya no disminuye, model.stop_training se marca True y el entrenamiento finaliza.

La cantidad a monitorear debe estar disponible en logs dictar Para que así sea, pase la pérdida o las métricas en model.compile().

Argumentos
monitor Cantidad a monitorear.
min_delta El cambio mínimo en la cantidad monitoreada para calificar como una mejora, es decir, un cambio absoluto de menos de min_delta, contará como ninguna mejora.
patience Número de épocas sin mejora después de las cuales se detendrá el entrenamiento.
verbose modo de verbosidad.
mode Uno de "auto", "min", "max". En min modo, el entrenamiento se detendrá cuando la cantidad monitoreada haya dejado de disminuir; en "max" modo se detendrá cuando la cantidad monitoreada haya dejado de aumentar; en "auto" modo, la dirección se infiere automáticamente del nombre de la cantidad monitoreada.
baseline Valor de referencia para la cantidad supervisada. El entrenamiento se detendrá si el modelo no muestra una mejora con respecto a la línea de base.
restore_best_weights Si restaurar los pesos del modelo desde la época con el mejor valor de la cantidad monitoreada. Si es Falso, se utilizan los pesos del modelo obtenidos en el último paso del entrenamiento.

Ejemplo:

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# This callback will stop the training when there is no improvement in
# the validation loss for three consecutive epochs.
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
                    epochs=10, batch_size=1, callbacks=[callback],
                    verbose=0)
len(history.history['loss'])  # Only 4 epochs are run.
4

Métodos

get_monitor_value

Ver fuente

get_monitor_value(
    logs
)

set_model

Ver fuente

set_model(
    model
)

set_params

Ver fuente

set_params(
    params
)