This is another potential improvement to the transformer architecture from Facebook (the other one that comes to mind is this one from same authors: https://arxiv.org/abs/2405.18719), but note that it comes with a major problem that might not be obvious at first glance: it's just not usable in practice without a ton of work. It modifies the innards of the attention mechanism, so it is incompatible with Flash Attention (or any other optimized attention library), and you do not want to train anything beyond toy models without Flash Attention (the performance hit is just way too big).
There's pytorch's FlexAttention which could maybe make this practical, but currently it's just way too buggy.
People familiar with exotic RNNs and improvements to LSTMs know this problem all too well. The moment your lstm isnt a bog standard lstm, it loses all the speed-ups from cuDNN and it becomes borderline unusable for anything but toy models.
These would be inherently temporary problems though right? If it became eventually clear that alternate methods were the way forward, NVDIA would be highly motivated to do the optimization work wouldn't they? Any new step functions that can forestall the asymptotic plateauing of AI progress are things they desperately need.
Why do you say FlexAttention is too buggy? I have heard about a lot of successful usages of it, and never heard about any such problems.
Also note, depending on your model dimensions and sequence lengths, often the attention computation plays only a minor role (maybe 10% overall or so), and the MLP computation dominates.
Last time I tried it I encountered both showstopper bugs (it was completely obviously broken) and subtle correctness bugs (it looked like it was working, but since I'm paranoid I have unit tests for everything and numerically the errors were too big compared to what you'd get with eager attention or Flash Attention), and it was too slow for my taste compared to Flash Attention so I just dropped it. And I wasn't even doing anything super exotic with it.
Maybe it's better now, but I'd still consider using FlexAttention without a corresponding unit test checking its accuracy against an equivalent eager implementation completely irresponsible.
How does this compare with Byte Latent Transformer [1]? This happens with convolution post-embedding while BLT happens with attention at embedding time?
As I understand it, BLT uses a small nn to tokenize but doesn’t change the attention mechanism. MTA uses traditional BPE for tokenization but changes the attention mechanism. You could use both (latency be damned!)
That's... not always a given for SOTA sized models. When the ROI on more training stops, it is nice to have alternatives, whether that is RL-tuned reasoning models or alternative architectures that improve specific areas of weakness.
It's a valid criticism that this method would increase compute requirements, but sometimes an improvement in the end result justifies the compute needed. For things like code generation in large datasets, many people would be willing to "pay" with more compute if the results were better. And this doesn't seem to require more memory bandwidth, so it could be particularly good for local models.
I read the paper and the results don't really convince me that is the case. But the problem still remains of being able to use information from different part of the model without squishing it to a single value with the softmax.
There’s no one-size-fits-all answer here, but in my experience, for long contexts, perf for conv-based methods outperforms strictly attention-based methods. See evo2:
“With the current implementation of Evo2, we do not have the heavily optimized kernels in place for convolution operators like we do for attention layers in a model like llama2. Even with this shortcoming, we see that the benefit from including more convolutional layers makes up for the earlier stage of optimization at around the 64k context length. Beyond that point we see an improvement in performance even compared to a highly optimized transformer model.“
The difference RoPE makes vs traditional positional encoding is that you just care about relative distances between tokens, and we can attenuate the attention over great distances.
Instead of making the model look at every token in the entire sequence all at once (which gets expensive fast), you can break the text into logical chunks—like sentences or paragraphs—and run self-attention within each chunk. That keeps things efficient while still capturing local meaning. Then, for each chunk, you create a summary—either by pooling or using a small learned head—and pass those summaries into a second layer of attention that operates on a much smaller scale. This gives you higher-level context across the document, kind of like moving from sentences to sections to the whole thing. Optionally, you can even send that higher-level context back down to influence the lower layers. This approach shows up in models like Longformer and BigBird (which use attention windows), hierarchical models (like HANs), and newer architectures like RetNet and Mamba that compress information over time or scale. RoPE fits neatly into this by helping each chunk handle relative positions more naturally.
RoPE is kind of perfect for this setup because it handles relative positions directly in the attention mechanism, which means each chunk can still understand the order and spacing of tokens without relying on fixed position embeddings. It’s especially useful when you're working with long sequences or chunked inputs, because it doesn’t care where the chunk is in the overall document—it just cares about how tokens relate to each other within that chunk. RoPE also makes it easier for models to generalize to longer inputs than they were trained on, since the rotational math behind it naturally extends beyond the original context window. Plus, because it's baked into the dot product itself, it adds no extra memory or computation, and plays well with hierarchical or multi-scale attention setups. Basically, it’s a clean, efficient way to inject positional awareness that doesn’t break when you start slicing things up.
> allowing nearby queries and keys to affect each other's attention weights for more precise attention
If it is only nearby tokens it is multiplicative by a constant right? Not making it cubic scaling with context length or anything.
Deepseek got a training performance increase with two tokens at a time, though it doesn't go into the final model inference like this. They did say it can be used for speculative decode to reduce inference costs though.
They may get away with less attention heads with this new approach too.
Interesting. So they convolve the k,v, q vectors? I have been trying the opposite.
I have been working on a classification problem on audio data (with context size somewhere between 1000 and 3000 with potential to expand later). I have been experimenting with adding attention onto a CNN for a classification task I have been working on.
I tried training a vanilla transformer but in the sizes that I am aiming for (5-30M parameters), the training is incredibly unstable and doesn't achieve the performance of an LSTM.
So I went back to CNNs which are fast to train but don't achieve the losses of LSTMs (which are much slower to train,and for higher context sizes you get into the vanishing gradient problem). The CNN-GRU hubrid a worked much better, giving me my best result.
The GRU layer I used had a size of 512. For increasing context sizes, I'd have to make the convolutional layers deeper so as not to increase the GRU size too large. Instead, I decided to swap out the GRU with a MultiHeadAttention layer. The results are great - better than the CNN-GRU (my previous best). Plus, for equivalent sizes the model is faster to train though it hogs a lot of memory.
We have to move past tokenization for the next leap in capabilities. All this work done on tokens, specially in the RL optimization contest, is just local optimization alchemy.
LLMs in their entirety are unlikely to move past tokenization - it is the inescapable core from the roots of NLP and Markov Chains.
The future of AI and all of ML in general likely does exist beyond tokenization, but I find it unlikely we will get there without moving past LLMs as a whole.
We need to focus on the strengths of LLMs and abandon the incredibly wasteful amount of effort being put into trying to make them put on convincing facsimiles of things they can't do just because the output is in natural language and easily fools humans at first glance.
This is valid but also hard to back up with any alternatives. At the end of the day it’s just a neural network with backprop. New architectures will likely only be marginally better. So either we add new algorithms on top of it like RL, create a new learning algorithm (for example forward-forward), or we figure out how to use more energy efficient compute (analog etc) to scale several more magnitudes. It’s gonna take some time
Yeah, that's fair - it's very easy to tell that LLMs are not the end state, but it's near impossible to know what comes next.
Personally I think LLMs will be relegated to transforming output and input from whatever new logic system is brought forth, rather than pretending they're doing logic by aggregating static corpora like we are now.
Why is there an expectation that “nearby” tokens are relevant to increase the information in the similarities? That seems like it would hold true within individual words, but the whole point of attention was to solve long range dependencies. Reintroducing local windows seems like a step backwards to me.
Maybe it's helpful to find the right point in the long context, but then have easy access to the local structure around that point.
eg, yes, the magically relevant point is the third word of the fifth paragraph on page 183 of the document, but then having a good representation of all of that page is more helpful than the single word.
This doesn’t answer your question, but one thing to keep in mind is that past the very first layer, every “token” position is a weighted average of every previous position, so adjacency isn’t necessarily related to adjacent input tokens.
A borderline tautological answer might be “because the network learns that putting related things next to each other increases the usefulness of the convolutions”
It's a little more inductive bias. That's not necessarily a step backwards. You need the right amount of inductive bias for a given data size and model capacity, no more and no less. Transformers already make the inductive bias of temporal locality by being causal.
So, why would this extract more semantic meaning than multi-head attention? Isn't the whole point of multiple heads similar to how CNNs use multiple types of filters to extract different semantic relationships?
Achieved by “applying convolution operations over queries, keys and heads, allowing nearby queries and keys to affect each other's attention weights for more precise attention”
Cool to see convolutions making such a comeback lately in the llm world. See also the recent striped hyena2 architecture, which uses the conv-based hyena operator to great success:
The null hypothesis is more compute or bigger network = better results. Conv operations make sense on images because the data is naturally 2 dimensional, so applying an operation across a sliding window makes sense.
Skimming the paper, I don’t see them testing against e.g. a normal decoder with an extra layer or something.
I don’t see the same logic applying on an embedding, where the individual indexes matter. Adjacent indexes in an embedding have no relationship, unlike adjacent pixels in an image.
They do have a weak relationship, in that earlier index tokens were encountered earlier during the formation of the vocabulary, so they are similar in typicality
No, if you check the diagram (page 2) these are literally indexes into the KV vectors, not positional indexes in the text. If it was the text I would agree with you.
Convolutions are used in many non-image applications, including language (eg dilated convolutions have been popular for some time) and 1D cases. The paper I linked references the hyena operator, which is literally a convolution replacement for attention (though it’s often used in hybrid architectures like the one I linked).
There is a planet-wise eternal 100% safe AI solution that can be a billion dollar startup, too:
Put all the GPUs in cloud/s controlled by international scientists (now you can use your GPU on any device, can earn money by renting it when you don’t need it, nothing changes except you need to be online to us it, but we’ll have 5G and better worldwide. You can develop, sell or release free math-proven safe AI models in this cloud “AI App Store”, etc).
Because the main risk is an AI agent botnet - current GPUs are like nukes that are 100% unprotected - any hacker can make a virus with AI agent component just to steal money, this AI will be not aligned at all, will become a per perpetual and eventually autonomous botnet.
This is another potential improvement to the transformer architecture from Facebook (the other one that comes to mind is this one from same authors: https://arxiv.org/abs/2405.18719), but note that it comes with a major problem that might not be obvious at first glance: it's just not usable in practice without a ton of work. It modifies the innards of the attention mechanism, so it is incompatible with Flash Attention (or any other optimized attention library), and you do not want to train anything beyond toy models without Flash Attention (the performance hit is just way too big).
There's pytorch's FlexAttention which could maybe make this practical, but currently it's just way too buggy.
People familiar with exotic RNNs and improvements to LSTMs know this problem all too well. The moment your lstm isnt a bog standard lstm, it loses all the speed-ups from cuDNN and it becomes borderline unusable for anything but toy models.
These would be inherently temporary problems though right? If it became eventually clear that alternate methods were the way forward, NVDIA would be highly motivated to do the optimization work wouldn't they? Any new step functions that can forestall the asymptotic plateauing of AI progress are things they desperately need.
Why do you say FlexAttention is too buggy? I have heard about a lot of successful usages of it, and never heard about any such problems.
Also note, depending on your model dimensions and sequence lengths, often the attention computation plays only a minor role (maybe 10% overall or so), and the MLP computation dominates.
Last time I tried it I encountered both showstopper bugs (it was completely obviously broken) and subtle correctness bugs (it looked like it was working, but since I'm paranoid I have unit tests for everything and numerically the errors were too big compared to what you'd get with eager attention or Flash Attention), and it was too slow for my taste compared to Flash Attention so I just dropped it. And I wasn't even doing anything super exotic with it.
Maybe it's better now, but I'd still consider using FlexAttention without a corresponding unit test checking its accuracy against an equivalent eager implementation completely irresponsible.
How does this compare with Byte Latent Transformer [1]? This happens with convolution post-embedding while BLT happens with attention at embedding time?
1. https://ai.meta.com/research/publications/byte-latent-transf...
As I understand it, BLT uses a small nn to tokenize but doesn’t change the attention mechanism. MTA uses traditional BPE for tokenization but changes the attention mechanism. You could use both (latency be damned!)
Sure, you can get better model performance by throwing more compute at the problem in different places. Does is it improve perf on an isoflop basis?
That's... not always a given for SOTA sized models. When the ROI on more training stops, it is nice to have alternatives, whether that is RL-tuned reasoning models or alternative architectures that improve specific areas of weakness.
It's a valid criticism that this method would increase compute requirements, but sometimes an improvement in the end result justifies the compute needed. For things like code generation in large datasets, many people would be willing to "pay" with more compute if the results were better. And this doesn't seem to require more memory bandwidth, so it could be particularly good for local models.
I read the paper and the results don't really convince me that is the case. But the problem still remains of being able to use information from different part of the model without squishing it to a single value with the softmax.
There’s no one-size-fits-all answer here, but in my experience, for long contexts, perf for conv-based methods outperforms strictly attention-based methods. See evo2:
“With the current implementation of Evo2, we do not have the heavily optimized kernels in place for convolution operators like we do for attention layers in a model like llama2. Even with this shortcoming, we see that the benefit from including more convolutional layers makes up for the earlier stage of optimization at around the 64k context length. Beyond that point we see an improvement in performance even compared to a highly optimized transformer model.“
https://docs.nvidia.com/bionemo-framework/latest/models/evo2...
So, we're proposing a multiplicative increase of something that already scales quadratically with the context size?
I think we've already got a bit of a bottleneck in terms of memory bandwidth utilization.
If you have a bottleneck in terms of memory bandwidth utilization, this method is great - it would utilize the idle compute.
LLaMa 3 already has RoPE encoding which can handle arbitrarily long contexts (within reason)
https://arxiv.org/abs/2104.09864
The difference RoPE makes vs traditional positional encoding is that you just care about relative distances between tokens, and we can attenuate the attention over great distances.
Instead of making the model look at every token in the entire sequence all at once (which gets expensive fast), you can break the text into logical chunks—like sentences or paragraphs—and run self-attention within each chunk. That keeps things efficient while still capturing local meaning. Then, for each chunk, you create a summary—either by pooling or using a small learned head—and pass those summaries into a second layer of attention that operates on a much smaller scale. This gives you higher-level context across the document, kind of like moving from sentences to sections to the whole thing. Optionally, you can even send that higher-level context back down to influence the lower layers. This approach shows up in models like Longformer and BigBird (which use attention windows), hierarchical models (like HANs), and newer architectures like RetNet and Mamba that compress information over time or scale. RoPE fits neatly into this by helping each chunk handle relative positions more naturally.
RoPE is kind of perfect for this setup because it handles relative positions directly in the attention mechanism, which means each chunk can still understand the order and spacing of tokens without relying on fixed position embeddings. It’s especially useful when you're working with long sequences or chunked inputs, because it doesn’t care where the chunk is in the overall document—it just cares about how tokens relate to each other within that chunk. RoPE also makes it easier for models to generalize to longer inputs than they were trained on, since the rotational math behind it naturally extends beyond the original context window. Plus, because it's baked into the dot product itself, it adds no extra memory or computation, and plays well with hierarchical or multi-scale attention setups. Basically, it’s a clean, efficient way to inject positional awareness that doesn’t break when you start slicing things up.
PS: LLaMA's RoPE may be a bit off but it still works great: https://discuss.huggingface.co/t/is-llama-rotary-embedding-i...
> allowing nearby queries and keys to affect each other's attention weights for more precise attention
If it is only nearby tokens it is multiplicative by a constant right? Not making it cubic scaling with context length or anything.
Deepseek got a training performance increase with two tokens at a time, though it doesn't go into the final model inference like this. They did say it can be used for speculative decode to reduce inference costs though.
They may get away with less attention heads with this new approach too.
Maybe Sam was right about needing one trillion dollars!
Interesting. So they convolve the k,v, q vectors? I have been trying the opposite.
I have been working on a classification problem on audio data (with context size somewhere between 1000 and 3000 with potential to expand later). I have been experimenting with adding attention onto a CNN for a classification task I have been working on.
I tried training a vanilla transformer but in the sizes that I am aiming for (5-30M parameters), the training is incredibly unstable and doesn't achieve the performance of an LSTM.
So I went back to CNNs which are fast to train but don't achieve the losses of LSTMs (which are much slower to train,and for higher context sizes you get into the vanishing gradient problem). The CNN-GRU hubrid a worked much better, giving me my best result.
The GRU layer I used had a size of 512. For increasing context sizes, I'd have to make the convolutional layers deeper so as not to increase the GRU size too large. Instead, I decided to swap out the GRU with a MultiHeadAttention layer. The results are great - better than the CNN-GRU (my previous best). Plus, for equivalent sizes the model is faster to train though it hogs a lot of memory.
What codec were you using for the audio data?
We have to move past tokenization for the next leap in capabilities. All this work done on tokens, specially in the RL optimization contest, is just local optimization alchemy.
LLMs in their entirety are unlikely to move past tokenization - it is the inescapable core from the roots of NLP and Markov Chains.
The future of AI and all of ML in general likely does exist beyond tokenization, but I find it unlikely we will get there without moving past LLMs as a whole.
We need to focus on the strengths of LLMs and abandon the incredibly wasteful amount of effort being put into trying to make them put on convincing facsimiles of things they can't do just because the output is in natural language and easily fools humans at first glance.
They won't move past tokenization, but you can take it down to the byte level and make it arbitrarily flexible and adaptive:
https://ai.meta.com/research/publications/byte-latent-transf...
This is valid but also hard to back up with any alternatives. At the end of the day it’s just a neural network with backprop. New architectures will likely only be marginally better. So either we add new algorithms on top of it like RL, create a new learning algorithm (for example forward-forward), or we figure out how to use more energy efficient compute (analog etc) to scale several more magnitudes. It’s gonna take some time
Yeah, that's fair - it's very easy to tell that LLMs are not the end state, but it's near impossible to know what comes next.
Personally I think LLMs will be relegated to transforming output and input from whatever new logic system is brought forth, rather than pretending they're doing logic by aggregating static corpora like we are now.
Why is there an expectation that “nearby” tokens are relevant to increase the information in the similarities? That seems like it would hold true within individual words, but the whole point of attention was to solve long range dependencies. Reintroducing local windows seems like a step backwards to me.
Maybe it's helpful to find the right point in the long context, but then have easy access to the local structure around that point.
eg, yes, the magically relevant point is the third word of the fifth paragraph on page 183 of the document, but then having a good representation of all of that page is more helpful than the single word.
This doesn’t answer your question, but one thing to keep in mind is that past the very first layer, every “token” position is a weighted average of every previous position, so adjacency isn’t necessarily related to adjacent input tokens.
A borderline tautological answer might be “because the network learns that putting related things next to each other increases the usefulness of the convolutions”
It's a little more inductive bias. That's not necessarily a step backwards. You need the right amount of inductive bias for a given data size and model capacity, no more and no less. Transformers already make the inductive bias of temporal locality by being causal.
So, why would this extract more semantic meaning than multi-head attention? Isn't the whole point of multiple heads similar to how CNNs use multiple types of filters to extract different semantic relationships?
Achieved by “applying convolution operations over queries, keys and heads, allowing nearby queries and keys to affect each other's attention weights for more precise attention”
Cool to see convolutions making such a comeback lately in the llm world. See also the recent striped hyena2 architecture, which uses the conv-based hyena operator to great success:
https://arxiv.org/abs/2503.01868
The null hypothesis is more compute or bigger network = better results. Conv operations make sense on images because the data is naturally 2 dimensional, so applying an operation across a sliding window makes sense.
Skimming the paper, I don’t see them testing against e.g. a normal decoder with an extra layer or something.
I don’t see the same logic applying on an embedding, where the individual indexes matter. Adjacent indexes in an embedding have no relationship, unlike adjacent pixels in an image.
They do have a weak relationship, in that earlier index tokens were encountered earlier during the formation of the vocabulary, so they are similar in typicality
No, if you check the diagram (page 2) these are literally indexes into the KV vectors, not positional indexes in the text. If it was the text I would agree with you.
Convolutions are used in many non-image applications, including language (eg dilated convolutions have been popular for some time) and 1D cases. The paper I linked references the hyena operator, which is literally a convolution replacement for attention (though it’s often used in hybrid architectures like the one I linked).
There is a planet-wise eternal 100% safe AI solution that can be a billion dollar startup, too:
Put all the GPUs in cloud/s controlled by international scientists (now you can use your GPU on any device, can earn money by renting it when you don’t need it, nothing changes except you need to be online to us it, but we’ll have 5G and better worldwide. You can develop, sell or release free math-proven safe AI models in this cloud “AI App Store”, etc).
Because the main risk is an AI agent botnet - current GPUs are like nukes that are 100% unprotected - any hacker can make a virus with AI agent component just to steal money, this AI will be not aligned at all, will become a per perpetual and eventually autonomous botnet.