diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 74515f995..be650a61c 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -167,11 +167,13 @@ class TenantLLMService(CommonService): else: assert False, "LLM type error" + llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm + num = 0 try: - for u in cls.query(tenant_id=tenant_id, llm_name=mdlnm): + for u in cls.query(tenant_id=tenant_id, llm_name=llm_name): num += cls.model.update(used_tokens=u.used_tokens + used_tokens)\ - .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ + .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name)\ .execute() except Exception as e: pass @@ -207,7 +209,7 @@ class LLMBundle(object): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens): database_logger.error( - "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) + "Can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) return emd, used_tokens def encode_queries(self, query: str): @@ -215,7 +217,7 @@ class LLMBundle(object): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens): database_logger.error( - "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) + "Can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) return emd, used_tokens def similarity(self, query: str, texts: list): @@ -223,7 +225,7 @@ class LLMBundle(object): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens): database_logger.error( - "Can't update token usage for {}/RERANK".format(self.tenant_id)) + "Can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens)) return sim, used_tokens def describe(self, image, max_tokens=300): @@ -231,7 +233,7 @@ class LLMBundle(object): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens): database_logger.error( - "Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) + "Can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) return txt def transcription(self, audio): @@ -239,7 +241,7 @@ class LLMBundle(object): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens): database_logger.error( - "Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id)) + "Can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens)) return txt def tts(self, text): @@ -254,10 +256,10 @@ class LLMBundle(object): def chat(self, system, history, gen_conf): txt, used_tokens = self.mdl.chat(system, history, gen_conf) - if not TenantLLMService.increase_usage( + if isinstance(txt, int) and not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens, self.llm_name): database_logger.error( - "Can't update token usage for {}/CHAT".format(self.tenant_id)) + "Can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) return txt def chat_streamly(self, system, history, gen_conf): @@ -266,6 +268,6 @@ class LLMBundle(object): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, txt, self.llm_name): database_logger.error( - "Can't update token usage for {}/CHAT".format(self.tenant_id)) + "Can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) return yield txt diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 89d9334ed..7e7aee46b 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -89,9 +89,15 @@ { "name": "Tongyi-Qianwen", "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", + "tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,SPEECH2TEXT,MODERATION", "status": "1", "llm": [ + { + "llm_name": "qwen-long", + "tags": "LLM,CHAT,10000K", + "max_tokens": 1000000, + "model_type": "chat" + }, { "llm_name": "qwen-turbo", "tags": "LLM,CHAT,8K", @@ -139,6 +145,12 @@ "tags": "LLM,CHAT,IMAGE2TEXT", "max_tokens": 765, "model_type": "image2text" + }, + { + "llm_name": "gte-rerank", + "tags": "RE-RANK,4k", + "max_tokens": 4000, + "model_type": "rerank" } ] }, diff --git a/graphrag/graph_extractor.py b/graphrag/graph_extractor.py index dce15cbf7..e56a24780 100644 --- a/graphrag/graph_extractor.py +++ b/graphrag/graph_extractor.py @@ -164,6 +164,7 @@ class GraphExtractor: text = perform_variable_replacements(self._extraction_prompt, variables=variables) gen_conf = {"temperature": 0.3} response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) + if response.find("**ERROR**") >= 0: raise Exception(response) token_count = num_tokens_from_string(text + response) results = response or "" diff --git a/graphrag/index.py b/graphrag/index.py index 02d978f71..018b438b9 100644 --- a/graphrag/index.py +++ b/graphrag/index.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os from concurrent.futures import ThreadPoolExecutor import json from functools import reduce @@ -64,7 +65,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, en texts, graphs = [], [] cnt = 0 threads = [] - exe = ThreadPoolExecutor(max_workers=50) + max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50)) + exe = ThreadPoolExecutor(max_workers=max_workers) for i in range(len(chunks)): tkn_cnt = num_tokens_from_string(chunks[i]) if cnt+tkn_cnt >= left_token_count and texts: diff --git a/graphrag/mind_map_extractor.py b/graphrag/mind_map_extractor.py index 2bedf9639..d25b24c24 100644 --- a/graphrag/mind_map_extractor.py +++ b/graphrag/mind_map_extractor.py @@ -16,6 +16,7 @@ import collections import logging +import os import re import logging import traceback @@ -89,7 +90,8 @@ class MindMapExtractor: prompt_variables = {} try: - exe = ThreadPoolExecutor(max_workers=12) + max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12)) + exe = ThreadPoolExecutor(max_workers=max_workers) threads = [] token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512) texts = [] diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 8ad059fb0..11f248d17 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -122,7 +122,8 @@ RerankModel = { "TogetherAI": TogetherAIRerank, "SILICONFLOW": SILICONFLOWRerank, "BaiduYiyan": BaiduYiyanRerank, - "Voyage AI": VoyageRerank + "Voyage AI": VoyageRerank, + "Tongyi-Qianwen": QWenRerank, } Seq2txtModel = { diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index bf91dbf0f..f5b584b29 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -31,7 +31,8 @@ import asyncio class Base(ABC): def __init__(self, key, model_name, base_url): - self.client = OpenAI(api_key=key, base_url=base_url) + timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600)) + self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name def chat(self, system, history, gen_conf): @@ -216,28 +217,39 @@ class QWenChat(Base): self.model_name = model_name def chat(self, system, history, gen_conf): - from http import HTTPStatus - if system: - history.insert(0, {"role": "system", "content": system}) - response = Generation.call( - self.model_name, - messages=history, - result_format='message', - **gen_conf - ) - ans = "" - tk_count = 0 - if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]['message']['content'] - tk_count += response.usage.total_tokens - if response.output.choices[0].get("finish_reason", "") == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, tk_count + stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true' + if not stream_flag: + from http import HTTPStatus + if system: + history.insert(0, {"role": "system", "content": system}) - return "**ERROR**: " + response.message, tk_count + response = Generation.call( + self.model_name, + messages=history, + result_format='message', + **gen_conf + ) + ans = "" + tk_count = 0 + if response.status_code == HTTPStatus.OK: + ans += response.output.choices[0]['message']['content'] + tk_count += response.usage.total_tokens + if response.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + return ans, tk_count - def chat_streamly(self, system, history, gen_conf): + return "**ERROR**: " + response.message, tk_count + else: + g = self._chat_streamly(system, history, gen_conf, incremental_output=True) + result_list = list(g) + error_msg_list = [item for item in result_list if str(item).find("**ERROR**") >= 0] + if len(error_msg_list) > 0: + return "**ERROR**: " + "".join(error_msg_list) , 0 + else: + return "".join(result_list[:-1]), result_list[-1] + + def _chat_streamly(self, system, history, gen_conf, incremental_output=False): from http import HTTPStatus if system: history.insert(0, {"role": "system", "content": system}) @@ -249,6 +261,7 @@ class QWenChat(Base): messages=history, result_format='message', stream=True, + incremental_output=incremental_output, **gen_conf ) for resp in response: @@ -267,6 +280,9 @@ class QWenChat(Base): yield tk_count + def chat_streamly(self, system, history, gen_conf): + return self._chat_streamly(system, history, gen_conf) + class ZhipuChat(Base): def __init__(self, key, model_name="glm-3-turbo", **kwargs): diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 7ef8bcf5d..f68869c12 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -390,3 +390,27 @@ class VoyageRerank(Base): for r in res.results: rank[r.index] = r.relevance_score return rank, res.total_tokens + +class QWenRerank(Base): + def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs): + import dashscope + self.api_key = key + self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name + + def similarity(self, query: str, texts: list): + import dashscope + from http import HTTPStatus + resp = dashscope.TextReRank.call( + api_key=self.api_key, + model=self.model_name, + query=query, + documents=texts, + top_n=len(texts), + return_documents=False + ) + rank = np.zeros(len(texts), dtype=float) + if resp.status_code == HTTPStatus.OK: + for r in resp.output.results: + rank[r.index] = r.relevance_score + return rank, resp.usage.total_tokens + return rank, 0 \ No newline at end of file