1use std::sync::Arc;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::provider::{cache_kernels, CACHE_MODULE};
9use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
10use xlog_solve::GpuCnf;
11
12use super::disk_cache;
13use crate::gpu::GpuXgcf;
14
15#[derive(Debug, Clone, Copy)]
22#[non_exhaustive]
23pub struct GpuCircuitCacheConfig {
24 pub num_slots: u32,
26 pub table_size: u32,
28 pub node_cap: u32,
30 pub edge_cap: u32,
32 pub level_cap: u32,
34 pub var_cap: u32,
36}
37
38impl Default for GpuCircuitCacheConfig {
39 fn default() -> Self {
43 Self {
44 num_slots: 4,
45 table_size: 8,
46 node_cap: 65_536,
47 edge_cap: 131_072,
48 level_cap: 65_536,
49 var_cap: 128,
50 }
51 }
52}
53
54fn cache_grid_dim_for_u32_count(context: &str, count: u32, block_dim: u32) -> Result<u32> {
55 if count == 0 {
56 return Ok(0);
57 }
58 if block_dim == 0 {
59 return Err(XlogError::Compilation(format!(
60 "{context}: GPU cache block size must be nonzero"
61 )));
62 }
63 let padded = count
64 .checked_add(block_dim - 1)
65 .ok_or_else(|| XlogError::Compilation(format!("{context}: GPU cache grid overflow")))?;
66 Ok(padded / block_dim)
67}
68
69fn cache_grid_dim_for_u64_count(context: &str, count: u64, block_dim: u32) -> Result<u32> {
70 if count == 0 {
71 return Ok(0);
72 }
73 if block_dim == 0 {
74 return Err(XlogError::Compilation(format!(
75 "{context}: GPU cache block size must be nonzero"
76 )));
77 }
78 let block = block_dim as u64;
79 let grid = count
80 .checked_add(block - 1)
81 .map(|padded| padded / block)
82 .ok_or_else(|| XlogError::Compilation(format!("{context}: GPU cache grid overflow")))?;
83 u32::try_from(grid)
84 .map_err(|_| XlogError::Compilation(format!("{context}: GPU cache grid exceeds u32")))
85}
86
87pub struct GpuCircuitCache {
88 provider: Arc<CudaKernelProvider>,
89 table_size: u32,
90 num_slots: u32,
91 node_cap: u32,
92 edge_cap: u32,
93 level_cap: u32,
94 var_cap: u32,
95 keys: TrackedCudaSlice<u64>,
96 slots: TrackedCudaSlice<u32>,
97 state: TrackedCudaSlice<u32>,
98 last_used: TrackedCudaSlice<u64>,
99 slot_states: TrackedCudaSlice<u32>,
100 clock: TrackedCudaSlice<u64>,
101 node_type: TrackedCudaSlice<u8>,
102 child_offsets: TrackedCudaSlice<u32>,
103 child_indices: TrackedCudaSlice<u32>,
104 lit: TrackedCudaSlice<i32>,
105 decision_var: TrackedCudaSlice<u32>,
106 decision_child_false: TrackedCudaSlice<u32>,
107 decision_child_true: TrackedCudaSlice<u32>,
108 level_nodes: TrackedCudaSlice<u32>,
109 level_offsets: TrackedCudaSlice<u32>,
110 var_log_true: TrackedCudaSlice<f64>,
111 var_log_false: TrackedCudaSlice<f64>,
112 values: TrackedCudaSlice<f64>,
113 adj: TrackedCudaSlice<f64>,
114 grad_true: TrackedCudaSlice<f64>,
115 grad_false: TrackedCudaSlice<f64>,
116 meta_num_nodes: TrackedCudaSlice<u32>,
117 meta_num_levels: TrackedCudaSlice<u32>,
118 meta_root: TrackedCudaSlice<u32>,
119 meta_max_var: TrackedCudaSlice<u32>,
120 always_on: TrackedCudaSlice<u32>,
121 zero_f64: TrackedCudaSlice<f64>,
122 one_f64: TrackedCudaSlice<f64>,
123 free_var_mask: TrackedCudaSlice<u8>,
124 has_free_var_mask: Vec<bool>,
125}
126
127pub struct GpuCacheLookup {
128 provider: Arc<CudaKernelProvider>,
129 slot: TrackedCudaSlice<u32>,
130 compile_needed: TrackedCudaSlice<u32>,
131}
132
133impl GpuCacheLookup {
134 pub fn slot_device(&self) -> &TrackedCudaSlice<u32> {
135 &self.slot
136 }
137
138 pub fn compile_needed_device(&self) -> &TrackedCudaSlice<u32> {
139 &self.compile_needed
140 }
141
142 pub fn provider(&self) -> &Arc<CudaKernelProvider> {
143 &self.provider
144 }
145
146 pub fn into_handle(self) -> Result<GpuCircuitCacheHandle> {
147 let slot_host_vec: Vec<u32> = self
148 .provider
149 .device()
150 .inner()
151 .dtoh_sync_copy(&self.slot)
152 .map_err(|e| XlogError::Kernel(format!("dtoh slot index: {}", e)))?;
153 Ok(GpuCircuitCacheHandle {
154 provider: self.provider,
155 slot: self.slot,
156 compile_needed: self.compile_needed,
157 slot_host: slot_host_vec[0],
158 num_nodes: 0,
159 num_levels: 0,
160 root: 0,
161 max_var: 0,
162 })
163 }
164}
165
166pub struct GpuCircuitCacheHandle {
167 provider: Arc<CudaKernelProvider>,
168 slot: TrackedCudaSlice<u32>,
169 compile_needed: TrackedCudaSlice<u32>,
170 slot_host: u32,
171 num_nodes: u32,
172 num_levels: u32,
173 root: u32,
174 max_var: u32,
175}
176
177impl GpuCircuitCacheHandle {
178 pub fn slot_device(&self) -> &TrackedCudaSlice<u32> {
179 &self.slot
180 }
181
182 pub fn compile_needed_device(&self) -> &TrackedCudaSlice<u32> {
183 &self.compile_needed
184 }
185
186 pub fn provider(&self) -> &Arc<CudaKernelProvider> {
187 &self.provider
188 }
189
190 pub fn num_nodes(&self) -> u32 {
191 self.num_nodes
192 }
193
194 pub fn num_levels(&self) -> u32 {
195 self.num_levels
196 }
197
198 pub fn root(&self) -> u32 {
199 self.root
200 }
201
202 pub fn max_var(&self) -> u32 {
203 self.max_var
204 }
205
206 pub(crate) fn slot_index(&self) -> u32 {
207 self.slot_host
208 }
209
210 #[allow(dead_code)] pub(crate) fn set_meta(&mut self, num_nodes: u32, num_levels: u32, root: u32, max_var: u32) {
212 self.num_nodes = num_nodes;
213 self.num_levels = num_levels;
214 self.root = root;
215 self.max_var = max_var;
216 }
217}
218
219pub fn hash_cnf_gpu(
224 cnf: &GpuCnf,
225 provider: &Arc<CudaKernelProvider>,
226) -> Result<TrackedCudaSlice<u64>> {
227 let memory = provider.memory();
228 let mut out_hash = memory.alloc::<u64>(1)?;
229
230 let func = provider
231 .device()
232 .inner()
233 .get_func(CACHE_MODULE, cache_kernels::CACHE_CNF_HASH)
234 .ok_or_else(|| XlogError::Kernel("cache_cnf_hash kernel not found".to_string()))?;
235
236 unsafe {
238 func.clone().launch(
239 LaunchConfig {
240 grid_dim: (1, 1, 1),
241 block_dim: (1, 1, 1),
242 shared_mem_bytes: 0,
243 },
244 (
245 &cnf.num_vars,
246 &cnf.num_clauses,
247 &cnf.num_lits,
248 &cnf.clause_offsets,
249 &cnf.literals,
250 &mut out_hash,
251 ),
252 )
253 }
254 .map_err(|e| XlogError::Kernel(format!("cache_cnf_hash launch failed: {}", e)))?;
255 Ok(out_hash)
257}
258
259impl GpuCircuitCache {
260 pub fn provider(&self) -> &Arc<CudaKernelProvider> {
261 &self.provider
262 }
263
264 pub fn var_log_weights_mut(
265 &mut self,
266 ) -> (&mut TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) {
267 (&mut self.var_log_true, &mut self.var_log_false)
268 }
269
270 pub fn grad_true(&self) -> &TrackedCudaSlice<f64> {
271 &self.grad_true
272 }
273
274 pub fn grad_false(&self) -> &TrackedCudaSlice<f64> {
275 &self.grad_false
276 }
277
278 pub fn values(&self) -> &TrackedCudaSlice<f64> {
279 &self.values
280 }
281
282 pub fn meta_num_nodes_device(&self) -> &TrackedCudaSlice<u32> {
283 &self.meta_num_nodes
284 }
285
286 pub fn meta_num_levels_device(&self) -> &TrackedCudaSlice<u32> {
287 &self.meta_num_levels
288 }
289
290 pub fn meta_root_device(&self) -> &TrackedCudaSlice<u32> {
291 &self.meta_root
292 }
293
294 pub fn meta_max_var_device(&self) -> &TrackedCudaSlice<u32> {
295 &self.meta_max_var
296 }
297
298 pub fn num_slots(&self) -> u32 {
299 self.num_slots
300 }
301
302 pub(crate) fn has_any_free_var_mask(&self) -> bool {
303 self.has_free_var_mask.iter().any(|&v| v)
304 }
305
306 pub(crate) fn has_free_var_mask_for_slot(&self, slot: u32) -> bool {
307 self.has_free_var_mask
308 .get(slot as usize)
309 .copied()
310 .unwrap_or(false)
311 }
312
313 pub(crate) fn var_stride(&self) -> Result<u32> {
314 self.var_cap
315 .checked_add(1)
316 .ok_or_else(|| XlogError::Compilation("GpuCircuitCache var_cap overflow".to_string()))
317 }
318
319 pub(crate) fn node_stride(&self) -> u32 {
320 self.node_cap
321 }
322
323 pub(crate) fn copy_slot_weights_to_batch(
324 &mut self,
325 handle: &GpuCircuitCacheHandle,
326 out_true_batch: &mut TrackedCudaSlice<f64>,
327 out_false_batch: &mut TrackedCudaSlice<f64>,
328 batch_size: u32,
329 ) -> Result<()> {
330 if batch_size == 0 {
331 return Ok(());
332 }
333 let var_stride = self.var_stride()?;
334 let expected = (batch_size as usize)
335 .checked_mul(var_stride as usize)
336 .ok_or_else(|| {
337 XlogError::Compilation("GpuCircuitCache batch weight size overflow".to_string())
338 })?;
339 if out_true_batch.len() != expected || out_false_batch.len() != expected {
340 return Err(XlogError::Compilation(format!(
341 "GpuCircuitCache batched weight buffers must both have len {}, got {} and {}",
342 expected,
343 out_true_batch.len(),
344 out_false_batch.len()
345 )));
346 }
347
348 let device = self.provider.device().inner();
349 let func = device
350 .get_func(
351 xlog_cuda::provider::WEIGHTS_MODULE,
352 xlog_cuda::provider::weights_kernels::WEIGHTS_COPY_SLOT_TO_BATCH,
353 )
354 .ok_or_else(|| {
355 XlogError::Kernel("weights_copy_slot_to_batch kernel not found".to_string())
356 })?;
357
358 let block_dim = 256u32;
359 let total = (batch_size as u64)
360 .checked_mul(var_stride as u64)
361 .ok_or_else(|| {
362 XlogError::Compilation("GpuCircuitCache batch copy overflow".to_string())
363 })?;
364 let grid_dim =
365 cache_grid_dim_for_u64_count("GpuCircuitCache batch weight copy", total, block_dim)?;
366 if grid_dim == 0 {
367 return Ok(());
368 }
369
370 unsafe {
372 func.clone().launch(
373 LaunchConfig {
374 grid_dim: (grid_dim, 1, 1),
375 block_dim: (block_dim, 1, 1),
376 shared_mem_bytes: 0,
377 },
378 (
379 handle.slot_device(),
380 self.var_cap,
381 &self.var_log_true,
382 &self.var_log_false,
383 out_true_batch,
384 out_false_batch,
385 var_stride,
386 batch_size,
387 ),
388 )
389 }
390 .map_err(|e| XlogError::Kernel(format!("weights_copy_slot_to_batch failed: {}", e)))?;
391
392 Ok(())
393 }
394
395 #[allow(clippy::too_many_arguments)]
396 pub(crate) fn eval_grads_inplace_fused_batched(
397 &mut self,
398 handle: &GpuCircuitCacheHandle,
399 var_log_true_batch: &TrackedCudaSlice<f64>,
400 var_log_false_batch: &TrackedCudaSlice<f64>,
401 values_batch: &mut TrackedCudaSlice<f64>,
402 adj_batch: &mut TrackedCudaSlice<f64>,
403 grad_true_batch: &mut TrackedCudaSlice<f64>,
404 grad_false_batch: &mut TrackedCudaSlice<f64>,
405 batch_size: u32,
406 ) -> Result<()> {
407 if batch_size == 0 {
408 return Ok(());
409 }
410 if self.has_free_var_mask_for_slot(handle.slot_index()) {
411 return Err(XlogError::Execution(
412 "Batched fused eval currently does not support free-var correction".to_string(),
413 ));
414 }
415
416 let var_stride = self.var_stride()?;
417 let node_stride = self.node_stride();
418 let expected_var = (batch_size as usize)
419 .checked_mul(var_stride as usize)
420 .ok_or_else(|| {
421 XlogError::Compilation("GpuCircuitCache batched var buffer overflow".to_string())
422 })?;
423 let expected_node = (batch_size as usize)
424 .checked_mul(node_stride as usize)
425 .ok_or_else(|| {
426 XlogError::Compilation("GpuCircuitCache batched node buffer overflow".to_string())
427 })?;
428
429 if var_log_true_batch.len() != expected_var
430 || var_log_false_batch.len() != expected_var
431 || grad_true_batch.len() != expected_var
432 || grad_false_batch.len() != expected_var
433 {
434 return Err(XlogError::Compilation(format!(
435 "GpuCircuitCache batched var buffers must have len {}",
436 expected_var
437 )));
438 }
439 if values_batch.len() != expected_node || adj_batch.len() != expected_node {
440 return Err(XlogError::Compilation(format!(
441 "GpuCircuitCache batched node buffers must have len {}",
442 expected_node
443 )));
444 }
445
446 let device = self.provider.device().inner();
447 device
448 .memset_zeros(adj_batch)
449 .map_err(|e| XlogError::Kernel(format!("Failed to zero batched adj: {}", e)))?;
450 device
451 .memset_zeros(grad_true_batch)
452 .map_err(|e| XlogError::Kernel(format!("Failed to zero batched grad_true: {}", e)))?;
453 device
454 .memset_zeros(grad_false_batch)
455 .map_err(|e| XlogError::Kernel(format!("Failed to zero batched grad_false: {}", e)))?;
456
457 let eval_all = device
458 .get_func(
459 xlog_cuda::CIRCUIT_MODULE,
460 xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED_BATCHED,
461 )
462 .ok_or_else(|| {
463 XlogError::Kernel("xgcf_eval_all_levels_cached_batched not found".to_string())
464 })?;
465 let set_root_adj = device
466 .get_func(
467 xlog_cuda::CIRCUIT_MODULE,
468 xlog_cuda::circuit_kernels::XGCF_SET_ROOT_ADJ_CACHED_BATCHED,
469 )
470 .ok_or_else(|| {
471 XlogError::Kernel("xgcf_set_root_adj_cached_batched not found".to_string())
472 })?;
473 let backward_all = device
474 .get_func(
475 xlog_cuda::CIRCUIT_MODULE,
476 xlog_cuda::circuit_kernels::XGCF_BACKWARD_ALL_LEVELS_CACHED_BATCHED,
477 )
478 .ok_or_else(|| {
479 XlogError::Kernel("xgcf_backward_all_levels_cached_batched not found".to_string())
480 })?;
481
482 let block_size = 256u32;
483 let mut eval_params: Vec<*mut std::ffi::c_void> = vec![
484 handle.slot_device().as_kernel_param(),
485 self.node_cap.as_kernel_param(),
486 self.edge_cap.as_kernel_param(),
487 self.level_cap.as_kernel_param(),
488 self.var_cap.as_kernel_param(),
489 (&self.node_type).as_kernel_param(),
490 (&self.child_offsets).as_kernel_param(),
491 (&self.child_indices).as_kernel_param(),
492 (&self.lit).as_kernel_param(),
493 (&self.decision_var).as_kernel_param(),
494 (&self.decision_child_false).as_kernel_param(),
495 (&self.decision_child_true).as_kernel_param(),
496 (&self.level_nodes).as_kernel_param(),
497 (&self.level_offsets).as_kernel_param(),
498 (&self.meta_num_levels).as_kernel_param(),
499 var_log_true_batch.as_kernel_param(),
500 var_log_false_batch.as_kernel_param(),
501 var_stride.as_kernel_param(),
502 values_batch.as_kernel_param(),
503 node_stride.as_kernel_param(),
504 batch_size.as_kernel_param(),
505 ];
506 unsafe {
508 eval_all.clone().launch(
509 LaunchConfig {
510 grid_dim: (batch_size, 1, 1),
511 block_dim: (block_size, 1, 1),
512 shared_mem_bytes: 0,
513 },
514 &mut eval_params,
515 )
516 }
517 .map_err(|e| {
518 XlogError::Kernel(format!("xgcf_eval_all_levels_cached_batched failed: {}", e))
519 })?;
520
521 unsafe {
523 set_root_adj.clone().launch(
524 LaunchConfig {
525 grid_dim: (batch_size, 1, 1),
526 block_dim: (1, 1, 1),
527 shared_mem_bytes: 0,
528 },
529 (
530 handle.slot_device(),
531 self.node_cap,
532 &self.meta_root,
533 &mut *adj_batch,
534 node_stride,
535 batch_size,
536 ),
537 )
538 }
539 .map_err(|e| {
540 XlogError::Kernel(format!("xgcf_set_root_adj_cached_batched failed: {}", e))
541 })?;
542
543 let mut backward_params: Vec<*mut std::ffi::c_void> = vec![
544 handle.slot_device().as_kernel_param(),
545 self.node_cap.as_kernel_param(),
546 self.edge_cap.as_kernel_param(),
547 self.level_cap.as_kernel_param(),
548 self.var_cap.as_kernel_param(),
549 (&self.node_type).as_kernel_param(),
550 (&self.child_offsets).as_kernel_param(),
551 (&self.child_indices).as_kernel_param(),
552 (&self.decision_var).as_kernel_param(),
553 (&self.decision_child_false).as_kernel_param(),
554 (&self.decision_child_true).as_kernel_param(),
555 (&self.lit).as_kernel_param(),
556 (&self.level_nodes).as_kernel_param(),
557 (&self.level_offsets).as_kernel_param(),
558 (&self.meta_num_levels).as_kernel_param(),
559 var_log_true_batch.as_kernel_param(),
560 var_log_false_batch.as_kernel_param(),
561 var_stride.as_kernel_param(),
562 values_batch.as_kernel_param(),
563 node_stride.as_kernel_param(),
564 adj_batch.as_kernel_param(),
565 node_stride.as_kernel_param(),
566 grad_true_batch.as_kernel_param(),
567 grad_false_batch.as_kernel_param(),
568 var_stride.as_kernel_param(),
569 batch_size.as_kernel_param(),
570 ];
571 unsafe {
573 backward_all.clone().launch(
574 LaunchConfig {
575 grid_dim: (batch_size, 1, 1),
576 block_dim: (block_size, 1, 1),
577 shared_mem_bytes: 0,
578 },
579 &mut backward_params,
580 )
581 }
582 .map_err(|e| {
583 XlogError::Kernel(format!(
584 "xgcf_backward_all_levels_cached_batched failed: {}",
585 e
586 ))
587 })?;
588
589 Ok(())
590 }
591
592 pub(crate) fn copy_root_batched_from_values(
593 &self,
594 handle: &GpuCircuitCacheHandle,
595 values_batch: &TrackedCudaSlice<f64>,
596 out_roots: &mut TrackedCudaSlice<f64>,
597 batch_size: u32,
598 ) -> Result<()> {
599 if batch_size == 0 {
600 return Ok(());
601 }
602 let node_stride = self.node_stride();
603 let expected_values = (batch_size as usize)
604 .checked_mul(node_stride as usize)
605 .ok_or_else(|| {
606 XlogError::Compilation("GpuCircuitCache batched values overflow".to_string())
607 })?;
608 if values_batch.len() != expected_values || out_roots.len() != batch_size as usize {
609 return Err(XlogError::Compilation(format!(
610 "GpuCircuitCache root copy expects values len {} and roots len {}, got {} and {}",
611 expected_values,
612 batch_size,
613 values_batch.len(),
614 out_roots.len()
615 )));
616 }
617
618 let device = self.provider.device().inner();
619 let copy_root = device
620 .get_func(
621 xlog_cuda::CIRCUIT_MODULE,
622 xlog_cuda::circuit_kernels::XGCF_COPY_ROOT_CACHED_META_BATCHED,
623 )
624 .ok_or_else(|| {
625 XlogError::Kernel("xgcf_copy_root_cached_meta_batched not found".to_string())
626 })?;
627 unsafe {
629 copy_root.clone().launch(
630 LaunchConfig {
631 grid_dim: (batch_size, 1, 1),
632 block_dim: (1, 1, 1),
633 shared_mem_bytes: 0,
634 },
635 (
636 handle.slot_device(),
637 self.node_cap,
638 &self.meta_root,
639 values_batch,
640 node_stride,
641 out_roots,
642 batch_size,
643 ),
644 )
645 }
646 .map_err(|e| {
647 XlogError::Kernel(format!("xgcf_copy_root_cached_meta_batched failed: {}", e))
648 })?;
649 Ok(())
650 }
651
652 pub fn new(provider: &Arc<CudaKernelProvider>, config: GpuCircuitCacheConfig) -> Result<Self> {
653 if config.num_slots == 0 {
654 return Err(XlogError::Compilation(
655 "GpuCircuitCache requires num_slots > 0".to_string(),
656 ));
657 }
658 if config.table_size == 0 {
659 return Err(XlogError::Compilation(
660 "GpuCircuitCache requires table_size > 0".to_string(),
661 ));
662 }
663 if config.table_size < config.num_slots {
664 return Err(XlogError::Compilation(format!(
665 "GpuCircuitCache table_size {} < num_slots {}",
666 config.table_size, config.num_slots
667 )));
668 }
669 if config.node_cap == 0
670 || config.edge_cap == 0
671 || config.level_cap == 0
672 || config.var_cap == 0
673 {
674 return Err(XlogError::Compilation(
675 "GpuCircuitCache requires non-zero caps".to_string(),
676 ));
677 }
678
679 let memory = provider.memory();
680 let device = provider.device().inner();
681
682 let table_len = usize::try_from(config.table_size).map_err(|_| {
683 XlogError::Compilation("GpuCircuitCache table_size overflow".to_string())
684 })?;
685 let slot_len = usize::try_from(config.num_slots).map_err(|_| {
686 XlogError::Compilation("GpuCircuitCache num_slots overflow".to_string())
687 })?;
688
689 let node_cap = usize::try_from(config.node_cap)
690 .map_err(|_| XlogError::Compilation("GpuCircuitCache node_cap overflow".to_string()))?;
691 let edge_cap = usize::try_from(config.edge_cap)
692 .map_err(|_| XlogError::Compilation("GpuCircuitCache edge_cap overflow".to_string()))?;
693 let level_cap = usize::try_from(config.level_cap).map_err(|_| {
694 XlogError::Compilation("GpuCircuitCache level_cap overflow".to_string())
695 })?;
696 let var_cap = usize::try_from(config.var_cap)
697 .map_err(|_| XlogError::Compilation("GpuCircuitCache var_cap overflow".to_string()))?;
698
699 let node_slots = slot_len.checked_mul(node_cap).ok_or_else(|| {
700 XlogError::Compilation("GpuCircuitCache node slots overflow".to_string())
701 })?;
702 let edge_slots = slot_len.checked_mul(edge_cap).ok_or_else(|| {
703 XlogError::Compilation("GpuCircuitCache edge slots overflow".to_string())
704 })?;
705 let var_slots = slot_len.checked_mul(var_cap + 1).ok_or_else(|| {
706 XlogError::Compilation("GpuCircuitCache var slots overflow".to_string())
707 })?;
708 let node_offsets = slot_len.checked_mul(node_cap + 1).ok_or_else(|| {
709 XlogError::Compilation("GpuCircuitCache offset slots overflow".to_string())
710 })?;
711 let level_offsets = slot_len.checked_mul(level_cap + 1).ok_or_else(|| {
712 XlogError::Compilation("GpuCircuitCache level offsets overflow".to_string())
713 })?;
714
715 let mut keys = memory.alloc::<u64>(table_len)?;
716 let mut slots = memory.alloc::<u32>(table_len)?;
717 let mut state = memory.alloc::<u32>(table_len)?;
718 let mut last_used = memory.alloc::<u64>(table_len)?;
719 let mut slot_states = memory.alloc::<u32>(slot_len)?;
720 let mut clock = memory.alloc::<u64>(1)?;
721
722 let mut node_type = memory.alloc::<u8>(node_slots)?;
723 let mut child_offsets = memory.alloc::<u32>(node_offsets)?;
724 let mut child_indices = memory.alloc::<u32>(edge_slots)?;
725 let mut lit = memory.alloc::<i32>(node_slots)?;
726 let mut decision_var = memory.alloc::<u32>(node_slots)?;
727 let mut decision_child_false = memory.alloc::<u32>(node_slots)?;
728 let mut decision_child_true = memory.alloc::<u32>(node_slots)?;
729 let mut level_nodes = memory.alloc::<u32>(node_slots)?;
730 let mut level_offsets = memory.alloc::<u32>(level_offsets)?;
731
732 let mut var_log_true = memory.alloc::<f64>(var_slots)?;
733 let mut var_log_false = memory.alloc::<f64>(var_slots)?;
734 let mut values = memory.alloc::<f64>(node_slots)?;
735 let mut adj = memory.alloc::<f64>(node_slots)?;
736 let mut grad_true = memory.alloc::<f64>(var_slots)?;
737 let mut grad_false = memory.alloc::<f64>(var_slots)?;
738 let mut free_var_mask = memory.alloc::<u8>(var_slots)?;
739 let mut meta_num_nodes = memory.alloc::<u32>(slot_len)?;
740 let mut meta_num_levels = memory.alloc::<u32>(slot_len)?;
741 let mut meta_root = memory.alloc::<u32>(slot_len)?;
742 let mut meta_max_var = memory.alloc::<u32>(slot_len)?;
743 let mut always_on = memory.alloc::<u32>(1)?;
744 let zero_len = node_cap.max(var_cap + 1);
745 let mut zero_f64 = memory.alloc::<f64>(zero_len)?;
746 let mut one_f64 = memory.alloc::<f64>(1)?;
747
748 device
749 .memset_zeros(&mut keys)
750 .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero keys failed: {}", e)))?;
751 device
752 .memset_zeros(&mut slots)
753 .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero slots failed: {}", e)))?;
754 device
755 .memset_zeros(&mut state)
756 .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero state failed: {}", e)))?;
757 device.memset_zeros(&mut last_used).map_err(|e| {
758 XlogError::Kernel(format!("GpuCircuitCache zero last_used failed: {}", e))
759 })?;
760 device.memset_zeros(&mut slot_states).map_err(|e| {
761 XlogError::Kernel(format!("GpuCircuitCache zero slot_states failed: {}", e))
762 })?;
763 device
764 .memset_zeros(&mut clock)
765 .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero clock failed: {}", e)))?;
766
767 device.memset_zeros(&mut node_type).map_err(|e| {
768 XlogError::Kernel(format!("GpuCircuitCache zero node_type failed: {}", e))
769 })?;
770 device.memset_zeros(&mut child_offsets).map_err(|e| {
771 XlogError::Kernel(format!("GpuCircuitCache zero child_offsets failed: {}", e))
772 })?;
773 device.memset_zeros(&mut child_indices).map_err(|e| {
774 XlogError::Kernel(format!("GpuCircuitCache zero child_indices failed: {}", e))
775 })?;
776 device
777 .memset_zeros(&mut lit)
778 .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero lit failed: {}", e)))?;
779 device.memset_zeros(&mut decision_var).map_err(|e| {
780 XlogError::Kernel(format!("GpuCircuitCache zero decision_var failed: {}", e))
781 })?;
782 device
783 .memset_zeros(&mut decision_child_false)
784 .map_err(|e| {
785 XlogError::Kernel(format!(
786 "GpuCircuitCache zero decision_child_false failed: {}",
787 e
788 ))
789 })?;
790 device.memset_zeros(&mut decision_child_true).map_err(|e| {
791 XlogError::Kernel(format!(
792 "GpuCircuitCache zero decision_child_true failed: {}",
793 e
794 ))
795 })?;
796 device.memset_zeros(&mut level_nodes).map_err(|e| {
797 XlogError::Kernel(format!("GpuCircuitCache zero level_nodes failed: {}", e))
798 })?;
799 device.memset_zeros(&mut level_offsets).map_err(|e| {
800 XlogError::Kernel(format!("GpuCircuitCache zero level_offsets failed: {}", e))
801 })?;
802 device.memset_zeros(&mut var_log_true).map_err(|e| {
803 XlogError::Kernel(format!("GpuCircuitCache zero var_log_true failed: {}", e))
804 })?;
805 device.memset_zeros(&mut var_log_false).map_err(|e| {
806 XlogError::Kernel(format!("GpuCircuitCache zero var_log_false failed: {}", e))
807 })?;
808 device
809 .memset_zeros(&mut values)
810 .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero values failed: {}", e)))?;
811 device
812 .memset_zeros(&mut adj)
813 .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero adj failed: {}", e)))?;
814 device.memset_zeros(&mut grad_true).map_err(|e| {
815 XlogError::Kernel(format!("GpuCircuitCache zero grad_true failed: {}", e))
816 })?;
817 device.memset_zeros(&mut grad_false).map_err(|e| {
818 XlogError::Kernel(format!("GpuCircuitCache zero grad_false failed: {}", e))
819 })?;
820 device.memset_zeros(&mut free_var_mask).map_err(|e| {
821 XlogError::Kernel(format!("GpuCircuitCache zero free_var_mask failed: {}", e))
822 })?;
823 device.memset_zeros(&mut meta_num_nodes).map_err(|e| {
824 XlogError::Kernel(format!("GpuCircuitCache zero meta_num_nodes failed: {}", e))
825 })?;
826 device.memset_zeros(&mut meta_num_levels).map_err(|e| {
827 XlogError::Kernel(format!(
828 "GpuCircuitCache zero meta_num_levels failed: {}",
829 e
830 ))
831 })?;
832 device.memset_zeros(&mut meta_root).map_err(|e| {
833 XlogError::Kernel(format!("GpuCircuitCache zero meta_root failed: {}", e))
834 })?;
835 device.memset_zeros(&mut meta_max_var).map_err(|e| {
836 XlogError::Kernel(format!("GpuCircuitCache zero meta_max_var failed: {}", e))
837 })?;
838 device.memset_zeros(&mut zero_f64).map_err(|e| {
839 XlogError::Kernel(format!("GpuCircuitCache zero zero_f64 failed: {}", e))
840 })?;
841 provider
842 .htod_launch_metadata_sync_copy_into(&[1u32], &mut always_on)
843 .map_err(|e| {
844 XlogError::Kernel(format!("GpuCircuitCache init always_on failed: {}", e))
845 })?;
846 provider
847 .htod_launch_metadata_sync_copy_into(&[1.0f64], &mut one_f64)
848 .map_err(|e| {
849 XlogError::Kernel(format!("GpuCircuitCache init one_f64 failed: {}", e))
850 })?;
851
852 Ok(Self {
853 provider: provider.clone(),
854 table_size: config.table_size,
855 num_slots: config.num_slots,
856 node_cap: config.node_cap,
857 edge_cap: config.edge_cap,
858 level_cap: config.level_cap,
859 var_cap: config.var_cap,
860 keys,
861 slots,
862 state,
863 last_used,
864 slot_states,
865 clock,
866 node_type,
867 child_offsets,
868 child_indices,
869 lit,
870 decision_var,
871 decision_child_false,
872 decision_child_true,
873 level_nodes,
874 level_offsets,
875 var_log_true,
876 var_log_false,
877 values,
878 adj,
879 grad_true,
880 grad_false,
881 meta_num_nodes,
882 meta_num_levels,
883 meta_root,
884 meta_max_var,
885 always_on,
886 zero_f64,
887 one_f64,
888 free_var_mask,
889 has_free_var_mask: vec![false; config.num_slots as usize],
890 })
891 }
892
893 pub fn lookup_or_insert(&mut self, key: u64) -> Result<GpuCacheLookup> {
894 let memory = self.provider.memory();
895 let mut key_device = memory.alloc::<u64>(1)?;
896 self.provider
897 .htod_launch_metadata_sync_copy_into(&[key], &mut key_device)
898 .map_err(|e| XlogError::Kernel(format!("cache upload key failed: {}", e)))?;
899 self.lookup_or_insert_device(&key_device)
900 }
901
902 pub(crate) fn lookup_or_insert_device(
903 &mut self,
904 key_device: &TrackedCudaSlice<u64>,
905 ) -> Result<GpuCacheLookup> {
906 let memory = self.provider.memory();
907 let mut out_slot = memory.alloc::<u32>(1)?;
908 let mut out_compile_needed = memory.alloc::<u32>(1)?;
909
910 let func = self
911 .provider
912 .device()
913 .inner()
914 .get_func(CACHE_MODULE, cache_kernels::CACHE_LOOKUP_OR_INSERT)
915 .ok_or_else(|| {
916 XlogError::Kernel("cache_lookup_or_insert kernel not found".to_string())
917 })?;
918
919 unsafe {
921 func.clone().launch(
922 LaunchConfig {
923 grid_dim: (1, 1, 1),
924 block_dim: (1, 1, 1),
925 shared_mem_bytes: 0,
926 },
927 (
928 key_device,
929 self.table_size,
930 self.num_slots,
931 &mut self.keys,
932 &mut self.slots,
933 &mut self.state,
934 &mut self.last_used,
935 &mut self.slot_states,
936 &mut self.clock,
937 &mut out_slot,
938 &mut out_compile_needed,
939 ),
940 )
941 }
942 .map_err(|e| XlogError::Kernel(format!("cache_lookup_or_insert failed: {}", e)))?;
943 Ok(GpuCacheLookup {
945 provider: self.provider.clone(),
946 slot: out_slot,
947 compile_needed: out_compile_needed,
948 })
949 }
950
951 pub fn claim_slot(&mut self, key: u64) -> Result<GpuCircuitCacheHandle> {
952 let lookup = self.lookup_or_insert(key)?;
953 lookup.into_handle()
954 }
955
956 pub fn store_from_xgcf(
957 &mut self,
958 handle: &mut GpuCircuitCacheHandle,
959 xgcf: &GpuXgcf,
960 ) -> Result<()> {
961 let device = self.provider.device().inner();
966 let num_nodes_host: Vec<u32> = device
967 .dtoh_sync_copy(xgcf.num_nodes_device())
968 .map_err(|e| XlogError::Kernel(format!("dtoh meta_num_nodes: {}", e)))?;
969 let num_nodes = num_nodes_host[0];
970 if num_nodes == 0 {
971 return Err(XlogError::Compilation(
972 "GpuCircuitCache store: num_nodes must be > 0".to_string(),
973 ));
974 }
975 if num_nodes > self.node_cap {
976 return Err(XlogError::Compilation(format!(
977 "GpuCircuitCache store: num_nodes {} exceeds node_cap {}",
978 num_nodes, self.node_cap
979 )));
980 }
981
982 let num_edges_host: Vec<u32> = device
983 .dtoh_sync_copy(xgcf.num_edges_device())
984 .map_err(|e| XlogError::Kernel(format!("dtoh meta_num_edges: {}", e)))?;
985 let num_edges = num_edges_host[0];
986 if num_edges > self.edge_cap {
987 return Err(XlogError::Compilation(format!(
988 "GpuCircuitCache store: num_edges {} exceeds edge_cap {}",
989 num_edges, self.edge_cap
990 )));
991 }
992
993 let num_levels = xgcf.num_levels();
994 if num_levels == 0 {
995 return Err(XlogError::Compilation(
996 "GpuCircuitCache store: num_levels must be > 0".to_string(),
997 ));
998 }
999 if num_levels > self.level_cap {
1000 return Err(XlogError::Compilation(format!(
1001 "GpuCircuitCache store: num_levels {} exceeds level_cap {}",
1002 num_levels, self.level_cap
1003 )));
1004 }
1005
1006 let root = xgcf.root();
1007 if root >= num_nodes {
1008 return Err(XlogError::Compilation(format!(
1009 "GpuCircuitCache store: root {} out of bounds (num_nodes={})",
1010 root, num_nodes
1011 )));
1012 }
1013
1014 let max_var = xgcf.max_var();
1015 if max_var > self.var_cap {
1016 return Err(XlogError::Compilation(format!(
1017 "GpuCircuitCache store: max_var {} exceeds var_cap {}",
1018 max_var, self.var_cap
1019 )));
1020 }
1021
1022 let expected_child_offsets = (num_nodes as usize) + 1;
1023 if xgcf.child_offsets().len() < expected_child_offsets {
1024 return Err(XlogError::Compilation(format!(
1025 "GpuCircuitCache store: child_offsets len {} < num_nodes+1 {}",
1026 xgcf.child_offsets().len(),
1027 expected_child_offsets
1028 )));
1029 }
1030 if xgcf.level_nodes().len() < num_nodes as usize {
1031 return Err(XlogError::Compilation(format!(
1032 "GpuCircuitCache store: level_nodes len {} < num_nodes {}",
1033 xgcf.level_nodes().len(),
1034 num_nodes
1035 )));
1036 }
1037 let expected_level_offsets = (num_levels as usize) + 1;
1038 if xgcf.level_offsets().len() != expected_level_offsets {
1039 return Err(XlogError::Compilation(format!(
1040 "GpuCircuitCache store: level_offsets len {} != num_levels+1 {}",
1041 xgcf.level_offsets().len(),
1042 expected_level_offsets
1043 )));
1044 }
1045
1046 handle.num_nodes = num_nodes;
1047 handle.num_levels = num_levels;
1048 handle.root = root;
1049 handle.max_var = max_var;
1050
1051 let store_u8 = device
1052 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U8)
1053 .ok_or_else(|| XlogError::Kernel("cache_store_u8 kernel not found".to_string()))?;
1054 let store_u32 = device
1055 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U32)
1056 .ok_or_else(|| XlogError::Kernel("cache_store_u32 kernel not found".to_string()))?;
1057 let store_i32 = device
1058 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_I32)
1059 .ok_or_else(|| XlogError::Kernel("cache_store_i32 kernel not found".to_string()))?;
1060 let store_f64 = device
1061 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
1062 .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
1063 let store_meta = device
1064 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_META)
1065 .ok_or_else(|| XlogError::Kernel("cache_store_meta kernel not found".to_string()))?;
1066
1067 let block_dim = 256u32;
1068
1069 let node_stride = self.node_cap;
1070 let offset_stride = self.node_cap.checked_add(1).ok_or_else(|| {
1071 XlogError::Compilation("GpuCircuitCache store: node_cap overflow".to_string())
1072 })?;
1073 let level_offset_stride = self.level_cap.checked_add(1).ok_or_else(|| {
1074 XlogError::Compilation("GpuCircuitCache store: level_cap overflow".to_string())
1075 })?;
1076 let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1077 XlogError::Compilation("GpuCircuitCache store: var_cap overflow".to_string())
1078 })?;
1079
1080 let num_nodes_plus1 = num_nodes.checked_add(1).ok_or_else(|| {
1081 XlogError::Compilation("GpuCircuitCache store: num_nodes overflow".to_string())
1082 })?;
1083 let num_levels_plus1 = num_levels.checked_add(1).ok_or_else(|| {
1084 XlogError::Compilation("GpuCircuitCache store: num_levels overflow".to_string())
1085 })?;
1086 let weights_len = max_var.checked_add(1).ok_or_else(|| {
1087 XlogError::Compilation("GpuCircuitCache store: max_var overflow".to_string())
1088 })?;
1089
1090 let grid_nodes =
1091 cache_grid_dim_for_u32_count("GpuCircuitCache store node_type", num_nodes, block_dim)?;
1092 if grid_nodes != 0 {
1093 unsafe {
1095 store_u8.clone().launch(
1096 LaunchConfig {
1097 grid_dim: (grid_nodes, 1, 1),
1098 block_dim: (block_dim, 1, 1),
1099 shared_mem_bytes: 0,
1100 },
1101 (
1102 handle.slot_device(),
1103 handle.compile_needed_device(),
1104 node_stride,
1105 xgcf.node_type(),
1106 &mut self.node_type,
1107 num_nodes,
1108 ),
1109 )
1110 }
1111 .map_err(|e| XlogError::Kernel(format!("cache_store_u8 failed: {}", e)))?;
1112 }
1113
1114 let grid_offsets = cache_grid_dim_for_u32_count(
1115 "GpuCircuitCache store child_offsets",
1116 num_nodes_plus1,
1117 block_dim,
1118 )?;
1119 if grid_offsets != 0 {
1120 unsafe {
1122 store_u32.clone().launch(
1123 LaunchConfig {
1124 grid_dim: (grid_offsets, 1, 1),
1125 block_dim: (block_dim, 1, 1),
1126 shared_mem_bytes: 0,
1127 },
1128 (
1129 handle.slot_device(),
1130 handle.compile_needed_device(),
1131 offset_stride,
1132 xgcf.child_offsets(),
1133 &mut self.child_offsets,
1134 num_nodes_plus1,
1135 ),
1136 )
1137 }
1138 .map_err(|e| XlogError::Kernel(format!("cache_store_child_offsets failed: {}", e)))?;
1139 }
1140
1141 let grid_edges = cache_grid_dim_for_u32_count(
1142 "GpuCircuitCache store child_indices",
1143 num_edges,
1144 block_dim,
1145 )?;
1146 if grid_edges != 0 {
1147 unsafe {
1149 store_u32.clone().launch(
1150 LaunchConfig {
1151 grid_dim: (grid_edges, 1, 1),
1152 block_dim: (block_dim, 1, 1),
1153 shared_mem_bytes: 0,
1154 },
1155 (
1156 handle.slot_device(),
1157 handle.compile_needed_device(),
1158 self.edge_cap,
1159 xgcf.child_indices(),
1160 &mut self.child_indices,
1161 num_edges,
1162 ),
1163 )
1164 }
1165 .map_err(|e| XlogError::Kernel(format!("cache_store_child_indices failed: {}", e)))?;
1166 }
1167
1168 if grid_nodes != 0 {
1169 unsafe {
1171 store_i32.clone().launch(
1172 LaunchConfig {
1173 grid_dim: (grid_nodes, 1, 1),
1174 block_dim: (block_dim, 1, 1),
1175 shared_mem_bytes: 0,
1176 },
1177 (
1178 handle.slot_device(),
1179 handle.compile_needed_device(),
1180 node_stride,
1181 xgcf.lit(),
1182 &mut self.lit,
1183 num_nodes,
1184 ),
1185 )
1186 }
1187 .map_err(|e| XlogError::Kernel(format!("cache_store_lit failed: {}", e)))?;
1188
1189 unsafe {
1191 store_u32.clone().launch(
1192 LaunchConfig {
1193 grid_dim: (grid_nodes, 1, 1),
1194 block_dim: (block_dim, 1, 1),
1195 shared_mem_bytes: 0,
1196 },
1197 (
1198 handle.slot_device(),
1199 handle.compile_needed_device(),
1200 node_stride,
1201 xgcf.decision_var(),
1202 &mut self.decision_var,
1203 num_nodes,
1204 ),
1205 )
1206 }
1207 .map_err(|e| XlogError::Kernel(format!("cache_store_decision_var failed: {}", e)))?;
1208
1209 unsafe {
1211 store_u32.clone().launch(
1212 LaunchConfig {
1213 grid_dim: (grid_nodes, 1, 1),
1214 block_dim: (block_dim, 1, 1),
1215 shared_mem_bytes: 0,
1216 },
1217 (
1218 handle.slot_device(),
1219 handle.compile_needed_device(),
1220 node_stride,
1221 xgcf.decision_child_false(),
1222 &mut self.decision_child_false,
1223 num_nodes,
1224 ),
1225 )
1226 }
1227 .map_err(|e| {
1228 XlogError::Kernel(format!("cache_store_decision_child_false failed: {}", e))
1229 })?;
1230
1231 unsafe {
1233 store_u32.clone().launch(
1234 LaunchConfig {
1235 grid_dim: (grid_nodes, 1, 1),
1236 block_dim: (block_dim, 1, 1),
1237 shared_mem_bytes: 0,
1238 },
1239 (
1240 handle.slot_device(),
1241 handle.compile_needed_device(),
1242 node_stride,
1243 xgcf.decision_child_true(),
1244 &mut self.decision_child_true,
1245 num_nodes,
1246 ),
1247 )
1248 }
1249 .map_err(|e| {
1250 XlogError::Kernel(format!("cache_store_decision_child_true failed: {}", e))
1251 })?;
1252
1253 unsafe {
1255 store_u32.clone().launch(
1256 LaunchConfig {
1257 grid_dim: (grid_nodes, 1, 1),
1258 block_dim: (block_dim, 1, 1),
1259 shared_mem_bytes: 0,
1260 },
1261 (
1262 handle.slot_device(),
1263 handle.compile_needed_device(),
1264 node_stride,
1265 xgcf.level_nodes(),
1266 &mut self.level_nodes,
1267 num_nodes,
1268 ),
1269 )
1270 }
1271 .map_err(|e| XlogError::Kernel(format!("cache_store_level_nodes failed: {}", e)))?;
1272 }
1273
1274 let grid_levels = cache_grid_dim_for_u32_count(
1275 "GpuCircuitCache store level_offsets",
1276 num_levels_plus1,
1277 block_dim,
1278 )?;
1279 if grid_levels != 0 {
1280 unsafe {
1282 store_u32.clone().launch(
1283 LaunchConfig {
1284 grid_dim: (grid_levels, 1, 1),
1285 block_dim: (block_dim, 1, 1),
1286 shared_mem_bytes: 0,
1287 },
1288 (
1289 handle.slot_device(),
1290 handle.compile_needed_device(),
1291 level_offset_stride,
1292 xgcf.level_offsets(),
1293 &mut self.level_offsets,
1294 num_levels_plus1,
1295 ),
1296 )
1297 }
1298 .map_err(|e| XlogError::Kernel(format!("cache_store_level_offsets failed: {}", e)))?;
1299 }
1300
1301 let grid_weights = cache_grid_dim_for_u32_count(
1302 "GpuCircuitCache store free_var_mask",
1303 weights_len,
1304 block_dim,
1305 )?;
1306 if grid_weights != 0 {
1307 unsafe {
1309 store_f64.clone().launch(
1310 LaunchConfig {
1311 grid_dim: (grid_weights, 1, 1),
1312 block_dim: (block_dim, 1, 1),
1313 shared_mem_bytes: 0,
1314 },
1315 (
1316 handle.slot_device(),
1317 handle.compile_needed_device(),
1318 var_stride,
1319 xgcf.var_log_true(),
1320 &mut self.var_log_true,
1321 weights_len,
1322 ),
1323 )
1324 }
1325 .map_err(|e| XlogError::Kernel(format!("cache_store_var_log_true failed: {}", e)))?;
1326
1327 unsafe {
1329 store_f64.clone().launch(
1330 LaunchConfig {
1331 grid_dim: (grid_weights, 1, 1),
1332 block_dim: (block_dim, 1, 1),
1333 shared_mem_bytes: 0,
1334 },
1335 (
1336 handle.slot_device(),
1337 handle.compile_needed_device(),
1338 var_stride,
1339 xgcf.var_log_false(),
1340 &mut self.var_log_false,
1341 weights_len,
1342 ),
1343 )
1344 }
1345 .map_err(|e| XlogError::Kernel(format!("cache_store_var_log_false failed: {}", e)))?;
1346 }
1347
1348 unsafe {
1350 store_meta.clone().launch(
1351 LaunchConfig {
1352 grid_dim: (1, 1, 1),
1353 block_dim: (1, 1, 1),
1354 shared_mem_bytes: 0,
1355 },
1356 (
1357 handle.slot_device(),
1358 handle.compile_needed_device(),
1359 self.num_slots,
1360 num_nodes,
1361 num_levels,
1362 root,
1363 max_var,
1364 &mut self.meta_num_nodes,
1365 &mut self.meta_num_levels,
1366 &mut self.meta_root,
1367 &mut self.meta_max_var,
1368 ),
1369 )
1370 }
1371 .map_err(|e| XlogError::Kernel(format!("cache_store_meta failed: {}", e)))?;
1372
1373 Ok(())
1376 }
1377
1378 pub fn store_weights(
1379 &mut self,
1380 handle: &GpuCircuitCacheHandle,
1381 weights_true: &TrackedCudaSlice<f64>,
1382 weights_false: &TrackedCudaSlice<f64>,
1383 ) -> Result<()> {
1384 let weights_len = handle.max_var.checked_add(1).ok_or_else(|| {
1385 XlogError::Compilation("GpuCircuitCache store_weights max_var overflow".to_string())
1386 })?;
1387 let weights_len_usize = usize::try_from(weights_len).map_err(|_| {
1388 XlogError::Compilation("GpuCircuitCache store_weights len overflow".to_string())
1389 })?;
1390 if weights_true.len() < weights_len_usize || weights_false.len() < weights_len_usize {
1391 return Err(XlogError::Compilation(format!(
1392 "GpuCircuitCache store_weights requires weights len >= {}, got true={} false={}",
1393 weights_len,
1394 weights_true.len(),
1395 weights_false.len()
1396 )));
1397 }
1398
1399 let device = self.provider.device().inner();
1400 let store_f64 = device
1401 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
1402 .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
1403
1404 let block_dim = 256u32;
1405 let grid_dim = if weights_len == 0 {
1406 0
1407 } else {
1408 weights_len.div_ceil(block_dim)
1409 };
1410 if grid_dim != 0 {
1411 let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1412 XlogError::Compilation("GpuCircuitCache store_weights var_cap overflow".to_string())
1413 })?;
1414 unsafe {
1416 store_f64.clone().launch(
1417 LaunchConfig {
1418 grid_dim: (grid_dim, 1, 1),
1419 block_dim: (block_dim, 1, 1),
1420 shared_mem_bytes: 0,
1421 },
1422 (
1423 handle.slot_device(),
1424 handle.compile_needed_device(),
1425 var_stride,
1426 weights_true,
1427 &mut self.var_log_true,
1428 weights_len,
1429 ),
1430 )
1431 }
1432 .map_err(|e| XlogError::Kernel(format!("cache_store_weights_true failed: {}", e)))?;
1433
1434 unsafe {
1436 store_f64.clone().launch(
1437 LaunchConfig {
1438 grid_dim: (grid_dim, 1, 1),
1439 block_dim: (block_dim, 1, 1),
1440 shared_mem_bytes: 0,
1441 },
1442 (
1443 handle.slot_device(),
1444 handle.compile_needed_device(),
1445 var_stride,
1446 weights_false,
1447 &mut self.var_log_false,
1448 weights_len,
1449 ),
1450 )
1451 }
1452 .map_err(|e| XlogError::Kernel(format!("cache_store_weights_false failed: {}", e)))?;
1453 }
1454
1455 Ok(())
1457 }
1458
1459 pub fn overwrite_weights(
1460 &mut self,
1461 handle: &GpuCircuitCacheHandle,
1462 weights_true: &TrackedCudaSlice<f64>,
1463 weights_false: &TrackedCudaSlice<f64>,
1464 ) -> Result<()> {
1465 let weights_len = handle.max_var.checked_add(1).ok_or_else(|| {
1466 XlogError::Compilation("GpuCircuitCache overwrite_weights max_var overflow".to_string())
1467 })?;
1468 let weights_len_usize = usize::try_from(weights_len).map_err(|_| {
1469 XlogError::Compilation("GpuCircuitCache overwrite_weights len overflow".to_string())
1470 })?;
1471 if weights_true.len() < weights_len_usize || weights_false.len() < weights_len_usize {
1472 return Err(XlogError::Compilation(format!(
1473 "GpuCircuitCache overwrite_weights requires weights len >= {}, got true={} false={}",
1474 weights_len,
1475 weights_true.len(),
1476 weights_false.len()
1477 )));
1478 }
1479
1480 let device = self.provider.device().inner();
1481 let store_f64 = device
1482 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
1483 .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
1484
1485 let block_dim = 256u32;
1486 let grid_dim = if weights_len == 0 {
1487 0
1488 } else {
1489 weights_len.div_ceil(block_dim)
1490 };
1491 if grid_dim != 0 {
1492 let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1493 XlogError::Compilation(
1494 "GpuCircuitCache overwrite_weights var_cap overflow".to_string(),
1495 )
1496 })?;
1497 unsafe {
1499 store_f64.clone().launch(
1500 LaunchConfig {
1501 grid_dim: (grid_dim, 1, 1),
1502 block_dim: (block_dim, 1, 1),
1503 shared_mem_bytes: 0,
1504 },
1505 (
1506 handle.slot_device(),
1507 &self.always_on,
1508 var_stride,
1509 weights_true,
1510 &mut self.var_log_true,
1511 weights_len,
1512 ),
1513 )
1514 }
1515 .map_err(|e| {
1516 XlogError::Kernel(format!("cache_overwrite_weights_true failed: {}", e))
1517 })?;
1518
1519 unsafe {
1521 store_f64.clone().launch(
1522 LaunchConfig {
1523 grid_dim: (grid_dim, 1, 1),
1524 block_dim: (block_dim, 1, 1),
1525 shared_mem_bytes: 0,
1526 },
1527 (
1528 handle.slot_device(),
1529 &self.always_on,
1530 var_stride,
1531 weights_false,
1532 &mut self.var_log_false,
1533 weights_len,
1534 ),
1535 )
1536 }
1537 .map_err(|e| {
1538 XlogError::Kernel(format!("cache_overwrite_weights_false failed: {}", e))
1539 })?;
1540 }
1541
1542 Ok(())
1544 }
1545
1546 pub fn store_free_var_mask(
1547 &mut self,
1548 handle: &GpuCircuitCacheHandle,
1549 mask: &TrackedCudaSlice<u8>,
1550 ) -> Result<()> {
1551 let mask_len = u32::try_from(mask.len()).map_err(|_| {
1552 XlogError::Compilation("GpuCircuitCache free_var_mask len overflow".to_string())
1553 })?;
1554 let expected_len = handle.max_var.checked_add(1).ok_or_else(|| {
1555 XlogError::Compilation("GpuCircuitCache free_var_mask max_var overflow".to_string())
1556 })?;
1557 if mask_len != expected_len {
1558 return Err(XlogError::Compilation(format!(
1559 "GpuCircuitCache free_var_mask len {} != expected {}",
1560 mask_len, expected_len
1561 )));
1562 }
1563 let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1564 XlogError::Compilation("GpuCircuitCache free_var_mask var_cap overflow".to_string())
1565 })?;
1566 if expected_len > var_stride {
1567 return Err(XlogError::Compilation(format!(
1568 "GpuCircuitCache free_var_mask len {} exceeds var_cap+1 {}",
1569 expected_len, var_stride
1570 )));
1571 }
1572
1573 let device = self.provider.device().inner();
1574 let store_u8 = device
1575 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U8)
1576 .ok_or_else(|| XlogError::Kernel("cache_store_u8 kernel not found".to_string()))?;
1577
1578 let block_dim = 256u32;
1579 let grid_dim = mask_len.div_ceil(block_dim);
1580 if grid_dim == 0 {
1581 return Ok(());
1582 }
1583
1584 unsafe {
1586 store_u8.clone().launch(
1587 LaunchConfig {
1588 grid_dim: (grid_dim, 1, 1),
1589 block_dim: (block_dim, 1, 1),
1590 shared_mem_bytes: 0,
1591 },
1592 (
1593 handle.slot_device(),
1594 handle.compile_needed_device(),
1595 var_stride,
1596 mask,
1597 &mut self.free_var_mask,
1598 mask_len,
1599 ),
1600 )
1601 }
1602 .map_err(|e| XlogError::Kernel(format!("cache_store_free_var_mask failed: {}", e)))?;
1603
1604 let slot_idx = handle.slot_index() as usize;
1606 debug_assert!(
1607 slot_idx < self.has_free_var_mask.len(),
1608 "slot_index {} exceeds num_slots {}",
1609 slot_idx,
1610 self.has_free_var_mask.len()
1611 );
1612 if slot_idx < self.has_free_var_mask.len() {
1613 self.has_free_var_mask[slot_idx] = true;
1614 }
1615 Ok(())
1616 }
1617
1618 pub(crate) fn restore_from_host_arrays(
1625 &mut self,
1626 handle: &mut GpuCircuitCacheHandle,
1627 artifact: &disk_cache::CircuitArtifact,
1628 ) -> Result<()> {
1629 let num_nodes = artifact.num_nodes;
1631 if num_nodes == 0 {
1632 return Err(XlogError::Compilation(
1633 "GpuCircuitCache restore: num_nodes must be > 0".to_string(),
1634 ));
1635 }
1636 if num_nodes > self.node_cap {
1637 return Err(XlogError::Compilation(format!(
1638 "GpuCircuitCache restore: num_nodes {} exceeds node_cap {}",
1639 num_nodes, self.node_cap
1640 )));
1641 }
1642
1643 let num_edges = artifact.num_edges;
1644 if num_edges > self.edge_cap {
1645 return Err(XlogError::Compilation(format!(
1646 "GpuCircuitCache restore: num_edges {} exceeds edge_cap {}",
1647 num_edges, self.edge_cap
1648 )));
1649 }
1650
1651 let num_levels = artifact.num_levels;
1652 if num_levels == 0 {
1653 return Err(XlogError::Compilation(
1654 "GpuCircuitCache restore: num_levels must be > 0".to_string(),
1655 ));
1656 }
1657 if num_levels > self.level_cap {
1658 return Err(XlogError::Compilation(format!(
1659 "GpuCircuitCache restore: num_levels {} exceeds level_cap {}",
1660 num_levels, self.level_cap
1661 )));
1662 }
1663
1664 let root = artifact.root;
1665 if root >= num_nodes {
1666 return Err(XlogError::Compilation(format!(
1667 "GpuCircuitCache restore: root {} out of bounds (num_nodes={})",
1668 root, num_nodes
1669 )));
1670 }
1671
1672 let max_var = artifact.max_var;
1673 if max_var > self.var_cap {
1674 return Err(XlogError::Compilation(format!(
1675 "GpuCircuitCache restore: max_var {} exceeds var_cap {}",
1676 max_var, self.var_cap
1677 )));
1678 }
1679
1680 let expected_child_offsets = (num_nodes as usize) + 1;
1681 if artifact.child_offsets.len() < expected_child_offsets {
1682 return Err(XlogError::Compilation(format!(
1683 "GpuCircuitCache restore: child_offsets len {} < num_nodes+1 {}",
1684 artifact.child_offsets.len(),
1685 expected_child_offsets
1686 )));
1687 }
1688 if artifact.level_nodes.len() < num_nodes as usize {
1689 return Err(XlogError::Compilation(format!(
1690 "GpuCircuitCache restore: level_nodes len {} < num_nodes {}",
1691 artifact.level_nodes.len(),
1692 num_nodes
1693 )));
1694 }
1695 let expected_level_offsets = (num_levels as usize) + 1;
1696 if artifact.level_offsets.len() != expected_level_offsets {
1697 return Err(XlogError::Compilation(format!(
1698 "GpuCircuitCache restore: level_offsets len {} != num_levels+1 {}",
1699 artifact.level_offsets.len(),
1700 expected_level_offsets
1701 )));
1702 }
1703
1704 handle.num_nodes = num_nodes;
1706 handle.num_levels = num_levels;
1707 handle.root = root;
1708 handle.max_var = max_var;
1709
1710 let device = self.provider.device().inner();
1712 let memory = self.provider.memory();
1713
1714 let store_u8 = device
1715 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U8)
1716 .ok_or_else(|| XlogError::Kernel("cache_store_u8 kernel not found".to_string()))?;
1717 let store_u32 = device
1718 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U32)
1719 .ok_or_else(|| XlogError::Kernel("cache_store_u32 kernel not found".to_string()))?;
1720 let store_i32 = device
1721 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_I32)
1722 .ok_or_else(|| XlogError::Kernel("cache_store_i32 kernel not found".to_string()))?;
1723 let store_meta = device
1724 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_META)
1725 .ok_or_else(|| XlogError::Kernel("cache_store_meta kernel not found".to_string()))?;
1726
1727 let block_dim = 256u32;
1728
1729 let node_stride = self.node_cap;
1730 let offset_stride = self.node_cap.checked_add(1).ok_or_else(|| {
1731 XlogError::Compilation("GpuCircuitCache restore: node_cap overflow".to_string())
1732 })?;
1733 let level_offset_stride = self.level_cap.checked_add(1).ok_or_else(|| {
1734 XlogError::Compilation("GpuCircuitCache restore: level_cap overflow".to_string())
1735 })?;
1736 let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1737 XlogError::Compilation("GpuCircuitCache restore: var_cap overflow".to_string())
1738 })?;
1739
1740 let num_nodes_plus1 = num_nodes.checked_add(1).ok_or_else(|| {
1741 XlogError::Compilation("GpuCircuitCache restore: num_nodes overflow".to_string())
1742 })?;
1743 let num_levels_plus1 = num_levels.checked_add(1).ok_or_else(|| {
1744 XlogError::Compilation("GpuCircuitCache restore: num_levels overflow".to_string())
1745 })?;
1746
1747 let grid_nodes = cache_grid_dim_for_u32_count(
1749 "GpuCircuitCache restore node_type",
1750 num_nodes,
1751 block_dim,
1752 )?;
1753 if grid_nodes != 0 {
1754 let mut d_node_type = memory.alloc::<u8>(num_nodes as usize)?;
1755 self.provider
1756 .htod_sync_copy_into_tracked(
1757 &artifact.node_type[..num_nodes as usize],
1758 &mut d_node_type,
1759 )
1760 .map_err(|e| XlogError::Kernel(format!("restore htod node_type failed: {}", e)))?;
1761 unsafe {
1763 store_u8.clone().launch(
1764 LaunchConfig {
1765 grid_dim: (grid_nodes, 1, 1),
1766 block_dim: (block_dim, 1, 1),
1767 shared_mem_bytes: 0,
1768 },
1769 (
1770 handle.slot_device(),
1771 handle.compile_needed_device(),
1772 node_stride,
1773 &d_node_type,
1774 &mut self.node_type,
1775 num_nodes,
1776 ),
1777 )
1778 }
1779 .map_err(|e| {
1780 XlogError::Kernel(format!("restore cache_store node_type failed: {}", e))
1781 })?;
1782 }
1783
1784 let grid_offsets = cache_grid_dim_for_u32_count(
1786 "GpuCircuitCache restore child_offsets",
1787 num_nodes_plus1,
1788 block_dim,
1789 )?;
1790 if grid_offsets != 0 {
1791 let mut d_child_offsets = memory.alloc::<u32>(num_nodes_plus1 as usize)?;
1792 self.provider
1793 .htod_sync_copy_into_tracked(
1794 &artifact.child_offsets[..num_nodes_plus1 as usize],
1795 &mut d_child_offsets,
1796 )
1797 .map_err(|e| {
1798 XlogError::Kernel(format!("restore htod child_offsets failed: {}", e))
1799 })?;
1800 unsafe {
1802 store_u32.clone().launch(
1803 LaunchConfig {
1804 grid_dim: (grid_offsets, 1, 1),
1805 block_dim: (block_dim, 1, 1),
1806 shared_mem_bytes: 0,
1807 },
1808 (
1809 handle.slot_device(),
1810 handle.compile_needed_device(),
1811 offset_stride,
1812 &d_child_offsets,
1813 &mut self.child_offsets,
1814 num_nodes_plus1,
1815 ),
1816 )
1817 }
1818 .map_err(|e| {
1819 XlogError::Kernel(format!("restore cache_store child_offsets failed: {}", e))
1820 })?;
1821 }
1822
1823 let grid_edges = cache_grid_dim_for_u32_count(
1825 "GpuCircuitCache restore child_indices",
1826 num_edges,
1827 block_dim,
1828 )?;
1829 if grid_edges != 0 {
1830 let mut d_child_indices = memory.alloc::<u32>(num_edges as usize)?;
1831 self.provider
1832 .htod_sync_copy_into_tracked(
1833 &artifact.child_indices[..num_edges as usize],
1834 &mut d_child_indices,
1835 )
1836 .map_err(|e| {
1837 XlogError::Kernel(format!("restore htod child_indices failed: {}", e))
1838 })?;
1839 unsafe {
1841 store_u32.clone().launch(
1842 LaunchConfig {
1843 grid_dim: (grid_edges, 1, 1),
1844 block_dim: (block_dim, 1, 1),
1845 shared_mem_bytes: 0,
1846 },
1847 (
1848 handle.slot_device(),
1849 handle.compile_needed_device(),
1850 self.edge_cap,
1851 &d_child_indices,
1852 &mut self.child_indices,
1853 num_edges,
1854 ),
1855 )
1856 }
1857 .map_err(|e| {
1858 XlogError::Kernel(format!("restore cache_store child_indices failed: {}", e))
1859 })?;
1860 }
1861
1862 if grid_nodes != 0 {
1864 let mut d_lit = memory.alloc::<i32>(num_nodes as usize)?;
1865 self.provider
1866 .htod_sync_copy_into_tracked(&artifact.lit[..num_nodes as usize], &mut d_lit)
1867 .map_err(|e| XlogError::Kernel(format!("restore htod lit failed: {}", e)))?;
1868 unsafe {
1870 store_i32.clone().launch(
1871 LaunchConfig {
1872 grid_dim: (grid_nodes, 1, 1),
1873 block_dim: (block_dim, 1, 1),
1874 shared_mem_bytes: 0,
1875 },
1876 (
1877 handle.slot_device(),
1878 handle.compile_needed_device(),
1879 node_stride,
1880 &d_lit,
1881 &mut self.lit,
1882 num_nodes,
1883 ),
1884 )
1885 }
1886 .map_err(|e| XlogError::Kernel(format!("restore cache_store lit failed: {}", e)))?;
1887
1888 let mut d_decision_var = memory.alloc::<u32>(num_nodes as usize)?;
1890 self.provider
1891 .htod_sync_copy_into_tracked(
1892 &artifact.decision_var[..num_nodes as usize],
1893 &mut d_decision_var,
1894 )
1895 .map_err(|e| {
1896 XlogError::Kernel(format!("restore htod decision_var failed: {}", e))
1897 })?;
1898 unsafe {
1900 store_u32.clone().launch(
1901 LaunchConfig {
1902 grid_dim: (grid_nodes, 1, 1),
1903 block_dim: (block_dim, 1, 1),
1904 shared_mem_bytes: 0,
1905 },
1906 (
1907 handle.slot_device(),
1908 handle.compile_needed_device(),
1909 node_stride,
1910 &d_decision_var,
1911 &mut self.decision_var,
1912 num_nodes,
1913 ),
1914 )
1915 }
1916 .map_err(|e| {
1917 XlogError::Kernel(format!("restore cache_store decision_var failed: {}", e))
1918 })?;
1919
1920 let mut d_decision_child_false = memory.alloc::<u32>(num_nodes as usize)?;
1922 self.provider
1923 .htod_sync_copy_into_tracked(
1924 &artifact.decision_child_false[..num_nodes as usize],
1925 &mut d_decision_child_false,
1926 )
1927 .map_err(|e| {
1928 XlogError::Kernel(format!("restore htod decision_child_false failed: {}", e))
1929 })?;
1930 unsafe {
1932 store_u32.clone().launch(
1933 LaunchConfig {
1934 grid_dim: (grid_nodes, 1, 1),
1935 block_dim: (block_dim, 1, 1),
1936 shared_mem_bytes: 0,
1937 },
1938 (
1939 handle.slot_device(),
1940 handle.compile_needed_device(),
1941 node_stride,
1942 &d_decision_child_false,
1943 &mut self.decision_child_false,
1944 num_nodes,
1945 ),
1946 )
1947 }
1948 .map_err(|e| {
1949 XlogError::Kernel(format!(
1950 "restore cache_store decision_child_false failed: {}",
1951 e
1952 ))
1953 })?;
1954
1955 let mut d_decision_child_true = memory.alloc::<u32>(num_nodes as usize)?;
1957 self.provider
1958 .htod_sync_copy_into_tracked(
1959 &artifact.decision_child_true[..num_nodes as usize],
1960 &mut d_decision_child_true,
1961 )
1962 .map_err(|e| {
1963 XlogError::Kernel(format!("restore htod decision_child_true failed: {}", e))
1964 })?;
1965 unsafe {
1967 store_u32.clone().launch(
1968 LaunchConfig {
1969 grid_dim: (grid_nodes, 1, 1),
1970 block_dim: (block_dim, 1, 1),
1971 shared_mem_bytes: 0,
1972 },
1973 (
1974 handle.slot_device(),
1975 handle.compile_needed_device(),
1976 node_stride,
1977 &d_decision_child_true,
1978 &mut self.decision_child_true,
1979 num_nodes,
1980 ),
1981 )
1982 }
1983 .map_err(|e| {
1984 XlogError::Kernel(format!(
1985 "restore cache_store decision_child_true failed: {}",
1986 e
1987 ))
1988 })?;
1989
1990 let mut d_level_nodes = memory.alloc::<u32>(num_nodes as usize)?;
1992 self.provider
1993 .htod_sync_copy_into_tracked(
1994 &artifact.level_nodes[..num_nodes as usize],
1995 &mut d_level_nodes,
1996 )
1997 .map_err(|e| {
1998 XlogError::Kernel(format!("restore htod level_nodes failed: {}", e))
1999 })?;
2000 unsafe {
2002 store_u32.clone().launch(
2003 LaunchConfig {
2004 grid_dim: (grid_nodes, 1, 1),
2005 block_dim: (block_dim, 1, 1),
2006 shared_mem_bytes: 0,
2007 },
2008 (
2009 handle.slot_device(),
2010 handle.compile_needed_device(),
2011 node_stride,
2012 &d_level_nodes,
2013 &mut self.level_nodes,
2014 num_nodes,
2015 ),
2016 )
2017 }
2018 .map_err(|e| {
2019 XlogError::Kernel(format!("restore cache_store level_nodes failed: {}", e))
2020 })?;
2021 }
2022
2023 let grid_levels = cache_grid_dim_for_u32_count(
2025 "GpuCircuitCache restore level_offsets",
2026 num_levels_plus1,
2027 block_dim,
2028 )?;
2029 if grid_levels != 0 {
2030 let mut d_level_offsets = memory.alloc::<u32>(num_levels_plus1 as usize)?;
2031 self.provider
2032 .htod_sync_copy_into_tracked(
2033 &artifact.level_offsets[..num_levels_plus1 as usize],
2034 &mut d_level_offsets,
2035 )
2036 .map_err(|e| {
2037 XlogError::Kernel(format!("restore htod level_offsets failed: {}", e))
2038 })?;
2039 unsafe {
2041 store_u32.clone().launch(
2042 LaunchConfig {
2043 grid_dim: (grid_levels, 1, 1),
2044 block_dim: (block_dim, 1, 1),
2045 shared_mem_bytes: 0,
2046 },
2047 (
2048 handle.slot_device(),
2049 handle.compile_needed_device(),
2050 level_offset_stride,
2051 &d_level_offsets,
2052 &mut self.level_offsets,
2053 num_levels_plus1,
2054 ),
2055 )
2056 }
2057 .map_err(|e| {
2058 XlogError::Kernel(format!("restore cache_store level_offsets failed: {}", e))
2059 })?;
2060 }
2061
2062 unsafe {
2065 store_meta.clone().launch(
2066 LaunchConfig {
2067 grid_dim: (1, 1, 1),
2068 block_dim: (1, 1, 1),
2069 shared_mem_bytes: 0,
2070 },
2071 (
2072 handle.slot_device(),
2073 handle.compile_needed_device(),
2074 self.num_slots,
2075 num_nodes,
2076 num_levels,
2077 root,
2078 max_var,
2079 &mut self.meta_num_nodes,
2080 &mut self.meta_num_levels,
2081 &mut self.meta_root,
2082 &mut self.meta_max_var,
2083 ),
2084 )
2085 }
2086 .map_err(|e| XlogError::Kernel(format!("restore cache_store_meta failed: {}", e)))?;
2087
2088 let slot_idx = handle.slot_index() as usize;
2090
2091 let mask_cap = var_stride; let grid_mask_zero = cache_grid_dim_for_u32_count(
2095 "GpuCircuitCache restore zero free_var_mask",
2096 mask_cap,
2097 block_dim,
2098 )?;
2099 if grid_mask_zero != 0 {
2100 let mut d_zeros = memory.alloc::<u8>(mask_cap as usize)?;
2101 device.memset_zeros(&mut d_zeros).map_err(|e| {
2102 XlogError::Kernel(format!("restore memset_zeros free_var_mask failed: {}", e))
2103 })?;
2104 unsafe {
2106 store_u8.clone().launch(
2107 LaunchConfig {
2108 grid_dim: (grid_mask_zero, 1, 1),
2109 block_dim: (block_dim, 1, 1),
2110 shared_mem_bytes: 0,
2111 },
2112 (
2113 handle.slot_device(),
2114 handle.compile_needed_device(),
2115 var_stride,
2116 &d_zeros,
2117 &mut self.free_var_mask,
2118 mask_cap,
2119 ),
2120 )
2121 }
2122 .map_err(|e| {
2123 XlogError::Kernel(format!(
2124 "restore cache_store zero free_var_mask failed: {}",
2125 e
2126 ))
2127 })?;
2128 }
2129
2130 let has_mask = artifact.has_free_var_mask && !artifact.free_var_mask.is_empty();
2132 if has_mask {
2133 let mask_len = max_var.checked_add(1).ok_or_else(|| {
2134 XlogError::Compilation(
2135 "GpuCircuitCache restore: free_var_mask max_var overflow".to_string(),
2136 )
2137 })?;
2138 let actual_len = std::cmp::min(mask_len as usize, artifact.free_var_mask.len());
2139 if actual_len > 0 {
2140 let actual_len_u32 = u32::try_from(actual_len).map_err(|_| {
2141 XlogError::Compilation(
2142 "GpuCircuitCache restore free_var_mask len exceeds u32".to_string(),
2143 )
2144 })?;
2145 let grid_mask = cache_grid_dim_for_u32_count(
2146 "GpuCircuitCache restore free_var_mask",
2147 actual_len_u32,
2148 block_dim,
2149 )?;
2150 if grid_mask != 0 {
2151 let mut d_mask = memory.alloc::<u8>(actual_len)?;
2152 self.provider
2153 .htod_sync_copy_into_tracked(
2154 &artifact.free_var_mask[..actual_len],
2155 &mut d_mask,
2156 )
2157 .map_err(|e| {
2158 XlogError::Kernel(format!("restore htod free_var_mask failed: {}", e))
2159 })?;
2160 unsafe {
2162 store_u8.clone().launch(
2163 LaunchConfig {
2164 grid_dim: (grid_mask, 1, 1),
2165 block_dim: (block_dim, 1, 1),
2166 shared_mem_bytes: 0,
2167 },
2168 (
2169 handle.slot_device(),
2170 handle.compile_needed_device(),
2171 var_stride,
2172 &d_mask,
2173 &mut self.free_var_mask,
2174 actual_len_u32,
2175 ),
2176 )
2177 }
2178 .map_err(|e| {
2179 XlogError::Kernel(format!(
2180 "restore cache_store free_var_mask failed: {}",
2181 e
2182 ))
2183 })?;
2184 }
2185 }
2186 }
2187
2188 debug_assert!(
2190 slot_idx < self.has_free_var_mask.len(),
2191 "slot_index {} exceeds num_slots {}",
2192 slot_idx,
2193 self.has_free_var_mask.len()
2194 );
2195 if slot_idx < self.has_free_var_mask.len() {
2196 self.has_free_var_mask[slot_idx] = has_mask;
2197 }
2198
2199 Ok(())
2202 }
2203
2204 pub(crate) fn build_artifact_from_device(
2211 &self,
2212 handle: &GpuCircuitCacheHandle,
2213 provider: &Arc<CudaKernelProvider>,
2214 ) -> Result<disk_cache::CircuitArtifact> {
2215 let device = provider.device().inner();
2216 let slot = handle.slot_index() as usize;
2217 let num_nodes = handle.num_nodes();
2218 let num_levels = handle.num_levels();
2219 let root = handle.root();
2220 let max_var = handle.max_var();
2221
2222 if num_nodes == 0 {
2223 return Err(XlogError::Compilation(
2224 "build_artifact_from_device: num_nodes is 0".to_string(),
2225 ));
2226 }
2227
2228 let node_stride = self.node_cap as usize;
2229 let offset_stride = (self.node_cap as usize) + 1;
2230 let edge_stride = self.edge_cap as usize;
2231 let level_offset_stride = (self.level_cap as usize) + 1;
2232 let var_stride = (self.var_cap as usize) + 1;
2233
2234 let slot_node_start = slot * node_stride;
2235 let slot_offset_start = slot * offset_stride;
2236 let slot_level_offset_start = slot * level_offset_stride;
2237 let slot_var_start = slot * var_stride;
2238
2239 let nn = num_nodes as usize;
2240 let nn1 = nn + 1;
2241 let nl1 = (num_levels as usize) + 1;
2242
2243 let child_offsets_view = self
2246 .child_offsets
2247 .slice(slot_offset_start..(slot_offset_start + nn1));
2248 let child_offsets: Vec<u32> = device
2249 .dtoh_sync_copy(&child_offsets_view)
2250 .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh child_offsets: {}", e)))?;
2251 let num_edges = if nn1 > 0 {
2252 child_offsets[nn]
2253 .checked_sub(child_offsets[0])
2254 .ok_or_else(|| {
2255 XlogError::Compilation(
2256 "build_artifact_from_device: child_offsets[num_nodes] < child_offsets[0]"
2257 .to_string(),
2258 )
2259 })?
2260 } else {
2261 0
2262 };
2263
2264 let slot_edge_start = slot * edge_stride;
2266 let ne = num_edges as usize;
2267 let child_indices: Vec<u32> = if ne > 0 {
2268 let view = self
2269 .child_indices
2270 .slice(slot_edge_start..(slot_edge_start + ne));
2271 device.dtoh_sync_copy(&view).map_err(|e| {
2272 XlogError::Kernel(format!("build_artifact dtoh child_indices: {}", e))
2273 })?
2274 } else {
2275 Vec::new()
2276 };
2277
2278 let node_type_view = self
2280 .node_type
2281 .slice(slot_node_start..(slot_node_start + nn));
2282 let node_type: Vec<u8> = device
2283 .dtoh_sync_copy(&node_type_view)
2284 .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh node_type: {}", e)))?;
2285
2286 let lit_view = self.lit.slice(slot_node_start..(slot_node_start + nn));
2288 let lit: Vec<i32> = device
2289 .dtoh_sync_copy(&lit_view)
2290 .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh lit: {}", e)))?;
2291
2292 let dv_view = self
2294 .decision_var
2295 .slice(slot_node_start..(slot_node_start + nn));
2296 let decision_var: Vec<u32> = device
2297 .dtoh_sync_copy(&dv_view)
2298 .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh decision_var: {}", e)))?;
2299
2300 let dcf_view = self
2302 .decision_child_false
2303 .slice(slot_node_start..(slot_node_start + nn));
2304 let decision_child_false: Vec<u32> = device.dtoh_sync_copy(&dcf_view).map_err(|e| {
2305 XlogError::Kernel(format!("build_artifact dtoh decision_child_false: {}", e))
2306 })?;
2307
2308 let dct_view = self
2310 .decision_child_true
2311 .slice(slot_node_start..(slot_node_start + nn));
2312 let decision_child_true: Vec<u32> = device.dtoh_sync_copy(&dct_view).map_err(|e| {
2313 XlogError::Kernel(format!("build_artifact dtoh decision_child_true: {}", e))
2314 })?;
2315
2316 let ln_view = self
2318 .level_nodes
2319 .slice(slot_node_start..(slot_node_start + nn));
2320 let level_nodes: Vec<u32> = device
2321 .dtoh_sync_copy(&ln_view)
2322 .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh level_nodes: {}", e)))?;
2323
2324 let lo_view = self
2326 .level_offsets
2327 .slice(slot_level_offset_start..(slot_level_offset_start + nl1));
2328 let level_offsets: Vec<u32> = device
2329 .dtoh_sync_copy(&lo_view)
2330 .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh level_offsets: {}", e)))?;
2331
2332 let has_free_var_mask = self.has_free_var_mask_for_slot(slot as u32);
2334 let mask_len = (max_var as usize) + 1;
2335 let free_var_mask: Vec<u8> = if mask_len > 0 {
2336 let fvm_view = self
2337 .free_var_mask
2338 .slice(slot_var_start..(slot_var_start + mask_len));
2339 device.dtoh_sync_copy(&fvm_view).map_err(|e| {
2340 XlogError::Kernel(format!("build_artifact dtoh free_var_mask: {}", e))
2341 })?
2342 } else {
2343 Vec::new()
2344 };
2345
2346 Ok(disk_cache::CircuitArtifact {
2347 num_nodes,
2348 num_edges,
2349 num_levels,
2350 root,
2351 max_var,
2352 has_free_var_mask,
2353 node_type,
2354 child_offsets,
2355 child_indices,
2356 lit,
2357 decision_var,
2358 decision_child_false,
2359 decision_child_true,
2360 level_nodes,
2361 level_offsets,
2362 free_var_mask,
2363 })
2364 }
2365
2366 pub fn eval_log_wmc_device_inplace(
2367 &mut self,
2368 handle: &GpuCircuitCacheHandle,
2369 out_log_z: &mut TrackedCudaSlice<f64>,
2370 ) -> Result<()> {
2371 self.eval_log_wmc_device_only(handle, out_log_z)
2372 }
2373
2374 pub fn eval_log_wmc_device_only(
2375 &mut self,
2376 handle: &GpuCircuitCacheHandle,
2377 out_log_z: &mut TrackedCudaSlice<f64>,
2378 ) -> Result<()> {
2379 if out_log_z.len() != 1 {
2380 return Err(XlogError::Compilation(format!(
2381 "GPU cache logZ output len {} != 1",
2382 out_log_z.len()
2383 )));
2384 }
2385
2386 {
2387 let device = self.provider.device().inner();
2388 let eval_all = device
2389 .get_func(
2390 xlog_cuda::CIRCUIT_MODULE,
2391 xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED,
2392 )
2393 .ok_or_else(|| {
2394 XlogError::Kernel("xgcf_eval_all_levels_cached kernel not found".to_string())
2395 })?;
2396
2397 let block_size: u32 = 256;
2398 let mut params: Vec<*mut std::ffi::c_void> = vec![
2399 handle.slot_device().as_kernel_param(),
2400 self.node_cap.as_kernel_param(),
2401 self.edge_cap.as_kernel_param(),
2402 self.level_cap.as_kernel_param(),
2403 self.var_cap.as_kernel_param(),
2404 (&self.node_type).as_kernel_param(),
2405 (&self.child_offsets).as_kernel_param(),
2406 (&self.child_indices).as_kernel_param(),
2407 (&self.lit).as_kernel_param(),
2408 (&self.decision_var).as_kernel_param(),
2409 (&self.decision_child_false).as_kernel_param(),
2410 (&self.decision_child_true).as_kernel_param(),
2411 (&self.level_nodes).as_kernel_param(),
2412 (&self.level_offsets).as_kernel_param(),
2413 (&self.var_log_true).as_kernel_param(),
2414 (&self.var_log_false).as_kernel_param(),
2415 (&self.values).as_kernel_param(),
2416 (&self.meta_num_levels).as_kernel_param(),
2417 ];
2418 unsafe {
2420 eval_all.clone().launch(
2421 LaunchConfig {
2422 grid_dim: (1, 1, 1),
2423 block_dim: (block_size, 1, 1),
2424 shared_mem_bytes: 0,
2425 },
2426 &mut params,
2427 )
2428 }
2429 .map_err(|e| XlogError::Kernel(format!("xgcf_eval_all_levels_cached failed: {}", e)))?;
2430 }
2431
2432 self.apply_free_var_correction_cached(handle, true, false)?;
2433
2434 let device = self.provider.device().inner();
2435 let copy_root = device
2436 .get_func(
2437 xlog_cuda::CIRCUIT_MODULE,
2438 xlog_cuda::circuit_kernels::XGCF_COPY_ROOT_CACHED_META,
2439 )
2440 .ok_or_else(|| {
2441 XlogError::Kernel("xgcf_copy_root_cached_meta kernel not found".to_string())
2442 })?;
2443 unsafe {
2445 copy_root.clone().launch(
2446 LaunchConfig {
2447 grid_dim: (1, 1, 1),
2448 block_dim: (1, 1, 1),
2449 shared_mem_bytes: 0,
2450 },
2451 (
2452 handle.slot_device(),
2453 self.node_cap,
2454 &self.values,
2455 &self.meta_root,
2456 out_log_z,
2457 ),
2458 )
2459 }
2460 .map_err(|e| XlogError::Kernel(format!("xgcf_copy_root_cached_meta failed: {}", e)))?;
2461
2462 Ok(())
2465 }
2466
2467 pub fn eval_grads_inplace(&mut self, handle: &GpuCircuitCacheHandle) -> Result<()> {
2468 let device = self.provider.device().inner();
2469 let eval_all = device
2470 .get_func(
2471 xlog_cuda::CIRCUIT_MODULE,
2472 xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED,
2473 )
2474 .ok_or_else(|| {
2475 XlogError::Kernel("xgcf_eval_all_levels_cached kernel not found".to_string())
2476 })?;
2477 let block_size: u32 = 256;
2478 let mut params: Vec<*mut std::ffi::c_void> = vec![
2479 handle.slot_device().as_kernel_param(),
2480 self.node_cap.as_kernel_param(),
2481 self.edge_cap.as_kernel_param(),
2482 self.level_cap.as_kernel_param(),
2483 self.var_cap.as_kernel_param(),
2484 (&self.node_type).as_kernel_param(),
2485 (&self.child_offsets).as_kernel_param(),
2486 (&self.child_indices).as_kernel_param(),
2487 (&self.lit).as_kernel_param(),
2488 (&self.decision_var).as_kernel_param(),
2489 (&self.decision_child_false).as_kernel_param(),
2490 (&self.decision_child_true).as_kernel_param(),
2491 (&self.level_nodes).as_kernel_param(),
2492 (&self.level_offsets).as_kernel_param(),
2493 (&self.var_log_true).as_kernel_param(),
2494 (&self.var_log_false).as_kernel_param(),
2495 (&self.values).as_kernel_param(),
2496 (&self.meta_num_levels).as_kernel_param(),
2497 ];
2498 unsafe {
2500 eval_all.clone().launch(
2501 LaunchConfig {
2502 grid_dim: (1, 1, 1),
2503 block_dim: (block_size, 1, 1),
2504 shared_mem_bytes: 0,
2505 },
2506 &mut params,
2507 )
2508 }
2509 .map_err(|e| XlogError::Kernel(format!("xgcf_eval_all_levels_cached failed: {}", e)))?;
2510
2511 let device = self.provider.device().inner();
2512 let store_f64 = device
2513 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
2514 .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
2515
2516 let node_stride = self.node_cap;
2517 let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
2518 XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2519 })?;
2520 let weights_len = self.var_cap.checked_add(1).ok_or_else(|| {
2521 XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2522 })?;
2523
2524 let grid_nodes = cache_grid_dim_for_u32_count(
2525 "GpuCircuitCache eval_grads zero adj",
2526 self.node_cap,
2527 block_size,
2528 )?;
2529 if grid_nodes != 0 {
2530 unsafe {
2532 store_f64.clone().launch(
2533 LaunchConfig {
2534 grid_dim: (grid_nodes, 1, 1),
2535 block_dim: (block_size, 1, 1),
2536 shared_mem_bytes: 0,
2537 },
2538 (
2539 handle.slot_device(),
2540 &self.always_on,
2541 node_stride,
2542 &self.zero_f64,
2543 &mut self.adj,
2544 self.node_cap,
2545 ),
2546 )
2547 }
2548 .map_err(|e| XlogError::Kernel(format!("cache zero adj failed: {}", e)))?;
2549 }
2550
2551 let grid_weights = cache_grid_dim_for_u32_count(
2552 "GpuCircuitCache eval_grads zero weights",
2553 weights_len,
2554 block_size,
2555 )?;
2556 if grid_weights != 0 {
2557 unsafe {
2559 store_f64.clone().launch(
2560 LaunchConfig {
2561 grid_dim: (grid_weights, 1, 1),
2562 block_dim: (block_size, 1, 1),
2563 shared_mem_bytes: 0,
2564 },
2565 (
2566 handle.slot_device(),
2567 &self.always_on,
2568 var_stride,
2569 &self.zero_f64,
2570 &mut self.grad_true,
2571 weights_len,
2572 ),
2573 )
2574 }
2575 .map_err(|e| XlogError::Kernel(format!("cache zero grad_true failed: {}", e)))?;
2576
2577 unsafe {
2579 store_f64.clone().launch(
2580 LaunchConfig {
2581 grid_dim: (grid_weights, 1, 1),
2582 block_dim: (block_size, 1, 1),
2583 shared_mem_bytes: 0,
2584 },
2585 (
2586 handle.slot_device(),
2587 &self.always_on,
2588 var_stride,
2589 &self.zero_f64,
2590 &mut self.grad_false,
2591 weights_len,
2592 ),
2593 )
2594 }
2595 .map_err(|e| XlogError::Kernel(format!("cache zero grad_false failed: {}", e)))?;
2596 }
2597
2598 let add_scalar = device
2599 .get_func(
2600 xlog_cuda::CIRCUIT_MODULE,
2601 xlog_cuda::circuit_kernels::XGCF_ADD_SCALAR_CACHED,
2602 )
2603 .ok_or_else(|| {
2604 XlogError::Kernel("xgcf_add_scalar_cached kernel not found".to_string())
2605 })?;
2606 unsafe {
2608 add_scalar.clone().launch(
2609 LaunchConfig {
2610 grid_dim: (1, 1, 1),
2611 block_dim: (1, 1, 1),
2612 shared_mem_bytes: 0,
2613 },
2614 (
2615 handle.slot_device(),
2616 self.node_cap,
2617 &mut self.adj,
2618 &self.meta_root,
2619 &self.one_f64,
2620 ),
2621 )
2622 }
2623 .map_err(|e| XlogError::Kernel(format!("xgcf_add_scalar_cached (adj) failed: {}", e)))?;
2624
2625 let propagate = device
2626 .get_func(
2627 xlog_cuda::CIRCUIT_MODULE,
2628 xlog_cuda::circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE_CACHED,
2629 )
2630 .ok_or_else(|| {
2631 XlogError::Kernel(
2632 "xgcf_backward_level_propagate_cached kernel not found".to_string(),
2633 )
2634 })?;
2635 let decision_grad = device
2636 .get_func(
2637 xlog_cuda::CIRCUIT_MODULE,
2638 xlog_cuda::circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD_CACHED,
2639 )
2640 .ok_or_else(|| {
2641 XlogError::Kernel(
2642 "xgcf_backward_level_decision_grad_cached kernel not found".to_string(),
2643 )
2644 })?;
2645 let lit_grad = device
2646 .get_func(
2647 xlog_cuda::CIRCUIT_MODULE,
2648 xlog_cuda::circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD_CACHED,
2649 )
2650 .ok_or_else(|| {
2651 XlogError::Kernel(
2652 "xgcf_backward_level_lit_grad_cached kernel not found".to_string(),
2653 )
2654 })?;
2655
2656 let num_blocks = self.node_cap.div_ceil(block_size);
2657 let num_levels = self.level_cap;
2658 for level in (0..num_levels).rev() {
2659 if num_blocks == 0 {
2660 continue;
2661 }
2662 let level_u32: u32 = level;
2663 let mut params: Vec<*mut std::ffi::c_void> = vec![
2664 handle.slot_device().as_kernel_param(),
2665 self.node_cap.as_kernel_param(),
2666 self.edge_cap.as_kernel_param(),
2667 self.level_cap.as_kernel_param(),
2668 self.var_cap.as_kernel_param(),
2669 (&self.node_type).as_kernel_param(),
2670 (&self.child_offsets).as_kernel_param(),
2671 (&self.child_indices).as_kernel_param(),
2672 (&self.decision_var).as_kernel_param(),
2673 (&self.decision_child_false).as_kernel_param(),
2674 (&self.decision_child_true).as_kernel_param(),
2675 (&self.level_nodes).as_kernel_param(),
2676 (&self.level_offsets).as_kernel_param(),
2677 level_u32.as_kernel_param(),
2678 (&self.var_log_true).as_kernel_param(),
2679 (&self.var_log_false).as_kernel_param(),
2680 (&self.values).as_kernel_param(),
2681 (&self.adj).as_kernel_param(),
2682 (&self.meta_num_levels).as_kernel_param(),
2683 ];
2684
2685 unsafe {
2687 propagate.clone().launch(
2688 LaunchConfig {
2689 grid_dim: (num_blocks, 1, 1),
2690 block_dim: (block_size, 1, 1),
2691 shared_mem_bytes: 0,
2692 },
2693 &mut params,
2694 )
2695 }
2696 .map_err(|e| {
2697 XlogError::Kernel(format!(
2698 "xgcf_backward_level_propagate_cached failed: {}",
2699 e
2700 ))
2701 })?;
2702
2703 let mut params: Vec<*mut std::ffi::c_void> = vec![
2704 handle.slot_device().as_kernel_param(),
2705 self.node_cap.as_kernel_param(),
2706 self.edge_cap.as_kernel_param(),
2707 self.level_cap.as_kernel_param(),
2708 self.var_cap.as_kernel_param(),
2709 (&self.node_type).as_kernel_param(),
2710 (&self.decision_var).as_kernel_param(),
2711 (&self.decision_child_false).as_kernel_param(),
2712 (&self.decision_child_true).as_kernel_param(),
2713 (&self.level_nodes).as_kernel_param(),
2714 (&self.level_offsets).as_kernel_param(),
2715 level_u32.as_kernel_param(),
2716 (&self.var_log_true).as_kernel_param(),
2717 (&self.var_log_false).as_kernel_param(),
2718 (&self.values).as_kernel_param(),
2719 (&self.adj).as_kernel_param(),
2720 (&self.grad_true).as_kernel_param(),
2721 (&self.grad_false).as_kernel_param(),
2722 (&self.meta_num_levels).as_kernel_param(),
2723 ];
2724
2725 unsafe {
2727 decision_grad.clone().launch(
2728 LaunchConfig {
2729 grid_dim: (num_blocks, 1, 1),
2730 block_dim: (block_size, 1, 1),
2731 shared_mem_bytes: 0,
2732 },
2733 &mut params,
2734 )
2735 }
2736 .map_err(|e| {
2737 XlogError::Kernel(format!(
2738 "xgcf_backward_level_decision_grad_cached failed: {}",
2739 e
2740 ))
2741 })?;
2742
2743 let mut params: Vec<*mut std::ffi::c_void> = vec![
2744 handle.slot_device().as_kernel_param(),
2745 self.node_cap.as_kernel_param(),
2746 self.edge_cap.as_kernel_param(),
2747 self.level_cap.as_kernel_param(),
2748 self.var_cap.as_kernel_param(),
2749 (&self.node_type).as_kernel_param(),
2750 (&self.lit).as_kernel_param(),
2751 (&self.level_nodes).as_kernel_param(),
2752 (&self.level_offsets).as_kernel_param(),
2753 level_u32.as_kernel_param(),
2754 (&self.adj).as_kernel_param(),
2755 (&self.grad_true).as_kernel_param(),
2756 (&self.grad_false).as_kernel_param(),
2757 (&self.meta_num_levels).as_kernel_param(),
2758 ];
2759
2760 unsafe {
2762 lit_grad.clone().launch(
2763 LaunchConfig {
2764 grid_dim: (num_blocks, 1, 1),
2765 block_dim: (block_size, 1, 1),
2766 shared_mem_bytes: 0,
2767 },
2768 &mut params,
2769 )
2770 }
2771 .map_err(|e| {
2772 XlogError::Kernel(format!("xgcf_backward_level_lit_grad_cached failed: {}", e))
2773 })?;
2774 }
2775
2776 self.apply_free_var_correction_cached(handle, true, true)?;
2777 Ok(())
2780 }
2781
2782 pub fn eval_grads_inplace_fused(&mut self, handle: &GpuCircuitCacheHandle) -> Result<()> {
2787 let device = self.provider.device().inner();
2788 let eval_all = device
2789 .get_func(
2790 xlog_cuda::CIRCUIT_MODULE,
2791 xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED,
2792 )
2793 .ok_or_else(|| {
2794 XlogError::Kernel("xgcf_eval_all_levels_cached kernel not found".to_string())
2795 })?;
2796 let block_size: u32 = 256;
2797 let mut params: Vec<*mut std::ffi::c_void> = vec![
2798 handle.slot_device().as_kernel_param(),
2799 self.node_cap.as_kernel_param(),
2800 self.edge_cap.as_kernel_param(),
2801 self.level_cap.as_kernel_param(),
2802 self.var_cap.as_kernel_param(),
2803 (&self.node_type).as_kernel_param(),
2804 (&self.child_offsets).as_kernel_param(),
2805 (&self.child_indices).as_kernel_param(),
2806 (&self.lit).as_kernel_param(),
2807 (&self.decision_var).as_kernel_param(),
2808 (&self.decision_child_false).as_kernel_param(),
2809 (&self.decision_child_true).as_kernel_param(),
2810 (&self.level_nodes).as_kernel_param(),
2811 (&self.level_offsets).as_kernel_param(),
2812 (&self.var_log_true).as_kernel_param(),
2813 (&self.var_log_false).as_kernel_param(),
2814 (&self.values).as_kernel_param(),
2815 (&self.meta_num_levels).as_kernel_param(),
2816 ];
2817 unsafe {
2819 eval_all.clone().launch(
2820 LaunchConfig {
2821 grid_dim: (1, 1, 1),
2822 block_dim: (block_size, 1, 1),
2823 shared_mem_bytes: 0,
2824 },
2825 &mut params,
2826 )
2827 }
2828 .map_err(|e| XlogError::Kernel(format!("xgcf_eval_all_levels_cached failed: {}", e)))?;
2829
2830 let device = self.provider.device().inner();
2831 let store_f64 = device
2832 .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
2833 .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
2834
2835 let node_stride = self.node_cap;
2836 let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
2837 XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2838 })?;
2839 let weights_len = self.var_cap.checked_add(1).ok_or_else(|| {
2840 XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2841 })?;
2842
2843 let grid_nodes = cache_grid_dim_for_u32_count(
2844 "GpuCircuitCache batched eval_grads zero adj",
2845 self.node_cap,
2846 block_size,
2847 )?;
2848 if grid_nodes != 0 {
2849 unsafe {
2851 store_f64.clone().launch(
2852 LaunchConfig {
2853 grid_dim: (grid_nodes, 1, 1),
2854 block_dim: (block_size, 1, 1),
2855 shared_mem_bytes: 0,
2856 },
2857 (
2858 handle.slot_device(),
2859 &self.always_on,
2860 node_stride,
2861 &self.zero_f64,
2862 &mut self.adj,
2863 self.node_cap,
2864 ),
2865 )
2866 }
2867 .map_err(|e| XlogError::Kernel(format!("cache zero adj failed: {}", e)))?;
2868 }
2869
2870 let grid_weights = cache_grid_dim_for_u32_count(
2871 "GpuCircuitCache batched eval_grads zero weights",
2872 weights_len,
2873 block_size,
2874 )?;
2875 if grid_weights != 0 {
2876 unsafe {
2878 store_f64.clone().launch(
2879 LaunchConfig {
2880 grid_dim: (grid_weights, 1, 1),
2881 block_dim: (block_size, 1, 1),
2882 shared_mem_bytes: 0,
2883 },
2884 (
2885 handle.slot_device(),
2886 &self.always_on,
2887 var_stride,
2888 &self.zero_f64,
2889 &mut self.grad_true,
2890 weights_len,
2891 ),
2892 )
2893 }
2894 .map_err(|e| XlogError::Kernel(format!("cache zero grad_true failed: {}", e)))?;
2895
2896 unsafe {
2898 store_f64.clone().launch(
2899 LaunchConfig {
2900 grid_dim: (grid_weights, 1, 1),
2901 block_dim: (block_size, 1, 1),
2902 shared_mem_bytes: 0,
2903 },
2904 (
2905 handle.slot_device(),
2906 &self.always_on,
2907 var_stride,
2908 &self.zero_f64,
2909 &mut self.grad_false,
2910 weights_len,
2911 ),
2912 )
2913 }
2914 .map_err(|e| XlogError::Kernel(format!("cache zero grad_false failed: {}", e)))?;
2915 }
2916
2917 let add_scalar = device
2918 .get_func(
2919 xlog_cuda::CIRCUIT_MODULE,
2920 xlog_cuda::circuit_kernels::XGCF_ADD_SCALAR_CACHED,
2921 )
2922 .ok_or_else(|| {
2923 XlogError::Kernel("xgcf_add_scalar_cached kernel not found".to_string())
2924 })?;
2925 unsafe {
2927 add_scalar.clone().launch(
2928 LaunchConfig {
2929 grid_dim: (1, 1, 1),
2930 block_dim: (1, 1, 1),
2931 shared_mem_bytes: 0,
2932 },
2933 (
2934 handle.slot_device(),
2935 self.node_cap,
2936 &mut self.adj,
2937 &self.meta_root,
2938 &self.one_f64,
2939 ),
2940 )
2941 }
2942 .map_err(|e| XlogError::Kernel(format!("xgcf_add_scalar_cached (adj) failed: {}", e)))?;
2943
2944 let backward_all = device
2946 .get_func(
2947 xlog_cuda::CIRCUIT_MODULE,
2948 xlog_cuda::circuit_kernels::XGCF_BACKWARD_ALL_LEVELS_CACHED,
2949 )
2950 .ok_or_else(|| XlogError::Kernel("xgcf_backward_all_levels_cached not found".into()))?;
2951
2952 let mut params: Vec<*mut std::ffi::c_void> = vec![
2953 handle.slot_device().as_kernel_param(),
2954 self.node_cap.as_kernel_param(),
2955 self.edge_cap.as_kernel_param(),
2956 self.level_cap.as_kernel_param(),
2957 self.var_cap.as_kernel_param(),
2958 (&self.node_type).as_kernel_param(),
2959 (&self.child_offsets).as_kernel_param(),
2960 (&self.child_indices).as_kernel_param(),
2961 (&self.decision_var).as_kernel_param(),
2962 (&self.decision_child_false).as_kernel_param(),
2963 (&self.decision_child_true).as_kernel_param(),
2964 (&self.lit).as_kernel_param(),
2965 (&self.level_nodes).as_kernel_param(),
2966 (&self.level_offsets).as_kernel_param(),
2967 (&self.var_log_true).as_kernel_param(),
2968 (&self.var_log_false).as_kernel_param(),
2969 (&self.values).as_kernel_param(),
2970 (&self.adj).as_kernel_param(),
2971 (&self.grad_true).as_kernel_param(),
2972 (&self.grad_false).as_kernel_param(),
2973 (&self.meta_num_levels).as_kernel_param(),
2974 ];
2975
2976 unsafe {
2978 backward_all.clone().launch(
2979 LaunchConfig {
2980 grid_dim: (1, 1, 1),
2981 block_dim: (block_size, 1, 1),
2982 shared_mem_bytes: 0,
2983 },
2984 &mut params,
2985 )
2986 }
2987 .map_err(|e| XlogError::Kernel(format!("xgcf_backward_all_levels_cached failed: {}", e)))?;
2988
2989 self.apply_free_var_correction_cached(handle, true, true)?;
2990 Ok(())
2991 }
2992
2993 fn apply_free_var_correction_cached(
2994 &mut self,
2995 handle: &GpuCircuitCacheHandle,
2996 apply_log_z: bool,
2997 apply_grads: bool,
2998 ) -> Result<()> {
2999 if !self.has_free_var_mask_for_slot(handle.slot_index()) {
3000 return Ok(());
3001 }
3002 let n = self
3003 .var_cap
3004 .checked_add(1)
3005 .ok_or_else(|| XlogError::Compilation("GPU cache free-var overflow".to_string()))?;
3006 if n == 0 {
3007 return Ok(());
3008 }
3009
3010 let device = self.provider.device().inner();
3011 let block_dim = 256u32;
3012 let grid_dim = n.div_ceil(block_dim);
3013
3014 if apply_grads {
3015 let apply_grad = device
3016 .get_func(
3017 xlog_cuda::CIRCUIT_MODULE,
3018 xlog_cuda::circuit_kernels::XGCF_FREE_VAR_APPLY_GRAD_CACHED,
3019 )
3020 .ok_or_else(|| {
3021 XlogError::Kernel(
3022 "xgcf_free_var_apply_grad_cached kernel not found".to_string(),
3023 )
3024 })?;
3025 unsafe {
3027 apply_grad.clone().launch(
3028 LaunchConfig {
3029 grid_dim: (grid_dim, 1, 1),
3030 block_dim: (block_dim, 1, 1),
3031 shared_mem_bytes: 0,
3032 },
3033 (
3034 handle.slot_device(),
3035 self.var_cap,
3036 &self.free_var_mask,
3037 &self.var_log_true,
3038 &self.var_log_false,
3039 n,
3040 &mut self.grad_true,
3041 &mut self.grad_false,
3042 ),
3043 )
3044 }
3045 .map_err(|e| {
3046 XlogError::Kernel(format!("xgcf_free_var_apply_grad_cached failed: {}", e))
3047 })?;
3048 }
3049
3050 if apply_log_z {
3051 let reduce_stage = device
3052 .get_func(
3053 xlog_cuda::CIRCUIT_MODULE,
3054 xlog_cuda::circuit_kernels::XGCF_FREE_VAR_REDUCE_STAGE_CACHED,
3055 )
3056 .ok_or_else(|| {
3057 XlogError::Kernel(
3058 "xgcf_free_var_reduce_stage_cached kernel not found".to_string(),
3059 )
3060 })?;
3061 let add_scalar = device
3062 .get_func(
3063 xlog_cuda::CIRCUIT_MODULE,
3064 xlog_cuda::circuit_kernels::XGCF_ADD_SCALAR_CACHED,
3065 )
3066 .ok_or_else(|| {
3067 XlogError::Kernel("xgcf_add_scalar_cached kernel not found".to_string())
3068 })?;
3069
3070 let memory = self.provider.memory();
3071 let mut buf_a = memory.alloc::<f64>(n as usize)?;
3072 let mut buf_b = memory.alloc::<f64>(n as usize)?;
3073
3074 let mut stage_n = n;
3075 let mut stage0 = true;
3076 let mut output_is_a = true;
3077 loop {
3078 let out_len = stage_n.div_ceil(2);
3079 let stage_grid = out_len.div_ceil(block_dim);
3080
3081 let (in_buf, out_buf): (&TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) =
3082 if output_is_a {
3083 (&buf_b, &mut buf_a)
3084 } else {
3085 (&buf_a, &mut buf_b)
3086 };
3087 let mode = if stage0 { 0u32 } else { 1u32 };
3088
3089 unsafe {
3091 reduce_stage.clone().launch(
3092 LaunchConfig {
3093 grid_dim: (stage_grid, 1, 1),
3094 block_dim: (block_dim, 1, 1),
3095 shared_mem_bytes: 0,
3096 },
3097 (
3098 handle.slot_device(),
3099 self.var_cap,
3100 &self.free_var_mask,
3101 &self.var_log_true,
3102 &self.var_log_false,
3103 in_buf,
3104 stage_n,
3105 mode,
3106 out_buf,
3107 ),
3108 )
3109 }
3110 .map_err(|e| {
3111 XlogError::Kernel(format!("xgcf_free_var_reduce_stage_cached failed: {}", e))
3112 })?;
3113
3114 if out_len == 1 {
3115 let result_buf = if output_is_a { &buf_a } else { &buf_b };
3116 unsafe {
3118 add_scalar.clone().launch(
3119 LaunchConfig {
3120 grid_dim: (1, 1, 1),
3121 block_dim: (1, 1, 1),
3122 shared_mem_bytes: 0,
3123 },
3124 (
3125 handle.slot_device(),
3126 self.node_cap,
3127 &mut self.values,
3128 &self.meta_root,
3129 result_buf,
3130 ),
3131 )
3132 }
3133 .map_err(|e| {
3134 XlogError::Kernel(format!("xgcf_add_scalar_cached failed: {}", e))
3135 })?;
3136 break;
3137 }
3138
3139 stage_n = out_len;
3140 stage0 = false;
3141 output_is_a = !output_is_a;
3142 }
3143 }
3144
3145 Ok(())
3146 }
3147}