Skip to content

Commit fe7e278

Browse files
Crnaneoppolariss
authored andcommitted
feat: whitelist of ai_summary and add trace_id
1 parent 7b57965 commit fe7e278

4 files changed

Lines changed: 23 additions & 12 deletions

File tree

apis/hole/apis.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"treehole_next/utils/sensitive"
1515

1616
"github.com/gofiber/fiber/v2"
17+
"github.com/google/uuid"
1718
"github.com/opentreehole/go-common"
1819
"github.com/rs/zerolog/log"
1920
"gorm.io/gorm"
@@ -951,7 +952,13 @@ func DeleteHole(c *fiber.Ctx) error {
951952
return c.Status(204).JSON(nil)
952953
}
953954
func GenerateSummary(c *fiber.Ctx) error {
954-
955+
uid, _ := common.GetUserID(c)
956+
if config.Config.WhiteListUserIds != nil && !slices.Contains(config.Config.WhiteListUserIds, uid) {
957+
return c.Status(404).JSON(fiber.Map{
958+
"code": 404,
959+
"message": "Cannot GET " + c.Path(),
960+
})
961+
}
955962
id, _ := c.ParamsInt("id")
956963
forceRefresh := c.QueryBool("force_refresh")
957964
var cachedData Summary
@@ -1041,6 +1048,7 @@ func GenerateSummary(c *fiber.Ctx) error {
10411048
if cachedData.Data.HoleID != id && cachedData.Data.HoleID != 0 {
10421049
log.Error().
10431050
Int("hole_id", id).
1051+
Str("trace_id", cachedData.TraceID).
10441052
Int("ai_server_return_id", cachedData.Data.HoleID).
10451053
Msg("AISummary: hole id error")
10461054
cachedData.Data.HoleID = id
@@ -1099,9 +1107,10 @@ func GenerateSummary(c *fiber.Ctx) error {
10991107
}
11001108

11011109
requestBody := map[string]any{
1102-
"floors": summaryFloors,
1103-
"content": content,
1104-
"hole_id": hole.ID,
1110+
"floors": summaryFloors,
1111+
"content": content,
1112+
"hole_id": hole.ID,
1113+
"trace_id": uuid.NewString(),
11051114
}
11061115

11071116
requestJSON, err := json.Marshal(requestBody)
@@ -1128,6 +1137,7 @@ func GenerateSummary(c *fiber.Ctx) error {
11281137
if resp.StatusCode != http.StatusOK {
11291138
log.Error().
11301139
Str("url", config.Config.AISummaryURL+"/generate_summary").
1140+
Str("trace_id", requestBody["trace_id"].(string)).
11311141
Int("status", resp.StatusCode).
11321142
Str("req_body_base64", func() string {
11331143
if len(requestJSON) > config.Config.SummaryLogLimit {
@@ -1151,6 +1161,7 @@ func GenerateSummary(c *fiber.Ctx) error {
11511161
if err != nil {
11521162
log.Error().
11531163
Str("url", config.Config.AISummaryURL+"/generate_summary").
1164+
Str("trace_id", requestBody["trace_id"].(string)).
11541165
Int("status", resp.StatusCode).
11551166
Str("req_body_base64", func() string {
11561167
if len(requestJSON) > config.Config.SummaryLogLimit {
@@ -1184,6 +1195,7 @@ func GenerateSummary(c *fiber.Ctx) error {
11841195
response.Message = errCode2Message[response.Code]
11851196
default:
11861197
log.Error().Str("url", config.Config.AISummaryURL+"/generate_summary").
1198+
Str("trace_id", requestBody["trace_id"].(string)).
11871199
Int("status", resp.StatusCode).
11881200
Str("req_body_base64", func() string {
11891201
if len(requestJSON) > config.Config.SummaryLogLimit {

apis/hole/schemas.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ func (body ModifyModel) DoNothing() bool {
133133

134134
type Summary struct {
135135
Code int `json:"code"`
136+
TraceID string `json:"trace_id"`
136137
Message string `json:"message"`
137138
Data struct {
138139
HoleID int `json:"hole_id"`

config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ var Config struct {
6060
SummaryContentLimit int64 `env:"SUMMARY_CONTENT_LIMIT" envDefault:"500"`
6161
SummarySteps int `env:"SUMMARY_STEPS" envDefault:"5"`
6262
SummaryLogLimit int `env:"SUMMARY_LOG_LIMIT" envDefault:"1000"`
63+
WhiteListUserIds []int `env:"WHITE_LIST_USER_IDS"`
6364
}
6465

6566
var DynamicConfig struct {

models/hole.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package models
22

33
import (
44
"fmt"
5-
"strconv"
65
"time"
76

87
"golang.org/x/exp/maps"
@@ -276,7 +275,8 @@ func (holes Holes) Preprocess(c *fiber.Ctx) error {
276275
hole.SetHoleFloor()
277276
floors = append(floors, hole.Floors...)
278277
// set ai_summary_available
279-
hole.AISummaryAvailable = true
278+
uid, _ := common.GetUserID(c)
279+
hole.AISummaryAvailable = config.Config.WhiteListUserIds == nil || slices.Contains(config.Config.WhiteListUserIds, uid)
280280
for _, tag := range hole.Tags {
281281
if len(tag.Name) > 0 && tag.Name[0] == '*' {
282282
hole.AISummaryAvailable = false
@@ -294,14 +294,11 @@ func (holes Holes) Preprocess(c *fiber.Ctx) error {
294294
return err
295295
}
296296

297-
var discard any
298-
hole.AISummaryAvailable = hole.Reply > config.Config.SummaryFloorLimit || contentSum >= config.Config.SummaryContentLimit || utils.GetCache("AISummary"+strconv.Itoa(hole.ID), &discard)
297+
// var discard any
298+
hole.AISummaryAvailable = hole.Reply > config.Config.SummaryFloorLimit || contentSum >= config.Config.SummaryContentLimit // || utils.GetCache("AISummary"+strconv.Itoa(hole.ID), &discard)
299299

300300
if hole.AISummaryAvailable {
301-
err = query.Where("is_sensitive = ?", true).Count(&sensitiveCount).Error
302-
if err != nil {
303-
return err
304-
}
301+
query.Where("is_sensitive = ? AND is_actual_sensitive IS NULL", true).Count(&sensitiveCount)
305302
if sensitiveCount > 0 {
306303
hole.AISummaryAvailable = false
307304
}

0 commit comments

Comments
 (0)