Skip to content

Commit 9470a01

Browse files
authored
[Firebase AI] Add JsonSchema automatic generation (#1425)
* [Firebase AI] Add JsonSchema automatic generation * Update UIHandlerAutomated.cs * Fix feedback issues * Handle nullable better
1 parent e547024 commit 9470a01

3 files changed

Lines changed: 263 additions & 18 deletions

File tree

firebaseai/src/JsonSchema.cs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
*/
1616

1717
using System;
18+
using System.Collections;
1819
using System.Collections.Generic;
1920
using System.Linq;
21+
using System.Reflection;
2022
using Firebase.AI.Internal;
2123

2224
namespace Firebase.AI
@@ -481,6 +483,10 @@ public static JsonSchema AnyOf(
481483
);
482484
}
483485

486+
/// <summary>
487+
/// Returns a `JsonSchema` that references a definition in a parent object.
488+
/// </summary>
489+
/// <param name="schemaReference">The path to the definition, typically "$/defs/class_name"</param>
484490
public static JsonSchema Ref(string schemaReference)
485491
{
486492
return new JsonSchema(null,
@@ -570,6 +576,187 @@ internal Dictionary<string, object> ToJson()
570576

571577
return json;
572578
}
579+
580+
/// <summary>
581+
/// Generates a JsonSchema for the given type, using reflection.
582+
/// Note that if the type implements: static JsonSchema ToJsonSchema(), that function
583+
/// will be called to generate the JsonSchema.
584+
/// </summary>
585+
/// <param name="type">The type to construct the JsonSchema of.</param>
586+
/// <param name="description">The description to use for the returned JsonSchema</param>
587+
public static JsonSchema FromType(Type type, string description = null)
588+
{
589+
return FromTypeInternal(type, null, description, new Dictionary<string, JsonSchema>(), true, out _);
590+
}
591+
592+
private static JsonSchema FromTypeInternal(Type type, MemberInfo memberInfo, string description,
593+
Dictionary<string, JsonSchema> definitions, bool topLevel, out bool optional)
594+
{
595+
optional = false;
596+
597+
if (type == null) return null;
598+
599+
// Handle Nullable<T> by unwrapping it.
600+
bool isNullableValueType = type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
601+
if (isNullableValueType)
602+
{
603+
type = System.Nullable.GetUnderlyingType(type);
604+
}
605+
606+
// If the given type has Schema info, pull it from that
607+
var schemaInfo = memberInfo == null ?
608+
type.GetCustomAttribute<SchemaInfoAttribute>() :
609+
memberInfo.GetCustomAttribute<SchemaInfoAttribute>();
610+
611+
optional = schemaInfo?.Optional ?? false;
612+
613+
// Check if there is a defined function on the type to make the JsonSchema
614+
var toSchemaMethod = type.GetMethod("ToJsonSchema",
615+
BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy);
616+
if (toSchemaMethod != null && toSchemaMethod.ReturnType == typeof(JsonSchema) &&
617+
toSchemaMethod.GetParameters().Length == 0)
618+
{
619+
return (JsonSchema)toSchemaMethod.Invoke(null, null);
620+
}
621+
622+
// If not provided a description, check the schemaInfo object
623+
description ??= schemaInfo?.Description;
624+
bool nullable = schemaInfo != null && schemaInfo.Nullable;
625+
626+
// Check for the commonly used attribute Range, for Min and Max
627+
float? min = null;
628+
float? max = null;
629+
var rangeAttr = memberInfo == null ?
630+
type.GetCustomAttribute<UnityEngine.RangeAttribute>() :
631+
memberInfo.GetCustomAttribute<UnityEngine.RangeAttribute>();
632+
if (rangeAttr != null)
633+
{
634+
min = rangeAttr.min;
635+
max = rangeAttr.max;
636+
}
637+
638+
if (type.IsPrimitive)
639+
{
640+
// Possible primitives:
641+
// bool
642+
// byte, sbyte, short, ushort, int, uint, long, ulong, float, double
643+
// char *Not clearly mapped*
644+
// IntPtr, UIntPtr *Not clearly mapped*
645+
if (type == typeof(bool))
646+
{
647+
return Boolean(description: description, nullable: nullable);
648+
}
649+
else if (type == typeof(float))
650+
{
651+
return Float(description: description, nullable: nullable,
652+
minimum: min, maximum: max);
653+
}
654+
else if (type == typeof(double))
655+
{
656+
return Double(description: description, nullable: nullable,
657+
minimum: min, maximum: max);
658+
}
659+
else if (type == typeof(long) || type == typeof(ulong))
660+
{
661+
return Long(description: description, nullable: nullable,
662+
minimum: (long?)min, maximum: (long?)max);
663+
}
664+
else
665+
{
666+
// Treat everything else as an Int. While there could be logic to add to set
667+
// minimum and maximums based on the type, it will likely be unnecessary.
668+
return Int(description: description, nullable: nullable,
669+
minimum: (int?)min, maximum: (int?)max);
670+
}
671+
}
672+
else if (type.IsEnum)
673+
{
674+
return Enum(type.GetEnumNames(), description: description, nullable: nullable);
675+
}
676+
else if (type == typeof(string))
677+
{
678+
return String(description: description, nullable: nullable);
679+
}
680+
else if (type.IsArray)
681+
{
682+
Type elementType = type.GetElementType();
683+
JsonSchema elementSchema = FromTypeInternal(elementType, null, null, definitions, false, out _);
684+
return Array(elementSchema, description: description, nullable: nullable);
685+
}
686+
else if (type.IsGenericType && typeof(IEnumerable).IsAssignableFrom(type))
687+
{
688+
// There isn't a great way to handle dictionaries, so bail out.
689+
if (typeof(IDictionary).IsAssignableFrom(type))
690+
{
691+
return null;
692+
}
693+
Type elementType = type.GetInterfaces()
694+
.FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))
695+
?.GetGenericArguments()[0] ?? typeof(object);
696+
JsonSchema elementSchema = FromTypeInternal(elementType, null, null, definitions, false, out _);
697+
return Array(elementSchema, description, nullable: nullable);
698+
}
699+
else
700+
{
701+
// Assume it is an Object
702+
// If this is not at the top level, we want to add it to the definitions, and use a ref instead
703+
if (!topLevel)
704+
{
705+
string key = type.FullName;
706+
if (!definitions.ContainsKey(key))
707+
{
708+
definitions[key] = null; // Placeholder to prevent infinite recursion.
709+
JsonSchema jsonSchema = GenerateObject(type, schemaInfo, description, definitions, false);
710+
definitions[key] = jsonSchema;
711+
}
712+
713+
return Ref($"#/$defs/{key}");
714+
}
715+
716+
// Generate the top level object, which will include any found definitions.
717+
return GenerateObject(type, schemaInfo, description, definitions, topLevel);
718+
}
719+
}
720+
721+
private static JsonSchema GenerateObject(Type type, SchemaInfoAttribute schemaInfo, string description,
722+
Dictionary<string, JsonSchema> definitions, bool includeDefinitions)
723+
{
724+
bool nullable = schemaInfo != null && schemaInfo.Nullable;
725+
726+
Dictionary<string, JsonSchema> properties = new();
727+
List<string> optionalProperties = new();
728+
// Get the public Fields and Properties
729+
var infos = type.FindMembers(
730+
MemberTypes.Field | MemberTypes.Property,
731+
BindingFlags.Instance | BindingFlags.Public,
732+
null, null);
733+
foreach (var info in infos)
734+
{
735+
JsonSchema jsonSchema = null;
736+
bool optional = false;
737+
if (info is FieldInfo fieldInfo)
738+
{
739+
jsonSchema = FromTypeInternal(fieldInfo.FieldType, info, null, definitions, false, out optional);
740+
}
741+
else if (info is PropertyInfo propertyInfo)
742+
{
743+
jsonSchema = FromTypeInternal(propertyInfo.PropertyType, info, null, definitions, false, out optional);
744+
}
745+
746+
if (jsonSchema != null)
747+
{
748+
properties[info.Name] = jsonSchema;
749+
if (optional)
750+
{
751+
optionalProperties.Add(info.Name);
752+
}
753+
}
754+
}
755+
756+
return Object(properties, optionalProperties: optionalProperties,
757+
description: description, title: schemaInfo?.Title,
758+
nullable: nullable, schemaDefinitions: includeDefinitions ? definitions : null);
759+
}
573760
}
574761

575762
}

firebaseai/src/Serialization.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,40 @@
2222

2323
namespace Firebase.AI
2424
{
25+
/// <summary>
26+
/// Attribute that can be used to define various fields when generating
27+
/// the JsonSchema for it.
28+
/// </summary>
29+
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Property | AttributeTargets.Field)]
30+
public class SchemaInfoAttribute : Attribute
31+
{
32+
/// <summary>
33+
/// A human-readable explanation of the purpose of the schema or property.
34+
/// </summary>
35+
public string Description { get; set; } = null;
36+
37+
/// <summary>
38+
/// A human-readable name/summary for the schema or a specific property.
39+
/// </summary>
40+
public string Title { get; set; } = null;
41+
42+
/// <summary>
43+
/// Indicates if the value may be null.
44+
/// </summary>
45+
public bool Nullable { get; set; } = false;
46+
47+
/// <summary>
48+
/// The format of the data.
49+
/// </summary>
50+
public string Format { get; set; } = null;
51+
52+
/// <summary>
53+
/// Indicates that the property should be considered as optional.
54+
/// Properties are considered required by default.
55+
/// </summary>
56+
public bool Optional { get; set; } = false;
57+
}
58+
2559
/// <summary>
2660
/// Interface to define a method to construct the object from a Dictionary<string, object>.
2761
///
@@ -80,6 +114,10 @@ internal static object ObjectToType(object obj, Type type)
80114
{
81115
return t;
82116
}
117+
else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
118+
{
119+
return ObjectToType(obj, Nullable.GetUnderlyingType(type));
120+
}
83121
else if (type.IsEnum)
84122
{
85123
if (obj is string str) return Enum.Parse(type, str);

firebaseai/testapp/Assets/Firebase/Sample/FirebaseAI/UIHandlerAutomated.cs

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,33 +1128,53 @@ async Task TestTemplateImagenGenerateImage(Backend backend)
11281128
AssertEq($"Image Height = Width", texture.height, texture.width);
11291129
}
11301130

1131+
// Class used for validating JsonSchema generation
1132+
public class SampleRecord
1133+
{
1134+
public enum MyColor
1135+
{
1136+
Red,
1137+
Green,
1138+
Blue
1139+
}
1140+
1141+
[SchemaInfo(Description = "The first and last name of a person")]
1142+
public string name;
1143+
1144+
public int age;
1145+
1146+
public bool Alive { get; set; }
1147+
1148+
[SchemaInfo(Description = "How much of their life they have left.")]
1149+
[Range(0, 1)]
1150+
public double Percent;
1151+
1152+
public MyColor eye_color;
1153+
1154+
[SchemaInfo(Optional = true)]
1155+
public char BloodType;
1156+
1157+
[SchemaInfo(Nullable = true)]
1158+
public SampleRecord[] Children;
1159+
1160+
public override string ToString()
1161+
{
1162+
return $"{name} {age} {Alive} {Percent} {eye_color} [{string.Join(", ", Children.Select(t => $"({t})"))}]";
1163+
}
1164+
}
1165+
11311166
async Task TestJsonSchemaStructureOutput(Backend backend)
11321167
{
11331168
var model = GetFirebaseAI(backend).GetGenerativeModel(TestModelName,
11341169
generationConfig: new GenerationConfig(
11351170
responseMimeType: "application/json",
1136-
responseJsonSchema: JsonSchema.Object(
1137-
properties: new Dictionary<string, JsonSchema>
1138-
{
1139-
{ "metadata", JsonSchema.Ref("#/$defs/metadata_schema") }
1140-
},
1141-
schemaDefinitions: new Dictionary<string, JsonSchema>
1142-
{
1143-
{
1144-
"metadata_schema", JsonSchema.Object(
1145-
properties: new Dictionary<string, JsonSchema> {
1146-
{ "id", JsonSchema.String() },
1147-
{ "data", JsonSchema.String() }
1148-
}
1149-
)
1150-
}
1151-
})));
1171+
responseJsonSchema: JsonSchema.FromType(typeof(SampleRecord))));
11521172

1153-
var response = await model.GenerateContentAsync(
1173+
var response = await model.GenerateObjectAsync<SampleRecord>(
11541174
"Hello, I am testing setting the response schema with an object, cause you give me some random values.");
11551175

11561176
// There isn't much guarantee on what this will respond with. We just want non-empty.
1157-
Assert("Response was empty.", !string.IsNullOrWhiteSpace(response.Text));
1177+
Assert("Response was missing a Name.", !string.IsNullOrWhiteSpace(response.Result.name));
11581178
}
11591179

11601180
// Test providing a file from a GCS bucket (Firebase Storage) to the model.

0 commit comments

Comments
 (0)