Skip to content

Commit

Permalink
make data field optional but check for remember memory
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 committed Jan 16, 2025
1 parent c4f831e commit 3b88960
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
6 changes: 3 additions & 3 deletions crates/goose-mcp/examples/mcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// An example script to run an MCP server
use anyhow::Result;
use goose_mcp::DeveloperRouter;
use goose_mcp::{DeveloperRouter, MemoryRouter};
use mcp_server::router::RouterService;
use mcp_server::{ByteTransport, Server};
use tokio::io::{stdin, stdout};
Expand All @@ -10,7 +10,7 @@ use tracing_subscriber::{self, EnvFilter};
#[tokio::main]
async fn main() -> Result<()> {
// Set up file appender for logging
let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log");
let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "goose-mcp-example.log");

// Initialize the tracing subscriber with file and stdout logging
tracing_subscriber::fmt()
Expand All @@ -25,7 +25,7 @@ async fn main() -> Result<()> {
tracing::info!("Starting MCP server");

// Create an instance of our counter router
let router = RouterService(DeveloperRouter::new());
let router = RouterService(MemoryRouter::new());

// Create and run the server
let server = Server::new(router);
Expand Down
24 changes: 13 additions & 11 deletions crates/goose-mcp/src/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,17 @@ impl MemoryRouter {
match tool_call.name.as_str() {
"remember_memory" => {
let args = MemoryArgs::from_value(&tool_call.arguments)?;
self.remember(
"context",
args.category,
args.data,
&args.tags,
args.is_global,
)?;
let data = args
.data
.as_deref()
.filter(|d| !d.is_empty())
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Data must exist when remembering a memory",
)
})?;
self.remember("context", args.category, data, &args.tags, args.is_global)?;
Ok(format!("Stored memory in category: {}", args.category))
}
"retrieve_memories" => {
Expand Down Expand Up @@ -496,7 +500,7 @@ impl Router for MemoryRouter {
#[derive(Debug)]
struct MemoryArgs<'a> {
category: &'a str,
data: &'a str,
data: Option<&'a str>,
tags: Vec<&'a str>,
is_global: bool,
}
Expand All @@ -515,9 +519,7 @@ impl<'a> MemoryArgs<'a> {
));
}

let data = args["data"]
.as_str()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Data must be a string"))?;
let data = args.get("data").and_then(|d| d.as_str());

let tags = match &args["tags"] {
Value::Array(arr) => arr.iter().filter_map(|v| v.as_str()).collect(),
Expand Down

0 comments on commit 3b88960

Please sign in to comment.