toasty_driver_postgresql/
lib.rs1#![warn(missing_docs)]
2
3mod oid_cache;
15mod statement_cache;
16#[cfg(feature = "tls")]
17mod tls;
18mod r#type;
19mod value;
20
21pub(crate) use value::Value;
22
23use async_trait::async_trait;
24use percent_encoding::percent_decode_str;
25use std::{borrow::Cow, sync::Arc};
26use toasty_core::{
27 Result, Schema,
28 driver::{Capability, Driver, ExecResponse, Operation},
29 schema::db::{self, Migration, SchemaDiff, Table},
30 stmt,
31 stmt::ValueRecord,
32};
33use toasty_sql::{self as sql};
34use tokio_postgres::{Client, Config, Socket, tls::MakeTlsConnect, types::ToSql};
35use url::Url;
36
37use crate::{oid_cache::OidCache, statement_cache::StatementCache};
38
39#[derive(Debug)]
49pub struct PostgreSQL {
50 url: String,
51 config: Config,
52 #[cfg(feature = "tls")]
53 tls: Option<tls::MakeRustlsConnect>,
54}
55
56impl PostgreSQL {
57 pub fn new(url: impl Into<String>) -> Result<Self> {
59 let url_str = url.into();
60 let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
61
62 if !matches!(url.scheme(), "postgresql" | "postgres") {
63 return Err(toasty_core::Error::invalid_connection_url(format!(
64 "connection URL does not have a `postgresql` scheme; url={}",
65 url
66 )));
67 }
68
69 let host = url.host_str().ok_or_else(|| {
70 toasty_core::Error::invalid_connection_url(format!(
71 "missing host in connection URL; url={}",
72 url
73 ))
74 })?;
75
76 if url.path().is_empty() {
77 return Err(toasty_core::Error::invalid_connection_url(format!(
78 "no database specified - missing path in connection URL; url={}",
79 url
80 )));
81 }
82
83 let mut config = Config::new();
84 config.host(host);
85
86 let dbname = percent_decode_str(url.path().trim_start_matches('/'))
87 .decode_utf8()
88 .map_err(|_| {
89 toasty_core::Error::invalid_connection_url("database name is not valid UTF-8")
90 })?;
91 config.dbname(&*dbname);
92
93 if let Some(port) = url.port() {
94 config.port(port);
95 }
96
97 if !url.username().is_empty() {
98 let user = percent_decode_str(url.username())
99 .decode_utf8()
100 .map_err(|_| {
101 toasty_core::Error::invalid_connection_url("username is not valid UTF-8")
102 })?;
103 config.user(&*user);
104 }
105
106 if let Some(password) = url.password() {
107 config.password(percent_decode_str(password).collect::<Vec<u8>>());
108 }
109
110 for (key, value) in url.query_pairs() {
111 if key == "application_name" {
112 config.application_name(&*value);
113 }
114 }
115
116 #[cfg(feature = "tls")]
117 let tls = tls::configure_tls(&url, &mut config)?;
118
119 #[cfg(not(feature = "tls"))]
120 for (key, value) in url.query_pairs() {
121 if key == "sslmode" && value != "disable" {
122 return Err(toasty_core::Error::invalid_connection_url(
123 "TLS not available: compile with the `tls` feature",
124 ));
125 }
126 }
127
128 Ok(Self {
129 url: url_str,
130 config,
131 #[cfg(feature = "tls")]
132 tls,
133 })
134 }
135
136 async fn connect_with_config(&self, config: Config) -> Result<Connection> {
137 #[cfg(feature = "tls")]
138 if let Some(ref tls) = self.tls {
139 return Connection::connect(config, tls.clone()).await;
140 }
141 Connection::connect(config, tokio_postgres::NoTls).await
142 }
143}
144
145#[async_trait]
146impl Driver for PostgreSQL {
147 fn url(&self) -> Cow<'_, str> {
148 Cow::Borrowed(&self.url)
149 }
150
151 fn capability(&self) -> &'static Capability {
152 &Capability::POSTGRESQL
153 }
154
155 async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
156 Ok(Box::new(
157 self.connect_with_config(self.config.clone()).await?,
158 ))
159 }
160
161 fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
162 let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
163
164 let sql_strings: Vec<String> = statements
165 .iter()
166 .map(|stmt| sql::Serializer::postgresql(stmt.schema()).serialize(stmt.statement()))
167 .collect();
168
169 Migration::new_sql(sql_strings.join("\n"))
170 }
171
172 async fn reset_db(&self) -> toasty_core::Result<()> {
173 let dbname = self
174 .config
175 .get_dbname()
176 .ok_or_else(|| {
177 toasty_core::Error::invalid_connection_url("no database name configured")
178 })?
179 .to_string();
180
181 let temp_dbname = "__toasty_reset_temp";
183
184 let connect = |dbname: &str| {
185 let mut config = self.config.clone();
186 config.dbname(dbname);
187 self.connect_with_config(config)
188 };
189
190 let conn = connect(&dbname).await?;
192 conn.client
193 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
194 .await
195 .map_err(toasty_core::Error::driver_operation_failed)?;
196 conn.client
197 .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
198 .await
199 .map_err(toasty_core::Error::driver_operation_failed)?;
200 drop(conn);
201
202 let conn = connect(temp_dbname).await?;
204 conn.client
205 .execute(
206 "SELECT pg_terminate_backend(pid) \
207 FROM pg_stat_activity \
208 WHERE datname = $1 AND pid <> pg_backend_pid()",
209 &[&dbname],
210 )
211 .await
212 .map_err(toasty_core::Error::driver_operation_failed)?;
213 conn.client
214 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
215 .await
216 .map_err(toasty_core::Error::driver_operation_failed)?;
217 conn.client
218 .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
219 .await
220 .map_err(toasty_core::Error::driver_operation_failed)?;
221 drop(conn);
222
223 let conn = connect(&dbname).await?;
225 conn.client
226 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
227 .await
228 .map_err(toasty_core::Error::driver_operation_failed)?;
229
230 Ok(())
231 }
232}
233
234#[derive(Debug)]
236pub struct Connection {
237 client: Client,
238 statement_cache: StatementCache,
239 oid_cache: OidCache,
240}
241
242impl Connection {
243 pub fn new(client: Client) -> Self {
245 Self {
246 client,
247 statement_cache: StatementCache::new(100),
248 oid_cache: OidCache::new(),
249 }
250 }
251
252 pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
256 where
257 T: MakeTlsConnect<Socket> + 'static,
258 T::Stream: Send,
259 {
260 let (client, connection) = config
261 .connect(tls)
262 .await
263 .map_err(toasty_core::Error::driver_operation_failed)?;
264
265 tokio::spawn(async move {
266 if let Err(e) = connection.await {
267 eprintln!("connection error: {e}");
268 }
269 });
270
271 Ok(Self::new(client))
272 }
273
274 pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
276 let serializer = sql::Serializer::postgresql(schema);
277
278 let sql = serializer.serialize(&sql::Statement::create_table(
279 table,
280 &Capability::POSTGRESQL,
281 ));
282
283 self.client
284 .execute(&sql, &[])
285 .await
286 .map_err(toasty_core::Error::driver_operation_failed)?;
287
288 for index in &table.indices {
289 if index.primary_key {
290 continue;
291 }
292
293 let sql = serializer.serialize(&sql::Statement::create_index(index));
294
295 self.client
296 .execute(&sql, &[])
297 .await
298 .map_err(toasty_core::Error::driver_operation_failed)?;
299 }
300
301 Ok(())
302 }
303}
304
305impl From<Client> for Connection {
306 fn from(client: Client) -> Self {
307 Self::new(client)
308 }
309}
310
311#[async_trait]
312impl toasty_core::driver::Connection for Connection {
313 async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
314 tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
315
316 if let Operation::Transaction(ref t) = op {
317 let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
318 self.client.batch_execute(&sql).await.map_err(|e| {
319 if let Some(db_err) = e.as_db_error() {
320 match db_err.code().code() {
321 "40001" => toasty_core::Error::serialization_failure(db_err.message()),
322 "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
323 _ => toasty_core::Error::driver_operation_failed(e),
324 }
325 } else {
326 toasty_core::Error::driver_operation_failed(e)
327 }
328 })?;
329 return Ok(ExecResponse::count(0));
330 }
331
332 let (sql, typed_params, ret_tys) = match op {
333 Operation::Insert(op) => (sql::Statement::from(op.stmt), op.params, None),
334 Operation::QuerySql(query) => {
335 assert!(
336 query.last_insert_id_hack.is_none(),
337 "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
338 );
339 (sql::Statement::from(query.stmt), query.params, query.ret)
340 }
341 op => todo!("op={:#?}", op),
342 };
343
344 let width = sql.returning_len();
345
346 let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql);
347
348 tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = typed_params.len(), "executing SQL");
349
350 self.oid_cache
351 .preload(&self.client, typed_params.iter().map(|tv| &tv.ty))
352 .await?;
353 let param_types: Vec<_> = typed_params
354 .iter()
355 .map(|tv| self.oid_cache.get(&tv.ty))
356 .collect();
357
358 let values: Vec<_> = typed_params
359 .into_iter()
360 .map(|tv| Value::from(tv.value))
361 .collect();
362 let params = values
363 .iter()
364 .map(|param| param as &(dyn ToSql + Sync))
365 .collect::<Vec<_>>();
366
367 let statement = self
368 .statement_cache
369 .prepare_typed(&mut self.client, &sql_as_str, ¶m_types)
370 .await
371 .map_err(toasty_core::Error::driver_operation_failed)?;
372
373 if width.is_none() {
374 let count = self
375 .client
376 .execute(&statement, ¶ms)
377 .await
378 .map_err(toasty_core::Error::driver_operation_failed)?;
379 return Ok(ExecResponse::count(count));
380 }
381
382 let rows = self
383 .client
384 .query(&statement, ¶ms)
385 .await
386 .map_err(toasty_core::Error::driver_operation_failed)?;
387
388 if width.is_none() {
389 let [row] = &rows[..] else { todo!() };
390 let total = row.get::<usize, i64>(0);
391 let condition_matched = row.get::<usize, i64>(1);
392
393 if total == condition_matched {
394 Ok(ExecResponse::count(total as _))
395 } else {
396 Err(toasty_core::Error::condition_failed(
397 "update condition did not match",
398 ))
399 }
400 } else {
401 let ret_tys = ret_tys.as_ref().unwrap().clone();
402 let results = rows.into_iter().map(move |row| {
403 let mut results = Vec::new();
404 for (i, column) in row.columns().iter().enumerate() {
405 results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
406 }
407
408 Ok(ValueRecord::from_vec(results))
409 });
410
411 Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
412 results,
413 )))
414 }
415 }
416
417 async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
418 let serializer = sql::Serializer::postgresql(&schema.db);
419
420 let mut created_enum_types = hashbrown::HashSet::new();
423 for table in &schema.db.tables {
424 for column in &table.columns {
425 if let toasty_core::schema::db::Type::Enum(type_enum) = &column.storage_ty
426 && created_enum_types.insert(type_enum.name.clone())
427 {
428 let sql = serializer.serialize(&sql::Statement::create_enum_type(type_enum));
429
430 tracing::debug!(enum_type = ?type_enum.name, "creating enum type");
431 self.client
432 .execute(&sql, &[])
433 .await
434 .map_err(toasty_core::Error::driver_operation_failed)?;
435 }
436 }
437 }
438
439 for table in &schema.db.tables {
440 tracing::debug!(table = %table.name, "creating table");
441 self.create_table(&schema.db, table).await?;
442 }
443 Ok(())
444 }
445
446 async fn applied_migrations(
447 &mut self,
448 ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
449 self.client
451 .execute(
452 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
453 id BIGINT PRIMARY KEY,
454 name TEXT NOT NULL,
455 applied_at TIMESTAMP NOT NULL
456 )",
457 &[],
458 )
459 .await
460 .map_err(toasty_core::Error::driver_operation_failed)?;
461
462 let rows = self
464 .client
465 .query(
466 "SELECT id FROM __toasty_migrations ORDER BY applied_at",
467 &[],
468 )
469 .await
470 .map_err(toasty_core::Error::driver_operation_failed)?;
471
472 Ok(rows
473 .iter()
474 .map(|row| {
475 let id: i64 = row.get(0);
476 toasty_core::schema::db::AppliedMigration::new(id as u64)
477 })
478 .collect())
479 }
480
481 async fn apply_migration(
482 &mut self,
483 id: u64,
484 name: &str,
485 migration: &toasty_core::schema::db::Migration,
486 ) -> Result<()> {
487 tracing::info!(id = id, name = %name, "applying migration");
488 self.client
490 .execute(
491 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
492 id BIGINT PRIMARY KEY,
493 name TEXT NOT NULL,
494 applied_at TIMESTAMP NOT NULL
495 )",
496 &[],
497 )
498 .await
499 .map_err(toasty_core::Error::driver_operation_failed)?;
500
501 let transaction = self
503 .client
504 .transaction()
505 .await
506 .map_err(toasty_core::Error::driver_operation_failed)?;
507
508 for statement in migration.statements() {
510 if let Err(e) = transaction
511 .batch_execute(statement)
512 .await
513 .map_err(toasty_core::Error::driver_operation_failed)
514 {
515 transaction
516 .rollback()
517 .await
518 .map_err(toasty_core::Error::driver_operation_failed)?;
519 return Err(e);
520 }
521 }
522
523 if let Err(e) = transaction
525 .execute(
526 "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
527 &[&(id as i64), &name],
528 )
529 .await
530 .map_err(toasty_core::Error::driver_operation_failed)
531 {
532 transaction
533 .rollback()
534 .await
535 .map_err(toasty_core::Error::driver_operation_failed)?;
536 return Err(e);
537 }
538
539 transaction
541 .commit()
542 .await
543 .map_err(toasty_core::Error::driver_operation_failed)?;
544 Ok(())
545 }
546}