Skip to content

ONNX Conversion & Java Serving (CPU & GPU)

High-performance inference on CPU and GPU using ONNX Runtime (ORT) in a Java environment.

Why ONNX + Java?

  • JVM Ecosystem: Integrate directly with existing Java/Scala backend services (Spring Boot, Flink, Spark).
  • Lower Latency: Avoid HTTP overhead of calling an external model server (TensorFlow Serving / TorchServe) by running in-process (JNI).
  • CPU optimization: ONNX Runtime is highly optimized for CPU inference (AVX2/AVX512).

1. Model Conversion (Python)

Convert PyTorch or TensorFlow models to .onnx format.

PyTorch to ONNX

import torch
import torch.nn as nn

# 1. Load trained model
model = MyModel()
model.load_state_dict(torch.load("model.pth"))
model.eval()

# 2. Define dummy input (shape must match model input)
dummy_input = torch.randn(1, 3, 224, 224)

# 3. Export
torch.onnx.export(
    model, 
    dummy_input, 
    "model.onnx",
    input_names=["input"], 
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=14
)

TensorFlow to ONNX

Use tf2onnx:

pip install tf2onnx
python -m tf2onnx.convert --saved-model ./tf_model --output model.onnx --opset 14

2. Java Inference (ONNX Runtime)

Use the onnxruntime Java API to load and query the model.

Dependency (Maven)

<!-- Optimized Native Binaries for CPU -->
<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.17.1</version>
</dependency>

Inference Code

import ai.onnxruntime.*;
import java.nio.FloatBuffer;
import java.util.Collections;

public class ModelService {
    private OrtEnvironment env;
    private OrtSession session;

    public ModelService(String modelPath) throws OrtException {
        this.env = OrtEnvironment.getEnvironment();
        // Optimization: Enable graph optimizations
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
        opts.setIntraOpNumThreads(4); // Tune based on CPU cores

        this.session = env.createSession(modelPath, opts);
    }

    public float[] predict(float[] inputData, long[] shape) throws OrtException {
        // 1. Create Tensor from Java array
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), shape);

        // 2. Run Inference
        OrtSession.Result result = session.run(Collections.singletonMap("input", inputTensor));

        // 3. Extract Output
        float[][] output = (float[][]) result.get(0).getValue();
        result.close(); // Important: Close native resource to prevent leaks (or use try-with-resources)

        return output[0];
    }
}

3. Java Inference (GPU / CUDA)

To run on NVIDIA GPUs, switch dependencies and configure the session options.

Dependency (Maven)

Replace onnxruntime with onnxruntime_gpu:

<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime_gpu</artifactId>
    <version>1.17.1</version>
</dependency>

CUDA Configuration Code

import ai.onnxruntime.providers.OrtCUDAProviderOptions;

// ... inside constructor ...

OrtSession.SessionOptions opts = new OrtSession.SessionOptions();

// 1. Configure CUDA Provider
OrtCUDAProviderOptions cudaOpts = new OrtCUDAProviderOptions(0); // GPU Device ID 0
opts.addCUDA(cudaOpts);

// 2. Create Session (will load model onto GPU)
this.session = env.createSession(modelPath, opts);

Note: Ensure the host machine has compatible NVIDIA Drivers and CUDA Toolkit installed (matching the ORT version).


4. Performance Tuning

CPU Tuning

Setting Recommendation
IntraOpNumThreads Set to number of physical cores available to the request. Don't oversubscribe.
InterOpNumThreads Keep low (1) unless running parallel subgraphs (rare).
Execution Mode SEQUENTIAL is usually faster for simple models; PARALLEL increases latency but throughput for complex graphs.
Memory Mapping Use sessionOptions.addSessionConfigEntry("session.load_model_format", "ORT") for faster loading if using optimized format.

GPU Tuning

Setting Recommendation
cudnn_conv_algo_search Set to HEURISTIC (default) or EXHAUSTIVE (slower init, faster run) via provider options.
gpu_mem_limit Set a hard limit on GPU memory usage if sharing the card with other processes.
IO Binding Crucial: Use run(..., runOptions, ioBinding) to keep inputs/outputs on device (GPU) and avoid CPU-GPU copies.

Memory Management

Critical: ONNX Runtime uses off-heap memory (C++). - Always close OrtSession.Result and OnnxTensor objects explicitly or use try-with-resources. - Garbage Collection (GC) won't reclaim native memory fast enough, leading to OOM.


5. Advanced Topics

Quantization (INT8)

Drastically reduce model size and latency with minimal accuracy loss. - Dynamic Quantization: Weights are INT8, activations quantized on-the-fly (Great for NLP/Transformers). - Static Quantization: Weights and activations are INT8. Requires a "calibration" dataset to determine ranges (Best for CNNs).

Execution Providers: CUDA vs TensorRT

  • CUDA: General purpose, robust support for most ONNX ops. Fast compilation.
  • TensorRT: NVIDIA's specialized optimizer.
    • Pros: Can be ~2-5x faster than basic CUDA execution.
    • Cons: Extremely slow engine build time (minutes on startup), strictly tied to specific GPU architecture. Use for stable, long-running production deployments.

6. Architecture Diagram

graph TD
    subgraph Python ["Training Phase - Python"]
        TF["TensorFlow / PyTorch"] -->|Export| ModelFile["model.onnx"]
    end

    subgraph Java ["Serving Phase - JVM"]
        AppService["Spring Boot / Flink"] 
        OrtRuntime["ONNX Runtime C++"]

        ModelFile -->|Load| OrtRuntime
        AppService -- JNI --> OrtRuntime
        OrtRuntime -- Result --> AppService
    end

    style OrtRuntime fill:#f9f,stroke:#333