From e3278c03570ff5645e8bdcb95ddd9e60370f0114 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 10 Feb 2026 08:03:28 -0800 Subject: [PATCH] Add DataReader class for typed dataframe loading (#2220) * Add DataReader class for typed dataframe loading Introduce DataReader that wraps TableProvider and applies type coercion functions when loading dataframes from weakly-typed formats (e.g. CSV). - Add DataReader class with methods for each table type: entities, relationships, communities, community_reports, covariates, text_units, and documents - Add typed loading functions in dfs.py for community_reports, covariates, text_units, and documents (entities, relationships, communities already existed) - Integrate DataReader into all 17 indexing workflows replacing raw read_dataframe calls - Integrate DataReader into CLI query's _resolve_output_files for typed loading across all search types (global, local, drift, basic) - Export DataReader from data_model package __init__ * Fix column check --- .../patch-20260210011450472481.json | 4 + packages/graphrag/graphrag/cli/query.py | 8 +- .../graphrag/graphrag/data_model/__init__.py | 4 + .../graphrag/data_model/data_reader.py | 71 ++++++++++ packages/graphrag/graphrag/data_model/dfs.py | 129 ++++++++++++++++++ .../index/workflows/create_base_text_units.py | 4 +- .../index/workflows/create_communities.py | 6 +- .../workflows/create_community_reports.py | 16 ++- .../create_community_reports_text.py | 8 +- .../index/workflows/create_final_documents.py | 6 +- .../workflows/create_final_text_units.py | 14 +- .../index/workflows/extract_covariates.py | 4 +- .../graphrag/index/workflows/extract_graph.py | 4 +- .../index/workflows/extract_graph_nlp.py | 4 +- .../index/workflows/finalize_graph.py | 6 +- .../workflows/generate_text_embeddings.py | 10 +- .../graphrag/index/workflows/prune_graph.py | 6 +- .../index/workflows/update_communities.py | 5 +- .../workflows/update_community_reports.py | 11 +- .../index/workflows/update_covariates.py | 5 +- .../update_entities_relationships.py | 9 +- .../index/workflows/update_text_units.py | 5 +- 22 files changed, 285 insertions(+), 54 deletions(-) create mode 100644 .semversioner/next-release/patch-20260210011450472481.json create mode 100644 packages/graphrag/graphrag/data_model/data_reader.py create mode 100644 packages/graphrag/graphrag/data_model/dfs.py diff --git a/.semversioner/next-release/patch-20260210011450472481.json b/.semversioner/next-release/patch-20260210011450472481.json new file mode 100644 index 0000000000..3af259576d --- /dev/null +++ b/.semversioner/next-release/patch-20260210011450472481.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add DataReader class for typed dataframe loading from TableProvider across indexing workflows and query CLI" +} diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index 21ec4b654c..cf4be9162a 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -15,6 +15,7 @@ from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks from graphrag.config.load_config import load_config from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader if TYPE_CHECKING: import pandas as pd @@ -375,12 +376,13 @@ def _resolve_output_files( output_list: list[str], optional_list: list[str] | None = None, ) -> dict[str, Any]: - """Read indexing output files to a dataframe dict.""" + """Read indexing output files to a dataframe dict, with correct column types.""" dataframe_dict = {} storage_obj = create_storage(config.output_storage) table_provider = create_table_provider(config.table_provider, storage=storage_obj) + reader = DataReader(table_provider) for name in output_list: - df_value = asyncio.run(table_provider.read_dataframe(name)) + df_value = asyncio.run(getattr(reader, name)()) dataframe_dict[name] = df_value # for optional output files, set the dict entry to None instead of erroring out if it does not exist @@ -388,7 +390,7 @@ def _resolve_output_files( for optional_file in optional_list: file_exists = asyncio.run(table_provider.has(optional_file)) if file_exists: - df_value = asyncio.run(table_provider.read_dataframe(optional_file)) + df_value = asyncio.run(getattr(reader, optional_file)()) dataframe_dict[optional_file] = df_value else: dataframe_dict[optional_file] = None diff --git a/packages/graphrag/graphrag/data_model/__init__.py b/packages/graphrag/graphrag/data_model/__init__.py index 3c0de524cd..ab84444c1b 100644 --- a/packages/graphrag/graphrag/data_model/__init__.py +++ b/packages/graphrag/graphrag/data_model/__init__.py @@ -2,3 +2,7 @@ # Licensed under the MIT License """Knowledge model package.""" + +from graphrag.data_model.data_reader import DataReader + +__all__ = ["DataReader"] diff --git a/packages/graphrag/graphrag/data_model/data_reader.py b/packages/graphrag/graphrag/data_model/data_reader.py new file mode 100644 index 0000000000..176b9ee4d5 --- /dev/null +++ b/packages/graphrag/graphrag/data_model/data_reader.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A DataReader that loads typed dataframes from a TableProvider.""" + +import pandas as pd +from graphrag_storage.tables import TableProvider + +from graphrag.data_model.dfs import ( + communities_typed, + community_reports_typed, + covariates_typed, + documents_typed, + entities_typed, + relationships_typed, + text_units_typed, +) + + +class DataReader: + """Reads dataframes from a TableProvider and applies correct column types. + + When loading from weakly-typed formats like CSV, list columns are stored as + plain strings. This class wraps a TableProvider, loading each table and + converting columns to their expected types before returning. + """ + + def __init__(self, table_provider: TableProvider) -> None: + """Initialize a DataReader with the given TableProvider. + + Args + ---- + table_provider: TableProvider + The table provider to load dataframes from. + """ + self._table_provider = table_provider + + async def entities(self) -> pd.DataFrame: + """Load and return the entities dataframe with correct types.""" + df = await self._table_provider.read_dataframe("entities") + return entities_typed(df) + + async def relationships(self) -> pd.DataFrame: + """Load and return the relationships dataframe with correct types.""" + df = await self._table_provider.read_dataframe("relationships") + return relationships_typed(df) + + async def communities(self) -> pd.DataFrame: + """Load and return the communities dataframe with correct types.""" + df = await self._table_provider.read_dataframe("communities") + return communities_typed(df) + + async def community_reports(self) -> pd.DataFrame: + """Load and return the community reports dataframe with correct types.""" + df = await self._table_provider.read_dataframe("community_reports") + return community_reports_typed(df) + + async def covariates(self) -> pd.DataFrame: + """Load and return the covariates dataframe with correct types.""" + df = await self._table_provider.read_dataframe("covariates") + return covariates_typed(df) + + async def text_units(self) -> pd.DataFrame: + """Load and return the text units dataframe with correct types.""" + df = await self._table_provider.read_dataframe("text_units") + return text_units_typed(df) + + async def documents(self) -> pd.DataFrame: + """Load and return the documents dataframe with correct types.""" + df = await self._table_provider.read_dataframe("documents") + return documents_typed(df) diff --git a/packages/graphrag/graphrag/data_model/dfs.py b/packages/graphrag/graphrag/data_model/dfs.py new file mode 100644 index 0000000000..d6d7e729fc --- /dev/null +++ b/packages/graphrag/graphrag/data_model/dfs.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing dataframe processing utilities.""" + +from typing import Any + +import pandas as pd + +from graphrag.data_model.schemas import ( + COMMUNITY_CHILDREN, + COMMUNITY_ID, + COMMUNITY_LEVEL, + COVARIATE_IDS, + EDGE_DEGREE, + EDGE_WEIGHT, + ENTITY_IDS, + FINDINGS, + N_TOKENS, + NODE_DEGREE, + NODE_FREQUENCY, + PERIOD, + RATING, + RELATIONSHIP_IDS, + SHORT_ID, + SIZE, + TEXT_UNIT_IDS, +) + + +def _split_list_column(value: Any) -> list[Any]: + """Split a column containing a list string into an actual list.""" + if isinstance(value, str): + return [item.strip("[] '") for item in value.split(",")] if value else [] + return value + + +def entities_typed(df: pd.DataFrame) -> pd.DataFrame: + """Return the entities dataframe with correct types, in case it was stored in a weakly-typed format.""" + if SHORT_ID in df.columns: + df[SHORT_ID] = df[SHORT_ID].astype(int) + if TEXT_UNIT_IDS in df.columns: + df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) + if NODE_FREQUENCY in df.columns: + df[NODE_FREQUENCY] = df[NODE_FREQUENCY].astype(int) + if NODE_DEGREE in df.columns: + df[NODE_DEGREE] = df[NODE_DEGREE].astype(int) + + return df + + +def relationships_typed(df: pd.DataFrame) -> pd.DataFrame: + """Return the relationships dataframe with correct types, in case it was stored in a weakly-typed format.""" + if SHORT_ID in df.columns: + df[SHORT_ID] = df[SHORT_ID].astype(int) + if EDGE_WEIGHT in df.columns: + df[EDGE_WEIGHT] = df[EDGE_WEIGHT].astype(float) + if EDGE_DEGREE in df.columns: + df[EDGE_DEGREE] = df[EDGE_DEGREE].astype(int) + if TEXT_UNIT_IDS in df.columns: + df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) + + return df + + +def communities_typed(df: pd.DataFrame) -> pd.DataFrame: + """Return the communities dataframe with correct types, in case it was stored in a weakly-typed format.""" + if SHORT_ID in df.columns: + df[SHORT_ID] = df[SHORT_ID].astype(int) + df[COMMUNITY_ID] = df[COMMUNITY_ID].astype(int) + df[COMMUNITY_LEVEL] = df[COMMUNITY_LEVEL].astype(int) + df[COMMUNITY_CHILDREN] = df[COMMUNITY_CHILDREN].apply(_split_list_column) + if ENTITY_IDS in df.columns: + df[ENTITY_IDS] = df[ENTITY_IDS].apply(_split_list_column) + if RELATIONSHIP_IDS in df.columns: + df[RELATIONSHIP_IDS] = df[RELATIONSHIP_IDS].apply(_split_list_column) + if TEXT_UNIT_IDS in df.columns: + df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) + df[PERIOD] = df[PERIOD].astype(str) + df[SIZE] = df[SIZE].astype(int) + + return df + + +def community_reports_typed(df: pd.DataFrame) -> pd.DataFrame: + """Return the community reports dataframe with correct types, in case it was stored in a weakly-typed format.""" + if SHORT_ID in df.columns: + df[SHORT_ID] = df[SHORT_ID].astype(int) + df[COMMUNITY_ID] = df[COMMUNITY_ID].astype(int) + df[COMMUNITY_LEVEL] = df[COMMUNITY_LEVEL].astype(int) + df[COMMUNITY_CHILDREN] = df[COMMUNITY_CHILDREN].apply(_split_list_column) + df[RATING] = df[RATING].astype(float) + df[FINDINGS] = df[FINDINGS].apply(_split_list_column) + df[SIZE] = df[SIZE].astype(int) + + return df + + +def covariates_typed(df: pd.DataFrame) -> pd.DataFrame: + """Return the covariates dataframe with correct types, in case it was stored in a weakly-typed format.""" + if SHORT_ID in df.columns: + df[SHORT_ID] = df[SHORT_ID].astype(int) + + return df + + +def text_units_typed(df: pd.DataFrame) -> pd.DataFrame: + """Return the text units dataframe with correct types, in case it was stored in a weakly-typed format.""" + if SHORT_ID in df.columns: + df[SHORT_ID] = df[SHORT_ID].astype(int) + df[N_TOKENS] = df[N_TOKENS].astype(int) + if ENTITY_IDS in df.columns: + df[ENTITY_IDS] = df[ENTITY_IDS].apply(_split_list_column) + if RELATIONSHIP_IDS in df.columns: + df[RELATIONSHIP_IDS] = df[RELATIONSHIP_IDS].apply(_split_list_column) + if COVARIATE_IDS in df.columns: + df[COVARIATE_IDS] = df[COVARIATE_IDS].apply(_split_list_column) + + return df + + +def documents_typed(df: pd.DataFrame) -> pd.DataFrame: + """Return the documents dataframe with correct types, in case it was stored in a weakly-typed format.""" + if SHORT_ID in df.columns: + df[SHORT_ID] = df[SHORT_ID].astype(int) + if TEXT_UNIT_IDS in df.columns: + df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) + + return df diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index 2d53fd8e6f..196ab3f1b6 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -15,6 +15,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.utils.hashing import gen_sha512_hash @@ -30,7 +31,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform base text_units.""" logger.info("Workflow started: create_base_text_units") - documents = await context.output_table_provider.read_dataframe("documents") + reader = DataReader(context.output_table_provider) + documents = await reader.documents() tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model) chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode) diff --git a/packages/graphrag/graphrag/index/workflows/create_communities.py b/packages/graphrag/graphrag/index/workflows/create_communities.py index 7c3d7a6b33..a9c4fb2054 100644 --- a/packages/graphrag/graphrag/index/workflows/create_communities.py +++ b/packages/graphrag/graphrag/index/workflows/create_communities.py @@ -12,6 +12,7 @@ import pandas as pd from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS from graphrag.index.operations.cluster_graph import cluster_graph from graphrag.index.operations.create_graph import create_graph @@ -27,8 +28,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform final communities.""" logger.info("Workflow started: create_communities") - entities = await context.output_table_provider.read_dataframe("entities") - relationships = await context.output_table_provider.read_dataframe("relationships") + reader = DataReader(context.output_table_provider) + entities = await reader.entities() + relationships = await reader.relationships() max_cluster_size = config.cluster_graph.max_cluster_size use_lcc = config.cluster_graph.use_lcc diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports.py b/packages/graphrag/graphrag/index/workflows/create_community_reports.py index c6f8b1decf..c9c8b964c1 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports.py @@ -15,6 +15,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.operations.finalize_community_reports import ( finalize_community_reports, ) @@ -43,15 +44,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" logger.info("Workflow started: create_community_reports") - edges = await context.output_table_provider.read_dataframe("relationships") - entities = await context.output_table_provider.read_dataframe("entities") - communities = await context.output_table_provider.read_dataframe("communities") + reader = DataReader(context.output_table_provider) + relationships = await reader.relationships() + entities = await reader.entities() + communities = await reader.communities() claims = None if config.extract_claims.enabled and await context.output_table_provider.has( "covariates" ): - claims = await context.output_table_provider.read_dataframe("covariates") + claims = await reader.covariates() model_config = config.get_completion_model_config( config.community_reports.completion_model_id @@ -67,7 +69,7 @@ async def run_workflow( tokenizer = model.tokenizer output = await create_community_reports( - edges_input=edges, + relationships=relationships, entities=entities, communities=communities, claims_input=claims, @@ -88,7 +90,7 @@ async def run_workflow( async def create_community_reports( - edges_input: pd.DataFrame, + relationships: pd.DataFrame, entities: pd.DataFrame, communities: pd.DataFrame, claims_input: pd.DataFrame | None, @@ -105,7 +107,7 @@ async def create_community_reports( nodes = explode_communities(communities, entities) nodes = _prep_nodes(nodes) - edges = _prep_edges(edges_input) + edges = _prep_edges(relationships) claims = None if claims_input is not None: diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py index 52cb4b0f8e..ea58c995c1 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py @@ -14,6 +14,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.operations.finalize_community_reports import ( finalize_community_reports, ) @@ -42,9 +43,10 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" logger.info("Workflow started: create_community_reports_text") - entities = await context.output_table_provider.read_dataframe("entities") - communities = await context.output_table_provider.read_dataframe("communities") - text_units = await context.output_table_provider.read_dataframe("text_units") + reader = DataReader(context.output_table_provider) + entities = await reader.entities() + communities = await reader.communities() + text_units = await reader.text_units() model_config = config.get_completion_model_config( config.community_reports.completion_model_id diff --git a/packages/graphrag/graphrag/index/workflows/create_final_documents.py b/packages/graphrag/graphrag/index/workflows/create_final_documents.py index c799d1bb44..57c67229e9 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_documents.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_documents.py @@ -8,6 +8,7 @@ import pandas as pd from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -21,8 +22,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform final documents.""" logger.info("Workflow started: create_final_documents") - documents = await context.output_table_provider.read_dataframe("documents") - text_units = await context.output_table_provider.read_dataframe("text_units") + reader = DataReader(context.output_table_provider) + documents = await reader.documents() + text_units = await reader.text_units() output = create_final_documents(documents, text_units) diff --git a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py index 51fde7f72d..91e097c355 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py @@ -8,6 +8,7 @@ import pandas as pd from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -21,19 +22,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform the text units.""" logger.info("Workflow started: create_final_text_units") - text_units = await context.output_table_provider.read_dataframe("text_units") - final_entities = await context.output_table_provider.read_dataframe("entities") - final_relationships = await context.output_table_provider.read_dataframe( - "relationships" - ) + reader = DataReader(context.output_table_provider) + text_units = await reader.text_units() + final_entities = await reader.entities() + final_relationships = await reader.relationships() final_covariates = None if config.extract_claims.enabled and await context.output_table_provider.has( "covariates" ): - final_covariates = await context.output_table_provider.read_dataframe( - "covariates" - ) + final_covariates = await reader.covariates() output = create_final_text_units( text_units, diff --git a/packages/graphrag/graphrag/index/workflows/extract_covariates.py b/packages/graphrag/graphrag/index/workflows/extract_covariates.py index f27d8590d1..a98e8a8851 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/extract_covariates.py @@ -15,6 +15,7 @@ from graphrag.config.defaults import DEFAULT_ENTITY_TYPES from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS from graphrag.index.operations.extract_covariates.extract_covariates import ( extract_covariates as extractor, @@ -36,7 +37,8 @@ async def run_workflow( logger.info("Workflow started: extract_covariates") output = None if config.extract_claims.enabled: - text_units = await context.output_table_provider.read_dataframe("text_units") + reader = DataReader(context.output_table_provider) + text_units = await reader.text_units() model_config = config.get_completion_model_config( config.extract_claims.completion_model_id diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph.py b/packages/graphrag/graphrag/index/workflows/extract_graph.py index 237bbe16cc..dc86b180fc 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph.py @@ -13,6 +13,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.operations.extract_graph.extract_graph import ( extract_graph as extractor, ) @@ -34,7 +35,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph") - text_units = await context.output_table_provider.read_dataframe("text_units") + reader = DataReader(context.output_table_provider) + text_units = await reader.text_units() extraction_model_config = config.get_completion_model_config( config.extract_graph.completion_model_id diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py index c0cd069ac6..3bd51f9026 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py @@ -10,6 +10,7 @@ from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.operations.build_noun_graph.build_noun_graph import build_noun_graph from graphrag.index.operations.build_noun_graph.np_extractors.base import ( BaseNounPhraseExtractor, @@ -29,7 +30,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph_nlp") - text_units = await context.output_table_provider.read_dataframe("text_units") + reader = DataReader(context.output_table_provider) + text_units = await reader.text_units() text_analyzer_config = config.extract_graph_nlp.text_analyzer text_analyzer = create_noun_phrase_extractor(text_analyzer_config) diff --git a/packages/graphrag/graphrag/index/workflows/finalize_graph.py b/packages/graphrag/graphrag/index/workflows/finalize_graph.py index 64029a8cb6..31fc9fddd4 100644 --- a/packages/graphrag/graphrag/index/workflows/finalize_graph.py +++ b/packages/graphrag/graphrag/index/workflows/finalize_graph.py @@ -8,6 +8,7 @@ import pandas as pd from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.finalize_entities import finalize_entities from graphrag.index.operations.finalize_relationships import finalize_relationships @@ -24,8 +25,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: finalize_graph") - entities = await context.output_table_provider.read_dataframe("entities") - relationships = await context.output_table_provider.read_dataframe("relationships") + reader = DataReader(context.output_table_provider) + entities = await reader.entities() + relationships = await reader.relationships() final_entities, final_relationships = finalize_graph( entities, diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index c1e42969ee..ebce58b914 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -22,6 +22,7 @@ text_unit_text_embedding, ) from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -40,17 +41,16 @@ async def run_workflow( logger.info("Workflow started: generate_text_embeddings") embedded_fields = config.embed_text.names logger.info("Embedding the following fields: %s", embedded_fields) + reader = DataReader(context.output_table_provider) text_units = None entities = None community_reports = None if text_unit_text_embedding in embedded_fields: - text_units = await context.output_table_provider.read_dataframe("text_units") + text_units = await reader.text_units() if entity_description_embedding in embedded_fields: - entities = await context.output_table_provider.read_dataframe("entities") + entities = await reader.entities() if community_full_content_embedding in embedded_fields: - community_reports = await context.output_table_provider.read_dataframe( - "community_reports" - ) + community_reports = await reader.community_reports() model_config = config.get_embedding_model_config( config.embed_text.embedding_model_id diff --git a/packages/graphrag/graphrag/index/workflows/prune_graph.py b/packages/graphrag/graphrag/index/workflows/prune_graph.py index 483c9b18b3..f3720fd1ee 100644 --- a/packages/graphrag/graphrag/index/workflows/prune_graph.py +++ b/packages/graphrag/graphrag/index/workflows/prune_graph.py @@ -9,6 +9,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.prune_graph_config import PruneGraphConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.graph_to_dataframes import graph_to_dataframes from graphrag.index.operations.prune_graph import prune_graph as prune_graph_operation @@ -24,8 +25,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: prune_graph") - entities = await context.output_table_provider.read_dataframe("entities") - relationships = await context.output_table_provider.read_dataframe("relationships") + reader = DataReader(context.output_table_provider) + entities = await reader.entities() + relationships = await reader.relationships() pruned_entities, pruned_relationships = prune_graph( entities, diff --git a/packages/graphrag/graphrag/index/workflows/update_communities.py b/packages/graphrag/graphrag/index/workflows/update_communities.py index 7887706a86..170b5a81c8 100644 --- a/packages/graphrag/graphrag/index/workflows/update_communities.py +++ b/packages/graphrag/graphrag/index/workflows/update_communities.py @@ -8,6 +8,7 @@ from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -42,8 +43,8 @@ async def _update_communities( output_table_provider: TableProvider, ) -> dict: """Update the communities output.""" - old_communities = await previous_table_provider.read_dataframe("communities") - delta_communities = await delta_table_provider.read_dataframe("communities") + old_communities = await DataReader(previous_table_provider).communities() + delta_communities = await DataReader(delta_table_provider).communities() merged_communities, community_id_mapping = _update_and_merge_communities( old_communities, delta_communities ) diff --git a/packages/graphrag/graphrag/index/workflows/update_community_reports.py b/packages/graphrag/graphrag/index/workflows/update_community_reports.py index 9c9b0f2fec..29a2508031 100644 --- a/packages/graphrag/graphrag/index/workflows/update_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/update_community_reports.py @@ -9,6 +9,7 @@ from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -51,12 +52,10 @@ async def _update_community_reports( community_id_mapping: dict, ) -> pd.DataFrame: """Update the community reports output.""" - old_community_reports = await previous_table_provider.read_dataframe( - "community_reports" - ) - delta_community_reports = await delta_table_provider.read_dataframe( - "community_reports" - ) + old_community_reports = await DataReader( + previous_table_provider + ).community_reports() + delta_community_reports = await DataReader(delta_table_provider).community_reports() merged_community_reports = _update_and_merge_community_reports( old_community_reports, delta_community_reports, community_id_mapping ) diff --git a/packages/graphrag/graphrag/index/workflows/update_covariates.py b/packages/graphrag/graphrag/index/workflows/update_covariates.py index 2a79b52d74..650aeee987 100644 --- a/packages/graphrag/graphrag/index/workflows/update_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/update_covariates.py @@ -10,6 +10,7 @@ from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -45,8 +46,8 @@ async def _update_covariates( output_table_provider: TableProvider, ) -> None: """Update the covariates output.""" - old_covariates = await previous_table_provider.read_dataframe("covariates") - delta_covariates = await delta_table_provider.read_dataframe("covariates") + old_covariates = await DataReader(previous_table_provider).covariates() + delta_covariates = await DataReader(delta_table_provider).covariates() merged_covariates = _merge_covariates(old_covariates, delta_covariates) await output_table_provider.write_dataframe("covariates", merged_covariates) diff --git a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py index c7d1bcc416..6c3937a99b 100644 --- a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py +++ b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py @@ -13,6 +13,7 @@ from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -63,16 +64,16 @@ async def _update_entities_and_relationships( callbacks: WorkflowCallbacks, ) -> tuple[pd.DataFrame, pd.DataFrame, dict]: """Update Final Entities and Relationships output.""" - old_entities = await previous_table_provider.read_dataframe("entities") - delta_entities = await delta_table_provider.read_dataframe("entities") + old_entities = await DataReader(previous_table_provider).entities() + delta_entities = await DataReader(delta_table_provider).entities() merged_entities_df, entity_id_mapping = _group_and_resolve_entities( old_entities, delta_entities ) # Update Relationships - old_relationships = await previous_table_provider.read_dataframe("relationships") - delta_relationships = await delta_table_provider.read_dataframe("relationships") + old_relationships = await DataReader(previous_table_provider).relationships() + delta_relationships = await DataReader(delta_table_provider).relationships() merged_relationships_df = _update_and_merge_relationships( old_relationships, delta_relationships, diff --git a/packages/graphrag/graphrag/index/workflows/update_text_units.py b/packages/graphrag/graphrag/index/workflows/update_text_units.py index 02592b8aa4..c5095c27d8 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_units.py @@ -10,6 +10,7 @@ from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.data_model.data_reader import DataReader from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -48,8 +49,8 @@ async def _update_text_units( entity_id_mapping: dict, ) -> pd.DataFrame: """Update the text units output.""" - old_text_units = await previous_table_provider.read_dataframe("text_units") - delta_text_units = await delta_table_provider.read_dataframe("text_units") + old_text_units = await DataReader(previous_table_provider).text_units() + delta_text_units = await DataReader(delta_table_provider).text_units() merged_text_units = _update_and_merge_text_units( old_text_units, delta_text_units, entity_id_mapping )