ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

正确编写numpy meshgrid函数

2019-11-10 21:09:16  阅读:169  来源: 互联网

标签:matplotlib python numpy


下面的代码有效.但是,行计算rez可能不是以numpy开发人员想要的方式编写的.如何在不转换列表的情况下有效地重写它?

另外,如果您碰巧知道如何在python中编写数学函数的良好指南,该指南也可以与数字和np.arrays一起使用,请分享:)

import numpy as np
from matplotlib import pyplot as plt

def L2normSquared(X, Y, x, y):
    return sum((X-x)**2 + (Y-y)**2)

X = np.random.normal(0, 1, 10)
Y = np.random.normal(0, 1, 10)

PX = np.arange(-1, 1, 0.1)
PY = np.arange(-1, 1, 0.1)
PXm, PYm = np.meshgrid(PX, PY)

rez = np.array([[L2normSquared(X, Y, x, y) for y in PY] for x in PX])

print(rez.shape)

plt.imshow(rez, cmap='jet', interpolation='nearest', origin='lower', extent=[-1, 1, -1, 1])
plt.show()

编辑:让我解释一下我要做什么

我在2D中生成10个随机点.然后,我为2D中的任意点(x,y)定义一个损失函数.该函数的结果是所有10个固定点到该任意点的欧几里得距离的总和.最后,我想使用2d imshow方法绘制此损失函数

编辑2:根据尼尔斯·沃纳(Nils Werner)的回答,可以使用3D阵列和广播,产生以下代码

import numpy as np
from matplotlib import pyplot as plt

def L2normSquared(X, Y, x, y):
    return np.sum((X-x)**2 + (Y-y)**2, axis=0)

X = np.random.normal(0, 1, 10)
Y = np.random.normal(0, 1, 10)

PX = np.arange(-1, 1, 0.1)
PY = np.arange(-1, 1, 0.1)
PXm, PYm = np.meshgrid(PX, PY)


rez = L2normSquared(X[:, None, None], Y[:, None, None], PXm, PYm)

print(rez.shape)

plt.imshow(rez, cmap='jet', interpolation='nearest', origin='lower', extent=[-1, 1, -1, 1])
plt.show()

但是,此代码实际上比列表理解要慢(对于10000个随机坐标,步长0.01,大约慢2-3倍).对于更大的输入,会发生内存崩溃,这使我相信这种方法会在内部导致3D数组动态编程,这在内存分配方面无法很好地扩展.

编辑3:
非常抱歉,但我的最小例子太少了.在我面临的原始问题中,坐标X和Y不会解耦,从而可以分别计算它们.原始功能之一是

def gaussian(X, Y, x, y):
  return sum(np.exp(-(X-x)**2 -(Y-y)**2))

解决方法:

meshgrid背后的想法是,您将获得两个或多个数组,您可以直接将这些数组传递给操作.因此,理想情况下,我们根本不需要for循环.

但是,由于您在X和PX之间做一个“外部差异”,然后在X轴上求和,因此还需要使用broadcasting首先做外积,最后在正确的轴上求和:

import numpy as np
from matplotlib import pyplot as plt

def L2normSquared(X, Y, x, y):
    return np.sum((X-x)**2 + (Y-y)**2, axis=0)

X = np.random.normal(0, 1, 10)
Y = np.random.normal(0, 1, 10)

PX = np.arange(-1, 1, 0.1)
PY = np.arange(-1, 1, 0.1)
PXm, PYm = np.meshgrid(PX, PY)

rez = L2normSquared(X[:, None, None], Y[:, None, None], PXm, PYm)

标签:matplotlib,python,numpy
来源: https://codeday.me/bug/20191110/2015131.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有