# pip install soundfile sounddevice numpy
import numpy as np
import soundfile as sf

try:
    import sounddevice as sd   # optional (for playback)
    HAVE_SD = True
except Exception:
    HAVE_SD = False

def rms_energy(x):
    """Total signal energy (sum of squares) across all channels."""
    return np.sum(x.astype(np.float64)**2)

def add_noise_at_snr(signal, snr_db, rng=None):
    """
    Add zero-mean unit-variance Gaussian noise scaled to achieve the desired (global) SNR in dB.
    SNR = 10*log10(E_signal / E_noise).
    """
    if rng is None:
        rng = np.random.default_rng()
    noise = rng.standard_normal(size=signal.shape)

    Es = rms_energy(signal)
    En = rms_energy(noise)

    # Scale factor for noise amplitude to hit target SNR:
    a = np.sqrt(Es / En) * 10 ** (-snr_db / 20.0)

    noisy = signal + a * noise
    return noisy

def normalize_for_wav(x, margin=1.1):
    """Normalize to keep headroom before saving to 16-bit PCM."""
    peak = np.max(np.abs(x))
    if peak == 0:
        return x
    return x / (margin * peak)

# --- Parameters ---
in_wav  = "furelise-1000z.wav"
out_wav = "furelise-1000z-noise.wav"
SNR_DB  = 10  # target global SNR in dB

# --- Read (soundfile returns float in [-1, 1] when possible) ---
s, fs = sf.read(in_wav, always_2d=False)  # shape: (N,) or (N, C)

# --- (Optional) Listen to the clean audio ---
if HAVE_SD:
    print("Playing clean audio...")
    sd.play(s, fs)
    sd.wait()

# --- Add white Gaussian noise at target SNR ---
sn = add_noise_at_snr(s, SNR_DB)

# --- (Optional) Listen to the noisy audio ---
if HAVE_SD:
    print("Playing noisy audio...")
    sd.play(sn, fs)
    sd.wait()

# --- Normalize and save as 16-bit PCM ---
sn_norm = normalize_for_wav(sn, margin=1.1)
sf.write(out_wav, sn_norm, fs, subtype="PCM_16")

print(f"Saved noisy file to: {out_wav}")

# --- Next steps :
# 1) Spectral subtraction
# 2) Wiener filtering
# You can process `sn_norm` with those methods and save additional outputs.
