Skip to content

Commit 5c375f0

Browse files
author
gushiqiao
committed
update server for return tensor
1 parent 11678e1 commit 5c375f0

14 files changed

Lines changed: 422 additions & 19 deletions

File tree

configs/ltx2/ltx2_3.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"audio_mel_bins":16,
1515
"double_precision_rope": true,
1616
"use_tiling_vae": false,
17-
"dit_original_ckpt": "Lightricks/LTX-2.3ltx-2.3-22b-dev.safetensors",
17+
"dit_original_ckpt": "/data/nvme4/models/ltx-2.3/ltx-2.3-22b-dev.safetensors",
1818
"caption_proj_before_connector": true,
1919
"cross_attention_adaln": true,
2020
"apply_gated_attention": true,

configs/ltx2/ltx2_3_distill_upsample_offload.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
"audio_fps": 24000,
1616
"audio_mel_bins":16,
1717
"double_precision_rope": true,
18-
"dit_original_ckpt": "Lightricks/LTX-2.3ltx-2.3-22b-distilled.safetensors",
18+
"dit_original_ckpt": "/data/nvme4/models/ltx-2.3/ltx-2.3-22b-distilled.safetensors",
1919
"skip_fp8_block_index" : [0, 43, 44, 45, 46, 47],
2020
"distilled_sigma_values": [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0],
2121
"caption_proj_before_connector": true,
2222
"cross_attention_adaln": true,
2323
"apply_gated_attention": true,
2424
"use_upsampler": true,
25-
"upsampler_original_ckpt": "Lightricks/LTX-2.3ltx-2.3-spatial-upscaler-x2-1.1.safetensors",
25+
"upsampler_original_ckpt": "/data/nvme4/models/ltx-2.3/ltx-2.3-spatial-upscaler-x2-1.1.safetensors",
2626
"distilled_sigma_values_upsample": [0.909375, 0.725, 0.421875, 0.0]
2727
}

lightx2v/models/networks/flux2_klein/infer/pre_infer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import torch
22
import torch.nn.functional as F
3-
from diffusers.models.transformers.transformer_flux2 import Flux2PosEmbed
3+
4+
try:
5+
from diffusers.models.transformers.transformer_flux2 import Flux2PosEmbed
6+
except ImportError:
7+
Flux2PosEmbed = None
48

59
from .module_io import Flux2KleinPreInferModuleOutput
610

lightx2v/models/networks/wan/model.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,90 @@ def infer(self, inputs):
197197
elif self.offload_granularity != "model":
198198
self.pre_weight.to_cpu()
199199
self.transformer_weights.non_block_weights_to_cpu()
200+
201+
@torch.no_grad()
202+
def infer_tensor_once(self, latents, timestep, context, context_null=None):
203+
"""
204+
Run one WAN forward pass from explicit tensors.
205+
206+
Args:
207+
latents: noisy latents, shape [C,F,H,W] or [1,F,C,H,W].
208+
timestep: timestep tensor (scalar / [1] / [1,F]); first value is used.
209+
context: conditional text embeddings, shape [L,D] or [1,L,D].
210+
context_null: optional unconditional text embeddings, same shape as context.
211+
Returns:
212+
noise prediction tensor with shape [C,F,H,W].
213+
"""
214+
if self.cpu_offload:
215+
if self.offload_granularity == "model" and "wan2.2_moe" not in self.config["model_cls"]:
216+
self.to_cuda()
217+
elif self.offload_granularity != "model":
218+
self.pre_weight.to_cuda()
219+
self.transformer_weights.non_block_weights_to_cuda()
220+
221+
if latents.ndim == 5:
222+
# [B,F,C,H,W] -> [C,F,H,W], only batch size 1 is supported.
223+
if latents.shape[0] != 1:
224+
raise ValueError(f"Expected batch size 1 for 5D latents, got shape {tuple(latents.shape)}")
225+
latents = latents.squeeze(0).permute(1, 0, 2, 3).contiguous()
226+
elif latents.ndim != 4:
227+
raise ValueError(f"Expected latents ndim in [4,5], got {latents.ndim}")
228+
229+
if context.ndim == 2:
230+
context = context.unsqueeze(0)
231+
if context.ndim != 3:
232+
raise ValueError(f"Expected context ndim in [2,3], got {context.ndim}")
233+
234+
if context_null is None:
235+
context_null = context
236+
elif context_null.ndim == 2:
237+
context_null = context_null.unsqueeze(0)
238+
239+
timestep = timestep.flatten()
240+
if timestep.numel() == 0:
241+
raise ValueError("Empty timestep tensor")
242+
timestep = timestep[:1].to(torch.int64).contiguous()
243+
244+
self.scheduler.prepare(seed=0, latent_shape=[1, 1, 1, 1], image_encoder_output={})
245+
self.scheduler.latents = latents.to(AI_DEVICE)
246+
self.scheduler.timestep_input = timestep.to(AI_DEVICE)
247+
248+
inputs = {
249+
"text_encoder_output": {
250+
"context": context.to(AI_DEVICE),
251+
"context_null": context_null.to(AI_DEVICE),
252+
},
253+
"image_encoder_output": {},
254+
}
255+
256+
def _convert_flow_pred_to_x0(flow_pred, xt, timestep_tensor):
257+
original_dtype = flow_pred.dtype
258+
flow_pred, xt, sigmas, timesteps = map(
259+
lambda x: x.double().to(flow_pred.device),
260+
[flow_pred, xt, self.scheduler.sigmas, self.scheduler.timesteps],
261+
)
262+
timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep_tensor.unsqueeze(1)).abs(), dim=1)
263+
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
264+
x0_pred = xt - sigma_t * flow_pred
265+
return x0_pred.to(original_dtype)
266+
267+
timestep_for_x0 = timestep.flatten()[:1]
268+
if self.config.get("enable_cfg", False):
269+
noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True)
270+
noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False)
271+
pred_x0_cond = _convert_flow_pred_to_x0(noise_pred_cond, self.scheduler.latents, timestep_for_x0)
272+
pred_x0_uncond = _convert_flow_pred_to_x0(noise_pred_uncond, self.scheduler.latents, timestep_for_x0)
273+
noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
274+
pred_x0 = pred_x0_uncond + self.scheduler.sample_guide_scale * (pred_x0_cond - pred_x0_uncond)
275+
else:
276+
noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
277+
pred_x0 = _convert_flow_pred_to_x0(noise_pred, self.scheduler.latents, timestep_for_x0)
278+
279+
if self.cpu_offload:
280+
if self.offload_granularity == "model" and "wan2.2_moe" not in self.config["model_cls"]:
281+
self.to_cpu()
282+
elif self.offload_granularity != "model":
283+
self.pre_weight.to_cpu()
284+
self.transformer_weights.non_block_weights_to_cpu()
285+
286+
return noise_pred, pred_x0

lightx2v/models/schedulers/flux2_klein/scheduler.py

100644100755
Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,19 @@
44

55
import numpy as np
66
import torch
7-
from diffusers.pipelines.flux2.pipeline_flux2 import compute_empirical_mu
8-
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
9-
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
107

8+
try:
9+
from diffusers.pipelines.flux2.pipeline_flux2 import compute_empirical_mu
10+
except ImportError:
11+
compute_empirical_mu = None
12+
try:
13+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
14+
except ImportError:
15+
retrieve_timesteps = None
16+
try:
17+
from diffusers.pipelines.flux2.pipeline_flux2 import FlowMatchEulerDiscreteScheduler
18+
except ImportError:
19+
FlowMatchEulerDiscreteScheduler = None
1120
from lightx2v.models.schedulers.scheduler import BaseScheduler
1221
from lightx2v.utils.envs import GET_DTYPE
1322
from lightx2v_platform.base.global_var import AI_DEVICE

lightx2v/models/video_encoders/hf/flux2_klein/vae.py

100644100755
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import os
22

33
import torch
4-
from diffusers.models import AutoencoderKLFlux2
5-
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
4+
5+
try:
6+
from diffusers.models import AutoencoderKLFlux2
7+
except ImportError:
8+
AutoencoderKLFlux2 = None
9+
try:
10+
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
11+
except ImportError:
12+
Flux2ImageProcessor = None
613

714
from lightx2v.utils.envs import GET_DTYPE
815
from lightx2v_platform.base.global_var import AI_DEVICE

lightx2v/server/api/tasks/video.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
66
from loguru import logger
77

8-
from ...schema import TaskResponse, VideoTaskRequest
8+
from ...schema import TaskResponse, VideoTaskRequest, WanTensorInferRequest, WanTensorInferResponse
99
from ...task_manager import task_manager
1010
from ..deps import get_services, validate_url_async
1111

@@ -107,3 +107,32 @@ async def save_file_async(file: UploadFile, target_dir: Path) -> str:
107107
except Exception as e:
108108
logger.error(f"Failed to create video form task: {e}")
109109
raise HTTPException(status_code=500, detail=str(e))
110+
111+
112+
@router.post("/tensor_infer", response_model=WanTensorInferResponse)
113+
async def tensor_infer_wan(message: WanTensorInferRequest):
114+
services = get_services()
115+
assert services.inference_service is not None, "Inference service is not initialized"
116+
117+
try:
118+
payload = {
119+
"task_id": message.task_id,
120+
"noisy_tensor": message.noisy_tensor,
121+
"context_tensor": message.context_tensor,
122+
"timestep_tensor": message.timestep_tensor,
123+
"context_null_tensor": message.context_null_tensor,
124+
"return_pred_x0": message.return_pred_x0,
125+
}
126+
result = await services.inference_service.submit_tensor_infer_async(payload)
127+
if result is None:
128+
raise HTTPException(status_code=500, detail="Tensor infer request failed")
129+
130+
if result.get("status") != "success":
131+
raise HTTPException(status_code=500, detail=result.get("error", "Tensor infer failed"))
132+
133+
return WanTensorInferResponse(**result)
134+
except HTTPException:
135+
raise
136+
except Exception as e:
137+
logger.error(f"Failed to process tensor infer request: {e}")
138+
raise HTTPException(status_code=500, detail=str(e))

lightx2v/server/schema.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ class TaskResponse(BaseModel):
7272
save_result_path: str
7373

7474

75+
class WanTensorInferRequest(BaseTaskRequest):
76+
noisy_tensor: str = Field(..., description="Base64-encoded torch tensor, shape [1,F,C,H,W] or [C,F,H,W]")
77+
context_tensor: str = Field(..., description="Base64-encoded torch tensor, shape [1,L,D] or [L,D]")
78+
timestep_tensor: str = Field(..., description="Base64-encoded torch tensor, scalar or [1] / [1,F]")
79+
context_null_tensor: str = Field("", description="Optional base64 tensor for unconditional context")
80+
return_pred_x0: bool = Field(False, description="Whether to also return pred_x0")
81+
82+
83+
class WanTensorInferResponse(BaseModel):
84+
task_id: str
85+
status: str
86+
noise_pred_tensor: str = Field("", description="Base64-encoded torch tensor")
87+
pred_x0_tensor: str = Field("", description="Base64-encoded torch tensor")
88+
message: str = Field("", description="Execution message")
89+
error: str = Field("", description="Error message when status=failed")
90+
91+
7592
class StopTaskResponse(BaseModel):
7693
stop_status: str
7794
reason: str

lightx2v/server/services/inference/service.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,32 @@ async def submit_task_async(self, task_data: dict) -> Optional[dict]:
7272
"message": f"Task processing failed: {str(e)}",
7373
}
7474

75+
async def submit_tensor_infer_async(self, task_data: dict) -> Optional[dict]:
76+
if not self.is_running or not self.worker:
77+
logger.error("Inference service is not started")
78+
return None
79+
80+
if self.worker.rank != 0:
81+
return None
82+
83+
try:
84+
if self.worker.processing:
85+
logger.info("Waiting for previous task to complete before tensor infer request")
86+
87+
self.worker.processing = True
88+
result = await self.worker.process_tensor_request(task_data)
89+
self.worker.processing = False
90+
return result
91+
except Exception as e:
92+
self.worker.processing = False
93+
logger.error(f"Failed to process tensor infer request: {str(e)}")
94+
return {
95+
"task_id": task_data.get("task_id", "unknown"),
96+
"status": "failed",
97+
"error": str(e),
98+
"message": f"Tensor infer processing failed: {str(e)}",
99+
}
100+
75101
def server_metadata(self):
76102
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
77103
return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}

lightx2v/server/services/inference/worker.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2+
import base64
23
import os
4+
from io import BytesIO
35
from pathlib import Path
46
from typing import Any, Dict
57

@@ -123,6 +125,95 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
123125
else:
124126
return None
125127

128+
@staticmethod
129+
def _decode_tensor_base64(tensor_b64: str, device: str | torch.device) -> torch.Tensor:
130+
tensor_bytes = base64.b64decode(tensor_b64)
131+
buffer = BytesIO(tensor_bytes)
132+
return torch.load(buffer, map_location=device)
133+
134+
@staticmethod
135+
def _encode_tensor_base64(tensor: torch.Tensor) -> str:
136+
buffer = BytesIO()
137+
torch.save(tensor.detach().cpu(), buffer)
138+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
139+
140+
@staticmethod
141+
def _lookup_sigma_from_scheduler(scheduler, timestep_tensor: torch.Tensor, target_device: torch.device, target_dtype: torch.dtype) -> torch.Tensor:
142+
# Match Self-Forcing wan_wrapper logic: nearest timestep id -> scheduler.sigmas[timestep_id]
143+
timesteps = scheduler.timesteps.to(target_device, dtype=torch.float64)
144+
sigmas = scheduler.sigmas.to(target_device, dtype=torch.float64)
145+
t = timestep_tensor.flatten().to(target_device, dtype=torch.float64)
146+
timestep_id = torch.argmin((timesteps.unsqueeze(0) - t.unsqueeze(1)).abs(), dim=1)
147+
sigma_t = sigmas[timestep_id].to(target_dtype)
148+
return sigma_t
149+
150+
def _ensure_tensor_infer_scheduler_ready(self) -> None:
151+
scheduler = self.runner.model.scheduler
152+
if getattr(scheduler, "timesteps", None) is not None and getattr(scheduler, "sigmas", None) is not None:
153+
return
154+
# We only need scheduler metadata here, so use a tiny latent shape.
155+
scheduler.prepare(
156+
seed=0,
157+
latent_shape=[16, 1, 2, 2],
158+
image_encoder_output={},
159+
)
160+
161+
async def process_tensor_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
162+
if self.world_size > 1:
163+
return {
164+
"task_id": task_data.get("task_id", "unknown"),
165+
"status": "failed",
166+
"error": "tensor infer endpoint currently supports WORLD_SIZE=1 only",
167+
"message": "tensor infer endpoint currently supports WORLD_SIZE=1 only",
168+
}
169+
170+
try:
171+
if not hasattr(self.runner, "model"):
172+
raise RuntimeError("Runner model is not initialized")
173+
174+
if not hasattr(self.runner.model, "infer_tensor_once"):
175+
raise RuntimeError(f"Current model class does not support tensor infer: {type(self.runner.model).__name__}")
176+
177+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
178+
self._ensure_tensor_infer_scheduler_ready()
179+
noisy_tensor = self._decode_tensor_base64(task_data["noisy_tensor"], device=device)
180+
context_tensor = self._decode_tensor_base64(task_data["context_tensor"], device=device)
181+
timestep_tensor = self._decode_tensor_base64(task_data["timestep_tensor"], device=device)
182+
183+
context_null_tensor = None
184+
if task_data.get("context_null_tensor"):
185+
context_null_tensor = self._decode_tensor_base64(task_data["context_null_tensor"], device=device)
186+
187+
return_pred_x0 = bool(task_data.get("return_pred_x0", False))
188+
189+
noise_pred, pred_x0 = self.runner.model.infer_tensor_once(
190+
latents=noisy_tensor,
191+
timestep=timestep_tensor,
192+
context=context_tensor,
193+
context_null=context_null_tensor,
194+
)
195+
if not return_pred_x0:
196+
pred_x0 = None
197+
198+
return {
199+
"task_id": task_data.get("task_id", "unknown"),
200+
"status": "success",
201+
"noise_pred_tensor": self._encode_tensor_base64(noise_pred),
202+
"pred_x0_tensor": self._encode_tensor_base64(pred_x0) if pred_x0 is not None else "",
203+
"message": "Tensor infer completed",
204+
"error": "",
205+
}
206+
except Exception as e:
207+
logger.exception(f"Rank {self.rank} tensor inference failed: {e}")
208+
return {
209+
"task_id": task_data.get("task_id", "unknown"),
210+
"status": "failed",
211+
"noise_pred_tensor": "",
212+
"pred_x0_tensor": "",
213+
"message": f"Tensor infer failed: {e}",
214+
"error": str(e),
215+
}
216+
126217
def switch_lora(self, lora_name: str, lora_strength: float):
127218
try:
128219
if lora_name is None:

0 commit comments

Comments
 (0)