package chat import ( "context" "errors" "fmt" "log/slog" "strings" "time" "unicode" "github.com/stig/goaichat/internal/config" "github.com/stig/goaichat/internal/openai" ) const defaultSessionPrefix = "session-" // CompletionClient defines the subset of the OpenAI client used by the chat service. type CompletionClient interface { CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error) StreamChatCompletion(ctx context.Context, req openai.ChatCompletionRequest, handler openai.ChatCompletionStreamHandler) (*openai.ChatCompletionResponse, error) } // MessageRepository defines persistence hooks used by the chat service. type MessageRepository interface { CreateSession(ctx context.Context, name, model string) (int64, error) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error UpdateSessionName(ctx context.Context, sessionID int64, name string) error } // Service coordinates chat requests with the OpenAI client and maintains session history. type Service struct { logger *slog.Logger client CompletionClient repo MessageRepository model string temperature float64 stream bool streamNotice string history []openai.ChatMessage sessionID int64 summarySet bool sessionName string sessionNamed bool nameSuggested bool } // NewService constructs a Service from configuration and an OpenAI-compatible client. func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient, repo MessageRepository) (*Service, error) { if logger == nil { return nil, errors.New("logger cannot be nil") } if client == nil { return nil, errors.New("completion client cannot be nil") } if strings.TrimSpace(modelCfg.Name) == "" { return nil, errors.New("model name cannot be empty") } return &Service{ logger: logger, client: client, repo: repo, model: modelCfg.Name, temperature: modelCfg.Temperature, stream: modelCfg.Stream, history: make([]openai.ChatMessage, 0, 16), }, nil } // Send submits a user message and returns the assistant reply. When streamHandler is provided and streaming is enabled, // partial responses are forwarded to the handler as they arrive. func (s *Service) Send(ctx context.Context, input string, streamHandler openai.ChatCompletionStreamHandler) (string, error) { if s == nil { return "", errors.New("service is nil") } if ctx == nil { return "", errors.New("context is nil") } content := strings.TrimSpace(input) if content == "" { return "", errors.New("input cannot be empty") } userMsg := openai.ChatMessage{Role: "user", Content: content} s.history = append(s.history, userMsg) if err := s.persistMessage(ctx, userMsg); err != nil { s.logger.WarnContext(ctx, "failed to persist user message", "error", err) } messages := append([]openai.ChatMessage(nil), s.history...) temperature := s.temperature s.streamNotice = "" req := openai.ChatCompletionRequest{ Model: s.model, Messages: messages, Stream: s.stream, Temperature: &temperature, } s.logger.DebugContext(ctx, "sending chat completion", "model", s.model, "message_count", len(messages)) var resp *openai.ChatCompletionResponse var err error if s.stream { resp, err = s.client.StreamChatCompletion(ctx, req, streamHandler) if err != nil { resp, err = s.handleStreamingFailure(ctx, req, err) if err != nil { return "", s.translateProviderError(err) } } } else { resp, err = s.client.CreateChatCompletion(ctx, req) if err != nil { return "", s.translateProviderError(err) } } if len(resp.Choices) == 0 { return "", errors.New("no choices returned from completion") } reply := resp.Choices[0].Message s.history = append(s.history, reply) if err := s.persistMessage(ctx, reply); err != nil { s.logger.WarnContext(ctx, "failed to persist assistant message", "error", err) } return reply.Content, nil } // History returns a copy of the current conversation history. func (s *Service) History() []openai.ChatMessage { if s == nil { return nil } historyCopy := make([]openai.ChatMessage, len(s.history)) copy(historyCopy, s.history) return historyCopy } // ConsumeStreamingNotice returns any pending streaming notice and clears it. func (s *Service) ConsumeStreamingNotice() string { if s == nil { return "" } notice := s.streamNotice s.streamNotice = "" return notice } // StreamingEnabled reports whether streaming completions are configured for this service. func (s *Service) StreamingEnabled() bool { if s == nil { return false } return s.stream } func (s *Service) translateProviderError(err error) error { var reqErr *openai.RequestError if !errors.As(err, &reqErr) { return err } if guidance, ok := providerStatusGuidance(reqErr.StatusCode()); ok { return errors.New(guidance) } return err } func (s *Service) handleStreamingFailure(ctx context.Context, req openai.ChatCompletionRequest, streamErr error) (*openai.ChatCompletionResponse, error) { if s == nil { return nil, streamErr } var reqErr *openai.RequestError if !errors.As(streamErr, &reqErr) { return nil, streamErr } status := reqErr.StatusCode() if status < 400 || status >= 500 { return nil, streamErr } guidance, hasGuidance := providerStatusGuidance(status) message := guidance if !hasGuidance { message = strings.TrimSpace(reqErr.Message()) if message == "" { message = strings.TrimSpace(streamErr.Error()) } if message == "" { message = "Streaming is unavailable" } } message = fmt.Sprintf("%s\nStreaming has been disabled; responses will be fully buffered.", message) s.logger.WarnContext(ctx, "streaming disabled", "status", status, "error", strings.TrimSpace(reqErr.Message())) s.stream = false s.streamNotice = message req.Stream = false resp, err := s.client.CreateChatCompletion(ctx, req) if err != nil { return nil, s.translateProviderError(err) } return resp, nil } func providerStatusGuidance(status int) (string, bool) { switch status { case 401: return "Incorrect API key provided.\nVerify API key, clear browser cache, or generate a new key.", true case 429: return "Rate limit reached.\nPace requests and implement exponential backoff.", true case 500: return "Server error.\nRetry after a brief wait; contact support if persistent.", true case 503: return "Engine overloaded.\nRetry request after a brief wait; contact support if persistent.", true default: return "", false } } // Reset clears the in-memory conversation history. func (s *Service) Reset() { if s == nil { return } s.history = s.history[:0] s.sessionID = 0 s.summarySet = false s.sessionName = "" s.sessionNamed = false s.nameSuggested = false } // RestoreSession replaces in-memory history with persisted messages for an existing session. func (s *Service) RestoreSession(sessionID int64, name string, messages []openai.ChatMessage, summaryPresent bool) { if s == nil { return } s.history = s.history[:0] s.history = append(s.history, messages...) s.sessionID = sessionID s.summarySet = summaryPresent s.sessionName = name s.sessionNamed = !isAutoGeneratedName(name) s.nameSuggested = false } // CurrentSessionID returns the active session identifier, if any. func (s *Service) CurrentSessionID() int64 { if s == nil { return 0 } return s.sessionID } // SessionName returns the current session name. func (s *Service) SessionName() string { if s == nil { return "" } return s.sessionName } // SessionNamed reports whether the current session has a user-friendly name applied. func (s *Service) SessionNamed() bool { if s == nil { return false } return s.sessionNamed } // ShouldSuggestSessionName indicates whether the service would benefit from a generated name. func (s *Service) ShouldSuggestSessionName() bool { if s == nil { return false } return !s.sessionNamed && !s.nameSuggested } // MarkSessionNameSuggested prevents additional automatic suggestions for the current session. func (s *Service) MarkSessionNameSuggested() { if s == nil { return } s.nameSuggested = true } // SetSessionName normalizes and persists a new name for the active session, returning the stored value. func (s *Service) SetSessionName(ctx context.Context, name string) (string, error) { if s == nil { return "", errors.New("service is nil") } normalized := NormalizeSessionName(name) if normalized == "" { return "", errors.New("session name cannot be empty") } if s.repo != nil { if s.sessionID <= 0 { if err := s.ensureSession(ctx); err != nil { return "", err } } if err := s.repo.UpdateSessionName(ctx, s.sessionID, normalized); err != nil { return "", err } } s.sessionName = normalized s.sessionNamed = !isAutoGeneratedName(normalized) s.nameSuggested = true return normalized, nil } func (s *Service) ensureSession(ctx context.Context) error { if s == nil || s.repo == nil { return nil } if s.sessionID > 0 { return nil } name := fmt.Sprintf("%s%s", defaultSessionPrefix, time.Now().Format("20060102-150405")) s.logger.DebugContext(ctx, "creating chat session", "name", name) id, err := s.repo.CreateSession(ctx, name, s.model) if err != nil { return err } s.sessionID = id s.summarySet = false s.sessionName = name s.sessionNamed = false s.nameSuggested = false return nil } func (s *Service) persistMessage(ctx context.Context, msg openai.ChatMessage) error { if s.repo == nil { return nil } if err := s.ensureSession(ctx); err != nil { return err } var tokens *int64 if _, err := s.repo.AddMessage(ctx, s.sessionID, msg.Role, msg.Content, tokens); err != nil { return err } if msg.Role == "assistant" && !s.summarySet && strings.TrimSpace(msg.Content) != "" { const maxSummaryLen = 120 summary := msg.Content runes := []rune(summary) if len(runes) > maxSummaryLen { summary = string(runes[:maxSummaryLen]) + "..." } if err := s.repo.UpdateSessionSummary(ctx, s.sessionID, summary); err != nil { s.logger.WarnContext(ctx, "failed to update session summary", "error", err) } else { s.summarySet = true } } return nil } // SuggestSessionName uses the backing LLM to propose a descriptive session name. func (s *Service) SuggestSessionName(ctx context.Context) (string, error) { if s == nil { return "", errors.New("service is nil") } if s.client == nil { return "", errors.New("completion client is unavailable") } if len(s.history) == 0 { return "", errors.New("no conversation history available") } start := 0 if len(s.history) > 10 { start = len(s.history) - 10 } var builder strings.Builder for i := start; i < len(s.history); i++ { msg := s.history[i] builder.WriteString(msg.Role) builder.WriteString(": ") builder.WriteString(msg.Content) builder.WriteString("\n") } temp := 0.3 req := openai.ChatCompletionRequest{ Model: s.model, Messages: []openai.ChatMessage{ {Role: "system", Content: "You generate concise, descriptive chat session names. Respond with a lowercase kebab-case identifier (letters, digits, hyphens only, max 6 words)."}, {Role: "user", Content: fmt.Sprintf("Conversation excerpt:\n%s\nProvide only the session name.", builder.String())}, }, Temperature: &temp, } resp, err := s.client.CreateChatCompletion(ctx, req) if err != nil { return "", err } if len(resp.Choices) == 0 { return "", errors.New("no session name suggestion returned") } suggestion := NormalizeSessionName(resp.Choices[0].Message.Content) if suggestion == "" { return "", errors.New("empty session name suggestion") } return suggestion, nil } // NormalizeSessionName converts arbitrary text into a kebab-case identifier. func NormalizeSessionName(name string) string { trimmed := strings.TrimSpace(name) if trimmed == "" { return "" } trimmed = strings.ToLower(trimmed) var builder strings.Builder lastHyphen := false for _, r := range trimmed { if unicode.IsLetter(r) || unicode.IsDigit(r) { builder.WriteRune(r) lastHyphen = false } else if r == ' ' || r == '-' || r == '_' || r == '.' || r == '/' || r == '\\' { if !lastHyphen && builder.Len() > 0 { builder.WriteRune('-') lastHyphen = true } } if builder.Len() >= 80 { break } } slug := strings.Trim(builder.String(), "-") if slug == "" { return "" } parts := strings.Split(slug, "-") filtered := make([]string, 0, len(parts)) for _, part := range parts { if part == "" { continue } filtered = append(filtered, part) if len(filtered) >= 6 { break } } result := strings.Join(filtered, "-") if len(result) > 60 { runes := []rune(result) if len(runes) > 60 { runes = runes[:60] } result = strings.Trim(string(runes), "-") } return result } func isAutoGeneratedName(name string) bool { return strings.HasPrefix(name, defaultSessionPrefix) }