Skip to main content

toasty_driver_postgresql/
lib.rs

1#![warn(missing_docs)]
2
3//! Toasty driver for [PostgreSQL](https://www.postgresql.org/) using
4//! [`tokio-postgres`](https://docs.rs/tokio-postgres).
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use toasty_driver_postgresql::PostgreSQL;
10//!
11//! let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
12//! ```
13
14mod oid_cache;
15mod statement_cache;
16#[cfg(feature = "tls")]
17mod tls;
18mod r#type;
19mod value;
20
21pub(crate) use value::Value;
22
23use async_trait::async_trait;
24use percent_encoding::percent_decode_str;
25use std::{borrow::Cow, sync::Arc};
26use toasty_core::{
27    Result, Schema,
28    driver::{Capability, Driver, ExecResponse, Operation},
29    schema::db::{self, Migration, SchemaDiff, Table},
30    stmt,
31    stmt::ValueRecord,
32};
33use toasty_sql::{self as sql};
34use tokio_postgres::{Client, Config, Socket, tls::MakeTlsConnect, types::ToSql};
35use url::Url;
36
37use crate::{oid_cache::OidCache, statement_cache::StatementCache};
38
39/// A PostgreSQL [`Driver`] that connects via `tokio-postgres`.
40///
41/// # Examples
42///
43/// ```no_run
44/// use toasty_driver_postgresql::PostgreSQL;
45///
46/// let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
47/// ```
48#[derive(Debug)]
49pub struct PostgreSQL {
50    url: String,
51    config: Config,
52    #[cfg(feature = "tls")]
53    tls: Option<tls::MakeRustlsConnect>,
54}
55
56impl PostgreSQL {
57    /// Create a new PostgreSQL driver from a connection URL
58    pub fn new(url: impl Into<String>) -> Result<Self> {
59        let url_str = url.into();
60        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
61
62        if !matches!(url.scheme(), "postgresql" | "postgres") {
63            return Err(toasty_core::Error::invalid_connection_url(format!(
64                "connection URL does not have a `postgresql` scheme; url={}",
65                url
66            )));
67        }
68
69        let host = url.host_str().ok_or_else(|| {
70            toasty_core::Error::invalid_connection_url(format!(
71                "missing host in connection URL; url={}",
72                url
73            ))
74        })?;
75
76        if url.path().is_empty() {
77            return Err(toasty_core::Error::invalid_connection_url(format!(
78                "no database specified - missing path in connection URL; url={}",
79                url
80            )));
81        }
82
83        let mut config = Config::new();
84        config.host(host);
85
86        let dbname = percent_decode_str(url.path().trim_start_matches('/'))
87            .decode_utf8()
88            .map_err(|_| {
89                toasty_core::Error::invalid_connection_url("database name is not valid UTF-8")
90            })?;
91        config.dbname(&*dbname);
92
93        if let Some(port) = url.port() {
94            config.port(port);
95        }
96
97        if !url.username().is_empty() {
98            let user = percent_decode_str(url.username())
99                .decode_utf8()
100                .map_err(|_| {
101                    toasty_core::Error::invalid_connection_url("username is not valid UTF-8")
102                })?;
103            config.user(&*user);
104        }
105
106        if let Some(password) = url.password() {
107            config.password(percent_decode_str(password).collect::<Vec<u8>>());
108        }
109
110        for (key, value) in url.query_pairs() {
111            if key == "application_name" {
112                config.application_name(&*value);
113            }
114        }
115
116        #[cfg(feature = "tls")]
117        let tls = tls::configure_tls(&url, &mut config)?;
118
119        #[cfg(not(feature = "tls"))]
120        for (key, value) in url.query_pairs() {
121            if key == "sslmode" && value != "disable" {
122                return Err(toasty_core::Error::invalid_connection_url(
123                    "TLS not available: compile with the `tls` feature",
124                ));
125            }
126        }
127
128        Ok(Self {
129            url: url_str,
130            config,
131            #[cfg(feature = "tls")]
132            tls,
133        })
134    }
135
136    async fn connect_with_config(&self, config: Config) -> Result<Connection> {
137        #[cfg(feature = "tls")]
138        if let Some(ref tls) = self.tls {
139            return Connection::connect(config, tls.clone()).await;
140        }
141        Connection::connect(config, tokio_postgres::NoTls).await
142    }
143}
144
145#[async_trait]
146impl Driver for PostgreSQL {
147    fn url(&self) -> Cow<'_, str> {
148        Cow::Borrowed(&self.url)
149    }
150
151    fn capability(&self) -> &'static Capability {
152        &Capability::POSTGRESQL
153    }
154
155    async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
156        Ok(Box::new(
157            self.connect_with_config(self.config.clone()).await?,
158        ))
159    }
160
161    fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
162        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
163
164        let sql_strings: Vec<String> = statements
165            .iter()
166            .map(|stmt| sql::Serializer::postgresql(stmt.schema()).serialize(stmt.statement()))
167            .collect();
168
169        Migration::new_sql(sql_strings.join("\n"))
170    }
171
172    async fn reset_db(&self) -> toasty_core::Result<()> {
173        let dbname = self
174            .config
175            .get_dbname()
176            .ok_or_else(|| {
177                toasty_core::Error::invalid_connection_url("no database name configured")
178            })?
179            .to_string();
180
181        // We cannot drop a database we are currently connected to, so we need a temp database.
182        let temp_dbname = "__toasty_reset_temp";
183
184        let connect = |dbname: &str| {
185            let mut config = self.config.clone();
186            config.dbname(dbname);
187            self.connect_with_config(config)
188        };
189
190        // Step 1: Connect to the target DB and create a temp DB
191        let conn = connect(&dbname).await?;
192        conn.client
193            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
194            .await
195            .map_err(toasty_core::Error::driver_operation_failed)?;
196        conn.client
197            .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
198            .await
199            .map_err(toasty_core::Error::driver_operation_failed)?;
200        drop(conn);
201
202        // Step 2: Connect to the temp DB, drop and recreate the target
203        let conn = connect(temp_dbname).await?;
204        conn.client
205            .execute(
206                "SELECT pg_terminate_backend(pid) \
207                 FROM pg_stat_activity \
208                 WHERE datname = $1 AND pid <> pg_backend_pid()",
209                &[&dbname],
210            )
211            .await
212            .map_err(toasty_core::Error::driver_operation_failed)?;
213        conn.client
214            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
215            .await
216            .map_err(toasty_core::Error::driver_operation_failed)?;
217        conn.client
218            .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
219            .await
220            .map_err(toasty_core::Error::driver_operation_failed)?;
221        drop(conn);
222
223        // Step 3: Connect back to the target and clean up the temp DB
224        let conn = connect(&dbname).await?;
225        conn.client
226            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
227            .await
228            .map_err(toasty_core::Error::driver_operation_failed)?;
229
230        Ok(())
231    }
232}
233
234/// An open connection to a PostgreSQL database.
235#[derive(Debug)]
236pub struct Connection {
237    client: Client,
238    statement_cache: StatementCache,
239    oid_cache: OidCache,
240}
241
242impl Connection {
243    /// Initialize a Toasty PostgreSQL connection using an initialized client.
244    pub fn new(client: Client) -> Self {
245        Self {
246            client,
247            statement_cache: StatementCache::new(100),
248            oid_cache: OidCache::new(),
249        }
250    }
251
252    /// Connects to a PostgreSQL database using a [`postgres::Config`].
253    ///
254    /// See [`postgres::Client::configure`] for more information.
255    pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
256    where
257        T: MakeTlsConnect<Socket> + 'static,
258        T::Stream: Send,
259    {
260        let (client, connection) = config
261            .connect(tls)
262            .await
263            .map_err(toasty_core::Error::driver_operation_failed)?;
264
265        tokio::spawn(async move {
266            if let Err(e) = connection.await {
267                eprintln!("connection error: {e}");
268            }
269        });
270
271        Ok(Self::new(client))
272    }
273
274    /// Creates a table.
275    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
276        let serializer = sql::Serializer::postgresql(schema);
277
278        let sql = serializer.serialize(&sql::Statement::create_table(
279            table,
280            &Capability::POSTGRESQL,
281        ));
282
283        self.client
284            .execute(&sql, &[])
285            .await
286            .map_err(toasty_core::Error::driver_operation_failed)?;
287
288        for index in &table.indices {
289            if index.primary_key {
290                continue;
291            }
292
293            let sql = serializer.serialize(&sql::Statement::create_index(index));
294
295            self.client
296                .execute(&sql, &[])
297                .await
298                .map_err(toasty_core::Error::driver_operation_failed)?;
299        }
300
301        Ok(())
302    }
303}
304
305impl From<Client> for Connection {
306    fn from(client: Client) -> Self {
307        Self::new(client)
308    }
309}
310
311#[async_trait]
312impl toasty_core::driver::Connection for Connection {
313    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
314        tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
315
316        if let Operation::Transaction(ref t) = op {
317            let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
318            self.client.batch_execute(&sql).await.map_err(|e| {
319                if let Some(db_err) = e.as_db_error() {
320                    match db_err.code().code() {
321                        "40001" => toasty_core::Error::serialization_failure(db_err.message()),
322                        "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
323                        _ => toasty_core::Error::driver_operation_failed(e),
324                    }
325                } else {
326                    toasty_core::Error::driver_operation_failed(e)
327                }
328            })?;
329            return Ok(ExecResponse::count(0));
330        }
331
332        let (sql, typed_params, ret_tys) = match op {
333            Operation::Insert(op) => (sql::Statement::from(op.stmt), op.params, None),
334            Operation::QuerySql(query) => {
335                assert!(
336                    query.last_insert_id_hack.is_none(),
337                    "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
338                );
339                (sql::Statement::from(query.stmt), query.params, query.ret)
340            }
341            op => todo!("op={:#?}", op),
342        };
343
344        let width = sql.returning_len();
345
346        let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql);
347
348        tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = typed_params.len(), "executing SQL");
349
350        self.oid_cache
351            .preload(&self.client, typed_params.iter().map(|tv| &tv.ty))
352            .await?;
353        let param_types: Vec<_> = typed_params
354            .iter()
355            .map(|tv| self.oid_cache.get(&tv.ty))
356            .collect();
357
358        let values: Vec<_> = typed_params
359            .into_iter()
360            .map(|tv| Value::from(tv.value))
361            .collect();
362        let params = values
363            .iter()
364            .map(|param| param as &(dyn ToSql + Sync))
365            .collect::<Vec<_>>();
366
367        let statement = self
368            .statement_cache
369            .prepare_typed(&mut self.client, &sql_as_str, &param_types)
370            .await
371            .map_err(toasty_core::Error::driver_operation_failed)?;
372
373        if width.is_none() {
374            let count = self
375                .client
376                .execute(&statement, &params)
377                .await
378                .map_err(toasty_core::Error::driver_operation_failed)?;
379            return Ok(ExecResponse::count(count));
380        }
381
382        let rows = self
383            .client
384            .query(&statement, &params)
385            .await
386            .map_err(toasty_core::Error::driver_operation_failed)?;
387
388        if width.is_none() {
389            let [row] = &rows[..] else { todo!() };
390            let total = row.get::<usize, i64>(0);
391            let condition_matched = row.get::<usize, i64>(1);
392
393            if total == condition_matched {
394                Ok(ExecResponse::count(total as _))
395            } else {
396                Err(toasty_core::Error::condition_failed(
397                    "update condition did not match",
398                ))
399            }
400        } else {
401            let ret_tys = ret_tys.as_ref().unwrap().clone();
402            let results = rows.into_iter().map(move |row| {
403                let mut results = Vec::new();
404                for (i, column) in row.columns().iter().enumerate() {
405                    results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
406                }
407
408                Ok(ValueRecord::from_vec(results))
409            });
410
411            Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
412                results,
413            )))
414        }
415    }
416
417    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
418        let serializer = sql::Serializer::postgresql(&schema.db);
419
420        // Create PostgreSQL enum types before creating tables.
421        // Collect unique enum types across all columns.
422        let mut created_enum_types = hashbrown::HashSet::new();
423        for table in &schema.db.tables {
424            for column in &table.columns {
425                if let toasty_core::schema::db::Type::Enum(type_enum) = &column.storage_ty
426                    && created_enum_types.insert(type_enum.name.clone())
427                {
428                    let sql = serializer.serialize(&sql::Statement::create_enum_type(type_enum));
429
430                    tracing::debug!(enum_type = ?type_enum.name, "creating enum type");
431                    self.client
432                        .execute(&sql, &[])
433                        .await
434                        .map_err(toasty_core::Error::driver_operation_failed)?;
435                }
436            }
437        }
438
439        for table in &schema.db.tables {
440            tracing::debug!(table = %table.name, "creating table");
441            self.create_table(&schema.db, table).await?;
442        }
443        Ok(())
444    }
445
446    async fn applied_migrations(
447        &mut self,
448    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
449        // Ensure the migrations table exists
450        self.client
451            .execute(
452                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
453                id BIGINT PRIMARY KEY,
454                name TEXT NOT NULL,
455                applied_at TIMESTAMP NOT NULL
456            )",
457                &[],
458            )
459            .await
460            .map_err(toasty_core::Error::driver_operation_failed)?;
461
462        // Query all applied migrations
463        let rows = self
464            .client
465            .query(
466                "SELECT id FROM __toasty_migrations ORDER BY applied_at",
467                &[],
468            )
469            .await
470            .map_err(toasty_core::Error::driver_operation_failed)?;
471
472        Ok(rows
473            .iter()
474            .map(|row| {
475                let id: i64 = row.get(0);
476                toasty_core::schema::db::AppliedMigration::new(id as u64)
477            })
478            .collect())
479    }
480
481    async fn apply_migration(
482        &mut self,
483        id: u64,
484        name: &str,
485        migration: &toasty_core::schema::db::Migration,
486    ) -> Result<()> {
487        tracing::info!(id = id, name = %name, "applying migration");
488        // Ensure the migrations table exists
489        self.client
490            .execute(
491                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
492                id BIGINT PRIMARY KEY,
493                name TEXT NOT NULL,
494                applied_at TIMESTAMP NOT NULL
495            )",
496                &[],
497            )
498            .await
499            .map_err(toasty_core::Error::driver_operation_failed)?;
500
501        // Start transaction
502        let transaction = self
503            .client
504            .transaction()
505            .await
506            .map_err(toasty_core::Error::driver_operation_failed)?;
507
508        // Execute each migration statement
509        for statement in migration.statements() {
510            if let Err(e) = transaction
511                .batch_execute(statement)
512                .await
513                .map_err(toasty_core::Error::driver_operation_failed)
514            {
515                transaction
516                    .rollback()
517                    .await
518                    .map_err(toasty_core::Error::driver_operation_failed)?;
519                return Err(e);
520            }
521        }
522
523        // Record the migration
524        if let Err(e) = transaction
525            .execute(
526                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
527                &[&(id as i64), &name],
528            )
529            .await
530            .map_err(toasty_core::Error::driver_operation_failed)
531        {
532            transaction
533                .rollback()
534                .await
535                .map_err(toasty_core::Error::driver_operation_failed)?;
536            return Err(e);
537        }
538
539        // Commit transaction
540        transaction
541            .commit()
542            .await
543            .map_err(toasty_core::Error::driver_operation_failed)?;
544        Ok(())
545    }
546}