ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

tensorflow2.0(三)----循环神经网络(RNN)

2020-12-01 09:04:44  阅读:335  来源: 互联网

标签:---- RNN seq self batch tensorflow2.0 length tf size


class DataLoader():
	def __init__(self):
		path = tf.keras.utils.get_file('nietzsche.txt',origin='http://s3.amazonaws.com/text-data')
		with open(path,encoding='utf-8') as f:
			self.raw_text = f.read().lower()
		self.chars = sorted(list(set(self.raw_text)))
		self.char_indices = dict((c,i) for i,c in enumerate(self.chars))
		self.indices_char = dict((i,c) for i,c in enumerate(self.chars))
		self.text = [self.char_indices[c] for c in self.raw_text]
	
	def get_batch(self,seq_length,batch_size):
		seq = []
		next_char = []
		for i in range(batch_size):
			index = np.random.randint(0,len(self.text) - seq_length)
			seq.append(self.text[index:index+seq_length])
			next_char.append(self.text[index+seq_length])
		return np.array(seq), np.array(next_char)

class RNN(tf.keras.Model):
	def __init__(self,num_chars,batch_size,seq_length):
		super().__init__()
		self.num_chars = num_chars
		self.seq_length = seq_length
		self.batch_size = batch_size
		self.cell = tf.keras.layers.LSTMCell(units = 256)
		self.dense = tf.keras.layers.Dense(units = self.num_chars)

	def call(self,inputs,from_logits = False):
		inputs = tf.one_hot(inputs,depth = self.num_chars)
		state = self.cell.get_initial_state(batch_size=self.batch_size,dtype=tf.float32)
		for t in range(self.seq_length):
			output,state = self.cell(inputs[:,t,:],state)
		logits = self.dense(output)
		if from_logits:
			return logits
		else:
			return tf.nn.softmax(logits)

	num_batches = 1000
	seq_length = 40
	batch_size = 50
	learning_rate = le-3
	
	data_loader = DataLoader()
	model = RNN(num_chars = len(data_loader.chars),batch_size = batch_size,seq_length=seq_length)
	optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
	for batch_index in range(num_batches):
		x,y = data_loader.get_batch(seq_length,batch_size)
		with tf.GradientTape as tape:
			y_pred = model(x)
			loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y,y_pred=y_pred)
			print()
		grads = tape.gradient(loss,model.variables)
		optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))
	
	def predict(self,inputs,temperature=1):
		batch_size,_=tf.shape(inputs)
		logits = self(inputs,from_logits=True)
		prod = tf.nn.softmax(logits/temperature).numpy()
		return np.array([np.random.choice(self.num_chars,p=prod[i,:]) for i in range(
		batch_size.numpy())])
	
	x_,- = data_loader.get_batch(seq_length,1)
	for diversity in [0.2,0.5,1.0,1.2]:
		x = x_
		print("diversity %f" % diversity)
		for t in range(400):
			y_pred = model.predict(x,diversity)
			print(data_loader.indices_char[y_pred[0]],end = '',flush=True)
			x = np.concatenate([x[:,1:],np.expand_dims(y_pred,axis=1)],axis=-1)
			print("\n")

标签:----,RNN,seq,self,batch,tensorflow2.0,length,tf,size
来源: https://blog.csdn.net/weixin_43648821/article/details/106428047

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有