mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 04:15:55 +08:00
feat: add AWS Managed IAM auth for OpenSearch vector DB (#18963)
This commit is contained in:
parent
8b4ea01810
commit
8266815cda
@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
import enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(enum.StrEnum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
OPENSEARCH_HOST: Optional[str] = Field(
|
||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||
default=None,
|
||||
@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
|
||||
default=9200,
|
||||
)
|
||||
|
||||
OPENSEARCH_SECURE: bool = Field(
|
||||
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
|
||||
description="Authentication method for OpenSearch connection (default is 'basic')",
|
||||
default=AuthMethod.BASIC,
|
||||
)
|
||||
|
||||
OPENSEARCH_USER: Optional[str] = Field(
|
||||
description="Username for authenticating with OpenSearch",
|
||||
default=None,
|
||||
@ -29,7 +48,11 @@ class OpenSearchConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_SECURE: bool = Field(
|
||||
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
|
||||
default=False,
|
||||
OPENSEARCH_AWS_REGION: Optional[str] = Field(
|
||||
description="AWS region for OpenSearch (e.g. 'us-west-2')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
|
||||
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None
|
||||
)
|
||||
|
@ -1,10 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
|
||||
class OpenSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
secure: bool = False
|
||||
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
||||
user: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
secure: bool = False
|
||||
aws_region: Optional[str] = None
|
||||
aws_service: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config OPENSEARCH_PORT is required")
|
||||
if values.get("auth_method") == "aws_managed_iam":
|
||||
if not values.get("aws_region"):
|
||||
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
|
||||
if not values.get("aws_service"):
|
||||
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
|
||||
return values
|
||||
|
||||
def create_ssl_context(self) -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
|
||||
return ssl_context
|
||||
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
||||
import boto3 # type: ignore
|
||||
|
||||
return Urllib3AWSV4SignerAuth(
|
||||
credentials=boto3.Session().get_credentials(),
|
||||
region=self.aws_region,
|
||||
service=self.aws_service, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.secure,
|
||||
"connection_class": Urllib3HttpConnection,
|
||||
"pool_maxsize": 20,
|
||||
}
|
||||
if self.user and self.password:
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
if self.secure:
|
||||
params["ssl_context"] = self.create_ssl_context()
|
||||
elif self.auth_method == "aws_managed_iam":
|
||||
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
|
||||
action = {
|
||||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_id": uuid4().hex,
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||
if self._client_config.aws_service not in ["aoss"]:
|
||||
action["_id"] = uuid4().hex
|
||||
actions.append(action)
|
||||
|
||||
helpers.bulk(self._client, actions)
|
||||
helpers.bulk(
|
||||
client=self._client,
|
||||
actions=actions,
|
||||
timeout=30,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST or "localhost",
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
aws_region=dify_config.OPENSEARCH_AWS_REGION,
|
||||
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
|
||||
)
|
||||
|
||||
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
||||
|
@ -23,13 +23,70 @@ def setup_mock_redis():
|
||||
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
|
||||
|
||||
|
||||
class TestOpenSearchConfig:
|
||||
def test_to_opensearch_params(self):
|
||||
config = OpenSearchConfig(
|
||||
host="localhost",
|
||||
port=9200,
|
||||
secure=True,
|
||||
user="admin",
|
||||
password="password",
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] == ("admin", "password")
|
||||
|
||||
@patch("boto3.Session")
|
||||
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
|
||||
def test_to_opensearch_params_with_aws_managed_iam(
|
||||
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_aws_signer_auth.return_value = mock_auth_instance
|
||||
|
||||
aws_region = "ap-southeast-2"
|
||||
aws_service = "aoss"
|
||||
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
|
||||
port = 9201
|
||||
|
||||
config = OpenSearchConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
secure=True,
|
||||
auth_method="aws_managed_iam",
|
||||
aws_region=aws_region,
|
||||
aws_service=aws_service,
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": host, "port": port}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] is mock_auth_instance
|
||||
|
||||
mock_aws_signer_auth.assert_called_once_with(
|
||||
credentials=mock_credentials, region=aws_region, service=aws_service
|
||||
)
|
||||
assert mock_boto_session.return_value.get_credentials.called
|
||||
|
||||
|
||||
class TestOpenSearchVector:
|
||||
def setup_method(self):
|
||||
self.collection_name = "test_collection"
|
||||
self.example_doc_id = "example_doc_id"
|
||||
self.vector = OpenSearchVector(
|
||||
collection_name=self.collection_name,
|
||||
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
|
||||
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
|
||||
)
|
||||
self.vector._client = MagicMock()
|
||||
|
||||
|
@ -526,9 +526,13 @@ RELYT_DATABASE=postgres
|
||||
# open search configuration, only available when VECTOR_STORE is `opensearch`
|
||||
OPENSEARCH_HOST=opensearch
|
||||
OPENSEARCH_PORT=9200
|
||||
OPENSEARCH_SECURE=true
|
||||
OPENSEARCH_AUTH_METHOD=basic
|
||||
OPENSEARCH_USER=admin
|
||||
OPENSEARCH_PASSWORD=admin
|
||||
OPENSEARCH_SECURE=true
|
||||
# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless
|
||||
OPENSEARCH_AWS_REGION=ap-southeast-1
|
||||
OPENSEARCH_AWS_SERVICE=aoss
|
||||
|
||||
# tencent vector configurations, only available when VECTOR_STORE is `tencent`
|
||||
TENCENT_VECTOR_DB_URL=http://127.0.0.1
|
||||
|
@ -225,9 +225,12 @@ x-shared-env: &shared-api-worker-env
|
||||
RELYT_DATABASE: ${RELYT_DATABASE:-postgres}
|
||||
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
|
||||
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
|
||||
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
|
||||
OPENSEARCH_AUTH_METHOD: ${OPENSEARCH_AUTH_METHOD:-basic}
|
||||
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
|
||||
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
|
||||
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
|
||||
OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1}
|
||||
OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss}
|
||||
TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
|
||||
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
|
||||
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}
|
||||
|
Loading…
x
Reference in New Issue
Block a user