Files
goaichat/internal/chat/service.go

493 lines
13 KiB
Go

package chat
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"unicode"
"github.com/stig/goaichat/internal/config"
"github.com/stig/goaichat/internal/openai"
)
const defaultSessionPrefix = "session-"
// CompletionClient defines the subset of the OpenAI client used by the chat service.
type CompletionClient interface {
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
StreamChatCompletion(ctx context.Context, req openai.ChatCompletionRequest, handler openai.ChatCompletionStreamHandler) (*openai.ChatCompletionResponse, error)
}
// MessageRepository defines persistence hooks used by the chat service.
type MessageRepository interface {
CreateSession(ctx context.Context, name, model string) (int64, error)
AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error)
UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error
UpdateSessionName(ctx context.Context, sessionID int64, name string) error
}
// Service coordinates chat requests with the OpenAI client and maintains session history.
type Service struct {
logger *slog.Logger
client CompletionClient
repo MessageRepository
model string
temperature float64
stream bool
streamNotice string
history []openai.ChatMessage
sessionID int64
summarySet bool
sessionName string
sessionNamed bool
nameSuggested bool
}
// NewService constructs a Service from configuration and an OpenAI-compatible client.
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient, repo MessageRepository) (*Service, error) {
if logger == nil {
return nil, errors.New("logger cannot be nil")
}
if client == nil {
return nil, errors.New("completion client cannot be nil")
}
if strings.TrimSpace(modelCfg.Name) == "" {
return nil, errors.New("model name cannot be empty")
}
return &Service{
logger: logger,
client: client,
repo: repo,
model: modelCfg.Name,
temperature: modelCfg.Temperature,
stream: modelCfg.Stream,
history: make([]openai.ChatMessage, 0, 16),
}, nil
}
// Send submits a user message and returns the assistant reply. When streamHandler is provided and streaming is enabled,
// partial responses are forwarded to the handler as they arrive.
func (s *Service) Send(ctx context.Context, input string, streamHandler openai.ChatCompletionStreamHandler) (string, error) {
if s == nil {
return "", errors.New("service is nil")
}
if ctx == nil {
return "", errors.New("context is nil")
}
content := strings.TrimSpace(input)
if content == "" {
return "", errors.New("input cannot be empty")
}
userMsg := openai.ChatMessage{Role: "user", Content: content}
s.history = append(s.history, userMsg)
if err := s.persistMessage(ctx, userMsg); err != nil {
s.logger.WarnContext(ctx, "failed to persist user message", "error", err)
}
messages := append([]openai.ChatMessage(nil), s.history...)
temperature := s.temperature
s.streamNotice = ""
req := openai.ChatCompletionRequest{
Model: s.model,
Messages: messages,
Stream: s.stream,
Temperature: &temperature,
}
s.logger.DebugContext(ctx, "sending chat completion", "model", s.model, "message_count", len(messages))
var resp *openai.ChatCompletionResponse
var err error
if s.stream {
resp, err = s.client.StreamChatCompletion(ctx, req, streamHandler)
if err != nil {
resp, err = s.handleStreamingFailure(ctx, req, err)
if err != nil {
return "", s.translateProviderError(err)
}
}
} else {
resp, err = s.client.CreateChatCompletion(ctx, req)
if err != nil {
return "", s.translateProviderError(err)
}
}
if len(resp.Choices) == 0 {
return "", errors.New("no choices returned from completion")
}
reply := resp.Choices[0].Message
s.history = append(s.history, reply)
if err := s.persistMessage(ctx, reply); err != nil {
s.logger.WarnContext(ctx, "failed to persist assistant message", "error", err)
}
return reply.Content, nil
}
// History returns a copy of the current conversation history.
func (s *Service) History() []openai.ChatMessage {
if s == nil {
return nil
}
historyCopy := make([]openai.ChatMessage, len(s.history))
copy(historyCopy, s.history)
return historyCopy
}
// ConsumeStreamingNotice returns any pending streaming notice and clears it.
func (s *Service) ConsumeStreamingNotice() string {
if s == nil {
return ""
}
notice := s.streamNotice
s.streamNotice = ""
return notice
}
// StreamingEnabled reports whether streaming completions are configured for this service.
func (s *Service) StreamingEnabled() bool {
if s == nil {
return false
}
return s.stream
}
func (s *Service) translateProviderError(err error) error {
var reqErr *openai.RequestError
if !errors.As(err, &reqErr) {
return err
}
if guidance, ok := providerStatusGuidance(reqErr.StatusCode()); ok {
return errors.New(guidance)
}
return err
}
func (s *Service) handleStreamingFailure(ctx context.Context, req openai.ChatCompletionRequest, streamErr error) (*openai.ChatCompletionResponse, error) {
if s == nil {
return nil, streamErr
}
var reqErr *openai.RequestError
if !errors.As(streamErr, &reqErr) {
return nil, streamErr
}
status := reqErr.StatusCode()
if status < 400 || status >= 500 {
return nil, streamErr
}
guidance, hasGuidance := providerStatusGuidance(status)
message := guidance
if !hasGuidance {
message = strings.TrimSpace(reqErr.Message())
if message == "" {
message = strings.TrimSpace(streamErr.Error())
}
if message == "" {
message = "Streaming is unavailable"
}
}
message = fmt.Sprintf("%s\nStreaming has been disabled; responses will be fully buffered.", message)
s.logger.WarnContext(ctx, "streaming disabled", "status", status, "error", strings.TrimSpace(reqErr.Message()))
s.stream = false
s.streamNotice = message
req.Stream = false
resp, err := s.client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, s.translateProviderError(err)
}
return resp, nil
}
func providerStatusGuidance(status int) (string, bool) {
switch status {
case 401:
return "Incorrect API key provided.\nVerify API key, clear browser cache, or generate a new key.", true
case 429:
return "Rate limit reached.\nPace requests and implement exponential backoff.", true
case 500:
return "Server error.\nRetry after a brief wait; contact support if persistent.", true
case 503:
return "Engine overloaded.\nRetry request after a brief wait; contact support if persistent.", true
default:
return "", false
}
}
// Reset clears the in-memory conversation history.
func (s *Service) Reset() {
if s == nil {
return
}
s.history = s.history[:0]
s.sessionID = 0
s.summarySet = false
s.sessionName = ""
s.sessionNamed = false
s.nameSuggested = false
}
// RestoreSession replaces in-memory history with persisted messages for an existing session.
func (s *Service) RestoreSession(sessionID int64, name string, messages []openai.ChatMessage, summaryPresent bool) {
if s == nil {
return
}
s.history = s.history[:0]
s.history = append(s.history, messages...)
s.sessionID = sessionID
s.summarySet = summaryPresent
s.sessionName = name
s.sessionNamed = !isAutoGeneratedName(name)
s.nameSuggested = false
}
// CurrentSessionID returns the active session identifier, if any.
func (s *Service) CurrentSessionID() int64 {
if s == nil {
return 0
}
return s.sessionID
}
// SessionName returns the current session name.
func (s *Service) SessionName() string {
if s == nil {
return ""
}
return s.sessionName
}
// SessionNamed reports whether the current session has a user-friendly name applied.
func (s *Service) SessionNamed() bool {
if s == nil {
return false
}
return s.sessionNamed
}
// ShouldSuggestSessionName indicates whether the service would benefit from a generated name.
func (s *Service) ShouldSuggestSessionName() bool {
if s == nil {
return false
}
return !s.sessionNamed && !s.nameSuggested
}
// MarkSessionNameSuggested prevents additional automatic suggestions for the current session.
func (s *Service) MarkSessionNameSuggested() {
if s == nil {
return
}
s.nameSuggested = true
}
// SetSessionName normalizes and persists a new name for the active session, returning the stored value.
func (s *Service) SetSessionName(ctx context.Context, name string) (string, error) {
if s == nil {
return "", errors.New("service is nil")
}
normalized := NormalizeSessionName(name)
if normalized == "" {
return "", errors.New("session name cannot be empty")
}
if s.repo != nil {
if s.sessionID <= 0 {
if err := s.ensureSession(ctx); err != nil {
return "", err
}
}
if err := s.repo.UpdateSessionName(ctx, s.sessionID, normalized); err != nil {
return "", err
}
}
s.sessionName = normalized
s.sessionNamed = !isAutoGeneratedName(normalized)
s.nameSuggested = true
return normalized, nil
}
func (s *Service) ensureSession(ctx context.Context) error {
if s == nil || s.repo == nil {
return nil
}
if s.sessionID > 0 {
return nil
}
name := fmt.Sprintf("%s%s", defaultSessionPrefix, time.Now().Format("20060102-150405"))
s.logger.DebugContext(ctx, "creating chat session", "name", name)
id, err := s.repo.CreateSession(ctx, name, s.model)
if err != nil {
return err
}
s.sessionID = id
s.summarySet = false
s.sessionName = name
s.sessionNamed = false
s.nameSuggested = false
return nil
}
func (s *Service) persistMessage(ctx context.Context, msg openai.ChatMessage) error {
if s.repo == nil {
return nil
}
if err := s.ensureSession(ctx); err != nil {
return err
}
var tokens *int64
if _, err := s.repo.AddMessage(ctx, s.sessionID, msg.Role, msg.Content, tokens); err != nil {
return err
}
if msg.Role == "assistant" && !s.summarySet && strings.TrimSpace(msg.Content) != "" {
const maxSummaryLen = 120
summary := msg.Content
runes := []rune(summary)
if len(runes) > maxSummaryLen {
summary = string(runes[:maxSummaryLen]) + "..."
}
if err := s.repo.UpdateSessionSummary(ctx, s.sessionID, summary); err != nil {
s.logger.WarnContext(ctx, "failed to update session summary", "error", err)
} else {
s.summarySet = true
}
}
return nil
}
// SuggestSessionName uses the backing LLM to propose a descriptive session name.
func (s *Service) SuggestSessionName(ctx context.Context) (string, error) {
if s == nil {
return "", errors.New("service is nil")
}
if s.client == nil {
return "", errors.New("completion client is unavailable")
}
if len(s.history) == 0 {
return "", errors.New("no conversation history available")
}
start := 0
if len(s.history) > 10 {
start = len(s.history) - 10
}
var builder strings.Builder
for i := start; i < len(s.history); i++ {
msg := s.history[i]
builder.WriteString(msg.Role)
builder.WriteString(": ")
builder.WriteString(msg.Content)
builder.WriteString("\n")
}
temp := 0.3
req := openai.ChatCompletionRequest{
Model: s.model,
Messages: []openai.ChatMessage{
{Role: "system", Content: "You generate concise, descriptive chat session names. Respond with a lowercase kebab-case identifier (letters, digits, hyphens only, max 6 words)."},
{Role: "user", Content: fmt.Sprintf("Conversation excerpt:\n%s\nProvide only the session name.", builder.String())},
},
Temperature: &temp,
}
resp, err := s.client.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", errors.New("no session name suggestion returned")
}
suggestion := NormalizeSessionName(resp.Choices[0].Message.Content)
if suggestion == "" {
return "", errors.New("empty session name suggestion")
}
return suggestion, nil
}
// NormalizeSessionName converts arbitrary text into a kebab-case identifier.
func NormalizeSessionName(name string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return ""
}
trimmed = strings.ToLower(trimmed)
var builder strings.Builder
lastHyphen := false
for _, r := range trimmed {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
builder.WriteRune(r)
lastHyphen = false
} else if r == ' ' || r == '-' || r == '_' || r == '.' || r == '/' || r == '\\' {
if !lastHyphen && builder.Len() > 0 {
builder.WriteRune('-')
lastHyphen = true
}
}
if builder.Len() >= 80 {
break
}
}
slug := strings.Trim(builder.String(), "-")
if slug == "" {
return ""
}
parts := strings.Split(slug, "-")
filtered := make([]string, 0, len(parts))
for _, part := range parts {
if part == "" {
continue
}
filtered = append(filtered, part)
if len(filtered) >= 6 {
break
}
}
result := strings.Join(filtered, "-")
if len(result) > 60 {
runes := []rune(result)
if len(runes) > 60 {
runes = runes[:60]
}
result = strings.Trim(string(runes), "-")
}
return result
}
func isAutoGeneratedName(name string) bool {
return strings.HasPrefix(name, defaultSessionPrefix)
}