Skip to content

Commit 0dbba9f

Browse files
Add some latent operation nodes.
This is a port of the ModelSamplerTonemapNoiseTest from the experiments repo. To replicate that node use LatentOperationTonemapReinhard and LatentApplyOperationCFG together.
1 parent f584758 commit 0dbba9f

1 file changed

Lines changed: 79 additions & 0 deletions

File tree

comfy_extras/nodes_latent.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,90 @@ def op(self, samples, seed_behavior):
145145

146146
return (samples_out,)
147147

148+
class LatentApplyOperation:
149+
@classmethod
150+
def INPUT_TYPES(s):
151+
return {"required": { "samples": ("LATENT",),
152+
"operation": ("LATENT_OPERATION",),
153+
}}
154+
155+
RETURN_TYPES = ("LATENT",)
156+
FUNCTION = "op"
157+
158+
CATEGORY = "latent/advanced/operations"
159+
EXPERIMENTAL = True
160+
161+
def op(self, samples, operation):
162+
samples_out = samples.copy()
163+
164+
s1 = samples["samples"]
165+
samples_out["samples"] = operation(latent=s1)
166+
return (samples_out,)
167+
168+
class LatentApplyOperationCFG:
169+
@classmethod
170+
def INPUT_TYPES(s):
171+
return {"required": { "model": ("MODEL",),
172+
"operation": ("LATENT_OPERATION",),
173+
}}
174+
RETURN_TYPES = ("MODEL",)
175+
FUNCTION = "patch"
176+
177+
CATEGORY = "latent/advanced/operations"
178+
EXPERIMENTAL = True
179+
180+
def patch(self, model, operation):
181+
m = model.clone()
182+
183+
def pre_cfg_function(args):
184+
conds_out = args["conds_out"]
185+
if len(conds_out) == 2:
186+
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
187+
else:
188+
conds_out[0] = operation(latent=conds_out[0])
189+
return conds_out
190+
191+
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
192+
return (m, )
193+
194+
class LatentOperationTonemapReinhard:
195+
@classmethod
196+
def INPUT_TYPES(s):
197+
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
198+
}}
199+
200+
RETURN_TYPES = ("LATENT_OPERATION",)
201+
FUNCTION = "op"
202+
203+
CATEGORY = "latent/advanced/operations"
204+
EXPERIMENTAL = True
205+
206+
def op(self, multiplier):
207+
def tonemap_reinhard(latent, **kwargs):
208+
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
209+
normalized_latent = latent / latent_vector_magnitude
210+
211+
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
212+
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
213+
214+
top = (std * 5 + mean) * multiplier
215+
216+
#reinhard
217+
latent_vector_magnitude *= (1.0 / top)
218+
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
219+
new_magnitude *= top
220+
221+
return normalized_latent * new_magnitude
222+
return (tonemap_reinhard,)
223+
148224
NODE_CLASS_MAPPINGS = {
149225
"LatentAdd": LatentAdd,
150226
"LatentSubtract": LatentSubtract,
151227
"LatentMultiply": LatentMultiply,
152228
"LatentInterpolate": LatentInterpolate,
153229
"LatentBatch": LatentBatch,
154230
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
231+
"LatentApplyOperation": LatentApplyOperation,
232+
"LatentApplyOperationCFG": LatentApplyOperationCFG,
233+
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
155234
}

0 commit comments

Comments
 (0)