Este dilema se puede abordar de diversas formas, pero en este caso te damos la resolución más completa para nosotros.
Solución:
Respuesta dada por ptrblck
de la comunidad PyTorch. ¡Muchas gracias!
nb_classes = 9
confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
for t, p in zip(classes.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
print(confusion_matrix)
Para obtener la precisión por clase:
print(confusion_matrix.diag()/confusion_matrix.sum(1))
Aquí hay un enfoque ligeramente modificado (directo) usando confusion_matrix de sklearn: –
from sklearn.metrics import confusion_matrix
nb_classes = 9
# Initialize the prediction and label lists(tensors)
predlist=torch.zeros(0,dtype=torch.long, device='cpu')
lbllist=torch.zeros(0,dtype=torch.long, device='cpu')
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
# Append batch prediction results
predlist=torch.cat([predlist,preds.view(-1).cpu()])
lbllist=torch.cat([lbllist,classes.view(-1).cpu()])
# Confusion matrix
conf_mat=confusion_matrix(lbllist.numpy(), predlist.numpy())
print(conf_mat)
# Per-class accuracy
class_accuracy=100*conf_mat.diagonal()/conf_mat.sum(1)
print(class_accuracy)
Otra forma sencilla de obtener precisión es usar sklearns “accuracy_score”. Aquí hay un ejemplo:
from sklearn.metrics import accuracy_score
y_pred = y_pred.data.numpy()
accuracy = accuracy_score(labels, np.argmax(y_pred, axis=1))
Primero necesitas obtener los datos de la variable. “y_pred” son las predicciones de su modelo y, por supuesto, las etiquetas son sus etiquetas.
np.argmax devuelve el índice del mayor valor dentro del array. Queremos el valor más grande, ya que corresponde a la clase de probabilidad más alta cuando usamos softmax para la clasificación de clases múltiples. La puntuación de precisión devolverá un porcentaje de coincidencias entre las etiquetas y y_pred.
Si estás contento con lo expuesto, puedes dejar un enunciado acerca de qué le añadirías a esta división.