fix: huggingface and replicate. (#1888)

This commit is contained in:
Garfield Dai 2024-01-03 18:29:44 +08:00 committed by GitHub
parent ede69b4659
commit 91ee62d1ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 25 deletions

View File

@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
content=chunk.token.text
)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if chunk.details:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
),
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
finish_reason=chunk.details.finish_reason,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
if isinstance(response, str):

View File

@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
)
for key, value in input_properties:
if key not in ['system_prompt', 'prompt']:
if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
value_type = value.get('type')
if not value_type:
@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
index = -1
current_completion: str = ""
stop_condition_reached = False
prediction_output_length = 10000
is_prediction_output_finished = False
for output in prediction.output_iterator():
current_completion += output
if not is_prediction_output_finished and prediction.status == 'succeeded':
prediction_output_length = len(prediction.output) - 1
is_prediction_output_finished = True
if stop:
for s in stop:
if s in current_completion:
@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
content=output if output else ''
)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if index < prediction_output_length:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
),
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage
)
)
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
prompt_messages: list[PromptMessage]) -> LLMResult: