# Coercive subtyping extension to MLsub type inference

After I studied academic papers to find if there is a good way to implement coercions into a traditional Hindley Milner type system. I realised the choices and options are few, and that there are not many satisfying ways to implement function overloading or coercive subtyping with type inference. It would be worthwhile to document my work.

MLsub is a subtyping type inference engine designed by Stephen Dolan for his thesis. I followed through the implementation that was proposed in the thesis and extended it with coercive subtyping. For a short while I thought this was a mistake because of my clumsy implementation.

The code in this post is written in Lever 0.9.0. By many parts that language reads out like Python and it has similar semantics. Be warned that subsequent Lever versions have considerable modifications made to them that make them to be unable to run this code. I find comfort in using my own language for language projects, although on the retrospect it may harm in cases like this.

The code sits in suoml repository. The commit is 7981c14 at the time of writing this.

## MLSub

MLsub combines ML-style parametric polymorphism and subtyping. Thesis represents a theory for subtyping equivalent of unification naming it biunificiation and presents how to implement biunification constraint elimination as the rewriting of a finite-state automaton.

The approach works well for inferencing types for many kind of programs that would be untypeable in Hindley Milner type system, but the subsumption rules are strict and may require that the language designer predefines them for the user.

## Separating constraint generation from solving

"Generalizing Hindley-Milner Type Inference Algorithms" presents half of the solution to the problem I had. MLsub inferencer has a similar glass jaw as what Algorithm W has when it comes to recursion. This means that we have to retrieve a dependency graph before doing type inference. It results in two, slightly related traversals through the program that we want to type inference. If you separate the constraint generation from type inference it means you can retrieve the dependency graph while generating constraints.

The program starts by parsing source code and populate an environment for type inferencing it. Then it proceeds to figuring out the types.

`actions = object({`

`def = Def`

`app = App`

`let = Let`

`abs = Abs`

`if_then_else = IfThenElse`

`var = Var`

`string = (val):`

`return Lit(val, types.t_string)`

`int = (val):`

`return Lit(val, types.t_int)`

`rational = (val):`

`return Lit(val, types.t_rational)`

`})`

`parse = language.read_file("sample.fml")`

`decls = parse.traverse((rule, args, loc):`

`return getattr(actions, rule)(args...))`

`declared = {}`

`populate_with_globals(declared)`

`namespace = {}`

`for decl in decls`

`namespace[decl.name] = decl`

`dependency_graph = {}`

`constraints_graph = {}`

`for decl in decls`

`constraints = []`

`dependencies = set()`

`constraint_generation(decl, constraints, dependencies, declared, {})`

`dependency_graph[decl.name] = dependencies`

`constraints_graph[decl.name] = constraints`

`scc = tarjan_find_scc(dependency_graph)`

`for group in scc`

`is_rec = group_is_recursive(group, dependency_graph)`

`solve_constraints(declared, group, is_rec, constraints_graph)`

`for name, scheme in declared.items()`

`print(name, "::", to_raw_type(scheme))`

You can see above that
the `constraint_generation`

produces a list of constraints
and the dependency graph
together.
Then we directly get to solving constraints after we have
grouped the declarations to strongly connected components.

## Constraint generation

Separating constraint generation from solving them does not only help us in that one thing. It also makes the algorithm itself much easier to represent in code and simplifies it.

`constraint_generation = (a, cn, dep, declared, env):`

`if isinstance(a, Def)`

`ab = Abs(a.args, a.body)`

`result = constraint_generation(ab, cn, dep, declared, env)`

`cn.append(['def', a.name, result])`

`return result`

`if isinstance(a, Abs)`

`env = dict(env)`

`args = []`

`for arg in a.args`

`input, output = new_port(cn, true)`

`env[arg] = output`

`args.append(input)`

`body = constraint_generation(a.body, cn, dep, declared, env)`

`return ['function', args, body]`

`if isinstance(a, App)`

`lhs = constraint_generation(a.lhs, cn, dep, declared, env)`

`args = [lhs]`

`for arg in a.args`

`arg = constraint_generation(arg, cn, dep, declared, env)`

`args.append(arg)`

`input, output = new_port(cn)`

`cn.append(['call', args, input])`

`return output`

`if isinstance(a, Var)`

`if a.name in env`

`return env[a.name]`

`if a.name not in declared`

`dep.add(a.name)`

`return ['instantiate', a.name]`

`if isinstance(a, Lit)`

`return ['literal', a.type]`

`if isinstance(a, IfThenElse)`

`cond = constraint_generation(a.cond, cn, dep, declared, env)`

`cn.append(['<:', cond, ['literal', types.t_bool]])`

`t = constraint_generation(a.t, cn, dep, declared, env)`

`f = constraint_generation(a.f, cn, dep, declared, env)`

`input, output = new_port(cn)`

`cn.append(['<:', t, input])`

`cn.append(['<:', f, input])`

`return output`

`if isinstance(a, Let)`

`lhs = constraint_generation(a.lhs, cn, dep, declared, env)`

`env = dict(env)`

`env[a.name] = lhs`

`return constraint_generation(a.rhs, cn, dep, declared, env)`

`assert false, ["implement constraint_generation", a]`

We generate 'def', 'call', '<:' -constraints.

I left away the let-polymorphism because I am convinced I don't do much with that myself. You can implement it in MLsub by keeping a track of free variables.

## Constraint solving

The constraint solving is divided into three routines. The first routine catenates the constraints of the strongly connected group of declarations and resolves the recursion.

`solve_constraints = (declared, group, is_rec, constraints_graph):`

`visited = set()`

`constraints = []`

`for name in group`

`constraints.extend(constraints_graph[name])`

`if is_rec`

`free = set()`

`ports = {}`

`env = dict(declared)`

`for name in group`

`port = types.new_port(free)`

`ports[name] = port`

`env[name] = generalize(port.output, free)`

`for name, result in run_constraints(constraints, env, visited)`

`types.biunify([result, ports[name].input], visited)`

`for name in group`

`declared[name] = generalize(ports[name].output, null)`

`else`

`for name, result in run_constraints(constraints, declared, visited)`

`declared[name] = generalize(result, null)`

The second routine interpretes every constraint it gets and returns typing schemes.

`run_constraints = (constraints, declared, visited):`

`for c in constraints`

`which = c[0]`

`if which == "port"`

`c[1].add_flow(c[2])`

`elif which == "call"`

`args = []`

`for arg in c[1]`

`args.append(build_node(arg, declared, visited))`

`input = c[2]`

`call_c = types.new_callsite(types.t_call, [0], args.length)`

`calltype = types.State(-1)`

`calltype.heads.add(types.get_function_header(args.length))`

`for i in range(args.length)`

`calltype.add_transition(types.dom[i], args[i])`

`calltype.add_transition(types.cod, input)`

`types.biunify([call_c, calltype], visited)`

`elif which == "<:"`

`o = build_node(c[1], declared, visited)`

`i = build_node(c[2], declared, visited, -1)`

`types.biunify([o, i], visited)`

`elif which == "def"`

`yield [c[1], build_node(c[2], declared, visited)]`

`else`

`assert false, which`

The third routine produces types from the constructors in the constraints.

`build_node = (node_desc, declared, visited, pol=+1):`

`if isinstance(node_desc, types.State)`

`return node_desc`

`if node_desc[0] == "instantiate"`

`scheme = instantiate_type(declared[node_desc[1]])`

`return scheme.root`

`if node_desc[0] == "literal"`

`result = types.State(pol)`

`result.heads.add(node_desc[1])`

`return result`

`if node_desc[0] == "function"`

`args = node_desc[1]`

`resu = node_desc[2]`

`functype = types.State(pol)`

`functype.heads.add(types.get_function_header(args.length))`

`for i in range(args.length)`

`functype.add_transition(types.dom[i], args[i])`

`functype.add_transition(types.cod, resu)`

`return functype`

`assert false, node_desc`

In the second routine we already
diverge from MLsub's implementation.
In MLsub,
a call would be solved through biunification
as `f <: argument -> output`

.
Instead we construct a binuficiation `callsite <: (f, argument) -> output`

.
This results in the function calls being overloadable
such that we do not need to
have prior knowledge of
how a specific type can be called.

I skip Instantiate, generalize and printing of types. They are important parts of the MLsub, but the thesis already covers them. I only describe what these routines do and how they have to handle the callsites.

The instantiation copies a generalized schema. It has to copy any generalized callsites that it finds from the schema.

The generalize produces a generalized schema from a type signature. It has to see whether any of the generalized contravariant variables can activate a callsite and then copy and generalize those callsites.

The printing prints a generalized schema. The printing has to detect callsites and annotate them as constraints into the type printout.

## Coercive callsites

Coercion is limited to a callsite of an operator. Operator is a function with no specific, instead the implementation comes from the objects passed into it as arguments. When multiple arguments are used as a selector those arguments have to coerce into same type.

Callsites are resolved immediately when we gain more information about their arguments through biunification. A callsite may be discarded when all the flow variables pointing to its arguments are eliminated.

The coercive callsites extend states in the type automaton. When the 'site' variable is set, the state must appear in the states of the corresponding callsite.

`class State`

`+init = (self, pol=+1):`

`self.pol = pol`

`self.heads = set()`

`self.flow = set()`

`self.transitions = {}`

`self.site = null`

The callsite itself consists of an expected shape, set of states associated to that expected shape, and an activation record that memoizes which operations and coercions have been activated on the site.

`class OperatorSite`

`+init = (self, op, expect, states, activations):`

`self.op = op`

`self.expect = expect`

`self.states = states`

`self.activations = activations`

`copy = (self, fn):`

`op = self.op`

`expect = self.expect`

`states = {}`

`activations = {}`

`n_site = OperatorSite(op, expect, states, activations)`

`for edge, state in self.states.items()`

`states[edge] = n_state = fn(state)`

`if state.site == self`

`n_state.site = n_site`

`for head, activation in self.activations.items()`

`activations[head] = activation.copy(fn)`

`return n_site`

`class Activation`

`+init = (self, head, state, coerce):`

`self.head = head # head`

`self.state = state # State(+1)`

`self.coerce = coerce # set([head, index])`

`copy = (self, fn):`

`head = self.head`

`state = fn(self.state)`

`coerce = self.coerce`

`return Activation(head, state, coerce)`

Any callsite may be activated
when new type heads are merged
into the state referencing the callsite.
As the result the biunification algorithm
gets a small addition of `cc`

-variable,
which contains potential new activations to
callsites that were affected.

`biunify = (pair, visited):`

`if pair in visited`

`return`

`visited.add(pair)`

`p, q = pair`

`assert isinstance(p, State) and isinstance(q, State), [p,q]`

`assert p.pol > q.pol, "polarity conflict"`

`for x in p.heads`

`for y in q.heads`

`assert x == y, ["type error", x, y]`

`cc = []`

`for s in q.flow`

`merge(s, p, cc)`

`for s in p.flow`

`merge(s, q, cc)`

`for edge, ws in p.transitions.items()`

`wu = q.transitions.get(edge, {})`

`if edge.pol > 0`

`for s in ws`

`for u in wu`

`biunify([s, u], visited)`

`else`

`for s in ws`

`for u in wu`

`biunify([u, s], visited)`

`for site, sink, head in cc`

`update_coercion(site, sink, head, visited)`

`merge = (dst, src, cc):`

`assert dst.pol == src.pol, "polarity conflict"`

`for head in src.heads`

`dst.add_head(head, cc)`

`for state in src.flow`

`dst.add_flow(state)`

`for edge, ws in src.transitions.items()`

`try`

`dst.transitions[edge].update(ws)`

`except KeyError as _`

`dst.transitions[edge] = set(ws)`

Otherwise the biunification and merging as presented above are identical to those in original MLsub.

When new callsites are created, they are shaped very much like functions but with a difference that the selector arguments in the callsite are annotated with the site itself.

`new_callsite = (op, indices, argc):`

`expect = get_function_header(argc)`

`states = {}`

`activations = {}`

`site = OperatorSite(op, expect, states, activations)`

`callsite = State(+1)`

`callsite.heads.add(expect)`

`for i in range(argc)`

`if i in indices`

`sink = State(+1)`

`sink.site = site`

`else`

`sink = State(+1)`

`arg = State(-1)`

`arg.add_flow(sink)`

`states[dom[i]] = sink`

`callsite.add_transition(dom[i], arg)`

`sink = State(-1)`

`out = State(+1)`

`out.add_flow(sink)`

`states[cod] = sink`

`callsite.add_transition(cod, out)`

`return callsite`

When a callsite retrieves a type head not perceived before It will produce cartesian sets with the new type head and existing type heads in other arguments.

`update_coercion = (site, pivot, head, visited):`

`k = set([set([head])])`

`doms = set()`

`for edge, sink in site.states.items()`

`continue if sink.site != site`

`doms.add(edge)`

`continue if sink == pivot`

`n = set()`

`for t in sink.heads`

`for h in k`

`n.add(h.union([t]))`

`k = n`

`# k contains all the cartesian sets introduced by the new head.`

`# doms contains all domain edges that coerce.`

From each set a type is chosen such that every other type can be coerced into that type. If there are more than one, we treat it as if there were none. This is an attempt to ensure that coercions stay unique.

`# Every new cartesian produces a potential new coercion.`

`for heads in k`

`coercions = set()`

`# 'unique_coercion' determines the new target type.`

`# it may fail, but in real implementation there`

`# might be a default function listed for the operator.`

`target = unique_coercion(heads)`

`for edge in doms`

`sink = site.states[edge]`

`if sink == pivot`

`coercions.add([head, edge])`

`else`

`for h in sink.heads`

`if h in heads`

`coercions.add([h, edge])`

The callsite is activated with the selected type if it has not already happened because repeated activations would not produce new details. The implementation for operator is being instantiated from a method table and is being biunified to the callsite, with a twist.

`# Coercion is not re-applied if it has been already done.`

`if target in site.activations`

`activation = site.activations[target]`

`else`

`state = binary_operator(site.op, target, site.expect)`

`assert state.heads == set([site.expect])`

`"type error"`

`activation = Activation(target, state, set())`

`for edge, state in site.states.items()`

`continue if edge in doms`

`if edge.pol > 0`

`for w in activation.state.transitions[edge]`

`biunify([w, state], visited)`

`else`

`for w in activation.state.transitions[edge]`

`biunify([state, w], visited)`

The selector arguments aren't biunified directly. Instead coercions are inserted between the arguments of the activated function and the selector argument. It is ensured that the coercion function does not impose constraints to the selector inputs.

`for h, edge in coercions`

`continue if [h, edge] in activation.coerce`

`activation.coerce.add([h, edge])`

`k = simple_coercion(h, target)`

`assert k.heads == set([get_function_header(1)])`

`"type error"`

`# Removing heads from the domain.`

`for w in k.transitions[dom[0]]`

`w.heads.clear()`

`sink = site.states[edge]`

`for w in k.transitions[dom[0]]`

`biunify([sink, w], visited)`

`for w in k.transitions[cod]`

`for d in activation.state.transitions[edge]`

`biunify([w, d], visited)`

This system forms a coercion system around operator objects. A call on an operator produces an appropriate callsite, the callsite produces the necessary coercions.

## Coercion & Operator functions

There remains one question. if the function application is abstracted under the same interface, how to ensure that any callable object can be used as an implementation for an operator?

We have to retrieve the function signature by invoking a callsite when the type is analysed. The invocation then fills the function signature.

## Method tables associated to types

This system has to associate a method table to every type present in the language.