diff --git a/README.md b/README.md index 74cb8d4..72875ab 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,24 @@ # powermem-java-sdk -PowerMem 的 Java SDK(JDK 11 / Maven)。 +Java SDK for PowerMem (JDK 11 / Maven). -[中文](README_CN.md) | [English](README_EN.md) | [PowerMem 主仓库](https://github.com/oceanbase/powermem) +[中文](README_CN.md) | [PowerMem Main Repository](https://github.com/oceanbase/powermem) -## 项目简介 +## Overview -PowerMem 是一套面向 AI 应用的智能记忆系统,解决“长期记忆、检索与更新”的工程化落地问题。本目录提供其 **Java SDK**,方便在 Java 应用中直接集成记忆能力,并对接不同存储与模型服务(如 SQLite、OceanBase、Qwen 等)。 +PowerMem is an intelligent memory system for AI applications. This module provides a **Java SDK** so you can integrate memory capabilities into Java services and connect to different storages and model providers (SQLite, OceanBase, Qwen, etc.). -> 详细的使用与配置请查看: -> - `README_CN.md`(中文使用指南,包含 `.env` 示例 / 测试 / 打包说明) -> - `README_EN.md`(English guide) +## Key Features -## 核心特点 +- **Developer-friendly**: JDK 11 + Maven; loads config from `.env` automatically (env vars override). +- **Pluggable storage backends**: **SQLite** (local dev) and **OceanBase (MySQL mode)**. +- **Hybrid Search**: OceanBase supports **FULLTEXT + Vector** with fusion strategies (RRF / Weighted Fusion). +- **Optional reranker**: second-stage reranking for hybrid/search candidates; writes `_fusion_score` / `_rerank_score` into result `metadata`. +- **Graph Store (Python parity)**: `graph_store` support (entity/relation extraction, persistence, multi-hop retrieval, BM25 rerank) and returns `relations` from `add/search/get_all`. +- **Best-effort schema evolution**: auto `ALTER TABLE` for missing required columns when using existing OceanBase tables. +- **Testable**: offline unit tests by default; OceanBase integration tests are env-gated. -- **开发者友好**:JDK 11 + Maven;配置自动从 `.env` 加载(环境变量优先)。 -- **存储后端可插拔**:支持 **SQLite**(本地开发)与 **OceanBase(MySQL mode)**。 -- **混合检索(Hybrid Search)**:OceanBase 支持 **FULLTEXT + Vector** 并提供融合策略(RRF / Weighted Fusion)。 -- **可选 Reranker**:支持在 hybrid/search 候选集上进行二次精排,并把 `_fusion_score` / `_rerank_score` 写入返回 `metadata` 便于观测。 -- **兼容与演进**:OceanBase 表存在时会 best-effort 执行 `ALTER TABLE` 补齐必要列,减少线上环境升级成本。 -- **可测试**:离线单测默认可跑;OceanBase 集成测试通过环境变量开关避免误连线上库。 - -## 快速开始(入口) - -- **中文完整指南**:`README_CN.md` -- **English guide**:`README_EN.md` -- **配置示例**:`.env.example` - -## 最小示例(Java) +## Minimal Example (Java) ```java import com.oceanbase.powermem.sdk.config.ConfigLoader; @@ -38,9 +29,16 @@ import com.oceanbase.powermem.sdk.model.SearchMemoriesRequest; var cfg = ConfigLoader.fromEnvAndDotEnv(); var mem = new Memory(cfg); -mem.add(AddMemoryRequest.ofText("用户喜欢简洁的中文回答", "user123")); +mem.add(AddMemoryRequest.ofText("User prefers concise answers", "user123")); -var resp = mem.search(SearchMemoriesRequest.ofQuery("用户偏好是什么?", "user123")); +var resp = mem.search(SearchMemoriesRequest.ofQuery("What are the user's preferences?", "user123")); System.out.println(resp.getResults()); ``` +## Quick Links + +- **English guide (this file)**: `README.md` +- **中文完整指南**: `README_CN.md` +- **Config template**: `.env.example` +- **Test reports**: `target/surefire-reports/` + diff --git a/README_CN.md b/README_CN.md index 71d58db..4574f08 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,9 +2,46 @@ PowerMem 的 Java SDK(JDK 11 / Maven)。 -> 本仓库提供 `README_CN.md` / `README_EN.md` 两份文档,请保持同步更新。 +> 本仓库提供 `README_CN.md`(中文)与 `README.md`(English)两份文档,请保持同步更新。 -[README](README.md) | [English](README_EN.md) | [PowerMem 主仓库](https://github.com/oceanbase/powermem) +[English](README.md) | [PowerMem 主仓库](https://github.com/oceanbase/powermem) + +## 项目简介 + +PowerMem 是一套面向 AI 应用的智能记忆系统,解决“长期记忆、检索与更新”的工程化落地问题。本目录提供其 **Java SDK**,方便在 Java 应用中直接集成记忆能力,并对接不同存储与模型服务(如 SQLite、OceanBase、Qwen 等)。 + +## 核心特点 + +- **开发者友好**:JDK 11 + Maven;配置自动从 `.env` 加载(环境变量优先)。 +- **存储后端可插拔**:支持 **SQLite**(本地开发)与 **OceanBase(MySQL mode)**。 +- **混合检索(Hybrid Search)**:OceanBase 支持 **FULLTEXT + Vector** 并提供融合策略(RRF / Weighted Fusion)。 +- **可选 Reranker**:支持在 hybrid/search 候选集上二次精排,并把 `_fusion_score` / `_rerank_score` 写入返回 `metadata` 便于观测。 +- **Graph Store(对齐 Python)**:支持 `graph_store`(实体/关系抽取、图存储、多跳检索、BM25 重排),并在 `add/search/get_all` 返回 `relations`。 +- **兼容与演进**:OceanBase 表存在时 best-effort 执行 `ALTER TABLE` 补齐必要列,减少升级成本。 +- **可测试**:离线单测默认可跑;OceanBase 集成测试通过环境变量开关避免误连线上库。 + +## 最小示例(Java) + +```java +import com.oceanbase.powermem.sdk.config.ConfigLoader; +import com.oceanbase.powermem.sdk.core.Memory; +import com.oceanbase.powermem.sdk.model.AddMemoryRequest; +import com.oceanbase.powermem.sdk.model.SearchMemoriesRequest; + +var cfg = ConfigLoader.fromEnvAndDotEnv(); +var mem = new Memory(cfg); + +mem.add(AddMemoryRequest.ofText("用户喜欢简洁的中文回答", "user123")); + +var resp = mem.search(SearchMemoriesRequest.ofQuery("用户偏好是什么?", "user123")); +System.out.println(resp.getResults()); +``` + +## 快速入口 + +- **主仓库**:[PowerMem](https://github.com/oceanbase/powermem) +- **配置示例**:`.env.example` +- **测试报告**:`target/surefire-reports/` ## 快速开始 @@ -161,6 +198,46 @@ RERANKER_BASE_URL=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rer 集成测试默认不会跑(避免误连线上库),只有在设置了环境变量时才会启用: - `OceanBaseVectorStoreIT` - `OceanBaseMemoryE2eIT` +- `OceanBaseGraphStoreIT` + +## Graph Store(关系抽取 / 图检索) + +Python 版本支持 `graph_store`,并在 `add/search/get_all` 返回 `relations`。 + +Java 版本提供了 **GraphStore 接口与返回结构对齐**,并内置一个 **`memory`(InMemoryGraphStore)** 实现用于离线测试: +- 开启方式(代码配置):`cfg.getGraphStore().setEnabled(true); cfg.getGraphStore().setProvider("memory");` +- 返回结构: + - `add`:`relations={"deleted_entities":[...], "added_entities":[{"source","relationship","target"}, ...]}` + - `search/get_all`:`relations=[{"source","relationship","destination"}, ...]` + +同时,Java 版本已提供 **`OceanBaseGraphStore`(对齐 Python 的 schema/流程)**: +- **Schema**:`graph_entities` + `graph_relationships` 两表(可通过 `GRAPH_STORE_ENTITIES_TABLE` / `GRAPH_STORE_RELATIONSHIPS_TABLE` 修改表名) +- **抽取**:支持 LLM tools(`extract_entities` / `establish_relationships` / `delete_graph_memory`),并有解析兜底 +- **ANN**:best-effort `VECTOR(dims)` + `CREATE VECTOR INDEX`,失败则回退到 `embedding_json` brute-force +- **检索**:多跳扩展 + BM25 重排(tokenizer 为轻量实现,便于离线运行) + +### GraphStore 独立 LLM/Embedding 配置(对齐 Python:`graph_store.llm.*`) + +你可以为 GraphStore 单独配置 LLM/Embedder(优先级高于全局 `LLM_*` / `EMBEDDING_*`): + +```dotenv +# 开启 graph store +GRAPH_STORE_ENABLED=true +GRAPH_STORE_PROVIDER=oceanbase + +# graph_store.llm.* +GRAPH_STORE_LLM_PROVIDER=qwen +GRAPH_STORE_LLM_API_KEY=sk-xxx +GRAPH_STORE_LLM_MODEL=qwen-plus +GRAPH_STORE_LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + +# graph_store.embedder.* +GRAPH_STORE_EMBEDDING_PROVIDER=qwen +GRAPH_STORE_EMBEDDING_API_KEY=sk-xxx +GRAPH_STORE_EMBEDDING_MODEL=text-embedding-v4 +GRAPH_STORE_EMBEDDING_DIMS=1536 +GRAPH_STORE_EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 +``` ## 测试 @@ -181,6 +258,7 @@ export OCEANBASE_PORT=2881 mvn -Dtest=OceanBaseVectorStoreIT test mvn -Dtest=OceanBaseMemoryE2eIT test +mvn -Dtest=OceanBaseGraphStoreIT test ``` ## 作为正式 Java SDK 打包与使用 diff --git a/README_EN.md b/README_EN.md deleted file mode 100644 index 0aacf39..0000000 --- a/README_EN.md +++ /dev/null @@ -1,233 +0,0 @@ -# powermem-java-sdk - -Java SDK for PowerMem (JDK 11 / Maven). - -> This repository maintains **two README files**: `README_EN.md` and `README_CN.md`. Please keep them in sync. - -[README](README.md) | [中文](README_CN.md) | [PowerMem Main Repository](https://github.com/oceanbase/powermem) - -## Quickstart - -### 1) Prepare `.env` - -An `.env.example` is provided at the repo root (covers all supported configs). Copy it to `.env` and edit as needed: - -```bash -cp .env.example .env -``` - -Config loading rules: -- **Default search**: `${user.dir}/.env` (also `config/.env`, `conf/.env`) -- **Precedence**: OS environment variables **override** `.env` keys - -### 2) Build & run - -```bash -mvn -q -DskipTests package -``` - -This is a **library SDK** project. It does not ship a runnable `Main` entry by default. -Use **unit tests / integration tests** to validate behavior. - -### 3) Run tests - -```bash -# Unit tests (offline) -mvn test -``` - -If you use `-q` (quiet), Maven prints nothing on success. Use the **exit code** to tell (0 means success). - -#### Where to find test reports - -- **Surefire reports**: `target/surefire-reports/` - - `*.txt`: plain text output - - `*.xml`: JUnit XML (CI-friendly) - -## `.env` Examples - -### SQLite + Mock (offline) - -```dotenv -DATABASE_PROVIDER=sqlite -SQLITE_PATH=./data/powermem_dev.db -EMBEDDING_PROVIDER=mock -LLM_PROVIDER=mock -``` - -### SQLite + Qwen (real embedding/LLM) - -```dotenv -DATABASE_PROVIDER=sqlite -SQLITE_PATH=./data/powermem_dev.db - -LLM_PROVIDER=qwen -LLM_API_KEY=sk-xxx -LLM_MODEL=qwen-plus -QWEN_LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 - -EMBEDDING_PROVIDER=qwen -EMBEDDING_API_KEY=sk-xxx -EMBEDDING_MODEL=text-embedding-v4 -EMBEDDING_DIMS=1536 -QWEN_EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 -``` - -### OceanBase + Mock (offline connectivity smoke test) - -```dotenv -DATABASE_PROVIDER=oceanbase -OCEANBASE_HOST=127.0.0.1 -OCEANBASE_PORT=2881 -OCEANBASE_USER=root@sys -OCEANBASE_PASSWORD=your_password -OCEANBASE_DATABASE=ai_work -OCEANBASE_COLLECTION=memories_java_debug - -# MockEmbedder uses fixed 10 dims -OCEANBASE_EMBEDDING_MODEL_DIMS=10 -EMBEDDING_PROVIDER=mock -LLM_PROVIDER=mock -``` - -Notes: -- **Strongly recommended** to use a dedicated `OCEANBASE_COLLECTION` (e.g. `memories_java_debug`) to avoid clashing with existing tables. -- If the table already exists, the SDK will best-effort `ALTER TABLE` to add required columns (e.g. `payload`/`vector`) to avoid errors like `Unknown column 'payload'`. - -### OceanBase + Qwen (real vector search) - -```dotenv -DATABASE_PROVIDER=oceanbase -OCEANBASE_HOST=127.0.0.1 -OCEANBASE_PORT=2881 -OCEANBASE_USER=root@sys -OCEANBASE_PASSWORD=your_password -OCEANBASE_DATABASE=ai_work -OCEANBASE_COLLECTION=memories_java_debug - -# Must match your embedding dims -OCEANBASE_EMBEDDING_MODEL_DIMS=1536 - -EMBEDDING_PROVIDER=qwen -EMBEDDING_API_KEY=sk-xxx -EMBEDDING_MODEL=text-embedding-v4 -EMBEDDING_DIMS=1536 -QWEN_EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 - -LLM_PROVIDER=mock -``` - -### OceanBase Hybrid + Reranker (Python parity: coarse fusion + fine rerank) - -> The reranker reorders fused candidates (default: `topK*3`) and writes `_fusion_score` / `_rerank_score` into the returned `metadata`. -> -> **FULLTEXT parser support & downgrade strategy (best-effort)**: -> - **Supported values**: `ik` / `ngram` / `ngram2` / `beng` / `space` (same as the Python implementation) -> - If an unsupported value is configured, the SDK logs a warning and ignores it -> - If FULLTEXT index creation fails, the SDK retries in this order: no parser → configured parser → other supported parsers -> - If all attempts fail, FTS falls back to `LIKE` (with warnings explaining why) - -```dotenv -DATABASE_PROVIDER=oceanbase -OCEANBASE_HOST=127.0.0.1 -OCEANBASE_PORT=2881 -OCEANBASE_USER=root@sys -OCEANBASE_PASSWORD=your_password -OCEANBASE_DATABASE=ai_work -OCEANBASE_COLLECTION=memories_java_debug - -OCEANBASE_HYBRID_SEARCH=true -OCEANBASE_FUSION_METHOD=rrf -OCEANBASE_VECTOR_WEIGHT=0.5 -OCEANBASE_FTS_WEIGHT=0.5 -OCEANBASE_FULLTEXT_PARSER=ik - -EMBEDDING_PROVIDER=qwen -EMBEDDING_API_KEY=sk-xxx -EMBEDDING_MODEL=text-embedding-v4 -EMBEDDING_DIMS=1536 -QWEN_EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 - -RERANKER_PROVIDER=qwen -RERANKER_API_KEY=sk-xxx -RERANKER_MODEL=gte-rerank-v2 -RERANKER_TOP_K=10 -RERANKER_BASE_URL=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank -``` - -## OceanBase validation (integration tests) - -Integration tests are gated by env vars to avoid accidental connections: -- `OceanBaseVectorStoreIT` -- `OceanBaseMemoryE2eIT` - -## Tests - -Unit tests (offline): - -```bash -mvn test -``` - -OceanBase integration test (requires OceanBase env vars): - -```bash -export OCEANBASE_HOST=xxx -export OCEANBASE_USER=xxx -export OCEANBASE_PASSWORD=xxx -export OCEANBASE_DATABASE=xxx -export OCEANBASE_PORT=2881 - -mvn -Dtest=OceanBaseVectorStoreIT test -mvn -Dtest=OceanBaseMemoryE2eIT test -``` - -## Packaging & using as a real Java SDK - -### 1) Install locally (use from other local projects) - -```bash -mvn -DskipTests install -``` - -Then add dependency in your app (Maven): - -```xml - - com.oceanbase.powermem.sdk - powermem-java-sdk - 1.0-SNAPSHOT - -``` - -(Gradle) - -```gradle -implementation "com.oceanbase.powermem.sdk:powermem-java-sdk:1.0-SNAPSHOT" -``` - -### 2) Publish to a repository (team/production) - -- **Private Maven repo (recommended)**: configure `distributionManagement` in `pom.xml`, then run `mvn deploy` -- **Maven Central**: requires full POM metadata (license/scm/developers), sources+javadocs, plus signing and release workflow - -### 3) Minimal usage - -```java -import com.oceanbase.powermem.sdk.config.ConfigLoader; -import com.oceanbase.powermem.sdk.core.Memory; -import com.oceanbase.powermem.sdk.model.AddMemoryRequest; -import com.oceanbase.powermem.sdk.model.SearchMemoriesRequest; - -var cfg = ConfigLoader.fromEnvAndDotEnv(); -var mem = new Memory(cfg); - -var add = AddMemoryRequest.ofText("User prefers concise answers", "user123"); -add.setInfer(false); -mem.add(add); - -var search = SearchMemoriesRequest.ofQuery("What are the user's preferences?", "user123"); -var resp = mem.search(search); -System.out.println(resp.getResults()); -``` - diff --git a/src/main/java/com/oceanbase/powermem/sdk/config/ConfigLoader.java b/src/main/java/com/oceanbase/powermem/sdk/config/ConfigLoader.java index 886c202..4e46d11 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/config/ConfigLoader.java +++ b/src/main/java/com/oceanbase/powermem/sdk/config/ConfigLoader.java @@ -136,6 +136,50 @@ public static MemoryConfig fromMap(Map values) { setIfPresent(values, v -> vector.setPoolSize(parseInt(v)), "DATABASE_POOL_SIZE"); setIfPresent(values, v -> vector.setMaxOverflow(parseInt(v)), "DATABASE_MAX_OVERFLOW"); + // Graph store (optional). Mirrors Python graph_store.enable/provider and uses OceanBase by default. + GraphStoreConfig graph = config.getGraphStore(); + setIfPresent(values, v -> graph.setEnabled(parseBoolean(v)), "GRAPH_STORE_ENABLED", "graph_store.enabled"); + setIfPresent(values, graph::setProvider, "GRAPH_STORE_PROVIDER", "graph_store.provider"); + // default to OceanBase connection if graph-specific env isn't set + setIfPresent(values, graph::setHost, "GRAPH_STORE_HOST", "OCEANBASE_HOST"); + setIfPresent(values, v -> graph.setPort(parseInt(v)), "GRAPH_STORE_PORT", "OCEANBASE_PORT"); + setIfPresent(values, graph::setUser, "GRAPH_STORE_USER", "OCEANBASE_USER"); + setIfPresent(values, graph::setPassword, "GRAPH_STORE_PASSWORD", "OCEANBASE_PASSWORD"); + setIfPresent(values, graph::setDatabase, "GRAPH_STORE_DATABASE", "OCEANBASE_DATABASE"); + setIfPresent(values, v -> graph.setTimeoutSeconds(parseInt(v)), "GRAPH_STORE_TIMEOUT"); + // Python parity options + setIfPresent(values, graph::setEntitiesTable, "GRAPH_STORE_ENTITIES_TABLE", "GRAPH_STORE_TABLE_ENTITIES"); + setIfPresent(values, graph::setRelationshipsTable, "GRAPH_STORE_RELATIONSHIPS_TABLE", "GRAPH_STORE_TABLE_RELATIONSHIPS"); + setIfPresent(values, v -> graph.setEmbeddingModelDims(parseInt(v)), "GRAPH_STORE_EMBEDDING_MODEL_DIMS", "OCEANBASE_EMBEDDING_MODEL_DIMS"); + setIfPresent(values, graph::setIndexType, "GRAPH_STORE_INDEX_TYPE", "OCEANBASE_INDEX_TYPE"); + setIfPresent(values, graph::setMetricType, "GRAPH_STORE_VECTOR_METRIC_TYPE", "GRAPH_STORE_METRIC_TYPE", "OCEANBASE_VECTOR_METRIC_TYPE", "OCEANBASE_METRIC_TYPE"); + setIfPresent(values, graph::setVectorIndexName, "GRAPH_STORE_VIDX_NAME", "OCEANBASE_VIDX_NAME"); + setIfPresent(values, v -> graph.setSimilarityThreshold(parseDouble(v)), "GRAPH_STORE_SIMILARITY_THRESHOLD"); + setIfPresent(values, v -> graph.setMaxHops(parseInt(v)), "GRAPH_STORE_MAX_HOPS"); + setIfPresent(values, v -> graph.setSearchLimit(parseInt(v)), "GRAPH_STORE_SEARCH_LIMIT"); + setIfPresent(values, v -> graph.setBm25TopN(parseInt(v)), "GRAPH_STORE_BM25_TOP_N"); + setIfPresent(values, graph::setCustomPrompt, "GRAPH_STORE_CUSTOM_PROMPT", "graph_store.custom_prompt"); + setIfPresent(values, graph::setCustomExtractRelationsPrompt, "GRAPH_STORE_CUSTOM_EXTRACT_RELATIONS_PROMPT"); + setIfPresent(values, graph::setCustomDeleteRelationsPrompt, "GRAPH_STORE_CUSTOM_DELETE_RELATIONS_PROMPT"); + + // graph_store.llm.* overrides (if set, GraphStoreFactory will create a dedicated LLM instance) + LlmConfig graphLlm = graph.getLlm(); + setIfPresent(values, graphLlm::setProvider, "GRAPH_STORE_LLM_PROVIDER"); + setIfPresent(values, graphLlm::setApiKey, "GRAPH_STORE_LLM_API_KEY"); + setIfPresent(values, graphLlm::setModel, "GRAPH_STORE_LLM_MODEL"); + setIfPresent(values, graphLlm::setBaseUrl, "GRAPH_STORE_LLM_BASE_URL"); + setIfPresent(values, v -> graphLlm.setTemperature(parseDouble(v)), "GRAPH_STORE_LLM_TEMPERATURE"); + setIfPresent(values, v -> graphLlm.setMaxTokens(parseInt(v)), "GRAPH_STORE_LLM_MAX_TOKENS"); + setIfPresent(values, v -> graphLlm.setTopP(parseDouble(v)), "GRAPH_STORE_LLM_TOP_P"); + + // graph_store.embedder.* overrides (if set, GraphStoreFactory will create a dedicated Embedder instance) + EmbedderConfig graphEmb = graph.getEmbedder(); + setIfPresent(values, graphEmb::setProvider, "GRAPH_STORE_EMBEDDING_PROVIDER", "GRAPH_STORE_EMBEDDER_PROVIDER"); + setIfPresent(values, graphEmb::setApiKey, "GRAPH_STORE_EMBEDDING_API_KEY", "GRAPH_STORE_EMBEDDER_API_KEY"); + setIfPresent(values, graphEmb::setModel, "GRAPH_STORE_EMBEDDING_MODEL", "GRAPH_STORE_EMBEDDER_MODEL"); + setIfPresent(values, v -> graphEmb.setEmbeddingDims(parseInt(v)), "GRAPH_STORE_EMBEDDING_DIMS", "GRAPH_STORE_EMBEDDER_DIMS"); + setIfPresent(values, graphEmb::setBaseUrl, "GRAPH_STORE_EMBEDDING_BASE_URL", "GRAPH_STORE_EMBEDDER_BASE_URL"); + LlmConfig llm = config.getLlm(); setIfPresent(values, llm::setProvider, "LLM_PROVIDER", "llm.provider"); setIfPresent(values, llm::setApiKey, "LLM_API_KEY"); diff --git a/src/main/java/com/oceanbase/powermem/sdk/config/GraphStoreConfig.java b/src/main/java/com/oceanbase/powermem/sdk/config/GraphStoreConfig.java index dd3761f..97781c5 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/config/GraphStoreConfig.java +++ b/src/main/java/com/oceanbase/powermem/sdk/config/GraphStoreConfig.java @@ -13,6 +13,40 @@ public class GraphStoreConfig { private String user = "root"; private String password; private String database = "powermem"; + // Python parity: entities/relationships tables (defaults in powermem.storage.oceanbase.constants) + private String entitiesTable = "graph_entities"; + private String relationshipsTable = "graph_relationships"; + + // Vector / ANN configuration (best-effort in JDBC). + private int embeddingModelDims = 0; + private String indexType = "HNSW"; + private String metricType = "l2"; + private String vectorIndexName = "vidx"; + private double similarityThreshold = 0.7; + + // Graph search behavior. + private int maxHops = 3; + private int searchLimit = 100; + private int bm25TopN = 15; + + // Prompts customization (Python: custom_prompt or custom_*_prompt) + private String customPrompt; + private String customExtractRelationsPrompt; + private String customDeleteRelationsPrompt; + + /** + * Optional per-graph-store LLM config override (Python parity: graph_store.llm.*). + * If provider/apiKey is not set, the Memory-level LLM is used. + */ + private LlmConfig llm; + + /** + * Optional per-graph-store embedder config override (Python parity: graph_store.embedder.*). + * If provider/apiKey is not set, the Memory-level Embedder is used. + */ + private EmbedderConfig embedder; + + private int timeoutSeconds = 10; public GraphStoreConfig() {} @@ -71,5 +105,139 @@ public String getDatabase() { public void setDatabase(String database) { this.database = database; } + + public String getEntitiesTable() { + return entitiesTable; + } + + public void setEntitiesTable(String entitiesTable) { + this.entitiesTable = entitiesTable; + } + + public String getRelationshipsTable() { + return relationshipsTable; + } + + public void setRelationshipsTable(String relationshipsTable) { + this.relationshipsTable = relationshipsTable; + } + + public int getEmbeddingModelDims() { + return embeddingModelDims; + } + + public void setEmbeddingModelDims(int embeddingModelDims) { + this.embeddingModelDims = embeddingModelDims; + } + + public String getIndexType() { + return indexType; + } + + public void setIndexType(String indexType) { + this.indexType = indexType; + } + + public String getMetricType() { + return metricType; + } + + public void setMetricType(String metricType) { + this.metricType = metricType; + } + + public String getVectorIndexName() { + return vectorIndexName; + } + + public void setVectorIndexName(String vectorIndexName) { + this.vectorIndexName = vectorIndexName; + } + + public double getSimilarityThreshold() { + return similarityThreshold; + } + + public void setSimilarityThreshold(double similarityThreshold) { + this.similarityThreshold = similarityThreshold; + } + + public int getMaxHops() { + return maxHops; + } + + public void setMaxHops(int maxHops) { + this.maxHops = maxHops; + } + + public int getSearchLimit() { + return searchLimit; + } + + public void setSearchLimit(int searchLimit) { + this.searchLimit = searchLimit; + } + + public int getBm25TopN() { + return bm25TopN; + } + + public void setBm25TopN(int bm25TopN) { + this.bm25TopN = bm25TopN; + } + + public String getCustomPrompt() { + return customPrompt; + } + + public void setCustomPrompt(String customPrompt) { + this.customPrompt = customPrompt; + } + + public String getCustomExtractRelationsPrompt() { + return customExtractRelationsPrompt; + } + + public void setCustomExtractRelationsPrompt(String customExtractRelationsPrompt) { + this.customExtractRelationsPrompt = customExtractRelationsPrompt; + } + + public String getCustomDeleteRelationsPrompt() { + return customDeleteRelationsPrompt; + } + + public void setCustomDeleteRelationsPrompt(String customDeleteRelationsPrompt) { + this.customDeleteRelationsPrompt = customDeleteRelationsPrompt; + } + + public LlmConfig getLlm() { + if (llm == null) { + llm = new LlmConfig(); + } + return llm; + } + + public void setLlm(LlmConfig llm) { + this.llm = llm; + } + + public EmbedderConfig getEmbedder() { + if (embedder == null) { + embedder = new EmbedderConfig(); + } + return embedder; + } + + public void setEmbedder(EmbedderConfig embedder) { + this.embedder = embedder; + } + + public int getTimeoutSeconds() { + return timeoutSeconds; + } + + public void setTimeoutSeconds(int timeoutSeconds) { + this.timeoutSeconds = timeoutSeconds; + } } diff --git a/src/main/java/com/oceanbase/powermem/sdk/core/Memory.java b/src/main/java/com/oceanbase/powermem/sdk/core/Memory.java index 50ab8d0..a42e8a3 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/core/Memory.java +++ b/src/main/java/com/oceanbase/powermem/sdk/core/Memory.java @@ -12,6 +12,7 @@ public class Memory implements MemoryBase { private final com.oceanbase.powermem.sdk.config.MemoryConfig config; private final com.oceanbase.powermem.sdk.storage.base.VectorStore vectorStore; + private final com.oceanbase.powermem.sdk.storage.base.GraphStore graphStore; private final com.oceanbase.powermem.sdk.integrations.embeddings.Embedder embedder; private final com.oceanbase.powermem.sdk.intelligence.IntelligenceManager intelligence; private final com.oceanbase.powermem.sdk.integrations.llm.LLM llm; @@ -32,12 +33,17 @@ public Memory(com.oceanbase.powermem.sdk.config.MemoryConfig config) { this.intelligence = new com.oceanbase.powermem.sdk.intelligence.IntelligenceManager(this.config.getIntelligentMemory()); this.plugin = new com.oceanbase.powermem.sdk.intelligence.plugin.EbbinghausIntelligencePlugin(this.config.getIntelligentMemory()); this.reranker = com.oceanbase.powermem.sdk.integrations.rerank.RerankFactory.fromConfig(this.config.getReranker()); + this.graphStore = com.oceanbase.powermem.sdk.storage.factory.GraphStoreFactory.fromConfig( + this.config.getGraphStore(), this.embedder, this.llm); } @Override public com.oceanbase.powermem.sdk.model.AddMemoryResponse add(com.oceanbase.powermem.sdk.model.AddMemoryRequest request) { com.oceanbase.powermem.sdk.util.Preconditions.requireNonNull(request, "AddMemoryRequest is required"); + // Graph store (optional): add raw conversation text and return relations summary. + java.util.Map graphResult = maybeAddToGraph(request); + java.util.List msgs = request.getMessages(); String normalized = com.oceanbase.powermem.sdk.util.PowermemUtils.normalizeInput(request.getText(), msgs); if (normalized.isBlank()) { @@ -67,10 +73,15 @@ public com.oceanbase.powermem.sdk.model.AddMemoryResponse add(com.oceanbase.powe counts.put("NONE", 0); // Python benchmark/server: add() 返回不包含 action_counts;保持为 null 以省略序列化 resp.setActionCounts(null); - resp.setRelations(null); + // Python parity: when graph is enabled, return relations even if empty. + resp.setRelations(graphResult == null ? java.util.Collections.emptyMap() : graphResult); return resp; } - return intelligentAdd(request, normalized); + com.oceanbase.powermem.sdk.model.AddMemoryResponse resp = intelligentAdd(request, normalized); + if (resp != null) { + resp.setRelations(graphResult == null ? java.util.Collections.emptyMap() : graphResult); + } + return resp; } private com.oceanbase.powermem.sdk.model.AddMemoryResponse intelligentAdd(com.oceanbase.powermem.sdk.model.AddMemoryRequest request, String normalized) { @@ -325,7 +336,14 @@ public com.oceanbase.powermem.sdk.model.SearchMemoriesResponse search(com.oceanb results = filtered; } com.oceanbase.powermem.sdk.model.SearchMemoriesResponse resp = new com.oceanbase.powermem.sdk.model.SearchMemoriesResponse(results); - resp.setRelations(null); + if (graphStore != null && config != null && config.getGraphStore() != null && config.getGraphStore().isEnabled()) { + java.util.Map gf = buildGraphFilters(request.getUserId(), request.getAgentId(), request.getRunId(), request.getFilters()); + java.util.List> gr = graphStore.search(request.getQuery(), gf, limit); + // Python parity: when graph enabled, relations should always be present (empty list allowed). + resp.setRelations(gr == null ? java.util.Collections.emptyList() : gr); + } else { + resp.setRelations(null); + } return resp; } @@ -473,7 +491,14 @@ public com.oceanbase.powermem.sdk.model.GetAllMemoriesResponse getAll(com.oceanb results.add(m); } resp.setResults(results); - resp.setRelations(null); + if (graphStore != null && config != null && config.getGraphStore() != null && config.getGraphStore().isEnabled()) { + java.util.Map gf = buildGraphFilters(request.getUserId(), request.getAgentId(), request.getRunId(), null); + int lim = request.getLimit() > 0 ? request.getLimit() : 100; + java.util.List> gr = graphStore.getAll(gf, lim + Math.max(0, request.getOffset())); + resp.setRelations(gr == null ? java.util.Collections.emptyList() : gr); + } else { + resp.setRelations(null); + } return resp; } @@ -482,7 +507,55 @@ public com.oceanbase.powermem.sdk.model.DeleteAllMemoriesResponse deleteAll( com.oceanbase.powermem.sdk.model.DeleteAllMemoriesRequest request) { com.oceanbase.powermem.sdk.util.Preconditions.requireNonNull(request, "DeleteAllMemoriesRequest is required"); int deleted = storage.clearMemories(request.getUserId(), request.getAgentId(), request.getRunId()); + if (graphStore != null && config != null && config.getGraphStore() != null && config.getGraphStore().isEnabled()) { + java.util.Map gf = buildGraphFilters(request.getUserId(), request.getAgentId(), request.getRunId(), null); + try { + graphStore.deleteAll(gf); + } catch (Exception ignored) { + // best-effort, graph should not break deleteAll + } + } return new com.oceanbase.powermem.sdk.model.DeleteAllMemoriesResponse(deleted); } + + private java.util.Map maybeAddToGraph(com.oceanbase.powermem.sdk.model.AddMemoryRequest request) { + if (graphStore == null || config == null || config.getGraphStore() == null || !config.getGraphStore().isEnabled()) { + return null; + } + String data; + if (request.getMessages() != null && !request.getMessages().isEmpty()) { + StringBuilder sb = new StringBuilder(); + for (com.oceanbase.powermem.sdk.model.Message m : request.getMessages()) { + if (m == null) continue; + if ("system".equalsIgnoreCase(m.getRole())) continue; + if (m.getContent() == null || m.getContent().isBlank()) continue; + if (sb.length() > 0) sb.append("\n"); + sb.append(m.getContent()); + } + data = sb.toString(); + } else { + data = request.getText(); + } + if (data == null || data.isBlank()) { + return null; + } + java.util.Map gf = buildGraphFilters(request.getUserId(), request.getAgentId(), request.getRunId(), request.getFilters()); + return graphStore.add(data, gf); + } + + private static java.util.Map buildGraphFilters( + String userId, + String agentId, + String runId, + java.util.Map extraFilters) { + java.util.Map gf = extraFilters == null ? new java.util.HashMap<>() : new java.util.HashMap<>(extraFilters); + gf.put("user_id", userId == null ? "user" : userId); + gf.put("agent_id", agentId); + gf.put("run_id", runId); + if (gf.get("user_id") == null) { + gf.put("user_id", "user"); + } + return gf; + } } diff --git a/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphPrompts.java b/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphPrompts.java index f72db30..e374fb8 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphPrompts.java +++ b/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphPrompts.java @@ -7,5 +7,84 @@ */ public final class GraphPrompts { private GraphPrompts() {} + + // Default prompts (mirrors Python). + public static final String EXTRACT_RELATIONS_PROMPT = + "You are an advanced algorithm designed to extract structured information from text to construct knowledge graphs. " + + "Your goal is to capture comprehensive and accurate information. Follow these key principles:\n\n" + + "1. Extract only explicitly stated information from the text.\n" + + "2. Establish relationships among the entities provided.\n" + + "3. Use \"USER_ID\" as the source entity for any self-references (e.g., \"I,\" \"me,\" \"my,\" etc.) in user messages.\n" + + "CUSTOM_PROMPT\n\n" + + "Relationships:\n" + + " - Use consistent, general, and timeless relationship types.\n" + + " - Example: Prefer \"professor\" over \"became_professor.\"\n" + + " - Relationships should only be established among the entities explicitly mentioned in the user message.\n\n" + + "Entity Consistency:\n" + + " - Ensure that relationships are coherent and logically align with the context of the message.\n" + + " - Maintain consistent naming for entities across the extracted data.\n\n" + + "Strive to construct a coherent and easily understandable knowledge graph by eshtablishing all the relationships among the entities " + + "and adherence to the user's context.\n\n" + + "Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction."; + + public static final String DELETE_RELATIONS_SYSTEM_PROMPT = + "You are a graph memory manager specializing in identifying, managing, and optimizing relationships within graph-based memories. " + + "Your primary task is to analyze a list of existing relationships and determine which ones should be deleted based on the new information provided.\n" + + "Input:\n" + + "1. Existing Graph Memories: A list of current graph memories, each containing source, relationship, and destination information.\n" + + "2. New Text: The new information to be integrated into the existing graph structure.\n" + + "3. Use \"USER_ID\" as node for any self-references (e.g., \"I,\" \"me,\" \"my,\" etc.) in user messages.\n\n" + + "Guidelines:\n" + + "1. Identification: Use the new information to evaluate existing relationships in the memory graph.\n" + + "2. Deletion Criteria: Delete a relationship only if it meets at least one of these conditions:\n" + + " - Outdated or Inaccurate: The new information is more recent or accurate.\n" + + " - Contradictory: The new information conflicts with or negates the existing information.\n" + + "3. DO NOT DELETE if their is a possibility of same type of relationship but different destination nodes.\n" + + "4. Comprehensive Analysis:\n" + + " - Thoroughly examine each existing relationship against the new information and delete as necessary.\n" + + " - Multiple deletions may be required based on the new information.\n" + + "5. Semantic Integrity:\n" + + " - Ensure that deletions maintain or improve the overall semantic structure of the graph.\n" + + " - Avoid deleting relationships that are NOT contradictory/outdated to the new information.\n" + + "6. Temporal Awareness: Prioritize recency when timestamps are available.\n" + + "7. Necessity Principle: Only DELETE relationships that must be deleted and are contradictory/outdated to the new information to maintain an accurate and coherent memory graph.\n\n" + + "Note: DO NOT DELETE if their is a possibility of same type of relationship but different destination nodes.\n\n" + + "For example:\n" + + "Existing Memory: alice -- loves_to_eat -- pizza\n" + + "New Information: Alice also loves to eat burger.\n\n" + + "Do not delete in the above example because there is a possibility that Alice loves to eat both pizza and burger.\n\n" + + "Memory Format:\n" + + "source -- relationship -- destination\n\n" + + "Provide a list of deletion instructions, each specifying the relationship to be deleted."; + + public static String extractRelationsSystemPrompt(String userIdentity, String customPrompt) { + String p = EXTRACT_RELATIONS_PROMPT; + if (userIdentity != null && !userIdentity.isBlank()) { + p = p.replace("USER_ID", userIdentity); + } + if (customPrompt != null && !customPrompt.isBlank()) { + p = p.replace("CUSTOM_PROMPT", "4. " + customPrompt); + } else { + p = p.replace("CUSTOM_PROMPT", ""); + } + return p; + } + + public static String deleteRelationsSystemPrompt(String userIdForSelfRef, String customPrompt) { + String p = DELETE_RELATIONS_SYSTEM_PROMPT; + String uid = (userIdForSelfRef == null || userIdForSelfRef.isBlank()) ? "USER_ID" : userIdForSelfRef; + p = p.replace("USER_ID", uid); + // Currently we don't have a separate custom delete prompt; allow appending. + if (customPrompt != null && !customPrompt.isBlank()) { + p = p + "\n\nCUSTOM:\n" + customPrompt; + } + return p; + } + + public static String deleteRelationsUserPrompt(String existingMemories, String newText) { + String em = existingMemories == null ? "" : existingMemories; + String nt = newText == null ? "" : newText; + return "Here are the existing memories: " + em + " \n\n New Information: " + nt; + } } diff --git a/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphToolsPrompts.java b/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphToolsPrompts.java index 499d4e6..819fe56 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphToolsPrompts.java +++ b/src/main/java/com/oceanbase/powermem/sdk/prompts/graph/GraphToolsPrompts.java @@ -7,5 +7,114 @@ */ public final class GraphToolsPrompts { private GraphToolsPrompts() {} + + public static java.util.Map extractEntitiesTool(boolean structured) { + java.util.Map tool = new java.util.HashMap<>(); + tool.put("type", "function"); + java.util.Map fn = new java.util.HashMap<>(); + fn.put("name", "extract_entities"); + fn.put("description", "Extract entities and their types from the text."); + if (structured) { + fn.put("strict", Boolean.TRUE); + } + java.util.Map params = new java.util.HashMap<>(); + params.put("type", "object"); + java.util.Map props = new java.util.HashMap<>(); + java.util.Map entities = new java.util.HashMap<>(); + entities.put("type", "array"); + entities.put("description", "An array of entities with their types."); + java.util.Map items = new java.util.HashMap<>(); + items.put("type", "object"); + java.util.Map itemProps = new java.util.HashMap<>(); + itemProps.put("entity", java.util.Map.of("type", "string", "description", "The name or identifier of the entity.")); + itemProps.put("entity_type", java.util.Map.of("type", "string", "description", "The type or category of the entity.")); + items.put("properties", itemProps); + items.put("required", java.util.List.of("entity", "entity_type")); + items.put("additionalProperties", Boolean.FALSE); + entities.put("items", items); + props.put("entities", entities); + params.put("properties", props); + params.put("required", java.util.List.of("entities")); + params.put("additionalProperties", Boolean.FALSE); + fn.put("parameters", params); + tool.put("function", fn); + return tool; + } + + public static java.util.Map establishRelationsTool(boolean structured) { + java.util.Map tool = new java.util.HashMap<>(); + tool.put("type", "function"); + java.util.Map fn = new java.util.HashMap<>(); + // Python uses "establish_relationships" for non-structured and "establish_relations" for structured. + fn.put("name", structured ? "establish_relations" : "establish_relationships"); + fn.put("description", "Establish relationships among the entities based on the provided text."); + if (structured) { + fn.put("strict", Boolean.TRUE); + } + java.util.Map params = new java.util.HashMap<>(); + params.put("type", "object"); + java.util.Map props = new java.util.HashMap<>(); + + java.util.Map entities = new java.util.HashMap<>(); + entities.put("type", "array"); + java.util.Map items = new java.util.HashMap<>(); + items.put("type", "object"); + java.util.Map itemProps = new java.util.HashMap<>(); + itemProps.put("source", java.util.Map.of("type", "string", "description", "The source entity of the relationship.")); + itemProps.put("relationship", java.util.Map.of("type", "string", "description", "The relationship between the source and destination entities.")); + itemProps.put("destination", java.util.Map.of("type", "string", "description", "The destination entity of the relationship.")); + items.put("properties", itemProps); + items.put("required", java.util.List.of("source", "relationship", "destination")); + items.put("additionalProperties", Boolean.FALSE); + entities.put("items", items); + + props.put("entities", entities); + params.put("properties", props); + params.put("required", java.util.List.of("entities")); + params.put("additionalProperties", Boolean.FALSE); + fn.put("parameters", params); + tool.put("function", fn); + return tool; + } + + public static java.util.Map deleteGraphMemoryTool(boolean structured) { + java.util.Map tool = new java.util.HashMap<>(); + tool.put("type", "function"); + java.util.Map fn = new java.util.HashMap<>(); + fn.put("name", "delete_graph_memory"); + fn.put("description", "Delete the relationship between two nodes. This function deletes the existing relationship."); + if (structured) { + fn.put("strict", Boolean.TRUE); + } + java.util.Map params = new java.util.HashMap<>(); + params.put("type", "object"); + java.util.Map props = new java.util.HashMap<>(); + props.put("source", java.util.Map.of("type", "string", "description", "The identifier of the source node in the relationship.")); + props.put("relationship", java.util.Map.of("type", "string", "description", "The existing relationship between the source and destination nodes that needs to be deleted.")); + props.put("destination", java.util.Map.of("type", "string", "description", "The identifier of the destination node in the relationship.")); + params.put("properties", props); + params.put("required", java.util.List.of("source", "relationship", "destination")); + params.put("additionalProperties", Boolean.FALSE); + fn.put("parameters", params); + tool.put("function", fn); + return tool; + } + + public static java.util.List> tools(java.util.List> toolList) { + return toolList == null ? java.util.Collections.emptyList() : toolList; + } + + /** + * OpenAI-style tool_choice object: {"type":"function","function":{"name":"..."}}. + */ + public static java.util.Map toolChoiceFunction(String name) { + if (name == null || name.isBlank()) { + return null; + } + java.util.Map m = new java.util.HashMap<>(); + m.put("type", "function"); + m.put("function", java.util.Map.of("name", name)); + return m; + } } diff --git a/src/main/java/com/oceanbase/powermem/sdk/storage/base/GraphStore.java b/src/main/java/com/oceanbase/powermem/sdk/storage/base/GraphStore.java index 314db16..0af9b1e 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/storage/base/GraphStore.java +++ b/src/main/java/com/oceanbase/powermem/sdk/storage/base/GraphStore.java @@ -7,5 +7,39 @@ * (GraphStoreFactory).

*/ public interface GraphStore { + /** + * Add raw text data to graph store (entity/relation extraction + upsert). + * + *

Python parity: {@code GraphStoreBase.add(data, filters)} returns a dict like: + * {@code {"deleted_entities": [...], "added_entities": [{"source":..,"relationship":..,"target":..}, ...]}}.

+ * + * @param data raw text content + * @param filters scope filters (user_id/agent_id/run_id, etc) + * @return relations summary (may be empty) + */ + java.util.Map add(String data, java.util.Map filters); + + /** + * Search graph store for relevant relations. + * + *

Python parity: {@code GraphStoreBase.search(query, filters)} returns a list of dict like + * {@code [{"source":..,"relationship":..,"destination":..}, ...]}.

+ */ + java.util.List> search(String query, java.util.Map filters, int limit); + + /** + * Delete all graph data in the given scope. + */ + void deleteAll(java.util.Map filters); + + /** + * Retrieve all relations in the given scope. + */ + java.util.List> getAll(java.util.Map filters, int limit); + + /** + * Reset graph store state. + */ + void reset(); } diff --git a/src/main/java/com/oceanbase/powermem/sdk/storage/factory/GraphStoreFactory.java b/src/main/java/com/oceanbase/powermem/sdk/storage/factory/GraphStoreFactory.java index 90b9933..069932c 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/storage/factory/GraphStoreFactory.java +++ b/src/main/java/com/oceanbase/powermem/sdk/storage/factory/GraphStoreFactory.java @@ -8,5 +8,53 @@ */ public final class GraphStoreFactory { private GraphStoreFactory() {} + + public static com.oceanbase.powermem.sdk.storage.base.GraphStore fromConfig( + com.oceanbase.powermem.sdk.config.GraphStoreConfig config) { + return fromConfig(config, null, null); + } + + public static com.oceanbase.powermem.sdk.storage.base.GraphStore fromConfig( + com.oceanbase.powermem.sdk.config.GraphStoreConfig config, + com.oceanbase.powermem.sdk.integrations.embeddings.Embedder embedder, + com.oceanbase.powermem.sdk.integrations.llm.LLM llm) { + if (config == null || !config.isEnabled()) { + return null; + } + String provider = config.getProvider(); + if (provider == null || provider.isBlank()) { + provider = "oceanbase"; + } + String p = provider.trim().toLowerCase(); + + // Offline-friendly: in-memory implementation for tests and local experiments. + if ("memory".equals(p) || "inmemory".equals(p) || "in-memory".equals(p) || "mock".equals(p)) { + return new com.oceanbase.powermem.sdk.storage.memory.InMemoryGraphStore(); + } + + // graph_store.llm/embedder override: if configured, create dedicated instances. + com.oceanbase.powermem.sdk.integrations.llm.LLM effLlm = llm; + try { + com.oceanbase.powermem.sdk.config.LlmConfig gl = config.getLlm(); + if (gl != null && gl.getProvider() != null && !gl.getProvider().isBlank()) { + effLlm = com.oceanbase.powermem.sdk.integrations.llm.LLMFactory.fromConfig(gl); + } + } catch (Exception ignored) {} + + com.oceanbase.powermem.sdk.integrations.embeddings.Embedder effEmb = embedder; + try { + com.oceanbase.powermem.sdk.config.EmbedderConfig ge = config.getEmbedder(); + if (ge != null && ge.getProvider() != null && !ge.getProvider().isBlank()) { + effEmb = com.oceanbase.powermem.sdk.integrations.embeddings.EmbedderFactory.fromConfig(ge); + } + } catch (Exception ignored) {} + + if ("oceanbase".equals(p)) { + return new com.oceanbase.powermem.sdk.storage.oceanbase.OceanBaseGraphStore(config, effEmb, effLlm); + } + + // Fallback: treat unknown as oceanbase (but likely not implemented). + return new com.oceanbase.powermem.sdk.storage.oceanbase.OceanBaseGraphStore(config, effEmb, effLlm); + } } diff --git a/src/main/java/com/oceanbase/powermem/sdk/storage/memory/InMemoryGraphStore.java b/src/main/java/com/oceanbase/powermem/sdk/storage/memory/InMemoryGraphStore.java new file mode 100644 index 0000000..7159b33 --- /dev/null +++ b/src/main/java/com/oceanbase/powermem/sdk/storage/memory/InMemoryGraphStore.java @@ -0,0 +1,233 @@ +package com.oceanbase.powermem.sdk.storage.memory; + +import com.oceanbase.powermem.sdk.storage.base.GraphStore; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * In-memory graph store for offline usage and tests. + * + *

This implementation is NOT intended for production. It exists to: + * - provide Python-like return shapes for {@code relations} + * - enable unit/integration tests without a real graph database

+ */ +public class InMemoryGraphStore implements GraphStore { + + private static final class Triple { + final String source; + final String relationship; + final String destination; + + Triple(String source, String relationship, String destination) { + this.source = source; + this.relationship = relationship; + this.destination = destination; + } + } + + // scopeKey -> triples + private final Map> store = new ConcurrentHashMap<>(); + + @Override + public Map add(String data, Map filters) { + if (data == null || data.isBlank()) { + return Collections.emptyMap(); + } + String scopeKey = scopeKey(filters); + + List newTriples = extractTriples(data); + if (newTriples.isEmpty()) { + Map out = new HashMap<>(); + out.put("deleted_entities", Collections.emptyList()); + out.put("added_entities", Collections.emptyList()); + return out; + } + + store.compute(scopeKey, (k, existing) -> { + List next = existing == null ? new ArrayList<>() : new ArrayList<>(existing); + next.addAll(newTriples); + return next; + }); + + // Python parity: {"deleted_entities":[...], "added_entities":[{"source":..,"relationship":..,"target":..}]} + List> added = new ArrayList<>(); + for (Triple t : newTriples) { + Map m = new HashMap<>(); + m.put("source", t.source); + m.put("relationship", t.relationship); + m.put("target", t.destination); + added.add(m); + } + Map out = new HashMap<>(); + out.put("deleted_entities", Collections.emptyList()); + out.put("added_entities", added); + return out; + } + + @Override + public List> search(String query, Map filters, int limit) { + if (query == null) { + query = ""; + } + int top = limit > 0 ? limit : 100; + String q = query.trim().toLowerCase(Locale.ROOT); + String scopeKey = scopeKey(filters); + + List triples = store.getOrDefault(scopeKey, Collections.emptyList()); + if (triples.isEmpty()) { + return Collections.emptyList(); + } + + List> out = new ArrayList<>(); + for (Triple t : triples) { + if (!q.isEmpty()) { + String combined = (t.source + " " + t.relationship + " " + t.destination).toLowerCase(Locale.ROOT); + if (!combined.contains(q)) { + continue; + } + } + Map m = new HashMap<>(); + m.put("source", t.source); + m.put("relationship", t.relationship); + m.put("destination", t.destination); + out.add(m); + if (out.size() >= top) { + break; + } + } + return out; + } + + @Override + public void deleteAll(Map filters) { + store.remove(scopeKey(filters)); + } + + @Override + public List> getAll(Map filters, int limit) { + return search("", filters, limit); + } + + @Override + public void reset() { + store.clear(); + } + + private static String scopeKey(Map filters) { + String userId = valueAsString(filters == null ? null : filters.get("user_id")); + String agentId = valueAsString(filters == null ? null : filters.get("agent_id")); + String runId = valueAsString(filters == null ? null : filters.get("run_id")); + + // Python parity: graph requires user_id; when absent, Memory uses fallback "user". + if (userId == null || userId.isBlank()) { + userId = "user"; + } + if (agentId == null) agentId = ""; + if (runId == null) runId = ""; + return userId + "|" + agentId + "|" + runId; + } + + private static String valueAsString(Object v) { + if (v == null) return null; + String s = String.valueOf(v); + return s; + } + + /** + * Extremely small heuristic triple extractor. + * + *

Goal: enable tests and provide a reasonable default behavior offline. + * For production, use a real graph store (OceanBaseGraphStore) with LLM-based extraction.

+ */ + private static List extractTriples(String data) { + List out = new ArrayList<>(); + String[] lines = data.split("[\\n\\r]+"); + for (String raw : lines) { + if (raw == null) continue; + String s = raw.trim(); + if (s.isEmpty()) continue; + + // Simple Chinese patterns: "用户喜欢X" / "我喜欢X" / "用户偏好X" + Triple t = parseChinesePreference(s); + if (t != null) { + out.add(t); + continue; + } + + // English fallback: "User likes X" + Triple e = parseEnglishPreference(s); + if (e != null) { + out.add(e); + } + } + return out; + } + + private static Triple parseChinesePreference(String s) { + String subject = null; + String rest = null; + + // normalize punctuation + String x = s.replace(":", ":").replace(",", ",").replace("。", ".").trim(); + + if (x.startsWith("用户")) { + subject = "用户"; + rest = x.substring(2); + } else if (x.startsWith("我")) { + subject = "我"; + rest = x.substring(1); + } else { + // unknown subject, skip + return null; + } + + if (rest == null) return null; + rest = rest.trim(); + + String rel; + int idx; + if ((idx = rest.indexOf("喜欢")) >= 0) { + rel = "喜欢"; + } else if ((idx = rest.indexOf("偏好")) >= 0) { + rel = "偏好"; + } else if ((idx = rest.indexOf("讨厌")) >= 0) { + rel = "讨厌"; + } else { + return null; + } + + String obj = rest.substring(idx + rel.length()).trim(); + if (obj.isEmpty()) { + return null; + } + return new Triple(subject, rel, obj); + } + + private static Triple parseEnglishPreference(String s) { + String x = s.trim(); + String lower = x.toLowerCase(Locale.ROOT); + if (!lower.startsWith("user ")) { + return null; + } + // e.g. "User likes coffee" + int idx = lower.indexOf(" likes "); + String rel = "likes"; + if (idx < 0) { + idx = lower.indexOf(" prefers "); + rel = "prefers"; + } + if (idx < 0) { + return null; + } + String obj = x.substring(idx + (" " + rel + " ").length()).trim(); + if (obj.isEmpty()) return null; + return new Triple("User", rel, obj); + } +} + diff --git a/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseConstants.java b/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseConstants.java index c0dc60b..90dd753 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseConstants.java +++ b/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseConstants.java @@ -7,5 +7,20 @@ */ public final class OceanBaseConstants { private OceanBaseConstants() {} + + // Graph storage table names + public static final String TABLE_GRAPH_ENTITIES = "graph_entities"; + public static final String TABLE_GRAPH_RELATIONSHIPS = "graph_relationships"; + + // Vector config defaults + public static final String DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"; + public static final String DEFAULT_VIDX_NAME = "vidx"; + public static final String DEFAULT_INDEX_TYPE = "HNSW"; + + // Graph search defaults + public static final double DEFAULT_SIMILARITY_THRESHOLD = 0.7; + public static final int DEFAULT_SEARCH_LIMIT = 100; + public static final int DEFAULT_BM25_TOP_N = 15; + public static final int DEFAULT_MAX_HOPS = 3; } diff --git a/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseGraphStore.java b/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseGraphStore.java index 9b02d1e..ad663e1 100644 --- a/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseGraphStore.java +++ b/src/main/java/com/oceanbase/powermem/sdk/storage/oceanbase/OceanBaseGraphStore.java @@ -1,15 +1,1168 @@ package com.oceanbase.powermem.sdk.storage.oceanbase; +import com.oceanbase.powermem.sdk.exception.ApiException; +import com.oceanbase.powermem.sdk.integrations.embeddings.Embedder; +import com.oceanbase.powermem.sdk.integrations.llm.LLM; +import com.oceanbase.powermem.sdk.integrations.llm.LlmResponse; +import com.oceanbase.powermem.sdk.json.JacksonJsonCodec; +import com.oceanbase.powermem.sdk.json.JsonCodec; +import com.oceanbase.powermem.sdk.model.Message; +import com.oceanbase.powermem.sdk.prompts.graph.GraphPrompts; +import com.oceanbase.powermem.sdk.prompts.graph.GraphToolsPrompts; import com.oceanbase.powermem.sdk.storage.base.GraphStore; +import com.oceanbase.powermem.sdk.util.Bm25; +import com.oceanbase.powermem.sdk.util.LlmJsonUtils; +import com.oceanbase.powermem.sdk.util.SnowflakeIdGenerator; +import com.oceanbase.powermem.sdk.util.TextTokenizer; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.Statement; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; /** - * OceanBase graph store implementation (Java migration target). + * OceanBase graph store implementation (Python parity). * - *

Used to store entities/relations and support graph-based retrieval (multi-hop, relation queries).

+ *

Python reference: {@code src/powermem/storage/oceanbase/oceanbase_graph.py}

* - *

Python reference: {@code src/powermem/storage/oceanbase/oceanbase_graph.py} and - * {@code src/powermem/storage/oceanbase/oceanbase_graph.py}.

+ *

Implements: + * - entities + relationships tables (graph_entities / graph_relationships) + * - LLM tools based extraction (extract_entities + establish_relationships + delete_graph_memory) + * - ANN-style entity resolution via VECTOR distance (best-effort) with embedding_json brute-force fallback + * - BM25 rerank of candidate relationships + * - multi-hop expansion

*/ public class OceanBaseGraphStore implements GraphStore { + private static final Logger LOG = Logger.getLogger(OceanBaseGraphStore.class.getName()); + + private final com.oceanbase.powermem.sdk.config.GraphStoreConfig config; + private final Embedder embedder; + private final LLM llm; + private final JsonCodec json = new JacksonJsonCodec(); + private final SnowflakeIdGenerator idGenerator = SnowflakeIdGenerator.defaultGenerator(); + + private volatile boolean initialized; + private volatile boolean hasVectorColumn; + + public OceanBaseGraphStore(com.oceanbase.powermem.sdk.config.GraphStoreConfig config) { + this(config, null, null); + } + + public OceanBaseGraphStore(com.oceanbase.powermem.sdk.config.GraphStoreConfig config, Embedder embedder, LLM llm) { + this.config = config; + this.embedder = embedder; + this.llm = llm; + ensureInitialized(); + } + + @Override + public Map add(String data, Map filters) { + if (data == null || data.isBlank()) { + return Collections.emptyMap(); + } + ensureInitialized(); + Map scope = normalizeScope(filters); + + Map entityTypeMap = extractEntitiesWithTypes(data, scope); + List toAdd = extractRelations(data, entityTypeMap, scope); + if (toAdd.isEmpty()) { + return Map.of("deleted_entities", Collections.emptyList(), "added_entities", Collections.emptyList()); + } + + // Neighborhood search then delete decision (Python parity). + List> neighborhood = searchGraphNeighborhood(entityTypeMap.keySet(), scope, safeSearchLimit()); + String existingMemories = formatMemoriesForPrompt(neighborhood); + List toDelete = decideDeletions(existingMemories, data, scope); + + List> deleted = deleteRelations(toDelete, scope); + List> added = addEntitiesAndRelationships(toAdd, entityTypeMap, scope); + + Map out = new HashMap<>(); + out.put("deleted_entities", deleted); + out.put("added_entities", added); + return out; + } + + @Override + public List> search(String query, Map filters, int limit) { + ensureInitialized(); + Map scope = normalizeScope(filters); + int top = limit > 0 ? limit : safeBm25TopN(); + + Map seed = extractEntitiesWithTypes(query, scope); + List> candidates = multiHopNeighborhood(seed.keySet(), scope, safeMaxHops(), safeSearchLimit()); + if (candidates.isEmpty()) { + return Collections.emptyList(); + } + + List> reranked = bm25Rerank(query, candidates, top); + // Ensure Python search shape: destination key. + List> out = new ArrayList<>(); + for (Map r : reranked) { + if (r == null) continue; + Map m = new HashMap<>(); + m.put("source", r.get("source")); + m.put("relationship", r.get("relationship")); + Object dst = r.get("destination"); + if (dst == null) dst = r.get("target"); + m.put("destination", dst); + out.add(m); + } + return out; + } + + @Override + public void deleteAll(Map filters) { + ensureInitialized(); + Map scope = normalizeScope(filters); + try (Connection c = openConnection()) { + Set entityIds = new HashSet<>(); + + StringBuilder where = new StringBuilder(" WHERE user_id = ?"); + List args = new ArrayList<>(); + args.add(String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) { + where.append(" AND agent_id = ?"); + args.add(asNullableString(scope.get("agent_id"))); + } + if (asNullableString(scope.get("run_id")) != null) { + where.append(" AND run_id = ?"); + args.add(asNullableString(scope.get("run_id"))); + } + + String q = "SELECT source_entity_id, destination_entity_id FROM " + relationshipsTable() + where; + try (PreparedStatement ps = c.prepareStatement(q)) { + for (int i = 0; i < args.size(); i++) ps.setString(i + 1, String.valueOf(args.get(i))); + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) { + entityIds.add(rs.getLong(1)); + entityIds.add(rs.getLong(2)); + } + } + } + + try (PreparedStatement ps = c.prepareStatement("DELETE FROM " + relationshipsTable() + where)) { + for (int i = 0; i < args.size(); i++) ps.setString(i + 1, String.valueOf(args.get(i))); + ps.executeUpdate(); + } + + if (!entityIds.isEmpty()) { + String in = placeholders(entityIds.size()); + String delEnt = "DELETE FROM " + entitiesTable() + " WHERE user_id = ?" + + (asNullableString(scope.get("agent_id")) != null ? " AND agent_id = ?" : "") + + (asNullableString(scope.get("run_id")) != null ? " AND run_id = ?" : "") + + " AND id IN (" + in + ")"; + try (PreparedStatement ps = c.prepareStatement(delEnt)) { + int p = 1; + ps.setString(p++, String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) ps.setString(p++, asNullableString(scope.get("agent_id"))); + if (asNullableString(scope.get("run_id")) != null) ps.setString(p++, asNullableString(scope.get("run_id"))); + for (Long id : entityIds) ps.setLong(p++, id); + ps.executeUpdate(); + } + } + } catch (Exception ex) { + LOG.log(Level.WARNING, "OceanBaseGraphStore.deleteAll failed: " + ex.getMessage(), ex); + } + } + + @Override + public List> getAll(Map filters, int limit) { + ensureInitialized(); + Map scope = normalizeScope(filters); + int top = limit > 0 ? limit : safeSearchLimit(); + + String sql = "SELECT se.name AS source, r.relationship_type AS relationship, de.name AS target " + + "FROM " + relationshipsTable() + " r " + + "JOIN " + entitiesTable() + " se ON se.id = r.source_entity_id " + + "JOIN " + entitiesTable() + " de ON de.id = r.destination_entity_id " + + "WHERE r.user_id = ?" + + (asNullableString(scope.get("agent_id")) != null ? " AND r.agent_id = ?" : "") + + (asNullableString(scope.get("run_id")) != null ? " AND r.run_id = ?" : "") + + " ORDER BY r.updated_at DESC LIMIT " + top; + + List> out = new ArrayList<>(); + try (Connection c = openConnection(); PreparedStatement ps = c.prepareStatement(sql)) { + int p = 1; + ps.setString(p++, String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) ps.setString(p++, asNullableString(scope.get("agent_id"))); + if (asNullableString(scope.get("run_id")) != null) ps.setString(p++, asNullableString(scope.get("run_id"))); + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) { + Map m = new HashMap<>(); + m.put("source", rs.getString("source")); + m.put("relationship", rs.getString("relationship")); + m.put("target", rs.getString("target")); + out.add(m); + } + } + } catch (Exception ex) { + throw new ApiException("OceanBaseGraphStore.getAll failed: " + ex.getMessage(), ex); + } + return out; + } + + @Override + public void reset() { + ensureInitialized(); + try (Connection c = openConnection(); Statement st = c.createStatement()) { + st.execute("TRUNCATE TABLE " + relationshipsTable()); + st.execute("TRUNCATE TABLE " + entitiesTable()); + } catch (Exception ex) { + LOG.log(Level.WARNING, "OceanBaseGraphStore.reset failed: " + ex.getMessage(), ex); + } + } + + // ---------------- init/schema ---------------- + + private synchronized void ensureInitialized() { + if (initialized) return; + try (Connection c = openConnection(); Statement st = c.createStatement()) { + String ent = entitiesTable(); + String rel = relationshipsTable(); + + st.execute("CREATE TABLE IF NOT EXISTS " + ent + " (\n" + + " id BIGINT PRIMARY KEY,\n" + + " user_id VARCHAR(128) NOT NULL,\n" + + " agent_id VARCHAR(128) NULL,\n" + + " run_id VARCHAR(128) NULL,\n" + + " name VARCHAR(512) NOT NULL,\n" + + " entity_type VARCHAR(128) NULL,\n" + + " embedding_json LONGTEXT NULL,\n" + + " created_at TIMESTAMP NULL,\n" + + " updated_at TIMESTAMP NULL\n" + + ")"); + + st.execute("CREATE TABLE IF NOT EXISTS " + rel + " (\n" + + " id BIGINT PRIMARY KEY,\n" + + " user_id VARCHAR(128) NOT NULL,\n" + + " agent_id VARCHAR(128) NULL,\n" + + " run_id VARCHAR(128) NULL,\n" + + " source_entity_id BIGINT NOT NULL,\n" + + " relationship_type VARCHAR(255) NOT NULL,\n" + + " destination_entity_id BIGINT NOT NULL,\n" + + " created_at TIMESTAMP NULL,\n" + + " updated_at TIMESTAMP NULL\n" + + ")"); + + try { + st.execute("CREATE INDEX IF NOT EXISTS idx_" + ent + "_scope ON " + ent + " (user_id, agent_id, run_id)"); + st.execute("CREATE INDEX IF NOT EXISTS idx_" + rel + "_scope ON " + rel + " (user_id, agent_id, run_id)"); + st.execute("CREATE INDEX IF NOT EXISTS idx_" + rel + "_src ON " + rel + " (source_entity_id)"); + st.execute("CREATE INDEX IF NOT EXISTS idx_" + rel + "_dst ON " + rel + " (destination_entity_id)"); + } catch (Exception ignored) {} + + int dims = config == null ? 0 : config.getEmbeddingModelDims(); + if (dims > 0) { + try { + try (PreparedStatement ps = c.prepareStatement("ALTER TABLE " + ent + " ADD COLUMN embedding VECTOR(" + dims + ")")) { + ps.execute(); + } + } catch (Exception ignored) {} + try (PreparedStatement ps = c.prepareStatement( + "SELECT 1 FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME=? AND COLUMN_NAME='embedding'")) { + ps.setString(1, ent); + try (ResultSet rs = ps.executeQuery()) { + hasVectorColumn = rs.next(); + } + } catch (Exception ignored) { + hasVectorColumn = false; + } + + if (hasVectorColumn) { + Integer existingDims = tryGetVectorDims(c, ent, "embedding"); + if (existingDims != null && existingDims > 0 && existingDims != dims) { + throw new RuntimeException("Vector dimension mismatch: existing VECTOR(" + existingDims + "), requested dims=" + dims); + } + String indexType = config.getIndexType() == null ? OceanBaseConstants.DEFAULT_INDEX_TYPE : config.getIndexType().trim().toUpperCase(Locale.ROOT); + String metricType = config.getMetricType() == null ? OceanBaseConstants.DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE : config.getMetricType().trim().toLowerCase(Locale.ROOT); + String vidx = config.getVectorIndexName() == null || config.getVectorIndexName().isBlank() + ? OceanBaseConstants.DEFAULT_VIDX_NAME + : config.getVectorIndexName().trim(); + String ddl = "CREATE VECTOR INDEX " + vidx + " ON " + ent + " (embedding) USING " + indexType + + " WITH (metric_type='" + metricType + "')"; + try (PreparedStatement ps = c.prepareStatement(ddl)) { + ps.execute(); + } catch (Exception ignored) {} + } + } + } catch (Exception ex) { + throw new ApiException("Failed to initialize OceanBaseGraphStore: " + ex.getMessage(), ex); + } + initialized = true; + } + + private static Integer tryGetVectorDims(Connection c, String table, String col) { + try (PreparedStatement ps = c.prepareStatement( + "SELECT COLUMN_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME=? AND COLUMN_NAME=?")) { + ps.setString(1, table); + ps.setString(2, col); + try (ResultSet rs = ps.executeQuery()) { + if (!rs.next()) return null; + String t = rs.getString(1); + if (t == null) return null; + String s = t.trim().toLowerCase(Locale.ROOT); + int l = s.indexOf('('); + int r = s.indexOf(')', l + 1); + if (l < 0 || r <= l) return null; + return Integer.parseInt(s.substring(l + 1, r).trim()); + } + } catch (Exception ignored) { + return null; + } + } + + // ---------------- extraction (LLM tools + fallbacks) ---------------- + + private Map extractEntitiesWithTypes(String text, Map scope) { + if (llm != null) { + try { + String system = "You are a smart assistant who understands entities and their types in a given text. " + + "If user message contains self reference such as 'I', 'me', 'my' etc. " + + "then use " + String.valueOf(scope.get("user_id")) + " as the source entity. " + + "Extract all the entities from the text. " + + "***DO NOT*** answer the question itself if the given text is a question."; + List msgs = List.of( + new Message("system", system), + new Message("user", text == null ? "" : text) + ); + LlmResponse resp = llm.generateResponseWithTools( + msgs, + null, + List.of(GraphToolsPrompts.extractEntitiesTool(false)), + GraphToolsPrompts.toolChoiceFunction("extract_entities") + ); + Map parsed = parseExtractEntities(resp); + if (!parsed.isEmpty()) return normalizeEntityTypeMap(parsed); + } catch (Exception ex) { + LOG.log(Level.WARNING, "Graph extract_entities failed; fallback. cause=" + ex.getMessage(), ex); + } + } + // heuristic fallback + Map out = new HashMap<>(); + for (RelationTriple t : heuristicTriples(text)) { + out.put(t.source, "entity"); + out.put(t.destination, "entity"); + } + // If still empty (e.g., a pure query), fall back to token-based entities so search can work offline. + if (out.isEmpty() && text != null && !text.isBlank()) { + for (String tok : TextTokenizer.tokenize(text)) { + if (tok == null || tok.isBlank()) continue; + // skip ultra-short noise tokens + if (tok.length() <= 1) continue; + out.put(normalizeName(tok), "entity"); + } + } + return normalizeEntityTypeMap(out); + } + + private List extractRelations(String text, Map entityTypeMap, Map scope) { + if (llm != null && text != null && !text.isBlank() && entityTypeMap != null && !entityTypeMap.isEmpty()) { + try { + String userIdentity = "user_id: " + scope.get("user_id") + + (asNullableString(scope.get("agent_id")) != null ? ", agent_id: " + scope.get("agent_id") : "") + + (asNullableString(scope.get("run_id")) != null ? ", run_id: " + scope.get("run_id") : ""); + String custom = config == null ? null : (config.getCustomExtractRelationsPrompt() != null ? config.getCustomExtractRelationsPrompt() : config.getCustomPrompt()); + String system = GraphPrompts.extractRelationsSystemPrompt(userIdentity, custom); + String userMsg = "List of entities: " + entityTypeMap.keySet() + ". \n\nText: " + text; + List msgs = List.of( + new Message("system", system), + new Message("user", userMsg) + ); + LlmResponse resp = llm.generateResponseWithTools( + msgs, + null, + List.of(GraphToolsPrompts.establishRelationsTool(false)), + null + ); + List parsed = parseEstablishRelations(resp); + if (!parsed.isEmpty()) return normalizeRelationTriples(parsed); + } catch (Exception ex) { + LOG.log(Level.WARNING, "Graph establish_relationships failed; fallback. cause=" + ex.getMessage(), ex); + } + } + return normalizeRelationTriples(heuristicTriples(text)); + } + + private List decideDeletions(String existingMemories, String newText, Map scope) { + if (llm == null) return Collections.emptyList(); + try { + String custom = config == null ? null : config.getCustomDeleteRelationsPrompt(); + String system = GraphPrompts.deleteRelationsSystemPrompt(String.valueOf(scope.get("user_id")), custom); + String user = GraphPrompts.deleteRelationsUserPrompt(existingMemories, newText); + List msgs = List.of(new Message("system", system), new Message("user", user)); + LlmResponse resp = llm.generateResponseWithTools( + msgs, + null, + List.of(GraphToolsPrompts.deleteGraphMemoryTool(false)), + null + ); + return normalizeRelationTriples(parseDeleteGraphMemory(resp)); + } catch (Exception ex) { + LOG.log(Level.WARNING, "Graph delete decision failed; ignore deletes. cause=" + ex.getMessage(), ex); + return Collections.emptyList(); + } + } + + // ---------------- entity/relationship operations ---------------- + + private List> deleteRelations(List toDelete, Map scope) { + if (toDelete == null || toDelete.isEmpty()) return Collections.emptyList(); + List> out = new ArrayList<>(); + for (RelationTriple t : toDelete) { + int n = deleteRelationByNames(t, scope); + out.add(Map.of("deleted_count", n)); + } + return out; + } + + private int deleteRelationByNames(RelationTriple t, Map scope) { + if (t == null) return 0; + try (Connection c = openConnection()) { + List srcIds = findEntityIdsByName(c, t.source, scope); + List dstIds = findEntityIdsByName(c, t.destination, scope); + if (srcIds.isEmpty() || dstIds.isEmpty()) return 0; + + String relTable = relationshipsTable(); + StringBuilder sql = new StringBuilder("DELETE FROM ").append(relTable) + .append(" WHERE relationship_type = ? AND user_id = ?"); + List args = new ArrayList<>(); + args.add(t.relationship); + args.add(String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) { + sql.append(" AND agent_id = ?"); + args.add(asNullableString(scope.get("agent_id"))); + } + if (asNullableString(scope.get("run_id")) != null) { + sql.append(" AND run_id = ?"); + args.add(asNullableString(scope.get("run_id"))); + } + + sql.append(" AND source_entity_id IN (").append(placeholders(srcIds.size())).append(")"); + sql.append(" AND destination_entity_id IN (").append(placeholders(dstIds.size())).append(")"); + + try (PreparedStatement ps = c.prepareStatement(sql.toString())) { + int p = 1; + ps.setString(p++, String.valueOf(args.get(0))); + ps.setString(p++, String.valueOf(args.get(1))); + for (int i = 2; i < args.size(); i++) ps.setString(p++, String.valueOf(args.get(i))); + for (Long id : srcIds) ps.setLong(p++, id); + for (Long id : dstIds) ps.setLong(p++, id); + return ps.executeUpdate(); + } + } catch (Exception ex) { + LOG.log(Level.WARNING, "deleteRelation failed: " + ex.getMessage(), ex); + return 0; + } + } + + private List> addEntitiesAndRelationships(List triples, Map entityTypeMap, Map scope) { + if (triples == null || triples.isEmpty()) return Collections.emptyList(); + if (embedder == null) throw new ApiException("OceanBaseGraphStore requires an embedder"); + + List> out = new ArrayList<>(); + for (RelationTriple t : triples) { + long srcId = getOrCreateEntity(t.source, entityTypeMap.getOrDefault(t.source, "entity"), scope); + long dstId = getOrCreateEntity(t.destination, entityTypeMap.getOrDefault(t.destination, "entity"), scope); + Map rel = createRelationshipIfNotExists(srcId, dstId, t.relationship, scope); + if (rel != null && !rel.isEmpty()) out.add(rel); + } + return out; + } + + private long getOrCreateEntity(String name, String type, Map scope) { + String n = normalizeName(name); + float[] vec = embedder.embed(n, "search"); + EntityRow hit = searchSimilarEntity(vec, scope, similarityThreshold(), 1); + if (hit != null) return hit.id; + return createEntity(n, type, vec, scope); + } + + private long createEntity(String name, String type, float[] embedding, Map scope) { + long id = parseLongOrZero(idGenerator.nextId()); + Instant now = Instant.now(); + String embeddingJson = json.toJson(embedding == null ? new float[0] : embedding); + + List cols = new ArrayList<>(); + cols.add("id"); + cols.add("user_id"); + cols.add("agent_id"); + cols.add("run_id"); + cols.add("name"); + cols.add("entity_type"); + cols.add("embedding_json"); + cols.add("created_at"); + cols.add("updated_at"); + if (hasVectorColumn) cols.add("embedding"); + + StringBuilder sql = new StringBuilder("INSERT INTO ").append(entitiesTable()).append(" ("); + for (int i = 0; i < cols.size(); i++) { + if (i > 0) sql.append(", "); + sql.append(cols.get(i)); + } + sql.append(") VALUES ("); + for (int i = 0; i < cols.size(); i++) { + if (i > 0) sql.append(", "); + sql.append("?"); + } + sql.append(")"); + + try (Connection c = openConnection(); PreparedStatement ps = c.prepareStatement(sql.toString())) { + int p = 1; + ps.setLong(p++, id); + ps.setString(p++, String.valueOf(scope.get("user_id"))); + ps.setString(p++, asNullableString(scope.get("agent_id"))); + ps.setString(p++, asNullableString(scope.get("run_id"))); + ps.setString(p++, name); + ps.setString(p++, type); + ps.setString(p++, embeddingJson); + ps.setTimestamp(p++, java.sql.Timestamp.from(now)); + ps.setTimestamp(p++, java.sql.Timestamp.from(now)); + if (hasVectorColumn) ps.setString(p, embeddingJson); + ps.executeUpdate(); + return id; + } catch (Exception ex) { + throw new ApiException("createEntity failed: " + ex.getMessage(), ex); + } + } + + private Map createRelationshipIfNotExists(long srcId, long dstId, String relType, Map scope) { + long id = parseLongOrZero(idGenerator.nextId()); + Instant now = Instant.now(); + + try (Connection c = openConnection()) { + String exists = "SELECT 1 FROM " + relationshipsTable() + " WHERE user_id=?" + + (asNullableString(scope.get("agent_id")) != null ? " AND agent_id=?" : "") + + (asNullableString(scope.get("run_id")) != null ? " AND run_id=?" : "") + + " AND source_entity_id=? AND destination_entity_id=? AND relationship_type=? LIMIT 1"; + try (PreparedStatement ps = c.prepareStatement(exists)) { + int p = 1; + ps.setString(p++, String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) ps.setString(p++, asNullableString(scope.get("agent_id"))); + if (asNullableString(scope.get("run_id")) != null) ps.setString(p++, asNullableString(scope.get("run_id"))); + ps.setLong(p++, srcId); + ps.setLong(p++, dstId); + ps.setString(p, relType); + try (ResultSet rs = ps.executeQuery()) { + if (rs.next()) return Collections.emptyMap(); + } + } + + String insert = "INSERT INTO " + relationshipsTable() + + " (id, user_id, agent_id, run_id, source_entity_id, relationship_type, destination_entity_id, created_at, updated_at) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement ps = c.prepareStatement(insert)) { + int p = 1; + ps.setLong(p++, id); + ps.setString(p++, String.valueOf(scope.get("user_id"))); + ps.setString(p++, asNullableString(scope.get("agent_id"))); + ps.setString(p++, asNullableString(scope.get("run_id"))); + ps.setLong(p++, srcId); + ps.setString(p++, relType); + ps.setLong(p++, dstId); + ps.setTimestamp(p++, java.sql.Timestamp.from(now)); + ps.setTimestamp(p, java.sql.Timestamp.from(now)); + ps.executeUpdate(); + } + + Map out = new HashMap<>(); + out.put("source", getEntityNameById(c, srcId)); + out.put("relationship", relType); + out.put("target", getEntityNameById(c, dstId)); + return out; + } catch (Exception ex) { + throw new ApiException("createRelationship failed: " + ex.getMessage(), ex); + } + } + + private String getEntityNameById(Connection c, long id) throws Exception { + try (PreparedStatement ps = c.prepareStatement("SELECT name FROM " + entitiesTable() + " WHERE id=? LIMIT 1")) { + ps.setLong(1, id); + try (ResultSet rs = ps.executeQuery()) { + if (rs.next()) return rs.getString(1); + } + } + return null; + } + + private List findEntityIdsByName(Connection c, String name, Map scope) throws Exception { + String sql = "SELECT id FROM " + entitiesTable() + " WHERE user_id=?" + + (asNullableString(scope.get("agent_id")) != null ? " AND agent_id=?" : "") + + (asNullableString(scope.get("run_id")) != null ? " AND run_id=?" : "") + + " AND name=? LIMIT " + safeSearchLimit(); + List out = new ArrayList<>(); + try (PreparedStatement ps = c.prepareStatement(sql)) { + int p = 1; + ps.setString(p++, String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) ps.setString(p++, asNullableString(scope.get("agent_id"))); + if (asNullableString(scope.get("run_id")) != null) ps.setString(p++, asNullableString(scope.get("run_id"))); + ps.setString(p, name); + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) out.add(rs.getLong(1)); + } + } + return out; + } + + // ---------------- ANN lookup ---------------- + + private EntityRow searchSimilarEntity(float[] queryEmbedding, Map scope, double threshold, int limit) { + if (queryEmbedding == null || queryEmbedding.length == 0) return null; + int k = Math.max(1, limit); + + if (hasVectorColumn) { + String metricType = config.getMetricType() == null ? OceanBaseConstants.DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE : config.getMetricType().trim().toLowerCase(Locale.ROOT); + String distFunc = "l2_distance"; + if ("cosine".equals(metricType)) distFunc = "cosine_distance"; + if ("inner_product".equals(metricType) || "ip".equals(metricType)) distFunc = "inner_product"; + + String sql = "SELECT id, name, entity_type, " + distFunc + "(embedding, ?) AS d FROM " + entitiesTable() + + " WHERE user_id=?" + + (asNullableString(scope.get("agent_id")) != null ? " AND agent_id=?" : "") + + (asNullableString(scope.get("run_id")) != null ? " AND run_id=?" : "") + + " ORDER BY d ASC LIMIT " + k; + try (Connection c = openConnection(); PreparedStatement ps = c.prepareStatement(sql)) { + int p = 1; + ps.setString(p++, json.toJson(queryEmbedding)); + ps.setString(p++, String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) ps.setString(p++, asNullableString(scope.get("agent_id"))); + if (asNullableString(scope.get("run_id")) != null) ps.setString(p++, asNullableString(scope.get("run_id"))); + try (ResultSet rs = ps.executeQuery()) { + if (rs.next()) { + double d = rs.getDouble("d"); + if (threshold > 0 && d >= threshold) return null; + return new EntityRow(rs.getLong("id"), rs.getString("name"), rs.getString("entity_type"), d); + } + } + } catch (Exception ex) { + LOG.log(Level.WARNING, "Graph SQL distance failed; fallback to brute-force. cause=" + ex.getMessage(), ex); + } + } + + // Brute-force fallback via embedding_json + try (Connection c = openConnection()) { + String sql = "SELECT id, name, entity_type, embedding_json FROM " + entitiesTable() + + " WHERE user_id=?" + + (asNullableString(scope.get("agent_id")) != null ? " AND agent_id=?" : "") + + (asNullableString(scope.get("run_id")) != null ? " AND run_id=?" : "") + + " LIMIT " + safeSearchLimit(); + try (PreparedStatement ps = c.prepareStatement(sql)) { + int p = 1; + ps.setString(p++, String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) ps.setString(p++, asNullableString(scope.get("agent_id"))); + if (asNullableString(scope.get("run_id")) != null) ps.setString(p++, asNullableString(scope.get("run_id"))); + EntityRow best = null; + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) { + float[] vec = json.fromJson(rs.getString("embedding_json"), float[].class); + if (vec == null || vec.length == 0) continue; + double d = l2Distance(queryEmbedding, vec); + if (threshold > 0 && d >= threshold) continue; + if (best == null || d < best.distance) { + best = new EntityRow(rs.getLong("id"), rs.getString("name"), rs.getString("entity_type"), d); + } + } + } + return best; + } + } catch (Exception ex) { + LOG.log(Level.WARNING, "Graph brute-force entity search failed: " + ex.getMessage(), ex); + return null; + } + } + + private static double l2Distance(float[] a, float[] b) { + if (a == null || b == null) return Double.POSITIVE_INFINITY; + int n = Math.min(a.length, b.length); + double s = 0.0; + for (int i = 0; i < n; i++) { + double d = (double) a[i] - (double) b[i]; + s += d * d; + } + return Math.sqrt(s); + } + + // ---------------- neighborhood + BM25 + multi-hop ---------------- + + private List> searchGraphNeighborhood(Set seedNames, Map scope, int limit) { + if (seedNames == null || seedNames.isEmpty() || embedder == null) return Collections.emptyList(); + Set ids = new HashSet<>(); + for (String n : seedNames) { + if (n == null || n.isBlank()) continue; + float[] vec = embedder.embed(n, "search"); + EntityRow hit = searchSimilarEntity(vec, scope, similarityThreshold(), 1); + if (hit != null) ids.add(hit.id); + } + if (ids.isEmpty()) return Collections.emptyList(); + return fetchRelationsByEntityIds(ids, scope, limit); + } + + private List> multiHopNeighborhood(Set seedNames, Map scope, int maxHops, int limit) { + if (seedNames == null || seedNames.isEmpty() || embedder == null) return Collections.emptyList(); + int hops = Math.max(1, maxHops); + int cap = Math.max(1, limit); + + Set frontier = new HashSet<>(); + Set visitedNodes = new HashSet<>(); + for (String s : seedNames) { + if (s == null || s.isBlank()) continue; + float[] vec = embedder.embed(s, "search"); + EntityRow hit = searchSimilarEntity(vec, scope, similarityThreshold(), 1); + if (hit != null) { + frontier.add(hit.id); + visitedNodes.add(hit.id); + } + } + if (frontier.isEmpty()) return Collections.emptyList(); + + List> edges = new ArrayList<>(); + Set dedup = new HashSet<>(); + + for (int hop = 0; hop < hops && !frontier.isEmpty() && edges.size() < cap; hop++) { + List> batch = fetchRelationsByEntityIds(frontier, scope, cap - edges.size()); + Set next = new HashSet<>(); + for (Map e : batch) { + if (e == null) continue; + String key = String.valueOf(e.get("source")) + "|" + String.valueOf(e.get("relationship")) + "|" + String.valueOf(e.get("destination")); + if (dedup.add(key)) edges.add(e); + // Prefer id-based traversal for parity (avoid name->id roundtrip). + Object sid = e.get("_src_id"); + Object did = e.get("_dst_id"); + if (sid instanceof Number) next.add(((Number) sid).longValue()); + if (did instanceof Number) next.add(((Number) did).longValue()); + } + next.removeAll(visitedNodes); + visitedNodes.addAll(next); + frontier = next; + } + return edges; + } + + private List> fetchRelationsByEntityIds(Set ids, Map scope, int limit) { + if (ids == null || ids.isEmpty()) return Collections.emptyList(); + int top = Math.max(1, limit); + String in = placeholders(ids.size()); + String sql = "SELECT r.source_entity_id AS src_id, r.destination_entity_id AS dst_id, " + + "se.name AS source, r.relationship_type AS relationship, de.name AS destination " + + "FROM " + relationshipsTable() + " r " + + "JOIN " + entitiesTable() + " se ON se.id = r.source_entity_id " + + "JOIN " + entitiesTable() + " de ON de.id = r.destination_entity_id " + + "WHERE r.user_id=?" + + (asNullableString(scope.get("agent_id")) != null ? " AND r.agent_id=?" : "") + + (asNullableString(scope.get("run_id")) != null ? " AND r.run_id=?" : "") + + " AND (r.source_entity_id IN (" + in + ") OR r.destination_entity_id IN (" + in + "))" + + " ORDER BY r.updated_at DESC LIMIT " + top; + + List> out = new ArrayList<>(); + try (Connection c = openConnection(); PreparedStatement ps = c.prepareStatement(sql)) { + int p = 1; + ps.setString(p++, String.valueOf(scope.get("user_id"))); + if (asNullableString(scope.get("agent_id")) != null) ps.setString(p++, asNullableString(scope.get("agent_id"))); + if (asNullableString(scope.get("run_id")) != null) ps.setString(p++, asNullableString(scope.get("run_id"))); + for (Long id : ids) ps.setLong(p++, id); + for (Long id : ids) ps.setLong(p++, id); + try (ResultSet rs = ps.executeQuery()) { + while (rs.next()) { + Map m = new HashMap<>(); + m.put("_src_id", rs.getLong("src_id")); + m.put("_dst_id", rs.getLong("dst_id")); + m.put("source", rs.getString("source")); + m.put("relationship", rs.getString("relationship")); + m.put("destination", rs.getString("destination")); + out.add(m); + } + } + } catch (Exception ex) { + LOG.log(Level.WARNING, "fetchRelationsByEntityIds failed: " + ex.getMessage(), ex); + } + return out; + } + + private List> bm25Rerank(String query, List> candidates, int topN) { + if (candidates == null || candidates.isEmpty()) return Collections.emptyList(); + int top = Math.max(1, topN); + + List> corpus = new ArrayList<>(); + for (Map r : candidates) { + String combined = String.valueOf(r.get("source")) + " " + String.valueOf(r.get("relationship")) + " " + String.valueOf(r.get("destination")); + corpus.add(TextTokenizer.tokenize(combined)); + } + Bm25 bm25 = new Bm25(corpus); + double[] scores = bm25.getScores(TextTokenizer.tokenize(query == null ? "" : query)); + + List idx = new ArrayList<>(); + for (int i = 0; i < candidates.size(); i++) idx.add(i); + idx.sort((a, b) -> Double.compare(scores[b], scores[a])); + + List> out = new ArrayList<>(); + for (int i : idx) { + out.add(candidates.get(i)); + if (out.size() >= top) break; + } + return out; + } + + private static String formatMemoriesForPrompt(List> rels) { + if (rels == null || rels.isEmpty()) return ""; + StringBuilder sb = new StringBuilder(); + for (Map r : rels) { + if (r == null) continue; + if (sb.length() > 0) sb.append("\n"); + sb.append(String.valueOf(r.get("source"))).append(" -- ") + .append(String.valueOf(r.get("relationship"))).append(" -- ") + .append(String.valueOf(r.get("destination"))); + } + return sb.toString(); + } + + // ---------------- tool parsing helpers ---------------- + + private Map parseExtractEntities(LlmResponse resp) { + Map out = new HashMap<>(); + for (Map tc : extractToolCalls(resp)) { + String name = toolCallName(tc); + if (!"extract_entities".equals(name)) continue; + Map args = toolCallArguments(tc); + Object ents = args.get("entities"); + if (ents instanceof List) { + for (Object e : (List) ents) { + if (!(e instanceof Map)) continue; + Map em = (Map) e; + Object en = em.get("entity"); + Object et = em.get("entity_type"); + if (en == null) continue; + out.put(normalizeName(String.valueOf(en)), et == null ? "entity" : normalizeName(String.valueOf(et))); + } + } + } + return out; + } + + private List parseEstablishRelations(LlmResponse resp) { + List out = new ArrayList<>(); + for (Map tc : extractToolCalls(resp)) { + String name = toolCallName(tc); + if (!"establish_relationships".equals(name) && !"establish_relations".equals(name)) continue; + Map args = toolCallArguments(tc); + Object ents = args.get("entities"); + if (ents instanceof List) { + for (Object e : (List) ents) { + if (!(e instanceof Map)) continue; + Map em = (Map) e; + String src = normalizeName(String.valueOf(em.get("source"))); + String rel = normalizeName(String.valueOf(em.get("relationship"))); + String dst = normalizeName(String.valueOf(em.get("destination"))); + if (!src.isBlank() && !rel.isBlank() && !dst.isBlank()) out.add(new RelationTriple(src, rel, dst)); + } + } + } + return out; + } + + private List parseDeleteGraphMemory(LlmResponse resp) { + List out = new ArrayList<>(); + for (Map tc : extractToolCalls(resp)) { + String name = toolCallName(tc); + if (!"delete_graph_memory".equals(name)) continue; + Map args = toolCallArguments(tc); + String src = normalizeName(String.valueOf(args.get("source"))); + String rel = normalizeName(String.valueOf(args.get("relationship"))); + String dst = normalizeName(String.valueOf(args.get("destination"))); + if (!src.isBlank() && !rel.isBlank() && !dst.isBlank()) out.add(new RelationTriple(src, rel, dst)); + } + return out; + } + + private List> extractToolCalls(LlmResponse resp) { + if (resp == null) return Collections.emptyList(); + if (resp.getToolCalls() != null && !resp.getToolCalls().isEmpty()) return resp.getToolCalls(); + Map m = LlmJsonUtils.parseJsonObjectLoose(resp.getContent()); + Object tc = m.get("tool_calls"); + if (tc instanceof List) { + List> out = new ArrayList<>(); + for (Object o : (List) tc) { + if (o instanceof Map) { + @SuppressWarnings("unchecked") + Map mm = (Map) o; + out.add(mm); + } + } + return out; + } + return Collections.emptyList(); + } + + private static String toolCallName(Map tc) { + if (tc == null) return null; + Object fn = tc.get("function"); + if (fn instanceof Map) { + Object n = ((Map) fn).get("name"); + return n == null ? null : String.valueOf(n); + } + Object n = tc.get("name"); + return n == null ? null : String.valueOf(n); + } + + private static Map toolCallArguments(Map tc) { + if (tc == null) return Collections.emptyMap(); + Object fn = tc.get("function"); + Object args = null; + if (fn instanceof Map) args = ((Map) fn).get("arguments"); + if (args == null) args = tc.get("arguments"); + if (args instanceof Map) { + @SuppressWarnings("unchecked") + Map m = (Map) args; + return m; + } + if (args instanceof String) return LlmJsonUtils.parseJsonObjectLoose((String) args); + return Collections.emptyMap(); + } + + // ---------------- fallbacks / normalization ---------------- + + private static Map normalizeEntityTypeMap(Map m) { + Map out = new HashMap<>(); + if (m == null) return out; + for (Map.Entry e : m.entrySet()) { + if (e.getKey() == null) continue; + String k = normalizeName(e.getKey()); + if (k.isBlank()) continue; + String v = e.getValue() == null ? "entity" : normalizeName(e.getValue()); + if (v.isBlank()) v = "entity"; + out.put(k, v); + } + return out; + } + + private static List normalizeRelationTriples(List in) { + if (in == null || in.isEmpty()) return Collections.emptyList(); + List out = new ArrayList<>(); + for (RelationTriple t : in) { + if (t == null) continue; + String s = normalizeName(t.source); + String r = normalizeName(t.relationship); + String d = normalizeName(t.destination); + if (s.isBlank() || r.isBlank() || d.isBlank()) continue; + out.add(new RelationTriple(s, r, d)); + } + return out; + } + + private static String normalizeName(String s) { + if (s == null) return ""; + return s.toLowerCase(Locale.ROOT).replace(" ", "_").trim(); + } + + private static List heuristicTriples(String data) { + if (data == null || data.isBlank()) return Collections.emptyList(); + List out = new ArrayList<>(); + String[] lines = data.split("[\\n\\r]+"); + for (String raw : lines) { + if (raw == null) continue; + String s = raw.trim(); + if (s.isEmpty()) continue; + // Allow explicit triple format: source -- relationship -- destination + RelationTriple explicit = parseExplicitTriple(s); + if (explicit != null) { + out.add(explicit); + continue; + } + RelationTriple t = parseChinesePreference(s); + if (t != null) { + out.add(t); + continue; + } + RelationTriple e = parseEnglishPreference(s); + if (e != null) out.add(e); + } + return out; + } + + private static RelationTriple parseExplicitTriple(String s) { + if (s == null) return null; + String x = s.trim(); + if (!x.contains("--")) return null; + String[] parts = x.split("\\s*--\\s*"); + if (parts.length < 3) return null; + String src = normalizeName(parts[0]); + String rel = normalizeName(parts[1]); + String dst = normalizeName(parts[2]); + if (src.isBlank() || rel.isBlank() || dst.isBlank()) return null; + return new RelationTriple(src, rel, dst); + } + + private static RelationTriple parseChinesePreference(String s) { + String x = s.replace(":", ":").replace(",", ",").replace("。", ".").trim(); + String subject; + String rest; + if (x.startsWith("用户")) { + subject = "用户"; + rest = x.substring(2); + } else if (x.startsWith("我")) { + subject = "我"; + rest = x.substring(1); + } else { + return null; + } + rest = rest.trim(); + String rel; + int idx; + if ((idx = rest.indexOf("喜欢")) >= 0) rel = "喜欢"; + else if ((idx = rest.indexOf("偏好")) >= 0) rel = "偏好"; + else if ((idx = rest.indexOf("讨厌")) >= 0) rel = "讨厌"; + else return null; + String obj = rest.substring(idx + rel.length()).trim(); + if (obj.isEmpty()) return null; + return new RelationTriple(normalizeName(subject), normalizeName(rel), normalizeName(obj)); + } + + private static RelationTriple parseEnglishPreference(String s) { + String x = s.trim(); + String lower = x.toLowerCase(Locale.ROOT); + if (!lower.startsWith("user ")) return null; + int idx = lower.indexOf(" likes "); + String rel = "likes"; + if (idx < 0) { + idx = lower.indexOf(" prefers "); + rel = "prefers"; + } + if (idx < 0) return null; + String obj = x.substring(idx + (" " + rel + " ").length()).trim(); + if (obj.isEmpty()) return null; + return new RelationTriple("user", normalizeName(rel), normalizeName(obj)); + } + + // ---------------- config helpers ---------------- + + private int safeSearchLimit() { + int v = config == null ? 0 : config.getSearchLimit(); + return v > 0 ? v : OceanBaseConstants.DEFAULT_SEARCH_LIMIT; + } + + private int safeBm25TopN() { + int v = config == null ? 0 : config.getBm25TopN(); + return v > 0 ? v : OceanBaseConstants.DEFAULT_BM25_TOP_N; + } + + private int safeMaxHops() { + int v = config == null ? 0 : config.getMaxHops(); + return v > 0 ? v : OceanBaseConstants.DEFAULT_MAX_HOPS; + } + + private double similarityThreshold() { + double v = config == null ? 0.0 : config.getSimilarityThreshold(); + return v > 0.0 ? v : OceanBaseConstants.DEFAULT_SIMILARITY_THRESHOLD; + } + + // ---------------- connection + SQL helpers ---------------- + + private String entitiesTable() { + String tn = config == null ? null : config.getEntitiesTable(); + if (tn == null || tn.isBlank()) tn = OceanBaseConstants.TABLE_GRAPH_ENTITIES; + String safe = tn.trim(); + if (!safe.matches("[A-Za-z0-9_]+")) safe = OceanBaseConstants.TABLE_GRAPH_ENTITIES; + return safe; + } + + private String relationshipsTable() { + String tn = config == null ? null : config.getRelationshipsTable(); + if (tn == null || tn.isBlank()) tn = OceanBaseConstants.TABLE_GRAPH_RELATIONSHIPS; + String safe = tn.trim(); + if (!safe.matches("[A-Za-z0-9_]+")) safe = OceanBaseConstants.TABLE_GRAPH_RELATIONSHIPS; + return safe; + } + + private String jdbcUrl() { + String host = config.getHost() == null || config.getHost().isBlank() ? "127.0.0.1" : config.getHost(); + int port = config.getPort() > 0 ? config.getPort() : 2881; + String db = config.getDatabase() == null || config.getDatabase().isBlank() ? "ai_work" : config.getDatabase(); + int timeout = Math.max(1, config.getTimeoutSeconds()); + return "jdbc:mysql://" + host + ":" + port + "/" + db + + "?useUnicode=true&characterEncoding=UTF-8" + + "&useSSL=false&allowPublicKeyRetrieval=true&serverTimezone=UTC" + + "&connectTimeout=" + (timeout * 1000) + + "&socketTimeout=" + (timeout * 1000); + } + + private Connection openConnection() throws Exception { + String user = config.getUser(); + String pass = config.getPassword(); + if (user == null || user.isBlank()) { + throw new ApiException("OceanBase user is required (graph_store.user)"); + } + if (pass == null) pass = ""; + return DriverManager.getConnection(jdbcUrl(), user, pass); + } + + private static Map normalizeScope(Map filters) { + Map f = filters == null ? new HashMap<>() : new HashMap<>(filters); + Object uid = f.get("user_id"); + if (uid == null || String.valueOf(uid).isBlank()) { + f.put("user_id", "user"); // Python parity default + } + return f; + } + + private static String asNullableString(Object v) { + if (v == null) return null; + String s = String.valueOf(v); + return s.isBlank() ? null : s; + } + + private static long parseLongOrZero(String s) { + try { + return Long.parseLong(s); + } catch (Exception ex) { + return 0L; + } + } + + private static String placeholders(int n) { + if (n <= 0) return ""; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < n; i++) { + if (i > 0) sb.append(","); + sb.append("?"); + } + return sb.toString(); + } + + // ---------------- data types ---------------- + + private static final class RelationTriple { + final String source; + final String relationship; + final String destination; + RelationTriple(String source, String relationship, String destination) { + this.source = source; + this.relationship = relationship; + this.destination = destination; + } + } + + private static final class EntityRow { + final long id; + final String name; + final String entityType; + final double distance; + EntityRow(long id, String name, String entityType, double distance) { + this.id = id; + this.name = name; + this.entityType = entityType; + this.distance = distance; + } + } } diff --git a/src/main/java/com/oceanbase/powermem/sdk/util/Bm25.java b/src/main/java/com/oceanbase/powermem/sdk/util/Bm25.java new file mode 100644 index 0000000..16527b3 --- /dev/null +++ b/src/main/java/com/oceanbase/powermem/sdk/util/Bm25.java @@ -0,0 +1,82 @@ +package com.oceanbase.powermem.sdk.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Minimal BM25 scorer (Okapi BM25). + * + *

Used for graph triple reranking (Python parity: rank_bm25.BM25Okapi).

+ */ +public final class Bm25 { + private final double k1; + private final double b; + private final List> corpus; + private final Map docFreq = new HashMap<>(); + private final List> termFreqs = new ArrayList<>(); + private final double avgDocLen; + + public Bm25(List> tokenizedCorpus) { + this(tokenizedCorpus, 1.5, 0.75); + } + + public Bm25(List> tokenizedCorpus, double k1, double b) { + this.k1 = k1; + this.b = b; + this.corpus = tokenizedCorpus == null ? new ArrayList<>() : tokenizedCorpus; + + long totalLen = 0; + for (List doc : this.corpus) { + Map tf = new HashMap<>(); + if (doc != null) { + totalLen += doc.size(); + for (String t : doc) { + if (t == null || t.isBlank()) continue; + tf.put(t, tf.getOrDefault(t, 0) + 1); + } + } + termFreqs.add(tf); + for (String term : tf.keySet()) { + docFreq.put(term, docFreq.getOrDefault(term, 0) + 1); + } + } + this.avgDocLen = this.corpus.isEmpty() ? 0.0 : (double) totalLen / (double) this.corpus.size(); + } + + public double[] getScores(List queryTokens) { + int nDocs = corpus.size(); + double[] scores = new double[nDocs]; + if (nDocs == 0) { + return scores; + } + if (queryTokens == null || queryTokens.isEmpty()) { + return scores; + } + + for (int i = 0; i < nDocs; i++) { + List doc = corpus.get(i); + Map tf = termFreqs.get(i); + int docLen = doc == null ? 0 : doc.size(); + double denomNorm = (1.0 - b) + b * (avgDocLen > 0 ? ((double) docLen / avgDocLen) : 1.0); + + double s = 0.0; + for (String q : queryTokens) { + if (q == null || q.isBlank()) continue; + Integer df = docFreq.get(q); + if (df == null || df == 0) continue; + // IDF with smoothing (Okapi) + double idf = Math.log(1.0 + (nDocs - df + 0.5) / (df + 0.5)); + int f = tf.getOrDefault(q, 0); + if (f == 0) continue; + double numer = f * (k1 + 1.0); + double denom = f + k1 * denomNorm; + s += idf * (numer / denom); + } + scores[i] = s; + } + return scores; + } +} + diff --git a/src/main/java/com/oceanbase/powermem/sdk/util/TextTokenizer.java b/src/main/java/com/oceanbase/powermem/sdk/util/TextTokenizer.java new file mode 100644 index 0000000..7dfdc74 --- /dev/null +++ b/src/main/java/com/oceanbase/powermem/sdk/util/TextTokenizer.java @@ -0,0 +1,62 @@ +package com.oceanbase.powermem.sdk.util; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +/** + * Minimal tokenizer used for BM25 in graph store. + * + *

Python parity: jieba is used when available; here we use a lightweight heuristic: + * - ASCII words: split by non-alphanumerics + * - CJK: treat each CJK character as a token + * - Also keep numeric tokens

+ */ +public final class TextTokenizer { + private TextTokenizer() {} + + public static List tokenize(String text) { + List out = new ArrayList<>(); + if (text == null || text.isBlank()) { + return out; + } + String s = text.toLowerCase(Locale.ROOT); + + StringBuilder ascii = new StringBuilder(); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (isCjk(c)) { + flushAscii(ascii, out); + out.add(String.valueOf(c)); + } else if (Character.isLetterOrDigit(c)) { + ascii.append(c); + } else { + flushAscii(ascii, out); + } + } + flushAscii(ascii, out); + return out; + } + + private static void flushAscii(StringBuilder sb, List out) { + if (sb.length() == 0) return; + String tok = sb.toString(); + sb.setLength(0); + if (!tok.isBlank()) { + out.add(tok); + } + } + + private static boolean isCjk(char c) { + Character.UnicodeBlock b = Character.UnicodeBlock.of(c); + return b == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS + || b == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_A + || b == Character.UnicodeBlock.CJK_UNIFIED_IDEOGRAPHS_EXTENSION_B + || b == Character.UnicodeBlock.CJK_COMPATIBILITY_IDEOGRAPHS + || b == Character.UnicodeBlock.CJK_SYMBOLS_AND_PUNCTUATION + || b == Character.UnicodeBlock.HIRAGANA + || b == Character.UnicodeBlock.KATAKANA + || b == Character.UnicodeBlock.HANGUL_SYLLABLES; + } +} + diff --git a/src/test/java/com/oceanbase/powermem/GraphStoreConfigOverrideTest.java b/src/test/java/com/oceanbase/powermem/GraphStoreConfigOverrideTest.java new file mode 100644 index 0000000..2b98db6 --- /dev/null +++ b/src/test/java/com/oceanbase/powermem/GraphStoreConfigOverrideTest.java @@ -0,0 +1,61 @@ +package com.oceanbase.powermem; + +import com.oceanbase.powermem.sdk.config.MemoryConfig; +import com.oceanbase.powermem.sdk.config.ConfigLoader; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Offline test for graph_store.llm/embedder override parsing. + * + *

We primarily validate that ConfigLoader wires the nested config objects and does not affect + * the top-level llm/embedder config.

+ */ +public class GraphStoreConfigOverrideTest { + + @Test + void testGraphStoreNestedLlmAndEmbedder_areParsed() { + java.util.Map m = new java.util.HashMap<>(); + m.put("GRAPH_STORE_ENABLED", "true"); + m.put("GRAPH_STORE_PROVIDER", "oceanbase"); + + m.put("LLM_PROVIDER", "mock"); + m.put("EMBEDDING_PROVIDER", "mock"); + + m.put("GRAPH_STORE_LLM_PROVIDER", "openai"); + m.put("GRAPH_STORE_LLM_API_KEY", "k1"); + m.put("GRAPH_STORE_LLM_MODEL", "gpt-4o-mini"); + m.put("GRAPH_STORE_LLM_BASE_URL", "https://example.com/v1"); + + m.put("GRAPH_STORE_EMBEDDING_PROVIDER", "qwen"); + m.put("GRAPH_STORE_EMBEDDING_API_KEY", "k2"); + m.put("GRAPH_STORE_EMBEDDING_MODEL", "text-embedding-v4"); + m.put("GRAPH_STORE_EMBEDDING_DIMS", "1536"); + m.put("GRAPH_STORE_EMBEDDING_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"); + + MemoryConfig cfg = ConfigLoader.fromMap(m); + + assertNotNull(cfg.getGraphStore()); + assertTrue(cfg.getGraphStore().isEnabled()); + + // top-level remains mock + assertEquals("mock", cfg.getLlm().getProvider()); + assertEquals("mock", cfg.getEmbedder().getProvider()); + + // graph_store overrides + assertNotNull(cfg.getGraphStore().getLlm()); + assertEquals("openai", cfg.getGraphStore().getLlm().getProvider()); + assertEquals("k1", cfg.getGraphStore().getLlm().getApiKey()); + assertEquals("gpt-4o-mini", cfg.getGraphStore().getLlm().getModel()); + assertEquals("https://example.com/v1", cfg.getGraphStore().getLlm().getBaseUrl()); + + assertNotNull(cfg.getGraphStore().getEmbedder()); + assertEquals("qwen", cfg.getGraphStore().getEmbedder().getProvider()); + assertEquals("k2", cfg.getGraphStore().getEmbedder().getApiKey()); + assertEquals("text-embedding-v4", cfg.getGraphStore().getEmbedder().getModel()); + assertEquals(1536, cfg.getGraphStore().getEmbedder().getEmbeddingDims()); + assertEquals("https://dashscope.aliyuncs.com/compatible-mode/v1", cfg.getGraphStore().getEmbedder().getBaseUrl()); + } +} + diff --git a/src/test/java/com/oceanbase/powermem/MemoryE2eTest.java b/src/test/java/com/oceanbase/powermem/MemoryE2eTest.java index b7de54f..135089b 100644 --- a/src/test/java/com/oceanbase/powermem/MemoryE2eTest.java +++ b/src/test/java/com/oceanbase/powermem/MemoryE2eTest.java @@ -45,6 +45,32 @@ private static Memory newMemoryWithTempSqlite(Path dbPath) { return new Memory(cfg); } + private static Memory newMemoryWithTempSqliteAndInMemoryGraph(Path dbPath) { + MemoryConfig cfg = new MemoryConfig(); + VectorStoreConfig vs = VectorStoreConfig.sqlite(dbPath.toString()); + vs.setCollectionName("memories"); + cfg.setVectorStore(vs); + + // offline deterministic + EmbedderConfig emb = new EmbedderConfig(); + emb.setProvider("mock"); + cfg.setEmbedder(emb); + + LlmConfig llm = new LlmConfig(); + llm.setProvider("mock"); + cfg.setLlm(llm); + + // enable intelligent memory plugin (ebbinghaus lifecycle hooks) + cfg.getIntelligentMemory().setEnabled(true); + cfg.getIntelligentMemory().setDecayEnabled(true); + + // enable graph store (in-memory) + cfg.getGraphStore().setEnabled(true); + cfg.getGraphStore().setProvider("memory"); + + return new Memory(cfg); + } + @Test void testCrudAndSearch_withMetadataAndFilters(@TempDir Path tmp) { Memory mem = newMemoryWithTempSqlite(tmp.resolve("powermem_test.db")); @@ -247,5 +273,50 @@ void testCrudAndSearch_withAgentIdOnly_userIdOptional(@TempDir Path tmp) { assertNotNull(delResp); assertTrue(delResp.isDeleted()); } + + @Test + void testGraphStoreRelations_areReturnedWhenEnabled(@TempDir Path tmp) { + Memory mem = newMemoryWithTempSqliteAndInMemoryGraph(tmp.resolve("powermem_graph.db")); + + String userId = null; // Python parity: user_id optional; graph uses default "user" + String agentId = "a_graph"; + String runId = "r_graph"; + + DeleteAllMemoriesRequest delAll = new DeleteAllMemoriesRequest(); + delAll.setUserId(userId); + delAll.setAgentId(agentId); + delAll.setRunId(runId); + mem.deleteAll(delAll); + + AddMemoryRequest add = new AddMemoryRequest(); + add.setUserId(userId); + add.setAgentId(agentId); + add.setRunId(runId); + add.setInfer(false); + add.setText("用户喜欢简洁的中文回答"); + var addResp = mem.add(add); + assertNotNull(addResp); + assertNotNull(addResp.getResults()); + assertEquals(1, addResp.getResults().size()); + assertNotNull(addResp.getRelations(), "relations should be present when graph store is enabled"); + + @SuppressWarnings("unchecked") + Map rel = (Map) addResp.getRelations(); + assertTrue(rel.containsKey("added_entities")); + + // Search should return relations list (source/relationship/destination) + SearchMemoriesRequest search = SearchMemoriesRequest.ofQuery("喜欢 简洁"); + search.setUserId(userId); + search.setAgentId(agentId); + search.setRunId(runId); + var searchResp = mem.search(search); + assertNotNull(searchResp); + assertNotNull(searchResp.getRelations(), "search.relations should be present when graph store is enabled"); + + @SuppressWarnings("unchecked") + List> graphHits = (List>) searchResp.getRelations(); + assertFalse(graphHits.isEmpty()); + assertEquals("用户", String.valueOf(graphHits.get(0).get("source"))); + } } diff --git a/src/test/java/com/oceanbase/powermem/OceanBaseGraphStoreIT.java b/src/test/java/com/oceanbase/powermem/OceanBaseGraphStoreIT.java new file mode 100644 index 0000000..2c5f4db --- /dev/null +++ b/src/test/java/com/oceanbase/powermem/OceanBaseGraphStoreIT.java @@ -0,0 +1,96 @@ +package com.oceanbase.powermem; + +import com.oceanbase.powermem.sdk.config.EmbedderConfig; +import com.oceanbase.powermem.sdk.config.MemoryConfig; +import com.oceanbase.powermem.sdk.config.VectorStoreConfig; +import com.oceanbase.powermem.sdk.core.Memory; +import com.oceanbase.powermem.sdk.model.AddMemoryRequest; +import com.oceanbase.powermem.sdk.model.DeleteAllMemoriesRequest; +import com.oceanbase.powermem.sdk.model.SearchMemoriesRequest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * OceanBaseGraphStore integration test. + * + *

Runs only when required OceanBase env vars are set.

+ */ +@EnabledIfEnvironmentVariable(named = "OCEANBASE_HOST", matches = ".+") +@EnabledIfEnvironmentVariable(named = "OCEANBASE_USER", matches = ".+") +@EnabledIfEnvironmentVariable(named = "OCEANBASE_DATABASE", matches = ".+") +public class OceanBaseGraphStoreIT { + + @Test + void testGraphRelations_addAndSearch(@TempDir Path tmp) { + String host = System.getenv("OCEANBASE_HOST"); + String user = System.getenv("OCEANBASE_USER"); + String password = System.getenv().getOrDefault("OCEANBASE_PASSWORD", ""); + String db = System.getenv("OCEANBASE_DATABASE"); + int port = 2881; + try { + String p = System.getenv("OCEANBASE_PORT"); + if (p != null && !p.isBlank()) port = Integer.parseInt(p); + } catch (Exception ignored) {} + + // Vector store can stay local (SQLite). Graph store uses OceanBase. + MemoryConfig cfg = new MemoryConfig(); + VectorStoreConfig vs = VectorStoreConfig.sqlite(tmp.resolve("powermem_graph_it.db").toString()); + vs.setCollectionName("memories_graph_it"); + cfg.setVectorStore(vs); + + EmbedderConfig emb = new EmbedderConfig(); + emb.setProvider("mock"); + cfg.setEmbedder(emb); + cfg.getLlm().setProvider("mock"); + + cfg.getGraphStore().setEnabled(true); + cfg.getGraphStore().setProvider("oceanbase"); + cfg.getGraphStore().setHost(host); + cfg.getGraphStore().setPort(port); + cfg.getGraphStore().setUser(user); + cfg.getGraphStore().setPassword(password); + cfg.getGraphStore().setDatabase(db); + cfg.getGraphStore().setEntitiesTable("graph_entities_java_it"); + cfg.getGraphStore().setRelationshipsTable("graph_relationships_java_it"); + cfg.getGraphStore().setTimeoutSeconds(10); + + Memory mem = new Memory(cfg); + + String userId = null; // Python parity: optional + String agentId = "it_agent_graph"; + String runId = "it_run_graph"; + + DeleteAllMemoriesRequest delAll = new DeleteAllMemoriesRequest(); + delAll.setUserId(userId); + delAll.setAgentId(agentId); + delAll.setRunId(runId); + mem.deleteAll(delAll); + + AddMemoryRequest add = new AddMemoryRequest(); + add.setUserId(userId); + add.setAgentId(agentId); + add.setRunId(runId); + add.setInfer(false); + add.setText("user -- likes -- coffee\ncoffee -- is_a -- beverage"); + var addResp = mem.add(add); + assertNotNull(addResp); + assertNotNull(addResp.getRelations()); + + // Multi-hop: user -> coffee -> beverage + SearchMemoriesRequest search = SearchMemoriesRequest.ofQuery("user beverage"); + search.setUserId(userId); + search.setAgentId(agentId); + search.setRunId(runId); + var sr = mem.search(search); + assertNotNull(sr); + assertNotNull(sr.getRelations()); + assertTrue(sr.getRelations() instanceof java.util.List); + assertFalse(((java.util.List) sr.getRelations()).isEmpty()); + } +} +