From cdedecaff1cca5feb048aacbc3b5f4af04d7c271 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Fri, 8 Mar 2024 13:20:15 +0000 Subject: [PATCH] feat: add non-native hint with native output --- std/math/emulated/field_hint.go | 57 +++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go index 52b8e9ddfc..c264136b79 100644 --- a/std/math/emulated/field_hint.go +++ b/std/math/emulated/field_hint.go @@ -23,6 +23,17 @@ func (f *Field[T]) wrapHint(nonnativeInputs ...*Element[T]) []frontend.Variable // nonnativeHint function with nonnative inputs. After nonnativeHint returns, it // decomposes the outputs into limbs. func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { + return unwrapHint(true, true, nativeInputs, nativeOutputs, nonnativeHint) +} + +// UnwrapHintWithNativeOutput unwraps the native inputs into nonnative inputs. Then +// it calls nonnativeHint function with nonnative inputs. After nonnativeHint +// returns, it returns native outputs as-is. +func UnwrapHintWithNativeOutput(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { + return unwrapHint(true, false, nativeInputs, nativeOutputs, nonnativeHint) +} + +func unwrapHint(isEmulatedInput, isEmulatedOutput bool, nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hint) error { if len(nativeInputs) < 2 { return fmt.Errorf("hint wrapper header is 2 elements") } @@ -61,20 +72,32 @@ func UnwrapHint(nativeInputs, nativeOutputs []*big.Int, nonnativeHint solver.Hin } readPtr += 1 + currentInputLen } - if len(nativeOutputs)%nbLimbs != 0 { - return fmt.Errorf("output count doesn't divide limb count") + + var nonnativeOutputs []*big.Int + if isEmulatedOutput { + if len(nativeOutputs)%nbLimbs != 0 { + return fmt.Errorf("output count doesn't divide limb count") + } + nonnativeOutputs = make([]*big.Int, len(nativeOutputs)/nbLimbs) + } else { + nonnativeOutputs = make([]*big.Int, len(nativeOutputs)) } - nonnativeOutputs := make([]*big.Int, len(nativeOutputs)/nbLimbs) for i := range nonnativeOutputs { nonnativeOutputs[i] = new(big.Int) } if err := nonnativeHint(nonnativeMod, nonnativeInputs, nonnativeOutputs); err != nil { return fmt.Errorf("nonnative hint: %w", err) } - for i := range nonnativeOutputs { - nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod) - if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil { - return fmt.Errorf("decompose %d-th element: %w", i, err) + if isEmulatedOutput { + for i := range nonnativeOutputs { + nonnativeOutputs[i].Mod(nonnativeOutputs[i], nonnativeMod) + if err := decompose(nonnativeOutputs[i], uint(nbBits), nativeOutputs[i*nbLimbs:(i+1)*nbLimbs]); err != nil { + return fmt.Errorf("decompose %d-th element: %w", i, err) + } + } + } else { + for i := range nonnativeOutputs { + nativeOutputs[i].Set(nonnativeOutputs[i]) } } return nil @@ -107,3 +130,23 @@ func (f *Field[T]) NewHint(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) } return outputs, nil } + +// NewHintWithNativeOutput allows to call the emulation hint function hf on +// nonnative inputs, expecting nbOutputs results. This function splits +// internally the emulated element into limbs and passes these to the hint +// function. There is [UnwrapHint] function which performs corresponding +// recomposition of limbs into integer values (and vice verse for output). +// +// This method is an alternation of [NewHint] method, which allows to pass +// nonnative inputs to the hint function and returns native outputs. This is +// useful when the outputs do not necessarily have to be emulated elements (e.g. +// bits) as it skips enforcing range checks on the outputs. +func (f *Field[T]) NewHintWithNativeOutput(hf solver.Hint, nbOutputs int, inputs ...*Element[T]) ([]frontend.Variable, error) { + nativeInputs := f.wrapHint(inputs...) + nbNativeOutputs := nbOutputs + nativeOutputs, err := f.api.Compiler().NewHint(hf, nbNativeOutputs, nativeInputs...) + if err != nil { + return nil, fmt.Errorf("call hint: %w", err) + } + return nativeOutputs, nil +}