feat: add AWS Managed IAM auth for OpenSearch vector DB (#18963)

This commit is contained in:
Ahmad Zidan 2025-04-29 14:10:08 +07:00 committed by GitHub
parent 8b4ea01810
commit 8266815cda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 138 additions and 22 deletions

View File

@ -1,4 +1,5 @@
from typing import Optional
import enum
from typing import Literal, Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch
"""
class AuthMethod(enum.StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,
@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
default=9200,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
description="Authentication method for OpenSearch connection (default is 'basic')",
default=AuthMethod.BASIC,
)
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
default=None,
@ -29,7 +48,11 @@ class OpenSearchConfig(BaseSettings):
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None,
)
OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None
)

View File

@ -1,10 +1,9 @@
import json
import logging
import ssl
from typing import Any, Optional
from typing import Any, Literal, Optional
from uuid import uuid4
from opensearchpy import OpenSearch, helpers
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
secure: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
@model_validator(mode="before")
@classmethod
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values
def create_ssl_context(self) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
return ssl_context
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
if self.user and self.password:
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
if self.secure:
params["ssl_context"] = self.create_ssl_context()
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
action["_id"] = uuid4().hex
actions.append(action)
helpers.bulk(self._client, actions)
helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
},
}
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

View File

@ -23,13 +23,70 @@ def setup_mock_redis():
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
class TestOpenSearchConfig:
def test_to_opensearch_params(self):
config = OpenSearchConfig(
host="localhost",
port=9200,
secure=True,
user="admin",
password="password",
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
@patch("boto3.Session")
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
def test_to_opensearch_params_with_aws_managed_iam(
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
):
mock_credentials = MagicMock()
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
mock_auth_instance = MagicMock()
mock_aws_signer_auth.return_value = mock_auth_instance
aws_region = "ap-southeast-2"
aws_service = "aoss"
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
port = 9201
config = OpenSearchConfig(
host=host,
port=port,
secure=True,
auth_method="aws_managed_iam",
aws_region=aws_region,
aws_service=aws_service,
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": host, "port": port}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] is mock_auth_instance
mock_aws_signer_auth.assert_called_once_with(
credentials=mock_credentials, region=aws_region, service=aws_service
)
assert mock_boto_session.return_value.get_credentials.called
class TestOpenSearchVector:
def setup_method(self):
self.collection_name = "test_collection"
self.example_doc_id = "example_doc_id"
self.vector = OpenSearchVector(
collection_name=self.collection_name,
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
)
self.vector._client = MagicMock()

View File

@ -526,9 +526,13 @@ RELYT_DATABASE=postgres
# open search configuration, only available when VECTOR_STORE is `opensearch`
OPENSEARCH_HOST=opensearch
OPENSEARCH_PORT=9200
OPENSEARCH_SECURE=true
OPENSEARCH_AUTH_METHOD=basic
OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true
# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless
OPENSEARCH_AWS_REGION=ap-southeast-1
OPENSEARCH_AWS_SERVICE=aoss
# tencent vector configurations, only available when VECTOR_STORE is `tencent`
TENCENT_VECTOR_DB_URL=http://127.0.0.1

View File

@ -225,9 +225,12 @@ x-shared-env: &shared-api-worker-env
RELYT_DATABASE: ${RELYT_DATABASE:-postgres}
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
OPENSEARCH_AUTH_METHOD: ${OPENSEARCH_AUTH_METHOD:-basic}
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1}
OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss}
TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}