mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 16:19:02 +08:00
refactor: Add an enumeration type and use the factory pattern to obtain the corresponding class (#9356)
This commit is contained in:
parent
5908fd6552
commit
cd7ab6231f
@ -1,8 +1,8 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
|
||||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||||
|
from core.rag.datasource.keyword.keyword_type import KeyWordType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
@ -13,16 +13,19 @@ class Keyword:
|
|||||||
self._keyword_processor = self._init_keyword()
|
self._keyword_processor = self._init_keyword()
|
||||||
|
|
||||||
def _init_keyword(self) -> BaseKeyword:
|
def _init_keyword(self) -> BaseKeyword:
|
||||||
config = dify_config
|
keyword_type = dify_config.KEYWORD_STORE
|
||||||
keyword_type = config.KEYWORD_STORE
|
keyword_factory = self.get_keyword_factory(keyword_type)
|
||||||
|
return keyword_factory(self._dataset)
|
||||||
|
|
||||||
if not keyword_type:
|
@staticmethod
|
||||||
raise ValueError("Keyword store must be specified.")
|
def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]:
|
||||||
|
match keyword_type:
|
||||||
|
case KeyWordType.JIEBA:
|
||||||
|
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||||
|
|
||||||
if keyword_type == "jieba":
|
return Jieba
|
||||||
return Jieba(dataset=self._dataset)
|
case _:
|
||||||
else:
|
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
||||||
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
|
||||||
|
|
||||||
def create(self, texts: list[Document], **kwargs):
|
def create(self, texts: list[Document], **kwargs):
|
||||||
self._keyword_processor.create(texts, **kwargs)
|
self._keyword_processor.create(texts, **kwargs)
|
||||||
|
5
api/core/rag/datasource/keyword/keyword_type.py
Normal file
5
api/core/rag/datasource/keyword/keyword_type.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class KeyWordType(str, Enum):
|
||||||
|
JIEBA = "jieba"
|
@ -1,15 +1,25 @@
|
|||||||
from services.auth.firecrawl import FirecrawlAuth
|
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||||
from services.auth.jina import JinaAuth
|
from services.auth.auth_type import AuthType
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyAuthFactory:
|
class ApiKeyAuthFactory:
|
||||||
def __init__(self, provider: str, credentials: dict):
|
def __init__(self, provider: str, credentials: dict):
|
||||||
if provider == "firecrawl":
|
auth_factory = self.get_apikey_auth_factory(provider)
|
||||||
self.auth = FirecrawlAuth(credentials)
|
self.auth = auth_factory(credentials)
|
||||||
elif provider == "jinareader":
|
|
||||||
self.auth = JinaAuth(credentials)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid provider")
|
|
||||||
|
|
||||||
def validate_credentials(self):
|
def validate_credentials(self):
|
||||||
return self.auth.validate_credentials()
|
return self.auth.validate_credentials()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
|
||||||
|
match provider:
|
||||||
|
case AuthType.FIRECRAWL:
|
||||||
|
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||||
|
|
||||||
|
return FirecrawlAuth
|
||||||
|
case AuthType.JINA:
|
||||||
|
from services.auth.jina.jina import JinaAuth
|
||||||
|
|
||||||
|
return JinaAuth
|
||||||
|
case _:
|
||||||
|
raise ValueError("Invalid provider")
|
||||||
|
6
api/services/auth/auth_type.py
Normal file
6
api/services/auth/auth_type.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class AuthType(str, Enum):
|
||||||
|
FIRECRAWL = "firecrawl"
|
||||||
|
JINA = "jinareader"
|
0
api/services/auth/firecrawl/__init__.py
Normal file
0
api/services/auth/firecrawl/__init__.py
Normal file
0
api/services/auth/jina/__init__.py
Normal file
0
api/services/auth/jina/__init__.py
Normal file
Loading…
x
Reference in New Issue
Block a user