diff --git a/lib/rag/generation.ex b/lib/rag/generation.ex index 6ddd35b..abd9b7a 100644 --- a/lib/rag/generation.ex +++ b/lib/rag/generation.ex @@ -45,16 +45,23 @@ defmodule Rag.Generation do @doc """ Puts `query_embedding` in `generation.query_embedding`. """ - @spec put_query_embedding(t(), query_embedding :: list(number())) :: t() + @spec put_query_embedding(t(), query_embedding :: list(number()) | (t() -> list(number()))) :: + t() def put_query_embedding(%Generation{} = generation, query_embedding), - do: %{generation | query_embedding: query_embedding} + do: %{generation | query_embedding: get_value(generation, query_embedding)} @doc """ Puts `retrieval_result` at `key` in `generation.retrieval_results`. """ - @spec put_retrieval_result(t(), key :: atom(), retrieval_result :: map()) :: t() + @spec put_retrieval_result(t(), key :: atom(), retrieval_result :: map() | (t() -> map())) :: + t() def put_retrieval_result(%Generation{} = generation, key, retrieval_result), - do: put_in(generation, [Access.key!(:retrieval_results), key], retrieval_result) + do: + put_in( + generation, + [Access.key!(:retrieval_results), get_value(generation, key)], + retrieval_result + ) @doc """ Gets the retrieval result at `key` in `generation.retrieval_results`. @@ -66,38 +73,38 @@ defmodule Rag.Generation do @doc """ Puts `context` in `generation.context`. """ - @spec put_context(t(), context :: String.t()) :: t() - def put_context(%Generation{} = generation, context) when is_binary(context), - do: %{generation | context: context} + @spec put_context(t(), context :: String.t() | (t() -> String.t())) :: t() + def put_context(%Generation{} = generation, context), + do: %{generation | context: get_value(generation, context)} @doc """ Puts `context_sources` in `generation.context_sources`. """ - @spec put_context_sources(t(), context_sources :: list(String.t())) :: t() - def put_context_sources(%Generation{} = generation, context_sources) - when is_list(context_sources), - do: %{generation | context_sources: context_sources} + @spec put_context_sources(t(), context_sources :: list(String.t()) | (t() -> list(String.t()))) :: + t() + def put_context_sources(%Generation{} = generation, context_sources), + do: %{generation | context_sources: get_value(generation, context_sources)} @doc """ Puts `prompt` in `generation.prompt`. """ - @spec put_prompt(t(), prompt :: String.t()) :: t() - def put_prompt(%Generation{} = generation, prompt) when is_binary(prompt), - do: %{generation | prompt: prompt} + @spec put_prompt(t(), prompt :: String.t() | (t() -> String.t())) :: t() + def put_prompt(%Generation{} = generation, prompt), + do: %{generation | prompt: get_value(generation, prompt)} @doc """ Puts `response` in `generation.response`. """ - @spec put_response(t(), response :: String.t()) :: t() - def put_response(%Generation{} = generation, response) when is_binary(response), - do: %{generation | response: response} + @spec put_response(t(), response :: String.t() | (t() -> String.t())) :: t() + def put_response(%Generation{} = generation, response), + do: %{generation | response: get_value(generation, response)} @doc """ Puts `evaluation` at `key` in `generation.evaluations`. """ - @spec put_evaluation(t(), key :: atom(), evaluation :: any()) :: t() + @spec put_evaluation(t(), key :: atom(), evaluation :: any | (t() -> any())) :: t() def put_evaluation(%Generation{} = generation, key, evaluation), - do: put_in(generation, [Access.key!(:evaluations), key], evaluation) + do: put_in(generation, [Access.key!(:evaluations), key], get_value(generation, evaluation)) @doc """ Gets the evaluation at `key` in `generation.evaluations`. @@ -147,4 +154,7 @@ defmodule Rag.Generation do {generation, %{metadata | generation: generation}} end) end + + defp get_value(generation, value) when is_function(value, 1), do: value.(generation) + defp get_value(_generation, value), do: value end