ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

动手实现堆叠式CapsNet(上)

2020-12-04 23:57:26  阅读:196  来源: 互联网

标签:dim self CapsNet 堆叠 动手 objects input part operatorname


前言

本次动手实现论文《stacked capsule autoencoders》的pytorch版本。这篇论文的原作者开源了TensorFlow版本[1],其细节和工程性都挺不错,是个参考的好范本(做研究建议直接参考原项目)。关于pytorch的实现,github也开源了相关例子[2,3,4],但这些都只实现了原文第二个实验。本文是对其原文第一个实验的复现笔记,后续也计划复现第二个实验。

全部复现代码会开源在https://github.com/QiangZiBro/stacked_capsule_autoencoders.pytorch,欢迎提issue。

复现目标

第一个实验

  • Set Transformer (直接使用原论文代码)
  • CCAE模型
  • 高斯混合模型的编程实现
  • Concellation数据集生成
  • CCAE训练
  • 可视化CCAE

前期准备

环境

  • 系统:ubuntu 18.04.04
  • 显卡:GP100
  • 环境管理:miniconda3
  • 相关第三方库:pytorch1.7

为了保证工程性以及少点重复工作,我们基于一个深度学习模板项目来进行本次实现。当然,为了可解释性,也会使用notebook进行相关可视化。同时会写一些必备的test case,来帮助我更加了解一些细节。

git clone https://github.com/QiangZiBro/pytorch-template
cd pytorch-template
python new_project.py ../stacked_capsule_autoencoders.pytorch
cd ../stacked_capsule_autoencoders.pytorch
wget https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore -O .gitignore

模型细节

概览

CCAE编码器为Set Transformer,解码器为mlp的自编码器,其输入是2维平面上的点集。

我们接下来总结CCAE模型的细节

Set Transformer


Set Transformer的编码器可以是连续的SAB或者连续的ISAB。使用ISAB的优点是其使用了诱导点 I ∈ R m × d I \in \mathbb{R}^{m \times d} I∈Rm×d,使得计算参数比SAB更少。
Z = Encoder ⁡ ( X ) = SAB ⁡ ( SAB ⁡ ( X ) ) ∈ R n × d Z=\operatorname{Encoder}(X)=\operatorname{SAB}(\operatorname{SAB}(X)) \in \mathbb{R}^{n \times d} Z=Encoder(X)=SAB(SAB(X))∈Rn×d

Z = Encoder ⁡ ( X ) = ISAB ⁡ m ( ISAB ⁡ m ( X ) ) ∈ R n × d Z=\operatorname{Encoder}(X)=\operatorname{ISAB}_{m}\left(\operatorname{ISAB}_{m}(X)\right) \in \mathbb{R}^{n \times d} Z=Encoder(X)=ISABm​(ISABm​(X))∈Rn×d

解码器
O = Decoder ⁡ ( Z ; λ ) = rFF ⁡ ( SAB ⁡ ( PMA ⁡ k ( Z ) ) ) ∈ R k × d O=\operatorname{Decoder}(Z ; \lambda)=\operatorname{rFF}\left(\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)\right) \in \mathbb{R}^{k \times d} O=Decoder(Z;λ)=rFF(SAB(PMAk​(Z)))∈Rk×d


细节部分

  • rFF ⁡ ( x ) \operatorname{rFF}(x) rFF(x) 全连接 ,具体的讲,输入 n × d n \times d n×d维,输出也是 n × d n \times d n×d维。
  • 注意力机制: Att ⁡ ( Q , K , V ; ω ) = ω ( Q K ⊤ ) V ∈ R n × d v \operatorname{Att}(Q, K, V ; \omega)=\omega\left(Q K^{\top}\right) V \in \mathbb{R}^{n \times d_v} Att(Q,K,V;ω)=ω(QK⊤)V∈Rn×dv​,其中, ω \omega ω是激活函数。
  • 多头注意力机制

Multihead ⁡ ( Q , K , V ; λ , ω ) = c o n c a t ( Z 1 , … , Z h ) ⁡ W O ∈ R n × d , Z j = Att ⁡ ( Q W j Q , K W j K , V W j V ; ω j ) ​ \operatorname{Multihead} (Q, K, V ; \lambda, \omega) = \operatorname{concat(Z_1,…,Z_h)}W^O \in \mathbb{R}^{n \times d},\\ Z_{j}=\operatorname{Att}\left(Q W_{j}^{Q}, K W_{j}^{K}, V W_{j}^{V} ; \omega_{j}\right)​ Multihead(Q,K,V;λ,ω)=concat(Z1​,…,Zh​)WO∈Rn×d,Zj​=Att(QWjQ​,KWjK​,VWjV​;ωj​)​

  • 多头注意力模块(Multihead Attention Block) 输出维度和X维度相同
    MAB ⁡ ( X , Y ) = LayerNorm ⁡ ( H + rFF ⁡ ( H ) ) \operatorname{MAB}(X, Y)=\operatorname{LayerNorm}(H+\operatorname{rFF}(H)) MAB(X,Y)=LayerNorm(H+rFF(H))
    其中 H =  LayerNorm  ( X + Multihead ⁡ ( X , Y , Y ; ω ) ) H=\text { LayerNorm }(X+\operatorname{Multihead}(X, Y, Y ; \omega)) H= LayerNorm (X+Multihead(X,Y,Y;ω))

  • 集合注意力模块(Set Attention Block ),计算复杂度 O ( n 2 ) \mathcal{O}\left(n^{2}\right) O(n2)

SAB ⁡ ( X ) = MAB ⁡ ( X , X ) \operatorname{SAB}(X) = \operatorname{MAB}(X,X) SAB(X)=MAB(X,X)

  • 诱导集合注意力模块(Induced Set Attention Block )

ISAB ⁡ m ( X ) = MAB ⁡ ( X , H ) ∈ R n × d ​ \operatorname{ISAB}_m(X)=\operatorname{MAB}(X, H) \in \mathbb{R}^{n \times d}​ ISABm​(X)=MAB(X,H)∈Rn×d​

​ 其中 H = MAB ⁡ ( I , X ) ∈ R m × d H=\operatorname{MAB}(I,X) \in \mathbb{R}^{m \times d} H=MAB(I,X)∈Rm×d, I ∈ R m × d I \in \mathbb{R}^{m \times d} I∈Rm×d为可学习参数。

  • 多头注意力机制的池化(Pooling by Multihead Attention)。池化是一种常见的聚合(aggregation)操作。上面提到,池化可以是最大或是平均。这里提出的池化是应用一个MAB在一个可学习的矩阵 S ∈ R k × d S \in \mathbb{R}^{k \times d} S∈Rk×d上。在一些聚类任务上, k k k设为我们需要的类别数。使用基于注意力的池化的直觉是,每个实例对target的重要性都不一样

PMA ⁡ k ( Z ) = MAB ⁡ ( S , rFF ⁡ ( Z ) ) \operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) PMAk​(Z)=MAB(S,rFF(Z))

H = SAB ⁡ ( PMA ⁡ k ( Z ) ) H=\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right) H=SAB(PMAk​(Z))

其中池化操作 PMA ⁡ k ( Z ) = MAB ⁡ ( S , rFF ⁡ ( Z ) ) ∈ R k × d \operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) \in \mathbb{R}^{k \times d} PMAk​(Z)=MAB(S,rFF(Z))∈Rk×d, k k k表示输出集合中实例的个数, k < n k < n k<n。

CCAE

对M个2维输入点组成的集合 x 1 : M \mathbf{x_{1:M}} x1:M​,首先使用Set Transformer将这个集合编码为 K K K个 ( 2 × 2 + n c + 1 ) (2\times 2+n_c+1) (2×2+nc​+1)的object向量,这三个数分别表示OV矩阵大小、特殊向量(即特征)、存在概率。特殊向量的尺度是个超参,原文 n c = 16 n_c=16 nc​=16。
O V 1 : K , c 1 : K , a 1 : K = h c a p s ( x 1 : M ) = SetTransformer ⁡ ( x 1 : M ) \mathrm{OV}_{1: K}, \mathbf{c}_{1: K}, a_{1: K}=\mathrm{h}^{\mathrm{caps}}\left(\mathbf{x}_{1: M}\right) = \operatorname{SetTransformer}\left(\mathbf{x}_{1: M}\right) OV1:K​,c1:K​,a1:K​=hcaps(x1:M​)=SetTransformer(x1:M​)
对每个object向量,取其特殊向量,通过mlp解码出 N N N个part。其中,每个part长度为 ( 2 + 1 + 1 ) (2+1+1) (2+1+1),分别为OP矩阵、存在概率、和标准差;每个object应用一个单独的mlp,mlp结构为 n c , 128 , ( 2 + 1 + 1 ) × N n_c,128,(2+1+1)\times N nc​,128,(2+1+1)×N。
O P k , 1 : N , a k , 1 : N , λ k , 1 : N = h k p a r t ( c k ) = m l p k ⁡ ( c k ) \mathrm{OP}_{k, 1: N}, a_{k, 1: N}, \lambda_{k, 1: N}=\mathrm{h}_{\mathrm{k}}^{\mathrm{part}}\left(\mathbf{c}_{k}\right) = \operatorname{mlp_k}\left(\mathbf{c}_{k}\right) OPk,1:N​,ak,1:N​,λk,1:N​=hkpart​(ck​)=mlpk​(ck​)
在原文例子中, M = 3 , N = 4 M=3, N=4 M=3,N=4。

每个解码出的part都可以表示一个高斯分量。CCAE处理的数据是2维平面点,因此表示的高斯分量的均值是2维,协方差矩阵大小是 2 × 2 2 \times 2 2×2的矩阵。具体的讲,由第 i i i个object产生的第 j j j个part表示的高斯分量均值为
μ k , n = O V k O P k , n \mu_{k,n} = OV_k OP_{k,n} μk,n​=OVk​OPk,n​
其中 O V k OV_k OVk​是 2 × 2 2 \times 2 2×2的矩阵, O P k , n OP_{k,n} OPk,n​是长度为2的向量。而part只有一个标量的标准差 λ k , n \lambda_{k,n} λk,n​,即,原文将一个高斯分量假设为各向同性,通过标准差 λ k , n \lambda_{k,n} λk,n​计算到高斯模型的协方差矩阵:
Σ k , n = [ 1 λ k , n 0 0 1 λ k , n ] \Sigma_{k,n} = \begin{bmatrix} \frac{1}{\lambda_{k,n}} & 0\\ 0 & \frac{1}{\lambda_{k,n}} \end{bmatrix} Σk,n​=[λk,n​1​0​0λk,n​1​​]
对于这个高斯分量的存在概率,表示为
π k , n = a k a k , n ∑ i a i ∑ j a i , j \pi_{k,n} = \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} πk,n​=∑i​ai​∑j​ai,j​ak​ak,n​​
因此,给定每个高斯模型的三个参数:均值,协方差,概率。可以得到给定数据分布在整个高斯混合模型上的估计为:
p ( x 1 : M ) = ∏ m = 1 M ∑ k = 1 K ∑ n = 1 N a k a k , n ∑ i a i ∑ j a i , j p ( x m ∣ k , n ) p\left(\mathbf{x}_{1: M}\right)=\prod_{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, n\right) p(x1:M​)=m=1∏M​k=1∑K​n=1∑N​∑i​ai​∑j​ai,j​ak​ak,n​​p(xm​∣k,n)
其中,点 x m \mathbf{x}_{m} xm​在第 i i i个object产生的第 j j j个part表示的高斯计算得到的似然值为
p ( x m ∣ k , n ) = p ( x m ∣ μ k , n , Σ k , n ) = 1 ( 2 π ) D / 2 ∣ Σ k , n ∣ 1 / 2 exp ⁡ ( 1 2 ( x m − μ k , n ) T Σ k , n − 1 ( x m − μ k , n ) ) p(\mathbf{x}_{m}|k,n) = p(\mathbf{x}_{m}|\mu_{k,n} ,\Sigma_{k,n}) = \frac{1}{(2 \pi)^{D/2} |\Sigma_{k,n}|^{1/2}}\operatorname{exp}\left(\frac{1}{2} \left(\mathbf{x}_{m}-\mu_{k,n} \right)^T \Sigma_{k,n}^{-1} \left(\mathbf{x}_{m}-\mu_{k,n} \right) \right) p(xm​∣k,n)=p(xm​∣μk,n​,Σk,n​)=(2π)D/2∣Σk,n​∣1/21​exp(21​(xm​−μk,n​)TΣk,n−1​(xm​−μk,n​))
最大化 p ( x 1 : M ) p\left(\mathbf{x}_{1: M}\right) p(x1:M​),求得 μ k , n , λ k , n , π k , n \mu_{k,n},\lambda_{k,n},\pi_{k,n} μk,n​,λk,n​,πk,n​,在理论上即可得到表示这个数据分布的模型。原文使用反向传播优化参数,目标是最大化 log ⁡ p ( x 1 : M ) \operatorname{log }p\left(\mathbf{x}_{1: M}\right) logp(x1:M​),等价于最小化 − log ⁡ p ( x 1 : M ) -\operatorname{log }p\left(\mathbf{x}_{1: M}\right) −logp(x1:M​)。

数据集

数据(3个集群,两个正方形,一个三角形)是在线创建的,每一次创建后被随机平移、放缩、旋转到180度,最后所有点被标准化到-1到1之间。

决策

依据object $ a_{k}$和其概率最高的part a k , n a_{k,n} ak,n​,对每个点 x m x_m xm​,其类别决策为 k ⋆ = arg ⁡ max ⁡ k a k a k , n p ( x m ∣ k , n ) k^{\star}=\arg \max _{k} a_{k} a_{k, n} p\left(\mathbf{x}_{m} \mid k, n\right) k⋆=argmaxk​ak​ak,n​p(xm​∣k,n)。

一些实现细节

  • 多维矩阵 首先要明白将所有的part、object放在一个矩阵里,每个维度的含义。笔者设定:part为(B, n_objects, n_votes, (dim_input+1+1)),object为(B, n_objects, dim_input**2+dim_speical_features+1)。搞定这些,之后可以进行矩阵拆分,对应到原论文对应的变量里。

  • BatchMLP 在计算object到part的解码时用到。每个object capsule需要一个单独的MLP来解码到对应的part capsule,也就是说,输入的object维度为[B, n_objects, n_special_features],被多个MLP计算得到结果应该是(B, n_objects, n_votes*(dim_input+1+1))。pytorch里面只有单个的MLP,我们类似原作者也实现了个BatchMLP来完成这个功能。

  • 对概率的处理 对预测的 a k a_k ak​和 a k , n a_{k,n} ak,n​使用softmax等函数进行处理,对预测的标准差加上一个 ϵ = 1 0 − 6 \epsilon=10^{-6} ϵ=10−6防止分母为0.

代码部分

Set Transformer

关于Set Transformer的实现如下,笔者做了相关注释,具体每个模块实现这里不贴。简而言之,这个编码器将(B, N, dim_input)的输入转化为(B, num_outputs, dim_output)的输出。

import torch.nn as nn
from base import BaseModel
from model.modules.setmodules import ISAB,SAB,PMA


class SetTransformer(BaseModel):
    """
    """

    def __init__(self, dim_input, num_outputs, dim_output,
            num_inds=32, dim_hidden=128, num_heads=4, ln=True):
        """Set Transformer, An autoencoder model dealing with set data

        Input set X with N elements, each `dim_input` dimensions, output
        `num_outputs` elements, each `dim_output` dimensions.

        In short, choose:
        N --> num_outputs
        dim_input --> dim_output

        Hyper-parameters:
            num_inds
            dim_hidden
            num_heads
            ln
        Args:
            dim_input: Number of dimensions of one elem in input set X
            num_outputs: Number of output elements
            dim_output: Number of dimensions of one elem in output set
            num_inds: inducing points number
            dim_hidden: output dimension of one elem of middle layer
            num_heads: heads number of multi-heads attention in MAB
            ln: whether to use layer norm in MAB
        """
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
                ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
                ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
        self.dec = nn.Sequential(
                PMA(dim_hidden, num_heads, num_outputs, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                nn.Linear(dim_hidden, dim_output))

    def forward(self, X):
        """
        Args:
            X: (B, N, dim_input)

        Returns:
            output set with shape (B, num_outputs, dim_output)
        """
        return self.dec(self.enc(X))

CCAE

编码器核心部分如下,可以看到可以类似17版那样用矢量表示胶囊,不过这里每个胶囊用三种不同意义的变量表示,因此后续处理也不同。

objects = self.set_transformer(x)  # (B, n_objects, dim_input**2+dim_speical_features+1)
splits = [self.dim_input**2,self.dim_input**2+self.dim_speical_features]
ov_matrix,special_features,presence=objects[:,:,:splits[0]],objects[:,:,splits[0]:splits[1]],objects[:,:,splits[1]:]

ov_matrix = ov_matrix.reshape(B, self.n_objects, self.dim_input, self.dim_input)
presence = F.softmax(presence, dim=1)

解码器,注意到这里使用了一个BatchMLP,即使用多个MLP对每个object的特殊向量进行解码,每个object都可以解码出若干个part。


x = self.bmlp(x) # (B, n_objects, n_votes*(dim_input+1+1))
x_chunk = x.chunk(self.n_votes, dim=-1)
x_object_part = torch.stack(x_chunk, dim=2) # (B, n_objects, n_votes, (dim_input+1+1))

splits = [self.dim_input, self.dim_input+1]
op_matrix = x_object_part[:,:,:,:splits[0]]
standard_deviation = x_object_part[:,:,:,splits[0]:splits[1]]
presence = x_object_part[:,:,:,splits[1]:]
presence = F.softmax(presence, dim=2)

使用无监督的决策方式,参考上文原理部分

# (B, 1, n_objects, 1)
object_presence = res_dict.object_presence[:, None, ...]
# (B, 1, n_objects, n_votes)
part_presence = res_dict.part_presence[:, None, ...].squeeze(-1)
# (B, M, n_objects, n_votes)
likelihood = res_dict.likelihood
a_k_n_times_p = (part_presence * likelihood).max(dim=-1, keepdim=True)[0]
expr = object_presence * a_k_n_times_p
winners = expr.max(dim=-2)[1].squeeze(-1)

数据集

这里直接复用了原本数据生成代码,搭建了一个Dataloader

class CCAE_Dataloader(BaseDataLoader):
    def __init__(self,
            # for dataloader
            batch_size,
            shuffle=True,
            validation_split=0.0,
            num_workers=1,

            # for dataset
            shuffle_corners=True,
            gaussian_noise=0.,
            max_translation=1.,
            rotation_percent=0.0,
            which_patterns='basic',
            drop_prob=0.0,
            max_scale=3.,
            min_scale=.1
        ):
        self.dataset = CCAE_Dataset(
            shuffle_corners,
            gaussian_noise,
            max_translation,
            rotation_percent,
            which_patterns,
            drop_prob,
            max_scale,
            min_scale
        )
        super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
        
class CCAE_Dataset(data.Dataset):
    def __init__(self,
            shuffle_corners=True,
            gaussian_noise=0.,
            max_translation=1.,
            rotation_percent=0.0,
            which_patterns='basic',
            drop_prob=0.0,
            max_scale=3.,
            min_scale=.1
        ):
        self.shuffle_corners = shuffle_corners
        self.scale = max_scale
        self.gaussian_noise = gaussian_noise
        self.max_translation = max_translation
        self.rotation_percent = rotation_percent
        self.which_patterns = which_patterns
        self.drop_prob = drop_prob

    def __len__(self):
        return 10000
    def __getitem__(self, item):
        data = create_numpy(
            1,
            self.shuffle_corners,
            self.gaussian_noise,
            self.max_translation,
            self.rotation_percent,
            self.scale,
            self.which_patterns,
            self.drop_prob)
        return data

损失

损失函数的计算方式如下

def ccae_loss(res_dict, target, epsilon = 1e-6):
    """
    
    Args:
        res_dict:
        target: input set with (B, k, dim_input)
        epsilon: avoiding nan for reciprocal of standard deviation
    Returns:
        log likelihood for input dataset(here "target") , (B,)
    """
    # retrieve the variable (Sorry for possible complication)
    op_matrix = res_dict.op_matrix # (B, n_objects, n_votes, dim_input)
    ov_matrix = res_dict.ov_matrix # (B, n_objects, dim_input, dim_input)
    standard_deviation = res_dict.standard_deviation # (B, n_objects, n_votes, 1)
    object_presence = res_dict.object_presence # (B, n_objects, 1)
    part_presence = res_dict.part_presence  # (B, n_objects, n_votes, 1)
    dim_input = res_dict.dim_input
    B, n_objects, n_votes, _ = standard_deviation.shape
    op_matrix = op_matrix[:,:,:,:,None] # (B, n_objects, n_votes, dim_input,1)
    ov_matrix = ov_matrix[:,:,None,:,:] # (B, n_objects, 1, dim_input,dim_input)

    # 防止分母为0
    standard_deviation = epsilon + standard_deviation[Ellipsis, None]
    # 计算mu
    mu = ov_matrix @ op_matrix # (B, n_objects, n_votes, dim_input,1)
    # 计算协方差
    identity = torch.eye(dim_input).repeat(B, n_objects, n_votes, 1, 1).to(standard_deviation.device)
    sigma = identity * (1/standard_deviation) # (B, n_objects, n_votes, dim_input,dim_input)

    # 计算数据集(即target)在混合模型上的似然估计
    # (B, k, n_objects, n_votes)
    gaussian_likelihood = gmm(mu, sigma).likelihood(target, object_presence=object_presence, part_presence=part_presence)

    # 计算似然估计的对数,作为损失目标
    log_likelihood = torch.log(gaussian_likelihood.sum((1,2,3))).mean()
    gaussian_likelihood = gaussian_likelihood.mean()
    res_dict.likelihood = -gaussian_likelihood
    res_dict.log_likelihood = -log_likelihood



    return res_dict

笔者又实现了一个高斯混合模型类来计算似然值,下面是计算损失的核心代码。

mu = ov_matrix @ op_matrix  # (B, n_objects, n_votes, dim_input,1)
identity = (
  torch.eye(dim_input)
  .repeat(B, n_objects, n_votes, 1, 1)
  .to(standard_deviation.device)
)
sigma = identity * (
  1 / standard_deviation
)  # (B, n_objects, n_votes, dim_input,dim_input)

# (B, k, n_objects, n_votes)
likelihood = gmm(mu, sigma).likelihood(
  target, object_presence=object_presence, part_presence=part_presence
)
log_likelihood = torch.log(likelihood.sum((1, 2, 3))).mean()

后续思考,这个损失函数有点写复杂了,直接在model里算好就不需要这么多代码了。

高斯混合模型的核心实现

class GuassianMixture(object):
    """
    GMM for part capsules
    """

    def __init__(self, mu, sigma):
        """
        Args:
            mu: (B, n_objects, n_votes, dim_input, 1)
            sigma: (B, n_objects, n_votes, dim_input, dim_input)

        After initialized:
            mu:   (B, 1, n_objects, n_votes, dim_input, 1)
            sigma:(B, 1, n_objects, n_votes, dim_input,dim_input)
            multiplier:(B, 1, n_objects, n_votes, 1, 1)
        """
        #  Converse shape to
        #  (Batch_size, num_of_points, num_of_objects, number_of_votes, ...)

        mu = mu[:, None, ...]  # (B, 1, n_objects, n_votes, dim_input, 1)
        sigma = sigma[:, None, ...]  # (B, 1, n_objects, n_votes, dim_input,dim_input)

        self.sigma = sigma
        self.mu = mu
        self.sigma_inv = sigma.inverse()
        D = self.sigma.shape[-1]
        sigma_det = torch.det(sigma)  # (B, 1, n_objects, n_votes)
        self.multiplier = (
            1 / ((2 * math.pi) ** (D / 2) * sigma_det.sqrt())[..., None, None]
        )

    def likelihood(self, x, object_presence=None, part_presence=None):
        diff = x - self.mu
        exp_result = torch.exp(-0.5 * diff.transpose(-1, -2) @ self.sigma_inv @ diff)

        denominator = object_presence.sum(dim=2, keepdim=True) * part_presence.sum(
          dim=3, keepdim=True
        )
        exp_result = (object_presence * part_presence / denominator) * exp_result
        gaussian_likelihood = self.multiplier * exp_result
        return gaussian_likelihood.squeeze(-1).squeeze(-1)

    def plot(self, choose):
        raise NotImplemented

目前的效果

  • 正确分类

原数据

image-20201204231106857

无监督分类结果

image-20201204231147582

  • 错误分类

image-20201204231236948

总结

本文使用pytorch实现了原论文第一个toy experiment,做了一个简单的展示,损失使用的是
p ( x 1 : M ) = ∏ m = 1 M ∑ k = 1 K ∑ n = 1 N a k a k , n ∑ i a i ∑ j a i , j p ( x m ∣ k , n ) p\left(\mathbf{x}_{1: M}\right)=\prod_{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, n\right) p(x1:M​)=m=1∏M​k=1∑K​n=1∑N​∑i​ai​∑j​ai,j​ak​ak,n​​p(xm​∣k,n)
未使用原文提出的sparsity loss。

工程方面

  • 参数传递的过程中,形状为1的维度应该压缩掉

实验方面

  • 写了个重大BUG:BatchMLP忘记使用激活,梯度变为nan,还是对激活函数的理解程度不深,写MLP竟然忘记带了
  • 还需要对无监督效果进行评估,

TODO

  • 实现无监督评估方法

  • 尝试用这个模型做指导性学习

参考资料

[1] https://github.com/google-research/google-research/tree/master/stacked_capsule_autoencoders

[2] https://github.com/phanideepgampa/stacked-capsule-networks

[3] https://github.com/MuhammadMomin93/Stacked-Capsule-Autoencoders-PyTorch

[4] https://github.com/Axquaris/StackedCapsuleAutoencoders

[5] Fitting a generative model using standard divergences between measures http://www.math.ens.fr/~feydy/Teaching/DataScience/fitting_a_generative_model.html

标签:dim,self,CapsNet,堆叠,动手,objects,input,part,operatorname
来源: https://blog.csdn.net/Qiang_brother/article/details/110675073

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有