Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -237,7 +238,7 @@ def send_kv_caches_and_hidden_states_hpu(
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
if self.rank != 0:
if not get_tp_group().is_first_rank:
# only the first rank will send kv cache
return
input_tokens_tensor_cpu = model_input.input_tokens.to("cpu") # shape: [batch_size, seq_len_padding_to_128]
Expand All @@ -259,7 +260,7 @@ def send_kv_caches_and_hidden_states_hpu(
current_tokens_cpu = input_tokens_tensor_cpu[idx][:slen]
store_key_prefix = self.tensor_hash(current_tokens_cpu)
logger.debug(f"send token len: {slen}, token: {current_tokens_cpu}")
keys, values = [], []
keys = []
start = 0
padded_total_size = (slen + self.block_size - 1) // self.block_size * self.block_size
current_slot_mapping = model_input.attn_metadata.slot_mapping[idx][start:padded_total_size]
Expand All @@ -269,23 +270,19 @@ def send_kv_caches_and_hidden_states_hpu(
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
key_cache = kv_cache[0].reshape(-1, num_kv_heads, self.k_v_head_size)
# value_cache = kv_cache[1].reshape(-1, num_kv_heads, v_head_size)

keys.append(key_cache.index_select(0, current_slot_mapping).unsqueeze(0))
# values.append(value_cache[current_slot_mapping].unsqueeze(0))

keys = torch.cat(keys, dim=0)
# values = torch.cat(values, dim=0)
# we pack kv together, only need send one tensor
kvcache_to_sent = keys
store_kvcache_key = f"{store_key_prefix}_{self.rank}"
self.kv_store.put(store_kvcache_key, kvcache_to_sent)

logger.debug(f"put kv cache key: {store_kvcache_key}")

hidden_key = f"{store_key_prefix}_hidden_{self.rank}"
self.kv_store.put(hidden_key,
hidden_or_intermediate_states[idx].unsqueeze(0).cpu())
layer_store_kvcache_key = f"{store_key_prefix}_l{layer_id}_tp0"
keys_tensor = key_cache.index_select(0, current_slot_mapping).unsqueeze(0)
self.kv_store.put(layer_store_kvcache_key, keys_tensor)
logger.warning(f"put kv cache key: {layer_store_kvcache_key}")
# only write hidden states at the last pipeline stage
if get_pp_group().is_last_rank:
hidden_key = f"{store_key_prefix}_hidden_tp0"
hidden_val = hidden_or_intermediate_states[idx] if hidden_or_intermediate_states is not None else None
if hidden_val is not None:
self.kv_store.put(hidden_key, hidden_val.unsqueeze(0).cpu())
logger.warning(f"put hidden key: {hidden_key}")
else:
logger.warning(f"hidden_or_intermediate_states[{idx}] is None, skip put for {hidden_key}")
# ==== graph should end here ======
htorch.core.mark_step()
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
Expand All @@ -312,8 +309,6 @@ def recv_kv_caches_and_hidden_states_hpu(
block_indices_list = attn_metadata.block_indices.tolist()

hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_block_idx = 0

# For each sequence in the batch, we patch kv tensor together, so we recv
Expand Down Expand Up @@ -353,54 +348,43 @@ def recv_kv_caches_and_hidden_states_hpu(

# get roi for current seq
load_key_prefix = self.tensor_hash(current_tokens)
# For deepseek, we only need recv first rank
load_kvcache_key = f"{load_key_prefix}_0"
remote_kv = self.kv_store.get(load_kvcache_key)
hidden_key = f"{load_key_prefix}_hidden_0"
hidden = self.kv_store.get(hidden_key)

if remote_kv is None or hidden is None:
# didn't find any match.
logger.warning(f"Didn't find any match, load_key_prefix: {load_kvcache_key}")
bypass_model_exec = False
continue

# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
num_computed_tokens = current_tokens.shape[0]
num_computed_tokens_list.append(num_computed_tokens)

# it's padded to block size now.
key_values = remote_kv.to("hpu")
keys = key_values
# values = key_values[..., self.k_head_size:]

# get kv cache for each layer
htorch.core.mark_step()
torch.hpu.synchronize()
# put received KV caches into paged memory layer by layer
# for each layer, we need to pad the key and value to 128, so
# key shape should be [num_blocks, block_size, num_kv_heads(1,ommited), k_head_size]
# value shape should be [num_blocks, block_size, num_kv_heads(1,ommited), v_head_size]
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
current_layer_idx = i - model_executable.model.start_layer
kv_cache = kv_caches[current_layer_idx]

key_cache, value_cache = kv_cache[0], kv_cache[1]
for layer_id in range(model_executable.model.start_layer, model_executable.model.end_layer):
layer_load_kvcache_key = f"{load_key_prefix}_l{layer_id}_tp0"
remote_kv = self.kv_store.get(layer_load_kvcache_key)
if remote_kv is None:
logger.warning(f"Didn't find kv cache, key: {layer_load_kvcache_key}")
bypass_model_exec = False
continue
kv_cache = kv_caches[layer_id - model_executable.model.start_layer]
key_cache = kv_cache[0]
key = remote_kv.squeeze(0) # [1, ...] -> [...]
num_blocks = block_indices_tensor.shape[0]
block_size = self.block_size
k_v_head_size = self.k_v_head_size
if key.dim() == 3 and key.shape[1] == 1:
key = key.squeeze(1)
key = key.view(num_blocks, block_size, k_v_head_size)
self.cache_k(
key,
key_cache,
block_indices_tensor,
None,
)

# [num_layers, seq_len, num_kv_heads, k/v_head_size] -> [seq_len, k/v_head_size]
key = keys[current_layer_idx].squeeze(-2).view(-1, self.block_size, self.k_v_head_size)
# value = values[current_layer_idx].squeeze(-2)
hidden_key = f"{load_key_prefix}_hidden_tp0"
hidden = self.kv_store.get(hidden_key)
if hidden is not None:
hidden_or_intermediate_states_for_one_req.append(hidden.to("hpu"))
else:
logger.warning(f"Didn't find hidden state, key: {hidden_key}")
bypass_model_exec = False

# ====== D2D =======
self.cache_k(key,
key_cache,
block_indices_tensor,
None,
)
start_block_idx = end_block_idx
hidden_or_intermediate_states_for_one_req.append(hidden.to("hpu"))
htorch.core.mark_step()

if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
Expand All @@ -412,7 +396,7 @@ def recv_kv_caches_and_hidden_states_hpu(
hidden_or_intermediate_states = None

else:
logger.debug(
logger.warning(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,7 +1990,7 @@ def warmup_scenario(self,
intermediate_tensors = \
self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
context_size=seq_len if is_prompt else 1,
context_size=inputs.input_tokens.shape[1],
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(inputs,
Expand Down
17 changes: 10 additions & 7 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,16 @@ def execute_model(
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if isinstance(output, IntermediateTensors):
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
else:
get_pp_group().send_tensor_dict({},
all_gather_group=get_tp_group())
return [None]
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time
Expand Down
Loading