ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

(转)SignSGD 及其 MXNet 实现解读

2021-02-02 20:01:55  阅读:362  来源: 互联网

标签:解读 wd MXNet weight self SignSGD lr grad momentum


原文:https://zhuanlan.zhihu.com/p/112346480

论文笔记:SIGNSGD: compressed optimisation for non-convex problems

这是一篇来自 Caltech,Amazon AI 和 UC Irvine 的文章。

名字非常的直白,方法也异常的简单(简单并不简单)。

总结起来就是:

SGD里面,梯度真正有用的是方向而不是大小。所以,即使你只保留梯度的符号来对模型进行更新,也能得到收敛的效果。甚至有些情况下,这么做能减少梯度的噪声,使得收敛速度更快。

根据上面的结论,进而衍生出了三种算法

SignSGD

直接把 gradient 求 sign

Signum

把 momentum 求 sign

SignMajorityVote

在 distributed training 下的应用

MXNet 实现

作者给出了 MXNet 的实现,并且这个优化器也被 MXNet 收录了。(估计是因为作者当时在 Amazon AI 实习,然后二组是 Yuxiang Wang,当时也在 Amazon AI 工作。)

mxnet.optimizer - Apache MXNet documentation

mxnet.optimizer.signum - Apache MXNet documentation

下面来一起看一下代码,关键部分我已经注释出来了。

其中函数 fused_step 的原理和 step 应该是一样的,只是 MXNet 为了提高效率而提出一种混合计算图的方法(效率比较高,但是不再是清晰的python代码了)。具体可以看这里,

MXNet Graph Optimization and Quantization based on subgraph and MKL-DNN

    # coding: utf-8
    # Licensed to the Apache Software Foundation (ASF) under one
    # or more contributor license agreements.  See the NOTICE file
    # distributed with this work for additional information
    # regarding copyright ownership.  The ASF licenses this file
    # to you under the Apache License, Version 2.0 (the
    # "License"); you may not use this file except in compliance
    # with the License.  You may obtain a copy of the License at
    #
    #   http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing,
    # software distributed under the License is distributed on an
    # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    # KIND, either express or implied.  See the License for the
    # specific language governing permissions and limitations
    # under the License.
    """Signum optimizer."""
    from __future__ import absolute_import
    from ..ndarray import (zeros, clip)
    from ..ndarray import (signsgd_update, signum_update)
    from .optimizer import Optimizer, register
    
    __all__ = ['Signum']
    
    
    @register
    class Signum(Optimizer):
        r"""The Signum optimizer that takes the sign of gradient or momentum.
    
        The optimizer updates the weight by::
    
            rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
            state = momentum * state + (1-momentum)*rescaled_grad
            weight = (1 - lr * wd_lh) * weight - lr * sign(state)
    
        References
        ----------
        Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli & Anima Anandkumar. (2018).
        signSGD: Compressed Optimisation for Non-Convex Problems. In ICML'18.
    
        See: https://arxiv.org/abs/1802.04434
    
        For details of the update algorithm see
        :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
    
        This optimizer accepts the following parameters in addition to those accepted
        by :class:`.Optimizer`.
    
        Parameters
        ----------
        learning_rate : float, default 0.01
            The initial learning rate. If None, the optimization will use the
            learning rate from ``lr_scheduler``. If not None, it will overwrite
            the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
            is also None, then it will be set to 0.01 by default.
        momentum : float, optional
           The momentum value.
        wd_lh : float, optional
           The amount of decoupled weight decay regularization, see details in the original paper at:\
           https://arxiv.org/abs/1711.05101
        use_fused_step : bool, default True
            Whether or not to use fused kernels for optimizer.
            When use_fused_step=False, step is called,
            otherwise, fused_step is called.
        """
        def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, use_fused_step=True, **kwargs):
            super(Signum, self).__init__(learning_rate=learning_rate,
                                         use_fused_step=use_fused_step,
                                         **kwargs)
            # 这两个量都是 float numbers
            self.momentum = momentum
            self.wd_lh = wd_lh
    
        def create_state(self, index, weight):
            momentum = None
            if self.momentum != 0.0:   # 如果有 momentum,否则直接返回 None
    		# 相当于 pytorch 里面的 zero_like, 
    		# 这个函数会为每个参数都 call 一遍,为每个参数初始化一个 momentum
                momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
            return momentum
    
    
        def step(self, indices, weights, grads, states):
            """Perform an optimization step using gradients and states.
    
             Parameters
             ----------
             indices : list of int
                 List of unique indices of the parameters into the individual learning rates
                 and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
                 and `set_wd_mult()`, respectively.
             weights : list of NDArray
                 List of parameters to be updated.
             grads : list of NDArray
                 List of gradients of the objective with respect to this parameter.
             states : List of any obj
                 List of state returned by `create_state()`.
             """
            for index, weight, grad, state in zip(indices, weights, grads, states):
                self._update_count(index)
                lr = self._get_lr(index)
                wd = self._get_wd(index)
    
                if state is not None:  # 如果有 momentum 的话,就是 signum
                    # preprocess grad
                    # rescaled_grad = rescale_grad * clip(grad, clip_gradient) 
                    #                 + wd * weight 
                    # 这个地方实际上是跟文档里面的公式不符的,但是不是很影响结果
                    grad *= self.rescale_grad
                    if self.clip_gradient is not None:
                        grad = clip(grad, - self.clip_gradient, self.clip_gradient)
                    grad += wd * weight
    
                    # update mom, 这里算的其实是 -momentum
                    mom = state
                    mom[:] *= self.momentum
                    mom[:] -= (1 - self.momentum) * grad
    
                    # update weight
                    weight[:] *= 1 - lr * self.wd_lh
                    weight[:] += lr * ((mom > 0) - (mom < 0))
                else:                 # 如果没有 momentum 的话,就是 signsgd
                    # update weight
                    weight[:] *= 1 - lr * (wd + self.wd_lh)
                    weight[:] -= lr * ((grad > 0) - (grad < 0))
    
    
        def fused_step(self, indices, weights, grads, states):
            """Perform a fused optimization step using gradients and states.
            Fused kernel is used for update.
    
            Parameters
            ----------
            indices : list of int
                List of unique indices of the parameters into the individual learning rates
                and weight decays. Learning rates and weight decay may be set via `set_lr_mult()`
                and `set_wd_mult()`, respectively.
            weights : list of NDArray
                List of parameters to be updated.
            grads : list of NDArray
                List of gradients of the objective with respect to this parameter.
            states : List of any obj
                List of state returned by `create_state()`.
            """
            for index, weight, grad, state in zip(indices, weights, grads, states):
                self._update_count(index)
                lr = self._get_lr(index)
                wd = self._get_wd(index)
    
                kwargs = {'rescale_grad': self.rescale_grad}
                if self.momentum > 0:
                    kwargs['momentum'] = self.momentum
                if self.clip_gradient:
                    kwargs['clip_gradient'] = self.clip_gradient
    
                # update weight with fused kernel
                if state is not None:
                    if self.wd_lh:
                        kwargs['wd_lh'] = self.wd_lh
                    signum_update(weight, grad, state, out=weight,
                                  lr=lr, wd=wd, **kwargs)
                else:
                    wd += self.wd_lh
                    signsgd_update(weight, grad, out=weight,
                                   lr=lr, wd=wd, **kwargs)

 

 

标签:解读,wd,MXNet,weight,self,SignSGD,lr,grad,momentum
来源: https://blog.csdn.net/jollyjumper/article/details/113572966

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有