Skip to main content

toasty_driver_integration_suite/
instrumented_driver.rs

1use async_trait::async_trait;
2use std::{
3    borrow::Cow,
4    collections::VecDeque,
5    fmt,
6    sync::{
7        Arc, Mutex,
8        atomic::{AtomicBool, Ordering},
9    },
10};
11use toasty_core::{
12    Result, Schema,
13    driver::{Capability, Connection, Driver, ExecResponse, Operation, Rows},
14    schema::{
15        db::{AppliedMigration, Migration},
16        diff,
17    },
18};
19
20/// A fault that can be injected into the next operation routed through
21/// the driver. Faults are consumed in FIFO order: each `exec` call pops
22/// at most one fault off the queue before delegating (or short-circuiting
23/// past) the underlying driver.
24#[derive(Debug, Clone)]
25pub enum Fault {
26    /// Causes the next `exec` to return `Error::connection_lost` without
27    /// touching the underlying connection. The wrapping
28    /// `InstrumentedConnection`'s `is_valid` flips to `false`, mirroring
29    /// what a real connection-lost error would do and prompting the pool
30    /// to evict the connection.
31    ConnectionLost,
32}
33
34#[derive(Debug)]
35pub struct DriverOp {
36    pub operation: Operation,
37    pub response: ExecResponse,
38}
39
40/// Single control handle for the [`InstrumentedDriver`] test middleware.
41/// Exposes both the operation log (for assertions) and the fault queue
42/// (for injecting failures). Cheaply cloneable; every clone refers to
43/// the same shared state.
44#[derive(Clone, Default)]
45pub struct InstrumentedHandle {
46    inner: Arc<InstrumentedState>,
47}
48
49#[derive(Default)]
50struct InstrumentedState {
51    ops_log: Mutex<Vec<DriverOp>>,
52    faults: Mutex<VecDeque<Fault>>,
53}
54
55impl InstrumentedHandle {
56    /// Get the number of logged operations
57    pub fn len(&self) -> usize {
58        self.inner.ops_log.lock().unwrap().len()
59    }
60
61    /// Check if the log is empty
62    pub fn is_empty(&self) -> bool {
63        self.inner.ops_log.lock().unwrap().is_empty()
64    }
65
66    /// Clear the log
67    pub fn clear(&self) {
68        self.inner.ops_log.lock().unwrap().clear();
69    }
70
71    /// Remove and return the first operation from the log
72    #[track_caller]
73    pub fn pop(&self) -> (Operation, ExecResponse) {
74        let mut ops = self.inner.ops_log.lock().unwrap();
75        if ops.is_empty() {
76            panic!("no operations in log");
77        }
78        let driver_op = ops.remove(0);
79        (driver_op.operation, driver_op.response)
80    }
81
82    #[track_caller]
83    pub fn pop_op(&self) -> Operation {
84        self.pop().0
85    }
86
87    /// Queue a fault to fire on the next driver `exec` call. Faults fire
88    /// in FIFO order across all connections produced by the driver.
89    pub fn inject_fault(&self, fault: Fault) {
90        self.inner
91            .faults
92            .lock()
93            .expect("Failed to acquire faults lock")
94            .push_back(fault);
95    }
96}
97
98impl fmt::Debug for InstrumentedHandle {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        let ops = self.inner.ops_log.lock().unwrap();
101        f.debug_struct("InstrumentedHandle")
102            .field("ops", &*ops)
103            .finish()
104    }
105}
106
107/// Test-only driver wrapper that instruments an underlying driver: it
108/// records every operation for later assertion and can inject faults
109/// (connection loss, etc.) to exercise error-handling paths.
110#[derive(Debug)]
111pub struct InstrumentedDriver {
112    inner: Box<dyn Driver>,
113    handle: InstrumentedHandle,
114}
115
116impl InstrumentedDriver {
117    pub fn new(driver: Box<dyn Driver>) -> Self {
118        Self {
119            inner: driver,
120            handle: InstrumentedHandle::default(),
121        }
122    }
123
124    /// Get the single control handle for this driver. The handle exposes
125    /// both the operations log and the fault-injection queue.
126    pub fn handle(&self) -> InstrumentedHandle {
127        self.handle.clone()
128    }
129}
130
131#[async_trait]
132impl Driver for InstrumentedDriver {
133    fn url(&self) -> Cow<'_, str> {
134        self.inner.url()
135    }
136
137    fn capability(&self) -> &'static Capability {
138        self.inner.capability()
139    }
140
141    async fn connect(&self) -> Result<Box<dyn Connection>> {
142        Ok(Box::new(InstrumentedConnection {
143            inner: self.inner.connect().await?,
144            handle: self.handle.clone(),
145            valid: AtomicBool::new(true),
146        }))
147    }
148
149    fn generate_migration(&self, schema_diff: &diff::Schema<'_>) -> Migration {
150        self.inner.generate_migration(schema_diff)
151    }
152
153    async fn reset_db(&self) -> Result<()> {
154        self.inner.reset_db().await
155    }
156}
157
158/// Per-connection counterpart of [`InstrumentedDriver`]: records each
159/// `exec` and consults the shared fault queue before delegating.
160#[derive(Debug)]
161pub struct InstrumentedConnection {
162    /// The underlying driver that actually executes operations
163    inner: Box<dyn Connection>,
164
165    /// Shared handle: ops log + fault queue.
166    handle: InstrumentedHandle,
167
168    /// Set to `false` once an injected `ConnectionLost` fault has fired
169    /// against this connection. Surfaced through [`Connection::is_valid`]
170    /// so the pool evicts it the same way it would after a real
171    /// connection-lost error.
172    valid: AtomicBool,
173}
174
175#[async_trait]
176impl Connection for InstrumentedConnection {
177    async fn exec(&mut self, schema: &Arc<Schema>, operation: Operation) -> Result<ExecResponse> {
178        // Pop a queued fault, if any, and short-circuit before reaching
179        // the underlying driver.
180        let fault = self
181            .handle
182            .inner
183            .faults
184            .lock()
185            .expect("Failed to acquire faults lock")
186            .pop_front();
187        if let Some(fault) = fault {
188            match fault {
189                Fault::ConnectionLost => {
190                    self.valid.store(false, Ordering::Release);
191                    return Err(toasty_core::Error::connection_lost(std::io::Error::other(
192                        "injected connection-lost fault",
193                    )));
194                }
195            }
196        }
197
198        // Clone the operation for logging
199        let operation_clone = operation.clone();
200
201        // Execute the operation on the underlying driver
202        let mut response = self.inner.exec(schema, operation).await?;
203
204        // Duplicate the response for logging
205        let duplicated_response = duplicate_response_mut(&mut response).await?;
206
207        // Log the operation and response
208        let driver_op = DriverOp {
209            operation: operation_clone,
210            response: duplicated_response,
211        };
212
213        self.handle
214            .inner
215            .ops_log
216            .lock()
217            .expect("Failed to acquire ops log lock")
218            .push(driver_op);
219
220        Ok(response)
221    }
222
223    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
224        self.inner.push_schema(schema).await
225    }
226
227    async fn applied_migrations(&mut self) -> Result<Vec<AppliedMigration>> {
228        self.inner.applied_migrations().await
229    }
230
231    async fn apply_migration(&mut self, id: u64, name: &str, migration: &Migration) -> Result<()> {
232        self.inner.apply_migration(id, name, migration).await
233    }
234
235    fn is_valid(&self) -> bool {
236        self.valid.load(Ordering::Acquire) && self.inner.is_valid()
237    }
238
239    async fn ping(&mut self) -> Result<()> {
240        // Consume a queued fault before delegating, mirroring `exec`.
241        // A `ConnectionLost` fault here lets tests target the sweep's
242        // ping path the same way they target user query paths.
243        let fault = self
244            .handle
245            .inner
246            .faults
247            .lock()
248            .expect("Failed to acquire faults lock")
249            .pop_front();
250        if let Some(fault) = fault {
251            match fault {
252                Fault::ConnectionLost => {
253                    self.valid.store(false, Ordering::Release);
254                    return Err(toasty_core::Error::connection_lost(std::io::Error::other(
255                        "injected connection-lost fault",
256                    )));
257                }
258            }
259        }
260        self.inner.ping().await
261    }
262}
263
264/// Duplicate an ExecResponse, using ValueStream::dup() for value streams
265/// This version takes a mutable reference so we can call dup() on the ValueStream
266async fn duplicate_response_mut(response: &mut ExecResponse) -> Result<ExecResponse> {
267    let values = match &mut response.values {
268        Rows::Count(count) => Rows::Count(*count),
269        Rows::Value(_) => todo!(),
270        Rows::Stream(stream) => {
271            // Duplicate the value stream
272            let duplicated_stream = stream.dup().await?;
273            Rows::Stream(duplicated_stream)
274        }
275    };
276
277    Ok(ExecResponse::from_rows(values))
278}