diff --git a/sklearn_numba_dpex/common/kernels.py b/sklearn_numba_dpex/common/kernels.py index cb3f109..670aac9 100644 --- a/sklearn_numba_dpex/common/kernels.py +++ b/sklearn_numba_dpex/common/kernels.py @@ -112,6 +112,36 @@ def broadcast_division(dividend_array, divisor_vector): return broadcast_division[global_size, work_group_size] +def make_check_all_equal_kernel(shape, work_group_size): + n_items = math.prod(shape) + global_size = math.ceil(n_items / work_group_size) * work_group_size + zero_idx = np.uint32(0) + one_idx = np.uint32(1) + + @dpex.kernel + def check_all_equal_kernel(left, right, all_equal): + item_idx = dpex.get_global_id(0) + + if item_idx >= n_items: + return + + current_all_equal_value = all_equal[zero_idx] + + if current_all_equal_value == zero_idx: + return + + if left[item_idx] != right[item_idx]: + all_equal[zero_idx] = zero_idx + + def check_all_equal(left, right, all_equal): + left = dpt.reshape(left, (-1,)) + right = dpt.reshape(right, (-1,)) + all_equal[zero_idx] = one_idx + check_all_equal_kernel[global_size, work_group_size](left, right, all_equal) + + return check_all_equal + + @lru_cache def make_broadcast_ops_1d_2d_axis1_kernel(shape, ops, work_group_size, dtype): """ @@ -434,6 +464,9 @@ def _make_partial_sum_reduction_2d_axis1_kernel( reduction_block_size = 2 * work_group_size work_group_shape = (1, work_group_size) + local_sum_and_set_items_if = _make_sum_and_set_items_if_kernel_func() + global_sum_and_set_items_if = _make_sum_and_set_items_if_kernel_func() + @dpex.kernel # fmt: off def partial_sum_reduction( @@ -499,18 +532,16 @@ def partial_sum_reduction( local_values = dpex.local.array(work_group_size, dtype=dtype) - # We must be careful to not read items outside of the array ! - if augend_idx >= sum_axis_size: - local_values[local_work_id] = zero - elif addend_idx >= sum_axis_size: - local_values[local_work_id] = ( - fused_elementwise_func(summands[row_idx, augend_idx]) - ) - else: - local_values[local_work_id] = ( - fused_elementwise_func(summands[row_idx, augend_idx]) + - fused_elementwise_func(summands[row_idx, addend_idx]) - ) + _prepare_local_memory( + local_work_id, + row_idx, + augend_idx, + addend_idx, + sum_axis_size, + summands, + # OUT + local_values + ) dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) @@ -525,26 +556,62 @@ def partial_sum_reduction( # are discarded. n_active_work_items = n_active_work_items // two_as_a_long work_item_idx = first_value_idx + local_work_id + n_active_work_items - if ( - (local_work_id < n_active_work_items) and - (work_item_idx < sum_axis_size) - ): - # Yet again, the remaining work items choose two values to sum such - # that contiguous work items read and write into contiguous slots of - # `local_values`. - local_values[local_work_id] += ( - local_values[local_work_id + n_active_work_items] - ) + + # Yet again, the remaining work items choose two values to sum such that + # contiguous work items read and write into contiguous slots of + # `local_values`. + local_sum_and_set_items_if( + ( + (local_work_id < n_active_work_items) and + (work_item_idx < sum_axis_size) + ), + local_work_id, + local_work_id, + local_work_id + n_active_work_items, + local_values, + # OUT + local_values + ) dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) # At this point local_values[0] + local_values[1] is equal to the sum of all # elements in summands that have been covered by the work group, we write it # into global memory - if local_work_id == zero_idx: - result[row_idx, local_work_group_id_in_row] = ( - local_values[zero_idx] + local_values[one_idx] + global_sum_and_set_items_if( + local_work_id == zero_idx, + (row_idx, local_work_group_id_in_row), + zero_idx, + one_idx, + local_values, + # OUT + result + ) + + # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 + @dpex.func + # fmt: off + def _prepare_local_memory( + local_work_id, # PARAM + row_idx, # PARAM + augend_idx, # PARAM + addend_idx, # PARAM + sum_axis_size, # PARAM + summands, # IN + local_values, # OUT + ): + # fmt: on + # We must be careful to not read items outside of the array ! + if augend_idx >= sum_axis_size: + local_values[local_work_id] = zero + elif addend_idx >= sum_axis_size: + local_values[local_work_id] = fused_elementwise_func( + summands[row_idx, augend_idx] ) + else: + local_values[local_work_id] = fused_elementwise_func( + summands[row_idx, augend_idx] + ) + fused_elementwise_func(summands[row_idx, addend_idx]) return work_group_shape, reduction_block_size, partial_sum_reduction @@ -633,6 +700,9 @@ def _make_partial_sum_reduction_2d_axis0_kernel( reduction_block_size = 2 * n_sub_groups_per_work_group work_group_shape = (n_sub_groups_per_work_group, sub_group_size) + local_sum_and_set_items_if = _make_sum_and_set_items_if_kernel_func() + global_sum_and_set_items_if = _make_sum_and_set_items_if_kernel_func() + # ???: how does this strategy compares to having each thread reducing N contiguous # items ? @dpex.kernel @@ -703,19 +773,18 @@ def partial_sum_reduction( (dpex.get_group_id(one_idx) * sub_group_size) + local_col_idx ) - # We must be careful to not read items outside of the array ! sum_axis_size = summands.shape[zero_idx] - if (col_idx >= n_cols) or (augend_row_idx >= sum_axis_size): - local_values[local_row_idx, local_col_idx] = zero - elif addend_row_idx >= sum_axis_size: - local_values[local_row_idx, local_col_idx] = ( - fused_elementwise_func(summands[augend_row_idx, col_idx]) - ) - else: - local_values[local_row_idx, local_col_idx] = ( - fused_elementwise_func(summands[augend_row_idx, col_idx]) + - fused_elementwise_func(summands[addend_row_idx, col_idx]) - ) + _prepare_local_memory( + local_row_idx, + local_col_idx, + col_idx, + augend_row_idx, + addend_row_idx, + sum_axis_size, + summands, + # OUT + local_values + ) dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) @@ -731,29 +800,86 @@ def partial_sum_reduction( # are discarded. n_active_sub_groups = n_active_sub_groups // two_as_a_long work_item_row_idx = first_row_idx + local_row_idx + n_active_sub_groups - if ( - (local_row_idx < n_active_sub_groups) and - (col_idx < n_cols) and - (work_item_row_idx < sum_axis_size) - ): - local_values[local_row_idx, local_col_idx] += ( - local_values[local_row_idx + n_active_sub_groups, local_col_idx] - ) + + local_sum_and_set_items_if( + ( + (local_row_idx < n_active_sub_groups) and + (col_idx < n_cols) and + (work_item_row_idx < sum_axis_size) + ), + (local_row_idx, local_col_idx), + (local_row_idx, local_col_idx), + (local_row_idx + n_active_sub_groups, local_col_idx), + local_values, + # OUT + local_values + ) dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) # At this point local_values[0, :] + local_values[1, :] is equal to the sum of # all elements in summands that have been covered by the work group, we write # it into global memory - if (local_row_idx == zero_idx) and (col_idx < n_cols): - result[local_block_id_in_col, col_idx] = ( - local_values[zero_idx, local_col_idx] + - local_values[one_idx, local_col_idx] + global_sum_and_set_items_if( + (local_row_idx == zero_idx) and (col_idx < n_cols), + (local_block_id_in_col, col_idx), + (zero_idx, local_col_idx), + (one_idx, local_col_idx), + local_values, + result + ) + + # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 + @dpex.func + # fmt: off + def _prepare_local_memory( + local_row_idx, # PARAM + local_col_idx, # PARAM + col_idx, # PARAM + augend_row_idx, # PARAM + addend_row_idx, # PARAM + sum_axis_size, # PARAM + summands, # IN + local_values, # OUT + ): + # fmt: on + # We must be careful to not read items outside of the array ! + sum_axis_size = summands.shape[zero_idx] + if (col_idx >= n_cols) or (augend_row_idx >= sum_axis_size): + local_values[local_row_idx, local_col_idx] = zero + elif addend_row_idx >= sum_axis_size: + local_values[local_row_idx, local_col_idx] = fused_elementwise_func( + summands[augend_row_idx, col_idx] ) + else: + local_values[local_row_idx, local_col_idx] = fused_elementwise_func( + summands[augend_row_idx, col_idx] + ) + fused_elementwise_func(summands[addend_row_idx, col_idx]) return work_group_shape, reduction_block_size, partial_sum_reduction +# HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 +def _make_sum_and_set_items_if_kernel_func(): + @dpex.func + # fmt: off + def set_sum_of_items_kernel_func( + condition, # PARAM + result_idx, # PARAM + addend_idx, # PARAM + augend_idx, # PARAN + summands, # IN + result # OUT + ): + # fmt: on + if not condition: + return + + result[result_idx] = summands[addend_idx] + summands[augend_idx] + + return set_sum_of_items_kernel_func + + @lru_cache def make_argmin_reduction_1d_kernel(size, device, dtype, work_group_size="max"): """Implement 1d argmin with the same strategy than for @@ -805,55 +931,127 @@ def partial_argmin_reduction( local_argmin = dpex.local.array(work_group_size, dtype=local_argmin_dtype) local_values = dpex.local.array(work_group_size, dtype=dtype) + _prepare_local_memory( + local_work_id, + group_id, + current_size, + has_previous_result, + previous_result, + values, + # OUT + local_argmin, + local_values, + ) + + dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + n_active_work_items = work_group_size + for i in range(n_local_iterations): + n_active_work_items = n_active_work_items // two_as_a_long + _local_iteration( + local_work_id, + n_active_work_items, + # OUT + local_values, + local_argmin + ) + dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + + _register_result( + first_work_id, + group_id, + local_argmin, + local_values, + # OUT + argmin_indices + ) + + # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 + @dpex.func + # fmt: off + def _prepare_local_memory( + local_work_id, # PARAM + group_id, # PARAM + current_size, # PARAM + has_previous_result, # PARAM + previous_result, # IN + values, # IN + local_argmin, # OUT + local_values, # OUT + ): + # fmt: on first_value_idx = group_id * work_group_size * two_as_a_long x_idx = first_value_idx + local_work_id - y_idx = first_value_idx + work_group_size + local_work_id if x_idx >= current_size: local_values[local_work_id] = inf - else: - if has_previous_result: - x_idx = previous_result[x_idx] - - if y_idx >= current_size: - local_argmin[local_work_id] = x_idx - local_values[local_work_id] = values[x_idx] - - else: - if has_previous_result: - y_idx = previous_result[y_idx] - - x = values[x_idx] - y = values[y_idx] - if x < y or (x == y and x_idx < y_idx): - local_argmin[local_work_id] = x_idx - local_values[local_work_id] = x - else: - local_argmin[local_work_id] = y_idx - local_values[local_work_id] = y + return - dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) - n_active_work_items = work_group_size - for i in range(n_local_iterations): - n_active_work_items = n_active_work_items // two_as_a_long - if local_work_id < n_active_work_items: - local_x_idx = local_work_id - local_y_idx = local_work_id + n_active_work_items + if has_previous_result: + x_idx = previous_result[x_idx] + + y_idx = first_value_idx + work_group_size + local_work_id - x = local_values[local_x_idx] - y = local_values[local_y_idx] + if y_idx >= current_size: + local_argmin[local_work_id] = x_idx + local_values[local_work_id] = values[x_idx] + return - if x > y: - local_values[local_x_idx] = y - local_argmin[local_x_idx] = local_argmin[local_y_idx] + if has_previous_result: + y_idx = previous_result[y_idx] - dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + x = values[x_idx] + y = values[y_idx] + if x < y or (x == y and x_idx < y_idx): + local_argmin[local_work_id] = x_idx + local_values[local_work_id] = x + return + + local_argmin[local_work_id] = y_idx + local_values[local_work_id] = y - if first_work_id: - if local_values[zero_idx] <= local_values[one_idx]: - argmin_indices[group_id] = local_argmin[zero_idx] - else: - argmin_indices[group_id] = local_argmin[one_idx] + # HACK 906 + @dpex.func + # fmt: off + def _local_iteration( + local_work_id, # PARAM + n_active_work_items, # PARAM + local_values, # INOUT + local_argmin # OUT + ): + # fmt: on + if local_work_id >= n_active_work_items: + return + + local_x_idx = local_work_id + local_y_idx = local_work_id + n_active_work_items + + x = local_values[local_x_idx] + y = local_values[local_y_idx] + + if x <= y: + return + + local_values[local_x_idx] = y + local_argmin[local_x_idx] = local_argmin[local_y_idx] + + # HACK 906 + @dpex.func + # fmt: off + def _register_result( + first_work_id, # PARAM + group_id, # PARAM + local_argmin, # IN + local_values, # IN + argmin_indices # OUT + ): + + if not first_work_id: + return + + if local_values[zero_idx] <= local_values[one_idx]: + argmin_indices[group_id] = local_argmin[zero_idx] + else: + argmin_indices[group_id] = local_argmin[one_idx] # As many partial reductions as necessary are chained until only one element # remains.argmin_indices diff --git a/sklearn_numba_dpex/common/tests/__init__.py b/sklearn_numba_dpex/common/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sklearn_numba_dpex/kmeans/drivers.py b/sklearn_numba_dpex/kmeans/drivers.py index 121228b..2510ab8 100644 --- a/sklearn_numba_dpex/kmeans/drivers.py +++ b/sklearn_numba_dpex/kmeans/drivers.py @@ -18,6 +18,7 @@ make_argmin_reduction_1d_kernel, make_broadcast_division_1d_2d_axis0_kernel, make_broadcast_ops_1d_2d_axis1_kernel, + make_check_all_equal_kernel, make_half_l2_norm_2d_axis0_kernel, make_initialize_to_zeros_kernel, make_sum_reduction_2d_kernel, @@ -153,6 +154,10 @@ def lloyd( dtype=compute_dtype, ) + check_all_equal_kernel = make_check_all_equal_kernel( + (n_samples,), max_work_group_size + ) + # Allocate the necessary memory in the device global memory new_centroids_t = dpt.empty_like(centroids_t, device=device) centroids_half_l2_norm = dpt.empty(n_clusters, dtype=compute_dtype, device=device) @@ -163,7 +168,18 @@ def lloyd( sq_dist_to_nearest_centroid = per_sample_inertia = dpt.empty( n_samples, dtype=compute_dtype, device=device ) - assignments_idx = dpt.empty(n_samples, dtype=np.uint32, device=device) + + assignments_idx_new = dpt.empty(n_samples, dtype=np.uint32, device=device) + check_strict_convergence = tol == 0 + # See the main loop for a more elaborate note about checking "strict convergence" + if check_strict_convergence: + assignments_idx = dpt.empty(n_samples, dtype=np.uint32, device=device) + # allocation of one scalar where we store the result of array comparisons + check_equal_result = dpt.empty(1, dtype=np.uint32, device=device) + strict_convergence = False + else: + assignments_idx = assignments_idx_new + new_centroids_t_private_copies = dpt.empty( (n_centroids_private_copies, n_features, n_clusters), dtype=compute_dtype, @@ -203,7 +219,7 @@ def lloyd( centroids_t, centroids_half_l2_norm, # OUT: - assignments_idx, + assignments_idx_new, new_centroids_t_private_copies, cluster_sizes_private_copies, ) @@ -226,7 +242,7 @@ def lloyd( X_t, sample_weight, new_centroids_t, - assignments_idx, + assignments_idx_new, # OUT: per_sample_inertia, ) @@ -249,7 +265,7 @@ def lloyd( centroids_t, centroids_half_l2_norm, # OUT: - assignments_idx, + assignments_idx_new, ) # if verbose is True and if sample_weight is uniform, distances to @@ -262,7 +278,7 @@ def lloyd( X_t, dpt.ones_like(sample_weight), centroids_t, - assignments_idx, + assignments_idx_new, # OUT: sq_dist_to_nearest_centroid, ) @@ -273,7 +289,7 @@ def lloyd( sample_weight, new_centroids_t, cluster_sizes, - assignments_idx, + assignments_idx_new, empty_clusters_list, sq_dist_to_nearest_centroid, per_sample_inertia, @@ -283,17 +299,6 @@ def lloyd( # Change `new_centroids_t` inplace broadcast_division_kernel(new_centroids_t, cluster_sizes) - compute_centroid_shifts_kernel( - centroids_t, - new_centroids_t, - # OUT: - centroid_shifts, - ) - - centroid_shifts_sum, *_ = reduce_centroid_shifts_kernel(centroid_shifts) - # Use numpy type to work around https://github.com/IntelPython/dpnp/issues/1238 - centroid_shifts_sum = compute_dtype(centroid_shifts_sum) - # ???: unlike sklearn, sklearn_intelex checks that pseudo_inertia decreases # and keep an additional copy of centroids that is updated only if the # value of the pseudo_inertia is smaller than all the past values. @@ -313,14 +318,57 @@ def lloyd( # refers. For this reason, this strategy is not compatible with sklearn # unit tests, that consider new_centroids_t (the array after the update) # to be the best centroids at each iteration. - centroids_t, new_centroids_t = (new_centroids_t, centroids_t) + # ???: if two successive assignations have been computed equal, it's called + # "strict convergence" and means that the algorithm has converged and can't get + # better (since same assignments will produce the same centroid updates, which + # will produce the same assignments, and so on..). scikit-learn decides to + # check for strict convergence at each iteration, but that sounds expensive + # since it requires an additional pass on each sample label every time. + # Providing the user chooses a sensible value for `tol`, wouldn't the cost of + # this check be in general greater than what the benefits ? + + # Following this reasoning, unlike scikit-learn, we choose to enforce this + # behavior if and only if `tol == 0`, because in this case, it is easy to see + # that lloyd can indeed fail to stop at the right time due to numerical errors. + # Moreover, this is enough to pass scikit-learn unit tests. When `tol > 0`, we + # rely on the user setting an appropriate tolerance threshold. + + # TODO: open an issue at `scikit-learn` and propose to adopt this behavior + # instead ? + + if check_strict_convergence: + assignments_idx, assignments_idx_new = ( + assignments_idx_new, + assignments_idx, + ) + check_all_equal_kernel( + assignments_idx_new, assignments_idx, check_equal_result + ) + are_assignments_equal, *_ = check_equal_result + if are_assignments_equal: + strict_convergence = True + break + + compute_centroid_shifts_kernel( + centroids_t, + new_centroids_t, + # OUT: + centroid_shifts, + ) + + centroid_shifts_sum, *_ = reduce_centroid_shifts_kernel(centroid_shifts) + # Use numpy type to work around https://github.com/IntelPython/dpnp/issues/1238 + centroid_shifts_sum = compute_dtype(centroid_shifts_sum) + n_iteration += 1 if verbose: converged_at = n_iteration - 1 - if centroid_shifts_sum == 0: # NB: possible if tol = 0 + if (check_strict_convergence and strict_convergence) or ( + centroid_shifts_sum == 0 # NB: possible if tol = 0 + ): print(f"Converged at iteration {converged_at}: strict convergence.") elif centroid_shifts_sum <= tol: @@ -478,7 +526,7 @@ def prepare_data_for_lloyd(X_t, init, tol, sample_weight, copy_x): ) # At the time of writing this code, dpnp does not support functions (like `==` - # operator) that would help computing `sample_weight_is_uniform` in a simpler + # operator) that would help computing `X_mean_is_zeroed` in a simpler # manner. # TODO: if dpnp support extends to relevant features, use it instead ? sum_of_squares_kernel = make_sum_reduction_2d_kernel( @@ -539,7 +587,7 @@ def prepare_data_for_lloyd(X_t, init, tol, sample_weight, copy_x): variance = variance_kernel(dpt.reshape(X_t, -1)) # Use numpy type to work around https://github.com/IntelPython/dpnp/issues/1238 - tol = (dpt.asnumpy(variance)[0] / n_features) * tol + tol = (dpt.asnumpy(variance)[0] / (n_features * n_samples)) * tol # check if sample_weight is uniform # At the time of writing this code, dpnp does not support functions (like `==` diff --git a/sklearn_numba_dpex/kmeans/engine.py b/sklearn_numba_dpex/kmeans/engine.py index 1027336..691f82a 100644 --- a/sklearn_numba_dpex/kmeans/engine.py +++ b/sklearn_numba_dpex/kmeans/engine.py @@ -206,7 +206,9 @@ def kmeans_single(self, X, sample_weight, centers_init_t): # XXX: having a C-contiguous centroid array is expected in sklearn in some # unit test and by the cython engine. assignments_idx = dpt.asnumpy(assignments_idx).astype(np.int32) - best_centroids_t = np.asfortranarray(dpt.asnumpy(best_centroids_t)) + best_centroids_t = np.asfortranarray(dpt.asnumpy(best_centroids_t)).astype( + self.estimator._output_dtype + ) # ???: rather that returning whatever dtype the driver returns (which might # depends on device support for float64), shouldn't we cast to a dtype that @@ -269,7 +271,9 @@ def get_euclidean_distances(self, X): ) euclidean_distances = get_euclidean_distances(X.T, cluster_centers) if self._is_in_testing_mode: - euclidean_distances = dpt.asnumpy(euclidean_distances) + euclidean_distances = dpt.asnumpy(euclidean_distances).astype( + self.estimator._output_dtype + ) return euclidean_distances def _validate_data(self, X, reset=True): @@ -291,6 +295,12 @@ def _validate_data(self, X, reset=True): else: accepted_dtypes = [np.float32] + if self._is_in_testing_mode and reset: + if (X_dtype := X.dtype) not in accepted_dtypes: + self.estimator._output_dtype = np.float64 + else: + self.estimator._output_dtype = X_dtype + with _validate_with_array_api(device): try: X = self.estimator._validate_data( diff --git a/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py b/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py index d05eb1a..00e43a0 100644 --- a/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py +++ b/sklearn_numba_dpex/kmeans/kernels/compute_euclidean_distances.py @@ -118,18 +118,39 @@ def compute_distances( dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) - if sample_idx < n_samples: - for i in range(window_n_centroids): - centroid_idx = first_centroid_idx + i - if centroid_idx < n_clusters: - euclidean_distances_t[first_centroid_idx + i, sample_idx] = ( - math.sqrt(sq_distances[i]) - ) + _save_distance( + sample_idx, + first_centroid_idx, + euclidean_distances_t, + # OUT + sq_distances + ) first_centroid_idx += window_n_centroids dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 + @dpex.func + # fmt: off + def _save_distance( + sample_idx, # PARAM + first_centroid_idx, # PARAM + euclidean_distances_t, # IN + sq_distances # OUT + ): + # fmt: on + if sample_idx >= n_samples: + return + + for i in range(window_n_centroids): + centroid_idx = first_centroid_idx + i + + if centroid_idx < n_clusters: + euclidean_distances_t[centroid_idx, sample_idx] = ( + math.sqrt(sq_distances[i]) + ) + n_windows_for_sample = math.ceil(n_samples / window_n_centroids) global_size = ( diff --git a/sklearn_numba_dpex/kmeans/kernels/compute_labels.py b/sklearn_numba_dpex/kmeans/kernels/compute_labels.py index f8f6e5b..25606d4 100644 --- a/sklearn_numba_dpex/kmeans/kernels/compute_labels.py +++ b/sklearn_numba_dpex/kmeans/kernels/compute_labels.py @@ -161,11 +161,20 @@ def assignment( dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) - # No update step, only store min_idx in the output array - if sample_idx >= n_samples: - return + _setitem_if( + sample_idx < n_samples, + sample_idx, + min_idx, + # OUT + assignments_idx, + ) - assignments_idx[sample_idx] = min_idx + # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 + @dpex.func + def _setitem_if(condition, index, value, array): + if condition: + array[index] = value + return condition n_windows_for_sample = math.ceil(n_samples / window_n_centroids) diff --git a/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py b/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py index ba22b93..707d9e9 100644 --- a/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py +++ b/sklearn_numba_dpex/kmeans/kernels/kmeans_plusplus.py @@ -216,24 +216,43 @@ def kmeansplusplus_single_step( dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) - if sample_idx < n_samples: - sample_weight_ = sample_weight[sample_idx] - closest_dist_sq_ = closest_dist_sq[sample_idx] - for i in range(window_n_candidates): - candidate_idx = first_candidate_idx + i - if candidate_idx < n_candidates: - sq_distance_i = min( - sq_distances[i] * sample_weight_, - closest_dist_sq_ - ) - sq_distances_t[first_candidate_idx + i, sample_idx] = ( - sq_distance_i - ) + _save_sq_distances( + sample_idx, + first_candidate_idx, + sq_distances, + sample_weight, + closest_dist_sq, + # OUT + sq_distances_t + ) first_candidate_idx += window_n_candidates dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 + @dpex.func + # fmt: off + def _save_sq_distances( + sample_idx, # PARAM + first_candidate_idx, # PARAM + sq_distances, # IN + sample_weight, # IN + closest_dist_sq, # IN + sq_distances_t, # OUT + ): + # fmt: on + if sample_idx >= n_samples: + return + + sample_weight_ = sample_weight[sample_idx] + closest_dist_sq_ = closest_dist_sq[sample_idx] + for i in range(window_n_candidates): + candidate_idx = first_candidate_idx + i + if candidate_idx < n_candidates: + sq_distance_i = min(sq_distances[i] * sample_weight_, closest_dist_sq_) + sq_distances_t[first_candidate_idx + i, sample_idx] = sq_distance_i + n_windows_for_samples = math.ceil(n_samples / window_n_candidates) global_size = ( diff --git a/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py b/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py index 5be9acb..14151e9 100644 --- a/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py +++ b/sklearn_numba_dpex/kmeans/kernels/lloyd_single_step.py @@ -328,6 +328,33 @@ def fused_lloyd_single_step( # End of outer loop. By now min_idx and min_sample_pseudo_inertia # contains the expected values. + _update_result_data( + sample_idx, + min_idx, + sub_group_idx, + X_t, + sample_weight, + # OUT + assignments_idx, + cluster_sizes_private_copies, + new_centroids_t_private_copies, + ) + + # HACK 906: see sklearn_numba_dpex.patches.tests.test_patches.test_hack_906 + @dpex.func + # fmt: off + def _update_result_data( + sample_idx, # PARAM + min_idx, # PARAM + sub_group_idx, # PARAM + X_t, # IN + sample_weight, # IN + assignments_idx, # OUT + cluster_sizes_private_copies, # OUT + new_centroids_t_private_copies, # OUT + ): + # fmt: on + # NB: this check can't be moved at the top at the kernel, because if a work item # exits early with a `return` it will never reach the barriers, thus causing a # deadlock. Early returns are only possible when there are no barriers within @@ -384,6 +411,7 @@ def fused_lloyd_single_step( (privatization_idx, feature_idx, min_idx), X_t[feature_idx, sample_idx] * weight, ) + return ( n_centroids_private_copies, fused_lloyd_single_step[global_size, work_group_shape], diff --git a/sklearn_numba_dpex/patches/tests/test_patches.py b/sklearn_numba_dpex/patches/tests/test_patches.py index 806a74b..45b3b3b 100644 --- a/sklearn_numba_dpex/patches/tests/test_patches.py +++ b/sklearn_numba_dpex/patches/tests/test_patches.py @@ -1,5 +1,6 @@ import subprocess +import dpctl import dpctl.tensor as dpt import numba_dpex as dpex import numpy as np @@ -86,3 +87,82 @@ def test_spirv_fix(): kmeans.fit(X_array) finally: _load_numba_dpex_with_patches() + + +def test_hack_906(): + """This test will raise when all hacks tagged with HACK 906 can be reverted. + + The hack is used several time in the codebase to work around a bug in the JIT + compiler that affects sequences of instructions containing a conditional write + operation in an array followed by a barrier. + + For kernels that contain such patterns, the output is sometimes wrong. See + https://github.com/IntelPython/numba-dpex/issues/906 for more information and + updates on the issue resolution. + + The hack consist in wrapping instructions that are suspected of triggering the + bug (basically all write operations in kernels that also contain a barrier) in + `dpex.func` device functions. + + This hack makes the code significantly harder to read and should be reverted ASAP. + """ + + dtype = np.float32 + + @dpex.kernel + def kernel(result): + local_idx = dpex.get_local_id(0) + local_values = dpex.local.array((1,), dtype=dtype) + + dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + + if local_idx < 1: + local_values[0] = 1 + + dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + + if local_idx < 1: + result[0] = 10 + + result = dpt.zeros(sh=(1), dtype=dtype, device=dpctl.SyclDevice("cpu")) + kernel[32, 32](result) + + rationale = """If this test fails, it means that the bug reported at + https://github.com/IntelPython/numba-dpex/issues/906 has been fixed, and all the + hacks tags with `# HACK 906` that were used to work around it can now be removed. + This test can also be removed. + """ + + assert dpt.asnumpy(result)[0] != 10, rationale + + # Test that highlight how the hack works + @dpex.kernel + def kernel(result): + local_idx = dpex.get_local_id(0) + local_values = dpex.local.array((1,), dtype=dtype) + + _local_setitem_if((local_idx < 1), 0, 1, local_values) + + dpex.barrier(dpex.CLK_LOCAL_MEM_FENCE) + + _global_setitem_if((local_idx < 1), 0, 10, result) + + # HACK: must define twice to work around the bug highlighted in test_regression_fix + _local_setitem_if = make_setitem_if_kernel_func() + _global_setitem_if = make_setitem_if_kernel_func() + + result = dpt.zeros(sh=(1), dtype=dtype, device=dpctl.SyclDevice("cpu")) + kernel[32, 32](result) + + assert dpt.asnumpy(result)[0] == 10 + + +# HACK 906 +def make_setitem_if_kernel_func(): + @dpex.func + def _setitem_if(condition, index, value, array): + if condition: + array[index] = value + return condition + + return _setitem_if