diff --git a/internal/rbd/rbd_util.go b/internal/rbd/rbd_util.go index 3dff543fd..1ac823be9 100644 --- a/internal/rbd/rbd_util.go +++ b/internal/rbd/rbd_util.go @@ -2050,28 +2050,6 @@ func (ri *rbdImage) addSnapshotScheduling( return err } adminConn := ra.MirrorSnashotSchedule() - // list all the snapshot scheduling and check at least one image scheduling - // exists with specified interval. - ssList, err := adminConn.List(ls) - if err != nil { - return err - } - - for _, ss := range ssList { - // make sure we are matching image level scheduling. The - // `adminConn.List` lists the global level scheduling also. - if ss.Name == ri.String() { - for _, s := range ss.Schedule { - // TODO: Add support to check start time also. - // The start time is currently stored with different format - // in ceph. Comparison is not possible unless we know in - // which format ceph is storing it. - if s.Interval == interval { - return err - } - } - } - } err = adminConn.Add(ls, interval, startTime) if err != nil { return err diff --git a/internal/rbd/replicationcontrollerserver.go b/internal/rbd/replicationcontrollerserver.go index e00d1b7a5..16e16c418 100644 --- a/internal/rbd/replicationcontrollerserver.go +++ b/internal/rbd/replicationcontrollerserver.go @@ -117,12 +117,10 @@ func getMirroringMode(ctx context.Context, parameters map[string]string) (librbd return mirroringMode, nil } -// getSchedulingDetails gets the mirroring mode and scheduling details from the +// validateSchedulingDetails gets the mirroring mode and scheduling details from the // input GRPC request parameters and validates the scheduling is only supported // for snapshot mirroring mode. -func getSchedulingDetails(parameters map[string]string) (admin.Interval, admin.StartTime, error) { - admInt := admin.NoInterval - adminStartTime := admin.NoStartTime +func validateSchedulingDetails(parameters map[string]string) error { var err error val := parameters[imageMirroringKey] @@ -134,43 +132,49 @@ func getSchedulingDetails(parameters map[string]string) (admin.Interval, admin.S // an optional parameter. case "": default: - return admInt, adminStartTime, status.Error(codes.InvalidArgument, "scheduling is only supported for snapshot mode") + return status.Error(codes.InvalidArgument, "scheduling is only supported for snapshot mode") } // validate mandatory interval field interval, ok := parameters[schedulingIntervalKey] if ok && interval == "" { - return admInt, adminStartTime, status.Error(codes.InvalidArgument, "scheduling interval cannot be empty") + return status.Error(codes.InvalidArgument, "scheduling interval cannot be empty") } - adminStartTime = admin.StartTime(parameters[schedulingStartTimeKey]) + adminStartTime := admin.StartTime(parameters[schedulingStartTimeKey]) if !ok { // startTime is alone not supported it has to be present with interval if adminStartTime != "" { - return admInt, admin.NoStartTime, status.Errorf(codes.InvalidArgument, + return status.Errorf(codes.InvalidArgument, "%q parameter is supported only with %q", schedulingStartTimeKey, schedulingIntervalKey) } } if interval != "" { - admInt, err = validateSchedulingInterval(interval) + err = validateSchedulingInterval(interval) if err != nil { - return admInt, admin.NoStartTime, status.Error(codes.InvalidArgument, err.Error()) + return status.Error(codes.InvalidArgument, err.Error()) } } - return admInt, adminStartTime, nil + return nil +} + +// getSchedulingDetails returns scheduling interval and scheduling startTime. +func getSchedulingDetails(parameters map[string]string) (admin.Interval, admin.StartTime) { + return admin.Interval(parameters[schedulingIntervalKey]), + admin.StartTime(parameters[schedulingStartTimeKey]) } // validateSchedulingInterval return the interval as it is if its ending with // `m|h|d` or else it will return error. -func validateSchedulingInterval(interval string) (admin.Interval, error) { +func validateSchedulingInterval(interval string) error { re := regexp.MustCompile(`^\d+[mhd]$`) if re.MatchString(interval) { - return admin.Interval(interval), nil + return nil } - return "", errors.New("interval specified without d, h, m suffix") + return errors.New("interval specified without d, h, m suffix") } // EnableVolumeReplication extracts the RBD volume information from the @@ -189,7 +193,7 @@ func (rs *ReplicationServer) EnableVolumeReplication(ctx context.Context, } defer cr.DeleteCredentials() - interval, startTime, err := getSchedulingDetails(req.GetParameters()) + err = validateSchedulingDetails(req.GetParameters()) if err != nil { return nil, err } @@ -237,19 +241,6 @@ func (rs *ReplicationServer) EnableVolumeReplication(ctx context.Context, } } - if interval != "" { - err = rbdVol.addSnapshotScheduling(interval, startTime) - if err != nil { - return nil, err - } - log.DebugLog( - ctx, - "Added scheduling at interval %s, start time %s for volume %s", - interval, - startTime, - rbdVol) - } - return &replication.EnableVolumeReplicationResponse{}, nil } @@ -437,6 +428,20 @@ func (rs *ReplicationServer) PromoteVolume(ctx context.Context, } } + interval, startTime := getSchedulingDetails(req.GetParameters()) + if interval != admin.NoInterval { + err = rbdVol.addSnapshotScheduling(interval, startTime) + if err != nil { + return nil, err + } + log.DebugLog( + ctx, + "Added scheduling at interval %s, start time %s for volume %s", + interval, + startTime, + rbdVol) + } + return &replication.PromoteVolumeResponse{}, nil } diff --git a/internal/rbd/replicationcontrollerserver_test.go b/internal/rbd/replicationcontrollerserver_test.go index 1d49e8954..cf306710a 100644 --- a/internal/rbd/replicationcontrollerserver_test.go +++ b/internal/rbd/replicationcontrollerserver_test.go @@ -29,37 +29,31 @@ func TestValidateSchedulingInterval(t *testing.T) { tests := []struct { name string interval string - want admin.Interval wantErr bool }{ { "valid interval in minutes", "3m", - admin.Interval("3m"), false, }, { "valid interval in hour", "22h", - admin.Interval("22h"), false, }, { "valid interval in days", "13d", - admin.Interval("13d"), false, }, { "invalid interval without number", "d", - admin.Interval(""), true, }, { "invalid interval without (m|h|d) suffix", "12", - admin.Interval(""), true, }, } @@ -67,14 +61,86 @@ func TestValidateSchedulingInterval(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := validateSchedulingInterval(tt.interval) + err := validateSchedulingInterval(tt.interval) if (err != nil) != tt.wantErr { t.Errorf("validateSchedulingInterval() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("validateSchedulingInterval() = %v, want %v", got, tt.want) + }) + } +} + +func TestValidateSchedulingDetails(t *testing.T) { + t.Parallel() + tests := []struct { + name string + parameters map[string]string + wantErr bool + }{ + { + "valid parameters", + map[string]string{ + imageMirroringKey: string(imageMirrorModeSnapshot), + schedulingIntervalKey: "1h", + schedulingStartTimeKey: "14:00:00-05:00", + }, + false, + }, + { + "valid parameters when optional startTime is missing", + map[string]string{ + imageMirroringKey: string(imageMirrorModeSnapshot), + schedulingIntervalKey: "1h", + }, + false, + }, + { + "when mirroring mode is journal", + map[string]string{ + imageMirroringKey: "journal", + schedulingIntervalKey: "1h", + }, + true, + }, + { + "when startTime is specified without interval", + map[string]string{ + imageMirroringKey: string(imageMirrorModeSnapshot), + schedulingStartTimeKey: "14:00:00-05:00", + }, + true, + }, + { + "when no scheduling is specified", + map[string]string{ + imageMirroringKey: string(imageMirrorModeSnapshot), + }, + false, + }, + { + "when no parameters and scheduling details are specified", + map[string]string{}, + false, + }, + { + "when no mirroring mode is specified", + map[string]string{ + schedulingIntervalKey: "1h", + schedulingStartTimeKey: "14:00:00-05:00", + }, + false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateSchedulingDetails(tt.parameters) + if (err != nil) != tt.wantErr { + t.Errorf("getSchedulingDetails() error = %v, wantErr %v", err, tt.wantErr) + + return } }) } @@ -87,18 +153,15 @@ func TestGetSchedulingDetails(t *testing.T) { parameters map[string]string wantInterval admin.Interval wantStartTime admin.StartTime - wantErr bool }{ { "valid parameters", map[string]string{ - imageMirroringKey: string(imageMirrorModeSnapshot), schedulingIntervalKey: "1h", schedulingStartTimeKey: "14:00:00-05:00", }, admin.Interval("1h"), admin.StartTime("14:00:00-05:00"), - false, }, { "valid parameters when optional startTime is missing", @@ -108,17 +171,6 @@ func TestGetSchedulingDetails(t *testing.T) { }, admin.Interval("1h"), admin.NoStartTime, - false, - }, - { - "when mirroring mode is journal", - map[string]string{ - imageMirroringKey: "journal", - schedulingIntervalKey: "1h", - }, - admin.NoInterval, - admin.NoStartTime, - true, }, { "when startTime is specified without interval", @@ -127,46 +179,20 @@ func TestGetSchedulingDetails(t *testing.T) { schedulingStartTimeKey: "14:00:00-05:00", }, admin.NoInterval, - admin.NoStartTime, - true, - }, - { - "when no scheduling is specified", - map[string]string{ - imageMirroringKey: string(imageMirrorModeSnapshot), - }, - admin.NoInterval, - admin.NoStartTime, - false, + admin.StartTime("14:00:00-05:00"), }, { "when no parameters and scheduling details are specified", map[string]string{}, admin.NoInterval, admin.NoStartTime, - false, - }, - { - "when no mirroring mode is specified", - map[string]string{ - schedulingIntervalKey: "1h", - schedulingStartTimeKey: "14:00:00-05:00", - }, - admin.Interval("1h"), - admin.StartTime("14:00:00-05:00"), - false, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - interval, startTime, err := getSchedulingDetails(tt.parameters) - if (err != nil) != tt.wantErr { - t.Errorf("getSchedulingDetails() error = %v, wantErr %v", err, tt.wantErr) - - return - } + interval, startTime := getSchedulingDetails(tt.parameters) if !reflect.DeepEqual(interval, tt.wantInterval) { t.Errorf("getSchedulingDetails() interval = %v, want %v", interval, tt.wantInterval) }