DeepLearning4j
The only production-grade, open-source deep learning library for Java. Build, train, and deploy neural networks natively on the JVM.
DeepLearning4j (DL4J) is a distributed deep learning framework written for Java and Scala. Unlike Spring AI or LangChain4j which focus on LLM integration, DL4J lets you build and train your own neural networks from scratch—CNNs for image classification, RNNs for sequence modeling, and more. It's the TensorFlow/PyTorch equivalent for the JVM.
DL4J integrates with Spark for distributed training, uses ND4J for high-performance numerical computing, and supports GPU acceleration via CUDA. It's ideal for enterprises that need to build custom ML models while staying in the Java ecosystem.
DL4J Ecosystem
DeepLearning4j
Core neural network library for building and training models
ND4J
N-Dimensional arrays for Java—NumPy for the JVM
Spark Integration
Distributed training across clusters
DataVec
Data preprocessing and ETL for ML pipelines
Getting Started
Maven Dependencies
Add DL4J to your project
<dependencies><!-- Core DL4J --><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>1.0.0-M2.1</version></dependency><!-- ND4J Backend (CPU) --><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-native-platform</artifactId><version>1.0.0-M2.1</version></dependency><!-- For GPU support, use instead: --><!--
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.6-platform</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
--></dependencies>Building a Neural Network
MLP for Classification
A simple feedforward network for the Iris dataset
importorg.deeplearning4j.nn.conf.MultiLayerConfiguration;importorg.deeplearning4j.nn.conf.NeuralNetConfiguration;importorg.deeplearning4j.nn.conf.layers.DenseLayer;importorg.deeplearning4j.nn.conf.layers.OutputLayer;importorg.deeplearning4j.nn.multilayer.MultiLayerNetwork;importorg.nd4j.linalg.activations.Activation;importorg.nd4j.linalg.lossfunctions.LossFunctions;publicclassIrisClassifier{publicMultiLayerNetworkbuildModel(){MultiLayerConfiguration config =newNeuralNetConfiguration.Builder().seed(42).updater(newAdam(0.001)).list().layer(newDenseLayer.Builder().nIn(4)// 4 input features.nOut(10)// hidden layer size.activation(Activation.RELU).build()).layer(newDenseLayer.Builder().nIn(10).nOut(10).activation(Activation.RELU).build()).layer(newOutputLayer.Builder().nIn(10).nOut(3)// 3 classes.activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();MultiLayerNetwork model =newMultiLayerNetwork(config);
model.init();return model;}}Training the Model
Data Loading and Training Loop
importorg.datavec.api.records.reader.RecordReader;importorg.datavec.api.records.reader.impl.csv.CSVRecordReader;importorg.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;importorg.nd4j.linalg.dataset.api.iterator.DataSetIterator;publicclassTrainingService{publicvoidtrain(MultiLayerNetwork model,String dataPath)throwsException{// Load dataRecordReader recordReader =newCSVRecordReader(0,',');
recordReader.initialize(newFileSplit(newFile(dataPath)));DataSetIterator iterator =newRecordReaderDataSetIterator(
recordReader,32,// batch size4,// label index (5th column)3// number of classes);// Train for 100 epochsfor(int epoch =0; epoch <100; epoch++){
iterator.reset();
model.fit(iterator);// Evaluate every 10 epochsif(epoch %10==0){Evaluation eval = model.evaluate(iterator);System.out.println("Epoch "+ epoch +" - Accuracy: "+
eval.accuracy());}}// Save the model
model.save(newFile("iris-model.zip"));}publicvoidpredict(MultiLayerNetwork model,double[] features){INDArray input =Nd4j.create(features).reshape(1,4);INDArray output = model.output(input);int predictedClass =Nd4j.argMax(output,1).getInt(0);System.out.println("Predicted class: "+ predictedClass);}}Convolutional Neural Network
Image Classification with CNN
LeNet-style architecture for MNIST
importorg.deeplearning4j.nn.conf.layers.ConvolutionLayer;importorg.deeplearning4j.nn.conf.layers.SubsamplingLayer;publicMultiLayerNetworkbuildCNN(){MultiLayerConfiguration config =newNeuralNetConfiguration.Builder().seed(42).updater(newAdam(0.001)).list()// First conv layer.layer(newConvolutionLayer.Builder(5,5).nIn(1)// 1 channel (grayscale).nOut(20)// 20 filters.stride(1,1).activation(Activation.RELU).build()).layer(newSubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())// Second conv layer.layer(newConvolutionLayer.Builder(5,5).nOut(50).stride(1,1).activation(Activation.RELU).build()).layer(newSubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())// Dense layers.layer(newDenseLayer.Builder().nOut(500).activation(Activation.RELU).build()).layer(newOutputLayer.Builder().nOut(10)// 10 digits.activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.convolutionalFlat(28,28,1)).build();returnnewMultiLayerNetwork(config);}When to Use DL4J
✓ Good Fit
- • Need to train custom models (not just use LLMs)
- • Java/Scala-only environment requirements
- • Distributed training on Spark clusters
- • Embedded ML in Java applications
- • Production deployment on JVM infrastructure
- • Import TensorFlow/Keras models
⚠️ Consider Alternatives
- • Just need to call LLMs → Use Spring AI
- • RAG/chatbots → Use LangChain4j
- • Cutting-edge research → Python still leads
- • Simple ML → Consider Weka or Tribuo
- • Team more comfortable with Python
- • Need latest model architectures quickly
DL4J vs Other AI in Java Options
| Feature | DL4J | Spring AI | LangChain4j |
|---|---|---|---|
| Primary Use | Train neural networks | LLM integration | AI agents & RAG |
| Training | Full Support | No | No |
| LLM Calling | No | Core Feature | Core Feature |
| GPU Support | CUDA | N/A | N/A |
| Spark Integration | Native | No | No |
Transfer Learning
Import Pre-trained Models
Leverage Keras/TensorFlow models in your Java application
Training deep neural networks from scratch requires massive datasets and computational resources. Transfer learning lets you take a model trained on millions of images (like VGG16 or ResNet) and fine-tune it for your specific task. DeepLearning4j supports importing Keras H5 models directly, which means you can train in Python and deploy in Java—the best of both worlds.
importorg.deeplearning4j.nn.modelimport.keras.KerasModelImport;publicclassTransferLearningService{publicComputationGraphloadKerasModel(String modelPath)throwsException{// Import a Keras model (e.g., VGG16 pre-trained on ImageNet)ComputationGraph model =KerasModelImport.importKerasModelAndWeights(modelPath);return model;}publicMultiLayerNetworkfineTuneForCustomTask(MultiLayerNetwork baseModel,int numClasses){// Freeze early layers (feature extraction)// Only train the last few layers for your taskTransferLearning.Builder builder =newTransferLearning.Builder(baseModel).fineTuneConfiguration(newFineTuneConfiguration.Builder().updater(newAdam(0.0001))// Lower learning rate.seed(42).build()).setFeatureExtractor("conv2d_5")// Freeze up to this layer.removeOutputLayer().addLayer(newOutputLayer.Builder().nIn(512).nOut(numClasses)// Your custom classes.activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());return builder.build();}}Recurrent Neural Networks
LSTM for Sequence Modeling
Time series, NLP, and sequential data
Recurrent Neural Networks (RNNs) and their more powerful variant LSTMs excel at processing sequential data where order matters—think stock prices, sensor readings, or text. Unlike feedforward networks, LSTMs maintain a "memory" that allows them to learn long-term dependencies. DL4J provides full support for LSTM, GRU, and other recurrent architectures.
importorg.deeplearning4j.nn.conf.layers.LSTM;importorg.deeplearning4j.nn.conf.layers.RnnOutputLayer;publicMultiLayerNetworkbuildLSTM(int numFeatures,int numClasses){MultiLayerConfiguration config =newNeuralNetConfiguration.Builder().seed(42).updater(newAdam(0.001)).list().layer(newLSTM.Builder().nIn(numFeatures).nOut(128).activation(Activation.TANH).build()).layer(newLSTM.Builder().nIn(128).nOut(64).activation(Activation.TANH).build()).layer(newRnnOutputLayer.Builder().nIn(64).nOut(numClasses).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();returnnewMultiLayerNetwork(config);}// For time series prediction (regression)publicvoidtrainOnSequences(MultiLayerNetwork model,INDArray features,// Shape: [batch, features, timeSteps]INDArray labels){DataSet dataSet =newDataSet(features, labels);
model.fit(dataSet);}Distributed Training with Spark
One of DL4J's killer features is native Apache Spark integration. When your dataset is too large for a single machine, or training takes too long, you can distribute the workload across a Spark cluster. DL4J uses a parameter averaging approach where each worker trains on a subset of data, and gradients are periodically synchronized. This can cut training time from days to hours.
importorg.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;importorg.deeplearning4j.spark.api.TrainingMaster;importorg.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;publicclassDistributedTraining{publicvoidtrainOnSpark(JavaSparkContext sc,MultiLayerNetwork model,JavaRDD<DataSet> trainingData){// Configure distributed trainingTrainingMaster trainingMaster =newParameterAveragingTrainingMaster.Builder(32).averagingFrequency(5)// Sync every 5 mini-batches.workerPrefetchNumBatches(2)// Prefetch for efficiency.batchSizePerWorker(32)// Per-worker batch size.build();// Wrap model for SparkSparkDl4jMultiLayer sparkModel =newSparkDl4jMultiLayer(
sc, model, trainingMaster
);// Train across the clusterfor(int epoch =0; epoch <10; epoch++){
sparkModel.fit(trainingData);System.out.println("Completed epoch "+ epoch);}// Get the trained model backMultiLayerNetwork trainedModel = sparkModel.getNetwork();
trainedModel.save(newFile("distributed-model.zip"));}}Scaling Tip: Start with a small cluster (3-5 nodes) and monitor GPU utilization. The parameter averaging frequency is a trade-off: sync too often and you waste time on network overhead; sync too rarely and workers diverge.
Real-World Applications
DeepLearning4j powers production systems across industries where Java is the enterprise standard. Here are scenarios where DL4J excels over Python alternatives:
Fraud Detection
Financial institutions use DL4J to build real-time fraud detection systems that process millions of transactions per second. The JVM's low-latency garbage collection and mature ecosystem make it ideal for millisecond-sensitive decisions.
Predictive Maintenance
Manufacturing and IoT platforms analyze sensor data to predict equipment failures before they happen. DL4J's Spark integration handles the massive data volumes from industrial sensors.
Recommendation Engines
E-commerce and media companies run recommendation models that need to integrate with existing Java backends. DL4J fits naturally into Spring-based microservice architectures.
Model Serving & Deployment
Embedding Models in Spring Boot
Serve predictions via REST API
One of DL4J's biggest advantages is deployment simplicity. Your trained model is just a serialized Java object—no Python runtime, no ONNX conversion, no Docker containers with TensorFlow Serving. Load the model at startup and serve predictions directly from your Spring Boot application.
@RestController@RequestMapping("/api/predictions")publicclassPredictionController{privatefinalMultiLayerNetwork model;publicPredictionController()throwsException{// Load model at startupthis.model =MultiLayerNetwork.load(newFile("models/production-model.zip"),true);}@PostMapping("/classify")publicPredictionResponseclassify(@RequestBodyPredictionRequest request){// Convert input to INDArrayINDArray input =Nd4j.create(request.getFeatures()).reshape(1, request.getFeatures().length);// Run inferenceINDArray output = model.output(input);// Get predicted class and confidenceint predictedClass =Nd4j.argMax(output,1).getInt(0);double confidence = output.getDouble(predictedClass);returnnewPredictionResponse(predictedClass, confidence);}}recordPredictionRequest(double[] features){}recordPredictionResponse(int predictedClass,double confidence){}Tips for Success
Performance
- • Use nd4j-native-platform for AVX2 support
- • Enable GPU with CUDA backend for large models
- • Use workspaces for memory efficiency
- • Batch your data appropriately (32-128)
Development
- • Start with examples from dl4j-examples repo
- • Use UI module for training visualization
- • Import Keras models for transfer learning
- • Join the community on Discourse