11# from __future__ import annotations
22
33import json
4- import os
54import time
65import threading
76
109from urllib3 import Retry
1110
1211from google .oauth2 import service_account
12+ from google .oauth2 .credentials import Credentials
1313import google .auth .transport .requests
1414
1515from pyfcm .errors import (
1616 AuthenticationError ,
1717 InvalidDataError ,
18- FCMError ,
1918 FCMSenderIdMismatchError ,
2019 FCMServerError ,
2120 FCMNotRegisteredError ,
2524
2625
2726class BaseAPI (object ):
28- FCM_END_POINT = "https://fcm.googleapis.com/v1/projects"
27+ FCM_END_POINT_BASE = "https://fcm.googleapis.com/v1/projects"
2928
3029 def __init__ (
3130 self ,
32- service_account_file : str ,
33- project_id : str ,
34- credentials = None ,
35- proxy_dict = None ,
36- env = None ,
31+ service_account_file : str | None = None ,
32+ project_id : str | None = None ,
33+ credentials : Credentials | None = None ,
34+ proxy_dict : dict | None = None ,
35+ env : str | None = None ,
3736 json_encoder = None ,
3837 adapter = None ,
3938 ):
@@ -48,25 +47,23 @@ def __init__(
4847 json_encoder (BaseJSONEncoder): JSON encoder
4948 adapter (BaseAdapter): adapter instance
5049 """
51- self .service_account_file = service_account_file
52- self .project_id = project_id
53- self .FCM_END_POINT = self .FCM_END_POINT + f"/{ self .project_id } /messages:send"
54- self .FCM_REQ_PROXIES = None
55- self .custom_adapter = adapter
56- self .thread_local = threading .local ()
57- self .credentials = credentials
58-
59- if not service_account_file and not credentials :
50+ if not (service_account_file or credentials ):
6051 raise AuthenticationError (
6152 "Please provide a service account file path or credentials in the constructor"
6253 )
6354
55+ self ._service_account_file = service_account_file
56+ self ._fcm_end_point = None
57+ self ._project_id = project_id
58+ self .credentials = credentials
59+ self .custom_adapter = adapter
60+ self .thread_local = threading .local ()
61+
6462 if (
6563 proxy_dict
6664 and isinstance (proxy_dict , dict )
6765 and (("http" in proxy_dict ) or ("https" in proxy_dict ))
6866 ):
69- self .FCM_REQ_PROXIES = proxy_dict
7067 self .requests_session .proxies .update (proxy_dict )
7168
7269 if env == "app_engine" :
@@ -79,6 +76,23 @@ def __init__(
7976
8077 self .json_encoder = json_encoder
8178
79+ @property
80+ def fcm_end_point (self ) -> str :
81+ if self ._fcm_end_point is not None :
82+ return self ._fcm_end_point
83+ if self .credentials is None :
84+ self ._initialize_credentials ()
85+ # prefer the project ID scoped to the supplied credentials.
86+ # If, for some reason, the credentials do not specify a project id,
87+ # we'll check for an explicitly supplied one, and raise an error otherwise
88+ project_id = getattr (self .credentials , "project_id" , None ) or self ._project_id
89+ if not project_id :
90+ raise AuthenticationError (
91+ "Please provide a project_id either explicitly or through Google credentials."
92+ )
93+ self ._fcm_end_point = self .FCM_END_POINT_BASE + f"/{ project_id } /messages:send"
94+ return self ._fcm_end_point
95+
8296 @property
8397 def requests_session (self ):
8498 if getattr (self .thread_local , "requests_session" , None ) is None :
@@ -101,7 +115,7 @@ def requests_session(self):
101115
102116 def send_request (self , payload = None , timeout = None ):
103117 response = self .requests_session .post (
104- self .FCM_END_POINT , data = payload , timeout = timeout
118+ self .fcm_end_point , data = payload , timeout = timeout
105119 )
106120 if (
107121 "Retry-After" in response .headers
@@ -110,17 +124,21 @@ def send_request(self, payload=None, timeout=None):
110124 sleep_time = int (response .headers ["Retry-After" ])
111125 time .sleep (sleep_time )
112126 return self .send_request (payload , timeout )
127+
128+ if self ._is_access_token_expired (response ):
129+ self .thread_local .token_expiry = 0
130+ return self .send_request (payload , timeout )
131+
113132 return response
114133
115134 def send_async_request (self , params_list , timeout ):
116-
117135 import asyncio
118136 from .async_fcm import fetch_tasks
119137
120138 payloads = [self .parse_payload (** params ) for params in params_list ]
121139 responses = asyncio .new_event_loop ().run_until_complete (
122140 fetch_tasks (
123- end_point = self .FCM_END_POINT ,
141+ end_point = self .fcm_end_point ,
124142 headers = self .request_headers (),
125143 payloads = payloads ,
126144 timeout = timeout ,
@@ -129,25 +147,56 @@ def send_async_request(self, params_list, timeout):
129147
130148 return responses
131149
150+ def _is_access_token_expired (self , response ):
151+ """
152+ Check if the response indicates an expired access token
153+
154+ Args:
155+ response: HTTP response object
156+
157+ Returns:
158+ bool: True if access token is expired, False otherwise
159+ """
160+ if response .status_code != 401 :
161+ return False
162+
163+ try :
164+ error_response = response .json ()
165+ error_details = error_response .get ("error" , {}).get ("details" , [])
166+ for detail in error_details :
167+ if detail .get ("reason" ) == "ACCESS_TOKEN_EXPIRED" :
168+ return True
169+ except (ValueError , AttributeError ):
170+ pass
171+
172+ return False
173+
174+ def _initialize_credentials (self ):
175+ """
176+ Initialize credentials and FCM endpoint if not already initialized.
177+ """
178+ if self .credentials is None :
179+ self .credentials = service_account .Credentials .from_service_account_file (
180+ self ._service_account_file ,
181+ scopes = ["https://www.googleapis.com/auth/firebase.messaging" ],
182+ )
183+ self ._service_account_file = None
184+
132185 def _get_access_token (self ):
133186 """
134187 Generates access token from credentials.
135188 If token expires then new access token is generated.
136189 Returns:
137190 str: Access token
138191 """
192+ if self .credentials is None :
193+ self ._initialize_credentials ()
194+
139195 # get OAuth 2.0 access token
140196 try :
141- if self .service_account_file :
142- credentials = service_account .Credentials .from_service_account_file (
143- self .service_account_file ,
144- scopes = ["https://www.googleapis.com/auth/firebase.messaging" ],
145- )
146- else :
147- credentials = self .credentials
148197 request = google .auth .transport .requests .Request ()
149- credentials .refresh (request )
150- return credentials .token
198+ self . credentials .refresh (request )
199+ return self . credentials .token
151200 except Exception as e :
152201 raise InvalidDataError (e )
153202
@@ -195,7 +244,6 @@ def parse_response(self, response):
195244 FCMSenderIdMismatchError: the authenticated sender is different from the sender registered to the token
196245 FCMNotRegisteredError: device token is missing, not registered, or invalid
197246 """
198-
199247 if response .status_code == 200 :
200248 if (
201249 "content-length" in response .headers
@@ -221,10 +269,11 @@ def parse_response(self, response):
221269 raise FCMNotRegisteredError ("Token not registered" )
222270 else :
223271 raise FCMServerError (
224- f"FCM server error: Unexpected status code { response .status_code } . The server might be temporarily unavailable."
272+ f"FCM server error: Unexpected status code { response .status_code } . "
273+ "The server might be temporarily unavailable."
225274 )
226275
227- def parse_payload (
276+ def parse_payload ( # noqa: C901
228277 self ,
229278 fcm_token = None ,
230279 notification_title = None ,
0 commit comments