Saltar al contenido

Keras: Cómo usar fit_generator con múltiples entradas

Nuestro team redactor ha estado mucho tiempo buscando para darle respuesta a tu duda, te regalamos la respuesta así que deseamos serte de gran apoyo.

Solución:

Prueba este generador:

def generator_two_img(X1, X2, y, batch_size):
    genX1 = gen.flow(X1, y,  batch_size=batch_size, seed=1)
    genX2 = gen.flow(X2, y, batch_size=batch_size, seed=1)
    while True:
        X1i = genX1.next()
        X2i = genX2.next()
        yield [X1i[0], X2i[0]], X1i[1]

Generador para 3 entradas:

def generator_three_img(X1, X2, X3, y, batch_size):
    genX1 = gen.flow(X1, y,  batch_size=batch_size, seed=1)
    genX2 = gen.flow(X2, y, batch_size=batch_size, seed=1)
    genX3 = gen.flow(X3, y, batch_size=batch_size, seed=1)
    while True:
        X1i = genX1.next()
        X2i = genX2.next()
        X3i = genX3.next()
        yield [X1i[0], X2i[0], X3i[0]], X1i[1]

Tengo una implementación para múltiples entradas para TimeseriesGenerator que lo he adaptado (no he podido probarlo lamentablemente) para cumplir con este ejemplo con ImageDataGenerator. Mi enfoque fue construir una clase contenedora para los múltiples generadores de keras.utils.Sequence y luego implementar los métodos base de la misma: __len__ y __getitem__:

from keras.preprocessing.image import ImageDataGenerator
from keras.utils import Sequence


class MultipleInputGenerator(Sequence):
    """Wrapper of 2 ImageDataGenerator"""

    def __init__(self, X1, X2, Y, batch_size):
        # Keras generator
        self.generator = ImageDataGenerator(rotation_range=15, 
                                            width_shift_range=0.2,
                                            height_shift_range=0.2,
                                            shear_range=0.2,
                                            zoom_range=0.2,
                                            horizontal_flip=True, 
                                            fill_mode='nearest')

        # Real time multiple input data augmentation
        self.genX1 = self.generator.flow(X1, Y, batch_size=batch_size)
        self.genX2 = self.generator.flow(X2, Y, batch_size=batch_size)

    def __len__(self):
        """It is mandatory to implement it on Keras Sequence"""
        return self.genX1.__len__()

    def __getitem__(self, index):
        """Getting items from the 2 generators and packing them"""
        X1_batch, Y_batch = self.genX1.__getitem__(index)
        X2_batch, Y_batch = self.genX2.__getitem__(index)

        X_batch = [X1_batch, X2_batch]

        return X_batch, Y_batch

Puedes usar este generador con model.fit_generator() una vez que el generador ha sido instanciado.

Reseñas y calificaciones del tutorial

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)



Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *