diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index aff1ba866..b9442ba86 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -22,7 +22,7 @@ from flask_login import login_required, current_user from rag.app.qa import rmPrefix, beAdoc from rag.app.tag import label_question from rag.nlp import search, rag_tokenizer -from rag.prompts import keyword_extraction +from rag.prompts import keyword_extraction, cross_languages from rag.settings import PAGERANK_FLD from rag.utils import rmSpace from api.db import LLMType, ParserType @@ -275,6 +275,7 @@ def retrieval_test(): vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) + langs = req.get("cross_languages", []) tenant_ids = [] try: @@ -294,6 +295,9 @@ def retrieval_test(): if not e: return get_data_error_result(message="Knowledgebase not found!") + if langs: + question = cross_languages(kb.tenant_id, None, question, langs) + embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) rerank_mdl = None diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index c98756ae1..1d6ed6ca0 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -36,7 +36,8 @@ from api.utils import current_timestamp, datetime_format from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in +from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \ + cross_languages from rag.utils import num_tokens_from_string, rmSpace from rag.utils.tavily_conn import Tavily @@ -214,6 +215,9 @@ def chat(dialog, messages, stream=True, **kwargs): else: questions = questions[-1:] + if prompt_config.get("cross_languages"): + questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] + refine_question_ts = timer() rerank_mdl = None diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index b4fb3e87a..284df853a 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -131,7 +131,7 @@ class DocumentService(CommonService): if types: query = query.where(cls.model.type.in_(types)) - return query.scalar() or 0 + return int(query.scalar()) or 0 @classmethod @DB.connection_context() diff --git a/rag/prompts.py b/rag/prompts.py index 13c0fce88..7d061b810 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -306,6 +306,60 @@ Output: What's the weather in Rochester on {tomorrow}? ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] +def cross_languages(tenant_id, llm_id, query, languages=[]): + from api.db.services.llm_service import LLMBundle + + if llm_id and llm_id2llm_type(llm_id) == "image2text": + chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) + else: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + + sys_prompt = """ +Act as a streamlined multilingual translator. Strictly output translations separated by ### without any explanations or formatting. Follow these rules: + +1. Accept batch translation requests in format: +[source text] +=== +[target languages separated by commas] + +2. Always maintain: +- Original formatting (tables/lists/spacing) +- Technical terminology accuracy +- Cultural context appropriateness + +3. Output format: +[language1 translation] +### +[language1 translation] + +**Examples:** +Input: +Hello World! Let's discuss AI safety. +=== +Chinese, French, Jappanese + +Output: +你好世界!让我们讨论人工智能安全问题。 +### +Bonjour le monde ! Parlons de la sécurité de l'IA. +### +こんにちは世界!AIの安全性について話し合いましょう。 +""" + user_prompt=f""" +Input: +{query} +=== +{', '.join(languages)} + +Output: +""" + + ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.2}) + ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) + if ans.find("**ERROR**") >= 0: + return query + return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()]) + def content_tagging(chat_mdl, content, all_tags, examples, topn=3): prompt = f"""