You might remember my post last year about lazy infinite lists in Lean, back when I didn’t know how to write Lean. I still don’t know how to write Lean, but today we’ll try implementing infinite lists again, this time without cheating with unsafe or partial.

People often say Lean prevents you from doing infinite recursion. That’s not exactly true. What Lean really wants to prevent you from doing is this:

theorem bad : 2 + 2 = 5 := by
  exact bad

Basically, dangerous infinite recursion that can be used to prove false is banned. But there are also non-dangerous forms of infinite recursion, such as a while loop inside a do block (which is internally implemented using partial). Another way to do infinite recursion is by using partial_fixpoint.

I recently learned about the library Coinductive by Michael Sammler, as well as Alex Keizer’s fork to make it computable, which makes partial_fixpoint much more ergonomic to use. First, we’ll import some stuff and do some polynomial functor magic: (I’m using Willem Vanhulle’s Lean formatter so the code style probably looks slightly weird)

import Batteries.Data.Rat.Float
import Coinductive

namespace Stream
open Coinductive Lean.Order

inductive StreamF (α : Type u) (Stream : Type u) : Type u where
| snil
| scons (x : α) (tl : Thunk Stream)

inductive StreamF.In (α : Type u) : Type u where
| snil
| scons (x : α)

@[simp, grind =]
theorem Thunk.get_mk (fn : Unit → α) : Thunk.get ⟨fn⟩ = fn () := by rfl

@[simp]
theorem Thunk.mk_get (x : Thunk α) : Thunk.mk (fun _ ↦ x.get) = x := by simp [Thunk.ext_iff]

instance (α : Type u) : PF (StreamF α)
    where
  P :=
    ⟨StreamF.In α, fun
      | .snil => PEmpty
      | .scons _ => PUnit⟩
  unpack
    | .snil => .obj (.snil) nofun
    | .scons hd tl => .obj (.scons hd) fun _ ↦ tl.get
  pack
    | .obj (.snil) _ => .snil
    | .obj (.scons hd) tl => .scons hd (tl ⟨⟩)
  unpack_pack := by rintro _ (_ | _) <;> simp
  pack_unpack := by rintro _ (⟨⟨⟩, _⟩ | ⟨⟨⟩⟩) <;> simp <;> funext x <;> cases x

abbrev Stream (α : Type u) : Type u :=
  CoInd (StreamF α)

def Stream.fold (t : StreamF α (Stream α)) : Stream α :=
  CoInd.fold _ t

def Stream.snil : Stream α :=
  .fold .snil

def Stream.scons (hd : α) (tl : Thunk (Stream α)) : Stream α :=
  .fold <| .scons hd tl

instance : Inhabited (StreamF α PUnit) where default := .snil

Cool, so now we can write our Stream.map function?

def Stream.map (f : α → β) (s : Stream α) : Stream β :=
  match s.unfold with
  | .snil => .snil
  | .scons hd tl => .scons (f hd) (Stream.map f tl.get)
partial_fixpoint

Not so fast! Lean throws an error:

Could not prove 'Stream.Stream.map' to be monotone in its recursive calls:
  Cannot eliminate recursive call `Stream.map f tl.get` enclosed in
    scons (f hd✝) { fn := fun x => map f tl✝.get }

I won’t delve into the theory here, but basically Lean will only accept this definition of Stream.map if we show that it preserves the partial order on the Stream type. Here’s the full proof.

@[simp]
theorem unfold_snil : CoInd.unfold _ Stream.snil = StreamF.snil (α := α) := by simp [Stream.snil, Stream.fold]

@[simp]
theorem unfold_scons : CoInd.unfold _ (Stream.scons i s) = StreamF.scons i s := by simp [Stream.scons, Stream.fold]

@[simp]
theorem Stream.bot_eq : CoInd.bot (StreamF α) = Stream.snil :=
  by
  rw [CoInd.bot_eq]
  simp [PF.map, PF.pack, Stream.snil, Stream.fold]

theorem Stream.le_unfold (s1 s2 : Stream α) :
    (s1 ⊑ s2) = (s1 = .snil ∨ ∃ i s1' s2', s1 = .scons i s1' ∧ s2 = .scons i s2' ∧ s1'.get ⊑ s2'.get) :=
  by
  ext
  constructor
  · intro h
    rw [CoInd.le_unfold] at h
    rcases h with (rfl | ⟨i, _, _, _, _, h1, h2⟩); simp
    rw [← unfold_fold _ s1, ← unfold_fold _ s2]
    rw [← PF.unpack_pack s1.unfold, ← PF.unpack_pack s2.unfold]
    simp only [h1, h2]
    cases i <;> simp [PF.pack, snil, scons, fold]
    right
    exists ?_, ?_; rotate_left 1
    constructor; rfl
    apply Exists.intro
    constructor; rfl
    simp_all
  · rintro (rfl | ⟨_, _, _, rfl, rfl, _⟩)
    · simp [CoInd.le_unfold]
    · simp [CoInd.le_unfold]
      right
      simp [PF.unpack]
      constructor <;> try rfl
      grind

theorem scons_monoN α i (s1 s2 : Thunk (Stream α)) n :
    CoIndN.le _ (s1.get.approx n) (s2.get.approx n)      CoIndN.le _ ((Stream.scons i s1).approx (n + 1)) ((Stream.scons i s2).approx (n + 1)) :=
  by
  intro hs
  simp [CoIndN.le, PF.unpack]
  right
  constructor <;> try rfl
  grind [coherent1]

@[partial_fixpoint_monotone]
theorem scons_mono [PartialOrder β] i (f : β → Stream α) : monotone f → monotone fun x ↦ Stream.scons i ⟨fun _ ↦ f x⟩ :=
  by
  intro hf t1 t2 hle
  apply CoInd.le_leN
  rintro ⟨n⟩; simp [CoIndN.le]
  apply scons_monoN
  grind [CoInd.leN_le, monotone]

def Stream.map (f : α → β) (s : Stream α) : Stream β :=
  match s.unfold with
  | .snil => .snil
  | .scons hd tl => .scons (f hd) (Stream.map f tl.get)
partial_fixpoint

@[partial_fixpoint_monotone]
theorem map_mono [PartialOrder γ] (f : α → β) (g : γ → Stream α) : monotone g → monotone fun x ↦ Stream.map f (g x) :=
  by
  intro hf t1 t2 hle
  apply CoInd.le_leN
  intro n
  dsimp only
  have hs : (g t1)(g t2) := by grind [monotone]
  generalize g t1 = s1, g t2 = s2 at hs
  induction n generalizing s1 s2 with
  | zero => simp [CoIndN.le]
  | succ n ih =>
    unfold Stream.map
    rw [Stream.le_unfold] at hs
    rcases hs with rfl | ⟨hd, tl1, tl2, rfl, rfl, htl⟩
    · simp [CoIndN.le, CoIndN.bot]
    · simp_all
      simp [CoIndN.le, PF.unpack]
      right
      constructor <;> try rfl
      intro x
      exact ih _ _ htl

We can now do our favorite trick for defining cosine (I omitted a few of the monotonicity theorems for brevity):

def Stream.take (s : Stream α) : Nat → List (Option α)
  | 0 => []
  | n + 1 => s.shead :: s.stail.take n

def pos :=
  Stream.scons (1 : Rat) <| pos.map (· + 1)
partial_fixpoint

#eval pos.take 10

notation s "integrate" c => Stream.scons c <| Stream.zipWith (· / ·) s pos

#eval (1 : Rat) / 3

def expSeries :=
  expSeries integrate 1
partial_fixpoint

#eval expSeries.take 10

def evalAt n (s : Stream Rat) (x : Rat) :=
  s.take n |>.foldr
    (fun a acc ↦
      match a with
      | none => acc * x
      | some a => a + acc * x)
    0

#eval (evalAt 10 expSeries 2 : Float)
#eval Float.exp 2

mutual
def sinSeries :=
  cosSeries integrate (0 : Rat)
partial_fixpoint
def cosSeries :=
  (sinSeries integrate (-1)).map ()
partial_fixpoint
end

#eval (evalAt 10 sinSeries 2 : Float)
#eval Float.sin 2

Yay, no more unsafe or partial! But there’s a big problem with this approach, which is that it’s extremely slow. A simple pos.take 25 already takes 30 seconds, and seems to grow exponentially as we increase that number. Meanwhile, the same code in Haskell takes linear time:

pos = 1 : map (1+) pos
take 2500 pos

What if instead of the type of possibly infinite lists, we used the type of only infinite lists? Then we wouldn’t be able to directly use partial_fixpoint to define recursive functions for this type because this type’s partial order doesn’t have a bottom element. Fortunately, there’s a way around this, similar to the trick used to define F91 in Lean: We can prove that Stream.map preserve infiniteness and restrict the input and output types of Stream.map to only infinite lists. I haven’t figured out how to implement this idea though, but maybe it’ll be in a future blog post.