Gotta update() expression too.

If type vars get into the espression you have to keep them in sync with
the unification or you can lose information.


Some combinators can put symbols on the expression, you have to convert
those to type checkers or, as a hack, just look them up and run them.
This lets definitions work(-ish), ...
This commit is contained in:
Simon Forman 2018-06-27 22:26:27 -07:00
parent fc45727008
commit 6ca59847ab
3 changed files with 72 additions and 31 deletions

View File

@ -15,10 +15,11 @@ from joy.parser import Symbol
from joy.utils.stack import concat as CONCAT from joy.utils.stack import concat as CONCAT
from joy.utils.types import ( from joy.utils.types import (
AnyJoyType, A, AnyJoyType, A,
BooleanJoyType, B,
C, C,
DEFS, DEFS,
doc_from_stack_effect, doc_from_stack_effect,
FloatJoyType, FloatJoyType, F,
JoyTypeError, JoyTypeError,
NumberJoyType, N, NumberJoyType, N,
StackJoyType, S, StackJoyType, S,
@ -121,9 +122,9 @@ class CombinatorJoyType(FunctionJoyType):
if self.expect is None: if self.expect is None:
return f return f
g = self.expect, self.expect g = self.expect, self.expect
new_f = list(C(f, g)) new_f = list(C(f, g, ()))
assert len(new_f) == 1, repr(new_f) assert len(new_f) == 1, repr(new_f)
return new_f[0] return new_f[0][1]
@ -195,28 +196,28 @@ def _lil_uni(u, v, s):
raise JoyTypeError('Cannot unify %r and %r.' % (u, v)) raise JoyTypeError('Cannot unify %r and %r.' % (u, v))
def compose(f, g): def compose(f, g, e):
(f_in, f_out), (g_in, g_out) = f, g (f_in, f_out), (g_in, g_out) = f, g
for s in unify(g_in, f_out): for s in unify(g_in, f_out):
yield update(s, (f_in, g_out)) yield update(s, (e, (f_in, g_out)))
def C(f, g): def C(f, g, e):
f, g = relabel(f, g) f, g = relabel(f, g)
for fg in compose(f, g): for fg in compose(f, g, e):
yield delabel(fg) yield delabel(fg)
def meta_compose(F, G): def meta_compose(F, G, e):
for f, g in product(F, G): for f, g in product(F, G):
try: try:
for result in C(f, g): yield result for result in C(f, g, e): yield result
except JoyTypeError: except JoyTypeError:
pass pass
def MC(F, G): def MC(F, G, e):
res = sorted(set(meta_compose(F, G))) res = sorted(set(meta_compose(F, G, e)))
if not res: if not res:
raise JoyTypeError('Cannot unify %r and %r.' % (F, G)) raise JoyTypeError('Cannot unify %r and %r.' % (F, G))
return res return res
@ -236,16 +237,17 @@ def infer(e, F=ID):
n, e = e n, e = e
if isinstance(n, SymbolJoyType): if isinstance(n, SymbolJoyType):
res = flatten(infer(e, Fn) for Fn in MC([F], n.stack_effects)) res = flatten(infer(e, Fn) for e, Fn in MC([F], n.stack_effects, e))
elif isinstance(n, CombinatorJoyType): elif isinstance(n, CombinatorJoyType):
fi, fo = n.enter_guard(F) fi, fo = n.enter_guard(F)
res = [] res = flatten(_interpret(f, fi, fo, e) for f in n.stack_effects)
for combinator in n.stack_effects:
new_fo, ee, _ = combinator(fo, e, {}) elif isinstance(n, Symbol):
ee = update(FUNCTIONS, ee) # Fix Symbols. assert n not in FUNCTIONS, repr(n)
new_F = fi, new_fo func = joy.library._dictionary[n]
res.extend(infer(ee, new_F)) res = _interpret(func, F[0], F[1], e)
else: else:
fi, fo = F fi, fo = F
res = infer(e, (fi, (n, fo))) res = infer(e, (fi, (n, fo)))
@ -253,8 +255,18 @@ def infer(e, F=ID):
return res return res
def _interpret(f, fi, fo, e):
new_fo, ee, _ = f(fo, e, {})
ee = update(FUNCTIONS, ee) # Fix Symbols.
new_F = fi, new_fo
return infer(ee, new_F)
a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = A a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = A
b0, b1, b2, b3, b4, b5, b6, b7, b8, b9 = B
n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = N n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = N
f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = F
i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = I = map(IntJoyType, _R)
s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = S s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = S
_R = range(1, 11) _R = range(1, 11)
@ -266,13 +278,13 @@ Ss = map(StackStarJoyType, _R)
FUNCTIONS = { FUNCTIONS = {
name: SymbolJoyType(name, [DEFS[name]], i) name: SymbolJoyType(name, [DEFS[name]], i)
for i, name in enumerate(''' for i, name in enumerate('''
ccons cons divmod_ dup dupd first ccons cons divmod_ dup dupd dupdd first first_two fourth over pop
over pm pop popd popdd popop pred popd popdd popop popopd popopdd rest rrest rolldown rollup second
rest rolldown rollup rrest second stack swaack swap swons third tuck uncons unswons stuncons
sqrt stack succ swaack swap swons stununcons unit eq ge gt le lt ne and_ bool_ not_
third tuck uncons
'''.strip().split()) '''.strip().split())
} }
# sqrt succ pred pm
FUNCTIONS['sum'] = SymbolJoyType('sum', [(((Ns[1], s1), s0), (n0, s0))], 100) FUNCTIONS['sum'] = SymbolJoyType('sum', [(((Ns[1], s1), s0), (n0, s0))], 100)
FUNCTIONS['mul'] = SymbolJoyType('mul', [ FUNCTIONS['mul'] = SymbolJoyType('mul', [
((i2, (i1, s0)), (i3, s0)), ((i2, (i1, s0)), (i3, s0)),
@ -307,9 +319,24 @@ FUNCTIONS['branch'] = CombinatorJoyType('branch', [branch_true, branch_false], 1
globals().update(FUNCTIONS) globals().update(FUNCTIONS)
branch.expect = s7, (s6, (b1, s5))
i.expect = x.expect = s7, s6
dip.expect = s8, (a8, s7) dip.expect = s8, (a8, s7)
dipd.expect = s8, (a8, (a7, s7))
infra.expect = s8, (s7, s6)
NULLARY = infer(((stack, s3), (dip, (infra, (first, ())))))
##print NULLARY
nullary = FUNCTIONS['nullary'] = CombinatorJoyType(
'nullary',
[joy.library._dictionary['nullary']],
101,
)
nullary.expect = s7, s6
# Type Checking...
def _ge(self, other): def _ge(self, other):
return (issubclass(other.__class__, self.__class__) return (issubclass(other.__class__, self.__class__)

View File

@ -238,8 +238,8 @@ A = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = map(AnyJoyType, _R)
B = b0, b1, b2, b3, b4, b5, b6, b7, b8, b9 = map(BooleanJoyType, _R) B = b0, b1, b2, b3, b4, b5, b6, b7, b8, b9 = map(BooleanJoyType, _R)
N = n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = map(NumberJoyType, _R) N = n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = map(NumberJoyType, _R)
S = s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = map(StackJoyType, _R) S = s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = map(StackJoyType, _R)
F = f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = F = map(FloatJoyType, _R) F = f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = map(FloatJoyType, _R)
I = i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = I = map(IntJoyType, _R) I = i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = map(IntJoyType, _R)
def defs(): def defs():
@ -278,8 +278,6 @@ def defs():
and_ = __(b1, b2), __(b3) and_ = __(b1, b2), __(b3)
bool_ = not_ = __(a1), __(b1) bool_ = not_ = __(a1), __(b1)
add = div = floordiv = modulus = mul = pow_ = sub = truediv = \ add = div = floordiv = modulus = mul = pow_ = sub = truediv = \
lshift = rshift = __(n1, n2), __(n3,) lshift = rshift = __(n1, n2), __(n3,)
sqrt = C(dup, mul) sqrt = C(dup, mul)

View File

@ -90,11 +90,27 @@ class TestCombinators(TestMixin, unittest.TestCase):
def test_nullary(self): def test_nullary(self):
expression = n1, n2, (mul, s2), (stack, s3), dip, infra, first expression = n1, n2, (mul, s2), (stack, s3), dip, infra, first
f = [ f = [
(s1, (f1, (n1, (n2, s2)))), # (-- n2 n1 f1) (s1, (f1, (f2, (f3, s1)))), # (-- f3 f2 f1)
(s1, (i1, (n1, (n2, s2)))), # (-- n2 n1 i1) (s1, (f1, (f2, (i1, s1)))), # (-- i1 f2 f1)
(s1, (f1, (i1, (f2, s1)))), # (-- f2 i1 f1)
(s1, (i1, (i2, (i3, s1)))), # (-- i3 i2 i1)
] ]
self.assertEqualTypeStructure(infr(expression), f) self.assertEqualTypeStructure(infr(expression), f)
expression = n1, n2, (mul, s2), nullary
self.assertEqualTypeStructure(infr(expression), f)
def test_nullary_too(self):
expression = (stack, s3), dip, infra, first
f = ((s1, (a1, s2)), (a1, (a1, s2))) # (a1 [...1] -- a1 a1)
self.assertEqualTypeStructure(infr(expression), [f])
expression = nullary,
f = ((s1, (a1, s2)), (a1, (a1, s2))) # (a1 [...1] -- a1 a1)
# Something's not quite right here...
e = infr(expression)
self.assertEqualTypeStructure(infr(expression), [f])
def test_x(self): def test_x(self):
expression = (a1, (swap, ((dup, s2), (dip, s0)))), x expression = (a1, (swap, ((dup, s2), (dip, s0)))), x
f = (s0, ((a0, (swap, ((dup, s1), (dip, s2)))), (a1, (a1, s0)))) f = (s0, ((a0, (swap, ((dup, s1), (dip, s2)))), (a1, (a1, s0))))
@ -168,4 +184,4 @@ class TestYin(TestMixin, unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main() #defaultTest='TestCombinators.test_nullary_too')