mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 14:10:01 +08:00
321 lines
10 KiB
Python
321 lines
10 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from collections import defaultdict
|
|
import json_repair
|
|
from api.db import LLMType
|
|
from api.db.services.document_service import DocumentService
|
|
from api.db.services.llm_service import TenantLLMService, LLMBundle
|
|
from api.utils.file_utils import get_project_base_directory
|
|
from rag.settings import TAG_FLD
|
|
from rag.utils import num_tokens_from_string, encoder
|
|
|
|
|
|
def chunks_format(reference):
|
|
def get_value(d, k1, k2):
|
|
return d.get(k1, d.get(k2))
|
|
|
|
return [{
|
|
"id": get_value(chunk, "chunk_id", "id"),
|
|
"content": get_value(chunk, "content", "content_with_weight"),
|
|
"document_id": get_value(chunk, "doc_id", "document_id"),
|
|
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
|
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
|
"image_id": get_value(chunk, "image_id", "img_id"),
|
|
"positions": get_value(chunk, "positions", "position_int"),
|
|
"url": chunk.get("url")
|
|
} for chunk in reference.get("chunks", [])]
|
|
|
|
|
|
def llm_id2llm_type(llm_id):
|
|
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
|
fnm = os.path.join(get_project_base_directory(), "conf")
|
|
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
|
for llm_factory in llm_factories["factory_llm_infos"]:
|
|
for llm in llm_factory["llm"]:
|
|
if llm_id == llm["llm_name"]:
|
|
return llm["model_type"].strip(",")[-1]
|
|
|
|
|
|
def message_fit_in(msg, max_length=4000):
|
|
def count():
|
|
nonlocal msg
|
|
tks_cnts = []
|
|
for m in msg:
|
|
tks_cnts.append(
|
|
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
|
total = 0
|
|
for m in tks_cnts:
|
|
total += m["count"]
|
|
return total
|
|
|
|
c = count()
|
|
if c < max_length:
|
|
return c, msg
|
|
|
|
msg_ = [m for m in msg[:-1] if m["role"] == "system"]
|
|
if len(msg) > 1:
|
|
msg_.append(msg[-1])
|
|
msg = msg_
|
|
c = count()
|
|
if c < max_length:
|
|
return c, msg
|
|
|
|
ll = num_tokens_from_string(msg_[0]["content"])
|
|
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
|
if ll / (ll + ll2) > 0.8:
|
|
m = msg_[0]["content"]
|
|
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
|
msg[0]["content"] = m
|
|
return max_length, msg
|
|
|
|
m = msg_[1]["content"]
|
|
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
|
|
msg[1]["content"] = m
|
|
return max_length, msg
|
|
|
|
|
|
def kb_prompt(kbinfos, max_tokens):
|
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
|
used_token_count = 0
|
|
chunks_num = 0
|
|
for i, c in enumerate(knowledges):
|
|
used_token_count += num_tokens_from_string(c)
|
|
chunks_num += 1
|
|
if max_tokens * 0.97 < used_token_count:
|
|
knowledges = knowledges[:i]
|
|
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
|
|
break
|
|
|
|
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
|
docs = {d.id: d.meta_fields for d in docs}
|
|
|
|
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
|
for ck in kbinfos["chunks"][:chunks_num]:
|
|
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + ck["content_with_weight"])
|
|
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
|
|
|
knowledges = []
|
|
for nm, cks_meta in doc2chunks.items():
|
|
txt = f"Document: {nm} \n"
|
|
for k, v in cks_meta["meta"].items():
|
|
txt += f"{k}: {v}\n"
|
|
txt += "Relevant fragments as following:\n"
|
|
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
|
txt += f"{i}. {chunk}\n"
|
|
knowledges.append(txt)
|
|
return knowledges
|
|
|
|
|
|
def keyword_extraction(chat_mdl, content, topn=3):
|
|
prompt = f"""
|
|
Role: You're a text analyzer.
|
|
Task: extract the most important keywords/phrases of a given piece of text content.
|
|
Requirements:
|
|
- Summarize the text content, and give top {topn} important keywords/phrases.
|
|
- The keywords MUST be in language of the given piece of text content.
|
|
- The keywords are delimited by ENGLISH COMMA.
|
|
- Keywords ONLY in output.
|
|
|
|
### Text Content
|
|
{content}
|
|
|
|
"""
|
|
msg = [
|
|
{"role": "system", "content": prompt},
|
|
{"role": "user", "content": "Output: "}
|
|
]
|
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
|
if isinstance(kwd, tuple):
|
|
kwd = kwd[0]
|
|
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
|
if kwd.find("**ERROR**") >= 0:
|
|
return ""
|
|
return kwd
|
|
|
|
|
|
def question_proposal(chat_mdl, content, topn=3):
|
|
prompt = f"""
|
|
Role: You're a text analyzer.
|
|
Task: propose {topn} questions about a given piece of text content.
|
|
Requirements:
|
|
- Understand and summarize the text content, and propose top {topn} important questions.
|
|
- The questions SHOULD NOT have overlapping meanings.
|
|
- The questions SHOULD cover the main content of the text as much as possible.
|
|
- The questions MUST be in language of the given piece of text content.
|
|
- One question per line.
|
|
- Question ONLY in output.
|
|
|
|
### Text Content
|
|
{content}
|
|
|
|
"""
|
|
msg = [
|
|
{"role": "system", "content": prompt},
|
|
{"role": "user", "content": "Output: "}
|
|
]
|
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
|
if isinstance(kwd, tuple):
|
|
kwd = kwd[0]
|
|
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
|
if kwd.find("**ERROR**") >= 0:
|
|
return ""
|
|
return kwd
|
|
|
|
|
|
def full_question(tenant_id, llm_id, messages, language=None):
|
|
if llm_id2llm_type(llm_id) == "image2text":
|
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
|
else:
|
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
|
conv = []
|
|
for m in messages:
|
|
if m["role"] not in ["user", "assistant"]:
|
|
continue
|
|
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
|
conv = "\n".join(conv)
|
|
today = datetime.date.today().isoformat()
|
|
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
|
|
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
|
|
prompt = f"""
|
|
Role: A helpful assistant
|
|
|
|
Task and steps:
|
|
1. Generate a full user question that would follow the conversation.
|
|
2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
|
|
|
|
Requirements & Restrictions:
|
|
- If the user's latest question is completely, don't do anything, just return the original question.
|
|
- DON'T generate anything except a refined question."""
|
|
if language:
|
|
prompt += f"""
|
|
- Text generated MUST be in {language}."""
|
|
else:
|
|
prompt += """
|
|
- Text generated MUST be in the same language of the original user's question.
|
|
"""
|
|
prompt += f"""
|
|
|
|
######################
|
|
-Examples-
|
|
######################
|
|
|
|
# Example 1
|
|
## Conversation
|
|
USER: What is the name of Donald Trump's father?
|
|
ASSISTANT: Fred Trump.
|
|
USER: And his mother?
|
|
###############
|
|
Output: What's the name of Donald Trump's mother?
|
|
|
|
------------
|
|
# Example 2
|
|
## Conversation
|
|
USER: What is the name of Donald Trump's father?
|
|
ASSISTANT: Fred Trump.
|
|
USER: And his mother?
|
|
ASSISTANT: Mary Trump.
|
|
User: What's her full name?
|
|
###############
|
|
Output: What's the full name of Donald Trump's mother Mary Trump?
|
|
|
|
------------
|
|
# Example 3
|
|
## Conversation
|
|
USER: What's the weather today in London?
|
|
ASSISTANT: Cloudy.
|
|
USER: What's about tomorrow in Rochester?
|
|
###############
|
|
Output: What's the weather in Rochester on {tomorrow}?
|
|
|
|
######################
|
|
# Real Data
|
|
## Conversation
|
|
{conv}
|
|
###############
|
|
"""
|
|
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
|
|
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
|
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
|
|
|
|
|
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
|
prompt = f"""
|
|
Role: You're a text analyzer.
|
|
|
|
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
|
|
|
|
Steps::
|
|
- Comprehend the tag/label set.
|
|
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
|
|
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
|
|
|
|
Requirements
|
|
- The tags MUST be from the tag set.
|
|
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
|
|
- The relevance score must be range from 1 to 10.
|
|
- Keywords ONLY in output.
|
|
|
|
# TAG SET
|
|
{", ".join(all_tags)}
|
|
|
|
"""
|
|
for i, ex in enumerate(examples):
|
|
prompt += """
|
|
# Examples {}
|
|
### Text Content
|
|
{}
|
|
|
|
Output:
|
|
{}
|
|
|
|
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
|
|
|
|
prompt += f"""
|
|
# Real Data
|
|
### Text Content
|
|
{content}
|
|
|
|
"""
|
|
msg = [
|
|
{"role": "system", "content": prompt},
|
|
{"role": "user", "content": "Output: "}
|
|
]
|
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
|
if isinstance(kwd, tuple):
|
|
kwd = kwd[0]
|
|
kwd = re.sub(r"<think>.*</think>", "", kwd, flags=re.DOTALL)
|
|
if kwd.find("**ERROR**") >= 0:
|
|
raise Exception(kwd)
|
|
|
|
try:
|
|
return json_repair.loads(kwd)
|
|
except json_repair.JSONDecodeError:
|
|
try:
|
|
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
|
|
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
|
return json_repair.loads(result)
|
|
except Exception as e:
|
|
logging.exception(f"JSON parsing error: {result} -> {e}")
|
|
raise e
|