toasty_driver_mysql/
lib.rs

1#![warn(missing_docs)]
2#![allow(clippy::needless_range_loop)]
3
4//! Toasty driver for [MySQL](https://www.mysql.com/) using
5//! [`mysql_async`](https://docs.rs/mysql_async).
6//!
7//! # Examples
8//!
9//! ```no_run
10//! use toasty_driver_mysql::MySQL;
11//!
12//! let driver = MySQL::new("mysql://localhost/mydb").unwrap();
13//! ```
14
15mod value;
16pub(crate) use value::Value;
17
18use async_trait::async_trait;
19use mysql_async::{
20    Conn, OptsBuilder,
21    prelude::{Queryable, ToValue},
22};
23use std::{borrow::Cow, sync::Arc};
24use toasty_core::{
25    Result, Schema,
26    driver::{Capability, Driver, ExecResponse, Operation},
27    schema::db::{self, Migration, SchemaDiff, Table},
28    stmt::{self, ValueRecord},
29};
30use toasty_sql::{self as sql, TypedValue};
31use url::Url;
32
33/// A MySQL [`Driver`] that connects via `mysql_async`.
34///
35/// # Examples
36///
37/// ```no_run
38/// use toasty_driver_mysql::MySQL;
39///
40/// let driver = MySQL::new("mysql://localhost/mydb").unwrap();
41/// ```
42#[derive(Debug)]
43pub struct MySQL {
44    url: String,
45    opts: OptsBuilder,
46}
47
48impl MySQL {
49    /// Create a new MySQL driver from a connection URL.
50    ///
51    /// The URL must use the `mysql` scheme and include a database path, e.g.
52    /// `mysql://user:pass@host:3306/dbname`.
53    pub fn new(url: impl Into<String>) -> Result<Self> {
54        let url_str = url.into();
55        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
56
57        if url.scheme() != "mysql" {
58            return Err(toasty_core::Error::invalid_connection_url(format!(
59                "connection url does not have a `mysql` scheme; url={}",
60                url
61            )));
62        }
63
64        url.host_str().ok_or_else(|| {
65            toasty_core::Error::invalid_connection_url(format!(
66                "missing host in connection URL; url={}",
67                url
68            ))
69        })?;
70
71        if url.path().is_empty() {
72            return Err(toasty_core::Error::invalid_connection_url(format!(
73                "no database specified - missing path in connection URL; url={}",
74                url
75            )));
76        }
77
78        let opts = mysql_async::Opts::from_url(url.as_ref())
79            .map_err(toasty_core::Error::driver_operation_failed)?;
80        let opts = mysql_async::OptsBuilder::from_opts(opts).client_found_rows(true);
81
82        Ok(Self { url: url_str, opts })
83    }
84}
85
86#[async_trait]
87impl Driver for MySQL {
88    fn url(&self) -> Cow<'_, str> {
89        Cow::Borrowed(&self.url)
90    }
91
92    fn capability(&self) -> &'static Capability {
93        &Capability::MYSQL
94    }
95
96    async fn connect(&self) -> Result<Box<dyn toasty_core::driver::Connection>> {
97        let conn = Conn::new(self.opts.clone())
98            .await
99            .map_err(toasty_core::Error::driver_operation_failed)?;
100        Ok(Box::new(Connection::new(conn)))
101    }
102
103    fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
104        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::MYSQL);
105
106        let sql_strings: Vec<String> = statements
107            .iter()
108            .map(|stmt| {
109                let mut params = Vec::<TypedValue>::new();
110                let sql =
111                    sql::Serializer::mysql(stmt.schema()).serialize(stmt.statement(), &mut params);
112                assert!(
113                    params.is_empty(),
114                    "migration statements should not have parameters"
115                );
116                sql
117            })
118            .collect();
119
120        Migration::new_sql_with_breakpoints(&sql_strings)
121    }
122
123    async fn reset_db(&self) -> toasty_core::Result<()> {
124        let mut conn = Conn::new(self.opts.clone())
125            .await
126            .map_err(toasty_core::Error::driver_operation_failed)?;
127
128        let dbname = conn
129            .opts()
130            .db_name()
131            .ok_or_else(|| {
132                toasty_core::Error::invalid_connection_url("no database name configured")
133            })?
134            .to_string();
135
136        conn.query_drop(format!("DROP DATABASE IF EXISTS `{}`", dbname))
137            .await
138            .map_err(toasty_core::Error::driver_operation_failed)?;
139
140        conn.query_drop(format!("CREATE DATABASE `{}`", dbname))
141            .await
142            .map_err(toasty_core::Error::driver_operation_failed)?;
143
144        conn.query_drop(format!("USE `{}`", dbname))
145            .await
146            .map_err(toasty_core::Error::driver_operation_failed)?;
147
148        Ok(())
149    }
150}
151
152/// An open connection to a MySQL database.
153#[derive(Debug)]
154pub struct Connection {
155    conn: Conn,
156}
157
158impl Connection {
159    /// Wrap an existing [`mysql_async::Conn`] as a Toasty connection.
160    pub fn new(conn: Conn) -> Self {
161        Self { conn }
162    }
163
164    /// Create a table and its indices from a schema definition.
165    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
166        let serializer = sql::Serializer::mysql(schema);
167
168        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
169
170        let sql = serializer.serialize(
171            &sql::Statement::create_table(table, &Capability::MYSQL),
172            &mut params,
173        );
174
175        assert!(
176            params.is_empty(),
177            "creating a table shouldn't involve any parameters"
178        );
179
180        self.conn
181            .exec_drop(&sql, ())
182            .await
183            .map_err(toasty_core::Error::driver_operation_failed)?;
184
185        for index in &table.indices {
186            if index.primary_key {
187                continue;
188            }
189
190            let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
191
192            assert!(
193                params.is_empty(),
194                "creating an index shouldn't involve any parameters"
195            );
196
197            self.conn
198                .exec_drop(&sql, ())
199                .await
200                .map_err(toasty_core::Error::driver_operation_failed)?;
201        }
202
203        Ok(())
204    }
205}
206
207impl From<Conn> for Connection {
208    fn from(conn: Conn) -> Self {
209        Self { conn }
210    }
211}
212
213#[async_trait]
214impl toasty_core::driver::Connection for Connection {
215    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
216        tracing::trace!(driver = "mysql", op = %op.name(), "driver exec");
217
218        let (sql, ret, last_insert_id_hack): (sql::Statement, _, _) = match op {
219            Operation::QuerySql(op) => (op.stmt.into(), op.ret, op.last_insert_id_hack),
220            Operation::Transaction(op) => {
221                let sql = sql::Serializer::mysql(&schema.db).serialize_transaction(&op);
222                self.conn.query_drop(sql).await.map_err(|e| match e {
223                    mysql_async::Error::Server(se) => match se.code {
224                        1213 => toasty_core::Error::serialization_failure(se.message),
225                        1792 => toasty_core::Error::read_only_transaction(se.message),
226                        _ => toasty_core::Error::driver_operation_failed(
227                            mysql_async::Error::Server(se),
228                        ),
229                    },
230                    other => toasty_core::Error::driver_operation_failed(other),
231                })?;
232                return Ok(ExecResponse::count(0));
233            }
234            op => todo!("op={:#?}", op),
235        };
236
237        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
238
239        let sql_as_str = sql::Serializer::mysql(&schema.db).serialize(&sql, &mut params);
240
241        tracing::debug!(db.system = "mysql", db.statement = %sql_as_str, params = params.len(), "executing SQL");
242
243        let params = params
244            .into_iter()
245            .map(|tv| Value::from(tv.value))
246            .collect::<Vec<_>>();
247        let args = params
248            .iter()
249            .map(|param| param.to_value())
250            .collect::<Vec<_>>();
251
252        let statement = self
253            .conn
254            .prep(&sql_as_str)
255            .await
256            .map_err(toasty_core::Error::driver_operation_failed)?;
257
258        if ret.is_none() {
259            let count = self
260                .conn
261                .exec_iter(&statement, mysql_async::Params::Positional(args))
262                .await
263                .map_err(toasty_core::Error::driver_operation_failed)?
264                .affected_rows();
265
266            // Handle the last_insert_id_hack for MySQL INSERT with RETURNING
267            if let Some(num_rows) = last_insert_id_hack {
268                // Assert the previous statement was an INSERT
269                assert!(
270                    matches!(sql, sql::Statement::Insert(_)),
271                    "last_insert_id_hack should only be used with INSERT statements"
272                );
273
274                // Execute SELECT LAST_INSERT_ID() on the same connection
275                let first_id: u64 = self
276                    .conn
277                    .query_first("SELECT LAST_INSERT_ID()")
278                    .await
279                    .map_err(toasty_core::Error::driver_operation_failed)?
280                    .ok_or_else(|| {
281                        toasty_core::Error::driver_operation_failed(std::io::Error::other(
282                            "LAST_INSERT_ID() returned no rows",
283                        ))
284                    })?;
285
286                // Generate rows with sequential IDs
287                let results = (0..num_rows).map(move |offset| {
288                    let id = first_id + offset;
289                    // Return a record with a single field containing the ID
290                    Ok(ValueRecord::from_vec(vec![stmt::Value::U64(id)]))
291                });
292
293                return Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
294                    results,
295                )));
296            }
297
298            return Ok(ExecResponse::count(count));
299        }
300
301        let rows: Vec<mysql_async::Row> = self
302            .conn
303            .exec(&statement, &args)
304            .await
305            .map_err(toasty_core::Error::driver_operation_failed)?;
306
307        if let Some(returning) = ret {
308            let results = rows.into_iter().map(move |mut row| {
309                assert_eq!(
310                    row.len(),
311                    returning.len(),
312                    "row={row:#?}; returning={returning:#?}"
313                );
314
315                let mut results = Vec::new();
316                for i in 0..row.len() {
317                    let column = &row.columns()[i];
318                    results.push(Value::from_sql(i, &mut row, column, &returning[i]).into_inner());
319                }
320
321                Ok(ValueRecord::from_vec(results))
322            });
323
324            Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
325                results,
326            )))
327        } else {
328            let [row] = &rows[..] else { todo!() };
329            let total = row.get::<i64, usize>(0).unwrap();
330            let condition_matched = row.get::<i64, usize>(1).unwrap();
331
332            if total == condition_matched {
333                Ok(ExecResponse::count(total as _))
334            } else {
335                Err(toasty_core::Error::condition_failed(
336                    "update condition did not match",
337                ))
338            }
339        }
340    }
341
342    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
343        for table in &schema.db.tables {
344            tracing::debug!(table = %table.name, "creating table");
345            self.create_table(&schema.db, table).await?;
346        }
347        Ok(())
348    }
349
350    async fn applied_migrations(
351        &mut self,
352    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
353        // Ensure the migrations table exists
354        self.conn
355            .exec_drop(
356                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
357                id BIGINT UNSIGNED PRIMARY KEY,
358                name TEXT NOT NULL,
359                applied_at TIMESTAMP NOT NULL
360            )",
361                (),
362            )
363            .await
364            .map_err(toasty_core::Error::driver_operation_failed)?;
365
366        // Query all applied migrations
367        let rows: Vec<u64> = self
368            .conn
369            .exec("SELECT id FROM __toasty_migrations ORDER BY applied_at", ())
370            .await
371            .map_err(toasty_core::Error::driver_operation_failed)?;
372
373        Ok(rows
374            .into_iter()
375            .map(toasty_core::schema::db::AppliedMigration::new)
376            .collect())
377    }
378
379    async fn apply_migration(
380        &mut self,
381        id: u64,
382        name: &str,
383        migration: &toasty_core::schema::db::Migration,
384    ) -> Result<()> {
385        tracing::info!(id = id, name = %name, "applying migration");
386        // Ensure the migrations table exists
387        self.conn
388            .exec_drop(
389                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
390                id BIGINT UNSIGNED PRIMARY KEY,
391                name TEXT NOT NULL,
392                applied_at TIMESTAMP NOT NULL
393            )",
394                (),
395            )
396            .await
397            .map_err(toasty_core::Error::driver_operation_failed)?;
398
399        // Start transaction
400        let mut transaction = self
401            .conn
402            .start_transaction(Default::default())
403            .await
404            .map_err(toasty_core::Error::driver_operation_failed)?;
405
406        // Execute each migration statement
407        for statement in migration.statements() {
408            if let Err(e) = transaction
409                .query_drop(statement)
410                .await
411                .map_err(toasty_core::Error::driver_operation_failed)
412            {
413                transaction
414                    .rollback()
415                    .await
416                    .map_err(toasty_core::Error::driver_operation_failed)?;
417                return Err(e);
418            }
419        }
420
421        // Record the migration
422        if let Err(e) = transaction
423            .exec_drop(
424                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES (?, ?, NOW())",
425                (id, name),
426            )
427            .await
428            .map_err(toasty_core::Error::driver_operation_failed)
429        {
430            transaction
431                .rollback()
432                .await
433                .map_err(toasty_core::Error::driver_operation_failed)?;
434            return Err(e);
435        }
436
437        // Commit transaction
438        transaction
439            .commit()
440            .await
441            .map_err(toasty_core::Error::driver_operation_failed)?;
442        Ok(())
443    }
444}