Skip to content

Commit

Permalink
util: Support ValidateFields in validator middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Volosatovs committed Mar 6, 2019
1 parent 4453468 commit 97f87aa
Showing 1 changed file with 58 additions and 58 deletions.
116 changes: 58 additions & 58 deletions pkg/rpcmiddleware/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ func RegisterAllowedFieldMaskPaths(rpcFullMethod string, allowedPaths ...string)
}
}

func getAllowedFieldMaskPaths(rpcFullMethod string) map[string]struct{} {
return allowedFieldMaskPaths[rpcFullMethod]
}

type fieldMaskGetter interface {
GetFieldMask() types.FieldMask
}

var errForbiddenFieldMaskPaths = errors.DefineInvalidArgument("field_mask_paths", "forbidden path(s) in field mask", "forbidden_paths")

func forbiddenPaths(requestedPaths []string, allowedPaths map[string]struct{}) (invalidPaths []string) {
Expand All @@ -60,25 +52,53 @@ nextRequestedPath:
return
}

type validatorWithContext interface {
ValidateContext(ctx context.Context) error
}

type validator interface {
Validate() error
}

func convertError(err error) error {
if ttnErr, ok := errors.From(err); ok {
return ttnErr
}
return grpc.Errorf(codes.InvalidArgument, err.Error())
}

func validateMessage(ctx context.Context, fullMethod string, msg interface{}) error {
if v, ok := msg.(interface {
GetFieldMask() types.FieldMask
}); ok {
if forbiddenPaths := forbiddenPaths(v.GetFieldMask().Paths, allowedFieldMaskPaths[fullMethod]); len(forbiddenPaths) > 0 {
return errForbiddenFieldMaskPaths.WithAttributes("forbidden_paths", forbiddenPaths)
}
}

switch v := msg.(type) {
case interface {
ValidateContext(context.Context) error
}:
if err := v.ValidateContext(ctx); err != nil {
return convertError(err)
}

case interface {
Validate() error
}:
if err := v.Validate(); err != nil {
return convertError(err)
}

case interface {
ValidateFields(...string) error
}:
if err := v.ValidateFields(); err != nil {
return convertError(err)
}

}
return nil
}

// UnaryServerInterceptor returns a new unary server interceptor that validates
// incoming messages if those incoming messages implement:
// (A) ValidateContext(ctx context.Context) error
// (B) Validate() error
// (C) ValidateFields(...string) error
// If a message implements both, then (A) should call (B).
//
// Invalid messages will be rejected with the error returned from the validator,
Expand All @@ -89,28 +109,33 @@ func convertError(err error) error {
// then the field mask paths are validated according to the registered list.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if v, ok := req.(fieldMaskGetter); ok {
if forbiddenPaths := forbiddenPaths(v.GetFieldMask().Paths, getAllowedFieldMaskPaths(info.FullMethod)); len(forbiddenPaths) > 0 {
return nil, errForbiddenFieldMaskPaths.WithAttributes("forbidden_paths", forbiddenPaths)
}
}
if v, ok := req.(validatorWithContext); ok {
if err := v.ValidateContext(ctx); err != nil {
return nil, convertError(err)
}
} else if v, ok := req.(validator); ok {
if err := v.Validate(); err != nil {
return nil, convertError(err)
}
if err := validateMessage(ctx, info.FullMethod, req); err != nil {
return nil, err
}
return handler(ctx, req)
}
}

type recvWrapper struct {
grpc.ServerStream
fullMethod string
}

func (s *recvWrapper) RecvMsg(msg interface{}) error {
if err := s.ServerStream.RecvMsg(msg); err != nil {
return err
}
if err := validateMessage(s.Context(), s.fullMethod, msg); err != nil {
return err
}
return nil
}

// StreamServerInterceptor returns a new streaming server interceptor that validates
// incoming messages if those incoming messages implement:
// (A) ValidateContext(ctx context.Context) error
// (B) Validate() error
// (C) ValidateFields(...string) error
// If a message implements both, then (A) should call (B).
//
// Invalid messages will be rejected with the error returned from the validator,
Expand All @@ -126,34 +151,9 @@ func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// then the field mask paths are validated according to the registered list.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{ServerStream: stream, allowedFieldMaskPaths: getAllowedFieldMaskPaths(info.FullMethod)}
return handler(srv, wrapper)
}
}

type recvWrapper struct {
grpc.ServerStream
allowedFieldMaskPaths map[string]struct{}
}

func (s *recvWrapper) RecvMsg(m interface{}) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if v, ok := m.(fieldMaskGetter); ok {
requested := v.GetFieldMask().Paths
if forbiddenPaths := forbiddenPaths(requested, s.allowedFieldMaskPaths); len(forbiddenPaths) > 0 {
return errForbiddenFieldMaskPaths.WithAttributes("forbidden_paths", forbiddenPaths)
}
return handler(srv, &recvWrapper{
ServerStream: stream,
fullMethod: info.FullMethod,
})
}
if v, ok := m.(validatorWithContext); ok {
if err := v.ValidateContext(s.Context()); err != nil {
return convertError(err)
}
} else if v, ok := m.(validator); ok {
if err := v.Validate(); err != nil {
return convertError(err)
}
}
return nil
}

0 comments on commit 97f87aa

Please sign in to comment.