#!/usr/bin/env python3
"""
Merge Call Audio Utility

Merges audio chunks from all sessions of a call into a single WAV file.

Usage:
    python -m tools.merge_calls              # Merge all calls
    python -m tools.merge_calls <call_id>    # Merge specific call
"""

import os
import sys
import wave
from pathlib import Path

# Project paths
PROJECT_ROOT = Path(__file__).resolve().parent.parent
CALLS_ROOT = PROJECT_ROOT / "dataset" / "calls"

EXPECTED_SAMPLE_RATE = 16000
EXPECTED_CHANNELS = 1
EXPECTED_SAMPLE_WIDTH = 2  # bytes (16-bit)


def merge_call(call_id: str) -> bool:
    """
    Merge all audio sessions for a call into merged.wav.

    Args:
        call_id: Call identifier (directory name)

    Returns:
        True if successful, False otherwise
    """
    call_dir = CALLS_ROOT / call_id
    sessions_dir = call_dir / "audio" / "sessions"
    output_path = call_dir / "audio" / "merged.wav"

    if not sessions_dir.is_dir():
        print(f"No sessions directory found for call {call_id}")
        return False

    # Collect all chunks from all sessions
    all_chunks = []
    for session_folder in sorted(sessions_dir.iterdir()):
        if not session_folder.is_dir():
            continue
        for chunk_file in sorted(session_folder.iterdir()):
            if chunk_file.suffix == ".wav":
                all_chunks.append(chunk_file)

    if not all_chunks:
        print(f"No chunks found for call {call_id}")
        return False

    # Merge all PCM data
    merged_pcm = b""
    total_frames = 0

    for wav_path in all_chunks:
        try:
            with wave.open(str(wav_path), "rb") as wf:
                sr = wf.getframerate()
                ch = wf.getnchannels()
                sw = wf.getsampwidth()
                frames = wf.getnframes()

                if (
                    sr != EXPECTED_SAMPLE_RATE
                    or ch != EXPECTED_CHANNELS
                    or sw != EXPECTED_SAMPLE_WIDTH
                ):
                    print(
                        f"  Warning: Incompatible WAV format in {wav_path}: "
                        f"{sr}Hz, {ch}ch, {sw * 8}bit (skipped)"
                    )
                    continue

                pcm = wf.readframes(frames)
                merged_pcm += pcm
                total_frames += frames

        except Exception as e:
            print(f"  Warning: Failed to read {wav_path}: {e}")
            continue

    if not merged_pcm:
        print(f"No valid audio data for call {call_id}")
        return False

    # Write merged file
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with wave.open(str(output_path), "wb") as wf:
        wf.setnchannels(EXPECTED_CHANNELS)
        wf.setsampwidth(EXPECTED_SAMPLE_WIDTH)
        wf.setframerate(EXPECTED_SAMPLE_RATE)
        wf.writeframes(merged_pcm)

    duration_sec = total_frames / EXPECTED_SAMPLE_RATE

    print(
        f"Merged call {call_id} | "
        f"chunks={len(all_chunks)} | "
        f"duration={duration_sec:.1f}s | "
        f"file={output_path}"
    )

    return True


def merge_all_calls():
    """Merge audio for all calls in dataset"""
    if not CALLS_ROOT.is_dir():
        print(f"Calls directory not found: {CALLS_ROOT}")
        return

    call_ids = [
        d.name for d in sorted(CALLS_ROOT.iterdir())
        if d.is_dir()
    ]

    if not call_ids:
        print("No call folders found")
        return

    print(f"Found {len(call_ids)} calls to process\n")

    success = 0
    failed = 0

    for call_id in call_ids:
        if merge_call(call_id):
            success += 1
        else:
            failed += 1

    print(f"\nDone: {success} merged, {failed} failed")


def main():
    if len(sys.argv) > 1:
        call_id = sys.argv[1]
        merge_call(call_id)
    else:
        merge_all_calls()


if __name__ == "__main__":
    main()
