mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
make sure the models will not be load twice (#422)
### What problem does this PR solve? #381 ### Type of change - [x] Refactoring
This commit is contained in:
parent
e8570da856
commit
453c29170f
@ -105,8 +105,8 @@ def stats():
|
|||||||
res = {
|
res = {
|
||||||
"pv": [(o["dt"], o["pv"]) for o in objs],
|
"pv": [(o["dt"], o["pv"]) for o in objs],
|
||||||
"uv": [(o["dt"], o["uv"]) for o in objs],
|
"uv": [(o["dt"], o["uv"]) for o in objs],
|
||||||
"speed": [(o["dt"], o["tokens"]/o["duration"]) for o in objs],
|
"speed": [(o["dt"], float(o["tokens"])/float(o["duration"])) for o in objs],
|
||||||
"tokens": [(o["dt"], o["tokens"]/1000.) for o in objs],
|
"tokens": [(o["dt"], float(o["tokens"])/1000.) for o in objs],
|
||||||
"round": [(o["dt"], o["round"]) for o in objs],
|
"round": [(o["dt"], o["round"]) for o in objs],
|
||||||
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
|
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
|
||||||
}
|
}
|
||||||
@ -115,8 +115,7 @@ def stats():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/new_conversation', methods=['POST'])
|
@manager.route('/new_conversation', methods=['GET'])
|
||||||
@validate_request("user_id")
|
|
||||||
def set_conversation():
|
def set_conversation():
|
||||||
token = request.headers.get('Authorization').split()[1]
|
token = request.headers.get('Authorization').split()[1]
|
||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
@ -131,7 +130,7 @@ def set_conversation():
|
|||||||
conv = {
|
conv = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"dialog_id": dia.id,
|
"dialog_id": dia.id,
|
||||||
"user_id": req["user_id"],
|
"user_id": request.args.get("user_id", ""),
|
||||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||||
}
|
}
|
||||||
API4ConversationService.save(**conv)
|
API4ConversationService.save(**conv)
|
||||||
|
@ -629,7 +629,7 @@ class Document(DataBaseModel):
|
|||||||
max_length=128,
|
max_length=128,
|
||||||
null=False,
|
null=False,
|
||||||
default="local",
|
default="local",
|
||||||
help_text="where dose this document from")
|
help_text="where dose this document come from")
|
||||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||||
created_by = CharField(
|
created_by = CharField(
|
||||||
max_length=32,
|
max_length=32,
|
||||||
|
@ -43,7 +43,9 @@ class HuParser:
|
|||||||
model_dir, "updown_concat_xgb.model"))
|
model_dir, "updown_concat_xgb.model"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_dir = snapshot_download(
|
model_dir = snapshot_download(
|
||||||
repo_id="InfiniFlow/text_concat_xgb_v1.0")
|
repo_id="InfiniFlow/text_concat_xgb_v1.0",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
self.updown_cnt_mdl.load_model(os.path.join(
|
self.updown_cnt_mdl.load_model(os.path.join(
|
||||||
model_dir, "updown_concat_xgb.model"))
|
model_dir, "updown_concat_xgb.model"))
|
||||||
|
|
||||||
|
@ -43,7 +43,9 @@ class LayoutRecognizer(Recognizer):
|
|||||||
"rag/res/deepdoc")
|
"rag/res/deepdoc")
|
||||||
super().__init__(self.labels, domain, model_dir)
|
super().__init__(self.labels, domain, model_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
super().__init__(self.labels, domain, model_dir)
|
super().__init__(self.labels, domain, model_dir)
|
||||||
|
|
||||||
self.garbage_layouts = ["footer", "header", "reference"]
|
self.garbage_layouts = ["footer", "header", "reference"]
|
||||||
|
@ -486,7 +486,9 @@ class OCR(object):
|
|||||||
self.text_detector = TextDetector(model_dir)
|
self.text_detector = TextDetector(model_dir)
|
||||||
self.text_recognizer = TextRecognizer(model_dir)
|
self.text_recognizer = TextRecognizer(model_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
self.text_detector = TextDetector(model_dir)
|
self.text_detector = TextDetector(model_dir)
|
||||||
self.text_recognizer = TextRecognizer(model_dir)
|
self.text_recognizer = TextRecognizer(model_dir)
|
||||||
|
|
||||||
|
@ -41,7 +41,9 @@ class Recognizer(object):
|
|||||||
"rag/res/deepdoc")
|
"rag/res/deepdoc")
|
||||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||||
if not os.path.exists(model_file_path):
|
if not os.path.exists(model_file_path):
|
||||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||||
else:
|
else:
|
||||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||||
|
@ -39,7 +39,9 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
get_project_base_directory(),
|
get_project_base_directory(),
|
||||||
"rag/res/deepdoc"))
|
"rag/res/deepdoc"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc"))
|
super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
|
local_dir_use_symlinks=False))
|
||||||
|
|
||||||
def __call__(self, images, thr=0.2):
|
def __call__(self, images, thr=0.2):
|
||||||
tbls = super().__call__(images, thr)
|
tbls = super().__call__(images, thr)
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
import os
|
import os
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
@ -35,7 +37,10 @@ try:
|
|||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
use_fp16=torch.cuda.is_available())
|
use_fp16=torch.cuda.is_available())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
|
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/bge-large-zh-v1.5"),
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
|
flag_model = FlagModel(model_dir,
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
use_fp16=torch.cuda.is_available())
|
use_fp16=torch.cuda.is_available())
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user