Skip to main content

xlog_runtime/executor/
join_cache.rs

1use std::collections::HashMap;
2use xlog_core::{RelId, ScalarType, Schema};
3use xlog_cuda::{CudaBuffer, JoinIndexV2};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub(crate) struct JoinIndexKey {
7    pub(crate) rel: RelId,
8    pub(crate) version: u64,
9    pub(crate) key_cols: Vec<usize>,
10    pub(crate) schema: JoinIndexSchemaSignature,
11    pub(crate) device_ordinal: u32,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub(crate) struct JoinIndexSchemaSignature {
16    column_types: Vec<ScalarType>,
17    row_size_bytes: usize,
18}
19
20impl JoinIndexSchemaSignature {
21    fn from_schema(schema: &Schema) -> Self {
22        Self {
23            column_types: (0..schema.arity())
24                .filter_map(|idx| schema.column_type(idx))
25                .collect(),
26            row_size_bytes: schema.row_size_bytes(),
27        }
28    }
29}
30
31impl JoinIndexKey {
32    pub(crate) fn new(
33        rel: RelId,
34        version: u64,
35        key_cols: Vec<usize>,
36        schema: &Schema,
37        device_ordinal: u32,
38    ) -> Self {
39        Self {
40            rel,
41            version,
42            key_cols,
43            schema: JoinIndexSchemaSignature::from_schema(schema),
44            device_ordinal,
45        }
46    }
47}
48
49struct CachedJoinIndex {
50    index: CachedJoinIndexPayload,
51    bytes: u64,
52    last_used: u64,
53}
54
55#[allow(clippy::large_enum_variant)]
56enum CachedJoinIndexPayload {
57    Ready(JoinIndexV2),
58    #[cfg(test)]
59    Placeholder,
60}
61
62/// Persistent join-index manager telemetry.
63#[derive(Clone, Debug, Default, PartialEq, Eq)]
64pub struct JoinIndexCacheStats {
65    /// Lookup attempts.
66    pub lookups: u64,
67    /// Successful index reuses.
68    pub hits: u64,
69    /// Lookup misses.
70    pub misses: u64,
71    /// Successful index builds inserted into the cache.
72    pub builds: u64,
73    /// LRU/budget evictions.
74    pub evictions: u64,
75    /// Entries invalidated because a relation changed.
76    pub invalidations: u64,
77    /// Stale entries rejected by provider validation.
78    pub stale_rejections: u64,
79    /// Background-build mode requests.
80    pub background_build_requests: u64,
81    /// Background-build mode completions.
82    pub background_builds_completed: u64,
83    /// Background builds whose indexed reuse was deferred until a later evaluation.
84    pub background_builds_deferred: u64,
85    /// Current retained index count.
86    pub entries: usize,
87    /// Current retained index bytes.
88    pub total_bytes: u64,
89}
90
91pub(crate) struct JoinIndexCache {
92    entries: HashMap<JoinIndexKey, CachedJoinIndex>,
93    clock: u64,
94    total_bytes: u64,
95    pub(crate) max_bytes: u64,
96    stats: JoinIndexCacheStats,
97}
98
99/// Estimate the GPU memory footprint of a join index built on `right` with `right_keys`.
100///
101/// Returns u64::MAX if keys are empty or column types are missing (signals "don't build").
102pub(crate) fn estimate_join_index_bytes(right: &CudaBuffer, right_keys: &[usize]) -> u64 {
103    if right_keys.is_empty() {
104        return u64::MAX;
105    }
106
107    let mut key_bytes_per_row: u64 = 0;
108    for &k in right_keys {
109        let Some(ty) = right.schema().column_type(k) else {
110            return u64::MAX;
111        };
112        key_bytes_per_row = key_bytes_per_row.saturating_add(ty.size_bytes() as u64);
113    }
114
115    let num_rows = right.num_rows();
116    let packed_bytes = num_rows.saturating_mul(key_bytes_per_row);
117    let target = num_rows.saturating_mul(2).max(1024);
118    let num_buckets = target.next_power_of_two();
119
120    // Stored index bytes: packed keys + (counts+offsets) + (entry row ids + entry hashes)
121    packed_bytes
122        .saturating_add(num_buckets.saturating_mul(8))
123        .saturating_add(num_rows.saturating_mul(12))
124}
125
126impl JoinIndexCache {
127    pub(crate) fn new(max_bytes: u64) -> Self {
128        Self {
129            entries: HashMap::new(),
130            clock: 0,
131            total_bytes: 0,
132            max_bytes,
133            stats: JoinIndexCacheStats::default(),
134        }
135    }
136
137    /// Decide whether to build a new join index for a relation.
138    ///
139    /// Heuristic: require higher "heat" for larger indexes, and avoid building under
140    /// memory pressure. Always skip if the estimated index cannot fit in the cache budget.
141    pub(crate) fn should_build(
142        &self,
143        est_index_bytes: u64,
144        build_heat: f32,
145        remaining_device_bytes: u64,
146        device_budget_bytes: u64,
147    ) -> bool {
148        let heat_threshold = if self.max_bytes > 0 && est_index_bytes > self.max_bytes / 2 {
149            0.6
150        } else {
151            0.3
152        };
153        let has_room =
154            remaining_device_bytes >= est_index_bytes.saturating_add(device_budget_bytes / 10);
155
156        build_heat >= heat_threshold && est_index_bytes <= self.max_bytes && has_room
157    }
158
159    pub(crate) fn clear(&mut self) {
160        let removed = self.entries.len() as u64;
161        self.entries.clear();
162        self.clock = 0;
163        self.total_bytes = 0;
164        self.stats.invalidations = self.stats.invalidations.saturating_add(removed);
165    }
166
167    pub(crate) fn get(&mut self, key: &JoinIndexKey) -> Option<&JoinIndexV2> {
168        self.stats.lookups = self.stats.lookups.saturating_add(1);
169        let Some(entry) = self.entries.get_mut(key) else {
170            self.stats.misses = self.stats.misses.saturating_add(1);
171            return None;
172        };
173        self.clock = self.clock.saturating_add(1);
174        entry.last_used = self.clock;
175        match &entry.index {
176            CachedJoinIndexPayload::Ready(index) => {
177                self.stats.hits = self.stats.hits.saturating_add(1);
178                Some(index)
179            }
180            #[cfg(test)]
181            CachedJoinIndexPayload::Placeholder => {
182                self.stats.misses = self.stats.misses.saturating_add(1);
183                None
184            }
185        }
186    }
187
188    pub(crate) fn insert(&mut self, key: JoinIndexKey, index: JoinIndexV2) {
189        let bytes = index.estimated_bytes();
190        if bytes > self.max_bytes {
191            return;
192        }
193
194        self.evict_until_fits(bytes);
195
196        self.clock = self.clock.saturating_add(1);
197        let last_used = self.clock;
198
199        if let Some(prev) = self.entries.remove(&key) {
200            self.total_bytes = self.total_bytes.saturating_sub(prev.bytes);
201        }
202
203        self.total_bytes = self.total_bytes.saturating_add(bytes);
204        self.entries.insert(
205            key,
206            CachedJoinIndex {
207                index: CachedJoinIndexPayload::Ready(index),
208                bytes,
209                last_used,
210            },
211        );
212        self.stats.builds = self.stats.builds.saturating_add(1);
213    }
214
215    pub(crate) fn remove(&mut self, key: &JoinIndexKey) {
216        if let Some(prev) = self.entries.remove(key) {
217            self.total_bytes = self.total_bytes.saturating_sub(prev.bytes);
218        }
219    }
220
221    pub(crate) fn remove_stale(&mut self, key: &JoinIndexKey) {
222        let before = self.entries.len();
223        self.remove(key);
224        if self.entries.len() < before {
225            self.stats.stale_rejections = self.stats.stale_rejections.saturating_add(1);
226        }
227    }
228
229    pub(crate) fn invalidate_rel(&mut self, rel: RelId) {
230        let keys: Vec<JoinIndexKey> = self
231            .entries
232            .keys()
233            .filter(|k| k.rel == rel)
234            .cloned()
235            .collect();
236        for key in keys {
237            if let Some(entry) = self.entries.remove(&key) {
238                self.total_bytes = self.total_bytes.saturating_sub(entry.bytes);
239                self.stats.invalidations = self.stats.invalidations.saturating_add(1);
240            }
241        }
242    }
243
244    pub(crate) fn evict_until_fits(&mut self, additional_bytes: u64) {
245        while !self.entries.is_empty()
246            && self.total_bytes.saturating_add(additional_bytes) > self.max_bytes
247        {
248            let mut oldest_key: Option<JoinIndexKey> = None;
249            let mut oldest_clock = u64::MAX;
250
251            for (k, v) in &self.entries {
252                if v.last_used < oldest_clock {
253                    oldest_clock = v.last_used;
254                    oldest_key = Some(k.clone());
255                }
256            }
257
258            let Some(key) = oldest_key else {
259                break;
260            };
261            if let Some(entry) = self.entries.remove(&key) {
262                self.total_bytes = self.total_bytes.saturating_sub(entry.bytes);
263                self.stats.evictions = self.stats.evictions.saturating_add(1);
264            } else {
265                break;
266            }
267        }
268    }
269
270    pub(crate) fn record_background_build_request(&mut self) {
271        self.stats.background_build_requests =
272            self.stats.background_build_requests.saturating_add(1);
273    }
274
275    pub(crate) fn record_background_build_complete(&mut self) {
276        self.stats.background_builds_completed =
277            self.stats.background_builds_completed.saturating_add(1);
278    }
279
280    pub(crate) fn record_background_build_deferred(&mut self) {
281        self.stats.background_builds_deferred =
282            self.stats.background_builds_deferred.saturating_add(1);
283    }
284
285    pub(crate) fn stats(&self) -> JoinIndexCacheStats {
286        let mut stats = self.stats.clone();
287        stats.entries = self.entries.len();
288        stats.total_bytes = self.total_bytes;
289        stats
290    }
291
292    #[cfg(test)]
293    fn insert_test_entry(&mut self, key: JoinIndexKey, bytes: u64) {
294        if bytes > self.max_bytes {
295            return;
296        }
297        self.evict_until_fits(bytes);
298        self.clock = self.clock.saturating_add(1);
299        self.total_bytes = self.total_bytes.saturating_add(bytes);
300        self.entries.insert(
301            key,
302            CachedJoinIndex {
303                index: CachedJoinIndexPayload::Placeholder,
304                bytes,
305                last_used: self.clock,
306            },
307        );
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use xlog_core::{ScalarType, Schema};
315
316    fn schema(cols: Vec<(&str, ScalarType)>) -> Schema {
317        Schema::new(
318            cols.into_iter()
319                .map(|(name, ty)| (name.to_string(), ty))
320                .collect(),
321        )
322    }
323
324    #[test]
325    fn persistent_key_includes_schema_generation_key_and_device() {
326        let u32_schema = schema(vec![("k", ScalarType::U32)]);
327        let u64_schema = schema(vec![("k", ScalarType::U64)]);
328
329        let key = JoinIndexKey::new(RelId(7), 3, vec![0], &u32_schema, 0);
330        assert_eq!(key.rel, RelId(7));
331        assert_eq!(key.version, 3);
332        assert_eq!(key.key_cols, vec![0]);
333        assert_eq!(key.device_ordinal, 0);
334
335        assert_ne!(
336            key,
337            JoinIndexKey::new(RelId(7), 4, vec![0], &u32_schema, 0),
338            "generation/version must partition stale indexes"
339        );
340        assert_ne!(
341            key,
342            JoinIndexKey::new(RelId(7), 3, vec![0], &u64_schema, 0),
343            "schema changes must partition indexes"
344        );
345        assert_ne!(
346            key,
347            JoinIndexKey::new(RelId(7), 3, vec![0], &u32_schema, 1),
348            "device ordinal must partition indexes"
349        );
350    }
351
352    #[test]
353    fn persistent_cache_budget_evicts_lru_and_records_stats() {
354        let schema = schema(vec![("k", ScalarType::U32)]);
355        let key_a = JoinIndexKey::new(RelId(1), 1, vec![0], &schema, 0);
356        let key_b = JoinIndexKey::new(RelId(2), 1, vec![0], &schema, 0);
357        let mut cache = JoinIndexCache::new(100);
358
359        cache.insert_test_entry(key_a, 60);
360        cache.insert_test_entry(key_b, 60);
361
362        let stats = cache.stats();
363        assert_eq!(stats.entries, 1);
364        assert_eq!(stats.total_bytes, 60);
365        assert_eq!(stats.evictions, 1);
366    }
367
368    #[test]
369    fn persistent_cache_invalidation_records_removed_entries() {
370        let schema = schema(vec![("k", ScalarType::U32)]);
371        let key = JoinIndexKey::new(RelId(1), 1, vec![0], &schema, 0);
372        let mut cache = JoinIndexCache::new(100);
373
374        cache.insert_test_entry(key, 32);
375        cache.invalidate_rel(RelId(1));
376
377        let stats = cache.stats();
378        assert_eq!(stats.entries, 0);
379        assert_eq!(stats.total_bytes, 0);
380        assert_eq!(stats.invalidations, 1);
381    }
382}