Skip to content

Commit 1cd4e45

Browse files
committed
Add API-level validation for generation parameters and guard retriever initialization
Invalid generation parameters such as negative top_k were previously forwarded to the model layer, which could result in runtime errors. This change adds ge/le constraints to generation parameters in PromptedLLMRequest so that invalid inputs are rejected with a 422 validation error at the API layer. Additionally, retriever initialization is now guarded to prevent startup errors when DOC_PATHS is empty.
1 parent 8d5ead8 commit 1cd4e45

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

app/routes/api.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
from fastapi import APIRouter, Depends, HTTPException, Header, Query, Request
55
from sqlalchemy.orm import Session
6-
from pydantic import BaseModel, Field
6+
from pydantic import BaseModel, Field
77
import time
88
import logging
99
import os
@@ -21,17 +21,47 @@ class ChatMessage(BaseModel):
2121
content: str
2222

2323
class PromptedLLMRequest(BaseModel):
24-
"""Request model for ask-llm-prompted endpoint"""
25-
chat: bool = Field(False, description="Enable chat mode (uses messages instead of question)")
26-
question: Optional[str] = Field(None, description="The question to ask (required if chat=False)")
27-
custom_prompt: Optional[str] = Field(None, description="Custom prompt to replace system prompt (required if chat=False)")
28-
messages: Optional[List[ChatMessage]] = Field(None, description="List of chat messages (required if chat=True)")
29-
max_length: int = Field(1024, description="Maximum length of generated text")
30-
truncation: bool = Field(True, description="Whether to truncate input if too long")
31-
repetition_penalty: float = Field(1.1, description="Repetition penalty")
32-
temperature: float = Field(0.7, description="Temperature for sampling")
33-
top_p: float = Field(0.9, description="Top-p (nucleus) sampling parameter")
34-
top_k: int = Field(50, description="Top-k sampling parameter")
24+
chat: bool = False
25+
question: Optional[str] = None
26+
custom_prompt: Optional[str] = None
27+
messages: Optional[List[ChatMessage]] = None
28+
29+
max_length: int = Field(
30+
1024,
31+
ge=1,
32+
le=2048,
33+
description="Must be between 1 and 2048"
34+
)
35+
36+
truncation: bool = True
37+
38+
repetition_penalty: float = Field(
39+
1.1,
40+
ge=1.0,
41+
le=1.5,
42+
description="Must be between 1.0 and 1.5"
43+
)
44+
45+
temperature: float = Field(
46+
0.7,
47+
ge=0.0,
48+
le=1.0,
49+
description="Must be between 0 and 1"
50+
)
51+
52+
top_p: float = Field(
53+
0.9,
54+
ge=0.0,
55+
le=1.0,
56+
description="Must be between 0 and 1"
57+
)
58+
59+
top_k: int = Field(
60+
50,
61+
ge=1,
62+
le=100,
63+
description="Must be between 1 and 100"
64+
)
3565

3666
router = APIRouter(tags=["api"])
3767

@@ -40,7 +70,8 @@ class PromptedLLMRequest(BaseModel):
4070

4171
# load ai agent and document paths
4272
agent = RAGAgent(model=settings.DEFAULT_MODEL)
43-
agent.retriever = agent.setup_vectorstore(settings.DOC_PATHS)
73+
if settings.DOC_PATHS:
74+
agent.retriever = agent.setup_vectorstore(settings.DOC_PATHS)
4475

4576
# user quotas tracking
4677
user_quotas: Dict[str, Dict] = {}

0 commit comments

Comments
 (0)