SpringAIChatRepository.java
package de.mirkosertic.powerstaff.profilesearch.command;
import de.mirkosertic.powerstaff.profilesearch.query.ProfileSearchQueryService;
import org.jspecify.annotations.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import tools.jackson.core.JacksonException;
import tools.jackson.core.type.TypeReference;
import tools.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class SpringAIChatRepository implements ChatMemoryRepository {
private static final Logger logger = LoggerFactory.getLogger(SpringAIChatRepository.class);
private final String conversationId;
private final ProfileSearchQueryService queryService;
private final ProfileSearchCommandService commandService;
private final List<Message> newMessages;
private final ObjectMapper objectMapper;
private final ChatProgressCollector chatProgressCollector;
public SpringAIChatRepository(final String conversationId, final ProfileSearchQueryService queryService, final ProfileSearchCommandService commandService, final ObjectMapper objectMapper, final ChatProgressCollector chatProgressCollector) {
this.conversationId = conversationId;
this.queryService = queryService;
this.commandService = commandService;
this.newMessages = new ArrayList<>();
this.objectMapper = objectMapper;
this.chatProgressCollector = chatProgressCollector;
}
@Override
public List<String> findConversationIds() {
return List.of(conversationId);
}
@Override
public List<Message> findByConversationId(final @NonNull String s) {
if (!conversationId.equals(s)) {
throw new IllegalArgumentException("Invalid conversationId: " + s);
}
return queryService.findMessagesByChat(Long.parseLong(s)).stream()
.map(mv -> {
if (LlmService.ROLE_USER.equals(mv.role())) {
return new PersistentUserMessage(mv.content());
} else if (LlmService.ROLE_ASSISTANT.equals(mv.role())) {
return new PersistentAssistantMessage(mv.content());
} else if (LlmService.ROLE_TOOL_CALL.equals(mv.role())) {
final var toolCalls = new ArrayList<AssistantMessage.ToolCall>();
try {
final List<Map<String, Object>> unmarshalled = objectMapper.readValue(mv.jsonPayload(), new TypeReference<ArrayList<Map<String, Object>>>() {
});
for (final Map<String, Object> call : unmarshalled) {
toolCalls.add(new AssistantMessage.ToolCall(call.get("id").toString(), call.get("type").toString(), call.get("name").toString(), call.get("arguments").toString()));
}
} catch (final JacksonException ex) {
logger.warn("Cannot parse tool call payload: {}", mv.jsonPayload(), ex);
}
return new PersistentAssistantMessage(mv.content(), toolCalls);
} else if (LlmService.ROLE_TOOL_RESULT.equals(mv.role())) {
final var toolResponses = new ArrayList<ToolResponseMessage.ToolResponse>();
try {
final List<Map<String, Object>> unmarshalled = objectMapper.readValue(mv.jsonPayload(), new TypeReference<ArrayList<Map<String, Object>>>() {
});
for (final Map<String, Object> call : unmarshalled) {
toolResponses.add(new ToolResponseMessage.ToolResponse(call.get("id").toString(), call.get("name").toString(), call.get("responseData").toString()));
}
} catch (final JacksonException ex) {
logger.warn("Cannot parse tool call payload: {}", mv.jsonPayload(), ex);
}
return new PersistentToolResponseMessage(toolResponses);
} else {
logger.warn("Ignoring persistent message: {}", mv);
return null;
}
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
@Override
public void saveAll(final @NonNull String s, final List<Message> list) {
for (final Message message : list) {
if (!(message instanceof PersistentMessage)) {
if (message instanceof final UserMessage x) {
logger.info("Persisting UserMessage: {}", message);
commandService.addMessage(Long.parseLong(s), LlmService.ROLE_USER, x.getText());
newMessages.add(message);
} else if (message instanceof final AssistantMessage x) {
logger.info("Persisting AssistantMessage: {}", message);
if (x.hasToolCalls()) {
final List<Map<String, Object>> toolCallOptions = new ArrayList<>();
for (final AssistantMessage.ToolCall toolCall : x.getToolCalls()) {
final Map<String, Object> call = new HashMap<>();
call.put("name", toolCall.name());
call.put("arguments", toolCall.arguments());
call.put("type", toolCall.type());
call.put("id", toolCall.id());
toolCallOptions.add(call);
}
final String jsonPayload = objectMapper.writeValueAsString(toolCallOptions);
commandService.addMessage(Long.parseLong(s), LlmService.ROLE_TOOL_CALL, x.getText(), jsonPayload, chatProgressCollector.getAssistantThoughtsAndReset());
chatProgressCollector.toolInvocation(x.getText(), jsonPayload);
} else {
final var saved = commandService.addMessage(Long.parseLong(s), LlmService.ROLE_ASSISTANT, x.getText(), null, chatProgressCollector.getAssistantThoughtsAndReset());
lastAssistantMessageId = saved.getId();
}
newMessages.add(message);
} else if (message instanceof final ToolResponseMessage x) {
final List<Map<String, Object>> toolCallOptions = new ArrayList<>();
final List<String> toolCallNames = new ArrayList<>();
for (final ToolResponseMessage.ToolResponse toolResponse : x.getResponses()) {
final Map<String, Object> call = new HashMap<>();
call.put("name", toolResponse.name());
call.put("responseData", toolResponse.responseData());
call.put("id", toolResponse.id());
toolCallOptions.add(call);
toolCallNames.add(toolResponse.name());
}
final String toolnames = objectMapper.writeValueAsString(toolCallNames);
final String jsonPayload = objectMapper.writeValueAsString(toolCallOptions);
commandService.addMessage(Long.parseLong(s), LlmService.ROLE_TOOL_RESULT, toolnames, jsonPayload, null);
chatProgressCollector.toolResponses(toolnames, jsonPayload);
newMessages.add(message);
} else {
logger.warn("Cannot persist message: {}", message);
}
}
}
}
@Override
public void deleteByConversationId(final @NonNull String s) {
logger.info("Deleting messages for conversationId: {}", s);
}
public List<Message> getNewMessages() {
return newMessages;
}
public long getLastAssistantMessageId() {
return lastAssistantMessageId;
}
private long lastAssistantMessageId = -1;
}