Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.offloading.worker

OffloadingConnectorWorker

Implementation of Worker side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
class OffloadingConnectorWorker:
    """Implementation of Worker side methods"""

    def __init__(self, spec: OffloadingSpec):
        self.spec = spec
        self.worker = OffloadingWorker()

        self.kv_connector_stats = OffloadingConnectorStats()
        # job_id -> req_id for in-flight loads.
        self._load_jobs: dict[int, ReqId] = {}
        self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
        self._connector_worker_meta = OffloadingWorkerMetadata()

    def _register_handlers(self, kv_caches: CanonicalKVCaches):
        for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches):
            self.worker.register_handler(src_cls, dst_cls, handler)

    def register_kv_caches(
        self, kv_caches: dict[str, torch.Tensor | list[torch.Tensor]]
    ):
        num_blocks = self.spec.kv_cache_config.num_blocks

        # layer_name -> (num_blocks, page_size_bytes) int8 view.
        # Standardized layouts always have num_blocks as the leading dim.
        tensors_per_block: dict[str, tuple[torch.Tensor, ...]] = {}
        # layer_name -> size of (un-padded) page in bytes
        unpadded_page_size_bytes: dict[str, int] = {}
        # layer_name -> size of page in bytes
        page_size_bytes: dict[str, int] = {}
        for kv_cache_group in self.spec.kv_cache_config.kv_cache_groups:
            group_layer_names = kv_cache_group.layer_names
            group_kv_cache_spec = kv_cache_group.kv_cache_spec
            if isinstance(group_kv_cache_spec, UniformTypeKVCacheSpecs):
                per_layer_specs = group_kv_cache_spec.kv_cache_specs
            else:
                per_layer_specs = {}
            for layer_name in group_layer_names:
                layer_kv_cache_spec = per_layer_specs.get(
                    layer_name, group_kv_cache_spec
                )
                layer_kv_cache = kv_caches[layer_name]
                # AttentionSpec yields a single tensor; MambaSpec yields a
                # list of typed state tensors that share one underlying
                # buffer. Either way, the first tensor's storage_offset
                # marks the start of this layer's region.
                ref = (
                    layer_kv_cache[0]
                    if isinstance(layer_kv_cache, list)
                    else layer_kv_cache
                )
                page = layer_kv_cache_spec.page_size_bytes
                offset = ref.storage_offset() * ref.element_size()
                tensors_per_block[layer_name] = (
                    torch.tensor([], dtype=torch.int8, device=ref.device)
                    .set_(ref.untyped_storage())
                    .view(-1)[offset : offset + num_blocks * page]
                    .view(num_blocks, page),
                )
                page_size_bytes[layer_name] = page

                if isinstance(layer_kv_cache_spec, AttentionSpec):
                    unpadded_page_size_bytes[layer_name] = (
                        layer_kv_cache_spec.real_page_size_bytes
                    )
                elif isinstance(layer_kv_cache_spec, MambaSpec):
                    unpadded_page_size_bytes[layer_name] = replace(
                        layer_kv_cache_spec, page_size_padded=None
                    ).page_size_bytes
                else:
                    raise NotImplementedError

        block_tensors: list[CanonicalKVCacheTensor] = []
        block_data_refs: dict[str, list[CanonicalKVCacheRef]] = defaultdict(list)
        for kv_cache_tensor in self.spec.kv_cache_config.kv_cache_tensors:
            for slot_layers in kv_cache_tensor.shared_by:
                # Filter to layers that were actually processed above.
                # Some slots may have no corresponding model layer (reserved
                # memory with no group layer at that index).
                tensor_layer_names = [n for n in slot_layers if n in tensors_per_block]
                if not tensor_layer_names:
                    continue

                # Verify all layers in the slot reference the same tensors.
                assert len({len(tensors_per_block[n]) for n in tensor_layer_names}) == 1
                assert (
                    len(
                        {tensors_per_block[n][0].data_ptr() for n in tensor_layer_names}
                    )
                    == 1
                )
                assert (
                    len({tensors_per_block[n][0].stride() for n in tensor_layer_names})
                    == 1
                )

                first_layer_name = tensor_layer_names[0]
                for tensor in tensors_per_block[first_layer_name]:
                    block_tensors.append(
                        CanonicalKVCacheTensor(
                            tensor=tensor,
                            page_size_bytes=page_size_bytes[first_layer_name],
                        )
                    )

                    curr_tensor_idx = len(block_tensors) - 1
                    for layer_name in tensor_layer_names:
                        block_data_refs[layer_name].append(
                            CanonicalKVCacheRef(
                                tensor_idx=curr_tensor_idx,
                                page_size_bytes=(unpadded_page_size_bytes[layer_name]),
                            )
                        )

        group_data_refs: list[list[CanonicalKVCacheRef]] = []
        for kv_cache_group in self.spec.kv_cache_config.kv_cache_groups:
            group_refs: list[CanonicalKVCacheRef] = []
            for layer_name in kv_cache_group.layer_names:
                group_refs += block_data_refs[layer_name]
            group_data_refs.append(group_refs)

        canonical_kv_caches = CanonicalKVCaches(
            tensors=block_tensors,
            group_data_refs=group_data_refs,
        )

        self._register_handlers(canonical_kv_caches)

    def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
        for job_id, transfer_spec in self._unsubmitted_store_jobs:
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success
        self._unsubmitted_store_jobs.clear()

        if kv_connector_metadata.jobs_to_flush:
            self.worker.wait(kv_connector_metadata.jobs_to_flush)

    def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
        for job_id, transfer_spec in self._unsubmitted_store_jobs:
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success
        self._unsubmitted_store_jobs.clear()

        for job_id, entry in metadata.load_jobs.items():
            self._load_jobs[job_id] = entry.req_id
            success = self.worker.transfer_async(job_id, entry.transfer_spec)
            assert success

    def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
        for job_id, entry in metadata.store_jobs.items():
            # NOTE(orozery): defer the store to the beginning of the next
            # engine step, so that offloading starts AFTER transfers related
            # to token sampling, thereby avoiding delays to token generation.
            self._unsubmitted_store_jobs.append((job_id, entry.transfer_spec))

    def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
        """
        Returns:
            tuple of (finished_sending, finished_recving). Stores never
            emit finished_sending — the scheduler tracks store completion
            via kv_connector_worker_meta.completed_jobs and fences any
            block reuse via jobs_to_flush. Loads still emit
            finished_recving so the base scheduler can resume requests
            blocked on remote KV (and free aborted-during-load reqs).
        """
        finished_recving: set[str] = set()
        for transfer_result in self.worker.get_finished():
            # we currently do not support job failures
            job_id = transfer_result.job_id
            assert transfer_result.success
            if (
                transfer_result.transfer_time
                and transfer_result.transfer_size is not None
                and transfer_result.transfer_type is not None
            ):
                self.kv_connector_stats.record_transfer(
                    num_bytes=transfer_result.transfer_size,
                    time=transfer_result.transfer_time,
                    transfer_type=transfer_result.transfer_type,
                )

            self._connector_worker_meta.mark_completed(job_id)
            req_id = self._load_jobs.pop(job_id, None)
            if req_id is not None:
                finished_recving.add(req_id)

        return set(), finished_recving

    def build_connector_worker_meta(self) -> OffloadingWorkerMetadata | None:
        """Return completed transfer job IDs since the last call."""
        if not self._connector_worker_meta.completed_jobs:
            return None
        meta = self._connector_worker_meta
        self._connector_worker_meta = OffloadingWorkerMetadata()
        return meta

    def get_kv_connector_stats(self) -> KVConnectorStats | None:
        """
        Get the KV transfer stats for the connector.
        """

        if self.kv_connector_stats.is_empty():
            return None
        # Clear stats for next iteration
        kv_connector_stats = self.kv_connector_stats
        self.kv_connector_stats = OffloadingConnectorStats()
        return kv_connector_stats

    def shutdown(self) -> None:
        self._unsubmitted_store_jobs.clear()
        self._load_jobs.clear()
        self._connector_worker_meta = OffloadingWorkerMetadata()
        self.worker.shutdown()

build_connector_worker_meta

build_connector_worker_meta() -> (
    OffloadingWorkerMetadata | None
)

Return completed transfer job IDs since the last call.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
def build_connector_worker_meta(self) -> OffloadingWorkerMetadata | None:
    """Return completed transfer job IDs since the last call."""
    if not self._connector_worker_meta.completed_jobs:
        return None
    meta = self._connector_worker_meta
    self._connector_worker_meta = OffloadingWorkerMetadata()
    return meta

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[set[str], set[str]]

Returns:

Type Description
set[str]

tuple of (finished_sending, finished_recving). Stores never

set[str]

emit finished_sending — the scheduler tracks store completion

tuple[set[str], set[str]]

via kv_connector_worker_meta.completed_jobs and fences any

tuple[set[str], set[str]]

block reuse via jobs_to_flush. Loads still emit

tuple[set[str], set[str]]

finished_recving so the base scheduler can resume requests

tuple[set[str], set[str]]

blocked on remote KV (and free aborted-during-load reqs).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
    """
    Returns:
        tuple of (finished_sending, finished_recving). Stores never
        emit finished_sending — the scheduler tracks store completion
        via kv_connector_worker_meta.completed_jobs and fences any
        block reuse via jobs_to_flush. Loads still emit
        finished_recving so the base scheduler can resume requests
        blocked on remote KV (and free aborted-during-load reqs).
    """
    finished_recving: set[str] = set()
    for transfer_result in self.worker.get_finished():
        # we currently do not support job failures
        job_id = transfer_result.job_id
        assert transfer_result.success
        if (
            transfer_result.transfer_time
            and transfer_result.transfer_size is not None
            and transfer_result.transfer_type is not None
        ):
            self.kv_connector_stats.record_transfer(
                num_bytes=transfer_result.transfer_size,
                time=transfer_result.transfer_time,
                transfer_type=transfer_result.transfer_type,
            )

        self._connector_worker_meta.mark_completed(job_id)
        req_id = self._load_jobs.pop(job_id, None)
        if req_id is not None:
            finished_recving.add(req_id)

    return set(), finished_recving

get_kv_connector_stats

get_kv_connector_stats() -> KVConnectorStats | None

Get the KV transfer stats for the connector.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
def get_kv_connector_stats(self) -> KVConnectorStats | None:
    """
    Get the KV transfer stats for the connector.
    """

    if self.kv_connector_stats.is_empty():
        return None
    # Clear stats for next iteration
    kv_connector_stats = self.kv_connector_stats
    self.kv_connector_stats = OffloadingConnectorStats()
    return kv_connector_stats