ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

tf.gather,tf.gather_nd,tf.boolean_mask

2021-06-03 10:58:36  阅读:251  来源: 互联网

标签:case gather mask tf data axis


函数定义链接:

tf.gather:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather

tf.gather_nd:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather_nd

tf.boolean_mask:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/boolean_mask

 

区别

1.tf.gather

tf.gather(
    params, indices, validate_indices=None, name=None, axis=None, batch_dims=0
)

Input: param维度[p1,p2,p3,p4,....]
        indices维度[i1,i2,....]
        axis:指定维度
Output: 根据indices的数值,从params第axis维获取数据,输出数据维度[...,p(axis-1),i1,i2,...,p(axis+1),...]
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>'''

#case 1:
tf.gather(data,[0,1],axis=0)
'''
<tf.Tensor: id=1375, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>
'''

#case 2:
tf.gather(data,[0,1],axis=1)
'''
<tf.Tensor: id=1378, shape=(2, 2, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19]]])>
'''

#case 3:
tf.gather(data,[[0,1]],axis=1)
'''
<tf.Tensor: id=1381, shape=(2, 1, 2, 4), dtype=int32, numpy=
array([[[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]]],
       [[[12, 13, 14, 15],
         [16, 17, 18, 19]]]])>
'''

2.tf.gather_nd

tf.gather_nd(
    params, indices, name=None, batch_dims=0
)
Input: param维度[p1,p2,p3,p4,....]
        indices维度[i1,i2,....]
        
Output: 根据indices,获取params对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>'''

#case 1:
tf.gather_nd(data,[0])
'''
<tf.Tensor: id=1383, shape=(3, 4), dtype=int32, numpy=
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])>
'''

#case 2:
tf.gather(data,[0,1])
'''
<tf.Tensor: id=1385, shape=(4,), dtype=int32, numpy=array([4, 5, 6, 7])>
'''

#case 3:
tf.gather(data,[[0,1]])
'''
<tf.Tensor: id=1387, shape=(1, 4), dtype=int32, numpy=array([[4, 5, 6, 7]])>
'''

3.tf.boolean_mask

tf.boolean_mask(
    tensor, mask, name='boolean_mask', axis=None
)
Input: tensor维度[p1,p2,p3,p4,....,p(n)]
        mask:二值化,维度[p(axis),p(axis+1),...,p(axis+i)],axis+i<n,(注,mask维度需要跟tensor维度对应)
        axis:从该维度开始

        
Output: 根据mask,获取tensor对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>'''

#case 1:
tf.boolean_mask(data,[True,False],axis=0)
'''
<tf.Tensor: id=1451, shape=(1, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]])>
'''

#case 2: 维度不对应
tf.boolean_mask(data,[True],axis=0)
'''
ValueError: Shapes (2,) and (1,) are incompatible
'''

#case 3:
tf.boolean_mask(data,[[True,False,False],[True,False,True]],axis=0)
'''
<tf.Tensor: id=1482, shape=(3, 4), dtype=int32, numpy=
array([[ 0,  1,  2,  3],
       [12, 13, 14, 15],
       [20, 21, 22, 23]])>
'''

 

 

 

标签:case,gather,mask,tf,data,axis
来源: https://blog.csdn.net/u014426939/article/details/117511948

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有