From 50f209204e7b81dd7eafea0381416271f60a5dd1 Mon Sep 17 00:00:00 2001 From: Yingfeng Date: Thu, 2 Jan 2025 13:44:44 +0800 Subject: [PATCH] Synchronize with enterprise version (#4325) ### Type of change - [x] Refactoring --- agent/templates/customer_service.json | 6 +- rag/app/knowledge_graph.py | 4 +- rag/app/manual.py | 4 +- rag/llm/chat_model.py | 2 +- rag/llm/tts_model.py | 29 ++++++- rag/svr/cache_file_svr.py | 118 +++++++++++++------------- 6 files changed, 94 insertions(+), 69 deletions(-) diff --git a/agent/templates/customer_service.json b/agent/templates/customer_service.json index e8aa89b63..edc9931c1 100644 --- a/agent/templates/customer_service.json +++ b/agent/templates/customer_service.json @@ -336,7 +336,7 @@ "parameters": [], "presencePenaltyEnabled": true, "presence_penalty": 0.4, - "prompt": "Role: You are a customer support. \n\nTask: Please answer the question based on content of knowledge base. \n\nReuirements & restrictions:\n - DO NOT make things up when all knowledge base content is irrelevant to the question. \n - Answers need to consider chat history.\n - Request about customer's contact information like, Wechat number, LINE number, twitter, discord, etc,. , when knowlegebase content can't answer his question. So, product expert could contact him soon to solve his problem.\n\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.", + "prompt": "Role: You are a customer support. \n\nTask: Please answer the question based on content of knowledge base. \n\nRequirements & restrictions:\n - DO NOT make things up when all knowledge base content is irrelevant to the question. \n - Answers need to consider chat history.\n - Request about customer's contact information like, Wechat number, LINE number, twitter, discord, etc,. , when knowledge base content can't answer his question. So, product expert could contact him soon to solve his problem.\n\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.", "temperature": 0.1, "temperatureEnabled": true, "topPEnabled": true, @@ -603,7 +603,7 @@ { "data": { "form": { - "text": "Static messages.\nDefine replys after recieve user's contact information." + "text": "Static messages.\nDefine response after receive user's contact information." }, "label": "Note", "name": "N: What else?" @@ -691,7 +691,7 @@ { "data": { "form": { - "text": "Complete questions by conversation history.\nUser: What's RAGFlow?\nAssistant: RAGFlow is xxx.\nUser: How to deloy it?\n\nRefine it: How to deploy RAGFlow?" + "text": "Complete questions by conversation history.\nUser: What's RAGFlow?\nAssistant: RAGFlow is xxx.\nUser: How to deploy it?\n\nRefine it: How to deploy RAGFlow?" }, "label": "Note", "name": "N: Refine Question" diff --git a/rag/app/knowledge_graph.py b/rag/app/knowledge_graph.py index b252d5615..f4330f482 100644 --- a/rag/app/knowledge_graph.py +++ b/rag/app/knowledge_graph.py @@ -9,7 +9,7 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): parser_config = kwargs.get( "parser_config", { - "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": True}) + "chunk_token_num": 512, "delimiter": "\n!?;。;!?", "layout_recognize": True}) eng = lang.lower() == "english" parser_config["layout_recognize"] = True @@ -29,4 +29,4 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000, doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) chunks.extend(tokenize_chunks(sections, doc, eng)) - return chunks \ No newline at end of file + return chunks diff --git a/rag/app/manual.py b/rag/app/manual.py index c60df258e..8a3480907 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -256,7 +256,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) return res - elif re.search(r"\.docx$", filename, re.IGNORECASE): + elif re.search(r"\.docx?$", filename, re.IGNORECASE): docx_parser = Docx() ti_list, tbls = docx_parser(filename, binary, from_page=0, to_page=10000, callback=callback) @@ -279,4 +279,4 @@ if __name__ == "__main__": pass - chunk(sys.argv[1], callback=dummy) \ No newline at end of file + chunk(sys.argv[1], callback=dummy) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index cf038cb43..f7e12b4d7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -24,7 +24,6 @@ import openai from ollama import Client from rag.nlp import is_chinese, is_english from rag.utils import num_tokens_from_string -from groq import Groq import os import json import requests @@ -840,6 +839,7 @@ class GeminiChat(Base): class GroqChat: def __init__(self, key, model_name, base_url=''): + from groq import Groq self.client = Groq(api_key=key) self.model_name = model_name diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 3ce0e4487..fa82cc827 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -299,8 +299,6 @@ class SparkTTS: yield audio_chunk - - class XinferenceTTS: def __init__(self, key, model_name, **kwargs): self.base_url = kwargs.get("base_url", None) @@ -330,3 +328,30 @@ class XinferenceTTS: for chunk in response.iter_content(chunk_size=1024): if chunk: yield chunk + + +class OllamaTTS(Base): + def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"): + if not base_url: + base_url = "https://api.ollama.ai/v1" + self.model_name = model_name + self.base_url = base_url + self.headers = { + "Content-Type": "application/json" + } + + def tts(self, text, voice="standard-voice"): + payload = { + "model": self.model_name, + "voice": voice, + "input": text + } + + response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True) + + if response.status_code != 200: + raise Exception(f"**Error**: {response.status_code}, {response.text}") + + for chunk in response.iter_content(): + if chunk: + yield chunk diff --git a/rag/svr/cache_file_svr.py b/rag/svr/cache_file_svr.py index 8b96a2af5..81be82f0b 100644 --- a/rag/svr/cache_file_svr.py +++ b/rag/svr/cache_file_svr.py @@ -1,60 +1,60 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import time -import traceback - -from api.db.db_models import close_connection -from api.db.services.task_service import TaskService -from rag.utils.storage_factory import STORAGE_IMPL -from rag.utils.redis_conn import REDIS_CONN - - -def collect(): - doc_locations = TaskService.get_ongoing_doc_name() - logging.debug(doc_locations) - if len(doc_locations) == 0: - time.sleep(1) - return - return doc_locations - -def main(): - locations = collect() - if not locations: - return - logging.info(f"TASKS: {len(locations)}") - for kb_id, loc in locations: - try: - if REDIS_CONN.is_alive(): - try: - key = "{}/{}".format(kb_id, loc) - if REDIS_CONN.exist(key): - continue - file_bin = STORAGE_IMPL.get(kb_id, loc) - REDIS_CONN.transaction(key, file_bin, 12 * 60) - logging.info("CACHE: {}".format(loc)) - except Exception as e: - traceback.print_stack(e) - except Exception as e: - traceback.print_stack(e) - - - -if __name__ == "__main__": - while True: - main() - close_connection() +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import time +import traceback + +from api.db.db_models import close_connection +from api.db.services.task_service import TaskService +from rag.utils.minio_conn import MINIOs +from rag.utils.redis_conn import REDIS_CONN + + +def collect(): + doc_locations = TaskService.get_ongoing_doc_name() + logging.debug(doc_locations) + if len(doc_locations) == 0: + time.sleep(1) + return + return doc_locations + + +def main(): + locations = collect() + if not locations: + return + logging.info(f"TASKS: {len(locations)}") + for kb_id, loc in locations: + try: + if REDIS_CONN.is_alive(): + try: + key = "{}/{}".format(kb_id, loc) + if REDIS_CONN.exist(key): + continue + file_bin = MINIOs.get(kb_id, loc) + REDIS_CONN.transaction(key, file_bin, 12 * 60) + logging.info("CACHE: {}".format(loc)) + except Exception as e: + traceback.print_stack(e) + except Exception as e: + traceback.print_stack(e) + + +if __name__ == "__main__": + while True: + main() + close_connection() time.sleep(1) \ No newline at end of file