ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Understanding Shapes in PyTorch Distributions Package

2021-04-28 10:02:43  阅读:255  来源: 互联网

标签:sample Package torch shape batch Shapes PyTorch event Size


转自:https://bochang.me/blog/posts/pytorch-distributions/    

Understanding Shapes in PyTorch Distributions Package 

The torch.distributions package implements various probability distributions, as well as methods for sampling and computing statistics. It generally follows the design of the TensorFlow distributions package (Dillon et al. 2017).

There are three types of “shapes”, sample shape, batch shape, and event shape, that are crucial to understanding the torch.distributions package. The same definition of shapes is also used in other packages, including GluonTSPyro, etc.

In this blog post, we describe the different types of shapes and illustrate the differences among them by code examples. On top of that, we try to answer a few questions related to the shapes in torch.distributions. All code examples are compatible with PyTorch v1.3.0.

Three types of shapes

The three types of shapes are defined as follows and illustrated in Figure 1.

  1. Sample shape describes independent, identically distributed draws from the distribution.
  2. Batch shape describes independent, not identically distributed draws. Namely, we may have a set of (different) parameterizations to the same distribution. This enables the common use case in machine learning of a batch of examples, each modeled by its own distribution.
  3. Event shape describes the shape of a single draw (event space) from the distribution; it may be dependent across dimensions.

Shape semantics.

Figure 1: Three groups of shapes. Reproduced from Dillon et al. (2017).

The definitions above might be difficult to understand. We take the Monte Carlo estimation of the evidence lower bound (ELBO) in the variational autoencoder (VAE) as an example to illustrate their differences. The average ELBO over a batch of bb observations, xixi for i=1,…,bi=1,…,b, is

L=1b∑i=1bEq(z|xi)log[p(xi|z)p(z)q(z|xi)],L=1b∑i=1bEq(z|xi)log⁡[p(xi|z)p(z)q(z|xi)],

where z∈Rsz∈Rs is an ss-dimensional latent vector. The ELBO can be estimated by Monte Carlo samples; specifically, for each xixi, nn samples are drawn from zij∼i.i.d.q(z|xi)zij∼i.i.d.q(z|xi) for j=1,…,nj=1,…,n. The estimate is then

Lˆ=1bn∑i=1b∑j=1nlog[p(xi|zij)p(zij)q(zij|xi)].L^=1bn∑i=1b∑j=1nlog⁡[p(xi|zij)p(zij)q(zij|xi)].

All the Monte Carlos samples zijzij can be compactly represented as a tensor of shape (n,b,s)(n,b,s) or, correspondingly, [sample_shape, batch_shape, event_shape].

We also provide mathematical notations for a few combinations of shapes in Table 1, for Gaussian random variables/vectors. Ignoring subscripts, XX represents a random variable, μμ and σσ are scalars; XX denotes a two-dimensional random vector, μμ is a two-dimensional vector, and ΣΣ is a 2×22×2 (not necessarily diagonal) matrix. Moreover, [] and [2] denote torch.Size([]) and torch.Size([2]), respectively. For each row, the link to a PyTorch code example is also given.

no.sample shapebatch shapeevent shapemathematical notationcode
1[][][]X∼N(μ,σ2)X∼N(μ,σ2)link
2[2][][]X1,X2∼i.i.d.N(μ,σ2)X1,X2∼i.i.d.N(μ,σ2)link
3[][2][]X1∼N(μ1,σ21)X1∼N(μ1,σ12)
X2∼N(μ2,σ22)X2∼N(μ2,σ22)
link
4[][][2]X∼N(μ,Σ)X∼N(μ,Σ)link
5[][2][2]X1∼N(μ1,Σ1)X1∼N(μ1,Σ1)
X2∼N(μ2,Σ2)X2∼N(μ2,Σ2)
link
6[2][][2]X1,X2∼i.i.d.N(μ,Σ)X1,X2∼i.i.d.N(μ,Σ)link
7[2][2][]X11,X12∼i.i.d.N(μ1,σ21)X11,X12∼i.i.d.N(μ1,σ12)
X21,X22∼i.i.d.N(μ2,σ22)X21,X22∼i.i.d.N(μ2,σ22)
link
8[2][2][2]X11,X12∼i.i.d.N(μ1,Σ1)X11,X12∼i.i.d.N(μ1,Σ1)
X21,X22∼i.i.d.N(μ2,Σ2)X21,X22∼i.i.d.N(μ2,Σ2)
link

Table 1: Examples of various combinations of shapes.

This table is adapted from this blog post; you might find the visualization in that post helpful.

Every Distribution class has instance attributes batch_shape and event_shape. Furthermore, each class also has a method .sample, which takes sample_shape as an argument and generates samples from the distribution. Note that sample_shape is not an instance attribute because, conceptually, it is not associated with a distribution.

What is the difference between Normal and MultivariateNormal?

There are two distribution classes that correspond to normal distributions: the univariate normal

torch.distributions.normal.Normal(loc, scale, validate_args=None)

and the multivariate normal

torch.distributions.multivariate_normal.MultivariateNormal(loc, 
covariance_matrix=None, precision_matrix=None, scale_tril=None, 
validate_args=None)

Since the Normal class represents univariate normal distributions, the event_shape of a Normal instance is always torch.Size([]). Even if the loc or scale arguments are high-order tensors, their “shapes” will go to batch_shape. For example,

>>> normal = Normal(torch.randn(5, 3, 2), torch.ones(5, 3, 2))
>>> (normal.batch_shape, normal.event_shape)
(torch.Size([5, 3, 2]), torch.Size([]))

In contrast, for MultivariateNormal, the batch_shape and event_shape can be inferred from the shape of covariance_matrix. In the following example, the covariance matrix torch.eye(2) is 2×22×2 matrix, so it can be inferred that the event_shape should be [2] and the batch_shape is [5, 3].

>>> mvn = MultivariateNormal(torch.randn(5, 3, 2), torch.eye(2))
>>> (mvn.batch_shape, mvn.event_shape)
(torch.Size([5, 3]), torch.Size([2]))

What does .expand do?

Every Distribution instance has an .expand method. Its docstring is as follows:

class Distribution(object):

    def expand(self, batch_shape, _instance=None):
        """
        Returns a new distribution instance (or populates an existing instance
        provided by a derived class) with batch dimensions expanded to
        `batch_shape`. This method calls :class:`~torch.Tensor.expand` on
        the distribution's parameters. As such, this does not allocate new
        memory for the expanded distribution instance. Additionally,
        this does not repeat any args checking or parameter broadcasting in
        `__init__.py`, when an instance is first created.

        Args:
            batch_shape (torch.Size): the desired expanded size.
            _instance: new instance provided by subclasses that
                need to override `.expand`.

        Returns:
            New distribution instance with batch dimensions expanded to
            `batch_size`.
        """
        raise NotImplementedError

It essentially creates a new distribution instance by expanding the batch_shape. For example, if we define a MultivariateNormal instance with a batch_shape of [] and an event_shape of [2],

>>> mvn = MultivariateNormal(torch.randn(2), torch.eye(2))
>>> (mvn.batch_shape, mvn.event_shape)
(torch.Size([]), torch.Size([2]))

it can be expanded to a new instance that have a batch_shape of [5]:

>>> new_batch_shape = torch.Size([5])
>>> expanded_mvn = mvn.expand(new_batch_shape)
>>> (expanded_mvn.batch_shape, expanded_mvn.event_shape)
(torch.Size([5]), torch.Size([2]))

Note that no new memory is allocated in this process; therefore, all batch dimensions have the same location parameter.

>>> expanded_mvn.loc
tensor([[-2.2299,  0.0122],
        [-2.2299,  0.0122],
        [-2.2299,  0.0122],
        [-2.2299,  0.0122],
        [-2.2299,  0.0122]])

This can be compared with the following example with the same batch_shape of [5] and event_shape of [2]. However, the batch dimensions have different location parameters.

>>> batched_mvn = MultivariateNormal(torch.randn(5, 2), torch.eye(2))
>>> (batched_mvn.batch_shape, batched_mvn.event_shape)
(torch.Size([5]), torch.Size([2]))
>>> batched_mvn.loc
tensor([[-0.3935, -0.7844],
        [ 0.3310,  0.9311],
        [-0.8141, -0.2252],
        [ 2.4199, -0.5444],
        [ 0.5586,  1.0157]])

What is the Independent class?

The Independent class does not represent any probability distribution. Instead, it creates a new distribution instance by “reinterpreting” some of the batch shapes of an existing distribution as event shapes.

torch.distributions.independent.Independent(base_distribution, 
reinterpreted_batch_ndims, validate_args=None)

The first argument base_distribution is self-explanatory; the second argument reinterpreted_batch_ndims is the number of batch shapes to be reinterpreted as event shapes.

The usage of the Independent class can be illustrated by the following example. We start with a Normal instance with a batch_shape of [5, 3, 2] and an event_shape of [].

>>> loc = torch.zeros(5, 3, 2)
>>> scale = torch.ones(2)
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size([5, 3, 2]), torch.Size([])]

An Independent instance can reinterpret the last batch_shape as the event_shape. As a result, the new batch_shape is [5, 3], and the event_shape now becomes [2].

>>> normal_ind_1 = Independent(normal, 1)
>>> [normal_ind_1.batch_shape, normal_ind_1.event_shape]
[torch.Size([5, 3]), torch.Size([2])]

The instance normal_ind_1 is essentially the same as the following MultivariateNormal instance:

>>> mvn = MultivariateNormal(loc, torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size([5, 3]), torch.Size([2])]

We can further reinterpret more batch shapes as event shapes:

>>> normal_ind_1_ind_1 = Independent(normal_ind_1, 1)
>>> [normal_ind_1_ind_1.batch_shape, normal_ind_1_ind_1.event_shape]
[torch.Size([5]), torch.Size([3, 2])

or equivalently,

>>> normal_ind_2 = Independent(normal, 2)
>>> [normal_ind_2.batch_shape, normal_ind_2.event_shape]
[torch.Size([5]), torch.Size([3, 2])]

PyTorch code examples

In this section, code examples for each row in the Table 1 are provided. The following import statements are needed for the examples.

import torch
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal

Row 1: [], [], []

>>> dist = Normal(0.0, 1.0)
>>> sample_shape = torch.Size([])
>>> dist.sample(sample_shape)
tensor(-1.3349)
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([]), torch.Size([]), torch.Size([]))

back to Table 1

Row 2: [2], [], []

>>> dist = Normal(0.0, 1.0)
>>> sample_shape = torch.Size([2])
>>> dist.sample(sample_shape)
tensor([ 0.2786, -1.4113])
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([2]), torch.Size([]), torch.Size([]))

back to Table 1

Row 3: [], [2], []

>>> dist = Normal(torch.zeros(2), torch.ones(2))
>>> sample_shape = torch.Size([])
>>> dist.sample(sample_shape)
tensor([0.0101, 0.6976])
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([]), torch.Size([2]), torch.Size([]))

back to Table 1

Row 4: [], [], [2]

>>> dist = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> sample_shape = torch.Size([])
>>> dist.sample(sample_shape)
tensor([ 0.2880, -1.6795])
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([]), torch.Size([]), torch.Size([2]))

back to Table 1

Row 5: [], [2], [2]

>>> dist = MultivariateNormal(torch.zeros(2, 2), torch.eye(2))
>>> sample_shape = torch.Size([])
>>> dist.sample(sample_shape)
tensor([[-0.4703,  0.4152],
        [-1.6471, -0.6276]])
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([]), torch.Size([2]), torch.Size([2]))

back to Table 1

Row 6: [2], [], [2]

>>> dist = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> sample_shape = torch.Size([2])
>>> dist.sample(sample_shape)
tensor([[ 2.2040, -0.7195],
        [-0.4787,  0.0895]])
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([2]), torch.Size([]), torch.Size([2]))

back to Table 1

Row 7: [2], [2], []

>>> dist = Normal(torch.zeros(2), torch.ones(2))
>>> sample_shape = torch.Size([2])
>>> dist.sample(sample_shape)
tensor([[ 0.2639,  0.9083],
        [-0.7536,  0.5296]])
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([2]), torch.Size([2]), torch.Size([]))

back to Table 1

Row 8: [2], [2], [2]

>>> dist = MultivariateNormal(torch.zeros(2, 2), torch.eye(2))
>>> sample_shape = torch.Size([2])
>>> dist.sample(sample_shape)
tensor([[[ 0.4683,  0.6118],
         [ 1.0697, -0.0856]],

        [[-1.3001, -0.1734],
         [ 0.4705, -0.0404]]])
>>> (sample_shape, dist.batch_shape, dist.event_shape)
(torch.Size([2]), torch.Size([2]), torch.Size([2]))

back to Table 1

References

  • Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., … & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604.

 

Written on Oct 20, 2019.

标签:sample,Package,torch,shape,batch,Shapes,PyTorch,event,Size
来源: https://blog.csdn.net/bat67/article/details/116225534

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有