Skip to content

Commit 3e9270e

Browse files
committed
Allow csr matric for transform
1 parent 79ba55a commit 3e9270e

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

simpeg/directives/_save_geoh5.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55

66
import numpy as np
7-
7+
from scipy.sparse import csc_matrix, csr_matrix
88
from .directives import InversionDirective
99
from simpeg.maps import IdentityMap
1010

@@ -188,7 +188,12 @@ def transforms(self, funcs: list | tuple):
188188

189189
for fun in funcs:
190190
if not any(
191-
[isinstance(fun, (IdentityMap, np.ndarray, float)), callable(fun)]
191+
[
192+
isinstance(
193+
fun, (IdentityMap, np.ndarray, csr_matrix, csc_matrix, float)
194+
),
195+
callable(fun),
196+
]
192197
):
193198
raise TypeError(
194199
"Input transformation must be of type"
@@ -212,7 +217,9 @@ def apply_transformations(self, prop: np.ndarray) -> np.ndarray:
212217
"""
213218
prop = prop.flatten()
214219
for fun in self.transforms:
215-
if isinstance(fun, (IdentityMap, np.ndarray, float)):
220+
if isinstance(
221+
fun, (IdentityMap, np.ndarray, csr_matrix, csc_matrix, float)
222+
):
216223
prop = fun * prop
217224
else:
218225
prop = fun(prop)

0 commit comments

Comments
 (0)