numpy实现RNN原理实现
作者:J k l 发布时间:2023-09-21 23:47:33
标签:numpy,RNN
首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。
import numpy as np
class Rnn():
def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
def feed(self, x):
'''
:param x: [seq, batch_size, embedding]
:return: out, hidden
'''
# x.shape [sep, batch, feature]
# hidden.shape [hidden_size, batch]
# Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
# Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]
out = []
x, hidden = np.array(x), [np.zeros((self.hidden_size, x.shape[1])) for i in range(self.num_layers)]
Wih = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(1, self.num_layers)]
Wih.insert(0, np.random.random((self.hidden_size, x.shape[2])))
Whh = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(self.num_layers)]
time = x.shape[0]
for i in range(time):
hidden[0] = np.tanh((np.dot(Wih[0], np.transpose(x[i, ...], (1, 0))) +
np.dot(Whh[0], hidden[0])
))
for i in range(1, self.num_layers):
hidden[i] = np.tanh((np.dot(Wih[i], hidden[i-1]) +
np.dot(Whh[i], hidden[i])
))
out.append(hidden[self.num_layers-1])
return np.array(out), np.array(hidden)
def sigmoid(x):
return 1.0/(1.0 + 1.0/np.exp(x))
if __name__ == '__main__':
rnn = Rnn(1, 5, 4)
input = np.random.random((6, 2, 1))
out, h = rnn.feed(input)
print(f'seq is {input.shape[0]}, batch_size is {input.shape[1]} ', 'out.shape ', out.shape, ' h.shape ', h.shape)
# print(sigmoid(np.random.random((2, 3))))
#
# element-wise multiplication
# print(np.array([1, 2])*np.array([2, 1]))
来源:https://blog.csdn.net/qq_43056256/article/details/114272542
0
投稿
猜你喜欢
- 本文实例讲述了PHP实现将科学计数法转换为原始数字字符串的方法,分享给大家供大家参考。具体实现代码如下:function NumToStr(
- 前言Tkinter是python内置的标准GUI库,基于Tkinter实现了简易人员管理系统,所用数据库为Mongodb代码时间宝贵!直接上
- 联合结果集 新建临时工数据表 代码如下:CREATE TABLE T_TempEmployee (FIdCardNumber VARCHAR
- 《lnmp一键安装包》中需要获取ip地址,有2种情况:如果服务器只有私网地址没有公网地址,这个时候获取的IP(即私网地址)不能用来判断服务器
- #过滤式特征选择#根据方差进行选择,方差越小,代表该属性识别能力很差,可以剔除from sklearn.feature_selection
- 1. 用SimpleITK读取dicom序列:import SimpleITK as sitkimport numpy as npimg_p
- 前言图像分割是指根据灰度、色彩、空间纹理、几何形状等特征把图像划分成若干个互不相交的区域。最简单的图像分割就是将物体从背景中分割出来1.图像
- 为什么,这么简单的一个python,我还要特意来写一篇文章呢?是因为留念下,在使用了Anaconda2和Anaconda3的基础上,现在需安
- “深入认识Python内建类型”这部分的内容会从源码角度为大家介绍Python中各种常用的内建类型。
- 问题描述python的pandas库中有一个十分便利的isnull()函数,它可以用来判断缺失值,我们通过几个例子学习它的使用方法。首先我们
- 其主要的优点便是无需再手工添加大量的信息了,可以指定对某一个站信息的截取进行批量录入,达到省时省力的目的。与其单纯的ASP小偷程序不同的是:
- 50个常用sql语句 Student(S#,Sname,Sage,Ssex) 学生表 Course(C#,Cname,T#) 课程表 SC(
- 古巴比伦王颁布了汉摩拉比法典,刻在黑色的玄武岩,距今已经三千七百多年,你在橱窗前…熟悉吧?没错,这就是周董的爱在西元前歌词。前不久工作不是很
- 在使用python函数print()时,如下代码会出现输出无法显示的问题:分三次在一行输出 123print(1, end="&q
- 概述虽然Python的强项在人工智能,数据处理方面,但是对于日常简单的应用,Python也提供了非常友好的支持(如:Tkinter),本文主
- 项目介绍go-admin 是一个中后台管理系统,基于(gin, gorm, Casbin, Vue, Element UI)实现。主要目的是
- 我们在前面的几节中分别讲了提高网站性能中内容、服务器、JavaScript和CSS等方面的内容。除此之外,图片和Coockie也是我们网站中
- 如何做一个检索结果带链接的检索?具体代码和说明如下:<% data=request.form("search_da
- 在批评Python的讨论中,常常说起Python多线程是多么的难用。还有人对 global interpreter lock(也被亲切的称为
- 蜗牛很慢。蜗牛快递会怎样?答案是:当然也会很慢。但是蜗牛尽了他的全力,为了它的兔子朋友,以生命在奔跑。每天都是24个小时,快的只是速度,却不