Copy // Client wraps vector providers to implement the MCP tool interface
type Client struct {
provider VectorProvider
maxDocs int // Maximum documents to return in a single call
}
// NewClient creates a new RAG client with simple provider (legacy compatibility)
func NewClient(ragDatabase string) *Client {
config := map[string]interface{}{
"provider": "simple",
"database_path": ragDatabase,
}
provider, err := CreateProviderFromConfig(config)
if err != nil {
// Fallback to simple provider for backward compatibility
simpleProvider := NewSimpleProvider(ragDatabase)
_ = simpleProvider.Initialize(context.Background())
return &Client{
provider: simpleProvider,
maxDocs: 10,
}
}
return &Client{
provider: provider,
maxDocs: 10,
}
}
// CallTool implements the MCP tool interface for RAG operations
func (c *Client) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (string, error) {
switch toolName {
case "rag_search":
return c.handleRAGSearch(ctx, args)
case "rag_ingest":
return c.handleRAGIngest(ctx, args)
case "rag_stats":
return c.handleRAGStats(ctx, args)
default:
return "", fmt.Errorf("unknown RAG tool: %s. Available tools: rag_search, rag_ingest, rag_stats", toolName)
}
}
// handleRAGSearch processes search requests with enhanced formatting
func (c *Client) handleRAGSearch(ctx context.Context, args map[string]interface{}) (string, error) {
// Extract and validate query parameter
query, err := c.extractStringParam(args, "query", true)
if err != nil {
return "", err
}
// Extract optional limit parameter with validation
limit := c.maxDocs
if limitParam, exists := args["limit"]; exists {
if limitInt, ok := limitParam.(int); ok {
limit = limitInt
} else if limitFloat, ok := limitParam.(float64); ok {
limit = int(limitFloat)
} else if limitStr, ok := limitParam.(string); ok {
if parsed, parseErr := strconv.Atoi(limitStr); parseErr == nil {
limit = parsed
}
}
}
// Clamp limit to reasonable bounds
if limit <= 0 {
limit = 3
}
if limit > 20 {
limit = 20
}
// Perform search using the provider
results, err := c.provider.Search(ctx, query, SearchOptions{
Limit: limit,
})
if err != nil {
return "", fmt.Errorf("search failed: %w", err)
}
// Format results for display
if len(results) == 0 {
return "No relevant context found for query: '" + query + "'", nil
}
// Build response string with scores and highlights
var response strings.Builder
response.WriteString(fmt.Sprintf("Found %d relevant context(s) for '%s':\n", len(results), query))
for i, result := range results {
response.WriteString(fmt.Sprintf("--- Context %d ---\n", i+1))
// Add source information if available
if result.FileName != "" {
response.WriteString(fmt.Sprintf("Source: %s", result.FileName))
if result.Score > 0 {
response.WriteString(fmt.Sprintf(" (score: %.2f)", result.Score))
}
response.WriteString("\n")
}
// Add content
response.WriteString(fmt.Sprintf("Content: %s\n", result.Content))
// Add highlights if available
if len(result.Highlights) > 0 {
response.WriteString(fmt.Sprintf("Highlights: %s\n", strings.Join(result.Highlights, " | ")))
}
}
return response.String(), nil
}