jklincn


PyTorch CPU 矩阵乘法执行路径


  • CPU:AMD EPYC 7763 64-Core Processor(仅支持 AVX2)
  • Pytorch 版本:torch 2.6.0a0+git60d1c71

关于 PyTorch C++ 源码调试,可以参考之前的一篇文章 VSCode 配置 PyTorch C++ 源码开发环境(编译与调试)

Python 层

我们探究的是张量的 mm 方法,一个简单的示例如下

1
2
3
4
5
import torch

A = torch.tensor([[1, -2, 0], [-1, 2, -2]])
B = torch.tensor([[2, -1, 1, 0], [0, -2, 1, 2], [-1, 1, -2, 0]])
C = A.mm(B)

其 python 层的定义为

1
2
3
4
5
6
7
8
// .../site-packages/torch/_C/__init__.pyi
    def mm(self, mat2: Tensor) -> Tensor:
        r"""
        mm(mat2) -> Tensor

        See :func:`torch.mm`
        """
        ...

好吧,python 这没什么好看的

C++ 前端绑定

PyTorch 使用 pybind11 将 C++ 和 CUDA 代码集成到 Python 项目中,其 Tensor 对象的方法实现路径为:

torch/csrc/autograd/generated/python_variable_methods.cpp

其中有一个 PyMethodDef 变量,这用于在 Python 中绑定 C/C++ 函数,让它们能作为 Python 方法调用。

1
2
3
4
5
PyMethodDef variable_methods[] = {
  ...
  {"mm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS, NULL},
  ...
};

variable_methods 数组包含了一个绑定到 THPVariable_mm 函数的 Python 方法,方法名为 “mm” 。因此 THPVariable_mm 是实际实现的 C++ 函数。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
static PyObject * THPVariable_mm(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  // 用于捕获 C++ 层抛出的异常,并将其转换为 Python 异常
  HANDLE_TH_ERRORS
  // 将 Python 对象 self_ 解包为 C++ 中的 Tensor 引用。
  const Tensor& self = THPVariable_Unpack(self_);
  // 解析 Python 调用的参数。这里定义了一个字符串 "mm(Tensor mat2)",表示 mm 函数接收一个 Tensor 参数 mat2。
  static PythonArgParser parser({
    "mm(Tensor mat2)",
  }, /*traceable=*/true);

  // 解析传递的 args 和 kwargs,将其解析并存储在 ParsedArgs 对象 parsed_args 中。
  // _r 包含了解析后的参数。
  ParsedArgs<1> parsed_args;
  auto _r = parser.parse(self_, args, kwargs, parsed_args);
  // 检查 self 或 mat2 是否实现了 __torch_function__,即是否有自定义的行为。如果存在自定义实现,就会调用 handle_torch_function,执行 __torch_function__ 钩子中的自定义逻辑,并返回相应结果,而不执行后续代码。
  if(_r.has_torch_function()) {
    return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor");
  }
  // aten::mm(Tensor self, Tensor mat2) -> Tensor
  // 定义了一个 dispatch_mm Lambda 函数,将 self 和 mat2 作为输入,执行 mm 操作。
  auto dispatch_mm = [](const at::Tensor & self, const at::Tensor & mat2) -> at::Tensor {
    // 在执行 self.mm(mat2) 前释放 Python 的全局解释器锁(GIL),允许其他线程并行执行。
    pybind11::gil_scoped_release no_gil;
    // 执行矩阵乘法
    return self.mm(mat2);
  };
  // 将 dispatch_mm(self, _r.tensor(0)) 的结果包装成 Python 对象,返回给 Python 层。_r.tensor(0) 代表 mat2 参数。
  return wrap(dispatch_mm(self, _r.tensor(0)));
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

C++层(算子分发)

THPVariable_mm 是处理与 Python 交互的函数,实际执行的是 self.mm(),即 Tensor::mm() 。

1
2
3
4
// build/aten/src/ATen/core/TensorBody.h
inline at::Tensor Tensor::mm(const at::Tensor & mat2) const {
    return at::_ops::mm::call(const_cast<Tensor&>(*this), mat2);
}

at::_ops::mm:这是代码生成器自动生成的一个命名空间,用于封装 mm 操作的所有调用入口。

call 方法实际上会调用 Dispatcher 系统的入口,将 mm 操作的两个矩阵 self 和 mat2 传递给 Dispatcher。紧接着 Dispatcher 系统就开始分发,这里涉及一系列 call 函数,包括从 dispatchTable_ 中获得具体的 kernel_function(lookup 函数) 。在此只列调用堆栈,不详细展开。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
    frame #1: 0x00007f086add4e4a libtorch_cpu.so`torch::autograd::VariableType::(anonymous namespace)::mm(ks=(repr_ = 137438986241), self=0x00007f089a711dc8, mat2=0x00007ffdfff7da20) at VariableType_3.cpp:13461:6
    frame #2: 0x00007f086ae8e59b libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&), torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &, const at::Tensor &) [inlined] operator(args#2=0x00007ffdfff7da20, args#1=0x00007f089a711dc8, args#0=(repr_ = 137438986241), this=0x0000000003c6a490) at WrapFunctionIntoFunctor.h:13:72
    frame #3: 0x00007f086ae8e55b libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&), torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&)>::call(functor=0x0000000003c6a490, dispatchKeySet=(repr_ = 137438986241), args#0=0x00007f089a711dc8, args#1=0x00007ffdfff7da20) at make_boxed_from_unboxed_functor.h:485:79
    frame #4: 0x00007f0866e79832 libtorch_cpu.so`at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&>(unboxed_kernel_func=0x00007f086ae8e4e6, functor=0x0000000003c6a490, dispatchKeySet=(repr_ = 137438986241), (null)=0x00007f089a711dc8, (null)=0x00007ffdfff7da20) at KernelFunction_impl.h:64:72
    frame #5: 0x00007f0867727386 libtorch_cpu.so`at::_ops::mm::call(at::Tensor const&, at::Tensor const&) [inlined] at::Tensor c10::KernelFunction::call<at::Tensor, at::Tensor const&, at::Tensor const&>((null)=0x00007ffdfff7da20, (null)=0x00007f089a711dc8, dispatchKeySet=(repr_ = 137438986241), opHandle=0x00007f0883f8a490, this=0x0000000001901598) const at KernelFunction_impl.h:116:87
    frame #6: 0x00007f0867727309 libtorch_cpu.so`at::_ops::mm::call(at::Tensor const&, at::Tensor const&) [inlined] at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>((null)=0x00007ffdfff7da20, (null)=0x00007f089a711dc8, op=0x00007f0883f8a490, this=0x00007f0883f56600) const at Dispatcher.h:698:97
    frame #7: 0x00007f0867727113 libtorch_cpu.so`at::_ops::mm::call(at::Tensor const&, at::Tensor const&) [inlined] c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)>::call(args#1=0x00007ffdfff7da20, args#0=0x00007f089a711dc8, this=<unavailable>) const at Dispatcher.h:531:97
    frame #8: 0x00007f0867726fec libtorch_cpu.so`at::_ops::mm::call(self=0x00007f089a711dc8, mat2=0x00007ffdfff7da20) at Operators_3.cpp:3459:30
    frame #9: 0x00007f088a121d2e libtorch_python.so`at::Tensor::mm(this=0x00007f089a711dc8, mat2=0x00007ffdfff7da20) const at TensorBody.h:2999:63
    frame #10: 0x00007f088a0c9362 libtorch_python.so`operator(__closure=0x00007ffdfff7da1f, self=0x00007f089a711dc8, mat2=0x00007ffdfff7da20) at python_variable_methods.cpp:11289:24
    frame #11: 0x00007f088a0c961a libtorch_python.so`torch::autograd::THPVariable_mm(self_=0x00007f089a711db0, args=0x00007f075d317eb0, kwargs=0x0000000000000000) at python_variable_methods.cpp:11291:26
    frame #12: 0x0000000000552a7a python3.12`method_vectorcall_VARARGS_KEYWORDS(func=<unavailable>, args=0x00007f089c5ce638, nargsf=<unavailable>, kwnames=<unavailable>) at descrobject.c:365:14
    frame #13: 0x000000000093ff80 python3.12`PyTupleIter_Type + 416

经过一系列 call 的跳转,来到了 at::Tensor mm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat2),这个函数主要提供了自动求导支持。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
at::Tensor mm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat2) {
  // 将输入 self 和 mat2 转换为基础实现的引用(self_ 和 mat2_),便于后续的低级调用。
  auto& self_ = unpack(self, "self", 0);
  auto& mat2_ = unpack(mat2, "mat2", 1);
  [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, mat2 );

  [[maybe_unused]] auto _any_has_forward_grad_result = (isFwGradDefined(self) || isFwGradDefined(mat2));

  // 自动求导逻辑
  std::shared_ptr<MmBackward0> grad_fn;
  // 如果需要梯度,函数会为反向传播创建一个 MmBackward0 节点,这个对象将保存输入张量和其元数据,以便在计算梯度时使用。
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<MmBackward0>(new MmBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self, mat2 ));
    if (grad_fn->should_compute_output(0)) {
      grad_fn->mat2_ = SavedVariable(mat2, false);
    }
    grad_fn->mat2_layout = mat2.layout();
    grad_fn->mat2_sym_sizes = mat2.sym_sizes().vec();
    grad_fn->mat2_sym_strides = strides_or_error(mat2, "mat2").vec();
    if (grad_fn->should_compute_output(1)) {
      grad_fn->self_ = SavedVariable(self, false);
    }
    grad_fn->self_layout = self.layout();
    grad_fn->self_sym_sizes = self.sym_sizes().vec();
    grad_fn->self_sym_strides = strides_or_error(self, "self").vec();
  }

  // 调试信息保存
  #ifndef NDEBUG
  auto self__storage_saved =
    self_.has_storage() ? ::std::optional<Storage>(self_.storage()) : ::std::nullopt;
  c10::intrusive_ptr<TensorImpl> self__impl_saved;
  if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
  auto mat2__storage_saved =
    mat2_.has_storage() ? ::std::optional<Storage>(mat2_.storage()) : ::std::nullopt;
  c10::intrusive_ptr<TensorImpl> mat2__impl_saved;
  if (mat2_.defined()) mat2__impl_saved = mat2_.getIntrusivePtr();
  #endif

  // 执行矩阵乘法操作(通过 redispatch 调用具体实现)
  auto _tmp = ([&]() {
    // 设置调度上下文,确保不会在当前操作中触发自动求导的 in-place 操作或视图。
    at::AutoDispatchBelowADInplaceOrView guard;
    // 根据 DispatchKeySet ks 的设备信息,redispatch 会选择适当的后端实现,例如 mm_out_cpu 或 mm_out_cuda。在这一步,redispatch 负责根据 ks 中的键集调度到实际的底层计算函数。
    return at::redispatch::mm(ks & c10::after_autograd_keyset, self_, mat2_);
  })();
  auto result = std::move(_tmp);

  // 断言检查(调试模式)
  #ifndef NDEBUG
  if (self__storage_saved.has_value() &&
      !at::impl::dispatch_mode_enabled() &&
      !at::impl::tensor_has_dispatch(self_) &&
      !at::impl::tensor_has_dispatch(self_))
    TORCH_INTERNAL_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
  if (self__impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(self_))
    TORCH_INTERNAL_ASSERT(self__impl_saved == self_.getIntrusivePtr());
  if (mat2__storage_saved.has_value() &&
      !at::impl::dispatch_mode_enabled() &&
      !at::impl::tensor_has_dispatch(mat2_) &&
      !at::impl::tensor_has_dispatch(mat2_))
    TORCH_INTERNAL_ASSERT(mat2__storage_saved.value().is_alias_of(mat2_.storage()));
  if (mat2__impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(mat2_))
    TORCH_INTERNAL_ASSERT(mat2__impl_saved == mat2_.getIntrusivePtr());
  if (result.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(result)) {
    TORCH_INTERNAL_ASSERT(result.storage().use_count() == 1, "function: mm");
  }
  if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(result))
    TORCH_INTERNAL_ASSERT(result.use_count() <= 1, "function: mm");
  #endif

  // 如果需要求导,则调用 set_history 将 grad_fn 关联到输出张量 result。这一步标记了计算图的依赖关系,使得反向传播时可以追溯到对应的前向操作。
  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }

  // 如果开启了前向模式自动微分(Forward AD),该部分代码会计算前向梯度信息,并将其存储在 result 中。result_new_fw_grad_opt 包含了 result 的前向梯度信息。
  ::std::optional<at::Tensor> result_new_fw_grad_opt = ::std::nullopt;
  if (_any_has_forward_grad_result && (result.defined())) {
      auto self_t_raw = toNonOptFwGrad(self);
      auto self_tensor = toNonOptTensor(self);
      auto self_t = (self_t_raw.defined() || !self_tensor.defined())
        ? self_t_raw : at::_efficientzerotensor_symint(self_tensor.sym_sizes(), self_tensor.options());
      auto self_p = toNonOptPrimal(self);
      auto mat2_t_raw = toNonOptFwGrad(mat2);
      auto mat2_tensor = toNonOptTensor(mat2);
      auto mat2_t = (mat2_t_raw.defined() || !mat2_tensor.defined())
        ? mat2_t_raw : at::_efficientzerotensor_symint(mat2_tensor.sym_sizes(), mat2_tensor.options());
      auto mat2_p = toNonOptPrimal(mat2);
      result_new_fw_grad_opt = at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t);
  }
  if (result_new_fw_grad_opt.has_value() && result_new_fw_grad_opt.value().defined() && result.defined()) {
    // The hardcoded 0 here will need to be updated once we support multiple levels.
    result._set_fw_grad(result_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ false);
  }

  // 返回计算结果
  return result;
}

at::redispatch::mm() 又是一系列的 redispatch 函数,这里也只提供调用堆栈

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    * frame #0: 0x00007f54fa1e3a73 libtorch_cpu.so`at::native::addmm_impl_cpu_(result=0x00007ffd9a1a3468, self=0x00007ffd9a1a3468, m1=Tensor @ 0x00007ffd9a1a33c0, m2=Tensor @ 0x00007ffd9a1a33c8, beta=0x00007ffd9a1a33d0, alpha=0x00007ffd9a1a33f0) at LinearAlgebra.cpp:1507:8
    frame #1: 0x00007f54fa1e4cfd libtorch_cpu.so`at::native::structured_mm_out_cpu::impl(this=0x00007ffd9a1a3460, self=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0, result=0x00007ffd9a1a3468) at LinearAlgebra.cpp:1635:20
    frame #2: 0x00007f54fbc5e384 libtorch_cpu.so`at::(anonymous namespace)::wrapper_CPU_mm(self=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0) at RegisterCPU.cpp:8595:8
    frame #3: 0x00007f54fbe1eea0 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, const at::Tensor&), at::(anonymous namespace)::wrapper_CPU_mm>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, const at::Tensor&> >, at::Tensor(const at::Tensor&, const at::Tensor&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &, const at::Tensor &) [inlined] operator(args#1=0x00007ffd9a1a3cf0, args#0=0x00007f552cfc7088, this=0x000000000356eb80) at WrapFunctionIntoFunctor.h:13:72
    frame #4: 0x00007f54fbe1ee72 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, const at::Tensor&), at::(anonymous namespace)::wrapper_CPU_mm>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, const at::Tensor&> >, at::Tensor(const at::Tensor&, const at::Tensor&)>::call(functor=0x000000000356eb80, (null)=(repr_ = 32769), args#0=0x00007f552cfc7088, args#1=0x00007ffd9a1a3cf0) at make_boxed_from_unboxed_functor.h:468:63
    frame #5: 0x00007f54fae79832 libtorch_cpu.so`at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&>(unboxed_kernel_func=0x00007f54fbe1ee07, functor=0x000000000356eb80, dispatchKeySet=(repr_ = 32769), (null)=0x00007f552cfc7088, (null)=0x00007ffd9a1a3cf0) at KernelFunction_impl.h:64:72
    frame #6: 0x00007f54fad0e1f4 libtorch_cpu.so`at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) const [inlined] at::Tensor c10::KernelFunction::call<at::Tensor, at::Tensor const&, at::Tensor const&>((null)=0x00007ffd9a1a3cf0, (null)=0x00007f552cfc7088, dispatchKeySet=(repr_ = 32769), opHandle=0x00007f5517f8a4b0, this=0x0000000002886a18) const at KernelFunction_impl.h:116:87
    frame #7: 0x00007f54fad0e183 libtorch_cpu.so`at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&>(this=0x00007f5517f56600, op=0x00007f5517f8a4b0, currentDispatchKeySet=(repr_ = 32769), (null)=0x00007f552cfc7088, (null)=0x00007ffd9a1a3cf0) const at Dispatcher.h:714:102
    frame #8: 0x00007f54fb727614 libtorch_cpu.so`at::_ops::mm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [inlined] c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)>::redispatch(args#1=0x00007ffd9a1a3cf0, args#0=0x00007f552cfc7088, currentDispatchKeySet=(repr_ = 32769), this=<unavailable>) const at Dispatcher.h:536:126
    frame #9: 0x00007f54fb72757b libtorch_cpu.so`at::_ops::mm::redispatch(dispatchKeySet=(repr_ = 32769), self=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0) at Operators_3.cpp:3466:52
    frame #10: 0x00007f54feec5f32 libtorch_cpu.so`at::redispatch::mm(dispatchKeySet=(repr_ = 32769), self=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0) at RedispatchFunctions.h:5222:67
    frame #11: 0x00007f54fedd46de libtorch_cpu.so`operator(__closure=0x00007ffd9a1a3930) at VariableType_3.cpp:13460:76
    frame #12: 0x00007f54fedd4e4a libtorch_cpu.so`torch::autograd::VariableType::(anonymous namespace)::mm(ks=(repr_ = 137438986241), self=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0) at VariableType_3.cpp:13461:6
    frame #13: 0x00007f54fee8e59b libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&), torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &, const at::Tensor &) [inlined] operator(args#2=0x00007ffd9a1a3cf0, args#1=0x00007f552cfc7088, args#0=(repr_ = 137438986241), this=0x0000000004bf0490) at WrapFunctionIntoFunctor.h:13:72
    frame #14: 0x00007f54fee8e55b libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&), torch::autograd::VariableType::(anonymous namespace)::mm>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&)>::call(functor=0x0000000004bf0490, dispatchKeySet=(repr_ = 137438986241), args#0=0x00007f552cfc7088, args#1=0x00007ffd9a1a3cf0) at make_boxed_from_unboxed_functor.h:485:79
    frame #15: 0x00007f54fae79832 libtorch_cpu.so`at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&>(unboxed_kernel_func=0x00007f54fee8e4e6, functor=0x0000000004bf0490, dispatchKeySet=(repr_ = 137438986241), (null)=0x00007f552cfc7088, (null)=0x00007ffd9a1a3cf0) at KernelFunction_impl.h:64:72
    frame #16: 0x00007f54fb727386 libtorch_cpu.so`at::_ops::mm::call(at::Tensor const&, at::Tensor const&) [inlined] at::Tensor c10::KernelFunction::call<at::Tensor, at::Tensor const&, at::Tensor const&>((null)=0x00007ffd9a1a3cf0, (null)=0x00007f552cfc7088, dispatchKeySet=(repr_ = 137438986241), opHandle=0x00007f5517f8a490, this=0x0000000002887598) const at KernelFunction_impl.h:116:87
    frame #17: 0x00007f54fb727309 libtorch_cpu.so`at::_ops::mm::call(at::Tensor const&, at::Tensor const&) [inlined] at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>((null)=0x00007ffd9a1a3cf0, (null)=0x00007f552cfc7088, op=0x00007f5517f8a490, this=0x00007f5517f56600) const at Dispatcher.h:698:97
    frame #18: 0x00007f54fb727113 libtorch_cpu.so`at::_ops::mm::call(at::Tensor const&, at::Tensor const&) [inlined] c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)>::call(args#1=0x00007ffd9a1a3cf0, args#0=0x00007f552cfc7088, this=<unavailable>) const at Dispatcher.h:531:97
    frame #19: 0x00007f54fb726fec libtorch_cpu.so`at::_ops::mm::call(self=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0) at Operators_3.cpp:3459:30
    frame #20: 0x00007f552491ed2e libtorch_python.so`at::Tensor::mm(this=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0) const at TensorBody.h:2999:63
    frame #21: 0x00007f55248c6362 libtorch_python.so`operator(__closure=0x00007ffd9a1a3cef, self=0x00007f552cfc7088, mat2=0x00007ffd9a1a3cf0) at python_variable_methods.cpp:11289:24
    frame #22: 0x00007f55248c661a libtorch_python.so`torch::autograd::THPVariable_mm(self_=0x00007f552cfc7070, args=0x00007f53efebc850, kwargs=0x0000000000000000) at python_variable_methods.cpp:11291:26
    frame #23: 0x0000000000552a7a python3.12`method_vectorcall_VARARGS_KEYWORDS(func=<unavailable>, args=0x00007f552ee82638, nargsf=<unavailable>, kwnames=<unavailable>) at descrobject.c:365:14

可以看到 frame #2 中

1
2
3
4
5
6
7
// build/aten/src/ATen/RegisterCPU.cpp
at::Tensor wrapper_CPU_mm(const at::Tensor & self, const at::Tensor & mat2) {
    structured_mm_out_cpu_functional op;
    op.meta(self, mat2);
    op.impl(self, mat2, op.outputs_[0]);
    return std::move(op.outputs_[0]);
}

其中 op.impl() 跳转到 mm_out_cpu()

1
2
3
4
5
6
7
// aten/src/ATen/native/LinearAlgebra.cpp
TORCH_IMPL_FUNC(mm_out_cpu)(const Tensor & self, const Tensor & mat2, const Tensor & result) {
  {
    at::NoNamesGuard guard;
    addmm_impl_cpu_(const_cast<Tensor&>(result), result, self, mat2, 0, 1);
  }
}

最终调用 addmm_impl_cpu_(const_cast<Tensor&>(result), result, self, mat2, 0, 1) ,这里有两个需要注意的传参细节:

  1. 传入了两个 result
  2. 传入了两个固定的标量 0 和 1

重新布局矩阵内存

addmm_impl_cpu_ 是一个在 CPU 上实现的矩阵乘加运算,运算表达式:result = beta _ self + alpha _ m1 * m2 ,这里 beta 和 alpha 都是标量,m1 和 m2 是矩阵。addmm 常用于神经网络中的线性层运算,形式类似于 y = Wx + b 。

我们下面以 CPU 上的 int8 运算为例,示例代码:

1
2
3
4
5
import torch

A = torch.tensor([[1, -2, 0], [-1, 2, -2]], dtype=torch.int8)
B = torch.tensor([[2, -1, 1, 0], [0, -2, 1, 2], [-1, 1, -2, 0]], dtype=torch.int8)
C = A.mm(B)

一个 2×3 的矩阵 A 和一个 3×4 的矩阵 B 进行相乘,结果是一个 2×4 的矩阵 C 。

在上层 mm_out_cpu 中,self 是 A,mat2 是 B,result 是 C。因此根据其调用时的传参,进入到 addmm_impl_cpu_ 中的参数信息如下:

  • result 和 self 是 C;
  • m1 是 A,m2 是 B;
  • beta 是 0,alpha 是 1,即简化为 C = A * B 。

下面我们结合例子详细解释 addmm_impl_cpu_ 函数,部分注释中给出了变量此时的值。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// aten/src/ATen/native/LinearAlgebra.cpp
static void addmm_impl_cpu_(
    Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
  // 张量维度检查,确保都是二维张量,即矩阵
  TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
  // 数据类型检查,确保相乘的两个矩阵数据类型一致
  TORCH_CHECK(
    m1.dtype() == m2.dtype(),
    "expected m1 and m2 to have the same dtype, but got: ", m1.dtype(), " != ", m2.dtype()
  )

  // 使用临时变量:当需要张量信息的时候,数组访问会比调用其方法更快
  const auto self_sizes = self.sizes(); // [2, 4]:矩阵 C 的形状是 2*4
  auto m1_strides = m1.strides(); // [3, 1]:矩阵 A 第 0 个维度步长是 3(相邻行之间的距离,即列数),第 1 个维度步长是 1(行内相邻元素的距离)
  auto m1_sizes = m1.sizes(); // [2, 3]:矩阵 A 的形状是 2*3
  auto m2_strides = m2.strides(); // [4, 1]:矩阵 B 第 0 个维度步长是 4,第 1 个维度步长是 1
  auto m2_sizes = m2.sizes(); // [3, 4]:矩阵 B 的形状是 3*4

  // 检查做完矩阵乘法后的矩阵形状是否和 self 一致,否则无法相加
  TORCH_CHECK(
      self_sizes[0] == m1_sizes[0] && self_sizes[1] == m2_sizes[1],
      "input shape is incompatible with matrix multiplication (",
      m1_sizes[0], "x", m1_sizes[1], " @ ", m2_sizes[0], "x", m2_sizes[1], " != ",
      self_sizes[0], "x", self_sizes[1], ")");

  // 调整 result 形状,并获取形状和步长信息
  at::native::resize_output(result, self_sizes);
  const auto result_strides = result.strides(); // [4, 1]:最终输出矩阵的第 0 个维度步长是 4,第 1 个维度步长是 1
  const auto result_sizes = result.sizes(); // [2, 4]:最终输出矩阵的形状是 2*4

  // 输出的矩阵如果为空,则直接返回
  if (result.numel() == 0) {
    return;
  }

  // 如果 m1 的列数为 0(空矩阵情况),则无法进行矩阵乘法,只做 result = beta * self
  // 根据 beta 的值来更新 result 的值:
  //   1. 如果 beta 为 0,则将 result 置零;
  //   2. 否则将 self 的值按 beta 缩放后赋给 result(代码是先赋值再缩放)
  if (m1_sizes[1] == 0) {
    if (beta.toComplexDouble() == 0.0) {
      result.zero_();
    } else {
      if (!self.is_same(result)) {
        result.copy_(self);
      }
      result.mul_(beta);
    }
    return;
  }

  // 如果 beta 不为 0,self 和 result 不是同一个矩阵,则这里将 self 赋给 result
  if (beta.toComplexDouble() != 0.0 && !self.is_same(result)) {
    result.copy_(self);
  }

  // 检查 result 的内存布局,并决定是否进行转置或重新排列使其变为按列存储(也叫作 FORTRAN 连续),然后记为张量 c
  bool transpose_c = false;
  Tensor c;

  // 第一种情况:如果矩阵是按列存储,并且保证内存不会重叠(与运算的第二个条件,使用 max 避免了无效比较,比如行数为 0),则无需处理
  if (result_strides[0] == 1 &&
      (result_sizes[1] == 1 || result_strides[1] >= std::max(int64_t{1}, result_sizes[0]))) {
    transpose_c = false;
    c = result.resolve_conj();
  } else if (result_strides[1] == 1 &&
             (result_sizes[0] == 1 || result_strides[0] >= std::max(int64_t{1}, result_sizes[1]))) {
    // 第二种情况:如果矩阵是按行存储的,并且也能保证内存不会重叠,则进行转置
    // 交换矩阵
    std::swap(m1, m2);
    std::swap(m1_sizes, m2_sizes);
    std::swap(m1_strides, m2_strides);
    // 设置转置符号
    transpose_c = true;
    c = result.resolve_conj();
  } else {
    // 第三种情况:既不按列连续,也不按行连续,需要调整内存布局
    transpose_c = false;
    // 将张量 c 的内存布局调整为 Fortran 连续,即列优先
    c = result.resolve_conj().transpose(0, 1).contiguous().transpose_(0, 1);
  }

  // 在本例中,符合上述第二种情况,因此 m1 和 m2 进行了交换,当前值如下:
  // m1 = [[2, -1, 1, 0], [0, -2, 1, 2], [-1, 1, -2, 0]]
  // m1_strides = [4, 1]
  // m1_sizes = [3, 4]
  // m2 = [[1, -2, 0], [-1, 2, -2]]
  // m2_strides = [3, 1]
  // m2_sizes = [2, 3]

  const int64_t m = result_sizes[transpose_c ? 1 : 0]; // m = 4
  const int64_t n = result_sizes[transpose_c ? 0 : 1]; // n = 2
  const int64_t k = m1_sizes[transpose_c ? 0 : 1];     // k = 3

  // 同上,检查 m1 的内存布局,然后记为张量 a
  // 在本例中,符合第一种情况,即不转置
  bool transpose_a = false;
  Tensor a;

  if (m1_strides[transpose_c ? 1 : 0] == 1 &&
      m1_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, m)) {
    transpose_a = false;
    a = m1.resolve_conj();
  } else if (m1_strides[transpose_c ? 0 : 1] == 1 &&
             m1_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, k)) {
    transpose_a = true;
    a = m1;
  } else {
    transpose_a = !transpose_c;
    a = m1.clone(at::MemoryFormat::Contiguous);
  }

  // 同上,检查矩阵 m2 的内存布局,然后记为张量 b
  // 在本例中,符合第一种情况,即不转置
  bool transpose_b = false;
  Tensor b;

  if (m2_strides[transpose_c ? 1 : 0] == 1 &&
      m2_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, k)) {
    transpose_b = false;
    b = m2.resolve_conj();
  } else if (m2_strides[transpose_c ? 0 : 1] == 1 &&
             m2_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, n)) {
    transpose_b = true;
    b = m2;
  } else {
    transpose_b = !transpose_c;
    b = m2.clone(at::MemoryFormat::Contiguous);
  }

  const int64_t lda = a.strides()[(transpose_a == transpose_c) ? 1 : 0]; // lda = 4
  const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0]; // ldb = 3
  const int64_t ldc = c.strides()[transpose_c ? 0 : 1];                  // ldc = 4

  // 确保张量 c 的共轭状态已经被解决,因为在调用 GEMM 函数时,无法指定 c 的共轭状态。
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());

  bool dispatched = false;
// 如果是 aarch64 平台并且启用 MKL-DNN ,则会尝试调用 MKL-DNN 的矩阵乘法来加速运算。本例中是 x86 平台,跳过。
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
  if (transpose_c) {
    bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
    if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
      try {
        mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
        dispatched = true;
      } catch (const std::exception& e) {
        TORCH_WARN("mkldnn_matmul failed, switching to BLAS gemm:", e.what());
        at::globalContext().setUserEnabledMkldnn(false);
      }
    }
  }
#endif

  // 使用 BLAS gemm 函数执行实际运算,这里的判断语句是为了兼容上面的 aarch64
  if(!dispatched) {
    _AT_DISPATCH_ADDMM_TYPES(result.scalar_type(), "addmm_impl_cpu_", [&]{
          using opmath_t = at::opmath_type<scalar_t>;
          // at::native::cpublas::gemm 是 PyTorch 对 BLAS 中 gemm 函数的封装
          at::native::cpublas::gemm(
              transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
              transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
              m, n, k,
              alpha.to<opmath_t>(),
              a.const_data_ptr<scalar_t>(), lda,
              b.const_data_ptr<scalar_t>(), ldb,
              beta.to<opmath_t>(),
              c.mutable_data_ptr<scalar_t>(), ldc);
        });
  }

  // 如果 c 和 result 不是同一个矩阵,则将计算结果从 c 拷贝回 result。
  if (!c.is_same(result)) {
    result.copy_(c);
  }
}

addmm_impl_cpu_ 探索了三个矩阵的内存布局,把它们从按行存储转换成了按列存储(Fortran contiguous),以此来优化后续计算的访存。注意,这里的 transpose_c 是用来判断 transpose_a 和 transpose_b 的值,实际没有传到下一层函数中。可以认为 transpose_a 和 transpose_b 参数的值隐式包含了 transpose_c 这一信息。

函数最后使用 gemm 来实现 addmm,gemm (general matrix multiplication) 是通用矩阵乘法,用于计算矩阵 C = alpha _ A _ B + beta * C 。关于 GEMM 的更多信息可以看 Basic Linear Algebra Subprograms - Wikipedia

gemm 实际运算

int8 实现

at::native::cpublas::gemm() 是一个模板函数,在 int8 数据类型下,由于找不到对应的优化实现,所以只能回退到默认实现(数据类型是 scalar_t),也是纯软件实现。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
// aten/src/ATen/native/CPUBlas.h
template <typename scalar_t>
void gemm(
    TransposeType transa, TransposeType transb,
    int64_t m, int64_t n, int64_t k,
    at::opmath_type<scalar_t> alpha,
    const scalar_t *a, int64_t lda,
    const scalar_t *b, int64_t ldb,
    at::opmath_type<scalar_t> beta,
    scalar_t *c, int64_t ldc) {
  internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
  gemm_stub(
    kCPU, c10::CppTypeToScalarType<scalar_t>::value,
    transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}

normalize_last_dims 函数的作用是标准化与矩阵乘法相关的步长,以确保矩阵在内存中的布局正确,特别是在处理转置情况时。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
void normalize_last_dims(
    TransposeType transa, TransposeType transb,
    int64_t m, int64_t n, int64_t k,
    int64_t *lda, int64_t *ldb, int64_t *ldc) {
  if (n == 1) {
    *ldc = m;
  }

  if(transa != TransposeType::NoTranspose) {
    if (m == 1) {
      *lda = k;
    }
  } else if(k == 1) {
    *lda = m;
  }

  if(transb != TransposeType::NoTranspose) {
    if (k == 1) {
      *ldb = n;
    }
  } else if (n == 1) {
    *ldb = k;
  }
}

在本例中参数如下,所以不进行任何操作。

(at::native::TransposeType) transa = NoTranspose
(at::native::TransposeType) transb = NoTranspose
(int64_t) m = 4
(int64_t) n = 2
(int64_t) k = 3
(int64_t *) lda = 0x00007ffe2bfe9838
(int64_t *) ldb = 0x00007ffe2bfe9848
(int64_t *) ldc = 0x00007ffe2bfe9860

gemm_stub 是一种回退机制,用于在没有 BLAS 或 MKL-DNN 等算子库加速的情况下执行矩阵乘法。我们的 int8 在 CPU 上没有特定的优化路径,因此只能选择 gemm_stub() 。

1
2
3
// aten/src/ATen/native/cpu/BlasKernel.cpp
// 将 cpublas_gemm_impl 注册到 gemm_stub
REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl);

cpublas_gemm_impl 是一个用于在 CPU 上执行 gemm 的实现。它使用模板和宏来支持不同的数据类型(例如 float、double 等),并通过 gemm_core_ 函数来执行实际的矩阵乘法。_AT_DISPATCH_GEMM_TYPES 是一个模板分发宏,根据 type 确定 scalar_t 的实际类型,然后在该类型上执行后续代码块。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
// aten/src/ATen/native/cpu/BlasKernel.cpp
void cpublas_gemm_impl(
    at::ScalarType type,
    TransposeType transa, TransposeType transb,
    int64_t m, int64_t n, int64_t k,
    const Scalar& alpha,
    const void *a, int64_t lda,
    const void *b, int64_t ldb,
    const Scalar& beta,
    void *c, int64_t ldc) {
  _AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
        using opmath_t = at::opmath_type<scalar_t>;
        gemm_core_(
            transa, transb, m, n, k,
            alpha.to<opmath_t>(),
            static_cast<const scalar_t *>(a), lda,
            static_cast<const scalar_t *>(b), ldb,
            beta.to<opmath_t>(),
            static_cast<scalar_t *>(c), ldc);
      });
}

gemm_core_ 是一个模板化的核心矩阵乘法实现,用于处理各种不同的矩阵转置组合情况。它根据矩阵 A 和 B 是否转置的情况,调用相应的子函数来执行具体的矩阵乘法计算。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// aten/src/ATen/native/cpu/BlasKernel.cpp
template <typename scalar_t, typename opmath_t>
void gemm_core_(
    TransposeType transa, TransposeType transb,
    int64_t m, int64_t n, int64_t k,
    opmath_t alpha,
    const scalar_t *a, int64_t lda,
    const scalar_t *b, int64_t ldb,
    opmath_t beta,
    scalar_t *c, int64_t ldc) {
  if (transa == TransposeType::NoTranspose &&
      transb == TransposeType::NoTranspose) {
    // gemm_notrans_:当 A 和 B 都不需要转置时调用。
    return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  } else if (
      transa != TransposeType::NoTranspose &&
      transb == TransposeType::NoTranspose) {
    // gemm_transa_:当 A 需要转置而 B 不需要转置时调用。
    gemm_transa_(transa, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  } else if (
      transa == TransposeType::NoTranspose &&
      transb != TransposeType::NoTranspose) {
    // gemm_transb_:当 B 需要转置而 A 不需要转置时调用。
    gemm_transb_(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  } else {
    // gemm_transab_:当 A 和 B 都需要转置时调用。
    gemm_transab_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  }
}

在本例中调用的是 gemm_notrans_,这是 pytorch 矩阵乘法执行路径的最后一层了,我们仔细看一下这个函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
template <typename scalar_t, typename opmath_t>
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
gemm_notrans_(
    int64_t m,
    int64_t n,
    int64_t k,
    opmath_t alpha,
    const scalar_t* a,
    int64_t lda,
    const scalar_t* b,
    int64_t ldb,
    opmath_t beta,
    scalar_t* c,
    int64_t ldc) {
  // c *= beta
  scale_(m, n, beta, c, ldc);

  // c += alpha * (a @ b)
  for (const auto l : c10::irange(k)) {
    for (const auto j : c10::irange(n)) {
      // 避免后续重复访问和计算
      opmath_t val = b[l + j * ldb] * alpha;
      // 循环展开
      int64_t i_m = m / 4;
      for (const auto i_i : c10::irange(i_m)) {
        c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
        c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
        c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
        c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
      }
      // 处理剩余的不足 4 的行数
      int64_t i = i_m * 4;
      for (; i < m; i++)
        c[j * ldc + i] += a[i + l * lda] * val;
    }
  }
}

模板参数和函数参数:

  • scalar_t:用于表示输入矩阵和输出矩阵元素的类型,例如 float 或 double。
  • opmath_t:表示计算过程中使用的累加数据类型。这里与 scalar_t 相同(因为使用了 std::is_same_v 保证,如果不同则是另外一个实现)。
  • m, n, k:表示参与运算的矩阵的形状。k 为矩阵 A 和矩阵 B 的共享维度。
  • alpha 和 beta:标量,用于缩放矩阵 A * B 和 C。
  • a、b、c:指向矩阵 A、B 和 C 数据的指针。
  • lda、ldb、ldc:表示矩阵 A、B、C 的行步长。

这个函数先计算矩阵 C 的缩放,然后再计算矩阵 A 和矩阵 B 的乘,其中使用了循环展开减少循环控制开销。

此时请忘掉一般的矩阵乘法形式,因为我们已经进入到其访存优化的内部实现中!

回到我们的例子,模拟一下具体的计算过程

1
2
3
4
5
import torch

A = torch.tensor([[1, -2, 0], [-1, 2, -2]], dtype=torch.int8)
B = torch.tensor([[2, -1, 1, 0], [0, -2, 1, 2], [-1, 1, -2, 0]], dtype=torch.int8)
R = A.mm(B)

原本的矩阵数组是

  • 2*3 矩阵 A = [1, -2, 0, -1, 2, -2]
  • 3*4 矩阵 B = [2, -1, 1, 0, 0, -2, 1, 2, -1, 1, -2, 0]
  • 计算后的 2*4 矩阵 C = [2, 3, -1, -4, 0, -5, 5, 4](作为对照结果)

在 addmm_impl_cpu_ 函数中经过内存布局调整后(矩阵 A 与 B 互换)

  • 矩阵 A = [2, -1, 1, 0, 0, -2, 1, 2, -1, 1, -2, 0]
  • 矩阵 B = [1, -2, 0, -1, 2, -2]

进入 gemm_notrans_ 的参数分别为:

(int64_t) m = 4
(int64_t) n = 2
(int64_t) k = 3
(signed char) alpha = '\x01'
(const signed char *) a = 0x00000000077eaec0
(int64_t) lda = 4
(const signed char *) b = 0x00000000077eae40
(int64_t) ldb = 3
(signed char) beta = '\0'
(signed char *) c = 0x00000000077f2900
(int64_t) ldc = 4

计算过程:

  1. 当 l = 0 且 j = 0 时:
    • val = b[0 + 0 * 3] = b[0]
    • i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
      • c[0 * 4+ 0 * 4 + 0] += a[0 * 4 + 0 + 0 * 4] _ val,即计算 c[0] += a[0] _ b[0]
      • c[0 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 0 * 4] _ val,即计算 c[1] += a[1] _ b[0]
      • c[0 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 0 * 4] _ val,即计算 c[2] += a[2] _ b[0]
      • c[0 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 0 * 4] _ val,即计算 c[3] += a[3] _ b[0]
  2. 当 l = 0 且 j = 1 时:
    • val = b[0 + 1 * 3] = b[3]
    • i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
      • c[1 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 0 * 4] _ val,即计算 c[4] += a[0] _ b[3]
      • c[1 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 0 * 4] _ val,即计算 c[5] += a[1] _ b[3]
      • c[1 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 0 * 4] _ val,即计算 c[6] += a[2] _ b[3]
      • c[1 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 0 * 4] _ val,即计算 c[7] += a[3] _ b[3]
  3. 当 l = 1 且 j = 0 时:
    • val = b[1 + 0 * 3] = b[1]
    • i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
      • c[0 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 1 * 4] _ val,即计算 c[0] += a[4] _ b[1]
      • c[0 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 1 * 4] _ val,即计算 c[1] += a[5] _ b[1]
      • c[0 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 1 * 4] _ val,即计算 c[2] += a[6] _ b[1]
      • c[0 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 1 * 4] _ val,即计算 c[3] += a[7] _ b[1]
  4. 当 l = 1 且 j = 1 时:
    • val = b[1 + 1 * 3] = b[4]
    • i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
      • c[1 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 1 * 4] _ val,即计算 c[4] += a[4] _ b[4]
      • c[1 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 1 * 4] _ val,即计算 c[5] += a[5] _ b[4]
      • c[1 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 1 * 4] _ val,即计算 c[6] += a[6] _ b[4]
      • c[1 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 1 * 4] _ val,即计算 c[7] += a[7] _ b[4]
  5. 当 l = 2 且 j = 0 时:
    • val = b[2 + 0 * 3] = b[2]
    • i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
      • c[0 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 2 * 4] _ val,即计算 c[0] += a[8] _ b[2]
      • c[0 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 2 * 4] _ val,即计算 c[1] += a[9] _ b[2]
      • c[0 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 2 * 4] _ val,即计算 c[2] += a[10] _ b[2]
      • c[0 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 2 * 4] _ val,即计算 c[3] += a[11] _ b[2]
  6. 当 l = 2 且 j = 1 时:
    • val = b[2 + 1 * 3] = b[5]
    • i_m = m / 4 = 4 / 4 = 1,i_i 的值仅取 0,进行循环展开
      • c[1 * 4 + 0 * 4 + 0] += a[0 * 4 + 0 + 2 * 4] _ val,即计算 c[4] += a[8] _ b[5]
      • c[1 * 4+ 0 * 4 + 1] += a[0 * 4 + 1 + 2 * 4] _ val,即计算 c[5] += a[9] _ b[5]
      • c[1 * 4+ 0 * 4 + 2] += a[0 * 4 + 2 + 2 * 4] _ val,即计算 c[6] += a[10] _ b[5]
      • c[1 * 4+ 0 * 4 + 3] += a[0 * 4 + 3 + 2 * 4] _ val,即计算 c[7] += a[11] _ b[5]

把累加表达式合并得到:

$$ \begin{align} c[0] &= a[0]*b[0] + a[4]*b[1] + a[8]*b[2] = 2 + 0 + 0 = 2 \\\\ c[1] &= a[1]*b[0] + a[5]*b[1] + a[9]*b[2] = -1 + 4 + 0 = 3 \\\\ c[2] &= a[2]*b[0] + a[6]*b[1] + a[10]*b[2] = 1 - 2 + 0 = -1 \\\\ c[3] &= a[3]*b[0] + a[7]*b[1] + a[11]*b[2] = 0 - 4 + 0 = -4 \\\\ \\\\ c[4] &= a[0]*b[3] + a[4]*b[4] + a[8]*b[5] = -2 + 0 + 2 = 0 \\\\ c[5] &= a[1]*b[3] + a[5]*b[4] + a[9]*b[5] = 1 - 4 - 2 = -5 \\\\ c[6] &= a[2]*b[3] + a[6]*b[4] + a[10]*b[5] = -1 + 2 + 4 =5 \\\\ c[7] &= a[3]*b[3] + a[7]*b[4] + a[11]*b[5] = 0 + 4 + 0 =4 \\\\ \end{align} $$

可以看到这和期望的结果一致。

所以,为什么要这么计算?

我们先来看看正常的矩阵乘法计算过程:

$$ \begin{pmatrix} a[0] & a[1] & a[2] \\\\ a[3] & a[4] & a[5] \end{pmatrix}\begin{pmatrix} b[0] & b[1] & b[2] & b[3] \\\\ b[4] & b[5] & b[6] & b[7] \\\\ b[8] & b[9] & b[10] & b[11] \end{pmatrix}=\begin{pmatrix} c[0] & c[1] & c[2] & c[3] \\\\ c[4] & c[5] & c[6] & c[7] \end{pmatrix} $$
$$ \begin{align} c[0] &= a[0]*b[0] + a[1]*b[4] + a[2]*b[8] = 2 + 0 + 0 = 2 \\\\ c[1] &= a[0]*b[1] + a[1]*b[5] + a[2]*b[9] = -1 + 4 + 0 = 3 \\\\ c[2] &= a[0]*b[2] + a[1]*b[6] + a[2]*b[10] = 1 - 2 + 0 = -1 \\\\ c[3] &= a[0]*b[3] + a[1]*b[7] + a[2]*b[11] = 0 - 4 + 0 = -4 \\\\ \\\\ c[4] &= a[3]*b[0] + a[4]*b[4] + a[5]*b[8] = -2 + 0 + 2 = 0 \\\\ c[5] &= a[3]*b[1] + a[4]*b[5] + a[5]*b[9] = 1 - 4 - 2 = -5 \\\\ c[6] &= a[3]*b[2] + a[4]*b[6] + a[5]*b[10] = -1 + 2 + 4 = 5 \\\\ c[7] &= a[3]*b[3] + a[4]*b[7] + a[5]*b[11] = 0 + 4 + 0 = 4 \\\\ \end{align} $$

注意这里和上面 PyTorch 算法的区别就是 a 和 b 反一下。结合之前 addmm_impl_cpu_ 中矩阵的互换,我们就可以证明 PyTorch 算法的正确性。

相应代码

1
2
3
4
5
6
7
for (int i = 0; i < m; ++i) {              // 遍历 A 的每一行
	for (int j = 0; j < p; ++j) {          // 遍历 B 的每一列
		for (int k = 0; k < n; ++k) {      // 计算 A 的第 i 行与 B 的第 j 列的点积
		C[i * ldc + j] += A[i * lda + k] * B[k * ldb + j]; // 即 C[i][j] += A[i][k] * B[k][j];
        }
    }
}

矩阵 C 的每一个元素在第三个循环中计算完成,变量 k 在迭代时,对于矩阵 A 来说是顺序访存,但对于矩阵 B 是跳跃访存(矩阵 B 按行存储这里却按列读取,即循环变量乘上了步长)。对于较大的矩阵,这种访问模式会导致处理器不断跳跃到新的缓存行,这样一来,缓存行中的大部分数据都无法被重用,导致频繁的缓存未命中。

而 PyTorch 中的做法:矩阵 C 的每一个元素是分组计算累加得到的(即 24(m*n*k) 次计算由于循环展开分成了 6(k*n) 组(六次内部循环),矩阵 C 的每个元素需要在 3 组计算中累加得到,比如 c[0] 需要参与第 1,3,5 组计算)。矩阵 B 中参与运算的元素使用临时变量保存,对矩阵 C 和矩阵 A 的访存都是连续的,这就避免了大量的缓存未命中,从而提高计算速度。

两者区别如下图所示:

  • 一般矩阵乘法:b[0]->b[4]->b[8] 这样的访存是跨越行的,会导致缓存利用率下降。跳跃访存的总次数为 3*8 -1 = 23 次。

  • PyTorch 算法:内层循环都是顺序访存,只有中间循环(b[l + j * ldb],此时循环变量乘上了步长)才会跳跃访存(b[0]->b[3])。跳跃访存的总次数为 k*n = 3*2-1 = 5 次。

    虽然内层循环中也有 c[j * ldc + i_i * 4 + 0]、a[i_i * 4 + 0 + l * lda] 这样的变量乘步长的形式,但在内层循环执行过程中,j 和 l 实际上是固定的常量,因此不能称为循环变量。内层循环的循环变量是 i_i 。

image

最后,作者数学、算法功底不好,对上述 PyTorch 矩阵乘法如有错误或不合理的解释,请指正。

float32 实现

示例代码(PyTorch 默认数据类型是 float32)

1
2
3
4
5
import torch

A = torch.rand(10, 10)
B = torch.rand(10, 10)
R = A.mm(B)

float32 数据类型在 addmm_impl_cpu_ 中调用的 gemm 就是另一种实现。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
// aten/src/ATen/native/CPUBlas.cpp
void gemm(
    TransposeType transa, TransposeType transb,
    int64_t m, int64_t n, int64_t k,
    const float alpha,
    const float *a, int64_t lda,
    const float *b, int64_t ldb,
    const float beta,
    float *c, int64_t ldc) {
  internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);

// 尝试 MKL-DNN 优化路径
#if AT_MKLDNN_ENABLED()
   if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
     return;
   }
#endif

// 尝试 BLAS 库路径
#if AT_BUILD_WITH_BLAS()
  if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
    int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
    float alpha_ = alpha, beta_ = beta;
    // cblas_sgemm:用于单精度浮点数的矩阵乘法运算,适用于 iOS 平台和 Apple Accelerate 框架。
    #if C10_IOS
    CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
    CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
    cblas_sgemm(CblasColMajor,
      transa_, transb_,
      m_, n_, k_,
      alpha_,
      a, lda_,
      b, ldb_,
      beta_,
      c, ldc_);
    #else
    // sgemm_:这是一个 BLAS 的 C 接口,用于执行矩阵乘法和加法操作。
    char transa_ = to_blas(transa), transb_ = to_blas(transb);
    sgemm_(
        &transa_, &transb_,
        &m_, &n_, &k_,
        &alpha_,
        a, &lda_,
        b, &ldb_,
        &beta_,
        c, &ldc_);
    #endif
    return;
  }
#endif
  // 如果 MKL-DNN 和 BLAS 都不可用或不适用,则像 int8 一样使用 gemm_stub 作为默认实现。
  gemm_stub(
      at::kCPU, at::kFloat,
      transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}

这个 gemm 实现会优先使用 mkldnn 的 gemm 实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// aten/src/ATen/native/mkldnn/Matmul.cpp
bool mkldnn_bf32_gemm(
    TransposeType transa, TransposeType transb,
    int64_t m, int64_t n, int64_t k,
    float alpha,
    const float *a, int64_t lda,
    const float *b, int64_t ldb,
    float beta,
    float *c, int64_t ldc){
      return mkldnn_gemm<float>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
    }

mkldnn_bf32_gemm() 是一个包装函数,实际执行在 mkldnn_gemm() 。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// aten/src/ATen/native/mkldnn/Matmul.cpp
template<typename scalar_t>
inline typename std::enable_if_t<
    std::is_same_v<scalar_t, float> ||
    std::is_same_v<scalar_t, c10::Half> ||
    std::is_same_v<scalar_t, c10::BFloat16>,
    bool>
mkldnn_gemm(
    TransposeType transa, TransposeType transb,
    int64_t m, int64_t n, int64_t k,
    float alpha,
    const scalar_t *a_data, int64_t lda,
    const scalar_t *b_data, int64_t ldb,
    float beta,
    scalar_t *c_data, int64_t ldc) {
  bool bf16_usable = std::is_same_v<scalar_t, c10::BFloat16> && use_mkldnn_bf16_matmul();
  bool fp16_usable = std::is_same_v<scalar_t, c10::Half> && use_mkldnn_fp16_matmul();
  bool bf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_bf32_matmul();
  if ( !(bf16_usable || fp16_usable || bf32_usable) ||
      (m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) {
    return false;
  }

  ideep::attr_t op_attr;
  // Use mkldnn post ops to perform the add.
  if (beta != 0.0f) {
    op_attr = ideep::attr_t::fuse_sum();
  }
  if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path

  // NOTE: View as c-contiguous to avoid extra reordering in mkldnn
  // Use identity: C = AB <=> C^T = B^T A^T
  ideep::tensor::dims a_strides{{lda, 1}}, b_strides{{ldb, 1}}, c_strides{{ldc, 1}};
  if (transa != TransposeType::NoTranspose) {
    std::swap(a_strides[0], a_strides[1]);
  }
  if (transb != TransposeType::NoTranspose) {
    std::swap(b_strides[0], b_strides[1]);
  }

  auto idtype = ideep::tensor::data_type::bf16;
  if constexpr (std::is_same_v<scalar_t, c10::Half>) {
    idtype = ideep::tensor::data_type::f16;
  }
  if constexpr (std::is_same_v<scalar_t, float>) {
    idtype = ideep::tensor::data_type::f32;
  }

  ideep::tensor a({
      /*sizes=*/{k, m},
      idtype,
      /*strides=*/a_strides},
    const_cast<scalar_t*>(a_data));
  ideep::tensor b({
      /*sizes=*/{n, k},
      idtype,
      /*strides=*/b_strides},
    const_cast<scalar_t*>(b_data));
  ideep::tensor c({
      /*sizes=*/{n, m},
      idtype,
      /*strides=*/c_strides},
    c_data);

  ideep::matmul_forward::compute(
      b, a, c, alpha, beta,
      ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);

  if (c.get_data_handle() != c_data){
    // ideep will query onednn expect format of output
    // if given output format is not expected, ideep will re-init an output buffer
    // under this case, we need copy the re-inited buffer back to given buffer
    ideep::tensor real_output({
        /*sizes=*/{n, m},
        idtype,
        /*strides=*/c_strides},
      c_data);
    c.reorder_to(real_output);
  }

  return true;
}

可以看到这里根据 bf16_usable/fp16_usable/bf32_usable 来判断 mkldnn 是否可用?

为什么我 float32 的数据类型会进行这些判断?这是因为 PyTorch 中有一个控制 float32 矩阵乘法的内部精度的设置(float32MatmulPrecision)。当 float32MatmulPrecision 为 medium 时,就会使用 bf16 进行内部计算,从而提供加速效果。如 use_mkldnn_bf32_matmul() 函数所述。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
static bool use_mkldnn_bf16_matmul() {
  return at::globalContext().userEnabledMkldnn() && mkldnn_bf16_device_check();
}

static bool use_mkldnn_fp16_matmul() {
  return at::globalContext().userEnabledMkldnn() && mkldnn_fp16_device_check();
}

static bool use_mkldnn_bf32_matmul() {
  return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
}

如果系统上有支持 mkldnn 的硬件设备,并且 float32MatmulPrecision 为 medium,那么 mkldnn_gemm 就会调用 ideep::matmul_forward::compute 来执行具体的计算。ideep(Intel Deep Learning Boost library)是由英特尔开发的深度学习库,专门优化了在 Intel 硬件(特别是支持 Intel® DL Boost 和 AVX-512 的硬件)上进行高效的深度学习运算。关于 mkldnn 和 ideep 的关系在这不展开叙述。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// third_party/ideep/include/ideep/operators/matmul.hpp
  static void compute(
      const tensor& src,
      const tensor& weights,
      tensor& dst,
      const float dst_coeff = 1.0f,
      const float sum_coeff = 1.0f,
      const scale_t& src_scales = scale_t(),
      const scale_t& weights_scales = scale_t(),
      const scale_t& dst_scales = scale_t(),
      const attr_t& attr = attr_t(),
      const std::vector<tensor>& bin_post_params = {},
      const data_type dst_type = data_type::undef,
      const lowp_kind alowp_kind = u8s8,
      const engine& aengine = engine::cpu_engine()) {
    // Consider fp32 only for IPEX
    static tensor dummy_bias;
    compute_impl</*with_bias=*/false, /*reorder_src*/true, /*reorder_weight*/true>(
        src, weights, dummy_bias, dst,
        src_scales, weights_scales, dst_scales,
        IDEEP_EMPTY_ZP, IDEEP_EMPTY_ZP,
        dst_coeff, sum_coeff, attr, bin_post_params,
        dst_type, alowp_kind, aengine);
  }

由于作者的机器不支持 mkldnn,因此关于 mkldnn 的内容就止步于这个第三方库中的函数定义,不再往下展开。

在我们的例子中,采用的是 sgemm_ 实现

1
2
3
// aten/src/ATen/native/CPUBlas.cpp
#include <ATen/native/mkl/LinearAlgebra.h>
extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);

sgemm* 函数是一个用于在 CPU 上执行单精度浮点数矩阵乘法的封装。sgemm 是一个典型的 BLAS(Basic Linear Algebra Subprograms)函数,专门用于进行单精度浮点矩阵相乘。PyTorch 通过 sgemm* 函数,利用了底层的 BLAS 实现来加速矩阵计算。

可以使用以下命令确定 PyTorch 目前使用的 BLAS 库

$ python -c "import torch;print(torch.__config__.show())"
PyTorch built with:
  - GCC 11.4
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2025.0-Product Build 20241009 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.5.3 (Git Hash 66f0cb9eb66affd2da3bf5f8d897376f04aae6af)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 12.4
  - NVCC architecture flags: -gencode;arch=compute_86,code=sm_86
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Debug, CUDA_VERSION=12.4, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.6.0, USE_CUDA=ON, USE_CUDNN=OFF, USE_CUSPARSELT=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

根据输出信息

  • Intel(R) oneAPI Math Kernel Library Version 2025.0-Product Build 20241009 for Intel(R) 64 architecture applications
  • Build settings: BLAS_INFO=mkl

可以发现当前 PyTorch 构建使用的是 Intel oneAPI Math Kernel Library (MKL) 。因此下一步会调用 mkl 的 sgemm 实现继续执行。其头文件包含代码如下:

1
2
// ATen/native/mkl/LinearAlgebra.h
#include <mkl.h>

mkl 是我们在安装 PyTorch 依赖时安装的,包括其静态链接库和头文件,之前的安装命令如下

pip install mkl-static mkl-include

我们可以在系统中找到它们(当前 conda 虚拟环境为 pytorch_dev)

(pytorch_dev) $ find $(conda info --base)/envs/pytorch_dev/include -name "mkl*.h"
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas_64.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_spblas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapacke.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapack.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapack_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_spblas_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_scalapack.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cdft.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df_functions.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_pardiso.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_functions.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cdft_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_sparse_qr.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_functions_64.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_blas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_defines.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_service.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_pblas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df_defines.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_dfti.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cblas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_blas_kernels.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_trig_transforms.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_dss.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_sparse_handle.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_version.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_defines.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_trans_names.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cblas_64.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_poisson.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_solvers_ee.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_df.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_trans.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blacs.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_dfti_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_cluster_sparse_solver.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_spblas_omp_variant.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vsl_functions.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_compact.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_types.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_lapack_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_rci.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_lapack.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_blas_omp_offload.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_direct_call.h
/home/lin/miniconda3/envs/pytorch_dev/include/mkl_vml_omp_variant.h

(pytorch_dev) $ find $(conda info --base)/envs/pytorch_dev/lib -name "libmkl*"
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_gf_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_tbb_thread.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_intel_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_lapack95_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_gnu_thread.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_lapack95_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blas95_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blas95_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_openmpi_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_intelmpi_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_scalapack_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_intel_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_openmpi_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_cdft_core.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_core.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_scalapack_ilp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_intel_thread.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_blacs_intelmpi_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_gf_lp64.a
/home/lin/miniconda3/envs/pytorch_dev/lib/libmkl_sequential.a

由于 mkl 不开源,因此静态库中发生了什么我们就不得而知了,只能获得返回的计算结果。

本文也就告一段落,下一步可能会将 PyTorch 的 BLAS 库换成 OpenBLAS,比较其和 mkl 性能差距。OpenBLAS 作为开源项目,有非常多的内容可供学习。


本站不记录浏览量,但如果您觉得本内容有帮助,请点个小红心,让我知道您的喜欢。