Rethinking LLM inference: Why developer AI needs a different approach
TL;DR: We believe that full codebase context is critical for developer AI. But processing all this context usually comes at the cost of latency. At Augment, we’re tackling this challenge head-on, pushing the boundaries of what’s possible for LLM inference. This post breaks down the challenges of inference for coding, explaining Augment’s approach to optimizing LLM inference, and how building our inference stack delivers superior quality and speed to our customers.
Why context matters
For coding, context is everything. Changes in your code base depend not only on the current file, but also on dependencies and call-sites, READMEs, build files, third-party libraries, and so much more. Large language models have shown incredible flexibility in using additional context. At Augment, we have learned time and time again that providing more relevant context improves the quality of our products.
This example demonstrates how important small hints in the context are for coding. Augment has mastered the art of retrieving the most relevant pieces of information for each completion and each chat request in large code bases.
Optimize for decoding or context processing?
Our internal research has shown that code predictions continue to improve in quality when context length increases well beyond 10,000 tokens. However, adding so much context into the prompt means that we need to process a lot of tokens, posing a challenge for low latency responses.
Note that the balance of context tokens and output tokens in AI for code is very different from standard AI chat applications: A typical chat question has on the order of 100 input tokens and 100s of output tokens. But we usually have many thousands of context tokens and often only 10s of output tokens.
The open source community offers capable inference solutions today, such as vLLM and TensorRT-LLM. But existing solutions and benchmarks target the chat use case with short prompts (little context) and long answers. The same appears to be the case for proprietary services: An excellent blog post by friendli.ai suggests that the time to first token (TTFT) offered by these services dramatically increases for inputs beyond 1k tokens. They report the best TTFT from Llama3 70B for a 10k token prompt is around 1000ms.
While developing our inference stack, we faced multiple decision points that required balancing optimizations between context processing and decoding speed. Because the user experience of our products primarily depends on context, we obsessively prioritized context processing speed. The result is an inference stack able to serve requests with 10k input tokens to Llama3 70B with a TTFT of less than 300ms (3x faster), and that’s not even counting our unique approach to caching.
Another point of reference is a blog post by fireworks.ai, which compares their service to vLLM and TensorRT-LLM for the Llama 8B model with 30k context tokens and 256 output tokens using 8 GPUs (AMD MI300 GPUs for their service and NVIDIA H100 for TensorRT-LLM and vLLM). They show that these options stay below 5 requests/s. In contrast, our inference system can serve almost 10 requests/s on 8 H100s. All numbers are taken without speculative decoding and caching.
On FLOPS and memory bandwidth
High-powered GPUs such as NVIDIA’s H100 series are expensive, but they can do an incredible number of math operations per second, as measured in FLOPS (floating-point operations per second). However, if you compare the compute needed to process a single request with the total FLOPS of the device and the time it takes to answer the request, you end up at a FLOPS utilization below 1%. Most of that compute capability is wasted.
The main reason is that decoding is inefficient in terms of FLOPS utilization. Consider the runtime of a forward pass as a function of batch size.
As a first approximation, the achievable runtime is dominated by the time it takes the GPU to do all the calculations (FLOPS limit) and the time it takes the GPU to load the model weights from GPU memory into the cache (memory bandwidth limit). The compute requirements grow linearly with the batch size, but the model weights need to be loaded only once, so the bandwidth limit is not significantly affected. The crossover point for these two curves is where we transition from a memory bandwidth-limited regime into the flops-limited regime. For high-powered GPUs, this crossover point is very high, and we need several hundred or even several thousand tokens in the batch to be close to the FLOPS limit.
Context processing is naturally FLOPS bound, but decoding can have an unintuitively large impact on our FLOPS utilization: Say the crossover point is at 512, and we do a decoding step for a single token. The time this step takes is roughly the same as it would take us to process 512 tokens. So we used only about 1/512 ≈ 0.2% of the FLOPs the GPU could have provided. In a naive implementation, even a small number of decoding tokens per request can severely reduce average utilization.
Token-level batching
Traditional batching strategies group decoding steps of multiple requests into the same batch. Note that this does not fundamentally change the low utilization issue: Consider batching together 10 requests that are decoding. If our crossover point is 512 tokens to go from memory bandwidth limited to FLOPS limited, then we are still only utilizing 10/512 or about 2% of the FLOPS available.
Our batching strategy addresses this problem heads on. To avoid low FLOPS utilization in decoding steps and remain in the FLOPS bound regime whenever possible, we allow decoding steps to “piggyback” on the context processing of other requests.
We construct batches that mix and match tokens from multiple requests. To construct the next batch, we ask all the requests in the inference queue which tokens they want to process. If a request is in the decoding phase, it will contribute a single token, and if a request is in the context processing phase, it will contribute possibly a lot of tokens.
While very large batch sizes have marginally higher FLOPS utilization than large batch sizes, the overall request latency also needs to consider the decoding speed. The core issue is that Transformers can only generate one output token per batch (ignoring speculative decoding), so that the overall request latency grows linearly with the runtime of a single batch. To maximize throughput and minimize latency, we thus pick a batch size near the cross-over point, where we are close to the minimal runtime and close to the optimal FLOPS utilization.
The academic literature has recently caught on to this approach as well and we want to highlight the excellent papers on Sarathi (link 1, link 2) and DeepSpeed-FastGen (link). The academic literature calls the technique “chunked prefill”.
Requirements of a production inference system
Cancellations
Besides pure speed and predictability, an inference system for a production environment comes with many non-obvious requirements. One of them is request cancellation: We fire an inference request on each and every keystroke. This guarantees the best experience, but it can quickly overwhelm the backend, as successive keystrokes can arrive as close as 50ms apart and outpace the speed of an inference server. In a naive implementation, this means that a user might have multiple requests in flight, while we know that only the most recent request is actually needed.
To avoid this problem, it is important to reliably cancel work that’s identified as unnecessary. Our batching strategy enables this as each batch is relatively small and we can stop the computation of a request after any batch. It also plays well with caching, as we can reuse the partially completed context processing of the previous request if a prefix of the next request matches the last one.
Deployment sizes
There are alternative batching strategies that split context processing and decoding into separate groups of GPUs and send the prefilled KV caches between them. But this increases the size of the smallest unit that can be deployed independently, which adds painful complexity when scaling to multiple data centers and varying workloads. In contrast, our batching strategy allows us to colocate the processing of context and output tokens and thereby keeps deployments compact and uniform.
Our optimization process
The most important step in any optimization work comes at the beginning: choosing what to optimize. We focused on a single hardware platform (NVIDIA H100) and a single model architecture. Our batching strategy has the additional advantage that it allows us to narrow down the number of batch sizes to a handful of cases. This means the workload is almost static: one batch of tokens is the same as any other batch.
Once you have that focus, optimizing LLMs on GPUs has the same shape as every software optimization project: break down the workload into its component steps; measure the time of each component; compare those times to the theoretical ideal; sort by how much end-to-end speedup you can achieve per-component; work on the first one; then go back to the beginning. They may be old-fashioned, but spreadsheets are your friend!
Consider our breakdown of the batch time of the Llama 3.1 8B model. The numbers are for 4 NVIDIA H100 GPUs, the model is in FP8 precision, the batch size is 2048, and we have 8192 tokens in the KV cache.
What’s remarkable is that almost 3/4ths of total time is spent on matrix multiplies and self-attention, which are limited by the math throughput (FLOPS) of the GPU. And yet, there is still room to improve! At least 20% of the time remains in work that can be optimized away.
The highlights of our optimization journey include:
- CUDA Graphs: modern GPUs are so fast for inference workloads, they can perform the work faster than the host CPU can submit the work-to-do. As a result, you see idle gaps in execution where the GPU is waiting to receive its next unit of work. The solution is to pre-define all of the work-to-do and submit it in one go. On NVIDIA GPUs, the technology to do so is called CUDA Graphs. The focus on fixed-size batches makes this much easier, since the workload never changes even as request sizes do.
- FP8: the latest GPUs support 2x faster math throughput with 8-bit floating point values (compared to 16-bit floating point). This is usually called “weights and activations” quantization, since you are converting both the model weights and the runtime values to 8-bit floating point.
- FlashAttention-3: the FlashAttention project has galvanized the open source community to refine hardware-optimized implementations of self-attention that are much faster than prior alternatives. We are fortunate to collaborate with experts at Colfax International to refine the latest FlashAttention-3 release for the inference case. The result is the fastest implementation of self-attention that we are aware of.
- Efficient communication: in multi-GPU execution, there is frequent collective communication among the GPUs to share the split-up activations they all need. This time is pure overhead, so we moved to custom implementations of all_reduce and all_gather (based on those from TRT-LLM) that have much lower latency for inference than the out-of-the-box NCCL implementation.
- Custom CUDA kernels: after optimizing the big things like matrix multiplies and self-attention, you are left with overhead time spent doing miscellaneous…stuff. One of the best fixes for this overhead is kernel fusion: combine multiple operations into a single function so you don’t have to load activations over-and-over. We wrote our own CUDA kernels for various fusions that were too complex to express with native PyTorch functionality.
There is always more to optimize
For code completions and chat there is a latency below which users experience the product as “snappy”. Can we stop optimizing once we have reached this magic threshold?
Unfortunately, there are multiple sources of additional latency that contribute to the user experience: Network, request queueing, and retrieval from a million embeddings adds up to a significant amount of time, especially in the higher percentiles of request latencies. To give our users the best experience, we had to learn how to serve context-rich code completions with a time to first token below 120ms, while keeping the flops utilization of our GPUs over 25%. We believe that this is an outstanding number, considering that even highly optimized training jobs achieve only 38% to 43% FLOPs utilization (Llama 3.1 paper, page 10).
Today, our total completion latency stays below 220ms, setting a new bar for snappy code completions, while offering full codebase awareness on every keystroke.
Yet we believe that there is always more to optimize and that it is best to think about inference optimizations in terms of providing options: Any inference speedup can be seen as a budget that can be reinvested to drive value for our users. For example, we can spend this budget on increasing the model size or add even more context to further improve the understanding of large code bases.
There is a lot more to talk about, and we are planning to discuss additional challenges in future blog posts. For example, our strategy to develop an inference stack in-house enables us to optimize models for cache reuse, saving large fractions of work.
If LLM inference and building the most context-aware AI for code excites you, let’s talk! We’re looking for engineers who thrive on optimizing the future of AI-driven development. To experience the benefits of deep context, low latency AI for coding, sign up for Augment.