You might remember my post last year about lazy infinite lists in Lean, back when I didn’t know how to write Lean. I still don’t know how to write Lean, but today we’ll try implementing infinite lists again, this time without cheating with unsafe or partial.
People often say Lean prevents you from doing infinite recursion. That’s not exactly true. What Lean really wants to prevent you from doing is this:
theorem bad : 2 + 2 = 5 := by
exact bad
Basically, dangerous infinite recursion that can be used to prove false is banned. But there are also non-dangerous forms of infinite recursion, such as a while loop inside a do block (which is internally implemented using partial). Another way to do infinite recursion is by using partial_fixpoint.
I recently learned about the library Coinductive by Michael Sammler, as well as Alex Keizer’s fork to make it computable, which makes partial_fixpoint much more ergonomic to use. First, we’ll import some stuff and do some polynomial functor magic: (I’m using Willem Vanhulle’s Lean formatter so the code style probably looks slightly weird)
import Batteries.Data.Rat.Float
import Coinductive
namespace Stream
open Coinductive Lean.Order
inductive StreamF (α : Type u) (Stream : Type u) : Type u where
| snil
| scons (x : α) (tl : Thunk Stream)
inductive StreamF.In (α : Type u) : Type u where
| snil
| scons (x : α)
@[simp, grind =]
theorem Thunk.get_mk (fn : Unit → α) : Thunk.get ⟨fn⟩ = fn () := by rfl
@[simp]
theorem Thunk.mk_get (x : Thunk α) : Thunk.mk (fun _ ↦ x.get) = x := by simp [Thunk.ext_iff]
instance (α : Type u) : PF (StreamF α)
where
P :=
⟨StreamF.In α, fun
| .snil => PEmpty
| .scons _ => PUnit⟩
unpack
| .snil => .obj (.snil) nofun
| .scons hd tl => .obj (.scons hd) fun _ ↦ tl.get
pack
| .obj (.snil) _ => .snil
| .obj (.scons hd) tl => .scons hd (tl ⟨⟩)
unpack_pack := by rintro _ (_ | _) <;> simp
pack_unpack := by rintro _ (⟨⟨⟩, _⟩ | ⟨⟨⟩⟩) <;> simp <;> funext x <;> cases x
abbrev Stream (α : Type u) : Type u :=
CoInd (StreamF α)
def Stream.fold (t : StreamF α (Stream α)) : Stream α :=
CoInd.fold _ t
def Stream.snil : Stream α :=
.fold .snil
def Stream.scons (hd : α) (tl : Thunk (Stream α)) : Stream α :=
.fold <| .scons hd tl
instance : Inhabited (StreamF α PUnit) where default := .snil
Cool, so now we can write our Stream.map function?
def Stream.map (f : α → β) (s : Stream α) : Stream β :=
match s.unfold with
| .snil => .snil
| .scons hd tl => .scons (f hd) (Stream.map f tl.get)
partial_fixpoint
Not so fast! Lean throws an error:
Could not prove 'Stream.Stream.map' to be monotone in its recursive calls:
Cannot eliminate recursive call `Stream.map f tl.get` enclosed in
scons (f hd✝) { fn := fun x => map f tl✝.get }
I won’t delve into the theory here, but basically Lean will only accept this definition of Stream.map if we show that it preserves the partial order on the Stream type. Here’s the full proof.
@[simp]
theorem unfold_snil : CoInd.unfold _ Stream.snil = StreamF.snil (α := α) := by simp [Stream.snil, Stream.fold]
@[simp]
theorem unfold_scons : CoInd.unfold _ (Stream.scons i s) = StreamF.scons i s := by simp [Stream.scons, Stream.fold]
@[simp]
theorem Stream.bot_eq : CoInd.bot (StreamF α) = Stream.snil :=
by
rw [CoInd.bot_eq]
simp [PF.map, PF.pack, Stream.snil, Stream.fold]
theorem Stream.le_unfold (s1 s2 : Stream α) :
(s1 ⊑ s2) = (s1 = .snil ∨ ∃ i s1' s2', s1 = .scons i s1' ∧ s2 = .scons i s2' ∧ s1'.get ⊑ s2'.get) :=
by
ext
constructor
· intro h
rw [CoInd.le_unfold] at h
rcases h with (rfl | ⟨i, _, _, _, _, h1, h2⟩); simp
rw [← unfold_fold _ s1, ← unfold_fold _ s2]
rw [← PF.unpack_pack s1.unfold, ← PF.unpack_pack s2.unfold]
simp only [h1, h2]
cases i <;> simp [PF.pack, snil, scons, fold]
right
exists ?_, ?_; rotate_left 1
constructor; rfl
apply Exists.intro
constructor; rfl
simp_all
· rintro (rfl | ⟨_, _, _, rfl, rfl, _⟩)
· simp [CoInd.le_unfold]
· simp [CoInd.le_unfold]
right
simp [PF.unpack]
constructor <;> try rfl
grind
theorem scons_monoN α i (s1 s2 : Thunk (Stream α)) n :
CoIndN.le _ (s1.get.approx n) (s2.get.approx n) →
CoIndN.le _ ((Stream.scons i s1).approx (n + 1)) ((Stream.scons i s2).approx (n + 1)) :=
by
intro hs
simp [CoIndN.le, PF.unpack]
right
constructor <;> try rfl
grind [coherent1]
@[partial_fixpoint_monotone]
theorem scons_mono [PartialOrder β] i (f : β → Stream α) : monotone f → monotone fun x ↦ Stream.scons i ⟨fun _ ↦ f x⟩ :=
by
intro hf t1 t2 hle
apply CoInd.le_leN
rintro ⟨n⟩; simp [CoIndN.le]
apply scons_monoN
grind [CoInd.leN_le, monotone]
def Stream.map (f : α → β) (s : Stream α) : Stream β :=
match s.unfold with
| .snil => .snil
| .scons hd tl => .scons (f hd) (Stream.map f tl.get)
partial_fixpoint
@[partial_fixpoint_monotone]
theorem map_mono [PartialOrder γ] (f : α → β) (g : γ → Stream α) : monotone g → monotone fun x ↦ Stream.map f (g x) :=
by
intro hf t1 t2 hle
apply CoInd.le_leN
intro n
dsimp only
have hs : (g t1) ⊑ (g t2) := by grind [monotone]
generalize g t1 = s1, g t2 = s2 at hs
induction n generalizing s1 s2 with
| zero => simp [CoIndN.le]
| succ n ih =>
unfold Stream.map
rw [Stream.le_unfold] at hs
rcases hs with rfl | ⟨hd, tl1, tl2, rfl, rfl, htl⟩
· simp [CoIndN.le, CoIndN.bot]
· simp_all
simp [CoIndN.le, PF.unpack]
right
constructor <;> try rfl
intro x
exact ih _ _ htl
We can now do our favorite trick for defining cosine (I omitted a few of the monotonicity theorems for brevity):
def Stream.take (s : Stream α) : Nat → List (Option α)
| 0 => []
| n + 1 => s.shead :: s.stail.take n
def pos :=
Stream.scons (1 : Rat) <| pos.map (· + 1)
partial_fixpoint
#eval pos.take 10
notation s "integrate" c => Stream.scons c <| Stream.zipWith (· / ·) s pos
#eval (1 : Rat) / 3
def expSeries :=
expSeries integrate 1
partial_fixpoint
#eval expSeries.take 10
def evalAt n (s : Stream Rat) (x : Rat) :=
s.take n |>.foldr
(fun a acc ↦
match a with
| none => acc * x
| some a => a + acc * x)
0
#eval (evalAt 10 expSeries 2 : Float)
#eval Float.exp 2
mutual
def sinSeries :=
cosSeries integrate (0 : Rat)
partial_fixpoint
def cosSeries :=
(sinSeries integrate (-1)).map (-·)
partial_fixpoint
end
#eval (evalAt 10 sinSeries 2 : Float)
#eval Float.sin 2
Yay, no more unsafe or partial! But there’s a big problem with this approach, which is that it’s extremely slow. A simple pos.take 25 already takes 30 seconds, and seems to grow exponentially as we increase that number. Meanwhile, the same code in Haskell takes linear time:
pos = 1 : map (1+) pos
take 2500 pos
What if instead of the type of possibly infinite lists, we used the type of only infinite lists? Then we wouldn’t be able to directly use partial_fixpoint to define recursive functions for this type because this type’s partial order doesn’t have a bottom element. Fortunately, there’s a way around this, similar to the trick used to define F91 in Lean: We can prove that Stream.map preserve infiniteness and restrict the input and output types of Stream.map to only infinite lists. I haven’t figured out how to implement this idea though, but maybe it’ll be in a future blog post.