Saltar al contenido

Obtén el valor de algunos pesos en un modelo entrenado por TensorFlow

Solución:

En TensorFlow, los pesos entrenados están representados por tf.Variable objetos. Si creaste un tf.Variable—Eg llamó v– usted mismo, puede obtener su valor como una matriz NumPy llamando sess.run(v) (dónde sess es un tf.Session).

Si actualmente no tiene un puntero al tf.Variable, puede obtener una lista de las variables entrenables en el gráfico actual llamando tf.trainable_variables(). Esta función devuelve una lista de todos los entrenables. tf.Variable objetos en el gráfico actual, y puede seleccionar el que desee haciendo coincidir el v.name propiedad. Por ejemplo:

# Desired variable is called "tower_2/filter:0".
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]

Respuesta compatible 2.0: Si construimos un modelo usando Keras Sequential API, podemos obtener los pesos del modelo usando el código que se menciona a continuación:

!pip install tensorflow==2.1

from tf.keras import Sequential

model = Sequential()

model.add(Conv2D(filters=conv1_fmaps, kernel_size=conv1_ksize,
                         strides=conv1_stride, padding=conv1_pad,
                         activation=tf.nn.relu, input_shape=(height, width, channels),
                    data_format="channels_last"))

model.add(MaxPool2D(pool_size = (2,2), strides= (2,2), padding="VALID"))

model.add(Dropout(0.25))

model.add(Flatten())

model.add(Dense(units = 32, activation = 'relu'))

model.add(Dense(units = 10, activation = 'softmax'))

model.summary()

print(model.trainable_variables) 

La última declaración, print(model.trainable_variables), devolverá los pesos del modelo como se muestra a continuación:

    [<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 32) dtype=float32>,
 <tf.Variable 'conv2d/bias:0' shape=(32,) dtype=float32>, <tf.Variable 
'dense/kernel:0' shape=(6272, 32) dtype=float32>, <tf.Variable 'dense/bias:0' 
shape=(32,) dtype=float32>, <tf.Variable 'dense_1/kernel:0' shape=(32, 10) 
dtype=float32>, <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32>]

Entonces, si continúa con este código paso a paso, primero obtendrá una lista de variables usadas / entrenables. Luego, podría ordenarlos en una lista en la que clasifique las matrices / listas de peso por nombres de variables, solo por ejemplo, cómo fue posible tratar con esa información.

vars = tf.trainable_variables()
print(vars) #some infos about variables...
vars_vals = sess.run(vars)
for var, val in zip(vars, vars_vals):
    print("var: {}, value: {}".format(var.name, val)) #...or sort it in a list....
¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)


Tags : /

Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *