Skip to content

Commit ccb6881

Browse files
committed
Add function to consistently strip the copyright banner from templates
1 parent b4218ef commit ccb6881

1 file changed

Lines changed: 17 additions & 6 deletions

File tree

madmatrix/model_handling.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@
2929
# FD gauge check
3030
fd_gauge = (unitary_gauge == 3)
3131

32+
def strip_banner(file_text, banner_mark):
33+
# skip leading lines that start with '!'
34+
start = 0
35+
file_lines = file_text.split('\n')
36+
for i, line in enumerate(file_lines):
37+
if line.startswith(banner_mark):
38+
start = i + 1
39+
else:
40+
break
41+
return '\n'.join(file_lines[start:])
42+
3243
# AV - define a custom ALOHAWriter
3344
# (NB: enable this via MadMatrixUFOModelConverter.aloha_writer)
3445
class MadMatrixALOHAWriter(aloha_writers.ALOHAWriterForGPU):
@@ -888,7 +899,7 @@ def read_aloha_template_files(self, ext):
888899

889900
if ext == 'h': file = open(pjoin(path, helas_temp_file)).read()
890901
else: file = open(pjoin(path, self.helas_cc)).read()
891-
file = '\n'.join( file.split('\n')[9:] ) # skip first 9 lines in helas.h/cu (copyright including ALOHA)
902+
file = strip_banner(file, banner_mark = "!") # skip first 9 lines in helas.h/cu (copyright including ALOHA)
892903
out.append( file )
893904
return out
894905

@@ -1353,7 +1364,7 @@ def write_aloha_routines(self):
13531364
if (fd_gauge): replace_dict['nwave'] += 1
13541365
file_h = self.read_template_file(self.aloha_template_h) % replace_dict
13551366
file_cc = self.read_template_file(self.aloha_template_cc) % replace_dict
1356-
file_cc = '\n'.join( file_cc.split('\n')[9:] ) # skip first 9 lines in cpp_hel_amps_cc.inc (copyright including ALOHA)
1367+
file_cc = strip_banner(file_cc, banner_mark = "!") # skip first 9 lines in cpp_hel_amps_cc.inc (copyright including ALOHA)
13571368
# Write the HelAmps_sm.h and HelAmps_sm.cc files
13581369
###MadMatrixwriters.CPPWriter(model_h_file).writelines(file_h)
13591370
###MadMatrixwriters.CPPWriter(model_cc_file).writelines(file_cc)
@@ -1461,7 +1472,7 @@ def get_process_class_definitions(self, write=True):
14611472

14621473
if( write ): # ZW: added dict return for uses in child exporters. Default argument is True so no need to modify other calls to this function
14631474
file = self.read_template_file(self.process_class_template) % replace_dict
1464-
file = '\n'.join( file.split('\n')[8:] ) # skip first 8 lines in process_class.inc (copyright)
1475+
file = strip_banner(file, banner_mark = "!") # skip first 8 lines in process_class.inc (copyright)
14651476
return file
14661477
else:
14671478
return replace_dict
@@ -1633,7 +1644,7 @@ def get_process_function_definitions(self, write=True):
16331644
file_lines = file.split('\n')
16341645
file_lines = [l.replace('cIPC, cIPD','cIPC') for l in file_lines] # remove cIPD from OpenMP pragma
16351646
file = '\n'.join( file_lines )
1636-
file = '\n'.join( file.split('\n')[8:] ) # skip first 8 lines in process_function_definitions.inc (copyright)
1647+
file = strip_banner(file, banner_mark = "!") # skip first 8 lines in process_function_definitions.inc (copyright)
16371648
return file
16381649

16391650
# AV - modify export_cpp.OneProcessExporterCPP method (add debug printouts for multichannel #342)
@@ -1664,7 +1675,7 @@ def get_sigmaKin_lines(self, color_amplitudes, write=True):
16641675

16651676
if write:
16661677
file = self.read_template_file(self.process_sigmaKin_function_template) % replace_dict
1667-
file = '\n'.join( file.split('\n')[8:] ) # skip first 8 lines in process_sigmaKin_function.inc (copyright)
1678+
file = strip_banner(file, banner_mark = "!") # skip first 8 lines in process_sigmaKin_function.inc (copyright)
16681679
return file, replace_dict
16691680
else:
16701681
return replace_dict
@@ -1816,7 +1827,7 @@ def get_all_sigmaKin_lines(self, color_amplitudes, class_name):
18161827
file_extend = []
18171828
for i, me in enumerate(self.matrix_elements):
18181829
file = self.get_matrix_single_process( i, me, color_amplitudes[i], class_name )
1819-
file = '\n'.join( file.split('\n')[8:] ) # skip first 8 lines in process_matrix.inc (copyright)
1830+
file = strip_banner(file, banner_mark = "!") # skip first 8 lines in process_matrix.inc (copyright)
18201831
file_extend.append( file )
18211832
assert i == 0, "more than one ME in get_all_sigmaKin_lines" # AV sanity check (added for color_sum.cc but valid independently)
18221833
ret_lines.extend( file_extend )

0 commit comments

Comments
 (0)