@@ -70,7 +70,7 @@ def bootloader_shell(self):
7070 pass
7171 yield self .serial
7272
73- def flash (
73+ def flash ( # noqa: C901
7474 self ,
7575 path : PathBuf ,
7676 * ,
@@ -81,10 +81,19 @@ def flash(
8181 force_flash_bundle : str | None = None ,
8282 cacert_file : str | None = None ,
8383 insecure_tls : bool = False ,
84+ headers : dict [str , str ] | None = None ,
85+ bearer_token : str | None = None ,
8486 ):
87+ if bearer_token :
88+ bearer_token = self ._validate_bearer_token (bearer_token )
89+
90+ if headers :
91+ headers = self ._validate_header_dict (headers )
92+
8593 """Flash image to DUT"""
8694 should_download_to_httpd = True
8795 image_url = ""
96+ original_http_url = None
8897 operator_scheme = None
8998 # initrmafs cannot handle https yet, fallback to using the exporter's http server
9099 if path .startswith (("http://" , "https://" )) and not force_exporter_http :
@@ -94,7 +103,17 @@ def flash(
94103 else :
95104 # use the exporter's http server for the flasher image, we should download it first
96105 if operator is None :
97- path , operator , operator_scheme = operator_for_path (path )
106+ if path .startswith (("http://" , "https://" )) and bearer_token :
107+ parsed = urlparse (path )
108+ self .logger .info (f"Using Bearer token authentication for { parsed .netloc } " )
109+ original_http_url = path
110+ operator = Operator (
111+ "http" , root = "/" , endpoint = f"{ parsed .scheme } ://{ parsed .netloc } " , token = bearer_token
112+ )
113+ operator_scheme = "http"
114+ path = Path (parsed .path )
115+ else :
116+ path , operator , operator_scheme = operator_for_path (path )
98117 image_url = self .http .get_url () + "/" + path .name
99118
100119 # start counting time for the flash operation
@@ -107,7 +126,16 @@ def flash(
107126 # Start the storage write operation in the background
108127 storage_thread = threading .Thread (
109128 target = self ._transfer_bg_thread ,
110- args = (path , operator , operator_scheme , os_image_checksum , self .http .storage , error_queue , image_url ),
129+ args = (
130+ path ,
131+ operator ,
132+ operator_scheme ,
133+ os_image_checksum ,
134+ self .http .storage ,
135+ error_queue ,
136+ original_http_url ,
137+ headers ,
138+ ),
111139 name = "storage_transfer" ,
112140 )
113141 storage_thread .start ()
@@ -152,9 +180,17 @@ def flash(
152180 else :
153181 stored_cacert = self ._setup_flasher_ssl (console , manifest , cacert_file )
154182
155-
156- self ._flash_with_progress (console , manifest , path , image_url , target_device ,
157- insecure_tls , stored_cacert )
183+ header_args = self ._prepare_headers (headers , bearer_token )
184+ self ._flash_with_progress (
185+ console ,
186+ manifest ,
187+ path ,
188+ image_url ,
189+ target_device ,
190+ insecure_tls ,
191+ stored_cacert ,
192+ header_args ,
193+ )
158194
159195 total_time = time .time () - start_time
160196 # total time in minutes:seconds
@@ -222,7 +258,36 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str:
222258 tls_args += f"--cacert { stored_cacert } "
223259 return tls_args .strip ()
224260
225- def _flash_with_progress (self , console , manifest , path , image_url , target_path , insecure_tls , stored_cacert ):
261+ def _curl_header_args (self , headers : dict [str , str ] | None ) -> str :
262+ """Generate header arguments for curl command"""
263+ if not headers :
264+ return ""
265+
266+ parts : list [str ] = []
267+
268+ def _sq (s : str ) -> str :
269+ return s .replace ("'" , "'\" '\" '" )
270+
271+ for k , v in headers .items ():
272+ k = str (k ).strip ()
273+ v = str (v ).strip ()
274+ if not k :
275+ continue
276+ parts .append (f"-H '{ _sq (k )} : { _sq (v )} '" )
277+
278+ return " " .join (parts )
279+
280+ def _flash_with_progress (
281+ self ,
282+ console ,
283+ manifest ,
284+ path ,
285+ image_url ,
286+ target_path ,
287+ insecure_tls ,
288+ stored_cacert ,
289+ header_args : str ,
290+ ):
226291 """Flash image to target device with progress monitoring.
227292
228293 Args:
@@ -241,11 +306,11 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
241306 tls_args = self ._curl_tls_args (insecure_tls , stored_cacert )
242307
243308 # Check if the image URL is accessible using curl and the TLS arguments
244- self ._check_url_access (console , prompt , image_url , tls_args )
309+ self ._check_url_access (console , prompt , image_url , tls_args , header_args )
245310
246311 # Flash the image, we run curl -> decompress -> dd in the background, so we can monitor dd's progress
247312 flash_cmd = (
248- f'( curl -fsSL { tls_args } "{ image_url } " | '
313+ f'( curl -fsSL { tls_args } { header_args } "{ image_url } " | '
249314 f"{ decompress_cmd } "
250315 f"dd of={ target_path } bs=64k iflag=fullblock oflag=direct) &"
251316 )
@@ -287,7 +352,7 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
287352 console .sendline ("sync" )
288353 console .expect (prompt , timeout = EXPECT_TIMEOUT_SYNC )
289354
290- def _check_url_access (self , console , prompt , image_url : str , tls_args : str ):
355+ def _check_url_access (self , console , prompt , image_url : str , tls_args : str , header_args : str ):
291356 """Check if the image URL is accessible using curl.
292357
293358 Args:
@@ -299,7 +364,9 @@ def _check_url_access(self, console, prompt, image_url: str, tls_args: str):
299364 Raises:
300365 RuntimeError: If the URL is not accessible
301366 """
302- console .sendline (f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null { tls_args } "{ image_url } "' )
367+ console .sendline (
368+ f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null { tls_args } { header_args } "{ image_url } "'
369+ )
303370 console .expect (prompt , timeout = EXPECT_TIMEOUT_DEFAULT )
304371 curl_output = console .before .decode (errors = "ignore" ).strip ()
305372 console .sendline ("echo $?" )
@@ -358,6 +425,7 @@ def _transfer_bg_thread(
358425 to_storage : OpendalClient ,
359426 error_queue ,
360427 original_url : str | None = None ,
428+ headers : dict [str , str ] | None = None ,
361429 ):
362430 """Transfer image to exporter storage in the background
363431 Args:
@@ -367,6 +435,7 @@ def _transfer_bg_thread(
367435 error_queue: Queue to put exceptions in if any
368436 known_hash: Known hash of the image
369437 original_url: Original URL for HTTP fallback
438+ headers: HTTP headers for requests
370439 """
371440 self .logger .info (f"Writing image to storage in the background: { src_path } " )
372441 try :
@@ -392,7 +461,9 @@ def _transfer_bg_thread(
392461 self .logger .info (f"Uploading image to storage: { filename } " )
393462 to_storage .write_from_path (filename , src_path , src_operator )
394463
395- metadata , metadata_json = self ._create_metadata_and_json (src_operator , src_path , file_hash , original_url )
464+ metadata , metadata_json = self ._create_metadata_and_json (
465+ src_operator , src_path , file_hash , original_url , headers
466+ )
396467 metadata_file = filename + ".metadata"
397468 to_storage .write_bytes (metadata_file , metadata_json .encode (errors = "ignore" ))
398469
@@ -415,7 +486,7 @@ def _sha256_file(self, src_operator, src_path) -> str:
415486 return m .hexdigest ()
416487
417488 def _create_metadata_and_json (
418- self , src_operator , src_path , file_hash = None , original_url = None
489+ self , src_operator , src_path , file_hash = None , original_url = None , headers : dict [ str , str ] | None = None
419490 ) -> tuple [Metadata | None , str ]:
420491 """Create a metadata json string from a metadata object"""
421492 metadata = None
@@ -436,7 +507,10 @@ def _create_metadata_and_json(
436507
437508 if original_url and original_url .startswith (("http://" , "https://" )):
438509 try :
439- response = requests .head (original_url )
510+ if headers :
511+ response = requests .head (original_url , headers = headers )
512+ else :
513+ response = requests .head (original_url )
440514
441515 http_metadata = {}
442516 if "content-length" in response .headers :
@@ -611,6 +685,71 @@ def manifest(self):
611685 self ._manifest = FlasherBundleManifestV1Alpha1 .from_string (yaml_str )
612686 return self ._manifest
613687
688+ def _validate_header_dict (self , header_map : dict [str , str ]) -> dict [str , str ]:
689+ token_re = re .compile (r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$" )
690+ seen : set [str ] = set ()
691+ for key , value in header_map .items ():
692+ key = key .strip ()
693+ value = value .strip ()
694+ if not key :
695+ raise ArgumentError (f"Invalid header key: '{ key } '" )
696+
697+ if not token_re .match (key ):
698+ raise ArgumentError (f"Invalid header name '{ key } ': must be an HTTP token (RFC7230)" )
699+ if any (c in ("\r " , "\n " ) for c in key ) or any (c in ("\r " , "\n " ) for c in value ):
700+ raise ArgumentError ("Header names/values must not contain CR/LF" )
701+ kl = key .lower ()
702+ if kl in seen :
703+ raise ArgumentError (f"Duplicate header '{ key } '" )
704+ seen .add (kl )
705+ return header_map
706+
707+ def _parse_headers (self , headers : list [str ]) -> dict [str , str ]:
708+ header_map : dict [str , str ] = {}
709+ for h in headers :
710+ if ":" not in h :
711+ raise click .ClickException (f"Invalid header format: { h !r} . Expected 'Key: Value'." )
712+
713+ key , value = h .split (":" , 1 )
714+ header_map [key .strip ()] = value .strip ()
715+
716+ try :
717+ return self ._validate_header_dict (header_map )
718+ except ArgumentError as e :
719+ raise click .ClickException (str (e )) from e
720+
721+ def _prepare_headers (self , headers : dict [str , str ] | None , bearer_token : str | None ) -> str :
722+ all_headers = headers .copy () if headers else {}
723+ if bearer_token :
724+ if any (k .lower () == "authorization" for k in all_headers .keys ()):
725+ self .logger .warning ("Authorization header provided - ignoring bearer token" )
726+ else :
727+ all_headers ["Authorization" ] = f"Bearer { bearer_token } "
728+
729+ if bearer_token and "Authorization" not in (headers or {}):
730+ auth_header = {"Authorization" : all_headers ["Authorization" ]}
731+ self ._validate_header_dict (auth_header )
732+
733+ return self ._curl_header_args (all_headers )
734+
735+ def _validate_bearer_token (self , token : str | None ) -> str | None :
736+ if token is None :
737+ return None
738+
739+ token = token .strip ()
740+ if not token :
741+ raise click .ClickException ("Bearer token cannot be empty" )
742+
743+ # RFC 6750 allows token68 format (base64url-encoded) or other token formats
744+ # Basic validation: printable ASCII excluding whitespace and special chars that could cause issues
745+ if not all (32 < ord (c ) < 127 and c not in ' "\\ ' for c in token ):
746+ raise click .ClickException ("Bearer token contains invalid characters" )
747+
748+ if len (token ) > 4096 :
749+ raise click .ClickException ("Bearer token is too long (max 4096 characters)" )
750+
751+ return token
752+
614753 def cli (self ):
615754 @driver_click_group (self )
616755 def base ():
@@ -630,6 +769,17 @@ def base():
630769 @click .option ("--force-flash-bundle" , type = str , help = "Force use of a specific flasher OCI bundle" )
631770 @click .option ("--cacert" , type = click .Path (exists = True , dir_okay = False ), help = "CA certificate to use for HTTPS" )
632771 @click .option ("--insecure-tls" , is_flag = True , help = "Skip TLS certificate verification" )
772+ @click .option (
773+ "--header" ,
774+ "header" ,
775+ multiple = True ,
776+ help = "Custom HTTP header in 'Key: Value' format" ,
777+ )
778+ @click .option (
779+ "--bearer" ,
780+ type = str ,
781+ help = "Bearer token for HTTP authentication" ,
782+ )
633783 @debug_console_option
634784 def flash (
635785 file ,
@@ -641,6 +791,8 @@ def flash(
641791 force_flash_bundle ,
642792 cacert ,
643793 insecure_tls ,
794+ header ,
795+ bearer ,
644796 ):
645797 """Flash image to DUT from file"""
646798 if os_image_checksum_file and os .path .exists (os_image_checksum_file ):
@@ -649,13 +801,18 @@ def flash(
649801 self .logger .info (f"Read checksum from file: { os_image_checksum } " )
650802
651803 self .set_console_debug (console_debug )
804+
805+ headers = self ._parse_headers (header ) if header else None
806+
652807 self .flash (
653808 file ,
654809 partition = target ,
655810 force_exporter_http = force_exporter_http ,
656811 force_flash_bundle = force_flash_bundle ,
657812 cacert_file = cacert ,
658813 insecure_tls = insecure_tls ,
814+ headers = headers ,
815+ bearer_token = bearer ,
659816 )
660817
661818 @base .command ()
0 commit comments