|
8 | 8 | import numpy as np |
9 | 9 |
|
10 | 10 | from dask.distributed import Client, Future |
11 | | -from dask import array, delayed, compute |
12 | 11 | from ..data_misfit import L2DataMisfit |
13 | 12 |
|
14 | 13 | from simpeg.utils import validate_list_of_types |
@@ -88,16 +87,6 @@ def _deriv2(objfct, multiplier, _, v): |
88 | 87 | return multiplier * deriv2 |
89 | 88 |
|
90 | 89 |
|
91 | | -def _get_attr(objfct, key): |
92 | | - if isinstance(objfct, ComboObjectiveFunction): |
93 | | - attr = [] |
94 | | - for objfct_ in objfct.objfcts: |
95 | | - attr.append(_get_attr(objfct_, key)) |
96 | | - return attr |
97 | | - |
98 | | - return objfct.nP |
99 | | - |
100 | | - |
101 | 90 | def _store_model(objfct, model): |
102 | 91 |
|
103 | 92 | if isinstance(objfct, ComboObjectiveFunction): |
@@ -251,30 +240,6 @@ def __call__(self, m, f=None): |
251 | 240 | values = self.client.gather(values) |
252 | 241 | return np.sum(values) |
253 | 242 |
|
254 | | - @property |
255 | | - def nP(self): |
256 | | - """Number of model parameters. |
257 | | -
|
258 | | - Returns |
259 | | - ------- |
260 | | - int |
261 | | - Number of model parameters. |
262 | | - """ |
263 | | - if self._nP is None: |
264 | | - nP = [] |
265 | | - for futures in self._workloads: |
266 | | - for future, worker in zip(futures, self._workers, strict=True): |
267 | | - nP.append( |
268 | | - self.client.submit( |
269 | | - _get_attr, |
270 | | - future, |
271 | | - "nP", |
272 | | - workers=worker, |
273 | | - ) |
274 | | - ) |
275 | | - self._nP = np.sum(self.client.gather(nP)) |
276 | | - return self._nP |
277 | | - |
278 | 243 | @property |
279 | 244 | def client(self): |
280 | 245 | """ |
@@ -590,200 +555,3 @@ def residuals(self, m, f=None): |
590 | 555 | residuals += client.gather(future_residuals) |
591 | 556 |
|
592 | 557 | return residuals |
593 | | - |
594 | | - # |
595 | | - # def broadcast_updates(self, updates: dict): |
596 | | - # """ |
597 | | - # Set the attributes of the objective functions and simulations |
598 | | - # """ |
599 | | - # stores = [] |
600 | | - # client = self.client |
601 | | - # for fun, (key, value) in updates.items(): |
602 | | - # if fun not in self._lookup: |
603 | | - # continue |
604 | | - # |
605 | | - # future, worker = self._lookup[fun] |
606 | | - # |
607 | | - # stores.append( |
608 | | - # client.submit( |
609 | | - # _setter_broadcast, |
610 | | - # future, |
611 | | - # key, |
612 | | - # value, |
613 | | - # workers=worker, |
614 | | - # ) |
615 | | - # ) |
616 | | - # self.client.gather(stores) # blocking call to ensure all models were stored |
617 | | - |
618 | | - |
619 | | -class DaskComboMisfits(ComboObjectiveFunction): |
620 | | - """ |
621 | | - A composite objective function for distributed computing. |
622 | | - """ |
623 | | - |
624 | | - def __init__( |
625 | | - self, |
626 | | - objfcts: list[BaseObjectiveFunction], |
627 | | - multipliers=None, |
628 | | - worker: str | None = None, |
629 | | - **kwargs, |
630 | | - ): |
631 | | - self._model: np.ndarray | None = None |
632 | | - |
633 | | - super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) |
634 | | - |
635 | | - def __call__(self, m, f=None): |
636 | | - self.model = m |
637 | | - |
638 | | - futures = [] |
639 | | - count = 0 |
640 | | - |
641 | | - delayed_call = delayed(_calc_objective) |
642 | | - for objfct in self.objfcts: |
643 | | - if self.multipliers[count] == 0.0: |
644 | | - continue |
645 | | - |
646 | | - futures.append(delayed_call(objfct, self.multipliers[count], m)) |
647 | | - count += 1 |
648 | | - |
649 | | - return np.sum(compute(futures)[0]) |
650 | | - |
651 | | - def deriv(self, m, f=None): |
652 | | - """ |
653 | | - First derivative of the composite objective function is the sum of the |
654 | | - derivatives of each objective function in the list, weighted by their |
655 | | - respective multplier. |
656 | | -
|
657 | | - :param numpy.ndarray m: model |
658 | | - :param SimPEG.Fields f: Fields object (if applicable) |
659 | | - """ |
660 | | - self.model = m |
661 | | - |
662 | | - futures = [] |
663 | | - |
664 | | - count = 0 |
665 | | - |
666 | | - delayed_call = delayed(_deriv) |
667 | | - for objfct in self.objfcts: |
668 | | - if self.multipliers[count] == 0.0: # don't evaluate the fct |
669 | | - continue |
670 | | - |
671 | | - futures.append( |
672 | | - array.from_delayed( |
673 | | - delayed_call( |
674 | | - objfct, |
675 | | - self.multipliers[count], |
676 | | - m, |
677 | | - ), |
678 | | - shape=m.shape, |
679 | | - dtype=float, |
680 | | - ) |
681 | | - ) |
682 | | - |
683 | | - count += 1 |
684 | | - |
685 | | - return array.vstack(futures).sum(axis=0).compute() |
686 | | - |
687 | | - def deriv2(self, m, v=None, f=None): |
688 | | - """ |
689 | | - Second derivative of the composite objective function is the sum of the |
690 | | - second derivatives of each objective function in the list, weighted by |
691 | | - their respective multplier. |
692 | | -
|
693 | | - :param numpy.ndarray m: model |
694 | | - :param numpy.ndarray v: vector we are multiplying by |
695 | | - :param SimPEG.Fields f: Fields object (if applicable) |
696 | | - """ |
697 | | - self.model = m |
698 | | - |
699 | | - futures = [] |
700 | | - count = 0 |
701 | | - |
702 | | - delayed_call = delayed(_deriv2) |
703 | | - for objfct in self.objfcts: |
704 | | - if self.multipliers[count] == 0.0: # don't evaluate the fct |
705 | | - continue |
706 | | - |
707 | | - futures.append( |
708 | | - array.from_delayed( |
709 | | - delayed_call(objfct, self.multipliers[count], m, v), |
710 | | - shape=m.shape, |
711 | | - dtype=float, |
712 | | - ) |
713 | | - ) |
714 | | - count += 1 |
715 | | - |
716 | | - return array.vstack(futures).sum(axis=0).compute() |
717 | | - |
718 | | - def get_dpred(self, m, f=None): |
719 | | - """ |
720 | | - Request calculation of predicted data from all simulations. |
721 | | - """ |
722 | | - self.model = m |
723 | | - |
724 | | - futures = [] |
725 | | - delayed_call = delayed(_calc_dpred) |
726 | | - |
727 | | - for objfct in self.objfcts: |
728 | | - futures.append(delayed_call(objfct, m)) |
729 | | - |
730 | | - return compute(futures)[0] |
731 | | - |
732 | | - def getJtJdiag(self, m, f=None): |
733 | | - """ |
734 | | - Request calculation of the diagonal of JtJ from all simulations. |
735 | | - """ |
736 | | - self.model = m |
737 | | - |
738 | | - if getattr(self, "_jtjdiag", None) is None: |
739 | | - |
740 | | - futures = [] |
741 | | - delayed_call = delayed(_get_jtj_diag) |
742 | | - |
743 | | - for objfct in self.objfcts: |
744 | | - futures.append( |
745 | | - array.from_delayed( |
746 | | - delayed_call(objfct, m), shape=m.shape, dtype=float |
747 | | - ) |
748 | | - ) |
749 | | - |
750 | | - self._jtjdiag = array.vstack(futures).sum(axis=0).compute() |
751 | | - |
752 | | - return self._jtjdiag |
753 | | - |
754 | | - def residuals(self, m, f=None): |
755 | | - """ |
756 | | - Compute the residual for the data misfit. |
757 | | - """ |
758 | | - self.model = m |
759 | | - |
760 | | - futures = [] |
761 | | - |
762 | | - delayed_call = delayed(_calc_residual) |
763 | | - for objfct in self.objfcts: |
764 | | - futures.append(delayed_call(objfct, m)) |
765 | | - |
766 | | - return compute(futures)[0] |
767 | | - |
768 | | - @property |
769 | | - def model(self): |
770 | | - return self._model |
771 | | - |
772 | | - @model.setter |
773 | | - def model(self, value): |
774 | | - # Only send the model to the internal simulations if it was updated. |
775 | | - if ( |
776 | | - isinstance(value, np.ndarray) |
777 | | - and isinstance(self.model, np.ndarray) |
778 | | - and np.allclose(value, self.model) |
779 | | - ): |
780 | | - return |
781 | | - |
782 | | - self._jtjdiag = None |
783 | | - |
784 | | - stores = [] |
785 | | - delayed_call = delayed(_store_model) |
786 | | - for objfct in self.objfcts: |
787 | | - stores.append(delayed_call(objfct, value)) |
788 | | - compute(stores) |
789 | | - self._model = value |
0 commit comments