369 lines
16 KiB
Python
369 lines
16 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
import requests
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
# --- Configuration (Same as before) ---
|
|
DB_PATH = "dune_db"
|
|
EMBEDDING_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
|
|
OLLAMA_API_URL = "http://localhost:11434/api/generate"
|
|
OLLAMA_MODEL = "llama3:8b"
|
|
PROMPT_TEMPLATE = """
|
|
You are an expert lore master for the Dune universe.
|
|
Your task is to answer the user's question with as much detail and context as possible, based *only* on the provided text excerpts.
|
|
If you don't know the answer from the context provided, just say that you don't know, don't try to make up an answer.
|
|
|
|
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
|
|
Do not break the answer into sections based on the source texts. Synthesize them.
|
|
Do not start with "based on the context provided".
|
|
The answer should be thorough and well-explained.
|
|
|
|
CONTEXT:
|
|
{context}
|
|
|
|
QUESTION:
|
|
{question}
|
|
|
|
ANSWER:
|
|
"""
|
|
|
|
# --- Logging setup to match uvicorn format ---
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(levelname)s: %(name)s: %(message)s'
|
|
)
|
|
logger = logging.getLogger("dune_api")
|
|
|
|
# --- Conversation storage (in-memory for simplicity) ---
|
|
conversations = {} # session_id -> {title, created, exchanges: []}
|
|
|
|
# --- Pydantic Models ---
|
|
class AskRequest(BaseModel):
|
|
question: str
|
|
|
|
class ConversationRequest(BaseModel):
|
|
question: str
|
|
session_id: Optional[str] = None
|
|
|
|
class CreateConversationRequest(BaseModel):
|
|
title: str
|
|
|
|
# --- Initialize FastAPI and load resources (Same as before) ---
|
|
app = FastAPI(
|
|
title="Dune Expert API",
|
|
description="Ask questions about the Dune universe and get expert answers",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# --- CORS middleware for React app ---
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Vite dev server
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
logger.info("Initializing Dune Expert API...")
|
|
|
|
try:
|
|
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={'trust_remote_code': True})
|
|
vector_store = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
|
|
retriever = vector_store.as_retriever(search_kwargs={"k": 8})
|
|
logger.info("Successfully loaded embeddings and vector store")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize vector store: {e}")
|
|
raise RuntimeError(f"Could not initialize vector store: {e}")
|
|
|
|
# --- Health check endpoint ---
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy", "service": "dune-expert-api", "timestamp": datetime.now().isoformat()}
|
|
|
|
# --- Conversation management endpoints ---
|
|
@app.post("/conversation/start")
|
|
async def start_conversation(request: CreateConversationRequest):
|
|
"""Start a new conversation with a title and return a session ID"""
|
|
if not request.title.strip():
|
|
raise HTTPException(status_code=400, detail="Title cannot be empty")
|
|
|
|
if len(request.title) > 100:
|
|
raise HTTPException(status_code=400, detail="Title too long (max 100 chars)")
|
|
|
|
session_id = str(uuid.uuid4())
|
|
conversations[session_id] = {
|
|
"title": request.title.strip(),
|
|
"created": datetime.now().isoformat(),
|
|
"exchanges": []
|
|
}
|
|
logger.info(f"Started new conversation: {session_id} - '{request.title}'")
|
|
return {"session_id": session_id, "title": request.title.strip()}
|
|
|
|
@app.get("/conversation/{session_id}/history")
|
|
async def get_conversation_history(session_id: str):
|
|
"""Get the history of a conversation"""
|
|
if session_id not in conversations:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
|
|
conversation = conversations[session_id]
|
|
return {
|
|
"session_id": session_id,
|
|
"title": conversation["title"],
|
|
"created": conversation["created"],
|
|
"exchange_count": len(conversation["exchanges"]),
|
|
"history": conversation["exchanges"]
|
|
}
|
|
|
|
@app.delete("/conversation/{session_id}")
|
|
async def clear_conversation(session_id: str):
|
|
"""Clear a conversation's history"""
|
|
if session_id not in conversations:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
|
|
title = conversations[session_id]["title"]
|
|
conversations[session_id]["exchanges"] = []
|
|
logger.info(f"Cleared conversation: {session_id} - '{title}'")
|
|
return {"message": "Conversation cleared", "title": title}
|
|
|
|
@app.get("/conversations")
|
|
async def get_all_conversations():
|
|
"""Get all stored conversations with their metadata"""
|
|
conversation_summary = []
|
|
|
|
for session_id, conversation in conversations.items():
|
|
exchanges = conversation["exchanges"]
|
|
if exchanges: # Only include conversations that have exchanges
|
|
first_question = exchanges[0]["question"]
|
|
last_question = exchanges[-1]["question"]
|
|
conversation_summary.append({
|
|
"session_id": session_id,
|
|
"title": conversation["title"],
|
|
"created": conversation["created"],
|
|
"exchange_count": len(exchanges),
|
|
"first_question": first_question[:100] + "..." if len(first_question) > 100 else first_question,
|
|
"last_question": last_question[:100] + "..." if len(last_question) > 100 else last_question
|
|
})
|
|
else:
|
|
# Include empty conversations too, but with different info
|
|
conversation_summary.append({
|
|
"session_id": session_id,
|
|
"title": conversation["title"],
|
|
"created": conversation["created"],
|
|
"exchange_count": 0,
|
|
"first_question": None,
|
|
"last_question": None
|
|
})
|
|
|
|
logger.info(f"Retrieved {len(conversation_summary)} conversations")
|
|
return {
|
|
"total_conversations": len(conversation_summary),
|
|
"conversations": conversation_summary
|
|
}
|
|
|
|
# --- NEW: The Streaming Endpoint ---
|
|
@app.post("/ask-stream")
|
|
async def ask_question_stream(request: AskRequest):
|
|
# Basic input validation
|
|
if not request.question.strip():
|
|
logger.warning("Empty question received")
|
|
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
|
|
|
if len(request.question) > 1000:
|
|
logger.warning(f"Question too long: {len(request.question)} characters")
|
|
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
|
|
|
|
logger.info(f"🔍 Streaming request for: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
|
|
|
|
# 1. Retrieve context (this part is still blocking)
|
|
try:
|
|
retrieved_docs = retriever.invoke(request.question)
|
|
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
|
|
prompt = PROMPT_TEMPLATE.format(context=context, question=request.question)
|
|
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
|
|
except Exception as e:
|
|
logger.error(f"Failed to retrieve context: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
|
|
|
|
# 2. Define the generator for the streaming response
|
|
async def stream_generator():
|
|
try:
|
|
ollama_payload = {
|
|
"model": OLLAMA_MODEL,
|
|
"prompt": prompt,
|
|
"stream": True # <-- The key change to enable streaming from Ollama
|
|
}
|
|
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
|
|
|
|
# Use stream=True to get a streaming response from requests
|
|
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
|
|
response.raise_for_status()
|
|
# Ollama streams JSON objects separated by newlines
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
chunk = json.loads(line)
|
|
# Yield the actual text part of the token
|
|
llm_response = chunk.get("response", "")
|
|
if llm_response: # Only log non-empty responses
|
|
logger.debug(f"LLM response chunk: {llm_response[:20]}...")
|
|
yield llm_response
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse JSON chunk: {line}")
|
|
continue
|
|
|
|
except requests.exceptions.Timeout:
|
|
error_msg = "Request to language model timed out"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
except requests.exceptions.ConnectionError:
|
|
error_msg = "Could not connect to the language model. Is Ollama running?"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
except requests.RequestException as e:
|
|
error_msg = f"Error communicating with Ollama: {e}"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
except Exception as e:
|
|
error_msg = f"An unexpected error occurred: {e}"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
|
|
# 3. Return the generator wrapped in a StreamingResponse
|
|
return StreamingResponse(stream_generator(), media_type="text/plain")
|
|
|
|
# --- Conversation-enabled streaming endpoint ---
|
|
@app.post("/ask-conversation")
|
|
async def ask_question_with_conversation(request: ConversationRequest):
|
|
# Basic input validation
|
|
if not request.question.strip():
|
|
logger.warning("Empty question received")
|
|
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
|
|
|
if len(request.question) > 1000:
|
|
logger.warning(f"Question too long: {len(request.question)} characters")
|
|
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
|
|
|
|
# Handle session
|
|
session_id = request.session_id
|
|
if session_id and session_id not in conversations:
|
|
logger.warning(f"Unknown session ID: {session_id}")
|
|
raise HTTPException(status_code=404, detail="Conversation session not found")
|
|
elif not session_id:
|
|
# Cannot create session without title in conversation endpoint
|
|
raise HTTPException(status_code=400, detail="Session ID required. Use /conversation/start to create a new conversation with a title.")
|
|
|
|
logger.info(f"🔍 Conversation request [{session_id[:8]}...]: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
|
|
|
|
# 1. Retrieve context from vector store
|
|
try:
|
|
retrieved_docs = retriever.invoke(request.question)
|
|
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
|
|
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
|
|
except Exception as e:
|
|
logger.error(f"Failed to retrieve context: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
|
|
|
|
# 2. Build conversation context from history
|
|
conversation_data = conversations[session_id]
|
|
conversation_history = conversation_data["exchanges"]
|
|
conversation_context = ""
|
|
if conversation_history:
|
|
conversation_context = "\n\nPREVIOUS CONVERSATION:\n"
|
|
# Include last 3 exchanges to keep context manageable
|
|
for exchange in conversation_history[-3:]:
|
|
conversation_context += f"Human: {exchange['question']}\nAssistant: {exchange['answer'][:200]}{'...' if len(exchange['answer']) > 200 else ''}\n\n"
|
|
conversation_context += "CURRENT QUESTION:\n"
|
|
|
|
# 3. Create enhanced prompt with conversation context
|
|
enhanced_prompt = f"""
|
|
You are an expert lore master for the Dune universe.
|
|
Your task is to answer the user's question with as much detail and context as possible, based *only* on the provided text excerpts.
|
|
If you don't know the answer from the context provided, just say that you don't know, don't try to make up an answer.
|
|
|
|
Pay attention to the conversation history if provided - the user might be asking follow-up questions or referring to previous topics.
|
|
|
|
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
|
|
Do not break the answer into sections based on the source texts. Synthesize them.
|
|
Do not start with "based on the context provided".
|
|
The answer should be thorough and well-explained.
|
|
|
|
CONTEXT FROM DUNE BOOKS:
|
|
{context}
|
|
{conversation_context}
|
|
QUESTION:
|
|
{request.question}
|
|
|
|
ANSWER:
|
|
"""
|
|
|
|
# 4. Collect the full response for conversation storage
|
|
full_response = ""
|
|
|
|
# 5. Define the generator for the streaming response
|
|
async def stream_generator():
|
|
nonlocal full_response
|
|
try:
|
|
ollama_payload = {
|
|
"model": OLLAMA_MODEL,
|
|
"prompt": enhanced_prompt,
|
|
"stream": True
|
|
}
|
|
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
|
|
|
|
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
|
|
response.raise_for_status()
|
|
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
chunk = json.loads(line)
|
|
llm_response = chunk.get("response", "")
|
|
if llm_response:
|
|
full_response += llm_response
|
|
yield llm_response
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse JSON chunk: {line}")
|
|
continue
|
|
|
|
# Store the complete exchange in conversation history
|
|
conversations[session_id]["exchanges"].append({
|
|
"question": request.question,
|
|
"answer": full_response
|
|
})
|
|
exchange_count = len(conversations[session_id]["exchanges"])
|
|
logger.info(f"Stored exchange in conversation {session_id[:8]}... (total: {exchange_count} exchanges)")
|
|
|
|
# Keep conversation history manageable (max 10 exchanges)
|
|
if exchange_count > 10:
|
|
conversations[session_id]["exchanges"] = conversations[session_id]["exchanges"][-10:]
|
|
|
|
except requests.exceptions.Timeout:
|
|
error_msg = "Request to language model timed out"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
except requests.exceptions.ConnectionError:
|
|
error_msg = "Could not connect to the language model. Is Ollama running?"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
except requests.RequestException as e:
|
|
error_msg = f"Error communicating with Ollama: {e}"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
except Exception as e:
|
|
error_msg = f"An unexpected error occurred: {e}"
|
|
logger.error(error_msg)
|
|
yield f"Error: {error_msg}"
|
|
|
|
# 6. Return the response with session info in headers
|
|
response = StreamingResponse(stream_generator(), media_type="text/plain")
|
|
response.headers["X-Session-ID"] = session_id
|
|
return response |