Saltar al contenido

¿Qué hace model.eval () en pytorch?

Solución:

model.eval() es una especie de cambio para algunas capas / partes específicas del modelo que se comportan de manera diferente durante el entrenamiento y el tiempo de inferencia (evaluación). Por ejemplo, Dropouts Layers, BatchNorm Layers, etc. Debe desactivarlas durante la evaluación del modelo y .eval() lo hará por ti. Además, la práctica común para la evaluación / validación es utilizar torch.no_grad() en pareja con model.eval() para desactivar el cálculo de gradientes:

# evaluate model:
model.eval()

with torch.no_grad():
    ...
    out_data = model(data)
    ...

PERO, no olvides volver a training modo después del paso de evaluación:

# training step
...
model.train()
...

model.eval es un método de torch.nn.Module

ingrese la descripción de la imagen aquí

El método opuesto es el model.train explicado muy bien por Umang Gupta.

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *