feat: move main API to python
This commit is contained in:
259
src/llm/main.py
259
src/llm/main.py
@@ -3,6 +3,10 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
@@ -19,6 +23,7 @@ If you don't know the answer from the context provided, just say that you don't
|
||||
|
||||
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
|
||||
Do not break the answer into sections based on the source texts. Synthesize them.
|
||||
Do not start with "based on the context provided".
|
||||
The answer should be thorough and well-explained.
|
||||
|
||||
CONTEXT:
|
||||
@@ -30,25 +35,95 @@ QUESTION:
|
||||
ANSWER:
|
||||
"""
|
||||
|
||||
# --- Pydantic Models (Same as before) ---
|
||||
# --- Logging setup to match uvicorn format ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(levelname)s: %(name)s: %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("dune_api")
|
||||
|
||||
# --- Conversation storage (in-memory for simplicity) ---
|
||||
conversations = {} # session_id -> list of {question, answer} pairs
|
||||
|
||||
# --- Pydantic Models ---
|
||||
class AskRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
class ConversationRequest(BaseModel):
|
||||
question: str
|
||||
session_id: Optional[str] = None
|
||||
|
||||
# --- Initialize FastAPI and load resources (Same as before) ---
|
||||
app = FastAPI()
|
||||
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={'trust_remote_code': True})
|
||||
vector_store = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
|
||||
retriever = vector_store.as_retriever(search_kwargs={"k": 8})
|
||||
app = FastAPI(
|
||||
title="Dune Expert API",
|
||||
description="Ask questions about the Dune universe and get expert answers",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
logger.info("Initializing Dune Expert API...")
|
||||
|
||||
try:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={'trust_remote_code': True})
|
||||
vector_store = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
|
||||
retriever = vector_store.as_retriever(search_kwargs={"k": 8})
|
||||
logger.info("Successfully loaded embeddings and vector store")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize vector store: {e}")
|
||||
raise RuntimeError(f"Could not initialize vector store: {e}")
|
||||
|
||||
# --- Health check endpoint ---
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "service": "dune-expert-api", "timestamp": datetime.now().isoformat()}
|
||||
|
||||
# --- Conversation management endpoints ---
|
||||
@app.post("/conversation/start")
|
||||
async def start_conversation():
|
||||
"""Start a new conversation and return a session ID"""
|
||||
session_id = str(uuid.uuid4())
|
||||
conversations[session_id] = []
|
||||
logger.info(f"Started new conversation: {session_id}")
|
||||
return {"session_id": session_id}
|
||||
|
||||
@app.get("/conversation/{session_id}/history")
|
||||
async def get_conversation_history(session_id: str):
|
||||
"""Get the history of a conversation"""
|
||||
if session_id not in conversations:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
return {"session_id": session_id, "history": conversations[session_id]}
|
||||
|
||||
@app.delete("/conversation/{session_id}")
|
||||
async def clear_conversation(session_id: str):
|
||||
"""Clear a conversation's history"""
|
||||
if session_id not in conversations:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
conversations[session_id] = []
|
||||
logger.info(f"Cleared conversation: {session_id}")
|
||||
return {"message": "Conversation cleared"}
|
||||
|
||||
# --- NEW: The Streaming Endpoint ---
|
||||
@app.post("/ask-stream")
|
||||
async def ask_question_stream(request: AskRequest):
|
||||
print(f"🔍 Streaming request for: {request.question}")
|
||||
# Basic input validation
|
||||
if not request.question.strip():
|
||||
logger.warning("Empty question received")
|
||||
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
||||
|
||||
if len(request.question) > 1000:
|
||||
logger.warning(f"Question too long: {len(request.question)} characters")
|
||||
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
|
||||
|
||||
logger.info(f"🔍 Streaming request for: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
|
||||
|
||||
# 1. Retrieve context (this part is still blocking)
|
||||
retrieved_docs = retriever.invoke(request.question)
|
||||
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
|
||||
prompt = PROMPT_TEMPLATE.format(context=context, question=request.question)
|
||||
try:
|
||||
retrieved_docs = retriever.invoke(request.question)
|
||||
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
|
||||
prompt = PROMPT_TEMPLATE.format(context=context, question=request.question)
|
||||
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve context: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
|
||||
|
||||
# 2. Define the generator for the streaming response
|
||||
async def stream_generator():
|
||||
@@ -58,21 +133,169 @@ async def ask_question_stream(request: AskRequest):
|
||||
"prompt": prompt,
|
||||
"stream": True # <-- The key change to enable streaming from Ollama
|
||||
}
|
||||
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
|
||||
|
||||
# Use stream=True to get a streaming response from requests
|
||||
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True) as response:
|
||||
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
|
||||
response.raise_for_status()
|
||||
# Ollama streams JSON objects separated by newlines
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
chunk = json.loads(line)
|
||||
# Yield the actual text part of the token
|
||||
yield chunk.get("response", "")
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
# Yield the actual text part of the token
|
||||
llm_response = chunk.get("response", "")
|
||||
if llm_response: # Only log non-empty responses
|
||||
logger.debug(f"LLM response chunk: {llm_response[:20]}...")
|
||||
yield llm_response
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON chunk: {line}")
|
||||
continue
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = "Request to language model timed out"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = "Could not connect to the language model. Is Ollama running?"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.RequestException as e:
|
||||
print(f"❌ Error communicating with Ollama: {e}")
|
||||
yield "Error: Could not connect to the language model."
|
||||
error_msg = f"Error communicating with Ollama: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
print(f"❌ An unexpected error occurred: {e}")
|
||||
yield "Error: An unexpected error occurred while generating the answer."
|
||||
error_msg = f"An unexpected error occurred: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
|
||||
# 3. Return the generator wrapped in a StreamingResponse
|
||||
return StreamingResponse(stream_generator(), media_type="text/plain")
|
||||
return StreamingResponse(stream_generator(), media_type="text/plain")
|
||||
|
||||
# --- Conversation-enabled streaming endpoint ---
|
||||
@app.post("/ask-conversation")
|
||||
async def ask_question_with_conversation(request: ConversationRequest):
|
||||
# Basic input validation
|
||||
if not request.question.strip():
|
||||
logger.warning("Empty question received")
|
||||
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
||||
|
||||
if len(request.question) > 1000:
|
||||
logger.warning(f"Question too long: {len(request.question)} characters")
|
||||
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
|
||||
|
||||
# Handle session
|
||||
session_id = request.session_id
|
||||
if session_id and session_id not in conversations:
|
||||
logger.warning(f"Unknown session ID: {session_id}")
|
||||
raise HTTPException(status_code=404, detail="Conversation session not found")
|
||||
elif not session_id:
|
||||
# Create new session if none provided
|
||||
session_id = str(uuid.uuid4())
|
||||
conversations[session_id] = []
|
||||
logger.info(f"Created new conversation session: {session_id}")
|
||||
|
||||
logger.info(f"🔍 Conversation request [{session_id[:8]}...]: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
|
||||
|
||||
# 1. Retrieve context from vector store
|
||||
try:
|
||||
retrieved_docs = retriever.invoke(request.question)
|
||||
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
|
||||
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve context: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
|
||||
|
||||
# 2. Build conversation context from history
|
||||
conversation_history = conversations[session_id]
|
||||
conversation_context = ""
|
||||
if conversation_history:
|
||||
conversation_context = "\n\nPREVIOUS CONVERSATION:\n"
|
||||
# Include last 3 exchanges to keep context manageable
|
||||
for exchange in conversation_history[-3:]:
|
||||
conversation_context += f"Human: {exchange['question']}\nAssistant: {exchange['answer'][:200]}{'...' if len(exchange['answer']) > 200 else ''}\n\n"
|
||||
conversation_context += "CURRENT QUESTION:\n"
|
||||
|
||||
# 3. Create enhanced prompt with conversation context
|
||||
enhanced_prompt = f"""
|
||||
You are an expert lore master for the Dune universe.
|
||||
Your task is to answer the user's question with as much detail and context as possible, based *only* on the provided text excerpts.
|
||||
If you don't know the answer from the context provided, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
Pay attention to the conversation history if provided - the user might be asking follow-up questions or referring to previous topics.
|
||||
|
||||
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
|
||||
Do not break the answer into sections based on the source texts. Synthesize them.
|
||||
Do not start with "based on the context provided".
|
||||
The answer should be thorough and well-explained.
|
||||
|
||||
CONTEXT FROM DUNE BOOKS:
|
||||
{context}
|
||||
{conversation_context}
|
||||
QUESTION:
|
||||
{request.question}
|
||||
|
||||
ANSWER:
|
||||
"""
|
||||
|
||||
# 4. Collect the full response for conversation storage
|
||||
full_response = ""
|
||||
|
||||
# 5. Define the generator for the streaming response
|
||||
async def stream_generator():
|
||||
nonlocal full_response
|
||||
try:
|
||||
ollama_payload = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"prompt": enhanced_prompt,
|
||||
"stream": True
|
||||
}
|
||||
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
|
||||
|
||||
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
llm_response = chunk.get("response", "")
|
||||
if llm_response:
|
||||
full_response += llm_response
|
||||
yield llm_response
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON chunk: {line}")
|
||||
continue
|
||||
|
||||
# Store the complete exchange in conversation history
|
||||
conversations[session_id].append({
|
||||
"question": request.question,
|
||||
"answer": full_response
|
||||
})
|
||||
logger.info(f"Stored exchange in conversation {session_id[:8]}... (total: {len(conversations[session_id])} exchanges)")
|
||||
|
||||
# Keep conversation history manageable (max 10 exchanges)
|
||||
if len(conversations[session_id]) > 10:
|
||||
conversations[session_id] = conversations[session_id][-10:]
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = "Request to language model timed out"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = "Could not connect to the language model. Is Ollama running?"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"Error communicating with Ollama: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
error_msg = f"An unexpected error occurred: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
|
||||
# 6. Return the response with session info in headers
|
||||
response = StreamingResponse(stream_generator(), media_type="text/plain")
|
||||
response.headers["X-Session-ID"] = session_id
|
||||
return response
|
||||
Reference in New Issue
Block a user