### What problem does this PR solve?


### Type of change

- [x] Refactoring
This commit is contained in:
KevinHuSh 2024-06-12 11:02:15 +08:00 committed by GitHub
parent 2cc89211f6
commit abcd3d2469
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 73 additions and 83 deletions

View File

@ -85,7 +85,6 @@ def register_page(page_path):
url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}' url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}'
app.register_blueprint(page.manager, url_prefix=url_prefix) app.register_blueprint(page.manager, url_prefix=url_prefix)
print(f'API file: {page_path}, URL: {url_prefix}')
return url_prefix return url_prefix

View File

@ -40,6 +40,7 @@ from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO from rag.utils.minio_conn import MINIO
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from api.utils.web_utils import html2pdf, is_valid_url from api.utils.web_utils import html2pdf, is_valid_url
from api.utils.web_utils import html2pdf, is_valid_url
@manager.route('/upload', methods=['POST']) @manager.route('/upload', methods=['POST'])
@ -117,6 +118,68 @@ def upload():
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/web_crawl', methods=['POST'])
@login_required
@validate_request("kb_id", "name", "url")
def web_crawl():
kb_id = request.form.get("kb_id")
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
name = request.form.get("name")
url = request.form.get("url")
if not is_valid_url(url):
return get_json_result(
data=False, retmsg='The URL format is invalid', retcode=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
blob = html2pdf(url)
if not blob: return server_error_response(ValueError("Download failure."))
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
kb_root_folder = FileService.get_kb_folder(current_user.id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
try:
filename = duplicate_name(
DocumentService.query,
name=name+".pdf",
kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")
location = filename
while MINIO.obj_exist(kb_id, location):
location += "_"
MINIO.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
}
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
DocumentService.insert(doc)
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e:
return server_error_response(e)
return get_json_result(data=True)
@manager.route('/create', methods=['POST']) @manager.route('/create', methods=['POST'])
@login_required @login_required
@validate_request("name", "kb_id") @validate_request("name", "kb_id")
@ -417,69 +480,3 @@ def get_image(image_id):
return response return response
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/web_crawl', methods=['POST'])
@login_required
def web_crawl():
kb_id = request.form.get("kb_id")
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
name = request.form.get("name")
url = request.form.get("url")
if not name:
return get_json_result(
data=False, retmsg='Lack of "name"', retcode=RetCode.ARGUMENT_ERROR)
if not url:
return get_json_result(
data=False, retmsg='Lack of "url"', retcode=RetCode.ARGUMENT_ERROR)
if not is_valid_url(url):
return get_json_result(
data=False, retmsg='The URL format is invalid', retcode=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
kb_root_folder = FileService.get_kb_folder(current_user.id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
try:
filename = duplicate_name(
DocumentService.query,
name=name+".pdf",
kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")
location = filename
while MINIO.obj_exist(kb_id, location):
location += "_"
blob = html2pdf(url)
MINIO.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
}
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
DocumentService.insert(doc)
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e:
return get_json_result(
data=False, retmsg=e, retcode=RetCode.SERVER_ERROR)
return get_json_result(data=True)

View File

@ -112,14 +112,15 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt_config["system"] = prompt_config["system"].replace( prompt_config["system"] = prompt_config["system"].replace(
"{%s}" % p["key"], " ") "{%s}" % p["key"], " ")
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
for _ in range(len(questions) // 2): for _ in range(len(questions) // 2):
questions.append(questions[-1]) questions.append(questions[-1])
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else: else:
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, dialog.vector_similarity_weight,

View File

@ -248,11 +248,12 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
if data == None: if data is None:
return jsonify({"code": code, "message": message}) return jsonify({"code": code, "message": message})
else: else:
return jsonify({"code": code, "message": message, "data": data}) return jsonify({"code": code, "message": message, "data": data})
def construct_error_response(e): def construct_error_response(e):
stat_logger.exception(e) stat_logger.exception(e)
try: try:

View File

@ -154,11 +154,6 @@ class LoggerFactory(object):
delay=True) delay=True)
if level: if level:
handler.level = level handler.level = level
else:
handler.level = LoggerFactory.LEVEL
formatter = logging.Formatter(LoggerFactory.LOG_FORMAT)
handler.setFormatter(formatter)
return handler return handler

View File

@ -78,5 +78,3 @@ def __get_pdf_from_html(
def is_valid_url(url: str) -> bool: def is_valid_url(url: str) -> bool:
return bool(re.match(r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url)) return bool(re.match(r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url))

View File

@ -26,9 +26,8 @@ import dashscope
from openai import OpenAI from openai import OpenAI
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
import asyncio
import numpy as np import numpy as np
import asyncio
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
@ -317,12 +316,12 @@ class InfinityEmbed(Base):
engine_kwargs: dict = {}, engine_kwargs: dict = {},
key = None, key = None,
): ):
from infinity_emb import EngineArgs from infinity_emb import EngineArgs
from infinity_emb.engine import AsyncEngineArray from infinity_emb.engine import AsyncEngineArray
self._default_model = model_names[0] self._default_model = model_names[0]
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names]) self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
async def _embed(self, sentences: list[str], model_name: str = ""): async def _embed(self, sentences: list[str], model_name: str = ""):
if not model_name: if not model_name: