Introducci贸n a las GAN con Python y TensorFlow

    Introducci贸n

    Los modelos generativos son una familia de arquitecturas de IA cuyo objetivo es crear muestras de datos desde cero. Lo logran capturando las distribuciones de datos del tipo de cosas que queremos generar.

    Este tipo de modelos se est谩n investigando intensamente y hay una gran cantidad de publicidad a su alrededor. Solo mire el cuadro que muestra la cantidad de art铆culos publicados en el campo durante los 煤ltimos a帽os:

    Desde 2014, cuando el primer art铆culo sobre redes generativas de confrontaci贸n se public贸, los modelos generativos se est谩n volviendo incre铆blemente poderosos y ahora podemos generar muestras de datos hiperrealistas para una amplia gama de distribuciones: im谩genes, videos, m煤sica, escritos, etc.

    A continuaci贸n, se muestran algunos ejemplos de im谩genes generadas por una GAN:

    驴Qu茅 son los modelos generativos?

    El marco de GAN

    El marco m谩s exitoso propuesto para modelos generativos, al menos en los 煤ltimos a帽os, toma el nombre de Redes generativas antag贸nicas (GAN).

    En pocas palabras, un GAN se compone de dos modelos separados, representados por redes neuronales: un generador G y un discriminador D. El objetivo del discriminador es saber si una muestra de datos proviene de una distribuci贸n de datos real, o si en su lugar se genera por G.

    El objetivo del generador es generar muestras de datos para enga帽ar al discriminador.

    El generador no es m谩s que una red neuronal profunda. Toma como entrada un vector de ruido aleatorio (generalmente gaussiano o de una distribuci贸n uniforme) y genera una muestra de datos de la distribuci贸n que queremos capturar.

    El discriminador es, nuevamente, solo una red neuronal. Su objetivo es, como su nombre indica, discriminar entre muestras reales y falsas. En consecuencia, su entrada es una muestra de datos, ya sea proveniente del generador o de la distribuci贸n de datos real.

    La salida es un n煤mero simple, que representa la probabilidad de que la entrada sea real. Una probabilidad alta significa que el discriminador est谩 seguro de que las muestras que le est谩n dando son genuinas. Por el contrario, una probabilidad baja muestra una alta confianza en el hecho de que la muestra proviene de la red de generadores:

    Imag铆nese un falsificador de arte que intenta crear obras de arte falsas y un cr铆tico de arte que necesita distinguir entre pinturas adecuadas y falsas.

    En este escenario, el cr铆tico act煤a como nuestro discriminador y el falsificador es el generador, tomando retroalimentaci贸n del cr铆tico para mejorar sus habilidades y hacer que su arte forjado parezca m谩s convincente:

    Formaci贸n

    Entrenar a un GAN puede ser algo doloroso. La inestabilidad del entrenamiento siempre ha sido un problema y muchas investigaciones se han centrado en hacer que el entrenamiento sea m谩s estable.

    La funci贸n objetivo b谩sica de un modelo vanilla GAN es la siguiente:

    Aqu铆, re se refiere a la red discriminadora, mientras que GRAMO obviamente se refiere al generador.

    Como muestra la f贸rmula, el generador se optimiza para confundir al discriminador al m谩ximo, al tratar de hacer que genere altas probabilidades de muestras de datos falsas.

    Por el contrario, el discriminador intenta mejorar la distinci贸n entre muestras procedentes de G y muestras procedentes de la distribuci贸n real.

    El t茅rmino adversario proviene exactamente de la forma en que se entrenan los GANS, enfrentando a las dos redes entre s铆.

    Una vez que hemos entrenado nuestro modelo, el discriminador ya no es necesario. Todo lo que tenemos que hacer es alimentar al generador con un vector de ruido aleatorio y, con suerte, obtendremos una muestra de datos artificial y realista como resultado.

    Problemas de GAN

    Entonces, 驴por qu茅 son tan dif铆ciles de entrenar las GAN? Como se indic贸 anteriormente, las GAN son muy dif铆ciles de entrenar en su forma b谩sica. Veremos brevemente por qu茅 este es el caso.

    Equilibrio de Nash dif铆cil de alcanzar

    Dado que estas dos redes se disparan informaci贸n entre s铆, se podr铆a representar como un juego en el que uno adivina si la entrada es real o no.

    El marco GAN es un juego no convexo, de dos jugadores y no cooperativo con par谩metros continuos de alta dimensi贸n, en el que cada jugador quiere minimizar su funci贸n de coste. El 贸ptimo de este proceso toma el nombre de Equilibrio de Nash – donde cada jugador no se desempe帽ar谩 mejor cambiando una estrategia, dado que el otro jugador no cambia su estrategia.

    Sin embargo, las GAN se entrenan t铆picamente usando t茅cnicas de descenso de gradientes que est谩n dise帽adas para encontrar el valor bajo de una funci贸n de costo y no encontrar el Equilibrio de Nash de un juego.

    Colapso de modo

    La mayor铆a de las distribuciones de datos son multimodales. Toma el Conjunto de datos MNIST: hay 10 “modos” de datos, que se refieren a los diferentes d铆gitos entre 0 y 9.

    Un buen modelo generativo podr铆a producir muestras con suficiente variabilidad, pudiendo as铆 generar muestras de todas las diferentes clases.

    Sin embargo, esto no siempre sucede.

    Digamos que el generador se vuelve realmente bueno para producir el d铆gito “3”. Si las muestras producidas son lo suficientemente convincentes, el discriminador probablemente les asignar谩 altas probabilidades.

    Como resultado, el generador se ver谩 impulsado a producir muestras que provengan de ese modo espec铆fico, ignorando las otras clases la mayor parte del tiempo. B谩sicamente enviar谩 spam al mismo n煤mero y con cada n煤mero que pase el discriminador, este comportamiento solo se aplicar谩 m谩s.

    Gradiente decreciente

    Muy similar al ejemplo anterior, el discriminador puede tener demasiado 茅xito en distinguir muestras de datos. Cuando eso es cierto, el gradiente del generador desaparece, comienza a aprender cada vez menos y no logra converger.

    Este desequilibrio, al igual que el anterior, se puede producir si entrenamos las redes por separado. La evoluci贸n de la red neuronal puede ser bastante impredecible, lo que puede llevar a que una est茅 por delante de la otra por una milla. Si los capacitamos juntos, principalmente nos aseguramos de que estas cosas no sucedan.

    Lo 煤ltimo

    Ser铆a imposible ofrecer una visi贸n completa de todas las mejoras y desarrollos que hicieron que las GAN fueran m谩s poderosas y estables en los 煤ltimos a帽os.

    En cambio, lo que har茅 es compilar una lista de las arquitecturas y t茅cnicas m谩s exitosas, proporcionando enlaces a recursos relevantes para profundizar m谩s.

    DCGAN

    Las GAN convolucionales profundas (DCGAN) introdujeron convoluciones en las redes generadoras y discriminadoras.

    Sin embargo, esto no fue simplemente una cuesti贸n de agregar capas convolucionales al modelo, ya que el entrenamiento se volvi贸 a煤n m谩s inestable.

    Se tuvieron que aplicar varios trucos para que los DCGAN fueran 煤tiles:

    • La normalizaci贸n de lotes se aplic贸 tanto al generador como a la red discriminadora
    • La deserci贸n se utiliza como t茅cnica de regularizaci贸n
    • El generador necesitaba una forma de aumentar la muestra del vector de entrada aleatoria a una imagen de salida. Aqu铆 se emplea la transposici贸n de capas convolucionales
    • LeakyRelu y TanH las activaciones se utilizan en ambas redes

    WGAN

    GAN de Wasserstein (WGAN) tienen como objetivo mejorar la estabilidad del entrenamiento. Hay una gran cantidad de matem谩ticas detr谩s de este tipo de modelo. Se puede encontrar una explicaci贸n m谩s accesible aqu铆.

    Las ideas b谩sicas aqu铆 fueron proponer una nueva funci贸n de costos que tenga un gradiente m谩s suave en todas partes.

    La nueva funci贸n de costo usa una m茅trica llamada distancia de Wasserstein, que tiene un gradiente m谩s suave en todas partes.

    Como resultado, el discriminador, que ahora se llama cr铆tico, genera valores de confianza que ya no deben interpretarse como una probabilidad. Los valores altos significan que el modelo conf铆a en que la entrada es real.

    Dos mejoras significativas para WGAN son:

    • No tiene signos de colapso de modo en los experimentos.
    • El generador a煤n puede aprender cuando el cr铆tico se desempe帽a bien

    SAGANs

    GAN de auto-atenci贸n (SAGAN) introducen un mecanismo de atenci贸n al marco de GAN.

    Los mecanismos de atenci贸n permiten utilizar informaci贸n global de forma local. Lo que esto significa es que podemos capturar el significado de diferentes partes de una imagen y usar esa informaci贸n para producir mejores muestras.

    Esto proviene de la observaci贸n de que las convoluciones son bastante malas para capturar dependencias a largo plazo en muestras de entrada, ya que la convoluci贸n es una operaci贸n local cuyo campo receptivo depende del tama帽o espacial del kernel.

    Esto significa que, por ejemplo, no es posible que una salida en la posici贸n superior izquierda de una imagen tenga relaci贸n con la salida en la parte inferior derecha.

    Una forma de resolver este problema ser铆a utilizar n煤cleos con tama帽os m谩s grandes para capturar m谩s informaci贸n. Sin embargo, esto har铆a que el modelo fuera computacionalmente ineficaz y muy lento de entrenar.

    La auto-atenci贸n resuelve este problema, proporcionando una forma eficiente de capturar informaci贸n global y usarla localmente cuando pueda resultar 煤til.

    BigGANs

    BigGANs son, en el momento de redactar este art铆culo, considerados m谩s o menos de 煤ltima generaci贸n, en lo que respecta a la calidad de las muestras generadas.

    Lo que hicieron los investigadores aqu铆 fue reunir todo lo que hab铆a estado funcionando hasta ese momento y luego escalarlo masivamente.
    Su modelo de referencia era de hecho un SAGAN, al que a帽adieron algunos trucos para mejorar la estabilidad.

    Demostraron que las GAN se benefician enormemente del escalado, incluso cuando no se introducen m谩s mejoras funcionales en el modelo, como se cita en el documento original:

    Hemos demostrado que las Redes Adversarias Generativas entrenadas para modelar im谩genes naturales de m煤ltiples categor铆as se benefician enormemente de la ampliaci贸n, tanto en t茅rminos de fidelidad como de variedad de las muestras generadas. Como resultado, nuestros modelos establecen un nuevo nivel de rendimiento entre los modelos ImageNet GAN, mejorando el estado de la t茅cnica por un amplio margen.

    Una GAN simple en Python

    Implementaci贸n de c贸digo

    Dicho todo esto, sigamos adelante e implementemos un GAN simple que genera d铆gitos del 0 al 9, un ejemplo bastante cl谩sico:

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    import os
    
    # Sample z from uniform distribution
    def sample_Z(m, n):
        return np.random.uniform(-1., 1., size=[m, n])
    
    def plot(samples):
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)
    
        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    
        return fig
    

    Ahora podemos definir el marcador de posici贸n para nuestras muestras de entrada y vectores de ruido:

    # Input image, for discriminator model.
    X = tf.placeholder(tf.float32, shape=[None, 784])
    
    # Input noise for generator.
    Z = tf.placeholder(tf.float32, shape=[None, 100])
    

    Ahora, definimos nuestras redes generadoras y discriminadoras. Son perceptrones simples con una sola capa oculta.

    Usamos relu activaciones en las neuronas de la capa oculta, y sigmoides para las capas de salida.

    def generator(z):
        with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
            x = tf.layers.dense(z, 128, activation=tf.nn.relu)
            x = tf.layers.dense(z, 784)
            x = tf.nn.sigmoid(x)
        return x
    
    def discriminator(x):
        with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
            x = tf.layers.dense(x, 128, activation=tf.nn.relu)
            x = tf.layers.dense(x, 1)
            x = tf.nn.sigmoid(x)
        return x
    

    Ahora podemos definir nuestros modelos, funciones de p茅rdida y optimizadores:

    # Generator model
    G_sample = generator(Z)
    
    # Discriminator models
    D_real = discriminator(X)
    D_fake = discriminator(G_sample)
    
    
    # Loss function
    D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
    G_loss = -tf.reduce_mean(tf.log(D_fake))
    
    # Select parameters
    disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("disc")]
    gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("gen")]
    
    # Optimizers
    D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=disc_vars)
    G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=gen_vars)
    

    Finalmente, podemos escribir una rutina de entrenamiento. En cada iteraci贸n, realizamos un paso de optimizaci贸n para el discriminador y otro para el generador.

    Cada 100 iteraciones guardamos algunas muestras generadas para que podamos ver el progreso.

    # Batch size
    mb_size = 128
    
    # Dimension of input noise
    Z_dim = 100
    
    mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    if not os.path.exists('out2/'):
        os.makedirs('out2/')
    
    i = 0
    
    for it in range(1000000):
    
        # Save generated images every 1000 iterations.
        if it % 1000 == 0:
            samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
    
            fig = plot(samples)
            plt.savefig('out2/{}.png'.format(str(i).zfill(3)), bbox_inches="tight")
            i += 1
            plt.close(fig)
    
    
        # Get next batch of images. Each batch has mb_size samples.
        X_mb, _ = mnist.train.next_batch(mb_size)
    
    
        # Run disciminator solver
        _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    
        # Run generator solver
        _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
    
        # Print loss
        if it % 1000 == 0:
            print('Iter: {}'.format(it))
            print('D loss: {:.4}'. format(D_loss_curr))
    

    Resultados y posibles mejoras

    Durante las primeras iteraciones, todo lo que vemos es ruido aleatorio:

    Aqu铆, las redes a煤n no han aprendido nada. Sin embargo, despu茅s de solo un par de minutos, 隆ya podemos ver c贸mo nuestros d铆gitos van tomando forma!

    Recursos

    Si quieres jugar con el c贸digo, est谩 en GitHub!

    Etiquetas:

    Deja una respuesta

    Tu direcci贸n de correo electr贸nico no ser谩 publicada. Los campos obligatorios est谩n marcados con *