Saltar al contenido

¿Cómo convertir imágenes RGB a escala de grises en el cargador de datos PyTorch?

Solución:

Cuando usas ImageFolder class y sin un cargador personalizado, pytorch usa PIL para cargar la imagen y la convierte a RGB. Cargador predeterminado si el backend de la imagen de Torchvision es PIL:

def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')

Puedes usar Escala de grises de Torchvision función en transforma. Convertirá la imagen RGB de 3 canales en escala de grises de 1 canal. Obtenga más información sobre esto en https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Grayscale

A continuación se muestra un código de muestra,

import torchvision as tv
import numpy as np
import torch.utils.data as data
dataDir="D:\general\ML_DL\datasets\CIFAR"
trainTransform  = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1),
                                    tv.transforms.ToTensor(), 
                                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainSet        = tv.datasets.CIFAR10(dataDir, train=True, download=False, transform=trainTransform)
dataloader      = data.DataLoader(trainSet, batch_size=1, shuffle=False, num_workers=0)
images, labels  = iter(dataloader).next()
print (images.size())
¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *