mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 14:10:01 +08:00
Feat: add vision LLM PDF parser (#6173)
### What problem does this PR solve? Add vision LLM PDF parser ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
897fe85b5c
commit
5cf610af40
@ -15,13 +15,12 @@
|
|||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from api.db.services.user_service import TenantService
|
|
||||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
|
||||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||||
|
|
||||||
|
|
||||||
class LLMFactoriesService(CommonService):
|
class LLMFactoriesService(CommonService):
|
||||||
@ -266,6 +265,14 @@ class LLMBundle:
|
|||||||
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt):
|
||||||
|
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
||||||
|
if not TenantLLMService.increase_usage(
|
||||||
|
self.tenant_id, self.llm_type, used_tokens):
|
||||||
|
logging.error(
|
||||||
|
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
|
return txt
|
||||||
|
|
||||||
def transcription(self, audio):
|
def transcription(self, audio):
|
||||||
txt, used_tokens = self.mdl.transcription(audio)
|
txt, used_tokens = self.mdl.transcription(audio)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(
|
||||||
|
@ -17,26 +17,27 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from timeit import default_timer as timer
|
import re
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import trio
|
from copy import deepcopy
|
||||||
|
|
||||||
import xgboost as xgb
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import re
|
from timeit import default_timer as timer
|
||||||
import pdfplumber
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pdfplumber
|
||||||
|
import trio
|
||||||
|
import xgboost as xgb
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from PIL import Image
|
||||||
from pypdf import PdfReader as pdf2_read
|
from pypdf import PdfReader as pdf2_read
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
from deepdoc.vision import OCR, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||||
|
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from copy import deepcopy
|
from rag.prompts import vision_llm_describe_prompt
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from rag.settings import PARALLEL_DEVICES
|
from rag.settings import PARALLEL_DEVICES
|
||||||
|
|
||||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
@ -45,7 +46,7 @@ if LOCK_KEY_pdfplumber not in sys.modules:
|
|||||||
|
|
||||||
|
|
||||||
class RAGFlowPdfParser:
|
class RAGFlowPdfParser:
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
@ -106,7 +107,7 @@ class RAGFlowPdfParser:
|
|||||||
def _y_dis(
|
def _y_dis(
|
||||||
self, a, b):
|
self, a, b):
|
||||||
return (
|
return (
|
||||||
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
||||||
|
|
||||||
def _match_proj(self, b):
|
def _match_proj(self, b):
|
||||||
proj_patt = [
|
proj_patt = [
|
||||||
@ -129,9 +130,9 @@ class RAGFlowPdfParser:
|
|||||||
tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split()
|
tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split()
|
||||||
tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split()
|
tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split()
|
||||||
tks_all = up["text"][-LEN:].strip() \
|
tks_all = up["text"][-LEN:].strip() \
|
||||||
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
||||||
up["text"][-1] + down["text"][0]) else "") \
|
up["text"][-1] + down["text"][0]) else "") \
|
||||||
+ down["text"][:LEN].strip()
|
+ down["text"][:LEN].strip()
|
||||||
tks_all = rag_tokenizer.tokenize(tks_all).split()
|
tks_all = rag_tokenizer.tokenize(tks_all).split()
|
||||||
fea = [
|
fea = [
|
||||||
up.get("R", -1) == down.get("R", -1),
|
up.get("R", -1) == down.get("R", -1),
|
||||||
@ -153,7 +154,7 @@ class RAGFlowPdfParser:
|
|||||||
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
||||||
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
||||||
True if re.search(r"[\((][^\))]+$", up["text"])
|
True if re.search(r"[\((][^\))]+$", up["text"])
|
||||||
and re.search(r"[\))]", down["text"]) else False,
|
and re.search(r"[\))]", down["text"]) else False,
|
||||||
self._match_proj(down),
|
self._match_proj(down),
|
||||||
True if re.match(r"[A-Z]", down["text"]) else False,
|
True if re.match(r"[A-Z]", down["text"]) else False,
|
||||||
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
||||||
@ -215,7 +216,7 @@ class RAGFlowPdfParser:
|
|||||||
continue
|
continue
|
||||||
for tb in tbls: # for table
|
for tb in tbls: # for table
|
||||||
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
||||||
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
||||||
left *= ZM
|
left *= ZM
|
||||||
top *= ZM
|
top *= ZM
|
||||||
right *= ZM
|
right *= ZM
|
||||||
@ -457,7 +458,7 @@ class RAGFlowPdfParser:
|
|||||||
b_["text"],
|
b_["text"],
|
||||||
any(feats),
|
any(feats),
|
||||||
any(concatting_feats),
|
any(concatting_feats),
|
||||||
))
|
))
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
# merge up and down
|
# merge up and down
|
||||||
@ -665,7 +666,7 @@ class RAGFlowPdfParser:
|
|||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
lout_no = str(self.boxes[i]["page_number"]) + \
|
lout_no = str(self.boxes[i]["page_number"]) + \
|
||||||
"-" + str(self.boxes[i]["layoutno"])
|
"-" + str(self.boxes[i]["layoutno"])
|
||||||
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
|
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
|
||||||
"title",
|
"title",
|
||||||
"figure caption",
|
"figure caption",
|
||||||
@ -1023,7 +1024,7 @@ class RAGFlowPdfParser:
|
|||||||
logging.debug("Images converted.")
|
logging.debug("Images converted.")
|
||||||
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
||||||
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
||||||
range(len(self.page_chars))]
|
range(len(self.page_chars))]
|
||||||
if sum([1 if e else 0 for e in self.is_english]) > len(
|
if sum([1 if e else 0 for e in self.is_english]) > len(
|
||||||
self.page_images) / 2:
|
self.page_images) / 2:
|
||||||
self.is_english = True
|
self.is_english = True
|
||||||
@ -1036,7 +1037,7 @@ class RAGFlowPdfParser:
|
|||||||
if chars[j]["text"] and chars[j + 1]["text"] \
|
if chars[j]["text"] and chars[j + 1]["text"] \
|
||||||
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
|
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
|
||||||
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
|
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
|
||||||
chars[j]["width"]) / 2:
|
chars[j]["width"]) / 2:
|
||||||
chars[j]["text"] += " "
|
chars[j]["text"] += " "
|
||||||
j += 1
|
j += 1
|
||||||
|
|
||||||
@ -1067,7 +1068,7 @@ class RAGFlowPdfParser:
|
|||||||
chars = __ocr_preprocess()
|
chars = __ocr_preprocess()
|
||||||
|
|
||||||
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
|
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
|
||||||
self.parallel_limiter[i % PARALLEL_DEVICES])
|
self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||||
await trio.sleep(0.1)
|
await trio.sleep(0.1)
|
||||||
else:
|
else:
|
||||||
for i, img in enumerate(self.page_images):
|
for i, img in enumerate(self.page_images):
|
||||||
@ -1142,7 +1143,7 @@ class RAGFlowPdfParser:
|
|||||||
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
||||||
right *
|
right *
|
||||||
ZM, min(
|
ZM, min(
|
||||||
bottom, self.page_images[pns[0]].size[1])
|
bottom, self.page_images[pns[0]].size[1])
|
||||||
))
|
))
|
||||||
)
|
)
|
||||||
if 0 < ii < len(poss) - 1:
|
if 0 < ii < len(poss) - 1:
|
||||||
@ -1240,5 +1241,52 @@ class PlainParser:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class VisionParser(RAGFlowPdfParser):
|
||||||
|
def __init__(self, vision_model, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.vision_model = vision_model
|
||||||
|
|
||||||
|
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
|
||||||
|
try:
|
||||||
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
|
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||||
|
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||||
|
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||||
|
enumerate(self.pdf.pages[page_from:page_to])]
|
||||||
|
self.total_page = len(self.pdf.pages)
|
||||||
|
except Exception:
|
||||||
|
self.page_images = None
|
||||||
|
self.total_page = 0
|
||||||
|
logging.exception("VisionParser __images__")
|
||||||
|
|
||||||
|
def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
|
||||||
|
callback = kwargs.get("callback", lambda prog, msg: None)
|
||||||
|
|
||||||
|
self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs)
|
||||||
|
|
||||||
|
total_pdf_pages = self.total_page
|
||||||
|
|
||||||
|
start_page = max(0, from_page)
|
||||||
|
end_page = min(to_page, total_pdf_pages)
|
||||||
|
|
||||||
|
all_docs = []
|
||||||
|
|
||||||
|
for idx, img_binary in enumerate(self.page_images or []):
|
||||||
|
pdf_page_num = idx # 0-based
|
||||||
|
if pdf_page_num < start_page or pdf_page_num >= end_page:
|
||||||
|
continue
|
||||||
|
|
||||||
|
docs = picture_vision_llm_chunk(
|
||||||
|
binary=img_binary,
|
||||||
|
vision_model=self.vision_model,
|
||||||
|
prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if docs:
|
||||||
|
all_docs.append(docs)
|
||||||
|
return [(doc, "") for doc in all_docs], []
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
|
@ -26,8 +26,10 @@ from markdown import markdown
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tika import parser
|
from tika import parser
|
||||||
|
|
||||||
|
from api.db import LLMType
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser
|
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser
|
||||||
from deepdoc.parser.pdf_parser import PlainParser
|
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||||
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table
|
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
|
|
||||||
@ -237,9 +239,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||||
pdf_parser = Pdf()
|
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
|
||||||
|
if layout_recognizer == "DeepDOC":
|
||||||
|
pdf_parser = Pdf()
|
||||||
|
elif layout_recognizer == "Plain Text":
|
||||||
pdf_parser = PlainParser()
|
pdf_parser = PlainParser()
|
||||||
|
else:
|
||||||
|
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
|
||||||
|
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||||
|
|
||||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||||
callback=callback)
|
callback=callback)
|
||||||
res = tokenize_table(tables, doc, is_english)
|
res = tokenize_table(tables, doc, is_english)
|
||||||
|
@ -21,8 +21,9 @@ from PIL import Image
|
|||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from rag.nlp import tokenize
|
|
||||||
from deepdoc.vision import OCR
|
from deepdoc.vision import OCR
|
||||||
|
from rag.nlp import tokenize
|
||||||
|
from rag.utils import clean_markdown_block
|
||||||
|
|
||||||
ocr = OCR()
|
ocr = OCR()
|
||||||
|
|
||||||
@ -57,3 +58,32 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
|||||||
callback(prog=-1, msg=str(e))
|
callback(prog=-1, msg=str(e))
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
|
||||||
|
"""
|
||||||
|
A simple wrapper to process image to markdown texts via VLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Simple markdown texts generated by VLM.
|
||||||
|
"""
|
||||||
|
callback = callback or (lambda prog, msg: None)
|
||||||
|
|
||||||
|
img = binary
|
||||||
|
txt = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_binary = io.BytesIO()
|
||||||
|
img.save(img_binary, format='JPEG')
|
||||||
|
img_binary.seek(0)
|
||||||
|
|
||||||
|
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
|
||||||
|
|
||||||
|
txt += "\n" + ans
|
||||||
|
|
||||||
|
return txt
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
callback(-1, str(e))
|
||||||
|
|
||||||
|
return []
|
||||||
|
@ -13,29 +13,34 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from openai.lib.azure import AzureOpenAI
|
|
||||||
from zhipuai import ZhipuAI
|
|
||||||
import io
|
|
||||||
from abc import ABC
|
|
||||||
from ollama import Client
|
|
||||||
from PIL import Image
|
|
||||||
from openai import OpenAI
|
|
||||||
import os
|
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
import io
|
||||||
import json
|
import json
|
||||||
import requests
|
import os
|
||||||
|
from abc import ABC
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from ollama import Client
|
||||||
|
from openai import OpenAI
|
||||||
|
from openai.lib.azure import AzureOpenAI
|
||||||
|
from PIL import Image
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
|
||||||
from rag.nlp import is_english
|
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
from rag.nlp import is_english
|
||||||
|
from rag.prompts import vision_llm_describe_prompt
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, key, model_name):
|
def __init__(self, key, model_name):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, image=""):
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
@ -122,6 +127,25 @@ class Base(ABC):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def vision_llm_prompt(self, b64, prompt=None):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{b64}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
def chat_prompt(self, text, b64):
|
def chat_prompt(self, text, b64):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@ -140,12 +164,12 @@ class Base(ABC):
|
|||||||
class GptV4(Base):
|
class GptV4(Base):
|
||||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url="https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
prompt = self.prompt(b64)
|
prompt = self.prompt(b64)
|
||||||
for i in range(len(prompt)):
|
for i in range(len(prompt)):
|
||||||
@ -159,6 +183,16 @@ class GptV4(Base):
|
|||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
|
||||||
|
res = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=vision_prompt,
|
||||||
|
)
|
||||||
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
class AzureGptV4(Base):
|
class AzureGptV4(Base):
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
@ -168,7 +202,7 @@ class AzureGptV4(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
prompt = self.prompt(b64)
|
prompt = self.prompt(b64)
|
||||||
for i in range(len(prompt)):
|
for i in range(len(prompt)):
|
||||||
@ -182,6 +216,16 @@ class AzureGptV4(Base):
|
|||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
|
||||||
|
res = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=vision_prompt,
|
||||||
|
)
|
||||||
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
class QWenCV(Base):
|
class QWenCV(Base):
|
||||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
|
||||||
@ -212,23 +256,57 @@ class QWenCV(Base):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def vision_llm_prompt(self, binary, prompt=None):
|
||||||
|
# stupid as hell
|
||||||
|
tmp_dir = get_project_base_directory("tmp")
|
||||||
|
if not os.path.exists(tmp_dir):
|
||||||
|
os.mkdir(tmp_dir)
|
||||||
|
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
|
||||||
|
Image.open(io.BytesIO(binary)).save(path)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"image": f"file://{path}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
def chat_prompt(self, text, b64):
|
def chat_prompt(self, text, b64):
|
||||||
return [
|
return [
|
||||||
{"image": f"{b64}"},
|
{"image": f"{b64}"},
|
||||||
{"text": text},
|
{"text": text},
|
||||||
]
|
]
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
response = MultiModalConversation.call(model=self.model_name,
|
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
|
||||||
messages=self.prompt(image))
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||||
|
return response.message, 0
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
|
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
|
||||||
|
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||||
return response.message, 0
|
return response.message, 0
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, image=""):
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
@ -254,6 +332,7 @@ class QWenCV(Base):
|
|||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, image=""):
|
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
@ -292,7 +371,7 @@ class Zhipu4V(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=1024):
|
def describe(self, image):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
|
|
||||||
prompt = self.prompt(b64)
|
prompt = self.prompt(b64)
|
||||||
@ -300,7 +379,17 @@ class Zhipu4V(Base):
|
|||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=prompt
|
messages=prompt,
|
||||||
|
)
|
||||||
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
|
||||||
|
res = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=vision_prompt
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
@ -364,7 +453,7 @@ class OllamaCV(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=1024):
|
def describe(self, image):
|
||||||
prompt = self.prompt("")
|
prompt = self.prompt("")
|
||||||
try:
|
try:
|
||||||
response = self.client.generate(
|
response = self.client.generate(
|
||||||
@ -377,6 +466,19 @@ class OllamaCV(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
|
||||||
|
try:
|
||||||
|
response = self.client.generate(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=vision_prompt[0]["content"][1]["text"],
|
||||||
|
images=[image],
|
||||||
|
)
|
||||||
|
ans = response["response"].strip()
|
||||||
|
return ans, 128
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, image=""):
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
@ -460,7 +562,7 @@ class XinferenceCV(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
@ -469,9 +571,20 @@ class XinferenceCV(Base):
|
|||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
|
||||||
|
res = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=vision_prompt,
|
||||||
|
)
|
||||||
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
class GeminiCV(Base):
|
class GeminiCV(Base):
|
||||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||||
from google.generativeai import client, GenerativeModel
|
from google.generativeai import GenerativeModel, client
|
||||||
client.configure(api_key=key)
|
client.configure(api_key=key)
|
||||||
_client = client.get_default_generative_client()
|
_client = client.get_default_generative_client()
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -479,17 +592,28 @@ class GeminiCV(Base):
|
|||||||
self.model._client = _client
|
self.model._client = _client
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=2048):
|
def describe(self, image):
|
||||||
from PIL.Image import open
|
from PIL.Image import open
|
||||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
||||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
img = open(BytesIO(base64.b64decode(b64)))
|
img = open(BytesIO(base64.b64decode(b64)))
|
||||||
input = [prompt,img]
|
input = [prompt, img]
|
||||||
res = self.model.generate_content(
|
res = self.model.generate_content(
|
||||||
input
|
input
|
||||||
)
|
)
|
||||||
return res.text,res.usage_metadata.total_token_count
|
return res.text, res.usage_metadata.total_token_count
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
from PIL.Image import open
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
img = open(BytesIO(base64.b64decode(b64)))
|
||||||
|
input = [vision_prompt, img]
|
||||||
|
res = self.model.generate_content(
|
||||||
|
input,
|
||||||
|
)
|
||||||
|
return res.text, res.usage_metadata.total_token_count
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, image=""):
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
@ -566,7 +690,7 @@ class LocalCV(Base):
|
|||||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def describe(self, image, max_tokens=1024):
|
def describe(self, image):
|
||||||
return "", 0
|
return "", 0
|
||||||
|
|
||||||
|
|
||||||
@ -590,7 +714,7 @@ class NvidiaCV(Base):
|
|||||||
)
|
)
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
def describe(self, image, max_tokens=1024):
|
def describe(self, image):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=self.base_url,
|
url=self.base_url,
|
||||||
@ -609,6 +733,27 @@ class NvidiaCV(Base):
|
|||||||
response["usage"]["total_tokens"],
|
response["usage"]["total_tokens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url=self.base_url,
|
||||||
|
headers={
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.key}",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"messages": vision_prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = response.json()
|
||||||
|
return (
|
||||||
|
response["choices"][0]["message"]["content"].strip(),
|
||||||
|
response["usage"]["total_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
def prompt(self, b64):
|
def prompt(self, b64):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@ -622,6 +767,17 @@ class NvidiaCV(Base):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def vision_llm_prompt(self, b64, prompt=None):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
prompt if prompt else vision_llm_describe_prompt()
|
||||||
|
)
|
||||||
|
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
def chat_prompt(self, text, b64):
|
def chat_prompt(self, text, b64):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@ -634,7 +790,7 @@ class NvidiaCV(Base):
|
|||||||
class StepFunCV(GptV4):
|
class StepFunCV(GptV4):
|
||||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url="https://api.stepfun.com/v1"
|
base_url = "https://api.stepfun.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
@ -666,18 +822,18 @@ class TogetherAICV(GptV4):
|
|||||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
|
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.together.xyz/v1"
|
base_url = "https://api.together.xyz/v1"
|
||||||
super().__init__(key, model_name,lang,base_url)
|
super().__init__(key, model_name, lang, base_url)
|
||||||
|
|
||||||
|
|
||||||
class YiCV(GptV4):
|
class YiCV(GptV4):
|
||||||
def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
|
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.lingyiwanwu.com/v1"
|
base_url = "https://api.lingyiwanwu.com/v1"
|
||||||
super().__init__(key, model_name,lang,base_url)
|
super().__init__(key, model_name, lang, base_url)
|
||||||
|
|
||||||
|
|
||||||
class HunyuanCV(Base):
|
class HunyuanCV(Base):
|
||||||
def __init__(self, key, model_name, lang="Chinese",base_url=None):
|
def __init__(self, key, model_name, lang="Chinese", base_url=None):
|
||||||
from tencentcloud.common import credential
|
from tencentcloud.common import credential
|
||||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||||
|
|
||||||
@ -689,11 +845,11 @@ class HunyuanCV(Base):
|
|||||||
self.client = hunyuan_client.HunyuanClient(cred, "")
|
self.client = hunyuan_client.HunyuanClient(cred, "")
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
def describe(self, image, max_tokens=4096):
|
def describe(self, image):
|
||||||
from tencentcloud.hunyuan.v20230901 import models
|
|
||||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||||
TencentCloudSDKException,
|
TencentCloudSDKException,
|
||||||
)
|
)
|
||||||
|
from tencentcloud.hunyuan.v20230901 import models
|
||||||
|
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
req = models.ChatCompletionsRequest()
|
req = models.ChatCompletionsRequest()
|
||||||
@ -707,6 +863,23 @@ class HunyuanCV(Base):
|
|||||||
except TencentCloudSDKException as e:
|
except TencentCloudSDKException as e:
|
||||||
return ans + "\n**ERROR**: " + str(e), 0
|
return ans + "\n**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
|
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
||||||
|
from tencentcloud.hunyuan.v20230901 import models
|
||||||
|
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
req = models.ChatCompletionsRequest()
|
||||||
|
params = {"Model": self.model_name, "Messages": vision_prompt}
|
||||||
|
req.from_json_string(json.dumps(params))
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.ChatCompletions(req)
|
||||||
|
ans = response.Choices[0].Message.Content
|
||||||
|
return ans, response.Usage.TotalTokens
|
||||||
|
except TencentCloudSDKException as e:
|
||||||
|
return ans + "\n**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def prompt(self, b64):
|
def prompt(self, b64):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
@ -18,13 +18,13 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.document_service import DocumentService
|
|
||||||
from api.db.services.llm_service import TenantLLMService, LLMBundle
|
|
||||||
from rag.settings import TAG_FLD
|
from rag.settings import TAG_FLD
|
||||||
from rag.utils import num_tokens_from_string, encoder
|
from rag.utils import encoder, num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
def chunks_format(reference):
|
def chunks_format(reference):
|
||||||
@ -44,6 +44,8 @@ def chunks_format(reference):
|
|||||||
|
|
||||||
|
|
||||||
def llm_id2llm_type(llm_id):
|
def llm_id2llm_type(llm_id):
|
||||||
|
from api.db.services.llm_service import TenantLLMService
|
||||||
|
|
||||||
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||||
|
|
||||||
llm_factories = settings.FACTORY_LLM_INFOS
|
llm_factories = settings.FACTORY_LLM_INFOS
|
||||||
@ -92,6 +94,8 @@ def message_fit_in(msg, max_length=4000):
|
|||||||
|
|
||||||
|
|
||||||
def kb_prompt(kbinfos, max_tokens):
|
def kb_prompt(kbinfos, max_tokens):
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
|
||||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||||
used_token_count = 0
|
used_token_count = 0
|
||||||
chunks_num = 0
|
chunks_num = 0
|
||||||
@ -223,6 +227,8 @@ Requirements:
|
|||||||
|
|
||||||
|
|
||||||
def full_question(tenant_id, llm_id, messages, language=None):
|
def full_question(tenant_id, llm_id, messages, language=None):
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
|
||||||
if llm_id2llm_type(llm_id) == "image2text":
|
if llm_id2llm_type(llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
else:
|
else:
|
||||||
@ -358,3 +364,32 @@ Output:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def vision_llm_describe_prompt(page=None) -> str:
|
||||||
|
prompt_en = """
|
||||||
|
INSTRUCTION:
|
||||||
|
Transcribe the content from the provided PDF page image into clean Markdown format.
|
||||||
|
- Only output the content transcribed from the image.
|
||||||
|
- Do NOT output this instruction or any other explanation.
|
||||||
|
- If the content is missing or you do not understand the input, return an empty string.
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. Do NOT generate examples, demonstrations, or templates.
|
||||||
|
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
|
||||||
|
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
|
||||||
|
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
|
||||||
|
5. Do NOT explain Markdown or mention that you are using Markdown.
|
||||||
|
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||||
|
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||||
|
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if page is not None:
|
||||||
|
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
|
||||||
|
|
||||||
|
prompt_en += """
|
||||||
|
FAILURE HANDLING:
|
||||||
|
- If you do not detect valid content in the image, return an empty string.
|
||||||
|
"""
|
||||||
|
return prompt_en
|
||||||
|
@ -16,7 +16,9 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
@ -92,6 +94,12 @@ def truncate(string: str, max_len: int) -> str:
|
|||||||
return encoder.decode(encoder.encode(string)[:max_len])
|
return encoder.decode(encoder.encode(string)[:max_len])
|
||||||
|
|
||||||
|
|
||||||
|
def clean_markdown_block(text):
|
||||||
|
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
||||||
|
text = re.sub(r'\n?\s*```\s*$', '', text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def get_float(v: str | None):
|
def get_float(v: str | None):
|
||||||
if v is None:
|
if v is None:
|
||||||
return float('-inf')
|
return float('-inf')
|
||||||
@ -99,3 +107,4 @@ def get_float(v: str | None):
|
|||||||
return float(v)
|
return float(v)
|
||||||
except Exception:
|
except Exception:
|
||||||
return float('-inf')
|
return float('-inf')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user