107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
"""
|
|
Tool Registry — ChromaDB-backed semantic search over tool catalog.
|
|
Keeps agent context lean by only surfacing relevant tools via RAG.
|
|
"""
|
|
|
|
import chromadb
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(__file__), ".chromadb")
|
|
TOOLS_DIR = os.path.join(os.path.dirname(__file__), "created_tools")
|
|
|
|
class ToolRegistry:
|
|
def __init__(self, db_path: str = DB_PATH):
|
|
self.client = chromadb.PersistentClient(path=db_path)
|
|
self.collection = self.client.get_or_create_collection(
|
|
name="tool_registry",
|
|
metadata={"hnsw:space": "cosine"}
|
|
)
|
|
os.makedirs(TOOLS_DIR, exist_ok=True)
|
|
|
|
def register(self, name: str, description: str, source_path: str,
|
|
params: dict = None, tags: list = None) -> dict:
|
|
"""Register a tool in the semantic index."""
|
|
doc = f"{name}: {description}"
|
|
if tags:
|
|
doc += f" tags: {', '.join(tags)}"
|
|
|
|
metadata = {
|
|
"name": name,
|
|
"description": description,
|
|
"source_path": source_path,
|
|
"params_json": json.dumps(params or {}),
|
|
"tags": ",".join(tags or []),
|
|
}
|
|
|
|
self.collection.upsert(
|
|
ids=[name],
|
|
documents=[doc],
|
|
metadatas=[metadata],
|
|
)
|
|
return metadata
|
|
|
|
def search(self, query: str, n: int = 5, threshold: float = 0.6) -> list[dict]:
|
|
"""Semantic search for tools. Returns matches above threshold."""
|
|
results = self.collection.query(
|
|
query_texts=[query],
|
|
n_results=min(n, max(self.collection.count(), 1)),
|
|
)
|
|
if not results["ids"][0]:
|
|
return []
|
|
|
|
tools = []
|
|
for i, id_ in enumerate(results["ids"][0]):
|
|
distance = results["distances"][0][i] if results["distances"] else 1.0
|
|
similarity = 1 - distance
|
|
if similarity >= threshold:
|
|
meta = results["metadatas"][0][i]
|
|
meta["similarity"] = round(similarity, 3)
|
|
tools.append(meta)
|
|
return tools
|
|
|
|
def get(self, name: str) -> Optional[dict]:
|
|
"""Get a specific tool by name."""
|
|
try:
|
|
result = self.collection.get(ids=[name])
|
|
if result["ids"]:
|
|
return result["metadatas"][0]
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
def list_all(self) -> list[dict]:
|
|
"""List all registered tools."""
|
|
result = self.collection.get()
|
|
return result["metadatas"] if result["metadatas"] else []
|
|
|
|
def delete(self, name: str):
|
|
"""Remove a tool from the registry."""
|
|
self.collection.delete(ids=[name])
|
|
|
|
def index_existing_tools(self, tools_dir: str = None) -> int:
|
|
"""Index existing Python tools from a directory by extracting docstrings."""
|
|
tools_dir = tools_dir or os.path.join(os.path.dirname(__file__), "..")
|
|
count = 0
|
|
for f in Path(tools_dir).glob("*.py"):
|
|
if f.name.startswith("_"):
|
|
continue
|
|
try:
|
|
source = f.read_text()
|
|
# Extract first docstring
|
|
desc = ""
|
|
if '"""' in source:
|
|
parts = source.split('"""')
|
|
if len(parts) >= 3:
|
|
desc = parts[1].strip().split("\n")[0]
|
|
if not desc:
|
|
desc = f"Tool: {f.stem}"
|
|
name = f.stem.replace("-", "_")
|
|
self.register(name, desc, str(f.resolve()), tags=["existing"])
|
|
count += 1
|
|
except Exception:
|
|
continue
|
|
return count
|