diff --git a/internal/util/cephcmds.go b/internal/util/cephcmds.go index a405f4973..66c32b196 100644 --- a/internal/util/cephcmds.go +++ b/internal/util/cephcmds.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "os/exec" + "time" "github.com/ceph/ceph-csi/internal/util/log" @@ -65,6 +66,59 @@ func ExecCommand(ctx context.Context, program string, args ...string) (string, s return stdout, stderr, nil } +// ExecCommandWithTimeout executes passed in program with args, timeout and +// returns separate stdout and stderr streams. If the command is not executed +// within given timeout, the process will be killed. In case ctx is not set to +// context.TODO(), the command will be logged after it was executed. +func ExecCommandWithTimeout( + ctx context.Context, + timeout time.Duration, + program string, + args ...string) ( + string, + string, + error) { + var ( + sanitizedArgs = StripSecretInArgs(args) + stdoutBuf bytes.Buffer + stderrBuf bytes.Buffer + ) + + cctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(cctx, program, args...) // #nosec:G204, commands executing not vulnerable. + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + + err := cmd.Run() + stdout := stdoutBuf.String() + stderr := stderrBuf.String() + if err != nil { + // if its a timeout log return context deadline exceeded error message + if errors.Is(cctx.Err(), context.DeadlineExceeded) { + err = fmt.Errorf("timeout: %w", cctx.Err()) + } + err = fmt.Errorf("an error (%w) and stderror (%s) occurred while running %s args: %v", + err, + stderr, + program, + sanitizedArgs) + + if ctx != context.TODO() { + log.ErrorLog(ctx, "%s", err) + } + + return stdout, stderr, err + } + + if ctx != context.TODO() { + log.UsefulLog(ctx, "command succeeded: %s %v", program, sanitizedArgs) + } + + return stdout, stderr, nil +} + // GetPoolID fetches the ID of the pool that matches the passed in poolName // parameter. func GetPoolID(monitors string, cr *Credentials, poolName string) (int64, error) { diff --git a/internal/util/cephcmds_test.go b/internal/util/cephcmds_test.go new file mode 100644 index 000000000..134eb34f0 --- /dev/null +++ b/internal/util/cephcmds_test.go @@ -0,0 +1,89 @@ +/* +Copyright 2021 The Ceph-CSI Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestExecCommandWithTimeout(t *testing.T) { + t.Parallel() + type args struct { + ctx context.Context + program string + timeout time.Duration + args []string + } + tests := []struct { + name string + args args + stdout string + expectedErr error + wantErr bool + }{ + { + name: "echo hello", + args: args{ + ctx: context.TODO(), + program: "echo", + timeout: time.Second, + args: []string{"hello"}, + }, + stdout: "hello\n", + expectedErr: nil, + wantErr: false, + }, + { + name: "sleep with timeout", + args: args{ + ctx: context.TODO(), + program: "sleep", + timeout: time.Second, + args: []string{"3"}, + }, + stdout: "", + expectedErr: context.DeadlineExceeded, + wantErr: true, + }, + } + for _, tt := range tests { + newtt := tt + t.Run(newtt.name, func(t *testing.T) { + t.Parallel() + stdout, _, err := ExecCommandWithTimeout(newtt.args.ctx, + newtt.args.timeout, + newtt.args.program, + newtt.args.args...) + if (err != nil) != newtt.wantErr { + t.Errorf("ExecCommandWithTimeout() error = %v, wantErr %v", err, newtt.wantErr) + + return + } + + if newtt.wantErr && !errors.Is(err, newtt.expectedErr) { + t.Errorf("ExecCommandWithTimeout() error expected got = %v, want %v", err, newtt.expectedErr) + } + + if stdout != newtt.stdout { + t.Errorf("ExecCommandWithTimeout() got = %v, want %v", stdout, newtt.stdout) + } + }) + } +}