diff --git a/internal/reference/reference.go b/internal/reference/reference.go index 780e46d..4859b8f 100644 --- a/internal/reference/reference.go +++ b/internal/reference/reference.go @@ -450,6 +450,27 @@ func (r *Resolver[T, F]) Alias(name string, a T) error { return nil } +// UpdateAlias updates an existing alias in any scope with a new value. This is used to update +// source aliases with iteration-specific values while maintaining the persistent scope structure. +// The alias must already exist in some scope. +func (r *Resolver[T, F]) UpdateAlias(name string, a T) error { + if len(r.aliases) == 0 { + return errors.New("internal error - no scope available for alias update") + } + + aKey := aliasKey{r.currLib, name} + + // Search from innermost to outermost scope to find the alias + for i := len(r.aliases) - 1; i >= 0; i-- { + if _, exists := r.aliases[i][aKey]; exists { + r.aliases[i][aKey] = a + return nil + } + } + + return fmt.Errorf("alias %s does not exist in any scope", name) +} + // PublicDefs returns the public definitions stored in the reference resolver. func (r *Resolver[T, F]) PublicDefs() (map[result.LibKey]map[string]T, error) { pDefs := make(map[result.LibKey]map[string]T) diff --git a/internal/reference/reference_test.go b/internal/reference/reference_test.go index d419c98..c59f1b6 100644 --- a/internal/reference/reference_test.go +++ b/internal/reference/reference_test.go @@ -1102,6 +1102,49 @@ func buildLibrary(t *testing.T, r *Resolver[model.IExpression, model.IExpression } } +func TestUpdateAlias(t *testing.T) { + r := NewResolver[string, string]() + r.SetCurrentUnnamed() + r.EnterScope() + + // Create initial alias + if err := r.Alias("A", "initial_value"); err != nil { + t.Fatalf("Alias(A) returned unexpected error: %v", err) + } + + // Verify initial value + got, err := r.ResolveLocal("A") + if err != nil { + t.Errorf("ResolveLocal(A) returned unexpected error: %v", err) + } + if got != "initial_value" { + t.Errorf("ResolveLocal(A) = %v, want %v", got, "initial_value") + } + + // Update alias + if err := r.UpdateAlias("A", "updated_value"); err != nil { + t.Fatalf("UpdateAlias(A) returned unexpected error: %v", err) + } + + // Verify updated value + got, err = r.ResolveLocal("A") + if err != nil { + t.Errorf("ResolveLocal(A) returned unexpected error: %v", err) + } + if got != "updated_value" { + t.Errorf("ResolveLocal(A) = %v, want %v", got, "updated_value") + } + + // Test error case - update non-existent alias + err = r.UpdateAlias("NonExistent", "value") + if err == nil { + t.Errorf("UpdateAlias(NonExistent) expected error but got success") + } + if !strings.Contains(err.Error(), "does not exist") { + t.Errorf("UpdateAlias error should contain 'does not exist', got: %v", err) + } +} + func newFHIRModelInfo(t *testing.T) *modelinfo.ModelInfos { t.Helper() fhirMIBytes, err := embeddata.ModelInfos.ReadFile("third_party/cqframework/fhir-modelinfo-4.0.1.xml") diff --git a/interpreter/expressions.go b/interpreter/expressions.go index 0fe9478..1579a7a 100644 --- a/interpreter/expressions.go +++ b/interpreter/expressions.go @@ -308,6 +308,11 @@ func (i *interpreter) evalAliasRef(a *model.AliasRef) (result.Value, error) { } func (i *interpreter) evalIdentifierRef(r *model.IdentifierRef) (result.Value, error) { + // First try to resolve as a local reference (alias, variable, etc.) + if val, err := i.refs.ResolveLocal(r.Name); err == nil { + return val, nil + } + obj, err := i.refs.ScopedStruct() if err != nil { return result.Value{}, err diff --git a/interpreter/query.go b/interpreter/query.go index 2b2bb5d..bb01844 100644 --- a/interpreter/query.go +++ b/interpreter/query.go @@ -63,23 +63,27 @@ func (i *interpreter) evalQuery(q *model.Query) (result.Value, error) { return result.Value{}, err } - sourceObj, err := i.letClause(q.Let) - if err != nil { - return result.Value{}, err + // Register let clauses in the query scope (they will be evaluated per iteration) + if len(q.Let) > 0 { + for _, letClause := range q.Let { + // Register let variables with placeholder values - they'll be updated per iteration + if err := i.refs.Alias(letClause.Identifier, result.Value{}); err != nil { + return result.Value{}, err + } + } } - sourceObjs = append(sourceObjs, sourceObj...) for _, relationship := range q.Relationship { var err error var sourceObj result.Value - iters, sourceObj, err = i.relationshipClause(iters, relationship) + iters, sourceObj, err = i.relationshipClause(iters, relationship, q.Let) if err != nil { return result.Value{}, err } sourceObjs = append(sourceObjs, sourceObj) } - iters, err = i.whereClause(iters, q.Where) + iters, err = i.whereClause(iters, q.Where, q.Let) if err != nil { return result.Value{}, err } @@ -95,7 +99,7 @@ func (i *interpreter) evalQuery(q *model.Query) (result.Value, error) { } if q.Return != nil { - finalVals, err = i.returnClause(iters, q.Return) + finalVals, err = i.returnClause(iters, q.Return, q.Let) if err != nil { return result.Value{}, err } @@ -178,7 +182,10 @@ func (i *interpreter) sourceClause(s []*model.AliasedSource) ([]iteration, []res } aliases = append(aliases, a) } - return cartesianProduct(aliases), sourceObjs, nil + + iterations := cartesianProduct(aliases) + + return iterations, sourceObjs, nil } // cartesianProduct converts [[{A, 4}], [{B, 1}, {B, 2}, {B, 3}]] into the cartesian product @@ -207,24 +214,8 @@ func cartesianProduct(aliases [][]alias) []iteration { return cartIters } -func (i *interpreter) letClause(m []*model.LetClause) ([]result.Value, error) { - sourceObjs := make([]result.Value, 0, len(m)) - for _, letClause := range m { - obj, err := i.evalExpression(letClause.Expression) - if err != nil { - return nil, err - } - sourceObjs = append(sourceObjs, obj) - if err := i.refs.Alias(letClause.Identifier, obj); err != nil { - return nil, err - } - } - - return sourceObjs, nil -} - -func (i *interpreter) relationshipClause(iters []iteration, m model.IRelationshipClause) ([]iteration, result.Value, error) { +func (i *interpreter) relationshipClause(iters []iteration, m model.IRelationshipClause, letClauses []*model.LetClause) ([]iteration, result.Value, error) { var relAliasName string var relAliasSource model.IExpression var suchThat model.IExpression @@ -260,20 +251,35 @@ func (i *interpreter) relationshipClause(iters []iteration, m model.IRelationshi return nil, result.Value{}, err } + // Register the relationship alias in the query scope + if len(relIters) > 0 { + if err := i.refs.Alias(relAliasName, relIters[0]); err != nil { + return nil, result.Value{}, err + } + } + filteredIters := []iteration{} for _, iter := range iters { for _, relIter := range relIters { - i.refs.EnterScope() - // Define the alias for the relationship. - if err := i.refs.Alias(relAliasName, relIter); err != nil { - return nil, result.Value{}, err - } - // Define the aliases for the query iterations. + // Register or update source aliases with current iteration values for _, alias := range iter { - if err := i.refs.Alias(alias.alias, alias.obj); err != nil { - return nil, result.Value{}, err + if err := i.refs.UpdateAlias(alias.alias, alias.obj); err != nil { + // If UpdateAlias fails, the alias doesn't exist yet, so create it + if err := i.refs.Alias(alias.alias, alias.obj); err != nil { + return nil, result.Value{}, err + } } } + + // Update the relationship alias with current relationship iteration value + if err := i.refs.UpdateAlias(relAliasName, relIter); err != nil { + return nil, result.Value{}, err + } + + // Evaluate let clauses for this iteration + if err := i.evaluateLetClauses(letClauses); err != nil { + return nil, result.Value{}, err + } filter, err := i.evalExpression(suchThat) if err != nil { @@ -286,29 +292,34 @@ func (i *interpreter) relationshipClause(iters []iteration, m model.IRelationshi if (with && filter.GolangValue() == true) || (!with && filter.GolangValue() == false) { // We found a relationship where such that expression evaluated to true. Save this iter in // filteredIters and break. - i.refs.ExitScope() filteredIters = append(filteredIters, iter) break } - i.refs.ExitScope() } } return filteredIters, relSourceObj, nil } -func (i *interpreter) whereClause(iters []iteration, where model.IExpression) ([]iteration, error) { +func (i *interpreter) whereClause(iters []iteration, where model.IExpression, letClauses []*model.LetClause) ([]iteration, error) { if where == nil { return iters, nil } var filteredIters []iteration for _, iter := range iters { - i.refs.EnterScope() + for _, alias := range iter { - if err := i.refs.Alias(alias.alias, alias.obj); err != nil { - return nil, err + if err := i.refs.UpdateAlias(alias.alias, alias.obj); err != nil { + if err := i.refs.Alias(alias.alias, alias.obj); err != nil { + return nil, err + } } } + + if err := i.evaluateLetClauses(letClauses); err != nil { + return nil, err + } + filter, err := i.evalExpression(where) if err != nil { return nil, err @@ -320,11 +331,26 @@ func (i *interpreter) whereClause(iters []iteration, where model.IExpression) ([ if filter.GolangValue() == true { filteredIters = append(filteredIters, iter) } - i.refs.ExitScope() } return filteredIters, nil } +// evaluateLetClauses evaluates let clauses for the current iteration +func (i *interpreter) evaluateLetClauses(letClauses []*model.LetClause) error { + for _, letClause := range letClauses { + obj, err := i.evalExpression(letClause.Expression) + if err != nil { + return err + } + + // Update the let variable with the current iteration's value + if err := i.refs.UpdateAlias(letClause.Identifier, obj); err != nil { + return err + } + } + return nil +} + func (i *interpreter) aggregateClause(iters []iteration, aggregateClause *model.AggregateClause) ([]result.Value, error) { var filteredIters []iteration if aggregateClause.Distinct { @@ -340,51 +366,59 @@ func (i *interpreter) aggregateClause(iters []iteration, aggregateClause *model. return nil, err } - for _, iter := range filteredIters { - i.refs.EnterScope() - - if err := i.refs.Alias(aggregateClause.Identifier, aggregateObj); err != nil { - i.refs.ExitScope() - return nil, err - } + // Register the aggregate identifier in the query scope + if err := i.refs.Alias(aggregateClause.Identifier, aggregateObj); err != nil { + return nil, err + } + for _, iter := range filteredIters { for _, alias := range iter { - if err := i.refs.Alias(alias.alias, alias.obj); err != nil { - i.refs.ExitScope() - return nil, err + if err := i.refs.UpdateAlias(alias.alias, alias.obj); err != nil { + // If UpdateAlias fails, the alias doesn't exist yet, so create it + if err := i.refs.Alias(alias.alias, alias.obj); err != nil { + return nil, err + } } } + if err := i.refs.UpdateAlias(aggregateClause.Identifier, aggregateObj); err != nil { + return nil, err + } + aggregateObj, err = i.evalExpression(aggregateClause.Expression) if err != nil { - i.refs.ExitScope() return nil, err } - - i.refs.ExitScope() } return []result.Value{aggregateObj}, nil } -func (i *interpreter) returnClause(iters []iteration, returnClause *model.ReturnClause) ([]result.Value, error) { +func (i *interpreter) returnClause(iters []iteration, returnClause *model.ReturnClause, letClauses []*model.LetClause) ([]result.Value, error) { returnObjs := make([]result.Value, 0, len(iters)) for _, iter := range iters { - i.refs.EnterScope() for _, alias := range iter { - if err := i.refs.Alias(alias.alias, alias.obj); err != nil { - return nil, err + if err := i.refs.UpdateAlias(alias.alias, alias.obj); err != nil { + // If UpdateAlias fails, the alias doesn't exist yet, so create it + if err := i.refs.Alias(alias.alias, alias.obj); err != nil { + return nil, err + } } } + + if err := i.evaluateLetClauses(letClauses); err != nil { + return nil, err + } + retObj, err := i.evalExpression(returnClause.Expression) if err != nil { return nil, err } + if returnClause.Distinct { returnObjs = appendIfDistinct(returnObjs, retObj) } else { returnObjs = append(returnObjs, retObj) } - i.refs.ExitScope() } return returnObjs, nil } @@ -529,7 +563,12 @@ func (i *interpreter) getSortValue(it model.ISortByItem, v result.Value) (result return result.Value{}, fmt.Errorf("internal error - unsupported sort by item type: %T", iv) } - return i.dateTimeOrError(rv) + if converted, err := i.dateTimeOrError(rv); err == nil { + return converted, nil + } + + // Return the original value for non-DateTime types (Integer, String, etc.) + return rv, nil } func (i *interpreter) sortByColumnOrExpression(objs []result.Value, sbis []model.ISortByItem) error { @@ -539,24 +578,76 @@ func (i *interpreter) sortByColumnOrExpression(objs []result.Value, sbis []model ap, err := i.getSortValue(sortItem, a) if err != nil { sortErr = err - continue + return 0 // Return 0 to indicate equal, which preserves original order } bp, err := i.getSortValue(sortItem, b) if err != nil { sortErr = err - continue + return 0 // Return 0 to indicate equal, which preserves original order } - av := ap.GolangValue().(result.DateTime).Date - bv := bp.GolangValue().(result.DateTime).Date - - // In the future when we have an implementation of dateTime comparison without precision we should swap to using that. - // TODO(b/308012659): Implement dateTime comparison that doesn't take a precision. - if av.Equal(bv) { + + switch ap.RuntimeType() { + case types.DateTime: + // In the future when we have an implementation of dateTime comparison without precision we should swap to using that. + // TODO(b/308012659): Implement dateTime comparison that doesn't take a precision. + av := ap.GolangValue().(result.DateTime).Date + bv := bp.GolangValue().(result.DateTime).Date + if av.Equal(bv) { + continue + } else if sortItem.SortDirection() == model.DESCENDING { + return bv.Compare(av) + } + return av.Compare(bv) + case types.Date: + av := ap.GolangValue().(result.Date).Date + bv := bp.GolangValue().(result.Date).Date + if av.Equal(bv) { + continue + } else if sortItem.SortDirection() == model.DESCENDING { + return bv.Compare(av) + } + return av.Compare(bv) + case types.Integer: + av := ap.GolangValue().(int32) + bv := bp.GolangValue().(int32) + if av == bv { + continue + } else if sortItem.SortDirection() == model.DESCENDING { + return compareNumeralInt(bv, av) + } + return compareNumeralInt(av, bv) + case types.Decimal: + av := ap.GolangValue().(float64) + bv := bp.GolangValue().(float64) + if av == bv { + continue + } else if sortItem.SortDirection() == model.DESCENDING { + return compareNumeralInt(bv, av) + } + return compareNumeralInt(av, bv) + case types.Long: + av := ap.GolangValue().(int64) + bv := bp.GolangValue().(int64) + if av == bv { + continue + } else if sortItem.SortDirection() == model.DESCENDING { + return compareNumeralInt(bv, av) + } + return compareNumeralInt(av, bv) + case types.String: + av := ap.GolangValue().(string) + bv := bp.GolangValue().(string) + cmp := strings.Compare(av, bv) + if cmp == 0 { + continue + } else if sortItem.SortDirection() == model.DESCENDING { + return -cmp + } + return cmp + default: + // For other types, try to continue with next sort item continue - } else if sortItem.SortDirection() == model.DESCENDING { - return bv.Compare(av) } - return av.Compare(bv) } // All columns evaluated to equal so this sort is undefined. return 0 diff --git a/parser/expressions.go b/parser/expressions.go index 16b968d..08d9d51 100644 --- a/parser/expressions.go +++ b/parser/expressions.go @@ -498,6 +498,15 @@ func (v *visitor) VisitReferentialIdentifier(ctx cql.IReferentialIdentifierConte modelFunc, err := v.refs.ResolveLocal(name) if err != nil { + // Only create IdentifierRef for forward references in specific contexts where it's legitimate + // (like sort clauses referencing tuple fields from return clauses) + // For general invalid references, we should still return an error + if v.isInSortContext() { + return &model.IdentifierRef{ + Name: name, + Expression: model.ResultType(types.Any), + } + } return v.badExpression(err.Error(), ctx) } return modelFunc() @@ -1225,6 +1234,11 @@ func stringToTimeUnit(s string) model.Unit { return model.UNSETUNIT } +// isInSortContext returns true if we're currently parsing within a sort context +func (v *visitor) isInSortContext() bool { + return v.inSortContext +} + // unquoteString takes the given CQL string, removes the surrounding ' and unescapes it. // Escaped to character mapping: https://cql.hl7.org/03-developersguide.html#literals. // TODO(b/302003569): properly unescaping unicode characters is not yet supported diff --git a/parser/library.go b/parser/library.go index aebe401..9245840 100644 --- a/parser/library.go +++ b/parser/library.go @@ -563,4 +563,7 @@ type visitor struct { // Accumulated parsing errors to be returned to the caller. errors parsingErrors + + // Track if we're currently parsing within a sort context + inSortContext bool } diff --git a/parser/query.go b/parser/query.go index 246e8e7..9bcbc5c 100644 --- a/parser/query.go +++ b/parser/query.go @@ -230,7 +230,10 @@ func (v *visitor) parseSortClause(sc cql.ISortClauseContext, q *model.Query) (*m v.refs.EnterStructScope(func() model.IExpression { return q.Source[0] }) defer v.refs.ExitStructScope() + // Set sort context flag to allow forward references in sort expressions + v.inSortContext = true sortExpr := v.VisitExpression(sbi.ExpressionTerm()) + v.inSortContext = false switch t := sortExpr.(type) { case *model.IdentifierRef: diff --git a/tests/enginetests/cross_library_test.go b/tests/enginetests/cross_library_test.go new file mode 100644 index 0000000..32d90d2 --- /dev/null +++ b/tests/enginetests/cross_library_test.go @@ -0,0 +1,171 @@ +package enginetests + +import ( + "context" + "testing" + + "github.com/google/cql/interpreter" + "github.com/google/cql/parser" + "github.com/google/cql/result" + "github.com/google/go-cmp/cmp" + "github.com/lithammer/dedent" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestCrossLibraryAliasResolution(t *testing.T) { + tests := []struct { + name string + libraries []string + wantResult result.Value + wantError string + }{ + { + name: "Cross library function calls with alias M - should reproduce the error", + libraries: []string{ + // Helper library (like CDS_Connect_Commons) + dedent.Dedent(` + library TestCommons version '1.0.0' + using FHIR version '4.0.1' + + define function ActiveMedicationStatement(MedList List): + MedList M + where M > 5 + + define function ActiveMedicationRequest(MedList List): + MedList M + where M > 3 + `), + // Main library (like Statin Therapy) + dedent.Dedent(` + library MainLib version '1.0.0' + using FHIR version '4.0.1' + include TestCommons version '1.0.0' called TC + + define TESTRESULT: + exists(TC.ActiveMedicationStatement({1, 6, 3, 8})) + or exists(TC.ActiveMedicationRequest({1, 2, 4, 7})) + `), + }, + wantResult: newOrFatal(t, true), + }, + { + name: "Cross library with FHIR retrievals - closer to real scenario", + libraries: []string{ + // Helper library with FHIR functions + dedent.Dedent(` + library FHIRCommons version '1.0.0' + using FHIR version '4.0.1' + + define function ActiveMedicationStatement(MedList List): + MedList M + where M.status.value = 'active' + + define function ActiveMedicationRequest(MedList List): + MedList M + where M.status.value = 'active' + `), + // Main library calling FHIR functions + dedent.Dedent(` + library StatinLib version '1.0.0' + using FHIR version '4.0.1' + include FHIRCommons version '1.0.0' called FC + + define TESTRESULT: + exists(FC.ActiveMedicationStatement([MedicationStatement])) + or exists(FC.ActiveMedicationRequest([MedicationRequest])) + `), + }, + wantResult: newOrFatal(t, false), // Empty retrievals + }, + { + name: "Exact CDS_Connect_Commons pattern with let clause - works with our fix", + libraries: []string{ + // Exact pattern from CDS_Connect_Commons + dedent.Dedent(` + library CDSCommons version '1.0.0' + using FHIR version '4.0.1' + + define function PeriodToInterval(period FHIR.Period): + if period is null then + null + else + Interval[period."start".value, period."end".value] + + define function ActiveMedicationStatement(MedList List): + MedList M + let EffectivePeriod: PeriodToInterval(M.effective as FHIR.Period) + where M.status.value = 'active' + and (end of EffectivePeriod is null or end of EffectivePeriod after Now()) + `), + // Main library that calls the function + dedent.Dedent(` + library MainLib version '1.0.0' + using FHIR version '4.0.1' + include CDSCommons version '1.0.0' called C3F + + define TESTRESULT: + exists(C3F.ActiveMedicationStatement([MedicationStatement])) + `), + }, + wantResult: newOrFatal(t, false), // Empty retrievals + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + + // Add FHIRHelpers to all libraries + libsWithHelpers := make([]string, len(tc.libraries)) + for i, lib := range tc.libraries { + libsWithHelpers[i] = addFHIRHelpersLib(t, lib)[0] + } + + parsedLibs, err := p.Libraries(context.Background(), libsWithHelpers, parser.Config{}) + if tc.wantError != "" { + if err == nil { + t.Fatalf("Expected parse error containing %q, but got no error", tc.wantError) + } + if !contains(err.Error(), tc.wantError) { + t.Fatalf("Expected parse error containing %q, but got: %v", tc.wantError, err) + } + return + } + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + + results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if tc.wantError != "" { + if err == nil { + t.Fatalf("Expected eval error containing %q, but got no error", tc.wantError) + } + if !contains(err.Error(), tc.wantError) { + t.Fatalf("Expected eval error containing %q, but got: %v", tc.wantError, err) + } + return + } + if err != nil { + t.Fatalf("Eval returned unexpected error: %v", err) + } + + gotResult := getTESTRESULTWithSources(t, results) + if diff := cmp.Diff(tc.wantResult, gotResult, protocmp.Transform()); diff != "" { + t.Errorf("Eval diff (-want +got)\n%v", diff) + } + }) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || + func() bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false + }()))) +} diff --git a/tests/enginetests/query_test.go b/tests/enginetests/query_test.go index 0d7ac27..b5f9e55 100644 --- a/tests/enginetests/query_test.go +++ b/tests/enginetests/query_test.go @@ -617,6 +617,550 @@ func TestQuery(t *testing.T) { cql: "define TESTRESULT: (null as Code) l return l.code", wantResult: newOrFatal(t, nil), }, + { + // Test case for the "could not resolve the local reference to M" bug fix + name: "Alias reference in where clause", + cql: dedent.Dedent(` + using FHIR version '4.0.1' + include FHIRHelpers version '4.0.1' called FHIRHelpers + + define TESTRESULT: [Encounter] M where M.id = '1'`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Named{Value: RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}), + }, + StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}}, + }), + }, + { + // Test case for function alias resolution - reproduces the CDS_Connect_Commons issue + name: "Function with query alias", + cql: dedent.Dedent(` + define function TestFunction(IntList List): + IntList M + where M > 5 + + define TESTRESULT: TestFunction({1, 6, 3, 8})`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, 6), + newOrFatal(t, 8), + }, + StaticType: &types.List{ElementType: types.Integer}, + }), + }, + { + // Test case for function with let clause and alias - closer to CDS_Connect_Commons pattern + name: "Function with let clause and query alias", + cql: dedent.Dedent(` + define function TestFunctionWithLet(IntList List): + IntList M + let Threshold: 5 + where M > Threshold + + define TESTRESULT: TestFunctionWithLet({1, 6, 3, 8})`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, 6), + newOrFatal(t, 8), + }, + StaticType: &types.List{ElementType: types.Integer}, + }), + }, + { + // Test case for multiple function calls with same alias - reproduces the real CDS_Connect_Commons issue + name: "Multiple functions with same alias name", + cql: dedent.Dedent(` + define function ActiveMedicationStatement(MedList List): + MedList M + where M > 5 + + define function ActiveMedicationRequest(MedList List): + MedList M + where M > 3 + + define TESTRESULT: + exists(ActiveMedicationStatement({1, 6, 3, 8})) + or exists(ActiveMedicationRequest({1, 2, 4, 7}))`), + wantResult: newOrFatal(t, true), + }, + { + // Test case that reproduces the exact real-world error with FHIR retrievals and cross-library calls + name: "Cross library function calls with FHIR retrievals", + cql: dedent.Dedent(` + library TestCommons version '1.0.0' + using FHIR version '4.0.1' + + define function ActiveMedicationStatement(MedList List): + MedList M + where M.status.value = 'active' + + define function ActiveMedicationRequest(MedList List): + MedList M + where M.status.value = 'active' + + define TESTRESULT: + exists(ActiveMedicationStatement([MedicationStatement])) + or exists(ActiveMedicationRequest([MedicationRequest]))`), + wantResult: newOrFatal(t, false), // Both retrievals return empty lists, so exists() returns false + }, + { + // Test case for let clause accessing source alias - should fail with current implementation + name: "Let clause with source alias reference", + cql: dedent.Dedent(` + define function TestLetSourceAlias(nums List): + nums N + let threshold: 5 + where N > threshold + + define TESTRESULT: TestLetSourceAlias({1, 10, 15, 20})`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, 10), + newOrFatal(t, 15), + newOrFatal(t, 20), + }, // Values where N > 5 + StaticType: &types.List{ElementType: types.Integer}, + }), + }, + { + // Test case for relationship clause with alias access + name: "Relationship clause with alias access", + cql: dedent.Dedent(` + define TESTRESULT: + ({1, 2, 3}) A + with ({4, 5}) B + such that A + B > 6 + return A`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, 2), // 2+4=6, 2+5=7 > 6 + newOrFatal(t, 3), // 3+4=7, 3+5=8 > 6 + }, + StaticType: &types.List{ElementType: types.Integer}, + }), + }, + { + // Test case for return clause accessing let variables + name: "Return clause accessing let variables", + cql: dedent.Dedent(` + define function TestReturnLet(nums List): + nums N + let factor: 10 + return N * factor + + define TESTRESULT: TestReturnLet({1, 2, 3})`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, 10), + newOrFatal(t, 20), + newOrFatal(t, 30), + }, + StaticType: &types.List{ElementType: types.Integer}, + }), + }, + { + // Test case for aggregate clause with source alias + name: "Aggregate clause with source alias", + cql: dedent.Dedent(` + define function TestAggregateAlias(nums List): + nums N + aggregate A starting 0: A + N + + define TESTRESULT: TestAggregateAlias({1, 2, 3})`), + wantResult: newOrFatal(t, 6), // Sum of 1+2+3 + }, + { + // Test case for complex multi-clause with all scope issues + name: "Complex multi-clause with all scope management", + cql: dedent.Dedent(` + define function TestAllClauses(nums List): + nums N + let threshold: 5, + multiplier: 2 + with ({1, 2}) H + such that N + H > threshold + where N > threshold + return N * multiplier + + define TESTRESULT: TestAllClauses({1, 3, 6, 8, 10})`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, 12), // 6*2, where 6>5 and 6+H>5 for some H + newOrFatal(t, 16), // 8*2, where 8>5 and 8+H>5 for some H + newOrFatal(t, 20), // 10*2, where 10>5 and 10+H>5 for some H + }, + StaticType: &types.List{ElementType: types.Integer}, + }), + }, + { + // Test case that should expose scope management issues - let clause referencing source alias + name: "Let clause referencing source alias directly", + cql: dedent.Dedent(` + define function TestLetWithSourceAlias(nums List): + nums N + let doubled: N * 2 + where doubled > 10 + + define TESTRESULT: TestLetWithSourceAlias({1, 3, 6, 8})`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, 6), // 6*2=12 > 10 + newOrFatal(t, 8), // 8*2=16 > 10 + }, + StaticType: &types.List{ElementType: types.Integer}, + }), + }, + { + // Test case that reproduces the NCQA_CQLBase error - sort by tuple field + name: "Sort by tuple field from return clause", + cql: dedent.Dedent(` + define TESTRESULT: + ({1, 3, 2}) I + return Tuple { + value: I, + sortKey: I * 10 + } + sort by sortKey asc`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 1), + "sortKey": newOrFatal(t, 10), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "sortKey": types.Integer, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 2), + "sortKey": newOrFatal(t, 20), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "sortKey": types.Integer, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 3), + "sortKey": newOrFatal(t, 30), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "sortKey": types.Integer, + }}, + }), + }, + StaticType: &types.List{ElementType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "sortKey": types.Integer, + }}}, + }), + }, + { + // Test sorting by integer field in tuple + name: "Sort by integer field ascending", + cql: dedent.Dedent(` + define TESTRESULT: + ({3, 1, 2}) I + return Tuple { + value: I, + intField: I + } + sort by intField asc`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 1), + "intField": newOrFatal(t, 1), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 2), + "intField": newOrFatal(t, 2), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 3), + "intField": newOrFatal(t, 3), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}, + }), + }, + StaticType: &types.List{ElementType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}}, + }), + }, + { + // Test sorting by integer field descending + name: "Sort by integer field descending", + cql: dedent.Dedent(` + define TESTRESULT: + ({1, 3, 2}) I + return Tuple { + value: I, + intField: I + } + sort by intField desc`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 3), + "intField": newOrFatal(t, 3), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 2), + "intField": newOrFatal(t, 2), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 1), + "intField": newOrFatal(t, 1), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}, + }), + }, + StaticType: &types.List{ElementType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Integer, + "intField": types.Integer, + }}}, + }), + }, + { + // Test sorting by decimal field + name: "Sort by decimal field ascending", + cql: dedent.Dedent(` + define TESTRESULT: + ({3.2, 1.1, 2.5}) D + return Tuple { + value: D, + decField: D + } + sort by decField asc`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 1.1), + "decField": newOrFatal(t, 1.1), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Decimal, + "decField": types.Decimal, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 2.5), + "decField": newOrFatal(t, 2.5), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Decimal, + "decField": types.Decimal, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, 3.2), + "decField": newOrFatal(t, 3.2), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Decimal, + "decField": types.Decimal, + }}, + }), + }, + StaticType: &types.List{ElementType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Decimal, + "decField": types.Decimal, + }}}, + }), + }, + { + // Test sorting by string field + name: "Sort by string field ascending", + cql: dedent.Dedent(` + define TESTRESULT: + ({'zebra', 'apple', 'banana'}) S + return Tuple { + value: S, + strField: S + } + sort by strField asc`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, "apple"), + "strField": newOrFatal(t, "apple"), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.String, + "strField": types.String, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, "banana"), + "strField": newOrFatal(t, "banana"), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.String, + "strField": types.String, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, "zebra"), + "strField": newOrFatal(t, "zebra"), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.String, + "strField": types.String, + }}, + }), + }, + StaticType: &types.List{ElementType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.String, + "strField": types.String, + }}}, + }), + }, + { + // Test sorting by DateTime field + name: "Sort by DateTime field descending", + cql: dedent.Dedent(` + define TESTRESULT: + ({@2015-01-01T10:00:00.000Z, @2013-01-01T10:00:00.000Z, @2014-01-01T10:00:00.000Z}) DT + return Tuple { + value: DT, + dtField: DT + } + sort by dtField desc`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, result.DateTime{Date: time.Date(2015, time.January, 1, 10, 0, 0, 0, time.UTC), Precision: model.MILLISECOND}), + "dtField": newOrFatal(t, result.DateTime{Date: time.Date(2015, time.January, 1, 10, 0, 0, 0, time.UTC), Precision: model.MILLISECOND}), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.DateTime, + "dtField": types.DateTime, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, result.DateTime{Date: time.Date(2014, time.January, 1, 10, 0, 0, 0, time.UTC), Precision: model.MILLISECOND}), + "dtField": newOrFatal(t, result.DateTime{Date: time.Date(2014, time.January, 1, 10, 0, 0, 0, time.UTC), Precision: model.MILLISECOND}), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.DateTime, + "dtField": types.DateTime, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, result.DateTime{Date: time.Date(2013, time.January, 1, 10, 0, 0, 0, time.UTC), Precision: model.MILLISECOND}), + "dtField": newOrFatal(t, result.DateTime{Date: time.Date(2013, time.January, 1, 10, 0, 0, 0, time.UTC), Precision: model.MILLISECOND}), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.DateTime, + "dtField": types.DateTime, + }}, + }), + }, + StaticType: &types.List{ElementType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.DateTime, + "dtField": types.DateTime, + }}}, + }), + }, + { + // Test sorting by Date field + name: "Sort by Date field ascending", + cql: dedent.Dedent(` + define TESTRESULT: + ({@2015-01-01, @2013-01-01, @2014-01-01}) D + return Tuple { + value: D, + dateField: D + } + sort by dateField asc`), + wantResult: newOrFatal(t, result.List{ + Value: []result.Value{ + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, result.Date{Date: time.Date(2013, time.January, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + "dateField": newOrFatal(t, result.Date{Date: time.Date(2013, time.January, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Date, + "dateField": types.Date, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, result.Date{Date: time.Date(2014, time.January, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + "dateField": newOrFatal(t, result.Date{Date: time.Date(2014, time.January, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Date, + "dateField": types.Date, + }}, + }), + newOrFatal(t, result.Tuple{ + Value: map[string]result.Value{ + "value": newOrFatal(t, result.Date{Date: time.Date(2015, time.January, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + "dateField": newOrFatal(t, result.Date{Date: time.Date(2015, time.January, 1, 0, 0, 0, 0, defaultEvalTimestamp.Location()), Precision: model.DAY}), + }, + RuntimeType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Date, + "dateField": types.Date, + }}, + }), + }, + StaticType: &types.List{ElementType: &types.Tuple{ElementTypes: map[string]types.IType{ + "value": types.Date, + "dateField": types.Date, + }}}, + }), + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { @@ -633,10 +1177,10 @@ func TestQuery(t *testing.T) { if err != nil { t.Fatalf("Eval returned unexpected error: %v", err) } - gotResult := getTESTRESULTWithSources(t, results) - if diff := cmp.Diff(tc.wantResult, gotResult, protocmp.Transform()); diff != "" { + if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" { t.Errorf("Eval diff (-want +got)\n%v", diff) } + gotResult := getTESTRESULTWithSources(t, results) if diff := cmp.Diff(tc.wantSourceExpression, gotResult.SourceExpression(), protocmp.Transform()); tc.wantSourceExpression != nil && diff != "" { t.Errorf("Eval SourceExpression diff (-want +got)\n%v", diff) }