diff --git a/internal/app/deps.go b/internal/app/deps.go index 817e481..e28ec72 100644 --- a/internal/app/deps.go +++ b/internal/app/deps.go @@ -10,6 +10,7 @@ import ( "github.com/GaIsBAX/Webhix/internal/config" "github.com/GaIsBAX/Webhix/internal/core" "github.com/GaIsBAX/Webhix/internal/hub" + "github.com/GaIsBAX/Webhix/internal/notify" "github.com/GaIsBAX/Webhix/internal/repos" "github.com/GaIsBAX/Webhix/internal/server" "github.com/GaIsBAX/Webhix/internal/store" @@ -38,7 +39,12 @@ func newDependencies(ctx context.Context, cfg *config.Config) (*dependencies, er } repos := newRepositories(infra.db) - services := newServices(repos) + + notifyRegistry := notify.NewRegistry(map[string]notify.Provider{ + "telegram": notify.NewTelegramProvider(), + }) + + services := newServices(repos, notifyRegistry) deps.mux = mux deps.cfg = cfg @@ -46,7 +52,7 @@ func newDependencies(ctx context.Context, cfg *config.Config) (*dependencies, er deps.infra = infra deps.repos = repos deps.services = services - deps.handlers = newHandlers(&deps) + deps.handlers = newHandlers(&deps, notifyRegistry) deps.handlers.registerRoutes() staticFS, err := fs.Sub(web.Static, "static") @@ -63,10 +69,10 @@ type services struct { serve *core.Serve } -func newServices(repos *repositories) *services { +func newServices(repos *repositories, sender core.NotificationSender) *services { hook := core.NewHook(repos.hook, func() string { return pkg.GeneratePrefixedString("ho") - }) + }, sender) serve := core.NewServe(repos.serve) return &services{ @@ -129,12 +135,14 @@ type handlers struct { hook *server.Hook } -func newHandlers(deps *dependencies) *handlers { +func newHandlers(deps *dependencies, registry *notify.Registry) *handlers { return &handlers{ hook: server.NewHook(&server.HookDeps{ - Mux: deps.mux, - Service: deps.services.hook, - Hub: deps.infra.hub, + Mux: deps.mux, + Service: deps.services.hook, + Notifications: deps.services.hook, + Registry: registry, + Hub: deps.infra.hub, Opts: server.HookOptions{ BaseURL: deps.cfg.BaseURL, MaxBodySize: deps.cfg.MaxBodySize, diff --git a/internal/cli/apiclient/client.go b/internal/cli/apiclient/client.go new file mode 100644 index 0000000..967c9cc --- /dev/null +++ b/internal/cli/apiclient/client.go @@ -0,0 +1,100 @@ +package apiclient + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "time" +) + +type Client struct { + server *string + authToken *string + http *http.Client +} + +func New(server, authToken *string) *Client { + return &Client{ + server: server, + authToken: authToken, + http: &http.Client{Timeout: 30 * time.Second}, + } +} + +type apiResponse struct { + Success bool `json:"success"` + Body json.RawMessage `json:"body"` + Error *apiError `json:"error"` +} + +type apiError struct { + Message string `json:"message"` +} + +func (c *Client) Get(ctx context.Context, path string, out any) error { + return c.do(ctx, http.MethodGet, path, nil, out) +} + +func (c *Client) Put(ctx context.Context, path string, body any) error { + return c.do(ctx, http.MethodPut, path, body, nil) +} + +func (c *Client) Post(ctx context.Context, path string, body any) error { + return c.do(ctx, http.MethodPost, path, body, nil) +} + +func (c *Client) Delete(ctx context.Context, path string) error { + return c.do(ctx, http.MethodDelete, path, nil, nil) +} + +func (c *Client) do(ctx context.Context, method, path string, body any, out any) error { + var r io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return err + } + r = bytes.NewReader(data) + } + + req, err := http.NewRequestWithContext(ctx, method, *c.server+path, r) + if err != nil { + return err + } + if *c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+*c.authToken) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.http.Do(req) + if err != nil { + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + slog.Warn("close response body", "err", err) + } + }() + + var ar apiResponse + if err := json.NewDecoder(resp.Body).Decode(&ar); err != nil { + return fmt.Errorf("server returned %d", resp.StatusCode) + } + if !ar.Success { + if ar.Error != nil { + return errors.New(ar.Error.Message) + } + return fmt.Errorf("server returned %d", resp.StatusCode) + } + if out != nil { + return json.Unmarshal(ar.Body, out) + } + return nil +} diff --git a/internal/cli/notify/command.go b/internal/cli/notify/command.go new file mode 100644 index 0000000..de80eb5 --- /dev/null +++ b/internal/cli/notify/command.go @@ -0,0 +1,68 @@ +package notify + +import ( + "context" + "net/url" + + "github.com/GaIsBAX/Webhix/internal/cli/apiclient" + "github.com/GaIsBAX/Webhix/internal/cli/notify/telegram" + "github.com/GaIsBAX/Webhix/internal/config" + "github.com/spf13/cobra" +) + +type notificationChannel struct { + Provider string `json:"provider"` + Config map[string]string `json:"config"` + Redacted []string `json:"redacted"` +} + +func NewCommand(ctx context.Context, cfg *config.Config) *cobra.Command { + opts := DefaultOptions() + if cfg.SecretKey != "" { + opts.AuthToken = cfg.SecretKey + } + + cmd := &cobra.Command{ + Use: "notify", + Short: "Manage endpoint notifications", + } + + RegisterFlags(cmd, &opts) + + client := apiclient.New(&opts.Server, &opts.AuthToken) + + cmd.AddCommand(newListCmd(ctx, client)) + cmd.AddCommand(telegram.NewCommand(ctx, client)) + + return cmd +} + +func newListCmd(ctx context.Context, client *apiclient.Client) *cobra.Command { + return &cobra.Command{ + Use: "list ", + Short: "List all configured notification channels", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + var channels []notificationChannel + if err := client.Get(ctx, "/api/endpoints/"+url.PathEscape(args[0])+"/notifications", &channels); err != nil { + return err + } + + if len(channels) == 0 { + cmd.Println("No notifications configured.") + return nil + } + + for _, ch := range channels { + cmd.Printf("Provider: %s\n", ch.Provider) + for k, v := range ch.Config { + cmd.Printf(" %s: %s\n", k, v) + } + for _, k := range ch.Redacted { + cmd.Printf(" %s: [set]\n", k) + } + } + return nil + }, + } +} diff --git a/internal/cli/notify/flags.go b/internal/cli/notify/flags.go new file mode 100644 index 0000000..05f9f5e --- /dev/null +++ b/internal/cli/notify/flags.go @@ -0,0 +1,13 @@ +package notify + +import "github.com/spf13/cobra" + +const ( + flagServer = "server" + flagAuthToken = "auth-token" +) + +func RegisterFlags(cmd *cobra.Command, opt *Options) { + cmd.PersistentFlags().StringVar(&opt.Server, flagServer, opt.Server, "Webhix server URL") + cmd.PersistentFlags().StringVar(&opt.AuthToken, flagAuthToken, opt.AuthToken, "auth token (env: WEBHIX_SECRET_KEY)") +} diff --git a/internal/cli/notify/options.go b/internal/cli/notify/options.go new file mode 100644 index 0000000..97a9099 --- /dev/null +++ b/internal/cli/notify/options.go @@ -0,0 +1,12 @@ +package notify + +type Options struct { + Server string + AuthToken string +} + +func DefaultOptions() Options { + return Options{ + Server: "http://localhost:8080", + } +} diff --git a/internal/cli/notify/telegram/command.go b/internal/cli/notify/telegram/command.go new file mode 100644 index 0000000..26b5911 --- /dev/null +++ b/internal/cli/notify/telegram/command.go @@ -0,0 +1,99 @@ +package telegram + +import ( + "context" + "net/url" + + "github.com/GaIsBAX/Webhix/internal/cli/apiclient" + "github.com/spf13/cobra" +) + +type Options struct { + BotToken string + ChatID string + ProxyURL string +} + +func NewCommand(ctx context.Context, client *apiclient.Client) *cobra.Command { + cmd := &cobra.Command{ + Use: "telegram", + Short: "Manage Telegram notifications", + } + + cmd.AddCommand(newSetCmd(ctx, client)) + cmd.AddCommand(newTestCmd(ctx, client)) + cmd.AddCommand(newRemoveCmd(ctx, client)) + + return cmd +} + +func newSetCmd(ctx context.Context, client *apiclient.Client) *cobra.Command { + opts := Options{} + + cmd := &cobra.Command{ + Use: "set ", + Short: "Configure Telegram notifications for an endpoint", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + cfg := map[string]string{"bot_token": opts.BotToken, "chat_id": opts.ChatID} + if opts.ProxyURL != "" { + cfg["proxy_url"] = opts.ProxyURL + } + + body := map[string]any{"provider": "telegram", "config": cfg} + path := "/api/endpoints/" + url.PathEscape(args[0]) + "/notifications/telegram" + if err := client.Put(ctx, path, body); err != nil { + return err + } + + cmd.Println("Telegram notifications configured.") + return nil + }, + } + + RegisterFlags(cmd, &opts) + must(cmd.MarkFlagRequired(flagBotToken)) + must(cmd.MarkFlagRequired(flagChatID)) + + return cmd +} + +func newTestCmd(ctx context.Context, client *apiclient.Client) *cobra.Command { + return &cobra.Command{ + Use: "test ", + Short: "Send a test Telegram message", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + path := "/api/endpoints/" + url.PathEscape(args[0]) + "/notifications/telegram/test" + if err := client.Post(ctx, path, nil); err != nil { + return err + } + + cmd.Println("Test message sent.") + return nil + }, + } +} + +func newRemoveCmd(ctx context.Context, client *apiclient.Client) *cobra.Command { + return &cobra.Command{ + Use: "remove ", + Short: "Remove Telegram notifications from an endpoint", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + path := "/api/endpoints/" + url.PathEscape(args[0]) + "/notifications/telegram" + if err := client.Delete(ctx, path); err != nil { + return err + } + + cmd.Println("Telegram notifications removed.") + return nil + }, + } +} + +func must(err error) { + if err != nil { + panic(err) + } +} diff --git a/internal/cli/notify/telegram/flags.go b/internal/cli/notify/telegram/flags.go new file mode 100644 index 0000000..5d50721 --- /dev/null +++ b/internal/cli/notify/telegram/flags.go @@ -0,0 +1,15 @@ +package telegram + +import "github.com/spf13/cobra" + +const ( + flagBotToken = "bot-token" + flagChatID = "chat" + flagProxyURL = "proxy" +) + +func RegisterFlags(cmd *cobra.Command, opt *Options) { + cmd.Flags().StringVar(&opt.BotToken, flagBotToken, opt.BotToken, "Telegram bot token") + cmd.Flags().StringVar(&opt.ChatID, flagChatID, opt.ChatID, "Telegram chat ID") + cmd.Flags().StringVar(&opt.ProxyURL, flagProxyURL, opt.ProxyURL, "Proxy URL (e.g. socks5://127.0.0.1:1080)") +} diff --git a/internal/cli/root.go b/internal/cli/root.go index 6280d06..75dcba8 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -4,6 +4,7 @@ import ( "context" "github.com/GaIsBAX/Webhix/internal/cli/forward" + "github.com/GaIsBAX/Webhix/internal/cli/notify" "github.com/GaIsBAX/Webhix/internal/cli/serve" "github.com/GaIsBAX/Webhix/internal/cli/tunnel" "github.com/GaIsBAX/Webhix/internal/cli/version" @@ -29,6 +30,7 @@ func NewRootCommand( cmd.AddCommand(serve.NewCommand(ctx, cfg, serveFactory)) cmd.AddCommand(forward.NewCommand(ctx, cfg)) + cmd.AddCommand(notify.NewCommand(ctx, cfg)) cmd.AddCommand(tunnel.NewCommand(ctx, cfg)) cmd.AddCommand(version.NewCommand(ctx)) diff --git a/internal/core/hook_core.go b/internal/core/hook_core.go index 0488a69..2e16b4c 100644 --- a/internal/core/hook_core.go +++ b/internal/core/hook_core.go @@ -3,30 +3,46 @@ package core import ( "context" "errors" + "fmt" + "html" + "log/slog" + "sync" "github.com/GaIsBAX/Webhix/internal/domain" ) -const defaultHookResponseStatusCode int64 = 200 +const ( + defaultHookResponseStatusCode int64 = 200 + maxConcurrentNotifications int = 64 +) type TokenGenerator func() string +type NotificationSender interface { + Send(ctx context.Context, provider string, config map[string]string, message string) error +} + type HookRepository interface { - CreateHook(ctx context.Context, token string) (domain.Hook, error) + CreateHook(ctx context.Context, token, name string) (domain.Hook, error) GetHookByToken(ctx context.Context, token string) (domain.Hook, error) ListHooks(ctx context.Context) ([]domain.Hook, error) CreateWebhookRequest(ctx context.Context, params domain.CreateWebhookRequestParams) (domain.WebhookRequest, error) ListWebhookRequests(ctx context.Context, hookID int64) ([]domain.WebhookRequest, error) GetHookResponse(ctx context.Context, hookID int64) (domain.HookResponse, error) UpsertHookResponse(ctx context.Context, hookID int64, params domain.UpsertHookResponseParams) (domain.HookResponse, error) + ListNotificationChannels(ctx context.Context, hookID int64) ([]domain.NotificationChannel, error) + UpsertNotificationChannel(ctx context.Context, hookID int64, provider string, config map[string]string) (domain.NotificationChannel, error) + DeleteNotificationChannel(ctx context.Context, hookID int64, provider string) error } type Hook struct { repo HookRepository generateToken TokenGenerator + sender NotificationSender + notifySem chan struct{} } -func NewHook(repo HookRepository, generateToken TokenGenerator) *Hook { +func NewHook(repo HookRepository, generateToken TokenGenerator, sender NotificationSender) *Hook { if generateToken == nil { generateToken = func() string { return "" } } @@ -34,6 +50,8 @@ func NewHook(repo HookRepository, generateToken TokenGenerator) *Hook { return &Hook{ repo: repo, generateToken: generateToken, + sender: sender, + notifySem: make(chan struct{}, maxConcurrentNotifications), } } @@ -41,12 +59,8 @@ func (s *Hook) ListHooks(ctx context.Context) ([]domain.Hook, error) { return s.repo.ListHooks(ctx) } -func (s *Hook) CreateHook(ctx context.Context, token string) (domain.Hook, error) { - if token == "" { - token = s.generateToken() - } - - return s.repo.CreateHook(ctx, token) +func (s *Hook) CreateHook(ctx context.Context, name string) (domain.Hook, error) { + return s.repo.CreateHook(ctx, s.generateToken(), name) } func (s *Hook) ReceiveWebhook(ctx context.Context, token string, params domain.CreateWebhookRequestParams) (domain.WebhookRequest, domain.HookResponse, error) { @@ -108,6 +122,78 @@ func (s *Hook) SetHookResponse(ctx context.Context, token string, params domain. return s.repo.UpsertHookResponse(ctx, hook.ID, params) } +func (s *Hook) ListChannels(ctx context.Context, token string) ([]domain.NotificationChannel, error) { + hook, err := s.repo.GetHookByToken(ctx, token) + if err != nil { + return nil, err + } + return s.repo.ListNotificationChannels(ctx, hook.ID) +} + +func (s *Hook) UpsertChannel(ctx context.Context, token, provider string, config map[string]string) (domain.NotificationChannel, error) { + hook, err := s.repo.GetHookByToken(ctx, token) + if err != nil { + return domain.NotificationChannel{}, err + } + return s.repo.UpsertNotificationChannel(ctx, hook.ID, provider, config) +} + +func (s *Hook) DeleteChannel(ctx context.Context, token, provider string) error { + hook, err := s.repo.GetHookByToken(ctx, token) + if err != nil { + return err + } + return s.repo.DeleteNotificationChannel(ctx, hook.ID, provider) +} + +func (s *Hook) GetChannelsForHookID(ctx context.Context, hookID int64) ([]domain.NotificationChannel, error) { + return s.repo.ListNotificationChannels(ctx, hookID) +} + +func (s *Hook) DispatchNotifications(ctx context.Context, req domain.WebhookRequest, token string) { + ctx = context.WithoutCancel(ctx) + go func() { + select { + case s.notifySem <- struct{}{}: + defer func() { <-s.notifySem }() + s.sendNotifications(ctx, req, token) + default: + slog.Warn("notification queue full, dropping", "token", token) + } + }() +} + +func (s *Hook) sendNotifications(ctx context.Context, req domain.WebhookRequest, token string) { + channels, err := s.repo.ListNotificationChannels(ctx, req.HookID) + if err != nil { + slog.Warn("fetch notification channels", "hookID", req.HookID, "err", err) + return + } + if len(channels) == 0 { + return + } + + msg := fmt.Sprintf( + "📨 New webhook\nEndpoint: /r/%s\nMethod: %s\nPath: %s", + html.EscapeString(token), html.EscapeString(req.Method), html.EscapeString(req.Path), + ) + + var wg sync.WaitGroup + for _, ch := range channels { + if !ch.Enabled { + continue + } + wg.Add(1) + go func(ch domain.NotificationChannel) { + defer wg.Done() + if err := s.sender.Send(ctx, ch.Provider, ch.Config, msg); err != nil { + slog.Warn("notification failed", "provider", ch.Provider, "token", token, "err", err) + } + }(ch) + } + wg.Wait() +} + func defaultHookResponse() domain.HookResponse { return domain.HookResponse{ StatusCode: defaultHookResponseStatusCode, diff --git a/internal/domain/hook.go b/internal/domain/hook.go index 98911e2..c90c672 100644 --- a/internal/domain/hook.go +++ b/internal/domain/hook.go @@ -61,3 +61,13 @@ func (p UpsertHookResponseParams) Validate() error { } return nil } + +type NotificationChannel struct { + ID int64 + HookID int64 + Provider string + Config map[string]string + Enabled bool + CreatedAt time.Time + UpdatedAt time.Time +} diff --git a/internal/notify/provider.go b/internal/notify/provider.go new file mode 100644 index 0000000..5492710 --- /dev/null +++ b/internal/notify/provider.go @@ -0,0 +1,60 @@ +package notify + +import ( + "context" + "fmt" + "sort" +) + +type Config map[string]string + +type Provider interface { + Send(ctx context.Context, config Config, message string) error + ValidateConfig(config Config) error + SecretKeys() []string +} + +type Registry struct { + providers map[string]Provider +} + +func NewRegistry(providers map[string]Provider) *Registry { + return &Registry{providers: providers} +} + +func (r *Registry) Send(ctx context.Context, provider string, config map[string]string, message string) error { + p, ok := r.providers[provider] + if !ok { + return fmt.Errorf("unknown provider: %s", provider) + } + + return p.Send(ctx, Config(config), message) +} + +func (r *Registry) ValidateConfig(provider string, config map[string]string) error { + p, ok := r.providers[provider] + if !ok { + return fmt.Errorf("unknown provider %q", provider) + } + + return p.ValidateConfig(Config(config)) +} + +func (r *Registry) SecretKeys(provider string) []string { + p, ok := r.providers[provider] + if !ok { + return nil + } + + return p.SecretKeys() +} + +func (r *Registry) KnownProviders() []string { + keys := make([]string, 0, len(r.providers)) + for k := range r.providers { + keys = append(keys, k) + } + + sort.Strings(keys) + return keys +} diff --git a/internal/notify/telegram.go b/internal/notify/telegram.go new file mode 100644 index 0000000..e4d6201 --- /dev/null +++ b/internal/notify/telegram.go @@ -0,0 +1,120 @@ +package notify + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/url" + "sync" + "time" +) + +var ( + defaultTelegramClient = &http.Client{Timeout: 10 * time.Second} + proxyClients sync.Map +) + +func NewTelegramProvider() Provider { + return telegramProvider{} +} + +type telegramProvider struct{} + +func (telegramProvider) Send(ctx context.Context, config Config, message string) error { + return sendMessage(ctx, config["bot_token"], config["chat_id"], message, config["proxy_url"]) +} + +func (telegramProvider) ValidateConfig(config Config) error { + for _, k := range []string{"bot_token", "chat_id"} { + if config[k] == "" { + return fmt.Errorf("telegram: %q is required", k) + } + } + return nil +} + +func (telegramProvider) SecretKeys() []string { + return []string{"bot_token"} +} + +func sendMessage(ctx context.Context, botToken, chatID, text, proxyURL string) error { + if botToken == "" || chatID == "" { + return fmt.Errorf("telegram: bot_token and chat_id are required") + } + + client := defaultTelegramClient + if proxyURL != "" { + if v, ok := proxyClients.Load(proxyURL); ok { + if c, ok := v.(*http.Client); ok { + client = c + } + } else { + proxy, err := validateProxy(proxyURL) + if err != nil { + return err + } + c := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{Proxy: http.ProxyURL(proxy)}, + } + proxyClients.Store(proxyURL, c) + client = c + } + } + + payload, err := json.Marshal(map[string]string{ + "chat_id": chatID, + "text": text, + "parse_mode": "HTML", + }) + if err != nil { + return err + } + + apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(payload)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + slog.Warn("close response body", "err", err) + } + }() + + if resp.StatusCode == http.StatusOK { + return nil + } + + var tgErr struct { + Description string `json:"description"` + } + if err := json.NewDecoder(resp.Body).Decode(&tgErr); err == nil && tgErr.Description != "" { + return fmt.Errorf("telegram: %s", tgErr.Description) + } + return fmt.Errorf("telegram API returned %d", resp.StatusCode) +} + +// validateProxy parses and validates the proxy URL scheme. +// Returns the parsed URL so the caller doesn't parse it twice. +func validateProxy(rawURL string) (*url.URL, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + switch u.Scheme { + case "http", "https", "socks5": + return u, nil + default: + return nil, fmt.Errorf("proxy scheme %q not allowed: use http, https, or socks5", u.Scheme) + } +} diff --git a/internal/repos/hook.go b/internal/repos/hook.go index 99d39d0..316f510 100644 --- a/internal/repos/hook.go +++ b/internal/repos/hook.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "errors" + "log/slog" "github.com/GaIsBAX/Webhix/internal/domain" "github.com/GaIsBAX/Webhix/internal/store/sqlc" @@ -20,10 +21,10 @@ func NewHook(db sqlc.DBTX) *Hook { } } -func (r *Hook) CreateHook(ctx context.Context, token string) (domain.Hook, error) { +func (r *Hook) CreateHook(ctx context.Context, token, name string) (domain.Hook, error) { hook, err := r.q.CreateHook(ctx, sqlc.CreateHookParams{ Token: token, - Name: sql.NullString{}, + Name: sql.NullString{String: name, Valid: name != ""}, }) if err != nil { return domain.Hook{}, err @@ -123,6 +124,86 @@ func (r *Hook) UpsertHookResponse(ctx context.Context, hookID int64, params doma return toDomainHookResponse(row), nil } +func (r *Hook) ListNotificationChannels(ctx context.Context, hookID int64) ([]domain.NotificationChannel, error) { + rows, err := r.q.ListNotificationChannels(ctx, hookID) + if err != nil { + return nil, err + } + + result := make([]domain.NotificationChannel, len(rows)) + for i, row := range rows { + result[i] = toDomainChannel(row) + } + + return result, nil +} + +func (r *Hook) GetNotificationChannel(ctx context.Context, hookID int64, provider string) (domain.NotificationChannel, error) { + row, err := r.q.GetNotificationChannel(ctx, sqlc.GetNotificationChannelParams{ + HookID: hookID, + Provider: provider, + }) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return domain.NotificationChannel{}, domain.ErrNotFound + } + return domain.NotificationChannel{}, err + } + + return toDomainChannel(row), nil +} + +func (r *Hook) UpsertNotificationChannel(ctx context.Context, hookID int64, provider string, config map[string]string) (domain.NotificationChannel, error) { + configJSON, err := json.Marshal(config) + if err != nil { + return domain.NotificationChannel{}, err + } + + row, err := r.q.UpsertNotificationChannel(ctx, sqlc.UpsertNotificationChannelParams{ + HookID: hookID, + Provider: provider, + Config: string(configJSON), + }) + + if err != nil { + return domain.NotificationChannel{}, err + } + + return toDomainChannel(row), nil +} + +func (r *Hook) DeleteNotificationChannel(ctx context.Context, hookID int64, provider string) error { + n, err := r.q.DeleteNotificationChannel(ctx, sqlc.DeleteNotificationChannelParams{ + HookID: hookID, + Provider: provider, + }) + if err != nil { + return err + } + if n == 0 { + return domain.ErrNotFound + } + return nil +} + +func toDomainChannel(row sqlc.HookNotificationChannel) domain.NotificationChannel { + cfg := map[string]string{} + if err := json.Unmarshal([]byte(row.Config), &cfg); err != nil { + slog.Warn("parse notification channel config", "err", err) + } + + return domain.NotificationChannel{ + ID: row.ID, + HookID: row.HookID, + Provider: row.Provider, + Config: cfg, + Enabled: row.Enabled != 0, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } +} + func toDomainHookResponse(row sqlc.HookResponse) domain.HookResponse { headers := map[string]string{} if err := json.Unmarshal([]byte(row.Headers), &headers); err != nil { diff --git a/internal/server/contract.go b/internal/server/contract.go index 11f5f73..a16982e 100644 --- a/internal/server/contract.go +++ b/internal/server/contract.go @@ -47,6 +47,12 @@ type HookResponseContract struct { Body string `json:"body"` } +type NotificationContract struct { + Provider string `json:"provider"` + Config map[string]string `json:"config"` + Redacted []string `json:"redacted,omitempty"` +} + type SetHookResponseRequestContract struct { StatusCode int `json:"statusCode"` Headers map[string]string `json:"headers"` @@ -122,3 +128,22 @@ func toWebhookRequestContract(req domain.WebhookRequest) WebhookRequestContract ReceivedAt: req.ReceivedAt, } } + +func toNotificationContract(ch domain.NotificationChannel, secretKeys []string) NotificationContract { + secretSet := make(map[string]bool, len(secretKeys)) + for _, s := range secretKeys { + secretSet[s] = true + } + + cfg := make(map[string]string, len(ch.Config)) + var redacted []string + for k, v := range ch.Config { + if secretSet[k] && v != "" { + redacted = append(redacted, k) + } else { + cfg[k] = v + } + } + + return NotificationContract{Provider: ch.Provider, Config: cfg, Redacted: redacted} +} diff --git a/internal/server/handler.go b/internal/server/handler.go index 1364e48..b5269e9 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "html" "io" "log/slog" "net/http" @@ -16,11 +17,19 @@ const DefaultMaxBodySize int64 = 5 << 20 // 5MB type HookService interface { ListHooks(ctx context.Context) ([]domain.Hook, error) - CreateHook(ctx context.Context, token string) (domain.Hook, error) + CreateHook(ctx context.Context, name string) (domain.Hook, error) ReceiveWebhook(ctx context.Context, token string, params domain.CreateWebhookRequestParams) (domain.WebhookRequest, domain.HookResponse, error) ListWebhookRequests(ctx context.Context, token string) ([]domain.WebhookRequest, error) GetHookResponse(ctx context.Context, token string) (domain.HookResponse, error) SetHookResponse(ctx context.Context, token string, params domain.UpsertHookResponseParams) (domain.HookResponse, error) + DispatchNotifications(ctx context.Context, req domain.WebhookRequest, token string) +} + +type NotificationService interface { + ListChannels(ctx context.Context, token string) ([]domain.NotificationChannel, error) + UpsertChannel(ctx context.Context, token, provider string, config map[string]string) (domain.NotificationChannel, error) + DeleteChannel(ctx context.Context, token, provider string) error + GetChannelsForHookID(ctx context.Context, hookID int64) ([]domain.NotificationChannel, error) } type EventBroker interface { @@ -29,6 +38,12 @@ type EventBroker interface { Publish(token string, data []byte) } +type NotificationRegistry interface { + Send(ctx context.Context, provider string, config map[string]string, message string) error + ValidateConfig(provider string, config map[string]string) error + SecretKeys(provider string) []string +} + type HookOptions struct { BaseURL string MaxBodySize int64 @@ -36,10 +51,12 @@ type HookOptions struct { } type HookDeps struct { - Mux *http.ServeMux - Service HookService - Hub EventBroker - Opts HookOptions + Mux *http.ServeMux + Service HookService + Notifications NotificationService + Registry NotificationRegistry + Hub EventBroker + Opts HookOptions } type Hook struct { @@ -61,6 +78,10 @@ func (h *Hook) RegisterRoutes() { h.deps.Mux.HandleFunc("GET /api/endpoints/{token}/events", h.StreamEvents) h.deps.Mux.HandleFunc("GET /api/endpoints/{token}/response", h.GetResponse) h.deps.Mux.HandleFunc("PUT /api/endpoints/{token}/response", h.SetResponse) + h.deps.Mux.HandleFunc("GET /api/endpoints/{token}/notifications", h.GetNotification) + h.deps.Mux.HandleFunc("PUT /api/endpoints/{token}/notifications/{provider}", h.SetNotification) + h.deps.Mux.HandleFunc("DELETE /api/endpoints/{token}/notifications/{provider}", h.DeleteNotification) + h.deps.Mux.HandleFunc("POST /api/endpoints/{token}/notifications/{provider}/test", h.TestNotification) h.deps.Mux.HandleFunc("/r/{token}", h.ReceiveWebhook) } @@ -188,6 +209,7 @@ func (h *Hook) ReceiveWebhook(w http.ResponseWriter, r *http.Request) { } h.deps.Hub.Publish(token, data) + h.deps.Service.DispatchNotifications(r.Context(), req, token) if customResp.StatusCode > 0 { for k, v := range customResp.Headers { @@ -340,6 +362,152 @@ func (h *Hook) SetResponse(w http.ResponseWriter, r *http.Request) { SendSuccess(w, http.StatusOK, data) } +func (h *Hook) GetNotification(w http.ResponseWriter, r *http.Request) { + token := r.PathValue("token") + + channels, err := h.deps.Notifications.ListChannels(r.Context(), token) + if err != nil { + if errors.Is(err, domain.ErrNotFound) { + SendError(w, http.StatusNotFound, ErrNotFound) + return + } + slog.Error("list notification channels", "err", err) + SendError(w, http.StatusInternalServerError, ErrInternal) + return + } + + contracts := make([]NotificationContract, len(channels)) + for i, ch := range channels { + contracts[i] = toNotificationContract(ch, h.deps.Registry.SecretKeys(ch.Provider)) + } + + data, err := json.Marshal(contracts) + if err != nil { + SendError(w, http.StatusInternalServerError, ErrInternal) + return + } + + SendSuccess(w, http.StatusOK, data) +} + +func (h *Hook) SetNotification(w http.ResponseWriter, r *http.Request) { + if h.readOnly(w) { + return + } + + token := r.PathValue("token") + provider := r.PathValue("provider") + + contract, err := DecodeRequest[NotificationContract](r) + if err != nil { + SendError(w, http.StatusBadRequest, ErrBadRequest) + return + } + + if contract.Config == nil { + contract.Config = make(map[string]string) + } + + if secrets := h.deps.Registry.SecretKeys(provider); len(secrets) > 0 { + if existing, err := h.deps.Notifications.ListChannels(r.Context(), token); err == nil { + for _, exc := range existing { + if exc.Provider != provider { + continue + } + for _, key := range secrets { + if contract.Config[key] == "" && exc.Config[key] != "" { + contract.Config[key] = exc.Config[key] + } + } + break + } + } + } + + if err := h.deps.Registry.ValidateConfig(provider, contract.Config); err != nil { + SendError(w, http.StatusBadRequest, WithDetails(ErrBadRequest, ErrorDetailContract{Message: err.Error()})) + return + } + + ch, err := h.deps.Notifications.UpsertChannel(r.Context(), token, provider, contract.Config) + if err != nil { + if errors.Is(err, domain.ErrNotFound) { + SendError(w, http.StatusNotFound, ErrNotFound) + return + } + slog.Error("upsert notification channel", "err", err) + SendError(w, http.StatusInternalServerError, ErrInternal) + return + } + + data, err := json.Marshal(toNotificationContract(ch, h.deps.Registry.SecretKeys(ch.Provider))) + if err != nil { + SendError(w, http.StatusInternalServerError, ErrInternal) + return + } + + SendSuccess(w, http.StatusOK, data) +} + +func (h *Hook) DeleteNotification(w http.ResponseWriter, r *http.Request) { + if h.readOnly(w) { + return + } + + token := r.PathValue("token") + provider := r.PathValue("provider") + + if err := h.deps.Notifications.DeleteChannel(r.Context(), token, provider); err != nil { + if errors.Is(err, domain.ErrNotFound) { + SendError(w, http.StatusNotFound, ErrNotFound) + return + } + slog.Error("delete notification channel", "err", err) + SendError(w, http.StatusInternalServerError, ErrInternal) + return + } + + SendSuccess(w, http.StatusOK, []byte(`{}`)) +} + +func (h *Hook) TestNotification(w http.ResponseWriter, r *http.Request) { + token := r.PathValue("token") + provider := r.PathValue("provider") + + channels, err := h.deps.Notifications.ListChannels(r.Context(), token) + if err != nil { + if errors.Is(err, domain.ErrNotFound) { + SendError(w, http.StatusNotFound, ErrNotFound) + return + } + slog.Error("list channels for test", "err", err) + SendError(w, http.StatusInternalServerError, ErrInternal) + return + } + + for _, ch := range channels { + if ch.Provider != provider { + continue + } + msg := fmt.Sprintf("✅ Webhix test notification for endpoint /r/%s", html.EscapeString(token)) + if err := h.deps.Registry.Send(r.Context(), ch.Provider, ch.Config, msg); err != nil { + slog.Error("test notification", "provider", provider, "err", err) + SendError(w, http.StatusBadGateway, WithDetails(ErrInternal, ErrorDetailContract{ + Field: provider, + Message: err.Error(), + })) + return + } + SendSuccess(w, http.StatusOK, []byte(`{"sent":true}`)) + return + } + + SendError(w, http.StatusNotFound, WithDetails(ErrNotFound, ErrorDetailContract{ + Field: "provider", + Message: provider + " is not configured for this endpoint", + })) +} + func (h *Hook) readOnly(w http.ResponseWriter) bool { if !h.deps.Opts.ReadOnly { return false diff --git a/internal/store/migrations/20260529000000_add_hook_notification_channels.sql b/internal/store/migrations/20260529000000_add_hook_notification_channels.sql new file mode 100644 index 0000000..db6220b --- /dev/null +++ b/internal/store/migrations/20260529000000_add_hook_notification_channels.sql @@ -0,0 +1,15 @@ +-- +goose Up +CREATE TABLE hook_notification_channels ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + hook_id INTEGER NOT NULL, + provider TEXT NOT NULL, + config TEXT NOT NULL DEFAULT '{}', + enabled INTEGER NOT NULL DEFAULT 1, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (hook_id) REFERENCES hooks(id) ON DELETE CASCADE, + UNIQUE(hook_id, provider) +); + +-- +goose Down +DROP TABLE IF EXISTS hook_notification_channels; diff --git a/internal/store/query/hooks.sql b/internal/store/query/hooks.sql index 49ea469..a5b6356 100644 --- a/internal/store/query/hooks.sql +++ b/internal/store/query/hooks.sql @@ -75,3 +75,27 @@ ORDER BY created_at DESC; SELECT id, hook_id, status_code, headers, body, created_at, updated_at FROM hook_responses WHERE hook_id = ?; + +-- name: ListNotificationChannels :many +SELECT id, hook_id, provider, config, enabled, created_at, updated_at +FROM hook_notification_channels +WHERE hook_id = ? +ORDER BY provider; + +-- name: GetNotificationChannel :one +SELECT id, hook_id, provider, config, enabled, created_at, updated_at +FROM hook_notification_channels +WHERE hook_id = ? AND provider = ?; + +-- name: UpsertNotificationChannel :one +INSERT INTO hook_notification_channels (hook_id, provider, config) +VALUES (?, ?, ?) +ON CONFLICT (hook_id, provider) DO UPDATE SET + config = excluded.config, + enabled = 1, + updated_at = CURRENT_TIMESTAMP +RETURNING id, hook_id, provider, config, enabled, created_at, updated_at; + +-- name: DeleteNotificationChannel :execrows +DELETE FROM hook_notification_channels +WHERE hook_id = ? AND provider = ?; diff --git a/internal/store/sqlc/hooks.sql.go b/internal/store/sqlc/hooks.sql.go index 8a792b6..eff583d 100644 --- a/internal/store/sqlc/hooks.sql.go +++ b/internal/store/sqlc/hooks.sql.go @@ -92,6 +92,24 @@ func (q *Queries) CreateWebhookRequest(ctx context.Context, arg CreateWebhookReq return i, err } +const deleteNotificationChannel = `-- name: DeleteNotificationChannel :execrows +DELETE FROM hook_notification_channels +WHERE hook_id = ? AND provider = ? +` + +type DeleteNotificationChannelParams struct { + HookID int64 `json:"hook_id"` + Provider string `json:"provider"` +} + +func (q *Queries) DeleteNotificationChannel(ctx context.Context, arg DeleteNotificationChannelParams) (int64, error) { + result, err := q.db.ExecContext(ctx, deleteNotificationChannel, arg.HookID, arg.Provider) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + const deleteWebhookRequestsOlderThan = `-- name: DeleteWebhookRequestsOlderThan :execresult DELETE FROM webhook_requests WHERE received_at < datetime('now', ?) @@ -153,10 +171,37 @@ func (q *Queries) GetHookResponseByHookID(ctx context.Context, hookID int64) (Ho return i, err } +const getNotificationChannel = `-- name: GetNotificationChannel :one +SELECT id, hook_id, provider, config, enabled, created_at, updated_at +FROM hook_notification_channels +WHERE hook_id = ? AND provider = ? +` + +type GetNotificationChannelParams struct { + HookID int64 `json:"hook_id"` + Provider string `json:"provider"` +} + +func (q *Queries) GetNotificationChannel(ctx context.Context, arg GetNotificationChannelParams) (HookNotificationChannel, error) { + row := q.db.QueryRowContext(ctx, getNotificationChannel, arg.HookID, arg.Provider) + var i HookNotificationChannel + err := row.Scan( + &i.ID, + &i.HookID, + &i.Provider, + &i.Config, + &i.Enabled, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const listHooks = `-- name: ListHooks :many SELECT id, token, name, created_at, updated_at FROM hooks -ORDER BY created_at DESC` +ORDER BY created_at DESC +` func (q *Queries) ListHooks(ctx context.Context) ([]Hook, error) { rows, err := q.db.QueryContext(ctx, listHooks) @@ -187,6 +232,44 @@ func (q *Queries) ListHooks(ctx context.Context) ([]Hook, error) { return items, nil } +const listNotificationChannels = `-- name: ListNotificationChannels :many +SELECT id, hook_id, provider, config, enabled, created_at, updated_at +FROM hook_notification_channels +WHERE hook_id = ? +ORDER BY provider +` + +func (q *Queries) ListNotificationChannels(ctx context.Context, hookID int64) ([]HookNotificationChannel, error) { + rows, err := q.db.QueryContext(ctx, listNotificationChannels, hookID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []HookNotificationChannel + for rows.Next() { + var i HookNotificationChannel + if err := rows.Scan( + &i.ID, + &i.HookID, + &i.Provider, + &i.Config, + &i.Enabled, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listWebhookRequestsByHookID = `-- name: ListWebhookRequestsByHookID :many SELECT id, hook_id, method, path, query, headers, body, remote_addr, content_type, body_size, received_at FROM webhook_requests @@ -337,3 +420,34 @@ func (q *Queries) UpsertHookResponse(ctx context.Context, arg UpsertHookResponse ) return i, err } + +const upsertNotificationChannel = `-- name: UpsertNotificationChannel :one +INSERT INTO hook_notification_channels (hook_id, provider, config) +VALUES (?, ?, ?) +ON CONFLICT (hook_id, provider) DO UPDATE SET + config = excluded.config, + enabled = 1, + updated_at = CURRENT_TIMESTAMP +RETURNING id, hook_id, provider, config, enabled, created_at, updated_at +` + +type UpsertNotificationChannelParams struct { + HookID int64 `json:"hook_id"` + Provider string `json:"provider"` + Config string `json:"config"` +} + +func (q *Queries) UpsertNotificationChannel(ctx context.Context, arg UpsertNotificationChannelParams) (HookNotificationChannel, error) { + row := q.db.QueryRowContext(ctx, upsertNotificationChannel, arg.HookID, arg.Provider, arg.Config) + var i HookNotificationChannel + err := row.Scan( + &i.ID, + &i.HookID, + &i.Provider, + &i.Config, + &i.Enabled, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/store/sqlc/models.go b/internal/store/sqlc/models.go index 0c8fe84..9d98d7c 100644 --- a/internal/store/sqlc/models.go +++ b/internal/store/sqlc/models.go @@ -17,6 +17,16 @@ type Hook struct { UpdatedAt time.Time `json:"updated_at"` } +type HookNotificationChannel struct { + ID int64 `json:"id"` + HookID int64 `json:"hook_id"` + Provider string `json:"provider"` + Config string `json:"config"` + Enabled int64 `json:"enabled"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type HookResponse struct { ID int64 `json:"id"` HookID int64 `json:"hook_id"` diff --git a/internal/web/ui/src/features/endpoint-session/api/endpoint-api.ts b/internal/web/ui/src/features/endpoint-session/api/endpoint-api.ts index 0fd6ab2..e0a4d90 100644 --- a/internal/web/ui/src/features/endpoint-session/api/endpoint-api.ts +++ b/internal/web/ui/src/features/endpoint-session/api/endpoint-api.ts @@ -57,6 +57,30 @@ export async function saveHookResponse(token: string, data: HookResponse): Promi if (!json.success) throw new Error(json.error?.message || 'Failed to save'); } +export interface Notification { + telegramBotToken: string; + telegramChatId: string; + proxyUrl: string; +} + +export async function fetchNotification(token: string): Promise { + const response = await fetch(`/api/endpoints/${token}/notifications`); + const json = (await response.json()) as ApiResponse; + if (!json.success || !json.body) + return { telegramBotToken: '', telegramChatId: '', proxyUrl: '' }; + return json.body; +} + +export async function saveNotification(token: string, data: Notification): Promise { + const response = await fetch(`/api/endpoints/${token}/notifications`, { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(data), + }); + const json = (await response.json()) as ApiResponse; + if (!json.success) throw new Error(json.error?.message || 'Failed to save'); +} + export function connectEvents( token: string, handlers: { diff --git a/internal/web/ui/src/widgets/request-detail/request-detail.ts b/internal/web/ui/src/widgets/request-detail/request-detail.ts index 88bb1a2..3004bcc 100644 --- a/internal/web/ui/src/widgets/request-detail/request-detail.ts +++ b/internal/web/ui/src/widgets/request-detail/request-detail.ts @@ -12,6 +12,8 @@ import { import { fetchHookResponse, saveHookResponse, + fetchNotification, + saveNotification, } from '../../features/endpoint-session/api/endpoint-api'; export function renderSelectedDetail(elements: Elements, state: AppState): void { @@ -237,7 +239,57 @@ function createSettingsForm(token: string | null): HTMLDivElement { saveBtn.className = 'settings-save-btn'; saveBtn.textContent = 'Save'; - wrap.append(statusLabel, statusInput, headersLabel, headersInput, bodyLabel, bodyInput, saveBtn); + // Telegram notifications section + const divider = document.createElement('hr'); + divider.className = 'settings-divider'; + + const tgTitle = document.createElement('h4'); + tgTitle.className = 'settings-section-title'; + tgTitle.textContent = 'Telegram Notifications'; + + const tgTokenLabel = document.createElement('label'); + tgTokenLabel.textContent = 'Bot Token'; + const tgTokenInput = document.createElement('input'); + tgTokenInput.type = 'text'; + tgTokenInput.className = 'settings-input'; + tgTokenInput.placeholder = '123456:ABC-DEF...'; + + const tgChatLabel = document.createElement('label'); + tgChatLabel.textContent = 'Chat ID'; + const tgChatInput = document.createElement('input'); + tgChatInput.type = 'text'; + tgChatInput.className = 'settings-input'; + tgChatInput.placeholder = '-1001234567890'; + + const tgProxyLabel = document.createElement('label'); + tgProxyLabel.textContent = 'Proxy URL (optional)'; + const tgProxyInput = document.createElement('input'); + tgProxyInput.type = 'text'; + tgProxyInput.className = 'settings-input'; + tgProxyInput.placeholder = 'socks5://127.0.0.1:1080'; + + const tgSaveBtn = document.createElement('button'); + tgSaveBtn.className = 'settings-save-btn'; + tgSaveBtn.textContent = 'Save Notifications'; + + wrap.append( + statusLabel, + statusInput, + headersLabel, + headersInput, + bodyLabel, + bodyInput, + saveBtn, + divider, + tgTitle, + tgTokenLabel, + tgTokenInput, + tgChatLabel, + tgChatInput, + tgProxyLabel, + tgProxyInput, + tgSaveBtn, + ); if (token) { void fetchHookResponse(token).then((resp) => { @@ -248,6 +300,12 @@ function createSettingsForm(token: string | null): HTMLDivElement { bodyInput.value = resp.body || ''; }); + void fetchNotification(token).then((n) => { + tgTokenInput.value = n.telegramBotToken; + tgChatInput.value = n.telegramChatId; + tgProxyInput.value = n.proxyUrl; + }); + saveBtn.addEventListener('click', () => { let headers: Record = {}; try { @@ -278,6 +336,26 @@ function createSettingsForm(token: string | null): HTMLDivElement { saveBtn.disabled = false; }); }); + + tgSaveBtn.addEventListener('click', () => { + tgSaveBtn.disabled = true; + void saveNotification(token, { + telegramBotToken: tgTokenInput.value.trim(), + telegramChatId: tgChatInput.value.trim(), + proxyUrl: tgProxyInput.value.trim(), + }) + .then(() => { + tgSaveBtn.textContent = 'Saved!'; + setTimeout(() => (tgSaveBtn.textContent = 'Save Notifications'), 2000); + }) + .catch(() => { + tgSaveBtn.textContent = 'Error'; + setTimeout(() => (tgSaveBtn.textContent = 'Save Notifications'), 2000); + }) + .finally(() => { + tgSaveBtn.disabled = false; + }); + }); } return wrap;