Skip to main content

xlog_logic/optimizer/
stream_schedule_pass.rs

1//! Ahead-of-time stream schedule construction for independent WCOJ rules.
2
3use xlog_ir::Stratum;
4
5/// Hardware inputs used by the stream schedule pass.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct HardwareCapabilities {
8    /// Number of streaming multiprocessors visible to the runtime.
9    pub sm_count: usize,
10    /// Count of independent rules in the stratum.
11    pub independent_rule_count: usize,
12}
13
14impl HardwareCapabilities {
15    /// Build schedule inputs from an SM count and independent-rule count.
16    pub fn new(sm_count: usize, independent_rule_count: usize) -> Self {
17        Self {
18            sm_count,
19            independent_rule_count,
20        }
21    }
22}
23
24/// Phase node kind in the Count -> Scan -> Resize -> Materialize schedule.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum StreamPhase {
27    /// WCOJ count kernel phase.
28    Count,
29    /// Deterministic prefix scan over per-block counts.
30    Scan,
31    /// Output allocation phase after scan determines cardinality.
32    Resize,
33    /// WCOJ materialize kernel phase.
34    Materialize,
35}
36
37/// One phase scheduled for one independent rule.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct StreamPhaseNode {
40    /// Stratum id that owns this phase node.
41    pub stratum_id: u32,
42    /// Rule index within the stratum's independent-rule list.
43    pub rule_index: usize,
44    /// Phase executed for this rule.
45    pub phase: StreamPhase,
46    /// CUDA stream slot selected by greedy bin assignment.
47    pub stream_index: usize,
48}
49
50/// Phase-aligned stream schedule for a stratum.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct StreamSchedule {
53    /// Stratum id this schedule covers.
54    pub stratum_id: u32,
55    /// Number of CUDA stream slots used by the schedule.
56    pub stream_count: usize,
57    /// Ordered phase nodes. Phases are grouped by phase kind, then rule index.
58    pub phases: Vec<StreamPhaseNode>,
59}
60
61/// Build the phase-aligned stream schedule for a stratum.
62///
63/// Stream count follows the GPU occupancy rule:
64/// `min(SM_count / 4, max_independent_rules_in_stratum)`, with one stream
65/// retained for single-rule strata so the scheduler produces the same serial
66/// execution shape as the non-mux path.
67pub fn schedule_streams(stratum: &Stratum, hw: &HardwareCapabilities) -> StreamSchedule {
68    let rule_count = hw.independent_rule_count;
69    let stream_count = if rule_count == 0 {
70        0
71    } else {
72        let sm_lanes = (hw.sm_count / 4).max(1);
73        sm_lanes.min(rule_count)
74    };
75    let mut phases = Vec::with_capacity(rule_count.saturating_mul(4));
76    for phase in [
77        StreamPhase::Count,
78        StreamPhase::Scan,
79        StreamPhase::Resize,
80        StreamPhase::Materialize,
81    ] {
82        for rule_index in 0..rule_count {
83            phases.push(StreamPhaseNode {
84                stratum_id: stratum.id,
85                rule_index,
86                phase,
87                stream_index: if stream_count == 0 {
88                    0
89                } else {
90                    rule_index % stream_count
91                },
92            });
93        }
94    }
95    StreamSchedule {
96        stratum_id: stratum.id,
97        stream_count,
98        phases,
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    fn stratum() -> Stratum {
107        Stratum {
108            id: 7,
109            sccs: vec![0],
110        }
111    }
112
113    #[test]
114    fn schedules_four_rules_on_four_streams_by_phase() {
115        let schedule = schedule_streams(&stratum(), &HardwareCapabilities::new(16, 4));
116        assert_eq!(schedule.stream_count, 4);
117        assert_eq!(schedule.phases.len(), 16);
118        assert!(schedule.phases[0..4]
119            .iter()
120            .all(|node| node.phase == StreamPhase::Count));
121        assert!(schedule.phases[4..8]
122            .iter()
123            .all(|node| node.phase == StreamPhase::Scan));
124        assert!(schedule.phases[8..12]
125            .iter()
126            .all(|node| node.phase == StreamPhase::Resize));
127        assert!(schedule.phases[12..16]
128            .iter()
129            .all(|node| node.phase == StreamPhase::Materialize));
130        let stream_slots: Vec<_> = schedule.phases[0..4]
131            .iter()
132            .map(|node| node.stream_index)
133            .collect();
134        assert_eq!(stream_slots, vec![0, 1, 2, 3]);
135    }
136
137    #[test]
138    fn single_rule_uses_one_stream() {
139        let schedule = schedule_streams(&stratum(), &HardwareCapabilities::new(16, 1));
140        assert_eq!(schedule.stream_count, 1);
141        assert_eq!(schedule.phases.len(), 4);
142        assert!(schedule.phases.iter().all(|node| node.stream_index == 0));
143    }
144}