Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-elasticsearch'
implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:0.8.0-SNAPSHOT'
testImplementation 'org.springframework.boot:spring-boot-starter-test'
}

tasks.named('test') {
useJUnitPlatform()
testImplementation 'org.junit.jupiter:junit-jupiter:5.9.2'
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package de.cofinpro.springai.chatgpt;

import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package de.cofinpro.springai.retrieval_augmented_generation;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.retriever.VectorStoreRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;

Expand All @@ -19,8 +17,6 @@ public abstract class AbstractRetrievalAugmentedGenerationService {

private final VectorStore vectorStore;

private final VectorStoreRetriever vectorStoreRetriever;

private final OpenAiChatClient openAiChatClient;

private final SystemPromptTemplate systemPromptTemplate;
Expand All @@ -31,7 +27,6 @@ public AbstractRetrievalAugmentedGenerationService(VectorStore vectorStore, Open
Resource bikesResource) {
this.vectorStore = vectorStore;
this.openAiChatClient = openAiChatClient;
vectorStoreRetriever = new VectorStoreRetriever(vectorStore);
systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource);
this.bikesResource = bikesResource;
}
Expand All @@ -46,7 +41,7 @@ public void ingestDocuments() {
}

public String retrievalAugmentedGeneration(String message) {
final var similarDocuments = vectorStoreRetriever.retrieve(message);
final var similarDocuments = vectorStore.similaritySearch(message);
final var joinedDocuments = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", joinedDocuments));
final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message)));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package de.cofinpro.springai.retrieval_augmented_generation.elasticsearch;

import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService;
import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
package de.cofinpro.springai.retrieval_augmented_generation.simple;

import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.retriever.VectorStoreRetriever;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Service
public class SimpleRetrievalAugmentedGenerationService extends AbstractRetrievalAugmentedGenerationService {
Expand Down
4 changes: 4 additions & 0 deletions src/main/resources/application.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
spring:
elasticsearch:
uris: http://localhost:9200
ai:
azure:
openai:
api-key: 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