Bases: ABC, Generic[TInitInfo, TUpdateInfo]
Base class for weight transfer engines that handle transport of model weights from a trainer to inference workers.
This abstraction separates weight transfer transport logic from the worker implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be plugged in.
Subclasses should define
init_info_cls: Type of backend-specific initialization info update_info_cls: Type of backend-specific update info
Source code in vllm/distributed/weight_transfer/base.py
| class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
"""
Base class for weight transfer engines that handle transport of model weights
from a trainer to inference workers.
This abstraction separates weight transfer transport logic from the worker
implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
plugged in.
Subclasses should define:
init_info_cls: Type of backend-specific initialization info
update_info_cls: Type of backend-specific update info
"""
# Subclasses should override these class attributes
init_info_cls: type[TInitInfo]
update_info_cls: type[TUpdateInfo]
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
self.config = config
self.parallel_config = parallel_config
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
"""
Construct typed init info from dict with validation.
Args:
init_dict: Dictionary containing backend-specific initialization parameters
Returns:
Typed backend-specific init info dataclass
Raises:
ValueError: If init_dict is invalid for this backend
"""
try:
return self.init_info_cls(**init_dict)
except TypeError as e:
raise ValueError(
f"Invalid init_info for {self.__class__.__name__}: {e}"
) from e
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
"""
Construct typed update info from dict with validation.
Args:
update_dict: Dictionary containing backend-specific update parameters
Returns:
Typed backend-specific update info dataclass
Raises:
ValueError: If update_dict is invalid for this backend
"""
try:
return self.update_info_cls(**update_dict)
except TypeError as e:
raise ValueError(
f"Invalid update_info for {self.__class__.__name__}: {e}"
) from e
@abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
Args:
init_info: Backend-specific initialization info
"""
raise NotImplementedError
@abstractmethod
def receive_weights(
self,
update_info: TUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer and load them incrementally.
Args:
update_info: Backend-specific update info containing parameter metadata
and any backend-specific data
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
This should be called when the worker is shutting down.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | Any,
) -> None:
"""
Send weights from trainer to inference workers.
This is a static method that can be called from the trainer process
to send weights to all inference workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
The tensors should be on the appropriate device for the backend.
trainer_args: Dictionary containing backend-specific arguments needed
to send weights. The structure depends on the backend:
- NCCL: Contains 'group', 'src', 'packed', etc.
- IPC: Contains 'mode' ('http' or 'ray'),
'llm_handle' (for Ray), 'url' (for HTTP), etc.
Example:
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> engine.trainer_send_weights(param_iter, trainer_args)
"""
raise NotImplementedError
|
__init__
Initialize the weight transfer engine.
Parameters:
| Name | Type | Description | Default |
config | WeightTransferConfig | The configuration for the weight transfer engine | required |
parallel_config | ParallelConfig | The configuration for the parallel setup | required |
Source code in vllm/distributed/weight_transfer/base.py
| def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
self.config = config
self.parallel_config = parallel_config
|
init_transfer_engine abstractmethod
init_transfer_engine(init_info: TInitInfo) -> None
Initialize the weight transfer mechanism. This is called once at the beginning of training.
Parameters:
| Name | Type | Description | Default |
init_info | TInitInfo | Backend-specific initialization info | required |
Source code in vllm/distributed/weight_transfer/base.py
| @abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
Args:
init_info: Backend-specific initialization info
"""
raise NotImplementedError
|
parse_init_info
parse_init_info(init_dict: dict[str, Any]) -> TInitInfo
Construct typed init info from dict with validation.
Parameters:
| Name | Type | Description | Default |
init_dict | dict[str, Any] | Dictionary containing backend-specific initialization parameters | required |
Returns:
| Type | Description |
TInitInfo | Typed backend-specific init info dataclass |
Raises:
| Type | Description |
ValueError | If init_dict is invalid for this backend |
Source code in vllm/distributed/weight_transfer/base.py
| def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
"""
Construct typed init info from dict with validation.
Args:
init_dict: Dictionary containing backend-specific initialization parameters
Returns:
Typed backend-specific init info dataclass
Raises:
ValueError: If init_dict is invalid for this backend
"""
try:
return self.init_info_cls(**init_dict)
except TypeError as e:
raise ValueError(
f"Invalid init_info for {self.__class__.__name__}: {e}"
) from e
|
parse_update_info
parse_update_info(
update_dict: dict[str, Any],
) -> TUpdateInfo
Construct typed update info from dict with validation.
Parameters:
| Name | Type | Description | Default |
update_dict | dict[str, Any] | Dictionary containing backend-specific update parameters | required |
Returns:
| Type | Description |
TUpdateInfo | Typed backend-specific update info dataclass |
Raises:
| Type | Description |
ValueError | If update_dict is invalid for this backend |
Source code in vllm/distributed/weight_transfer/base.py
| def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
"""
Construct typed update info from dict with validation.
Args:
update_dict: Dictionary containing backend-specific update parameters
Returns:
Typed backend-specific update info dataclass
Raises:
ValueError: If update_dict is invalid for this backend
"""
try:
return self.update_info_cls(**update_dict)
except TypeError as e:
raise ValueError(
f"Invalid update_info for {self.__class__.__name__}: {e}"
) from e
|
receive_weights abstractmethod
Receive weights from the trainer and load them incrementally.
Parameters:
| Name | Type | Description | Default |
update_info | TUpdateInfo | Backend-specific update info containing parameter metadata and any backend-specific data | required |
load_weights | Callable[[list[tuple[str, Tensor]]], None] | Callable that loads weights into the model. Called incrementally for each weight to avoid OOM. | required |
Source code in vllm/distributed/weight_transfer/base.py
| @abstractmethod
def receive_weights(
self,
update_info: TUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer and load them incrementally.
Args:
update_info: Backend-specific update info containing parameter metadata
and any backend-specific data
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
raise NotImplementedError
|
shutdown abstractmethod
Shutdown the weight transfer engine. This should be called when the worker is shutting down.
Source code in vllm/distributed/weight_transfer/base.py
| @abstractmethod
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
This should be called when the worker is shutting down.
"""
raise NotImplementedError
|
trainer_send_weights abstractmethod staticmethod
Send weights from trainer to inference workers.
This is a static method that can be called from the trainer process to send weights to all inference workers.
Parameters:
| Name | Type | Description | Default |
iterator | Iterator[tuple[str, Tensor]] | Iterator of model parameters. Returns (name, tensor) tuples. The tensors should be on the appropriate device for the backend. | required |
trainer_args | dict[str, Any] | Any | Dictionary containing backend-specific arguments needed to send weights. The structure depends on the backend: - NCCL: Contains 'group', 'src', 'packed', etc. - IPC: Contains 'mode' ('http' or 'ray'), 'llm_handle' (for Ray), 'url' (for HTTP), etc. | required |
Example
param_iter = ((n, p) for n, p in model.named_parameters()) engine.trainer_send_weights(param_iter, trainer_args)
Source code in vllm/distributed/weight_transfer/base.py
| @staticmethod
@abstractmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | Any,
) -> None:
"""
Send weights from trainer to inference workers.
This is a static method that can be called from the trainer process
to send weights to all inference workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
The tensors should be on the appropriate device for the backend.
trainer_args: Dictionary containing backend-specific arguments needed
to send weights. The structure depends on the backend:
- NCCL: Contains 'group', 'src', 'packed', etc.
- IPC: Contains 'mode' ('http' or 'ray'),
'llm_handle' (for Ray), 'url' (for HTTP), etc.
Example:
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> engine.trainer_send_weights(param_iter, trainer_args)
"""
raise NotImplementedError
|