From a22cf8c3dd2624883f04cc8d21b31a01be220fbb Mon Sep 17 00:00:00 2001 From: Tian Zhang Date: Thu, 26 Mar 2026 13:10:11 -0700 Subject: [PATCH] initial commit for ca spanner colab --- .DS_Store | Bin 0 -> 8196 bytes conversational_analytics/README.md | 24 + .../agent_spanner_http_example.ipynb | 724 +++++++++++++++ .../agent_spanner_sdk_example.ipynb | 832 ++++++++++++++++++ 4 files changed, 1580 insertions(+) create mode 100644 .DS_Store create mode 100644 conversational_analytics/README.md create mode 100644 conversational_analytics/agent_spanner_http_example.ipynb create mode 100644 conversational_analytics/agent_spanner_sdk_example.ipynb diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5685892f7337906b816fe87c194739d9558098d4 GIT binary patch literal 8196 zcmeHM%Wl&^6g|^MYIq3*n{1H$1F6(MaEk+~#6!YJEFrN&QsLN9N{ZNd-+y4ujy-$C zkMt+Hf^+W-l|5+|0feZSk!FrFbI%>mnHx{81Hc?C^F5#kpvx+F=LVZgii}HLDO;}R zF42$=yhVu$?@?g9VjVhy0zrYGKu{nk5ES?q6yTXHrMu?bx1*sA3IqlIO9kZl5U~oz z4r`10=wPEM0I|bvExcwPpq$8I?69_|RQj~pgK<>jQVe719FN4D7(1*jS~`rS!?>A^ z%TSEYPF|$uFtNqZ1_gows|rZ&eh)+B%-#E~{5{8nS$v=Eg5z^c4~r@r71f9g`I7Ck z?ta(8D_{pB+LiQi#CGCYfA*}q$S}onoxQp5*d|} zcM(g>IM0yl=yOH9H@(BL$3FMiJ!O|TA!mVO^52nHsiy+3bge^r|DgRDjy->m_z>D} z=(DlEVzQTJ)ojlD`dK#iF{_fq?OFHhr}MM3FW+R}e!)-vNBqpEzrUuFbWs*nIpI3w zFz7ILEAEx2|CWrmcm>V$GQViOA3yJ2Y^)}<6!lG@Z zG2FDe#+V&&rUXxjyNt@*V@L!y8Mhc)WBQgj?rX&5M6MNck15gHWZuLJ{oaAA~5hRqLz!g$p+umzW^8dm5_y1RDMd4P1 z0zrXI6)>G7O%BM?Q~eUQobnOtYgQ>@S6fsHo7N8je!lp_kk=7Sg%dlhEwYCtya-Sk LLK77Ds|x%8 80:\n", + " wrapped_text = textwrap.fill(full_text, width=80)\n", + " print(wrapped_text)\n", + " else:\n", + " print(full_text)\n", + "\n", + "def get_property(data, field_name, default=''):\n", + " \"\"\"Safely gets a property from a dictionary.\"\"\"\n", + " return data.get(field_name, default)\n", + "\n", + "def display_schema(data):\n", + " \"\"\"Displays schema information in a DataFrame.\"\"\"\n", + " fields = data.get('fields', [])\n", + " df = pd.DataFrame({\n", + " \"Column\": [get_property(field, 'name') for field in fields],\n", + " \"Type\": [get_property(field, 'type') for field in fields],\n", + " \"Description\": [get_property(field, 'description', '-') for field in fields],\n", + " \"Mode\": [get_property(field, 'mode') for field in fields]\n", + " })\n", + " display(df)\n", + "\n", + "def display_section_title(text):\n", + " \"\"\"Displays a formatted section title in HTML.\"\"\"\n", + " display(HTML(f'

{text}

'))\n", + "\n", + "def format_bq_table_ref(table_ref):\n", + " \"\"\"Formats a BigQuery table reference for display.\"\"\"\n", + " return f\"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}\"\n", + "\n", + "def format_looker_table_ref(table_ref):\n", + " \"\"\"Formats a Looker table reference for display.\"\"\"\n", + " return f\"lookmlModel: {table_ref.get('lookmlModel')}, explore: {table_ref.get('explore')}, lookerInstanceUri: {table_ref.get('lookerInstanceUri')}\"\n", + "\n", + "def display_datasource(datasource):\n", + " \"\"\"Displays information about a datasource, including its schema.\"\"\"\n", + " source_name = ''\n", + " if 'studioDatasourceId' in datasource:\n", + " source_name = datasource['studioDatasourceId']\n", + " elif 'lookerExploreReference' in datasource:\n", + " source_name = format_looker_table_ref(datasource['lookerExploreReference'])\n", + " elif 'bigqueryTableReference' in datasource:\n", + " source_name = format_bq_table_ref(datasource['bigqueryTableReference'])\n", + " else:\n", + " source_name = \"Unknown Datasource\"\n", + "\n", + " print(source_name)\n", + " if 'schema' in datasource:\n", + " display_schema(datasource['schema'])\n", + "\n", + "def handle_schema_response(resp):\n", + " \"\"\"Handles responses related to schema resolution.\"\"\"\n", + " if 'query' in resp:\n", + " print(get_property(resp['query'], 'question'))\n", + " elif 'result' in resp:\n", + " display_section_title('Schema resolved')\n", + " print('Data sources:')\n", + " for datasource in get_property(resp['result'], 'datasources', []):\n", + " display_datasource(datasource)\n", + "\n", + "def handle_data_response(resp):\n", + " \"\"\"Handles responses containing data or SQL queries.\"\"\"\n", + " if 'query' in resp:\n", + " query = resp['query']\n", + " display_section_title('Retrieval query')\n", + " print(f\"Query name: {get_property(query, 'name')}\")\n", + " if 'question' in query:\n", + " print(f\"Question: {get_property(query, 'question')}\")\n", + " if 'datasources' in query:\n", + " print('Data sources:')\n", + " for datasource in get_property(query, 'datasources', []):\n", + " display_datasource(datasource)\n", + " elif 'generatedSql' in resp:\n", + " display_section_title('SQL generated')\n", + " print(resp['generatedSql'])\n", + " elif 'result' in resp:\n", + " display_section_title('Data retrieved')\n", + " result = resp['result']\n", + " schema = result.get('schema', {})\n", + " fields = [get_property(field, 'name') for field in schema.get('fields', [])]\n", + " data = result.get('data', [])\n", + "\n", + " data_dict = {field: [get_property(el, field) for el in data] for field in fields}\n", + " display(pd.DataFrame(data_dict))\n", + "\n", + "def handle_chart_response(resp):\n", + " \"\"\"Handles responses for generating charts.\"\"\"\n", + " if 'query' in resp:\n", + " print(get_property(resp['query'], 'instructions'))\n", + " elif 'result' in resp:\n", + " vegaConfig = get_property(resp['result'], 'vegaConfig')\n", + " if vegaConfig:\n", + " alt.Chart.from_json(json_lib.dumps(vegaConfig)).display()\n", + "\n", + "def handle_error(resp):\n", + " \"\"\"Handles error responses.\"\"\"\n", + " display_section_title('Error')\n", + " print(f\"Code: {get_property(resp, 'code')}\")\n", + " print(f\"Message: {get_property(resp, 'message')}\")\n", + "\n" + ], + "metadata": { + "id": "ozC3VvPfE6qp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "def get_stream(url, payload):\n", + " \"\"\"\n", + " Posts data to a URL and processes the streaming JSON response line by line.\n", + "\n", + " Args:\n", + " url (str): The URL to send the POST request to.\n", + " payload (dict): The JSON payload to send in the request body.\n", + " \"\"\"\n", + " s = requests.Session()\n", + " acc = '' # Accumulator for JSON parts\n", + "\n", + " try:\n", + " with s.post(url, json=payload, headers=headers, stream=True) as resp:\n", + " resp.raise_for_status() # Raise an exception for bad status codes\n", + " for line in resp.iter_lines():\n", + " if not line:\n", + " continue\n", + "\n", + " try:\n", + " decoded_line = line.decode('utf-8')\n", + " except UnicodeDecodeError:\n", + " print(f\"Warning: Could not decode line: {line}\")\n", + " continue\n", + "\n", + " # This custom JSON assembly logic seems fragile.\n", + " # It's attempting to piece together a JSON object\n", + " # from lines that might not be complete JSON objects themselves.\n", + " if decoded_line == '[{':\n", + " acc = '{'\n", + " elif decoded_line == '}]':\n", + " acc += '}'\n", + " elif decoded_line == ',':\n", + " continue\n", + " else:\n", + " acc += decoded_line\n", + "\n", + " if not is_json(acc):\n", + " continue\n", + "\n", + " try:\n", + " data_json = json_lib.loads(acc)\n", + " except json_lib.JSONDecodeError:\n", + " print(f\"Warning: Could not decode accumulated JSON: {acc}\")\n", + " acc = '' # Reset accumulator on error\n", + " continue\n", + "\n", + " if 'error' in data_json:\n", + " handle_error(data_json['error'])\n", + " acc = ''\n", + " continue\n", + "\n", + " if 'systemMessage' in data_json:\n", + " system_message = data_json['systemMessage']\n", + " if 'text' in system_message:\n", + " handle_text_response(system_message['text'])\n", + " elif 'schema' in system_message:\n", + " handle_schema_response(system_message['schema'])\n", + " elif 'data' in system_message:\n", + " handle_data_response(system_message['data'])\n", + " elif 'chart' in system_message:\n", + " handle_chart_response(system_message['chart'])\n", + " else:\n", + " # Fallback for unhandled systemMessage types\n", + " colored_json = highlight(acc, lexers.JsonLexer(), formatters.TerminalFormatter())\n", + " print(colored_json)\n", + " else:\n", + " # Fallback for responses without systemMessage or error\n", + " colored_json = highlight(acc, lexers.JsonLexer(), formatters.TerminalFormatter())\n", + " print(colored_json)\n", + "\n", + " acc = '' # Reset accumulator after processing a complete JSON object\n", + "\n", + " except requests.exceptions.RequestException as e:\n", + " print(f\"Error during request: {e}\")\n", + " except Exception as e:\n", + " print(f\"An unexpected error occurred: {e}\")\n", + "\n" + ], + "metadata": { + "id": "zx1t7jToHKJH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def get_stream_multi_turn(url, json, conversation_messages):\n", + " s = requests.Session()\n", + "\n", + " acc = ''\n", + "\n", + " with s.post(url, json=json, headers=headers, stream=True) as resp:\n", + " for line in resp.iter_lines():\n", + " if not line:\n", + " continue\n", + "\n", + " decoded_line = str(line, encoding='utf-8')\n", + "\n", + " if decoded_line == '[{':\n", + " acc = '{'\n", + " elif decoded_line == '}]':\n", + " acc += '}'\n", + " elif decoded_line == ',':\n", + " continue\n", + " else:\n", + " acc += decoded_line\n", + "\n", + " if not is_json(acc):\n", + " continue\n", + "\n", + " data_json = json_lib.loads(acc)\n", + " # Store the response that will be used in the next iteration\n", + " conversation_messages.append(data_json)\n", + "\n", + " if not 'systemMessage' in data_json:\n", + " if 'error' in data_json:\n", + " handle_error(data_json['error'])\n", + " continue\n", + "\n", + " if 'text' in data_json['systemMessage']:\n", + " handle_text_response(data_json['systemMessage']['text'])\n", + " elif 'schema' in data_json['systemMessage']:\n", + " handle_schema_response(data_json['systemMessage']['schema'])\n", + " elif 'data' in data_json['systemMessage']:\n", + " handle_data_response(data_json['systemMessage']['data'])\n", + " elif 'chart' in data_json['systemMessage']:\n", + " handle_chart_response(data_json['systemMessage']['chart'])\n", + " else:\n", + " colored_json = highlight(acc, lexers.JsonLexer(), formatters.TerminalFormatter())\n", + " print(colored_json)\n", + " print('\\n')\n", + " acc = ''" + ], + "metadata": { + "id": "fQ7_GcurkS2M" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# # Spanner Data Agent End-to-End Example\n", + "#\n", + "# This notebook demonstrates how to create and interact with a Data Agent for Spanner using the Gemini Data Analytics API via HTTP requests in Python.\n", + "\n", + "# %% [markdown]\n", + "# ## 1. Setup and Authentication\n", + "#\n", + "# Import necessary libraries and authenticate your user. This will be used to authorize the API requests.\n", + "\n", + "# %% code\n", + "import json\n", + "import requests\n", + "from google.colab import auth\n", + "import os\n", + "\n", + "# Authenticate user and get access token\n", + "auth.authenticate_user()\n", + "access_token = !gcloud auth application-default print-access-token\n", + "\n", + "if not access_token or not access_token[0]:\n", + " raise ValueError(\"Failed to get access token. Please ensure you are authenticated.\")\n", + "\n", + "headers = {\n", + " \"Authorization\": f\"Bearer {access_token[0]}\",\n", + " \"Content-Type\": \"application/json\",\n", + " \"x-server-timeout\": \"300\", # Custom timeout up to 600s\n", + "}\n", + "print(\"Headers configured.\")" + ], + "metadata": { + "id": "8PuSOsrXkipJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 2. Configure Data Source Details\n", + "#\n", + "# Update the variables below with your specific Google Cloud project, Spanner instance details, and desired agent settings.\n", + "\n", + "# %% code\n", + "# Replace with your actual project and Spanner details\n", + "billing_project = \"project-id\" # Your billing project ID\n", + "location = \"global\" # The region of your Spanner instance\n", + "\n", + "# Spanner connection details\n", + "spanner_project_id = \"project-id\" # Project ID of the Spanner instance\n", + "spanner_instance_id = \"instance-id\" # Your spanner instance ID\n", + "spanner_database_id = \"database-id\" # Your database name\n", + "engine = \"GOOGLE_SQL\"\n", + "system_instruction = \"Help the user analyze data from the spanner database\"\n", + "\n", + "# Data Agent ID\n", + "data_agent_id = \"spanner_agent_e2e_example\"" + ], + "metadata": { + "id": "36Ffl_9SlOYF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 3. Define Spanner Data Source\n", + "#\n", + "# This dictionary structure tells the Data Agent how to connect to your Spanner database.\n", + "# Its better to provide context_set_id similar to https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases#providing_context_with_querydata\n", + "\n", + "# %% code\n", + "# Spanner data source definition\n", + "spanner_data_sources = {\n", + " \"spanner_reference\": {\n", + " \"database_reference\": {\n", + " \"project_id\": spanner_project_id,\n", + " \"region\": location,\n", + " \"engine\": engine,\n", + " \"instance_id\": spanner_instance_id,\n", + " \"database_id\": spanner_database_id,\n", + " },\n", + " # Optional: Include this if you have pre-authored context for the agent\n", + " # \"agent_context_reference\": {\n", + " # \"context_set_id\": f\"projects/{billing_project}/locations/{location}/contextSets/your_context_set_id\"\n", + " # }\n", + " }\n", + "}\n", + "print(\"Spanner data source configured:\")\n", + "print(json.dumps(spanner_data_sources, indent=2))\n" + ], + "metadata": { + "id": "Mz0_sMECVSZT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 4. Create the Data Agent\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to create the Data Agent resource.\n", + "# Make sure roles/geminidataanalytics.dataAgentCreator is granted\n", + "\n", + "# %% code\n", + "data_agent_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/global/dataAgents\"\n", + "\n", + "data_agent_payload = {\n", + " \"name\": f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\",\n", + " \"description\": \"This is an example Spanner data agent created from Colab.\", # Optional\n", + " \"data_analytics_agent\": {\n", + " \"published_context\": {\n", + " \"datasource_references\": spanner_data_sources,\n", + " \"system_instruction\": system_instruction,\n", + " }\n", + " }\n", + "}\n", + "\n", + "params = {\"data_agent_id\": data_agent_id} # Optional\n", + "\n", + "print(f\"Creating Data Agent: {data_agent_id}...\")\n", + "data_agent_response = requests.post(\n", + " data_agent_url, params=params, json=data_agent_payload, headers=headers\n", + ")\n", + "\n", + "if data_agent_response.status_code == 200:\n", + " print(\"Data Agent created successfully!\")\n", + " print(json.dumps(data_agent_response.json(), indent=2))\n", + "elif data_agent_response.status_code == 409:\n", + " print(f\"Data Agent '{data_agent_id}' already exists.\")\n", + " # Optionally, you could add code here to fetch the existing agent\n", + "else:\n", + " print(f\"Error creating Data Agent: {data_agent_response.status_code}\")\n", + " print(data_agent_response.text)" + ], + "metadata": { + "id": "ir8wraonVhnN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 5. Create the conversation\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to create conversation.\n", + "\n", + "# %% code\n", + "conversation_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/global/conversations\"\n", + "\n", + "conversation_id = \"conversation_agent_spanner_example\"\n", + "\n", + "conversation_payload = {\n", + " \"agents\": [\n", + " f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\"\n", + " ],\n", + " \"name\": f\"projects/{billing_project}/locations/global/conversations/{conversation_id}\"\n", + "}\n", + "\n", + "params = {\n", + " \"conversation_id\": conversation_id\n", + "}\n", + "\n", + "conversation_response = requests.post(conversation_url, headers=headers, params=params, json=conversation_payload)\n", + "\n", + "if conversation_response.status_code == 200:\n", + " print(\"Conversation created successfully!\")\n", + " print(json.dumps(conversation_response.json(), indent=2))\n", + "else:\n", + " print(f\"Error creating Conversation: {conversation_response.status_code}\")\n", + " print(conversation_response.text)" + ], + "metadata": { + "id": "Fcim8o3Q4RXM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 6. Chat with the API by using conversation (stateful)\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat using a conversation\n", + "\n", + "# %% code\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/global:chat\"\n", + "\n", + "# The natural language question to ask the data agent\n", + "user_prompt = \"Which artist has the most cards in the database?\" # Replace with your question\n", + "\n", + "\n", + "# Construct the payload\n", + "chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/global\",\n", + " \"messages\": [\n", + " {\n", + " \"userMessage\": {\n", + " \"text\": user_prompt\n", + " }\n", + " }\n", + " ],\n", + " \"conversation_reference\": {\n", + " \"conversation\": f\"projects/{billing_project}/locations/global/conversations/{conversation_id}\",\n", + " \"data_agent_context\": {\n", + " \"data_agent\": f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\",\n", + " }\n", + " }\n", + "}\n", + "\n", + "print(f\"Sending prompt to :chat: '{user_prompt}'\")\n", + "print(f\"Endpoint: {chat_url}\")\n", + "print(f\"Payload: {json.dumps(chat_payload, indent=2)}\")\n", + "\n", + "\n", + "get_stream(chat_url, chat_payload)" + ], + "metadata": { + "id": "hnzypmKM6o54" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 7. Chat with the API by using data agent (stateless)\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat without conversation\n", + "# %% code\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/{location}:chat\"\n", + "\n", + "\n", + "# Construct the payload\n", + "chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/global\",\n", + " \"messages\": [\n", + " {\n", + " \"userMessage\": {\n", + " \"text\": user_prompt\n", + " }\n", + " }\n", + " ],\n", + " \"data_agent_context\": {\n", + " \"data_agent\": f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\",\n", + " }\n", + "}\n", + "\n", + "print(f\"Sending prompt to :chat: '{user_prompt}'\")\n", + "print(f\"Endpoint: {chat_url}\")\n", + "print(f\"Payload: {json.dumps(chat_payload, indent=2)}\")\n", + "\n", + "get_stream(chat_url, chat_payload)\n" + ], + "metadata": { + "id": "gMxaruBA-Z2M" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 8. Chat with the API by using inline context (stateless)\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat with inline context\n", + "# %% code\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/us-east4:chat\"\n", + "\n", + "\n", + "# Construct the payload\n", + "chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/{location}\",\n", + " \"messages\": [\n", + " {\n", + " \"userMessage\": {\n", + " \"text\": user_prompt\n", + " }\n", + " }\n", + " ],\n", + " \"inline_context\": {\n", + " \"datasource_references\": spanner_data_sources\n", + " # The \"options\" field for python analysis is not shown in the Spanner example\n", + " }\n", + "}\n", + "\n", + "print(f\"Sending prompt to :chat: '{user_prompt}'\")\n", + "print(f\"Endpoint: {chat_url}\")\n", + "print(f\"Payload: {json.dumps(chat_payload, indent=2)}\")\n", + "\n", + "get_stream(chat_url, chat_payload)\n" + ], + "metadata": { + "id": "CPG0ge21BeHY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 9. Chat with the API multi turn conversation\n", + "#\n", + "# Send a POST request to the Gemini Data Analytics API to chat with multi turn conversation\n", + "# %% code\n", + "\n", + "\n", + "chat_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{billing_project}/locations/{location}:chat\"\n", + "\n", + "# List that is used to track previous turns and is reused across requests\n", + "conversation_messages = []\n", + "\n", + "# Helper function for calling the API\n", + "def multi_turn_conversation(msg):\n", + " userMessage = {\n", + " \"userMessage\": {\n", + " \"text\": msg\n", + " }\n", + " }\n", + "\n", + " # Send a multi-turn request by including previous turns and the new message\n", + " conversation_messages.append(userMessage)\n", + " print(f\"Current conversation history: {json.dumps(conversation_messages, indent=2)}\")\n", + "\n", + " # Construct the payload for Spanner\n", + " chat_payload = {\n", + " \"parent\": f\"projects/{billing_project}/locations/{location}\",\n", + " \"messages\": conversation_messages,\n", + " \"data_agent_context\": {\n", + " \"data_agent\": f\"projects/{billing_project}/locations/global/dataAgents/{data_agent_id}\",\n", + " # \"credentials\": looker_credentials\n", + " },\n", + "\n", + " }\n", + "\n", + " # Call the get_stream_multi_turn helper function to stream the response\n", + " get_stream_multi_turn(chat_url, chat_payload, conversation_messages)\n", + "\n", + "# Send first-turn request\n", + "print(\"--- Turn 1 ---\")\n", + "multi_turn_conversation(\"Which artist has the most cards in the database?\")\n", + "\n", + "# Send follow-up-turn request\n", + "print(\"\\n--- Turn 2 ---\")\n", + "multi_turn_conversation(\"How many unique artists are represented in the database?\")" + ], + "metadata": { + "id": "NUzvL7VECTZL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "m7QJvKgvjRkv" + }, + "outputs": [], + "source": [ + "\n", + "# %% [markdown]\n", + "# ## 6. Clean up (Optional)\n", + "#\n", + "# Delete the Data Agent resource if it's no longer needed.\n", + "\n", + "# %% code\n", + "def delete_data_agent(project, agent_id):\n", + " delete_url = f\"https://geminidataanalytics.googleapis.com/v1beta/projects/{project}/locations/global/dataAgents/{agent_id}\"\n", + " print(f\"Deleting Data Agent: {agent_id}...\")\n", + " delete_response = requests.delete(delete_url, headers=headers)\n", + " if delete_response.status_code == 200:\n", + " print(f\"Data Agent {agent_id} deleted successfully.\")\n", + " elif delete_response.status_code == 404:\n", + " print(f\"Data Agent {agent_id} not found.\")\n", + " else:\n", + " print(f\"Error deleting Data Agent {agent_id}: {delete_response.status_code}\")\n", + " print(delete_response.text)\n", + "\n", + "# Uncomment to delete the agent:\n", + "delete_data_agent(billing_project, data_agent_id)\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "oym_sAQCjxCN" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/conversational_analytics/agent_spanner_sdk_example.ipynb b/conversational_analytics/agent_spanner_sdk_example.ipynb new file mode 100644 index 0000000..b26cfea --- /dev/null +++ b/conversational_analytics/agent_spanner_sdk_example.ipynb @@ -0,0 +1,832 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyM1RPSBgSG0l+5cRdf2KRek" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TZ5__KYykwgr" + }, + "outputs": [], + "source": [ + "# %% [markdown]\n", + "# ## 1. Define helper function\n", + "#\n", + "# %% code\n", + "\n", + "import json as json_lib\n", + "import textwrap\n", + "import time\n", + "\n", + "import altair as alt\n", + "import IPython\n", + "import pandas as pd\n", + "import proto\n", + "import requests\n", + "from google.iam.v1 import iam_policy_pb2\n", + "from google.iam.v1 import policy_pb2\n", + "from google.protobuf import field_mask_pb2\n", + "from google.protobuf.json_format import MessageToDict, MessageToJson\n", + "from IPython.display import HTML, display\n", + "from pygments import formatters, highlight, lexers\n", + "\n", + "\n", + "def handle_text_response(resp):\n", + " parts = resp.parts\n", + " full_text = \"\".join(parts)\n", + " if \"\\n\" not in full_text and len(full_text) > 80:\n", + " wrapped_text = textwrap.fill(full_text, width=80)\n", + " print(wrapped_text)\n", + " else:\n", + " print(full_text)\n", + "\n", + "\n", + "def display_schema(data):\n", + " fields = getattr(data, \"fields\")\n", + " df = pd.DataFrame({\n", + " \"Column\": map(lambda field: getattr(field, \"name\"), fields),\n", + " \"Type\": map(lambda field: getattr(field, \"type\"), fields),\n", + " \"Description\": map(lambda field: getattr(field, \"description\", \"-\"), fields),\n", + " \"Mode\": map(lambda field: getattr(field, \"mode\"), fields),\n", + " })\n", + " display(df)\n", + "\n", + "\n", + "def display_section_title(text):\n", + " display(HTML(\"

{}

\".format(text)))\n", + "\n", + "\n", + "def format_looker_table_ref(table_ref):\n", + " return \"lookmlModel: {}, explore: {}, lookerInstanceUri: {}\".format(\n", + " table_ref.lookml_model, table_ref.explore, table_ref.looker_instance_uri\n", + " )\n", + "\n", + "\n", + "def format_bq_table_ref(table_ref):\n", + " return \"{}.{}.{}\".format(\n", + " table_ref.project_id, table_ref.dataset_id, table_ref.table_id\n", + " )\n", + "\n", + "\n", + "def display_datasource(datasource):\n", + " source_name = \"\"\n", + " if \"studio_datasource_id\" in datasource:\n", + " source_name = getattr(datasource, \"studio_datasource_id\")\n", + " elif \"looker_explore_reference\" in datasource:\n", + " source_name = format_looker_table_ref(\n", + " getattr(datasource, \"looker_explore_reference\")\n", + " )\n", + " else:\n", + " source_name = format_bq_table_ref(\n", + " getattr(datasource, \"bigquery_table_reference\")\n", + " )\n", + "\n", + " print(source_name)\n", + " display_schema(datasource.schema)\n", + "\n", + "\n", + "def handle_schema_response(resp):\n", + " if \"query\" in resp:\n", + " print(resp.query.question)\n", + " elif \"result\" in resp:\n", + " display_section_title(\"Schema resolved\")\n", + " print(\"Data sources:\")\n", + " for datasource in resp.result.datasources:\n", + " display_datasource(datasource)\n", + "\n", + "\n", + "def handle_data_response(resp):\n", + " if \"query\" in resp:\n", + " query = resp.query\n", + " display_section_title(\"Retrieval query\")\n", + " print(f\"Query name: {query.name}\")\n", + " if \"question\" in query:\n", + " print(f\"Question: {query.question}\")\n", + " if \"datasources\" in query:\n", + " print(\"Data sources:\")\n", + " for datasource in query.datasources:\n", + " display_datasource(datasource)\n", + " elif \"generated_sql\" in resp:\n", + " display_section_title(\"SQL generated\")\n", + " print(resp.generated_sql)\n", + " elif \"result\" in resp:\n", + " display_section_title(\"Data retrieved\")\n", + "\n", + " fields = [field.name for field in resp.result.schema.fields]\n", + " d = {}\n", + " for el in resp.result.data:\n", + " for field in fields:\n", + " if field in d:\n", + " d[field].append(el[field])\n", + " else:\n", + " d[field] = [el[field]]\n", + "\n", + " display(pd.DataFrame(d))\n", + "\n", + "\n", + "def handle_chart_response(resp):\n", + " def _value_to_dict(v):\n", + " if isinstance(v, proto.marshal.collections.maps.MapComposite):\n", + " return _map_to_dict(v)\n", + " elif isinstance(v, proto.marshal.collections.RepeatedComposite):\n", + " return [_value_to_dict(el) for el in v]\n", + " elif isinstance(v, (int, float, str, bool)):\n", + " return v\n", + " else:\n", + " return MessageToDict(v)\n", + "\n", + " def _map_to_dict(d):\n", + " out = {}\n", + " for k in d:\n", + " if isinstance(d[k], proto.marshal.collections.maps.MapComposite):\n", + " out[k] = _map_to_dict(d[k])\n", + " else:\n", + " out[k] = _value_to_dict(d[k])\n", + " return out\n", + "\n", + " if \"query\" in resp:\n", + " print(resp.query.instructions)\n", + " elif \"result\" in resp:\n", + " vegaConfig = resp.result.vega_config\n", + " vegaConfig_dict = _map_to_dict(vegaConfig)\n", + " alt.Chart.from_json(json_lib.dumps(vegaConfig_dict)).display()\n", + "\n", + "\n", + "def show_message(msg):\n", + " m = msg.system_message\n", + " if \"text\" in m:\n", + " handle_text_response(getattr(m, \"text\"))\n", + " elif \"schema\" in m:\n", + " handle_schema_response(getattr(m, \"schema\"))\n", + " elif \"data\" in m:\n", + " handle_data_response(getattr(m, \"data\"))\n", + " elif \"chart\" in m:\n", + " handle_chart_response(getattr(m, \"chart\"))\n", + " print(\"\\n\")\n" + ] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 2. Authenticate GDA API\n", + "#\n", + "# %% code\n", + "\n", + "%pip install google-cloud-geminidataanalytics\n", + "from google.colab import auth\n", + "auth.authenticate_user()\n", + "\n", + "from google.cloud import geminidataanalytics\n", + "\n", + "data_agent_client = geminidataanalytics.DataAgentServiceClient()\n", + "data_chat_client = geminidataanalytics.DataChatServiceClient()\n" + ], + "metadata": { + "id": "U5WW4RPTpGsX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 3. Configure Data Source Details\n", + "#\n", + "# Update the variables below with your specific Google Cloud project, spanner instance details, and desired agent settings.\n", + "\n", + "# %% code\n", + "# Replace with your actual project and spanner details\n", + "billing_project = \"project\" # Your billing project ID\n", + "location = \"location\" # The region of your spanner instance\n", + "\n", + "# spanner connection details\n", + "spanner_project_id = \"project-id\" # Project ID of the spanner instance\n", + "engine = \"GOOGLE_SQL\"\n", + "spanner_instance_id = \"instance-id\" # Your spanner instance ID\n", + "spanner_database_id = \"database-id\" # Your database name (e.g., postgres)\n", + "\n", + "system_instruction = \"Help the user analyze data from the spanner database\"\n", + "\n", + "# Data Agent ID\n", + "data_agent_id = \"spanner_agent_e2e_example\"" + ], + "metadata": { + "id": "XYsWJRaKqHPk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 4. Connect to spanner source\n", + "#\n", + "# Update the variables below with your specific Google Cloud project, spanner instance details, and desired agent settings.\n", + "# # Its better to provide context_set_id similar to https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases#providing_context_with_querydata\n", + "\n", + "# %% code\n", + "# Replace with your actual project and spanner details\n", + "spanner_ref_1 = geminidataanalytics.SpannerReference()\n", + "spanner_ref_1.database_reference = geminidataanalytics.SpannerDatabaseReference()\n", + "spanner_ref_1.database_reference.project_id = billing_project\n", + "spanner_ref_1.database_reference.region = \"us-central1\" # Example region\n", + "spanner_ref_1.database_reference.engine = engine\n", + "spanner_ref_1.database_reference.instance_id = spanner_instance_id\n", + "spanner_ref_1.database_reference.database_id = spanner_database_id\n", + "# optional set agent context reference\n", + "# spanner_ref_1.agent_context_reference.context_set_id = f\"projects/{billing_project}/locations/{location}/contextSets/your_context_set_id\"\n", + "\n", + "datasource_references = geminidataanalytics.DatasourceReferences()\n", + "datasource_references.spanner_reference = spanner_ref_1\n", + "\n", + "print(\"spanner DatasourceReferences created\")" + ], + "metadata": { + "id": "kkuyz0qXrMDC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 5. Stateful chat\n", + "#\n", + "\n", + "# %% code\n", + "# Set up context for stateful chat\n", + "published_context = geminidataanalytics.Context()\n", + "published_context.system_instruction = system_instruction\n", + "published_context.datasource_references = datasource_references\n", + "# Optional: To enable advanced analysis with Python, include the following line:\n", + "published_context.options.analysis.python.enabled = True\n" + ], + "metadata": { + "id": "jSC7Bi0ExV81" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 6. Create a syncronize data agent\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "data_agent = geminidataanalytics.DataAgent()\n", + "data_agent.data_analytics_agent.published_context = published_context\n", + "data_agent.name = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\" # Optional\n", + "\n", + "request = geminidataanalytics.CreateDataAgentRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " data_agent_id=data_agent_id, # Optional\n", + " data_agent=data_agent,\n", + ")\n", + "\n", + "try:\n", + " response = data_agent_client.create_data_agent_sync(request=request)\n", + " print(\"Data Agent created\")\n", + " print(response)\n", + "except Exception as e:\n", + " print(f\"Error creating Data Agent: {e}\")\n" + ], + "metadata": { + "id": "N934aEpXxlmW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 7. Create a asyncronize data agent\n", + "#\n", + "\n", + "# data_agent_id = \"data_agent_spanner\"\n", + "\n", + "# data_agent = geminidataanalytics.DataAgent()\n", + "# data_agent.data_analytics_agent.published_context = published_context\n", + "# data_agent.name = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\" # Optional\n", + "\n", + "# request = geminidataanalytics.CreateDataAgentRequest(\n", + "# parent=f\"projects/{billing_project}/locations/{location}\",\n", + "# data_agent_id=data_agent_id, # Optional\n", + "# data_agent=data_agent,\n", + "# )\n", + "\n", + "# try:\n", + "# data_agent_client.create_data_agent(request=request)\n", + "# print(\"Data Agent created\")\n", + "# except Exception as e:\n", + "# print(f\"Error creating Data Agent: {e}\")\n", + "\n" + ], + "metadata": { + "id": "nHCranZXyBOY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 8. Create a conversation\n", + "#\n", + "\n", + "# %% code\n", + "# Initialize request arguments\n", + "conversation_id = \"conversation_spanner_example\"\n", + "\n", + "conversation = geminidataanalytics.Conversation()\n", + "conversation.agents = [f'projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}']\n", + "conversation.name = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + "\n", + "request = geminidataanalytics.CreateConversationRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " conversation_id=conversation_id,\n", + " conversation=conversation,\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.create_conversation(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "t6K4__0BypcM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 9. get a data agent\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "# Initialize request arguments\n", + "request = geminidataanalytics.GetDataAgentRequest(\n", + " name=f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\",\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_agent_client.get_data_agent(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "Ye5_Tx23y-r-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 10. list data agent\n", + "#\n", + "\n", + "# %% code\n", + "creator_filter = \"YOUR-CREATOR-FILTER\" #optional\n", + "\n", + "request = geminidataanalytics.ListDataAgentsRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " # creator_filter=creator_filter, optional\n", + ")\n", + "\n", + "# Make the request\n", + "page_result = data_agent_client.list_data_agents(request=request)\n", + "\n", + "# Handle the response\n", + "for response in page_result:\n", + " print(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "aDBTrJHszLkn" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 11. update a agent\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "data_agent = geminidataanalytics.DataAgent()\n", + "data_agent.data_analytics_agent.published_context = published_context\n", + "data_agent.name = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "data_agent.description = \"Updated description of the data agent.\"\n", + "\n", + "update_mask = field_mask_pb2.FieldMask(paths=['description', 'data_analytics_agent.published_context'])\n", + "\n", + "request = geminidataanalytics.UpdateDataAgentRequest(\n", + " data_agent=data_agent,\n", + " update_mask=update_mask,\n", + ")\n", + "\n", + "try:\n", + " # Make the request\n", + " response = data_agent_client.update_data_agent_sync(request=request)\n", + " print(\"Data Agent Updated\")\n", + " print(response)\n", + "except Exception as e:\n", + " print(f\"Error updating Data Agent: {e}\")\n" + ], + "metadata": { + "collapsed": true, + "id": "yxLspk9RzU0A" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "# %% [markdown]\n", + "# ## 12. Get IAM policy\n", + "#\n", + "\n", + "# %% code\n", + "\n", + "resource = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "request = iam_policy_pb2.GetIamPolicyRequest(\n", + " resource=resource,\n", + " )\n", + "try:\n", + " response = data_agent_client.get_iam_policy(request=request)\n", + " print(\"IAM Policy fetched successfully!\")\n", + " print(f\"Response: {response}\")\n", + "except Exception as e:\n", + " print(f\"Error setting IAM policy: {e}\")\n" + ], + "metadata": { + "collapsed": true, + "id": "q0hJ-qKh0gOx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 13. Update IAM policy\n", + "#\n", + "\n", + "# %% code\n", + "role = \"roles/geminidataanalytics.dataAgentEditor\"\n", + "users = \"tzha@google.com\" # replace with your email\n", + "\n", + "resource = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "\n", + "# Construct the IAM policy\n", + "binding = policy_pb2.Binding(\n", + " role=role,\n", + " members= [f\"user:{i.strip()}\" for i in users.split(\",\")]\n", + ")\n", + "\n", + "policy = policy_pb2.Policy(bindings=[binding])\n", + "\n", + "# Create the request\n", + "request = iam_policy_pb2.SetIamPolicyRequest(\n", + " resource=resource,\n", + " policy=policy\n", + ")\n", + "\n", + "# Send the request\n", + "try:\n", + " response = data_agent_client.set_iam_policy(request=request)\n", + " print(\"IAM Policy set successfully!\")\n", + " print(f\"Response: {response}\")\n", + "except Exception as e:\n", + " print(f\"Error setting IAM policy: {e}\")\n" + ], + "metadata": { + "id": "9_RQbHrg00EY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 14. Get a conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "\n", + "request = geminidataanalytics.GetConversationRequest(\n", + " name = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.get_conversation(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "GO4ObQfG1Qi7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 15. list conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "request = geminidataanalytics.ListConversationsRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.list_conversations(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "TrxPjSGQ1cp0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 16. list messages in conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "request = geminidataanalytics.ListMessagesRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\",\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.list_messages(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "6LWgpjhO1lXV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 17. Stateful chat with agent and conversation\n", + "#\n", + "\n", + "# %% code\"\n", + "# Create a request that contains a single user message (your question)\n", + "question = \"How many cards has John Avon illustrated in total?\"\n", + "messages = [geminidataanalytics.Message()]\n", + "messages[0].user_message.text = question\n", + "\n", + "\n", + "# Create a conversation_reference\n", + "conversation_reference = geminidataanalytics.ConversationReference()\n", + "conversation_reference.conversation = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + "conversation_reference.data_agent_context.data_agent = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "# conversation_reference.data_agent_context.credentials = credentials\n", + "\n", + "# Form the request\n", + "request = geminidataanalytics.ChatRequest(\n", + " parent = f\"projects/{billing_project}/locations/{location}\",\n", + " messages = messages,\n", + " conversation_reference = conversation_reference\n", + ")\n", + "\n", + "# Make the request\n", + "stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + "# Handle the response\n", + "for response in stream:\n", + " show_message(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "xLo-hiiw16sH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 18. Stateless chat without conversation\n", + "#\n", + "\n", + "# %% code\n", + "# Create a request that contains a single user message (your question)\n", + "question = \"How many cards has John Avon illustrated in total?\"\n", + "\n", + "messages = [geminidataanalytics.Message()]\n", + "messages[0].user_message.text = question\n", + "\n", + "\n", + "data_agent_context = geminidataanalytics.DataAgentContext()\n", + "data_agent_context.data_agent = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "# data_agent_context.credentials = credentials\n", + "\n", + "# Form the request\n", + "request = geminidataanalytics.ChatRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " messages=messages,\n", + " data_agent_context = data_agent_context\n", + ")\n", + "\n", + "# Make the request\n", + "stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + "# Handle the response\n", + "for response in stream:\n", + " show_message(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "awhj4bd52oiz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 19. Ask questions with inline context\n", + "#\n", + "\n", + "# %% code\n", + "# Create a request that contains a single user message (your question)\n", + "messages = [geminidataanalytics.Message()]\n", + "messages[0].user_message.text = question\n", + "\n", + "inline_context = geminidataanalytics.Context()\n", + "inline_context.system_instruction = system_instruction\n", + "inline_context.datasource_references = datasource_references\n", + "# Optional: To enable advanced analysis with Python, include the following line:\n", + "inline_context.options.analysis.python.enabled = True\n", + "\n", + "request = geminidataanalytics.ChatRequest(\n", + " inline_context=inline_context,\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " messages=messages,\n", + ")\n", + "\n", + "# Make the request\n", + "stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + "# Handle the response\n", + "for response in stream:\n", + " show_message(response)\n" + ], + "metadata": { + "collapsed": true, + "id": "3Oju_c-I3sHR" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 20. multi turn conversation\n", + "#\n", + "\n", + "# %% code\n", + "# List that is used to track previous turns and is reused across requests\n", + "conversation_messages = []\n", + "\n", + "\n", + "# Use data agent context\n", + "data_agent_context = geminidataanalytics.DataAgentContext()\n", + "data_agent_context.data_agent = f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\"\n", + "# data_agent_context.credentials = credentials\n", + "\n", + "# Helper function for calling the API\n", + "def multi_turn_Conversation(msg):\n", + "\n", + " message = geminidataanalytics.Message()\n", + " message.user_message.text = msg\n", + "\n", + " # Send a multi-turn request by including previous turns and the new message\n", + " conversation_messages.append(message)\n", + "\n", + " request = geminidataanalytics.ChatRequest(\n", + " parent=f\"projects/{billing_project}/locations/{location}\",\n", + " messages=conversation_messages,\n", + " # Use data agent context\n", + " data_agent_context=data_agent_context,\n", + " # Use inline context\n", + " # inline_context=inline_context,\n", + " )\n", + "\n", + " # Make the request\n", + " stream = data_chat_client.chat(request=request, timeout=300) #custom timeout up to 600s\n", + "\n", + " # Handle the response\n", + " for response in stream:\n", + " show_message(response)\n", + " conversation_messages.append(response)\n", + "\n", + "# Send the first turn request\n", + "multi_turn_Conversation(\"How many cards has John Avon illustrated in total?\")\n", + "\n", + "# Send follow-up turn request\n", + "multi_turn_Conversation(\"Which artist has the most cards?\")\n" + ], + "metadata": { + "collapsed": true, + "id": "n4BHxhWE5GVy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# %% [markdown]\n", + "# ## 21. Delete agent and conversation\n", + "#\n", + "\n", + "# %% code\n", + "request = geminidataanalytics.DeleteDataAgentRequest(\n", + " name=f\"projects/{billing_project}/locations/{location}/dataAgents/{data_agent_id}\",\n", + ")\n", + "\n", + "try:\n", + " # Make the request\n", + " data_agent_client.delete_data_agent_sync(request=request)\n", + " print(\"Data Agent Deleted\")\n", + "except Exception as e:\n", + " print(f\"Error deleting Data Agent: {e}\")\n", + "\n", + "\n", + "\n", + "request = geminidataanalytics.DeleteConversationRequest(\n", + " name = f\"projects/{billing_project}/locations/{location}/conversations/{conversation_id}\"\n", + ")\n", + "\n", + "# Make the request\n", + "response = data_chat_client.delete_conversation(request=request)\n", + "\n", + "# Handle the response\n", + "print(response)\n" + ], + "metadata": { + "id": "LONnCDRQ1LZy" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file