From bded4721d0906befb545191c4410f8a90d436b1a Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 27 Jan 2024 01:26:47 +0000 Subject: [PATCH] overload `eval` for evaluating multiple tensors at once (#384) --- backend/VERSION | 2 +- .../src/tensorflow/compiler/xla/literal.cpp | 85 ++++++++++++------- backend/src/tensorflow/compiler/xla/literal.h | 27 ++++-- .../tensorflow/compiler/xla/shape_util.cpp | 16 ++++ .../src/tensorflow/compiler/xla/shape_util.h | 7 ++ src/Compiler/Eval.idr | 7 +- src/Compiler/LiteralRW.idr | 34 ++++++-- .../Prim/TensorFlow/Compiler/Xla/Literal.idr | 20 ++--- .../TensorFlow/Compiler/Xla/ShapeUtil.idr | 16 ++++ .../Xla/TensorFlow/Compiler/Xla/Literal.idr | 61 ++++++------- .../Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr | 26 ++++++ .../Xla/TensorFlow/Compiler/Xla/XlaData.idr | 16 ++-- src/Tensor.idr | 72 +++++++++++++--- src/Util.idr | 7 ++ test/Unit/TestTensor.idr | 49 +++++++++++ test/Unit/TestTensor/Sampling.idr | 26 ++---- 16 files changed, 342 insertions(+), 129 deletions(-) diff --git a/backend/VERSION b/backend/VERSION index c5d54ec32..7c1886bb9 100644 --- a/backend/VERSION +++ b/backend/VERSION @@ -1 +1 @@ -0.0.9 +0.0.10 diff --git a/backend/src/tensorflow/compiler/xla/literal.cpp b/backend/src/tensorflow/compiler/xla/literal.cpp index 09b30f5d8..5ff1922f5 100644 --- a/backend/src/tensorflow/compiler/xla/literal.cpp +++ b/backend/src/tensorflow/compiler/xla/literal.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "literal.h" #include "shape.h" +#include "shape_util.h" extern "C" { Literal* Literal_new(Shape& shape) { @@ -31,61 +32,85 @@ extern "C" { } template -NativeT Literal_Get(Literal& lit, int* indices) { +NativeT Literal_Get(Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index) { xla::Literal& lit_ = reinterpret_cast(lit); - int64_t rank = lit_.shape().rank(); - int64_t multi_index[rank]; - std::copy(indices, indices + rank, multi_index); - return lit_.Get(absl::Span(multi_index, rank)); + int64_t multi_index_[multi_index_len]; + std::copy(multi_index, multi_index + multi_index_len, multi_index_); + auto multi_index_span = absl::Span(multi_index_, multi_index_len); + auto& shape_index_ = reinterpret_cast(shape_index); + return lit_.Get(multi_index_span, shape_index_); }; template -void Literal_Set(Literal& lit, int* indices, NativeT value) { +void Literal_Set( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, NativeT value +) { xla::Literal& lit_ = reinterpret_cast(lit); - int64_t rank = lit_.shape().rank(); - int64_t multi_index[rank]; - std::copy(indices, indices + rank, multi_index); - lit_.Set(absl::Span(multi_index, rank), value); + int64_t multi_index_[multi_index_len]; + std::copy(multi_index, multi_index + multi_index_len, multi_index_); + auto multi_index_span = absl::Span(multi_index_, multi_index_len); + auto& shape_index_ = reinterpret_cast(shape_index); + lit_.Set(multi_index_span, shape_index_, value); }; extern "C" { - int Literal_Get_bool(Literal& lit, int* indices) { - return (int) Literal_Get(lit, indices); + int Literal_Get_bool( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ) { + return (int) Literal_Get(lit, multi_index, multi_index_len, shape_index); } - int Literal_Get_int32_t(Literal& lit, int* indices) { - return Literal_Get(lit, indices); + int Literal_Get_int32_t( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ) { + return Literal_Get(lit, multi_index, multi_index_len, shape_index); } - int Literal_Get_uint32_t(Literal& lit, int* indices) { - return (int) Literal_Get(lit, indices); + int Literal_Get_uint32_t( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ) { + return (int) Literal_Get(lit, multi_index, multi_index_len, shape_index); } - int Literal_Get_uint64_t(Literal& lit, int* indices) { - return (int) Literal_Get(lit, indices); + int Literal_Get_uint64_t( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ) { + return (int) Literal_Get(lit, multi_index, multi_index_len, shape_index); } - double Literal_Get_double(Literal& lit, int* indices) { - return Literal_Get(lit, indices); + double Literal_Get_double( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ) { + return Literal_Get(lit, multi_index, multi_index_len, shape_index); } - void Literal_Set_bool(Literal& lit, int* indices, int value) { - Literal_Set(lit, indices, (bool) value); + void Literal_Set_bool( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value + ) { + Literal_Set(lit, multi_index, multi_index_len, shape_index, (bool) value); } - void Literal_Set_int32_t(Literal& lit, int* indices, int value) { - Literal_Set(lit, indices, value); + void Literal_Set_int32_t( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value + ) { + Literal_Set(lit, multi_index, multi_index_len, shape_index, value); } - void Literal_Set_uint32_t(Literal& lit, int* indices, int value) { - Literal_Set(lit, indices, (uint32_t) value); + void Literal_Set_uint32_t( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value + ) { + Literal_Set(lit, multi_index, multi_index_len, shape_index, (uint32_t) value); } - void Literal_Set_uint64_t(Literal& lit, int* indices, int value) { - Literal_Set(lit, indices, (uint64_t) value); + void Literal_Set_uint64_t( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value + ) { + Literal_Set(lit, multi_index, multi_index_len, shape_index, (uint64_t) value); } - void Literal_Set_double(Literal& lit, int* indices, double value) { - Literal_Set(lit, indices, value); + void Literal_Set_double( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, double value + ) { + Literal_Set(lit, multi_index, multi_index_len, shape_index, value); } } diff --git a/backend/src/tensorflow/compiler/xla/literal.h b/backend/src/tensorflow/compiler/xla/literal.h index bb494a72c..7d404aa73 100644 --- a/backend/src/tensorflow/compiler/xla/literal.h +++ b/backend/src/tensorflow/compiler/xla/literal.h @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "shape.h" +#include "shape_util.h" extern "C" { struct Literal; @@ -22,11 +23,23 @@ extern "C" { void Literal_delete(Literal* lit); - int Literal_Get_bool(Literal& lit, int* indices); - int Literal_Get_int(Literal& lit, int* indices); - double Literal_Get_double(Literal& lit, int* indices); - - void Literal_Set_bool(Literal& lit, int* indices, int value); - void Literal_Set_int(Literal& lit, int* indices, int value); - void Literal_Set_double(Literal& lit, int* indices, double value); + int Literal_Get_bool( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ); + int Literal_Get_int( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ); + double Literal_Get_double( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index + ); + + void Literal_Set_bool( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value + ); + void Literal_Set_int( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value + ); + void Literal_Set_double( + Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, double value + ); } diff --git a/backend/src/tensorflow/compiler/xla/shape_util.cpp b/backend/src/tensorflow/compiler/xla/shape_util.cpp index 897ca8380..8d8ebe079 100644 --- a/backend/src/tensorflow/compiler/xla/shape_util.cpp +++ b/backend/src/tensorflow/compiler/xla/shape_util.cpp @@ -17,8 +17,24 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "shape.h" +#include "shape_util.h" extern "C" { + ShapeIndex* ShapeIndex_new() { + return reinterpret_cast(new xla::ShapeIndex()); + } + void ShapeIndex_delete(ShapeIndex* s) { + delete reinterpret_cast(s); + } + + void ShapeIndex_push_back(ShapeIndex& shape_index, int value) { + reinterpret_cast(shape_index).push_back(value); + } + + void ShapeIndex_push_front(ShapeIndex& shape_index, int value) { + reinterpret_cast(shape_index).push_front(value); + } + Shape* MakeShape(int primitive_type, int* shape, int rank) { int64_t shape64[rank]; std::copy(shape, shape + rank, shape64); diff --git a/backend/src/tensorflow/compiler/xla/shape_util.h b/backend/src/tensorflow/compiler/xla/shape_util.h index fddfcc0d6..42ba0016e 100644 --- a/backend/src/tensorflow/compiler/xla/shape_util.h +++ b/backend/src/tensorflow/compiler/xla/shape_util.h @@ -16,5 +16,12 @@ limitations under the License. #include "shape.h" extern "C" { + struct ShapeIndex; + + ShapeIndex* ShapeIndex_new(); + void ShapeIndex_delete(ShapeIndex* s); + void ShapeIndex_push_back(ShapeIndex& shape_index, int value); + void ShapeIndex_push_front(ShapeIndex& shape_index, int value); + Shape* MakeShape(int primitive_type, int* shape, int rank); } diff --git a/src/Compiler/Eval.idr b/src/Compiler/Eval.idr index 8960d864f..ad2585b65 100644 --- a/src/Compiler/Eval.idr +++ b/src/Compiler/Eval.idr @@ -104,7 +104,7 @@ interpret xlaBuilder (MkFn params root env) = do set posInGraph param interpretE : Expr -> Builder XlaOp - interpretE (FromLiteral {dtype} lit) = constantLiteral xlaBuilder !(write {dtype} lit) + interpretE (FromLiteral {dtype} lit) = constantLiteral xlaBuilder !(write {dtype} [] lit) interpretE (Arg x) = get x interpretE (Tuple xs) = tuple xlaBuilder !(traverse get xs) interpretE (GetTupleElement idx x) = getTupleElement !(get x) idx @@ -234,12 +234,11 @@ toString f = do pure $ opToString xlaBuilder root export covering -execute : PrimitiveRW dtype a => Fn 0 -> {shape : _} -> ErrIO $ Literal shape a +execute : Fn 0 -> ErrIO Literal execute f = do xlaBuilder <- mkXlaBuilder "root" computation <- compile xlaBuilder f gpuStatus <- validateGPUMachineManager platform <- if ok gpuStatus then gpuMachineManager else getPlatform "Host" client <- getOrCreateLocalClient platform - lit <- executeAndTransfer client computation - pure (read {dtype} lit) + executeAndTransfer client computation diff --git a/src/Compiler/LiteralRW.idr b/src/Compiler/LiteralRW.idr index 7cecb83c8..c0084c2a6 100644 --- a/src/Compiler/LiteralRW.idr +++ b/src/Compiler/LiteralRW.idr @@ -16,7 +16,8 @@ limitations under the License. module Compiler.LiteralRW import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData -import Compiler.Xla.TensorFlow.Compiler.Xla.Literal +import public Compiler.Xla.TensorFlow.Compiler.Xla.Literal +import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil import Literal import Util @@ -39,21 +40,36 @@ indexed = go shape [] go (0 :: _) _ = [] go (S d :: ds) idxs = concat $ map (\i => go ds (snoc idxs i)) (range (S d)) -export +public export interface Primitive dtype => LiteralRW dtype ty where - set : Literal -> List Nat -> ty -> IO () - get : Literal -> List Nat -> ty + set : Literal -> List Nat -> ShapeIndex -> ty -> IO () + get : Literal -> List Nat -> ShapeIndex -> ty export -write : (HasIO io, LiteralRW dtype a) => {shape : _} -> Literal shape a -> io Literal -write xs = liftIO $ do +write : HasIO io => + LiteralRW dtype a => + {shape : _} -> + List Nat -> + Literal shape a -> + io Literal +write idxs xs = liftIO $ do literal <- allocLiteral {dtype} shape - sequence_ [| (\idxs => set {dtype} literal idxs) indexed xs |] + shapeIndex <- allocShapeIndex + traverse_ (pushBack shapeIndex) idxs + sequence_ [| (\idxs => set {dtype} literal idxs shapeIndex) indexed xs |] pure literal export -read : LiteralRW dtype a => Literal -> {shape : _} -> Literal shape a -read lit = map (get {dtype} lit) indexed +read : HasIO io => + LiteralRW dtype a => + {shape : _} -> + List Nat -> + Literal -> + io $ Literal shape a +read idxs lit = do + shapeIndex <- allocShapeIndex + traverse_ (pushBack shapeIndex) idxs + pure $ map (\mIdx => get {dtype} lit mIdx shapeIndex) (indexed {shape}) export LiteralRW PRED Bool where diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr index ae5d947dc..417f5439f 100644 --- a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr @@ -29,40 +29,40 @@ prim__delete : AnyPtr -> PrimIO () export %foreign (libxla "Literal_Set_bool") -prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> PrimIO () +prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int -> PrimIO () export %foreign (libxla "Literal_Get_bool") -literalGetBool : GCAnyPtr -> GCPtr Int -> Int +literalGetBool : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int export %foreign (libxla "Literal_Set_double") -prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Double -> PrimIO () +prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Double -> PrimIO () export %foreign (libxla "Literal_Get_double") -literalGetDouble : GCAnyPtr -> GCPtr Int -> Double +literalGetDouble : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Double export %foreign (libxla "Literal_Set_int32_t") -prim__literalSetInt32t : GCAnyPtr -> GCPtr Int -> Int -> PrimIO () +prim__literalSetInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int -> PrimIO () export %foreign (libxla "Literal_Get_int32_t") -literalGetInt32t : GCAnyPtr -> GCPtr Int -> Int +literalGetInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int export %foreign (libxla "Literal_Set_uint32_t") -prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int -> Bits32 -> PrimIO () +prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits32 -> PrimIO () export %foreign (libxla "Literal_Get_uint32_t") -literalGetUInt32t : GCAnyPtr -> GCPtr Int -> Bits32 +literalGetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits32 export %foreign (libxla "Literal_Set_uint64_t") -prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int -> Bits64 -> PrimIO () +prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits64 -> PrimIO () export %foreign (libxla "Literal_Get_uint64_t") -literalGetUInt64t : GCAnyPtr -> GCPtr Int -> Bits64 +literalGetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits64 diff --git a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr index 75fecd19f..6835a4094 100644 --- a/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr +++ b/src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr @@ -19,6 +19,22 @@ import System.FFI import Compiler.Xla.Prim.Util +export +%foreign (libxla "ShapeIndex_new") +prim__shapeIndexNew : PrimIO AnyPtr + +export +%foreign (libxla "ShapeIndex_delete") +prim__shapeIndexDelete : AnyPtr -> PrimIO () + +export +%foreign (libxla "ShapeIndex_push_back") +prim__shapeIndexPushBack : GCAnyPtr -> Int -> PrimIO () + +export +%foreign (libxla "ShapeIndex_push_front") +prim__shapeIndexPushFront : GCAnyPtr -> Int -> PrimIO () + export %foreign (libxla "MakeShape") prim__mkShape : Int -> GCPtr Int -> Int -> PrimIO AnyPtr diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr index ec8616ca5..9d20c3d41 100644 --- a/src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/Literal.idr @@ -41,65 +41,66 @@ allocLiteral shape = do namespace Bool export - set : Literal -> List Nat -> Bool -> IO () - set (MkLiteral lit) idxs value = do + set : Literal -> List Nat -> ShapeIndex -> Bool -> IO () + set (MkLiteral lit) idxs (MkShapeIndex shapeIndex) value = do MkIntArray idxsArrayPtr <- mkIntArray idxs - primIO $ prim__literalSetBool lit idxsArrayPtr (if value then 1 else 0) + primIO $ + prim__literalSetBool lit idxsArrayPtr (cast $ length idxs) shapeIndex (boolToCInt value) export - get : Literal -> List Nat -> Bool - get (MkLiteral lit) idxs = unsafePerformIO $ do + get : Literal -> List Nat -> ShapeIndex -> Bool + get (MkLiteral lit) idxs (MkShapeIndex shapeIndex) = unsafePerformIO $ do MkIntArray idxsArrayPtr <- mkIntArray idxs - pure $ cIntToBool $ literalGetBool lit idxsArrayPtr + pure $ cIntToBool $ literalGetBool lit idxsArrayPtr (cast $ length idxs) shapeIndex namespace Double export - set : Literal -> List Nat -> Double -> IO () - set (MkLiteral lit) idxs value = do + set : Literal -> List Nat -> ShapeIndex -> Double -> IO () + set (MkLiteral lit) idxs (MkShapeIndex shapeIndex) value = do MkIntArray idxsArrayPtr <- mkIntArray idxs - primIO $ prim__literalSetDouble lit idxsArrayPtr value + primIO $ prim__literalSetDouble lit idxsArrayPtr (cast $ length idxs) shapeIndex value export - get : Literal -> List Nat -> Double - get (MkLiteral lit) idxs = unsafePerformIO $ do + get : Literal -> List Nat -> ShapeIndex -> Double + get (MkLiteral lit) idxs (MkShapeIndex shapeIndex) = unsafePerformIO $ do MkIntArray idxsArrayPtr <- mkIntArray idxs - pure $ literalGetDouble lit idxsArrayPtr + pure $ literalGetDouble lit idxsArrayPtr (cast $ length idxs) shapeIndex namespace Int32t export - set : Literal -> List Nat -> Int32 -> IO () - set (MkLiteral lit) idxs value = do + set : Literal -> List Nat -> ShapeIndex -> Int32 -> IO () + set (MkLiteral lit) idxs (MkShapeIndex shapeIndex) value = do MkIntArray idxsArrayPtr <- mkIntArray idxs - primIO $ prim__literalSetInt32t lit idxsArrayPtr (cast value) + primIO $ prim__literalSetInt32t lit idxsArrayPtr (cast $ length idxs) shapeIndex (cast value) export - get : Literal -> List Nat -> Int32 - get (MkLiteral lit) idxs = unsafePerformIO $ do + get : Literal -> List Nat -> ShapeIndex -> Int32 + get (MkLiteral lit) idxs (MkShapeIndex shapeIndex) = unsafePerformIO $ do MkIntArray idxsArrayPtr <- mkIntArray idxs - pure $ cast $ literalGetInt32t lit idxsArrayPtr + pure $ cast $ literalGetInt32t lit idxsArrayPtr (cast $ length idxs) shapeIndex namespace UInt32t export - set : Literal -> List Nat -> Nat -> IO () - set (MkLiteral lit) idxs value = do + set : Literal -> List Nat -> ShapeIndex -> Nat -> IO () + set (MkLiteral lit) idxs (MkShapeIndex shapeIndex) value = do MkIntArray idxsArrayPtr <- mkIntArray idxs - primIO $ prim__literalSetUInt32t lit idxsArrayPtr (cast value) + primIO $ prim__literalSetUInt32t lit idxsArrayPtr (cast $ length idxs) shapeIndex (cast value) export - get : Literal -> List Nat -> Nat - get (MkLiteral lit) idxs = unsafePerformIO $ do + get : Literal -> List Nat -> ShapeIndex -> Nat + get (MkLiteral lit) idxs (MkShapeIndex shapeIndex) = unsafePerformIO $ do MkIntArray idxsArrayPtr <- mkIntArray idxs - pure $ cast $ literalGetUInt32t lit idxsArrayPtr + pure $ cast $ literalGetUInt32t lit idxsArrayPtr (cast $ length idxs) shapeIndex namespace UInt64t export - set : Literal -> List Nat -> Nat -> IO () - set (MkLiteral lit) idxs value = do + set : Literal -> List Nat -> ShapeIndex -> Nat -> IO () + set (MkLiteral lit) idxs (MkShapeIndex shapeIndex) value = do MkIntArray idxsArrayPtr <- mkIntArray idxs - primIO $ prim__literalSetUInt64t lit idxsArrayPtr (cast value) + primIO $ prim__literalSetUInt64t lit idxsArrayPtr (cast $ length idxs) shapeIndex (cast value) export - get : Literal -> List Nat -> Nat - get (MkLiteral lit) idxs = unsafePerformIO $ do + get : Literal -> List Nat -> ShapeIndex -> Nat + get (MkLiteral lit) idxs (MkShapeIndex shapeIndex) = unsafePerformIO $ do MkIntArray idxsArrayPtr <- mkIntArray idxs - pure $ cast $ literalGetUInt64t lit idxsArrayPtr + pure $ cast $ literalGetUInt64t lit idxsArrayPtr (cast $ length idxs) shapeIndex diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr index 3ab047ea7..5ebc26c5b 100644 --- a/src/Compiler/Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/ShapeUtil.idr @@ -21,6 +21,32 @@ import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData import Compiler.Xla.Util import Types +public export +data ShapeIndex : Type where + MkShapeIndex : GCAnyPtr -> ShapeIndex + +namespace ShapeIndex + export + delete : HasIO io => AnyPtr -> io () + delete = primIO . prim__shapeIndexDelete + +export +allocShapeIndex : HasIO io => io ShapeIndex +allocShapeIndex = do + ptr <- primIO prim__shapeIndexNew + ptr <- onCollectAny ptr ShapeIndex.delete + pure (MkShapeIndex ptr) + +export +pushBack : HasIO io => ShapeIndex -> Nat -> io () +pushBack (MkShapeIndex shapeIndex) value = + primIO $ prim__shapeIndexPushBack shapeIndex (cast value) + +export +pushFront : HasIO io => ShapeIndex -> Nat -> io () +pushFront (MkShapeIndex shapeIndex) value = + primIO $ prim__shapeIndexPushFront shapeIndex (cast value) + export mkShape : (HasIO io, Primitive dtype) => Types.Shape -> io Xla.Shape mkShape shape = do diff --git a/src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr b/src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr index fe9dfe27e..c9498a4cb 100644 --- a/src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr +++ b/src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr @@ -81,20 +81,20 @@ allocDotDimensionNumbers = do export addLhsContractingDimensions : HasIO io => DotDimensionNumbers -> Nat -> io () -addLhsContractingDimensions (MkDotDimensionNumbers dimension_numbers) n = - primIO $ prim__addLhsContractingDimensions dimension_numbers (cast n) +addLhsContractingDimensions (MkDotDimensionNumbers dimensionNumbers) n = + primIO $ prim__addLhsContractingDimensions dimensionNumbers (cast n) export addRhsContractingDimensions : HasIO io => DotDimensionNumbers -> Nat -> io () -addRhsContractingDimensions (MkDotDimensionNumbers dimension_numbers) n = - primIO $ prim__addRhsContractingDimensions dimension_numbers (cast n) +addRhsContractingDimensions (MkDotDimensionNumbers dimensionNumbers) n = + primIO $ prim__addRhsContractingDimensions dimensionNumbers (cast n) export addLhsBatchDimensions : HasIO io => DotDimensionNumbers -> Nat -> io () -addLhsBatchDimensions (MkDotDimensionNumbers dimension_numbers) n = - primIO $ prim__addLhsBatchDimensions dimension_numbers (cast n) +addLhsBatchDimensions (MkDotDimensionNumbers dimensionNumbers) n = + primIO $ prim__addLhsBatchDimensions dimensionNumbers (cast n) export addRhsBatchDimensions : HasIO io => DotDimensionNumbers -> Nat -> io () -addRhsBatchDimensions (MkDotDimensionNumbers dimension_numbers) n = - primIO $ prim__addRhsBatchDimensions dimension_numbers (cast n) +addRhsBatchDimensions (MkDotDimensionNumbers dimensionNumbers) n = + primIO $ prim__addRhsBatchDimensions dimensionNumbers (cast n) diff --git a/src/Tensor.idr b/src/Tensor.idr index ddbb017fb..8cf212c06 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -96,25 +96,74 @@ namespace S32 fromInteger : Integer -> Graph $ Tensor [] S32 fromInteger = tensor . Scalar . fromInteger +partial +try : Show e => Monad m => EitherT e m a -> m a +try x = runEitherT x <&> \case + Right x => x + Left err => idris_crash (show err) + ||| Evaluate a `Tensor`, returning its value as a `Literal`. This function builds and executes the ||| computational graph. ||| -||| This function will execute the graph on GPU if one is found, else it will use the host CPU. +||| `eval` will execute the graph on GPU if one is found, else it will use the host CPU. ||| ||| **Note:** -||| * Each call to `eval` will rebuild and execute the graph. Similarly, multiple calls to -||| `eval` on different `Tensor`s in a computation will be treated entirely independently. -||| `eval` does not store intermediate values. This is a known limitation, and may change in -||| the future. +||| * Each call to `eval` will rebuild and execute the graph; multiple calls to `eval` on different +||| tensors, even if they are in the same computation, will be treated entirely independently. +||| To efficiently evaluate multiple tensors at once, use `TensorList.eval`. ||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level ||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`. export partial eval : PrimitiveRW dtype ty => Graph (Tensor shape dtype) -> IO (Literal shape ty) -eval $ MkGraph x = do +eval $ MkGraph x = let (env, MkTensor root) = runState empty x - runEitherT (execute {dtype} (MkFn [] root env)) <&> \case - Right lit => lit - Left err => idris_crash (show err) + in try $ execute (MkFn [] root env) >>= read {dtype} [] + +namespace TensorList + ||| A list of `Tensor`s, along with the conversions needed to evaluate them to `Literal`s. + ||| The list is parametrized by the shapes and types of the resulting `Literal`s. + public export + data TensorList : List Shape -> List Type -> Type where + Nil : TensorList [] [] + (::) : PrimitiveRW dtype ty => + Tensor shape dtype -> + TensorList shapes tys -> + TensorList (shape :: shapes) (ty :: tys) + + ||| Evaluate a list of `Tensor`s as a list of `Literal`s. Tensors in the list can have different + ||| shapes and element types. For example, + ||| ``` + ||| main : IO () + ||| main = do [x, y] <- eval $ do x <- tensor {dtype = F64} [1.2, 3.4] + ||| y <- reduce @{Sum} [0] x + ||| pure [x, y] + ||| printLn x + ||| printLn y + ||| ``` + ||| In contrast to `Tensor.eval` when called on multiple tensors, this function constructs and + ||| compiles the graph just once. + ||| + ||| `eval` will execute the graph on GPU if one is found, else it will use the host CPU. + ||| + ||| **Note:** + ||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level + ||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`. + export partial + eval : Graph (TensorList shapes tys) -> IO (All2 Literal shapes tys) + eval $ MkGraph xs = + let (env, xs) = runState empty xs + (env, root) = runState env (addNode $ Tuple $ nodes xs) + in try $ execute (MkFn [] root env) >>= readAll xs 0 + + where + + nodes : TensorList s t -> List Nat + nodes [] = [] + nodes (MkTensor x :: xs) = x :: nodes xs + + readAll : HasIO io => TensorList s t -> Nat -> Literal -> io $ All2 Literal s t + readAll [] _ _ = pure [] + readAll (MkTensor {dtype} _ :: ts) n lit = [| read {dtype} [n] lit :: readAll ts (S n) lit |] ||| A string representation of the graph used to define a `Tensor`, detailing all enqueued XLA ||| operations. @@ -122,9 +171,8 @@ eval $ MkGraph x = do ||| Useful for debugging. export partial Show (Graph $ Tensor shape dtype) where - show $ MkGraph x = let (env, MkTensor root) = runState empty x in - case unsafePerformIO $ runEitherT $ toString (MkFn [] root env) of - Right str => str + show $ MkGraph x = let (env, MkTensor root) = runState empty x + in unsafePerformIO $ try $ toString (MkFn [] root env) ||| Bounds for numeric tensors. Will be infinite for floating point types. export diff --git a/src/Util.idr b/src/Util.idr index 8821a940f..4e95e0eb0 100644 --- a/src/Util.idr +++ b/src/Util.idr @@ -126,6 +126,13 @@ namespace List impl _ [] = [] impl i (x :: xs) = if elem i idxs then impl (S i) xs else x :: impl (S i) xs + namespace All2 + ||| A binary version of `All` from the standard library. + public export + data All2 : (0 p : a -> b -> Type) -> List a -> List b -> Type where + Nil : All2 p [] [] + (::) : forall xs, ys . p x y -> All2 p xs ys -> All2 p (x :: xs) (y :: ys) + ||| A `Sorted f xs` proves that for all consecutive elements `x` and `y` in `xs`, `f x y` exists. ||| For example, a `Sorted LT xs` proves that all `Nat`s in `xs` appear in increasing numerical ||| order. diff --git a/test/Unit/TestTensor.idr b/test/Unit/TestTensor.idr index 509fc8eea..8d81c8ed4 100644 --- a/test/Unit/TestTensor.idr +++ b/test/Unit/TestTensor.idr @@ -53,6 +53,53 @@ tensorThenEval = property $ do x <- forAll (literal shape bool) x === unsafePerformIO (eval (tensor {dtype=PRED} x)) +partial +evalTuple : Property +evalTuple = property $ do + s0 <- forAll shapes + s1 <- forAll shapes + s2 <- forAll shapes + + x0 <- forAll (literal s0 doubles) + x1 <- forAll (literal s1 int32s) + x2 <- forAll (literal s2 nats) + + let y0 = tensor {dtype = F64} x0 + y1 = tensor {dtype = S32} x1 + y2 = tensor {dtype = U64} x2 + + let [] = unsafePerformIO $ eval (pure []) + + let [x0'] = unsafePerformIO $ eval (do pure [!y0]) + + x0' ==~ x0 + + let [x0', x1'] = unsafePerformIO $ eval (do pure [!y0, !y1]) + + x0' ==~ x0 + x1' === x1 + + let [x0', x1', x2'] = unsafePerformIO $ eval (do pure [!y0, !y1, !y2]) + + x0' ==~ x0 + x1' === x1 + x2' === x2 + +partial +evalTupleNonTrivial : Property +evalTupleNonTrivial = property $ do + let xs = do y0 <- tensor [1.0, -2.0, 0.4] + y1 <- tensor 3.0 + u <- exp y0 + v <- slice [at 1] u + pure y1 + w <- slice [0.to 2] u + pure [v, w] + + [v, w] = unsafePerformIO $ eval xs + + v ==~ Scalar (exp (-2.0) + 3.0) + w ==~ [| exp [1.0, -2.0] |] + partial canConvertAtXlaNumericBounds : Property canConvertAtXlaNumericBounds = fixedProperty $ do @@ -402,6 +449,8 @@ export partial group : Group group = MkGroup "Tensor" $ [ ("eval . tensor", tensorThenEval) + , ("eval multiple tensors", evalTuple) + , ("eval multiple tensors for non-trivial graph", evalTupleNonTrivial) , ("can read/write finite numeric bounds to/from XLA", canConvertAtXlaNumericBounds) , ("bounded non-finite", boundedNonFinite) , ("iota", iota) diff --git a/test/Unit/TestTensor/Sampling.idr b/test/Unit/TestTensor/Sampling.idr index 66dc044af..71b9a8489 100644 --- a/test/Unit/TestTensor/Sampling.idr +++ b/test/Unit/TestTensor/Sampling.idr @@ -103,7 +103,7 @@ uniformSeedIsUpdated = withTests 20 . property $ do key <- forAll (literal [] nats) seed <- forAll (literal [1] nats) - let everything = do + let [seed, seed', seed'', sample, sample'] = unsafePerformIO $ eval $ do bound <- tensor bound bound' <- tensor bound' key <- tensor key @@ -112,12 +112,7 @@ uniformSeedIsUpdated = withTests 20 . property $ do rng <- uniform key {shape=[10]} !(broadcast bound) !(broadcast bound') (seed', sample) <- runStateT seed rng (seed'', sample') <- runStateT seed' rng - seeds <- concat 0 !(concat 0 seed seed') seed'' - samples <- concat 0 !(expand 0 sample) !(expand 0 sample') - pure (seeds, samples) - - [seed, seed', seed''] = unsafeEval (do (seeds, _) <- everything; pure seeds) - [sample, sample'] = unsafeEval (do (_, samples) <- everything; pure samples) + pure [seed, seed', seed'', sample, sample'] diff seed' (/=) seed diff seed'' (/=) seed' @@ -131,7 +126,7 @@ uniformIsReproducible = withTests 20 . property $ do key <- forAll (literal [] nats) seed <- forAll (literal [1] nats) - let [sample, sample'] = unsafeEval $ do + let [sample, sample'] = unsafePerformIO $ eval $ do bound <- tensor bound bound' <- tensor bound' key <- tensor key @@ -140,7 +135,7 @@ uniformIsReproducible = withTests 20 . property $ do rng <- uniform {shape=[10]} key !(broadcast bound) !(broadcast bound') sample <- evalStateT seed rng sample' <- evalStateT seed rng - concat 0 !(expand 0 sample) !(expand 0 sample') + pure [sample, sample'] sample ==~ sample' @@ -169,18 +164,13 @@ normalSeedIsUpdated = withTests 20 . property $ do key <- forAll (literal [] nats) seed <- forAll (literal [1] nats) - let everything = do + let [seed, seed', seed'', sample, sample'] = unsafePerformIO $ eval $ do key <- tensor key seed <- tensor seed let rng = normal key {shape=[10]} (seed', sample) <- runStateT seed rng (seed'', sample') <- runStateT seed' rng - seeds <- concat 0 !(concat 0 seed seed') seed'' - samples <- concat 0 !(expand 0 sample) !(expand 0 sample') - pure (seeds, samples) - - [seed, seed', seed''] = unsafeEval (do (seeds, _) <- everything; pure seeds) - [sample, sample'] = unsafeEval (do (_, samples) <- everything; pure samples) + pure [seed, seed', seed'', sample, sample'] diff seed' (/=) seed diff seed'' (/=) seed' @@ -192,14 +182,14 @@ normalIsReproducible = withTests 20 . property $ do key <- forAll (literal [] nats) seed <- forAll (literal [1] nats) - let [sample, sample'] = unsafeEval $ do + let [sample, sample'] = unsafePerformIO $ eval $ do key <- tensor key seed <- tensor seed let rng = normal {shape=[10]} key sample <- evalStateT seed rng sample' <- evalStateT seed rng - concat 0 !(expand 0 sample) !(expand 0 sample') + pure [sample, sample'] sample ==~ sample'