SpringAILlmService.java
package de.mirkosertic.powerstaff.profilesearch.command;
import de.mirkosertic.powerstaff.auth.UserQueryService;
import de.mirkosertic.powerstaff.auth.PsUser;
import de.mirkosertic.powerstaff.auth.UserView;
import de.mirkosertic.powerstaff.profilesearch.query.LlmFreelancerContext;
import de.mirkosertic.powerstaff.profilesearch.query.LlmProjectContext;
import de.mirkosertic.powerstaff.profilesearch.query.ProfileSearchQueryService;
import io.modelcontextprotocol.client.McpSyncClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.ToolCallAdvisor;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import tools.jackson.databind.ObjectMapper;
import java.security.Principal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
public class SpringAILlmService implements LlmService {
private static final Logger logger = LoggerFactory.getLogger(SpringAILlmService.class);
private final ChatClient chatClient;
private final LlmChatClientFactory chatClientFactory;
private final McpClientFactory mcpClientFactory;
private final ProfileSearchCommandService commandService;
private final ProfileSearchQueryService queryService;
private final ObjectMapper objectMapper;
private final UserQueryService userQueryService;
public SpringAILlmService(final ChatClient chatClient, final LlmChatClientFactory chatClientFactory,
final McpClientFactory mcpClientFactory, final ProfileSearchCommandService commandService,
final ProfileSearchQueryService queryService, final ObjectMapper objectMapper,
final UserQueryService userQueryService) {
this.chatClient = chatClient;
this.chatClientFactory = chatClientFactory;
this.mcpClientFactory = mcpClientFactory;
this.commandService = commandService;
this.queryService = queryService;
this.objectMapper = objectMapper;
this.userQueryService = userQueryService;
}
private ChatClient resolveChatClient(final Principal principal) {
final String token = userQueryService.findByUsername(principal.getName())
.map(UserView::llmApiToken)
.filter(t -> t != null && !t.isBlank())
.or(() -> userQueryService.findFirstAdminLlmApiToken())
.orElse(null);
return chatClientFactory.create(token);
}
private McpSyncClient tryCreateMcpClient() {
try {
final McpSyncClient client = mcpClientFactory.createClient();
client.initialize();
return client;
} catch (final McpConnectionException e) {
logger.warn("MCP nicht verfügbar, Chat läuft ohne Tools: {}", e.getMessage());
return null;
}
}
private static void closeMcpClientQuietly(final McpSyncClient client) {
if (client != null) {
try {
client.close();
} catch (final Exception e) {
logger.warn("Fehler beim Schließen des MCP Clients", e);
}
}
}
// package-private für Tests: Routing von ChatResponse-Tokens auf ChatProgressCollector-Callbacks
void routeTokenToCollector(final ChatResponse response, final ChatProgressCollector collector) {
final Usage usage = response.getMetadata().getUsage();
if (!(usage instanceof EmptyUsage)) {
collector.reportUsage(usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens());
}
for (final var result : response.getResults()) {
final AssistantMessage output = result.getOutput();
if (output.getText() == null) {
final Object reasoningContent = output.getMetadata().get("reasoningContent");
if (reasoningContent != null && !reasoningContent.toString().isEmpty()) {
collector.thinkingToken(reasoningContent.toString());
} else {
logger.debug("ChatResponse ohne Text und ohne reasoningContent – ignoriert");
}
} else if (!output.getText().isEmpty()) {
collector.assistantResponseToken(output.getText());
}
if (result.getMetadata().getFinishReason() != null && "STOP".equals(result.getMetadata().getFinishReason())) {
collector.stopped();
}
}
}
private static void nullsafeAttributeAppend(final String attribute, final String value, final StringBuilder sb) {
sb.append(attribute).append(": ").append(value != null ? value : "--Unbekannt--").append("\n");
}
private String resolveSystemPrompt(final Principal principal, final Optional<LlmProjectContext> context) {
final var template = userQueryService.findByUsername(principal.getName())
.flatMap(u -> Optional.ofNullable(u.profileSearchSystemPrompt()))
.orElse(PsUser.DEFAULT_SYSTEM_PROMPT);
final StringBuilder result = new StringBuilder(new PromptTemplate(template).render(Map.of("user", principal.getName())));
result.append("\nWenn Du nach Freiberuflern oder Kandidaten suchst, kannst Du im Lucene Index suchen!\n");
if (context.isPresent()) {
final LlmProjectContext project = context.get();
result.append("\nDas aktuelle Projekt ist:\n");
nullsafeAttributeAppend("**Projektnummer**", project.projectNumber(), result);
nullsafeAttributeAppend("**Kurzbeschreibung**", project.descriptionShort(), result);
nullsafeAttributeAppend("**Beschreibung**", project.descriptionLong(), result);
nullsafeAttributeAppend("**Einsatzort**", project.workplace(), result);
nullsafeAttributeAppend("**Benötigte Skills**", project.skills(), result);
nullsafeAttributeAppend("**Dauer**", project.duration(), result);
if (project.positions() != null && !project.positions().isEmpty()) {
result.append("\n\nBereits dem aktuellen Projekt zugewiesen\n");
for (final LlmFreelancerContext position : project.positions()) {
result.append(" - ");
nullsafeAttributeAppend("**Kodierung bzw. Dateiname**", position.code(), result);
result.append(" ");
nullsafeAttributeAppend("**Status**", position.positionStatus(), result);
result.append(" ");
nullsafeAttributeAppend("**Skills**", position.skills(), result);
}
}
}
return result.toString();
}
@Override
public List<Reply> sendMessage(final Principal principal, final String sessionId, final String conversationId, final Optional<LlmProjectContext> context, final String userMessage) {
final var progressCollector = new ChatProgressCollector() {
final StringBuilder assistantThoughts = new StringBuilder();
@Override
public void thinkingToken(final String token) {
assistantThoughts.append(token);
}
@Override
public String getAssistantThoughtsAndReset() {
if (!assistantThoughts.isEmpty()) {
final String result = assistantThoughts.toString();
assistantThoughts.setLength(0);
return result;
}
return "";
}
};
final var chatRepository = new SpringAIChatRepository(conversationId, queryService, commandService, objectMapper, progressCollector);
final McpSyncClient mcpClient = tryCreateMcpClient();
final ChatResponse chatClientResponse;
try {
var promptSpec = resolveChatClient(principal).prompt()
.advisors(
ToolCallAdvisor.builder()
.disableInternalConversationHistory()
.advisorOrder(BaseAdvisor.HIGHEST_PRECEDENCE + 300)
.build(),
MessageChatMemoryAdvisor.builder(
MessageWindowChatMemory.builder()
.maxMessages(10)
.chatMemoryRepository(chatRepository)
.build())
.conversationId(conversationId)
.build(),
new SimpleLoggerAdvisor()
)
.system(resolveSystemPrompt(principal, context))
.user(userMessage);
if (mcpClient != null) {
promptSpec = promptSpec.toolCallbacks(SyncMcpToolCallbackProvider.builder().addMcpClient(mcpClient).build());
}
chatClientResponse = promptSpec
.stream()
.chatResponse()
.filter(t -> {
routeTokenToCollector(t, progressCollector);
return true;
})
.blockLast();
} finally {
closeMcpClientQuietly(mcpClient);
}
int promptTokens = 0;
int completionTokens = 0;
if (chatClientResponse != null && chatClientResponse.getMetadata() != null) {
final Usage usage = chatClientResponse.getMetadata().getUsage();
if (usage != null) {
logger.info("Collected chat usage: {}", usage);
promptTokens = usage.getPromptTokens() != null ? usage.getPromptTokens() : 0;
completionTokens = usage.getCompletionTokens() != null ? usage.getCompletionTokens() : 0;
}
}
final List<Message> persistedMessages = chatRepository.getNewMessages();
final List<Reply> replies = new ArrayList<>();
// Skip the first message (index 0 = user message, already shown on frontend)
for (int i = 1; i < persistedMessages.size(); i++) {
final var message = persistedMessages.get(i);
if (message instanceof final UserMessage msg) {
replies.add(new Reply(-1, LlmService.ROLE_USER, msg.getText(), null, null, null));
} else if (message instanceof final AssistantMessage msg) {
if (msg.hasToolCalls()) {
final List<Map<String, Object>> toolCallOptions = new ArrayList<>();
for (final AssistantMessage.ToolCall toolCall : msg.getToolCalls()) {
final Map<String, Object> call = new HashMap<>();
call.put("name", toolCall.name());
call.put("arguments", toolCall.arguments());
call.put("type", toolCall.type());
call.put("id", toolCall.id());
toolCallOptions.add(call);
}
replies.add(new Reply(-1, LlmService.ROLE_TOOL_CALL, msg.getText(), objectMapper.writeValueAsString(toolCallOptions), null, null));
} else {
replies.add(new Reply(-1, LlmService.ROLE_ASSISTANT, msg.getText(), null, promptTokens, completionTokens));
}
} else if (message instanceof final ToolResponseMessage msg) {
final List<Map<String, Object>> toolCallResponses = new ArrayList<>();
final List<String> toolCallNames = new ArrayList<>();
for (final ToolResponseMessage.ToolResponse toolResponse : msg.getResponses()) {
final Map<String, Object> response = new HashMap<>();
response.put("id", toolResponse.id());
response.put("name", toolResponse.name());
response.put("responseData", toolResponse.responseData());
toolCallResponses.add(response);
toolCallNames.add(toolResponse.name());
}
replies.add(new Reply(-1, LlmService.ROLE_TOOL_RESULT, objectMapper.writeValueAsString(toolCallNames), objectMapper.writeValueAsString(toolCallResponses), null, null));
} else {
logger.warn("Unsupported message type: {}", message);
}
}
return replies;
}
@Override
public void sendMessageStreaming(final Principal principal, final String sessionId, final String conversationId,
final Optional<LlmProjectContext> context, final String userMessage,
final Consumer<LlmService.ChatStreamEvent> eventSink) {
final var progressCollector = new StreamingChatProgressCollector(eventSink);
final var chatRepository = new SpringAIChatRepository(conversationId, queryService, commandService, objectMapper, progressCollector);
final McpSyncClient mcpClient = tryCreateMcpClient();
final ChatResponse chatClientResponse;
try {
var promptSpec = resolveChatClient(principal).prompt()
.advisors(
ToolCallAdvisor.builder()
.disableInternalConversationHistory()
.advisorOrder(BaseAdvisor.HIGHEST_PRECEDENCE + 300)
.build(),
MessageChatMemoryAdvisor.builder(
MessageWindowChatMemory.builder()
.maxMessages(10)
.chatMemoryRepository(chatRepository)
.build())
.conversationId(conversationId)
.build(),
new SimpleLoggerAdvisor()
)
.system(resolveSystemPrompt(principal, context))
.user(userMessage);
if (mcpClient != null) {
promptSpec = promptSpec.toolCallbacks(SyncMcpToolCallbackProvider.builder().addMcpClient(mcpClient).build());
}
chatClientResponse = promptSpec
.stream()
.chatResponse()
.filter(t -> {
routeTokenToCollector(t, progressCollector);
return true;
})
.blockLast();
} finally {
closeMcpClientQuietly(mcpClient);
}
int promptTokens = 0;
int completionTokens = 0;
if (chatClientResponse != null && chatClientResponse.getMetadata() != null) {
final Usage usage = chatClientResponse.getMetadata().getUsage();
if (usage != null) {
logger.info("Collected streaming chat usage: {}", usage);
promptTokens = usage.getPromptTokens() != null ? usage.getPromptTokens() : 0;
completionTokens = usage.getCompletionTokens() != null ? usage.getCompletionTokens() : 0;
}
}
// maxContextTokens wird vom Controller ergänzt (er kennt ProfileSearchProperties)
eventSink.accept(new LlmService.ChatStreamEvent.MessageComplete(
chatRepository.getLastAssistantMessageId(), promptTokens, completionTokens, 0));
}
// package-private für Tests
static final class StreamingChatProgressCollector implements ChatProgressCollector {
private final Consumer<LlmService.ChatStreamEvent> eventSink;
private final StringBuilder assistantThoughts = new StringBuilder();
StreamingChatProgressCollector(final Consumer<LlmService.ChatStreamEvent> eventSink) {
this.eventSink = eventSink;
}
@Override
public void thinkingToken(final String token) {
assistantThoughts.append(token);
eventSink.accept(new LlmService.ChatStreamEvent.ThinkingToken(token));
}
@Override
public void assistantResponseToken(final String token) {
eventSink.accept(new LlmService.ChatStreamEvent.ContentToken(token));
}
@Override
public void toolInvocation(final String toolName, final String jsonPayload) {
eventSink.accept(new LlmService.ChatStreamEvent.ToolCall(toolName, jsonPayload));
}
@Override
public void toolResponses(final String toolNames, final String jsonPayload) {
eventSink.accept(new LlmService.ChatStreamEvent.ToolResult(toolNames, jsonPayload));
}
@Override
public String getAssistantThoughtsAndReset() {
if (!assistantThoughts.isEmpty()) {
final String result = assistantThoughts.toString();
assistantThoughts.setLength(0);
return result;
}
return "";
}
}
}