Introducción a las GAN con Python y TensorFlow

    Introducción

    Los modelos generativos son una familia de arquitecturas de IA cuyo objetivo es crear muestras de datos desde cero. Lo logran capturando las distribuciones de datos del tipo de cosas que queremos generar.

    Este tipo de modelos se están investigando intensamente y hay una gran cantidad de publicidad a su alrededor. Solo mire el cuadro que muestra la cantidad de artículos publicados en el campo durante los últimos años:

    Desde 2014, cuando el primer artículo sobre redes generativas de confrontación se publicó, los modelos generativos se están volviendo increíblemente poderosos y ahora podemos generar muestras de datos hiperrealistas para una amplia gama de distribuciones: imágenes, videos, música, escritos, etc.

    A continuación, se muestran algunos ejemplos de imágenes generadas por una GAN:

    ¿Qué son los modelos generativos?

    El marco de GAN

    El marco más exitoso propuesto para modelos generativos, al menos en los últimos años, toma el nombre de Redes generativas antagónicas (GAN).

    En pocas palabras, un GAN se compone de dos modelos separados, representados por redes neuronales: un generador G y un discriminador D. El objetivo del discriminador es saber si una muestra de datos proviene de una distribución de datos real, o si en su lugar se genera por G.

    El objetivo del generador es generar muestras de datos para engañar al discriminador.

    El generador no es más que una red neuronal profunda. Toma como entrada un vector de ruido aleatorio (generalmente gaussiano o de una distribución uniforme) y genera una muestra de datos de la distribución que queremos capturar.

    El discriminador es, nuevamente, solo una red neuronal. Su objetivo es, como su nombre indica, discriminar entre muestras reales y falsas. En consecuencia, su entrada es una muestra de datos, ya sea proveniente del generador o de la distribución de datos real.

    La salida es un número simple, que representa la probabilidad de que la entrada sea real. Una probabilidad alta significa que el discriminador está seguro de que las muestras que le están dando son genuinas. Por el contrario, una probabilidad baja muestra una alta confianza en el hecho de que la muestra proviene de la red de generadores:

    Imagínese un falsificador de arte que intenta crear obras de arte falsas y un crítico de arte que necesita distinguir entre pinturas adecuadas y falsas.

    En este escenario, el crítico actúa como nuestro discriminador y el falsificador es el generador, tomando retroalimentación del crítico para mejorar sus habilidades y hacer que su arte forjado parezca más convincente:

    Formación

    Entrenar a un GAN puede ser algo doloroso. La inestabilidad del entrenamiento siempre ha sido un problema y muchas investigaciones se han centrado en hacer que el entrenamiento sea más estable.

    La función objetivo básica de un modelo vanilla GAN es la siguiente:

    Aquí, re se refiere a la red discriminadora, mientras que GRAMO obviamente se refiere al generador.

    Como muestra la fórmula, el generador se optimiza para confundir al discriminador al máximo, al tratar de hacer que genere altas probabilidades de muestras de datos falsas.

    Por el contrario, el discriminador intenta mejorar la distinción entre muestras procedentes de G y muestras procedentes de la distribución real.

    El término adversario proviene exactamente de la forma en que se entrenan los GANS, enfrentando a las dos redes entre sí.

    Una vez que hemos entrenado nuestro modelo, el discriminador ya no es necesario. Todo lo que tenemos que hacer es alimentar al generador con un vector de ruido aleatorio y, con suerte, obtendremos una muestra de datos artificial y realista como resultado.

    Problemas de GAN

    Entonces, ¿por qué son tan difíciles de entrenar las GAN? Como se indicó anteriormente, las GAN son muy difíciles de entrenar en su forma básica. Veremos brevemente por qué este es el caso.

    Equilibrio de Nash difícil de alcanzar

    Dado que estas dos redes se disparan información entre sí, se podría representar como un juego en el que uno adivina si la entrada es real o no.

    El marco GAN es un juego no convexo, de dos jugadores y no cooperativo con parámetros continuos de alta dimensión, en el que cada jugador quiere minimizar su función de coste. El óptimo de este proceso toma el nombre de Equilibrio de Nash – donde cada jugador no se desempeñará mejor cambiando una estrategia, dado que el otro jugador no cambia su estrategia.

    Sin embargo, las GAN se entrenan típicamente usando técnicas de descenso de gradientes que están diseñadas para encontrar el valor bajo de una función de costo y no encontrar el Equilibrio de Nash de un juego.

    Colapso de modo

    La mayoría de las distribuciones de datos son multimodales. Toma el Conjunto de datos MNIST: hay 10 “modos” de datos, que se refieren a los diferentes dígitos entre 0 y 9.

    Un buen modelo generativo podría producir muestras con suficiente variabilidad, pudiendo así generar muestras de todas las diferentes clases.

    Sin embargo, esto no siempre sucede.

    Digamos que el generador se vuelve realmente bueno para producir el dígito “3”. Si las muestras producidas son lo suficientemente convincentes, el discriminador probablemente les asignará altas probabilidades.

    Como resultado, el generador se verá impulsado a producir muestras que provengan de ese modo específico, ignorando las otras clases la mayor parte del tiempo. Básicamente enviará spam al mismo número y con cada número que pase el discriminador, este comportamiento solo se aplicará más.

    Gradiente decreciente

    Muy similar al ejemplo anterior, el discriminador puede tener demasiado éxito en distinguir muestras de datos. Cuando eso es cierto, el gradiente del generador desaparece, comienza a aprender cada vez menos y no logra converger.

    Este desequilibrio, al igual que el anterior, se puede producir si entrenamos las redes por separado. La evolución de la red neuronal puede ser bastante impredecible, lo que puede llevar a que una esté por delante de la otra por una milla. Si los capacitamos juntos, principalmente nos aseguramos de que estas cosas no sucedan.

    Lo último

    Sería imposible ofrecer una visión completa de todas las mejoras y desarrollos que hicieron que las GAN fueran más poderosas y estables en los últimos años.

    En cambio, lo que haré es compilar una lista de las arquitecturas y técnicas más exitosas, proporcionando enlaces a recursos relevantes para profundizar más.

    DCGAN

    Las GAN convolucionales profundas (DCGAN) introdujeron convoluciones en las redes generadoras y discriminadoras.

    Sin embargo, esto no fue simplemente una cuestión de agregar capas convolucionales al modelo, ya que el entrenamiento se volvió aún más inestable.

    Se tuvieron que aplicar varios trucos para que los DCGAN fueran útiles:

    • La normalización de lotes se aplicó tanto al generador como a la red discriminadora
    • La deserción se utiliza como técnica de regularización
    • El generador necesitaba una forma de aumentar la muestra del vector de entrada aleatoria a una imagen de salida. Aquí se emplea la transposición de capas convolucionales
    • LeakyRelu y TanH las activaciones se utilizan en ambas redes

    WGAN

    GAN de Wasserstein (WGAN) tienen como objetivo mejorar la estabilidad del entrenamiento. Hay una gran cantidad de matemáticas detrás de este tipo de modelo. Se puede encontrar una explicación más accesible aquí.

    Las ideas básicas aquí fueron proponer una nueva función de costos que tenga un gradiente más suave en todas partes.

    La nueva función de costo usa una métrica llamada distancia de Wasserstein, que tiene un gradiente más suave en todas partes.

    Como resultado, el discriminador, que ahora se llama crítico, genera valores de confianza que ya no deben interpretarse como una probabilidad. Los valores altos significan que el modelo confía en que la entrada es real.

    Dos mejoras significativas para WGAN son:

    • No tiene signos de colapso de modo en los experimentos.
    • El generador aún puede aprender cuando el crítico se desempeña bien

    SAGANs

    GAN de auto-atención (SAGAN) introducen un mecanismo de atención al marco de GAN.

    Los mecanismos de atención permiten utilizar información global de forma local. Lo que esto significa es que podemos capturar el significado de diferentes partes de una imagen y usar esa información para producir mejores muestras.

    Esto proviene de la observación de que las convoluciones son bastante malas para capturar dependencias a largo plazo en muestras de entrada, ya que la convolución es una operación local cuyo campo receptivo depende del tamaño espacial del kernel.

    Esto significa que, por ejemplo, no es posible que una salida en la posición superior izquierda de una imagen tenga relación con la salida en la parte inferior derecha.

    Una forma de resolver este problema sería utilizar núcleos con tamaños más grandes para capturar más información. Sin embargo, esto haría que el modelo fuera computacionalmente ineficaz y muy lento de entrenar.

    La auto-atención resuelve este problema, proporcionando una forma eficiente de capturar información global y usarla localmente cuando pueda resultar útil.

    BigGANs

    BigGANs son, en el momento de redactar este artículo, considerados más o menos de última generación, en lo que respecta a la calidad de las muestras generadas.

    Lo que hicieron los investigadores aquí fue reunir todo lo que había estado funcionando hasta ese momento y luego escalarlo masivamente.
    Su modelo de referencia era de hecho un SAGAN, al que añadieron algunos trucos para mejorar la estabilidad.

    Demostraron que las GAN se benefician enormemente del escalado, incluso cuando no se introducen más mejoras funcionales en el modelo, como se cita en el documento original:

    Hemos demostrado que las Redes Adversarias Generativas entrenadas para modelar imágenes naturales de múltiples categorías se benefician enormemente de la ampliación, tanto en términos de fidelidad como de variedad de las muestras generadas. Como resultado, nuestros modelos establecen un nuevo nivel de rendimiento entre los modelos ImageNet GAN, mejorando el estado de la técnica por un amplio margen.

    Una GAN simple en Python

    Implementación de código

    Dicho todo esto, sigamos adelante e implementemos un GAN simple que genera dígitos del 0 al 9, un ejemplo bastante clásico:

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    import os
    
    # Sample z from uniform distribution
    def sample_Z(m, n):
        return np.random.uniform(-1., 1., size=[m, n])
    
    def plot(samples):
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)
    
        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    
        return fig
    

    Ahora podemos definir el marcador de posición para nuestras muestras de entrada y vectores de ruido:

    # Input image, for discriminator model.
    X = tf.placeholder(tf.float32, shape=[None, 784])
    
    # Input noise for generator.
    Z = tf.placeholder(tf.float32, shape=[None, 100])
    

    Ahora, definimos nuestras redes generadoras y discriminadoras. Son perceptrones simples con una sola capa oculta.

    Usamos relu activaciones en las neuronas de la capa oculta, y sigmoides para las capas de salida.

    def generator(z):
        with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
            x = tf.layers.dense(z, 128, activation=tf.nn.relu)
            x = tf.layers.dense(z, 784)
            x = tf.nn.sigmoid(x)
        return x
    
    def discriminator(x):
        with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
            x = tf.layers.dense(x, 128, activation=tf.nn.relu)
            x = tf.layers.dense(x, 1)
            x = tf.nn.sigmoid(x)
        return x
    

    Ahora podemos definir nuestros modelos, funciones de pérdida y optimizadores:

    # Generator model
    G_sample = generator(Z)
    
    # Discriminator models
    D_real = discriminator(X)
    D_fake = discriminator(G_sample)
    
    
    # Loss function
    D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
    G_loss = -tf.reduce_mean(tf.log(D_fake))
    
    # Select parameters
    disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("disc")]
    gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("gen")]
    
    # Optimizers
    D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=disc_vars)
    G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=gen_vars)
    

    Finalmente, podemos escribir una rutina de entrenamiento. En cada iteración, realizamos un paso de optimización para el discriminador y otro para el generador.

    Cada 100 iteraciones guardamos algunas muestras generadas para que podamos ver el progreso.

    # Batch size
    mb_size = 128
    
    # Dimension of input noise
    Z_dim = 100
    
    mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    if not os.path.exists('out2/'):
        os.makedirs('out2/')
    
    i = 0
    
    for it in range(1000000):
    
        # Save generated images every 1000 iterations.
        if it % 1000 == 0:
            samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
    
            fig = plot(samples)
            plt.savefig('out2/{}.png'.format(str(i).zfill(3)), bbox_inches="tight")
            i += 1
            plt.close(fig)
    
    
        # Get next batch of images. Each batch has mb_size samples.
        X_mb, _ = mnist.train.next_batch(mb_size)
    
    
        # Run disciminator solver
        _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    
        # Run generator solver
        _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
    
        # Print loss
        if it % 1000 == 0:
            print('Iter: {}'.format(it))
            print('D loss: {:.4}'. format(D_loss_curr))
    

    Resultados y posibles mejoras

    Durante las primeras iteraciones, todo lo que vemos es ruido aleatorio:

    Aquí, las redes aún no han aprendido nada. Sin embargo, después de solo un par de minutos, ¡ya podemos ver cómo nuestros dígitos van tomando forma!

    Recursos

    Si quieres jugar con el código, está en GitHub!

    Etiquetas:

    Deja una respuesta

    Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *