Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions solid/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ type Config struct {
S3Bucket string `mapstructure:"s3_bucket"`
S3Region string `mapstructure:"s3_region"`
S3Prefix string `mapstructure:"s3_prefix"`

APIKeyAccess bool `mapstructure:"api_key_access"`
}

func LoadConfig(path string) (Config, error) {
Expand All @@ -42,6 +44,8 @@ func LoadConfig(path string) (Config, error) {
viper.SetDefault("s3_region", "")
viper.SetDefault("s3_prefix", "artifacts")

viper.SetDefault("api_key_access", "true")

viper.AutomaticEnv()

config := Config{}
Expand Down
15 changes: 15 additions & 0 deletions solid/controllers/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package controllers

import (
"context"

"github.com/zeddo123/mlsolid/solid/types"
)

func (c *Controller) CreateAPIKey(ctx context.Context, perm types.Permissions) (string, error) {
return c.Redis.CreateAPIKey(ctx, perm)
}

func (c *Controller) GetPermissions(ctx context.Context, key string) (types.Permissions, error) {
return c.Redis.APIKeyPermissions(ctx, key)
}
32 changes: 30 additions & 2 deletions solid/controllers/controller.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
package controllers

import (
"context"
"fmt"

"github.com/zeddo123/mlsolid/solid"
"github.com/zeddo123/mlsolid/solid/s3"
"github.com/zeddo123/mlsolid/solid/store"
"github.com/zeddo123/mlsolid/solid/types"
)

type Controller struct {
Redis store.RedisStore
S3 s3.ObjectStore
Redis store.RedisStore
S3 s3.ObjectStore
Config solid.Config
APIKey string
}

func (c *Controller) Permissions() (types.Permissions, error) {
return c.GetPermissions(context.Background(), c.APIKey)
}

func (c *Controller) HasPermission(perm types.PermissionType) error {
if !c.Config.APIKeyAccess {
return nil
}

perms, err := c.Permissions()
if err != nil {
return err
}

if !perms.HasPermission(perm) {
return fmt.Errorf("unauthorized")
}

return nil
}
12 changes: 12 additions & 0 deletions solid/controllers/exp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@ import (
)

func (c *Controller) Exp(ctx context.Context, expID string) (*types.Experiment, error) {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return nil, err
}

return c.Redis.Exp(ctx, types.NormalizeID(expID))
}

func (c *Controller) ExpRuns(ctx context.Context, expID string) ([]string, error) {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return nil, err
}

return c.Redis.ExpRunIDs(ctx, expID)
}

func (c *Controller) Exps(ctx context.Context) ([]string, error) {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return nil, err
}

return c.Redis.Exps(ctx)
}
28 changes: 28 additions & 0 deletions solid/controllers/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import (
)

func (c *Controller) ModelRegistry(ctx context.Context, name string) (*types.ModelRegistry, error) {
if err := c.HasPermission(types.PullRegistryPermission); err != nil {
return nil, err
}

registry, err := c.Redis.ModelRegistry(ctx, name)
if err != nil {
return nil, err
Expand All @@ -16,6 +20,10 @@ func (c *Controller) ModelRegistry(ctx context.Context, name string) (*types.Mod
}

func (c *Controller) CreateModelRegistry(ctx context.Context, name string) error {
if err := c.HasPermission(types.PushRegistryPermission); err != nil {
return err
}

if name == "" {
return types.NewBadRequest("model registry name cannot be empty")
}
Expand All @@ -24,14 +32,26 @@ func (c *Controller) CreateModelRegistry(ctx context.Context, name string) error
}

func (c *Controller) LastModelEntry(ctx context.Context, registryName string) (types.ModelEntry, error) {
if err := c.HasPermission(types.PullRegistryPermission); err != nil {
return types.ModelEntry{}, err
}

return c.Redis.LastModel(ctx, registryName)
}

func (c *Controller) TaggedModel(ctx context.Context, registryName string, tag string) (types.ModelEntry, error) {
if err := c.HasPermission(types.PullRegistryPermission); err != nil {
return types.ModelEntry{}, err
}

return c.Redis.ModelByTag(ctx, registryName, tag)
}

func (c *Controller) AddModelEntry(ctx context.Context, registryName string, url string, tags ...string) error {
if err := c.HasPermission(types.PushRegistryPermission); err != nil {
return err
}

registry, err := c.Redis.ModelRegistry(ctx, registryName)
if err != nil {
return err
Expand All @@ -45,6 +65,10 @@ func (c *Controller) AddModelEntry(ctx context.Context, registryName string, url
func (c *Controller) AddArtifactToRegistry(ctx context.Context, registryName string, runID string,
artifactID string, tags ...string,
) error {
if err := c.HasPermission(types.PushRegistryPermission); err != nil {
return err
}

artifact, err := c.Redis.Artifact(ctx, runID, artifactID)
if err != nil {
return err
Expand All @@ -54,6 +78,10 @@ func (c *Controller) AddArtifactToRegistry(ctx context.Context, registryName str
}

func (c *Controller) TagModel(ctx context.Context, registryName string, version int, tags ...string) error {
if err := c.HasPermission(types.PushRegistryPermission); err != nil {
return err
}

registry, err := c.Redis.ModelRegistry(ctx, registryName)
if err != nil {
return err
Expand Down
24 changes: 24 additions & 0 deletions solid/controllers/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
)

func (c *Controller) CreateRun(ctx context.Context, run types.Run) error {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return err
}

if c.S3 == nil {
return types.NewInternalErr("object store is not configured")
}
Expand Down Expand Up @@ -43,6 +47,10 @@ func (c *Controller) CreateRun(ctx context.Context, run types.Run) error {
}

func (c *Controller) Run(ctx context.Context, runID string) (*types.Run, error) {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return nil, err
}

id := types.NormalizeID(runID)

ok, err := c.Redis.RunExists(ctx, id)
Expand All @@ -58,6 +66,10 @@ func (c *Controller) Run(ctx context.Context, runID string) (*types.Run, error)
}

func (c *Controller) Runs(ctx context.Context, ids []string) ([]*types.Run, error) {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return nil, err
}

normalized := make([]string, len(ids))

for i, id := range ids {
Expand All @@ -73,6 +85,10 @@ func (c *Controller) Runs(ctx context.Context, ids []string) ([]*types.Run, erro
}

func (c *Controller) AddMetrics(ctx context.Context, runID string, m []types.Metric) error {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return err
}

ok, err := c.Redis.RunExists(ctx, types.NormalizeID(runID))
if err != nil {
return err
Expand All @@ -97,6 +113,10 @@ func (c *Controller) AddMetrics(ctx context.Context, runID string, m []types.Met
}

func (c *Controller) AddArtifacts(ctx context.Context, runID string, as []types.Artifact) error {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return err
}

ids := types.ArtifactIDs(as)
artifactsMap := types.ArtifactIDMap(as)

Expand Down Expand Up @@ -126,6 +146,10 @@ func (c *Controller) AddArtifacts(ctx context.Context, runID string, as []types.
func (c *Controller) Artifact(ctx context.Context, runID string,
artifact string,
) (*types.SavedArtifact, io.ReadCloser, error) {
if err := c.HasPermission(types.PushExperimentsPermission); err != nil {
return nil, nil, err
}

a, err := c.Redis.Artifact(ctx, runID, artifact)
if err != nil {
return nil, nil, err
Expand Down
43 changes: 43 additions & 0 deletions solid/store/api_key_access.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package store

import (
"context"

"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/zeddo123/mlsolid/solid/types"
)

func (r *RedisStore) CreateAPIKey(ctx context.Context, perm types.Permissions) (string, error) {
key := uuid.NewString()
redisKey := r.makeAPIKeyKey(key)

fn := func(tx *redis.Tx) error {
_, err := tx.Pipelined(ctx, func(p redis.Pipeliner) error {
p.HSet(ctx, redisKey, perm.Mapping())
p.Expire(ctx, redisKey, perm.Expiry)

return nil
})

return err
}

err := r.runTx(ctx, fn, TransactionMaxTries, redisKey)
if err != nil {
return "", err
}

return key, nil
}

func (r *RedisStore) APIKeyPermissions(ctx context.Context, key string) (types.Permissions, error) {
redisKey := r.makeAPIKeyKey(key)

m, err := r.Client.HGetAll(ctx, redisKey).Result()
if err != nil {
return types.Permissions{}, err
}

return types.NewPermissions(m)
}
6 changes: 6 additions & 0 deletions solid/store/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
// Example
// tag:registry:yolov12:prod
ModelRegistryTagKeyPattern = "tag:registry:%s:%s"
// APIKeyKeyPattern pattern for an api key
APIKeyKeyPattern = "apikey:%s"

TransactionMaxTries = 10
)
Expand Down Expand Up @@ -76,9 +78,13 @@
return fmt.Sprintf(ModelRegistryTagKeyPattern, name, tag)
}

func (r *RedisStore) makeAPIKeyKey(key string) string {
return fmt.Sprintf(APIKeyKeyPattern, key)
}

// runTx runs a transaction function with an optimistic locks on the keys passed as argument.
func (r *RedisStore) runTx(ctx context.Context, fn func(tx *redis.Tx) error,
maxRetries int, keys ...string,

Check failure on line 87 in solid/store/redis.go

View workflow job for this annotation

GitHub Actions / lint

(*RedisStore).runTx - maxRetries always receives TransactionMaxTries (10) (unparam)
) error {
for range maxRetries {
err := r.Client.Watch(ctx, fn, keys...)
Expand Down
65 changes: 65 additions & 0 deletions solid/types/api_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package types

import (
"strconv"
"time"
)

type PermissionType string

const (
PushRegistryPermission PermissionType = "PushRegistry"
PullRegistryPermission PermissionType = "PullRegistry"
PushExperimentsPermission PermissionType = "PushExperiments"
)

type Permissions struct {
PullRegistry bool
PushRegistry bool
PushExperiments bool
Expiry time.Duration
}

func NewPermissions(m map[string]string) (Permissions, error) {
pullRegistry, err := strconv.ParseBool(m["PullRegistry"])
if err != nil {
return Permissions{}, err
}

pushRegistry, err := strconv.ParseBool(m["PushRegistry"])
if err != nil {
return Permissions{}, err
}

pushExperiments, err := strconv.ParseBool(m["PushExperiments"])
if err != nil {
return Permissions{}, err
}

return Permissions{
PullRegistry: pullRegistry,
PushRegistry: pushRegistry,
PushExperiments: pushExperiments,
}, nil
}

func (p *Permissions) Mapping() map[string]string {
return map[string]string{
"PullRegistry": strconv.FormatBool(p.PullRegistry),
"PushRegistry": strconv.FormatBool(p.PushRegistry),
"PushExperiments": strconv.FormatBool(p.PushExperiments),
}
}

func (p *Permissions) HasPermission(perm PermissionType) bool {
switch perm {
case PullRegistryPermission:
return p.PullRegistry
case PushRegistryPermission:
return p.PushRegistry
case PushExperimentsPermission:
return p.PushExperiments
default:
return false
}
}
Loading