You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Apply DataProto to vllm inference pipeline, aligning its API with the sglang inferencer introduced in Unified data exchange protocol across modules #960. This unifies data exchange across inference engines and modernizes the vllm integration.
Remove Ray dependency in vllm, paving the way for a Ray-less lmflow implementation.
Detailed Description
DataProto integration
VLLMInferencer now returns DataProto instead of list[VLLMInferenceResultWithInput], with prompts in non_tensor_batch["inputs"] and generated text in non_tensor_batch["outputs"]
prepare_inputs_for_inference creates DataProto for both sglang and vllm through a unified code path
__vllm_inference in HFDecoderModel extracts prompts and sampling params from DataProto, converts to vllm.SamplingParams, and stores outputs back into the proto
Inference results are saved/loaded as pickle via DataProto.save_to_disk / load_from_disk
inference_results_path now accepts a directory — results are automatically saved as inference_results.pkl inside it
API alignment with sglang and modernization
VLLMInferencer now mirrors SGLangInferencer
Removed InferencerWithOffloading base class and all Ray-based distributed inference code -- vllm >= 0.8 supports data_parallel_size natively in vllm.LLM(), using a multiprocessing backend with no Ray dependency
Added --inference_data_parallel_size argument
Total GPUs used = tensor_parallel_size × data_parallel_size
Removed use_beam_search from sampling params (dropped in vLLM V1), added deprecation warning
Fixed deactivate_model_for_inference — old cleanup code referenced llm_engine.model_executor.driver_worker which no longer exists in V1
Added --inference_max_model_len to cap context length (prompt and output) for models with large defaults
Bumped vllm version constraint from >=0.4.3 to >=0.8.0 in setup.py
MemorySafeVLLMInferencer is updated to return DataProto. iterative_dpo_aligner.py consumes MemorySafeVLLMInferencer and will need a separate update to handle DataProto instead of list[VLLMInferenceResultWithInput].
Tests
6 unit tests pass (no GPU): sampling params parsing, DataProto save/load round-trip, DataProto repeat logic
2 GPU integration tests pass: full inference pipeline + save/load with Qwen3-0.6B on RTX 4090
Run scripts/run_vllm_inference.sh end-to-end with target model
examples/vllm_inference.py was flipped to release_gpu=True, but the method it invokes (HFModelMixin.deactivate_model_for_inference) still documents itself as a placeholder that "cannot release all gpu resources by our observation" for vLLM. Either the docstring is stale (and should be updated to reflect that vLLM >= 0.8 now releases properly, matching the new setup.py pin) or the example is misleading users. Also note that MemorySafeVLLMInferencer was added in this same PR precisely because in-process release was inadequate — worth a comment explaining when a user should pick release_gpu=True vs. MemorySafeVLLMInferencer.
Updated the deactivate_model_for_inference docstring to describe vllm>=0.8 best-effort GPU release behavior, and clarified when to use release_gpu=True versus MemorySafeVLLMInferencer (single-GPU inference vs. TP>1 / colocated training+inference).
Added a matching note in examples/vllm_inference.py.
Re-checked at 7408430. The docstring/example mismatch I flagged is fully fixed: deactivate_model_for_inference now scopes when in-process release is reliable (vllm >= 0.8, single-GPU, inference-only) vs. when MemorySafeVLLMInferencer should be used (tp_size > 1, CUDA graphs, colocated training+inference), and examples/vllm_inference.py carries a matching comment. Thanks for the quick turnaround.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Overview
DataPrototo vllm inference pipeline, aligning its API with the sglang inferencer introduced in Unified data exchange protocol across modules #960. This unifies data exchange across inference engines and modernizes the vllm integration.Detailed Description
DataProto integration
VLLMInferencernow returnsDataProtoinstead oflist[VLLMInferenceResultWithInput], with prompts innon_tensor_batch["inputs"]and generated text innon_tensor_batch["outputs"]prepare_inputs_for_inferencecreatesDataProtofor both sglang and vllm through a unified code path__vllm_inferenceinHFDecoderModelextracts prompts and sampling params fromDataProto, converts tovllm.SamplingParams, and stores outputs back into the protoDataProto.save_to_disk/load_from_diskinference_results_pathnow accepts a directory — results are automatically saved asinference_results.pklinside itAPI alignment with sglang and modernization
VLLMInferencernow mirrorsSGLangInferencerInferencerWithOffloadingbase class and all Ray-based distributed inference code -- vllm >= 0.8 supportsdata_parallel_sizenatively invllm.LLM(), using a multiprocessing backend with no Ray dependency--inference_data_parallel_sizeargumenttensor_parallel_size × data_parallel_sizeuse_beam_searchfrom sampling params (dropped in vLLM V1), added deprecation warningdeactivate_model_for_inference— old cleanup code referencedllm_engine.model_executor.driver_workerwhich no longer exists in V1--inference_max_model_lento cap context length (prompt and output) for models with large defaults>=0.4.3to>=0.8.0insetup.pyFiles changed
src/lmflow/pipeline/vllm_inferencer.pysrc/lmflow/models/hf_decoder_model.pysrc/lmflow/models/hf_model_mixin.pysrc/lmflow/args.pysrc/lmflow/pipeline/sglang_inferencer.pysrc/lmflow/pipeline/utils/memory_safe_vllm_inference.pyexamples/vllm_inference.pyscripts/run_vllm_inference.shscripts/run_sglang_inference.shsetup.pytests/pipeline/test_vllm_inferencer.pyDownstream impact
MemorySafeVLLMInferenceris updated to returnDataProto.iterative_dpo_aligner.pyconsumesMemorySafeVLLMInferencerand will need a separate update to handleDataProtoinstead oflist[VLLMInferenceResultWithInput].Tests
scripts/run_vllm_inference.shend-to-end with target model