Ver fuente en GitHub

Gestiona guardar / restaurar valores rastreables en el disco.

tf.train.Checkpoint(
    root=None,**kwargs
)

Los objetos de TensorFlow pueden contener un estado rastreable, como tf.Variables, tf.keras.optimizers.Optimizer implementaciones, tf.data.Dataset iteradores, tf.keras.Layer implementaciones, o tf.keras.Model implementaciones. Estos se llaman objetos rastreables.

A Checkpoint El objeto se puede construir para guardar uno o un grupo de objetos rastreables en un archivo de punto de control. Mantiene un save_counter para numerar los puntos de control.

Ejemplo:

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)# Save a checkpoint to /tmp/training_checkpoints-save_counter. Every time# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

Ejemplo 2:

import tensorflow as tf
import os

checkpoint_directory ="/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory,"ckpt")# Create a Checkpoint that will manage two objects with trackable state,# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))for _ inrange(num_training_steps):
  optimizer.minimize(...)# Variables will be restored on creation.
status.assert_consumed()# Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Checkpoint.save() y Checkpoint.restore() escribir y leer puntos de control basados ​​en objetos, en contraste con los de TensorFlow 1.x tf.compat.v1.train.Saver que escribe y lee variable.name puntos de control basados. El punto de control basado en objetos guarda un gráfico de dependencias entre objetos de Python (Layers, Optimizers, Variables, etc.) con bordes con nombre, y este gráfico se usa para hacer coincidir variables al restaurar un punto de control. Puede ser más robusto a los cambios en el programa Python y ayuda a admitir la restauración al crear para las variables.

Checkpoint los objetos tienen dependencias de los objetos pasados ​​como argumentos de palabra clave a sus constructores, y cada dependencia recibe un nombre que es idéntico al nombre del argumento de palabra clave para el que se creó. Clases de TensorFlow como Layerarena Optimizers agregará automáticamente dependencias en sus propias variables (por ejemplo, “kernel” y “sesgo” para tf.keras.layers.Dense). Heredando de tf.keras.Model facilita la gestión de dependencias en clases definidas por el usuario, ya que Model se engancha en la asignación de atributos. Por ejemplo:

classRegress(tf.keras.Model):def__init__(self):super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)# ...defcall(self, inputs):
    x = self.input_transform(inputs)# ...

Esta Model tiene una dependencia llamada “input_transform” en su Dense capa, que a su vez depende de sus variables. Como resultado, guardar una instancia de Regress utilizando tf.train.Checkpoint también guardará todas las variables creadas por el Dense capa.

Cuando se asignan variables a varios trabajadores, cada trabajador escribe su propia sección del punto de control. Estas secciones luego se fusionan / reindexan para comportarse como un único punto de control. Esto evita copiar todas las variables a un trabajador, pero requiere que todos los trabajadores vean un sistema de archivos común.

Esta función difiere ligeramente del modelo Keras save_weights función. tf.keras.Model.save_weights crea un archivo de punto de control con el nombre especificado en filepath, tiempo tf.train.Checkpoint numera los puntos de control, usando filepath como prefijo para los nombres de archivo de punto de control. Aparte de esto, model.save_weights() y tf.train.Checkpoint(model).save() son equivalentes.

Ver el guía de puntos de control de formación para detalles.

Args
root El objeto raíz al punto de control.
**kwargs Los argumentos de palabras clave se establecen como atributos de este objeto y se guardan con el punto de control. Los valores deben ser objetos rastreables.
Eleva
ValueError Si root o los objetos en kwargs no son rastreables. A ValueError también se eleva si el root objeto rastrea diferentes objetos de los enumerados en atributos en kwargs (p. ej. root.child = A y tf.train.Checkpoint(root, child=B) son incompatibles).
Atributos
save_counter Incrementado cuando save() se llama. Se usa para numerar los puntos de control.

Métodos

read

Ver fuente

read(
    save_path, options=None)

Lee un punto de control de entrenamiento escrito con write.

Lee esto Checkpoint y cualquier objeto del que dependa.

Este método es como restore() pero no espera el save_counter variable en el puesto de control. Solo restaura los objetos de los que ya depende el punto de control.

El método está diseñado principalmente para ser utilizado por utilidades de gestión de puntos de control de nivel superior que utilizan write() en lugar de save() y tienen sus propios mecanismos para numerar y rastrear los puntos de control.

Uso de ejemplo:

# Create a checkpoint with write()
ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
path = ckpt.write('/tmp/my_checkpoint')# Later, load the checkpoint with read()# With restore() assert_consumed() would have failed.
checkpoint.read(path).assert_consumed()# You can also pass options to read(). For example this# runs the IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.read(path, options=options)
Args
save_path La ruta al punto de control devuelta por write.
options Opcional tf.train.CheckpointOptions objeto.
Devoluciones
Un objeto de estado de carga, que se puede utilizar para realizar afirmaciones sobre el estado de una restauración de punto de control. Ver restore para detalles.

restore

Ver fuente

restore(
    save_path, options=None)

Restaura un punto de control de entrenamiento.

Restaura esto Checkpoint y cualquier objeto del que dependa.

Este método está diseñado para cargar puntos de control creados por save(). Para los puestos de control creados por write() utilizar el read() método que no espera el save_counter variable agregada por save().

restore() asigna valores inmediatamente si las variables para restaurar ya se han creado, o pospone la restauración hasta que se crean las variables. Las dependencias agregadas después de esta llamada se compararán si tienen un objeto correspondiente en el punto de control (la solicitud de restauración se pondrá en cola en cualquier objeto rastreable esperando que se agregue la dependencia esperada).

Para asegurarse de que la carga esté completa y no se realizarán más asignaciones, use el assert_consumed() método del objeto de estado devuelto por restore():

checkpoint = tf.train.Checkpoint(...)
checkpoint.restore(path).assert_consumed()# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options).assert_consumed()

Se generará una excepción si no se encontraron objetos de Python en el gráfico de dependencia en el punto de control, o si algún valor con puntos de control no tiene un objeto de Python coincidente.

Basado en nombre tf.compat.v1.train.Saver Los puntos de control de TensorFlow 1.x se pueden cargar con este método. Los nombres se utilizan para hacer coincidir las variables. Vuelva a codificar los puntos de control basados ​​en nombres utilizando tf.train.Checkpoint.save tan pronto como sea posible.

Cargando desde puntos de control de modelos guardados

Para cargar valores de un modelo guardado, simplemente pase el directorio modelo guardado a checkpoint.restore:

model = tf.keras.Model(...)
tf.saved_model.save(model, path)# or model.save(path, save_format='tf')

checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()

Este ejemplo llama expect_partial() en el estado cargado, ya que SavedModels guardados de Keras a menudo genera claves adicionales en el punto de control. De lo contrario, el programa imprime muchas advertencias sobre las claves no utilizadas en el momento de la salida.

Args
save_path La ruta al punto de control, según lo devuelto por save o tf.train.latest_checkpoint. Si el punto de control fue escrito por el nombre tf.compat.v1.train.Saver, los nombres se utilizan para hacer coincidir las variables. Esta ruta también puede ser un directorio de modelo guardado.
options Opcional tf.train.CheckpointOptions objeto.
Devoluciones
Un objeto de estado de carga, que se puede utilizar para hacer afirmaciones sobre el estado de una restauración de punto de control.

El objeto de estado devuelto tiene los siguientes métodos:

  • assert_consumed(): Genera una excepción si alguna variable no coincide: valores de puntos de control que no tienen un objeto de Python coincidente o objetos de Python en el gráfico de dependencia sin valores en el punto de control. Este método devuelve el objeto de estado y, por lo tanto, puede estar encadenado con otras aserciones.

  • assert_existing_objects_matched(): Genera una excepción si alguno de los objetos de Python existentes en el gráfico de dependencia no tiene comparación. diferente a assert_consumed, esta aserción pasará si los valores en el punto de control no tienen objetos Python correspondientes. Por ejemplo un tf.keras.Layer objeto que aún no se ha construido, y por lo tanto no ha creado ninguna variable, pasará esta aserción pero fallará assert_consumed. Útil cuando se carga parte de un punto de control más grande en un nuevo programa de Python, por ejemplo, un punto de control de entrenamiento con un tf.compat.v1.train.Optimizer se guardó pero solo se está cargando el estado requerido para la inferencia. Este método devuelve el objeto de estado y, por lo tanto, puede estar encadenado con otras aserciones.

  • assert_nontrivial_match(): Afirma que se coincidió con algo aparte del objeto raíz. Esta es una afirmación muy débil, pero es útil para verificar la cordura en el código de la biblioteca donde pueden existir objetos en el punto de control que no se han creado en Python y algunos objetos de Python pueden no tener un valor de punto de control.

  • expect_partial(): Silenciar las advertencias sobre restauraciones incompletas de puntos de control. De lo contrario, las advertencias se imprimen para las partes no utilizadas del archivo u objeto del punto de control cuando Checkpoint se elimina el objeto (a menudo al cerrar el programa).

Eleva
NotFoundError si no se puede encontrar un punto de control o modelo guardado en save_path.

save

Ver fuente

save(
    file_prefix, options=None)

Guarda un punto de control de formación y proporciona una gestión básica del punto de control.

El punto de control guardado incluye variables creadas por este objeto y cualquier objeto rastreable del que depende en ese momento Checkpoint.save() se llama.

save es un envoltorio de conveniencia básica alrededor del write método, numerando secuencialmente los puntos de control utilizando save_counter y actualizar los metadatos utilizados por tf.train.latest_checkpoint. La administración de puntos de control más avanzada, por ejemplo, recolección de basura y numeración personalizada, puede ser proporcionada por otras utilidades que también write y read. (tf.train.CheckpointManager por ejemplo).

step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.save("/tmp/ckpt")#Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt").assert_consumed()#You can also pass options to save()andrestore(). For example this#runsthe IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.save("/tmp/ckpt", options=options)#Later, read the checkpoint with restore()
checkpoint.restore("/tmp/ckpt", options=options).assert_consumed()
Args
file_prefix Un prefijo que se utilizará para los nombres de archivo del punto de control (/ ruta / al / directorio / y_a_prefijo). Los nombres se generan basándose en este prefijo y Checkpoint.save_counter.
options Opcional tf.train.CheckpointOptions objeto.
Devoluciones
La ruta completa al punto de control.

write

Ver fuente

write(
    file_prefix, options=None)

Escribe un punto de control de entrenamiento.

El punto de control incluye variables creadas por este objeto y cualquier objeto rastreable del que depende en ese momento Checkpoint.write() se llama.

write no numera los puntos de control, incrementa save_countero actualice los metadatos utilizados por tf.train.latest_checkpoint. Está destinado principalmente a ser utilizado por utilidades de gestión de puntos de control de nivel superior. save proporciona una implementación muy básica de estas características.

Puntos de control escritos con write debe leerse con read.

Uso de ejemplo:

step = tf.Variable(0, name="step")
checkpoint = tf.Checkpoint(step=step)
checkpoint.write("/tmp/ckpt")#Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt").assert_consumed()#You can also pass options to write()andread(). For example this#runsthe IO ops on the localhost:
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.write("/tmp/ckpt", options=options)#Later, read the checkpoint with read()
checkpoint.read("/tmp/ckpt", options=options).assert_consumed()
Args
file_prefix Un prefijo que se utilizará para los nombres de archivo del punto de control (/ ruta / al / directorio / y_a_prefijo).
options Opcional tf.train.CheckpointOptions objeto.
Devoluciones
La ruta completa al punto de control (es decir, file_prefix).