diff --git a/internal/app/app.go b/internal/app/app.go index 0ad8a9e..0960dc0 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -13,6 +13,7 @@ import ( "github.com/stig/goaichat/internal/chat" "github.com/stig/goaichat/internal/config" "github.com/stig/goaichat/internal/openai" + "github.com/stig/goaichat/internal/storage" ) // App encapsulates the Goaichat application runtime wiring. @@ -21,6 +22,8 @@ type App struct { config *config.Config openAI *openai.Client chat *chat.Service + store *storage.Manager + repo *storage.Repository input io.Reader output io.Writer } @@ -71,6 +74,9 @@ func (a *App) Run(ctx context.Context) error { if err := a.initOpenAIClient(); err != nil { return err } + if err := a.initStorage(); err != nil { + return err + } if err := a.initChatService(); err != nil { return err } @@ -108,7 +114,7 @@ func (a *App) initChatService() error { return nil } - service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI) + service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI, a.repo) if err != nil { return err } @@ -117,6 +123,32 @@ func (a *App) initChatService() error { return nil } +func (a *App) initStorage() error { + if a.store != nil { + return nil + } + + manager, err := storage.NewManager(*a.config) + if err != nil { + return err + } + + db, err := manager.Open() + if err != nil { + return err + } + + repo, err := storage.NewRepository(db) + if err != nil { + _ = manager.Close() + return err + } + + a.store = manager + a.repo = repo + return nil +} + func (a *App) runCLILoop(ctx context.Context) error { scanner := bufio.NewScanner(a.input) @@ -163,11 +195,14 @@ func (a *App) runCLILoop(ctx context.Context) error { if _, err := fmt.Fprintf(a.output, "AI: %s\n", reply); err != nil { return err } + + if err := a.maybeSuggestSessionName(ctx); err != nil { + a.logger.WarnContext(ctx, "session name suggestion failed", "error", err) + } } } func (a *App) handleCommand(ctx context.Context, input string) (handled bool, exit bool, err error) { - _ = ctx trimmed := strings.TrimSpace(input) if trimmed == "" { return true, false, errors.New("no input provided") @@ -175,6 +210,86 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex if !strings.HasPrefix(trimmed, "/") { return false, false, nil } + if strings.HasPrefix(trimmed, "/rename") { + parts := strings.Fields(trimmed) + if len(parts) < 2 { + _, err := fmt.Fprintln(a.output, "Usage: /rename ") + return true, false, err + } + rawName := strings.Join(parts[1:], " ") + normalized := chat.NormalizeSessionName(rawName) + if normalized == "" { + _, err := fmt.Fprintln(a.output, "Session name cannot be empty.") + return true, false, err + } + + if a.repo != nil { + existing, fetchErr := a.repo.GetSessionByName(ctx, normalized) + if fetchErr != nil { + _, err := fmt.Fprintf(a.output, "Failed to verify name availability: %v\n", fetchErr) + return true, false, err + } + if existing != nil { + currentID := a.chat.CurrentSessionID() + if currentID == 0 || existing.ID != currentID { + _, err := fmt.Fprintf(a.output, "Session name %q is already in use.\n", normalized) + return true, false, err + } + } + } + + applied, setErr := a.chat.SetSessionName(ctx, rawName) + if setErr != nil { + _, err := fmt.Fprintf(a.output, "Failed to rename session: %v\n", setErr) + return true, false, err + } + + _, err := fmt.Fprintf(a.output, "Session renamed to %q.\n", applied) + return true, false, err + } + + if strings.HasPrefix(trimmed, "/open") { + parts := strings.Fields(trimmed) + if len(parts) != 2 { + _, err := fmt.Fprintln(a.output, "Usage: /open ") + return true, false, err + } + if a.repo == nil { + _, err := fmt.Fprintln(a.output, "Storage not initialised; cannot open sessions.") + return true, false, err + } + + session, fetchErr := a.repo.GetSessionByName(ctx, parts[1]) + if fetchErr != nil { + _, err := fmt.Fprintf(a.output, "Failed to fetch session: %v\n", fetchErr) + return true, false, err + } + if session == nil { + _, err := fmt.Fprintf(a.output, "Session %q not found.\n", parts[1]) + return true, false, err + } + + messages, msgErr := a.repo.GetMessages(ctx, session.ID) + if msgErr != nil { + _, err := fmt.Fprintf(a.output, "Failed to load messages: %v\n", msgErr) + return true, false, err + } + + chatMessages := make([]openai.ChatMessage, 0, len(messages)) + for _, message := range messages { + role := strings.TrimSpace(message.Role) + if role == "" { + continue + } + chatMessages = append(chatMessages, openai.ChatMessage{Role: role, Content: message.Content}) + } + + summaryPresent := session.Summary.Valid && strings.TrimSpace(session.Summary.String) != "" + a.chat.RestoreSession(session.ID, session.Name, chatMessages, summaryPresent) + + _, err := fmt.Fprintf(a.output, "Loaded session %q with %d messages.\n", session.Name, len(chatMessages)) + return true, false, err + } switch trimmed { case "/exit": @@ -183,11 +298,57 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex a.chat.Reset() _, err := fmt.Fprintln(a.output, "History cleared.") return true, false, err + case "/list": + if a.repo == nil { + _, err := fmt.Fprintln(a.output, "History commands unavailable (storage not initialised).") + return true, false, err + } + sessions, listErr := a.repo.ListSessions(ctx) + if listErr != nil { + _, err := fmt.Fprintf(a.output, "Failed to list sessions: %v\n", listErr) + return true, false, err + } + if len(sessions) == 0 { + _, err := fmt.Fprintln(a.output, "No saved sessions.") + return true, false, err + } + for _, session := range sessions { + summary := session.Summary.String + if summary == "" { + summary = "(no summary)" + } + if _, err := fmt.Fprintf(a.output, "- %s [%s]: %s\n", session.Name, session.ModelName, summary); err != nil { + return true, false, err + } + } + return true, false, nil case "/help": - _, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /help (more coming soon)") + _, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /list, /open , /rename , /help (more coming soon)") return true, false, err default: _, err := fmt.Fprintf(a.output, "Unknown command %q. Try /help.\n", trimmed) return true, false, err } } + +func (a *App) maybeSuggestSessionName(ctx context.Context) error { + if a.chat == nil { + return nil + } + if !a.chat.ShouldSuggestSessionName() { + return nil + } + + suggestion, err := a.chat.SuggestSessionName(ctx) + if err != nil { + a.chat.MarkSessionNameSuggested() + return err + } + + a.chat.MarkSessionNameSuggested() + if _, err := fmt.Fprintf(a.output, "Suggested session name: %s\nUse /rename %s to apply it now.\n", suggestion, suggestion); err != nil { + return err + } + + return nil +} diff --git a/internal/chat/service.go b/internal/chat/service.go index cfdb321..c4ce4b9 100644 --- a/internal/chat/service.go +++ b/internal/chat/service.go @@ -3,30 +3,49 @@ package chat import ( "context" "errors" + "fmt" "log/slog" "strings" + "time" + "unicode" "github.com/stig/goaichat/internal/config" "github.com/stig/goaichat/internal/openai" ) +const defaultSessionPrefix = "session-" + // CompletionClient defines the subset of the OpenAI client used by the chat service. type CompletionClient interface { CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error) } +// MessageRepository defines persistence hooks used by the chat service. +type MessageRepository interface { + CreateSession(ctx context.Context, name, model string) (int64, error) + AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) + UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error + UpdateSessionName(ctx context.Context, sessionID int64, name string) error +} + // Service coordinates chat requests with the OpenAI client and maintains session history. type Service struct { - logger *slog.Logger - client CompletionClient - model string - temperature float64 - stream bool - history []openai.ChatMessage + logger *slog.Logger + client CompletionClient + repo MessageRepository + model string + temperature float64 + stream bool + history []openai.ChatMessage + sessionID int64 + summarySet bool + sessionName string + sessionNamed bool + nameSuggested bool } // NewService constructs a Service from configuration and an OpenAI-compatible client. -func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient) (*Service, error) { +func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient, repo MessageRepository) (*Service, error) { if logger == nil { return nil, errors.New("logger cannot be nil") } @@ -40,6 +59,7 @@ func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client Complet return &Service{ logger: logger, client: client, + repo: repo, model: modelCfg.Name, temperature: modelCfg.Temperature, stream: modelCfg.Stream, @@ -63,6 +83,9 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) { userMsg := openai.ChatMessage{Role: "user", Content: content} s.history = append(s.history, userMsg) + if err := s.persistMessage(ctx, userMsg); err != nil { + s.logger.WarnContext(ctx, "failed to persist user message", "error", err) + } messages := append([]openai.ChatMessage(nil), s.history...) temperature := s.temperature @@ -80,13 +103,15 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) { if err != nil { return "", err } - if len(resp.Choices) == 0 { return "", errors.New("no choices returned from completion") } reply := resp.Choices[0].Message s.history = append(s.history, reply) + if err := s.persistMessage(ctx, reply); err != nil { + s.logger.WarnContext(ctx, "failed to persist assistant message", "error", err) + } return reply.Content, nil } @@ -108,4 +133,257 @@ func (s *Service) Reset() { return } s.history = s.history[:0] + s.sessionID = 0 + s.summarySet = false + s.sessionName = "" + s.sessionNamed = false + s.nameSuggested = false +} + +// RestoreSession replaces in-memory history with persisted messages for an existing session. +func (s *Service) RestoreSession(sessionID int64, name string, messages []openai.ChatMessage, summaryPresent bool) { + if s == nil { + return + } + + s.history = s.history[:0] + s.history = append(s.history, messages...) + s.sessionID = sessionID + s.summarySet = summaryPresent + s.sessionName = name + s.sessionNamed = !isAutoGeneratedName(name) + s.nameSuggested = false +} + +// CurrentSessionID returns the active session identifier, if any. +func (s *Service) CurrentSessionID() int64 { + if s == nil { + return 0 + } + return s.sessionID +} + +// SessionName returns the current session name. +func (s *Service) SessionName() string { + if s == nil { + return "" + } + return s.sessionName +} + +// SessionNamed reports whether the current session has a user-friendly name applied. +func (s *Service) SessionNamed() bool { + if s == nil { + return false + } + return s.sessionNamed +} + +// ShouldSuggestSessionName indicates whether the service would benefit from a generated name. +func (s *Service) ShouldSuggestSessionName() bool { + if s == nil { + return false + } + return !s.sessionNamed && !s.nameSuggested +} + +// MarkSessionNameSuggested prevents additional automatic suggestions for the current session. +func (s *Service) MarkSessionNameSuggested() { + if s == nil { + return + } + s.nameSuggested = true +} + +// SetSessionName normalizes and persists a new name for the active session, returning the stored value. +func (s *Service) SetSessionName(ctx context.Context, name string) (string, error) { + if s == nil { + return "", errors.New("service is nil") + } + normalized := NormalizeSessionName(name) + if normalized == "" { + return "", errors.New("session name cannot be empty") + } + + if s.repo != nil { + if s.sessionID <= 0 { + if err := s.ensureSession(ctx); err != nil { + return "", err + } + } + if err := s.repo.UpdateSessionName(ctx, s.sessionID, normalized); err != nil { + return "", err + } + } + + s.sessionName = normalized + s.sessionNamed = !isAutoGeneratedName(normalized) + s.nameSuggested = true + return normalized, nil +} + +func (s *Service) ensureSession(ctx context.Context) error { + if s == nil || s.repo == nil { + return nil + } + if s.sessionID > 0 { + return nil + } + + name := fmt.Sprintf("%s%s", defaultSessionPrefix, time.Now().Format("20060102-150405")) + s.logger.DebugContext(ctx, "creating chat session", "name", name) + + id, err := s.repo.CreateSession(ctx, name, s.model) + if err != nil { + return err + } + + s.sessionID = id + s.summarySet = false + s.sessionName = name + s.sessionNamed = false + s.nameSuggested = false + return nil +} + +func (s *Service) persistMessage(ctx context.Context, msg openai.ChatMessage) error { + if s.repo == nil { + return nil + } + + if err := s.ensureSession(ctx); err != nil { + return err + } + + var tokens *int64 + if _, err := s.repo.AddMessage(ctx, s.sessionID, msg.Role, msg.Content, tokens); err != nil { + return err + } + + if msg.Role == "assistant" && !s.summarySet && strings.TrimSpace(msg.Content) != "" { + const maxSummaryLen = 120 + summary := msg.Content + runes := []rune(summary) + if len(runes) > maxSummaryLen { + summary = string(runes[:maxSummaryLen]) + "..." + } + if err := s.repo.UpdateSessionSummary(ctx, s.sessionID, summary); err != nil { + s.logger.WarnContext(ctx, "failed to update session summary", "error", err) + } else { + s.summarySet = true + } + } + + return nil +} + +// SuggestSessionName uses the backing LLM to propose a descriptive session name. +func (s *Service) SuggestSessionName(ctx context.Context) (string, error) { + if s == nil { + return "", errors.New("service is nil") + } + if s.client == nil { + return "", errors.New("completion client is unavailable") + } + if len(s.history) == 0 { + return "", errors.New("no conversation history available") + } + + start := 0 + if len(s.history) > 10 { + start = len(s.history) - 10 + } + + var builder strings.Builder + for i := start; i < len(s.history); i++ { + msg := s.history[i] + builder.WriteString(msg.Role) + builder.WriteString(": ") + builder.WriteString(msg.Content) + builder.WriteString("\n") + } + + temp := 0.3 + req := openai.ChatCompletionRequest{ + Model: s.model, + Messages: []openai.ChatMessage{ + {Role: "system", Content: "You generate concise, descriptive chat session names. Respond with a lowercase kebab-case identifier (letters, digits, hyphens only, max 6 words)."}, + {Role: "user", Content: fmt.Sprintf("Conversation excerpt:\n%s\nProvide only the session name.", builder.String())}, + }, + Temperature: &temp, + } + + resp, err := s.client.CreateChatCompletion(ctx, req) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", errors.New("no session name suggestion returned") + } + + suggestion := NormalizeSessionName(resp.Choices[0].Message.Content) + if suggestion == "" { + return "", errors.New("empty session name suggestion") + } + + return suggestion, nil +} + +// NormalizeSessionName converts arbitrary text into a kebab-case identifier. +func NormalizeSessionName(name string) string { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return "" + } + + trimmed = strings.ToLower(trimmed) + var builder strings.Builder + lastHyphen := false + + for _, r := range trimmed { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + builder.WriteRune(r) + lastHyphen = false + } else if r == ' ' || r == '-' || r == '_' || r == '.' || r == '/' || r == '\\' { + if !lastHyphen && builder.Len() > 0 { + builder.WriteRune('-') + lastHyphen = true + } + } + if builder.Len() >= 80 { + break + } + } + + slug := strings.Trim(builder.String(), "-") + if slug == "" { + return "" + } + + parts := strings.Split(slug, "-") + filtered := make([]string, 0, len(parts)) + for _, part := range parts { + if part == "" { + continue + } + filtered = append(filtered, part) + if len(filtered) >= 6 { + break + } + } + + result := strings.Join(filtered, "-") + if len(result) > 60 { + runes := []rune(result) + if len(runes) > 60 { + runes = runes[:60] + } + result = strings.Trim(string(runes), "-") + } + + return result +} + +func isAutoGeneratedName(name string) bool { + return strings.HasPrefix(name, defaultSessionPrefix) } diff --git a/internal/storage/repository.go b/internal/storage/repository.go new file mode 100644 index 0000000..3d0bddc --- /dev/null +++ b/internal/storage/repository.go @@ -0,0 +1,204 @@ +package storage + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" +) + +// Session represents a persisted chat session. +type Session struct { + ID int64 + Name string + ModelName string + Summary sql.NullString + CreatedAt time.Time + UpdatedAt time.Time +} + +// GetSessionByName fetches a session by its unique name, returning nil when not found. +func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) { + if r == nil { + return nil, errors.New("repository is nil") + } + if strings.TrimSpace(name) == "" { + return nil, errors.New("session name cannot be empty") + } + + row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name) + var session Session + if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("get session by name: %w", err) + } + + return &session, nil +} + +// Message represents a persisted chat message. +type Message struct { + ID int64 + SessionID int64 + Role string + Content string + TokenCount sql.NullInt64 + CreatedAt time.Time +} + +// Repository exposes CRUD helpers for sessions and messages. +type Repository struct { + db *sql.DB +} + +// NewRepository constructs a repository bound to the provided DB. +func NewRepository(db *sql.DB) (*Repository, error) { + if db == nil { + return nil, errors.New("db cannot be nil") + } + return &Repository{db: db}, nil +} + +// CreateSession inserts a new session record. +func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) { + if r == nil { + return 0, errors.New("repository is nil") + } + if name == "" { + return 0, errors.New("session name cannot be empty") + } + + query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)` + res, err := r.db.ExecContext(ctx, query, name, model) + if err != nil { + return 0, fmt.Errorf("insert session: %w", err) + } + + id, err := res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("fetch session id: %w", err) + } + return id, nil +} + +// AddMessage persists a message linked to a session. +func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) { + if r == nil { + return 0, errors.New("repository is nil") + } + if sessionID <= 0 { + return 0, errors.New("sessionID must be positive") + } + if role == "" || content == "" { + return 0, errors.New("role and content must be provided") + } + + var tokenVal interface{} + if tokens != nil { + tokenVal = *tokens + } else { + tokenVal = nil + } + + query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)` + res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal) + if err != nil { + return 0, fmt.Errorf("insert message: %w", err) + } + + id, err := res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("fetch message id: %w", err) + } + return id, nil +} + +// ListSessions returns all sessions ordered by creation time descending. +func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) { + rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`) + if err != nil { + return nil, fmt.Errorf("query sessions: %w", err) + } + defer rows.Close() + + var sessions []Session + for rows.Next() { + var s Session + if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan session: %w", err) + } + sessions = append(sessions, s) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate sessions: %w", err) + } + + return sessions, nil +} + +// GetMessages returns messages for a given session ordered by creation time. +func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) { + if sessionID <= 0 { + return nil, errors.New("sessionID must be positive") + } + rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID) + if err != nil { + return nil, fmt.Errorf("query messages: %w", err) + } + defer rows.Close() + + var messages []Message + for rows.Next() { + var m Message + if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil { + return nil, fmt.Errorf("scan message: %w", err) + } + messages = append(messages, m) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate messages: %w", err) + } + + return messages, nil +} + +// UpdateSessionSummary updates the summary column for a session. +func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error { + if sessionID <= 0 { + return errors.New("sessionID must be positive") + } + + _, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID) + if err != nil { + return fmt.Errorf("update session summary: %w", err) + } + + return nil +} + +// UpdateSessionName updates the name of an existing session. +func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error { + if r == nil { + return errors.New("repository is nil") + } + if sessionID <= 0 { + return errors.New("sessionID must be positive") + } + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return errors.New("session name cannot be empty") + } + + _, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID) + if err != nil { + return fmt.Errorf("update session name: %w", err) + } + + return nil +}