toasty_driver_dynamodb/
lib.rs

1#![warn(missing_docs)]
2
3//! Toasty driver for [Amazon DynamoDB](https://aws.amazon.com/dynamodb/) using
4//! the [`aws-sdk-dynamodb`](https://docs.rs/aws-sdk-dynamodb) SDK.
5//!
6//! # Examples
7//!
8//! ```no_run
9//! # async fn example() -> toasty_core::Result<()> {
10//! use toasty_driver_dynamodb::DynamoDb;
11//!
12//! let driver = DynamoDb::from_env("dynamodb://localhost".to_string()).await?;
13//! # Ok(())
14//! # }
15//! ```
16
17mod op;
18mod r#type;
19mod value;
20
21pub(crate) use r#type::TypeExt;
22pub(crate) use value::Value;
23
24use async_trait::async_trait;
25use toasty_core::{
26    Error, Result, Schema,
27    driver::{Capability, Driver, ExecResponse, operation::Operation},
28    schema::db::{self, Column, ColumnId, Migration, SchemaDiff, Table},
29    stmt::{self, ExprContext},
30};
31
32use aws_sdk_dynamodb::{
33    Client,
34    error::SdkError,
35    operation::update_item::UpdateItemError,
36    types::{
37        AttributeDefinition, AttributeValue, BillingMode, Delete, GlobalSecondaryIndex,
38        KeySchemaElement, KeyType, KeysAndAttributes, Projection, ProjectionType, Put, PutRequest,
39        ReturnValuesOnConditionCheckFailure, TransactWriteItem, Update, WriteRequest,
40    },
41};
42use std::{borrow::Cow, collections::HashMap, sync::Arc};
43
44/// A DynamoDB [`Driver`] backed by the AWS SDK.
45///
46/// Create one with [`DynamoDb::from_env`] to load AWS credentials and region
47/// from the environment, or [`DynamoDb::new`] / [`DynamoDb::with_sdk_config`]
48/// for manual setup.
49#[derive(Debug, Clone)]
50pub struct DynamoDb {
51    url: String,
52    client: Client,
53}
54
55impl DynamoDb {
56    /// Create driver with pre-built client (backward compatible, synchronous)
57    pub fn new(url: String, client: Client) -> Self {
58        Self { url, client }
59    }
60
61    /// Create driver loading AWS config from environment (async factory)
62    /// Reads: AWS_REGION, AWS_ENDPOINT_URL_DYNAMODB, AWS credentials, etc.
63    pub async fn from_env(url: String) -> Result<Self> {
64        use aws_config::BehaviorVersion;
65
66        let sdk_config = aws_config::defaults(BehaviorVersion::latest()).load().await;
67        let client = Client::new(&sdk_config);
68        Ok(Self::new(url, client))
69    }
70
71    /// Create driver with custom SdkConfig (synchronous)
72    pub fn with_sdk_config(url: String, sdk_config: &aws_config::SdkConfig) -> Self {
73        let client = Client::new(sdk_config);
74        Self::new(url, client)
75    }
76}
77
78#[async_trait]
79impl Driver for DynamoDb {
80    fn url(&self) -> Cow<'_, str> {
81        Cow::Borrowed(&self.url)
82    }
83
84    fn capability(&self) -> &'static Capability {
85        &Capability::DYNAMODB
86    }
87
88    async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
89        // Clone the shared client - cheap operation (Client uses Arc internally)
90        Ok(Box::new(Connection::new(self.client.clone())))
91    }
92
93    fn generate_migration(&self, _schema_diff: &SchemaDiff<'_>) -> Migration {
94        unimplemented!(
95            "DynamoDB migrations are not yet supported. DynamoDB schema changes require manual table updates through the AWS console or SDK."
96        )
97    }
98
99    async fn reset_db(&self) -> toasty_core::Result<()> {
100        // Use shared client directly
101        let mut exclusive_start_table_name = None;
102        loop {
103            let mut req = self.client.list_tables();
104            if let Some(start) = &exclusive_start_table_name {
105                req = req.exclusive_start_table_name(start);
106            }
107
108            let resp = req
109                .send()
110                .await
111                .map_err(toasty_core::Error::driver_operation_failed)?;
112
113            if let Some(table_names) = &resp.table_names {
114                for table_name in table_names {
115                    self.client
116                        .delete_table()
117                        .table_name(table_name)
118                        .send()
119                        .await
120                        .map_err(toasty_core::Error::driver_operation_failed)?;
121                }
122            }
123
124            exclusive_start_table_name = resp.last_evaluated_table_name;
125            if exclusive_start_table_name.is_none() {
126                break;
127            }
128        }
129
130        Ok(())
131    }
132}
133
134/// An open connection to DynamoDB.
135#[derive(Debug)]
136pub struct Connection {
137    /// Handle to the AWS SDK client
138    client: Client,
139}
140
141impl Connection {
142    /// Wrap an existing [`aws_sdk_dynamodb::Client`] as a Toasty connection.
143    pub fn new(client: Client) -> Self {
144        Self { client }
145    }
146}
147
148#[async_trait]
149impl toasty_core::driver::Connection for Connection {
150    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
151        self.exec2(schema, op).await
152    }
153
154    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
155        for table in &schema.db.tables {
156            tracing::debug!(table = %table.name, "creating table");
157            self.create_table(&schema.db, table, true).await?;
158        }
159        Ok(())
160    }
161
162    async fn applied_migrations(
163        &mut self,
164    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
165        todo!("DynamoDB migrations are not yet implemented")
166    }
167
168    async fn apply_migration(
169        &mut self,
170        _id: u64,
171        _name: &str,
172        _migration: &toasty_core::schema::db::Migration,
173    ) -> Result<()> {
174        todo!("DynamoDB migrations are not yet implemented")
175    }
176}
177
178impl Connection {
179    async fn exec2(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
180        match op {
181            Operation::GetByKey(op) => self.exec_get_by_key(schema, op).await,
182            Operation::QueryPk(op) => self.exec_query_pk(schema, op).await,
183            Operation::DeleteByKey(op) => self.exec_delete_by_key(&schema.db, op).await,
184            Operation::UpdateByKey(op) => self.exec_update_by_key(&schema.db, op).await,
185            Operation::FindPkByIndex(op) => self.exec_find_pk_by_index(schema, op).await,
186            Operation::QuerySql(op) => {
187                assert!(
188                    op.last_insert_id_hack.is_none(),
189                    "last_insert_id_hack is MySQL-specific and should not be set for DynamoDB"
190                );
191                match op.stmt {
192                    stmt::Statement::Insert(op) => self.exec_insert(&schema.db, op).await,
193                    _ => todo!("op={:#?}", op),
194                }
195            }
196            Operation::Transaction(_) => Err(Error::unsupported_feature(
197                "transactions are not supported by the DynamoDB driver",
198            )),
199            _ => todo!("op={op:#?}"),
200        }
201    }
202}
203
204fn ddb_key(table: &Table, key: &stmt::Value) -> HashMap<String, AttributeValue> {
205    let mut ret = HashMap::new();
206
207    for (index, column) in table.primary_key_columns().enumerate() {
208        let value = match key {
209            stmt::Value::Record(record) => &record[index],
210            value => value,
211        };
212
213        ret.insert(column.name.clone(), Value::from(value.clone()).to_ddb());
214    }
215
216    ret
217}
218
219/// Convert a DynamoDB AttributeValue to stmt::Value (type-inferred).
220fn attr_value_to_stmt_value(attr: &AttributeValue) -> stmt::Value {
221    use AttributeValue as AV;
222
223    match attr {
224        AV::S(s) => stmt::Value::String(s.clone()),
225        AV::N(n) => {
226            // Try to parse as i64 first (most common), fallback to string
227            n.parse::<i64>()
228                .map(stmt::Value::I64)
229                .unwrap_or_else(|_| stmt::Value::String(n.clone()))
230        }
231        AV::Bool(b) => stmt::Value::Bool(*b),
232        AV::B(bytes) => stmt::Value::Bytes(bytes.clone().into_inner()),
233        AV::Null(_) => stmt::Value::Null,
234        // For complex types, convert to string representation
235        _ => stmt::Value::String(format!("{:?}", attr)),
236    }
237}
238
239/// Serialize a DynamoDB LastEvaluatedKey (for pagination) into stmt::Value.
240/// Format: flat record [name1, value1, name2, value2, ...]
241/// Example: { "pk": S("abc"), "sk": N("42") } → Record([String("pk"), String("abc"), String("sk"), I64(42)])
242fn serialize_ddb_cursor(last_key: &HashMap<String, AttributeValue>) -> stmt::Value {
243    let mut fields = Vec::with_capacity(last_key.len() * 2);
244
245    for (name, attr_value) in last_key {
246        fields.push(stmt::Value::String(name.clone()));
247        fields.push(attr_value_to_stmt_value(attr_value));
248    }
249
250    stmt::Value::Record(stmt::ValueRecord::from_vec(fields))
251}
252
253/// Deserialize a stmt::Value cursor into a DynamoDB ExclusiveStartKey.
254/// Expects flat record format: [name1, value1, name2, value2, ...]
255fn deserialize_ddb_cursor(cursor: &stmt::Value) -> HashMap<String, AttributeValue> {
256    let mut ret = HashMap::new();
257
258    if let stmt::Value::Record(fields) = cursor {
259        // Process pairs: [name, value, name, value, ...]
260        for chunk in fields.chunks(2) {
261            if chunk.len() == 2
262                && let (stmt::Value::String(name), value) = (&chunk[0], &chunk[1])
263            {
264                ret.insert(name.clone(), Value::from(value.clone()).to_ddb());
265            }
266        }
267    }
268
269    ret
270}
271
272fn ddb_key_schema(partition: &Column, range: Option<&Column>) -> Vec<KeySchemaElement> {
273    let mut ks = vec![];
274
275    ks.push(
276        KeySchemaElement::builder()
277            .attribute_name(&partition.name)
278            .key_type(KeyType::Hash)
279            .build()
280            .unwrap(),
281    );
282
283    if let Some(range) = range {
284        ks.push(
285            KeySchemaElement::builder()
286                .attribute_name(&range.name)
287                .key_type(KeyType::Range)
288                .build()
289                .unwrap(),
290        );
291    }
292
293    ks
294}
295
296fn item_to_record<'a, 'stmt>(
297    item: &HashMap<String, AttributeValue>,
298    columns: impl Iterator<Item = &'a Column>,
299) -> Result<stmt::ValueRecord> {
300    Ok(stmt::ValueRecord::from_vec(
301        columns
302            .map(|column| {
303                if let Some(value) = item.get(&column.name) {
304                    Value::from_ddb(&column.ty, value).into_inner()
305                } else {
306                    stmt::Value::Null
307                }
308            })
309            .collect(),
310    ))
311}
312
313fn ddb_expression(
314    cx: &ExprContext<'_, db::Schema>,
315    attrs: &mut ExprAttrs,
316    primary: bool,
317    expr: &stmt::Expr,
318) -> String {
319    match expr {
320        stmt::Expr::BinaryOp(expr_binary_op) => {
321            let lhs = ddb_expression(cx, attrs, primary, &expr_binary_op.lhs);
322            let rhs = ddb_expression(cx, attrs, primary, &expr_binary_op.rhs);
323
324            match expr_binary_op.op {
325                stmt::BinaryOp::Eq => format!("{lhs} = {rhs}"),
326                stmt::BinaryOp::Ne if primary => {
327                    todo!("!= conditions on primary key not supported")
328                }
329                stmt::BinaryOp::Ne => format!("{lhs} <> {rhs}"),
330                stmt::BinaryOp::Gt => format!("{lhs} > {rhs}"),
331                stmt::BinaryOp::Ge => format!("{lhs} >= {rhs}"),
332                stmt::BinaryOp::Lt => format!("{lhs} < {rhs}"),
333                stmt::BinaryOp::Le => format!("{lhs} <= {rhs}"),
334            }
335        }
336        stmt::Expr::Reference(expr_reference) => {
337            let column = cx.resolve_expr_reference(expr_reference).as_column_unwrap();
338            attrs.column(column).to_string()
339        }
340        stmt::Expr::Value(val) => attrs.value(val),
341        stmt::Expr::And(expr_and) => {
342            let operands = expr_and
343                .operands
344                .iter()
345                .map(|operand| ddb_expression(cx, attrs, primary, operand))
346                .collect::<Vec<_>>();
347            operands.join(" AND ")
348        }
349        stmt::Expr::Or(expr_or) => {
350            let operands = expr_or
351                .operands
352                .iter()
353                .map(|operand| ddb_expression(cx, attrs, primary, operand))
354                .collect::<Vec<_>>();
355            operands.join(" OR ")
356        }
357        stmt::Expr::InList(in_list) => {
358            let expr = ddb_expression(cx, attrs, primary, &in_list.expr);
359
360            // Extract the list items and create individual attribute values
361            let items = match &*in_list.list {
362                stmt::Expr::Value(stmt::Value::List(vals)) => vals
363                    .iter()
364                    .map(|val| attrs.value(val))
365                    .collect::<Vec<_>>()
366                    .join(", "),
367                _ => {
368                    // If it's not a literal list, treat it as a single expression
369                    ddb_expression(cx, attrs, primary, &in_list.list)
370                }
371            };
372
373            format!("{expr} IN ({items})")
374        }
375        stmt::Expr::IsNull(expr_is_null) => {
376            let inner = ddb_expression(cx, attrs, primary, &expr_is_null.expr);
377            format!("attribute_not_exists({inner})")
378        }
379        stmt::Expr::Not(expr_not) => {
380            let inner = ddb_expression(cx, attrs, primary, &expr_not.expr);
381            format!("(NOT {inner})")
382        }
383        _ => todo!("FILTER = {:#?}", expr),
384    }
385}
386
387#[derive(Default)]
388struct ExprAttrs {
389    columns: HashMap<ColumnId, String>,
390    attr_names: HashMap<String, String>,
391    attr_values: HashMap<String, AttributeValue>,
392}
393
394impl ExprAttrs {
395    fn column(&mut self, column: &Column) -> &str {
396        use std::collections::hash_map::Entry;
397
398        match self.columns.entry(column.id) {
399            Entry::Vacant(e) => {
400                let name = format!("#col_{}", column.id.index);
401                self.attr_names.insert(name.clone(), column.name.clone());
402                e.insert(name)
403            }
404            Entry::Occupied(e) => e.into_mut(),
405        }
406    }
407
408    fn value(&mut self, val: &stmt::Value) -> String {
409        self.ddb_value(Value::from(val.clone()).to_ddb())
410    }
411
412    fn ddb_value(&mut self, val: AttributeValue) -> String {
413        let i = self.attr_values.len();
414        let name = format!(":v_{i}");
415        self.attr_values.insert(name.clone(), val);
416        name
417    }
418}