-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhost_balanced.py
More file actions
executable file
·124 lines (98 loc) · 3.71 KB
/
host_balanced.py
File metadata and controls
executable file
·124 lines (98 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import logging
import os
import signal
import torch
from flask import Flask, request, jsonify
from modules.host_common import *
from modules.scheduler_config import *
from modules.utils import *
app = Flask(__name__)
base = None
def __initialize_environment():
global base
base = CommonHost()
base.local_rank = 0
base.set_logger()
base.initialized = True
return
args = None
def __run_host():
global args, base
parser = argparse.ArgumentParser()
# generic
for k, v in GENERIC_HOST_ARGS.items(): parser.add_argument(f"--{k}", type=v, default=None)
for e in GENERIC_HOST_ARGS_TOGGLES: parser.add_argument(f"--{e}", action="store_true")
args = parser.parse_args()
torch._logging.set_logs(all=logging.CRITICAL)
base.log("ℹ️ Starting Flask host")
logging.getLogger('werkzeug').disabled = True
app.run(host="localhost", port=args.port)
return
@app.route("/<path>", methods=["GET", "POST"])
def handle_path(path):
global base
match path:
# status
case "initialize":
return base.get_initialized_flask()
case "applied":
return base.get_applied()
case "progress":
return base.get_progress_flask()
# generation
case "apply":
return __apply_pipeline_parallel(request.json)
case "generate":
return __generate_image_parallel(request.json)
case "offload":
return "Operation not supported by this host", 500
case "sleep":
return "Operation not supported by this host", 500
case "close":
base.log("🛑 Received exit signal - shutting down")
base.close_pipeline()
os.kill(os.getpid(), signal.SIGTERM)
raise HostShutdown
case _:
return "", 404
def __apply_pipeline_parallel(data):
global base
with torch.no_grad():
return base.setup_pipeline(data, backend_name="balanced")
def __generate_image_parallel(data):
global base
data = base.prepare_inputs(data)
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
base.progress = 0
# inference kwargs
kwargs = base.setup_inference(data, can_use_compel=False)
# inference
with torch.inference_mode():
output = base.pipe(**kwargs)
# clean up
clean()
base.progress = 100
# output
if output is not None:
if get_is_image_model(base.pipeline_type):
if base.pipeline_type in ["sdup"]:
output = output.images[0]
else:
output_images = output.images
if base.pipeline_type in ["flux"]: output_images = base.pipe._unpack_latents(output_images, data["height"], data["width"], base.pipe.vae_scale_factor)
flag = base.pipe.vae.device == torch.device("cpu")
if flag: base.pipe.vae = base.pipe.vae.to(device=output_images.device)
images = base.convert_latent_to_image(output_images)
latents = base.convert_latent_to_output_latent(output_images)
if flag: base.pipe.vae = base.pipe.vae.to(device="cpu")
return { "message": "OK", "output": pickle_and_encode_b64(images[0]), "latent": pickle_and_encode_b64(latents), "is_image": True }
else:
output = output.frames[0]
return { "message": "OK", "output": pickle_and_encode_b64(output), "is_image": False }
else:
return { "message": "No image from pipeline", "output": None, "is_image": False }
if __name__ == "__main__":
__initialize_environment()
__run_host()