refactor: move app directories
This commit is contained in:
369
llm/main.py
Normal file
369
llm/main.py
Normal file
@@ -0,0 +1,369 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
# --- Configuration (Same as before) ---
|
||||
DB_PATH = "dune_db"
|
||||
EMBEDDING_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
|
||||
OLLAMA_API_URL = "http://localhost:11434/api/generate"
|
||||
OLLAMA_MODEL = "llama3:8b"
|
||||
PROMPT_TEMPLATE = """
|
||||
You are an expert lore master for the Dune universe.
|
||||
Your task is to answer the user's question with as much detail and context as possible, based *only* on the provided text excerpts.
|
||||
If you don't know the answer from the context provided, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
|
||||
Do not break the answer into sections based on the source texts. Synthesize them.
|
||||
Do not start with "based on the context provided".
|
||||
The answer should be thorough and well-explained.
|
||||
|
||||
CONTEXT:
|
||||
{context}
|
||||
|
||||
QUESTION:
|
||||
{question}
|
||||
|
||||
ANSWER:
|
||||
"""
|
||||
|
||||
# --- Logging setup to match uvicorn format ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(levelname)s: %(name)s: %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("dune_api")
|
||||
|
||||
# --- Conversation storage (in-memory for simplicity) ---
|
||||
conversations = {} # session_id -> {title, created, exchanges: []}
|
||||
|
||||
# --- Pydantic Models ---
|
||||
class AskRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
class ConversationRequest(BaseModel):
|
||||
question: str
|
||||
session_id: Optional[str] = None
|
||||
|
||||
class CreateConversationRequest(BaseModel):
|
||||
title: str
|
||||
|
||||
# --- Initialize FastAPI and load resources (Same as before) ---
|
||||
app = FastAPI(
|
||||
title="Dune Expert API",
|
||||
description="Ask questions about the Dune universe and get expert answers",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# --- CORS middleware for React app ---
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Vite dev server
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
logger.info("Initializing Dune Expert API...")
|
||||
|
||||
try:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={'trust_remote_code': True})
|
||||
vector_store = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
|
||||
retriever = vector_store.as_retriever(search_kwargs={"k": 8})
|
||||
logger.info("Successfully loaded embeddings and vector store")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize vector store: {e}")
|
||||
raise RuntimeError(f"Could not initialize vector store: {e}")
|
||||
|
||||
# --- Health check endpoint ---
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "service": "dune-expert-api", "timestamp": datetime.now().isoformat()}
|
||||
|
||||
# --- Conversation management endpoints ---
|
||||
@app.post("/conversation/start")
|
||||
async def start_conversation(request: CreateConversationRequest):
|
||||
"""Start a new conversation with a title and return a session ID"""
|
||||
if not request.title.strip():
|
||||
raise HTTPException(status_code=400, detail="Title cannot be empty")
|
||||
|
||||
if len(request.title) > 100:
|
||||
raise HTTPException(status_code=400, detail="Title too long (max 100 chars)")
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
conversations[session_id] = {
|
||||
"title": request.title.strip(),
|
||||
"created": datetime.now().isoformat(),
|
||||
"exchanges": []
|
||||
}
|
||||
logger.info(f"Started new conversation: {session_id} - '{request.title}'")
|
||||
return {"session_id": session_id, "title": request.title.strip()}
|
||||
|
||||
@app.get("/conversation/{session_id}/history")
|
||||
async def get_conversation_history(session_id: str):
|
||||
"""Get the history of a conversation"""
|
||||
if session_id not in conversations:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
conversation = conversations[session_id]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"title": conversation["title"],
|
||||
"created": conversation["created"],
|
||||
"exchange_count": len(conversation["exchanges"]),
|
||||
"history": conversation["exchanges"]
|
||||
}
|
||||
|
||||
@app.delete("/conversation/{session_id}")
|
||||
async def clear_conversation(session_id: str):
|
||||
"""Clear a conversation's history"""
|
||||
if session_id not in conversations:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
title = conversations[session_id]["title"]
|
||||
conversations[session_id]["exchanges"] = []
|
||||
logger.info(f"Cleared conversation: {session_id} - '{title}'")
|
||||
return {"message": "Conversation cleared", "title": title}
|
||||
|
||||
@app.get("/conversations")
|
||||
async def get_all_conversations():
|
||||
"""Get all stored conversations with their metadata"""
|
||||
conversation_summary = []
|
||||
|
||||
for session_id, conversation in conversations.items():
|
||||
exchanges = conversation["exchanges"]
|
||||
if exchanges: # Only include conversations that have exchanges
|
||||
first_question = exchanges[0]["question"]
|
||||
last_question = exchanges[-1]["question"]
|
||||
conversation_summary.append({
|
||||
"session_id": session_id,
|
||||
"title": conversation["title"],
|
||||
"created": conversation["created"],
|
||||
"exchange_count": len(exchanges),
|
||||
"first_question": first_question[:100] + "..." if len(first_question) > 100 else first_question,
|
||||
"last_question": last_question[:100] + "..." if len(last_question) > 100 else last_question
|
||||
})
|
||||
else:
|
||||
# Include empty conversations too, but with different info
|
||||
conversation_summary.append({
|
||||
"session_id": session_id,
|
||||
"title": conversation["title"],
|
||||
"created": conversation["created"],
|
||||
"exchange_count": 0,
|
||||
"first_question": None,
|
||||
"last_question": None
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(conversation_summary)} conversations")
|
||||
return {
|
||||
"total_conversations": len(conversation_summary),
|
||||
"conversations": conversation_summary
|
||||
}
|
||||
|
||||
# --- NEW: The Streaming Endpoint ---
|
||||
@app.post("/ask-stream")
|
||||
async def ask_question_stream(request: AskRequest):
|
||||
# Basic input validation
|
||||
if not request.question.strip():
|
||||
logger.warning("Empty question received")
|
||||
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
||||
|
||||
if len(request.question) > 1000:
|
||||
logger.warning(f"Question too long: {len(request.question)} characters")
|
||||
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
|
||||
|
||||
logger.info(f"🔍 Streaming request for: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
|
||||
|
||||
# 1. Retrieve context (this part is still blocking)
|
||||
try:
|
||||
retrieved_docs = retriever.invoke(request.question)
|
||||
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
|
||||
prompt = PROMPT_TEMPLATE.format(context=context, question=request.question)
|
||||
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve context: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
|
||||
|
||||
# 2. Define the generator for the streaming response
|
||||
async def stream_generator():
|
||||
try:
|
||||
ollama_payload = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"prompt": prompt,
|
||||
"stream": True # <-- The key change to enable streaming from Ollama
|
||||
}
|
||||
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
|
||||
|
||||
# Use stream=True to get a streaming response from requests
|
||||
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
|
||||
response.raise_for_status()
|
||||
# Ollama streams JSON objects separated by newlines
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
# Yield the actual text part of the token
|
||||
llm_response = chunk.get("response", "")
|
||||
if llm_response: # Only log non-empty responses
|
||||
logger.debug(f"LLM response chunk: {llm_response[:20]}...")
|
||||
yield llm_response
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON chunk: {line}")
|
||||
continue
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = "Request to language model timed out"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = "Could not connect to the language model. Is Ollama running?"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"Error communicating with Ollama: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
error_msg = f"An unexpected error occurred: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
|
||||
# 3. Return the generator wrapped in a StreamingResponse
|
||||
return StreamingResponse(stream_generator(), media_type="text/plain")
|
||||
|
||||
# --- Conversation-enabled streaming endpoint ---
|
||||
@app.post("/ask-conversation")
|
||||
async def ask_question_with_conversation(request: ConversationRequest):
|
||||
# Basic input validation
|
||||
if not request.question.strip():
|
||||
logger.warning("Empty question received")
|
||||
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
||||
|
||||
if len(request.question) > 1000:
|
||||
logger.warning(f"Question too long: {len(request.question)} characters")
|
||||
raise HTTPException(status_code=400, detail="Question too long (max 1000 chars)")
|
||||
|
||||
# Handle session
|
||||
session_id = request.session_id
|
||||
if session_id and session_id not in conversations:
|
||||
logger.warning(f"Unknown session ID: {session_id}")
|
||||
raise HTTPException(status_code=404, detail="Conversation session not found")
|
||||
elif not session_id:
|
||||
# Cannot create session without title in conversation endpoint
|
||||
raise HTTPException(status_code=400, detail="Session ID required. Use /conversation/start to create a new conversation with a title.")
|
||||
|
||||
logger.info(f"🔍 Conversation request [{session_id[:8]}...]: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
|
||||
|
||||
# 1. Retrieve context from vector store
|
||||
try:
|
||||
retrieved_docs = retriever.invoke(request.question)
|
||||
context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
|
||||
logger.info(f"Retrieved {len(retrieved_docs)} documents for context")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve context: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
|
||||
|
||||
# 2. Build conversation context from history
|
||||
conversation_data = conversations[session_id]
|
||||
conversation_history = conversation_data["exchanges"]
|
||||
conversation_context = ""
|
||||
if conversation_history:
|
||||
conversation_context = "\n\nPREVIOUS CONVERSATION:\n"
|
||||
# Include last 3 exchanges to keep context manageable
|
||||
for exchange in conversation_history[-3:]:
|
||||
conversation_context += f"Human: {exchange['question']}\nAssistant: {exchange['answer'][:200]}{'...' if len(exchange['answer']) > 200 else ''}\n\n"
|
||||
conversation_context += "CURRENT QUESTION:\n"
|
||||
|
||||
# 3. Create enhanced prompt with conversation context
|
||||
enhanced_prompt = f"""
|
||||
You are an expert lore master for the Dune universe.
|
||||
Your task is to answer the user's question with as much detail and context as possible, based *only* on the provided text excerpts.
|
||||
If you don't know the answer from the context provided, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
Pay attention to the conversation history if provided - the user might be asking follow-up questions or referring to previous topics.
|
||||
|
||||
Combine all the relevant information from the context below into a single, cohesive, and comprehensive answer.
|
||||
Do not break the answer into sections based on the source texts. Synthesize them.
|
||||
Do not start with "based on the context provided".
|
||||
The answer should be thorough and well-explained.
|
||||
|
||||
CONTEXT FROM DUNE BOOKS:
|
||||
{context}
|
||||
{conversation_context}
|
||||
QUESTION:
|
||||
{request.question}
|
||||
|
||||
ANSWER:
|
||||
"""
|
||||
|
||||
# 4. Collect the full response for conversation storage
|
||||
full_response = ""
|
||||
|
||||
# 5. Define the generator for the streaming response
|
||||
async def stream_generator():
|
||||
nonlocal full_response
|
||||
try:
|
||||
ollama_payload = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"prompt": enhanced_prompt,
|
||||
"stream": True
|
||||
}
|
||||
logger.info(f"Sending request to Ollama with model: {OLLAMA_MODEL}")
|
||||
|
||||
with requests.post(OLLAMA_API_URL, json=ollama_payload, stream=True, timeout=30) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
llm_response = chunk.get("response", "")
|
||||
if llm_response:
|
||||
full_response += llm_response
|
||||
yield llm_response
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse JSON chunk: {line}")
|
||||
continue
|
||||
|
||||
# Store the complete exchange in conversation history
|
||||
conversations[session_id]["exchanges"].append({
|
||||
"question": request.question,
|
||||
"answer": full_response
|
||||
})
|
||||
exchange_count = len(conversations[session_id]["exchanges"])
|
||||
logger.info(f"Stored exchange in conversation {session_id[:8]}... (total: {exchange_count} exchanges)")
|
||||
|
||||
# Keep conversation history manageable (max 10 exchanges)
|
||||
if exchange_count > 10:
|
||||
conversations[session_id]["exchanges"] = conversations[session_id]["exchanges"][-10:]
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = "Request to language model timed out"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = "Could not connect to the language model. Is Ollama running?"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"Error communicating with Ollama: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
error_msg = f"An unexpected error occurred: {e}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
|
||||
# 6. Return the response with session info in headers
|
||||
response = StreamingResponse(stream_generator(), media_type="text/plain")
|
||||
response.headers["X-Session-ID"] = session_id
|
||||
return response
|
||||
Reference in New Issue
Block a user