diff --git a/cmd/snapctl/commands.go b/cmd/snapctl/commands.go index 5a0bd2456..6566e1223 100644 --- a/cmd/snapctl/commands.go +++ b/cmd/snapctl/commands.go @@ -109,6 +109,17 @@ var ( flPluginVersion, }, }, + { + Name: "swap", + Usage: "swap :: or swap -t -n -v ", + Action: swapPlugins, + Flags: []cli.Flag{ + flPluginAsc, + flPluginType, + flPluginName, + flPluginVersion, + }, + }, { Name: "list", Usage: "list", diff --git a/cmd/snapctl/plugin.go b/cmd/snapctl/plugin.go index ffb6c8797..411e72cd1 100644 --- a/cmd/snapctl/plugin.go +++ b/cmd/snapctl/plugin.go @@ -116,6 +116,84 @@ func unloadPlugin(ctx *cli.Context) { fmt.Printf("Type: %s\n", r.Type) } +func swapPlugins(ctx *cli.Context) { + // plugin to load + pAsc := ctx.String("plugin-asc") + var paths []string + if len(ctx.Args()) < 1 || len(ctx.Args()) > 2 { + fmt.Println("Incorrect usage:") + cli.ShowCommandHelp(ctx, ctx.Command.Name) + os.Exit(1) + } + paths = append(paths, ctx.Args().First()) + if pAsc != "" { + if !strings.Contains(pAsc, ".asc") { + fmt.Println("Must be a .asc file for the -a flag") + cli.ShowCommandHelp(ctx, ctx.Command.Name) + os.Exit(1) + } + paths = append(paths, pAsc) + } + + // plugin to unload + var pDetails []string + var pType string + var pName string + var pVer int + var err error + + if len(ctx.Args()) == 2 { + pDetails = filepath.SplitList(ctx.Args()[1]) + if len(pDetails) == 3 { + pType = pDetails[0] + pName = pDetails[1] + pVer, err = strconv.Atoi(pDetails[2]) + if err != nil { + fmt.Println("Can't convert version string to integer") + cli.ShowCommandHelp(ctx, ctx.Command.Name) + os.Exit(1) + } + } + } else { + pType = ctx.String("plugin-type") + pName = ctx.String("plugin-name") + pVer = ctx.Int("plugin-version") + } + if pType == "" { + fmt.Println("Must provide plugin type") + cli.ShowCommandHelp(ctx, ctx.Command.Name) + os.Exit(1) + } + if pName == "" { + fmt.Println("Must provide plugin name") + cli.ShowCommandHelp(ctx, ctx.Command.Name) + os.Exit(1) + } + if pVer < 1 { + fmt.Println("Must provide plugin version") + cli.ShowCommandHelp(ctx, ctx.Command.Name) + os.Exit(1) + } + + r := pClient.SwapPlugin(paths, pType, pName, pVer) + if r.Err != nil { + fmt.Printf("Error swapping plugins:\n%v\n", r.Err.Error()) + os.Exit(1) + } + + fmt.Println("Plugin loaded") + fmt.Printf("Name: %s\n", r.LoadedPlugin.Name) + fmt.Printf("Version: %d\n", r.LoadedPlugin.Version) + fmt.Printf("Type: %s\n", r.LoadedPlugin.Type) + fmt.Printf("Signed: %v\n", r.LoadedPlugin.Signed) + fmt.Printf("Loaded Time: %s\n\n", r.LoadedPlugin.LoadedTime().Format(timeFormat)) + + fmt.Println("\nPlugin unloaded") + fmt.Printf("Name: %s\n", r.UnloadedPlugin.Name) + fmt.Printf("Version: %d\n", r.UnloadedPlugin.Version) + fmt.Printf("Type: %s\n", r.UnloadedPlugin.Type) +} + func listPlugins(ctx *cli.Context) { plugins := pClient.GetPlugins(ctx.Bool("running")) if plugins.Err != nil { diff --git a/mgmt/rest/client/client_func_test.go b/mgmt/rest/client/client_func_test.go index 03d4e697a..f26b153f0 100644 --- a/mgmt/rest/client/client_func_test.go +++ b/mgmt/rest/client/client_func_test.go @@ -202,6 +202,42 @@ func TestSnapClient(t *testing.T) { }) }) + Convey("SwapPlugins", t, func() { + Convey("Swap with different types should fail", func() { + sp := c.SwapPlugin(FILE_PLUGIN_PATH, p1.LoadedPlugins[0].Type, p1.LoadedPlugins[0].Name, p1.LoadedPlugins[0].Version) + So(sp.Err, ShouldNotBeNil) + lps := c.GetPlugins(false) + So(len(lps.LoadedPlugins), ShouldEqual, 1) + }) + Convey("Swap with same plugin should fail", func() { + sp := c.SwapPlugin(MOCK_PLUGIN_PATH1, p1.LoadedPlugins[0].Type, p1.LoadedPlugins[0].Name, p1.LoadedPlugins[0].Version) + So(sp.Err, ShouldNotBeNil) + lps := c.GetPlugins(false) + So(len(lps.LoadedPlugins), ShouldEqual, 1) + }) + Convey("Swap with plugins with the same type and name", func() { + sp := c.SwapPlugin(MOCK_PLUGIN_PATH2, p1.LoadedPlugins[0].Type, p1.LoadedPlugins[0].Name, p1.LoadedPlugins[0].Version) + So(sp.Err, ShouldBeNil) + lps := c.GetPlugins(false) + So(len(lps.LoadedPlugins), ShouldEqual, 1) + So(lps.LoadedPlugins[0].Type, ShouldEqual, "collector") + So(lps.LoadedPlugins[0].Name, ShouldEqual, "mock") + So(lps.LoadedPlugins[0].Type, ShouldEqual, p1.LoadedPlugins[0].Type) + So(lps.LoadedPlugins[0].Name, ShouldEqual, p1.LoadedPlugins[0].Name) + So(lps.LoadedPlugins[0].Version, ShouldNotEqual, p1.LoadedPlugins[0].Version) + + sp2 := c.SwapPlugin(MOCK_PLUGIN_PATH1, sp.LoadedPlugin.Type, sp.LoadedPlugin.Name, sp.LoadedPlugin.Version) + So(sp2.Err, ShouldBeNil) + lps2 := c.GetPlugins(false) + So(len(lps.LoadedPlugins), ShouldEqual, 1) + So(lps2.LoadedPlugins[0].Type, ShouldEqual, "collector") + So(lps2.LoadedPlugins[0].Name, ShouldEqual, "mock") + So(lps2.LoadedPlugins[0].Type, ShouldEqual, sp.LoadedPlugin.Type) + So(lps2.LoadedPlugins[0].Name, ShouldEqual, sp.LoadedPlugin.Name) + So(lps2.LoadedPlugins[0].Version, ShouldNotEqual, sp.LoadedPlugin.Version) + }) + }) + if cerr == nil { p2 = c.LoadPlugin(MOCK_PLUGIN_PATH2) } diff --git a/mgmt/rest/client/plugin.go b/mgmt/rest/client/plugin.go index eec524c7a..e4f32a6a7 100644 --- a/mgmt/rest/client/plugin.go +++ b/mgmt/rest/client/plugin.go @@ -20,6 +20,7 @@ limitations under the License. package client import ( + "errors" "fmt" "net/url" "time" @@ -78,6 +79,51 @@ func (c *Client) UnloadPlugin(pluginType, name string, version int) *UnloadPlugi return r } +// SwapPlugin swaps two plugins with the same type and name e.g. collector:mock:1 with collector:mock:2 +func (c *Client) SwapPlugin(loadPath []string, unloadType, unloadName string, unloadVersion int) *SwapPluginsResult { + r := &SwapPluginsResult{} + + if len(loadPath) != 1 { + r.Err = errors.New("Did not receive only one plugin to load.") + return r + } + // Load plugin + lp := c.LoadPlugin(loadPath) + if lp.Err != nil { + r.Err = errors.New(lp.Err.Error()) + return r + } + if len(lp.LoadedPlugins) != 1 { + r.Err = errors.New("There is not just one plugin to be loaded") + } + lpr := lp.LoadedPlugins[0].LoadedPlugin + + // Make sure both plugins have the same type and name before unloading. If not, rollback. + if lpr.Type != unloadType || lpr.Name != unloadName { + up := c.UnloadPlugin(lpr.Type, lpr.Name, lpr.Version) + if up.Err != nil { + r.Err = errors.New("Plugins do not have the same type and name. Failed to rollback after error.") + return r + } + r.Err = errors.New("Plugins don't have the same type and name.") + return r + } + // Unload plugin + up := c.UnloadPlugin(unloadType, unloadName, unloadVersion) + if up.Err != nil { + r.Err = up.Err + up2 := c.UnloadPlugin(lpr.Type, lpr.Name, lpr.Version) + if up2.Err != nil { + r.Err = errors.New("Failed to rollback after error unloading plugin.") + } + return r + } + upr := up.PluginUnloaded + r.LoadedPlugin = lp.LoadedPlugins[0] + r.UnloadedPlugin = upr + return r +} + // GetPlugins returns the loaded and available plugins through an HTTP GET request. // By specifying the details flag to tweak output info. An error returns if it failed. func (c *Client) GetPlugins(details bool) *GetPluginsResult { @@ -131,6 +177,12 @@ type UnloadPluginResult struct { Err error } +type SwapPluginsResult struct { + LoadedPlugin LoadedPlugin + UnloadedPlugin *rbody.PluginUnloaded + Err error +} + // We wrap this so we can provide some functionality (like LoadedTime) type LoadedPlugin struct { *rbody.LoadedPlugin