diff --git a/pkg/rpcmiddleware/validator/validator.go b/pkg/rpcmiddleware/validator/validator.go index 2c7b53c4aa..58e48a26e9 100644 --- a/pkg/rpcmiddleware/validator/validator.go +++ b/pkg/rpcmiddleware/validator/validator.go @@ -39,14 +39,6 @@ func RegisterAllowedFieldMaskPaths(rpcFullMethod string, allowedPaths ...string) } } -func getAllowedFieldMaskPaths(rpcFullMethod string) map[string]struct{} { - return allowedFieldMaskPaths[rpcFullMethod] -} - -type fieldMaskGetter interface { - GetFieldMask() types.FieldMask -} - var errForbiddenFieldMaskPaths = errors.DefineInvalidArgument("field_mask_paths", "forbidden path(s) in field mask", "forbidden_paths") func forbiddenPaths(requestedPaths []string, allowedPaths map[string]struct{}) (invalidPaths []string) { @@ -60,14 +52,6 @@ nextRequestedPath: return } -type validatorWithContext interface { - ValidateContext(ctx context.Context) error -} - -type validator interface { - Validate() error -} - func convertError(err error) error { if ttnErr, ok := errors.From(err); ok { return ttnErr @@ -75,10 +59,46 @@ func convertError(err error) error { return grpc.Errorf(codes.InvalidArgument, err.Error()) } +func validateMessage(ctx context.Context, fullMethod string, msg interface{}) error { + if v, ok := msg.(interface { + GetFieldMask() types.FieldMask + }); ok { + if forbiddenPaths := forbiddenPaths(v.GetFieldMask().Paths, allowedFieldMaskPaths[fullMethod]); len(forbiddenPaths) > 0 { + return errForbiddenFieldMaskPaths.WithAttributes("forbidden_paths", forbiddenPaths) + } + } + + switch v := msg.(type) { + case interface { + ValidateContext(context.Context) error + }: + if err := v.ValidateContext(ctx); err != nil { + return convertError(err) + } + + case interface { + Validate() error + }: + if err := v.Validate(); err != nil { + return convertError(err) + } + + case interface { + ValidateFields(...string) error + }: + if err := v.ValidateFields(); err != nil { + return convertError(err) + } + + } + return nil +} + // UnaryServerInterceptor returns a new unary server interceptor that validates // incoming messages if those incoming messages implement: // (A) ValidateContext(ctx context.Context) error // (B) Validate() error +// (C) ValidateFields(...string) error // If a message implements both, then (A) should call (B). // // Invalid messages will be rejected with the error returned from the validator, @@ -89,28 +109,33 @@ func convertError(err error) error { // then the field mask paths are validated according to the registered list. func UnaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if v, ok := req.(fieldMaskGetter); ok { - if forbiddenPaths := forbiddenPaths(v.GetFieldMask().Paths, getAllowedFieldMaskPaths(info.FullMethod)); len(forbiddenPaths) > 0 { - return nil, errForbiddenFieldMaskPaths.WithAttributes("forbidden_paths", forbiddenPaths) - } - } - if v, ok := req.(validatorWithContext); ok { - if err := v.ValidateContext(ctx); err != nil { - return nil, convertError(err) - } - } else if v, ok := req.(validator); ok { - if err := v.Validate(); err != nil { - return nil, convertError(err) - } + if err := validateMessage(ctx, info.FullMethod, req); err != nil { + return nil, err } return handler(ctx, req) } } +type recvWrapper struct { + grpc.ServerStream + fullMethod string +} + +func (s *recvWrapper) RecvMsg(msg interface{}) error { + if err := s.ServerStream.RecvMsg(msg); err != nil { + return err + } + if err := validateMessage(s.Context(), s.fullMethod, msg); err != nil { + return err + } + return nil +} + // StreamServerInterceptor returns a new streaming server interceptor that validates // incoming messages if those incoming messages implement: // (A) ValidateContext(ctx context.Context) error // (B) Validate() error +// (C) ValidateFields(...string) error // If a message implements both, then (A) should call (B). // // Invalid messages will be rejected with the error returned from the validator, @@ -126,34 +151,9 @@ func UnaryServerInterceptor() grpc.UnaryServerInterceptor { // then the field mask paths are validated according to the registered list. func StreamServerInterceptor() grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - wrapper := &recvWrapper{ServerStream: stream, allowedFieldMaskPaths: getAllowedFieldMaskPaths(info.FullMethod)} - return handler(srv, wrapper) - } -} - -type recvWrapper struct { - grpc.ServerStream - allowedFieldMaskPaths map[string]struct{} -} - -func (s *recvWrapper) RecvMsg(m interface{}) error { - if err := s.ServerStream.RecvMsg(m); err != nil { - return err - } - if v, ok := m.(fieldMaskGetter); ok { - requested := v.GetFieldMask().Paths - if forbiddenPaths := forbiddenPaths(requested, s.allowedFieldMaskPaths); len(forbiddenPaths) > 0 { - return errForbiddenFieldMaskPaths.WithAttributes("forbidden_paths", forbiddenPaths) - } + return handler(srv, &recvWrapper{ + ServerStream: stream, + fullMethod: info.FullMethod, + }) } - if v, ok := m.(validatorWithContext); ok { - if err := v.ValidateContext(s.Context()); err != nil { - return convertError(err) - } - } else if v, ok := m.(validator); ok { - if err := v.Validate(); err != nil { - return convertError(err) - } - } - return nil }