Skip to main content

toasty_driver_mysql/
lib.rs

1#![warn(missing_docs)]
2#![allow(clippy::needless_range_loop)]
3
4//! Toasty driver for [MySQL](https://www.mysql.com/) using
5//! [`mysql_async`](https://docs.rs/mysql_async).
6//!
7//! # Examples
8//!
9//! ```no_run
10//! use toasty_driver_mysql::MySQL;
11//!
12//! let driver = MySQL::new("mysql://localhost/mydb").unwrap();
13//! ```
14
15mod value;
16pub(crate) use value::Value;
17
18use async_trait::async_trait;
19use mysql_async::{
20    Conn, OptsBuilder,
21    prelude::{Queryable, ToValue},
22};
23use std::{borrow::Cow, sync::Arc};
24use toasty_core::{
25    Result, Schema,
26    driver::{Capability, Driver, ExecResponse, Operation},
27    schema::db::{self, Migration, SchemaDiff, Table},
28    stmt::{self, ValueRecord},
29};
30use toasty_sql::{self as sql};
31use url::Url;
32
33/// A MySQL [`Driver`] that connects via `mysql_async`.
34///
35/// # Examples
36///
37/// ```no_run
38/// use toasty_driver_mysql::MySQL;
39///
40/// let driver = MySQL::new("mysql://localhost/mydb").unwrap();
41/// ```
42#[derive(Debug)]
43pub struct MySQL {
44    url: String,
45    opts: OptsBuilder,
46}
47
48impl MySQL {
49    /// Create a new MySQL driver from a connection URL.
50    ///
51    /// The URL must use the `mysql` scheme and include a database path, e.g.
52    /// `mysql://user:pass@host:3306/dbname`.
53    pub fn new(url: impl Into<String>) -> Result<Self> {
54        let url_str = url.into();
55        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
56
57        if url.scheme() != "mysql" {
58            return Err(toasty_core::Error::invalid_connection_url(format!(
59                "connection url does not have a `mysql` scheme; url={}",
60                url
61            )));
62        }
63
64        url.host_str().ok_or_else(|| {
65            toasty_core::Error::invalid_connection_url(format!(
66                "missing host in connection URL; url={}",
67                url
68            ))
69        })?;
70
71        if url.path().is_empty() {
72            return Err(toasty_core::Error::invalid_connection_url(format!(
73                "no database specified - missing path in connection URL; url={}",
74                url
75            )));
76        }
77
78        let opts = mysql_async::Opts::from_url(url.as_ref())
79            .map_err(toasty_core::Error::driver_operation_failed)?;
80        let opts = mysql_async::OptsBuilder::from_opts(opts).client_found_rows(true);
81
82        Ok(Self { url: url_str, opts })
83    }
84}
85
86#[async_trait]
87impl Driver for MySQL {
88    fn url(&self) -> Cow<'_, str> {
89        Cow::Borrowed(&self.url)
90    }
91
92    fn capability(&self) -> &'static Capability {
93        &Capability::MYSQL
94    }
95
96    async fn connect(&self) -> Result<Box<dyn toasty_core::driver::Connection>> {
97        let conn = Conn::new(self.opts.clone())
98            .await
99            .map_err(toasty_core::Error::driver_operation_failed)?;
100        Ok(Box::new(Connection::new(conn)))
101    }
102
103    fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
104        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::MYSQL);
105
106        let sql_strings: Vec<String> = statements
107            .iter()
108            .map(|stmt| sql::Serializer::mysql(stmt.schema()).serialize(stmt.statement()))
109            .collect();
110
111        Migration::new_sql_with_breakpoints(&sql_strings)
112    }
113
114    async fn reset_db(&self) -> toasty_core::Result<()> {
115        let mut conn = Conn::new(self.opts.clone())
116            .await
117            .map_err(toasty_core::Error::driver_operation_failed)?;
118
119        let dbname = conn
120            .opts()
121            .db_name()
122            .ok_or_else(|| {
123                toasty_core::Error::invalid_connection_url("no database name configured")
124            })?
125            .to_string();
126
127        conn.query_drop(format!("DROP DATABASE IF EXISTS `{}`", dbname))
128            .await
129            .map_err(toasty_core::Error::driver_operation_failed)?;
130
131        conn.query_drop(format!("CREATE DATABASE `{}`", dbname))
132            .await
133            .map_err(toasty_core::Error::driver_operation_failed)?;
134
135        conn.query_drop(format!("USE `{}`", dbname))
136            .await
137            .map_err(toasty_core::Error::driver_operation_failed)?;
138
139        Ok(())
140    }
141}
142
143/// An open connection to a MySQL database.
144#[derive(Debug)]
145pub struct Connection {
146    conn: Conn,
147}
148
149impl Connection {
150    /// Wrap an existing [`mysql_async::Conn`] as a Toasty connection.
151    pub fn new(conn: Conn) -> Self {
152        Self { conn }
153    }
154
155    /// Create a table and its indices from a schema definition.
156    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
157        let serializer = sql::Serializer::mysql(schema);
158
159        let sql = serializer.serialize(&sql::Statement::create_table(table, &Capability::MYSQL));
160
161        self.conn
162            .exec_drop(&sql, ())
163            .await
164            .map_err(toasty_core::Error::driver_operation_failed)?;
165
166        for index in &table.indices {
167            if index.primary_key {
168                continue;
169            }
170
171            let sql = serializer.serialize(&sql::Statement::create_index(index));
172
173            self.conn
174                .exec_drop(&sql, ())
175                .await
176                .map_err(toasty_core::Error::driver_operation_failed)?;
177        }
178
179        Ok(())
180    }
181}
182
183impl From<Conn> for Connection {
184    fn from(conn: Conn) -> Self {
185        Self { conn }
186    }
187}
188
189#[async_trait]
190impl toasty_core::driver::Connection for Connection {
191    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
192        tracing::trace!(driver = "mysql", op = %op.name(), "driver exec");
193
194        let (sql, typed_params, ret, last_insert_id_hack) = match op {
195            Operation::QuerySql(op) => (
196                sql::Statement::from(op.stmt),
197                op.params,
198                op.ret,
199                op.last_insert_id_hack,
200            ),
201            Operation::Transaction(op) => {
202                let sql = sql::Serializer::mysql(&schema.db).serialize_transaction(&op);
203                self.conn.query_drop(sql).await.map_err(|e| match e {
204                    mysql_async::Error::Server(se) => match se.code {
205                        1213 => toasty_core::Error::serialization_failure(se.message),
206                        1792 => toasty_core::Error::read_only_transaction(se.message),
207                        _ => toasty_core::Error::driver_operation_failed(
208                            mysql_async::Error::Server(se),
209                        ),
210                    },
211                    other => toasty_core::Error::driver_operation_failed(other),
212                })?;
213                return Ok(ExecResponse::count(0));
214            }
215            op => todo!("op={:#?}", op),
216        };
217
218        let (sql_as_str, arg_order) =
219            sql::Serializer::mysql(&schema.db).serialize_with_arg_order(&sql);
220
221        tracing::debug!(db.system = "mysql", db.statement = %sql_as_str, params = typed_params.len(), "executing SQL");
222
223        // MySQL uses positional `?` without indices, so params must be reordered
224        // to match the order `Expr::Arg(n)` placeholders appear in the SQL.
225        let params: Vec<_> = arg_order
226            .iter()
227            .map(|&pos| Value::from(typed_params[pos].value.clone()))
228            .collect();
229        let args = params
230            .iter()
231            .map(|param| param.to_value())
232            .collect::<Vec<_>>();
233
234        let statement = self
235            .conn
236            .prep(&sql_as_str)
237            .await
238            .map_err(toasty_core::Error::driver_operation_failed)?;
239
240        if ret.is_none() {
241            let count = self
242                .conn
243                .exec_iter(&statement, mysql_async::Params::Positional(args))
244                .await
245                .map_err(toasty_core::Error::driver_operation_failed)?
246                .affected_rows();
247
248            // Handle the last_insert_id_hack for MySQL INSERT with RETURNING
249            if let Some(num_rows) = last_insert_id_hack {
250                // Assert the previous statement was an INSERT
251                assert!(
252                    matches!(sql, sql::Statement::Insert(_)),
253                    "last_insert_id_hack should only be used with INSERT statements"
254                );
255
256                // Execute SELECT LAST_INSERT_ID() on the same connection
257                let first_id: u64 = self
258                    .conn
259                    .query_first("SELECT LAST_INSERT_ID()")
260                    .await
261                    .map_err(toasty_core::Error::driver_operation_failed)?
262                    .ok_or_else(|| {
263                        toasty_core::Error::driver_operation_failed(std::io::Error::other(
264                            "LAST_INSERT_ID() returned no rows",
265                        ))
266                    })?;
267
268                // Generate rows with sequential IDs
269                let results = (0..num_rows).map(move |offset| {
270                    let id = first_id + offset;
271                    // Return a record with a single field containing the ID
272                    Ok(ValueRecord::from_vec(vec![stmt::Value::U64(id)]))
273                });
274
275                return Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
276                    results,
277                )));
278            }
279
280            return Ok(ExecResponse::count(count));
281        }
282
283        let rows: Vec<mysql_async::Row> = self
284            .conn
285            .exec(&statement, &args)
286            .await
287            .map_err(toasty_core::Error::driver_operation_failed)?;
288
289        if let Some(returning) = ret {
290            let results = rows.into_iter().map(move |mut row| {
291                assert_eq!(
292                    row.len(),
293                    returning.len(),
294                    "row={row:#?}; returning={returning:#?}"
295                );
296
297                let mut results = Vec::new();
298                for i in 0..row.len() {
299                    let column = &row.columns()[i];
300                    results.push(Value::from_sql(i, &mut row, column, &returning[i]).into_inner());
301                }
302
303                Ok(ValueRecord::from_vec(results))
304            });
305
306            Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
307                results,
308            )))
309        } else {
310            let [row] = &rows[..] else { todo!() };
311            let total = row.get::<i64, usize>(0).unwrap();
312            let condition_matched = row.get::<i64, usize>(1).unwrap();
313
314            if total == condition_matched {
315                Ok(ExecResponse::count(total as _))
316            } else {
317                Err(toasty_core::Error::condition_failed(
318                    "update condition did not match",
319                ))
320            }
321        }
322    }
323
324    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
325        for table in &schema.db.tables {
326            tracing::debug!(table = %table.name, "creating table");
327            self.create_table(&schema.db, table).await?;
328        }
329        Ok(())
330    }
331
332    async fn applied_migrations(
333        &mut self,
334    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
335        // Ensure the migrations table exists
336        self.conn
337            .exec_drop(
338                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
339                id BIGINT UNSIGNED PRIMARY KEY,
340                name TEXT NOT NULL,
341                applied_at TIMESTAMP NOT NULL
342            )",
343                (),
344            )
345            .await
346            .map_err(toasty_core::Error::driver_operation_failed)?;
347
348        // Query all applied migrations
349        let rows: Vec<u64> = self
350            .conn
351            .exec("SELECT id FROM __toasty_migrations ORDER BY applied_at", ())
352            .await
353            .map_err(toasty_core::Error::driver_operation_failed)?;
354
355        Ok(rows
356            .into_iter()
357            .map(toasty_core::schema::db::AppliedMigration::new)
358            .collect())
359    }
360
361    async fn apply_migration(
362        &mut self,
363        id: u64,
364        name: &str,
365        migration: &toasty_core::schema::db::Migration,
366    ) -> Result<()> {
367        tracing::info!(id = id, name = %name, "applying migration");
368        // Ensure the migrations table exists
369        self.conn
370            .exec_drop(
371                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
372                id BIGINT UNSIGNED PRIMARY KEY,
373                name TEXT NOT NULL,
374                applied_at TIMESTAMP NOT NULL
375            )",
376                (),
377            )
378            .await
379            .map_err(toasty_core::Error::driver_operation_failed)?;
380
381        // Start transaction
382        let mut transaction = self
383            .conn
384            .start_transaction(Default::default())
385            .await
386            .map_err(toasty_core::Error::driver_operation_failed)?;
387
388        // Execute each migration statement
389        for statement in migration.statements() {
390            if let Err(e) = transaction
391                .query_drop(statement)
392                .await
393                .map_err(toasty_core::Error::driver_operation_failed)
394            {
395                transaction
396                    .rollback()
397                    .await
398                    .map_err(toasty_core::Error::driver_operation_failed)?;
399                return Err(e);
400            }
401        }
402
403        // Record the migration
404        if let Err(e) = transaction
405            .exec_drop(
406                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES (?, ?, NOW())",
407                (id, name),
408            )
409            .await
410            .map_err(toasty_core::Error::driver_operation_failed)
411        {
412            transaction
413                .rollback()
414                .await
415                .map_err(toasty_core::Error::driver_operation_failed)?;
416            return Err(e);
417        }
418
419        // Commit transaction
420        transaction
421            .commit()
422            .await
423            .map_err(toasty_core::Error::driver_operation_failed)?;
424        Ok(())
425    }
426}