diff --git a/nhs_aws_helpers/dynamodb_model_store/base_model_store.py b/nhs_aws_helpers/dynamodb_model_store/base_model_store.py index 9669729..80f8fb6 100644 --- a/nhs_aws_helpers/dynamodb_model_store/base_model_store.py +++ b/nhs_aws_helpers/dynamodb_model_store/base_model_store.py @@ -561,6 +561,25 @@ async def paginate_models( last_evaluated_key=last_evaluated_key, ) + async def query_models_from_index( + self, + model_type: Type[TBaseModel_co], + index_name: str, + max_concurrency=10, + **kwargs, + ) -> PagedItems[TBaseModel_co]: + items, last_evaluated_key = await self.query_items(IndexName=index_name, **kwargs) + model_keys = [model_type.model_key_from_item(item) for item in items] + models = [ + cast(TBaseModel_co, record) + for record in await self.batch_get_model(model_keys, max_concurrency=max_concurrency) + if record + ] + return PagedItems( + items=models, + last_evaluated_key=last_evaluated_key, + ) + async def paginate_models_from_index( self, paginator_type: Literal["query", "scan", "list_backups", "list_tables"], diff --git a/tests/base_model_store_tests.py b/tests/base_model_store_tests.py index d802fe4..699c4c9 100644 --- a/tests/base_model_store_tests.py +++ b/tests/base_model_store_tests.py @@ -4,7 +4,18 @@ from dataclasses import dataclass, field from datetime import date, datetime from enum import Enum -from typing import Any, Dict, Generator, List, Mapping, Optional, Type, TypedDict, Union, cast +from typing import ( + Any, + Dict, + Generator, + List, + Mapping, + Optional, + Type, + TypedDict, + Union, + cast, +) from uuid import uuid4 import petname # type: ignore[import] @@ -744,6 +755,24 @@ async def test_paginate_models_from_index(store: MyModelStore): assert all(isinstance(model, AnotherModel) for model in page.items) +async def test_query_models_from_index(store: MyModelStore): + async with store.batch_writer() as writer: + for i in range(100): + await writer.put_item(AnotherModel(id="1", sk_field=f"SK:{i}")) + + + response = await store.query_models_from_index( + model_type=AnotherModel, + index_name="gsi_model_type", + KeyConditionExpression=Key("model_type").eq(AnotherModel.__name__), + Limit=40, + ) + + assert len(response.items) == 40 + assert all(isinstance(model, AnotherModel) for model in response.items) + assert response.last_evaluated_key == {"model_type": "AnotherModel", "my_pk": "BB#1", "my_sk": "SK:47"} + + async def test_paged_items(): has_size = PagedItems(items=[1, 2, 3], last_evaluated_key={"my_pk": 123}) assert has_size