Skip to main content

xlog_prob/compilation/
gpu_weights.rs

1//! GPU-native weight table builders for exact inference.
2
3use std::sync::Arc;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::provider::{weights_kernels, WEIGHTS_MODULE};
9use xlog_cuda::{CudaKernelProvider, LaunchAsync};
10
11use crate::compilation::gpu_cnf::GpuCnfVarTables;
12
13pub struct GpuWeights {
14    pub log_true: TrackedCudaSlice<f64>,
15    pub log_false: TrackedCudaSlice<f64>,
16}
17
18fn kernel_count_u32(context: &str, count: usize) -> Result<u32> {
19    u32::try_from(count)
20        .map_err(|_| XlogError::Compilation(format!("{context} exceeds GPU u32 index space")))
21}
22
23fn grid_for(count: u32, block: u32) -> Result<u32> {
24    if count == 0 {
25        return Ok(0);
26    }
27    if block == 0 {
28        return Err(XlogError::Compilation(
29            "GPU weight kernel block size must be nonzero".to_string(),
30        ));
31    }
32    let grid = (count as u64).div_ceil(block as u64);
33    let step = grid
34        .checked_mul(block as u64)
35        .ok_or_else(|| XlogError::Compilation("GPU weight grid-stride overflow".to_string()))?;
36    if step > u32::MAX as u64 {
37        return Err(XlogError::Compilation(
38            "GPU weight grid-stride step exceeds u32 index space".to_string(),
39        ));
40    }
41    u32::try_from(grid).map_err(|_| {
42        XlogError::Compilation("GPU weight kernel grid exceeds u32 index space".to_string())
43    })
44}
45
46fn checked_var_table_count(var_cap: u32) -> Result<u32> {
47    var_cap.checked_add(1).ok_or_else(|| {
48        XlogError::Compilation("GPU weight var_cap exceeds u32 table index space".to_string())
49    })
50}
51
52fn weights_len_for_var_cap(var_cap: u32) -> Result<usize> {
53    (var_cap as usize)
54        .checked_add(1)
55        .ok_or_else(|| XlogError::Compilation("weight table size overflow".to_string()))
56}
57
58fn query_weights_len_for_var_cap(var_cap: u32) -> Result<usize> {
59    (var_cap as usize)
60        .checked_add(1)
61        .ok_or_else(|| XlogError::Compilation("query var_cap overflow".to_string()))
62}
63
64fn evidence_len_for_var_cap(var_cap: u32) -> Result<usize> {
65    (var_cap as usize)
66        .checked_add(1)
67        .ok_or_else(|| XlogError::Compilation("evidence var_cap overflow".to_string()))
68}
69
70pub fn build_evidence_by_var_gpu(
71    node_var: &TrackedCudaSlice<u32>,
72    evidence_nodes: &TrackedCudaSlice<u32>,
73    evidence_vals: &TrackedCudaSlice<u8>,
74    var_cap: u32,
75    provider: &Arc<CudaKernelProvider>,
76) -> Result<TrackedCudaSlice<u8>> {
77    if evidence_nodes.len() != evidence_vals.len() {
78        return Err(XlogError::Compilation(format!(
79            "GPU evidence nodes len {} != vals len {}",
80            evidence_nodes.len(),
81            evidence_vals.len()
82        )));
83    }
84    let len = evidence_len_for_var_cap(var_cap)?;
85
86    let memory = provider.memory();
87    let device = provider.device().inner();
88    let mut evidence_by_var = memory.alloc::<u8>(len)?;
89    device
90        .memset_zeros(&mut evidence_by_var)
91        .map_err(|e| XlogError::Kernel(format!("Failed to zero evidence buffer: {}", e)))?;
92
93    let count = evidence_nodes.len();
94    if count == 0 {
95        return Ok(evidence_by_var);
96    }
97    let count_u32 = kernel_count_u32("GPU evidence node count", count)?;
98
99    let func = device
100        .get_func(
101            WEIGHTS_MODULE,
102            weights_kernels::WEIGHTS_SET_EVIDENCE_FROM_NODES,
103        )
104        .ok_or_else(|| {
105            XlogError::Kernel("weights_set_evidence_from_nodes kernel not found".to_string())
106        })?;
107    let block = 256u32;
108    let grid = grid_for(count_u32, block)?;
109    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
110    unsafe {
111        func.clone().launch(
112            LaunchConfig {
113                grid_dim: (grid.max(1), 1, 1),
114                block_dim: (block, 1, 1),
115                shared_mem_bytes: 0,
116            },
117            (
118                node_var,
119                evidence_nodes,
120                evidence_vals,
121                count_u32,
122                var_cap,
123                &mut evidence_by_var,
124            ),
125        )
126    }
127    .map_err(|e| XlogError::Kernel(format!("weights_set_evidence_from_nodes failed: {}", e)))?;
128    // No device synchronize: returns device-resident slice; same-stream ordering suffices.
129    Ok(evidence_by_var)
130}
131
132pub fn map_nodes_to_vars_gpu(
133    node_var: &TrackedCudaSlice<u32>,
134    node_ids: &TrackedCudaSlice<u32>,
135    var_cap: u32,
136    provider: &Arc<CudaKernelProvider>,
137) -> Result<TrackedCudaSlice<u32>> {
138    let memory = provider.memory();
139    let device = provider.device().inner();
140    let mut out = memory.alloc::<u32>(node_ids.len())?;
141    let count = node_ids.len();
142    if count == 0 {
143        return Ok(out);
144    }
145    let count_u32 = kernel_count_u32("GPU node-to-var map count", count)?;
146
147    let func = device
148        .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_MAP_NODES_TO_VARS)
149        .ok_or_else(|| {
150            XlogError::Kernel("weights_map_nodes_to_vars kernel not found".to_string())
151        })?;
152
153    let block = 256u32;
154    let grid = grid_for(count_u32, block)?;
155    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
156    unsafe {
157        func.clone().launch(
158            LaunchConfig {
159                grid_dim: (grid.max(1), 1, 1),
160                block_dim: (block, 1, 1),
161                shared_mem_bytes: 0,
162            },
163            (node_var, node_ids, count_u32, var_cap, &mut out),
164        )
165    }
166    .map_err(|e| XlogError::Kernel(format!("weights_map_nodes_to_vars failed: {}", e)))?;
167    // No device synchronize: returns device-resident slice; same-stream ordering suffices.
168    Ok(out)
169}
170
171pub fn apply_query_vars_device(
172    provider: &Arc<CudaKernelProvider>,
173    query_vars: &TrackedCudaSlice<u32>,
174    var_cap: u32,
175    log_false: &mut TrackedCudaSlice<f64>,
176    saved: &mut TrackedCudaSlice<f64>,
177) -> Result<()> {
178    let count = query_vars.len();
179    if saved.len() < count {
180        return Err(XlogError::Compilation(format!(
181            "query restore buffer len {} < query vars len {}",
182            saved.len(),
183            count
184        )));
185    }
186    let weights_len = query_weights_len_for_var_cap(var_cap)?;
187    if log_false.len() < weights_len {
188        return Err(XlogError::Compilation(format!(
189            "log_false len {} < var_cap+1 {}",
190            log_false.len(),
191            weights_len
192        )));
193    }
194    if count == 0 {
195        return Ok(());
196    }
197    let count_u32 = kernel_count_u32("GPU query apply count", count)?;
198
199    let device = provider.device().inner();
200    let func = device
201        .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_APPLY_QUERY_VARS)
202        .ok_or_else(|| {
203            XlogError::Kernel("weights_apply_query_vars kernel not found".to_string())
204        })?;
205
206    let block = 256u32;
207    let grid = grid_for(count_u32, block)?;
208    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
209    unsafe {
210        func.clone().launch(
211            LaunchConfig {
212                grid_dim: (grid.max(1), 1, 1),
213                block_dim: (block, 1, 1),
214                shared_mem_bytes: 0,
215            },
216            (query_vars, count_u32, var_cap, log_false, saved),
217        )
218    }
219    .map_err(|e| XlogError::Kernel(format!("weights_apply_query_vars failed: {}", e)))?;
220    // No device synchronize: same-stream ordering guarantees visibility to subsequent kernels.
221    Ok(())
222}
223
224pub fn restore_query_vars_device(
225    provider: &Arc<CudaKernelProvider>,
226    query_vars: &TrackedCudaSlice<u32>,
227    var_cap: u32,
228    log_false: &mut TrackedCudaSlice<f64>,
229    saved: &TrackedCudaSlice<f64>,
230) -> Result<()> {
231    let count = query_vars.len();
232    if saved.len() < count {
233        return Err(XlogError::Compilation(format!(
234            "query restore buffer len {} < query vars len {}",
235            saved.len(),
236            count
237        )));
238    }
239    let weights_len = query_weights_len_for_var_cap(var_cap)?;
240    if log_false.len() < weights_len {
241        return Err(XlogError::Compilation(format!(
242            "log_false len {} < var_cap+1 {}",
243            log_false.len(),
244            weights_len
245        )));
246    }
247    if count == 0 {
248        return Ok(());
249    }
250    let count_u32 = kernel_count_u32("GPU query restore count", count)?;
251
252    let device = provider.device().inner();
253    let func = device
254        .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_RESTORE_QUERY_VARS)
255        .ok_or_else(|| {
256            XlogError::Kernel("weights_restore_query_vars kernel not found".to_string())
257        })?;
258
259    let block = 256u32;
260    let grid = grid_for(count_u32, block)?;
261    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
262    unsafe {
263        func.clone().launch(
264            LaunchConfig {
265                grid_dim: (grid.max(1), 1, 1),
266                block_dim: (block, 1, 1),
267                shared_mem_bytes: 0,
268            },
269            (query_vars, count_u32, var_cap, log_false, saved),
270        )
271    }
272    .map_err(|e| XlogError::Kernel(format!("weights_restore_query_vars failed: {}", e)))?;
273    // No device synchronize: same-stream ordering guarantees visibility to subsequent kernels.
274    Ok(())
275}
276
277pub fn build_weights_gpu(
278    vars: &GpuCnfVarTables,
279    leaf_probs: &TrackedCudaSlice<f64>,
280    choice_true: &TrackedCudaSlice<f64>,
281    choice_false: &TrackedCudaSlice<f64>,
282    evidence_by_var: &TrackedCudaSlice<u8>,
283    provider: &Arc<CudaKernelProvider>,
284) -> Result<GpuWeights> {
285    let var_cap = vars.max_var;
286    let weights_len = weights_len_for_var_cap(var_cap)?;
287
288    if vars.leaf_var.len() < leaf_probs.len() {
289        return Err(XlogError::Compilation(format!(
290            "leaf_probs len {} exceeds leaf_var len {}",
291            leaf_probs.len(),
292            vars.leaf_var.len()
293        )));
294    }
295    if vars.choice_var.len() < choice_true.len() {
296        return Err(XlogError::Compilation(format!(
297            "choice_true len {} exceeds choice_var len {}",
298            choice_true.len(),
299            vars.choice_var.len()
300        )));
301    }
302    if choice_true.len() != choice_false.len() {
303        return Err(XlogError::Compilation(format!(
304            "choice_true len {} != choice_false len {}",
305            choice_true.len(),
306            choice_false.len()
307        )));
308    }
309    if evidence_by_var.len() != weights_len {
310        return Err(XlogError::Compilation(format!(
311            "evidence_by_var len {} != weights len {}",
312            evidence_by_var.len(),
313            weights_len
314        )));
315    }
316
317    let memory = provider.memory();
318    let device = provider.device().inner();
319    let mut log_true = memory.alloc::<f64>(weights_len)?;
320    let mut log_false = memory.alloc::<f64>(weights_len)?;
321
322    // Initialize to 0.0
323    device
324        .memset_zeros(&mut log_true)
325        .map_err(|e| XlogError::Kernel(format!("Failed to zero log_true weights: {}", e)))?;
326    device
327        .memset_zeros(&mut log_false)
328        .map_err(|e| XlogError::Kernel(format!("Failed to zero log_false weights: {}", e)))?;
329
330    let block = 256u32;
331
332    if !leaf_probs.is_empty() {
333        let leaf_count = kernel_count_u32("GPU leaf probability count", leaf_probs.len())?;
334        let func = device
335            .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FILL_LEAF)
336            .ok_or_else(|| XlogError::Kernel("weights_fill_leaf kernel not found".to_string()))?;
337        let grid = grid_for(leaf_count, block)?;
338        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
339        unsafe {
340            func.clone().launch(
341                LaunchConfig {
342                    grid_dim: (grid.max(1), 1, 1),
343                    block_dim: (block, 1, 1),
344                    shared_mem_bytes: 0,
345                },
346                (
347                    &vars.leaf_var,
348                    leaf_probs,
349                    leaf_count,
350                    var_cap,
351                    &mut log_true,
352                    &mut log_false,
353                ),
354            )
355        }
356        .map_err(|e| XlogError::Kernel(format!("weights_fill_leaf failed: {}", e)))?;
357    }
358
359    if !choice_true.is_empty() {
360        let choice_count = kernel_count_u32("GPU choice probability count", choice_true.len())?;
361        let func = device
362            .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FILL_CHOICE)
363            .ok_or_else(|| XlogError::Kernel("weights_fill_choice kernel not found".to_string()))?;
364        let grid = grid_for(choice_count, block)?;
365        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
366        unsafe {
367            func.clone().launch(
368                LaunchConfig {
369                    grid_dim: (grid.max(1), 1, 1),
370                    block_dim: (block, 1, 1),
371                    shared_mem_bytes: 0,
372                },
373                (
374                    &vars.choice_var,
375                    choice_true,
376                    choice_false,
377                    choice_count,
378                    var_cap,
379                    &mut log_true,
380                    &mut log_false,
381                ),
382            )
383        }
384        .map_err(|e| XlogError::Kernel(format!("weights_fill_choice failed: {}", e)))?;
385    }
386
387    if !evidence_by_var.is_empty() {
388        let var_table_count = checked_var_table_count(var_cap)?;
389        let func = device
390            .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_APPLY_EVIDENCE)
391            .ok_or_else(|| {
392                XlogError::Kernel("weights_apply_evidence kernel not found".to_string())
393            })?;
394        let grid = grid_for(var_table_count, block)?;
395        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
396        unsafe {
397            func.clone().launch(
398                LaunchConfig {
399                    grid_dim: (grid.max(1), 1, 1),
400                    block_dim: (block, 1, 1),
401                    shared_mem_bytes: 0,
402                },
403                (evidence_by_var, var_cap, &mut log_true, &mut log_false),
404            )
405        }
406        .map_err(|e| XlogError::Kernel(format!("weights_apply_evidence failed: {}", e)))?;
407    }
408    // No device synchronize: returns device-resident weights; same-stream ordering suffices.
409    Ok(GpuWeights {
410        log_true,
411        log_false,
412    })
413}
414
415#[allow(dead_code)] // reserved: host-side weight upload path for testing/diagnostics
416pub(crate) fn upload_weights_from_host(
417    provider: &Arc<CudaKernelProvider>,
418    weights: &[(f64, f64)],
419) -> Result<GpuWeights> {
420    let weights_len = weights.len();
421    let mut host_true: Vec<f64> = Vec::with_capacity(weights_len);
422    let mut host_false: Vec<f64> = Vec::with_capacity(weights_len);
423    for &(t, f) in weights {
424        host_true.push(t);
425        host_false.push(f);
426    }
427
428    let memory = provider.memory();
429    let mut log_true = memory.alloc::<f64>(weights_len)?;
430    let mut log_false = memory.alloc::<f64>(weights_len)?;
431    provider
432        .htod_sync_copy_into_tracked(&host_true, &mut log_true)
433        .map_err(|e| XlogError::Kernel(format!("Upload log_true weights failed: {}", e)))?;
434    provider
435        .htod_sync_copy_into_tracked(&host_false, &mut log_false)
436        .map_err(|e| XlogError::Kernel(format!("Upload log_false weights failed: {}", e)))?;
437
438    Ok(GpuWeights {
439        log_true,
440        log_false,
441    })
442}