Skip to main content

pyxlog/
lib.rs

1//! Python bindings for XLOG via PyO3.
2#![allow(missing_docs)] // PyO3 #[pyclass] / #[pymethods] generate pub items without docs
3#![allow(
4    clippy::large_enum_variant,
5    clippy::needless_range_loop,
6    clippy::too_many_arguments,
7    clippy::type_complexity
8)]
9
10use std::collections::{HashMap, HashSet};
11use std::os::raw::{c_char, c_void};
12use std::sync::Arc;
13
14use pyo3::exceptions::{PyMemoryError, PyRuntimeError, PyValueError};
15use pyo3::prelude::*;
16use pyo3::types::{PyDict, PyList};
17
18use xlog_core::{MemoryBudget, Schema};
19use xlog_cuda::{
20    device_runtime::{
21        AsyncCudaResource, DeviceMemoryResource, GlobalDeviceBudget, StreamPool, XlogDeviceRuntime,
22    },
23    CudaBuffer, CudaDevice, CudaKernelProvider, DlpackManagedTensor, GpuMemoryManager,
24};
25#[cfg(feature = "arrow-device-import")]
26use xlog_cuda::{ArrowDeviceArray, ArrowDeviceArrayOwned};
27use xlog_gpu::logic as gpu_logic;
28use xlog_logic::ast::ProbEngine;
29use xlog_neural::{NetworkRegistry, TensorSourceRegistry};
30use xlog_prob::exact::GpuConfig;
31
32use xlog_core::RelId;
33use xlog_ir::ExecutionPlan;
34use xlog_logic::ast::Program as AstProgram;
35use xlog_runtime::{Executor, RelationStore};
36
37mod neural_registry;
38use neural_registry::NeuralPredicateRegistry;
39mod dlpack;
40mod ilp;
41mod ilp_exact;
42mod ilp_gpu;
43mod logic;
44mod neural;
45mod program;
46mod training;
47mod types;
48pub(crate) use program::{
49    CachedCircuit, CompiledProbProgram, HardFilter, InputSource, JoinPlan, NeuralGroup,
50    QuerySignature,
51};
52
53const DLPACK_CAPSULE_NAME: &[u8] = b"dltensor\0";
54const USED_DLPACK_CAPSULE_NAME: &[u8] = b"used_dltensor\0";
55
56#[cfg(feature = "arrow-device-import")]
57const ARROW_DEVICE_ARRAY_CAPSULE_NAME: &[u8] = b"arrow_device_array\0";
58#[cfg(feature = "arrow-device-import")]
59const USED_ARROW_DEVICE_ARRAY_CAPSULE_NAME: &[u8] = b"used_arrow_device_array\0";
60
61unsafe extern "C" fn dlpack_capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
62    if capsule.is_null() {
63        return;
64    }
65
66    let valid =
67        pyo3::ffi::PyCapsule_IsValid(capsule, DLPACK_CAPSULE_NAME.as_ptr() as *const c_char);
68    if valid == 0 {
69        return;
70    }
71
72    let ptr =
73        pyo3::ffi::PyCapsule_GetPointer(capsule, DLPACK_CAPSULE_NAME.as_ptr() as *const c_char);
74    if ptr.is_null() {
75        pyo3::ffi::PyErr_Clear();
76        return;
77    }
78
79    let managed = ptr as *mut xlog_cuda::DLManagedTensor;
80    drop(DlpackManagedTensor::from_raw(managed));
81}
82
83pub(crate) fn dlpack_capsule_from_tensor(
84    py: Python<'_>,
85    tensor: DlpackManagedTensor,
86) -> PyResult<PyObject> {
87    let raw = tensor.into_raw();
88    let ptr = raw as *mut c_void;
89    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
90    let capsule = unsafe {
91        pyo3::ffi::PyCapsule_New(
92            ptr,
93            DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
94            Some(dlpack_capsule_destructor),
95        )
96    };
97    if capsule.is_null() {
98        // SAFETY: the pointer is a valid owned Python object pointer returned by the C API
99        unsafe {
100            drop(DlpackManagedTensor::from_raw(raw));
101        }
102        return Err(PyRuntimeError::new_err("Failed to create DLPack capsule"));
103    }
104    // SAFETY: capsule is a non-null owned pointer returned by PyCapsule_New; PyO3 takes ownership
105    let obj: Py<PyAny> = unsafe { Py::from_owned_ptr(py, capsule) };
106    Ok(obj)
107}
108
109#[cfg(feature = "arrow-device-import")]
110unsafe extern "C" fn arrow_device_array_capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
111    if capsule.is_null() {
112        return;
113    }
114
115    let valid = pyo3::ffi::PyCapsule_IsValid(
116        capsule,
117        ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
118    );
119    if valid == 0 {
120        return;
121    }
122
123    let ptr = pyo3::ffi::PyCapsule_GetPointer(
124        capsule,
125        ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
126    );
127    if ptr.is_null() {
128        pyo3::ffi::PyErr_Clear();
129        return;
130    }
131
132    drop(ArrowDeviceArrayOwned::from_raw(
133        ptr as *mut ArrowDeviceArray,
134    ));
135}
136
137#[cfg(feature = "arrow-device-import")]
138pub(crate) fn arrow_device_capsule_from_device_array(
139    py: Python<'_>,
140    device_array: ArrowDeviceArrayOwned,
141) -> PyResult<PyObject> {
142    let raw = device_array.into_raw();
143    let ptr = raw as *mut c_void;
144    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
145    let capsule = unsafe {
146        pyo3::ffi::PyCapsule_New(
147            ptr,
148            ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
149            Some(arrow_device_array_capsule_destructor),
150        )
151    };
152    if capsule.is_null() {
153        // SAFETY: the pointer is a valid owned Python object pointer returned by the C API
154        unsafe {
155            drop(ArrowDeviceArrayOwned::from_raw(raw));
156        }
157        return Err(PyRuntimeError::new_err(
158            "Failed to create Arrow device array capsule",
159        ));
160    }
161    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
162    let obj: Py<PyAny> = unsafe { Py::from_owned_ptr(py, capsule) };
163    Ok(obj)
164}
165
166#[cfg(feature = "arrow-device-import")]
167pub(crate) fn arrow_device_from_py(obj: &Bound<'_, PyAny>) -> PyResult<ArrowDeviceArrayOwned> {
168    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
169    if unsafe {
170        pyo3::ffi::PyCapsule_IsValid(
171            obj.as_ptr(),
172            ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
173        )
174    } == 0
175    {
176        return Err(PyValueError::new_err(
177            "Expected an Arrow device array capsule (arrow_device_array)",
178        ));
179    }
180
181    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
182    let ptr = unsafe {
183        pyo3::ffi::PyCapsule_GetPointer(
184            obj.as_ptr(),
185            ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
186        )
187    };
188    if ptr.is_null() {
189        return Err(PyRuntimeError::new_err(
190            "Failed to get Arrow device array pointer",
191        ));
192    }
193
194    // Mark consumed so the capsule destructor doesn't free the pointer we now own.
195    // SAFETY: capsule is valid (checked above); renaming marks it consumed so the destructor skips cleanup
196    let rc = unsafe {
197        pyo3::ffi::PyCapsule_SetName(
198            obj.as_ptr(),
199            USED_ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
200        )
201    };
202    if rc != 0 {
203        return Err(PyRuntimeError::new_err(
204            "Failed to mark Arrow device array capsule as consumed",
205        ));
206    }
207
208    // SAFETY: ptr is non-null (checked above) and points to an ArrowDeviceArray matching the Arrow C Data Interface layout
209    Ok(unsafe { ArrowDeviceArrayOwned::from_raw(ptr as *mut ArrowDeviceArray) })
210}
211
212pub(crate) fn provider_from_config(config: GpuConfig) -> xlog_core::Result<CudaKernelProvider> {
213    let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
214    let stream_pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
215    let async_resource: Box<dyn DeviceMemoryResource + Send + Sync> =
216        Box::new(AsyncCudaResource::new(
217            Arc::clone(&device),
218            config.device_ordinal as u32,
219            Arc::clone(&stream_pool),
220        ));
221    let budget_limit = usize::try_from(config.memory_bytes).unwrap_or(usize::MAX);
222    let budgeted: Box<dyn DeviceMemoryResource + Send + Sync> =
223        Box::new(GlobalDeviceBudget::new(async_resource, budget_limit));
224    let runtime = Arc::new(XlogDeviceRuntime::with_resource(
225        Arc::clone(&device),
226        config.device_ordinal as u32,
227        stream_pool,
228        budgeted,
229    ));
230    let memory = Arc::new(GpuMemoryManager::with_runtime(
231        device.clone(),
232        MemoryBudget::with_limit(config.memory_bytes),
233        runtime,
234    ));
235    CudaKernelProvider::with_runtime(device, memory)
236}
237
238pub(crate) fn enforce_call_memory_limit(
239    provider: &Arc<CudaKernelProvider>,
240    memory_mb: Option<u64>,
241) -> PyResult<()> {
242    let Some(memory_mb) = memory_mb else {
243        return Ok(());
244    };
245    if memory_mb == 0 {
246        return Err(PyValueError::new_err("memory_mb must be > 0"));
247    }
248    let memory_limit_bytes = memory_mb.saturating_mul(1024 * 1024);
249    let allocated_bytes = provider.memory().allocated_bytes();
250    if allocated_bytes > memory_limit_bytes {
251        return Err(PyMemoryError::new_err(format!(
252            "per-call memory limit exceeded before evaluation: allocated_bytes={} memory_limit_bytes={}",
253            allocated_bytes, memory_limit_bytes
254        )));
255    }
256    Ok(())
257}
258
259pub(crate) fn provider_memory_stats(
260    py: Python<'_>,
261    provider: &Arc<CudaKernelProvider>,
262) -> PyResult<PyObject> {
263    let dict = PyDict::new(py);
264    let memory = provider.memory();
265    dict.set_item("allocated_bytes", memory.allocated_bytes())?;
266    dict.set_item("memory_limit_bytes", memory.budget().device_bytes)?;
267    dict.set_item("peak_memory_bytes", memory.allocated_bytes())?;
268    dict.set_item("status", "available")?;
269    Ok(dict.into())
270}
271
272#[allow(dead_code)]
273pub(crate) fn pack_rule_provenance(
274    py: Python<'_>,
275    entries: &[xlog_logic::RuleProvenance],
276) -> PyResult<PyObject> {
277    let list = PyList::empty(py);
278    for entry in entries {
279        let dict = PyDict::new(py);
280        dict.set_item("rule_id", &entry.rule_id)?;
281        dict.set_item("source_kind", entry.source_kind.as_str())?;
282        dict.set_item("head", &entry.head)?;
283        match &entry.source_span {
284            Some(source_span) => dict.set_item("source_span", source_span)?,
285            None => dict.set_item("source_span", py.None())?,
286        }
287        match &entry.generation_trace_hash {
288            Some(hash) => dict.set_item("generation_trace_hash", hash)?,
289            None => dict.set_item("generation_trace_hash", py.None())?,
290        }
291        dict.set_item("support_relation_ids", &entry.support_relation_ids)?;
292        dict.set_item(
293            "counterexample_relation_ids",
294            &entry.counterexample_relation_ids,
295        )?;
296        list.append(dict)?;
297    }
298    Ok(list.into())
299}
300
301#[allow(dead_code)]
302pub(crate) fn pack_query_proof_traces(
303    py: Python<'_>,
304    entries: &[xlog_logic::QueryProofTrace],
305) -> PyResult<PyObject> {
306    let list = PyList::empty(py);
307    for entry in entries {
308        let dict = PyDict::new(py);
309        dict.set_item("query_id", &entry.query_id)?;
310        dict.set_item("query", &entry.query)?;
311        dict.set_item("answer_relation", &entry.answer_relation)?;
312        dict.set_item("rule_ids", &entry.rule_ids)?;
313        dict.set_item("source_facts", &entry.source_facts)?;
314        dict.set_item("rejected_alternatives", &entry.rejected_alternatives)?;
315        list.append(dict)?;
316    }
317    Ok(list.into())
318}
319
320pub(crate) fn parse_prob_engine_override(s: &str) -> PyResult<ProbEngine> {
321    let v = s.trim().to_ascii_lowercase();
322    match v.as_str() {
323        "exact_ddnnf" | "exact" | "ddnnf" => Ok(ProbEngine::ExactDdnnf),
324        "mc" => Ok(ProbEngine::Mc),
325        other => Err(PyValueError::new_err(format!(
326            "Unknown prob_engine '{}'; expected 'exact_ddnnf' or 'mc'",
327            other
328        ))),
329    }
330}
331
332pub(crate) fn dlpack_from_py(obj: &Bound<'_, PyAny>) -> PyResult<DlpackManagedTensor> {
333    let py = obj.py();
334
335    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
336    let capsule_obj: Bound<'_, PyAny> = if unsafe {
337        pyo3::ffi::PyCapsule_IsValid(obj.as_ptr(), DLPACK_CAPSULE_NAME.as_ptr() as *const c_char)
338    } != 0
339    {
340        obj.clone()
341    } else if obj.hasattr("__dlpack__")? {
342        match obj.call_method0("__dlpack__") {
343            Ok(v) => v,
344            Err(err) => {
345                if err.is_instance_of::<pyo3::exceptions::PyTypeError>(py) {
346                    obj.call_method1("__dlpack__", (py.None(),))?
347                } else {
348                    return Err(err);
349                }
350            }
351        }
352    } else {
353        return Err(PyValueError::new_err(
354            "Expected a DLPack capsule or an object with __dlpack__",
355        ));
356    };
357
358    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
359    if unsafe {
360        pyo3::ffi::PyCapsule_IsValid(
361            capsule_obj.as_ptr(),
362            DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
363        )
364    } == 0
365    {
366        return Err(PyValueError::new_err("Invalid DLPack capsule"));
367    }
368
369    // SAFETY: capsule validity was checked immediately before this call; pointer lifetime is managed by the capsule
370    let ptr = unsafe {
371        pyo3::ffi::PyCapsule_GetPointer(
372            capsule_obj.as_ptr(),
373            DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
374        )
375    };
376    if ptr.is_null() {
377        return Err(PyRuntimeError::new_err("Failed to get DLPack pointer"));
378    }
379
380    // SAFETY: capsule is valid (checked above); renaming marks it consumed so the destructor skips cleanup
381    let rc = unsafe {
382        pyo3::ffi::PyCapsule_SetName(
383            capsule_obj.as_ptr(),
384            USED_DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
385        )
386    };
387    if rc != 0 {
388        return Err(PyRuntimeError::new_err(
389            "Failed to mark DLPack capsule as consumed",
390        ));
391    }
392
393    // SAFETY: ptr is non-null (checked above) and points to a DLManagedTensor matching the DLPack specification layout
394    Ok(unsafe { DlpackManagedTensor::from_raw(ptr as *mut xlog_cuda::DLManagedTensor) })
395}
396
397#[pyfunction]
398fn dlpack_is_cuda(obj: &Bound<'_, PyAny>) -> PyResult<bool> {
399    // SAFETY: capsule validity is checked before reading the DLPack header. This
400    // does not consume the capsule; ownership remains with its destructor.
401    if unsafe {
402        pyo3::ffi::PyCapsule_IsValid(obj.as_ptr(), DLPACK_CAPSULE_NAME.as_ptr() as *const c_char)
403    } == 0
404    {
405        return Err(PyValueError::new_err(
406            "Expected a DLPack capsule (dltensor)",
407        ));
408    }
409
410    // SAFETY: capsule validity was checked immediately before this call.
411    let ptr = unsafe {
412        pyo3::ffi::PyCapsule_GetPointer(obj.as_ptr(), DLPACK_CAPSULE_NAME.as_ptr() as *const c_char)
413    };
414    if ptr.is_null() {
415        return Err(PyRuntimeError::new_err("Failed to get DLPack pointer"));
416    }
417
418    // SAFETY: ptr is non-null and points to a DLManagedTensor owned by the capsule.
419    let managed = unsafe { &*(ptr as *const xlog_cuda::DLManagedTensor) };
420    Ok(managed.dl_tensor.device.device_type == xlog_cuda::dlpack::K_DLCUDA)
421}
422
423#[pyclass(name = "DifferentiableProofTraceMap")]
424pub struct PyDifferentiableProofTraceMap {
425    inner: xlog_logic::DifferentiableProofTraceMap,
426}
427
428fn pack_differentiable_proof_trace(
429    py: Python<'_>,
430    trace: &xlog_logic::ProofTrace,
431) -> PyResult<PyObject> {
432    let dict = PyDict::new(py);
433    dict.set_item("proof_id", trace.proof_id)?;
434    dict.set_item("answer_key", &trace.answer_key)?;
435    dict.set_item("clause_id", &trace.clause_id)?;
436    dict.set_item("support_atoms", &trace.support_atoms)?;
437    dict.set_item("weight", trace.weight)?;
438    dict.set_item("gradient", trace.gradient)?;
439    Ok(dict.into())
440}
441
442#[pymethods]
443impl PyDifferentiableProofTraceMap {
444    #[new]
445    fn new() -> Self {
446        Self {
447            inner: xlog_logic::DifferentiableProofTraceMap::new(),
448        }
449    }
450
451    fn insert(
452        &mut self,
453        answer_key: String,
454        clause_id: String,
455        support_atoms: Vec<String>,
456        initial_weight: f64,
457    ) -> PyResult<u64> {
458        if !initial_weight.is_finite() {
459            return Err(PyValueError::new_err(
460                "initial_weight must be a finite float",
461            ));
462        }
463        Ok(self.inner.insert(xlog_logic::ProofTraceSpec {
464            answer_key,
465            clause_id,
466            support_atoms,
467            initial_weight,
468        }))
469    }
470
471    fn trace(&self, py: Python<'_>, proof_id: u64) -> PyResult<Option<PyObject>> {
472        self.inner
473            .trace(proof_id)
474            .map(|trace| pack_differentiable_proof_trace(py, trace))
475            .transpose()
476    }
477
478    fn traces(&self, py: Python<'_>) -> PyResult<PyObject> {
479        let list = PyList::empty(py);
480        for trace in self.inner.traces() {
481            list.append(pack_differentiable_proof_trace(py, trace)?)?;
482        }
483        Ok(list.into())
484    }
485
486    fn accumulate_binary_logistic_gradients(
487        &mut self,
488        targets: Vec<(String, f64)>,
489    ) -> PyResult<f64> {
490        if targets.iter().any(|(_, target)| !target.is_finite()) {
491            return Err(PyValueError::new_err("targets must be finite floats"));
492        }
493        Ok(self.inner.accumulate_binary_logistic_gradients(&targets))
494    }
495
496    fn apply_gradients(&mut self, learning_rate: f64) -> PyResult<()> {
497        if !learning_rate.is_finite() || learning_rate < 0.0 {
498            return Err(PyValueError::new_err(
499                "learning_rate must be a finite non-negative float",
500            ));
501        }
502        self.inner.apply_gradients(learning_rate);
503        Ok(())
504    }
505}
506
507#[pyclass]
508pub struct Program;
509
510#[pyclass]
511pub struct CompiledProgram {
512    pub(crate) program: CompiledProbProgram,
513    pub(crate) output_provider: Arc<CudaKernelProvider>,
514    /// Registry for neural networks
515    pub(crate) network_registry: NetworkRegistry,
516    /// Registry for neural predicate metadata (predicate -> network/labels)
517    pub(crate) neural_registry: NeuralPredicateRegistry,
518    /// Names of neural networks declared in the program (from nn() declarations)
519    pub(crate) declared_networks: HashSet<String>,
520    /// Map from network name to form: true = embedding, false = classification
521    pub(crate) declared_network_forms: HashMap<String, bool>,
522    /// Registry for tensor data sources (images, embeddings, etc.)
523    pub(crate) tensor_sources: TensorSourceRegistry,
524    /// Name of the Stage-B existential-join domain tensor source, as supplied by
525    /// the Python driver (single source of truth). `None` until a domain source is
526    /// registered; read by the join forward to resolve `DomainRow`/`ConstDummy`
527    /// groups instead of an engine-side hardcoded name.
528    pub(crate) domain_source: Option<String>,
529    /// Original program source (for dynamic query compilation)
530    pub(crate) _source: String,
531    /// Parsed program AST (for signature analysis)
532    pub(crate) ast: xlog_logic::ast::Program,
533    /// GPU configuration
534    pub(crate) _gpu_config: GpuConfig,
535    /// Probabilistic inference engine
536    pub(crate) _prob_engine: ProbEngine,
537    /// Cache of analyzed query signatures.
538    pub(crate) query_signature_cache: HashMap<String, QuerySignature>,
539    /// Cache of compiled circuits by template signature
540    pub(crate) circuit_cache: HashMap<String, CachedCircuit>,
541    /// Number of circuit-template cache hits observed by neural training paths.
542    pub(crate) circuit_cache_hits: usize,
543    /// Number of circuit-template cache misses observed by neural training paths.
544    pub(crate) circuit_cache_misses: usize,
545    /// Number of times the template compilation path executed.
546    pub(crate) template_compile_count: usize,
547    /// When true, batch queries sharing the same circuit template in training.
548    pub(crate) batch_queries: bool,
549    /// Latest circuit compilation profile (populated on cache miss when profiling).
550    pub(crate) last_compile_profile: Option<xlog_prob::compilation::CircuitCompileProfile>,
551}
552
553#[pyclass]
554pub struct LogicProgram;
555
556#[pyclass]
557pub struct CompiledLogicProgram {
558    pub(crate) program: gpu_logic::LogicProgram,
559    pub(crate) provider: Arc<CudaKernelProvider>,
560}
561
562#[pyclass]
563pub struct LogicRelationSession {
564    pub(crate) program: gpu_logic::LogicProgram,
565    pub(crate) provider: Arc<CudaKernelProvider>,
566    pub(crate) relation_store: RelationStore,
567    pub(crate) evaluation_store: Option<RelationStore>,
568    pub(crate) session_runtime: Option<gpu_logic::LogicSessionRuntime>,
569    pub(crate) last_delta_stats: Option<LogicDeltaStats>,
570    pub(crate) relation_callbacks: Vec<RelationChangeCallback>,
571    pub(crate) next_relation_callback_id: u64,
572    pub(crate) relation_generations: HashMap<String, u64>,
573}
574
575pub(crate) struct RelationChangeCallback {
576    pub id: u64,
577    pub callback: PyObject,
578}
579
580#[derive(Clone, Debug)]
581pub(crate) struct LogicDeltaStats {
582    pub input_delta_count: usize,
583    pub changed_relations: usize,
584    pub changed_relation_names: Vec<String>,
585    pub insert_rows: u64,
586    pub delete_rows: u64,
587    pub has_deletes: bool,
588    pub affected_sccs: usize,
589    pub recomputed_sccs: usize,
590    pub incremental_sccs: usize,
591    pub coalesced_insert_rows: u64,
592    pub coalesced_delete_rows: u64,
593    pub canceled_rows: u64,
594    pub equivalent_to_full_recompute: Option<bool>,
595    pub planner_telemetry: gpu_logic::DeltaPlannerTelemetry,
596    pub debug_trace: Vec<String>,
597}
598
599#[pyclass]
600pub struct LogicQueryResult {
601    #[pyo3(get)]
602    pub relation_name: String,
603    #[pyo3(get)]
604    pub columns: Vec<String>,
605    #[pyo3(get)]
606    pub sort_labels: Vec<String>,
607    #[pyo3(get)]
608    pub tensors: Vec<PyObject>,
609    #[pyo3(get)]
610    pub num_rows: usize,
611    #[pyo3(get)]
612    pub is_true: bool,
613}
614
615#[pyclass]
616pub struct LogicEvalResult {
617    #[pyo3(get)]
618    pub queries: Vec<Py<LogicQueryResult>>,
619}
620
621#[pyclass]
622pub struct IlpTaggedCreditDeviceResult {
623    #[pyo3(get)]
624    pub fact_row_offsets: PyObject,
625    #[pyo3(get)]
626    pub entry_indices: PyObject,
627    #[pyo3(get)]
628    pub entry_i: PyObject,
629    #[pyo3(get)]
630    pub entry_j: PyObject,
631    #[pyo3(get)]
632    pub entry_k: PyObject,
633}
634
635#[pyclass]
636pub struct McDeviceEvalResult {
637    /// Per-query satisfying-sample counts. DLPack int32 tensor on CUDA.
638    #[pyo3(get)]
639    pub query_counts: PyObject,
640    /// Evidence satisfying-sample count. DLPack int32 tensor with shape [1] on CUDA.
641    #[pyo3(get)]
642    pub evidence_count: PyObject,
643    #[pyo3(get)]
644    pub total_samples: usize,
645    #[pyo3(get)]
646    pub seed: u64,
647    #[pyo3(get)]
648    pub confidence: f64,
649    #[pyo3(get)]
650    pub nonmonotone_semantics: String,
651    #[pyo3(get)]
652    pub nonmonotone_sccs: usize,
653    #[pyo3(get)]
654    pub nonmonotone_cycles: usize,
655    #[pyo3(get)]
656    pub nonmonotone_iteration_limit_hits: usize,
657    #[pyo3(get)]
658    pub sampling_method: String,
659    #[pyo3(get)]
660    pub resident_no_host_certified: bool,
661    #[pyo3(get)]
662    pub resident_no_host_policy_result: String,
663    #[pyo3(get)]
664    pub resident_no_host_tracked_dtoh_calls: u64,
665    #[pyo3(get)]
666    pub resident_no_host_tracked_htod_calls: u64,
667    #[pyo3(get)]
668    pub resident_no_host_host_loop_iterations: u64,
669    #[pyo3(get)]
670    pub resident_no_host_per_sample_host_launches: u64,
671    #[pyo3(get)]
672    pub resident_no_host_untracked_metadata_reads: u64,
673    #[pyo3(get)]
674    pub resident_no_host_engine_launches: u64,
675    #[pyo3(get)]
676    pub resident_no_host_host_fixpoint_iterations: u64,
677    #[pyo3(get)]
678    pub resident_no_host_per_operator_host_allocations: u64,
679}
680
681#[pyclass]
682pub struct EvalResult {
683    #[pyo3(get)]
684    pub atoms: Vec<String>,
685    #[pyo3(get)]
686    pub prob: PyObject,
687    #[pyo3(get)]
688    pub log_prob: PyObject,
689    #[pyo3(get)]
690    pub num_vars: usize,
691    #[pyo3(get)]
692    pub grad_true: Option<Vec<PyObject>>,
693    #[pyo3(get)]
694    pub grad_false: Option<Vec<PyObject>>,
695    #[pyo3(get)]
696    pub approx: bool,
697    #[pyo3(get)]
698    pub stderr: Option<PyObject>,
699    #[pyo3(get)]
700    pub ci_low: Option<PyObject>,
701    #[pyo3(get)]
702    pub ci_high: Option<PyObject>,
703    #[pyo3(get)]
704    pub samples: Option<usize>,
705    #[pyo3(get)]
706    pub evidence_samples: Option<usize>,
707    #[pyo3(get)]
708    pub seed: Option<u64>,
709    #[pyo3(get)]
710    pub confidence: Option<f64>,
711    #[pyo3(get)]
712    pub nonmonotone_semantics: Option<String>,
713    #[pyo3(get)]
714    pub nonmonotone_sccs: Option<usize>,
715    #[pyo3(get)]
716    pub nonmonotone_cycles: Option<usize>,
717    #[pyo3(get)]
718    pub nonmonotone_iteration_limit_hits: Option<usize>,
719    #[pyo3(get)]
720    pub sampling_method: Option<String>,
721    /// MC only: which engine produced the result — `"gpu-resident"` for the
722    /// resident megakernel engine, `"cpu-oracle"` for the explicitly opted-in
723    /// CPU oracle. `None` for exact inference.
724    #[pyo3(get)]
725    pub mc_engine: Option<String>,
726}
727
728// =========================================================================
729// Training Infrastructure
730// =========================================================================
731
732/// Statistics for a single training epoch.
733#[pyclass]
734#[derive(Clone)]
735pub struct EpochStats {
736    /// Average loss across all batches in the epoch
737    #[pyo3(get)]
738    pub avg_loss: f64,
739    /// Number of batches processed
740    #[pyo3(get)]
741    pub num_batches: usize,
742    /// Total number of queries processed
743    #[pyo3(get)]
744    pub total_queries: usize,
745}
746
747/// Training history tracking loss over epochs and batches.
748#[pyclass]
749#[derive(Clone)]
750pub struct TrainingHistory {
751    /// Loss at the end of each epoch
752    #[pyo3(get)]
753    pub epoch_losses: Vec<f64>,
754    /// Wall-clock time (seconds) for each epoch
755    #[pyo3(get)]
756    pub epoch_times: Vec<f64>,
757    /// Loss for each batch across all epochs
758    #[pyo3(get)]
759    pub batch_losses: Vec<f64>,
760    /// True if training was stopped early due to validation loss plateau.
761    #[pyo3(get)]
762    pub stopped_early: bool,
763}
764
765#[pyclass]
766pub struct IlpProgramFactory;
767
768#[pyclass]
769pub struct CompiledIlpProgram {
770    pub(crate) base_source: String,
771    pub(crate) _learnable_source: String,
772    pub(crate) ast: AstProgram,
773    pub(crate) executor: Executor,
774    pub(crate) provider: Arc<CudaKernelProvider>,
775    pub(crate) plan: ExecutionPlan,
776    pub(crate) rel_index: Vec<(RelId, String)>,
777    pub(crate) schemas: HashMap<String, Schema>,
778    pub(crate) left_keys: Vec<usize>,
779    pub(crate) right_keys: Vec<usize>,
780    pub(crate) head_projection: Vec<usize>,
781    pub(crate) compiled_schema_size: usize,
782    pub(crate) head_rel_name: String,
783    pub(crate) max_active_rules: usize,
784    pub(crate) candidate_map: Option<HashMap<(u32, u32, u32), u32>>,
785    pub(crate) candidate_order: Option<Vec<(u32, u32, u32)>>,
786    pub(crate) relation_overrides: HashMap<String, CudaBuffer>,
787    /// Maximum bytes for per-chunk temp allocations (masks, prefix sums,
788    /// chunk-local COO scratch). The final merged COO buffer is exact-NNZ
789    /// sized and may exceed this budget. Default: 16 MB.
790    pub(crate) coo_chunk_budget: u64,
791    /// When true, raise instead of falling back to chunked COO path.
792    /// Use in zero-D2H benchmarks and CI gates. Default: false.
793    pub(crate) strict_zero_dtoh: bool,
794}
795
796#[pymodule]
797#[pyo3(name = "_native")]
798fn pyxlog(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
799    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
800    m.add_class::<Program>()?;
801    m.add_class::<CompiledProgram>()?;
802    m.add_class::<LogicProgram>()?;
803    m.add_class::<CompiledLogicProgram>()?;
804    m.add_class::<LogicRelationSession>()?;
805    m.add_class::<LogicQueryResult>()?;
806    m.add_class::<LogicEvalResult>()?;
807    m.add_class::<McDeviceEvalResult>()?;
808    m.add_class::<EvalResult>()?;
809    // Training infrastructure
810    m.add_class::<PyDifferentiableProofTraceMap>()?;
811    m.add_class::<EpochStats>()?;
812    m.add_class::<TrainingHistory>()?;
813    // ILP bindings
814    m.add_class::<IlpProgramFactory>()?;
815    m.add_class::<CompiledIlpProgram>()?;
816    m.add_class::<IlpTaggedCreditDeviceResult>()?;
817    m.add_function(wrap_pyfunction!(training::train_model, m)?)?;
818    m.add_function(wrap_pyfunction!(training::train_model_tensor, m)?)?;
819    m.add_function(wrap_pyfunction!(dlpack::dlpack_roundtrip, m)?)?;
820    m.add_function(wrap_pyfunction!(dlpack_is_cuda, m)?)?;
821    #[cfg(feature = "arrow-device-import")]
822    m.add_function(wrap_pyfunction!(dlpack::export_arrow_device, m)?)?;
823    #[cfg(feature = "arrow-device-import")]
824    m.add_function(wrap_pyfunction!(dlpack::import_arrow_device, m)?)?;
825    Ok(())
826}