In this tutorial we will create a new Java project and set up the libraries required to use TensorFlow.
TensorFlow is an end-to-end open source platform for machine learning, it is re-defining the ecosystem of tools geared towards machine learning development. With it comes a comprehensive, flexible ecosystem of tools, libraries, and community resources that lets researchers push the state-of-the-art in machine learning and developers quickly build and deploy machine learning-powered applications.
Below, we will explain how to create a Hello World project with a simple calculation as an example of how TensorFlow works with Java. To import TensorFlow we will be using the Maven repository. Our source code for this example is available on GitHub.
TensorFlow can run on any Java Virtual Machine to build, train, and run machine learning models. It requires at minimum JDK version 8. In this tutorial, we used JDK version 11. As a data scientist, TensorFlow is a must, especially when graphing is required. With Java runtime environments already installed on most devices, having TensorFlow working with Java is an added advantage.
After creating en empty Java project, we include TensorFlow in the Maven application by adding its dependency artifact to our pom.xml file. If you are new to Maven, read the beginner tutorials to understand the basic concepts. An example of our code inside pom.xml,
<!-- Include TensorFlow (pure CPU only) for all supported platforms -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>1.15.0</version>
</dependency>
In this section, we provide code to show how to perform squaring of a number using TensorFlow in Java. Here, the TensorFlow Operation class (Ops) is used to perform computations. It represents an API for building operations. Using the Ops object in our code below, we are able to perform multiplications by actually calling the graph feature as well. In order to pass our values into the tf.math.mul(), we have to use the placeholder op, which is intended to hold a value that will be fed into the computation. An example is shown below:
public static void main(String[] args) throws Exception {
System.out.println("Hello TensorFlow " + TensorFlow.version());
try (ConcreteFunction multiply = ConcreteFunction.create(app::multiply);
TInt32 x = TInt32.scalarOf(10);
Tensor dblX = multiply.call(x)) {
System.out.println(x.getInt() + " multiplied is " + ((TInt32)dblX).getInt());
}
catch (Exception ex){
System.out.println(ex.getMessage());
}
}
private static Signature multiply(Ops tf) {
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
Mul<TInt32> dblX = tf.math.mul(x, x);
return Signature.builder().input("x", x).output("dbl", dblX).build();
}
In this tutorial, we showed how easy it is to add TensorFlow to Java using Maven. We also showed a simple implementation of how to multiply numbers using TensorFlow features.
1 thought on “How to set up a TensorFlow Java project using Maven”