Skip to content

Commit

Permalink
Provided an optimized GetEnumerator for scan
Browse files Browse the repository at this point in the history
  • Loading branch information
manofstick committed Jun 22, 2017
1 parent 513ce76 commit 709ce85
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 11 deletions.
96 changes: 86 additions & 10 deletions src/fsharp/FSharp.Core/iseq.fs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ namespace Microsoft.FSharp.Collections
member __.GetHashCode o = c.GetHashCode o._1
member __.Equals (lhs,rhs) = c.Equals (lhs._1, rhs._1) }

[<AbstractClass>]
type PreferGetEnumerator<'T>() =
inherit EnumerableBase<'T>()

abstract GetEnumerator: unit -> IEnumerator<'T>
abstract GetSeq : unit -> ISeq<'T>

interface IEnumerable<'T> with
member this.GetEnumerator () : IEnumerator<'T> = this.GetEnumerator ()

interface ISeq<'T> with
member this.PushTransform<'U> (next:ITransformFactory<'T,'U>) : ISeq<'U> = (this.GetSeq()).PushTransform next
member this.Fold<'Result> (f:PipeIdx->Folder<'T,'Result>) : 'Result = (this.GetSeq()).Fold f

[<CompiledName "Empty">]
let empty<'T> = Microsoft.FSharp.Collections.SeqComposition.Core.EmptyEnumerable<'T>.Instance

Expand Down Expand Up @@ -592,17 +606,79 @@ namespace Microsoft.FSharp.Collections
let concat (sources:ISeq<#ISeq<'T>>) : ISeq<'T> =
upcast (ThinConcatEnumerable (sources, id))

(*
Represents the following seq comprehension, but they don't work at this level
seq {
let f = OptimizedClosures.FSharpFunc<_,_,_>.Adapt folder
let mutable state = initialState
yield state
for item in enumerable do
state <- f.Invoke (state, item)
yield state }
*)
type ScanEnumerator<'T,'State>(folder:'State->'T->'State, initialState:'State, enumerable:seq<'T>) =
let f = OptimizedClosures.FSharpFunc<_,_,_>.Adapt folder

let mutable state = 0 (*Pre-start*)
let mutable enumerator = Unchecked.defaultof<IEnumerator<'T>>
let mutable current = initialState

interface IEnumerator<'State> with
member this.Current: 'State =
match state with
| 0(*PreStart*) -> notStarted()
| 1(*GetEnumerator*) -> current
| 2(*MoveNext*) -> current
| _(*Finished*) -> alreadyFinished()

interface IEnumerator with
member this.Current : obj =
box (this:>IEnumerator<'State>).Current

member this.MoveNext () : bool =
match state with
| 0(*PreStart*) ->
state <- 1(*GetEnumerator*)
true
| 1(*GetEnumerator*) ->
enumerator <- enumerable.GetEnumerator ()
state <- 2(*MoveNext*)
(this:>IEnumerator).MoveNext ()
| 2(*MoveNext*) ->
if enumerator.MoveNext () then
current <- f.Invoke (current, enumerator.Current)
true
else
current <- Unchecked.defaultof<_>
state <- 3(*Finished*)
false
| _(*Finished*) -> alreadyFinished()

member this.Reset () : unit = noReset ()

interface IDisposable with
member this.Dispose(): unit =
if isNotNull enumerator then
enumerator.Dispose ()

[<CompiledName "Scan">]
let inline scan (folder:'State->'T->'State) (initialState:'State) (source:ISeq<'T>) :ISeq<'State> =
let head = singleton initialState
let tail =
source.PushTransform { new ITransformFactory<'T,'State> with
override __.Compose _ _ next =
upcast { new Transform<'T,'V,'State>(next, initialState) with
override this.ProcessNext (input:'T) : bool =
this.State <- folder this.State input
TailCall.avoid (next.ProcessNext this.State) } }
concat (ofList [ head ; tail ])
let scan (folder:'State->'T->'State) (initialState:'State) (source:ISeq<'T>) : ISeq<'State> =
upcast { new PreferGetEnumerator<'State>() with
member this.GetEnumerator () =
upcast new ScanEnumerator<'T,'State>(folder, initialState, source)

member this.GetSeq () =
let head = singleton initialState
let tail =
source.PushTransform { new ITransformFactory<'T,'State> with
override __.Compose _ _ next =
let f = OptimizedClosures.FSharpFunc<_,_,_>.Adapt folder
upcast { new Transform<'T,'V,'State>(next, initialState) with
override this.ProcessNext (input:'T) : bool =
this.State <- f.Invoke (this.State, input)
TailCall.avoid (next.ProcessNext this.State) } }
concat (ofList [ head ; tail ]) }

[<CompiledName "Skip">]
let skip (skipCount:int) (source:ISeq<'T>) : ISeq<'T> =
Expand Down
2 changes: 1 addition & 1 deletion src/fsharp/FSharp.Core/iseq.fsi
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ namespace Microsoft.FSharp.Collections
val inline reduce : f:('T->'T->'T) -> source:ISeq<'T> -> 'T

[<CompiledName "Scan">]
val inline scan : folder:('State->'T->'State) -> initialState:'State -> source:ISeq<'T> -> ISeq<'State>
val scan : folder:('State->'T->'State) -> initialState:'State -> source:ISeq<'T> -> ISeq<'State>

[<CompiledName "Skip">]
val skip : skipCount:int -> source:ISeq<'T> -> ISeq<'T>
Expand Down
3 changes: 3 additions & 0 deletions src/fsharp/FSharp.Core/seqcore.fsi
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ namespace Microsoft.FSharp.Collections.SeqComposition
abstract member Append : ISeq<'T> -> ISeq<'T>
abstract member Length : unit -> int
abstract member GetRaw : unit -> seq<'T>
default Append : ISeq<'T> -> ISeq<'T>
default Length : unit -> int
default GetRaw : unit -> seq<'T>
interface ISeq<'T>

[<AbstractClass>]
Expand Down

0 comments on commit 709ce85

Please sign in to comment.