(* week-06_reasoning-about-lambda-dropped-functions.v *)
(* LPP 2025 - CS3234 2024-2025, Sem2 *)
(* Olivier Danvy <olivier@comp.nus.edu.sg> *)
(* Version of 21 Feb 2025 *)

(* ********** *)

Ltac fold_unfold_tactic name := intros; unfold name; fold name; reflexivity.

Require Import Arith Bool List.

(* ********** *)

Definition add_acc (n m : nat) : nat :=
  let fix loop n a :=
    match n with
      O =>
      a
    | S n' =>
      loop n' (S a)
    end
  in loop n m.

(* ***** *)

Lemma O_is_right_neutral_for_add_acc :
  forall n : nat,
    add_acc n 0 = n.
Proof.
  unfold add_acc.
  remember (fix loop (n0 a : nat) {struct n0} : nat := match n0 with
                                              | 0 => a
                                              | S n' => loop n' (S a)
                                              end)
    as loop eqn:H_loop.
  assert (fold_unfold_loop_O :
            forall a : nat,
              loop 0 a = a).
  { intro a.
    rewrite -> H_loop.
    reflexivity. }
  assert (fold_unfold_loop_S :
            forall n' a : nat,
              loop (S n') a = loop n' (S a)).
  { intros n' a.
    rewrite -> H_loop.
    reflexivity. }
  assert (about_loop :
            forall n a : nat,
              loop n a = loop n 0 + a).
  { intro n'.
    induction n' as [ | n'' IHn'']; intro a.
    - rewrite -> (fold_unfold_loop_O a).
      rewrite -> (fold_unfold_loop_O 0).
      exact (Nat.add_0_l a).
    - rewrite -> (fold_unfold_loop_S n'' a).
      rewrite -> (fold_unfold_loop_S n'' 0).
      rewrite (IHn'' (S a)).
      rewrite (IHn'' (S 0)).
      Check Nat.add_assoc.
      rewrite <- Nat.add_assoc.
      Check Nat.add_1_l.
      rewrite -> (Nat.add_1_l a).
      reflexivity. }
  intro n.
  induction n as [ | n' IHn'].
  - rewrite -> (fold_unfold_loop_O 0).
    reflexivity.
  - rewrite -> (fold_unfold_loop_S n' 0).
    rewrite -> (about_loop n' 1).
    rewrite -> IHn'.
    exact (Nat.add_1_r n').
Qed.

(* ***** *)

Lemma add_acc_is_associative :
  forall n1 n2 n3 : nat,
    add_acc n1 (add_acc n2 n3) = add_acc (add_acc n1 n2) n3.
Proof.
  unfold add_acc.
  remember (fix loop (n0 a : nat) {struct n0} : nat := match n0 with
                                              | 0 => a
                                              | S n' => loop n' (S a)
                                              end)
    as loop eqn:H_loop.
  assert (fold_unfold_loop_O :
            forall a : nat,
              loop 0 a = a).
  { intro a.
    rewrite -> H_loop.
    reflexivity. }
  assert (fold_unfold_loop_S :
            forall n' a : nat,
              loop (S n') a = loop n' (S a)).
  { intros n' a.
    rewrite -> H_loop.
    reflexivity. }
  assert (about_loop :
            forall n a : nat,
              loop n a = loop n 0 + a).
  { intro n'.
    induction n' as [ | n'' IHn'']; intro a.
    - rewrite -> (fold_unfold_loop_O a).
      rewrite -> (fold_unfold_loop_O 0).
      exact (Nat.add_0_l a).
    - rewrite -> (fold_unfold_loop_S n'' a).
      rewrite -> (fold_unfold_loop_S n'' 0).
      rewrite (IHn'' (S a)).
      rewrite (IHn'' (S 0)).
      Check Nat.add_assoc.
      rewrite <- Nat.add_assoc.
      Check Nat.add_1_l.
      rewrite -> (Nat.add_1_l a).
      reflexivity. }
  intro n1.
  induction n1 as [ | n1' IHn1'].
  - intros n2 n3.
    rewrite -> (fold_unfold_loop_O (loop n2 n3)).
    rewrite -> (fold_unfold_loop_O n2).
    reflexivity.
  - intros n2 n3.
    rewrite -> (fold_unfold_loop_S n1' (loop n2 n3)).
    rewrite -> (fold_unfold_loop_S n1' n2).
    rewrite <- (IHn1' (S n2) n3).
    assert (helpful :
              forall x y : nat,
                S (loop x y) = loop (S x) y).
    { intro x.
      induction x as [ | x' IHx'].
      - intro y.
        rewrite -> (fold_unfold_loop_O y).
        rewrite -> (fold_unfold_loop_S 0 y).
        rewrite -> (fold_unfold_loop_O (S y)).
        reflexivity.
      - intro y.
        rewrite -> (fold_unfold_loop_S x' y).
        rewrite -> (fold_unfold_loop_S (S x') y).
        exact (IHx' (S y)). }
    rewrite -> (helpful n2 n3).
    reflexivity.
Qed.

(* ********** *)

Definition power (x n : nat) : nat :=
  let fix loop i a :=
    match i with
      O =>
      a
    | S i' =>
      loop i' (x * a)
    end
  in loop n 1.

Proposition about_exponentiating_with_a_sum :
  forall x n1 n2 : nat,
    power x (n1 + n2) = power x n1 * power x n2.
Proof.
  unfold power.
  intro x.
  remember (fix loop (i a : nat) {struct i} : nat := match i with
                                                     | 0 => a
                                                     | S i' => loop i' (x * a)
                                                     end)
    as loop eqn:H_loop.
  assert (fold_unfold_loop_O :
            forall a : nat,
              loop 0 a = a).
  { intro a.
    rewrite -> H_loop.
    reflexivity. }
  assert (fold_unfold_loop_S :
            forall i' a : nat,
              loop (S i') a = loop i' (x * a)).
  { intros n' a.
    rewrite -> H_loop.
    reflexivity. }
  assert (eureka :
            forall n a : nat,
              loop n a = loop n 1 * a).
  { intro n.
    induction n as [ | n' IHn'].
    - intro a.
      rewrite -> (fold_unfold_loop_O a).
      rewrite -> (fold_unfold_loop_O 1).
      symmetry.
      exact (Nat.mul_1_l a).
    - intro a.
      rewrite -> (fold_unfold_loop_S n' a).
      rewrite -> (fold_unfold_loop_S n' 1).
      rewrite -> (Nat.mul_1_r x).
      rewrite -> (IHn' (x * a)).
      rewrite -> (IHn' x).
      Check (Nat.mul_assoc (loop n' 1) x a).
      exact (Nat.mul_assoc (loop n' 1) x a). }
  intro n1.
  induction n1 as [ | n1' IHn1'].
  - intro n2.
    rewrite -> (Nat.add_0_l n2).
    rewrite -> (fold_unfold_loop_O 1).
    symmetry.
    exact (Nat.mul_1_l (loop n2 1)).
  - intro n2.
    Check plus_Sn_m.
    rewrite -> (plus_Sn_m n1' n2).
    rewrite -> (fold_unfold_loop_S (n1' + n2) 1).
    rewrite -> (Nat.mul_1_r x).
    rewrite -> (fold_unfold_loop_S n1' 1).
    rewrite -> (Nat.mul_1_r x).
    rewrite -> (eureka (n1' + n2) x).
    rewrite -> (eureka n1' x).
    rewrite -> (IHn1' n2).
    Check Nat.mul_assoc.
    rewrite -> (Nat.mul_comm (loop n1' 1 * loop n2 1)).
    rewrite -> (Nat.mul_comm (loop n1' 1) x).
    Check Nat.mul_assoc.
    exact (Nat.mul_assoc x (loop n1' 1) (loop n2 1)).
Qed.

(* ********** *)

(* end of week-06_reasoning-about-lambda-dropped-functions.v *)
