This commit is contained in:
Timothy Jaeryang Baek 2025-02-12 23:26:47 -08:00
parent 2b7f9d14d0
commit 83e5db7be7
4 changed files with 16 additions and 11 deletions

View File

@ -139,7 +139,7 @@ async def update_task_config(
async def generate_title( async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -231,7 +231,7 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"}, content={"detail": "Tags generation is disabled"},
) )
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -293,7 +293,7 @@ async def generate_chat_tags(
async def generate_image_prompt( async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -374,7 +374,7 @@ async def generate_queries(
detail=f"Query generation is disabled", detail=f"Query generation is disabled",
) )
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -455,7 +455,7 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
) )
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -518,7 +518,7 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -587,7 +587,7 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }

View File

@ -165,7 +165,7 @@ async def generate_chat_completion(
if BYPASS_MODEL_ACCESS_CONTROL: if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True bypass_filter = True
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -284,7 +284,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }
@ -350,7 +350,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
if not request.app.state.MODELS: if not request.app.state.MODELS:
await get_all_models(request) await get_all_models(request)
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }

View File

@ -622,7 +622,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Initialize events to store additional event to be sent to the client # Initialize events to store additional event to be sent to the client
# Initialize contexts and citation # Initialize contexts and citation
if request.state.direct and request.state.model: if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = { models = {
request.state.model["id"]: request.state.model, request.state.model["id"]: request.state.model,
} }

View File

@ -273,6 +273,11 @@
const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx]; const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx];
try { try {
if (API_CONFIG?.prefix_id) {
const prefixId = API_CONFIG.prefix_id;
form_data['model'] = form_data['model'].replace(`${prefixId}.`, ``);
}
const [res, controller] = await chatCompletion( const [res, controller] = await chatCompletion(
OPENAI_API_KEY, OPENAI_API_KEY,
form_data, form_data,