Hacemos una verificación completa cada enunciado en nuestra página web con el objetivo de enseñarte en todo momento información certera y actualizada.
Solución:
PyTorch no tiene una función para calcular la cantidad total de parámetros como lo hace Keras, pero es posible sumar la cantidad de elementos para cada grupo de parámetros:
pytorch_total_params = sum(p.numel() for p in model.parameters())
Si desea calcular sólo el entrenable parámetros:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
Respuesta inspirada en esta respuesta en PyTorch Forums.
Nota: estoy respondiendo mi propia pregunta. Si alguien tiene una mejor solución, por favor comparta con nosotros.
Para obtener el recuento de parámetros de cada capa como Keras, PyTorch tiene model.named_paramters() que devuelve un iterador tanto del nombre del parámetro como del parámetro en sí.
Aquí hay un ejemplo:
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
total_params+=param
print(table)
print(f"Total Trainable Params: total_params")
return total_params
count_parameters(net)
La salida sería algo como esto:
+-------------------+------------+
| Modules | Parameters |
+-------------------+------------+
| embeddings.weight | 922866 |
| conv1.weight | 1048576 |
| conv1.bias | 1024 |
| bn1.weight | 1024 |
| bn1.bias | 1024 |
| conv2.weight | 2097152 |
| conv2.bias | 1024 |
| bn2.weight | 1024 |
| bn2.bias | 1024 |
| conv3.weight | 2097152 |
| conv3.bias | 1024 |
| bn3.weight | 1024 |
| bn3.bias | 1024 |
| lin1.weight | 50331648 |
| lin1.bias | 512 |
| lin2.weight | 265728 |
| lin2.bias | 519 |
+-------------------+------------+
Total Trainable Params: 56773369
Si desea calcular la cantidad de pesos y sesgos en cada capa sin instanciar el modelo, simplemente puede cargar el archivo sin procesar e iterar sobre el resultado. collections.OrderedDict
al igual que:
import torch
tensor_dict = torch.load('model.dat', map_location='cpu') # OrderedDict
tensor_list = list(tensor_dict.items())
for layer_tensor_name, tensor in tensor_list:
print('Layer : elements'.format(layer_tensor_name, torch.numel(tensor)))
Obtendrás algo como
conv1.weight: 312
conv1.bias: 26
batch_norm1.weight: 26
batch_norm1.bias: 26
batch_norm1.running_mean: 26
batch_norm1.running_var: 26
conv2.weight: 2340
conv2.bias: 10
batch_norm2.weight: 10
batch_norm2.bias: 10
batch_norm2.running_mean: 10
batch_norm2.running_var: 10
fcs.layers.0.weight: 135200
fcs.layers.0.bias: 260
fcs.layers.1.weight: 33800
fcs.layers.1.bias: 130
fcs.batch_norm_layers.0.weight: 260
fcs.batch_norm_layers.0.bias: 260
fcs.batch_norm_layers.0.running_mean: 260
fcs.batch_norm_layers.0.running_var: 260
Comentarios y calificaciones de la guía
Más adelante puedes encontrar las aclaraciones de otros usuarios, tú de igual manera puedes dejar el tuyo si lo deseas.