Skip to main content

xlog_prob/compilation/
gpu_pir_intern.rs

1//! GPU PIR interner (device-side hash-consing).
2//!
3//! Implements deterministic, memory-bounded interning of PIR node batches on GPU.
4
5use std::ffi::c_void;
6use std::sync::Arc;
7
8use cudarc::driver::{DeviceSlice, LaunchConfig};
9use xlog_core::{Result, XlogError};
10use xlog_cuda::memory::TrackedCudaSlice;
11use xlog_cuda::provider::{pir_kernels, scan_kernels, RadixSortScratch, PIR_MODULE, SCAN_MODULE};
12use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
13
14use crate::compilation::gpu_pir::{GpuPirGraph, PIR_AND, PIR_CONST};
15
16/// Host-side PIR batch (for tests and host-driven workflows).
17#[derive(Debug, Clone)]
18pub struct PirBatch {
19    pub node_type: Vec<u8>,
20    pub leaf_id: Vec<u32>,
21    pub decision_var: Vec<u32>,
22    pub decision_child_false: Vec<u32>,
23    pub decision_child_true: Vec<u32>,
24    pub child_offsets: Vec<u32>,
25    pub children: Vec<u32>,
26}
27
28impl PirBatch {
29    pub fn len(&self) -> usize {
30        self.node_type.len()
31    }
32
33    pub fn is_empty(&self) -> bool {
34        self.node_type.is_empty()
35    }
36
37    /// Build a batch of AND nodes with the given child lists.
38    ///
39    /// This helper is primarily used by tests.
40    pub fn and_or_batch(children: Vec<Vec<u32>>) -> Self {
41        let num_nodes = children.len();
42        let mut node_type = vec![PIR_AND; num_nodes];
43        let leaf_id = vec![0u32; num_nodes];
44        let decision_var = vec![0u32; num_nodes];
45        let decision_child_false = vec![0u32; num_nodes];
46        let decision_child_true = vec![0u32; num_nodes];
47
48        let mut child_offsets = Vec::with_capacity(num_nodes + 1);
49        let mut flat_children = Vec::new();
50        child_offsets.push(0);
51        for kids in children {
52            flat_children.extend(kids);
53            child_offsets.push(flat_children.len() as u32);
54        }
55
56        if node_type.is_empty() {
57            node_type = Vec::new();
58        }
59
60        Self {
61            node_type,
62            leaf_id,
63            decision_var,
64            decision_child_false,
65            decision_child_true,
66            child_offsets,
67            children: flat_children,
68        }
69    }
70
71    pub fn to_device(&self, provider: &Arc<CudaKernelProvider>) -> Result<GpuPirBatch> {
72        let num_nodes = self.node_type.len();
73        if self.leaf_id.len() != num_nodes
74            || self.decision_var.len() != num_nodes
75            || self.decision_child_false.len() != num_nodes
76            || self.decision_child_true.len() != num_nodes
77        {
78            return Err(XlogError::Compilation(
79                "PirBatch: array length mismatch".to_string(),
80            ));
81        }
82        if self.child_offsets.len() != num_nodes + 1 {
83            return Err(XlogError::Compilation(
84                "PirBatch: child_offsets must be len num_nodes+1".to_string(),
85            ));
86        }
87        if let Some(&last) = self.child_offsets.last() {
88            if last as usize != self.children.len() {
89                return Err(XlogError::Compilation(
90                    "PirBatch: child_offsets last entry must equal children len".to_string(),
91                ));
92            }
93        }
94
95        let memory = provider.memory();
96
97        let mut d_node_type = memory.alloc::<u8>(num_nodes)?;
98        let mut d_leaf_id = memory.alloc::<u32>(num_nodes)?;
99        let mut d_decision_var = memory.alloc::<u32>(num_nodes)?;
100        let mut d_decision_child_false = memory.alloc::<u32>(num_nodes)?;
101        let mut d_decision_child_true = memory.alloc::<u32>(num_nodes)?;
102        let mut d_child_offsets = memory.alloc::<u32>(self.child_offsets.len())?;
103        let mut d_children = memory.alloc::<u32>(self.children.len())?;
104
105        provider
106            .htod_sync_copy_into_tracked(&self.node_type, &mut d_node_type)
107            .map_err(|e| XlogError::Kernel(format!("PirBatch upload node_type: {}", e)))?;
108        provider
109            .htod_sync_copy_into_tracked(&self.leaf_id, &mut d_leaf_id)
110            .map_err(|e| XlogError::Kernel(format!("PirBatch upload leaf_id: {}", e)))?;
111        provider
112            .htod_sync_copy_into_tracked(&self.decision_var, &mut d_decision_var)
113            .map_err(|e| XlogError::Kernel(format!("PirBatch upload decision_var: {}", e)))?;
114        provider
115            .htod_sync_copy_into_tracked(&self.decision_child_false, &mut d_decision_child_false)
116            .map_err(|e| {
117                XlogError::Kernel(format!("PirBatch upload decision_child_false: {}", e))
118            })?;
119        provider
120            .htod_sync_copy_into_tracked(&self.decision_child_true, &mut d_decision_child_true)
121            .map_err(|e| {
122                XlogError::Kernel(format!("PirBatch upload decision_child_true: {}", e))
123            })?;
124        provider
125            .htod_sync_copy_into_tracked(&self.child_offsets, &mut d_child_offsets)
126            .map_err(|e| XlogError::Kernel(format!("PirBatch upload child_offsets: {}", e)))?;
127        provider
128            .htod_sync_copy_into_tracked(&self.children, &mut d_children)
129            .map_err(|e| XlogError::Kernel(format!("PirBatch upload children: {}", e)))?;
130
131        Ok(GpuPirBatch {
132            node_type: d_node_type,
133            leaf_id: d_leaf_id,
134            decision_var: d_decision_var,
135            decision_child_false: d_decision_child_false,
136            decision_child_true: d_decision_child_true,
137            child_offsets: d_child_offsets,
138            children: d_children,
139        })
140    }
141}
142
143/// Device-resident PIR batch.
144pub struct GpuPirBatch {
145    pub node_type: TrackedCudaSlice<u8>,
146    pub leaf_id: TrackedCudaSlice<u32>,
147    pub decision_var: TrackedCudaSlice<u32>,
148    pub decision_child_false: TrackedCudaSlice<u32>,
149    pub decision_child_true: TrackedCudaSlice<u32>,
150    pub child_offsets: TrackedCudaSlice<u32>,
151    pub children: TrackedCudaSlice<u32>,
152}
153
154impl GpuPirBatch {
155    pub fn num_nodes(&self) -> usize {
156        self.node_type.len()
157    }
158
159    pub fn num_children(&self) -> usize {
160        self.children.len()
161    }
162}
163
164/// GPU PIR interner (device-side).
165pub struct GpuPirInterner {
166    provider: Arc<CudaKernelProvider>,
167    node_cap: u32,
168    child_cap: u32,
169    graph: GpuPirGraph,
170    graph_hashes: TrackedCudaSlice<u64>,
171    num_nodes: TrackedCudaSlice<u32>,
172    num_children: TrackedCudaSlice<u32>,
173}
174
175impl GpuPirInterner {
176    pub fn new(provider: &Arc<CudaKernelProvider>, node_cap: u32, child_cap: u32) -> Result<Self> {
177        if node_cap < 2 {
178            return Err(XlogError::Compilation(
179                "GpuPirInterner requires node_cap >= 2 for const nodes".to_string(),
180            ));
181        }
182        let memory = provider.memory();
183        let device = provider.device().inner();
184
185        let mut node_type = memory.alloc::<u8>(node_cap as usize)?;
186        let mut child_offsets = memory.alloc::<u32>((node_cap as usize) + 1)?;
187        let mut children = memory.alloc::<u32>(child_cap as usize)?;
188        let mut leaf_id = memory.alloc::<u32>(node_cap as usize)?;
189        let mut decision_var = memory.alloc::<u32>(node_cap as usize)?;
190        let mut decision_child_false = memory.alloc::<u32>(node_cap as usize)?;
191        let mut decision_child_true = memory.alloc::<u32>(node_cap as usize)?;
192        let mut graph_hashes = memory.alloc::<u64>(node_cap as usize)?;
193
194        device
195            .memset_zeros(&mut node_type)
196            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init node_type: {}", e)))?;
197        device
198            .memset_zeros(&mut child_offsets)
199            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init child_offsets: {}", e)))?;
200        device
201            .memset_zeros(&mut children)
202            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init children: {}", e)))?;
203        device
204            .memset_zeros(&mut decision_var)
205            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init decision_var: {}", e)))?;
206        device
207            .memset_zeros(&mut decision_child_false)
208            .map_err(|e| {
209                XlogError::Kernel(format!("GpuPirInterner init decision_child_false: {}", e))
210            })?;
211        device.memset_zeros(&mut decision_child_true).map_err(|e| {
212            XlogError::Kernel(format!("GpuPirInterner init decision_child_true: {}", e))
213        })?;
214        device
215            .memset_zeros(&mut graph_hashes)
216            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init graph_hashes: {}", e)))?;
217
218        let mut leaf_id_host = vec![0u32; node_cap as usize];
219        if node_cap > 1 {
220            leaf_id_host[1] = 1;
221        }
222        provider
223            .htod_launch_metadata_sync_copy_into(&leaf_id_host, &mut leaf_id)
224            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init leaf_id: {}", e)))?;
225
226        let hash_fn = device
227            .get_func(PIR_MODULE, pir_kernels::PIR_HASH_KEYS)
228            .ok_or_else(|| XlogError::Kernel("pir_hash_keys not found".to_string()))?;
229        let num_const = 2u32;
230        let block_size = 256u32;
231        let grid_const = num_const.div_ceil(block_size);
232        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
233        unsafe {
234            hash_fn.clone().launch(
235                LaunchConfig {
236                    grid_dim: (grid_const.max(1), 1, 1),
237                    block_dim: (block_size, 1, 1),
238                    shared_mem_bytes: 0,
239                },
240                (
241                    &node_type,
242                    &leaf_id,
243                    &decision_var,
244                    &decision_child_false,
245                    &decision_child_true,
246                    &child_offsets,
247                    &children,
248                    num_const,
249                    &mut graph_hashes,
250                ),
251            )
252        }
253        .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init hash: {}", e)))?;
254
255        let mut num_nodes = memory.alloc::<u32>(1)?;
256        let mut num_children = memory.alloc::<u32>(1)?;
257        provider
258            .htod_launch_metadata_sync_copy_into(&[2u32], &mut num_nodes)
259            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init num_nodes: {}", e)))?;
260        provider
261            .htod_launch_metadata_sync_copy_into(&[0u32], &mut num_children)
262            .map_err(|e| XlogError::Kernel(format!("GpuPirInterner init num_children: {}", e)))?;
263
264        Ok(Self {
265            provider: Arc::clone(provider),
266            node_cap,
267            child_cap,
268            graph: GpuPirGraph {
269                node_type,
270                child_offsets,
271                children,
272                leaf_id,
273                decision_var,
274                decision_child_false,
275                decision_child_true,
276            },
277            graph_hashes,
278            num_nodes,
279            num_children,
280        })
281    }
282
283    pub fn graph(&self) -> &GpuPirGraph {
284        &self.graph
285    }
286
287    pub fn intern_batch(&mut self, batch: &PirBatch) -> Result<TrackedCudaSlice<u32>> {
288        if batch.node_type.contains(&PIR_CONST) {
289            return Err(XlogError::Compilation(
290                "GpuPirInterner does not accept PIR_CONST in batches".to_string(),
291            ));
292        }
293        let mut device_batch = batch.to_device(&self.provider)?;
294        self.intern_device_batch(&mut device_batch)
295    }
296
297    pub fn intern_device_batch(
298        &mut self,
299        batch: &mut GpuPirBatch,
300    ) -> Result<TrackedCudaSlice<u32>> {
301        let num_nodes = batch.num_nodes();
302        if num_nodes == 0 {
303            return self.provider.memory().alloc::<u32>(0);
304        }
305        let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
306            XlogError::Compilation("GpuPirInterner: num_nodes overflow".to_string())
307        })?;
308        let num_children = batch.num_children();
309        let num_children_u32 = u32::try_from(num_children).map_err(|_| {
310            XlogError::Compilation("GpuPirInterner: num_children overflow".to_string())
311        })?;
312
313        let device = self.provider.device().inner();
314        let memory = self.provider.memory();
315        let block_size = 256u32;
316
317        // Canonicalize AND/OR children (sort + dedup) into new buffers.
318        let mut canon_child_offsets = memory.alloc::<u32>(num_nodes + 1)?;
319        let mut canon_children = memory.alloc::<u32>(num_children)?;
320
321        if num_children_u32 == 0 {
322            device.memset_zeros(&mut canon_child_offsets).map_err(|e| {
323                XlogError::Kernel(format!("GpuPirInterner zero child_offsets: {}", e))
324            })?;
325        } else {
326            let mut parent_ids = memory.alloc::<u32>(num_children)?;
327            let fill_fn = device
328                .get_func(PIR_MODULE, pir_kernels::PIR_FILL_CHILD_PARENTS)
329                .ok_or_else(|| XlogError::Kernel("pir_fill_child_parents not found".to_string()))?;
330            let grid_nodes = num_nodes_u32.div_ceil(block_size);
331            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
332            unsafe {
333                fill_fn.clone().launch(
334                    LaunchConfig {
335                        grid_dim: (grid_nodes, 1, 1),
336                        block_dim: (block_size, 1, 1),
337                        shared_mem_bytes: 0,
338                    },
339                    (&batch.child_offsets, num_nodes_u32, &mut parent_ids),
340                )
341            }
342            .map_err(|e| XlogError::Kernel(format!("pir_fill_child_parents failed: {}", e)))?;
343
344            let mut sort_scratch = RadixSortScratch::new(&self.provider, num_children_u32)?;
345            self.provider.radix_sort_u32_pairs(
346                &mut batch.children,
347                &mut parent_ids,
348                num_children_u32,
349                &mut sort_scratch,
350            )?;
351            self.provider.radix_sort_u32_pairs(
352                &mut parent_ids,
353                &mut batch.children,
354                num_children_u32,
355                &mut sort_scratch,
356            )?;
357
358            let mut pair_unique_mask = memory.alloc::<u8>(num_children)?;
359            let mark_pairs = device
360                .get_func(PIR_MODULE, pir_kernels::PIR_MARK_UNIQUE_PAIRS)
361                .ok_or_else(|| XlogError::Kernel("pir_mark_unique_pairs not found".to_string()))?;
362            let grid_pairs = num_children_u32.div_ceil(block_size);
363            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
364            unsafe {
365                mark_pairs.clone().launch(
366                    LaunchConfig {
367                        grid_dim: (grid_pairs, 1, 1),
368                        block_dim: (block_size, 1, 1),
369                        shared_mem_bytes: 0,
370                    },
371                    (
372                        &parent_ids,
373                        &batch.children,
374                        num_children_u32,
375                        &mut pair_unique_mask,
376                    ),
377                )
378            }
379            .map_err(|e| XlogError::Kernel(format!("pir_mark_unique_pairs failed: {}", e)))?;
380
381            let pair_prefix = self
382                .provider
383                .scan_u8_mask_device(&pair_unique_mask, num_children_u32)?;
384
385            let mut unique_pairs_total = memory.alloc::<u32>(1)?;
386            device
387                .memset_zeros(&mut unique_pairs_total)
388                .map_err(|e| XlogError::Kernel(format!("zero unique_pairs_total: {}", e)))?;
389            let count_mask = device
390                .get_func(SCAN_MODULE, scan_kernels::COUNT_MASK)
391                .ok_or_else(|| XlogError::Kernel("count_mask not found".to_string()))?;
392            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
393            unsafe {
394                count_mask.clone().launch(
395                    LaunchConfig {
396                        grid_dim: (grid_pairs, 1, 1),
397                        block_dim: (block_size, 1, 1),
398                        shared_mem_bytes: 0,
399                    },
400                    (&pair_unique_mask, num_children_u32, &mut unique_pairs_total),
401                )
402            }
403            .map_err(|e| XlogError::Kernel(format!("count_mask (pairs) failed: {}", e)))?;
404
405            let mut canon_parent = memory.alloc::<u32>(num_children)?;
406            let compact_pairs = device
407                .get_func(PIR_MODULE, pir_kernels::PIR_COMPACT_PAIRS)
408                .ok_or_else(|| XlogError::Kernel("pir_compact_pairs not found".to_string()))?;
409            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
410            unsafe {
411                compact_pairs.clone().launch(
412                    LaunchConfig {
413                        grid_dim: (grid_pairs, 1, 1),
414                        block_dim: (block_size, 1, 1),
415                        shared_mem_bytes: 0,
416                    },
417                    (
418                        &parent_ids,
419                        &batch.children,
420                        &pair_unique_mask,
421                        &pair_prefix,
422                        num_children_u32,
423                        &mut canon_parent,
424                        &mut canon_children,
425                    ),
426                )
427            }
428            .map_err(|e| XlogError::Kernel(format!("pir_compact_pairs failed: {}", e)))?;
429
430            let mut child_counts = memory.alloc::<u32>(num_nodes)?;
431            device
432                .memset_zeros(&mut child_counts)
433                .map_err(|e| XlogError::Kernel(format!("zero child_counts: {}", e)))?;
434            let count_children = device
435                .get_func(PIR_MODULE, pir_kernels::PIR_COUNT_CHILDREN)
436                .ok_or_else(|| XlogError::Kernel("pir_count_children not found".to_string()))?;
437            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
438            unsafe {
439                count_children.clone().launch(
440                    LaunchConfig {
441                        grid_dim: (grid_pairs, 1, 1),
442                        block_dim: (block_size, 1, 1),
443                        shared_mem_bytes: 0,
444                    },
445                    (
446                        &canon_parent,
447                        &unique_pairs_total,
448                        num_nodes_u32,
449                        &mut child_counts,
450                    ),
451                )
452            }
453            .map_err(|e| XlogError::Kernel(format!("pir_count_children failed: {}", e)))?;
454
455            self.provider
456                .exclusive_scan_u32_inplace(&mut child_counts, num_nodes_u32)?;
457
458            let write_offsets = device
459                .get_func(PIR_MODULE, pir_kernels::PIR_WRITE_CHILD_OFFSETS)
460                .ok_or_else(|| {
461                    XlogError::Kernel("pir_write_child_offsets not found".to_string())
462                })?;
463            let grid_nodes = num_nodes_u32.div_ceil(block_size);
464            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
465            unsafe {
466                write_offsets.clone().launch(
467                    LaunchConfig {
468                        grid_dim: (grid_nodes.max(1), 1, 1),
469                        block_dim: (block_size, 1, 1),
470                        shared_mem_bytes: 0,
471                    },
472                    (
473                        &child_counts,
474                        num_nodes_u32,
475                        &unique_pairs_total,
476                        &mut canon_child_offsets,
477                    ),
478                )
479            }
480            .map_err(|e| XlogError::Kernel(format!("pir_write_child_offsets failed: {}", e)))?;
481        }
482
483        // Hash and pack keys for nodes.
484        let mut hashes = memory.alloc::<u64>(num_nodes)?;
485        let hash_fn = device
486            .get_func(PIR_MODULE, pir_kernels::PIR_HASH_KEYS)
487            .ok_or_else(|| XlogError::Kernel("pir_hash_keys not found".to_string()))?;
488        let grid_nodes = num_nodes_u32.div_ceil(block_size);
489        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
490        unsafe {
491            hash_fn.clone().launch(
492                LaunchConfig {
493                    grid_dim: (grid_nodes, 1, 1),
494                    block_dim: (block_size, 1, 1),
495                    shared_mem_bytes: 0,
496                },
497                (
498                    &batch.node_type,
499                    &batch.leaf_id,
500                    &batch.decision_var,
501                    &batch.decision_child_false,
502                    &batch.decision_child_true,
503                    &canon_child_offsets,
504                    &canon_children,
505                    num_nodes_u32,
506                    &mut hashes,
507                ),
508            )
509        }
510        .map_err(|e| XlogError::Kernel(format!("pir_hash_keys failed: {}", e)))?;
511
512        let mut key_tag = memory.alloc::<u32>(num_nodes)?;
513        let mut key_payload = memory.alloc::<u32>(num_nodes)?;
514        let mut key_child_len = memory.alloc::<u32>(num_nodes)?;
515        let pack_fn = device
516            .get_func(PIR_MODULE, pir_kernels::PIR_PACK_KEYS)
517            .ok_or_else(|| XlogError::Kernel("pir_pack_keys not found".to_string()))?;
518        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
519        unsafe {
520            pack_fn.clone().launch(
521                LaunchConfig {
522                    grid_dim: (grid_nodes, 1, 1),
523                    block_dim: (block_size, 1, 1),
524                    shared_mem_bytes: 0,
525                },
526                (
527                    &batch.node_type,
528                    &batch.leaf_id,
529                    &batch.decision_var,
530                    &batch.decision_child_false,
531                    &batch.decision_child_true,
532                    &canon_child_offsets,
533                    num_nodes_u32,
534                    &mut key_tag,
535                    &mut key_payload,
536                    &mut key_child_len,
537                ),
538            )
539        }
540        .map_err(|e| XlogError::Kernel(format!("pir_pack_keys failed: {}", e)))?;
541
542        // Sort indices by (hash, tag, payload, len).
543        let mut indices = memory.alloc::<u32>(num_nodes)?;
544        self.provider.init_indices(&mut indices, num_nodes_u32)?;
545        let mut keys = memory.alloc::<u32>(num_nodes)?;
546        let mut node_sort = RadixSortScratch::new(&self.provider, num_nodes_u32)?;
547
548        self.provider
549            .gather_u32_by_indices(&key_child_len, &indices, &mut keys, num_nodes_u32)?;
550        self.provider.radix_sort_u32_pairs(
551            &mut keys,
552            &mut indices,
553            num_nodes_u32,
554            &mut node_sort,
555        )?;
556
557        self.provider
558            .gather_u32_by_indices(&key_payload, &indices, &mut keys, num_nodes_u32)?;
559        self.provider.radix_sort_u32_pairs(
560            &mut keys,
561            &mut indices,
562            num_nodes_u32,
563            &mut node_sort,
564        )?;
565
566        self.provider
567            .gather_u32_by_indices(&key_tag, &indices, &mut keys, num_nodes_u32)?;
568        self.provider.radix_sort_u32_pairs(
569            &mut keys,
570            &mut indices,
571            num_nodes_u32,
572            &mut node_sort,
573        )?;
574
575        self.provider
576            .gather_u64_lo_by_indices(&hashes, &indices, &mut keys, num_nodes_u32)?;
577        self.provider.radix_sort_u32_pairs(
578            &mut keys,
579            &mut indices,
580            num_nodes_u32,
581            &mut node_sort,
582        )?;
583
584        self.provider
585            .gather_u64_hi_by_indices(&hashes, &indices, &mut keys, num_nodes_u32)?;
586        self.provider.radix_sort_u32_pairs(
587            &mut keys,
588            &mut indices,
589            num_nodes_u32,
590            &mut node_sort,
591        )?;
592
593        // Gather sorted node arrays.
594        let mut sorted_node_type = memory.alloc::<u8>(num_nodes)?;
595        let mut sorted_leaf_id = memory.alloc::<u32>(num_nodes)?;
596        let mut sorted_decision_var = memory.alloc::<u32>(num_nodes)?;
597        let mut sorted_decision_child_false = memory.alloc::<u32>(num_nodes)?;
598        let mut sorted_decision_child_true = memory.alloc::<u32>(num_nodes)?;
599
600        self.provider.gather_u8_by_indices(
601            &batch.node_type,
602            &indices,
603            &mut sorted_node_type,
604            num_nodes_u32,
605        )?;
606        self.provider.gather_u32_by_indices(
607            &batch.leaf_id,
608            &indices,
609            &mut sorted_leaf_id,
610            num_nodes_u32,
611        )?;
612        self.provider.gather_u32_by_indices(
613            &batch.decision_var,
614            &indices,
615            &mut sorted_decision_var,
616            num_nodes_u32,
617        )?;
618        self.provider.gather_u32_by_indices(
619            &batch.decision_child_false,
620            &indices,
621            &mut sorted_decision_child_false,
622            num_nodes_u32,
623        )?;
624        self.provider.gather_u32_by_indices(
625            &batch.decision_child_true,
626            &indices,
627            &mut sorted_decision_child_true,
628            num_nodes_u32,
629        )?;
630
631        // Build sorted child offsets/children.
632        let mut sorted_child_len = memory.alloc::<u32>(num_nodes)?;
633        self.provider.gather_u32_by_indices(
634            &key_child_len,
635            &indices,
636            &mut sorted_child_len,
637            num_nodes_u32,
638        )?;
639        self.provider
640            .exclusive_scan_u32_inplace(&mut sorted_child_len, num_nodes_u32)?;
641
642        let mut sorted_child_offsets = memory.alloc::<u32>(num_nodes + 1)?;
643        let write_offsets = device
644            .get_func(PIR_MODULE, pir_kernels::PIR_WRITE_CHILD_OFFSETS)
645            .ok_or_else(|| XlogError::Kernel("pir_write_child_offsets not found".to_string()))?;
646        let total_children_view = canon_child_offsets.slice(num_nodes..(num_nodes + 1));
647        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
648        unsafe {
649            write_offsets.clone().launch(
650                LaunchConfig {
651                    grid_dim: (grid_nodes.max(1), 1, 1),
652                    block_dim: (block_size, 1, 1),
653                    shared_mem_bytes: 0,
654                },
655                (
656                    &sorted_child_len,
657                    num_nodes_u32,
658                    &total_children_view,
659                    &mut sorted_child_offsets,
660                ),
661            )
662        }
663        .map_err(|e| XlogError::Kernel(format!("pir_write_child_offsets(sorted) failed: {}", e)))?;
664
665        let mut sorted_children = memory.alloc::<u32>(num_children)?;
666        let gather_children = device
667            .get_func(PIR_MODULE, pir_kernels::PIR_GATHER_CHILDREN)
668            .ok_or_else(|| XlogError::Kernel("pir_gather_children not found".to_string()))?;
669        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
670        unsafe {
671            gather_children.clone().launch(
672                LaunchConfig {
673                    grid_dim: (grid_nodes, 1, 1),
674                    block_dim: (block_size, 1, 1),
675                    shared_mem_bytes: 0,
676                },
677                (
678                    &indices,
679                    &canon_child_offsets,
680                    &canon_children,
681                    &sorted_child_offsets,
682                    num_nodes_u32,
683                    &mut sorted_children,
684                ),
685            )
686        }
687        .map_err(|e| XlogError::Kernel(format!("pir_gather_children failed: {}", e)))?;
688
689        // Recompute hashes in sorted order for uniqueness checks.
690        let mut sorted_hashes = memory.alloc::<u64>(num_nodes)?;
691        let hash_fn = device
692            .get_func(PIR_MODULE, pir_kernels::PIR_HASH_KEYS)
693            .ok_or_else(|| XlogError::Kernel("pir_hash_keys not found".to_string()))?;
694        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
695        unsafe {
696            hash_fn.clone().launch(
697                LaunchConfig {
698                    grid_dim: (grid_nodes, 1, 1),
699                    block_dim: (block_size, 1, 1),
700                    shared_mem_bytes: 0,
701                },
702                (
703                    &sorted_node_type,
704                    &sorted_leaf_id,
705                    &sorted_decision_var,
706                    &sorted_decision_child_false,
707                    &sorted_decision_child_true,
708                    &sorted_child_offsets,
709                    &sorted_children,
710                    num_nodes_u32,
711                    &mut sorted_hashes,
712                ),
713            )
714        }
715        .map_err(|e| XlogError::Kernel(format!("pir_hash_keys(sorted) failed: {}", e)))?;
716
717        let hash_table = self
718            .provider
719            .build_hash_table_u64(&self.graph_hashes, self.node_cap)?;
720        let existing_id = memory.alloc::<u32>(num_nodes)?;
721        let find_existing = device
722            .get_func(PIR_MODULE, pir_kernels::PIR_FIND_EXISTING)
723            .ok_or_else(|| XlogError::Kernel("pir_find_existing not found".to_string()))?;
724        let mut find_params: Vec<*mut c_void> = vec![
725            (&sorted_hashes).as_kernel_param(),
726            (&sorted_node_type).as_kernel_param(),
727            (&sorted_leaf_id).as_kernel_param(),
728            (&sorted_decision_var).as_kernel_param(),
729            (&sorted_decision_child_false).as_kernel_param(),
730            (&sorted_decision_child_true).as_kernel_param(),
731            (&sorted_child_offsets).as_kernel_param(),
732            (&sorted_children).as_kernel_param(),
733            num_nodes_u32.as_kernel_param(),
734            (&self.graph.node_type).as_kernel_param(),
735            (&self.graph.child_offsets).as_kernel_param(),
736            (&self.graph.children).as_kernel_param(),
737            (&self.graph.leaf_id).as_kernel_param(),
738            (&self.graph.decision_var).as_kernel_param(),
739            (&self.graph.decision_child_false).as_kernel_param(),
740            (&self.graph.decision_child_true).as_kernel_param(),
741            (&self.num_nodes).as_kernel_param(),
742            (&hash_table.bucket_offsets).as_kernel_param(),
743            (&hash_table.bucket_counts).as_kernel_param(),
744            (&hash_table.bucket_entries).as_kernel_param(),
745            (&hash_table.bucket_entry_hashes).as_kernel_param(),
746            hash_table.bucket_mask.as_kernel_param(),
747            (&existing_id).as_kernel_param(),
748        ];
749        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
750        unsafe {
751            find_existing.clone().launch(
752                LaunchConfig {
753                    grid_dim: (grid_nodes, 1, 1),
754                    block_dim: (block_size, 1, 1),
755                    shared_mem_bytes: 0,
756                },
757                &mut find_params,
758            )
759        }
760        .map_err(|e| XlogError::Kernel(format!("pir_find_existing failed: {}", e)))?;
761
762        // Mark unique nodes in sorted order.
763        let mut unique_mask = memory.alloc::<u8>(num_nodes)?;
764        let mark_unique = device
765            .get_func(PIR_MODULE, pir_kernels::PIR_MARK_UNIQUE)
766            .ok_or_else(|| XlogError::Kernel("pir_mark_unique not found".to_string()))?;
767        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
768        unsafe {
769            mark_unique.clone().launch(
770                LaunchConfig {
771                    grid_dim: (grid_nodes, 1, 1),
772                    block_dim: (block_size, 1, 1),
773                    shared_mem_bytes: 0,
774                },
775                (
776                    &sorted_hashes,
777                    &sorted_node_type,
778                    &sorted_leaf_id,
779                    &sorted_decision_var,
780                    &sorted_decision_child_false,
781                    &sorted_decision_child_true,
782                    &sorted_child_offsets,
783                    &sorted_children,
784                    num_nodes_u32,
785                    &mut unique_mask,
786                ),
787            )
788        }
789        .map_err(|e| XlogError::Kernel(format!("pir_mark_unique failed: {}", e)))?;
790
791        let mut new_mask = memory.alloc::<u8>(num_nodes)?;
792        let mark_new = device
793            .get_func(PIR_MODULE, pir_kernels::PIR_MARK_NEW_GROUPS)
794            .ok_or_else(|| XlogError::Kernel("pir_mark_new_groups not found".to_string()))?;
795        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
796        unsafe {
797            mark_new.clone().launch(
798                LaunchConfig {
799                    grid_dim: (grid_nodes, 1, 1),
800                    block_dim: (block_size, 1, 1),
801                    shared_mem_bytes: 0,
802                },
803                (&unique_mask, &existing_id, num_nodes_u32, &mut new_mask),
804            )
805        }
806        .map_err(|e| XlogError::Kernel(format!("pir_mark_new_groups failed: {}", e)))?;
807
808        let unique_prefix = self
809            .provider
810            .scan_u8_mask_device(&unique_mask, num_nodes_u32)?;
811
812        let new_prefix = self
813            .provider
814            .scan_u8_mask_device(&new_mask, num_nodes_u32)?;
815
816        let mut new_nodes_total = memory.alloc::<u32>(1)?;
817        device
818            .memset_zeros(&mut new_nodes_total)
819            .map_err(|e| XlogError::Kernel(format!("zero new_nodes_total: {}", e)))?;
820        let count_mask = device
821            .get_func(SCAN_MODULE, scan_kernels::COUNT_MASK)
822            .ok_or_else(|| XlogError::Kernel("count_mask not found".to_string()))?;
823        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
824        unsafe {
825            count_mask.clone().launch(
826                LaunchConfig {
827                    grid_dim: (grid_nodes, 1, 1),
828                    block_dim: (block_size, 1, 1),
829                    shared_mem_bytes: 0,
830                },
831                (&new_mask, num_nodes_u32, &mut new_nodes_total),
832            )
833        }
834        .map_err(|e| XlogError::Kernel(format!("count_mask (new nodes) failed: {}", e)))?;
835
836        let mut group_node_id = memory.alloc::<u32>(num_nodes)?;
837        let build_group = device
838            .get_func(PIR_MODULE, pir_kernels::PIR_BUILD_GROUP_IDS)
839            .ok_or_else(|| XlogError::Kernel("pir_build_group_ids not found".to_string()))?;
840        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
841        unsafe {
842            build_group.clone().launch(
843                LaunchConfig {
844                    grid_dim: (grid_nodes, 1, 1),
845                    block_dim: (block_size, 1, 1),
846                    shared_mem_bytes: 0,
847                },
848                (
849                    &unique_mask,
850                    &unique_prefix,
851                    &existing_id,
852                    &new_prefix,
853                    &self.num_nodes,
854                    num_nodes_u32,
855                    &mut group_node_id,
856                ),
857            )
858        }
859        .map_err(|e| XlogError::Kernel(format!("pir_build_group_ids failed: {}", e)))?;
860
861        let mut graph_child_counts = memory.alloc::<u32>(num_nodes)?;
862        let build_counts = device
863            .get_func(PIR_MODULE, pir_kernels::PIR_BUILD_GRAPH_CHILD_COUNTS)
864            .ok_or_else(|| {
865                XlogError::Kernel("pir_build_graph_child_counts not found".to_string())
866            })?;
867        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
868        unsafe {
869            build_counts.clone().launch(
870                LaunchConfig {
871                    grid_dim: (grid_nodes, 1, 1),
872                    block_dim: (block_size, 1, 1),
873                    shared_mem_bytes: 0,
874                },
875                (
876                    &sorted_node_type,
877                    &sorted_child_offsets,
878                    &new_mask,
879                    num_nodes_u32,
880                    &mut graph_child_counts,
881                ),
882            )
883        }
884        .map_err(|e| XlogError::Kernel(format!("pir_build_graph_child_counts failed: {}", e)))?;
885
886        let mut graph_children_total = memory.alloc::<u32>(1)?;
887        device
888            .memset_zeros(&mut graph_children_total)
889            .map_err(|e| XlogError::Kernel(format!("zero graph_children_total: {}", e)))?;
890        let sum_counts = device
891            .get_func(PIR_MODULE, pir_kernels::PIR_SUM_COUNTS)
892            .ok_or_else(|| XlogError::Kernel("pir_sum_counts not found".to_string()))?;
893        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
894        unsafe {
895            sum_counts.clone().launch(
896                LaunchConfig {
897                    grid_dim: (grid_nodes, 1, 1),
898                    block_dim: (block_size, 1, 1),
899                    shared_mem_bytes: 0,
900                },
901                (
902                    &graph_child_counts,
903                    num_nodes_u32,
904                    &mut graph_children_total,
905                ),
906            )
907        }
908        .map_err(|e| XlogError::Kernel(format!("pir_sum_counts failed: {}", e)))?;
909
910        self.provider
911            .exclusive_scan_u32_inplace(&mut graph_child_counts, num_nodes_u32)?;
912
913        let out_ids = memory.alloc::<u32>(num_nodes)?;
914        let emit = device
915            .get_func(PIR_MODULE, pir_kernels::PIR_EMIT_NODES_AND_IDS)
916            .ok_or_else(|| XlogError::Kernel("pir_emit_nodes_and_ids not found".to_string()))?;
917        let mut emit_params: Vec<*mut c_void> = vec![
918            (&sorted_node_type).as_kernel_param(),
919            (&sorted_leaf_id).as_kernel_param(),
920            (&sorted_decision_var).as_kernel_param(),
921            (&sorted_decision_child_false).as_kernel_param(),
922            (&sorted_decision_child_true).as_kernel_param(),
923            (&sorted_child_offsets).as_kernel_param(),
924            (&sorted_children).as_kernel_param(),
925            (&unique_mask).as_kernel_param(),
926            (&unique_prefix).as_kernel_param(),
927            (&group_node_id).as_kernel_param(),
928            (&graph_child_counts).as_kernel_param(),
929            (&indices).as_kernel_param(),
930            num_nodes_u32.as_kernel_param(),
931            (&self.num_nodes).as_kernel_param(),
932            (&self.num_children).as_kernel_param(),
933            self.node_cap.as_kernel_param(),
934            self.child_cap.as_kernel_param(),
935            (&self.graph.node_type).as_kernel_param(),
936            (&self.graph.child_offsets).as_kernel_param(),
937            (&self.graph.children).as_kernel_param(),
938            (&self.graph.leaf_id).as_kernel_param(),
939            (&self.graph.decision_var).as_kernel_param(),
940            (&self.graph.decision_child_false).as_kernel_param(),
941            (&self.graph.decision_child_true).as_kernel_param(),
942            (&new_mask).as_kernel_param(),
943            (&sorted_hashes).as_kernel_param(),
944            (&self.graph_hashes).as_kernel_param(),
945            (&out_ids).as_kernel_param(),
946        ];
947        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
948        unsafe {
949            emit.clone().launch(
950                LaunchConfig {
951                    grid_dim: (grid_nodes, 1, 1),
952                    block_dim: (block_size, 1, 1),
953                    shared_mem_bytes: 0,
954                },
955                &mut emit_params,
956            )
957        }
958        .map_err(|e| XlogError::Kernel(format!("pir_emit_nodes_and_ids failed: {}", e)))?;
959
960        let update_counts = device
961            .get_func(PIR_MODULE, pir_kernels::PIR_UPDATE_COUNTS)
962            .ok_or_else(|| XlogError::Kernel("pir_update_counts not found".to_string()))?;
963        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
964        unsafe {
965            update_counts.clone().launch(
966                LaunchConfig {
967                    grid_dim: (1, 1, 1),
968                    block_dim: (1, 1, 1),
969                    shared_mem_bytes: 0,
970                },
971                (
972                    &new_nodes_total,
973                    &graph_children_total,
974                    self.node_cap,
975                    self.child_cap,
976                    &mut self.num_nodes,
977                    &mut self.num_children,
978                ),
979            )
980        }
981        .map_err(|e| XlogError::Kernel(format!("pir_update_counts failed: {}", e)))?;
982        // No device synchronize: returns device-resident IDs; same-stream ordering suffices.
983        Ok(out_ids)
984    }
985}