-
Notifications
You must be signed in to change notification settings - Fork 236
Lambda Literals in SMT Proof
Brian Milnes edited this page Jan 20, 2025
·
1 revision
module NamingLambdaLiterals
open FStar.Mul
open FStar.IntegerIntervals
module F = FStar.FunctionalExtensionality
/// Defining a simple sigma (ranged sum) that will use lambda literals for proofs shows
/// the problems proving simple lemmas.
let rec sigma (i: int) (j: int) (e: interval i (j + 1) -> int)
: Tot int (decreases j - i)
= if j < i
then 0
else if i = j
then e i
else sigma i (j - 1) e + e j
[@@expect_failure [19]]
let rec sigma_const_dist (i: int) (j: int) (e: interval i (j + 1) -> int) (c: nat) :
Lemma (requires i <= j)
(ensures sigma i j (fun k -> c * (e k)) = c * (sigma i j e))
(decreases j - i)
= if j < i
then false_elim()
else if i = j
then ()
else begin
sigma_const_dist i (j - 1) e c;
// F* can not get this for lambda literal related reasons.
assert(sigma i j (fun k -> c * (e k)) =
sigma i (j - 1) (fun k -> c * (e k)) +
(c * e j))
end
/// Nik's solution here, more specific that section 9.4, is to put in
/// functional extensionality and then lift the lambda out into a named
/// function.
let rec sigma_feq (i j : int) (f g : interval i (j + 1) -> int)
: Lemma
(requires i <= j /\ F.feq f g)
(ensures sigma i j f == sigma i j g)
(decreases (j - i))
= if i = j then ()
else sigma_feq i (j - 1) f g
let name_it (#i #j:int) (e:interval i (j + 1) -> int) (c:nat)
: interval i (j + 1) -> int
= fun k -> c * (e k)
let rec sigma_const_dist_aux (i: int) (j: int { i <= j }) (e: interval i (j + 1) -> int) (c: nat)
: Lemma
(ensures sigma i j (name_it #i #j e c) == c * (sigma i j e))
(decreases j - i)
= if i = j
then ()
else begin
calc (==) {
sigma i j (name_it #i #j e c);
(==) {}
sigma i (j - 1) (name_it #i #j e c) + c * e j;
(==) { sigma_feq i (j - 1) (name_it #i #j e c) (name_it #i #(j - 1) e c) }
sigma i (j - 1) (name_it #i #(j - 1) e c) + c * e j;
(==) { sigma_const_dist_aux i (j - 1) e c }
c * sigma i (j - 1) e + c * e j;
}
end
let sigma_const_dist (i: int) (j: int { i <= j }) (e: interval i (j + 1) -> int) (c: nat)
: Lemma
(ensures sigma i j (fun k -> c * e k) == c * sigma i j e)
= sigma_const_dist_aux i j e c
'''