代码解读 | Hybrid Transformers for Music Source Separation[04]

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块、ISTFT模块)7个模块。

        本篇目标:拆解STFT模块的底层。

二、拆解STFT模块底层

2.1 torch.stft

import torch as th


def spectro(x, n_fft=512, hop_length=None, pad=0):
    *other, length = x.shape
    x = x.reshape(-1, length)
    is_mps = x.device.type == 'mps'
    if is_mps:
        x = x.cpu()
    z = th.stft(x,
                n_fft * (1 + pad),
                hop_length or n_fft // 4,
                window=th.hann_window(n_fft).to(x),
                win_length=n_fft,
                normalized=True,
                center=True,
                return_complex=True,
                pad_mode='reflect')
    _, freqs, frame = z.shape
    return z.view(*other, freqs, frame)

        核心代码,长上面这样。

        简单说一下为啥使用短时傅里叶变换(STFT),而不直接使用傅里叶变换(FT)。原因:傅立叶变换只能告诉我们信号当中有哪些频率成分。当我们还想知道各个成分出现的时间的时候,就得用到STFT了(这也就是时频分析。所谓时频分析,就是既要考虑到频率特征,又要考虑到时间序列变化)。

        上述公式就是torch.stft的底层公式,一句话总结:首先窗函数×时域信号,然后进行傅里叶变换其中,\omega表示频率,m表示滑动窗口的下标,input是一个时间序列,hop_length表示窗移大小,win_length表示窗长,window表示窗函数。


        具体的,torch.stft函数中各个参数的意义如下所示。

参数名称 说明
input (Tensor):the input tensor 输入
n_fft (int): size of Fourier transform 傅里叶变换大小(决定频率分辨率)
hop_length (int, optional): the distance between neighboring sliding window frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) 窗移,默认大小floor(n_fft / 4)
win_length (int, optional): the size of window frame and STFT filter. Default: ``None`` (treated as equal to :attr:`n_fft`) 窗长,默认大小n_fft
window (Tensor, optional): the optional window function. Default: ``None`` (treated as window of all :math:`1` s) 窗函数
center (bool, optional): whether to pad :attr:`input` on both sides so that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. Default: ``True``

是否对input两侧进行填充,

以至于在t帧的是居中的

pad_mode (string, optional): controls the padding method used when :attr:`center` is ``True``. Default: ``"reflect"`` 填充模式
normalized (bool, optional): controls whether to return the normalized STFT results Default: ``False`` 是否归一化
onesided (bool, optional): controls whether to return half of results to avoid redundancy for real inputs. Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. 控制是否返回一半结果
return_complex (bool, optional): whether to return a complex tensor, or a real tensor with an extra last dimension for the real and imaginary components. 返回值是否设置为复数
  • n_fft 关注的是频率分辨率,即能够分辨的最小频率间隔。n_fft 越大,频率分辨率越高,但计算量也越大。
  • win_length 关注的是时间分辨率,即能够分辨的最小时间间隔。win_length 越大,时间分辨率越低,但可以更好地捕捉到低频信号的特征。

2.2 STFT整个模块干了啥

        上图是htdemucs调用STFT模块的入口。

       1、为了保持输出大小=输入大小/hop_length,先对输入信息进行填充(使用pad1d函数),然后进行STFT变换(核心代码见2.1)。

        2、拿到STFT结果后,进入_magnitude函数。当cac为True的时候,_magnitude函数把复数维度移动到通道维度。当cac为False的时候,_magnitude函数计算出幅度值。

        done,STFT模块讲解完成。


        感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

相关推荐

  1. ChatPDF代码解读2

    2024-06-12 04:06:02       10 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-12 04:06:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-12 04:06:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-12 04:06:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-12 04:06:02       20 阅读

热门阅读

  1. AIGC涉及到的算法(一)

    2024-06-12 04:06:02       7 阅读
  2. 集线器(HUB)简介

    2024-06-12 04:06:02       10 阅读
  3. dp类总结

    2024-06-12 04:06:02       9 阅读
  4. Spring是什么??IOC又是什么??

    2024-06-12 04:06:02       9 阅读
  5. 学习PLC+LabVIEW

    2024-06-12 04:06:02       8 阅读
  6. 【VUE3】自定义防抖指令

    2024-06-12 04:06:02       8 阅读
  7. controller_manager卡在loading_controller

    2024-06-12 04:06:02       9 阅读
  8. 中继器简介

    2024-06-12 04:06:02       9 阅读