diff --git a/README.md b/README.md index d55ed3e..405f67e 100644 --- a/README.md +++ b/README.md @@ -255,6 +255,25 @@ config :samly, Samly.State, |:------------|:-----------| | `opts` | _(optional)_ The `:key` is the name of the session key where assertion is stored. Default is `:samly_assertion`. | +#### Relay State customization + +By default, Samly sets the relay state to a random string. +However, you can set any specific value to the relay state. + +Set a fixed string as the relay state using the following: +```elixir +config :samly, Samly.State, + relay_state: "any-specific-string" +``` + +Set the relay state to the result of an annonymous function using the follows: +```elixir +config :samly, Samly.State, + relay_state: fn conn -> + "#{conn.scheme}://#{conn.host}#{conn.request_path}" + end +``` + ## SAML Assertion Once authentication is completed successfully, IdP sends a "consume" SAML diff --git a/lib/samly/auth_handler.ex b/lib/samly/auth_handler.ex index a9cf580..6d03b19 100644 --- a/lib/samly/auth_handler.ex +++ b/lib/samly/auth_handler.ex @@ -67,7 +67,7 @@ defmodule Samly.AuthHandler do conn |> redirect(302, target_url) _ -> - relay_state = State.gen_id() + relay_state = State.create_relay_state(conn) {idp_signin_url, req_xml_frag} = Helper.gen_idp_signin_req(sp, idp_rec, Map.get(idp, :nameid_format)) @@ -109,7 +109,7 @@ defmodule Samly.AuthHandler do Helper.gen_idp_signout_req(sp, idp_rec, subject_rec, session_index) conn = State.delete_assertion(conn, assertion_key) - relay_state = State.gen_id() + relay_state = State.create_relay_state(conn) conn |> put_session("target_url", target_url) diff --git a/lib/samly/provider.ex b/lib/samly/provider.ex index f31ad14..b26a1b6 100644 --- a/lib/samly/provider.ex +++ b/lib/samly/provider.ex @@ -34,7 +34,8 @@ defmodule Samly.Provider do store_env = Application.get_env(:samly, Samly.State, []) store_provider = store_env[:store] || Samly.State.ETS store_opts = store_env[:opts] || [] - State.init(store_provider, store_opts) + relay_state = store_env[:relay_state] || (&Samly.State.gen_id/1) + State.init(store_provider, store_opts, relay_state) opts = Application.get_env(:samly, Samly.Provider, []) diff --git a/lib/samly/state.ex b/lib/samly/state.ex index 9b25295..85c983c 100644 --- a/lib/samly/state.ex +++ b/lib/samly/state.ex @@ -3,11 +3,14 @@ defmodule Samly.State do @state_store :state_store - def init(store_provider), do: init(store_provider, []) - - def init(store_provider, opts) do + def init(store_provider, opts \\ [], relay_state \\ &gen_id/1) do opts = store_provider.init(opts) - Application.put_env(:samly, @state_store, %{provider: store_provider, opts: opts}) + + Application.put_env(:samly, @state_store, %{ + provider: store_provider, + opts: opts, + relay_state: relay_state + }) end def get_assertion(conn, assertion_key) do @@ -25,7 +28,27 @@ defmodule Samly.State do store_provider.delete_assertion(conn, assertion_key, opts) end - def gen_id() do + @spec create_relay_state(Plug.Conn.t()) :: String.t() + def create_relay_state(conn) do + case Application.get_env(:samly, @state_store).relay_state do + relay_state when is_function(relay_state, 1) -> + relay_state.(conn) + + relay_state when is_binary(relay_state) -> + relay_state + + relay_state -> + raise "Invalid relay_state: expected a function of arity 1 or a string, got #{inspect(relay_state)}" + end + end + + @spec gen_id(Plug.Conn.t()) :: String.t() + def gen_id(_conn) do + gen_id() + end + + @spec gen_id :: String.t() + def gen_id do 24 |> :crypto.strong_rand_bytes() |> Base.url_encode64() end end diff --git a/test/samly_state_test.exs b/test/samly_state_test.exs index 9daac8e..844b5a4 100644 --- a/test/samly_state_test.exs +++ b/test/samly_state_test.exs @@ -2,6 +2,29 @@ defmodule Samly.StateTest do use ExUnit.Case, async: true use Plug.Test + test "create_relay_state" do + conn = conn(:get, "/relay-state-path") + + default_relay_state_length = Samly.State.gen_id() |> String.length() + assert Samly.State.init(Samly.State.ETS, [], &Samly.State.gen_id/1) == :ok + assert Samly.State.create_relay_state(conn) |> String.length() == default_relay_state_length + + assert Samly.State.init(Samly.State.ETS, [], "relay_state_string") == :ok + assert Samly.State.create_relay_state(conn) == "relay_state_string" + + relay_state_fun = fn conn -> "#{conn.scheme}://#{conn.host}#{conn.request_path}" end + assert Samly.State.init(Samly.State.ETS, [], relay_state_fun) == :ok + assert Samly.State.create_relay_state(conn) == relay_state_fun.(conn) + + for relay_state_param <- [1, fn -> "0arity" end, fn _, _ -> "2arity" end] do + assert Samly.State.init(Samly.State.ETS, [], relay_state_param) == :ok + + assert_raise RuntimeError, ~r/^Invalid relay_state/, fn -> + Samly.State.create_relay_state(conn) + end + end + end + describe "With Session Cache" do setup do opts =