Skip to main content

toasty_driver_postgresql/
lib.rs

1#![warn(missing_docs)]
2
3//! Toasty driver for [PostgreSQL](https://www.postgresql.org/) using
4//! [`tokio-postgres`](https://docs.rs/tokio-postgres).
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use toasty_driver_postgresql::PostgreSQL;
10//!
11//! let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
12//! ```
13
14mod oid_cache;
15mod statement_cache;
16#[cfg(feature = "tls")]
17mod tls;
18mod r#type;
19mod value;
20
21pub(crate) use value::Value;
22
23use async_trait::async_trait;
24use percent_encoding::percent_decode_str;
25use std::{borrow::Cow, sync::Arc};
26use toasty_core::{
27    Result, Schema,
28    driver::{
29        Capability, Driver, ExecResponse, Operation,
30        operation::{Transaction, TransactionMode},
31    },
32    schema::{
33        db::{self, Migration, Table},
34        diff,
35    },
36    stmt,
37    stmt::ValueRecord,
38};
39use toasty_sql::{self as sql};
40use tokio_postgres::{Client, Config, Socket, tls::MakeTlsConnect, types::ToSql};
41use url::Url;
42
43use crate::{oid_cache::OidCache, statement_cache::StatementCache};
44
45/// Classifies a `tokio_postgres::Error` into a Toasty error.
46///
47/// Errors that carry a server-side `DbError` are mapped to typed
48/// variants where one exists (`SerializationFailure`,
49/// `ReadOnlyTransaction`); everything else with a `DbError` becomes
50/// `DriverOperationFailed`. Errors *without* a `DbError` are
51/// classified as `ConnectionLost`: per `tokio-postgres`, those
52/// originate from the underlying socket or protocol layer (closed
53/// socket, IO error, end-of-stream during handshake), which the
54/// pool treats as evictable.
55fn classify_pg_error(e: tokio_postgres::Error) -> toasty_core::Error {
56    if let Some(db_err) = e.as_db_error() {
57        match db_err.code().code() {
58            "40001" => toasty_core::Error::serialization_failure(db_err.message()),
59            "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
60            _ => toasty_core::Error::driver_operation_failed(e),
61        }
62    } else {
63        toasty_core::Error::connection_lost(e)
64    }
65}
66
67/// A PostgreSQL [`Driver`] that connects via `tokio-postgres`.
68///
69/// # Examples
70///
71/// ```no_run
72/// use toasty_driver_postgresql::PostgreSQL;
73///
74/// let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
75/// ```
76#[derive(Debug)]
77pub struct PostgreSQL {
78    url: String,
79    config: Config,
80    #[cfg(feature = "tls")]
81    tls: Option<tls::MakeRustlsConnect>,
82}
83
84impl PostgreSQL {
85    /// Create a new PostgreSQL driver from a connection URL
86    pub fn new(url: impl Into<String>) -> Result<Self> {
87        let url_str = url.into();
88        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
89
90        if !matches!(url.scheme(), "postgresql" | "postgres") {
91            return Err(toasty_core::Error::invalid_connection_url(format!(
92                "connection URL does not have a `postgresql` scheme; url={}",
93                url
94            )));
95        }
96
97        let host = url.host_str().ok_or_else(|| {
98            toasty_core::Error::invalid_connection_url(format!(
99                "missing host in connection URL; url={}",
100                url
101            ))
102        })?;
103
104        if url.path().is_empty() {
105            return Err(toasty_core::Error::invalid_connection_url(format!(
106                "no database specified - missing path in connection URL; url={}",
107                url
108            )));
109        }
110
111        let mut config = Config::new();
112        config.host(host);
113
114        let dbname = percent_decode_str(url.path().trim_start_matches('/'))
115            .decode_utf8()
116            .map_err(|_| {
117                toasty_core::Error::invalid_connection_url("database name is not valid UTF-8")
118            })?;
119        config.dbname(&*dbname);
120
121        if let Some(port) = url.port() {
122            config.port(port);
123        }
124
125        if !url.username().is_empty() {
126            let user = percent_decode_str(url.username())
127                .decode_utf8()
128                .map_err(|_| {
129                    toasty_core::Error::invalid_connection_url("username is not valid UTF-8")
130                })?;
131            config.user(&*user);
132        }
133
134        if let Some(password) = url.password() {
135            config.password(percent_decode_str(password).collect::<Vec<u8>>());
136        }
137
138        for (key, value) in url.query_pairs() {
139            if key == "application_name" {
140                config.application_name(&*value);
141            }
142        }
143
144        #[cfg(feature = "tls")]
145        let tls = tls::configure_tls(&url, &mut config)?;
146
147        #[cfg(not(feature = "tls"))]
148        for (key, value) in url.query_pairs() {
149            if key == "sslmode" && value != "disable" {
150                return Err(toasty_core::Error::invalid_connection_url(
151                    "TLS not available: compile with the `tls` feature",
152                ));
153            }
154        }
155
156        Ok(Self {
157            url: url_str,
158            config,
159            #[cfg(feature = "tls")]
160            tls,
161        })
162    }
163
164    async fn connect_with_config(&self, config: Config) -> Result<Connection> {
165        #[cfg(feature = "tls")]
166        if let Some(ref tls) = self.tls {
167            return Connection::connect(config, tls.clone()).await;
168        }
169        Connection::connect(config, tokio_postgres::NoTls).await
170    }
171}
172
173#[async_trait]
174impl Driver for PostgreSQL {
175    fn url(&self) -> Cow<'_, str> {
176        Cow::Borrowed(&self.url)
177    }
178
179    fn capability(&self) -> &'static Capability {
180        &Capability::POSTGRESQL
181    }
182
183    async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
184        Ok(Box::new(
185            self.connect_with_config(self.config.clone()).await?,
186        ))
187    }
188
189    fn generate_migration(&self, schema_diff: &diff::Schema<'_>) -> Migration {
190        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
191
192        let sql_strings: Vec<String> = statements
193            .iter()
194            .map(|stmt| sql::Serializer::postgresql(stmt.schema()).serialize(stmt.statement()))
195            .collect();
196
197        Migration::new_sql(sql_strings.join("\n"))
198    }
199
200    async fn reset_db(&self) -> toasty_core::Result<()> {
201        let dbname = self
202            .config
203            .get_dbname()
204            .ok_or_else(|| {
205                toasty_core::Error::invalid_connection_url("no database name configured")
206            })?
207            .to_string();
208
209        // We cannot drop a database we are currently connected to, so we need a temp database.
210        let temp_dbname = "__toasty_reset_temp";
211
212        let connect = |dbname: &str| {
213            let mut config = self.config.clone();
214            config.dbname(dbname);
215            self.connect_with_config(config)
216        };
217
218        // Step 1: Connect to the target DB and create a temp DB
219        let conn = connect(&dbname).await?;
220        conn.client
221            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
222            .await
223            .map_err(classify_pg_error)?;
224        conn.client
225            .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
226            .await
227            .map_err(classify_pg_error)?;
228        drop(conn);
229
230        // Step 2: Connect to the temp DB, drop and recreate the target
231        let conn = connect(temp_dbname).await?;
232        conn.client
233            .execute(
234                "SELECT pg_terminate_backend(pid) \
235                 FROM pg_stat_activity \
236                 WHERE datname = $1 AND pid <> pg_backend_pid()",
237                &[&dbname],
238            )
239            .await
240            .map_err(classify_pg_error)?;
241        conn.client
242            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
243            .await
244            .map_err(classify_pg_error)?;
245        conn.client
246            .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
247            .await
248            .map_err(classify_pg_error)?;
249        drop(conn);
250
251        // Step 3: Connect back to the target and clean up the temp DB
252        let conn = connect(&dbname).await?;
253        conn.client
254            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
255            .await
256            .map_err(classify_pg_error)?;
257
258        Ok(())
259    }
260}
261
262/// An open connection to a PostgreSQL database.
263#[derive(Debug)]
264pub struct Connection {
265    client: Client,
266    statement_cache: StatementCache,
267    oid_cache: OidCache,
268}
269
270impl Connection {
271    /// Initialize a Toasty PostgreSQL connection using an initialized client.
272    pub fn new(client: Client) -> Self {
273        Self {
274            client,
275            statement_cache: StatementCache::new(100),
276            oid_cache: OidCache::new(),
277        }
278    }
279
280    /// Connects to a PostgreSQL database using a [`postgres::Config`].
281    ///
282    /// See [`postgres::Client::configure`] for more information.
283    pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
284    where
285        T: MakeTlsConnect<Socket> + 'static,
286        T::Stream: Send,
287    {
288        let (client, connection) = config.connect(tls).await.map_err(classify_pg_error)?;
289
290        tokio::spawn(async move {
291            if let Err(e) = connection.await {
292                eprintln!("connection error: {e}");
293            }
294        });
295
296        Ok(Self::new(client))
297    }
298
299    /// Creates a table.
300    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
301        let serializer = sql::Serializer::postgresql(schema);
302
303        let sql = serializer.serialize(&sql::Statement::create_table(
304            table,
305            &Capability::POSTGRESQL,
306        ));
307
308        self.client
309            .execute(&sql, &[])
310            .await
311            .map_err(classify_pg_error)?;
312
313        for index in &table.indices {
314            if index.primary_key {
315                continue;
316            }
317
318            let sql = serializer.serialize(&sql::Statement::create_index(index));
319
320            self.client
321                .execute(&sql, &[])
322                .await
323                .map_err(classify_pg_error)?;
324        }
325
326        Ok(())
327    }
328}
329
330impl From<Client> for Connection {
331    fn from(client: Client) -> Self {
332        Self::new(client)
333    }
334}
335
336#[async_trait]
337impl toasty_core::driver::Connection for Connection {
338    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
339        tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
340
341        if let Operation::Transaction(ref t) = op {
342            // PostgreSQL has no `BEGIN IMMEDIATE` / `BEGIN EXCLUSIVE`
343            // analogue; reject non-Default modes loudly rather than
344            // silently dropping them at the serializer.
345            if let Transaction::Start {
346                mode: mode @ (TransactionMode::Immediate | TransactionMode::Exclusive),
347                ..
348            } = t
349            {
350                return Err(toasty_core::Error::unsupported_feature(format!(
351                    "PostgreSQL does not support TransactionMode::{mode:?}"
352                )));
353            }
354            let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
355            self.client
356                .batch_execute(&sql)
357                .await
358                .map_err(classify_pg_error)?;
359            return Ok(ExecResponse::count(0));
360        }
361
362        let (sql, typed_params, ret_tys) = match op {
363            Operation::Insert(op) => (sql::Statement::from(op.stmt), op.params, None),
364            Operation::QuerySql(query) => {
365                assert!(
366                    query.last_insert_id_hack.is_none(),
367                    "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
368                );
369                (sql::Statement::from(query.stmt), query.params, query.ret)
370            }
371            op => todo!("op={:#?}", op),
372        };
373
374        let width = sql.returning_len();
375
376        let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql);
377
378        tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = typed_params.len(), "executing SQL");
379
380        self.oid_cache
381            .preload(&self.client, typed_params.iter().map(|tv| &tv.ty))
382            .await?;
383        let param_types: Vec<_> = typed_params
384            .iter()
385            .map(|tv| self.oid_cache.get(&tv.ty).clone())
386            .collect();
387
388        let values: Vec<_> = typed_params
389            .into_iter()
390            .map(|tv| Value::from(tv.value))
391            .collect();
392        let params = values
393            .iter()
394            .map(|param| param as &(dyn ToSql + Sync))
395            .collect::<Vec<_>>();
396
397        let statement = self
398            .statement_cache
399            .prepare_typed(&mut self.client, &sql_as_str, &param_types)
400            .await
401            .map_err(classify_pg_error)?;
402
403        if width.is_none() {
404            let count = self
405                .client
406                .execute(&statement, &params)
407                .await
408                .map_err(classify_pg_error)?;
409            return Ok(ExecResponse::count(count));
410        }
411
412        let rows = self
413            .client
414            .query(&statement, &params)
415            .await
416            .map_err(classify_pg_error)?;
417
418        if width.is_none() {
419            let [row] = &rows[..] else { todo!() };
420            let total = row.get::<usize, i64>(0);
421            let condition_matched = row.get::<usize, i64>(1);
422
423            if total == condition_matched {
424                Ok(ExecResponse::count(total as _))
425            } else {
426                Err(toasty_core::Error::condition_failed(
427                    "update condition did not match",
428                ))
429            }
430        } else {
431            let ret_tys = ret_tys.as_ref().unwrap().clone();
432            let results = rows.into_iter().map(move |row| {
433                let mut results = Vec::new();
434                for (i, column) in row.columns().iter().enumerate() {
435                    results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
436                }
437
438                Ok(ValueRecord::from_vec(results))
439            });
440
441            Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
442                results,
443            )))
444        }
445    }
446
447    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
448        let serializer = sql::Serializer::postgresql(&schema.db);
449
450        // Create PostgreSQL enum types before creating tables.
451        // Collect unique enum types across all columns.
452        let mut created_enum_types = hashbrown::HashSet::new();
453        for table in &schema.db.tables {
454            for column in &table.columns {
455                if let toasty_core::schema::db::Type::Enum(type_enum) = &column.storage_ty
456                    && created_enum_types.insert(type_enum.name.clone())
457                {
458                    let sql = serializer.serialize(&sql::Statement::create_enum_type(type_enum));
459
460                    tracing::debug!(enum_type = ?type_enum.name, "creating enum type");
461                    self.client
462                        .execute(&sql, &[])
463                        .await
464                        .map_err(classify_pg_error)?;
465                }
466            }
467        }
468
469        for table in &schema.db.tables {
470            tracing::debug!(table = %table.name, "creating table");
471            self.create_table(&schema.db, table).await?;
472        }
473        Ok(())
474    }
475
476    async fn applied_migrations(
477        &mut self,
478    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
479        // Ensure the migrations table exists
480        self.client
481            .execute(
482                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
483                id BIGINT PRIMARY KEY,
484                name TEXT NOT NULL,
485                applied_at TIMESTAMP NOT NULL
486            )",
487                &[],
488            )
489            .await
490            .map_err(classify_pg_error)?;
491
492        // Query all applied migrations
493        let rows = self
494            .client
495            .query(
496                "SELECT id FROM __toasty_migrations ORDER BY applied_at",
497                &[],
498            )
499            .await
500            .map_err(classify_pg_error)?;
501
502        Ok(rows
503            .iter()
504            .map(|row| {
505                let id: i64 = row.get(0);
506                toasty_core::schema::db::AppliedMigration::new(id as u64)
507            })
508            .collect())
509    }
510
511    async fn apply_migration(
512        &mut self,
513        id: u64,
514        name: &str,
515        migration: &toasty_core::schema::db::Migration,
516    ) -> Result<()> {
517        tracing::info!(id = id, name = %name, "applying migration");
518        // Ensure the migrations table exists
519        self.client
520            .execute(
521                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
522                id BIGINT PRIMARY KEY,
523                name TEXT NOT NULL,
524                applied_at TIMESTAMP NOT NULL
525            )",
526                &[],
527            )
528            .await
529            .map_err(classify_pg_error)?;
530
531        // Start transaction
532        let transaction = self.client.transaction().await.map_err(classify_pg_error)?;
533
534        // Execute each migration statement
535        for statement in migration.statements() {
536            if let Err(e) = transaction
537                .batch_execute(statement)
538                .await
539                .map_err(classify_pg_error)
540            {
541                transaction.rollback().await.map_err(classify_pg_error)?;
542                return Err(e);
543            }
544        }
545
546        // Record the migration
547        if let Err(e) = transaction
548            .execute(
549                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
550                &[&(id as i64), &name],
551            )
552            .await
553            .map_err(classify_pg_error)
554        {
555            transaction.rollback().await.map_err(classify_pg_error)?;
556            return Err(e);
557        }
558
559        // Commit transaction
560        transaction.commit().await.map_err(classify_pg_error)?;
561        Ok(())
562    }
563
564    fn is_valid(&self) -> bool {
565        !self.client.is_closed()
566    }
567
568    async fn ping(&mut self) -> Result<()> {
569        // An empty `simple_query` is the lightest sync round-trip in
570        // the PG protocol — it skips parsing entirely. Any failure is
571        // surfaced as `connection_lost`: the only meaningful outcome
572        // of a ping is "the connection is alive" or "evict it."
573        self.client
574            .simple_query("")
575            .await
576            .map(|_| ())
577            .map_err(toasty_core::Error::connection_lost)
578    }
579}