ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

numpy数组、pytorch张量的维度操作

2021-01-02 16:29:32  阅读:346  来源: 互联网

标签:tensor torch 张量 shape 维度 pytorch 拼接 print numpy


内容概要

写脚本时经常需要进行numpy的数组ndarray和pytorch张量tensor的维度操作,主要包括拼接、拆分、维度扩张和压缩等等。有关的方法很多,老是记不住,写一篇帖子记录一下。
参考:https://blog.csdn.net/pipisorry/article/details/108988615?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_baidulandingword-2&spm=1001.2101.3001.4242

tensor

torch.cat(tensors, dim=0, *, out=None)

cat()方法用于拼接给定的张量,所有张量必须具有相同的形状(连接的维度除外)

# -*- coding: utf-8 -*-
import torch

x = torch.randint(1, 10, size=(2, 4))
print('x', x)
y = torch.randint(1, 10, size=(2, 4))
print('y', y)
a = torch.cat([x, y], dim=0)
print(a)
b = torch.cat([x, y], dim=1)
print(b)
# out
x tensor([[3, 7, 2, 3],
        [3, 7, 9, 1]])
y tensor([[3, 6, 5, 8],
        [5, 2, 8, 5]])
tensor([[3, 7, 2, 3],
        [3, 7, 9, 1],
        [3, 6, 5, 8],
        [5, 2, 8, 5]])
tensor([[3, 7, 2, 3, 3, 6, 5, 8],
        [3, 7, 9, 1, 5, 2, 8, 5]])

torch.stack(tensors, dim=0, *, out=None)

stack()方法也用于张量的拼接,与cat()方法不同,他会在新的维度上进行拼接

# -*- coding: utf-8 -*-
import torch

x = torch.randint(1, 10, size=(2, 4))
print('x', x)
y = torch.randint(1, 10, size=(2, 4))
print('y', y)
a = torch.stack([x, y], dim=0)
print(a)
print(a.shape)

#out
x tensor([[7, 7, 7, 1],
        [2, 6, 6, 2]])
y tensor([[3, 7, 1, 5],
        [5, 1, 8, 9]])
tensor([[[7, 7, 7, 1],
         [2, 6, 6, 2]],

        [[3, 7, 1, 5],
         [5, 1, 8, 9]]])
torch.Size([2, 2, 4])

注意使用cat()方法拼接张量时拼接的维度可以不一致,但是stack拼接由于会产生新的维度,所以用于拼接的张量的所有维度都必须是一直的。

torch.squeeze(input, dim=None, *, out=None)

squeeze()方法用于张量维度的压缩,使用时指定某个维度,若该维度的值为1则压缩此维度。
unsqueeze()方法用于扩张一个维度,用法和squeeze(),注意在使用这两种方法时,总是要指定dim

#-*- coding: utf-8 -*-
import torch

x = torch.randint(1, 10, size=(1, 2, 4))
print('x', x, 'shape=', x.shape)
x = x.squeeze(dim=0)
print('x', x, 'shape=', x.shape)
x = x.unsqueeze(dim=2)
print('x', x, 'shape=', x.shape)

#out
x tensor([[[1, 8, 3, 8],
         [3, 9, 7, 3]]]) shape= torch.Size([1, 2, 4])
x tensor([[1, 8, 3, 8],
        [3, 9, 7, 3]]) shape= torch.Size([2, 4])
x tensor([[[1],
         [8],
         [3],
         [8]],

        [[3],
         [9],
         [7],
         [3]]]) shape= torch.Size([2, 4, 1])

numpy

numpy.append()

append()方法与tensor的cat方法用法几乎一样,用于darray数组的拼接,被拼接的维度可以不一样。

#-*- coding: utf-8 -*-
import numpy as np

y = np.zeros(shape=(0, 5, 5))
print(y)
x = np.random.randint(1, 10, size=(1, 5, 5))
print(x)
y = np.append(y, x, axis=0)
print(y)

#out
[]
[[[9 1 7 7 9]
  [6 9 9 6 7]
  [7 6 1 6 6]
  [7 6 1 6 5]
  [4 6 6 3 4]]]
[[[9. 1. 7. 7. 9.]
  [6. 9. 9. 6. 7.]
  [7. 6. 1. 6. 6.]
  [7. 6. 1. 6. 5.]
  [4. 6. 6. 3. 4.]]]

numpy.stack()

stack()方法

#-*- coding: utf-8 -*-
import numpy as np

y = np.zeros(shape=(0, 5, 5))
print(y)
x = np.random.randint(1, 10, size=(1, 5, 5))
print(x)
y = np.stack((y, x), axis=0)
print(y)

#out
ValueError: all input arrays must have the same shape

使用stack()方法拼接数组要求数组的维度完全一致,否则会像上图一样报错,这与我的使用习惯不符,所以我一直用的append()方法,append方法似乎存在内存占用较多的问题,但是操作更灵活。

其他有关tensor和darray的维度操作方法我总结之后会更新到该贴,欢迎指正。

标签:tensor,torch,张量,shape,维度,pytorch,拼接,print,numpy
来源: https://blog.csdn.net/weixin_44132759/article/details/112099273

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有