package com.arms.api.sample.rereading;

// import org.springframework.ai.chroma.vectorstore.ChromaVectorStore;

import com.arms.api.sample.logging.SimpleLogAdvisor;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

@RestController
@RequiredArgsConstructor
class SampleReReadingAdvisorController {

	private final ConcurrentHashMap<String, AtomicBoolean> streamStatus = new ConcurrentHashMap<>();

    private final ChatModel chatModel;

    private final VectorStore vectorStore;

	private final SimpleLogAdvisor simpleLogAdvisor;

    @GetMapping("/ai/reReadingAdvisor/stopStream")
    public String stopStream(@RequestParam("streamId") String streamId) {
        if (streamStatus.containsKey(streamId)) {
            streamStatus.get(streamId).set(true);
            return "스트림 " + streamId + " 중단";
        } else {
            return "해당 stream ID를 찾을 수 없습니다.";
        }
    }

//	localhost:31311/ai/reReadingAdvisor/generateStream?streamId=1234&message=PM에 대해서 설명해줘
    @GetMapping(value = "/ai/reReadingAdvisor/generateStream")
    public Flux<String> generateStream(@RequestParam(value = "message", defaultValue = "오늘 날씨 어때?") String message,
        @RequestParam("streamId") String streamId) {
        streamStatus.put(streamId, new AtomicBoolean(false));

		ChatClient chatClient = ChatClient.builder(chatModel)
//				.defaultAdvisors(new ReReadingAdvisor())
				.build();

		PromptTemplate promptTemplate = new PromptTemplate("""
			Context information is below.
	
			---------------------
			{context}
			---------------------
	
			Given the context information and no prior knowledge, answer the query.
	
			Follow these rules:
	
				주어진 요청에 대한 제공된 내용은 정보를 바탕으로, 사전 지식 없이 사용자 댓글에 답변하세요.
				그림이나 표가 들어가는 단어는 제외하고 찾아줘.
				문장이 끝나면 개행해줘.
				만약 요청에 답이 없다면, 영어로 대답하지 말고 한국어로 사용자에게 답변을 할 수 없다고 알려주세요.
	
			Query: {query}
	
			Answer:
		""");


		Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
			.documentRetriever(VectorStoreDocumentRetriever.builder()
				.similarityThreshold(0.5)
				.vectorStore(vectorStore)
				.build())
			.queryAugmenter(
				ContextualQueryAugmenter.builder()
					.allowEmptyContext(false)
					.promptTemplate(promptTemplate)
				.build())
			.build();

		Advisor reReadingAdvisor = new ReReadingAdvisor()
				.withOrder(retrievalAugmentationAdvisor.getOrder() + 1);
//		Advisor reReadingAdvisor = ReReadingAdvisor.builder()
//				.order(retrievalAugmentationAdvisor.getOrder() + 1)
//				.build();

		return chatClient.prompt()
//			.advisors(reReadingAdvisor)
//			.advisors(retrievalAugmentationAdvisor)
			.advisors(simpleLogAdvisor, retrievalAugmentationAdvisor, reReadingAdvisor)
			.user(message)
			.stream()
			.content()
			.takeUntil(data -> streamStatus.get(streamId).get())
			.doOnComplete(() -> streamStatus.remove(streamId))
			.doOnError(error -> streamStatus.remove(streamId));
    }
}
