"""Question answering with sources over documents."""

from __future__ import annotations

import inspect
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple

from langchain_core._api import deprecated
from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from pydantic import ConfigDict, model_validator

from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
from langchain.chains.qa_with_sources.map_reduce_prompt import (
    COMBINE_PROMPT,
    EXAMPLE_PROMPT,
    QUESTION_PROMPT,
)


@deprecated(
    since="0.2.13",
    removal="1.0",
    message=(
        "This class is deprecated. Refer to this guide on retrieval and question "
        "answering with sources: "
        "https://python.langchain.com/v0.2/docs/how_to/qa_sources/"
    ),
)
class BaseQAWithSourcesChain(Chain, ABC):
    """Question answering chain with sources over documents."""

    combine_documents_chain: BaseCombineDocumentsChain
    """Chain to use to combine documents."""
    question_key: str = "question"  #: :meta private:
    input_docs_key: str = "docs"  #: :meta private:
    answer_key: str = "answer"  #: :meta private:
    sources_answer_key: str = "sources"  #: :meta private:
    return_source_documents: bool = False
    """Return the source documents."""

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
        question_prompt: BasePromptTemplate = QUESTION_PROMPT,
        combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
        **kwargs: Any,
    ) -> BaseQAWithSourcesChain:
        """Construct the chain from an LLM."""
        llm_question_chain = LLMChain(llm=llm, prompt=question_prompt)
        llm_combine_chain = LLMChain(llm=llm, prompt=combine_prompt)
        combine_results_chain = StuffDocumentsChain(
            llm_chain=llm_combine_chain,
            document_prompt=document_prompt,
            document_variable_name="summaries",
        )
        reduce_documents_chain = ReduceDocumentsChain(  # type: ignore[misc]
            combine_documents_chain=combine_results_chain
        )
        combine_documents_chain = MapReduceDocumentsChain(
            llm_chain=llm_question_chain,
            reduce_documents_chain=reduce_documents_chain,
            document_variable_name="context",
        )
        return cls(
            combine_documents_chain=combine_documents_chain,
            **kwargs,
        )

    @classmethod
    def from_chain_type(
        cls,
        llm: BaseLanguageModel,
        chain_type: str = "stuff",
        chain_type_kwargs: Optional[dict] = None,
        **kwargs: Any,
    ) -> BaseQAWithSourcesChain:
        """Load chain from chain type."""
        _chain_kwargs = chain_type_kwargs or {}
        combine_documents_chain = load_qa_with_sources_chain(
            llm, chain_type=chain_type, **_chain_kwargs
        )
        return cls(combine_documents_chain=combine_documents_chain, **kwargs)

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="forbid",
    )

    @property
    def input_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        return [self.question_key]

    @property
    def output_keys(self) -> List[str]:
        """Return output key.

        :meta private:
        """
        _output_keys = [self.answer_key, self.sources_answer_key]
        if self.return_source_documents:
            _output_keys = _output_keys + ["source_documents"]
        return _output_keys

    @model_validator(mode="before")
    @classmethod
    def validate_naming(cls, values: Dict) -> Any:
        """Fix backwards compatibility in naming."""
        if "combine_document_chain" in values:
            values["combine_documents_chain"] = values.pop("combine_document_chain")
        return values

    def _split_sources(self, answer: str) -> Tuple[str, str]:
        """Split sources from answer."""
        if re.search(r"SOURCES?:", answer, re.IGNORECASE):
            answer, sources = re.split(
                r"SOURCES?:|QUESTION:\s", answer, flags=re.IGNORECASE
            )[:2]
            sources = re.split(r"\n", sources)[0].strip()
        else:
            sources = ""
        return answer, sources

    @abstractmethod
    def _get_docs(
        self,
        inputs: Dict[str, Any],
        *,
        run_manager: CallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs to run questioning over."""

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        accepts_run_manager = (
            "run_manager" in inspect.signature(self._get_docs).parameters
        )
        if accepts_run_manager:
            docs = self._get_docs(inputs, run_manager=_run_manager)
        else:
            docs = self._get_docs(inputs)  # type: ignore[call-arg]

        answer = self.combine_documents_chain.run(
            input_documents=docs, callbacks=_run_manager.get_child(), **inputs
        )
        answer, sources = self._split_sources(answer)
        result: Dict[str, Any] = {
            self.answer_key: answer,
            self.sources_answer_key: sources,
        }
        if self.return_source_documents:
            result["source_documents"] = docs
        return result

    @abstractmethod
    async def _aget_docs(
        self,
        inputs: Dict[str, Any],
        *,
        run_manager: AsyncCallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs to run questioning over."""

    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        accepts_run_manager = (
            "run_manager" in inspect.signature(self._aget_docs).parameters
        )
        if accepts_run_manager:
            docs = await self._aget_docs(inputs, run_manager=_run_manager)
        else:
            docs = await self._aget_docs(inputs)  # type: ignore[call-arg]
        answer = await self.combine_documents_chain.arun(
            input_documents=docs, callbacks=_run_manager.get_child(), **inputs
        )
        answer, sources = self._split_sources(answer)
        result: Dict[str, Any] = {
            self.answer_key: answer,
            self.sources_answer_key: sources,
        }
        if self.return_source_documents:
            result["source_documents"] = docs
        return result


@deprecated(
    since="0.2.13",
    removal="1.0",
    message=(
        "This class is deprecated. Refer to this guide on retrieval and question "
        "answering with sources: "
        "https://python.langchain.com/v0.2/docs/how_to/qa_sources/"
    ),
)
class QAWithSourcesChain(BaseQAWithSourcesChain):
    """Question answering with sources over documents."""

    input_docs_key: str = "docs"  #: :meta private:

    @property
    def input_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        return [self.input_docs_key, self.question_key]

    def _get_docs(
        self,
        inputs: Dict[str, Any],
        *,
        run_manager: CallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs to run questioning over."""
        return inputs.pop(self.input_docs_key)

    async def _aget_docs(
        self,
        inputs: Dict[str, Any],
        *,
        run_manager: AsyncCallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs to run questioning over."""
        return inputs.pop(self.input_docs_key)

    @property
    def _chain_type(self) -> str:
        return "qa_with_sources_chain"
