From 12540a4a4900870d8be24a7c2e08e583aef1f7af Mon Sep 17 00:00:00 2001 From: Adam Wulf Date: Mon, 13 May 2024 23:19:15 -0500 Subject: [PATCH] Refactor so that Memory context is built by the MemoryTool, which opens the ability for later tools to also add context to the request. --- Heat/Conversations/ConversationInput.swift | 4 +--- .../Conversations/ConversationViewModel.swift | 19 ++++-------------- Heat/Launcher/LauncherViewModel.swift | 7 +++---- .../Sources/HeatKit/Tools/ContextTool.swift | 5 +++++ .../Sources/HeatKit/Tools/MemoryTool.swift | 20 +++++++++++++++++++ 5 files changed, 33 insertions(+), 22 deletions(-) create mode 100644 HeatKit/Sources/HeatKit/Tools/ContextTool.swift diff --git a/Heat/Conversations/ConversationInput.swift b/Heat/Conversations/ConversationInput.swift index 555a240..eea408c 100644 --- a/Heat/Conversations/ConversationInput.swift +++ b/Heat/Conversations/ConversationInput.swift @@ -11,8 +11,6 @@ struct ConversationInput: View { @Environment(ConversationViewModel.self) var conversationViewModel @Environment(\.modelContext) private var modelContext - @Query(sort: \Memory.created, order: .forward) var memories: [Memory] - @State var imagePickerViewModel: ImagePickerViewModel @State var content: String @State var command: String @@ -184,7 +182,7 @@ struct ConversationInput: View { guard !content.isEmpty else { return } do { - try conversationViewModel.generate(content, context: memories.map { $0.content }) + try conversationViewModel.generate(content) } catch let error as KitError { conversationViewModel.error = error } catch { diff --git a/Heat/Conversations/ConversationViewModel.swift b/Heat/Conversations/ConversationViewModel.swift index 9244cf0..12d7233 100644 --- a/Heat/Conversations/ConversationViewModel.swift +++ b/Heat/Conversations/ConversationViewModel.swift @@ -43,7 +43,7 @@ final class ConversationViewModel { conversationID = conversation.id } - func generate(_ content: String, context: [String] = []) throws { + func generate(_ content: String) throws { guard !content.isEmpty else { return } guard let conversation else { throw KitError.missingConversation @@ -54,13 +54,12 @@ final class ConversationViewModel { let toolService = try store.preferredToolService() let toolModel = try store.preferredToolModel() - - let context = prepareContext(context) - + let contextTools: [ContextTool.Type] = [MemoryTool.self] + generateTask = Task { await MessageManager() .append(messages: messages) - .append(message: context) + .append(messages: contextTools.compactMap({ $0.prepareContext() })) .append(message: .init(role: .user, content: content)) { message in self.store.upsert(suggestions: [], conversationID: conversation.id) self.store.upsert(message: message, conversationID: conversation.id) @@ -253,15 +252,5 @@ final class ConversationViewModel { } return nil } - - private func prepareContext(_ context: [String]) -> Message? { - guard !context.isEmpty else { return nil } - - return Message(role: .system, content: """ - Some things to remember about who the user is. Use these to better relate to the user when responding: - - \(context.joined(separator: "\n")) - """) - } } diff --git a/Heat/Launcher/LauncherViewModel.swift b/Heat/Launcher/LauncherViewModel.swift index 485640b..f66bad8 100644 --- a/Heat/Launcher/LauncherViewModel.swift +++ b/Heat/Launcher/LauncherViewModel.swift @@ -35,7 +35,7 @@ final class LauncherViewModel { conversationID = conversation.id } - func generate(_ content: String, context: [String] = []) throws { + func generate(_ content: String) throws { guard !content.isEmpty else { return } guard let conversation else { throw KitError.missingConversation @@ -43,13 +43,12 @@ final class LauncherViewModel { let chatService = try store.preferredChatService() let chatModel = try store.preferredChatModel() - - let context = prepareContext(context) + let contextTools: [ContextTool.Type] = [MemoryTool.self] generateTask = Task { await MessageManager() .append(messages: messages) - .append(message: context) + .append(messages: contextTools.compactMap({ $0.prepareContext() })) .append(message: .init(role: .user, content: content)) { message in self.store.upsert(suggestions: [], conversationID: conversation.id) self.store.upsert(message: message, conversationID: conversation.id) diff --git a/HeatKit/Sources/HeatKit/Tools/ContextTool.swift b/HeatKit/Sources/HeatKit/Tools/ContextTool.swift new file mode 100644 index 0000000..2d35e58 --- /dev/null +++ b/HeatKit/Sources/HeatKit/Tools/ContextTool.swift @@ -0,0 +1,5 @@ +import GenKit + +public protocol ContextTool { + static func prepareContext() -> Message? +} diff --git a/HeatKit/Sources/HeatKit/Tools/MemoryTool.swift b/HeatKit/Sources/HeatKit/Tools/MemoryTool.swift index 476e98c..cbddec8 100644 --- a/HeatKit/Sources/HeatKit/Tools/MemoryTool.swift +++ b/HeatKit/Sources/HeatKit/Tools/MemoryTool.swift @@ -70,3 +70,23 @@ extension MemoryTool { } } } + +extension MemoryTool: ContextTool { + public static func prepareContext() -> Message? { + guard let container = try? ModelContainer(for: Memory.self) else { return nil } + let fetchRequest = FetchDescriptor(sortBy: [SortDescriptor(\Memory.created, order: .forward)]) + let context = ModelContext(container) + do { + let memories: [Memory] = try context.fetch(fetchRequest) + guard !memories.isEmpty else { return nil } + return Message(role: .system, content: """ + Some things to remember about who the user is. Use these to better relate to the user when responding: + + \(memories.map({ $0.content }).joined(separator: "\n")) + """) + } catch { + print("Failed to fetch memories: \(error)") + return nil + } + } +}