mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-23 22:50:17 +08:00
add alot of api (#23)
* clean rust version project * clean rust version project * build python version rag-flow * add alot of api
This commit is contained in:
parent
30791976d5
commit
3198faf2d2
@ -35,7 +35,7 @@ class Base(ABC):
|
||||
|
||||
|
||||
class HuEmbedding(Base):
|
||||
def __init__(self):
|
||||
def __init__(self, key="", model_name=""):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
|
@ -411,8 +411,11 @@ class TextChunker(HuChunker):
|
||||
flds = self.Fields()
|
||||
if self.is_binary_file(fnm):
|
||||
return flds
|
||||
txt = ""
|
||||
if isinstance(fnm, str):
|
||||
with open(fnm, "r") as f:
|
||||
txt = f.read()
|
||||
else: txt = fnm.decode("utf-8")
|
||||
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
||||
flds.table_chunks = []
|
||||
return flds
|
||||
|
@ -8,7 +8,7 @@ from rag.nlp import huqie, query
|
||||
import numpy as np
|
||||
|
||||
|
||||
def index_name(uid): return f"docgpt_{uid}"
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
|
||||
|
||||
class Dealer:
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import copy
|
||||
@ -24,9 +25,10 @@ from timeit import default_timer as timer
|
||||
|
||||
from rag.llm import EmbeddingModel, CvModel
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from rag.utils import ELASTICSEARCH, num_tokens_from_string
|
||||
from rag.utils import ELASTICSEARCH
|
||||
from rag.utils import MINIO
|
||||
from rag.utils import rmSpace, findMaxDt
|
||||
from rag.utils import rmSpace, findMaxTm
|
||||
|
||||
from rag.nlp import huchunk, huqie, search
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.services.document_service import DocumentService
|
||||
from web_server.db.services.llm_service import TenantLLMService
|
||||
from web_server.settings import database_logger
|
||||
from web_server.utils import get_format_time
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
@ -83,7 +86,7 @@ def collect(comm, mod, tm):
|
||||
if len(docs) == 0:
|
||||
return pd.DataFrame()
|
||||
docs = pd.DataFrame(docs)
|
||||
mtm = str(docs["update_time"].max())[:19]
|
||||
mtm = docs["update_time"].max()
|
||||
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
||||
return docs
|
||||
|
||||
@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
|
||||
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
||||
|
||||
|
||||
def build(row):
|
||||
def build(row, cvmdl):
|
||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
return []
|
||||
|
||||
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
||||
if ELASTICSEARCH.getTotal(res) > 0:
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
||||
@ -120,7 +124,8 @@ def build(row):
|
||||
set_progress(row["id"], random.randint(0, 20) /
|
||||
100., "Finished preparing! Start to slice file!", True)
|
||||
try:
|
||||
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
|
||||
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
||||
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
set_progress(
|
||||
@ -131,6 +136,9 @@ def build(row):
|
||||
row["id"], -1, f"Internal server error: %s" %
|
||||
str(e).replace(
|
||||
"'", ""))
|
||||
|
||||
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
||||
|
||||
return []
|
||||
|
||||
if not obj.text_chunks and not obj.table_chunks:
|
||||
@ -144,7 +152,7 @@ def build(row):
|
||||
"Finished slicing files. Start to embedding the content.")
|
||||
|
||||
doc = {
|
||||
"doc_id": row["did"],
|
||||
"doc_id": row["id"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": os.path.split(row["location"])[-1],
|
||||
"title_tks": huqie.qie(row["name"]),
|
||||
@ -164,10 +172,10 @@ def build(row):
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
if isinstance(img, Image):
|
||||
img.save(output_buffer, format='JPEG')
|
||||
else:
|
||||
if isinstance(img, bytes):
|
||||
output_buffer = BytesIO(img)
|
||||
else:
|
||||
img.save(output_buffer, format='JPEG')
|
||||
|
||||
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||
@ -215,15 +223,16 @@ def embedding(docs, mdl):
|
||||
|
||||
|
||||
def model_instance(tenant_id, llm_type):
|
||||
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
|
||||
if not model_config:return
|
||||
model_config = model_config[0]
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
|
||||
if not model_config:
|
||||
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
|
||||
else: model_config = model_config[0].to_dict()
|
||||
if llm_type == LLMType.EMBEDDING:
|
||||
if model_config.llm_factory not in EmbeddingModel: return
|
||||
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||
if model_config["llm_factory"] not in EmbeddingModel: return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
||||
if llm_type == LLMType.IMAGE2TEXT:
|
||||
if model_config.llm_factory not in CvModel: return
|
||||
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||
if model_config["llm_factory"] not in CvModel: return
|
||||
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
|
||||
|
||||
|
||||
def main(comm, mod):
|
||||
@ -231,7 +240,7 @@ def main(comm, mod):
|
||||
from rag.llm import HuEmbedding
|
||||
model = HuEmbedding()
|
||||
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
||||
tm = findMaxDt(tm_fnm)
|
||||
tm = findMaxTm(tm_fnm)
|
||||
rows = collect(comm, mod, tm)
|
||||
if len(rows) == 0:
|
||||
return
|
||||
@ -247,7 +256,7 @@ def main(comm, mod):
|
||||
st_tm = timer()
|
||||
cks = build(r, cv_mdl)
|
||||
if not cks:
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
tmf.write(str(r["update_time"]) + "\n")
|
||||
continue
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
@ -268,12 +277,19 @@ def main(comm, mod):
|
||||
cron_logger.error(str(es_r))
|
||||
else:
|
||||
set_progress(r["id"], 1., "Done!")
|
||||
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
|
||||
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
|
||||
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
|
||||
|
||||
tmf.write(str(r["update_time"]) + "\n")
|
||||
tmf.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
peewee_logger = logging.getLogger('peewee')
|
||||
peewee_logger.propagate = False
|
||||
peewee_logger.addHandler(database_logger.handlers[0])
|
||||
peewee_logger.setLevel(database_logger.level)
|
||||
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
main(comm.Get_size(), comm.Get_rank())
|
||||
|
@ -40,6 +40,25 @@ def findMaxDt(fnm):
|
||||
print("WARNING: can't find " + fnm)
|
||||
return m
|
||||
|
||||
|
||||
def findMaxTm(fnm):
|
||||
m = 0
|
||||
try:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
break
|
||||
l = l.strip("\n")
|
||||
if l == 'nan':
|
||||
continue
|
||||
if int(l) > m:
|
||||
m = int(l)
|
||||
except Exception as e:
|
||||
print("WARNING: can't find " + fnm)
|
||||
return m
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
|
@ -294,6 +294,7 @@ class HuEs:
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery deleteByQuery: " +
|
||||
str(e) + "【Q】:" + str(query.to_dict()))
|
||||
if str(e).find("NotFoundError") > 0: return True
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import pathlib
|
||||
|
||||
from elasticsearch_dsl import Q
|
||||
@ -195,11 +196,15 @@ def rm():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
|
||||
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
|
||||
if not DocumentService.delete_by_id(req["doc_id"]):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document removal)!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
MINIO.rm(kb.id, doc.location)
|
||||
|
||||
MINIO.rm(doc.kb_id, doc.location)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -233,3 +238,43 @@ def rename():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get', methods=['GET'])
|
||||
@login_required
|
||||
def get():
|
||||
doc_id = request.args["doc_id"]
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
|
||||
blob = MINIO.get(doc.kb_id, doc.location)
|
||||
return get_json_result(data={"base64": base64.b64decode(blob)})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/change_parser', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "parser_id")
|
||||
def change_parser():
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||
return get_json_result(data=True)
|
||||
|
||||
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
|
||||
|
||||
@manager.route('/create', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("name", "description", "permission", "embd_id", "parser_id")
|
||||
@validate_request("name", "description", "permission", "parser_id")
|
||||
def create():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
@ -46,7 +46,7 @@ def create():
|
||||
|
||||
@manager.route('/update', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
|
||||
@validate_request("kb_id", "name", "description", "permission", "parser_id")
|
||||
def update():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
@ -72,6 +72,18 @@ def update():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/detail', methods=['GET'])
|
||||
@login_required
|
||||
def detail():
|
||||
kb_id = request.args["kb_id"]
|
||||
try:
|
||||
kb = KnowledgebaseService.get_detail(kb_id)
|
||||
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET'])
|
||||
@login_required
|
||||
def list():
|
||||
|
95
web_server/apps/llm_app.py
Normal file
95
web_server/apps/llm_app.py
Normal file
@ -0,0 +1,95 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from web_server.db.services import duplicate_name
|
||||
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||
from web_server.db.services.user_service import TenantService, UserTenantService
|
||||
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db import StatusEnum, UserTenantRole
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.db_models import Knowledgebase, TenantLLM
|
||||
from web_server.settings import stat_logger, RetCode
|
||||
from web_server.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/factories', methods=['GET'])
|
||||
@login_required
|
||||
def factories():
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
return get_json_result(data=fac.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set_api_key', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": req["llm_factory"],
|
||||
"api_key": req["api_key"]
|
||||
}
|
||||
# TODO: Test api_key
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req: llm[n] = req[n]
|
||||
|
||||
TenantLLM.insert(**llm).on_conflict("replace").execute()
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/my_llms', methods=['GET'])
|
||||
@login_required
|
||||
def my_llms():
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
objs = [o.to_dict() for o in objs]
|
||||
for o in objs: del o["api_key"]
|
||||
return get_json_result(data=objs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET'])
|
||||
@login_required
|
||||
def list():
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
objs = [o.to_dict() for o in objs if o.api_key]
|
||||
fct = {}
|
||||
for o in objs:
|
||||
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
|
||||
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
|
||||
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
||||
for m in llms:
|
||||
m["available"] = False
|
||||
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
|
||||
m["available"] = True
|
||||
res = {}
|
||||
for m in llms:
|
||||
if m["fid"] not in res: res[m["fid"]] = []
|
||||
res[m["fid"]].append(m)
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
@ -16,9 +16,12 @@
|
||||
from flask import request, session, redirect, url_for
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from flask_login import login_required, current_user, login_user, logout_user
|
||||
|
||||
from web_server.db.db_models import TenantLLM
|
||||
from web_server.db.services.llm_service import TenantLLMService
|
||||
from web_server.utils.api_utils import server_error_response, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
||||
from web_server.db import UserTenantRole
|
||||
from web_server.db import UserTenantRole, LLMType
|
||||
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
||||
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
||||
from web_server.settings import stat_logger
|
||||
@ -47,8 +50,9 @@ def login():
|
||||
avatar = download_img(userinfo["avatar_url"])
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register({
|
||||
users = user_register(user_id, {
|
||||
"access_token": session["access_token"],
|
||||
"email": userinfo["email"],
|
||||
"avatar": avatar,
|
||||
@ -63,6 +67,7 @@ def login():
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return server_error_response(e)
|
||||
elif not request.json:
|
||||
@ -162,7 +167,25 @@ def user_info():
|
||||
return get_json_result(data=current_user.to_dict())
|
||||
|
||||
|
||||
def user_register(user):
|
||||
def rollback_user_registration(user_id):
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def user_register(user_id, user):
|
||||
|
||||
user_id = get_uuid()
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
@ -180,10 +203,12 @@ def user_register(user):
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
|
||||
|
||||
if not UserService.save(**user):return
|
||||
TenantService.save(**tenant)
|
||||
UserTenantService.save(**usr_tenant)
|
||||
TenantLLMService.save(**tenant_llm)
|
||||
return UserService.query(email=user["email"])
|
||||
|
||||
|
||||
@ -203,14 +228,17 @@ def user_add():
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
}
|
||||
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_dict)
|
||||
users = user_register(user_id, user_dict)
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
@ -220,7 +248,7 @@ def user_add():
|
||||
@login_required
|
||||
def tenant_info():
|
||||
try:
|
||||
tenants = TenantService.get_by_user_id(current_user.id)
|
||||
tenants = TenantService.get_by_user_id(current_user.id)[0]
|
||||
return get_json_result(data=tenants)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
|
||||
class LLM(DataBaseModel):
|
||||
# defautlt LLMs for every users
|
||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
||||
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
@ -442,8 +443,8 @@ class LLM(DataBaseModel):
|
||||
class TenantLLM(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
||||
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name")
|
||||
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
|
||||
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
|
||||
@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant_llm"
|
||||
primary_key = CompositeKey('tenant_id', 'llm_factory')
|
||||
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
|
||||
|
||||
|
||||
class Knowledgebase(DataBaseModel):
|
||||
@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel):
|
||||
permission = CharField(max_length=16, null=False, help_text="me|team")
|
||||
created_by = CharField(max_length=32, null=False)
|
||||
doc_num = IntegerField(default=0)
|
||||
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
|
@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from peewee import Expression
|
||||
|
||||
from web_server.db import TenantPermission, FileType
|
||||
from web_server.db.db_models import DB, Knowledgebase
|
||||
from web_server.db.db_models import DB, Knowledgebase, Tenant
|
||||
from web_server.db.db_models import Document
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
@ -61,15 +62,28 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
|
||||
docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where(
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
cls.model.type != FileType.VIRTUAL,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= tm,
|
||||
cls.model.create_time %
|
||||
comm == mod).order_by(
|
||||
cls.model.update_time.asc()).paginate(
|
||||
1,
|
||||
items_per_page)
|
||||
(Expression(cls.model.create_time, "%%", comm) == mod))\
|
||||
.order_by(cls.model.update_time.asc())\
|
||||
.paginate(1, items_per_page)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duation=cls.model.process_duation+duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
|
||||
return num
|
||||
|
||||
|
@ -17,7 +17,7 @@ import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db import TenantPermission
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import DB, UserTenant, Tenant
|
||||
from web_server.db.db_models import Knowledgebase
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc):
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page, orderby, desc):
|
||||
kbs = cls.model.select().where(
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
kbs = kbs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(kbs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_detail(cls, kb_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
Tenant.embd_id,
|
||||
cls.model.avatar,
|
||||
cls.model.name,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id]
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if not kbs:
|
||||
return
|
||||
d = kbs[0].to_dict()
|
||||
d["embd_id"] = kbs[0].tenant.embd_id
|
||||
return d
|
||||
|
@ -33,3 +33,21 @@ class LLMService(CommonService):
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_type):
|
||||
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
|
||||
if objs and len(objs)>0 and objs[0].llm_name:
|
||||
return objs[0]
|
||||
|
||||
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
|
||||
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
|
||||
(cls.model.tenant_id == tenant_id),
|
||||
(cls.model.model_type == model_type),
|
||||
(LLM.status == StatusEnum.VALID)
|
||||
)
|
||||
|
||||
if not objs:return
|
||||
return objs[0]
|
||||
|
||||
|
@ -79,7 +79,7 @@ class TenantService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_user_id(cls, user_id):
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
|
||||
return list(cls.model.select(*fields)\
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
@ -143,7 +143,7 @@ def filename_type(filename):
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
return FileType.PDF.value
|
||||
|
||||
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
|
||||
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
||||
return FileType.DOC.value
|
||||
|
||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
||||
|
Loading…
x
Reference in New Issue
Block a user