Pruning One More Token is Enough

James Davis
10 min readNov 8, 2024

--

This is a brief for the research paper “Pruning One More Token is Enough: Leveraging Latency-Workload Non-Linearities for Vision Transformers on the Edge”, published at WACV 2025. Here is a link to the full paper. This post was written by the first author, Nick Eliopoulos, and lightly edited by me.

Summary

Vision transfomers (ViTs) are a common architectural component of deep neural networks (DNN). Thus, techniques that improve ViT efficiency will benefit many DNN applications. The primary cost of a ViT is processing a stream of input tokens, so one way to improve ViT efficiency is to do less work by removing some tokens. This approach is called token sparsification. Prior works have used token sparsification on high-resource systems such as GPU servers and found that this technique can reliably improve throughput without significant accuracy degradation. However, when we tried out these methods on low-resource systems such as IoT devices, we sometimes observed the opposite effect: we still saw decreased accuracy, but unfortunately we also saw increased latency. We wondered why, and took a look under the hood with some detailed measurements. In this paper we show that the GPU Tail Effect can explain the phenomenon we observed. To sum up: hardware characteristics and machine learning frameworks both interact with the workload size (number of tokens), making token count non-linear with throughput.

Background

Making Vision Transformers More Efficient

ViT architectures such as DINOv2 achieve state-of-the-art accuracy on multiple tasks. However, these architectures are generally considered to be too costly to deploy them on low-resource environments such as IoT gadgets and other low-resource devices on the “network edge”.

The cost of a ViT can be roughly characterized by the following equation:

ViT inference time = # tokens x Cost to process each token

Unsurprisingly, researchers have thus explored two kinds of optimizations: removing tokens (eg token sparsification) or reducing processing costs (eg quantization or optimized kernels like FlashAttention or Liger). Our work is in the former vein, specifically within token sparsification.

Token sparsification involves identifying and removing unnecessary tokens (inputs) from a ViT model during inference time. By unnecessary we don’t mean that they are uninformative (though sometimes this is the case), but that the ViT can perform its work without the information they contain. For a simple example, there might be two adjacent tokens that are correlated, in which case only one of them need be included for training and inference. Or an image like a movie still might have some black edges to achieve a particular ratio, in which case the black edges contain no information.

To pick out unnecessary tokens, sparsification methods commonly employ either heuristics or small neural networks trained for the task. Successful token sparsification approaches can reduce latency while avoiding significant accuracy degradation. However, as we explain shortly, existing token sparsification methods implicitly assume that removing more tokens will always reduce latency and will do so in a linear way— which is, perhaps surprisingly, not always correct in a low-resource context.

Latency vs. Workload Relationships

The premise of token pruning is that reducing the workload (tokens) can decrease latency. In general, in the large — for large input sizes — this is true! However, in the small — for the small input sizes characteristic of IoT/Edge tasks — the relationship becomes non-linear.

The next figure shows the size of the effect. See the caption for details. Consequently, in many cases it is possible to achieve large latency reductions without removing too many tokens — and conversely, you may see a latency degradation if you do not take the hardware performance into account.

Non-linear relationship between token count and ViT inference time. METHOD: The left side shows a measurement on an IoT-grade device, the AGX Orin. The right side shows a measurement on a server-grade device, the NVIDIA A100. ANALYSIS: The top row shows that on small workloads (ImageNet 1K configured with batch size 1 or 2), the effect of reducing tokens is hard to model. Removing a small number of tokens (taking a step leftward on the x-axis) may lead to a large decrease in the inference time (a big step down on the y-axis). And it may actually increase the inference time, too — that’s what the sawtooth pattern indicates..

Problem Statement

This relationship can stem from how workloads are dispatched to the GPU by an ML framework, framework overhead, and kernel design decisions. Prior work for an earlier DNN architecture, convolutional neural networks (CNN), has taken advantage of this observation to guide their approaches (channel pruning). However, there is little investigation on how to leverage latency-workload relationships for the new class of DNN architecture, Vision Transformers.Furthermore, some of the existing work that does investigate these relationships lacks first-hand measurements of GPU metrics that lead to this latency behavior.

We summarize the knowledge gaps as follows:

  1. Many efficient techniques (sparsification being one of them) do not consider fine-grained hardware characteristics or behavior.
  2. There is no understanding of how to use latency-workload relationships to guide ViT token sparsification.
  3. There is a general lack of primary measurements regarding GPU behavior in work that investigates latency-workload relationships.

Approach

Framing as an optimization problem

In order to improve token sparsification by considering the latency vs. workload behavior, we want to choose a number of tokens to prune, R, based on this relationship. A good value of R should yield a large latency reduction, while avoiding significant accuracy degradation. This problem can be easily framed as a multi-objective optimization problem.

Let’s use n to denote the number of tokens we give to the network. Now we have a couple more terms to introduce:

  • A(n): The accuracy of inference with n tokens
  • L(n): The latency of inference with n tokens
  • Utility of A: Higher accuracy is better, so for a given n we calculate the ratio of a given A(i) to the maximum A(j). A larger ratio means we are closer to the maximum, hence higher utility.
  • Utility of L: Lower latency is better, so we do the same thing but calculate 1 minus the ratio to invert the result. Lower latency compared to the maximum thus means higher utility.
Utility functions for Latency and Accuracy.

With this arranged, we can formalize our target number of tokens to prune with a simple multi-objective optimization. A more complex optimization is possible, but we prefer to keep things simple when simple is good enough!

The equation to choose R based on accuracy and latency scores from our grid-search. R is the number of tokens we prune from the total available, N. This formulation takes a linear combination of the utility for accuracy and latency, with a single parameter alpha to weight them. In the full paper we do a short ablation on alpha values.

Measuring latency and accuracy for the utility functions

There are (at least) two options for measuring latency and accuracy for the utility functions.

  1. Option 1 (tried and failed): We could attempt to model the latency and accuracy as a function of the number of tokens, n. However, in our experiments, we found that accuracy and latency were hard to predict. This is because of the complexity of the software and hardware stack that underlies any use of an ML framework. The underlying property of interest is called the GPU Tail Effect, which says basically that the current generation of GPUs (which handle tasks on blocks of data in a Single-Instruction/Multiple Data manner) will perform the same amount of work in a cycle whether or not the block of data is full. However, it is hard to predict the relationship between the number of tokens and the fullness of the data blocks during inference, because the ML framework (PyTorch etc.) may change its strategy depending on the number of tokens! Maybe we could learn this function with enough information, but a simpler approach worked too.
  2. Option 2 (we use this one): Instead of modeling behavior, we can simply measure it. We can perform a grid-search (i.e. some for loops) over the number of tokens n and measure both latency and accuracy degradation. Measuring L(n) is straightforward. Gauging accuracy degradation is a bit more complex because the accuracy change may depend on the layer at which pruning occurs (prior work has shown that the layer matters). A lower bound for A(n) is to randomly remove tokens at the first layer of a ViT model, such that n tokens remain. This is a lower bound because it’s the worst possible scheme: any pruning method should be better than random pruning, and removing tokens at the first layer of a ViT has been shown to reduce accuracy more than later pruning.

We record latency L(n) according to Algorithm 1 below. and estimate accuracy degradation A(n) according to Algorithm 2. They aren’t scary, just a for loop. Now, you might wonder if obtaining L(n) and A(n) in this manner is cost effective. We think it is. These algorithms need to be done once for a particular ViT model and device — thus we call them “offline” measurements, since they are not done at inference time. If you run an engineering shop that wants to apply this technique, you can spare the ~5 hours it takes to run this measurement, especially since this time is dwarfed by the time needed to train the model in the first place.

Putting it all together

Here’s a pretty figure to show how we integrate these ideas into the resulting system. We divide the approach into two: the calculation of the pruning schedule (using that multi-objective optimization we discussed) and the selection of where to prune the tokens within the model. I omit this second part here, but see the paper for details.

Illustration of our method to decide a pruning schedule (left) and how we prune according to the schedule at inference time (right).

Results

Plan of action (“Experimental design”)

OK, so now we have a multi-objective optimization problem and we have defined ways to measure all of the variables of interest. We are ready to see how well it works.

  • We want to try a range of hardware, a workload relevant to low-resource environments, and see how well our method stacks up to the state of the art.
  • Since our approach and the competitors are all trying to balance accuracy and latency, we specifically want to learn whether, given a latency target, we get better accuracy than the competition; and similarly whether given an accuracy target, we get better latency than the competition.

We also ablate over a couple of design choices, namely: (1) the effect of the hyperparameter α that is used to weight the utility function; (2) the actual cost of the offline measurement of L(n) and A(n) — this is engineering overhead of our technique; and (3) which DNN layer(s) L at which we should actually prune. For brevity, I won’t describe these here.

Hardware

In order to capture a wide breadth of latency-workload behavior, we evaluate on two edge devices (NVIDIA AGX Orin and NVIDIA TX2) and a high-performance/server-class system with NVIDIA A100 GPUs. The next table characterizes these devices.

Summary of device hardware information. TX2 and Orin are IoT/Edge-grade, while A100 is a common server-grade GPU.

Workload

We used the ImageNet 1K dataset, applying vision transformers on the task of image classification. We used small batch sizes to imitate workloads within the capacity of IoT/Edge devices (think about tasks like facial recognition in airport traffic control, or traffic cameras detecting license plates for paying tolls — only a few pictures are captured per case, not a whole bunch, and the device may have limited memory to buffer the stream of pictures when under heavy load).

Comparison points

We compare our method to other state-of-the-art token sparsification works: DynamicViT (DyViT), Top-K, and Token Merging (ToMe). We evaluated these techniques on several state-of-the-art vision transformer models, such as DeIT-S, ViT-L, and DinoV2-G.

Evaluation 1: Holding latency constant, do we improve accuracy?

First, we chose pruning parameters such that these all methods achieve similar latency (as close as possible, within ~5% or 7ms) to our method. In this evaluation, we want to know if our scheme improves accuracy (measurement: top-1 loss, lower is better).

The next table shows some of the results of this experiment.

Comparison with existing methods on TX2. METHOD: The task is inference over the ImageNet1K dataset, with Batch Size 2. We parameterized the comparison approaches to have similar latency and measured the effect on the top-1 loss. ANALYSIS: At this latency point, our method obtains better accuracy (smaller loss) than the other training-free methods. DyViT does better than ours, but it requires training which increases its cost to apply.

Evaluation 2: How does the Pareto frontier look?

Next we evaluated our method against the others on a fixed task, varying the hyperparameters of each. In this evaluation, each technique is expected to offer the ability to trade off between accuracy and latency as needed for an application context. By plotting performance when the techniques achieve comparable latency or accuracy, we can see how they compare. The next figure shows that our technique does better than the others on this workload, for all of the latency and accuracy targets we measured.

Illustration of accuracy-latency tradeoffs of surveyed methods. METHODS: This figure show results on the AGX Orin, using the DinoV2-G base model with batch size of 2. ANALYSIS: Our pruning schedule and mechanism generate points that expand the pareto front. The number of tokens removed at each layer (r) of Top-K and ToMe is evaluated from r = 5 to = 8 in increments of 1.

Many more experiments, data, and ablations are in the paper :-).

Conclusion

In this work, we illustrate how token sparsification can be improved by utilizing latency-workload relationships. We showed how to determine a token pruning schedule by leveraging these relationships; in comparison with state-of-the-art methods like ToMe, our method yields equal or greater latency reductions while maintaining higher accuracy. Generally speaking, our work illustrates the importance of carefully considering latency behavior of DL models, as these vary across devices and particularly on low-resource devices.

Thanks for reading!

For more details, see the full paper (pre-print on arXiV):
https://arxiv.org/abs/2407.05941.

If you want to examine or build on our work, see our Github:
https://github.com/nickjeliopoulos/PruneOneMoreToken.

--

--

James Davis
James Davis

Written by James Davis

I am a professor in ECE@Purdue. My research assistants and I blog here about research findings and engineering tips.