Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/fides/api/service/connectors/saas_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus
from fides.api.models.policy import Policy
from fides.api.models.privacy_request import PrivacyRequest, RequestTask
from fides.api.models.privacy_request.request_task import AsyncTaskType
from fides.api.schemas.consentable_item import (
ConsentableItem,
build_consent_item_hierarchy,
Expand Down Expand Up @@ -277,11 +278,20 @@ def retrieve_data(

# Delegate async requests
with get_db() as db:
# Guard clause to ensure we only run async access requests for access requests
if self.guard_access_request(policy):
if async_dsr_strategy := _get_async_dsr_strategy(
db, request_task, query_config, ActionType.access
if async_dsr_strategy := _get_async_dsr_strategy(
db, request_task, query_config, ActionType.access
):
check_guard_access_request = self.guard_access_request(policy)
# Guard clause only applies to polling requests
# Callback requests should always proceed
if (async_dsr_strategy.type == AsyncTaskType.polling) and (
not check_guard_access_request
):
logger.info(
f"Skipping async access request for policy: {policy.name}"
)
return []
if check_guard_access_request:
return async_dsr_strategy.async_retrieve_data(
client=self.create_client(),
request_task_id=request_task.id,
Expand Down
66 changes: 66 additions & 0 deletions tests/fixtures/saas/test_data/saas_async_polling_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
saas_config:
fides_key: saas_async_polling_config
name: Async Polling Example Custom Connector
type: async_polling_example
description: Test Async Polling Config
version: 0.0.1

connector_params:
- name: domain
- name: api_token
label: API token

client_config:
protocol: http
host: <domain>
authentication:
strategy: bearer
configuration:
token: <api_token>

test_request:
method: GET
path: /

endpoints:
- name: user
requests:
read:
method: GET
path: /api/v1/user
query_params:
- name: query
value: <email>
param_values:
- name: email
identity: email
correlation_id_path: request_id
async_config:
strategy: polling
configuration:
status_request:
method: GET
path: /api/v1/user/status
status_path: status
status_completed_value: completed
result_request:
method: GET
path: /api/v1/user/result
update:
method: DELETE
path: /api/v1/user/<user_id>
correlation_id_path: correlation_id
async_config:
strategy: polling
configuration:
status_request:
method: GET
path: /api/v1/user/<correlation_id>/status
status_path: status
status_completed_value: completed
param_values:
- name: user_id
references:
- dataset: saas_async_polling_config
field: user.id
direction: from
15 changes: 15 additions & 0 deletions tests/fixtures/saas/test_data/saas_async_polling_dataset.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataset:
- fides_key: saas_async_polling_config
name: async_polling_example
description: A sample dataset for async polling
collections:
- name: user
fields:
- name: id
data_categories: [user.unique_id]
fidesops_meta:
primary_key: True
- name: system_id
data_categories: [system]
- name: state
data_categories: [user.contact.address.state]
212 changes: 211 additions & 1 deletion tests/ops/service/connectors/test_saas_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from starlette.status import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND

from fides.api.common_exceptions import (
AwaitingAsyncProcessing,
AwaitingAsyncTask,
ClientUnsuccessfulException,
ConnectionException,
Expand All @@ -22,7 +23,7 @@
from fides.api.graph.graph import DatasetGraph, Node
from fides.api.graph.traversal import Traversal, TraversalNode
from fides.api.models.consent_automation import ConsentAutomation
from fides.api.models.policy import Policy
from fides.api.models.policy import ActionType, Policy, Rule, RuleTarget
from fides.api.models.privacy_notice import UserConsentPreference
from fides.api.models.privacy_request import PrivacyRequest, RequestTask
from fides.api.models.worker_task import ExecutionLogStatus
Expand Down Expand Up @@ -1121,6 +1122,23 @@ def async_graph(self, saas_example_async_dataset_config, db, privacy_request):
db, privacy_request, traversal_nodes, end_nodes, graph
)

@pytest.fixture(scope="function")
def async_graph_polling(
self, saas_async_polling_example_dataset_config, db, privacy_request
):
# Build proper async graph with persisted request tasks for polling tests
async_graph = saas_async_polling_example_dataset_config.get_graph()
graph = DatasetGraph(async_graph)
traversal = Traversal(graph, {"email": "customer-1@example.com"})
traversal_nodes = {}
end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn)
persist_new_access_request_tasks(
db, privacy_request, traversal, traversal_nodes, end_nodes, graph
)
persist_initial_erasure_request_tasks(
db, privacy_request, traversal_nodes, end_nodes, graph
)

@mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send")
def test_read_request_expects_async_results(
self,
Expand Down Expand Up @@ -1306,3 +1324,195 @@ def test_callback_succeeded_mask_data(
)
== 5
)

@mock.patch(
"fides.api.service.connectors.saas_connector.SaaSConnector.create_client"
)
def test_guard_access_request_with_access_policy(
self,
mock_create_client,
privacy_request,
saas_async_polling_example_connection_config,
async_graph_polling,
):
"""
Test that guard_access_request allows async access requests to run
when the policy has access rules (access request scenario).
"""
connector: SaaSConnector = get_connector(
saas_async_polling_example_connection_config
)
mock_create_client.return_value = mock.MagicMock()

# Get access request task
request_task = privacy_request.access_tasks.filter(
RequestTask.collection_name == "user"
).first()
execution_node = ExecutionNode(request_task)

# Policy has access rules, so guard should return True and async_retrieve_data should be called
with pytest.raises(AwaitingAsyncProcessing):
connector.retrieve_data(
execution_node,
privacy_request.policy,
privacy_request,
request_task,
{},
)

@mock.patch(
"fides.api.service.connectors.saas_connector.SaaSConnector.create_client"
)
def test_guard_access_request_with_erasure_only_policy(
self,
mock_create_client,
db,
privacy_request,
saas_async_polling_example_connection_config,
async_graph_polling,
oauth_client,
):
"""
Test that guard_access_request skips async access requests
when the policy has no access rules (erasure-only request scenario).
This test ensures coverage of the logger.info and return [] lines.
Uses polling async strategy to test the guard clause.
"""
# Create an erasure-only policy (no access rules)
erasure_only_policy = Policy.create(
db=db,
data={
"name": "Erasure Only Policy",
"key": "erasure_only_policy_test",
"client_id": oauth_client.id,
},
)

erasure_rule = Rule.create(
db=db,
data={
"action_type": ActionType.erasure,
"name": "Erasure Rule",
"key": "erasure_rule_test",
"policy_id": erasure_only_policy.id,
"masking_strategy": {
"strategy": "null_rewrite",
"configuration": {},
},
"client_id": oauth_client.id,
},
)

RuleTarget.create(
db=db,
data={
"data_category": "user.name",
"rule_id": erasure_rule.id,
"client_id": oauth_client.id,
},
)

connector: SaaSConnector = get_connector(
saas_async_polling_example_connection_config
)

# Get access request task
request_task = privacy_request.access_tasks.filter(
RequestTask.collection_name == "user"
).first()
execution_node = ExecutionNode(request_task)

# Verify guard_access_request returns False for erasure-only policy
assert connector.guard_access_request(erasure_only_policy) is False

result = connector.retrieve_data(
execution_node,
erasure_only_policy,
privacy_request,
request_task,
{},
)

# Should return empty list without calling async_retrieve_data
assert result == []

@mock.patch(
"fides.api.service.connectors.saas_connector.SaaSConnector.create_client"
)
def test_callback_requests_ignore_guard_clause(
self,
mock_create_client,
db,
privacy_request,
saas_async_example_connection_config,
async_graph,
oauth_client,
):
"""
Test that callback requests ignore the guard clause entirely.
Even if guard_access_request returns False (erasure-only policy),
callback requests should still proceed and raise AwaitingAsyncTask.
"""
# Create an erasure-only policy (no access rules)
erasure_only_policy = Policy.create(
db=db,
data={
"name": "Erasure Only Policy Callback Test",
"key": "erasure_only_policy_callback_test",
"client_id": oauth_client.id,
},
)

erasure_rule = Rule.create(
db=db,
data={
"action_type": ActionType.erasure,
"name": "Erasure Rule Callback Test",
"key": "erasure_rule_callback_test",
"policy_id": erasure_only_policy.id,
"masking_strategy": {
"strategy": "null_rewrite",
"configuration": {},
},
"client_id": oauth_client.id,
},
)

RuleTarget.create(
db=db,
data={
"data_category": "user.name",
"rule_id": erasure_rule.id,
"client_id": oauth_client.id,
},
)

connector: SaaSConnector = get_connector(saas_async_example_connection_config)
# Mock the client and its send method to allow async callback flow
mock_client = mock.MagicMock()
mock_send_response = mock.MagicMock()
mock_send_response.json.return_value = {"id": "123"}
mock_client.send.return_value = mock_send_response
mock_create_client.return_value = mock_client

# Get access request task
request_task = privacy_request.access_tasks.filter(
RequestTask.collection_name == "user"
).first()
execution_node = ExecutionNode(request_task)

# Verify guard_access_request returns False for erasure-only policy
assert connector.guard_access_request(erasure_only_policy) is False

# Even though guard_access_request returns False, callback requests
# should ignore the guard and always proceed with common requests.
connector.retrieve_data(
execution_node,
erasure_only_policy,
privacy_request,
request_task,
{},
)

# Verify that the async callback flow was triggered (client.send was called)
assert mock_client.send.called
Loading