diff --git a/api/.env.example b/api/.env.example index 4e2d76f810..b437beabd4 100644 --- a/api/.env.example +++ b/api/.env.example @@ -72,6 +72,7 @@ VECTOR_STORE=weaviate WEAVIATE_ENDPOINT=http://localhost:8080 WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_GRPC_ENABLED=false +WEAVIATE_BATCH_SIZE=100 # Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode QDRANT_URL=path:storage/qdrant diff --git a/api/config.py b/api/config.py index bb11a85bf5..64afda88ec 100644 --- a/api/config.py +++ b/api/config.py @@ -43,6 +43,7 @@ DEFAULTS = { 'SENTRY_TRACES_SAMPLE_RATE': 1.0, 'SENTRY_PROFILES_SAMPLE_RATE': 1.0, 'WEAVIATE_GRPC_ENABLED': 'True', + 'WEAVIATE_BATCH_SIZE': 100, 'CELERY_BACKEND': 'database', 'PDF_PREVIEW': 'True', 'LOG_LEVEL': 'INFO', @@ -138,6 +139,7 @@ class Config: self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY') self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED') + self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE')) # qdrant settings self.QDRANT_URL = get_env('QDRANT_URL') diff --git a/api/core/vector_store/vector_store.py b/api/core/vector_store/vector_store.py index 56b5fd0f97..59a4c5060b 100644 --- a/api/core/vector_store/vector_store.py +++ b/api/core/vector_store/vector_store.py @@ -27,7 +27,8 @@ class VectorStore: self._client = WeaviateVectorStoreClient( endpoint=app.config['WEAVIATE_ENDPOINT'], api_key=app.config['WEAVIATE_API_KEY'], - grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'] + grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'], + batch_size=app.config['WEAVIATE_BATCH_SIZE'] ) elif self._vector_store == 'qdrant': self._client = QdrantVectorStoreClient( diff --git a/api/core/vector_store/weaviate_vector_store_client.py b/api/core/vector_store/weaviate_vector_store_client.py index f56162e000..0fe120de71 100644 --- a/api/core/vector_store/weaviate_vector_store_client.py +++ b/api/core/vector_store/weaviate_vector_store_client.py @@ -18,10 +18,10 @@ from llama_index.readers.weaviate.utils import ( class WeaviateVectorStoreClient(BaseVectorStoreClient): - def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool): - self._client = self.init_from_config(endpoint, api_key, grpc_enabled) + def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): + self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size) - def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool): + def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): auth_config = weaviate.auth.AuthApiKey(api_key=api_key) weaviate.connect.connection.has_grpc = grpc_enabled @@ -36,7 +36,7 @@ class WeaviateVectorStoreClient(BaseVectorStoreClient): client.batch.configure( # `batch_size` takes an `int` value to enable auto-batching # (`None` is used for manual batching) - batch_size=100, + batch_size=batch_size, # dynamically update the `batch_size` based on import speed dynamic=True, # `timeout_retries` takes an `int` value to retry on time outs