mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-20 13:10:05 +08:00
Feat: add vision LLM PDF parser (#6173)
### What problem does this PR solve? Add vision LLM PDF parser ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
897fe85b5c
commit
5cf610af40
@ -15,13 +15,12 @@
|
||||
#
|
||||
import logging
|
||||
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB
|
||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
||||
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
@ -266,6 +265,14 @@ class LLMBundle:
|
||||
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return txt
|
||||
|
||||
def describe_with_prompt(self, image, prompt):
|
||||
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error(
|
||||
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return txt
|
||||
|
||||
def transcription(self, audio):
|
||||
txt, used_tokens = self.mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(
|
||||
|
@ -17,26 +17,27 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from timeit import default_timer as timer
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import trio
|
||||
|
||||
import xgboost as xgb
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
import re
|
||||
import pdfplumber
|
||||
from PIL import Image
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import numpy as np
|
||||
import pdfplumber
|
||||
import trio
|
||||
import xgboost as xgb
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader as pdf2_read
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
||||
from deepdoc.vision import OCR, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||
from rag.nlp import rag_tokenizer
|
||||
from copy import deepcopy
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from rag.prompts import vision_llm_describe_prompt
|
||||
from rag.settings import PARALLEL_DEVICES
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
@ -45,7 +46,7 @@ if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
|
||||
|
||||
class RAGFlowPdfParser:
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
@ -1240,5 +1241,52 @@ class PlainParser:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VisionParser(RAGFlowPdfParser):
|
||||
def __init__(self, vision_model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vision_model = vision_model
|
||||
|
||||
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
|
||||
try:
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(self.pdf.pages[page_from:page_to])]
|
||||
self.total_page = len(self.pdf.pages)
|
||||
except Exception:
|
||||
self.page_images = None
|
||||
self.total_page = 0
|
||||
logging.exception("VisionParser __images__")
|
||||
|
||||
def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
|
||||
callback = kwargs.get("callback", lambda prog, msg: None)
|
||||
|
||||
self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs)
|
||||
|
||||
total_pdf_pages = self.total_page
|
||||
|
||||
start_page = max(0, from_page)
|
||||
end_page = min(to_page, total_pdf_pages)
|
||||
|
||||
all_docs = []
|
||||
|
||||
for idx, img_binary in enumerate(self.page_images or []):
|
||||
pdf_page_num = idx # 0-based
|
||||
if pdf_page_num < start_page or pdf_page_num >= end_page:
|
||||
continue
|
||||
|
||||
docs = picture_vision_llm_chunk(
|
||||
binary=img_binary,
|
||||
vision_model=self.vision_model,
|
||||
prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
if docs:
|
||||
all_docs.append(docs)
|
||||
return [(doc, "") for doc in all_docs], []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
@ -26,8 +26,10 @@ from markdown import markdown
|
||||
from PIL import Image
|
||||
from tika import parser
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
|
||||
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
@ -237,9 +239,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
return res
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if layout_recognizer == "DeepDOC":
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
elif layout_recognizer == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
else:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
|
||||
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||
|
||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||
callback=callback)
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
|
@ -21,8 +21,9 @@ from PIL import Image
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from rag.nlp import tokenize
|
||||
from deepdoc.vision import OCR
|
||||
from rag.nlp import tokenize
|
||||
from rag.utils import clean_markdown_block
|
||||
|
||||
ocr = OCR()
|
||||
|
||||
@ -57,3 +58,32 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
callback(prog=-1, msg=str(e))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
|
||||
"""
|
||||
A simple wrapper to process image to markdown texts via VLM.
|
||||
|
||||
Returns:
|
||||
Simple markdown texts generated by VLM.
|
||||
"""
|
||||
callback = callback or (lambda prog, msg: None)
|
||||
|
||||
img = binary
|
||||
txt = ""
|
||||
|
||||
try:
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format='JPEG')
|
||||
img_binary.seek(0)
|
||||
|
||||
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
|
||||
|
||||
txt += "\n" + ans
|
||||
|
||||
return txt
|
||||
|
||||
except Exception as e:
|
||||
callback(-1, str(e))
|
||||
|
||||
return []
|
||||
|
@ -13,29 +13,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
import io
|
||||
from abc import ABC
|
||||
from ollama import Client
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import io
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
from abc import ABC
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from PIL import Image
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from rag.nlp import is_english
|
||||
from api.utils import get_uuid
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import is_english
|
||||
from rag.prompts import vision_llm_describe_prompt
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
def describe(self, image):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
@ -122,6 +127,25 @@ class Base(ABC):
|
||||
}
|
||||
]
|
||||
|
||||
def vision_llm_prompt(self, b64, prompt=None):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def chat_prompt(self, text, b64):
|
||||
return [
|
||||
{
|
||||
@ -140,12 +164,12 @@ class Base(ABC):
|
||||
class GptV4(Base):
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url="https://api.openai.com/v1"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64)
|
||||
for i in range(len(prompt)):
|
||||
@ -159,6 +183,16 @@ class GptV4(Base):
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=vision_prompt,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
|
||||
class AzureGptV4(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
@ -168,7 +202,7 @@ class AzureGptV4(Base):
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64)
|
||||
for i in range(len(prompt)):
|
||||
@ -182,6 +216,16 @@ class AzureGptV4(Base):
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=vision_prompt,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
|
||||
class QWenCV(Base):
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
|
||||
@ -212,23 +256,57 @@ class QWenCV(Base):
|
||||
}
|
||||
]
|
||||
|
||||
def vision_llm_prompt(self, binary, prompt=None):
|
||||
# stupid as hell
|
||||
tmp_dir = get_project_base_directory("tmp")
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.mkdir(tmp_dir)
|
||||
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
|
||||
Image.open(io.BytesIO(binary)).save(path)
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"image": f"file://{path}"
|
||||
},
|
||||
{
|
||||
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def chat_prompt(self, text, b64):
|
||||
return [
|
||||
{"image": f"{b64}"},
|
||||
{"text": text},
|
||||
]
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
def describe(self, image):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
response = MultiModalConversation.call(model=self.model_name,
|
||||
messages=self.prompt(image))
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||
return response.message, 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||
return response.message, 0
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
@ -254,6 +332,7 @@ class QWenCV(Base):
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
@ -292,7 +371,7 @@ class Zhipu4V(Base):
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=1024):
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
|
||||
prompt = self.prompt(b64)
|
||||
@ -300,7 +379,17 @@ class Zhipu4V(Base):
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=prompt
|
||||
messages=prompt,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=vision_prompt
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
@ -364,7 +453,7 @@ class OllamaCV(Base):
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=1024):
|
||||
def describe(self, image):
|
||||
prompt = self.prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
@ -377,6 +466,19 @@ class OllamaCV(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
|
||||
try:
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=vision_prompt[0]["content"][1]["text"],
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
@ -460,7 +562,7 @@ class XinferenceCV(Base):
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
@ -469,9 +571,20 @@ class XinferenceCV(Base):
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=vision_prompt,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
|
||||
class GeminiCV(Base):
|
||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||
from google.generativeai import client, GenerativeModel
|
||||
from google.generativeai import GenerativeModel, client
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.model_name = model_name
|
||||
@ -479,17 +592,28 @@ class GeminiCV(Base):
|
||||
self.model._client = _client
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=2048):
|
||||
def describe(self, image):
|
||||
from PIL.Image import open
|
||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
b64 = self.image2base64(image)
|
||||
img = open(BytesIO(base64.b64decode(b64)))
|
||||
input = [prompt,img]
|
||||
input = [prompt, img]
|
||||
res = self.model.generate_content(
|
||||
input
|
||||
)
|
||||
return res.text,res.usage_metadata.total_token_count
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from PIL.Image import open
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
img = open(BytesIO(base64.b64decode(b64)))
|
||||
input = [vision_prompt, img]
|
||||
res = self.model.generate_content(
|
||||
input,
|
||||
)
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
from transformers import GenerationConfig
|
||||
@ -566,7 +690,7 @@ class LocalCV(Base):
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
pass
|
||||
|
||||
def describe(self, image, max_tokens=1024):
|
||||
def describe(self, image):
|
||||
return "", 0
|
||||
|
||||
|
||||
@ -590,7 +714,7 @@ class NvidiaCV(Base):
|
||||
)
|
||||
self.key = key
|
||||
|
||||
def describe(self, image, max_tokens=1024):
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
response = requests.post(
|
||||
url=self.base_url,
|
||||
@ -609,6 +733,27 @@ class NvidiaCV(Base):
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
|
||||
response = requests.post(
|
||||
url=self.base_url,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
},
|
||||
json={
|
||||
"messages": vision_prompt,
|
||||
},
|
||||
)
|
||||
response = response.json()
|
||||
return (
|
||||
response["choices"][0]["message"]["content"].strip(),
|
||||
response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
@ -622,6 +767,17 @@ class NvidiaCV(Base):
|
||||
}
|
||||
]
|
||||
|
||||
def vision_llm_prompt(self, b64, prompt=None):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
prompt if prompt else vision_llm_describe_prompt()
|
||||
)
|
||||
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||
}
|
||||
]
|
||||
|
||||
def chat_prompt(self, text, b64):
|
||||
return [
|
||||
{
|
||||
@ -634,7 +790,7 @@ class NvidiaCV(Base):
|
||||
class StepFunCV(GptV4):
|
||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
||||
if not base_url:
|
||||
base_url="https://api.stepfun.com/v1"
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
@ -666,18 +822,18 @@ class TogetherAICV(GptV4):
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
super().__init__(key, model_name,lang,base_url)
|
||||
super().__init__(key, model_name, lang, base_url)
|
||||
|
||||
|
||||
class YiCV(GptV4):
|
||||
def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
super().__init__(key, model_name,lang,base_url)
|
||||
super().__init__(key, model_name, lang, base_url)
|
||||
|
||||
|
||||
class HunyuanCV(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese",base_url=None):
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
|
||||
@ -689,11 +845,11 @@ class HunyuanCV(Base):
|
||||
self.client = hunyuan_client.HunyuanClient(cred, "")
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=4096):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
def describe(self, image):
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
req = models.ChatCompletionsRequest()
|
||||
@ -707,6 +863,23 @@ class HunyuanCV(Base):
|
||||
except TencentCloudSDKException as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {"Model": self.model_name, "Messages": vision_prompt}
|
||||
req.from_json_string(json.dumps(params))
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.ChatCompletions(req)
|
||||
ans = response.Choices[0].Message.Content
|
||||
return ans, response.Usage.TotalTokens
|
||||
except TencentCloudSDKException as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
|
@ -18,13 +18,13 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import json_repair
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import TenantLLMService, LLMBundle
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import num_tokens_from_string, encoder
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
def chunks_format(reference):
|
||||
@ -44,6 +44,8 @@ def chunks_format(reference):
|
||||
|
||||
|
||||
def llm_id2llm_type(llm_id):
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
|
||||
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
@ -92,6 +94,8 @@ def message_fit_in(msg, max_length=4000):
|
||||
|
||||
|
||||
def kb_prompt(kbinfos, max_tokens):
|
||||
from api.db.services.document_service import DocumentService
|
||||
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
used_token_count = 0
|
||||
chunks_num = 0
|
||||
@ -223,6 +227,8 @@ Requirements:
|
||||
|
||||
|
||||
def full_question(tenant_id, llm_id, messages, language=None):
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
|
||||
if llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
@ -358,3 +364,32 @@ Output:
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def vision_llm_describe_prompt(page=None) -> str:
|
||||
prompt_en = """
|
||||
INSTRUCTION:
|
||||
Transcribe the content from the provided PDF page image into clean Markdown format.
|
||||
- Only output the content transcribed from the image.
|
||||
- Do NOT output this instruction or any other explanation.
|
||||
- If the content is missing or you do not understand the input, return an empty string.
|
||||
|
||||
RULES:
|
||||
1. Do NOT generate examples, demonstrations, or templates.
|
||||
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
|
||||
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
|
||||
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
|
||||
5. Do NOT explain Markdown or mention that you are using Markdown.
|
||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||
"""
|
||||
|
||||
if page is not None:
|
||||
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
|
||||
|
||||
prompt_en += """
|
||||
FAILURE HANDLING:
|
||||
- If you do not detect valid content in the image, return an empty string.
|
||||
"""
|
||||
return prompt_en
|
||||
|
@ -16,7 +16,9 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import tiktoken
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
@ -92,6 +94,12 @@ def truncate(string: str, max_len: int) -> str:
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
|
||||
def clean_markdown_block(text):
|
||||
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
||||
text = re.sub(r'\n?\s*```\s*$', '', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def get_float(v: str | None):
|
||||
if v is None:
|
||||
return float('-inf')
|
||||
@ -99,3 +107,4 @@ def get_float(v: str | None):
|
||||
return float(v)
|
||||
except Exception:
|
||||
return float('-inf')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user