Pruning One More Token is Enough
This is a brief for the research paper “Pruning One More Token is Enough: Leveraging Latency-Workload Non-Linearities for Vision Transformers on the Edge”, published at WACV 2025. Here is a link to the full paper. This post was written by the first author, Nick Eliopoulos, and lightly edited by me.
Summary
Vision transfomers (ViTs) are a common architectural component of deep neural networks (DNN). Thus, techniques that improve ViT efficiency will benefit many DNN applications. The primary cost of a ViT is processing a stream of input tokens, so one way to improve ViT efficiency is to do less work by removing some tokens. This approach is called token sparsification. Prior works have used token sparsification on high-resource systems such as GPU servers and found that this technique can reliably improve throughput without significant accuracy degradation. However, when we tried out these methods on low-resource systems such as IoT devices, we sometimes observed the opposite effect: we still saw decreased accuracy, but unfortunately we also saw increased latency. We wondered why, and took a look under the hood with some detailed measurements. In this paper we show that the GPU Tail Effect can explain the phenomenon we observed. To sum up: hardware characteristics and machine learning frameworks both interact with the workload size (number of tokens), making token count non-linear with throughput.
Background
Making Vision Transformers More Efficient
ViT architectures such as DINOv2 achieve state-of-the-art accuracy on multiple tasks. However, these architectures are generally considered to be too costly to deploy them on low-resource environments such as IoT gadgets and other low-resource devices on the “network edge”.
The cost of a ViT can be roughly characterized by the following equation:
ViT inference time = # tokens x Cost to process each token
Unsurprisingly, researchers have thus explored two kinds of optimizations: removing tokens (eg token sparsification) or reducing processing costs (eg quantization or optimized kernels like FlashAttention or Liger). Our work is in the former vein, specifically within token sparsification.
Token sparsification involves identifying and removing unnecessary tokens (inputs) from a ViT model during inference time. By unnecessary we don’t mean that they are uninformative (though sometimes this is the case), but that the ViT can perform its work without the information they contain. For a simple example, there might be two adjacent tokens that are correlated, in which case only one of them need be included for training and inference. Or an image like a movie still might have some black edges to achieve a particular ratio, in which case the black edges contain no information.
To pick out unnecessary tokens, sparsification methods commonly employ either heuristics or small neural networks trained for the task. Successful token sparsification approaches can reduce latency while avoiding significant accuracy degradation. However, as we explain shortly, existing token sparsification methods implicitly assume that removing more tokens will always reduce latency and will do so in a linear way— which is, perhaps surprisingly, not always correct in a low-resource context.
Latency vs. Workload Relationships
The premise of token pruning is that reducing the workload (tokens) can decrease latency. In general, in the large — for large input sizes — this is true! However, in the small — for the small input sizes characteristic of IoT/Edge tasks — the relationship becomes non-linear.
The next figure shows the size of the effect. See the caption for details. Consequently, in many cases it is possible to achieve large latency reductions without removing too many tokens — and conversely, you may see a latency degradation if you do not take the hardware performance into account.
Problem Statement
This relationship can stem from how workloads are dispatched to the GPU by an ML framework, framework overhead, and kernel design decisions. Prior work for an earlier DNN architecture, convolutional neural networks (CNN), has taken advantage of this observation to guide their approaches (channel pruning). However, there is little investigation on how to leverage latency-workload relationships for the new class of DNN architecture, Vision Transformers.Furthermore, some of the existing work that does investigate these relationships lacks first-hand measurements of GPU metrics that lead to this latency behavior.
We summarize the knowledge gaps as follows:
- Many efficient techniques (sparsification being one of them) do not consider fine-grained hardware characteristics or behavior.
- There is no understanding of how to use latency-workload relationships to guide ViT token sparsification.
- There is a general lack of primary measurements regarding GPU behavior in work that investigates latency-workload relationships.
Approach
Framing as an optimization problem
In order to improve token sparsification by considering the latency vs. workload behavior, we want to choose a number of tokens to prune, R, based on this relationship. A good value of R should yield a large latency reduction, while avoiding significant accuracy degradation. This problem can be easily framed as a multi-objective optimization problem.
Let’s use n to denote the number of tokens we give to the network. Now we have a couple more terms to introduce:
- A(n): The accuracy of inference with n tokens
- L(n): The latency of inference with n tokens
- Utility of A: Higher accuracy is better, so for a given n we calculate the ratio of a given A(i) to the maximum A(j). A larger ratio means we are closer to the maximum, hence higher utility.
- Utility of L: Lower latency is better, so we do the same thing but calculate 1 minus the ratio to invert the result. Lower latency compared to the maximum thus means higher utility.
With this arranged, we can formalize our target number of tokens to prune with a simple multi-objective optimization. A more complex optimization is possible, but we prefer to keep things simple when simple is good enough!
Measuring latency and accuracy for the utility functions
There are (at least) two options for measuring latency and accuracy for the utility functions.
- Option 1 (tried and failed): We could attempt to model the latency and accuracy as a function of the number of tokens, n. However, in our experiments, we found that accuracy and latency were hard to predict. This is because of the complexity of the software and hardware stack that underlies any use of an ML framework. The underlying property of interest is called the GPU Tail Effect, which says basically that the current generation of GPUs (which handle tasks on blocks of data in a Single-Instruction/Multiple Data manner) will perform the same amount of work in a cycle whether or not the block of data is full. However, it is hard to predict the relationship between the number of tokens and the fullness of the data blocks during inference, because the ML framework (PyTorch etc.) may change its strategy depending on the number of tokens! Maybe we could learn this function with enough information, but a simpler approach worked too.
- Option 2 (we use this one): Instead of modeling behavior, we can simply measure it. We can perform a grid-search (i.e. some for loops) over the number of tokens n and measure both latency and accuracy degradation. Measuring L(n) is straightforward. Gauging accuracy degradation is a bit more complex because the accuracy change may depend on the layer at which pruning occurs (prior work has shown that the layer matters). A lower bound for A(n) is to randomly remove tokens at the first layer of a ViT model, such that n tokens remain. This is a lower bound because it’s the worst possible scheme: any pruning method should be better than random pruning, and removing tokens at the first layer of a ViT has been shown to reduce accuracy more than later pruning.
We record latency L(n) according to Algorithm 1 below. and estimate accuracy degradation A(n) according to Algorithm 2. They aren’t scary, just a for loop. Now, you might wonder if obtaining L(n) and A(n) in this manner is cost effective. We think it is. These algorithms need to be done once for a particular ViT model and device — thus we call them “offline” measurements, since they are not done at inference time. If you run an engineering shop that wants to apply this technique, you can spare the ~5 hours it takes to run this measurement, especially since this time is dwarfed by the time needed to train the model in the first place.
Putting it all together
Here’s a pretty figure to show how we integrate these ideas into the resulting system. We divide the approach into two: the calculation of the pruning schedule (using that multi-objective optimization we discussed) and the selection of where to prune the tokens within the model. I omit this second part here, but see the paper for details.
Results
Plan of action (“Experimental design”)
OK, so now we have a multi-objective optimization problem and we have defined ways to measure all of the variables of interest. We are ready to see how well it works.
- We want to try a range of hardware, a workload relevant to low-resource environments, and see how well our method stacks up to the state of the art.
- Since our approach and the competitors are all trying to balance accuracy and latency, we specifically want to learn whether, given a latency target, we get better accuracy than the competition; and similarly whether given an accuracy target, we get better latency than the competition.
We also ablate over a couple of design choices, namely: (1) the effect of the hyperparameter α that is used to weight the utility function; (2) the actual cost of the offline measurement of L(n) and A(n) — this is engineering overhead of our technique; and (3) which DNN layer(s) L at which we should actually prune. For brevity, I won’t describe these here.
Hardware
In order to capture a wide breadth of latency-workload behavior, we evaluate on two edge devices (NVIDIA AGX Orin and NVIDIA TX2) and a high-performance/server-class system with NVIDIA A100 GPUs. The next table characterizes these devices.
Workload
We used the ImageNet 1K dataset, applying vision transformers on the task of image classification. We used small batch sizes to imitate workloads within the capacity of IoT/Edge devices (think about tasks like facial recognition in airport traffic control, or traffic cameras detecting license plates for paying tolls — only a few pictures are captured per case, not a whole bunch, and the device may have limited memory to buffer the stream of pictures when under heavy load).
Comparison points
We compare our method to other state-of-the-art token sparsification works: DynamicViT (DyViT), Top-K, and Token Merging (ToMe). We evaluated these techniques on several state-of-the-art vision transformer models, such as DeIT-S, ViT-L, and DinoV2-G.
Evaluation 1: Holding latency constant, do we improve accuracy?
First, we chose pruning parameters such that these all methods achieve similar latency (as close as possible, within ~5% or 7ms) to our method. In this evaluation, we want to know if our scheme improves accuracy (measurement: top-1 loss, lower is better).
The next table shows some of the results of this experiment.
Evaluation 2: How does the Pareto frontier look?
Next we evaluated our method against the others on a fixed task, varying the hyperparameters of each. In this evaluation, each technique is expected to offer the ability to trade off between accuracy and latency as needed for an application context. By plotting performance when the techniques achieve comparable latency or accuracy, we can see how they compare. The next figure shows that our technique does better than the others on this workload, for all of the latency and accuracy targets we measured.
Many more experiments, data, and ablations are in the paper :-).
Conclusion
In this work, we illustrate how token sparsification can be improved by utilizing latency-workload relationships. We showed how to determine a token pruning schedule by leveraging these relationships; in comparison with state-of-the-art methods like ToMe, our method yields equal or greater latency reductions while maintaining higher accuracy. Generally speaking, our work illustrates the importance of carefully considering latency behavior of DL models, as these vary across devices and particularly on low-resource devices.
Thanks for reading!
For more details, see the full paper (pre-print on arXiV):
https://arxiv.org/abs/2407.05941.
If you want to examine or build on our work, see our Github:
https://github.com/nickjeliopoulos/PruneOneMoreToken.