diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 00cf899349aa1..e283fd7ee92b1 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -456,6 +456,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_validity.cc compute/kernels/vector_array_sort.cc compute/kernels/vector_cumulative_ops.cc + compute/kernels/vector_pairwise.cc compute/kernels/vector_nested.cc compute/kernels/vector_rank.cc compute/kernels/vector_replace.cc diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index b33e3feb72993..67595c3308f9b 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -35,6 +35,7 @@ #include "arrow/result.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" +#include "arrow/util/reflection_internal.h" namespace arrow { @@ -150,6 +151,8 @@ static auto kRankOptionsType = GetFunctionOptionsType( DataMember("sort_keys", &RankOptions::sort_keys), DataMember("null_placement", &RankOptions::null_placement), DataMember("tiebreaker", &RankOptions::tiebreaker)); +static auto kPairwiseOptionsType = GetFunctionOptionsType( + DataMember("periods", &PairwiseOptions::periods)); } // namespace } // namespace internal @@ -217,6 +220,10 @@ RankOptions::RankOptions(std::vector sort_keys, NullPlacement null_plac tiebreaker(tiebreaker) {} constexpr char RankOptions::kTypeName[]; +PairwiseOptions::PairwiseOptions(int64_t periods) + : FunctionOptions(internal::kPairwiseOptionsType), periods(periods) {} +constexpr char PairwiseOptions::kTypeName[]; + namespace internal { void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType)); @@ -229,6 +236,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType)); } } // namespace internal @@ -338,6 +346,15 @@ Result> ValueCounts(const Datum& value, ExecContext return checked_pointer_cast(result.make_array()); } +Result> PairwiseDiff(const Array& array, + const PairwiseOptions& options, + bool check_overflow, ExecContext* ctx) { + auto func_name = check_overflow ? "pairwise_diff_checked" : "pairwise_diff"; + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction(func_name, {Datum(array)}, &options, ctx)); + return result.make_array(); +} + // ---------------------------------------------------------------------- // Filter- and take-related selection functions diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 56bccb38c2b53..c85db1aa3ba88 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -234,6 +234,17 @@ class ARROW_EXPORT CumulativeOptions : public FunctionOptions { }; using CumulativeSumOptions = CumulativeOptions; // For backward compatibility +/// \brief Options for pairwise functions +class ARROW_EXPORT PairwiseOptions : public FunctionOptions { + public: + explicit PairwiseOptions(int64_t periods = 1); + static constexpr char const kTypeName[] = "PairwiseOptions"; + static PairwiseOptions Defaults() { return PairwiseOptions(); } + + /// Periods to shift for applying the binary operation, accepts negative values. + int64_t periods = 1; +}; + /// @} /// \brief Filter with a boolean selection filter @@ -650,6 +661,28 @@ Result CumulativeMin( const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(), ExecContext* ctx = NULLPTR); +/// \brief Return the first order difference of an array. +/// +/// Computes the first order difference of an array, i.e. +/// output[i] = input[i] - input[i - p] if i >= p +/// output[i] = null otherwise +/// where p is the period. For example, with p = 1, +/// Diff([1, 4, 9, 10, 15]) = [null, 3, 5, 1, 5]. +/// With p = 2, +/// Diff([1, 4, 9, 10, 15]) = [null, null, 8, 6, 6] +/// p can also be negative, in which case the diff is computed in +/// the opposite direction. +/// \param[in] array array input +/// \param[in] options options, specifying overflow behavior and period +/// \param[in] check_overflow whether to return error on overflow +/// \param[in] ctx the function execution context, optional +/// \return result as array +ARROW_EXPORT +Result> PairwiseDiff(const Array& array, + const PairwiseOptions& options, + bool check_overflow = false, + ExecContext* ctx = NULLPTR); + // ---------------------------------------------------------------------- // Deprecated functions diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index c2c514dbb9f2f..3fbefe4a1ab7b 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -356,6 +356,9 @@ struct ARROW_EXPORT ExecResult { const std::shared_ptr& array_data() const { return std::get>(this->value); } + ArrayData* array_data_mutable() { + return std::get>(this->value).get(); + } bool is_array_data() const { return this->value.index() == 1; } }; diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index a52636aeb6bd2..5b5b5718e19dc 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -283,14 +283,16 @@ class ARROW_EXPORT OutputType { /// /// This function SHOULD _not_ be used to check for arity, that is to be /// performed one or more layers above. - using Resolver = Result (*)(KernelContext*, const std::vector&); + using Resolver = + std::function(KernelContext*, const std::vector&)>; /// \brief Output an exact type OutputType(std::shared_ptr type) // NOLINT implicit construction : kind_(FIXED), type_(std::move(type)) {} /// \brief Output a computed type depending on actual input types - OutputType(Resolver resolver) // NOLINT implicit construction + template + OutputType(Fn resolver) // NOLINT implicit construction : kind_(COMPUTED), resolver_(std::move(resolver)) {} OutputType(const OutputType& other) { diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index dcb024089475c..1afeb419c4958 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -69,6 +69,7 @@ add_arrow_benchmark(scalar_temporal_benchmark PREFIX "arrow-compute") add_arrow_compute_test(vector_test SOURCES vector_cumulative_ops_test.cc + vector_pairwise_test.cc vector_hash_test.cc vector_nested_test.cc vector_replace_test.cc diff --git a/cpp/src/arrow/compute/kernels/vector_pairwise.cc b/cpp/src/arrow/compute/kernels/vector_pairwise.cc new file mode 100644 index 0000000000000..440b1393a3ab2 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_pairwise.cc @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Vector kernels for pairwise computation + +#include +#include +#include "arrow/builder.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/base_arithmetic_internal.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/util.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" +#include "arrow/visit_type_inline.h" + +namespace arrow::compute::internal { + +// We reuse the kernel exec function of a scalar binary function to compute pairwise +// results. For example, for pairwise_diff, we reuse subtract's kernel exec. +struct PairwiseState : KernelState { + PairwiseState(const PairwiseOptions& options, ArrayKernelExec scalar_exec) + : periods(options.periods), scalar_exec(scalar_exec) {} + + int64_t periods; + ArrayKernelExec scalar_exec; +}; + +/// A generic pairwise implementation that can be reused by different ops. +Status PairwiseExecImpl(KernelContext* ctx, const ArraySpan& input, + const ArrayKernelExec& scalar_exec, int64_t periods, + ArrayData* result) { + // We only compute values in the region where the input-with-offset overlaps + // the original input. The margin where these do not overlap gets filled with null. + auto margin_length = std::min(abs(periods), input.length); + auto computed_length = input.length - margin_length; + auto margin_start = periods > 0 ? 0 : computed_length; + auto computed_start = periods > 0 ? margin_length : 0; + auto left_start = computed_start; + auto right_start = margin_length - computed_start; + // prepare bitmap + bit_util::ClearBitmap(result->buffers[0]->mutable_data(), margin_start, margin_length); + for (int64_t i = computed_start; i < computed_start + computed_length; i++) { + if (input.IsValid(i) && input.IsValid(i - periods)) { + bit_util::SetBit(result->buffers[0]->mutable_data(), i); + } else { + bit_util::ClearBit(result->buffers[0]->mutable_data(), i); + } + } + // prepare input span + ArraySpan left(input); + left.SetSlice(left_start, computed_length); + ArraySpan right(input); + right.SetSlice(right_start, computed_length); + // prepare output span + ArraySpan output_span; + output_span.SetMembers(*result); + output_span.offset = computed_start; + output_span.length = computed_length; + ExecResult output{output_span}; + // execute scalar function + RETURN_NOT_OK(scalar_exec(ctx, ExecSpan({left, right}, computed_length), &output)); + + return Status::OK(); +} + +Status PairwiseExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const auto& state = checked_cast(*ctx->state()); + auto input = batch[0].array; + RETURN_NOT_OK(PairwiseExecImpl(ctx, batch[0].array, state.scalar_exec, state.periods, + out->array_data_mutable())); + return Status::OK(); +} + +const FunctionDoc pairwise_diff_doc( + "Compute first order difference of an array", + ("Computes the first order difference of an array, It internally calls \n" + "the scalar function \"subtract\" to compute \n differences, so its \n" + "behavior and supported types are the same as \n" + "\"subtract\". The period can be specified in :struct:`PairwiseOptions`.\n" + "\n" + "Results will wrap around on integer overflow. Use function \n" + "\"pairwise_diff_checked\" if you want overflow to return an error."), + {"input"}, "PairwiseOptions"); + +const FunctionDoc pairwise_diff_checked_doc( + "Compute first order difference of an array", + ("Computes the first order difference of an array, It internally calls \n" + "the scalar function \"subtract_checked\" (or the checked variant) to compute \n" + "differences, so its behavior and supported types are the same as \n" + "\"subtract_checked\". The period can be specified in :struct:`PairwiseOptions`.\n" + "\n" + "This function returns an error on overflow. For a variant that doesn't \n" + "fail on overflow, use function \"pairwise_diff\"."), + {"input"}, "PairwiseOptions"); + +const PairwiseOptions* GetDefaultPairwiseOptions() { + static const auto kDefaultPairwiseOptions = PairwiseOptions::Defaults(); + return &kDefaultPairwiseOptions; +} + +struct PairwiseKernelData { + InputType input; + OutputType output; + ArrayKernelExec exec; +}; + +void RegisterPairwiseDiffKernels(std::string_view func_name, + std::string_view base_func_name, const FunctionDoc& doc, + FunctionRegistry* registry) { + VectorKernel kernel; + kernel.can_execute_chunkwise = false; + kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; + kernel.mem_allocation = MemAllocation::PREALLOCATE; + kernel.init = OptionsWrapper::Init; + auto func = std::make_shared(std::string(func_name), Arity::Unary(), + doc, GetDefaultPairwiseOptions()); + + auto base_func_result = registry->GetFunction(std::string(base_func_name)); + DCHECK_OK(base_func_result.status()); + const auto& base_func = checked_cast(**base_func_result); + DCHECK_EQ(base_func.arity().num_args, 2); + + for (const auto& base_func_kernel : base_func.kernels()) { + const auto& base_func_kernel_sig = base_func_kernel->signature; + if (!base_func_kernel_sig->in_types()[0].Equals( + base_func_kernel_sig->in_types()[1])) { + continue; + } + OutputType out_type(base_func_kernel_sig->out_type()); + // Need to wrap base output resolver + if (out_type.kind() == OutputType::COMPUTED) { + out_type = + OutputType([base_resolver = base_func_kernel_sig->out_type().resolver()]( + KernelContext* ctx, const std::vector& input_types) { + return base_resolver(ctx, {input_types[0], input_types[0]}); + }); + } + + kernel.signature = + KernelSignature::Make({base_func_kernel_sig->in_types()[0]}, out_type); + kernel.exec = PairwiseExec; + kernel.init = [scalar_exec = base_func_kernel->exec](KernelContext* ctx, + const KernelInitArgs& args) { + return std::make_unique( + checked_cast(*args.options), scalar_exec); + }; + DCHECK_OK(func->AddKernel(kernel)); + } + + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +void RegisterVectorPairwise(FunctionRegistry* registry) { + RegisterPairwiseDiffKernels("pairwise_diff", "subtract", pairwise_diff_doc, registry); + RegisterPairwiseDiffKernels("pairwise_diff_checked", "subtract_checked", + pairwise_diff_checked_doc, registry); +} + +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_pairwise_test.cc b/cpp/src/arrow/compute/kernels/vector_pairwise_test.cc new file mode 100644 index 0000000000000..c77f8ecc1a403 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_pairwise_test.cc @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "gmock/gmock.h" + +namespace arrow::compute { + +Result> GetOutputType( + const std::shared_ptr input_type) { + switch (input_type->id()) { + case Type::TIMESTAMP: { + return duration(checked_cast(*input_type).unit()); + } + case Type::TIME32: { + return duration(checked_cast(*input_type).unit()); + } + case Type::TIME64: { + return duration(checked_cast(*input_type).unit()); + } + case Type::DATE32: { + return duration(TimeUnit::SECOND); + } + case Type::DATE64: { + return duration(TimeUnit::MILLI); + } + case Type::DECIMAL128: { + const auto& real_type = checked_cast(*input_type); + return Decimal128Type::Make(real_type.precision() + 1, real_type.scale()); + } + case Type::DECIMAL256: { + const auto& real_type = checked_cast(*input_type); + return Decimal256Type::Make(real_type.precision() + 1, real_type.scale()); + } + default: { + return input_type; + } + } +} + +class TestPairwiseDiff : public ::testing::Test { + public: + void SetUp() override { + test_numerical_types_ = NumericTypes(); + test_temporal_types_ = TemporalTypes(); + test_decimal_types_ = {decimal(4, 2), decimal(70, 10)}; + + test_input_types_.insert(test_input_types_.end(), test_numerical_types_.begin(), + test_numerical_types_.end()); + test_input_types_.insert(test_input_types_.end(), test_temporal_types_.begin(), + test_temporal_types_.end()); + test_input_types_.insert(test_input_types_.end(), test_decimal_types_.begin(), + test_decimal_types_.end()); + } + + protected: + std::vector> test_numerical_types_; + std::vector> test_temporal_types_; + std::vector> test_decimal_types_; + std::vector> test_input_types_; +}; + +TEST_F(TestPairwiseDiff, Empty) { + for (int64_t period = -2; period <= 2; ++period) { + PairwiseOptions options(period); + for (auto input_type : test_input_types_) { + ASSERT_OK_AND_ASSIGN(auto output_type, GetOutputType(input_type)); + auto input = ArrayFromJSON(input_type, "[]"); + auto output = ArrayFromJSON(output_type, "[]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + } +} + +TEST_F(TestPairwiseDiff, AllNull) { + for (int64_t period = -2; period <= 2; ++period) { + PairwiseOptions options(period); + for (auto input_type : test_input_types_) { + ASSERT_OK_AND_ASSIGN(auto output_type, GetOutputType(input_type)); + auto input = ArrayFromJSON(input_type, "[null, null, null]"); + auto output = ArrayFromJSON(output_type, "[null, null, null]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + } +} + +TEST_F(TestPairwiseDiff, Numeric) { + { + PairwiseOptions options(1); + for (auto input_type : test_numerical_types_) { + ASSERT_OK_AND_ASSIGN(auto output_type, GetOutputType(input_type)); + auto input = ArrayFromJSON(input_type, "[null, 1, 2, null, 4, 5, 6]"); + auto output = ArrayFromJSON(output_type, "[null, null, 1, null, null, 1, 1]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + } + + { + PairwiseOptions options(2); + for (auto input_type : test_numerical_types_) { + ASSERT_OK_AND_ASSIGN(auto output_type, GetOutputType(input_type)); + auto input = ArrayFromJSON(input_type, "[null, 1, 2, null, 4, 5, 6]"); + auto output = ArrayFromJSON(output_type, "[null, null, null, null, 2, null, 2]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + } + + { + PairwiseOptions options(-1); + for (auto input_type : test_numerical_types_) { + ASSERT_OK_AND_ASSIGN(auto output_type, GetOutputType(input_type)); + auto input = ArrayFromJSON(input_type, "[6, 5, 4, null, 2, 1, null]"); + auto output = ArrayFromJSON(output_type, "[1, 1, null, null, 1, null, null]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + } + + { + PairwiseOptions options(-2); + for (auto input_type : test_numerical_types_) { + ASSERT_OK_AND_ASSIGN(auto output_type, GetOutputType(input_type)); + auto input = ArrayFromJSON(input_type, "[6, 5, 4, null, 2, 1, null]"); + auto output = ArrayFromJSON(output_type, "[2, null, 2, null, null, null, null]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + } +} + +TEST_F(TestPairwiseDiff, Overflow) { + { + PairwiseOptions options(1); + auto input = ArrayFromJSON(uint8(), "[3, 2, 1]"); + auto output = ArrayFromJSON(uint8(), "[null, 255, 255]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + + { + PairwiseOptions options(1); + auto input = ArrayFromJSON(uint8(), "[3, 2, 1]"); + auto output = ArrayFromJSON(uint8(), "[null, 255, 255]"); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("overflow"), + CallFunction("pairwise_diff_checked", {input}, &options)); + } +} + +TEST_F(TestPairwiseDiff, Temporal) { + { + PairwiseOptions options(1); + for (auto input_type : test_temporal_types_) { + ASSERT_OK_AND_ASSIGN(auto output_type, GetOutputType(input_type)); + auto input = ArrayFromJSON(input_type, "[null, 5, 1, null, 9, 6, 37]"); + auto output = ArrayFromJSON( + output_type, + input_type->id() != Type::DATE32 // Subtract date32 results in seconds + ? "[null, null, -4, null, null, -3, 31]" + : "[null, null, -345600, null, null, -259200, 2678400]"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + } +} + +TEST_F(TestPairwiseDiff, Decimal) { + { + PairwiseOptions options(1); + auto input = ArrayFromJSON(decimal(4, 2), R"(["11.00", "22.11", "-10.25", "33.45"])"); + auto output = ArrayFromJSON(decimal(5, 2), R"([null, "11.11", "-32.36", "43.70"])"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + + { + PairwiseOptions options(-1); + auto input = ArrayFromJSON( + decimal(40, 30), + R"(["1111111111.222222222222222222222222222222", "2222222222.333333333333333333333333333333"])"); + auto output = ArrayFromJSON( + decimal(41, 30), R"(["-1111111111.111111111111111111111111111111", null])"); + CheckVectorUnary("pairwise_diff", input, output, &options); + } + + { /// Out of range decimal precision + PairwiseOptions options(1); + auto input = ArrayFromJSON(decimal(38, 0), R"(["1e38"])"); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + testing::HasSubstr("Decimal precision out of range"), + CallFunction("pairwise_diff", {input}, &options)); + } +} +} // namespace arrow::compute diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index fe8a83a3f6eae..a4b484a2069ea 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -310,6 +310,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterVectorSort(registry.get()); RegisterVectorRunEndEncode(registry.get()); RegisterVectorRunEndDecode(registry.get()); + RegisterVectorPairwise(registry.get()); // Aggregate functions RegisterHashAggregateBasic(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 6628d09716544..b4239701d9573 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -54,7 +54,7 @@ void RegisterVectorSelection(FunctionRegistry* registry); void RegisterVectorSort(FunctionRegistry* registry); void RegisterVectorRunEndEncode(FunctionRegistry* registry); void RegisterVectorRunEndDecode(FunctionRegistry* registry); - +void RegisterVectorPairwise(FunctionRegistry* registry); void RegisterVectorOptions(FunctionRegistry* registry); // Aggregate functions diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 70c17ae2b96ea..55e29588129b8 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1847,3 +1847,28 @@ replaced, based on the remaining inputs. results in a corresponding null in the output. Also see: :ref:`if_else `. + +Pairwise functions +~~~~~~~~~~~~~~~~~~~~ +Pairwise functions are unary vector functions that perform a binary operation on +a pair of elements in the input array, typically on adjacent elements. The n-th +output is computed by applying the binary operation to the n-th and (n-p)-th inputs, +where p is the period. The default period is 1, in which case the binary +operation is applied to adjacent pairs of inputs. The period can also be +negative, in which case the n-th output is computed by applying the binary +operation to the n-th and (n+abs(p))-th inputs. + ++------------------------+-------+----------------------+----------------------+--------------------------------+----------+ +| Function name | Arity | Input types | Output type | Options class | Notes | ++========================+=======+======================+======================+================================+==========+ +| pairwise_diff | Unary | Numeric/Temporal | Numeric/Temporal | :struct:`PairwiseOptions` | \(1)(2) | ++------------------------+-------+----------------------+----------------------+--------------------------------+----------+ +| pairwise_diff_checked | Unary | Numeric/Temporal | Numeric/Temporal | :struct:`PairwiseOptions` | \(1)(3) | ++------------------------+-------+----------------------+----------------------+--------------------------------+----------+ + +* \(1) Computes the first order difference of an array, It internally calls + the scalar function ``Subtract`` (or the checked variant) to compute + differences, so its behavior and supported types are the same as + ``Subtract``. The period can be specified in :struct:`PairwiseOptions`. +* \(2) Wraps around the result when overflow is detected. +* \(3) Returns an ``Invalid`` :class:`Status` when overflow is detected. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 43deedd653425..f29d4db3941fd 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -521,6 +521,14 @@ Structural Transforms replace_with_mask struct_field +Pairwise Functions +------------------ + +.. autosummary:: + :toctree: ../generated/ + + pairwise_diff + Compute Options --------------- @@ -547,6 +555,7 @@ Compute Options ModeOptions NullOptions PadOptions + PairwiseOptions PartitionNthOptions QuantileOptions ReplaceSliceOptions diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index d1aded326d5c8..a33a09548d629 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1969,6 +1969,25 @@ class CumulativeOptions(_CumulativeOptions): self._set_options(start, skip_nulls) +cdef class _PairwiseOptions(FunctionOptions): + def _set_options(self, period): + self.wrapped.reset(new CPairwiseOptions(period)) + + +class PairwiseOptions(_PairwiseOptions): + """ + Options for `pairwise` functions. + + Parameters + ---------- + period : int, default 1 + Period for applying the period function. + """ + + def __init__(self, period=1): + self._set_options(period) + + cdef class _ArraySortOptions(FunctionOptions): def _set_options(self, order, null_placement): self.wrapped.reset(new CArraySortOptions( diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 3d428758a497c..0fefa18dd1136 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -50,6 +50,7 @@ ModeOptions, NullOptions, PadOptions, + PairwiseOptions, PartitionNthOptions, QuantileOptions, RandomOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 37a261c833431..da46cdcb750d5 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2407,6 +2407,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: optional[shared_ptr[CScalar]] start c_bool skip_nulls + cdef cppclass CPairwiseOptions \ + "arrow::compute::PairwiseOptions"(CFunctionOptions): + CPairwiseOptions(int64_t period) + int64_t period + cdef cppclass CArraySortOptions \ "arrow::compute::ArraySortOptions"(CFunctionOptions): CArraySortOptions(CSortOrder, CNullPlacement) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 38bdeb126348b..d9209ada24a5c 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -155,6 +155,7 @@ def test_option_class_equality(): pc.ModeOptions(), pc.NullOptions(), pc.PadOptions(5), + pc.PairwiseOptions(period=1), pc.PartitionNthOptions(1, null_placement="at_start"), pc.CumulativeOptions(start=None, skip_nulls=False), pc.QuantileOptions(), @@ -3481,3 +3482,33 @@ def test_run_end_encode(): check_run_end_encode_decode(pc.RunEndEncodeOptions(pa.int16())) check_run_end_encode_decode(pc.RunEndEncodeOptions('int32')) check_run_end_encode_decode(pc.RunEndEncodeOptions(pa.int64())) + + +def test_pairwise_diff(): + arr = pa.array([1, 2, 3, None, 4, 5]) + expected = pa.array([None, 1, 1, None, None, 1]) + result = pa.compute.pairwise_diff(arr, period=1) + assert result.equals(expected) + + arr = pa.array([1, 2, 3, None, 4, 5]) + expected = pa.array([None, None, 2, None, 1, None]) + result = pa.compute.pairwise_diff(arr, period=2) + assert result.equals(expected) + + # negative period + arr = pa.array([1, 2, 3, None, 4, 5], type=pa.int8()) + expected = pa.array([-1, -1, None, None, -1, None], type=pa.int8()) + result = pa.compute.pairwise_diff(arr, period=-1) + assert result.equals(expected) + + # wrap around overflow + arr = pa.array([1, 2, 3, None, 4, 5], type=pa.uint8()) + expected = pa.array([255, 255, None, None, 255, None], type=pa.uint8()) + result = pa.compute.pairwise_diff(arr, period=-1) + assert result.equals(expected) + + # fail on overflow + arr = pa.array([1, 2, 3, None, 4, 5], type=pa.uint8()) + with pytest.raises(pa.ArrowInvalid, + match="overflow"): + pa.compute.pairwise_diff_checked(arr, period=-1)