refine admin initialization (#75)

This commit is contained in:
KevinHuSh 2024-02-27 14:57:34 +08:00 committed by GitHub
parent d1c600d5d3
commit 4568a4b2cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 91 additions and 34 deletions

View File

@ -20,7 +20,7 @@ from flask_login import login_required, current_user
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, huqie, retrievaler from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode from api.settings import RetCode, retrievaler
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import hashlib import hashlib
import re import re

View File

@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger from api.settings import access_logger, stat_logger, retrievaler
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder, rmSpace from rag.utils import num_tokens_from_string, encoder, rmSpace

View File

@ -16,10 +16,12 @@
import time import time
import uuid import uuid
from api.db import LLMType from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from api.db.services import UserService from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
def init_superuser(): def init_superuser():
@ -32,8 +34,44 @@ def init_superuser():
"creator": "system", "creator": "system",
"status": "1", "status": "1",
} }
tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
"user_id": user_info["id"],
"invited_by": user_info["id"],
"role": UserTenantRole.OWNER
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY})
if not UserService.save(**user_info):
print("【ERROR】can't init admin.")
return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
UserService.save(**user_info) UserService.save(**user_info)
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0:
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
v,c = embd_mdl.encode(["Hello!"])
if c == 0:
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
def init_llm_factory(): def init_llm_factory():
factory_infos = [{ factory_infos = [{
@ -171,10 +209,10 @@ def init_llm_factory():
def init_web_data(): def init_web_data():
start_time = time.time() start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory() if not LLMService.get_all().count():init_llm_factory()
if not UserService.get_all().count():
init_superuser()
print("init web data success:{}".format(time.time() - start_time)) print("init web data success:{}".format(time.time() - start_time))

View File

@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
from rag.nlp import search
from rag.utils import ELASTICSEARCH
# Server
API_VERSION = "v1" API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow" RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py" SERVER_MODULE = "rag_flow_server.py"
@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = [] PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH)
class CustomEnum(Enum): class CustomEnum(Enum):
@classmethod @classmethod
def valid(cls, value): def valid(cls, value):

View File

@ -230,7 +230,7 @@ class HuParser:
b["H_right"] = headers[ii]["x1"] b["H_right"] = headers[ii]["x1"]
b["H"] = ii b["H"] = ii
ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None: if ii is not None:
b["C"] = ii b["C"] = ii
b["C_left"] = clmns[ii]["x0"] b["C_left"] = clmns[ii]["x0"]

View File

@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer):
super().__init__(self.labels, domain, super().__init__(self.labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b): def __is_garbage(b):
patt = [r"^•+$", r"(版权归©|免责条款|地址[:])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", patt = [r"^•+$", r"(版权归©|免责条款|地址[:])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",

View File

@ -2,7 +2,6 @@ import copy
import numpy as np import numpy as np
import cv2 import cv2
import paddle
from shapely.geometry import Polygon from shapely.geometry import Polygon
import pyclipper import pyclipper
@ -215,7 +214,7 @@ class DBPostProcess(object):
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor): if not isinstance(pred, np.ndarray):
pred = pred.numpy() pred = pred.numpy()
pred = pred[:, 0, :, :] pred = pred[:, 0, :, :]
segmentation = pred > self.thresh segmentation = pred > self.thresh
@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list): if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1] preds = preds[-1]
if isinstance(preds, paddle.Tensor): if not isinstance(preds, np.ndarray):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)

View File

@ -259,6 +259,18 @@ class Recognizer(object):
return max_overlaped_i return max_overlaped_i
@staticmethod
def find_horizontally_tightest_fit(box, boxes):
if not boxes:
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
min_dis = dis
return min_i
@staticmethod @staticmethod
def find_overlapped_with_threashold(box, boxes, thr=0.3): def find_overlapped_with_threashold(box, boxes, thr=0.3):
if not boxes: if not boxes:

View File

@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr):
clmns = sorted([r for r in tb_cpns if re.match( clmns = sorted([r for r in tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: x["x0"]) r"table column$", r["label"])], key=lambda x: x["x0"])
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5) clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
for b in boxes: for b in boxes:
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3) ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
if ii is not None: if ii is not None:
@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr):
b["H_right"] = headers[ii]["x1"] b["H_right"] = headers[ii]["x1"]
b["H"] = ii b["H"] = ii
ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None: if ii is not None:
b["C"] = ii b["C"] = ii
b["C_left"] = clmns[ii]["x0"] b["C_left"] = clmns[ii]["x0"]
@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr):
b["H_left"] = spans[ii]["x0"] b["H_left"] = spans[ii]["x0"]
b["H_right"] = spans[ii]["x1"] b["H_right"] = spans[ii]["x1"]
b["SP"] = ii b["SP"] = ii
html = """ html = """
<html> <html>
<head> <head>

View File

@ -14,7 +14,6 @@ import logging
import os import os
import re import re
from collections import Counter from collections import Counter
from copy import deepcopy
import numpy as np import numpy as np
@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer):
super().__init__(self.labels, "tsr", super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, images, thr=0.5): def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr) tbls = super().__call__(images, thr)
res = [] res = []
# align left&right for rows, align top&bottom for columns # align left&right for rows, align top&bottom for columns
@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer):
"row") > 0 or b["label"].find("header") > 0] "row") > 0 or b["label"].find("header") > 0]
if not left: if not left:
continue continue
left = np.median(left) if len(left) > 4 else np.min(left) left = np.mean(left) if len(left) > 4 else np.min(left)
right = np.median(right) if len(right) > 4 else np.max(right) right = np.mean(right) if len(right) > 4 else np.max(right)
for b in lts: for b in lts:
if b["label"].find("row") > 0 or b["label"].find("header") > 0: if b["label"].find("row") > 0 or b["label"].find("header") > 0:
if b["x0"] > left: if b["x0"] > left:
@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer):
i = 0 i = 0
while i < len(boxes): while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]): if TableStructureRecognizer.is_caption(boxes[i]):
if is_english: cap + " "
cap += boxes[i]["text"] cap += boxes[i]["text"]
boxes.pop(i) boxes.pop(i)
i -= 1 i -= 1
@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer):
for i in range(clmno): for i in range(clmno):
if not tbl[r][i]: if not tbl[r][i]:
continue continue
txt = "".join([a["text"].strip() for a in tbl[r][i]]) txt = " ".join([a["text"].strip() for a in tbl[r][i]])
headers[r][i] = txt headers[r][i] = txt
hdrset.add(txt) hdrset.add(txt)
if all([not t for t in headers[r]]): if all([not t for t in headers[r]]):

View File

@ -15,7 +15,7 @@
# #
from abc import ABC from abc import ABC
from openai import OpenAI from openai import OpenAI
import os import openai
class Base(ABC): class Base(ABC):
@ -33,11 +33,14 @@ class GptTurbo(Base):
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system}) if system: history.insert(0, {"role": "system", "content": system})
try:
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
**gen_conf) **gen_conf)
return res.choices[0].message.content.strip(), res.usage.completion_tokens return res.choices[0].message.content.strip(), res.usage.completion_tokens
except openai.APIError as e:
return "ERROR: "+str(e), 0
from dashscope import Generation from dashscope import Generation
@ -58,7 +61,7 @@ class QWenChat(Base):
) )
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0 return "ERROR: " + response.message, 0
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
@ -77,4 +80,4 @@ class ZhipuChat(Base):
) )
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.completion_tokens return response.output.choices[0]['message']['content'], response.usage.completion_tokens
return response.message, 0 return "ERROR: " + response.message, 0

View File

@ -1,7 +1,4 @@
from . import search
from rag.utils import ELASTICSEARCH
retrievaler = search.Dealer(ELASTICSEARCH)
from nltk.stem import PorterStemmer from nltk.stem import PorterStemmer
stemmer = PorterStemmer() stemmer = PorterStemmer()
@ -39,10 +36,12 @@ BULLET_PATTERN = [[
] ]
] ]
def random_choices(arr, k): def random_choices(arr, k):
k = min(len(arr), k) k = min(len(arr), k)
return random.choices(arr, k=k) return random.choices(arr, k=k)
def bullets_category(sections): def bullets_category(sections):
global BULLET_PATTERN global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN) hits = [0] * len(BULLET_PATTERN)

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import re import re
from elasticsearch_dsl import Q, Search, A from elasticsearch_dsl import Q, Search
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass
@ -183,6 +183,7 @@ class Dealer:
def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7): embd_mdl, tkweight=0.3, vtweight=0.7):
assert len(chunks) == len(chunk_v)
pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer) pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)): for i in range(1, len(pieces)):
if re.match(r"[a-z][.?;!][ \n]", pieces[i]): if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
@ -216,7 +217,7 @@ class Dealer:
if mx < 0.55: if mx < 0.55:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(
set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
res = "" res = ""
for i, p in enumerate(pieces): for i, p in enumerate(pieces):
@ -225,6 +226,7 @@ class Dealer:
continue continue
if i not in cites: if i not in cites:
continue continue
assert int(cites[i]) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i]) res += "##%s$$" % "$".join(cites[i])
return res return res