-
Notifications
You must be signed in to change notification settings - Fork 8
Optimizing Inference
Inference is the process of making predictions using a trained model. In Brain4J, inference performance can be significantly improved using batched inference.
Instead of processing one input at a time, it's much more efficient to group multiple samples into a single batch. For example, predicting 100 inputs individually is slower and less efficient than predicting all 100 in a single batch.
This is because Brain4J executes tensor operations using multi-threaded routines that scale better with larger data chunks.
Here's a basic inference with a single input:
Tensor input = Tensors.vector(dimension);
Tensor output = model.predict(input);
This works, but it’s not optimal for performance with many predictions.
Using batched inference:
List<Tensor> inputs = ...; // All inputs must be 1D tensors of the same dimension
Tensor batch = Tensors.mergeTensors(inputs); // Shape: [batch_size, input_dim]
Tensor output = model.predict(batch); // Output shape: [batch_size, output_dim]
Note
Tensors.mergeTensors
stacks the input tensors one on top of the other, creating a matrix [elements, dimension]
Check out Examples & Use Cases
This wiki is still under construction. If you feel that you can contribute, please do so! Thanks.