Saltar al contenido

¿Cómo reemplazar (o insertar) la capa intermedia en el modelo Keras?

Solución:

La siguiente función le permite insertar una nueva capa antes de, después o para reemplazar cada capa del modelo original cuyo nombre coincide con una expresión regular, incluidos los modelos no secuenciales como DenseNet o ResNet.

import re
from keras.models import Model

def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
                        insert_layer_name=None, position='after'):

    # Auxiliary dictionary to describe the network graph
    network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}

    # Set the input layers of each layer
    for layer in model.layers:
        for node in layer._outbound_nodes:
            layer_name = node.outbound_layer.name
            if layer_name not in network_dict['input_layers_of']:
                network_dict['input_layers_of'].update(
                        {layer_name: [layer.name]})
            else:
                network_dict['input_layers_of'][layer_name].append(layer.name)

    # Set the output tensor of the input layer
    network_dict['new_output_tensor_of'].update(
            {model.layers[0].name: model.input})

    # Iterate over all layers after the input
    model_outputs = []
    for layer in model.layers[1:]:

        # Determine input tensors
        layer_input = [network_dict['new_output_tensor_of'][layer_aux] 
                for layer_aux in network_dict['input_layers_of'][layer.name]]
        if len(layer_input) == 1:
            layer_input = layer_input[0]

        # Insert layer if name matches the regular expression
        if re.match(layer_regex, layer.name):
            if position == 'replace':
                x = layer_input
            elif position == 'after':
                x = layer(layer_input)
            elif position == 'before':
                pass
            else:
                raise ValueError('position must be: before, after or replace')

            new_layer = insert_layer_factory()
            if insert_layer_name:
                new_layer.name = insert_layer_name
            else:
                new_layer.name="{}_{}".format(layer.name, 
                                                new_layer.name)
            x = new_layer(x)
            print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name,
                                                            layer.name, position))
            if position == 'before':
                x = layer(x)
        else:
            x = layer(layer_input)

        # Set new output tensor (the original one, or the one of the inserted
        # layer)
        network_dict['new_output_tensor_of'].update({layer.name: x})

        # Save tensor in output list if it is output in initial model
        if layer_name in model.output_names:
            model_outputs.append(x)

    return Model(inputs=model.inputs, outputs=model_outputs)

La diferencia con respecto al caso más simple de un modelo puramente secuencial es que antes de iterar sobre las capas para encontrar la capa clave, primero analiza el gráfico y almacena las capas de entrada de cada capa en un diccionario auxiliar. Luego, a medida que itera sobre las capas, también almacena el nuevo tensor de salida de cada capa, que se usa para determinar las capas de entrada de cada capa, al construir el nuevo modelo.

Un caso de uso sería el siguiente, donde se inserta una capa de abandono después de cada capa de activación de ResNet50:

from keras.applications.resnet50 import ResNet50
from keras.models import load_model

model = ResNet50()
def dropout_layer_factory():
    return Dropout(rate=0.2, name="dropout")
model = insert_layer_nonseq(model, '.*activation.*', dropout_layer_factory)

# Fix possible problems with new model
model.save('temp.h5')
model = load_model('temp.h5')

model.summary()

Puede utilizar las siguientes funciones:

def replace_intermediate_layer_in_keras(model, layer_id, new_layer):
    from keras.models import Model

    layers = [l for l in model.layers]

    x = layers[0].output
    for i in range(1, len(layers)):
        if i == layer_id:
            x = new_layer(x)
        else:
            x = layers[i](x)

    new_model = Model(input=layers[0].input, output=x)
    return new_model

def insert_intermediate_layer_in_keras(model, layer_id, new_layer):
    from keras.models import Model

    layers = [l for l in model.layers]

    x = layers[0].output
    for i in range(1, len(layers)):
        if i == layer_id:
            x = new_layer(x)
        x = layers[i](x)

    new_model = Model(input=layers[0].input, output=x)
    return new_model

Ejemplo:

if __name__ == '__main__':
    from keras.layers import Conv2D, BatchNormalization
    model = keras_simple_model()
    print(model.summary())
    model = replace_intermediate_layer_in_keras(model, 3, Conv2D(4, (3, 3), activation=None, padding='same', name="conv2_repl", use_bias=False))
    print(model.summary())
    model = insert_intermediate_layer_in_keras(model, 4, BatchNormalization())
    print(model.summary())

Hay algunas limitaciones en los reemplazos debido a las formas de las capas, etc.

Así fue como lo hice:

import keras 
from keras.models import Model 
from tqdm import tqdm 
from keras import backend as K

def make_list(X):
    if isinstance(X, list):
        return X
    return [X]

def list_no_list(X):
    if len(X) == 1:
        return X[0]
    return X

def replace_layer(model, replace_layer_subname, replacement_fn,
**kwargs):
    """
    args:
        model :: keras.models.Model instance
        replace_layer_subname :: str -- if str in layer name, replace it
        replacement_fn :: fn to call to replace all instances
            > fn output must produce shape as the replaced layers input
    returns:
        new model with replaced layers
    quick examples:
        want to just remove all layers with 'batch_norm' in the name:
            > new_model = replace_layer(model, 'batch_norm', lambda **kwargs : (lambda u:u))
        want to replace all Conv1D(N, m, padding='same') with an LSTM (lets say all have 'conv1d' in name)
            > new_model = replace_layer(model, 'conv1d', lambda layer, **kwargs: LSTM(units=layer.filters, return_sequences=True)
    """
    model_inputs = []
    model_outputs = []
    tsr_dict = {}

    model_output_names = [out.name for out in make_list(model.output)]

    for i, layer in enumerate(model.layers):
        ### Loop if layer is used multiple times
        for j in range(len(layer._inbound_nodes)):

            ### check layer inp/outp
            inpt_names = [inp.name for inp in make_list(layer.get_input_at(j))]
            outp_names = [out.name for out in make_list(layer.get_output_at(j))]

            ### setup model inputs
            if 'input' in layer.name:
                for inpt_tsr in make_list(layer.get_output_at(j)):
                    model_inputs.append(inpt_tsr)
                    tsr_dict[inpt_tsr.name] = inpt_tsr
                continue

            ### setup layer inputs
            inpt = list_no_list([tsr_dict[name] for name in inpt_names])

            ### remake layer 
            if replace_layer_subname in layer.name:
                print('replacing '+layer.name)
                x = replacement_fn(old_layer=layer, **kwargs)(inpt)
            else:
                x = layer(inpt)

            ### reinstantialize outputs into dict
            for name, out_tsr in zip(outp_names, make_list(x)):

                ### check if is an output
                if name in model_output_names:
                    model_outputs.append(out_tsr)
                tsr_dict[name] = out_tsr

    return Model(model_inputs, model_outputs)

Tengo una capa personalizada (tomada de alguien en línea) llamada BatchNormalizationFreeze, por lo que un ejemplo de uso es este:

 new_model = model_replacement(model, 'batch_normal', lambda **kwargs : BatchNormalizationFreeze()(x))

Si va a hacer varias capas, simplemente reemplace la función de reemplazo con un modelo psuedo que las haga todas a la vez

¡Haz clic para puntuar esta entrada!
(Votos: 0 Promedio: 0)


Tags : /

Utiliza Nuestro Buscador

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *