406 lines
10 KiB
Go
406 lines
10 KiB
Go
package chat
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"github.com/stig/goaichat/internal/config"
|
|
"github.com/stig/goaichat/internal/openai"
|
|
)
|
|
|
|
const defaultSessionPrefix = "session-"
|
|
|
|
// CompletionClient defines the subset of the OpenAI client used by the chat service.
|
|
type CompletionClient interface {
|
|
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
|
|
StreamChatCompletion(ctx context.Context, req openai.ChatCompletionRequest, handler openai.ChatCompletionStreamHandler) (*openai.ChatCompletionResponse, error)
|
|
}
|
|
|
|
// MessageRepository defines persistence hooks used by the chat service.
|
|
type MessageRepository interface {
|
|
CreateSession(ctx context.Context, name, model string) (int64, error)
|
|
AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error)
|
|
UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error
|
|
UpdateSessionName(ctx context.Context, sessionID int64, name string) error
|
|
}
|
|
|
|
// Service coordinates chat requests with the OpenAI client and maintains session history.
|
|
type Service struct {
|
|
logger *slog.Logger
|
|
client CompletionClient
|
|
repo MessageRepository
|
|
model string
|
|
temperature float64
|
|
stream bool
|
|
history []openai.ChatMessage
|
|
sessionID int64
|
|
summarySet bool
|
|
sessionName string
|
|
sessionNamed bool
|
|
nameSuggested bool
|
|
}
|
|
|
|
// NewService constructs a Service from configuration and an OpenAI-compatible client.
|
|
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient, repo MessageRepository) (*Service, error) {
|
|
if logger == nil {
|
|
return nil, errors.New("logger cannot be nil")
|
|
}
|
|
if client == nil {
|
|
return nil, errors.New("completion client cannot be nil")
|
|
}
|
|
if strings.TrimSpace(modelCfg.Name) == "" {
|
|
return nil, errors.New("model name cannot be empty")
|
|
}
|
|
|
|
return &Service{
|
|
logger: logger,
|
|
client: client,
|
|
repo: repo,
|
|
model: modelCfg.Name,
|
|
temperature: modelCfg.Temperature,
|
|
stream: modelCfg.Stream,
|
|
history: make([]openai.ChatMessage, 0, 16),
|
|
}, nil
|
|
}
|
|
|
|
// Send submits a user message and returns the assistant reply. When streamHandler is provided and streaming is enabled,
|
|
// partial responses are forwarded to the handler as they arrive.
|
|
func (s *Service) Send(ctx context.Context, input string, streamHandler openai.ChatCompletionStreamHandler) (string, error) {
|
|
if s == nil {
|
|
return "", errors.New("service is nil")
|
|
}
|
|
if ctx == nil {
|
|
return "", errors.New("context is nil")
|
|
}
|
|
|
|
content := strings.TrimSpace(input)
|
|
if content == "" {
|
|
return "", errors.New("input cannot be empty")
|
|
}
|
|
|
|
userMsg := openai.ChatMessage{Role: "user", Content: content}
|
|
s.history = append(s.history, userMsg)
|
|
if err := s.persistMessage(ctx, userMsg); err != nil {
|
|
s.logger.WarnContext(ctx, "failed to persist user message", "error", err)
|
|
}
|
|
|
|
messages := append([]openai.ChatMessage(nil), s.history...)
|
|
temperature := s.temperature
|
|
|
|
req := openai.ChatCompletionRequest{
|
|
Model: s.model,
|
|
Messages: messages,
|
|
Stream: s.stream,
|
|
Temperature: &temperature,
|
|
}
|
|
|
|
s.logger.DebugContext(ctx, "sending chat completion", "model", s.model, "message_count", len(messages))
|
|
|
|
var resp *openai.ChatCompletionResponse
|
|
var err error
|
|
if s.stream {
|
|
resp, err = s.client.StreamChatCompletion(ctx, req, streamHandler)
|
|
} else {
|
|
resp, err = s.client.CreateChatCompletion(ctx, req)
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return "", errors.New("no choices returned from completion")
|
|
}
|
|
|
|
reply := resp.Choices[0].Message
|
|
s.history = append(s.history, reply)
|
|
if err := s.persistMessage(ctx, reply); err != nil {
|
|
s.logger.WarnContext(ctx, "failed to persist assistant message", "error", err)
|
|
}
|
|
|
|
return reply.Content, nil
|
|
}
|
|
|
|
// History returns a copy of the current conversation history.
|
|
func (s *Service) History() []openai.ChatMessage {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
|
|
historyCopy := make([]openai.ChatMessage, len(s.history))
|
|
copy(historyCopy, s.history)
|
|
return historyCopy
|
|
}
|
|
|
|
// StreamingEnabled reports whether streaming completions are configured for this service.
|
|
func (s *Service) StreamingEnabled() bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
return s.stream
|
|
}
|
|
|
|
// Reset clears the in-memory conversation history.
|
|
func (s *Service) Reset() {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.history = s.history[:0]
|
|
s.sessionID = 0
|
|
s.summarySet = false
|
|
s.sessionName = ""
|
|
s.sessionNamed = false
|
|
s.nameSuggested = false
|
|
}
|
|
|
|
// RestoreSession replaces in-memory history with persisted messages for an existing session.
|
|
func (s *Service) RestoreSession(sessionID int64, name string, messages []openai.ChatMessage, summaryPresent bool) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
|
|
s.history = s.history[:0]
|
|
s.history = append(s.history, messages...)
|
|
s.sessionID = sessionID
|
|
s.summarySet = summaryPresent
|
|
s.sessionName = name
|
|
s.sessionNamed = !isAutoGeneratedName(name)
|
|
s.nameSuggested = false
|
|
}
|
|
|
|
// CurrentSessionID returns the active session identifier, if any.
|
|
func (s *Service) CurrentSessionID() int64 {
|
|
if s == nil {
|
|
return 0
|
|
}
|
|
return s.sessionID
|
|
}
|
|
|
|
// SessionName returns the current session name.
|
|
func (s *Service) SessionName() string {
|
|
if s == nil {
|
|
return ""
|
|
}
|
|
return s.sessionName
|
|
}
|
|
|
|
// SessionNamed reports whether the current session has a user-friendly name applied.
|
|
func (s *Service) SessionNamed() bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
return s.sessionNamed
|
|
}
|
|
|
|
// ShouldSuggestSessionName indicates whether the service would benefit from a generated name.
|
|
func (s *Service) ShouldSuggestSessionName() bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
return !s.sessionNamed && !s.nameSuggested
|
|
}
|
|
|
|
// MarkSessionNameSuggested prevents additional automatic suggestions for the current session.
|
|
func (s *Service) MarkSessionNameSuggested() {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.nameSuggested = true
|
|
}
|
|
|
|
// SetSessionName normalizes and persists a new name for the active session, returning the stored value.
|
|
func (s *Service) SetSessionName(ctx context.Context, name string) (string, error) {
|
|
if s == nil {
|
|
return "", errors.New("service is nil")
|
|
}
|
|
normalized := NormalizeSessionName(name)
|
|
if normalized == "" {
|
|
return "", errors.New("session name cannot be empty")
|
|
}
|
|
|
|
if s.repo != nil {
|
|
if s.sessionID <= 0 {
|
|
if err := s.ensureSession(ctx); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if err := s.repo.UpdateSessionName(ctx, s.sessionID, normalized); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
s.sessionName = normalized
|
|
s.sessionNamed = !isAutoGeneratedName(normalized)
|
|
s.nameSuggested = true
|
|
return normalized, nil
|
|
}
|
|
|
|
func (s *Service) ensureSession(ctx context.Context) error {
|
|
if s == nil || s.repo == nil {
|
|
return nil
|
|
}
|
|
if s.sessionID > 0 {
|
|
return nil
|
|
}
|
|
|
|
name := fmt.Sprintf("%s%s", defaultSessionPrefix, time.Now().Format("20060102-150405"))
|
|
s.logger.DebugContext(ctx, "creating chat session", "name", name)
|
|
|
|
id, err := s.repo.CreateSession(ctx, name, s.model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.sessionID = id
|
|
s.summarySet = false
|
|
s.sessionName = name
|
|
s.sessionNamed = false
|
|
s.nameSuggested = false
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) persistMessage(ctx context.Context, msg openai.ChatMessage) error {
|
|
if s.repo == nil {
|
|
return nil
|
|
}
|
|
|
|
if err := s.ensureSession(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
var tokens *int64
|
|
if _, err := s.repo.AddMessage(ctx, s.sessionID, msg.Role, msg.Content, tokens); err != nil {
|
|
return err
|
|
}
|
|
|
|
if msg.Role == "assistant" && !s.summarySet && strings.TrimSpace(msg.Content) != "" {
|
|
const maxSummaryLen = 120
|
|
summary := msg.Content
|
|
runes := []rune(summary)
|
|
if len(runes) > maxSummaryLen {
|
|
summary = string(runes[:maxSummaryLen]) + "..."
|
|
}
|
|
if err := s.repo.UpdateSessionSummary(ctx, s.sessionID, summary); err != nil {
|
|
s.logger.WarnContext(ctx, "failed to update session summary", "error", err)
|
|
} else {
|
|
s.summarySet = true
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SuggestSessionName uses the backing LLM to propose a descriptive session name.
|
|
func (s *Service) SuggestSessionName(ctx context.Context) (string, error) {
|
|
if s == nil {
|
|
return "", errors.New("service is nil")
|
|
}
|
|
if s.client == nil {
|
|
return "", errors.New("completion client is unavailable")
|
|
}
|
|
if len(s.history) == 0 {
|
|
return "", errors.New("no conversation history available")
|
|
}
|
|
|
|
start := 0
|
|
if len(s.history) > 10 {
|
|
start = len(s.history) - 10
|
|
}
|
|
|
|
var builder strings.Builder
|
|
for i := start; i < len(s.history); i++ {
|
|
msg := s.history[i]
|
|
builder.WriteString(msg.Role)
|
|
builder.WriteString(": ")
|
|
builder.WriteString(msg.Content)
|
|
builder.WriteString("\n")
|
|
}
|
|
|
|
temp := 0.3
|
|
req := openai.ChatCompletionRequest{
|
|
Model: s.model,
|
|
Messages: []openai.ChatMessage{
|
|
{Role: "system", Content: "You generate concise, descriptive chat session names. Respond with a lowercase kebab-case identifier (letters, digits, hyphens only, max 6 words)."},
|
|
{Role: "user", Content: fmt.Sprintf("Conversation excerpt:\n%s\nProvide only the session name.", builder.String())},
|
|
},
|
|
Temperature: &temp,
|
|
}
|
|
|
|
resp, err := s.client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return "", errors.New("no session name suggestion returned")
|
|
}
|
|
|
|
suggestion := NormalizeSessionName(resp.Choices[0].Message.Content)
|
|
if suggestion == "" {
|
|
return "", errors.New("empty session name suggestion")
|
|
}
|
|
|
|
return suggestion, nil
|
|
}
|
|
|
|
// NormalizeSessionName converts arbitrary text into a kebab-case identifier.
|
|
func NormalizeSessionName(name string) string {
|
|
trimmed := strings.TrimSpace(name)
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
|
|
trimmed = strings.ToLower(trimmed)
|
|
var builder strings.Builder
|
|
lastHyphen := false
|
|
|
|
for _, r := range trimmed {
|
|
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
|
builder.WriteRune(r)
|
|
lastHyphen = false
|
|
} else if r == ' ' || r == '-' || r == '_' || r == '.' || r == '/' || r == '\\' {
|
|
if !lastHyphen && builder.Len() > 0 {
|
|
builder.WriteRune('-')
|
|
lastHyphen = true
|
|
}
|
|
}
|
|
if builder.Len() >= 80 {
|
|
break
|
|
}
|
|
}
|
|
|
|
slug := strings.Trim(builder.String(), "-")
|
|
if slug == "" {
|
|
return ""
|
|
}
|
|
|
|
parts := strings.Split(slug, "-")
|
|
filtered := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
if part == "" {
|
|
continue
|
|
}
|
|
filtered = append(filtered, part)
|
|
if len(filtered) >= 6 {
|
|
break
|
|
}
|
|
}
|
|
|
|
result := strings.Join(filtered, "-")
|
|
if len(result) > 60 {
|
|
runes := []rune(result)
|
|
if len(runes) > 60 {
|
|
runes = runes[:60]
|
|
}
|
|
result = strings.Trim(string(runes), "-")
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func isAutoGeneratedName(name string) bool {
|
|
return strings.HasPrefix(name, defaultSessionPrefix)
|
|
}
|