mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
add lighten control (#2567)
### What problem does this PR solve? #2295 ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe):
This commit is contained in:
parent
9251fb39af
commit
7bb28ca2bd
@ -18,6 +18,7 @@ import json
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||||
|
from api.settings import LIGHTEN
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.db import StatusEnum, LLMType
|
from api.db import StatusEnum, LLMType
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
@ -319,13 +320,14 @@ def my_llms():
|
|||||||
@login_required
|
@login_required
|
||||||
def list_app():
|
def list_app():
|
||||||
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
||||||
|
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN else []
|
||||||
model_type = request.args.get("model_type")
|
model_type = request.args.get("model_type")
|
||||||
try:
|
try:
|
||||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||||
llms = LLMService.get_all()
|
llms = LLMService.get_all()
|
||||||
llms = [m.to_dict()
|
llms = [m.to_dict()
|
||||||
for m in llms if m.status == StatusEnum.VALID.value]
|
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
|
||||||
for m in llms:
|
for m in llms:
|
||||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied
|
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ RAG_FLOW_SERVICE_NAME = "ragflow"
|
|||||||
SERVER_MODULE = "rag_flow_server.py"
|
SERVER_MODULE = "rag_flow_server.py"
|
||||||
TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
|
TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp")
|
||||||
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||||
|
LIGHTEN = os.environ.get('LIGHTEN')
|
||||||
|
|
||||||
SUBPROCESS_STD_LOG_NAME = "std.log"
|
SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||||
|
|
||||||
@ -57,7 +58,12 @@ REQUEST_MAX_WAIT_SEC = 300
|
|||||||
|
|
||||||
USE_REGISTRY = get_base_config("use_registry")
|
USE_REGISTRY = get_base_config("use_registry")
|
||||||
|
|
||||||
default_llm = {
|
LLM = get_base_config("user_default_llm", {})
|
||||||
|
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
||||||
|
LLM_BASE_URL = LLM.get("base_url")
|
||||||
|
|
||||||
|
if not LIGHTEN:
|
||||||
|
default_llm = {
|
||||||
"Tongyi-Qianwen": {
|
"Tongyi-Qianwen": {
|
||||||
"chat_model": "qwen-plus",
|
"chat_model": "qwen-plus",
|
||||||
"embedding_model": "text-embedding-v2",
|
"embedding_model": "text-embedding-v2",
|
||||||
@ -113,21 +119,15 @@ default_llm = {
|
|||||||
"asr_model": "",
|
"asr_model": "",
|
||||||
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LLM = get_base_config("user_default_llm", {})
|
|
||||||
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
|
||||||
LLM_BASE_URL = LLM.get("base_url")
|
|
||||||
|
|
||||||
if LLM_FACTORY not in default_llm:
|
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
||||||
print(
|
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
|
||||||
"\33[91m【ERROR】\33[0m:",
|
RERANK_MDL = default_llm["BAAI"]["rerank_model"] if not LIGHTEN else ""
|
||||||
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
|
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
||||||
LLM_FACTORY = "Tongyi-Qianwen"
|
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||||
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
|
else:
|
||||||
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"]
|
CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
|
||||||
RERANK_MDL = default_llm["BAAI"]["rerank_model"]
|
|
||||||
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
|
||||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
|
||||||
|
|
||||||
API_KEY = LLM.get("api_key", "")
|
API_KEY = LLM.get("api_key", "")
|
||||||
PARSERS = LLM.get(
|
PARSERS = LLM.get(
|
||||||
|
@ -16,7 +16,6 @@ import random
|
|||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import torch
|
|
||||||
import re
|
import re
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
import logging
|
import logging
|
||||||
@ -25,6 +24,7 @@ import numpy as np
|
|||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from pypdf import PdfReader as pdf2_read
|
from pypdf import PdfReader as pdf2_read
|
||||||
|
|
||||||
|
from api.settings import LIGHTEN
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
@ -44,6 +44,8 @@ class RAGFlowPdfParser:
|
|||||||
self.tbl_det = TableStructureRecognizer()
|
self.tbl_det = TableStructureRecognizer()
|
||||||
|
|
||||||
self.updown_cnt_mdl = xgb.Booster()
|
self.updown_cnt_mdl = xgb.Booster()
|
||||||
|
if not LIGHTEN:
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||||
try:
|
try:
|
||||||
|
@ -25,10 +25,10 @@ from abc import ABC
|
|||||||
from ollama import Client
|
from ollama import Client
|
||||||
import dashscope
|
import dashscope
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from FlagEmbedding import FlagModel
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from api.settings import LIGHTEN
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
@ -60,8 +60,10 @@ class DefaultEmbedding(Base):
|
|||||||
^_-
|
^_-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not DefaultEmbedding._model:
|
if not LIGHTEN and not DefaultEmbedding._model:
|
||||||
with DefaultEmbedding._model_lock:
|
with DefaultEmbedding._model_lock:
|
||||||
|
from FlagEmbedding import FlagModel
|
||||||
|
import torch
|
||||||
if not DefaultEmbedding._model:
|
if not DefaultEmbedding._model:
|
||||||
try:
|
try:
|
||||||
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
||||||
|
@ -16,12 +16,12 @@
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import requests
|
import requests
|
||||||
import torch
|
|
||||||
from FlagEmbedding import FlagReranker
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
import os
|
import os
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from api.settings import LIGHTEN
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
import json
|
import json
|
||||||
@ -53,7 +53,9 @@ class DefaultRerank(Base):
|
|||||||
^_-
|
^_-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not DefaultRerank._model:
|
if not LIGHTEN and not DefaultRerank._model:
|
||||||
|
import torch
|
||||||
|
from FlagEmbedding import FlagReranker
|
||||||
with DefaultRerank._model_lock:
|
with DefaultRerank._model_lock:
|
||||||
if not DefaultRerank._model:
|
if not DefaultRerank._model:
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user