mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 06:00:00 +08:00
fix sequence2txt error and usage total token issue (#2961)
### What problem does this PR solve? #1363 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
6a4858a7ee
commit
b2524eec49
@ -26,7 +26,6 @@ from api.db.services.dialog_service import DialogService, ConversationService, c
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
||||||
from api.settings import RetCode, retrievaler
|
from api.settings import RetCode, retrievaler
|
||||||
from api.utils import get_uuid
|
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from graphrag.mind_map_extractor import MindMapExtractor
|
from graphrag.mind_map_extractor import MindMapExtractor
|
||||||
@ -187,6 +186,7 @@ def completion():
|
|||||||
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
||||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
|
@ -133,7 +133,8 @@ class TenantLLMService(CommonService):
|
|||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
return
|
return
|
||||||
return Seq2txtModel[model_config["llm_factory"]](
|
return Seq2txtModel[model_config["llm_factory"]](
|
||||||
model_config["api_key"], model_config["llm_name"], lang,
|
key=model_config["api_key"], model_name=model_config["llm_name"],
|
||||||
|
lang=lang,
|
||||||
base_url=model_config["api_base"]
|
base_url=model_config["api_base"]
|
||||||
)
|
)
|
||||||
if llm_type == LLMType.TTS:
|
if llm_type == LLMType.TTS:
|
||||||
|
@ -197,6 +197,7 @@ def thumbnail_img(filename, blob):
|
|||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def thumbnail(filename, blob):
|
def thumbnail(filename, blob):
|
||||||
img = thumbnail_img(filename, blob)
|
img = thumbnail_img(filename, blob)
|
||||||
if img is not None:
|
if img is not None:
|
||||||
@ -205,6 +206,7 @@ def thumbnail(filename, blob):
|
|||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def traversal_files(base):
|
def traversal_files(base):
|
||||||
for root, ds, fs in os.walk(base):
|
for root, ds, fs in os.walk(base):
|
||||||
for f in fs:
|
for f in fs:
|
||||||
|
@ -67,14 +67,16 @@ class Base(ABC):
|
|||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
ans += resp.choices[0].delta.content
|
ans += resp.choices[0].delta.content
|
||||||
total_tokens = (
|
total_tokens += 1
|
||||||
(
|
if not hasattr(resp, "usage") or not resp.usage:
|
||||||
total_tokens
|
total_tokens = (
|
||||||
+ num_tokens_from_string(resp.choices[0].delta.content)
|
total_tokens
|
||||||
)
|
+ num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
if not hasattr(resp, "usage") or not resp.usage
|
)
|
||||||
else resp.usage.get("total_tokens", total_tokens)
|
elif isinstance(resp.usage, dict):
|
||||||
)
|
total_tokens = resp.usage.get("total_tokens", total_tokens)
|
||||||
|
else: total_tokens = resp.usage.total_tokens
|
||||||
|
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
@ -87,7 +87,7 @@ class AzureSeq2txt(Base):
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceSeq2txt(Base):
|
class XinferenceSeq2txt(Base):
|
||||||
def __init__(self,key,model_name="whisper-small",**kwargs):
|
def __init__(self, key, model_name="whisper-small", **kwargs):
|
||||||
self.base_url = kwargs.get('base_url', None)
|
self.base_url = kwargs.get('base_url', None)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.key = key
|
self.key = key
|
||||||
|
Loading…
x
Reference in New Issue
Block a user