Skip to main content

xlog_prob/compilation/
gpu_pir.rs

1//! GPU-resident Provenance IR (PIR) representation.
2//!
3//! Mirrors `crate::pir::PirGraph` in a structure-of-arrays layout on device.
4
5use std::sync::Arc;
6
7use cudarc::driver::DeviceSlice;
8use xlog_core::{Result, XlogError};
9use xlog_cuda::memory::TrackedCudaSlice;
10use xlog_cuda::CudaKernelProvider;
11
12use crate::pir::{PirGraph, PirNode, PirNodeId};
13
14/// Node type tags matching `PirNode` variants.
15pub(crate) const PIR_CONST: u8 = 0;
16pub const PIR_LIT: u8 = 1;
17pub const PIR_NEG_LIT: u8 = 2;
18pub const PIR_AND: u8 = 3;
19pub const PIR_OR: u8 = 4;
20pub(crate) const PIR_DECISION: u8 = 5;
21
22/// GPU-resident PIR graph (device-side mirror of `pir::PirGraph`).
23pub struct GpuPirGraph {
24    pub node_type: TrackedCudaSlice<u8>,
25    pub child_offsets: TrackedCudaSlice<u32>,
26    pub children: TrackedCudaSlice<u32>,
27    pub leaf_id: TrackedCudaSlice<u32>,
28    pub decision_var: TrackedCudaSlice<u32>,
29    pub decision_child_false: TrackedCudaSlice<u32>,
30    pub decision_child_true: TrackedCudaSlice<u32>,
31}
32
33/// GPU-resident PIR root list.
34pub struct GpuPirRoots {
35    pub roots: TrackedCudaSlice<u32>,
36}
37
38impl GpuPirGraph {
39    /// Upload a host PIR graph to device buffers.
40    ///
41    /// This is intended for tests and tooling. Production GPU-native paths
42    /// should construct PIR directly on device.
43    pub fn from_host(pir: &PirGraph, provider: &Arc<CudaKernelProvider>) -> Result<Self> {
44        let num_nodes = pir.len();
45        let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
46            XlogError::Compilation("GpuPirGraph::from_host: node count overflow".to_string())
47        })?;
48
49        let mut node_type: Vec<u8> = Vec::with_capacity(num_nodes);
50        let mut child_offsets: Vec<u32> = Vec::with_capacity(num_nodes + 1);
51        let mut children: Vec<u32> = Vec::new();
52        let mut leaf_id: Vec<u32> = Vec::with_capacity(num_nodes);
53        let mut decision_var: Vec<u32> = Vec::with_capacity(num_nodes);
54        let mut decision_child_false: Vec<u32> = Vec::with_capacity(num_nodes);
55        let mut decision_child_true: Vec<u32> = Vec::with_capacity(num_nodes);
56
57        child_offsets.push(0);
58
59        for (idx, node) in pir.nodes().iter().enumerate() {
60            let node_id = u32::try_from(idx).map_err(|_| {
61                XlogError::Compilation("GpuPirGraph::from_host: node id overflow".to_string())
62            })?;
63
64            match node {
65                PirNode::Const(value) => {
66                    node_type.push(PIR_CONST);
67                    leaf_id.push(u32::from(*value));
68                    decision_var.push(0);
69                    decision_child_false.push(0);
70                    decision_child_true.push(0);
71                }
72                PirNode::Lit { leaf } => {
73                    node_type.push(PIR_LIT);
74                    leaf_id.push(leaf.as_u32());
75                    decision_var.push(0);
76                    decision_child_false.push(0);
77                    decision_child_true.push(0);
78                }
79                PirNode::NegLit { leaf } => {
80                    node_type.push(PIR_NEG_LIT);
81                    leaf_id.push(leaf.as_u32());
82                    decision_var.push(0);
83                    decision_child_false.push(0);
84                    decision_child_true.push(0);
85                }
86                PirNode::And { children: kids } => {
87                    validate_children_sorted(node_id, kids, num_nodes_u32)?;
88                    node_type.push(PIR_AND);
89                    leaf_id.push(0);
90                    decision_var.push(0);
91                    decision_child_false.push(0);
92                    decision_child_true.push(0);
93                    for &child in kids {
94                        children.push(child.as_u32());
95                    }
96                }
97                PirNode::Or { children: kids } => {
98                    validate_children_sorted(node_id, kids, num_nodes_u32)?;
99                    node_type.push(PIR_OR);
100                    leaf_id.push(0);
101                    decision_var.push(0);
102                    decision_child_false.push(0);
103                    decision_child_true.push(0);
104                    for &child in kids {
105                        children.push(child.as_u32());
106                    }
107                }
108                PirNode::Decision {
109                    var,
110                    child_false,
111                    child_true,
112                } => {
113                    validate_child_id(node_id, *child_false, num_nodes_u32)?;
114                    validate_child_id(node_id, *child_true, num_nodes_u32)?;
115                    node_type.push(PIR_DECISION);
116                    leaf_id.push(0);
117                    decision_var.push(var.as_u32());
118                    decision_child_false.push(child_false.as_u32());
119                    decision_child_true.push(child_true.as_u32());
120                }
121            }
122
123            let next_off = u32::try_from(children.len()).map_err(|_| {
124                XlogError::Compilation(
125                    "GpuPirGraph::from_host: children count exceeds u32".to_string(),
126                )
127            })?;
128            child_offsets.push(next_off);
129        }
130
131        if child_offsets.len() != num_nodes + 1 {
132            return Err(XlogError::Compilation(
133                "GpuPirGraph::from_host: child_offsets length mismatch".to_string(),
134            ));
135        }
136
137        let memory = provider.memory();
138
139        let mut d_node_type = memory.alloc::<u8>(node_type.len())?;
140        let mut d_child_offsets = memory.alloc::<u32>(child_offsets.len())?;
141        let mut d_children = memory.alloc::<u32>(children.len())?;
142        let mut d_leaf_id = memory.alloc::<u32>(leaf_id.len())?;
143        let mut d_decision_var = memory.alloc::<u32>(decision_var.len())?;
144        let mut d_decision_child_false = memory.alloc::<u32>(decision_child_false.len())?;
145        let mut d_decision_child_true = memory.alloc::<u32>(decision_child_true.len())?;
146
147        provider
148            .htod_sync_copy_into_tracked(&node_type, &mut d_node_type)
149            .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload node_type: {}", e)))?;
150        provider
151            .htod_sync_copy_into_tracked(&child_offsets, &mut d_child_offsets)
152            .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload child_offsets: {}", e)))?;
153        provider
154            .htod_sync_copy_into_tracked(&children, &mut d_children)
155            .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload children: {}", e)))?;
156        provider
157            .htod_sync_copy_into_tracked(&leaf_id, &mut d_leaf_id)
158            .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload leaf_id: {}", e)))?;
159        provider
160            .htod_sync_copy_into_tracked(&decision_var, &mut d_decision_var)
161            .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload decision_var: {}", e)))?;
162        provider
163            .htod_sync_copy_into_tracked(&decision_child_false, &mut d_decision_child_false)
164            .map_err(|e| {
165                XlogError::Kernel(format!("GpuPirGraph upload decision_child_false: {}", e))
166            })?;
167        provider
168            .htod_sync_copy_into_tracked(&decision_child_true, &mut d_decision_child_true)
169            .map_err(|e| {
170                XlogError::Kernel(format!("GpuPirGraph upload decision_child_true: {}", e))
171            })?;
172
173        Ok(Self {
174            node_type: d_node_type,
175            child_offsets: d_child_offsets,
176            children: d_children,
177            leaf_id: d_leaf_id,
178            decision_var: d_decision_var,
179            decision_child_false: d_decision_child_false,
180            decision_child_true: d_decision_child_true,
181        })
182    }
183
184    pub fn num_nodes(&self) -> usize {
185        self.node_type.len()
186    }
187}
188
189impl GpuPirRoots {
190    pub fn from_host(roots: &[PirNodeId], provider: &Arc<CudaKernelProvider>) -> Result<Self> {
191        let mut host: Vec<u32> = Vec::with_capacity(roots.len());
192        for &r in roots {
193            host.push(r.as_u32());
194        }
195
196        let memory = provider.memory();
197        let mut d_roots = memory.alloc::<u32>(host.len())?;
198        provider
199            .htod_sync_copy_into_tracked(&host, &mut d_roots)
200            .map_err(|e| XlogError::Kernel(format!("GpuPirRoots upload: {}", e)))?;
201
202        Ok(Self { roots: d_roots })
203    }
204}
205
206fn validate_child_id(parent: u32, child: PirNodeId, num_nodes: u32) -> Result<()> {
207    let id = child.as_u32();
208    if id >= num_nodes {
209        return Err(XlogError::Compilation(format!(
210            "GpuPirGraph::from_host: child {:?} out of bounds for parent {}",
211            child, parent
212        )));
213    }
214    Ok(())
215}
216
217fn validate_children_sorted(parent: u32, children: &[PirNodeId], num_nodes: u32) -> Result<()> {
218    let mut prev: Option<u32> = None;
219    for &child in children {
220        let id = child.as_u32();
221        if id >= num_nodes {
222            return Err(XlogError::Compilation(format!(
223                "GpuPirGraph::from_host: child {:?} out of bounds for parent {}",
224                child, parent
225            )));
226        }
227        if let Some(p) = prev {
228            if id <= p {
229                return Err(XlogError::Compilation(format!(
230                    "GpuPirGraph::from_host: children of {} must be sorted and unique",
231                    parent
232                )));
233            }
234        }
235        prev = Some(id);
236    }
237    Ok(())
238}