toasty_driver_postgresql/
lib.rs

1#![warn(missing_docs)]
2
3//! Toasty driver for [PostgreSQL](https://www.postgresql.org/) using
4//! [`tokio-postgres`](https://docs.rs/tokio-postgres).
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use toasty_driver_postgresql::PostgreSQL;
10//!
11//! let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
12//! ```
13
14mod statement_cache;
15mod r#type;
16mod value;
17
18pub(crate) use value::Value;
19
20use async_trait::async_trait;
21use postgres::{Socket, tls::MakeTlsConnect, types::ToSql};
22use std::{borrow::Cow, sync::Arc};
23use toasty_core::{
24    Result, Schema,
25    driver::{Capability, Driver, ExecResponse, Operation},
26    schema::db::{self, Migration, SchemaDiff, Table},
27    stmt,
28    stmt::ValueRecord,
29};
30use toasty_sql::{self as sql, TypedValue};
31use tokio_postgres::{Client, Config};
32use url::Url;
33
34use crate::{statement_cache::StatementCache, r#type::TypeExt};
35
36/// A PostgreSQL [`Driver`] that connects via `tokio-postgres`.
37///
38/// # Examples
39///
40/// ```no_run
41/// use toasty_driver_postgresql::PostgreSQL;
42///
43/// let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
44/// ```
45#[derive(Debug)]
46pub struct PostgreSQL {
47    url: String,
48    config: Config,
49}
50
51impl PostgreSQL {
52    /// Create a new PostgreSQL driver from a connection URL
53    pub fn new(url: impl Into<String>) -> Result<Self> {
54        let url_str = url.into();
55        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
56
57        if !matches!(url.scheme(), "postgresql" | "postgres") {
58            return Err(toasty_core::Error::invalid_connection_url(format!(
59                "connection URL does not have a `postgresql` scheme; url={}",
60                url
61            )));
62        }
63
64        let host = url.host_str().ok_or_else(|| {
65            toasty_core::Error::invalid_connection_url(format!(
66                "missing host in connection URL; url={}",
67                url
68            ))
69        })?;
70
71        if url.path().is_empty() {
72            return Err(toasty_core::Error::invalid_connection_url(format!(
73                "no database specified - missing path in connection URL; url={}",
74                url
75            )));
76        }
77
78        let mut config = Config::new();
79        config.host(host);
80        config.dbname(url.path().trim_start_matches('/'));
81
82        if let Some(port) = url.port() {
83            config.port(port);
84        }
85
86        if !url.username().is_empty() {
87            config.user(url.username());
88        }
89
90        if let Some(password) = url.password() {
91            config.password(password);
92        }
93
94        Ok(Self {
95            url: url_str,
96            config,
97        })
98    }
99}
100
101#[async_trait]
102impl Driver for PostgreSQL {
103    fn url(&self) -> Cow<'_, str> {
104        Cow::Borrowed(&self.url)
105    }
106
107    fn capability(&self) -> &'static Capability {
108        &Capability::POSTGRESQL
109    }
110
111    async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
112        Ok(Box::new(
113            Connection::connect(self.config.clone(), tokio_postgres::NoTls).await?,
114        ))
115    }
116
117    fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
118        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
119
120        let sql_strings: Vec<String> = statements
121            .iter()
122            .map(|stmt| {
123                let mut params = Vec::<TypedValue>::new();
124                let sql = sql::Serializer::postgresql(stmt.schema())
125                    .serialize(stmt.statement(), &mut params);
126                assert!(
127                    params.is_empty(),
128                    "migration statements should not have parameters"
129                );
130                sql
131            })
132            .collect();
133
134        Migration::new_sql(sql_strings.join("\n"))
135    }
136
137    async fn reset_db(&self) -> toasty_core::Result<()> {
138        let dbname = self
139            .config
140            .get_dbname()
141            .ok_or_else(|| {
142                toasty_core::Error::invalid_connection_url("no database name configured")
143            })?
144            .to_string();
145
146        // We cannot drop a database we are currently connected to, so we need a temp database.
147        let temp_dbname = "__toasty_reset_temp";
148
149        let connect = |dbname: &str| {
150            let mut config = self.config.clone();
151            config.dbname(dbname);
152            Connection::connect(config, tokio_postgres::NoTls)
153        };
154
155        // Step 1: Connect to the target DB and create a temp DB
156        let conn = connect(&dbname).await?;
157        conn.client
158            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
159            .await
160            .map_err(toasty_core::Error::driver_operation_failed)?;
161        conn.client
162            .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
163            .await
164            .map_err(toasty_core::Error::driver_operation_failed)?;
165        drop(conn);
166
167        // Step 2: Connect to the temp DB, drop and recreate the target
168        let conn = connect(temp_dbname).await?;
169        conn.client
170            .execute(
171                "SELECT pg_terminate_backend(pid) \
172                 FROM pg_stat_activity \
173                 WHERE datname = $1 AND pid <> pg_backend_pid()",
174                &[&dbname],
175            )
176            .await
177            .map_err(toasty_core::Error::driver_operation_failed)?;
178        conn.client
179            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
180            .await
181            .map_err(toasty_core::Error::driver_operation_failed)?;
182        conn.client
183            .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
184            .await
185            .map_err(toasty_core::Error::driver_operation_failed)?;
186        drop(conn);
187
188        // Step 3: Connect back to the target and clean up the temp DB
189        let conn = connect(&dbname).await?;
190        conn.client
191            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
192            .await
193            .map_err(toasty_core::Error::driver_operation_failed)?;
194
195        Ok(())
196    }
197}
198
199/// An open connection to a PostgreSQL database.
200#[derive(Debug)]
201pub struct Connection {
202    client: Client,
203    statement_cache: StatementCache,
204}
205
206impl Connection {
207    /// Initialize a Toasty PostgreSQL connection using an initialized client.
208    pub fn new(client: Client) -> Self {
209        Self {
210            client,
211            statement_cache: StatementCache::new(100),
212        }
213    }
214
215    /// Connects to a PostgreSQL database using a [`postgres::Config`].
216    ///
217    /// See [`postgres::Client::configure`] for more information.
218    pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
219    where
220        T: MakeTlsConnect<Socket> + 'static,
221        T::Stream: Send,
222    {
223        let (client, connection) = config
224            .connect(tls)
225            .await
226            .map_err(toasty_core::Error::driver_operation_failed)?;
227
228        tokio::spawn(async move {
229            if let Err(e) = connection.await {
230                eprintln!("connection error: {e}");
231            }
232        });
233
234        Ok(Self::new(client))
235    }
236
237    /// Creates a table.
238    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
239        let serializer = sql::Serializer::postgresql(schema);
240
241        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
242        let sql = serializer.serialize(
243            &sql::Statement::create_table(table, &Capability::POSTGRESQL),
244            &mut params,
245        );
246
247        assert!(
248            params.is_empty(),
249            "creating a table shouldn't involve any parameters"
250        );
251
252        self.client
253            .execute(&sql, &[])
254            .await
255            .map_err(toasty_core::Error::driver_operation_failed)?;
256
257        // NOTE: `params` is guaranteed to be empty based on the assertion above. If
258        // that changes, `params.clear()` should be called here.
259        for index in &table.indices {
260            if index.primary_key {
261                continue;
262            }
263
264            let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
265
266            assert!(
267                params.is_empty(),
268                "creating an index shouldn't involve any parameters"
269            );
270
271            self.client
272                .execute(&sql, &[])
273                .await
274                .map_err(toasty_core::Error::driver_operation_failed)?;
275        }
276
277        Ok(())
278    }
279}
280
281impl From<Client> for Connection {
282    fn from(client: Client) -> Self {
283        Self {
284            client,
285            statement_cache: StatementCache::new(100),
286        }
287    }
288}
289
290#[async_trait]
291impl toasty_core::driver::Connection for Connection {
292    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
293        tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
294
295        if let Operation::Transaction(ref t) = op {
296            let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
297            self.client.batch_execute(&sql).await.map_err(|e| {
298                if let Some(db_err) = e.as_db_error() {
299                    match db_err.code().code() {
300                        "40001" => toasty_core::Error::serialization_failure(db_err.message()),
301                        "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
302                        _ => toasty_core::Error::driver_operation_failed(e),
303                    }
304                } else {
305                    toasty_core::Error::driver_operation_failed(e)
306                }
307            })?;
308            return Ok(ExecResponse::count(0));
309        }
310
311        let (sql, ret_tys): (sql::Statement, _) = match op {
312            Operation::Insert(op) => (op.stmt.into(), None),
313            Operation::QuerySql(query) => {
314                assert!(
315                    query.last_insert_id_hack.is_none(),
316                    "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
317                );
318                (query.stmt.into(), query.ret)
319            }
320            op => todo!("op={:#?}", op),
321        };
322
323        let width = sql.returning_len();
324
325        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
326        let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql, &mut params);
327
328        tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = params.len(), "executing SQL");
329
330        let param_types = params
331            .iter()
332            .map(|typed_value| typed_value.infer_ty().to_postgres_type())
333            .collect::<Vec<_>>();
334
335        let values: Vec<_> = params.into_iter().map(|tv| Value::from(tv.value)).collect();
336        let params = values
337            .iter()
338            .map(|param| param as &(dyn ToSql + Sync))
339            .collect::<Vec<_>>();
340
341        let statement = self
342            .statement_cache
343            .prepare_typed(&mut self.client, &sql_as_str, &param_types)
344            .await
345            .map_err(toasty_core::Error::driver_operation_failed)?;
346
347        if width.is_none() {
348            let count = self
349                .client
350                .execute(&statement, &params)
351                .await
352                .map_err(toasty_core::Error::driver_operation_failed)?;
353            return Ok(ExecResponse::count(count));
354        }
355
356        let rows = self
357            .client
358            .query(&statement, &params)
359            .await
360            .map_err(toasty_core::Error::driver_operation_failed)?;
361
362        if width.is_none() {
363            let [row] = &rows[..] else { todo!() };
364            let total = row.get::<usize, i64>(0);
365            let condition_matched = row.get::<usize, i64>(1);
366
367            if total == condition_matched {
368                Ok(ExecResponse::count(total as _))
369            } else {
370                Err(toasty_core::Error::condition_failed(
371                    "update condition did not match",
372                ))
373            }
374        } else {
375            let ret_tys = ret_tys.as_ref().unwrap().clone();
376            let results = rows.into_iter().map(move |row| {
377                let mut results = Vec::new();
378                for (i, column) in row.columns().iter().enumerate() {
379                    results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
380                }
381
382                Ok(ValueRecord::from_vec(results))
383            });
384
385            Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
386                results,
387            )))
388        }
389    }
390
391    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
392        for table in &schema.db.tables {
393            tracing::debug!(table = %table.name, "creating table");
394            self.create_table(&schema.db, table).await?;
395        }
396        Ok(())
397    }
398
399    async fn applied_migrations(
400        &mut self,
401    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
402        // Ensure the migrations table exists
403        self.client
404            .execute(
405                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
406                id BIGINT PRIMARY KEY,
407                name TEXT NOT NULL,
408                applied_at TIMESTAMP NOT NULL
409            )",
410                &[],
411            )
412            .await
413            .map_err(toasty_core::Error::driver_operation_failed)?;
414
415        // Query all applied migrations
416        let rows = self
417            .client
418            .query(
419                "SELECT id FROM __toasty_migrations ORDER BY applied_at",
420                &[],
421            )
422            .await
423            .map_err(toasty_core::Error::driver_operation_failed)?;
424
425        Ok(rows
426            .iter()
427            .map(|row| {
428                let id: i64 = row.get(0);
429                toasty_core::schema::db::AppliedMigration::new(id as u64)
430            })
431            .collect())
432    }
433
434    async fn apply_migration(
435        &mut self,
436        id: u64,
437        name: &str,
438        migration: &toasty_core::schema::db::Migration,
439    ) -> Result<()> {
440        tracing::info!(id = id, name = %name, "applying migration");
441        // Ensure the migrations table exists
442        self.client
443            .execute(
444                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
445                id BIGINT PRIMARY KEY,
446                name TEXT NOT NULL,
447                applied_at TIMESTAMP NOT NULL
448            )",
449                &[],
450            )
451            .await
452            .map_err(toasty_core::Error::driver_operation_failed)?;
453
454        // Start transaction
455        let transaction = self
456            .client
457            .transaction()
458            .await
459            .map_err(toasty_core::Error::driver_operation_failed)?;
460
461        // Execute each migration statement
462        for statement in migration.statements() {
463            if let Err(e) = transaction
464                .batch_execute(statement)
465                .await
466                .map_err(toasty_core::Error::driver_operation_failed)
467            {
468                transaction
469                    .rollback()
470                    .await
471                    .map_err(toasty_core::Error::driver_operation_failed)?;
472                return Err(e);
473            }
474        }
475
476        // Record the migration
477        if let Err(e) = transaction
478            .execute(
479                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
480                &[&(id as i64), &name],
481            )
482            .await
483            .map_err(toasty_core::Error::driver_operation_failed)
484        {
485            transaction
486                .rollback()
487                .await
488                .map_err(toasty_core::Error::driver_operation_failed)?;
489            return Err(e);
490        }
491
492        // Commit transaction
493        transaction
494            .commit()
495            .await
496            .map_err(toasty_core::Error::driver_operation_failed)?;
497        Ok(())
498    }
499}