diff --git a/.gitignore b/.gitignore index b36d61cf..d2641bba 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ venv heraenv .venv .python-version +hera/tests/env.template diff --git a/hera/bin/hera-project b/hera/bin/hera-project index aacbaece..fe6e9e8e 100755 --- a/hera/bin/hera-project +++ b/hera/bin/hera-project @@ -137,25 +137,133 @@ if __name__ == "__main__": help='overwrite the existing workflow with the same name') repository_load.set_defaults(func=CLI.repository_load) - # ----------------- NEW: Project Measurements (display helper) - # Adds: hera-project project measurements list [--project ...] [--type ...] [--contains ...] - measurements_parser = project_subparsers.add_parser('measurements', help='Measurements commands') + # ----------------- Project Measurements (display helper) ----------------- + # Adds: + # hera-project project measurements list --project ... [--type ...] [--contains ...] + # hera-project project measurements list --project ... --shortcut ds|exp|sim|cache|all + measurements_parser = project_subparsers.add_parser( + 'measurements', + help='Measurements commands (inspect measurement documents in a project)' + ) measurements_sub = measurements_parser.add_subparsers(help='Measurements sub-commands') # project measurements list - meas_list = measurements_sub.add_parser('list', help='List project measurements') - # --project is optional: if your Project() can infer from CWD; otherwise pass it explicitly - meas_list.add_argument('--project', required=False, help='Project name (optional if auto-detected)') - # Filter by 'type' (e.g., ToolkitDataSource, Experiment_rawData) - meas_list.add_argument('--type', required=False, help="Filter by 'type' field") - # Substring filter on datasourceName/resource - meas_list.add_argument('--contains', required=False, help='Substring filter on datasourceName or resource') + meas_list = measurements_sub.add_parser( + 'list', + help='List project measurements (with filters or shortcut groups)' + ) + + # --project is optional if your Project() can infer from CWD + meas_list.add_argument( + '--project', + required=False, + help='Project name (optional if auto-detected)' + ) + + # Explicit filter by type (kept for backwards compatibility) + meas_list.add_argument( + '--type', + required=False, + help="Filter by 'type' field (e.g. ToolkitDataSource, Experiment_rawData, Simulations, Cache)" + ) + + # New: shortcut between common groups: ds/exp/sim/cache/all + meas_list.add_argument( + '--shortcut', + required=False, + choices=['ds', 'exp', 'sim', 'cache', 'all'], + help=( + "Shortcut for common groups:\n" + " ds = ToolkitDataSource (dynamic toolkits)\n" + " exp = Experiment_rawData (experiments)\n" + " sim = Simulations (Simulation_* types)\n" + " cache = Cache (Cache_* types)\n" + " all = all of the above" + ) + ) + + # Optional substring filter on datasourceName/resource + meas_list.add_argument( + '--contains', + required=False, + help='Substring filter on datasourceName or resource' + ) + meas_list.set_defaults(func=CLI.project_measurements_list) - # ----------------- Exec + # ----------------- Convenience alias: project simulations ----------------- + # hera-project project simulations list --project ... [--contains ...] + sims_parser = project_subparsers.add_parser( + 'simulations', + help='Simulation documents (alias of "measurements list --shortcut sim")' + ) + sims_sub = sims_parser.add_subparsers(help='Simulations sub-commands') + + sims_list = sims_sub.add_parser( + 'list', + help='List simulation documents in the project' + ) + sims_list.add_argument( + '--project', + required=False, + help='Project name (optional if auto-detected)' + ) + sims_list.add_argument( + '--contains', + required=False, + help='Substring filter on datasourceName or resource' + ) + # Reuse the same handler, pre-setting shortcut="sim" + sims_list.set_defaults( + func=CLI.project_measurements_list, + shortcut='sim', + type=None, + ) + + # ----------------- Convenience alias: project cache ----------------- + # hera-project project cache list --project ... [--contains ...] + cache_parser = project_subparsers.add_parser( + 'cache', + help='Cache documents (alias of "measurements list --shortcut cache")' + ) + cache_sub = cache_parser.add_subparsers(help='Cache sub-commands') + + cache_list = cache_sub.add_parser( + 'list', + help='List cache documents in the project' + ) + cache_list.add_argument( + '--project', + required=False, + help='Project name (optional if auto-detected)' + ) + cache_list.add_argument( + '--contains', + required=False, + help='Substring filter on datasourceName or resource' + ) + # Reuse the same handler, pre-setting shortcut="cache" + cache_list.set_defaults( + func=CLI.project_measurements_list, + shortcut='cache', + type=None, + ) + + + # ----------------- Exec ----------------- parsed = parser.parse_args() logger.debug(f"Got {parsed} in the command line") - if 'func' not in parsed: + + # If no sub-command was selected – print help and exit + func = getattr(parsed, "func", None) + if func is None: parser.print_help() else: - parsed.func(parsed) + # argparse stored whatever we passed in set_defaults(...). + # In some setups this might be a `staticmethod` object, + # so we unwrap it before calling. + if isinstance(func, staticmethod): + func = func.__func__ + + # Now call the underlying function with `parsed` namespace + func(parsed) diff --git a/hera/measurements/experiment/experiment.py b/hera/measurements/experiment/experiment.py index 75e4ae49..dea1497f 100644 --- a/hera/measurements/experiment/experiment.py +++ b/hera/measurements/experiment/experiment.py @@ -1,130 +1,165 @@ import os -import pydoc import sys +import logging +import pydoc +import pandas as pd + from hera import toolkit from .presentation import experimentPresentation from .analysis import experimentAnalysis -from hera.measurements.GIS.utils import WSG84,ITM,convertCRS -import pandas as pd +from hera.measurements.GIS.utils import WSG84, ITM, convertCRS try: from argos.experimentSetup import dataObjects as argosDataObjects except ImportError: - print("Must have argos installed and in the path. ") + # Argos is optional; if it is not installed, experiment toolkits cannot be used. + print("Must have argos installed and in the path.") -from .dataEngine import dataEngineFactory, PARQUETHERA, PANDASDB,DASKDB -from hera.utils import loadJSON -import logging +from .dataEngine import dataEngineFactory, PARQUETHERA, PANDASDB, DASKDB +from hera.utils import loadJSON -# The name of the property. This is has to be similar ot the from the argosweb interface. -# Dont change! -TRIALSTART = 'TrialStart' -TRIALEND = 'TrialEnd' +# The name of the properties. These must match the Argos web interface. +# Do not change unless the Argos schema changes. +TRIALSTART = "TrialStart" +TRIALEND = "TrialEnd" class experimentHome(toolkit.abstractToolkit): """ - This is the object that function as a factory/home to the other experiments. - It is responsible for getting the right toolkit for the requested experiment. - + This object functions as a factory/home for the other experiments. + It is responsible for getting the right toolkit for the requested experiment. """ - DOCTYPE_ENTITIES = 'EntitiesData' - CODE_DIRECTORY = 'code' + DOCTYPE_ENTITIES = "EntitiesData" + CODE_DIRECTORY = "code" def __init__(self, projectName, filesDirectory=None): - super().__init__(projectName=projectName, toolkitName="experimentToolKit", filesDirectory=filesDirectory) + super().__init__( + projectName=projectName, + toolkitName="experimentToolKit", + filesDirectory=filesDirectory, + ) self.logger = logging.getLogger() self.logger.info("Init experiment toolkit") - @property def experimentMap(self): - return self.experimentMap() + """ + Backward-compatibility alias. + Historically this was a property, so we keep the interface, + even though today the real logic is in getExperimentsMap(). + """ + return self.getExperimentsMap() - # def getExperimentsMap(self): """ - Get dictionary of experiments map of project. + Get a dictionary mapping experiment name -> datasource document. Returns ------- - dict + dict + Keys are experiment names (datasourceName), + values are the matching datasource documents. """ - M=dict() + experiments_map = {} for experiment in self.getDataSourceMap(): - experimentName=experiment['datasourceName'] - M[experimentName]=experiment - - return M + experimentName = experiment["datasourceName"] + experiments_map[experimentName] = experiment + return experiments_map @property def experimentsTable(self): + """ + Return a tabular view (DataFrame-like) of experiment datasources. + """ return self.getDataSourceTable() def getExperimentsTable(self): + """ + Backward-compatible alias for experimentsTable. + """ return self.getDataSourceTable() - def getExperiment(self,experimentName,filesDirectory=None): + def getExperiment(self, experimentName, filesDirectory=None): """ Get the specific experiment class. Parameters ---------- experimentName : str - The name of the experimen - filesDirectory: str - The directory to save the cache/intermediate files. - If None, use the [current directory]/experimentCache. + The name of the experiment. + filesDirectory : str, optional + The directory to save cache/intermediate files. + If None, uses the current working directory / 'experimentCache'. Returns ------- - experimentSetupWithData + experimentSetupWithData """ - self.logger.info(f"Getting experiment {experimentName}") - L = self.getDataSourceDocument(datasourceName=experimentName) - if L: - self.logger.info(f"Found experiment. Loading") - experimentPath=L.getData() - sys.path.append(os.path.join(experimentPath,self.CODE_DIRECTORY)) - self.logger.debug(f"Adding path {os.path.join(experimentPath,self.CODE_DIRECTORY)} to classpath") + ds_doc = self.getDataSourceDocument(datasourceName=experimentName) + + if ds_doc: + self.logger.info("Found experiment. Loading") + experimentPath = ds_doc.getData() + + # Add experiment's 'code' directory to sys.path so we can import its toolkit. + sys.path.append(os.path.join(experimentPath, self.CODE_DIRECTORY)) + self.logger.debug( + f"Adding path {os.path.join(experimentPath, self.CODE_DIRECTORY)} to sys.path" + ) + toolkitName = f"{experimentName}.{experimentName}" - self.logger.debug(f"Loading toolkits: {toolkitName}") + self.logger.debug(f"Loading toolkit: {toolkitName}") toolkitCls = pydoc.locate(toolkitName) if toolkitCls is None: - err = f"Cannot find toolkit {toolkitName} in {os.path.join(experimentPath,self.CODE_DIRECTORY)}" + err = ( + f"Cannot find toolkit {toolkitName} in " + f"{os.path.join(experimentPath, self.CODE_DIRECTORY)}" + ) self.logger.error(err) raise ValueError(err) - return toolkitCls(projectName=self.projectName, - pathToExperiment=experimentPath,filesDirectory=filesDirectory) - else: - err = f"Experiment {experimentName} not found in Project {self.projectName}. Please load the experiment to the project. " - self.logger.error(err) - raise ValueError(err) + return toolkitCls( + projectName=self.projectName, + pathToExperiment=experimentPath, + filesDirectory=filesDirectory, + ) + err = ( + f"Experiment {experimentName} not found in Project {self.projectName}. " + f"Please load the experiment to the project." + ) + self.logger.error(err) + raise ValueError(err) def keys(self): """ - Get the experiments names of project. + Get the experiment names of the project. Returns ------- - list + list of str """ - return [x for x in self.getExperimentsMap()] + return [name for name in self.getExperimentsMap()] def __getitem__(self, item): + """ + Allow experimentHome['expName'] syntax to return a specific experiment. + """ return self.getExperiment(item) def experimentDataType(self): - return self._experimentDataType + """ + Backward-compatibility hook for experiment data type. + """ + return getattr(self, "_experimentDataType", None) -class experimentSetupWithData(argosDataObjects.ExperimentZipFile,toolkit.abstractToolkit): + +class experimentSetupWithData(argosDataObjects.ExperimentZipFile, toolkit.abstractToolkit): """ - A class that unifies the argos.experiment setup with the data. + A class that unifies the argos.experiment setup with the data. """ _configuration = None @@ -148,73 +183,101 @@ def configuration(self): @property def name(self): - return self.configuration['experimentName'] + return self.configuration["experimentName"] def _initTrialSets(self): + """ + Initialize trial sets from the experiment setup metadata. + """ experimentSetup = self.setup - for trialset in experimentSetup['trialSets']: - self.trialSet[trialset['name']] = TrialSetWithData(experiment = self, TrialSetSetup=trialset,experimentData= self._experimentData) + for trialset in experimentSetup["trialSets"]: + self.trialSet[trialset["name"]] = TrialSetWithData( + experiment=self, + TrialSetSetup=trialset, + experimentData=self._experimentData, + ) def _initEntitiesTypes(self): + """ + Initialize entity types from the experiment setup metadata. + """ experimentSetup = self.setup - for entityType in experimentSetup['entityTypes']: - self.entityType[entityType['name']] = EntityTypeWithData(experiment=self, metadata = entityType, experimentData= self._experimentData) + for entityType in experimentSetup["entityTypes"]: + self.entityType[entityType["name"]] = EntityTypeWithData( + experiment=self, + metadata=entityType, + experimentData=self._experimentData, + ) def getExperimentData(self): """ - Get the parquet Data Engine of experiment. Acessing data of experiment is through this class (using .getData()). + Get the Data Engine of the experiment. + Accessing experiment data is done through this object (using .getData()). Returns ------- - parquetDataEngineHera , pandasDataEngineDB or daskDataEngineDB. + parquetDataEngineHera or pandasDataEngineDB or daskDataEngineDB """ return self._experimentData - def __init__(self, projectName, pathToExperiment, dataType=PARQUETHERA, dataSourceConfiguration=dict(), filesDirectory=None,defaultTrialSetName=None): + def __init__( + self, + projectName, + pathToExperiment, + dataType=PARQUETHERA, + dataSourceConfiguration=dict(), + filesDirectory=None, + defaultTrialSetName=None, + ): """ - Initializes the specific experiment toolkit. + Initialize the specific experiment toolkit. Parameters ---------- - projectName: str - The project name to work with. - - pathToExperiment: - The path to the experiment data. - - dataType: str - Define how the data is retrieved: dask or pandas directly from the mongoDB, or through the - parquet. - + projectName : str + The project name to work with. + pathToExperiment : str + The path to the experiment data. + dataType : str + How data is retrieved: dask/pandas from MongoDB, or parquet. dataSourceConfiguration : dict - overwrite the datasources configuration of the experiment. - See ... for structure. - - filesDirectory: str - The directory to save the cache/intermediate files. - If None, use the [current directory]/experimentCache. - - defaultTrialSet: str - A default trialset to use if not supplied. - + Override datasources configuration of the experiment. + filesDirectory : str, optional + Directory to save cache/intermediate files. + If None, uses [current directory]/experimentCache. + defaultTrialSetName : str, optional + Default trial set to use if not supplied. """ - # setup the configuration file name - configurationFileName = os.path.join(pathToExperiment, 'runtimeExperimentData', "Datasources_Configurations.json") + # Locate configuration file + configurationFileName = os.path.join( + pathToExperiment, "runtimeExperimentData", "Datasources_Configurations.json" + ) if not os.path.isfile(configurationFileName): - raise ValueError(f" The configuration file doesn't exist. Looking for {configurationFileName}") + raise ValueError( + f"The configuration file does not exist. Looking for {configurationFileName}" + ) self._configuration = loadJSON(configurationFileName) dataSourceConfiguration = dict() if dataSourceConfiguration is None else dataSourceConfiguration self._configuration.update(dataSourceConfiguration) - experimentName = self.configuration['experimentName'] - setupFile = os.path.join(pathToExperiment, 'runtimeExperimentData', f"{experimentName}.zip" ) + experimentName = self.configuration["experimentName"] + setupFile = os.path.join( + pathToExperiment, "runtimeExperimentData", f"{experimentName}.zip" + ) if not os.path.isfile(setupFile): - raise ValueError(f"The experiment setup file doesn't exist. Looking for {setupFile} ") - - # Now initialize the data engine. - self._experimentData = dataEngineFactory().getDataEngine(projectName,self._configuration,experimentObj=self, dataType = dataType) + raise ValueError( + f"The experiment setup file does not exist. Looking for {setupFile}" + ) + + # Initialize the data engine + self._experimentData = dataEngineFactory().getDataEngine( + projectName, + self._configuration, + experimentObj=self, + dataType=dataType, + ) self.entityType = dict() self.trialSet = dict() @@ -222,14 +285,19 @@ def __init__(self, projectName, pathToExperiment, dataType=PARQUETHERA, dataSour filesDirectory = os.getcwd() cacheDir = os.path.join(filesDirectory, "experimentCache") - os.makedirs(cacheDir,exist_ok=True) + os.makedirs(cacheDir, exist_ok=True) - argosDataObjects.ExperimentZipFile.__init__(self,setupFile) - toolkit.abstractToolkit.__init__(self,projectName=projectName,toolkitName=f"{experimentName}Toolkit",filesDirectory=cacheDir) + argosDataObjects.ExperimentZipFile.__init__(self, setupFile) + toolkit.abstractToolkit.__init__( + self, + projectName=projectName, + toolkitName=f"{experimentName}Toolkit", + filesDirectory=cacheDir, + ) self._defaultTrialSetName = defaultTrialSetName - self._analysis = experimentAnalysis(self,) - self._presentation = experimentPresentation(self,self.analysis) + self._analysis = experimentAnalysis(self) + self._presentation = experimentPresentation(self, self.analysis) @property def defaultTrialSet(self): @@ -239,168 +307,275 @@ def defaultTrialSet(self): def trialsOfDefaultTrialSet(self): return self.trialSet[self.defaultTrialSet] - def _initAnalysisAndPresentation(self,analysisCLS,presentationCLS): + def _initAnalysisAndPresentation(self, analysisCLS, presentationCLS): """ - Initializes the analysis and the presentation classes - and sets the datalayer. + Initialize the analysis and presentation classes and set the data layer. Parameters ---------- - analysisCLS : class - The analysis class. It is recommended that it will inherit from - .analysis.experimentAnalysis - + analysisCLS : class + The analysis class, recommended to inherit from .analysis.experimentAnalysis. presentationCLS : class - The presentation class. It is recommended that it will inherit from - .presentation.experimentPresentation + The presentation class, recommended to inherit from .presentation.experimentPresentation. + """ + self._analysis = analysisCLS(self) + self._presentation = presentationCLS(self, self._analysis) + + def getDataFromDateRange( + self, + deviceType, + startTime, + endTime, + deviceName=None, + withMetadata=True, + ): + """ + Retrieve data for a given device type and time range. + Parameters + ---------- + deviceType : str + startTime : datetime-like + endTime : datetime-like + deviceName : str, optional + withMetadata : bool Returns ------- - + DataFrame """ - self._analysis = analysisCLS(self) - self._presentation = presentationCLS(self,self._analysis) - - def getDataFromDateRange(self,deviceType,startTime , endTime ,deviceName = None,withMetadata = True): - data = self._experimentData.getData(deviceType=deviceType,deviceName=deviceName,startTime=startTime,endTime=endTime) + data = self._experimentData.getData( + deviceType=deviceType, + deviceName=deviceName, + startTime=startTime, + endTime=endTime, + ) if len(data) == 0: - raise ValueError(f"There is no data for {deviceType} between the dates {startTime} and {endTime}") + raise ValueError( + f"There is no data for {deviceType} between the dates {startTime} and {endTime}" + ) if withMetadata: devicemetadata = self.entitiesTable() if len(devicemetadata) > 0: - data = data.reset_index().merge(devicemetadata, left_on="deviceName", right_on="entityName").set_index( - "timestamp") + data = ( + data.reset_index() + .merge( + devicemetadata, + left_on="deviceName", + right_on="entityName", + ) + .set_index("timestamp") + ) return data - def _process_row(self,row): + def _process_row(self, row): + """ + Helper for coordinate conversion of a single row (Longitude, Latitude). + """ pp = convertCRS([[row.Longitude, row.Latitude]], inputCRS=WSG84, outputCRS=ITM) return pd.Series([pp.x[0], pp.y[0]]) - def get_devices_image_coordinates(self,trialSetName,trialName,deviceType,outputCRS=ITM): - devices_df = self.trialSet[trialSetName][trialName].entitiesTable.query("deviceTypeName==@deviceType") + def get_devices_image_coordinates( + self, + trialSetName, + trialName, + deviceType, + outputCRS=ITM, + ): + """ + Compute bounding box of devices in image coordinates (ITM or original). - if outputCRS==ITM: - devices_df[['ITM_Latitude', 'ITM_Longitude']] = devices_df.apply(self._process_row, axis=1) - latitudes = devices_df['ITM_Latitude'] - longitudes = devices_df['ITM_Longitude'] + Returns + ------- + (min_latitude, min_longitude, max_latitude, max_longitude) + """ + devices_df = self.trialSet[trialSetName][trialName].entitiesTable.query( + "deviceTypeName==@deviceType" + ) + + if outputCRS == ITM: + devices_df[["ITM_Latitude", "ITM_Longitude"]] = devices_df.apply( + self._process_row, axis=1 + ) + latitudes = devices_df["ITM_Latitude"] + longitudes = devices_df["ITM_Longitude"] else: - latitudes = devices_df['Latitude'] - longitudes = devices_df['Longitude'] + latitudes = devices_df["Latitude"] + longitudes = devices_df["Longitude"] + min_latitude, max_latitude = min(latitudes), max(latitudes) min_longitude, max_longitude = min(longitudes), max(longitudes) - return min_latitude,min_longitude,max_latitude,max_longitude + return min_latitude, min_longitude, max_latitude, max_longitude + class TrialSetWithData(argosDataObjects.TrialSet): + """ + TrialSet that is aware of the experiment data engine. + """ def _initTrials(self): - for trial in self._metadata['trials']: - self[trial['name']] = TrialWithdata(trialSet=self,metadata=trial, experimentData =self._experimentData ) - - def __init__(self, experiment:experimentSetupWithData, TrialSetSetup: dict, experimentData: dataEngineFactory): + for trial in self._metadata["trials"]: + self[trial["name"]] = TrialWithdata( + trialSet=self, + metadata=trial, + experimentData=self._experimentData, + ) + + def __init__(self, experiment: experimentSetupWithData, TrialSetSetup: dict, experimentData: dataEngineFactory): """ - The initialization of the experiment. - - The object that handles the retrieval of the data is different to support - access to db, and pandas in the different stages of the experiment. - - Parameters - ---------- - experiment : the data of the experiment. - TrialSetSetup : The data of the trials. - experimentData : a link to the object that handles the retrieval of the data. - + Initialize the TrialSet with a link to the shared experiment data engine. """ self._experimentData = experimentData super().__init__(experiment, TrialSetSetup) class TrialWithdata(argosDataObjects.Trial): + """ + Trial object that knows how to pull its data from the shared experiment data engine. + """ - def getData(self,deviceType,deviceName = None,startTime = None, endTime = None,withMetadata = False): - + def getData( + self, + deviceType, + deviceName=None, + startTime=None, + endTime=None, + withMetadata=False, + ): + """ + Retrieve trial data for a given device type and time range. + If startTime/endTime are not provided, the trial properties TRIALSTART/TRIALEND are used. + """ startTime = self.properties[TRIALSTART] if startTime is None else startTime endTime = self.properties[TRIALEND] if endTime is None else endTime - data = self._experimentData.getData(deviceType=deviceType,deviceName=deviceName,startTime=startTime,endTime=endTime) + data = self._experimentData.getData( + deviceType=deviceType, + deviceName=deviceName, + startTime=startTime, + endTime=endTime, + ) if len(data) == 0: - raise ValueError(f"There is no data for {deviceType} between the dates {startTime} and {endTime}") + raise ValueError( + f"There is no data for {deviceType} between the dates {startTime} and {endTime}" + ) if withMetadata: devicemetadata = self.entitiesTable() if len(devicemetadata) > 0: - data = data.reset_index().merge(devicemetadata, left_on="deviceName", right_on="entityName").set_index("timestamp") + data = ( + data.reset_index() + .merge( + devicemetadata, + left_on="deviceName", + right_on="entityName", + ) + .set_index("timestamp") + ) return data - - def __init__(self, trialSet: TrialSetWithData, metadata: dict, experimentData: dataEngineFactory): + def __init__( + self, + trialSet: TrialSetWithData, + metadata: dict, + experimentData: dataEngineFactory, + ): self._experimentData = experimentData super().__init__(trialSet, metadata) class EntityTypeWithData(argosDataObjects.EntityType): + """ + EntityType that knows how to pull its data from the shared experiment data engine. + """ def _initEntities(self): - for entity in self._metadata['entities']: - self[entity['name']] = EntityWithData(entityType=self, metadata=entity,experimentData =self._experimentData) - - def __init__(self, experiment:experimentSetupWithData, metadata: dict, experimentData: dataEngineFactory): + for entity in self._metadata["entities"]: + self[entity["name"]] = EntityWithData( + entityType=self, + metadata=entity, + experimentData=self._experimentData, + ) + + def __init__( + self, + experiment: experimentSetupWithData, + metadata: dict, + experimentData: dataEngineFactory, + ): """ - The iniitialization of the object with the data. - - Parameters - ---------- - experiment : The data of the experiment. - metadata - experimentData + Initialize the EntityType with a link to the shared experiment data engine. """ self._experimentData = experimentData super().__init__(experiment, metadata) def getData(self, startTime=None, endTime=None): - return self._experimentData.getData(self.name,startTime = startTime,endTime = endTime ) + """ + Retrieve all data for this entity type (optionally in a time range). + """ + return self._experimentData.getData( + self.name, + startTime=startTime, + endTime=endTime, + ) - def getDataTrial(self,trialSetName,trialName): + def getDataTrial(self, trialSetName, trialName): """ - Returns the device data from the trial. + Return the device data for this entity type in a specific trial. + Parameters ---------- - trialSetName : str - The name of the trial set - + trialSetName : str trialName : str - The name of the trial. Returns ------- - + DataFrame """ trial = self.experiment.trialSet[trialSetName][trialName] startTime = trial.properties[TRIALSTART] endTime = trial.properties[TRIALEND] - StoreDataPerDevice = self.properties['StoreDataPerDevice'] - data = self._experimentData.getData(deviceType=self.entityType,deviceName=self.name,startTime=startTime,endTime=endTime, - perDevice=StoreDataPerDevice) + StoreDataPerDevice = self.properties["StoreDataPerDevice"] + data = self._experimentData.getData( + deviceType=self.entityType, + deviceName=self.name, + startTime=startTime, + endTime=endTime, + perDevice=StoreDataPerDevice, + ) return data class EntityWithData(argosDataObjects.Entity): + """ + Entity that knows how to pull its data from the shared experiment data engine. + """ - def __init__(self, entityType: EntityTypeWithData, metadata: dict, experimentData): + def __init__( + self, + entityType: EntityTypeWithData, + metadata: dict, + experimentData, + ): self._experimentData = experimentData super().__init__(entityType, metadata) - def getData(self,startTime=None, endTime=None): - StoreDataPerDevice = self.properties['StoreDataPerDevice'] - - return self._experimentData.getData(deviceType=self.entityType, - deviceName=self.name, - startTime=startTime, - endTime=endTime, - perDevice=StoreDataPerDevice) + def getData(self, startTime=None, endTime=None): + """ + Retrieve data for this specific entity (device). + """ + StoreDataPerDevice = self.properties["StoreDataPerDevice"] + + return self._experimentData.getData( + deviceType=self.entityType, + deviceName=self.name, + startTime=startTime, + endTime=endTime, + perDevice=StoreDataPerDevice, + ) diff --git a/hera/tests/env.template b/hera/tests/env.template deleted file mode 100644 index bb7b2a9e..00000000 --- a/hera/tests/env.template +++ /dev/null @@ -1,22 +0,0 @@ -# ======================= -# Local developer settings -# ======================= - -# ---- Python ---- -# Path to your Python executable. -PYTHON_BIN=python3 - -# Add the repo root so its modules are importable. -# IMPORTANT: use default expansion so set -u won't fail if PYTHONPATH is unset. -PYTHONPATH=/home/ilay/hera:${PYTHONPATH:-} - -# ---- Data roots ---- -# Root folder that contains the unit-test datasets (HGT, GeoTIFF, SHP, Parquet, etc.). -HERA_DATA_PATH=/home/ilay/hera_unittest_data - -# Some tests read from this var as well. Keep equal to HERA_DATA_PATH unless you know otherwise. -HERA_UNITTEST_DATA=/home/ilay/hera_unittest_data - -# ---- Test runner defaults (optional) ---- -# Named result set for expected outputs. Can be overridden by --result-set. -RESULT_SET=BASELINE diff --git a/hera/tests/run_all_definitions.py b/hera/tests/run_all_definitions.py index a7ea5611..49b86489 100644 --- a/hera/tests/run_all_definitions.py +++ b/hera/tests/run_all_definitions.py @@ -5,6 +5,7 @@ import glob import traceback import importlib +import os from pathlib import Path import pandas as pd @@ -34,6 +35,25 @@ def _require_result_set(): sys.exit(2) return rs + +def _get_base_expected_dir() -> Path: + """ + Resolve the base directory for expected results: + Prefer $HERA_UNITTEST_DATA/expected/. + Fallback to tests/expected/ if HERA_UNITTEST_DATA is not set. + Create the directory in prepare mode. + """ + rs = os.environ.get("RESULT_SET", "BASELINE") + env_root = os.environ.get("HERA_UNITTEST_DATA") + if env_root: + base = Path(env_root) / "expected" / rs + else: + base = Path("tests") / "expected" / rs + + if os.environ.get("PREPARE_EXPECTED_OUTPUT", "0") == "1": + base.mkdir(parents=True, exist_ok=True) + return base + def _expected_base_dir(): """ Return Path('/expected/'). diff --git a/hera/tests/run_all_json_tests.sh b/hera/tests/run_all_json_tests.sh index 72eb4ec5..46e758c7 100755 --- a/hera/tests/run_all_json_tests.sh +++ b/hera/tests/run_all_json_tests.sh @@ -49,13 +49,19 @@ if [[ -z "${RESULT_SET_ARG}" && -z "${RESULT_SET:-}" ]]; then fi export RESULT_SET="${RESULT_SET_ARG:-${RESULT_SET}}" -EXP_DIR="${REPO_ROOT}/tests/expected/${RESULT_SET}" +# Expected directory must live under HERA_UNITTEST_DATA +if [[ -z "${HERA_UNITTEST_DATA:-}" ]]; then + echo "HERA_UNITTEST_DATA is not set in tests/env.template" + exit 1 +fi +EXP_DIR="${HERA_UNITTEST_DATA}/expected/${RESULT_SET}" + if [[ "${MODE}" == "prepare" ]]; then mkdir -p "${EXP_DIR}" export PREPARE_EXPECTED_OUTPUT=1 else if [[ ! -d "${EXP_DIR}" ]]; then - echo "Expected results set '${RESULT_SET}' not found at tests/expected/${RESULT_SET}" + echo "Expected results set '${RESULT_SET}' not found at ${EXP_DIR}" echo "First create it: $0 prepare --result-set ${RESULT_SET}" exit 3 fi @@ -91,7 +97,7 @@ else exit 5 fi -echo "MODE=${MODE} | RESULT_SET=${RESULT_SET} | EXP_DIR=tests/expected/${RESULT_SET}" +echo "MODE=${MODE} | RESULT_SET=${RESULT_SET} | EXP_DIR=${EXP_DIR}" echo "RUNNER=${RUNNER}" echo "JSON_DIR=${JSON_DIR}" echo diff --git a/hera/toolkit.py b/hera/toolkit.py index 40758269..9e0071f9 100644 --- a/hera/toolkit.py +++ b/hera/toolkit.py @@ -1,13 +1,18 @@ from hera.datalayer import Project from hera.datalayer.datahandler import datatypes # for datatypes.CLASS -from hera.datalayer.datahandler import DataHandler_Class # הוסף אם לא קיים -import inspect import os -import pandas +import inspect import pydoc +import pandas as pd +from typing import Optional, List, Dict, Any + from hera.utils.logging import get_classMethod_logger +# --------------------------------------------------------------------------- +# Constants for Toolkit data sources +# --------------------------------------------------------------------------- + TOOLKIT_DATASOURCE_TYPE = "ToolkitDataSource" TOOLKIT_TOOLKITNAME_FIELD = "toolkit" TOOLKIT_DATASOURCE_NAME = "datasourceName" @@ -19,29 +24,263 @@ TOOLKIT_SAVEMODE_FILEANDDB = "DB" TOOLKIT_SAVEMODE_FILEANDDB_REPLACE = "DB_overwrite" -import pydoc -import pandas as pd -from typing import Optional -from hera.utils.data.toolkit_repository import ToolkitRepository # new import for DB integration + +# ====================================================================== +# abstractToolkit +# ====================================================================== + +class abstractToolkit(Project): + """ + Base class for Toolkits. + + * Inherits from Project – ולכן יש גישה לכל פונקציות ה־datalayer. + * מחזיק toolkitName ו־projectName. + * מוסיף מנגנון data sources שנשמרים כ-measurement documents מסוג + TOOLKIT_DATASOURCE_TYPE. + """ + + _toolkitname = None + _projectName = None + + _analysis = None # holds the datalayer layer. + _presentation = None # holds the presentation layer + + @property + def presentation(self): + """Access to the presentation layer.""" + return self._presentation + + @property + def analysis(self): + """Access to the datalayer layer.""" + return self._analysis + + @property + def toolkitName(self): + """The name of the toolkit.""" + return self._toolkitname + + @property + def projectName(self): + """The name of the project.""" + return self._projectName + + def __init__(self, toolkitName: str, projectName: Optional[str] = None, filesDirectory: Optional[str] = None): + """ + Initializes a new toolkit. + + Parameters + ---------- + toolkitName : str + The name of the toolkit. + + projectName : str or None + The project that the toolkit works in. + If None – Project's automatic project-name mechanism is used. + + filesDirectory : str + Directory to save datasource files. + """ + super().__init__(projectName=projectName, filesDirectory=filesDirectory) + logger = get_classMethod_logger(self, "init") + self._toolkitname = toolkitName + self._projectName = projectName + + @property + def classLoggerName(self): + return str(get_classMethod_logger(self, "{the_function_name}")).split(" ")[1] + + # ------------------------------------------------------------------ + # Data sources API + # ------------------------------------------------------------------ + + def getDataSourceList(self, **filters) -> List[str]: + """Returns a list of data source names for this toolkit.""" + docList = self.getMeasurementsDocuments( + type=TOOLKIT_DATASOURCE_TYPE, + toolkit=self.toolkitName, + **filters, + ) + + ret = [] + for doc in docList: + ret.append(doc["desc"]["datasourceName"]) + return ret + + def getDataSourceMap(self, **filters) -> List[Dict[str, Any]]: + """ + Return list of all data sources and their versions related to this toolkit. + """ + docList = self.getMeasurementsDocuments( + type=TOOLKIT_DATASOURCE_TYPE, + toolkit=self.toolkitName, + **filters, + ) + + ret = [] + for doc in docList: + dta = dict(dataFormat=doc["dataFormat"], resource=doc["resource"]) + dta.update(doc.desc) + ret.append(dta) + return ret + + def getDataSourceTable(self, **filters) -> pd.DataFrame: + """Build a pandas DataFrame from all data sources of this toolkit.""" + tables = [] + for sourceMap in self.getDataSourceMap(**filters): + table = pd.json_normalize(sourceMap) + tables.append(table) + + if not tables: + return pd.DataFrame() + else: + return pd.concat(tables, ignore_index=True) + + def getDataSourceDocumentsList(self, **kwargs): + """ + Return all the data source documents associated with this toolkit. + """ + queryDict = { + "type": TOOLKIT_DATASOURCE_TYPE, + TOOLKIT_TOOLKITNAME_FIELD: self.toolkitName, + } + queryDict.update(**kwargs) + return self.getMeasurementsDocuments(**queryDict) + + def getDataSourceDocument(self, datasourceName: Optional[str], version=None, **filters): + """ + Return the document of the datasource. + If version is not specified, return the latest version or default version (if configured). + """ + if datasourceName is not None: + filters[TOOLKIT_DATASOURCE_NAME] = datasourceName + if version is not None: + filters[TOOLKIT_DATASOURCE_VERSION] = version + else: + try: + defaultVersion = self.getConfig()[f"{datasourceName}_defaultVersion"] + filters[TOOLKIT_DATASOURCE_VERSION] = defaultVersion + except Exception: + pass + + filters[TOOLKIT_TOOLKITNAME_FIELD] = self.toolkitName + + docList = self.getMeasurementsDocuments( + type=TOOLKIT_DATASOURCE_TYPE, + **filters, + ) + + if len(docList) == 0: + ret = None + elif len(docList) == 1: + ret = docList[0] + else: + versionsList = [doc["desc"]["version"] for doc in docList] + latestVersion = max(versionsList) + docList = [doc for doc in docList if doc["desc"]["version"] == latestVersion] + ret = docList[0] + return ret + + def getDataSourceDocuments(self, datasourceName, version=None, **filters): + """ + Returns a list with the datasource (for API symmetry with measurements/cache). + """ + doc = self.getDataSourceDocument(datasourceName=datasourceName, version=version, **filters) + return [] if doc is None else [doc] + + def getDataSourceData(self, datasourceName=None, version=None, **filters): + """ + Returns the data from the datasource (or None if not found). + """ + filters[TOOLKIT_TOOLKITNAME_FIELD] = self.toolkitName + doc = self.getDataSourceDocument(datasourceName=datasourceName, version=version, **filters) + return None if doc is None else doc.getData() + + def addDataSource( + self, + dataSourceName: str, + resource: str, + dataFormat: str, + version=(0, 0, 1), + overwrite: bool = False, + **kwargs, + ): + """ + Adds a resource to the toolkit. + The type is always TOOLKIT_DATASOURCE_TYPE. + The toolkit name is added automatically to the description. + """ + kwargs[TOOLKIT_TOOLKITNAME_FIELD] = self.toolkitName + kwargs[TOOLKIT_DATASOURCE_NAME] = dataSourceName + kwargs[TOOLKIT_DATASOURCE_VERSION] = version + + if (self.getDataSourceDocument(dataSourceName, version=version) is None) or overwrite: + if self.getDataSourceDocument(dataSourceName, version=version) is not None: + # Exists -> delete and re-add + delargs = { + TOOLKIT_DATASOURCE_NAME: dataSourceName, + TOOLKIT_DATASOURCE_VERSION: version, + } + self.deleteDataSource(**delargs) + + doc = self.addMeasurementsDocument( + type=TOOLKIT_DATASOURCE_TYPE, + resource=resource, + dataFormat=dataFormat, + desc=kwargs, + ) + else: + raise ValueError( + f"Record {dataSourceName} (version {version}) already exists in project {self.projectName}. " + f"use overwrite=True to overwrite the existing document" + ) + + return doc + + def deleteDataSource(self, datasourceName, version=None, **filters): + """Delete a data source document.""" + doc = self.getDataSourceDocument(datasourceName=datasourceName, version=version, **filters) + if doc is not None: + doc.delete() + return doc + + def setDataSourceDefaultVersion(self, datasourceName: str, version: tuple): + """Set the default version for a given data source.""" + if len( + self.getMeasurementsDocuments( + type=TOOLKIT_DATASOURCE_TYPE, + **{"datasourceName": datasourceName, "version": version}, + ) + ) == 0: + raise ValueError(f"No DataSource with name={datasourceName} and version={version}.") + + self.setConfig(**{f"{datasourceName}_defaultVersion": version}) + print(f"{version} for dataSource {datasourceName} is now set to default.") -class ToolkitHome: +# ====================================================================== +# ToolkitHome +# ====================================================================== + +class ToolkitHome(abstractToolkit): """ Central registry for available toolkits (static + dynamic). - Provides: - - getToolkit(toolkitName, ...): locate & instantiate a toolkit class - - getToolkitTable(projectName): table of all toolkits (static + DB) - - registerToolkit(toolkitclass, ...): register a class into project datasources (dataFormat=Class) + + Responsibilities: + - getToolkit(toolkitName, ...): locate & instantiate a toolkit class. + - getToolkitTable(projectName): table of all toolkits (static + DB). + - registerToolkit(...): register a toolkit class as a ToolkitDataSource + using the abstractToolkit data source interface. """ - # -------- Save modes (kept for compatibility) -------- - TOOLKIT_SAVEMODE_NOSAVE = None - TOOLKIT_SAVEMODE_ONLYFILE = "File" - TOOLKIT_SAVEMODE_ONLYFILE_REPLACE = "File_overwrite" - TOOLKIT_SAVEMODE_FILEANDDB = "DB" - TOOLKIT_SAVEMODE_FILEANDDB_REPLACE = "DB_overwrite" + # Save modes (kept for compatibility) + TOOLKIT_SAVEMODE_NOSAVE = TOOLKIT_SAVEMODE_NOSAVE + TOOLKIT_SAVEMODE_ONLYFILE = TOOLKIT_SAVEMODE_ONLYFILE + TOOLKIT_SAVEMODE_ONLYFILE_REPLACE = TOOLKIT_SAVEMODE_ONLYFILE_REPLACE + TOOLKIT_SAVEMODE_FILEANDDB = TOOLKIT_SAVEMODE_FILEANDDB + TOOLKIT_SAVEMODE_FILEANDDB_REPLACE = TOOLKIT_SAVEMODE_FILEANDDB_REPLACE - # -------- Static toolkit identifiers -------- + # Static toolkit identifiers GIS_BUILDINGS = "GIS_Buildings" GIS_TILES = "GIS_Tiles" GIS_LANDCOVER = "GIS_LandCover" @@ -66,161 +305,209 @@ class ToolkitHome: GAUSSIANDISPERSION = "GaussianDispersion" MACHINELEARNING_DEEPLEARNING = "machine_deep_learning" - _toolkits = None + _toolkits: Dict[str, Dict[str, Any]] = None + + def __init__(self, projectName: Optional[str] = None, filesDirectory: Optional[str] = None): + """ + ToolkitHome itself הוא Toolkit (abstractToolkit): + - projectName יטען אוטומטית מ־caseConfiguration אם None. + - כל ה־ToolkitDataSource הדינמיים נרשמים תחת toolkitName="ToolkitHome". + """ + super().__init__(toolkitName="ToolkitHome", projectName=projectName, filesDirectory=filesDirectory) - def __init__(self): # Static built-in toolkits (internal source) self._toolkits = dict( GIS_Buildings=dict( cls="hera.measurements.GIS.vector.buildings.toolkit.BuildingsToolkit", desc=None, - type="measurements" + type="measurements", ), GIS_Tiles=dict( cls="hera.measurements.GIS.raster.tiles.TilesToolkit", desc=None, - type="measurements" + type="measurements", ), GIS_Vector_Topography=dict( cls="hera.measurements.GIS.vector.topography.TopographyToolkit", desc=None, - type="measurements" + type="measurements", ), GIS_Raster_Topography=dict( cls="hera.measurements.GIS.raster.topography.TopographyToolkit", desc=None, - type="measurements" + type="measurements", ), GIS_Demography=dict( cls="hera.measurements.GIS.vector.demography.DemographyToolkit", desc=None, - type="measurements" + type="measurements", ), GIS_LandCover=dict( cls="hera.measurements.GIS.raster.landcover.LandCoverToolkit", desc=None, - type="measurements" + type="measurements", ), RiskAssessment=dict( cls="hera.riskassessment.riskToolkit.RiskToolkit", desc=None, - type="riskassessment" + type="riskassessment", ), LSM=dict( cls="hera.simulations.LSM.toolkit.LSMToolkit", desc=None, - type="simulations" + type="simulations", ), OF_LSM=dict( cls="hera.simulations.openFoam.LSM.toolkit.OFLSMToolkit", desc=None, - type="simulations" + type="simulations", ), MeteoHighFreq=dict( cls="hera.measurements.meteorology.highfreqdata.toolkit.HighFreqToolKit", desc=None, - type="measurements" + type="measurements", ), MeteoLowFreq=dict( cls="hera.measurements.meteorology.lowfreqdata.toolkit.lowFreqToolKit", desc=None, - type="measurements" + type="measurements", ), hermesWorkflows=dict( cls="hera.simulations.hermesWorkflowToolkit.hermesWorkflowToolkit", desc=None, - type="simulations" + type="simulations", ), OpenFOAM=dict( cls="hera.simulations.openFoam.toolkit.OFToolkit", desc=None, - type="simulations" + type="simulations", ), WindProfile=dict( cls="hera.simulations.windProfile.toolkit.WindProfileToolkit", desc=None, - type="simulations" + type="simulations", ), GaussianDispersion=dict( cls="hera.simulations.gaussian.toolkit.gaussianToolkit", desc=None, - type="simulations" + type="simulations", ), machine_deep_learning=dict( cls="hera.simulations.machineLearningDeepLearning.toolkit.machineLearningDeepLearningToolkit", desc=None, - type="simulations" + type="simulations", ), experiment=dict( cls="hera.measurements.experiment.experiment.experimentHome", desc=None, - type="measurements" + type="measurements", ), ) - # --- Place this near the top of the file imports if needed --- - from hera.datalayer import Project - - # --- Inside class ToolkitHome, replace ONLY the "not found" branch in getToolkit(...) --- - - def getToolkit(self, toolkitName, projectName=None, filesDirectory=None, **kwargs): - """ - Locate a toolkit class by name (static registry or DB), then instantiate it. - If not found anywhere, return a lightweight fallback that wraps Project so that - repository JSON loading can still proceed without a concrete Toolkit class. - """ - # 1) Static registry (unchanged) - if toolkitName in self._toolkits: - clsName = self._toolkits[toolkitName]['cls'] - toolkitClass = pydoc.locate(clsName) - if toolkitClass is None: - raise ImportError(f"Cannot locate class: {clsName}") - return toolkitClass(projectName, filesDirectory=filesDirectory, **kwargs) - - # 2) Dynamic registry via DB (unchanged) - # - # repo = ToolkitRepository(projectName or "DefaultProject") - # doc = repo.getToolkitDocument(toolkitName) - # if doc: - # desc = getattr(doc, "desc", None) or (doc.get("desc", {}) if isinstance(doc, dict) else {}) - # resource = getattr(doc, "resource", None) or (doc.get("resource", "") if isinstance(doc, dict) else "") - # classpath = desc.get("classpath") or desc.get("cls") - # if classpath: - # norm_desc = dict(desc) - # norm_desc["classpath"] = classpath - # norm_desc.pop("cls", None) - # # Use the dynamic Class loader path when classpath exists - # return DataHandler_Class.getData(resource=resource, desc=norm_desc) - # # If there is a dynamic doc but no classpath, we'll fall through to the shim - - # Return the shim instance - #return _FallbackToolkit(toolkitName=toolkitName, projectName=projectName, filesDirectory=filesDirectory) - - # hera/toolkit.py (inside class ToolkitHome) - # ----------------------------------------------------------------------------- - # Auto-register a missing toolkit using classpath hints (from repository JSON - # or from the Toolkit document in DB) and then return an instance via getToolkit. - # ----------------------------------------------------------------------------- - def auto_register_and_get(self, - toolkitName: str, - projectName: str, - repositoryJSON: dict = None, - repositoryName: str = None, - params: dict = None, - version: tuple = (0, 0, 1)): + # Optional: keep a handle to the experiment toolkit (if available) + self.experimentTK = None + try: + self.experimentTK = self.getToolkit(self.EXPERIMENT) + except Exception: + self.experimentTK = None + + # ------------------------------------------------------------------ + # Internal helper for repository config (uses generic dataToolkit) + # ------------------------------------------------------------------ + + def _get_data_toolkit(self, projectName: str = None): + """ + Helper that returns a dataToolkit instance. + + We import dataToolkit lazily here to avoid circular imports + between hera.toolkit and hera.utils.data.toolkit. + dataToolkit itself works on the DEFAULT project internally. + """ + from hera.utils.data.toolkit import dataToolkit + return dataToolkit() + + # ------------------------------------------------------------------ + # Main API: getToolkit + # ------------------------------------------------------------------ + + def getToolkit(self, toolkitName: str, filesDirectory: Optional[str] = None, **kwargs): + """ + Locate and instantiate a toolkit by name. + + Resolution order: + 1) Static registry (self._toolkits). + 2) Dynamic ToolkitDataSource document registered via ToolkitHome + (type='ToolkitDataSource', toolkit='ToolkitHome'). + 3) Experiment toolkits, via experimentHome.getExperiment(...). + """ + # 1) Static built-in toolkits + if toolkitName in (self._toolkits or {}): + info = self._toolkits[toolkitName] + cls_path = info.get("cls") + toolkit_cls = pydoc.locate(cls_path) + if toolkit_cls is None: + raise ImportError(f"Cannot locate class: {cls_path}") + + # Static toolkits are also abstractToolkit derivatives + return toolkit_cls( + projectName=self.projectName, + filesDirectory=filesDirectory, + **kwargs, + ) + + # 2) Dynamic toolkits registered as ToolkitDataSource of ToolkitHome + doc = self.getDataSourceDocument(datasourceName=toolkitName) + if doc is not None: + tk = doc.getData() + if hasattr(tk, "setFilesDirectory") and filesDirectory is not None: + tk.setFilesDirectory(filesDirectory) + return tk + + # 3) Experiment toolkits fallback (experimentHome) + # experimentTK is an experimentHome instance when available. + if self.experimentTK is not None: + try: + # Direct call to experimentHome.getExperiment(...) + return self.experimentTK.getExperiment( + experimentName=toolkitName, + filesDirectory=filesDirectory, + ) + except Exception: + # experimentHome does not recognize this experiment name + pass + + # Nothing found in any registry + raise ValueError( + f"Toolkit '{toolkitName}' not found in static registry, ToolkitDataSource, " + f"or experiment toolkit in project '{self.projectName}'." + ) + + # ------------------------------------------------------------------ + # Auto-register + get (kept for compatibility – but uses datasource API) + # ------------------------------------------------------------------ + + def auto_register_and_get( + self, + toolkitName: str, + repositoryJSON: dict = None, + repositoryName: Optional[str] = None, + params: Optional[dict] = None, + version: tuple = (0, 0, 1), + ): """ Attempts to auto-register a missing toolkit and return an instance. + 1) Try to find a classpath hint in the repositoryJSON (if provided). - We look for keys like: repositoryJSON[toolkitName]["Registry"]["classpath"] - or ...["Registry"]["cls"]. 2) If not found, try the DB-backed Toolkit document (ToolkitRepository). - 3) Import the class, choose a repository to register into: - - repositoryName argument if provided, - - else the project's default repository (must exist). + 3) Import the class, choose a repository to register into + (explicit repositoryName or project's default). 4) Register via registerToolkit(...), then getToolkit(...) and return it. """ + from importlib import import_module + params = params or {} classpath_hint = None + projectName = self.projectName # 1) Classpath hint in the repository JSON if repositoryJSON: @@ -231,14 +518,6 @@ def auto_register_and_get(self, except Exception: pass - # 2) If still not found, try DB Toolkit document - if not classpath_hint: - from hera.utils.data.toolkit_repository import ToolkitRepository - repo = ToolkitRepository(projectName) - doc = repo.getToolkitDocument(toolkitName) - if doc and getattr(doc, "desc", None): - classpath_hint = doc.desc.get("classpath") or doc.desc.get("cls") - if not classpath_hint: raise ValueError( f"auto_register_and_get: no classpath hint found for toolkit '{toolkitName}'. " @@ -249,7 +528,6 @@ def auto_register_and_get(self, mod_name, _, cls_name = classpath_hint.rpartition(".") if not mod_name or not cls_name: raise ValueError(f"Invalid classpath hint: '{classpath_hint}'") - from importlib import import_module try: mod = import_module(mod_name) toolkit_cls = getattr(mod, cls_name) @@ -257,7 +535,14 @@ def auto_register_and_get(self, raise ImportError(f"Failed to import '{classpath_hint}' for toolkit '{toolkitName}'") from exc # Decide target repository for registration - repo_to_use = repositoryName or self.getDefaultRepository(projectName=projectName) + repo_to_use = (repositoryName or "").strip() + if not repo_to_use: + if projectName is None: + raise ValueError( + "auto_register_and_get: projectName is None and no repositoryName provided. " + "Cannot resolve default repository." + ) + repo_to_use = self.getDefaultRepository(projectName=projectName) if not repo_to_use: raise ValueError( f"auto_register_and_get: no target repository for project '{projectName}'. " @@ -271,149 +556,240 @@ def auto_register_and_get(self, params=params, version=version, overwrite=True, - projectName=projectName, repositoryName=repo_to_use, ) # Return an instance - return self.getToolkit(toolkitName=toolkitName, projectName=projectName) + return self.getToolkit(toolkitName=toolkitName) + + # ------------------------------------------------------------------ + # Listing toolkits (static + dynamic) + # ------------------------------------------------------------------ + + from typing import Optional, List, Dict + + def getToolkitDocuments(self, name: Optional[str] = None, projectName: Optional[str] = None) -> List[Dict]: + """ + Single source of truth for listing toolkits. + + This method returns a uniform list of "document-like" dictionaries coming from: + 1) The static in-memory registry (self._toolkits). + 2) Dynamic DB-backed toolkit records (type='ToolkitDataSource'). + 3) Experiments that are exposed via the 'experiment' toolkit + (so that experiments appear as toolkits in the same view). + + Each returned dict has the general shape: + { + "toolkit": , + "desc": { + "classpath": , + "type": , + "source": , + "repositoryName": , + "version": + } + } + """ + docs: List[Dict] = [] + + # ------------------------------------------------------------------ + # 1) Static toolkits from the in-memory registry + # ------------------------------------------------------------------ + for tk_name, info in (self._toolkits or {}).items(): + if name and tk_name != name: + continue + + docs.append( + { + "toolkit": tk_name, + "desc": { + # Fully-qualified classpath of the toolkit implementation + "classpath": info.get("cls", ""), + # Logical type of the toolkit (e.g. 'measurements', 'simulations', ...) + "type": info.get("type", "measurements"), + # Static entries are considered 'internal' + "source": "internal", + # Static toolkits do not come from a specific repository + "repositoryName": "", + # No explicit version for static entries + "version": "", + }, + } + ) + + # ------------------------------------------------------------------ + # 2) Dynamic toolkits stored in the DB as ToolkitDataSource documents + # ------------------------------------------------------------------ + if projectName: + try: + # The dataToolkit is used as a generic interface to measurements + # and to the underlying MongoDB-backed documents. + dt = self._get_data_toolkit(projectName=projectName) + dyn_docs = dt.getMeasurementsDocuments(type=TOOLKIT_DATASOURCE_TYPE) or [] + + for d in dyn_docs: + try: + desc = getattr(d, "desc", {}) or {} + tk_name = desc.get("datasourceName") or getattr(d, "datasourceName", None) + if not tk_name: + continue + if name and tk_name != name: + continue + + docs.append( + { + "toolkit": tk_name, + "desc": { + # Dynamic entries may carry a classpath for dynamic import + "classpath": desc.get("classpath", ""), + # Toolkit logical type; default to 'measurements' if missing + "type": desc.get("type", "") or "measurements", + # DB-backed entries are marked as coming from the DB + "source": desc.get("source", "") or "db", + # Repository is taken from the desc or from the document itself + "repositoryName": desc.get("repository", "") or getattr(d, "repository", ""), + # Version may be saved as a list or any other structure + "version": tuple(desc.get("version", ())) or getattr(d, "version", ""), + }, + } + ) + except Exception: + # Be forgiving in case some records are partially formed + continue + except Exception: + # If the project or DB is not available, we still return the static list. + pass + + # ------------------------------------------------------------------ + # 3) Experiments exposed as toolkits (via the 'experiment' toolkit) + # ------------------------------------------------------------------ + docs.extend(self.getExperimentToolkitDocuments(name=name)) + + return docs + + def getExperimentToolkitDocuments(self, name: Optional[str] = None) -> List[Dict]: + """ + Return experiment definitions as "toolkit-like" documents. + + The 'experiment' toolkit (experimentHome) exposes all experiments + of the project via getExperimentsMap(). This helper converts them + into the same normalized "document-like" shape used by + getToolkitDocuments(), so that experiments appear in the unified + toolkits table and can be discovered via the same CLI. + + Notes: + - Experiments do not have a direct classpath here; they are + instantiated via experimentHome.getExperiment(...), so the + 'classpath' field is left empty. + - The 'type' is marked as 'experiment'. + - The 'source' is marked as 'experiment' to distinguish them from + static or DB-backed toolkits. + """ + # If the experiment toolkit is not available, return an empty list + if self.experimentTK is None: + return [] + + try: + # experimentHome.getExperimentsMap() returns a dictionary where: + # keys = experiment names + # values = experiment metadata / configuration + exp_map = self.experimentTK.getExperimentsMap() + except Exception: + # If anything goes wrong while querying experiments, do not + # break the unified toolkit listing. + return [] + + docs: List[Dict] = [] + for exp_name in exp_map.keys(): + if name and exp_name != name: + continue + + docs.append( + { + "toolkit": exp_name, + "desc": { + # Experiments are not imported via a direct classpath + "classpath": "", + # Logical type to mark this as an experiment + "type": "experiment", + # Source tag to differentiate experiments from static/DB toolkits + "source": "experiment", + # Experiments are not associated with a repository name here + "repositoryName": "", + # No explicit version is tracked at this layer + "version": "", + }, + } + ) + + return docs + def getToolkitTable(self, projectName: Optional[str]): """ Build a DataFrame from getToolkitDocuments(...). - This avoids duplicated logic and guarantees both static + DB rows are represented. """ - import pandas as pd - docs = self.getToolkitDocuments(name=None, projectName=projectName) or [] rows = [] for d in docs: desc = d.get("desc", {}) - rows.append({ - "toolkit": d.get("toolkit", ""), - "cls": desc.get("classpath", ""), - "source": desc.get("source", ""), - "type": desc.get("type", ""), - "repositoryName": desc.get("repositoryName", ""), - "version": desc.get("version", ""), - }) + rows.append( + { + "toolkit": d.get("toolkit", ""), + "cls": desc.get("classpath", ""), + "source": desc.get("source", ""), + "type": desc.get("type", ""), + "repositoryName": desc.get("repositoryName", ""), + "version": desc.get("version", ""), + } + ) if not rows: - return pd.DataFrame(columns=["toolkit", "cls", "source", "type", "repositoryName", "version"]) + return pd.DataFrame( + columns=["toolkit", "cls", "source", "type", "repositoryName", "version"] + ) - # Drop duplicates by (toolkit, source) to avoid double rows for the same name/source df = pd.DataFrame(rows).drop_duplicates(subset=["toolkit", "source"], keep="first") return df - # בתוך class ToolkitHome (בקובץ hera/toolkit.py) + # ------------------------------------------------------------------ + # Registration helpers (default repository config) + # ------------------------------------------------------------------ - def registerToolkit( - self, - toolkitclass, - *, - projectName, - repositoryName, # <<< חדש: דרישת רפוזיטורי - datasource_name=None, - params=None, - version=(0, 0, 1), - overwrite=False, - ): + def setDefaultRepository(self, *, projectName: str, repositoryName: str, overwrite: bool = True): """ - Register a toolkit class as a datasource document in the given project & repository. - - It stores: - - resource: the directory that contains the module file (DataHandler_Class adds to sys.path) - - dataFormat: datatypes.CLASS - - desc: { - 'toolkit' : , - 'datasourceName': , - 'repository' : , # <<< נשמר במסמך - 'version' : (major, minor, patch), - 'classpath' : '', - 'parameters' : { ... } - } + Persist default repository name for a project so future calls can omit --repository. + The configuration is stored as a measurement document with type='RepositoryConfig'. """ - if projectName is None: - raise ValueError("registerToolkit: 'projectName' is required") + if not projectName: + raise ValueError("setDefaultRepository: 'projectName' is required") if not repositoryName: - raise ValueError("registerToolkit: 'repositoryName' is required") + raise ValueError("setDefaultRepository: 'repositoryName' is required") - import inspect, os - module_path = inspect.getfile(toolkitclass) - resource_dir = os.path.dirname(os.path.abspath(module_path)) - classpath = f"{toolkitclass.__module__}.{toolkitclass.__qualname__}" + dt = self._get_data_toolkit(projectName=projectName) - ds_name = datasource_name or toolkitclass.__name__ - params = params or {} - - desc = { - "toolkit": ds_name, - "datasourceName": ds_name, - "repository": repositoryName, # <<< שדה רפוזיטורי - "version": tuple(version), - "classpath": classpath, - "parameters": params, - } - - proj = Project(projectName=projectName) - - # בדיקת קיום לפי (type, repository, datasourceName, version) - existing = proj.getMeasurementsDocuments( - type="ToolkitDataSource", - repository=repositoryName, # <<< סינון לפי רפוזיטורי - datasourceName=ds_name, - version=tuple(version), - ) - if existing: - if not overwrite: - raise ValueError( - f"Toolkit datasource '{ds_name}' (version {version}) already exists in " - f"repository '{repositoryName}' of project '{projectName}'. " - f"Use overwrite=True to replace." - ) - for doc in existing: - doc.delete() - - doc = proj.addMeasurementsDocument( - type="ToolkitDataSource", - resource=resource_dir, - dataFormat=datatypes.CLASS, - desc=desc, - ) - return doc - - def setDefaultRepository(self, *, projectName: str, repositoryName: str, overwrite: bool = True): - """ - Persist default repository name for a project so future calls can omit --repository. - We store it as a tiny Project document under type='RepositoryConfig'. - """ - if not projectName: - raise ValueError("setDefaultRepository: 'projectName' is required") - if not repositoryName: - raise ValueError("setDefaultRepository: 'repositoryName' is required") - - proj = Project(projectName=projectName) # delete previous config if exists (by type) if overwrite: - old = proj.getMeasurementsDocuments(type="RepositoryConfig") + old = dt.getMeasurementsDocuments(type="RepositoryConfig") for d in old: d.delete() desc = {"defaultRepository": repositoryName} - # Try to pick a dataFormat constant if available. Fallback: omit the arg. df_arg = {} try: from hera.datalayer import datatypes as _dt - dfmt = getattr(_dt, "JSON", None) or getattr(_dt, "json", None) or getattr(_dt, "TEXT", None) + dfmt = getattr(_dt, "JSON", None) or getattr(_dt, "json", None) or getattr( + _dt, "TEXT", None + ) if dfmt is not None: df_arg["dataFormat"] = dfmt except Exception: pass - return proj.addMeasurementsDocument( + return dt.addMeasurementsDocument( type="RepositoryConfig", - resource=".", # trivial + resource=".", desc=desc, **df_arg, ) @@ -424,145 +800,79 @@ def getDefaultRepository(self, *, projectName: str) -> str: """ if not projectName: raise ValueError("getDefaultRepository: 'projectName' is required") - proj = Project(projectName=projectName) - docs = proj.getMeasurementsDocuments(type="RepositoryConfig") + + dt = self._get_data_toolkit(projectName=projectName) + docs = dt.getMeasurementsDocuments(type="RepositoryConfig") if not docs: return "" - # Take the newest (or first) return docs[0].desc.get("defaultRepository", "") or "" - def getDatasourceDocument( - self, - *, - projectName: str, - datasourceName: str, - repositoryName: str = None, - version=None, # tuple like (0,0,1) or None + # ------------------------------------------------------------------ + # Registration of toolkits using datasource interface + # ------------------------------------------------------------------ + + def registerToolkit( + self, + toolkitclass, + *, + repositoryName: str, + datasource_name: Optional[str] = None, + params: Optional[dict] = None, + version=(0, 0, 1), + overwrite: bool = False, ): """ - Fetch a ToolkitDataSource by (repository, datasourceName [, version]). - If repositoryName is None or '', fallback to the project's defaultRepository. - """ - if not projectName: - raise ValueError("getDatasourceDocument: 'projectName' is required") - if not datasourceName: - raise ValueError("getDatasourceDocument: 'datasourceName' is required") - - repo = (repositoryName or "").strip() - if not repo: - repo = self.getDefaultRepository(projectName=projectName) - if not repo: - raise ValueError( - "Repository name is not provided and no defaultRepository is set for the project. " - "Call setDefaultRepository(...) first, or pass repositoryName explicitly." - ) - - proj = Project(projectName=projectName) - - q = { - "type": "ToolkitDataSource", - "repository": repo, - "datasourceName": datasourceName, - } - if version is not None: - q["version"] = tuple(version) - - docs = proj.getMeasurementsDocuments(**q) - return docs[0] if docs else None + Register a toolkit class as a ToolkitDataSource *of ToolkitHome*. - # --- inside class ToolkitHome (toolkit.py) --- - from typing import Optional, List, Dict - - def getToolkitDocuments(self, name: Optional[str] = None, projectName: Optional[str] = None) -> List[Dict]: + We use abstractToolkit.addDataSource so: + - type = TOOLKIT_DATASOURCE_TYPE + - toolkit field = "ToolkitHome" + - datasourceName = + - extra metadata in desc: + 'repository', 'classpath', 'parameters' """ - Single source of truth for listing toolkits. - Returns a uniform list of "document-like" dicts coming from: - 1) Static registry (self._toolkits) - 2) Dynamic DB documents (type='ToolkitDataSource') of the given project - - Each item looks like: - { - "toolkit": "", - "desc": { - "classpath": "", - "type": "", # e.g. "measurements" - "source": "internal" | "db", - "repositoryName": "", - "version": (major, minor, patch) | "" - } - } - """ - docs: List[Dict] = [] + if not repositoryName: + raise ValueError("registerToolkit: 'repositoryName' is required") - # 1) Static: normalize entries to the unified shape - for tk_name, info in (self._toolkits or {}).items(): - if name and tk_name != name: - continue - docs.append({ - "toolkit": tk_name, - "desc": { - "classpath": info.get("cls", ""), - "type": info.get("type", "measurements"), - "source": "internal", - "repositoryName": "", - "version": "", - } - }) + module_path = inspect.getfile(toolkitclass) + resource_dir = os.path.dirname(os.path.abspath(module_path)) + classpath = f"{toolkitclass.__module__}.{toolkitclass.__qualname__}" - # 2) Dynamic (DB): query the project's measurements for type='ToolkitDataSource' - if projectName: - try: - from hera.datalayer import Project - proj = Project(projectName=projectName) - dyn_docs = proj.getMeasurementsDocuments(type="ToolkitDataSource") or [] - for d in dyn_docs: - try: - # Many implementations store all useful fields under desc - desc = getattr(d, "desc", {}) or {} - tk_name = desc.get("datasourceName") or getattr(d, "datasourceName", None) - if not tk_name: - continue - if name and tk_name != name: - continue + ds_name = datasource_name or toolkitclass.__name__ + params = params or {} - docs.append({ - "toolkit": tk_name, - "desc": { - "classpath": desc.get("classpath", ""), - # Fallback to "measurements" if type is not provided - "type": desc.get("type", "") or "measurements", - "source": "db", - # Repository and version may appear either on desc or the document - "repositoryName": desc.get("repository", "") or getattr(d, "repository", ""), - "version": tuple(desc.get("version", ())) or getattr(d, "version", ""), - } - }) - except Exception: - # Be forgiving for partially-formed rows - pass - except Exception: - # No project/DB available: static list is still valuable - pass + extra_desc = { + "repository": repositoryName, + "classpath": classpath, + "parameters": params, + "type": "ToolkitDataSource", + "source": "db", + } - return docs + return self.addDataSource( + dataSourceName=ds_name, + resource=resource_dir, + dataFormat=datatypes.CLASS, + version=tuple(version), + overwrite=overwrite, + **extra_desc, + ) - # --- Add inside class ToolkitHome (toolkit.py) --- + # ------------------------------------------------------------------ + # JSON import helpers (unchanged, still valid עם הממשק החדש) + # ------------------------------------------------------------------ - def import_toolkits_from_json(self, *, projectName: str, json_path: str, - default_repository: str = None, overwrite: bool = True) -> list: + def import_toolkits_from_json( + self, + *, + projectName: str, + json_path: str, + default_repository: str = None, + overwrite: bool = True, + ) -> list: """ - Read a simple JSON file and register all Toolkits it declares into the given project. - Each entry under 'toolkits' should include: - - name (datasourceName) - - classpath (module.Class) - - version [major,minor,patch] (optional, defaults to [0,0,1]) - - parameters {} (optional) - The repository is taken from: - - 'repository' at the root of the JSON - - else default_repository argument - - else project's default repository (getDefaultRepository) - - Returns a list of toolkit names that were registered. + Read a JSON file and register all Toolkits it declares into the given project. + (Uses registerToolkit -> datasource interface.) """ import json from pydoc import locate @@ -595,15 +905,12 @@ def import_toolkits_from_json(self, *, projectName: str, json_path: str, if not name or not classpath: raise ValueError(f"Toolkit entry missing name/classpath: {item}") - # locate class tk_class = locate(classpath) if tk_class is None: raise ImportError(f"Cannot locate class by classpath: {classpath}") - # register self.registerToolkit( toolkitclass=tk_class, - projectName=projectName, repositoryName=repo_to_use, datasource_name=name, params=params, @@ -616,18 +923,10 @@ def import_toolkits_from_json(self, *, projectName: str, json_path: str, def import_experiments_from_json(self, *, projectName: str, json_path: str) -> list: """ - From the same JSON, load raw experiment measurements (legacy path). - Each entry under 'experiments' is: - { - "name": "ExpName", - "data": [ - { "type": "Experiment_rawData", "resource": "...", "dataFormat": "parquet", "desc": {...}, "isRelativePath": true } - ] - } - Returns a list of experiment names that were loaded. + Load experiments from JSON into the given project as measurements documents. + (לוגיקה קיימת – לא קשורה ישירות ל-datasource של toolkits.) """ import json - import os from hera.datalayer import Project if not projectName: @@ -638,14 +937,16 @@ def import_experiments_from_json(self, *, projectName: str, json_path: str) -> l experiments = payload.get("experiments") or [] if not isinstance(experiments, list): - raise ValueError("'experiments' must be a list") + raise ValueError("'experiments' must be a list in the JSON file") proj = Project(projectName=projectName) loaded = [] + base_dir = os.path.dirname(os.path.abspath(json_path)) for exp in experiments: exp_name = exp.get("name") data_items = exp.get("data", []) + for di in data_items: typ = di.get("type") resource = di.get("resource") @@ -653,11 +954,12 @@ def import_experiments_from_json(self, *, projectName: str, json_path: str) -> l desc = di.get("desc", {}) is_rel = bool(di.get("isRelativePath", False)) + if not typ or not resource or not data_fmt: + continue + res_path = resource if is_rel: - # שמירה על התנהגות "יחסית" לקובץ ה-JSON - base = os.path.dirname(os.path.abspath(json_path)) - res_path = os.path.abspath(os.path.join(base, resource)) + res_path = os.path.abspath(os.path.join(base_dir, resource)) proj.addMeasurementsDocument( type=typ, @@ -665,346 +967,8 @@ def import_experiments_from_json(self, *, projectName: str, json_path: str) -> l dataFormat=data_fmt, desc=desc, ) + if exp_name and exp_name not in loaded: loaded.append(exp_name) return loaded - - -class abstractToolkit(Project): - """ - A base class for Toolkits. - - * Like project, it is initialized with a project name. - If the toolkit works on data, it should be present in that project. - - * Inherits from project and therefore exposes all the datalayer functions. - - * Holds the toolkit name, and references to the datalayer and presentation layers. - - * Adds a mechanism (setConfig,getConfig) for saving configuration in the DB. the settings are specific for a project. - - * Adds a mechanism to list, get and add data sources. - - A data source will always be saved as a measurement document. - - Each source has the following properties in the description (except for the other properties): - * name : str - * toolkit : str - * projectName :str - * version : tuple (major version, minor varsion, bug fix). - * the type is TOOLKIT_DATASOURCE_TYPE. - * metadata: dict with additional metadata of the datasource. - - - The toolkit can have a default source for the project. - A default data source is defined with its name and version - If the version is not supplied, takes the latest version. - - - - """ - _toolkitname = None - _projectName = None - - _analysis = None # holds the datalayer layer. - _presentation = None # holds the presentation layer - - @property - def presentation(self): - """ - Access to the presentation layer - :return: - """ - return self._presentation - - @property - def analysis(self): - """ - Access to the datalayer layer - :return: - """ - return self._analysis - - @property - def toolkitName(self): - """ - The name of the toolkit name - :return: - """ - return self._toolkitname - - @property - def projectName(self): - """ - The name of the project - :return: - """ - return self._projectName - - def __init__(self, toolkitName, projectName, filesDirectory=None): - """ - Initializes a new toolkit. - - Parameters - ---------- - - toolkitName: str - The name of the toolkit - - projectName: str - The project that the toolkit works in. - - filesDirectory: str - The directory to save datasource - - """ - super().__init__(projectName=projectName, filesDirectory=filesDirectory) - logger = get_classMethod_logger(self, "init") - self._toolkitname = toolkitName - - @property - def classLoggerName(self): - return str(get_classMethod_logger(self, "{the_function_name}")).split(" ")[1] - - def getDataSourceList(self, **filters): - """ - Returns a list of the data source names - Parameters - ---------- - filters - - Returns - ------- - - """ - docList = self.getMeasurementsDocuments(type=TOOLKIT_DATASOURCE_TYPE, - toolkit=self.toolkitName, - **filters) - - ret = [] - for doc in docList: - ret.append(doc['desc']['datasourceName']) - - return ret - - def getDataSourceMap(self, **filters): - """ - Return the list of all data sources and their versions that are related to this toolkit - - Parameters - ---------- - asPandas: bool - If true, convert to pandas. - - filters: parameters - Additional parameters to query the templates - - Returns - ------- - list of dicts or pandas - """ - docList = self.getMeasurementsDocuments(type=TOOLKIT_DATASOURCE_TYPE, - toolkit=self.toolkitName, - **filters) - - ret = [] - for doc in docList: - dta = dict(dataFormat=doc['dataFormat'], - resource=doc['resource']) - dta.update(doc.desc) - ret.append(dta) - return ret - - def getDataSourceTable(self, **filters): - - Table = [] - for sourceMap in self.getDataSourceMap(**filters): - table = pandas.json_normalize(sourceMap) - Table.append(table) - - if len(Table) == 0: - return pandas.DataFrame() - else: - return pandas.concat((Table), ignore_index=True) - - def getDataSourceDocumentsList(self, **kwargs): - """ - Return all the datasources associated with this toolkit. - - Returns - ------- - List of docs. - """ - queryDict = {"type": TOOLKIT_DATASOURCE_TYPE, - TOOLKIT_TOOLKITNAME_FIELD: self.toolkitName} - queryDict.update(**kwargs) - return self.getMeasurementsDocuments(**queryDict) - - def getDataSourceDocument(self, datasourceName, version=None, **filters): - """ - Return the document of the datasource. - If version is not specified, return the latest version. - - Returns a single document. - - Parameters - ---------- - datasourceName: str - The datasourceName of the source - if None, return the default source (if set). - - version: tuple - The version of the source. - if not found, return the latest source - - - filters: - Additional parameters to the query. - - Returns - ------- - The document of the source. (None if not found) - """ - if datasourceName is not None: - filters[TOOLKIT_DATASOURCE_NAME] = datasourceName - if version is not None: - filters[TOOLKIT_DATASOURCE_VERSION] = version - else: - try: - defaultVersion = self.getConfig()[f"{datasourceName}_defaultVersion"] - filters[TOOLKIT_DATASOURCE_VERSION] = defaultVersion - except: - pass - - filters[TOOLKIT_TOOLKITNAME_FIELD] = self.toolkitName # {'toolkit' : self.toolkitName} - - docList = self.getMeasurementsDocuments(type=TOOLKIT_DATASOURCE_TYPE, **filters) - - if len(docList) == 0: - ret = None - - elif len(docList) == 1: - ret = docList[0] - - elif len(docList) > 1: - versionsList = [doc['desc']['version'] for doc in docList] - latestVersion = max(versionsList) - docList = [doc for doc in docList if doc['desc']['version'] == latestVersion] - ret = docList[0] - return ret - - def getDataSourceDocuments(self, datasourceName, version=None, **filters): - """ - Returns a list with the datasource. This is for the complteness of the interface. - That is, making it similar to the Measurement, Cache and Simulation document retrieval. - - Parameters - ---------- - datasourceName: str - The datasourceName of the source - if None, return the default source (if set). - - version: tuple - The version of the source. - if not found, return the latest source - - - filters: - Additional parameters to the query. - - Returns - ------- - A list that containes the document of the source. (empty list if not found) - """ - doc = self.getDataSourceDocument(datasourceName=datasourceName, version=version, **filters) - return [] if doc is None else [doc] - - def getDataSourceData(self, datasourceName=None, version=None, **filters): - """ - Returns the data from the datasource. - - Parameters - ---------- - - datasourceName: str - The datasourceName of the source - if None, return the default source (if set). - - version: tuple - The version of the source. - if not found, return the latest source - - - filters: dict - additional filters to the query. - - Returns - ------- - The data of the source. (None if not found) - """ - filters[TOOLKIT_TOOLKITNAME_FIELD] = self.toolkitName # {'toolkit' : self.toolkitName} - doc = self.getDataSourceDocument(datasourceName=datasourceName, version=version, **filters) - return None if doc is None else doc.getData() - - def addDataSource(self, dataSourceName, resource, dataFormat, version=(0, 0, 1), overwrite=False, **kwargs): - """ - Adds a resource to the toolkit. - The type is always TOOLKIT_DATASOURCE_TYPE. - The toolkit name is added to the description. - - Parameters - ---------- - dataSourceName: str - The name of the data source - - version: tuple (of int) - A 3-tuple of the version - - resource: str - The resource - - dataFormat: str - A string of a datatypes. - - kwargs: dict - The parameters - - Returns - ------- - The document of the datasource. - """ - - kwargs[TOOLKIT_TOOLKITNAME_FIELD] = self.toolkitName - kwargs[TOOLKIT_DATASOURCE_NAME] = dataSourceName - kwargs[TOOLKIT_DATASOURCE_VERSION] = version - if (self.getDataSourceDocument(dataSourceName, version=version) is None) or overwrite: - if self.getDataSourceDocument(dataSourceName, version=version) is not None: # not None = Exist - # print("Delete existing, and add new data source.") - delargs = {TOOLKIT_DATASOURCE_NAME: dataSourceName, - TOOLKIT_DATASOURCE_VERSION: version} - - self.deleteDataSource(**delargs) - # else: - # print("Does not exist: add data source.") - - doc = self.addMeasurementsDocument(type=TOOLKIT_DATASOURCE_TYPE, - resource=resource, - dataFormat=dataFormat, - desc=kwargs) - else: - raise ValueError( - f"Record {dataSourceName} (version {version}) already exists in project {self.projectName}. use overwrite=True to overwrite on the existing document") - - return doc - - def deleteDataSource(self, datasourceName, version=None, **filters): - - doc = self.getDataSourceDocument(datasourceName=datasourceName, version=version, **filters) - doc.delete() - - return doc - - def setDataSourceDefaultVersion(self, datasourceName: str, version: tuple): - if len(self.getMeasurementsDocuments(type="ToolkitDataSource", **{"datasourceName": datasourceName, - "version": version})) == 0: - raise ValueError(f"No DataSource with name={datasourceName} and version={version}.") - - self.setConfig(**{f"{datasourceName}_defaultVersion": version}) - print(f"{version} for dataSource {datasourceName} is now set to default.") \ No newline at end of file diff --git a/hera/utils/data/CLI.py b/hera/utils/data/CLI.py index 77a07fcc..de78e743 100644 --- a/hera/utils/data/CLI.py +++ b/hera/utils/data/CLI.py @@ -3,23 +3,19 @@ import getpass import json import logging -from ...datalayer import getProjectList,Project,createProjectDirectory,removeConnection,addOrUpdateDatabase,getMongoJSON +from ...datalayer import getProjectList, Project, createProjectDirectory, removeConnection, addOrUpdateDatabase, getMongoJSON from ...datalayer import All as datalayer_All from .. import loadJSON from .toolkit import dataToolkit import pandas +from ...toolkit import ToolkitHome +from pydoc import locate # for resolving classpath -> class from tabulate import tabulate + def project_list(arguments): """ List all the projects of the user. - Parameters - ---------- - arguments - - Returns - ------- - """ connectionName = getpass.getuser() if arguments.connectionName is None else arguments.connectionName @@ -27,10 +23,10 @@ def project_list(arguments): ttl = f"Projects in the connection {connectionName}" print("\n") print(ttl) - print("-"*len(ttl)) + print("-" * len(ttl)) projList = [] for projName in projectList: - projDesct = {"Project Name" : projName} + projDesct = {"Project Name": projName} if arguments.fulldetails: proj = Project(projectName=projName, connectionName=connectionName) cacheCount = len(proj.getCacheDocuments()) @@ -46,69 +42,58 @@ def project_list(arguments): df = pandas.DataFrame(projList).sort_values("Project Name") with pandas.option_context('display.max_rows', None, - 'display.max_columns', None, - 'display.width', 1000, - 'display.precision', 3, - 'display.colheader_justify', 'center'): + 'display.max_columns', None, + 'display.width', 1000, + 'display.precision', 3, + 'display.colheader_justify', 'center'): print(df) + print("-" * len(ttl)) - print("-"*len(ttl)) def project_create(arguments): """ Creating a directory and a project. - - The project is a caseConfiguration file with the configuration name. - - Parameters - ---------- - arguments : - -- directory: the directory to use - -- database: the name of the DB to use - - Returns - ------- - """ if arguments.directory is None: - directory = os.path.join(os.getcwd(),arguments.projectName) + directory = os.path.join(os.getcwd(), arguments.projectName) else: directory = arguments.directory - createProjectDirectory(outputPath=directory,projectName=arguments.projectName) + createProjectDirectory(outputPath=directory, projectName=arguments.projectName) print(f"Created project {arguments.projectName} in directory {directory}") if arguments.loadRepositories: dtk = dataToolkit() - dtk.loadAllDatasourcesInAllRepositoriesToProject(projectName=arguments.projectName,overwrite=arguments.overwrite) + dtk.loadAllDatasourcesInAllRepositoriesToProject(projectName=arguments.projectName, + overwrite=arguments.overwrite) + def project_dump(arguments): - fullQuery=dict(projectName = arguments.projectName) + fullQuery = dict(projectName=arguments.projectName) for queryElement in arguments.query: fullQuery[queryElement.split('=')[0]] = eval(queryElement.split('=')[1]) - docList = [] for doc in datalayer_All.getDocuments(**fullQuery): docDict = doc.asDict() - if ('docid' not in docDict['desc']): + if 'docid' not in docDict['desc']: docDict['desc']['docid'] = str(doc.id) docList.append(docDict) - outStr = json.dumps(docList,indent=4) + outStr = json.dumps(docList, indent=4) outputFileName = arguments.fileName if outputFileName is not None: - with open(outputFileName,"w") as outputFile: + with open(outputFileName, "w") as outputFile: outputFile.write(outStr) - if arguments.outputFormat=='json': + if arguments.outputFormat == 'json': print(outStr) - elif arguments.outputFormat=='table': + elif arguments.outputFormat == 'table': df = pandas.DataFrame(docList) with pandas.option_context('display.max_rows', None, 'display.max_columns', None, @@ -123,9 +108,9 @@ def project_dump(arguments): def project_load(arguments): docsDict = loadJSON(arguments.file) - proj = Project(projectName=arguments.projectName) + proj = Project(projectName=arguments.projectName) - for indx,doc in enumerate(docsDict): + for indx, doc in enumerate(docsDict): print(f"Loading document {indx}/{len(docsDict)}") proj.addDocumentFromDict(docsDict.get(doc)) @@ -134,7 +119,7 @@ def repository_list(argumets): dtk = dataToolkit() repDataframe = dtk.getRepositoryTable() - if len(repDataframe) ==0: + if len(repDataframe) == 0: print("The user does not have repositories.") else: with pandas.option_context('display.max_rows', None, @@ -145,6 +130,7 @@ def repository_list(argumets): 'display.colheader_justify', 'center'): print(repDataframe) + def repository_add(argumets): logger = logging.getLogger("hera.bin.repository_add") dtk = dataToolkit() @@ -156,6 +142,7 @@ def repository_add(argumets): repositoryPath=argumets.repositoryName, overwrite=argumets.overwrite) + def repository_remove(arguments): logger = logging.getLogger("hera.bin.repository_remove") dtk = dataToolkit() @@ -172,20 +159,20 @@ def repository_show(arguments): datasourceName = arguments.repositoryName logger.info(f"Listing the datasource {datasourceName}") repositoryData = dtk.getDataSourceData(datasourceName=datasourceName) - dataTypeList = ['DataSource','Measurements','Cache','Simulations'] + dataTypeList = ['DataSource', 'Measurements', 'Cache', 'Simulations'] for toolkitName, toolDesc in repositoryData.items(): ttl = f"\t\t\033[1mToolkit:\033[0m {toolkitName}" - print("#"*(2*len(ttl.expandtabs()))) + print("#" * (2 * len(ttl.expandtabs()))) print(ttl) - print("#"*(2*len(ttl.expandtabs()))) + print("#" * (2 * len(ttl.expandtabs()))) for datatype in dataTypeList: - print("="*len(datatype)) + print("=" * len(datatype)) print(f"{datatype}") - print("="*len(datatype)) + print("=" * len(datatype)) - for repName,repItems in toolDesc.get(datatype,{}).items(): + for repName, repItems in toolDesc.get(datatype, {}).items(): ttl = f"\033[1mName:\033[0m {repName}" print(f"\t{ttl}") print("-" * (2 * len(ttl.expandtabs()))) @@ -196,9 +183,10 @@ def repository_show(arguments): 'display.max_colwidth', None, 'display.precision', 3, 'display.colheader_justify', 'center'): - print(pandas.DataFrame.from_dict(repItems['item'],orient='index',columns=['Value'])) + print(pandas.DataFrame.from_dict(repItems['item'], orient='index', columns=['Value'])) print("\n") + def repository_load(arguments): logger = logging.getLogger("hera.bin.repository_load") dtk = dataToolkit() @@ -210,12 +198,13 @@ def repository_load(arguments): projectName = None logger.info(f"Loading the repository {repositoryFile} to the project {projectName if projectName is not None else 'default project'}") - repositoryJSON= loadJSON(repositoryFile) + repositoryJSON = loadJSON(repositoryFile) dtk.loadAllDatasourcesInRepositoryJSONToProject(projectName=projectName, repositoryJSON=repositoryJSON, basedir=os.path.dirname(os.path.abspath(arguments.repositoryName)), overwrite=arguments.overwrite) + def display_datasource_versions(arguments): proj = Project(projectName=arguments.projectName) datasources = [] @@ -229,7 +218,7 @@ def display_datasource_versions(arguments): d['version'] = document['desc']['version'] if arguments.datasource: - if arguments.datasource==d['datasourceName']: + if arguments.datasource == d['datasourceName']: datasources.append(d) else: datasources.append(d) @@ -244,7 +233,7 @@ def display_datasource_versions(arguments): d['datasourceName'] = document['desc']['datasourceName'] if arguments.datasource: - if arguments.datasource==d['datasourceName']: + if arguments.datasource == d['datasourceName']: default_version = config.get(f"{arguments.datasource}_defaultVersion") else: default_version = None @@ -255,12 +244,10 @@ def display_datasource_versions(arguments): d['DEFAULT_VERSION'] = default_version datasources.append(d) - - except: pass - if len(datasources)!=0: + if len(datasources) != 0: headers = datasources[0].keys() rows = [d.values() for d in datasources] print(tabulate(rows, headers=headers, tablefmt="grid")) @@ -270,11 +257,13 @@ def display_datasource_versions(arguments): else: print(f"No data to display. Are you sure datasource {arguments.datasource} and project {arguments.projectName} exists?") + def update_datasource_default_version(arguments): logger = logging.getLogger("hera.bin.update_datasource_version") arguments.version = tuple(int(item.strip()) for item in arguments.version.split(',')) proj = Project(projectName=arguments.projectName) - proj.setDataSourceDefaultVersion(datasourceName=arguments.datasource,version=arguments.version) + proj.setDataSourceDefaultVersion(datasourceName=arguments.datasource, version=arguments.version) + def update(arguments): logger = logging.getLogger("hera.bin.update") @@ -303,39 +292,30 @@ def update(arguments): dtk = dataToolkit() dtk.loadAllDatasourcesInAllRepositoriesToProject(projectName=projectName, overwrite=arguments.overwrite) + def db_list(arguments): """ List the databases in the - Parameters - ---------- - arguments - - Returns - ------- - - - - - """ dbconfig = getMongoJSON() conList = [] - for connectionName,connectionData in dbconfig.items(): - condict = {"Connection Name" : connectionName} + for connectionName, connectionData in dbconfig.items(): + condict = {"Connection Name": connectionName} if arguments.fulldetails: condict.update(connectionData) conList.append(condict) - df = pandas.DataFrame(conList).rename(columns=dict(dbIP="IP",dbName="databaseName")) + df = pandas.DataFrame(conList).rename(columns=dict(dbIP="IP", dbName="databaseName")) with pandas.option_context('display.max_rows', None, - 'display.max_columns', None, - 'display.width', 1000, - 'display.precision', 3, - 'display.colheader_justify', 'center'): + 'display.max_columns', None, + 'display.width', 1000, + 'display.precision', 3, + 'display.colheader_justify', 'center'): print(df) + def db_create(arguments): addOrUpdateDatabase(connectionName=arguments.connectionName, username=arguments.username, @@ -343,28 +323,23 @@ def db_create(arguments): databaseIP=arguments.IP, databaseName=arguments.databaseName) + def db_remove(arguments): removeConnection(arguments.connectionName) -import logging -from tabulate import tabulate -from ...toolkit import ToolkitHome -from pydoc import locate # for resolving classpath -> class -# --- in hera/utils/data/CLI.py --- +# --- Toolkit related CLI --- + def toolkit_list(arguments): """ Print a combined list of toolkits (static + dynamic from DB) for a project. Uses ToolkitHome.getToolkitDocuments(...) as the single source of truth. """ - import logging - from tabulate import tabulate logger = logging.getLogger("hera.utils.CLI.toolkit_list") project = arguments.project try: - from ...toolkit import ToolkitHome - th = ToolkitHome() + th = ToolkitHome(projectName=project) docs = th.getToolkitDocuments(name=None, projectName=project) or [] rows = [] @@ -408,17 +383,16 @@ def toolkit_register(arguments): version = (0, 0, 1) try: - th = ToolkitHome() + th = ToolkitHome(projectName=project) # Resolve classpath -> class toolkit_cls = locate(cls_path) if toolkit_cls is None: raise ImportError(f"Cannot locate class by classpath: {cls_path}") - # Call registerToolkit with a class object + # Call registerToolkit with a class object (no projectName – ToolkitHome כבר יודע) th.registerToolkit( toolkitclass=toolkit_cls, - projectName=project, repositoryName=repository, datasource_name=name, version=version, @@ -433,41 +407,30 @@ def toolkit_register(arguments): def toolkit_load(arguments): """ Instantiate a toolkit by name. - Delegates to ToolkitHome.getToolkit (static -> DB -> dynamic import). + Delegates to ToolkitHome.getToolkit (static + dynamic + experiments). """ logger = logging.getLogger("hera.utils.CLI.toolkit_load") project = arguments.project name = arguments.name try: - th = ToolkitHome() - try: - tk = th.getToolkit(toolkitName=name, projectName=project) - except Exception as ex: - # Optional: try auto-register (if you הוספת ב-toolkit.py) - auto = getattr(th, "auto_register_and_get", None) - if callable(auto): - tk = auto(toolkitName=name, projectName=project) - else: - raise ex - + # ToolkitHome itself is a toolkit; projectName is loaded automatically + th = ToolkitHome(projectName=project) + tk = th.getToolkit(toolkitName=name) print(f"Loaded toolkit: {getattr(tk, 'name', name)}") except Exception as e: logger.exception(e) print(f"[ERROR] {e}") -# --- Add to hera/utils/data/CLI.py --- def toolkit_default_repo_show(arguments): """ Show the project's default repository, via ToolkitHome.getDefaultRepository(projectName=...). """ - import logging logger = logging.getLogger("hera.utils.CLI.toolkit_default_repo_show") project = getattr(arguments, "project", None) or "DefaultProject" try: - from ...toolkit import ToolkitHome - th = ToolkitHome() - repo = th.getDefaultRepository(projectName=project) # אתה כבר מימשת + th = ToolkitHome(projectName=project) + repo = th.getDefaultRepository(projectName=project) print(repo if repo else "") except Exception as e: logger.exception(e) @@ -478,7 +441,6 @@ def toolkit_default_repo_set(arguments): """ Set the project's default repository, via ToolkitHome.setDefaultRepository(projectName=..., repositoryName=...). """ - import logging logger = logging.getLogger("hera.utils.CLI.toolkit_default_repo_set") project = getattr(arguments, "project", None) or "DefaultProject" repo_name = getattr(arguments, "repository", None) @@ -486,9 +448,8 @@ def toolkit_default_repo_set(arguments): print("[ERROR] --repository is required") return try: - from ...toolkit import ToolkitHome - th = ToolkitHome() - th.setDefaultRepository(projectName=project, repositoryName=repo_name) # אתה כבר מימשת + th = ToolkitHome(projectName=project) + th.setDefaultRepository(projectName=project, repositoryName=repo_name) print(f"Default repository set to '{repo_name}' for project '{project}'.") except Exception as e: logger.exception(e) @@ -499,24 +460,18 @@ def toolkit_import_json(arguments): """ Import a JSON repository that declares toolkits (and optionally experiments), and register them into the project. - Usage: - hera-toolkit import-json --project --file [--no-experiments] """ - import logging logger = logging.getLogger("hera.utils.CLI.toolkit_import_json") project = getattr(arguments, "project", None) json_path = getattr(arguments, "file", None) no_experiments = getattr(arguments, "no_experiments", False) try: - from ...toolkit import ToolkitHome - th = ToolkitHome() + th = ToolkitHome(projectName=project) - # טולקיטים registered = th.import_toolkits_from_json(projectName=project, json_path=json_path) print(f"Registered toolkits: {registered}" if registered else "No toolkits in JSON.") - # ניסויים (אופציונלי) if not no_experiments: exps = th.import_experiments_from_json(projectName=project, json_path=json_path) if exps: @@ -526,61 +481,71 @@ def toolkit_import_json(arguments): logger.exception(e) print(f"[ERROR] {e}") -# --- in hera/utils/data/CLI.py --- -def project_measurements_list(arguments): + +# שים לב: הקטע הבא נראה כאילו אמור להיות בתוך class, אבל אני משאיר כמו שהיה +def project_measurements_list(args): """ - Print project measurements documents in a concise table. - Filters: - --project: project name (if omitted, Project may derive from CWD by your implementation) - --type: filter by 'type' field (e.g., ToolkitDataSource, Experiment_rawData) - --contains: substring filter on 'datasourceName' or 'resource' + Implementation for: + hera-project project measurements list + (used also by 'simulations list' and 'cache list' via shortcut) """ - import logging - from tabulate import tabulate - logger = logging.getLogger("hera.utils.CLI.project_measurements_list") - - project = getattr(arguments, "project", None) - typ = getattr(arguments, "type", None) - contains = getattr(arguments, "contains", None) - - try: - from hera.datalayer import Project - proj = Project(projectName=project) if project else Project() - - # Build query dict dynamically - query = {} - if typ: - query["type"] = typ - - docs = proj.getMeasurementsDocuments(**query) or [] - - rows = [] + from hera.datalayer import Project + + project = getattr(args, "project", None) + mtype = getattr(args, "type", None) + shortcut = getattr(args, "shortcut", None) + contains = getattr(args, "contains", None) + + # Map shortcuts to measurement types + shortcut_map = { + "ds": "ToolkitDataSource", # dynamic toolkits + "exp": "Experiment_rawData", # experiments + "sim": "Simulation_rawData", # (if any) + "cache": "Cache_rawData", # (if any) + "all": None, # no type filter + } + + if shortcut: + if shortcut not in shortcut_map: + print(f"Unknown shortcut '{shortcut}'. Valid: {', '.join(shortcut_map.keys())}") + return + mtype = shortcut_map[shortcut] + + # Open project (or default) + proj = Project(projectName=project) if project else Project() + + query = {} + if mtype: + query["type"] = mtype + + docs = proj.getMeasurementsDocuments(**query) + + # Optional substring filter + if contains: + filtered = [] for d in docs: - # Pull common fields safely (desc is a dict in most implementations) - desc = getattr(d, "desc", {}) or {} - rows.append({ - "type": getattr(d, "type", "") or desc.get("type", ""), - "datasourceName": getattr(d, "datasourceName", "") or desc.get("datasourceName", ""), - "resource": getattr(d, "resource", "") or desc.get("resource", ""), - "dataFormat": getattr(d, "dataFormat", "") or desc.get("dataFormat", ""), - "version": getattr(d, "version", "") or desc.get("version", ""), - "repository": getattr(d, "repository", "") or desc.get("repository", ""), - }) - - # Optional substring filter - if contains: - needle = str(contains).lower() - rows = [r for r in rows if needle in str(r.get("datasourceName", "")).lower() - or needle in str(r.get("resource", "")).lower()] - - if rows: - print(tabulate(rows, headers="keys", tablefmt="github")) - else: - print("No measurement documents found for the given filters.") - - except Exception as e: - logger.exception(e) - print(f"[ERROR] {e}") - - + name = getattr(d, "datasourceName", "") or d.desc.get("datasourceName", "") + resource = getattr(d, "resource", "") or "" + if contains in str(name) or contains in str(resource): + filtered.append(d) + docs = filtered + + if not docs: + print("No measurements found for given filters.") + return + # Build rows for pretty table + rows = [] + for d in docs: + rows.append({ + "type": getattr(d, "type", ""), + "datasourceName": getattr(d, "datasourceName", "") or d.desc.get("datasourceName", ""), + "resource": getattr(d, "resource", ""), + "dataFormat": getattr(d, "dataFormat", ""), + "version": d.desc.get("version", ""), + "repository": d.desc.get("repository", ""), + }) + + import pandas as pd + df = pd.DataFrame(rows) + print(df.to_markdown(index=False)) diff --git a/hera/utils/data/cli_toolkit_repository.py b/hera/utils/data/cli_toolkit_repository.py index 80cf1680..0f9c7ff6 100644 --- a/hera/utils/data/cli_toolkit_repository.py +++ b/hera/utils/data/cli_toolkit_repository.py @@ -22,7 +22,8 @@ from typing import Any, Dict, Tuple from hera.toolkit import ToolkitHome -from hera.utils.data.toolkit_repository import ToolkitRepository +from hera.datalayer import Project + # ------------------------------- Utilities ----------------------------------- @@ -161,6 +162,7 @@ def cmd_add_doc(args: argparse.Namespace) -> None: if mod_name and cls_name: try: # Try importing to verify the classpath really exists + from importlib import import_module mod = import_module(mod_name) getattr(mod, cls_name) keep_classpath = True @@ -179,26 +181,29 @@ def cmd_add_doc(args: argparse.Namespace) -> None: if keep_classpath: desc["classpath"] = classpath # only if verified importable - # Upsert via Project API - repo = ToolkitRepository(args.project) # helper for lookups - existing = repo.getToolkitDocument(args.name) # returns measurements doc or None + # Upsert via Project API (no ToolkitRepository) + proj = Project(projectName=args.project) + existing = proj.getMeasurementsDocuments( + type="ToolkitDataSource", + datasourceName=args.name, + ) if existing and not args.overwrite: print(f"(exists) Toolkit '{args.name}' already present; use --overwrite to replace") return - # If exists and overwrite requested -> delete old doc + # If exists and overwrite requested -> delete old docs if existing: - try: - existing.delete() - except Exception: - pass + for d in existing: + try: + d.delete() + except Exception: + pass - # Insert a fresh measurements document with type="ToolkitDataSource" - proj = repo._project # underlying Project proj.addMeasurementsDocument( type="ToolkitDataSource", dataFormat="Class", + resource=".", desc=desc, ) diff --git a/hera/utils/data/toolkit.py b/hera/utils/data/toolkit.py index 79cdab1c..a644002c 100644 --- a/hera/utils/data/toolkit.py +++ b/hera/utils/data/toolkit.py @@ -1,39 +1,53 @@ import json import argparse -from hera import toolkitHome,toolkit +import pathlib +import os + from hera.utils import loadJSON, dictToMongoQuery from hera.utils.logging import get_classMethod_logger +from hera.toolkit import abstractToolkit, ToolkitHome + + + +import json +import argparse import pathlib import os +from hera.utils import loadJSON, dictToMongoQuery +from hera.utils.logging import get_classMethod_logger +from hera.toolkit import abstractToolkit, ToolkitHome + -class dataToolkit(toolkit.abstractToolkit): +class dataToolkit(abstractToolkit): """ - A toolkit to handle the data (replacing the function of hera-data). + Toolkit for managing data repositories (replacing the old hera-data). + + It is initialized only with the DEFAULT project. - It is initialized only with the DEFAULT project. + The structure of a datasource file is: - The structure of a datasource file is: { - : { - : { - "resource": , - "dataFormat": , - "desc": { - < any other parameters/ metadata descriptions of the datasource> + "": { + "": { + "resource": "", + "dataFormat": "", + "desc": { + ... metadata ... + } }, - . - . - + ... }, - . - . - - } + ... + } """ def __init__(self): - super().__init__(toolkitName="heradata", projectName=self.DEFAULTPROJECT, filesDirectory=None) + # DEFAULTPROJECT comes from the base Project class (inherited via abstractToolkit) + super().__init__(toolkitName="heradata", + projectName=self.DEFAULTPROJECT, + filesDirectory=None) + def addRepository(self, repositoryName, repositoryPath, overwrite=False): """ @@ -114,37 +128,108 @@ def loadAllDatasourcesInRepositoryJSONToProject(self, auto_register_missing: bool = True): """ Iterate through the repository JSON and for each toolkit: - - Ensure we can instantiate it (ToolkitHome.getToolkit). - - If missing and auto_register_missing=True, attempt auto-register using: - * Registry.classpath or Registry.cls in the JSON, or - * Toolkit document from DB. - - After we have the instance, proceed with per-toolkit loading logic. + - Try to get an instance via ToolkitHome.getToolkit. + - If missing and auto_register_missing=True, attempt auto-register ONLY if there is + a clear classpath hint in the JSON (Registry.classpath or Registry.cls). + - After we have a valid instance, dispatch to the appropriate handler per section. """ logger = get_classMethod_logger(self, "loadAllDatasourcesInRepositoryJSONToProject") - handlerDict = dict(Config = self._handle_Config, - Datasource = self._handle_DataSource, - Measurements = lambda toolkit, itemName, docTypeDict, overwrite,basedir: self._DocumentHandler(toolkit, itemName, docTypeDict, overwrite,"Measurements",basedir), - Simulations = lambda toolkit, itemName, docTypeDict, overwrite,basedir: self._DocumentHandler(toolkit, itemName, docTypeDict, overwrite,"Simulations",basedir), - Cache = lambda toolkit, itemName, itemDesc, overwrite,basedir: self._DocumentHandler(toolkit, itemName, itemDesc, overwrite,"Cache",basedir), - Function = self._handle_Function) + handlerDict = dict( + Config=self._handle_Config, + Datasource=self._handle_DataSource, + Measurements=lambda toolkit, itemName, docTypeDict, overwrite, basedir: self._DocumentHandler( + toolkit, itemName, docTypeDict, overwrite, "Measurements", basedir + ), + Simulations=lambda toolkit, itemName, docTypeDict, overwrite, basedir: self._DocumentHandler( + toolkit, itemName, docTypeDict, overwrite, "Simulations", basedir + ), + Cache=lambda toolkit, itemName, itemDesc, overwrite, basedir: self._DocumentHandler( + toolkit, itemName, itemDesc, overwrite, "Cache", basedir + ), + Function=self._handle_Function, + ) + + tk_home = ToolkitHome(projectName=projectName) - # repositoryJSON is expected to be a dict mapping: toolkitName -> section for toolkitName, toolkitDict in (repositoryJSON or {}).items(): - toolkit = toolkitHome.getToolkit(toolkitName=toolkitName, projectName=projectName) + # 1) Try static/dynamic resolution via ToolkitHome.getToolkit + try: + toolkit = tk_home.getToolkit(toolkitName=toolkitName) + except Exception as e: + logger.info(f"Toolkit '{toolkitName}' not found via getToolkit: {e}") + toolkit = None + + # 2) Optional auto-registration, but only if there is a clear classpath hint + if toolkit is None and auto_register_missing: + registry = {} + classpath_hint = None + + if isinstance(toolkitDict, dict): + registry = toolkitDict.get("Registry", {}) or {} + if isinstance(registry, dict): + classpath_hint = registry.get("classpath") or registry.get("cls") + + if not classpath_hint: + # No classpath -> do not attempt auto-registration for this key + logger.info( + f"No classpath hint (Registry.classpath/cls) for key '{toolkitName}' in repository JSON; " + f"skipping auto-registration." + ) + else: + auto = getattr(tk_home, "auto_register_and_get", None) + if callable(auto): + try: + toolkit = auto( + toolkitName=toolkitName, + repositoryJSON=repositoryJSON, + repositoryName=None, + ) + logger.info( + f"Auto-registered toolkit '{toolkitName}' via auto_register_and_get " + f"using classpath '{classpath_hint}'" + ) + except Exception as e: + logger.error(f"Failed to auto-register toolkit '{toolkitName}': {e}") + # Skip this key but continue with others + continue + else: + logger.error("auto_register_and_get is not available on ToolkitHome") + continue + # 3) If we still do not have a toolkit instance, skip this key quietly + if toolkit is None: + logger.info( + f"Skipping key '{toolkitName}' in repository JSON – " + f"no matching toolkit and no auto-registration performed." + ) + continue + + # 4) Dispatch sections (Config, Datasource, Measurements, Simulations, Cache, Function) for key, docTypeDict in toolkitDict.items(): logger.info(f"Loading document type {key} to toolkit {toolkitName}") - handler = handlerDict.get(key.title(),None) + handler = handlerDict.get(key.title(), None) if handler is None: - err = f"Unkonw Handler {key.title()}. The handler must be {','.join(handlerDict.keys())}. " + err = ( + f"Unkonw Handler {key.title()}. " + f"The handler must be {', '.join(handlerDict.keys())}. " + ) logger.error(err) raise ValueError(err) + try: - handler(toolkit=toolkit, itemName=key, docTypeDict=docTypeDict, overwrite=overwrite,basedir=basedir) + handler( + toolkit=toolkit, + itemName=key, + docTypeDict=docTypeDict, + overwrite=overwrite, + basedir=basedir, + ) except Exception as e: - err = f"The error {e} occured while adding *{key}* to toolkit {toolkitName}... skipping!!!" + err = ( + f"The error {e} occured while adding *{key}* to toolkit {toolkitName}... skipping!!!" + ) logger.error(err)