"""
WebSocket Chatbot Router — /ws/chat
Handles the linear resolution flow from the service layer.
"""
import uuid
import json
import logging
from datetime import datetime
from typing import Optional

from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState

from database.firebase_client import get_firestore
from services.faq_service import resolve_query

logger = logging.getLogger("chatbot.websocket")
router = APIRouter(tags=["WebSocket Chat"])


# ── WebSocket Message Helpers ──────────────────────────────────────────────────

async def _send(ws: WebSocket, msg: dict):
    """Safely sends a JSON message over WebSocket."""
    try:
        if ws.client_state == WebSocketState.CONNECTED:
            await ws.send_text(json.dumps(msg))
    except Exception as e:
        logger.warning(f"WebSocket send error: {e}")


async def _send_typing(ws: WebSocket):
    await _send(ws, {"type": "typing"})


async def _send_message(ws: WebSocket, text: str, source: str = "faq", message_id: Optional[str] = None):
    payload = {"type": "message", "text": text, "source": source}
    if message_id:
        payload["message_id"] = message_id
    await _send(ws, payload)


async def _send_error(ws: WebSocket, text: str):
    await _send(ws, {"type": "error", "text": text})


# ── Domain Loader ──────────────────────────────────────────────────────────────

def _load_domain(db, domain_id: str) -> Optional[dict]:
    """Loads and validates a domain from Firestore."""
    try:
        snap = db.collection("domains").document(domain_id).get()
        if not snap.exists:
            return None
        data = snap.to_dict()
        data["id"] = snap.id
        return data if data.get("is_active", True) else None
    except Exception as e:
        logger.error(f"Domain load error for {domain_id}: {e}")
        return None


def _log_conversation(db, domain_id: str, session_id: str, query: str, response: str, source: str):
    """Writes a conversation log entry to Firestore."""
    try:
        log_id = str(uuid.uuid4())
        db.collection("conversations").document(log_id).set({
            "domain_id": domain_id,
            "session_id": session_id,
            "query": query,
            "response": response,
            "source": source,
            "created_at": datetime.utcnow().isoformat(),
        })
    except Exception as e:
        logger.warning(f"Conversation log write failed: {e}")


# ── WebSocket Endpoint ─────────────────────────────────────────────────────────

@router.websocket("/ws/chat")
async def websocket_chat(
    websocket: WebSocket,
    domain_id: str,
    session_id: Optional[str] = None,
):
    """
    Main chatbot WebSocket endpoint.
    Flows through the linear resolve_query pipeline in the service layer.
    """
    await websocket.accept()
    session_id = session_id or str(uuid.uuid4())
    db = get_firestore()

    # ── Validate Domain ────────────────────────────────────────────────────────
    domain = _load_domain(db, domain_id)
    if not domain:
        await _send_error(websocket, "This chatbot is not available.")
        await websocket.close(code=1008)
        return

    subscriber_uid = domain.get("user_id", "")
    provider = domain.get("embedding_provider", "openai")
    custom_prompt = domain.get("custom_prompt")

    welcome = domain.get("widget_welcome_message") or domain.get("welcome_message") or "Welcome to Acme Support."
    await _send(websocket, {
        "type": "connected",
        "session_id": session_id,
        "message": welcome,
    })

    logger.info(f"WebSocket connected: domain={domain_id} session={session_id}")

    # ── Message Loop ───────────────────────────────────────────────────────────
    try:
        while True:
            raw = await websocket.receive_text()

            try:
                data = json.loads(raw)
            except json.JSONDecodeError:
                await _send_error(websocket, "Invalid message format.")
                continue

            msg_type = data.get("type", "message")

            if msg_type == "ping":
                await _send(websocket, {"type": "pong"})
                continue

            if msg_type != "message":
                continue

            query = data.get("text", "").strip()
            if not query:
                continue

            # ── Show Typing Indicator ──────────────────────────────────────────
            await _send_typing(websocket)

            # ── Resolve Query via Service Layer Pipeline ───────────────────────
            try:
                result = await resolve_query(
                    domain_id=domain_id,
                    subscriber_uid=subscriber_uid,
                    query=query,
                    session_id=session_id,
                    db=db,
                    provider=provider,
                    custom_prompt=custom_prompt,
                    domain=domain
                )
                
                answer = result["response"]
                source = result["source"]
                
                # Check for failed_question_id in fallback cases
                fq_id = result.get("failed_question_id")

                await _send_message(websocket, answer, source=source, message_id=fq_id)
                _log_conversation(db, domain_id, session_id, query, answer, source)


            except Exception as e:
                logger.error(f"Error during query resolution: {e}", exc_info=True)
                await _send_message(
                    websocket, 
                    "An error occurred while looking up the answer. Please try again.", 
                    source="error"
                )

    except WebSocketDisconnect:
        logger.info(f"WebSocket disconnected: session={session_id}")
    except Exception as e:
        logger.error(f"WebSocket error on session {session_id}: {e}", exc_info=True)
        try:
            await _send_error(websocket, "A server error occurred.")
        except Exception:
            pass
