mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-10 21:19:02 +08:00
Feat: support cross-lang search. (#7557)
### What problem does this PR solve? #7376 #4503 #5710 #7470 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
2fe332d01d
commit
2ccec93d71
@ -22,7 +22,7 @@ from flask_login import login_required, current_user
|
|||||||
from rag.app.qa import rmPrefix, beAdoc
|
from rag.app.qa import rmPrefix, beAdoc
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
from rag.prompts import keyword_extraction
|
from rag.prompts import keyword_extraction, cross_languages
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
@ -275,6 +275,7 @@ def retrieval_test():
|
|||||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -294,6 +295,9 @@ def retrieval_test():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
|
if langs:
|
||||||
|
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
|
@ -36,7 +36,8 @@ from api.utils import current_timestamp, datetime_format
|
|||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
|
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \
|
||||||
|
cross_languages
|
||||||
from rag.utils import num_tokens_from_string, rmSpace
|
from rag.utils import num_tokens_from_string, rmSpace
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
@ -214,6 +215,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
else:
|
else:
|
||||||
questions = questions[-1:]
|
questions = questions[-1:]
|
||||||
|
|
||||||
|
if prompt_config.get("cross_languages"):
|
||||||
|
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||||
|
|
||||||
refine_question_ts = timer()
|
refine_question_ts = timer()
|
||||||
|
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
|
@ -131,7 +131,7 @@ class DocumentService(CommonService):
|
|||||||
if types:
|
if types:
|
||||||
query = query.where(cls.model.type.in_(types))
|
query = query.where(cls.model.type.in_(types))
|
||||||
|
|
||||||
return query.scalar() or 0
|
return int(query.scalar()) or 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
|
@ -306,6 +306,60 @@ Output: What's the weather in Rochester on {tomorrow}?
|
|||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||||
|
|
||||||
|
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
|
||||||
|
if llm_id and llm_id2llm_type(llm_id) == "image2text":
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
|
else:
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||||
|
|
||||||
|
sys_prompt = """
|
||||||
|
Act as a streamlined multilingual translator. Strictly output translations separated by ### without any explanations or formatting. Follow these rules:
|
||||||
|
|
||||||
|
1. Accept batch translation requests in format:
|
||||||
|
[source text]
|
||||||
|
===
|
||||||
|
[target languages separated by commas]
|
||||||
|
|
||||||
|
2. Always maintain:
|
||||||
|
- Original formatting (tables/lists/spacing)
|
||||||
|
- Technical terminology accuracy
|
||||||
|
- Cultural context appropriateness
|
||||||
|
|
||||||
|
3. Output format:
|
||||||
|
[language1 translation]
|
||||||
|
###
|
||||||
|
[language1 translation]
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
Input:
|
||||||
|
Hello World! Let's discuss AI safety.
|
||||||
|
===
|
||||||
|
Chinese, French, Jappanese
|
||||||
|
|
||||||
|
Output:
|
||||||
|
你好世界!让我们讨论人工智能安全问题。
|
||||||
|
###
|
||||||
|
Bonjour le monde ! Parlons de la sécurité de l'IA.
|
||||||
|
###
|
||||||
|
こんにちは世界!AIの安全性について話し合いましょう。
|
||||||
|
"""
|
||||||
|
user_prompt=f"""
|
||||||
|
Input:
|
||||||
|
{query}
|
||||||
|
===
|
||||||
|
{', '.join(languages)}
|
||||||
|
|
||||||
|
Output:
|
||||||
|
"""
|
||||||
|
|
||||||
|
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.2})
|
||||||
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
if ans.find("**ERROR**") >= 0:
|
||||||
|
return query
|
||||||
|
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
|
||||||
|
|
||||||
|
|
||||||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user