package com.arms.api.sample; // import org.springframework.ai.chroma.vectorstore.ChromaVectorStore; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import com.arms.api.sample.logging.SimpleLogAdvisor; import com.arms.api.sample.safeguard.SafeGuardAdvisor; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import lombok.RequiredArgsConstructor; import reactor.core.publisher.Flux; @RestController @RequiredArgsConstructor class SampleController { private final ConcurrentHashMap streamStatus = new ConcurrentHashMap<>(); private final ChatModel chatModel; private final VectorStore vectorStore; private final MyPagePdfDocumentReader myPagePdfDocumentReader; private final SimpleLogAdvisor simpleLogAdvisor; private final SafeGuardAdvisor safeGuardAdvisor; @GetMapping("/ai/test") public List getTest(@RequestParam(value = "message") String message) { List results = vectorStore.similaritySearch(message); return results; } @GetMapping("/ai/my-page-pdf-document-reader") public List getMyPagePdfDocumentReader() { return myPagePdfDocumentReader.getDocsFromPdf(); } @GetMapping("/ai/stopStream") public String stopStream(@RequestParam("streamId") String streamId) { if (streamStatus.containsKey(streamId)) { streamStatus.get(streamId).set(true); return "스트림 " + streamId + " 중단"; } else { return "해당 stream ID를 찾을 수 없습니다."; } } @GetMapping(value = "/ai/generateStream") public Flux generateStream(@RequestParam(value = "message", defaultValue = "오늘 날씨 어때?") String message, @RequestParam("streamId") String streamId) { streamStatus.put(streamId, new AtomicBoolean(false)); ChatClient chatClient = ChatClient.builder(chatModel) .build(); PromptTemplate promptTemplate = new PromptTemplate(""" Context information is below. --------------------- {context} --------------------- Given the context information and no prior knowledge, answer the query. Follow these rules: 주어진 요청에 대한 제공된 내용은 정보를 바탕으로, 사전 지식 없이 사용자 댓글에 답변하세요. 그림이나 표가 들어가는 단어는 제외하고 찾아줘. 문장이 끝나면 개행해줘. 만약 요청에 답이 없다면, 영어로 대답하지 말고 한국어로 사용자에게 답변을 할 수 없다고 알려주세요. Query: {query} Answer: """); Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() .documentRetriever(VectorStoreDocumentRetriever.builder() .similarityThreshold(0.5) .vectorStore(vectorStore) .build()) .queryAugmenter( ContextualQueryAugmenter.builder() .allowEmptyContext(false) .promptTemplate(promptTemplate) .build()) .build(); return chatClient.prompt() .advisors(simpleLogAdvisor, safeGuardAdvisor, retrievalAugmentationAdvisor) .user(message) .stream() .content() .takeUntil(data -> streamStatus.get(streamId).get()) .doOnComplete(() -> streamStatus.remove(streamId)) .doOnError(error -> streamStatus.remove(streamId)); } }