Skip to content

Optimizing Inference

echo edited this page May 4, 2025 · 4 revisions

Inference is the process of making predictions using a trained model. In Brain4J, inference performance can be significantly improved using batched inference.

Batch Inference

Instead of processing one input at a time, it's much more efficient to group multiple samples into a single batch. For example, predicting 100 inputs individually is slower and less efficient than predicting all 100 in a single batch.

This is because Brain4J executes tensor operations using multi-threaded routines that scale better with larger data chunks.

Example

Here's a basic inference with a single input:

Tensor input = Tensors.vector(dimension);
Tensor output = model.predict(input);

This works, but it’s not optimal for performance with many predictions.

Using batched inference:

List<Tensor> inputs = ...; // All inputs must be 1D tensors of the same dimension
Tensor batch = Tensors.mergeTensors(inputs); // Shape: [batch_size, input_dim]
Tensor output = model.predict(batch); // Output shape: [batch_size, output_dim]

Note

Tensors.mergeTensors stacks the input tensors one on top of the other, creating a matrix [elements, dimension]

Next Steps

Check out Examples & Use Cases

Clone this wiki locally