"""
Vector Service — Multi Embedding Provider Architecture

Provides semantic FAQ search for the chatbot system using ChromaDB.
Supports dynamic switching between OpenAI and Gemini embeddings on a per-domain basis.
"""
import logging
from typing import List, Dict, Any, Optional

import chromadb
from chromadb import EmbeddingFunction, Documents, Embeddings

import core.config as cfg

logger = logging.getLogger("chatbot.vector_service")

# ── Embedding Functions ────────────────────────────────────────────────────────

class OpenAIEmbeddingFunction(EmbeddingFunction):
    """Custom ChromaDB EmbeddingFunction using OpenAI API."""
    def __init__(self):
        self._client = None
        if cfg.OPENAI_API_KEY and not cfg.OPENAI_API_KEY.startswith("sk-proj-your-"):
            try:
                from openai import OpenAI
                self._client = OpenAI(api_key=cfg.OPENAI_API_KEY)
            except Exception as e:
                logger.warning(f"OpenAI embedding init failed: {e}")

    def __call__(self, input: Documents) -> Embeddings:
        if not self._client:
            raise RuntimeError("OpenAI embedding not available. API key missing.")
        try:
            # We use the synchronous API for chromadb EmbeddingFunction
            response = self._client.embeddings.create(
                input=input,
                model="text-embedding-3-small"
            )
            return [data.embedding for data in response.data]
        except Exception as e:
            logger.error(f"OpenAI embedding failed: {e}")
            raise RuntimeError(f"OpenAI embedding failed: {e}")


_openai_ef = OpenAIEmbeddingFunction()


class GeminiEmbeddingFunction(EmbeddingFunction):
    """Custom ChromaDB EmbeddingFunction using Gemini API with automatic fallback to OpenAI."""
    def __init__(self):
        self._genai = None
        if cfg.GEMINI_API_KEY:
            try:
                import google.generativeai as genai
                genai.configure(api_key=cfg.GEMINI_API_KEY)
                self._genai = genai
            except Exception as e:
                logger.warning(f"Gemini embedding init failed: {e}")

    def __call__(self, input: Documents) -> Embeddings:
        if not self._genai:
            logger.warning("Gemini embedding not available (init failed). Falling back to OpenAI.")
            return _openai_ef(input)
        embeddings = []
        for text in input:
            try:
                result = self._genai.embed_content(
                    model=f"models/{cfg.GEMINI_EMBEDDING_MODEL}",
                    content=text,
                    task_type="retrieval_document",
                )
                embeddings.append(result["embedding"])
            except Exception as e:
                logger.error(f"Gemini embedding failed for text: {e}. Falling back to OpenAI.")
                try:
                    return _openai_ef(input)
                except Exception as oe:
                    logger.critical(f"OpenAI fallback embedding also failed: {oe}")
                    raise RuntimeError(f"Gemini and OpenAI fallback embeddings failed: {e} / {oe}")
        return embeddings


_gemini_ef = GeminiEmbeddingFunction()


# ── ChromaDB Client ────────────────────────────────────────────────────────────

try:
    chroma_client = chromadb.PersistentClient(path=cfg.CHROMA_PERSIST_DIRECTORY)
    logger.info(f"ChromaDB initialized at: {cfg.CHROMA_PERSIST_DIRECTORY}")
except Exception as e:
    logger.warning(f"ChromaDB PersistentClient failed, using EphemeralClient: {e}")
    chroma_client = chromadb.EphemeralClient()


# ── Provider Resolution ────────────────────────────────────────────────────────

def get_provider_for_domain(domain_id: str, db) -> str:
    """Fetches the configured embedding provider for a domain from Firestore."""
    try:
        if db:
            snap = db.collection("domains").document(domain_id).get()
            if snap.exists:
                data = snap.to_dict()
                return data.get("embedding_provider", "openai")
    except Exception as e:
        logger.warning(f"Failed to fetch embedding provider for {domain_id}: {e}")
    return "openai"


def get_embedding_function(provider: str) -> EmbeddingFunction:
    """Returns the correct EmbeddingFunction instance."""
    if provider.lower() == "gemini":
        return _gemini_ef
    return _openai_ef


# ── Collection Helper ──────────────────────────────────────────────────────────

def get_collection(domain_id: str, db=None):
    """Returns or creates a dedicated ChromaDB collection for a domain using cosine distance."""
    provider = get_provider_for_domain(domain_id, db)
    ef = get_embedding_function(provider)
    
    safe_name = f"domain_{domain_id.replace('-', '_')}"
    return chroma_client.get_or_create_collection(
        name=safe_name,
        embedding_function=ef,
        metadata={"hnsw:space": "cosine"}
    )


# ── Retrain Logic ──────────────────────────────────────────────────────────────

def reindex_domain(domain_id: str, db) -> None:
    """
    Called when a provider is switched or manual retrain is requested.
    Deletes the old vectors, fetches all FAQs, and re-embeds them using the new provider.
    """
    safe_name = f"domain_{domain_id.replace('-', '_')}"
    
    # 1. Delete existing collection
    try:
        chroma_client.delete_collection(name=safe_name)
        logger.info(f"Deleted old Chroma collection: {safe_name}")
    except Exception as e:
        logger.info(f"Collection {safe_name} did not exist or could not be deleted: {e}")

    # 2. Fetch all FAQs for this domain
    faqs = []
    try:
        faq_docs = db.collection("faqs").where("domain_id", "==", domain_id).stream()
        for doc in faq_docs:
            data = doc.to_dict()
            faqs.append({
                "id": doc.id,
                "question": data.get("question", ""),
                "answer": data.get("answer", "")
            })
    except Exception as e:
        logger.error(f"Failed to fetch FAQs for reindexing domain {domain_id}: {e}")
        raise RuntimeError(f"Firestore fetch failed: {e}")

    # 3. Bulk index with the new active provider
    if faqs:
        bulk_index_faqs(domain_id, faqs, db)
        logger.info(f"Re-indexed {len(faqs)} FAQs for domain {domain_id}")


# ── FAQ Indexing ───────────────────────────────────────────────────────────────

def index_faq(domain_id: str, faq_id: str, question: str, answer: str, db=None) -> None:
    """Indexes a single FAQ Q&A pair into the domain's vector collection."""
    try:
        collection = get_collection(domain_id, db)
        document_content = f"Question: {question}\nAnswer: {answer}"
        collection.upsert(
            documents=[document_content],
            metadatas=[{"faq_id": faq_id, "question": question, "answer": answer}],
            ids=[faq_id],
        )
        logger.info(f"Indexed FAQ {faq_id} for domain {domain_id}")
    except Exception as e:
        logger.error(f"Failed to index FAQ {faq_id}: {e}")
        raise RuntimeError(f"Vector indexing error: {e}")


def remove_faq(domain_id: str, faq_id: str, db=None) -> None:
    """Removes a single FAQ from the domain's vector collection."""
    try:
        collection = get_collection(domain_id, db)
        collection.delete(ids=[faq_id])
        logger.info(f"Removed FAQ {faq_id} from domain {domain_id}")
    except Exception as e:
        logger.error(f"Failed to remove FAQ {faq_id}: {e}")


def bulk_index_faqs(domain_id: str, faqs: List[Dict[str, Any]], db=None) -> None:
    """Batch-indexes a list of FAQs (each with id, question, answer)."""
    if not faqs:
        return
    try:
        collection = get_collection(domain_id, db)
        ids = []
        documents = []
        metadatas = []
        for faq in faqs:
            ids.append(faq["id"])
            documents.append(f"Question: {faq['question']}\nAnswer: {faq['answer']}")
            metadatas.append({"faq_id": faq["id"], "question": faq["question"], "answer": faq["answer"]})
        collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
        logger.info(f"Bulk indexed {len(faqs)} FAQs for domain {domain_id}")
    except Exception as e:
        logger.error(f"Bulk indexing failed for domain {domain_id}: {e}")
        raise RuntimeError(f"Vector batch insert error: {e}")


# ── FAQ Search ─────────────────────────────────────────────────────────────────

def search_similar_faqs(domain_id: str, query: str, limit: int = 3, db=None) -> str:
    """
    Searches the domain's vector space for relevant FAQ context.
    Returns a formatted string of matched content to use as LLM context.
    """
    try:
        collection = get_collection(domain_id, db)
        count = collection.count()
        if count == 0:
            return "No documented facts available for this query."
        results = collection.query(query_texts=[query], n_results=min(limit, count))
        if not results or not results["documents"] or not results["documents"][0]:
            return "No documented facts available for this query."
        return "\n\n".join(results["documents"][0])
    except Exception as e:
        logger.error(f"Vector search failed for domain {domain_id}: {e}")
        return "Search temporarily unavailable."


def search_faq_similarity(
    domain_id: str,
    query: str,
    merged_domain_ids: Optional[List[str]] = None,
    db=None
) -> dict:
    """
    Searches one or more domain collections for the best FAQ match.
    Returns the highest-confidence result across all searched domains.
    """
    domains_to_search = [domain_id]
    if merged_domain_ids:
        for m_id in merged_domain_ids:
            if m_id and m_id not in domains_to_search:
                domains_to_search.append(m_id)

    best_match = {
        "matched_faq_id": None,
        "question": None,
        "answer": None,
        "confidence_score": 0.0,
    }

    for d_id in domains_to_search:
        try:
            collection = get_collection(d_id, db)
            count = collection.count()
            if count == 0:
                continue
            results = collection.query(query_texts=[query], n_results=1)
            if not results or not results["documents"] or not results["documents"][0]:
                continue
            distance = (
                results["distances"][0][0]
                if results.get("distances") and results["distances"][0]
                else 1.0
            )
            confidence = max(0.0, min(1.0, 1.0 - distance))
            if confidence > best_match["confidence_score"]:
                metadata = (
                    results["metadatas"][0][0]
                    if results.get("metadatas") and results["metadatas"][0]
                    else {}
                )
                best_match = {
                    "matched_faq_id": results["ids"][0][0] if results.get("ids") and results["ids"][0] else None,
                    "question": metadata.get("question"),
                    "answer": metadata.get("answer"),
                    "confidence_score": float(confidence),
                }
        except Exception as e:
            logger.error(f"Similarity search error on domain {d_id}: {e}")
            raise RuntimeError(f"Embedding provider error: {e}")

    return best_match
