toasty_driver_mysql/
lib.rs

1#![allow(clippy::needless_range_loop)]
2
3mod value;
4pub(crate) use value::Value;
5
6use mysql_async::{
7    prelude::{Queryable, ToValue},
8    Conn, OptsBuilder,
9};
10use std::{borrow::Cow, sync::Arc};
11use toasty_core::{
12    async_trait,
13    driver::{Capability, Driver, Operation, Response},
14    schema::db::{self, Migration, SchemaDiff, Table},
15    stmt::{self, ValueRecord},
16    Result, Schema,
17};
18use toasty_sql::{self as sql, TypedValue};
19use url::Url;
20
21#[derive(Debug)]
22pub struct MySQL {
23    url: String,
24    opts: OptsBuilder,
25}
26
27impl MySQL {
28    pub fn new(url: impl Into<String>) -> Result<Self> {
29        let url_str = url.into();
30        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
31
32        if url.scheme() != "mysql" {
33            return Err(toasty_core::Error::invalid_connection_url(format!(
34                "connection url does not have a `mysql` scheme; url={}",
35                url
36            )));
37        }
38
39        url.host_str().ok_or_else(|| {
40            toasty_core::Error::invalid_connection_url(format!(
41                "missing host in connection URL; url={}",
42                url
43            ))
44        })?;
45
46        if url.path().is_empty() {
47            return Err(toasty_core::Error::invalid_connection_url(format!(
48                "no database specified - missing path in connection URL; url={}",
49                url
50            )));
51        }
52
53        let opts = mysql_async::Opts::from_url(url.as_ref())
54            .map_err(toasty_core::Error::driver_operation_failed)?;
55        let opts = mysql_async::OptsBuilder::from_opts(opts).client_found_rows(true);
56
57        Ok(Self { url: url_str, opts })
58    }
59}
60
61#[async_trait]
62impl Driver for MySQL {
63    fn url(&self) -> Cow<'_, str> {
64        Cow::Borrowed(&self.url)
65    }
66
67    fn capability(&self) -> &'static Capability {
68        &Capability::MYSQL
69    }
70
71    async fn connect(&self) -> Result<Box<dyn toasty_core::driver::Connection>> {
72        let conn = Conn::new(self.opts.clone())
73            .await
74            .map_err(toasty_core::Error::driver_operation_failed)?;
75        Ok(Box::new(Connection::new(conn)))
76    }
77
78    fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
79        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::MYSQL);
80
81        let sql_strings: Vec<String> = statements
82            .iter()
83            .map(|stmt| {
84                let mut params = Vec::<TypedValue>::new();
85                let sql =
86                    sql::Serializer::mysql(stmt.schema()).serialize(stmt.statement(), &mut params);
87                assert!(
88                    params.is_empty(),
89                    "migration statements should not have parameters"
90                );
91                sql
92            })
93            .collect();
94
95        Migration::new_sql_with_breakpoints(&sql_strings)
96    }
97
98    async fn reset_db(&self) -> toasty_core::Result<()> {
99        let mut conn = Conn::new(self.opts.clone())
100            .await
101            .map_err(toasty_core::Error::driver_operation_failed)?;
102
103        let dbname = conn
104            .opts()
105            .db_name()
106            .ok_or_else(|| {
107                toasty_core::Error::invalid_connection_url("no database name configured")
108            })?
109            .to_string();
110
111        conn.query_drop(format!("DROP DATABASE IF EXISTS `{}`", dbname))
112            .await
113            .map_err(toasty_core::Error::driver_operation_failed)?;
114
115        conn.query_drop(format!("CREATE DATABASE `{}`", dbname))
116            .await
117            .map_err(toasty_core::Error::driver_operation_failed)?;
118
119        conn.query_drop(format!("USE `{}`", dbname))
120            .await
121            .map_err(toasty_core::Error::driver_operation_failed)?;
122
123        Ok(())
124    }
125}
126
127#[derive(Debug)]
128pub struct Connection {
129    conn: Conn,
130}
131
132impl Connection {
133    pub fn new(conn: Conn) -> Self {
134        Self { conn }
135    }
136
137    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
138        let serializer = sql::Serializer::mysql(schema);
139
140        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
141
142        let sql = serializer.serialize(
143            &sql::Statement::create_table(table, &Capability::MYSQL),
144            &mut params,
145        );
146
147        assert!(
148            params.is_empty(),
149            "creating a table shouldn't involve any parameters"
150        );
151
152        self.conn
153            .exec_drop(&sql, ())
154            .await
155            .map_err(toasty_core::Error::driver_operation_failed)?;
156
157        for index in &table.indices {
158            if index.primary_key {
159                continue;
160            }
161
162            let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
163
164            assert!(
165                params.is_empty(),
166                "creating an index shouldn't involve any parameters"
167            );
168
169            self.conn
170                .exec_drop(&sql, ())
171                .await
172                .map_err(toasty_core::Error::driver_operation_failed)?;
173        }
174
175        Ok(())
176    }
177}
178
179impl From<Conn> for Connection {
180    fn from(conn: Conn) -> Self {
181        Self { conn }
182    }
183}
184
185#[async_trait]
186impl toasty_core::driver::Connection for Connection {
187    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<Response> {
188        let (sql, ret, last_insert_id_hack): (sql::Statement, _, _) = match op {
189            Operation::QuerySql(op) => (op.stmt.into(), op.ret, op.last_insert_id_hack),
190            Operation::Transaction(op) => {
191                let sql = sql::Serializer::mysql(&schema.db).serialize_transaction(&op);
192                self.conn.query_drop(sql).await.map_err(|e| match e {
193                    mysql_async::Error::Server(se) => match se.code {
194                        1213 => toasty_core::Error::serialization_failure(se.message),
195                        1792 => toasty_core::Error::read_only_transaction(se.message),
196                        _ => toasty_core::Error::driver_operation_failed(
197                            mysql_async::Error::Server(se),
198                        ),
199                    },
200                    other => toasty_core::Error::driver_operation_failed(other),
201                })?;
202                return Ok(Response::count(0));
203            }
204            op => todo!("op={:#?}", op),
205        };
206
207        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
208
209        let sql_as_str = sql::Serializer::mysql(&schema.db).serialize(&sql, &mut params);
210
211        let params = params
212            .into_iter()
213            .map(|tv| Value::from(tv.value))
214            .collect::<Vec<_>>();
215        let args = params
216            .iter()
217            .map(|param| param.to_value())
218            .collect::<Vec<_>>();
219
220        let statement = self
221            .conn
222            .prep(&sql_as_str)
223            .await
224            .map_err(toasty_core::Error::driver_operation_failed)?;
225
226        if ret.is_none() {
227            let count = self
228                .conn
229                .exec_iter(&statement, mysql_async::Params::Positional(args))
230                .await
231                .map_err(toasty_core::Error::driver_operation_failed)?
232                .affected_rows();
233
234            // Handle the last_insert_id_hack for MySQL INSERT with RETURNING
235            if let Some(num_rows) = last_insert_id_hack {
236                // Assert the previous statement was an INSERT
237                assert!(
238                    matches!(sql, sql::Statement::Insert(_)),
239                    "last_insert_id_hack should only be used with INSERT statements"
240                );
241
242                // Execute SELECT LAST_INSERT_ID() on the same connection
243                let first_id: u64 = self
244                    .conn
245                    .query_first("SELECT LAST_INSERT_ID()")
246                    .await
247                    .map_err(toasty_core::Error::driver_operation_failed)?
248                    .ok_or_else(|| {
249                        toasty_core::Error::driver_operation_failed(std::io::Error::other(
250                            "LAST_INSERT_ID() returned no rows",
251                        ))
252                    })?;
253
254                // Generate rows with sequential IDs
255                let results = (0..num_rows).map(move |offset| {
256                    let id = first_id + offset;
257                    // Return a record with a single field containing the ID
258                    Ok(ValueRecord::from_vec(vec![stmt::Value::U64(id)]))
259                });
260
261                return Ok(Response::value_stream(stmt::ValueStream::from_iter(
262                    results,
263                )));
264            }
265
266            return Ok(Response::count(count));
267        }
268
269        let rows: Vec<mysql_async::Row> = self
270            .conn
271            .exec(&statement, &args)
272            .await
273            .map_err(toasty_core::Error::driver_operation_failed)?;
274
275        if let Some(returning) = ret {
276            let results = rows.into_iter().map(move |mut row| {
277                assert_eq!(
278                    row.len(),
279                    returning.len(),
280                    "row={row:#?}; returning={returning:#?}"
281                );
282
283                let mut results = Vec::new();
284                for i in 0..row.len() {
285                    let column = &row.columns()[i];
286                    results.push(Value::from_sql(i, &mut row, column, &returning[i]).into_inner());
287                }
288
289                Ok(ValueRecord::from_vec(results))
290            });
291
292            Ok(Response::value_stream(stmt::ValueStream::from_iter(
293                results,
294            )))
295        } else {
296            let [row] = &rows[..] else { todo!() };
297            let total = row.get::<i64, usize>(0).unwrap();
298            let condition_matched = row.get::<i64, usize>(1).unwrap();
299
300            if total == condition_matched {
301                Ok(Response::count(total as _))
302            } else {
303                Err(toasty_core::Error::condition_failed(
304                    "update condition did not match",
305                ))
306            }
307        }
308    }
309
310    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
311        for table in &schema.db.tables {
312            self.create_table(&schema.db, table).await?;
313        }
314        Ok(())
315    }
316
317    async fn applied_migrations(
318        &mut self,
319    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
320        // Ensure the migrations table exists
321        self.conn
322            .exec_drop(
323                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
324                id BIGINT UNSIGNED PRIMARY KEY,
325                name TEXT NOT NULL,
326                applied_at TIMESTAMP NOT NULL
327            )",
328                (),
329            )
330            .await
331            .map_err(toasty_core::Error::driver_operation_failed)?;
332
333        // Query all applied migrations
334        let rows: Vec<u64> = self
335            .conn
336            .exec("SELECT id FROM __toasty_migrations ORDER BY applied_at", ())
337            .await
338            .map_err(toasty_core::Error::driver_operation_failed)?;
339
340        Ok(rows
341            .into_iter()
342            .map(toasty_core::schema::db::AppliedMigration::new)
343            .collect())
344    }
345
346    async fn apply_migration(
347        &mut self,
348        id: u64,
349        name: String,
350        migration: &toasty_core::schema::db::Migration,
351    ) -> Result<()> {
352        // Ensure the migrations table exists
353        self.conn
354            .exec_drop(
355                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
356                id BIGINT UNSIGNED PRIMARY KEY,
357                name TEXT NOT NULL,
358                applied_at TIMESTAMP NOT NULL
359            )",
360                (),
361            )
362            .await
363            .map_err(toasty_core::Error::driver_operation_failed)?;
364
365        // Start transaction
366        let mut transaction = self
367            .conn
368            .start_transaction(Default::default())
369            .await
370            .map_err(toasty_core::Error::driver_operation_failed)?;
371
372        // Execute each migration statement
373        for statement in migration.statements() {
374            if let Err(e) = transaction
375                .query_drop(statement)
376                .await
377                .map_err(toasty_core::Error::driver_operation_failed)
378            {
379                transaction
380                    .rollback()
381                    .await
382                    .map_err(toasty_core::Error::driver_operation_failed)?;
383                return Err(e);
384            }
385        }
386
387        // Record the migration
388        if let Err(e) = transaction
389            .exec_drop(
390                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES (?, ?, NOW())",
391                (id, name),
392            )
393            .await
394            .map_err(toasty_core::Error::driver_operation_failed)
395        {
396            transaction
397                .rollback()
398                .await
399                .map_err(toasty_core::Error::driver_operation_failed)?;
400            return Err(e);
401        }
402
403        // Commit transaction
404        transaction
405            .commit()
406            .await
407            .map_err(toasty_core::Error::driver_operation_failed)?;
408        Ok(())
409    }
410}