以pytorch的forward hook为例探究hook机制

NiaxHahaha 2024-10-10 08:09:00 阅读 78

对pytorch中hook机制的简单介绍和源码分析

在看pytorch的nn.Module部分的源码的时候,看到了一堆"钩子",也就是hook,然后去研究了一下这是啥玩意。

基本概念

在深度学习中,hook 是一种可以在模型的不同阶段插入自定义代码的机制。通过自定义数据在通过模型的特定层的额外行为,可以用来监控状态,协助调试,获得中间结果。

以前向hook为例

前向hook是模型在forward过程中会调用的hook,通过torch.nn.Module的register_forward_hook() 函数,将一个自定义的hook函数注册给模型的一个层

该层在进行前向之后,根据其输入和输出会进行相应的行为。

<code>

import torch

import torch.nn as nn

# 定义模型

class SimpleModel(nn.Module):

def __init__(self):

super(SimpleModel, self).__init__()

self.fc1 = nn.Linear(10, 5)

self.fc2 = nn.Linear(5, 2)

def forward(self, x):

x = self.fc1(x)

return self.fc2(x)

model = SimpleModel()

# 自定义的forward hook

def my_forward_hook(module, input, output):

print(f"层: {module}")

print(f"输入: {input}")

print(f"输出: {output}")

# 为模型的fc1层注册hook

hook = model.fc1.register_forward_hook(my_forward_hook)

# 移除这个hook

hook.remove()

接口

hook函数的格式

需要是一个接受三个特定参数,返回None的函数

def hook_function(module, input, output):

# 自定义逻辑

return None

  • module: 触发钩子的模型层,事实上是调用register_forward_hook的nn.Module实例
  • input: 传递给该层的输入张量(可能是元组),是前向传播时该层接收到的输入。
  • output: 该层的输出张量,是前向传播时该层生成的输出。

    函数内部可以做自定义行为,可以在函数内部对output进行修改,从而改变模型的输出。

注册hook

hook = model.fc1.register_forward_hook(my_forward_hook)

hook.remove()

对于定义好的hook函数,将其作为参数,调用需要注册的模型层的注册函数即可。

如果不再需要这个hook,调用remove函数。

简单的源码讨论

还是以forward hook为例。一个nn.Module具有成员_forward_hooks,这是一个有序字典,在__init__()函数调用的时候被初始化

self._forward_hooks = OrderedDict()

注册钩子的register函数。

每个hook对应一个RemovableHandle对象,以其id作为键注册到hook字典中,利用其remove函数实现移除。

def register_forward_hook(

self,

hook: Union[

Callable[[T, Tuple[Any, ...], Any], Optional[Any]],

Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],

],

*,

prepend: bool = False,

with_kwargs: bool = False,

always_call: bool = False,

) -> RemovableHandle:

handle = RemovableHandle(

self._forward_hooks,

extra_dict=[

self._forward_hooks_with_kwargs,

self._forward_hooks_always_called,

],

)

self._forward_hooks[handle.id] = hook

if with_kwargs:

self._forward_hooks_with_kwargs[handle.id] = True

if always_call:

self._forward_hooks_always_called[handle.id] = True

if prepend:

self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]

return handle

#简化版的handle类

class RemovableHandle:

def __init__(self, hooks_dict, handle_id):

self.hooks_dict = hooks_dict

self.id = handle_id

def remove(self):

del self.hooks_dict[self.id]


上一篇: Go基础知识

下一篇: Rust 中的 HashMap 实战指南:理解与优化技巧

本文标签

hook    模型    函数    self   


声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。