Expand description
Batched Neural Evaluation
This module provides infrastructure for grouping neural predicate calls by network name, enabling efficient batched GPU evaluation.
§Why Batching?
In DeepProbLog-style programs, the same neural network may be called many times:
nn(mnist_net, [X], Y, [0..9]) :: digit(X, Y).
addition(X, Y, Z) :- digit(X, LeftDigit), digit(Y, RightDigit), Z is LeftDigit + RightDigit.For a query like addition(img1, img2, Z), we need to evaluate mnist_net
twice (once for each digit). Instead of two separate forward passes, we batch
them into a single mnist_net([img1, img2]) call for GPU efficiency.
§Usage
use xlog_neural::batch::{BatchCollector, NeuralCall};
let mut collector = BatchCollector::new();
// During proof search, collect neural calls
collector.add(NeuralCall::new("mnist", vec![0])); // digit(img[0], Y)
collector.add(NeuralCall::new("mnist", vec![1])); // digit(img[1], Y)
// Group by network for batched evaluation
let batches = collector.collect();
let mnist_indices = collector.indices_for_network("mnist");
// mnist_indices = [0, 1] - evaluate both in one forward passStructs§
- Batch
Collector - Collects neural predicate calls for batched evaluation.
- Batch
Mapping - Mapping from call index to batch result index.
- Batch
Result - Result of a batched neural evaluation.
- Neural
Call - A single neural predicate call.