|
25 | 25 | import tempfile |
26 | 26 | import time |
27 | 27 | import logging |
| 28 | +import os |
28 | 29 |
|
29 | 30 | logger = logging.getLogger(__name__) |
30 | 31 |
|
@@ -106,51 +107,90 @@ def __init__(self, voc_index_to_token: Optional[Mapping[int, str]] = None, |
106 | 107 | self.group_1_name = group_1_name |
107 | 108 | self.premise_engine = premise_engine |
108 | 109 |
|
| 110 | + |
| 111 | + |
| 112 | + |
109 | 113 | def find_patterns(self, instances: List[PremiseInstance]): |
110 | 114 | import pypremise.io |
111 | | - import os |
112 | 115 |
|
113 | | - # the Premise C++ code reads and write to files for in and output |
114 | 116 | feature_file = tempfile.NamedTemporaryFile(delete=False) |
| 117 | + feature_file.close() |
115 | 118 | label_file = tempfile.NamedTemporaryFile(delete=False) |
| 119 | + label_file.close() |
116 | 120 | result_file = tempfile.NamedTemporaryFile(delete=False) |
| 121 | + result_file.close() |
| 122 | + |
| 123 | + feature_path = os.path.abspath(feature_file.name).replace("\\", "/") |
| 124 | + label_path = os.path.abspath(label_file.name).replace("\\", "/") |
| 125 | + result_path = os.path.abspath(result_file.name).replace("\\", "/") |
117 | 126 |
|
118 | | - pypremise.io.write_dat_content(instances, feature_file.name, label_file.name) |
| 127 | + pypremise.io.write_dat_content(instances, feature_path, label_path) |
119 | 128 |
|
120 | | - # embeddings |
| 129 | + # === embeddings === |
121 | 130 | if self.embedding_index_to_vector is not None: |
122 | 131 | embedding_file = tempfile.NamedTemporaryFile(delete=False) |
123 | | - embedding_path = embedding_file.name |
| 132 | + embedding_file.close() |
| 133 | + embedding_path = os.path.abspath(embedding_file.name).replace("\\", "/") |
| 134 | + |
124 | 135 | max_feature_index = Premise._get_max_feature_index(instances) |
125 | | - pypremise.io.write_embedding_file(self.embedding_index_to_vector, embedding_path, |
126 | | - self.embedding_dimensionality, max_feature_index) |
| 136 | + pypremise.io.write_embedding_file( |
| 137 | + self.embedding_index_to_vector, |
| 138 | + embedding_path, |
| 139 | + self.embedding_dimensionality, |
| 140 | + max_feature_index |
| 141 | + ) |
127 | 142 | else: |
128 | 143 | embedding_file = None |
129 | 144 | embedding_path = "" |
130 | 145 |
|
131 | | - # actual Premise |
| 146 | + # === call Premise C++ program === |
132 | 147 | start_time = time.time() |
133 | | - pypremise.io.call_premise_program(feature_file.name, label_file.name, result_file.name, embedding_path, |
134 | | - self.embedding_dimensionality, self.max_neighbor_distance, |
135 | | - self.fisher_p_value, self.clause_max_overlap, self.min_overlap, |
136 | | - self.premise_engine) |
137 | | - logger.info(f"Premise ran for {time.time() - start_time} seconds.") |
138 | | - |
139 | | - results = pypremise.io.parse_premise_result(result_file.name, self.group_0_name, self.group_1_name) |
140 | | - |
141 | | - # clean up temporary files |
142 | | - os.remove(feature_file.name) |
143 | | - os.remove(label_file.name) |
144 | | - os.remove(result_file.name) |
| 148 | + if self.premise_engine is None: |
| 149 | + self.premise_engine = pypremise.io.get_premise_path() |
| 150 | + pypremise.io.call_premise_program( |
| 151 | + feature_path, label_path, result_path, embedding_path, |
| 152 | + self.embedding_dimensionality, self.max_neighbor_distance, |
| 153 | + self.fisher_p_value, self.clause_max_overlap, self.min_overlap, |
| 154 | + self.premise_engine |
| 155 | + ) |
| 156 | + logger.info(f"Premise ran for {time.time() - start_time:.2f} seconds.") |
| 157 | + |
| 158 | + # === check result file === |
| 159 | + try: |
| 160 | + size = os.path.getsize(result_path) |
| 161 | + logger.info(f"Result file size: {size} bytes") |
| 162 | + if size == 0: |
| 163 | + logger.warning("Result file is empty — check Premise stderr or parameters.") |
| 164 | + except Exception as e: |
| 165 | + logger.error(f"Could not stat result file: {e}") |
| 166 | + |
| 167 | + # === analyse results === |
| 168 | + results = pypremise.io.parse_premise_result( |
| 169 | + result_path, self.group_0_name, self.group_1_name |
| 170 | + ) |
| 171 | + |
| 172 | + def safe_remove(path): |
| 173 | + try: |
| 174 | + os.remove(path) |
| 175 | + except PermissionError: |
| 176 | + time.sleep(0.2) |
| 177 | + try: |
| 178 | + os.remove(path) |
| 179 | + except Exception: |
| 180 | + pass |
| 181 | + |
| 182 | + for f in [feature_path, label_path, result_path]: |
| 183 | + safe_remove(f) |
145 | 184 | if embedding_file is not None: |
146 | | - os.remove(embedding_file.name) |
| 185 | + safe_remove(embedding_path) |
147 | 186 |
|
148 | | - # if we have a map from indices to tokens, use it to convert our patterns indices to tokens |
149 | 187 | if self.voc_index_to_token is not None: |
150 | 188 | self._pattern_indices_to_tokens(results) |
151 | 189 |
|
152 | 190 | return results |
153 | 191 |
|
| 192 | + |
| 193 | + |
154 | 194 | def _pattern_indices_to_tokens(self, results: List[PremiseResult]): |
155 | 195 | """ |
156 | 196 | Converts the features in the given patterns from their index representation to their token representation. |
|
0 commit comments