diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py index 67681dff8f6c..2b3223c30194 100644 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -14,6 +14,7 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.sequence import IntermediateTensors @@ -237,7 +238,7 @@ def send_kv_caches_and_hidden_states_hpu( hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: - if self.rank != 0: + if not get_tp_group().is_first_rank: # only the first rank will send kv cache return input_tokens_tensor_cpu = model_input.input_tokens.to("cpu") # shape: [batch_size, seq_len_padding_to_128] @@ -259,7 +260,7 @@ def send_kv_caches_and_hidden_states_hpu( current_tokens_cpu = input_tokens_tensor_cpu[idx][:slen] store_key_prefix = self.tensor_hash(current_tokens_cpu) logger.debug(f"send token len: {slen}, token: {current_tokens_cpu}") - keys, values = [], [] + keys = [] start = 0 padded_total_size = (slen + self.block_size - 1) // self.block_size * self.block_size current_slot_mapping = model_input.attn_metadata.slot_mapping[idx][start:padded_total_size] @@ -269,23 +270,19 @@ def send_kv_caches_and_hidden_states_hpu( for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] key_cache = kv_cache[0].reshape(-1, num_kv_heads, self.k_v_head_size) - # value_cache = kv_cache[1].reshape(-1, num_kv_heads, v_head_size) - - keys.append(key_cache.index_select(0, current_slot_mapping).unsqueeze(0)) - # values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - # values = torch.cat(values, dim=0) - # we pack kv together, only need send one tensor - kvcache_to_sent = keys - store_kvcache_key = f"{store_key_prefix}_{self.rank}" - self.kv_store.put(store_kvcache_key, kvcache_to_sent) - - logger.debug(f"put kv cache key: {store_kvcache_key}") - - hidden_key = f"{store_key_prefix}_hidden_{self.rank}" - self.kv_store.put(hidden_key, - hidden_or_intermediate_states[idx].unsqueeze(0).cpu()) + layer_store_kvcache_key = f"{store_key_prefix}_l{layer_id}_tp0" + keys_tensor = key_cache.index_select(0, current_slot_mapping).unsqueeze(0) + self.kv_store.put(layer_store_kvcache_key, keys_tensor) + logger.warning(f"put kv cache key: {layer_store_kvcache_key}") + # only write hidden states at the last pipeline stage + if get_pp_group().is_last_rank: + hidden_key = f"{store_key_prefix}_hidden_tp0" + hidden_val = hidden_or_intermediate_states[idx] if hidden_or_intermediate_states is not None else None + if hidden_val is not None: + self.kv_store.put(hidden_key, hidden_val.unsqueeze(0).cpu()) + logger.warning(f"put hidden key: {hidden_key}") + else: + logger.warning(f"hidden_or_intermediate_states[{idx}] is None, skip put for {hidden_key}") # ==== graph should end here ====== htorch.core.mark_step() logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) @@ -312,8 +309,6 @@ def recv_kv_caches_and_hidden_states_hpu( block_indices_list = attn_metadata.block_indices.tolist() hidden_or_intermediate_states_for_one_req = [] - input_tokens_list = [] - num_computed_tokens_list = [] start_block_idx = 0 # For each sequence in the batch, we patch kv tensor together, so we recv @@ -353,54 +348,43 @@ def recv_kv_caches_and_hidden_states_hpu( # get roi for current seq load_key_prefix = self.tensor_hash(current_tokens) - # For deepseek, we only need recv first rank - load_kvcache_key = f"{load_key_prefix}_0" - remote_kv = self.kv_store.get(load_kvcache_key) - hidden_key = f"{load_key_prefix}_hidden_0" - hidden = self.kv_store.get(hidden_key) - - if remote_kv is None or hidden is None: - # didn't find any match. - logger.warning(f"Didn't find any match, load_key_prefix: {load_kvcache_key}") - bypass_model_exec = False - continue - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - num_computed_tokens = current_tokens.shape[0] - num_computed_tokens_list.append(num_computed_tokens) - - # it's padded to block size now. - key_values = remote_kv.to("hpu") - keys = key_values - # values = key_values[..., self.k_head_size:] - + # get kv cache for each layer htorch.core.mark_step() torch.hpu.synchronize() - # put received KV caches into paged memory layer by layer - # for each layer, we need to pad the key and value to 128, so - # key shape should be [num_blocks, block_size, num_kv_heads(1,ommited), k_head_size] - # value shape should be [num_blocks, block_size, num_kv_heads(1,ommited), v_head_size] - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - current_layer_idx = i - model_executable.model.start_layer - kv_cache = kv_caches[current_layer_idx] - - key_cache, value_cache = kv_cache[0], kv_cache[1] + for layer_id in range(model_executable.model.start_layer, model_executable.model.end_layer): + layer_load_kvcache_key = f"{load_key_prefix}_l{layer_id}_tp0" + remote_kv = self.kv_store.get(layer_load_kvcache_key) + if remote_kv is None: + logger.warning(f"Didn't find kv cache, key: {layer_load_kvcache_key}") + bypass_model_exec = False + continue + kv_cache = kv_caches[layer_id - model_executable.model.start_layer] + key_cache = kv_cache[0] + key = remote_kv.squeeze(0) # [1, ...] -> [...] + num_blocks = block_indices_tensor.shape[0] + block_size = self.block_size + k_v_head_size = self.k_v_head_size + if key.dim() == 3 and key.shape[1] == 1: + key = key.squeeze(1) + key = key.view(num_blocks, block_size, k_v_head_size) + self.cache_k( + key, + key_cache, + block_indices_tensor, + None, + ) - # [num_layers, seq_len, num_kv_heads, k/v_head_size] -> [seq_len, k/v_head_size] - key = keys[current_layer_idx].squeeze(-2).view(-1, self.block_size, self.k_v_head_size) - # value = values[current_layer_idx].squeeze(-2) + hidden_key = f"{load_key_prefix}_hidden_tp0" + hidden = self.kv_store.get(hidden_key) + if hidden is not None: + hidden_or_intermediate_states_for_one_req.append(hidden.to("hpu")) + else: + logger.warning(f"Didn't find hidden state, key: {hidden_key}") + bypass_model_exec = False - # ====== D2D ======= - self.cache_k(key, - key_cache, - block_indices_tensor, - None, - ) start_block_idx = end_block_idx - hidden_or_intermediate_states_for_one_req.append(hidden.to("hpu")) htorch.core.mark_step() + if not bypass_model_exec: # Some of the KV cache is not retrieved # Here we will fall back to normal model forwarding @@ -412,7 +396,7 @@ def recv_kv_caches_and_hidden_states_hpu( hidden_or_intermediate_states = None else: - logger.debug( + logger.warning( "[rank%d]: Successfully received all KVs and hidden " "states, skip model forwarding.", torch.distributed.get_rank()) hidden_or_intermediate_states = torch.cat( diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 902eb0b50cf8..d1cc16f5501d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1990,7 +1990,7 @@ def warmup_scenario(self, intermediate_tensors = \ self.model.make_empty_intermediate_tensors( batch_size=batch_size, - context_size=seq_len if is_prompt else 1, + context_size=inputs.input_tokens.shape[1], dtype=self.model_config.dtype, device=self.device) self.execute_model(inputs, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 9d2ddb4615ee..0d5ed11f5111 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -454,13 +454,16 @@ def execute_model( model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: # output is IntermediateTensors - assert isinstance(output, IntermediateTensors) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) + if isinstance(output, IntermediateTensors): + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + output.tensors["model_execute_time"] = torch.tensor( + model_execute_time + orig_model_execute_time) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + else: + get_pp_group().send_tensor_dict({}, + all_gather_group=get_tp_group()) return [None] if (self.observability_config is not None and self.observability_config.collect_model_execute_time