"""
RingAI – Real-time IVR + Near-Real-Time STT (Whisper-base, CPU)

Pipeline:
Caller → Twilio Media Stream → WebSocket
→ μ-law decode → 8 kHz PCM → buffer (1s)
→ resample to 16 kHz → save WAV chunks (dataset)
→ rolling context window → Whisper-base (CPU)
→ incremental live transcription (console)
→ save FINAL call transcript to disk (call-level, not rolling-only)
"""

# Imports
from fastapi import FastAPI, WebSocket
from fastapi.responses import PlainTextResponse

import base64
import audioop
import time
import wave
import os
import csv
import asyncio
import tempfile
import whisper
import logging
import json


# =========================
# Audio constants
# =========================

SAMPLE_RATE_IN = 8000          # Twilio streams at 8 kHz
BYTES_PER_SAMPLE = 2           # 16-bit PCM (2 bytes)
SAMPLE_RATE_OUT = 16000        # Whisper expects 16 kHz

BUFFER_DURATION_SEC = 1        # collect ~1 second of audio
BUFFER_SIZE = SAMPLE_RATE_IN * BYTES_PER_SAMPLE * BUFFER_DURATION_SEC


# =========================
# Dataset paths
# =========================

DATASET_ROOT = "dataset"
AUDIO_ROOT = os.path.join(DATASET_ROOT, "audio")
META_ROOT = os.path.join(DATASET_ROOT, "metadata")
TRANSCRIPTS_ROOT = os.path.join(DATASET_ROOT, "transcripts_raw")

os.makedirs(AUDIO_ROOT, exist_ok=True)
os.makedirs(META_ROOT, exist_ok=True)
os.makedirs(TRANSCRIPTS_ROOT, exist_ok=True)


# =========================
# Whisper (Realtime – CPU)
# =========================

WHISPER_MODEL_NAME = "base"
WHISPER_LANGUAGE = "en"
WHISPER_FP16 = False  # must be False on CPU

TRANSCRIBE_EVERY_SEC = 3        # how often to update live text
ROLLING_WINDOW_SEC = 20         # how much context Whisper sees

BYTES_PER_SEC_16K = 16000 * 2   # 16kHz * 2 bytes = 32000 bytes/sec
ROLLING_MAX_BYTES = ROLLING_WINDOW_SEC * BYTES_PER_SEC_16K

print(f"🔁 Loading Whisper model: {WHISPER_MODEL_NAME}")
whisper_model = whisper.load_model(WHISPER_MODEL_NAME)

# Reduce noisy logs
logging.getLogger("whisper").setLevel(logging.ERROR)


# =========================
# FastAPI app
# =========================

app = FastAPI()


# =========================
# Incoming call (TwiML)
# =========================

@app.post("/incoming-call")
async def incoming_call():
    """
    Twilio hits this endpoint when a call comes in.
    We return TwiML telling Twilio to open a WebSocket stream to our server.
    """
    twiml = """<?xml version="1.0" encoding="UTF-8"?>
<Response>
  <Say voice="alice">Hello.</Say>
  <Pause length="1"/>
  <Say voice="alice">This is Ring A I.</Say>
  <Pause length="1"/>
  <Say voice="alice">Please speak clearly after the tone.</Say>

  <Connect>
    <Stream url="wss://ringai.southernaccountancy.com/media-stream"/>
  </Connect>
</Response>"""
    return PlainTextResponse(content=twiml, media_type="application/xml")


# =========================
# WebSocket: Media Stream
# =========================

@app.websocket("/media-stream")
async def media_stream(ws: WebSocket):
    """
    Receives real-time audio from Twilio.
    Handles:
    - buffering
    - WAV saving
    - rolling Whisper transcription (live)
    - saving FINAL call transcript to disk
    """

    await ws.accept()
    print("🔊 Media stream connected")

    # ---- Chunk buffering (1 second) ----
    audio_buffer = b""
    buffer_start_time = None

    # ---- Call metadata ----
    call_id = "unknown_call"
    stream_id = "unknown_stream"
    chunk_index = 0

    # ---- Realtime STT state ----
    rolling_pcm_16k = b""
    last_transcribe_ts = 0.0

    # This is Whisper's rolling output (changes over time)
    last_full_text = ""

    # ✅ This is the call-level transcript accumulator (ONLY GROWS)
    call_transcript_parts = []

    # prevent overlapping transcriptions
    transcribe_lock = asyncio.Lock()

    try:
        while True:
            message = await ws.receive_json()
            event = message.get("event")

            # -------------------------
            # Stream start
            # -------------------------
            if event == "start":
                start = message.get("start", {}) or {}
                stream_id = start.get("streamSid") or "unknown_stream"
                call_id = start.get("callSid") or stream_id

                os.makedirs(os.path.join(AUDIO_ROOT, call_id), exist_ok=True)

                print(f"▶️ Stream started | call_id={call_id}")

                # reset buffers for a clean call
                audio_buffer = b""
                rolling_pcm_16k = b""
                buffer_start_time = time.time()
                chunk_index = 0
                last_full_text = ""
                call_transcript_parts = []
                last_transcribe_ts = 0.0

            # -------------------------
            # Audio frame
            # -------------------------
            elif event == "media":
                payload = message["media"]["payload"]

                # μ-law → PCM @ 8 kHz
                ulaw_audio = base64.b64decode(payload)
                pcm_8k = audioop.ulaw2lin(ulaw_audio, BYTES_PER_SAMPLE)

                audio_buffer += pcm_8k

                # When ~1 second collected
                if len(audio_buffer) >= BUFFER_SIZE:
                    chunk_index += 1
                    duration = time.time() - buffer_start_time

                    # Resample 8k → 16k
                    pcm_16k, _ = audioop.ratecv(
                        audio_buffer,
                        BYTES_PER_SAMPLE,
                        1,
                        SAMPLE_RATE_IN,
                        SAMPLE_RATE_OUT,
                        None
                    )

                    # ---- Save WAV chunk (dataset) ----
                    ts = time.strftime("%Y%m%d_%H%M%S")
                    rel_path = f"{call_id}/chunk_{chunk_index:04d}_{ts}.wav"
                    wav_path = os.path.join(AUDIO_ROOT, rel_path)

                    save_wav_16k_mono(pcm_16k, wav_path)

                    append_metadata_row(
                        os.path.join(META_ROOT, f"{call_id}.csv"),
                        {
                            "call_id": call_id,
                            "chunk_index": chunk_index,
                            "timestamp": ts,
                            "wav_path": os.path.join("audio", rel_path),
                            "duration_sec": round(duration, 2),
                        }
                    )

                    print(f"💾 Saved chunk {chunk_index:04d}")

                    # ---- Rolling window for live STT ----
                    rolling_pcm_16k += pcm_16k
                    if len(rolling_pcm_16k) > ROLLING_MAX_BYTES:
                        rolling_pcm_16k = rolling_pcm_16k[-ROLLING_MAX_BYTES:]

                    # ---- Periodic Whisper transcription (async) ----
                    now = time.time()
                    if now - last_transcribe_ts >= TRANSCRIBE_EVERY_SEC:
                        last_transcribe_ts = now
                        snapshot = rolling_pcm_16k

                        async def run_transcribe():
                            nonlocal last_full_text, call_transcript_parts
                            async with transcribe_lock:
                                text = await asyncio.to_thread(
                                    whisper_transcribe_bytes,
                                    snapshot
                                )

                                # Get ONLY the new portion of text
                                delta = incremental_diff(last_full_text, text)

                                if delta:
                                    print(f"📝 LIVE: {delta}")

                                    # ✅ Keep call-level transcript (monotonic)
                                    call_transcript_parts.append(delta)

                                last_full_text = text

                        asyncio.create_task(run_transcribe())

                    # reset 1s chunk buffer
                    audio_buffer = b""
                    buffer_start_time = time.time()

            # -------------------------
            # Stream stop
            # -------------------------
            elif event == "stop":
                print(f"⏹️ Stream stopped | call_id={call_id}")

                # ✅ FINAL transcript should be the accumulated deltas, not rolling window
                final_text = " ".join(call_transcript_parts).strip()

                if final_text:
                    # Save TXT (human-friendly)
                    txt_path = os.path.join(TRANSCRIPTS_ROOT, f"{call_id}.txt")
                    with open(txt_path, "w", encoding="utf-8") as f:
                        f.write(final_text + "\n")

                    # Save JSON (machine-friendly)
                    json_path = os.path.join(TRANSCRIPTS_ROOT, f"{call_id}.json")
                    with open(json_path, "w", encoding="utf-8") as f:
                        json.dump(
                            {
                                "call_id": call_id,
                                "language": WHISPER_LANGUAGE,
                                "model": WHISPER_MODEL_NAME,
                                "final_transcript": final_text,
                            },
                            f,
                            indent=2,
                            ensure_ascii=False
                        )

                    print(f"📄 Final transcript saved | {txt_path}")
                else:
                    print("⚠️ No transcript to save")

                break

    except Exception as e:
        print("❌ WebSocket error:", e)


# =========================
# Helper functions
# =========================

def save_wav_16k_mono(pcm_audio: bytes, path: str):
    """Save 16 kHz mono PCM to WAV."""
    with wave.open(path, "wb") as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(SAMPLE_RATE_OUT)
        wf.writeframes(pcm_audio)


def append_metadata_row(csv_path: str, row: dict):
    """Append metadata for each chunk."""
    exists = os.path.exists(csv_path)
    with open(csv_path, "a", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=row.keys())
        if not exists:
            writer.writeheader()
        writer.writerow(row)


def incremental_diff(prev: str, new: str) -> str:
    """
    Return only newly added text compared to the previous transcript.
    This is what we append to the call-level transcript.
    """
    prev = (prev or "").strip()
    new = (new or "").strip()

    if not prev:
        return new

    if new.startswith(prev):
        return new[len(prev):].lstrip()

    # fallback: longest common prefix
    i = 0
    while i < min(len(prev), len(new)) and prev[i] == new[i]:
        i += 1
    return new[i:].lstrip()


def whisper_transcribe_bytes(pcm_16k: bytes) -> str:
    """Run Whisper on PCM bytes (16kHz mono)."""
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
        save_wav_16k_mono(pcm_16k, tmp.name)
        result = whisper_model.transcribe(
            tmp.name,
            language=WHISPER_LANGUAGE,
            task="transcribe",
            fp16=WHISPER_FP16,
            verbose=False,
            condition_on_previous_text=True
        )
    return (result.get("text") or "").strip()
