Skip to main content

xlog_cuda/device_runtime/
async_resource.rs

1//! [`AsyncCudaResource`] — stream-ordered allocation backed by
2//! cudarc's `CudaStream::alloc` (which forwards to `cuMemAllocAsync`
3//! when the context supports it).
4//!
5//! Each [`DeviceMemoryResource::allocate`] call resolves the
6//! caller-supplied [`StreamId`] to a live `cudarc::driver::CudaStream`
7//! via the [`StreamPool`], allocates against that stream, and stores
8//! the resulting `CudaSlice<u8>` in the resource's live map. Drop on
9//! deallocate invokes `cuMemFreeAsync` (when supported) on the same
10//! stream the allocation was bound to.
11//!
12//! This backend is the production candidate. It is **not** the
13//! sanitizer/cert backend — pool/async behavior can hide byte-level
14//! out-of-bounds patterns from Compute Sanitizer; the cert role
15//! belongs to [`DirectCudaResource`] (subject to manual Compute Sanitizer
16//! confirmation on a supported host).
17//!
18//! # Stream-ordering contract enforced here
19//!   * `allocate(.., stream, ..)` is ordered on the resolved
20//!     `CudaStream`. The returned `DeviceBlock` carries the same
21//!     `alloc_stream`.
22//!   * `deallocate(block)` releases the underlying memory ordered on
23//!     the block's `alloc_stream`. Callers must have synchronized any
24//!     work on a different stream before deallocation.
25//!   * Reuse of the underlying byte address by a future `allocate` is
26//!     ordered after the previous deallocate by the CUDA driver's
27//!     stream-ordered memory allocator semantics. The stream-ordered
28//!     allocation lifetime regression test encodes this.
29//!
30//! # `bytes_outstanding` and pending-free accounting
31//!
32//! The trait contract is "live + retired-but-not-yet-freed". A queued
33//! `cuMemFreeAsync` is "retired-but-not-yet-freed" until the host
34//! synchronizes the stream the free was queued on. We therefore keep
35//! two atomic counters:
36//!
37//!   * `live_bytes` — bytes for blocks currently in the live map.
38//!   * `pending_bytes` — bytes for blocks whose `CudaSlice` has been
39//!     dropped (so a `cuMemFreeAsync` is queued on the alloc stream)
40//!     but whose stream has not yet been synchronized by us.
41//!
42//! `bytes_outstanding()` returns `live_bytes + pending_bytes`.
43//!
44//! `reap_pending()` drains the per-stream pending map under the
45//! per-stream mutex, synchronizes each drained stream, and then
46//! subtracts only the **synchronized** total from `pending_bytes`
47//! via `fetch_sub` — it does **not** zero the counter. A
48//! `deallocate` that races between reap's drain and its `fetch_sub`
49//! re-populates both the per-stream map and the global atomic
50//! together (under the same mutex), so its bytes either land
51//! entirely before the drain (reaped this round) or entirely after
52//! (kept for the next reap), never split.
53//!
54//! On the first stream-sync failure, the failing entry and every
55//! remaining un-iterated drained entry are **restored** into
56//! `pending_per_stream` so a subsequent reap can retry them. Only
57//! the bytes for streams that successfully synchronized are
58//! decremented from `pending_bytes`. Without this recovery, a
59//! transient driver error mid-reap would lose track of pending
60//! bytes forever — the drained map would be gone, `pending_bytes`
61//! would still count them, but no stream id would be queued for
62//! a future reap. Production callers (`GlobalDeviceBudget`, the
63//! stream-ordered allocation lifetime tests' final assertions) thus see consistent
64//! `bytes_outstanding()` even on transient sync failures.
65
66use std::collections::HashMap;
67use std::sync::atomic::{AtomicUsize, Ordering};
68use std::sync::{Arc, Mutex};
69
70use cudarc::driver::{CudaEvent, CudaSlice};
71
72use super::resource::{
73    Access, AllocTag, BlockId, BlockState, DeviceBlock, DeviceMemoryResource, Generation,
74    ResourceError, ResourceResult, StreamId,
75};
76use super::stream_pool::StreamPool;
77use crate::CudaDevice;
78
79/// One live allocation tracked by [`AsyncCudaResource`]. Carries
80/// the cudarc-owned `CudaSlice<u8>` (whose drop queues the
81/// underlying `cuMemFreeAsync`) plus access-aware dependency
82/// state and the allocation's [`Generation`].
83///
84/// # Dependency state
85///
86/// The block's outstanding dependencies are tracked in two
87/// distinct sets so future operations can wait on the minimal
88/// correct fence:
89///
90///   * `last_write` — the most recent write event recorded on
91///     the block, paired with the stream that recorded it. A
92///     subsequent read on a different stream must wait on this
93///     event; a subsequent write on a different stream must wait
94///     on this event AND every entry in `outstanding_reads`.
95///   * `outstanding_reads` — every read event recorded since the
96///     current `last_write` was installed (or since allocation,
97///     if no write has occurred yet), each paired with its
98///     recording stream. A subsequent write on a different
99///     stream must wait on each entry here. Cleared at finish
100///     time when a new write event replaces `last_write`: the
101///     writer's prepare-time waits already subsumed every prior
102///     reader's dependency, so any future operation that waits
103///     on the new `last_write` transitively observes those
104///     reads' completion.
105///
106/// On `deallocate`, the alloc stream waits on `last_write` (if
107/// any) AND every entry in `outstanding_reads` before the queued
108/// `cuMemFreeAsync` runs.
109///
110/// # ABA / generation guard
111///
112/// The `generation` field guards against address recycling:
113/// every API that mutates the entry validates
114/// `block.generation == entry.generation` before touching it.
115/// Mismatch returns [`ResourceError::UseAfterFree`] and the
116/// entry is unchanged.
117struct LiveEntry {
118    slice: CudaSlice<u8>,
119    generation: Generation,
120    /// Most recent write event on this block, OR the
121    /// allocation-ready event if no write has happened yet.
122    /// Future reads/writes on a different stream wait on this
123    /// event. Replaced by `finish_block_use` for
124    /// `Access::Write` / `Access::ReadWrite`. The
125    /// allocation-ready seed exists because cuMemAllocAsync
126    /// orders the allocation only on `alloc_stream` — a
127    /// cross-stream consumer that submits a kernel before
128    /// allocation completes would read pool-recycled garbage.
129    last_write: Option<(StreamId, CudaEvent)>,
130    /// Read events recorded since `last_write` was installed
131    /// (or since allocation). Future writes on a different
132    /// stream wait on each entry. Cleared by `finish_block_use`
133    /// when a write replaces `last_write`.
134    outstanding_reads: Vec<(StreamId, CudaEvent)>,
135}
136
137/// Stream-ordered cudarc-backed allocator.
138pub struct AsyncCudaResource {
139    device: Arc<CudaDevice>,
140    device_ordinal: u32,
141    stream_pool: Arc<StreamPool>,
142    /// Live allocations keyed by raw device pointer. Each entry
143    /// holds the cudarc slice and any recorded last-use events
144    /// from cross-stream consumers. Removed on deallocate; the
145    /// slice is then dropped, queueing `cuMemFreeAsync` on its
146    /// bound stream — *after* the stream has been told to wait on
147    /// every recorded event.
148    live: Mutex<HashMap<u64, LiveEntry>>,
149    /// Bytes for blocks currently in `live`. Always accurate.
150    live_bytes: AtomicUsize,
151    /// Bytes for blocks dropped (queued for cuMemFreeAsync) but
152    /// whose owning stream has not yet been synchronized by us.
153    /// Equal to the sum of values in `pending_per_stream`. Both are
154    /// updated under the `pending_per_stream` mutex so a concurrent
155    /// `reap_pending` cannot wipe out bytes that a racing
156    /// `deallocate` queued after reap drained the per-stream map.
157    pending_bytes: AtomicUsize,
158    /// Per-stream pending-free byte totals. Used by `reap_pending`
159    /// to (a) compute the total to subtract from `pending_bytes`
160    /// after stream synchronization, and (b) preserve any bytes
161    /// added by a `deallocate` that races with reap — those bytes
162    /// remain in this map and in `pending_bytes`, ready for the
163    /// next reap.
164    pending_per_stream: Mutex<HashMap<StreamId, usize>>,
165}
166
167impl AsyncCudaResource {
168    /// Construct a resource bound to `device` using `stream_pool` for
169    /// stream resolution. `device_ordinal` is the CUDA ordinal for
170    /// logging / multi-device disambiguation.
171    pub fn new(device: Arc<CudaDevice>, device_ordinal: u32, stream_pool: Arc<StreamPool>) -> Self {
172        Self {
173            device,
174            device_ordinal,
175            stream_pool,
176            live: Mutex::new(HashMap::new()),
177            live_bytes: AtomicUsize::new(0),
178            pending_bytes: AtomicUsize::new(0),
179            pending_per_stream: Mutex::new(HashMap::new()),
180        }
181    }
182
183    pub fn device(&self) -> &Arc<CudaDevice> {
184        &self.device
185    }
186
187    pub fn stream_pool(&self) -> &Arc<StreamPool> {
188        &self.stream_pool
189    }
190
191    /// Bytes currently held by live blocks (excludes pending frees).
192    /// Test/diagnostic accessor — production code should use
193    /// `bytes_outstanding`.
194    pub fn live_bytes(&self) -> usize {
195        self.live_bytes.load(Ordering::Relaxed)
196    }
197
198    /// Bytes queued for `cuMemFreeAsync` whose stream has not yet
199    /// been synchronized by us. Test/diagnostic accessor.
200    pub fn pending_free_bytes(&self) -> usize {
201        self.pending_bytes.load(Ordering::Relaxed)
202    }
203
204    /// Sum of per-stream pending byte tallies. Test/diagnostic
205    /// accessor used to assert the invariant
206    /// `pending_free_bytes() == pending_per_stream_total()`. The
207    /// invariant must hold at any quiescent moment; if it fails
208    /// the bookkeeping under the `pending_per_stream` mutex has
209    /// drifted from the global atomic — see `deallocate` and
210    /// `reap_pending`, which update both as a unit.
211    pub fn pending_per_stream_total(&self) -> usize {
212        let map = self
213            .pending_per_stream
214            .lock()
215            .expect("AsyncCudaResource pending_per_stream poisoned");
216        map.values().copied().sum()
217    }
218
219    /// Number of recorded outstanding-read events plus a
220    /// last_write event (0 or 1) currently attached to the live
221    /// block at `ptr`. Test/diagnostic accessor — used by
222    /// reproducers to confirm `finish_block_use` actually
223    /// attached events before deallocate consumed them. Returns
224    /// `None` if `ptr` is not currently in the live map.
225    pub fn pending_use_event_count(&self, ptr: u64) -> Option<usize> {
226        let live = self
227            .live
228            .lock()
229            .expect("AsyncCudaResource live map poisoned");
230        live.get(&ptr)
231            .map(|e| e.outstanding_reads.len() + if e.last_write.is_some() { 1 } else { 0 })
232    }
233}
234
235impl DeviceMemoryResource for AsyncCudaResource {
236    fn allocate(
237        &self,
238        bytes: usize,
239        stream: StreamId,
240        tag: AllocTag,
241    ) -> ResourceResult<DeviceBlock> {
242        if bytes == 0 {
243            return Err(ResourceError::Driver(
244                "AsyncCudaResource: zero-byte allocation not supported".to_string(),
245            ));
246        }
247        let cu_stream = self.stream_pool.resolve(stream).ok_or_else(|| {
248            ResourceError::StreamMisuse(format!(
249                "AsyncCudaResource: unknown StreamId({})",
250                stream.0
251            ))
252        })?;
253
254        // SAFETY: bytes > 0 verified above. cudarc's
255        // `CudaStream::alloc::<u8>(len)` forwards to `cuMemAllocAsync`
256        // when the context has async-alloc enabled (CUDA 11.2+);
257        // otherwise it falls back to synchronous alloc internally.
258        // Failures are surfaced as `ResourceError::Driver`.
259        let slice = unsafe {
260            cu_stream
261                .alloc::<u8>(bytes)
262                .map_err(|e| ResourceError::Driver(format!("cuMemAllocAsync({}): {}", bytes, e)))?
263        };
264
265        // Record an "allocation-ready" event on the alloc stream
266        // immediately after the cuMemAllocAsync call. Cross-
267        // stream consumers MUST wait on this event before
268        // touching the bytes, otherwise the launch (on a
269        // different stream) may begin before the allocation
270        // completes and read pre-init / pool-recycled garbage.
271        // We store it in `last_write` so the access-aware
272        // prepare path's existing read-waits-on-last_write and
273        // write-waits-on-last_write rules cover it for free.
274        // Same-stream consumers skip the wait (already ordered).
275        let alloc_event = cu_stream.record_event(None).map_err(|e| {
276            ResourceError::Driver(format!(
277                "AsyncCudaResource::allocate: record allocation-ready event failed: {}",
278                e
279            ))
280        })?;
281
282        // Extract the raw device pointer for the DeviceBlock surface.
283        // The "sync" handle returned by `device_ptr` is intentionally
284        // leaked — the slice's lifetime is managed by our live map,
285        // not by the sync token.
286        let (raw_ptr, sync) =
287            <CudaSlice<u8> as cudarc::driver::DevicePtr<u8>>::device_ptr(&slice, slice.stream());
288        std::mem::forget(sync);
289        let ptr = raw_ptr;
290
291        {
292            let mut live = self
293                .live
294                .lock()
295                .expect("AsyncCudaResource live map poisoned");
296            // Use `contains_key` then `insert` so a (theoretical)
297            // pointer collision returns `Err` without mutating the
298            // map. The `live.insert(ptr, slice).is_some()` pattern
299            // would replace the existing entry, drop the old slice
300            // (queueing cuMemFreeAsync on memory we still believe
301            // we own), and leave the new slice resident while we
302            // return Err — `live_bytes` would also not be updated.
303            // Avoid that here.
304            if live.contains_key(&ptr) {
305                return Err(ResourceError::Driver(format!(
306                    "AsyncCudaResource: pointer collision on alloc ({:#x})",
307                    ptr
308                )));
309            }
310            // Generation must match between the LiveEntry and the
311            // returned DeviceBlock so record_block_use and
312            // deallocate can ABA-validate by (ptr, generation).
313            let generation = Generation::next();
314            live.insert(
315                ptr,
316                LiveEntry {
317                    slice,
318                    generation,
319                    last_write: Some((stream, alloc_event)),
320                    outstanding_reads: Vec::new(),
321                },
322            );
323            self.live_bytes.fetch_add(bytes, Ordering::Relaxed);
324            Ok(DeviceBlock {
325                ptr,
326                device_ordinal: self.device_ordinal,
327                alloc_stream: stream,
328                bytes,
329                align: std::mem::align_of::<u8>(),
330                tag,
331                generation,
332                state: BlockState::Live,
333            })
334        }
335    }
336
337    fn deallocate(&self, block: DeviceBlock) -> ResourceResult<()> {
338        if block.device_ordinal != self.device_ordinal {
339            return Err(ResourceError::Driver(format!(
340                "AsyncCudaResource: deallocate on wrong device (block ord {} vs resource ord {})",
341                block.device_ordinal, self.device_ordinal
342            )));
343        }
344        // Resolve the alloc stream FIRST. If resolution fails the
345        // live entry stays in place and accounting is unchanged —
346        // the caller can retry. Removing the entry first then
347        // erroring would queue `cuMemFreeAsync` on a stream the
348        // caller did not expect (via the slice drop on the error
349        // return path) AND leave accounting drift behind.
350        let alloc_stream = self
351            .stream_pool
352            .resolve(block.alloc_stream)
353            .ok_or_else(|| {
354                ResourceError::StreamMisuse(format!(
355                    "AsyncCudaResource::deallocate: alloc_stream StreamId({}) does not resolve",
356                    block.alloc_stream.0
357                ))
358            })?;
359
360        // Take the live-map lock and validate (ptr, generation)
361        // before removing. The generation guard closes the ABA
362        // window: if the address was freed and reused, the older
363        // block's deallocate must NOT tear down the new live
364        // entry. Mismatch -> UseAfterFree, no mutation.
365        //
366        // While the entry is still in the map, queue waits on
367        // alloc_stream for: the block's last_write (if any) and
368        // every outstanding_read. cudarc's `wait` records the
369        // dependency synchronously; if any wait call fails, the
370        // events stay owned by the entry, the entry stays in the
371        // map, and accounting is untouched — caller can retry.
372        //
373        // Same-stream waits are skipped — events recorded on
374        // `block.alloc_stream` are already ordered before
375        // anything else queued there, so requesting a wait would
376        // just be busywork. Cross-stream events are the ones
377        // that fence the queued cuMemFreeAsync against in-flight
378        // consumers.
379        //
380        // Only after every wait succeeds do we remove the entry,
381        // taking ownership of the slice and events, and exit the
382        // lock. From that point removal is committed and the
383        // slice drop below queues cuMemFreeAsync correctly
384        // ordered after every wait we just submitted.
385        let (slice, last_write, outstanding_reads) = {
386            let mut live = self
387                .live
388                .lock()
389                .expect("AsyncCudaResource live map poisoned");
390            match live.get(&block.ptr) {
391                Some(entry) if entry.generation == block.generation => {
392                    if let Some((write_stream, event)) = &entry.last_write {
393                        if *write_stream != block.alloc_stream {
394                            alloc_stream.wait(event).map_err(|e| {
395                                ResourceError::Driver(format!(
396                                    "AsyncCudaResource::deallocate: cuStreamWaitEvent on \
397                                     last_write failed: {}",
398                                    e
399                                ))
400                            })?;
401                        }
402                    }
403                    for (read_stream, event) in &entry.outstanding_reads {
404                        if *read_stream != block.alloc_stream {
405                            alloc_stream.wait(event).map_err(|e| {
406                                ResourceError::Driver(format!(
407                                    "AsyncCudaResource::deallocate: cuStreamWaitEvent on \
408                                     outstanding read failed: {}",
409                                    e
410                                ))
411                            })?;
412                        }
413                    }
414                    let LiveEntry {
415                        slice,
416                        last_write,
417                        outstanding_reads,
418                        ..
419                    } = live
420                        .remove(&block.ptr)
421                        .expect("present under lock per get above");
422                    (slice, last_write, outstanding_reads)
423                }
424                Some(_) | None => {
425                    return Err(ResourceError::UseAfterFree {
426                        generation: block.generation,
427                    });
428                }
429            }
430        };
431
432        // Move the bytes from "live" to "pending free": the slice
433        // drop below queues `cuMemFreeAsync` on `block.alloc_stream`,
434        // but the driver may not actually free until that stream
435        // drains. The trait contract requires us to keep counting
436        // these bytes until `reap_pending` confirms completion.
437        //
438        // The pending bookkeeping is updated as a unit under the
439        // `pending_per_stream` mutex: per-stream tally first, then
440        // the global atomic. `reap_pending` reads (drain, sync,
441        // subtract) symmetrically under the same mutex around the
442        // drain so it can only subtract the exact total it drained.
443        // A `deallocate` that races with reap therefore lands either
444        // entirely before reap's drain (its bytes are reaped this
445        // round) or entirely after (its bytes stay pending for the
446        // next reap) — never split.
447        self.live_bytes.fetch_sub(block.bytes, Ordering::Relaxed);
448        {
449            let mut per_stream = self
450                .pending_per_stream
451                .lock()
452                .expect("AsyncCudaResource pending_per_stream poisoned");
453            *per_stream.entry(block.alloc_stream).or_insert(0) += block.bytes;
454            self.pending_bytes.fetch_add(block.bytes, Ordering::Relaxed);
455        }
456
457        // Dropping the CudaSlice<u8> invokes cuMemFreeAsync on its
458        // bound stream when async-alloc is enabled, otherwise falls
459        // back to synchronous cuMemFree. Either way the deallocation
460        // is ordered on the slice's stream, which matches the
461        // DeviceBlock's `alloc_stream` — and now also waits for
462        // every recorded cross-stream use event we just queued
463        // above.
464        drop(slice);
465        // Drop the events explicitly after the slice drop has
466        // queued the free. The event handles can be released as
467        // soon as the wait calls return — cudarc's `wait` records
468        // the dependency in the stream and does not retain the
469        // event.
470        drop(last_write);
471        drop(outstanding_reads);
472        Ok(())
473    }
474
475    fn device_ordinal(&self) -> u32 {
476        self.device_ordinal
477    }
478
479    fn bytes_outstanding(&self) -> usize {
480        self.live_bytes.load(Ordering::Relaxed) + self.pending_bytes.load(Ordering::Relaxed)
481    }
482
483    fn reap_pending(&self) -> ResourceResult<()> {
484        self.reap_pending_with(|stream_id| match self.stream_pool.resolve(stream_id) {
485            Some(stream) => stream.synchronize().map_err(|e| {
486                ResourceError::Driver(format!(
487                    "AsyncCudaResource::reap_pending: stream sync failed: {}",
488                    e
489                ))
490            }),
491            // Pool returned no handle for this id. The pool currently
492            // never rotates entries, so this is a defensive branch.
493            // If the id is unresolved there is no stream we can
494            // synchronize on; treat the bytes as definitely freed —
495            // the only consistent accounting is to release them and
496            // let the caller surface any subsequent error against a
497            // known stream.
498            None => Ok(()),
499        })
500    }
501
502    fn supports_block_use_tracking(&self) -> bool {
503        true
504    }
505
506    fn record_block_use(&self, block: &DeviceBlock, use_stream: StreamId) -> ResourceResult<()> {
507        // Backward-compatibility shim. Pre-migration callers used
508        // `record_block_use` for "this stream did SOMETHING with
509        // this block; please wait on me before freeing." That
510        // semantics maps to `finish_block_use(.., Access::Read)`:
511        // the event is recorded on `use_stream` and appended to
512        // outstanding_reads so deallocate waits on it. New
513        // callers MUST call `prepare_block_use` BEFORE the launch
514        // and `finish_block_use` after; this shim does NOT queue
515        // the pre-launch wait so it is unsafe for use-after-write
516        // / use-after-prior-read scenarios.
517        self.finish_block_use(BlockId::from_block(block), use_stream, Access::Read)
518    }
519
520    fn prepare_block_use(
521        &self,
522        block: BlockId,
523        use_stream: StreamId,
524        access: Access,
525    ) -> ResourceResult<()> {
526        if block.device_ordinal != self.device_ordinal {
527            return Err(ResourceError::Driver(format!(
528                "AsyncCudaResource::prepare_block_use: block device {} != resource device {}",
529                block.device_ordinal, self.device_ordinal
530            )));
531        }
532        let use_cu_stream = self.stream_pool.resolve(use_stream).ok_or_else(|| {
533            ResourceError::StreamMisuse(format!(
534                "AsyncCudaResource::prepare_block_use: unknown StreamId({})",
535                use_stream.0
536            ))
537        })?;
538
539        // Validate (ptr, generation) and queue cross-stream
540        // waits while holding the live-map lock. The waits are
541        // cuStreamWaitEvent calls which record a dependency in
542        // the use stream and return — they don't block, so the
543        // lock is held only briefly. Same-stream events are
544        // skipped (already ordered).
545        let live = self
546            .live
547            .lock()
548            .expect("AsyncCudaResource live map poisoned");
549        let entry = match live.get(&block.ptr) {
550            Some(entry) if entry.generation == block.generation => entry,
551            Some(_) | None => {
552                return Err(ResourceError::UseAfterFree {
553                    generation: block.generation,
554                });
555            }
556        };
557        if access.reads() || access.writes() {
558            // Reader: wait on prior write.
559            // Writer / RW: wait on prior write AND every prior reader.
560            if let Some((write_stream, event)) = &entry.last_write {
561                if *write_stream != use_stream {
562                    use_cu_stream.wait(event).map_err(|e| {
563                        ResourceError::Driver(format!(
564                            "AsyncCudaResource::prepare_block_use: wait on last_write failed: {}",
565                            e
566                        ))
567                    })?;
568                }
569            }
570        }
571        if access.writes() {
572            for (read_stream, event) in &entry.outstanding_reads {
573                if *read_stream != use_stream {
574                    use_cu_stream.wait(event).map_err(|e| {
575                        ResourceError::Driver(format!(
576                            "AsyncCudaResource::prepare_block_use: wait on outstanding read \
577                             failed: {}",
578                            e
579                        ))
580                    })?;
581                }
582            }
583        }
584        Ok(())
585    }
586
587    fn finish_block_use(
588        &self,
589        block: BlockId,
590        use_stream: StreamId,
591        access: Access,
592    ) -> ResourceResult<()> {
593        if block.device_ordinal != self.device_ordinal {
594            return Err(ResourceError::Driver(format!(
595                "AsyncCudaResource::finish_block_use: block device {} != resource device {}",
596                block.device_ordinal, self.device_ordinal
597            )));
598        }
599        let use_cu_stream = self.stream_pool.resolve(use_stream).ok_or_else(|| {
600            ResourceError::StreamMisuse(format!(
601                "AsyncCudaResource::finish_block_use: unknown StreamId({})",
602                use_stream.0
603            ))
604        })?;
605        // Validate (ptr, generation) BEFORE recording the event
606        // on `use_stream`. This avoids creating an event that we
607        // would have to immediately destroy on the ABA failure
608        // path.
609        {
610            let live = self
611                .live
612                .lock()
613                .expect("AsyncCudaResource live map poisoned");
614            match live.get(&block.ptr) {
615                Some(entry) if entry.generation == block.generation => {}
616                Some(_) | None => {
617                    return Err(ResourceError::UseAfterFree {
618                        generation: block.generation,
619                    });
620                }
621            }
622        }
623        // Record the event on the use stream OUTSIDE the live-map
624        // lock — event creation/record can block on the CUDA
625        // driver and we don't want to hold the live-map lock
626        // across that. Re-validate generation after acquiring the
627        // lock so a racing dealloc that already removed the entry
628        // doesn't see a phantom event attached to a stale block.
629        let event = use_cu_stream.record_event(None).map_err(|e| {
630            ResourceError::Driver(format!(
631                "AsyncCudaResource::finish_block_use: event record failed: {}",
632                e
633            ))
634        })?;
635        let mut live = self
636            .live
637            .lock()
638            .expect("AsyncCudaResource live map poisoned");
639        match live.get_mut(&block.ptr) {
640            Some(entry) if entry.generation == block.generation => {
641                if access.writes() {
642                    // Writer: the prepare phase queued waits on
643                    // every prior reader and on last_write, so
644                    // any future op that observes the new
645                    // last_write transitively observes those
646                    // dependencies. Drop the prior state.
647                    entry.last_write = Some((use_stream, event));
648                    entry.outstanding_reads.clear();
649                } else {
650                    debug_assert!(access.reads());
651                    entry.outstanding_reads.push((use_stream, event));
652                }
653                Ok(())
654            }
655            Some(_) | None => {
656                // Event drops here, releasing the CUDA event.
657                // cudarc's wait was never queued so no stream
658                // dependency leaks.
659                drop(event);
660                Err(ResourceError::UseAfterFree {
661                    generation: block.generation,
662                })
663            }
664        }
665    }
666}
667
668impl AsyncCudaResource {
669    /// Drain pending per-stream entries and synchronize each
670    /// drained stream via `sync_stream`, releasing only the bytes
671    /// for streams that the closure successfully synchronized.
672    ///
673    /// On the first synchronization failure, the failing entry and
674    /// **every remaining un-iterated drained entry** are restored
675    /// into `pending_per_stream` so a subsequent reap can retry
676    /// them, and `pending_bytes` is decremented only by the
677    /// already-synchronized total. The closure's error is then
678    /// returned to the caller. Without this recovery, a transient
679    /// driver error mid-reap would lose track of pending bytes
680    /// forever (drained map is gone, `pending_bytes` still counts
681    /// them, but no stream is queued for a future reap).
682    ///
683    /// Production callers go through [`reap_pending`]
684    /// (the trait method), which passes a closure that resolves
685    /// the [`StreamId`] against [`StreamPool`] and calls
686    /// `CudaStream::synchronize`. This helper exists so unit tests
687    /// can inject controlled sync failures without touching the
688    /// CUDA driver.
689    pub(crate) fn reap_pending_with<F>(&self, mut sync_stream: F) -> ResourceResult<()>
690    where
691        F: FnMut(StreamId) -> ResourceResult<()>,
692    {
693        // Drain the per-stream map atomically. Anything added by a
694        // racing `deallocate` after this point lands in a fresh
695        // entry and waits for the next reap.
696        //
697        // Critically, we do NOT touch `pending_bytes` here — only
698        // after a stream has synchronized do we subtract its bytes.
699        // A `deallocate` that races between our drain and our
700        // subtract has already added to `pending_bytes` under the
701        // same mutex (see `deallocate`), and that addition is
702        // preserved because we `fetch_sub` the synchronized total
703        // rather than `store(0)`.
704        let drained: HashMap<StreamId, usize> = {
705            let mut per_stream = self
706                .pending_per_stream
707                .lock()
708                .expect("AsyncCudaResource pending_per_stream poisoned");
709            std::mem::take(&mut *per_stream)
710        };
711        if drained.is_empty() {
712            return Ok(());
713        }
714
715        let mut synced_total: usize = 0;
716        let mut failure: Option<ResourceError> = None;
717        let mut unsynced: Vec<(StreamId, usize)> = Vec::new();
718        let mut iter = drained.into_iter();
719        while let Some((stream_id, bytes)) = iter.next() {
720            match sync_stream(stream_id) {
721                Ok(()) => {
722                    synced_total = synced_total.saturating_add(bytes);
723                }
724                Err(e) => {
725                    // Restore the failing entry and every remaining
726                    // drained entry so they can be retried by a
727                    // future reap.
728                    unsynced.push((stream_id, bytes));
729                    unsynced.extend(iter.by_ref());
730                    failure = Some(e);
731                    break;
732                }
733            }
734        }
735
736        if !unsynced.is_empty() {
737            let mut per_stream = self
738                .pending_per_stream
739                .lock()
740                .expect("AsyncCudaResource pending_per_stream poisoned");
741            for (stream_id, bytes) in unsynced {
742                *per_stream.entry(stream_id).or_insert(0) += bytes;
743            }
744        }
745
746        if synced_total > 0 {
747            self.pending_bytes
748                .fetch_sub(synced_total, Ordering::Relaxed);
749        }
750
751        match failure {
752            Some(e) => Err(e),
753            None => Ok(()),
754        }
755    }
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761
762    fn try_setup() -> Option<(Arc<CudaDevice>, Arc<StreamPool>)> {
763        let device = Arc::new(CudaDevice::new(0).ok()?);
764        let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
765        Some((device, pool))
766    }
767
768    #[test]
769    fn allocate_then_deallocate_round_trips_on_default_stream() {
770        let Some((device, pool)) = try_setup() else {
771            return;
772        };
773        let r = AsyncCudaResource::new(device, 0, pool);
774        let block = r
775            .allocate(2048, StreamId::DEFAULT, AllocTag::UNTAGGED)
776            .expect("alloc");
777        assert_eq!(block.bytes, 2048);
778        assert_eq!(block.alloc_stream, StreamId::DEFAULT);
779        assert_eq!(r.bytes_outstanding(), 2048);
780        assert_eq!(r.live_bytes(), 2048);
781        assert_eq!(r.pending_free_bytes(), 0);
782
783        r.deallocate(block).expect("dealloc");
784        // Pending after dealloc — cuMemFreeAsync is queued, not drained.
785        assert_eq!(r.live_bytes(), 0);
786        assert_eq!(r.pending_free_bytes(), 2048);
787        assert_eq!(r.bytes_outstanding(), 2048);
788
789        r.reap_pending().expect("reap pending");
790        assert_eq!(r.bytes_outstanding(), 0);
791        assert_eq!(r.pending_free_bytes(), 0);
792    }
793
794    #[test]
795    fn allocate_on_acquired_non_default_stream() {
796        let Some((device, pool)) = try_setup() else {
797            return;
798        };
799        let r = AsyncCudaResource::new(device, 0, Arc::clone(&pool));
800        let stream = pool.acquire().expect("acquire non-default stream");
801        let block = r
802            .allocate(1024, stream, AllocTag("async-test"))
803            .expect("alloc on non-default stream");
804        assert_eq!(block.alloc_stream, stream);
805        r.deallocate(block).expect("dealloc");
806        // Still counted as outstanding until reap.
807        assert_eq!(r.bytes_outstanding(), 1024);
808        r.reap_pending().expect("reap pending");
809        assert_eq!(r.bytes_outstanding(), 0);
810    }
811
812    #[test]
813    fn allocate_unknown_stream_id_rejected() {
814        let Some((device, pool)) = try_setup() else {
815            return;
816        };
817        let r = AsyncCudaResource::new(device, 0, pool);
818        let err = r.allocate(64, StreamId(99), AllocTag::UNTAGGED);
819        assert!(matches!(err, Err(ResourceError::StreamMisuse(_))));
820    }
821
822    #[test]
823    fn deallocate_unknown_block_returns_use_after_free() {
824        let Some((device, pool)) = try_setup() else {
825            return;
826        };
827        let r = AsyncCudaResource::new(device, 0, pool);
828        let bogus = DeviceBlock {
829            ptr: 0xfeed_face,
830            device_ordinal: 0,
831            alloc_stream: StreamId::DEFAULT,
832            bytes: 16,
833            align: 1,
834            tag: AllocTag::UNTAGGED,
835            generation: Generation::next(),
836            state: BlockState::Live,
837        };
838        assert!(matches!(
839            r.deallocate(bogus),
840            Err(ResourceError::UseAfterFree { .. })
841        ));
842    }
843
844    #[test]
845    fn reap_with_no_pending_is_noop() {
846        let Some((device, pool)) = try_setup() else {
847            return;
848        };
849        let r = AsyncCudaResource::new(device, 0, pool);
850        r.reap_pending().expect("reap on empty");
851        assert_eq!(r.bytes_outstanding(), 0);
852    }
853
854    /// Test-only helper: install pending state directly so we can
855    /// exercise `reap_pending_with` without going through real
856    /// CUDA streams. Bypasses the normal `allocate`/`deallocate`
857    /// path; intended exclusively for the failure-recovery test.
858    fn install_pending(r: &AsyncCudaResource, entries: &[(StreamId, usize)]) {
859        let mut per_stream = r
860            .pending_per_stream
861            .lock()
862            .expect("AsyncCudaResource pending_per_stream poisoned");
863        let mut total: usize = 0;
864        for (id, bytes) in entries {
865            *per_stream.entry(*id).or_insert(0) += *bytes;
866            total = total.saturating_add(*bytes);
867        }
868        drop(per_stream);
869        r.pending_bytes.fetch_add(total, Ordering::Relaxed);
870    }
871
872    #[test]
873    fn reap_pending_recovers_unsynced_streams_when_sync_fails() {
874        // No CUDA needed for the recovery semantics — we use the
875        // real AsyncCudaResource (constructor needs a device only)
876        // and inject sync failures via `reap_pending_with`.
877        let Some((device, pool)) = try_setup() else {
878            return;
879        };
880        let r = AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool));
881
882        // Install two pending entries: the test will fail sync for
883        // StreamId(2). Bytes total 3072.
884        install_pending(&r, &[(StreamId(1), 1024), (StreamId(2), 2048)]);
885        assert_eq!(r.pending_free_bytes(), 3072);
886        assert_eq!(r.pending_per_stream_total(), 3072);
887
888        // Track which streams the closure successfully synchronized.
889        // HashMap iteration order is unspecified, so an
890        // order-independent assertion uses this set: the test must
891        // hold for any iteration order.
892        let synced = std::sync::Mutex::new(Vec::<StreamId>::new());
893        let result = r.reap_pending_with(|stream_id| {
894            if stream_id == StreamId(2) {
895                Err(ResourceError::Driver(
896                    "simulated sync failure on StreamId(2)".into(),
897                ))
898            } else {
899                synced.lock().unwrap().push(stream_id);
900                Ok(())
901            }
902        });
903
904        assert!(matches!(result, Err(ResourceError::Driver(_))));
905
906        let synced = synced.into_inner().unwrap();
907        // Iteration order [1,2]: 1 syncs ok, 2 fails → synced=[1],
908        //   synced_total=1024, pending_bytes=2048, map=[(2,2048)].
909        // Iteration order [2,1]: 2 fails first, break aborts → synced=[],
910        //   synced_total=0, pending_bytes=3072, map=[(1,1024),(2,2048)].
911        // Both must satisfy: pending == 3072 - synced_bytes.
912        let synced_bytes: usize = if synced.contains(&StreamId(1)) {
913            1024
914        } else {
915            0
916        };
917        let expected_pending = 3072 - synced_bytes;
918        assert_eq!(
919            r.pending_free_bytes(),
920            expected_pending,
921            "synced={:?}; pending_bytes must reflect only un-synced bytes",
922            synced
923        );
924        assert_eq!(
925            r.pending_per_stream_total(),
926            expected_pending,
927            "synced={:?}; pending_per_stream_total must equal pending_free_bytes \
928             (cross-counter invariant)",
929            synced
930        );
931
932        // A second reap with a closure that succeeds for everything
933        // must drain the rest cleanly — proves the restored entries
934        // are retried, not lost.
935        r.reap_pending_with(|_| Ok(())).expect("retry reap");
936        assert_eq!(r.pending_free_bytes(), 0);
937        assert_eq!(r.pending_per_stream_total(), 0);
938    }
939
940    #[test]
941    fn reap_pending_drains_normally_when_sync_always_succeeds() {
942        // Sanity: closure-based variant of the success path. Proves
943        // the new factoring hasn't regressed the happy case.
944        let Some((device, pool)) = try_setup() else {
945            return;
946        };
947        let r = AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool));
948
949        install_pending(&r, &[(StreamId(1), 256), (StreamId(2), 512)]);
950        r.reap_pending_with(|_| Ok(())).expect("reap");
951        assert_eq!(r.pending_free_bytes(), 0);
952        assert_eq!(r.pending_per_stream_total(), 0);
953    }
954}