Skip to content

Lambda Literals in SMT Proof

Brian Milnes edited this page Jan 20, 2025 · 1 revision
module NamingLambdaLiterals
open FStar.Mul
open FStar.IntegerIntervals
module F = FStar.FunctionalExtensionality

///   Defining a simple sigma (ranged sum) that will use lambda literals for proofs shows
///  the problems proving simple lemmas.

let rec sigma (i: int) (j: int) (e: interval i (j + 1) -> int)
: Tot int (decreases j - i)
= if j < i
  then 0
  else if i = j
  then e i
  else sigma i (j - 1) e + e j

[@@expect_failure [19]]
let rec sigma_const_dist (i: int) (j: int) (e: interval i (j + 1) -> int) (c: nat) :
  Lemma (requires i <= j)
        (ensures sigma i j (fun k -> c * (e k)) = c * (sigma i j e))
  (decreases j - i)
= if j < i 
  then false_elim()
  else if i = j
  then ()
  else begin 
             sigma_const_dist i (j - 1) e c;
             // F* can not get this for lambda literal related reasons.
             assert(sigma i j       (fun k -> c * (e k)) =
                    sigma i (j - 1) (fun k -> c * (e k)) +
                    (c * e j))
        end

///   Nik's solution here, more specific that section 9.4, is to put in 
///  functional extensionality and then lift the lambda out into a named
///  function.

let rec sigma_feq (i j : int) (f g : interval i (j + 1) -> int)
: Lemma
  (requires i <= j /\ F.feq f g)
  (ensures sigma i j f == sigma i j g)
  (decreases (j - i))
= if i = j then ()
  else sigma_feq i (j - 1) f g

let name_it (#i #j:int) (e:interval i (j + 1) -> int) (c:nat)
: interval i (j + 1) -> int
= fun k -> c * (e k)

let rec sigma_const_dist_aux (i: int) (j: int { i <= j }) (e: interval i (j + 1) -> int) (c: nat)
: Lemma
  (ensures sigma i j (name_it #i #j e c) == c * (sigma i j e))
  (decreases j - i)
= if i = j
  then ()
  else begin
    calc (==) {
      sigma i j (name_it #i #j e c);
    (==) {}
      sigma i (j - 1) (name_it #i #j e c) + c * e j;
    (==) { sigma_feq i (j - 1) (name_it #i #j e c) (name_it #i #(j - 1) e c) }
      sigma i (j - 1) (name_it #i #(j - 1) e c) + c * e j;
    (==) { sigma_const_dist_aux i (j - 1) e c }
      c * sigma i (j - 1) e + c * e j;
    }
   end

let sigma_const_dist (i: int) (j: int { i <= j }) (e: interval i (j + 1) -> int) (c: nat)
: Lemma
  (ensures sigma i j (fun k -> c * e k) == c * sigma i j e)
= sigma_const_dist_aux i j e c
'''
Clone this wiki locally