diff --git a/README.md b/README.md index 9a40124..e48e67f 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,12 @@ Look in the examples directory to learn how to use this library: API documentation is available at [godoc.org](https://godoc.org/github.com/go-routeros/routeros). Page on the [Mikrotik Wiki](http://wiki.mikrotik.com/wiki/API_in_Go). -Released versions: +Usage of `gopkg.in` was removed in favor of Go modules. Please, update you import paths to +`github.com/go-routeros/routeros/v3`. + +Old released versions: [**v2**](https://github.com/go-routeros/routeros/tree/v2) [**v1**](https://github.com/go-routeros/routeros/tree/v1) + +To install it, run: +`go get github.com/go-routeros/routeros/v3` \ No newline at end of file diff --git a/async.go b/async.go index 1bbd7fe..22ba5ba 100644 --- a/async.go +++ b/async.go @@ -1,6 +1,10 @@ package routeros -import "github.com/go-routeros/routeros/proto" +import ( + "context" + + "github.com/go-routeros/routeros/v3/proto" +) type sentenceProcessor interface { processSentence(sen *proto.Sentence) (bool, error) @@ -12,6 +16,11 @@ type replyCloser interface { // Async starts asynchronous mode and returns immediately. func (c *Client) Async() <-chan error { + return c.AsyncContext(context.Background()) +} + +// AsyncContext starts asynchronous mode with context and returns immediately. +func (c *Client) AsyncContext(ctx context.Context) <-chan error { c.mu.Lock() defer c.mu.Unlock() @@ -23,16 +32,16 @@ func (c *Client) Async() <-chan error { } c.async = true c.tags = make(map[string]sentenceProcessor) - go c.asyncLoopChan(errC) + go c.asyncLoopChan(ctx, errC) return errC } -func (c *Client) asyncLoopChan(errC chan<- error) { +func (c *Client) asyncLoopChan(ctx context.Context, errC chan<- error) { defer close(errC) + // If c.Close() has been called, c.closing will be true, and // err will be “use of closed network connection”. Ignore that error. - err := c.asyncLoop() - if err != nil { + if err := c.asyncLoop(ctx); err != nil { c.mu.Lock() closing := c.closing c.mu.Unlock() @@ -42,9 +51,17 @@ func (c *Client) asyncLoopChan(errC chan<- error) { } } -func (c *Client) asyncLoop() error { +// asyncLoop - main goroutine for async mode. Read and process sentences, handle context done. +func (c *Client) asyncLoop(ctx context.Context) error { + go func() { + <-ctx.Done() + + c.r.Cancel() + }() + for { sen, err := c.r.ReadSentence() + if err != nil { c.closeTags(err) return err @@ -53,6 +70,8 @@ func (c *Client) asyncLoop() error { c.mu.Lock() r, ok := c.tags[sen.Tag] c.mu.Unlock() + + // cannot find tag for this sentence, ignore if !ok { continue } @@ -71,9 +90,22 @@ func (c *Client) closeTags(err error) { c.mu.Lock() defer c.mu.Unlock() + // If c.Close() has been called, c.closing will be true, and + // err will be “use of closed network connection”. Ignore that error. + if c.closing { + for _, r := range c.tags { + closeReply(r, nil) + } + + c.tags = nil + + return + } + for _, r := range c.tags { closeReply(r, err) } + c.tags = nil } diff --git a/chan_reply.go b/chan_reply.go index bdfcda1..bb48d72 100644 --- a/chan_reply.go +++ b/chan_reply.go @@ -1,6 +1,8 @@ package routeros -import "github.com/go-routeros/routeros/proto" +import ( + "github.com/go-routeros/routeros/v3/proto" +) // chanReply is shared between ListenReply and AsyncReply. type chanReply struct { diff --git a/client.go b/client.go index 3db98a5..83db72d 100644 --- a/client.go +++ b/client.go @@ -4,88 +4,171 @@ Package routeros is a pure Go client library for accessing Mikrotik devices usin package routeros import ( - "crypto/md5" + "context" + "crypto/md5" //nolint:gosec "crypto/tls" "encoding/hex" "errors" "fmt" "io" + "log/slog" "net" + "os" "sync" + "sync/atomic" + "time" - "github.com/go-routeros/routeros/proto" + "github.com/go-routeros/routeros/v3/proto" ) // Client is a RouterOS API client. type Client struct { Queue int + log *slog.Logger + logMutex sync.Mutex + rwc io.ReadWriteCloser - r proto.Reader - w proto.Writer closing bool async bool nextTag int64 tags map[string]sentenceProcessor mu sync.Mutex + mw sync.Mutex + + r proto.Reader + w proto.Writer } +var ( + ErrNoChallengeReceived = errors.New("no ret (challenge) received") + ErrInvalidChallengeReceived = errors.New("invalid ret (challenge) hex string received") +) + +var defaultHandler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelInfo, +}) + // NewClient returns a new Client over rwc. Login must be called. func NewClient(rwc io.ReadWriteCloser) (*Client, error) { return &Client{ rwc: rwc, - r: proto.NewReader(rwc), - w: proto.NewWriter(rwc), + log: slog.New(defaultHandler), + + r: proto.NewReader(rwc), + w: proto.NewWriter(rwc), }, nil } +// incrementTag atomically increments tag number and returns result +func (c *Client) incrementTag() int64 { + return atomic.AddInt64(&c.nextTag, 1) +} + +// IsAsync return true if client run in async mode. +func (c *Client) IsAsync() bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.async +} + // Dial connects and logs in to a RouterOS device. func Dial(address, username, password string) (*Client, error) { - conn, err := net.Dial("tcp", address) + return DialContext(context.Background(), address, username, password) +} + +// DialTimeout connects and logs in to a RouterOS device with timeout. +func DialTimeout(address, username, password string, timeout time.Duration) (*Client, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return DialContext(ctx, address, username, password) +} + +// DialContext connects and logs in to a RouterOS device using context. +func DialContext(ctx context.Context, address, username, password string) (*Client, error) { + conn, err := new(net.Dialer).DialContext(ctx, "tcp", address) if err != nil { - return nil, err + return nil, fmt.Errorf("could not connect to router os: %w", err) } - return newClientAndLogin(conn, username, password) + return newClientAndLogin(ctx, conn, username, password) } // DialTLS connects and logs in to a RouterOS device using TLS. func DialTLS(address, username, password string, tlsConfig *tls.Config) (*Client, error) { - conn, err := tls.Dial("tcp", address, tlsConfig) + return DialTLSContext(context.Background(), address, username, password, tlsConfig) +} + +// DialTLSTimeout connects and logs in to a RouterOS device using TLS with timeout. +func DialTLSTimeout(address, username, password string, tlsConfig *tls.Config, timeout time.Duration) (*Client, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return DialTLSContext(ctx, address, username, password, tlsConfig) +} + +// DialTLSContext connects and logs in to a RouterOS device using TLS and context. +func DialTLSContext(ctx context.Context, address, username, password string, tlsConfig *tls.Config) (*Client, error) { + conn, err := (&tls.Dialer{Config: tlsConfig}).DialContext(ctx, "tcp", address) if err != nil { - return nil, err + return nil, fmt.Errorf("could not connect to router os: %w", err) } - return newClientAndLogin(conn, username, password) + return newClientAndLogin(ctx, conn, username, password) } -func newClientAndLogin(rwc io.ReadWriteCloser, username, password string) (*Client, error) { +// newClientAndLogin - creates a new client with context over specified rwc, then logs in to the RouterOS, returns new client. +func newClientAndLogin(ctx context.Context, rwc io.ReadWriteCloser, username, password string) (*Client, error) { c, err := NewClient(rwc) if err != nil { - rwc.Close() - return nil, err + return nil, fmt.Errorf("could not connect to router os: %w; close: %w", err, rwc.Close()) } - err = c.Login(username, password) + err = c.LoginContext(ctx, username, password) if err != nil { - c.Close() - return nil, err + return nil, fmt.Errorf("could not login: %w; close %w", err, c.Close()) } return c, nil } +func (c *Client) SetLogHandler(handler LogHandler) { + c.logMutex.Lock() + c.log = slog.New(handler) + c.logMutex.Unlock() +} + +func (c *Client) logger() *slog.Logger { + c.logMutex.Lock() + defer c.logMutex.Unlock() + + return c.log +} + // Close closes the connection to the RouterOS device. -func (c *Client) Close() { +func (c *Client) Close() error { c.mu.Lock() + defer c.mu.Unlock() + + c.r.Close() + c.w.Close() + if c.closing { - c.mu.Unlock() - return + return nil } + c.closing = true - c.mu.Unlock() - c.rwc.Close() + + return c.rwc.Close() } // Login runs the /login command. Dial and DialTLS call this automatically. func (c *Client) Login(username, password string) error { - r, err := c.Run("/login", "=name="+username, "=password="+password) + return c.LoginContext(context.Background(), username, password) +} + +// LoginContext runs the /login command. DialContext and DialTLSContext call this automatically. +func (c *Client) LoginContext(ctx context.Context, username, password string) error { + r, err := c.RunContext(ctx, "/login", "=name="+username, "=password="+password) if err != nil { return err } @@ -95,27 +178,25 @@ func (c *Client) Login(username, password string) error { if r.Done != nil { return nil } - return errors.New("RouterOS: /login: no ret (challenge) received") + return fmt.Errorf("RouterOS: /login: %w", ErrNoChallengeReceived) } // Login method pre-6.43 two stages, challenge - b, err := hex.DecodeString(ret) - if err != nil { - return fmt.Errorf("RouterOS: /login: invalid ret (challenge) hex string received: %s", err) + var dec []byte + if dec, err = hex.DecodeString(ret); err != nil { + return fmt.Errorf("RouterOS: /login: %w: %w", ErrInvalidChallengeReceived, err) } - r, err = c.Run("/login", "=name="+username, "=response="+c.challengeResponse(b, password)) - if err != nil { - return err - } + _, err = c.RunContext(ctx, "/login", "=name="+username, "=response="+c.challengeResponse(dec, password)) - return nil + return err } +// challengeResponse - prepare MD5 hash for auth challenge response func (c *Client) challengeResponse(cha []byte, password string) string { - h := md5.New() + h := md5.New() //nolint:gosec h.Write([]byte{0}) - io.WriteString(h, password) + h.Write([]byte(password)) h.Write(cha) return fmt.Sprintf("00%x", h.Sum(nil)) } diff --git a/client_test.go b/client_test.go index 939fa07..327f5e3 100644 --- a/client_test.go +++ b/client_test.go @@ -1,15 +1,16 @@ package routeros import ( - "flag" - "strings" + "context" + "errors" + "io" + "net" + "os" "testing" -) + "time" -var ( - routerosAddress = flag.String("routeros.address", "", "RouterOS address:port") - routerosUsername = flag.String("routeros.username", "admin", "RouterOS user name") - routerosPassword = flag.String("routeros.password", "admin", "RouterOS password") + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type liveTest struct { @@ -17,6 +18,26 @@ type liveTest struct { c *Client } +type testConfig struct { + Address string + Username string + Password string +} + +func fetchConfig(t *testing.T) *testConfig { + cfg := &testConfig{ + Address: os.Getenv("ROUTEROS_TEST_ADDRESS"), + Username: os.Getenv("ROUTEROS_TEST_USERNAME"), + Password: os.Getenv("ROUTEROS_TEST_PASSWORD"), + } + + if cfg.Address == "" || cfg.Username == "" || cfg.Password == "" { + t.Skip("skipping integration tests because address or username or password is missing") + } + + return cfg +} + func newLiveTest(t *testing.T) *liveTest { tt := &liveTest{T: t} tt.connect() @@ -24,53 +45,66 @@ func newLiveTest(t *testing.T) *liveTest { } func (t *liveTest) connect() { - if *routerosAddress == "" { - t.Skip("Flag -routeros.address not set") - } + cfg := fetchConfig(t.T) + ctx := context.WithValue(context.Background(), "logger", t.T) + var err error - t.c, err = Dial(*routerosAddress, *routerosUsername, *routerosPassword) - if err != nil { - t.Fatal(err) - } + t.c, err = DialContext(ctx, cfg.Address, cfg.Username, cfg.Password) + require.NoError(t, err) } -func (t *liveTest) run(sentence ...string) *Reply { +func (t *liveTest) runContext(ctx context.Context, sentence ...string) *Reply { t.Logf("Run: %#q", sentence) - r, err := t.c.RunArgs(sentence) - if err != nil { - t.Fatal(err) - } + + r, err := t.c.RunArgsContext(ctx, sentence) + require.NoError(t, err) + require.NotNil(t, r) + require.NotNil(t, r.Done, "done not received") + t.Logf("Reply: %s", r) return r } func (t *liveTest) getUptime() { - r := t.run("/system/resource/print") - if len(r.Re) != 1 { - t.Fatalf("len(!re)=%d; want 1", len(r.Re)) - } + // allow test to fail after 5 seconds if we didn't receive answer + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ctx = context.WithValue(ctx, "logger", t.T) + + r := t.runContext(ctx, "/system/resource/print") + require.Len(t, r.Re, 1, "expected 1 response") + _, ok := r.Re[0].Map["uptime"] - if !ok { - t.Fatal("Missing uptime") - } + require.True(t, ok, "missing uptime") +} + +func deferCloser(t *testing.T, c io.Closer) { + require.NoError(t, c.Close()) } func TestRunSync(tt *testing.T) { t := newLiveTest(tt) - defer t.c.Close() + defer deferCloser(tt, t.c) t.getUptime() } func TestRunAsync(tt *testing.T) { + // allow test to fail after 5 seconds if we didn't receive answer + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + t := newLiveTest(tt) - defer t.c.Close() - t.c.Async() + defer deferCloser(tt, t.c) + + t.c.AsyncContext(ctx) + require.True(tt, t.c.async, "client should be in async mode") t.getUptime() } func TestRunError(tt *testing.T) { t := newLiveTest(tt) - defer t.c.Close() + defer deferCloser(tt, t.c) for i, sentence := range [][]string{ {"/xxx"}, {"/ip/address/add", "=address=127.0.0.2/32", "=interface=xxx"}, @@ -84,63 +118,85 @@ func TestRunError(tt *testing.T) { } func TestDialInvalidPort(t *testing.T) { - c, err := Dial("127.0.0.1:xxx", "x", "x") - if err == nil { - c.Close() - t.Fatalf("Dial succeeded; want error") + con, err := Dial("127.0.0.1:xxx", "x", "x") + if con != nil { + assert.NoError(t, con.Close()) } - if err.Error() != "dial tcp: lookup tcp/xxx: getaddrinfow: The specified class was not found." && - err.Error() != "dial tcp: lookup tcp/xxx: Servname not supported for ai_socktype" { - t.Fatal(err) + + require.Error(t, err) + require.IsType(t, &net.OpError{}, errors.Unwrap(err)) + + var e *net.DNSError + require.True(t, errors.As(err, &e)) + require.Contains(t, e.Err, "unknown port") +} + +func TestDialTimeout(t *testing.T) { + con, err := DialTimeout("255.255.255.0:8729", "x", "x", time.Millisecond) + if con != nil { + assert.NoError(t, con.Close()) + } + + require.Error(t, err) + + var e net.Error + require.Truef(t, errors.As(err, &e), "want=net.Error have=%q", err) + require.Truef(t, e.Timeout(), `expected="i/o timeout", have=%q`, e) +} + +func TestDialTLSTimeout(t *testing.T) { + con, err := DialTLSTimeout("255.255.255.0:8729", "x", "x", nil, time.Millisecond) + if con != nil { + assert.NoError(t, con.Close()) } + + require.Error(t, err) + + var e net.Error + require.Truef(t, errors.As(err, &e), "want=net.Error have=%q", err) + require.Truef(t, e.Timeout(), `expected="i/o timeout", have=%q`, e) } func TestDialTLSInvalidPort(t *testing.T) { - c, err := DialTLS("127.0.0.1:xxx", "x", "x", nil) - if err == nil { - c.Close() - t.Fatalf("Dial succeeded; want error") - } - if err.Error() != "dial tcp: lookup tcp/xxx: getaddrinfow: The specified class was not found." && - err.Error() != "dial tcp: lookup tcp/xxx: Servname not supported for ai_socktype" { - t.Fatal(err) + con, err := DialTLS("127.0.0.1:xxx", "x", "x", nil) + if con != nil { + assert.NoError(t, con.Close()) } + + require.Error(t, err) + require.IsType(t, &net.OpError{}, errors.Unwrap(err)) + + var e *net.DNSError + require.True(t, errors.As(err, &e)) + require.Contains(t, e.Err, "unknown port") } func TestInvalidLogin(t *testing.T) { - if *routerosAddress == "" { - t.Skip("Flag -routeros.address not set") - } - var err error - c, err := Dial(*routerosAddress, "xxx", "APasswordThatWillNeverExistir") - if err == nil { - c.Close() - t.Fatalf("Dial succeeded; want error") - } - if err.Error() != "from RouterOS device: cannot log in" && - err.Error() != "from RouterOS device: invalid user name or password (6)" { - t.Fatal(err) + cfg := fetchConfig(t) + + c, err := Dial(cfg.Address, "xxx", "APasswordThatWillNeverExistir") + if c != nil { + assert.NoError(t, c.Close()) } + + require.Error(t, err, "dial succeeded; want error") + + var devErr *DeviceError + require.Truef(t, errors.As(err, &devErr), "wait for device error: %v", err) + require.Contains(t, []string{"cannot log in", "invalid user name or password (6)"}, devErr.fetchMessage()) } func TestTrapHandling(tt *testing.T) { t := newLiveTest(tt) - defer t.c.Close() + defer deferCloser(tt, t.c) cmd := []string{"/ip/dns/static/add", "=type=A", "=name=example.com", "=ttl=30", "=address=1.0.0.0"} _, _ = t.c.RunArgs(cmd) _, err := t.c.RunArgs(cmd) - if err == nil { - t.Fatal("Should've returned an error due to a duplicate") - } - devErr, ok := err.(*DeviceError) - if !ok { - t.Fatal("Should've returned a DeviceError") - } - message := devErr.Sentence.Map["message"] - wanted := "entry already exists" - if !strings.Contains(message, wanted) { - t.Fatalf(`message=%#v; want %#v`, message, wanted) - } + require.Error(tt, err, "should've returned an error due to a duplicate") + + var devErr *DeviceError + require.True(t, errors.As(err, &devErr), "should've returned a DeviceError") + require.Contains(tt, devErr.Sentence.Map["message"], "entry already exists") } diff --git a/error.go b/error.go index 848c76e..8585d0f 100644 --- a/error.go +++ b/error.go @@ -2,13 +2,14 @@ package routeros import ( "errors" + "fmt" - "github.com/go-routeros/routeros/proto" + "github.com/go-routeros/routeros/v3/proto" ) var ( - errAlreadyAsync = errors.New("Async() has already been called") - errAsyncLoopEnded = errors.New("Async() loop has ended - probably read error") + errAlreadyAsync = errors.New("method Async() has already been called") + errAsyncLoopEnded = errors.New("method Async(): loop has ended - probably read error") ) // UnknownReplyError records the sentence whose Word is unknown. @@ -26,10 +27,14 @@ type DeviceError struct { Sentence *proto.Sentence } -func (err *DeviceError) Error() string { - m := err.Sentence.Map["message"] - if m == "" { - m = "unknown error: " + err.Sentence.String() +func (err *DeviceError) fetchMessage() string { + if m := err.Sentence.Map["message"]; m != "" { + return m } - return "from RouterOS device: " + m + + return "unknown error: " + err.Sentence.String() +} + +func (err *DeviceError) Error() string { + return fmt.Sprintf("from RouterOS device: %s", err.fetchMessage()) } diff --git a/examples/listen/main.go b/examples/listen/main.go index 667db81..8a660da 100644 --- a/examples/listen/main.go +++ b/examples/listen/main.go @@ -1,15 +1,19 @@ package main import ( + "context" "flag" - "log" + "log/slog" + "os" + "os/signal" "strings" "time" - "github.com/go-routeros/routeros" + "github.com/go-routeros/routeros/v3" ) var ( + debug = flag.Bool("debug", false, "debug log level mode") command = flag.String("command", "/ip/firewall/address-list/listen", "RouterOS command") address = flag.String("address", "127.0.0.1:8728", "RouterOS address and port") username = flag.String("username", "admin", "User name") @@ -19,96 +23,138 @@ var ( useTLS = flag.Bool("tls", false, "Use TLS") ) -func dial() (*routeros.Client, error) { +func dial(ctx context.Context) (*routeros.Client, error) { if *useTLS { - return routeros.DialTLS(*address, *username, *password, nil) + return routeros.DialTLSContext(ctx, *address, *username, *password, nil) } - return routeros.Dial(*address, *username, *password) + return routeros.DialContext(ctx, *address, *username, *password) +} + +func fatal(log *slog.Logger, message string, err error) { + log.Error(message, slog.Any("error", err)) + os.Exit(2) } func main() { - flag.Parse() + var err error + if err = flag.CommandLine.Parse(os.Args[1:]); err != nil { + panic(err) + } - c, err := dial() - if err != nil { - log.Fatal(err) + logLevel := slog.LevelInfo + if debug != nil && *debug { + logLevel = slog.LevelDebug } - defer c.Close() - c.Queue = 100 + handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: logLevel, + }) + + log := slog.New(handler) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + var cli *routeros.Client + if cli, err = dial(ctx); err != nil { + fatal(log, "could not dial", err) + } + + cli.SetLogHandler(handler) + defer func() { + if errClose := cli.Close(); errClose != nil { + log.Error("could not close routerOS client", slog.Any("error", errClose)) + } + }() + + cli.Queue = 100 if *async { - explicitAsync(c) + explicitAsync(ctx, log, cli) } else { - implicitAsync(c) + implicitAsync(ctx, log, cli) + } + + if err = cli.Close(); err != nil { + fatal(log, "could not close client", err) } } -func explicitAsync(c *routeros.Client) { +func explicitAsync(ctx context.Context, log *slog.Logger, c *routeros.Client) { errC := c.Async() + log.Debug("Running explicitAsync mode...") + + ctx, cancel := context.WithCancel(ctx) + defer cancel() go func() { - l, err := c.ListenArgs(strings.Split(*command, " ")) + l, err := c.ListenArgsContext(ctx, strings.Split(*command, " ")) if err != nil { - log.Fatal(err) + fatal(log, "could not listen", err) } go func() { time.Sleep(*timeout) - log.Print("Cancelling the RouterOS command...") - _, err := l.Cancel() - if err != nil { - log.Fatal(err) + log.Debug("Cancelling the RouterOS command...") + + if _, errCancel := l.CancelContext(ctx); errCancel != nil { + fatal(log, "could not cancel context", errCancel) } + + log.Debug("cancelled") + cancel() }() - log.Print("Waiting for !re...") + log.Info("Waiting for !re...") for sen := range l.Chan() { - log.Printf("Update: %s", sen) + log.Info("Update", slog.String("sentence", sen.String())) } - err = l.Err() - if err != nil { - log.Fatal(err) + if err = l.Err(); err != nil { + fatal(log, "received an error", err) } - - log.Print("Done!") - c.Close() }() - err := <-errC - if err != nil { - log.Fatal(err) + select { + case <-ctx.Done(): + return + case err := <-errC: + if err != nil { + fatal(log, "received an error", err) + } } } -func implicitAsync(c *routeros.Client) { - l, err := c.ListenArgs(strings.Split(*command, " ")) +func implicitAsync(ctx context.Context, log *slog.Logger, c *routeros.Client) { + l, err := c.ListenArgsContext(ctx, strings.Split(*command, " ")) if err != nil { - log.Fatal(err) + fatal(log, "could not listen", err) } go func() { time.Sleep(*timeout) - - log.Print("Cancelling the RouterOS command...") - _, err := l.Cancel() - if err != nil { - log.Fatal(err) + log.Debug("Cancelling the RouterOS command...") + if _, errCancel := l.Cancel(); errCancel != nil { + fatal(log, "could not cancel", errCancel) } }() - log.Print("Waiting for !re...") - for sen := range l.Chan() { - log.Printf("Update: %s", sen) - } +loop: + for { + select { + case <-ctx.Done(): + break loop + case sen, ok := <-l.Chan(): + if !ok { + break loop + } - err = l.Err() - if err != nil { - log.Fatal(err) + log.Info("Update", slog.String("sentence", sen.String())) + } + } + if err = l.Err(); err != nil { + fatal(log, "received an error", err) } - - log.Print("Done!") - c.Close() } diff --git a/examples/run/main.go b/examples/run/main.go index 2d88fd4..ec3a719 100644 --- a/examples/run/main.go +++ b/examples/run/main.go @@ -2,13 +2,15 @@ package main import ( "flag" - "log" + "log/slog" + "os" "strings" - "github.com/go-routeros/routeros" + "github.com/go-routeros/routeros/v3" ) var ( + debug = flag.Bool("debug", false, "debug log level mode") command = flag.String("command", "/system/resource/print", "RouterOS command") address = flag.String("address", "127.0.0.1:8728", "RouterOS address and port") username = flag.String("username", "admin", "User name") @@ -24,23 +26,45 @@ func dial() (*routeros.Client, error) { return routeros.Dial(*address, *username, *password) } +func fatal(log *slog.Logger, message string, err error) { + log.Error(message, slog.Any("error", err)) + os.Exit(2) +} + func main() { - flag.Parse() + var err error + if err = flag.CommandLine.Parse(os.Args[1:]); err != nil { + panic(err) + } + + logLevel := slog.LevelInfo + if debug != nil && *debug { + logLevel = slog.LevelDebug + } + + handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: logLevel, + }) + + log := slog.New(handler) c, err := dial() if err != nil { - log.Fatal(err) + fatal(log, "could not connect", err) } defer c.Close() + c.SetLogHandler(handler) + if *async { c.Async() } r, err := c.RunArgs(strings.Split(*command, " ")) if err != nil { - log.Fatal(err) + fatal(log, "could not run args", err) } - log.Print(r) + log.Info("received results", slog.Any("results", r)) } diff --git a/examples/tab/main.go b/examples/tab/main.go index 2dabdcc..e1f9f46 100644 --- a/examples/tab/main.go +++ b/examples/tab/main.go @@ -3,14 +3,16 @@ package main import ( "flag" "fmt" - "log" + "log/slog" + "os" "strings" "time" - "github.com/go-routeros/routeros" + "github.com/go-routeros/routeros/v3" ) var ( + debug = flag.Bool("debug", false, "debug log level mode") address = flag.String("address", "192.168.0.1:8728", "Address") username = flag.String("username", "admin", "Username") password = flag.String("password", "admin", "Password") @@ -18,18 +20,41 @@ var ( interval = flag.Duration("interval", 1*time.Second, "Interval") ) +func fatal(log *slog.Logger, message string, err error) { + log.Error(message, slog.Any("error", err)) + os.Exit(2) +} + func main() { - flag.Parse() + var err error + if err = flag.CommandLine.Parse(os.Args[1:]); err != nil { + panic(err) + } + + logLevel := slog.LevelInfo + if debug != nil && *debug { + logLevel = slog.LevelDebug + } + + handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: logLevel, + }) + + log := slog.New(handler) c, err := routeros.Dial(*address, *username, *password) if err != nil { - log.Fatal(err) + fatal(log, "could not dial", err) } + c.SetLogHandler(handler) + for { - reply, err := c.Run("/interface/print", "?disabled=false", "?running=true", "=.proplist="+*properties) - if err != nil { - log.Fatal(err) + var reply *routeros.Reply + + if reply, err = c.Run("/interface/print", "?disabled=false", "?running=true", "=.proplist="+*properties); err != nil { + fatal(log, "could not run", err) } for _, re := range reply.Re { diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..624a1c1 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/go-routeros/routeros/v3 + +go 1.22.3 + +require github.com/stretchr/testify v1.9.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..60ce688 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/listen.go b/listen.go index 002545b..e7fd9bb 100644 --- a/listen.go +++ b/listen.go @@ -1,9 +1,18 @@ package routeros import ( + "context" "fmt" + "log/slog" - "github.com/go-routeros/routeros/proto" + "github.com/go-routeros/routeros/v3/proto" +) + +const ( + fatalSentence = "!fatal" + doneSentence = "!done" + trapSentence = "!trap" + reSentence = "!re" ) // ListenReply is the struct returned by the Listen*() functions. @@ -26,28 +35,54 @@ func (l *ListenReply) Cancel() (*Reply, error) { return l.c.Run("/cancel", "=tag="+l.tag) } +// CancelContext sends a cancel command to the RouterOS device with context. +func (l *ListenReply) CancelContext(ctx context.Context) (*Reply, error) { + return l.c.RunContext(ctx, "/cancel", "=tag="+l.tag) +} + // Listen simply calls ListenArgsQueue() with queueSize set to c.Queue. func (c *Client) Listen(sentence ...string) (*ListenReply, error) { return c.ListenArgsQueue(sentence, c.Queue) } +// ListenContext simply calls ListenArgsQueue() with queueSize set to c.Queue. +func (c *Client) ListenContext(ctx context.Context, sentence ...string) (*ListenReply, error) { + return c.ListenArgsQueueContext(ctx, sentence, c.Queue) +} + // ListenArgs simply calls ListenArgsQueue() with queueSize set to c.Queue. func (c *Client) ListenArgs(sentence []string) (*ListenReply, error) { return c.ListenArgsQueue(sentence, c.Queue) } +// ListenArgsContext simply calls ListenArgsQueue() with queueSize set to c.Queue. +func (c *Client) ListenArgsContext(ctx context.Context, sentence []string) (*ListenReply, error) { + return c.ListenArgsQueueContext(ctx, sentence, c.Queue) +} + // ListenArgsQueue sends a sentence to the RouterOS device and returns immediately. func (c *Client) ListenArgsQueue(sentence []string, queueSize int) (*ListenReply, error) { - if !c.async { - c.Async() + return c.ListenArgsQueueContext(context.Background(), sentence, queueSize) +} + +// ListenArgsQueueContext sends a sentence to the RouterOS device and returns immediately. +func (c *Client) ListenArgsQueueContext(ctx context.Context, sentence []string, queueSize int) (*ListenReply, error) { + c.logger().Debug("ListenArgsQueueContext", slog.Any("sentences", sentence)) + + if !c.IsAsync() { + c.AsyncContext(ctx) } - c.nextTag++ + tag := c.incrementTag() + l := &ListenReply{c: c} - l.tag = fmt.Sprintf("l%d", c.nextTag) + l.tag = fmt.Sprintf("l%d", tag) l.reC = make(chan *proto.Sentence, queueSize) c.w.BeginSentence() + + c.logger().Debug("set listener tag", slog.String("tag", l.tag)) + for _, word := range sentence { c.w.WriteWord(word) } @@ -56,31 +91,39 @@ func (c *Client) ListenArgsQueue(sentence []string, queueSize int) (*ListenReply c.mu.Lock() defer c.mu.Unlock() - err := c.w.EndSentence() - if err != nil { + if err := c.w.EndSentence(); err != nil { return nil, err } + if c.tags == nil { return nil, errAsyncLoopEnded } + c.tags[l.tag] = l + + go func() { + <-ctx.Done() + + c.r.Cancel() + }() + return l, nil } func (l *ListenReply) processSentence(sen *proto.Sentence) (bool, error) { switch sen.Word { - case "!re": + case reSentence: l.reC <- sen - case "!done": + case doneSentence: l.Done = sen return true, nil - case "!trap": + case trapSentence: if sen.Map["category"] == "2" { l.Done = sen // "execution of command interrupted" return true, nil } return true, &DeviceError{sen} - case "!fatal": + case fatalSentence: return true, &DeviceError{sen} case "": // API docs say that empty sentences should be ignored diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..1d44c5b --- /dev/null +++ b/logger.go @@ -0,0 +1,7 @@ +package routeros + +import ( + "log/slog" +) + +type LogHandler slog.Handler diff --git a/proto/io_context.go b/proto/io_context.go new file mode 100644 index 0000000..73e1886 --- /dev/null +++ b/proto/io_context.go @@ -0,0 +1,103 @@ +package proto + +import ( + "bufio" + "io" + "sync/atomic" +) + +type ctxResult struct { + num int + err error +} + +type ctxReader struct { + io.Reader + close atomic.Bool + done chan struct{} +} + +func (c *ctxReader) Close() { + if c.close.Load() { + return + } + c.close.Store(true) + close(c.done) +} + +func (c *ctxReader) Cancel() { + if c.close.Load() { + return + } + + select { + case c.done <- struct{}{}: + default: + } +} + +func (c *ctxReader) Read(p []byte) (int, error) { + out := make(chan *ctxResult, 1) + buf := make([]byte, len(p)) + + go func() { + res := new(ctxResult) + res.num, res.err = c.Reader.Read(buf) + out <- res + close(out) + }() + + select { + case <-c.done: + return 0, io.EOF + case r := <-out: + copy(p, buf) + + return r.num, r.err + } +} + +type ctxWriter struct { + *bufio.Writer + close atomic.Bool + done chan struct{} +} + +func (c *ctxWriter) Close() { + if c.close.Load() { + return + } + c.close.Store(true) + close(c.done) +} + +func (c *ctxWriter) Cancel() { + if c.close.Load() { + return + } + + select { + case c.done <- struct{}{}: + default: + } +} + +func (c *ctxWriter) Write(p []byte) (int, error) { + out := make(chan *ctxResult, 1) + buf := make([]byte, len(p)) + copy(buf, p) + + go func() { + res := new(ctxResult) + res.num, res.err = c.Writer.Write(buf) + out <- res + close(out) + }() + + select { + case <-c.done: + return 0, io.EOF + case r := <-out: + return r.num, r.err + } +} diff --git a/proto/reader.go b/proto/reader.go index a9b1d66..3486591 100644 --- a/proto/reader.go +++ b/proto/reader.go @@ -10,15 +10,19 @@ import ( // Reader reads sentences from a RouterOS device. type Reader interface { ReadSentence() (*Sentence, error) + Cancel() + Close() } type reader struct { - *bufio.Reader + *ctxReader } // NewReader returns a new Reader to read from r. func NewReader(r io.Reader) Reader { - return &reader{bufio.NewReader(r)} + return &reader{ + ctxReader: &ctxReader{Reader: bufio.NewReader(r), done: make(chan struct{})}, + } } // ReadSentence reads a sentence. diff --git a/proto/reader_test.go b/proto/reader_test.go index 6f54b72..ac0c566 100644 --- a/proto/reader_test.go +++ b/proto/reader_test.go @@ -2,11 +2,15 @@ package proto import ( "bytes" + "crypto/rand" + "fmt" "testing" + + "github.com/stretchr/testify/require" ) func TestReadLength(t *testing.T) { - for _, d := range []struct { + for i, d := range []struct { length int64 rawBytes []byte }{ @@ -16,13 +20,22 @@ func TestReadLength(t *testing.T) { {0x002acdef, []byte{0xE0, 0x2a, 0xcd, 0xef}}, {0x10000080, []byte{0xF0, 0x10, 0x00, 0x00, 0x80}}, } { - r := NewReader(bytes.NewBuffer(d.rawBytes)).(*reader) - l, err := r.readLength() - if err != nil { - t.Fatalf("readLength error: %s", err) - } - if l != d.length { - t.Fatalf("Expected len=%X for input %#v, got %X", d.length, d.rawBytes, l) - } + t.Run(fmt.Sprintf("#%d length=%d", i, d.length), func(t *testing.T) { + r := NewReader(bytes.NewBuffer(d.rawBytes)).(*reader) + l, err := r.readLength() + require.NoError(t, err, "read length error") + require.Equal(t, d.length, l, "expected length is wrong") + }) } } + +func TestReadRandom(t *testing.T) { + randomBytes := make([]byte, 4) + _, err := rand.Read(randomBytes) + require.NoError(t, err, "read random bytes error") + + r := NewReader(bytes.NewBuffer(randomBytes)).(*reader) + _, err = r.readLength() + require.NoError(t, err, "read length error") + +} diff --git a/proto/sentence_test.go b/proto/sentence_test.go index 4746b79..6cfe3f8 100644 --- a/proto/sentence_test.go +++ b/proto/sentence_test.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestReadWrite(t *testing.T) { @@ -17,31 +19,25 @@ func TestReadWrite(t *testing.T) { {[]string{"!done", ".tag=abc123"}, `[]`, "abc123"}, {strings.Split("!re =tx-byte=123456789 =only-key", " "), "[{`tx-byte` `123456789`} {`only-key` ``}]", ""}, } { - buf := &bytes.Buffer{} - // Write sentence into buf. - w := NewWriter(buf) - w.BeginSentence() - for _, word := range test.in { - w.WriteWord(word) - } - err := w.EndSentence() - if err != nil { - t.Errorf("#%d: Input(%#q)=%#v", i, test.in, err) - continue - } - // Read sentence from buf. - r := NewReader(buf) - sen, err := r.ReadSentence() - if err != nil { - t.Errorf("#%d: Input(%#q)=%#v", i, test.in, err) - continue - } - x := fmt.Sprintf("%#q", sen.List) - if x != test.out { - t.Errorf("#%d: Input(%#q)=%s; want %s", i, test.in, x, test.out) - } - if sen.Tag != test.tag { - t.Errorf("#%d: Input(%#q)=%s; want %s", i, test.in, sen.Tag, test.tag) - } + t.Run(fmt.Sprintf("#%d out=%s tag=%s", i, test.out, test.tag), func(t *testing.T) { + buf := &bytes.Buffer{} + // Write sentence into buf. + w := NewWriter(buf) + w.BeginSentence() + for _, word := range test.in { + w.WriteWord(word) + } + err := w.EndSentence() + require.NoErrorf(t, err, "#%d input(%#q)", i, test.in) + + // Read sentence from buf. + r := NewReader(buf) + sen, err := r.ReadSentence() + require.NoErrorf(t, err, "#%d input(%#q)", i, test.in) + + x := fmt.Sprintf("%#q", sen.List) + require.Equal(t, test.out, x, "#%d input(%#q)", i, test.in) + require.Equal(t, test.tag, sen.Tag, "#%d input(%#q)", i, test.in) + }) } } diff --git a/proto/writer.go b/proto/writer.go index b50f314..be50b9e 100644 --- a/proto/writer.go +++ b/proto/writer.go @@ -11,17 +11,21 @@ type Writer interface { BeginSentence() WriteWord(word string) EndSentence() error + + Cancel() + Close() } type writer struct { - *bufio.Writer + *ctxWriter + err error sync.Mutex } // NewWriter returns a new Writer to write to w. func NewWriter(w io.Writer) Writer { - return &writer{Writer: bufio.NewWriter(w)} + return &writer{ctxWriter: &ctxWriter{Writer: bufio.NewWriter(w), done: make(chan struct{})}} } // BeginSentence prepares w for writing a sentence. diff --git a/proto/writer_test.go b/proto/writer_test.go index a484bc1..167e71f 100644 --- a/proto/writer_test.go +++ b/proto/writer_test.go @@ -1,12 +1,14 @@ package proto import ( - "bytes" + "fmt" "testing" + + "github.com/stretchr/testify/require" ) func TestEncodeLength(t *testing.T) { - for _, d := range []struct { + for i, d := range []struct { length int rawBytes []byte }{ @@ -16,9 +18,8 @@ func TestEncodeLength(t *testing.T) { {0x002acdef, []byte{0xE0, 0x2a, 0xcd, 0xef}}, {0x10000080, []byte{0xF0, 0x10, 0x00, 0x00, 0x80}}, } { - b := encodeLength(d.length) - if !bytes.Equal(b, d.rawBytes) { - t.Fatalf("Expected output %#v for len=%d, got %#v", d.rawBytes, d.length, b) - } + t.Run(fmt.Sprintf("#%d length=%d", i, d.length), func(t *testing.T) { + require.Equal(t, d.rawBytes, encodeLength(d.length), "expected bytes is wrong") + }) } } diff --git a/proto_test.go b/proto_test.go index a2c8cef..a40cbd3 100644 --- a/proto_test.go +++ b/proto_test.go @@ -1,19 +1,42 @@ -package routeros_test +package routeros import ( + "crypto/rand" + "errors" "io" "testing" - "github.com/go-routeros/routeros" - "github.com/go-routeros/routeros/proto" + "github.com/stretchr/testify/require" + + "github.com/go-routeros/routeros/v3/proto" ) -func TestLogin(t *testing.T) { +func TestRandomData(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) + + randomBytes := make([]byte, 1024) + _, err := rand.Read(randomBytes) + require.NoError(t, err, "read random bytes error") + + s.readSentence(t, "/login @ [{`name` `userTest`} {`password` `passTest`}]") + s.writeSentence(t, "!done", string(randomBytes)) + }() + + err := c.Login("userTest", "passTest") + require.Error(t, err) + +} + +func TestLoginPre643(t *testing.T) { + c, s := newPair(t) + defer deferCloser(t, c) + + go func() { + defer deferCloser(t, s) s.readSentence(t, "/login @ [{`name` `userTest`} {`password` `passTest`}]") s.writeSentence(t, "!done", "=ret=abc123") s.readSentence(t, "/login @ [{`name` `userTest`} {`response` `0021277bff9ac7caf06aa608e46616d47f`}]") @@ -21,17 +44,29 @@ func TestLogin(t *testing.T) { }() err := c.Login("userTest", "passTest") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) +} + +func TestLoginPost643(t *testing.T) { + c, s := newPair(t) + defer deferCloser(t, c) + + go func() { + defer deferCloser(t, s) + s.readSentence(t, "/login @ [{`name` `userTest`} {`password` `passTest`}]") + s.writeSentence(t, "!done") + }() + + err := c.Login("userTest", "passTest") + require.NoError(t, err) } -func TestLoginIncorrect(t *testing.T) { +func TestLoginIncorrectPre643(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/login @ [{`name` `userTest`} {`password` `passTest`}]") s.writeSentence(t, "!done", "=ret=abc123") s.readSentence(t, "/login @ [{`name` `userTest`} {`response` `0021277bff9ac7caf06aa608e46616d47f`}]") @@ -40,171 +75,157 @@ func TestLoginIncorrect(t *testing.T) { }() err := c.Login("userTest", "passTest") - if err == nil { - t.Fatalf("Login succeeded; want error") - } - if err.Error() != "from RouterOS device: incorrect login" { - t.Fatal(err) - } + require.Error(t, err, "Login succeeded; want error") + + var top *DeviceError + require.Truef(t, errors.As(err, &top), "want=DeviceError, have=%#v", err) + require.Contains(t, []string{"incorrect login"}, top.fetchMessage()) } -func TestLoginNoChallenge(t *testing.T) { +func TestLoginIncorrectPost643(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/login @ [{`name` `userTest`} {`password` `passTest`}]") + s.writeSentence(t, "!trap", "=message=invalid user name or password (6)") s.writeSentence(t, "!done") }() err := c.Login("userTest", "passTest") - if err != nil { - t.Fatal(err) - } + require.Error(t, err, "Login succeeded; want error") + + var top *DeviceError + require.Truef(t, errors.As(err, &top), "want=DeviceError, have=%#v", err) + require.Contains(t, []string{"invalid user name or password (6)"}, top.fetchMessage()) +} + +func TestLoginNoChallenge(t *testing.T) { + c, s := newPair(t) + defer deferCloser(t, c) + + go func() { + defer deferCloser(t, s) + s.readSentence(t, "/login @ [{`name` `userTest`} {`password` `passTest`}]") + s.writeSentence(t, "!done") + }() + + require.NoError(t, c.Login("userTest", "passTest")) } func TestLoginInvalidChallenge(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/login @ [{`name` `userTest`} {`password` `passTest`}]") s.writeSentence(t, "!done", "=ret=Invalid Hex String") }() err := c.Login("userTest", "passTest") - if err == nil { - t.Fatalf("Login succeeded; want error") - } - if err.Error() != "RouterOS: /login: invalid ret (challenge) hex string received: encoding/hex: invalid byte: U+0049 'I'" { - t.Fatal(err) - } + require.Error(t, err, "Login succeeded; want error") + require.Truef(t, errors.Is(err, ErrInvalidChallengeReceived), + "want=ErrInvalidChallengeReceived, have=%#v", err) } func TestLoginEOF(t *testing.T) { c, s := newPair(t) - defer c.Close() - s.Close() + defer deferCloser(t, c) + require.NoError(t, s.Close()) err := c.Login("userTest", "passTest") - if err == nil { - t.Fatalf("Login succeeded; want error") - } - if err.Error() != "io: read/write on closed pipe" { - t.Fatal(err) - } + require.Error(t, err, "Login succeeded; want error") + require.EqualError(t, err, io.ErrClosedPipe.Error()) } func TestCloseTwice(t *testing.T) { c, s := newPair(t) - defer s.Close() - c.Close() - c.Close() + defer deferCloser(t, s) + require.NoError(t, c.Close()) + require.NoError(t, c.Close()) } func TestAsyncTwice(t *testing.T) { c, s := newPair(t) - defer c.Close() - defer s.Close() + defer deferCloser(t, c) + defer deferCloser(t, s) c.Async() errC := c.Async() err := <-errC - want := "Async() has already been called" - if err.Error() != want { - t.Fatalf("Second Async()=%#q; want %#q", err, want) - } - - err = <-errC - if err != nil { - t.Fatalf("Async() channel should be closed after error; got %#q", err) - } + require.EqualError(t, err, errAlreadyAsync.Error()) + require.NoError(t, <-errC, errAsyncLoopEnded.Error()) } -func TestRun(t *testing.T) { +func TestProtoRun(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @ []") s.writeSentence(t, "!re", "=address=1.2.3.4/32") s.writeSentence(t, "!done") }() sen, err := c.Run("/ip/address") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + want := "!re @ [{`address` `1.2.3.4/32`}]\n!done @ []" - if sen.String() != want { - t.Fatalf("/ip/address (%s); want (%s)", sen, want) - } + require.Equal(t, want, sen.String(), "for /ip/address") } func TestRunWithListen(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @l1 []") s.writeSentence(t, "!re", ".tag=l1", "=address=1.2.3.4/32") s.writeSentence(t, "!done", ".tag=l1") }() listen, err := c.Listen("/ip/address") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sen := <-listen.Chan() want := "!re @l1 [{`address` `1.2.3.4/32`}]" - if sen.String() != want { - t.Fatalf("/ip/address (%s); want (%s)", sen, want) - } + require.Equal(t, want, sen.String(), "for /ip/address") sen = <-listen.Chan() - if sen != nil { - t.Fatalf("Listen() channel should be closed after EOF; got %#q", sen) - } - err = listen.Err() - if err != nil { - t.Fatal(err) - } + require.Nil(t, sen, "Listener should have been closed after EOF") + require.NoError(t, listen.Err()) } -func TestRunAsync(t *testing.T) { +func TestProtoRunAsync(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) c.Async() go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @r1 []") s.writeSentence(t, "!re", ".tag=r1", "=address=1.2.3.4/32") s.writeSentence(t, "!done", ".tag=r1") }() sen, err := c.Run("/ip/address") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + want := "!re @r1 [{`address` `1.2.3.4/32`}]\n!done @r1 []" - if sen.String() != want { - t.Fatalf("/ip/address (%s); want (%s)", sen, want) - } + require.Equal(t, want, sen.String(), "for /ip/address") } func TestRunEmptySentence(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @ []") s.writeSentence(t) s.writeSentence(t, "!re", "=address=1.2.3.4/32") @@ -212,151 +233,132 @@ func TestRunEmptySentence(t *testing.T) { }() sen, err := c.Run("/ip/address") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + want := "!re @ [{`address` `1.2.3.4/32`}]\n!done @ []" - if sen.String() != want { - t.Fatalf("/ip/address (%s); want (%s)", sen, want) - } + require.Equal(t, want, sen.String(), "for /ip/address") } func TestRunEOF(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @ []") }() _, err := c.Run("/ip/address") - if err == nil { - t.Fatalf("Run succeeded; want error") - } - if err != io.EOF { - t.Fatal(err) - } + require.Error(t, err, "Run succeeded; want error") + require.Truef(t, errors.Is(err, io.EOF), "want=io.EOF, have=%#v", err) } func TestRunEOFAsync(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) c.Async() go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @r1 []") s.writeSentence(t, "!re", "=address=1.2.3.4/32") }() _, err := c.Run("/ip/address") - if err == nil { - t.Fatalf("Run succeeded; want error") - } - if err != io.EOF { - t.Fatal(err) - } + require.Error(t, err, "Run succeeded; want error") + require.Truef(t, errors.Is(err, io.EOF), "want=io.EOF, have=%#v", err) } func TestRunInvalidSentence(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @ []") s.writeSentence(t, "!xxx") }() _, err := c.Run("/ip/address") - if err == nil { - t.Fatalf("Run succeeded; want error") - } - if err.Error() != "unknown RouterOS reply word: !xxx" { - t.Fatal(err) - } + require.Error(t, err, "Run succeeded; want error") + + var unkErr *UnknownReplyError + require.Truef(t, errors.As(err, &unkErr), "want=UnknownReplyError, have=%#v", err) + require.Equal(t, unkErr.Sentence.Word, "!xxx") } func TestRunTrap(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @ []") s.writeSentence(t, "!trap", "=message=Some device error message") s.writeSentence(t, "!done") }() _, err := c.Run("/ip/address") - if err == nil { - t.Fatalf("Run succeeded; want error") - } - if err.Error() != "from RouterOS device: Some device error message" { - t.Fatal(err) - } + require.Error(t, err, "Run succeeded; want error") + + var devErr *DeviceError + require.Truef(t, errors.As(err, &devErr), "want=DeviceError, have=%#v", err) + require.Equal(t, devErr.fetchMessage(), "Some device error message") } func TestRunTrapWithoutMessage(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @ []") s.writeSentence(t, "!trap", "=some=unknown key") s.writeSentence(t, "!done") }() _, err := c.Run("/ip/address") - if err == nil { - t.Fatalf("Run succeeded; want error") - } - if err.Error() != "from RouterOS device: unknown error: !trap @ [{`some` `unknown key`}]" { - t.Fatal(err) - } + require.Error(t, err, "Run succeeded; want error") + + var devErr *DeviceError + require.Truef(t, errors.As(err, &devErr), "want=DeviceError, have=%#v", err) + require.Equal(t, devErr.fetchMessage(), "unknown error: !trap @ [{`some` `unknown key`}]") } func TestRunFatal(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address @ []") - s.writeSentence(t, "!fatal", "=message=Some device error message") + s.writeSentence(t, fatalSentence, "=message=Some device error message") }() _, err := c.Run("/ip/address") - if err == nil { - t.Fatalf("Run succeeded; want error") - } - if err.Error() != "from RouterOS device: Some device error message" { - t.Fatal(err) - } + require.Error(t, err, "Run succeeded; want error") + + var devErr *DeviceError + require.Truef(t, errors.As(err, &devErr), "want=DeviceError, have=%#v", err) + require.Equal(t, devErr.fetchMessage(), "Some device error message") } func TestRunAfterClose(t *testing.T) { c, s := newPair(t) - c.Close() - s.Close() + require.NoError(t, c.Close()) + require.NoError(t, s.Close()) _, err := c.Run("/ip/address") - if err == nil { - t.Fatalf("Run succeeded; want error") - } - if err.Error() != "io: read/write on closed pipe" { - t.Fatal(err) - } + require.Error(t, err, "Run succeeded; want error") + require.EqualError(t, err, io.EOF.Error()) } func TestListen(t *testing.T) { c, s := newPair(t) - defer c.Close() + defer deferCloser(t, c) go func() { - defer s.Close() + defer deferCloser(t, s) s.readSentence(t, "/ip/address/listen @l1 []") s.writeSentence(t, "!re", ".tag=l1", "=address=1.2.3.4/32") s.readSentence(t, "/cancel @r2 [{`tag` `l1`}]") @@ -367,27 +369,20 @@ func TestListen(t *testing.T) { c.Queue = 1 listen, err := c.Listen("/ip/address/listen") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + reC := listen.Chan() - listen.Cancel() + _, err = listen.Cancel() + require.Equal(t, err, io.EOF) sen := <-reC want := "!re @l1 [{`address` `1.2.3.4/32`}]" - if sen.String() != want { - t.Fatalf("/ip/address/listen (%s); want (%s)", sen, want) - } + require.Equalf(t, want, sen.String(), "/ip/address/listen (%s); want (%s)", sen, want) sen = <-reC - if sen != nil { - t.Fatalf("Listen() channel should be closed after Close(); got %#q", sen) - } - err = listen.Err() - if err != nil { - t.Fatal(err) - } + require.Nilf(t, sen, "Listen() channel should be closed after Close(); got %#q", sen) + require.NoError(t, listen.Err()) } type conn struct { @@ -396,27 +391,25 @@ type conn struct { } func (c *conn) Close() error { - c.PipeReader.Close() - c.PipeWriter.Close() - return nil + if err := c.PipeReader.Close(); err != nil { + return err + } + + return c.PipeWriter.Close() } -func newPair(t *testing.T) (*routeros.Client, *fakeServer) { +func newPair(t *testing.T) (*Client, *fakeServer) { ar, aw := io.Pipe() br, bw := io.Pipe() - c, err := routeros.NewClient(&conn{ar, bw}) - if err != nil { - t.Fatal(err) - } + c, err := NewClient(&conn{ar, bw}) + require.NoError(t, err) - s := &fakeServer{ + return c, &fakeServer{ proto.NewReader(br), proto.NewWriter(aw), &conn{br, aw}, } - - return c, s } type fakeServer struct { @@ -427,12 +420,8 @@ type fakeServer struct { func (f *fakeServer) readSentence(t *testing.T, want string) { sen, err := f.r.ReadSentence() - if err != nil { - t.Fatal(err) - } - if sen.String() != want { - t.Fatalf("Sentence (%s); want (%s)", sen.String(), want) - } + require.NoError(t, err) + require.Equal(t, want, sen.String(), "wrong sentence") t.Logf("< %s\n", sen) } @@ -442,8 +431,6 @@ func (f *fakeServer) writeSentence(t *testing.T, sentence ...string) { for _, word := range sentence { f.w.WriteWord(word) } - err := f.w.EndSentence() - if err != nil { - t.Fatal(err) - } + + require.NoError(t, f.w.EndSentence()) } diff --git a/reply.go b/reply.go index 5c77993..db2be53 100644 --- a/reply.go +++ b/reply.go @@ -1,10 +1,9 @@ package routeros import ( - "bytes" - "fmt" + "strings" - "github.com/go-routeros/routeros/proto" + "github.com/go-routeros/routeros/v3/proto" ) // Reply has all the sentences from a reply. @@ -14,47 +13,28 @@ type Reply struct { } func (r *Reply) String() string { - b := &bytes.Buffer{} - for _, re := range r.Re { - fmt.Fprintf(b, "%s\n", re) + var sb strings.Builder + for _, sen := range r.Re { + sb.WriteString(sen.String()) + sb.WriteRune('\n') } - fmt.Fprintf(b, "%s", r.Done) - return b.String() -} - -// readReply reads one reply synchronously. It returns the reply. -func (c *Client) readReply() (*Reply, error) { - r := &Reply{} - var lastErr error - for { - sen, err := c.r.ReadSentence() - if err != nil { - return nil, err - } - done, err := r.processSentence(sen) - if err != nil { - if done { - return nil, err - } - - lastErr = err - } - if done { - return r, lastErr - } + if r.Done != nil { + sb.WriteString(r.Done.String()) } + + return sb.String() } func (r *Reply) processSentence(sen *proto.Sentence) (bool, error) { switch sen.Word { - case "!re": + case reSentence: r.Re = append(r.Re, sen) - case "!done": + case doneSentence: r.Done = sen return true, nil - case "!trap", "!fatal": - return sen.Word == "!fatal", &DeviceError{sen} + case trapSentence, fatalSentence: + return sen.Word == fatalSentence, &DeviceError{sen} case "": // API docs say that empty sentences should be ignored default: diff --git a/run.go b/run.go index 56cd950..3f4f6ed 100644 --- a/run.go +++ b/run.go @@ -1,9 +1,11 @@ package routeros import ( + "context" "fmt" + "log/slog" - "github.com/go-routeros/routeros/proto" + "github.com/go-routeros/routeros/v3/proto" ) type asyncReply struct { @@ -12,53 +14,100 @@ type asyncReply struct { } // Run simply calls RunArgs(). -func (c *Client) Run(sentence ...string) (*Reply, error) { - return c.RunArgs(sentence) +func (c *Client) Run(sentences ...string) (*Reply, error) { + return c.RunArgs(sentences) +} + +// RunContext simply calls RunArgsContext(). +func (c *Client) RunContext(ctx context.Context, sentences ...string) (*Reply, error) { + return c.RunArgsContext(ctx, sentences) } // RunArgs sends a sentence to the RouterOS device and waits for the reply. -func (c *Client) RunArgs(sentence []string) (*Reply, error) { +func (c *Client) RunArgs(sentences []string) (*Reply, error) { + return c.RunArgsContext(context.Background(), sentences) +} + +// RunArgsContext sends a sentence to the RouterOS device and waits for the reply. +func (c *Client) RunArgsContext(ctx context.Context, sentences []string) (*Reply, error) { + c.logger().Debug("RunArgsContext", slog.Any("sentences", sentences)) + c.w.BeginSentence() - for _, word := range sentence { - c.w.WriteWord(word) - } - if !c.async { - return c.endCommandSync() - } - a, err := c.endCommandAsync() - if err != nil { - return nil, err - } - for range a.reC { + for _, sentence := range sentences { + c.w.WriteWord(sentence) } - return &a.Reply, a.err -} -func (c *Client) endCommandSync() (*Reply, error) { - err := c.w.EndSentence() - if err != nil { - return nil, err + if !c.IsAsync() { + return c.runArgsContextSync() } - return c.readReply() -} -func (c *Client) endCommandAsync() (*asyncReply, error) { - c.nextTag++ + // async mode, assign new tag to request + tag := c.incrementTag() + a := &asyncReply{} a.reC = make(chan *proto.Sentence) - a.tag = fmt.Sprintf("r%d", c.nextTag) + a.tag = fmt.Sprintf("r%d", tag) c.w.WriteWord(".tag=" + a.tag) - - c.mu.Lock() - defer c.mu.Unlock() - - err := c.w.EndSentence() - if err != nil { + c.logger().Debug("set tag", slog.String("tag", a.tag)) + if err := c.w.EndSentence(); err != nil { return nil, err } + + c.mu.Lock() if c.tags == nil { + c.mu.Unlock() + return nil, errAsyncLoopEnded } + c.tags[a.tag] = a - return a, nil + c.mu.Unlock() + + // wait for asyncLoop to close channel or context done + for { + select { + case <-ctx.Done(): + c.r.Cancel() + + return nil, ctx.Err() + case _, ok := <-a.reC: + if !ok { // channel closed + return &a.Reply, a.err + } + } + } +} + +// runArgsContextSync - read command reply in sync mode and return +func (c *Client) runArgsContextSync() (*Reply, error) { + var err error + if err = c.w.EndSentence(); err != nil { + return nil, err + } + + out := new(Reply) + + var lastErr error + for { + var sen *proto.Sentence + + // read next sentence + if sen, err = c.r.ReadSentence(); err != nil { + return nil, err + } + + var done bool + + switch done, err = out.processSentence(sen); { + case err != nil && done: + // processed error sentence and it was fatal + return nil, err + case err != nil: + // processed error sentence, but it was not fatal, read next, store last error + lastErr = err + case done: + // processed sentence is Done, return result and last error + return out, lastErr + } + } }