Skip to content

Commit

Permalink
hide IO for toString and convert to Show implementation for `Tens…
Browse files Browse the repository at this point in the history
…or` (#242)
  • Loading branch information
joelberkeley authored Mar 23, 2022
1 parent e3e0828 commit 22fdf00
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ toArray (MkTensor {shape} mkOp) = unsafePerformIO $ do
||| Return a string representation of an unevaluated `Tensor`, detailing all enqueued operations.
||| Useful for debugging.
export
toString : Tensor shape dtype -> IO String
toString (MkTensor f) = do
builder <- prim__mkXlaBuilder ""
pure (prim__opToString builder !(f builder))
Show (Tensor shape dtype) where
show (MkTensor f) = unsafePerformIO $ do
builder <- prim__mkXlaBuilder ""
pure (prim__opToString builder !(f builder))

----------------------------- structural operations ----------------------------

Expand Down
17 changes: 8 additions & 9 deletions test/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,22 @@ test_const_toArray = do
in assert name (sufficientlyEq x x')
) doubles

test_toString : IO ()
test_toString = do
str <- toString $ const {shape=[]} {dtype=S32} 1
assert "toString for scalar Int" (str == "constant, shape=[], metadata={:0}")
test_show : IO ()
test_show = do
let x = const {shape=[]} {dtype=S32} 1
assert "show for scalar Int" (show x == "constant, shape=[], metadata={:0}")

let x = const {shape=[]} {dtype=S32} 1
y = const {shape=[]} {dtype=S32} 2
str <- toString (x + y)
assert "toString for scalar addition" $ str ==
assert "show for scalar addition" $ show (Tensor.(+) x y) ==
"""
add, shape=[], metadata={:0}
constant, shape=[], metadata={:0}
constant, shape=[], metadata={:0}
"""

str <- toString $ const {shape=[_]} {dtype=F64} [1.3, 2.0, -0.4]
assert "toString for vector F64" $ str == "constant, shape=[3], metadata={:0}"
let x = const {shape=[_]} {dtype=F64} [1.3, 2.0, -0.4]
assert "show for vector F64" $ show x == "constant, shape=[3], metadata={:0}"

test_reshape : IO ()
test_reshape = do
Expand Down Expand Up @@ -944,7 +943,7 @@ export
test : IO ()
test = do
test_const_toArray
test_toString
test_show
test_reshape
test_slice
test_index
Expand Down

0 comments on commit 22fdf00

Please sign in to comment.