Skip to content

Commit dc32dcf

Browse files
committed
check result validation and tests
1 parent 0316914 commit dc32dcf

3 files changed

Lines changed: 690 additions & 3 deletions

File tree

dbq_validator.go

Lines changed: 201 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"context"
1919
"fmt"
2020
"log/slog"
21+
"strconv"
2122
"time"
2223
)
2324

@@ -83,7 +84,7 @@ func (d DbqDataValidatorImpl) RunCheck(ctx context.Context, adapter DbqDataSourc
8384
"check_query", checkQuery)
8485

8586
startTime := time.Now()
86-
queryResult, pass, err := adapter.ExecuteQuery(ctx, checkQuery)
87+
queryResult, err := adapter.ExecuteQuery(ctx, checkQuery)
8788
elapsed := time.Since(startTime).Milliseconds()
8889

8990
if err != nil {
@@ -95,9 +96,206 @@ func (d DbqDataValidatorImpl) RunCheck(ctx context.Context, adapter DbqDataSourc
9596
"check_expression", check.Expression,
9697
"duration_ms", elapsed)
9798

98-
// todo: do actual check validation based on the query result
9999
result.QueryResultValue = queryResult
100-
result.Pass = pass
100+
result.Pass = d.validateResult(queryResult, check.ParsedCheck)
101101

102102
return result
103103
}
104+
105+
// validateResult checks if the query result meets the check criteria
106+
func (d DbqDataValidatorImpl) validateResult(queryResult string, parsedCheck *CheckExpression) bool {
107+
if parsedCheck == nil {
108+
// If there's no parsed check, consider it a pass (raw queries without validation)
109+
return true
110+
}
111+
112+
// If there's no operator, just check if we got a result (for functions like raw_query)
113+
if parsedCheck.Operator == "" {
114+
return queryResult != ""
115+
}
116+
117+
// Convert query result to float64 for numeric comparisons
118+
actualValue, err := strconv.ParseFloat(queryResult, 64)
119+
if err != nil {
120+
d.logger.Warn("Failed to parse query result as number, treating as string comparison",
121+
"result", queryResult,
122+
"error", err)
123+
return d.validateStringResult(queryResult, parsedCheck)
124+
}
125+
126+
switch parsedCheck.Operator {
127+
case "between":
128+
return d.validateBetweenRange(actualValue, parsedCheck.ThresholdValue)
129+
case ">":
130+
return d.validateGreaterThan(actualValue, parsedCheck.ThresholdValue)
131+
case ">=":
132+
return d.validateGreaterThanOrEqual(actualValue, parsedCheck.ThresholdValue)
133+
case "<":
134+
return d.validateLessThan(actualValue, parsedCheck.ThresholdValue)
135+
case "<=":
136+
return d.validateLessThanOrEqual(actualValue, parsedCheck.ThresholdValue)
137+
case "==", "=":
138+
return d.validateEqual(actualValue, parsedCheck.ThresholdValue)
139+
case "!=", "<>":
140+
return d.validateNotEqual(actualValue, parsedCheck.ThresholdValue)
141+
default:
142+
d.logger.Warn("Unknown operator, defaulting to true",
143+
"operator", parsedCheck.Operator)
144+
return true
145+
}
146+
}
147+
148+
// validateStringResult handles string-based comparisons when numeric parsing fails
149+
func (d DbqDataValidatorImpl) validateStringResult(queryResult string, parsedCheck *CheckExpression) bool {
150+
switch parsedCheck.Operator {
151+
case "==", "=":
152+
if thresholdStr, ok := parsedCheck.ThresholdValue.(string); ok {
153+
return queryResult == thresholdStr
154+
}
155+
return queryResult == fmt.Sprintf("%v", parsedCheck.ThresholdValue)
156+
case "!=", "<>":
157+
if thresholdStr, ok := parsedCheck.ThresholdValue.(string); ok {
158+
return queryResult != thresholdStr
159+
}
160+
return queryResult != fmt.Sprintf("%v", parsedCheck.ThresholdValue)
161+
default:
162+
d.logger.Warn("String comparison not supported for operator, defaulting to false",
163+
"operator", parsedCheck.Operator,
164+
"result", queryResult)
165+
return false
166+
}
167+
}
168+
169+
// validateBetweenRange checks if value is within the specified range
170+
func (d DbqDataValidatorImpl) validateBetweenRange(actualValue float64, thresholdValue interface{}) bool {
171+
betweenRange, ok := thresholdValue.(BetweenRange)
172+
if !ok {
173+
d.logger.Warn("Invalid threshold value for between operator",
174+
"value", thresholdValue)
175+
return false
176+
}
177+
178+
minVal, err := d.convertToFloat64(betweenRange.Min)
179+
if err != nil {
180+
d.logger.Warn("Failed to convert min value to float64",
181+
"min", betweenRange.Min,
182+
"error", err)
183+
return false
184+
}
185+
186+
maxVal, err := d.convertToFloat64(betweenRange.Max)
187+
if err != nil {
188+
d.logger.Warn("Failed to convert max value to float64",
189+
"max", betweenRange.Max,
190+
"error", err)
191+
return false
192+
}
193+
194+
return actualValue >= minVal && actualValue <= maxVal
195+
}
196+
197+
// validateGreaterThan checks if actual > threshold
198+
func (d DbqDataValidatorImpl) validateGreaterThan(actualValue float64, thresholdValue interface{}) bool {
199+
threshold, err := d.convertToFloat64(thresholdValue)
200+
if err != nil {
201+
d.logger.Warn("Failed to convert threshold to float64 for > comparison",
202+
"threshold", thresholdValue,
203+
"error", err)
204+
return false
205+
}
206+
return actualValue > threshold
207+
}
208+
209+
// validateGreaterThanOrEqual checks if actual >= threshold
210+
func (d DbqDataValidatorImpl) validateGreaterThanOrEqual(actualValue float64, thresholdValue interface{}) bool {
211+
threshold, err := d.convertToFloat64(thresholdValue)
212+
if err != nil {
213+
d.logger.Warn("Failed to convert threshold to float64 for >= comparison",
214+
"threshold", thresholdValue,
215+
"error", err)
216+
return false
217+
}
218+
return actualValue >= threshold
219+
}
220+
221+
// validateLessThan checks if actual < threshold
222+
func (d DbqDataValidatorImpl) validateLessThan(actualValue float64, thresholdValue interface{}) bool {
223+
threshold, err := d.convertToFloat64(thresholdValue)
224+
if err != nil {
225+
d.logger.Warn("Failed to convert threshold to float64 for < comparison",
226+
"threshold", thresholdValue,
227+
"error", err)
228+
return false
229+
}
230+
return actualValue < threshold
231+
}
232+
233+
// validateLessThanOrEqual checks if actual <= threshold
234+
func (d DbqDataValidatorImpl) validateLessThanOrEqual(actualValue float64, thresholdValue interface{}) bool {
235+
threshold, err := d.convertToFloat64(thresholdValue)
236+
if err != nil {
237+
d.logger.Warn("Failed to convert threshold to float64 for <= comparison",
238+
"threshold", thresholdValue,
239+
"error", err)
240+
return false
241+
}
242+
return actualValue <= threshold
243+
}
244+
245+
// validateEqual checks if actual == threshold
246+
func (d DbqDataValidatorImpl) validateEqual(actualValue float64, thresholdValue interface{}) bool {
247+
threshold, err := d.convertToFloat64(thresholdValue)
248+
if err != nil {
249+
d.logger.Warn("Failed to convert threshold to float64 for == comparison",
250+
"threshold", thresholdValue,
251+
"error", err)
252+
return false
253+
}
254+
return actualValue == threshold
255+
}
256+
257+
// validateNotEqual checks if actual != threshold
258+
func (d DbqDataValidatorImpl) validateNotEqual(actualValue float64, thresholdValue interface{}) bool {
259+
threshold, err := d.convertToFloat64(thresholdValue)
260+
if err != nil {
261+
d.logger.Warn("Failed to convert threshold to float64 for != comparison",
262+
"threshold", thresholdValue,
263+
"error", err)
264+
return false
265+
}
266+
return actualValue != threshold
267+
}
268+
269+
// convertToFloat64 converts various types to float64
270+
func (d DbqDataValidatorImpl) convertToFloat64(value interface{}) (float64, error) {
271+
switch v := value.(type) {
272+
case float64:
273+
return v, nil
274+
case float32:
275+
return float64(v), nil
276+
case int:
277+
return float64(v), nil
278+
case int8:
279+
return float64(v), nil
280+
case int16:
281+
return float64(v), nil
282+
case int32:
283+
return float64(v), nil
284+
case int64:
285+
return float64(v), nil
286+
case uint:
287+
return float64(v), nil
288+
case uint8:
289+
return float64(v), nil
290+
case uint16:
291+
return float64(v), nil
292+
case uint32:
293+
return float64(v), nil
294+
case uint64:
295+
return float64(v), nil
296+
case string:
297+
return strconv.ParseFloat(v, 64)
298+
default:
299+
return 0, fmt.Errorf("unsupported type: %T", value)
300+
}
301+
}

dbq_validator_integration_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// Copyright 2025 The DBQ Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dbqcore
16+
17+
import (
18+
"context"
19+
"io"
20+
"log/slog"
21+
"testing"
22+
)
23+
24+
// MockAdapter for testing the validation logic
25+
type MockAdapter struct {
26+
queryResult string
27+
queryError error
28+
}
29+
30+
func (m *MockAdapter) InterpretDataQualityCheck(check *DataQualityCheck, dataset string, defaultWhere string) (string, error) {
31+
return "SELECT COUNT(*) FROM " + dataset, nil
32+
}
33+
34+
func (m *MockAdapter) ExecuteQuery(ctx context.Context, query string) (string, error) {
35+
return m.queryResult, m.queryError
36+
}
37+
38+
func TestDbqDataValidator_RunCheck_Integration(t *testing.T) {
39+
validator := NewDbqDataValidator(slog.New(slog.NewTextHandler(io.Discard, nil)))
40+
41+
tests := []struct {
42+
name string
43+
check *DataQualityCheck
44+
queryResult string
45+
expectedPass bool
46+
}{
47+
{
48+
name: "row_count > 100 - pass",
49+
check: &DataQualityCheck{
50+
Expression: "row_count > 100",
51+
ParsedCheck: &CheckExpression{
52+
FunctionName: "row_count",
53+
Operator: ">",
54+
ThresholdValue: 100,
55+
},
56+
},
57+
queryResult: "150",
58+
expectedPass: true,
59+
},
60+
{
61+
name: "row_count > 100 - fail",
62+
check: &DataQualityCheck{
63+
Expression: "row_count > 100",
64+
ParsedCheck: &CheckExpression{
65+
FunctionName: "row_count",
66+
Operator: ">",
67+
ThresholdValue: 100,
68+
},
69+
},
70+
queryResult: "50",
71+
expectedPass: false,
72+
},
73+
{
74+
name: "avg between 3.0 and 5.0 - pass",
75+
check: &DataQualityCheck{
76+
Expression: "avg(rating) between 3.0 and 5.0",
77+
ParsedCheck: &CheckExpression{
78+
FunctionName: "avg",
79+
Operator: "between",
80+
ThresholdValue: BetweenRange{Min: 3.0, Max: 5.0},
81+
},
82+
},
83+
queryResult: "4.2",
84+
expectedPass: true,
85+
},
86+
{
87+
name: "avg between 3.0 and 5.0 - fail",
88+
check: &DataQualityCheck{
89+
Expression: "avg(rating) between 3.0 and 5.0",
90+
ParsedCheck: &CheckExpression{
91+
FunctionName: "avg",
92+
Operator: "between",
93+
ThresholdValue: BetweenRange{Min: 3.0, Max: 5.0},
94+
},
95+
},
96+
queryResult: "6.0",
97+
expectedPass: false,
98+
},
99+
{
100+
name: "freshness < 3600 - pass",
101+
check: &DataQualityCheck{
102+
Expression: "freshness(last_updated) < 3600",
103+
ParsedCheck: &CheckExpression{
104+
FunctionName: "freshness",
105+
Operator: "<",
106+
ThresholdValue: 3600,
107+
},
108+
},
109+
queryResult: "1800",
110+
expectedPass: true,
111+
},
112+
{
113+
name: "freshness < 3600 - fail",
114+
check: &DataQualityCheck{
115+
Expression: "freshness(last_updated) < 3600",
116+
ParsedCheck: &CheckExpression{
117+
FunctionName: "freshness",
118+
Operator: "<",
119+
ThresholdValue: 3600,
120+
},
121+
},
122+
queryResult: "7200",
123+
expectedPass: false,
124+
},
125+
}
126+
127+
for _, tt := range tests {
128+
t.Run(tt.name, func(t *testing.T) {
129+
adapter := &MockAdapter{
130+
queryResult: tt.queryResult,
131+
queryError: nil,
132+
}
133+
134+
result := validator.RunCheck(context.Background(), adapter, tt.check, "test_table", "")
135+
136+
if result.Error != "" {
137+
t.Errorf("Unexpected error: %s", result.Error)
138+
}
139+
140+
if result.Pass != tt.expectedPass {
141+
t.Errorf("Expected Pass = %v, got %v", tt.expectedPass, result.Pass)
142+
}
143+
144+
if result.QueryResultValue != tt.queryResult {
145+
t.Errorf("Expected QueryResultValue = %s, got %s", tt.queryResult, result.QueryResultValue)
146+
}
147+
148+
if result.CheckID != tt.check.Expression {
149+
t.Errorf("Expected CheckID = %s, got %s", tt.check.Expression, result.CheckID)
150+
}
151+
})
152+
}
153+
}

0 commit comments

Comments
 (0)