Skip to content

Commit a74e0cd

Browse files
committed
Instantiate AgenticPolicyCompiler environment from YAML definitions
1 parent 4b92925 commit a74e0cd

18 files changed

Lines changed: 283 additions & 214 deletions

bundle/src/main/java/dev/cel/bundle/CelEnvironment.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import dev.cel.common.types.OptionalType;
4444
import dev.cel.common.types.SimpleType;
4545
import dev.cel.common.types.TypeParamType;
46+
import dev.cel.common.types.TypeType;
4647
import dev.cel.compiler.CelCompiler;
4748
import dev.cel.compiler.CelCompilerBuilder;
4849
import dev.cel.compiler.CelCompilerLibrary;
@@ -71,9 +72,10 @@ public abstract class CelEnvironment {
7172
"math", CanonicalCelExtension.MATH,
7273
"optional", CanonicalCelExtension.OPTIONAL,
7374
"protos", CanonicalCelExtension.PROTOS,
75+
"regex", CanonicalCelExtension.REGEX,
7476
"sets", CanonicalCelExtension.SETS,
7577
"strings", CanonicalCelExtension.STRINGS,
76-
"comprehensions", CanonicalCelExtension.COMPREHENSIONS);
78+
"two-var-comprehensions", CanonicalCelExtension.COMPREHENSIONS);
7779

7880
private static final ImmutableMap<String, ObjIntConsumer<CelOptions.Builder>> LIMIT_HANDLERS =
7981
ImmutableMap.of(
@@ -102,7 +104,7 @@ public abstract class CelEnvironment {
102104
/**
103105
* Container, which captures default namespace and aliases for value resolution.
104106
*/
105-
public abstract CelContainer container();
107+
public abstract Optional<CelContainer> container();
106108

107109
/**
108110
* An optional description of the environment (example: location of the file containing the config
@@ -226,7 +228,6 @@ public static Builder newBuilder() {
226228
return new AutoValue_CelEnvironment.Builder()
227229
.setName("")
228230
.setDescription("")
229-
.setContainer(CelContainer.ofName(""))
230231
.setVariables(ImmutableSet.of())
231232
.setFunctions(ImmutableSet.of())
232233
.setFeatures(ImmutableSet.of())
@@ -242,7 +243,6 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions)
242243
CelCompilerBuilder compilerBuilder =
243244
celCompiler
244245
.toCompilerBuilder()
245-
.setContainer(container())
246246
.setOptions(celOptions)
247247
.setTypeProvider(celTypeProvider)
248248
.addVarDeclarations(
@@ -254,6 +254,9 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions)
254254
.map(f -> f.toCelFunctionDecl(celTypeProvider))
255255
.collect(toImmutableList()));
256256

257+
258+
container().ifPresent(compilerBuilder::setContainer);
259+
257260
addAllCompilerExtensions(compilerBuilder, celOptions);
258261

259262
applyStandardLibrarySubset(compilerBuilder);
@@ -416,6 +419,9 @@ public abstract static class VariableDecl {
416419
/** The type of the variable. */
417420
public abstract TypeDecl type();
418421

422+
/** Description of the variable. */
423+
public abstract Optional<String> description();
424+
419425
/** Builder for {@link VariableDecl}. */
420426
@AutoValue.Builder
421427
public abstract static class Builder implements RequiredFieldsChecker {
@@ -428,6 +434,8 @@ public abstract static class Builder implements RequiredFieldsChecker {
428434

429435
public abstract VariableDecl.Builder setType(TypeDecl typeDecl);
430436

437+
public abstract VariableDecl.Builder setDescription(String name);
438+
431439
@Override
432440
public ImmutableList<RequiredField> requiredFields() {
433441
return ImmutableList.of(
@@ -667,6 +675,9 @@ public CelType toCelType(CelTypeProvider celTypeProvider) {
667675
CelType keyType = params().get(0).toCelType(celTypeProvider);
668676
CelType valueType = params().get(1).toCelType(celTypeProvider);
669677
return MapType.create(keyType, valueType);
678+
case "type":
679+
checkState(params().size() == 1, "Expected 1 parameter for type, got " + params().size());
680+
return TypeType.create(params().get(0).toCelType(celTypeProvider));
670681
default:
671682
if (isTypeParam()) {
672683
return TypeParamType.create(name());
@@ -838,10 +849,14 @@ enum CanonicalCelExtension {
838849
SETS(
839850
(options, version) -> CelExtensions.sets(options),
840851
(options, version) -> CelExtensions.sets(options)),
852+
REGEX(
853+
(options, version) -> CelExtensions.regex(),
854+
(options, version) -> CelExtensions.regex()),
841855
LISTS((options, version) -> CelExtensions.lists(), (options, version) -> CelExtensions.lists()),
842856
COMPREHENSIONS(
843857
(options, version) -> CelExtensions.comprehensions(),
844-
(options, version) -> CelExtensions.comprehensions());
858+
(options, version) -> CelExtensions.comprehensions())
859+
;
845860

846861
@SuppressWarnings("ImmutableEnumChecker")
847862
private final CompilerExtensionProvider compilerExtensionProvider;

bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@ private VariableDecl parseVariable(ParserContext<Node> ctx, Node node) {
353353
case "name":
354354
builder.setName(newString(ctx, valueNode));
355355
break;
356+
case "description":
357+
builder.setDescription(newString(ctx, valueNode));
358+
break;
356359
case "type":
357360
if (typeDeclBuilder != null) {
358361
ctx.reportError(
@@ -428,6 +431,9 @@ private FunctionDecl parseFunction(ParserContext<Node> ctx, Node node) {
428431
case "overloads":
429432
builder.setOverloads(parseOverloads(ctx, valueNode));
430433
break;
434+
case "description":
435+
// TODO: Set description
436+
break;
431437
default:
432438
ctx.reportError(keyId, String.format("Unsupported function tag: %s", keyName));
433439
break;
@@ -479,6 +485,9 @@ private static ImmutableSet<OverloadDecl> parseOverloads(ParserContext<Node> ctx
479485
case "target":
480486
overloadDeclBuilder.setTarget(parseTypeDecl(ctx, valueNode));
481487
break;
488+
case "examples":
489+
// TODO: Set examples
490+
break;
482491
default:
483492
ctx.reportError(keyId, String.format("Unsupported overload tag: %s", fieldName));
484493
break;

bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlSerializer.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@ public Node representData(Object data) {
7979
if (!environment.description().isEmpty()) {
8080
configMap.put("description", environment.description());
8181
}
82-
if (!environment.container().name().isEmpty()
83-
|| !environment.container().abbreviations().isEmpty()
84-
|| !environment.container().aliases().isEmpty()) {
85-
configMap.put("container", environment.container());
82+
83+
if (environment.container().isPresent()) {
84+
configMap.put("container", environment.container().get());
8685
}
8786
if (!environment.extensions().isEmpty()) {
8887
configMap.put("extensions", environment.extensions().asList());

tools/ai/BUILD.bazel

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@ package(
55
default_visibility = ["//visibility:public"],
66
)
77

8+
java_library(
9+
name = "agentic_policy_environment",
10+
exports = ["//tools/src/main/java/dev/cel/tools/ai:agentic_policy_environment"],
11+
)
12+
813
java_library(
914
name = "agentic_policy_compiler",
1015
exports = ["//tools/src/main/java/dev/cel/tools/ai:agentic_policy_compiler"],
1116
)
1217

18+
alias(
19+
name = "ai_environments",
20+
actual = "//tools/src/main/resources/environment:ai_environments",
21+
)
22+
1323
alias(
1424
name = "test_policies",
1525
testonly = True,

tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static dev.cel.common.formats.YamlHelper.assertYamlType;
44

5+
import com.google.protobuf.Descriptors.FileDescriptor;
56
import dev.cel.bundle.Cel;
67
import dev.cel.common.CelAbstractSyntaxTree;
78
import dev.cel.common.formats.ValueString;
@@ -66,7 +67,9 @@ public void visitPolicyTag(
6667
break;
6768

6869
case "variables":
69-
if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) return;
70+
if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) {
71+
return;
72+
}
7073
List<Variable> parsedVariables = new ArrayList<>();
7174
SequenceNode varList = (SequenceNode) node;
7275

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package dev.cel.tools.ai;
2+
3+
import static java.nio.charset.StandardCharsets.UTF_8;
4+
5+
import com.google.common.base.Ascii;
6+
import com.google.common.collect.ImmutableCollection;
7+
import com.google.common.collect.ImmutableList;
8+
import com.google.common.collect.ImmutableSet;
9+
import com.google.common.io.Resources;
10+
import com.google.protobuf.Descriptors.FileDescriptor;
11+
import dev.cel.bundle.Cel;
12+
import dev.cel.bundle.CelEnvironment;
13+
import dev.cel.bundle.CelEnvironmentException;
14+
import dev.cel.bundle.CelEnvironmentYamlParser;
15+
import dev.cel.bundle.CelFactory;
16+
import dev.cel.common.CelContainer;
17+
import dev.cel.common.CelOptions;
18+
import dev.cel.common.types.CelType;
19+
import dev.cel.common.types.CelTypeProvider;
20+
import dev.cel.common.types.OpaqueType;
21+
import dev.cel.expr.ai.Agent;
22+
import dev.cel.expr.ai.AgentMessage;
23+
import dev.cel.expr.ai.AgentMessage.Part;
24+
import dev.cel.expr.ai.ClassificationLabel;
25+
import dev.cel.expr.ai.Finding;
26+
import dev.cel.parser.CelStandardMacro;
27+
import dev.cel.runtime.CelFunctionBinding;
28+
import java.io.IOException;
29+
import java.net.URL;
30+
import java.util.ArrayList;
31+
import java.util.List;
32+
import java.util.Optional;
33+
34+
final class AgenticPolicyEnvironment {
35+
36+
private static final CelOptions CEL_OPTIONS =
37+
CelOptions.current()
38+
.enableTimestampEpoch(true)
39+
.populateMacroCalls(true)
40+
.build();
41+
42+
private static final Cel CEL_BASE_ENV =
43+
CelFactory.standardCelBuilder()
44+
.setContainer(CelContainer.ofName("cel.expr.ai")) // TODO: config?
45+
.addFileTypes(Agent.getDescriptor().getFile())
46+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
47+
.setTypeProvider(new AgentTypeProvider())
48+
.addFunctionBindings(
49+
CelFunctionBinding.from(
50+
"AgentMessage_threatFindings",
51+
ImmutableList.of(AgentMessage.class),
52+
(args) -> getFindings((AgentMessage) args[0], "threats", ClassificationLabel.Category.THREAT)
53+
),
54+
CelFunctionBinding.from(
55+
"ai.finding_string_double",
56+
ImmutableList.of(String.class, Double.class),
57+
(args) -> Finding.newBuilder()
58+
.setValue((String) args[0])
59+
.setConfidence((Double) args[1])
60+
.build()
61+
),
62+
CelFunctionBinding.from(
63+
"optional_type(list(Finding))_hasAll_list(Finding)",
64+
ImmutableList.of(Optional.class, List.class),
65+
(args) -> hasAllFindings((Optional<List<Finding>>) args[0], (List<Finding>) args[1])
66+
)
67+
)
68+
.setOptions(CEL_OPTIONS)
69+
.build();
70+
71+
private static Optional<List<Finding>> getFindings(AgentMessage msg, String labelName, ClassificationLabel.Category category) {
72+
List<Finding> results = new ArrayList<>();
73+
74+
for (Part part : msg.getPartsList()) {
75+
if (part.hasPrompt()) {
76+
// TODO: Collect from classification
77+
results.add(Finding.newBuilder().setValue("prompt_injection").setConfidence(1.0d).build());
78+
} else if (part.hasToolCall()) {
79+
// TODO: Collect from classification
80+
}
81+
82+
}
83+
84+
if (results.isEmpty()) {
85+
return Optional.empty();
86+
}
87+
88+
return Optional.of(results);
89+
}
90+
91+
private static boolean hasAllFindings(Optional<List<Finding>> sourceOpt, List<Finding> required) {
92+
if (!sourceOpt.isPresent()) {
93+
return false;
94+
}
95+
List<Finding> source = sourceOpt.get();
96+
97+
return required.stream().allMatch(req ->
98+
source.stream().anyMatch(act ->
99+
act.getValue().equals(req.getValue()) &&
100+
act.getConfidence() >= req.getConfidence()
101+
)
102+
);
103+
}
104+
105+
static Cel newInstance() {
106+
Cel celEnv = CEL_BASE_ENV;
107+
108+
celEnv = extendFromConfig(celEnv, "environment/agent_env.yaml");
109+
celEnv = extendFromConfig(celEnv, "environment/common_env.yaml");
110+
return extendFromConfig(celEnv, "environment/tool_call_env.yaml");
111+
}
112+
113+
private static Cel extendFromConfig(Cel cel, String yamlConfigPath) {
114+
String yamlEnv;
115+
try {
116+
yamlEnv = readFile(yamlConfigPath);
117+
} catch (IOException e) {
118+
String errorMsg = String.format("Failed to read %s: %s", yamlConfigPath, e.getMessage());
119+
throw new IllegalArgumentException(errorMsg, e);
120+
}
121+
try {
122+
CelEnvironment env = CelEnvironmentYamlParser.newInstance().parse(yamlEnv);
123+
return env.extend(cel, CEL_OPTIONS);
124+
} catch (CelEnvironmentException e) {
125+
String errorMsg = String.format("Failed to extend CEL environment from %s: %s", yamlConfigPath, e.getMessage());
126+
throw new IllegalArgumentException(errorMsg, e);
127+
}
128+
}
129+
130+
private static String readFile(String path) throws IOException {
131+
URL url = Resources.getResource(Ascii.toLowerCase(path));
132+
return Resources.toString(url, UTF_8);
133+
}
134+
135+
private static final class AgentTypeProvider implements CelTypeProvider {
136+
private static final OpaqueType AGENT_MESSAGE_SET_TYPE = OpaqueType.create("cel.expr.ai.AgentMessageSet");
137+
138+
private static final ImmutableSet<CelType> ALL_TYPES = ImmutableSet.of(AGENT_MESSAGE_SET_TYPE);
139+
140+
@Override
141+
public ImmutableCollection<CelType> types() {
142+
return ALL_TYPES;
143+
}
144+
@Override
145+
public Optional<CelType> findType(String typeName) {
146+
if (typeName.equals(AGENT_MESSAGE_SET_TYPE.name())) {
147+
return Optional.of(AGENT_MESSAGE_SET_TYPE);
148+
}
149+
150+
return Optional.empty();
151+
}
152+
}
153+
154+
private AgenticPolicyEnvironment() {}
155+
}

tools/src/main/java/dev/cel/tools/ai/BUILD.bazel

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ package(
66
"//:license",
77
],
88
default_visibility = ["//visibility:public"],
9-
# default_visibility = [
10-
# "//tools/ai:__pkg__",
11-
# ],
9+
# default_visibility = [
10+
# "//tools/ai:__pkg__",
11+
# ],
1212
)
1313

1414
java_library(
1515
name = "agentic_policy_compiler",
1616
srcs = ["AgenticPolicyCompiler.java"],
1717
deps = [
1818
":agent_context_java_proto",
19+
":agentic_policy_environment",
1920
"//bundle:cel",
2021
"//common:cel_ast",
2122
"//common/formats:value_string",
@@ -33,6 +34,28 @@ java_library(
3334
],
3435
)
3536

37+
java_library(
38+
name = "agentic_policy_environment",
39+
srcs = ["AgenticPolicyEnvironment.java"],
40+
resources = ["//tools/ai:ai_environments"],
41+
deps = [
42+
":agent_context_extensions_java_proto",
43+
":agent_context_java_proto",
44+
"//bundle:cel",
45+
"//bundle:environment",
46+
"//bundle:environment_exception",
47+
"//bundle:environment_yaml_parser",
48+
"//common:container",
49+
"//common:options",
50+
"//common/types",
51+
"//common/types:type_providers",
52+
"//parser:macro",
53+
"//runtime:function_binding",
54+
"@maven//:com_google_guava_guava",
55+
"@maven//:com_google_protobuf_protobuf_java",
56+
],
57+
)
58+
3659
proto_library(
3760
name = "agent_context_proto",
3861
srcs = ["agent_context.proto"],

0 commit comments

Comments
 (0)