The Expressive Power of Transformers with Chain of Thought

2310.07923

YC

20

Reddit

0

Published 4/15/2024 by William Merrill, Ashish Sabharwal

📉

Abstract

Recent theoretical work has identified surprisingly simple reasoning problems, such as checking if two nodes in a graph are connected or simulating finite-state machines, that are provably unsolvable by standard transformers that answer immediately after reading their input. However, in practice, transformers' reasoning can be improved by allowing them to use a chain of thought or scratchpad, i.e., generate and condition on a sequence of intermediate tokens before answering. Motivated by this, we ask: Does such intermediate generation fundamentally extend the computational power of a decoder-only transformer? We show that the answer is yes, but the amount of increase depends crucially on the amount of intermediate generation. For instance, we find that transformer decoders with a logarithmic number of decoding steps (w.r.t. the input length) push the limits of standard transformers only slightly, while a linear number of decoding steps, assuming projected pre-norm (a slight generalization of standard pre-norm), adds a clear new ability (under standard complexity conjectures): recognizing all regular languages. Our results also imply that linear steps keep transformer decoders within context-sensitive languages, and polynomial steps with generalized pre-norm make them recognize exactly the class of polynomial-time solvable problems -- the first exact characterization of a type of transformers in terms of standard complexity classes. Together, this provides a nuanced framework for understanding how the length of a transformer's chain of thought or scratchpad impacts its reasoning power.

Get summaries of the top AI research delivered straight to your inbox:

Overview

  • Researchers have found that standard transformer models, which provide immediate outputs, are limited in their ability to solve certain simple reasoning problems.
  • However, transformers can improve their reasoning by generating and conditioning on a sequence of intermediate tokens before answering, known as a "chain of thought" or "scratchpad."
  • This paper explores whether this intermediate generation fundamentally extends the computational power of a decoder-only transformer.

Plain English Explanation

Transformers are a type of artificial intelligence model that have become widely used for tasks like language processing and generation. These models typically provide an immediate output after reading their input.

However, recent research has shown that there are some surprisingly simple reasoning problems, such as checking if two nodes in a graph are connected or simulating finite-state machines, that standard transformers cannot solve effectively.

To address this limitation, the researchers in this paper explored whether allowing transformers to use a "chain of thought" or "scratchpad" could fundamentally increase their computational power. This means the transformer would generate and condition on a sequence of intermediate tokens before providing a final answer.

The key finding is that the answer is yes, but the amount of increase in computational power depends heavily on the length of the intermediate generation. For example, a transformer with a logarithmic number of decoding steps (relative to the input length) only slightly improves on standard transformers, while a linear number of decoding steps can allow the transformer to recognize all regular languages, a clear new ability.

The researchers also show that linear decoding steps keep transformers within the class of context-sensitive languages, while polynomial steps with a certain generalization can make them recognize exactly the class of problems solvable in polynomial time. This provides a nuanced framework for understanding how the length of a transformer's "chain of thought" impacts its reasoning power.

Technical Explanation

This paper examines whether allowing transformer decoders to generate and condition on a sequence of intermediate tokens, rather than providing a single immediate output, can fundamentally extend their computational power.

The researchers first establish that standard transformer decoders are provably limited in their ability to solve certain simple reasoning problems, such as checking graph connectivity and simulating finite-state machines. This is due to their inability to maintain and update an internal state or "scratchpad" as they process the input.

To overcome this limitation, the authors consider transformer decoders that are allowed to generate and condition on a sequence of intermediate tokens before providing a final output. This mimics the "chain of thought" or "scratchpad" that humans often use when solving complex reasoning problems.

The key results are:

  • A logarithmic number of decoding steps (relative to input length) only slightly extends the power of standard transformers.
  • A linear number of decoding steps, with a slight generalization of standard "pre-norm" layers, allows transformers to recognize all regular languages, a clear new ability.
  • Linear decoding steps keep transformers within the class of context-sensitive languages.
  • Polynomial decoding steps with a further generalization make transformers recognize exactly the class of problems solvable in polynomial time.

These findings provide a nuanced framework for understanding how the length of a transformer's "chain of thought" impacts its reasoning capabilities, ranging from slight improvements to the ability to solve more complex problems.

Critical Analysis

The paper provides a rigorous theoretical analysis of how the ability to generate and condition on intermediate tokens can extend the computational power of transformer decoders. The insights offered are valuable for understanding the fundamental limitations and capabilities of these widely used models.

One potential limitation of the research is that it focuses solely on the theoretical computational power of transformers, without considering practical aspects such as training dynamics, sample efficiency, and real-world performance. While the theoretical results are important, it would be valuable to see how these insights translate to actual transformer-based systems and their performance on relevant tasks.

Additionally, the paper does not explore the implications of these findings for the development of more capable reasoning systems. Further research could investigate how the insights from this work could be leveraged to design transformer architectures or training approaches that better support robust reasoning and problem-solving abilities.

Overall, this paper offers a significant contribution to the understanding of transformer models and their computational limitations and capabilities. The findings provide a solid foundation for future research in this area.

Conclusion

This paper investigates whether allowing transformer decoders to generate and condition on a sequence of intermediate tokens, rather than providing a single immediate output, can fundamentally extend their computational power. The key finding is that it can, but the extent of the increase depends heavily on the length of the intermediate generation process.

The researchers show that a logarithmic number of decoding steps only slightly improves on standard transformers, while a linear number of steps can enable transformers to recognize all regular languages. They also characterize the computational classes that transformers with different step counts can represent, providing a nuanced framework for understanding the impact of "chain of thought" or "scratchpad" generation on reasoning abilities.

These insights contribute to a deeper understanding of the fundamental limitations and capabilities of transformer models, which are widely used in various artificial intelligence applications. The work also suggests promising directions for future research on developing more capable reasoning systems, by leveraging the potential of intermediate generation to push the boundaries of what transformers can achieve.



This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!

Related Papers

Chain of Thought Empowers Transformers to Solve Inherently Serial Problems

Chain of Thought Empowers Transformers to Solve Inherently Serial Problems

Zhiyuan Li, Hong Liu, Denny Zhou, Tengyu Ma

YC

0

Reddit

0

Instructing the model to generate a sequence of intermediate steps, a.k.a., a chain of thought (CoT), is a highly effective method to improve the accuracy of large language models (LLMs) on arithmetics and symbolic reasoning tasks. However, the mechanism behind CoT remains unclear. This work provides a theoretical understanding of the power of CoT for decoder-only transformers through the lens of expressiveness. Conceptually, CoT empowers the model with the ability to perform inherently serial computation, which is otherwise lacking in transformers, especially when depth is low. Given input length $n$, previous works have shown that constant-depth transformers with finite precision $mathsf{poly}(n)$ embedding size can only solve problems in $mathsf{TC}^0$ without CoT. We first show an even tighter expressiveness upper bound for constant-depth transformers with constant-bit precision, which can only solve problems in $mathsf{AC}^0$, a proper subset of $ mathsf{TC}^0$. However, with $T$ steps of CoT, constant-depth transformers using constant-bit precision and $O(log n)$ embedding size can solve any problem solvable by boolean circuits of size $T$. Empirically, enabling CoT dramatically improves the accuracy for tasks that are hard for parallel computation, including the composition of permutation groups, iterated squaring, and circuit value problems, especially for low-depth transformers.

Read more

5/8/2024

🔎

What Formal Languages Can Transformers Express? A Survey

Lena Strobl, William Merrill, Gail Weiss, David Chiang, Dana Angluin

YC

0

Reddit

0

As transformers have gained prominence in natural language processing, some researchers have investigated theoretically what problems they can and cannot solve, by treating problems as formal languages. Exploring such questions can help clarify the power of transformers relative to other models of computation, their fundamental capabilities and limits, and the impact of architectural choices. Work in this subarea has made considerable progress in recent years. Here, we undertake a comprehensive survey of this work, documenting the diverse assumptions that underlie different results and providing a unified framework for harmonizing seemingly contradictory findings.

Read more

5/8/2024

🌐

When can transformers reason with abstract symbols?

Enric Boix-Adsera, Omid Saremi, Emmanuel Abbe, Samy Bengio, Etai Littwin, Joshua Susskind

YC

0

Reddit

0

We investigate the capabilities of transformer models on relational reasoning tasks. In these tasks, models are trained on a set of strings encoding abstract relations, and are then tested out-of-distribution on data that contains symbols that did not appear in the training dataset. We prove that for any relational reasoning task in a large family of tasks, transformers learn the abstract relations and generalize to the test set when trained by gradient descent on sufficiently large quantities of training data. This is in contrast to classical fully-connected networks, which we prove fail to learn to reason. Our results inspire modifications of the transformer architecture that add only two trainable parameters per head, and that we empirically demonstrate improve data efficiency for learning to reason.

Read more

4/17/2024

General Purpose Verification for Chain of Thought Prompting

General Purpose Verification for Chain of Thought Prompting

Robert Vacareanu, Anurag Pratik, Evangelia Spiliopoulou, Zheng Qi, Giovanni Paolini, Neha Anna John, Jie Ma, Yassine Benajiba, Miguel Ballesteros

YC

0

Reddit

0

Many of the recent capabilities demonstrated by Large Language Models (LLMs) arise primarily from their ability to exploit contextual information. In this paper, we explore ways to improve reasoning capabilities of LLMs through (1) exploration of different chains of thought and (2) validation of the individual steps of the reasoning process. We propose three general principles that a model should adhere to while reasoning: (i) Relevance, (ii) Mathematical Accuracy, and (iii) Logical Consistency. We apply these constraints to the reasoning steps generated by the LLM to improve the accuracy of the final generation. The constraints are applied in the form of verifiers: the model itself is asked to verify if the generated steps satisfy each constraint. To further steer the generations towards high-quality solutions, we use the perplexity of the reasoning steps as an additional verifier. We evaluate our method on 4 distinct types of reasoning tasks, spanning a total of 9 different datasets. Experiments show that our method is always better than vanilla generation, and, in 6 out of the 9 datasets, it is better than best-of N sampling which samples N reasoning chains and picks the lowest perplexity generation.

Read more

5/2/2024