From 3b88960e32dcaf14466aa737059e19c0f1efed32 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 16 Jan 2025 11:32:15 -0500 Subject: [PATCH] make data field optional but check for remember memory --- crates/goose-mcp/examples/mcp.rs | 6 +++--- crates/goose-mcp/src/memory/mod.rs | 24 +++++++++++++----------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/crates/goose-mcp/examples/mcp.rs b/crates/goose-mcp/examples/mcp.rs index 15e401a06..072479892 100644 --- a/crates/goose-mcp/examples/mcp.rs +++ b/crates/goose-mcp/examples/mcp.rs @@ -1,6 +1,6 @@ // An example script to run an MCP server use anyhow::Result; -use goose_mcp::DeveloperRouter; +use goose_mcp::{DeveloperRouter, MemoryRouter}; use mcp_server::router::RouterService; use mcp_server::{ByteTransport, Server}; use tokio::io::{stdin, stdout}; @@ -10,7 +10,7 @@ use tracing_subscriber::{self, EnvFilter}; #[tokio::main] async fn main() -> Result<()> { // Set up file appender for logging - let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log"); + let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "goose-mcp-example.log"); // Initialize the tracing subscriber with file and stdout logging tracing_subscriber::fmt() @@ -25,7 +25,7 @@ async fn main() -> Result<()> { tracing::info!("Starting MCP server"); // Create an instance of our counter router - let router = RouterService(DeveloperRouter::new()); + let router = RouterService(MemoryRouter::new()); // Create and run the server let server = Server::new(router); diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index b6a20a17a..f157c9f6b 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -342,13 +342,17 @@ impl MemoryRouter { match tool_call.name.as_str() { "remember_memory" => { let args = MemoryArgs::from_value(&tool_call.arguments)?; - self.remember( - "context", - args.category, - args.data, - &args.tags, - args.is_global, - )?; + let data = args + .data + .as_deref() + .filter(|d| !d.is_empty()) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Data must exist when remembering a memory", + ) + })?; + self.remember("context", args.category, data, &args.tags, args.is_global)?; Ok(format!("Stored memory in category: {}", args.category)) } "retrieve_memories" => { @@ -496,7 +500,7 @@ impl Router for MemoryRouter { #[derive(Debug)] struct MemoryArgs<'a> { category: &'a str, - data: &'a str, + data: Option<&'a str>, tags: Vec<&'a str>, is_global: bool, } @@ -515,9 +519,7 @@ impl<'a> MemoryArgs<'a> { )); } - let data = args["data"] - .as_str() - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Data must be a string"))?; + let data = args.get("data").and_then(|d| d.as_str()); let tags = match &args["tags"] { Value::Array(arr) => arr.iter().filter_map(|v| v.as_str()).collect(),