Files
pipecat/tests/vllm-inference-test.py
2024-06-20 19:33:33 -04:00

87 lines
2.8 KiB
Python

import asyncio
import time
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import random_uuid
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=4096
)
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.<|eot_id|><|start_header_id|>system<|end_header_id|>\n\nPlease introduce yourself to the user.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
async def main():
print("🥶 cold starting inference")
start = time.monotonic_ns()
engine_args = AsyncEngineArgs(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enable_prefix_caching=True,
gpu_memory_utilization=0.90,
enforce_eager=False, # False means slower starts but faster inference
disable_log_stats=True, # disable logging so we can stream tokens
disable_log_requests=True,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
duration_s = (time.monotonic_ns() - start) / 1e9
print(f"🏎️ engine started in {duration_s:.0f}s")
request_id = random_uuid()
result_generator = engine.generate(
prompt,
sampling_params,
request_id,
)
index, num_tokens = 0, 0
start = time.monotonic_ns()
async for output in result_generator:
if (
output.outputs[0].text
and "\ufffd" == output.outputs[0].text[-1]
):
continue
text_delta = output.outputs[0].text[index:]
index = len(output.outputs[0].text)
num_tokens = len(output.outputs[0].token_ids)
print(text_delta)
duration_s = (time.monotonic_ns() - start) / 1e9
print(
f"\n\tGenerated {num_tokens} tokens in {duration_s:.1f}s,"
f" throughput = {num_tokens / duration_s:.0f} tokens/second.\n"
)
return
async def xmain():
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enable_prefix_caching=True
)
outputs = llm.generate(prompt, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
outputs = llm.generate(prompt, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
asyncio.run(main())