toasty_driver_postgresql/
lib.rs1#![warn(missing_docs)]
2
3mod statement_cache;
15mod r#type;
16mod value;
17
18pub(crate) use value::Value;
19
20use async_trait::async_trait;
21use postgres::{Socket, tls::MakeTlsConnect, types::ToSql};
22use std::{borrow::Cow, sync::Arc};
23use toasty_core::{
24 Result, Schema,
25 driver::{Capability, Driver, ExecResponse, Operation},
26 schema::db::{self, Migration, SchemaDiff, Table},
27 stmt,
28 stmt::ValueRecord,
29};
30use toasty_sql::{self as sql, TypedValue};
31use tokio_postgres::{Client, Config};
32use url::Url;
33
34use crate::{statement_cache::StatementCache, r#type::TypeExt};
35
36#[derive(Debug)]
46pub struct PostgreSQL {
47 url: String,
48 config: Config,
49}
50
51impl PostgreSQL {
52 pub fn new(url: impl Into<String>) -> Result<Self> {
54 let url_str = url.into();
55 let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
56
57 if !matches!(url.scheme(), "postgresql" | "postgres") {
58 return Err(toasty_core::Error::invalid_connection_url(format!(
59 "connection URL does not have a `postgresql` scheme; url={}",
60 url
61 )));
62 }
63
64 let host = url.host_str().ok_or_else(|| {
65 toasty_core::Error::invalid_connection_url(format!(
66 "missing host in connection URL; url={}",
67 url
68 ))
69 })?;
70
71 if url.path().is_empty() {
72 return Err(toasty_core::Error::invalid_connection_url(format!(
73 "no database specified - missing path in connection URL; url={}",
74 url
75 )));
76 }
77
78 let mut config = Config::new();
79 config.host(host);
80 config.dbname(url.path().trim_start_matches('/'));
81
82 if let Some(port) = url.port() {
83 config.port(port);
84 }
85
86 if !url.username().is_empty() {
87 config.user(url.username());
88 }
89
90 if let Some(password) = url.password() {
91 config.password(password);
92 }
93
94 Ok(Self {
95 url: url_str,
96 config,
97 })
98 }
99}
100
101#[async_trait]
102impl Driver for PostgreSQL {
103 fn url(&self) -> Cow<'_, str> {
104 Cow::Borrowed(&self.url)
105 }
106
107 fn capability(&self) -> &'static Capability {
108 &Capability::POSTGRESQL
109 }
110
111 async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
112 Ok(Box::new(
113 Connection::connect(self.config.clone(), tokio_postgres::NoTls).await?,
114 ))
115 }
116
117 fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
118 let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
119
120 let sql_strings: Vec<String> = statements
121 .iter()
122 .map(|stmt| {
123 let mut params = Vec::<TypedValue>::new();
124 let sql = sql::Serializer::postgresql(stmt.schema())
125 .serialize(stmt.statement(), &mut params);
126 assert!(
127 params.is_empty(),
128 "migration statements should not have parameters"
129 );
130 sql
131 })
132 .collect();
133
134 Migration::new_sql(sql_strings.join("\n"))
135 }
136
137 async fn reset_db(&self) -> toasty_core::Result<()> {
138 let dbname = self
139 .config
140 .get_dbname()
141 .ok_or_else(|| {
142 toasty_core::Error::invalid_connection_url("no database name configured")
143 })?
144 .to_string();
145
146 let temp_dbname = "__toasty_reset_temp";
148
149 let connect = |dbname: &str| {
150 let mut config = self.config.clone();
151 config.dbname(dbname);
152 Connection::connect(config, tokio_postgres::NoTls)
153 };
154
155 let conn = connect(&dbname).await?;
157 conn.client
158 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
159 .await
160 .map_err(toasty_core::Error::driver_operation_failed)?;
161 conn.client
162 .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
163 .await
164 .map_err(toasty_core::Error::driver_operation_failed)?;
165 drop(conn);
166
167 let conn = connect(temp_dbname).await?;
169 conn.client
170 .execute(
171 "SELECT pg_terminate_backend(pid) \
172 FROM pg_stat_activity \
173 WHERE datname = $1 AND pid <> pg_backend_pid()",
174 &[&dbname],
175 )
176 .await
177 .map_err(toasty_core::Error::driver_operation_failed)?;
178 conn.client
179 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
180 .await
181 .map_err(toasty_core::Error::driver_operation_failed)?;
182 conn.client
183 .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
184 .await
185 .map_err(toasty_core::Error::driver_operation_failed)?;
186 drop(conn);
187
188 let conn = connect(&dbname).await?;
190 conn.client
191 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
192 .await
193 .map_err(toasty_core::Error::driver_operation_failed)?;
194
195 Ok(())
196 }
197}
198
199#[derive(Debug)]
201pub struct Connection {
202 client: Client,
203 statement_cache: StatementCache,
204}
205
206impl Connection {
207 pub fn new(client: Client) -> Self {
209 Self {
210 client,
211 statement_cache: StatementCache::new(100),
212 }
213 }
214
215 pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
219 where
220 T: MakeTlsConnect<Socket> + 'static,
221 T::Stream: Send,
222 {
223 let (client, connection) = config
224 .connect(tls)
225 .await
226 .map_err(toasty_core::Error::driver_operation_failed)?;
227
228 tokio::spawn(async move {
229 if let Err(e) = connection.await {
230 eprintln!("connection error: {e}");
231 }
232 });
233
234 Ok(Self::new(client))
235 }
236
237 pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
239 let serializer = sql::Serializer::postgresql(schema);
240
241 let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
242 let sql = serializer.serialize(
243 &sql::Statement::create_table(table, &Capability::POSTGRESQL),
244 &mut params,
245 );
246
247 assert!(
248 params.is_empty(),
249 "creating a table shouldn't involve any parameters"
250 );
251
252 self.client
253 .execute(&sql, &[])
254 .await
255 .map_err(toasty_core::Error::driver_operation_failed)?;
256
257 for index in &table.indices {
260 if index.primary_key {
261 continue;
262 }
263
264 let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
265
266 assert!(
267 params.is_empty(),
268 "creating an index shouldn't involve any parameters"
269 );
270
271 self.client
272 .execute(&sql, &[])
273 .await
274 .map_err(toasty_core::Error::driver_operation_failed)?;
275 }
276
277 Ok(())
278 }
279}
280
281impl From<Client> for Connection {
282 fn from(client: Client) -> Self {
283 Self {
284 client,
285 statement_cache: StatementCache::new(100),
286 }
287 }
288}
289
290#[async_trait]
291impl toasty_core::driver::Connection for Connection {
292 async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
293 tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
294
295 if let Operation::Transaction(ref t) = op {
296 let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
297 self.client.batch_execute(&sql).await.map_err(|e| {
298 if let Some(db_err) = e.as_db_error() {
299 match db_err.code().code() {
300 "40001" => toasty_core::Error::serialization_failure(db_err.message()),
301 "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
302 _ => toasty_core::Error::driver_operation_failed(e),
303 }
304 } else {
305 toasty_core::Error::driver_operation_failed(e)
306 }
307 })?;
308 return Ok(ExecResponse::count(0));
309 }
310
311 let (sql, ret_tys): (sql::Statement, _) = match op {
312 Operation::Insert(op) => (op.stmt.into(), None),
313 Operation::QuerySql(query) => {
314 assert!(
315 query.last_insert_id_hack.is_none(),
316 "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
317 );
318 (query.stmt.into(), query.ret)
319 }
320 op => todo!("op={:#?}", op),
321 };
322
323 let width = sql.returning_len();
324
325 let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
326 let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql, &mut params);
327
328 tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = params.len(), "executing SQL");
329
330 let param_types = params
331 .iter()
332 .map(|typed_value| typed_value.infer_ty().to_postgres_type())
333 .collect::<Vec<_>>();
334
335 let values: Vec<_> = params.into_iter().map(|tv| Value::from(tv.value)).collect();
336 let params = values
337 .iter()
338 .map(|param| param as &(dyn ToSql + Sync))
339 .collect::<Vec<_>>();
340
341 let statement = self
342 .statement_cache
343 .prepare_typed(&mut self.client, &sql_as_str, ¶m_types)
344 .await
345 .map_err(toasty_core::Error::driver_operation_failed)?;
346
347 if width.is_none() {
348 let count = self
349 .client
350 .execute(&statement, ¶ms)
351 .await
352 .map_err(toasty_core::Error::driver_operation_failed)?;
353 return Ok(ExecResponse::count(count));
354 }
355
356 let rows = self
357 .client
358 .query(&statement, ¶ms)
359 .await
360 .map_err(toasty_core::Error::driver_operation_failed)?;
361
362 if width.is_none() {
363 let [row] = &rows[..] else { todo!() };
364 let total = row.get::<usize, i64>(0);
365 let condition_matched = row.get::<usize, i64>(1);
366
367 if total == condition_matched {
368 Ok(ExecResponse::count(total as _))
369 } else {
370 Err(toasty_core::Error::condition_failed(
371 "update condition did not match",
372 ))
373 }
374 } else {
375 let ret_tys = ret_tys.as_ref().unwrap().clone();
376 let results = rows.into_iter().map(move |row| {
377 let mut results = Vec::new();
378 for (i, column) in row.columns().iter().enumerate() {
379 results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
380 }
381
382 Ok(ValueRecord::from_vec(results))
383 });
384
385 Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
386 results,
387 )))
388 }
389 }
390
391 async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
392 for table in &schema.db.tables {
393 tracing::debug!(table = %table.name, "creating table");
394 self.create_table(&schema.db, table).await?;
395 }
396 Ok(())
397 }
398
399 async fn applied_migrations(
400 &mut self,
401 ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
402 self.client
404 .execute(
405 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
406 id BIGINT PRIMARY KEY,
407 name TEXT NOT NULL,
408 applied_at TIMESTAMP NOT NULL
409 )",
410 &[],
411 )
412 .await
413 .map_err(toasty_core::Error::driver_operation_failed)?;
414
415 let rows = self
417 .client
418 .query(
419 "SELECT id FROM __toasty_migrations ORDER BY applied_at",
420 &[],
421 )
422 .await
423 .map_err(toasty_core::Error::driver_operation_failed)?;
424
425 Ok(rows
426 .iter()
427 .map(|row| {
428 let id: i64 = row.get(0);
429 toasty_core::schema::db::AppliedMigration::new(id as u64)
430 })
431 .collect())
432 }
433
434 async fn apply_migration(
435 &mut self,
436 id: u64,
437 name: &str,
438 migration: &toasty_core::schema::db::Migration,
439 ) -> Result<()> {
440 tracing::info!(id = id, name = %name, "applying migration");
441 self.client
443 .execute(
444 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
445 id BIGINT PRIMARY KEY,
446 name TEXT NOT NULL,
447 applied_at TIMESTAMP NOT NULL
448 )",
449 &[],
450 )
451 .await
452 .map_err(toasty_core::Error::driver_operation_failed)?;
453
454 let transaction = self
456 .client
457 .transaction()
458 .await
459 .map_err(toasty_core::Error::driver_operation_failed)?;
460
461 for statement in migration.statements() {
463 if let Err(e) = transaction
464 .batch_execute(statement)
465 .await
466 .map_err(toasty_core::Error::driver_operation_failed)
467 {
468 transaction
469 .rollback()
470 .await
471 .map_err(toasty_core::Error::driver_operation_failed)?;
472 return Err(e);
473 }
474 }
475
476 if let Err(e) = transaction
478 .execute(
479 "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
480 &[&(id as i64), &name],
481 )
482 .await
483 .map_err(toasty_core::Error::driver_operation_failed)
484 {
485 transaction
486 .rollback()
487 .await
488 .map_err(toasty_core::Error::driver_operation_failed)?;
489 return Err(e);
490 }
491
492 transaction
494 .commit()
495 .await
496 .map_err(toasty_core::Error::driver_operation_failed)?;
497 Ok(())
498 }
499}