Skip to content

Commit

Permalink
Merge pull request #8102 from heyitsanthony/txn-nested
Browse files Browse the repository at this point in the history
api: nested txns
  • Loading branch information
Anthony Romano authored Jun 22, 2017
2 parents 23c8167 + b10ea20 commit 9cb12de
Show file tree
Hide file tree
Showing 14 changed files with 1,033 additions and 507 deletions.
2 changes: 2 additions & 0 deletions Documentation/dev-guide/api_reference_v3.md
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ Empty field.
| request_range | | RangeRequest |
| request_put | | PutRequest |
| request_delete_range | | DeleteRangeRequest |
| request_txn | | TxnRequest |



Expand All @@ -691,6 +692,7 @@ Empty field.
| response_range | | RangeResponse |
| response_put | | PutResponse |
| response_delete_range | | DeleteRangeResponse |
| response_txn | | TxnResponse |



Expand Down
6 changes: 6 additions & 0 deletions Documentation/dev-guide/apispec/swagger/rpc.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1954,6 +1954,9 @@
},
"request_range": {
"$ref": "#/definitions/etcdserverpbRangeRequest"
},
"request_txn": {
"$ref": "#/definitions/etcdserverpbTxnRequest"
}
}
},
Expand Down Expand Up @@ -1993,6 +1996,9 @@
},
"response_range": {
"$ref": "#/definitions/etcdserverpbRangeResponse"
},
"response_txn": {
"$ref": "#/definitions/etcdserverpbTxnResponse"
}
}
},
Expand Down
84 changes: 65 additions & 19 deletions clientv3/integration/txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,35 @@ func TestTxnReadRetry(t *testing.T) {
defer clus.Terminate(t)

kv := clus.Client(0)
clus.Members[0].Stop(t)
<-clus.Members[0].StopNotify()

donec := make(chan struct{})
go func() {
ctx := context.TODO()
_, err := kv.Txn(ctx).Then(clientv3.OpGet("foo")).Commit()
if err != nil {
t.Fatalf("expected response, got error %v", err)
thenOps := [][]clientv3.Op{
{clientv3.OpGet("foo")},
{clientv3.OpTxn(nil, []clientv3.Op{clientv3.OpGet("foo")}, nil)},
{clientv3.OpTxn(nil, nil, nil)},
{},
}
for i := range thenOps {
clus.Members[0].Stop(t)
<-clus.Members[0].StopNotify()

donec := make(chan struct{})
go func() {
_, err := kv.Txn(context.TODO()).Then(thenOps[i]...).Commit()
if err != nil {
t.Fatalf("expected response, got error %v", err)
}
donec <- struct{}{}
}()
// wait for txn to fail on disconnect
time.Sleep(100 * time.Millisecond)

// restart node; client should resume
clus.Members[0].Restart(t)
select {
case <-donec:
case <-time.After(2 * clus.Members[1].ServerConfig.ReqTimeout()):
t.Fatalf("waited too long")
}
donec <- struct{}{}
}()
// wait for txn to fail on disconnect
time.Sleep(100 * time.Millisecond)

// restart node; client should resume
clus.Members[0].Restart(t)
select {
case <-donec:
case <-time.After(2 * clus.Members[1].ServerConfig.ReqTimeout()):
t.Fatalf("waited too long")
}
}

Expand Down Expand Up @@ -179,3 +187,41 @@ func TestTxnCompareRange(t *testing.T) {
t.Fatal("expected prefix compare to false, got compares as true")
}
}

func TestTxnNested(t *testing.T) {
defer testutil.AfterTest(t)

clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
defer clus.Terminate(t)

kv := clus.Client(0)

tresp, err := kv.Txn(context.TODO()).
If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
Then(
clientv3.OpPut("foo", "bar"),
clientv3.OpTxn(nil, []clientv3.Op{clientv3.OpPut("abc", "123")}, nil)).
Else(clientv3.OpPut("foo", "baz")).Commit()
if err != nil {
t.Fatal(err)
}
if len(tresp.Responses) != 2 {
t.Errorf("expected 2 top-level txn responses, got %+v", tresp.Responses)
}

// check txn writes were applied
resp, err := kv.Get(context.TODO(), "foo")
if err != nil {
t.Fatal(err)
}
if len(resp.Kvs) != 1 || string(resp.Kvs[0].Value) != "bar" {
t.Errorf("unexpected Get response %+v", resp)
}
resp, err = kv.Get(context.TODO(), "abc")
if err != nil {
t.Fatal(err)
}
if len(resp.Kvs) != 1 || string(resp.Kvs[0].Value) != "123" {
t.Errorf("unexpected Get response %+v", resp)
}
}
9 changes: 8 additions & 1 deletion clientv3/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ type OpResponse struct {
put *PutResponse
get *GetResponse
del *DeleteResponse
txn *TxnResponse
}

func (op OpResponse) Put() *PutResponse { return op.put }
func (op OpResponse) Get() *GetResponse { return op.get }
func (op OpResponse) Del() *DeleteResponse { return op.del }
func (op OpResponse) Txn() *TxnResponse { return op.txn }

type kv struct {
remote pb.KVClient
Expand Down Expand Up @@ -134,7 +136,6 @@ func (kv *kv) Do(ctx context.Context, op Op) (OpResponse, error) {
func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) {
var err error
switch op.t {
// TODO: handle other ops
case tRange:
var resp *pb.RangeResponse
resp, err = kv.remote.Range(ctx, op.toRangeRequest(), grpc.FailFast(false))
Expand All @@ -155,6 +156,12 @@ func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) {
if err == nil {
return OpResponse{del: (*DeleteResponse)(resp)}, nil
}
case tTxn:
var resp *pb.TxnResponse
resp, err = kv.remote.Txn(ctx, op.toTxnRequest())
if err == nil {
return OpResponse{txn: (*TxnResponse)(resp)}, nil
}
default:
panic("Unknown op")
}
Expand Down
64 changes: 39 additions & 25 deletions clientv3/namespace/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (kv *kvPrefix) Delete(ctx context.Context, key string, opts ...clientv3.OpO
}

func (kv *kvPrefix) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) {
if len(op.KeyBytes()) == 0 {
if len(op.KeyBytes()) == 0 && !op.IsTxn() {
return clientv3.OpResponse{}, rpctypes.ErrEmptyKey
}
r, err := kv.KV.Do(ctx, kv.prefixOp(op))
Expand All @@ -88,6 +88,8 @@ func (kv *kvPrefix) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse
kv.unprefixPutResponse(r.Put())
case r.Del() != nil:
kv.unprefixDeleteResponse(r.Del())
case r.Txn() != nil:
kv.unprefixTxnResponse(r.Txn())
}
return r, nil
}
Expand All @@ -102,34 +104,17 @@ func (kv *kvPrefix) Txn(ctx context.Context) clientv3.Txn {
}

func (txn *txnPrefix) If(cs ...clientv3.Cmp) clientv3.Txn {
newCmps := make([]clientv3.Cmp, len(cs))
for i := range cs {
newCmps[i] = cs[i]
pfxKey, endKey := txn.kv.prefixInterval(cs[i].KeyBytes(), cs[i].RangeEnd)
newCmps[i].WithKeyBytes(pfxKey)
if len(cs[i].RangeEnd) != 0 {
newCmps[i].RangeEnd = endKey
}
}
txn.Txn = txn.Txn.If(newCmps...)
txn.Txn = txn.Txn.If(txn.kv.prefixCmps(cs)...)
return txn
}

func (txn *txnPrefix) Then(ops ...clientv3.Op) clientv3.Txn {
newOps := make([]clientv3.Op, len(ops))
for i := range ops {
newOps[i] = txn.kv.prefixOp(ops[i])
}
txn.Txn = txn.Txn.Then(newOps...)
txn.Txn = txn.Txn.Then(txn.kv.prefixOps(ops)...)
return txn
}

func (txn *txnPrefix) Else(ops ...clientv3.Op) clientv3.Txn {
newOps := make([]clientv3.Op, len(ops))
for i := range ops {
newOps[i] = txn.kv.prefixOp(ops[i])
}
txn.Txn = txn.Txn.Else(newOps...)
txn.Txn = txn.Txn.Else(txn.kv.prefixOps(ops)...)
return txn
}

Expand All @@ -143,10 +128,14 @@ func (txn *txnPrefix) Commit() (*clientv3.TxnResponse, error) {
}

func (kv *kvPrefix) prefixOp(op clientv3.Op) clientv3.Op {
begin, end := kv.prefixInterval(op.KeyBytes(), op.RangeBytes())
op.WithKeyBytes(begin)
op.WithRangeBytes(end)
return op
if !op.IsTxn() {
begin, end := kv.prefixInterval(op.KeyBytes(), op.RangeBytes())
op.WithKeyBytes(begin)
op.WithRangeBytes(end)
return op
}
cmps, thenOps, elseOps := op.Txn()
return clientv3.OpTxn(kv.prefixCmps(cmps), kv.prefixOps(thenOps), kv.prefixOps(elseOps))
}

func (kv *kvPrefix) unprefixGetResponse(resp *clientv3.GetResponse) {
Expand Down Expand Up @@ -182,6 +171,10 @@ func (kv *kvPrefix) unprefixTxnResponse(resp *clientv3.TxnResponse) {
if tv.ResponseDeleteRange != nil {
kv.unprefixDeleteResponse((*clientv3.DeleteResponse)(tv.ResponseDeleteRange))
}
case *pb.ResponseOp_ResponseTxn:
if tv.ResponseTxn != nil {
kv.unprefixTxnResponse((*clientv3.TxnResponse)(tv.ResponseTxn))
}
default:
}
}
Expand All @@ -190,3 +183,24 @@ func (kv *kvPrefix) unprefixTxnResponse(resp *clientv3.TxnResponse) {
func (p *kvPrefix) prefixInterval(key, end []byte) (pfxKey []byte, pfxEnd []byte) {
return prefixInterval(p.pfx, key, end)
}

func (kv *kvPrefix) prefixCmps(cs []clientv3.Cmp) []clientv3.Cmp {
newCmps := make([]clientv3.Cmp, len(cs))
for i := range cs {
newCmps[i] = cs[i]
pfxKey, endKey := kv.prefixInterval(cs[i].KeyBytes(), cs[i].RangeEnd)
newCmps[i].WithKeyBytes(pfxKey)
if len(cs[i].RangeEnd) != 0 {
newCmps[i].RangeEnd = endKey
}
}
return newCmps
}

func (kv *kvPrefix) prefixOps(ops []clientv3.Op) []clientv3.Op {
newOps := make([]clientv3.Op, len(ops))
for i := range ops {
newOps[i] = kv.prefixOp(ops[i])
}
return newOps
}
44 changes: 44 additions & 0 deletions clientv3/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
tRange opType = iota + 1
tPut
tDeleteRange
tTxn
)

var (
Expand Down Expand Up @@ -67,10 +68,18 @@ type Op struct {
// for put
val []byte
leaseID LeaseID

// txn
cmps []Cmp
thenOps []Op
elseOps []Op
}

// accesors / mutators

func (op Op) IsTxn() bool { return op.t == tTxn }
func (op Op) Txn() ([]Cmp, []Op, []Op) { return op.cmps, op.thenOps, op.elseOps }

// KeyBytes returns the byte slice holding the Op's key.
func (op Op) KeyBytes() []byte { return op.key }

Expand Down Expand Up @@ -113,6 +122,22 @@ func (op Op) toRangeRequest() *pb.RangeRequest {
return r
}

func (op Op) toTxnRequest() *pb.TxnRequest {
thenOps := make([]*pb.RequestOp, len(op.thenOps))
for i, tOp := range op.thenOps {
thenOps[i] = tOp.toRequestOp()
}
elseOps := make([]*pb.RequestOp, len(op.elseOps))
for i, eOp := range op.elseOps {
elseOps[i] = eOp.toRequestOp()
}
cmps := make([]*pb.Compare, len(op.cmps))
for i := range op.cmps {
cmps[i] = (*pb.Compare)(&op.cmps[i])
}
return &pb.TxnRequest{Compare: cmps, Success: thenOps, Failure: elseOps}
}

func (op Op) toRequestOp() *pb.RequestOp {
switch op.t {
case tRange:
Expand All @@ -123,12 +148,27 @@ func (op Op) toRequestOp() *pb.RequestOp {
case tDeleteRange:
r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end, PrevKv: op.prevKV}
return &pb.RequestOp{Request: &pb.RequestOp_RequestDeleteRange{RequestDeleteRange: r}}
case tTxn:
return &pb.RequestOp{Request: &pb.RequestOp_RequestTxn{RequestTxn: op.toTxnRequest()}}
default:
panic("Unknown Op")
}
}

func (op Op) isWrite() bool {
if op.t == tTxn {
for _, tOp := range op.thenOps {
if tOp.isWrite() {
return true
}
}
for _, tOp := range op.elseOps {
if tOp.isWrite() {
return true
}
}
return false
}
return op.t != tRange
}

Expand Down Expand Up @@ -194,6 +234,10 @@ func OpPut(key, val string, opts ...OpOption) Op {
return ret
}

func OpTxn(cmps []Cmp, thenOps []Op, elseOps []Op) Op {
return Op{t: tTxn, cmps: cmps, thenOps: thenOps, elseOps: elseOps}
}

func opWatch(key string, opts ...OpOption) Op {
ret := Op{t: tRange, key: []byte(key)}
ret.applyOpts(opts)
Expand Down
Loading

0 comments on commit 9cb12de

Please sign in to comment.