Skip to content
Merged
41 changes: 39 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,46 @@ pip install flash-attn # This may take up to 10 mins.
pip install -e .
```

## Run
## Test simple server

```bash
ray start --head
python server.py [--tensor-parallel-size <N>]
python simple_server.py
```

The detailed arguments for `simple_server.py` can be found by:
```bash
python simple_server.py --help
```

## FastAPI server

Install the following additional dependencies:
```bash
pip install fastapi uvicorn
```

To start the server:
```bash
ray start --head
python -m cacheflow.http_frontend.fastapi_frontend
```

To test the server:
```bash
python -m cacheflow.http_frontend.test_cli_client
```

## Gradio web server

Install the following additional dependencies:
```bash
pip install gradio
```

Start the server:
```bash
python -m cacheflow.http_frontend.fastapi_frontend
# At another terminal
python -m cacheflow.http_frontend.gradio_webserver
```
152 changes: 152 additions & 0 deletions cacheflow/http_frontend/fastapi_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import argparse
import asyncio
import time
from typing import List, Dict
import json

import ray
from transformers import AutoTokenizer
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn

from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory

app = FastAPI()

class FastAPIFrontend:
def __init__(
self,
model: str,
model_path: str,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
):
self.block_size = block_size

self.tokenizer = AutoTokenizer.from_pretrained(model)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
remote_server_class = ray.remote(num_cpus=0)(Server)
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
dtype=dtype,
seed=seed,
swap_space=swap_space,
max_batch_size=max_batch_size,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)

self.running_seq_groups: Dict[int, SequenceGroup] = {}
self.sequence_group_events: Dict[int, asyncio.Event] = {}
self.is_server_running = False

async def server_step(self):
self.is_server_running = True
updated_seq_groups = await self.server.step.remote()
self.is_server_running = False
for seq_group in updated_seq_groups:
group_id = seq_group.group_id
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id].set()

async def generate(self, request_dict: Dict):
prompt = request_dict["prompt"]
sampling_params = SamplingParams.from_dict(request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
token_ids = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)

group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs)
group_event = asyncio.Event()
self.sequence_group_events[group_id] = group_event
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
while True:
if not self.is_server_running:
await self.server_step()
# Wait for new output. Add a 1s timeout to prevent dead lock.
await asyncio.wait_for(group_event.wait(), timeout=1)
group_event.clear()
seq_group = self.running_seq_groups[group_id]
all_outputs = []
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
all_outputs.append(output)
ret = {
"text": all_outputs,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
if seq_group.is_finished():
break


@app.post("/generate")
async def generate_stream(request: Request):
request_dict = await request.json()
return StreamingResponse(frontend.generate(request_dict))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser)
args = parser.parse_args()

# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')

(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))

frontend = FastAPIFrontend(
model=args.model,
model_path=args.model_path,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
)

uvicorn.run(app, host=args.host, port=args.port, log_level="info")
43 changes: 43 additions & 0 deletions cacheflow/http_frontend/gradio_webserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse
import json
import time

import gradio as gr
import requests


def http_bot(prompt):
headers = {"User-Agent": "Cacheflow Client"}
pload = {
"prompt": prompt,
"max_num_steps": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)

for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
yield output


def build_demo():
with gr.Blocks() as demo:
gr.Markdown(
"# Cacheflow demo\n"
)
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False)
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10003)
parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate")
args = parser.parse_args()

demo = build_demo()
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port)
23 changes: 23 additions & 0 deletions cacheflow/http_frontend/test_cli_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import requests
import json

def http_request():
prompt = "Ion Stoica is a"

headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": 4,
"use_beam_search": True,
"temperature": 0.0,
}
response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True)

for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output

for h in http_request():
print(h, flush=True)
26 changes: 14 additions & 12 deletions cacheflow/master/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, List
from typing import Dict, List, Tuple

from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
Expand All @@ -14,14 +13,12 @@ class Scheduler:

def __init__(
self,
frontend: Frontend,
controllers: List,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
) -> None:
self.frontend = frontend
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
Expand All @@ -47,9 +44,12 @@ def __init__(
# Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = []

def _fetch_inputs(self) -> None:
inputs = self.frontend.get_inputs()
for seq_group, sampling_params in inputs:
def add_sequence_groups(
self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]],
) -> None:
# Add sequence groups to the pending queue.
for seq_group, sampling_params in sequence_groups:
self.pending.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params

Expand Down Expand Up @@ -104,7 +104,7 @@ def _swap_out(
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)

def step(self) -> None:
def step(self) -> List[SequenceGroup]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
Expand Down Expand Up @@ -158,7 +158,6 @@ def step(self) -> None:
# 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
self._fetch_inputs()
if not self.swapped:
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
Expand All @@ -176,6 +175,8 @@ def step(self) -> None:

# 4. Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()

for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
Expand Down Expand Up @@ -219,6 +220,8 @@ def step(self) -> None:
blocks_to_copy,
)

return updated_seq_groups

def post_step(
self,
seq_outputs: Dict[int, SequenceOutputs],
Expand Down Expand Up @@ -268,13 +271,12 @@ def post_step(
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._return(seq_group)
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running

def _return(self, seq_group: SequenceGroup) -> None:
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
self.frontend.print_response(seq_group)
Loading