1use crate::{
2 Schema,
3 schema::{
4 app::{Field, Model, ModelId, ModelRoot},
5 db::{self, Column, ColumnId, Table, TableId},
6 },
7 stmt::{
8 Delete, Expr, ExprArg, ExprColumn, ExprFunc, ExprReference, ExprSet, Insert, InsertTarget,
9 Query, Returning, Select, Source, SourceTable, Statement, TableDerived, TableFactor,
10 TableRef, Type, TypeUnion, Update, UpdateTarget,
11 },
12};
13
14#[derive(Debug)]
31pub struct ExprContext<'a, T = Schema> {
32 schema: &'a T,
33 parent: Option<&'a ExprContext<'a, T>>,
34 target: ExprTarget<'a>,
35}
36
37#[derive(Debug)]
49pub enum ResolvedRef<'a> {
50 Column(&'a Column),
59
60 Field(&'a Field),
70
71 Model(&'a ModelRoot),
73
74 Cte {
84 nesting: usize,
86 index: usize,
88 },
89
90 Derived(DerivedRef<'a>),
96}
97
98#[derive(Debug)]
100pub struct DerivedRef<'a> {
101 pub nesting: usize,
103
104 pub index: usize,
106
107 pub derived: &'a TableDerived,
109}
110
111impl DerivedRef<'_> {
112 pub fn is_column_always_null(&self) -> bool {
118 let ExprSet::Values(values) = &self.derived.subquery.body else {
119 return false;
120 };
121
122 if values.is_empty() {
123 return false;
124 }
125
126 values.rows.iter().all(|row| self.row_column_is_null(row))
127 }
128
129 fn row_column_is_null(&self, row: &Expr) -> bool {
130 match row {
131 Expr::Value(super::Value::Record(record)) => {
132 self.index < record.len() && record[self.index].is_null()
133 }
134 Expr::Record(record) => {
135 self.index < record.len()
136 && matches!(&record.fields[self.index], Expr::Value(super::Value::Null))
137 }
138 Expr::Value(super::Value::Null) => true,
139 _ => false,
140 }
141 }
142}
143
144#[derive(Debug, Clone, Copy)]
149pub enum ExprTarget<'a> {
150 Free,
152
153 Model(&'a ModelRoot),
155
156 Table(&'a Table),
160
161 Source(&'a SourceTable),
163}
164
165pub trait Resolve {
171 fn table_for_model(&self, model: &ModelRoot) -> Option<&Table>;
173
174 fn model(&self, id: ModelId) -> Option<&Model>;
180
181 fn table(&self, id: TableId) -> Option<&Table>;
187}
188
189pub trait IntoExprTarget<'a, T = Schema> {
192 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a>;
194}
195
196#[derive(Debug)]
197struct ArgTyStack<'a> {
198 tys: &'a [Type],
199 parent: Option<&'a ArgTyStack<'a>>,
200}
201
202impl<'a, T> ExprContext<'a, T> {
203 pub fn schema(&self) -> &'a T {
205 self.schema
206 }
207
208 pub fn target(&self) -> ExprTarget<'a> {
210 self.target
211 }
212
213 pub fn target_at(&self, nesting: usize) -> &ExprTarget<'a> {
215 let mut curr = self;
216
217 for _ in 0..nesting {
219 let Some(parent) = curr.parent else {
220 todo!("bug: invalid nesting level");
221 };
222
223 curr = parent;
224 }
225
226 &curr.target
227 }
228}
229
230impl<'a> ExprContext<'a, ()> {
231 pub fn new_free() -> ExprContext<'a, ()> {
233 ExprContext {
234 schema: &(),
235 parent: None,
236 target: ExprTarget::Free,
237 }
238 }
239}
240
241impl<'a, T: Resolve> ExprContext<'a, T> {
242 pub fn new(schema: &'a T) -> ExprContext<'a, T> {
244 ExprContext::new_with_target(schema, ExprTarget::Free)
245 }
246
247 pub fn new_with_target(
249 schema: &'a T,
250 target: impl IntoExprTarget<'a, T>,
251 ) -> ExprContext<'a, T> {
252 let target = target.into_expr_target(schema);
253 ExprContext {
254 schema,
255 parent: None,
256 target,
257 }
258 }
259
260 pub fn scope<'child>(
263 &'child self,
264 target: impl IntoExprTarget<'child, T>,
265 ) -> ExprContext<'child, T> {
267 let target = target.into_expr_target(self.schema);
268 ExprContext {
269 schema: self.schema,
270 parent: Some(self),
271 target,
272 }
273 }
274
275 pub fn resolve_expr_reference(&self, expr_reference: &ExprReference) -> ResolvedRef<'a> {
291 let nesting = match expr_reference {
292 ExprReference::Column(expr_column) => expr_column.nesting,
293 ExprReference::Field { nesting, .. } => *nesting,
294 ExprReference::Model { nesting } => *nesting,
295 };
296
297 let target = self.target_at(nesting);
298
299 match target {
300 ExprTarget::Free => todo!("cannot resolve column in free context"),
301 ExprTarget::Model(model) => match expr_reference {
302 ExprReference::Model { .. } => ResolvedRef::Model(model),
303 ExprReference::Field { index, .. } => ResolvedRef::Field(&model.fields[*index]),
304 ExprReference::Column(expr_column) => {
305 assert_eq!(expr_column.table, 0, "TODO: is this true?");
306
307 let Some(table) = self.schema.table_for_model(model) else {
308 panic!(
309 "Failed to find database table for model '{:?}' - model may not be mapped to a table",
310 model.name
311 )
312 };
313 ResolvedRef::Column(&table.columns[expr_column.column])
314 }
315 },
316 ExprTarget::Table(table) => match expr_reference {
317 ExprReference::Model { .. } => {
318 panic!("Cannot resolve ExprReference::Model in Table target context")
319 }
320 ExprReference::Field { .. } => panic!(
321 "Cannot resolve ExprReference::Field in Table target context - use ExprReference::Column instead"
322 ),
323 ExprReference::Column(expr_column) => {
324 ResolvedRef::Column(&table.columns[expr_column.column])
325 }
326 },
327 ExprTarget::Source(source_table) => {
328 match expr_reference {
329 ExprReference::Column(expr_column) => {
330 let table_ref = &source_table.tables[expr_column.table];
332 match table_ref {
333 TableRef::Table(table_id) => {
334 let Some(table) = self.schema.table(*table_id) else {
335 panic!(
336 "Failed to resolve table with ID {:?} - table not found in schema.",
337 table_id,
338 );
339 };
340 ResolvedRef::Column(&table.columns[expr_column.column])
341 }
342 TableRef::Derived(derived) => ResolvedRef::Derived(DerivedRef {
343 nesting: expr_column.nesting,
344 index: expr_column.column,
345 derived,
346 }),
347 TableRef::Cte {
348 nesting: cte_nesting,
349 index,
350 } => {
351 ResolvedRef::Cte {
353 nesting: expr_column.nesting + cte_nesting,
354 index: *index,
355 }
356 }
357 TableRef::Arg(_) => todo!(),
358 }
359 }
360 ExprReference::Model { .. } => {
361 panic!("Cannot resolve ExprReference::Model in Source::Table context")
362 }
363 ExprReference::Field { .. } => panic!(
364 "Cannot resolve ExprReference::Field in Source::Table context - use ExprReference::Column instead"
365 ),
366 }
367 }
368 }
369 }
370
371 pub fn infer_stmt_ty(&self, stmt: &Statement, args: &[Type]) -> Type {
373 let cx = self.scope(stmt);
374
375 match stmt {
376 Statement::Delete(stmt) => stmt
377 .returning
378 .as_ref()
379 .map(|returning| cx.infer_returning_ty(returning, args, false))
380 .unwrap_or(Type::Unit),
381 Statement::Insert(stmt) => stmt
382 .returning
383 .as_ref()
384 .map(|returning| cx.infer_returning_ty(returning, args, stmt.source.single))
385 .unwrap_or(Type::Unit),
386 Statement::Query(stmt) => match &stmt.body {
387 ExprSet::Select(body) => cx.infer_returning_ty(&body.returning, args, stmt.single),
388 ExprSet::SetOp(_body) => todo!(),
389 ExprSet::Update(_body) => todo!(),
390 ExprSet::Values(_body) => todo!(),
391 ExprSet::Insert(body) => body
392 .returning
393 .as_ref()
394 .map(|returning| cx.infer_returning_ty(returning, args, stmt.single))
395 .unwrap_or(Type::Unit),
396 },
397 Statement::Update(stmt) => stmt
398 .returning
399 .as_ref()
400 .map(|returning| cx.infer_returning_ty(returning, args, false))
401 .unwrap_or(Type::Unit),
402 }
403 }
404
405 fn infer_returning_ty(&self, returning: &Returning, args: &[Type], single: bool) -> Type {
406 let arg_ty_stack = ArgTyStack::new(args);
407
408 match returning {
409 Returning::Model { .. } => {
410 let ty = Type::Model(
411 self.target
412 .model_id()
413 .expect("returning `Model` when not in model context"),
414 );
415
416 if single { ty } else { Type::list(ty) }
417 }
418 Returning::Changed => todo!(),
419 Returning::Expr(expr) => {
420 let ty = self.infer_expr_ty2(&arg_ty_stack, expr, false);
421
422 if single { ty } else { Type::list(ty) }
423 }
424 Returning::Value(expr) => self.infer_expr_ty2(&arg_ty_stack, expr, true),
425 }
426 }
427
428 pub fn infer_expr_ty(&self, expr: &Expr, args: &[Type]) -> Type {
430 let arg_ty_stack = ArgTyStack::new(args);
431 self.infer_expr_ty2(&arg_ty_stack, expr, false)
432 }
433
434 fn infer_expr_ty2(&self, args: &ArgTyStack<'_>, expr: &Expr, returning_expr: bool) -> Type {
435 match expr {
436 Expr::Arg(e) => args.resolve_arg_ty(e).clone(),
437 Expr::And(_) => Type::Bool,
438 Expr::BinaryOp(_) => Type::Bool,
439 Expr::Cast(e) => e.ty.clone(),
440 Expr::Reference(expr_ref) => {
441 assert!(
442 !returning_expr,
443 "should have been handled in Expr::Project. Invalid expr?"
444 );
445 self.infer_expr_reference_ty(expr_ref)
446 }
447 Expr::IsNull(_) => Type::Bool,
448 Expr::IsVariant(_) => Type::Bool,
449 Expr::List(e) => {
450 debug_assert!(!e.items.is_empty());
451 Type::list(self.infer_expr_ty2(args, &e.items[0], returning_expr))
452 }
453 Expr::Map(e) => {
454 let base = self.infer_expr_ty2(args, &e.base, returning_expr);
456
457 let Type::List(item) = base else {
459 todo!("error handling; base={base:#?}")
460 };
461
462 let scope_tys = &[*item];
463
464 let args = args.scope(scope_tys);
466
467 let ty = self.infer_expr_ty2(&args, &e.map, returning_expr);
469
470 Type::list(ty)
472 }
473 Expr::Or(_) => Type::Bool,
474 Expr::Project(e) => {
475 if returning_expr {
476 match &*e.base {
477 Expr::Arg(expr_arg) => {
478 assert!(e.projection.as_slice().len() == 1);
483 return args.resolve_arg_ty(expr_arg).clone();
484 }
485 Expr::Reference(expr_reference) => {
486 assert!(e.projection.as_slice().len() == 1);
491 return self.infer_expr_reference_ty(expr_reference);
492 }
493 _ => {}
494 }
495 }
496
497 let mut base = self.infer_expr_ty2(args, &e.base, returning_expr);
498
499 for step in e.projection.iter() {
500 base = match base {
501 Type::Record(mut fields) => {
502 std::mem::replace(&mut fields[*step], Type::Null)
503 }
504 Type::List(items) => *items,
505 expr => todo!(
506 "returning_expr={returning_expr:#?}; expr={expr:#?}; project={e:#?}"
507 ),
508 }
509 }
510
511 base
512 }
513 Expr::Record(e) => Type::Record(
514 e.fields
515 .iter()
516 .map(|field| self.infer_expr_ty2(args, field, returning_expr))
517 .collect(),
518 ),
519 Expr::Value(value) => value.infer_ty(),
520 Expr::Let(expr_let) => {
521 let scope_tys: Vec<_> = expr_let
522 .bindings
523 .iter()
524 .map(|b| self.infer_expr_ty2(args, b, returning_expr))
525 .collect();
526 let args = args.scope(&scope_tys);
527 self.infer_expr_ty2(&args, &expr_let.body, returning_expr)
528 }
529 Expr::Match(expr_match) => {
530 let mut union = TypeUnion::new();
535 for arm in &expr_match.arms {
536 let ty = self.infer_expr_ty2(args, &arm.expr, returning_expr);
537 union.insert(ty);
538 }
539 let else_ty = self.infer_expr_ty2(args, &expr_match.else_expr, returning_expr);
540 union.insert(else_ty);
541 union.simplify()
542 }
543 Expr::Error(_) => Type::Unknown,
547 Expr::Exists(_) => Type::Bool,
548 Expr::Func(ExprFunc::Count(_)) => Type::U64,
549 Expr::Func(ExprFunc::LastInsertId(_)) => Type::I64,
550 _ => todo!("{expr:#?}"),
551 }
552 }
553
554 pub fn infer_expr_reference_ty(&self, expr_reference: &ExprReference) -> Type {
556 match self.resolve_expr_reference(expr_reference) {
557 ResolvedRef::Model(model) => Type::Model(model.id),
558 ResolvedRef::Column(column) => column.ty.clone(),
559 ResolvedRef::Field(field) => field.expr_ty().clone(),
560 ResolvedRef::Cte { .. } => todo!("type inference for CTE columns not implemented"),
561 ResolvedRef::Derived(_) => {
562 todo!("type inference for derived table columns not implemented")
563 }
564 }
565 }
566}
567
568impl<'a> ExprContext<'a, Schema> {
569 pub fn target_as_model(&self) -> Option<&'a ModelRoot> {
572 self.target.as_model()
573 }
574
575 pub fn expr_ref_column(&self, column_id: impl Into<ColumnId>) -> ExprReference {
583 let column_id = column_id.into();
584
585 match self.target {
586 ExprTarget::Free => {
587 panic!("Cannot create ExprColumn in free context - no table target available")
588 }
589 ExprTarget::Model(model) => {
590 let Some(table) = self.schema.table_for_model(model) else {
591 panic!(
592 "Failed to find database table for model '{:?}' - model may not be mapped to a table",
593 model.name
594 )
595 };
596
597 assert_eq!(table.id, column_id.table);
598 }
599 ExprTarget::Table(table) => assert_eq!(table.id, column_id.table),
600 ExprTarget::Source(source_table) => {
601 let [TableRef::Table(table_id)] = source_table.tables[..] else {
602 panic!(
603 "Expected exactly one table reference, found {} tables",
604 source_table.tables.len()
605 );
606 };
607 assert_eq!(table_id, column_id.table);
608 }
609 }
610
611 ExprReference::Column(ExprColumn {
612 nesting: 0,
613 table: 0,
614 column: column_id.index,
615 })
616 }
617}
618
619impl<'a, T> Clone for ExprContext<'a, T> {
620 fn clone(&self) -> Self {
621 *self
622 }
623}
624
625impl<'a, T> Copy for ExprContext<'a, T> {}
626
627impl<'a> ResolvedRef<'a> {
628 #[track_caller]
634 pub fn as_column_unwrap(self) -> &'a Column {
635 match self {
636 ResolvedRef::Column(column) => column,
637 _ => panic!("Expected ResolvedRef::Column, found {:?}", self),
638 }
639 }
640
641 #[track_caller]
647 pub fn as_field_unwrap(self) -> &'a Field {
648 match self {
649 ResolvedRef::Field(field) => field,
650 _ => panic!("Expected ResolvedRef::Field, found {:?}", self),
651 }
652 }
653
654 #[track_caller]
660 pub fn as_model_unwrap(self) -> &'a ModelRoot {
661 match self {
662 ResolvedRef::Model(model) => model,
663 _ => panic!("Expected ResolvedRef::Model, found {:?}", self),
664 }
665 }
666}
667
668impl Resolve for Schema {
669 fn model(&self, id: ModelId) -> Option<&Model> {
670 Some(self.app.model(id))
671 }
672
673 fn table(&self, id: TableId) -> Option<&Table> {
674 Some(self.db.table(id))
675 }
676
677 fn table_for_model(&self, model: &ModelRoot) -> Option<&Table> {
678 Some(self.table_for(model.id))
679 }
680}
681
682impl Resolve for db::Schema {
683 fn model(&self, _id: ModelId) -> Option<&Model> {
684 None
685 }
686
687 fn table(&self, id: TableId) -> Option<&Table> {
688 Some(db::Schema::table(self, id))
689 }
690
691 fn table_for_model(&self, _model: &ModelRoot) -> Option<&Table> {
692 None
693 }
694}
695
696impl Resolve for () {
697 fn model(&self, _id: ModelId) -> Option<&Model> {
698 None
699 }
700
701 fn table(&self, _id: TableId) -> Option<&Table> {
702 None
703 }
704
705 fn table_for_model(&self, _model: &ModelRoot) -> Option<&Table> {
706 None
707 }
708}
709
710impl<'a> ExprTarget<'a> {
711 pub fn as_model(self) -> Option<&'a ModelRoot> {
713 match self {
714 ExprTarget::Model(model) => Some(model),
715 _ => None,
716 }
717 }
718
719 #[track_caller]
725 pub fn as_model_unwrap(self) -> &'a ModelRoot {
726 match self.as_model() {
727 Some(model) => model,
728 _ => panic!("expected ExprTarget::Model; was {self:#?}"),
729 }
730 }
731
732 pub fn model_id(self) -> Option<ModelId> {
734 Some(match self {
735 ExprTarget::Model(model) => model.id,
736 _ => return None,
737 })
738 }
739
740 pub fn as_table(self) -> Option<&'a Table> {
742 match self {
743 ExprTarget::Table(table) => Some(table),
744 _ => None,
745 }
746 }
747
748 #[track_caller]
754 pub fn as_table_unwrap(self) -> &'a Table {
755 self.as_table()
756 .unwrap_or_else(|| panic!("expected ExprTarget::Table; was {self:#?}"))
757 }
758}
759
760impl<'a, T: Resolve> IntoExprTarget<'a, T> for ExprTarget<'a> {
761 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
762 match self {
763 ExprTarget::Source(source_table) => {
764 if source_table.from.len() == 1 && source_table.from[0].joins.is_empty() {
765 match &source_table.from[0].relation {
766 TableFactor::Table(source_table_id) => {
767 debug_assert_eq!(0, source_table_id.0);
768 debug_assert_eq!(1, source_table.tables.len());
769
770 match &source_table.tables[0] {
771 TableRef::Table(table_id) => {
772 let table = schema.table(*table_id).unwrap();
773 ExprTarget::Table(table)
774 }
775 _ => self,
776 }
777 }
778 }
779 } else {
780 self
781 }
782 }
783 _ => self,
784 }
785 }
786}
787
788impl<'a, T> IntoExprTarget<'a, T> for &'a ModelRoot {
789 fn into_expr_target(self, _schema: &'a T) -> ExprTarget<'a> {
790 ExprTarget::Model(self)
791 }
792}
793
794impl<'a, T> IntoExprTarget<'a, T> for &'a Table {
795 fn into_expr_target(self, _schema: &'a T) -> ExprTarget<'a> {
796 ExprTarget::Table(self)
797 }
798}
799
800impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Query {
801 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
802 self.body.into_expr_target(schema)
803 }
804}
805
806impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a ExprSet {
807 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
808 match self {
809 ExprSet::Select(select) => select.into_expr_target(schema),
810 ExprSet::SetOp(_) => todo!(),
811 ExprSet::Update(update) => update.into_expr_target(schema),
812 ExprSet::Values(_) => ExprTarget::Free,
813 ExprSet::Insert(insert) => insert.into_expr_target(schema),
814 }
815 }
816}
817
818impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Select {
819 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
820 self.source.into_expr_target(schema)
821 }
822}
823
824impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Insert {
825 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
826 self.target.into_expr_target(schema)
827 }
828}
829
830impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Update {
831 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
832 self.target.into_expr_target(schema)
833 }
834}
835
836impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Delete {
837 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
838 self.from.into_expr_target(schema)
839 }
840}
841
842impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a InsertTarget {
843 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
844 match self {
845 InsertTarget::Scope(query) => query.into_expr_target(schema),
846 InsertTarget::Model(model) => {
847 let Some(model) = schema.model(*model) else {
848 todo!()
849 };
850 ExprTarget::Model(model.as_root_unwrap())
851 }
852 InsertTarget::Table(insert_table) => {
853 let table = schema.table(insert_table.table).unwrap();
854 ExprTarget::Table(table)
855 }
856 }
857 }
858}
859
860impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a UpdateTarget {
861 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
862 match self {
863 UpdateTarget::Query(query) => query.into_expr_target(schema),
864 UpdateTarget::Model(model) => {
865 let Some(model) = schema.model(*model) else {
866 todo!()
867 };
868 ExprTarget::Model(model.as_root_unwrap())
869 }
870 UpdateTarget::Table(table_id) => {
871 let Some(table) = schema.table(*table_id) else {
872 todo!()
873 };
874 ExprTarget::Table(table)
875 }
876 }
877 }
878}
879
880impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Source {
881 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
882 match self {
883 Source::Model(source_model) => {
884 let Some(model) = schema.model(source_model.id) else {
885 todo!()
886 };
887 ExprTarget::Model(model.as_root_unwrap())
888 }
889 Source::Table(source_table) => {
890 ExprTarget::Source(source_table).into_expr_target(schema)
891 }
892 }
893 }
894}
895
896impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Statement {
897 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
898 match self {
899 Statement::Delete(stmt) => stmt.into_expr_target(schema),
900 Statement::Insert(stmt) => stmt.into_expr_target(schema),
901 Statement::Query(stmt) => stmt.into_expr_target(schema),
902 Statement::Update(stmt) => stmt.into_expr_target(schema),
903 }
904 }
905}
906
907impl<'a> ArgTyStack<'a> {
908 fn new(tys: &'a [Type]) -> ArgTyStack<'a> {
909 ArgTyStack { tys, parent: None }
910 }
911
912 fn resolve_arg_ty(&self, expr_arg: &ExprArg) -> &'a Type {
913 let mut nesting = expr_arg.nesting;
914 let mut args = self;
915
916 while nesting > 0 {
917 args = args.parent.unwrap();
918 nesting -= 1;
919 }
920
921 &args.tys[expr_arg.position]
922 }
923
924 fn scope<'child>(&'child self, tys: &'child [Type]) -> ArgTyStack<'child> {
925 ArgTyStack {
926 tys,
927 parent: Some(self),
928 }
929 }
930}