import time
import threading
from typing import Any, Optional
from datetime import datetime, timedelta
from utils.logger import logger

class Cache:
    def __init__(self):
        self._cache = {}
        self._lock = threading.Lock()
    
    def get(self, key: str) -> Optional[Any]:
        """
        Get a value from cache if it exists and hasn't expired
        
        Args:
            key (str): Cache key
            
        Returns:
            Optional[Any]: Cached value if exists and valid, None otherwise
        """
        with self._lock:
            if key not in self._cache:
                return None
            
            item = self._cache[key]
            if item['expires_at'] and item['expires_at'] < time.time():
                del self._cache[key]
                return None
                
            return item['value']
    
    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
        """
        Set a value in cache with optional TTL
        
        Args:
            key (str): Cache key
            value (Any): Value to cache
            ttl (Optional[int]): Time to live in seconds. If None, item won't expire
        """
        with self._lock:
            expires_at = time.time() + ttl if ttl else None
            self._cache[key] = {
                'value': value,
                'expires_at': expires_at
            }
            logger.debug(f"Cached value for key: {key}")
    
    def clear(self, key: str) -> None:
        """
        Remove a specific key from cache
        
        Args:
            key (str): Cache key to remove
        """
        with self._lock:
            if key in self._cache:
                del self._cache[key]
                logger.debug(f"Cleared cache for key: {key}")
    
    def clear_all(self) -> None:
        """Clear all cached items"""
        with self._lock:
            self._cache.clear()
            logger.debug("Cleared all cache")
    
    def cleanup_expired(self) -> None:
        """Remove all expired items from cache"""
        with self._lock:
            current_time = time.time()
            expired_keys = [
                key for key, item in self._cache.items()
                if item['expires_at'] and item['expires_at'] < current_time
            ]
            for key in expired_keys:
                del self._cache[key]
            if expired_keys:
                logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries")

# Global cache instance
_cache = Cache()

# Convenience functions
def get_cache(key: str) -> Optional[Any]:
    """Get a value from the global cache"""
    return _cache.get(key)

def set_cache(key: str, value: Any, ttl: Optional[int] = None) -> None:
    """Set a value in the global cache"""
    _cache.set(key, value, ttl)

def clear_cache(key: str) -> None:
    """Clear a specific key from the global cache"""
    _cache.clear(key)

def clear_all_cache() -> None:
    """Clear all items from the global cache"""
    _cache.clear_all()

def cleanup_expired_cache() -> None:
    """Clean up expired items from the global cache"""
    _cache.cleanup_expired()