import torch
import matplotlib.pyplot as plt
# --------------------------
# 1. 设备配置(自动调用GPU + cuFFT)
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device} (cuFFT 已启用)")
# --------------------------
# 2. 生成测试信号(GPU上)
# --------------------------
SAMPLING_RATE = 1000
DURATION = 1
FREQ1 = 50
FREQ2 = 120
NOISE_AMPLITUDE = 0.5
t = torch.linspace(0, DURATION, int(SAMPLING_RATE * DURATION), device=device)
# 50Hz + 120Hz 正弦波 + 噪声
signal = (
1.0 * torch.sin(2 * torch.pi * FREQ1 * t)
+ 0.5 * torch.sin(2 * torch.pi * FREQ2 * t)
+ NOISE_AMPLITUDE * torch.randn_like(t)
)
# --------------------------
# 3. 加汉宁窗(抑制频谱泄露)
# --------------------------
window = torch.hann_window(len(signal), device=device)
signal_windowed = signal * window
# --------------------------
# 4. GPU加速cuFFT计算(PyTorch直接调用cuFFT)
# --------------------------
fft_raw = torch.fft.fft(signal)
fft_windowed = torch.fft.fft(signal_windowed)
# 频率轴
freq = torch.fft.fftfreq(len(t), 1/SAMPLING_RATE, device=device)[:len(t)//2]
# 计算幅度谱(dB)
def spectrum(fft_data):
mag = torch.abs(fft_data)[:len(fft_data)//2] / len(fft_data)
mag = 2 * mag
return 20 * torch.log10(mag + 1e-10)
spec_raw = spectrum(fft_raw)
spec_windowed = spectrum(fft_windowed)
# --------------------------
# 4.1 按总能量百分比保留频域分量,再 iFFT 回时域
# --------------------------
ENERGY_KEEP_RATIO = 0.95
rfft_raw = torch.fft.rfft(signal)
energy = torch.abs(rfft_raw) ** 2
total_energy = torch.sum(energy)
sorted_energy, sorted_idx = torch.sort(energy, descending=True)
cumulative_energy = torch.cumsum(sorted_energy, dim=0)
keep_count = int(torch.searchsorted(cumulative_energy, ENERGY_KEEP_RATIO * total_energy).item()) + 1
mask = torch.zeros_like(rfft_raw, dtype=torch.bool)
mask[sorted_idx[:keep_count]] = True
rfft_filtered = torch.where(mask, rfft_raw, torch.zeros_like(rfft_raw))
signal_ifft = torch.fft.irfft(rfft_filtered, n=len(signal))
# --------------------------
# 5. 转到CPU绘图
# --------------------------
t_cpu = t.cpu().numpy()
sig_cpu = signal.cpu().numpy()
sig_win_cpu = signal_windowed.cpu().numpy()
sig_ifft_cpu = signal_ifft.cpu().numpy()
freq_cpu = freq.cpu().numpy()
spec_raw_cpu = spec_raw.cpu().numpy()
spec_win_cpu = spec_windowed.cpu().numpy()
# --------------------------
# 绘图
# --------------------------
plt.figure(figsize=(12,10))
plt.subplot(411)
plt.plot(t_cpu, sig_cpu, 'gray', label='raw signal')
plt.title('Time Domain Signal (cuFFT)')
plt.legend()
plt.grid()
plt.subplot(412)
plt.plot(t_cpu, sig_win_cpu, 'blue', label='hanning windowed signal')
plt.legend()
plt.grid()
plt.subplot(413)
plt.plot(freq_cpu, spec_raw_cpu, 'r', alpha=0.6, label='raw cuFFT')
plt.plot(freq_cpu, spec_win_cpu, 'g', linewidth=2, label='windowed cuFFT')
plt.scatter([50,120], [0,-6], c='black', label='true frequencies')
plt.xlim(0,200)
plt.title('Frequency Domain Spectrum (cuFFT)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude (dB)')
plt.legend()
plt.grid()
plt.subplot(414)
plt.plot(t_cpu, sig_cpu, 'gray', alpha=0.6, label='raw signal')
plt.plot(t_cpu, sig_ifft_cpu, 'm', linewidth=1.5, label=f'iFFT reconstruction ({int(ENERGY_KEEP_RATIO*100)}% energy kept)')
plt.title('Time Domain Overlay: Raw vs Energy-Threshold iFFT')
plt.xlabel('Time (s)')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
发表回复
要发表评论,您必须先登录。