"""
Voice Activity Detection (VAD) Module

Energy-based VAD using RMS with adaptive noise floor estimation.
"""

import audioop
from dataclasses import dataclass, field
from typing import Optional

from ..config import (
    SAMPLE_RATE_IN,
    BYTES_PER_SAMPLE,
    VAD_FRAME_MS,
    END_SILENCE_MS,
    MIN_UTTERANCE_MS,
    MIN_RMS_THRESHOLD,
    NOISE_MULTIPLIER,
    NOISE_EMA_ALPHA,
    FRAME_BYTES_8K,
    MIN_PCM_BYTES_8K,
)


@dataclass
class VADDetector:
    """
    Voice Activity Detector using energy-based RMS thresholding.

    Uses exponential moving average (EMA) for adaptive noise floor estimation.

    Attributes:
        frame_ms: Duration of each analysis frame in milliseconds
        end_silence_ms: Silence duration (ms) to mark end of utterance
        min_utterance_ms: Minimum utterance length (ms) to accept
        min_rms_threshold: Absolute minimum RMS threshold
        noise_multiplier: Factor to multiply noise floor for threshold
        noise_ema_alpha: EMA smoothing factor for noise estimation
    """

    frame_ms: int = VAD_FRAME_MS
    end_silence_ms: int = END_SILENCE_MS
    min_utterance_ms: int = MIN_UTTERANCE_MS
    min_rms_threshold: int = MIN_RMS_THRESHOLD
    noise_multiplier: float = NOISE_MULTIPLIER
    noise_ema_alpha: float = NOISE_EMA_ALPHA

    # Internal state
    _vad_buffer: bytes = field(default=b"", repr=False)
    _speech_buffer: bytes = field(default=b"", repr=False)
    _speaking: bool = field(default=False, repr=False)
    _silence_ms: int = field(default=0, repr=False)
    _noise_rms: float = field(default=0.0, repr=False)
    _buffer_committed: bool = field(default=False, repr=False)

    @property
    def frame_bytes(self) -> int:
        """Number of bytes per VAD frame"""
        return int(SAMPLE_RATE_IN * (self.frame_ms / 1000.0) * BYTES_PER_SAMPLE)

    @property
    def min_pcm_bytes(self) -> int:
        """Minimum PCM bytes for valid utterance (~400ms)"""
        return MIN_PCM_BYTES_8K

    @property
    def is_speaking(self) -> bool:
        """Whether speech is currently detected"""
        return self._speaking

    @property
    def speech_buffer(self) -> bytes:
        """Current accumulated speech buffer"""
        return self._speech_buffer

    @property
    def silence_ms(self) -> int:
        """Current silence duration in milliseconds"""
        return self._silence_ms

    def reset(self):
        """Reset VAD state for new utterance detection"""
        self._speech_buffer = b""
        self._speaking = False
        self._silence_ms = 0
        self._buffer_committed = False

    def reset_all(self):
        """Full reset including noise estimation"""
        self.reset()
        self._vad_buffer = b""
        self._noise_rms = 0.0

    def add_audio(self, pcm_8k: bytes) -> list[dict]:
        """
        Add audio data and process for voice activity.

        Args:
            pcm_8k: 8kHz 16-bit PCM audio data

        Returns:
            List of detected utterances, each containing:
                - pcm_8k: The speech audio buffer
                - duration_ms: Duration in milliseconds
        """
        self._vad_buffer += pcm_8k
        utterances = []

        while len(self._vad_buffer) >= self.frame_bytes:
            frame = self._vad_buffer[: self.frame_bytes]
            self._vad_buffer = self._vad_buffer[self.frame_bytes :]

            result = self._process_frame(frame)
            if result:
                utterances.append(result)

        return utterances

    def _process_frame(self, frame: bytes) -> Optional[dict]:
        """
        Process a single audio frame for VAD.

        Returns utterance dict if speech segment completed, None otherwise.
        """
        rms = audioop.rms(frame, BYTES_PER_SAMPLE)

        # Update noise floor estimate during silence
        if not self._speaking:
            self._noise_rms = (
                (1.0 - self.noise_ema_alpha) * self._noise_rms
                + self.noise_ema_alpha * rms
            )

        # Calculate adaptive threshold
        threshold = max(self.min_rms_threshold, int(self._noise_rms * self.noise_multiplier))
        is_speech = rms >= threshold

        if is_speech:
            if not self._speaking:
                # Speech start
                self._speaking = True
                self._silence_ms = 0
                self._buffer_committed = False

            self._speech_buffer += frame
            self._silence_ms = 0

        else:
            if self._speaking:
                # In trailing silence
                self._silence_ms += self.frame_ms
                self._speech_buffer += frame

                if self._silence_ms >= self.end_silence_ms:
                    # Speech ended - return utterance
                    return self._commit_utterance()

        return None

    def _commit_utterance(self) -> Optional[dict]:
        """Commit current speech buffer as completed utterance"""
        self._speaking = False

        duration_ms = int(
            (len(self._speech_buffer) / (BYTES_PER_SAMPLE * SAMPLE_RATE_IN)) * 1000
        )

        if duration_ms < self.min_utterance_ms:
            self.reset()
            return None

        if len(self._speech_buffer) < self.min_pcm_bytes:
            self.reset()
            return None

        result = {
            "pcm_8k": self._speech_buffer,
            "duration_ms": duration_ms,
        }

        self._buffer_committed = True
        self._speech_buffer = b""
        self._silence_ms = 0

        return result

    def force_commit(self) -> Optional[dict]:
        """
        Force commit current speech buffer (for wall-clock timeout).

        Returns utterance dict if buffer is valid, None otherwise.
        """
        if not self._speech_buffer or self._buffer_committed:
            return None

        return self._commit_utterance()

    def get_threshold(self) -> int:
        """Get current adaptive RMS threshold"""
        return max(self.min_rms_threshold, int(self._noise_rms * self.noise_multiplier))

    def get_noise_floor(self) -> float:
        """Get current estimated noise RMS"""
        return self._noise_rms
