From 8266815cdae6b3fd2e3cb1ad5ba51fed39b13053 Mon Sep 17 00:00:00 2001 From: Ahmad Zidan <6338730+lan666as@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:10:08 +0700 Subject: [PATCH] feat: add AWS Managed IAM auth for OpenSearch vector DB (#18963) --- .../middleware/vdb/opensearch_config.py | 31 ++++++++-- .../vdb/opensearch/opensearch_vector.py | 59 ++++++++++++++----- .../vdb/opensearch/test_opensearch.py | 59 ++++++++++++++++++- docker/.env.example | 6 +- docker/docker-compose.yaml | 5 +- 5 files changed, 138 insertions(+), 22 deletions(-) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 81dde4c04d..96f478e9a6 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,4 +1,5 @@ -from typing import Optional +import enum +from typing import Literal, Optional from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings @@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings): Configuration settings for OpenSearch """ + class AuthMethod(enum.StrEnum): + """ + Authentication method for OpenSearch + """ + + BASIC = "basic" + AWS_MANAGED_IAM = "aws_managed_iam" + OPENSEARCH_HOST: Optional[str] = Field( description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, @@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings): default=9200, ) + OPENSEARCH_SECURE: bool = Field( + description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", + default=False, + ) + + OPENSEARCH_AUTH_METHOD: AuthMethod = Field( + description="Authentication method for OpenSearch connection (default is 'basic')", + default=AuthMethod.BASIC, + ) + OPENSEARCH_USER: Optional[str] = Field( description="Username for authenticating with OpenSearch", default=None, @@ -29,7 +48,11 @@ class OpenSearchConfig(BaseSettings): default=None, ) - OPENSEARCH_SECURE: bool = Field( - description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", - default=False, + OPENSEARCH_AWS_REGION: Optional[str] = Field( + description="AWS region for OpenSearch (e.g. 'us-west-2')", + default=None, + ) + + OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field( + description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None ) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 6636646cff..e23b8d197f 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,10 +1,9 @@ import json import logging -import ssl -from typing import Any, Optional +from typing import Any, Literal, Optional from uuid import uuid4 -from opensearchpy import OpenSearch, helpers +from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator @@ -24,9 +23,12 @@ logger = logging.getLogger(__name__) class OpenSearchConfig(BaseModel): host: str port: int + secure: bool = False + auth_method: Literal["basic", "aws_managed_iam"] = "basic" user: Optional[str] = None password: Optional[str] = None - secure: bool = False + aws_region: Optional[str] = None + aws_service: Optional[str] = None @model_validator(mode="before") @classmethod @@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel): raise ValueError("config OPENSEARCH_HOST is required") if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") + if values.get("auth_method") == "aws_managed_iam": + if not values.get("aws_region"): + raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method") + if not values.get("aws_service"): + raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method") return values - def create_ssl_context(self) -> ssl.SSLContext: - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation - return ssl_context + def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth: + import boto3 # type: ignore + + return Urllib3AWSV4SignerAuth( + credentials=boto3.Session().get_credentials(), + region=self.aws_region, + service=self.aws_service, # type: ignore[arg-type] + ) def to_opensearch_params(self) -> dict[str, Any]: params = { "hosts": [{"host": self.host, "port": self.port}], "use_ssl": self.secure, "verify_certs": self.secure, + "connection_class": Urllib3HttpConnection, + "pool_maxsize": 20, } - if self.user and self.password: + + if self.auth_method == "basic": + logger.info("Using basic authentication for OpenSearch Vector DB") + params["http_auth"] = (self.user, self.password) - if self.secure: - params["ssl_context"] = self.create_ssl_context() + elif self.auth_method == "aws_managed_iam": + logger.info("Using AWS managed IAM role for OpenSearch Vector DB") + + params["http_auth"] = self.create_aws_managed_iam_auth() + return params @@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector): action = { "_op_type": "index", "_index": self._collection_name.lower(), - "_id": uuid4().hex, "_source": { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, }, } + # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 + if self._client_config.aws_service not in ["aoss"]: + action["_id"] = uuid4().hex actions.append(action) - helpers.bulk(self._client, actions) + helpers.bulk( + client=self._client, + actions=actions, + timeout=30, + max_retries=3, + ) def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} @@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector): }, } + logger.info(f"Creating OpenSearch index {self._collection_name.lower()}") self._client.indices.create(index=self._collection_name.lower(), body=index_body) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory): open_search_config = OpenSearchConfig( host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, + secure=dify_config.OPENSEARCH_SECURE, + auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, - secure=dify_config.OPENSEARCH_SECURE, + aws_region=dify_config.OPENSEARCH_AWS_REGION, + aws_service=dify_config.OPENSEARCH_AWS_SERVICE, ) return OpenSearchVector(collection_name=collection_name, config=open_search_config) diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index 35eed75c2f..2d44dd2924 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -23,13 +23,70 @@ def setup_mock_redis(): ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) +class TestOpenSearchConfig: + def test_to_opensearch_params(self): + config = OpenSearchConfig( + host="localhost", + port=9200, + secure=True, + user="admin", + password="password", + ) + + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": "localhost", "port": 9200}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"] == ("admin", "password") + + @patch("boto3.Session") + @patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth") + def test_to_opensearch_params_with_aws_managed_iam( + self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock + ): + mock_credentials = MagicMock() + mock_boto_session.return_value.get_credentials.return_value = mock_credentials + + mock_auth_instance = MagicMock() + mock_aws_signer_auth.return_value = mock_auth_instance + + aws_region = "ap-southeast-2" + aws_service = "aoss" + host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com" + port = 9201 + + config = OpenSearchConfig( + host=host, + port=port, + secure=True, + auth_method="aws_managed_iam", + aws_region=aws_region, + aws_service=aws_service, + ) + + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": host, "port": port}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"] is mock_auth_instance + + mock_aws_signer_auth.assert_called_once_with( + credentials=mock_credentials, region=aws_region, service=aws_service + ) + assert mock_boto_session.return_value.get_credentials.called + + class TestOpenSearchVector: def setup_method(self): self.collection_name = "test_collection" self.example_doc_id = "example_doc_id" self.vector = OpenSearchVector( collection_name=self.collection_name, - config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), + config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"), ) self.vector._client = MagicMock() diff --git a/docker/.env.example b/docker/.env.example index 1adb07ca64..7bff2975fb 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -526,9 +526,13 @@ RELYT_DATABASE=postgres # open search configuration, only available when VECTOR_STORE is `opensearch` OPENSEARCH_HOST=opensearch OPENSEARCH_PORT=9200 +OPENSEARCH_SECURE=true +OPENSEARCH_AUTH_METHOD=basic OPENSEARCH_USER=admin OPENSEARCH_PASSWORD=admin -OPENSEARCH_SECURE=true +# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless +OPENSEARCH_AWS_REGION=ap-southeast-1 +OPENSEARCH_AWS_SERVICE=aoss # tencent vector configurations, only available when VECTOR_STORE is `tencent` TENCENT_VECTOR_DB_URL=http://127.0.0.1 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1bf6954299..3ed0f60e96 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -225,9 +225,12 @@ x-shared-env: &shared-api-worker-env RELYT_DATABASE: ${RELYT_DATABASE:-postgres} OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} + OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} + OPENSEARCH_AUTH_METHOD: ${OPENSEARCH_AUTH_METHOD:-basic} OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} - OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} + OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1} + OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss} TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}