6161from pathlib import Path
6262from shutil import which
6363from subprocess import Popen
64- from typing import IO , Any , Generic , Sequence , TypeVar
64+ from typing import IO , Any , Generic , Sequence , Type , TypeVar
6565
6666from diopter .utils import CommandOutput , run_cmd , run_cmd_async , temporary_file
6767
@@ -152,6 +152,9 @@ def __del__(self) -> None:
152152 os .remove (self .filename )
153153
154154
155+ SourceType = TypeVar ("SourceType" , bound = "Source" )
156+
157+
155158@dataclass (frozen = True , kw_only = True )
156159class Source (ABC ):
157160 """A base class for C or C++source programs together
@@ -261,7 +264,7 @@ def to_json_dict_impl(self) -> dict[str, Any]:
261264 raise NotImplementedError
262265
263266 @classmethod
264- def from_json_dict (cls , d : dict [str , Any ]) -> Source :
267+ def from_json_dict (cls : Type [ SourceType ] , d : dict [str , Any ]) -> SourceType :
265268 """Returns a source parsed from a json dictionary.
266269
267270 Args:
@@ -1292,8 +1295,8 @@ def compile_program_async(
12921295 def preprocess_program (
12931296 self ,
12941297 program : ProgramType ,
1295- make_compiler_agnostic : bool = False ,
12961298 additional_flags : tuple [str , ...] = tuple (),
1299+ do_not_expand_includes : tuple [str , ...] = tuple (),
12971300 timeout : int | None = None ,
12981301 ) -> ProgramType :
12991302 """Preprocesses the program
@@ -1303,9 +1306,9 @@ def preprocess_program(
13031306 input program
13041307 additional_flags (tuple[str, ...]):
13051308 additional flags used for the compilation
1306- make_compiler_agnostic (bool ):
1307- if true will try to remove certain constructs (e.g., attributes)
1308- such that the resulting program can be compiled with both gcc and clang
1309+ do_not_expand_includes (tuple[str, ...] ):
1310+ include directives that should not be expanded,
1311+ this only works with <...> includes
13091312 timeout (int | None):
13101313 timeout in seconds for the compilation command
13111314
@@ -1314,6 +1317,18 @@ def preprocess_program(
13141317 the prepocessed program
13151318 """
13161319
1320+ if do_not_expand_includes :
1321+ tmpdir = tempfile .TemporaryDirectory ()
1322+ for include in do_not_expand_includes :
1323+ p = Path (tmpdir .name ) / Path (include )
1324+ if not p .parent .exists ():
1325+ p .parent .mkdir (parents = True )
1326+ with open (p , "w" ) as f :
1327+ print (
1328+ f"//unpreprocessed_include_#include <{ include } >" , file = f , end = ""
1329+ )
1330+ additional_flags += ("-I" , tmpdir .name , "-C" )
1331+
13171332 result = self .compile_program (
13181333 program ,
13191334 ASMCompilationOutput (None ),
@@ -1322,26 +1337,10 @@ def preprocess_program(
13221337 )
13231338 preprocessed_source = result .output .read ()
13241339
1325- if make_compiler_agnostic :
1326- # remove malloc attributes with args, clang doesn't understand these
1327- preprocessed_source = re .sub (
1328- r"__attribute__ \(\(__malloc__ \(.*, .*\)\)\)" , r"" , preprocessed_source
1329- )
1330- # remove f128 builtins builtins, clang doesn't understand these
1331- preprocessed_source = re .sub (
1332- r"extern int [^;]*f128[^;]*;" , r"" , preprocessed_source
1333- )
1334- # remove Float*** typedefs, gcc doesn't like these
1335- preprocessed_source = re .sub (
1336- r"typedef [^;]*_Float\d+x?;" , r"" , preprocessed_source
1337- )
1338- # replace remaining FloatX types with the standard ones
1339- preprocessed_source = re .sub (r"_Float32x" , r"double" , preprocessed_source )
1340- preprocessed_source = re .sub (
1341- r"_Float64x" , r"long double" , preprocessed_source
1340+ if do_not_expand_includes :
1341+ preprocessed_source = preprocessed_source .replace (
1342+ "//unpreprocessed_include_" , ""
13421343 )
1343- preprocessed_source = re .sub (r"_Float32" , r"float" , preprocessed_source )
1344- preprocessed_source = re .sub (r"_Float64" , r"double" , preprocessed_source )
13451344
13461345 return program .with_preprocessed_code (preprocessed_source )
13471346
0 commit comments