mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-14 04:18:17 +08:00
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
from collections.abc import Generator
|
|
|
|
import google.generativeai.types.generation_types as generation_config_types
|
|
import pytest
|
|
from _pytest.monkeypatch import MonkeyPatch
|
|
from google.ai import generativelanguage as glm
|
|
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
|
from google.generativeai import GenerativeModel
|
|
from google.generativeai.client import _ClientManager, configure
|
|
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
|
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
|
|
|
current_api_key = ""
|
|
|
|
|
|
class MockGoogleResponseClass:
|
|
_done = False
|
|
|
|
def __iter__(self):
|
|
full_response_text = "it's google!"
|
|
|
|
for i in range(0, len(full_response_text) + 1, 1):
|
|
if i == len(full_response_text):
|
|
self._done = True
|
|
yield GenerateContentResponse(
|
|
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
|
)
|
|
else:
|
|
yield GenerateContentResponse(
|
|
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
|
)
|
|
|
|
|
|
class MockGoogleResponseCandidateClass:
|
|
finish_reason = "stop"
|
|
|
|
@property
|
|
def content(self) -> gag_content.Content:
|
|
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
|
|
|
|
|
|
class MockGoogleClass:
|
|
@staticmethod
|
|
def generate_content_sync() -> GenerateContentResponse:
|
|
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
|
|
|
|
@staticmethod
|
|
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
|
|
return MockGoogleResponseClass()
|
|
|
|
def generate_content(
|
|
self: GenerativeModel,
|
|
contents: content_types.ContentsType,
|
|
*,
|
|
generation_config: generation_config_types.GenerationConfigType | None = None,
|
|
safety_settings: safety_types.SafetySettingOptions | None = None,
|
|
stream: bool = False,
|
|
**kwargs,
|
|
) -> GenerateContentResponse:
|
|
global current_api_key
|
|
|
|
if len(current_api_key) < 16:
|
|
raise Exception("Invalid API key")
|
|
|
|
if stream:
|
|
return MockGoogleClass.generate_content_stream()
|
|
|
|
return MockGoogleClass.generate_content_sync()
|
|
|
|
@property
|
|
def generative_response_text(self) -> str:
|
|
return "it's google!"
|
|
|
|
@property
|
|
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
|
return [MockGoogleResponseCandidateClass()]
|
|
|
|
def make_client(self: _ClientManager, name: str):
|
|
global current_api_key
|
|
|
|
if name.endswith("_async"):
|
|
name = name.split("_")[0]
|
|
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
|
else:
|
|
cls = getattr(glm, name.title() + "ServiceClient")
|
|
|
|
# Attempt to configure using defaults.
|
|
if not self.client_config:
|
|
configure()
|
|
|
|
client_options = self.client_config.get("client_options", None)
|
|
if client_options:
|
|
current_api_key = client_options.api_key
|
|
|
|
def nop(self, *args, **kwargs):
|
|
pass
|
|
|
|
original_init = cls.__init__
|
|
cls.__init__ = nop
|
|
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
|
cls.__init__ = original_init
|
|
|
|
if not self.default_metadata:
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
|
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
|
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
|
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
|
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
|
|
|
yield
|
|
|
|
monkeypatch.undo()
|