05Nvidia剪枝方案介绍

发布时间 2023-07-06 16:11:01作者: DemonSlayer

Nvidia剪枝方案介绍

目前大多数的剪枝研究处于以下两个方面

  1. 绝大多数剪枝是非结构化的,属于细粒度稀疏。而细粒度稀疏其实没有那么好的加速效果
  2. Coarse-grained sparsity的稀疏效果有限

("Coarse-grained sparsity"是一种稀疏性类型,它指的是在较大的数据块或数据结构中存在稀疏性,而不是在单个元素级别。在深度学习和神经网络中,这通常意味着在层级别或通道级别进行稀疏化,而不是在单个权重或神经元级别。

例如,对于卷积神经网络,粗粒度稀疏性可能意味着整个过滤器或通道被置零或被剪枝,而不是单个权重。这种稀疏性类型的一个优点是,它可以更容易地利用硬件加速器的并行性,因为整个数据块可以一次性地被加载、处理或跳过。

相反,"fine-grained sparsity"则是指在单个元素级别存在稀疏性,例如单个权重或神经元被置零或被剪枝。这种稀疏性类型可能更难以优化,因为它可能需要更复杂的索引和数据管理策略。

"coarse-grained sparsity"是一种在更大的数据结构级别实现稀疏性的策略,它可以更容易地与硬件优化相结合。)

面临的挑战

  1. 精度丢失
  2. 没有一个通用的剪枝方案去针对不同的网络
  3. Lack of speedup(由于剪完之后结构发生了改变,可能无法使用矩阵加速,可能无法利用内存加速,存储开销变大)

    下面是一个demo,流程是加载预训练模型 -> 测试预训练模型 -> 剪枝 ->测试剪枝后的模型 -> 再训练剪枝后的模型 -> 测试再训练后的模型 -> 保存剪枝和再训练后的模型
torch.manual_seed(42)
get_model("./model.pt")
# get_model("None")
print("-------orig---------")
test()
print(model[2].state_dict())
ASP.prune_trained_model(model, optimizer)
print("-------pruned---------")
test()
print(model[2].state_dict())
train()
print("-------retrain---------")
test()
print(model[2].state_dict())
torch.save(model, "./model_sparse.pt")

构建加载模型的函数get_model()

#如果有则直接加载模型和优化器,没有则构建一个简单的模型并train一下然后保存下来
def get_model(f):
    global model, optimizer
    if os.path.exists(f):
        model = torch.load(f).cuda()
        optimizer = optim.Adam(model.parameters(), lr=0.01)
    else:
        model = nn.Sequential(
            nn.Linear(8, 16),
            nn.PReLU(),
            nn.Linear(16, 8),
        ).cuda()
        optimizer = optim.Adam(model.parameters(), lr=0.01)
        train()
        torch.save(model, f)

ASP( Automatic Sparsity Pruning)复现,该方法是Nvidia在2020年提出并首次引入Nvidia的Ampere架构中。在这种方法中,权重的重要性是通过一种称为 "mask" 的机制来确定的。这些 mask 是在训练过程中学习的,并且在训练结束时,权重被乘以相应的 mask。这样,不重要的权重(即,对应于 mask 中的零的权重)就被剪枝掉了。

class ASP:
    model = None
    verbosity = 0
    optimizer = None
    sparse_parameters = []
    calculate_mask = None

    @classmethod
    def init_model_for_pruning(
        cls,
        model,
        mask_calculator="m4n2_1d",
        verbosity=3,
        whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d],
        custom_layer_dict={},
    ):
        assert cls.model is None, "ASP has been initialized already."
        cls.model = model
        cls.verbosity = verbosity

        if isinstance(mask_calculator, str):
            def create_mask_from_pattern(param):
                return create_mask(param, mask_calculator).bool()

            cls.calculate_mask = create_mask_from_pattern

        # function to extract variables that will be sparsified.
        # idea is that you will add one of these functions for each module type that can be sparsified.

        sparse_parameter_list = {
            torch.nn.Linear: ["weight"],
            torch.nn.Conv1d: ["weight"],
            torch.nn.Conv2d: ["weight"],
        }
        if (
            custom_layer_dict
        ):  # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
            sparse_parameter_list.update(custom_layer_dict)
            whitelist += list(custom_layer_dict.keys())

        for module_type in whitelist:
            assert module_type in sparse_parameter_list, (
                "Module %s :: Don't know how to sparsify module." % module.dtype()
            )

        # find all sparse modules, extract sparse parameters and decorate
        def add_sparse_attributes(module_name, module):
            sparse_parameters = sparse_parameter_list[type(module)]
            for p_name, p in module.named_parameters():
                if p_name in sparse_parameters and p.requires_grad:
                    # check for NVIDIA's TC compatibility: we check along the horizontal direction
                    if p.dtype == torch.float32 and (
                        (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
                    ):  # User defines FP32 and APEX internally uses FP16 math
                        print(
                            "[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
                            % (module_name, p_name, str(p.size()), str(p.dtype))
                        )
                        continue
                    if p.dtype == torch.float16 and (
                        (p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0
                    ):  # For Conv2d dim= K x CRS; we prune along C
                        print(
                            "[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
                            % (module_name, p_name, str(p.size()), str(p.dtype))
                        )
                        continue

                    if cls.verbosity >= 3:
                        print(
                            "[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity"
                            % (module_name, p_name, str(p.size()), str(p.dtype))
                        )

                    mask = torch.ones_like(p).bool()
                    buffname = p_name.split(".")[-1]  # buffer names cannot contain "."
                    module.register_buffer("__%s_mma_mask" % buffname, mask)
                    cls.sparse_parameters.append(
                        (module_name, module, p_name, p, mask)
                    )
                else:
                    if cls.verbosity >= 3:
                        print(
                            "[ASP] Not sparsifying %s::%s of size=%s and type=%s"
                            % (module_name, p_name, str(p.size()), str(p.dtype))
                        )

        for name, sparse_module in eligible_modules(
            model, tuple(whitelist)
        ):
            add_sparse_attributes(name, sparse_module)

    @classmethod
    def init_optimizer_for_pruning(cls, optimizer):
        assert cls.optimizer is None, "ASP has initialized optimizer already."
        assert (
            cls.calculate_mask is not None
        ), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."

        # store pointer to original optimizer step method
        cls.optimizer = optimizer
        cls.optimizer.__step = optimizer.step

        def __step(opt_self, *args, **kwargs):
            # prune gradients before step method
            with torch.no_grad():
                for (
                    module_name,
                    module,
                    p_name,
                    p,
                    mask,
                ) in cls.sparse_parameters:
                    if p.grad is not None:  # thx pjudd
                        p.grad.mul_(mask)
            # call original optimizer step method
            rval = opt_self.__step(*args, **kwargs)
            # prune parameters after step method
            with torch.no_grad():
                for (
                    module_name,
                    module,
                    p_name,
                    p,
                    mask,
                ) in cls.sparse_parameters:
                    p.mul_(mask)
            return rval

        cls.optimizer.step = types.MethodType(__step, cls.optimizer)

    @classmethod
    def compute_sparse_masks(cls): #!aaaa
        with torch.no_grad():
            for module_name, module, p_name, p, mask in cls.sparse_parameters:
                mask.set_(cls.calculate_mask(p)) # torch.Size([8, 16]) # mask = cls.calculate_mask(p) # in place op
                p.mul_(
                    mask
                )  # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights

    @classmethod
    def prune_trained_model(cls, model, optimizer):
        # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
        cls.init_model_for_pruning(
            model,
            mask_calculator="m4n2_1d",
            verbosity=2,
            whitelist=[torch.nn.Linear, torch.nn.Conv2d],
        )
        cls.init_optimizer_for_pruning(optimizer)
        cls.compute_sparse_masks()

构建mask

def create_mask(tensor, pattern="m4n2_1d", density=0.5): #! 0
    # Reshape tensor and mask.
    shape = tensor.shape
    ttype = tensor.type()
    t = tensor.float().contiguous()

    # len(shape) == 2:
    t = t.view(shape[0], shape[1])
    func = getattr(sys.modules[__name__], pattern, None) # getattr() asks for the name of a thing we're looking for (like a function or an attribute in a module), and if it finds it, we can use it later in our code.
    mask = func(t, density) # func here is m4n2_1d func
    return mask.view(shape).type(ttype)
param = torch.randn(8, 16).to("cuda:0")

    def create_mask_from_pattern(param):
        return create_mask(param, "m4n2_1d").bool() #工厂模式

    mask = create_mask_from_pattern(param)

首先是取到权重矩阵,然后分割成每4个一组,然后乘以01的全排列(m个位置里选出n个1),假设是4,则有6种排列,那么结果是n*6的矩阵,然后在每一个维度上取一个最大值

#从m个位置里选出n个位置为1并生成所有的排列
def compute_valid_1d_patterns(m, n): 
    patterns = torch.zeros(m) # [0,0,0,0]
    patterns[:n] = 1
    valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
    return valid_patterns
def mn_1d_best(matrix, m, n): 
    patterns = compute_valid_1d_patterns(m, n).cuda()
    #首先把权重矩阵复制出来,全部填上1,并更改为4个一组
    mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1, m)
    mat, shape = reshape_1d(matrix, m) # matrix: [8, 16] ==>  mat[32, 4]
    #做矩阵乘法,并对每一行取一个最大值的索引
    #在PyTorch中,torch.argmax()函数返回输入张量中沿指定维度最大值的索引。dim参数就是用来指定这个维度的。
    pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) # 32x4@4x6=32x6
    #pmax是索引,根据索引把对应01排列取出来
    mask[:] = patterns[pmax[:]]
    #然后将mask还原成matrix的形状
    mask = mask.view(matrix.shape)
    return mask