暮晨集雨

Mu Chen's Online Collection

谓暮晨集雨者,牧宸集语也。一个AI加持的个人博客站。

FFT 加窗demo

import torch
import matplotlib.pyplot as plt

# --------------------------
# 1. 设备配置(自动调用GPU + cuFFT)
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device} (cuFFT 已启用)")

# --------------------------
# 2. 生成测试信号(GPU上)
# --------------------------
SAMPLING_RATE = 1000
DURATION = 1
FREQ1 = 50
FREQ2 = 120
NOISE_AMPLITUDE = 0.5

t = torch.linspace(0, DURATION, int(SAMPLING_RATE * DURATION), device=device)
# 50Hz + 120Hz 正弦波 + 噪声
signal = (
    1.0 * torch.sin(2 * torch.pi * FREQ1 * t)
    + 0.5 * torch.sin(2 * torch.pi * FREQ2 * t)
    + NOISE_AMPLITUDE * torch.randn_like(t)
)

# --------------------------
# 3. 加汉宁窗(抑制频谱泄露)
# --------------------------
window = torch.hann_window(len(signal), device=device)
signal_windowed = signal * window

# --------------------------
# 4. GPU加速cuFFT计算(PyTorch直接调用cuFFT)
# --------------------------
fft_raw = torch.fft.fft(signal)
fft_windowed = torch.fft.fft(signal_windowed)

# 频率轴
freq = torch.fft.fftfreq(len(t), 1/SAMPLING_RATE, device=device)[:len(t)//2]

# 计算幅度谱(dB)
def spectrum(fft_data):
    mag = torch.abs(fft_data)[:len(fft_data)//2] / len(fft_data)
    mag = 2 * mag
    return 20 * torch.log10(mag + 1e-10)

spec_raw = spectrum(fft_raw)
spec_windowed = spectrum(fft_windowed)

# --------------------------
# 4.1 按总能量百分比保留频域分量,再 iFFT 回时域
# --------------------------
ENERGY_KEEP_RATIO = 0.95

rfft_raw = torch.fft.rfft(signal)
energy = torch.abs(rfft_raw) ** 2
total_energy = torch.sum(energy)

sorted_energy, sorted_idx = torch.sort(energy, descending=True)
cumulative_energy = torch.cumsum(sorted_energy, dim=0)
keep_count = int(torch.searchsorted(cumulative_energy, ENERGY_KEEP_RATIO * total_energy).item()) + 1

mask = torch.zeros_like(rfft_raw, dtype=torch.bool)
mask[sorted_idx[:keep_count]] = True
rfft_filtered = torch.where(mask, rfft_raw, torch.zeros_like(rfft_raw))
signal_ifft = torch.fft.irfft(rfft_filtered, n=len(signal))

# --------------------------
# 5. 转到CPU绘图
# --------------------------
t_cpu = t.cpu().numpy()
sig_cpu = signal.cpu().numpy()
sig_win_cpu = signal_windowed.cpu().numpy()
sig_ifft_cpu = signal_ifft.cpu().numpy()
freq_cpu = freq.cpu().numpy()
spec_raw_cpu = spec_raw.cpu().numpy()
spec_win_cpu = spec_windowed.cpu().numpy()

# --------------------------
# 绘图
# --------------------------
plt.figure(figsize=(12,10))
plt.subplot(411)
plt.plot(t_cpu, sig_cpu, 'gray', label='raw signal')
plt.title('Time Domain Signal (cuFFT)')
plt.legend()
plt.grid()

plt.subplot(412)
plt.plot(t_cpu, sig_win_cpu, 'blue', label='hanning windowed signal')
plt.legend()
plt.grid()

plt.subplot(413)
plt.plot(freq_cpu, spec_raw_cpu, 'r', alpha=0.6, label='raw cuFFT')
plt.plot(freq_cpu, spec_win_cpu, 'g', linewidth=2, label='windowed cuFFT')
plt.scatter([50,120], [0,-6], c='black', label='true frequencies')
plt.xlim(0,200)
plt.title('Frequency Domain Spectrum (cuFFT)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude (dB)')
plt.legend()
plt.grid()

plt.subplot(414)
plt.plot(t_cpu, sig_cpu, 'gray', alpha=0.6, label='raw signal')
plt.plot(t_cpu, sig_ifft_cpu, 'm', linewidth=1.5, label=f'iFFT reconstruction ({int(ENERGY_KEEP_RATIO*100)}% energy kept)')
plt.title('Time Domain Overlay: Raw vs Energy-Threshold iFFT')
plt.xlabel('Time (s)')
plt.legend()
plt.grid()

plt.tight_layout()
plt.show()


评论

发表回复