diff --git a/joy/utils/polytypes.py b/joy/utils/polytypes.py index beada70..c467cc3 100644 --- a/joy/utils/polytypes.py +++ b/joy/utils/polytypes.py @@ -15,10 +15,11 @@ from joy.parser import Symbol from joy.utils.stack import concat as CONCAT from joy.utils.types import ( AnyJoyType, A, + BooleanJoyType, B, C, DEFS, doc_from_stack_effect, - FloatJoyType, + FloatJoyType, F, JoyTypeError, NumberJoyType, N, StackJoyType, S, @@ -121,9 +122,9 @@ class CombinatorJoyType(FunctionJoyType): if self.expect is None: return f g = self.expect, self.expect - new_f = list(C(f, g)) + new_f = list(C(f, g, ())) assert len(new_f) == 1, repr(new_f) - return new_f[0] + return new_f[0][1] @@ -195,28 +196,28 @@ def _lil_uni(u, v, s): raise JoyTypeError('Cannot unify %r and %r.' % (u, v)) -def compose(f, g): +def compose(f, g, e): (f_in, f_out), (g_in, g_out) = f, g for s in unify(g_in, f_out): - yield update(s, (f_in, g_out)) + yield update(s, (e, (f_in, g_out))) -def C(f, g): +def C(f, g, e): f, g = relabel(f, g) - for fg in compose(f, g): + for fg in compose(f, g, e): yield delabel(fg) -def meta_compose(F, G): +def meta_compose(F, G, e): for f, g in product(F, G): try: - for result in C(f, g): yield result + for result in C(f, g, e): yield result except JoyTypeError: pass -def MC(F, G): - res = sorted(set(meta_compose(F, G))) +def MC(F, G, e): + res = sorted(set(meta_compose(F, G, e))) if not res: raise JoyTypeError('Cannot unify %r and %r.' % (F, G)) return res @@ -236,16 +237,17 @@ def infer(e, F=ID): n, e = e if isinstance(n, SymbolJoyType): - res = flatten(infer(e, Fn) for Fn in MC([F], n.stack_effects)) + res = flatten(infer(e, Fn) for e, Fn in MC([F], n.stack_effects, e)) elif isinstance(n, CombinatorJoyType): fi, fo = n.enter_guard(F) - res = [] - for combinator in n.stack_effects: - new_fo, ee, _ = combinator(fo, e, {}) - ee = update(FUNCTIONS, ee) # Fix Symbols. - new_F = fi, new_fo - res.extend(infer(ee, new_F)) + res = flatten(_interpret(f, fi, fo, e) for f in n.stack_effects) + + elif isinstance(n, Symbol): + assert n not in FUNCTIONS, repr(n) + func = joy.library._dictionary[n] + res = _interpret(func, F[0], F[1], e) + else: fi, fo = F res = infer(e, (fi, (n, fo))) @@ -253,8 +255,18 @@ def infer(e, F=ID): return res +def _interpret(f, fi, fo, e): + new_fo, ee, _ = f(fo, e, {}) + ee = update(FUNCTIONS, ee) # Fix Symbols. + new_F = fi, new_fo + return infer(ee, new_F) + + a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = A +b0, b1, b2, b3, b4, b5, b6, b7, b8, b9 = B n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = N +f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = F +i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = I = map(IntJoyType, _R) s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = S _R = range(1, 11) @@ -266,13 +278,13 @@ Ss = map(StackStarJoyType, _R) FUNCTIONS = { name: SymbolJoyType(name, [DEFS[name]], i) for i, name in enumerate(''' - ccons cons divmod_ dup dupd first - over pm pop popd popdd popop pred - rest rolldown rollup rrest second - sqrt stack succ swaack swap swons - third tuck uncons + ccons cons divmod_ dup dupd dupdd first first_two fourth over pop + popd popdd popop popopd popopdd rest rrest rolldown rollup second + stack swaack swap swons third tuck uncons unswons stuncons + stununcons unit eq ge gt le lt ne and_ bool_ not_ '''.strip().split()) } +# sqrt succ pred pm FUNCTIONS['sum'] = SymbolJoyType('sum', [(((Ns[1], s1), s0), (n0, s0))], 100) FUNCTIONS['mul'] = SymbolJoyType('mul', [ ((i2, (i1, s0)), (i3, s0)), @@ -307,9 +319,24 @@ FUNCTIONS['branch'] = CombinatorJoyType('branch', [branch_true, branch_false], 1 globals().update(FUNCTIONS) - +branch.expect = s7, (s6, (b1, s5)) +i.expect = x.expect = s7, s6 dip.expect = s8, (a8, s7) +dipd.expect = s8, (a8, (a7, s7)) +infra.expect = s8, (s7, s6) +NULLARY = infer(((stack, s3), (dip, (infra, (first, ()))))) +##print NULLARY + +nullary = FUNCTIONS['nullary'] = CombinatorJoyType( + 'nullary', + [joy.library._dictionary['nullary']], + 101, + ) +nullary.expect = s7, s6 + + +# Type Checking... def _ge(self, other): return (issubclass(other.__class__, self.__class__) diff --git a/joy/utils/types.py b/joy/utils/types.py index 20f26a2..01bf690 100644 --- a/joy/utils/types.py +++ b/joy/utils/types.py @@ -238,8 +238,8 @@ A = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = map(AnyJoyType, _R) B = b0, b1, b2, b3, b4, b5, b6, b7, b8, b9 = map(BooleanJoyType, _R) N = n0, n1, n2, n3, n4, n5, n6, n7, n8, n9 = map(NumberJoyType, _R) S = s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 = map(StackJoyType, _R) -F = f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = F = map(FloatJoyType, _R) -I = i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = I = map(IntJoyType, _R) +F = f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = map(FloatJoyType, _R) +I = i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 = map(IntJoyType, _R) def defs(): @@ -278,8 +278,6 @@ def defs(): and_ = __(b1, b2), __(b3) bool_ = not_ = __(a1), __(b1) - - add = div = floordiv = modulus = mul = pow_ = sub = truediv = \ lshift = rshift = __(n1, n2), __(n3,) sqrt = C(dup, mul) diff --git a/test/test_type_inference.py b/test/test_type_inference.py index a1d3435..1d59407 100644 --- a/test/test_type_inference.py +++ b/test/test_type_inference.py @@ -90,11 +90,27 @@ class TestCombinators(TestMixin, unittest.TestCase): def test_nullary(self): expression = n1, n2, (mul, s2), (stack, s3), dip, infra, first f = [ - (s1, (f1, (n1, (n2, s2)))), # (-- n2 n1 f1) - (s1, (i1, (n1, (n2, s2)))), # (-- n2 n1 i1) + (s1, (f1, (f2, (f3, s1)))), # (-- f3 f2 f1) + (s1, (f1, (f2, (i1, s1)))), # (-- i1 f2 f1) + (s1, (f1, (i1, (f2, s1)))), # (-- f2 i1 f1) + (s1, (i1, (i2, (i3, s1)))), # (-- i3 i2 i1) ] self.assertEqualTypeStructure(infr(expression), f) + expression = n1, n2, (mul, s2), nullary + self.assertEqualTypeStructure(infr(expression), f) + + def test_nullary_too(self): + expression = (stack, s3), dip, infra, first + f = ((s1, (a1, s2)), (a1, (a1, s2))) # (a1 [...1] -- a1 a1) + self.assertEqualTypeStructure(infr(expression), [f]) + + expression = nullary, + f = ((s1, (a1, s2)), (a1, (a1, s2))) # (a1 [...1] -- a1 a1) + # Something's not quite right here... + e = infr(expression) + self.assertEqualTypeStructure(infr(expression), [f]) + def test_x(self): expression = (a1, (swap, ((dup, s2), (dip, s0)))), x f = (s0, ((a0, (swap, ((dup, s1), (dip, s2)))), (a1, (a1, s0)))) @@ -168,4 +184,4 @@ class TestYin(TestMixin, unittest.TestCase): if __name__ == '__main__': - unittest.main() + unittest.main() #defaultTest='TestCombinators.test_nullary_too')