-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updating an array with multiple indices crashes the compilation pipeline #1446
Comments
Thanks for submitting the report @positr0nium! The case in your example is actually known to not work, which is when all possible array indices are passed to the Could you confirm whether you find it useful to use I will say though that the compiler shouldn't crash even if this is the case, so better messaging would be nice for sure. |
This seems to crash too :(
|
Just to confirm @positr0nium, do you get the same CompileError: Compilation failed:
test_function:28:13: error: Indices are not unique and/or not sorted, unique boolean: 0, sorted boolean :0
%10 = "stablehlo.scatter"(%0, %8, %9) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^
test_function:28:13: note: see current operation:
%15 = "mhlo.scatter"(%2, %14, %1) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<10xi64>, tensor<9x1xi32>, tensor<9xi64>) -> tensor<10xi64>
test_function:28:13: error: Indices are not unique and/or not sorted, unique boolean: 0, sorted boolean :0
%10 = "stablehlo.scatter"(%0, %8, %9) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^
test_function:28:13: note: see current operation:
%15 = "mhlo.scatter"(%2, %14, %1) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<10xi64>, tensor<9x1xi32>, tensor<9xi64>) -> tensor<10xi64>
test_function:28:13: error: failed to legalize operation 'mhlo.scatter'
%10 = "stablehlo.scatter"(%0, %8, %9) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^
test_function:28:13: note: see current operation:
%16 = "mhlo.scatter"(%1, %15, %3) <{indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<10xi64>, tensor<9x1xi32>, tensor<9xi64>) -> tensor<10xi64>
While processing 'FinalizingBufferize' pass While processing 'mlir::detail::OpToOpPassAdaptor' pass Failed to lower MLIR module |
Yes, this is the message I also receive. |
Actually this use case is fully supported, but it has some restrictions that are difficult for us to check or document since they relate to a JAX library function. Catalyst only supports slices with the from catalyst import qjit
import jax.numpy as jnp
@qjit
def test_function():
N = 10
jax_array = jnp.zeros(N, dtype = jnp.int64)
idx_array = jnp.arange(N-1, dtype = jnp.int64)
init_val = jnp.ones(N-1, dtype = jnp.int64)
return jax_array.at[idx_array].set(init_val, indices_are_sorted=True, unique_indices=True)
>>> test_function()
Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], dtype=int64) @josh146 Do you know how we could document this better? Maybe here? We could also try improving the error message to explicitly instruct the user to use these flags. |
Hi,
when using the
.at
methods with multiple indices the MLIR pipeline is crashed. Code example:The text was updated successfully, but these errors were encountered: