用不到100 行,Rust 使 Python 快 100 倍!

2023-06-0408:47:31编程语言入门到精通Comments997 views字数 13467阅读模式

作者 | Ohad Ravid文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

译者 | 平川文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

策划 | 刘燕文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

本文最初发布于 Ohad Ravid 的个人博客 Tea and Bits。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

不久前,在 $work,我们的一个核心 Python 库遇到了性能问题。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

这个库是我们 3D 处理管道的基础库。它相当大,也很复杂。它使用 NumPy 和其他科学相关的 Python 包进行广泛的数学和几何操作。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我们的系统还必须在 CPU 资源有限的环境中运行,虽然它一开始表现得不错,但随着实际并发用户数的增长,我们开始遇到问题,系统难以满足负载增长的需求。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我们的结论是,要处理增加的工作负载,系统至少要快 50 倍。我们认为,Rust 可以帮助我们实现这一目标。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

因为我们遇到的性能问题很常见,所以我们要在这篇(并不算短的)文章中重现并解决它们。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

泡壶茶或冲杯咖啡,我将介绍:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

(a)基本问题;
(b)用来解决这个问题的几次迭代优化。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

你也可以直接跳到文章末尾,查看最终代码。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

一个实用的例子

让我们创建一个小型库,用它重现我们最初遇到的性能问题(但其所完成的工作是任意的)。
假设你有一个多边形列表和一个点列表,都是二维的。出于业务原因,我们希望将每个点“匹配”到单个多边形。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我们假想的库将:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

1、从点和多边形的初始列表开始(均为二维的)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

2、对于每个点,计算它到多边形中心的距离,找到一个离它最近的、小得多的多边形子集。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

3、从这些多边形中,选择一个“最佳”的(我们将以“面积最小的”为“最佳”)。代码类似下面这样(这里 有完整代码):文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

from typing import List, Tuple
import numpy as np
from dataclasses import dataclass
from functools import cached_property

Point = np.array

@dataclass
class Polygon:
    x: np.array
    y: np.array

    @cached_property
    def center(self) -> Point: ...
    def area(self) -> float: ...

def find_close_polygons(polygon_subset: List[Polygon], point: Point, max_dist: float) -> List[Polygon]:
    ...

def select_best_polygon(polygon_sets: List[Tuple[Point, List[Polygon]]]) -> List[Tuple[Point, Polygon]]:
    ...

def main(polygons: List[Polygon], points: np.ndarray) -> List[Tuple[Point, Polygon]]:

主要的困难(性能方面)在于混合使用 Python 对象和 numpy 数组。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我们马上会深入分析这个问题。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

值得注意的是,对于这个玩具库来说,把部分 / 所有内容转换为 向量化 的 numpy 或许是可行的,但对于真正的库来说,这几乎是不可能的,因为这会使代码的可读性和可修改性大大降低,并且收益非常有限(这里有一个部分向量化的版本,它的速度更快,但与我们想要实现的结果差很远)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

此外,任何基于 JIT 的技巧(PyPy/numba)所能带来的收益都非常小(保准起见,我们会进行测量)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

为什么不完全用 Rust™重写呢?文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

虽然完全重写很有吸引力,但存在几个问题:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  1. 该库已经使用 numpy 进行了大量的计算,所以我们为什么要期望 Rust 更好呢?
  2. 它很大很复杂,对业务非常重要,而且高度的算法化,所以这将需要大约几个月的工作量,我们可怜的本地服务器如今已经奄奄一息了。
  3. 一群友好的研究人员正在积极地研究这个库,实现更好的算法,并做了大量的实验。要让他们学习一种新的编程语言,等待程序编译并与借用检查器进行斗争,他们肯定不会很乐意。他们会感激我们没有给他们带来太大的麻烦。
    性能分析工具初探

该介绍我们的性能分析器朋友了。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

Python 有一个内置的性能分析器(cProfile),但在这种情况下,它并不是真正适合这项工作的工具:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  1. 这将给所有的 Python 代码带来大量的开销,而且对原生代码没有任何影响,因此结果可能会有偏差。
  2. 我们将无法看到原生栈帧,也就是说,我们将无法深入观察 Rust 代码。我们将使用py-spy(GitHub)。

py-spy是一个 采样分析器 #Statistical_profilers),可以看到原生栈帧。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

该项目还慷慨地将预先构建好的版本发布到了 PyPi,我们只需运行pip install py-spy 就可以开始工作了。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我们还需要一些东西来度量。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

# measure.py
import time
import poly_match
import os

# 减少噪音,实际地改进性能
os.environ["OPENBLAS_NUM_THREADS"] = "1"

polygons, points = poly_match.generate_example()

# 随着代码的速度越来越快,我们会增加这个值
NUM_ITER = 10

t0 = time.perf_counter()
for _ in range(NUM_ITER):
    poly_match.main(polygons, points)
t1 = time.perf_counter()

took = (t1 - t0) / NUM_ITER
print(f"Took and avg of {took * 1000:.2f}ms per iteration")

虽然不是很科学,但也可以让我们走很远。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

要做好基准测试很难。话虽如此,不要因为追求完美的基准测试设置而倍感压力,特别是当你开始优化一个程序时。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

——Nicholas Nethercote,《Rust 性能手册》运行这个脚本可以得出基线:
$ python measure.py 平均每次迭代耗时 293.41 毫秒文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

对于原始库,为了涵盖了所有情况,我们用了 50 个不同的示例。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

这与系统的整体性能是匹配的,这意味着我们可以着手降低这个数值了。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

注:我们还可以使用 PyPy 进行度量(为了充分发挥 JIT 的能力,我们还在前面加了一个预热操作)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

$ conda create -n pypyenv -c conda-forge pypy numpy && conda activate pypyenv
$ pypy measure_with_warmup.py
平均每次迭代耗时 1495.81 毫秒

先测量一下

那么,首先让我们看看是哪里慢。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

$ py-spy record --native -o profile.svg -- python measure.py
py-spy> 每秒采样 100 次。按 Control-C 退出。

平均每次迭代耗时 365.43 毫秒

py-spy> 进程退出,采样结束。
py-spy> 将火焰图数据写入'profile.svg'。样本数:391 错误数:0

可以看到,开销已经非常小。为了进行比较,我们使用cProfile得出了以下结果:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

$ python -m cProfile measure.py
平均每次迭代耗时 546.47 毫秒
         7.806 秒 7551778 次函数调用(7409483 次原始调用)
         ..

我们得到了下面这个漂亮的淡红色 火焰图:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

用不到100 行,Rust 使 Python 快 100 倍!

每个框都对应一个函数,我们可以看每个函数花费的相对时间,包括它正在调用的函数(沿着图 / 堆栈向下。注意: 原文中每个框都可以点击进入)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

这里,我们主要得出了以下结论:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  1. 绝大部分时间都花在了find_close_polygons上。
  2. 大部分时间都用在了norm ,这是一个 numpy 函数。

我们看一下 find_close_polygons:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

def find_close_polygons(
    polygon_subset: List[Polygon], point: np.array, max_dist: float
) -> List[Polygon]:
    close_polygons = []
    for poly in polygon_subset:
        if np.linalg.norm(poly.center - point) < max_dist:
            close_polygons.append(poly)

    return close_polygons

 

我们将在 Rust 中重写这个函数。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

在深入研究细节之前,有几点需要注意:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  1. 该函数接受并返回复杂对象(Polyonnp.array)。
  2. 对象的大小非常重要(所以复制可能有开销)。
  3. 该函数会被大量调用(所以我们引入的开销可能会产生很大的影响)。

我的第一个 Rust 模块

pyo3是一个用于 Python 和 Rust 交互的 crate。它的文档非常完善。要了解基本设置,请点击 这里。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我们将把这个 crate 命名为poly_match_rs,并添加名为find_close_polygons的函数。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

mkdir poly_match_rs && cd "$_"
pip install maturin
maturin init --bindings pyo3
maturin develop

一开始,这个 crate 看起来是下面这样:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

use pyo3::prelude::*;

#[pyfunction]
fn find_close_polygons() -> PyResult<()> {
    Ok(())
}

#[pymodule]
fn poly_match_rs(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(find_close_polygons, m)?)?;
    Ok(())
}

我们还需要记住,每次修改 Rust 库时都要执行maturin develop文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

就是这样!我们调用下新函数,看看会发生什么。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

>>> poly_match_rs.find_close_polygons(polygons, point, max_dist)
E TypeError: poly_match_rs.poly_match_rs.find_close_polygons() takes no arguments (3 given)

v1 —— 简单的 Rust 译文

我们将从匹配预期的 API 开始。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

PyO3 在从 Python 到 Rust 的转换方面非常聪明,所以这非常简单:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

#[pyfunction]
fn find_close_polygons(polygons: Vec<PyObject>, point: PyObject, max_dist: f64) -> PyResult<Vec<PyObject>> {
    Ok(vec![])
}

顾名思义,PyObject是一个通用的 Python 对象。我们将尝试和它交互。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

这样程序应该可以运行了(尽管不正确)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我将复制粘贴原始的 Python 函数,然后仅仅修正下语法。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

#[pyfunction]
fn find_close_polygons(polygons: Vec<PyObject>, point: PyObject, max_dist: f64) -> PyResult<Vec<PyObject>> {
    let mut close_polygons = vec![];

    for poly in polygons {
        if norm(poly.center - point) < max_dist {
            close_polygons.push(poly)
        }
    }

    Ok(close_polygons)
}

很酷,但编译不通过:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

% maturin develop
...

error[E0609]: no field `center` on type `Py<PyAny>`
 --> src/lib.rs:8:22
  |
8 |         if norm(poly.center - point) < max_dist {
  |                      ^^^^^^ unknown field



error[E0425]: cannot find function `norm` in this scope
 --> src/lib.rs:8:12
  |
8 |         if norm(poly.center - point) < max_dist {
  |            ^^^^ not found in this scope



error: aborting due to 2 previous errors ] 58/59: poly_match_rs

我们需要三个 crate 来实现我们的函数:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

# 针对 Rust 原生数组操作
ndarray = "0.15"

# 面向数组的`norm`函数
ndarray-linalg = "0.16"  

# 访问 numpy 基于`ndarray`创建的对象
numpy = "0.18"

首先,将不透明的通用对象point: PyObject 转换为我们可以使用的对象。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

就像我们向 PyO3 请求“PyObjectsVec”一样,我们可以请求 numpy-array,它会自动为我们转换参数。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

use numpy::PyReadonlyArray1;

#[pyfunction]
fn find_close_polygons(
    // 一个“拥有 GIL 锁”的对象,这样我们就可以访问 Python 管理的内存
    py: Python<'_>,
    polygons: Vec<PyObject>,
    // 一个我们可以访问的 numpy 数组的引用
    point: PyReadonlyArray1<f64>,
    max_dist: f64,
) -> PyResult<Vec<PyObject>> {
    // 转换为`ndarray::ArrayView1`,一个完全可操作的原生数组
    let point = point.as_array();
    ...
}

因为point现在是一个ArrayView1,所以我们可以使用它了,如下所示:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

// 使`norm`函数可用
use ndarray_linalg::Norm;

assert_eq!((point.to_owned() - point).norm(), 0.);

现在,我们只需要获取每个多边形的中心,并将其“强制转换”为ArrayView1文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

在 PyO3 中,这个过程如下所示:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

let center = poly
  .getattr(py, "center")?                 // Python 风格的 getattr,需要 GIL 令牌(`py`)。
  .extract::<PyReadonlyArray1<f64>>(py)?  // 告诉 PyO3 将结果转换成什么。
  .as_array()                             // 类似之前的`point`。
  .to_owned();                            // 需要让`-`的一侧被“owned”。

有点拗口,但总的来说,结果还是相当清晰的,就是对原始代码进行逐行翻译:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

use pyo3::prelude::*;

use ndarray_linalg::Norm;
use numpy::PyReadonlyArray1;

#[pyfunction]
fn find_close_polygons(
    py: Python<'_>,
    polygons: Vec<PyObject>,
    point: PyReadonlyArray1<f64>,
    max_dist: f64,
) -> PyResult<Vec<PyObject>> {
    let mut close_polygons = vec![];
    let point = point.as_array();
    for poly in polygons {
        let center = poly
            .getattr(py, "center")?
            .extract::<PyReadonlyArray1<f64>>(py)?
            .as_array()
            .to_owned();

        if (center - point).norm() < max_dist {
            close_polygons.push(poly)
        }
    }

    Ok(close_polygons)
}

与原来的版本进行比较:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

def find_close_polygons(
    polygon_subset: List[Polygon], point: np.array, max_dist: float
) -> List[Polygon]:
    close_polygons = []
    for poly in polygon_subset:
        if np.linalg.norm(poly.center - point) < max_dist:
            close_polygons.append(poly)

    return close_polygons

我们希望,与原来的函数相比,这个版本有一些优势,但是有多大的优势呢?文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

 $ (cd ./poly_match_rs/ && maturin develop)
$ python measure.py
平均每次迭代耗时 609.46 毫秒

所以……Rust 真的很慢吗?不!我们只是忘了要求速度!如果我们使用maturin develop --release来运行,得到的结果就会好很多:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

$ (cd ./poly_match_rs/ && maturin develop --release)
$ python measure.py
平均每次迭代耗时 23.44 毫秒

我们还想查看原生代码,因此,我们将在 release 配置中启用调试符号。既然这样,我们不妨要求最高速度。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

# 添加到 Cargo.toml
[profile.release]
debug = true       # 启用性能分析器的调试符号
lto = true         # 链接时优化
codegen-units = 1  # 编译慢,但运行快

v2 —— 使用 Rust 进一步重写

现在,使用py-spy中的--native 标志,我们可以同时查看 Python 和新的原生代码。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

再次运行py-spy :文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

$ py-spy record --native -o profile.svg -- python measure.py
py-spy> 采样过程每秒 100 次。按 Ctrl+C 键退出。

我们得到了下面这个火焰图(非红色的部分也加进来了,以供参考):文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

从分析器的输出中,我们可以看到一些有趣的东西:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  1. find_close_polygons::...::trampoline (Python 直接调用的符号)和__pyfunction_find_close_polygons(我们的实际实现)的相对大小。
    耗时占比 95% 对 88%,所以开销很小。
  2. 实际逻辑(If (center - point).norm() < max_dist{…})即lib_v1.rs: 22(右侧的小框)占总运行时间的大约 9%。
    所以 10 倍的改进还是可能的!
  3. 大部分时间都消耗在了lib_v1.rs: 16,即poly.getattr(…).extract(…),放大一下就可以看到getattr,它在获取使用了as_array的底层数组。
    我们从中得出的结论是,我们需要专注于解决第 3 点,方法就是用 Rust 重写 Polygon。

看下我们的目标:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

@dataclass
class Polygon:
    x: np.array
    y: np.array
    _area: float = None

    @cached_property
    def center(self) -> np.array:
        centroid = np.array([self.x, self.y]).mean(axis=1)
        return centroid

    def area(self) -> float:
        if self._area is None:
            self._area = 0.5 * np.abs(
                np.dot(self.x, np.roll(self.y, 1)) - np.dot(self.y, np.roll(self.x, 1))
            )
        return self._area

我们将尽可能地保留现有的 API,但我们真的不需要area 那么快(目前)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

实际的类可能会包含其他复杂的东西,比如merge方法,它使用了scipy.spatial中的ConvexHull文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

为了降低成本(并限制这篇本已很长的文章的篇幅),我们只将Polygon的“核心”功能迁移到 Rust,并在 Python 中通过子类化来实现 API 的其余部分。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

struct 看起来是这样的:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

// `Array1`是一维数组, `numpy` crate 可以很方便地操作
use ndarray::Array1;

// `subclass` 告诉 PyO3 可以在 Python 中将其子类化
#[pyclass(subclass)]
struct Polygon {
    x: Array1<f64>,
    y: Array1<f64>,
    center: Array1<f64>,
}

现在,我们需要实际地实现它。我们想把poly.{x, y, center} 暴露为:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  1. 属性。
  2. numpy 数组。

我们还需要一个构造函数,以便 Python 可以创建新的 Polygon 。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

use numpy::{PyArray1, PyReadonlyArray1, ToPyArray};

#[pymethods]
impl Polygon {
    #[new]
    fn new(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>) -> Polygon {
        let x = x.as_array();
        let y = y.as_array();
        let center = Array1::from_vec(vec![x.mean().unwrap(), y.mean().unwrap()]);

        Polygon {
            x: x.to_owned(),
            y: y.to_owned(),
            center,
        }
    }

    // 返回类型中的`Py<..>`是一种说明“Python 拥有一个对象”的方式。
    #[getter]               
    fn x(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
        Ok(self.x.to_pyarray(py).to_owned()) // 创建一个归 Python 所有的 numpy 版本的`x`。
    }

    // 对`y`和`center`做同样处理。
}

我们需要将新的结构体作为类添加到模块中:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

#[pymodule]
fn poly_match_rs(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Polygon>()?; // new.
    m.add_function(wrap_pyfunction!(find_close_polygons, m)?)?;
    Ok(())
}

现在,我们可以更新 Python 代码来使用它了:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

class Polygon(poly_match_rs.Polygon):
    _area: float = None

    def area(self) -> float:
        ...

它可以编译通过,也确实可以工作,但慢了许多!(请记住,现在每次访问xycenter时都需要创建一个新的 numpy 数组)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

为了真正地提高性能,我们需要从 Python Polygon 列表中extract 原来 Rust 版本的Polygon 。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

对于这种类型的操作,PyO3 非常灵活,为我们提供了多种方法。我们的一个限制是,还需要返回 Python- Polygon ,并且我们不想对实际的数据做任何克隆。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

也可以在每个PyObject 上手动调用.extract::<Polygon>(py)? ,但我们要求 PyO3 直接提供Py<Polygon>文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

这是对 Python 拥有的对象的引用,我们希望该对象包含一个原生pyclass结构的实例(或者,在我们的情况下是一个子类)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

#[pyfunction]
fn find_close_polygons(
    py: Python<'_>,
    polygons: Vec<Py<Polygon>>,             // 引用 Python 拥有的对象
    point: PyReadonlyArray1<f64>,
    max_dist: f64,
) -> PyResult<Vec<Py<Polygon>>> {           // 返回相同的`Py`引用,未经修改
    let mut close_polygons = vec![];
    let point = point.as_array();
    for poly in polygons {
        let center = poly.borrow(py).center // 需要使用 GIL (`py`) 借用底层的`Polygon`
            .to_owned();

        if (center - point).norm() < max_dist {
            close_polygons.push(poly)
        }
    }

    Ok(close_polygons)
}

让我们看看这段代码会带来什么:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

$ python measure.py
平均每次迭代耗时 6.29 毫秒

就快达到目标了!再提升 2 倍就行了!文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

v3 —— 避免分配

让我们再一次启动剖析器。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  1. 我们首先看下select_best_polygon,它现在调用了一些 Rust 代码(当它获取x & y向量)
    我们可以修复这个问题,但潜在的改善可能非常小(可能只有 10%)
  2. 我们看到,20% 的时间花在了extract_argument(在lib_v2.rs: 48下)上,所以我们还是付出了不小的开销!
    但大部分时间都在PyIterator::nextPyTypeInfo::is_type_of中,要改进并不容易。
  3. 我们看到,花在分配东西上的时间相当多!
    lib_v2.rs: 58是我们的if,我们还看到了drop_in_placeto_owned
    这行代码大约了占总时间的 35%,远超我们的预期:所有数据都到位的话,这应该是”快速位(fast bit)“。
  4. 让我们来处理下最后一点。

以下是有问题的代码片段:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

let center = poly.borrow(py).center
    .to_owned();

if (center - point).norm() < max_dist { ... }

我们希望避免to_owned 。但norm 需要一个 owned 对象,所以我们必须手动实现。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

(这里,我们之所以可以改进ndarray是因为我们知道数组实际上只包含 2 个f32)。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

看起来就像下面这样:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

use ndarray_linalg::Scalar;

let center = &poly.as_ref(py).borrow().center;

if ((center[0] - point[0]).square() + (center[1] - point[1]).square()).sqrt() < max_dist {
    close_polygons.push(poly)
}

但是,借用检查器不允许我们这样做:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

error[E0505]: cannot move out of `poly` because it is borrowed
  --> src/lib.rs:58:33
   |
55 |         let center = &poly.as_ref(py).borrow().center;
   |                       ------------------------
   |                       |
   |                       borrow of `poly` occurs here
   |                       a temporary with access to the borrow is created here ...
...
58 |             close_polygons.push(poly);
   |                                 ^^^^ move out of `poly` occurs here
59 |         }
60 |     }
   |     - ... and the borrow might be used here, when that temporary is dropped and runs the `Drop` code for type `PyRef`

像往常一样,借用检查器是正确的:我们正在进行内存犯罪。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

更简单的修复方法是直接克隆,并编译close_polygons.push(poly.clone())文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

实际上,这个克隆的成本非常低,因为我们只incrPython 对象的引用计数。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

然而,在这种情况下,我们也可以通过做一个典型的 Rust 技巧来缩短借用:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

let norm = {
    let center = &poly.as_ref(py).borrow().center;

    ((center[0] - point[0]).square() + (center[1] - point[1]).square()).sqrt()
};

if norm < max_dist {
    close_polygons.push(poly)
}

因为poly只在内部作用域内被借用,一旦到达close_polygons.push ,编译器就可以知道我们不再持有该引用,并将愉快地完成新版本的编译。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

最后,我们得到如下结果:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

$ python measure.py
平均每次迭代耗时 2.90 毫秒

比原来的代码快了 100 倍。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

小结

我们是从下面的 Python 代码开始的:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

@dataclass
class Polygon:
    x: np.array
    y: np.array
    _area: float = None

    @cached_property
    def center(self) -> np.array:
        centroid = np.array([self.x, self.y]).mean(axis=1)
        return centroid

    def area(self) -> float:
        ...

def find_close_polygons(
    polygon_subset: List[Polygon], point: np.array, max_dist: float
) -> List[Polygon]:
    close_polygons = []
    for poly in polygon_subset:
        if np.linalg.norm(poly.center - point) < max_dist:
            close_polygons.append(poly)

    return close_polygons

# Rest of file (main, select_best_polygon).

我们使用py-spy对它进行了性能分析,即使只是对 find_close_polyons 做 最简单的行对行翻译,性能提升也超过了 10 倍。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

我们又做了多轮 profile-write-measure 迭代,直到运行速度获得了 100 倍的改进,而且 API 与原来的库相同。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

最终的 Python 代码如下所示:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

import poly_match_rs
from poly_match_rs import find_close_polygons

class Polygon(poly_match_rs.Polygon):
    _area: float = None

    def area(self) -> float:
        ...

# 文件的其余部分没有变化 (main, select_best_polygon)。

它调用以下 Rust 代码:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

use pyo3::prelude::*;

use ndarray::Array1;
use ndarray_linalg::Scalar;
use numpy::{PyArray1, PyReadonlyArray1, ToPyArray};

#[pyclass(subclass)]
struct Polygon {
    x: Array1<f64>,
    y: Array1<f64>,
    center: Array1<f64>,
}

#[pymethods]
impl Polygon {
    #[new]
    fn new(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>) -> Polygon {
        let x = x.as_array();
        let y = y.as_array();
        let center = Array1::from_vec(vec![x.mean().unwrap(), y.mean().unwrap()]);

        Polygon {
            x: x.to_owned(),
            y: y.to_owned(),
            center,
        }
    }

    #[getter]
    fn x(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
        Ok(self.x.to_pyarray(py).to_owned())
    }

    // 对`y`和`center`做同样处理。
}

#[pyfunction]
fn find_close_polygons(
    py: Python<'_>,
    polygons: Vec<Py<Polygon>>,
    point: PyReadonlyArray1<f64>,
    max_dist: f64,
) -> PyResult<Vec<Py<Polygon>>> {
    let mut close_polygons = vec![];
    let point = point.as_array();
    for poly in polygons {
        let norm = {
            let center = &poly.as_ref(py).borrow().center;

            ((center[0] - point[0]).square() + (center[1] - point[1]).square()).sqrt()
        };

        if norm < max_dist {
            close_polygons.push(poly)
        }
    }

    Ok(close_polygons)
}

#[pymodule]
fn poly_match_rs(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Polygon>()?;
    m.add_function(wrap_pyfunction!(find_close_polygons, m)?)?;
    Ok(())
}

本文要点

  • Rust(在 pyo3 的帮助下)以最小的妥协为普通的 Python 代码解锁了真正的原生性能体验。
  • 对于研究人员来说,Python 是一个极好的 API,和 Rust 快速构建块一起组成了一个极其强大的组合。
  • 性能分析非常有趣,它能让你真正地了解代码中发生的一切。

最后:电脑的速度快得惊人。下次,当你等待某件事完成时,可以考虑启动一个分析器,你可能会了解到一些新东西。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

原文链接:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

https://ohadravid.github.io/posts/2023-03-rusty-python/文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ymba/44279.html

  • 本站内容整理自互联网,仅提供信息存储空间服务,以方便学习之用。如对文章、图片、字体等版权有疑问,请在下方留言,管理员看到后,将第一时间进行处理。
  • 转载请务必保留本文链接:https://www.cainiaoxueyuan.com/ymba/44279.html

Comment

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定