xlog_cuda/device_runtime/async_resource.rs
1//! [`AsyncCudaResource`] — stream-ordered allocation backed by
2//! cudarc's `CudaStream::alloc` (which forwards to `cuMemAllocAsync`
3//! when the context supports it).
4//!
5//! Each [`DeviceMemoryResource::allocate`] call resolves the
6//! caller-supplied [`StreamId`] to a live `cudarc::driver::CudaStream`
7//! via the [`StreamPool`], allocates against that stream, and stores
8//! the resulting `CudaSlice<u8>` in the resource's live map. Drop on
9//! deallocate invokes `cuMemFreeAsync` (when supported) on the same
10//! stream the allocation was bound to.
11//!
12//! This backend is the production candidate. It is **not** the
13//! sanitizer/cert backend — pool/async behavior can hide byte-level
14//! out-of-bounds patterns from Compute Sanitizer; the cert role
15//! belongs to [`DirectCudaResource`] (subject to manual Compute Sanitizer
16//! confirmation on a supported host).
17//!
18//! # Stream-ordering contract enforced here
19//! * `allocate(.., stream, ..)` is ordered on the resolved
20//! `CudaStream`. The returned `DeviceBlock` carries the same
21//! `alloc_stream`.
22//! * `deallocate(block)` releases the underlying memory ordered on
23//! the block's `alloc_stream`. Callers must have synchronized any
24//! work on a different stream before deallocation.
25//! * Reuse of the underlying byte address by a future `allocate` is
26//! ordered after the previous deallocate by the CUDA driver's
27//! stream-ordered memory allocator semantics. The stream-ordered
28//! allocation lifetime regression test encodes this.
29//!
30//! # `bytes_outstanding` and pending-free accounting
31//!
32//! The trait contract is "live + retired-but-not-yet-freed". A queued
33//! `cuMemFreeAsync` is "retired-but-not-yet-freed" until the host
34//! synchronizes the stream the free was queued on. We therefore keep
35//! two atomic counters:
36//!
37//! * `live_bytes` — bytes for blocks currently in the live map.
38//! * `pending_bytes` — bytes for blocks whose `CudaSlice` has been
39//! dropped (so a `cuMemFreeAsync` is queued on the alloc stream)
40//! but whose stream has not yet been synchronized by us.
41//!
42//! `bytes_outstanding()` returns `live_bytes + pending_bytes`.
43//!
44//! `reap_pending()` drains the per-stream pending map under the
45//! per-stream mutex, synchronizes each drained stream, and then
46//! subtracts only the **synchronized** total from `pending_bytes`
47//! via `fetch_sub` — it does **not** zero the counter. A
48//! `deallocate` that races between reap's drain and its `fetch_sub`
49//! re-populates both the per-stream map and the global atomic
50//! together (under the same mutex), so its bytes either land
51//! entirely before the drain (reaped this round) or entirely after
52//! (kept for the next reap), never split.
53//!
54//! On the first stream-sync failure, the failing entry and every
55//! remaining un-iterated drained entry are **restored** into
56//! `pending_per_stream` so a subsequent reap can retry them. Only
57//! the bytes for streams that successfully synchronized are
58//! decremented from `pending_bytes`. Without this recovery, a
59//! transient driver error mid-reap would lose track of pending
60//! bytes forever — the drained map would be gone, `pending_bytes`
61//! would still count them, but no stream id would be queued for
62//! a future reap. Production callers (`GlobalDeviceBudget`, the
63//! stream-ordered allocation lifetime tests' final assertions) thus see consistent
64//! `bytes_outstanding()` even on transient sync failures.
65
66use std::collections::HashMap;
67use std::sync::atomic::{AtomicUsize, Ordering};
68use std::sync::{Arc, Mutex};
69
70use cudarc::driver::{CudaEvent, CudaSlice};
71
72use super::resource::{
73 Access, AllocTag, BlockId, BlockState, DeviceBlock, DeviceMemoryResource, Generation,
74 ResourceError, ResourceResult, StreamId,
75};
76use super::stream_pool::StreamPool;
77use crate::CudaDevice;
78
79/// One live allocation tracked by [`AsyncCudaResource`]. Carries
80/// the cudarc-owned `CudaSlice<u8>` (whose drop queues the
81/// underlying `cuMemFreeAsync`) plus access-aware dependency
82/// state and the allocation's [`Generation`].
83///
84/// # Dependency state
85///
86/// The block's outstanding dependencies are tracked in two
87/// distinct sets so future operations can wait on the minimal
88/// correct fence:
89///
90/// * `last_write` — the most recent write event recorded on
91/// the block, paired with the stream that recorded it. A
92/// subsequent read on a different stream must wait on this
93/// event; a subsequent write on a different stream must wait
94/// on this event AND every entry in `outstanding_reads`.
95/// * `outstanding_reads` — every read event recorded since the
96/// current `last_write` was installed (or since allocation,
97/// if no write has occurred yet), each paired with its
98/// recording stream. A subsequent write on a different
99/// stream must wait on each entry here. Cleared at finish
100/// time when a new write event replaces `last_write`: the
101/// writer's prepare-time waits already subsumed every prior
102/// reader's dependency, so any future operation that waits
103/// on the new `last_write` transitively observes those
104/// reads' completion.
105///
106/// On `deallocate`, the alloc stream waits on `last_write` (if
107/// any) AND every entry in `outstanding_reads` before the queued
108/// `cuMemFreeAsync` runs.
109///
110/// # ABA / generation guard
111///
112/// The `generation` field guards against address recycling:
113/// every API that mutates the entry validates
114/// `block.generation == entry.generation` before touching it.
115/// Mismatch returns [`ResourceError::UseAfterFree`] and the
116/// entry is unchanged.
117struct LiveEntry {
118 slice: CudaSlice<u8>,
119 generation: Generation,
120 /// Most recent write event on this block, OR the
121 /// allocation-ready event if no write has happened yet.
122 /// Future reads/writes on a different stream wait on this
123 /// event. Replaced by `finish_block_use` for
124 /// `Access::Write` / `Access::ReadWrite`. The
125 /// allocation-ready seed exists because cuMemAllocAsync
126 /// orders the allocation only on `alloc_stream` — a
127 /// cross-stream consumer that submits a kernel before
128 /// allocation completes would read pool-recycled garbage.
129 last_write: Option<(StreamId, CudaEvent)>,
130 /// Read events recorded since `last_write` was installed
131 /// (or since allocation). Future writes on a different
132 /// stream wait on each entry. Cleared by `finish_block_use`
133 /// when a write replaces `last_write`.
134 outstanding_reads: Vec<(StreamId, CudaEvent)>,
135}
136
137/// Stream-ordered cudarc-backed allocator.
138pub struct AsyncCudaResource {
139 device: Arc<CudaDevice>,
140 device_ordinal: u32,
141 stream_pool: Arc<StreamPool>,
142 /// Live allocations keyed by raw device pointer. Each entry
143 /// holds the cudarc slice and any recorded last-use events
144 /// from cross-stream consumers. Removed on deallocate; the
145 /// slice is then dropped, queueing `cuMemFreeAsync` on its
146 /// bound stream — *after* the stream has been told to wait on
147 /// every recorded event.
148 live: Mutex<HashMap<u64, LiveEntry>>,
149 /// Bytes for blocks currently in `live`. Always accurate.
150 live_bytes: AtomicUsize,
151 /// Bytes for blocks dropped (queued for cuMemFreeAsync) but
152 /// whose owning stream has not yet been synchronized by us.
153 /// Equal to the sum of values in `pending_per_stream`. Both are
154 /// updated under the `pending_per_stream` mutex so a concurrent
155 /// `reap_pending` cannot wipe out bytes that a racing
156 /// `deallocate` queued after reap drained the per-stream map.
157 pending_bytes: AtomicUsize,
158 /// Per-stream pending-free byte totals. Used by `reap_pending`
159 /// to (a) compute the total to subtract from `pending_bytes`
160 /// after stream synchronization, and (b) preserve any bytes
161 /// added by a `deallocate` that races with reap — those bytes
162 /// remain in this map and in `pending_bytes`, ready for the
163 /// next reap.
164 pending_per_stream: Mutex<HashMap<StreamId, usize>>,
165}
166
167impl AsyncCudaResource {
168 /// Construct a resource bound to `device` using `stream_pool` for
169 /// stream resolution. `device_ordinal` is the CUDA ordinal for
170 /// logging / multi-device disambiguation.
171 pub fn new(device: Arc<CudaDevice>, device_ordinal: u32, stream_pool: Arc<StreamPool>) -> Self {
172 Self {
173 device,
174 device_ordinal,
175 stream_pool,
176 live: Mutex::new(HashMap::new()),
177 live_bytes: AtomicUsize::new(0),
178 pending_bytes: AtomicUsize::new(0),
179 pending_per_stream: Mutex::new(HashMap::new()),
180 }
181 }
182
183 pub fn device(&self) -> &Arc<CudaDevice> {
184 &self.device
185 }
186
187 pub fn stream_pool(&self) -> &Arc<StreamPool> {
188 &self.stream_pool
189 }
190
191 /// Bytes currently held by live blocks (excludes pending frees).
192 /// Test/diagnostic accessor — production code should use
193 /// `bytes_outstanding`.
194 pub fn live_bytes(&self) -> usize {
195 self.live_bytes.load(Ordering::Relaxed)
196 }
197
198 /// Bytes queued for `cuMemFreeAsync` whose stream has not yet
199 /// been synchronized by us. Test/diagnostic accessor.
200 pub fn pending_free_bytes(&self) -> usize {
201 self.pending_bytes.load(Ordering::Relaxed)
202 }
203
204 /// Sum of per-stream pending byte tallies. Test/diagnostic
205 /// accessor used to assert the invariant
206 /// `pending_free_bytes() == pending_per_stream_total()`. The
207 /// invariant must hold at any quiescent moment; if it fails
208 /// the bookkeeping under the `pending_per_stream` mutex has
209 /// drifted from the global atomic — see `deallocate` and
210 /// `reap_pending`, which update both as a unit.
211 pub fn pending_per_stream_total(&self) -> usize {
212 let map = self
213 .pending_per_stream
214 .lock()
215 .expect("AsyncCudaResource pending_per_stream poisoned");
216 map.values().copied().sum()
217 }
218
219 /// Number of recorded outstanding-read events plus a
220 /// last_write event (0 or 1) currently attached to the live
221 /// block at `ptr`. Test/diagnostic accessor — used by
222 /// reproducers to confirm `finish_block_use` actually
223 /// attached events before deallocate consumed them. Returns
224 /// `None` if `ptr` is not currently in the live map.
225 pub fn pending_use_event_count(&self, ptr: u64) -> Option<usize> {
226 let live = self
227 .live
228 .lock()
229 .expect("AsyncCudaResource live map poisoned");
230 live.get(&ptr)
231 .map(|e| e.outstanding_reads.len() + if e.last_write.is_some() { 1 } else { 0 })
232 }
233}
234
235impl DeviceMemoryResource for AsyncCudaResource {
236 fn allocate(
237 &self,
238 bytes: usize,
239 stream: StreamId,
240 tag: AllocTag,
241 ) -> ResourceResult<DeviceBlock> {
242 if bytes == 0 {
243 return Err(ResourceError::Driver(
244 "AsyncCudaResource: zero-byte allocation not supported".to_string(),
245 ));
246 }
247 let cu_stream = self.stream_pool.resolve(stream).ok_or_else(|| {
248 ResourceError::StreamMisuse(format!(
249 "AsyncCudaResource: unknown StreamId({})",
250 stream.0
251 ))
252 })?;
253
254 // SAFETY: bytes > 0 verified above. cudarc's
255 // `CudaStream::alloc::<u8>(len)` forwards to `cuMemAllocAsync`
256 // when the context has async-alloc enabled (CUDA 11.2+);
257 // otherwise it falls back to synchronous alloc internally.
258 // Failures are surfaced as `ResourceError::Driver`.
259 let slice = unsafe {
260 cu_stream
261 .alloc::<u8>(bytes)
262 .map_err(|e| ResourceError::Driver(format!("cuMemAllocAsync({}): {}", bytes, e)))?
263 };
264
265 // Record an "allocation-ready" event on the alloc stream
266 // immediately after the cuMemAllocAsync call. Cross-
267 // stream consumers MUST wait on this event before
268 // touching the bytes, otherwise the launch (on a
269 // different stream) may begin before the allocation
270 // completes and read pre-init / pool-recycled garbage.
271 // We store it in `last_write` so the access-aware
272 // prepare path's existing read-waits-on-last_write and
273 // write-waits-on-last_write rules cover it for free.
274 // Same-stream consumers skip the wait (already ordered).
275 let alloc_event = cu_stream.record_event(None).map_err(|e| {
276 ResourceError::Driver(format!(
277 "AsyncCudaResource::allocate: record allocation-ready event failed: {}",
278 e
279 ))
280 })?;
281
282 // Extract the raw device pointer for the DeviceBlock surface.
283 // The "sync" handle returned by `device_ptr` is intentionally
284 // leaked — the slice's lifetime is managed by our live map,
285 // not by the sync token.
286 let (raw_ptr, sync) =
287 <CudaSlice<u8> as cudarc::driver::DevicePtr<u8>>::device_ptr(&slice, slice.stream());
288 std::mem::forget(sync);
289 let ptr = raw_ptr;
290
291 {
292 let mut live = self
293 .live
294 .lock()
295 .expect("AsyncCudaResource live map poisoned");
296 // Use `contains_key` then `insert` so a (theoretical)
297 // pointer collision returns `Err` without mutating the
298 // map. The `live.insert(ptr, slice).is_some()` pattern
299 // would replace the existing entry, drop the old slice
300 // (queueing cuMemFreeAsync on memory we still believe
301 // we own), and leave the new slice resident while we
302 // return Err — `live_bytes` would also not be updated.
303 // Avoid that here.
304 if live.contains_key(&ptr) {
305 return Err(ResourceError::Driver(format!(
306 "AsyncCudaResource: pointer collision on alloc ({:#x})",
307 ptr
308 )));
309 }
310 // Generation must match between the LiveEntry and the
311 // returned DeviceBlock so record_block_use and
312 // deallocate can ABA-validate by (ptr, generation).
313 let generation = Generation::next();
314 live.insert(
315 ptr,
316 LiveEntry {
317 slice,
318 generation,
319 last_write: Some((stream, alloc_event)),
320 outstanding_reads: Vec::new(),
321 },
322 );
323 self.live_bytes.fetch_add(bytes, Ordering::Relaxed);
324 Ok(DeviceBlock {
325 ptr,
326 device_ordinal: self.device_ordinal,
327 alloc_stream: stream,
328 bytes,
329 align: std::mem::align_of::<u8>(),
330 tag,
331 generation,
332 state: BlockState::Live,
333 })
334 }
335 }
336
337 fn deallocate(&self, block: DeviceBlock) -> ResourceResult<()> {
338 if block.device_ordinal != self.device_ordinal {
339 return Err(ResourceError::Driver(format!(
340 "AsyncCudaResource: deallocate on wrong device (block ord {} vs resource ord {})",
341 block.device_ordinal, self.device_ordinal
342 )));
343 }
344 // Resolve the alloc stream FIRST. If resolution fails the
345 // live entry stays in place and accounting is unchanged —
346 // the caller can retry. Removing the entry first then
347 // erroring would queue `cuMemFreeAsync` on a stream the
348 // caller did not expect (via the slice drop on the error
349 // return path) AND leave accounting drift behind.
350 let alloc_stream = self
351 .stream_pool
352 .resolve(block.alloc_stream)
353 .ok_or_else(|| {
354 ResourceError::StreamMisuse(format!(
355 "AsyncCudaResource::deallocate: alloc_stream StreamId({}) does not resolve",
356 block.alloc_stream.0
357 ))
358 })?;
359
360 // Take the live-map lock and validate (ptr, generation)
361 // before removing. The generation guard closes the ABA
362 // window: if the address was freed and reused, the older
363 // block's deallocate must NOT tear down the new live
364 // entry. Mismatch -> UseAfterFree, no mutation.
365 //
366 // While the entry is still in the map, queue waits on
367 // alloc_stream for: the block's last_write (if any) and
368 // every outstanding_read. cudarc's `wait` records the
369 // dependency synchronously; if any wait call fails, the
370 // events stay owned by the entry, the entry stays in the
371 // map, and accounting is untouched — caller can retry.
372 //
373 // Same-stream waits are skipped — events recorded on
374 // `block.alloc_stream` are already ordered before
375 // anything else queued there, so requesting a wait would
376 // just be busywork. Cross-stream events are the ones
377 // that fence the queued cuMemFreeAsync against in-flight
378 // consumers.
379 //
380 // Only after every wait succeeds do we remove the entry,
381 // taking ownership of the slice and events, and exit the
382 // lock. From that point removal is committed and the
383 // slice drop below queues cuMemFreeAsync correctly
384 // ordered after every wait we just submitted.
385 let (slice, last_write, outstanding_reads) = {
386 let mut live = self
387 .live
388 .lock()
389 .expect("AsyncCudaResource live map poisoned");
390 match live.get(&block.ptr) {
391 Some(entry) if entry.generation == block.generation => {
392 if let Some((write_stream, event)) = &entry.last_write {
393 if *write_stream != block.alloc_stream {
394 alloc_stream.wait(event).map_err(|e| {
395 ResourceError::Driver(format!(
396 "AsyncCudaResource::deallocate: cuStreamWaitEvent on \
397 last_write failed: {}",
398 e
399 ))
400 })?;
401 }
402 }
403 for (read_stream, event) in &entry.outstanding_reads {
404 if *read_stream != block.alloc_stream {
405 alloc_stream.wait(event).map_err(|e| {
406 ResourceError::Driver(format!(
407 "AsyncCudaResource::deallocate: cuStreamWaitEvent on \
408 outstanding read failed: {}",
409 e
410 ))
411 })?;
412 }
413 }
414 let LiveEntry {
415 slice,
416 last_write,
417 outstanding_reads,
418 ..
419 } = live
420 .remove(&block.ptr)
421 .expect("present under lock per get above");
422 (slice, last_write, outstanding_reads)
423 }
424 Some(_) | None => {
425 return Err(ResourceError::UseAfterFree {
426 generation: block.generation,
427 });
428 }
429 }
430 };
431
432 // Move the bytes from "live" to "pending free": the slice
433 // drop below queues `cuMemFreeAsync` on `block.alloc_stream`,
434 // but the driver may not actually free until that stream
435 // drains. The trait contract requires us to keep counting
436 // these bytes until `reap_pending` confirms completion.
437 //
438 // The pending bookkeeping is updated as a unit under the
439 // `pending_per_stream` mutex: per-stream tally first, then
440 // the global atomic. `reap_pending` reads (drain, sync,
441 // subtract) symmetrically under the same mutex around the
442 // drain so it can only subtract the exact total it drained.
443 // A `deallocate` that races with reap therefore lands either
444 // entirely before reap's drain (its bytes are reaped this
445 // round) or entirely after (its bytes stay pending for the
446 // next reap) — never split.
447 self.live_bytes.fetch_sub(block.bytes, Ordering::Relaxed);
448 {
449 let mut per_stream = self
450 .pending_per_stream
451 .lock()
452 .expect("AsyncCudaResource pending_per_stream poisoned");
453 *per_stream.entry(block.alloc_stream).or_insert(0) += block.bytes;
454 self.pending_bytes.fetch_add(block.bytes, Ordering::Relaxed);
455 }
456
457 // Dropping the CudaSlice<u8> invokes cuMemFreeAsync on its
458 // bound stream when async-alloc is enabled, otherwise falls
459 // back to synchronous cuMemFree. Either way the deallocation
460 // is ordered on the slice's stream, which matches the
461 // DeviceBlock's `alloc_stream` — and now also waits for
462 // every recorded cross-stream use event we just queued
463 // above.
464 drop(slice);
465 // Drop the events explicitly after the slice drop has
466 // queued the free. The event handles can be released as
467 // soon as the wait calls return — cudarc's `wait` records
468 // the dependency in the stream and does not retain the
469 // event.
470 drop(last_write);
471 drop(outstanding_reads);
472 Ok(())
473 }
474
475 fn device_ordinal(&self) -> u32 {
476 self.device_ordinal
477 }
478
479 fn bytes_outstanding(&self) -> usize {
480 self.live_bytes.load(Ordering::Relaxed) + self.pending_bytes.load(Ordering::Relaxed)
481 }
482
483 fn reap_pending(&self) -> ResourceResult<()> {
484 self.reap_pending_with(|stream_id| match self.stream_pool.resolve(stream_id) {
485 Some(stream) => stream.synchronize().map_err(|e| {
486 ResourceError::Driver(format!(
487 "AsyncCudaResource::reap_pending: stream sync failed: {}",
488 e
489 ))
490 }),
491 // Pool returned no handle for this id. The pool currently
492 // never rotates entries, so this is a defensive branch.
493 // If the id is unresolved there is no stream we can
494 // synchronize on; treat the bytes as definitely freed —
495 // the only consistent accounting is to release them and
496 // let the caller surface any subsequent error against a
497 // known stream.
498 None => Ok(()),
499 })
500 }
501
502 fn supports_block_use_tracking(&self) -> bool {
503 true
504 }
505
506 fn record_block_use(&self, block: &DeviceBlock, use_stream: StreamId) -> ResourceResult<()> {
507 // Backward-compatibility shim. Pre-migration callers used
508 // `record_block_use` for "this stream did SOMETHING with
509 // this block; please wait on me before freeing." That
510 // semantics maps to `finish_block_use(.., Access::Read)`:
511 // the event is recorded on `use_stream` and appended to
512 // outstanding_reads so deallocate waits on it. New
513 // callers MUST call `prepare_block_use` BEFORE the launch
514 // and `finish_block_use` after; this shim does NOT queue
515 // the pre-launch wait so it is unsafe for use-after-write
516 // / use-after-prior-read scenarios.
517 self.finish_block_use(BlockId::from_block(block), use_stream, Access::Read)
518 }
519
520 fn prepare_block_use(
521 &self,
522 block: BlockId,
523 use_stream: StreamId,
524 access: Access,
525 ) -> ResourceResult<()> {
526 if block.device_ordinal != self.device_ordinal {
527 return Err(ResourceError::Driver(format!(
528 "AsyncCudaResource::prepare_block_use: block device {} != resource device {}",
529 block.device_ordinal, self.device_ordinal
530 )));
531 }
532 let use_cu_stream = self.stream_pool.resolve(use_stream).ok_or_else(|| {
533 ResourceError::StreamMisuse(format!(
534 "AsyncCudaResource::prepare_block_use: unknown StreamId({})",
535 use_stream.0
536 ))
537 })?;
538
539 // Validate (ptr, generation) and queue cross-stream
540 // waits while holding the live-map lock. The waits are
541 // cuStreamWaitEvent calls which record a dependency in
542 // the use stream and return — they don't block, so the
543 // lock is held only briefly. Same-stream events are
544 // skipped (already ordered).
545 let live = self
546 .live
547 .lock()
548 .expect("AsyncCudaResource live map poisoned");
549 let entry = match live.get(&block.ptr) {
550 Some(entry) if entry.generation == block.generation => entry,
551 Some(_) | None => {
552 return Err(ResourceError::UseAfterFree {
553 generation: block.generation,
554 });
555 }
556 };
557 if access.reads() || access.writes() {
558 // Reader: wait on prior write.
559 // Writer / RW: wait on prior write AND every prior reader.
560 if let Some((write_stream, event)) = &entry.last_write {
561 if *write_stream != use_stream {
562 use_cu_stream.wait(event).map_err(|e| {
563 ResourceError::Driver(format!(
564 "AsyncCudaResource::prepare_block_use: wait on last_write failed: {}",
565 e
566 ))
567 })?;
568 }
569 }
570 }
571 if access.writes() {
572 for (read_stream, event) in &entry.outstanding_reads {
573 if *read_stream != use_stream {
574 use_cu_stream.wait(event).map_err(|e| {
575 ResourceError::Driver(format!(
576 "AsyncCudaResource::prepare_block_use: wait on outstanding read \
577 failed: {}",
578 e
579 ))
580 })?;
581 }
582 }
583 }
584 Ok(())
585 }
586
587 fn finish_block_use(
588 &self,
589 block: BlockId,
590 use_stream: StreamId,
591 access: Access,
592 ) -> ResourceResult<()> {
593 if block.device_ordinal != self.device_ordinal {
594 return Err(ResourceError::Driver(format!(
595 "AsyncCudaResource::finish_block_use: block device {} != resource device {}",
596 block.device_ordinal, self.device_ordinal
597 )));
598 }
599 let use_cu_stream = self.stream_pool.resolve(use_stream).ok_or_else(|| {
600 ResourceError::StreamMisuse(format!(
601 "AsyncCudaResource::finish_block_use: unknown StreamId({})",
602 use_stream.0
603 ))
604 })?;
605 // Validate (ptr, generation) BEFORE recording the event
606 // on `use_stream`. This avoids creating an event that we
607 // would have to immediately destroy on the ABA failure
608 // path.
609 {
610 let live = self
611 .live
612 .lock()
613 .expect("AsyncCudaResource live map poisoned");
614 match live.get(&block.ptr) {
615 Some(entry) if entry.generation == block.generation => {}
616 Some(_) | None => {
617 return Err(ResourceError::UseAfterFree {
618 generation: block.generation,
619 });
620 }
621 }
622 }
623 // Record the event on the use stream OUTSIDE the live-map
624 // lock — event creation/record can block on the CUDA
625 // driver and we don't want to hold the live-map lock
626 // across that. Re-validate generation after acquiring the
627 // lock so a racing dealloc that already removed the entry
628 // doesn't see a phantom event attached to a stale block.
629 let event = use_cu_stream.record_event(None).map_err(|e| {
630 ResourceError::Driver(format!(
631 "AsyncCudaResource::finish_block_use: event record failed: {}",
632 e
633 ))
634 })?;
635 let mut live = self
636 .live
637 .lock()
638 .expect("AsyncCudaResource live map poisoned");
639 match live.get_mut(&block.ptr) {
640 Some(entry) if entry.generation == block.generation => {
641 if access.writes() {
642 // Writer: the prepare phase queued waits on
643 // every prior reader and on last_write, so
644 // any future op that observes the new
645 // last_write transitively observes those
646 // dependencies. Drop the prior state.
647 entry.last_write = Some((use_stream, event));
648 entry.outstanding_reads.clear();
649 } else {
650 debug_assert!(access.reads());
651 entry.outstanding_reads.push((use_stream, event));
652 }
653 Ok(())
654 }
655 Some(_) | None => {
656 // Event drops here, releasing the CUDA event.
657 // cudarc's wait was never queued so no stream
658 // dependency leaks.
659 drop(event);
660 Err(ResourceError::UseAfterFree {
661 generation: block.generation,
662 })
663 }
664 }
665 }
666}
667
668impl AsyncCudaResource {
669 /// Drain pending per-stream entries and synchronize each
670 /// drained stream via `sync_stream`, releasing only the bytes
671 /// for streams that the closure successfully synchronized.
672 ///
673 /// On the first synchronization failure, the failing entry and
674 /// **every remaining un-iterated drained entry** are restored
675 /// into `pending_per_stream` so a subsequent reap can retry
676 /// them, and `pending_bytes` is decremented only by the
677 /// already-synchronized total. The closure's error is then
678 /// returned to the caller. Without this recovery, a transient
679 /// driver error mid-reap would lose track of pending bytes
680 /// forever (drained map is gone, `pending_bytes` still counts
681 /// them, but no stream is queued for a future reap).
682 ///
683 /// Production callers go through [`reap_pending`]
684 /// (the trait method), which passes a closure that resolves
685 /// the [`StreamId`] against [`StreamPool`] and calls
686 /// `CudaStream::synchronize`. This helper exists so unit tests
687 /// can inject controlled sync failures without touching the
688 /// CUDA driver.
689 pub(crate) fn reap_pending_with<F>(&self, mut sync_stream: F) -> ResourceResult<()>
690 where
691 F: FnMut(StreamId) -> ResourceResult<()>,
692 {
693 // Drain the per-stream map atomically. Anything added by a
694 // racing `deallocate` after this point lands in a fresh
695 // entry and waits for the next reap.
696 //
697 // Critically, we do NOT touch `pending_bytes` here — only
698 // after a stream has synchronized do we subtract its bytes.
699 // A `deallocate` that races between our drain and our
700 // subtract has already added to `pending_bytes` under the
701 // same mutex (see `deallocate`), and that addition is
702 // preserved because we `fetch_sub` the synchronized total
703 // rather than `store(0)`.
704 let drained: HashMap<StreamId, usize> = {
705 let mut per_stream = self
706 .pending_per_stream
707 .lock()
708 .expect("AsyncCudaResource pending_per_stream poisoned");
709 std::mem::take(&mut *per_stream)
710 };
711 if drained.is_empty() {
712 return Ok(());
713 }
714
715 let mut synced_total: usize = 0;
716 let mut failure: Option<ResourceError> = None;
717 let mut unsynced: Vec<(StreamId, usize)> = Vec::new();
718 let mut iter = drained.into_iter();
719 while let Some((stream_id, bytes)) = iter.next() {
720 match sync_stream(stream_id) {
721 Ok(()) => {
722 synced_total = synced_total.saturating_add(bytes);
723 }
724 Err(e) => {
725 // Restore the failing entry and every remaining
726 // drained entry so they can be retried by a
727 // future reap.
728 unsynced.push((stream_id, bytes));
729 unsynced.extend(iter.by_ref());
730 failure = Some(e);
731 break;
732 }
733 }
734 }
735
736 if !unsynced.is_empty() {
737 let mut per_stream = self
738 .pending_per_stream
739 .lock()
740 .expect("AsyncCudaResource pending_per_stream poisoned");
741 for (stream_id, bytes) in unsynced {
742 *per_stream.entry(stream_id).or_insert(0) += bytes;
743 }
744 }
745
746 if synced_total > 0 {
747 self.pending_bytes
748 .fetch_sub(synced_total, Ordering::Relaxed);
749 }
750
751 match failure {
752 Some(e) => Err(e),
753 None => Ok(()),
754 }
755 }
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761
762 fn try_setup() -> Option<(Arc<CudaDevice>, Arc<StreamPool>)> {
763 let device = Arc::new(CudaDevice::new(0).ok()?);
764 let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
765 Some((device, pool))
766 }
767
768 #[test]
769 fn allocate_then_deallocate_round_trips_on_default_stream() {
770 let Some((device, pool)) = try_setup() else {
771 return;
772 };
773 let r = AsyncCudaResource::new(device, 0, pool);
774 let block = r
775 .allocate(2048, StreamId::DEFAULT, AllocTag::UNTAGGED)
776 .expect("alloc");
777 assert_eq!(block.bytes, 2048);
778 assert_eq!(block.alloc_stream, StreamId::DEFAULT);
779 assert_eq!(r.bytes_outstanding(), 2048);
780 assert_eq!(r.live_bytes(), 2048);
781 assert_eq!(r.pending_free_bytes(), 0);
782
783 r.deallocate(block).expect("dealloc");
784 // Pending after dealloc — cuMemFreeAsync is queued, not drained.
785 assert_eq!(r.live_bytes(), 0);
786 assert_eq!(r.pending_free_bytes(), 2048);
787 assert_eq!(r.bytes_outstanding(), 2048);
788
789 r.reap_pending().expect("reap pending");
790 assert_eq!(r.bytes_outstanding(), 0);
791 assert_eq!(r.pending_free_bytes(), 0);
792 }
793
794 #[test]
795 fn allocate_on_acquired_non_default_stream() {
796 let Some((device, pool)) = try_setup() else {
797 return;
798 };
799 let r = AsyncCudaResource::new(device, 0, Arc::clone(&pool));
800 let stream = pool.acquire().expect("acquire non-default stream");
801 let block = r
802 .allocate(1024, stream, AllocTag("async-test"))
803 .expect("alloc on non-default stream");
804 assert_eq!(block.alloc_stream, stream);
805 r.deallocate(block).expect("dealloc");
806 // Still counted as outstanding until reap.
807 assert_eq!(r.bytes_outstanding(), 1024);
808 r.reap_pending().expect("reap pending");
809 assert_eq!(r.bytes_outstanding(), 0);
810 }
811
812 #[test]
813 fn allocate_unknown_stream_id_rejected() {
814 let Some((device, pool)) = try_setup() else {
815 return;
816 };
817 let r = AsyncCudaResource::new(device, 0, pool);
818 let err = r.allocate(64, StreamId(99), AllocTag::UNTAGGED);
819 assert!(matches!(err, Err(ResourceError::StreamMisuse(_))));
820 }
821
822 #[test]
823 fn deallocate_unknown_block_returns_use_after_free() {
824 let Some((device, pool)) = try_setup() else {
825 return;
826 };
827 let r = AsyncCudaResource::new(device, 0, pool);
828 let bogus = DeviceBlock {
829 ptr: 0xfeed_face,
830 device_ordinal: 0,
831 alloc_stream: StreamId::DEFAULT,
832 bytes: 16,
833 align: 1,
834 tag: AllocTag::UNTAGGED,
835 generation: Generation::next(),
836 state: BlockState::Live,
837 };
838 assert!(matches!(
839 r.deallocate(bogus),
840 Err(ResourceError::UseAfterFree { .. })
841 ));
842 }
843
844 #[test]
845 fn reap_with_no_pending_is_noop() {
846 let Some((device, pool)) = try_setup() else {
847 return;
848 };
849 let r = AsyncCudaResource::new(device, 0, pool);
850 r.reap_pending().expect("reap on empty");
851 assert_eq!(r.bytes_outstanding(), 0);
852 }
853
854 /// Test-only helper: install pending state directly so we can
855 /// exercise `reap_pending_with` without going through real
856 /// CUDA streams. Bypasses the normal `allocate`/`deallocate`
857 /// path; intended exclusively for the failure-recovery test.
858 fn install_pending(r: &AsyncCudaResource, entries: &[(StreamId, usize)]) {
859 let mut per_stream = r
860 .pending_per_stream
861 .lock()
862 .expect("AsyncCudaResource pending_per_stream poisoned");
863 let mut total: usize = 0;
864 for (id, bytes) in entries {
865 *per_stream.entry(*id).or_insert(0) += *bytes;
866 total = total.saturating_add(*bytes);
867 }
868 drop(per_stream);
869 r.pending_bytes.fetch_add(total, Ordering::Relaxed);
870 }
871
872 #[test]
873 fn reap_pending_recovers_unsynced_streams_when_sync_fails() {
874 // No CUDA needed for the recovery semantics — we use the
875 // real AsyncCudaResource (constructor needs a device only)
876 // and inject sync failures via `reap_pending_with`.
877 let Some((device, pool)) = try_setup() else {
878 return;
879 };
880 let r = AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool));
881
882 // Install two pending entries: the test will fail sync for
883 // StreamId(2). Bytes total 3072.
884 install_pending(&r, &[(StreamId(1), 1024), (StreamId(2), 2048)]);
885 assert_eq!(r.pending_free_bytes(), 3072);
886 assert_eq!(r.pending_per_stream_total(), 3072);
887
888 // Track which streams the closure successfully synchronized.
889 // HashMap iteration order is unspecified, so an
890 // order-independent assertion uses this set: the test must
891 // hold for any iteration order.
892 let synced = std::sync::Mutex::new(Vec::<StreamId>::new());
893 let result = r.reap_pending_with(|stream_id| {
894 if stream_id == StreamId(2) {
895 Err(ResourceError::Driver(
896 "simulated sync failure on StreamId(2)".into(),
897 ))
898 } else {
899 synced.lock().unwrap().push(stream_id);
900 Ok(())
901 }
902 });
903
904 assert!(matches!(result, Err(ResourceError::Driver(_))));
905
906 let synced = synced.into_inner().unwrap();
907 // Iteration order [1,2]: 1 syncs ok, 2 fails → synced=[1],
908 // synced_total=1024, pending_bytes=2048, map=[(2,2048)].
909 // Iteration order [2,1]: 2 fails first, break aborts → synced=[],
910 // synced_total=0, pending_bytes=3072, map=[(1,1024),(2,2048)].
911 // Both must satisfy: pending == 3072 - synced_bytes.
912 let synced_bytes: usize = if synced.contains(&StreamId(1)) {
913 1024
914 } else {
915 0
916 };
917 let expected_pending = 3072 - synced_bytes;
918 assert_eq!(
919 r.pending_free_bytes(),
920 expected_pending,
921 "synced={:?}; pending_bytes must reflect only un-synced bytes",
922 synced
923 );
924 assert_eq!(
925 r.pending_per_stream_total(),
926 expected_pending,
927 "synced={:?}; pending_per_stream_total must equal pending_free_bytes \
928 (cross-counter invariant)",
929 synced
930 );
931
932 // A second reap with a closure that succeeds for everything
933 // must drain the rest cleanly — proves the restored entries
934 // are retried, not lost.
935 r.reap_pending_with(|_| Ok(())).expect("retry reap");
936 assert_eq!(r.pending_free_bytes(), 0);
937 assert_eq!(r.pending_per_stream_total(), 0);
938 }
939
940 #[test]
941 fn reap_pending_drains_normally_when_sync_always_succeeds() {
942 // Sanity: closure-based variant of the success path. Proves
943 // the new factoring hasn't regressed the happy case.
944 let Some((device, pool)) = try_setup() else {
945 return;
946 };
947 let r = AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool));
948
949 install_pending(&r, &[(StreamId(1), 256), (StreamId(2), 512)]);
950 r.reap_pending_with(|_| Ok(())).expect("reap");
951 assert_eq!(r.pending_free_bytes(), 0);
952 assert_eq!(r.pending_per_stream_total(), 0);
953 }
954}