diff --git a/build.gradle b/build.gradle index 2ba85f1..287edce 100644 --- a/build.gradle +++ b/build.gradle @@ -24,8 +24,6 @@ dependencies { implementation 'org.springframework.boot:spring-boot-starter-data-elasticsearch' implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:0.8.0-SNAPSHOT' testImplementation 'org.springframework.boot:spring-boot-starter-test' -} - -tasks.named('test') { - useJUnitPlatform() + testImplementation 'org.junit.jupiter:junit-jupiter:5.9.2' + testRuntimeOnly 'org.junit.platform:junit-platform-launcher' } diff --git a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java index 9189ca4..4d63edc 100644 --- a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java +++ b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java @@ -1,6 +1,6 @@ package de.cofinpro.springai.chatgpt; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java index cb36041..468fa45 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java @@ -1,13 +1,11 @@ package de.cofinpro.springai.retrieval_augmented_generation; import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.prompt.Prompt; import org.springframework.ai.prompt.SystemPromptTemplate; import org.springframework.ai.prompt.messages.UserMessage; import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.retriever.VectorStoreRetriever; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.core.io.Resource; @@ -19,8 +17,6 @@ public abstract class AbstractRetrievalAugmentedGenerationService { private final VectorStore vectorStore; - private final VectorStoreRetriever vectorStoreRetriever; - private final OpenAiChatClient openAiChatClient; private final SystemPromptTemplate systemPromptTemplate; @@ -31,7 +27,6 @@ public AbstractRetrievalAugmentedGenerationService(VectorStore vectorStore, Open Resource bikesResource) { this.vectorStore = vectorStore; this.openAiChatClient = openAiChatClient; - vectorStoreRetriever = new VectorStoreRetriever(vectorStore); systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource); this.bikesResource = bikesResource; } @@ -46,7 +41,7 @@ public void ingestDocuments() { } public String retrievalAugmentedGeneration(String message) { - final var similarDocuments = vectorStoreRetriever.retrieve(message); + final var similarDocuments = vectorStore.similaritySearch(message); final var joinedDocuments = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n")); final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", joinedDocuments)); final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message))); diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java index b2ef7b6..2e89f48 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java @@ -1,7 +1,7 @@ package de.cofinpro.springai.retrieval_augmented_generation.elasticsearch; import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java index d418275..a3e11a6 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java @@ -1,23 +1,14 @@ package de.cofinpro.springai.retrieval_augmented_generation.simple; import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; -import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.openai.client.OpenAiChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.UserMessage; -import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.retriever.VectorStoreRetriever; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.vectorstore.SimpleVectorStore; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; import java.io.File; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; @Service public class SimpleRetrievalAugmentedGenerationService extends AbstractRetrievalAugmentedGenerationService { diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index ea2ddd5..a92bc0e 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -1,3 +1,7 @@ spring: elasticsearch: uris: http://localhost:9200 + ai: + azure: + openai: + api-key: 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