Cómo usar TensorFlow con Java

C

Introducción

El Machine Learning está ganando popularidad y uso en todo el mundo. Ya ha cambiado drásticamente la forma en que se crean ciertas aplicaciones y es probable que continúe siendo una parte enorme (y en aumento) de nuestra vida diaria.

No hay forma de endulzarlo, el Machine Learning no es simple. Es bastante abrumador y puede parecer muy complejo para muchos.

Empresas como Google se encargaron de acercar los conceptos de Machine Learning a los desarrolladores y permitirles dar sus primeros pasos gradualmente, con una gran ayuda.

Por tanto, marcos como TensorFlow nació.

¿Qué es TensorFlow?

TensorFlow es un marco de Machine Learning de código abierto desarrollado por Google en Python y C ++.

Ayuda a los desarrolladores a adquirir datos, preparar y entrenar modelos, predecir estados futuros y realizar Machine Learning a gran escala.

Con él, podemos entrenar y ejecutar redes neuronales profundas que se utilizan con mayor frecuencia para el reconocimiento óptico de caracteres, el reconocimiento / clasificación de imágenes, el procesamiento del lenguaje natural, etc.

Tensores y operaciones

TensorFlow se basa en gráficos computacionales, que puedes imaginar como un gráfico clásico con nodes y bordes.

Cada node se conoce como una operación, y toman cero o más tensores y producen cero o más tensores. Una operación puede ser muy simple, como una suma básica, pero también pueden ser muy complejas.

Los tensores se representan como bordes del gráfico y son la unidad de datos central. Realizamos diferentes funciones en estos tensores a medida que los alimentamos a las operaciones. Pueden tener una o varias dimensiones, que a veces se denominan sus rangos – (Escalar: rango 0, Vector: rango 1, Matriz: rango 2)

Estos datos fluyen a través del gráfico computacional a través de tensores, afectados por las operaciones, de ahí el nombre TensorFlow.

Los tensores pueden almacenar datos en cualquier número de dimensiones y hay tres tipos principales de tensores: marcadores de posición, variables y constantes.

Instalación de TensorFlow

Con Maven, instalar TensorFlow es tan fácil como incluir la dependencia:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

Si su dispositivo admite Soporte de GPU, luego use estas dependencias:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow_jni_gpu</artifactId>
  <version>1.13.1</version>
</dependency>

Puede verificar la versión de TensorFlow instalada actualmente usando el TensorFlow objeto:

System.out.println(TensorFlow.version());

API de TensorFlow Java

Las ofertas de TensorFlow de la API de Java están incluidas en org.tensorflow paquete. Actualmente es experimental, por lo que no se garantiza que sea estable.

Tenga en cuenta que el único lenguaje totalmente compatible con TensorFlow es Python y que la API de Java no es tan funcional.

Nos presenta nuevas clases, una interfaz, una enumeración y una excepción.

Clases

Las nuevas clases introducidas a través de la API son:

  • Graph: Un gráfico de flujo de datos que representa un cálculo de TensorFlow
  • Operation: Un node Graph que realiza cálculos en tensores
  • OperationBuilder: Una clase de constructor para operaciones
  • Output<T>: Un identificador simbólico de un tensor producido por una Operación
  • SavedModelBundle: Representa un modelo cargado desde el almacenamiento.
  • SavedModelBundle.Loader: Proporciona opciones para cargar un modelo guardado
  • Server: Un servidor TensorFlow en proceso, para usar en entrenamiento distribuido
  • Session: Controlador para la ejecución de gráficos
  • Session.Run: Tensores de salida y metadatos obtenidos al ejecutar una sesión
  • Session.Runner: Ejecutar operaciones y evaluar tensores
  • Shape: La forma posiblemente parcialmente conocida de un tensor producido por una operación
  • Tensor<T>: Una matriz multidimensional de tipo estático cuyos elementos son de un tipo descrito por T
  • TensorFlow: Métodos de utilidad estáticos que describen el tiempo de ejecución de TensorFlow
  • Tensors: Métodos de fábrica con seguridad de tipos para crear objetos Tensor
Enum
  • DataType: Representa el tipo de elementos en un tensor como una enumeración
Interfaz
  • Operand<T>: Interfaz implementada por operandos de una operación de TensorFlow
Excepción
  • TensorFlowException: Se lanza una excepción sin marcar al ejecutar TensorFlow Graphs

Si comparamos todo esto con el módulo tf en Python, hay una diferencia obvia. La API de Java no tiene casi la misma cantidad de funcionalidad, al menos por ahora.

Gráficos

Como se mencionó anteriormente, TensorFlow se basa en gráficos computacionales, donde org.tensorflow.Graph es la implementación de Java.

Nota: Sus instancias son seguras para subprocesos, aunque necesitamos liberar explícitamente los recursos utilizados por Graph una vez que hayamos terminado con él.

Comencemos con un gráfico vacío:

Graph graph = new Graph();

Este gráfico no significa mucho, está vacío. Para hacer algo con él, primero debemos cargarlo con Operations.

Para cargarlo con operaciones, usamos el opBuilder() método, que devuelve un OperationBuilder objeto que agregará las operaciones a nuestro gráfico una vez que llamemos al .build() método.

Constantes

Agreguemos una constante a nuestro gráfico:

Operation x = graph.opBuilder("Const", "x")
               .setAttr("dtype", DataType.FLOAT)
               .setAttr("value", Tensor.create(3.0f))
               .build(); 

Marcadores de posición

Los marcadores de posición son un “tipo” de variable que no tiene un valor en la declaración. Sus valores se asignarán en una fecha posterior. Esto nos permite construir gráficos con operaciones sin ningún dato real:

Operation y = graph.opBuilder("Placeholder", "y")
        .setAttr("dtype", DataType.FLOAT)
        .build();

Funciones

Y ahora, finalmente, para redondear esto, necesitamos agregar ciertas funciones. Estos pueden ser tan simples como multiplicar, dividir o sumar, o tan complejos como multiplicaciones de matrices. Al igual que antes, definimos funciones usando el .opBuilder() método:

Operation xy = graph.opBuilder("Mul", "xy")
  .addInput(x.output(0))
  .addInput(y.output(0))
  .build();         

Nota: Estamos usando output(0) como tensor puede tener más de una salida.

Visualización de gráficos

Lamentablemente, la API de Java aún no incluye ninguna herramienta que le permita visualizar gráficos como lo haría en Python. Cuando la API de Java se actualice, también lo hará este artículo.

Sesiones

Como se mencionó antes, un Session es el conductor de un GraphEjecución. Encapsula el entorno en el que Operationsy Graphs se ejecutan para calcular Tensors.

Lo que esto significa es que los tensores en nuestro gráfico que construimos en realidad no tienen ningún valor, ya que no ejecutamos el gráfico dentro de una sesión.

Primero agreguemos el gráfico a una sesión:

Session session = new Session(graph);

Nuestro cálculo simplemente multiplica el x y y valor. Para ejecutar nuestro gráfico y calcularlo, fetch() la xy operación y alimentarlo x y y valores:

Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);
System.out.println(tensor.floatValue());

Ejecutar este fragmento de código producirá:

10.0f

Guardar modelos en Python y cargar en Java

Esto puede sonar un poco extraño, pero dado que Python es el único lenguaje bien soportado, la API de Java todavía no tiene la funcionalidad para guardar modelos.

Esto significa que la API de Java está diseñada solo para el caso de uso de servicio, al menos hasta que sea totalmente compatible con TensorFlow. Al menos, podemos entrenar y guardar modelos en Python y luego cargarlos en Java para servirlos, usando el SavedModelBundle clase:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); 
Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);  

System.out.println(tensor.floatValue());

Conclusión

TensorFlow es un marco potente, robusto y ampliamente utilizado. Se está mejorando constantemente y últimamente se ha introducido en nuevos lenguajes, incluidos Java y JavaScript.

Aunque la API de Java aún no tiene tanta funcionalidad como TensorFlow para Python, aún puede servir como una buena introducción a TensorFlow para desarrolladores de Java.

 

About the author

Ramiro de la Vega

Bienvenido a Pharos.sh

Soy Ramiro de la Vega, Estadounidense con raíces Españolas. Empecé a programar hace casi 20 años cuando era muy jovencito.

Espero que en mi web encuentres la inspiración y ayuda que necesitas para adentrarte en el fantástico mundo de la programación y conseguir tus objetivos por difíciles que sean.

Add comment

Sobre mi

Últimos Post

Etiquetas

Esta web utiliza cookies propias para su correcto funcionamiento. Al hacer clic en el botón Aceptar, aceptas el uso de estas tecnologías y el procesamiento de tus datos para estos propósitos. Más información
Privacidad