Skip to main content

xlog_prob/compilation/
validation.rs

1//! GPU-native equivalence validation (φ ≡ C) using the GPU CDCL verifier.
2
3use std::sync::Arc;
4
5use std::ffi::c_void;
6
7use cudarc::driver::LaunchConfig;
8use xlog_core::{Result, XlogError};
9use xlog_cuda::memory::TrackedCudaSlice;
10use xlog_cuda::provider::sat_kernels;
11use xlog_cuda::provider::SAT_MODULE;
12use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
13use xlog_solve::{GpuCdclConfig, GpuCdclSolver, GpuCnf};
14
15#[cfg(debug_assertions)]
16use crate::compilation::gpu_d4::validate_cnf_gpu;
17
18use crate::gpu::GpuXgcf;
19
20const MAX_GRID_X: u64 = 65_535;
21
22fn checked_launch_grid(elements: u32, block: u32, context: &str) -> Result<u32> {
23    if block == 0 {
24        return Err(XlogError::Kernel(format!(
25            "{context}: CUDA launch block size must be nonzero"
26        )));
27    }
28    let grid = if elements == 0 {
29        1
30    } else {
31        u64::from(elements).div_ceil(u64::from(block))
32    };
33    if grid > MAX_GRID_X {
34        return Err(XlogError::Kernel(format!(
35            "{context}: launch grid {grid} exceeds x-dimension limit {MAX_GRID_X} \
36             for {elements} elements with block size {block}"
37        )));
38    }
39    Ok(grid as u32)
40}
41
42fn checked_clause_offset_span(clause_cap: u32, context: &str) -> Result<u32> {
43    clause_cap
44        .checked_add(1)
45        .ok_or_else(|| XlogError::Kernel(format!("{context}: clause offset span overflow")))
46}
47
48/// Configuration for GPU-native equivalence verification (phi equiv C).
49///
50/// Controls the CDCL solver parameters and whether to reuse the solver
51/// workspace across multiple equivalence checks. Workspace reuse amortizes
52/// device-memory allocation when verifying many circuits in sequence (e.g.,
53/// during incremental compilation).
54#[derive(Debug, Clone, Copy, Default)]
55#[non_exhaustive]
56pub struct GpuEquivalenceConfig {
57    /// CDCL solver configuration for the equivalence verifier.
58    pub cdcl: GpuCdclConfig,
59    /// Reuse the CDCL workspace across successive verifier invocations.
60    pub reuse_workspace: bool,
61}
62
63/// GPU-resident equivalence queries + device metadata required to solve them without host reads.
64pub struct GpuEquivalenceQueries {
65    pub q1: GpuCnf,
66    pub q2: GpuCnf,
67    /// Base variable id for the ¬phi selector vars in q2 (len=1, device-resident).
68    pub q2_unsat_var_base: TrackedCudaSlice<u32>,
69}
70
71struct CircuitCnf {
72    cnf: GpuCnf,
73    /// Exclusive prefix sum over `is_internal(node)` (len = num_nodes).
74    /// Used to map internal node ids -> Tseitin vars in kernels.
75    internal_prefix: TrackedCudaSlice<u32>,
76}
77
78fn build_circuit_cnf(
79    provider: &Arc<CudaKernelProvider>,
80    circuit: &GpuXgcf,
81    base_num_vars: &TrackedCudaSlice<u32>,
82    base_var_cap: u32,
83    compile_needed: &TrackedCudaSlice<u32>,
84) -> Result<CircuitCnf> {
85    if base_var_cap == 0 {
86        return Err(XlogError::Compilation(
87            "GPU equivalence verifier requires base_var_cap > 0".to_string(),
88        ));
89    }
90    if circuit.max_var() > base_var_cap {
91        return Err(XlogError::Compilation(format!(
92            "Circuit references var {} but base CNF has only {} vars",
93            circuit.max_var(),
94            base_var_cap
95        )));
96    }
97
98    let num_nodes = circuit.num_nodes();
99    if num_nodes == 0 {
100        return Err(XlogError::Compilation(
101            "GPU equivalence verifier requires circuit with num_nodes > 0".to_string(),
102        ));
103    }
104    if circuit.root() as usize >= num_nodes {
105        return Err(XlogError::Compilation(format!(
106            "GPU equivalence verifier: circuit root {} out of bounds (num_nodes={})",
107            circuit.root(),
108            num_nodes
109        )));
110    }
111
112    let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
113        XlogError::Compilation(format!(
114            "GPU equivalence verifier: circuit num_nodes {} exceeds u32::MAX",
115            num_nodes
116        ))
117    })?;
118
119    // Safe, host-known upper bounds (no device->host reads required).
120    let num_edges = circuit.num_edges();
121    let n64 = num_nodes as u64;
122    let e64 = num_edges as u64;
123
124    let var_cap = u32::try_from((base_var_cap as u64).saturating_add(n64))
125        .map_err(|_| XlogError::Kernel("Circuit CNF var capacity exceeds u32::MAX".to_string()))?;
126    let clause_cap =
127        u32::try_from(e64.checked_add(4u64.saturating_mul(n64)).ok_or_else(|| {
128            XlogError::Kernel("Circuit CNF clause capacity overflow".to_string())
129        })?)
130        .map_err(|_| {
131            XlogError::Kernel("Circuit CNF clause capacity exceeds u32::MAX".to_string())
132        })?;
133    let lit_cap = u32::try_from(
134        (3u64.saturating_mul(e64))
135            .checked_add(12u64.saturating_mul(n64))
136            .ok_or_else(|| {
137                XlogError::Kernel("Circuit CNF literal capacity overflow".to_string())
138            })?,
139    )
140    .map_err(|_| XlogError::Kernel("Circuit CNF literal capacity exceeds u32::MAX".to_string()))?;
141
142    let memory = provider.memory();
143    let device = provider.device().inner();
144
145    // Per-node count arrays (len = num_nodes) used for exclusive scans.
146    let mut internal_prefix = memory.alloc::<u32>(num_nodes)?;
147    let mut clause_base = memory.alloc::<u32>(num_nodes)?;
148    let mut lit_base = memory.alloc::<u32>(num_nodes)?;
149
150    let counts_fn = device
151        .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_COUNTS)
152        .ok_or_else(|| XlogError::Kernel("sat_xgcf_cnf_counts kernel not found".to_string()))?;
153
154    let block = 256u32;
155    let grid = checked_launch_grid(num_nodes_u32, block, "sat_xgcf_cnf_counts")?;
156
157    // SAFETY: sat_xgcf_cnf_counts(compile_needed, node_type, child_offsets, num_nodes, internal_counts, clause_counts, lit_counts)
158    unsafe {
159        counts_fn.clone().launch(
160            LaunchConfig {
161                grid_dim: (grid, 1, 1),
162                block_dim: (block, 1, 1),
163                shared_mem_bytes: 0,
164            },
165            (
166                compile_needed,
167                circuit.node_type(),
168                circuit.child_offsets(),
169                num_nodes_u32,
170                &mut internal_prefix,
171                &mut clause_base,
172                &mut lit_base,
173            ),
174        )
175    }
176    .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_counts failed: {}", e)))?;
177
178    // Capture last elements before scans overwrite them.
179    let mut internal_last = memory.alloc::<u32>(1)?;
180    let mut clause_last = memory.alloc::<u32>(1)?;
181    let mut lit_last = memory.alloc::<u32>(1)?;
182
183    let capture_last_fn = device
184        .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_CAPTURE_LAST_COUNTS)
185        .ok_or_else(|| {
186            XlogError::Kernel("sat_xgcf_cnf_capture_last_counts kernel not found".to_string())
187        })?;
188    // SAFETY: sat_xgcf_cnf_capture_last_counts(internal_counts, clause_counts, lit_counts, num_nodes, out_internal_last, out_clause_last, out_lit_last)
189    unsafe {
190        capture_last_fn.clone().launch(
191            LaunchConfig {
192                grid_dim: (1, 1, 1),
193                block_dim: (1, 1, 1),
194                shared_mem_bytes: 0,
195            },
196            (
197                &internal_prefix,
198                &clause_base,
199                &lit_base,
200                num_nodes_u32,
201                &mut internal_last,
202                &mut clause_last,
203                &mut lit_last,
204            ),
205        )
206    }
207    .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_capture_last_counts failed: {}", e)))?;
208
209    provider.exclusive_scan_u32_inplace(&mut internal_prefix, num_nodes_u32)?;
210    provider.exclusive_scan_u32_inplace(&mut clause_base, num_nodes_u32)?;
211    provider.exclusive_scan_u32_inplace(&mut lit_base, num_nodes_u32)?;
212    // No device synchronize: next ops are alloc + kernel launches on same stream.
213
214    // Output CNF buffers + device-resident meta.
215    let d_num_vars = memory.alloc::<u32>(1)?;
216    let d_num_clauses = memory.alloc::<u32>(1)?;
217    let d_num_lits = memory.alloc::<u32>(1)?;
218    let mut d_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
219    let d_lits = memory.alloc::<i32>(lit_cap as usize)?;
220
221    let totals_fn = device
222        .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_COMPUTE_TOTALS)
223        .ok_or_else(|| {
224            XlogError::Kernel("sat_xgcf_cnf_compute_totals kernel not found".to_string())
225        })?;
226    // SAFETY: sat_xgcf_cnf_compute_totals(internal_prefix, clause_base, lit_base, internal_last*, clause_last*, lit_last*, num_nodes, base_num_vars, clause_cap, lit_cap, out_num_vars*, out_num_clauses*, out_num_lits*)
227    let mut totals_params: Vec<*mut c_void> = vec![
228        (&internal_prefix).as_kernel_param(),
229        (&clause_base).as_kernel_param(),
230        (&lit_base).as_kernel_param(),
231        (&internal_last).as_kernel_param(),
232        (&clause_last).as_kernel_param(),
233        (&lit_last).as_kernel_param(),
234        num_nodes_u32.as_kernel_param(),
235        (base_num_vars).as_kernel_param(),
236        clause_cap.as_kernel_param(),
237        lit_cap.as_kernel_param(),
238        (&d_num_vars).as_kernel_param(),
239        (&d_num_clauses).as_kernel_param(),
240        (&d_num_lits).as_kernel_param(),
241    ];
242    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
243    unsafe {
244        totals_fn.clone().launch(
245            LaunchConfig {
246                grid_dim: (1, 1, 1),
247                block_dim: (1, 1, 1),
248                shared_mem_bytes: 0,
249            },
250            &mut totals_params,
251        )
252    }
253    .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_compute_totals failed: {}", e)))?;
254
255    let emit_fn = device
256        .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_EMIT)
257        .ok_or_else(|| XlogError::Kernel("sat_xgcf_cnf_emit kernel not found".to_string()))?;
258
259    // sat_xgcf_cnf_emit(...) exceeds cudarc's tuple-arity impls for LaunchAsync, so launch with
260    // an explicit parameter list.
261    let mut params: Vec<*mut c_void> = vec![
262        compile_needed.as_kernel_param(),
263        circuit.node_type().as_kernel_param(),
264        circuit.child_offsets().as_kernel_param(),
265        circuit.child_indices().as_kernel_param(),
266        circuit.lit().as_kernel_param(),
267        circuit.decision_var().as_kernel_param(),
268        circuit.decision_child_false().as_kernel_param(),
269        circuit.decision_child_true().as_kernel_param(),
270        (&internal_prefix).as_kernel_param(),
271        (&clause_base).as_kernel_param(),
272        (&lit_base).as_kernel_param(),
273        (base_num_vars).as_kernel_param(),
274        num_nodes_u32.as_kernel_param(),
275        (&d_offsets).as_kernel_param(),
276        (&d_lits).as_kernel_param(),
277    ];
278
279    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
280    unsafe {
281        emit_fn.clone().launch(
282            LaunchConfig {
283                grid_dim: (grid, 1, 1),
284                block_dim: (block, 1, 1),
285                shared_mem_bytes: 0,
286            },
287            &mut params,
288        )
289    }
290    .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_emit failed: {}", e)))?;
291
292    // sat_xgcf_cnf_emit does not write the CSR terminator; finalize deterministically on device.
293    let term_fn = device
294        .get_func(SAT_MODULE, sat_kernels::SAT_CNF_WRITE_TERMINATOR)
295        .ok_or_else(|| {
296            XlogError::Kernel("sat_cnf_write_terminator kernel not found".to_string())
297        })?;
298    // SAFETY: sat_cnf_write_terminator(out_offsets, num_clauses*, num_lits*)
299    unsafe {
300        term_fn.clone().launch(
301            LaunchConfig {
302                grid_dim: (1, 1, 1),
303                block_dim: (1, 1, 1),
304                shared_mem_bytes: 0,
305            },
306            (&mut d_offsets, &d_num_clauses, &d_num_lits),
307        )
308    }
309    .map_err(|e| XlogError::Kernel(format!("sat_cnf_write_terminator failed: {}", e)))?;
310    // No device synchronize: returns device-resident CNF; same-stream ordering suffices.
311
312    Ok(CircuitCnf {
313        cnf: GpuCnf {
314            var_cap,
315            clause_cap,
316            lit_cap,
317            num_vars: d_num_vars,
318            num_clauses: d_num_clauses,
319            num_lits: d_num_lits,
320            clause_offsets: d_offsets,
321            literals: d_lits,
322        },
323        internal_prefix,
324    })
325}
326
327fn build_phi_and_not_c(
328    provider: &Arc<CudaKernelProvider>,
329    phi: &GpuCnf,
330    circuit: &GpuXgcf,
331    circuit_cnf: &CircuitCnf,
332    compile_needed: &TrackedCudaSlice<u32>,
333) -> Result<GpuCnf> {
334    let device = provider.device().inner();
335    let memory = provider.memory();
336
337    let phi_clause_cap = phi.clause_cap;
338    let phi_lit_cap = phi.lit_cap;
339
340    let clause_cap = u32::try_from(
341        (phi_clause_cap as u64)
342            .checked_add(circuit_cnf.cnf.clause_cap as u64)
343            .and_then(|v| v.checked_add(1))
344            .ok_or_else(|| XlogError::Kernel("phi ∧ ¬C clause capacity overflow".to_string()))?,
345    )
346    .map_err(|_| XlogError::Kernel("phi ∧ ¬C clause capacity exceeds u32::MAX".to_string()))?;
347    let lit_cap = u32::try_from(
348        (phi_lit_cap as u64)
349            .checked_add(circuit_cnf.cnf.lit_cap as u64)
350            .and_then(|v| v.checked_add(1))
351            .ok_or_else(|| XlogError::Kernel("phi ∧ ¬C literal capacity overflow".to_string()))?,
352    )
353    .map_err(|_| XlogError::Kernel("phi ∧ ¬C literal capacity exceeds u32::MAX".to_string()))?;
354
355    let var_cap = circuit_cnf.cnf.var_cap;
356
357    let out_num_vars = memory.alloc::<u32>(1)?;
358    let out_num_clauses = memory.alloc::<u32>(1)?;
359    let out_num_lits = memory.alloc::<u32>(1)?;
360    let d_unused0 = memory.alloc::<u32>(1)?;
361    let d_unused1 = memory.alloc::<u32>(1)?;
362    let d_unused2 = memory.alloc::<u32>(1)?;
363
364    let mut d_zero = memory.alloc::<u32>(1)?;
365    provider
366        .htod_launch_metadata_sync_copy_into(&[0u32], &mut d_zero)
367        .map_err(|e| XlogError::Kernel(format!("Failed to upload zero: {}", e)))?;
368
369    let mut out_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
370    let mut out_lits = memory.alloc::<i32>(lit_cap as usize)?;
371
372    let copy_fn = device
373        .get_func(SAT_MODULE, sat_kernels::SAT_CNF_COPY_INTO)
374        .ok_or_else(|| XlogError::Kernel("sat_cnf_copy_into kernel not found".to_string()))?;
375
376    let block = 256u32;
377    let phi_copy_elems =
378        checked_clause_offset_span(phi_clause_cap, "sat_cnf_copy_into(phi)")?.max(phi_lit_cap);
379    let grid = checked_launch_grid(phi_copy_elems, block, "sat_cnf_copy_into(phi)")?;
380
381    // Copy phi (exact sizes) into the front.
382    // sat_cnf_copy_into(src_offsets, src_lits, src_num_clauses*, src_num_lits*, src_clause_cap, src_lit_cap,
383    //                  dst_clause_base*, dst_lit_base*, dst_clause_cap, dst_lit_cap, dst_offsets, dst_lits)
384    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
385    unsafe {
386        copy_fn.clone().launch(
387            LaunchConfig {
388                grid_dim: (grid, 1, 1),
389                block_dim: (block, 1, 1),
390                shared_mem_bytes: 0,
391            },
392            (
393                &phi.clause_offsets,
394                &phi.literals,
395                &phi.num_clauses,
396                &phi.num_lits,
397                phi.clause_cap,
398                phi.lit_cap,
399                &d_zero,
400                &d_zero,
401                clause_cap,
402                lit_cap,
403                &mut out_offsets,
404                &mut out_lits,
405            ),
406        )
407    }
408    .map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(phi) failed: {}", e)))?;
409
410    // Copy CNF(C) after phi using device-resident bases (phi.num_clauses/phi.num_lits).
411    let circuit_copy_elems =
412        checked_clause_offset_span(circuit_cnf.cnf.clause_cap, "sat_cnf_copy_into(circuit)")?
413            .max(circuit_cnf.cnf.lit_cap);
414    let grid_c = checked_launch_grid(circuit_copy_elems, block, "sat_cnf_copy_into(circuit)")?;
415    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
416    unsafe {
417        copy_fn.clone().launch(
418            LaunchConfig {
419                grid_dim: (grid_c, 1, 1),
420                block_dim: (block, 1, 1),
421                shared_mem_bytes: 0,
422            },
423            (
424                &circuit_cnf.cnf.clause_offsets,
425                &circuit_cnf.cnf.literals,
426                &circuit_cnf.cnf.num_clauses,
427                &circuit_cnf.cnf.num_lits,
428                circuit_cnf.cnf.clause_cap,
429                circuit_cnf.cnf.lit_cap,
430                &phi.num_clauses,
431                &phi.num_lits,
432                clause_cap,
433                lit_cap,
434                &mut out_offsets,
435                &mut out_lits,
436            ),
437        )
438    }
439    .map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(C) failed: {}", e)))?;
440
441    // Finalize: append unit clause forcing root false + write device-resident totals for the combined query.
442    let unit_fn = device
443        .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE)
444        .ok_or_else(|| {
445            XlogError::Kernel("sat_xgcf_write_root_unit_clause kernel not found".to_string())
446        })?;
447
448    // IMPORTANT: When launching with an explicit `Vec<*mut c_void>` parameter list, scalar kernel
449    // arguments MUST be backed by stable host storage until `cuLaunchKernel` copies them. Do not
450    // pass temporaries like `circuit.root().as_kernel_param()` or `0i32.as_kernel_param()`.
451    let root = circuit.root();
452    let force_true: i32 = 0;
453    let out_var_cap = var_cap;
454    let out_clause_cap = clause_cap;
455    let out_lit_cap = lit_cap;
456
457    let mut params: Vec<*mut c_void> = vec![
458        compile_needed.as_kernel_param(),
459        circuit.node_type().as_kernel_param(),
460        circuit.lit().as_kernel_param(),
461        (&circuit_cnf.internal_prefix).as_kernel_param(),
462        (&phi.num_vars).as_kernel_param(),
463        root.as_kernel_param(),
464        force_true.as_kernel_param(), // force_false
465        (&phi.num_clauses).as_kernel_param(),
466        (&phi.num_lits).as_kernel_param(),
467        (&circuit_cnf.cnf.num_vars).as_kernel_param(),
468        (&circuit_cnf.cnf.num_clauses).as_kernel_param(),
469        (&circuit_cnf.cnf.num_lits).as_kernel_param(),
470        (&d_zero).as_kernel_param(), // extra_num_vars
471        (&d_zero).as_kernel_param(), // extra_num_clauses
472        (&d_zero).as_kernel_param(), // extra_num_lits
473        out_var_cap.as_kernel_param(),
474        out_clause_cap.as_kernel_param(),
475        out_lit_cap.as_kernel_param(),
476        (&out_num_vars).as_kernel_param(),
477        (&out_num_clauses).as_kernel_param(),
478        (&out_num_lits).as_kernel_param(),
479        (&d_unused0).as_kernel_param(),
480        (&d_unused1).as_kernel_param(),
481        (&d_unused2).as_kernel_param(),
482        (&out_offsets).as_kernel_param(),
483        (&out_lits).as_kernel_param(),
484    ];
485
486    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
487    unsafe {
488        unit_fn.clone().launch(
489            LaunchConfig {
490                grid_dim: (1, 1, 1),
491                block_dim: (1, 1, 1),
492                shared_mem_bytes: 0,
493            },
494            &mut params,
495        )
496    }
497    .map_err(|e| XlogError::Kernel(format!("sat_xgcf_write_root_unit_clause failed: {}", e)))?;
498    // No device synchronize: returns device-resident CNF; same-stream ordering suffices.
499
500    Ok(GpuCnf {
501        var_cap,
502        clause_cap,
503        lit_cap,
504        num_vars: out_num_vars,
505        num_clauses: out_num_clauses,
506        num_lits: out_num_lits,
507        clause_offsets: out_offsets,
508        literals: out_lits,
509    })
510}
511
512fn build_c_and_not_phi(
513    provider: &Arc<CudaKernelProvider>,
514    phi: &GpuCnf,
515    circuit: &GpuXgcf,
516    circuit_cnf: &CircuitCnf,
517    compile_needed: &TrackedCudaSlice<u32>,
518) -> Result<(GpuCnf, TrackedCudaSlice<u32>)> {
519    let device = provider.device().inner();
520    let memory = provider.memory();
521
522    let phi_clause_cap = phi.clause_cap;
523    let phi_lit_cap = phi.lit_cap;
524
525    // ¬phi encoding:
526    // clauses_notphi = sum(len_j + 1) + 1 = L + m + 1
527    // lits_notphi = sum(3*len_j + 1) + m = 3L + 2m
528    let notphi_clause_cap = u32::try_from(
529        (phi_lit_cap as u64)
530            .checked_add(phi_clause_cap as u64)
531            .and_then(|v| v.checked_add(1))
532            .ok_or_else(|| XlogError::Kernel("¬phi clause count overflow".to_string()))?,
533    )
534    .map_err(|_| XlogError::Kernel("¬phi clause count exceeds u32::MAX".to_string()))?;
535    let notphi_lit_cap = u32::try_from(
536        (phi_lit_cap as u64)
537            .checked_mul(3)
538            .and_then(|v| v.checked_add(2u64.saturating_mul(phi_clause_cap as u64)))
539            .ok_or_else(|| XlogError::Kernel("¬phi literal count overflow".to_string()))?,
540    )
541    .map_err(|_| XlogError::Kernel("¬phi literal count exceeds u32::MAX".to_string()))?;
542
543    let var_cap = circuit_cnf
544        .cnf
545        .var_cap
546        .checked_add(phi_clause_cap)
547        .ok_or_else(|| XlogError::Kernel("C ∧ ¬phi var capacity overflow".to_string()))?;
548    let clause_cap = u32::try_from(
549        (circuit_cnf.cnf.clause_cap as u64)
550            .checked_add(1)
551            .and_then(|v| v.checked_add(notphi_clause_cap as u64))
552            .ok_or_else(|| XlogError::Kernel("C ∧ ¬phi clause capacity overflow".to_string()))?,
553    )
554    .map_err(|_| XlogError::Kernel("C ∧ ¬phi clause capacity exceeds u32::MAX".to_string()))?;
555    let lit_cap = u32::try_from(
556        (circuit_cnf.cnf.lit_cap as u64)
557            .checked_add(1)
558            .and_then(|v| v.checked_add(notphi_lit_cap as u64))
559            .ok_or_else(|| XlogError::Kernel("C ∧ ¬phi literal capacity overflow".to_string()))?,
560    )
561    .map_err(|_| XlogError::Kernel("C ∧ ¬phi literal capacity exceeds u32::MAX".to_string()))?;
562
563    let out_num_vars = memory.alloc::<u32>(1)?;
564    let out_num_clauses = memory.alloc::<u32>(1)?;
565    let out_num_lits = memory.alloc::<u32>(1)?;
566
567    let mut d_zero = memory.alloc::<u32>(1)?;
568    provider
569        .htod_launch_metadata_sync_copy_into(&[0u32], &mut d_zero)
570        .map_err(|e| XlogError::Kernel(format!("Failed to upload zero: {}", e)))?;
571
572    // Device-resident exact extras for ¬phi (computed from phi.num_*).
573    let mut d_extra_num_vars = memory.alloc::<u32>(1)?;
574    let mut d_extra_num_clauses = memory.alloc::<u32>(1)?;
575    let mut d_extra_num_lits = memory.alloc::<u32>(1)?;
576
577    let d_unsat_var_base = memory.alloc::<u32>(1)?;
578    let d_notphi_clause_base = memory.alloc::<u32>(1)?;
579    let d_notphi_lit_base = memory.alloc::<u32>(1)?;
580
581    let mut out_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
582    let mut out_lits = memory.alloc::<i32>(lit_cap as usize)?;
583
584    let copy_fn = device
585        .get_func(SAT_MODULE, sat_kernels::SAT_CNF_COPY_INTO)
586        .ok_or_else(|| XlogError::Kernel("sat_cnf_copy_into kernel not found".to_string()))?;
587
588    // Copy CNF(C) into the front (exact sizes).
589    let block = 256u32;
590    let circuit_copy_elems =
591        checked_clause_offset_span(circuit_cnf.cnf.clause_cap, "sat_cnf_copy_into(circuit)")?
592            .max(circuit_cnf.cnf.lit_cap);
593    let grid = checked_launch_grid(circuit_copy_elems, block, "sat_cnf_copy_into(circuit)")?;
594    // sat_cnf_copy_into(...)
595    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
596    unsafe {
597        copy_fn.clone().launch(
598            LaunchConfig {
599                grid_dim: (grid, 1, 1),
600                block_dim: (block, 1, 1),
601                shared_mem_bytes: 0,
602            },
603            (
604                &circuit_cnf.cnf.clause_offsets,
605                &circuit_cnf.cnf.literals,
606                &circuit_cnf.cnf.num_clauses,
607                &circuit_cnf.cnf.num_lits,
608                circuit_cnf.cnf.clause_cap,
609                circuit_cnf.cnf.lit_cap,
610                &d_zero,
611                &d_zero,
612                clause_cap,
613                lit_cap,
614                &mut out_offsets,
615                &mut out_lits,
616            ),
617        )
618    }
619    .map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(C) failed: {}", e)))?;
620
621    // Compute exact ¬phi size contributions on GPU.
622    let notphi_counts_fn = device
623        .get_func(SAT_MODULE, sat_kernels::SAT_NOT_PHI_COUNTS)
624        .ok_or_else(|| XlogError::Kernel("sat_not_phi_counts kernel not found".to_string()))?;
625    // SAFETY: sat_not_phi_counts(phi_num_clauses*, phi_num_lits*, out_extra_num_vars*, out_extra_num_clauses*, out_extra_num_lits*)
626    unsafe {
627        notphi_counts_fn.clone().launch(
628            LaunchConfig {
629                grid_dim: (1, 1, 1),
630                block_dim: (1, 1, 1),
631                shared_mem_bytes: 0,
632            },
633            (
634                compile_needed,
635                &phi.num_clauses,
636                &phi.num_lits,
637                &mut d_extra_num_vars,
638                &mut d_extra_num_clauses,
639                &mut d_extra_num_lits,
640            ),
641        )
642    }
643    .map_err(|e| XlogError::Kernel(format!("sat_not_phi_counts failed: {}", e)))?;
644
645    // Prepare: insert unit clause forcing root true and compute device-resident totals / bases.
646    let unit_fn = device
647        .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE)
648        .ok_or_else(|| {
649            XlogError::Kernel("sat_xgcf_write_root_unit_clause kernel not found".to_string())
650        })?;
651
652    // IMPORTANT: See note in build_phi_and_not_c about stable scalar kernel parameters.
653    let root = circuit.root();
654    let force_true: i32 = 1;
655    let out_var_cap = var_cap;
656    let out_clause_cap = clause_cap;
657    let out_lit_cap = lit_cap;
658
659    let mut params: Vec<*mut c_void> = vec![
660        compile_needed.as_kernel_param(),
661        circuit.node_type().as_kernel_param(),
662        circuit.lit().as_kernel_param(),
663        (&circuit_cnf.internal_prefix).as_kernel_param(),
664        (&phi.num_vars).as_kernel_param(),
665        root.as_kernel_param(),
666        force_true.as_kernel_param(), // force_true
667        (&d_zero).as_kernel_param(),  // clause_base
668        (&d_zero).as_kernel_param(),  // lit_base
669        (&circuit_cnf.cnf.num_vars).as_kernel_param(),
670        (&circuit_cnf.cnf.num_clauses).as_kernel_param(),
671        (&circuit_cnf.cnf.num_lits).as_kernel_param(),
672        (&d_extra_num_vars).as_kernel_param(), // extra_num_vars (u_j vars)
673        (&d_extra_num_clauses).as_kernel_param(), // extra_num_clauses
674        (&d_extra_num_lits).as_kernel_param(), // extra_num_lits
675        out_var_cap.as_kernel_param(),
676        out_clause_cap.as_kernel_param(),
677        out_lit_cap.as_kernel_param(),
678        (&out_num_vars).as_kernel_param(),
679        (&out_num_clauses).as_kernel_param(),
680        (&out_num_lits).as_kernel_param(),
681        (&d_unsat_var_base).as_kernel_param(),
682        (&d_notphi_clause_base).as_kernel_param(),
683        (&d_notphi_lit_base).as_kernel_param(),
684        (&out_offsets).as_kernel_param(),
685        (&out_lits).as_kernel_param(),
686    ];
687
688    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
689    unsafe {
690        unit_fn.clone().launch(
691            LaunchConfig {
692                grid_dim: (1, 1, 1),
693                block_dim: (1, 1, 1),
694                shared_mem_bytes: 0,
695            },
696            &mut params,
697        )
698    }
699    .map_err(|e| XlogError::Kernel(format!("sat_xgcf_write_root_unit_clause failed: {}", e)))?;
700
701    // Emit ¬phi encoding after CNF(C) + unit using device-resident base indices.
702    let not_phi_fn = device
703        .get_func(SAT_MODULE, sat_kernels::SAT_EMIT_NOT_PHI)
704        .ok_or_else(|| XlogError::Kernel("sat_emit_not_phi kernel not found".to_string()))?;
705
706    let block = 256u32;
707    let grid = checked_launch_grid(phi_clause_cap, block, "sat_emit_not_phi")?;
708
709    // SAFETY: sat_emit_not_phi(phi_offsets, phi_lits, phi_num_clauses*, unsat_var_base*, out_clause_base*, out_lit_base*, out_offsets, out_lits)
710    unsafe {
711        not_phi_fn.clone().launch(
712            LaunchConfig {
713                grid_dim: (grid, 1, 1),
714                block_dim: (block, 1, 1),
715                shared_mem_bytes: 0,
716            },
717            (
718                compile_needed,
719                &phi.clause_offsets,
720                &phi.literals,
721                &phi.num_clauses,
722                &d_unsat_var_base,
723                &d_notphi_clause_base,
724                &d_notphi_lit_base,
725                &mut out_offsets,
726                &mut out_lits,
727            ),
728        )
729    }
730    .map_err(|e| XlogError::Kernel(format!("sat_emit_not_phi failed: {}", e)))?;
731    // No device synchronize: returns device-resident CNF; same-stream ordering suffices.
732
733    Ok((
734        GpuCnf {
735            var_cap,
736            clause_cap,
737            lit_cap,
738            num_vars: out_num_vars,
739            num_clauses: out_num_clauses,
740            num_lits: out_num_lits,
741            clause_offsets: out_offsets,
742            literals: out_lits,
743        },
744        d_unsat_var_base,
745    ))
746}
747
748pub(crate) fn check_equivalence_gpu(
749    phi: &GpuCnf,
750    phi_decision_var_limit: &TrackedCudaSlice<u32>,
751    circuit: &GpuXgcf,
752    provider: &Arc<CudaKernelProvider>,
753    config: GpuEquivalenceConfig,
754) -> Result<()> {
755    let queries = build_equivalence_queries_gpu(phi, circuit, provider)?;
756
757    #[cfg(debug_assertions)]
758    {
759        // Fail-fast: if query CNFs are malformed, the solver may hang or misbehave.
760        validate_cnf_gpu(&queries.q1, provider.as_ref())?;
761        validate_cnf_gpu(&queries.q2, provider.as_ref())?;
762    }
763
764    let solver = GpuCdclSolver::new(provider.clone(), config.cdcl);
765    if config.reuse_workspace {
766        let max_var_cap = std::cmp::max(queries.q1.var_cap, queries.q2.var_cap);
767        let max_clause_cap = std::cmp::max(queries.q1.clause_cap, queries.q2.clause_cap);
768        let mut ws = solver.new_workspace(max_var_cap, max_clause_cap)?;
769        // q1: decisions only on semantically meaningful phi vars (exclude internal/Tseitin vars).
770        solver.solve_expect_unsat_with_branch_limit_ws(
771            &mut ws,
772            &queries.q1,
773            phi_decision_var_limit,
774        )?;
775        // q2: decisions on semantically meaningful phi vars + ¬phi selector vars.
776        solver.solve_expect_unsat_with_decision_ranges_ws(
777            &mut ws,
778            &queries.q2,
779            phi_decision_var_limit,
780            &queries.q2_unsat_var_base,
781            &phi.num_clauses,
782        )?;
783    } else {
784        // q1: decisions only on semantically meaningful phi vars (exclude internal/Tseitin vars).
785        solver.solve_expect_unsat_with_branch_limit(&queries.q1, phi_decision_var_limit)?;
786        // q2: decisions on semantically meaningful phi vars + ¬phi selector vars.
787        solver.solve_expect_unsat_with_decision_ranges(
788            &queries.q2,
789            phi_decision_var_limit,
790            &queries.q2_unsat_var_base,
791            &phi.num_clauses,
792        )?;
793    }
794    Ok(())
795}
796
797/// Build the two equivalence-check queries on GPU:
798/// - q1 = φ ∧ ¬C
799/// - q2 = C ∧ ¬φ
800///
801/// This helper exists so tests and tooling can inspect query CNFs without duplicating kernel
802/// orchestration logic.
803pub fn build_equivalence_queries_gpu(
804    phi: &GpuCnf,
805    circuit: &GpuXgcf,
806    provider: &Arc<CudaKernelProvider>,
807) -> Result<GpuEquivalenceQueries> {
808    // Non-gated path: force compilation/verification on.
809    let memory = provider.memory();
810    let mut compile_needed = memory.alloc::<u32>(1)?;
811    provider
812        .htod_launch_metadata_sync_copy_into(&[1u32], &mut compile_needed)
813        .map_err(|e| XlogError::Kernel(format!("Failed to upload compile_needed=1: {}", e)))?;
814
815    let circuit_cnf = build_circuit_cnf(
816        provider,
817        circuit,
818        &phi.num_vars,
819        phi.var_cap,
820        &compile_needed,
821    )?;
822    let q1 = build_phi_and_not_c(provider, phi, circuit, &circuit_cnf, &compile_needed)?;
823    let (q2, q2_unsat_var_base) =
824        build_c_and_not_phi(provider, phi, circuit, &circuit_cnf, &compile_needed)?;
825    Ok(GpuEquivalenceQueries {
826        q1,
827        q2,
828        q2_unsat_var_base,
829    })
830}
831
832pub(crate) fn check_equivalence_gpu_gated(
833    phi: &GpuCnf,
834    phi_decision_var_limit: &TrackedCudaSlice<u32>,
835    circuit: &GpuXgcf,
836    provider: &Arc<CudaKernelProvider>,
837    config: GpuEquivalenceConfig,
838    compile_needed: &TrackedCudaSlice<u32>,
839) -> Result<()> {
840    #[cfg(debug_assertions)]
841    eprintln!("[xlog-prob] equivalence: build_circuit_cnf");
842    let circuit_cnf = build_circuit_cnf(
843        provider,
844        circuit,
845        &phi.num_vars,
846        phi.var_cap,
847        compile_needed,
848    )?;
849    #[cfg(debug_assertions)]
850    {
851        provider.device().synchronize().map_err(|e| {
852            XlogError::Kernel(format!("sync after build_circuit_cnf failed: {}", e))
853        })?;
854        eprintln!("[xlog-prob] equivalence: build_phi_and_not_c");
855    }
856
857    let q1 = build_phi_and_not_c(provider, phi, circuit, &circuit_cnf, compile_needed)?;
858    #[cfg(debug_assertions)]
859    {
860        provider.device().synchronize().map_err(|e| {
861            XlogError::Kernel(format!("sync after build_phi_and_not_c failed: {}", e))
862        })?;
863        eprintln!("[xlog-prob] equivalence: build_c_and_not_phi");
864    }
865    let (q2, q2_unsat_var_base) =
866        build_c_and_not_phi(provider, phi, circuit, &circuit_cnf, compile_needed)?;
867    #[cfg(debug_assertions)]
868    {
869        provider.device().synchronize().map_err(|e| {
870            XlogError::Kernel(format!("sync after build_c_and_not_phi failed: {}", e))
871        })?;
872        eprintln!(
873            "[xlog-prob] equivalence: caps: phi(v={} c={} l={}) circuit_cnf(v={} c={} l={}) q1(v={} c={} l={}) q2(v={} c={} l={})",
874            phi.var_cap,
875            phi.clause_cap,
876            phi.lit_cap,
877            circuit_cnf.cnf.var_cap,
878            circuit_cnf.cnf.clause_cap,
879            circuit_cnf.cnf.lit_cap,
880            q1.var_cap,
881            q1.clause_cap,
882            q1.lit_cap,
883            q2.var_cap,
884            q2.clause_cap,
885            q2.lit_cap,
886        );
887        eprintln!("[xlog-prob] equivalence: solve_expect_unsat q1");
888    }
889
890    #[cfg(debug_assertions)]
891    {
892        validate_cnf_gpu(&q1, provider.as_ref())?;
893        validate_cnf_gpu(&q2, provider.as_ref())?;
894    }
895
896    let solver = GpuCdclSolver::new(provider.clone(), config.cdcl);
897    if config.reuse_workspace {
898        let max_var_cap = std::cmp::max(q1.var_cap, q2.var_cap);
899        let max_clause_cap = std::cmp::max(q1.clause_cap, q2.clause_cap);
900        let mut ws = solver.new_workspace(max_var_cap, max_clause_cap)?;
901        solver.solve_expect_unsat_with_branch_limit_gated_ws(
902            &mut ws,
903            &q1,
904            compile_needed,
905            phi_decision_var_limit,
906        )?;
907        #[cfg(debug_assertions)]
908        {
909            provider.device().synchronize().map_err(|e| {
910                XlogError::Kernel(format!("sync after solve_expect_unsat(q1) failed: {}", e))
911            })?;
912            eprintln!("[xlog-prob] equivalence: solve_expect_unsat q2");
913        }
914        solver.solve_expect_unsat_with_decision_ranges_gated_ws(
915            &mut ws,
916            &q2,
917            compile_needed,
918            phi_decision_var_limit,
919            &q2_unsat_var_base,
920            &phi.num_clauses,
921        )?;
922    } else {
923        solver.solve_expect_unsat_with_branch_limit_gated(
924            &q1,
925            compile_needed,
926            phi_decision_var_limit,
927        )?;
928        #[cfg(debug_assertions)]
929        {
930            provider.device().synchronize().map_err(|e| {
931                XlogError::Kernel(format!("sync after solve_expect_unsat(q1) failed: {}", e))
932            })?;
933            eprintln!("[xlog-prob] equivalence: solve_expect_unsat q2");
934        }
935        solver.solve_expect_unsat_with_decision_ranges_gated(
936            &q2,
937            compile_needed,
938            phi_decision_var_limit,
939            &q2_unsat_var_base,
940            &phi.num_clauses,
941        )?;
942    }
943    #[cfg(debug_assertions)]
944    {
945        provider.device().synchronize().map_err(|e| {
946            XlogError::Kernel(format!("sync after solve_expect_unsat(q2) failed: {}", e))
947        })?;
948        eprintln!("[xlog-prob] equivalence: done");
949    }
950    Ok(())
951}
952
953/// Pre-launch CNF size bound for the equivalence verifier.
954///
955/// The GPU CDCL equivalence check is treewidth-exponential; on a hard CNF it
956/// can run long enough to hit a CUDA launch failure, which poisons the primary
957/// context for every later compile in the process. Catching that
958/// after the fact is impossible — a poisoned context cannot be recovered
959/// in-process — so the only sound mitigation is to decline *before* launch.
960///
961/// The bound is on the host-side `var_cap`/`clause_cap` (capacity upper bounds,
962/// no device read). Defaults are unbounded (`u32::MAX`), so behavior is
963/// unchanged unless an operator sets `XLOG_D4_VERIFY_MAX_VARS` /
964/// `XLOG_D4_VERIFY_MAX_CLAUSES`. The recommended production default is a
965/// calibration follow-up against the measured explosion boundary.
966fn verify_size_budget() -> (u32, u32) {
967    fn env_u32(key: &str) -> u32 {
968        std::env::var(key)
969            .ok()
970            .and_then(|v| v.trim().parse::<u32>().ok())
971            .unwrap_or(u32::MAX)
972    }
973    static BUDGET: std::sync::OnceLock<(u32, u32)> = std::sync::OnceLock::new();
974    *BUDGET.get_or_init(|| {
975        (
976            env_u32("XLOG_D4_VERIFY_MAX_VARS"),
977            env_u32("XLOG_D4_VERIFY_MAX_CLAUSES"),
978        )
979    })
980}
981
982/// Pure size-bound decision: typed [`XlogError::VerifyBudgetExceeded`] when the
983/// CNF caps exceed the budget. GPU-free so the decline path is unit-testable
984/// without a device.
985fn enforce_verify_size_bound(
986    var_cap: u32,
987    clause_cap: u32,
988    var_budget: u32,
989    clause_budget: u32,
990    context: &str,
991) -> Result<()> {
992    if var_cap > var_budget || clause_cap > clause_budget {
993        // A size-based decline is a COMPILE-capacity issue ("too big to
994        // compile safely"), distinct from the verify-phase conflict budget.
995        return Err(XlogError::CompileCapacityExceeded {
996            context: context.to_string(),
997            detail: format!(
998                "CNF {var_cap} vars / {clause_cap} clauses exceeds size bound \
999                 ({var_budget} vars / {clause_budget} clauses)"
1000            ),
1001        });
1002    }
1003    Ok(())
1004}
1005
1006/// Decline the verify with a typed [`XlogError::VerifyBudgetExceeded`] when the
1007/// CNF exceeds the configured size bound, before any kernel launches.
1008///
1009/// Callable from the compile path so the size guard runs BEFORE `compile_gpu_d4`
1010/// — the D4 compile itself can crash (context-poisoning launch failure) on a
1011/// large CNF, earlier than the verify, so a size check only at the verify entry
1012/// is too late to prevent it.
1013pub(crate) fn check_verify_size_bound(phi: &GpuCnf, context: &str) -> Result<()> {
1014    // Operator/calibration diagnostic: log the CNF caps the bound sees, so the
1015    // safe value of XLOG_D4_VERIFY_MAX_VARS/_MAX_CLAUSES can be read off a real
1016    // workload (must sit above every program that compiles fine, below the
1017    // explosion boundary). Off unless XLOG_DEBUG_VERIFY_SIZE=1.
1018    if std::env::var("XLOG_DEBUG_VERIFY_SIZE").as_deref() == Ok("1") {
1019        eprintln!(
1020            "[xlog-prob] verify-size {context}: var_cap={} clause_cap={} lit_cap={}",
1021            phi.var_cap, phi.clause_cap, phi.lit_cap
1022        );
1023    }
1024    let (var_budget, clause_budget) = verify_size_budget();
1025    enforce_verify_size_bound(
1026        phi.var_cap,
1027        phi.clause_cap,
1028        var_budget,
1029        clause_budget,
1030        context,
1031    )
1032}
1033
1034pub fn validate_equivalence_gpu(
1035    phi: &GpuCnf,
1036    phi_decision_var_limit: &TrackedCudaSlice<u32>,
1037    circuit: &GpuXgcf,
1038    provider: &Arc<CudaKernelProvider>,
1039    config: GpuEquivalenceConfig,
1040) -> Result<()> {
1041    check_verify_size_bound(phi, "validate_equivalence_gpu")?;
1042    check_equivalence_gpu(phi, phi_decision_var_limit, circuit, provider, config)
1043}
1044
1045pub fn validate_equivalence_gpu_gated(
1046    phi: &GpuCnf,
1047    phi_decision_var_limit: &TrackedCudaSlice<u32>,
1048    circuit: &GpuXgcf,
1049    provider: &Arc<CudaKernelProvider>,
1050    config: GpuEquivalenceConfig,
1051    compile_needed: &TrackedCudaSlice<u32>,
1052) -> Result<()> {
1053    check_verify_size_bound(phi, "validate_equivalence_gpu_gated")?;
1054    check_equivalence_gpu_gated(
1055        phi,
1056        phi_decision_var_limit,
1057        circuit,
1058        provider,
1059        config,
1060        compile_needed,
1061    )
1062}
1063
1064#[cfg(test)]
1065mod verify_size_bound_tests {
1066    use super::enforce_verify_size_bound;
1067    use xlog_core::XlogError;
1068
1069    // A CNF over the configured bound must decline with a typed
1070    // VerifyBudgetExceeded carrying the tripping sizes — BEFORE any launch —
1071    // never a panic and never a doomed kernel that poisons the context.
1072    #[test]
1073    fn over_var_budget_declines_typed() {
1074        let err = enforce_verify_size_bound(5000, 100, 4096, u32::MAX, "ctx")
1075            .expect_err("must decline over var budget");
1076        match err {
1077            // Size declines are COMPILE-capacity, distinct from verify budget.
1078            XlogError::CompileCapacityExceeded { context, detail } => {
1079                assert_eq!(context, "ctx");
1080                // detail names the tripping sizes and the bound.
1081                assert!(detail.contains("5000 vars"), "detail: {detail}");
1082                assert!(detail.contains("size bound"), "detail: {detail}");
1083            }
1084            other => panic!("wrong error variant: {other:?}"),
1085        }
1086    }
1087
1088    #[test]
1089    fn over_clause_budget_declines_typed() {
1090        let err = enforce_verify_size_bound(10, 20_000, u32::MAX, 16_384, "ctx")
1091            .expect_err("must decline over clause budget");
1092        assert!(matches!(err, XlogError::CompileCapacityExceeded { .. }));
1093    }
1094
1095    #[test]
1096    fn within_budget_proceeds() {
1097        enforce_verify_size_bound(100, 200, 4096, 16_384, "ctx")
1098            .expect("within budget must pass the bound");
1099    }
1100
1101    // Default budget is unbounded (u32::MAX) => no regression: any real CNF
1102    // passes the bound unless an operator opts in via env.
1103    #[test]
1104    fn unbounded_default_never_declines() {
1105        enforce_verify_size_bound(u32::MAX, u32::MAX, u32::MAX, u32::MAX, "ctx")
1106            .expect("unbounded default must not decline");
1107    }
1108}