PyTorch CPU 矩阵乘法执行路径
- CPU:AMD EPYC 7763 64-Core Processor(仅支持 AVX2)
- Pytorch 版本:torch 2.6.0a0+git60d1c71
关于 PyTorch C++ 源码调试,可以参考之前的一篇文章 VSCode 配置 PyTorch C++ 源码开发环境(编译与调试)
Python 层
我们探究的是张量的 mm 方法,一个简单的示例如下
|
|
其 python 层的定义为
|
|
好吧,python 这没什么好看的
C++ 前端绑定
PyTorch 使用 pybind11 将 C++ 和 CUDA 代码集成到 Python 项目中,其 Tensor 对象的方法实现路径为:
torch/csrc/autograd/generated/python_variable_methods.cpp
其中有一个 PyMethodDef 变量,这用于在 Python 中绑定 C/C++ 函数,让它们能作为 Python 方法调用。
|
|
variable_methods 数组包含了一个绑定到 THPVariable_mm 函数的 Python 方法,方法名为 “mm” 。因此 THPVariable_mm 是实际实现的 C++ 函数。
|
|
C++层(算子分发)
THPVariable_mm 是处理与 Python 交互的函数,实际执行的是 self.mm(),即 Tensor::mm() 。
|
|
at::_ops::mm:这是代码生成器自动生成的一个命名空间,用于封装 mm 操作的所有调用入口。
call 方法实际上会调用 Dispatcher 系统的入口,将 mm 操作的两个矩阵 self 和 mat2 传递给 Dispatcher。紧接着 Dispatcher 系统就开始分发,这里涉及一系列 call 函数,包括从 dispatchTable_ 中获得具体的 kernel_function(lookup 函数) 。在此只列调用堆栈,不详细展开。
|
|
经过一系列 call 的跳转,来到了 at::Tensor mm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat2),这个函数主要提供了自动求导支持。
|
|
at::redispatch::mm() 又是一系列的 redispatch 函数,这里也只提供调用堆栈
|
|
可以看到 frame #2 中
|
|
其中 op.impl() 跳转到 mm_out_cpu()
|
|
最终调用 addmm_impl_cpu_(const_cast<Tensor&>(result), result, self, mat2, 0, 1) ,这里有两个需要注意的传参细节:
- 传入了两个 result
- 传入了两个固定的标量 0 和 1
重新布局矩阵内存
addmm_impl_cpu_ 是一个在 CPU 上实现的矩阵乘加运算,运算表达式:result = beta _ self + alpha _ m1 * m2 ,这里 beta 和 alpha 都是标量,m1 和 m2 是矩阵。addmm 常用于神经网络中的线性层运算,形式类似于 y = Wx + b 。
我们下面以 CPU 上的 int8 运算为例,示例代码:
|
|
一个 2×3 的矩阵 A 和一个 3×4 的矩阵 B 进行相乘,结果是一个 2×4 的矩阵 C 。
在上层 mm_out_cpu 中,self 是 A,mat2 是 B,result 是 C。因此根据其调用时的传参,进入到 addmm_impl_cpu_ 中的参数信息如下:
- result 和 self 是 C;
- m1 是 A,m2 是 B;
- beta 是 0,alpha 是 1,即简化为 C = A * B 。
下面我们结合例子详细解释 addmm_impl_cpu_ 函数,部分注释中给出了变量此时的值。
|
|
addmm_impl_cpu_ 探索了三个矩阵的内存布局,把它们从按行存储转换成了按列存储(Fortran contiguous),以此来优化后续计算的访存。注意,这里的 transpose_c 是用来判断 transpose_a 和 transpose_b 的值,实际没有传到下一层函数中。可以认为 transpose_a 和 transpose_b 参数的值隐式包含了 transpose_c 这一信息。
函数最后使用 gemm 来实现 addmm,gemm (general matrix multiplication) 是通用矩阵乘法,用于计算矩阵 C = alpha _ A _ B + beta * C 。关于 GEMM 的更多信息可以看 Basic Linear Algebra Subprograms - Wikipedia 。
gemm 实际运算
int8 实现
at::native::cpublas::gemm() 是一个模板函数,在 int8 数据类型下,由于找不到对应的优化实现,所以只能回退到默认实现(数据类型是 scalar_t),也是纯软件实现。
|
|
normalize_last_dims 函数的作用是标准化与矩阵乘法相关的步长,以确保矩阵在内存中的布局正确,特别是在处理转置情况时。
|
|
在本例中参数如下,所以不进行任何操作。
(at::native::TransposeType) transa = NoTranspose
(at::native::TransposeType) transb = NoTranspose
(int64_t) m = 4
(int64_t) n = 2
(int64_t) k = 3
(int64_t *) lda = 0x00007ffe2bfe9838
(int64_t *) ldb = 0x00007ffe2bfe9848
(int64_t *) ldc = 0x00007ffe2bfe9860
gemm_stub 是一种回退机制,用于在没有 BLAS 或 MKL-DNN 等算子库加速的情况下执行矩阵乘法。我们的 int8 在 CPU 上没有特定的优化路径,因此只能选择 gemm_stub() 。
|
|
cpublas_gemm_impl 是一个用于在 CPU 上执行 gemm 的实现。它使用模板和宏来支持不同的数据类型(例如 float、double 等),并通过 gemm_core_ 函数来执行实际的矩阵乘法。_AT_DISPATCH_GEMM_TYPES 是一个模板分发宏,根据 type 确定 scalar_t 的实际类型,然后在该类型上执行后续代码块。
|
|
gemm_core_ 是一个模板化的核心矩阵乘法实现,用于处理各种不同的矩阵转置组合情况。它根据矩阵 A 和 B 是否转置的情况,调用相应的子函数来执行具体的矩阵乘法计算。
|
|
在本例中调用的是 gemm_notrans_,这是 pytorch 矩阵乘法执行路径的最后一层了,我们仔细看一下这个函数
|
|
模板参数和函数参数:
- scalar_t:用于表示输入矩阵和输出矩阵元素的类型,例如 float 或 double。
- opmath_t:表示计算过程中使用的累加数据类型。这里与 scalar_t 相同(因为使用了 std::is_same_v 保证,如果不同则是另外一个实现)。
- m, n, k:表示参与运算的矩阵的形状。k 为矩阵 A 和矩阵 B 的共享维度。
- alpha 和 beta:标量,用于缩放矩阵 A * B 和 C。
- a、b、c:指向矩阵 A、B 和 C 数据的指针。
- lda、ldb、ldc:表示矩阵 A、B、C 的行步长。
这个函数先计算矩阵 C 的缩放,然后再计算矩阵 A 和矩阵 B 的乘,其中使用了循环展开减少循环控制开销。
此时请忘掉一般的矩阵乘法形式,因为我们已经进入到其访存优化的内部实现中!
回到我们的例子,模拟一下具体的计算过程
|
|
原本的矩阵数组是
- 2*3 矩阵 A = [1, -2, 0, -1, 2, -2]
- 3*4 矩阵 B = [2, -1, 1, 0, 0, -2, 1, 2, -1, 1, -2, 0]
- 计算后的 2*4 矩阵 C = [2, 3, -1, -4, 0, -5, 5, 4](作为对照结果)
在 addmm_impl_cpu_ 函数中经过内存布局调整后(矩阵 A 与 B 互换):
- 矩阵 A = [2, -1, 1, 0, 0, -2, 1, 2, -1, 1, -2, 0]
- 矩阵 B = [1, -2, 0, -1, 2, -2]
进入 gemm_notrans_ 的参数分别为:
(int64_t) m = 4
(int64_t) n = 2
(int64_t) k = 3
(signed char) alpha = '\x01'
(const signed char *) a = 0x00000000077eaec0
(int64_t) lda = 4
(const signed char *) b = 0x00000000077eae40
(int64_t) ldb = 3
(signed char) beta = '\0'
(signed char *) c = 0x00000000077f2900
(int64_t) ldc = 4
计算过程:
- 当 l = 0 且 j = 0 时:
- val = b[0 + 0 * 3] = b[0]
- i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
- c[0 * 4+ 0 * 4 + 0] += a[0 * 4 + 0 + 0 * 4] _ val,即计算 c[0] += a[0] _ b[0]
- c[0 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 0 * 4] _ val,即计算 c[1] += a[1] _ b[0]
- c[0 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 0 * 4] _ val,即计算 c[2] += a[2] _ b[0]
- c[0 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 0 * 4] _ val,即计算 c[3] += a[3] _ b[0]
- 当 l = 0 且 j = 1 时:
- val = b[0 + 1 * 3] = b[3]
- i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
- c[1 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 0 * 4] _ val,即计算 c[4] += a[0] _ b[3]
- c[1 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 0 * 4] _ val,即计算 c[5] += a[1] _ b[3]
- c[1 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 0 * 4] _ val,即计算 c[6] += a[2] _ b[3]
- c[1 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 0 * 4] _ val,即计算 c[7] += a[3] _ b[3]
- 当 l = 1 且 j = 0 时:
- val = b[1 + 0 * 3] = b[1]
- i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
- c[0 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 1 * 4] _ val,即计算 c[0] += a[4] _ b[1]
- c[0 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 1 * 4] _ val,即计算 c[1] += a[5] _ b[1]
- c[0 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 1 * 4] _ val,即计算 c[2] += a[6] _ b[1]
- c[0 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 1 * 4] _ val,即计算 c[3] += a[7] _ b[1]
- 当 l = 1 且 j = 1 时:
- val = b[1 + 1 * 3] = b[4]
- i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
- c[1 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 1 * 4] _ val,即计算 c[4] += a[4] _ b[4]
- c[1 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 1 * 4] _ val,即计算 c[5] += a[5] _ b[4]
- c[1 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 1 * 4] _ val,即计算 c[6] += a[6] _ b[4]
- c[1 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 1 * 4] _ val,即计算 c[7] += a[7] _ b[4]
- 当 l = 2 且 j = 0 时:
- val = b[2 + 0 * 3] = b[2]
- i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
- c[0 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 2 * 4] _ val,即计算 c[0] += a[8] _ b[2]
- c[0 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 2 * 4] _ val,即计算 c[1] += a[9] _ b[2]
- c[0 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 2 * 4] _ val,即计算 c[2] += a[10] _ b[2]
- c[0 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 2 * 4] _ val,即计算 c[3] += a[11] _ b[2]
- 当 l = 2 且 j = 1 时:
- val = b[2 + 1 * 3] = b[5]
- i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
- c[1 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 2 * 4] _ val,即计算 c[4] += a[8] _ b[5]
- c[1 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 2 * 4] _ val,即计算 c[5] += a[9] _ b[5]
- c[1 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 2 * 4] _ val,即计算 c[6] += a[10] _ b[5]
- c[1 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 2 * 4] _ val,即计算 c[7] += a[11] _ b[5]
把累加表达式合并得到:
$$ \begin{align} c[0] &= a[0]*b[0] + a[4]*b[1] + a[8]*b[2] = 2 + 0 + 0 = 2 \\\\ c[1] &= a[1]*b[0] + a[5]*b[1] + a[9]*b[2] = -1 + 4 + 0 = 3 \\\\ c[2] &= a[2]*b[0] + a[6]*b[1] + a[10]*b[2] = 1 - 2 + 0 = -1 \\\\ c[3] &= a[3]*b[0] + a[7]*b[1] + a[11]*b[2] = 0 - 4 + 0 = -4 \\\\ \\\\ c[4] &= a[0]*b[3] + a[4]*b[4] + a[8]*b[5] = -2 + 0 + 2 = 0 \\\\ c[5] &= a[1]*b[3] + a[5]*b[4] + a[9]*b[5] = 1 - 4 - 2 = -5 \\\\ c[6] &= a[2]*b[3] + a[6]*b[4] + a[10]*b[5] = -1 + 2 + 4 =5 \\\\ c[7] &= a[3]*b[3] + a[7]*b[4] + a[11]*b[5] = 0 + 4 + 0 =4 \\\\ \end{align} $$可以看到这和期望的结果一致。
所以,为什么要这么计算?
我们先来看看正常的矩阵乘法计算过程:
$$ \begin{pmatrix} a[0] & a[1] & a[2] \\\\ a[3] & a[4] & a[5] \end{pmatrix}\begin{pmatrix} b[0] & b[1] & b[2] & b[3] \\\\ b[4] & b[5] & b[6] & b[7] \\\\ b[8] & b[9] & b[10] & b[11] \end{pmatrix}=\begin{pmatrix} c[0] & c[1] & c[2] & c[3] \\\\ c[4] & c[5] & c[6] & c[7] \end{pmatrix} $$ $$ \begin{align} c[0] &= a[0]*b[0] + a[1]*b[4] + a[2]*b[8] = 2 + 0 + 0 = 2 \\\\ c[1] &= a[0]*b[1] + a[1]*b[5] + a[2]*b[9] = -1 + 4 + 0 = 3 \\\\ c[2] &= a[0]*b[2] + a[1]*b[6] + a[2]*b[10] = 1 - 2 + 0 = -1 \\\\ c[3] &= a[0]*b[3] + a[1]*b[7] + a[2]*b[11] = 0 - 4 + 0 = -4 \\\\ \\\\ c[4] &= a[3]*b[0] + a[4]*b[4] + a[5]*b[8] = -2 + 0 + 2 = 0 \\\\ c[5] &= a[3]*b[1] + a[4]*b[5] + a[5]*b[9] = 1 - 4 - 2 = -5 \\\\ c[6] &= a[3]*b[2] + a[4]*b[6] + a[5]*b[10] = -1 + 2 + 4 = 5 \\\\ c[7] &= a[3]*b[3] + a[4]*b[7] + a[5]*b[11] = 0 + 4 + 0 = 4 \\\\ \end{align} $$注意这里和上面 PyTorch 算法的区别就是 a 和 b 反一下。结合之前 addmm_impl_cpu_ 中矩阵的互换,我们就可以证明 PyTorch 算法的正确性。
相应代码
|
|
矩阵 C 的每一个元素在第三个循环中计算完成,变量 k 在迭代时,对于矩阵 A 来说是顺序访存,但对于矩阵 B 是跳跃访存(矩阵 B 按行存储这里却按列读取,即循环变量乘上了步长)。对于较大的矩阵,这种访问模式会导致处理器不断跳跃到新的缓存行,这样一来,缓存行中的大部分数据都无法被重用,导致频繁的缓存未命中。
而 PyTorch 中的做法:矩阵 C 的每一个元素是分组计算累加得到的(即 24(m*n*k) 次计算由于循环展开分成了 6(k*n) 组(六次内部循环),矩阵 C 的每个元素需要在 3 组计算中累加得到,比如 c[0] 需要参与第 1,3,5 组计算)。矩阵 B 中参与运算的元素使用临时变量保存,对矩阵 C 和矩阵 A 的访存都是连续的,这就避免了大量的缓存未命中,从而提高计算速度。
两者区别如下图所示:
-
一般矩阵乘法:b[0]->b[4]->b[8] 这样的访存是跨越行的,会导致缓存利用率下降。跳跃访存的总次数为 3*8 -1 = 23 次。
-
PyTorch 算法:内层循环都是顺序访存,只有中间循环(b[l + j * ldb],此时循环变量乘上了步长)才会跳跃访存(b[0]->b[3])。跳跃访存的总次数为 k*n = 3*2-1 = 5 次。
虽然内层循环中也有 c[j * ldc + i_i * 4 + 0]、a[i_i * 4 + 0 + l * lda] 这样的变量乘步长的形式,但在内层循环执行过程中,j 和 l 实际上是固定的常量,因此不能称为循环变量。内层循环的循环变量是 i_i 。
最后,作者数学、算法功底不好,对上述 PyTorch 矩阵乘法如有错误或不合理的解释,请指正。
float32 实现
示例代码(PyTorch 默认数据类型是 float32)
|
|
float32 数据类型在 addmm_impl_cpu_ 中调用的 gemm 就是另一种实现。
|
|
这个 gemm 实现会优先使用 mkldnn 的 gemm 实现
|
|
mkldnn_bf32_gemm() 是一个包装函数,实际执行在 mkldnn_gemm() 。
|
|
可以看到这里根据 bf16_usable/fp16_usable/bf32_usable 来判断 mkldnn 是否可用?
为什么我 float32 的数据类型会进行这些判断?这是因为 PyTorch 中有一个控制 float32 矩阵乘法的内部精度的设置(float32MatmulPrecision)。当 float32MatmulPrecision 为 medium 时,就会使用 bf16 进行内部计算,从而提供加速效果。如 use_mkldnn_bf32_matmul() 函数所述。
|
|
如果系统上有支持 mkldnn 的硬件设备,并且 float32MatmulPrecision 为 medium,那么 mkldnn_gemm 就会调用 ideep::matmul_forward::compute 来执行具体的计算。ideep(Intel Deep Learning Boost library)是由英特尔开发的深度学习库,专门优化了在 Intel 硬件(特别是支持 Intel® DL Boost 和 AVX-512 的硬件)上进行高效的深度学习运算。关于 mkldnn 和 ideep 的关系在这不展开叙述。
|
|
由于作者的机器不支持 mkldnn,因此关于 mkldnn 的内容就止步于这个第三方库中的函数定义,不再往下展开。
在我们的例子中,采用的是 sgemm_ 实现
|
|
sgemm* 函数是一个用于在 CPU 上执行单精度浮点数矩阵乘法的封装。sgemm 是一个典型的 BLAS(Basic Linear Algebra Subprograms)函数,专门用于进行单精度浮点矩阵相乘。PyTorch 通过 sgemm* 函数,利用了底层的 BLAS 实现来加速矩阵计算。
可以使用以下命令确定 PyTorch 目前使用的 BLAS 库
$ python -c "import torch;print(torch.__config__.show())"
PyTorch built with:
- GCC 11.4
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2025.0-Product Build 20241009 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.5.3 (Git Hash 66f0cb9eb66affd2da3bf5f8d897376f04aae6af)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 12.4
- NVCC architecture flags: -gencode;arch=compute_86,code=sm_86
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Debug, CUDA_VERSION=12.4, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.6.0, USE_CUDA=ON, USE_CUDNN=OFF, USE_CUSPARSELT=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
根据输出信息
- Intel(R) oneAPI Math Kernel Library Version 2025.0-Product Build 20241009 for Intel(R) 64 architecture applications
- Build settings: BLAS_INFO=mkl
可以发现当前 PyTorch 构建使用的是 Intel oneAPI Math Kernel Library (MKL) 。因此下一步会调用 mkl 的 sgemm 实现继续执行。其头文件包含代码如下:
|
|
mkl 是我们在安装 PyTorch 依赖时安装的,包括其静态链接库和头文件,之前的安装命令如下
pip install mkl-static mkl-include
我们可以在系统中找到它们(当前 conda 虚拟环境为 pytorch_dev)
(pytorch_dev) $ find $(conda info --base)/envs/pytorch_dev/include -name "mkl*.h"
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas_64.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_spblas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapacke.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapack.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapack_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_spblas_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_scalapack.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cdft.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df_functions.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_pardiso.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_functions.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cdft_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_sparse_qr.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_functions_64.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_blas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_defines.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_service.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_pblas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df_defines.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_dfti.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cblas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_blas_kernels.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_trig_transforms.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_dss.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_sparse_handle.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_version.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_defines.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_trans_names.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cblas_64.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_poisson.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_solvers_ee.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_trans.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blacs.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_dfti_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cluster_sparse_solver.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_spblas_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_functions.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_compact.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapack_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_rci.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_lapack.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_call.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_omp_variant.h
(pytorch_dev) $ find $(conda info --base)/envs/pytorch_dev/lib -name "libmkl*"
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_gf_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_tbb_thread.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_intel_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_lapack95_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_gnu_thread.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_lapack95_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blas95_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blas95_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_openmpi_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_intelmpi_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_scalapack_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_intel_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_openmpi_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_cdft_core.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_core.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_scalapack_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_intel_thread.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_intelmpi_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_gf_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_sequential.a
由于 mkl 不开源,因此静态库中发生了什么我们就不得而知了,只能获得返回的计算结果。
本文也就告一段落,下一步可能会将 PyTorch 的 BLAS 库换成 OpenBLAS,比较其和 mkl 性能差距。OpenBLAS 作为开源项目,有非常多的内容可供学习。
本站不记录浏览量,但如果您觉得本内容有帮助,请点个小红心,让我知道您的喜欢。