From ce76f9428ba8996a8879a285f3ab4c5a517e2e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oskar=20Segersv=C3=A4rd?= Date: Sat, 23 Aug 2025 16:01:00 +0200 Subject: [PATCH 1/3] Start to rebuild type inference based on algorithm W --- types/infer.go | 402 +++++++++++++++++++------------------------- types/infer_test.go | 92 +++++----- types/subst.go | 42 +++++ types/type.go | 228 ++++++++++++++++++++++++- types/type_test.go | 81 +++++++++ 5 files changed, 574 insertions(+), 271 deletions(-) create mode 100644 types/subst.go diff --git a/types/infer.go b/types/infer.go index 038cb65..5c4f7a2 100644 --- a/types/infer.go +++ b/types/infer.go @@ -34,14 +34,6 @@ func (s *Scope[T]) Bind(name string, val T) *Scope[T] { return &Scope[T]{s, name, val} } -func (s *Scope[T]) Rebind(name string, val T) bool { - if bound := s.Get(name); bound != nil { - bound.val = val - return true - } - return false -} - type TypeScope = *Scope[TypeRef] type context struct { @@ -59,6 +51,11 @@ func (c *context) bind(name string, ref TypeRef) TypeScope { return c.scope } +// Unbinds the last bound variable. +func (c *context) unbind() { + c.scope = c.scope.parent +} + func Infer(se ast.SourceExpr) (string, error) { var reg Registry var scope TypeScope @@ -91,15 +88,21 @@ func InferInScope(reg *Registry, scope TypeScope, se ast.SourceExpr) (ref TypeRe } }() - return context.infer(se.Expr), err + _, ref = context.infer(se.Expr) + return ref, err } -func (c *context) infer(expr ast.Expr) TypeRef { +func (c *context) infer(expr ast.Expr) (Subst, TypeRef) { switch x := expr.(type) { case *ast.Literal: - return literalTypeRef(x.Kind) + return nil, literalTypeRef(x.Kind) case *ast.Ident: - return c.scope.Lookup(c.source.GetString(x.Pos)) + name := c.source.GetString(x.Pos) + ref := c.scope.Lookup(name) + if ref == NeverRef { + c.bail(x.Pos, "unbound variable: "+name) + } + return nil, c.reg.Instantiate(ref) case *ast.WhereExpr: return c.where(x) case *ast.ListExpr: @@ -107,63 +110,47 @@ func (c *context) infer(expr ast.Expr) TypeRef { case *ast.RecordExpr: return c.record(x) case ast.EnumExpr: - return c.enum(x) - case *ast.FuncExpr: - unbound := c.reg.Unbound() - // Hold onto the binding, in case inferring the body rebinds its type. - binding := c.bind(c.source.GetString(x.Arg.Span()), unbound) - ret := c.infer(x.Body) - return c.reg.Func(binding.val, ret) - case *ast.CallExpr: - // Special-case pick with a value. - if pick, ok := x.Fn.(*ast.BinaryExpr); ok && pick.Op == token.PICK { - return c.pick(pick, x.Arg) - } - - typ := c.infer(x.Fn) - arg := c.infer(x.Arg) - - if !typ.IsFunction() { - // If the argument is an unbound identifier, rebind it. - id, ok := x.Fn.(*ast.Ident) - if ok && typ.IsUnbound() { - name := c.source.GetString(id.Pos) - // Let's steal the now unused (?) unbound. - fn := c.reg.Func(arg, typ) - if c.scope.Rebind(name, fn) { - return typ - } - - // Let's try to rebind a type. - } - - c.bail(x.Span(), fmt.Sprintf("cannot call non-function %s", c.reg.String(typ))) - } - fn := c.reg.GetFunc(typ) + // return nil, c.enum(x) + return nil, NeverRef - ref := c.call(fn, arg) - if ref != NeverRef { - return ref - } + case *ast.FuncExpr: + // Not sure how to juggle vars vs unbound. :/ + binder := c.reg.Var() + c.bind(c.source.GetString(x.Arg.Span()), binder) + defer c.unbind() + subs, ret := c.infer(x.Body) + return subs, c.reg.Func(c.reg.substitute(binder, subs), ret) - c.bail(x.Span(), fmt.Sprintf("cannot call %s with %s", c.reg.String(typ), c.reg.String(arg))) + case *ast.CallExpr: + // return nil, NeverRef + // // Special-case pick with a value. + // if pick, ok := x.Fn.(*ast.BinaryExpr); ok && pick.Op == token.PICK { + // return nil, c.pick(pick, x.Arg) + // } + + res := c.reg.Var() + s1, fn := c.infer(x.Fn) + s2, arg := c.infer(x.Arg) + s3 := c.reg.Unify(c.reg.substitute(fn, s2), c.reg.Func(arg, res)) + s4 := c.reg.Compose(s3, c.reg.Compose(s2, s1)) + return s4, c.reg.substitute(res, s3) case *ast.BinaryExpr: - left := c.infer(x.Left) - right := c.infer(x.Right) + _, left := c.infer(x.Left) + _, right := c.infer(x.Right) switch x.Op { - case token.PICK: - return c.pick(x, nil) - case token.CONCAT: - if left == TextRef { - if right == TextRef { - return TextRef - } else if right.IsUnbound() { - - } else { - return NeverRef - } - } + // case token.PICK: + // return c.pick(x, nil) + // case token.CONCAT: + // if left == TextRef { + // if right == TextRef { + // return TextRef + // } else if right.IsUnbound() { + + // } else { + // return NeverRef + // } + // } case token.ADD: if left == IntRef { return c.ensure(x.Right, right, IntRef) @@ -178,79 +165,41 @@ func (c *context) infer(expr ast.Expr) TypeRef { panic(fmt.Sprintf("can't infer node %T", expr)) } -func (c *context) ensure(x ast.Expr, got, want TypeRef) TypeRef { +func (c *context) ensure(x ast.Expr, got, want TypeRef) (Subst, TypeRef) { if got == want { - return got + return nil, got } - if got.IsUnbound() { - c.rebind(x, want) - return want - } - - if c.isAssignable(want, got) { - return got - } - - c.bail(x.Span(), fmt.Sprintf("expected %s, got %s", c.reg.String(want), c.reg.String(got))) - return NeverRef -} - -func (c *context) call(fn FuncRef, arg TypeRef) TypeRef { - if c.isAssignable(fn.Arg, arg) { - return fn.Result - } - - if fn.Arg.IsUnbound() { - return c.reg.Bind(fn.Result, fn.Arg, arg) - } - - if fn.Arg.IsFunction() && arg.IsFunction() { - afn := c.reg.GetFunc(fn.Arg) - bfn := c.reg.GetFunc(arg) - - // If completely unbound, replace with arg. - if afn.Arg.IsUnbound() && afn.Result.IsUnbound() { - res := c.reg.Bind(fn.Result, afn.Result, bfn.Result) - return c.reg.Bind(res, afn.Arg, bfn.Arg) - } - } - - return NeverRef -} - -func (c *context) isAssignable(a, b TypeRef) bool { - if a == b { - return true - } - - aTag, _ := a.extract() - switch aTag { - case listTag: - if b.IsList() && c.reg.GetList(b).IsUnbound() { - return true + // Really? Must make this API better. + defer func() { + if pnc := recover(); pnc != nil { + if msg, ok := pnc.(string); ok { + c.bail(x.Span(), msg) + } else { + panic(pnc) + } } - } + }() - return false + return c.reg.Unify(got, want), want } -func (c *context) where(x *ast.WhereExpr) TypeRef { +func (c *context) where(x *ast.WhereExpr) (Subst, TypeRef) { name := c.source.GetString(x.Id.Pos) - if x.Typ == nil { - // If there is no type annotation, we can infer it from the value. - c.bind(name, c.infer(x.Val)) - return c.infer(x.Expr) - } - tRef := c.typ(x.Typ) - vRef := c.infer(x.Val) - if tRef != vRef { - c.bail(x.Val.Span(), fmt.Sprintf("cannot assign %s to %s", c.reg.String(vRef), c.reg.String(tRef))) + s1, tyVal := c.infer(x.Val) + + // If there's an annotation, make sure it matches the inferred type. + if x.Typ != nil { + s2, _ := c.ensure(x.Typ, tyVal, c.typ(x.Typ)) + c.reg.apply(s2) } - c.bind(name, tRef) - return c.infer(x.Expr) + c.bind(name, tyVal) + defer c.unbind() + c.reg.apply(s1) // Apply anything learned about vars. + s2, tyExpr := c.infer(x.Expr) + return c.reg.Compose(s1, s2), tyExpr } func (c *context) typ(x ast.Expr) TypeRef { @@ -273,126 +222,115 @@ func (c *context) typ(x ast.Expr) TypeRef { return NeverRef } -func (c *context) list(x *ast.ListExpr) (res TypeRef) { +func (c *context) list(x *ast.ListExpr) (Subst, TypeRef) { + var sub Subst + res := NeverRef + for _, v := range x.Elements { - typ := c.infer(v) - if typ == res { - continue - } else if res == NeverRef { + s, typ := c.infer(v) + sub = c.reg.Compose(s, sub) + + if res == NeverRef { res = typ - } else if typ.IsUnbound() { - c.rebind(v, res) - } else { - c.bail(v.Span(), "list elements must all be of type "+c.reg.String(res)) - // Bad list. - return NeverRef + continue } + + s, res = c.ensure(v, res, typ) + sub = c.reg.Compose(s, sub) } - if res == NeverRef { - res = c.reg.Unbound() - } - return c.reg.List(res) -} -// Re-binds the type of expresion x, or fails. -func (c *context) rebind(x ast.Expr, ref TypeRef) { - name := c.source.GetString(x.Span()) - _, ok := x.(*ast.Ident) - if ok { - s := c.scope.Get(name) - if s != nil { - // Let's steal the now unused (?) generic. - s.val = ref - return - } + if res == NeverRef { + res = c.reg.Var() } - c.bail(x.Span(), fmt.Sprintf("can't rebind type of non-identifier %s", name)) + return sub, c.reg.List(res) } -func (c *context) record(x *ast.RecordExpr) TypeRef { +func (c *context) record(x *ast.RecordExpr) (Subst, TypeRef) { // If there is a rest/spread, our type is equal to that. - if x.Rest != nil { - rest := c.infer(x.Rest) - rec := c.reg.GetRecord(rest) - if rec == nil { - c.bail(x.Rest.Span(), fmt.Sprintf("cannot spread from non-record type %s", c.reg.String(rest))) - } - for k, v := range x.Entries { - expected, ok := rec[k] - if !ok { - c.bail(v.Span(), fmt.Sprintf("cannot set %s not in the base record", k)) - - } - actual := c.infer(v) - if actual != expected { - c.bail(v.Span(), fmt.Sprintf("type of %s must be %s, not %s", k, c.reg.String(expected), c.reg.String(actual))) - } - } - return rest - } - + // if x.Rest != nil { + // rest := c.infer(x.Rest) + // rec := c.reg.GetRecord(rest) + // if rec == nil { + // c.bail(x.Rest.Span(), fmt.Sprintf("cannot spread from non-record type %s", c.reg.String(rest))) + // } + // for k, v := range x.Entries { + // expected, ok := rec[k] + // if !ok { + // c.bail(v.Span(), fmt.Sprintf("cannot set %s not in the base record", k)) + + // } + // actual := c.infer(v) + // if actual != expected { + // c.bail(v.Span(), fmt.Sprintf("type of %s must be %s, not %s", k, c.reg.String(expected), c.reg.String(actual))) + // } + // } + // return s, rest + // } + + var s, s2 Subst ref := make(MapRef, len(x.Entries)) for k, v := range x.Entries { - ref[k] = c.infer(v) + s2, ref[k] = c.infer(v) + s = c.reg.Compose(s, s2) } - return c.reg.Record(ref) + return s, c.reg.Record(ref) } -func (c *context) enum(x ast.EnumExpr) TypeRef { - ref := make(MapRef, len(x)) - for _, v := range x { - name := c.source.GetString(v.Tag.Pos) - vRef := NeverRef - if v.Typ != nil { - vRef = c.infer(v.Typ) - } - ref[name] = vRef - } - return c.reg.Enum(ref) -} - -func (c *context) pick(x *ast.BinaryExpr, val ast.Expr) TypeRef { - // TODO: A binary expr for pick is annoying. - name := c.source.GetString(x.Left.Span()) - ref := c.scope.Lookup(name) - enum := c.reg.GetEnum(ref) - if enum == nil { - c.bail(x.Left.Span(), fmt.Sprintf("%s isn't an enum", name)) - } - - if id, ok := x.Right.(*ast.Ident); ok { - tag := c.source.GetString(id.Span()) - typ, ok := enum[tag] - if !ok { - c.bail(id.Span(), - fmt.Sprintf("#%s isn't a valid option for enum %s", - tag, c.reg.String(ref))) - } - - // We expect no value. - if typ == NeverRef { - // But there was one. - if val != nil { - c.bail(val.Span(), fmt.Sprintf("#%s doesn't take any value", tag)) - } - } else { - valRef := c.infer(val) - // TODO: check assignability instead - if typ != valRef { - // Wrong type. - c.bail(val.Span(), - fmt.Sprintf("cannot assign %s to #%s which needs %s", - c.reg.String(valRef), tag, c.reg.String(typ))) - return NeverRef - } - } - - return ref - } - - // TODO: better error handling? - return NeverRef -} +// func (c *context) enum(x ast.EnumExpr) TypeRef { +// ref := make(MapRef, len(x)) +// for _, v := range x { +// name := c.source.GetString(v.Tag.Pos) +// vRef := NeverRef +// if v.Typ != nil { +// vRef = c.infer(v.Typ) +// } +// ref[name] = vRef +// } +// return c.reg.Enum(ref) +// } + +// func (c *context) pick(x *ast.BinaryExpr, val ast.Expr) TypeRef { +// // TODO: A binary expr for pick is annoying. +// name := c.source.GetString(x.Left.Span()) +// ref := c.scope.Lookup(name) +// enum := c.reg.GetEnum(ref) +// if enum == nil { +// c.bail(x.Left.Span(), fmt.Sprintf("%s isn't an enum", name)) +// } + +// if id, ok := x.Right.(*ast.Ident); ok { +// tag := c.source.GetString(id.Span()) +// typ, ok := enum[tag] +// if !ok { +// c.bail(id.Span(), +// fmt.Sprintf("#%s isn't a valid option for enum %s", +// tag, c.reg.String(ref))) +// } + +// // We expect no value. +// if typ == NeverRef { +// // But there was one. +// if val != nil { +// c.bail(val.Span(), fmt.Sprintf("#%s doesn't take any value", tag)) +// } +// } else { +// valRef := c.infer(val) +// // TODO: check assignability instead +// if typ != valRef { +// // Wrong type. +// c.bail(val.Span(), +// fmt.Sprintf("cannot assign %s to #%s which needs %s", +// c.reg.String(valRef), tag, c.reg.String(typ))) +// return NeverRef +// } +// } + +// return ref +// } + +// // TODO: better error handling? +// return NeverRef +// } func literalTypeRef(tok token.Token) TypeRef { switch tok { diff --git a/types/infer_test.go b/types/infer_test.go index a214e7e..02b92d6 100644 --- a/types/infer_test.go +++ b/types/infer_test.go @@ -19,26 +19,38 @@ func TestInfer(t *testing.T) { // Primitives {`5`, `int`}, {`a ; a = 5`, `int`}, + {`1 + 2`, `int`}, // Lists - {`[]`, `list a`}, // empty list has an unbound type for its values + {`[]`, `list $0`}, // empty list has an unbound type for its values {`[1, 2]`, `list int`}, // Records {`{ a = 1 }`, `{ a : int }`}, - {`{ ..base, a = ~01 } ; base = { a = ~00 }`, `{ a : byte }`}, - // Enums - {`bool ; bool : #true #false`, `#false #true`}, - {`e ; e : #l int #r`, `#l int #r`}, - {`e::r ; e : #l int #r`, `#l int #r`}, - {`e::l 4 ; e : #l int #r`, `#l int #r`}, + // {`{ ..base, a = ~01 } ; base = { a = ~00 }`, `{ a : byte }`}, + // // Enums + // {`bool ; bool : #true #false`, `#false #true`}, + // {`e ; e : #l int #r`, `#l int #r`}, + // {`e::r ; e : #l int #r`, `#l int #r`}, + // {`e::l 4 ; e : #l int #r`, `#l int #r`}, // Functions - {`_ -> "hi"`, `a -> text`}, - {`_ -> _ -> "hi"`, `a -> b -> text`}, + {`a -> a`, `$0 -> $0`}, + {`_ -> "hi"`, `$0 -> text`}, + {`_ -> _ -> "hi"`, `$0 -> $1 -> text`}, {`(_ -> "hi") ()`, `text`}, - {`a -> b -> { a = a, b = b }`, `a -> b -> { a : a, b : b }`}, - {`(a -> b -> { a = a, b = b }) 1`, `a -> { a : int, b : a }`}, + {`a -> b -> { a = a, b = b }`, `$0 -> $1 -> { a : $0, b : $1 }`}, + {`(a -> b -> { a = a, b = b }) 1`, `$2 -> { a : int, b : $2 }`}, {`(a -> b -> { a = a, b = b }) 1 "yo" `, `{ a : int, b : text }`}, {`a ; a : int = 1`, `int`}, {`a -> a + 1`, `int -> int`}, + {`b -> (a ; a : int = b)`, `int -> int`}, + + {`f -> a -> [ a ]`, `$0 -> $1 -> list $1`}, + {`(f -> a -> [ a ]) "a"`, `$2 -> list $2`}, + {`(f -> a -> [ a ]) "a" 3`, `list int`}, + + {`f -> a -> ([ b, b ] ; b = (f a))`, `($1 -> $2) -> $1 -> list $2`}, + // If used the same, arguments must be the same. + {`a -> b -> [ a, b ]`, `$1 -> $1 -> list $1`}, + {`(a -> b -> [ a, b ]) 1`, `int -> list int`}, } for _, ex := range examples { @@ -56,19 +68,21 @@ func TestInfer(t *testing.T) { func TestInferFailure(t *testing.T) { examples := []struct{ source, message string }{ + // Unbound + {`b ; a = b -> b`, `unbound variable: b`}, // Lists - {`[1, 1.0]`, `list elements must all be of type int`}, - // Records - {`{ ..base, a = 1 } ; base = { a = ~00 }`, `type of a must be byte, not int`}, - {`{ ..1, a = 1 }`, `cannot spread from non-record type int`}, - // Enums - {`1::a`, `1 isn't an enum`}, - {`a::a ; a : #b`, `#a isn't a valid option for enum #b`}, - {`a::b 1 ; a : #b`, `#b doesn't take any value`}, - {`a::b 1 ; a : #b text`, `cannot assign int to #b which needs text`}, - {`1 + ~dd`, `expected int, got byte`}, - {`a ; a : int = 1.0`, `cannot assign float to int`}, - {`f ; f : int -> text = a -> 1`, `cannot assign a -> int to int -> text`}, + {`[1, 1.0]`, `cannot unify 'int' with 'float'`}, + // // Records + // {`{ ..base, a = 1 } ; base = { a = ~00 }`, `type of a must be byte, not int`}, + // {`{ ..1, a = 1 }`, `cannot spread from non-record type int`}, + // // Enums + // {`1::a`, `1 isn't an enum`}, + // {`a::a ; a : #b`, `#a isn't a valid option for enum #b`}, + // {`a::b 1 ; a : #b`, `#b doesn't take any value`}, + // {`a::b 1 ; a : #b text`, `cannot assign int to #b which needs text`}, + {`1 + ~dd`, `cannot unify 'byte' with 'int'`}, + {`a ; a : int = 1.0`, `cannot unify 'float' with 'int'`}, + {`f ; f : int -> text = a -> 1`, `cannot unify 'int' with 'text'`}, } for _, ex := range examples { @@ -86,29 +100,31 @@ func TestInferFailure(t *testing.T) { } func TestInferInScope(t *testing.T) { - reg := Registry{} - var scope *Scope[TypeRef] - - scope = scope.Bind("len", reg.Func(reg.List(reg.Unbound()), IntRef)) - examples := []struct{ source, typ string }{ - {`len`, `list a -> int`}, + {`len`, `list $0 -> int`}, {`len []`, `int`}, - {`f -> a -> [ a ]`, `a -> b -> list b`}, - {`(f -> a -> [ a ]) "a"`, `a -> list a`}, - {`(f -> a -> [ a ]) "a" 3`, `list int`}, - {`f -> a -> ([ b, b ] ; b = (f a))`, `(a -> b) -> a -> list b`}, - // If used the same, arguments must be the same. - {`a -> b -> [ a, b ]`, `a -> a -> list a`}, - {`(a -> b -> [ a, b ]) 1`, `int -> list int`}, - - {`(f -> a -> ([ b, b ] ; b = (f a))) len`, `list a -> list int`}, + {`(f -> a -> ([ b, b ] ; b = (f a))) len`, `list $4 -> list int`}, {`(f -> a -> ([ b, b ] ; b = (f a))) len []`, `list int`}, + {`f -> f (f 1)`, `(int -> int) -> int`}, // {`twice ; twice = f -> a -> f (f a)`, `(a -> a) -> a -> a`}, + + {`{ a = id 1, b = id "" }`, `{ a : int, b : text }`}, + // Custom functions don't work, since var/unbound are different. + // {`{ a = id2 1, b = id2 "" } ; id2 = a -> a`, `{ a : int, b : text }`}, } for _, ex := range examples { se := must(parser.ParseExpr(ex.source)) + + // New registry every test. + reg := Registry{} + var scope *Scope[TypeRef] + + scope = scope.Bind("len", reg.Func(reg.List(reg.Unbound()), IntRef)) + + a := reg.Unbound() + scope = scope.Bind("id", reg.Func(a, a)) + ref, err := InferInScope(®, scope, se) if err != nil { t.Error(err) diff --git a/types/subst.go b/types/subst.go new file mode 100644 index 0000000..61cdfb0 --- /dev/null +++ b/types/subst.go @@ -0,0 +1,42 @@ +package types + +import "strings" + +// A single substitution. +type Sub struct { + // Must be a variable. + replace TypeRef + with TypeRef +} + +// A set of substitutions. +type Subst []Sub + +func (s Subst) binds(target TypeRef) bool { + for _, s := range s { + if s.replace == target { + return true + } + } + return false +} + +func (s *Subst) bind(replace, with TypeRef) { + *s = append(*s, Sub{replace, with}) +} + +// For debugging. +func (s Subst) String(reg *Registry) string { + var b strings.Builder + + for i, s := range s { + if i != 0 { + b.WriteString(", ") + } + b.WriteString(VarString(s.replace)) + b.WriteString(": ") + b.WriteString(reg.String(s.with)) + } + + return b.String() +} diff --git a/types/type.go b/types/type.go index c079fee..5fba61f 100644 --- a/types/type.go +++ b/types/type.go @@ -3,6 +3,7 @@ package types import ( "maps" "slices" + "strconv" "strings" ) @@ -16,8 +17,19 @@ const ( enumTag recordTag unboundTag + varTag ) +var tagNames = [...]string{ + primitiveTag: "primitive", + listTag: "list", + funcTag: "func", + enumTag: "enum", + recordTag: "record", + unboundTag: "unbound", + varTag: "var", +} + // Efficiently encodes a type reference within a Registry. // // The zero value references the impossible "never" type. @@ -36,8 +48,17 @@ func (ref TypeRef) extract() (tag, int) { return tag, index } +func (ref TypeRef) tag() tag { + return tag(ref & 0x0f) +} + func (ref TypeRef) hasTag(t tag) bool { - return tag(ref&0x0f) == t + return ref.tag() == t +} + +// Returns true if both TypeRefs have the same tag. +func (ref TypeRef) SameTypeAs(other TypeRef) bool { + return ref.hasTag(other.tag()) } // IsList returns true if the TypeRef is a list. @@ -55,6 +76,11 @@ func (ref TypeRef) IsUnbound() bool { return ref.hasTag(unboundTag) } +// IsVar returns true if the TypeRef is an var type. +func (ref TypeRef) IsVar() bool { + return ref.hasTag(varTag) +} + const ( // Shortcut to the TypeRef for Never. NeverRef TypeRef = TypeRef(int(primitiveTag) | (iota << 4)) // Inlined makeTypeRef @@ -95,6 +121,12 @@ type Registry struct { // Enums and records are maps to TypeRefs. enums []MapRef records []MapRef + // Type variables that will point to another type, + // or NeverRef if not yet assigned. + // + // Schemes are types with unbound TypeRefs. When instantiating a type, + // all unbound types will be replaced with fresh vars instead. + vars []TypeRef } // Returns the number of types in the registry, for debugging. @@ -175,6 +207,60 @@ func (c *Registry) Unbound() (ref TypeRef) { return } +// Var returns a new variable TypeRef. +func (c *Registry) Var() (ref TypeRef) { + i := len(c.vars) + c.vars = append(c.vars, NeverRef) + return makeTypeRef(varTag, i) +} + +// GetVar returns the TypeRef for an record type. +func (c *Registry) GetVar(ref TypeRef) TypeRef { + tag, index := ref.extract() + if tag != varTag { + return NeverRef + } + mid := c.vars[index] + if mid.hasTag(varTag) { + // Try to resolve one more layer. + c.vars[index] = c.GetVar(mid) + } + return c.vars[index] +} + +// VarString returns the string representation of an unresolved variable. +func VarString(ref TypeRef) string { + tag, index := ref.extract() + if tag != varTag { + panic("VarString: got non-var tag " + tagNames[tag]) + } + return "$" + strconv.FormatInt(int64(index), 10) +} + +type MapTypeRef func(ref TypeRef) + +func (c *Registry) traverse(target TypeRef, mtr MapTypeRef) { + tag, index := target.extract() + switch tag { + case listTag: + c.traverse(c.lists[index], mtr) + case funcTag: + fn := c.funcs[index] + c.traverse(fn.Arg, mtr) + c.traverse(fn.Result, mtr) + case enumTag: + for _, v := range c.enums[index] { + c.traverse(v, mtr) + } + case recordTag: + for _, v := range c.records[index] { + c.traverse(v, mtr) + } + } + + mtr(target) +} + // Bind replaces all occurrences of `unbound` with `resolved` in the `target` type. func (c *Registry) Bind(target, unbound, resolved TypeRef) TypeRef { // Base case: the target is the unbound we want to replace. @@ -212,6 +298,135 @@ func (c *Registry) Bind(target, unbound, resolved TypeRef) TypeRef { return target } +func (c *Registry) Instantiate(target TypeRef) TypeRef { + var subst Subst + c.insertUnbound(target, &subst) + return c.substitute(target, subst) +} + +func (c *Registry) insertUnbound(target TypeRef, subst *Subst) { + tag, index := target.extract() + switch tag { + case unboundTag: + if !subst.binds(target) { + subst.bind(target, c.Var()) + } + case listTag: + c.insertUnbound(c.lists[index], subst) + case funcTag: + fn := c.funcs[index] + c.insertUnbound(fn.Arg, subst) + c.insertUnbound(fn.Result, subst) + // TODO: Other types + } +} + +func (c *Registry) Unify(a, b TypeRef) Subst { + if a == b { + return nil + } + if a.IsUnbound() || (a.IsVar() && c.GetVar(a) == NeverRef) { + return c.BindVar(a, b) + } + if b.IsUnbound() || (b.IsVar() && c.GetVar(b) == NeverRef) { + return c.BindVar(b, a) + } + + if a.tag() == b.tag() { + if a.IsFunction() { + aFn := c.GetFunc(a) + bFn := c.GetFunc(b) + s1 := c.Unify(aFn.Arg, bFn.Arg) + s2 := c.Unify(c.substitute(aFn.Result, s1), c.substitute(bFn.Result, s1)) + return c.Compose(s1, s2) + } + if a.IsList() { + aEl := c.GetList(a) + bEl := c.GetList(b) + return c.Unify(aEl, bEl) + } + } + + panic("cannot unify '" + c.String(a) + "' with '" + c.String(b) + "'") +} + +func (c *Registry) BindVar(a, b TypeRef) Subst { + if a == b { + return nil + } + c.traverse(b, func(ref TypeRef) { + if a == ref { + panic("occurs check failed") + } + }) + return Subst{{replace: a, with: b}} +} + +func (c *Registry) Compose(a, b Subst) Subst { + res := slices.Clone(b) + for _, s := range res { + // fmt.Fprintf(os.Stderr, "replace %s: %s\n", c.String(s.replace), c.String(s.with)) + s.with = c.substitute(s.with, a) + } + for _, s := range a { + if !res.binds(s.replace) { + res.bind(s.replace, s.with) + } + } + + return res +} + +func (c *Registry) apply(subst Subst) { + for _, s := range subst { + c.substitute(s.replace, subst) + } +} + +func (c *Registry) substitute(target TypeRef, subst Subst) TypeRef { + tag, index := target.extract() + switch tag { + case unboundTag: + for _, s := range subst { + if s.replace == target { + return s.with + } + } + case varTag: + for _, s := range subst { + if s.replace == target { + c.vars[index] = s.with + return s.with + } + } + case listTag: + return c.List( + c.substitute(c.lists[index], subst), + ) + case funcTag: + fn := c.funcs[index] + return c.Func( + c.substitute(fn.Arg, subst), + c.substitute(fn.Result, subst), + ) + case enumTag: + ref := make(MapRef, len(c.enums[index])) + for k, v := range c.enums[index] { + ref[k] = c.substitute(v, subst) + } + return c.Enum(ref) + case recordTag: + ref := make(MapRef, len(c.records[index])) + for k, v := range c.records[index] { + ref[k] = c.substitute(v, subst) + } + return c.Record(ref) + } + + // Else, the target remains unchanged. + return target +} + func findOrAdd[T comparable](ls *[]T, tag tag, el T) TypeRef { list := *ls for i, typ := range list { @@ -252,6 +467,7 @@ func (b *stringer) unbound(index int) { b.unbounds = append(b.unbounds, index) } b.WriteByte(unboundNames[i]) + // b.WriteByte(unboundNames[index]) } func (b *stringer) string(ref TypeRef, nesting int) { @@ -284,7 +500,17 @@ func (b *stringer) string(ref TypeRef, nesting int) { case recordTag: b.record(index) case unboundTag: + // b.WriteByte('\'') + // b.WriteString(strconv.FormatInt(int64(index), 10)) b.unbound(index) + case varTag: + ref := b.reg.GetVar(ref) + if ref == NeverRef { + b.WriteByte('$') + b.WriteString(strconv.FormatInt(int64(index), 10)) + } else { + b.string(ref, nesting) + } default: // The invalid type. panic("bad type-ref") diff --git a/types/type_test.go b/types/type_test.go index 12ef6fc..7fd3b6e 100644 --- a/types/type_test.go +++ b/types/type_test.go @@ -146,6 +146,87 @@ func TestBind(t *testing.T) { } +func TestSubstitute(t *testing.T) { + reg := Registry{} + + x := reg.Var() + Eq(t, reg.substitute(x, Subst{{x, IntRef}}), IntRef) + + Eq(t, reg.substitute(reg.Func(x, x), Subst{{x, IntRef}}), reg.Func(IntRef, IntRef)) + + y := reg.Var() + Eq(t, reg.substitute(reg.Func(x, y), Subst{{x, IntRef}}), reg.Func(IntRef, y)) + + // A Scheme is represented by an unbound variable. + // These are left untouched. (forall x. x) untouched. + a := reg.Unbound() + Eq(t, reg.substitute(a, Subst{{x, IntRef}}), a) + + Eq(t, reg.substitute(reg.Func(x, y), Subst{{a, IntRef}, {y, TextRef}}), reg.Func(x, TextRef)) + + // Other... + b := reg.Unbound() + f := reg.Func(a, b) + Eq(t, reg.String(f), "a -> b") + + l := reg.List(a) + Eq(t, reg.String(l), "list a") + + g := reg.Instantiate(f) + h := reg.Instantiate(f) + Eq(t, reg.String(f), "a -> b") + Eq(t, reg.String(g), "$2 -> $3") + Eq(t, reg.String(h), "$4 -> $5") + + Eq(t, reg.String(reg.Instantiate(l)), "list $6") + + gFn := reg.GetFunc(g) + + subst := Subst{ + {replace: gFn.Arg, with: IntRef}, + } + + Eq(t, reg.substitute(g, subst), reg.Func(IntRef, gFn.Result)) + Eq(t, reg.substitute(h, subst), h) +} + +func TestCompose(t *testing.T) { + reg := Registry{} + + t1 := reg.Var() + t2 := reg.Var() + + s1 := Subst{{t2, IntRef}} + Eq(t, s1.String(®), "$1: int") + + s2 := Subst{{t1, reg.Func(IntRef, t2)}} + Eq(t, s2.String(®), "$0: int -> $1") + + Eq(t, reg.Compose(nil, nil).String(®), "") + Eq(t, reg.Compose(s1, nil).String(®), "$1: int") + Eq(t, reg.Compose(nil, s1).String(®), "$1: int") + + Eq(t, reg.Compose(s1, s2).String(®), "$0: int -> int, $1: int") +} + +func TestUnify(t *testing.T) { + reg := Registry{} + + a := reg.Var() + Eq(t, reg.Unify(a, IntRef).String(®), "$0: int") + + a = reg.Var() + Eq(t, reg.Unify(reg.Func(a, IntRef), reg.Func(TextRef, IntRef)).String(®), "$1: text") + + a = reg.Var() + b := reg.Var() + Eq(t, reg.Unify(reg.Func(a, b), reg.Func(b, a)).String(®), "$2: $3") + + a = reg.Var() + b = reg.Var() + Eq(t, reg.Unify(reg.Func(a, IntRef), reg.Func(HoleRef, b)).String(®), "$5: int, $4: ()") +} + func Neq[T comparable](t *testing.T, a, b T) { if a == b { t.Errorf("Expected %v NOT to be %v", a, b) From 5892a6b42724962994f8b6a210133d01efc26dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oskar=20Segersv=C3=A4rd?= Date: Sun, 24 Aug 2025 12:54:39 +0200 Subject: [PATCH 2/3] Pivot to algorithm J from https://bernsteinbear.com/blog/type-inference/ --- scanner/scanner.go | 6 +- scanner/scanner_test.go | 1 + types/infer.go | 317 +++++++++++++++++++++------------------- types/infer_test.go | 57 +++++--- types/subst.go | 9 ++ types/type.go | 202 +++++++++++++++---------- types/type_test.go | 135 +++++------------ 7 files changed, 381 insertions(+), 346 deletions(-) diff --git a/scanner/scanner.go b/scanner/scanner.go index 6d06216..eced12e 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -158,11 +158,7 @@ func (s *Scanner) bytes() (tok token.Token, span token.Span) { s.next() } - if s.offset-offs < 2 { - s.error(s.offset, "too short base64 string") - tok = token.BAD - return - } + // The two chars `~~` encodes an empty byte array. for (s.offset-offs)%4 > 0 { if s.ch != '=' { diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index 1229fb3..5215fed 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -41,6 +41,7 @@ var elements = []elt{ {token.TEXT, `"world"`, literal}, {token.BYTE, "~ca", literal}, {token.BYTES, "~~aGVsbG8gd29ybGQ=", literal}, + {token.BYTES, "~~", literal}, // Operators and delimiters {token.ASSIGN, "=", operator}, diff --git a/types/infer.go b/types/infer.go index 5c4f7a2..e35b798 100644 --- a/types/infer.go +++ b/types/infer.go @@ -56,13 +56,15 @@ func (c *context) unbind() { c.scope = c.scope.parent } -func Infer(se ast.SourceExpr) (string, error) { - var reg Registry - var scope TypeScope - +func DefaultScope() (reg Registry, scope TypeScope) { for _, p := range primitives { scope = scope.Bind(reg.String(p), p) } + return +} + +func Infer(se ast.SourceExpr) (string, error) { + reg, scope := DefaultScope() ref, err := InferInScope(®, scope, se) if err != nil { @@ -88,21 +90,21 @@ func InferInScope(reg *Registry, scope TypeScope, se ast.SourceExpr) (ref TypeRe } }() - _, ref = context.infer(se.Expr) + ref = context.infer(se.Expr) return ref, err } -func (c *context) infer(expr ast.Expr) (Subst, TypeRef) { +func (c *context) infer(expr ast.Expr) TypeRef { switch x := expr.(type) { case *ast.Literal: - return nil, literalTypeRef(x.Kind) + return literalTypeRef(x.Kind) case *ast.Ident: name := c.source.GetString(x.Pos) ref := c.scope.Lookup(name) if ref == NeverRef { c.bail(x.Pos, "unbound variable: "+name) } - return nil, c.reg.Instantiate(ref) + return c.reg.Instantiate(ref) case *ast.WhereExpr: return c.where(x) case *ast.ListExpr: @@ -110,54 +112,62 @@ func (c *context) infer(expr ast.Expr) (Subst, TypeRef) { case *ast.RecordExpr: return c.record(x) case ast.EnumExpr: - // return nil, c.enum(x) - return nil, NeverRef + return c.enum(x) case *ast.FuncExpr: // Not sure how to juggle vars vs unbound. :/ binder := c.reg.Var() c.bind(c.source.GetString(x.Arg.Span()), binder) defer c.unbind() - subs, ret := c.infer(x.Body) - return subs, c.reg.Func(c.reg.substitute(binder, subs), ret) + ret := c.infer(x.Body) + return c.reg.Func(binder, ret) case *ast.CallExpr: - // return nil, NeverRef - // // Special-case pick with a value. - // if pick, ok := x.Fn.(*ast.BinaryExpr); ok && pick.Op == token.PICK { - // return nil, c.pick(pick, x.Arg) - // } + // Special-case pick with a value. + if pick, ok := x.Fn.(*ast.BinaryExpr); ok && pick.Op == token.PICK { + return c.pick(pick, x.Arg) + } res := c.reg.Var() - s1, fn := c.infer(x.Fn) - s2, arg := c.infer(x.Arg) - s3 := c.reg.Unify(c.reg.substitute(fn, s2), c.reg.Func(arg, res)) - s4 := c.reg.Compose(s3, c.reg.Compose(s2, s1)) - return s4, c.reg.substitute(res, s3) + fn := c.infer(x.Fn) + arg := c.infer(x.Arg) + c.ensure(x, fn, c.reg.Func(arg, res)) + return res case *ast.BinaryExpr: - _, left := c.infer(x.Left) - _, right := c.infer(x.Right) + if x.Op == token.PICK { + return c.pick(x, nil) + } + + left := c.infer(x.Left) + right := c.infer(x.Right) switch x.Op { - // case token.PICK: - // return c.pick(x, nil) - // case token.CONCAT: - // if left == TextRef { - // if right == TextRef { - // return TextRef - // } else if right.IsUnbound() { - - // } else { - // return NeverRef - // } - // } - case token.ADD: - if left == IntRef { - return c.ensure(x.Right, right, IntRef) + case token.PREPEND: + return c.pend(x.Left, x.Right, left, right) + case token.APPEND: + return c.pend(x.Right, x.Left, right, left) + case token.CONCAT: + if left == TextRef || right == TextRef { + c.ensure(x, left, right) + return TextRef } - if right == IntRef { - return c.ensure(x.Left, left, IntRef) + if left == BytesRef || right == BytesRef { + c.ensure(x, left, right) + return BytesRef } + // Local var to ensure left and right are lists. + a := c.reg.List(c.reg.Var()) + c.ensure(x, left, right) + c.ensure(x, left, a) + return a + case token.ADD, token.SUB, token.MUL: + if left == FloatRef || right == FloatRef { + c.ensure(x, left, right) + return FloatRef + } + // Assume int, like ML does. + c.ensure(x.Left, left, IntRef) + return c.ensure(x.Right, right, IntRef) } panic(fmt.Sprintf("can't infer binary expression %s", x.Op.String())) } @@ -165,41 +175,38 @@ func (c *context) infer(expr ast.Expr) (Subst, TypeRef) { panic(fmt.Sprintf("can't infer node %T", expr)) } -func (c *context) ensure(x ast.Expr, got, want TypeRef) (Subst, TypeRef) { - if got == want { - return nil, got - } - - // Really? Must make this API better. - defer func() { - if pnc := recover(); pnc != nil { - if msg, ok := pnc.(string); ok { - c.bail(x.Span(), msg) - } else { - panic(pnc) +func (c *context) ensure(x ast.Expr, got, want TypeRef) TypeRef { + if got != want { + // Really? Must make this API better. + defer func() { + if pnc := recover(); pnc != nil { + if msg, ok := pnc.(string); ok { + c.bail(x.Span(), msg) + } else { + panic(pnc) + } } - } - }() + }() - return c.reg.Unify(got, want), want + c.reg.unify(got, want) + } + return want } -func (c *context) where(x *ast.WhereExpr) (Subst, TypeRef) { +func (c *context) where(x *ast.WhereExpr) TypeRef { name := c.source.GetString(x.Id.Pos) - s1, tyVal := c.infer(x.Val) + tyVal := c.infer(x.Val) // If there's an annotation, make sure it matches the inferred type. if x.Typ != nil { - s2, _ := c.ensure(x.Typ, tyVal, c.typ(x.Typ)) - c.reg.apply(s2) + c.ensure(x.Typ, tyVal, c.typ(x.Typ)) } - c.bind(name, tyVal) + c.bind(name, c.reg.generalize(tyVal)) defer c.unbind() - c.reg.apply(s1) // Apply anything learned about vars. - s2, tyExpr := c.infer(x.Expr) - return c.reg.Compose(s1, s2), tyExpr + tyExpr := c.infer(x.Expr) + return tyExpr } func (c *context) typ(x ast.Expr) TypeRef { @@ -222,115 +229,110 @@ func (c *context) typ(x ast.Expr) TypeRef { return NeverRef } -func (c *context) list(x *ast.ListExpr) (Subst, TypeRef) { - var sub Subst +func (c *context) list(x *ast.ListExpr) TypeRef { res := NeverRef for _, v := range x.Elements { - s, typ := c.infer(v) - sub = c.reg.Compose(s, sub) + typ := c.infer(v) if res == NeverRef { res = typ continue } - s, res = c.ensure(v, res, typ) - sub = c.reg.Compose(s, sub) + c.ensure(v, res, typ) } if res == NeverRef { res = c.reg.Var() } - return sub, c.reg.List(res) + return c.reg.List(res) } -func (c *context) record(x *ast.RecordExpr) (Subst, TypeRef) { +func (c *context) record(x *ast.RecordExpr) TypeRef { // If there is a rest/spread, our type is equal to that. - // if x.Rest != nil { - // rest := c.infer(x.Rest) - // rec := c.reg.GetRecord(rest) - // if rec == nil { - // c.bail(x.Rest.Span(), fmt.Sprintf("cannot spread from non-record type %s", c.reg.String(rest))) - // } - // for k, v := range x.Entries { - // expected, ok := rec[k] - // if !ok { - // c.bail(v.Span(), fmt.Sprintf("cannot set %s not in the base record", k)) - - // } - // actual := c.infer(v) - // if actual != expected { - // c.bail(v.Span(), fmt.Sprintf("type of %s must be %s, not %s", k, c.reg.String(expected), c.reg.String(actual))) - // } - // } - // return s, rest - // } - - var s, s2 Subst + if x.Rest != nil { + rest := c.infer(x.Rest) + rec := c.reg.GetRecord(rest) + if rec == nil { + c.bail(x.Rest.Span(), fmt.Sprintf("cannot spread from non-record type %s", c.reg.String(rest))) + } + for k, v := range x.Entries { + expected, ok := rec[k] + if !ok { + c.bail(v.Span(), fmt.Sprintf("cannot set %s not in the base record", k)) + + } + actual := c.infer(v) + if actual != expected { + c.bail(v.Span(), fmt.Sprintf("type of %s must be %s, not %s", k, c.reg.String(expected), c.reg.String(actual))) + } + } + return rest + } + ref := make(MapRef, len(x.Entries)) for k, v := range x.Entries { - s2, ref[k] = c.infer(v) - s = c.reg.Compose(s, s2) + ref[k] = c.infer(v) + } + return c.reg.Record(ref) +} + +func (c *context) enum(x ast.EnumExpr) TypeRef { + ref := make(MapRef, len(x)) + for _, v := range x { + name := c.source.GetString(v.Tag.Pos) + vRef := NeverRef + if v.Typ != nil { + vRef = c.infer(v.Typ) + } + ref[name] = vRef } - return s, c.reg.Record(ref) + return c.reg.Enum(ref) } -// func (c *context) enum(x ast.EnumExpr) TypeRef { -// ref := make(MapRef, len(x)) -// for _, v := range x { -// name := c.source.GetString(v.Tag.Pos) -// vRef := NeverRef -// if v.Typ != nil { -// vRef = c.infer(v.Typ) -// } -// ref[name] = vRef -// } -// return c.reg.Enum(ref) -// } - -// func (c *context) pick(x *ast.BinaryExpr, val ast.Expr) TypeRef { -// // TODO: A binary expr for pick is annoying. -// name := c.source.GetString(x.Left.Span()) -// ref := c.scope.Lookup(name) -// enum := c.reg.GetEnum(ref) -// if enum == nil { -// c.bail(x.Left.Span(), fmt.Sprintf("%s isn't an enum", name)) -// } - -// if id, ok := x.Right.(*ast.Ident); ok { -// tag := c.source.GetString(id.Span()) -// typ, ok := enum[tag] -// if !ok { -// c.bail(id.Span(), -// fmt.Sprintf("#%s isn't a valid option for enum %s", -// tag, c.reg.String(ref))) -// } - -// // We expect no value. -// if typ == NeverRef { -// // But there was one. -// if val != nil { -// c.bail(val.Span(), fmt.Sprintf("#%s doesn't take any value", tag)) -// } -// } else { -// valRef := c.infer(val) -// // TODO: check assignability instead -// if typ != valRef { -// // Wrong type. -// c.bail(val.Span(), -// fmt.Sprintf("cannot assign %s to #%s which needs %s", -// c.reg.String(valRef), tag, c.reg.String(typ))) -// return NeverRef -// } -// } - -// return ref -// } - -// // TODO: better error handling? -// return NeverRef -// } +func (c *context) pick(x *ast.BinaryExpr, val ast.Expr) TypeRef { + // TODO: A binary expr for pick is annoying. + name := c.source.GetString(x.Left.Span()) + ref := c.scope.Lookup(name) + enum := c.reg.GetEnum(ref) + if enum == nil { + c.bail(x.Left.Span(), fmt.Sprintf("%s isn't an enum", name)) + } + + if id, ok := x.Right.(*ast.Ident); ok { + tag := c.source.GetString(id.Span()) + typ, ok := enum[tag] + if !ok { + c.bail(id.Span(), + fmt.Sprintf("#%s isn't a valid option for enum %s", + tag, c.reg.String(ref))) + } + + // We expect no value. + if typ == NeverRef { + // But there was one. + if val != nil { + c.bail(val.Span(), fmt.Sprintf("#%s doesn't take any value", tag)) + } + } else { + valRef := c.infer(val) + // TODO: check assignability instead + if typ != valRef { + // Wrong type. + c.bail(val.Span(), + fmt.Sprintf("cannot assign %s to #%s which needs %s", + c.reg.String(valRef), tag, c.reg.String(typ))) + return NeverRef + } + } + + return ref + } + + // TODO: better error handling? + return NeverRef +} func literalTypeRef(tok token.Token) TypeRef { switch tok { @@ -350,3 +352,16 @@ func literalTypeRef(tok token.Token) TypeRef { return NeverRef } + +// Either pre-pend or ap-pend. +func (c *context) pend(singleX, listX ast.Expr, single, list TypeRef) TypeRef { + // Special-case bytes. + if single == ByteRef || list == BytesRef { + c.ensure(singleX, single, ByteRef) + c.ensure(listX, list, BytesRef) + return BytesRef + } + + c.ensure(singleX, c.reg.List(single), list) + return list +} diff --git a/types/infer_test.go b/types/infer_test.go index 02b92d6..fe29802 100644 --- a/types/infer_test.go +++ b/types/infer_test.go @@ -25,17 +25,35 @@ func TestInfer(t *testing.T) { {`[1, 2]`, `list int`}, // Records {`{ a = 1 }`, `{ a : int }`}, - // {`{ ..base, a = ~01 } ; base = { a = ~00 }`, `{ a : byte }`}, + {`{ ..base, a = ~01 } ; base = { a = ~00 }`, `{ a : byte }`}, // // Enums - // {`bool ; bool : #true #false`, `#false #true`}, - // {`e ; e : #l int #r`, `#l int #r`}, - // {`e::r ; e : #l int #r`, `#l int #r`}, - // {`e::l 4 ; e : #l int #r`, `#l int #r`}, + {`bool ; bool : #true #false`, `#false #true`}, + {`e ; e : #l int #r`, `#l int #r`}, + {`e::r ; e : #l int #r`, `#l int #r`}, + {`e::l 4 ; e : #l int #r`, `#l int #r`}, // Functions {`a -> a`, `$0 -> $0`}, {`_ -> "hi"`, `$0 -> text`}, {`_ -> _ -> "hi"`, `$0 -> $1 -> text`}, {`(_ -> "hi") ()`, `text`}, + + // Prepend and append + {`a -> a >+ []`, `$1 -> list $1`}, + {`a -> a +< int`, `list int -> list int`}, + {`a -> a >+ ~~1111`, `byte -> bytes`}, + {`a -> a +< ~ff`, `bytes -> bytes`}, + + // Concat + {`"hi " ++ "you!"`, `text`}, + {`[] ++ [1]`, `list int`}, + {`~~1111 ++ ~~`, `bytes`}, + {`a -> b -> a ++ b`, `list $2 -> list $2 -> list $2`}, + + // Math + {`a -> 1.0 + a`, `float -> float`}, + {`4 - 3`, `int`}, + {`a -> b -> a * b`, `int -> int -> int`}, // Default to int. + {`a -> b -> { a = a, b = b }`, `$0 -> $1 -> { a : $0, b : $1 }`}, {`(a -> b -> { a = a, b = b }) 1`, `$2 -> { a : int, b : $2 }`}, {`(a -> b -> { a = a, b = b }) 1 "yo" `, `{ a : int, b : text }`}, @@ -43,6 +61,10 @@ func TestInfer(t *testing.T) { {`a -> a + 1`, `int -> int`}, {`b -> (a ; a : int = b)`, `int -> int`}, + {`f -> f (f 1)`, `(int -> int) -> int`}, + {`a -> f -> f (f a)`, `$2 -> ($2 -> $2) -> $2`}, + {`f -> a -> f (f a)`, `($2 -> $2) -> $2 -> $2`}, + {`f -> a -> [ a ]`, `$0 -> $1 -> list $1`}, {`(f -> a -> [ a ]) "a"`, `$2 -> list $2`}, {`(f -> a -> [ a ]) "a" 3`, `list int`}, @@ -72,17 +94,21 @@ func TestInferFailure(t *testing.T) { {`b ; a = b -> b`, `unbound variable: b`}, // Lists {`[1, 1.0]`, `cannot unify 'int' with 'float'`}, - // // Records - // {`{ ..base, a = 1 } ; base = { a = ~00 }`, `type of a must be byte, not int`}, - // {`{ ..1, a = 1 }`, `cannot spread from non-record type int`}, - // // Enums - // {`1::a`, `1 isn't an enum`}, - // {`a::a ; a : #b`, `#a isn't a valid option for enum #b`}, - // {`a::b 1 ; a : #b`, `#b doesn't take any value`}, - // {`a::b 1 ; a : #b text`, `cannot assign int to #b which needs text`}, + {`[4] ++ ["text"]`, `cannot unify 'int' with 'text'`}, + {`4 ++ 6`, `cannot unify 'int' with 'list $0'`}, + // Records + {`{ ..base, a = 1 } ; base = { a = ~00 }`, `type of a must be byte, not int`}, + {`{ ..1, a = 1 }`, `cannot spread from non-record type int`}, + // Enums + {`1::a`, `1 isn't an enum`}, + {`a::a ; a : #b`, `#a isn't a valid option for enum #b`}, + {`a::b 1 ; a : #b`, `#b doesn't take any value`}, + {`a::b 1 ; a : #b text`, `cannot assign int to #b which needs text`}, {`1 + ~dd`, `cannot unify 'byte' with 'int'`}, {`a ; a : int = 1.0`, `cannot unify 'float' with 'int'`}, {`f ; f : int -> text = a -> 1`, `cannot unify 'int' with 'text'`}, + // Math + {`1 + 1.0`, `cannot unify 'int' with 'float'`}, } for _, ex := range examples { @@ -105,12 +131,9 @@ func TestInferInScope(t *testing.T) { {`len []`, `int`}, {`(f -> a -> ([ b, b ] ; b = (f a))) len`, `list $4 -> list int`}, {`(f -> a -> ([ b, b ] ; b = (f a))) len []`, `list int`}, - {`f -> f (f 1)`, `(int -> int) -> int`}, - // {`twice ; twice = f -> a -> f (f a)`, `(a -> a) -> a -> a`}, {`{ a = id 1, b = id "" }`, `{ a : int, b : text }`}, - // Custom functions don't work, since var/unbound are different. - // {`{ a = id2 1, b = id2 "" } ; id2 = a -> a`, `{ a : int, b : text }`}, + {`{ a = id2 1, b = id2 "" } ; id2 = a -> a`, `{ a : int, b : text }`}, } for _, ex := range examples { diff --git a/types/subst.go b/types/subst.go index 61cdfb0..d181bfe 100644 --- a/types/subst.go +++ b/types/subst.go @@ -21,6 +21,15 @@ func (s Subst) binds(target TypeRef) bool { return false } +func (s Subst) bound(target TypeRef) TypeRef { + for _, s := range s { + if s.replace == target { + return s.with + } + } + return NeverRef +} + func (s *Subst) bind(replace, with TypeRef) { *s = append(*s, Sub{replace, with}) } diff --git a/types/type.go b/types/type.go index 5fba61f..74e8f54 100644 --- a/types/type.go +++ b/types/type.go @@ -52,6 +52,10 @@ func (ref TypeRef) tag() tag { return tag(ref & 0x0f) } +func (ref TypeRef) index() int { + return int(ref >> 4) +} + func (ref TypeRef) hasTag(t tag) bool { return ref.tag() == t } @@ -214,20 +218,34 @@ func (c *Registry) Var() (ref TypeRef) { return makeTypeRef(varTag, i) } +// Resolve follows variables to their last bound var. +func (c *Registry) Resolve(ref TypeRef) TypeRef { + // Ignore non-vars. + if !ref.IsVar() { + return ref + } + other := c.Resolve(c.vars[ref.index()]) + if other == NeverRef { + return ref + } + c.vars[ref.index()] = other + return other +} + // GetVar returns the TypeRef for an record type. func (c *Registry) GetVar(ref TypeRef) TypeRef { + c.Resolve(ref) tag, index := ref.extract() if tag != varTag { - return NeverRef - } - mid := c.vars[index] - if mid.hasTag(varTag) { - // Try to resolve one more layer. - c.vars[index] = c.GetVar(mid) + return ref } return c.vars[index] } +func (c *Registry) IsFree(ref TypeRef) bool { + return c.Resolve(ref).IsVar() +} + // VarString returns the string representation of an unresolved variable. func VarString(ref TypeRef) string { tag, index := ref.extract() @@ -261,35 +279,26 @@ func (c *Registry) traverse(target TypeRef, mtr MapTypeRef) { mtr(target) } -// Bind replaces all occurrences of `unbound` with `resolved` in the `target` type. -func (c *Registry) Bind(target, unbound, resolved TypeRef) TypeRef { - // Base case: the target is the unbound we want to replace. - if target == unbound { - return resolved - } +type Replacer func(ref TypeRef) TypeRef +func (c *Registry) replace(target TypeRef, f Replacer) TypeRef { tag, index := target.extract() switch tag { case listTag: - return c.List( - c.Bind(c.lists[index], unbound, resolved), - ) + return c.List(f(c.lists[index])) case funcTag: fn := c.funcs[index] - return c.Func( - c.Bind(fn.Arg, unbound, resolved), - c.Bind(fn.Result, unbound, resolved), - ) + return c.Func(f(fn.Arg), f(fn.Result)) case enumTag: ref := make(MapRef, len(c.enums[index])) for k, v := range c.enums[index] { - ref[k] = c.Bind(v, unbound, resolved) + ref[k] = f(v) } return c.Enum(ref) case recordTag: ref := make(MapRef, len(c.records[index])) for k, v := range c.records[index] { - ref[k] = c.Bind(v, unbound, resolved) + ref[k] = f(v) } return c.Record(ref) } @@ -298,6 +307,33 @@ func (c *Registry) Bind(target, unbound, resolved TypeRef) TypeRef { return target } +// bind binds a free variable to a type. +func (reg *Registry) bind(a, b TypeRef) { + // Get to the bottom of `a`. + a = reg.Resolve(a) + + if !a.IsVar() { + panic("cannot bind non-free var " + reg.String(a)) + } + reg.vars[a.index()] = b +} + +// The opposite of instantiate. +func (c *Registry) generalize(target TypeRef) TypeRef { + var subst Subst + return c.replace(target, func(other TypeRef) TypeRef { + if other.IsVar() { + b := subst.bound(other) + if b == NeverRef { + b = c.Unbound() + subst.bind(other, b) + } + return b + } + return other + }) +} + func (c *Registry) Instantiate(target TypeRef) TypeRef { var subst Subst c.insertUnbound(target, &subst) @@ -321,65 +357,50 @@ func (c *Registry) insertUnbound(target TypeRef, subst *Subst) { } } -func (c *Registry) Unify(a, b TypeRef) Subst { - if a == b { - return nil - } - if a.IsUnbound() || (a.IsVar() && c.GetVar(a) == NeverRef) { - return c.BindVar(a, b) - } - if b.IsUnbound() || (b.IsVar() && c.GetVar(b) == NeverRef) { - return c.BindVar(b, a) - } +func (c *Registry) unify(a, b TypeRef) { + a = c.Resolve(a) + b = c.Resolve(b) - if a.tag() == b.tag() { - if a.IsFunction() { - aFn := c.GetFunc(a) - bFn := c.GetFunc(b) - s1 := c.Unify(aFn.Arg, bFn.Arg) - s2 := c.Unify(c.substitute(aFn.Result, s1), c.substitute(bFn.Result, s1)) - return c.Compose(s1, s2) - } - if a.IsList() { - aEl := c.GetList(a) - bEl := c.GetList(b) - return c.Unify(aEl, bEl) - } + tag, index := a.extract() + if tag == unboundTag { + panic("unexpected unbound var during unification") } - panic("cannot unify '" + c.String(a) + "' with '" + c.String(b) + "'") -} - -func (c *Registry) BindVar(a, b TypeRef) Subst { - if a == b { - return nil + if tag == varTag { + c.traverse(b, func(ref TypeRef) { + if a == ref { + panic("occurs check failed") + } + }) + c.vars[index] = b + return } - c.traverse(b, func(ref TypeRef) { - if a == ref { - panic("occurs check failed") - } - }) - return Subst{{replace: a, with: b}} -} -func (c *Registry) Compose(a, b Subst) Subst { - res := slices.Clone(b) - for _, s := range res { - // fmt.Fprintf(os.Stderr, "replace %s: %s\n", c.String(s.replace), c.String(s.with)) - s.with = c.substitute(s.with, a) - } - for _, s := range a { - if !res.binds(s.replace) { - res.bind(s.replace, s.with) - } + if b.IsVar() { + c.unify(b, a) + return } - return res -} + bTag, bIndex := b.extract() + if tag == bTag { + switch tag { + case funcTag: + aFn := c.GetFunc(a) + bFn := c.GetFunc(b) -func (c *Registry) apply(subst Subst) { - for _, s := range subst { - c.substitute(s.replace, subst) + c.unify(aFn.Arg, bFn.Arg) + c.unify(aFn.Result, bFn.Result) + case listTag: + c.unify(c.GetList(a), c.GetList(b)) + case recordTag: + c.unify(c.GetList(a), c.GetList(b)) + case primitiveTag: + if index != bIndex { + panic("cannot unify '" + c.String(a) + "' with '" + c.String(b) + "'") + } + } + } else { + panic("cannot unify '" + c.String(a) + "' with '" + c.String(b) + "'") } } @@ -427,6 +448,39 @@ func (c *Registry) substitute(target TypeRef, subst Subst) TypeRef { return target } +// DebugString returns a string representation for TypeRef. +func (reg *Registry) DebugString() string { + var s stringer + s.reg = reg + + s.WriteString("Vars:\n") + for i := range reg.vars { + s.WriteString(" $") + s.WriteString(strconv.Itoa(i)) + s.WriteString(": ") + s.string(makeTypeRef(varTag, i), 0) + s.WriteString("\n") + } + s.WriteString("Functions:\n") + for i := range reg.funcs { + s.WriteString(" ") + s.WriteString(strconv.Itoa(i)) + s.WriteString(": ") + s.string(makeTypeRef(funcTag, i), 0) + s.WriteString("\n") + } + s.WriteString("Lists:\n") + for i := range reg.lists { + s.WriteString(" ") + s.WriteString(strconv.Itoa(i)) + s.WriteString(": ") + s.string(makeTypeRef(listTag, i), 0) + s.WriteString("\n") + } + + return s.String() +} + func findOrAdd[T comparable](ls *[]T, tag tag, el T) TypeRef { list := *ls for i, typ := range list { @@ -500,14 +554,12 @@ func (b *stringer) string(ref TypeRef, nesting int) { case recordTag: b.record(index) case unboundTag: - // b.WriteByte('\'') - // b.WriteString(strconv.FormatInt(int64(index), 10)) b.unbound(index) case varTag: ref := b.reg.GetVar(ref) if ref == NeverRef { b.WriteByte('$') - b.WriteString(strconv.FormatInt(int64(index), 10)) + b.WriteString(strconv.Itoa(index)) } else { b.string(ref, nesting) } diff --git a/types/type_test.go b/types/type_test.go index 7fd3b6e..8c727c3 100644 --- a/types/type_test.go +++ b/types/type_test.go @@ -95,76 +95,12 @@ func TestGeneric(t *testing.T) { Eq(t, reg.String(listMap), "(a -> b) -> list a -> list b") } -func TestBind(t *testing.T) { +func TestInstantiate(t *testing.T) { reg := Registry{} - a := reg.Unbound() - Eq(t, reg.String(a), "a") - Eq(t, reg.Bind(a, a, IntRef), IntRef) - - id := reg.Func(a, a) - Eq(t, reg.String(id), "a -> a") - Eq(t, reg.Size(), 1) - - inc := reg.Bind(id, a, IntRef) - Eq(t, reg.String(inc), "int -> int") - Eq(t, reg.Size(), 2) - - b := reg.Unbound() - listMap := reg.Func(reg.Func(a, b), reg.Func(reg.List(a), reg.List(b))) - Eq(t, reg.String(listMap), "(a -> b) -> list a -> list b") - Eq(t, reg.Size(), 7) - - // Replace b -> int. - listMapBInt := reg.Bind(listMap, b, IntRef) - Eq(t, reg.String(listMapBInt), "(a -> int) -> list a -> list int") - Eq(t, reg.Size(), 11) - - // Now also a -> int. - listMapABInt := reg.Bind(listMapBInt, a, IntRef) - Eq(t, reg.String(listMapABInt), "(int -> int) -> list int -> list int") - Eq(t, reg.Size(), 13) - - // Let's go the other way, replace a -> int. - listMapAInt := reg.Bind(listMap, a, IntRef) - Eq(t, reg.String(listMapAInt), "(int -> a) -> list int -> list a") - Eq(t, reg.Size(), 16) - - // Returns same type if resolved the other way. - Eq(t, listMapABInt, reg.Bind(listMapAInt, b, IntRef)) - Eq(t, reg.Size(), 16) - - record := reg.Record(MapRef{"kind": IntRef, "a": a, "b": b}) - enum := reg.Enum(MapRef{"a": a, "b": b}) - recordToEnum := reg.Func(record, enum) - - // TODO: Don't sort order of record keys. - Eq(t, reg.String(recordToEnum), "{ a : a, b : b, kind : int } -> #a a #b b") - - recordToEnumAInt := reg.Bind(recordToEnum, a, IntRef) - Eq(t, reg.String(recordToEnumAInt), "{ a : int, b : a, kind : int } -> #a int #b a") - -} - -func TestSubstitute(t *testing.T) { - reg := Registry{} - - x := reg.Var() - Eq(t, reg.substitute(x, Subst{{x, IntRef}}), IntRef) - - Eq(t, reg.substitute(reg.Func(x, x), Subst{{x, IntRef}}), reg.Func(IntRef, IntRef)) - - y := reg.Var() - Eq(t, reg.substitute(reg.Func(x, y), Subst{{x, IntRef}}), reg.Func(IntRef, y)) - // A Scheme is represented by an unbound variable. // These are left untouched. (forall x. x) untouched. a := reg.Unbound() - Eq(t, reg.substitute(a, Subst{{x, IntRef}}), a) - - Eq(t, reg.substitute(reg.Func(x, y), Subst{{a, IntRef}, {y, TextRef}}), reg.Func(x, TextRef)) - - // Other... b := reg.Unbound() f := reg.Func(a, b) Eq(t, reg.String(f), "a -> b") @@ -174,57 +110,60 @@ func TestSubstitute(t *testing.T) { g := reg.Instantiate(f) h := reg.Instantiate(f) - Eq(t, reg.String(f), "a -> b") - Eq(t, reg.String(g), "$2 -> $3") - Eq(t, reg.String(h), "$4 -> $5") + Eq(t, reg.String(g), "$0 -> $1") + Eq(t, reg.String(h), "$2 -> $3") - Eq(t, reg.String(reg.Instantiate(l)), "list $6") + Eq(t, reg.String(reg.Instantiate(l)), "list $4") +} - gFn := reg.GetFunc(g) +func TestGetVar(t *testing.T) { + reg := Registry{} - subst := Subst{ - {replace: gFn.Arg, with: IntRef}, - } + a := reg.Var() + b := reg.Var() + c := reg.Var() + reg.bind(a, b) + reg.bind(b, c) + reg.bind(c, IntRef) - Eq(t, reg.substitute(g, subst), reg.Func(IntRef, gFn.Result)) - Eq(t, reg.substitute(h, subst), h) + Eq(t, reg.GetVar(a), IntRef) } -func TestCompose(t *testing.T) { +func TestResolve(t *testing.T) { reg := Registry{} - t1 := reg.Var() - t2 := reg.Var() + a := reg.Var() + b := reg.Var() - s1 := Subst{{t2, IntRef}} - Eq(t, s1.String(®), "$1: int") + Eq(t, reg.Resolve(a), a) + Eq(t, reg.Resolve(b), b) + Eq(t, reg.IsFree(a), true) + Eq(t, reg.IsFree(b), true) - s2 := Subst{{t1, reg.Func(IntRef, t2)}} - Eq(t, s2.String(®), "$0: int -> $1") + reg.bind(a, b) - Eq(t, reg.Compose(nil, nil).String(®), "") - Eq(t, reg.Compose(s1, nil).String(®), "$1: int") - Eq(t, reg.Compose(nil, s1).String(®), "$1: int") + Eq(t, reg.IsFree(a), true) + Eq(t, reg.IsFree(b), true) + Eq(t, reg.Resolve(a), b) + Eq(t, reg.Resolve(b), b) - Eq(t, reg.Compose(s1, s2).String(®), "$0: int -> int, $1: int") + reg.bind(b, IntRef) + + Eq(t, reg.IsFree(a), false) + Eq(t, reg.IsFree(b), false) + Eq(t, reg.Resolve(a), IntRef) + Eq(t, reg.Resolve(b), IntRef) } -func TestUnify(t *testing.T) { +func TestUnify_J(t *testing.T) { reg := Registry{} + res := reg.Var() a := reg.Var() - Eq(t, reg.Unify(a, IntRef).String(®), "$0: int") - - a = reg.Var() - Eq(t, reg.Unify(reg.Func(a, IntRef), reg.Func(TextRef, IntRef)).String(®), "$1: text") - - a = reg.Var() - b := reg.Var() - Eq(t, reg.Unify(reg.Func(a, b), reg.Func(b, a)).String(®), "$2: $3") + fn := reg.Func(a, reg.List(a)) + reg.unify(fn, reg.Func(IntRef, res)) - a = reg.Var() - b = reg.Var() - Eq(t, reg.Unify(reg.Func(a, IntRef), reg.Func(HoleRef, b)).String(®), "$5: int, $4: ()") + Eq(t, reg.String(res), "list int") } func Neq[T comparable](t *testing.T, a, b T) { From 4c3e30a5f72ced9b281de99daddf59b0ffe71bac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oskar=20Segersv=C3=A4rd?= Date: Sun, 24 Aug 2025 16:58:02 +0200 Subject: [PATCH 3/3] Fix pick with a known type --- eval/eval.go | 5 ----- eval/eval_test.go | 3 +-- types/infer.go | 9 +-------- types/infer_test.go | 4 +++- 4 files changed, 5 insertions(+), 16 deletions(-) diff --git a/eval/eval.go b/eval/eval.go index 9d75f8f..b2d086c 100644 --- a/eval/eval.go +++ b/eval/eval.go @@ -536,11 +536,6 @@ func (c *context) pick(pick *ast.BinaryExpr, x ast.Expr) (Value, error) { if err != nil { return nil, err } - if val.Type() != tagTyp { - return nil, c.error(pick.Right.Span(), - fmt.Sprintf("#%s requires a value of type %s, got %s", - tag, c.reg.String(tagTyp), c.reg.String(val.Type()))) - } return Variant{ref, tag, val}, nil } } diff --git a/eval/eval_test.go b/eval/eval_test.go index bcfc913..6ccd0ca 100644 --- a/eval/eval_test.go +++ b/eval/eval_test.go @@ -58,8 +58,7 @@ var expressions = []struct { // | "hello " ++ name -> name // | _ -> "" <| "hello Oseg"`, Text("Oseg")}, {`box::empty ; box : #empty`, `#empty`}, - // TODO: Cannot infer type of `n -> x * 2`. - // {`typ::fun (n -> x * 2) ; typ : #fun (int -> int)`, `#fun n -> x * 2`}, + {`typ::fun (x -> x * 2) ; typ : #fun (int -> int)`, `#fun x -> x * 2`}, // Destructuring. {`{ a = 1, b = 2 } |> | { a = c, b = d } -> c + d`, `3`}, diff --git a/types/infer.go b/types/infer.go index e35b798..98fd8e8 100644 --- a/types/infer.go +++ b/types/infer.go @@ -317,14 +317,7 @@ func (c *context) pick(x *ast.BinaryExpr, val ast.Expr) TypeRef { } } else { valRef := c.infer(val) - // TODO: check assignability instead - if typ != valRef { - // Wrong type. - c.bail(val.Span(), - fmt.Sprintf("cannot assign %s to #%s which needs %s", - c.reg.String(valRef), tag, c.reg.String(typ))) - return NeverRef - } + c.ensure(val, valRef, typ) } return ref diff --git a/types/infer_test.go b/types/infer_test.go index fe29802..6cffa32 100644 --- a/types/infer_test.go +++ b/types/infer_test.go @@ -73,6 +73,8 @@ func TestInfer(t *testing.T) { // If used the same, arguments must be the same. {`a -> b -> [ a, b ]`, `$1 -> $1 -> list $1`}, {`(a -> b -> [ a, b ]) 1`, `int -> list int`}, + + {`typ::fun (x -> x * 2) ; typ : #fun (int -> int)`, `#fun (int -> int)`}, } for _, ex := range examples { @@ -103,7 +105,7 @@ func TestInferFailure(t *testing.T) { {`1::a`, `1 isn't an enum`}, {`a::a ; a : #b`, `#a isn't a valid option for enum #b`}, {`a::b 1 ; a : #b`, `#b doesn't take any value`}, - {`a::b 1 ; a : #b text`, `cannot assign int to #b which needs text`}, + {`a::b 1 ; a : #b text`, `cannot unify 'int' with 'text'`}, {`1 + ~dd`, `cannot unify 'byte' with 'int'`}, {`a ; a : int = 1.0`, `cannot unify 'float' with 'int'`}, {`f ; f : int -> text = a -> 1`, `cannot unify 'int' with 'text'`},