Skip to content

Commit

Permalink
[#827] executor/docker: Fix: command arguments are not evaluated (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
yohamta authored Feb 12, 2025
1 parent c449058 commit 33c93f4
Show file tree
Hide file tree
Showing 16 changed files with 1,001 additions and 633 deletions.
15 changes: 15 additions & 0 deletions cmd/stop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/dagu-org/dagu/internal/digraph/scheduler"
"github.com/stretchr/testify/require"
)

func TestStopCommand(t *testing.T) {
Expand All @@ -31,6 +32,20 @@ func TestStopCommand(t *testing.T) {
args: []string{"stop", dagFile.Location},
expectedOut: []string{"DAG stopped"}})

// Log the status of the DAG.
go func() {
for {
select {
case <-time.After(time.Millisecond * 100):
status, err := th.Client.GetLatestStatus(th.Context, dagFile.DAG)
require.NoError(t, err)
t.Logf("status: %s, started: %s, finished: %s", status.Status, status.StartedAt, status.FinishedAt)
case <-done:
return
}
}
}()

// Check the DAG is stopped.
dagFile.AssertLatestStatus(t, scheduler.StatusCancel)
<-done
Expand Down
4 changes: 4 additions & 0 deletions internal/digraph/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ func (th *DAG) AssertEnv(t *testing.T, key, val string) {
}
}
t.Errorf("expected env %s=%s not found", key, val)
for i, env := range th.Env {
// print all envs that were found for debugging
t.Logf("env[%d]: %s", i, env)
}
}

func (th *DAG) AssertParam(t *testing.T, params ...string) {
Expand Down
164 changes: 126 additions & 38 deletions internal/digraph/executor/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,115 @@ var _ Executor = (*commandExecutor)(nil)
var _ ExitCoder = (*commandExecutor)(nil)

type commandExecutor struct {
cmd *exec.Cmd
lock sync.Mutex
exitCode int
mu sync.Mutex
config *commandConfig
cmd *exec.Cmd
scriptFile string
exitCode int
}

// ExitCode implements ExitCoder.
func (e *commandExecutor) ExitCode() int {
return e.exitCode
}

func (e *commandExecutor) Run(_ context.Context) error {
e.lock.Lock()
err := e.cmd.Start()
e.lock.Unlock()
if err != nil {
func (e *commandExecutor) Run(ctx context.Context) error {
e.mu.Lock()

if len(e.config.Dir) > 0 && !fileutil.FileExists(e.config.Dir) {
e.mu.Unlock()
return fmt.Errorf("directory does not exist: %s", e.config.Dir)
}

if e.config.Script != "" {
scriptFile, err := setupScript(ctx, digraph.Step{Dir: e.config.Dir, Script: e.config.Script})
if err != nil {
e.mu.Unlock()
return fmt.Errorf("failed to setup script: %w", err)
}
e.scriptFile = scriptFile
defer func() {
// Remove the temporary script file after the command has finished
_ = os.Remove(scriptFile)
}()
}
e.cmd = e.config.newCmd(ctx, e.scriptFile)

if err := e.cmd.Start(); err != nil {
e.exitCode = exitCodeFromError(err)
e.mu.Unlock()
return err
}
e.mu.Unlock()

if err := e.cmd.Wait(); err != nil {
e.exitCode = exitCodeFromError(err)
return err
}

return nil
}

func (e *commandExecutor) SetStdout(out io.Writer) {
e.cmd.Stdout = out
e.config.Stdout = out
}

func (e *commandExecutor) SetStderr(out io.Writer) {
e.cmd.Stderr = out
e.config.Stderr = out
}

func (e *commandExecutor) Kill(sig os.Signal) error {
e.lock.Lock()
defer e.lock.Unlock()
if e.cmd == nil || e.cmd.Process == nil {
return nil
e.mu.Lock()
defer e.mu.Unlock()

if e.cmd != nil && e.cmd.Process != nil {
return syscall.Kill(-e.cmd.Process.Pid, sig.(syscall.Signal))
}

return nil
}

type commandConfig struct {
Ctx context.Context
Dir string
Command string
Args []string
Script string
ShellCommand string
ShellCommandArgs string
Stdout io.Writer
Stderr io.Writer
}

func (cfg *commandConfig) newCmd(ctx context.Context, scriptFile string) *exec.Cmd {
var cmd *exec.Cmd
switch {
case cfg.ShellCommand != "" && scriptFile != "":
// If script is provided ignore the shell command args

// nolint: gosec
cmd = exec.CommandContext(cfg.Ctx, cfg.ShellCommand, scriptFile)

case cfg.ShellCommand != "" && cfg.ShellCommandArgs != "":
// nolint: gosec
cmd = exec.CommandContext(cfg.Ctx, cfg.ShellCommand, "-c", cfg.ShellCommandArgs)

default:
cmd = createDirectCommand(cfg.Ctx, cfg.Command, cfg.Args, scriptFile)

}

stepContext := digraph.GetStepContext(ctx)
cmd.Env = append(cmd.Env, stepContext.AllEnvs()...)
cmd.Dir = cfg.Dir
cmd.Stdout = cfg.Stdout
cmd.Stderr = cfg.Stderr
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
Pgid: 0,
}
return syscall.Kill(-e.cmd.Process.Pid, sig.(syscall.Signal))

return cmd
}

func init() {
Expand All @@ -80,42 +149,61 @@ func exitCodeFromError(err error) int {

func newCommand(ctx context.Context, step digraph.Step) (Executor, error) {
if len(step.Dir) > 0 && !fileutil.FileExists(step.Dir) {
return nil, fmt.Errorf("directory %q does not exist", step.Dir)
return nil, fmt.Errorf("directory does not exist: %s", step.Dir)
}

stepContext := digraph.GetStepContext(ctx)

cmd, err := createCommand(ctx, step)
cfg, err := createCommandConfig(ctx, step)
if err != nil {
return nil, fmt.Errorf("failed to create command: %w", err)
}
cmd.Env = append(cmd.Env, stepContext.AllEnvs()...)
cmd.Dir = step.Dir

cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
Pgid: 0,
}

return &commandExecutor{cmd: cmd}, nil
return &commandExecutor{config: cfg}, nil
}

func createCommand(ctx context.Context, step digraph.Step) (*exec.Cmd, error) {
func createCommandConfig(ctx context.Context, step digraph.Step) (*commandConfig, error) {
shellCommand := cmdutil.GetShellCommand(step.Shell)
shellCmdArgs := step.ShellCmdArgs
if shellCommand == "" || shellCmdArgs == "" {
return createDirectCommand(ctx, step, step.Args), nil

return &commandConfig{
Ctx: ctx,
Dir: step.Dir,
Command: step.Command,
Args: step.Args,
Script: step.Script,
ShellCommand: shellCommand,
ShellCommandArgs: shellCmdArgs,
}, nil
}

func setupScript(_ context.Context, step digraph.Step) (string, error) {
file, err := os.CreateTemp(step.Dir, "dagu_script-")
if err != nil {
return "", fmt.Errorf("failed to create script file: %w", err)
}
defer func() {
_ = file.Close()
}()

if _, err = file.WriteString(step.Script); err != nil {
return "", fmt.Errorf("failed to write script to file: %w", err)
}

if err = file.Sync(); err != nil {
return "", fmt.Errorf("failed to sync script file: %w", err)
}
return createShellCommand(ctx, shellCommand, shellCmdArgs), nil

return file.Name(), nil
}

// createDirectCommand creates a command that runs directly without a shell
func createDirectCommand(ctx context.Context, step digraph.Step, args []string) *exec.Cmd {
// nolint: gosec
return exec.CommandContext(ctx, step.Command, args...)
}
func createDirectCommand(ctx context.Context, cmd string, args []string, scriptFile string) *exec.Cmd {
arguments := make([]string, len(args))
copy(arguments, args)

// createShellCommand creates a command that runs through a shell
func createShellCommand(ctx context.Context, shell, shellCmd string) *exec.Cmd {
return exec.CommandContext(ctx, shell, "-c", shellCmd)
if scriptFile != "" {
arguments = append(arguments, scriptFile)
}

// nolint: gosec
return exec.CommandContext(ctx, cmd, arguments...)
}
17 changes: 14 additions & 3 deletions internal/digraph/executor/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ func (e *docker) Kill(_ os.Signal) error {
return nil
}

func (e *docker) Run(_ context.Context) error {
ctx, cancelFunc := context.WithCancel(context.Background())
func (e *docker) Run(ctx context.Context) error {
ctx, cancelFunc := context.WithCancel(ctx)
e.context = ctx
e.cancel = cancelFunc

Expand Down Expand Up @@ -114,7 +114,18 @@ func (e *docker) Run(_ context.Context) error {
e.containerConfig.Image = e.image
}

e.containerConfig.Cmd = append([]string{e.step.Command}, e.step.Args...)
// Evaluate args
stepContext := digraph.GetStepContext(ctx)
var args []string
for _, arg := range e.step.Args {
val, err := stepContext.EvalString(arg)
if err != nil {
return fmt.Errorf("failed to evaluate arg %s: %w", arg, err)
}
args = append(args, val)
}

e.containerConfig.Cmd = append([]string{e.step.Command}, args...)

resp, err := cli.ContainerCreate(
ctx, e.containerConfig, e.hostConfig, nil, nil, "",
Expand Down
Loading

0 comments on commit 33c93f4

Please sign in to comment.