Skip to main content

toasty_driver_mysql/
lib.rs

1#![warn(missing_docs)]
2#![allow(clippy::needless_range_loop)]
3
4//! Toasty driver for [MySQL](https://www.mysql.com/) using
5//! [`mysql_async`](https://docs.rs/mysql_async).
6//!
7//! # Examples
8//!
9//! ```no_run
10//! use toasty_driver_mysql::MySQL;
11//!
12//! let driver = MySQL::new("mysql://localhost/mydb").unwrap();
13//! ```
14
15mod value;
16pub(crate) use value::Value;
17
18use async_trait::async_trait;
19use mysql_async::{
20    Conn, OptsBuilder,
21    prelude::{Queryable, ToValue},
22};
23use std::{borrow::Cow, cell::Cell, sync::Arc};
24use toasty_core::{
25    Result, Schema,
26    driver::{
27        Capability, Driver, ExecResponse, Operation,
28        operation::{Transaction, TransactionMode},
29    },
30    schema::{
31        db::{self, Migration, Table},
32        diff,
33    },
34    stmt::{self, ValueRecord},
35};
36use toasty_sql::{self as sql};
37use url::Url;
38
39/// Classifies a `mysql_async::Error` into a Toasty error.
40///
41/// `Error::Io` (any TCP/TLS-level fault) and the IO-shaped `Driver`
42/// variants (`ConnectionClosed`, `PoolDisconnected`) become
43/// `ConnectionLost`. `Server` errors with known SQLSTATE codes are
44/// mapped to typed variants. Everything else is
45/// `DriverOperationFailed`.
46fn classify_mysql_error(e: mysql_async::Error) -> toasty_core::Error {
47    use mysql_async::{DriverError, Error};
48    match e {
49        Error::Io(_) => toasty_core::Error::connection_lost(e),
50        Error::Driver(DriverError::ConnectionClosed | DriverError::PoolDisconnected) => {
51            toasty_core::Error::connection_lost(e)
52        }
53        Error::Server(se) => match se.code {
54            1213 => toasty_core::Error::serialization_failure(se.message),
55            1792 => toasty_core::Error::read_only_transaction(se.message),
56            _ => toasty_core::Error::driver_operation_failed(Error::Server(se)),
57        },
58        other => toasty_core::Error::driver_operation_failed(other),
59    }
60}
61
62/// Classify a `mysql_async::Error`, also flipping the connection's
63/// validity flag if the error indicates the connection is gone.
64fn record_mysql_err(valid: &Cell<bool>, e: mysql_async::Error) -> toasty_core::Error {
65    let err = classify_mysql_error(e);
66    if err.is_connection_lost() {
67        valid.set(false);
68    }
69    err
70}
71
72/// A MySQL [`Driver`] that connects via `mysql_async`.
73///
74/// # Examples
75///
76/// ```no_run
77/// use toasty_driver_mysql::MySQL;
78///
79/// let driver = MySQL::new("mysql://localhost/mydb").unwrap();
80/// ```
81#[derive(Debug)]
82pub struct MySQL {
83    url: String,
84    opts: OptsBuilder,
85}
86
87impl MySQL {
88    /// Create a new MySQL driver from a connection URL.
89    ///
90    /// The URL must use the `mysql` scheme and include a database path, e.g.
91    /// `mysql://user:pass@host:3306/dbname`.
92    pub fn new(url: impl Into<String>) -> Result<Self> {
93        let url_str = url.into();
94        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
95
96        if url.scheme() != "mysql" {
97            return Err(toasty_core::Error::invalid_connection_url(format!(
98                "connection url does not have a `mysql` scheme; url={}",
99                url
100            )));
101        }
102
103        url.host_str().ok_or_else(|| {
104            toasty_core::Error::invalid_connection_url(format!(
105                "missing host in connection URL; url={}",
106                url
107            ))
108        })?;
109
110        if url.path().is_empty() {
111            return Err(toasty_core::Error::invalid_connection_url(format!(
112                "no database specified - missing path in connection URL; url={}",
113                url
114            )));
115        }
116
117        let opts = mysql_async::Opts::from_url(url.as_ref())
118            .map_err(toasty_core::Error::driver_operation_failed)?;
119        let opts = mysql_async::OptsBuilder::from_opts(opts).client_found_rows(true);
120
121        Ok(Self { url: url_str, opts })
122    }
123}
124
125#[async_trait]
126impl Driver for MySQL {
127    fn url(&self) -> Cow<'_, str> {
128        Cow::Borrowed(&self.url)
129    }
130
131    fn capability(&self) -> &'static Capability {
132        &Capability::MYSQL
133    }
134
135    async fn connect(&self) -> Result<Box<dyn toasty_core::driver::Connection>> {
136        let conn = Conn::new(self.opts.clone())
137            .await
138            .map_err(classify_mysql_error)?;
139        Ok(Box::new(Connection::new(conn)))
140    }
141
142    fn generate_migration(&self, schema_diff: &diff::Schema<'_>) -> Migration {
143        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::MYSQL);
144
145        let sql_strings: Vec<String> = statements
146            .iter()
147            .map(|stmt| sql::Serializer::mysql(stmt.schema()).serialize(stmt.statement()))
148            .collect();
149
150        Migration::new_sql_with_breakpoints(&sql_strings)
151    }
152
153    async fn reset_db(&self) -> toasty_core::Result<()> {
154        let mut conn = Conn::new(self.opts.clone())
155            .await
156            .map_err(classify_mysql_error)?;
157
158        let dbname = conn
159            .opts()
160            .db_name()
161            .ok_or_else(|| {
162                toasty_core::Error::invalid_connection_url("no database name configured")
163            })?
164            .to_string();
165
166        conn.query_drop(format!("DROP DATABASE IF EXISTS `{}`", dbname))
167            .await
168            .map_err(classify_mysql_error)?;
169
170        conn.query_drop(format!("CREATE DATABASE `{}`", dbname))
171            .await
172            .map_err(classify_mysql_error)?;
173
174        conn.query_drop(format!("USE `{}`", dbname))
175            .await
176            .map_err(classify_mysql_error)?;
177
178        Ok(())
179    }
180}
181
182/// An open connection to a MySQL database.
183#[derive(Debug)]
184pub struct Connection {
185    conn: Conn,
186    /// Set to `false` once `exec` has observed a connection-lost
187    /// error. `mysql_async::Conn` does not expose a passive flag, so
188    /// the driver tracks one itself. Read by [`is_valid`].
189    valid: Cell<bool>,
190}
191
192impl Connection {
193    /// Wrap an existing [`mysql_async::Conn`] as a Toasty connection.
194    pub fn new(conn: Conn) -> Self {
195        Self {
196            conn,
197            valid: Cell::new(true),
198        }
199    }
200
201    /// Create a table and its indices from a schema definition.
202    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
203        let serializer = sql::Serializer::mysql(schema);
204
205        let sql = serializer.serialize(&sql::Statement::create_table(table, &Capability::MYSQL));
206
207        self.conn
208            .exec_drop(&sql, ())
209            .await
210            .map_err(|e| record_mysql_err(&self.valid, e))?;
211
212        for index in &table.indices {
213            if index.primary_key {
214                continue;
215            }
216
217            let sql = serializer.serialize(&sql::Statement::create_index(index));
218
219            self.conn
220                .exec_drop(&sql, ())
221                .await
222                .map_err(|e| record_mysql_err(&self.valid, e))?;
223        }
224
225        Ok(())
226    }
227}
228
229impl From<Conn> for Connection {
230    fn from(conn: Conn) -> Self {
231        Self::new(conn)
232    }
233}
234
235#[async_trait]
236impl toasty_core::driver::Connection for Connection {
237    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
238        tracing::trace!(driver = "mysql", op = %op.name(), "driver exec");
239
240        let (sql, typed_params, ret, last_insert_id_hack) = match op {
241            Operation::QuerySql(op) => (
242                sql::Statement::from(op.stmt),
243                op.params,
244                op.ret,
245                op.last_insert_id_hack,
246            ),
247            Operation::Transaction(op) => {
248                // MySQL has no `BEGIN IMMEDIATE` / `BEGIN EXCLUSIVE`
249                // analogue; reject non-Default modes loudly rather than
250                // silently dropping them at the serializer.
251                if let Transaction::Start {
252                    mode: mode @ (TransactionMode::Immediate | TransactionMode::Exclusive),
253                    ..
254                } = &op
255                {
256                    return Err(toasty_core::Error::unsupported_feature(format!(
257                        "MySQL does not support TransactionMode::{mode:?}"
258                    )));
259                }
260                let sql = sql::Serializer::mysql(&schema.db).serialize_transaction(&op);
261                self.conn
262                    .query_drop(sql)
263                    .await
264                    .map_err(|e| record_mysql_err(&self.valid, e))?;
265                return Ok(ExecResponse::count(0));
266            }
267            op => todo!("op={:#?}", op),
268        };
269
270        let (sql_as_str, arg_order) =
271            sql::Serializer::mysql(&schema.db).serialize_with_arg_order(&sql);
272
273        tracing::debug!(db.system = "mysql", db.statement = %sql_as_str, params = typed_params.len(), "executing SQL");
274
275        // MySQL uses positional `?` without indices, so params must be reordered
276        // to match the order `Expr::Arg(n)` placeholders appear in the SQL.
277        let params: Vec<_> = arg_order
278            .iter()
279            .map(|&pos| Value::from(typed_params[pos].value.clone()))
280            .collect();
281        let args = params
282            .iter()
283            .map(|param| param.to_value())
284            .collect::<Vec<_>>();
285
286        let statement = self
287            .conn
288            .prep(&sql_as_str)
289            .await
290            .map_err(|e| record_mysql_err(&self.valid, e))?;
291
292        if ret.is_none() {
293            let count = self
294                .conn
295                .exec_iter(&statement, mysql_async::Params::Positional(args))
296                .await
297                .map_err(|e| record_mysql_err(&self.valid, e))?
298                .affected_rows();
299
300            // Handle the last_insert_id_hack for MySQL INSERT with RETURNING
301            if let Some(num_rows) = last_insert_id_hack {
302                // Assert the previous statement was an INSERT
303                assert!(
304                    matches!(sql, sql::Statement::Insert(_)),
305                    "last_insert_id_hack should only be used with INSERT statements"
306                );
307
308                // Execute SELECT LAST_INSERT_ID() on the same connection
309                let first_id: u64 = self
310                    .conn
311                    .query_first("SELECT LAST_INSERT_ID()")
312                    .await
313                    .map_err(|e| record_mysql_err(&self.valid, e))?
314                    .ok_or_else(|| {
315                        toasty_core::Error::driver_operation_failed(std::io::Error::other(
316                            "LAST_INSERT_ID() returned no rows",
317                        ))
318                    })?;
319
320                // Generate rows with sequential IDs
321                let results = (0..num_rows).map(move |offset| {
322                    let id = first_id + offset;
323                    // Return a record with a single field containing the ID
324                    Ok(ValueRecord::from_vec(vec![stmt::Value::U64(id)]))
325                });
326
327                return Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
328                    results,
329                )));
330            }
331
332            return Ok(ExecResponse::count(count));
333        }
334
335        let rows: Vec<mysql_async::Row> = self
336            .conn
337            .exec(&statement, &args)
338            .await
339            .map_err(|e| record_mysql_err(&self.valid, e))?;
340
341        if let Some(returning) = ret {
342            let results = rows.into_iter().map(move |mut row| {
343                assert_eq!(
344                    row.len(),
345                    returning.len(),
346                    "row={row:#?}; returning={returning:#?}"
347                );
348
349                let mut results = Vec::new();
350                for i in 0..row.len() {
351                    let column = &row.columns()[i];
352                    results.push(Value::from_sql(i, &mut row, column, &returning[i]).into_inner());
353                }
354
355                Ok(ValueRecord::from_vec(results))
356            });
357
358            Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
359                results,
360            )))
361        } else {
362            let [row] = &rows[..] else { todo!() };
363            let total = row.get::<i64, usize>(0).unwrap();
364            let condition_matched = row.get::<i64, usize>(1).unwrap();
365
366            if total == condition_matched {
367                Ok(ExecResponse::count(total as _))
368            } else {
369                Err(toasty_core::Error::condition_failed(
370                    "update condition did not match",
371                ))
372            }
373        }
374    }
375
376    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
377        for table in &schema.db.tables {
378            tracing::debug!(table = %table.name, "creating table");
379            self.create_table(&schema.db, table).await?;
380        }
381        Ok(())
382    }
383
384    async fn applied_migrations(
385        &mut self,
386    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
387        // Ensure the migrations table exists
388        self.conn
389            .exec_drop(
390                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
391                id BIGINT UNSIGNED PRIMARY KEY,
392                name TEXT NOT NULL,
393                applied_at TIMESTAMP NOT NULL
394            )",
395                (),
396            )
397            .await
398            .map_err(|e| record_mysql_err(&self.valid, e))?;
399
400        // Query all applied migrations
401        let rows: Vec<u64> = self
402            .conn
403            .exec("SELECT id FROM __toasty_migrations ORDER BY applied_at", ())
404            .await
405            .map_err(|e| record_mysql_err(&self.valid, e))?;
406
407        Ok(rows
408            .into_iter()
409            .map(toasty_core::schema::db::AppliedMigration::new)
410            .collect())
411    }
412
413    async fn apply_migration(
414        &mut self,
415        id: u64,
416        name: &str,
417        migration: &toasty_core::schema::db::Migration,
418    ) -> Result<()> {
419        tracing::info!(id = id, name = %name, "applying migration");
420        // Ensure the migrations table exists
421        self.conn
422            .exec_drop(
423                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
424                id BIGINT UNSIGNED PRIMARY KEY,
425                name TEXT NOT NULL,
426                applied_at TIMESTAMP NOT NULL
427            )",
428                (),
429            )
430            .await
431            .map_err(|e| record_mysql_err(&self.valid, e))?;
432
433        // Start transaction
434        let mut transaction = self
435            .conn
436            .start_transaction(Default::default())
437            .await
438            .map_err(|e| record_mysql_err(&self.valid, e))?;
439
440        // Execute each migration statement
441        for statement in migration.statements() {
442            if let Err(e) = transaction
443                .query_drop(statement)
444                .await
445                .map_err(|e| record_mysql_err(&self.valid, e))
446            {
447                transaction
448                    .rollback()
449                    .await
450                    .map_err(|e| record_mysql_err(&self.valid, e))?;
451                return Err(e);
452            }
453        }
454
455        // Record the migration
456        if let Err(e) = transaction
457            .exec_drop(
458                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES (?, ?, NOW())",
459                (id, name),
460            )
461            .await
462            .map_err(|e| record_mysql_err(&self.valid, e))
463        {
464            transaction
465                .rollback()
466                .await
467                .map_err(|e| record_mysql_err(&self.valid, e))?;
468            return Err(e);
469        }
470
471        // Commit transaction
472        transaction
473            .commit()
474            .await
475            .map_err(|e| record_mysql_err(&self.valid, e))?;
476        Ok(())
477    }
478
479    fn is_valid(&self) -> bool {
480        self.valid.get()
481    }
482
483    async fn ping(&mut self) -> Result<()> {
484        // `COM_PING` is the cheapest server round-trip in the MySQL
485        // protocol. Any failure is surfaced as `connection_lost`: the
486        // only meaningful outcome of a ping is "the connection is
487        // alive" or "evict it." Also flip the validity flag so a
488        // subsequent `is_valid` check observes the dead connection.
489        match self.conn.ping().await {
490            Ok(()) => Ok(()),
491            Err(e) => {
492                self.valid.set(false);
493                Err(toasty_core::Error::connection_lost(e))
494            }
495        }
496    }
497}