From 18106a4fc6c099b45360af0ae33ff01f021e77e7 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:57:03 +0800 Subject: [PATCH] add tidb on qdrant type (#9831) Co-authored-by: Zhaofeng Miao <522856232@qq.com> --- api/configs/feature/__init__.py | 5 + api/configs/middleware/__init__.py | 7 + .../middleware/vdb/tidb_on_qdrant_config.py | 65 +++ api/controllers/console/datasets/datasets.py | 1 + .../datasource/vdb/tidb_on_qdrant/__init__.py | 0 .../vdb/tidb_on_qdrant/tidb_entities.py | 17 + .../tidb_on_qdrant/tidb_on_qdrant_vector.py | 526 ++++++++++++++++++ .../vdb/tidb_on_qdrant/tidb_service.py | 250 +++++++++ api/core/rag/datasource/vdb/vector_factory.py | 17 +- api/core/rag/datasource/vdb/vector_type.py | 1 + api/extensions/ext_celery.py | 11 + ...0956-0251a1c768cc_add_tidb_auth_binding.py | 51 ++ ..._10_22_0959-43fa78bc3b7d_add_white_list.py | 42 ++ api/models/dataset.py | 32 ++ api/schedule/create_tidb_serverless_task.py | 56 ++ .../update_tidb_serverless_status_task.py | 51 ++ api/services/auth/jina.py | 44 ++ .../jina-reader/base/checkbox-with-label.tsx | 40 ++ .../jina-reader/base/error-message.tsx | 30 + .../create/website/jina-reader/base/field.tsx | 54 ++ .../create/website/jina-reader/base/input.tsx | 58 ++ .../website/jina-reader/base/options-wrap.tsx | 55 ++ .../website/jina-reader/base/url-input.tsx | 48 ++ .../jina-reader/crawled-result-item.tsx | 40 ++ .../website/jina-reader/crawled-result.tsx | 87 +++ .../create/website/jina-reader/crawling.tsx | 37 ++ .../website/jina-reader/mock-crawl-result.ts | 24 + 27 files changed, 1648 insertions(+), 1 deletion(-) create mode 100644 api/configs/middleware/vdb/tidb_on_qdrant_config.py create mode 100644 api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py create mode 100644 api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py create mode 100644 api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py create mode 100644 api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py create mode 100644 api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py create mode 100644 api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py create mode 100644 api/schedule/create_tidb_serverless_task.py create mode 100644 api/schedule/update_tidb_serverless_status_task.py create mode 100644 api/services/auth/jina.py create mode 100644 web/app/components/datasets/create/website/jina-reader/base/checkbox-with-label.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/base/error-message.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/base/field.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/base/input.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/base/url-input.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/crawled-result-item.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/crawled-result.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/crawling.tsx create mode 100644 web/app/components/datasets/create/website/jina-reader/mock-crawl-result.ts diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 4d6c9aedc1..0fa926038d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -571,6 +571,11 @@ class DataSetConfig(BaseSettings): default=False, ) + TIDB_SERVERLESS_NUMBER: PositiveInt = Field( + description="number of tidb serverless cluster", + default=500, + ) + class WorkspaceConfig(BaseSettings): """ diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 705e30b7ff..3d68e29d0e 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -27,6 +27,7 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig from configs.middleware.vdb.qdrant_config import QdrantConfig from configs.middleware.vdb.relyt_config import RelytConfig from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig +from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig from configs.middleware.vdb.upstash_config import UpstashConfig from configs.middleware.vdb.vikingdb_config import VikingDBConfig @@ -54,6 +55,11 @@ class VectorStoreConfig(BaseSettings): default=None, ) + VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( + description="Enable whitelist for vector store.", + default=False, + ) + class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( @@ -248,5 +254,6 @@ class MiddlewareConfig( InternalTestConfig, VikingDBConfig, UpstashConfig, + TidbOnQdrantConfig, ): pass diff --git a/api/configs/middleware/vdb/tidb_on_qdrant_config.py b/api/configs/middleware/vdb/tidb_on_qdrant_config.py new file mode 100644 index 0000000000..98268798ef --- /dev/null +++ b/api/configs/middleware/vdb/tidb_on_qdrant_config.py @@ -0,0 +1,65 @@ +from typing import Optional + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class TidbOnQdrantConfig(BaseSettings): + """ + Tidb on Qdrant configs + """ + + TIDB_ON_QDRANT_URL: Optional[str] = Field( + description="Tidb on Qdrant url", + default=None, + ) + + TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( + description="Tidb on Qdrant api key", + default=None, + ) + + TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( + description="Tidb on Qdrant client timeout in seconds", + default=20, + ) + + TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field( + description="whether enable grpc support for Tidb on Qdrant connection", + default=False, + ) + + TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field( + description="Tidb on Qdrant grpc port", + default=6334, + ) + + TIDB_PUBLIC_KEY: Optional[str] = Field( + description="Tidb account public key", + default=None, + ) + + TIDB_PRIVATE_KEY: Optional[str] = Field( + description="Tidb account private key", + default=None, + ) + + TIDB_API_URL: Optional[str] = Field( + description="Tidb API url", + default=None, + ) + + TIDB_IAM_API_URL: Optional[str] = Field( + description="Tidb IAM API url", + default=None, + ) + + TIDB_REGION: Optional[str] = Field( + description="Tidb serverless region", + default="regions/aws-us-east-1", + ) + + TIDB_PROJECT_ID: Optional[str] = Field( + description="Tidb project id", + default=None, + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index b500c6c3fa..c8022efb8b 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -639,6 +639,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR + | VectorType.TIDB_ON_QDRANT ): return { "retrieval_method": [ diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py new file mode 100644 index 0000000000..1e62b3c589 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import BaseModel + + +class ClusterEntity(BaseModel): + """ + Model Config Entity. + """ + + name: str + cluster_id: str + displayName: str + region: str + spendingLimit: Optional[int] = 1000 + version: str + createdBy: str diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py new file mode 100644 index 0000000000..a38f84636e --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -0,0 +1,526 @@ +import json +import os +import uuid +from collections.abc import Generator, Iterable, Sequence +from itertools import islice +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import qdrant_client +import requests +from flask import current_app +from pydantic import BaseModel +from qdrant_client.http import models as rest +from qdrant_client.http.models import ( + FilterSelector, + HnswConfigDiff, + PayloadSchemaType, + TextIndexParams, + TextIndexType, + TokenizerType, +) +from qdrant_client.local.qdrant_local import QdrantLocal +from requests.auth import HTTPDigestAuth + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, TidbAuthBinding + +if TYPE_CHECKING: + from qdrant_client import grpc # noqa + from qdrant_client.conversions import common_types + from qdrant_client.http import models as rest + + DictFilter = dict[str, Union[str, int, bool, dict, list]] + MetadataFilter = Union[DictFilter, common_types.Filter] + + +class TidbOnQdrantConfig(BaseModel): + endpoint: str + api_key: Optional[str] = None + timeout: float = 20 + root_path: Optional[str] = None + grpc_port: int = 6334 + prefer_grpc: bool = False + + def to_qdrant_params(self): + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") + if not os.path.isabs(path): + path = os.path.join(self.root_path, path) + + return {"path": path} + else: + return { + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": False, + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, + } + + +class TidbConfig(BaseModel): + api_url: str + public_key: str + private_key: str + + +class TidbOnQdrantVector(BaseVector): + def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"): + super().__init__(collection_name) + self._client_config = config + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._distance_func = distance_func.upper() + self._group_id = group_id + + def get_type(self) -> str: + return VectorType.TIDB_ON_QDRANT + + def to_index_struct(self) -> dict: + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + if texts: + # get embedding vector size + vector_size = len(embeddings[0]) + # get collection name + collection_name = self._collection_name + # create collection + self.create_collection(collection_name, vector_size) + + self.add_texts(texts, embeddings, **kwargs) + + def create_collection(self, collection_name: str, vector_size: int): + lock_name = "vector_indexing_lock_{}".format(collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + collection_name = collection_name or uuid.uuid4().hex + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if collection_name not in all_collection_name: + from qdrant_client.http import models as rest + + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[self._distance_func], + ) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) + self._client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + hnsw_config=hnsw_config, + timeout=int(self._client_config.timeout), + ) + + # create group_id payload index + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) + # create doc_id payload index + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) + # create full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params + ) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + added_ids = [] + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) + added_ids.extend(batch_ids) + + return added_ids + + def _generate_rest_batches( + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, + ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: + from qdrant_client.http import models as rest + + texts_iterator = iter(texts) + embeddings_iterator = iter(embeddings) + metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata and id for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) + + # Generate the embeddings for all the texts in a batch + batch_embeddings = list(islice(embeddings_iterator, batch_size)) + + points = [ + rest.PointStruct( + id=point_id, + vector=vector, + payload=payload, + ) + for point_id, vector, payload in zip( + batch_ids, + batch_embeddings, + self._build_payloads( + batch_texts, + batch_metadatas, + Field.CONTENT_KEY.value, + Field.METADATA_KEY.value, + group_id, + Field.GROUP_KEY.value, + ), + ) + ] + + yield batch_ids, points + + @classmethod + def _build_payloads( + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, + ) -> list[dict]: + payloads = [] + for i, text in enumerate(texts): + if text is None: + raise ValueError( + "At least one of the texts is None. Please remove it before " + "calling .from_texts or .add_texts on Qdrant instance." + ) + metadata = metadatas[i] if metadatas is not None else None + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) + + return payloads + + def delete_by_metadata_field(self, key: str, value: str): + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ], + ) + + self._reload_if_needed() + + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def delete(self): + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + self._client.delete_collection(collection_name=self._collection_name) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def delete_by_ids(self, ids: list[str]) -> None: + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + for node_id in ids: + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchValue(value=node_id), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def text_exists(self, id: str) -> bool: + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if self._collection_name not in all_collection_name: + return False + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) + + return len(response) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from qdrant_client.http import models + + filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ], + ) + results = self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, + query_filter=filter, + limit=kwargs.get("top_k", 4), + with_payload=True, + with_vectors=True, + score_threshold=kwargs.get("score_threshold", 0.0), + ) + docs = [] + for result in results: + metadata = result.payload.get(Field.METADATA_KEY.value) or {} + # duplicate check score threshold + score_threshold = kwargs.get("score_threshold") or 0.0 + if result.score > score_threshold: + metadata["score"] = result.score + doc = Document( + page_content=result.payload.get(Field.CONTENT_KEY.value), + metadata=metadata, + ) + docs.append(doc) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Return docs most similar by bm25. + Returns: + List of documents most similar to the query text and distance for each. + """ + from qdrant_client.http import models + + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="page_content", + match=models.MatchText(text=query), + ) + ] + ) + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=kwargs.get("top_k", 2), + with_payload=True, + with_vectors=True, + ) + results = response[0] + documents = [] + for result in results: + if result: + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) + document.metadata["vector"] = result.vector + documents.append(document) + + return documents + + def _reload_if_needed(self): + if isinstance(self._client, QdrantLocal): + self._client = cast(QdrantLocal, self._client) + self._client._load() + + @classmethod + def _document_from_scored_point( + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, + ) -> Document: + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata=scored_point.payload.get(metadata_payload_key) or {}, + ) + + +class TidbOnQdrantVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: + tidb_auth_binding = ( + db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() + ) + if not tidb_auth_binding: + idle_tidb_auth_binding = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") + .limit(1) + .one_or_none() + ) + if idle_tidb_auth_binding: + idle_tidb_auth_binding.active = True + idle_tidb_auth_binding.tenant_id = dataset.tenant_id + db.session.commit() + TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" + else: + with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): + tidb_auth_binding = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) + .one_or_none() + ) + if tidb_auth_binding: + TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + + else: + new_cluster = TidbService.create_tidb_serverless_cluster( + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + dify_config.TIDB_REGION, + ) + new_tidb_auth_binding = TidbAuthBinding( + cluster_id=new_cluster["cluster_id"], + cluster_name=new_cluster["cluster_name"], + account=new_cluster["account"], + password=new_cluster["password"], + tenant_id=dataset.tenant_id, + active=True, + status="ACTIVE", + ) + db.session.add(new_tidb_auth_binding) + db.session.commit() + TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" + + else: + TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name)) + + config = current_app.config + + return TidbOnQdrantVector( + collection_name=collection_name, + group_id=dataset.id, + config=TidbOnQdrantConfig( + endpoint=dify_config.TIDB_ON_QDRANT_URL, + api_key=TIDB_ON_QDRANT_API_KEY, + root_path=config.root_path, + timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, + prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, + ), + ) + + def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str): + """ + Creates a new TiDB Serverless cluster. + :param tidb_config: The configuration for the TiDB Cloud API. + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": "1372813089454548012", + } + cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} + + response = requests.post( + f"{tidb_config.api_url}/clusters", + json=cluster_data, + auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str): + """ + Changes the root password of a specific TiDB Serverless cluster. + + :param tidb_config: The configuration for the TiDB Cloud API. + :param cluster_id: The ID of the cluster for which the password is to be changed (required). + :param new_password: The new password for the root user (required). + :return: The response from the API. + """ + + body = {"password": new_password} + + response = requests.put( + f"{tidb_config.api_url}/clusters/{cluster_id}/password", + json=body, + auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py new file mode 100644 index 0000000000..f10d6339ee --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -0,0 +1,250 @@ +import time +import uuid + +import requests +from requests.auth import HTTPDigestAuth + +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import TidbAuthBinding + + +class TidbService: + @staticmethod + def create_tidb_serverless_cluster( + project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ): + """ + Creates a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": 100, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "")[:16] + cluster_data = { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + + response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + response_data = response.json() + cluster_id = response_data["clusterId"] + retry_count = 0 + max_retries = 30 + while retry_count < max_retries: + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if cluster_response["state"] == "ACTIVE": + user_prefix = cluster_response["userPrefix"] + return { + "cluster_id": cluster_id, + "cluster_name": display_name, + "account": f"{user_prefix}.root", + "password": password, + } + time.sleep(30) # wait 30 seconds before retrying + retry_count += 1 + else: + response.raise_for_status() + + @staticmethod + def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def change_tidb_serverless_root_password( + api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str + ): + """ + Changes the root password of a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster for which the password is to be changed (required).+ + :param account: The account for which the password is to be changed (required). + :param new_password: The new password for the root user (required). + :return: The response from the API. + """ + + body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} + + response = requests.patch( + f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", + json=body, + auth=HTTPDigestAuth(public_key, private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def batch_update_tidb_serverless_cluster_status( + tidb_serverless_list: list[TidbAuthBinding], + project_id: str, + api_url: str, + iam_url: str, + public_key: str, + private_key: str, + ) -> list[dict]: + """ + Update the status of a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + clusters = [] + tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} + cluster_ids = [item.cluster_id for item in tidb_serverless_list] + params = {"clusterIds": cluster_ids, "view": "FULL"} + response = requests.get( + f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_infos = [] + for item in response_data["clusters"]: + state = item["state"] + userPrefix = item["userPrefix"] + if state == "ACTIVE" and len(userPrefix) > 0: + cluster_info = tidb_serverless_list_map[item["clusterId"]] + cluster_info.status = "ACTIVE" + cluster_info.account = f"{userPrefix}.root" + db.session.add(cluster_info) + db.session.commit() + else: + response.raise_for_status() + + @staticmethod + def batch_create_tidb_serverless_cluster( + batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ) -> list[dict]: + """ + Creates a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + clusters = [] + for _ in range(batch_size): + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": 10, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "") + cluster_data = { + "cluster": { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + } + cache_key = f"tidb_serverless_cluster_password:{display_name}" + redis_client.setex(cache_key, 3600, password) + clusters.append(cluster_data) + + request_body = {"requests": clusters} + response = requests.post( + f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_infos = [] + for item in response_data["clusters"]: + cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" + password = redis_client.get(cache_key) + if not password: + continue + cluster_info = { + "cluster_id": item["clusterId"], + "cluster_name": item["displayName"], + "account": "root", + "password": password.decode("utf-8"), + } + cluster_infos.append(cluster_info) + return cluster_infos + else: + response.raise_for_status() diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 9ea3cf4b6b..59a5aadacd 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset +from models.dataset import Dataset, Whitelist class AbstractVectorFactory(ABC): @@ -35,8 +36,18 @@ class Vector: def _init_vector(self) -> BaseVector: vector_type = dify_config.VECTOR_STORE + if self._dataset.index_struct_dict: vector_type = self._dataset.index_struct_dict["type"] + else: + if dify_config.VECTOR_STORE_WHITELIST_ENABLE: + whitelist = ( + db.session.query(Whitelist) + .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") + .one_or_none() + ) + if whitelist: + vector_type = VectorType.TIDB_ON_QDRANT if not vector_type: raise ValueError("Vector store must be specified.") @@ -115,6 +126,10 @@ class Vector: from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory return UpstashVectorFactory + case VectorType.TIDB_ON_QDRANT: + from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory + + return TidbOnQdrantVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 30aa814553..3b6df94f78 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -19,3 +19,4 @@ class VectorType(str, Enum): BAIDU = "baidu" VIKINGDB = "vikingdb" UPSTASH = "upstash" + TIDB_ON_QDRANT = "tidb_on_qdrant" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index b9b019373d..504899c276 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,6 +1,7 @@ from datetime import timedelta from celery import Celery, Task +from celery.schedules import crontab from flask import Flask from configs import dify_config @@ -55,6 +56,8 @@ def init_app(app: Flask) -> Celery: imports = [ "schedule.clean_embedding_cache_task", "schedule.clean_unused_datasets_task", + "schedule.create_tidb_serverless_task", + "schedule.update_tidb_serverless_status_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -66,6 +69,14 @@ def init_app(app: Flask) -> Celery: "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", "schedule": timedelta(days=day), }, + "create_tidb_serverless_task": { + "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task", + "schedule": crontab(minute="0", hour="*"), + }, + "update_tidb_serverless_status_task": { + "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", + "schedule": crontab(minute="30", hour="*"), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py new file mode 100644 index 0000000000..ca2e410442 --- /dev/null +++ b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py @@ -0,0 +1,51 @@ +"""add-tidb-auth-binding + +Revision ID: 0251a1c768cc +Revises: 63a83fcf12ba +Create Date: 2024-08-15 09:56:59.012490 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '0251a1c768cc' +down_revision = 'bbadea11becb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tidb_auth_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('cluster_id', sa.String(length=255), nullable=False), + sa.Column('cluster_name', sa.String(length=255), nullable=False), + sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False), + sa.Column('account', sa.String(length=255), nullable=False), + sa.Column('password', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey') + ) + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False) + batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False) + batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.drop_index('tidb_auth_bindings_tenant_idx') + batch_op.drop_index('tidb_auth_bindings_created_at_idx') + batch_op.drop_index('tidb_auth_bindings_active_idx') + batch_op.drop_index('tidb_auth_bindings_status_idx') + op.drop_table('tidb_auth_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py new file mode 100644 index 0000000000..9daf148bc4 --- /dev/null +++ b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py @@ -0,0 +1,42 @@ +"""add_white_list + +Revision ID: 43fa78bc3b7d +Revises: 0251a1c768cc +Create Date: 2024-10-22 09:59:23.713716 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '43fa78bc3b7d' +down_revision = '0251a1c768cc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('whitelists', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='whitelists_pkey') + ) + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.drop_index('whitelists_tenant_idx') + + op.drop_table('whitelists') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 4e2ccab7e8..459bfa47c0 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -704,6 +704,38 @@ class DatasetCollectionBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) +class TidbAuthBinding(db.Model): + __tablename__ = "tidb_auth_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), + db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), + db.Index("tidb_auth_bindings_active_idx", "active"), + db.Index("tidb_auth_bindings_created_at_idx", "created_at"), + db.Index("tidb_auth_bindings_status_idx", "status"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + cluster_id = db.Column(db.String(255), nullable=False) + cluster_name = db.Column(db.String(255), nullable=False) + active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) + account = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class Whitelist(db.Model): + __tablename__ = "whitelists" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="whitelists_pkey"), + db.Index("whitelists_tenant_idx", "tenant_id"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + category = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + class DatasetPermission(db.Model): __tablename__ = "dataset_permissions" __table_args__ = ( diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py new file mode 100644 index 0000000000..42d6c04beb --- /dev/null +++ b/api/schedule/create_tidb_serverless_task.py @@ -0,0 +1,56 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from extensions.ext_database import db +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def create_tidb_serverless_task(): + click.echo(click.style("Start create tidb serverless task.", fg="green")) + tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER + start_at = time.perf_counter() + while True: + try: + # check the number of idle tidb serverless + idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() + if idle_tidb_serverless_number >= tidb_serverless_number: + break + # create tidb serverless + iterations_per_thread = 20 + create_clusters(iterations_per_thread) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + break + + end_at = time.perf_counter() + click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) + + +def create_clusters(batch_size): + try: + new_clusters = TidbService.batch_create_tidb_serverless_cluster( + batch_size, + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + dify_config.TIDB_REGION, + ) + for new_cluster in new_clusters: + tidb_auth_binding = TidbAuthBinding( + cluster_id=new_cluster["cluster_id"], + cluster_name=new_cluster["cluster_name"], + account=new_cluster["account"], + password=new_cluster["password"], + ) + db.session.add(tidb_auth_binding) + db.session.commit() + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py new file mode 100644 index 0000000000..07eca3173b --- /dev/null +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -0,0 +1,51 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def update_tidb_serverless_status_task(): + click.echo(click.style("Update tidb serverless status task.", fg="green")) + start_at = time.perf_counter() + while True: + try: + # check the number of idle tidb serverless + tidb_serverless_list = TidbAuthBinding.query.filter( + TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" + ).all() + if len(tidb_serverless_list) == 0: + break + # update tidb serverless status + iterations_per_thread = 20 + update_clusters(tidb_serverless_list) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + break + + end_at = time.perf_counter() + click.echo( + click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") + ) + + +def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): + try: + # batch 20 + for i in range(0, len(tidb_serverless_list), 20): + items = tidb_serverless_list[i : i + 20] + TidbService.batch_update_tidb_serverless_cluster_status( + items, + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + ) + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py new file mode 100644 index 0000000000..de898a1f94 --- /dev/null +++ b/api/services/auth/jina.py @@ -0,0 +1,44 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class JinaAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + } + response = self._post_request("https://r.jina.ai", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/web/app/components/datasets/create/website/jina-reader/base/checkbox-with-label.tsx b/web/app/components/datasets/create/website/jina-reader/base/checkbox-with-label.tsx new file mode 100644 index 0000000000..25d40fe076 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/checkbox-with-label.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import cn from '@/utils/classnames' +import Checkbox from '@/app/components/base/checkbox' +import Tooltip from '@/app/components/base/tooltip' + +type Props = { + className?: string + isChecked: boolean + onChange: (isChecked: boolean) => void + label: string + labelClassName?: string + tooltip?: string +} + +const CheckboxWithLabel: FC = ({ + className = '', + isChecked, + onChange, + label, + labelClassName, + tooltip, +}) => { + return ( + + ) +} +export default React.memo(CheckboxWithLabel) diff --git a/web/app/components/datasets/create/website/jina-reader/base/error-message.tsx b/web/app/components/datasets/create/website/jina-reader/base/error-message.tsx new file mode 100644 index 0000000000..aa337ec4bf --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/error-message.tsx @@ -0,0 +1,30 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import cn from '@/utils/classnames' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' + +type Props = { + className?: string + title: string + errorMsg?: string +} + +const ErrorMessage: FC = ({ + className, + title, + errorMsg, +}) => { + return ( +
+
+ +
{title}
+
+ {errorMsg && ( +
{errorMsg}
+ )} +
+ ) +} +export default React.memo(ErrorMessage) diff --git a/web/app/components/datasets/create/website/jina-reader/base/field.tsx b/web/app/components/datasets/create/website/jina-reader/base/field.tsx new file mode 100644 index 0000000000..5b5ca90c5d --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/field.tsx @@ -0,0 +1,54 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Input from './input' +import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' + +type Props = { + className?: string + label: string + labelClassName?: string + value: string | number + onChange: (value: string | number) => void + isRequired?: boolean + placeholder?: string + isNumber?: boolean + tooltip?: string +} + +const Field: FC = ({ + className, + label, + labelClassName, + value, + onChange, + isRequired = false, + placeholder = '', + isNumber = false, + tooltip, +}) => { + return ( +
+
+
{label}
+ {isRequired && *} + {tooltip && ( + {tooltip}
+ } + triggerClassName='ml-0.5 w-4 h-4' + /> + )} +
+ + + ) +} +export default React.memo(Field) diff --git a/web/app/components/datasets/create/website/jina-reader/base/input.tsx b/web/app/components/datasets/create/website/jina-reader/base/input.tsx new file mode 100644 index 0000000000..7d2d2b609f --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/input.tsx @@ -0,0 +1,58 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' + +type Props = { + value: string | number + onChange: (value: string | number) => void + placeholder?: string + isNumber?: boolean +} + +const MIN_VALUE = 0 + +const Input: FC = ({ + value, + onChange, + placeholder = '', + isNumber = false, +}) => { + const handleChange = useCallback((e: React.ChangeEvent) => { + const value = e.target.value + if (isNumber) { + let numberValue = parseInt(value, 10) // integer only + if (isNaN(numberValue)) { + onChange('') + return + } + if (numberValue < MIN_VALUE) + numberValue = MIN_VALUE + + onChange(numberValue) + return + } + onChange(value) + }, [isNumber, onChange]) + + const otherOption = (() => { + if (isNumber) { + return { + min: MIN_VALUE, + } + } + return { + + } + })() + return ( + + ) +} +export default React.memo(Input) diff --git a/web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx b/web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx new file mode 100644 index 0000000000..652401a20f --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx @@ -0,0 +1,55 @@ +'use client' +import { useBoolean } from 'ahooks' +import type { FC } from 'react' +import React, { useEffect } from 'react' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' +import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + className?: string + children: React.ReactNode + controlFoldOptions?: number +} + +const OptionsWrap: FC = ({ + className = '', + children, + controlFoldOptions, +}) => { + const { t } = useTranslation() + + const [fold, { + toggle: foldToggle, + setTrue: foldHide, + }] = useBoolean(false) + + useEffect(() => { + if (controlFoldOptions) + foldHide() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [controlFoldOptions]) + return ( +
+
+
+ +
{t(`${I18N_PREFIX}.options`)}
+
+ +
+ {!fold && ( +
+ {children} +
+ )} + +
+ ) +} +export default React.memo(OptionsWrap) diff --git a/web/app/components/datasets/create/website/jina-reader/base/url-input.tsx b/web/app/components/datasets/create/website/jina-reader/base/url-input.tsx new file mode 100644 index 0000000000..e6b0475874 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/url-input.tsx @@ -0,0 +1,48 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Input from './input' +import Button from '@/app/components/base/button' + +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + isRunning: boolean + onRun: (url: string) => void +} + +const UrlInput: FC = ({ + isRunning, + onRun, +}) => { + const { t } = useTranslation() + const [url, setUrl] = useState('') + const handleUrlChange = useCallback((url: string | number) => { + setUrl(url as string) + }, []) + const handleOnRun = useCallback(() => { + if (isRunning) + return + onRun(url) + }, [isRunning, onRun, url]) + + return ( +
+ + +
+ ) +} +export default React.memo(UrlInput) diff --git a/web/app/components/datasets/create/website/jina-reader/crawled-result-item.tsx b/web/app/components/datasets/create/website/jina-reader/crawled-result-item.tsx new file mode 100644 index 0000000000..5531d3e140 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/crawled-result-item.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import type { CrawlResultItem as CrawlResultItemType } from '@/models/datasets' +import Checkbox from '@/app/components/base/checkbox' + +type Props = { + payload: CrawlResultItemType + isChecked: boolean + isPreview: boolean + onCheckChange: (checked: boolean) => void + onPreview: () => void +} + +const CrawledResultItem: FC = ({ + isPreview, + payload, + isChecked, + onCheckChange, + onPreview, +}) => { + const { t } = useTranslation() + + const handleCheckChange = useCallback(() => { + onCheckChange(!isChecked) + }, [isChecked, onCheckChange]) + return ( +
+
+ +
{payload.title}
+
{t('datasetCreation.stepOne.website.preview')}
+
+
{payload.source_url}
+
+ ) +} +export default React.memo(CrawledResultItem) diff --git a/web/app/components/datasets/create/website/jina-reader/crawled-result.tsx b/web/app/components/datasets/create/website/jina-reader/crawled-result.tsx new file mode 100644 index 0000000000..2bd51e4d73 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/crawled-result.tsx @@ -0,0 +1,87 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import CheckboxWithLabel from './base/checkbox-with-label' +import CrawledResultItem from './crawled-result-item' +import cn from '@/utils/classnames' +import type { CrawlResultItem } from '@/models/datasets' + +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + className?: string + list: CrawlResultItem[] + checkedList: CrawlResultItem[] + onSelectedChange: (selected: CrawlResultItem[]) => void + onPreview: (payload: CrawlResultItem) => void + usedTime: number +} + +const CrawledResult: FC = ({ + className = '', + list, + checkedList, + onSelectedChange, + onPreview, + usedTime, +}) => { + const { t } = useTranslation() + + const isCheckAll = checkedList.length === list.length + + const handleCheckedAll = useCallback(() => { + if (!isCheckAll) + onSelectedChange(list) + + else + onSelectedChange([]) + }, [isCheckAll, list, onSelectedChange]) + + const handleItemCheckChange = useCallback((item: CrawlResultItem) => { + return (checked: boolean) => { + if (checked) + onSelectedChange([...checkedList, item]) + + else + onSelectedChange(checkedList.filter(checkedItem => checkedItem.source_url !== item.source_url)) + } + }, [checkedList, onSelectedChange]) + + const [previewIndex, setPreviewIndex] = React.useState(-1) + const handlePreview = useCallback((index: number) => { + return () => { + setPreviewIndex(index) + onPreview(list[index]) + } + }, [list, onPreview]) + + return ( +
+
+ +
{t(`${I18N_PREFIX}.scrapTimeInfo`, { + total: list.length, + time: usedTime.toFixed(1), + })}
+
+
+ {list.map((item, index) => ( + checkedItem.source_url === item.source_url)} + onCheckChange={handleItemCheckChange(item)} + /> + ))} +
+
+ ) +} +export default React.memo(CrawledResult) diff --git a/web/app/components/datasets/create/website/jina-reader/crawling.tsx b/web/app/components/datasets/create/website/jina-reader/crawling.tsx new file mode 100644 index 0000000000..ee26e7671a --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/crawling.tsx @@ -0,0 +1,37 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import { RowStruct } from '@/app/components/base/icons/src/public/other' + +type Props = { + className?: string + crawledNum: number + totalNum: number +} + +const Crawling: FC = ({ + className = '', + crawledNum, + totalNum, +}) => { + const { t } = useTranslation() + + return ( +
+
+ {t('datasetCreation.stepOne.website.totalPageScraped')} {crawledNum}/{totalNum} +
+ +
+ {['', '', '', ''].map((item, index) => ( +
+ +
+ ))} +
+
+ ) +} +export default React.memo(Crawling) diff --git a/web/app/components/datasets/create/website/jina-reader/mock-crawl-result.ts b/web/app/components/datasets/create/website/jina-reader/mock-crawl-result.ts new file mode 100644 index 0000000000..8fd5e6636f --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/mock-crawl-result.ts @@ -0,0 +1,24 @@ +import type { CrawlResultItem } from '@/models/datasets' + +const result: CrawlResultItem[] = [ + { + title: 'Start the frontend Docker container separately', + markdown: 'Markdown 1', + description: 'Description 1', + source_url: 'https://example.com/1', + }, + { + title: 'Advanced Tool Integration', + markdown: 'Markdown 2', + description: 'Description 2', + source_url: 'https://example.com/2', + }, + { + title: 'Local Source Code Start | English | Dify', + markdown: 'Markdown 3', + description: 'Description 3', + source_url: 'https://example.com/3', + }, +] + +export default result