import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# ----------------------------
# Parameters
# ----------------------------
L = 1.0                 # well width
hbar = 1.0
m = 1.0

sigma = L / 20          # initial Gaussian width
x0 = L / 3              # initial packet center
k0 = 80.0               # initial momentum wave number

N = 250                 # number of energy eigenstates
Nx = 1200               # spatial grid points

x = np.linspace(0, L, Nx)
dx = x[1] - x[0]

# ----------------------------
# Infinite well eigenstates
# phi_n(x) = sqrt(2/L) sin(n pi x / L)
# E_n = n^2 pi^2 hbar^2 / (2mL^2)
# ----------------------------
n = np.arange(1, N + 1)
E = (n**2 * np.pi**2 * hbar**2) / (2 * m * L**2)

phi = np.sqrt(2 / L) * np.sin(np.outer(n, np.pi * x / L))

# ----------------------------
# Initial Gaussian packet
# ----------------------------
psi0 = np.exp(-(x - x0)**2 / (4 * sigma**2)) * np.exp(1j * k0 * x)

# Enforce zero at the walls
psi0[0] = 0
psi0[-1] = 0

# Normalize
psi0 /= np.sqrt(np.trapz(np.abs(psi0)**2, x))

# Expansion coefficients c_n = integral phi_n^*(x) psi0(x) dx
c = np.array([
    np.trapz(phi[i] * psi0, x)
    for i in range(N)
])

# ----------------------------
# Time evolution
# psi(x,t) = sum_n c_n phi_n(x) exp(-i E_n t / hbar)
# ----------------------------
def psi_t(t):
    phase = np.exp(-1j * E * t / hbar)
    return np.sum(c[:, None] * phi * phase[:, None], axis=0)

# ----------------------------
# Animation
# ----------------------------
fig, ax = plt.subplots(figsize=(8, 4))

line, = ax.plot([], [], lw=2)
ax.set_xlim(0, L)
ax.set_ylim(0, 1.3 * np.max(np.abs(psi0)**2))
ax.set_xlabel("x")
ax.set_ylabel(r"$|\psi(x,t)|^2$")
ax.set_title("Gaussian wave packet in an infinite square well")

time_text = ax.text(0.02, 0.9, "", transform=ax.transAxes)

T_cl = 2 * m * L / (hbar * k0)   # rough classical crossing time
t_max = 4 * T_cl
frames = 300
times = np.linspace(0, t_max, frames)

def init():
    line.set_data([], [])
    time_text.set_text("")
    return line, time_text

def update(frame):
    t = times[frame]
    psi = psi_t(t)
    density = np.abs(psi)**2

    line.set_data(x, density)
    time_text.set_text(f"t = {t:.4f}")
    return line, time_text

ani = FuncAnimation(
    fig,
    update,
    frames=frames,
    init_func=init,
    interval=30,
    blit=True
)

plt.show()

# To save as MP4, uncomment:
#ani.save("gaussian_packet_in_box.mp4", writer="ffmpeg", fps=30)

# To save as GIF, uncomment:
# ani.save("gaussian_packet_in_box.gif", writer="pillow", fps=30)
