Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/pull.yaml → .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
name: QA
name: CI

on:
pull_request:
push:
branches:
- main

Expand Down Expand Up @@ -79,4 +80,5 @@ jobs:
v2-${{ runner.os }}-plts
- run: mix deps.get
- run: mix compile --warnings-as-errors
- run: mix deps.unlock --check-unused
- run: mix lint
3 changes: 2 additions & 1 deletion lib/mix/tasks/rag.gen_eval.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defmodule Mix.Tasks.Rag.GenEval do
use Igniter.Mix.Task
alias Igniter.Project

@example "mix rag.gen_eval"

Expand Down Expand Up @@ -40,7 +41,7 @@ defmodule Mix.Tasks.Rag.GenEval do
rag_module = Module.concat(root_module, "Rag")

igniter
|> Igniter.Project.Config.configure(
|> Project.Config.configure(
"config.exs",
app_name,
[:openai_key],
Expand Down
31 changes: 16 additions & 15 deletions lib/mix/tasks/rag.install.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defmodule Mix.Tasks.Rag.Install do
use Igniter.Mix.Task
alias Igniter.{Libs, Project}

@example "mix rag.install --vector-store pgvector"

Expand Down Expand Up @@ -57,10 +58,10 @@ defmodule Mix.Tasks.Rag.Install do

igniter =
igniter
|> Igniter.Project.Deps.add_dep({:text_chunker, "~> 0.3.1"})
|> Igniter.Project.Deps.add_dep({:bumblebee, "~> 0.6.0"})
|> Igniter.Project.Deps.add_dep({:exla, "~> 0.9.1"})
|> Igniter.Project.Config.configure("config.exs", :nx, [:default_backend], EXLA.Backend)
|> Project.Deps.add_dep({:text_chunker, "~> 0.3.1"})
|> Project.Deps.add_dep({:bumblebee, "~> 0.6.0"})
|> Project.Deps.add_dep({:exla, "~> 0.9.1"})
|> Project.Config.configure("config.exs", :nx, [:default_backend], EXLA.Backend)
|> Igniter.compose_task("rag.gen_eval")
|> Igniter.compose_task("rag.gen_servings")
|> Igniter.compose_task("rag.gen_rag_module")
Expand All @@ -79,15 +80,15 @@ defmodule Mix.Tasks.Rag.Install do

defp with_chroma(igniter) do
igniter
|> Igniter.Project.Deps.add_dep({:chroma, "~> 0.1.3"})
|> Project.Deps.add_dep({:chroma, "~> 0.1.3"})
|> Igniter.apply_and_fetch_dependencies()
|> Igniter.Project.Config.configure("config.exs", :chroma, [:host], "http://localhost:8000")
|> Igniter.Project.Config.configure("config.exs", :chroma, [:api_base], "api")
|> Igniter.Project.Config.configure("config.exs", :chroma, [:api_version], "v1")
|> Project.Config.configure("config.exs", :chroma, [:host], "http://localhost:8000")
|> Project.Config.configure("config.exs", :chroma, [:api_base], "api")
|> Project.Config.configure("config.exs", :chroma, [:api_version], "v1")
end

defp with_pgvector(igniter) do
app_name = Igniter.Project.Application.app_name(igniter)
app_name = Project.Application.app_name(igniter)

root_module =
app_name
Expand All @@ -99,23 +100,23 @@ defmodule Mix.Tasks.Rag.Install do
schema_module = Module.concat(root_module, "Rag.Chunk")

igniter
|> Igniter.Project.Deps.add_dep({:ecto, "~> 3.0"})
|> Igniter.Project.Deps.add_dep({:ecto_sql, "~> 3.10"})
|> Igniter.Project.Deps.add_dep({:pgvector, "~> 0.3.0"})
|> Project.Deps.add_dep({:ecto, "~> 3.0"})
|> Project.Deps.add_dep({:ecto_sql, "~> 3.10"})
|> Project.Deps.add_dep({:pgvector, "~> 0.3.0"})
|> Igniter.apply_and_fetch_dependencies()
|> Igniter.include_or_create_file(
"lib/postgrex_types.ex",
"""
Postgrex.Types.define(#{inspect(postgrex_types_module)}, [Pgvector.Extensions.Vector] ++ Ecto.Adapters.Postgres.extensions(), [])
"""
)
|> Igniter.Project.Config.configure(
|> Project.Config.configure(
"config.exs",
app_name,
[repo_module, :types],
postgrex_types_module
)
|> Igniter.Project.Module.create_module(
|> Project.Module.create_module(
schema_module,
"""
@moduledoc \"""
Expand All @@ -138,7 +139,7 @@ defmodule Mix.Tasks.Rag.Install do
end
"""
)
|> Igniter.Libs.Ecto.gen_migration(repo_module, "create_chunks_table",
|> Libs.Ecto.gen_migration(repo_module, "create_chunks_table",
body: """
def up() do
execute("CREATE EXTENSION IF NOT EXISTS vector")
Expand Down
34 changes: 15 additions & 19 deletions lib/rag/ai/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@ defmodule Rag.Ai.Nx do
end

def generate_embeddings(%__MODULE__{} = provider, texts, _opts) when is_list(texts) do
try do
embeddings =
Nx.Serving.batched_run(provider.embeddings_serving, texts)
|> Enum.map(&Nx.to_list(&1.embedding))

{:ok, embeddings}
rescue
error ->
{:error, error}
end
embeddings =
Nx.Serving.batched_run(provider.embeddings_serving, texts)
|> Enum.map(&Nx.to_list(&1.embedding))

{:ok, embeddings}
rescue
error ->
{:error, error}
end

@impl Rag.Ai.Provider
Expand All @@ -46,14 +44,12 @@ defmodule Rag.Ai.Nx do
end

def generate_text(%__MODULE__{} = provider, prompt, _opts) when is_binary(prompt) do
try do
%{results: [result]} =
Nx.Serving.batched_run(provider.text_serving, prompt)

{:ok, result.text}
rescue
error ->
{:error, error}
end
%{results: [result]} =
Nx.Serving.batched_run(provider.text_serving, prompt)

{:ok, result.text}
rescue
error ->
{:error, error}
end
end
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ defmodule Rag.MixProject do
defp aliases do
[
lint: [
"format --check-formatted"
"format --check-formatted",
"credo --strict"
]
]
end
Expand Down
4 changes: 1 addition & 3 deletions test/rag/embedding/http_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ defmodule Rag.Embedding.HttpTest do
use ExUnit.Case
use Mimic

alias Rag.Embedding
alias Rag.Generation
alias Rag.Ai
alias Rag.{Ai, Embedding, Generation}

setup do
%{provider: Ai.OpenAI.new(%{})}
Expand Down
5 changes: 2 additions & 3 deletions test/rag/embedding/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ defmodule Rag.Embedding.NxTest do
use ExUnit.Case
use Mimic

alias Rag.Embedding
alias Rag.Generation
alias Rag.{Ai, Embedding, Generation}

setup do
%{provider: Rag.Ai.Nx.new(%{embeddings_serving: TestEmbeddingsServing})}
%{provider: Ai.Nx.new(%{embeddings_serving: TestEmbeddingsServing})}
end

describe "generate_embedding/3" do
Expand Down
4 changes: 1 addition & 3 deletions test/rag/evaluation/http_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ defmodule Rag.Evaluation.HttpTest do
use ExUnit.Case
use Mimic

alias Rag.Generation
alias Rag.Evaluation
alias Rag.Ai
alias Rag.{Ai, Evaluation, Generation}

setup do
%{provider: Ai.OpenAI.new(%{})}
Expand Down
4 changes: 1 addition & 3 deletions test/rag/evaluation/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ defmodule Rag.Evaluation.NxTest do
use ExUnit.Case
use Mimic

alias Rag.Generation
alias Rag.Evaluation
alias Rag.Ai
alias Rag.{Ai, Evaluation, Generation}

setup do
%{provider: Ai.Nx.new(%{text_serving: TestTextServing})}
Expand Down
7 changes: 3 additions & 4 deletions test/rag/evaluation_test.exs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
defmodule Rag.EvaluationTest do
use ExUnit.Case

alias Rag.Generation
alias Rag.Evaluation
alias Rag.{Ai, Evaluation, Generation}

describe "evaluate_rag_triad/2" do
test "takes a query, context, and response and returns an evaluation with scores and reasoning" do
Expand Down Expand Up @@ -207,7 +206,7 @@ defmodule Rag.EvaluationTest do
@tag :integration_test
test "openai evaluation" do
api_key = System.get_env("OPENAI_API_KEY")
provider = Rag.Ai.OpenAI.new(text_model: "gpt-4o-mini", api_key: api_key)
provider = Ai.OpenAI.new(text_model: "gpt-4o-mini", api_key: api_key)

query = "When was Elixir 1.18.1 released?"
context = "Elixir 1.18.1 was released on 2024-12-24"
Expand All @@ -222,7 +221,7 @@ defmodule Rag.EvaluationTest do
@tag :integration_test
test "cohere evaluation" do
api_key = System.get_env("COHERE_API_KEY")
provider = Rag.Ai.Cohere.new(text_model: "command-r-plus-08-2024", api_key: api_key)
provider = Ai.Cohere.new(text_model: "command-r-plus-08-2024", api_key: api_key)

query = "When was Elixir 1.18.1 released?"
context = "Elixir 1.18.1 was released on 2024-12-24"
Expand Down
3 changes: 1 addition & 2 deletions test/rag/generation/http_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ defmodule Rag.Generation.HttpTest do
use ExUnit.Case
use Mimic

alias Rag.Generation
alias Rag.Ai
alias Rag.{Ai, Generation}

setup do
%{provider: Ai.OpenAI.new(%{})}
Expand Down
3 changes: 1 addition & 2 deletions test/rag/generation/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ defmodule Rag.Generation.NxTest do
use ExUnit.Case
use Mimic

alias Rag.Generation
alias Rag.Ai
alias Rag.{Ai, Generation}

setup do
%{provider: Ai.Nx.new(%{text_serving: TestTextServing})}
Expand Down