mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 23:55:52 +08:00
feat: add AWS Managed IAM auth for OpenSearch vector DB (#18963)
This commit is contained in:
parent
8b4ea01810
commit
8266815cda
@ -1,4 +1,5 @@
|
|||||||
from typing import Optional
|
import enum
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field, PositiveInt
|
from pydantic import Field, PositiveInt
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
|
|||||||
Configuration settings for OpenSearch
|
Configuration settings for OpenSearch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class AuthMethod(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
Authentication method for OpenSearch
|
||||||
|
"""
|
||||||
|
|
||||||
|
BASIC = "basic"
|
||||||
|
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||||
|
|
||||||
OPENSEARCH_HOST: Optional[str] = Field(
|
OPENSEARCH_HOST: Optional[str] = Field(
|
||||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||||
default=None,
|
default=None,
|
||||||
@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
|
|||||||
default=9200,
|
default=9200,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
OPENSEARCH_SECURE: bool = Field(
|
||||||
|
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
|
||||||
|
description="Authentication method for OpenSearch connection (default is 'basic')",
|
||||||
|
default=AuthMethod.BASIC,
|
||||||
|
)
|
||||||
|
|
||||||
OPENSEARCH_USER: Optional[str] = Field(
|
OPENSEARCH_USER: Optional[str] = Field(
|
||||||
description="Username for authenticating with OpenSearch",
|
description="Username for authenticating with OpenSearch",
|
||||||
default=None,
|
default=None,
|
||||||
@ -29,7 +48,11 @@ class OpenSearchConfig(BaseSettings):
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
OPENSEARCH_SECURE: bool = Field(
|
OPENSEARCH_AWS_REGION: Optional[str] = Field(
|
||||||
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
|
description="AWS region for OpenSearch (e.g. 'us-west-2')",
|
||||||
default=False,
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
|
||||||
|
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None
|
||||||
)
|
)
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import ssl
|
from typing import Any, Literal, Optional
|
||||||
from typing import Any, Optional
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from opensearchpy import OpenSearch, helpers
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||||
from opensearchpy.helpers import BulkIndexError
|
from opensearchpy.helpers import BulkIndexError
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
|
|||||||
class OpenSearchConfig(BaseModel):
|
class OpenSearchConfig(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
|
secure: bool = False
|
||||||
|
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
secure: bool = False
|
aws_region: Optional[str] = None
|
||||||
|
aws_service: Optional[str] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
|
|||||||
raise ValueError("config OPENSEARCH_HOST is required")
|
raise ValueError("config OPENSEARCH_HOST is required")
|
||||||
if not values.get("port"):
|
if not values.get("port"):
|
||||||
raise ValueError("config OPENSEARCH_PORT is required")
|
raise ValueError("config OPENSEARCH_PORT is required")
|
||||||
|
if values.get("auth_method") == "aws_managed_iam":
|
||||||
|
if not values.get("aws_region"):
|
||||||
|
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
|
||||||
|
if not values.get("aws_service"):
|
||||||
|
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def create_ssl_context(self) -> ssl.SSLContext:
|
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
||||||
ssl_context = ssl.create_default_context()
|
import boto3 # type: ignore
|
||||||
ssl_context.check_hostname = False
|
|
||||||
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
|
return Urllib3AWSV4SignerAuth(
|
||||||
return ssl_context
|
credentials=boto3.Session().get_credentials(),
|
||||||
|
region=self.aws_region,
|
||||||
|
service=self.aws_service, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
def to_opensearch_params(self) -> dict[str, Any]:
|
def to_opensearch_params(self) -> dict[str, Any]:
|
||||||
params = {
|
params = {
|
||||||
"hosts": [{"host": self.host, "port": self.port}],
|
"hosts": [{"host": self.host, "port": self.port}],
|
||||||
"use_ssl": self.secure,
|
"use_ssl": self.secure,
|
||||||
"verify_certs": self.secure,
|
"verify_certs": self.secure,
|
||||||
|
"connection_class": Urllib3HttpConnection,
|
||||||
|
"pool_maxsize": 20,
|
||||||
}
|
}
|
||||||
if self.user and self.password:
|
|
||||||
|
if self.auth_method == "basic":
|
||||||
|
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||||
|
|
||||||
params["http_auth"] = (self.user, self.password)
|
params["http_auth"] = (self.user, self.password)
|
||||||
if self.secure:
|
elif self.auth_method == "aws_managed_iam":
|
||||||
params["ssl_context"] = self.create_ssl_context()
|
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||||
|
|
||||||
|
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
|
|||||||
action = {
|
action = {
|
||||||
"_op_type": "index",
|
"_op_type": "index",
|
||||||
"_index": self._collection_name.lower(),
|
"_index": self._collection_name.lower(),
|
||||||
"_id": uuid4().hex,
|
|
||||||
"_source": {
|
"_source": {
|
||||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||||
Field.METADATA_KEY.value: documents[i].metadata,
|
Field.METADATA_KEY.value: documents[i].metadata,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||||
|
if self._client_config.aws_service not in ["aoss"]:
|
||||||
|
action["_id"] = uuid4().hex
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
|
|
||||||
helpers.bulk(self._client, actions)
|
helpers.bulk(
|
||||||
|
client=self._client,
|
||||||
|
actions=actions,
|
||||||
|
timeout=30,
|
||||||
|
max_retries=3,
|
||||||
|
)
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||||
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
|
||||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||||
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
|||||||
open_search_config = OpenSearchConfig(
|
open_search_config = OpenSearchConfig(
|
||||||
host=dify_config.OPENSEARCH_HOST or "localhost",
|
host=dify_config.OPENSEARCH_HOST or "localhost",
|
||||||
port=dify_config.OPENSEARCH_PORT,
|
port=dify_config.OPENSEARCH_PORT,
|
||||||
|
secure=dify_config.OPENSEARCH_SECURE,
|
||||||
|
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
||||||
user=dify_config.OPENSEARCH_USER,
|
user=dify_config.OPENSEARCH_USER,
|
||||||
password=dify_config.OPENSEARCH_PASSWORD,
|
password=dify_config.OPENSEARCH_PASSWORD,
|
||||||
secure=dify_config.OPENSEARCH_SECURE,
|
aws_region=dify_config.OPENSEARCH_AWS_REGION,
|
||||||
|
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
||||||
|
@ -23,13 +23,70 @@ def setup_mock_redis():
|
|||||||
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
|
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenSearchConfig:
|
||||||
|
def test_to_opensearch_params(self):
|
||||||
|
config = OpenSearchConfig(
|
||||||
|
host="localhost",
|
||||||
|
port=9200,
|
||||||
|
secure=True,
|
||||||
|
user="admin",
|
||||||
|
password="password",
|
||||||
|
)
|
||||||
|
|
||||||
|
params = config.to_opensearch_params()
|
||||||
|
|
||||||
|
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
|
||||||
|
assert params["use_ssl"] is True
|
||||||
|
assert params["verify_certs"] is True
|
||||||
|
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||||
|
assert params["http_auth"] == ("admin", "password")
|
||||||
|
|
||||||
|
@patch("boto3.Session")
|
||||||
|
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
|
||||||
|
def test_to_opensearch_params_with_aws_managed_iam(
|
||||||
|
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
|
||||||
|
):
|
||||||
|
mock_credentials = MagicMock()
|
||||||
|
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
|
||||||
|
|
||||||
|
mock_auth_instance = MagicMock()
|
||||||
|
mock_aws_signer_auth.return_value = mock_auth_instance
|
||||||
|
|
||||||
|
aws_region = "ap-southeast-2"
|
||||||
|
aws_service = "aoss"
|
||||||
|
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
|
||||||
|
port = 9201
|
||||||
|
|
||||||
|
config = OpenSearchConfig(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
secure=True,
|
||||||
|
auth_method="aws_managed_iam",
|
||||||
|
aws_region=aws_region,
|
||||||
|
aws_service=aws_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
params = config.to_opensearch_params()
|
||||||
|
|
||||||
|
assert params["hosts"] == [{"host": host, "port": port}]
|
||||||
|
assert params["use_ssl"] is True
|
||||||
|
assert params["verify_certs"] is True
|
||||||
|
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||||
|
assert params["http_auth"] is mock_auth_instance
|
||||||
|
|
||||||
|
mock_aws_signer_auth.assert_called_once_with(
|
||||||
|
credentials=mock_credentials, region=aws_region, service=aws_service
|
||||||
|
)
|
||||||
|
assert mock_boto_session.return_value.get_credentials.called
|
||||||
|
|
||||||
|
|
||||||
class TestOpenSearchVector:
|
class TestOpenSearchVector:
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
self.collection_name = "test_collection"
|
self.collection_name = "test_collection"
|
||||||
self.example_doc_id = "example_doc_id"
|
self.example_doc_id = "example_doc_id"
|
||||||
self.vector = OpenSearchVector(
|
self.vector = OpenSearchVector(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
|
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
|
||||||
)
|
)
|
||||||
self.vector._client = MagicMock()
|
self.vector._client = MagicMock()
|
||||||
|
|
||||||
|
@ -526,9 +526,13 @@ RELYT_DATABASE=postgres
|
|||||||
# open search configuration, only available when VECTOR_STORE is `opensearch`
|
# open search configuration, only available when VECTOR_STORE is `opensearch`
|
||||||
OPENSEARCH_HOST=opensearch
|
OPENSEARCH_HOST=opensearch
|
||||||
OPENSEARCH_PORT=9200
|
OPENSEARCH_PORT=9200
|
||||||
|
OPENSEARCH_SECURE=true
|
||||||
|
OPENSEARCH_AUTH_METHOD=basic
|
||||||
OPENSEARCH_USER=admin
|
OPENSEARCH_USER=admin
|
||||||
OPENSEARCH_PASSWORD=admin
|
OPENSEARCH_PASSWORD=admin
|
||||||
OPENSEARCH_SECURE=true
|
# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless
|
||||||
|
OPENSEARCH_AWS_REGION=ap-southeast-1
|
||||||
|
OPENSEARCH_AWS_SERVICE=aoss
|
||||||
|
|
||||||
# tencent vector configurations, only available when VECTOR_STORE is `tencent`
|
# tencent vector configurations, only available when VECTOR_STORE is `tencent`
|
||||||
TENCENT_VECTOR_DB_URL=http://127.0.0.1
|
TENCENT_VECTOR_DB_URL=http://127.0.0.1
|
||||||
|
@ -225,9 +225,12 @@ x-shared-env: &shared-api-worker-env
|
|||||||
RELYT_DATABASE: ${RELYT_DATABASE:-postgres}
|
RELYT_DATABASE: ${RELYT_DATABASE:-postgres}
|
||||||
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
|
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
|
||||||
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
|
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
|
||||||
|
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
|
||||||
|
OPENSEARCH_AUTH_METHOD: ${OPENSEARCH_AUTH_METHOD:-basic}
|
||||||
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
|
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
|
||||||
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
|
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
|
||||||
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
|
OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1}
|
||||||
|
OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss}
|
||||||
TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
|
TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
|
||||||
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
|
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
|
||||||
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}
|
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user