diff --git a/api/models/model.py b/api/models/model.py index 9172410ebb..078b45c9ea 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -147,7 +147,7 @@ class AppModelConfig(db.Model): "suggested_questions": self.suggested_questions_list, "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, "speech_to_text": self.speech_to_text_dict, - "retriever_resource": self.retriever_resource, + "retriever_resource": self.retriever_resource_dict, "more_like_this": self.more_like_this_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, "model": self.model_dict, diff --git a/api/services/completion_service.py b/api/services/completion_service.py index ce9d45b325..f8cb0d9bc0 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -366,6 +366,7 @@ class CompletionService: generate_channel = list(pubsub.channels.keys())[0].decode('utf-8') if not streaming: try: + message_result = {} for message in pubsub.listen(): if message["type"] == "message": result = message["data"].decode('utf-8') @@ -373,7 +374,10 @@ class CompletionService: if result.get('error'): cls.handle_error(result) if result['event'] == 'message' and 'data' in result: - return cls.get_message_response_data(result.get('data')) + message_result['message'] = result.get('data') + if result['event'] == 'message_end' and 'data' in result: + message_result['message_end'] = result.get('data') + return cls.get_blocking_message_response_data(message_result) except ValueError as e: if e.args[0] != "I/O operation on closed file.": # ignore this error raise CompletionStoppedError() @@ -399,7 +403,6 @@ class CompletionService: if event == "end": logging.debug("{} finished".format(generate_channel)) break - if event == 'message': yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n" elif event == 'chain': @@ -441,6 +444,27 @@ class CompletionService: return response_data + @classmethod + def get_blocking_message_response_data(cls, data: dict): + message = data.get('message') + response_data = { + 'event': 'message', + 'task_id': message.get('task_id'), + 'id': message.get('message_id'), + 'answer': message.get('text'), + 'metadata': {}, + 'created_at': int(time.time()) + } + + if message.get('mode') == 'chat': + response_data['conversation_id'] = message.get('conversation_id') + if 'message_end' in data: + message_end = data.get('message_end') + if 'retriever_resources' in message_end: + response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources') + + return response_data + @classmethod def get_message_end_data(cls, data: dict): response_data = {