diff --git a/core/group/filter.go b/core/group/filter.go index 314423e15..6638c8bfc 100644 --- a/core/group/filter.go +++ b/core/group/filter.go @@ -1,10 +1,16 @@ package group +import "github.com/raystack/frontier/core/authenticate" + type Filter struct { // only one filter gets applied at a time SU bool // super user + // Principal restricts results to groups the principal has a policy on. + // Intersected with GroupIDs when both are set. + Principal *authenticate.Principal + OrganizationID string State State WithMemberCount bool diff --git a/core/group/mocks/membership_service.go b/core/group/mocks/membership_service.go index b425239ae..cf94cadab 100644 --- a/core/group/mocks/membership_service.go +++ b/core/group/mocks/membership_service.go @@ -5,6 +5,8 @@ package mocks import ( context "context" + authenticate "github.com/raystack/frontier/core/authenticate" + mock "github.com/stretchr/testify/mock" ) @@ -21,6 +23,66 @@ func (_m *MembershipService) EXPECT() *MembershipService_Expecter { return &MembershipService_Expecter{mock: &_m.Mock} } +// ListGroupsByPrincipal provides a mock function with given fields: ctx, principal, orgID +func (_m *MembershipService) ListGroupsByPrincipal(ctx context.Context, principal authenticate.Principal, orgID string) ([]string, error) { + ret := _m.Called(ctx, principal, orgID) + + if len(ret) == 0 { + panic("no return value specified for ListGroupsByPrincipal") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string) ([]string, error)); ok { + return rf(ctx, principal, orgID) + } + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string) []string); ok { + r0 = rf(ctx, principal, orgID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, string) error); ok { + r1 = rf(ctx, principal, orgID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListGroupsByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListGroupsByPrincipal' +type MembershipService_ListGroupsByPrincipal_Call struct { + *mock.Call +} + +// ListGroupsByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principal authenticate.Principal +// - orgID string +func (_e *MembershipService_Expecter) ListGroupsByPrincipal(ctx interface{}, principal interface{}, orgID interface{}) *MembershipService_ListGroupsByPrincipal_Call { + return &MembershipService_ListGroupsByPrincipal_Call{Call: _e.mock.On("ListGroupsByPrincipal", ctx, principal, orgID)} +} + +func (_c *MembershipService_ListGroupsByPrincipal_Call) Run(run func(ctx context.Context, principal authenticate.Principal, orgID string)) *MembershipService_ListGroupsByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(string)) + }) + return _c +} + +func (_c *MembershipService_ListGroupsByPrincipal_Call) Return(_a0 []string, _a1 error) *MembershipService_ListGroupsByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListGroupsByPrincipal_Call) RunAndReturn(run func(context.Context, authenticate.Principal, string) ([]string, error)) *MembershipService_ListGroupsByPrincipal_Call { + _c.Call.Return(run) + return _c +} + // OnGroupCreated provides a mock function with given fields: ctx, groupID, orgID, creatorID, creatorType func (_m *MembershipService) OnGroupCreated(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string) error { ret := _m.Called(ctx, groupID, orgID, creatorID, creatorType) diff --git a/core/group/mocks/relation_service.go b/core/group/mocks/relation_service.go index 315ea8498..f052c4e3f 100644 --- a/core/group/mocks/relation_service.go +++ b/core/group/mocks/relation_service.go @@ -129,65 +129,6 @@ func (_c *RelationService_ListRelations_Call) RunAndReturn(run func(context.Cont return _c } -// LookupResources provides a mock function with given fields: ctx, rel -func (_m *RelationService) LookupResources(ctx context.Context, rel relation.Relation) ([]string, error) { - ret := _m.Called(ctx, rel) - - if len(ret) == 0 { - panic("no return value specified for LookupResources") - } - - var r0 []string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) ([]string, error)); ok { - return rf(ctx, rel) - } - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) []string); ok { - r0 = rf(ctx, rel) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, relation.Relation) error); ok { - r1 = rf(ctx, rel) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// RelationService_LookupResources_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LookupResources' -type RelationService_LookupResources_Call struct { - *mock.Call -} - -// LookupResources is a helper method to define mock.On call -// - ctx context.Context -// - rel relation.Relation -func (_e *RelationService_Expecter) LookupResources(ctx interface{}, rel interface{}) *RelationService_LookupResources_Call { - return &RelationService_LookupResources_Call{Call: _e.mock.On("LookupResources", ctx, rel)} -} - -func (_c *RelationService_LookupResources_Call) Run(run func(ctx context.Context, rel relation.Relation)) *RelationService_LookupResources_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(relation.Relation)) - }) - return _c -} - -func (_c *RelationService_LookupResources_Call) Return(_a0 []string, _a1 error) *RelationService_LookupResources_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *RelationService_LookupResources_Call) RunAndReturn(run func(context.Context, relation.Relation) ([]string, error)) *RelationService_LookupResources_Call { - _c.Call.Return(run) - return _c -} - // NewRelationService creates a new instance of RelationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewRelationService(t interface { diff --git a/core/group/service.go b/core/group/service.go index bf71334ff..98e149675 100644 --- a/core/group/service.go +++ b/core/group/service.go @@ -21,7 +21,6 @@ import ( type RelationService interface { ListRelations(ctx context.Context, rel relation.Relation) ([]relation.Relation, error) - LookupResources(ctx context.Context, rel relation.Relation) ([]string, error) Delete(ctx context.Context, rel relation.Relation) error } @@ -37,6 +36,7 @@ type PolicyService interface { type MembershipService interface { OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error + ListGroupsByPrincipal(ctx context.Context, principal authenticate.Principal, orgID string) ([]string, error) } type Service struct { @@ -90,6 +90,23 @@ func (s Service) GetByIDs(ctx context.Context, ids []string) ([]Group, error) { } func (s Service) List(ctx context.Context, flt Filter) ([]Group, error) { + if flt.Principal != nil { + if s.membershipService == nil { + return nil, fmt.Errorf("group: membership service not wired") + } + ids, err := s.membershipService.ListGroupsByPrincipal(ctx, *flt.Principal, flt.OrganizationID) + if err != nil { + return nil, err + } + if len(flt.GroupIDs) > 0 { + ids = utils.Intersection(ids, flt.GroupIDs) + } + if len(ids) == 0 { + return []Group{}, nil + } + flt.GroupIDs = ids + } + if flt.OrganizationID == "" && len(flt.GroupIDs) == 0 && !flt.SU { return nil, ErrInvalidID } @@ -124,45 +141,6 @@ func (s Service) Update(ctx context.Context, grp Group) (Group, error) { return Group{}, ErrInvalidID } -func (s Service) ListByUser(ctx context.Context, principal authenticate.Principal, flt Filter) ([]Group, error) { - subjectID, subjectType := principal.ResolveSubject() - subjectIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.GroupNamespace}, - Subject: relation.Subject{Namespace: subjectType, ID: subjectID}, - RelationName: schema.MembershipPermission, - }) - if err != nil { - return nil, err - } - subjectIDs, err = s.intersectPATScope(ctx, principal, schema.GroupNamespace, subjectIDs) - if err != nil { - return nil, err - } - if len(subjectIDs) == 0 { - // no groups - return nil, nil - } - flt.GroupIDs = subjectIDs - return s.List(ctx, flt) -} - -// intersectPATScope narrows resource IDs to only those the PAT is scoped to. -func (s Service) intersectPATScope(ctx context.Context, principal authenticate.Principal, - namespace string, resourceIDs []string) ([]string, error) { - if principal.PAT == nil || len(resourceIDs) == 0 { - return resourceIDs, nil - } - patIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ - Object: relation.Object{Namespace: namespace}, - Subject: relation.Subject{ID: principal.PAT.ID, Namespace: schema.PATPrincipal}, - RelationName: schema.GetPermission, - }) - if err != nil { - return nil, err - } - return utils.Intersection(resourceIDs, patIDs), nil -} - // ListByOrganization will be useful for nested groups but we don't do that at the moment // so it will not be directly used func (s Service) ListByOrganization(ctx context.Context, id string) ([]Group, error) { diff --git a/core/group/service_test.go b/core/group/service_test.go index 130477dd3..1a8c0efd7 100644 --- a/core/group/service_test.go +++ b/core/group/service_test.go @@ -12,7 +12,6 @@ import ( "github.com/raystack/frontier/core/group" "github.com/raystack/frontier/core/group/mocks" "github.com/raystack/frontier/core/policy" - "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/user" pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" @@ -251,106 +250,105 @@ func TestService_Update(t *testing.T) { }) } -func TestService_ListByUser(t *testing.T) { +func TestService_List_PrincipalFilter(t *testing.T) { ctx := context.Background() - t.Run("should resolve PAT to user and intersect with PAT group scope", func(t *testing.T) { + t.Run("user principal — narrows GroupIDs via membership shim", func(t *testing.T) { mockRepo := mocks.NewRepository(t) mockRelationSvc := mocks.NewRelationService(t) mockAuthnSvc := mocks.NewAuthnService(t) mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) - // LookupResources for user's group memberships - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.GroupNamespace}, - Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"}, - RelationName: schema.MembershipPermission, - }).Return([]string{"group-1", "group-2", "group-3"}, nil).Once() - - // LookupResources for PAT's group scope - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.GroupNamespace}, - Subject: relation.Subject{ID: "pat-456", Namespace: schema.PATPrincipal}, - RelationName: schema.GetPermission, - }).Return([]string{"group-1", "group-3"}, nil).Once() - - // Repo should be called with intersection + principal := authenticate.Principal{ID: "user-123", Type: schema.UserPrincipal} + mockMembershipSvc.EXPECT().ListGroupsByPrincipal(ctx, principal, ""). + Return([]string{"group-1", "group-2"}, nil).Once() mockRepo.On("List", ctx, group.Filter{ - GroupIDs: []string{"group-1", "group-3"}, + Principal: &principal, + GroupIDs: []string{"group-1", "group-2"}, }).Return([]group.Group{ {ID: "group-1", Name: "group-one"}, - {ID: "group-3", Name: "group-three"}, + {ID: "group-2", Name: "group-two"}, }, nil).Once() - result, err := svc.ListByUser(ctx, authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"}, - }, group.Filter{}) - + result, err := svc.List(ctx, group.Filter{Principal: &principal}) assert.NoError(t, err) assert.Len(t, result, 2) }) - t.Run("should return nil when PAT has no group scope overlap", func(t *testing.T) { + t.Run("Principal + OrganizationID — forwards orgID to shim and repo", func(t *testing.T) { mockRepo := mocks.NewRepository(t) mockRelationSvc := mocks.NewRelationService(t) mockAuthnSvc := mocks.NewAuthnService(t) mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.GroupNamespace}, - Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"}, - RelationName: schema.MembershipPermission, - }).Return([]string{"group-1"}, nil).Once() - - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.GroupNamespace}, - Subject: relation.Subject{ID: "pat-456", Namespace: schema.PATPrincipal}, - RelationName: schema.GetPermission, - }).Return([]string{"group-2"}, nil).Once() - - result, err := svc.ListByUser(ctx, authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"}, - }, group.Filter{}) + principal := authenticate.Principal{ID: "user-123", Type: schema.UserPrincipal} + mockMembershipSvc.EXPECT().ListGroupsByPrincipal(ctx, principal, "org-1"). + Return([]string{"group-1"}, nil).Once() + mockRepo.On("List", ctx, group.Filter{ + Principal: &principal, + OrganizationID: "org-1", + GroupIDs: []string{"group-1"}, + }).Return([]group.Group{{ID: "group-1", Name: "group-one"}}, nil).Once() + result, err := svc.List(ctx, group.Filter{Principal: &principal, OrganizationID: "org-1"}) assert.NoError(t, err) - assert.Nil(t, result) + assert.Len(t, result, 1) }) - t.Run("should pass through for regular user principal", func(t *testing.T) { + t.Run("PAT principal — shim handles PAT scoping, service stays oblivious", func(t *testing.T) { mockRepo := mocks.NewRepository(t) mockRelationSvc := mocks.NewRelationService(t) mockAuthnSvc := mocks.NewAuthnService(t) mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.GroupNamespace}, - Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"}, - RelationName: schema.MembershipPermission, - }).Return([]string{"group-1", "group-2"}, nil).Once() - + principal := authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"}, + } + mockMembershipSvc.EXPECT().ListGroupsByPrincipal(ctx, principal, ""). + Return([]string{"group-1", "group-3"}, nil).Once() mockRepo.On("List", ctx, group.Filter{ - GroupIDs: []string{"group-1", "group-2"}, + Principal: &principal, + GroupIDs: []string{"group-1", "group-3"}, }).Return([]group.Group{ {ID: "group-1", Name: "group-one"}, - {ID: "group-2", Name: "group-two"}, + {ID: "group-3", Name: "group-three"}, }, nil).Once() - result, err := svc.ListByUser(ctx, authenticate.Principal{ - ID: "user-123", - Type: schema.UserPrincipal, - }, group.Filter{}) - + result, err := svc.List(ctx, group.Filter{Principal: &principal}) assert.NoError(t, err) assert.Len(t, result, 2) }) + + t.Run("empty membership result — short-circuits to empty slice", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockRelationSvc := mocks.NewRelationService(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) + + principal := authenticate.Principal{ID: "user-123", Type: schema.UserPrincipal} + mockMembershipSvc.EXPECT().ListGroupsByPrincipal(ctx, principal, ""). + Return(nil, nil).Once() + + result, err := svc.List(ctx, group.Filter{Principal: &principal}) + assert.NoError(t, err) + assert.Empty(t, result) + }) } diff --git a/core/invitation/mocks/group_service.go b/core/invitation/mocks/group_service.go index 1bab713a0..8a22a9a78 100644 --- a/core/invitation/mocks/group_service.go +++ b/core/invitation/mocks/group_service.go @@ -5,8 +5,6 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - group "github.com/raystack/frontier/core/group" mock "github.com/stretchr/testify/mock" @@ -82,29 +80,29 @@ func (_c *GroupService_Get_Call) RunAndReturn(run func(context.Context, string) return _c } -// ListByUser provides a mock function with given fields: ctx, principal, flt -func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, principal, flt) +// List provides a mock function with given fields: ctx, flt +func (_m *GroupService) List(ctx context.Context, flt group.Filter) ([]group.Group, error) { + ret := _m.Called(ctx, flt) if len(ret) == 0 { - panic("no return value specified for ListByUser") + panic("no return value specified for List") } var r0 []group.Group var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, group.Filter) ([]group.Group, error)); ok { + return rf(ctx, flt) } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) []group.Group); ok { - r0 = rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, group.Filter) []group.Group); ok { + r0 = rf(ctx, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]group.Group) } } - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, group.Filter) error); ok { - r1 = rf(ctx, principal, flt) + if rf, ok := ret.Get(1).(func(context.Context, group.Filter) error); ok { + r1 = rf(ctx, flt) } else { r1 = ret.Error(1) } @@ -112,32 +110,31 @@ func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.P return r0, r1 } -// GroupService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type GroupService_ListByUser_Call struct { +// GroupService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type GroupService_List_Call struct { *mock.Call } -// ListByUser is a helper method to define mock.On call +// List is a helper method to define mock.On call // - ctx context.Context -// - principal authenticate.Principal // - flt group.Filter -func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *GroupService_ListByUser_Call { - return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} +func (_e *GroupService_Expecter) List(ctx interface{}, flt interface{}) *GroupService_List_Call { + return &GroupService_List_Call{Call: _e.mock.On("List", ctx, flt)} } -func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt group.Filter)) *GroupService_ListByUser_Call { +func (_c *GroupService_List_Call) Run(run func(ctx context.Context, flt group.Filter)) *GroupService_List_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(group.Filter)) + run(args[0].(context.Context), args[1].(group.Filter)) }) return _c } -func (_c *GroupService_ListByUser_Call) Return(_a0 []group.Group, _a1 error) *GroupService_ListByUser_Call { +func (_c *GroupService_List_Call) Return(_a0 []group.Group, _a1 error) *GroupService_List_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { +func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.Filter) ([]group.Group, error)) *GroupService_List_Call { _c.Call.Return(run) return _c } diff --git a/core/invitation/service.go b/core/invitation/service.go index 144c2644c..936f0c620 100644 --- a/core/invitation/service.go +++ b/core/invitation/service.go @@ -54,7 +54,7 @@ type MembershipService interface { type GroupService interface { Get(ctx context.Context, id string) (group.Group, error) - ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) + List(ctx context.Context, flt group.Filter) ([]group.Group, error) } type RelationService interface { @@ -319,9 +319,8 @@ func (s Service) Accept(ctx context.Context, id uuid.UUID) error { // check if the invitation has a group membership if len(invite.GroupIDs) > 0 { - userGroups, err := s.groupSvc.ListByUser(ctx, authenticate.Principal{ - ID: userOb.ID, Type: schema.UserPrincipal, - }, group.Filter{}) + principal := authenticate.Principal{ID: userOb.ID, Type: schema.UserPrincipal} + userGroups, err := s.groupSvc.List(ctx, group.Filter{Principal: &principal}) if err != nil { return err } diff --git a/core/invitation/service_test.go b/core/invitation/service_test.go index ef9a8a79f..b3c4dee6f 100644 --- a/core/invitation/service_test.go +++ b/core/invitation/service_test.go @@ -3,12 +3,16 @@ package invitation_test import ( "context" "testing" + "time" auditMocks "github.com/raystack/frontier/core/auditrecord/mocks" + auditModels "github.com/raystack/frontier/core/auditrecord/models" "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/raystack/frontier/core/group" "github.com/raystack/frontier/core/invitation" "github.com/raystack/frontier/core/invitation/mocks" "github.com/raystack/frontier/core/membership" @@ -86,3 +90,59 @@ func TestService_Create(t *testing.T) { }) } } + +func TestService_Accept_DedupesExistingGroupMembers(t *testing.T) { + ctx := context.Background() + inviteID := uuid.New() + userID := "user-id" + userEmail := "test@example.com" + orgID := "org-id" + + dialer, repo, orgService, groupService, userService, relationService, prefService, auditRecordRepo := mockService(t) + membershipSvc := mocks.NewMembershipService(t) + + repo.EXPECT().Get(ctx, inviteID).Return(invitation.Invitation{ + ID: inviteID, + UserEmailID: userEmail, + OrgID: orgID, + GroupIDs: []string{"g-alpha", "g-gamma"}, + ExpiresAt: time.Now().Add(time.Hour), + }, nil) + + userOb := user.User{ID: userID, Email: userEmail, Title: "Test User"} + userPrincipal := authenticate.Principal{ID: userID, Type: schema.UserPrincipal} + + // isUserOrgMember — already a member, so AddOrganizationMember is skipped + userService.EXPECT().GetByID(ctx, userEmail).Return(userOb, nil) + membershipSvc.EXPECT().ListResourcesByPrincipal(ctx, userPrincipal, schema.OrganizationNamespace, membership.ResourceFilter{}). + Return([]string{orgID}, nil) + + prefService.EXPECT().LoadPlatformPreferences(ctx).Return(map[string]string{}, nil) + + // User is already a member of g-alpha + groupService.EXPECT().List(ctx, group.Filter{Principal: &userPrincipal}). + Return([]group.Group{{ID: "g-alpha"}}, nil) + + // Both invite groups get looked up + groupService.EXPECT().Get(ctx, "g-alpha").Return(group.Group{ID: "g-alpha"}, nil) + groupService.EXPECT().Get(ctx, "g-gamma").Return(group.Group{ID: "g-gamma"}, nil) + + // Only g-gamma is added; g-alpha is skipped (no SetGroupMemberRole expectation for it, + // so the mock would fail if the code called it) + membershipSvc.EXPECT(). + SetGroupMemberRole(ctx, "g-gamma", userID, schema.UserPrincipal, schema.GroupMemberRole). + Return(nil) + + // Audit + delete tail + orgService.EXPECT().Get(ctx, orgID).Return(organization.Organization{ID: orgID, Title: "Test Org"}, nil) + auditRecordRepo.EXPECT().Create(ctx, mock.AnythingOfType("models.AuditRecord")). + Return(auditModels.AuditRecord{}, nil) + relationService.EXPECT().Delete(ctx, mock.AnythingOfType("relation.Relation")).Return(nil) + repo.EXPECT().Delete(ctx, inviteID).Return(nil) + + svc := invitation.NewService(dialer, repo, orgService, groupService, + userService, relationService, prefService, auditRecordRepo, membershipSvc) + + err := svc.Accept(ctx, inviteID) + assert.NoError(t, err) +} diff --git a/core/membership/service.go b/core/membership/service.go index dee69d28c..69662fdc8 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -1568,6 +1568,13 @@ func (s *Service) ListOrgsByPrincipal(ctx context.Context, principal authenticat return s.ListResourcesByPrincipal(ctx, principal, schema.OrganizationNamespace, ResourceFilter{}) } +// ListGroupsByPrincipal Shim for the group package (group → membership would cycle). PATs scope +// orgs and projects, not groups, so a PAT sees exactly its user's groups — resolve the PAT. +func (s *Service) ListGroupsByPrincipal(ctx context.Context, principal authenticate.Principal, orgID string) ([]string, error) { + subjectID, subjectType := principal.ResolveSubject() + return s.listResourcesForPrincipal(ctx, subjectID, subjectType, schema.GroupNamespace, ResourceFilter{OrgID: orgID}) +} + // ListResourcesByPrincipal returns the resource IDs of the given type on which // the principal has at least one policy. Reads Postgres policies — no SpiceDB. // With a PAT, runs the algorithm twice (user, then PAT-as-principal) and diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 36701b484..a82e0b3e9 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -2289,3 +2289,107 @@ func TestService_ListResourcesByPrincipal(t *testing.T) { }) } } + +func TestService_ListGroupsByPrincipal(t *testing.T) { + ctx := context.Background() + + userID := uuid.New().String() + patID := uuid.New().String() + orgA := uuid.New().String() + groupA := uuid.New().String() + groupB := uuid.New().String() + roleGroupMemberID := uuid.New().String() + + tests := []struct { + name string + principal authenticate.Principal + orgID string + setup func(p *mocks.PolicyService, g *mocks.GroupService) + want []string + }{ + { + name: "user principal — reads user's group policies", + principal: authenticate.Principal{ID: userID, Type: schema.UserPrincipal}, + setup: func(p *mocks.PolicyService, _ *mocks.GroupService) { + p.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.GroupNamespace, + }).Return([]policy.Policy{ + {ResourceID: groupA, RoleID: roleGroupMemberID}, + {ResourceID: groupB, RoleID: roleGroupMemberID}, + }, nil) + }, + want: []string{groupA, groupB}, + }, + { + name: "PAT principal — resolves to underlying user, no PAT-side query", + principal: authenticate.Principal{ + ID: userID, + Type: schema.UserPrincipal, + PAT: &pat.PAT{ID: patID, UserID: userID, OrgID: orgA}, + }, + setup: func(p *mocks.PolicyService, _ *mocks.GroupService) { + // only the user-side lookup; no policy.List for PrincipalType=PAT + p.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.GroupNamespace, + }).Return([]policy.Policy{ + {ResourceID: groupA, RoleID: roleGroupMemberID}, + }, nil) + }, + want: []string{groupA}, + }, + { + name: "PAT principal + orgID — narrows result via groupService", + principal: authenticate.Principal{ + ID: userID, + Type: schema.UserPrincipal, + PAT: &pat.PAT{ID: patID, UserID: userID, OrgID: orgA}, + }, + orgID: orgA, + setup: func(p *mocks.PolicyService, g *mocks.GroupService) { + p.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.GroupNamespace, + }).Return([]policy.Policy{ + {ResourceID: groupA, RoleID: roleGroupMemberID}, + {ResourceID: groupB, RoleID: roleGroupMemberID}, + }, nil) + g.EXPECT().List(ctx, group.Filter{ + OrganizationID: orgA, + GroupIDs: []string{groupA, groupB}, + }).Return([]group.Group{{ID: groupA}}, nil) + }, + want: []string{groupA}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mp := mocks.NewPolicyService(t) + mg := mocks.NewGroupService(t) + + tt.setup(mp, mg) + + svc := membership.NewService( + slog.New(slog.NewTextHandler(io.Discard, nil)), + mp, + mocks.NewRelationService(t), + mocks.NewRoleService(t), + mocks.NewOrgService(t), + mocks.NewUserService(t), + mocks.NewProjectService(t), + mg, + mocks.NewServiceuserService(t), + mocks.NewAuditRecordRepository(t), + ) + + got, err := svc.ListGroupsByPrincipal(ctx, tt.principal, tt.orgID) + assert.NoError(t, err) + assert.ElementsMatch(t, tt.want, got) + }) + } +} diff --git a/core/project/mocks/group_service.go b/core/project/mocks/group_service.go index 1fb9b9456..df1e0bfc1 100644 --- a/core/project/mocks/group_service.go +++ b/core/project/mocks/group_service.go @@ -5,10 +5,7 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - group "github.com/raystack/frontier/core/group" - mock "github.com/stretchr/testify/mock" ) @@ -141,29 +138,29 @@ func (_c *GroupService_GetByIDs_Call) RunAndReturn(run func(context.Context, []s return _c } -// ListByUser provides a mock function with given fields: ctx, principal, flt -func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, principal, flt) +// List provides a mock function with given fields: ctx, flt +func (_m *GroupService) List(ctx context.Context, flt group.Filter) ([]group.Group, error) { + ret := _m.Called(ctx, flt) if len(ret) == 0 { - panic("no return value specified for ListByUser") + panic("no return value specified for List") } var r0 []group.Group var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, group.Filter) ([]group.Group, error)); ok { + return rf(ctx, flt) } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) []group.Group); ok { - r0 = rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, group.Filter) []group.Group); ok { + r0 = rf(ctx, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]group.Group) } } - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, group.Filter) error); ok { - r1 = rf(ctx, principal, flt) + if rf, ok := ret.Get(1).(func(context.Context, group.Filter) error); ok { + r1 = rf(ctx, flt) } else { r1 = ret.Error(1) } @@ -171,32 +168,31 @@ func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.P return r0, r1 } -// GroupService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type GroupService_ListByUser_Call struct { +// GroupService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type GroupService_List_Call struct { *mock.Call } -// ListByUser is a helper method to define mock.On call +// List is a helper method to define mock.On call // - ctx context.Context -// - principal authenticate.Principal // - flt group.Filter -func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *GroupService_ListByUser_Call { - return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} +func (_e *GroupService_Expecter) List(ctx interface{}, flt interface{}) *GroupService_List_Call { + return &GroupService_List_Call{Call: _e.mock.On("List", ctx, flt)} } -func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt group.Filter)) *GroupService_ListByUser_Call { +func (_c *GroupService_List_Call) Run(run func(ctx context.Context, flt group.Filter)) *GroupService_List_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(group.Filter)) + run(args[0].(context.Context), args[1].(group.Filter)) }) return _c } -func (_c *GroupService_ListByUser_Call) Return(_a0 []group.Group, _a1 error) *GroupService_ListByUser_Call { +func (_c *GroupService_List_Call) Return(_a0 []group.Group, _a1 error) *GroupService_List_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { +func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.Filter) ([]group.Group, error)) *GroupService_List_Call { _c.Call.Return(run) return _c } diff --git a/core/project/service.go b/core/project/service.go index cdda10eb6..d6572059a 100644 --- a/core/project/service.go +++ b/core/project/service.go @@ -57,7 +57,7 @@ type AuthnService interface { type GroupService interface { Get(ctx context.Context, id string) (group.Group, error) GetByIDs(ctx context.Context, ids []string) ([]group.Group, error) - ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) + List(ctx context.Context, flt group.Filter) ([]group.Group, error) } type Service struct { @@ -200,8 +200,8 @@ func (s Service) listNonInheritedProjectIDs(ctx context.Context, principalID, pr } // projects added via group memberships - groups, err := s.groupService.ListByUser(ctx, - authenticate.Principal{ID: principalID, Type: principalType}, group.Filter{}) + principal := authenticate.Principal{ID: principalID, Type: principalType} + groups, err := s.groupService.List(ctx, group.Filter{Principal: &principal}) if err != nil { return nil, err } diff --git a/core/project/service_test.go b/core/project/service_test.go index ab45b1023..9bec3baa5 100644 --- a/core/project/service_test.go +++ b/core/project/service_test.go @@ -406,7 +406,9 @@ func TestService_ListByUser(t *testing.T) { }, }, nil) - groupService.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, group.Filter{}).Return([]group.Group{}, nil) + groupService.EXPECT().List(ctx, group.Filter{ + Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, + }).Return([]group.Group{}, nil) repo.EXPECT().List(ctx, project.Filter{ ProjectIDs: []string{"project-id"}, @@ -464,7 +466,9 @@ func TestService_ListByUser(t *testing.T) { }, }, nil) - groupService.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, group.Filter{}).Return([]group.Group{ + groupService.EXPECT().List(ctx, group.Filter{ + Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, + }).Return([]group.Group{ { ID: "group-id", }, @@ -613,7 +617,9 @@ func TestService_ListByUser(t *testing.T) { }, nil) // Group lookup uses user-only principal (no double PAT filtering) - groupService.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, group.Filter{}).Return([]group.Group{}, nil) + groupService.EXPECT().List(ctx, group.Filter{ + Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, + }).Return([]group.Group{}, nil) // PAT scope intersection relationService.EXPECT().LookupResources(ctx, relation.Relation{ @@ -643,6 +649,80 @@ func TestService_ListByUser(t *testing.T) { return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, }, + { + name: "PAT principal with non-inherited surfaces group-mediated projects intersected with PAT scope", + args: args{ + principal: authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, + }, + flt: project.Filter{ + NonInherited: true, + }, + }, + want: []project.Project{ + { + ID: "project-via-group", + Name: "group-project", + Organization: organization.Organization{ID: "org-1"}, + }, + }, + wantErr: false, + setup: func() *project.Service { + repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) + _ = roleService + // User has no direct project policies; everything comes via group + policyService.EXPECT().List(ctx, policy.Filter{ + PrincipalType: schema.UserPrincipal, + PrincipalID: "user-id", + ResourceType: schema.ProjectNamespace, + }).Return([]policy.Policy{}, nil) + + // User is a member of a group (PAT resolved to user before this call) + groupService.EXPECT().List(ctx, group.Filter{ + Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, + }).Return([]group.Group{{ID: "group-id"}}, nil) + + // That group has policy on a project + policyService.EXPECT().List(ctx, policy.Filter{ + PrincipalType: schema.GroupPrincipal, + PrincipalIDs: []string{"group-id"}, + ResourceType: schema.ProjectNamespace, + }).Return([]policy.Policy{ + { + ResourceID: "project-via-group", + ResourceType: schema.ProjectNamespace, + PrincipalID: "group-id", + PrincipalType: schema.GroupPrincipal, + }, + }, nil) + + // PAT scope grants the same project + relationService.EXPECT().LookupResources(ctx, relation.Relation{ + Object: relation.Object{ + Namespace: schema.ProjectNamespace, + }, + Subject: relation.Subject{ + ID: "pat-456", + Namespace: schema.PATPrincipal, + }, + RelationName: schema.GetPermission, + }).Return([]string{"project-via-group"}, nil) + + repo.EXPECT().List(ctx, project.Filter{ + ProjectIDs: []string{"project-via-group"}, + NonInherited: true, + }).Return([]project.Project{ + { + ID: "project-via-group", + Name: "group-project", + Organization: organization.Organization{ID: "org-1"}, + }, + }, nil) + return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + }, + }, { name: "PAT principal with no project overlap returns empty", args: args{ diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 90630da99..ac952fc52 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -301,7 +301,6 @@ type GroupService interface { GetByIDs(ctx context.Context, ids []string) ([]group.Group, error) List(ctx context.Context, flt group.Filter) ([]group.Group, error) Update(ctx context.Context, grp group.Group) (group.Group, error) - ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) RemoveUsers(ctx context.Context, groupID string, userID []string) error Enable(ctx context.Context, id string) error Disable(ctx context.Context, id string) error diff --git a/internal/api/v1beta1connect/mocks/group_service.go b/internal/api/v1beta1connect/mocks/group_service.go index 2046ea845..6b5512354 100644 --- a/internal/api/v1beta1connect/mocks/group_service.go +++ b/internal/api/v1beta1connect/mocks/group_service.go @@ -5,10 +5,7 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - group "github.com/raystack/frontier/core/group" - mock "github.com/stretchr/testify/mock" ) @@ -351,66 +348,6 @@ func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.F return _c } -// ListByUser provides a mock function with given fields: ctx, principal, flt -func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, principal, flt) - - if len(ret) == 0 { - panic("no return value specified for ListByUser") - } - - var r0 []group.Group - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, principal, flt) - } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) []group.Group); ok { - r0 = rf(ctx, principal, flt) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]group.Group) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, group.Filter) error); ok { - r1 = rf(ctx, principal, flt) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GroupService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type GroupService_ListByUser_Call struct { - *mock.Call -} - -// ListByUser is a helper method to define mock.On call -// - ctx context.Context -// - principal authenticate.Principal -// - flt group.Filter -func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *GroupService_ListByUser_Call { - return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} -} - -func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt group.Filter)) *GroupService_ListByUser_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(group.Filter)) - }) - return _c -} - -func (_c *GroupService_ListByUser_Call) Return(_a0 []group.Group, _a1 error) *GroupService_ListByUser_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { - _c.Call.Return(run) - return _c -} - // RemoveUsers provides a mock function with given fields: ctx, groupID, userID func (_m *GroupService) RemoveUsers(ctx context.Context, groupID string, userID []string) error { ret := _m.Called(ctx, groupID, userID) diff --git a/internal/api/v1beta1connect/user.go b/internal/api/v1beta1connect/user.go index e779a2f11..8c859e135 100644 --- a/internal/api/v1beta1connect/user.go +++ b/internal/api/v1beta1connect/user.go @@ -464,11 +464,13 @@ func (h *ConnectHandler) ListUserGroups(ctx context.Context, request *connect.Re errorLogger := NewErrorLogger() var groups []*frontierv1beta1.Group - groupsList, err := h.groupService.ListByUser(ctx, authenticate.Principal{ - ID: request.Msg.GetId(), Type: schema.UserPrincipal, - }, group.Filter{OrganizationID: request.Msg.GetOrgId()}) + principal := authenticate.Principal{ID: request.Msg.GetId(), Type: schema.UserPrincipal} + groupsList, err := h.groupService.List(ctx, group.Filter{ + Principal: &principal, + OrganizationID: request.Msg.GetOrgId(), + }) if err != nil { - errorLogger.LogServiceError(ctx, request, "ListUserGroups.ListByUser", err, + errorLogger.LogServiceError(ctx, request, "ListUserGroups.List", err, "user_id", request.Msg.GetId(), "org_id", request.Msg.GetOrgId()) @@ -508,14 +510,13 @@ func (h *ConnectHandler) ListCurrentUserGroups(ctx context.Context, request *con var groupsPb []*frontierv1beta1.Group var accessPairsPb []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair - groupsList, err := h.groupService.ListByUser(ctx, principal, - group.Filter{ - OrganizationID: request.Msg.GetOrgId(), - WithMemberCount: request.Msg.GetWithMemberCount(), - }, - ) + groupsList, err := h.groupService.List(ctx, group.Filter{ + Principal: &principal, + OrganizationID: request.Msg.GetOrgId(), + WithMemberCount: request.Msg.GetWithMemberCount(), + }) if err != nil { - errorLogger.LogServiceError(ctx, request, "ListCurrentUserGroups.ListByUser", err, + errorLogger.LogServiceError(ctx, request, "ListCurrentUserGroups.List", err, "principal_id", principal.ID, "principal_type", principal.Type, "org_id", request.Msg.GetOrgId()) diff --git a/internal/api/v1beta1connect/user_test.go b/internal/api/v1beta1connect/user_test.go index 4196f3f77..8dffaf9b6 100644 --- a/internal/api/v1beta1connect/user_test.go +++ b/internal/api/v1beta1connect/user_test.go @@ -914,7 +914,10 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should list user groups successfully", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: userID, Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return([]group.Group{ + gs.EXPECT().List(mock.Anything, group.Filter{ + Principal: &authenticate.Principal{ID: userID, Type: "app/user"}, + OrganizationID: orgID, + }).Return([]group.Group{ { ID: "group-1", Name: "test-group-1", @@ -966,7 +969,10 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should return empty list when user has no groups", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: userID, Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) + gs.EXPECT().List(mock.Anything, group.Filter{ + Principal: &authenticate.Principal{ID: userID, Type: "app/user"}, + OrganizationID: orgID, + }).Return([]group.Group{}, nil) }, req: &frontierv1beta1.ListUserGroupsRequest{ Id: userID, @@ -980,7 +986,10 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should return not found error for invalid user ID", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "invalid-id", Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return(nil, group.ErrInvalidID) + gs.EXPECT().List(mock.Anything, group.Filter{ + Principal: &authenticate.Principal{ID: "invalid-id", Type: "app/user"}, + OrganizationID: orgID, + }).Return(nil, group.ErrInvalidID) }, req: &frontierv1beta1.ListUserGroupsRequest{ Id: "invalid-id", @@ -992,7 +1001,10 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should return internal error for service failure", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: userID, Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return(nil, errors.New("database error")) + gs.EXPECT().List(mock.Anything, group.Filter{ + Principal: &authenticate.Principal{ID: userID, Type: "app/user"}, + OrganizationID: orgID, + }).Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListUserGroupsRequest{ Id: userID, @@ -1061,7 +1073,10 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - gs.EXPECT().ListByUser(mock.Anything, mockPrincipal, group.Filter{OrganizationID: orgID}).Return([]group.Group{ + gs.EXPECT().List(mock.Anything, group.Filter{ + Principal: &mockPrincipal, + OrganizationID: orgID, + }).Return([]group.Group{ { ID: "group-1", Name: "test-group-1", @@ -1102,7 +1117,10 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { User: &user.User{ID: "user-1", Email: "test@example.com"}, } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - gs.EXPECT().ListByUser(mock.Anything, mockPrincipal, group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) + gs.EXPECT().List(mock.Anything, group.Filter{ + Principal: &mockPrincipal, + OrganizationID: orgID, + }).Return([]group.Group{}, nil) }, req: &frontierv1beta1.ListCurrentUserGroupsRequest{ OrgId: orgID, @@ -1133,7 +1151,10 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { User: &user.User{ID: "user-1", Email: "test@example.com"}, } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - gs.EXPECT().ListByUser(mock.Anything, mockPrincipal, group.Filter{OrganizationID: orgID}).Return(nil, errors.New("database error")) + gs.EXPECT().List(mock.Anything, group.Filter{ + Principal: &mockPrincipal, + OrganizationID: orgID, + }).Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListCurrentUserGroupsRequest{ OrgId: orgID,