Skip to content

Commit

Permalink
feat: Cleanup processes (#20)
Browse files Browse the repository at this point in the history
Closes #1
  • Loading branch information
mikew authored Oct 12, 2024
1 parent b830a8a commit 944906f
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 68 deletions.
204 changes: 155 additions & 49 deletions src/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math/rand/v2"
"os"
"os/exec"
"os/signal"
"path"
"runtime"
"strings"
Expand All @@ -15,6 +16,7 @@ import (
"github.com/urfave/cli/v2"

"nvrh/src/context"
"nvrh/src/exec_helpers"
"nvrh/src/nvim_helpers"
"nvrh/src/ssh_helpers"
)
Expand Down Expand Up @@ -55,6 +57,7 @@ var CliClientOpenCommand = cli.Command{
},

Action: func(c *cli.Context) error {
// Prepare the context.
sessionId := fmt.Sprintf("%d", time.Now().Unix())
nvrhContext := context.NvrhContext{
SessionId: sessionId,
Expand All @@ -72,52 +75,153 @@ var CliClientOpenCommand = cli.Command{
BrowserScriptPath: fmt.Sprintf("/tmp/nvrh-browser-%s", sessionId),
}

// TODO Could really use a context to pass around instead of a bunch of
// args.
// socketPath := fmt.Sprintf("/tmp/nvrh-socket-%d", sessionId)
if nvrhContext.ShouldUsePorts {
min := 1025
max := 65535
nvrhContext.PortNumber = rand.IntN((max - min) + min)
nvrhContext.PortNumber = rand.IntN(max-min) + min
}

if nvrhContext.Server == "" {
return fmt.Errorf("<server> is required")
}

go ssh_helpers.StartRemoteNvim(nvrhContext)
var nv *nvim.Nvim

doneChan := make(chan error)

signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt)

// Prepare remote instance.
go func() {
remoteCmd := ssh_helpers.BuildRemoteNvimCmd(&nvrhContext)
nvrhContext.CommandsToKill = append(nvrhContext.CommandsToKill, remoteCmd)

// We don't want the ssh process ending too early, if it does we can't
// clean up the remote nvim instance.
exec_helpers.PrepareForForking(remoteCmd)

if err := remoteCmd.Run(); err != nil {
log.Printf("Error running ssh: %v", err)
} else {
log.Printf("Remote nvim exited")
}
}()

// Prepare client instance.
nvChan := make(chan *nvim.Nvim, 1)
go func() {
nv, err := nvim_helpers.WaitForNvim(nvrhContext)
nv, err := nvim_helpers.WaitForNvim(&nvrhContext)

if err != nil {
log.Printf("Error connecting to nvim: %v", err)
return
}

defer func() {
log.Print("Closing nvim")
nv.Close()
}()
log.Print("Connected to nvim")
nvChan <- nv

nv.RegisterHandler("tunnel-port", ssh_helpers.MakeRpcTunnelHandler(nvrhContext.Server))
nv.RegisterHandler("open-url", RpcHandleOpenUrl)
if err := prepareRemoteNvim(&nvrhContext, nv); err != nil {
log.Printf("Error preparing remote nvim: %v", err)
}

clientCmd := BuildClientNvimCmd(&nvrhContext)
nvrhContext.CommandsToKill = append(nvrhContext.CommandsToKill, clientCmd)

if err := clientCmd.Run(); err != nil {
log.Printf("Error running local editor: %v", err)
doneChan <- err
} else {
log.Printf("Local editor exited")
doneChan <- nil
}
}()

nv = <-nvChan

go func() {
select {
case sig := <-signalChan:
doneChan <- fmt.Errorf("Received signal: %s", sig)
}
}()

batch := nv.NewBatch()
err := <-doneChan

// Let nvim know the channel id so it can send us messages.
batch.Command(fmt.Sprintf(`let $NVRH_CHANNEL_ID="%d"`, nv.ChannelID()))
// Set $BROWSER so the remote machine can open a browser locally.
batch.Command(fmt.Sprintf(`let $BROWSER="%s"`, nvrhContext.BrowserScriptPath))
log.Printf("Closing nvrh")
closeNvimSocket(nv)
killAllCmds(nvrhContext.CommandsToKill)

// Add command to tunnel port.
batch.Command("command! -nargs=1 NvrhTunnelPort call rpcnotify(str2nr($NVRH_CHANNEL_ID), 'tunnel-port', [<f-args>])")
// Add command to open url.
batch.Command("command! -nargs=1 NvrhOpenUrl call rpcnotify(str2nr($NVRH_CHANNEL_ID), 'open-url', [<f-args>])")
if err != nil {
return err
}

return nil
},
}

func BuildClientNvimCmd(nvrhContext *context.NvrhContext) *exec.Cmd {
replacedArgs := make([]string, len(nvrhContext.LocalEditor))
for i, arg := range nvrhContext.LocalEditor {
replacedArgs[i] = strings.Replace(arg, "{{SOCKET_PATH}}", nvrhContext.LocalSocketOrPort(), -1)
}

log.Printf("Starting local editor: %v", replacedArgs)

editorCommand := exec.Command(replacedArgs[0], replacedArgs[1:]...)
if replacedArgs[0] == "nvim" {
editorCommand.Stdin = os.Stdin
editorCommand.Stdout = os.Stdout
editorCommand.Stderr = os.Stderr
}

// Prepare the browser script.
var output any
batch.ExecLua(`
return editorCommand
}

func prepareRemoteNvim(nvrhContext *context.NvrhContext, nv *nvim.Nvim) error {
nv.RegisterHandler("tunnel-port", ssh_helpers.MakeRpcTunnelHandler(nvrhContext))
nv.RegisterHandler("open-url", RpcHandleOpenUrl)

batch := nv.NewBatch()

// Let nvim know the channel id so it can send us messages.
batch.Command(fmt.Sprintf(`let $NVRH_CHANNEL_ID="%d"`, nv.ChannelID()))
// Set $BROWSER so the remote machine can open a browser locally.
batch.Command(fmt.Sprintf(`let $BROWSER="%s"`, nvrhContext.BrowserScriptPath))

// Add command to tunnel port.
// TODO use `vim.api.nvim_create_user_command`, and check to see if the
// port is already mapped somehow.
batch.ExecLua(`
vim.api.nvim_create_user_command(
'NvrhTunnelPort',
function(args)
vim.rpcnotify(tonumber(os.getenv('NVRH_CHANNEL_ID')), 'tunnel-port', { args.args })
end,
{
nargs = 1,
force = true,
}
)
return true
`, nil, nil)

// Add command to open url.
batch.ExecLua(`
vim.api.nvim_create_user_command(
'NvrhOpenUrl',
function(args)
vim.rpcnotify(tonumber(os.getenv('NVRH_CHANNEL_ID')), 'open-url', { args.args })
end,
{
nargs = 1,
force = true,
}
)
`, nil, nil)

// Prepare the browser script.
batch.ExecLua(`
local browser_script_path, socket_path, channel_id = ...
local script_contents = [[
Expand All @@ -134,39 +238,18 @@ vim.fn.writefile(vim.fn.split(script_contents, '\n'), browser_script_path)
os.execute('chmod +x ' .. browser_script_path)
return true
`, &output, nvrhContext.BrowserScriptPath, nvrhContext.LocalSocketOrPort(), nv.ChannelID())
`, nil, nvrhContext.BrowserScriptPath, nvrhContext.LocalSocketOrPort(), nv.ChannelID())

if err := batch.Execute(); err != nil {
log.Fatalf("Error while preparing remote nvim: %v", err)
}

log.Print("Connected to nvim")
startLocalEditor(nvrhContext)
}()

select {}
},
}

func startLocalEditor(nvrhContext context.NvrhContext) {
replacedArgs := make([]string, len(nvrhContext.LocalEditor))
for i, arg := range nvrhContext.LocalEditor {
replacedArgs[i] = strings.Replace(arg, "{{SOCKET_PATH}}", nvrhContext.LocalSocketOrPort(), -1)
}

log.Printf("Starting local editor: %v", replacedArgs)

editorCommand := exec.Command(replacedArgs[0], replacedArgs[1:]...)
if replacedArgs[0] == "nvim" {
editorCommand.Stdin = os.Stdin
editorCommand.Stdout = os.Stdout
editorCommand.Stderr = os.Stderr
if err := batch.Execute(); err != nil {
return err
}

if err := editorCommand.Run(); err != nil {
log.Printf("Error running editor: %v", err)
return
}
return nil
}

func RpcHandleOpenUrl(v *nvim.Nvim, args []string) {
Expand All @@ -190,3 +273,26 @@ func RpcHandleOpenUrl(v *nvim.Nvim, args []string) {
log.Printf("Don't know how to open url on %s", goos)
}
}

func killAllCmds(cmds []*exec.Cmd) {
for _, cmd := range cmds {
log.Printf("Killing command: %v", cmd)
if cmd.Process != nil {
if err := cmd.Process.Kill(); err != nil {
log.Printf("Error killing command: %v", err)
}
}
}
}

func closeNvimSocket(nv *nvim.Nvim) {
if nv == nil {
return
}

log.Print("Closing nvim")
if err := nv.ExecLua("vim.cmd('qall!')", nil, nil); err != nil {
log.Printf("Error closing remote nvim: %v", err)
}
nv.Close()
}
3 changes: 3 additions & 0 deletions src/context/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package context

import (
"fmt"
"os/exec"
)

type NvrhContext struct {
Expand All @@ -18,6 +19,8 @@ type NvrhContext struct {
LocalEditor []string

BrowserScriptPath string

CommandsToKill []*exec.Cmd
}

func (nc NvrhContext) LocalSocketOrPort() string {
Expand Down
15 changes: 15 additions & 0 deletions src/exec_helpers/unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//go:build !windows
// +build !windows

package exec_helpers

import (
"os/exec"
"syscall"
)

func PrepareForForking(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}
}
14 changes: 14 additions & 0 deletions src/exec_helpers/windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//go:build windows
// +build windows

package exec_helpers

import (
"os/exec"
)

func PrepareForForking(cmd *exec.Cmd) {
// cmd.SysProcAttr = &syscall.SysProcAttr{
// Setpgid: true,
// }
}
2 changes: 1 addition & 1 deletion src/nvim_helpers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"nvrh/src/context"
)

func WaitForNvim(nvrhContext context.NvrhContext) (*nvim.Nvim, error) {
func WaitForNvim(nvrhContext *context.NvrhContext) (*nvim.Nvim, error) {
for {
nv, err := nvim.Dial(nvrhContext.LocalSocketOrPort())

Expand Down
29 changes: 11 additions & 18 deletions src/ssh_helpers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
"github.com/neovim/go-client/nvim"
)

func StartRemoteNvim(nvrhContext context.NvrhContext) {
nvimCommand := buildRemoteCommand(nvrhContext)
log.Printf("Starting remote nvim: %s", nvimCommand)
func BuildRemoteNvimCmd(nvrhContext *context.NvrhContext) *exec.Cmd {
nvimCommandString := buildRemoteCommandString(nvrhContext)
log.Printf("Starting remote nvim: %s", nvimCommandString)

tunnel := fmt.Sprintf("%s:%s", nvrhContext.LocalSocketPath, nvrhContext.RemoteSocketPath)
if nvrhContext.ShouldUsePorts {
Expand All @@ -29,7 +29,7 @@ func StartRemoteNvim(nvrhContext context.NvrhContext) {
nvrhContext.Server,
// TODO Not really sure if this is better than piping it as exampled
// below.
fmt.Sprintf("$SHELL -i -c '%s'", nvimCommand),
fmt.Sprintf("$SHELL -i -c '%s'", nvimCommandString),
)

if runtime.GOOS == "windows" {
Expand All @@ -53,19 +53,10 @@ func StartRemoteNvim(nvrhContext context.NvrhContext) {
// Close the pipe after writing
// stdinPipe.Close()

if err := sshCommand.Start(); err != nil {
log.Printf("Error starting command: %v", err)
return
}

defer sshCommand.Process.Kill()

if err := sshCommand.Wait(); err != nil {
log.Printf("Error waiting for command: %v", err)
}
return sshCommand
}

func buildRemoteCommand(nvrhContext context.NvrhContext) string {
func buildRemoteCommandString(nvrhContext *context.NvrhContext) string {
envPairsString := ""
if len(nvrhContext.RemoteEnv) > 0 {
envPairsString = strings.Join(nvrhContext.RemoteEnv, " ")
Expand All @@ -79,18 +70,20 @@ func buildRemoteCommand(nvrhContext context.NvrhContext) string {
)
}

func MakeRpcTunnelHandler(server string) func(*nvim.Nvim, []string) {
func MakeRpcTunnelHandler(nvrhContext *context.NvrhContext) func(*nvim.Nvim, []string) {
return func(v *nvim.Nvim, args []string) {
go func() {
log.Printf("Tunneling %s:%s", server, args[0])
log.Printf("Tunneling %s:%s", nvrhContext.Server, args[0])

sshCommand := exec.Command(
"ssh",
"-NL",
fmt.Sprintf("%s:0.0.0.0:%s", args[0], args[0]),
server,
nvrhContext.Server,
)

nvrhContext.CommandsToKill = append(nvrhContext.CommandsToKill, sshCommand)

if err := sshCommand.Start(); err != nil {
log.Printf("Error starting command: %v", err)
return
Expand Down

0 comments on commit 944906f

Please sign in to comment.