diff --git a/eval/eval.go b/eval/eval.go index 9d75f8f..b2d086c 100644 --- a/eval/eval.go +++ b/eval/eval.go @@ -536,11 +536,6 @@ func (c *context) pick(pick *ast.BinaryExpr, x ast.Expr) (Value, error) { if err != nil { return nil, err } - if val.Type() != tagTyp { - return nil, c.error(pick.Right.Span(), - fmt.Sprintf("#%s requires a value of type %s, got %s", - tag, c.reg.String(tagTyp), c.reg.String(val.Type()))) - } return Variant{ref, tag, val}, nil } } diff --git a/eval/eval_test.go b/eval/eval_test.go index bcfc913..6ccd0ca 100644 --- a/eval/eval_test.go +++ b/eval/eval_test.go @@ -58,8 +58,7 @@ var expressions = []struct { // | "hello " ++ name -> name // | _ -> "" <| "hello Oseg"`, Text("Oseg")}, {`box::empty ; box : #empty`, `#empty`}, - // TODO: Cannot infer type of `n -> x * 2`. - // {`typ::fun (n -> x * 2) ; typ : #fun (int -> int)`, `#fun n -> x * 2`}, + {`typ::fun (x -> x * 2) ; typ : #fun (int -> int)`, `#fun x -> x * 2`}, // Destructuring. {`{ a = 1, b = 2 } |> | { a = c, b = d } -> c + d`, `3`}, diff --git a/scanner/scanner.go b/scanner/scanner.go index 6d06216..eced12e 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -158,11 +158,7 @@ func (s *Scanner) bytes() (tok token.Token, span token.Span) { s.next() } - if s.offset-offs < 2 { - s.error(s.offset, "too short base64 string") - tok = token.BAD - return - } + // The two chars `~~` encodes an empty byte array. for (s.offset-offs)%4 > 0 { if s.ch != '=' { diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index 1229fb3..5215fed 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -41,6 +41,7 @@ var elements = []elt{ {token.TEXT, `"world"`, literal}, {token.BYTE, "~ca", literal}, {token.BYTES, "~~aGVsbG8gd29ybGQ=", literal}, + {token.BYTES, "~~", literal}, // Operators and delimiters {token.ASSIGN, "=", operator}, diff --git a/types/infer.go b/types/infer.go index 038cb65..98fd8e8 100644 --- a/types/infer.go +++ b/types/infer.go @@ -34,14 +34,6 @@ func (s *Scope[T]) Bind(name string, val T) *Scope[T] { return &Scope[T]{s, name, val} } -func (s *Scope[T]) Rebind(name string, val T) bool { - if bound := s.Get(name); bound != nil { - bound.val = val - return true - } - return false -} - type TypeScope = *Scope[TypeRef] type context struct { @@ -59,13 +51,20 @@ func (c *context) bind(name string, ref TypeRef) TypeScope { return c.scope } -func Infer(se ast.SourceExpr) (string, error) { - var reg Registry - var scope TypeScope +// Unbinds the last bound variable. +func (c *context) unbind() { + c.scope = c.scope.parent +} +func DefaultScope() (reg Registry, scope TypeScope) { for _, p := range primitives { scope = scope.Bind(reg.String(p), p) } + return +} + +func Infer(se ast.SourceExpr) (string, error) { + reg, scope := DefaultScope() ref, err := InferInScope(®, scope, se) if err != nil { @@ -91,7 +90,8 @@ func InferInScope(reg *Registry, scope TypeScope, se ast.SourceExpr) (ref TypeRe } }() - return context.infer(se.Expr), err + ref = context.infer(se.Expr) + return ref, err } func (c *context) infer(expr ast.Expr) TypeRef { @@ -99,7 +99,12 @@ func (c *context) infer(expr ast.Expr) TypeRef { case *ast.Literal: return literalTypeRef(x.Kind) case *ast.Ident: - return c.scope.Lookup(c.source.GetString(x.Pos)) + name := c.source.GetString(x.Pos) + ref := c.scope.Lookup(name) + if ref == NeverRef { + c.bail(x.Pos, "unbound variable: "+name) + } + return c.reg.Instantiate(ref) case *ast.WhereExpr: return c.where(x) case *ast.ListExpr: @@ -108,69 +113,61 @@ func (c *context) infer(expr ast.Expr) TypeRef { return c.record(x) case ast.EnumExpr: return c.enum(x) + case *ast.FuncExpr: - unbound := c.reg.Unbound() - // Hold onto the binding, in case inferring the body rebinds its type. - binding := c.bind(c.source.GetString(x.Arg.Span()), unbound) + // Not sure how to juggle vars vs unbound. :/ + binder := c.reg.Var() + c.bind(c.source.GetString(x.Arg.Span()), binder) + defer c.unbind() ret := c.infer(x.Body) - return c.reg.Func(binding.val, ret) + return c.reg.Func(binder, ret) + case *ast.CallExpr: // Special-case pick with a value. if pick, ok := x.Fn.(*ast.BinaryExpr); ok && pick.Op == token.PICK { return c.pick(pick, x.Arg) } - typ := c.infer(x.Fn) + res := c.reg.Var() + fn := c.infer(x.Fn) arg := c.infer(x.Arg) + c.ensure(x, fn, c.reg.Func(arg, res)) + return res - if !typ.IsFunction() { - // If the argument is an unbound identifier, rebind it. - id, ok := x.Fn.(*ast.Ident) - if ok && typ.IsUnbound() { - name := c.source.GetString(id.Pos) - // Let's steal the now unused (?) unbound. - fn := c.reg.Func(arg, typ) - if c.scope.Rebind(name, fn) { - return typ - } - - // Let's try to rebind a type. - } - - c.bail(x.Span(), fmt.Sprintf("cannot call non-function %s", c.reg.String(typ))) - } - fn := c.reg.GetFunc(typ) - - ref := c.call(fn, arg) - if ref != NeverRef { - return ref + case *ast.BinaryExpr: + if x.Op == token.PICK { + return c.pick(x, nil) } - c.bail(x.Span(), fmt.Sprintf("cannot call %s with %s", c.reg.String(typ), c.reg.String(arg))) - - case *ast.BinaryExpr: left := c.infer(x.Left) right := c.infer(x.Right) switch x.Op { - case token.PICK: - return c.pick(x, nil) + case token.PREPEND: + return c.pend(x.Left, x.Right, left, right) + case token.APPEND: + return c.pend(x.Right, x.Left, right, left) case token.CONCAT: - if left == TextRef { - if right == TextRef { - return TextRef - } else if right.IsUnbound() { - - } else { - return NeverRef - } + if left == TextRef || right == TextRef { + c.ensure(x, left, right) + return TextRef } - case token.ADD: - if left == IntRef { - return c.ensure(x.Right, right, IntRef) + if left == BytesRef || right == BytesRef { + c.ensure(x, left, right) + return BytesRef } - if right == IntRef { - return c.ensure(x.Left, left, IntRef) + // Local var to ensure left and right are lists. + a := c.reg.List(c.reg.Var()) + c.ensure(x, left, right) + c.ensure(x, left, a) + return a + case token.ADD, token.SUB, token.MUL: + if left == FloatRef || right == FloatRef { + c.ensure(x, left, right) + return FloatRef } + // Assume int, like ML does. + c.ensure(x.Left, left, IntRef) + return c.ensure(x.Right, right, IntRef) } panic(fmt.Sprintf("can't infer binary expression %s", x.Op.String())) } @@ -179,78 +176,37 @@ func (c *context) infer(expr ast.Expr) TypeRef { } func (c *context) ensure(x ast.Expr, got, want TypeRef) TypeRef { - if got == want { - return got - } - - if got.IsUnbound() { - c.rebind(x, want) - return want - } - - if c.isAssignable(want, got) { - return got - } - - c.bail(x.Span(), fmt.Sprintf("expected %s, got %s", c.reg.String(want), c.reg.String(got))) - return NeverRef -} - -func (c *context) call(fn FuncRef, arg TypeRef) TypeRef { - if c.isAssignable(fn.Arg, arg) { - return fn.Result - } - - if fn.Arg.IsUnbound() { - return c.reg.Bind(fn.Result, fn.Arg, arg) - } - - if fn.Arg.IsFunction() && arg.IsFunction() { - afn := c.reg.GetFunc(fn.Arg) - bfn := c.reg.GetFunc(arg) - - // If completely unbound, replace with arg. - if afn.Arg.IsUnbound() && afn.Result.IsUnbound() { - res := c.reg.Bind(fn.Result, afn.Result, bfn.Result) - return c.reg.Bind(res, afn.Arg, bfn.Arg) - } - } - - return NeverRef -} - -func (c *context) isAssignable(a, b TypeRef) bool { - if a == b { - return true - } + if got != want { + // Really? Must make this API better. + defer func() { + if pnc := recover(); pnc != nil { + if msg, ok := pnc.(string); ok { + c.bail(x.Span(), msg) + } else { + panic(pnc) + } + } + }() - aTag, _ := a.extract() - switch aTag { - case listTag: - if b.IsList() && c.reg.GetList(b).IsUnbound() { - return true - } + c.reg.unify(got, want) } - - return false + return want } func (c *context) where(x *ast.WhereExpr) TypeRef { name := c.source.GetString(x.Id.Pos) - if x.Typ == nil { - // If there is no type annotation, we can infer it from the value. - c.bind(name, c.infer(x.Val)) - return c.infer(x.Expr) - } - tRef := c.typ(x.Typ) - vRef := c.infer(x.Val) - if tRef != vRef { - c.bail(x.Val.Span(), fmt.Sprintf("cannot assign %s to %s", c.reg.String(vRef), c.reg.String(tRef))) + tyVal := c.infer(x.Val) + + // If there's an annotation, make sure it matches the inferred type. + if x.Typ != nil { + c.ensure(x.Typ, tyVal, c.typ(x.Typ)) } - c.bind(name, tRef) - return c.infer(x.Expr) + c.bind(name, c.reg.generalize(tyVal)) + defer c.unbind() + tyExpr := c.infer(x.Expr) + return tyExpr } func (c *context) typ(x ast.Expr) TypeRef { @@ -273,42 +229,26 @@ func (c *context) typ(x ast.Expr) TypeRef { return NeverRef } -func (c *context) list(x *ast.ListExpr) (res TypeRef) { +func (c *context) list(x *ast.ListExpr) TypeRef { + res := NeverRef + for _, v := range x.Elements { typ := c.infer(v) - if typ == res { - continue - } else if res == NeverRef { + + if res == NeverRef { res = typ - } else if typ.IsUnbound() { - c.rebind(v, res) - } else { - c.bail(v.Span(), "list elements must all be of type "+c.reg.String(res)) - // Bad list. - return NeverRef + continue } + + c.ensure(v, res, typ) } + if res == NeverRef { - res = c.reg.Unbound() + res = c.reg.Var() } return c.reg.List(res) } -// Re-binds the type of expresion x, or fails. -func (c *context) rebind(x ast.Expr, ref TypeRef) { - name := c.source.GetString(x.Span()) - _, ok := x.(*ast.Ident) - if ok { - s := c.scope.Get(name) - if s != nil { - // Let's steal the now unused (?) generic. - s.val = ref - return - } - } - c.bail(x.Span(), fmt.Sprintf("can't rebind type of non-identifier %s", name)) -} - func (c *context) record(x *ast.RecordExpr) TypeRef { // If there is a rest/spread, our type is equal to that. if x.Rest != nil { @@ -377,14 +317,7 @@ func (c *context) pick(x *ast.BinaryExpr, val ast.Expr) TypeRef { } } else { valRef := c.infer(val) - // TODO: check assignability instead - if typ != valRef { - // Wrong type. - c.bail(val.Span(), - fmt.Sprintf("cannot assign %s to #%s which needs %s", - c.reg.String(valRef), tag, c.reg.String(typ))) - return NeverRef - } + c.ensure(val, valRef, typ) } return ref @@ -412,3 +345,16 @@ func literalTypeRef(tok token.Token) TypeRef { return NeverRef } + +// Either pre-pend or ap-pend. +func (c *context) pend(singleX, listX ast.Expr, single, list TypeRef) TypeRef { + // Special-case bytes. + if single == ByteRef || list == BytesRef { + c.ensure(singleX, single, ByteRef) + c.ensure(listX, list, BytesRef) + return BytesRef + } + + c.ensure(singleX, c.reg.List(single), list) + return list +} diff --git a/types/infer_test.go b/types/infer_test.go index a214e7e..6cffa32 100644 --- a/types/infer_test.go +++ b/types/infer_test.go @@ -19,26 +19,62 @@ func TestInfer(t *testing.T) { // Primitives {`5`, `int`}, {`a ; a = 5`, `int`}, + {`1 + 2`, `int`}, // Lists - {`[]`, `list a`}, // empty list has an unbound type for its values + {`[]`, `list $0`}, // empty list has an unbound type for its values {`[1, 2]`, `list int`}, // Records {`{ a = 1 }`, `{ a : int }`}, {`{ ..base, a = ~01 } ; base = { a = ~00 }`, `{ a : byte }`}, - // Enums + // // Enums {`bool ; bool : #true #false`, `#false #true`}, {`e ; e : #l int #r`, `#l int #r`}, {`e::r ; e : #l int #r`, `#l int #r`}, {`e::l 4 ; e : #l int #r`, `#l int #r`}, // Functions - {`_ -> "hi"`, `a -> text`}, - {`_ -> _ -> "hi"`, `a -> b -> text`}, + {`a -> a`, `$0 -> $0`}, + {`_ -> "hi"`, `$0 -> text`}, + {`_ -> _ -> "hi"`, `$0 -> $1 -> text`}, {`(_ -> "hi") ()`, `text`}, - {`a -> b -> { a = a, b = b }`, `a -> b -> { a : a, b : b }`}, - {`(a -> b -> { a = a, b = b }) 1`, `a -> { a : int, b : a }`}, + + // Prepend and append + {`a -> a >+ []`, `$1 -> list $1`}, + {`a -> a +< int`, `list int -> list int`}, + {`a -> a >+ ~~1111`, `byte -> bytes`}, + {`a -> a +< ~ff`, `bytes -> bytes`}, + + // Concat + {`"hi " ++ "you!"`, `text`}, + {`[] ++ [1]`, `list int`}, + {`~~1111 ++ ~~`, `bytes`}, + {`a -> b -> a ++ b`, `list $2 -> list $2 -> list $2`}, + + // Math + {`a -> 1.0 + a`, `float -> float`}, + {`4 - 3`, `int`}, + {`a -> b -> a * b`, `int -> int -> int`}, // Default to int. + + {`a -> b -> { a = a, b = b }`, `$0 -> $1 -> { a : $0, b : $1 }`}, + {`(a -> b -> { a = a, b = b }) 1`, `$2 -> { a : int, b : $2 }`}, {`(a -> b -> { a = a, b = b }) 1 "yo" `, `{ a : int, b : text }`}, {`a ; a : int = 1`, `int`}, {`a -> a + 1`, `int -> int`}, + {`b -> (a ; a : int = b)`, `int -> int`}, + + {`f -> f (f 1)`, `(int -> int) -> int`}, + {`a -> f -> f (f a)`, `$2 -> ($2 -> $2) -> $2`}, + {`f -> a -> f (f a)`, `($2 -> $2) -> $2 -> $2`}, + + {`f -> a -> [ a ]`, `$0 -> $1 -> list $1`}, + {`(f -> a -> [ a ]) "a"`, `$2 -> list $2`}, + {`(f -> a -> [ a ]) "a" 3`, `list int`}, + + {`f -> a -> ([ b, b ] ; b = (f a))`, `($1 -> $2) -> $1 -> list $2`}, + // If used the same, arguments must be the same. + {`a -> b -> [ a, b ]`, `$1 -> $1 -> list $1`}, + {`(a -> b -> [ a, b ]) 1`, `int -> list int`}, + + {`typ::fun (x -> x * 2) ; typ : #fun (int -> int)`, `#fun (int -> int)`}, } for _, ex := range examples { @@ -56,8 +92,12 @@ func TestInfer(t *testing.T) { func TestInferFailure(t *testing.T) { examples := []struct{ source, message string }{ + // Unbound + {`b ; a = b -> b`, `unbound variable: b`}, // Lists - {`[1, 1.0]`, `list elements must all be of type int`}, + {`[1, 1.0]`, `cannot unify 'int' with 'float'`}, + {`[4] ++ ["text"]`, `cannot unify 'int' with 'text'`}, + {`4 ++ 6`, `cannot unify 'int' with 'list $0'`}, // Records {`{ ..base, a = 1 } ; base = { a = ~00 }`, `type of a must be byte, not int`}, {`{ ..1, a = 1 }`, `cannot spread from non-record type int`}, @@ -65,10 +105,12 @@ func TestInferFailure(t *testing.T) { {`1::a`, `1 isn't an enum`}, {`a::a ; a : #b`, `#a isn't a valid option for enum #b`}, {`a::b 1 ; a : #b`, `#b doesn't take any value`}, - {`a::b 1 ; a : #b text`, `cannot assign int to #b which needs text`}, - {`1 + ~dd`, `expected int, got byte`}, - {`a ; a : int = 1.0`, `cannot assign float to int`}, - {`f ; f : int -> text = a -> 1`, `cannot assign a -> int to int -> text`}, + {`a::b 1 ; a : #b text`, `cannot unify 'int' with 'text'`}, + {`1 + ~dd`, `cannot unify 'byte' with 'int'`}, + {`a ; a : int = 1.0`, `cannot unify 'float' with 'int'`}, + {`f ; f : int -> text = a -> 1`, `cannot unify 'int' with 'text'`}, + // Math + {`1 + 1.0`, `cannot unify 'int' with 'float'`}, } for _, ex := range examples { @@ -86,29 +128,28 @@ func TestInferFailure(t *testing.T) { } func TestInferInScope(t *testing.T) { - reg := Registry{} - var scope *Scope[TypeRef] - - scope = scope.Bind("len", reg.Func(reg.List(reg.Unbound()), IntRef)) - examples := []struct{ source, typ string }{ - {`len`, `list a -> int`}, + {`len`, `list $0 -> int`}, {`len []`, `int`}, - {`f -> a -> [ a ]`, `a -> b -> list b`}, - {`(f -> a -> [ a ]) "a"`, `a -> list a`}, - {`(f -> a -> [ a ]) "a" 3`, `list int`}, - {`f -> a -> ([ b, b ] ; b = (f a))`, `(a -> b) -> a -> list b`}, - // If used the same, arguments must be the same. - {`a -> b -> [ a, b ]`, `a -> a -> list a`}, - {`(a -> b -> [ a, b ]) 1`, `int -> list int`}, - - {`(f -> a -> ([ b, b ] ; b = (f a))) len`, `list a -> list int`}, + {`(f -> a -> ([ b, b ] ; b = (f a))) len`, `list $4 -> list int`}, {`(f -> a -> ([ b, b ] ; b = (f a))) len []`, `list int`}, - // {`twice ; twice = f -> a -> f (f a)`, `(a -> a) -> a -> a`}, + + {`{ a = id 1, b = id "" }`, `{ a : int, b : text }`}, + {`{ a = id2 1, b = id2 "" } ; id2 = a -> a`, `{ a : int, b : text }`}, } for _, ex := range examples { se := must(parser.ParseExpr(ex.source)) + + // New registry every test. + reg := Registry{} + var scope *Scope[TypeRef] + + scope = scope.Bind("len", reg.Func(reg.List(reg.Unbound()), IntRef)) + + a := reg.Unbound() + scope = scope.Bind("id", reg.Func(a, a)) + ref, err := InferInScope(®, scope, se) if err != nil { t.Error(err) diff --git a/types/subst.go b/types/subst.go new file mode 100644 index 0000000..d181bfe --- /dev/null +++ b/types/subst.go @@ -0,0 +1,51 @@ +package types + +import "strings" + +// A single substitution. +type Sub struct { + // Must be a variable. + replace TypeRef + with TypeRef +} + +// A set of substitutions. +type Subst []Sub + +func (s Subst) binds(target TypeRef) bool { + for _, s := range s { + if s.replace == target { + return true + } + } + return false +} + +func (s Subst) bound(target TypeRef) TypeRef { + for _, s := range s { + if s.replace == target { + return s.with + } + } + return NeverRef +} + +func (s *Subst) bind(replace, with TypeRef) { + *s = append(*s, Sub{replace, with}) +} + +// For debugging. +func (s Subst) String(reg *Registry) string { + var b strings.Builder + + for i, s := range s { + if i != 0 { + b.WriteString(", ") + } + b.WriteString(VarString(s.replace)) + b.WriteString(": ") + b.WriteString(reg.String(s.with)) + } + + return b.String() +} diff --git a/types/type.go b/types/type.go index c079fee..74e8f54 100644 --- a/types/type.go +++ b/types/type.go @@ -3,6 +3,7 @@ package types import ( "maps" "slices" + "strconv" "strings" ) @@ -16,8 +17,19 @@ const ( enumTag recordTag unboundTag + varTag ) +var tagNames = [...]string{ + primitiveTag: "primitive", + listTag: "list", + funcTag: "func", + enumTag: "enum", + recordTag: "record", + unboundTag: "unbound", + varTag: "var", +} + // Efficiently encodes a type reference within a Registry. // // The zero value references the impossible "never" type. @@ -36,8 +48,21 @@ func (ref TypeRef) extract() (tag, int) { return tag, index } +func (ref TypeRef) tag() tag { + return tag(ref & 0x0f) +} + +func (ref TypeRef) index() int { + return int(ref >> 4) +} + func (ref TypeRef) hasTag(t tag) bool { - return tag(ref&0x0f) == t + return ref.tag() == t +} + +// Returns true if both TypeRefs have the same tag. +func (ref TypeRef) SameTypeAs(other TypeRef) bool { + return ref.hasTag(other.tag()) } // IsList returns true if the TypeRef is a list. @@ -55,6 +80,11 @@ func (ref TypeRef) IsUnbound() bool { return ref.hasTag(unboundTag) } +// IsVar returns true if the TypeRef is an var type. +func (ref TypeRef) IsVar() bool { + return ref.hasTag(varTag) +} + const ( // Shortcut to the TypeRef for Never. NeverRef TypeRef = TypeRef(int(primitiveTag) | (iota << 4)) // Inlined makeTypeRef @@ -95,6 +125,12 @@ type Registry struct { // Enums and records are maps to TypeRefs. enums []MapRef records []MapRef + // Type variables that will point to another type, + // or NeverRef if not yet assigned. + // + // Schemes are types with unbound TypeRefs. When instantiating a type, + // all unbound types will be replaced with fresh vars instead. + vars []TypeRef } // Returns the number of types in the registry, for debugging. @@ -175,35 +211,235 @@ func (c *Registry) Unbound() (ref TypeRef) { return } -// Bind replaces all occurrences of `unbound` with `resolved` in the `target` type. -func (c *Registry) Bind(target, unbound, resolved TypeRef) TypeRef { - // Base case: the target is the unbound we want to replace. - if target == unbound { - return resolved +// Var returns a new variable TypeRef. +func (c *Registry) Var() (ref TypeRef) { + i := len(c.vars) + c.vars = append(c.vars, NeverRef) + return makeTypeRef(varTag, i) +} + +// Resolve follows variables to their last bound var. +func (c *Registry) Resolve(ref TypeRef) TypeRef { + // Ignore non-vars. + if !ref.IsVar() { + return ref + } + other := c.Resolve(c.vars[ref.index()]) + if other == NeverRef { + return ref + } + c.vars[ref.index()] = other + return other +} + +// GetVar returns the TypeRef for an record type. +func (c *Registry) GetVar(ref TypeRef) TypeRef { + c.Resolve(ref) + tag, index := ref.extract() + if tag != varTag { + return ref + } + return c.vars[index] +} + +func (c *Registry) IsFree(ref TypeRef) bool { + return c.Resolve(ref).IsVar() +} + +// VarString returns the string representation of an unresolved variable. +func VarString(ref TypeRef) string { + tag, index := ref.extract() + if tag != varTag { + panic("VarString: got non-var tag " + tagNames[tag]) + } + return "$" + strconv.FormatInt(int64(index), 10) +} + +type MapTypeRef func(ref TypeRef) + +func (c *Registry) traverse(target TypeRef, mtr MapTypeRef) { + tag, index := target.extract() + switch tag { + case listTag: + c.traverse(c.lists[index], mtr) + case funcTag: + fn := c.funcs[index] + c.traverse(fn.Arg, mtr) + c.traverse(fn.Result, mtr) + case enumTag: + for _, v := range c.enums[index] { + c.traverse(v, mtr) + } + case recordTag: + for _, v := range c.records[index] { + c.traverse(v, mtr) + } + } + + mtr(target) +} + +type Replacer func(ref TypeRef) TypeRef + +func (c *Registry) replace(target TypeRef, f Replacer) TypeRef { + tag, index := target.extract() + switch tag { + case listTag: + return c.List(f(c.lists[index])) + case funcTag: + fn := c.funcs[index] + return c.Func(f(fn.Arg), f(fn.Result)) + case enumTag: + ref := make(MapRef, len(c.enums[index])) + for k, v := range c.enums[index] { + ref[k] = f(v) + } + return c.Enum(ref) + case recordTag: + ref := make(MapRef, len(c.records[index])) + for k, v := range c.records[index] { + ref[k] = f(v) + } + return c.Record(ref) + } + + // Else, the target remains unchanged. + return target +} + +// bind binds a free variable to a type. +func (reg *Registry) bind(a, b TypeRef) { + // Get to the bottom of `a`. + a = reg.Resolve(a) + + if !a.IsVar() { + panic("cannot bind non-free var " + reg.String(a)) + } + reg.vars[a.index()] = b +} + +// The opposite of instantiate. +func (c *Registry) generalize(target TypeRef) TypeRef { + var subst Subst + return c.replace(target, func(other TypeRef) TypeRef { + if other.IsVar() { + b := subst.bound(other) + if b == NeverRef { + b = c.Unbound() + subst.bind(other, b) + } + return b + } + return other + }) +} + +func (c *Registry) Instantiate(target TypeRef) TypeRef { + var subst Subst + c.insertUnbound(target, &subst) + return c.substitute(target, subst) +} + +func (c *Registry) insertUnbound(target TypeRef, subst *Subst) { + tag, index := target.extract() + switch tag { + case unboundTag: + if !subst.binds(target) { + subst.bind(target, c.Var()) + } + case listTag: + c.insertUnbound(c.lists[index], subst) + case funcTag: + fn := c.funcs[index] + c.insertUnbound(fn.Arg, subst) + c.insertUnbound(fn.Result, subst) + // TODO: Other types + } +} + +func (c *Registry) unify(a, b TypeRef) { + a = c.Resolve(a) + b = c.Resolve(b) + + tag, index := a.extract() + if tag == unboundTag { + panic("unexpected unbound var during unification") + } + + if tag == varTag { + c.traverse(b, func(ref TypeRef) { + if a == ref { + panic("occurs check failed") + } + }) + c.vars[index] = b + return + } + + if b.IsVar() { + c.unify(b, a) + return + } + + bTag, bIndex := b.extract() + if tag == bTag { + switch tag { + case funcTag: + aFn := c.GetFunc(a) + bFn := c.GetFunc(b) + + c.unify(aFn.Arg, bFn.Arg) + c.unify(aFn.Result, bFn.Result) + case listTag: + c.unify(c.GetList(a), c.GetList(b)) + case recordTag: + c.unify(c.GetList(a), c.GetList(b)) + case primitiveTag: + if index != bIndex { + panic("cannot unify '" + c.String(a) + "' with '" + c.String(b) + "'") + } + } + } else { + panic("cannot unify '" + c.String(a) + "' with '" + c.String(b) + "'") } +} +func (c *Registry) substitute(target TypeRef, subst Subst) TypeRef { tag, index := target.extract() switch tag { + case unboundTag: + for _, s := range subst { + if s.replace == target { + return s.with + } + } + case varTag: + for _, s := range subst { + if s.replace == target { + c.vars[index] = s.with + return s.with + } + } case listTag: return c.List( - c.Bind(c.lists[index], unbound, resolved), + c.substitute(c.lists[index], subst), ) case funcTag: fn := c.funcs[index] return c.Func( - c.Bind(fn.Arg, unbound, resolved), - c.Bind(fn.Result, unbound, resolved), + c.substitute(fn.Arg, subst), + c.substitute(fn.Result, subst), ) case enumTag: ref := make(MapRef, len(c.enums[index])) for k, v := range c.enums[index] { - ref[k] = c.Bind(v, unbound, resolved) + ref[k] = c.substitute(v, subst) } return c.Enum(ref) case recordTag: ref := make(MapRef, len(c.records[index])) for k, v := range c.records[index] { - ref[k] = c.Bind(v, unbound, resolved) + ref[k] = c.substitute(v, subst) } return c.Record(ref) } @@ -212,6 +448,39 @@ func (c *Registry) Bind(target, unbound, resolved TypeRef) TypeRef { return target } +// DebugString returns a string representation for TypeRef. +func (reg *Registry) DebugString() string { + var s stringer + s.reg = reg + + s.WriteString("Vars:\n") + for i := range reg.vars { + s.WriteString(" $") + s.WriteString(strconv.Itoa(i)) + s.WriteString(": ") + s.string(makeTypeRef(varTag, i), 0) + s.WriteString("\n") + } + s.WriteString("Functions:\n") + for i := range reg.funcs { + s.WriteString(" ") + s.WriteString(strconv.Itoa(i)) + s.WriteString(": ") + s.string(makeTypeRef(funcTag, i), 0) + s.WriteString("\n") + } + s.WriteString("Lists:\n") + for i := range reg.lists { + s.WriteString(" ") + s.WriteString(strconv.Itoa(i)) + s.WriteString(": ") + s.string(makeTypeRef(listTag, i), 0) + s.WriteString("\n") + } + + return s.String() +} + func findOrAdd[T comparable](ls *[]T, tag tag, el T) TypeRef { list := *ls for i, typ := range list { @@ -252,6 +521,7 @@ func (b *stringer) unbound(index int) { b.unbounds = append(b.unbounds, index) } b.WriteByte(unboundNames[i]) + // b.WriteByte(unboundNames[index]) } func (b *stringer) string(ref TypeRef, nesting int) { @@ -285,6 +555,14 @@ func (b *stringer) string(ref TypeRef, nesting int) { b.record(index) case unboundTag: b.unbound(index) + case varTag: + ref := b.reg.GetVar(ref) + if ref == NeverRef { + b.WriteByte('$') + b.WriteString(strconv.Itoa(index)) + } else { + b.string(ref, nesting) + } default: // The invalid type. panic("bad type-ref") diff --git a/types/type_test.go b/types/type_test.go index 12ef6fc..8c727c3 100644 --- a/types/type_test.go +++ b/types/type_test.go @@ -95,55 +95,75 @@ func TestGeneric(t *testing.T) { Eq(t, reg.String(listMap), "(a -> b) -> list a -> list b") } -func TestBind(t *testing.T) { +func TestInstantiate(t *testing.T) { reg := Registry{} + // A Scheme is represented by an unbound variable. + // These are left untouched. (forall x. x) untouched. a := reg.Unbound() - Eq(t, reg.String(a), "a") - Eq(t, reg.Bind(a, a, IntRef), IntRef) + b := reg.Unbound() + f := reg.Func(a, b) + Eq(t, reg.String(f), "a -> b") - id := reg.Func(a, a) - Eq(t, reg.String(id), "a -> a") - Eq(t, reg.Size(), 1) + l := reg.List(a) + Eq(t, reg.String(l), "list a") - inc := reg.Bind(id, a, IntRef) - Eq(t, reg.String(inc), "int -> int") - Eq(t, reg.Size(), 2) + g := reg.Instantiate(f) + h := reg.Instantiate(f) + Eq(t, reg.String(g), "$0 -> $1") + Eq(t, reg.String(h), "$2 -> $3") - b := reg.Unbound() - listMap := reg.Func(reg.Func(a, b), reg.Func(reg.List(a), reg.List(b))) - Eq(t, reg.String(listMap), "(a -> b) -> list a -> list b") - Eq(t, reg.Size(), 7) + Eq(t, reg.String(reg.Instantiate(l)), "list $4") +} + +func TestGetVar(t *testing.T) { + reg := Registry{} - // Replace b -> int. - listMapBInt := reg.Bind(listMap, b, IntRef) - Eq(t, reg.String(listMapBInt), "(a -> int) -> list a -> list int") - Eq(t, reg.Size(), 11) + a := reg.Var() + b := reg.Var() + c := reg.Var() + reg.bind(a, b) + reg.bind(b, c) + reg.bind(c, IntRef) - // Now also a -> int. - listMapABInt := reg.Bind(listMapBInt, a, IntRef) - Eq(t, reg.String(listMapABInt), "(int -> int) -> list int -> list int") - Eq(t, reg.Size(), 13) + Eq(t, reg.GetVar(a), IntRef) +} + +func TestResolve(t *testing.T) { + reg := Registry{} + + a := reg.Var() + b := reg.Var() - // Let's go the other way, replace a -> int. - listMapAInt := reg.Bind(listMap, a, IntRef) - Eq(t, reg.String(listMapAInt), "(int -> a) -> list int -> list a") - Eq(t, reg.Size(), 16) + Eq(t, reg.Resolve(a), a) + Eq(t, reg.Resolve(b), b) + Eq(t, reg.IsFree(a), true) + Eq(t, reg.IsFree(b), true) - // Returns same type if resolved the other way. - Eq(t, listMapABInt, reg.Bind(listMapAInt, b, IntRef)) - Eq(t, reg.Size(), 16) + reg.bind(a, b) - record := reg.Record(MapRef{"kind": IntRef, "a": a, "b": b}) - enum := reg.Enum(MapRef{"a": a, "b": b}) - recordToEnum := reg.Func(record, enum) + Eq(t, reg.IsFree(a), true) + Eq(t, reg.IsFree(b), true) + Eq(t, reg.Resolve(a), b) + Eq(t, reg.Resolve(b), b) - // TODO: Don't sort order of record keys. - Eq(t, reg.String(recordToEnum), "{ a : a, b : b, kind : int } -> #a a #b b") + reg.bind(b, IntRef) + + Eq(t, reg.IsFree(a), false) + Eq(t, reg.IsFree(b), false) + Eq(t, reg.Resolve(a), IntRef) + Eq(t, reg.Resolve(b), IntRef) +} + +func TestUnify_J(t *testing.T) { + reg := Registry{} - recordToEnumAInt := reg.Bind(recordToEnum, a, IntRef) - Eq(t, reg.String(recordToEnumAInt), "{ a : int, b : a, kind : int } -> #a int #b a") + res := reg.Var() + a := reg.Var() + fn := reg.Func(a, reg.List(a)) + reg.unify(fn, reg.Func(IntRef, res)) + Eq(t, reg.String(res), "list int") } func Neq[T comparable](t *testing.T, a, b T) {