Saltar al contenido

Cómo usar tf.reset_default_graph ()

Solución:

Probablemente así es como lo usa:

import tensorflow as tf
a = tf.constant(1)
with tf.Session() as sess:
    tf.reset_default_graph()

Obtiene un error porque lo usa en una sesión. Desde el tf.reset_default_graph() documentación:

Llamar a esta función mientras un tf.Session o tf.InteractiveSession está activo dará como resultado un comportamiento indefinido. El uso de cualquier objeto tf.Operation o tf.Tensor creado previamente después de llamar a esta función dará como resultado un comportamiento indefinido


tf.reset_default_graph() puede ser útil (al menos para mí) durante la fase de prueba mientras experimento en el cuaderno jupyter. Sin embargo, nunca lo he usado en producción y no veo cómo sería útil allí.

Aquí hay un ejemplo que podría estar en un cuaderno:

import tensorflow as tf
# create some graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(...)

Ahora ya no necesito estas cosas, pero si creo otro gráfico y lo visualizo en el tensorboard, veré los nodos antiguos y los nuevos. Para resolver esto, podría reiniciar el kernel y ejecutar solo la siguiente celda. Sin embargo, solo puedo hacer:

tf.reset_default_graph()
# create a new graph
with tf.Session() as sess:
    print sess.run(...)

Editar después de que OP agregó su código:

with tf.name_scope("predict"):
    tf.reset_default_graph()

Esto es lo que sucede aproximadamente. Tu código falla porque tf.name_scope ya agregué algo a un gráfico. Mientras está dentro de este “agregar algo al gráfico”, le dice a TF que elimine el gráfico por completo, pero no puede porque está ocupado agregando algo.

Por alguna razón, necesito construir un nuevo gráfico DURANTE MUCHAS VECES, y acabo de probarlo, ¡que finalmente funciona! Muchas gracias por la respuesta de Salvador Dali 🙂

import tensorflow as tf
from my_models import Classifier

for i in range(10):
    tf.reset_default_graph()
    # build the graph
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
    classifier = Classifier(global_step)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        print("do sth here.")
¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *