Skip to content

Commit

Permalink
Add first working SIMD version
Browse files Browse the repository at this point in the history
  • Loading branch information
VMois committed Apr 22, 2024
1 parent a5924c7 commit 672e207
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 30 deletions.
78 changes: 74 additions & 4 deletions simd_atol.mojo
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tensor import Tensor
from utils.loop import unroll
from sys import strided_load
from memory.unsafe import bitcast

alias char = Int8
alias simd_width = simdwidthof[char]()
Expand Down Expand Up @@ -34,14 +35,83 @@ fn _is_uint(s: String) raises -> Bool:
return is_int


@always_inline
fn _combine_chunks[new_dtype: DType, old_dtype: DType, old_len: Int](value: SIMD[old_dtype, old_len]) raises -> SIMD[new_dtype, old_len // 2]:
var left_selected: SIMD[old_dtype, old_len]
var right_selected: SIMD[old_dtype, old_len]
var right_multiplied: SIMD[old_dtype, old_len]
@parameter
if old_len == 16:
alias even_mask = SIMD[old_dtype, old_len](0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f)
alias odd_mask = SIMD[old_dtype, old_len](0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0, 0x0f, 0)
left_selected = value & even_mask
right_selected = value & odd_mask
var left_shifted = left_selected.shift_left[1]()
alias multiplier = SIMD[old_dtype, old_len](10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0)
right_multiplied = right_selected * multiplier
return bitcast[new_dtype, old_len // 2](left_shifted + right_multiplied)
elif old_len == 8:
alias even_mask = SIMD[old_dtype, old_len](0, 0x00ff, 0, 0x00ff, 0, 0x00ff, 0, 0x00ff)
alias odd_mask = SIMD[old_dtype, old_len](0x00ff, 0, 0x00ff, 0, 0x00ff, 0, 0x00ff, 0)
left_selected = value & even_mask
right_selected = value & odd_mask
var left_shifted = left_selected.shift_left[1]()
alias multiplier = SIMD[old_dtype, old_len](100, 0, 100, 0, 100, 0, 100, 0)
right_multiplied = right_selected * multiplier
#print("left shifted", left_shifted)
#print("left selected", left_selected)
#print("right selected", right_selected)
return bitcast[new_dtype, old_len // 2](left_shifted + right_multiplied)
elif old_len == 4:
alias even_mask = SIMD[old_dtype, old_len](0, 0xffff, 0, 0xffff)
alias odd_mask = SIMD[old_dtype, old_len](0xffff, 0, 0xffff, 0)
left_selected = value & even_mask
right_selected = value & odd_mask
var left_shifted = left_selected.shift_left[1]()
alias multiplier = SIMD[old_dtype, old_len](10000, 0, 10000, 0)
right_multiplied = right_selected * multiplier
#print("left shifted", left_shifted)
#print("left selected", left_selected)
#print("right selected", right_selected)
return bitcast[new_dtype, old_len // 2](left_shifted + right_multiplied)
elif old_len == 2:
return (value[0] * 100000000 + value[1]).cast[new_dtype]()
else:
raise Error("Unsupported length")


fn atol(s: String) raises -> Int:
return 42
"""
Convert String that consists of 16 or less characters into integer.
"""

if len(s) == 0 or len(s) > 16:
raise Error("Only 16 or less Strings are supported.")

if not _is_uint(s):
raise Error("String is not convertible to integer.")

#print("Original:", s)
var zeros = SIMD[DType.uint8, simd_width](48)
var ptr = rebind[DTypePointer[DType.uint8]](s._as_ptr())
var adjusted_value = ptr.load[width=simd_width](0) - zeros
#print("Adjusted value", adjusted_value)
var chunk16 = _combine_chunks[DType.uint16](adjusted_value)
#print(chunk16)
var chunk32 = _combine_chunks[DType.uint32](chunk16)
#print(chunk32)
var chunk32_2 = _combine_chunks[DType.uint64](chunk32)
#print(chunk32_2)
var chunk32_3 = _combine_chunks[DType.uint64](chunk32_2)
#print(chunk32_3)

return chunk32_3.to_int() // (10 ** (simd_width - len(s)))


fn main() raises:
var s1: String = "2357"
var s1: String = "5852010871235579"
var s2: String = "-1257"
var s3: String = "9.03"

print(_is_uint(s3))
print(atol(s1))

32 changes: 6 additions & 26 deletions test_atol.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,15 @@ fn test_is_int() raises:
assert_equal(False, _is_uint("9.03"), "9.03 is not a valid uint.")


fn test_atol() raises:
fn test_simd_atol() raises:
assert_equal(375, atol(String("375")))
assert_equal(1, atol(String("001")))
assert_equal(-89, atol(String("-89")))
assert_equal(5852010871235579, atol(String("5852010871235579")))
assert_equal(9999, atol(String("9999")))
assert_equal(0, atol(String("0000")))
assert_equal(0, atol(String("0")))

# Negative cases
try:
_ = atol(String("9.03"))
raise Error("Failed to raise when converting string to integer.")
except e:
assert_equal(str(e), "String is not convertible to integer.")

try:
_ = atol(String(""))
raise Error("Failed to raise when converting empty string to integer.")
except e:
assert_equal(str(e), "Empty String cannot be converted to integer.")

try:
_ = atol(String("9223372036854775832"))
raise Error(
"Failed to raise when converting an integer too large to store in"
" Int."
)
except e:
assert_equal(
str(e), "String expresses an integer too large to store in Int."
)

fn main() raises:
#test_atol()
test_is_int()
test_simd_atol()

0 comments on commit 672e207

Please sign in to comment.