triton之normalization教程

一 前向

在上式中,x是代表一个tensor

import torch

import triton
import triton.language as tl

try:
    # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
    # should not be added to extras_require in setup.py.
    import apex
    HAS_APEX = True
except ModuleNotFoundError:
    HAS_APEX = False


@triton.jit
def _layer_norm_fwd_fused(
    X,  # pointer to the input
    Y,  # pointer to the output
    W,  # pointer to the weights
    B,  # pointer to the biases
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
    stride,  # how much to increase the pointer when moving by 1 row
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
    BLOCK_SIZE: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.

相关推荐

  1. triton教程1:前言、安装、跑官方例子与推荐视频

    2024-05-04 17:06:03       56 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-04 17:06:03       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-04 17:06:03       106 阅读
  3. 在Django里面运行非项目文件

    2024-05-04 17:06:03       87 阅读
  4. Python语言-面向对象

    2024-05-04 17:06:03       96 阅读

热门阅读

  1. P2404 自然数的拆分问题 题解

    2024-05-04 17:06:03       31 阅读
  2. android 14.0 SystemUI导航栏添加虚拟按键功能(三)

    2024-05-04 17:06:03       31 阅读
  3. 404 Not Found - GET https://registry.npmjs.org/fs-promises

    2024-05-04 17:06:03       34 阅读
  4. 大数据分析入门10分钟快速了解SQL

    2024-05-04 17:06:03       30 阅读
  5. PIXI入门系列之终章

    2024-05-04 17:06:03       35 阅读
  6. python编程功能选择建议处理方式

    2024-05-04 17:06:03       31 阅读
  7. D3CTF2024

    D3CTF2024

    2024-05-04 17:06:03      26 阅读