Saltar al contenido

Clasificación binaria de PyTorch: ¿la misma estructura de red, datos ‘más simples’, pero peor rendimiento?

Solución:

TL; DR

Sus datos de entrada no están normalizados.

  1. usar x_data = (x_data - x_data.mean()) / x_data.std()
  2. aumentar la tasa de aprendizaje optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

Obtendrás
ingrese la descripción de la imagen aquí

convergencia en solo 1000 iteraciones.

Más detalles

La diferencia clave entre los dos ejemplos que tiene es que los datos x en el primer ejemplo se centra en (0, 0) y tiene una varianza muy baja.
Por otro lado, los datos del segundo ejemplo se centran en 92 y tienen una varianza relativamente grande.

Este sesgo inicial en los datos no se tiene en cuenta cuando se inicializan aleatoriamente las ponderaciones, lo cual se realiza en base a la suposición de que las entradas se distribuyen aproximadamente normalmente alrededor cero.
Es casi imposible que el proceso de optimización compense esta gran desviación, por lo que el modelo se atasca en una solución subóptima.

Una vez que normaliza las entradas, al restar la media y dividir por la std, el proceso de optimización se vuelve estable nuevamente y converge rápidamente a una buena solución.

Para obtener más detalles sobre la normalización de entrada y la inicialización de pesos, puede leer la sección 2.2 en Él et al Profundizando en los rectificadores: superando el rendimiento a nivel humano en la clasificación de ImageNet (ICCV 2015).

¿Qué pasa si no puedo normalizar los datos?

Si, por alguna razón, no puede calcular los datos medios y estándar por adelantado, aún puede usar nn.BatchNorm1d para estimar y normalizar los datos como parte del proceso de formación. Por ejemplo

class Model(nn.Module):
    def __init__(self, input_size, H1, output_size):
        super().__init__()
        self.bn = nn.BatchNorm1d(input_size)  # adding batchnorm
        self.linear = nn.Linear(input_size, H1)
        self.linear2 = nn.Linear(H1, output_size)
    
    def forward(self, x):
        x = torch.sigmoid(self.linear(self.bn(x)))  # batchnorm the input x
        x = torch.sigmoid(self.linear2(x))
        return x

Esta modificación sin cualquier cambio en los datos de entrada, produce una convergencia similar después de solo 1000 épocas:
ingrese la descripción de la imagen aquí

Un comentario menor

Para estabilidad numérica, es mejor usar nn.BCEWithLogitsLoss en lugar de nn.BCELoss. Para este fin, debe quitar el torch.sigmoid desde el forward() salida, la sigmoid se computará dentro de la pérdida.
Vea, por ejemplo, este hilo sobre la pérdida sigmoidea + entropía cruzada relacionada para predicciones binarias.

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *