0

0

Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues

    Published 11/20/2024 by Riccardo Grazzi, Julien Siems, Jorg K. H. Franke, Arber Zela, Frank Hutter, Massimiliano Pontil

    Overview

    • Research explores how negative eigenvalues enhance state tracking in Linear RNNs
    • Demonstrates LRNNs can maintain oscillatory patterns through negative eigenvalues
    • Challenges conventional wisdom about restricting RNNs to positive eigenvalues
    • Shows improved performance on sequence modeling tasks

    Expanding eigenvalue range improves LRNN parity learning accuracy.

    1/4

    Expanding eigenvalue range improves LRNN parity learning accuracy.

    Original caption: Figure 1: Extending the eigenvalue range of the state transition matrices of diagonal LRNNs improves performance from random guessing (range [0,1]01[0,1][ 0 , 1 ]) to perfect score (range [−1,1]11[-1,1][ - 1 , 1 ]) on learning parity. Trained on sequences up to length 40; Tested on lengths 40–256 (3 seeds).

    Table shows instances of LRNN layers, with learnable parameters and functions.

    1/2

    Method A(xt) B(xt) dec(Ht, xt)
    Mamba Diag(exp(-Δtexp(w1,i))) kt,iΔt⊙xt ψ(qtTHtT+w2⊙xt)
    GLA Diag(αt) ktvtT ψ(qtTHtT)
    DeltaNet I-βtktktT βtktvtT ψ(qtTHtT)

    Original caption: Table 1: Instances of LRNNs layers in (1), where 𝜶t=sigmoid⁢(𝑾α⁢𝒙t)subscript𝜶𝑡sigmoidsubscript𝑾𝛼subscript𝒙𝑡\bm{\alpha}_{t}{=}\mathrm{sigmoid}({\bm{W}}_{\alpha}{\bm{x}}_{t})bold_italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_sigmoid ( bold_italic_W start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), 𝚫t=softplus⁢(𝑾Δ⁢𝒙t)subscript𝚫𝑡softplussubscript𝑾Δsubscript𝒙𝑡\bm{\Delta}_{t}{=}\mathrm{softplus}({\bm{W}}_{\Delta}{\bm{x}}_{t})bold_Δ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_softplus ( bold_italic_W start_POSTSUBSCRIPT roman_Δ end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), βt=sigmoid⁢(𝒘β⁢𝒙t)subscript𝛽𝑡sigmoidsubscript𝒘𝛽subscript𝒙𝑡\beta_{t}{=}\mathrm{sigmoid}({\bm{w}}_{\beta}{\bm{x}}_{t})italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_sigmoid ( bold_italic_w start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), while 𝒒t,𝒌t∈ℝn,𝒗t∈ℝdformulae-sequencesubscript𝒒𝑡subscript𝒌𝑡superscriptℝ𝑛subscript𝒗𝑡superscriptℝ𝑑{\bm{q}}_{t},{\bm{k}}_{t}\in\mathbb{R}^{n},{\bm{v}}_{t}\in\mathbb{R}^{d}bold_italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT are output of learnable possibly non-linear functions of 𝒙tsubscript𝒙𝑡{\bm{x}}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Also ψ:ℝd→ℝd:𝜓→superscriptℝ𝑑superscriptℝ𝑑\psi:\mathbb{R}^{d}\to\mathbb{R}^{d}italic_ψ : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is another learnable function usually containing an MLP and a normalization, while 𝑾1∈ℝn×dsubscript𝑾1superscriptℝ𝑛𝑑{\bm{W}}_{1}\in\mathbb{R}^{n\times d}bold_italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_d end_POSTSUPERSCRIPT, 𝑾Δ∈ℝd×lsubscript𝑾Δsuperscriptℝ𝑑𝑙{\bm{W}}_{\Delta}\in\mathbb{R}^{d\times l}bold_italic_W start_POSTSUBSCRIPT roman_Δ end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_l end_POSTSUPERSCRIPT, 𝑾α∈ℝn×lsubscript𝑾𝛼superscriptℝ𝑛𝑙{\bm{W}}_{\alpha}\in\mathbb{R}^{n\times l}bold_italic_W start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_l end_POSTSUPERSCRIPT, 𝒘β∈ℝlsubscript𝒘𝛽superscriptℝ𝑙{\bm{w}}_{\beta}\in\mathbb{R}^{l}bold_italic_w start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT and 𝒘2∈ℝdsubscript𝒘2superscriptℝ𝑑{\bm{w}}_{2}\in\mathbb{R}^{d}bold_italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT are learnable parameters. For simplicity, we omitted 1D convolutions and for Mamba we wrote the matrices for the recursion of each row of 𝑯tsubscript𝑯𝑡{\bm{H}}_{t}bold_italic_H start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and set 𝒌t=(kt,1,…,kt,n)⊤subscript𝒌𝑡superscriptsubscript𝑘𝑡1…subscript𝑘𝑡𝑛top{\bm{k}}_{t}=(k_{t,1},\dots,k_{t,n})^{\top}bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_k start_POSTSUBSCRIPT italic_t , 1 end_POSTSUBSCRIPT , … , italic_k start_POSTSUBSCRIPT italic_t , italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and 𝑾1=(𝒘1,1,…,𝒘1,n)⊤subscript𝑾1superscriptsubscript𝒘11…subscript𝒘1𝑛top{\bm{W}}_{1}=({\bm{w}}_{1,1},\dots,{\bm{w}}_{1,n})^{\top}bold_italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ( bold_italic_w start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT , … , bold_italic_w start_POSTSUBSCRIPT 1 , italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT.

    Plain English Explanation

    Linear Recurrent Neural Networks (LRNNs) are simple but powerful systems for processing sequences of information. Think of them like a person trying to remember and update information over time. Traditional wisdom suggested these networks work best when they gradually forget information (positive eigenvalues).

    This research reveals that allowing LRNNs to have negative patterns of memory (negative eigenvalues) helps them track changing states much better. It's similar to how a pendulum swings back and forth - this oscillating pattern can help the network maintain and process information more effectively.

    The team discovered that these oscillating patterns let LRNNs handle complex tasks like keeping track of multiple pieces of information or recognizing patterns in sequences. It's like giving the network the ability to juggle multiple balls instead of just holding onto one.

    Key Findings

    State tracking capabilities improve significantly when negative eigenvalues are used. The networks showed:

    • Better performance on sequence modeling tasks
    • Improved ability to maintain multiple state patterns
    • More stable long-term memory capabilities
    • Enhanced pattern recognition in complex sequences

    Technical Explanation

    The research implements oscillatory patterns in LRNNs through carefully controlled negative eigenvalues in the recurrent weight matrix. The architecture maintains stability while allowing for periodic state changes.

    The experiments tested the networks on various sequence modeling tasks, comparing performance between traditional positive-only eigenvalue systems and those allowing negative values. The results demonstrate that negative eigenvalues enable more sophisticated state tracking mechanisms.

    Regular language processing capabilities showed marked improvement, particularly in tasks requiring maintenance of multiple state variables.

    Critical Analysis

    While the results are promising, several limitations exist:

    • The relationship between eigenvalue patterns and specific tasks needs further exploration
    • Scaling properties for very long sequences remain unclear
    • The impact on training stability requires additional investigation
    • Potential trade-offs between oscillatory behavior and memory persistence

    Conclusion

    This work fundamentally changes our understanding of how LRNNs can process information. The inclusion of negative eigenvalues opens new possibilities for sequence modeling applications and suggests that simpler architectures might be more capable than previously thought. This could lead to more efficient and effective neural network designs for sequence processing tasks.

    Full paper

    Loading...

    Loading PDF viewer...

    Read original: arXiv:2411.12537



    This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!

    Total Score

    1

    Follow @aimodelsfyi on 𝕏 →