Entiende el código de forma correcta antes de usarlo a tu trabajo si tquieres aportar algo puedes dejarlo en los comentarios.
Solución:
Encontré esta página en su repositorio de github, simplemente pegaré el contenido aquí.
Enfoque recomendado para guardar un modelo
Hay dos enfoques principales para serializar y restaurar un modelo.
El primero (recomendado) guarda y carga solo los parámetros del modelo:
torch.save(the_model.state_dict(), PATH)
Entonces despúes:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
El segundo guarda y carga todo el modelo:
torch.save(the_model, PATH)
Entonces despúes:
the_model = torch.load(PATH)
Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y la estructura de directorio exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunas refactorizaciones serias.
Depende de lo que quieras hacer.
Caso #1: Guarde el modelo para usarlo usted mismo para la inferencia: guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación. Esto se hace porque normalmente tiene BatchNorm
y Dropout
capas que por defecto están en modo tren en construcción:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
Caso #2: Guardar modelo para retomar el entrenamiento más tarde: si necesita seguir entrenando el modelo que está a punto de guardar, necesita guardar más que solo el modelo. También debe guardar el estado del optimizador, las épocas, la puntuación, etc. Lo haría así:
state =
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
torch.save(state, filepath)
Para reanudar el entrenamiento harías cosas como: state = torch.load(filepath)
y luego, para restaurar el estado de cada objeto individual, algo como esto:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
Ya que estás reanudando el entrenamiento, NO llamar model.eval()
una vez que restablezca los estados al cargar.
Caso # 3: Modelo para ser utilizado por otra persona sin acceso a su código: En Tensorflow puedes crear un .pb
archivo que define tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve
. La forma equivalente de hacer esto en Pytorch sería:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
Esta forma todavía no es a prueba de balas y dado que pytorch todavía está experimentando muchos cambios, no lo recomendaría.
La biblioteca pickle Python implementa protocolos binarios para serializar y deserializar un objeto de Python.
Cuando usted import torch
(o cuando usa PyTorch) lo hará import pickle
para ti y no necesitas llamar pickle.dump()
y pickle.load()
directamente, cuáles son los métodos para guardar y cargar el objeto.
De hecho, torch.save()
y torch.load()
envolverá pickle.dump()
y pickle.load()
para usted.
UN state_dict
la otra respuesta mencionada merece solo unas pocas notas más.
Qué state_dict
Qué tenemos dentro de PyTorch? en realidad hay dos state_dict
s.
El modelo PyTorch es torch.nn.Module
posee model.parameters()
llame para obtener parámetros aprendibles (w y b). Estos parámetros de aprendizaje, una vez establecidos aleatoriamente, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict
.
El segundo state_dict
es el dict de estado del optimizador. Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizador state_dict
está arreglado. Nada que aprender allí.
Porque state_dict
Los objetos son diccionarios de Python, se pueden guardar, actualizar, modificar y restaurar fácilmente, lo que agrega una gran cantidad de modularidad a los modelos y optimizadores de PyTorch.
Vamos a crear un modelo súper simple para explicar esto:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "t", optimizer.state_dict()[var_name])
Este código generará lo siguiente:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state
param_groups ['lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]]
Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencias
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Tenga en cuenta que solo las capas con parámetros aprendibles (capas convolucionales, capas lineales, etc.) y los búferes registrados (capas de normas por lotes) tienen entradas en el modelo. state_dict
.
Cosas que no se pueden aprender, pertenecen al objeto del optimizador state_dict
que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.
El resto de la historia es la misma; en la fase de inferencia (esta es una fase en la que usamos el modelo después del entrenamiento) para predecir; predecimos basándonos en los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros. model.state_dict()
.
torch.save(model.state_dict(), filepath)
Y para usar más tarde model.load_state_dict(torch.load(filepath)) model.eval()
Nota: No olvides la última línea. model.eval()
esto es crucial después de cargar el modelo.
Tampoco trates de guardar torch.save(model.parameters(), filepath)
. Él model.parameters()
es solo el objeto generador.
Por otro lado, torch.save(model, filepath)
guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict
. Verifique la otra excelente respuesta de @Jadiel de Armas para guardar el dictado de estado del optimizador.
Eres capaz de asentar nuestra investigación fijando un comentario y valorándolo te lo agradecemos.