Saltar al contenido

Tensorflow convierte un archivo pb a TFLITE usando python

Solución:

Puede convertir a tflite directamente en Python directamente. Tienes que congelar el gráfico y usar toco_convert. Necesita que los nombres y formas de entrada y salida se determinen antes de llamar a la API, como en el caso de la línea de comandos.

Un fragmento de código de ejemplo

Copiado de la documentación, donde se define un gráfico “congelado” (sin variables) como parte de su código:

import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
  open("test.tflite", "wb").write(tflite_model)

En el ejemplo anterior, no hay un paso de gráfico congelado ya que no hay variables. Si tiene variables y ejecuta toco sin congelar el gráfico, es decir, primero convierta esas variables en constantes, ¡entonces toco se quejará!

Si ha congelado graphdef y conoce las entradas y salidas

Entonces no necesitas la sesión. Puede llamar directamente a la API de toco:

path_to_frozen_graphdef_pb = '...'
input_tensors = [...]
output_tensors = [...]
frozen_graph_def = tf.GraphDef()
with open(path_to_frozen_graphdef_pb, 'rb') as f:
  frozen_graph_def.ParseFromString(f.read())
tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

Si tiene graphdef no congelado y conoce las entradas y salidas

Luego debe cargar la sesión y congelar el gráfico primero antes de llamar a toco:

path_to_graphdef_pb = '...'
g = tf.GraphDef()
with open(path_to_graphdef_pb, 'rb') as f:
  g.ParseFromString(f.read())
output_node_names = ["..."]
input_tensors = [..]
output_tensors = [...]

with tf.Session(graph=g) as sess:
  frozen_graph_def = tf.graph_util.convert_variables_to_constants(
      sess, sess.graph_def, output_node_names)
# Note here we are passing frozen_graph_def obtained in the previous step to toco.
tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

Si no conoce las entradas / salidas del gráfico

Esto puede suceder si no definió el gráfico, ej. descargó el gráfico de algún lugar o utilizó una API de alto nivel como los tf.estimators que le ocultan el gráfico. En este caso, debe cargar el gráfico y buscar para averiguar las entradas y salidas antes de llamar a toco. Vea mi respuesta a esta pregunta SO.

Esto es lo que funcionó para mí: (Modelo SSD_InceptionV2)

  1. Después de terminar el entrenamiento. Usé model_main.py de la carpeta object_detection. TFv1.11
  2. ExportGraph como TFLITE:
python /tensorflow/models/research/object_detection/export_tflite_ssd_graph.py

--pipeline_config_path annotations/ssd_inception_v2_coco.config 
--trained_checkpoint_prefix trained-inference-graphs/inference_graph_v7.pb/model.ckpt 
--output_directory trained-inference-graphs/inference_graph_v7.pb/tflite 
--max_detections 3
  1. Esto genera un archivo .pb para que pueda generar el archivo tflite de esta manera:
tflite_convert 
--output_file=test.tflite 
--graph_def_file=tflite_graph.pb 
--input_arrays=normalized_input_image_tensor 
--output_arrays="TFLite_Detection_PostProcess",'TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'

--input_shape=1,300,300,3 
--allow_custom_ops

Ahora las entradas / salidas no estoy 100 seguro de cómo obtener esto, pero este código me ayuda antes:

import tensorflow as tf
frozen='/tensorflow/mobilenets/mobilenet_v1_1.0_224.pb'
gf = tf.GraphDef()
gf.ParseFromString(open(frozen,'rb').read())
[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')]    
[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Mul')]

Siguiendo este ejemplo de TF, puede pasar el parámetro “–Saved_model_dir” para exportar el archivo Saved_model.pb y la carpeta de variables a algún directorio (ninguno existente) antes de ejecutar el script retrain.py:

python retrain.py …… –saved_model_dir /home/…./export

Para convertir su modelo a tflite, debe usar la siguiente línea:

convert_saved_model.convert(saved_model_dir="/home/.../export",output_arrays="final_result",output_tflite="/home/.../export/graph.tflite")

Nota: necesita importar convert_saved_model:

de tensorflow.contrib.lite.python importar convert_saved_model

Recuerda que puedes convertir a tflite de 2 formas:

ingrese la descripción de la imagen aquí

Pero la forma más fácil es exportar Saved_model.pb con variables en caso de que desee evitar el uso de herramientas de compilación como Bazel.

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *