1use std::ffi::c_void;
4use std::sync::Arc;
5
6use cudarc::driver::{DeviceSlice, LaunchConfig};
7use xlog_core::{Result, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::provider::{cnf_kernels, CNF_MODULE};
10use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
11use xlog_solve::GpuCnf;
12
13use crate::compilation::gpu_pir::GpuPirGraph;
14use crate::compilation::gpu_pir::GpuPirRoots;
15
16pub struct GpuCnfVarTables {
18 pub node_var: TrackedCudaSlice<u32>,
19 pub leaf_var: TrackedCudaSlice<u32>,
20 pub choice_var: TrackedCudaSlice<u32>,
21 pub max_var: u32,
22}
23
24pub struct GpuCnfEncoding {
26 pub cnf: GpuCnf,
27 pub vars: GpuCnfVarTables,
28 pub decision_var_limit: TrackedCudaSlice<u32>,
34}
35
36const MAX_GRID_X: u64 = 65_535;
37
38fn checked_grid_dim(n: u32, block: u32, context: &str) -> Result<u32> {
39 if block == 0 {
40 return Err(XlogError::Kernel(format!(
41 "{context}: CUDA launch block size must be nonzero"
42 )));
43 }
44 let grid = if n == 0 {
45 1
46 } else {
47 u64::from(n).div_ceil(u64::from(block))
48 };
49 if grid > MAX_GRID_X {
50 return Err(XlogError::Kernel(format!(
51 "{context}: launch grid {grid} exceeds x-dimension limit {MAX_GRID_X} \
52 for {n} elements with block size {block}"
53 )));
54 }
55 Ok(grid as u32)
56}
57
58pub fn encode_cnf_gpu(
60 pir: &GpuPirGraph,
61 roots: &GpuPirRoots,
62 provider: &Arc<CudaKernelProvider>,
63) -> Result<GpuCnfEncoding> {
64 if roots.roots.is_empty() {
65 return Err(XlogError::Compilation(
66 "Cannot encode CNF for empty PIR root set".to_string(),
67 ));
68 }
69 let num_nodes = pir.node_type.len();
70 if num_nodes == 0 {
71 return Err(XlogError::Compilation(
72 "Cannot encode CNF for empty PIR graph".to_string(),
73 ));
74 }
75
76 let num_nodes_u32 = u32::try_from(num_nodes)
77 .map_err(|_| XlogError::Compilation("PIR node count overflow".to_string()))?;
78 let num_roots_u32 = u32::try_from(roots.roots.len())
79 .map_err(|_| XlogError::Compilation("PIR root count exceeds u32::MAX".to_string()))?;
80
81 let num_edges = pir.children.len();
82 let n64 = num_nodes as u64;
83 let e64 = num_edges as u64;
84
85 let var_cap = u32::try_from(
86 n64.checked_mul(3)
87 .ok_or_else(|| XlogError::Kernel("CNF var capacity overflow".to_string()))?,
88 )
89 .map_err(|_| XlogError::Kernel("CNF var capacity exceeds u32::MAX".to_string()))?;
90 let clause_cap = u32::try_from(
91 e64.checked_add(
92 n64.checked_mul(4)
93 .ok_or_else(|| XlogError::Kernel("CNF clause capacity overflow".to_string()))?,
94 )
95 .ok_or_else(|| XlogError::Kernel("CNF clause capacity overflow".to_string()))?,
96 )
97 .map_err(|_| XlogError::Kernel("CNF clause capacity exceeds u32::MAX".to_string()))?;
98 let lit_cap =
99 u32::try_from(
100 e64.checked_mul(3)
101 .ok_or_else(|| XlogError::Kernel("CNF literal capacity overflow".to_string()))?
102 .checked_add(n64.checked_mul(12).ok_or_else(|| {
103 XlogError::Kernel("CNF literal capacity overflow".to_string())
104 })?)
105 .ok_or_else(|| XlogError::Kernel("CNF literal capacity overflow".to_string()))?,
106 )
107 .map_err(|_| XlogError::Kernel("CNF literal capacity exceeds u32::MAX".to_string()))?;
108
109 let leaf_cap = num_nodes_u32;
110 let choice_cap = num_nodes_u32;
111
112 let memory = provider.memory();
113 let device = provider.device().inner();
114
115 let mut reachable = memory.alloc::<u32>(num_nodes)?;
116 let mut queue = memory.alloc::<u32>(num_nodes)?;
117 let mut queue_ready = memory.alloc::<u32>(num_nodes)?;
118 let mut head = memory.alloc::<u32>(1)?;
119 let mut tail = memory.alloc::<u32>(1)?;
120 let mut in_flight = memory.alloc::<u32>(1)?;
121
122 let mut leaf_used = memory.alloc::<u32>(leaf_cap as usize)?;
123 let mut choice_used = memory.alloc::<u32>(choice_cap as usize)?;
124 let mut leaf_var = memory.alloc::<u32>(leaf_cap as usize)?;
125 let mut choice_var = memory.alloc::<u32>(choice_cap as usize)?;
126
127 let mut node_needs_var = memory.alloc::<u32>(num_nodes)?;
128 let mut node_var = memory.alloc::<u32>(num_nodes)?;
129
130 let mut clause_counts = memory.alloc::<u32>(num_nodes)?;
131 let mut lit_counts = memory.alloc::<u32>(num_nodes)?;
132
133 let mut leaf_prefix = memory.alloc::<u32>(leaf_cap as usize)?;
134 let mut choice_prefix = memory.alloc::<u32>(choice_cap as usize)?;
135
136 let mut node_last = memory.alloc::<u32>(1)?;
137 let mut clause_last = memory.alloc::<u32>(1)?;
138 let mut lit_last = memory.alloc::<u32>(1)?;
139
140 let mut num_leaf = memory.alloc::<u32>(1)?;
141 let mut num_choice = memory.alloc::<u32>(1)?;
142 let mut base_choice = memory.alloc::<u32>(1)?;
143 let mut base_node = memory.alloc::<u32>(1)?;
144 let mut decision_var_limit = memory.alloc::<u32>(1)?;
145
146 let d_num_vars = memory.alloc::<u32>(1)?;
147 let d_num_clauses = memory.alloc::<u32>(1)?;
148 let d_num_lits = memory.alloc::<u32>(1)?;
149
150 let mut d_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
151 let d_lits = memory.alloc::<i32>(lit_cap as usize)?;
152
153 device
154 .memset_zeros(&mut reachable)
155 .map_err(|e| XlogError::Kernel(format!("zero reachable: {}", e)))?;
156 device
157 .memset_zeros(&mut queue)
158 .map_err(|e| XlogError::Kernel(format!("zero queue: {}", e)))?;
159 device
160 .memset_zeros(&mut queue_ready)
161 .map_err(|e| XlogError::Kernel(format!("zero queue_ready: {}", e)))?;
162 device
163 .memset_zeros(&mut head)
164 .map_err(|e| XlogError::Kernel(format!("zero head: {}", e)))?;
165 device
166 .memset_zeros(&mut tail)
167 .map_err(|e| XlogError::Kernel(format!("zero tail: {}", e)))?;
168 device
169 .memset_zeros(&mut in_flight)
170 .map_err(|e| XlogError::Kernel(format!("zero in_flight: {}", e)))?;
171 device
172 .memset_zeros(&mut leaf_used)
173 .map_err(|e| XlogError::Kernel(format!("zero leaf_used: {}", e)))?;
174 device
175 .memset_zeros(&mut choice_used)
176 .map_err(|e| XlogError::Kernel(format!("zero choice_used: {}", e)))?;
177 device
178 .memset_zeros(&mut leaf_var)
179 .map_err(|e| XlogError::Kernel(format!("zero leaf_var: {}", e)))?;
180 device
181 .memset_zeros(&mut choice_var)
182 .map_err(|e| XlogError::Kernel(format!("zero choice_var: {}", e)))?;
183 device
184 .memset_zeros(&mut node_needs_var)
185 .map_err(|e| XlogError::Kernel(format!("zero node_needs_var: {}", e)))?;
186 device
187 .memset_zeros(&mut node_var)
188 .map_err(|e| XlogError::Kernel(format!("zero node_var: {}", e)))?;
189 device
190 .memset_zeros(&mut clause_counts)
191 .map_err(|e| XlogError::Kernel(format!("zero clause_counts: {}", e)))?;
192 device
193 .memset_zeros(&mut lit_counts)
194 .map_err(|e| XlogError::Kernel(format!("zero lit_counts: {}", e)))?;
195
196 let reach_init_fn = device
197 .get_func(CNF_MODULE, cnf_kernels::CNF_REACHABILITY_INIT)
198 .ok_or_else(|| XlogError::Kernel("cnf_reachability_init kernel not found".to_string()))?;
199 let reach_bfs_fn = device
200 .get_func(CNF_MODULE, cnf_kernels::CNF_REACHABILITY_BFS)
201 .ok_or_else(|| XlogError::Kernel("cnf_reachability_bfs kernel not found".to_string()))?;
202 let mark_leaf_choice_fn = device
203 .get_func(CNF_MODULE, cnf_kernels::CNF_MARK_LEAF_CHOICE)
204 .ok_or_else(|| XlogError::Kernel("cnf_mark_leaf_choice kernel not found".to_string()))?;
205 let assign_leaf_var_fn = device
206 .get_func(CNF_MODULE, cnf_kernels::CNF_ASSIGN_LEAF_VAR)
207 .ok_or_else(|| XlogError::Kernel("cnf_assign_leaf_var kernel not found".to_string()))?;
208 let assign_choice_var_fn = device
209 .get_func(CNF_MODULE, cnf_kernels::CNF_ASSIGN_CHOICE_VAR)
210 .ok_or_else(|| XlogError::Kernel("cnf_assign_choice_var kernel not found".to_string()))?;
211 let mark_node_vars_fn = device
212 .get_func(CNF_MODULE, cnf_kernels::CNF_MARK_NODE_VARS)
213 .ok_or_else(|| XlogError::Kernel("cnf_mark_node_vars kernel not found".to_string()))?;
214 let count_clauses_fn = device
215 .get_func(CNF_MODULE, cnf_kernels::CNF_COUNT_CLAUSES)
216 .ok_or_else(|| XlogError::Kernel("cnf_count_clauses kernel not found".to_string()))?;
217 let capture_last_fn = device
218 .get_func(CNF_MODULE, cnf_kernels::CNF_CAPTURE_LAST_COUNTS)
219 .ok_or_else(|| XlogError::Kernel("cnf_capture_last_counts kernel not found".to_string()))?;
220 let leaf_choice_totals_fn = device
221 .get_func(CNF_MODULE, cnf_kernels::CNF_COMPUTE_LEAF_CHOICE_TOTALS)
222 .ok_or_else(|| {
223 XlogError::Kernel("cnf_compute_leaf_choice_totals kernel not found".to_string())
224 })?;
225 let compute_totals_fn = device
226 .get_func(CNF_MODULE, cnf_kernels::CNF_COMPUTE_TOTALS)
227 .ok_or_else(|| XlogError::Kernel("cnf_compute_totals kernel not found".to_string()))?;
228 let assign_node_var_fn = device
229 .get_func(CNF_MODULE, cnf_kernels::CNF_ASSIGN_NODE_VAR)
230 .ok_or_else(|| XlogError::Kernel("cnf_assign_node_var kernel not found".to_string()))?;
231 let emit_clauses_fn = device
232 .get_func(CNF_MODULE, cnf_kernels::CNF_EMIT_CLAUSES)
233 .ok_or_else(|| XlogError::Kernel("cnf_emit_clauses kernel not found".to_string()))?;
234 let set_clause_end_fn = device
235 .get_func(CNF_MODULE, cnf_kernels::CNF_SET_CLAUSE_END)
236 .ok_or_else(|| XlogError::Kernel("cnf_set_clause_end kernel not found".to_string()))?;
237
238 let block = 256u32;
239
240 let grid_roots = checked_grid_dim(num_roots_u32, block, "cnf_reachability_init")?;
241 unsafe {
243 reach_init_fn.clone().launch(
244 LaunchConfig {
245 grid_dim: (grid_roots, 1, 1),
246 block_dim: (block, 1, 1),
247 shared_mem_bytes: 0,
248 },
249 (
250 &roots.roots,
251 num_roots_u32,
252 num_nodes_u32,
253 &mut reachable,
254 &mut queue,
255 &mut queue_ready,
256 &mut head,
257 &mut tail,
258 &mut in_flight,
259 ),
260 )
261 }
262 .map_err(|e| XlogError::Kernel(format!("cnf_reachability_init failed: {}", e)))?;
263
264 let grid_nodes = checked_grid_dim(num_nodes_u32, block, "cnf node kernels")?;
265 unsafe {
267 reach_bfs_fn.clone().launch(
268 LaunchConfig {
269 grid_dim: (grid_nodes, 1, 1),
270 block_dim: (block, 1, 1),
271 shared_mem_bytes: 0,
272 },
273 (
274 &pir.node_type,
275 &pir.child_offsets,
276 &pir.children,
277 &pir.decision_child_false,
278 &pir.decision_child_true,
279 num_nodes_u32,
280 &mut reachable,
281 &mut queue,
282 &mut queue_ready,
283 &mut head,
284 &mut tail,
285 &mut in_flight,
286 ),
287 )
288 }
289 .map_err(|e| XlogError::Kernel(format!("cnf_reachability_bfs failed: {}", e)))?;
290
291 unsafe {
293 mark_leaf_choice_fn.clone().launch(
294 LaunchConfig {
295 grid_dim: (grid_nodes, 1, 1),
296 block_dim: (block, 1, 1),
297 shared_mem_bytes: 0,
298 },
299 (
300 &pir.node_type,
301 &pir.leaf_id,
302 &pir.decision_var,
303 &reachable,
304 num_nodes_u32,
305 leaf_cap,
306 choice_cap,
307 &mut leaf_used,
308 &mut choice_used,
309 ),
310 )
311 }
312 .map_err(|e| XlogError::Kernel(format!("cnf_mark_leaf_choice failed: {}", e)))?;
313
314 if leaf_cap > 0 {
315 device
316 .dtod_copy(&leaf_used, &mut leaf_prefix)
317 .map_err(|e| XlogError::Kernel(format!("copy leaf_used: {}", e)))?;
318 provider.exclusive_scan_u32_inplace(&mut leaf_prefix, leaf_cap)?;
319 }
320 if choice_cap > 0 {
321 device
322 .dtod_copy(&choice_used, &mut choice_prefix)
323 .map_err(|e| XlogError::Kernel(format!("copy choice_used: {}", e)))?;
324 provider.exclusive_scan_u32_inplace(&mut choice_prefix, choice_cap)?;
325 }
326
327 unsafe {
329 leaf_choice_totals_fn.clone().launch(
330 LaunchConfig {
331 grid_dim: (1, 1, 1),
332 block_dim: (1, 1, 1),
333 shared_mem_bytes: 0,
334 },
335 (
336 &leaf_prefix,
337 &leaf_used,
338 leaf_cap,
339 &choice_prefix,
340 &choice_used,
341 choice_cap,
342 &mut num_leaf,
343 &mut num_choice,
344 &mut base_choice,
345 &mut base_node,
346 &mut decision_var_limit,
347 ),
348 )
349 }
350 .map_err(|e| XlogError::Kernel(format!("cnf_compute_leaf_choice_totals failed: {}", e)))?;
351
352 if leaf_cap > 0 {
353 let grid_leaf = checked_grid_dim(leaf_cap, block, "cnf_assign_leaf_var")?;
354 unsafe {
356 assign_leaf_var_fn.clone().launch(
357 LaunchConfig {
358 grid_dim: (grid_leaf, 1, 1),
359 block_dim: (block, 1, 1),
360 shared_mem_bytes: 0,
361 },
362 (&leaf_used, &leaf_prefix, leaf_cap, &mut leaf_var),
363 )
364 }
365 .map_err(|e| XlogError::Kernel(format!("cnf_assign_leaf_var failed: {}", e)))?;
366 }
367 if choice_cap > 0 {
368 let grid_choice = checked_grid_dim(choice_cap, block, "cnf_assign_choice_var")?;
369 unsafe {
371 assign_choice_var_fn.clone().launch(
372 LaunchConfig {
373 grid_dim: (grid_choice, 1, 1),
374 block_dim: (block, 1, 1),
375 shared_mem_bytes: 0,
376 },
377 (
378 &choice_used,
379 &choice_prefix,
380 choice_cap,
381 &base_choice,
382 &mut choice_var,
383 ),
384 )
385 }
386 .map_err(|e| XlogError::Kernel(format!("cnf_assign_choice_var failed: {}", e)))?;
387 }
388
389 unsafe {
391 mark_node_vars_fn.clone().launch(
392 LaunchConfig {
393 grid_dim: (grid_nodes, 1, 1),
394 block_dim: (block, 1, 1),
395 shared_mem_bytes: 0,
396 },
397 (
398 &pir.node_type,
399 &reachable,
400 num_nodes_u32,
401 &mut node_needs_var,
402 ),
403 )
404 }
405 .map_err(|e| XlogError::Kernel(format!("cnf_mark_node_vars failed: {}", e)))?;
406
407 unsafe {
409 count_clauses_fn.clone().launch(
410 LaunchConfig {
411 grid_dim: (grid_nodes, 1, 1),
412 block_dim: (block, 1, 1),
413 shared_mem_bytes: 0,
414 },
415 (
416 &pir.node_type,
417 &pir.child_offsets,
418 &reachable,
419 num_nodes_u32,
420 &mut clause_counts,
421 &mut lit_counts,
422 ),
423 )
424 }
425 .map_err(|e| XlogError::Kernel(format!("cnf_count_clauses failed: {}", e)))?;
426
427 unsafe {
429 capture_last_fn.clone().launch(
430 LaunchConfig {
431 grid_dim: (1, 1, 1),
432 block_dim: (1, 1, 1),
433 shared_mem_bytes: 0,
434 },
435 (
436 &node_needs_var,
437 &clause_counts,
438 &lit_counts,
439 num_nodes_u32,
440 &mut node_last,
441 &mut clause_last,
442 &mut lit_last,
443 ),
444 )
445 }
446 .map_err(|e| XlogError::Kernel(format!("cnf_capture_last_counts failed: {}", e)))?;
447
448 provider.exclusive_scan_u32_inplace(&mut node_needs_var, num_nodes_u32)?;
449 provider.exclusive_scan_u32_inplace(&mut clause_counts, num_nodes_u32)?;
450 provider.exclusive_scan_u32_inplace(&mut lit_counts, num_nodes_u32)?;
451
452 let mut totals_params: Vec<*mut c_void> = vec![
453 (&node_needs_var).as_kernel_param(),
454 (&clause_counts).as_kernel_param(),
455 (&lit_counts).as_kernel_param(),
456 (&node_last).as_kernel_param(),
457 (&clause_last).as_kernel_param(),
458 (&lit_last).as_kernel_param(),
459 num_nodes_u32.as_kernel_param(),
460 (&base_node).as_kernel_param(),
461 var_cap.as_kernel_param(),
462 clause_cap.as_kernel_param(),
463 lit_cap.as_kernel_param(),
464 (&d_num_vars).as_kernel_param(),
465 (&d_num_clauses).as_kernel_param(),
466 (&d_num_lits).as_kernel_param(),
467 ];
468 unsafe {
470 compute_totals_fn.clone().launch(
471 LaunchConfig {
472 grid_dim: (1, 1, 1),
473 block_dim: (1, 1, 1),
474 shared_mem_bytes: 0,
475 },
476 &mut totals_params,
477 )
478 }
479 .map_err(|e| XlogError::Kernel(format!("cnf_compute_totals failed: {}", e)))?;
480
481 unsafe {
483 assign_node_var_fn.clone().launch(
484 LaunchConfig {
485 grid_dim: (grid_nodes, 1, 1),
486 block_dim: (block, 1, 1),
487 shared_mem_bytes: 0,
488 },
489 (
490 &pir.node_type,
491 &pir.leaf_id,
492 &reachable,
493 &node_needs_var,
494 &base_node,
495 num_nodes_u32,
496 leaf_cap,
497 &leaf_var,
498 &mut node_var,
499 ),
500 )
501 }
502 .map_err(|e| XlogError::Kernel(format!("cnf_assign_node_var failed: {}", e)))?;
503
504 let mut emit_params: Vec<*mut c_void> = vec![
505 (&pir.node_type).as_kernel_param(),
506 (&pir.child_offsets).as_kernel_param(),
507 (&pir.children).as_kernel_param(),
508 (&pir.leaf_id).as_kernel_param(),
509 (&pir.decision_var).as_kernel_param(),
510 (&pir.decision_child_false).as_kernel_param(),
511 (&pir.decision_child_true).as_kernel_param(),
512 (&reachable).as_kernel_param(),
513 (&node_var).as_kernel_param(),
514 (&leaf_var).as_kernel_param(),
515 (&choice_var).as_kernel_param(),
516 (&clause_counts).as_kernel_param(),
517 (&lit_counts).as_kernel_param(),
518 num_nodes_u32.as_kernel_param(),
519 leaf_cap.as_kernel_param(),
520 choice_cap.as_kernel_param(),
521 (&d_offsets).as_kernel_param(),
522 (&d_lits).as_kernel_param(),
523 ];
524
525 unsafe {
527 emit_clauses_fn.clone().launch(
528 LaunchConfig {
529 grid_dim: (grid_nodes, 1, 1),
530 block_dim: (block, 1, 1),
531 shared_mem_bytes: 0,
532 },
533 &mut emit_params,
534 )
535 }
536 .map_err(|e| XlogError::Kernel(format!("cnf_emit_clauses failed: {}", e)))?;
537
538 unsafe {
540 set_clause_end_fn.clone().launch(
541 LaunchConfig {
542 grid_dim: (1, 1, 1),
543 block_dim: (1, 1, 1),
544 shared_mem_bytes: 0,
545 },
546 (&mut d_offsets, &d_num_clauses, &d_num_lits),
547 )
548 }
549 .map_err(|e| XlogError::Kernel(format!("cnf_set_clause_end failed: {}", e)))?;
550 Ok(GpuCnfEncoding {
552 cnf: GpuCnf {
553 var_cap,
554 clause_cap,
555 lit_cap,
556 num_vars: d_num_vars,
557 num_clauses: d_num_clauses,
558 num_lits: d_num_lits,
559 clause_offsets: d_offsets,
560 literals: d_lits,
561 },
562 vars: GpuCnfVarTables {
563 node_var,
564 leaf_var,
565 choice_var,
566 max_var: var_cap,
567 },
568 decision_var_limit,
569 })
570}