toasty_driver_postgresql/
lib.rs1#![warn(missing_docs)]
2
3mod oid_cache;
15mod statement_cache;
16#[cfg(feature = "tls")]
17mod tls;
18mod r#type;
19mod value;
20
21pub(crate) use value::Value;
22
23use async_trait::async_trait;
24use percent_encoding::percent_decode_str;
25use std::{borrow::Cow, sync::Arc};
26use toasty_core::{
27 Result, Schema,
28 driver::{
29 Capability, Driver, ExecResponse, Operation,
30 operation::{Transaction, TransactionMode},
31 },
32 schema::{
33 db::{self, Migration, Table},
34 diff,
35 },
36 stmt,
37 stmt::ValueRecord,
38};
39use toasty_sql::{self as sql};
40use tokio_postgres::{Client, Config, Socket, tls::MakeTlsConnect, types::ToSql};
41use url::Url;
42
43use crate::{oid_cache::OidCache, statement_cache::StatementCache};
44
45fn classify_pg_error(e: tokio_postgres::Error) -> toasty_core::Error {
56 if let Some(db_err) = e.as_db_error() {
57 match db_err.code().code() {
58 "40001" => toasty_core::Error::serialization_failure(db_err.message()),
59 "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
60 _ => toasty_core::Error::driver_operation_failed(e),
61 }
62 } else {
63 toasty_core::Error::connection_lost(e)
64 }
65}
66
67#[derive(Debug)]
77pub struct PostgreSQL {
78 url: String,
79 config: Config,
80 #[cfg(feature = "tls")]
81 tls: Option<tls::MakeRustlsConnect>,
82}
83
84impl PostgreSQL {
85 pub fn new(url: impl Into<String>) -> Result<Self> {
87 let url_str = url.into();
88 let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
89
90 if !matches!(url.scheme(), "postgresql" | "postgres") {
91 return Err(toasty_core::Error::invalid_connection_url(format!(
92 "connection URL does not have a `postgresql` scheme; url={}",
93 url
94 )));
95 }
96
97 let host = url.host_str().ok_or_else(|| {
98 toasty_core::Error::invalid_connection_url(format!(
99 "missing host in connection URL; url={}",
100 url
101 ))
102 })?;
103
104 if url.path().is_empty() {
105 return Err(toasty_core::Error::invalid_connection_url(format!(
106 "no database specified - missing path in connection URL; url={}",
107 url
108 )));
109 }
110
111 let mut config = Config::new();
112 config.host(host);
113
114 let dbname = percent_decode_str(url.path().trim_start_matches('/'))
115 .decode_utf8()
116 .map_err(|_| {
117 toasty_core::Error::invalid_connection_url("database name is not valid UTF-8")
118 })?;
119 config.dbname(&*dbname);
120
121 if let Some(port) = url.port() {
122 config.port(port);
123 }
124
125 if !url.username().is_empty() {
126 let user = percent_decode_str(url.username())
127 .decode_utf8()
128 .map_err(|_| {
129 toasty_core::Error::invalid_connection_url("username is not valid UTF-8")
130 })?;
131 config.user(&*user);
132 }
133
134 if let Some(password) = url.password() {
135 config.password(percent_decode_str(password).collect::<Vec<u8>>());
136 }
137
138 for (key, value) in url.query_pairs() {
139 if key == "application_name" {
140 config.application_name(&*value);
141 }
142 }
143
144 #[cfg(feature = "tls")]
145 let tls = tls::configure_tls(&url, &mut config)?;
146
147 #[cfg(not(feature = "tls"))]
148 for (key, value) in url.query_pairs() {
149 if key == "sslmode" && value != "disable" {
150 return Err(toasty_core::Error::invalid_connection_url(
151 "TLS not available: compile with the `tls` feature",
152 ));
153 }
154 }
155
156 Ok(Self {
157 url: url_str,
158 config,
159 #[cfg(feature = "tls")]
160 tls,
161 })
162 }
163
164 async fn connect_with_config(&self, config: Config) -> Result<Connection> {
165 #[cfg(feature = "tls")]
166 if let Some(ref tls) = self.tls {
167 return Connection::connect(config, tls.clone()).await;
168 }
169 Connection::connect(config, tokio_postgres::NoTls).await
170 }
171}
172
173#[async_trait]
174impl Driver for PostgreSQL {
175 fn url(&self) -> Cow<'_, str> {
176 Cow::Borrowed(&self.url)
177 }
178
179 fn capability(&self) -> &'static Capability {
180 &Capability::POSTGRESQL
181 }
182
183 async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
184 Ok(Box::new(
185 self.connect_with_config(self.config.clone()).await?,
186 ))
187 }
188
189 fn generate_migration(&self, schema_diff: &diff::Schema<'_>) -> Migration {
190 let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
191
192 let sql_strings: Vec<String> = statements
193 .iter()
194 .map(|stmt| sql::Serializer::postgresql(stmt.schema()).serialize(stmt.statement()))
195 .collect();
196
197 Migration::new_sql(sql_strings.join("\n"))
198 }
199
200 async fn reset_db(&self) -> toasty_core::Result<()> {
201 let dbname = self
202 .config
203 .get_dbname()
204 .ok_or_else(|| {
205 toasty_core::Error::invalid_connection_url("no database name configured")
206 })?
207 .to_string();
208
209 let temp_dbname = "__toasty_reset_temp";
211
212 let connect = |dbname: &str| {
213 let mut config = self.config.clone();
214 config.dbname(dbname);
215 self.connect_with_config(config)
216 };
217
218 let conn = connect(&dbname).await?;
220 conn.client
221 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
222 .await
223 .map_err(classify_pg_error)?;
224 conn.client
225 .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
226 .await
227 .map_err(classify_pg_error)?;
228 drop(conn);
229
230 let conn = connect(temp_dbname).await?;
232 conn.client
233 .execute(
234 "SELECT pg_terminate_backend(pid) \
235 FROM pg_stat_activity \
236 WHERE datname = $1 AND pid <> pg_backend_pid()",
237 &[&dbname],
238 )
239 .await
240 .map_err(classify_pg_error)?;
241 conn.client
242 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
243 .await
244 .map_err(classify_pg_error)?;
245 conn.client
246 .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
247 .await
248 .map_err(classify_pg_error)?;
249 drop(conn);
250
251 let conn = connect(&dbname).await?;
253 conn.client
254 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
255 .await
256 .map_err(classify_pg_error)?;
257
258 Ok(())
259 }
260}
261
262#[derive(Debug)]
264pub struct Connection {
265 client: Client,
266 statement_cache: StatementCache,
267 oid_cache: OidCache,
268}
269
270impl Connection {
271 pub fn new(client: Client) -> Self {
273 Self {
274 client,
275 statement_cache: StatementCache::new(100),
276 oid_cache: OidCache::new(),
277 }
278 }
279
280 pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
284 where
285 T: MakeTlsConnect<Socket> + 'static,
286 T::Stream: Send,
287 {
288 let (client, connection) = config.connect(tls).await.map_err(classify_pg_error)?;
289
290 tokio::spawn(async move {
291 if let Err(e) = connection.await {
292 eprintln!("connection error: {e}");
293 }
294 });
295
296 Ok(Self::new(client))
297 }
298
299 pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
301 let serializer = sql::Serializer::postgresql(schema);
302
303 let sql = serializer.serialize(&sql::Statement::create_table(
304 table,
305 &Capability::POSTGRESQL,
306 ));
307
308 self.client
309 .execute(&sql, &[])
310 .await
311 .map_err(classify_pg_error)?;
312
313 for index in &table.indices {
314 if index.primary_key {
315 continue;
316 }
317
318 let sql = serializer.serialize(&sql::Statement::create_index(index));
319
320 self.client
321 .execute(&sql, &[])
322 .await
323 .map_err(classify_pg_error)?;
324 }
325
326 Ok(())
327 }
328}
329
330impl From<Client> for Connection {
331 fn from(client: Client) -> Self {
332 Self::new(client)
333 }
334}
335
336#[async_trait]
337impl toasty_core::driver::Connection for Connection {
338 async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
339 tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
340
341 if let Operation::Transaction(ref t) = op {
342 if let Transaction::Start {
346 mode: mode @ (TransactionMode::Immediate | TransactionMode::Exclusive),
347 ..
348 } = t
349 {
350 return Err(toasty_core::Error::unsupported_feature(format!(
351 "PostgreSQL does not support TransactionMode::{mode:?}"
352 )));
353 }
354 let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
355 self.client
356 .batch_execute(&sql)
357 .await
358 .map_err(classify_pg_error)?;
359 return Ok(ExecResponse::count(0));
360 }
361
362 let (sql, typed_params, ret_tys) = match op {
363 Operation::Insert(op) => (sql::Statement::from(op.stmt), op.params, None),
364 Operation::QuerySql(query) => {
365 assert!(
366 query.last_insert_id_hack.is_none(),
367 "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
368 );
369 (sql::Statement::from(query.stmt), query.params, query.ret)
370 }
371 op => todo!("op={:#?}", op),
372 };
373
374 let width = sql.returning_len();
375
376 let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql);
377
378 tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = typed_params.len(), "executing SQL");
379
380 self.oid_cache
381 .preload(&self.client, typed_params.iter().map(|tv| &tv.ty))
382 .await?;
383 let param_types: Vec<_> = typed_params
384 .iter()
385 .map(|tv| self.oid_cache.get(&tv.ty).clone())
386 .collect();
387
388 let values: Vec<_> = typed_params
389 .into_iter()
390 .map(|tv| Value::from(tv.value))
391 .collect();
392 let params = values
393 .iter()
394 .map(|param| param as &(dyn ToSql + Sync))
395 .collect::<Vec<_>>();
396
397 let statement = self
398 .statement_cache
399 .prepare_typed(&mut self.client, &sql_as_str, ¶m_types)
400 .await
401 .map_err(classify_pg_error)?;
402
403 if width.is_none() {
404 let count = self
405 .client
406 .execute(&statement, ¶ms)
407 .await
408 .map_err(classify_pg_error)?;
409 return Ok(ExecResponse::count(count));
410 }
411
412 let rows = self
413 .client
414 .query(&statement, ¶ms)
415 .await
416 .map_err(classify_pg_error)?;
417
418 if width.is_none() {
419 let [row] = &rows[..] else { todo!() };
420 let total = row.get::<usize, i64>(0);
421 let condition_matched = row.get::<usize, i64>(1);
422
423 if total == condition_matched {
424 Ok(ExecResponse::count(total as _))
425 } else {
426 Err(toasty_core::Error::condition_failed(
427 "update condition did not match",
428 ))
429 }
430 } else {
431 let ret_tys = ret_tys.as_ref().unwrap().clone();
432 let results = rows.into_iter().map(move |row| {
433 let mut results = Vec::new();
434 for (i, column) in row.columns().iter().enumerate() {
435 results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
436 }
437
438 Ok(ValueRecord::from_vec(results))
439 });
440
441 Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
442 results,
443 )))
444 }
445 }
446
447 async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
448 let serializer = sql::Serializer::postgresql(&schema.db);
449
450 let mut created_enum_types = hashbrown::HashSet::new();
453 for table in &schema.db.tables {
454 for column in &table.columns {
455 if let toasty_core::schema::db::Type::Enum(type_enum) = &column.storage_ty
456 && created_enum_types.insert(type_enum.name.clone())
457 {
458 let sql = serializer.serialize(&sql::Statement::create_enum_type(type_enum));
459
460 tracing::debug!(enum_type = ?type_enum.name, "creating enum type");
461 self.client
462 .execute(&sql, &[])
463 .await
464 .map_err(classify_pg_error)?;
465 }
466 }
467 }
468
469 for table in &schema.db.tables {
470 tracing::debug!(table = %table.name, "creating table");
471 self.create_table(&schema.db, table).await?;
472 }
473 Ok(())
474 }
475
476 async fn applied_migrations(
477 &mut self,
478 ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
479 self.client
481 .execute(
482 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
483 id BIGINT PRIMARY KEY,
484 name TEXT NOT NULL,
485 applied_at TIMESTAMP NOT NULL
486 )",
487 &[],
488 )
489 .await
490 .map_err(classify_pg_error)?;
491
492 let rows = self
494 .client
495 .query(
496 "SELECT id FROM __toasty_migrations ORDER BY applied_at",
497 &[],
498 )
499 .await
500 .map_err(classify_pg_error)?;
501
502 Ok(rows
503 .iter()
504 .map(|row| {
505 let id: i64 = row.get(0);
506 toasty_core::schema::db::AppliedMigration::new(id as u64)
507 })
508 .collect())
509 }
510
511 async fn apply_migration(
512 &mut self,
513 id: u64,
514 name: &str,
515 migration: &toasty_core::schema::db::Migration,
516 ) -> Result<()> {
517 tracing::info!(id = id, name = %name, "applying migration");
518 self.client
520 .execute(
521 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
522 id BIGINT PRIMARY KEY,
523 name TEXT NOT NULL,
524 applied_at TIMESTAMP NOT NULL
525 )",
526 &[],
527 )
528 .await
529 .map_err(classify_pg_error)?;
530
531 let transaction = self.client.transaction().await.map_err(classify_pg_error)?;
533
534 for statement in migration.statements() {
536 if let Err(e) = transaction
537 .batch_execute(statement)
538 .await
539 .map_err(classify_pg_error)
540 {
541 transaction.rollback().await.map_err(classify_pg_error)?;
542 return Err(e);
543 }
544 }
545
546 if let Err(e) = transaction
548 .execute(
549 "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
550 &[&(id as i64), &name],
551 )
552 .await
553 .map_err(classify_pg_error)
554 {
555 transaction.rollback().await.map_err(classify_pg_error)?;
556 return Err(e);
557 }
558
559 transaction.commit().await.map_err(classify_pg_error)?;
561 Ok(())
562 }
563
564 fn is_valid(&self) -> bool {
565 !self.client.is_closed()
566 }
567
568 async fn ping(&mut self) -> Result<()> {
569 self.client
574 .simple_query("")
575 .await
576 .map(|_| ())
577 .map_err(toasty_core::Error::connection_lost)
578 }
579}