Skip to main content

Module batch

Module batch 

Source
Expand description

Batched Neural Evaluation

This module provides infrastructure for grouping neural predicate calls by network name, enabling efficient batched GPU evaluation.

§Why Batching?

In DeepProbLog-style programs, the same neural network may be called many times:

nn(mnist_net, [X], Y, [0..9]) :: digit(X, Y).
addition(X, Y, Z) :- digit(X, LeftDigit), digit(Y, RightDigit), Z is LeftDigit + RightDigit.

For a query like addition(img1, img2, Z), we need to evaluate mnist_net twice (once for each digit). Instead of two separate forward passes, we batch them into a single mnist_net([img1, img2]) call for GPU efficiency.

§Usage

use xlog_neural::batch::{BatchCollector, NeuralCall};

let mut collector = BatchCollector::new();

// During proof search, collect neural calls
collector.add(NeuralCall::new("mnist", vec![0])); // digit(img[0], Y)
collector.add(NeuralCall::new("mnist", vec![1])); // digit(img[1], Y)

// Group by network for batched evaluation
let batches = collector.collect();
let mnist_indices = collector.indices_for_network("mnist");
// mnist_indices = [0, 1] - evaluate both in one forward pass

Structs§

BatchCollector
Collects neural predicate calls for batched evaluation.
BatchMapping
Mapping from call index to batch result index.
BatchResult
Result of a batched neural evaluation.
NeuralCall
A single neural predicate call.