import enum
import json
import uuid
import logging
import inspect
import functools

from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
from pydantic import BaseModel, validate_call

from openai import OpenAI
from instructor.function_calls import openai_schema


T_Retval = TypeVar("T_Retval")


class FinetuneFormat(enum.Enum):
    MESSAGES: str = "messages"
    RAW: str = "raw"


def get_signature_from_fn(fn: Callable[..., Any]) -> str:
    """
    Get the function signature as a string.

    :Example:

    >>> def my_function(a: int, b: int) -> int:
    >>>     return a + b
    >>>
    >>> get_signature_from_fn(my_function)
    "def my_function(a: int, b: int) -> int"

    :param fn: Function to get the signature for.
    :return: Function signature as a string.
    """
    sig = inspect.signature(fn)
    lines = f"def {fn.__name__}{sig}"
    docstring = inspect.getdoc(fn)
    if docstring:
        formatted_docstring = f'"""\n{docstring}\n"""'
    else:
        formatted_docstring = ""
    return f"{lines}\n{formatted_docstring}"


@functools.lru_cache()
def format_function(func: Callable[..., Any]) -> str:
    """
    Format a function as a string with docstring and body.
    """
    source_lines = inspect.getsourcelines(func)
    definition = " ".join(source_lines[0]).strip()

    docstring = inspect.getdoc(func)
    if docstring:
        formatted_docstring = f'"""\n{docstring}\n"""'
    else:
        formatted_docstring = ""

    body = inspect.getsource(func)
    body = body.replace(f"def {func.__name__}", "")

    return f"{definition}\n{formatted_docstring}\n{body}"


def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool:
    """
    Check if the return type of a function is a pydantic BaseModel or an instance of it.

    :param func: Function to check.
    :return: True if the return type is a pydantic BaseModel or an instance of it.
    """
    return_type = inspect.signature(func).return_annotation
    assert (
        return_type != inspect.Signature.empty
    ), "Must have a return type hint that is a pydantic BaseModel"
    return inspect.isclass(return_type) and issubclass(return_type, BaseModel)


class Instructions:
    def __init__(
        self,
        name: Optional[str] = None,
        id: Optional[str] = None,
        log_handlers: Optional[List[logging.Handler]] = None,
        finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
        indent: int = 2,
        include_code_body: bool = False,
        openai_client: Optional[OpenAI] = None,
    ) -> None:
        """
        Instructions for distillation and dispatch.

        :param name: Name of the instructions.
        :param id: ID of the instructions.
        :param log_handlers: List of log handlers to use.
        :param finetune_format: Format to use for finetuning.
        :param indent: Indentation to use for finetuning.
        :param include_code_body: Whether to include the code body in the finetuning.
        """
        self.name = name
        self.id = id or str(uuid.uuid4())
        self.unique_id = str(uuid.uuid4())
        self.finetune_format = finetune_format
        self.indent = indent
        self.include_code_body = include_code_body
        self.client = openai_client or OpenAI()

        self.logger = logging.getLogger(self.name)
        for handler in log_handlers or []:
            self.logger.addHandler(handler)

    def distil(
        self,
        *args: Any,
        name: Optional[str] = None,
        mode: str = "distil",
        model: str = "gpt-3.5-turbo",
        fine_tune_format: Optional[FinetuneFormat] = None,
    ) -> Callable[
        [Callable[..., Any]],
        Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]],
    ]:
        """
        Decorator to track the function call and response, supports distillation and dispatch modes.

        If used without arguments, it must be used as a decorator.

        :Example:

        >>> @distil
        >>> def my_function() -> MyModel:
        >>>     return MyModel()
        >>>
        >>> @distil(name="my_function")
        >>> def my_function() -> MyModel:
        >>>     return MyModel()

        :param fn: Function to track.
        :param name: Name of the function to track. Defaults to the function name.
        :param mode: Mode to use for distillation. Defaults to "distil".
        """
        allowed_modes = {"distil", "dispatch"}
        assert mode in allowed_modes, f"Must be in {allowed_modes}"

        if fine_tune_format is None:
            fine_tune_format = self.finetune_format

        def _wrap_distil(
            fn: Callable[..., Any],
        ) -> Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]]:
            msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
            assert is_return_type_base_model_or_instance(fn), msg
            return_base_model = inspect.signature(fn).return_annotation

            @functools.wraps(fn)
            def _dispatch(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
                name = name if name else fn.__name__
                openai_kwargs = self.openai_kwargs(
                    name=name,
                    fn=fn,
                    args=args,
                    kwargs=kwargs,
                    base_model=return_base_model,
                )
                return self.client.chat.completions.create(
                    **openai_kwargs, model=model, response_model=return_base_model
                )

            @functools.wraps(fn)
            def _distil(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
                resp = fn(*args, **kwargs)
                self.track(
                    fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format
                )

                return resp

            return _dispatch if mode == "dispatch" else _distil

        if len(args) == 1 and callable(args[0]):
            return _wrap_distil(args[0])

        return _wrap_distil

    @validate_call  # type: ignore[misc]
    def track(
        self,
        fn: Callable[..., Any],
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
        resp: BaseModel,
        name: Optional[str] = None,
        finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
    ) -> None:
        """
        Track the function call and response in a log file, later used for finetuning.

        :param fn: Function to track.
        :param args: Arguments passed to the function.
        :param kwargs: Keyword arguments passed to the function.
        :param resp: Response returned by the function.
        :param name: Name of the function to track. Defaults to the function name.
        :param finetune_format: Format to use for finetuning. Defaults to "raw".
        """
        name = name if name else fn.__name__
        base_model: BaseModel = type(resp)

        if finetune_format == FinetuneFormat.MESSAGES:
            openai_function_call = openai_schema(base_model).openai_schema
            openai_kwargs = self.openai_kwargs(name, fn, args, kwargs, base_model)
            openai_kwargs["messages"].append(
                {
                    "role": "assistant",
                    "function_call": {
                        "name": base_model.__name__,
                        "arguments": resp.model_dump_json(indent=self.indent),
                    },
                }
            )
            openai_kwargs["functions"] = [openai_function_call]
            self.logger.info(json.dumps(openai_kwargs))

        if finetune_format == FinetuneFormat.RAW:
            function_body = dict(
                fn_name=name,
                fn_repr=format_function(fn),
                args=args,
                kwargs=kwargs,
                resp=resp.model_dump(),
                schema=base_model.model_json_schema(),
            )
            self.logger.info(json.dumps(function_body))

    def openai_kwargs(
        self,
        name: str,
        fn: Callable[..., Any],
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
        base_model: Type[BaseModel],
    ) -> Dict[str, Any]:
        if self.include_code_body:
            func_def = format_function(fn)
        else:
            func_def = get_signature_from_fn(fn)

        str_args = ", ".join(map(str, args))
        str_kwargs = (
            ", ".join(f"{k}={json.dumps(v)}" for k, v in kwargs.items()) or None
        )
        call_args = ", ".join(filter(None, [str_args, str_kwargs]))

        function_body = {
            "messages": [
                {
                    "role": "system",
                    "content": f"Predict the results of this function:\n\n{func_def}",
                },
                {
                    "role": "user",
                    "content": f"Return `{name}({call_args})`",
                },
            ],
        }
        return function_body
