toasty_driver_postgresql/
lib.rs1mod statement_cache;
2mod r#type;
3mod value;
4
5pub(crate) use value::Value;
6
7use postgres::{tls::MakeTlsConnect, types::ToSql, Socket};
8use std::{borrow::Cow, sync::Arc};
9use toasty_core::{
10 async_trait,
11 driver::{Capability, Driver, Operation, Response},
12 schema::db::{self, Migration, SchemaDiff, Table},
13 stmt,
14 stmt::ValueRecord,
15 Result, Schema,
16};
17use toasty_sql::{self as sql, TypedValue};
18use tokio_postgres::{Client, Config};
19use url::Url;
20
21use crate::{r#type::TypeExt, statement_cache::StatementCache};
22
23#[derive(Debug)]
24pub struct PostgreSQL {
25 url: String,
26 config: Config,
27}
28
29impl PostgreSQL {
30 pub fn new(url: impl Into<String>) -> Result<Self> {
32 let url_str = url.into();
33 let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
34
35 if !matches!(url.scheme(), "postgresql" | "postgres") {
36 return Err(toasty_core::Error::invalid_connection_url(format!(
37 "connection URL does not have a `postgresql` scheme; url={}",
38 url
39 )));
40 }
41
42 let host = url.host_str().ok_or_else(|| {
43 toasty_core::Error::invalid_connection_url(format!(
44 "missing host in connection URL; url={}",
45 url
46 ))
47 })?;
48
49 if url.path().is_empty() {
50 return Err(toasty_core::Error::invalid_connection_url(format!(
51 "no database specified - missing path in connection URL; url={}",
52 url
53 )));
54 }
55
56 let mut config = Config::new();
57 config.host(host);
58 config.dbname(url.path().trim_start_matches('/'));
59
60 if let Some(port) = url.port() {
61 config.port(port);
62 }
63
64 if !url.username().is_empty() {
65 config.user(url.username());
66 }
67
68 if let Some(password) = url.password() {
69 config.password(password);
70 }
71
72 Ok(Self {
73 url: url_str,
74 config,
75 })
76 }
77}
78
79#[async_trait]
80impl Driver for PostgreSQL {
81 fn url(&self) -> Cow<'_, str> {
82 Cow::Borrowed(&self.url)
83 }
84
85 fn capability(&self) -> &'static Capability {
86 &Capability::POSTGRESQL
87 }
88
89 async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
90 Ok(Box::new(
91 Connection::connect(self.config.clone(), tokio_postgres::NoTls).await?,
92 ))
93 }
94
95 fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
96 let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
97
98 let sql_strings: Vec<String> = statements
99 .iter()
100 .map(|stmt| {
101 let mut params = Vec::<TypedValue>::new();
102 let sql = sql::Serializer::postgresql(stmt.schema())
103 .serialize(stmt.statement(), &mut params);
104 assert!(
105 params.is_empty(),
106 "migration statements should not have parameters"
107 );
108 sql
109 })
110 .collect();
111
112 Migration::new_sql(sql_strings.join("\n"))
113 }
114
115 async fn reset_db(&self) -> toasty_core::Result<()> {
116 let dbname = self
117 .config
118 .get_dbname()
119 .ok_or_else(|| {
120 toasty_core::Error::invalid_connection_url("no database name configured")
121 })?
122 .to_string();
123
124 let temp_dbname = "__toasty_reset_temp";
126
127 let connect = |dbname: &str| {
128 let mut config = self.config.clone();
129 config.dbname(dbname);
130 Connection::connect(config, tokio_postgres::NoTls)
131 };
132
133 let conn = connect(&dbname).await?;
135 conn.client
136 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
137 .await
138 .map_err(toasty_core::Error::driver_operation_failed)?;
139 conn.client
140 .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
141 .await
142 .map_err(toasty_core::Error::driver_operation_failed)?;
143 drop(conn);
144
145 let conn = connect(temp_dbname).await?;
147 conn.client
148 .execute(
149 "SELECT pg_terminate_backend(pid) \
150 FROM pg_stat_activity \
151 WHERE datname = $1 AND pid <> pg_backend_pid()",
152 &[&dbname],
153 )
154 .await
155 .map_err(toasty_core::Error::driver_operation_failed)?;
156 conn.client
157 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
158 .await
159 .map_err(toasty_core::Error::driver_operation_failed)?;
160 conn.client
161 .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
162 .await
163 .map_err(toasty_core::Error::driver_operation_failed)?;
164 drop(conn);
165
166 let conn = connect(&dbname).await?;
168 conn.client
169 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
170 .await
171 .map_err(toasty_core::Error::driver_operation_failed)?;
172
173 Ok(())
174 }
175}
176
177#[derive(Debug)]
178pub struct Connection {
179 client: Client,
180 statement_cache: StatementCache,
181}
182
183impl Connection {
184 pub fn new(client: Client) -> Self {
186 Self {
187 client,
188 statement_cache: StatementCache::new(100),
189 }
190 }
191
192 pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
196 where
197 T: MakeTlsConnect<Socket> + 'static,
198 T::Stream: Send,
199 {
200 let (client, connection) = config
201 .connect(tls)
202 .await
203 .map_err(toasty_core::Error::driver_operation_failed)?;
204
205 tokio::spawn(async move {
206 if let Err(e) = connection.await {
207 eprintln!("connection error: {e}");
208 }
209 });
210
211 Ok(Self::new(client))
212 }
213
214 pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
216 let serializer = sql::Serializer::postgresql(schema);
217
218 let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
219 let sql = serializer.serialize(
220 &sql::Statement::create_table(table, &Capability::POSTGRESQL),
221 &mut params,
222 );
223
224 assert!(
225 params.is_empty(),
226 "creating a table shouldn't involve any parameters"
227 );
228
229 self.client
230 .execute(&sql, &[])
231 .await
232 .map_err(toasty_core::Error::driver_operation_failed)?;
233
234 for index in &table.indices {
237 if index.primary_key {
238 continue;
239 }
240
241 let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
242
243 assert!(
244 params.is_empty(),
245 "creating an index shouldn't involve any parameters"
246 );
247
248 self.client
249 .execute(&sql, &[])
250 .await
251 .map_err(toasty_core::Error::driver_operation_failed)?;
252 }
253
254 Ok(())
255 }
256}
257
258impl From<Client> for Connection {
259 fn from(client: Client) -> Self {
260 Self {
261 client,
262 statement_cache: StatementCache::new(100),
263 }
264 }
265}
266
267#[async_trait]
268impl toasty_core::driver::Connection for Connection {
269 async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<Response> {
270 if let Operation::Transaction(ref t) = op {
271 let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
272 self.client.batch_execute(&sql).await.map_err(|e| {
273 if let Some(db_err) = e.as_db_error() {
274 match db_err.code().code() {
275 "40001" => toasty_core::Error::serialization_failure(db_err.message()),
276 "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
277 _ => toasty_core::Error::driver_operation_failed(e),
278 }
279 } else {
280 toasty_core::Error::driver_operation_failed(e)
281 }
282 })?;
283 return Ok(Response::count(0));
284 }
285
286 let (sql, ret_tys): (sql::Statement, _) = match op {
287 Operation::Insert(op) => (op.stmt.into(), None),
288 Operation::QuerySql(query) => {
289 assert!(
290 query.last_insert_id_hack.is_none(),
291 "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
292 );
293 (query.stmt.into(), query.ret)
294 }
295 op => todo!("op={:#?}", op),
296 };
297
298 let width = sql.returning_len();
299
300 let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
301 let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql, &mut params);
302
303 let param_types = params
304 .iter()
305 .map(|typed_value| typed_value.infer_ty().to_postgres_type())
306 .collect::<Vec<_>>();
307
308 let values: Vec<_> = params.into_iter().map(|tv| Value::from(tv.value)).collect();
309 let params = values
310 .iter()
311 .map(|param| param as &(dyn ToSql + Sync))
312 .collect::<Vec<_>>();
313
314 let statement = self
315 .statement_cache
316 .prepare_typed(&mut self.client, &sql_as_str, ¶m_types)
317 .await
318 .map_err(toasty_core::Error::driver_operation_failed)?;
319
320 if width.is_none() {
321 let count = self
322 .client
323 .execute(&statement, ¶ms)
324 .await
325 .map_err(toasty_core::Error::driver_operation_failed)?;
326 return Ok(Response::count(count));
327 }
328
329 let rows = self
330 .client
331 .query(&statement, ¶ms)
332 .await
333 .map_err(toasty_core::Error::driver_operation_failed)?;
334
335 if width.is_none() {
336 let [row] = &rows[..] else { todo!() };
337 let total = row.get::<usize, i64>(0);
338 let condition_matched = row.get::<usize, i64>(1);
339
340 if total == condition_matched {
341 Ok(Response::count(total as _))
342 } else {
343 Err(toasty_core::Error::condition_failed(
344 "update condition did not match",
345 ))
346 }
347 } else {
348 let ret_tys = ret_tys.as_ref().unwrap().clone();
349 let results = rows.into_iter().map(move |row| {
350 let mut results = Vec::new();
351 for (i, column) in row.columns().iter().enumerate() {
352 results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
353 }
354
355 Ok(ValueRecord::from_vec(results))
356 });
357
358 Ok(Response::value_stream(stmt::ValueStream::from_iter(
359 results,
360 )))
361 }
362 }
363
364 async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
365 for table in &schema.db.tables {
366 self.create_table(&schema.db, table).await?;
367 }
368 Ok(())
369 }
370
371 async fn applied_migrations(
372 &mut self,
373 ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
374 self.client
376 .execute(
377 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
378 id BIGINT PRIMARY KEY,
379 name TEXT NOT NULL,
380 applied_at TIMESTAMP NOT NULL
381 )",
382 &[],
383 )
384 .await
385 .map_err(toasty_core::Error::driver_operation_failed)?;
386
387 let rows = self
389 .client
390 .query(
391 "SELECT id FROM __toasty_migrations ORDER BY applied_at",
392 &[],
393 )
394 .await
395 .map_err(toasty_core::Error::driver_operation_failed)?;
396
397 Ok(rows
398 .iter()
399 .map(|row| {
400 let id: i64 = row.get(0);
401 toasty_core::schema::db::AppliedMigration::new(id as u64)
402 })
403 .collect())
404 }
405
406 async fn apply_migration(
407 &mut self,
408 id: u64,
409 name: String,
410 migration: &toasty_core::schema::db::Migration,
411 ) -> Result<()> {
412 self.client
414 .execute(
415 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
416 id BIGINT PRIMARY KEY,
417 name TEXT NOT NULL,
418 applied_at TIMESTAMP NOT NULL
419 )",
420 &[],
421 )
422 .await
423 .map_err(toasty_core::Error::driver_operation_failed)?;
424
425 let transaction = self
427 .client
428 .transaction()
429 .await
430 .map_err(toasty_core::Error::driver_operation_failed)?;
431
432 for statement in migration.statements() {
434 if let Err(e) = transaction
435 .batch_execute(statement)
436 .await
437 .map_err(toasty_core::Error::driver_operation_failed)
438 {
439 transaction
440 .rollback()
441 .await
442 .map_err(toasty_core::Error::driver_operation_failed)?;
443 return Err(e);
444 }
445 }
446
447 if let Err(e) = transaction
449 .execute(
450 "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
451 &[&(id as i64), &name],
452 )
453 .await
454 .map_err(toasty_core::Error::driver_operation_failed)
455 {
456 transaction
457 .rollback()
458 .await
459 .map_err(toasty_core::Error::driver_operation_failed)?;
460 return Err(e);
461 }
462
463 transaction
465 .commit()
466 .await
467 .map_err(toasty_core::Error::driver_operation_failed)?;
468 Ok(())
469 }
470}