diff --git a/backend/main.py b/backend/main.py index e896bf2..e33ce59 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2,12 +2,15 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import List, Dict, Any import uuid +import json +import asyncio from . import storage -from .council import run_full_council, generate_conversation_title +from .council import run_full_council, generate_conversation_title, stage1_collect_responses, stage2_collect_rankings, stage3_synthesize_final, calculate_aggregate_rankings app = FastAPI(title="LLM Council API") @@ -120,6 +123,77 @@ async def send_message(conversation_id: str, request: SendMessageRequest): } +@app.post("/api/conversations/{conversation_id}/message/stream") +async def send_message_stream(conversation_id: str, request: SendMessageRequest): + """ + Send a message and stream the 3-stage council process. + Returns Server-Sent Events as each stage completes. + """ + # Check if conversation exists + conversation = storage.get_conversation(conversation_id) + if conversation is None: + raise HTTPException(status_code=404, detail="Conversation not found") + + # Check if this is the first message + is_first_message = len(conversation["messages"]) == 0 + + async def event_generator(): + try: + # Add user message + storage.add_user_message(conversation_id, request.content) + + # Start title generation in parallel (don't await yet) + title_task = None + if is_first_message: + title_task = asyncio.create_task(generate_conversation_title(request.content)) + + # Stage 1: Collect responses + yield f"data: {json.dumps({'type': 'stage1_start'})}\n\n" + stage1_results = await stage1_collect_responses(request.content) + yield f"data: {json.dumps({'type': 'stage1_complete', 'data': stage1_results})}\n\n" + + # Stage 2: Collect rankings + yield f"data: {json.dumps({'type': 'stage2_start'})}\n\n" + stage2_results, label_to_model = await stage2_collect_rankings(request.content, stage1_results) + aggregate_rankings = calculate_aggregate_rankings(stage2_results, label_to_model) + yield f"data: {json.dumps({'type': 'stage2_complete', 'data': stage2_results, 'metadata': {'label_to_model': label_to_model, 'aggregate_rankings': aggregate_rankings}})}\n\n" + + # Stage 3: Synthesize final answer + yield f"data: {json.dumps({'type': 'stage3_start'})}\n\n" + stage3_result = await stage3_synthesize_final(request.content, stage1_results, stage2_results) + yield f"data: {json.dumps({'type': 'stage3_complete', 'data': stage3_result})}\n\n" + + # Wait for title generation if it was started + if title_task: + title = await title_task + storage.update_conversation_title(conversation_id, title) + yield f"data: {json.dumps({'type': 'title_complete', 'data': {'title': title}})}\n\n" + + # Save complete assistant message + storage.add_assistant_message( + conversation_id, + stage1_results, + stage2_results, + stage3_result + ) + + # Send completion event + yield f"data: {json.dumps({'type': 'complete'})}\n\n" + + except Exception as e: + # Send error event + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + } + ) + + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 0396e3a..1954155 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -69,33 +69,114 @@ function App() { messages: [...prev.messages, userMessage], })); - // Send message and get council response - const response = await api.sendMessage(currentConversationId, content); - - // Add assistant message to UI + // Create a partial assistant message that will be updated progressively const assistantMessage = { role: 'assistant', - stage1: response.stage1, - stage2: response.stage2, - stage3: response.stage3, - metadata: response.metadata, + stage1: null, + stage2: null, + stage3: null, + metadata: null, + loading: { + stage1: false, + stage2: false, + stage3: false, + }, }; + // Add the partial assistant message setCurrentConversation((prev) => ({ ...prev, messages: [...prev.messages, assistantMessage], })); - // Reload conversations list to update message count - await loadConversations(); + // Send message with streaming + await api.sendMessageStream(currentConversationId, content, (eventType, event) => { + switch (eventType) { + case 'stage1_start': + setCurrentConversation((prev) => { + const messages = [...prev.messages]; + const lastMsg = messages[messages.length - 1]; + lastMsg.loading.stage1 = true; + return { ...prev, messages }; + }); + break; + + case 'stage1_complete': + setCurrentConversation((prev) => { + const messages = [...prev.messages]; + const lastMsg = messages[messages.length - 1]; + lastMsg.stage1 = event.data; + lastMsg.loading.stage1 = false; + return { ...prev, messages }; + }); + break; + + case 'stage2_start': + setCurrentConversation((prev) => { + const messages = [...prev.messages]; + const lastMsg = messages[messages.length - 1]; + lastMsg.loading.stage2 = true; + return { ...prev, messages }; + }); + break; + + case 'stage2_complete': + setCurrentConversation((prev) => { + const messages = [...prev.messages]; + const lastMsg = messages[messages.length - 1]; + lastMsg.stage2 = event.data; + lastMsg.metadata = event.metadata; + lastMsg.loading.stage2 = false; + return { ...prev, messages }; + }); + break; + + case 'stage3_start': + setCurrentConversation((prev) => { + const messages = [...prev.messages]; + const lastMsg = messages[messages.length - 1]; + lastMsg.loading.stage3 = true; + return { ...prev, messages }; + }); + break; + + case 'stage3_complete': + setCurrentConversation((prev) => { + const messages = [...prev.messages]; + const lastMsg = messages[messages.length - 1]; + lastMsg.stage3 = event.data; + lastMsg.loading.stage3 = false; + return { ...prev, messages }; + }); + break; + + case 'title_complete': + // Reload conversations to get updated title + loadConversations(); + break; + + case 'complete': + // Stream complete, reload conversations list + loadConversations(); + setIsLoading(false); + break; + + case 'error': + console.error('Stream error:', event.message); + setIsLoading(false); + break; + + default: + console.log('Unknown event type:', eventType); + } + }); } catch (error) { console.error('Failed to send message:', error); - // Remove optimistic user message on error + // Remove optimistic messages on error setCurrentConversation((prev) => ({ ...prev, - messages: prev.messages.slice(0, -1), + messages: prev.messages.slice(0, -2), })); - } finally { setIsLoading(false); } }; diff --git a/frontend/src/api.js b/frontend/src/api.js index 479f0ef..87ec685 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -65,4 +65,51 @@ export const api = { } return response.json(); }, + + /** + * Send a message and receive streaming updates. + * @param {string} conversationId - The conversation ID + * @param {string} content - The message content + * @param {function} onEvent - Callback function for each event: (eventType, data) => void + * @returns {Promise} + */ + async sendMessageStream(conversationId, content, onEvent) { + const response = await fetch( + `${API_BASE}/api/conversations/${conversationId}/message/stream`, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ content }), + } + ); + + if (!response.ok) { + throw new Error('Failed to send message'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + const lines = chunk.split('\n'); + + for (const line of lines) { + if (line.startsWith('data: ')) { + const data = line.slice(6); + try { + const event = JSON.parse(data); + onEvent(event.type, event); + } catch (e) { + console.error('Failed to parse SSE event:', e); + } + } + } + } + }, }; diff --git a/frontend/src/components/ChatInterface.css b/frontend/src/components/ChatInterface.css index 531d2a3..0d01300 100644 --- a/frontend/src/components/ChatInterface.css +++ b/frontend/src/components/ChatInterface.css @@ -71,6 +71,20 @@ font-size: 14px; } +.stage-loading { + display: flex; + align-items: center; + gap: 12px; + padding: 16px; + margin: 12px 0; + background: #f9fafb; + border-radius: 8px; + border: 1px solid #e0e0e0; + color: #666; + font-size: 14px; + font-style: italic; +} + .spinner { width: 20px; height: 20px; diff --git a/frontend/src/components/ChatInterface.jsx b/frontend/src/components/ChatInterface.jsx index 951183f..3ae796c 100644 --- a/frontend/src/components/ChatInterface.jsx +++ b/frontend/src/components/ChatInterface.jsx @@ -71,13 +71,39 @@ export default function ChatInterface({ ) : (
LLM Council
- - - + + {/* Stage 1 */} + {msg.loading?.stage1 && ( +
+
+ Running Stage 1: Collecting individual responses... +
+ )} + {msg.stage1 && } + + {/* Stage 2 */} + {msg.loading?.stage2 && ( +
+
+ Running Stage 2: Peer rankings... +
+ )} + {msg.stage2 && ( + + )} + + {/* Stage 3 */} + {msg.loading?.stage3 && ( +
+
+ Running Stage 3: Final synthesis... +
+ )} + {msg.stage3 && }
)} @@ -94,24 +120,26 @@ export default function ChatInterface({
-
-