AI with Java/DeepLearning4j

    DeepLearning4j

    The only production-grade, open-source deep learning library for Java. Build, train, and deploy neural networks natively on the JVM.

    DeepLearning4j (DL4J) is a distributed deep learning framework written for Java and Scala. Unlike Spring AI or LangChain4j which focus on LLM integration, DL4J lets you build and train your own neural networks from scratch—CNNs for image classification, RNNs for sequence modeling, and more. It's the TensorFlow/PyTorch equivalent for the JVM.

    DL4J integrates with Spark for distributed training, uses ND4J for high-performance numerical computing, and supports GPU acceleration via CUDA. It's ideal for enterprises that need to build custom ML models while staying in the Java ecosystem.

    DL4J Ecosystem

    DeepLearning4j

    Core neural network library for building and training models

    ND4J

    N-Dimensional arrays for Java—NumPy for the JVM

    Spark Integration

    Distributed training across clusters

    DataVec

    Data preprocessing and ETL for ML pipelines

    Getting Started

    Maven Dependencies

    Add DL4J to your project

    pom.xml
    <dependencies><!-- Core DL4J --><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>1.0.0-M2.1</version></dependency><!-- ND4J Backend (CPU) --><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-native-platform</artifactId><version>1.0.0-M2.1</version></dependency><!-- For GPU support, use instead: --><!--
    <dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-11.6-platform</artifactId>
    <version>1.0.0-M2.1</version>
    </dependency>
    --></dependencies>

    Building a Neural Network

    MLP for Classification

    A simple feedforward network for the Iris dataset

    IrisClassifier.java
    importorg.deeplearning4j.nn.conf.MultiLayerConfiguration;importorg.deeplearning4j.nn.conf.NeuralNetConfiguration;importorg.deeplearning4j.nn.conf.layers.DenseLayer;importorg.deeplearning4j.nn.conf.layers.OutputLayer;importorg.deeplearning4j.nn.multilayer.MultiLayerNetwork;importorg.nd4j.linalg.activations.Activation;importorg.nd4j.linalg.lossfunctions.LossFunctions;publicclassIrisClassifier{publicMultiLayerNetworkbuildModel(){MultiLayerConfiguration config =newNeuralNetConfiguration.Builder().seed(42).updater(newAdam(0.001)).list().layer(newDenseLayer.Builder().nIn(4)// 4 input features.nOut(10)// hidden layer size.activation(Activation.RELU).build()).layer(newDenseLayer.Builder().nIn(10).nOut(10).activation(Activation.RELU).build()).layer(newOutputLayer.Builder().nIn(10).nOut(3)// 3 classes.activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();MultiLayerNetwork model =newMultiLayerNetwork(config);
    model.init();return model;}}

    Training the Model

    Data Loading and Training Loop

    TrainingService.java
    importorg.datavec.api.records.reader.RecordReader;importorg.datavec.api.records.reader.impl.csv.CSVRecordReader;importorg.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;importorg.nd4j.linalg.dataset.api.iterator.DataSetIterator;publicclassTrainingService{publicvoidtrain(MultiLayerNetwork model,String dataPath)throwsException{// Load dataRecordReader recordReader =newCSVRecordReader(0,',');
    recordReader.initialize(newFileSplit(newFile(dataPath)));DataSetIterator iterator =newRecordReaderDataSetIterator(
    recordReader,32,// batch size4,// label index (5th column)3// number of classes);// Train for 100 epochsfor(int epoch =0; epoch <100; epoch++){
    iterator.reset();
    model.fit(iterator);// Evaluate every 10 epochsif(epoch %10==0){Evaluation eval = model.evaluate(iterator);System.out.println("Epoch "+ epoch +" - Accuracy: "+ 
    eval.accuracy());}}// Save the model
    model.save(newFile("iris-model.zip"));}publicvoidpredict(MultiLayerNetwork model,double[] features){INDArray input =Nd4j.create(features).reshape(1,4);INDArray output = model.output(input);int predictedClass =Nd4j.argMax(output,1).getInt(0);System.out.println("Predicted class: "+ predictedClass);}}

    Convolutional Neural Network

    Image Classification with CNN

    LeNet-style architecture for MNIST

    CNN for MNIST
    importorg.deeplearning4j.nn.conf.layers.ConvolutionLayer;importorg.deeplearning4j.nn.conf.layers.SubsamplingLayer;publicMultiLayerNetworkbuildCNN(){MultiLayerConfiguration config =newNeuralNetConfiguration.Builder().seed(42).updater(newAdam(0.001)).list()// First conv layer.layer(newConvolutionLayer.Builder(5,5).nIn(1)// 1 channel (grayscale).nOut(20)// 20 filters.stride(1,1).activation(Activation.RELU).build()).layer(newSubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())// Second conv layer.layer(newConvolutionLayer.Builder(5,5).nOut(50).stride(1,1).activation(Activation.RELU).build()).layer(newSubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())// Dense layers.layer(newDenseLayer.Builder().nOut(500).activation(Activation.RELU).build()).layer(newOutputLayer.Builder().nOut(10)// 10 digits.activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.convolutionalFlat(28,28,1)).build();returnnewMultiLayerNetwork(config);}

    When to Use DL4J

    ✓ Good Fit

    • • Need to train custom models (not just use LLMs)
    • • Java/Scala-only environment requirements
    • • Distributed training on Spark clusters
    • • Embedded ML in Java applications
    • • Production deployment on JVM infrastructure
    • • Import TensorFlow/Keras models

    ⚠️ Consider Alternatives

    • • Just need to call LLMs → Use Spring AI
    • • RAG/chatbots → Use LangChain4j
    • • Cutting-edge research → Python still leads
    • • Simple ML → Consider Weka or Tribuo
    • • Team more comfortable with Python
    • • Need latest model architectures quickly

    DL4J vs Other AI in Java Options

    FeatureDL4JSpring AILangChain4j
    Primary UseTrain neural networksLLM integrationAI agents & RAG
    Training
    Full Support
    No
    No
    LLM Calling
    No
    Core Feature
    Core Feature
    GPU Support
    CUDA
    N/A
    N/A
    Spark Integration
    Native
    No
    No

    Transfer Learning

    Import Pre-trained Models

    Leverage Keras/TensorFlow models in your Java application

    Training deep neural networks from scratch requires massive datasets and computational resources. Transfer learning lets you take a model trained on millions of images (like VGG16 or ResNet) and fine-tune it for your specific task. DeepLearning4j supports importing Keras H5 models directly, which means you can train in Python and deploy in Java—the best of both worlds.

    TransferLearningService.java
    importorg.deeplearning4j.nn.modelimport.keras.KerasModelImport;publicclassTransferLearningService{publicComputationGraphloadKerasModel(String modelPath)throwsException{// Import a Keras model (e.g., VGG16 pre-trained on ImageNet)ComputationGraph model =KerasModelImport.importKerasModelAndWeights(modelPath);return model;}publicMultiLayerNetworkfineTuneForCustomTask(MultiLayerNetwork baseModel,int numClasses){// Freeze early layers (feature extraction)// Only train the last few layers for your taskTransferLearning.Builder builder =newTransferLearning.Builder(baseModel).fineTuneConfiguration(newFineTuneConfiguration.Builder().updater(newAdam(0.0001))// Lower learning rate.seed(42).build()).setFeatureExtractor("conv2d_5")// Freeze up to this layer.removeOutputLayer().addLayer(newOutputLayer.Builder().nIn(512).nOut(numClasses)// Your custom classes.activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());return builder.build();}}

    Recurrent Neural Networks

    LSTM for Sequence Modeling

    Time series, NLP, and sequential data

    Recurrent Neural Networks (RNNs) and their more powerful variant LSTMs excel at processing sequential data where order matters—think stock prices, sensor readings, or text. Unlike feedforward networks, LSTMs maintain a "memory" that allows them to learn long-term dependencies. DL4J provides full support for LSTM, GRU, and other recurrent architectures.

    LSTM Network
    importorg.deeplearning4j.nn.conf.layers.LSTM;importorg.deeplearning4j.nn.conf.layers.RnnOutputLayer;publicMultiLayerNetworkbuildLSTM(int numFeatures,int numClasses){MultiLayerConfiguration config =newNeuralNetConfiguration.Builder().seed(42).updater(newAdam(0.001)).list().layer(newLSTM.Builder().nIn(numFeatures).nOut(128).activation(Activation.TANH).build()).layer(newLSTM.Builder().nIn(128).nOut(64).activation(Activation.TANH).build()).layer(newRnnOutputLayer.Builder().nIn(64).nOut(numClasses).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();returnnewMultiLayerNetwork(config);}// For time series prediction (regression)publicvoidtrainOnSequences(MultiLayerNetwork model,INDArray features,// Shape: [batch, features, timeSteps]INDArray labels){DataSet dataSet =newDataSet(features, labels);
    model.fit(dataSet);}

    Distributed Training with Spark

    One of DL4J's killer features is native Apache Spark integration. When your dataset is too large for a single machine, or training takes too long, you can distribute the workload across a Spark cluster. DL4J uses a parameter averaging approach where each worker trains on a subset of data, and gradients are periodically synchronized. This can cut training time from days to hours.

    Spark Distributed Training
    importorg.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;importorg.deeplearning4j.spark.api.TrainingMaster;importorg.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;publicclassDistributedTraining{publicvoidtrainOnSpark(JavaSparkContext sc,MultiLayerNetwork model,JavaRDD<DataSet> trainingData){// Configure distributed trainingTrainingMaster trainingMaster =newParameterAveragingTrainingMaster.Builder(32).averagingFrequency(5)// Sync every 5 mini-batches.workerPrefetchNumBatches(2)// Prefetch for efficiency.batchSizePerWorker(32)// Per-worker batch size.build();// Wrap model for SparkSparkDl4jMultiLayer sparkModel =newSparkDl4jMultiLayer(
    sc, model, trainingMaster
    );// Train across the clusterfor(int epoch =0; epoch <10; epoch++){
    sparkModel.fit(trainingData);System.out.println("Completed epoch "+ epoch);}// Get the trained model backMultiLayerNetwork trainedModel = sparkModel.getNetwork();
    trainedModel.save(newFile("distributed-model.zip"));}}

    Scaling Tip: Start with a small cluster (3-5 nodes) and monitor GPU utilization. The parameter averaging frequency is a trade-off: sync too often and you waste time on network overhead; sync too rarely and workers diverge.

    Real-World Applications

    DeepLearning4j powers production systems across industries where Java is the enterprise standard. Here are scenarios where DL4J excels over Python alternatives:

    Fraud Detection

    Financial institutions use DL4J to build real-time fraud detection systems that process millions of transactions per second. The JVM's low-latency garbage collection and mature ecosystem make it ideal for millisecond-sensitive decisions.

    RNN/LSTM
    Anomaly Detection

    Predictive Maintenance

    Manufacturing and IoT platforms analyze sensor data to predict equipment failures before they happen. DL4J's Spark integration handles the massive data volumes from industrial sensors.

    Time Series
    Spark

    Recommendation Engines

    E-commerce and media companies run recommendation models that need to integrate with existing Java backends. DL4J fits naturally into Spring-based microservice architectures.

    Embeddings
    Collaborative Filtering

    Model Serving & Deployment

    Embedding Models in Spring Boot

    Serve predictions via REST API

    One of DL4J's biggest advantages is deployment simplicity. Your trained model is just a serialized Java object—no Python runtime, no ONNX conversion, no Docker containers with TensorFlow Serving. Load the model at startup and serve predictions directly from your Spring Boot application.

    PredictionController.java
    @RestController@RequestMapping("/api/predictions")publicclassPredictionController{privatefinalMultiLayerNetwork model;publicPredictionController()throwsException{// Load model at startupthis.model =MultiLayerNetwork.load(newFile("models/production-model.zip"),true);}@PostMapping("/classify")publicPredictionResponseclassify(@RequestBodyPredictionRequest request){// Convert input to INDArrayINDArray input =Nd4j.create(request.getFeatures()).reshape(1, request.getFeatures().length);// Run inferenceINDArray output = model.output(input);// Get predicted class and confidenceint predictedClass =Nd4j.argMax(output,1).getInt(0);double confidence = output.getDouble(predictedClass);returnnewPredictionResponse(predictedClass, confidence);}}recordPredictionRequest(double[] features){}recordPredictionResponse(int predictedClass,double confidence){}

    Tips for Success

    Performance

    • • Use nd4j-native-platform for AVX2 support
    • • Enable GPU with CUDA backend for large models
    • • Use workspaces for memory efficiency
    • • Batch your data appropriately (32-128)

    Development

    • • Start with examples from dl4j-examples repo
    • • Use UI module for training visualization
    • • Import Keras models for transfer learning
    • • Join the community on Discourse

    Train Your Own Models

    DeepLearning4j is the right choice when you need to build custom neural networks in Java—not just call existing LLMs. Start with the examples repository.