Skip to content

Commit

Permalink
Refactored Request/Reply and removed reflect dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
alinz committed May 18, 2024
1 parent b30e446 commit dbee65e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 40 deletions.
5 changes: 3 additions & 2 deletions examples/request-reply-single/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ func main() {
return &Resp{Result: req.A / req.B}, nil
})

fn := bus.Request[*Req, *Resp](c, funcName)
fn := bus.Request(c, funcName)

for range 1000 {
req := &Req{A: 4, B: 2}
resp, err := fn(ctx, req)
resp := &Resp{}
err := fn(ctx, req, resp)
if err != nil {
fmt.Printf("%s = %s\n", req, err)
} else {
Expand Down
5 changes: 3 additions & 2 deletions examples/request-reply/request/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

fn := bus.Request[*data.Req, *data.Resp](c, "func.div")
fn := bus.Request(c, "func.div")

for range 1000 {
req := &data.Req{A: 4, B: 2}
resp, err := fn(ctx, req)
resp := &data.Resp{}
err := fn(ctx, req, resp)
if err != nil {
fmt.Printf("%s = %s\n", req, err)
} else {
Expand Down
49 changes: 18 additions & 31 deletions request_reply.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,27 @@ import (
"encoding/json"
"fmt"
"log/slog"
"reflect"
"time"
)

type RequestReplyFunc[Req, Resp any] func(context.Context, Req) (Resp, error)
type RequestFunc func(ctx context.Context, req any, resp any) error
type ReplyFunc[Req, Resp any] func(ctx context.Context, req Req) (Resp, error)

func Request[Req, Resp any](stream Stream, subject string) RequestReplyFunc[Req, Resp] {
return func(ctx context.Context, req Req) (resp Resp, err error) {
func Request(stream Stream, subject string) RequestFunc {
return func(ctx context.Context, req any, resp any) (err error) {
evt, err := NewEvent(
WithSubject(subject),
WithReply(),
WithJsonData(req),
WithExpiresAt(30*time.Second),
)
if err != nil {
return resp, err
return err
}

err = stream.Put(ctx, evt)
if err != nil {
return resp, err
return err
}

for msgs, err := range stream.Get(
Expand All @@ -34,11 +34,11 @@ func Request[Req, Resp any](stream Stream, subject string) RequestReplyFunc[Req,
WithFromOldest(),
) {
if err != nil {
return resp, err
return err
}

if len(msgs.Events) != 1 {
return resp, fmt.Errorf("expected one event")
return fmt.Errorf("expected one event but got %d", len(msgs.Events))
}

evt := msgs.Events[0]
Expand All @@ -50,31 +50,31 @@ func Request[Req, Resp any](stream Stream, subject string) RequestReplyFunc[Req,

err = json.Unmarshal(evt.Data, &replyMsg)
if err != nil {
return resp, err
return err
}

if replyMsg.Type == "error" {
var errMsg string
err = json.Unmarshal(replyMsg.Payload, &errMsg)
if err != nil {
return resp, err
return err
}
return resp, fmt.Errorf(errMsg)
return fmt.Errorf(errMsg)
}

resp, err = jsonUnmarshal[Resp](replyMsg.Payload)
err = json.Unmarshal(replyMsg.Payload, resp)
if err != nil {
return resp, err
return err
}

return resp, nil
return nil
}

return
}
}

func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn RequestReplyFunc[Req, Resp]) {
func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn ReplyFunc[*Req, *Resp]) {
queueName := fmt.Sprintf("queue.%s", subject)
msgs := stream.Get(
ctx,
Expand All @@ -97,7 +97,8 @@ func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn

event := msg.Events[0]

req, err := jsonUnmarshal[Req](event.Data)
var req Req
err = json.Unmarshal(event.Data, &req)
if err != nil {
return
}
Expand All @@ -107,7 +108,7 @@ func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn
Payload any `json:"payload"`
}

resp, err := fn(ctx, req)
resp, err := fn(ctx, &req)
if err != nil {
replyMsg.Type = "error"
replyMsg.Payload = err.Error()
Expand All @@ -133,17 +134,3 @@ func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn
}
}()
}

func jsonUnmarshal[T any](data json.RawMessage) (v T, err error) {
v = initializePointer(v)
err = json.Unmarshal(data, &v)
return v, err
}

func initializePointer[T any](v T) T {
t := reflect.TypeOf(v)
if t.Kind() != reflect.Ptr {
return v
}
return reflect.New(t.Elem()).Interface().(T)
}
12 changes: 7 additions & 5 deletions request_reply_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,26 @@ func TestRequestReply(t *testing.T) {
Result int
}

bus.Reply(context.TODO(), client, "func.div", func(ctx context.Context, req *Req) (*Resp, error) {
bus.Reply(context.TODO(), client, "func.div", func(ctx context.Context, req *Req) (resp *Resp, err error) {
if req.B == 0 {
return nil, fmt.Errorf("division by zero")
}

return &Resp{Result: req.A / req.B}, nil
})

fn := bus.Request[*Req, *Resp](client, "func.div")
fn := bus.Request(client, "func.div")

req := &Req{A: 4, B: 2}
resp, err := fn(context.Background(), req)
resp := &Resp{}
err := fn(context.Background(), req, resp)

assert.NoError(t, err)
assert.Equal(t, 2, resp.Result)

req = &Req{A: 4, B: 0}
resp, err = fn(context.Background(), req)
resp = &Resp{}
err = fn(context.Background(), req, resp)
assert.Error(t, err)
assert.Equal(t, "division by zero", err.Error())
assert.Nil(t, resp)
}

0 comments on commit dbee65e

Please sign in to comment.