From 088f4399296a038123712a2ac7e263c1b1b95714 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 15 Jan 2025 07:26:47 -0500 Subject: [PATCH] [Preserve dtype of array when converting to torch (#1349) We have noticing the following error with a recent version of outlines when used with MLX: ``` TypeError: argument 'token_id': 'float' object cannot be interpreted as an integer At: /.../outlines_core/fsm/guide.py(294): get_next_state /.../outlines/processors/structured.py(101): process_logits /.../outlines/processors/base_logits_processor.py(90): __call__ ``` The issue is that the MLX array of tokens, which are integers, are being force-converted to floats, even though outlines expects an integer array. This is because all MLX arrays are being converted to `float32`, even when it's not necessarily appropriate, like in this case. Looking at the [commented link](https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch), the advice was to convert to `float32` only for `bfloat16`, because numpy does not support `bfloat16`. Now the MLX `_to_torch` implementation matches the other array libraries, none of the other libraries are being force-casted to float --- outlines/processors/base_logits_processor.py | 6 +++--- tests/processors/test_base_processor.py | 8 +++++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index 44b55af2e..800e69f79 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -110,9 +110,9 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: import mlx.core as mx # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch - return torch.from_dlpack( - np.array(tensor_like.astype(mx.float32), copy=False) - ) + if tensor_like.dtype == mx.bfloat16: + tensor_like = tensor_like.astype(mx.float32) + return torch.from_dlpack(np.array(tensor_like, copy=False)) elif is_jax_array_type(type(tensor_like)): import jax diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py index cd9f48278..d2a1e1af2 100644 --- a/tests/processors/test_base_processor.py +++ b/tests/processors/test_base_processor.py @@ -18,6 +18,7 @@ import mlx.core as mx arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32) + arrays["mlx_bfloat16"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16) except ImportError: pass @@ -59,7 +60,12 @@ def test_from_torch(array_type, processor): torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) data = processor._from_torch(torch_tensor, type(arrays[array_type])) assert isinstance(data, type(arrays[array_type])) - assert np.allclose(data, arrays[array_type]) + if array_type == "mlx_bfloat16": + # For bfloat16, we expect the output to be float32 due to the conversion + assert data.dtype == mx.float32 + assert np.allclose(np.array(data), np.array([[1, 2], [3, 4]], dtype=np.float32)) + else: + assert np.allclose(data, arrays[array_type]) @pytest.mark.parametrize("array_type", arrays.keys())