@@ -69,7 +69,7 @@ def bootloader_shell(self):
6969 pass
7070 yield self .serial
7171
72- def flash (
72+ def flash ( # noqa: C901
7373 self ,
7474 path : PathBuf ,
7575 * ,
@@ -80,10 +80,19 @@ def flash(
8080 force_flash_bundle : str | None = None ,
8181 cacert_file : str | None = None ,
8282 insecure_tls : bool = False ,
83+ headers : dict [str , str ] | None = None ,
84+ bearer_token : str | None = None ,
8385 ):
86+ if bearer_token :
87+ bearer_token = self ._validate_bearer_token (bearer_token )
88+
89+ if headers :
90+ headers = self ._validate_header_dict (headers )
91+
8492 """Flash image to DUT"""
8593 should_download_to_httpd = True
8694 image_url = ""
95+ original_http_url = None
8796 operator_scheme = None
8897 # initrmafs cannot handle https yet, fallback to using the exporter's http server
8998 if path .startswith (("http://" , "https://" )) and not force_exporter_http :
@@ -93,7 +102,17 @@ def flash(
93102 else :
94103 # use the exporter's http server for the flasher image, we should download it first
95104 if operator is None :
96- path , operator , operator_scheme = operator_for_path (path )
105+ if path .startswith (("http://" , "https://" )) and bearer_token :
106+ parsed = urlparse (path )
107+ self .logger .info (f"Using Bearer token authentication for { parsed .netloc } " )
108+ original_http_url = path
109+ operator = Operator (
110+ "http" , root = "/" , endpoint = f"{ parsed .scheme } ://{ parsed .netloc } " , token = bearer_token
111+ )
112+ operator_scheme = "http"
113+ path = Path (parsed .path )
114+ else :
115+ path , operator , operator_scheme = operator_for_path (path )
97116 image_url = self .http .get_url () + "/" + path .name
98117
99118 # start counting time for the flash operation
@@ -106,7 +125,16 @@ def flash(
106125 # Start the storage write operation in the background
107126 storage_thread = threading .Thread (
108127 target = self ._transfer_bg_thread ,
109- args = (path , operator , operator_scheme , os_image_checksum , self .http .storage , error_queue , image_url ),
128+ args = (
129+ path ,
130+ operator ,
131+ operator_scheme ,
132+ os_image_checksum ,
133+ self .http .storage ,
134+ error_queue ,
135+ original_http_url ,
136+ headers ,
137+ ),
110138 name = "storage_transfer" ,
111139 )
112140 storage_thread .start ()
@@ -151,9 +179,17 @@ def flash(
151179 else :
152180 stored_cacert = self ._setup_flasher_ssl (console , manifest , cacert_file )
153181
154-
155- self ._flash_with_progress (console , manifest , path , image_url , target_device ,
156- insecure_tls , stored_cacert )
182+ header_args = self ._prepare_headers (headers , bearer_token )
183+ self ._flash_with_progress (
184+ console ,
185+ manifest ,
186+ path ,
187+ image_url ,
188+ target_device ,
189+ insecure_tls ,
190+ stored_cacert ,
191+ header_args ,
192+ )
157193
158194 total_time = time .time () - start_time
159195 # total time in minutes:seconds
@@ -221,7 +257,36 @@ def _curl_tls_args(self, insecure_tls: bool, stored_cacert: str | None) -> str:
221257 tls_args += f"--cacert { stored_cacert } "
222258 return tls_args .strip ()
223259
224- def _flash_with_progress (self , console , manifest , path , image_url , target_path , insecure_tls , stored_cacert ):
260+ def _curl_header_args (self , headers : dict [str , str ] | None ) -> str :
261+ """Generate header arguments for curl command"""
262+ if not headers :
263+ return ""
264+
265+ parts : list [str ] = []
266+
267+ def _sq (s : str ) -> str :
268+ return s .replace ("'" , "'\" '\" '" )
269+
270+ for k , v in headers .items ():
271+ k = str (k ).strip ()
272+ v = str (v ).strip ()
273+ if not k :
274+ continue
275+ parts .append (f"-H '{ _sq (k )} : { _sq (v )} '" )
276+
277+ return " " .join (parts )
278+
279+ def _flash_with_progress (
280+ self ,
281+ console ,
282+ manifest ,
283+ path ,
284+ image_url ,
285+ target_path ,
286+ insecure_tls ,
287+ stored_cacert ,
288+ header_args : str ,
289+ ):
225290 """Flash image to target device with progress monitoring.
226291
227292 Args:
@@ -240,11 +305,11 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
240305 tls_args = self ._curl_tls_args (insecure_tls , stored_cacert )
241306
242307 # Check if the image URL is accessible using curl and the TLS arguments
243- self ._check_url_access (console , prompt , image_url , tls_args )
308+ self ._check_url_access (console , prompt , image_url , tls_args , header_args )
244309
245310 # Flash the image, we run curl -> decompress -> dd in the background, so we can monitor dd's progress
246311 flash_cmd = (
247- f'( curl -fsSL { tls_args } "{ image_url } " | '
312+ f'( curl -fsSL { tls_args } { header_args } "{ image_url } " | '
248313 f"{ decompress_cmd } "
249314 f"dd of={ target_path } bs=64k iflag=fullblock oflag=direct) &"
250315 )
@@ -286,7 +351,7 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path,
286351 console .sendline ("sync" )
287352 console .expect (prompt , timeout = EXPECT_TIMEOUT_SYNC )
288353
289- def _check_url_access (self , console , prompt , image_url : str , tls_args : str ):
354+ def _check_url_access (self , console , prompt , image_url : str , tls_args : str , header_args : str ):
290355 """Check if the image URL is accessible using curl.
291356
292357 Args:
@@ -298,7 +363,9 @@ def _check_url_access(self, console, prompt, image_url: str, tls_args: str):
298363 Raises:
299364 RuntimeError: If the URL is not accessible
300365 """
301- console .sendline (f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null { tls_args } "{ image_url } "' )
366+ console .sendline (
367+ f'curl --location --max-time 30 --fail -sS -r 0-0 -o /dev/null { tls_args } { header_args } "{ image_url } "'
368+ )
302369 console .expect (prompt , timeout = EXPECT_TIMEOUT_DEFAULT )
303370 curl_output = console .before .decode (errors = "ignore" ).strip ()
304371 console .sendline ("echo $?" )
@@ -357,6 +424,7 @@ def _transfer_bg_thread(
357424 to_storage : OpendalClient ,
358425 error_queue ,
359426 original_url : str | None = None ,
427+ headers : dict [str , str ] | None = None ,
360428 ):
361429 """Transfer image to exporter storage in the background
362430 Args:
@@ -366,6 +434,7 @@ def _transfer_bg_thread(
366434 error_queue: Queue to put exceptions in if any
367435 known_hash: Known hash of the image
368436 original_url: Original URL for HTTP fallback
437+ headers: HTTP headers for requests
369438 """
370439 self .logger .info (f"Writing image to storage in the background: { src_path } " )
371440 try :
@@ -391,7 +460,9 @@ def _transfer_bg_thread(
391460 self .logger .info (f"Uploading image to storage: { filename } " )
392461 to_storage .write_from_path (filename , src_path , src_operator )
393462
394- metadata , metadata_json = self ._create_metadata_and_json (src_operator , src_path , file_hash , original_url )
463+ metadata , metadata_json = self ._create_metadata_and_json (
464+ src_operator , src_path , file_hash , original_url , headers
465+ )
395466 metadata_file = filename + ".metadata"
396467 to_storage .write_bytes (metadata_file , metadata_json .encode (errors = "ignore" ))
397468
@@ -414,7 +485,7 @@ def _sha256_file(self, src_operator, src_path) -> str:
414485 return m .hexdigest ()
415486
416487 def _create_metadata_and_json (
417- self , src_operator , src_path , file_hash = None , original_url = None
488+ self , src_operator , src_path , file_hash = None , original_url = None , headers : dict [ str , str ] | None = None
418489 ) -> tuple [Metadata | None , str ]:
419490 """Create a metadata json string from a metadata object"""
420491 metadata = None
@@ -435,7 +506,10 @@ def _create_metadata_and_json(
435506
436507 if original_url and original_url .startswith (("http://" , "https://" )):
437508 try :
438- response = requests .head (original_url )
509+ if headers :
510+ response = requests .head (original_url , headers = headers )
511+ else :
512+ response = requests .head (original_url )
439513
440514 http_metadata = {}
441515 if "content-length" in response .headers :
@@ -610,6 +684,71 @@ def manifest(self):
610684 self ._manifest = FlasherBundleManifestV1Alpha1 .from_string (yaml_str )
611685 return self ._manifest
612686
687+ def _validate_header_dict (self , header_map : dict [str , str ]) -> dict [str , str ]:
688+ token_re = re .compile (r"^[!#$%&'*+\-.^_`|~0-9A-Za-z]+$" )
689+ seen : set [str ] = set ()
690+ for key , value in header_map .items ():
691+ key = key .strip ()
692+ value = value .strip ()
693+ if not key :
694+ raise ArgumentError (f"Invalid header key: '{ key } '" )
695+
696+ if not token_re .match (key ):
697+ raise ArgumentError (f"Invalid header name '{ key } ': must be an HTTP token (RFC7230)" )
698+ if any (c in ("\r " , "\n " ) for c in key ) or any (c in ("\r " , "\n " ) for c in value ):
699+ raise ArgumentError ("Header names/values must not contain CR/LF" )
700+ kl = key .lower ()
701+ if kl in seen :
702+ raise ArgumentError (f"Duplicate header '{ key } '" )
703+ seen .add (kl )
704+ return header_map
705+
706+ def _parse_headers (self , headers : list [str ]) -> dict [str , str ]:
707+ header_map : dict [str , str ] = {}
708+ for h in headers :
709+ if ":" not in h :
710+ raise click .ClickException (f"Invalid header format: { h !r} . Expected 'Key: Value'." )
711+
712+ key , value = h .split (":" , 1 )
713+ header_map [key .strip ()] = value .strip ()
714+
715+ try :
716+ return self ._validate_header_dict (header_map )
717+ except ArgumentError as e :
718+ raise click .ClickException (str (e )) from e
719+
720+ def _prepare_headers (self , headers : dict [str , str ] | None , bearer_token : str | None ) -> str :
721+ all_headers = headers .copy () if headers else {}
722+ if bearer_token :
723+ if any (k .lower () == "authorization" for k in all_headers .keys ()):
724+ self .logger .warning ("Authorization header provided - ignoring bearer token" )
725+ else :
726+ all_headers ["Authorization" ] = f"Bearer { bearer_token } "
727+
728+ if bearer_token and "Authorization" not in (headers or {}):
729+ auth_header = {"Authorization" : all_headers ["Authorization" ]}
730+ self ._validate_header_dict (auth_header )
731+
732+ return self ._curl_header_args (all_headers )
733+
734+ def _validate_bearer_token (self , token : str | None ) -> str | None :
735+ if token is None :
736+ return None
737+
738+ token = token .strip ()
739+ if not token :
740+ raise click .ClickException ("Bearer token cannot be empty" )
741+
742+ # RFC 6750 allows token68 format (base64url-encoded) or other token formats
743+ # Basic validation: printable ASCII excluding whitespace and special chars that could cause issues
744+ if not all (32 < ord (c ) < 127 and c not in ' "\\ ' for c in token ):
745+ raise click .ClickException ("Bearer token contains invalid characters" )
746+
747+ if len (token ) > 4096 :
748+ raise click .ClickException ("Bearer token is too long (max 4096 characters)" )
749+
750+ return token
751+
613752 def cli (self ):
614753 @click .group
615754 def base ():
@@ -629,6 +768,17 @@ def base():
629768 @click .option ("--force-flash-bundle" , type = str , help = "Force use of a specific flasher OCI bundle" )
630769 @click .option ("--cacert" , type = click .Path (exists = True , dir_okay = False ), help = "CA certificate to use for HTTPS" )
631770 @click .option ("--insecure-tls" , is_flag = True , help = "Skip TLS certificate verification" )
771+ @click .option (
772+ "--header" ,
773+ "header" ,
774+ multiple = True ,
775+ help = "Custom HTTP header in 'Key: Value' format" ,
776+ )
777+ @click .option (
778+ "--bearer" ,
779+ type = str ,
780+ help = "Bearer token for HTTP authentication" ,
781+ )
632782 @debug_console_option
633783 def flash (
634784 file ,
@@ -640,6 +790,8 @@ def flash(
640790 force_flash_bundle ,
641791 cacert ,
642792 insecure_tls ,
793+ header ,
794+ bearer ,
643795 ):
644796 """Flash image to DUT from file"""
645797 if os_image_checksum_file and os .path .exists (os_image_checksum_file ):
@@ -648,13 +800,18 @@ def flash(
648800 self .logger .info (f"Read checksum from file: { os_image_checksum } " )
649801
650802 self .set_console_debug (console_debug )
803+
804+ headers = self ._parse_headers (header ) if header else None
805+
651806 self .flash (
652807 file ,
653808 partition = target ,
654809 force_exporter_http = force_exporter_http ,
655810 force_flash_bundle = force_flash_bundle ,
656811 cacert_file = cacert ,
657812 insecure_tls = insecure_tls ,
813+ headers = headers ,
814+ bearer_token = bearer ,
658815 )
659816
660817 @base .command ()
0 commit comments