Pytorch1.9 released 里发布了一个 Inference Mode API(Beta).

推理模式API ( Inference Mode API ) 可以显著加速推理工作负载的速度,同时保持安全,并确保永远不会计算不正确的梯度. 在不需要 autograd 时,其提供了最好的性能.

inference_mode — PyTorch 1.9.0 documentation

定义如:

class torch.inference_mode(mode=True)
#参数 mode: 是否开启推断模式

InferenceMode 是类似与 no_grad 的上下文管理器(context manager),主要用于确定不需要与 autograd 交互时使用. 这种模式下运行的代码,通过禁用试图跟踪(view tracking) 和版本计数器缓冲(version counter bumps) 来获得更好的性能.

InferenceMode 上下文管理器是局部线程的(thread local),其不会影响其他线程中的计算.

InferenceMode 也起到装饰器的作用.

使用示例:

import torch

x = torch.ones(1, 2, 3, requires_grad=True)
with torch.inference_mode():
  y = x * x

#
y.requires_grad # False
y._version
#Traceback (most recent call last):
#File "<stdin>", line 1, in <module>
#RuntimeError: Inference tensors do not track version counter.

@torch.inference_mode()
def func(x):
  return x * x
out = func(x)
out.requires_grad #False
Last modification:June 16th, 2021 at 02:00 pm