Saltar al contenido

¿La mejor manera de guardar un modelo entrenado en PyTorch?

Entiende el código de forma correcta antes de usarlo a tu trabajo si tquieres aportar algo puedes dejarlo en los comentarios.

Solución:

Encontré esta página en su repositorio de github, simplemente pegaré el contenido aquí.


Enfoque recomendado para guardar un modelo

Hay dos enfoques principales para serializar y restaurar un modelo.

El primero (recomendado) guarda y carga solo los parámetros del modelo:

torch.save(the_model.state_dict(), PATH)

Entonces despúes:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

El segundo guarda y carga todo el modelo:

torch.save(the_model, PATH)

Entonces despúes:

the_model = torch.load(PATH)

Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y la estructura de directorio exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunas refactorizaciones serias.

Depende de lo que quieras hacer.

Caso #1: Guarde el modelo para usarlo usted mismo para la inferencia: guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación. Esto se hace porque normalmente tiene BatchNorm y Dropout capas que por defecto están en modo tren en construcción:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Caso #2: Guardar modelo para retomar el entrenamiento más tarde: si necesita seguir entrenando el modelo que está a punto de guardar, necesita guardar más que solo el modelo. También debe guardar el estado del optimizador, las épocas, la puntuación, etc. Lo haría así:

state = 
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...

torch.save(state, filepath)

Para reanudar el entrenamiento harías cosas como: state = torch.load(filepath)y luego, para restaurar el estado de cada objeto individual, algo como esto:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Ya que estás reanudando el entrenamiento, NO llamar model.eval() una vez que restablezca los estados al cargar.

Caso # 3: Modelo para ser utilizado por otra persona sin acceso a su código: En Tensorflow puedes crear un .pb archivo que define tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve. La forma equivalente de hacer esto en Pytorch sería:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Esta forma todavía no es a prueba de balas y dado que pytorch todavía está experimentando muchos cambios, no lo recomendaría.

La biblioteca pickle Python implementa protocolos binarios para serializar y deserializar un objeto de Python.

Cuando usted import torch (o cuando usa PyTorch) lo hará import pickle para ti y no necesitas llamar pickle.dump() y pickle.load() directamente, cuáles son los métodos para guardar y cargar el objeto.

De hecho, torch.save() y torch.load() envolverá pickle.dump() y pickle.load() para usted.

UN state_dict la otra respuesta mencionada merece solo unas pocas notas más.

Qué state_dict Qué tenemos dentro de PyTorch? en realidad hay dos state_dicts.

El modelo PyTorch es torch.nn.Module posee model.parameters() llame para obtener parámetros aprendibles (w y b). Estos parámetros de aprendizaje, una vez establecidos aleatoriamente, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict.

El segundo state_dict es el dict de estado del optimizador. Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizador state_dict está arreglado. Nada que aprender allí.

Porque state_dict Los objetos son diccionarios de Python, se pueden guardar, actualizar, modificar y restaurar fácilmente, lo que agrega una gran cantidad de modularidad a los modelos y optimizadores de PyTorch.

Vamos a crear un modelo súper simple para explicar esto:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "t", optimizer.state_dict()[var_name])

Este código generará lo siguiente:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    
param_groups     ['lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]]

Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencias

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Tenga en cuenta que solo las capas con parámetros aprendibles (capas convolucionales, capas lineales, etc.) y los búferes registrados (capas de normas por lotes) tienen entradas en el modelo. state_dict.

Cosas que no se pueden aprender, pertenecen al objeto del optimizador state_dictque contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.

El resto de la historia es la misma; en la fase de inferencia (esta es una fase en la que usamos el modelo después del entrenamiento) para predecir; predecimos basándonos en los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros. model.state_dict().

torch.save(model.state_dict(), filepath)

Y para usar más tarde model.load_state_dict(torch.load(filepath)) model.eval()

Nota: No olvides la última línea. model.eval() esto es crucial después de cargar el modelo.

Tampoco trates de guardar torch.save(model.parameters(), filepath). Él model.parameters() es solo el objeto generador.

Por otro lado, torch.save(model, filepath) guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict. Verifique la otra excelente respuesta de @Jadiel de Armas para guardar el dictado de estado del optimizador.

Eres capaz de asentar nuestra investigación fijando un comentario y valorándolo te lo agradecemos.

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *