Saltar al contenido

pytorch – conexión entre loss.backward () y optimizer.step ()

Solución:

Sin profundizar demasiado en los aspectos internos de pytorch, puedo ofrecer una respuesta simplista:

Recuerde que al inicializar optimizer le dice explícitamente qué parámetros (tensores) del modelo debe actualizar. Los gradientes son “almacenados” por los propios tensores (tienen un grad y un requires_grad atributos) una vez que llamas backward() en la pérdida. Después de calcular los gradientes para todos los tensores en el modelo, llamar optimizer.step() hace que el optimizador repita todos los parámetros (tensores) que se supone que debe actualizar y usar sus almacenados internamente grad para actualizar sus valores.

En esta respuesta se puede encontrar más información sobre gráficos computacionales y la información adicional “grad” almacenada en los tensores de pytorch.

Hacer referencia a los parámetros por el optimizador a veces puede causar problemas, por ejemplo, cuando el modelo se mueve a la GPU después inicializando el optimizador. Asegúrese de haber terminado de configurar su modelo antes de construyendo el optimizador. Consulte esta respuesta para obtener más detalles.

Cuando usted llama loss.backward(), todo lo que hace es calcular el gradiente de pérdida con todos los parámetros de pérdida que tienen requires_grad = True y guárdelos en parameter.grad atributo para cada parámetro.

optimizer.step() actualiza todos los parámetros basados ​​en parameter.grad

Digamos que definimos un modelo: modely función de pérdida: criterion y tenemos la siguiente secuencia de pasos:

pred = model(input)
loss = criterion(pred, true_labels)
loss.backward()

pred tendrá un grad_fn atributo, que hace referencia a una función que lo creó y lo vincula al modelo. Por lo tanto, loss.backward() tendrá información sobre el modelo con el que está trabajando.

Intenta eliminar grad_fn atributo, por ejemplo con:

pred = pred.clone().detach()

Entonces los gradientes del modelo serán None y en consecuencia, los pesos no se actualizarán.

Y el optimizador está vinculado al modelo porque pasamos model.parameters() cuando creamos el optimizador.

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *