toasty_driver_postgresql/
lib.rs

1mod statement_cache;
2mod r#type;
3mod value;
4
5pub(crate) use value::Value;
6
7use postgres::{tls::MakeTlsConnect, types::ToSql, Socket};
8use std::{borrow::Cow, sync::Arc};
9use toasty_core::{
10    async_trait,
11    driver::{Capability, Driver, Operation, Response},
12    schema::db::{self, Migration, SchemaDiff, Table},
13    stmt,
14    stmt::ValueRecord,
15    Result, Schema,
16};
17use toasty_sql::{self as sql, TypedValue};
18use tokio_postgres::{Client, Config};
19use url::Url;
20
21use crate::{r#type::TypeExt, statement_cache::StatementCache};
22
23#[derive(Debug)]
24pub struct PostgreSQL {
25    url: String,
26    config: Config,
27}
28
29impl PostgreSQL {
30    /// Create a new PostgreSQL driver from a connection URL
31    pub fn new(url: impl Into<String>) -> Result<Self> {
32        let url_str = url.into();
33        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
34
35        if !matches!(url.scheme(), "postgresql" | "postgres") {
36            return Err(toasty_core::Error::invalid_connection_url(format!(
37                "connection URL does not have a `postgresql` scheme; url={}",
38                url
39            )));
40        }
41
42        let host = url.host_str().ok_or_else(|| {
43            toasty_core::Error::invalid_connection_url(format!(
44                "missing host in connection URL; url={}",
45                url
46            ))
47        })?;
48
49        if url.path().is_empty() {
50            return Err(toasty_core::Error::invalid_connection_url(format!(
51                "no database specified - missing path in connection URL; url={}",
52                url
53            )));
54        }
55
56        let mut config = Config::new();
57        config.host(host);
58        config.dbname(url.path().trim_start_matches('/'));
59
60        if let Some(port) = url.port() {
61            config.port(port);
62        }
63
64        if !url.username().is_empty() {
65            config.user(url.username());
66        }
67
68        if let Some(password) = url.password() {
69            config.password(password);
70        }
71
72        Ok(Self {
73            url: url_str,
74            config,
75        })
76    }
77}
78
79#[async_trait]
80impl Driver for PostgreSQL {
81    fn url(&self) -> Cow<'_, str> {
82        Cow::Borrowed(&self.url)
83    }
84
85    fn capability(&self) -> &'static Capability {
86        &Capability::POSTGRESQL
87    }
88
89    async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
90        Ok(Box::new(
91            Connection::connect(self.config.clone(), tokio_postgres::NoTls).await?,
92        ))
93    }
94
95    fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
96        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
97
98        let sql_strings: Vec<String> = statements
99            .iter()
100            .map(|stmt| {
101                let mut params = Vec::<TypedValue>::new();
102                let sql = sql::Serializer::postgresql(stmt.schema())
103                    .serialize(stmt.statement(), &mut params);
104                assert!(
105                    params.is_empty(),
106                    "migration statements should not have parameters"
107                );
108                sql
109            })
110            .collect();
111
112        Migration::new_sql(sql_strings.join("\n"))
113    }
114
115    async fn reset_db(&self) -> toasty_core::Result<()> {
116        let dbname = self
117            .config
118            .get_dbname()
119            .ok_or_else(|| {
120                toasty_core::Error::invalid_connection_url("no database name configured")
121            })?
122            .to_string();
123
124        // We cannot drop a database we are currently connected to, so we need a temp database.
125        let temp_dbname = "__toasty_reset_temp";
126
127        let connect = |dbname: &str| {
128            let mut config = self.config.clone();
129            config.dbname(dbname);
130            Connection::connect(config, tokio_postgres::NoTls)
131        };
132
133        // Step 1: Connect to the target DB and create a temp DB
134        let conn = connect(&dbname).await?;
135        conn.client
136            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
137            .await
138            .map_err(toasty_core::Error::driver_operation_failed)?;
139        conn.client
140            .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
141            .await
142            .map_err(toasty_core::Error::driver_operation_failed)?;
143        drop(conn);
144
145        // Step 2: Connect to the temp DB, drop and recreate the target
146        let conn = connect(temp_dbname).await?;
147        conn.client
148            .execute(
149                "SELECT pg_terminate_backend(pid) \
150                 FROM pg_stat_activity \
151                 WHERE datname = $1 AND pid <> pg_backend_pid()",
152                &[&dbname],
153            )
154            .await
155            .map_err(toasty_core::Error::driver_operation_failed)?;
156        conn.client
157            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
158            .await
159            .map_err(toasty_core::Error::driver_operation_failed)?;
160        conn.client
161            .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
162            .await
163            .map_err(toasty_core::Error::driver_operation_failed)?;
164        drop(conn);
165
166        // Step 3: Connect back to the target and clean up the temp DB
167        let conn = connect(&dbname).await?;
168        conn.client
169            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
170            .await
171            .map_err(toasty_core::Error::driver_operation_failed)?;
172
173        Ok(())
174    }
175}
176
177#[derive(Debug)]
178pub struct Connection {
179    client: Client,
180    statement_cache: StatementCache,
181}
182
183impl Connection {
184    /// Initialize a Toasty PostgreSQL connection using an initialized client.
185    pub fn new(client: Client) -> Self {
186        Self {
187            client,
188            statement_cache: StatementCache::new(100),
189        }
190    }
191
192    /// Connects to a PostgreSQL database using a [`postgres::Config`].
193    ///
194    /// See [`postgres::Client::configure`] for more information.
195    pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
196    where
197        T: MakeTlsConnect<Socket> + 'static,
198        T::Stream: Send,
199    {
200        let (client, connection) = config
201            .connect(tls)
202            .await
203            .map_err(toasty_core::Error::driver_operation_failed)?;
204
205        tokio::spawn(async move {
206            if let Err(e) = connection.await {
207                eprintln!("connection error: {e}");
208            }
209        });
210
211        Ok(Self::new(client))
212    }
213
214    /// Creates a table.
215    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
216        let serializer = sql::Serializer::postgresql(schema);
217
218        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
219        let sql = serializer.serialize(
220            &sql::Statement::create_table(table, &Capability::POSTGRESQL),
221            &mut params,
222        );
223
224        assert!(
225            params.is_empty(),
226            "creating a table shouldn't involve any parameters"
227        );
228
229        self.client
230            .execute(&sql, &[])
231            .await
232            .map_err(toasty_core::Error::driver_operation_failed)?;
233
234        // NOTE: `params` is guaranteed to be empty based on the assertion above. If
235        // that changes, `params.clear()` should be called here.
236        for index in &table.indices {
237            if index.primary_key {
238                continue;
239            }
240
241            let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
242
243            assert!(
244                params.is_empty(),
245                "creating an index shouldn't involve any parameters"
246            );
247
248            self.client
249                .execute(&sql, &[])
250                .await
251                .map_err(toasty_core::Error::driver_operation_failed)?;
252        }
253
254        Ok(())
255    }
256}
257
258impl From<Client> for Connection {
259    fn from(client: Client) -> Self {
260        Self {
261            client,
262            statement_cache: StatementCache::new(100),
263        }
264    }
265}
266
267#[async_trait]
268impl toasty_core::driver::Connection for Connection {
269    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<Response> {
270        if let Operation::Transaction(ref t) = op {
271            let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
272            self.client.batch_execute(&sql).await.map_err(|e| {
273                if let Some(db_err) = e.as_db_error() {
274                    match db_err.code().code() {
275                        "40001" => toasty_core::Error::serialization_failure(db_err.message()),
276                        "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
277                        _ => toasty_core::Error::driver_operation_failed(e),
278                    }
279                } else {
280                    toasty_core::Error::driver_operation_failed(e)
281                }
282            })?;
283            return Ok(Response::count(0));
284        }
285
286        let (sql, ret_tys): (sql::Statement, _) = match op {
287            Operation::Insert(op) => (op.stmt.into(), None),
288            Operation::QuerySql(query) => {
289                assert!(
290                    query.last_insert_id_hack.is_none(),
291                    "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
292                );
293                (query.stmt.into(), query.ret)
294            }
295            op => todo!("op={:#?}", op),
296        };
297
298        let width = sql.returning_len();
299
300        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
301        let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql, &mut params);
302
303        let param_types = params
304            .iter()
305            .map(|typed_value| typed_value.infer_ty().to_postgres_type())
306            .collect::<Vec<_>>();
307
308        let values: Vec<_> = params.into_iter().map(|tv| Value::from(tv.value)).collect();
309        let params = values
310            .iter()
311            .map(|param| param as &(dyn ToSql + Sync))
312            .collect::<Vec<_>>();
313
314        let statement = self
315            .statement_cache
316            .prepare_typed(&mut self.client, &sql_as_str, &param_types)
317            .await
318            .map_err(toasty_core::Error::driver_operation_failed)?;
319
320        if width.is_none() {
321            let count = self
322                .client
323                .execute(&statement, &params)
324                .await
325                .map_err(toasty_core::Error::driver_operation_failed)?;
326            return Ok(Response::count(count));
327        }
328
329        let rows = self
330            .client
331            .query(&statement, &params)
332            .await
333            .map_err(toasty_core::Error::driver_operation_failed)?;
334
335        if width.is_none() {
336            let [row] = &rows[..] else { todo!() };
337            let total = row.get::<usize, i64>(0);
338            let condition_matched = row.get::<usize, i64>(1);
339
340            if total == condition_matched {
341                Ok(Response::count(total as _))
342            } else {
343                Err(toasty_core::Error::condition_failed(
344                    "update condition did not match",
345                ))
346            }
347        } else {
348            let ret_tys = ret_tys.as_ref().unwrap().clone();
349            let results = rows.into_iter().map(move |row| {
350                let mut results = Vec::new();
351                for (i, column) in row.columns().iter().enumerate() {
352                    results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
353                }
354
355                Ok(ValueRecord::from_vec(results))
356            });
357
358            Ok(Response::value_stream(stmt::ValueStream::from_iter(
359                results,
360            )))
361        }
362    }
363
364    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
365        for table in &schema.db.tables {
366            self.create_table(&schema.db, table).await?;
367        }
368        Ok(())
369    }
370
371    async fn applied_migrations(
372        &mut self,
373    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
374        // Ensure the migrations table exists
375        self.client
376            .execute(
377                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
378                id BIGINT PRIMARY KEY,
379                name TEXT NOT NULL,
380                applied_at TIMESTAMP NOT NULL
381            )",
382                &[],
383            )
384            .await
385            .map_err(toasty_core::Error::driver_operation_failed)?;
386
387        // Query all applied migrations
388        let rows = self
389            .client
390            .query(
391                "SELECT id FROM __toasty_migrations ORDER BY applied_at",
392                &[],
393            )
394            .await
395            .map_err(toasty_core::Error::driver_operation_failed)?;
396
397        Ok(rows
398            .iter()
399            .map(|row| {
400                let id: i64 = row.get(0);
401                toasty_core::schema::db::AppliedMigration::new(id as u64)
402            })
403            .collect())
404    }
405
406    async fn apply_migration(
407        &mut self,
408        id: u64,
409        name: String,
410        migration: &toasty_core::schema::db::Migration,
411    ) -> Result<()> {
412        // Ensure the migrations table exists
413        self.client
414            .execute(
415                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
416                id BIGINT PRIMARY KEY,
417                name TEXT NOT NULL,
418                applied_at TIMESTAMP NOT NULL
419            )",
420                &[],
421            )
422            .await
423            .map_err(toasty_core::Error::driver_operation_failed)?;
424
425        // Start transaction
426        let transaction = self
427            .client
428            .transaction()
429            .await
430            .map_err(toasty_core::Error::driver_operation_failed)?;
431
432        // Execute each migration statement
433        for statement in migration.statements() {
434            if let Err(e) = transaction
435                .batch_execute(statement)
436                .await
437                .map_err(toasty_core::Error::driver_operation_failed)
438            {
439                transaction
440                    .rollback()
441                    .await
442                    .map_err(toasty_core::Error::driver_operation_failed)?;
443                return Err(e);
444            }
445        }
446
447        // Record the migration
448        if let Err(e) = transaction
449            .execute(
450                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
451                &[&(id as i64), &name],
452            )
453            .await
454            .map_err(toasty_core::Error::driver_operation_failed)
455        {
456            transaction
457                .rollback()
458                .await
459                .map_err(toasty_core::Error::driver_operation_failed)?;
460            return Err(e);
461        }
462
463        // Commit transaction
464        transaction
465            .commit()
466            .await
467            .map_err(toasty_core::Error::driver_operation_failed)?;
468        Ok(())
469    }
470}