diff --git a/connection.go b/connection.go index 0f3f6a4..39a0bb0 100644 --- a/connection.go +++ b/connection.go @@ -206,18 +206,46 @@ func DialConfig(url string, config Config) (*Connection, error) { } if config.SASL == nil { - config.SASL = []Authentication{uri.PlainAuth()} + if uri.AuthMechanism != nil { + for _, identifier := range uri.AuthMechanism { + switch strings.ToUpper(identifier) { + case "PLAIN": + config.SASL = append(config.SASL, uri.PlainAuth()) + case "AMQPLAIN": + config.SASL = append(config.SASL, uri.AMQPlainAuth()) + case "EXTERNAL": + config.SASL = append(config.SASL, &ExternalAuth{}) + default: + return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier) + } + } + } else { + config.SASL = []Authentication{uri.PlainAuth()} + } } if config.Vhost == "" { config.Vhost = uri.Vhost } + if config.Heartbeat == 0 { + config.Heartbeat = time.Duration(uri.Heartbeat) * time.Second + } + + if config.ChannelMax == 0 { + config.ChannelMax = uri.ChannelMax + } + + connectionTimeout := defaultConnectionTimeout + if uri.ConnectionTimeout != 0 { + connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond + } + addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10)) dialer := config.Dial if dialer == nil { - dialer = DefaultDial(defaultConnectionTimeout) + dialer = DefaultDial(connectionTimeout) } conn, err = dialer("tcp", addr) diff --git a/uri.go b/uri.go index 87ef09e..6abfa7d 100644 --- a/uri.go +++ b/uri.go @@ -7,6 +7,7 @@ package amqp091 import ( "errors" + "fmt" "net" "net/url" "strconv" @@ -32,16 +33,20 @@ var defaultURI = URI{ // URI represents a parsed AMQP URI string. type URI struct { - Scheme string - Host string - Port int - Username string - Password string - Vhost string - CertFile string // client TLS auth - path to certificate (PEM) - CACertFile string // client TLS auth - path to CA certificate (PEM) - KeyFile string // client TLS auth - path to private key (PEM) - ServerName string // client TLS auth - server name + Scheme string + Host string + Port int + Username string + Password string + Vhost string + CertFile string // client TLS auth - path to certificate (PEM) + CACertFile string // client TLS auth - path to CA certificate (PEM) + KeyFile string // client TLS auth - path to private key (PEM) + ServerName string // client TLS auth - server name + AuthMechanism []string + Heartbeat int + ConnectionTimeout int + ChannelMax uint16 } // ParseURI attempts to parse the given AMQP URI according to the spec. @@ -62,6 +67,10 @@ type URI struct { // keyfile: // cacertfile: // server_name_indication: +// auth_mechanism: +// heartbeat: +// connection_timeout: +// channel_max: // // If cacertfile is not provided, system CA certificates will be used. // Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided. @@ -134,6 +143,31 @@ func ParseURI(uri string) (URI, error) { builder.KeyFile = params.Get("keyfile") builder.CACertFile = params.Get("cacertfile") builder.ServerName = params.Get("server_name_indication") + builder.AuthMechanism = params["auth_mechanism"] + + if params.Has("heartbeat") { + value, err := strconv.Atoi(params.Get("heartbeat")) + if err != nil { + return builder, fmt.Errorf("heartbeat is not an integer: %v", err) + } + builder.Heartbeat = value + } + + if params.Has("connection_timeout") { + value, err := strconv.Atoi(params.Get("connection_timeout")) + if err != nil { + return builder, fmt.Errorf("connection_timeout is not an integer: %v", err) + } + builder.ConnectionTimeout = value + } + + if params.Has("channel_max") { + value, err := strconv.ParseUint(params.Get("channel_max"), 10, 16) + if err != nil { + return builder, fmt.Errorf("connection_timeout is not an integer: %v", err) + } + builder.ChannelMax = uint16(value) + } return builder, nil } diff --git a/uri_test.go b/uri_test.go index a369441..68940ed 100644 --- a/uri_test.go +++ b/uri_test.go @@ -6,6 +6,7 @@ package amqp091 import ( + "reflect" "testing" ) @@ -388,3 +389,23 @@ func TestURITLSConfig(t *testing.T) { t.Fatal("Server name not set") } } + +func TestURIParameters(t *testing.T) { + url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8" + uri, err := ParseURI(url) + if err != nil { + t.Fatal("Could not parse") + } + if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) { + t.Fatal("AuthMechanism not set") + } + if uri.Heartbeat != 2 { + t.Fatal("Heartbeat not set") + } + if uri.ConnectionTimeout != 5000 { + t.Fatal("ConnectionTimeout not set") + } + if uri.ChannelMax != 8 { + t.Fatal("ChannelMax name not set") + } +}