Skip to content

Commit a679500

Browse files
committed
Add UnmarshalMsgWithState to Unmarshaler interface
1 parent ab5758d commit a679500

File tree

4 files changed

+47
-5
lines changed

4 files changed

+47
-5
lines changed

gen/unmarshal.go

+17-2
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,17 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
6060
u.p.printf("\n return ((*(%s))(%s)).UnmarshalMsg(bts)", baseType, c)
6161
u.p.printf("\n}")
6262

63+
u.p.printf("\nfunc (%s %s) UnmarshalMsgWithState(bts []byte, st msgp.UnmarshalState) ([]byte, error) {", c, methodRecv)
64+
u.p.printf("\n return ((*(%s))(%s)).UnmarshalMsgWithState(bts, st)", baseType, c)
65+
u.p.printf("\n}")
66+
6367
u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv)
6468
u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv)
6569
u.p.printf("\n return ok")
6670
u.p.printf("\n}")
6771

6872
u.topics.Add(methodRecv, "UnmarshalMsg")
73+
u.topics.Add(methodRecv, "UnmarshalMsgWithState")
6974
u.topics.Add(methodRecv, "CanUnmarshalMsg")
7075

7176
return u.msgs, u.p.err
@@ -75,7 +80,12 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
7580
c := p.Varname()
7681
methodRecv := methodReceiver(p)
7782

78-
u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
83+
u.p.printf("\nfunc (%s %s) UnmarshalMsgWithState(bts []byte, st msgp.UnmarshalState) (o []byte, err error) {", c, methodRecv)
84+
u.p.printf("\n if st.Depth == 0 {")
85+
u.p.printf("\n err = msgp.ErrMaxDepthExceeded{}")
86+
u.p.printf("\n return")
87+
u.p.printf("\n }")
88+
u.p.printf("\n st.Depth--")
7989
next(u, p)
8090
u.p.print("\no = bts")
8191

@@ -91,12 +101,17 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
91101
}
92102
u.p.nakedReturn()
93103

104+
u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
105+
u.p.printf("\n return %s.UnmarshalMsgWithState(bts, msgp.DefaultUnmarshalState)", c)
106+
u.p.printf("\n}")
107+
94108
u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv)
95109
u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv)
96110
u.p.printf("\n return ok")
97111
u.p.printf("\n}")
98112

99113
u.topics.Add(methodRecv, "UnmarshalMsg")
114+
u.topics.Add(methodRecv, "UnmarshalMsgWithState")
100115
u.topics.Add(methodRecv, "CanUnmarshalMsg")
101116

102117
return u.msgs, u.p.err
@@ -236,7 +251,7 @@ func (u *unmarshalGen) gBase(b *BaseElem) {
236251
case Ext:
237252
u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered)
238253
case IDENT:
239-
u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered)
254+
u.p.printf("\nbts, err = %s.UnmarshalMsgWithState(bts, st)", lowered)
240255
case String:
241256
if b.common.AllocBound() != "" {
242257
sz := randIdent()

msgp/errors.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,9 @@ func Resumable(e error) bool {
8181
//
8282
// ErrShortBytes is not wrapped with any context due to backward compatibility
8383
// issues with the public API.
84-
//
8584
func WrapError(err error, ctx ...interface{}) error {
8685
switch e := err.(type) {
87-
case errShort:
86+
case errShort, ErrMaxDepthExceeded:
8887
return e
8988
case contextError:
9089
return e.withContext(ctxString(ctx))
@@ -344,3 +343,12 @@ func (e *ErrUnsupportedType) withContext(ctx string) error {
344343
o.ctx = addCtx(o.ctx, ctx)
345344
return &o
346345
}
346+
347+
// ErrMaxDepthExceeded is returned if the maximum traversal depth is exceeded.
348+
type ErrMaxDepthExceeded struct{}
349+
350+
// Error implements error
351+
func (e ErrMaxDepthExceeded) Error() string { return "Max depth exceeded" }
352+
353+
// Resumable implements Error
354+
func (e ErrMaxDepthExceeded) Resumable() bool { return false }

msgp/read.go

+9
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,14 @@ func (t Type) String() string {
7878
// field in a struct rather than unmarshaling the entire struct.
7979
type Unmarshaler interface {
8080
UnmarshalMsg([]byte) ([]byte, error)
81+
UnmarshalMsgWithState([]byte, UnmarshalState) ([]byte, error)
8182
CanUnmarshalMsg(o interface{}) bool
8283
}
84+
85+
// UnmarshalState holds state while running UnmarshalMsg.
86+
type UnmarshalState struct {
87+
AllowableDepth uint64
88+
}
89+
90+
// DefaultUnmarshalState defines the default state.
91+
var DefaultUnmarshalState = UnmarshalState{AllowableDepth: 10000}

msgp/read_bytes.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ func (*Raw) CanUnmarshalMsg(z interface{}) bool {
8787
// It sets the contents of *Raw to be the next
8888
// object in the provided byte slice.
8989
func (r *Raw) UnmarshalMsg(b []byte) ([]byte, error) {
90+
return r.UnmarshalMsgWithState(b, DefaultUnmarshalState)
91+
}
92+
93+
// UnmarshalMsg implements msgp.Unmarshaler.
94+
// It sets the contents of *Raw to be the next
95+
// object in the provided byte slice.
96+
func (r *Raw) UnmarshalMsgWithState(b []byte, st UnmarshalState) ([]byte, error) {
97+
if st.AllowableDepth == 0 {
98+
return nil, ErrMaxDepthExceeded{}
99+
}
90100
l := len(b)
91101
out, err := Skip(b)
92102
if err != nil {
@@ -1185,7 +1195,7 @@ func ReadStringBytes(b []byte) (string, []byte, error) {
11851195
// into a slice of bytes. 'v' is the value of
11861196
// the 'str' object, which may reside in memory
11871197
// pointed to by 'scratch.' 'o' is the remaining bytes
1188-
// in 'b.''
1198+
// in 'b'.
11891199
// Possible errors:
11901200
// - ErrShortBytes (b not long enough)
11911201
// - TypeError{} (not 'str' type)

0 commit comments

Comments
 (0)