xlog_cuda/device_runtime/
stream_pool.rs1use std::sync::Arc;
22use std::sync::Mutex;
23
24use cudarc::driver::CudaStream;
25
26use super::resource::StreamId;
27use crate::CudaDevice;
28
29pub const DEFAULT_MAX_STREAMS: usize = 16;
34pub const ENV_WCOJ_POOL_MB_PER_STREAM: &str = "XLOG_WCOJ_POOL_MB_PER_STREAM";
35pub const DEFAULT_POOL_MB_PER_STREAM: u64 = 256;
36
37pub fn configured_pool_mb_per_stream() -> u64 {
38 std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM)
39 .ok()
40 .and_then(|raw| raw.trim().parse::<u64>().ok())
41 .filter(|mb| *mb > 0)
42 .unwrap_or(DEFAULT_POOL_MB_PER_STREAM)
43}
44
45pub fn configured_pool_bytes_per_stream() -> u64 {
46 configured_pool_mb_per_stream().saturating_mul(1024 * 1024)
47}
48
49pub fn planned_pool_budget_bytes(arms: u64, streams: u64) -> u64 {
50 arms.saturating_mul(streams)
51 .saturating_mul(configured_pool_bytes_per_stream())
52}
53
54#[derive(Debug)]
57pub enum StreamPoolError {
58 Capacity { max: usize },
62 ForkFailed(String),
65}
66
67impl std::fmt::Display for StreamPoolError {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 Self::Capacity { max } => {
71 write!(f, "stream pool at capacity (max={})", max)
72 }
73 Self::ForkFailed(msg) => {
74 write!(f, "stream fork failed: {}", msg)
75 }
76 }
77 }
78}
79
80impl std::error::Error for StreamPoolError {}
81
82pub struct StreamPool {
84 device: Arc<CudaDevice>,
85 max_streams: usize,
86 pool_bytes_per_stream: u64,
87 streams: Mutex<Vec<Arc<CudaStream>>>,
92}
93
94impl StreamPool {
95 pub fn new(device: Arc<CudaDevice>, max_streams: usize) -> Self {
97 Self {
98 device,
99 max_streams: max_streams.max(1),
100 pool_bytes_per_stream: configured_pool_bytes_per_stream(),
101 streams: Mutex::new(Vec::new()),
102 }
103 }
104
105 pub fn with_defaults(device: Arc<CudaDevice>) -> Self {
107 Self::new(device, DEFAULT_MAX_STREAMS)
108 }
109
110 pub fn acquire(&self) -> Result<StreamId, StreamPoolError> {
126 let mut streams = self.streams.lock().expect("stream pool poisoned");
127 if streams.len() >= self.max_streams {
128 return Err(StreamPoolError::Capacity {
129 max: self.max_streams,
130 });
131 }
132 match self.device.inner().stream().fork() {
133 Ok(handle) => {
134 streams.push(handle);
135 Ok(StreamId(streams.len() as u32))
139 }
140 Err(e) => Err(StreamPoolError::ForkFailed(e.to_string())),
141 }
142 }
143
144 pub fn resolve(&self, id: StreamId) -> Option<Arc<CudaStream>> {
148 if id == StreamId::DEFAULT {
149 return Some(Arc::clone(self.device.inner().stream()));
150 }
151 let streams = self.streams.lock().expect("stream pool poisoned");
152 let idx = id.0 as usize;
153 if idx == 0 || idx > streams.len() {
154 return None;
155 }
156 Some(Arc::clone(&streams[idx - 1]))
157 }
158
159 pub fn non_default_len(&self) -> usize {
161 self.streams.lock().expect("stream pool poisoned").len()
162 }
163
164 pub fn device(&self) -> &Arc<CudaDevice> {
167 &self.device
168 }
169
170 pub fn max_streams(&self) -> usize {
172 self.max_streams
173 }
174
175 pub fn pool_bytes_per_stream(&self) -> u64 {
178 self.pool_bytes_per_stream
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 static ENV_LOCK: Mutex<()> = Mutex::new(());
187
188 fn try_device() -> Option<Arc<CudaDevice>> {
189 CudaDevice::new(0).ok().map(Arc::new)
190 }
191
192 #[test]
193 fn acquire_returns_distinct_non_default_ids() {
194 let Some(device) = try_device() else {
195 return;
196 };
197 let pool = StreamPool::new(device, 4);
198 let a = pool.acquire().expect("first acquire");
199 let b = pool.acquire().expect("second acquire");
200 assert_ne!(a, StreamId::DEFAULT);
201 assert_ne!(b, StreamId::DEFAULT);
202 assert_ne!(a, b, "consecutive acquire calls must yield distinct ids");
203 assert_eq!(pool.non_default_len(), 2);
204 }
205
206 #[test]
207 fn acquire_returns_capacity_error_at_max() {
208 let Some(device) = try_device() else {
209 return;
210 };
211 let pool = StreamPool::new(device, 1);
212 let _first = pool.acquire().expect("first acquire under cap");
213 let err = pool.acquire();
214 assert!(
215 matches!(err, Err(StreamPoolError::Capacity { max: 1 })),
216 "expected Capacity error once max_streams hit, got {:?}",
217 err
218 );
219 }
220
221 #[test]
222 fn resolve_default_returns_device_default_stream() {
223 let Some(device) = try_device() else {
224 return;
225 };
226 let pool = StreamPool::with_defaults(device);
227 assert!(pool.resolve(StreamId::DEFAULT).is_some());
228 }
229
230 #[test]
231 fn resolve_acquired_returns_owned_stream() {
232 let Some(device) = try_device() else {
233 return;
234 };
235 let pool = StreamPool::new(device, 4);
236 let id = pool.acquire().expect("acquire");
237 assert_ne!(id, StreamId::DEFAULT);
238 assert!(pool.resolve(id).is_some());
239 }
240
241 #[test]
242 fn resolve_unknown_returns_none() {
243 let Some(device) = try_device() else {
244 return;
245 };
246 let pool = StreamPool::with_defaults(device);
247 assert!(pool.resolve(StreamId(99)).is_none());
248 }
249
250 #[test]
251 fn pool_mb_per_stream_env_overrides_default() {
252 let _guard = ENV_LOCK.lock().expect("env lock poisoned");
253 let old = std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM).ok();
254 std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, "128");
255 assert_eq!(configured_pool_mb_per_stream(), 128);
256 match old {
257 Some(value) => std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, value),
258 None => std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM),
259 }
260 }
261
262 #[test]
263 fn planned_pool_budget_uses_default_4_by_4_contract() {
264 let _guard = ENV_LOCK.lock().expect("env lock poisoned");
265 let old = std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM).ok();
266 std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM);
267 assert_eq!(configured_pool_mb_per_stream(), DEFAULT_POOL_MB_PER_STREAM);
268 assert_eq!(
269 planned_pool_budget_bytes(4, 4),
270 4_u64 * 4 * 256 * 1024 * 1024
271 );
272 match old {
273 Some(value) => std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, value),
274 None => std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM),
275 }
276 }
277}