Skip to main content

Verifying the Monad Laws with Supercompilation

A while ago I wrote a post about how one can use coq to make a proper monad module. I was just thinking today that it would be nice to have a tool for haskell that would allow one to write down conjectures and discharge them automatically with a theorem prover. Supercompilation makes a nice clean theorem prover for haskell since one could express the equations of interest in haskell itself. Below is an example of the list monad, and then the 3 monad laws written as conj1,conj2 and conj3. I prove the first law by manual supercompilation. The next two are left as an exercise for the interested reader.

equal xs ys = case xs of 
                [] -> case ys of 
                        [] -> True 
                        _ -> False 
                (x:xs') -> case ys of 
                             [] -> False 
                             (y:ys') -> case (eq x y) of 
                                          False -> False
                                          True ->  equal xs' ys'

bind [] f = [] 
bind (h:t) f = (f h)++(bind t f)
  
ret a = [a]

conj1 a f = equal (bind (ret a) f) (f a)
conj2 m = equal (bind m (\\a -> ret a)) m
conj3 m f g = equal (bind (bind m f) g) (bind m (\\x -> (bind (f x) g)))

In order to define equality on lists we had to make reference to a different equality predicate on the elements of the list. We will make the assumption that this equality is decidable and supercompilation can prove the reflexive property for this equality predicate, that is "eq x x = True" will be taken as an assumption. It is somewhat hard to imagine a case where supercompilation would have a hard time with this because of the case substitution rule.

Now we take conj1 and attempt to prove it by semantics preserving transformations of the source code [1]. I've been brutally explicit in the steps used so that people who are interested can see exactly how to do these sorts of things themselves. I've found that program transformation techniques can be extremely useful in reasoning about code. Of course, it helps to assume that all functions are total, and I'll be doing just that. I use the notation M[x:=y] to mean the substitution of y for x in M. Aside from that everything is just Haskell.

conj1 a f = equal (bind (ret a) f) (f a)

{- unfold equal -} 
conj1 a f = case (bind (ret a) f) of 
              [] -> case (f a) of 
                      [] -> True 
                      _ -> False
              (x:xs') -> case (f a) of 
                           [] -> False 
                           (y:ys') -> case eq x y of 
                                        False -> False 
                                        True -> equal xs' ys'

{- unfold bind -}
conj1 a f = case (case (ret a) of 
                    [] -> []
                    (z:zs) -> (f h)++(bind zs f))
              [] -> case (f a) of 
                      [] -> True 
                      _ -> False
              (x:xs') -> case (f a) of 
                           [] -> False 
                           (y:ys') -> case eq x y of 
                                        False -> False 
                                        True -> equal xs' ys'

{- case distribution -} 
conj1 a f = case (ret a) of 
              [] -> case (f a) of 
                      [] -> True 
                      _ -> False
              (z:zs) -> case ((f h)++(bind zs f)) of 
                          [] -> case (f a) of 
                                  [] -> True 
                                  _ -> False
                          (x:xs') -> case (f a) of 
                                       [] -> False 
                                       (y:ys') -> case eq x y of 
                                                    False -> False 
                                                    True -> equal xs' ys'

{- unfold ret -}
conj1 a f = case [a] of 
              [] -> case (f a) of 
                      [] -> True 
                      _ -> False
              (z:zs) -> case ((f h)++(bind zs f)) of 
                          [] -> case (f a) of 
                                  [] -> True 
                                  _ -> False
                          (x:xs') -> case (f a) of 
                                       [] -> False 
                                       (y:ys') -> case eq x y of 
                                                    False -> False 
                                                    True -> equal xs' ys'

{- case selection -}
conj1 a f = (case ((f z)++(bind zs f)) of 
                      [] -> case (f a) of 
                              [] -> True 
                              _ -> False
                      (x:xs') -> case (f a) of 
                                   [] -> False 
                                   (y:ys') -> case eq x y of 
                                                False -> False 
                                                True -> equal xs' ys') 
                                     [z := a,  zs := []]

{- substitution -}
conj1 a f = case ((f a)++(bind [] f)) of 
              [] -> case (f a) of 
                      [] -> True 
                      _ -> False
              (x:xs') -> case (f a) of 
                           [] -> False 
                           (y:ys') -> case eq x y of 
                                        False -> False 
                                        True -> equal xs' ys'

{- unfold bind  -}
conj1 a f = case (case (f a) of
                    [] -> []
                    (z:zs') -> z:(zs'++[])) of 
              [] -> case (f a) of 
                      [] -> True 
                      _ -> False 
              (x:xs') -> case (f a) of 
                           [] -> False 
                           (y:ys') -> case eq x y of 
                                        False -> False 
                                        True -> equal xs' ys' 

{- case selection, substitution -}
conj1 a f = case (f a) of
              [] -> True                     
              (z:zs') -> case eq z z of 
                           False -> False 
                           True -> equal (zs'++[]) zs'

{- Assumption: eq z z = True -}
 conj1 a f = case (f a) of
              [] -> True                     
              (z:zs') -> equal (zs'++[]) zs'

{- unfold of equal and ++ -}
conj1 a f = case (f a) of
              [] -> True                     
              (z:zs') -> case (case zs' of 
                                 [] -> [] 
                                 (w:ws) -> w:(ws++[]) of 
                           [] -> case zs' of 
                                   [] -> True
                                   (y:ys) -> False
                           (x:xs) -> case zs' of 
                                       [] -> False 
                                       (y:ys) -> case (eq x y) of
                                                   False -> False 
                                                   True -> equal xs ys

{- case selection -}                                                 
conj1 a f = case (f a) of
              [] -> True                     
              (z:zs') -> case zs' of 
                           [] -> True
                           (w:ws) -> case (eq w w) of
                                       False -> False 
                                       True -> equal (ws++[]) ws
                 
{- Assumption: eq w w = True -}
conj1 a f = case (f a) of
              [] -> True                     
              (z:zs') -> case zs' of 
                           [] -> True
                           (w:ws) -> equal (ws++[]) ws
            
{- fold, (we encountered a repeated instance of 'equal (ws++[]) ws') -}
conj1 a f = case (f a) of 
              [] -> True                     
              (z:zs') -> let g = \\ xs -> 
                                 case xs of 
                                   [] -> True 
                                   (y:ys') -> g ys'
                         in g zs'

Now we have a function that can only return true, regardless of the value of (f a) and assuming that f is total, we can replace this term with True. The proof of this is simply by induction on (f a), but the principle can be built into a checker and is the principle used in the Poitín [3] theorem prover.

Something along these lines would be very light-weight in comparison to using a full theorem prover and could just issue a warning if some laws couldn't be proved. The proof of this in Coq is actually more work and requires lemmas about the list type to be proven. I did in fact fudge the ((f a)++[]) situation there by replacing with (f a), and I'm not entirely sure what supercompilation does with it if you leave it in. I'll have to try it later. UPDATE: This is fixed in the code above.

If we want to prove this result by hand we can also derive it in a much simpler way by not following the supercompilation method verbatim:

conj1 a f = equal (bind [a] f) (f a)
conj1 a f = equal (append (f a) (bind [] f)) (f a)
conj1 a f = equal (append (f a) []) (f a)
conj1 a f = equal (f a) (f a)

The advantage of the former is of course that it is entirely mechanical.

UPDATE: I went ahead and incorporated the definition of append and supercompiled it (by hand) with this definition and it reduces to exactly the same term, so in fact supercompilation, without any special knowledge, is capable of proving the conjecture. This is a real win over most proof assistants which would require this append lemma to be proved separately, or incorporated into the knowledge base. For a more visual and more compact depiction of the derivation we can draw a partial process tree with the append function definition included in the derivation.



And indeed the second monad law:



Is the third monad law for the list monad provable using supercompilation? I expect it is, but since I don't have a supercompiler, and it is a bit of a bear to do by hand, I'm not sure.

Incidently, there is already a supercompiler being developed for haskell called Supero [2]. It probably wouldn't be too much work to extend this with theorem proving capabilities using ideas from [3].

[1] A Transformation System for Developing Recursive Programs
[2] A Supercompiler for Core Haskell
[3] Poit?n: Distilling Theorems From Conjectures

Comments

Popular posts from this blog

Generating etags automatically when needed

Have you ever wanted M-. (the emacs command which finds the definition of the term under the cursor) to just "do the right thing" and go to the most current definition site, but were in a language that didn't have an inferior process set-up to query about source locations correctly (as is done in lisp, ocaml and some other languages with sophisticated emacs interfaces)?

Well, fret no more. Here is an approach that will let you save the appropriate files and regenerate your TAGS file automatically when things change assuring that M-. takes you to the appropriate place.

You will have to reset the tags-table-list or set it when you first use M-. and you'll want to change the language given to find and etags in the 'create-prolog-tags function (as you're probably not using prolog), but otherwise it shouldn't require much customisation.

And finally, you will need to run etags once manually, or run 'M-x create-prolog-tags' in order to get the initia…

Decidable Equality in Agda

So I've been playing with typing various things in System-F which previously I had left with auxiliary well-formedness conditions. This includes substitutions and contexts, both of which are interesting to have well typed versions of. Since I've been learning Agda, it seemed sensible to carry out this work in that language, as there is nothing like a problem to help you learn a language.

In the course of proving properties, I ran into the age old problem of showing that equivalence is decidable between two objects. In this particular case, I need to be able to show the decidability of equality over types in System F in order to have formation rules for variable contexts. We'd like a context Γ to have (x:A) only if (x:B) does not occur in Γ when (A ≠ B). For us to have statements about whether two types are equal or not, we're going to need to be able to decide if that's true using a terminating procedure.

And so we arrive at our story. In Coq, equality is som…

Formalisation of Tables in a Dependent Language

I've had an idea kicking about in my head for a while of making query plans explicit in SQL in such a way that one can be assured that the query plan corresponds to the SQL statement desired. The idea is something like a Curry-Howard in a relational setting. One could infer the plan from the SQL, the SQL from the plan, or do a sort of "type-checking" to make sure that the plan corresponds to the SQL.

The devil is always in the details however. When I started looking at the primitives that I would need, it turns out that the low level table joining operations are actually not that far from primitive SQL statement themselves. I decided to go ahead and formalise some of what would be necessary in Agda in order get a better feel for the types of objects I would need and the laws which would be required to demonstrate that a plan corresponded with a statement.

Dependent types are very powerful and give you plenty of rope to hang yourself. It's always something of…