fix sequence2txt error and usage total token issue (#2961)

### What problem does this PR solve?

#1363

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu 2024-10-22 11:38:37 +08:00 committed by GitHub
parent 6a4858a7ee
commit b2524eec49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 16 additions and 11 deletions

View File

@ -26,7 +26,6 @@ from api.db.services.dialog_service import DialogService, ConversationService, c
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api.settings import RetCode, retrievaler from api.settings import RetCode, retrievaler
from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
@ -187,6 +186,7 @@ def completion():
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e: except Exception as e:
traceback.print_exc()
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}}, "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n" ensure_ascii=False) + "\n\n"

View File

@ -133,7 +133,8 @@ class TenantLLMService(CommonService):
if model_config["llm_factory"] not in Seq2txtModel: if model_config["llm_factory"] not in Seq2txtModel:
return return
return Seq2txtModel[model_config["llm_factory"]]( return Seq2txtModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], lang, key=model_config["api_key"], model_name=model_config["llm_name"],
lang=lang,
base_url=model_config["api_base"] base_url=model_config["api_base"]
) )
if llm_type == LLMType.TTS: if llm_type == LLMType.TTS:

View File

@ -197,6 +197,7 @@ def thumbnail_img(filename, blob):
pass pass
return None return None
def thumbnail(filename, blob): def thumbnail(filename, blob):
img = thumbnail_img(filename, blob) img = thumbnail_img(filename, blob)
if img is not None: if img is not None:
@ -205,6 +206,7 @@ def thumbnail(filename, blob):
else: else:
return '' return ''
def traversal_files(base): def traversal_files(base):
for root, ds, fs in os.walk(base): for root, ds, fs in os.walk(base):
for f in fs: for f in fs:

View File

@ -67,14 +67,16 @@ class Base(ABC):
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content ans += resp.choices[0].delta.content
total_tokens = ( total_tokens += 1
( if not hasattr(resp, "usage") or not resp.usage:
total_tokens total_tokens = (
+ num_tokens_from_string(resp.choices[0].delta.content) total_tokens
) + num_tokens_from_string(resp.choices[0].delta.content)
if not hasattr(resp, "usage") or not resp.usage )
else resp.usage.get("total_tokens", total_tokens) elif isinstance(resp.usage, dict):
) total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"

View File

@ -87,7 +87,7 @@ class AzureSeq2txt(Base):
class XinferenceSeq2txt(Base): class XinferenceSeq2txt(Base):
def __init__(self,key,model_name="whisper-small",**kwargs): def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get('base_url', None) self.base_url = kwargs.get('base_url', None)
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key