Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is ITL's first token so long? #62

Open
sunshenao opened this issue Jan 3, 2025 · 3 comments
Open

Why is ITL's first token so long? #62

sunshenao opened this issue Jan 3, 2025 · 3 comments

Comments

@sunshenao
Copy link

model: Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4
GPU: H20 * 8
P: 4H20
D: 4
H20
input:1024,output : 6

sudo sh disagg_performance_benchmark.sh

This is what I get when qps=10.

mean_itl_ms
median_itl_ms

The average ITL is much larger for separated than for non-separated, but the median ITL for separated is much smaller.
I see this in the result file generated

image

The first token of the decode part always takes a long time.
This phenomenon does not occur at lower qps, e.g., when qps = 2

image

But when the qps is small, the ITL increase is not very obvious. but the TTFT increases a lot

image

May I ask why this is so, is there any way to reduce the time of the first token of decode?
I look forward to your answer. think you

@ShangmingCai
Copy link
Collaborator

In disaggregated prefilling scenarios, the first token of the decode (i.e., TTFT) consists of the prefill stage overhead, KVCache transfer cost, and the first run overhead of the decode stage. Since the layer-wise KVCache transfer is not ready yet, the unsatisfactory performance of TTFT is as expected temporarily.

Also, the implementation of KVCache transfer is not in a zero-copy fashion currently. And there is a buffer_lock in the implementation of simple_buffer.py, which might cause troubles when QPS is large.

Please refer to the roadmap of the disaggregated prefilling feature of vLLM (vllm-project/vllm#10818) and Mooncake (#44), there remains much work to do before we make this feature production-level ready.

BTW, how many GPUs do you use for the set of chunked prefill experiments?

@sunshenao
Copy link
Author

Thanks for your answer, I started 2 chunked prefill instances with 4*h20 each, just like the disagg_performance_benchmark.sh configuration.
The other question I have is, is that so far the results of my tests don't show any increase in performance compared to the chunked prefill instances, what is the reason for this?

For example, my results running on a30

gpu: A30
prefill: 1A30
decode: 1
A30
model : Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4
mooncake.json: {
"prefill_url": "localhost:13003",
"decode_url": "localhost:13103",
"metadata_server": "localhost:2379",
"metadata_backend": "etcd",
"protocol": "tcp",
"device_name": ""
}
input_len: 256, output_len: 6
sudo sh disagg_performance_benchmark.sh

image
image
image
image

@ShangmingCai
Copy link
Collaborator

Thanks for your answer, I started 2 chunked prefill instances with 4*h20 each, just like the disagg_performance_benchmark.sh configuration. The other question I have is, is that so far the results of my tests don't show any increase in performance compared to the chunked prefill instances, what is the reason for this?

Yes. Since PD separation is similar to pipeline processing at the current stage, even with high QPS, there exist computing bubbles on both nodes. However, for the two chunked prefill implementations, the resources of each GPU can be fully utilized, and because it does not involve KVCache transmission between nodes, the performance will not be limited by the network.

Therefore, 1P1D tests will not show any increase in performance compared to 2 chunked prefill instances. To better evaluate the practicality of PD separation, we need the implementation of XpYd, and also a heterogeneous GPU environment and high-speed network.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants