|
7 | 7 |
|
8 | 8 |
|
9 | 9 | ''' |
10 | | -Basic Fortnet IO Format Classes |
| 10 | +Basic Fortnet Input Format Class |
11 | 11 |
|
12 | | -These basic Python classes implement the Fortnet input and output file format. |
13 | | -The Fnetdata class enables to create compatible HDF5 datasets, whereas the |
14 | | -Fnetout class extracts certain properties of the HDF5 output for later analysis. |
| 12 | +This basic Python class implements the Fortnet input |
| 13 | +format and enables to create compatible HDF5 datasets. |
15 | 14 | ''' |
16 | 15 |
|
17 | 16 |
|
@@ -538,227 +537,5 @@ def hdf_append_targets(root, data): |
538 | 537 | targets[...] = tmp |
539 | 538 |
|
540 | 539 |
|
541 | | -class Fnetout: |
542 | | - '''Basic Fortnet Output Format Class.''' |
543 | | - |
544 | | - |
545 | | - def __init__(self, fname): |
546 | | - '''Initializes a Fnetout object. |
547 | | -
|
548 | | - Args: |
549 | | -
|
550 | | - fname (str): filename to extract data from |
551 | | -
|
552 | | - ''' |
553 | | - |
554 | | - self._fname = fname |
555 | | - |
556 | | - with h5py.File(self._fname, 'r') as fnetoutfile: |
557 | | - fnetout = fnetoutfile['fnetout'] |
558 | | - self._mode = fnetout.attrs.get('mode').decode('UTF-8').strip() |
559 | | - if not self._mode in ('validate', 'predict'): |
560 | | - raise FnetoutError('Invalid running mode specification.') |
561 | | - |
562 | | - output = fnetoutfile['fnetout']['output'] |
563 | | - self._ndatapoints = output.attrs.get('ndatapoints') |
564 | | - if len(self._ndatapoints) == 1: |
565 | | - # number of datapoints stored in array of size 1 |
566 | | - self._ndatapoints = self._ndatapoints[0] |
567 | | - else: |
568 | | - msg = "Error while reading fnetout file '" + self._fname + \ |
569 | | - "'. Unrecognized number of datapoints obtained." |
570 | | - raise FnetoutError(msg) |
571 | | - self._tforces = output.attrs.get('tforces') |
572 | | - if len(self._tforces) == 1: |
573 | | - # booleans stored in integer arrays of size 1 |
574 | | - self._tforces = bool(self._tforces[0]) |
575 | | - else: |
576 | | - msg = "Error while reading fnetout file '" + self._fname + \ |
577 | | - "'. Unrecognized force specification obtained." |
578 | | - raise FnetoutError(msg) |
579 | | - |
580 | | - self._targettype = \ |
581 | | - output.attrs.get('targettype').decode('UTF-8').strip() |
582 | | - if not self._targettype in ('atomic', 'global'): |
583 | | - raise FnetoutError('Invalid running mode obtained.') |
584 | | - |
585 | | - # get number of atomic or global predictions/targets |
586 | | - self._npredictions = np.shape( |
587 | | - np.array(output['datapoint1']['output']))[1] |
588 | | - |
589 | | - if self._mode == 'validate': |
590 | | - self._npredictions = int(self._npredictions / 2) |
591 | | - |
592 | | - |
593 | | - @property |
594 | | - def mode(self): |
595 | | - '''Defines property, providing the mode of the Fortnet run. |
596 | | -
|
597 | | - Returns: |
598 | | -
|
599 | | - mode (str): mode of the run that produced the Fnetout file |
600 | | -
|
601 | | - ''' |
602 | | - |
603 | | - return self._mode |
604 | | - |
605 | | - |
606 | | - @property |
607 | | - def ndatapoints(self): |
608 | | - '''Defines property, providing the number of datapoints. |
609 | | -
|
610 | | - Returns: |
611 | | -
|
612 | | - ndatapoints (int): total number of datapoints of the training |
613 | | -
|
614 | | - ''' |
615 | | - |
616 | | - return self._ndatapoints |
617 | | - |
618 | | - |
619 | | - @property |
620 | | - def targettype(self): |
621 | | - '''Defines property, providing the target type. |
622 | | -
|
623 | | - Returns: |
624 | | -
|
625 | | - targettype (str): type of targets the network was trained on |
626 | | -
|
627 | | - ''' |
628 | | - |
629 | | - return self._targettype |
630 | | - |
631 | | - |
632 | | - @property |
633 | | - def tforces(self): |
634 | | - '''Defines property, providing hint whether atomic forces are present. |
635 | | -
|
636 | | - Returns: |
637 | | -
|
638 | | - tforces (bool): true, if atomic forces are supplied |
639 | | -
|
640 | | - ''' |
641 | | - |
642 | | - return self._tforces |
643 | | - |
644 | | - |
645 | | - @property |
646 | | - def predictions(self): |
647 | | - '''Defines property, providing the predictions of Fortnet. |
648 | | -
|
649 | | - Returns: |
650 | | -
|
651 | | - predictions (list or 2darray): predictions of the network |
652 | | -
|
653 | | - ''' |
654 | | - |
655 | | - with h5py.File(self._fname, 'r') as fnetoutfile: |
656 | | - output = fnetoutfile['fnetout']['output'] |
657 | | - if self._targettype == 'atomic': |
658 | | - predictions = [] |
659 | | - for idata in range(self._ndatapoints): |
660 | | - dataname = 'datapoint' + str(idata + 1) |
661 | | - if self._mode == 'validate': |
662 | | - predictions.append( |
663 | | - np.array(output[dataname]['output'], |
664 | | - dtype=float)[:, :self._npredictions]) |
665 | | - else: |
666 | | - predictions.append( |
667 | | - np.array(output[dataname]['output'], dtype=float)) |
668 | | - else: |
669 | | - predictions = np.empty( |
670 | | - (self._ndatapoints, self._npredictions), dtype=float) |
671 | | - for idata in range(self._ndatapoints): |
672 | | - dataname = 'datapoint' + str(idata + 1) |
673 | | - if self._mode == 'validate': |
674 | | - predictions[idata, :] = \ |
675 | | - np.array(output[dataname]['output'], |
676 | | - dtype=float)[0, :self._npredictions] |
677 | | - else: |
678 | | - predictions[idata, :] = \ |
679 | | - np.array(output[dataname]['output'], |
680 | | - dtype=float)[0, :] |
681 | | - |
682 | | - return predictions |
683 | | - |
684 | | - |
685 | | - @property |
686 | | - def targets(self): |
687 | | - '''Defines property, providing the targets during training. |
688 | | -
|
689 | | - Returns: |
690 | | -
|
691 | | - targets (list or 2darray): targets during training |
692 | | -
|
693 | | - ''' |
694 | | - |
695 | | - if self._mode == 'predict': |
696 | | - return None |
697 | | - |
698 | | - with h5py.File(self._fname, 'r') as fnetoutfile: |
699 | | - output = fnetoutfile['fnetout']['output'] |
700 | | - if self._targettype == 'atomic': |
701 | | - targets = [] |
702 | | - for idata in range(self._ndatapoints): |
703 | | - dataname = 'datapoint' + str(idata + 1) |
704 | | - targets.append( |
705 | | - np.array(output[dataname]['output'], |
706 | | - dtype=float)[:, self._npredictions:]) |
707 | | - else: |
708 | | - targets = np.empty( |
709 | | - (self._ndatapoints, self._npredictions), dtype=float) |
710 | | - for idata in range(self._ndatapoints): |
711 | | - dataname = 'datapoint' + str(idata + 1) |
712 | | - targets[idata, :] = \ |
713 | | - np.array(output[dataname]['output'], |
714 | | - dtype=float)[0, self._npredictions:] |
715 | | - |
716 | | - return targets |
717 | | - |
718 | | - |
719 | | - @property |
720 | | - def forces(self): |
721 | | - '''Defines property, providing the atomic forces, if supplied. |
722 | | -
|
723 | | - Returns: |
724 | | -
|
725 | | - forces (list): atomic forces on atoms |
726 | | -
|
727 | | - ''' |
728 | | - |
729 | | - tmp1 = [] |
730 | | - |
731 | | - if self._targettype == 'atomic': |
732 | | - msg = "Error while extracting forces from fnetout file '" \ |
733 | | - + self._fname + \ |
734 | | - "'. Forces only supplied for global property targets." |
735 | | - raise FnetoutError(msg) |
736 | | - |
737 | | - with h5py.File(self._fname, 'r') as fnetoutfile: |
738 | | - output = fnetoutfile['fnetout']['output'] |
739 | | - for idata in range(self._ndatapoints): |
740 | | - dataname = 'datapoint' + str(idata + 1) |
741 | | - tmp1.append(np.array(output[dataname]['forces'], dtype=float)) |
742 | | - |
743 | | - # convert to shape np.shape(forces[iData][iTarget]) = (iAtom, 3) |
744 | | - forces = [] |
745 | | - for tmp2 in tmp1: |
746 | | - entry = [] |
747 | | - if not np.shape(tmp2)[1]%3 == 0: |
748 | | - msg = "Error while extracting forces from fnetout file '" \ |
749 | | - + self._fname + \ |
750 | | - "'. Expected three force components and global target." |
751 | | - raise FnetoutError(msg) |
752 | | - for jj in range(int(np.shape(tmp2)[1] / 3)): |
753 | | - entry.append(tmp2[:, 3 * jj:3 * (jj + 1)]) |
754 | | - forces.append(entry) |
755 | | - |
756 | | - return forces |
757 | | - |
758 | | - |
759 | 540 | class FnetdataError(Exception): |
760 | 541 | '''Exception thrown by the Fnetdata class.''' |
761 | | - |
762 | | - |
763 | | -class FnetoutError(Exception): |
764 | | - '''Exception thrown by the Fnetout class.''' |
0 commit comments