1use std::ffi::c_void;
6use std::sync::Arc;
7
8use cudarc::driver::{DeviceSlice, LaunchConfig};
9use xlog_core::{Result, XlogError};
10use xlog_cuda::memory::TrackedCudaSlice;
11use xlog_cuda::provider::{pir_kernels, scan_kernels, RadixSortScratch, PIR_MODULE, SCAN_MODULE};
12use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
13
14use crate::compilation::gpu_pir::{GpuPirGraph, PIR_AND, PIR_CONST};
15
16#[derive(Debug, Clone)]
18pub struct PirBatch {
19 pub node_type: Vec<u8>,
20 pub leaf_id: Vec<u32>,
21 pub decision_var: Vec<u32>,
22 pub decision_child_false: Vec<u32>,
23 pub decision_child_true: Vec<u32>,
24 pub child_offsets: Vec<u32>,
25 pub children: Vec<u32>,
26}
27
28impl PirBatch {
29 pub fn len(&self) -> usize {
30 self.node_type.len()
31 }
32
33 pub fn is_empty(&self) -> bool {
34 self.node_type.is_empty()
35 }
36
37 pub fn and_or_batch(children: Vec<Vec<u32>>) -> Self {
41 let num_nodes = children.len();
42 let mut node_type = vec![PIR_AND; num_nodes];
43 let leaf_id = vec![0u32; num_nodes];
44 let decision_var = vec![0u32; num_nodes];
45 let decision_child_false = vec![0u32; num_nodes];
46 let decision_child_true = vec![0u32; num_nodes];
47
48 let mut child_offsets = Vec::with_capacity(num_nodes + 1);
49 let mut flat_children = Vec::new();
50 child_offsets.push(0);
51 for kids in children {
52 flat_children.extend(kids);
53 child_offsets.push(flat_children.len() as u32);
54 }
55
56 if node_type.is_empty() {
57 node_type = Vec::new();
58 }
59
60 Self {
61 node_type,
62 leaf_id,
63 decision_var,
64 decision_child_false,
65 decision_child_true,
66 child_offsets,
67 children: flat_children,
68 }
69 }
70
71 pub fn to_device(&self, provider: &Arc<CudaKernelProvider>) -> Result<GpuPirBatch> {
72 let num_nodes = self.node_type.len();
73 if self.leaf_id.len() != num_nodes
74 || self.decision_var.len() != num_nodes
75 || self.decision_child_false.len() != num_nodes
76 || self.decision_child_true.len() != num_nodes
77 {
78 return Err(XlogError::Compilation(
79 "PirBatch: array length mismatch".to_string(),
80 ));
81 }
82 if self.child_offsets.len() != num_nodes + 1 {
83 return Err(XlogError::Compilation(
84 "PirBatch: child_offsets must be len num_nodes+1".to_string(),
85 ));
86 }
87 if let Some(&last) = self.child_offsets.last() {
88 if last as usize != self.children.len() {
89 return Err(XlogError::Compilation(
90 "PirBatch: child_offsets last entry must equal children len".to_string(),
91 ));
92 }
93 }
94
95 let memory = provider.memory();
96
97 let mut d_node_type = memory.alloc::<u8>(num_nodes)?;
98 let mut d_leaf_id = memory.alloc::<u32>(num_nodes)?;
99 let mut d_decision_var = memory.alloc::<u32>(num_nodes)?;
100 let mut d_decision_child_false = memory.alloc::<u32>(num_nodes)?;
101 let mut d_decision_child_true = memory.alloc::<u32>(num_nodes)?;
102 let mut d_child_offsets = memory.alloc::<u32>(self.child_offsets.len())?;
103 let mut d_children = memory.alloc::<u32>(self.children.len())?;
104
105 provider
106 .htod_sync_copy_into_tracked(&self.node_type, &mut d_node_type)
107 .map_err(|e| XlogError::Kernel(format!("PirBatch upload node_type: {}", e)))?;
108 provider
109 .htod_sync_copy_into_tracked(&self.leaf_id, &mut d_leaf_id)
110 .map_err(|e| XlogError::Kernel(format!("PirBatch upload leaf_id: {}", e)))?;
111 provider
112 .htod_sync_copy_into_tracked(&self.decision_var, &mut d_decision_var)
113 .map_err(|e| XlogError::Kernel(format!("PirBatch upload decision_var: {}", e)))?;
114 provider
115 .htod_sync_copy_into_tracked(&self.decision_child_false, &mut d_decision_child_false)
116 .map_err(|e| {
117 XlogError::Kernel(format!("PirBatch upload decision_child_false: {}", e))
118 })?;
119 provider
120 .htod_sync_copy_into_tracked(&self.decision_child_true, &mut d_decision_child_true)
121 .map_err(|e| {
122 XlogError::Kernel(format!("PirBatch upload decision_child_true: {}", e))
123 })?;
124 provider
125 .htod_sync_copy_into_tracked(&self.child_offsets, &mut d_child_offsets)
126 .map_err(|e| XlogError::Kernel(format!("PirBatch upload child_offsets: {}", e)))?;
127 provider
128 .htod_sync_copy_into_tracked(&self.children, &mut d_children)
129 .map_err(|e| XlogError::Kernel(format!("PirBatch upload children: {}", e)))?;
130
131 Ok(GpuPirBatch {
132 node_type: d_node_type,
133 leaf_id: d_leaf_id,
134 decision_var: d_decision_var,
135 decision_child_false: d_decision_child_false,
136 decision_child_true: d_decision_child_true,
137 child_offsets: d_child_offsets,
138 children: d_children,
139 })
140 }
141}
142
143pub struct GpuPirBatch {
145 pub node_type: TrackedCudaSlice<u8>,
146 pub leaf_id: TrackedCudaSlice<u32>,
147 pub decision_var: TrackedCudaSlice<u32>,
148 pub decision_child_false: TrackedCudaSlice<u32>,
149 pub decision_child_true: TrackedCudaSlice<u32>,
150 pub child_offsets: TrackedCudaSlice<u32>,
151 pub children: TrackedCudaSlice<u32>,
152}
153
154impl GpuPirBatch {
155 pub fn num_nodes(&self) -> usize {
156 self.node_type.len()
157 }
158
159 pub fn num_children(&self) -> usize {
160 self.children.len()
161 }
162}
163
164pub struct GpuPirInterner {
166 provider: Arc<CudaKernelProvider>,
167 node_cap: u32,
168 child_cap: u32,
169 graph: GpuPirGraph,
170 graph_hashes: TrackedCudaSlice<u64>,
171 num_nodes: TrackedCudaSlice<u32>,
172 num_children: TrackedCudaSlice<u32>,
173}
174
175impl GpuPirInterner {
176 pub fn new(provider: &Arc<CudaKernelProvider>, node_cap: u32, child_cap: u32) -> Result<Self> {
177 if node_cap < 2 {
178 return Err(XlogError::Compilation(
179 "GpuPirInterner requires node_cap >= 2 for const nodes".to_string(),
180 ));
181 }
182 let memory = provider.memory();
183 let device = provider.device().inner();
184
185 let mut node_type = memory.alloc::<u8>(node_cap as usize)?;
186 let mut child_offsets = memory.alloc::<u32>((node_cap as usize) + 1)?;
187 let mut children = memory.alloc::<u32>(child_cap as usize)?;
188 let mut leaf_id = memory.alloc::<u32>(node_cap as usize)?;
189 let mut decision_var = memory.alloc::<u32>(node_cap as usize)?;
190 let mut decision_child_false = memory.alloc::<u32>(node_cap as usize)?;
191 let mut decision_child_true = memory.alloc::<u32>(node_cap as usize)?;
192 let mut graph_hashes = memory.alloc::<u64>(node_cap as usize)?;
193
194 device
195 .memset_zeros(&mut node_type)
196 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init node_type: {}", e)))?;
197 device
198 .memset_zeros(&mut child_offsets)
199 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init child_offsets: {}", e)))?;
200 device
201 .memset_zeros(&mut children)
202 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init children: {}", e)))?;
203 device
204 .memset_zeros(&mut decision_var)
205 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init decision_var: {}", e)))?;
206 device
207 .memset_zeros(&mut decision_child_false)
208 .map_err(|e| {
209 XlogError::Kernel(format!("GpuPirInterner init decision_child_false: {}", e))
210 })?;
211 device.memset_zeros(&mut decision_child_true).map_err(|e| {
212 XlogError::Kernel(format!("GpuPirInterner init decision_child_true: {}", e))
213 })?;
214 device
215 .memset_zeros(&mut graph_hashes)
216 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init graph_hashes: {}", e)))?;
217
218 let mut leaf_id_host = vec![0u32; node_cap as usize];
219 if node_cap > 1 {
220 leaf_id_host[1] = 1;
221 }
222 provider
223 .htod_launch_metadata_sync_copy_into(&leaf_id_host, &mut leaf_id)
224 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init leaf_id: {}", e)))?;
225
226 let hash_fn = device
227 .get_func(PIR_MODULE, pir_kernels::PIR_HASH_KEYS)
228 .ok_or_else(|| XlogError::Kernel("pir_hash_keys not found".to_string()))?;
229 let num_const = 2u32;
230 let block_size = 256u32;
231 let grid_const = num_const.div_ceil(block_size);
232 unsafe {
234 hash_fn.clone().launch(
235 LaunchConfig {
236 grid_dim: (grid_const.max(1), 1, 1),
237 block_dim: (block_size, 1, 1),
238 shared_mem_bytes: 0,
239 },
240 (
241 &node_type,
242 &leaf_id,
243 &decision_var,
244 &decision_child_false,
245 &decision_child_true,
246 &child_offsets,
247 &children,
248 num_const,
249 &mut graph_hashes,
250 ),
251 )
252 }
253 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init hash: {}", e)))?;
254
255 let mut num_nodes = memory.alloc::<u32>(1)?;
256 let mut num_children = memory.alloc::<u32>(1)?;
257 provider
258 .htod_launch_metadata_sync_copy_into(&[2u32], &mut num_nodes)
259 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init num_nodes: {}", e)))?;
260 provider
261 .htod_launch_metadata_sync_copy_into(&[0u32], &mut num_children)
262 .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init num_children: {}", e)))?;
263
264 Ok(Self {
265 provider: Arc::clone(provider),
266 node_cap,
267 child_cap,
268 graph: GpuPirGraph {
269 node_type,
270 child_offsets,
271 children,
272 leaf_id,
273 decision_var,
274 decision_child_false,
275 decision_child_true,
276 },
277 graph_hashes,
278 num_nodes,
279 num_children,
280 })
281 }
282
283 pub fn graph(&self) -> &GpuPirGraph {
284 &self.graph
285 }
286
287 pub fn intern_batch(&mut self, batch: &PirBatch) -> Result<TrackedCudaSlice<u32>> {
288 if batch.node_type.contains(&PIR_CONST) {
289 return Err(XlogError::Compilation(
290 "GpuPirInterner does not accept PIR_CONST in batches".to_string(),
291 ));
292 }
293 let mut device_batch = batch.to_device(&self.provider)?;
294 self.intern_device_batch(&mut device_batch)
295 }
296
297 pub fn intern_device_batch(
298 &mut self,
299 batch: &mut GpuPirBatch,
300 ) -> Result<TrackedCudaSlice<u32>> {
301 let num_nodes = batch.num_nodes();
302 if num_nodes == 0 {
303 return self.provider.memory().alloc::<u32>(0);
304 }
305 let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
306 XlogError::Compilation("GpuPirInterner: num_nodes overflow".to_string())
307 })?;
308 let num_children = batch.num_children();
309 let num_children_u32 = u32::try_from(num_children).map_err(|_| {
310 XlogError::Compilation("GpuPirInterner: num_children overflow".to_string())
311 })?;
312
313 let device = self.provider.device().inner();
314 let memory = self.provider.memory();
315 let block_size = 256u32;
316
317 let mut canon_child_offsets = memory.alloc::<u32>(num_nodes + 1)?;
319 let mut canon_children = memory.alloc::<u32>(num_children)?;
320
321 if num_children_u32 == 0 {
322 device.memset_zeros(&mut canon_child_offsets).map_err(|e| {
323 XlogError::Kernel(format!("GpuPirInterner zero child_offsets: {}", e))
324 })?;
325 } else {
326 let mut parent_ids = memory.alloc::<u32>(num_children)?;
327 let fill_fn = device
328 .get_func(PIR_MODULE, pir_kernels::PIR_FILL_CHILD_PARENTS)
329 .ok_or_else(|| XlogError::Kernel("pir_fill_child_parents not found".to_string()))?;
330 let grid_nodes = num_nodes_u32.div_ceil(block_size);
331 unsafe {
333 fill_fn.clone().launch(
334 LaunchConfig {
335 grid_dim: (grid_nodes, 1, 1),
336 block_dim: (block_size, 1, 1),
337 shared_mem_bytes: 0,
338 },
339 (&batch.child_offsets, num_nodes_u32, &mut parent_ids),
340 )
341 }
342 .map_err(|e| XlogError::Kernel(format!("pir_fill_child_parents failed: {}", e)))?;
343
344 let mut sort_scratch = RadixSortScratch::new(&self.provider, num_children_u32)?;
345 self.provider.radix_sort_u32_pairs(
346 &mut batch.children,
347 &mut parent_ids,
348 num_children_u32,
349 &mut sort_scratch,
350 )?;
351 self.provider.radix_sort_u32_pairs(
352 &mut parent_ids,
353 &mut batch.children,
354 num_children_u32,
355 &mut sort_scratch,
356 )?;
357
358 let mut pair_unique_mask = memory.alloc::<u8>(num_children)?;
359 let mark_pairs = device
360 .get_func(PIR_MODULE, pir_kernels::PIR_MARK_UNIQUE_PAIRS)
361 .ok_or_else(|| XlogError::Kernel("pir_mark_unique_pairs not found".to_string()))?;
362 let grid_pairs = num_children_u32.div_ceil(block_size);
363 unsafe {
365 mark_pairs.clone().launch(
366 LaunchConfig {
367 grid_dim: (grid_pairs, 1, 1),
368 block_dim: (block_size, 1, 1),
369 shared_mem_bytes: 0,
370 },
371 (
372 &parent_ids,
373 &batch.children,
374 num_children_u32,
375 &mut pair_unique_mask,
376 ),
377 )
378 }
379 .map_err(|e| XlogError::Kernel(format!("pir_mark_unique_pairs failed: {}", e)))?;
380
381 let pair_prefix = self
382 .provider
383 .scan_u8_mask_device(&pair_unique_mask, num_children_u32)?;
384
385 let mut unique_pairs_total = memory.alloc::<u32>(1)?;
386 device
387 .memset_zeros(&mut unique_pairs_total)
388 .map_err(|e| XlogError::Kernel(format!("zero unique_pairs_total: {}", e)))?;
389 let count_mask = device
390 .get_func(SCAN_MODULE, scan_kernels::COUNT_MASK)
391 .ok_or_else(|| XlogError::Kernel("count_mask not found".to_string()))?;
392 unsafe {
394 count_mask.clone().launch(
395 LaunchConfig {
396 grid_dim: (grid_pairs, 1, 1),
397 block_dim: (block_size, 1, 1),
398 shared_mem_bytes: 0,
399 },
400 (&pair_unique_mask, num_children_u32, &mut unique_pairs_total),
401 )
402 }
403 .map_err(|e| XlogError::Kernel(format!("count_mask (pairs) failed: {}", e)))?;
404
405 let mut canon_parent = memory.alloc::<u32>(num_children)?;
406 let compact_pairs = device
407 .get_func(PIR_MODULE, pir_kernels::PIR_COMPACT_PAIRS)
408 .ok_or_else(|| XlogError::Kernel("pir_compact_pairs not found".to_string()))?;
409 unsafe {
411 compact_pairs.clone().launch(
412 LaunchConfig {
413 grid_dim: (grid_pairs, 1, 1),
414 block_dim: (block_size, 1, 1),
415 shared_mem_bytes: 0,
416 },
417 (
418 &parent_ids,
419 &batch.children,
420 &pair_unique_mask,
421 &pair_prefix,
422 num_children_u32,
423 &mut canon_parent,
424 &mut canon_children,
425 ),
426 )
427 }
428 .map_err(|e| XlogError::Kernel(format!("pir_compact_pairs failed: {}", e)))?;
429
430 let mut child_counts = memory.alloc::<u32>(num_nodes)?;
431 device
432 .memset_zeros(&mut child_counts)
433 .map_err(|e| XlogError::Kernel(format!("zero child_counts: {}", e)))?;
434 let count_children = device
435 .get_func(PIR_MODULE, pir_kernels::PIR_COUNT_CHILDREN)
436 .ok_or_else(|| XlogError::Kernel("pir_count_children not found".to_string()))?;
437 unsafe {
439 count_children.clone().launch(
440 LaunchConfig {
441 grid_dim: (grid_pairs, 1, 1),
442 block_dim: (block_size, 1, 1),
443 shared_mem_bytes: 0,
444 },
445 (
446 &canon_parent,
447 &unique_pairs_total,
448 num_nodes_u32,
449 &mut child_counts,
450 ),
451 )
452 }
453 .map_err(|e| XlogError::Kernel(format!("pir_count_children failed: {}", e)))?;
454
455 self.provider
456 .exclusive_scan_u32_inplace(&mut child_counts, num_nodes_u32)?;
457
458 let write_offsets = device
459 .get_func(PIR_MODULE, pir_kernels::PIR_WRITE_CHILD_OFFSETS)
460 .ok_or_else(|| {
461 XlogError::Kernel("pir_write_child_offsets not found".to_string())
462 })?;
463 let grid_nodes = num_nodes_u32.div_ceil(block_size);
464 unsafe {
466 write_offsets.clone().launch(
467 LaunchConfig {
468 grid_dim: (grid_nodes.max(1), 1, 1),
469 block_dim: (block_size, 1, 1),
470 shared_mem_bytes: 0,
471 },
472 (
473 &child_counts,
474 num_nodes_u32,
475 &unique_pairs_total,
476 &mut canon_child_offsets,
477 ),
478 )
479 }
480 .map_err(|e| XlogError::Kernel(format!("pir_write_child_offsets failed: {}", e)))?;
481 }
482
483 let mut hashes = memory.alloc::<u64>(num_nodes)?;
485 let hash_fn = device
486 .get_func(PIR_MODULE, pir_kernels::PIR_HASH_KEYS)
487 .ok_or_else(|| XlogError::Kernel("pir_hash_keys not found".to_string()))?;
488 let grid_nodes = num_nodes_u32.div_ceil(block_size);
489 unsafe {
491 hash_fn.clone().launch(
492 LaunchConfig {
493 grid_dim: (grid_nodes, 1, 1),
494 block_dim: (block_size, 1, 1),
495 shared_mem_bytes: 0,
496 },
497 (
498 &batch.node_type,
499 &batch.leaf_id,
500 &batch.decision_var,
501 &batch.decision_child_false,
502 &batch.decision_child_true,
503 &canon_child_offsets,
504 &canon_children,
505 num_nodes_u32,
506 &mut hashes,
507 ),
508 )
509 }
510 .map_err(|e| XlogError::Kernel(format!("pir_hash_keys failed: {}", e)))?;
511
512 let mut key_tag = memory.alloc::<u32>(num_nodes)?;
513 let mut key_payload = memory.alloc::<u32>(num_nodes)?;
514 let mut key_child_len = memory.alloc::<u32>(num_nodes)?;
515 let pack_fn = device
516 .get_func(PIR_MODULE, pir_kernels::PIR_PACK_KEYS)
517 .ok_or_else(|| XlogError::Kernel("pir_pack_keys not found".to_string()))?;
518 unsafe {
520 pack_fn.clone().launch(
521 LaunchConfig {
522 grid_dim: (grid_nodes, 1, 1),
523 block_dim: (block_size, 1, 1),
524 shared_mem_bytes: 0,
525 },
526 (
527 &batch.node_type,
528 &batch.leaf_id,
529 &batch.decision_var,
530 &batch.decision_child_false,
531 &batch.decision_child_true,
532 &canon_child_offsets,
533 num_nodes_u32,
534 &mut key_tag,
535 &mut key_payload,
536 &mut key_child_len,
537 ),
538 )
539 }
540 .map_err(|e| XlogError::Kernel(format!("pir_pack_keys failed: {}", e)))?;
541
542 let mut indices = memory.alloc::<u32>(num_nodes)?;
544 self.provider.init_indices(&mut indices, num_nodes_u32)?;
545 let mut keys = memory.alloc::<u32>(num_nodes)?;
546 let mut node_sort = RadixSortScratch::new(&self.provider, num_nodes_u32)?;
547
548 self.provider
549 .gather_u32_by_indices(&key_child_len, &indices, &mut keys, num_nodes_u32)?;
550 self.provider.radix_sort_u32_pairs(
551 &mut keys,
552 &mut indices,
553 num_nodes_u32,
554 &mut node_sort,
555 )?;
556
557 self.provider
558 .gather_u32_by_indices(&key_payload, &indices, &mut keys, num_nodes_u32)?;
559 self.provider.radix_sort_u32_pairs(
560 &mut keys,
561 &mut indices,
562 num_nodes_u32,
563 &mut node_sort,
564 )?;
565
566 self.provider
567 .gather_u32_by_indices(&key_tag, &indices, &mut keys, num_nodes_u32)?;
568 self.provider.radix_sort_u32_pairs(
569 &mut keys,
570 &mut indices,
571 num_nodes_u32,
572 &mut node_sort,
573 )?;
574
575 self.provider
576 .gather_u64_lo_by_indices(&hashes, &indices, &mut keys, num_nodes_u32)?;
577 self.provider.radix_sort_u32_pairs(
578 &mut keys,
579 &mut indices,
580 num_nodes_u32,
581 &mut node_sort,
582 )?;
583
584 self.provider
585 .gather_u64_hi_by_indices(&hashes, &indices, &mut keys, num_nodes_u32)?;
586 self.provider.radix_sort_u32_pairs(
587 &mut keys,
588 &mut indices,
589 num_nodes_u32,
590 &mut node_sort,
591 )?;
592
593 let mut sorted_node_type = memory.alloc::<u8>(num_nodes)?;
595 let mut sorted_leaf_id = memory.alloc::<u32>(num_nodes)?;
596 let mut sorted_decision_var = memory.alloc::<u32>(num_nodes)?;
597 let mut sorted_decision_child_false = memory.alloc::<u32>(num_nodes)?;
598 let mut sorted_decision_child_true = memory.alloc::<u32>(num_nodes)?;
599
600 self.provider.gather_u8_by_indices(
601 &batch.node_type,
602 &indices,
603 &mut sorted_node_type,
604 num_nodes_u32,
605 )?;
606 self.provider.gather_u32_by_indices(
607 &batch.leaf_id,
608 &indices,
609 &mut sorted_leaf_id,
610 num_nodes_u32,
611 )?;
612 self.provider.gather_u32_by_indices(
613 &batch.decision_var,
614 &indices,
615 &mut sorted_decision_var,
616 num_nodes_u32,
617 )?;
618 self.provider.gather_u32_by_indices(
619 &batch.decision_child_false,
620 &indices,
621 &mut sorted_decision_child_false,
622 num_nodes_u32,
623 )?;
624 self.provider.gather_u32_by_indices(
625 &batch.decision_child_true,
626 &indices,
627 &mut sorted_decision_child_true,
628 num_nodes_u32,
629 )?;
630
631 let mut sorted_child_len = memory.alloc::<u32>(num_nodes)?;
633 self.provider.gather_u32_by_indices(
634 &key_child_len,
635 &indices,
636 &mut sorted_child_len,
637 num_nodes_u32,
638 )?;
639 self.provider
640 .exclusive_scan_u32_inplace(&mut sorted_child_len, num_nodes_u32)?;
641
642 let mut sorted_child_offsets = memory.alloc::<u32>(num_nodes + 1)?;
643 let write_offsets = device
644 .get_func(PIR_MODULE, pir_kernels::PIR_WRITE_CHILD_OFFSETS)
645 .ok_or_else(|| XlogError::Kernel("pir_write_child_offsets not found".to_string()))?;
646 let total_children_view = canon_child_offsets.slice(num_nodes..(num_nodes + 1));
647 unsafe {
649 write_offsets.clone().launch(
650 LaunchConfig {
651 grid_dim: (grid_nodes.max(1), 1, 1),
652 block_dim: (block_size, 1, 1),
653 shared_mem_bytes: 0,
654 },
655 (
656 &sorted_child_len,
657 num_nodes_u32,
658 &total_children_view,
659 &mut sorted_child_offsets,
660 ),
661 )
662 }
663 .map_err(|e| XlogError::Kernel(format!("pir_write_child_offsets(sorted) failed: {}", e)))?;
664
665 let mut sorted_children = memory.alloc::<u32>(num_children)?;
666 let gather_children = device
667 .get_func(PIR_MODULE, pir_kernels::PIR_GATHER_CHILDREN)
668 .ok_or_else(|| XlogError::Kernel("pir_gather_children not found".to_string()))?;
669 unsafe {
671 gather_children.clone().launch(
672 LaunchConfig {
673 grid_dim: (grid_nodes, 1, 1),
674 block_dim: (block_size, 1, 1),
675 shared_mem_bytes: 0,
676 },
677 (
678 &indices,
679 &canon_child_offsets,
680 &canon_children,
681 &sorted_child_offsets,
682 num_nodes_u32,
683 &mut sorted_children,
684 ),
685 )
686 }
687 .map_err(|e| XlogError::Kernel(format!("pir_gather_children failed: {}", e)))?;
688
689 let mut sorted_hashes = memory.alloc::<u64>(num_nodes)?;
691 let hash_fn = device
692 .get_func(PIR_MODULE, pir_kernels::PIR_HASH_KEYS)
693 .ok_or_else(|| XlogError::Kernel("pir_hash_keys not found".to_string()))?;
694 unsafe {
696 hash_fn.clone().launch(
697 LaunchConfig {
698 grid_dim: (grid_nodes, 1, 1),
699 block_dim: (block_size, 1, 1),
700 shared_mem_bytes: 0,
701 },
702 (
703 &sorted_node_type,
704 &sorted_leaf_id,
705 &sorted_decision_var,
706 &sorted_decision_child_false,
707 &sorted_decision_child_true,
708 &sorted_child_offsets,
709 &sorted_children,
710 num_nodes_u32,
711 &mut sorted_hashes,
712 ),
713 )
714 }
715 .map_err(|e| XlogError::Kernel(format!("pir_hash_keys(sorted) failed: {}", e)))?;
716
717 let hash_table = self
718 .provider
719 .build_hash_table_u64(&self.graph_hashes, self.node_cap)?;
720 let existing_id = memory.alloc::<u32>(num_nodes)?;
721 let find_existing = device
722 .get_func(PIR_MODULE, pir_kernels::PIR_FIND_EXISTING)
723 .ok_or_else(|| XlogError::Kernel("pir_find_existing not found".to_string()))?;
724 let mut find_params: Vec<*mut c_void> = vec![
725 (&sorted_hashes).as_kernel_param(),
726 (&sorted_node_type).as_kernel_param(),
727 (&sorted_leaf_id).as_kernel_param(),
728 (&sorted_decision_var).as_kernel_param(),
729 (&sorted_decision_child_false).as_kernel_param(),
730 (&sorted_decision_child_true).as_kernel_param(),
731 (&sorted_child_offsets).as_kernel_param(),
732 (&sorted_children).as_kernel_param(),
733 num_nodes_u32.as_kernel_param(),
734 (&self.graph.node_type).as_kernel_param(),
735 (&self.graph.child_offsets).as_kernel_param(),
736 (&self.graph.children).as_kernel_param(),
737 (&self.graph.leaf_id).as_kernel_param(),
738 (&self.graph.decision_var).as_kernel_param(),
739 (&self.graph.decision_child_false).as_kernel_param(),
740 (&self.graph.decision_child_true).as_kernel_param(),
741 (&self.num_nodes).as_kernel_param(),
742 (&hash_table.bucket_offsets).as_kernel_param(),
743 (&hash_table.bucket_counts).as_kernel_param(),
744 (&hash_table.bucket_entries).as_kernel_param(),
745 (&hash_table.bucket_entry_hashes).as_kernel_param(),
746 hash_table.bucket_mask.as_kernel_param(),
747 (&existing_id).as_kernel_param(),
748 ];
749 unsafe {
751 find_existing.clone().launch(
752 LaunchConfig {
753 grid_dim: (grid_nodes, 1, 1),
754 block_dim: (block_size, 1, 1),
755 shared_mem_bytes: 0,
756 },
757 &mut find_params,
758 )
759 }
760 .map_err(|e| XlogError::Kernel(format!("pir_find_existing failed: {}", e)))?;
761
762 let mut unique_mask = memory.alloc::<u8>(num_nodes)?;
764 let mark_unique = device
765 .get_func(PIR_MODULE, pir_kernels::PIR_MARK_UNIQUE)
766 .ok_or_else(|| XlogError::Kernel("pir_mark_unique not found".to_string()))?;
767 unsafe {
769 mark_unique.clone().launch(
770 LaunchConfig {
771 grid_dim: (grid_nodes, 1, 1),
772 block_dim: (block_size, 1, 1),
773 shared_mem_bytes: 0,
774 },
775 (
776 &sorted_hashes,
777 &sorted_node_type,
778 &sorted_leaf_id,
779 &sorted_decision_var,
780 &sorted_decision_child_false,
781 &sorted_decision_child_true,
782 &sorted_child_offsets,
783 &sorted_children,
784 num_nodes_u32,
785 &mut unique_mask,
786 ),
787 )
788 }
789 .map_err(|e| XlogError::Kernel(format!("pir_mark_unique failed: {}", e)))?;
790
791 let mut new_mask = memory.alloc::<u8>(num_nodes)?;
792 let mark_new = device
793 .get_func(PIR_MODULE, pir_kernels::PIR_MARK_NEW_GROUPS)
794 .ok_or_else(|| XlogError::Kernel("pir_mark_new_groups not found".to_string()))?;
795 unsafe {
797 mark_new.clone().launch(
798 LaunchConfig {
799 grid_dim: (grid_nodes, 1, 1),
800 block_dim: (block_size, 1, 1),
801 shared_mem_bytes: 0,
802 },
803 (&unique_mask, &existing_id, num_nodes_u32, &mut new_mask),
804 )
805 }
806 .map_err(|e| XlogError::Kernel(format!("pir_mark_new_groups failed: {}", e)))?;
807
808 let unique_prefix = self
809 .provider
810 .scan_u8_mask_device(&unique_mask, num_nodes_u32)?;
811
812 let new_prefix = self
813 .provider
814 .scan_u8_mask_device(&new_mask, num_nodes_u32)?;
815
816 let mut new_nodes_total = memory.alloc::<u32>(1)?;
817 device
818 .memset_zeros(&mut new_nodes_total)
819 .map_err(|e| XlogError::Kernel(format!("zero new_nodes_total: {}", e)))?;
820 let count_mask = device
821 .get_func(SCAN_MODULE, scan_kernels::COUNT_MASK)
822 .ok_or_else(|| XlogError::Kernel("count_mask not found".to_string()))?;
823 unsafe {
825 count_mask.clone().launch(
826 LaunchConfig {
827 grid_dim: (grid_nodes, 1, 1),
828 block_dim: (block_size, 1, 1),
829 shared_mem_bytes: 0,
830 },
831 (&new_mask, num_nodes_u32, &mut new_nodes_total),
832 )
833 }
834 .map_err(|e| XlogError::Kernel(format!("count_mask (new nodes) failed: {}", e)))?;
835
836 let mut group_node_id = memory.alloc::<u32>(num_nodes)?;
837 let build_group = device
838 .get_func(PIR_MODULE, pir_kernels::PIR_BUILD_GROUP_IDS)
839 .ok_or_else(|| XlogError::Kernel("pir_build_group_ids not found".to_string()))?;
840 unsafe {
842 build_group.clone().launch(
843 LaunchConfig {
844 grid_dim: (grid_nodes, 1, 1),
845 block_dim: (block_size, 1, 1),
846 shared_mem_bytes: 0,
847 },
848 (
849 &unique_mask,
850 &unique_prefix,
851 &existing_id,
852 &new_prefix,
853 &self.num_nodes,
854 num_nodes_u32,
855 &mut group_node_id,
856 ),
857 )
858 }
859 .map_err(|e| XlogError::Kernel(format!("pir_build_group_ids failed: {}", e)))?;
860
861 let mut graph_child_counts = memory.alloc::<u32>(num_nodes)?;
862 let build_counts = device
863 .get_func(PIR_MODULE, pir_kernels::PIR_BUILD_GRAPH_CHILD_COUNTS)
864 .ok_or_else(|| {
865 XlogError::Kernel("pir_build_graph_child_counts not found".to_string())
866 })?;
867 unsafe {
869 build_counts.clone().launch(
870 LaunchConfig {
871 grid_dim: (grid_nodes, 1, 1),
872 block_dim: (block_size, 1, 1),
873 shared_mem_bytes: 0,
874 },
875 (
876 &sorted_node_type,
877 &sorted_child_offsets,
878 &new_mask,
879 num_nodes_u32,
880 &mut graph_child_counts,
881 ),
882 )
883 }
884 .map_err(|e| XlogError::Kernel(format!("pir_build_graph_child_counts failed: {}", e)))?;
885
886 let mut graph_children_total = memory.alloc::<u32>(1)?;
887 device
888 .memset_zeros(&mut graph_children_total)
889 .map_err(|e| XlogError::Kernel(format!("zero graph_children_total: {}", e)))?;
890 let sum_counts = device
891 .get_func(PIR_MODULE, pir_kernels::PIR_SUM_COUNTS)
892 .ok_or_else(|| XlogError::Kernel("pir_sum_counts not found".to_string()))?;
893 unsafe {
895 sum_counts.clone().launch(
896 LaunchConfig {
897 grid_dim: (grid_nodes, 1, 1),
898 block_dim: (block_size, 1, 1),
899 shared_mem_bytes: 0,
900 },
901 (
902 &graph_child_counts,
903 num_nodes_u32,
904 &mut graph_children_total,
905 ),
906 )
907 }
908 .map_err(|e| XlogError::Kernel(format!("pir_sum_counts failed: {}", e)))?;
909
910 self.provider
911 .exclusive_scan_u32_inplace(&mut graph_child_counts, num_nodes_u32)?;
912
913 let out_ids = memory.alloc::<u32>(num_nodes)?;
914 let emit = device
915 .get_func(PIR_MODULE, pir_kernels::PIR_EMIT_NODES_AND_IDS)
916 .ok_or_else(|| XlogError::Kernel("pir_emit_nodes_and_ids not found".to_string()))?;
917 let mut emit_params: Vec<*mut c_void> = vec![
918 (&sorted_node_type).as_kernel_param(),
919 (&sorted_leaf_id).as_kernel_param(),
920 (&sorted_decision_var).as_kernel_param(),
921 (&sorted_decision_child_false).as_kernel_param(),
922 (&sorted_decision_child_true).as_kernel_param(),
923 (&sorted_child_offsets).as_kernel_param(),
924 (&sorted_children).as_kernel_param(),
925 (&unique_mask).as_kernel_param(),
926 (&unique_prefix).as_kernel_param(),
927 (&group_node_id).as_kernel_param(),
928 (&graph_child_counts).as_kernel_param(),
929 (&indices).as_kernel_param(),
930 num_nodes_u32.as_kernel_param(),
931 (&self.num_nodes).as_kernel_param(),
932 (&self.num_children).as_kernel_param(),
933 self.node_cap.as_kernel_param(),
934 self.child_cap.as_kernel_param(),
935 (&self.graph.node_type).as_kernel_param(),
936 (&self.graph.child_offsets).as_kernel_param(),
937 (&self.graph.children).as_kernel_param(),
938 (&self.graph.leaf_id).as_kernel_param(),
939 (&self.graph.decision_var).as_kernel_param(),
940 (&self.graph.decision_child_false).as_kernel_param(),
941 (&self.graph.decision_child_true).as_kernel_param(),
942 (&new_mask).as_kernel_param(),
943 (&sorted_hashes).as_kernel_param(),
944 (&self.graph_hashes).as_kernel_param(),
945 (&out_ids).as_kernel_param(),
946 ];
947 unsafe {
949 emit.clone().launch(
950 LaunchConfig {
951 grid_dim: (grid_nodes, 1, 1),
952 block_dim: (block_size, 1, 1),
953 shared_mem_bytes: 0,
954 },
955 &mut emit_params,
956 )
957 }
958 .map_err(|e| XlogError::Kernel(format!("pir_emit_nodes_and_ids failed: {}", e)))?;
959
960 let update_counts = device
961 .get_func(PIR_MODULE, pir_kernels::PIR_UPDATE_COUNTS)
962 .ok_or_else(|| XlogError::Kernel("pir_update_counts not found".to_string()))?;
963 unsafe {
965 update_counts.clone().launch(
966 LaunchConfig {
967 grid_dim: (1, 1, 1),
968 block_dim: (1, 1, 1),
969 shared_mem_bytes: 0,
970 },
971 (
972 &new_nodes_total,
973 &graph_children_total,
974 self.node_cap,
975 self.child_cap,
976 &mut self.num_nodes,
977 &mut self.num_children,
978 ),
979 )
980 }
981 .map_err(|e| XlogError::Kernel(format!("pir_update_counts failed: {}", e)))?;
982 Ok(out_ids)
984 }
985}