1pub mod index;
16pub mod provenance;
17pub mod reduce;
18pub mod score;
19pub mod types;
20mod validate;
21
22pub use provenance::InductionProvenanceRegistry;
23pub use reduce::{reduce_per_topology, ScoredPair};
24pub use types::{
25 ExactInductionConfig, ExactInductionResult, InducedRuleProvenance, InducedRuleRegistry,
26 InductionAlternative, InductionSupportRow, RuleSourceKind, ScoredCandidate, Topology,
27};
28
29use xlog_core::{RelId, Result, ScalarType, XlogError};
30use xlog_cuda::{CudaBuffer, CudaKernelProvider};
31
32use validate::{classify_request, PreKernelOutcome, RequestMetadata};
33
34#[derive(Clone, Copy, Debug, Eq, PartialEq)]
35enum ExactPairType {
36 U64,
37 U32,
38 Symbol,
39}
40
41pub struct InduceExactRequest<'a> {
51 pub head_rel_idx: RelId,
52 pub candidates: &'a [(RelId, &'a CudaBuffer)],
53 pub positives: &'a CudaBuffer,
54 pub negatives: Option<&'a CudaBuffer>,
55 pub config: ExactInductionConfig,
56}
57
58pub fn induce_exact(
67 provider: &CudaKernelProvider,
68 request: &InduceExactRequest<'_>,
69) -> Result<ExactInductionResult> {
70 if request.candidates.is_empty() {
73 return Ok(ExactInductionResult::default());
74 }
75
76 let pair_type = validate_pair_buffer(request.positives, "positives")?;
79 if let Some(neg) = request.negatives {
80 require_pair_type(neg, "negatives", pair_type)?;
81 }
82 for (i, (_, buf)) in request.candidates.iter().enumerate() {
83 require_pair_type(buf, &format!("candidate[{}]", i), pair_type)?;
84 }
85
86 let pos_count = cached_rows(request.positives, "positives")?;
92 let neg_count = request
93 .negatives
94 .map(|b| cached_rows(b, "negatives"))
95 .transpose()?
96 .unwrap_or(0);
97
98 let meta = RequestMetadata {
99 candidate_count: request.candidates.len() as u32,
100 positive_count: pos_count,
101 negative_count: neg_count,
102 k_per_topology: request.config.k_per_topology,
103 };
104
105 match classify_request(meta) {
106 PreKernelOutcome::TrivialEmpty(result) => Ok(result),
107 PreKernelOutcome::Proceed(m) => score_and_reduce(provider, request, m),
108 }
109}
110
111fn score_and_reduce(
112 provider: &CudaKernelProvider,
113 request: &InduceExactRequest<'_>,
114 meta: RequestMetadata,
115) -> Result<ExactInductionResult> {
116 let empty_neg_holder: Option<CudaBuffer> = if request.negatives.is_none() {
121 Some(provider.create_empty_buffer(request.positives.schema().clone())?)
122 } else {
123 None
124 };
125 let negatives: &CudaBuffer = match request.negatives {
126 Some(b) => b,
127 None => empty_neg_holder
128 .as_ref()
129 .expect("holder populated in the None branch above"),
130 };
131
132 let candidate_buffers: Vec<&CudaBuffer> = request.candidates.iter().map(|(_, b)| *b).collect();
134 let selected = provider.ilp_exact_score_topk(
135 &candidate_buffers,
136 request.positives,
137 negatives,
138 request.config.k_per_topology,
139 )?;
140 let mut candidates = Vec::with_capacity(selected.len());
141 for row in selected {
142 let topology = topology_from_kernel_idx(row.topology_idx)?;
143 let left_idx = row.left_idx as usize;
144 let right_idx = row.right_idx as usize;
145 let (left_rel_idx, _) = request.candidates.get(left_idx).ok_or_else(|| {
146 XlogError::Execution(format!(
147 "induce_exact: device selector returned left index {} for {} candidates",
148 left_idx,
149 request.candidates.len()
150 ))
151 })?;
152 let (right_rel_idx, _) = request.candidates.get(right_idx).ok_or_else(|| {
153 XlogError::Execution(format!(
154 "induce_exact: device selector returned right index {} for {} candidates",
155 right_idx,
156 request.candidates.len()
157 ))
158 })?;
159 candidates.push(ScoredCandidate {
160 topology,
161 head_rel_idx: request.head_rel_idx,
162 left_rel_idx: *left_rel_idx,
163 right_rel_idx: *right_rel_idx,
164 positives_covered: row.positives_covered,
165 negatives_covered: row.negatives_covered,
166 local_rank: row.local_rank,
167 next_positives_covered: row.next_positives_covered,
168 next_negatives_covered: row.next_negatives_covered,
169 tie_class_size: row.tie_class_size,
170 });
171 }
172 let total_scored = 4u32
173 .checked_mul(meta.candidate_count)
174 .and_then(|v| v.checked_mul(meta.candidate_count))
175 .ok_or_else(|| XlogError::Execution("induce_exact: total_scored overflow".into()))?;
176
177 Ok(ExactInductionResult {
178 candidates,
179 total_scored,
180 candidate_count: meta.candidate_count,
181 positive_count: meta.positive_count,
182 negative_count: meta.negative_count,
183 })
184}
185
186fn topology_from_kernel_idx(idx: u32) -> Result<Topology> {
187 match idx {
188 0 => Ok(Topology::Chain),
189 1 => Ok(Topology::Star),
190 2 => Ok(Topology::Fanout),
191 3 => Ok(Topology::Fanin),
192 _ => Err(XlogError::Execution(format!(
193 "induce_exact: device selector returned topology index {}",
194 idx
195 ))),
196 }
197}
198
199fn validate_pair_buffer(buf: &CudaBuffer, label: &str) -> Result<ExactPairType> {
200 if buf.arity() != 2 {
201 return Err(XlogError::Execution(format!(
202 "induce_exact: {} buffer has arity {}, expected 2",
203 label,
204 buf.arity(),
205 )));
206 }
207 let mut pair_type = None;
208 for col_idx in 0..2 {
209 let t = buf.schema().column_type(col_idx).ok_or_else(|| {
210 XlogError::Type(format!(
211 "induce_exact: {} buffer column {} has no schema type",
212 label, col_idx,
213 ))
214 })?;
215 let col_type = match t {
216 ScalarType::U64 => ExactPairType::U64,
217 ScalarType::U32 => ExactPairType::U32,
218 ScalarType::Symbol => ExactPairType::Symbol,
219 _ => {
220 return Err(XlogError::Type(format!(
221 "induce_exact: {} buffer column {} has type {:?}, expected U64, U32, or Symbol",
222 label, col_idx, t,
223 )));
224 }
225 };
226 if let Some(expected) = pair_type {
227 if expected != col_type {
228 return Err(XlogError::Type(format!(
229 "induce_exact: {} buffer column {} type mismatch: {:?} vs {:?}",
230 label, col_idx, expected, col_type,
231 )));
232 }
233 } else {
234 pair_type = Some(col_type);
235 }
236 }
237 Ok(pair_type.expect("arity 2 loop sets pair type"))
238}
239
240fn require_pair_type(buf: &CudaBuffer, label: &str, expected: ExactPairType) -> Result<()> {
241 let actual = validate_pair_buffer(buf, label)?;
242 if actual != expected {
243 return Err(XlogError::Type(format!(
244 "induce_exact: {} buffer type mismatch: expected {:?}, got {:?}",
245 label, expected, actual,
246 )));
247 }
248 Ok(())
249}
250
251fn cached_rows(buf: &CudaBuffer, label: &str) -> Result<u32> {
252 buf.cached_row_count().ok_or_else(|| {
253 XlogError::Execution(format!(
254 "induce_exact: {} buffer has no cached row count \
255 (DLPack ingest path should populate it; required to avoid hot-loop device-to-host transfer)",
256 label,
257 ))
258 })
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn topology_as_str_matches_python_contract() {
267 assert_eq!(Topology::Chain.as_str(), "chain");
268 assert_eq!(Topology::Star.as_str(), "star");
269 assert_eq!(Topology::Fanout.as_str(), "fanout");
270 assert_eq!(Topology::Fanin.as_str(), "fanin");
271 }
272
273 #[test]
274 fn topology_all_is_engine_order() {
275 assert_eq!(
276 Topology::ALL,
277 [
278 Topology::Chain,
279 Topology::Star,
280 Topology::Fanout,
281 Topology::Fanin
282 ],
283 );
284 }
285}