toasty_driver_integration_suite/
test.rs

1use std::{
2    error::Error,
3    sync::{Arc, Mutex, RwLock},
4};
5
6use toasty::{Db, schema::ModelSet};
7use tokio::runtime::Runtime;
8
9use crate::{ExecLog, Isolate, LoggingDriver, Setup};
10
11/// Global lock for coordinating serial vs parallel tests.
12/// Normal tests acquire a read lock (allowing parallelism).
13/// Serial tests acquire a write lock (exclusive access).
14static TEST_LOCK: RwLock<()> = RwLock::new(());
15
16/// Wraps the Tokio runtime and ensures cleanup happens.
17///
18/// This also passes necessary
19pub struct Test {
20    /// Handle to the DB suite setup
21    setup: Arc<dyn Setup>,
22
23    /// Handles isolating tables between tests
24    isolate: Isolate,
25
26    /// Tokio runtime used by the test
27    runtime: Option<Runtime>,
28
29    exec_log: ExecLog,
30
31    /// List of all tables created during the test. These will need to be removed later.
32    tables: Vec<String>,
33
34    /// Whether this test requires exclusive (serial) execution
35    serial: bool,
36}
37
38impl Test {
39    pub fn new(setup: Arc<dyn Setup>) -> Self {
40        let runtime = tokio::runtime::Builder::new_current_thread()
41            .enable_all()
42            .build()
43            .expect("failed to create Tokio runtime");
44
45        Test {
46            setup,
47            isolate: Isolate::new(),
48            runtime: Some(runtime),
49            exec_log: ExecLog::new(Arc::new(Mutex::new(Vec::new()))),
50            tables: vec![],
51            serial: false,
52        }
53    }
54
55    /// Try to setup a database with models, returns Result for error handling
56    pub async fn try_setup_db(&mut self, models: ModelSet) -> toasty::Result<Db> {
57        let mut builder = toasty::Db::builder();
58        builder.models(models);
59
60        // Set the table prefix
61        builder.table_name_prefix(&self.isolate.table_prefix());
62
63        // Always wrap with logging
64        let logging_driver = LoggingDriver::new(self.setup.driver());
65        let ops_log = logging_driver.ops_log_handle();
66        self.exec_log = ExecLog::new(ops_log);
67
68        // Build the database with the logging driver
69        let db = builder.build(logging_driver).await?;
70        db.push_schema().await?;
71
72        for table in &db.schema().db.tables {
73            self.tables.push(table.name.clone());
74        }
75
76        Ok(db)
77    }
78
79    /// Setup a database with models, always with logging enabled
80    pub async fn setup_db(&mut self, models: ModelSet) -> Db {
81        self.try_setup_db(models).await.unwrap()
82    }
83
84    /// Get the driver capability
85    pub fn capability(&self) -> &'static toasty_core::driver::Capability {
86        self.setup.driver().capability()
87    }
88
89    /// Get the execution log for assertions
90    pub fn log(&mut self) -> &mut ExecLog {
91        &mut self.exec_log
92    }
93
94    /// Set whether this test requires exclusive (serial) execution
95    pub fn set_serial(&mut self, serial: bool) {
96        self.serial = serial;
97    }
98
99    /// Run an async test function using the internal runtime
100    pub fn run<R>(&mut self, f: impl AsyncFn(&mut Test) -> R)
101    where
102        R: Into<TestResult>,
103    {
104        // Acquire the appropriate lock: write lock for serial tests (exclusive),
105        // read lock for normal tests (parallel).
106        let _guard: Box<dyn std::any::Any> = if self.serial {
107            Box::new(TEST_LOCK.write().unwrap_or_else(|e| e.into_inner()))
108        } else {
109            Box::new(TEST_LOCK.read().unwrap_or_else(|e| e.into_inner()))
110        };
111
112        // Temporarily take the runtime to avoid borrow checker issues
113        let runtime = self.runtime.take().expect("runtime already consumed");
114        let f: std::pin::Pin<Box<dyn std::future::Future<Output = R>>> = Box::pin(f(self));
115        let result = runtime.block_on(f).into();
116
117        // now, wut
118        for table in &self.tables {
119            runtime.block_on(self.setup.delete_table(table));
120        }
121
122        if let Some(error) = result.error {
123            panic!("Driver test returned an error: {error}");
124        }
125
126        self.runtime = Some(runtime);
127    }
128}
129
130pub struct TestResult {
131    error: Option<Box<dyn Error>>,
132}
133
134impl From<()> for TestResult {
135    fn from(_: ()) -> Self {
136        TestResult { error: None }
137    }
138}
139
140impl<O, E> From<Result<O, E>> for TestResult
141where
142    E: Into<Box<dyn Error>>,
143{
144    fn from(value: Result<O, E>) -> Self {
145        TestResult {
146            error: value.err().map(Into::into),
147        }
148    }
149}