"""
Dialogue State Management

Manages conversation state, cooldowns, and call persistence.
"""

import asyncio
import time
from dataclasses import dataclass, field
from typing import Optional, Dict, List

from ..config import (
    AI_SPEAK_COOLDOWN_SEC,
    MAX_CLARIFICATION_COUNT,
    STALE_CALL_TIMEOUT_SEC,
    log_call,
)


@dataclass
class DialogueState:
    """
    State machine for managing dialogue flow.

    Tracks AI responses, user turns, clarification loops, and conversation status.
    """

    last_ai_speak_ts: float = 0.0
    awaiting_user: bool = False
    clarify_count: int = 0
    conversation_complete: bool = False
    call_ended: bool = False
    locked_intent: Optional[str] = None

    def can_speak(self, cooldown_sec: float = AI_SPEAK_COOLDOWN_SEC) -> bool:
        """Check if AI is allowed to speak (cooldown expired)"""
        return time.time() - self.last_ai_speak_ts >= cooldown_sec

    def record_ai_speak(self):
        """Record that AI just spoke"""
        self.last_ai_speak_ts = time.time()
        self.awaiting_user = True

    def record_user_input(self):
        """Record that user started speaking"""
        if time.time() - self.last_ai_speak_ts > 0.8:
            self.awaiting_user = False

    def increment_clarify(self):
        """Increment clarification counter"""
        self.clarify_count += 1

    def reset_clarify(self):
        """Reset clarification counter"""
        self.clarify_count = 0

    def should_escalate(self, max_count: int = MAX_CLARIFICATION_COUNT) -> bool:
        """Check if should escalate due to too many clarifications"""
        return self.clarify_count > max_count

    def lock_intent(self, intent: str):
        """Lock conversation to a specific intent"""
        if intent and self.locked_intent is None:
            self.locked_intent = intent

    def end_conversation(self):
        """Mark conversation as complete (graceful end)"""
        self.conversation_complete = True
        self.awaiting_user = False

    def end_call(self):
        """Mark call as ended (hangup)"""
        self.call_ended = True

    def is_active(self) -> bool:
        """Check if conversation is still active"""
        return not self.call_ended and not self.conversation_complete

    def to_dict(self) -> dict:
        """Export state as dictionary"""
        return {
            "last_ai_speak_ts": self.last_ai_speak_ts,
            "awaiting_user": self.awaiting_user,
            "clarify_count": self.clarify_count,
            "conversation_complete": self.conversation_complete,
            "call_ended": self.call_ended,
            "locked_intent": self.locked_intent,
        }

    @classmethod
    def from_dict(cls, data: dict) -> "DialogueState":
        """Create state from dictionary"""
        return cls(
            last_ai_speak_ts=data.get("last_ai_speak_ts", 0.0),
            awaiting_user=data.get("awaiting_user", False),
            clarify_count=data.get("clarify_count", 0),
            conversation_complete=data.get("conversation_complete", False),
            call_ended=data.get("call_ended", False),
            locked_intent=data.get("locked_intent"),
        )


@dataclass
class CallState:
    """
    Persistent state for a single call.

    Survives WebSocket reconnections within timeout period.
    """

    call_id: str
    committed_texts: List[str] = field(default_factory=list)
    llm_history: List[dict] = field(default_factory=list)
    dialogue_state: DialogueState = field(default_factory=DialogueState)
    llm_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
    last_activity: float = field(default_factory=time.time)

    def update_activity(self):
        """Update last activity timestamp"""
        self.last_activity = time.time()

    def is_stale(self, timeout_sec: float = STALE_CALL_TIMEOUT_SEC) -> bool:
        """Check if call has been inactive too long"""
        return time.time() - self.last_activity > timeout_sec

    def add_transcript(self, text: str):
        """Add transcribed text to committed texts"""
        self.committed_texts.append(text)

    def get_final_transcript(self) -> str:
        """Get complete transcript"""
        return " ".join(self.committed_texts).strip()

    def add_to_history(self, role: str, content: str):
        """Add message to LLM history"""
        self.llm_history.append({
            "role": role,
            "content": content,
            "ts": time.time()
        })

    def trim_history(self, max_length: int):
        """Trim history to maximum length"""
        if len(self.llm_history) > max_length:
            self.llm_history[:] = self.llm_history[-max_length:]


class CallStateManager:
    """
    Manager for all active call states.

    Handles creation, retrieval, and cleanup of call states.
    """

    def __init__(self):
        self._states: Dict[str, CallState] = {}

    def get(self, call_id: str) -> Optional[CallState]:
        """Get call state if it exists"""
        return self._states.get(call_id)

    def get_or_create(self, call_id: str) -> CallState:
        """Get existing call state or create new one"""
        if call_id not in self._states:
            self._states[call_id] = CallState(call_id=call_id)
            log_call.info("call=%s new state created", call_id)
        else:
            log_call.info("call=%s state restored", call_id)

        state = self._states[call_id]
        state.update_activity()
        return state

    def remove(self, call_id: str):
        """Remove call state"""
        if call_id in self._states:
            del self._states[call_id]
            log_call.info("call=%s state removed", call_id)

    def cleanup_stale(self, timeout_sec: float = STALE_CALL_TIMEOUT_SEC) -> List[str]:
        """
        Find and return list of stale call IDs.

        Does NOT remove them - caller should finalize before removing.
        """
        stale = []
        now = time.time()

        for call_id, state in self._states.items():
            if now - state.last_activity > timeout_sec:
                stale.append(call_id)

        return stale

    def mark_ended(self, call_id: str):
        """Mark call as ended (from webhook)"""
        state = self._states.get(call_id)
        if state:
            state.dialogue_state.end_call()
            log_call.info("call=%s marked as ended", call_id)

    def __len__(self) -> int:
        return len(self._states)

    def __contains__(self, call_id: str) -> bool:
        return call_id in self._states


# Global call state manager
call_state_manager = CallStateManager()
