Coercive subtyping extension to MLsub type inference

After I studied academic papers to find if there is a good way to implement coercions into a traditional Hindley Milner type system. I realised the choices and options are few, and that there are not many satisfying ways to implement function overloading or coercive subtyping with type inference. It would be worthwhile to document my work.

MLsub is a subtyping type inference engine designed by Stephen Dolan for his thesis. I followed through the implementation that was proposed in the thesis and extended it with coercive subtyping. For a short while I thought this was a mistake because of my clumsy implementation.

The code in this post is written in Lever 0.9.0. By many parts that language reads out like Python and it has similar semantics. Be warned that subsequent Lever versions have considerable modifications made to them that make them to be unable to run this code. I find comfort in using my own language for language projects, although on the retrospect it may harm in cases like this.

The code sits in suoml repository. The commit is 7981c14 at the time of writing this.


MLsub combines ML-style parametric polymorphism and subtyping. Thesis represents a theory for subtyping equivalent of unification naming it biunificiation and presents how to implement biunification constraint elimination as the rewriting of a finite-state automaton.

The approach works well for inferencing types for many kind of programs that would be untypeable in Hindley Milner type system, but the subsumption rules are strict and may require that the language designer predefines them for the user.

Separating constraint generation from solving

"Generalizing Hindley-Milner Type Inference Algorithms" presents half of the solution to the problem I had. MLsub inferencer has a similar glass jaw as what Algorithm W has when it comes to recursion. This means that we have to retrieve a dependency graph before doing type inference. It results in two, slightly related traversals through the program that we want to type inference. If you separate the constraint generation from type inference it means you can retrieve the dependency graph while generating constraints.

The program starts by parsing source code and populate an environment for type inferencing it. Then it proceeds to figuring out the types.

actions = object({
    def = Def
    app = App
    let = Let
    abs = Abs
    if_then_else = IfThenElse
    var = Var
    string = (val):
        return Lit(val, types.t_string)
    int = (val):
        return Lit(val, types.t_int)
    rational = (val):
        return Lit(val, types.t_rational)
parse = language.read_file("sample.fml")
decls = parse.traverse((rule, args, loc):
    return getattr(actions, rule)(args...))

declared = {}

namespace = {}
for decl in decls
    namespace[] = decl

dependency_graph = {}
constraints_graph = {}
for decl in decls
    constraints = []
    dependencies = set()
    constraint_generation(decl, constraints, dependencies, declared, {})
    dependency_graph[] = dependencies
    constraints_graph[] = constraints

scc = tarjan_find_scc(dependency_graph)

for group in scc
    is_rec = group_is_recursive(group, dependency_graph)
    solve_constraints(declared, group, is_rec, constraints_graph)

for name, scheme in declared.items()
    print(name, "::", to_raw_type(scheme))

You can see above that the constraint_generation produces a list of constraints and the dependency graph together. Then we directly get to solving constraints after we have grouped the declarations to strongly connected components.

Constraint generation

Separating constraint generation from solving them does not only help us in that one thing. It also makes the algorithm itself much easier to represent in code and simplifies it.

constraint_generation = (a, cn, dep, declared, env):
    if isinstance(a, Def)
        ab = Abs(a.args, a.body)
        result = constraint_generation(ab, cn, dep, declared, env)
        cn.append(['def',, result])
        return result
    if isinstance(a, Abs)
        env = dict(env)
        args = []
        for arg in a.args
            input, output = new_port(cn, true)
            env[arg] = output
        body = constraint_generation(a.body, cn, dep, declared, env)
        return ['function', args, body]
    if isinstance(a, App)
        lhs = constraint_generation(a.lhs, cn, dep, declared, env)
        args = [lhs]
        for arg in a.args
            arg = constraint_generation(arg, cn, dep, declared, env)
        input, output = new_port(cn)
        cn.append(['call', args, input])
        return output
    if isinstance(a, Var)
        if in env
            return env[]
        if not in declared
        return ['instantiate',]
    if isinstance(a, Lit)
        return ['literal', a.type]
    if isinstance(a, IfThenElse)
        cond = constraint_generation(a.cond, cn, dep, declared, env)
        cn.append(['<:', cond, ['literal', types.t_bool]])
        t = constraint_generation(a.t, cn, dep, declared, env)
        f = constraint_generation(a.f, cn, dep, declared, env)
        input, output = new_port(cn)
        cn.append(['<:', t, input])
        cn.append(['<:', f, input])
        return output
    if isinstance(a, Let)
        lhs = constraint_generation(a.lhs, cn, dep, declared, env)
        env = dict(env)
        env[] = lhs
        return constraint_generation(a.rhs, cn, dep, declared, env)
    assert false, ["implement constraint_generation", a]

We generate 'def', 'call', '<:' -constraints.

I left away the let-polymorphism because I am convinced I don't do much with that myself. You can implement it in MLsub by keeping a track of free variables.

Constraint solving

The constraint solving is divided into three routines. The first routine catenates the constraints of the strongly connected group of declarations and resolves the recursion.

solve_constraints = (declared, group, is_rec, constraints_graph):
    visited = set()
    constraints = []
    for name in group
    if is_rec
        free  = set()
        ports = {}
        env = dict(declared)
        for name in group
            port = types.new_port(free)
            ports[name] = port
            env[name] = generalize(port.output, free)
        for name, result in run_constraints(constraints, env, visited)
            types.biunify([result, ports[name].input], visited)
        for name in group
            declared[name] = generalize(ports[name].output, null)
        for name, result in run_constraints(constraints, declared, visited)
            declared[name] = generalize(result, null)

The second routine interpretes every constraint it gets and returns typing schemes.

run_constraints = (constraints, declared, visited):
    for c in constraints
        which = c[0]
        if which == "port"
        elif which == "call"
            args = []
            for arg in c[1]
                args.append(build_node(arg, declared, visited))
            input = c[2]
            call_c = types.new_callsite(types.t_call, [0], args.length)
            calltype = types.State(-1)
            for i in range(args.length)
                calltype.add_transition(types.dom[i], args[i])
            calltype.add_transition(types.cod, input)
            types.biunify([call_c, calltype], visited)
        elif which == "<:"
            o = build_node(c[1], declared, visited)
            i = build_node(c[2], declared, visited, -1)
            types.biunify([o, i], visited)
        elif which == "def"
            yield [c[1], build_node(c[2], declared, visited)]
            assert false, which

The third routine produces types from the constructors in the constraints.

build_node = (node_desc, declared, visited, pol=+1):
    if isinstance(node_desc, types.State)
        return node_desc
    if node_desc[0] == "instantiate"
        scheme = instantiate_type(declared[node_desc[1]])
        return scheme.root
    if node_desc[0] == "literal"
        result = types.State(pol)
        return result
    if node_desc[0] == "function"
        args = node_desc[1]
        resu = node_desc[2]
        functype = types.State(pol)
        for i in range(args.length)
            functype.add_transition(types.dom[i], args[i])
        functype.add_transition(types.cod, resu)
        return functype
    assert false, node_desc

In the second routine we already diverge from MLsub's implementation. In MLsub, a call would be solved through biunification as f <: argument -> output. Instead we construct a binuficiation callsite <: (f, argument) -> output. This results in the function calls being overloadable such that we do not need to have prior knowledge of how a specific type can be called.

I skip Instantiate, generalize and printing of types. They are important parts of the MLsub, but the thesis already covers them. I only describe what these routines do and how they have to handle the callsites.

The instantiation copies a generalized schema. It has to copy any generalized callsites that it finds from the schema.

The generalize produces a generalized schema from a type signature. It has to see whether any of the generalized contravariant variables can activate a callsite and then copy and generalize those callsites.

The printing prints a generalized schema. The printing has to detect callsites and annotate them as constraints into the type printout.

Coercive callsites

Coercion is limited to a callsite of an operator. Operator is a function with no specific, instead the implementation comes from the objects passed into it as arguments. When multiple arguments are used as a selector those arguments have to coerce into same type.

Callsites are resolved immediately when we gain more information about their arguments through biunification. A callsite may be discarded when all the flow variables pointing to its arguments are eliminated.

The coercive callsites extend states in the type automaton. When the 'site' variable is set, the state must appear in the states of the corresponding callsite.

class State
    +init = (self, pol=+1):
        self.pol = pol
        self.heads = set()
        self.flow = set()
        self.transitions = {} = null

The callsite itself consists of an expected shape, set of states associated to that expected shape, and an activation record that memoizes which operations and coercions have been activated on the site.

class OperatorSite
    +init = (self, op, expect, states, activations):
        self.op = op
        self.expect = expect
        self.states = states
        self.activations = activations

    copy = (self, fn):
        op = self.op
        expect = self.expect
        states = {}
        activations = {}
        n_site = OperatorSite(op, expect, states, activations)
        for edge, state in self.states.items()
            states[edge] = n_state = fn(state)
            if == self
       = n_site
        for head, activation in self.activations.items()
            activations[head] = activation.copy(fn)
        return n_site

class Activation
    +init = (self, head, state, coerce):
        self.head = head     # head
        self.state = state   # State(+1)
        self.coerce = coerce # set([head, index])

    copy = (self, fn):
        head = self.head
        state = fn(self.state)
        coerce = self.coerce
        return Activation(head, state, coerce)

Any callsite may be activated when new type heads are merged into the state referencing the callsite. As the result the biunification algorithm gets a small addition of cc -variable, which contains potential new activations to callsites that were affected.

biunify = (pair, visited):
    if pair in visited
    p, q = pair
    assert isinstance(p, State) and isinstance(q, State), [p,q]
    assert p.pol > q.pol, "polarity conflict"

    for x in p.heads
        for y in q.heads
            assert x == y, ["type error", x, y]

    cc = []
    for s in q.flow
        merge(s, p, cc)
    for s in p.flow
        merge(s, q, cc)

    for edge, ws in p.transitions.items()
        wu = q.transitions.get(edge, {})
        if edge.pol > 0
            for s in ws
                for u in wu
                    biunify([s, u], visited)
            for s in ws
                for u in wu
                    biunify([u, s], visited)

    for site, sink, head in cc
        update_coercion(site, sink, head, visited)

merge = (dst, src, cc):
    assert dst.pol == src.pol, "polarity conflict"
    for head in src.heads
        dst.add_head(head, cc)
    for state in src.flow
    for edge, ws in src.transitions.items()
        except KeyError as _
            dst.transitions[edge] = set(ws)

Otherwise the biunification and merging as presented above are identical to those in original MLsub.

When new callsites are created, they are shaped very much like functions but with a difference that the selector arguments in the callsite are annotated with the site itself.

new_callsite = (op, indices, argc):
    expect = get_function_header(argc)
    states = {}
    activations = {}
    site = OperatorSite(op, expect, states, activations)

    callsite = State(+1)
    for i in range(argc)
        if i in indices
            sink = State(+1)
   = site
            sink = State(+1)
        arg = State(-1)
        states[dom[i]] = sink
        callsite.add_transition(dom[i], arg)
    sink = State(-1)
    out = State(+1)
    states[cod] = sink
    callsite.add_transition(cod, out)
    return callsite

When a callsite retrieves a type head not perceived before It will produce cartesian sets with the new type head and existing type heads in other arguments.

update_coercion = (site, pivot, head, visited):
    k = set([set([head])])
    doms = set()
    for edge, sink in site.states.items()
        continue if != site
        continue if sink == pivot
        n = set()
        for t in sink.heads
            for h in k
        k = n
    # k contains all the cartesian sets introduced by the new head.
    # doms contains all domain edges that coerce.

From each set a type is chosen such that every other type can be coerced into that type. If there are more than one, we treat it as if there were none. This is an attempt to ensure that coercions stay unique.

    # Every new cartesian produces a potential new coercion.
    for heads in k
        coercions = set()
        # 'unique_coercion' determines the new target type.
        # it may fail, but in real implementation there
        # might be a default function listed for the operator.
        target = unique_coercion(heads)
        for edge in doms
            sink = site.states[edge]
            if sink == pivot
                coercions.add([head, edge])
                for h in sink.heads
                    if h in heads
                        coercions.add([h, edge])

The callsite is activated with the selected type if it has not already happened because repeated activations would not produce new details. The implementation for operator is being instantiated from a method table and is being biunified to the callsite, with a twist.

        # Coercion is not re-applied if it has been already done.
        if target in site.activations
            activation = site.activations[target]
            state = binary_operator(site.op, target, site.expect)
            assert state.heads == set([site.expect])
                "type error"
            activation = Activation(target, state, set())

            for edge, state in site.states.items()
                continue if edge in doms
                if edge.pol > 0
                    for w in activation.state.transitions[edge]
                        biunify([w, state], visited)
                    for w in activation.state.transitions[edge]
                        biunify([state, w], visited)

The selector arguments aren't biunified directly. Instead coercions are inserted between the arguments of the activated function and the selector argument. It is ensured that the coercion function does not impose constraints to the selector inputs.

        for h, edge in coercions
            continue if [h, edge] in activation.coerce
            activation.coerce.add([h, edge])
            k = simple_coercion(h, target)
            assert k.heads == set([get_function_header(1)])
                "type error"
            # Removing heads from the domain.
            for w in k.transitions[dom[0]]
            sink = site.states[edge]
            for w in k.transitions[dom[0]]
                biunify([sink, w], visited)
            for w in k.transitions[cod]
                for d in activation.state.transitions[edge]
                    biunify([w, d], visited)

This system forms a coercion system around operator objects. A call on an operator produces an appropriate callsite, the callsite produces the necessary coercions.

Coercion & Operator functions

There remains one question. if the function application is abstracted under the same interface, how to ensure that any callable object can be used as an implementation for an operator?

We have to retrieve the function signature by invoking a callsite when the type is analysed. The invocation then fills the function signature.

Method tables associated to types

This system has to associate a method table to every type present in the language.