Skip to content

Commit

Permalink
feat: Snapshopt stream overrides values on each request (backport k24…
Browse files Browse the repository at this point in the history
…4) (#16559)

Co-authored-by: Dylan Guedes <djmgguedes@gmail.com>
  • Loading branch information
loki-gh-app[bot] and DylanGuedes authored Mar 5, 2025
1 parent 5cb3479 commit e0fad70
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 68 deletions.
2 changes: 1 addition & 1 deletion clients/pkg/promtail/targets/lokipush/pushtarget.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (t *PushTarget) run() error {
func (t *PushTarget) handleLoki(w http.ResponseWriter, r *http.Request) {
logger := util_log.WithContext(r.Context(), util_log.Logger)
userID, _ := tenant.TenantID(r.Context())
req, err := push.ParseRequest(logger, userID, r, nil, push.EmptyLimits{}, push.ParseLokiRequest, nil, nil, false)
req, err := push.ParseRequest(logger, userID, r, push.EmptyLimits{}, push.ParseLokiRequest, nil, nil, false)
if err != nil {
level.Warn(t.logger).Log("msg", "failed to parse incoming push request", "err", err.Error())
http.Error(w, err.Error(), http.StatusBadRequest)
Expand Down
35 changes: 29 additions & 6 deletions pkg/compactor/retention/expiration.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,40 @@ func NewTenantsRetention(l Limits) *TenantsRetention {
}

func (tr *TenantsRetention) RetentionHoursFor(userID string, lbs labels.Labels) string {
period := tr.RetentionPeriodFor(userID, lbs)
return util.RetentionHours(period)
return NewTenantRetentionSnapshot(tr.limits, userID).RetentionHoursFor(lbs)
}

func (tr *TenantsRetention) RetentionPeriodFor(userID string, lbs labels.Labels) time.Duration {
streamRetentions := tr.limits.StreamRetention(userID)
globalRetention := tr.limits.RetentionPeriod(userID)
return NewTenantRetentionSnapshot(tr.limits, userID).RetentionPeriodFor(lbs)
}

// TenantRetentionSnapshot is a snapshot of retention rules for a tenant.
// The underlying retention rules may change on the original limits object passed to
// NewTenantRetentionSnapshot, but the snapshot is immutable.
type TenantRetentionSnapshot struct {
streamRetentions []validation.StreamRetention
globalRetention time.Duration
}

func NewTenantRetentionSnapshot(limits Limits, userID string) *TenantRetentionSnapshot {
return &TenantRetentionSnapshot{
streamRetentions: limits.StreamRetention(userID),
globalRetention: limits.RetentionPeriod(userID),
}
}

func (r *TenantRetentionSnapshot) RetentionHoursFor(lbs labels.Labels) string {
period := r.RetentionPeriodFor(lbs)
return util.RetentionHours(period)
}

func (r *TenantRetentionSnapshot) RetentionPeriodFor(lbs labels.Labels) time.Duration {
var (
matchedRule validation.StreamRetention
found bool
)
Outer:
for _, streamRetention := range streamRetentions {
for _, streamRetention := range r.streamRetentions {
for _, m := range streamRetention.Matchers {
if !m.Matches(lbs.Get(m.Name)) {
continue Outer
Expand All @@ -166,10 +187,12 @@ Outer:
found = true
matchedRule = streamRetention
}

if found {
return time.Duration(matchedRule.Period)
}
return globalRetention

return r.globalRetention
}

type latestRetentionStartTime struct {
Expand Down
67 changes: 47 additions & 20 deletions pkg/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ type Distributor struct {
streamShardCount prometheus.Counter
tenantPushSanitizedStructuredMetadata *prometheus.CounterVec

policyResolver push.PolicyResolver
usageTracker push.UsageTracker
ingesterTasks chan pushIngesterTask
ingesterTaskWg sync.WaitGroup
Expand Down Expand Up @@ -224,11 +223,6 @@ func New(
return client.New(internalCfg, addr)
}

policyResolver := push.PolicyResolver(func(userID string, lbs labels.Labels) string {
mappings := overrides.PoliciesStreamMapping(userID)
return getPolicy(userID, lbs, mappings, logger)
})

validator, err := NewValidator(overrides, usageTracker)
if err != nil {
return nil, err
Expand Down Expand Up @@ -286,7 +280,6 @@ func New(
healthyInstancesCount: atomic.NewUint32(0),
rateLimitStrat: rateLimitStrat,
tee: tee,
policyResolver: policyResolver,
usageTracker: usageTracker,
ingesterTasks: make(chan pushIngesterTask),
ingesterAppends: promauto.With(registerer).NewCounterVec(prometheus.CounterOpts{
Expand Down Expand Up @@ -460,9 +453,17 @@ func (p *pushTracker) doneWithResult(err error) {
}
}

func (d *Distributor) Push(ctx context.Context, req *logproto.PushRequest) (*logproto.PushResponse, error) {
tenantID, err := tenant.TenantID(ctx)
if err != nil {
return nil, err
}
return d.PushWithResolver(ctx, req, newRequestScopedStreamResolver(tenantID, d.validator.Limits, d.logger))
}

// Push a set of streams.
// The returned error is the last one seen.
func (d *Distributor) Push(ctx context.Context, req *logproto.PushRequest) (*logproto.PushResponse, error) {
func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRequest, streamResolver *requestScopedStreamResolver) (*logproto.PushResponse, error) {
tenantID, err := tenant.TenantID(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -538,7 +539,7 @@ func (d *Distributor) Push(ctx context.Context, req *logproto.PushRequest) (*log

var lbs labels.Labels
var retentionHours, policy string
lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(validationContext, stream.Labels, stream)
lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver)
if err != nil {
d.writeFailuresManager.Log(tenantID, err)
validationErrors.Add(err)
Expand Down Expand Up @@ -661,7 +662,7 @@ func (d *Distributor) Push(ctx context.Context, req *logproto.PushRequest) (*log
}

if !d.ingestionRateLimiter.AllowN(now, tenantID, validationContext.validationMetrics.aggregatedPushStats.lineSize) {
d.trackDiscardedData(ctx, req, validationContext, tenantID, validationContext.validationMetrics, validation.RateLimited)
d.trackDiscardedData(ctx, req, validationContext, tenantID, validationContext.validationMetrics, validation.RateLimited, streamResolver)

err = fmt.Errorf(validation.RateLimitedErrorMsg, tenantID, int(d.ingestionRateLimiter.Limit(now, tenantID)), validationContext.validationMetrics.aggregatedPushStats.lineCount, validationContext.validationMetrics.aggregatedPushStats.lineSize)
d.writeFailuresManager.Log(tenantID, err)
Expand Down Expand Up @@ -810,6 +811,7 @@ func (d *Distributor) trackDiscardedData(
tenantID string,
validationMetrics validationMetrics,
reason string,
streamResolver push.StreamResolver,
) {
for policy, retentionToStats := range validationMetrics.policyPushStats {
for retentionHours, stats := range retentionToStats {
Expand All @@ -820,7 +822,7 @@ func (d *Distributor) trackDiscardedData(

if d.usageTracker != nil {
for _, stream := range req.Streams {
lbs, _, _, _, _, err := d.parseStreamLabels(validationContext, stream.Labels, stream)
lbs, _, _, _, _, err := d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver)
if err != nil {
continue
}
Expand Down Expand Up @@ -1199,11 +1201,11 @@ type labelData struct {
hash uint64
}

func (d *Distributor) parseStreamLabels(vContext validationContext, key string, stream logproto.Stream) (labels.Labels, string, uint64, string, string, error) {
mapping := d.validator.Limits.PoliciesStreamMapping(vContext.userID)
// parseStreamLabels parses stream labels using a request-scoped policy resolver
func (d *Distributor) parseStreamLabels(vContext validationContext, key string, stream logproto.Stream, streamResolver push.StreamResolver) (labels.Labels, string, uint64, string, string, error) {
if val, ok := d.labelCache.Get(key); ok {
retentionHours := d.tenantsRetention.RetentionHoursFor(vContext.userID, val.ls)
policy := getPolicy(vContext.userID, val.ls, mapping, d.logger)
retentionHours := streamResolver.RetentionHoursFor(val.ls)
policy := streamResolver.PolicyFor(val.ls)
return val.ls, val.ls.String(), val.hash, retentionHours, policy, nil
}

Expand All @@ -1214,7 +1216,7 @@ func (d *Distributor) parseStreamLabels(vContext validationContext, key string,
return nil, "", 0, retentionHours, "", fmt.Errorf(validation.InvalidLabelsErrorMsg, key, err)
}

policy := getPolicy(vContext.userID, ls, mapping, d.logger)
policy := streamResolver.PolicyFor(ls)
retentionHours := d.tenantsRetention.RetentionHoursFor(vContext.userID, ls)

if err := d.validator.ValidateLabels(vContext, ls, stream, retentionHours, policy); err != nil {
Expand Down Expand Up @@ -1311,16 +1313,41 @@ func (d *Distributor) HealthyInstancesCount() int {
return int(d.healthyInstancesCount.Load())
}

func getPolicy(userID string, lbs labels.Labels, mapping validation.PolicyStreamMapping, logger log.Logger) string {
policies := mapping.PolicyFor(lbs)
type requestScopedStreamResolver struct {
userID string
policyStreamMappings validation.PolicyStreamMapping
retention *retention.TenantRetentionSnapshot

logger log.Logger
}

func newRequestScopedStreamResolver(userID string, overrides Limits, logger log.Logger) *requestScopedStreamResolver {
return &requestScopedStreamResolver{
userID: userID,
policyStreamMappings: overrides.PoliciesStreamMapping(userID),
retention: retention.NewTenantRetentionSnapshot(overrides, userID),
logger: logger,
}
}

func (r requestScopedStreamResolver) RetentionPeriodFor(lbs labels.Labels) time.Duration {
return r.retention.RetentionPeriodFor(lbs)
}

func (r requestScopedStreamResolver) RetentionHoursFor(lbs labels.Labels) string {
return r.retention.RetentionHoursFor(lbs)
}

func (r requestScopedStreamResolver) PolicyFor(lbs labels.Labels) string {
policies := r.policyStreamMappings.PolicyFor(lbs)

var policy string
if len(policies) > 0 {
policy = policies[0]
if len(policies) > 1 {
level.Warn(logger).Log(
level.Warn(r.logger).Log(
"msg", "multiple policies matched for the same stream",
"org_id", userID,
"org_id", r.userID,
"stream", lbs.String(),
"policy", policy,
"policies", strings.Join(policies, ","),
Expand Down
112 changes: 109 additions & 3 deletions pkg/distributor/distributor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1259,11 +1259,12 @@ func Benchmark_SortLabelsOnPush(b *testing.B) {
distributors, _ := prepare(&testing.T{}, 1, 5, limits, nil)
d := distributors[0]
request := makeWriteRequest(10, 10)
streamResolver := newRequestScopedStreamResolver("123", d.validator.Limits, nil)
vCtx := d.validator.getValidationContextForTime(testTime, "123")
for n := 0; n < b.N; n++ {
stream := request.Streams[0]
stream.Labels = `{buzz="f", a="b"}`
_, _, _, _, _, err := d.parseStreamLabels(vCtx, stream.Labels, stream)
_, _, _, _, _, err := d.parseStreamLabels(vCtx, stream.Labels, stream, streamResolver)
if err != nil {
panic("parseStreamLabels fail,err:" + err.Error())
}
Expand Down Expand Up @@ -1307,11 +1308,11 @@ func TestParseStreamLabels(t *testing.T) {
d := distributors[0]

vCtx := d.validator.getValidationContextForTime(testTime, "123")

streamResolver := newRequestScopedStreamResolver("123", d.validator.Limits, nil)
t.Run(tc.name, func(t *testing.T) {
lbs, lbsString, hash, _, _, err := d.parseStreamLabels(vCtx, tc.origLabels, logproto.Stream{
Labels: tc.origLabels,
})
}, streamResolver)
if tc.expectedErr != nil {
require.Equal(t, tc.expectedErr, err)
return
Expand Down Expand Up @@ -2257,3 +2258,108 @@ func BenchmarkDistributor_PushWithPolicies(b *testing.B) {
})
}
}

func TestRequestScopedStreamResolver(t *testing.T) {
limits := &validation.Limits{}
flagext.DefaultValues(limits)

limits.RetentionPeriod = model.Duration(24 * time.Hour)
limits.StreamRetention = []validation.StreamRetention{
{
Period: model.Duration(48 * time.Hour),
Selector: `{env="prod"}`,
},
}
limits.PolicyStreamMapping = validation.PolicyStreamMapping{
"policy0": []*validation.PriorityStream{
{
Selector: `{env="prod"}`,
},
},
}

// Load matchers
require.NoError(t, limits.Validate())

overrides, err := validation.NewOverrides(*limits, nil)
require.NoError(t, err)

resolver := newRequestScopedStreamResolver("123", overrides, nil)

retentionHours := resolver.RetentionHoursFor(labels.FromStrings("env", "prod"))
require.Equal(t, "48", retentionHours)
retentionPeriod := resolver.RetentionPeriodFor(labels.FromStrings("env", "prod"))
require.Equal(t, 48*time.Hour, retentionPeriod)

retentionHours = resolver.RetentionHoursFor(labels.FromStrings("env", "dev"))
require.Equal(t, "24", retentionHours)
retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
require.Equal(t, 24*time.Hour, retentionPeriod)

policy := resolver.PolicyFor(labels.FromStrings("env", "prod"))
require.Equal(t, "policy0", policy)

policy = resolver.PolicyFor(labels.FromStrings("env", "dev"))
require.Empty(t, policy)

// We now modify the underlying limits to test that the resolver is not affected by changes to the limits
limits.RetentionPeriod = model.Duration(36 * time.Hour)
limits.StreamRetention = []validation.StreamRetention{
{
Period: model.Duration(72 * time.Hour),
Selector: `{env="dev"}`,
},
}
limits.PolicyStreamMapping = validation.PolicyStreamMapping{
"policy1": []*validation.PriorityStream{
{
Selector: `{env="dev"}`,
},
},
}

// Load matchers
require.NoError(t, limits.Validate())

newOverrides, err := validation.NewOverrides(*limits, nil)
require.NoError(t, err)

// overwrite the overrides we passed to the resolver by the new ones
*overrides = *newOverrides

// All should be the same as before
retentionHours = resolver.RetentionHoursFor(labels.FromStrings("env", "prod"))
require.Equal(t, "48", retentionHours)
retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "prod"))
require.Equal(t, 48*time.Hour, retentionPeriod)

retentionHours = resolver.RetentionHoursFor(labels.FromStrings("env", "dev"))
require.Equal(t, "24", retentionHours)
retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
require.Equal(t, 24*time.Hour, retentionPeriod)

policy = resolver.PolicyFor(labels.FromStrings("env", "prod"))
require.Equal(t, "policy0", policy)

policy = resolver.PolicyFor(labels.FromStrings("env", "dev"))
require.Empty(t, policy)

// But a new resolver should return the new values
newResolver := newRequestScopedStreamResolver("123", overrides, nil)

retentionHours = newResolver.RetentionHoursFor(labels.FromStrings("env", "prod"))
require.Equal(t, "36", retentionHours)
retentionPeriod = newResolver.RetentionPeriodFor(labels.FromStrings("env", "prod"))
require.Equal(t, 36*time.Hour, retentionPeriod)

retentionHours = newResolver.RetentionHoursFor(labels.FromStrings("env", "dev"))
require.Equal(t, "72", retentionHours)
retentionPeriod = newResolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
require.Equal(t, 72*time.Hour, retentionPeriod)

policy = newResolver.PolicyFor(labels.FromStrings("env", "prod"))
require.Empty(t, policy)

policy = newResolver.PolicyFor(labels.FromStrings("env", "dev"))
require.Equal(t, "policy1", policy)
}
8 changes: 6 additions & 2 deletions pkg/distributor/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe
pushRequestParser = d.RequestParserWrapper(pushRequestParser)
}

// Create a request-scoped policy and retention resolver that will ensure consistent policy and retention resolution
// across all parsers for this HTTP request.
streamResolver := newRequestScopedStreamResolver(tenantID, d.validator.Limits, logger)

logPushRequestStreams := d.tenantConfigs.LogPushRequestStreams(tenantID)
req, err := push.ParseRequest(logger, tenantID, r, d.tenantsRetention, d.validator.Limits, pushRequestParser, d.usageTracker, d.policyResolver, logPushRequestStreams)
req, err := push.ParseRequest(logger, tenantID, r, d.validator.Limits, pushRequestParser, d.usageTracker, streamResolver, logPushRequestStreams)
if err != nil {
if !errors.Is(err, push.ErrAllLogsFiltered) {
if d.tenantConfigs.LogPushRequest(tenantID) {
Expand Down Expand Up @@ -77,7 +81,7 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe
)
}

_, err = d.Push(r.Context(), req)
_, err = d.PushWithResolver(r.Context(), req, streamResolver)
if err == nil {
if d.tenantConfigs.LogPushRequest(tenantID) {
level.Debug(logger).Log(
Expand Down
3 changes: 1 addition & 2 deletions pkg/distributor/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ func newFakeParser() *fakeParser {
func (p *fakeParser) parseRequest(
_ string,
_ *http.Request,
_ push.TenantsRetention,
_ push.Limits,
_ push.UsageTracker,
_ push.PolicyResolver,
_ push.StreamResolver,
_ bool,
_ log.Logger,
) (*logproto.PushRequest, *push.Stats, error) {
Expand Down
Loading

0 comments on commit e0fad70

Please sign in to comment.