"""
Dataset Storage Module

Manages call data storage including audio chunks, transcripts, and metadata.
"""

import os
import csv
import json
import wave
from pathlib import Path
from typing import Optional, Union

from ..config import CALLS_ROOT, WHISPER_MODEL_NAME, WHISPER_LANGUAGE, log_audio


class CallDataStore:
    """
    Manager for call dataset storage.

    Handles directory structure, audio chunks, transcripts, and metadata.
    """

    def __init__(self, base_path: Union[str, Path] = CALLS_ROOT):
        """
        Initialize data store.

        Args:
            base_path: Root directory for call data
        """
        self.base_path = Path(base_path)
        self.base_path.mkdir(parents=True, exist_ok=True)

    def get_call_dir(self, call_id: str) -> Path:
        """Get root directory for a call"""
        path = self.base_path / call_id
        path.mkdir(parents=True, exist_ok=True)
        return path

    def get_audio_session_dir(self, call_id: str, session_id: str) -> Path:
        """Get audio session directory for a call"""
        path = self.get_call_dir(call_id) / "audio" / "sessions" / str(session_id)
        path.mkdir(parents=True, exist_ok=True)
        return path

    def get_transcript_dir(self, call_id: str) -> Path:
        """Get transcript directory for a call"""
        path = self.get_call_dir(call_id) / "transcript"
        path.mkdir(parents=True, exist_ok=True)
        return path

    def get_llm_dir(self, call_id: str) -> Path:
        """Get LLM events directory for a call"""
        path = self.get_call_dir(call_id) / "llm"
        path.mkdir(parents=True, exist_ok=True)
        return path

    def save_audio_chunk(
        self,
        call_id: str,
        session_id: str,
        chunk_index: int,
        pcm_16k: bytes,
        sample_rate: int = 16000,
    ) -> Path:
        """
        Save audio chunk to disk.

        Args:
            call_id: Call identifier
            session_id: Session identifier
            chunk_index: Chunk number
            pcm_16k: 16kHz 16-bit PCM audio data
            sample_rate: Sample rate (default 16kHz)

        Returns:
            Path to saved WAV file
        """
        session_dir = self.get_audio_session_dir(call_id, session_id)
        chunk_path = session_dir / f"chunk_{chunk_index:04d}.wav"

        with wave.open(str(chunk_path), "wb") as wf:
            wf.setnchannels(1)
            wf.setsampwidth(2)
            wf.setframerate(sample_rate)
            wf.writeframes(pcm_16k)

        log_audio.info("call=%s chunk=%d saved to %s", call_id, chunk_index, chunk_path)
        return chunk_path

    def save_metadata_row(self, call_id: str, row: dict):
        """
        Append metadata row to CSV.

        Args:
            call_id: Call identifier
            row: Metadata dictionary
        """
        meta_path = self.get_call_dir(call_id) / "metadata.csv"
        exists = meta_path.exists()

        with open(meta_path, "a", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=row.keys())
            if not exists:
                writer.writeheader()
            writer.writerow(row)

    def save_utterance(self, call_id: str, utterance: dict):
        """
        Append utterance to JSONL file.

        Args:
            call_id: Call identifier
            utterance: Utterance dictionary (ts, chunk, utter_ms, text, etc.)
        """
        segments_path = self.get_transcript_dir(call_id) / "utterances.jsonl"

        with open(segments_path, "a", encoding="utf-8") as f:
            json.dump(utterance, f, ensure_ascii=False)
            f.write("\n")

    def save_llm_event(self, call_id: str, event: dict):
        """
        Append LLM event to JSONL file.

        Args:
            call_id: Call identifier
            event: LLM event dictionary
        """
        events_path = self.get_llm_dir(call_id) / "events.jsonl"

        with open(events_path, "a", encoding="utf-8") as f:
            json.dump(event, f, ensure_ascii=False)
            f.write("\n")

    def save_final_transcript(self, call_id: str, text: str):
        """
        Save final transcript (both .txt and .json formats).

        Args:
            call_id: Call identifier
            text: Final transcript text
        """
        if not text:
            return

        transcript_dir = self.get_transcript_dir(call_id)

        # Save as plain text
        txt_path = transcript_dir / "final.txt"
        with open(txt_path, "w", encoding="utf-8") as f:
            f.write(text + "\n")

        # Save as JSON with metadata
        json_path = transcript_dir / "final.json"
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(
                {
                    "call_id": call_id,
                    "model": WHISPER_MODEL_NAME,
                    "language": WHISPER_LANGUAGE,
                    "final_transcript": text,
                },
                f,
                indent=2,
                ensure_ascii=False,
            )

    def merge_call_audio(self, call_id: str) -> Optional[Path]:
        """
        Merge all audio sessions into single WAV file.

        Args:
            call_id: Call identifier

        Returns:
            Path to merged WAV file, or None if no audio found
        """
        call_dir = self.get_call_dir(call_id)
        sessions_dir = call_dir / "audio" / "sessions"
        out_path = call_dir / "audio" / "merged.wav"

        if not sessions_dir.is_dir():
            log_audio.warning("call=%s no sessions directory found", call_id)
            return None

        # Collect all chunks from all sessions
        all_chunks = []
        for session_folder in sorted(sessions_dir.iterdir()):
            if not session_folder.is_dir():
                continue
            for chunk_file in sorted(session_folder.iterdir()):
                if chunk_file.suffix == ".wav":
                    all_chunks.append(chunk_file)

        if not all_chunks:
            log_audio.warning("call=%s no audio chunks found", call_id)
            return None

        # Merge all PCM data
        pcm = b""
        for chunk_path in all_chunks:
            try:
                with wave.open(str(chunk_path), "rb") as wf:
                    pcm += wf.readframes(wf.getnframes())
            except Exception as e:
                log_audio.error("call=%s failed to read chunk %s: %r", call_id, chunk_path, e)

        if not pcm:
            return None

        # Write merged file
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with wave.open(str(out_path), "wb") as wf:
            wf.setnchannels(1)
            wf.setsampwidth(2)
            wf.setframerate(16000)
            wf.writeframes(pcm)

        duration = len(pcm) / (16000 * 2)
        log_audio.info(
            "call=%s merged audio saved, duration=%.1fs, chunks=%d",
            call_id, duration, len(all_chunks)
        )

        return out_path

    def list_calls(self) -> list[str]:
        """List all call IDs in storage"""
        return [d.name for d in self.base_path.iterdir() if d.is_dir()]

    def get_call_metadata(self, call_id: str) -> dict:
        """Get summary metadata for a call"""
        call_dir = self.get_call_dir(call_id)

        return {
            "call_id": call_id,
            "has_merged_audio": (call_dir / "audio" / "merged.wav").exists(),
            "has_final_transcript": (call_dir / "transcript" / "final.txt").exists(),
            "has_llm_events": (call_dir / "llm" / "events.jsonl").exists(),
        }


# Global data store instance
_data_store: Optional[CallDataStore] = None


def get_data_store() -> CallDataStore:
    """Get global data store instance"""
    global _data_store
    if _data_store is None:
        _data_store = CallDataStore()
    return _data_store
