Skip to content

Commit

Permalink
ExecInfo prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
sourishkrout committed Jul 16, 2024
1 parent 444fdfb commit 010ad65
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 36 deletions.
17 changes: 17 additions & 0 deletions internal/command/exec_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package command

import "context"

type runnerContextKey struct{}

var ExecutionInfoKey = &runnerContextKey{}

type ExecutionInfo struct {
RunID string
KnownName string
KnownID string
}

func ContextWithExecutionInfo(ctx context.Context, execInfo *ExecutionInfo) context.Context {
return context.WithValue(ctx, ExecutionInfoKey, execInfo)
}
18 changes: 15 additions & 3 deletions internal/owl/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package owl

import (
"bytes"
"context"
"encoding/json"
"fmt"
"slices"
Expand All @@ -13,6 +14,7 @@ import (

"github.com/graphql-go/graphql"
"github.com/stateful/godotenv"
commandpkg "github.com/stateful/runme/v3/internal/command"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -521,17 +523,27 @@ func (s *Store) SensitiveKeys() ([]string, error) {
return keys, nil
}

func (s *Store) Update(newOrUpdated, deleted []string) error {
func (s *Store) Update(context context.Context, newOrUpdated, deleted []string) error {
s.mu.Lock()
defer s.mu.Unlock()

execInfo, ok := context.Value(commandpkg.ExecutionInfoKey).(*commandpkg.ExecutionInfo)
if !ok {
return errors.New("execution info not found in context")
}

execRef := fmt.Sprintf("#%s", execInfo.KnownID)
if execInfo.KnownName != "" {
execRef = fmt.Sprintf("#%s", execInfo.KnownName)
}

if len(newOrUpdated) > 0 {
updateOpSet, err := NewOperationSet(WithOperation(UpdateSetOperation), WithSpecs(false))
if err != nil {
return err
}

err = updateOpSet.addEnvs("[execution]", newOrUpdated...)
err = updateOpSet.addEnvs(execRef, newOrUpdated...)
if err != nil {
return err
}
Expand All @@ -545,7 +557,7 @@ func (s *Store) Update(newOrUpdated, deleted []string) error {
return err
}

err = deleteOpSet.addEnvs("[execution]", deleted...)
err = deleteOpSet.addEnvs(execRef, deleted...)
if err != nil {
return err
}
Expand Down
12 changes: 7 additions & 5 deletions internal/runner/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ type command struct {

tempScriptFile string

wg sync.WaitGroup
mu sync.Mutex
err error
context context.Context
wg sync.WaitGroup
mu sync.Mutex
err error

logger *zap.Logger
}
Expand Down Expand Up @@ -103,7 +104,7 @@ type commandConfig struct {
Logger *zap.Logger
}

func newCommand(cfg *commandConfig) (*command, error) {
func newCommand(context context.Context, cfg *commandConfig) (*command, error) {
var pathEnv string

// If PATH is set in the session, use it in the system
Expand Down Expand Up @@ -244,6 +245,7 @@ func newCommand(cfg *commandConfig) (*command, error) {
args = append(args, cfg.Args...)

cmd := &command{
context: context,
ProgramPath: programPath,
Args: append(args, extraArgs...),
Directory: directory,
Expand Down Expand Up @@ -501,7 +503,7 @@ func (c *command) collectEnvs() {
newEnvStore(endEnvs...),
)

err = c.Session.UpdateStore(c.cmd.Env, newOrUpdated, deleted)
err = c.Session.UpdateStore(c.context, c.cmd.Env, newOrUpdated, deleted)
c.seterr(err)
}

Expand Down
16 changes: 16 additions & 0 deletions internal/runner/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Stdout: stdout,
Expand Down Expand Up @@ -59,6 +60,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Stdout: stdout,
Expand Down Expand Up @@ -86,6 +88,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "",
LanguageID: "shellscript",
Expand Down Expand Up @@ -114,6 +117,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "",
LanguageID: "js",
Expand Down Expand Up @@ -142,6 +146,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "/usr/bin/env node",
LanguageID: "js",
Expand Down Expand Up @@ -170,6 +175,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "",
LanguageID: "sql",
Expand Down Expand Up @@ -200,6 +206,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Tty: true,
Expand Down Expand Up @@ -231,6 +238,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Tty: true,
Expand Down Expand Up @@ -266,6 +274,7 @@ func Test_command(t *testing.T) {
_, _ = stdin.WriteString("hello")

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Stdin: stdin,
Expand Down Expand Up @@ -295,6 +304,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Tty: true,
Expand Down Expand Up @@ -337,6 +347,7 @@ func Test_command(t *testing.T) {
_, _ = stdin.WriteString("hello")

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Stdin: stdin,
Expand Down Expand Up @@ -371,6 +382,7 @@ func Test_command(t *testing.T) {
stderr := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Tty: true,
Expand Down Expand Up @@ -427,6 +439,7 @@ func Test_command(t *testing.T) {
require.NoError(t, err)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Session: session,
Expand Down Expand Up @@ -473,6 +486,7 @@ func Test_command(t *testing.T) {
stdin := new(bytes.Buffer)

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Stdout: stdout,
Expand Down Expand Up @@ -502,6 +516,7 @@ func Test_command_Stop(t *testing.T) {
t.Parallel()

cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Stdin: bytes.NewBuffer(nil),
Expand Down Expand Up @@ -529,6 +544,7 @@ func Test_command_Stop(t *testing.T) {

func Test_exitCodeFromErr(t *testing.T) {
cmd, err := newCommand(
context.Background(),
&commandConfig{
ProgramName: "bash",
Tty: true,
Expand Down
22 changes: 15 additions & 7 deletions internal/runner/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ func ConvertRunnerProject(runnerProj *runnerv1.Project) (*project.Project, error
}

func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error {
logger := r.logger.With(zap.String("_id", ulid.GenerateID()))
_id := ulid.GenerateID()
logger := r.logger.With(zap.String("_id", _id))

logger.Info("running Execute in runnerService")

Expand All @@ -221,6 +222,13 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error
return errors.WithStack(err)
}

execInfo := &commandpkg.ExecutionInfo{
RunID: _id,
KnownName: req.GetKnownName(),
KnownID: req.GetKnownId(),
}
ctx := commandpkg.ContextWithExecutionInfo(srv.Context(), execInfo)

// We want to always log the request because it is used for AI training.
// see: https://github.com/stateful/runme/issues/574
if req.KnownId != "" {
Expand Down Expand Up @@ -255,7 +263,7 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error
}

if len(req.Envs) > 0 {
err := sess.AddEnvs(req.Envs)
err := sess.AddEnvs(ctx, req.Envs)
if err != nil {
return err
}
Expand Down Expand Up @@ -304,7 +312,7 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error
}

logger.Debug("command config", zap.Any("cfg", cfg))
cmd, err := newCommand(cfg)
cmd, err := newCommand(ctx, cfg)
if err != nil {
var errInvalidLanguage ErrInvalidLanguage
if errors.As(err, &errInvalidLanguage) {
Expand Down Expand Up @@ -343,10 +351,10 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error
return err
}

cmdCtx := srv.Context()
cmdCtx := ctx

if req.Background {
cmdCtx = context.Background()
cmdCtx = commandpkg.ContextWithExecutionInfo(context.Background(), execInfo)
}

if err := cmd.StartWithOpts(cmdCtx, &startOpts{}); err != nil {
Expand Down Expand Up @@ -518,14 +526,14 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error
}

if storeStdout {
err := sess.SetEnv("__", string(stdoutMem))
err := sess.SetEnv(ctx, "__", string(stdoutMem))
if err != nil {
logger.Sugar().Errorf("%v", err)
}

knownName := req.GetKnownName()
if knownName != "" && runnerConformsOpinionatedEnvVarNaming(knownName) {
err = sess.SetEnv(knownName, string(stdoutMem))
err = sess.SetEnv(ctx, knownName, string(stdoutMem))
if err != nil {
logger.Warn("failed to set env", zap.Error(err))
}
Expand Down
Loading

0 comments on commit 010ad65

Please sign in to comment.