Skip to content

Commit

Permalink
Oversampling mitigation (#354)
Browse files Browse the repository at this point in the history
Oversampling-mitigation
  • Loading branch information
atshaw43 authored Sep 30, 2022
1 parent 9487f19 commit 11c498b
Show file tree
Hide file tree
Showing 18 changed files with 495 additions and 27 deletions.
4 changes: 3 additions & 1 deletion aws-xray-recorder-sdk-aws-sdk/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ dependencies {
testImplementation("com.amazonaws:aws-java-sdk-lambda:1.12.228")
testImplementation("com.amazonaws:aws-java-sdk-s3:1.12.228")
testImplementation("com.amazonaws:aws-java-sdk-sns:1.12.228")
testImplementation("org.powermock:powermock-reflect:2.0.2")
testImplementation("org.skyscreamer:jsonassert:1.3.0")
testImplementation("org.powermock:powermock-module-junit4:2.0.7")
testImplementation("org.powermock:powermock-api-mockito2:2.0.7")
testImplementation("com.github.stefanbirkner:system-rules:1.16.0")
}

tasks.jar {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ public void beforeRequest(Request<?> request) {
if (null != entityContext) {
recorder.setTraceEntity(entityContext);
}

Optional<Subsegment> previousSubsegment = recorder.getCurrentSubsegmentOptional();
Subsegment currentSubsegment = recorder.beginSubsegment(serviceName);
currentSubsegment.putAllAws(extractRequestParameters(request));
currentSubsegment.putAws(EntityDataKeys.AWS.OPERATION_KEY, operationName);
Expand All @@ -192,10 +194,14 @@ public void beforeRequest(Request<?> request) {
currentSubsegment.setNamespace(Namespace.AWS.toString());

if (recorder.getCurrentSegment() != null && recorder.getCurrentSubsegment().shouldPropagate()) {
boolean isSampled = previousSubsegment.isPresent() ?
previousSubsegment.get().isSampled() :
recorder.getCurrentSegment().isSampled();

TraceHeader header =
new TraceHeader(recorder.getCurrentSegment().getTraceId(),
recorder.getCurrentSegment().isSampled() ? currentSubsegment.getId() : null,
recorder.getCurrentSegment().isSampled() ? SampleDecision.SAMPLED : SampleDecision.NOT_SAMPLED);
isSampled ? currentSubsegment.getId() : null,
isSampled ? SampleDecision.SAMPLED : SampleDecision.NOT_SAMPLED);
request.addHeader(TraceHeader.HEADER_KEY, header.toString());
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazonaws.xray.handlers;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;

import com.amazonaws.DefaultRequest;
import com.amazonaws.Response;
import com.amazonaws.http.HttpResponse;
import com.amazonaws.services.lambda.model.InvokeRequest;
import com.amazonaws.services.lambda.model.InvokeResult;
import com.amazonaws.xray.AWSXRayRecorder;
import com.amazonaws.xray.AWSXRayRecorderBuilder;
import com.amazonaws.xray.contexts.LambdaSegmentContext;
import com.amazonaws.xray.contexts.LambdaSegmentContextResolver;
import com.amazonaws.xray.emitters.DefaultEmitter;
import com.amazonaws.xray.emitters.Emitter;
import com.amazonaws.xray.entities.Subsegment;
import com.amazonaws.xray.entities.TraceHeader;
import com.amazonaws.xray.strategy.sampling.NoSamplingStrategy;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.MethodSorters;
import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;

@RunWith(PowerMockRunner.class)
@FixMethodOrder(MethodSorters.JVM)
@PrepareForTest({LambdaSegmentContext.class, LambdaSegmentContextResolver.class})
@PowerMockIgnore("javax.net.ssl.*")
public class TracingHandlerLambdaTest {
private static final String TRACE_HEADER =
"Root=1-57ff426a-80c11c39b0c928905eb0828d;Parent=1234abcd1234abcd;Sampled=1";

@Test
public void testSamplingOverrideFalseInLambda() throws Exception {
TraceHeader header = TraceHeader.fromString(TRACE_HEADER);

PowerMockito.stub(PowerMockito.method(
LambdaSegmentContext.class, "getTraceHeaderFromEnvironment")).toReturn(header);
PowerMockito.stub(PowerMockito.method(
LambdaSegmentContextResolver.class, "getLambdaTaskRoot")).toReturn("/var/task");

Emitter mockedEmitted = Mockito.mock(DefaultEmitter.class);

AWSXRayRecorder recorder = AWSXRayRecorderBuilder.standard()
.withEmitter(mockedEmitted)
.build();

recorder.beginSubsegmentWithoutSampling("Test");

Subsegment subsegment = ((Subsegment) recorder.getTraceEntity());
assertThat(subsegment.shouldPropagate()).isTrue();

DefaultRequest<Void> request = new DefaultRequest<>(new InvokeRequest(), "Test");

TracingHandler tracingHandler = new TracingHandler(recorder);
tracingHandler.beforeRequest(request);
tracingHandler.afterResponse(request, new Response(new InvokeResult(), new HttpResponse(request, null)));

assertThat(TraceHeader.fromString(request.getHeaders().get(TraceHeader.HEADER_KEY)).getSampled())
.isEqualTo(TraceHeader.SampleDecision.NOT_SAMPLED);

recorder.endSubsegment();

Mockito.verify(mockedEmitted, Mockito.times(0)).sendSubsegment(any());
}

@Test
public void testSamplingOverrideTrueInLambda() {
Emitter mockedEmitted = Mockito.mock(DefaultEmitter.class);

AWSXRayRecorder recorder = AWSXRayRecorderBuilder.standard()
.withSamplingStrategy(new NoSamplingStrategy())
.withEmitter(mockedEmitted)
.build();

TraceHeader header = TraceHeader.fromString(TRACE_HEADER);

PowerMockito.stub(PowerMockito.method(
LambdaSegmentContext.class, "getTraceHeaderFromEnvironment")).toReturn(header);
PowerMockito.stub(PowerMockito.method(
LambdaSegmentContextResolver.class, "getLambdaTaskRoot")).toReturn("/var/task");

Mockito.doAnswer(invocation -> { return true; }).when(mockedEmitted).sendSubsegment(any());

recorder.beginSubsegment("test1");
Subsegment subsegment = ((Subsegment) recorder.getTraceEntity());
assertThat(subsegment.shouldPropagate()).isTrue();
DefaultRequest<Void> request = new DefaultRequest<>(new InvokeRequest(), "Test");
TracingHandler tracingHandler = new TracingHandler(recorder);
tracingHandler.beforeRequest(request);
assertThat(TraceHeader.fromString(request.getHeaders().get(TraceHeader.HEADER_KEY)).getSampled())
.isEqualTo(TraceHeader.SampleDecision.SAMPLED);
tracingHandler.afterResponse(request, new Response(new InvokeResult(), new HttpResponse(request, null)));
recorder.endSubsegment();
Mockito.verify(mockedEmitted, Mockito.times(1)).sendSubsegment(any());
}

@Test
public void testSamplingOverrideMixedInLambda() {
Emitter mockedEmitted = Mockito.mock(DefaultEmitter.class);

AWSXRayRecorder recorder = AWSXRayRecorderBuilder.standard()
.withSamplingStrategy(new NoSamplingStrategy())
.withEmitter(mockedEmitted)
.build();

TraceHeader header = TraceHeader.fromString(TRACE_HEADER);

PowerMockito.stub(PowerMockito.method(
LambdaSegmentContext.class, "getTraceHeaderFromEnvironment")).toReturn(header);
PowerMockito.stub(PowerMockito.method(
LambdaSegmentContextResolver.class, "getLambdaTaskRoot")).toReturn("/var/task");

Mockito.doAnswer(invocation -> { return true; }).when(mockedEmitted).sendSubsegment(any());

recorder.beginSubsegment("test1");
Subsegment subsegment1 = ((Subsegment) recorder.getTraceEntity());
assertThat(subsegment1.shouldPropagate()).isTrue();
DefaultRequest<Void> request1 = new DefaultRequest<>(new InvokeRequest(), "Test");
TracingHandler tracingHandler1 = new TracingHandler(recorder);
tracingHandler1.beforeRequest(request1);
assertThat(TraceHeader.fromString(request1.getHeaders().get(TraceHeader.HEADER_KEY)).getSampled())
.isEqualTo(TraceHeader.SampleDecision.SAMPLED);
tracingHandler1.afterResponse(request1, new Response(new InvokeResult(), new HttpResponse(request1, null)));
recorder.endSubsegment();
Mockito.verify(mockedEmitted, Mockito.times(1)).sendSubsegment(any());

recorder.beginSubsegmentWithoutSampling("test2");
Subsegment subsegment2 = ((Subsegment) recorder.getTraceEntity());
assertThat(subsegment2.shouldPropagate()).isTrue();
DefaultRequest<Void> request2 = new DefaultRequest<>(new InvokeRequest(), "Test");
TracingHandler tracingHandler2 = new TracingHandler(recorder);
tracingHandler2.beforeRequest(request2);
assertThat(TraceHeader.fromString(request2.getHeaders().get(TraceHeader.HEADER_KEY)).getSampled())
.isEqualTo(TraceHeader.SampleDecision.NOT_SAMPLED);
tracingHandler2.afterResponse(request2, new Response(new InvokeResult(), new HttpResponse(request2, null)));
recorder.endSubsegment();
Mockito.verify(mockedEmitted, Mockito.times(1)).sendSubsegment(any());

recorder.beginSubsegment("test3");
Subsegment subsegment3 = ((Subsegment) recorder.getTraceEntity());
assertThat(subsegment3.shouldPropagate()).isTrue();
DefaultRequest<Void> request3 = new DefaultRequest<>(new InvokeRequest(), "Test");
TracingHandler tracingHandler3 = new TracingHandler(recorder);
tracingHandler3.beforeRequest(request3);
assertThat(TraceHeader.fromString(request3.getHeaders().get(TraceHeader.HEADER_KEY)).getSampled())
.isEqualTo(TraceHeader.SampleDecision.SAMPLED);
tracingHandler3.afterResponse(request3, new Response(new InvokeResult(), new HttpResponse(request3, null)));
recorder.endSubsegment();
Mockito.verify(mockedEmitted, Mockito.times(2)).sendSubsegment(any());

recorder.beginSubsegmentWithoutSampling("test4");
Subsegment subsegment4 = ((Subsegment) recorder.getTraceEntity());
assertThat(subsegment4.shouldPropagate()).isTrue();
DefaultRequest<Void> request4 = new DefaultRequest<>(new InvokeRequest(), "Test");
TracingHandler tracingHandler4 = new TracingHandler(recorder);
tracingHandler4.beforeRequest(request4);
assertThat(TraceHeader.fromString(request4.getHeaders().get(TraceHeader.HEADER_KEY)).getSampled())
.isEqualTo(TraceHeader.SampleDecision.NOT_SAMPLED);
tracingHandler4.afterResponse(request4, new Response(new InvokeResult(), new HttpResponse(request4, null)));
recorder.endSubsegment();
Mockito.verify(mockedEmitted, Mockito.times(2)).sendSubsegment(any());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ public static Subsegment beginSubsegment(String name) {
return globalRecorder.beginSubsegment(name);
}

public static Subsegment beginSubsegmentWithoutSampling(String name) {
return globalRecorder.beginSubsegmentWithoutSampling(name);
}

public static void endSubsegment() {
globalRecorder.endSubsegment();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.amazonaws.xray.exceptions.SubsegmentNotFoundException;
import com.amazonaws.xray.internal.FastIdGenerator;
import com.amazonaws.xray.internal.IdGenerator;
import com.amazonaws.xray.internal.SamplingStrategyOverride;
import com.amazonaws.xray.internal.SecureIdGenerator;
import com.amazonaws.xray.listeners.SegmentListener;
import com.amazonaws.xray.strategy.ContextMissingStrategy;
Expand Down Expand Up @@ -616,6 +617,31 @@ public Subsegment beginSubsegment(String name) {
return context.beginSubsegment(this, name);
}

/**
* Begins a subsegment.
*
* @param name
* the name to use for the created subsegment
* @throws SegmentNotFoundException
* if {@code contextMissingStrategy} throws exceptions and no segment is currently in progress
* @return the newly created subsegment, or {@code null} if {@code contextMissingStrategy} suppresses and no segment is
* currently in progress. The subsegment will not be sampled regardless of the SamplingStrategy.
*/
public Subsegment beginSubsegmentWithoutSampling(String name) {
SegmentContext context = getSegmentContext();
if (context == null) {
// No context available, we return a no-op subsegment so user code does not have to work around this. Based on
// ContextMissingStrategy they will still know about the issue unless they explicitly opt-ed out.
// This no-op subsegment is different from unsampled no-op subsegments only in that it should not cause trace
// context to be propagated downstream
return Subsegment.noOp(this, false);
}
return context.beginSubsegmentWithSamplingOverride(
this,
name,
SamplingStrategyOverride.FALSE);
}

/**
* Ends a subsegment.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.amazonaws.xray.entities.TraceHeader.SampleDecision;
import com.amazonaws.xray.entities.TraceID;
import com.amazonaws.xray.exceptions.SubsegmentNotFoundException;
import com.amazonaws.xray.internal.SamplingStrategyOverride;
import com.amazonaws.xray.listeners.SegmentListener;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -61,16 +62,25 @@ private static FacadeSegment newFacadeSegment(AWSXRayRecorder recorder, String n
}

@Override
public Subsegment beginSubsegment(AWSXRayRecorder recorder, String name) {
public Subsegment beginSubsegmentWithSamplingOverride(
AWSXRayRecorder recorder,
String name,
SamplingStrategyOverride samplingStrategyOverride) {

if (logger.isDebugEnabled()) {
logger.debug("Beginning subsegment named: " + name);
}

Entity entity = getTraceEntity();
if (entity == null) { // First subsgment of a subsegment branch.
Segment parentSegment = newFacadeSegment(recorder, name);
Subsegment subsegment = parentSegment.isRecording()
? new SubsegmentImpl(recorder, name, parentSegment)
: Subsegment.noOp(parentSegment, recorder);

boolean isRecording = parentSegment.isRecording() &&
samplingStrategyOverride == SamplingStrategyOverride.DISABLED;

Subsegment subsegment = isRecording
? new SubsegmentImpl(recorder, name, parentSegment, samplingStrategyOverride)
: Subsegment.noOp(parentSegment, recorder, samplingStrategyOverride);
subsegment.setParent(parentSegment);
// Enable FacadeSegment to keep track of its subsegments for subtree streaming
parentSegment.addSubsegment(subsegment);
Expand All @@ -81,14 +91,18 @@ public Subsegment beginSubsegment(AWSXRayRecorder recorder, String name) {
// Ensure customers have not leaked subsegments across invocations
TraceID environmentRootTraceId = LambdaSegmentContext.getTraceHeaderFromEnvironment().getRootTraceId();
if (environmentRootTraceId != null &&
!environmentRootTraceId.equals(parentSubsegment.getParentSegment().getTraceId())) {
!environmentRootTraceId.equals(parentSubsegment.getParentSegment().getTraceId())) {
clearTraceEntity();
return beginSubsegment(recorder, name);
}
Segment parentSegment = parentSubsegment.getParentSegment();
Subsegment subsegment = parentSegment.isRecording()
? new SubsegmentImpl(recorder, name, parentSegment)
: Subsegment.noOp(parentSegment, recorder);

boolean isRecording = parentSegment.isRecording() &&
samplingStrategyOverride == SamplingStrategyOverride.DISABLED;

Subsegment subsegment = isRecording
? new SubsegmentImpl(recorder, name, parentSegment, samplingStrategyOverride)
: Subsegment.noOp(parentSegment, recorder, samplingStrategyOverride);
subsegment.setParent(parentSubsegment);
parentSubsegment.addSubsegment(subsegment);
setTraceEntity(subsegment);
Expand All @@ -102,6 +116,11 @@ public Subsegment beginSubsegment(AWSXRayRecorder recorder, String name) {
}
}

@Override
public Subsegment beginSubsegment(AWSXRayRecorder recorder, String name) {
return beginSubsegmentWithSamplingOverride(recorder, name, SamplingStrategyOverride.DISABLED);
}

@Override
public void endSubsegment(AWSXRayRecorder recorder) {
Entity current = getTraceEntity();
Expand Down Expand Up @@ -134,14 +153,13 @@ public void endSubsegment(AWSXRayRecorder recorder) {

Entity parentEntity = current.getParent();
if (parentEntity instanceof FacadeSegment) {
if (((FacadeSegment) parentEntity).isSampled()) {
if (((Subsegment) current).isSampled()) {
current.getCreator().getEmitter().sendSubsegment((Subsegment) current);
}
clearTraceEntity();
} else {
setTraceEntity(current.getParent());
}

} else {
recorder.getContextMissingStrategy().contextMissing("Failed to end subsegment: subsegment cannot be found.",
SubsegmentNotFoundException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.amazonaws.xray.entities.Entity;
import com.amazonaws.xray.entities.Segment;
import com.amazonaws.xray.entities.Subsegment;
import com.amazonaws.xray.internal.SamplingStrategyOverride;
import java.util.Objects;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -65,5 +66,10 @@ default void clearTraceEntity() {

Subsegment beginSubsegment(AWSXRayRecorder recorder, String name);

Subsegment beginSubsegmentWithSamplingOverride(
AWSXRayRecorder recorder,
String name,
SamplingStrategyOverride samplingStrategyOverride);

void endSubsegment(AWSXRayRecorder recorder);
}
Loading

0 comments on commit 11c498b

Please sign in to comment.