Skip to main content

xlog_cuda_tests/harness/
xgcf.rs

1//! Helpers for testing XGCF circuit CUDA kernels.
2
3use std::ffi::c_void;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::{circuit_kernels, AsKernelParam, CudaFunction, LaunchAsync, CIRCUIT_MODULE};
9
10use super::TestContext;
11
12#[derive(Debug, Clone)]
13pub struct TinyXgcfSpec {
14    pub num_nodes: usize,
15    pub num_vars: usize,
16    pub root: u32,
17    pub node_type: Vec<u8>,
18    pub child_offsets: Vec<u32>,
19    pub child_indices: Vec<u32>,
20    pub lit: Vec<i32>,
21    pub decision_var: Vec<u32>,
22    pub decision_child_false: Vec<u32>,
23    pub decision_child_true: Vec<u32>,
24    pub level_nodes: Vec<u32>,
25    pub levels: Vec<(u32, u32)>,
26    pub var_log_true: Vec<f64>,
27    pub var_log_false: Vec<f64>,
28    pub expected_values: Vec<f64>,
29    pub expected_grad_true: Vec<f64>,
30    pub expected_grad_false: Vec<f64>,
31}
32
33#[derive(Debug, Clone)]
34pub struct TinyXgcfRun {
35    pub values: Vec<f64>,
36    pub adj: Vec<f64>,
37    pub grad_true: Vec<f64>,
38    pub grad_false: Vec<f64>,
39}
40
41/// Device-resident XGCF circuit + reusable buffers.
42///
43/// This is used by certification categories that validate *transfer efficiency* and *circuit reuse*.
44/// The key property: circuit structure is uploaded once; repeated evaluations reuse device buffers.
45pub struct TinyXgcfDevice {
46    pub num_nodes: usize,
47    pub num_vars: usize,
48    pub root: u32,
49    levels: Vec<(u32, u32)>,
50
51    // Cached kernel handles for performance-sensitive certification categories.
52    forward_fn: CudaFunction,
53    backward_propagate_fn: CudaFunction,
54    backward_decision_grad_fn: CudaFunction,
55    backward_lit_grad_fn: CudaFunction,
56
57    // Circuit structure (device-resident).
58    d_node_type: TrackedCudaSlice<u8>,
59    d_child_offsets: TrackedCudaSlice<u32>,
60    d_child_indices: TrackedCudaSlice<u32>,
61    d_lit: TrackedCudaSlice<i32>,
62    d_decision_var: TrackedCudaSlice<u32>,
63    d_decision_child_false: TrackedCudaSlice<u32>,
64    d_decision_child_true: TrackedCudaSlice<u32>,
65    d_level_nodes: TrackedCudaSlice<u32>,
66    d_level_offsets: TrackedCudaSlice<u32>,
67
68    // Per-evaluation inputs (device-resident).
69    d_var_log_true: TrackedCudaSlice<f64>,
70    d_var_log_false: TrackedCudaSlice<f64>,
71
72    // Per-evaluation outputs / scratch (device-resident).
73    d_values: TrackedCudaSlice<f64>,
74    d_adj: TrackedCudaSlice<f64>,
75    d_grad_true: TrackedCudaSlice<f64>,
76    d_grad_false: TrackedCudaSlice<f64>,
77}
78
79impl TinyXgcfDevice {
80    pub fn upload(ctx: &TestContext, spec: &TinyXgcfSpec) -> Result<Self> {
81        let device = ctx.device.inner();
82
83        let forward_fn = device
84            .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FORWARD_LEVEL)
85            .ok_or_else(|| {
86                XlogError::Kernel(format!(
87                    "Kernel {} not found in {}",
88                    circuit_kernels::XGCF_FORWARD_LEVEL,
89                    CIRCUIT_MODULE
90                ))
91            })?;
92        let backward_propagate_fn = device
93            .get_func(
94                CIRCUIT_MODULE,
95                circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
96            )
97            .ok_or_else(|| {
98                XlogError::Kernel(format!(
99                    "Kernel {} not found in {}",
100                    circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
101                    CIRCUIT_MODULE
102                ))
103            })?;
104        let backward_decision_grad_fn = device
105            .get_func(
106                CIRCUIT_MODULE,
107                circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
108            )
109            .ok_or_else(|| {
110                XlogError::Kernel(format!(
111                    "Kernel {} not found in {}",
112                    circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
113                    CIRCUIT_MODULE
114                ))
115            })?;
116        let backward_lit_grad_fn = device
117            .get_func(
118                CIRCUIT_MODULE,
119                circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
120            )
121            .ok_or_else(|| {
122                XlogError::Kernel(format!(
123                    "Kernel {} not found in {}",
124                    circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
125                    CIRCUIT_MODULE
126                ))
127            })?;
128
129        let mut d_node_type = ctx.memory.alloc::<u8>(spec.node_type.len())?;
130        ctx.htod_sync_copy_into(&spec.node_type, &mut d_node_type)
131            .map_err(|e| XlogError::Kernel(format!("Failed to upload node_type: {}", e)))?;
132
133        let mut d_child_offsets = ctx.memory.alloc::<u32>(spec.child_offsets.len())?;
134        ctx.htod_sync_copy_into(&spec.child_offsets, &mut d_child_offsets)
135            .map_err(|e| XlogError::Kernel(format!("Failed to upload child_offsets: {}", e)))?;
136
137        let mut d_child_indices = ctx.memory.alloc::<u32>(spec.child_indices.len())?;
138        ctx.htod_sync_copy_into(&spec.child_indices, &mut d_child_indices)
139            .map_err(|e| XlogError::Kernel(format!("Failed to upload child_indices: {}", e)))?;
140
141        let mut d_lit = ctx.memory.alloc::<i32>(spec.lit.len())?;
142        ctx.htod_sync_copy_into(&spec.lit, &mut d_lit)
143            .map_err(|e| XlogError::Kernel(format!("Failed to upload lit: {}", e)))?;
144
145        let mut d_decision_var = ctx.memory.alloc::<u32>(spec.decision_var.len())?;
146        ctx.htod_sync_copy_into(&spec.decision_var, &mut d_decision_var)
147            .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_var: {}", e)))?;
148
149        let mut d_decision_child_false =
150            ctx.memory.alloc::<u32>(spec.decision_child_false.len())?;
151        ctx.htod_sync_copy_into(&spec.decision_child_false, &mut d_decision_child_false)
152            .map_err(|e| {
153                XlogError::Kernel(format!("Failed to upload decision_child_false: {}", e))
154            })?;
155
156        let mut d_decision_child_true = ctx.memory.alloc::<u32>(spec.decision_child_true.len())?;
157        ctx.htod_sync_copy_into(&spec.decision_child_true, &mut d_decision_child_true)
158            .map_err(|e| {
159                XlogError::Kernel(format!("Failed to upload decision_child_true: {}", e))
160            })?;
161
162        let mut d_level_nodes = ctx.memory.alloc::<u32>(spec.level_nodes.len())?;
163        ctx.htod_sync_copy_into(&spec.level_nodes, &mut d_level_nodes)
164            .map_err(|e| XlogError::Kernel(format!("Failed to upload level_nodes: {}", e)))?;
165
166        // Device-resident level offsets (len = num_levels + 1) for level-aware kernels.
167        if spec.levels.is_empty() {
168            return Err(XlogError::Kernel(
169                "TinyXgcfSpec requires non-empty levels".to_string(),
170            ));
171        }
172        let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
173        for &(offset, _len) in &spec.levels {
174            level_offsets.push(offset);
175        }
176        let (last_offset, last_len) = *spec.levels.last().unwrap();
177        level_offsets.push(last_offset + last_len);
178
179        if level_offsets[0] != 0 {
180            return Err(XlogError::Kernel(
181                "TinyXgcfSpec level_offsets must start at 0".to_string(),
182            ));
183        }
184        for (i, &(offset, len)) in spec.levels.iter().enumerate() {
185            let expected_next = offset + len;
186            if level_offsets[i] != offset || level_offsets[i + 1] != expected_next {
187                return Err(XlogError::Kernel(
188                    "TinyXgcfSpec levels must be contiguous and match offsets".to_string(),
189                ));
190            }
191        }
192        let total = *level_offsets.last().unwrap() as usize;
193        if total != spec.level_nodes.len() {
194            return Err(XlogError::Kernel(format!(
195                "TinyXgcfSpec level_nodes len {} != level_offsets.last {}",
196                spec.level_nodes.len(),
197                total
198            )));
199        }
200
201        let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
202        ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
203            .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
204
205        let mut d_var_log_true = ctx.memory.alloc::<f64>(spec.var_log_true.len())?;
206        ctx.htod_sync_copy_into(&spec.var_log_true, &mut d_var_log_true)
207            .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
208
209        let mut d_var_log_false = ctx.memory.alloc::<f64>(spec.var_log_false.len())?;
210        ctx.htod_sync_copy_into(&spec.var_log_false, &mut d_var_log_false)
211            .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
212
213        let d_values = ctx.memory.alloc::<f64>(spec.num_nodes)?;
214        let d_adj = ctx.memory.alloc::<f64>(spec.num_nodes)?;
215        let d_grad_true = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
216        let d_grad_false = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
217
218        Ok(Self {
219            num_nodes: spec.num_nodes,
220            num_vars: spec.num_vars,
221            root: spec.root,
222            levels: spec.levels.clone(),
223            forward_fn,
224            backward_propagate_fn,
225            backward_decision_grad_fn,
226            backward_lit_grad_fn,
227            d_node_type,
228            d_child_offsets,
229            d_child_indices,
230            d_lit,
231            d_decision_var,
232            d_decision_child_false,
233            d_decision_child_true,
234            d_level_nodes,
235            d_level_offsets,
236            d_var_log_true,
237            d_var_log_false,
238            d_values,
239            d_adj,
240            d_grad_true,
241            d_grad_false,
242        })
243    }
244
245    fn launch_level_cached(
246        kernel: &CudaFunction,
247        num_level_nodes: u32,
248        params: &mut Vec<*mut c_void>,
249    ) -> Result<()> {
250        if num_level_nodes == 0 {
251            return Ok(());
252        }
253        let block_size = 256u32;
254        let num_blocks = (num_level_nodes + block_size - 1) / block_size;
255        let config = LaunchConfig {
256            grid_dim: (num_blocks, 1, 1),
257            block_dim: (block_size, 1, 1),
258            shared_mem_bytes: 0,
259        };
260        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
261        unsafe { kernel.clone().launch(config, params) }
262            .map_err(|e| XlogError::Kernel(format!("Failed to launch level kernel: {}", e)))?;
263        Ok(())
264    }
265
266    pub fn set_weights(
267        &mut self,
268        ctx: &TestContext,
269        log_true: &[f64],
270        log_false: &[f64],
271    ) -> Result<()> {
272        if log_true.len() != self.d_var_log_true.len()
273            || log_false.len() != self.d_var_log_false.len()
274        {
275            return Err(XlogError::Kernel(format!(
276                "Weight length mismatch: got (true={}, false={}), expected (true={}, false={})",
277                log_true.len(),
278                log_false.len(),
279                self.d_var_log_true.len(),
280                self.d_var_log_false.len()
281            )));
282        }
283        ctx.htod_sync_copy_into(log_true, &mut self.d_var_log_true)
284            .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
285        ctx.htod_sync_copy_into(log_false, &mut self.d_var_log_false)
286            .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
287        Ok(())
288    }
289
290    /// Launch forward kernels (no sync, no host transfers).
291    pub fn forward_launch(&mut self, _ctx: &TestContext) -> Result<()> {
292        for (level, &(_offset, len)) in self.levels.iter().enumerate() {
293            let level_u32 = level as u32;
294            let mut params: Vec<*mut c_void> = vec![
295                (&self.d_node_type).as_kernel_param(),
296                (&self.d_child_offsets).as_kernel_param(),
297                (&self.d_child_indices).as_kernel_param(),
298                (&self.d_lit).as_kernel_param(),
299                (&self.d_decision_var).as_kernel_param(),
300                (&self.d_decision_child_false).as_kernel_param(),
301                (&self.d_decision_child_true).as_kernel_param(),
302                (&self.d_level_nodes).as_kernel_param(),
303                (&self.d_level_offsets).as_kernel_param(),
304                level_u32.as_kernel_param(),
305                (&self.d_var_log_true).as_kernel_param(),
306                (&self.d_var_log_false).as_kernel_param(),
307                (&mut self.d_values).as_kernel_param(),
308            ];
309            Self::launch_level_cached(&self.forward_fn, len, &mut params)?;
310        }
311        Ok(())
312    }
313
314    pub fn forward_download_values(&mut self, ctx: &TestContext) -> Result<Vec<f64>> {
315        self.forward_launch(ctx)?;
316        ctx.sync_and_check()?;
317        ctx.dtoh_sync_copy(&self.d_values)
318            .map_err(|e| XlogError::Kernel(format!("Failed to download values: {}", e)))
319    }
320
321    pub fn forward_download_root(&mut self, ctx: &TestContext) -> Result<f64> {
322        self.forward_launch(ctx)?;
323        ctx.sync_and_check()?;
324        let root_idx: usize = self.root as usize;
325        if root_idx >= self.num_nodes {
326            return Err(XlogError::Kernel(format!(
327                "Root {} out of bounds for num_nodes {}",
328                self.root, self.num_nodes
329            )));
330        }
331        let root_view = self.d_values.slice(root_idx..(root_idx + 1));
332        let mut root_host = [0.0f64];
333        ctx.dtoh_sync_copy_into(&root_view, &mut root_host)
334            .map_err(|e| XlogError::Kernel(format!("Failed to download root value: {}", e)))?;
335        Ok(root_host[0])
336    }
337
338    /// Launch backward kernels using existing `d_values` (no sync, no host transfers).
339    pub fn backward_only_launch(&mut self, ctx: &TestContext) -> Result<()> {
340        let device = ctx.device.inner();
341        device
342            .memset_zeros(&mut self.d_adj)
343            .map_err(|e| XlogError::Kernel(format!("Failed to zero adj: {}", e)))?;
344        device
345            .memset_zeros(&mut self.d_grad_true)
346            .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_true: {}", e)))?;
347        device
348            .memset_zeros(&mut self.d_grad_false)
349            .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_false: {}", e)))?;
350
351        let root_idx: usize = self.root as usize;
352        if root_idx >= self.num_nodes {
353            return Err(XlogError::Kernel(format!(
354                "Root {} out of bounds for num_nodes {}",
355                self.root, self.num_nodes
356            )));
357        }
358        let mut root_view = self.d_adj.slice_mut(root_idx..(root_idx + 1));
359        ctx.htod_sync_copy_into(&[1.0f64], &mut root_view)
360            .map_err(|e| XlogError::Kernel(format!("Failed to set root adjoint: {}", e)))?;
361
362        for (level, &(_offset, len)) in self.levels.iter().enumerate().rev() {
363            let level_u32 = level as u32;
364            let mut params: Vec<*mut c_void> = vec![
365                (&self.d_node_type).as_kernel_param(),
366                (&self.d_child_offsets).as_kernel_param(),
367                (&self.d_child_indices).as_kernel_param(),
368                (&self.d_decision_var).as_kernel_param(),
369                (&self.d_decision_child_false).as_kernel_param(),
370                (&self.d_decision_child_true).as_kernel_param(),
371                (&self.d_level_nodes).as_kernel_param(),
372                (&self.d_level_offsets).as_kernel_param(),
373                level_u32.as_kernel_param(),
374                (&self.d_var_log_true).as_kernel_param(),
375                (&self.d_var_log_false).as_kernel_param(),
376                (&self.d_values).as_kernel_param(),
377                (&mut self.d_adj).as_kernel_param(),
378            ];
379            Self::launch_level_cached(&self.backward_propagate_fn, len, &mut params)?;
380        }
381
382        for (level, &(_offset, len)) in self.levels.iter().enumerate().rev() {
383            let level_u32 = level as u32;
384            let mut params: Vec<*mut c_void> = vec![
385                (&self.d_node_type).as_kernel_param(),
386                (&self.d_decision_var).as_kernel_param(),
387                (&self.d_decision_child_false).as_kernel_param(),
388                (&self.d_decision_child_true).as_kernel_param(),
389                (&self.d_level_nodes).as_kernel_param(),
390                (&self.d_level_offsets).as_kernel_param(),
391                level_u32.as_kernel_param(),
392                (&self.d_var_log_true).as_kernel_param(),
393                (&self.d_var_log_false).as_kernel_param(),
394                (&self.d_values).as_kernel_param(),
395                (&self.d_adj).as_kernel_param(),
396                (&mut self.d_grad_true).as_kernel_param(),
397                (&mut self.d_grad_false).as_kernel_param(),
398            ];
399            Self::launch_level_cached(&self.backward_decision_grad_fn, len, &mut params)?;
400        }
401
402        for (level, &(_offset, len)) in self.levels.iter().enumerate().rev() {
403            let level_u32 = level as u32;
404            let mut params: Vec<*mut c_void> = vec![
405                (&self.d_node_type).as_kernel_param(),
406                (&self.d_lit).as_kernel_param(),
407                (&self.d_level_nodes).as_kernel_param(),
408                (&self.d_level_offsets).as_kernel_param(),
409                level_u32.as_kernel_param(),
410                (&self.d_adj).as_kernel_param(),
411                (&mut self.d_grad_true).as_kernel_param(),
412                (&mut self.d_grad_false).as_kernel_param(),
413            ];
414            Self::launch_level_cached(&self.backward_lit_grad_fn, len, &mut params)?;
415        }
416
417        Ok(())
418    }
419
420    /// Convenience helper: forward + backward in one launch sequence (no sync, no host transfers).
421    pub fn forward_then_backward_launch(&mut self, ctx: &TestContext) -> Result<()> {
422        self.forward_launch(ctx)?;
423        self.backward_only_launch(ctx)
424    }
425}
426
427fn logsumexp2(a: f64, b: f64) -> f64 {
428    let m = a.max(b);
429    if m.is_infinite() && m.is_sign_negative() {
430        return m;
431    }
432    m + ((a - m).exp() + (b - m).exp()).ln()
433}
434
435/// Tiny Decision-DNNF-shaped XGCF circuit that exercises CONST/LIT/AND/OR/DECISION nodes.
436pub fn tiny_xgcf_spec() -> TinyXgcfSpec {
437    const CONST0: u8 = 0;
438    const CONST1: u8 = 1;
439    const LIT: u8 = 2;
440    const AND: u8 = 3;
441    const OR: u8 = 4;
442    const DECISION: u8 = 5;
443
444    // Node indices:
445    // 0: CONST1
446    // 1: LIT(+1)
447    // 2: LIT(-2)
448    // 3: AND(1,2)
449    // 4: DECISION(var3, child_f=0, child_t=3)
450    // 5: CONST0
451    // 6: OR(4,5)   (root)
452    let num_nodes = 7;
453    let root = 6u32;
454
455    let node_type: Vec<u8> = vec![CONST1, LIT, LIT, AND, DECISION, CONST0, OR];
456    let lit: Vec<i32> = vec![0, 1, -2, 0, 0, 0, 0];
457    let decision_var: Vec<u32> = vec![0, 0, 0, 0, 3, 0, 0];
458    let decision_child_false: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0];
459    let decision_child_true: Vec<u32> = vec![0, 0, 0, 0, 3, 0, 0];
460
461    let child_offsets: Vec<u32> = vec![0, 0, 0, 0, 2, 2, 2, 4];
462    let child_indices: Vec<u32> = vec![1, 2, 4, 5];
463
464    // Levels: [0,1,2,5], [3], [4], [6]
465    let level_nodes: Vec<u32> = vec![0, 1, 2, 5, 3, 4, 6];
466    let levels: Vec<(u32, u32)> = vec![(0, 4), (4, 1), (5, 1), (6, 1)];
467
468    let num_vars = 3usize;
469    let var_log_true: Vec<f64> = vec![
470        0.0,
471        0.7f64.ln(), // var1
472        0.2f64.ln(), // var2
473        0.6f64.ln(), // var3
474    ];
475    let var_log_false: Vec<f64> = vec![
476        0.0,
477        0.3f64.ln(), // var1
478        0.8f64.ln(), // var2
479        0.4f64.ln(), // var3
480    ];
481
482    let v0 = 0.0;
483    let v1 = var_log_true[1];
484    let v2 = var_log_false[2];
485    let v3 = v1 + v2;
486    let v4 = logsumexp2(var_log_false[3] + v0, var_log_true[3] + v3);
487    let v5 = f64::NEG_INFINITY;
488    let v6 = logsumexp2(v4, v5);
489
490    let expected_values: Vec<f64> = vec![v0, v1, v2, v3, v4, v5, v6];
491
492    let p_false = (var_log_false[3] + v0 - v4).exp();
493    let p_true = (var_log_true[3] + v3 - v4).exp();
494
495    let mut expected_grad_true = vec![0.0f64; num_vars + 1];
496    let mut expected_grad_false = vec![0.0f64; num_vars + 1];
497    expected_grad_true[1] = p_true; // LIT(+1)
498    expected_grad_false[2] = p_true; // LIT(-2)
499    expected_grad_true[3] = p_true; // DECISION var3
500    expected_grad_false[3] = p_false;
501
502    TinyXgcfSpec {
503        num_nodes,
504        num_vars,
505        root,
506        node_type,
507        child_offsets,
508        child_indices,
509        lit,
510        decision_var,
511        decision_child_false,
512        decision_child_true,
513        level_nodes,
514        levels,
515        var_log_true,
516        var_log_false,
517        expected_values,
518        expected_grad_true,
519        expected_grad_false,
520    }
521}
522
523fn launch_level(
524    ctx: &TestContext,
525    kernel_name: &str,
526    num_level_nodes: u32,
527    params: &mut Vec<*mut c_void>,
528) -> Result<()> {
529    if num_level_nodes == 0 {
530        return Ok(());
531    }
532    let device = ctx.device.inner();
533    let kernel = device
534        .get_func(CIRCUIT_MODULE, kernel_name)
535        .ok_or_else(|| {
536            XlogError::Kernel(format!(
537                "Kernel {} not found in {}",
538                kernel_name, CIRCUIT_MODULE
539            ))
540        })?;
541
542    let block_size = 256u32;
543    let num_blocks = (num_level_nodes + block_size - 1) / block_size;
544    let config = LaunchConfig {
545        grid_dim: (num_blocks, 1, 1),
546        block_dim: (block_size, 1, 1),
547        shared_mem_bytes: 0,
548    };
549
550    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
551    unsafe { kernel.clone().launch(config, params) }
552        .map_err(|e| XlogError::Kernel(format!("Failed to launch {}: {}", kernel_name, e)))?;
553    Ok(())
554}
555
556pub fn run_tiny_xgcf_forward(ctx: &TestContext, spec: &TinyXgcfSpec) -> Result<Vec<f64>> {
557    let mut d_node_type = ctx.memory.alloc::<u8>(spec.node_type.len())?;
558    ctx.htod_sync_copy_into(&spec.node_type, &mut d_node_type)
559        .map_err(|e| XlogError::Kernel(format!("Failed to upload node_type: {}", e)))?;
560
561    let mut d_child_offsets = ctx.memory.alloc::<u32>(spec.child_offsets.len())?;
562    ctx.htod_sync_copy_into(&spec.child_offsets, &mut d_child_offsets)
563        .map_err(|e| XlogError::Kernel(format!("Failed to upload child_offsets: {}", e)))?;
564
565    let mut d_child_indices = ctx.memory.alloc::<u32>(spec.child_indices.len())?;
566    ctx.htod_sync_copy_into(&spec.child_indices, &mut d_child_indices)
567        .map_err(|e| XlogError::Kernel(format!("Failed to upload child_indices: {}", e)))?;
568
569    let mut d_lit = ctx.memory.alloc::<i32>(spec.lit.len())?;
570    ctx.htod_sync_copy_into(&spec.lit, &mut d_lit)
571        .map_err(|e| XlogError::Kernel(format!("Failed to upload lit: {}", e)))?;
572
573    let mut d_decision_var = ctx.memory.alloc::<u32>(spec.decision_var.len())?;
574    ctx.htod_sync_copy_into(&spec.decision_var, &mut d_decision_var)
575        .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_var: {}", e)))?;
576
577    let mut d_decision_child_false = ctx.memory.alloc::<u32>(spec.decision_child_false.len())?;
578    ctx.htod_sync_copy_into(&spec.decision_child_false, &mut d_decision_child_false)
579        .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_false: {}", e)))?;
580
581    let mut d_decision_child_true = ctx.memory.alloc::<u32>(spec.decision_child_true.len())?;
582    ctx.htod_sync_copy_into(&spec.decision_child_true, &mut d_decision_child_true)
583        .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_true: {}", e)))?;
584
585    let mut d_level_nodes = ctx.memory.alloc::<u32>(spec.level_nodes.len())?;
586    ctx.htod_sync_copy_into(&spec.level_nodes, &mut d_level_nodes)
587        .map_err(|e| XlogError::Kernel(format!("Failed to upload level_nodes: {}", e)))?;
588
589    let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
590    for &(offset, _len) in &spec.levels {
591        level_offsets.push(offset);
592    }
593    let (last_offset, last_len) = *spec
594        .levels
595        .last()
596        .ok_or_else(|| XlogError::Kernel("TinyXgcfSpec requires non-empty levels".to_string()))?;
597    level_offsets.push(last_offset + last_len);
598    if level_offsets[0] != 0 {
599        return Err(XlogError::Kernel(
600            "TinyXgcfSpec level_offsets must start at 0".to_string(),
601        ));
602    }
603    let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
604    ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
605        .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
606
607    let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
608    for &(offset, _len) in &spec.levels {
609        level_offsets.push(offset);
610    }
611    let (last_offset, last_len) = *spec
612        .levels
613        .last()
614        .ok_or_else(|| XlogError::Kernel("TinyXgcfSpec requires non-empty levels".to_string()))?;
615    level_offsets.push(last_offset + last_len);
616    if level_offsets[0] != 0 {
617        return Err(XlogError::Kernel(
618            "TinyXgcfSpec level_offsets must start at 0".to_string(),
619        ));
620    }
621    let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
622    ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
623        .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
624
625    let mut d_var_log_true = ctx.memory.alloc::<f64>(spec.var_log_true.len())?;
626    ctx.htod_sync_copy_into(&spec.var_log_true, &mut d_var_log_true)
627        .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
628
629    let mut d_var_log_false = ctx.memory.alloc::<f64>(spec.var_log_false.len())?;
630    ctx.htod_sync_copy_into(&spec.var_log_false, &mut d_var_log_false)
631        .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
632
633    let mut d_values = ctx.memory.alloc::<f64>(spec.num_nodes)?;
634    let init_values = vec![0.0f64; spec.num_nodes];
635    ctx.htod_sync_copy_into(&init_values, &mut d_values)
636        .map_err(|e| XlogError::Kernel(format!("Failed to init values: {}", e)))?;
637
638    for (level, &(_offset, len)) in spec.levels.iter().enumerate() {
639        let level_u32 = level as u32;
640        let mut params: Vec<*mut c_void> = vec![
641            (&d_node_type).as_kernel_param(),
642            (&d_child_offsets).as_kernel_param(),
643            (&d_child_indices).as_kernel_param(),
644            (&d_lit).as_kernel_param(),
645            (&d_decision_var).as_kernel_param(),
646            (&d_decision_child_false).as_kernel_param(),
647            (&d_decision_child_true).as_kernel_param(),
648            (&d_level_nodes).as_kernel_param(),
649            (&d_level_offsets).as_kernel_param(),
650            level_u32.as_kernel_param(),
651            (&d_var_log_true).as_kernel_param(),
652            (&d_var_log_false).as_kernel_param(),
653            (&mut d_values).as_kernel_param(),
654        ];
655        launch_level(ctx, circuit_kernels::XGCF_FORWARD_LEVEL, len, &mut params)?;
656    }
657
658    ctx.sync_and_check()?;
659
660    ctx.dtoh_sync_copy(&d_values)
661        .map_err(|e| XlogError::Kernel(format!("Failed to download values: {}", e)))
662}
663
664pub fn run_tiny_xgcf_backward(ctx: &TestContext, spec: &TinyXgcfSpec) -> Result<TinyXgcfRun> {
665    let mut d_node_type = ctx.memory.alloc::<u8>(spec.node_type.len())?;
666    ctx.htod_sync_copy_into(&spec.node_type, &mut d_node_type)
667        .map_err(|e| XlogError::Kernel(format!("Failed to upload node_type: {}", e)))?;
668
669    let mut d_child_offsets = ctx.memory.alloc::<u32>(spec.child_offsets.len())?;
670    ctx.htod_sync_copy_into(&spec.child_offsets, &mut d_child_offsets)
671        .map_err(|e| XlogError::Kernel(format!("Failed to upload child_offsets: {}", e)))?;
672
673    let mut d_child_indices = ctx.memory.alloc::<u32>(spec.child_indices.len())?;
674    ctx.htod_sync_copy_into(&spec.child_indices, &mut d_child_indices)
675        .map_err(|e| XlogError::Kernel(format!("Failed to upload child_indices: {}", e)))?;
676
677    let mut d_lit = ctx.memory.alloc::<i32>(spec.lit.len())?;
678    ctx.htod_sync_copy_into(&spec.lit, &mut d_lit)
679        .map_err(|e| XlogError::Kernel(format!("Failed to upload lit: {}", e)))?;
680
681    let mut d_decision_var = ctx.memory.alloc::<u32>(spec.decision_var.len())?;
682    ctx.htod_sync_copy_into(&spec.decision_var, &mut d_decision_var)
683        .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_var: {}", e)))?;
684
685    let mut d_decision_child_false = ctx.memory.alloc::<u32>(spec.decision_child_false.len())?;
686    ctx.htod_sync_copy_into(&spec.decision_child_false, &mut d_decision_child_false)
687        .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_false: {}", e)))?;
688
689    let mut d_decision_child_true = ctx.memory.alloc::<u32>(spec.decision_child_true.len())?;
690    ctx.htod_sync_copy_into(&spec.decision_child_true, &mut d_decision_child_true)
691        .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_true: {}", e)))?;
692
693    let mut d_level_nodes = ctx.memory.alloc::<u32>(spec.level_nodes.len())?;
694    ctx.htod_sync_copy_into(&spec.level_nodes, &mut d_level_nodes)
695        .map_err(|e| XlogError::Kernel(format!("Failed to upload level_nodes: {}", e)))?;
696
697    let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
698    for &(offset, _len) in &spec.levels {
699        level_offsets.push(offset);
700    }
701    let (last_offset, last_len) = *spec
702        .levels
703        .last()
704        .ok_or_else(|| XlogError::Kernel("TinyXgcfSpec requires non-empty levels".to_string()))?;
705    level_offsets.push(last_offset + last_len);
706    if level_offsets[0] != 0 {
707        return Err(XlogError::Kernel(
708            "TinyXgcfSpec level_offsets must start at 0".to_string(),
709        ));
710    }
711    let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
712    ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
713        .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
714
715    let mut d_var_log_true = ctx.memory.alloc::<f64>(spec.var_log_true.len())?;
716    ctx.htod_sync_copy_into(&spec.var_log_true, &mut d_var_log_true)
717        .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
718
719    let mut d_var_log_false = ctx.memory.alloc::<f64>(spec.var_log_false.len())?;
720    ctx.htod_sync_copy_into(&spec.var_log_false, &mut d_var_log_false)
721        .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
722
723    let mut d_values = ctx.memory.alloc::<f64>(spec.num_nodes)?;
724    let init_values = vec![0.0f64; spec.num_nodes];
725    ctx.htod_sync_copy_into(&init_values, &mut d_values)
726        .map_err(|e| XlogError::Kernel(format!("Failed to init values: {}", e)))?;
727
728    for (level, &(_offset, len)) in spec.levels.iter().enumerate() {
729        let level_u32 = level as u32;
730        let mut params: Vec<*mut c_void> = vec![
731            (&d_node_type).as_kernel_param(),
732            (&d_child_offsets).as_kernel_param(),
733            (&d_child_indices).as_kernel_param(),
734            (&d_lit).as_kernel_param(),
735            (&d_decision_var).as_kernel_param(),
736            (&d_decision_child_false).as_kernel_param(),
737            (&d_decision_child_true).as_kernel_param(),
738            (&d_level_nodes).as_kernel_param(),
739            (&d_level_offsets).as_kernel_param(),
740            level_u32.as_kernel_param(),
741            (&d_var_log_true).as_kernel_param(),
742            (&d_var_log_false).as_kernel_param(),
743            (&mut d_values).as_kernel_param(),
744        ];
745        launch_level(ctx, circuit_kernels::XGCF_FORWARD_LEVEL, len, &mut params)?;
746    }
747
748    // adj[root] = 1, others 0
749    let mut adj_init = vec![0.0f64; spec.num_nodes];
750    let root_idx: usize = spec.root as usize;
751    if root_idx >= adj_init.len() {
752        return Err(XlogError::Kernel(format!(
753            "Root {} out of bounds for num_nodes {}",
754            spec.root, spec.num_nodes
755        )));
756    }
757    adj_init[root_idx] = 1.0;
758    let mut d_adj = ctx.memory.alloc::<f64>(spec.num_nodes)?;
759    ctx.htod_sync_copy_into(&adj_init, &mut d_adj)
760        .map_err(|e| XlogError::Kernel(format!("Failed to init adj: {}", e)))?;
761
762    let mut d_grad_true = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
763    let mut d_grad_false = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
764    let grad_init = vec![0.0f64; spec.num_vars + 1];
765    ctx.htod_sync_copy_into(&grad_init, &mut d_grad_true)
766        .map_err(|e| XlogError::Kernel(format!("Failed to init grad_true: {}", e)))?;
767    ctx.htod_sync_copy_into(&grad_init, &mut d_grad_false)
768        .map_err(|e| XlogError::Kernel(format!("Failed to init grad_false: {}", e)))?;
769
770    for (level, &(_offset, len)) in spec.levels.iter().enumerate().rev() {
771        let level_u32 = level as u32;
772        let mut params: Vec<*mut c_void> = vec![
773            (&d_node_type).as_kernel_param(),
774            (&d_child_offsets).as_kernel_param(),
775            (&d_child_indices).as_kernel_param(),
776            (&d_decision_var).as_kernel_param(),
777            (&d_decision_child_false).as_kernel_param(),
778            (&d_decision_child_true).as_kernel_param(),
779            (&d_level_nodes).as_kernel_param(),
780            (&d_level_offsets).as_kernel_param(),
781            level_u32.as_kernel_param(),
782            (&d_var_log_true).as_kernel_param(),
783            (&d_var_log_false).as_kernel_param(),
784            (&d_values).as_kernel_param(),
785            (&mut d_adj).as_kernel_param(),
786        ];
787        launch_level(
788            ctx,
789            circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
790            len,
791            &mut params,
792        )?;
793    }
794
795    for (level, &(_offset, len)) in spec.levels.iter().enumerate().rev() {
796        let level_u32 = level as u32;
797        let mut params: Vec<*mut c_void> = vec![
798            (&d_node_type).as_kernel_param(),
799            (&d_decision_var).as_kernel_param(),
800            (&d_decision_child_false).as_kernel_param(),
801            (&d_decision_child_true).as_kernel_param(),
802            (&d_level_nodes).as_kernel_param(),
803            (&d_level_offsets).as_kernel_param(),
804            level_u32.as_kernel_param(),
805            (&d_var_log_true).as_kernel_param(),
806            (&d_var_log_false).as_kernel_param(),
807            (&d_values).as_kernel_param(),
808            (&d_adj).as_kernel_param(),
809            (&mut d_grad_true).as_kernel_param(),
810            (&mut d_grad_false).as_kernel_param(),
811        ];
812        launch_level(
813            ctx,
814            circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
815            len,
816            &mut params,
817        )?;
818    }
819
820    for (level, &(_offset, len)) in spec.levels.iter().enumerate().rev() {
821        let level_u32 = level as u32;
822        let mut params: Vec<*mut c_void> = vec![
823            (&d_node_type).as_kernel_param(),
824            (&d_lit).as_kernel_param(),
825            (&d_level_nodes).as_kernel_param(),
826            (&d_level_offsets).as_kernel_param(),
827            level_u32.as_kernel_param(),
828            (&d_adj).as_kernel_param(),
829            (&mut d_grad_true).as_kernel_param(),
830            (&mut d_grad_false).as_kernel_param(),
831        ];
832        launch_level(
833            ctx,
834            circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
835            len,
836            &mut params,
837        )?;
838    }
839
840    ctx.sync_and_check()?;
841
842    let values = ctx
843        .dtoh_sync_copy(&d_values)
844        .map_err(|e| XlogError::Kernel(format!("Failed to download values: {}", e)))?;
845    let adj = ctx
846        .dtoh_sync_copy(&d_adj)
847        .map_err(|e| XlogError::Kernel(format!("Failed to download adj: {}", e)))?;
848    let grad_true = ctx
849        .dtoh_sync_copy(&d_grad_true)
850        .map_err(|e| XlogError::Kernel(format!("Failed to download grad_true: {}", e)))?;
851    let grad_false = ctx
852        .dtoh_sync_copy(&d_grad_false)
853        .map_err(|e| XlogError::Kernel(format!("Failed to download grad_false: {}", e)))?;
854
855    Ok(TinyXgcfRun {
856        values,
857        adj,
858        grad_true,
859        grad_false,
860    })
861}
862
863/// Generate a single-literal circuit: root = Lit(+var)
864pub fn gen_single_lit_circuit(var: u32) -> TinyXgcfSpec {
865    const LIT: u8 = 2;
866
867    let num_nodes = 1;
868    let num_vars = var as usize;
869    let root = 0;
870
871    let node_type = vec![LIT];
872    let child_offsets = vec![0, 0];
873    let child_indices = vec![];
874    let lit = vec![var as i32];
875    let decision_var = vec![0];
876    let decision_child_false = vec![0];
877    let decision_child_true = vec![0];
878    let level_nodes = vec![0];
879    let levels = vec![(0, 1)];
880
881    let mut var_log_true = vec![0.0; num_vars + 1];
882    let mut var_log_false = vec![0.0; num_vars + 1];
883    var_log_true[var as usize] = 0.7_f64.ln();
884    var_log_false[var as usize] = 0.3_f64.ln();
885
886    let expected_values = vec![var_log_true[var as usize]];
887    let mut expected_grad_true = vec![0.0; num_vars + 1];
888    let expected_grad_false = vec![0.0; num_vars + 1];
889    expected_grad_true[var as usize] = 1.0;
890
891    TinyXgcfSpec {
892        num_nodes,
893        num_vars,
894        root,
895        node_type,
896        child_offsets,
897        child_indices,
898        lit,
899        decision_var,
900        decision_child_false,
901        decision_child_true,
902        level_nodes,
903        levels,
904        var_log_true,
905        var_log_false,
906        expected_values,
907        expected_grad_true,
908        expected_grad_false,
909    }
910}
911
912/// Generate an AND circuit: root = AND(Lit(+1), Lit(+2))
913pub fn gen_and_circuit() -> TinyXgcfSpec {
914    const LIT: u8 = 2;
915    const AND: u8 = 3;
916
917    let num_nodes = 3;
918    let num_vars = 2;
919    let root = 2;
920
921    let node_type = vec![LIT, LIT, AND];
922    let child_offsets = vec![0, 0, 0, 2];
923    let child_indices = vec![0, 1];
924    let lit = vec![1, 2, 0];
925    let decision_var = vec![0, 0, 0];
926    let decision_child_false = vec![0, 0, 0];
927    let decision_child_true = vec![0, 0, 0];
928    let level_nodes = vec![0, 1, 2];
929    let levels = vec![(0, 2), (2, 1)];
930
931    let p1 = 0.7_f64;
932    let p2 = 0.6_f64;
933    let var_log_true = vec![0.0, p1.ln(), p2.ln()];
934    let var_log_false = vec![0.0, (1.0 - p1).ln(), (1.0 - p2).ln()];
935
936    let v0 = var_log_true[1];
937    let v1 = var_log_true[2];
938    let v2 = v0 + v1;
939    let expected_values = vec![v0, v1, v2];
940    let expected_grad_true = vec![0.0, 1.0, 1.0];
941    let expected_grad_false = vec![0.0, 0.0, 0.0];
942
943    TinyXgcfSpec {
944        num_nodes,
945        num_vars,
946        root,
947        node_type,
948        child_offsets,
949        child_indices,
950        lit,
951        decision_var,
952        decision_child_false,
953        decision_child_true,
954        level_nodes,
955        levels,
956        var_log_true,
957        var_log_false,
958        expected_values,
959        expected_grad_true,
960        expected_grad_false,
961    }
962}
963
964/// Generate an OR circuit: root = OR(Lit(+1), Lit(+2))
965pub fn gen_or_circuit() -> TinyXgcfSpec {
966    const LIT: u8 = 2;
967    const OR: u8 = 4;
968
969    let num_nodes = 3;
970    let num_vars = 2;
971    let root = 2;
972
973    let node_type = vec![LIT, LIT, OR];
974    let child_offsets = vec![0, 0, 0, 2];
975    let child_indices = vec![0, 1];
976    let lit = vec![1, 2, 0];
977    let decision_var = vec![0, 0, 0];
978    let decision_child_false = vec![0, 0, 0];
979    let decision_child_true = vec![0, 0, 0];
980    let level_nodes = vec![0, 1, 2];
981    let levels = vec![(0, 2), (2, 1)];
982
983    let p1 = 0.7_f64;
984    let p2 = 0.6_f64;
985    let var_log_true = vec![0.0, p1.ln(), p2.ln()];
986    let var_log_false = vec![0.0, (1.0 - p1).ln(), (1.0 - p2).ln()];
987
988    let v0 = var_log_true[1];
989    let v1 = var_log_true[2];
990    let v2 = logsumexp2(v0, v1);
991    let expected_values = vec![v0, v1, v2];
992
993    let p_child0 = (v0 - v2).exp();
994    let p_child1 = (v1 - v2).exp();
995    let expected_grad_true = vec![0.0, p_child0, p_child1];
996    let expected_grad_false = vec![0.0, 0.0, 0.0];
997
998    TinyXgcfSpec {
999        num_nodes,
1000        num_vars,
1001        root,
1002        node_type,
1003        child_offsets,
1004        child_indices,
1005        lit,
1006        decision_var,
1007        decision_child_false,
1008        decision_child_true,
1009        level_nodes,
1010        levels,
1011        var_log_true,
1012        var_log_false,
1013        expected_values,
1014        expected_grad_true,
1015        expected_grad_false,
1016    }
1017}
1018
1019/// Generate a Decision circuit: root = Decision(var, false_child=Const1, true_child=Lit(+1))
1020pub fn gen_decision_circuit() -> TinyXgcfSpec {
1021    const CONST1: u8 = 1;
1022    const LIT: u8 = 2;
1023    const DECISION: u8 = 5;
1024
1025    let num_nodes = 3;
1026    let num_vars = 2;
1027    let root = 2;
1028
1029    let node_type = vec![CONST1, LIT, DECISION];
1030    let child_offsets = vec![0, 0, 0, 0];
1031    let child_indices = vec![];
1032    let lit = vec![0, 1, 0];
1033    let decision_var = vec![0, 0, 2];
1034    let decision_child_false = vec![0, 0, 0];
1035    let decision_child_true = vec![0, 0, 1];
1036    let level_nodes = vec![0, 1, 2];
1037    let levels = vec![(0, 2), (2, 1)];
1038
1039    let p1 = 0.7_f64;
1040    let p2 = 0.6_f64;
1041    let var_log_true = vec![0.0, p1.ln(), p2.ln()];
1042    let var_log_false = vec![0.0, (1.0 - p1).ln(), (1.0 - p2).ln()];
1043
1044    let v0 = 0.0;
1045    let v1 = var_log_true[1];
1046    let v2 = logsumexp2(var_log_false[2] + v0, var_log_true[2] + v1);
1047    let expected_values = vec![v0, v1, v2];
1048
1049    let p_false = (var_log_false[2] + v0 - v2).exp();
1050    let p_true = (var_log_true[2] + v1 - v2).exp();
1051
1052    let expected_grad_true = vec![0.0, p_true, p_true];
1053    let expected_grad_false = vec![0.0, 0.0, p_false];
1054
1055    TinyXgcfSpec {
1056        num_nodes,
1057        num_vars,
1058        root,
1059        node_type,
1060        child_offsets,
1061        child_indices,
1062        lit,
1063        decision_var,
1064        decision_child_false,
1065        decision_child_true,
1066        level_nodes,
1067        levels,
1068        var_log_true,
1069        var_log_false,
1070        expected_values,
1071        expected_grad_true,
1072        expected_grad_false,
1073    }
1074}
1075
1076/// Generate a large circuit with N parallel literals under an OR node
1077pub fn gen_large_or_circuit(num_vars: usize) -> TinyXgcfSpec {
1078    const LIT: u8 = 2;
1079    const OR: u8 = 4;
1080
1081    let num_nodes = num_vars + 1;
1082    let root = num_vars as u32;
1083
1084    let mut node_type = vec![LIT; num_vars];
1085    node_type.push(OR);
1086
1087    let mut child_offsets: Vec<u32> = (0..=num_vars).map(|_| 0).collect();
1088    child_offsets.push(num_vars as u32);
1089
1090    let child_indices: Vec<u32> = (0..num_vars as u32).collect();
1091
1092    let mut lit: Vec<i32> = (1..=num_vars as i32).collect();
1093    lit.push(0);
1094
1095    let decision_var = vec![0; num_nodes];
1096    let decision_child_false = vec![0; num_nodes];
1097    let decision_child_true = vec![0; num_nodes];
1098
1099    let mut level_nodes: Vec<u32> = (0..num_vars as u32).collect();
1100    level_nodes.push(root);
1101    let levels = vec![(0, num_vars as u32), (num_vars as u32, 1)];
1102
1103    let p = 0.5_f64;
1104    let mut var_log_true = vec![0.0; num_vars + 1];
1105    let mut var_log_false = vec![0.0; num_vars + 1];
1106    for i in 1..=num_vars {
1107        var_log_true[i] = p.ln();
1108        var_log_false[i] = (1.0 - p).ln();
1109    }
1110
1111    let lit_val = p.ln();
1112    let mut expected_values = vec![lit_val; num_vars];
1113    let or_val = lit_val + (num_vars as f64).ln();
1114    expected_values.push(or_val);
1115
1116    let grad_per_lit = 1.0 / num_vars as f64;
1117    let mut expected_grad_true = vec![0.0; num_vars + 1];
1118    for i in 1..=num_vars {
1119        expected_grad_true[i] = grad_per_lit;
1120    }
1121    let expected_grad_false = vec![0.0; num_vars + 1];
1122
1123    TinyXgcfSpec {
1124        num_nodes,
1125        num_vars,
1126        root,
1127        node_type,
1128        child_offsets,
1129        child_indices,
1130        lit,
1131        decision_var,
1132        decision_child_false,
1133        decision_child_true,
1134        level_nodes,
1135        levels,
1136        var_log_true,
1137        var_log_false,
1138        expected_values,
1139        expected_grad_true,
1140        expected_grad_false,
1141    }
1142}
1143
1144/// Generate a deep chain circuit: AND(AND(AND(...Lit(1)...)))
1145pub fn gen_deep_chain_circuit(depth: usize) -> TinyXgcfSpec {
1146    const LIT: u8 = 2;
1147    const AND: u8 = 3;
1148
1149    let num_nodes = depth + 1;
1150    let num_vars = 1;
1151    let root = depth as u32;
1152
1153    let mut node_type = vec![LIT];
1154    for _ in 0..depth {
1155        node_type.push(AND);
1156    }
1157
1158    let mut child_offsets: Vec<u32> = vec![0];
1159    let mut child_indices: Vec<u32> = vec![];
1160    for i in 0..depth {
1161        child_offsets.push(child_indices.len() as u32);
1162        child_indices.push(i as u32);
1163    }
1164    child_offsets.push(child_indices.len() as u32);
1165
1166    let mut lit = vec![1i32];
1167    lit.extend(vec![0i32; depth]);
1168
1169    let decision_var = vec![0; num_nodes];
1170    let decision_child_false = vec![0; num_nodes];
1171    let decision_child_true = vec![0; num_nodes];
1172
1173    let level_nodes: Vec<u32> = (0..num_nodes as u32).collect();
1174    let levels: Vec<(u32, u32)> = (0..num_nodes).map(|i| (i as u32, 1)).collect();
1175
1176    let p = 0.7_f64;
1177    let var_log_true = vec![0.0, p.ln()];
1178    let var_log_false = vec![0.0, (1.0 - p).ln()];
1179
1180    let lit_val = p.ln();
1181    let expected_values = vec![lit_val; num_nodes];
1182
1183    let expected_grad_true = vec![0.0, 1.0];
1184    let expected_grad_false = vec![0.0, 0.0];
1185
1186    TinyXgcfSpec {
1187        num_nodes,
1188        num_vars,
1189        root,
1190        node_type,
1191        child_offsets,
1192        child_indices,
1193        lit,
1194        decision_var,
1195        decision_child_false,
1196        decision_child_true,
1197        level_nodes,
1198        levels,
1199        var_log_true,
1200        var_log_false,
1201        expected_values,
1202        expected_grad_true,
1203        expected_grad_false,
1204    }
1205}
1206
1207/// Compute numerical gradient for verification
1208pub fn numerical_gradient(
1209    ctx: &TestContext,
1210    spec: &TinyXgcfSpec,
1211    var: usize,
1212    eps: f64,
1213) -> xlog_core::Result<(f64, f64)> {
1214    let mut spec_plus = spec.clone();
1215    let mut spec_minus = spec.clone();
1216    spec_plus.var_log_true[var] += eps;
1217    spec_minus.var_log_true[var] -= eps;
1218
1219    let values_plus = run_tiny_xgcf_forward(ctx, &spec_plus)?;
1220    let values_minus = run_tiny_xgcf_forward(ctx, &spec_minus)?;
1221
1222    let grad_true =
1223        (values_plus[spec.root as usize] - values_minus[spec.root as usize]) / (2.0 * eps);
1224
1225    let mut spec_plus = spec.clone();
1226    let mut spec_minus = spec.clone();
1227    spec_plus.var_log_false[var] += eps;
1228    spec_minus.var_log_false[var] -= eps;
1229
1230    let values_plus = run_tiny_xgcf_forward(ctx, &spec_plus)?;
1231    let values_minus = run_tiny_xgcf_forward(ctx, &spec_minus)?;
1232
1233    let grad_false =
1234        (values_plus[spec.root as usize] - values_minus[spec.root as usize]) / (2.0 * eps);
1235
1236    Ok((grad_true, grad_false))
1237}