diff --git a/internal/csi-common/utils.go b/internal/csi-common/utils.go index daf6170ee..42beeac52 100644 --- a/internal/csi-common/utils.go +++ b/internal/csi-common/utils.go @@ -28,6 +28,7 @@ import ( "github.com/ceph/ceph-csi/internal/util/log" "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/csi-addons/spec/lib/go/replication" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" "google.golang.org/grpc" @@ -152,6 +153,20 @@ func getReqID(req interface{}) string { reqID = r.GroupSnapshotId case *csi.GetVolumeGroupSnapshotRequest: reqID = r.GroupSnapshotId + + // Replication + case *replication.EnableVolumeReplicationRequest: + reqID = r.VolumeId + case *replication.DisableVolumeReplicationRequest: + reqID = r.VolumeId + case *replication.PromoteVolumeRequest: + reqID = r.VolumeId + case *replication.DemoteVolumeRequest: + reqID = r.VolumeId + case *replication.ResyncVolumeRequest: + reqID = r.VolumeId + case *replication.GetVolumeReplicationInfoRequest: + reqID = r.VolumeId } return reqID diff --git a/internal/csi-common/utils_test.go b/internal/csi-common/utils_test.go index ddb16648a..1a12e9ecb 100644 --- a/internal/csi-common/utils_test.go +++ b/internal/csi-common/utils_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/csi-addons/spec/lib/go/replication" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" mount "k8s.io/mount-utils" @@ -75,6 +76,25 @@ func TestGetReqID(t *testing.T) { &csi.GetVolumeGroupSnapshotRequest{ GroupSnapshotId: fakeID, }, + + &replication.EnableVolumeReplicationRequest{ + VolumeId: fakeID, + }, + &replication.DisableVolumeReplicationRequest{ + VolumeId: fakeID, + }, + &replication.PromoteVolumeRequest{ + VolumeId: fakeID, + }, + &replication.DemoteVolumeRequest{ + VolumeId: fakeID, + }, + &replication.ResyncVolumeRequest{ + VolumeId: fakeID, + }, + &replication.GetVolumeReplicationInfoRequest{ + VolumeId: fakeID, + }, } for _, r := range req { if got := getReqID(r); got != fakeID {