Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ Dies geschieht mithilfe von `docker compose -f elasticsearch/docker-compose.yml
### REST-Requests absetzen
In der Datei rest-api.http sind mehrere REST-Requests dokumentiert. Diese können in IntelliJ per Klick aufgerufen werden

### Milvus Vektordatenbank starten
Damit die Anwendung mit der Vektordatenbank Milvus verwendet werden kann, muss per Docker oder minikube/helm eine Milvus instanz gestartet werden. Anschließend können host und port in der application.yml gesetzt werden.
Durch die Konfiguration von milvus host & port werden alle notwendigen Beans von Spring-AI automatisch initialisiert

## Anfrage an ChatGPT
Um eine Anfrage an ChatGPT zu stellen, kann der Endpunkt `GET localhost:8080/chatgpt?message=MeineAnfrage` aufgerufen werden.
Die Anfrage wird über Spring AI direkt an ChatGPT weitergeleitet und das Ergebnis zurückgegeben.
Expand Down
12 changes: 7 additions & 5 deletions build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
plugins {
id 'java'
id 'org.springframework.boot' version '3.2.1'
id 'org.springframework.boot' version '3.2.2'
id 'io.spring.dependency-management' version '1.1.4'
}

Expand All @@ -22,10 +22,12 @@ dependencies {
implementation 'org.springframework.boot:spring-boot-starter'
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.boot:spring-boot-starter-data-elasticsearch'

// Spring AI Dependencies
implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:0.8.0-SNAPSHOT'
testImplementation 'org.springframework.boot:spring-boot-starter-test'
}
implementation 'org.springframework.ai:spring-ai-milvus-store:0.8.0-SNAPSHOT'

tasks.named('test') {
useJUnitPlatform()
testImplementation 'org.springframework.boot:spring-boot-starter-test'
testImplementation 'org.junit.jupiter:junit-jupiter:5.9.2'
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
}
8 changes: 8 additions & 0 deletions rest-api.http
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ POST localhost:8080/rag/elasticsearch/ingest
GET localhost:8080/rag/elasticsearch?message=ultimate%20mountain%20bike


### Ingest data to milvus vector store
POST localhost:8080/rag/milvus/ingest

### RAG-Query using milvus vector store
GET localhost:8080/rag/milvus?message=ultimate%20mountain%20bike



Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package de.cofinpro.springai.chatgpt;

import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
Expand All @@ -20,6 +20,6 @@ public ChatGPTController(OpenAiChatClient openAiChatClient) {

@GetMapping
public String queryGpt(@RequestParam("message") String message) {
return openAiChatClient.generate(message);
return openAiChatClient.call(message);
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package de.cofinpro.springai.retrieval_augmented_generation;

import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.retriever.VectorStoreRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;

Expand All @@ -19,8 +17,6 @@ public abstract class AbstractRetrievalAugmentedGenerationService {

private final VectorStore vectorStore;

private final VectorStoreRetriever vectorStoreRetriever;

private final OpenAiChatClient openAiChatClient;

private final SystemPromptTemplate systemPromptTemplate;
Expand All @@ -31,7 +27,6 @@ public AbstractRetrievalAugmentedGenerationService(VectorStore vectorStore, Open
Resource bikesResource) {
this.vectorStore = vectorStore;
this.openAiChatClient = openAiChatClient;
vectorStoreRetriever = new VectorStoreRetriever(vectorStore);
systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource);
this.bikesResource = bikesResource;
}
Expand All @@ -46,10 +41,10 @@ public void ingestDocuments() {
}

public String retrievalAugmentedGeneration(String message) {
final var similarDocuments = vectorStoreRetriever.retrieve(message);
final var similarDocuments = vectorStore.similaritySearch(message);
final var joinedDocuments = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", joinedDocuments));
final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message)));
return openAiChatClient.generate(prompt).getGeneration().getContent();
return openAiChatClient.call(prompt).getResult().toString();
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package de.cofinpro.springai.retrieval_augmented_generation.elasticsearch;

import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService;
import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package de.cofinpro.springai.retrieval_augmented_generation.milvus;

import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

@RestController
@RequestMapping("/rag/milvus")
public class MilvusRetrievalAugmentedGenerationController {

private final MilvusRetrievalAugmentedGenerationService milvusRetrievalAugmentedGenerationService;

public MilvusRetrievalAugmentedGenerationController(MilvusRetrievalAugmentedGenerationService milvusRetrievalAugmentedGenerationService) {
this.milvusRetrievalAugmentedGenerationService = milvusRetrievalAugmentedGenerationService;
}

@PostMapping("/ingest")
public ResponseEntity<Void> ingestDocuments() {
milvusRetrievalAugmentedGenerationService.ingestDocuments();
return new ResponseEntity<>(HttpStatus.OK);
}

@GetMapping
public String retrievalAugmentedGeneration(@RequestParam("message") String message) {
return milvusRetrievalAugmentedGenerationService.retrievalAugmentedGeneration(message);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package de.cofinpro.springai.retrieval_augmented_generation.milvus;

import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.Map;

@Service
public class MilvusRetrievalAugmentedGenerationService {

private final OpenAiChatClient openAiChatClient;
private final Resource bikesResource;
private final SystemPromptTemplate systemPromptTemplate;

private final VectorStore vectorStore;

@Autowired
public MilvusRetrievalAugmentedGenerationService(OpenAiChatClient openAiChatClient,
VectorStore vectorStore,
@Value("classpath:/bikes.json") Resource bikesResource,
@Value("classpath:/system-prompt-template") Resource systemPromptTemplateResource) {
this.openAiChatClient = openAiChatClient;
this.vectorStore = vectorStore;
this.bikesResource = bikesResource;
this.systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource);
}

public void ingestDocuments() {
final var jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription");
vectorStore.add(jsonReader.get());
}

public String retrievalAugmentedGeneration(String message) {
final var test = vectorStore.similaritySearch(SearchRequest.query(message));
final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", test));
final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message)));
return openAiChatClient.call(prompt).getResult().toString();
}
}
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
package de.cofinpro.springai.retrieval_augmented_generation.simple;

import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.openai.client.OpenAiChatClient;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.retriever.VectorStoreRetriever;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Service
public class SimpleRetrievalAugmentedGenerationService extends AbstractRetrievalAugmentedGenerationService {
Expand Down
13 changes: 13 additions & 0 deletions src/main/resources/application.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
spring:
elasticsearch:
uris: http://localhost:9200
ai:
openai:
api-key:
chat:
options:
temperature: 0.5
user:
vectorstore:
milvus:
client:
host: localhost
port: 56875