import os
import wave
from glob import glob

AUDIO_ROOT = "dataset/audio"
MERGED_ROOT = "dataset/merged"

EXPECTED_SAMPLE_RATE = 16000
EXPECTED_CHANNELS = 1
EXPECTED_SAMPLE_WIDTH = 2  # bytes (16-bit)

os.makedirs(MERGED_ROOT, exist_ok=True)


def merge_call(call_id: str):
    call_dir = os.path.join(AUDIO_ROOT, call_id)
    if not os.path.isdir(call_dir):
        print(f"Call directory not found: {call_id}")
        return

    chunk_files = sorted(glob(os.path.join(call_dir, "*.wav")))

    if not chunk_files:
        print(f"No chunks found for call {call_id}")         
        return

    merged_pcm = b""
    total_frames = 0

    for wav_path in chunk_files:
        with wave.open(wav_path, "rb") as wf:
            sr = wf.getframerate()
            ch = wf.getnchannels()
            sw = wf.getsampwidth()
            frames = wf.getnframes()

            if (
                sr != EXPECTED_SAMPLE_RATE
                or ch != EXPECTED_CHANNELS
                or sw != EXPECTED_SAMPLE_WIDTH
            ):
                raise ValueError(
                    f"Incompatible WAV format in {wav_path}: "
                    f"{sr}Hz, {ch}ch, {sw * 8}bit"
                )

            pcm = wf.readframes(frames)
            merged_pcm += pcm
            total_frames += frames

    out_path = os.path.join(MERGED_ROOT, f"{call_id}.wav")

    with wave.open(out_path, "wb") as wf:
        wf.setnchannels(EXPECTED_CHANNELS)
        wf.setsampwidth(EXPECTED_SAMPLE_WIDTH)
        wf.setframerate(EXPECTED_SAMPLE_RATE)
        wf.writeframes(merged_pcm)

    duration_sec = total_frames / EXPECTED_SAMPLE_RATE

    print(
        f"Merged call {call_id} | "  
        f"chunks={len(chunk_files)} | " 
        f"duration≈{duration_sec:.2f}s | "
        f"file={out_path}"
    )


def merge_all_calls():
    call_ids = [
        d for d in os.listdir(AUDIO_ROOT)
        if os.path.isdir(os.path.join(AUDIO_ROOT, d))
    ]

    if not call_ids:
        print("No call folders found")
        return

    for call_id in call_ids:
        merge_call(call_id)


if __name__ == "__main__":
    merge_all_calls()
