Skip to content

Commit a10a7bf

Browse files
author
yihuiwen
committed
server save img to memory
1 parent 09f01d3 commit a10a7bf

7 files changed

Lines changed: 200 additions & 20 deletions

File tree

lightx2v/server/api/server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@ async def _process_single_task(self, task_info: Any):
129129
result = await generation_service.generate_with_stop_event(message, task_info.stop_event)
130130

131131
if result:
132-
task_manager.complete_task(task_id, result.save_result_path)
132+
task_manager.complete_task(
133+
task_id,
134+
save_result_path=result.save_result_path or None,
135+
result_png=getattr(result, "result_png", None),
136+
)
133137
logger.info(f"Task {task_id} completed successfully")
134138
else:
135139
if task_info.stop_event.is_set():

lightx2v/server/api/tasks/image.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from pathlib import Path
55

66
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
7+
from fastapi.responses import Response
78
from loguru import logger
89

910
from ...schema import ImageTaskRequest, TaskResponse
1011
from ...task_manager import TaskStatus, task_manager
1112
from ..deps import get_services, validate_url_async
12-
from .common import _stream_file_response
1313

1414
router = APIRouter()
1515

@@ -20,9 +20,6 @@ def _write_file_sync(file_path: Path, content: bytes) -> None:
2020

2121

2222
async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_interval_seconds: float):
23-
services = get_services()
24-
assert services.file_service is not None, "File service is not initialized"
25-
2623
start_time = time.monotonic()
2724
while True:
2825
task_status = task_manager.get_task_status(task_id)
@@ -31,14 +28,14 @@ async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_
3128

3229
status = task_status.get("status")
3330
if status == TaskStatus.COMPLETED.value:
34-
save_result_path = task_status.get("save_result_path")
35-
if not save_result_path:
36-
raise HTTPException(status_code=500, detail=f"Task completed but no result path found: {task_id}")
37-
38-
full_path = Path(save_result_path)
39-
if not full_path.is_absolute():
40-
full_path = services.file_service.output_video_dir / save_result_path
41-
return _stream_file_response(full_path)
31+
result_png = task_manager.get_task_result_png(task_id)
32+
if result_png:
33+
return Response(
34+
content=result_png,
35+
media_type="image/png",
36+
headers={"Content-Disposition": 'inline; filename="result.png"'},
37+
)
38+
raise HTTPException(status_code=500, detail=f"Task completed but no in-memory image found: {task_id}")
4239

4340
if status == TaskStatus.FAILED.value:
4441
raise HTTPException(status_code=500, detail=task_status.get("error", "Task failed"))
@@ -72,6 +69,7 @@ async def create_image_task(message: ImageTaskRequest):
7269
if not await validate_url_async(message.image_mask_path):
7370
raise HTTPException(status_code=400, detail=f"Image mask URL is not accessible: {message.image_mask_path}")
7471

72+
message.prefer_memory_result = False
7573
task_id = task_manager.create_task(message)
7674
message.task_id = task_id
7775

@@ -108,6 +106,7 @@ async def create_image_task_sync(
108106
if not await validate_url_async(message.image_mask_path):
109107
raise HTTPException(status_code=400, detail=f"Image mask URL is not accessible: {message.image_mask_path}")
110108

109+
message.prefer_memory_result = True
111110
task_id = task_manager.create_task(message)
112111
message.task_id = task_id
113112

@@ -184,6 +183,7 @@ async def save_file_async(file: UploadFile, target_dir: Path) -> str:
184183
)
185184

186185
try:
186+
message.prefer_memory_result = False
187187
task_id = task_manager.create_task(message)
188188
message.task_id = task_id
189189

lightx2v/server/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class BaseTaskRequest(DisaggOverrideRequest):
4040
target_shape: list[int] = Field([], description="Return video or image shape")
4141
lora_name: Optional[str] = Field(None, description="LoRA filename to load from lora_dir, None to disable LoRA")
4242
lora_strength: float = Field(1.0, description="LoRA strength")
43+
# Internal switch: sync API sets this True to return image from memory only.
44+
prefer_memory_result: bool = Field(default=False, exclude=True)
4345

4446
def __init__(self, **data):
4547
super().__init__(**data)
@@ -83,6 +85,8 @@ class TaskResponse(BaseModel):
8385
task_id: str
8486
task_status: str
8587
save_result_path: str
88+
# Filled after image generation in-process; never serialized in JSON responses.
89+
result_png: Optional[bytes] = Field(default=None, exclude=True)
8690

8791

8892
class StopTaskResponse(BaseModel):

lightx2v/server/services/generation/image.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A
4343

4444
self._prepare_output_path(message.save_result_path, task_data)
4545
task_data["seed"] = message.seed
46+
prefer_memory_result = bool(getattr(message, "prefer_memory_result", False))
47+
task_data.pop("prefer_memory_result", None)
48+
task_data["return_result_tensor"] = prefer_memory_result
4649

4750
result = await self.inference_service.submit_task_async(task_data)
4851

@@ -56,6 +59,17 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A
5659
actual_save_path = self.file_service.get_output_path(message.save_result_path)
5760
if not actual_save_path.suffix:
5861
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
62+
if prefer_memory_result:
63+
result_png = result.get("result_png")
64+
if not result_png:
65+
raise RuntimeError("Image inference did not return in-memory PNG bytes (result_png)")
66+
return TaskResponse(
67+
task_id=message.task_id,
68+
task_status="completed",
69+
save_result_path="",
70+
result_png=result_png,
71+
)
72+
5973
return TaskResponse(
6074
task_id=message.task_id,
6175
task_status="completed",
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Normalize image runner outputs to PNG bytes (in-memory, no disk)."""
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
import os
7+
import time
8+
from io import BytesIO
9+
from typing import Any, Optional
10+
11+
import torch
12+
from PIL import Image
13+
from loguru import logger
14+
15+
try:
16+
from torchvision.io import encode_png as tv_encode_png
17+
except Exception:
18+
tv_encode_png = None
19+
20+
21+
def _get_png_compression_level() -> int:
22+
raw = os.getenv("LIGHTX2V_SYNC_PNG_COMPRESSION", "6")
23+
try:
24+
level = int(raw)
25+
except ValueError:
26+
logger.warning(f"Invalid LIGHTX2V_SYNC_PNG_COMPRESSION={raw}, fallback to 6")
27+
return 6
28+
if level < 0 or level > 9:
29+
logger.warning(f"LIGHTX2V_SYNC_PNG_COMPRESSION={level} out of range [0,9], clamped")
30+
level = max(0, min(9, level))
31+
return level
32+
33+
34+
PNG_COMPRESSION_LEVEL = _get_png_compression_level()
35+
36+
37+
def _pil_to_png_bytes(pil_image: Image.Image) -> bytes:
38+
buf = BytesIO()
39+
img = pil_image
40+
if img.mode not in ("RGB", "RGBA"):
41+
img = img.convert("RGB")
42+
img.save(buf, format="PNG", compress_level=PNG_COMPRESSION_LEVEL)
43+
return buf.getvalue()
44+
45+
46+
def _pil_images_structure_to_png(images: Any) -> bytes:
47+
first = images[0]
48+
if isinstance(first, list):
49+
pil_image = first[0]
50+
else:
51+
pil_image = first
52+
if not hasattr(pil_image, "save"):
53+
raise TypeError(f"Unexpected image element type: {type(pil_image)}")
54+
return _pil_to_png_bytes(pil_image)
55+
56+
57+
def _tensor_to_png_bytes(image_tensor: torch.Tensor) -> bytes:
58+
total_start = time.perf_counter()
59+
task_tag = f"shape={tuple(image_tensor.shape)},dtype={image_tensor.dtype},device={image_tensor.device}"
60+
61+
cpu_start = time.perf_counter()
62+
tensor = image_tensor.detach().cpu()
63+
cpu_ms = (time.perf_counter() - cpu_start) * 1000
64+
65+
if tensor.ndim == 4:
66+
tensor = tensor[0]
67+
if tensor.ndim != 3:
68+
raise TypeError(f"Unsupported tensor shape: {tuple(tensor.shape)}")
69+
70+
prep_start = time.perf_counter()
71+
# Normalize layout once: keep CHW for fast PNG encoding path.
72+
if tensor.shape[0] in (1, 3, 4):
73+
tensor_chw = tensor
74+
elif tensor.shape[-1] in (1, 3, 4):
75+
tensor_chw = tensor.permute(2, 0, 1)
76+
else:
77+
raise TypeError(f"Unsupported tensor channel layout: {tuple(tensor.shape)}")
78+
79+
if tensor_chw.dtype.is_floating_point:
80+
# Most runners output floats in [0, 1].
81+
if float(tensor_chw.max()) <= 1.0:
82+
tensor_chw = (tensor_chw.clamp(0.0, 1.0) * 255.0).round()
83+
else:
84+
tensor_chw = tensor_chw.clamp(0.0, 255.0).round()
85+
86+
tensor_chw = tensor_chw.to(torch.uint8)
87+
prep_ms = (time.perf_counter() - prep_start) * 1000
88+
89+
# Fast path: encode PNG directly from CHW uint8 tensor.
90+
if tv_encode_png is not None:
91+
encode_start = time.perf_counter()
92+
png_bytes = tv_encode_png(tensor_chw, compression_level=PNG_COMPRESSION_LEVEL).numpy().tobytes()
93+
encode_ms = (time.perf_counter() - encode_start) * 1000
94+
total_ms = (time.perf_counter() - total_start) * 1000
95+
logger.info(f"Tensor->PNG(tv) cost total={total_ms:.2f}ms cpu_copy={cpu_ms:.2f}ms preprocess={prep_ms:.2f}ms encode={encode_ms:.2f}ms level={PNG_COMPRESSION_LEVEL} [{task_tag}]")
96+
return png_bytes
97+
98+
encode_start = time.perf_counter()
99+
arr = tensor_chw.permute(1, 2, 0).numpy()
100+
if arr.shape[-1] == 1:
101+
arr = arr[:, :, 0]
102+
png_bytes = _pil_to_png_bytes(Image.fromarray(arr))
103+
encode_ms = (time.perf_counter() - encode_start) * 1000
104+
total_ms = (time.perf_counter() - total_start) * 1000
105+
logger.info(f"Tensor->PNG(pil) cost total={total_ms:.2f}ms cpu_copy={cpu_ms:.2f}ms preprocess={prep_ms:.2f}ms encode={encode_ms:.2f}ms level={PNG_COMPRESSION_LEVEL} [{task_tag}]")
106+
return png_bytes
107+
108+
109+
def encode_pipeline_return_to_png_bytes(pipeline_return: Any) -> Optional[bytes]:
110+
"""Convert run_pipeline return value to a single PNG byte string, or None if not applicable."""
111+
if pipeline_return is None:
112+
return None
113+
try:
114+
if isinstance(pipeline_return, tuple) and len(pipeline_return) > 0:
115+
# e.g. BagelRunner returns (images, audio_or_none)
116+
pipeline_return = pipeline_return[0]
117+
if isinstance(pipeline_return, dict):
118+
images = pipeline_return.get("images")
119+
if images is None:
120+
return None
121+
if isinstance(images, torch.Tensor):
122+
return _tensor_to_png_bytes(images)
123+
return _pil_images_structure_to_png(images)
124+
if isinstance(pipeline_return, list) and len(pipeline_return) > 0:
125+
if isinstance(pipeline_return[0], torch.Tensor):
126+
return _tensor_to_png_bytes(pipeline_return[0])
127+
return _pil_images_structure_to_png(pipeline_return)
128+
if isinstance(pipeline_return, torch.Tensor):
129+
return _tensor_to_png_bytes(pipeline_return)
130+
if isinstance(pipeline_return, Image.Image):
131+
return _pil_to_png_bytes(pipeline_return)
132+
if isinstance(pipeline_return, str):
133+
raw = base64.b64decode(pipeline_return)
134+
img = Image.open(BytesIO(raw)).convert("RGB")
135+
return _pil_to_png_bytes(img)
136+
except Exception as e:
137+
logger.exception(f"Failed to encode pipeline output to PNG: {e}")
138+
return None
139+
return None

lightx2v/server/services/inference/worker.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import os
3+
import time
34
from pathlib import Path
45
from typing import Any, Dict
56

@@ -11,6 +12,7 @@
1112
from lightx2v.utils.set_config import set_config, set_parallel_config
1213

1314
from ..distributed_utils import DistributedManager
15+
from .pipeline_image_encode import encode_pipeline_return_to_png_bytes
1416

1517

1618
class TorchrunInferenceWorker:
@@ -66,6 +68,7 @@ def init(self, args) -> bool:
6668
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
6769
has_error = False
6870
error_msg = ""
71+
pipeline_return = None
6972

7073
try:
7174
if self.world_size > 1 and self.rank == 0:
@@ -79,7 +82,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
7982
self.switch_lora(lora_name, lora_strength)
8083

8184
task_data["task"] = self.runner.config["task"]
82-
task_data["return_result_tensor"] = False
85+
task_data["return_result_tensor"] = bool(task_data.get("return_result_tensor", False))
8386
task_data["negative_prompt"] = task_data.get("negative_prompt", "")
8487

8588
target_fps = task_data.pop("target_fps", None)
@@ -93,7 +96,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
9396
update_input_info_from_dict(self.input_info, task_data)
9497

9598
self.runner.set_config(task_data)
96-
self.runner.run_pipeline(self.input_info)
99+
pipeline_return = self.runner.run_pipeline(self.input_info)
97100

98101
await asyncio.sleep(0)
99102

@@ -114,12 +117,20 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
114117
"message": f"Inference failed: {error_msg}",
115118
}
116119
else:
117-
return {
120+
out: Dict[str, Any] = {
118121
"task_id": task_data["task_id"],
119122
"status": "success",
120-
"save_result_path": task_data["save_result_path"],
123+
"save_result_path": task_data.get("save_result_path"),
121124
"message": "Inference completed",
122125
}
126+
if task_data.get("return_result_tensor"):
127+
encode_start = time.perf_counter()
128+
png = encode_pipeline_return_to_png_bytes(pipeline_return)
129+
encode_elapsed_ms = (time.perf_counter() - encode_start) * 1000
130+
logger.info(f"Task {task_data.get('task_id')} encode result_png cost {encode_elapsed_ms:.2f} ms")
131+
if png:
132+
out["result_png"] = png
133+
return out
123134
else:
124135
return None
125136

lightx2v/server/task_manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class TaskInfo:
2828
end_time: Optional[datetime] = None
2929
error: Optional[str] = None
3030
save_result_path: Optional[str] = None
31+
result_png: Optional[bytes] = None
3132
stop_event: threading.Event = field(default_factory=threading.Event)
3233
thread: Optional[threading.Thread] = None
3334

@@ -81,7 +82,7 @@ def start_task(self, task_id: str) -> TaskInfo:
8182

8283
return task
8384

84-
def complete_task(self, task_id: str, save_result_path: Optional[str] = None):
85+
def complete_task(self, task_id: str, save_result_path: Optional[str] = None, result_png: Optional[bytes] = None):
8586
with self._lock:
8687
if task_id not in self._tasks:
8788
logger.warning(f"Task {task_id} not found for completion")
@@ -90,8 +91,8 @@ def complete_task(self, task_id: str, save_result_path: Optional[str] = None):
9091
task = self._tasks[task_id]
9192
task.status = TaskStatus.COMPLETED
9293
task.end_time = datetime.now()
93-
if save_result_path:
94-
task.save_result_path = save_result_path
94+
task.save_result_path = save_result_path
95+
task.result_png = result_png
9596

9697
self.completed_tasks += 1
9798
self._emit_queue_metrics_unlocked()
@@ -141,6 +142,13 @@ def get_task(self, task_id: str) -> Optional[TaskInfo]:
141142
with self._lock:
142143
return self._tasks.get(task_id)
143144

145+
def get_task_result_png(self, task_id: str) -> Optional[bytes]:
146+
with self._lock:
147+
task = self._tasks.get(task_id)
148+
if not task:
149+
return None
150+
return task.result_png
151+
144152
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
145153
task = self.get_task(task_id)
146154
if not task:

0 commit comments

Comments
 (0)