mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 19:35:52 +08:00
Refactor (#537)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
parent
cf9b554c3a
commit
66f8d35632
@ -58,7 +58,8 @@ def upload():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
retmsg="Can't find this knowledgebase!")
|
retmsg="Can't find this knowledgebase!")
|
||||||
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
|
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||||
|
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
retmsg="Exceed the maximum file number of a free user!")
|
retmsg="Exceed the maximum file number of a free user!")
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
|||||||
def factories():
|
def factories():
|
||||||
try:
|
try:
|
||||||
fac = LLMFactoriesService.get_all()
|
fac = LLMFactoriesService.get_all()
|
||||||
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
|
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -174,7 +174,7 @@ def list():
|
|||||||
llms = [m.to_dict()
|
llms = [m.to_dict()
|
||||||
for m in llms if m.status == StatusEnum.VALID.value]
|
for m in llms if m.status == StatusEnum.VALID.value]
|
||||||
for m in llms:
|
for m in llms:
|
||||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
|
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
|
||||||
|
|
||||||
llm_set = set([m["llm_name"] for m in llms])
|
llm_set = set([m["llm_name"] for m in llms])
|
||||||
for o in objs:
|
for o in objs:
|
||||||
|
@ -697,7 +697,7 @@ class Dialog(DataBaseModel):
|
|||||||
null=True,
|
null=True,
|
||||||
default="Chinese",
|
default="Chinese",
|
||||||
help_text="English|Chinese")
|
help_text="English|Chinese")
|
||||||
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||||
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
||||||
"presence_penalty": 0.4, "max_tokens": 215})
|
"presence_penalty": 0.4, "max_tokens": 215})
|
||||||
prompt_type = CharField(
|
prompt_type = CharField(
|
||||||
|
@ -120,7 +120,7 @@ factory_infos = [{
|
|||||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
},{
|
},{
|
||||||
"name": "QAnything",
|
"name": "Youdao",
|
||||||
"logo": "",
|
"logo": "",
|
||||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
@ -323,7 +323,7 @@ def init_llm_factory():
|
|||||||
"max_tokens": 2147483648,
|
"max_tokens": 2147483648,
|
||||||
"model_type": LLMType.EMBEDDING.value
|
"model_type": LLMType.EMBEDDING.value
|
||||||
},
|
},
|
||||||
# ------------------------ QAnything -----------------------
|
# ------------------------ Youdao -----------------------
|
||||||
{
|
{
|
||||||
"fid": factory_infos[7]["name"],
|
"fid": factory_infos[7]["name"],
|
||||||
"llm_name": "maidalun1020/bce-embedding-base_v1",
|
"llm_name": "maidalun1020/bce-embedding-base_v1",
|
||||||
@ -347,7 +347,9 @@ def init_llm_factory():
|
|||||||
LLMService.filter_delete([LLM.fid == "Local"])
|
LLMService.filter_delete([LLM.fid == "Local"])
|
||||||
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
|
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
|
||||||
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
|
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
|
||||||
|
LLMFactoriesService.filter_update([LLMFactoriesService.model.name == "QAnything"], {"name": "Youdao"})
|
||||||
|
LLMService.filter_update([LLMService.model.fid == "QAnything"], {"fid": "Youdao"})
|
||||||
|
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
||||||
"""
|
"""
|
||||||
drop table llm;
|
drop table llm;
|
||||||
drop table llm_factories;
|
drop table llm_factories;
|
||||||
|
@ -81,7 +81,7 @@ class TenantLLMService(CommonService):
|
|||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
llm = LLMService.query(llm_name=llm_name)
|
llm = LLMService.query(llm_name=llm_name)
|
||||||
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
|
if llm and llm[0].fid in ["Youdao", "FastEmbed"]:
|
||||||
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_name == "flag-embedding":
|
if llm_name == "flag-embedding":
|
||||||
|
@ -21,6 +21,7 @@ from api.db import StatusEnum, FileType, TaskStatus
|
|||||||
from api.db.db_models import Task, Document, Knowledgebase, Tenant
|
from api.db.db_models import Task, Document, Knowledgebase, Tenant
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.utils import current_timestamp
|
||||||
|
|
||||||
|
|
||||||
class TaskService(CommonService):
|
class TaskService(CommonService):
|
||||||
@ -70,6 +71,25 @@ class TaskService(CommonService):
|
|||||||
cls.model.id == docs[0]["id"]).execute()
|
cls.model.id == docs[0]["id"]).execute()
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_ongoing_doc_name(cls):
|
||||||
|
with DB.lock("get_task", -1):
|
||||||
|
docs = cls.model.select(*[Document.kb_id, Document.location]) \
|
||||||
|
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
||||||
|
.where(
|
||||||
|
Document.status == StatusEnum.VALID.value,
|
||||||
|
Document.run == TaskStatus.RUNNING.value,
|
||||||
|
~(Document.type == FileType.VIRTUAL.value),
|
||||||
|
cls.model.progress >= 0,
|
||||||
|
cls.model.progress < 1,
|
||||||
|
cls.model.create_time >= current_timestamp() - 180000
|
||||||
|
)
|
||||||
|
docs = list(docs.dicts())
|
||||||
|
if not docs: return []
|
||||||
|
|
||||||
|
return list(set([(d["kb_id"], d["location"]) for d in docs]))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def do_cancel(cls, id):
|
def do_cancel(cls, id):
|
||||||
|
@ -49,7 +49,6 @@ class HuParser:
|
|||||||
self.updown_cnt_mdl.load_model(os.path.join(
|
self.updown_cnt_mdl.load_model(os.path.join(
|
||||||
model_dir, "updown_concat_xgb.model"))
|
model_dir, "updown_concat_xgb.model"))
|
||||||
|
|
||||||
|
|
||||||
self.page_from = 0
|
self.page_from = 0
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
@ -975,6 +974,7 @@ class HuParser:
|
|||||||
self.outlines.append((a["/Title"], depth))
|
self.outlines.append((a["/Title"], depth))
|
||||||
continue
|
continue
|
||||||
dfs(a, depth + 1)
|
dfs(a, depth + 1)
|
||||||
|
|
||||||
dfs(outlines, 0)
|
dfs(outlines, 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Outlines exception: {e}")
|
logging.warning(f"Outlines exception: {e}")
|
||||||
@ -1012,9 +1012,9 @@ class HuParser:
|
|||||||
j += 1
|
j += 1
|
||||||
|
|
||||||
self.__ocr(i + 1, img, chars, zoomin)
|
self.__ocr(i + 1, img, chars, zoomin)
|
||||||
#if callback:
|
if callback and i % 6 == 5:
|
||||||
# callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
||||||
#print("OCR:", timer()-st)
|
# print("OCR:", timer()-st)
|
||||||
|
|
||||||
if not self.is_english and not any(
|
if not self.is_english and not any(
|
||||||
[c for c in self.page_chars]) and self.boxes:
|
[c for c in self.page_chars]) and self.boxes:
|
||||||
|
@ -25,7 +25,7 @@ EmbeddingModel = {
|
|||||||
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
||||||
"ZHIPU-AI": ZhipuEmbed,
|
"ZHIPU-AI": ZhipuEmbed,
|
||||||
"FastEmbed": FastEmbed,
|
"FastEmbed": FastEmbed,
|
||||||
"QAnything": QAnythingEmbed
|
"Youdao": YoudaoEmbed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,19 +229,19 @@ class XinferenceEmbed(Base):
|
|||||||
return np.array(res.data[0].embedding), res.usage.total_tokens
|
return np.array(res.data[0].embedding), res.usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
class QAnythingEmbed(Base):
|
class YoudaoEmbed(Base):
|
||||||
_client = None
|
_client = None
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||||
from BCEmbedding import EmbeddingModel as qanthing
|
from BCEmbedding import EmbeddingModel as qanthing
|
||||||
if not QAnythingEmbed._client:
|
if not YoudaoEmbed._client:
|
||||||
try:
|
try:
|
||||||
print("LOADING BCE...")
|
print("LOADING BCE...")
|
||||||
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
|
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
|
||||||
get_project_base_directory(),
|
get_project_base_directory(),
|
||||||
"rag/res/bce-embedding-base_v1"))
|
"rag/res/bce-embedding-base_v1"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
QAnythingEmbed._client = qanthing(
|
YoudaoEmbed._client = qanthing(
|
||||||
model_name_or_path=model_name.replace(
|
model_name_or_path=model_name.replace(
|
||||||
"maidalun1020", "InfiniFlow"))
|
"maidalun1020", "InfiniFlow"))
|
||||||
|
|
||||||
@ -251,10 +251,10 @@ class QAnythingEmbed(Base):
|
|||||||
for t in texts:
|
for t in texts:
|
||||||
token_count += num_tokens_from_string(t)
|
token_count += num_tokens_from_string(t)
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
|
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
|
||||||
res.extend(embds)
|
res.extend(embds)
|
||||||
return np.array(res), token_count
|
return np.array(res), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embds = QAnythingEmbed._client.encode([text])
|
embds = YoudaoEmbed._client.encode([text])
|
||||||
return np.array(embds[0]), num_tokens_from_string(text)
|
return np.array(embds[0]), num_tokens_from_string(text)
|
||||||
|
43
rag/svr/cache_file_svr.py
Normal file
43
rag/svr/cache_file_svr.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import random
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from api.db.db_models import close_connection
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
|
from rag.utils import MINIO
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
|
def collect():
|
||||||
|
doc_locations = TaskService.get_ongoing_doc_name()
|
||||||
|
#print(tasks)
|
||||||
|
if len(doc_locations) == 0:
|
||||||
|
time.sleep(1)
|
||||||
|
return
|
||||||
|
return doc_locations
|
||||||
|
|
||||||
|
def main():
|
||||||
|
locations = collect()
|
||||||
|
if not locations:return
|
||||||
|
print("TASKS:", len(locations))
|
||||||
|
for kb_id, loc in locations:
|
||||||
|
try:
|
||||||
|
if REDIS_CONN.is_alive():
|
||||||
|
try:
|
||||||
|
key = "{}/{}".format(kb_id, loc)
|
||||||
|
if REDIS_CONN.exist(key):continue
|
||||||
|
file_bin = MINIO.get(kb_id, loc)
|
||||||
|
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
||||||
|
print("CACHE:", loc)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_stack(e)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_stack(e)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
while True:
|
||||||
|
main()
|
||||||
|
close_connection()
|
||||||
|
time.sleep(1)
|
@ -107,8 +107,14 @@ def get_minio_binary(bucket, name):
|
|||||||
global MINIO
|
global MINIO
|
||||||
if REDIS_CONN.is_alive():
|
if REDIS_CONN.is_alive():
|
||||||
try:
|
try:
|
||||||
|
for _ in range(30):
|
||||||
|
if REDIS_CONN.exist("{}/{}".format(bucket, name)):
|
||||||
|
time.sleep(1)
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
r = REDIS_CONN.get("{}/{}".format(bucket, name))
|
r = REDIS_CONN.get("{}/{}".format(bucket, name))
|
||||||
if r: return r
|
if r: return r
|
||||||
|
cron_logger.warning("Cache missing: {}".format(name))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cron_logger.warning("Get redis[EXCEPTION]:" + str(e))
|
cron_logger.warning("Get redis[EXCEPTION]:" + str(e))
|
||||||
return MINIO.get(bucket, name)
|
return MINIO.get(bucket, name)
|
||||||
|
@ -56,7 +56,6 @@ class HuMinio(object):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
|
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
|
||||||
|
|
||||||
|
|
||||||
def get(self, bucket, fnm):
|
def get(self, bucket, fnm):
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
try:
|
try:
|
||||||
|
@ -25,6 +25,14 @@ class RedisDB:
|
|||||||
def is_alive(self):
|
def is_alive(self):
|
||||||
return self.REDIS is not None
|
return self.REDIS is not None
|
||||||
|
|
||||||
|
def exist(self, k):
|
||||||
|
if not self.REDIS: return
|
||||||
|
try:
|
||||||
|
return self.REDIS.exists(k)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("[EXCEPTION]exist" + str(k) + "||" + str(e))
|
||||||
|
self.__open__()
|
||||||
|
|
||||||
def get(self, k):
|
def get(self, k):
|
||||||
if not self.REDIS: return
|
if not self.REDIS: return
|
||||||
try:
|
try:
|
||||||
@ -51,5 +59,16 @@ class RedisDB:
|
|||||||
self.__open__()
|
self.__open__()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def transaction(self, key, value, exp=3600):
|
||||||
|
try:
|
||||||
|
pipeline = self.REDIS.pipeline(transaction=True)
|
||||||
|
pipeline.set(key, value, exp, nx=True)
|
||||||
|
pipeline.execute()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("[EXCEPTION]set" + str(key) + "||" + str(e))
|
||||||
|
self.__open__()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
REDIS_CONN = RedisDB()
|
REDIS_CONN = RedisDB()
|
Loading…
x
Reference in New Issue
Block a user