Gotta update() expression too.

If type vars get into the espression you have to keep them in sync with
the unification or you can lose information.


Some combinators can put symbols on the expression, you have to convert
those to type checkers or, as a hack, just look them up and run them.
This lets definitions work(-ish), ...
This commit is contained in:
Simon Forman 2018-06-27 22:26:27 -07:00
parent fc45727008
commit 6ca59847ab
3 changed files with 72 additions and 31 deletions

View File

@ -15,10 +15,11 @@ from joy.parser import Symbol
from joy.utils.stack import concat as CONCAT
from joy.utils.types import (
AnyJoyType, A,
BooleanJoyType, B,
C,
DEFS,
doc_from_stack_effect,
FloatJoyType,
FloatJoyType, F,
JoyTypeError,
NumberJoyType, N,
StackJoyType, S,
@ -121,9 +122,9 @@ class CombinatorJoyType(FunctionJoyType):
if self.expect is None:
return f
g = self.expect, self.expect
new_f = list(C(f, g))
new_f = list(C(f, g, ()))
assert len(new_f) == 1, repr(new_f)
return new_f[0]
return new_f[0][1]
@ -195,28 +196,28 @@ def _lil_uni(u, v, s):
raise JoyTypeError('Cannot unify %r and %r.' % (u, v))
def compose(f, g):
def compose(f, g, e):
(f_in, f_out), (g_in, g_out) = f, g
for s in unify(g_in, f_out):
yield update(s, (f_in, g_out))
yield update(s, (e, (f_in, g_out)))
def C(f, g):
def C(f, g, e):
f, g = relabel(f, g)
for fg in compose(f, g):
for fg in compose(f, g, e):
yield delabel(fg)
def meta_compose(F, G):
def meta_compose(F, G, e):
for f, g in product(F, G):
try:
for result in C(f, g): yield result
for result in C(f, g, e): yield result
except JoyTypeError:
pass
def MC(F, G):
res = sorted(set(meta_compose(F, G)))
def MC(F, G, e):
res = sorted(set(meta_compose(F, G, e)))
if not res:
raise JoyTypeError('Cannot unify %r and %r.' % (F, G))
return res
@ -236,16 +237,17 @@ def infer(e, F=ID):
n, e = e
if isinstance(n, SymbolJoyType):
res = flatten(infer(e, Fn) for Fn in MC([F], n.stack_effects))
res = flatten(infer(e, Fn) for e, Fn in MC([F], n.stack_effects, e))
elif isinstance(n, CombinatorJoyType):
fi, fo = n.enter_guard(F)
res = []
for combinator in n.stack_effects:
new_fo, ee, _ = combinator(fo, e, {})
ee = update(FUNCTIONS, ee) # Fix Symbols.
new_F = fi, new_fo
res.extend(infer(ee, new_F))
res = flatten(_interpret(f, fi, fo, e) for f in n.stack_effects)
elif isinstance(n, Symbol):
assert n not in FUNCTIONS, repr(n)
func = joy.library._dictionary[n]
res = _interpret(func, F[0], F[1], e)
else:
fi, fo = F
res = infer(e, (fi, (n, fo)))
@ -253,8 +255,18 @@ def infer(e, F=ID):
return res
def _interpret(f, fi, fo, e):
new_fo, ee, _ = f(fo, e, {})
ee = update(FUNCTIONS, ee) # Fix Symbols.
new_F = fi, new_fo
return infer(ee, new_F)
a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = A
b0, b1, b2, b3, b4, b5, b6, b7, b8, b9 = B
n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = N
f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = F
i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = I = map(IntJoyType, _R)
s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = S
_R = range(1, 11)
@ -266,13 +278,13 @@ Ss = map(StackStarJoyType, _R)
FUNCTIONS = {
name: SymbolJoyType(name, [DEFS[name]], i)
for i, name in enumerate('''
ccons cons divmod_ dup dupd first
over pm pop popd popdd popop pred
rest rolldown rollup rrest second
sqrt stack succ swaack swap swons
third tuck uncons
ccons cons divmod_ dup dupd dupdd first first_two fourth over pop
popd popdd popop popopd popopdd rest rrest rolldown rollup second
stack swaack swap swons third tuck uncons unswons stuncons
stununcons unit eq ge gt le lt ne and_ bool_ not_
'''.strip().split())
}
# sqrt succ pred pm
FUNCTIONS['sum'] = SymbolJoyType('sum', [(((Ns[1], s1), s0), (n0, s0))], 100)
FUNCTIONS['mul'] = SymbolJoyType('mul', [
((i2, (i1, s0)), (i3, s0)),
@ -307,9 +319,24 @@ FUNCTIONS['branch'] = CombinatorJoyType('branch', [branch_true, branch_false], 1
globals().update(FUNCTIONS)
branch.expect = s7, (s6, (b1, s5))
i.expect = x.expect = s7, s6
dip.expect = s8, (a8, s7)
dipd.expect = s8, (a8, (a7, s7))
infra.expect = s8, (s7, s6)
NULLARY = infer(((stack, s3), (dip, (infra, (first, ())))))
##print NULLARY
nullary = FUNCTIONS['nullary'] = CombinatorJoyType(
'nullary',
[joy.library._dictionary['nullary']],
101,
)
nullary.expect = s7, s6
# Type Checking...
def _ge(self, other):
return (issubclass(other.__class__, self.__class__)

View File

@ -238,8 +238,8 @@ A = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = map(AnyJoyType, _R)
B = b0, b1, b2, b3, b4, b5, b6, b7, b8, b9 = map(BooleanJoyType, _R)
N = n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = map(NumberJoyType, _R)
S = s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = map(StackJoyType, _R)
F = f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = F = map(FloatJoyType, _R)
I = i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = I = map(IntJoyType, _R)
F = f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = map(FloatJoyType, _R)
I = i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = map(IntJoyType, _R)
def defs():
@ -278,8 +278,6 @@ def defs():
and_ = __(b1, b2), __(b3)
bool_ = not_ = __(a1), __(b1)
add = div = floordiv = modulus = mul = pow_ = sub = truediv = \
lshift = rshift = __(n1, n2), __(n3,)
sqrt = C(dup, mul)

View File

@ -90,11 +90,27 @@ class TestCombinators(TestMixin, unittest.TestCase):
def test_nullary(self):
expression = n1, n2, (mul, s2), (stack, s3), dip, infra, first
f = [
(s1, (f1, (n1, (n2, s2)))), # (-- n2 n1 f1)
(s1, (i1, (n1, (n2, s2)))), # (-- n2 n1 i1)
(s1, (f1, (f2, (f3, s1)))), # (-- f3 f2 f1)
(s1, (f1, (f2, (i1, s1)))), # (-- i1 f2 f1)
(s1, (f1, (i1, (f2, s1)))), # (-- f2 i1 f1)
(s1, (i1, (i2, (i3, s1)))), # (-- i3 i2 i1)
]
self.assertEqualTypeStructure(infr(expression), f)
expression = n1, n2, (mul, s2), nullary
self.assertEqualTypeStructure(infr(expression), f)
def test_nullary_too(self):
expression = (stack, s3), dip, infra, first
f = ((s1, (a1, s2)), (a1, (a1, s2))) # (a1 [...1] -- a1 a1)
self.assertEqualTypeStructure(infr(expression), [f])
expression = nullary,
f = ((s1, (a1, s2)), (a1, (a1, s2))) # (a1 [...1] -- a1 a1)
# Something's not quite right here...
e = infr(expression)
self.assertEqualTypeStructure(infr(expression), [f])
def test_x(self):
expression = (a1, (swap, ((dup, s2), (dip, s0)))), x
f = (s0, ((a0, (swap, ((dup, s1), (dip, s2)))), (a1, (a1, s0))))
@ -168,4 +184,4 @@ class TestYin(TestMixin, unittest.TestCase):
if __name__ == '__main__':
unittest.main()
unittest.main() #defaultTest='TestCombinators.test_nullary_too')