Saltar al contenido

Comprobar el número total de parámetros en un modelo de PyTorch

Hacemos una verificación completa cada enunciado en nuestra página web con el objetivo de enseñarte en todo momento información certera y actualizada.

Solución:

PyTorch no tiene una función para calcular la cantidad total de parámetros como lo hace Keras, pero es posible sumar la cantidad de elementos para cada grupo de parámetros:

pytorch_total_params = sum(p.numel() for p in model.parameters())

Si desea calcular sólo el entrenable parámetros:

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

Respuesta inspirada en esta respuesta en PyTorch Forums.

Nota: estoy respondiendo mi propia pregunta. Si alguien tiene una mejor solución, por favor comparta con nosotros.

Para obtener el recuento de parámetros de cada capa como Keras, PyTorch tiene model.named_paramters() que devuelve un iterador tanto del nombre del parámetro como del parámetro en sí.

Aquí hay un ejemplo:

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: total_params")
    return total_params
    
count_parameters(net)

La salida sería algo como esto:

+-------------------+------------+
|      Modules      | Parameters |
+-------------------+------------+
| embeddings.weight |   922866   |
|    conv1.weight   |  1048576   |
|     conv1.bias    |    1024    |
|     bn1.weight    |    1024    |
|      bn1.bias     |    1024    |
|    conv2.weight   |  2097152   |
|     conv2.bias    |    1024    |
|     bn2.weight    |    1024    |
|      bn2.bias     |    1024    |
|    conv3.weight   |  2097152   |
|     conv3.bias    |    1024    |
|     bn3.weight    |    1024    |
|      bn3.bias     |    1024    |
|    lin1.weight    |  50331648  |
|     lin1.bias     |    512     |
|    lin2.weight    |   265728   |
|     lin2.bias     |    519     |
+-------------------+------------+
Total Trainable Params: 56773369

Si desea calcular la cantidad de pesos y sesgos en cada capa sin instanciar el modelo, simplemente puede cargar el archivo sin procesar e iterar sobre el resultado. collections.OrderedDict al igual que:

import torch


tensor_dict = torch.load('model.dat', map_location='cpu') # OrderedDict
tensor_list = list(tensor_dict.items())
for layer_tensor_name, tensor in tensor_list:
    print('Layer :  elements'.format(layer_tensor_name, torch.numel(tensor)))

Obtendrás algo como

conv1.weight: 312
conv1.bias: 26
batch_norm1.weight: 26
batch_norm1.bias: 26
batch_norm1.running_mean: 26
batch_norm1.running_var: 26
conv2.weight: 2340
conv2.bias: 10
batch_norm2.weight: 10
batch_norm2.bias: 10
batch_norm2.running_mean: 10
batch_norm2.running_var: 10
fcs.layers.0.weight: 135200
fcs.layers.0.bias: 260
fcs.layers.1.weight: 33800
fcs.layers.1.bias: 130
fcs.batch_norm_layers.0.weight: 260
fcs.batch_norm_layers.0.bias: 260
fcs.batch_norm_layers.0.running_mean: 260
fcs.batch_norm_layers.0.running_var: 260

Comentarios y calificaciones de la guía

Más adelante puedes encontrar las aclaraciones de otros usuarios, tú de igual manera puedes dejar el tuyo si lo deseas.

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *