-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdecompress.py
More file actions
163 lines (138 loc) · 4.89 KB
/
decompress.py
File metadata and controls
163 lines (138 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import argparse
from pathlib import Path
import yaml
from easydict import EasyDict as edict
import zlib
import struct
from lib import image_utils
from lib.diffc.denoise import denoise
from lib.diffc.decode import decode
from lib.diffc.rcc.gaussian_channel_simulator import GaussianChannelSimulator
def parse_args():
parser = argparse.ArgumentParser(
description="Decompress DiffC-compressed images"
)
parser.add_argument(
"--config",
help="Path to the compression config file",
required=True
)
parser.add_argument(
"--input_path",
default=None,
help="Path to a single .diffc file to decompress"
)
parser.add_argument(
"--input_dir",
default=None,
help="Path to a directory containing .diffc files to decompress"
)
parser.add_argument(
"--output_dir",
required=True,
help="Directory to output the decompressed images to"
)
return parser.parse_args()
def get_noise_prediction_model(model_name, config):
if model_name == "SD1.5":
from lib.models.SD15 import SD15Model
return SD15Model()
elif model_name == "SD2.1":
from lib.models.SD21 import SD21Model
return SD21Model()
elif model_name == "SDXL":
from lib.models.SDXL import SDXLModel
use_refiner = config.get("use_refiner", False)
return SDXLModel(use_refiner=use_refiner)
elif model_name == 'Flux':
from lib.models.Flux import FluxModel
return FluxModel()
else:
raise ValueError(f"Unrecognized model: {model_name}")
def read_diffc_file(file_path):
with open(file_path, 'rb') as f:
# Read caption length (4 bytes)
caption_length = struct.unpack('<I', f.read(4))[0]
# Read width, height, and step_idx (2 bytes each)
width = struct.unpack('<H', f.read(2))[0]
height = struct.unpack('<H', f.read(2))[0]
step_idx = struct.unpack('<H', f.read(2))[0]
# Read and decompress caption
compressed_caption = f.read(caption_length)
caption = zlib.decompress(compressed_caption).decode('utf-8')
# Read remaining bytes for image data
image_bytes = list(f.read())
return caption, width, height, step_idx, image_bytes
def decompress_file(input_path, output_path, noise_prediction_model,
gaussian_channel_simulator, config):
# Read compressed data
caption, width, height, step_idx, compressed_bytes = read_diffc_file(input_path)
# Decompress the representation
chunk_seeds_per_step = gaussian_channel_simulator.decompress_chunk_seeds(
compressed_bytes, config.manual_dkl_per_step[:step_idx+1]
)
timestep = config.encoding_timesteps[step_idx]
# Configure model with caption
noise_prediction_model.configure(
caption,
config.denoising_guidance_scale,
width,
height
)
# Get the noisy reconstruction
noisy_recon = decode(
width,
height,
config.encoding_timesteps,
noise_prediction_model,
gaussian_channel_simulator,
chunk_seeds_per_step,
config.manual_dkl_per_step,
seed=0)
# Denoise
recon_latent = denoise(
noisy_recon,
timestep,
config.denoising_timesteps,
noise_prediction_model
)
# Convert to image and save
recon_img_pt = noise_prediction_model.latent_to_image(recon_latent)
image_utils.torch_to_pil_img(recon_img_pt).save(output_path)
def main():
args = parse_args()
# Load config
with open(args.config, "r") as f:
config = edict(yaml.safe_load(f))
assert config.manual_dkl_per_step is not None, "Config must specify a manual_dkl_per_step to perform decompression."
# Set up output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
# Get input paths
if not bool(args.input_path) ^ bool(args.input_dir):
raise ValueError("Must specify exactly one of --input_path or --input_dir")
input_paths = []
if args.input_path:
input_paths.append(Path(args.input_path))
else:
input_dir = Path(args.input_dir)
input_paths = list(input_dir.glob("*.diffc"))
# Initialize models
gaussian_channel_simulator = GaussianChannelSimulator(
config.max_chunk_size,
config.chunk_padding
)
noise_prediction_model = get_noise_prediction_model(config.model, config)
# Process each file
for input_path in input_paths:
# Create output path: {original_name}_decompressed.png
output_path = output_dir / f"{input_path.stem}_decompressed.png"
decompress_file(
input_path,
output_path,
noise_prediction_model,
gaussian_channel_simulator,
config
)
if __name__ == "__main__":
main()