Solución:
model.eval()
es una especie de cambio para algunas capas / partes específicas del modelo que se comportan de manera diferente durante el entrenamiento y el tiempo de inferencia (evaluación). Por ejemplo, Dropouts Layers, BatchNorm Layers, etc. Debe desactivarlas durante la evaluación del modelo y .eval()
lo hará por ti. Además, la práctica común para la evaluación / validación es utilizar torch.no_grad()
en pareja con model.eval()
para desactivar el cálculo de gradientes:
# evaluate model:
model.eval()
with torch.no_grad():
...
out_data = model(data)
...
PERO, no olvides volver a training
modo después del paso de evaluación:
# training step
...
model.train()
...
model.eval
es un método de torch.nn.Module
El método opuesto es el model.train explicado muy bien por Umang Gupta.
¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)