diff --git a/examples/request-reply-single/main.go b/examples/request-reply-single/main.go index 31d7386..f66e2fd 100644 --- a/examples/request-reply-single/main.go +++ b/examples/request-reply-single/main.go @@ -44,11 +44,12 @@ func main() { return &Resp{Result: req.A / req.B}, nil }) - fn := bus.Request[*Req, *Resp](c, funcName) + fn := bus.Request(c, funcName) for range 1000 { req := &Req{A: 4, B: 2} - resp, err := fn(ctx, req) + resp := &Resp{} + err := fn(ctx, req, resp) if err != nil { fmt.Printf("%s = %s\n", req, err) } else { diff --git a/examples/request-reply/request/main.go b/examples/request-reply/request/main.go index a67453c..0841a39 100644 --- a/examples/request-reply/request/main.go +++ b/examples/request-reply/request/main.go @@ -19,11 +19,12 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - fn := bus.Request[*data.Req, *data.Resp](c, "func.div") + fn := bus.Request(c, "func.div") for range 1000 { req := &data.Req{A: 4, B: 2} - resp, err := fn(ctx, req) + resp := &data.Resp{} + err := fn(ctx, req, resp) if err != nil { fmt.Printf("%s = %s\n", req, err) } else { diff --git a/request_reply.go b/request_reply.go index 0080152..e2b0a95 100644 --- a/request_reply.go +++ b/request_reply.go @@ -5,14 +5,14 @@ import ( "encoding/json" "fmt" "log/slog" - "reflect" "time" ) -type RequestReplyFunc[Req, Resp any] func(context.Context, Req) (Resp, error) +type RequestFunc func(ctx context.Context, req any, resp any) error +type ReplyFunc[Req, Resp any] func(ctx context.Context, req Req) (Resp, error) -func Request[Req, Resp any](stream Stream, subject string) RequestReplyFunc[Req, Resp] { - return func(ctx context.Context, req Req) (resp Resp, err error) { +func Request(stream Stream, subject string) RequestFunc { + return func(ctx context.Context, req any, resp any) (err error) { evt, err := NewEvent( WithSubject(subject), WithReply(), @@ -20,12 +20,12 @@ func Request[Req, Resp any](stream Stream, subject string) RequestReplyFunc[Req, WithExpiresAt(30*time.Second), ) if err != nil { - return resp, err + return err } err = stream.Put(ctx, evt) if err != nil { - return resp, err + return err } for msgs, err := range stream.Get( @@ -34,11 +34,11 @@ func Request[Req, Resp any](stream Stream, subject string) RequestReplyFunc[Req, WithFromOldest(), ) { if err != nil { - return resp, err + return err } if len(msgs.Events) != 1 { - return resp, fmt.Errorf("expected one event") + return fmt.Errorf("expected one event but got %d", len(msgs.Events)) } evt := msgs.Events[0] @@ -50,31 +50,31 @@ func Request[Req, Resp any](stream Stream, subject string) RequestReplyFunc[Req, err = json.Unmarshal(evt.Data, &replyMsg) if err != nil { - return resp, err + return err } if replyMsg.Type == "error" { var errMsg string err = json.Unmarshal(replyMsg.Payload, &errMsg) if err != nil { - return resp, err + return err } - return resp, fmt.Errorf(errMsg) + return fmt.Errorf(errMsg) } - resp, err = jsonUnmarshal[Resp](replyMsg.Payload) + err = json.Unmarshal(replyMsg.Payload, resp) if err != nil { - return resp, err + return err } - return resp, nil + return nil } return } } -func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn RequestReplyFunc[Req, Resp]) { +func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn ReplyFunc[*Req, *Resp]) { queueName := fmt.Sprintf("queue.%s", subject) msgs := stream.Get( ctx, @@ -97,7 +97,8 @@ func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn event := msg.Events[0] - req, err := jsonUnmarshal[Req](event.Data) + var req Req + err = json.Unmarshal(event.Data, &req) if err != nil { return } @@ -107,7 +108,7 @@ func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn Payload any `json:"payload"` } - resp, err := fn(ctx, req) + resp, err := fn(ctx, &req) if err != nil { replyMsg.Type = "error" replyMsg.Payload = err.Error() @@ -133,17 +134,3 @@ func Reply[Req, Resp any](ctx context.Context, stream Stream, subject string, fn } }() } - -func jsonUnmarshal[T any](data json.RawMessage) (v T, err error) { - v = initializePointer(v) - err = json.Unmarshal(data, &v) - return v, err -} - -func initializePointer[T any](v T) T { - t := reflect.TypeOf(v) - if t.Kind() != reflect.Ptr { - return v - } - return reflect.New(t.Elem()).Interface().(T) -} diff --git a/request_reply_test.go b/request_reply_test.go index ccac42c..8973106 100644 --- a/request_reply_test.go +++ b/request_reply_test.go @@ -21,7 +21,7 @@ func TestRequestReply(t *testing.T) { Result int } - bus.Reply(context.TODO(), client, "func.div", func(ctx context.Context, req *Req) (*Resp, error) { + bus.Reply(context.TODO(), client, "func.div", func(ctx context.Context, req *Req) (resp *Resp, err error) { if req.B == 0 { return nil, fmt.Errorf("division by zero") } @@ -29,16 +29,18 @@ func TestRequestReply(t *testing.T) { return &Resp{Result: req.A / req.B}, nil }) - fn := bus.Request[*Req, *Resp](client, "func.div") + fn := bus.Request(client, "func.div") req := &Req{A: 4, B: 2} - resp, err := fn(context.Background(), req) + resp := &Resp{} + err := fn(context.Background(), req, resp) + assert.NoError(t, err) assert.Equal(t, 2, resp.Result) req = &Req{A: 4, B: 0} - resp, err = fn(context.Background(), req) + resp = &Resp{} + err = fn(context.Background(), req, resp) assert.Error(t, err) assert.Equal(t, "division by zero", err.Error()) - assert.Nil(t, resp) }