标签:case gather mask tf data axis
函数定义链接:
tf.gather:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather
tf.gather_nd:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather_nd
tf.boolean_mask:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/boolean_mask
区别:
1.tf.gather
tf.gather(
params, indices, validate_indices=None, name=None, axis=None, batch_dims=0
)
Input: param维度[p1,p2,p3,p4,....]
indices维度[i1,i2,....]
axis:指定维度
Output: 根据indices的数值,从params第axis维获取数据,输出数据维度[...,p(axis-1),i1,i2,...,p(axis+1),...]
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>'''
#case 1:
tf.gather(data,[0,1],axis=0)
'''
<tf.Tensor: id=1375, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>
'''
#case 2:
tf.gather(data,[0,1],axis=1)
'''
<tf.Tensor: id=1378, shape=(2, 2, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[12, 13, 14, 15],
[16, 17, 18, 19]]])>
'''
#case 3:
tf.gather(data,[[0,1]],axis=1)
'''
<tf.Tensor: id=1381, shape=(2, 1, 2, 4), dtype=int32, numpy=
array([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]]],
[[[12, 13, 14, 15],
[16, 17, 18, 19]]]])>
'''
2.tf.gather_nd
tf.gather_nd(
params, indices, name=None, batch_dims=0
)
Input: param维度[p1,p2,p3,p4,....]
indices维度[i1,i2,....]
Output: 根据indices,获取params对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>'''
#case 1:
tf.gather_nd(data,[0])
'''
<tf.Tensor: id=1383, shape=(3, 4), dtype=int32, numpy=
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])>
'''
#case 2:
tf.gather(data,[0,1])
'''
<tf.Tensor: id=1385, shape=(4,), dtype=int32, numpy=array([4, 5, 6, 7])>
'''
#case 3:
tf.gather(data,[[0,1]])
'''
<tf.Tensor: id=1387, shape=(1, 4), dtype=int32, numpy=array([[4, 5, 6, 7]])>
'''
3.tf.boolean_mask
tf.boolean_mask(
tensor, mask, name='boolean_mask', axis=None
)
Input: tensor维度[p1,p2,p3,p4,....,p(n)]
mask:二值化,维度[p(axis),p(axis+1),...,p(axis+i)],axis+i<n,(注,mask维度需要跟tensor维度对应)
axis:从该维度开始
Output: 根据mask,获取tensor对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>'''
#case 1:
tf.boolean_mask(data,[True,False],axis=0)
'''
<tf.Tensor: id=1451, shape=(1, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])>
'''
#case 2: 维度不对应
tf.boolean_mask(data,[True],axis=0)
'''
ValueError: Shapes (2,) and (1,) are incompatible
'''
#case 3:
tf.boolean_mask(data,[[True,False,False],[True,False,True]],axis=0)
'''
<tf.Tensor: id=1482, shape=(3, 4), dtype=int32, numpy=
array([[ 0, 1, 2, 3],
[12, 13, 14, 15],
[20, 21, 22, 23]])>
'''
标签:case,gather,mask,tf,data,axis 来源: https://blog.csdn.net/u014426939/article/details/117511948
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。