From 6a1fadd087dc90fb83b8097c059eee1fb889507b Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 13 Jan 2026 09:43:27 +0100 Subject: [PATCH] Add more validation when creating class proxies. --- bytebuddy-proxy-support/build.gradle.kts | 3 + .../proxysupport/ByteBuddyProxyFactory.java | 124 ++++++++++++------ .../ByteBuddyProxyFactoryTest.java | 102 ++++++++++++++ .../ReflectionServiceDefinitionFactory.java | 28 ++-- 4 files changed, 207 insertions(+), 50 deletions(-) create mode 100644 bytebuddy-proxy-support/src/test/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactoryTest.java diff --git a/bytebuddy-proxy-support/build.gradle.kts b/bytebuddy-proxy-support/build.gradle.kts index 902703ce..837261a9 100644 --- a/bytebuddy-proxy-support/build.gradle.kts +++ b/bytebuddy-proxy-support/build.gradle.kts @@ -13,6 +13,9 @@ dependencies { implementation(project(":common")) implementation(libs.bytebuddy) implementation(libs.objenesis) + + testImplementation(libs.junit.jupiter) + testImplementation(libs.assertj) } tasks.withType { isFailOnError = false } diff --git a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java index 738210de..53accf67 100644 --- a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java +++ b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java @@ -11,17 +11,20 @@ import static net.bytebuddy.matcher.ElementMatchers.*; import dev.restate.common.reflections.ProxyFactory; +import dev.restate.common.reflections.ReflectionUtils; import dev.restate.sdk.annotation.Exclusive; import dev.restate.sdk.annotation.Handler; import dev.restate.sdk.annotation.Shared; import dev.restate.sdk.annotation.Workflow; import java.lang.reflect.Field; +import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import net.bytebuddy.ByteBuddy; import net.bytebuddy.TypeCache; import net.bytebuddy.description.modifier.Visibility; import net.bytebuddy.dynamic.scaffold.TypeValidation; +import net.bytebuddy.implementation.ExceptionMethod; import net.bytebuddy.implementation.InvocationHandlerAdapter; import org.jspecify.annotations.Nullable; import org.objenesis.Objenesis; @@ -39,6 +42,7 @@ public final class ByteBuddyProxyFactory implements ProxyFactory { private static final String INTERCEPTOR_FIELD_NAME = "$$interceptor$$"; private final Objenesis objenesis = new ObjenesisStd(); + private final ByteBuddy byteBuddy = new ByteBuddy().with(TypeValidation.ENABLED); private final TypeCache> proxyClassCache = new TypeCache.WithInlineExpunction<>(TypeCache.Sort.SOFT); @@ -62,8 +66,24 @@ public final class ByteBuddyProxyFactory implements ProxyFactory { // Set the interceptor field Field interceptorField = proxyClass.getDeclaredField(INTERCEPTOR_FIELD_NAME); - interceptorField.setAccessible(true); - interceptorField.set(proxyInstance, interceptor); + interceptorField.set( + proxyInstance, + (InvocationHandler) + (proxy, method, args) -> { + MethodInvocation invocation = + new MethodInvocation() { + @Override + public Object[] getArguments() { + return args != null ? args : new Object[0]; + } + + @Override + public Method getMethod() { + return method; + } + }; + return interceptor.invoke(invocation); + }); return proxyInstance; } catch (Exception e) { @@ -71,55 +91,81 @@ public final class ByteBuddyProxyFactory implements ProxyFactory { } } - private Class generateProxyClass(Class clazz) { - ByteBuddy byteBuddy = new ByteBuddy().with(TypeValidation.ENABLED); + private Class generateProxyClass(Class clazz) throws NoSuchFieldException { + if (!clazz.isInterface()) { + // We perform here some additional validation of the handlers that won't be executed by + // bytebuddy and can easily lead to strange behavior + var methods = + ReflectionUtils.getUniqueDeclaredMethods( + clazz, + method -> + ReflectionUtils.findAnnotation(method, Handler.class) != null + || ReflectionUtils.findAnnotation(method, Shared.class) != null + || ReflectionUtils.findAnnotation(method, Workflow.class) != null + || ReflectionUtils.findAnnotation(method, Exclusive.class) != null); + for (var method : methods) { + validateMethod(method); + } + } var builder = clazz.isInterface() ? byteBuddy.subclass(Object.class).implement(clazz) : byteBuddy.subclass(clazz); + var annotationMatcher = + isAnnotatedWith(Handler.class) + .or(isAnnotatedWith(Exclusive.class)) + .or(isAnnotatedWith(Shared.class)) + .or(isAnnotatedWith(Workflow.class)); try (var unloaded = builder // Add a field to store the interceptor - .defineField(INTERCEPTOR_FIELD_NAME, MethodInterceptor.class, Visibility.PUBLIC) + .defineField(INTERCEPTOR_FIELD_NAME, InvocationHandler.class, Visibility.PUBLIC) // Intercept all methods - .method( - isMethod() - .and( - isAnnotatedWith(Handler.class) - .or(isAnnotatedWith(Exclusive.class)) - .or(isAnnotatedWith(Shared.class)) - .or(isAnnotatedWith(Workflow.class)))) + .method(annotationMatcher) + .intercept(InvocationHandlerAdapter.toField(INTERCEPTOR_FIELD_NAME)) + .method(not(annotationMatcher)) .intercept( - InvocationHandlerAdapter.of( - (proxy, method, args) -> { - // Get the interceptor from the field - Field field = proxy.getClass().getDeclaredField(INTERCEPTOR_FIELD_NAME); - field.setAccessible(true); - MethodInterceptor interceptor = (MethodInterceptor) field.get(proxy); - - if (interceptor == null) { - throw new IllegalStateException( - "Interceptor not set on proxy instance. This is a bug, please contact the developers."); - } - - MethodInvocation invocation = - new MethodInvocation() { - @Override - public Object[] getArguments() { - return args != null ? args : new Object[0]; - } - - @Override - public Method getMethod() { - return method; - } - }; - return interceptor.invoke(invocation); - })) + ExceptionMethod.throwing( + UnsupportedOperationException.class, + "Calling a method not annotated with a Restate handler annotation on the proxy class")) .make()) { - return unloaded.load(clazz.getClassLoader()).getLoaded(); + + var proxyClazz = unloaded.load(clazz.getClassLoader()).getLoaded(); + + // Make sure the field is accessible + Field interceptorField = proxyClazz.getDeclaredField(INTERCEPTOR_FIELD_NAME); + interceptorField.setAccessible(true); + return proxyClazz; + } + } + + private static void validateMethod(Method method) { + if (!Modifier.isPublic(method.getModifiers())) { + throw new IllegalArgumentException( + "Method '" + + method.getDeclaringClass().getSimpleName() + + "#" + + method.getName() + + "' MUST be public to be used as Restate handler. Modifiers:" + + Modifier.toString(method.getModifiers())); + } + if (Modifier.isStatic(method.getModifiers())) { + throw new IllegalArgumentException( + "Method '" + + method.getDeclaringClass().getSimpleName() + + "#" + + method.getName() + + "' is static, cannot be used as Restate handler"); + } + if (Modifier.isFinal(method.getModifiers())) { + throw new IllegalArgumentException( + "Method '" + + method.getDeclaringClass().getSimpleName() + + "#" + + method.getName() + + "' is final, cannot be used as Restate handler"); } } } diff --git a/bytebuddy-proxy-support/src/test/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactoryTest.java b/bytebuddy-proxy-support/src/test/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactoryTest.java new file mode 100644 index 00000000..2609edda --- /dev/null +++ b/bytebuddy-proxy-support/src/test/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactoryTest.java @@ -0,0 +1,102 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.bytebuddy.proxysupport; + +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Fail.fail; + +import dev.restate.sdk.annotation.Handler; +import dev.restate.sdk.annotation.Service; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +public class ByteBuddyProxyFactoryTest { + + @Service + public static class InvokeNonRestateMethod { + public void somethingElse() {} + } + + @Test + @DisplayName("Invoking non restate method should fail") + public void badCallToNonRestateMethod() { + var proxyFactory = new ByteBuddyProxyFactory(); + var proxy = + proxyFactory.createProxy( + InvokeNonRestateMethod.class, + invocation -> fail("Unexpected call to method interceptor")); + + assertThatCode(() -> proxy.somethingElse()) + .hasMessageContaining( + "Calling a method not annotated with a Restate handler annotation on the proxy class") + .isInstanceOf(UnsupportedOperationException.class); + } + + @Service + public static class PackagePrivateMethod { + @Handler + void handler() { + fail("This code should not be executed"); + } + } + + @Test + @DisplayName("Package private method should fail") + public void packagePrivateMethod() { + var proxyFactory = new ByteBuddyProxyFactory(); + assertThatCode( + () -> + proxyFactory.createProxy( + PackagePrivateMethod.class, + invocation -> fail("Unexpected call to method interceptor"))) + .cause() + .cause() + .hasMessageContaining("MUST be public to be used as Restate handler"); + } + + @Service + public static class FinalMethod { + @Handler + public final void handler() { + fail("This code should not be executed"); + } + } + + @Test + @DisplayName("Final method should fail") + public void finalMethod() { + var proxyFactory = new ByteBuddyProxyFactory(); + assertThatCode( + () -> + proxyFactory.createProxy( + FinalMethod.class, invocation -> fail("Unexpected call to method interceptor"))) + .cause() + .cause() + .hasMessageContaining("is final"); + } + + @Service + public static final class FinalClass { + @Handler + public void handler() { + fail("This code should not be executed"); + } + } + + @Test + @DisplayName("Final class should fail") + public void finalClass() { + var proxyFactory = new ByteBuddyProxyFactory(); + assertThatCode( + () -> + proxyFactory.createProxy( + FinalClass.class, invocation -> fail("Unexpected call to method interceptor"))) + .hasMessageContaining("is final, cannot be proxied"); + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java index d9028509..a250298f 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java @@ -119,17 +119,7 @@ public ServiceDefinition create( var handlerName = handlerInfo.name(); var genericParameterTypes = method.getGenericParameterTypes(); var parameterCount = method.getParameterCount(); - - if (!Modifier.isPublic(method.getModifiers())) { - throw new MalformedRestateServiceException( - serviceName, - "Handler method '" - + handlerName - + "' MUST be public, but method '" - + method.getName() - + "' has modifiers: " - + Modifier.toString(method.getModifiers())); - } + validateMethod(method, serviceName); if ((parameterCount == 1 || parameterCount == 2) && (genericParameterTypes[0].equals(Context.class) @@ -231,6 +221,22 @@ public ServiceDefinition create( return handlerDefinition; } + private static void validateMethod(Method method, String serviceName) { + if (!Modifier.isPublic(method.getModifiers())) { + throw new MalformedRestateServiceException( + serviceName, + "Method '" + + method.getName() + + "' MUST be public to be used as Restate handler. Modifiers:" + + Modifier.toString(method.getModifiers())); + } + if (Modifier.isStatic(method.getModifiers())) { + throw new MalformedRestateServiceException( + serviceName, + "Method '" + method.getName() + "' is static, cannot be used as Restate handler"); + } + } + @SuppressWarnings({"unchecked", "rawtypes"}) private Serde resolveInputSerde( Method method, SerdeFactory serdeFactory, String serviceName) {