Cómo usar TensorFlow con Java

    Introducción

    El Machine Learning está ganando popularidad y uso en todo el mundo. Ya ha cambiado drásticamente la forma en que se crean ciertas aplicaciones y es probable que continúe siendo una parte enorme (y en aumento) de nuestra vida diaria.

    No hay forma de endulzarlo, el Machine Learning no es simple. Es bastante abrumador y puede parecer muy complejo para muchos.

    Empresas como Google se encargaron de acercar los conceptos de Machine Learning a los desarrolladores y permitirles dar sus primeros pasos gradualmente, con una gran ayuda.

    Por tanto, marcos como TensorFlow nació.

    ¿Qué es TensorFlow?

    TensorFlow es un marco de Machine Learning de código abierto desarrollado por Google en Python y C ++.

    Ayuda a los desarrolladores a adquirir datos, preparar y entrenar modelos, predecir estados futuros y realizar Machine Learning a gran escala.

    Con él, podemos entrenar y ejecutar redes neuronales profundas que se utilizan con mayor frecuencia para el reconocimiento óptico de caracteres, el reconocimiento / clasificación de imágenes, el procesamiento del lenguaje natural, etc.

    Tensores y operaciones

    TensorFlow se basa en gráficos computacionales, que puedes imaginar como un gráfico clásico con nodes y bordes.

    Cada node se conoce como una operación, y toman cero o más tensores y producen cero o más tensores. Una operación puede ser muy simple, como una suma básica, pero también pueden ser muy complejas.

    Los tensores se representan como bordes del gráfico y son la unidad de datos central. Realizamos diferentes funciones en estos tensores a medida que los alimentamos a las operaciones. Pueden tener una o varias dimensiones, que a veces se denominan sus rangos – (Escalar: rango 0, Vector: rango 1, Matriz: rango 2)

    Estos datos fluyen a través del gráfico computacional a través de tensores, afectados por las operaciones, de ahí el nombre TensorFlow.

    Los tensores pueden almacenar datos en cualquier número de dimensiones y hay tres tipos principales de tensores: marcadores de posición, variables y constantes.

    Instalación de TensorFlow

    Con Maven, instalar TensorFlow es tan fácil como incluir la dependencia:

    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>tensorflow</artifactId>
      <version>1.13.1</version>
    </dependency>
    

    Si su dispositivo admite Soporte de GPU, luego use estas dependencias:

    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>libtensorflow</artifactId>
      <version>1.13.1</version>
    </dependency>
    
    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>libtensorflow_jni_gpu</artifactId>
      <version>1.13.1</version>
    </dependency>
    

    Puede verificar la versión de TensorFlow instalada actualmente usando el TensorFlow objeto:

    System.out.println(TensorFlow.version());
    

    API de TensorFlow Java

    Las ofertas de TensorFlow de la API de Java están incluidas en org.tensorflow paquete. Actualmente es experimental, por lo que no se garantiza que sea estable.

    Tenga en cuenta que el único lenguaje totalmente compatible con TensorFlow es Python y que la API de Java no es tan funcional.

    Nos presenta nuevas clases, una interfaz, una enumeración y una excepción.

    Clases

    Las nuevas clases introducidas a través de la API son:

    • Graph: Un gráfico de flujo de datos que representa un cálculo de TensorFlow
    • Operation: Un node Graph que realiza cálculos en tensores
    • OperationBuilder: Una clase de constructor para operaciones
    • Output<T>: Un identificador simbólico de un tensor producido por una Operación
    • SavedModelBundle: Representa un modelo cargado desde el almacenamiento.
    • SavedModelBundle.Loader: Proporciona opciones para cargar un modelo guardado
    • Server: Un servidor TensorFlow en proceso, para usar en entrenamiento distribuido
    • Session: Controlador para la ejecución de gráficos
    • Session.Run: Tensores de salida y metadatos obtenidos al ejecutar una sesión
    • Session.Runner: Ejecutar operaciones y evaluar tensores
    • Shape: La forma posiblemente parcialmente conocida de un tensor producido por una operación
    • Tensor<T>: Una matriz multidimensional de tipo estático cuyos elementos son de un tipo descrito por T
    • TensorFlow: Métodos de utilidad estáticos que describen el tiempo de ejecución de TensorFlow
    • Tensors: Métodos de fábrica con seguridad de tipos para crear objetos Tensor
    Enum
    • DataType: Representa el tipo de elementos en un tensor como una enumeración
    Interfaz
    • Operand<T>: Interfaz implementada por operandos de una operación de TensorFlow
    Excepción
    • TensorFlowException: Se lanza una excepción sin marcar al ejecutar TensorFlow Graphs

    Si comparamos todo esto con el módulo tf en Python, hay una diferencia obvia. La API de Java no tiene casi la misma cantidad de funcionalidad, al menos por ahora.

    Gráficos

    Como se mencionó anteriormente, TensorFlow se basa en gráficos computacionales, donde org.tensorflow.Graph es la implementación de Java.

    Nota: Sus instancias son seguras para subprocesos, aunque necesitamos liberar explícitamente los recursos utilizados por Graph una vez que hayamos terminado con él.

    Comencemos con un gráfico vacío:

    Graph graph = new Graph();
    

    Este gráfico no significa mucho, está vacío. Para hacer algo con él, primero debemos cargarlo con Operations.

    Para cargarlo con operaciones, usamos el opBuilder() método, que devuelve un OperationBuilder objeto que agregará las operaciones a nuestro gráfico una vez que llamemos al .build() método.

    Constantes

    Agreguemos una constante a nuestro gráfico:

    Operation x = graph.opBuilder("Const", "x")
                   .setAttr("dtype", DataType.FLOAT)
                   .setAttr("value", Tensor.create(3.0f))
                   .build(); 
    

    Marcadores de posición

    Los marcadores de posición son un «tipo» de variable que no tiene un valor en la declaración. Sus valores se asignarán en una fecha posterior. Esto nos permite construir gráficos con operaciones sin ningún dato real:

    Operation y = graph.opBuilder("Placeholder", "y")
            .setAttr("dtype", DataType.FLOAT)
            .build();
    

    Funciones

    Y ahora, finalmente, para redondear esto, necesitamos agregar ciertas funciones. Estos pueden ser tan simples como multiplicar, dividir o sumar, o tan complejos como multiplicaciones de matrices. Al igual que antes, definimos funciones usando el .opBuilder() método:

    Operation xy = graph.opBuilder("Mul", "xy")
      .addInput(x.output(0))
      .addInput(y.output(0))
      .build();         
    

    Nota: Estamos usando output(0) como tensor puede tener más de una salida.

    Visualización de gráficos

    Lamentablemente, la API de Java aún no incluye ninguna herramienta que le permita visualizar gráficos como lo haría en Python. Cuando la API de Java se actualice, también lo hará este artículo.

    Sesiones

    Como se mencionó antes, un Session es el conductor de un GraphEjecución. Encapsula el entorno en el que Operationsy Graphs se ejecutan para calcular Tensors.

    Lo que esto significa es que los tensores en nuestro gráfico que construimos en realidad no tienen ningún valor, ya que no ejecutamos el gráfico dentro de una sesión.

    Primero agreguemos el gráfico a una sesión:

    Session session = new Session(graph);
    

    Nuestro cálculo simplemente multiplica el x y y valor. Para ejecutar nuestro gráfico y calcularlo, fetch() la xy operación y alimentarlo x y y valores:

    Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);
    System.out.println(tensor.floatValue());
    

    Ejecutar este fragmento de código producirá:

    10.0f
    

    Guardar modelos en Python y cargar en Java

    Esto puede sonar un poco extraño, pero dado que Python es el único lenguaje bien soportado, la API de Java todavía no tiene la funcionalidad para guardar modelos.

    Esto significa que la API de Java está diseñada solo para el caso de uso de servicio, al menos hasta que sea totalmente compatible con TensorFlow. Al menos, podemos entrenar y guardar modelos en Python y luego cargarlos en Java para servirlos, usando el SavedModelBundle clase:

    SavedModelBundle model = SavedModelBundle.load("./model", "serve"); 
    Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);  
    
    System.out.println(tensor.floatValue());
    

    Conclusión

    TensorFlow es un marco potente, robusto y ampliamente utilizado. Se está mejorando constantemente y últimamente se ha introducido en nuevos lenguajes, incluidos Java y JavaScript.

    Aunque la API de Java aún no tiene tanta funcionalidad como TensorFlow para Python, aún puede servir como una buena introducción a TensorFlow para desarrolladores de Java.

     

    Etiquetas:

    Deja una respuesta

    Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *