Files
kwisatz-haderach/src/llm/main.py

301 lines
13 KiB
Python

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import requests
import json
import logging
from datetime import datetime
import uuid
from typing import Optional
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
# --- Configuration (Same as before) ---
DB_PATH = "dune_db"
EMBEDDING_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
OLLAMA_API_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "llama3:8b"
PROMPT_TEMPLATE = """
You are an expert lore master for the Dune universe.
Your task is to answer the user's question with as much detail and context as possible, based *only* on the provided text excerpts.
If you don't know the answer from the context provided, just say that you don't know, don't try to make up an answer.
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
Do not break the answer into sections based on the source texts. Synthesize them.
Do not start with "based on the context provided".
The answer should be thorough and well-explained.
CONTEXT:
{context}
QUESTION:
{question}
ANSWER:
"""
# --- Logging setup to match uvicorn format ---
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s: %(name)s: %(message)s'
)
logger = logging.getLogger("dune_api")
# --- Conversation storage (in-memory for simplicity) ---
conversations = {} # session_id -> list of {question, answer} pairs
# --- Pydantic Models ---
class AskRequest(BaseModel):
question: str
class ConversationRequest(BaseModel):
question: str
session_id: Optional[str] = None
# --- Initialize FastAPI and load resources (Same as before) ---
app = FastAPI(
title="Dune Expert API",
description="Ask questions about the Dune universe and get expert answers",
version="1.0.0"
)
logger.info("Initializing Dune Expert API...")
try:
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={'trust_remote_code': True})
vector_store = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 8})
logger.info("Successfully loaded embeddings and vector store")
except Exception as e:
logger.error(f"Failed to initialize vector store: {e}")
raise RuntimeError(f"Could not initialize vector store: {e}")
# --- Health check endpoint ---
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "dune-expert-api", "timestamp": datetime.now().isoformat()}
# --- Conversation management endpoints ---
@app.post("/conversation/start")
async def start_conversation():
"""Start a new conversation and return a session ID"""
session_id = str(uuid.uuid4())
conversations[session_id] = []
logger.info(f"Started new conversation: {session_id}")
return {"session_id": session_id}
@app.get("/conversation/{session_id}/history")
async def get_conversation_history(session_id: str):
"""Get the history of a conversation"""
if session_id not in conversations:
raise HTTPException(status_code=404, detail="Conversation not found")
return {"session_id": session_id, "history": conversations[session_id]}
@app.delete("/conversation/{session_id}")
async def clear_conversation(session_id: str):
"""Clear a conversation's history"""
if session_id not in conversations:
raise HTTPException(status_code=404, detail="Conversation not found")
conversations[session_id] = []
logger.info(f"Cleared conversation: {session_id}")
return {"message": "Conversation cleared"}
# --- NEW: The Streaming Endpoint ---
@app.post("/ask-stream")
async def ask_question_stream(request: AskRequest):
# Basic input validation
if not request.question.strip():
logger.warning("Empty question received")
raise HTTPException(status_code=400, detail="Question cannot be empty")
if len(request.question) > 1000:
logger.warning(f"Question too long: {len(request.question)} characters")
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
logger.info(f"🔍 Streaming request for: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
# 1. Retrieve context (this part is still blocking)
try:
retrieved_docs = retriever.invoke(request.question)
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
prompt = PROMPT_TEMPLATE.format(context=context, question=request.question)
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
except Exception as e:
logger.error(f"Failed to retrieve context: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
# 2. Define the generator for the streaming response
async def stream_generator():
try:
ollama_payload = {
"model": OLLAMA_MODEL,
"prompt": prompt,
"stream": True # <-- The key change to enable streaming from Ollama
}
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
# Use stream=True to get a streaming response from requests
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
response.raise_for_status()
# Ollama streams JSON objects separated by newlines
for line in response.iter_lines():
if line:
try:
chunk = json.loads(line)
# Yield the actual text part of the token
llm_response = chunk.get("response", "")
if llm_response: # Only log non-empty responses
logger.debug(f"LLM response chunk: {llm_response[:20]}...")
yield llm_response
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON chunk: {line}")
continue
except requests.exceptions.Timeout:
error_msg = "Request to language model timed out"
logger.error(error_msg)
yield f"Error: {error_msg}"
except requests.exceptions.ConnectionError:
error_msg = "Could not connect to the language model. Is Ollama running?"
logger.error(error_msg)
yield f"Error: {error_msg}"
except requests.RequestException as e:
error_msg = f"Error communicating with Ollama: {e}"
logger.error(error_msg)
yield f"Error: {error_msg}"
except Exception as e:
error_msg = f"An unexpected error occurred: {e}"
logger.error(error_msg)
yield f"Error: {error_msg}"
# 3. Return the generator wrapped in a StreamingResponse
return StreamingResponse(stream_generator(), media_type="text/plain")
# --- Conversation-enabled streaming endpoint ---
@app.post("/ask-conversation")
async def ask_question_with_conversation(request: ConversationRequest):
# Basic input validation
if not request.question.strip():
logger.warning("Empty question received")
raise HTTPException(status_code=400, detail="Question cannot be empty")
if len(request.question) > 1000:
logger.warning(f"Question too long: {len(request.question)} characters")
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
# Handle session
session_id = request.session_id
if session_id and session_id not in conversations:
logger.warning(f"Unknown session ID: {session_id}")
raise HTTPException(status_code=404, detail="Conversation session not found")
elif not session_id:
# Create new session if none provided
session_id = str(uuid.uuid4())
conversations[session_id] = []
logger.info(f"Created new conversation session: {session_id}")
logger.info(f"🔍 Conversation request [{session_id[:8]}...]: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
# 1. Retrieve context from vector store
try:
retrieved_docs = retriever.invoke(request.question)
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
except Exception as e:
logger.error(f"Failed to retrieve context: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
# 2. Build conversation context from history
conversation_history = conversations[session_id]
conversation_context = ""
if conversation_history:
conversation_context = "\n\nPREVIOUS CONVERSATION:\n"
# Include last 3 exchanges to keep context manageable
for exchange in conversation_history[-3:]:
conversation_context += f"Human: {exchange['question']}\nAssistant: {exchange['answer'][:200]}{'...' if len(exchange['answer']) > 200 else ''}\n\n"
conversation_context += "CURRENT QUESTION:\n"
# 3. Create enhanced prompt with conversation context
enhanced_prompt = f"""
You are an expert lore master for the Dune universe.
Your task is to answer the user's question with as much detail and context as possible, based *only* on the provided text excerpts.
If you don't know the answer from the context provided, just say that you don't know, don't try to make up an answer.
Pay attention to the conversation history if provided - the user might be asking follow-up questions or referring to previous topics.
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
Do not break the answer into sections based on the source texts. Synthesize them.
Do not start with "based on the context provided".
The answer should be thorough and well-explained.
CONTEXT FROM DUNE BOOKS:
{context}
{conversation_context}
QUESTION:
{request.question}
ANSWER:
"""
# 4. Collect the full response for conversation storage
full_response = ""
# 5. Define the generator for the streaming response
async def stream_generator():
nonlocal full_response
try:
ollama_payload = {
"model": OLLAMA_MODEL,
"prompt": enhanced_prompt,
"stream": True
}
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
try:
chunk = json.loads(line)
llm_response = chunk.get("response", "")
if llm_response:
full_response += llm_response
yield llm_response
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON chunk: {line}")
continue
# Store the complete exchange in conversation history
conversations[session_id].append({
"question": request.question,
"answer": full_response
})
logger.info(f"Stored exchange in conversation {session_id[:8]}... (total: {len(conversations[session_id])} exchanges)")
# Keep conversation history manageable (max 10 exchanges)
if len(conversations[session_id]) > 10:
conversations[session_id] = conversations[session_id][-10:]
except requests.exceptions.Timeout:
error_msg = "Request to language model timed out"
logger.error(error_msg)
yield f"Error: {error_msg}"
except requests.exceptions.ConnectionError:
error_msg = "Could not connect to the language model. Is Ollama running?"
logger.error(error_msg)
yield f"Error: {error_msg}"
except requests.RequestException as e:
error_msg = f"Error communicating with Ollama: {e}"
logger.error(error_msg)
yield f"Error: {error_msg}"
except Exception as e:
error_msg = f"An unexpected error occurred: {e}"
logger.error(error_msg)
yield f"Error: {error_msg}"
# 6. Return the response with session info in headers
response = StreamingResponse(stream_generator(), media_type="text/plain")
response.headers["X-Session-ID"] = session_id
return response