diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 910a6b2a2..274627291 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -15,13 +15,12 @@ # import logging -from api.db.services.user_service import TenantService -from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel from api import settings from api.db import LLMType -from api.db.db_models import DB -from api.db.db_models import LLMFactories, LLM, TenantLLM +from api.db.db_models import DB, LLM, LLMFactories, TenantLLM from api.db.services.common_service import CommonService +from api.db.services.user_service import TenantService +from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel class LLMFactoriesService(CommonService): @@ -266,6 +265,14 @@ class LLMBundle: "LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) return txt + def describe_with_prompt(self, image, prompt): + txt, used_tokens = self.mdl.describe_with_prompt(image, prompt) + if not TenantLLMService.increase_usage( + self.tenant_id, self.llm_type, used_tokens): + logging.error( + "LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) + return txt + def transcription(self, audio): txt, used_tokens = self.mdl.transcription(audio) if not TenantLLMService.increase_usage( diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 2205d0e61..8c434a6eb 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -17,26 +17,27 @@ import logging import os import random -from timeit import default_timer as timer +import re import sys import threading -import trio - -import xgboost as xgb +from copy import deepcopy from io import BytesIO -import re -import pdfplumber -from PIL import Image +from timeit import default_timer as timer + import numpy as np +import pdfplumber +import trio +import xgboost as xgb +from huggingface_hub import snapshot_download +from PIL import Image from pypdf import PdfReader as pdf2_read from api import settings from api.utils.file_utils import get_project_base_directory -from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer +from deepdoc.vision import OCR, LayoutRecognizer, Recognizer, TableStructureRecognizer +from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.nlp import rag_tokenizer -from copy import deepcopy -from huggingface_hub import snapshot_download - +from rag.prompts import vision_llm_describe_prompt from rag.settings import PARALLEL_DEVICES LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" @@ -45,7 +46,7 @@ if LOCK_KEY_pdfplumber not in sys.modules: class RAGFlowPdfParser: - def __init__(self): + def __init__(self, **kwargs): """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -57,12 +58,12 @@ class RAGFlowPdfParser: ^_- """ - + self.ocr = OCR() self.parallel_limiter = None if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1: self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] - + if hasattr(self, "model_speciess"): self.layouter = LayoutRecognizer("layout." + self.model_speciess) else: @@ -106,7 +107,7 @@ class RAGFlowPdfParser: def _y_dis( self, a, b): return ( - b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 + b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 def _match_proj(self, b): proj_patt = [ @@ -129,9 +130,9 @@ class RAGFlowPdfParser: tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split() tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split() tks_all = up["text"][-LEN:].strip() \ - + (" " if re.match(r"[a-zA-Z0-9]+", - up["text"][-1] + down["text"][0]) else "") \ - + down["text"][:LEN].strip() + + (" " if re.match(r"[a-zA-Z0-9]+", + up["text"][-1] + down["text"][0]) else "") \ + + down["text"][:LEN].strip() tks_all = rag_tokenizer.tokenize(tks_all).split() fea = [ up.get("R", -1) == down.get("R", -1), @@ -153,7 +154,7 @@ class RAGFlowPdfParser: True if re.search(r"[,,][^。.]+$", up["text"]) else False, True if re.search(r"[,,][^。.]+$", up["text"]) else False, True if re.search(r"[\((][^\))]+$", up["text"]) - and re.search(r"[\))]", down["text"]) else False, + and re.search(r"[\))]", down["text"]) else False, self._match_proj(down), True if re.match(r"[A-Z]", down["text"]) else False, True if re.match(r"[A-Z]", up["text"][-1]) else False, @@ -215,7 +216,7 @@ class RAGFlowPdfParser: continue for tb in tbls: # for table left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ - tb["x1"] + MARGIN, tb["bottom"] + MARGIN + tb["x1"] + MARGIN, tb["bottom"] + MARGIN left *= ZM top *= ZM right *= ZM @@ -309,7 +310,7 @@ class RAGFlowPdfParser: "page_number": pagenum} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], self.mean_height[-1] / 3 ) - + # merge chars in the same rect for c in Recognizer.sort_Y_firstly( chars, self.mean_height[pagenum - 1] // 4): @@ -457,7 +458,7 @@ class RAGFlowPdfParser: b_["text"], any(feats), any(concatting_feats), - )) + )) i += 1 continue # merge up and down @@ -665,7 +666,7 @@ class RAGFlowPdfParser: i += 1 continue lout_no = str(self.boxes[i]["page_number"]) + \ - "-" + str(self.boxes[i]["layoutno"]) + "-" + str(self.boxes[i]["layoutno"]) if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title", "figure caption", @@ -968,7 +969,7 @@ class RAGFlowPdfParser: fnm) if not binary else pdfplumber.open(BytesIO(binary)) total_page = len(pdf.pages) pdf.close() - return total_page + return total_page except Exception: logging.exception("total_page_number") @@ -994,7 +995,7 @@ class RAGFlowPdfParser: except Exception as e: logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}") self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead. - + self.total_page = len(self.pdf.pages) except Exception: logging.exception("RAGFlowPdfParser __images__") @@ -1023,7 +1024,7 @@ class RAGFlowPdfParser: logging.debug("Images converted.") self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in - range(len(self.page_chars))] + range(len(self.page_chars))] if sum([1 if e else 0 for e in self.is_english]) > len( self.page_images) / 2: self.is_english = True @@ -1036,7 +1037,7 @@ class RAGFlowPdfParser: if chars[j]["text"] and chars[j + 1]["text"] \ and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \ and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"], - chars[j]["width"]) / 2: + chars[j]["width"]) / 2: chars[j]["text"] += " " j += 1 @@ -1045,7 +1046,7 @@ class RAGFlowPdfParser: await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) else: self.__ocr(i + 1, img, chars, zoomin, id) - + if callback and i % 6 == 5: callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") @@ -1060,14 +1061,14 @@ class RAGFlowPdfParser: ) self.page_cum_height.append(img.size[1] / zoomin) return chars - + if self.parallel_limiter: async with trio.open_nursery() as nursery: for i, img in enumerate(self.page_images): chars = __ocr_preprocess() nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, - self.parallel_limiter[i % PARALLEL_DEVICES]) + self.parallel_limiter[i % PARALLEL_DEVICES]) await trio.sleep(0.1) else: for i, img in enumerate(self.page_images): @@ -1075,9 +1076,9 @@ class RAGFlowPdfParser: await __img_ocr(i, 0, img, chars, None) start = timer() - + trio.run(__img_ocr_launcher) - + logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") if not self.is_english and not any( @@ -1142,7 +1143,7 @@ class RAGFlowPdfParser: self.page_images[pns[0]].crop((left * ZM, top * ZM, right * ZM, min( - bottom, self.page_images[pns[0]].size[1]) + bottom, self.page_images[pns[0]].size[1]) )) ) if 0 < ii < len(poss) - 1: @@ -1240,5 +1241,52 @@ class PlainParser: raise NotImplementedError +class VisionParser(RAGFlowPdfParser): + def __init__(self, vision_model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.vision_model = vision_model + + def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None): + try: + with sys.modules[LOCK_KEY_pdfplumber]: + self.pdf = pdfplumber.open(fnm) if isinstance( + fnm, str) else pdfplumber.open(BytesIO(fnm)) + self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in + enumerate(self.pdf.pages[page_from:page_to])] + self.total_page = len(self.pdf.pages) + except Exception: + self.page_images = None + self.total_page = 0 + logging.exception("VisionParser __images__") + + def __call__(self, filename, from_page=0, to_page=100000, **kwargs): + callback = kwargs.get("callback", lambda prog, msg: None) + + self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs) + + total_pdf_pages = self.total_page + + start_page = max(0, from_page) + end_page = min(to_page, total_pdf_pages) + + all_docs = [] + + for idx, img_binary in enumerate(self.page_images or []): + pdf_page_num = idx # 0-based + if pdf_page_num < start_page or pdf_page_num >= end_page: + continue + + docs = picture_vision_llm_chunk( + binary=img_binary, + vision_model=self.vision_model, + prompt=vision_llm_describe_prompt(page=pdf_page_num+1), + callback=callback, + ) + + if docs: + all_docs.append(docs) + return [(doc, "") for doc in all_docs], [] + + if __name__ == "__main__": pass diff --git a/rag/app/naive.py b/rag/app/naive.py index 133632e50..c9ae6a3fb 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -26,8 +26,10 @@ from markdown import markdown from PIL import Image from tika import parser +from api.db import LLMType +from api.db.services.llm_service import LLMBundle from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser -from deepdoc.parser.pdf_parser import PlainParser +from deepdoc.parser.pdf_parser import PlainParser, VisionParser from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table from rag.utils import num_tokens_from_string @@ -237,9 +239,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, return res elif re.search(r"\.pdf$", filename, re.IGNORECASE): - pdf_parser = Pdf() - if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": + layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") + + if layout_recognizer == "DeepDOC": + pdf_parser = Pdf() + elif layout_recognizer == "Plain Text": pdf_parser = PlainParser() + else: + vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang) + pdf_parser = VisionParser(vision_model=vision_model, **kwargs) + sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) res = tokenize_table(tables, doc, is_english) diff --git a/rag/app/picture.py b/rag/app/picture.py index 578da79ff..94e9e2ef3 100644 --- a/rag/app/picture.py +++ b/rag/app/picture.py @@ -21,8 +21,9 @@ from PIL import Image from api.db import LLMType from api.db.services.llm_service import LLMBundle -from rag.nlp import tokenize from deepdoc.vision import OCR +from rag.nlp import tokenize +from rag.utils import clean_markdown_block ocr = OCR() @@ -57,3 +58,32 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): callback(prog=-1, msg=str(e)) return [] + + +def vision_llm_chunk(binary, vision_model, prompt=None, callback=None): + """ + A simple wrapper to process image to markdown texts via VLM. + + Returns: + Simple markdown texts generated by VLM. + """ + callback = callback or (lambda prog, msg: None) + + img = binary + txt = "" + + try: + img_binary = io.BytesIO() + img.save(img_binary, format='JPEG') + img_binary.seek(0) + + ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt)) + + txt += "\n" + ans + + return txt + + except Exception as e: + callback(-1, str(e)) + + return [] diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 9f26d8044..da52e75ba 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -13,31 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from openai.lib.azure import AzureOpenAI -from zhipuai import ZhipuAI -import io -from abc import ABC -from ollama import Client -from PIL import Image -from openai import OpenAI -import os import base64 -from io import BytesIO +import io import json -import requests +import os +from abc import ABC +from io import BytesIO + +import requests +from ollama import Client +from openai import OpenAI +from openai.lib.azure import AzureOpenAI +from PIL import Image +from zhipuai import ZhipuAI -from rag.nlp import is_english from api.utils import get_uuid from api.utils.file_utils import get_project_base_directory +from rag.nlp import is_english +from rag.prompts import vision_llm_describe_prompt class Base(ABC): def __init__(self, key, model_name): pass - def describe(self, image, max_tokens=300): + def describe(self, image): raise NotImplementedError("Please implement encode method!") - + + def describe_with_prompt(self, image, prompt=None): + raise NotImplementedError("Please implement encode method!") + def chat(self, system, history, gen_conf, image=""): if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] @@ -90,7 +95,7 @@ class Base(ABC): yield ans + "\n**ERROR**: " + str(e) yield tk_count - + def image2base64(self, image): if isinstance(image, bytes): return base64.b64encode(image).decode("utf-8") @@ -122,6 +127,25 @@ class Base(ABC): } ] + def vision_llm_prompt(self, b64, prompt=None): + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64}" + }, + }, + { + "type": "text", + "text": prompt if prompt else vision_llm_describe_prompt(), + }, + ], + } + ] + def chat_prompt(self, text, b64): return [ { @@ -140,12 +164,12 @@ class Base(ABC): class GptV4(Base): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): if not base_url: - base_url="https://api.openai.com/v1" + base_url = "https://api.openai.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang - def describe(self, image, max_tokens=300): + def describe(self, image): b64 = self.image2base64(image) prompt = self.prompt(b64) for i in range(len(prompt)): @@ -159,6 +183,16 @@ class GptV4(Base): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + def describe_with_prompt(self, image, prompt=None): + b64 = self.image2base64(image) + vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) + + res = self.client.chat.completions.create( + model=self.model_name, + messages=vision_prompt, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens + class AzureGptV4(Base): def __init__(self, key, model_name, lang="Chinese", **kwargs): @@ -168,7 +202,7 @@ class AzureGptV4(Base): self.model_name = model_name self.lang = lang - def describe(self, image, max_tokens=300): + def describe(self, image): b64 = self.image2base64(image) prompt = self.prompt(b64) for i in range(len(prompt)): @@ -182,6 +216,16 @@ class AzureGptV4(Base): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + def describe_with_prompt(self, image, prompt=None): + b64 = self.image2base64(image) + vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) + + res = self.client.chat.completions.create( + model=self.model_name, + messages=vision_prompt, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens + class QWenCV(Base): def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs): @@ -212,23 +256,57 @@ class QWenCV(Base): } ] + def vision_llm_prompt(self, binary, prompt=None): + # stupid as hell + tmp_dir = get_project_base_directory("tmp") + if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) + path = os.path.join(tmp_dir, "%s.jpg" % get_uuid()) + Image.open(io.BytesIO(binary)).save(path) + + return [ + { + "role": "user", + "content": [ + { + "image": f"file://{path}" + }, + { + "text": prompt if prompt else vision_llm_describe_prompt(), + }, + ], + } + ] + def chat_prompt(self, text, b64): return [ {"image": f"{b64}"}, {"text": text}, ] - - def describe(self, image, max_tokens=300): + + def describe(self, image): from http import HTTPStatus + from dashscope import MultiModalConversation - response = MultiModalConversation.call(model=self.model_name, - messages=self.prompt(image)) + response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image)) + if response.status_code == HTTPStatus.OK: + return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens + return response.message, 0 + + def describe_with_prompt(self, image, prompt=None): + from http import HTTPStatus + + from dashscope import MultiModalConversation + + vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image) + response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt) if response.status_code == HTTPStatus.OK: return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens return response.message, 0 def chat(self, system, history, gen_conf, image=""): from http import HTTPStatus + from dashscope import MultiModalConversation if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] @@ -254,6 +332,7 @@ class QWenCV(Base): def chat_streamly(self, system, history, gen_conf, image=""): from http import HTTPStatus + from dashscope import MultiModalConversation if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] @@ -292,15 +371,25 @@ class Zhipu4V(Base): self.model_name = model_name self.lang = lang - def describe(self, image, max_tokens=1024): + def describe(self, image): b64 = self.image2base64(image) prompt = self.prompt(b64) prompt[0]["content"][1]["type"] = "text" - + res = self.client.chat.completions.create( model=self.model_name, - messages=prompt + messages=prompt, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens + + def describe_with_prompt(self, image, prompt=None): + b64 = self.image2base64(image) + vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) + + res = self.client.chat.completions.create( + model=self.model_name, + messages=vision_prompt ) return res.choices[0].message.content.strip(), res.usage.total_tokens @@ -334,7 +423,7 @@ class Zhipu4V(Base): his["content"] = self.chat_prompt(his["content"], image) response = self.client.chat.completions.create( - model=self.model_name, + model=self.model_name, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7), @@ -364,7 +453,7 @@ class OllamaCV(Base): self.model_name = model_name self.lang = lang - def describe(self, image, max_tokens=1024): + def describe(self, image): prompt = self.prompt("") try: response = self.client.generate( @@ -377,6 +466,19 @@ class OllamaCV(Base): except Exception as e: return "**ERROR**: " + str(e), 0 + def describe_with_prompt(self, image, prompt=None): + vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("") + try: + response = self.client.generate( + model=self.model_name, + prompt=vision_prompt[0]["content"][1]["text"], + images=[image], + ) + ans = response["response"].strip() + return ans, 128 + except Exception as e: + return "**ERROR**: " + str(e), 0 + def chat(self, system, history, gen_conf, image=""): if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] @@ -460,7 +562,7 @@ class XinferenceCV(Base): self.model_name = model_name self.lang = lang - def describe(self, image, max_tokens=300): + def describe(self, image): b64 = self.image2base64(image) res = self.client.chat.completions.create( @@ -469,27 +571,49 @@ class XinferenceCV(Base): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + def describe_with_prompt(self, image, prompt=None): + b64 = self.image2base64(image) + vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) + + res = self.client.chat.completions.create( + model=self.model_name, + messages=vision_prompt, + ) + return res.choices[0].message.content.strip(), res.usage.total_tokens + + class GeminiCV(Base): def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): - from google.generativeai import client, GenerativeModel + from google.generativeai import GenerativeModel, client client.configure(api_key=key) _client = client.get_default_generative_client() self.model_name = model_name self.model = GenerativeModel(model_name=self.model_name) self.model._client = _client - self.lang = lang + self.lang = lang - def describe(self, image, max_tokens=2048): + def describe(self, image): from PIL.Image import open prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." - b64 = self.image2base64(image) - img = open(BytesIO(base64.b64decode(b64))) - input = [prompt,img] + b64 = self.image2base64(image) + img = open(BytesIO(base64.b64decode(b64))) + input = [prompt, img] res = self.model.generate_content( input ) - return res.text,res.usage_metadata.total_token_count + return res.text, res.usage_metadata.total_token_count + + def describe_with_prompt(self, image, prompt=None): + from PIL.Image import open + b64 = self.image2base64(image) + vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) + img = open(BytesIO(base64.b64decode(b64))) + input = [vision_prompt, img] + res = self.model.generate_content( + input, + ) + return res.text, res.usage_metadata.total_token_count def chat(self, system, history, gen_conf, image=""): from transformers import GenerationConfig @@ -566,7 +690,7 @@ class LocalCV(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): pass - def describe(self, image, max_tokens=1024): + def describe(self, image): return "", 0 @@ -590,7 +714,7 @@ class NvidiaCV(Base): ) self.key = key - def describe(self, image, max_tokens=1024): + def describe(self, image): b64 = self.image2base64(image) response = requests.post( url=self.base_url, @@ -609,6 +733,27 @@ class NvidiaCV(Base): response["usage"]["total_tokens"], ) + def describe_with_prompt(self, image, prompt=None): + b64 = self.image2base64(image) + vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) + + response = requests.post( + url=self.base_url, + headers={ + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"Bearer {self.key}", + }, + json={ + "messages": vision_prompt, + }, + ) + response = response.json() + return ( + response["choices"][0]["message"]["content"].strip(), + response["usage"]["total_tokens"], + ) + def prompt(self, b64): return [ { @@ -622,6 +767,17 @@ class NvidiaCV(Base): } ] + def vision_llm_prompt(self, b64, prompt=None): + return [ + { + "role": "user", + "content": ( + prompt if prompt else vision_llm_describe_prompt() + ) + + f' ', + } + ] + def chat_prompt(self, text, b64): return [ { @@ -634,7 +790,7 @@ class NvidiaCV(Base): class StepFunCV(GptV4): def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): if not base_url: - base_url="https://api.stepfun.com/v1" + base_url = "https://api.stepfun.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang @@ -666,18 +822,18 @@ class TogetherAICV(GptV4): def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"): if not base_url: base_url = "https://api.together.xyz/v1" - super().__init__(key, model_name,lang,base_url) + super().__init__(key, model_name, lang, base_url) class YiCV(GptV4): - def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",): + def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",): if not base_url: base_url = "https://api.lingyiwanwu.com/v1" - super().__init__(key, model_name,lang,base_url) + super().__init__(key, model_name, lang, base_url) class HunyuanCV(Base): - def __init__(self, key, model_name, lang="Chinese",base_url=None): + def __init__(self, key, model_name, lang="Chinese", base_url=None): from tencentcloud.common import credential from tencentcloud.hunyuan.v20230901 import hunyuan_client @@ -689,12 +845,12 @@ class HunyuanCV(Base): self.client = hunyuan_client.HunyuanClient(cred, "") self.lang = lang - def describe(self, image, max_tokens=4096): - from tencentcloud.hunyuan.v20230901 import models + def describe(self, image): from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) - + from tencentcloud.hunyuan.v20230901 import models + b64 = self.image2base64(image) req = models.ChatCompletionsRequest() params = {"Model": self.model_name, "Messages": self.prompt(b64)} @@ -706,7 +862,24 @@ class HunyuanCV(Base): return ans, response.Usage.TotalTokens except TencentCloudSDKException as e: return ans + "\n**ERROR**: " + str(e), 0 - + + def describe_with_prompt(self, image, prompt=None): + from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException + from tencentcloud.hunyuan.v20230901 import models + + b64 = self.image2base64(image) + vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) + req = models.ChatCompletionsRequest() + params = {"Model": self.model_name, "Messages": vision_prompt} + req.from_json_string(json.dumps(params)) + ans = "" + try: + response = self.client.ChatCompletions(req) + ans = response.Choices[0].Message.Content + return ans, response.Usage.TotalTokens + except TencentCloudSDKException as e: + return ans + "\n**ERROR**: " + str(e), 0 + def prompt(self, b64): return [ { @@ -725,4 +898,4 @@ class HunyuanCV(Base): }, ], } - ] \ No newline at end of file + ] diff --git a/rag/prompts.py b/rag/prompts.py index 839a55fc1..2baa82f7e 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -18,13 +18,13 @@ import json import logging import re from collections import defaultdict + import json_repair + from api import settings from api.db import LLMType -from api.db.services.document_service import DocumentService -from api.db.services.llm_service import TenantLLMService, LLMBundle from rag.settings import TAG_FLD -from rag.utils import num_tokens_from_string, encoder +from rag.utils import encoder, num_tokens_from_string def chunks_format(reference): @@ -44,9 +44,11 @@ def chunks_format(reference): def llm_id2llm_type(llm_id): + from api.db.services.llm_service import TenantLLMService + llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) - llm_factories = settings.FACTORY_LLM_INFOS + llm_factories = settings.FACTORY_LLM_INFOS for llm_factory in llm_factories: for llm in llm_factory["llm"]: if llm_id == llm["llm_name"]: @@ -92,6 +94,8 @@ def message_fit_in(msg, max_length=4000): def kb_prompt(kbinfos, max_tokens): + from api.db.services.document_service import DocumentService + knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] used_token_count = 0 chunks_num = 0 @@ -166,15 +170,15 @@ Overall, while Musk enjoys Dogecoin and often promotes it, he also warns against def keyword_extraction(chat_mdl, content, topn=3): prompt = f""" -Role: You're a text analyzer. +Role: You're a text analyzer. Task: extract the most important keywords/phrases of a given piece of text content. -Requirements: +Requirements: - Summarize the text content, and give top {topn} important keywords/phrases. - The keywords MUST be in language of the given piece of text content. - The keywords are delimited by ENGLISH COMMA. - Keywords ONLY in output. -### Text Content +### Text Content {content} """ @@ -194,9 +198,9 @@ Requirements: def question_proposal(chat_mdl, content, topn=3): prompt = f""" -Role: You're a text analyzer. +Role: You're a text analyzer. Task: propose {topn} questions about a given piece of text content. -Requirements: +Requirements: - Understand and summarize the text content, and propose top {topn} important questions. - The questions SHOULD NOT have overlapping meanings. - The questions SHOULD cover the main content of the text as much as possible. @@ -204,7 +208,7 @@ Requirements: - One question per line. - Question ONLY in output. -### Text Content +### Text Content {content} """ @@ -223,6 +227,8 @@ Requirements: def full_question(tenant_id, llm_id, messages, language=None): + from api.db.services.llm_service import LLMBundle + if llm_id2llm_type(llm_id) == "image2text": chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: @@ -239,7 +245,7 @@ def full_question(tenant_id, llm_id, messages, language=None): prompt = f""" Role: A helpful assistant -Task and steps: +Task and steps: 1. Generate a full user question that would follow the conversation. 2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}. @@ -300,11 +306,11 @@ Output: What's the weather in Rochester on {tomorrow}? def content_tagging(chat_mdl, content, all_tags, examples, topn=3): prompt = f""" -Role: You're a text analyzer. +Role: You're a text analyzer. Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set. -Steps:: +Steps:: - Comprehend the tag/label set. - Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON. - Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score. @@ -358,3 +364,32 @@ Output: except Exception as e: logging.exception(f"JSON parsing error: {result} -> {e}") raise e + + +def vision_llm_describe_prompt(page=None) -> str: + prompt_en = """ +INSTRUCTION: +Transcribe the content from the provided PDF page image into clean Markdown format. +- Only output the content transcribed from the image. +- Do NOT output this instruction or any other explanation. +- If the content is missing or you do not understand the input, return an empty string. + +RULES: +1. Do NOT generate examples, demonstrations, or templates. +2. Do NOT output any extra text such as 'Example', 'Example Output', or similar. +3. Do NOT generate any tables, headings, or content that is not explicitly present in the image. +4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content. +5. Do NOT explain Markdown or mention that you are using Markdown. +6. Do NOT wrap the output in ```markdown or ``` blocks. +7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image. +8. Preserve the original language, information, and order exactly as shown in the image. +""" + + if page is not None: + prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`." + + prompt_en += """ +FAILURE HANDLING: +- If you do not detect valid content in the image, return an empty string. +""" + return prompt_en diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index f80d56b7d..c9a007d1b 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -16,7 +16,9 @@ import os import re + import tiktoken + from api.utils.file_utils import get_project_base_directory @@ -54,7 +56,7 @@ def findMaxDt(fnm): pass return m - + def findMaxTm(fnm): m = 0 try: @@ -91,11 +93,18 @@ def truncate(string: str, max_len: int) -> str: """Returns truncated text if the length of text exceed max_len.""" return encoder.decode(encoder.encode(string)[:max_len]) + +def clean_markdown_block(text): + text = re.sub(r'^\s*```markdown\s*\n?', '', text) + text = re.sub(r'\n?\s*```\s*$', '', text) + return text.strip() + def get_float(v: str | None): if v is None: return float('-inf') try: return float(v) except Exception: - return float('-inf') \ No newline at end of file + return float('-inf') +