1use crate::{
2 Schema,
3 schema::{
4 app::{Field, Model, ModelId, ModelRoot},
5 db::{self, Column, ColumnId, Table, TableId},
6 },
7 stmt::{
8 Delete, Expr, ExprArg, ExprColumn, ExprFunc, ExprReference, ExprSet, Insert, InsertTarget,
9 Query, Returning, Select, Source, SourceTable, Statement, TableDerived, TableFactor,
10 TableRef, Type, TypeUnion, Update, UpdateTarget,
11 },
12};
13
14#[derive(Debug)]
31pub struct ExprContext<'a, T = Schema> {
32 schema: &'a T,
33 parent: Option<&'a ExprContext<'a, T>>,
34 target: ExprTarget<'a>,
35}
36
37#[derive(Debug)]
49pub enum ResolvedRef<'a> {
50 Column(&'a Column),
59
60 Field(&'a Field),
70
71 Model(&'a ModelRoot),
73
74 Cte {
84 nesting: usize,
86 index: usize,
88 },
89
90 Derived(DerivedRef<'a>),
96}
97
98#[derive(Debug)]
100pub struct DerivedRef<'a> {
101 pub nesting: usize,
103
104 pub index: usize,
106
107 pub derived: &'a TableDerived,
109}
110
111impl DerivedRef<'_> {
112 pub fn is_column_always_null(&self) -> bool {
118 let ExprSet::Values(values) = &self.derived.subquery.body else {
119 return false;
120 };
121
122 if values.is_empty() {
123 return false;
124 }
125
126 values.rows.iter().all(|row| self.row_column_is_null(row))
127 }
128
129 fn row_column_is_null(&self, row: &Expr) -> bool {
130 match row {
131 Expr::Value(super::Value::Record(record)) => {
132 self.index < record.len() && record[self.index].is_null()
133 }
134 Expr::Record(record) => {
135 self.index < record.len()
136 && matches!(&record.fields[self.index], Expr::Value(super::Value::Null))
137 }
138 Expr::Value(super::Value::Null) => true,
139 _ => false,
140 }
141 }
142}
143
144#[derive(Debug, Clone, Copy)]
149pub enum ExprTarget<'a> {
150 Free,
152
153 Model(&'a ModelRoot),
155
156 Table(&'a Table),
160
161 Source(&'a SourceTable),
163}
164
165pub trait Resolve {
171 fn table_for_model(&self, model: &ModelRoot) -> Option<&Table>;
173
174 fn model(&self, id: ModelId) -> Option<&Model>;
180
181 fn table(&self, id: TableId) -> Option<&Table>;
187}
188
189pub trait IntoExprTarget<'a, T = Schema> {
192 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a>;
194}
195
196#[derive(Debug)]
197struct ArgTyStack<'a> {
198 tys: &'a [Type],
199 parent: Option<&'a ArgTyStack<'a>>,
200}
201
202impl<'a, T> ExprContext<'a, T> {
203 pub fn schema(&self) -> &'a T {
205 self.schema
206 }
207
208 pub fn target(&self) -> ExprTarget<'a> {
210 self.target
211 }
212
213 pub fn target_at(&self, nesting: usize) -> &ExprTarget<'a> {
215 let mut curr = self;
216
217 for _ in 0..nesting {
219 let Some(parent) = curr.parent else {
220 todo!("bug: invalid nesting level");
221 };
222
223 curr = parent;
224 }
225
226 &curr.target
227 }
228}
229
230impl<'a> ExprContext<'a, ()> {
231 pub fn new_free() -> ExprContext<'a, ()> {
233 ExprContext {
234 schema: &(),
235 parent: None,
236 target: ExprTarget::Free,
237 }
238 }
239}
240
241impl<'a, T: Resolve> ExprContext<'a, T> {
242 pub fn new(schema: &'a T) -> ExprContext<'a, T> {
244 ExprContext::new_with_target(schema, ExprTarget::Free)
245 }
246
247 pub fn new_with_target(
249 schema: &'a T,
250 target: impl IntoExprTarget<'a, T>,
251 ) -> ExprContext<'a, T> {
252 let target = target.into_expr_target(schema);
253 ExprContext {
254 schema,
255 parent: None,
256 target,
257 }
258 }
259
260 pub fn scope<'child>(
263 &'child self,
264 target: impl IntoExprTarget<'child, T>,
265 ) -> ExprContext<'child, T> {
267 let target = target.into_expr_target(self.schema);
268 ExprContext {
269 schema: self.schema,
270 parent: Some(self),
271 target,
272 }
273 }
274
275 pub fn resolve_expr_reference(&self, expr_reference: &ExprReference) -> ResolvedRef<'a> {
291 let nesting = match expr_reference {
292 ExprReference::Column(expr_column) => expr_column.nesting,
293 ExprReference::Field { nesting, .. } => *nesting,
294 ExprReference::Model { nesting } => *nesting,
295 };
296
297 let target = self.target_at(nesting);
298
299 match target {
300 ExprTarget::Free => todo!("cannot resolve column in free context"),
301 ExprTarget::Model(model) => match expr_reference {
302 ExprReference::Model { .. } => ResolvedRef::Model(model),
303 ExprReference::Field { index, .. } => ResolvedRef::Field(&model.fields[*index]),
304 ExprReference::Column(expr_column) => {
305 assert_eq!(expr_column.table, 0, "TODO: is this true?");
306
307 let Some(table) = self.schema.table_for_model(model) else {
308 panic!(
309 "Failed to find database table for model '{:?}' - model may not be mapped to a table",
310 model.name
311 )
312 };
313 ResolvedRef::Column(&table.columns[expr_column.column])
314 }
315 },
316 ExprTarget::Table(table) => match expr_reference {
317 ExprReference::Model { .. } => {
318 panic!("Cannot resolve ExprReference::Model in Table target context")
319 }
320 ExprReference::Field { .. } => panic!(
321 "Cannot resolve ExprReference::Field in Table target context - use ExprReference::Column instead"
322 ),
323 ExprReference::Column(expr_column) => {
324 ResolvedRef::Column(&table.columns[expr_column.column])
325 }
326 },
327 ExprTarget::Source(source_table) => {
328 match expr_reference {
329 ExprReference::Column(expr_column) => {
330 let table_ref = &source_table.tables[expr_column.table];
332 match table_ref {
333 TableRef::Table(table_id) => {
334 let Some(table) = self.schema.table(*table_id) else {
335 panic!(
336 "Failed to resolve table with ID {:?} - table not found in schema.",
337 table_id,
338 );
339 };
340 ResolvedRef::Column(&table.columns[expr_column.column])
341 }
342 TableRef::Derived(derived) => ResolvedRef::Derived(DerivedRef {
343 nesting: expr_column.nesting,
344 index: expr_column.column,
345 derived,
346 }),
347 TableRef::Cte {
348 nesting: cte_nesting,
349 index,
350 } => {
351 ResolvedRef::Cte {
353 nesting: expr_column.nesting + cte_nesting,
354 index: *index,
355 }
356 }
357 TableRef::Arg(_) => todo!(),
358 }
359 }
360 ExprReference::Model { .. } => {
361 panic!("Cannot resolve ExprReference::Model in Source::Table context")
362 }
363 ExprReference::Field { .. } => panic!(
364 "Cannot resolve ExprReference::Field in Source::Table context - use ExprReference::Column instead"
365 ),
366 }
367 }
368 }
369 }
370
371 pub fn infer_stmt_ty(&self, stmt: &Statement, args: &[Type]) -> Type {
373 let cx = self.scope(stmt);
374
375 match stmt {
376 Statement::Delete(stmt) => stmt
377 .returning
378 .as_ref()
379 .map(|returning| cx.infer_returning_ty(returning, args, false))
380 .unwrap_or(Type::Unit),
381 Statement::Insert(stmt) => stmt
382 .returning
383 .as_ref()
384 .map(|returning| cx.infer_returning_ty(returning, args, stmt.source.single))
385 .unwrap_or(Type::Unit),
386 Statement::Query(stmt) => match &stmt.body {
387 ExprSet::Select(body) => cx.infer_returning_ty(&body.returning, args, stmt.single),
388 ExprSet::SetOp(_body) => todo!(),
389 ExprSet::Update(_body) => todo!(),
390 ExprSet::Values(_body) => todo!(),
391 ExprSet::Insert(body) => body
392 .returning
393 .as_ref()
394 .map(|returning| cx.infer_returning_ty(returning, args, stmt.single))
395 .unwrap_or(Type::Unit),
396 },
397 Statement::Update(stmt) => stmt
398 .returning
399 .as_ref()
400 .map(|returning| cx.infer_returning_ty(returning, args, false))
401 .unwrap_or(Type::Unit),
402 }
403 }
404
405 fn infer_returning_ty(&self, returning: &Returning, args: &[Type], single: bool) -> Type {
406 let arg_ty_stack = ArgTyStack::new(args);
407
408 match returning {
409 Returning::Model { .. } => {
410 let ty = Type::Model(
411 self.target
412 .model_id()
413 .expect("returning `Model` when not in model context"),
414 );
415
416 if single { ty } else { Type::list(ty) }
417 }
418 Returning::Changed => todo!(),
419 Returning::Project(expr) => {
420 let ty = self.infer_expr_ty2(&arg_ty_stack, expr, false);
421
422 if single { ty } else { Type::list(ty) }
423 }
424 Returning::Expr(expr) => self.infer_expr_ty2(&arg_ty_stack, expr, true),
425 }
426 }
427
428 pub fn infer_expr_ty(&self, expr: &Expr, args: &[Type]) -> Type {
430 let arg_ty_stack = ArgTyStack::new(args);
431 self.infer_expr_ty2(&arg_ty_stack, expr, false)
432 }
433
434 fn infer_expr_ty2(&self, args: &ArgTyStack<'_>, expr: &Expr, returning_expr: bool) -> Type {
435 match expr {
436 Expr::Arg(e) => args.resolve_arg_ty(e).clone(),
437 Expr::And(_) => Type::Bool,
438 Expr::AnyOp(_) | Expr::AllOp(_) => Type::Bool,
439 Expr::BinaryOp(_) => Type::Bool,
440 Expr::Cast(e) => e.ty.clone(),
441 Expr::Reference(expr_ref) => {
442 assert!(
443 !returning_expr,
444 "should have been handled in Expr::Project. Invalid expr?"
445 );
446 self.infer_expr_reference_ty(expr_ref)
447 }
448 Expr::IsNull(_) => Type::Bool,
449 Expr::IsVariant(_) => Type::Bool,
450 Expr::List(e) => {
451 debug_assert!(!e.items.is_empty());
452 Type::list(self.infer_expr_ty2(args, &e.items[0], returning_expr))
453 }
454 Expr::Map(e) => {
455 let base = self.infer_expr_ty2(args, &e.base, returning_expr);
457
458 let Type::List(item) = base else {
460 todo!("error handling; base={base:#?}")
461 };
462
463 let scope_tys = &[*item];
464
465 let args = args.scope(scope_tys);
467
468 let ty = self.infer_expr_ty2(&args, &e.map, returning_expr);
470
471 Type::list(ty)
473 }
474 Expr::Or(_) => Type::Bool,
475 Expr::Project(e) => {
476 if returning_expr {
477 match &*e.base {
478 Expr::Arg(expr_arg) => {
479 assert!(e.projection.as_slice().len() == 1);
484 return args.resolve_arg_ty(expr_arg).clone();
485 }
486 Expr::Reference(expr_reference) => {
487 assert!(e.projection.as_slice().len() == 1);
492 return self.infer_expr_reference_ty(expr_reference);
493 }
494 _ => {}
495 }
496 }
497
498 let mut base = self.infer_expr_ty2(args, &e.base, returning_expr);
499
500 for step in e.projection.iter() {
501 base = match base {
502 Type::Record(mut fields) => {
503 std::mem::replace(&mut fields[*step], Type::Null)
504 }
505 Type::List(items) => *items,
506 expr => todo!(
507 "returning_expr={returning_expr:#?}; expr={expr:#?}; project={e:#?}"
508 ),
509 }
510 }
511
512 base
513 }
514 Expr::Record(e) => Type::Record(
515 e.fields
516 .iter()
517 .map(|field| self.infer_expr_ty2(args, field, returning_expr))
518 .collect(),
519 ),
520 Expr::Value(value) => value.infer_ty(),
521 Expr::Let(expr_let) => {
522 let scope_tys: Vec<_> = expr_let
523 .bindings
524 .iter()
525 .map(|b| self.infer_expr_ty2(args, b, returning_expr))
526 .collect();
527 let args = args.scope(&scope_tys);
528 self.infer_expr_ty2(&args, &expr_let.body, returning_expr)
529 }
530 Expr::Match(expr_match) => {
531 let mut union = TypeUnion::new();
536 for arm in &expr_match.arms {
537 let ty = self.infer_expr_ty2(args, &arm.expr, returning_expr);
538 union.insert(ty);
539 }
540 let else_ty = self.infer_expr_ty2(args, &expr_match.else_expr, returning_expr);
541 union.insert(else_ty);
542 union.simplify()
543 }
544 Expr::Error(_) => Type::Unknown,
548 Expr::Exists(_) => Type::Bool,
549 Expr::Func(ExprFunc::Count(_)) => Type::U64,
550 Expr::Func(ExprFunc::LastInsertId(_)) => Type::I64,
551 _ => todo!("{expr:#?}"),
552 }
553 }
554
555 pub fn infer_expr_reference_ty(&self, expr_reference: &ExprReference) -> Type {
557 match self.resolve_expr_reference(expr_reference) {
558 ResolvedRef::Model(model) => Type::Model(model.id),
559 ResolvedRef::Column(column) => column.ty.clone(),
560 ResolvedRef::Field(field) => field.expr_ty().clone(),
561 ResolvedRef::Cte { .. } => todo!("type inference for CTE columns not implemented"),
562 ResolvedRef::Derived(_) => {
563 todo!("type inference for derived table columns not implemented")
564 }
565 }
566 }
567}
568
569impl<'a> ExprContext<'a, Schema> {
570 pub fn target_as_model(&self) -> Option<&'a ModelRoot> {
573 self.target.as_model()
574 }
575
576 pub fn expr_ref_column(&self, column_id: impl Into<ColumnId>) -> ExprReference {
584 let column_id = column_id.into();
585
586 match self.target {
587 ExprTarget::Free => {
588 panic!("Cannot create ExprColumn in free context - no table target available")
589 }
590 ExprTarget::Model(model) => {
591 let Some(table) = self.schema.table_for_model(model) else {
592 panic!(
593 "Failed to find database table for model '{:?}' - model may not be mapped to a table",
594 model.name
595 )
596 };
597
598 assert_eq!(table.id, column_id.table);
599 }
600 ExprTarget::Table(table) => assert_eq!(table.id, column_id.table),
601 ExprTarget::Source(source_table) => {
602 let [TableRef::Table(table_id)] = source_table.tables[..] else {
603 panic!(
604 "Expected exactly one table reference, found {} tables",
605 source_table.tables.len()
606 );
607 };
608 assert_eq!(table_id, column_id.table);
609 }
610 }
611
612 ExprReference::Column(ExprColumn {
613 nesting: 0,
614 table: 0,
615 column: column_id.index,
616 })
617 }
618}
619
620impl<'a, T> Clone for ExprContext<'a, T> {
621 fn clone(&self) -> Self {
622 *self
623 }
624}
625
626impl<'a, T> Copy for ExprContext<'a, T> {}
627
628impl<'a> ResolvedRef<'a> {
629 #[track_caller]
635 pub fn as_column_unwrap(self) -> &'a Column {
636 match self {
637 ResolvedRef::Column(column) => column,
638 _ => panic!("Expected ResolvedRef::Column, found {:?}", self),
639 }
640 }
641
642 #[track_caller]
648 pub fn as_field_unwrap(self) -> &'a Field {
649 match self {
650 ResolvedRef::Field(field) => field,
651 _ => panic!("Expected ResolvedRef::Field, found {:?}", self),
652 }
653 }
654
655 #[track_caller]
661 pub fn as_model_unwrap(self) -> &'a ModelRoot {
662 match self {
663 ResolvedRef::Model(model) => model,
664 _ => panic!("Expected ResolvedRef::Model, found {:?}", self),
665 }
666 }
667}
668
669impl Resolve for Schema {
670 fn model(&self, id: ModelId) -> Option<&Model> {
671 Some(self.app.model(id))
672 }
673
674 fn table(&self, id: TableId) -> Option<&Table> {
675 Some(self.db.table(id))
676 }
677
678 fn table_for_model(&self, model: &ModelRoot) -> Option<&Table> {
679 Some(self.table_for(model.id))
680 }
681}
682
683impl Resolve for db::Schema {
684 fn model(&self, _id: ModelId) -> Option<&Model> {
685 None
686 }
687
688 fn table(&self, id: TableId) -> Option<&Table> {
689 Some(db::Schema::table(self, id))
690 }
691
692 fn table_for_model(&self, _model: &ModelRoot) -> Option<&Table> {
693 None
694 }
695}
696
697impl Resolve for () {
698 fn model(&self, _id: ModelId) -> Option<&Model> {
699 None
700 }
701
702 fn table(&self, _id: TableId) -> Option<&Table> {
703 None
704 }
705
706 fn table_for_model(&self, _model: &ModelRoot) -> Option<&Table> {
707 None
708 }
709}
710
711impl<'a> ExprTarget<'a> {
712 pub fn as_model(self) -> Option<&'a ModelRoot> {
714 match self {
715 ExprTarget::Model(model) => Some(model),
716 _ => None,
717 }
718 }
719
720 #[track_caller]
726 pub fn as_model_unwrap(self) -> &'a ModelRoot {
727 match self.as_model() {
728 Some(model) => model,
729 _ => panic!("expected ExprTarget::Model; was {self:#?}"),
730 }
731 }
732
733 pub fn model_id(self) -> Option<ModelId> {
735 Some(match self {
736 ExprTarget::Model(model) => model.id,
737 _ => return None,
738 })
739 }
740
741 pub fn as_table(self) -> Option<&'a Table> {
743 match self {
744 ExprTarget::Table(table) => Some(table),
745 _ => None,
746 }
747 }
748
749 #[track_caller]
755 pub fn as_table_unwrap(self) -> &'a Table {
756 self.as_table()
757 .unwrap_or_else(|| panic!("expected ExprTarget::Table; was {self:#?}"))
758 }
759}
760
761impl<'a, T: Resolve> IntoExprTarget<'a, T> for ExprTarget<'a> {
762 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
763 match self {
764 ExprTarget::Source(source_table) => {
765 if source_table.from.len() == 1 && source_table.from[0].joins.is_empty() {
766 match &source_table.from[0].relation {
767 TableFactor::Table(source_table_id) => {
768 debug_assert_eq!(0, source_table_id.0);
769 debug_assert_eq!(1, source_table.tables.len());
770
771 match &source_table.tables[0] {
772 TableRef::Table(table_id) => {
773 let table = schema.table(*table_id).unwrap();
774 ExprTarget::Table(table)
775 }
776 _ => self,
777 }
778 }
779 }
780 } else {
781 self
782 }
783 }
784 _ => self,
785 }
786 }
787}
788
789impl<'a, T> IntoExprTarget<'a, T> for &'a ModelRoot {
790 fn into_expr_target(self, _schema: &'a T) -> ExprTarget<'a> {
791 ExprTarget::Model(self)
792 }
793}
794
795impl<'a, T> IntoExprTarget<'a, T> for &'a Table {
796 fn into_expr_target(self, _schema: &'a T) -> ExprTarget<'a> {
797 ExprTarget::Table(self)
798 }
799}
800
801impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Query {
802 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
803 self.body.into_expr_target(schema)
804 }
805}
806
807impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a ExprSet {
808 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
809 match self {
810 ExprSet::Select(select) => select.into_expr_target(schema),
811 ExprSet::SetOp(_) => todo!(),
812 ExprSet::Update(update) => update.into_expr_target(schema),
813 ExprSet::Values(_) => ExprTarget::Free,
814 ExprSet::Insert(insert) => insert.into_expr_target(schema),
815 }
816 }
817}
818
819impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Select {
820 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
821 self.source.into_expr_target(schema)
822 }
823}
824
825impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Insert {
826 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
827 self.target.into_expr_target(schema)
828 }
829}
830
831impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Update {
832 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
833 self.target.into_expr_target(schema)
834 }
835}
836
837impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Delete {
838 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
839 self.from.into_expr_target(schema)
840 }
841}
842
843impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a InsertTarget {
844 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
845 match self {
846 InsertTarget::Scope(query) => query.into_expr_target(schema),
847 InsertTarget::Model(model) => {
848 let Some(model) = schema.model(*model) else {
849 todo!()
850 };
851 ExprTarget::Model(model.as_root_unwrap())
852 }
853 InsertTarget::Table(insert_table) => {
854 let table = schema.table(insert_table.table).unwrap();
855 ExprTarget::Table(table)
856 }
857 }
858 }
859}
860
861impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a UpdateTarget {
862 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
863 match self {
864 UpdateTarget::Query(query) => query.into_expr_target(schema),
865 UpdateTarget::Model(model) => {
866 let Some(model) = schema.model(*model) else {
867 todo!()
868 };
869 ExprTarget::Model(model.as_root_unwrap())
870 }
871 UpdateTarget::Table(table_id) => {
872 let Some(table) = schema.table(*table_id) else {
873 todo!()
874 };
875 ExprTarget::Table(table)
876 }
877 }
878 }
879}
880
881impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Source {
882 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
883 match self {
884 Source::Model(source_model) => {
885 let Some(model) = schema.model(source_model.id) else {
886 todo!()
887 };
888 ExprTarget::Model(model.as_root_unwrap())
889 }
890 Source::Table(source_table) => {
891 ExprTarget::Source(source_table).into_expr_target(schema)
892 }
893 }
894 }
895}
896
897impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Statement {
898 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
899 match self {
900 Statement::Delete(stmt) => stmt.into_expr_target(schema),
901 Statement::Insert(stmt) => stmt.into_expr_target(schema),
902 Statement::Query(stmt) => stmt.into_expr_target(schema),
903 Statement::Update(stmt) => stmt.into_expr_target(schema),
904 }
905 }
906}
907
908impl<'a> ArgTyStack<'a> {
909 fn new(tys: &'a [Type]) -> ArgTyStack<'a> {
910 ArgTyStack { tys, parent: None }
911 }
912
913 fn resolve_arg_ty(&self, expr_arg: &ExprArg) -> &'a Type {
914 let mut nesting = expr_arg.nesting;
915 let mut args = self;
916
917 while nesting > 0 {
918 args = args.parent.unwrap();
919 nesting -= 1;
920 }
921
922 &args.tys[expr_arg.position]
923 }
924
925 fn scope<'child>(&'child self, tys: &'child [Type]) -> ArgTyStack<'child> {
926 ArgTyStack {
927 tys,
928 parent: Some(self),
929 }
930 }
931}