JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

alt

JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。

长话短说:

  • 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
  • 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
  • 使用 jax.grad 求导。
  • 使用 jax.vmap 进行矢量化,并使用 jax.pmap 进行跨设备并行化。

函数式编程

alt

JAX 遵循函数式编程哲学。这意味着您的函数必须是独立的或纯粹的:不允许有副作用。本质上,纯函数看起来像数学函数(图 1)。有输入进来,有东西出来,但与外界没有沟通。

例子#1

以下代码片段是一个非功能纯的示例。

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

注意 impure_example 之外的偏差。在编译期间(见下文),偏差可能会被缓存,因此不再反映偏差的变化。

例子#2

这是一个pure的例子。

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

在这里,pure_example 是独立的:所有参数都作为参数传递。

确定性采样器

alt

在计算机中,不存在真正的随机性。相反,NumPy 和 TensorFlow 等库会跟踪伪随机数状态来生成“随机”样本。

函数式编程的直接后果是随机函数的工作方式不同。由于不再允许全局状态,因此每次采样随机数时都需要显式传入伪随机数生成器 (PRNG) 密钥

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

此外,您有责任为任何后续调用推进“随机状态”。

key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

jit

您可以通过即时编译 JAX 指令来加快代码速度。例如,要编译缩放指数线性单位 (SELU) 函数,请使用 jax.numpy 中的 NumPy 函数并将 jax.jit 装饰器添加到该函数,如下所示:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

JAX 会跟踪您的指令并将其转换为 jaxpr。这使得加速线性代数 (XLA) 编译器能够为您的加速器生成非常高效的优化代码。

gard

JAX 最强大的功能之一是您可以轻松获取 gard。使用 jax.grad,您可以定义一个新函数,即符号导数。

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

正如您在示例中看到的,您不仅限于一阶导数。您可以通过简单地按顺序链接 grad 函数 n 次来获取 n 阶导数。

vmap 和 pmap

矩阵乘法使所有批次尺寸正确需要非常细心。 JAX 的矢量化映射函数 vmap 通过对函数进行矢量化来减轻这种负担。基本上,每个按元素应用函数 f 的代码块都是由 vmap 替换的候选者。让我们看一个例子。

计算线性函数:

def linear(x):
 return weights @ x

在一批示例 [x₁, x2,..] 中,我们可以天真地(没有 vmap)实现它,如下所示:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

相反,通过使用 vmap 对线性进行向量化,我们可以一次性计算整个批次:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)

本文由 mdnice 多平台发布

相关推荐

  1. Python数据处理和常用库(NumPy、Pandas)

    2023-12-29 09:50:04       17 阅读
  2. python基本数据注释)

    2023-12-29 09:50:04       17 阅读
  3. 深度学习 - PyTorch简介

    2023-12-29 09:50:04       6 阅读
  4. pytorch深度学习

    2023-12-29 09:50:04       14 阅读
  5. pytorch深度学习

    2023-12-29 09:50:04       10 阅读
  6. Pytorch深度学习

    2023-12-29 09:50:04       14 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-29 09:50:04       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-29 09:50:04       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-29 09:50:04       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-29 09:50:04       20 阅读

热门阅读

  1. catboost回归自动调参

    2023-12-29 09:50:04       27 阅读
  2. 7天玩转 Golang 标准库之 sort

    2023-12-29 09:50:04       37 阅读
  3. 多线程多进程的使用场景和常见问题处理

    2023-12-29 09:50:04       40 阅读
  4. MySQL数据库索引

    2023-12-29 09:50:04       37 阅读
  5. Presentation Error:编程中的细节之战

    2023-12-29 09:50:04       31 阅读
  6. 获取请求的真实ip

    2023-12-29 09:50:04       36 阅读
  7. opencv c++圆检测

    2023-12-29 09:50:04       38 阅读
  8. Docker Compose容器编排实战

    2023-12-29 09:50:04       32 阅读
  9. PHP:服务器端脚本语言的瑰宝

    2023-12-29 09:50:04       30 阅读