Skip to main content

toasty_driver_integration_suite/
test.rs

1use std::{
2    error::Error,
3    sync::{Arc, RwLock},
4};
5
6use toasty::{Db, schema::ModelSet};
7use tokio::runtime::Runtime;
8
9use crate::{Fault, InstrumentedDriver, InstrumentedHandle, Isolate, Setup};
10
11/// Global lock for coordinating serial vs parallel tests.
12/// Normal tests acquire a read lock (allowing parallelism).
13/// Serial tests acquire a write lock (exclusive access).
14static TEST_LOCK: RwLock<()> = RwLock::new(());
15
16/// Wraps the Tokio runtime and ensures cleanup happens.
17///
18/// This also passes necessary
19pub struct Test {
20    /// Handle to the DB suite setup
21    setup: Arc<dyn Setup>,
22
23    /// Handles isolating tables between tests
24    isolate: Isolate,
25
26    /// Tokio runtime used by the test
27    runtime: Option<Runtime>,
28
29    /// Single handle controlling the instrumented driver test middleware:
30    /// the operations log and the fault-injection queue. Populated by
31    /// `try_setup_db_with`.
32    handle: InstrumentedHandle,
33
34    /// List of all tables created during the test. These will need to be removed later.
35    tables: Vec<String>,
36
37    /// Whether this test requires exclusive (serial) execution
38    serial: bool,
39}
40
41impl Test {
42    pub fn new(setup: Arc<dyn Setup>) -> Self {
43        let runtime = tokio::runtime::Builder::new_current_thread()
44            .enable_all()
45            .build()
46            .expect("failed to create Tokio runtime");
47
48        Test {
49            setup,
50            isolate: Isolate::new(),
51            runtime: Some(runtime),
52            handle: InstrumentedHandle::default(),
53            tables: vec![],
54            serial: false,
55        }
56    }
57
58    /// Try to setup a database with models, returns Result for error handling
59    pub async fn try_setup_db(&mut self, models: ModelSet) -> toasty::Result<Db> {
60        self.try_setup_db_with(models, |_| {}).await
61    }
62
63    /// Try to setup a database with models, allowing the caller to customize
64    /// the [`toasty::db::Builder`] before it is built (e.g., to set pool
65    /// configuration).
66    pub async fn try_setup_db_with(
67        &mut self,
68        models: ModelSet,
69        customize: impl FnOnce(&mut toasty::db::Builder),
70    ) -> toasty::Result<Db> {
71        let mut builder = toasty::Db::builder();
72        builder.models(models);
73
74        // Set the table prefix
75        builder.table_name_prefix(&self.isolate.table_prefix());
76
77        // Apply caller customizations
78        customize(&mut builder);
79
80        // Always wrap with the instrumented test driver
81        let instrumented_driver = InstrumentedDriver::new(self.setup.driver());
82        self.handle = instrumented_driver.handle();
83
84        // Build the database with the instrumented driver
85        let db = builder.build(instrumented_driver).await?;
86        db.push_schema().await?;
87
88        for table in &db.schema().db.tables {
89            self.tables.push(table.name.clone());
90        }
91
92        Ok(db)
93    }
94
95    /// Setup a database with models, always with logging enabled
96    pub async fn setup_db(&mut self, models: ModelSet) -> Db {
97        self.try_setup_db(models).await.unwrap()
98    }
99
100    /// Setup a database, applying the given customization to the
101    /// [`toasty::db::Builder`] before building.
102    pub async fn setup_db_with(
103        &mut self,
104        models: ModelSet,
105        customize: impl FnOnce(&mut toasty::db::Builder),
106    ) -> Db {
107        self.try_setup_db_with(models, customize).await.unwrap()
108    }
109
110    /// Get the driver capability
111    pub fn capability(&self) -> &'static toasty_core::driver::Capability {
112        self.setup.driver().capability()
113    }
114
115    /// Get the instrumented-driver control handle. The handle exposes
116    /// the operation log (for assertions) and fault injection.
117    pub fn log(&self) -> &InstrumentedHandle {
118        &self.handle
119    }
120
121    /// Queue a fault to fire on the next driver `exec` call. Faults
122    /// fire in FIFO order. Only useful after `setup_db` has installed
123    /// the instrumented driver.
124    pub fn inject_fault(&self, fault: Fault) {
125        self.handle.inject_fault(fault);
126    }
127
128    /// Set whether this test requires exclusive (serial) execution
129    pub fn set_serial(&mut self, serial: bool) {
130        self.serial = serial;
131    }
132
133    /// Run an async test function using the internal runtime
134    pub fn run<R>(&mut self, f: impl AsyncFn(&mut Test) -> R)
135    where
136        R: Into<TestResult>,
137    {
138        // Acquire the appropriate lock: write lock for serial tests (exclusive),
139        // read lock for normal tests (parallel).
140        let _guard: Box<dyn std::any::Any> = if self.serial {
141            Box::new(TEST_LOCK.write().unwrap_or_else(|e| e.into_inner()))
142        } else {
143            Box::new(TEST_LOCK.read().unwrap_or_else(|e| e.into_inner()))
144        };
145
146        // Temporarily take the runtime to avoid borrow checker issues
147        let runtime = self.runtime.take().expect("runtime already consumed");
148        let f: std::pin::Pin<Box<dyn std::future::Future<Output = R>>> = Box::pin(f(self));
149        let result = runtime.block_on(f).into();
150
151        // now, wut
152        for table in &self.tables {
153            runtime.block_on(self.setup.delete_table(table));
154        }
155
156        if let Some(error) = result.error {
157            panic!("Driver test returned an error: {error}");
158        }
159
160        self.runtime = Some(runtime);
161    }
162}
163
164pub struct TestResult {
165    error: Option<Box<dyn Error>>,
166}
167
168impl From<()> for TestResult {
169    fn from(_: ()) -> Self {
170        TestResult { error: None }
171    }
172}
173
174impl<O, E> From<Result<O, E>> for TestResult
175where
176    E: Into<Box<dyn Error>>,
177{
178    fn from(value: Result<O, E>) -> Self {
179        TestResult {
180            error: value.err().map(Into::into),
181        }
182    }
183}