From 66f8d35632ea2ce13b1bab95d9098e50fb71cbda Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Thu, 25 Apr 2024 14:14:28 +0800 Subject: [PATCH] Refactor (#537) ### What problem does this PR solve? ### Type of change - [x] Refactoring --- api/apps/document_app.py | 3 ++- api/apps/llm_app.py | 4 +-- api/db/db_models.py | 2 +- api/db/init_data.py | 8 +++--- api/db/services/llm_service.py | 2 +- api/db/services/task_service.py | 20 +++++++++++++++ deepdoc/parser/pdf_parser.py | 34 +++++++++++++------------- rag/llm/__init__.py | 2 +- rag/llm/embedding_model.py | 12 ++++----- rag/svr/cache_file_svr.py | 43 +++++++++++++++++++++++++++++++++ rag/svr/task_broker.py | 2 +- rag/svr/task_executor.py | 6 +++++ rag/utils/minio_conn.py | 1 - rag/utils/redis_conn.py | 19 +++++++++++++++ 14 files changed, 124 insertions(+), 34 deletions(-) create mode 100644 rag/svr/cache_file_svr.py diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 8402d121f..bfcd44f42 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -58,7 +58,8 @@ def upload(): if not e: return get_data_error_result( retmsg="Can't find this knowledgebase!") - if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)): + MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) + if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: return get_data_error_result( retmsg="Exceed the maximum file number of a free user!") diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index b98316272..473b748fb 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel def factories(): try: fac = LLMFactoriesService.get_all() - return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]]) + return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]]) except Exception as e: return server_error_response(e) @@ -174,7 +174,7 @@ def list(): llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] for m in llms: - m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"] + m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"] llm_set = set([m["llm_name"] for m in llms]) for o in objs: diff --git a/api/db/db_models.py b/api/db/db_models.py index e6f2d287e..00e6a5887 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -697,7 +697,7 @@ class Dialog(DataBaseModel): null=True, default="Chinese", help_text="English|Chinese") - llm_id = CharField(max_length=32, null=False, help_text="default llm ID") + llm_id = CharField(max_length=128, null=False, help_text="default llm ID") llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, "presence_penalty": 0.4, "max_tokens": 215}) prompt_type = CharField( diff --git a/api/db/init_data.py b/api/db/init_data.py index 14cb414a1..ecd311473 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -120,7 +120,7 @@ factory_infos = [{ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "status": "1", },{ - "name": "QAnything", + "name": "Youdao", "logo": "", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "status": "1", @@ -323,7 +323,7 @@ def init_llm_factory(): "max_tokens": 2147483648, "model_type": LLMType.EMBEDDING.value }, - # ------------------------ QAnything ----------------------- + # ------------------------ Youdao ----------------------- { "fid": factory_infos[7]["name"], "llm_name": "maidalun1020/bce-embedding-base_v1", @@ -347,7 +347,9 @@ def init_llm_factory(): LLMService.filter_delete([LLM.fid == "Local"]) LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"]) TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"]) - + LLMFactoriesService.filter_update([LLMFactoriesService.model.name == "QAnything"], {"name": "Youdao"}) + LLMService.filter_update([LLMService.model.fid == "QAnything"], {"fid": "Youdao"}) + TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) """ drop table llm; drop table llm_factories; diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index f565da132..7f66ca6aa 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -81,7 +81,7 @@ class TenantLLMService(CommonService): if not model_config: if llm_type == LLMType.EMBEDDING.value: llm = LLMService.query(llm_name=llm_name) - if llm and llm[0].fid in ["QAnything", "FastEmbed"]: + if llm and llm[0].fid in ["Youdao", "FastEmbed"]: model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} if not model_config: if llm_name == "flag-embedding": diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 68ac64e24..8c6bc6e8d 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -21,6 +21,7 @@ from api.db import StatusEnum, FileType, TaskStatus from api.db.db_models import Task, Document, Knowledgebase, Tenant from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService +from api.utils import current_timestamp class TaskService(CommonService): @@ -70,6 +71,25 @@ class TaskService(CommonService): cls.model.id == docs[0]["id"]).execute() return docs + @classmethod + @DB.connection_context() + def get_ongoing_doc_name(cls): + with DB.lock("get_task", -1): + docs = cls.model.select(*[Document.kb_id, Document.location]) \ + .join(Document, on=(cls.model.doc_id == Document.id)) \ + .where( + Document.status == StatusEnum.VALID.value, + Document.run == TaskStatus.RUNNING.value, + ~(Document.type == FileType.VIRTUAL.value), + cls.model.progress >= 0, + cls.model.progress < 1, + cls.model.create_time >= current_timestamp() - 180000 + ) + docs = list(docs.dicts()) + if not docs: return [] + + return list(set([(d["kb_id"], d["location"]) for d in docs])) + @classmethod @DB.connection_context() def do_cancel(cls, id): diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 6bda36970..96f4bdd28 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -37,8 +37,8 @@ class HuParser: self.updown_cnt_mdl.set_param({"device": "cuda"}) try: model_dir = os.path.join( - get_project_base_directory(), - "rag/res/deepdoc") + get_project_base_directory(), + "rag/res/deepdoc") self.updown_cnt_mdl.load_model(os.path.join( model_dir, "updown_concat_xgb.model")) except Exception as e: @@ -49,7 +49,6 @@ class HuParser: self.updown_cnt_mdl.load_model(os.path.join( model_dir, "updown_concat_xgb.model")) - self.page_from = 0 """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -76,7 +75,7 @@ class HuParser: def _y_dis( self, a, b): return ( - b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 + b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 def _match_proj(self, b): proj_patt = [ @@ -99,9 +98,9 @@ class HuParser: tks_down = huqie.qie(down["text"][:LEN]).split(" ") tks_up = huqie.qie(up["text"][-LEN:]).split(" ") tks_all = up["text"][-LEN:].strip() \ - + (" " if re.match(r"[a-zA-Z0-9]+", - up["text"][-1] + down["text"][0]) else "") \ - + down["text"][:LEN].strip() + + (" " if re.match(r"[a-zA-Z0-9]+", + up["text"][-1] + down["text"][0]) else "") \ + + down["text"][:LEN].strip() tks_all = huqie.qie(tks_all).split(" ") fea = [ up.get("R", -1) == down.get("R", -1), @@ -123,7 +122,7 @@ class HuParser: True if re.search(r"[,,][^。.]+$", up["text"]) else False, True if re.search(r"[,,][^。.]+$", up["text"]) else False, True if re.search(r"[\((][^\))]+$", up["text"]) - and re.search(r"[\))]", down["text"]) else False, + and re.search(r"[\))]", down["text"]) else False, self._match_proj(down), True if re.match(r"[A-Z]", down["text"]) else False, True if re.match(r"[A-Z]", up["text"][-1]) else False, @@ -185,7 +184,7 @@ class HuParser: continue for tb in tbls: # for table left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ - tb["x1"] + MARGIN, tb["bottom"] + MARGIN + tb["x1"] + MARGIN, tb["bottom"] + MARGIN left *= ZM top *= ZM right *= ZM @@ -297,7 +296,7 @@ class HuParser: for b in bxs: if not b["text"]: left, right, top, bott = b["x0"] * ZM, b["x1"] * \ - ZM, b["top"] * ZM, b["bottom"] * ZM + ZM, b["top"] * ZM, b["bottom"] * ZM b["text"] = self.ocr.recognize(np.array(img), np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32)) @@ -622,7 +621,7 @@ class HuParser: i += 1 continue lout_no = str(self.boxes[i]["page_number"]) + \ - "-" + str(self.boxes[i]["layoutno"]) + "-" + str(self.boxes[i]["layoutno"]) if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title", "figure caption", @@ -975,6 +974,7 @@ class HuParser: self.outlines.append((a["/Title"], depth)) continue dfs(a, depth + 1) + dfs(outlines, 0) except Exception as e: logging.warning(f"Outlines exception: {e}") @@ -984,7 +984,7 @@ class HuParser: logging.info("Images converted.") self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in - range(len(self.page_chars))] + range(len(self.page_chars))] if sum([1 if e else 0 for e in self.is_english]) > len( self.page_images) / 2: self.is_english = True @@ -1012,9 +1012,9 @@ class HuParser: j += 1 self.__ocr(i + 1, img, chars, zoomin) - #if callback: - # callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") - #print("OCR:", timer()-st) + if callback and i % 6 == 5: + callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") + # print("OCR:", timer()-st) if not self.is_english and not any( [c for c in self.page_chars]) and self.boxes: @@ -1050,7 +1050,7 @@ class HuParser: left, right, top, bottom = float(left), float( right), float(top), float(bottom) poss.append(([int(p) - 1 for p in pn.split("-")], - left, right, top, bottom)) + left, right, top, bottom)) if not poss: if need_position: return None, None @@ -1076,7 +1076,7 @@ class HuParser: self.page_images[pns[0]].crop((left * ZM, top * ZM, right * ZM, min( - bottom, self.page_images[pns[0]].size[1]) + bottom, self.page_images[pns[0]].size[1]) )) ) if 0 < ii < len(poss) - 1: diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 7d6d7c441..3b035a435 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -25,7 +25,7 @@ EmbeddingModel = { "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "ZHIPU-AI": ZhipuEmbed, "FastEmbed": FastEmbed, - "QAnything": QAnythingEmbed + "Youdao": YoudaoEmbed } diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index e6e18fbed..597dbfdc9 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -229,19 +229,19 @@ class XinferenceEmbed(Base): return np.array(res.data[0].embedding), res.usage.total_tokens -class QAnythingEmbed(Base): +class YoudaoEmbed(Base): _client = None def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): from BCEmbedding import EmbeddingModel as qanthing - if not QAnythingEmbed._client: + if not YoudaoEmbed._client: try: print("LOADING BCE...") - QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join( + YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( get_project_base_directory(), "rag/res/bce-embedding-base_v1")) except Exception as e: - QAnythingEmbed._client = qanthing( + YoudaoEmbed._client = qanthing( model_name_or_path=model_name.replace( "maidalun1020", "InfiniFlow")) @@ -251,10 +251,10 @@ class QAnythingEmbed(Base): for t in texts: token_count += num_tokens_from_string(t) for i in range(0, len(texts), batch_size): - embds = QAnythingEmbed._client.encode(texts[i:i + batch_size]) + embds = YoudaoEmbed._client.encode(texts[i:i + batch_size]) res.extend(embds) return np.array(res), token_count def encode_queries(self, text): - embds = QAnythingEmbed._client.encode([text]) + embds = YoudaoEmbed._client.encode([text]) return np.array(embds[0]), num_tokens_from_string(text) diff --git a/rag/svr/cache_file_svr.py b/rag/svr/cache_file_svr.py new file mode 100644 index 000000000..4eeaab629 --- /dev/null +++ b/rag/svr/cache_file_svr.py @@ -0,0 +1,43 @@ +import random +import time +import traceback + +from api.db.db_models import close_connection +from api.db.services.task_service import TaskService +from rag.utils import MINIO +from rag.utils.redis_conn import REDIS_CONN + + +def collect(): + doc_locations = TaskService.get_ongoing_doc_name() + #print(tasks) + if len(doc_locations) == 0: + time.sleep(1) + return + return doc_locations + +def main(): + locations = collect() + if not locations:return + print("TASKS:", len(locations)) + for kb_id, loc in locations: + try: + if REDIS_CONN.is_alive(): + try: + key = "{}/{}".format(kb_id, loc) + if REDIS_CONN.exist(key):continue + file_bin = MINIO.get(kb_id, loc) + REDIS_CONN.transaction(key, file_bin, 12 * 60) + print("CACHE:", loc) + except Exception as e: + traceback.print_stack(e) + except Exception as e: + traceback.print_stack(e) + + + +if __name__ == "__main__": + while True: + main() + close_connection() + time.sleep(1) \ No newline at end of file diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py index 126d7e882..3e43fbff2 100644 --- a/rag/svr/task_broker.py +++ b/rag/svr/task_broker.py @@ -167,7 +167,7 @@ def update_progress(): info = { "process_duation": datetime.timestamp( datetime.now()) - - d["process_begin_at"].timestamp(), + d["process_begin_at"].timestamp(), "run": status} if prg != 0: info["progress"] = prg diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 7783a6308..b72b1c556 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -107,8 +107,14 @@ def get_minio_binary(bucket, name): global MINIO if REDIS_CONN.is_alive(): try: + for _ in range(30): + if REDIS_CONN.exist("{}/{}".format(bucket, name)): + time.sleep(1) + break + time.sleep(1) r = REDIS_CONN.get("{}/{}".format(bucket, name)) if r: return r + cron_logger.warning("Cache missing: {}".format(name)) except Exception as e: cron_logger.warning("Get redis[EXCEPTION]:" + str(e)) return MINIO.get(bucket, name) diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index d5b092779..a1a4e45a8 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -56,7 +56,6 @@ class HuMinio(object): except Exception as e: minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e)) - def get(self, bucket, fnm): for _ in range(1): try: diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 0fe30202d..5884d52aa 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -25,6 +25,14 @@ class RedisDB: def is_alive(self): return self.REDIS is not None + def exist(self, k): + if not self.REDIS: return + try: + return self.REDIS.exists(k) + except Exception as e: + logging.warning("[EXCEPTION]exist" + str(k) + "||" + str(e)) + self.__open__() + def get(self, k): if not self.REDIS: return try: @@ -51,5 +59,16 @@ class RedisDB: self.__open__() return False + def transaction(self, key, value, exp=3600): + try: + pipeline = self.REDIS.pipeline(transaction=True) + pipeline.set(key, value, exp, nx=True) + pipeline.execute() + return True + except Exception as e: + logging.warning("[EXCEPTION]set" + str(key) + "||" + str(e)) + self.__open__() + return False + REDIS_CONN = RedisDB() \ No newline at end of file