SkillAgentSearch skills...

ScalaLSTM

Using scala to implement tiny LSTM, mainly focusing on the BPTT process of training the network.

Install / Use

/learn @xuanyuansen/ScalaLSTM
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

###深入理解LSTM的BPTT算法 ####LSTM网络结构 关于LSTM网络的结构可以阅读这篇文章:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

这里需要注意文章最后提及的LSTM两种变形,第一种是加入peephole,使得gate layer能够回溯前一个cell的状态,这增加了一些复杂度;第二种是GRU,将gate layer和forget layer合并为一个update layer,降低了复杂度。 ####LSTM网络的训练 LSTM的训练使用了BPTT算法,需要重要理解的一点是BPTT算法相当于BP算法扩展到序列(时序)数据,另一个需要理解的点是LSTM是recurrent neural network(这里注意理解recurrent neural network和recursive neural network的区别),BPTT算法在计算中要注意这一点。

####LSTM的计算图Compute Graph

  • LSTM的BPTT算法可以参考这篇文章http://nicodjimenez.github.io/2014/08/08/lstm.html

  • 讲述很清晰,注意这篇文章里面最后的输出h(t)没有加入tanh变换。

  • 为了理解LSTM的recursive特性,可以参考下图。

  • 从LSTM的结构可以看到,当前cell的状态会受到前一个cell状态的影响,这体现了LSTM的recursive特性。同时在误差反向传播计算时,可以发现h(t)的误差不仅仅包含当前时刻T的误差,也包括T时刻后所有时刻的误差,即back propagation through time的含义。这样每个时刻变量的误差都可以经由h(t)和c(t+1)迭代计算。

  • 为了使整个直观计算过程,在参考神经网络计算图分解的基础上,LSTM的计算图如下图所示,从计算图上面可以直观地看出LSTM的forward propagation和back propagation过程。

  • 从图中可以看出,H(t-1)的误差由H(t)决定,且要对所有的gate layer求和,c(t-1)由c(t)决定,而c(t)的误差由两部分,一部分是h(t),领一部分是c(t+1)。

  • 如果所示,在计算的时候,需要传入h(t)和c(t+1),h(t)在更新的时候需要加上h(t+1)。

####SCALA实现

  • breeze库

####利用SPARK实现minibatch方式的训练

####几种常见的LSTM结构

  • 1、原始LSTM
  • 2、peephole LSTM
  • 3、GRU

Related Skills

View on GitHub
GitHub Stars20
CategoryDevelopment
Updated1y ago
Forks11

Languages

Scala

Security Score

75/100

Audited on Mar 21, 2025

No findings