Coercive subtyping extension to MLsub type inference
After I studied academic papers to find if there is a good way to implement coercions into a traditional Hindley Milner type system. I realised the choices and options are few, and that there are not many satisfying ways to implement function overloading or coercive subtyping with type inference. It would be worthwhile to document my work.
MLsub is a subtyping type inference engine designed by Stephen Dolan for his thesis. I followed through the implementation that was proposed in the thesis and extended it with coercive subtyping. For a short while I thought this was a mistake because of my clumsy implementation.
The code in this post is written in Lever 0.9.0. By many parts that language reads out like Python and it has similar semantics. Be warned that subsequent Lever versions have considerable modifications made to them that make them to be unable to run this code. I find comfort in using my own language for language projects, although on the retrospect it may harm in cases like this.
The code sits in suoml repository. The commit is 7981c14 at the time of writing this.
MLSub
MLsub combines ML-style parametric polymorphism and subtyping. Thesis represents a theory for subtyping equivalent of unification naming it biunificiation and presents how to implement biunification constraint elimination as the rewriting of a finite-state automaton.
The approach works well for inferencing types for many kind of programs that would be untypeable in Hindley Milner type system, but the subsumption rules are strict and may require that the language designer predefines them for the user.
Separating constraint generation from solving
"Generalizing Hindley-Milner Type Inference Algorithms" presents half of the solution to the problem I had. MLsub inferencer has a similar glass jaw as what Algorithm W has when it comes to recursion. This means that we have to retrieve a dependency graph before doing type inference. It results in two, slightly related traversals through the program that we want to type inference. If you separate the constraint generation from type inference it means you can retrieve the dependency graph while generating constraints.
The program starts by parsing source code and populate an environment for type inferencing it. Then it proceeds to figuring out the types.
actions = object({
def = Def
app = App
let = Let
abs = Abs
if_then_else = IfThenElse
var = Var
string = (val):
return Lit(val, types.t_string)
int = (val):
return Lit(val, types.t_int)
rational = (val):
return Lit(val, types.t_rational)
})
parse = language.read_file("sample.fml")
decls = parse.traverse((rule, args, loc):
return getattr(actions, rule)(args...))
declared = {}
populate_with_globals(declared)
namespace = {}
for decl in decls
namespace[decl.name] = decl
dependency_graph = {}
constraints_graph = {}
for decl in decls
constraints = []
dependencies = set()
constraint_generation(decl, constraints, dependencies, declared, {})
dependency_graph[decl.name] = dependencies
constraints_graph[decl.name] = constraints
scc = tarjan_find_scc(dependency_graph)
for group in scc
is_rec = group_is_recursive(group, dependency_graph)
solve_constraints(declared, group, is_rec, constraints_graph)
for name, scheme in declared.items()
print(name, "::", to_raw_type(scheme))
You can see above that
the constraint_generation
produces a list of constraints
and the dependency graph
together.
Then we directly get to solving constraints after we have
grouped the declarations to strongly connected components.
Constraint generation
Separating constraint generation from solving them does not only help us in that one thing. It also makes the algorithm itself much easier to represent in code and simplifies it.
constraint_generation = (a, cn, dep, declared, env):
if isinstance(a, Def)
ab = Abs(a.args, a.body)
result = constraint_generation(ab, cn, dep, declared, env)
cn.append(['def', a.name, result])
return result
if isinstance(a, Abs)
env = dict(env)
args = []
for arg in a.args
input, output = new_port(cn, true)
env[arg] = output
args.append(input)
body = constraint_generation(a.body, cn, dep, declared, env)
return ['function', args, body]
if isinstance(a, App)
lhs = constraint_generation(a.lhs, cn, dep, declared, env)
args = [lhs]
for arg in a.args
arg = constraint_generation(arg, cn, dep, declared, env)
args.append(arg)
input, output = new_port(cn)
cn.append(['call', args, input])
return output
if isinstance(a, Var)
if a.name in env
return env[a.name]
if a.name not in declared
dep.add(a.name)
return ['instantiate', a.name]
if isinstance(a, Lit)
return ['literal', a.type]
if isinstance(a, IfThenElse)
cond = constraint_generation(a.cond, cn, dep, declared, env)
cn.append(['<:', cond, ['literal', types.t_bool]])
t = constraint_generation(a.t, cn, dep, declared, env)
f = constraint_generation(a.f, cn, dep, declared, env)
input, output = new_port(cn)
cn.append(['<:', t, input])
cn.append(['<:', f, input])
return output
if isinstance(a, Let)
lhs = constraint_generation(a.lhs, cn, dep, declared, env)
env = dict(env)
env[a.name] = lhs
return constraint_generation(a.rhs, cn, dep, declared, env)
assert false, ["implement constraint_generation", a]
We generate 'def', 'call', '<:' -constraints.
I left away the let-polymorphism because I am convinced I don't do much with that myself. You can implement it in MLsub by keeping a track of free variables.
Constraint solving
The constraint solving is divided into three routines. The first routine catenates the constraints of the strongly connected group of declarations and resolves the recursion.
solve_constraints = (declared, group, is_rec, constraints_graph):
visited = set()
constraints = []
for name in group
constraints.extend(constraints_graph[name])
if is_rec
free = set()
ports = {}
env = dict(declared)
for name in group
port = types.new_port(free)
ports[name] = port
env[name] = generalize(port.output, free)
for name, result in run_constraints(constraints, env, visited)
types.biunify([result, ports[name].input], visited)
for name in group
declared[name] = generalize(ports[name].output, null)
else
for name, result in run_constraints(constraints, declared, visited)
declared[name] = generalize(result, null)
The second routine interpretes every constraint it gets and returns typing schemes.
run_constraints = (constraints, declared, visited):
for c in constraints
which = c[0]
if which == "port"
c[1].add_flow(c[2])
elif which == "call"
args = []
for arg in c[1]
args.append(build_node(arg, declared, visited))
input = c[2]
call_c = types.new_callsite(types.t_call, [0], args.length)
calltype = types.State(-1)
calltype.heads.add(types.get_function_header(args.length))
for i in range(args.length)
calltype.add_transition(types.dom[i], args[i])
calltype.add_transition(types.cod, input)
types.biunify([call_c, calltype], visited)
elif which == "<:"
o = build_node(c[1], declared, visited)
i = build_node(c[2], declared, visited, -1)
types.biunify([o, i], visited)
elif which == "def"
yield [c[1], build_node(c[2], declared, visited)]
else
assert false, which
The third routine produces types from the constructors in the constraints.
build_node = (node_desc, declared, visited, pol=+1):
if isinstance(node_desc, types.State)
return node_desc
if node_desc[0] == "instantiate"
scheme = instantiate_type(declared[node_desc[1]])
return scheme.root
if node_desc[0] == "literal"
result = types.State(pol)
result.heads.add(node_desc[1])
return result
if node_desc[0] == "function"
args = node_desc[1]
resu = node_desc[2]
functype = types.State(pol)
functype.heads.add(types.get_function_header(args.length))
for i in range(args.length)
functype.add_transition(types.dom[i], args[i])
functype.add_transition(types.cod, resu)
return functype
assert false, node_desc
In the second routine we already
diverge from MLsub's implementation.
In MLsub,
a call would be solved through biunification
as f <: argument -> output
.
Instead we construct a binuficiation callsite <: (f, argument) -> output
.
This results in the function calls being overloadable
such that we do not need to
have prior knowledge of
how a specific type can be called.
I skip Instantiate, generalize and printing of types. They are important parts of the MLsub, but the thesis already covers them. I only describe what these routines do and how they have to handle the callsites.
The instantiation copies a generalized schema. It has to copy any generalized callsites that it finds from the schema.
The generalize produces a generalized schema from a type signature. It has to see whether any of the generalized contravariant variables can activate a callsite and then copy and generalize those callsites.
The printing prints a generalized schema. The printing has to detect callsites and annotate them as constraints into the type printout.
Coercive callsites
Coercion is limited to a callsite of an operator. Operator is a function with no specific, instead the implementation comes from the objects passed into it as arguments. When multiple arguments are used as a selector those arguments have to coerce into same type.
Callsites are resolved immediately when we gain more information about their arguments through biunification. A callsite may be discarded when all the flow variables pointing to its arguments are eliminated.
The coercive callsites extend states in the type automaton. When the 'site' variable is set, the state must appear in the states of the corresponding callsite.
class State
+init = (self, pol=+1):
self.pol = pol
self.heads = set()
self.flow = set()
self.transitions = {}
self.site = null
The callsite itself consists of an expected shape, set of states associated to that expected shape, and an activation record that memoizes which operations and coercions have been activated on the site.
class OperatorSite
+init = (self, op, expect, states, activations):
self.op = op
self.expect = expect
self.states = states
self.activations = activations
copy = (self, fn):
op = self.op
expect = self.expect
states = {}
activations = {}
n_site = OperatorSite(op, expect, states, activations)
for edge, state in self.states.items()
states[edge] = n_state = fn(state)
if state.site == self
n_state.site = n_site
for head, activation in self.activations.items()
activations[head] = activation.copy(fn)
return n_site
class Activation
+init = (self, head, state, coerce):
self.head = head # head
self.state = state # State(+1)
self.coerce = coerce # set([head, index])
copy = (self, fn):
head = self.head
state = fn(self.state)
coerce = self.coerce
return Activation(head, state, coerce)
Any callsite may be activated
when new type heads are merged
into the state referencing the callsite.
As the result the biunification algorithm
gets a small addition of cc
-variable,
which contains potential new activations to
callsites that were affected.
biunify = (pair, visited):
if pair in visited
return
visited.add(pair)
p, q = pair
assert isinstance(p, State) and isinstance(q, State), [p,q]
assert p.pol > q.pol, "polarity conflict"
for x in p.heads
for y in q.heads
assert x == y, ["type error", x, y]
cc = []
for s in q.flow
merge(s, p, cc)
for s in p.flow
merge(s, q, cc)
for edge, ws in p.transitions.items()
wu = q.transitions.get(edge, {})
if edge.pol > 0
for s in ws
for u in wu
biunify([s, u], visited)
else
for s in ws
for u in wu
biunify([u, s], visited)
for site, sink, head in cc
update_coercion(site, sink, head, visited)
merge = (dst, src, cc):
assert dst.pol == src.pol, "polarity conflict"
for head in src.heads
dst.add_head(head, cc)
for state in src.flow
dst.add_flow(state)
for edge, ws in src.transitions.items()
try
dst.transitions[edge].update(ws)
except KeyError as _
dst.transitions[edge] = set(ws)
Otherwise the biunification and merging as presented above are identical to those in original MLsub.
When new callsites are created, they are shaped very much like functions but with a difference that the selector arguments in the callsite are annotated with the site itself.
new_callsite = (op, indices, argc):
expect = get_function_header(argc)
states = {}
activations = {}
site = OperatorSite(op, expect, states, activations)
callsite = State(+1)
callsite.heads.add(expect)
for i in range(argc)
if i in indices
sink = State(+1)
sink.site = site
else
sink = State(+1)
arg = State(-1)
arg.add_flow(sink)
states[dom[i]] = sink
callsite.add_transition(dom[i], arg)
sink = State(-1)
out = State(+1)
out.add_flow(sink)
states[cod] = sink
callsite.add_transition(cod, out)
return callsite
When a callsite retrieves a type head not perceived before It will produce cartesian sets with the new type head and existing type heads in other arguments.
update_coercion = (site, pivot, head, visited):
k = set([set([head])])
doms = set()
for edge, sink in site.states.items()
continue if sink.site != site
doms.add(edge)
continue if sink == pivot
n = set()
for t in sink.heads
for h in k
n.add(h.union([t]))
k = n
# k contains all the cartesian sets introduced by the new head.
# doms contains all domain edges that coerce.
From each set a type is chosen such that every other type can be coerced into that type. If there are more than one, we treat it as if there were none. This is an attempt to ensure that coercions stay unique.
# Every new cartesian produces a potential new coercion.
for heads in k
coercions = set()
# 'unique_coercion' determines the new target type.
# it may fail, but in real implementation there
# might be a default function listed for the operator.
target = unique_coercion(heads)
for edge in doms
sink = site.states[edge]
if sink == pivot
coercions.add([head, edge])
else
for h in sink.heads
if h in heads
coercions.add([h, edge])
The callsite is activated with the selected type if it has not already happened because repeated activations would not produce new details. The implementation for operator is being instantiated from a method table and is being biunified to the callsite, with a twist.
# Coercion is not re-applied if it has been already done.
if target in site.activations
activation = site.activations[target]
else
state = binary_operator(site.op, target, site.expect)
assert state.heads == set([site.expect])
"type error"
activation = Activation(target, state, set())
for edge, state in site.states.items()
continue if edge in doms
if edge.pol > 0
for w in activation.state.transitions[edge]
biunify([w, state], visited)
else
for w in activation.state.transitions[edge]
biunify([state, w], visited)
The selector arguments aren't biunified directly. Instead coercions are inserted between the arguments of the activated function and the selector argument. It is ensured that the coercion function does not impose constraints to the selector inputs.
for h, edge in coercions
continue if [h, edge] in activation.coerce
activation.coerce.add([h, edge])
k = simple_coercion(h, target)
assert k.heads == set([get_function_header(1)])
"type error"
# Removing heads from the domain.
for w in k.transitions[dom[0]]
w.heads.clear()
sink = site.states[edge]
for w in k.transitions[dom[0]]
biunify([sink, w], visited)
for w in k.transitions[cod]
for d in activation.state.transitions[edge]
biunify([w, d], visited)
This system forms a coercion system around operator objects. A call on an operator produces an appropriate callsite, the callsite produces the necessary coercions.
Coercion & Operator functions
There remains one question. if the function application is abstracted under the same interface, how to ensure that any callable object can be used as an implementation for an operator?
We have to retrieve the function signature by invoking a callsite when the type is analysed. The invocation then fills the function signature.
Method tables associated to types
This system has to associate a method table to every type present in the language.