打开网易新闻 查看更多图片

JAX

自2018 年底谷歌的 JAX出现以来,它的受欢迎程度一直在稳步增长。DeepMind 202年宣布使用 JAX 来加速自己的相关研究,越来越多来自Google 大脑与其他项目也在使用 JAX。随着JAX越来越火,似乎 JAX 是下一个大型深度学习框架?让人不禁联想是否Google会使用JAX来代替TensorFlow的江湖地位。

什么是JAX

JAX 是Autograd和XLA的结合,JAX 本身不是一个深度学习的框架,他是一个高性能的数值计算库,更是结合了可组合的函数转换库,用于高性能机器学习研究。深度学习只是其中的一部分而已,但是你完全可以把自己的深度学习移植到JAX 上面。

借助Autograd的更新版本,JAX 可以自动区分原生 Python 和 NumPy 函数。它可以通过循环、分支、递归和闭包进行微分,并且可以对导数的导数进行导数。它支持反向模式微分(也称为反向传播)grad和正向模式微分,两者可以任意组合成任何顺序。

新功能是 JAX 使用XLA 在 GPU 和 TPU 上编译和运行 NumPy 程序。默认情况下,编译发生在后端,库调用得到及时编译和执行。但是 JAX 还允许使用单一功能 API 将自己的 Python 函数即时编译到 XLA 优化内核中与 jit.编译和自动微分可以任意组合,因此可以在不离开 Python 的使用环境下表达复杂的算法并获得最大的性能。甚至可以使用一次对多个GPU 或 TPU 内核进行编程pmap,并对整个事物进行区分。

JAx

——2——

JAX运行速度对比

我们用 NumPy 和 JAX 将矩阵的前三个幂相加来对比一下jax与numpy的运行速度,首先是我们的 NumPy 实现:

def fn(x): return x + x*x + x*x*x x = np.random.randn(10000, 10000).astype(dtype='float32') %timeit -n5 fn(x)5 loops, best of 5: 478 ms per loop

我们发现这个计算大约需要478 毫秒。

然后,我们用 JAX 实现这个计算:

jax_fn = jit(fn)x = jnp.array(x)%timeit jax_fn(x).block_until_ready()100 loops, best of 5: 5.54 ms per loop

JAX 仅在5.54 毫秒内执行此计算-比 NumPy 快 86 倍。

打开网易新闻 查看更多图片

JAX 与numpy对比速度

从根本上说,如果您从事与科学计算相关的任何领域,那么您应该学习 JAX。并把自己的代码移植到JAX上面

——3——

JAX特征

1. NumPy on Accelerators - NumPy 是使用 Python 进行科学计算的基础包之一,但它仅与 CPU 兼容。JAX 提供了 NumPy 的实现(具有几乎相同的 API),它可以非常轻松地在GPU和TPU 上运行。对于许多用户而言,仅此一项就足以证明使用 JAX 的合理性。

2. XLA ——XLA,即加速线性代数,是专为线性代数设计的全程序优化编译器。JAX是建立在 XLA 之上,显着提高了计算速度上限。

3. JIT - JAX 允许您使用 XLA 将自己的函数转换为即时 (JIT) 编译版本。这意味着您可以通过在计算函数中添加一个简单的函数装饰器来将计算速度提高几个数量级。

4. 自动微分——JAX 文档将 JAX 称为“Autograd 和 XLA,结合在一起” 。自动微分的能力在科学计算的许多领域都至关重要,JAX 提供了几个强大的自动微分工具。

5. 深度学习——虽然 JAX 本身不是一个深度学习框架,但它确实为深度学习的目的提供了一个绰绰有余的功能。在 JAX 之上构建了许多旨在构建深度学习功能的库,包括Flax、Haiku和Elegy。JAX 对 Hessians 的高效计算也与深度学习相关,因为它们使高阶优化技术更加可行。

6. 通用可微分编程范式——当然可以使用 JAX 来构建和训练深度学习模型,它也为通用可微编程提供了一个框架。这意味着 JAX 可以通过使用基于模型的机器学习方法来解决问题,从而利用通过数十年研究建立的给定领域的先验知识。

7.JAX 为此类功能转换合并了一个可扩展系统,并且具有典型的四个主要功能转换函数:

grad()用于评估输入函数的梯度函数vmap()用于操作的自动矢量化pmap()便于计算并行化,以及jit()将函数转换为即时编译的版本

JAX,TensorFlow,pytorch 对比

JAX,TensorFlow,pytorch 对比

Tensorflow

PyTorch

Jax

Developed by

Google

Facebook

Google

Flexible

No

Yes

Yes

Graph-Creation

Static/Dynamic

Dynamic

Static

Target Audience

Researchers,

Developers

Researchers,

Developers

Researchers

Low/High-level API

High Level

Both

Both

Development Stage

Mature( v2.4.1 )

Mature( v1.8.0 )

Developing( v0.1.55 )

TensorFlow

1、tensoflow是一个对用户非常友好的框架。高级 API -Keras 的可用性使模型层定义、损失函数和模型创建变得非常容易。TensorFlow2.0 带有动态图类型。这使得该库对用户更加友好,并且是对以前版本的重大升级。

2、Keras 的这种高级接口有一定的缺点。由于 TensorFlow 抽象了许多底层机制,因此研究人员在使用他们的模型可以做什么方面的自由度降低了。

3、Tensorflow 提供的最吸引人的东西之一是 TensorBoard,它实际上是 TensorFlow 可视化工具包。它允许用户可视化损失函数、模型图、分析等。

因此,如果开始使用深度学习或希望轻松部署您的模型,TensorFlow 可能是一个很好的开始使用的框架。TensorFlow Lite 使将 ML 模型部署到移动和边缘设备变得更加容易。

打开网易新闻 查看更多图片

pytorch

1、与 TensorFlow 不同,PyTorch 使用动态类型图,这意味着执行图是随时随地创建的。它允许我们随时修改和检查图的内部。

2、除了用户友好的高级 API 之外,PyTorch 确实有一个构建良好的 API,它允许对用户的机器学习模型进行越来越多的控制。我们可以在训练期间模型的前向和后向传递期间检查和修改输出。这对于梯度裁剪和神经风格迁移非常有效。

3、PyTorch 允许扩展他们的代码,轻松添加新的损失函数和用户定义的层。PyTorch autograd 足够强大,可以通过这些用户定义的层进行区分。用户还可以选择定义梯度的计算方式。

4、PyTorch 对数据并行性和 GPU 使用有广泛的支持。

5、PyTorch 比 TensorFlow 更 Pythonic。PyTorch 非常适合 python 生态系统,它允许使用 Python 调试器工具来调试 PyTorch 代码。PyTorch 因其高度的灵活性吸引了众多学术研究人员和工业界的关注

JAX

Jax 是来自 Google 的一个相对较新的机器学习库。它更像是一个 autograd 库,可以区分每个本机 python 和 NumPy 代码。

“Python+NumPy 程序的可组合转换:微分、向量化、JIT 到 GPU/TPU 等等”。该库利用 grad 函数转换将函数转换为返回原始函数梯度的函数。Jax 还提供了一个函数转换 JIT,用于对现有函数进行即时编译,并分别提供了用于矢量化和并行化的 vmap 和 pmap。

让我们看一下 JAX 的一些特性:

正如官方网站所描述的,JAX 能够对 Python+NumPy 程序进行可组合的转换:微分、向量化、JIT 到 GPU/TPU 等等。

与 PyTorch 相比,JAX 最重要的方面是如何计算梯度。在 Torch 中,图形是在前向传播期间创建的,而梯度是在后向传播期间计算的。另一方面,在 JAX 中,计算被表示为一个函数。使用grad()函数返回一个梯度函数,该函数直接计算给定输入的函数梯度。

JAX 是一个 autograd 工具,单独使用它几乎不是一个好主意。有各种基于 JAX 的 ML 库,其中值得注意的是 ObJax、Flax 和 Elegy。由于它们都使用相同的核心,并且接口只是 JAX 库的包装器,因此我们将它们放在同一个括号中。

Flax最初是在 PyTorch 生态系统下开发的。它更注重使用的灵活性。另一方面,Elegy更多的是受到 Keras 的启发。ObJAX主要是为面向研究的目的而设计的,它更注重简单性和可理解性。事实上,它与标语一致——由 研究人员为研究人员服务。

JAX 日益流行。许多研究人员在他们的实验中使用 JAX,从 PyTorch 吸引了一些流量。JAX 仍处于起步阶段,不建议刚开始探索深度学习的人(目前)。使用最先进的技术需要一些数学专业知识。访问官方存储库以了解有关这个有前途的新库的更多信息。

深度学习的成功很大程度上归功于自动分化。TensorFlow和PyTorch等流行库在训练期间跟踪神经网络参数的梯度,两者都包含用于实现深度学习常用神经网络功能的高级 API。JAX是 CPU、GPU 和 TPU 上的 NumPy,对于高性能机器学习研究具有出色的自动区分能力。除了深度学习框架外,JAX 还创建了一个超级精巧的线性代数库,具有自动微分和 XLA 支持。

从 PyTorch 或 Tensorflow 2 到 JAX 的转变无异于构造。PyTorch 在前向传递期间构建图形,在反向传递期间构建梯度。另一方面,JAX 允许用户将他们的计算表达为 Python 函数,并通过使用 grad() 对其进行转换,得到可以像计算函数一样评估的梯度函数——但它给出的不是输出,而是梯度函数作为输入的第一个参数的输出。

虽然 TensorFlow 和 Pytorch 已经编译了执行模式,但这些模式是后来添加的,因此留下了疤痕。例如,TensorFlow 的 Eager 模式与图形模式并非 100% 兼容,从而导致开发人员体验不佳。Pytorch有一个不好的历史,因为它们是在 Eager 模式下执行的,所以被迫使用不太直观的张量格式。JAX 以这两种模式出现——专注于渴望调试和 JIT 以执行繁重的计算。但是这些模式的简洁特性允许在需要时进行混合和匹配。

PyTorch 和 Tensorflow 是深度学习库,由用于现代深度学习方法的高级 API 组成。相比之下,JAX 是一个更注重功能的库,用于任意可微分编程。

虽然 JAX 对库调用使用即时编译,但 jit 函数转换可以用作自定义 Python 函数的装饰器。

实验者_在每个库中实现了一个简单的多层感知器,由一系列确定数值的加权连接组成,相当于输入张量和权重矩阵的矩阵乘法。结果表明,JAX 在数据验证中占主导地位。

JAX 具有比任何其他库更快 CPU 执行时间,并且对于仅使用矩阵乘法的实现而言,执行时间最短。实验还发现,虽然 JAX 在 matmul 方面优于其他库,但 PyTorch 在 Linear Layers 方面领先。PyTorch 在 GPU 上运行时的执行时间很快——PyTorch 和线性层用了 9.9 秒,批处理大小为 16,384,这与 JAX 以 1024 的批处理大小运行相对应。PyTorch 是最快的,在利用更高级别的神经网络 API 时,紧随其后的是 JAX 和 TensorFlow。对于实现全连接的神经层,PyTorch 的执行速度比 TensorFlow 更有效。另一方面,与类似的 Autograd 库相比,JAX 提供了一个加速器。当 MLP 实现仅限于矩阵乘法运算时,JAX 也是最快的。

——5——

JAX,TensorFlow,pytorch 如何选择

科学计算

若你开始尝试加速你的代码(即 GPU 和 TPU)上运行NumPy 。那么你应该学习一下JAX,且可以把自己的代码移植到到JAX。

如果你的大部分工作是在 Python 中使用大量自定义代码,那么现在就可以开始学习JAX

如果想构建某种基于模型/神经网络的混合系统,那么使用 JAX 可能是一个好的解决方案。

如果你的大部分工作不是在 Python 中,或者正在使用一些专门的软件进行研究(热力学、半导体等),那么 JAX 可能不是适合的工具,除非想从这些程序中导出数据以用于某种自定义计算处理。

打开网易新闻 查看更多图片

深度学习

虽然我们已经强调 JAX 是一个不是专门为深度学习构建的通用框架,但 JAX 速度很快并且具有自动微分能力,这意味着很多深度学习的代码可以移植到JAX 上面。

如果在 TPU 上进行训练,那么可以开始使用 JAX,尤其是如果当前正在使用 PyTorch。虽然PyTorch-XLA存在,但使用 JAX 进行 TPU 训练绝对是体验好得多。如果正在研究“非标准”架构/建模,例如SDE-Nets,那么绝对应该尝试 JAX。

如果你不是在构建奇特的架构,只是在 GPU 上训练常见的架构,那么你现在应该坚持使用 PyTorch 或 TensorFlow。虽然 PyTorch 仍然在研究领域占据主导地位,但使用 JAX 的论文数量一直在稳步增长,并且随着 DeepMind 和 Google 等重量级人物持续开发用于 JAX 的高级深度学习 API,在短短几年内JAX 就可以很容易看到爆炸性的采用率。

但是我们初学者,至少应该熟悉 JAX 的基础知识,尤其是您进行任何类型的机器学习研究时。