From 5bf31e7a865e12e2e6a03c27e03d8fe7aab13936 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Fri, 25 Oct 2024 10:11:25 +0800 Subject: [PATCH] refactor: update load_stream method to directly yield file chunks (#9806) --- api/extensions/storage/aliyun_oss_storage.py | 9 +++------ api/extensions/storage/aws_s3_storage.py | 19 ++++++++----------- api/extensions/storage/azure_blob_storage.py | 10 +++------- api/extensions/storage/baidu_obs_storage.py | 9 +++------ .../storage/google_cloud_storage.py | 13 +++++-------- api/extensions/storage/huawei_obs_storage.py | 9 +++------ api/extensions/storage/local_fs_storage.py | 14 +++++--------- api/extensions/storage/oracle_oci_storage.py | 19 ++++++++----------- api/extensions/storage/supabase_storage.py | 13 +++++-------- api/extensions/storage/tencent_cos_storage.py | 7 ++----- .../storage/volcengine_tos_storage.py | 9 +++------ 11 files changed, 48 insertions(+), 83 deletions(-) diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 01c1000e50..67635b129e 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -36,12 +36,9 @@ class AliyunOssStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - while chunk := obj.read(4096): - yield chunk - - return generate() + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + while chunk := obj.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index cb67313bb2..ab2d0fba3b 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -62,17 +62,14 @@ class AwsS3Storage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - try: - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - - return generate() + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise def download(self, filename, target_filepath): self.client.download_file(self.bucket_name, filename, target_filepath) diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 477507feda..11a7544274 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -32,13 +32,9 @@ class AzureBlobStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: client = self._sync_client() - - def generate(filename: str = filename) -> Generator: - blob = client.get_blob_client(container=self.bucket_name, blob=filename) - blob_data = blob.download_blob() - yield from blob_data.chunks() - - return generate(filename) + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + blob_data = blob.download_blob() + yield from blob_data.chunks() def download(self, filename, target_filepath): client = self._sync_client() diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index cd69439749..e0d2140e91 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -39,12 +39,9 @@ class BaiduObsStorage(BaseStorage): return response.data.read() def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data - while chunk := response.read(4096): - yield chunk - - return generate() + response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data + while chunk := response.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath) diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index e90392a6ba..26b662d2f0 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -39,14 +39,11 @@ class GoogleCloudStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.get_blob(filename) - with blob.open(mode="rb") as blob_stream: - while chunk := blob_stream.read(4096): - yield chunk - - return generate() + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + with blob.open(mode="rb") as blob_stream: + while chunk := blob_stream.read(4096): + yield chunk def download(self, filename, target_filepath): bucket = self.client.get_bucket(self.bucket_name) diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 3c443d87ac..20be70ef83 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -27,12 +27,9 @@ class HuaweiObsStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response - while chunk := response.read(4096): - yield chunk - - return generate() + response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response + while chunk := response.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath) diff --git a/api/extensions/storage/local_fs_storage.py b/api/extensions/storage/local_fs_storage.py index b4146d7384..5a495ca4d4 100644 --- a/api/extensions/storage/local_fs_storage.py +++ b/api/extensions/storage/local_fs_storage.py @@ -40,15 +40,11 @@ class LocalFsStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: filepath = self._build_filepath(filename) - - def generate() -> Generator: - if not os.path.exists(filepath): - raise FileNotFoundError("File not found") - with open(filepath, "rb") as f: - while chunk := f.read(4096): # Read in chunks of 4KB - yield chunk - - return generate() + if not os.path.exists(filepath): + raise FileNotFoundError("File not found") + with open(filepath, "rb") as f: + while chunk := f.read(4096): # Read in chunks of 4KB + yield chunk def download(self, filename, target_filepath): filepath = self._build_filepath(filename) diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index e4f50b34e9..b59f83b8de 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -36,17 +36,14 @@ class OracleOCIStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - try: - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - - return generate() + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise def download(self, filename, target_filepath): self.client.download_file(self.bucket_name, filename, target_filepath) diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 1119244574..9f7c69a9ae 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -36,17 +36,14 @@ class SupabaseStorage(BaseStorage): return content def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - result = self.client.storage.from_(self.bucket_name).download(filename) - byte_stream = io.BytesIO(result) - while chunk := byte_stream.read(4096): # Read in chunks of 4KB - yield chunk - - return generate() + result = self.client.storage.from_(self.bucket_name).download(filename) + byte_stream = io.BytesIO(result) + while chunk := byte_stream.read(4096): # Read in chunks of 4KB + yield chunk def download(self, filename, target_filepath): result = self.client.storage.from_(self.bucket_name).download(filename) - Path(result).write_bytes(result) + Path(target_filepath).write_bytes(result) def exists(self, filename): result = self.client.storage.from_(self.bucket_name).list(filename) diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 8fd8e703a1..13a6c9239c 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -29,11 +29,8 @@ class TencentCosStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].get_stream(chunk_size=4096) - - return generate() + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].get_stream(chunk_size=4096) def download(self, filename, target_filepath): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 389c5630e3..de82be04ea 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -27,12 +27,9 @@ class VolcengineTosStorage(BaseStorage): return data def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.get_object(bucket=self.bucket_name, key=filename) - while chunk := response.read(4096): - yield chunk - - return generate() + response = self.client.get_object(bucket=self.bucket_name, key=filename) + while chunk := response.read(4096): + yield chunk def download(self, filename, target_filepath): self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)