diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..5685892 Binary files /dev/null and b/.DS_Store differ diff --git a/conversational_analytics/README.md b/conversational_analytics/README.md new file mode 100644 index 0000000..1e7fc0d --- /dev/null +++ b/conversational_analytics/README.md @@ -0,0 +1,24 @@ +# Spanner Conversational Analytics Examples + +This directory contains examples demonstrating how to use the Conversational Analytics API with Google Cloud Spanner. These examples show how to build data agents that can understand natural language questions and generate SQL queries to retrieve answers from your Spanner database. + +**Prerequisites:** + +* A Google Cloud Project with Spanner and an active Spanner instance. +* The Cloud AI Companion API enabled. +* Appropriate IAM permissions to access Spanner and the Conversational Analytics API. See [Conversational Analytics API access control with IAM](https://docs.cloud.google.com/gemini/data-agents/conversational-analytics-api/access-control) for details. +* `gcloud` CLI installed and configured. +* Python 3.7+ (for SDK examples). + +## Examples + +This guide provides examples for interacting with the API using both the Python Client Library (SDK) and HTTP requests (curl). + +### 1. Using the Python SDK + +This example shows how to set up a Data Agent and have a conversation to query your Spanner database. + +**Installation:** + +```bash +pip install google-cloud-aiplatform diff --git a/conversational_analytics/agent_spanner_http_example.ipynb b/conversational_analytics/agent_spanner_http_example.ipynb new file mode 100644 index 0000000..c652db6 --- /dev/null +++ b/conversational_analytics/agent_spanner_http_example.ipynb @@ -0,0 +1,724 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "kUl_qiynkVjI" + } + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# # Spanner Data Agent End-to-End Example\n", + "#\n", + "# This notebook demonstrates how to create and interact with a Data Agent for Spanner using the Gemini Data Analytics API via HTTP requests in Python.\n" + ], + "metadata": { + "id": "aKXN5h8dkaBd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Define helper functions\n", + "\n", + "import json as json_lib\n", + "import textwrap\n", + "import pandas as pd\n", + "from IPython.display import display, HTML\n", + "import altair as alt\n", + "from pygments import formatters, highlight, lexers\n", + "import requests\n", + "\n", + "def is_json(my_str):\n", + " \"\"\"Checks if a string is valid JSON.\"\"\"\n", + " try:\n", + " json_lib.loads(my_str)\n", + " except ValueError:\n", + " return False\n", + " return True\n", + "\n", + "def handle_text_response(resp):\n", + " \"\"\"Handles and prints text responses, wrapping long lines.\"\"\"\n", + " parts = resp.get('parts', [])\n", + " full_text = \"\".join(parts)\n", + " if \"\\n\" not in full_text and len(full_text) > 80:\n", + " wrapped_text = textwrap.fill(full_text, width=80)\n", + " print(wrapped_text)\n", + " else:\n", + " print(full_text)\n", + "\n", + "def get_property(data, field_name, default=''):\n", + " \"\"\"Safely gets a property from a dictionary.\"\"\"\n", + " return data.get(field_name, default)\n", + "\n", + "def display_schema(data):\n", + " \"\"\"Displays schema information in a DataFrame.\"\"\"\n", + " fields = data.get('fields', [])\n", + " df = pd.DataFrame({\n", + " \"Column\": [get_property(field, 'name') for field in fields],\n", + " \"Type\": [get_property(field, 'type') for field in fields],\n", + " \"Description\": [get_property(field, 'description', '-') for field in fields],\n", + " \"Mode\": [get_property(field, 'mode') for field in fields]\n", + " })\n", + " display(df)\n", + "\n", + "def display_section_title(text):\n", + " \"\"\"Displays a formatted section title in HTML.\"\"\"\n", + " display(HTML(f'

{text}

'))\n", + "\n", + "def format_bq_table_ref(table_ref):\n", + " \"\"\"Formats a BigQuery table reference for display.\"\"\"\n", + " return f\"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}\"\n", + "\n", + "def format_looker_table_ref(table_ref):\n", + " \"\"\"Formats a Looker table reference for display.\"\"\"\n", + " return f\"lookmlModel: {table_ref.get('lookmlModel')}, explore: {table_ref.get('explore')}, lookerInstanceUri: {table_ref.get('lookerInstanceUri')}\"\n", + "\n", + "def display_datasource(datasource):\n", + " \"\"\"Displays information about a datasource, including its schema.\"\"\"\n", + " source_name = ''\n", + " if 'studioDatasourceId' in datasource:\n", + " source_name = datasource['studioDatasourceId']\n", + " elif 'lookerExploreReference' in datasource:\n", + " source_name = format_looker_table_ref(datasource['lookerExploreReference'])\n", + " elif 'bigqueryTableReference' in datasource:\n", + " source_name = format_bq_table_ref(datasource['bigqueryTableReference'])\n", + " else:\n", + " source_name = \"Unknown Datasource\"\n", + "\n", + " print(source_name)\n", + " if 'schema' in datasource:\n", + " display_schema(datasource['schema'])\n", + "\n", + "def handle_schema_response(resp):\n", + " \"\"\"Handles responses related to schema resolution.\"\"\"\n", + " if 'query' in resp:\n", + " print(get_property(resp['query'], 'question'))\n", + " elif 'result' in resp:\n", + " display_section_title('Schema resolved')\n", + " print('Data sources:')\n", + " for datasource in get_property(resp['result'], 'datasources', []):\n", + " display_datasource(datasource)\n", + "\n", + "def handle_data_response(resp):\n", + " \"\"\"Handles responses containing data or SQL queries.\"\"\"\n", + " if 'query' in resp:\n", + " query = resp['query']\n", + " display_section_title('Retrieval query')\n", + " print(f\"Query name: {get_property(query, 'name')}\")\n", + " if 'question' in query:\n", + " print(f\"Question: {get_property(query, 'question')}\")\n", + " if 'datasources' in query:\n", + " print('Data sources:')\n", + " for datasource in get_property(query, 'datasources', []):\n", + " display_datasource(datasource)\n", + " elif 'generatedSql' in resp:\n", + " display_section_title('SQL generated')\n", + " print(resp['generatedSql'])\n", + " elif 'result' in resp:\n", + " display_section_title('Data retrieved')\n", + " result = resp['result']\n", + " schema = result.get('schema', {})\n", + " fields = [get_property(field, 'name') for field in schema.get('fields', [])]\n", + " data = result.get('data', [])\n", + "\n", + " data_dict = {field: [get_property(el, field) for el in data] for field in fields}\n", + " display(pd.DataFrame(data_dict))\n", + "\n", + "def handle_chart_response(resp):\n", + " \"\"\"Handles responses for generating charts.\"\"\"\n", + " if 'query' in resp:\n", + " print(get_property(resp['query'], 'instructions'))\n", + " elif 'result' in resp:\n", + " vegaConfig = get_property(resp['result'], 'vegaConfig')\n", + " if vegaConfig:\n", + " alt.Chart.from_json(json_lib.dumps(vegaConfig)).display()\n", + "\n", + "def handle_error(resp):\n", + " \"\"\"Handles error responses.\"\"\"\n", + " display_section_title('Error')\n", + " print(f\"Code: {get_property(resp, 'code')}\")\n", + " print(f\"Message: {get_property(resp, 'message')}\")\n", + "\n" + ], + "metadata": { + "id": "ozC3VvPfE6qp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "def get_stream(url, payload):\n", + " \"\"\"\n", + " Posts data to a URL and processes the streaming JSON response line by line.\n", + "\n", + " Args:\n", + " url (str): The URL to send the POST request to.\n", + " payload (dict): The JSON payload to send in the request body.\n", + " \"\"\"\n", + " s = requests.Session()\n", + " acc = '' # Accumulator for JSON parts\n", + "\n", + " try:\n", + " with s.post(url, json=payload, headers=headers, stream=True) as resp:\n", + " resp.raise_for_status() # Raise an exception for bad status codes\n", + " for line in resp.iter_lines():\n", + " if not line:\n", + " continue\n", + "\n", + " try:\n", + " decoded_line = line.decode('utf-8')\n", + " except UnicodeDecodeError:\n", + " print(f\"Warning: Could not decode line: {line}\")\n", + " continue\n", + "\n", + " # This custom JSON assembly logic seems fragile.\n", + " # It's attempting to piece together a JSON object\n", + " # from lines that might not be complete JSON objects themselves.\n", + " if decoded_line == '[{':\n", + " acc = '{'\n", + " elif decoded_line == '}]':\n", + " acc += '}'\n", + " elif decoded_line == ',':\n", + " continue\n", + " else:\n", + " acc += decoded_line\n", + "\n", + " if not is_json(acc):\n", + " continue\n", + "\n", + " try:\n", + " data_json = json_lib.loads(acc)\n", + " except json_lib.JSONDecodeError:\n", + " print(f\"Warning: Could not decode accumulated JSON: {acc}\")\n", + " acc = '' # Reset accumulator on error\n", + " continue\n", + "\n", + " if 'error' in data_json:\n", + " handle_error(data_json['error'])\n", + " acc = ''\n", + " continue\n", + "\n", + " if 'systemMessage' in data_json:\n", + " system_message = data_json['systemMessage']\n", + " if 'text' in system_message:\n", + " handle_text_response(system_message['text'])\n", + " elif 'schema' in system_message:\n", + " handle_schema_response(system_message['schema'])\n", + " elif 'data' in system_message:\n", + " handle_data_response(system_message['data'])\n", + " elif 'chart' in system_message:\n", + " handle_chart_response(system_message['chart'])\n", + " else:\n", + " # Fallback for unhandled systemMessage types\n", + " colored_json = highlight(acc, lexers.JsonLexer(), formatters.TerminalFormatter())\n", + " print(colored_json)\n", + " else:\n", + " # Fallback for responses without systemMessage or error\n", + " colored_json = highlight(acc, lexers.JsonLexer(), formatters.TerminalFormatter())\n", + " print(colored_json)\n", + "\n", + " acc = '' # Reset accumulator after processing a complete JSON object\n", + "\n", + " except requests.exceptions.RequestException as e:\n", + " print(f\"Error during request: {e}\")\n", + " except Exception as e:\n", + " print(f\"An unexpected error occurred: {e}\")\n", + "\n" + ], + "metadata": { + "id": "zx1t7jToHKJH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def get_stream_multi_turn(url, json, conversation_messages):\n", + " s = requests.Session()\n", + "\n", + " acc = ''\n", + "\n", + " with s.post(url, json=json, headers=headers, stream=True) as resp:\n", + " for line in resp.iter_lines():\n", + " if not line:\n", + " continue\n", + "\n", + " decoded_line = str(line, encoding='utf-8')\n", + "\n", + " if decoded_line == '[{':\n", + " acc = '{'\n", + " elif decoded_line == '}]':\n", + " acc += '}'\n", + " elif decoded_line == ',':\n", + " continue\n", + " else:\n", + " acc += decoded_line\n", + "\n", + " if not is_json(acc):\n", + " continue\n", + "\n", + " data_json = json_lib.loads(acc)\n", + " # Store the response that will be used in the next iteration\n", + " conversation_messages.append(data_json)\n", + "\n", + " if not 'systemMessage' in data_json:\n", + " if 'error' in data_json:\n", + " handle_error(data_json['error'])\n", + " continue\n", + "\n", + " if 'text' in data_json['systemMessage']:\n", + " handle_text_response(data_json['systemMessage']['text'])\n", + " elif 'schema' in data_json['systemMessage']:\n", + " handle_schema_response(data_json['systemMessage']['schema'])\n", + " elif 'data' in data_json['systemMessage']:\n", + " handle_data_response(data_json['systemMessage']['data'])\n", + " elif 'chart' in data_json['systemMessage']:\n", + " handle_chart_response(data_json['systemMessage']['chart'])\n", + " else:\n", + " colored_json = highlight(acc, lexers.JsonLexer(), formatters.TerminalFormatter())\n", + " print(colored_json)\n", + " print('\\n')\n", + " acc = ''" + ], + "metadata": { + "id": "fQ7_GcurkS2M" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# # Spanner Data Agent End-to-End Example\n", + "#\n", + "# This notebook demonstrates how to create and interact with a Data Agent for Spanner using the Gemini Data Analytics API via HTTP requests in Python.\n", + "\n", + "# %% [markdown]\n", + "# ## 1. Setup and Authentication\n", + "#\n", + "# Import necessary libraries and authenticate your user. This will be used to authorize the API requests.\n", + "\n", + "# %% code\n", + "import json\n", + "import requests\n", + "from google.colab import auth\n", + "import os\n", + "\n", + "# Authenticate user and get access token\n", + "auth.authenticate_user()\n", + "access_token = !gcloud auth application-default print-access-token\n", + "\n", + "if not access_token or not access_token[0]:\n", + " raise ValueError(\"Failed to get access token. Please ensure you are authenticated.\")\n", + "\n", + "headers = {\n", + " \"Authorization\": f\"Bearer {access_token[0]}\",\n", + " \"Content-Type\": \"application/json\",\n", + " \"x-server-timeout\": \"300\", # Custom timeout up to 600s\n", + "}\n", + "print(\"Headers configured.\")" + ], + "metadata": { + "id": "8PuSOsrXkipJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 2. Configure Data Source Details\n", + "#\n", + "# Update the variables below with your specific Google Cloud project, Spanner instance details, and desired agent settings.\n", + "\n", + "# %% code\n", + "# Replace with your actual project and Spanner details\n", + "billing_project = \"project-id\" # Your billing project ID\n", + "location = \"global\" # The region of your Spanner instance\n", + "\n", + "# Spanner connection details\n", + "spanner_project_id = \"project-id\" # Project ID of the Spanner instance\n", + "spanner_instance_id = \"instance-id\" # Your spanner instance ID\n", + "spanner_database_id = \"database-id\" # Your database name\n", + "engine = \"GOOGLE_SQL\"\n", + "system_instruction = \"Help the user analyze data from the spanner database\"\n", + "\n", + "# Data Agent ID\n", + "data_agent_id = \"spanner_agent_e2e_example\"" + ], + "metadata": { + "id": "36Ffl_9SlOYF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 3. Define Spanner Data Source\n", + "#\n", + "# This dictionary structure tells the Data Agent how to connect to your Spanner database.\n", + "# Its better to provide context_set_id similar to https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases#providing_context_with_querydata\n", + "\n", + "# %% code\n", + "# Spanner data source definition\n", + "spanner_data_sources = {\n", + " \"spanner_reference\": {\n", + " \"database_reference\": {\n", + " \"project_id\": spanner_project_id,\n", + " \"region\": location,\n", + " \"engine\": engine,\n", + " \"instance_id\": spanner_instance_id,\n", + " \"database_id\": spanner_database_id,\n", + " },\n", + " # Optional: Include this if you have pre-authored context for the agent\n", + " # \"agent_context_reference\": {\n", + " # \"context_set_id\": f\"projects/{billing_project}/locations/{location}/contextSets/your_context_set_id\"\n", + " # }\n", + " }\n", + "}\n", + "print(\"Spanner data source configured:\")\n", + "print(json.dumps(spanner_data_sources, indent=2))\n" + ], + "metadata": { + "id": "Mz0_sMECVSZT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 4. Create the Data Agent\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to create the Data Agent resource.\n", + "# Make sure roles/geminidataanalytics.dataAgentCreator is granted\n", + "\n", + "# %% code\n", + "data_agent_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/global/dataAgents\"\n", + "\n", + "data_agent_payload = {\n", + " \"name\": f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\",\n", + " \"description\": \"This is an example Spanner data agent created from Colab.\", # Optional\n", + " \"data_analytics_agent\": {\n", + " \"published_context\": {\n", + " \"datasource_references\": spanner_data_sources,\n", + " \"system_instruction\": system_instruction,\n", + " }\n", + " }\n", + "}\n", + "\n", + "params = {\"data_agent_id\": data_agent_id} # Optional\n", + "\n", + "print(f\"Creating Data Agent: {data_agent_id}...\")\n", + "data_agent_response = requests.post(\n", + " data_agent_url, params=params, json=data_agent_payload, headers=headers\n", + ")\n", + "\n", + "if data_agent_response.status_code == 200:\n", + " print(\"Data Agent created successfully!\")\n", + " print(json.dumps(data_agent_response.json(), indent=2))\n", + "elif data_agent_response.status_code == 409:\n", + " print(f\"Data Agent '{data_agent_id}' already exists.\")\n", + " # Optionally, you could add code here to fetch the existing agent\n", + "else:\n", + " print(f\"Error creating Data Agent: {data_agent_response.status_code}\")\n", + " print(data_agent_response.text)" + ], + "metadata": { + "id": "ir8wraonVhnN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 5. Create the conversation\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to create conversation.\n", + "\n", + "# %% code\n", + "conversation_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/global/conversations\"\n", + "\n", + "conversation_id = \"conversation_agent_spanner_example\"\n", + "\n", + "conversation_payload = {\n", + " \"agents\": [\n", + " f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\"\n", + " ],\n", + " \"name\": f\"projects/{billing_project}/locations/global/conversations/{conversation_id}\"\n", + "}\n", + "\n", + "params = {\n", + " \"conversation_id\": conversation_id\n", + "}\n", + "\n", + "conversation_response = requests.post(conversation_url, headers=headers, params=params, json=conversation_payload)\n", + "\n", + "if conversation_response.status_code == 200:\n", + " print(\"Conversation created successfully!\")\n", + " print(json.dumps(conversation_response.json(), indent=2))\n", + "else:\n", + " print(f\"Error creating Conversation: {conversation_response.status_code}\")\n", + " print(conversation_response.text)" + ], + "metadata": { + "id": "Fcim8o3Q4RXM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 6. Chat with the API by using conversation (stateful)\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat using a conversation\n", + "\n", + "# %% code\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/global:chat\"\n", + "\n", + "# The natural language question to ask the data agent\n", + "user_prompt = \"Which artist has the most cards in the database?\" # Replace with your question\n", + "\n", + "\n", + "# Construct the payload\n", + "chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/global\",\n", + " \"messages\": [\n", + " {\n", + " \"userMessage\": {\n", + " \"text\": user_prompt\n", + " }\n", + " }\n", + " ],\n", + " \"conversation_reference\": {\n", + " \"conversation\": f\"projects/{billing_project}/locations/global/conversations/{conversation_id}\",\n", + " \"data_agent_context\": {\n", + " \"data_agent\": f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\",\n", + " }\n", + " }\n", + "}\n", + "\n", + "print(f\"Sending prompt to :chat: '{user_prompt}'\")\n", + "print(f\"Endpoint: {chat_url}\")\n", + "print(f\"Payload: {json.dumps(chat_payload, indent=2)}\")\n", + "\n", + "\n", + "get_stream(chat_url, chat_payload)" + ], + "metadata": { + "id": "hnzypmKM6o54" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 7. Chat with the API by using data agent (stateless)\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat without conversation\n", + "# %% code\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/{location}:chat\"\n", + "\n", + "\n", + "# Construct the payload\n", + "chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/global\",\n", + " \"messages\": [\n", + " {\n", + " \"userMessage\": {\n", + " \"text\": user_prompt\n", + " }\n", + " }\n", + " ],\n", + " \"data_agent_context\": {\n", + " \"data_agent\": f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\",\n", + " }\n", + "}\n", + "\n", + "print(f\"Sending prompt to :chat: '{user_prompt}'\")\n", + "print(f\"Endpoint: {chat_url}\")\n", + "print(f\"Payload: {json.dumps(chat_payload, indent=2)}\")\n", + "\n", + "get_stream(chat_url, chat_payload)\n" + ], + "metadata": { + "id": "gMxaruBA-Z2M" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 8. Chat with the API by using inline context (stateless)\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat with inline context\n", + "# %% code\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/us-east4:chat\"\n", + "\n", + "\n", + "# Construct the payload\n", + "chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/{location}\",\n", + " \"messages\": [\n", + " {\n", + " \"userMessage\": {\n", + " \"text\": user_prompt\n", + " }\n", + " }\n", + " ],\n", + " \"inline_context\": {\n", + " \"datasource_references\": spanner_data_sources\n", + " # The \"options\" field for python analysis is not shown in the Spanner example\n", + " }\n", + "}\n", + "\n", + "print(f\"Sending prompt to :chat: '{user_prompt}'\")\n", + "print(f\"Endpoint: {chat_url}\")\n", + "print(f\"Payload: {json.dumps(chat_payload, indent=2)}\")\n", + "\n", + "get_stream(chat_url, chat_payload)\n" + ], + "metadata": { + "id": "CPG0ge21BeHY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 9. Chat with the API multi turn conversation\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat with multi turn conversation\n", + "# %% code\n", + "\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/{location}:chat\"\n", + "\n", + "# List that is used to track previous turns and is reused across requests\n", + "conversation_messages = []\n", + "\n", + "# Helper function for calling the API\n", + "def multi_turn_conversation(msg):\n", + " userMessage = {\n", + " \"userMessage\": {\n", + " \"text\": msg\n", + " }\n", + " }\n", + "\n", + " # Send a multi-turn request by including previous turns and the new message\n", + " conversation_messages.append(userMessage)\n", + " print(f\"Current conversation history: {json.dumps(conversation_messages, indent=2)}\")\n", + "\n", + " # Construct the payload for Spanner\n", + " chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/{location}\",\n", + " \"messages\": conversation_messages,\n", + " \"data_agent_context\": {\n", + " \"data_agent\": f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\",\n", + " # \"credentials\": looker_credentials\n", + " },\n", + "\n", + " }\n", + "\n", + " # Call the get_stream_multi_turn helper function to stream the response\n", + " get_stream_multi_turn(chat_url, chat_payload, conversation_messages)\n", + "\n", + "# Send first-turn request\n", + "print(\"--- Turn 1 ---\")\n", + "multi_turn_conversation(\"Which artist has the most cards in the database?\")\n", + "\n", + "# Send follow-up-turn request\n", + "print(\"\\n--- Turn 2 ---\")\n", + "multi_turn_conversation(\"How many unique artists are represented in the database?\")" + ], + "metadata": { + "id": "NUzvL7VECTZL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "m7QJvKgvjRkv" + }, + "outputs": [], + "source": [ + "\n", + "# %% [markdown]\n", + "# ## 6. Clean up (Optional)\n", + "#\n", + "# Delete the Data Agent resource if it's no longer needed.\n", + "\n", + "# %% code\n", + "def delete_data_agent(project, agent_id):\n", + " delete_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{project}/locations/global/dataAgents/{agent_id}\"\n", + " print(f\"Deleting Data Agent: {agent_id}...\")\n", + " delete_response = requests.delete(delete_url, headers=headers)\n", + " if delete_response.status_code == 200:\n", + " print(f\"Data Agent {agent_id} deleted successfully.\")\n", + " elif delete_response.status_code == 404:\n", + " print(f\"Data Agent {agent_id} not found.\")\n", + " else:\n", + " print(f\"Error deleting Data Agent {agent_id}: {delete_response.status_code}\")\n", + " print(delete_response.text)\n", + "\n", + "# Uncomment to delete the agent:\n", + "delete_data_agent(billing_project, data_agent_id)\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "oym_sAQCjxCN" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/conversational_analytics/agent_spanner_sdk_example.ipynb b/conversational_analytics/agent_spanner_sdk_example.ipynb new file mode 100644 index 0000000..b26cfea --- /dev/null +++ b/conversational_analytics/agent_spanner_sdk_example.ipynb @@ -0,0 +1,832 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyM1RPSBgSG0l+5cRdf2KRek" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TZ5__KYykwgr" + }, + "outputs": [], + "source": [ + "# %% [markdown]\n", + "# ## 1. Define helper function\n", + "#\n", + "# %% code\n", + "\n", + "import json as json_lib\n", + "import textwrap\n", + "import time\n", + "\n", + "import altair as alt\n", + "import IPython\n", + "import pandas as pd\n", + "import proto\n", + "import requests\n", + "from google.iam.v1 import iam_policy_pb2\n", + "from google.iam.v1 import policy_pb2\n", + "from google.protobuf import field_mask_pb2\n", + "from google.protobuf.json_format import MessageToDict, MessageToJson\n", + "from IPython.display import HTML, display\n", + "from pygments import formatters, highlight, lexers\n", + "\n", + "\n", + "def handle_text_response(resp):\n", + " parts = resp.parts\n", + " full_text = \"\".join(parts)\n", + " if \"\\n\" not in full_text and len(full_text) > 80:\n", + " wrapped_text = textwrap.fill(full_text, width=80)\n", + " print(wrapped_text)\n", + " else:\n", + " print(full_text)\n", + "\n", + "\n", + "def display_schema(data):\n", + " fields = getattr(data, \"fields\")\n", + " df = pd.DataFrame({\n", + " \"Column\": map(lambda field: getattr(field, \"name\"), fields),\n", + " \"Type\": map(lambda field: getattr(field, \"type\"), fields),\n", + " \"Description\": map(lambda field: getattr(field, \"description\", \"-\"), fields),\n", + " \"Mode\": map(lambda field: getattr(field, \"mode\"), fields),\n", + " })\n", + " display(df)\n", + "\n", + "\n", + "def display_section_title(text):\n", + " display(HTML(\"

{}

\".format(text)))\n", + "\n", + "\n", + "def format_looker_table_ref(table_ref):\n", + " return \"lookmlModel: {}, explore: {}, lookerInstanceUri: {}\".format(\n", + " table_ref.lookml_model, table_ref.explore, table_ref.looker_instance_uri\n", + " )\n", + "\n", + "\n", + "def format_bq_table_ref(table_ref):\n", + " return \"{}.{}.{}\".format(\n", + " table_ref.project_id, table_ref.dataset_id, table_ref.table_id\n", + " )\n", + "\n", + "\n", + "def display_datasource(datasource):\n", + " source_name = \"\"\n", + " if \"studio_datasource_id\" in datasource:\n", + " source_name = getattr(datasource, \"studio_datasource_id\")\n", + " elif \"looker_explore_reference\" in datasource:\n", + " source_name = format_looker_table_ref(\n", + " getattr(datasource, \"looker_explore_reference\")\n", + " )\n", + " else:\n", + " source_name = format_bq_table_ref(\n", + " getattr(datasource, \"bigquery_table_reference\")\n", + " )\n", + "\n", + " print(source_name)\n", + " display_schema(datasource.schema)\n", + "\n", + "\n", + "def handle_schema_response(resp):\n", + " if \"query\" in resp:\n", + " print(resp.query.question)\n", + " elif \"result\" in resp:\n", + " display_section_title(\"Schema resolved\")\n", + " print(\"Data sources:\")\n", + " for datasource in resp.result.datasources:\n", + " display_datasource(datasource)\n", + "\n", + "\n", + "def handle_data_response(resp):\n", + " if \"query\" in resp:\n", + " query = resp.query\n", + " display_section_title(\"Retrieval query\")\n", + " print(f\"Query name: {query.name}\")\n", + " if \"question\" in query:\n", + " print(f\"Question: {query.question}\")\n", + " if \"datasources\" in query:\n", + " print(\"Data sources:\")\n", + " for datasource in query.datasources:\n", + " display_datasource(datasource)\n", + " elif \"generated_sql\" in resp:\n", + " display_section_title(\"SQL generated\")\n", + " print(resp.generated_sql)\n", + " elif \"result\" in resp:\n", + " display_section_title(\"Data retrieved\")\n", + "\n", + " fields = [field.name for field in resp.result.schema.fields]\n", + " d = {}\n", + " for el in resp.result.data:\n", + " for field in fields:\n", + " if field in d:\n", + " d[field].append(el[field])\n", + " else:\n", + " d[field] = [el[field]]\n", + "\n", + " display(pd.DataFrame(d))\n", + "\n", + "\n", + "def handle_chart_response(resp):\n", + " def _value_to_dict(v):\n", + " if isinstance(v, proto.marshal.collections.maps.MapComposite):\n", + " return _map_to_dict(v)\n", + " elif isinstance(v, proto.marshal.collections.RepeatedComposite):\n", + " return [_value_to_dict(el) for el in v]\n", + " elif isinstance(v, (int, float, str, bool)):\n", + " return v\n", + " else:\n", + " return MessageToDict(v)\n", + "\n", + " def _map_to_dict(d):\n", + " out = {}\n", + " for k in d:\n", + " if isinstance(d[k], proto.marshal.collections.maps.MapComposite):\n", + " out[k] = _map_to_dict(d[k])\n", + " else:\n", + " out[k] = _value_to_dict(d[k])\n", + " return out\n", + "\n", + " if \"query\" in resp:\n", + " print(resp.query.instructions)\n", + " elif \"result\" in resp:\n", + " vegaConfig = resp.result.vega_config\n", + " vegaConfig_dict = _map_to_dict(vegaConfig)\n", + " alt.Chart.from_json(json_lib.dumps(vegaConfig_dict)).display()\n", + "\n", + "\n", + "def show_message(msg):\n", + " m = msg.system_message\n", + " if \"text\" in m:\n", + " handle_text_response(getattr(m, \"text\"))\n", + " elif \"schema\" in m:\n", + " handle_schema_response(getattr(m, \"schema\"))\n", + " elif \"data\" in m:\n", + " handle_data_response(getattr(m, \"data\"))\n", + " elif \"chart\" in m:\n", + " handle_chart_response(getattr(m, \"chart\"))\n", + " print(\"\\n\")\n" + ] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 2. Authenticate GDA API\n", + "#\n", + "# %% code\n", + "\n", + "%pip install google-cloud-geminidataanalytics\n", + "from google.colab import auth\n", + "auth.authenticate_user()\n", + "\n", + "from google.cloud import geminidataanalytics\n", + "\n", + "data_agent_client = geminidataanalytics.DataAgentServiceClient()\n", + "data_chat_client = geminidataanalytics.DataChatServiceClient()\n" + ], + "metadata": { + "id": "U5WW4RPTpGsX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 3. Configure Data Source Details\n", + "#\n", + "# Update the variables below with your specific Google Cloud project, spanner instance details, and desired agent settings.\n", + "\n", + "# %% code\n", + "# Replace with your actual project and spanner details\n", + "billing_project = \"project\" # Your billing project ID\n", + "location = \"location\" # The region of your spanner instance\n", + "\n", + "# spanner connection details\n", + "spanner_project_id = \"project-id\" # Project ID of the spanner instance\n", + "engine = \"GOOGLE_SQL\"\n", + "spanner_instance_id = \"instance-id\" # Your spanner instance ID\n", + "spanner_database_id = \"database-id\" # Your database name (e.g., postgres)\n", + "\n", + "system_instruction = \"Help the user analyze data from the spanner database\"\n", + "\n", + "# Data Agent ID\n", + "data_agent_id = \"spanner_agent_e2e_example\"" + ], + "metadata": { + "id": "XYsWJRaKqHPk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 4. Connect to spanner source\n", + "#\n", + "# Update the variables below with your specific Google Cloud project, spanner instance details, and desired agent settings.\n", + "# # Its better to provide context_set_id similar to https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases#providing_context_with_querydata\n", + "\n", + "# %% code\n", + "# Replace with your actual project and spanner details\n", + "spanner_ref_1 = geminidataanalytics.SpannerReference()\n", + "spanner_ref_1.database_reference = geminidataanalytics.SpannerDatabaseReference()\n", + "spanner_ref_1.database_reference.project_id = billing_project\n", + "spanner_ref_1.database_reference.region = \"us-central1\" # Example region\n", + "spanner_ref_1.database_reference.engine = engine\n", + "spanner_ref_1.database_reference.instance_id = spanner_instance_id\n", + "spanner_ref_1.database_reference.database_id = spanner_database_id\n", + "# optional set agent context reference\n", + "# spanner_ref_1.agent_context_reference.context_set_id = f\"projects/{billing_project}/locations/{location}/contextSets/your_context_set_id\"\n", + "\n", + "datasource_references = geminidataanalytics.DatasourceReferences()\n", + "datasource_references.spanner_reference = spanner_ref_1\n", + "\n", + "print(\"spanner DatasourceReferences created\")" + ], + "metadata": { + "id": "kkuyz0qXrMDC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 5. Stateful chat\n", + "#\n", + "\n", + "# %% code\n", + "# Set up context for stateful chat\n", + "published_context = geminidataanalytics.Context()\n", + "published_context.system_instruction = system_instruction\n", + "published_context.datasource_references = datasource_references\n", + "# Optional: To enable advanced analysis with Python, include the following line:\n", + "published_context.options.analysis.python.enabled = True\n" + ], + "metadata": { + "id": "jSC7Bi0ExV81" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 6. Create a syncronize data agent\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "data_agent = geminidataanalytics.DataAgent()\n", + "data_agent.data_analytics_agent.published_context = published_context\n", + "data_agent.name = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\" # Optional\n", + "\n", + "request = geminidataanalytics.CreateDataAgentRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " data_agent_id=data_agent_id, # Optional\n", + " data_agent=data_agent,\n", + ")\n", + "\n", + "try:\n", + " response = data_agent_client.create_data_agent_sync(request=request)\n", + " print(\"Data Agent created\")\n", + " print(response)\n", + "except Exception as e:\n", + " print(f\"Error creating Data Agent: {e}\")\n" + ], + "metadata": { + "id": "N934aEpXxlmW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 7. Create a asyncronize data agent\n", + "#\n", + "\n", + "# data_agent_id = \"data_agent_spanner\"\n", + "\n", + "# data_agent = geminidataanalytics.DataAgent()\n", + "# data_agent.data_analytics_agent.published_context = published_context\n", + "# data_agent.name = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\" # Optional\n", + "\n", + "# request = geminidataanalytics.CreateDataAgentRequest(\n", + "# parent=f\"projects/{billing_project}/locations/{location}\",\n", + "# data_agent_id=data_agent_id, # Optional\n", + "# data_agent=data_agent,\n", + "# )\n", + "\n", + "# try:\n", + "# data_agent_client.create_data_agent(request=request)\n", + "# print(\"Data Agent created\")\n", + "# except Exception as e:\n", + "# print(f\"Error creating Data Agent: {e}\")\n", + "\n" + ], + "metadata": { + "id": "nHCranZXyBOY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 8. Create a conversation\n", + "#\n", + "\n", + "# %% code\n", + "# Initialize request arguments\n", + "conversation_id = \"conversation_spanner_example\"\n", + "\n", + "conversation = geminidataanalytics.Conversation()\n", + "conversation.agents = [f'projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}']\n", + "conversation.name = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + "\n", + "request = geminidataanalytics.CreateConversationRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " conversation_id=conversation_id,\n", + " conversation=conversation,\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.create_conversation(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "t6K4__0BypcM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 9. get a data agent\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "# Initialize request arguments\n", + "request = geminidataanalytics.GetDataAgentRequest(\n", + " name=f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\",\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_agent_client.get_data_agent(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "Ye5_Tx23y-r-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 10. list data agent\n", + "#\n", + "\n", + "# %% code\n", + "creator_filter = \"YOUR-CREATOR-FILTER\" #optional\n", + "\n", + "request = geminidataanalytics.ListDataAgentsRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " # creator_filter=creator_filter, optional\n", + ")\n", + "\n", + "# Make the request\n", + "page_result = data_agent_client.list_data_agents(request=request)\n", + "\n", + "# Handle the response\n", + "for response in page_result:\n", + " print(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "aDBTrJHszLkn" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 11. update a agent\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "data_agent = geminidataanalytics.DataAgent()\n", + "data_agent.data_analytics_agent.published_context = published_context\n", + "data_agent.name = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "data_agent.description = \"Updated description of the data agent.\"\n", + "\n", + "update_mask = field_mask_pb2.FieldMask(paths=['description', 'data_analytics_agent.published_context'])\n", + "\n", + "request = geminidataanalytics.UpdateDataAgentRequest(\n", + " data_agent=data_agent,\n", + " update_mask=update_mask,\n", + ")\n", + "\n", + "try:\n", + " # Make the request\n", + " response = data_agent_client.update_data_agent_sync(request=request)\n", + " print(\"Data Agent Updated\")\n", + " print(response)\n", + "except Exception as e:\n", + " print(f\"Error updating Data Agent: {e}\")\n" + ], + "metadata": { + "collapsed": true, + "id": "yxLspk9RzU0A" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "# %% [markdown]\n", + "# ## 12. Get IAM policy\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "resource = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "request = iam_policy_pb2.GetIamPolicyRequest(\n", + " resource=resource,\n", + " )\n", + "try:\n", + " response = data_agent_client.get_iam_policy(request=request)\n", + " print(\"IAM Policy fetched successfully!\")\n", + " print(f\"Response: {response}\")\n", + "except Exception as e:\n", + " print(f\"Error setting IAM policy: {e}\")\n" + ], + "metadata": { + "collapsed": true, + "id": "q0hJ-qKh0gOx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 13. Update IAM policy\n", + "#\n", + "\n", + "# %% code\n", + "role = \"roles/geminidataanalytics.dataAgentEditor\"\n", + "users = \"tzha@google.com\" # replace with your email\n", + "\n", + "resource = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "\n", + "# Construct the IAM policy\n", + "binding = policy_pb2.Binding(\n", + " role=role,\n", + " members= [f\"user:{i.strip()}\" for i in users.split(\",\")]\n", + ")\n", + "\n", + "policy = policy_pb2.Policy(bindings=[binding])\n", + "\n", + "# Create the request\n", + "request = iam_policy_pb2.SetIamPolicyRequest(\n", + " resource=resource,\n", + " policy=policy\n", + ")\n", + "\n", + "# Send the request\n", + "try:\n", + " response = data_agent_client.set_iam_policy(request=request)\n", + " print(\"IAM Policy set successfully!\")\n", + " print(f\"Response: {response}\")\n", + "except Exception as e:\n", + " print(f\"Error setting IAM policy: {e}\")\n" + ], + "metadata": { + "id": "9_RQbHrg00EY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 14. Get a conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "\n", + "request = geminidataanalytics.GetConversationRequest(\n", + " name = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.get_conversation(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "GO4ObQfG1Qi7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 15. list conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "request = geminidataanalytics.ListConversationsRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.list_conversations(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "TrxPjSGQ1cp0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 16. list messages in conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "request = geminidataanalytics.ListMessagesRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\",\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.list_messages(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "6LWgpjhO1lXV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 17. Stateful chat with agent and conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "# Create a request that contains a single user message (your question)\n", + "question = \"How many cards has John Avon illustrated in total?\"\n", + "messages = [geminidataanalytics.Message()]\n", + "messages[0].user_message.text = question\n", + "\n", + "\n", + "# Create a conversation_reference\n", + "conversation_reference = geminidataanalytics.ConversationReference()\n", + "conversation_reference.conversation = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + "conversation_reference.data_agent_context.data_agent = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "# conversation_reference.data_agent_context.credentials = credentials\n", + "\n", + "# Form the request\n", + "request = geminidataanalytics.ChatRequest(\n", + " parent = f\"projects/{billing_project}/locations/{location}\",\n", + " messages = messages,\n", + " conversation_reference = conversation_reference\n", + ")\n", + "\n", + "# Make the request\n", + "stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + "# Handle the response\n", + "for response in stream:\n", + " show_message(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "xLo-hiiw16sH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 18. Stateless chat without conversation\n", + "#\n", + "\n", + "# %% code\n", + "# Create a request that contains a single user message (your question)\n", + "question = \"How many cards has John Avon illustrated in total?\"\n", + "\n", + "messages = [geminidataanalytics.Message()]\n", + "messages[0].user_message.text = question\n", + "\n", + "\n", + "data_agent_context = geminidataanalytics.DataAgentContext()\n", + "data_agent_context.data_agent = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "# data_agent_context.credentials = credentials\n", + "\n", + "# Form the request\n", + "request = geminidataanalytics.ChatRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " messages=messages,\n", + " data_agent_context = data_agent_context\n", + ")\n", + "\n", + "# Make the request\n", + "stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + "# Handle the response\n", + "for response in stream:\n", + " show_message(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "awhj4bd52oiz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 19. Ask questions with inline context\n", + "#\n", + "\n", + "# %% code\n", + "# Create a request that contains a single user message (your question)\n", + "messages = [geminidataanalytics.Message()]\n", + "messages[0].user_message.text = question\n", + "\n", + "inline_context = geminidataanalytics.Context()\n", + "inline_context.system_instruction = system_instruction\n", + "inline_context.datasource_references = datasource_references\n", + "# Optional: To enable advanced analysis with Python, include the following line:\n", + "inline_context.options.analysis.python.enabled = True\n", + "\n", + "request = geminidataanalytics.ChatRequest(\n", + " inline_context=inline_context,\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " messages=messages,\n", + ")\n", + "\n", + "# Make the request\n", + "stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + "# Handle the response\n", + "for response in stream:\n", + " show_message(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "3Oju_c-I3sHR" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 20. multi turn conversation\n", + "#\n", + "\n", + "# %% code\n", + "# List that is used to track previous turns and is reused across requests\n", + "conversation_messages = []\n", + "\n", + "\n", + "# Use data agent context\n", + "data_agent_context = geminidataanalytics.DataAgentContext()\n", + "data_agent_context.data_agent = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "# data_agent_context.credentials = credentials\n", + "\n", + "# Helper function for calling the API\n", + "def multi_turn_Conversation(msg):\n", + "\n", + " message = geminidataanalytics.Message()\n", + " message.user_message.text = msg\n", + "\n", + " # Send a multi-turn request by including previous turns and the new message\n", + " conversation_messages.append(message)\n", + "\n", + " request = geminidataanalytics.ChatRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " messages=conversation_messages,\n", + " # Use data agent context\n", + " data_agent_context=data_agent_context,\n", + " # Use inline context\n", + " # inline_context=inline_context,\n", + " )\n", + "\n", + " # Make the request\n", + " stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + " # Handle the response\n", + " for response in stream:\n", + " show_message(response)\n", + " conversation_messages.append(response)\n", + "\n", + "# Send the first turn request\n", + "multi_turn_Conversation(\"How many cards has John Avon illustrated in total?\")\n", + "\n", + "# Send follow-up turn request\n", + "multi_turn_Conversation(\"Which artist has the most cards?\")\n" + ], + "metadata": { + "collapsed": true, + "id": "n4BHxhWE5GVy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 21. Delete agent and conversation\n", + "#\n", + "\n", + "# %% code\n", + "request = geminidataanalytics.DeleteDataAgentRequest(\n", + " name=f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\",\n", + ")\n", + "\n", + "try:\n", + " # Make the request\n", + " data_agent_client.delete_data_agent_sync(request=request)\n", + " print(\"Data Agent Deleted\")\n", + "except Exception as e:\n", + " print(f\"Error deleting Data Agent: {e}\")\n", + "\n", + "\n", + "\n", + "request = geminidataanalytics.DeleteConversationRequest(\n", + " name = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.delete_conversation(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "LONnCDRQ1LZy" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file