网络编程
位置:首页>> 网络编程>> 网络编程>> numpy实现RNN原理实现

numpy实现RNN原理实现

作者:J k l  发布时间:2023-09-21 23:47:33 

标签:numpy,RNN

首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。


import numpy as np

class Rnn():

def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
   self.input_size = input_size
   self.hidden_size = hidden_size
   self.num_layers = num_layers
   self.bidirectional = bidirectional

def feed(self, x):
   '''

:param x: [seq, batch_size, embedding]
   :return: out, hidden
   '''

# x.shape [sep, batch, feature]
   # hidden.shape [hidden_size, batch]
   # Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
   # Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]

out = []
   x, hidden = np.array(x), [np.zeros((self.hidden_size, x.shape[1])) for i in range(self.num_layers)]
   Wih = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(1, self.num_layers)]
   Wih.insert(0, np.random.random((self.hidden_size, x.shape[2])))
   Whh = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(self.num_layers)]

time = x.shape[0]
   for i in range(time):
     hidden[0] = np.tanh((np.dot(Wih[0], np.transpose(x[i, ...], (1, 0))) +
              np.dot(Whh[0], hidden[0])
              ))

for i in range(1, self.num_layers):
       hidden[i] = np.tanh((np.dot(Wih[i], hidden[i-1]) +
                  np.dot(Whh[i], hidden[i])
                  ))

out.append(hidden[self.num_layers-1])

return np.array(out), np.array(hidden)

def sigmoid(x):
 return 1.0/(1.0 + 1.0/np.exp(x))

if __name__ == '__main__':
 rnn = Rnn(1, 5, 4)
 input = np.random.random((6, 2, 1))
 out, h = rnn.feed(input)
 print(f'seq is {input.shape[0]}, batch_size is {input.shape[1]} ', 'out.shape ', out.shape, ' h.shape ', h.shape)
 # print(sigmoid(np.random.random((2, 3))))
 #
 # element-wise multiplication
 # print(np.array([1, 2])*np.array([2, 1]))

来源:https://blog.csdn.net/qq_43056256/article/details/114272542

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com