diff --git a/.pylintrc b/.pylintrc index 780195824..8f3bb63eb 100644 --- a/.pylintrc +++ b/.pylintrc @@ -29,5 +29,4 @@ disable = no-value-for-parameter, unexpected-keyword-arg, inconsistent-return-statements, - duplicate-code, - attribute-defined-outside-init, + duplicate-code diff --git a/aikido_zen/background_process/cloud_connection_manager/__init__.py b/aikido_zen/background_process/cloud_connection_manager/__init__.py index ea92f5074..d6f3fc16f 100644 --- a/aikido_zen/background_process/cloud_connection_manager/__init__.py +++ b/aikido_zen/background_process/cloud_connection_manager/__init__.py @@ -35,13 +35,7 @@ def __init__(self, block, api, token, serverless): self.token = token # Should be instance of the Token class! self.routes = Routes(200) self.hostnames = Hostnames(200) - self.conf = ServiceConfig( - endpoints=[], - last_updated_at=-1, # Has not been updated yet - blocked_uids=[], - bypassed_ips=[], - received_any_stats=True, - ) + self.conf = ServiceConfig() self.firewall_lists = FirewallLists() self.rate_limiter = RateLimiter( max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py index e09f2c3d1..165f88c7d 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py @@ -15,13 +15,14 @@ def update_service_config(connection_manager, res): logger.debug("Updating blocking, setting blocking to : %s", res["block"]) connection_manager.block = bool(res["block"]) - connection_manager.conf.update( - endpoints=res.get("endpoints", []), - last_updated_at=res.get("configUpdatedAt", get_unixtime_ms()), - blocked_uids=res.get("blockedUserIds", []), - bypassed_ips=res.get("allowedIPAddresses", []), - received_any_stats=res.get("receivedAnyStats", True), + connection_manager.conf.set_endpoints(res.get("endpoints", [])) + connection_manager.conf.set_last_updated_at( + res.get("configUpdatedAt", get_unixtime_ms()) ) + connection_manager.conf.set_blocked_user_ids(res.get("blockedUserIds", [])) + connection_manager.conf.set_bypassed_ips(res.get("allowedIPAddresses", [])) + if res.get("receivedAnyStats", True): + connection_manager.conf.enable_received_any_stats() # Handle outbound request blocking configuration if "blockNewOutgoingRequests" in res: diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py index be6ddfed7..178d0a3b6 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py @@ -11,13 +11,7 @@ def test_update_service_config_outbound_blocking(): # Create a mock connection manager with a real ServiceConfig connection_manager = MagicMock() - connection_manager.conf = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + connection_manager.conf = ServiceConfig() connection_manager.block = False # Test response with blockNewOutgoingRequests @@ -45,13 +39,7 @@ def test_update_service_config_outbound_blocking_false(): # Create a mock connection manager with a real ServiceConfig connection_manager = MagicMock() - connection_manager.conf = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + connection_manager.conf = ServiceConfig() connection_manager.block = True # Test response with blockNewOutgoingRequests=False @@ -69,13 +57,7 @@ def test_update_service_config_outbound_blocking_missing(): # Create a mock connection manager with a real ServiceConfig connection_manager = MagicMock() - connection_manager.conf = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + connection_manager.conf = ServiceConfig() connection_manager.block = False # Test response without outbound blocking fields @@ -97,13 +79,7 @@ def test_update_service_config_failure(): # Create a mock connection manager with a real ServiceConfig connection_manager = MagicMock() - connection_manager.conf = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + connection_manager.conf = ServiceConfig() connection_manager.block = False # Set initial values @@ -127,13 +103,7 @@ def test_update_service_config_complete(): # Create a mock connection manager with a real ServiceConfig connection_manager = MagicMock() - connection_manager.conf = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + connection_manager.conf = ServiceConfig() connection_manager.block = False # Test complete response @@ -174,13 +144,7 @@ def test_update_service_config_domains_only(): # Create a mock connection manager with a real ServiceConfig connection_manager = MagicMock() - connection_manager.conf = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + connection_manager.conf = ServiceConfig() connection_manager.block = False # Test response with only domains @@ -207,13 +171,7 @@ def test_update_service_config_block_new_outgoing_requests_only(): # Create a mock connection manager with a real ServiceConfig connection_manager = MagicMock() - connection_manager.conf = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + connection_manager.conf = ServiceConfig() connection_manager.block = False # Set initial domains diff --git a/aikido_zen/background_process/service_config.py b/aikido_zen/background_process/service_config.py index 283369d7f..c94715699 100644 --- a/aikido_zen/background_process/service_config.py +++ b/aikido_zen/background_process/service_config.py @@ -6,42 +6,18 @@ from aikido_zen.helpers.match_endpoints import match_endpoints -# noinspection PyAttributeOutsideInit class ServiceConfig: - """Class holding the config of the connection_manager""" - - def __init__( - self, - endpoints, - last_updated_at: int, - blocked_uids, - bypassed_ips, - received_any_stats: bool, - ): - # Init the class using update function : - self.update( - endpoints, last_updated_at, blocked_uids, bypassed_ips, received_any_stats - ) + + def __init__(self): + self.endpoints = [] + self.bypassed_ips = IPMatcher() + self.blocked_uids = set() + self.last_updated_at = -1 + self.received_any_stats = False self.block_new_outgoing_requests = False self.outbound_domains = {} - def update( - self, - endpoints, - last_updated_at: int, - blocked_uids, - bypassed_ips, - received_any_stats: bool, - ): - self.last_updated_at = last_updated_at - self.received_any_stats = bool(received_any_stats) - self.blocked_uids = set(blocked_uids) - self.set_endpoints(endpoints) - self.set_bypassed_ips(bypassed_ips) - def set_endpoints(self, endpoints): - """Sets non-graphql endpoints""" - self.endpoints = [ endpoint for endpoint in endpoints if not endpoint.get("graphql") ] @@ -68,7 +44,6 @@ def get_endpoints(self, route_metadata): return match_endpoints(route_metadata, self.endpoints) def set_bypassed_ips(self, bypassed_ips): - """Creates an IPMatcher from the given bypassed ip set""" self.bypassed_ips = IPMatcher() for ip in bypassed_ips: self.bypassed_ips.add(ip) @@ -77,6 +52,15 @@ def is_bypassed_ip(self, ip): """Checks if the IP is on the bypass list""" return self.bypassed_ips.has(ip) + def set_blocked_user_ids(self, blocked_user_ids): + self.blocked_uids = set(blocked_user_ids) + + def enable_received_any_stats(self): + self.received_any_stats = True + + def set_last_updated_at(self, last_updated_at: int): + self.last_updated_at = last_updated_at + def update_outbound_domains(self, domains): self.outbound_domains = { domain["hostname"]: domain["mode"] for domain in domains diff --git a/aikido_zen/background_process/service_config_test.py b/aikido_zen/background_process/service_config_test.py index d34611693..f036cc38d 100644 --- a/aikido_zen/background_process/service_config_test.py +++ b/aikido_zen/background_process/service_config_test.py @@ -5,13 +5,7 @@ def test_service_config_outbound_blocking_initialization(): """Test that ServiceConfig initializes outbound blocking fields correctly""" - config = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + config = ServiceConfig() # Test initial values assert hasattr(config, "block_new_outgoing_requests") @@ -22,13 +16,7 @@ def test_service_config_outbound_blocking_initialization(): def test_service_config_set_block_new_outgoing_requests(): """Test the set_block_new_outgoing_requests method""" - config = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + config = ServiceConfig() # Test setting to True config.set_block_new_outgoing_requests(True) @@ -54,13 +42,7 @@ def test_service_config_set_block_new_outgoing_requests(): def test_service_config_update_domains(): """Test the update_outbound_domains method""" - config = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + config = ServiceConfig() # Test initial state assert config.outbound_domains == {} @@ -89,13 +71,7 @@ def test_service_config_update_domains(): def test_service_config_should_block_outgoing_request(): """Test the should_block_outgoing_request method""" - config = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + config = ServiceConfig() # Test with block_new_outgoing_requests = False (default) # Only block if mode is "block" @@ -138,6 +114,7 @@ def test_service_config_should_block_outgoing_request(): def test_service_config_initialization(): + service_config = ServiceConfig() endpoints = [ { "graphql": False, @@ -185,26 +162,33 @@ def test_service_config_initialization(): "force_protection_off": False, }, ] - last_updated_at = "2023-10-01" - service_config = ServiceConfig( - endpoints, - last_updated_at, - ["0", "0", "1", "5"], - ["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"], - True, - ) - # Check that non-GraphQL endpoints are correctly filtered - assert len(service_config.endpoints) == 3 + assert len(service_config.endpoints) == 0 + service_config.set_endpoints(endpoints) + assert ( + len(service_config.endpoints) == 3 + ) # Check that non-GraphQL endpoints are correctly filtered assert service_config.endpoints[0]["route"] == "/v1" assert service_config.endpoints[1]["route"] == "/v3" assert service_config.endpoints[2]["route"] == "/admin" - assert service_config.last_updated_at == last_updated_at + + service_config.set_last_updated_at(37982562953) + assert service_config.last_updated_at == 37982562953 + + assert isinstance(service_config.bypassed_ips, IPMatcher) + service_config.set_bypassed_ips(["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"]) assert isinstance(service_config.bypassed_ips, IPMatcher) assert service_config.bypassed_ips.has("127.0.0.1") assert service_config.bypassed_ips.has("123.1.2.2") assert not service_config.bypassed_ips.has("1.1.1.1") - assert service_config.blocked_uids == set(["1", "0", "5"]) + + assert len(service_config.blocked_uids) == 0 + service_config.set_blocked_user_ids({"0", "0", "1", "5"}) + assert service_config.blocked_uids == {"1", "0", "5"} + + assert not service_config.received_any_stats + service_config.enable_received_any_stats() + assert service_config.received_any_stats == True v1_endpoint = service_config.get_endpoints( { @@ -230,41 +214,9 @@ def test_service_config_initialization(): assert not admin_endpoint["allowedIPAddresses"].has("192.168.0.1") -# Sample data for testing -sample_endpoints = [ - {"url": "http://example.com/api/v1", "graphql": False, "context": "user"}, - {"url": "http://example.com/api/v2", "graphql": True, "context": "admin"}, - {"url": "http://example.com/api/v3", "graphql": False, "context": "guest"}, -] - - -@pytest.fixture -def service_config(): - return ServiceConfig( - endpoints=sample_endpoints, - last_updated_at="2023-10-01T00:00:00Z", - blocked_uids=["user1", "user2"], - bypassed_ips=["192.168.1.1", "10.0.0.1"], - received_any_stats=True, - ) - - -def test_initialization(service_config): - assert len(service_config.endpoints) == 2 # Only non-graphql endpoints - assert service_config.last_updated_at == "2023-10-01T00:00:00Z" - assert isinstance(service_config.bypassed_ips, IPMatcher) - assert service_config.blocked_uids == {"user1", "user2"} - - def test_ip_blocking(): - config = ServiceConfig( - endpoints=sample_endpoints, - last_updated_at="2023-10-01T00:00:00Z", - blocked_uids=["user1", "user2"], - bypassed_ips=["192.168.1.1", "10.0.0.0/16", "::1/128"], - received_any_stats=True, - ) - + config = ServiceConfig() + config.set_bypassed_ips(["192.168.1.1", "10.0.0.0/16", "::1/128"]) assert config.is_bypassed_ip("192.168.1.1") assert config.is_bypassed_ip("10.0.0.1") assert config.is_bypassed_ip("10.0.1.2") @@ -276,38 +228,39 @@ def test_ip_blocking(): def test_service_config_with_empty_allowlist(): - endpoints = [ - { - "graphql": False, - "method": "GET", - "route": "/admin", - "rate_limiting": { - "enabled": False, - "max_requests": 10, - "window_size_in_ms": 1000, - }, - "allowedIPAddresses": [], - "force_protection_off": False, - }, - ] - last_updated_at = "2023-10-01" - service_config = ServiceConfig( - endpoints, - last_updated_at, - ["0", "0", "1", "5"], - ["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"], - True, - ) + service_config = ServiceConfig() # Check that non-GraphQL endpoints are correctly filtered + service_config.set_endpoints( + [ + { + "graphql": False, + "method": "GET", + "route": "/admin", + "rate_limiting": { + "enabled": False, + "max_requests": 10, + "window_size_in_ms": 1000, + }, + "allowedIPAddresses": [], + "force_protection_off": False, + }, + ] + ) assert len(service_config.endpoints) == 1 assert service_config.endpoints[0]["route"] == "/admin" - assert service_config.last_updated_at == last_updated_at + + service_config.set_last_updated_at(29839537) + assert service_config.last_updated_at == 29839537 + + service_config.set_blocked_user_ids({"0", "0", "1", "5"}) + assert service_config.blocked_uids == {"1", "0", "5"} + + service_config.set_bypassed_ips(["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"]) assert isinstance(service_config.bypassed_ips, IPMatcher) - assert service_config.bypassed_ips.has("127.0.0.1") - assert service_config.bypassed_ips.has("123.1.2.2") - assert not service_config.bypassed_ips.has("1.1.1.1") - assert service_config.blocked_uids == set(["1", "0", "5"]) + assert service_config.is_bypassed_ip("127.0.0.1") + assert service_config.is_bypassed_ip("123.1.2.2") + assert not service_config.is_bypassed_ip("1.1.1.1") admin_endpoint = service_config.get_endpoints( { diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index bb8eeb3b3..a1c7310bd 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -50,7 +50,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): self.parsed_userinput = {} self.xml = {} self.outgoing_req_redirects = [] - self.set_body(body) + self.body = Context.parse_body_object(body) self.headers: Headers = Headers() self.cookies = {} self.query = {} @@ -107,26 +107,28 @@ def set_cookies(self, cookies): self.cookies = cookies def set_body(self, body): + self.body = Context.parse_body_object(body) + + @staticmethod + def parse_body_object(body): + """Parses the body and checks if it's possibly JSON""" try: - self.set_body_internal(body) + if isinstance(body, (str, bytes)) and len(body) == 0: + # Make sure that empty bodies like b"" don't get sent. + return None + if isinstance(body, bytes): + body = body.decode("utf-8") # Decode byte input to string. + if not isinstance(body, str): + return body + if body.strip()[0] in ["{", "[", '"']: + # Might be JSON, but might not have been parsed correctly by server because of wrong headers + parsed_body = json.loads(body) + if parsed_body: + return parsed_body + return body except Exception as e: - logger.debug("Exception occurred whilst setting body: %s", e) - - def set_body_internal(self, body): - """Sets the body and checks if it's possibly JSON""" - self.body = body - if isinstance(self.body, (str, bytes)) and len(body) == 0: - # Make sure that empty bodies like b"" don't get sent. - self.body = None - if isinstance(self.body, bytes): - self.body = self.body.decode("utf-8") # Decode byte input to string. - if not isinstance(self.body, str): - return - if self.body.strip()[0] in ["{", "[", '"']: - # Might be JSON, but might not have been parsed correctly by server because of wrong headers - parsed_body = json.loads(self.body) - if parsed_body: - self.body = parsed_body + logger.debug("Exception occurred whilst parsing body: %s", e) + return body def get_route_metadata(self): """Returns a route_metadata object""" diff --git a/aikido_zen/ratelimiting/init_test.py b/aikido_zen/ratelimiting/init_test.py index 966d8c7ac..017db8124 100644 --- a/aikido_zen/ratelimiting/init_test.py +++ b/aikido_zen/ratelimiting/init_test.py @@ -17,13 +17,10 @@ def user(): def create_connection_manager(endpoints=[], bypassed_ips=[]): cm = MagicMock() - cm.conf = ServiceConfig( - endpoints=endpoints, - last_updated_at=1, - blocked_uids=[], - bypassed_ips=bypassed_ips, - received_any_stats=True, - ) + cm.conf = ServiceConfig() + cm.conf.set_endpoints(endpoints) + cm.conf.enable_received_any_stats() + cm.conf.set_bypassed_ips(bypassed_ips) cm.rate_limiter = RateLimiter( max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes ) diff --git a/aikido_zen/sinks/tests/socket_test.py b/aikido_zen/sinks/tests/socket_test.py index 6e8bfa45e..a7ebc2d56 100644 --- a/aikido_zen/sinks/tests/socket_test.py +++ b/aikido_zen/sinks/tests/socket_test.py @@ -119,13 +119,7 @@ def test_socket_getaddrinfo_no_cache(): def test_service_config_should_block_outgoing_request(): """Test the should_block_outgoing_request method""" - config = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + config = ServiceConfig() # Test with no blocking configured assert not config.should_block_outgoing_request("example.com") @@ -148,13 +142,7 @@ def test_service_config_should_block_outgoing_request(): def test_service_config_update_domains(): """Test the update_outbound_domains method""" - config = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + config = ServiceConfig() # Test initial state assert config.outbound_domains == {} @@ -175,13 +163,7 @@ def test_service_config_update_domains(): def test_service_config_set_block_new_outgoing_requests(): """Test the set_block_new_outgoing_requests method""" - config = ServiceConfig( - endpoints=[], - last_updated_at=0, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, - ) + config = ServiceConfig() # Test initial state assert not config.block_new_outgoing_requests diff --git a/aikido_zen/sources/functions/request_handler_test.py b/aikido_zen/sources/functions/request_handler_test.py index 50503d4ad..dc8c3f3ad 100644 --- a/aikido_zen/sources/functions/request_handler_test.py +++ b/aikido_zen/sources/functions/request_handler_test.py @@ -154,19 +154,16 @@ def set_context(remote_address, user_agent="", route="/posts/:number"): def create_service_config(): - config = ServiceConfig( - endpoints=[ + config = ServiceConfig() + config.set_endpoints( + [ { "method": "POST", "route": "/posts/:number", "graphql": False, "allowedIPAddresses": ["1.1.1.1", "2.2.2.2", "3.3.3.3"], } - ], - last_updated_at=None, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, + ] ) get_cache().config = config return config diff --git a/aikido_zen/thread/thread_cache.py b/aikido_zen/thread/thread_cache.py index 0072bac3c..10f3f12e1 100644 --- a/aikido_zen/thread/thread_cache.py +++ b/aikido_zen/thread/thread_cache.py @@ -17,11 +17,13 @@ class ThreadCache: """ def __init__(self): + self.config = ServiceConfig() + self.middleware_installed = False + self.routes = Routes(max_size=1000) self.hostnames = Hostnames(200) self.users = Users(1000) self.stats = Statistics() self.ai_stats = AIStatistics() - self.reset() # Initialize values def is_bypassed_ip(self, ip): """Checks the given IP against the list of bypassed ips""" @@ -35,16 +37,9 @@ def get_endpoints(self): return self.config.endpoints def reset(self): - """Empties out all values of the cache""" - self.routes = Routes(max_size=1000) - self.config = ServiceConfig( - endpoints=[], - blocked_uids=set(), - bypassed_ips=[], - last_updated_at=-1, - received_any_stats=False, - ) self.middleware_installed = False + self.config = ServiceConfig() + self.routes.clear() self.hostnames.clear() self.users.clear() self.stats.clear() diff --git a/aikido_zen/thread/thread_cache_test.py b/aikido_zen/thread/thread_cache_test.py index 78e69cecb..968470bd2 100644 --- a/aikido_zen/thread/thread_cache_test.py +++ b/aikido_zen/thread/thread_cache_test.py @@ -153,26 +153,27 @@ def increment_in_thread(): def test_parses_routes_correctly(mock_get_comms, thread_cache: ThreadCache): """Test renewing the cache multiple times if TTL has expired.""" mock_get_comms.return_value = MagicMock() + service_config = ServiceConfig() + service_config.set_endpoints( + [ + { + "graphql": False, + "method": "POST", + "route": "/v2", + "rate_limiting": { + "enabled": False, + }, + "force_protection_off": False, + } + ] + ) + service_config.set_bypassed_ips(["192.168.1.1"]) + service_config.set_blocked_user_ids({"user123"}) + service_config.enable_received_any_stats() mock_get_comms.return_value.send_data_to_bg_process.return_value = { "success": True, "data": { - "config": ServiceConfig( - endpoints=[ - { - "graphql": False, - "method": "POST", - "route": "/v2", - "rate_limiting": { - "enabled": False, - }, - "force_protection_off": False, - } - ], - bypassed_ips=["192.168.1.1"], - blocked_uids={"user123"}, - last_updated_at=-1, - received_any_stats=True, - ), + "config": service_config, "routes": { "POST:/body": { "method": "POST",