Skip to content

Commit

Permalink
Don't call the hook if it's already been canceled. (prebid#138)
Browse files Browse the repository at this point in the history
* Don't call the hook if it's already been canceled.

* Don't use label+goto

* Adjust comments.
  • Loading branch information
scr-oath authored and GitHub Enterprise committed Sep 4, 2024
1 parent 4ef5c8f commit baf9e51
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 36 deletions.
2 changes: 1 addition & 1 deletion endpoints/openrtb2/amp_auction.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (deps *endpointDeps) AmpAuction(w http.ResponseWriter, r *http.Request, _ h
// to compute the auction timeout.
start := time.Now()

hookExecutor := hookexecution.NewHookExecutor(deps.hookExecutionPlanBuilder, hookexecution.EndpointAmp, deps.metricsEngine)
hookExecutor := hookexecution.NewHookExecutor(deps.hookExecutionPlanBuilder, hookexecution.EndpointAmp, deps.metricsEngine, hookexecution.WithRequestContext(r))

ao := analytics.AmpObject{
Status: http.StatusOK,
Expand Down
2 changes: 1 addition & 1 deletion endpoints/openrtb2/auction.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (deps *endpointDeps) Auction(w http.ResponseWriter, r *http.Request, _ http
// to compute the auction timeout.
start := time.Now()

hookExecutor := hookexecution.NewHookExecutor(deps.hookExecutionPlanBuilder, hookexecution.EndpointAuction, deps.metricsEngine)
hookExecutor := hookexecution.NewHookExecutor(deps.hookExecutionPlanBuilder, hookexecution.EndpointAuction, deps.metricsEngine, hookexecution.WithRequestContext(r))

ao := analytics.AuctionObject{
Status: http.StatusOK,
Expand Down
5 changes: 3 additions & 2 deletions endpoints/openrtb2/video_auction.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"github.com/prebid/openrtb/v20/openrtb3"
"io"
"net/http"
"net/url"
Expand All @@ -14,6 +13,8 @@ import (
"strings"
"time"

"github.com/prebid/openrtb/v20/openrtb3"

"github.com/buger/jsonparser"
"github.com/gofrs/uuid"
"github.com/golang/glog"
Expand Down Expand Up @@ -126,7 +127,7 @@ func NewVideoEndpoint(
func (deps *endpointDeps) VideoAuctionEndpoint(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
start := time.Now()

hookExecutor := hookexecution.NewHookExecutor(deps.hookExecutionPlanBuilder, hookexecution.EndpointVideo, deps.metricsEngine)
hookExecutor := hookexecution.NewHookExecutor(deps.hookExecutionPlanBuilder, hookexecution.EndpointVideo, deps.metricsEngine, hookexecution.WithRequestContext(r))
vo := analytics.VideoObject{
Status: http.StatusOK,
Errors: make([]error, 0),
Expand Down
70 changes: 46 additions & 24 deletions hooks/hookexecution/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,33 +107,55 @@ func executeHook[H any, P any](
ctx, cancel := context.WithTimeout(executionCtx.ctx, timeout)
defer cancel()

go func() {
result, err := hookHandler(ctx, moduleCtx, hw.Hook, payload)
hookRespCh <- hookResponse[P]{
Result: result,
Err: err,
// Only execute the hook if it's not already canceled
if ctx.Err() == nil {
// Execute the hook in the background
go func() {
defer func() {
if r := recover(); r != nil {
var err error
var ok bool
if err, ok = r.(error); !ok {
err = fmt.Errorf("panic during hook execution: %v", r)
}
hookRespCh <- hookResponse[P]{
Err: err,
}
}
}()

result, err := hookHandler(ctx, moduleCtx, hw.Hook, payload)
hookRespCh <- hookResponse[P]{
Result: result,
Err: err,
}
}()

// Figure out what the hook did and return if success or rejected
select {
case res := <-hookRespCh:
res.HookID = hookId
res.ExecutionTime = time.Since(startTime)
resp <- res
return
case <-ctx.Done():
// fall through to the error handler
case <-rejected:
return
}
}()
}

select {
case res := <-hookRespCh:
res.HookID = hookId
res.ExecutionTime = time.Since(startTime)
resp <- res
case <-ctx.Done():
theResp := hookResponse[P]{
Err: ctx.Err(),
ExecutionTime: time.Since(startTime),
HookID: hookId,
Result: hookstage.HookResult[P]{},
}
if errors.Is(theResp.Err, context.DeadlineExceeded) {
theResp.Err = TimeoutError{}
}
resp <- theResp
case <-rejected:
return
// Handle the context error case - either immediately, or after timeout.
theResp := hookResponse[P]{
Err: ctx.Err(),
ExecutionTime: time.Since(startTime),
HookID: hookId,
Result: hookstage.HookResult[P]{},
}
if errors.Is(theResp.Err, context.DeadlineExceeded) {
theResp.Err = TimeoutError{}
}
resp <- theResp
}

func collectHookResponses[P any](resp <-chan hookResponse[P], rejected chan<- struct{}) []hookResponse[P] {
Expand Down
67 changes: 64 additions & 3 deletions hooks/hookexecution/execution_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package hookexecution

import (
"context"
"testing"
"time"

"github.com/prebid/openrtb/v20/openrtb2"
"github.com/prebid/prebid-server/v2/config"
"github.com/prebid/prebid-server/v2/hooks"
"github.com/prebid/prebid-server/v2/hooks/hookstage"
"github.com/prebid/prebid-server/v2/openrtb_ext"
"github.com/prebid/prebid-server/v2/privacy"
Expand Down Expand Up @@ -90,7 +93,7 @@ func TestHandleModuleActivitiesBidderRequestPayload(t *testing.T) {
}
for _, test := range testCases {
t.Run(test.description, func(t *testing.T) {
//check input payload didn't change
// check input payload didn't change
origInPayloadData := test.inPayloadData
activityControl := privacy.NewActivityControl(test.privacyConfig)
newPayload := handleModuleActivities(test.hookCode, activityControl, test.inPayloadData, nil)
Expand Down Expand Up @@ -173,7 +176,7 @@ func TestHandleModuleActivitiesProcessedAuctionRequestPayload(t *testing.T) {
}
for _, test := range testCases {
t.Run(test.description, func(t *testing.T) {
//check input payload didn't change
// check input payload didn't change
origInPayloadData := test.inPayloadData
activityControl := privacy.NewActivityControl(test.privacyConfig)
account := &config.Account{Privacy: config.AccountPrivacy{IPv6Config: config.IPv6{AnonKeepBits: testIPv6ScrubBytes}}}
Expand Down Expand Up @@ -224,7 +227,7 @@ func TestHandleModuleActivitiesNoBidderRequestPayload(t *testing.T) {
}
for _, test := range testCases {
t.Run(test.description, func(t *testing.T) {
//check input payload didn't change
// check input payload didn't change
origInPayloadData := test.inPayloadData
activityControl := privacy.NewActivityControl(test.privacyConfig)
newPayload := handleModuleActivities(test.hookCode, activityControl, test.inPayloadData, &config.Account{})
Expand All @@ -233,3 +236,61 @@ func TestHandleModuleActivitiesNoBidderRequestPayload(t *testing.T) {
})
}
}

func TestExecuteHookContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
executionCtx := executionContext{
ctx: ctx,
}
var moduleCtx hookstage.ModuleInvocationContext
var hw hooks.HookWrapper[hookstage.Entrypoint]
var payload hookstage.EntrypointPayload
var hh hookHandler[hookstage.Entrypoint, hookstage.EntrypointPayload]
respCh := make(chan hookResponse[hookstage.EntrypointPayload], 1)
rejectedCh := make(chan struct{}, 1)
defer close(rejectedCh)
defer close(respCh)
executeHook(executionCtx, moduleCtx, hw, payload, hh, time.Minute, respCh, rejectedCh)
resp := <-respCh
assert.ErrorIs(t, resp.Err, context.Canceled)
}

func TestExecuteHookTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
defer cancel()
time.Sleep(2 * time.Nanosecond)
executionCtx := executionContext{
ctx: ctx,
}
var moduleCtx hookstage.ModuleInvocationContext
var hw hooks.HookWrapper[hookstage.Entrypoint]
var payload hookstage.EntrypointPayload
var hh hookHandler[hookstage.Entrypoint, hookstage.EntrypointPayload]
respCh := make(chan hookResponse[hookstage.EntrypointPayload], 1)
rejectedCh := make(chan struct{}, 1)
defer close(rejectedCh)
defer close(respCh)
executeHook(executionCtx, moduleCtx, hw, payload, hh, time.Minute, respCh, rejectedCh)
resp := <-respCh
assert.ErrorIs(t, resp.Err, TimeoutError{})
}

func TestExecuteHookPanic(t *testing.T) {
executionCtx := executionContext{
ctx: context.Background(),
}
var moduleCtx hookstage.ModuleInvocationContext
var hw hooks.HookWrapper[hookstage.Entrypoint]
var payload hookstage.EntrypointPayload
var hh hookHandler[hookstage.Entrypoint, hookstage.EntrypointPayload] = func(ctx context.Context, invocationContext hookstage.ModuleInvocationContext, entrypoint hookstage.Entrypoint, payload hookstage.EntrypointPayload) (hookstage.HookResult[hookstage.EntrypointPayload], error) {
panic("foobar")
}
respCh := make(chan hookResponse[hookstage.EntrypointPayload], 1)
rejectedCh := make(chan struct{}, 1)
defer close(rejectedCh)
defer close(respCh)
executeHook(executionCtx, moduleCtx, hw, payload, hh, time.Minute, respCh, rejectedCh)
resp := <-respCh
assert.EqualError(t, resp.Err, "panic during hook execution: foobar")
}
26 changes: 21 additions & 5 deletions hooks/hookexecution/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,34 @@ type hookExecutor struct {
ctx context.Context
}

func NewHookExecutor(builder hooks.ExecutionPlanBuilder, endpoint string, me metrics.MetricsEngine) *hookExecutor {
return &hookExecutor{
// HookExecutorOpt modifes a *hookExecutor
type HookExecutorOpt func(*hookExecutor)

// WithContext Sets the hookExecutor's context
func WithContext(ctx context.Context) HookExecutorOpt {
return func(executor *hookExecutor) {
executor.ctx = ctx
}
}

// WithRequestContext Sets the hookExecutor's context from the given request
func WithRequestContext(req *http.Request) HookExecutorOpt {
return WithContext(req.Context())
}

func NewHookExecutor(builder hooks.ExecutionPlanBuilder, endpoint string, me metrics.MetricsEngine, opts ...HookExecutorOpt) *hookExecutor {
ret := &hookExecutor{
endpoint: endpoint,
planBuilder: builder,
stageOutcomes: []StageOutcome{},
moduleContexts: &moduleContexts{ctxs: make(map[string]hookstage.ModuleContext)},
metricEngine: me,
ctx: context.Background(),
}
for _, opt := range opts {
opt(ret)
}
return ret
}

func (e *hookExecutor) SetAccount(account *config.Account) {
Expand All @@ -95,9 +114,6 @@ func (e *hookExecutor) GetOutcomes() []StageOutcome {
}

func (e *hookExecutor) ExecuteEntrypointStage(req *http.Request, body []byte) ([]byte, *RejectError) {
// Grab the request context to propagate to the hook callbacks.
e.ctx = req.Context()

plan := e.planBuilder.PlanForEntrypointStage(e.endpoint)
if len(plan) == 0 {
return body, nil
Expand Down

0 comments on commit baf9e51

Please sign in to comment.