feat: add basic frontend
This commit is contained in:
106
src/llm/main.py
106
src/llm/main.py
@@ -1,5 +1,6 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
@@ -43,7 +44,7 @@ logging.basicConfig(
|
||||
logger = logging.getLogger("dune_api")
|
||||
|
||||
# --- Conversation storage (in-memory for simplicity) ---
|
||||
conversations = {} # session_id -> list of {question, answer} pairs
|
||||
conversations = {} # session_id -> {title, created, exchanges: []}
|
||||
|
||||
# --- Pydantic Models ---
|
||||
class AskRequest(BaseModel):
|
||||
@@ -53,6 +54,9 @@ class ConversationRequest(BaseModel):
|
||||
question: str
|
||||
session_id: Optional[str] = None
|
||||
|
||||
class CreateConversationRequest(BaseModel):
|
||||
title: str
|
||||
|
||||
# --- Initialize FastAPI and load resources (Same as before) ---
|
||||
app = FastAPI(
|
||||
title="Dune Expert API",
|
||||
@@ -60,6 +64,15 @@ app = FastAPI(
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# --- CORS middleware for React app ---
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Vite dev server
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
logger.info("Initializing Dune Expert API...")
|
||||
|
||||
try:
|
||||
@@ -78,28 +91,83 @@ async def health_check():
|
||||
|
||||
# --- Conversation management endpoints ---
|
||||
@app.post("/conversation/start")
|
||||
async def start_conversation():
|
||||
"""Start a new conversation and return a session ID"""
|
||||
async def start_conversation(request: CreateConversationRequest):
|
||||
"""Start a new conversation with a title and return a session ID"""
|
||||
if not request.title.strip():
|
||||
raise HTTPException(status_code=400, detail="Title cannot be empty")
|
||||
|
||||
if len(request.title) > 100:
|
||||
raise HTTPException(status_code=400, detail="Title too long (max 100 chars)")
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
conversations[session_id] = []
|
||||
logger.info(f"Started new conversation: {session_id}")
|
||||
return {"session_id": session_id}
|
||||
conversations[session_id] = {
|
||||
"title": request.title.strip(),
|
||||
"created": datetime.now().isoformat(),
|
||||
"exchanges": []
|
||||
}
|
||||
logger.info(f"Started new conversation: {session_id} - '{request.title}'")
|
||||
return {"session_id": session_id, "title": request.title.strip()}
|
||||
|
||||
@app.get("/conversation/{session_id}/history")
|
||||
async def get_conversation_history(session_id: str):
|
||||
"""Get the history of a conversation"""
|
||||
if session_id not in conversations:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
return {"session_id": session_id, "history": conversations[session_id]}
|
||||
|
||||
conversation = conversations[session_id]
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"title": conversation["title"],
|
||||
"created": conversation["created"],
|
||||
"exchange_count": len(conversation["exchanges"]),
|
||||
"history": conversation["exchanges"]
|
||||
}
|
||||
|
||||
@app.delete("/conversation/{session_id}")
|
||||
async def clear_conversation(session_id: str):
|
||||
"""Clear a conversation's history"""
|
||||
if session_id not in conversations:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
conversations[session_id] = []
|
||||
logger.info(f"Cleared conversation: {session_id}")
|
||||
return {"message": "Conversation cleared"}
|
||||
|
||||
title = conversations[session_id]["title"]
|
||||
conversations[session_id]["exchanges"] = []
|
||||
logger.info(f"Cleared conversation: {session_id} - '{title}'")
|
||||
return {"message": "Conversation cleared", "title": title}
|
||||
|
||||
@app.get("/conversations")
|
||||
async def get_all_conversations():
|
||||
"""Get all stored conversations with their metadata"""
|
||||
conversation_summary = []
|
||||
|
||||
for session_id, conversation in conversations.items():
|
||||
exchanges = conversation["exchanges"]
|
||||
if exchanges: # Only include conversations that have exchanges
|
||||
first_question = exchanges[0]["question"]
|
||||
last_question = exchanges[-1]["question"]
|
||||
conversation_summary.append({
|
||||
"session_id": session_id,
|
||||
"title": conversation["title"],
|
||||
"created": conversation["created"],
|
||||
"exchange_count": len(exchanges),
|
||||
"first_question": first_question[:100] + "..." if len(first_question) > 100 else first_question,
|
||||
"last_question": last_question[:100] + "..." if len(last_question) > 100 else last_question
|
||||
})
|
||||
else:
|
||||
# Include empty conversations too, but with different info
|
||||
conversation_summary.append({
|
||||
"session_id": session_id,
|
||||
"title": conversation["title"],
|
||||
"created": conversation["created"],
|
||||
"exchange_count": 0,
|
||||
"first_question": None,
|
||||
"last_question": None
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(conversation_summary)} conversations")
|
||||
return {
|
||||
"total_conversations": len(conversation_summary),
|
||||
"conversations": conversation_summary
|
||||
}
|
||||
|
||||
# --- NEW: The Streaming Endpoint ---
|
||||
@app.post("/ask-stream")
|
||||
@@ -190,10 +258,8 @@ async def ask_question_with_conversation(request: ConversationRequest):
|
||||
logger.warning(f"Unknown session ID: {session_id}")
|
||||
raise HTTPException(status_code=404, detail="Conversation session not found")
|
||||
elif not session_id:
|
||||
# Create new session if none provided
|
||||
session_id = str(uuid.uuid4())
|
||||
conversations[session_id] = []
|
||||
logger.info(f"Created new conversation session: {session_id}")
|
||||
# Cannot create session without title in conversation endpoint
|
||||
raise HTTPException(status_code=400, detail="Session ID required. Use /conversation/start to create a new conversation with a title.")
|
||||
|
||||
logger.info(f"🔍 Conversation request [{session_id[:8]}...]: {request.question[:50]}{'...' if len(request.question) > 50 else ''}")
|
||||
|
||||
@@ -207,7 +273,8 @@ async def ask_question_with_conversation(request: ConversationRequest):
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve context from knowledge base")
|
||||
|
||||
# 2. Build conversation context from history
|
||||
conversation_history = conversations[session_id]
|
||||
conversation_data = conversations[session_id]
|
||||
conversation_history = conversation_data["exchanges"]
|
||||
conversation_context = ""
|
||||
if conversation_history:
|
||||
conversation_context = "\n\nPREVIOUS CONVERSATION:\n"
|
||||
@@ -268,15 +335,16 @@ ANSWER:
|
||||
continue
|
||||
|
||||
# Store the complete exchange in conversation history
|
||||
conversations[session_id].append({
|
||||
conversations[session_id]["exchanges"].append({
|
||||
"question": request.question,
|
||||
"answer": full_response
|
||||
})
|
||||
logger.info(f"Stored exchange in conversation {session_id[:8]}... (total: {len(conversations[session_id])} exchanges)")
|
||||
exchange_count = len(conversations[session_id]["exchanges"])
|
||||
logger.info(f"Stored exchange in conversation {session_id[:8]}... (total: {exchange_count} exchanges)")
|
||||
|
||||
# Keep conversation history manageable (max 10 exchanges)
|
||||
if len(conversations[session_id]) > 10:
|
||||
conversations[session_id] = conversations[session_id][-10:]
|
||||
if exchange_count > 10:
|
||||
conversations[session_id]["exchanges"] = conversations[session_id]["exchanges"][-10:]
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = "Request to language model timed out"
|
||||
|
||||
Reference in New Issue
Block a user