1use super::{DiffContext, TableId, Type, table};
2use crate::stmt;
3
4use std::{
5 collections::{HashMap, HashSet},
6 fmt,
7 ops::Deref,
8};
9
10#[derive(Debug, Clone, PartialEq)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36pub struct Column {
37 pub id: ColumnId,
39
40 pub name: String,
42
43 pub ty: stmt::Type,
45
46 pub storage_ty: Type,
48
49 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "is_false"))]
51 pub nullable: bool,
52
53 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "is_false"))]
55 pub primary_key: bool,
56
57 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "is_false"))]
61 pub auto_increment: bool,
62}
63
64#[cfg(feature = "serde")]
65fn is_false(b: &bool) -> bool {
66 !*b
67}
68
69#[derive(PartialEq, Eq, Clone, Copy, Hash)]
83#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
84pub struct ColumnId {
85 pub table: TableId,
87 pub index: usize,
89}
90
91impl ColumnId {
92 pub(crate) fn placeholder() -> Self {
93 Self {
94 table: table::TableId::placeholder(),
95 index: usize::MAX,
96 }
97 }
98}
99
100impl From<&Column> for ColumnId {
101 fn from(value: &Column) -> Self {
102 value.id
103 }
104}
105
106impl fmt::Debug for ColumnId {
107 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(fmt, "ColumnId({}/{})", self.table.0, self.index)
109 }
110}
111
112pub struct ColumnsDiff<'a> {
130 items: Vec<ColumnsDiffItem<'a>>,
131}
132
133impl<'a> ColumnsDiff<'a> {
134 pub fn from(cx: &DiffContext<'a>, previous: &'a [Column], next: &'a [Column]) -> Self {
140 fn has_diff(previous: &Column, next: &Column) -> bool {
141 previous.name != next.name
142 || previous.storage_ty != next.storage_ty
143 || previous.nullable != next.nullable
144 || previous.primary_key != next.primary_key
145 || previous.auto_increment != next.auto_increment
146 }
147
148 let mut items = vec![];
149 let mut add_ids: HashSet<_> = next.iter().map(|next| next.id).collect();
150
151 let next_map =
152 HashMap::<&str, &'a Column>::from_iter(next.iter().map(|to| (to.name.as_str(), to)));
153
154 for previous in previous {
155 let next = if let Some(next_id) = cx.rename_hints().get_column(previous.id) {
156 cx.next().column(next_id)
157 } else if let Some(next) = next_map.get(previous.name.as_str()) {
158 next
159 } else {
160 items.push(ColumnsDiffItem::DropColumn(previous));
161 continue;
162 };
163
164 add_ids.remove(&next.id);
165
166 if has_diff(previous, next) {
167 items.push(ColumnsDiffItem::AlterColumn { previous, next });
168 }
169 }
170
171 for column_id in add_ids {
172 items.push(ColumnsDiffItem::AddColumn(cx.next().column(column_id)));
173 }
174
175 Self { items }
176 }
177
178 pub const fn is_empty(&self) -> bool {
180 self.items.is_empty()
181 }
182}
183
184impl<'a> Deref for ColumnsDiff<'a> {
185 type Target = Vec<ColumnsDiffItem<'a>>;
186
187 fn deref(&self) -> &Self::Target {
188 &self.items
189 }
190}
191
192pub enum ColumnsDiffItem<'a> {
194 AddColumn(&'a Column),
196 DropColumn(&'a Column),
198 AlterColumn {
200 previous: &'a Column,
202 next: &'a Column,
204 },
205}
206
207#[cfg(test)]
208mod tests {
209 use crate::schema::db::{
210 Column, ColumnId, ColumnsDiff, ColumnsDiffItem, DiffContext, PrimaryKey, RenameHints,
211 Schema, Table, TableId, Type,
212 };
213 use crate::stmt;
214
215 fn make_column(
216 table_id: usize,
217 index: usize,
218 name: &str,
219 storage_ty: Type,
220 nullable: bool,
221 ) -> Column {
222 Column {
223 id: ColumnId {
224 table: TableId(table_id),
225 index,
226 },
227 name: name.to_string(),
228 ty: stmt::Type::String, storage_ty,
230 nullable,
231 primary_key: false,
232 auto_increment: false,
233 }
234 }
235
236 fn make_schema_with_columns(table_id: usize, columns: Vec<Column>) -> Schema {
237 let mut schema = Schema::default();
238 schema.tables.push(Table {
239 id: TableId(table_id),
240 name: "test_table".to_string(),
241 columns,
242 primary_key: PrimaryKey {
243 columns: vec![],
244 index: super::super::IndexId {
245 table: TableId(table_id),
246 index: 0,
247 },
248 },
249 indices: vec![],
250 });
251 schema
252 }
253
254 #[test]
255 fn test_no_diff_same_columns() {
256 let from_cols = vec![
257 make_column(0, 0, "id", Type::Integer(8), false),
258 make_column(0, 1, "name", Type::Text, false),
259 ];
260 let to_cols = vec![
261 make_column(0, 0, "id", Type::Integer(8), false),
262 make_column(0, 1, "name", Type::Text, false),
263 ];
264
265 let from_schema = make_schema_with_columns(0, from_cols.clone());
266 let to_schema = make_schema_with_columns(0, to_cols.clone());
267 let hints = RenameHints::new();
268 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
269
270 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
271 assert!(diff.is_empty());
272 }
273
274 #[test]
275 fn test_add_column() {
276 let from_cols = vec![make_column(0, 0, "id", Type::Integer(8), false)];
277 let to_cols = vec![
278 make_column(0, 0, "id", Type::Integer(8), false),
279 make_column(0, 1, "name", Type::Text, false),
280 ];
281
282 let from_schema = make_schema_with_columns(0, from_cols.clone());
283 let to_schema = make_schema_with_columns(0, to_cols.clone());
284 let hints = RenameHints::new();
285 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
286
287 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
288 assert_eq!(diff.items.len(), 1);
289 assert!(matches!(diff.items[0], ColumnsDiffItem::AddColumn(_)));
290 if let ColumnsDiffItem::AddColumn(col) = diff.items[0] {
291 assert_eq!(col.name, "name");
292 }
293 }
294
295 #[test]
296 fn test_drop_column() {
297 let from_cols = vec![
298 make_column(0, 0, "id", Type::Integer(8), false),
299 make_column(0, 1, "name", Type::Text, false),
300 ];
301 let to_cols = vec![make_column(0, 0, "id", Type::Integer(8), false)];
302
303 let from_schema = make_schema_with_columns(0, from_cols.clone());
304 let to_schema = make_schema_with_columns(0, to_cols.clone());
305 let hints = RenameHints::new();
306 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
307
308 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
309 assert_eq!(diff.items.len(), 1);
310 assert!(matches!(diff.items[0], ColumnsDiffItem::DropColumn(_)));
311 if let ColumnsDiffItem::DropColumn(col) = diff.items[0] {
312 assert_eq!(col.name, "name");
313 }
314 }
315
316 #[test]
317 fn test_alter_column_type() {
318 let from_cols = vec![make_column(0, 0, "id", Type::Integer(8), false)];
319 let to_cols = vec![make_column(0, 0, "id", Type::Text, false)];
320
321 let from_schema = make_schema_with_columns(0, from_cols.clone());
322 let to_schema = make_schema_with_columns(0, to_cols.clone());
323 let hints = RenameHints::new();
324 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
325
326 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
327 assert_eq!(diff.items.len(), 1);
328 assert!(matches!(diff.items[0], ColumnsDiffItem::AlterColumn { .. }));
329 }
330
331 #[test]
332 fn test_alter_column_nullable() {
333 let from_cols = vec![make_column(0, 0, "id", Type::Integer(8), false)];
334 let to_cols = vec![make_column(0, 0, "id", Type::Integer(8), true)];
335
336 let from_schema = make_schema_with_columns(0, from_cols.clone());
337 let to_schema = make_schema_with_columns(0, to_cols.clone());
338 let hints = RenameHints::new();
339 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
340
341 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
342 assert_eq!(diff.items.len(), 1);
343 assert!(matches!(diff.items[0], ColumnsDiffItem::AlterColumn { .. }));
344 }
345
346 #[test]
347 fn test_rename_column_with_hint() {
348 let from_cols = vec![make_column(0, 0, "old_name", Type::Text, false)];
350 let to_cols = vec![make_column(0, 0, "new_name", Type::Text, false)];
351
352 let from_schema = make_schema_with_columns(0, from_cols.clone());
353 let to_schema = make_schema_with_columns(0, to_cols.clone());
354
355 let mut hints = RenameHints::new();
356 hints.add_column_hint(
357 ColumnId {
358 table: TableId(0),
359 index: 0,
360 },
361 ColumnId {
362 table: TableId(0),
363 index: 0,
364 },
365 );
366 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
367
368 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
369 assert_eq!(diff.items.len(), 1);
370 assert!(matches!(diff.items[0], ColumnsDiffItem::AlterColumn { .. }));
371 if let ColumnsDiffItem::AlterColumn { previous, next } = diff.items[0] {
372 assert_eq!(previous.name, "old_name");
373 assert_eq!(next.name, "new_name");
374 }
375 }
376
377 #[test]
378 fn test_rename_column_without_hint_is_drop_and_add() {
379 let from_cols = vec![make_column(0, 0, "old_name", Type::Text, false)];
382 let to_cols = vec![make_column(0, 0, "new_name", Type::Text, false)];
383
384 let from_schema = make_schema_with_columns(0, from_cols.clone());
385 let to_schema = make_schema_with_columns(0, to_cols.clone());
386 let hints = RenameHints::new();
387 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
388
389 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
390 assert_eq!(diff.items.len(), 2);
391
392 let has_drop = diff
393 .items
394 .iter()
395 .any(|item| matches!(item, ColumnsDiffItem::DropColumn(_)));
396 let has_add = diff
397 .items
398 .iter()
399 .any(|item| matches!(item, ColumnsDiffItem::AddColumn(_)));
400 assert!(has_drop);
401 assert!(has_add);
402 }
403
404 #[cfg(feature = "serde")]
405 mod serde_tests {
406 use crate::schema::db::{Column, ColumnId, TableId, Type};
407 use crate::stmt;
408
409 fn base_column() -> Column {
410 Column {
411 id: ColumnId {
412 table: TableId(0),
413 index: 0,
414 },
415 name: "test".to_string(),
416 ty: stmt::Type::String,
417 storage_ty: Type::Text,
418 nullable: false,
419 primary_key: false,
420 auto_increment: false,
421 }
422 }
423
424 #[test]
425 fn false_booleans_are_omitted() {
426 let toml = toml::to_string(&base_column()).unwrap();
427 assert!(!toml.contains("nullable"), "toml: {toml}");
428 assert!(!toml.contains("primary_key"), "toml: {toml}");
429 assert!(!toml.contains("auto_increment"), "toml: {toml}");
430 }
431
432 #[test]
433 fn nullable_true_is_included() {
434 let col = Column {
435 nullable: true,
436 ..base_column()
437 };
438 let toml = toml::to_string(&col).unwrap();
439 assert!(toml.contains("nullable = true"), "toml: {toml}");
440 }
441
442 #[test]
443 fn primary_key_true_is_included() {
444 let col = Column {
445 primary_key: true,
446 ..base_column()
447 };
448 let toml = toml::to_string(&col).unwrap();
449 assert!(toml.contains("primary_key = true"), "toml: {toml}");
450 }
451
452 #[test]
453 fn auto_increment_true_is_included() {
454 let col = Column {
455 auto_increment: true,
456 ..base_column()
457 };
458 let toml = toml::to_string(&col).unwrap();
459 assert!(toml.contains("auto_increment = true"), "toml: {toml}");
460 }
461
462 #[test]
463 fn missing_bool_fields_deserialize_as_false() {
464 let toml = "name = \"test\"\nty = \"String\"\nstorage_ty = \"Text\"\n\n[id]\ntable = 0\nindex = 0\n";
465 let col: Column = toml::from_str(toml).unwrap();
466 assert!(!col.nullable);
467 assert!(!col.primary_key);
468 assert!(!col.auto_increment);
469 }
470
471 #[test]
472 fn round_trip_all_true() {
473 let original = Column {
474 nullable: true,
475 primary_key: true,
476 auto_increment: true,
477 ..base_column()
478 };
479 let decoded: Column = toml::from_str(&toml::to_string(&original).unwrap()).unwrap();
480 assert_eq!(original, decoded);
481 }
482 }
483
484 #[test]
485 fn test_multiple_operations() {
486 let from_cols = vec![
487 make_column(0, 0, "id", Type::Integer(8), false),
488 make_column(0, 1, "old_name", Type::Text, false),
489 make_column(0, 2, "to_drop", Type::Text, false),
490 ];
491 let to_cols = vec![
492 make_column(0, 0, "id", Type::Text, false), make_column(0, 1, "new_name", Type::Text, false), make_column(0, 2, "added", Type::Integer(8), false), ];
496
497 let from_schema = make_schema_with_columns(0, from_cols.clone());
498 let to_schema = make_schema_with_columns(0, to_cols.clone());
499
500 let mut hints = RenameHints::new();
501 hints.add_column_hint(
502 ColumnId {
503 table: TableId(0),
504 index: 1,
505 },
506 ColumnId {
507 table: TableId(0),
508 index: 1,
509 },
510 );
511 let cx = DiffContext::new(&from_schema, &to_schema, &hints);
512
513 let diff = ColumnsDiff::from(&cx, &from_cols, &to_cols);
514 assert_eq!(diff.items.len(), 4);
516 }
517}