import hypothesis.strategies as st
from hypothesis import given
from haskpy.internal import class_function, abstract_class_function
from haskpy.testing import assert_output
from haskpy import testing
# Use the "hidden" module in order to avoid circular imports
from ._applicative import Applicative
[docs]class Monad(Applicative):
"""Monad typeclass
Minimum required implementations:
- pure
- one of the following:
- bind
- join and map
The required Applicative methods are given default implementations based on
the required Monad methods. But it is recommended to implement other
methods as well if speed has any relevance.
"""
[docs] def bind(self, f):
"""m a -> (a -> m b) -> m b
Default implementation is based on ``join`` and ``map``:
self :: m a
f :: a -> m b
map f :: m a -> m (m b)
join :: m (m b) -> m b
"""
return self.map(f).join()
[docs] def join(self):
"""m (m a) -> m a
Default implementation is based on ``bind``:
self :: m (m a)
identity :: m a -> m a
bind :: m (m a) -> (m a -> m a) -> m a
"""
from haskpy.utils import identity
return self.bind(identity)
[docs] def apply(self, f):
r"""m a -> m (a -> b) -> m b
self :: m a
f :: m (a -> b)
Default implementation is based on ``bind`` and ``map``. In order to
use ``bind``, let's write its type as follows:
bind :: m (a -> b) -> ((a -> b) -> m b) -> m b
Let's also use a simple helper function:
h = \g -> map g self :: (a -> b) -> m b
Now:
bind f h :: m b
"""
return f.bind(lambda g: self.map(g))
[docs] def map(self, f):
"""m a -> (a -> b) -> m b
Default implementation is based on ``bind`` and ``pure``. This
implementation needs to be provided because the default implementation
of ``apply`` uses ``map`` thus creating a circular dependency between
the default ``map`` defined in ``Applicative``.
"""
# Because of circular dependencies, need to import here inside
from haskpy.types.function import compose
cls = type(self)
return self.bind(compose(cls.pure, f))
[docs] def __mod__(self, f):
"""Use ``%`` as bind operator similarly as ``>>=`` in Haskell
That is, ``x % f`` is equivalent to ``bind(x, f)`` and ``x.bind(f)``.
Why ``%`` operator?
- It's not very often used so less risk for confusion.
- It's not commutative as isn't bind either.
- It is similar to bind in a sense that the result has the same unit as
the left operand while the right operand has different unit.
- The symbol works visually as a line "binds" two circles and on the
other hand two circles tell about two similar structures on both
sides but those structures are just on different "level".
"""
return self.bind(f)
#
# Sampling methods for property tests
#
@abstract_class_function
def sample_monad_type(cls, a):
pass
#
# Test typeclass laws
#
@class_function
@assert_output
def assert_monad_left_identity(cls, f, a):
return (f(a), cls.pure(a).bind(f))
@class_function
@given(st.data())
def test_monad_left_identity(cls, data):
# Draw types
a = data.draw(testing.sample_eq_type())
b = data.draw(testing.sample_type())
mb = data.draw(cls.sample_monad_type(b))
# Draw values
f = data.draw(testing.sample_function(mb))
x = data.draw(a)
cls.assert_monad_left_identity(f, x, data=data)
return
@class_function
@assert_output
def assert_monad_right_identity(cls, m):
return (m, m.bind(cls.pure))
@class_function
@given(st.data())
def test_monad_right_identity(cls, data):
# Draw types
a = data.draw(testing.sample_type())
ma = data.draw(cls.sample_monad_type(a))
# Draw values
m = data.draw(ma)
cls.assert_monad_right_identity(m, data=data)
return
@class_function
@assert_output
def assert_monad_associativity(cls, m, f, g):
return (
m.bind(f).bind(g),
m.bind(lambda x: f(x).bind(g)),
)
@class_function
@given(st.data())
def test_monad_associativity(cls, data):
a = data.draw(testing.sample_eq_type())
b = data.draw(testing.sample_eq_type())
c = data.draw(testing.sample_type())
ma = data.draw(cls.sample_monad_type(a))
mb = data.draw(cls.sample_monad_type(b))
mc = data.draw(cls.sample_monad_type(c))
m = data.draw(ma)
f = data.draw(testing.sample_function(mb))
g = data.draw(testing.sample_function(mc))
cls.assert_monad_associativity(m, f, g, data=data)
return
#
# Test laws based on default implementations
#
@class_function
@assert_output
def assert_monad_bind(cls, u, f):
from .monad import bind
return (
Monad.bind(u, f),
u.bind(f),
bind(u, f),
)
@class_function
@given(st.data())
def test_monad_bind(cls, data):
"""Test consistency of ``bind`` with the default implementation"""
# Draw types
a = data.draw(testing.sample_eq_type())
b = data.draw(testing.sample_type())
ma = data.draw(cls.sample_monad_type(a))
mb = data.draw(cls.sample_monad_type(b))
# Draw values
u = data.draw(ma)
f = data.draw(testing.sample_function(mb))
cls.assert_monad_bind(u, f, data=data)
return
@class_function
@assert_output
def assert_monad_join(cls, u):
from .monad import join
return (
Monad.join(u),
u.join(),
join(u),
)
@class_function
@given(st.data())
def test_monad_join(cls, data):
"""Test consistency of ``join`` with the default implementation"""
# Draw types
b = data.draw(testing.sample_type())
mb = data.draw(cls.sample_monad_type(b))
mmb = data.draw(cls.sample_monad_type(mb))
# Draw values
u = data.draw(mmb)
cls.assert_monad_join(u, data=data)
return
@class_function
@assert_output
def assert_monad_map(cls, u, f):
return (
Monad.map(u, f),
u.map(f),
)
@class_function
@given(st.data())
def test_monad_map(cls, data):
"""Test consistency of ``map`` with the default implementation"""
# Draw types
a = data.draw(testing.sample_eq_type())
b = data.draw(testing.sample_type())
ma = data.draw(cls.sample_monad_type(a))
u = data.draw(ma)
f = data.draw(testing.sample_function(b))
cls.assert_monad_map(u, f, data=data)
return
@class_function
@assert_output
def assert_monad_apply(cls, u, v):
return (
Monad.apply(v, u),
v.apply(u),
)
@class_function
@given(st.data())
def test_monad_apply(cls, data):
"""Test consistency ``apply`` with the default implementations"""
# Draw types
a = data.draw(testing.sample_eq_type())
b = data.draw(testing.sample_type())
ma = data.draw(cls.sample_monad_type(a))
mab = data.draw(cls.sample_monad_type(testing.sample_function(b)))
# Draw values
v = data.draw(ma)
u = data.draw(mab)
cls.assert_monad_apply(u, v, data=data)
return