toasty_driver_integration_suite/
instrumented_driver.rs1use async_trait::async_trait;
2use std::{
3 borrow::Cow,
4 collections::VecDeque,
5 fmt,
6 sync::{
7 Arc, Mutex,
8 atomic::{AtomicBool, Ordering},
9 },
10};
11use toasty_core::{
12 Result, Schema,
13 driver::{Capability, Connection, Driver, ExecResponse, Operation, Rows},
14 schema::{
15 db::{AppliedMigration, Migration},
16 diff,
17 },
18};
19
20#[derive(Debug, Clone)]
25pub enum Fault {
26 ConnectionLost,
32}
33
34#[derive(Debug)]
35pub struct DriverOp {
36 pub operation: Operation,
37 pub response: ExecResponse,
38}
39
40#[derive(Clone, Default)]
45pub struct InstrumentedHandle {
46 inner: Arc<InstrumentedState>,
47}
48
49#[derive(Default)]
50struct InstrumentedState {
51 ops_log: Mutex<Vec<DriverOp>>,
52 faults: Mutex<VecDeque<Fault>>,
53}
54
55impl InstrumentedHandle {
56 pub fn len(&self) -> usize {
58 self.inner.ops_log.lock().unwrap().len()
59 }
60
61 pub fn is_empty(&self) -> bool {
63 self.inner.ops_log.lock().unwrap().is_empty()
64 }
65
66 pub fn clear(&self) {
68 self.inner.ops_log.lock().unwrap().clear();
69 }
70
71 #[track_caller]
73 pub fn pop(&self) -> (Operation, ExecResponse) {
74 let mut ops = self.inner.ops_log.lock().unwrap();
75 if ops.is_empty() {
76 panic!("no operations in log");
77 }
78 let driver_op = ops.remove(0);
79 (driver_op.operation, driver_op.response)
80 }
81
82 #[track_caller]
83 pub fn pop_op(&self) -> Operation {
84 self.pop().0
85 }
86
87 pub fn inject_fault(&self, fault: Fault) {
90 self.inner
91 .faults
92 .lock()
93 .expect("Failed to acquire faults lock")
94 .push_back(fault);
95 }
96}
97
98impl fmt::Debug for InstrumentedHandle {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 let ops = self.inner.ops_log.lock().unwrap();
101 f.debug_struct("InstrumentedHandle")
102 .field("ops", &*ops)
103 .finish()
104 }
105}
106
107#[derive(Debug)]
111pub struct InstrumentedDriver {
112 inner: Box<dyn Driver>,
113 handle: InstrumentedHandle,
114}
115
116impl InstrumentedDriver {
117 pub fn new(driver: Box<dyn Driver>) -> Self {
118 Self {
119 inner: driver,
120 handle: InstrumentedHandle::default(),
121 }
122 }
123
124 pub fn handle(&self) -> InstrumentedHandle {
127 self.handle.clone()
128 }
129}
130
131#[async_trait]
132impl Driver for InstrumentedDriver {
133 fn url(&self) -> Cow<'_, str> {
134 self.inner.url()
135 }
136
137 fn capability(&self) -> &'static Capability {
138 self.inner.capability()
139 }
140
141 async fn connect(&self) -> Result<Box<dyn Connection>> {
142 Ok(Box::new(InstrumentedConnection {
143 inner: self.inner.connect().await?,
144 handle: self.handle.clone(),
145 valid: AtomicBool::new(true),
146 }))
147 }
148
149 fn generate_migration(&self, schema_diff: &diff::Schema<'_>) -> Migration {
150 self.inner.generate_migration(schema_diff)
151 }
152
153 async fn reset_db(&self) -> Result<()> {
154 self.inner.reset_db().await
155 }
156}
157
158#[derive(Debug)]
161pub struct InstrumentedConnection {
162 inner: Box<dyn Connection>,
164
165 handle: InstrumentedHandle,
167
168 valid: AtomicBool,
173}
174
175#[async_trait]
176impl Connection for InstrumentedConnection {
177 async fn exec(&mut self, schema: &Arc<Schema>, operation: Operation) -> Result<ExecResponse> {
178 let fault = self
181 .handle
182 .inner
183 .faults
184 .lock()
185 .expect("Failed to acquire faults lock")
186 .pop_front();
187 if let Some(fault) = fault {
188 match fault {
189 Fault::ConnectionLost => {
190 self.valid.store(false, Ordering::Release);
191 return Err(toasty_core::Error::connection_lost(std::io::Error::other(
192 "injected connection-lost fault",
193 )));
194 }
195 }
196 }
197
198 let operation_clone = operation.clone();
200
201 let mut response = self.inner.exec(schema, operation).await?;
203
204 let duplicated_response = duplicate_response_mut(&mut response).await?;
206
207 let driver_op = DriverOp {
209 operation: operation_clone,
210 response: duplicated_response,
211 };
212
213 self.handle
214 .inner
215 .ops_log
216 .lock()
217 .expect("Failed to acquire ops log lock")
218 .push(driver_op);
219
220 Ok(response)
221 }
222
223 async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
224 self.inner.push_schema(schema).await
225 }
226
227 async fn applied_migrations(&mut self) -> Result<Vec<AppliedMigration>> {
228 self.inner.applied_migrations().await
229 }
230
231 async fn apply_migration(&mut self, id: u64, name: &str, migration: &Migration) -> Result<()> {
232 self.inner.apply_migration(id, name, migration).await
233 }
234
235 fn is_valid(&self) -> bool {
236 self.valid.load(Ordering::Acquire) && self.inner.is_valid()
237 }
238
239 async fn ping(&mut self) -> Result<()> {
240 let fault = self
244 .handle
245 .inner
246 .faults
247 .lock()
248 .expect("Failed to acquire faults lock")
249 .pop_front();
250 if let Some(fault) = fault {
251 match fault {
252 Fault::ConnectionLost => {
253 self.valid.store(false, Ordering::Release);
254 return Err(toasty_core::Error::connection_lost(std::io::Error::other(
255 "injected connection-lost fault",
256 )));
257 }
258 }
259 }
260 self.inner.ping().await
261 }
262}
263
264async fn duplicate_response_mut(response: &mut ExecResponse) -> Result<ExecResponse> {
267 let values = match &mut response.values {
268 Rows::Count(count) => Rows::Count(*count),
269 Rows::Value(_) => todo!(),
270 Rows::Stream(stream) => {
271 let duplicated_stream = stream.dup().await?;
273 Rows::Stream(duplicated_stream)
274 }
275 };
276
277 Ok(ExecResponse::from_rows(values))
278}