Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions lib/rag/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,23 @@ defmodule Rag.Generation do
@doc """
Puts `query_embedding` in `generation.query_embedding`.
"""
@spec put_query_embedding(t(), query_embedding :: list(number())) :: t()
@spec put_query_embedding(t(), query_embedding :: list(number()) | (t() -> list(number()))) ::
t()
def put_query_embedding(%Generation{} = generation, query_embedding),
do: %{generation | query_embedding: query_embedding}
do: %{generation | query_embedding: get_value(generation, query_embedding)}

@doc """
Puts `retrieval_result` at `key` in `generation.retrieval_results`.
"""
@spec put_retrieval_result(t(), key :: atom(), retrieval_result :: map()) :: t()
@spec put_retrieval_result(t(), key :: atom(), retrieval_result :: map() | (t() -> map())) ::
t()
def put_retrieval_result(%Generation{} = generation, key, retrieval_result),
do: put_in(generation, [Access.key!(:retrieval_results), key], retrieval_result)
do:
put_in(
generation,
[Access.key!(:retrieval_results), get_value(generation, key)],
retrieval_result
)

@doc """
Gets the retrieval result at `key` in `generation.retrieval_results`.
Expand All @@ -66,38 +73,38 @@ defmodule Rag.Generation do
@doc """
Puts `context` in `generation.context`.
"""
@spec put_context(t(), context :: String.t()) :: t()
def put_context(%Generation{} = generation, context) when is_binary(context),
do: %{generation | context: context}
@spec put_context(t(), context :: String.t() | (t() -> String.t())) :: t()
def put_context(%Generation{} = generation, context),
do: %{generation | context: get_value(generation, context)}

@doc """
Puts `context_sources` in `generation.context_sources`.
"""
@spec put_context_sources(t(), context_sources :: list(String.t())) :: t()
def put_context_sources(%Generation{} = generation, context_sources)
when is_list(context_sources),
do: %{generation | context_sources: context_sources}
@spec put_context_sources(t(), context_sources :: list(String.t()) | (t() -> list(String.t()))) ::
t()
def put_context_sources(%Generation{} = generation, context_sources),
do: %{generation | context_sources: get_value(generation, context_sources)}

@doc """
Puts `prompt` in `generation.prompt`.
"""
@spec put_prompt(t(), prompt :: String.t()) :: t()
def put_prompt(%Generation{} = generation, prompt) when is_binary(prompt),
do: %{generation | prompt: prompt}
@spec put_prompt(t(), prompt :: String.t() | (t() -> String.t())) :: t()
def put_prompt(%Generation{} = generation, prompt),
do: %{generation | prompt: get_value(generation, prompt)}

@doc """
Puts `response` in `generation.response`.
"""
@spec put_response(t(), response :: String.t()) :: t()
def put_response(%Generation{} = generation, response) when is_binary(response),
do: %{generation | response: response}
@spec put_response(t(), response :: String.t() | (t() -> String.t())) :: t()
def put_response(%Generation{} = generation, response),
do: %{generation | response: get_value(generation, response)}

@doc """
Puts `evaluation` at `key` in `generation.evaluations`.
"""
@spec put_evaluation(t(), key :: atom(), evaluation :: any()) :: t()
@spec put_evaluation(t(), key :: atom(), evaluation :: any | (t() -> any())) :: t()
def put_evaluation(%Generation{} = generation, key, evaluation),
do: put_in(generation, [Access.key!(:evaluations), key], evaluation)
do: put_in(generation, [Access.key!(:evaluations), key], get_value(generation, evaluation))

@doc """
Gets the evaluation at `key` in `generation.evaluations`.
Expand Down Expand Up @@ -147,4 +154,7 @@ defmodule Rag.Generation do
{generation, %{metadata | generation: generation}}
end)
end

defp get_value(generation, value) when is_function(value, 1), do: value.(generation)
defp get_value(_generation, value), do: value
end