diff --git a/cla-backend-go/cmd/s3_upload/main.go b/cla-backend-go/cmd/s3_upload/main.go index ca47cded8..0ffc99a61 100644 --- a/cla-backend-go/cmd/s3_upload/main.go +++ b/cla-backend-go/cmd/s3_upload/main.go @@ -57,7 +57,7 @@ func init() { if err != nil { log.Fatal(err) } - signService = sign.NewService("", "", companyRepo, nil, nil, nil, nil, configFile.DocuSignPrivateKey, nil, nil, nil, nil, githubOrgService, nil, "", "", nil, nil, nil, nil, nil) + signService = sign.NewService("", "", companyRepo, nil, nil, nil, nil, configFile.DocuSignPrivateKey, nil, nil, nil, nil, githubOrgService, nil, "", "", nil, nil, nil, nil, nil, nil, configFile.SSS.Required) // projectRepo = repository.NewRepository(awsSession, stage, nil, nil, nil) utils.SetS3Storage(awsSession, configFile.SignatureFilesBucket) } diff --git a/cla-backend-go/cmd/server.go b/cla-backend-go/cmd/server.go index 678035e70..0d377c0dc 100644 --- a/cla-backend-go/cmd/server.go +++ b/cla-backend-go/cmd/server.go @@ -83,6 +83,7 @@ import ( "github.com/linuxfoundation/easycla/cla-backend-go/api_logs" "github.com/linuxfoundation/easycla/cla-backend-go/signatures" + "github.com/linuxfoundation/easycla/cla-backend-go/sss" "github.com/linuxfoundation/easycla/cla-backend-go/telemetry" v2Signatures "github.com/linuxfoundation/easycla/cla-backend-go/v2/signatures" @@ -448,7 +449,31 @@ func server(localMode bool) http.Handler { v2GithubActivityService := v2GithubActivity.NewService(gitV1Repository, githubOrganizationsRepo, eventsService, autoEnableService, emailService) v2ClaGroupService := cla_groups.NewService(v1ProjectService, templateService, v1ProjectClaGroupRepo, v1ClaManagerService, v1SignaturesService, metricsRepo, gerritService, v1RepositoriesService, eventsService) - v2SignService := sign.NewService(configFile.ClaAPIV4Base, configFile.ClaV1ApiURL, v1CompanyRepo, v1CLAGroupRepo, v1ProjectClaGroupRepo, v1CompanyService, v2ClaGroupService, configFile.DocuSignPrivateKey, usersService, v1SignaturesService, storeRepository, v1RepositoriesService, githubOrganizationsService, gitlabOrganizationsService, configFile.CLALandingPage, configFile.CLALogoURL, emailService, eventsService, gitlabActivityService, gitlabApp, gerritService) + + // Initialize SSS (Sanctions Screening Service) client if configured + var sssClient *sss.Client + if configFile.SSS.BaseURL != "" && configFile.SSS.Auth0Domain != "" && configFile.SSS.Auth0ClientID != "" && configFile.SSS.Auth0ClientSecret != "" && configFile.SSS.Auth0Audience != "" { + sssTimeout := time.Duration(configFile.SSS.RequestTimeoutSec) * time.Second + if sssTimeout <= 0 { + sssTimeout = 30 * time.Second // default timeout + } + sssConfig := sss.SSSConfig{ + BaseURL: configFile.SSS.BaseURL, + Auth0Domain: configFile.SSS.Auth0Domain, + Auth0ClientID: configFile.SSS.Auth0ClientID, + Auth0ClientSecret: configFile.SSS.Auth0ClientSecret, + Auth0Audience: configFile.SSS.Auth0Audience, + Timeout: sssTimeout, + } + var sssErr error + sssClient, sssErr = sss.NewClient(sssConfig) + if sssErr != nil { + log.WithFields(f).WithError(sssErr).Warnf("failed to initialize SSS client, screening will be unavailable: %v", sssErr) + sssClient = nil + } + } + + v2SignService := sign.NewService(configFile.ClaAPIV4Base, configFile.ClaV1ApiURL, v1CompanyRepo, v1CLAGroupRepo, v1ProjectClaGroupRepo, v1CompanyService, v2ClaGroupService, configFile.DocuSignPrivateKey, usersService, v1SignaturesService, storeRepository, v1RepositoriesService, githubOrganizationsService, gitlabOrganizationsService, configFile.CLALandingPage, configFile.CLALogoURL, emailService, eventsService, gitlabActivityService, gitlabApp, gerritService, sssClient, configFile.SSS.Required) sessionStore, err := dynastore.New(dynastore.Path("/"), dynastore.HTTPOnly(), dynastore.TableName(configFile.SessionStoreTableName), dynastore.DynamoDB(dynamodb.New(awsSession))) if err != nil { diff --git a/cla-backend-go/company/repository.go b/cla-backend-go/company/repository.go index 0d848bc71..7784d1372 100644 --- a/cla-backend-go/company/repository.go +++ b/cla-backend-go/company/repository.go @@ -53,6 +53,7 @@ type IRepository interface { //nolint ApproveCompanyAccessRequest(ctx context.Context, companyInviteID string) error RejectCompanyAccessRequest(ctx context.Context, companyInviteID string) error UpdateCompanyAccessList(ctx context.Context, companyID string, companyACL []string) error + UpdateCompanySanctionStatus(ctx context.Context, companyID string, sanctioned bool) error IsCCLAEnabledForCompany(ctx context.Context, companyID string) (bool, error) } @@ -1276,7 +1277,67 @@ func (repo repository) UpdateCompanyAccessList(ctx context.Context, companyID st return nil } -// CreateCompany creates a new company record +// UpdateCompanySanctionStatus updates the is_sanctioned flag for a company. +// It only performs the update if the value has changed to avoid unnecessary DynamoDB writes. +func (repo repository) UpdateCompanySanctionStatus(ctx context.Context, companyID string, sanctioned bool) error { + f := logrus.Fields{ + "functionName": "company.repository.UpdateCompanySanctionStatus", + utils.XREQUESTID: ctx.Value(utils.XREQUESTID), + "companyID": companyID, + "sanctioned": sanctioned, + } + + // Fetch current company to check if value has changed + currentCompany, err := repo.GetCompany(ctx, companyID) + if err != nil { + log.WithFields(f).Warnf("unable to fetch current company record to check sanction status, error: %v", err) + return err + } + if currentCompany == nil { + return fmt.Errorf("company not found: %s", companyID) + } + + // Avoid unnecessary writes - only update if value has changed + if currentCompany.IsSanctioned == sanctioned { + log.WithFields(f).Debugf("sanction status unchanged (current=%v, new=%v), skipping update", currentCompany.IsSanctioned, sanctioned) + return nil + } + + log.WithFields(f).Debugf("updating sanction status from %v to %v", currentCompany.IsSanctioned, sanctioned) + + _, now := utils.CurrentTime() + + input := &dynamodb.UpdateItemInput{ + ExpressionAttributeNames: map[string]*string{ + "#S": aws.String("is_sanctioned"), + "#M": aws.String("date_modified"), + }, + ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ + ":s": { + BOOL: aws.Bool(sanctioned), + }, + ":m": { + S: aws.String(now), + }, + }, + TableName: aws.String(repo.companyTableName), + Key: map[string]*dynamodb.AttributeValue{ + "company_id": { + S: aws.String(companyID), + }, + }, + UpdateExpression: aws.String("SET #S = :s, #M = :m"), + } + + _, err = repo.dynamoDBClient.UpdateItem(input) + if err != nil { + log.WithFields(f).Warnf("error updating company sanction status, error: %v", err) + return err + } + + return nil +} + func (repo repository) CreateCompany(ctx context.Context, in *models.Company) (*models.Company, error) { f := logrus.Fields{ "functionName": "company.repository.CreateCompany", diff --git a/cla-backend-go/config/config.go b/cla-backend-go/config/config.go index 801e04240..37b96d7f0 100644 --- a/cla-backend-go/config/config.go +++ b/cla-backend-go/config/config.go @@ -98,6 +98,20 @@ type Config struct { // DocuSignPrivateKey is the private key for the DocuSign API DocuSignPrivateKey string `json:"docuSignPrivateKey"` + + // SSS (Sanctions Screening Service) configuration + SSS SSS `json:"sss"` +} + +// SSS model for Sanctions Screening Service configuration +type SSS struct { + BaseURL string `json:"base_url"` + Auth0Domain string `json:"auth0_domain"` + Auth0ClientID string `json:"auth0_client_id"` + Auth0ClientSecret string `json:"auth0_client_secret"` + Auth0Audience string `json:"auth0_audience"` + RequestTimeoutSec int `json:"request_timeout_sec"` + Required bool `json:"required"` } // Auth0 model diff --git a/cla-backend-go/config/local.go b/cla-backend-go/config/local.go index edccc145e..147bdeaa6 100644 --- a/cla-backend-go/config/local.go +++ b/cla-backend-go/config/local.go @@ -23,6 +23,7 @@ func loadLocalConfig(configFilePath string) (Config, error) { } localConfig := Config{} + localConfig.SSS.Required = true err = json.Unmarshal(content, &localConfig) if err != nil { return Config{}, err diff --git a/cla-backend-go/config/ssm.go b/cla-backend-go/config/ssm.go index 21f97cf51..931b4b23a 100644 --- a/cla-backend-go/config/ssm.go +++ b/cla-backend-go/config/ssm.go @@ -45,6 +45,7 @@ func loadSSMConfig(awsSession *session.Session, stage string) Config { //nolint } config := Config{} config.SignatureQueryDefaultValue = "all" + config.SSS.Required = true ssmClient := ssm.New(awsSession) @@ -268,5 +269,14 @@ func loadSSMConfig(awsSession *session.Session, stage string) Config { //nolint } } + sssRequiredKey := fmt.Sprintf("cla-sss-required-%s", stage) + if value, err := getSSMString(ssmClient, sssRequiredKey); err != nil { + log.WithFields(f).WithError(err).Warnf("unable to read optional SSS required flag %s - defaulting to true", sssRequiredKey) + } else if boolVal, err := strconv.ParseBool(value); err != nil { + log.WithFields(f).WithError(err).Warnf("unable to convert %s value to a boolean - defaulting to true", sssRequiredKey) + } else { + config.SSS.Required = boolVal + } + return config } diff --git a/cla-backend-go/v2/sign/helpers.go b/cla-backend-go/v2/sign/helpers.go index db1003604..33f9b089f 100644 --- a/cla-backend-go/v2/sign/helpers.go +++ b/cla-backend-go/v2/sign/helpers.go @@ -145,6 +145,23 @@ func (s service) hasUserSigned(ctx context.Context, user *models.User, projectID log.WithFields(f).WithError(compModelErr).Warnf("problem looking up company: %s", companyID) return &hasSigned, &companyAffiliation, compModelErr } + if companyModel == nil { + compModelErr = fmt.Errorf("company not found: %s", companyID) + log.WithFields(f).WithError(compModelErr).Warnf("company record is nil for company: %s", companyID) + return &hasSigned, &companyAffiliation, compModelErr + } + + // Check if company is sanctioned before allowing ECLA acknowledgement + sanctioned, sanctionErr := s.checkCompanyCompliance(ctx, companyModel) + if sanctionErr != nil { + log.WithFields(f).WithError(sanctionErr).Warnf("failed to check company compliance for company: %s", companyID) + return &hasSigned, &companyAffiliation, sanctionErr + } + if sanctioned { + sanctionedErr := fmt.Errorf("company %s is sanctioned", companyID) + log.WithFields(f).WithError(sanctionedErr).Error("company is sanctioned") + return &hasSigned, &companyAffiliation, sanctionedErr + } // Load the CLA Group - make sure it is valid claGroupModel, claGroupModelErr := s.claGroupService.GetCLAGroup(ctx, projectID) diff --git a/cla-backend-go/v2/sign/service.go b/cla-backend-go/v2/sign/service.go index f5b583d5e..5c65a8fa4 100644 --- a/cla-backend-go/v2/sign/service.go +++ b/cla-backend-go/v2/sign/service.go @@ -15,6 +15,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "github.com/go-openapi/strfmt" @@ -28,6 +29,7 @@ import ( "github.com/linuxfoundation/easycla/cla-backend-go/projects_cla_groups" "github.com/linuxfoundation/easycla/cla-backend-go/repositories" "github.com/linuxfoundation/easycla/cla-backend-go/signatures" + "github.com/linuxfoundation/easycla/cla-backend-go/sss" "github.com/linuxfoundation/easycla/cla-backend-go/users" "github.com/linuxfoundation/easycla/cla-backend-go/v2/cla_groups" gitlab_activity "github.com/linuxfoundation/easycla/cla-backend-go/v2/gitlab-activity" @@ -59,6 +61,7 @@ const ( DontLoadRepoDetails = false DocSignFalse = "false" DocusignCompleted = "Completed" + complianceCacheTTL = 5 * time.Minute ) // errors @@ -117,12 +120,22 @@ type service struct { gitlabActivityService gitlab_activity.Service gitlabApp *gitlab_api.App gerritService gerrits.Service + sssClient *sss.Client + sssRequired bool + complianceCache map[string]complianceCacheEntry + complianceCacheMu sync.Mutex +} + +type complianceCacheEntry struct { + sanctioned bool + err error + expiresAt time.Time } // NewService returns an instance of v2 project service func NewService(apiURL, v1API string, compRepo company.IRepository, projectRepo ProjectRepo, pcgRepo projects_cla_groups.Repository, compService company.IService, claGroupService cla_groups.Service, docsignPrivateKey string, userService users.Service, signatureService signatures.SignatureService, storeRepository store.Repository, repositoryService repositories.Service, githubOrgService github_organizations.Service, gitlabOrgService gitlab_organizations.ServiceInterface, claLandingPage string, claLogoURL string, emailTemplateService emails.EmailTemplateService, eventsService events.Service, gitlabActivityService gitlab_activity.Service, gitlabApp *gitlab_api.App, - gerritService gerrits.Service) Service { + gerritService gerrits.Service, sssClient *sss.Client, sssRequired bool) Service { return &service{ ClaV4ApiURL: apiURL, ClaV1ApiURL: v1API, @@ -145,6 +158,9 @@ func NewService(apiURL, v1API string, compRepo company.IRepository, projectRepo gitlabApp: gitlabApp, gerritService: gerritService, eventsService: eventsService, + sssClient: sssClient, + sssRequired: sssRequired, + complianceCache: make(map[string]complianceCacheEntry), } } @@ -243,7 +259,19 @@ func (s *service) RequestCorporateSignature(ctx context.Context, lfUsername stri } // 1.5 Check if company is sanctioned - if comp != nil && comp.IsSanctioned { + if comp == nil { + if input.CompanySfid != nil { + return nil, fmt.Errorf("company not found for SFID %s", *input.CompanySfid) + } + return nil, fmt.Errorf("company not found") + } + + sanctioned, sanctionErr := s.checkCompanyCompliance(ctx, comp) + if sanctionErr != nil { + log.WithFields(f).WithError(sanctionErr).Error("failed to check company compliance") + return nil, sanctionErr + } + if sanctioned { if input.CompanySfid != nil { err = fmt.Errorf("company %s is sanctioned", *input.CompanySfid) } else { @@ -2936,3 +2964,255 @@ func (s *service) GetUserActiveSignature(ctx context.Context, userID string) (*m UserID: userID, }, nil } + +// checkCompanyCompliance queries the Sanctions Screening Service for the given company +// and persists the result. Returns (sanctioned, error). A nil sssClient is a no-op. +func (s *service) checkCompanyCompliance(ctx context.Context, company *v1Models.Company) (bool, error) { + f := logrus.Fields{ + "functionName": "sign.checkCompanyCompliance", + utils.XREQUESTID: ctx.Value(utils.XREQUESTID), + "companyID": company.CompanyID, + "companyName": company.CompanyName, + } + + // Check if company is already manually sanctioned - if so, always block + if company.IsSanctioned { + log.WithFields(f).Warnf("company is manually sanctioned, blocking") + return true, nil + } + + cacheKey := s.complianceCacheKey(company) + if cached, ok := s.getComplianceCache(cacheKey); ok { + log.WithFields(f).Debugf("using cached compliance result for organization/company: %s", cacheKey) + return cached.sanctioned, cached.err + } + + if s.sssClient == nil { + log.WithFields(f).Debug("SSS client not configured, skipping live compliance check") + s.setComplianceCache(cacheKey, false, nil) + return false, nil + } + + // Fetch org from organization service to get the website/domain. + orgClient := organizationService.GetClient() + if orgClient == nil { + resultErr := fmt.Errorf("checkCompanyCompliance: organization service client is not configured") + if !s.sssRequired { + log.WithFields(f).WithError(resultErr).Warn("SSS is not required; continuing without live compliance result") + s.setComplianceCache(cacheKey, false, nil) + return false, nil + } + s.setComplianceCache(cacheKey, false, resultErr) + return false, resultErr + } + org, err := orgClient.GetOrganization(ctx, company.CompanyExternalID) + if err != nil { + log.WithFields(f).WithError(err).Warnf("failed to get organization %s for domain resolution", company.CompanyExternalID) + resultErr := fmt.Errorf("checkCompanyCompliance: failed to get organization %s: %w", company.CompanyExternalID, err) + if !s.sssRequired { + log.WithFields(f).WithError(resultErr).Warn("SSS is not required; continuing without live compliance result") + s.setComplianceCache(cacheKey, false, nil) + return false, nil + } + s.setComplianceCache(cacheKey, false, resultErr) + return false, resultErr + } + if org == nil { + log.WithFields(f).Warnf("organization record is nil for %s", company.CompanyExternalID) + resultErr := fmt.Errorf("checkCompanyCompliance: organization record is nil for %s", company.CompanyExternalID) + if !s.sssRequired { + log.WithFields(f).WithError(resultErr).Warn("SSS is not required; continuing without live compliance result") + s.setComplianceCache(cacheKey, false, nil) + return false, nil + } + s.setComplianceCache(cacheKey, false, resultErr) + return false, resultErr + } + + // Resolve domain: prefer Domains field, fallback to Link field + domain := s.resolveDomain(f, org) + if domain == "" { + log.WithFields(f).Warnf("unable to resolve domain for organization %s; skipping SSS check", company.CompanyExternalID) + s.setComplianceCache(cacheKey, false, nil) + return false, nil + } + + log.WithFields(f).Debugf("resolved domain: %s for SSS check", domain) + + req := sss.OrganizationStatusRequest{ + Domain: domain, + OrgName: company.CompanyName, + } + if strings.HasPrefix(company.CompanyExternalID, "001") { + req.SFDCID = company.CompanyExternalID + } + + result, err := s.sssClient.GetOrganizationStatus(ctx, req) + if err != nil { + blocked, handledErr := s.handleSSSError(f, company.CompanyID, err) + s.setComplianceCache(cacheKey, blocked, handledErr) + return blocked, handledErr + } + + sanctioned := result.Status == sss.StatusFlagged + + // Only persist if flagged (never clear a manual sanction via SSS clean result) + if sanctioned { + log.WithFields(f).Warnf("SSS returned flagged status for company %s, persisting sanction", company.CompanyID) + if persistErr := s.companyRepo.UpdateCompanySanctionStatus(ctx, company.CompanyID, true); persistErr != nil { + log.WithFields(f).WithError(persistErr).Warnf("failed to persist sanction status for company %s", company.CompanyID) + resultErr := fmt.Errorf("failed to persist sanction status for company %s: %w", company.CompanyID, persistErr) + s.setComplianceCache(cacheKey, false, resultErr) + return false, resultErr + } + } else { + log.WithFields(f).Debugf("SSS returned clean status for company %s", company.CompanyID) + } + + // Return combined result: blocked if manually sanctioned OR sss flagged + blocked := company.IsSanctioned || sanctioned + s.setComplianceCache(cacheKey, blocked, nil) + return blocked, nil +} + +func (s *service) complianceCacheKey(company *v1Models.Company) string { + if company == nil { + return "" + } + if key := strings.TrimSpace(company.CompanyExternalID); key != "" { + return key + } + return strings.TrimSpace(company.CompanyID) +} + +func (s *service) getComplianceCache(key string) (complianceCacheEntry, bool) { + if key == "" || s.complianceCache == nil { + return complianceCacheEntry{}, false + } + s.complianceCacheMu.Lock() + defer s.complianceCacheMu.Unlock() + + entry, ok := s.complianceCache[key] + if !ok { + return complianceCacheEntry{}, false + } + if time.Now().After(entry.expiresAt) { + delete(s.complianceCache, key) + return complianceCacheEntry{}, false + } + return entry, true +} + +func (s *service) setComplianceCache(key string, sanctioned bool, err error) { + if key == "" { + return + } + s.complianceCacheMu.Lock() + defer s.complianceCacheMu.Unlock() + if s.complianceCache == nil { + s.complianceCache = make(map[string]complianceCacheEntry) + } + s.complianceCache[key] = complianceCacheEntry{ + sanctioned: sanctioned, + err: err, + expiresAt: time.Now().Add(complianceCacheTTL), + } +} + +// resolveDomain attempts to resolve the domain for an organization. +// Priority: 1) Domains field from org (if available), 2) Parse Link field +func (s *service) resolveDomain(f logrus.Fields, org interface{}) string { + if domainStruct, ok := org.(interface{ GetDomains() []string }); ok { + domains := domainStruct.GetDomains() + if len(domains) > 0 && strings.TrimSpace(domains[0]) != "" { + domain := strings.TrimSpace(domains[0]) + domain = strings.TrimPrefix(domain, "www.") + log.WithFields(f).Debugf("resolved domain from Domains field: %s", domain) + return domain + } + } + + if linkStruct, ok := org.(interface{ GetLink() string }); ok { + link := strings.TrimSpace(linkStruct.GetLink()) + if link != "" { + domain := s.parseDomain(link) + if domain != "" { + return domain + } + } + } + + return "" +} + +// parseDomain extracts the hostname from a URL string. +// If the URL lacks a scheme, it prepends https:// for parsing. +func (s *service) parseDomain(urlStr string) string { + urlStr = strings.TrimSpace(urlStr) + if urlStr == "" { + return "" + } + + // Prepend https:// if no scheme is present + if !strings.Contains(urlStr, "://") { + urlStr = "https://" + urlStr + } + + u, err := url.Parse(urlStr) + if err != nil { + return "" + } + + hostname := u.Hostname() + if hostname == "" { + return "" + } + + // Strip leading www. + hostname = strings.TrimPrefix(hostname, "www.") + return hostname +} + +// handleSSSError differentiates between various SSS error types and logs appropriately. +// Returns a non-nil error for SSS failures that should block signing. +func (s *service) handleSSSError(f logrus.Fields, companyID string, err error) (bool, error) { + var badReqErr *sss.BadRequestError + var authErr *sss.AuthError + var retryErr *sss.RetryableError + var notFoundErr *sss.NotFoundError + var timeoutErr *sss.TimeoutError + allowWhenOptional := func(message string) (bool, error) { + if s.sssRequired { + return false, fmt.Errorf("%s for company %s: %w", message, companyID, err) + } + log.WithFields(f).WithError(err).Warnf("%s for company %s; SSS is not required, continuing", message, companyID) + return false, nil + } + + switch { + case errors.As(err, &timeoutErr): + log.WithFields(f).WithError(err).Warnf("SSS request timed out for company %s", companyID) + return allowWhenOptional("SSS screening unavailable (timeout)") + + case errors.As(err, &authErr): + log.WithFields(f).WithError(err).Errorf("SSS authentication/configuration error for company %s", companyID) + return allowWhenOptional("SSS authentication error (check configuration)") + + case errors.As(err, &retryErr): + log.WithFields(f).WithError(err).Warnf("SSS request failed with retryable error for company %s", companyID) + return allowWhenOptional("SSS screening unavailable (transient failure)") + + case errors.As(err, ¬FoundErr): + log.WithFields(f).WithError(err).Warnf("SSS organization not found for company %s", companyID) + // Not found is not a blocking error - proceed without SSS result + return false, nil + + case errors.As(err, &badReqErr): + log.WithFields(f).WithError(err).Warnf("SSS bad request for company %s", companyID) + return false, fmt.Errorf("SSS bad request for company %s: %w", companyID, err) + + default: + log.WithFields(f).WithError(err).Warnf("SSS request failed with unexpected error for company %s", companyID) + return allowWhenOptional("SSS request failed") + } +} diff --git a/cla-backend-go/v2/sign/service_sss_test.go b/cla-backend-go/v2/sign/service_sss_test.go new file mode 100644 index 000000000..99ad505f1 --- /dev/null +++ b/cla-backend-go/v2/sign/service_sss_test.go @@ -0,0 +1,102 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sign + +import ( + "errors" + "testing" + "time" + + "github.com/linuxfoundation/easycla/cla-backend-go/gen/v1/models" + "github.com/linuxfoundation/easycla/cla-backend-go/sss" + "github.com/sirupsen/logrus" +) + +type testOrg struct { + domains []string + link string +} + +func (o testOrg) GetDomains() []string { + return o.domains +} + +func (o testOrg) GetLink() string { + return o.link +} + +func TestResolveDomainPrefersDomains(t *testing.T) { + svc := &service{} + + got := svc.resolveDomain(logrus.Fields{}, testOrg{ + domains: []string{"www.example.com"}, + link: "https://fallback.example.org/path", + }) + + if got != "example.com" { + t.Fatalf("expected domain from Domains field, got %q", got) + } +} + +func TestResolveDomainFallsBackToParsedLink(t *testing.T) { + svc := &service{} + + got := svc.resolveDomain(logrus.Fields{}, testOrg{ + link: "www.example.org/path?query=1", + }) + + if got != "example.org" { + t.Fatalf("expected parsed Link hostname, got %q", got) + } +} + +func TestHandleSSSErrorRequiredBlocksAvailabilityErrors(t *testing.T) { + svc := &service{sssRequired: true} + + _, err := svc.handleSSSError(logrus.Fields{}, "company-id", &sss.RetryableError{Message: "unavailable"}) + if err == nil { + t.Fatal("expected required SSS retryable error to block") + } +} + +func TestHandleSSSErrorOptionalAllowsAvailabilityErrors(t *testing.T) { + svc := &service{sssRequired: false} + + blocked, err := svc.handleSSSError(logrus.Fields{}, "company-id", &sss.AuthError{Message: "auth failed"}) + if err != nil { + t.Fatalf("expected optional SSS auth error to continue, got %v", err) + } + if blocked { + t.Fatal("expected optional SSS auth error not to block") + } +} + +func TestComplianceCacheKeyPrefersExternalID(t *testing.T) { + svc := &service{} + + got := svc.complianceCacheKey(&models.Company{ + CompanyID: "internal-id", + CompanyExternalID: "external-id", + }) + + if got != "external-id" { + t.Fatalf("expected external id cache key, got %q", got) + } +} + +func TestComplianceCacheExpires(t *testing.T) { + svc := &service{ + complianceCache: map[string]complianceCacheEntry{ + "company-id": { + sanctioned: true, + err: errors.New("cached"), + expiresAt: time.Now().Add(-time.Second), + }, + }, + } + + if _, ok := svc.getComplianceCache("company-id"); ok { + t.Fatal("expected expired cache entry to be ignored") + } +}