1use crate::{
2 schema::{
3 app::{Field, Model, ModelId, ModelRoot},
4 db::{self, Column, ColumnId, Table, TableId},
5 },
6 stmt::{
7 Delete, Expr, ExprArg, ExprColumn, ExprReference, ExprSet, Insert, InsertTarget, Query,
8 Returning, Select, Source, SourceTable, Statement, TableDerived, TableFactor, TableRef,
9 Type, TypeUnion, Update, UpdateTarget,
10 },
11 Schema,
12};
13
14#[derive(Debug)]
16pub struct ExprContext<'a, T = Schema> {
17 schema: &'a T,
18 parent: Option<&'a ExprContext<'a, T>>,
19 target: ExprTarget<'a>,
20}
21
22#[derive(Debug)]
34pub enum ResolvedRef<'a> {
35 Column(&'a Column),
44
45 Field(&'a Field),
55
56 Model(&'a ModelRoot),
58
59 Cte { nesting: usize, index: usize },
69
70 Derived(DerivedRef<'a>),
76}
77
78#[derive(Debug)]
80pub struct DerivedRef<'a> {
81 pub nesting: usize,
83
84 pub index: usize,
86
87 pub derived: &'a TableDerived,
89}
90
91impl DerivedRef<'_> {
92 pub fn is_column_always_null(&self) -> bool {
98 let ExprSet::Values(values) = &self.derived.subquery.body else {
99 return false;
100 };
101
102 if values.is_empty() {
103 return false;
104 }
105
106 values.rows.iter().all(|row| self.row_column_is_null(row))
107 }
108
109 fn row_column_is_null(&self, row: &Expr) -> bool {
110 match row {
111 Expr::Value(super::Value::Record(record)) => {
112 self.index < record.len() && record[self.index].is_null()
113 }
114 Expr::Record(record) => {
115 self.index < record.len()
116 && matches!(&record.fields[self.index], Expr::Value(super::Value::Null))
117 }
118 Expr::Value(super::Value::Null) => true,
119 _ => false,
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy)]
125pub enum ExprTarget<'a> {
126 Free,
128
129 Model(&'a ModelRoot),
131
132 Table(&'a Table),
136
137 Source(&'a SourceTable),
139}
140
141pub trait Resolve {
142 fn table_for_model(&self, model: &ModelRoot) -> Option<&Table>;
143
144 fn model(&self, id: ModelId) -> Option<&Model>;
150
151 fn table(&self, id: TableId) -> Option<&Table>;
157}
158
159pub trait IntoExprTarget<'a, T = Schema> {
160 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a>;
161}
162
163#[derive(Debug)]
164struct ArgTyStack<'a> {
165 tys: &'a [Type],
166 parent: Option<&'a ArgTyStack<'a>>,
167}
168
169impl<'a, T> ExprContext<'a, T> {
170 pub fn schema(&self) -> &'a T {
171 self.schema
172 }
173
174 pub fn target(&self) -> ExprTarget<'a> {
175 self.target
176 }
177
178 pub fn target_at(&self, nesting: usize) -> &ExprTarget<'a> {
180 let mut curr = self;
181
182 for _ in 0..nesting {
184 let Some(parent) = curr.parent else {
185 todo!("bug: invalid nesting level");
186 };
187
188 curr = parent;
189 }
190
191 &curr.target
192 }
193}
194
195impl<'a> ExprContext<'a, ()> {
196 pub fn new_free() -> ExprContext<'a, ()> {
197 ExprContext {
198 schema: &(),
199 parent: None,
200 target: ExprTarget::Free,
201 }
202 }
203}
204
205impl<'a, T: Resolve> ExprContext<'a, T> {
206 pub fn new(schema: &'a T) -> ExprContext<'a, T> {
207 ExprContext::new_with_target(schema, ExprTarget::Free)
208 }
209
210 pub fn new_with_target(
211 schema: &'a T,
212 target: impl IntoExprTarget<'a, T>,
213 ) -> ExprContext<'a, T> {
214 let target = target.into_expr_target(schema);
215 ExprContext {
216 schema,
217 parent: None,
218 target,
219 }
220 }
221
222 pub fn scope<'child>(
223 &'child self,
224 target: impl IntoExprTarget<'child, T>,
225 ) -> ExprContext<'child, T> {
227 let target = target.into_expr_target(self.schema);
228 ExprContext {
229 schema: self.schema,
230 parent: Some(self),
231 target,
232 }
233 }
234
235 pub fn resolve_expr_reference(&self, expr_reference: &ExprReference) -> ResolvedRef<'a> {
251 let nesting = match expr_reference {
252 ExprReference::Column(expr_column) => expr_column.nesting,
253 ExprReference::Field { nesting, .. } => *nesting,
254 ExprReference::Model { nesting } => *nesting,
255 };
256
257 let target = self.target_at(nesting);
258
259 match target {
260 ExprTarget::Free => todo!("cannot resolve column in free context"),
261 ExprTarget::Model(model) => match expr_reference {
262 ExprReference::Model { .. } => ResolvedRef::Model(model),
263 ExprReference::Field { index, .. } => ResolvedRef::Field(&model.fields[*index]),
264 ExprReference::Column(expr_column) => {
265 assert_eq!(expr_column.table, 0, "TODO: is this true?");
266
267 let Some(table) = self.schema.table_for_model(model) else {
268 panic!("Failed to find database table for model '{:?}' - model may not be mapped to a table", model.name)
269 };
270 ResolvedRef::Column(&table.columns[expr_column.column])
271 }
272 },
273 ExprTarget::Table(table) => match expr_reference {
274 ExprReference::Model { .. } => panic!(
275 "Cannot resolve ExprReference::Model in Table target context"
276 ),
277 ExprReference::Field {.. } => panic!(
278 "Cannot resolve ExprReference::Field in Table target context - use ExprReference::Column instead"
279 ),
280 ExprReference::Column(expr_column) => ResolvedRef::Column(&table.columns[expr_column.column]),
281 },
282 ExprTarget::Source(source_table) => {
283 match expr_reference {
284 ExprReference::Column(expr_column) => {
285 let table_ref = &source_table.tables[expr_column.table];
287 match table_ref {
288 TableRef::Table(table_id) => {
289 let Some(table) = self.schema.table(*table_id) else {
290 panic!(
291 "Failed to resolve table with ID {:?} - table not found in schema.",
292 table_id,
293 );
294 };
295 ResolvedRef::Column(&table.columns[expr_column.column])
296 }
297 TableRef::Derived(derived) => {
298 ResolvedRef::Derived(DerivedRef {
299 nesting: expr_column.nesting,
300 index: expr_column.column,
301 derived,
302 })
303 }
304 TableRef::Cte {
305 nesting: cte_nesting,
306 index,
307 } => {
308 ResolvedRef::Cte {
310 nesting: expr_column.nesting + cte_nesting,
311 index: *index,
312 }
313 }
314 TableRef::Arg(_) => todo!(),
315 }
316 }
317 ExprReference::Model { .. } => panic!(
318 "Cannot resolve ExprReference::Model in Source::Table context"
319 ),
320 ExprReference::Field { .. } => panic!(
321 "Cannot resolve ExprReference::Field in Source::Table context - use ExprReference::Column instead"
322 ),
323 }
324 }
325 }
326 }
327
328 pub fn infer_stmt_ty(&self, stmt: &Statement, args: &[Type]) -> Type {
329 let cx = self.scope(stmt);
330
331 match stmt {
332 Statement::Delete(stmt) => stmt
333 .returning
334 .as_ref()
335 .map(|returning| cx.infer_returning_ty(returning, args, false))
336 .unwrap_or(Type::Unit),
337 Statement::Insert(stmt) => stmt
338 .returning
339 .as_ref()
340 .map(|returning| cx.infer_returning_ty(returning, args, stmt.source.single))
341 .unwrap_or(Type::Unit),
342 Statement::Query(stmt) => match &stmt.body {
343 ExprSet::Select(body) => cx.infer_returning_ty(&body.returning, args, stmt.single),
344 ExprSet::SetOp(_body) => todo!(),
345 ExprSet::Update(_body) => todo!(),
346 ExprSet::Values(_body) => todo!(),
347 ExprSet::Insert(body) => body
348 .returning
349 .as_ref()
350 .map(|returning| cx.infer_returning_ty(returning, args, stmt.single))
351 .unwrap_or(Type::Unit),
352 },
353 Statement::Update(stmt) => stmt
354 .returning
355 .as_ref()
356 .map(|returning| cx.infer_returning_ty(returning, args, false))
357 .unwrap_or(Type::Unit),
358 }
359 }
360
361 fn infer_returning_ty(&self, returning: &Returning, args: &[Type], single: bool) -> Type {
362 let arg_ty_stack = ArgTyStack::new(args);
363
364 match returning {
365 Returning::Model { .. } => {
366 let ty = Type::Model(
367 self.target
368 .model_id()
369 .expect("returning `Model` when not in model context"),
370 );
371
372 if single {
373 ty
374 } else {
375 Type::list(ty)
376 }
377 }
378 Returning::Changed => todo!(),
379 Returning::Expr(expr) => {
380 let ty = self.infer_expr_ty2(&arg_ty_stack, expr, false);
381
382 if single {
383 ty
384 } else {
385 Type::list(ty)
386 }
387 }
388 Returning::Value(expr) => self.infer_expr_ty2(&arg_ty_stack, expr, true),
389 }
390 }
391
392 pub fn infer_expr_ty(&self, expr: &Expr, args: &[Type]) -> Type {
393 let arg_ty_stack = ArgTyStack::new(args);
394 self.infer_expr_ty2(&arg_ty_stack, expr, false)
395 }
396
397 fn infer_expr_ty2(&self, args: &ArgTyStack<'_>, expr: &Expr, returning_expr: bool) -> Type {
398 match expr {
399 Expr::Arg(e) => args.resolve_arg_ty(e).clone(),
400 Expr::And(_) => Type::Bool,
401 Expr::BinaryOp(_) => Type::Bool,
402 Expr::Cast(e) => e.ty.clone(),
403 Expr::Reference(expr_ref) => {
404 assert!(
405 !returning_expr,
406 "should have been handled in Expr::Project. Invalid expr?"
407 );
408 self.infer_expr_reference_ty(expr_ref)
409 }
410 Expr::IsNull(_) => Type::Bool,
411 Expr::IsVariant(_) => Type::Bool,
412 Expr::List(e) => {
413 debug_assert!(!e.items.is_empty());
414 Type::list(self.infer_expr_ty2(args, &e.items[0], returning_expr))
415 }
416 Expr::Map(e) => {
417 let base = self.infer_expr_ty2(args, &e.base, returning_expr);
419
420 let Type::List(item) = base else {
422 todo!("error handling; base={base:#?}")
423 };
424
425 let scope_tys = &[*item];
426
427 let args = args.scope(scope_tys);
429
430 let ty = self.infer_expr_ty2(&args, &e.map, returning_expr);
432
433 Type::list(ty)
435 }
436 Expr::Or(_) => Type::Bool,
437 Expr::Project(e) => {
438 if returning_expr {
439 match &*e.base {
440 Expr::Arg(expr_arg) => {
441 assert!(e.projection.as_slice().len() == 1);
446 return args.resolve_arg_ty(expr_arg).clone();
447 }
448 Expr::Reference(expr_reference) => {
449 assert!(e.projection.as_slice().len() == 1);
454 return self.infer_expr_reference_ty(expr_reference);
455 }
456 _ => {}
457 }
458 }
459
460 let mut base = self.infer_expr_ty2(args, &e.base, returning_expr);
461
462 for step in e.projection.iter() {
463 base = match base {
464 Type::Record(mut fields) => {
465 std::mem::replace(&mut fields[*step], Type::Null)
466 }
467 Type::List(items) => *items,
468 expr => todo!(
469 "returning_expr={returning_expr:#?}; expr={expr:#?}; project={e:#?}"
470 ),
471 }
472 }
473
474 base
475 }
476 Expr::Record(e) => Type::Record(
477 e.fields
478 .iter()
479 .map(|field| self.infer_expr_ty2(args, field, returning_expr))
480 .collect(),
481 ),
482 Expr::Value(value) => value.infer_ty(),
483 Expr::Let(expr_let) => {
484 let scope_tys: Vec<_> = expr_let
485 .bindings
486 .iter()
487 .map(|b| self.infer_expr_ty2(args, b, returning_expr))
488 .collect();
489 let args = args.scope(&scope_tys);
490 self.infer_expr_ty2(&args, &expr_let.body, returning_expr)
491 }
492 Expr::Match(expr_match) => {
493 let mut union = TypeUnion::new();
498 for arm in &expr_match.arms {
499 let ty = self.infer_expr_ty2(args, &arm.expr, returning_expr);
500 union.insert(ty);
501 }
502 let else_ty = self.infer_expr_ty2(args, &expr_match.else_expr, returning_expr);
503 union.insert(else_ty);
504 union.simplify()
505 }
506 Expr::Error(_) => Type::Unknown,
510 _ => todo!("{expr:#?}"),
511 }
512 }
513
514 pub fn infer_expr_reference_ty(&self, expr_reference: &ExprReference) -> Type {
515 match self.resolve_expr_reference(expr_reference) {
516 ResolvedRef::Model(model) => Type::Model(model.id),
517 ResolvedRef::Column(column) => column.ty.clone(),
518 ResolvedRef::Field(field) => field.expr_ty().clone(),
519 ResolvedRef::Cte { .. } => todo!("type inference for CTE columns not implemented"),
520 ResolvedRef::Derived(_) => {
521 todo!("type inference for derived table columns not implemented")
522 }
523 }
524 }
525}
526
527impl<'a> ExprContext<'a, Schema> {
528 pub fn target_as_model(&self) -> Option<&'a ModelRoot> {
529 self.target.as_model()
530 }
531
532 pub fn expr_ref_column(&self, column_id: impl Into<ColumnId>) -> ExprReference {
533 let column_id = column_id.into();
534
535 match self.target {
536 ExprTarget::Free => {
537 panic!("Cannot create ExprColumn in free context - no table target available")
538 }
539 ExprTarget::Model(model) => {
540 let Some(table) = self.schema.table_for_model(model) else {
541 panic!("Failed to find database table for model '{:?}' - model may not be mapped to a table", model.name)
542 };
543
544 assert_eq!(table.id, column_id.table);
545 }
546 ExprTarget::Table(table) => assert_eq!(table.id, column_id.table),
547 ExprTarget::Source(source_table) => {
548 let [TableRef::Table(table_id)] = source_table.tables[..] else {
549 panic!(
550 "Expected exactly one table reference, found {} tables",
551 source_table.tables.len()
552 );
553 };
554 assert_eq!(table_id, column_id.table);
555 }
556 }
557
558 ExprReference::Column(ExprColumn {
559 nesting: 0,
560 table: 0,
561 column: column_id.index,
562 })
563 }
564}
565
566impl<'a, T> Clone for ExprContext<'a, T> {
567 fn clone(&self) -> Self {
568 *self
569 }
570}
571
572impl<'a, T> Copy for ExprContext<'a, T> {}
573
574impl<'a> ResolvedRef<'a> {
575 #[track_caller]
576 pub fn expect_column(self) -> &'a Column {
577 match self {
578 ResolvedRef::Column(column) => column,
579 _ => panic!("Expected ResolvedRef::Column, found {:?}", self),
580 }
581 }
582
583 #[track_caller]
584 pub fn expect_field(self) -> &'a Field {
585 match self {
586 ResolvedRef::Field(field) => field,
587 _ => panic!("Expected ResolvedRef::Field, found {:?}", self),
588 }
589 }
590
591 #[track_caller]
592 pub fn expect_model(self) -> &'a ModelRoot {
593 match self {
594 ResolvedRef::Model(model) => model,
595 _ => panic!("Expected ResolvedRef::Model, found {:?}", self),
596 }
597 }
598}
599
600impl Resolve for Schema {
601 fn model(&self, id: ModelId) -> Option<&Model> {
602 Some(self.app.model(id))
603 }
604
605 fn table(&self, id: TableId) -> Option<&Table> {
606 Some(self.db.table(id))
607 }
608
609 fn table_for_model(&self, model: &ModelRoot) -> Option<&Table> {
610 Some(self.table_for(model.id))
611 }
612}
613
614impl Resolve for db::Schema {
615 fn model(&self, _id: ModelId) -> Option<&Model> {
616 None
617 }
618
619 fn table(&self, id: TableId) -> Option<&Table> {
620 Some(db::Schema::table(self, id))
621 }
622
623 fn table_for_model(&self, _model: &ModelRoot) -> Option<&Table> {
624 None
625 }
626}
627
628impl Resolve for () {
629 fn model(&self, _id: ModelId) -> Option<&Model> {
630 None
631 }
632
633 fn table(&self, _id: TableId) -> Option<&Table> {
634 None
635 }
636
637 fn table_for_model(&self, _model: &ModelRoot) -> Option<&Table> {
638 None
639 }
640}
641
642impl<'a> ExprTarget<'a> {
643 pub fn as_model(self) -> Option<&'a ModelRoot> {
644 match self {
645 ExprTarget::Model(model) => Some(model),
646 _ => None,
647 }
648 }
649
650 #[track_caller]
651 pub fn as_model_unwrap(self) -> &'a ModelRoot {
652 match self.as_model() {
653 Some(model) => model,
654 _ => panic!("expected ExprTarget::Model; was {self:#?}"),
655 }
656 }
657
658 pub fn model_id(self) -> Option<ModelId> {
659 Some(match self {
660 ExprTarget::Model(model) => model.id,
661 _ => return None,
662 })
663 }
664
665 #[track_caller]
666 pub fn as_table_unwrap(self) -> &'a Table {
667 match self {
668 ExprTarget::Table(table) => table,
669 _ => panic!("expected ExprTarget::Table; was {self:#?}"),
670 }
671 }
672}
673
674impl<'a, T: Resolve> IntoExprTarget<'a, T> for ExprTarget<'a> {
675 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
676 match self {
677 ExprTarget::Source(source_table) => {
678 if source_table.from.len() == 1 && source_table.from[0].joins.is_empty() {
679 match &source_table.from[0].relation {
680 TableFactor::Table(source_table_id) => {
681 debug_assert_eq!(0, source_table_id.0);
682 debug_assert_eq!(1, source_table.tables.len());
683
684 match &source_table.tables[0] {
685 TableRef::Table(table_id) => {
686 let table = schema.table(*table_id).unwrap();
687 ExprTarget::Table(table)
688 }
689 _ => self,
690 }
691 }
692 }
693 } else {
694 self
695 }
696 }
697 _ => self,
698 }
699 }
700}
701
702impl<'a, T> IntoExprTarget<'a, T> for &'a ModelRoot {
703 fn into_expr_target(self, _schema: &'a T) -> ExprTarget<'a> {
704 ExprTarget::Model(self)
705 }
706}
707
708impl<'a, T> IntoExprTarget<'a, T> for &'a Table {
709 fn into_expr_target(self, _schema: &'a T) -> ExprTarget<'a> {
710 ExprTarget::Table(self)
711 }
712}
713
714impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Query {
715 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
716 self.body.into_expr_target(schema)
717 }
718}
719
720impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a ExprSet {
721 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
722 match self {
723 ExprSet::Select(select) => select.into_expr_target(schema),
724 ExprSet::SetOp(_) => todo!(),
725 ExprSet::Update(update) => update.into_expr_target(schema),
726 ExprSet::Values(_) => ExprTarget::Free,
727 ExprSet::Insert(insert) => insert.into_expr_target(schema),
728 }
729 }
730}
731
732impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Select {
733 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
734 self.source.into_expr_target(schema)
735 }
736}
737
738impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Insert {
739 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
740 self.target.into_expr_target(schema)
741 }
742}
743
744impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Update {
745 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
746 self.target.into_expr_target(schema)
747 }
748}
749
750impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Delete {
751 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
752 self.from.into_expr_target(schema)
753 }
754}
755
756impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a InsertTarget {
757 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
758 match self {
759 InsertTarget::Scope(query) => query.into_expr_target(schema),
760 InsertTarget::Model(model) => {
761 let Some(model) = schema.model(*model) else {
762 todo!()
763 };
764 ExprTarget::Model(model.expect_root())
765 }
766 InsertTarget::Table(insert_table) => {
767 let table = schema.table(insert_table.table).unwrap();
768 ExprTarget::Table(table)
769 }
770 }
771 }
772}
773
774impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a UpdateTarget {
775 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
776 match self {
777 UpdateTarget::Query(query) => query.into_expr_target(schema),
778 UpdateTarget::Model(model) => {
779 let Some(model) = schema.model(*model) else {
780 todo!()
781 };
782 ExprTarget::Model(model.expect_root())
783 }
784 UpdateTarget::Table(table_id) => {
785 let Some(table) = schema.table(*table_id) else {
786 todo!()
787 };
788 ExprTarget::Table(table)
789 }
790 }
791 }
792}
793
794impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Source {
795 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
796 match self {
797 Source::Model(source_model) => {
798 let Some(model) = schema.model(source_model.model) else {
799 todo!()
800 };
801 ExprTarget::Model(model.expect_root())
802 }
803 Source::Table(source_table) => {
804 ExprTarget::Source(source_table).into_expr_target(schema)
805 }
806 }
807 }
808}
809
810impl<'a, T: Resolve> IntoExprTarget<'a, T> for &'a Statement {
811 fn into_expr_target(self, schema: &'a T) -> ExprTarget<'a> {
812 match self {
813 Statement::Delete(stmt) => stmt.into_expr_target(schema),
814 Statement::Insert(stmt) => stmt.into_expr_target(schema),
815 Statement::Query(stmt) => stmt.into_expr_target(schema),
816 Statement::Update(stmt) => stmt.into_expr_target(schema),
817 }
818 }
819}
820
821impl<'a> ArgTyStack<'a> {
822 fn new(tys: &'a [Type]) -> ArgTyStack<'a> {
823 ArgTyStack { tys, parent: None }
824 }
825
826 fn resolve_arg_ty(&self, expr_arg: &ExprArg) -> &'a Type {
827 let mut nesting = expr_arg.nesting;
828 let mut args = self;
829
830 while nesting > 0 {
831 args = args.parent.unwrap();
832 nesting -= 1;
833 }
834
835 &args.tys[expr_arg.position]
836 }
837
838 fn scope<'child>(&'child self, tys: &'child [Type]) -> ArgTyStack<'child> {
839 ArgTyStack {
840 tys,
841 parent: Some(self),
842 }
843 }
844}