ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

Python数据分析(02) graphviz绘制KD二叉查找树

2021-02-09 14:59:32  阅读:197  来源: 互联网

标签:02 node right KD Python data index root dot


目录

前言

        我在另一篇博文机器学习(3) K近邻算法(KNN)介绍及C++实现中介绍了K近邻算法及KD树的实现方法,博文编写过程中需要显式绘制二叉树将其表示出来。初始方法是使用C++生成KD树,并根据graphviz的dot语言逐行编写KD树。通过查阅相关文献,发现使用Python绘制KD树的过程并不繁琐,于是本文介绍使用Python绘制graphviz二叉查找树的图形。另外,介绍博文中是绘制二维平面上KD树、收敛过程如何表示的图形化绘制方法

graphviz简介

        graphviz是一种便于绘制流程图、树形结构等的图形可视化软件。掌握基础的脚本语言就可以轻松绘制属于自己的流程图、二叉树图等内容。

graphviz安装

        安装graphviz流程:打开graphviz下载链接,依据网页提示选择属于自己的平台安装包。我安装的是windows10下的stable_windows_10_msbuild_Release_Win32_graphviz-2.46.0-win32.zip。下载完成后解压到C:/Software/graphviz等自己习惯的路径下,将C:/software/graphviz/bin加入到系统环境变量中,重启电脑以配置graphviz环境变量。使用Python绘制graphviz流程图,需要在安装python3环境后,在命令行pip install graphviz即可。
        安装python3环境流程:以windows10为例,打开Python3-Windows-下载地址,下载最新版本安装包或适合自己的版本安装包,安装到本地过程中记得配置环境变量Path,不再赘述。

graphviz语法

        我将着重讲述使用graphviz绘制二叉树涉及到的语法知识,更详细的语法知识参见graphviz官网说明文档。我将讲解两方面知识,第一是使用命令行编译dot文件,第二是使用python直接生成.gv文件。
        首先介绍命令行下如何使用graphviz语法编写一个二叉树。在自己的路径下生成一个demo.dot文档,文档中内容如下:

// demo.dot
digraph {
	node [shape=circle]
	1 [label="(7,2)"]
	2 [label="(5,4)"]
	3 [label="(2,3)"]
	4 [label="(6,6)", style="invis"]
	1 -> 2
	1 -> 4
	1 -> 3 [style="invis"]
}

        我将文档保存在了C:\File\demo.dot路径下。在命令行中执行:

cd C:\File
dot -Tpng demo.png -o demo.dot

        就会在路径C:\File下生成demo.png,图像如图所示。
demo.png

demo.png
        下面解释dot文件中每一行的含义。
digraph G{
	...
}
表示这是一个有向图,图中的边都带箭头。
...
	node [shape='circle']
...
表示图中的节点都是圆形。
1 [label='(7,2)']
声明一个节点,节点记为1,其内容为字符串"(7,2)"
4 [label="(6,6)", style="invis"]
声明一个节点,节点记为4,其内容为字符串"(6,6)",并且这个节点在图中不显示。
1 -> 2
声明一条从1指向2的边。
1 -> 3 [style="invis"]
声明一条从1指向3的边,并且这条边在图中不显示。

        据此,为了保证二叉树有序、对齐显示,我们在绘制二叉树的过程中,左右子树中间添加一个不可见的边和不可见节点,实现图形的对齐效果。如果使用C++强行绘制graphviz,就根据.dot文件的语法格式,向文件流中采用先根遍历的方法书写dot文本,使用文件流记得#include <fstream>。C++实现方法如下:

#include <fstream>
void drawKDTree(node* root, string path) {
	// 等价于先根序列。
	// path = "tree.dot"
	ofstream fout(path);
	string tab = "    ";
	fout << "digraph G{" << endl;
	fout << tab << "node[shape=circle]" << endl;
	int N = data.size()+1;
	preOrderDraw(root, fout, N);
	fout << "}" << endl;
	fout.close();
	
}
void preOrderDraw(node* root, ofstream& fout, int& nullIndex) {
	string tab = "    ";
	// 先根序列,绘制当前节点的内容。
	fout << tab << root->index << "[group=" << root->index << ", label=\"(" << data[root->index][0];
	for (int i = 1; i < n; i++) {
		fout << "," << data[root->index][i];
	}
	fout << ")\"]" << endl;
	// 绘制左节点的内容
	if (root->left) {
		// 当左节点非空的时候,需要绘制一条伸向左节点的有向边。
		fout << tab << root->index << " -> " << root->left->index << endl;
		// 递归遍历左子树。
		preOrderDraw(root->left, fout, nullIndex);
	}
	else {
		// 左节点为空的时候,为了保证图形的整洁有序,绘制左侧空节点占位。边与节点都为不可见[style=invis]。
		fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl;
		fout << tab << "_" << nullIndex++ << " [style=invis]" << endl;
	}
	// 为了二叉树的图形可以相当漂亮美观且对齐,设置一个中间空节点保证左右两侧对齐。
	fout << tab << root->index<<" -> "<< "_" << root->index << "[weight=10, group=" << root->index << ", style=invis]" << endl;
	fout << tab << "_" << root->index << "[style=invis]" << endl;
	// 同上,绘制右节点的内容。
	if (root->right) {
		// 当右节点非空的时候,需要绘制一条伸向右节点的有向边。
		fout << tab << root->index << " -> " << root->right->index << endl;
		// 递归遍历右子树。
		preOrderDraw(root->right, fout, nullIndex);
	}
	else {
		// 右节点为空的时候,为了保证图形的整洁有序,绘制右侧空节点占位。边与节点都为不可见[style=invis]。
		fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl;
		fout << tab << "_" << nullIndex++ << " [style=invis]" << endl;
	}
}

        下面介绍Python的graphviz语法。为了绘制同样一棵上面的树,我们只需要做这几行代码,即可生成一棵二叉树并展示出来。

// demo.py
from graphviz import Digraph
dot = Digraph(node_attr={'shape': 'circle'})
dot.node(1,"(7,2)")
dot.node(2,"(5,4)")
dot.node(3,"(2,3)")
dot.node(4,"(6,6)",style="invis")
dot.edge(1,2)
dot.edge(1,4)
dot.edge(1,3,style="invis")
dot.view()

        据此,使用先根遍历的方式,同样根据二叉树的节点,绘制边和点即可。

import numpy as np
from graphviz import Digraph
from matplotlib import pyplot as plt
from matplotlib.pyplot import MultipleLocator
#data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]]
data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]]
data = np.array(data)

# 节点
class node:
	def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True):
		self.data = _data
		self.left = _left
		self.right = _right
		self.father = _father
		self.dim = _dim
		self.index = _index
		self.visiable = _visiable
	def getData(self):
		s = "("
		for i in range(self.data.size):
			if i!=0:
				s += ','
			s+=str(self.data[i])
		s += ")"
		return s
	def __str__(self):
		if(self.visiable):
			return str(self.index)
		else:
			return "_invis"+str(self.index)

dataIndex = 1
def drawKDTree(data, depth, k, dot):
	# 根据数据生成KD树
	dim = depth % k
	length = data.shape[0]
	if(length==0):
		return None, dot
	index = []
	for i in range(length):
		index.append([data[i][dim], i])
	index.sort()
	root = data[index[length//2][1]]
	left = [data[index[i][1]] for i in range(length//2)]
	left = np.array(left)
	right = [data[index[i][1]] for i in range(length//2+1, length)]
	right = np.array(right)
	global dataIndex
	root_node = node(_data=root, _dim=dim, _index=dataIndex)
	dataIndex+=1

	dot.node(str(root_node.index), root_node.getData())

	root_node.left, dot=drawKDTree(left, depth+1, k, dot)
	if(root_node.left==None):
		pass
		dot.node("_left"+str(root_node.index), root_node.getData(), style="invis")
		dot.edge(str(root_node.index), "_left"+str(root_node.index), style="invis")
	else:
		dot.edge(str(root_node.index), str(root_node.left.index))

	dot.node("_middle"+str(root_node.index), root_node.getData(), style="invis")
	dot.edge(str(root_node.index), "_middle"+str(root_node.index), style="invis", weight="10")

	root_node.right, dot=drawKDTree(right, depth+1, k, dot)

	if(root_node.right==None):
		pass
		dot.node("_right"+str(root_node.index), root_node.getData(), style="invis")
		dot.edge(str(root_node.index), "_right"+str(root_node.index), style="invis")
	else:
		dot.edge(str(root_node.index), str(root_node.right.index))

	if(root_node.left):
		root_node.left.father=root_node
	if(root_node.right):
		root_node.right.father=root_node
	
	return root_node, dot

dot = Digraph(node_attr={'shape': 'circle'})
_, dot = drawKDTree(data, 0, 2, dot)
dot.view()
print(dot.source)

绘制平面KNN模拟图

        只需要通过pyplot在生成KD树的过程中,控制节点的维度以及左右边界,即可绘制分类的直线段;通过绘制scatter散点图,将点标记在图中;通过计算半径,绘制以待查询节点为圆心的圆形。特别注意的是,由于pyplot不支持深拷贝、也无法撤销某一步操作,因此想要在同一个背景下绘制不同的图形,只有自己设置一个函数以保证每次都可以同样调用生成同一块背景,并在该背景上绘制新的图形。这里的Python函数不包括数据的预处理、标签、投票等内容,仅仅是用于绘制图形而用的脚本内容。

import numpy as np
from graphviz import Digraph
from matplotlib import pyplot as plt
from matplotlib.pyplot import MultipleLocator
#data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]]
data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]]
data = np.array(data)

# 节点
class node:
	def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True):
		self.data = _data
		self.left = _left
		self.right = _right
		self.father = _father
		self.dim = _dim
		self.index = _index
		self.visiable = _visiable
	def getData(self):
		s = "("
		for i in range(self.data.size):
			if i!=0:
				s += ','
			s+=str(self.data[i])
		s += ")"
		return s
	def __str__(self):
		if(self.visiable):
			return str(self.index)
		else:
			return "_invis"+str(self.index)

# 生成KD树,并绘制一个完整的平面图形。
def createTree(data, depth, k, l, r, d, u):
	dim = depth % k
	length = data.shape[0]
	if(length==0):
		return None
	index = []
	for i in range(length):
		index.append([data[i][dim], i])
	index.sort()
	root = data[index[length//2][1]]
	left = [data[index[i][1]] for i in range(length//2)]
	left = np.array(left)
	right = [data[index[i][1]] for i in range(length//2+1, length)]
	right = np.array(right)
	root_node = node(_data=root, _dim=dim)

	if(dim == 0):
		plt.plot([root[0]]*(u-d+1), range(d, u+1))
		root_node.left=createTree(left, depth+1, k, l, root[0], d, u)
		root_node.right=createTree(right, depth+1, k, root[0], r, d, u)
		if(root_node.left):
			root_node.left.father=root_node
		if(root_node.right):
			root_node.right.father=root_node
	else:
		plt.plot(range(l, r+1), [root[1]]*(r-l+1))
		root_node.left=createTree(left, depth+1, k, l, r, d, root[1])
		root_node.right=createTree(right, depth+1, k, l, r, root[1], u)
		if(root_node.left):
			root_node.left.father=root_node
		if(root_node.right):
			root_node.right.father=root_node
	return root_node

# 绘制分类超平面
def drawOri(data):
	fig, ax = plt.subplots()
	fig.set_size_inches(5, 5)
	data = np.array(data)
	mmax = np.max(data)+1
	mmin = np.min(data)-1
	major_locator=MultipleLocator(1)
	plt.scatter(data[:,0], data[:,1])
	plt.xlim(mmin, mmax)
	plt.ylim(mmin, mmax)
	ax = plt.gca()
	ax.xaxis.set_major_locator(major_locator)
	ax.yaxis.set_major_locator(major_locator)
	return createTree(data, 0, 2, mmin, mmax, mmin, mmax)

# 绘制标记点及分类超平面
def drawPic(x, data):
	fig, ax = plt.subplots()
	fig.set_size_inches(5, 5)
	data = np.array(data)
	mmax = np.max(data)+1
	mmin = np.min(data)-1
	major_locator=MultipleLocator(1)
	plt.scatter(data[:,0], data[:,1])
	plt.scatter([x[0]], [x[1]], marker='x')
	plt.xlim(mmin, mmax)
	plt.ylim(mmin, mmax)
	ax = plt.gca()
	ax.xaxis.set_major_locator(major_locator)
	ax.yaxis.set_major_locator(major_locator)
	return createTree(data, 0, 2, mmin, mmax, mmin, mmax)

# 计算两点间的欧式距离
def distance(a, b):
	return ((a[0]-b[0])**2+(a[1]-b[1])**2)**0.5

# 寻找叶节点
def findLeaf(root, x, stack):
	if(root==None):
		return stack
	stack.append(root)
	if(x[root.dim]<=root.data[root.dim]):
		return findLeaf(root.left, x, stack)
	else:
		return findLeaf(root.right, x, stack)


# 寻找最近邻节点,并绘制图形
def searchNearest(root, x, differt_pic=True, show=False):
	plt.scatter([x[0]], [x[1]], marker='x')
	stack = []
	stack = findLeaf(root, x, stack)
	nearN = stack[-1]
	minD = distance(stack[-1].data, x)
	visted = set()
	path = 1
	while(stack):
		top = stack[-1]
		visted.add(top)
		stack.pop()
		dis = distance(top.data, x)
		if(dis < minD):
			minD = dis
			nearN = top
		if show:
			plt.show()
		if differt_pic:
			# 重新绘制一张底图。
			drawPic(x, data)
		ax = plt.gca()
		ax.scatter(top.data[0], top.data[1], marker='x', s=200)
		ax.plot([x[0], nearN.data[0]], [x[1], nearN.data[1]])
		theta = np.arange(0, 2*np.pi, 0.01)
		xx = x[0] + minD * np.cos(theta)
		yy = x[1] + minD * np.sin(theta)
		plt.plot(xx, yy)
		plt.savefig("{0}.png".format(path))
		path += 1

		left = x[top.dim] - minD
		right = x[top.dim] + minD
		if(left <= top.data[top.dim] and top.left != None and top.left not in visted):
			stack.append(top.left)
		if(right >= top.data[top.dim] and top.right != None and top.right not in visted):
			stack.append(top.right)
	return nearN


x = [4, 3]
drawOri(data)
plt.savefig("0.png")
root = drawPic(x, data, differt_pic=True, show=False)
searchNearest(root, x)

        绘制图形展示如下:
1
2

        至此,《统计学习方法》第三章的全部内容都更新完毕,在我的Gtihub中有详细代码,欢迎查阅。

标签:02,node,right,KD,Python,data,index,root,dot
来源: https://blog.csdn.net/ProfSnail/article/details/113769110

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有