diff --git a/api/utils/__init__.py b/api/utils/__init__.py index 37a77930f..92086b99b 100644 --- a/api/utils/__init__.py +++ b/api/utils/__init__.py @@ -70,6 +70,12 @@ def show_configs(): if "password" in v: v = copy.deepcopy(v) v["password"] = "*" * 8 + if "access_key" in v: + v = copy.deepcopy(v) + v["access_key"] = "*" * 8 + if "secret_key" in v: + v = copy.deepcopy(v) + v["secret_key"] = "*" * 8 msg += f"\n\t{k}: {v}" logging.info(msg) diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 18b41e164..0ba487b4e 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -37,6 +37,12 @@ redis: # access_key: 'access_key' # secret_key: 'secret_key' # region: 'region' +# oss: +# access_key: 'access_key' +# secret_key: 'secret_key' +# endpoint_url: 'http://oss-cn-hangzhou.aliyuncs.com' +# region: 'cn-hangzhou' +# bucket: 'bucket_name' # azure: # auth_type: 'sas' # container_url: 'container_url' diff --git a/docker/.env b/docker/.env index 53b4bb6da..ff256e39a 100644 --- a/docker/.env +++ b/docker/.env @@ -138,3 +138,11 @@ TIMEZONE='Asia/Shanghai' # - `ERROR` # For example, following line changes the log level of `ragflow.es_conn` to `DEBUG`: # LOG_LEVELS=ragflow.es_conn=DEBUG + +# aliyun OSS configuration +# STORAGE_IMPL=OSS +# ACCESS_KEY=xxx +# SECRET_KEY=eee +# ENDPOINT=http://oss-cn-hangzhou.aliyuncs.com +# REGION=cn-hangzhou +# BUCKET=ragflow65536 \ No newline at end of file diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index c667161f0..ce41f8147 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -37,6 +37,12 @@ redis: # access_key: 'access_key' # secret_key: 'secret_key' # region: 'region' +# oss: +# access_key: '${ACCESS_KEY}' +# secret_key: '${SECRET_KEY}' +# endpoint_url: '${ENDPOINT}' +# region: '${REGION}' +# bucket: '${BUCKET}' # azure: # auth_type: 'sas' # container_url: 'container_url' diff --git a/rag/settings.py b/rag/settings.py index 83e087484..72f48c04e 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -26,6 +26,7 @@ INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) AZURE = get_base_config("azure", {}) S3 = get_base_config("s3", {}) MINIO = decrypt_database_config(name="minio") +OSS = get_base_config("oss", {}) try: REDIS = decrypt_database_config(name="redis") except Exception: diff --git a/rag/utils/oss_conn.py b/rag/utils/oss_conn.py new file mode 100644 index 000000000..5525ecbc8 --- /dev/null +++ b/rag/utils/oss_conn.py @@ -0,0 +1,158 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import boto3 +from botocore.exceptions import ClientError +from botocore.config import Config +import time +from io import BytesIO +from rag.utils import singleton +from rag import settings + + +@singleton +class RAGFlowOSS(object): + def __init__(self): + self.conn = None + self.oss_config = settings.OSS + self.access_key = self.oss_config.get('access_key', None) + self.secret_key = self.oss_config.get('secret_key', None) + self.endpoint_url = self.oss_config.get('endpoint_url', None) + self.region = self.oss_config.get('region', None) + self.bucket = self.oss_config.get('bucket', None) + self.__open__() + + @staticmethod + def use_default_bucket(method): + def wrapper(self, bucket, *args, **kwargs): + # If there is a default bucket, use the default bucket + actual_bucket = self.bucket if self.bucket else bucket + return method(self, actual_bucket, *args, **kwargs) + return wrapper + + def __open__(self): + try: + if self.conn: + self.__close__() + except Exception: + pass + + try: + # Reference:https://help.aliyun.com/zh/oss/developer-reference/use-amazon-s3-sdks-to-access-oss + self.conn = boto3.client( + 's3', + region_name=self.region, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + endpoint_url=self.endpoint_url, + config=Config(s3={"addressing_style": "virtual"}, signature_version='v4') + ) + except Exception: + logging.exception(f"Fail to connect at region {self.region}") + + def __close__(self): + del self.conn + self.conn = None + + @use_default_bucket + def bucket_exists(self, bucket): + try: + logging.debug(f"head_bucket bucketname {bucket}") + self.conn.head_bucket(Bucket=bucket) + exists = True + except ClientError: + logging.exception(f"head_bucket error {bucket}") + exists = False + return exists + + def health(self): + bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" + + if not self.bucket_exists(bucket): + self.conn.create_bucket(Bucket=bucket) + logging.debug(f"create bucket {bucket} ********") + + r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm) + return r + + def get_properties(self, bucket, key): + return {} + + def list(self, bucket, dir, recursive=True): + return [] + + @use_default_bucket + def put(self, bucket, fnm, binary): + logging.debug(f"bucket name {bucket}; filename :{fnm}:") + for _ in range(1): + try: + if not self.bucket_exists(bucket): + self.conn.create_bucket(Bucket=bucket) + logging.info(f"create bucket {bucket} ********") + r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm) + + return r + except Exception: + logging.exception(f"Fail put {bucket}/{fnm}") + self.__open__() + time.sleep(1) + + @use_default_bucket + def rm(self, bucket, fnm): + try: + self.conn.delete_object(Bucket=bucket, Key=fnm) + except Exception: + logging.exception(f"Fail rm {bucket}/{fnm}") + + @use_default_bucket + def get(self, bucket, fnm): + for _ in range(1): + try: + r = self.conn.get_object(Bucket=bucket, Key=fnm) + object_data = r['Body'].read() + return object_data + except Exception: + logging.exception(f"fail get {bucket}/{fnm}") + self.__open__() + time.sleep(1) + return + + @use_default_bucket + def obj_exist(self, bucket, fnm): + try: + if self.conn.head_object(Bucket=bucket, Key=fnm): + return True + except ClientError as e: + if e.response['Error']['Code'] == '404': + return False + else: + raise + + @use_default_bucket + def get_presigned_url(self, bucket, fnm, expires): + for _ in range(10): + try: + r = self.conn.generate_presigned_url('get_object', + Params={'Bucket': bucket, + 'Key': fnm}, + ExpiresIn=expires) + + return r + except Exception: + logging.exception(f"fail get url {bucket}/{fnm}") + self.__open__() + time.sleep(1) + return diff --git a/rag/utils/storage_factory.py b/rag/utils/storage_factory.py index a27fda91a..63587b3b0 100644 --- a/rag/utils/storage_factory.py +++ b/rag/utils/storage_factory.py @@ -21,6 +21,7 @@ from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob from rag.utils.minio_conn import RAGFlowMinio from rag.utils.s3_conn import RAGFlowS3 +from rag.utils.oss_conn import RAGFlowOSS class Storage(Enum): @@ -28,6 +29,7 @@ class Storage(Enum): AZURE_SPN = 2 AZURE_SAS = 3 AWS_S3 = 4 + OSS = 5 class StorageFactory: @@ -36,6 +38,7 @@ class StorageFactory: Storage.AZURE_SPN: RAGFlowAzureSpnBlob, Storage.AZURE_SAS: RAGFlowAzureSasBlob, Storage.AWS_S3: RAGFlowS3, + Storage.OSS: RAGFlowOSS, } @classmethod