Add session naming workflow and rename command

This commit is contained in:
2025-10-01 16:32:39 +02:00
parent 14fb100dab
commit aaeb116ecb
3 changed files with 654 additions and 11 deletions

View File

@@ -13,6 +13,7 @@ import (
"github.com/stig/goaichat/internal/chat"
"github.com/stig/goaichat/internal/config"
"github.com/stig/goaichat/internal/openai"
"github.com/stig/goaichat/internal/storage"
)
// App encapsulates the Goaichat application runtime wiring.
@@ -21,6 +22,8 @@ type App struct {
config *config.Config
openAI *openai.Client
chat *chat.Service
store *storage.Manager
repo *storage.Repository
input io.Reader
output io.Writer
}
@@ -71,6 +74,9 @@ func (a *App) Run(ctx context.Context) error {
if err := a.initOpenAIClient(); err != nil {
return err
}
if err := a.initStorage(); err != nil {
return err
}
if err := a.initChatService(); err != nil {
return err
}
@@ -108,7 +114,7 @@ func (a *App) initChatService() error {
return nil
}
service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI)
service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI, a.repo)
if err != nil {
return err
}
@@ -117,6 +123,32 @@ func (a *App) initChatService() error {
return nil
}
func (a *App) initStorage() error {
if a.store != nil {
return nil
}
manager, err := storage.NewManager(*a.config)
if err != nil {
return err
}
db, err := manager.Open()
if err != nil {
return err
}
repo, err := storage.NewRepository(db)
if err != nil {
_ = manager.Close()
return err
}
a.store = manager
a.repo = repo
return nil
}
func (a *App) runCLILoop(ctx context.Context) error {
scanner := bufio.NewScanner(a.input)
@@ -163,11 +195,14 @@ func (a *App) runCLILoop(ctx context.Context) error {
if _, err := fmt.Fprintf(a.output, "AI: %s\n", reply); err != nil {
return err
}
if err := a.maybeSuggestSessionName(ctx); err != nil {
a.logger.WarnContext(ctx, "session name suggestion failed", "error", err)
}
}
}
func (a *App) handleCommand(ctx context.Context, input string) (handled bool, exit bool, err error) {
_ = ctx
trimmed := strings.TrimSpace(input)
if trimmed == "" {
return true, false, errors.New("no input provided")
@@ -175,6 +210,86 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex
if !strings.HasPrefix(trimmed, "/") {
return false, false, nil
}
if strings.HasPrefix(trimmed, "/rename") {
parts := strings.Fields(trimmed)
if len(parts) < 2 {
_, err := fmt.Fprintln(a.output, "Usage: /rename <session-name>")
return true, false, err
}
rawName := strings.Join(parts[1:], " ")
normalized := chat.NormalizeSessionName(rawName)
if normalized == "" {
_, err := fmt.Fprintln(a.output, "Session name cannot be empty.")
return true, false, err
}
if a.repo != nil {
existing, fetchErr := a.repo.GetSessionByName(ctx, normalized)
if fetchErr != nil {
_, err := fmt.Fprintf(a.output, "Failed to verify name availability: %v\n", fetchErr)
return true, false, err
}
if existing != nil {
currentID := a.chat.CurrentSessionID()
if currentID == 0 || existing.ID != currentID {
_, err := fmt.Fprintf(a.output, "Session name %q is already in use.\n", normalized)
return true, false, err
}
}
}
applied, setErr := a.chat.SetSessionName(ctx, rawName)
if setErr != nil {
_, err := fmt.Fprintf(a.output, "Failed to rename session: %v\n", setErr)
return true, false, err
}
_, err := fmt.Fprintf(a.output, "Session renamed to %q.\n", applied)
return true, false, err
}
if strings.HasPrefix(trimmed, "/open") {
parts := strings.Fields(trimmed)
if len(parts) != 2 {
_, err := fmt.Fprintln(a.output, "Usage: /open <session-name>")
return true, false, err
}
if a.repo == nil {
_, err := fmt.Fprintln(a.output, "Storage not initialised; cannot open sessions.")
return true, false, err
}
session, fetchErr := a.repo.GetSessionByName(ctx, parts[1])
if fetchErr != nil {
_, err := fmt.Fprintf(a.output, "Failed to fetch session: %v\n", fetchErr)
return true, false, err
}
if session == nil {
_, err := fmt.Fprintf(a.output, "Session %q not found.\n", parts[1])
return true, false, err
}
messages, msgErr := a.repo.GetMessages(ctx, session.ID)
if msgErr != nil {
_, err := fmt.Fprintf(a.output, "Failed to load messages: %v\n", msgErr)
return true, false, err
}
chatMessages := make([]openai.ChatMessage, 0, len(messages))
for _, message := range messages {
role := strings.TrimSpace(message.Role)
if role == "" {
continue
}
chatMessages = append(chatMessages, openai.ChatMessage{Role: role, Content: message.Content})
}
summaryPresent := session.Summary.Valid && strings.TrimSpace(session.Summary.String) != ""
a.chat.RestoreSession(session.ID, session.Name, chatMessages, summaryPresent)
_, err := fmt.Fprintf(a.output, "Loaded session %q with %d messages.\n", session.Name, len(chatMessages))
return true, false, err
}
switch trimmed {
case "/exit":
@@ -183,11 +298,57 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex
a.chat.Reset()
_, err := fmt.Fprintln(a.output, "History cleared.")
return true, false, err
case "/list":
if a.repo == nil {
_, err := fmt.Fprintln(a.output, "History commands unavailable (storage not initialised).")
return true, false, err
}
sessions, listErr := a.repo.ListSessions(ctx)
if listErr != nil {
_, err := fmt.Fprintf(a.output, "Failed to list sessions: %v\n", listErr)
return true, false, err
}
if len(sessions) == 0 {
_, err := fmt.Fprintln(a.output, "No saved sessions.")
return true, false, err
}
for _, session := range sessions {
summary := session.Summary.String
if summary == "" {
summary = "(no summary)"
}
if _, err := fmt.Fprintf(a.output, "- %s [%s]: %s\n", session.Name, session.ModelName, summary); err != nil {
return true, false, err
}
}
return true, false, nil
case "/help":
_, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /help (more coming soon)")
_, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /list, /open <name>, /rename <name>, /help (more coming soon)")
return true, false, err
default:
_, err := fmt.Fprintf(a.output, "Unknown command %q. Try /help.\n", trimmed)
return true, false, err
}
}
func (a *App) maybeSuggestSessionName(ctx context.Context) error {
if a.chat == nil {
return nil
}
if !a.chat.ShouldSuggestSessionName() {
return nil
}
suggestion, err := a.chat.SuggestSessionName(ctx)
if err != nil {
a.chat.MarkSessionNameSuggested()
return err
}
a.chat.MarkSessionNameSuggested()
if _, err := fmt.Fprintf(a.output, "Suggested session name: %s\nUse /rename %s to apply it now.\n", suggestion, suggestion); err != nil {
return err
}
return nil
}

View File

@@ -3,30 +3,49 @@ package chat
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"unicode"
"github.com/stig/goaichat/internal/config"
"github.com/stig/goaichat/internal/openai"
)
const defaultSessionPrefix = "session-"
// CompletionClient defines the subset of the OpenAI client used by the chat service.
type CompletionClient interface {
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
}
// MessageRepository defines persistence hooks used by the chat service.
type MessageRepository interface {
CreateSession(ctx context.Context, name, model string) (int64, error)
AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error)
UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error
UpdateSessionName(ctx context.Context, sessionID int64, name string) error
}
// Service coordinates chat requests with the OpenAI client and maintains session history.
type Service struct {
logger *slog.Logger
client CompletionClient
model string
temperature float64
stream bool
history []openai.ChatMessage
logger *slog.Logger
client CompletionClient
repo MessageRepository
model string
temperature float64
stream bool
history []openai.ChatMessage
sessionID int64
summarySet bool
sessionName string
sessionNamed bool
nameSuggested bool
}
// NewService constructs a Service from configuration and an OpenAI-compatible client.
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient) (*Service, error) {
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient, repo MessageRepository) (*Service, error) {
if logger == nil {
return nil, errors.New("logger cannot be nil")
}
@@ -40,6 +59,7 @@ func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client Complet
return &Service{
logger: logger,
client: client,
repo: repo,
model: modelCfg.Name,
temperature: modelCfg.Temperature,
stream: modelCfg.Stream,
@@ -63,6 +83,9 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) {
userMsg := openai.ChatMessage{Role: "user", Content: content}
s.history = append(s.history, userMsg)
if err := s.persistMessage(ctx, userMsg); err != nil {
s.logger.WarnContext(ctx, "failed to persist user message", "error", err)
}
messages := append([]openai.ChatMessage(nil), s.history...)
temperature := s.temperature
@@ -80,13 +103,15 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) {
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", errors.New("no choices returned from completion")
}
reply := resp.Choices[0].Message
s.history = append(s.history, reply)
if err := s.persistMessage(ctx, reply); err != nil {
s.logger.WarnContext(ctx, "failed to persist assistant message", "error", err)
}
return reply.Content, nil
}
@@ -108,4 +133,257 @@ func (s *Service) Reset() {
return
}
s.history = s.history[:0]
s.sessionID = 0
s.summarySet = false
s.sessionName = ""
s.sessionNamed = false
s.nameSuggested = false
}
// RestoreSession replaces in-memory history with persisted messages for an existing session.
func (s *Service) RestoreSession(sessionID int64, name string, messages []openai.ChatMessage, summaryPresent bool) {
if s == nil {
return
}
s.history = s.history[:0]
s.history = append(s.history, messages...)
s.sessionID = sessionID
s.summarySet = summaryPresent
s.sessionName = name
s.sessionNamed = !isAutoGeneratedName(name)
s.nameSuggested = false
}
// CurrentSessionID returns the active session identifier, if any.
func (s *Service) CurrentSessionID() int64 {
if s == nil {
return 0
}
return s.sessionID
}
// SessionName returns the current session name.
func (s *Service) SessionName() string {
if s == nil {
return ""
}
return s.sessionName
}
// SessionNamed reports whether the current session has a user-friendly name applied.
func (s *Service) SessionNamed() bool {
if s == nil {
return false
}
return s.sessionNamed
}
// ShouldSuggestSessionName indicates whether the service would benefit from a generated name.
func (s *Service) ShouldSuggestSessionName() bool {
if s == nil {
return false
}
return !s.sessionNamed && !s.nameSuggested
}
// MarkSessionNameSuggested prevents additional automatic suggestions for the current session.
func (s *Service) MarkSessionNameSuggested() {
if s == nil {
return
}
s.nameSuggested = true
}
// SetSessionName normalizes and persists a new name for the active session, returning the stored value.
func (s *Service) SetSessionName(ctx context.Context, name string) (string, error) {
if s == nil {
return "", errors.New("service is nil")
}
normalized := NormalizeSessionName(name)
if normalized == "" {
return "", errors.New("session name cannot be empty")
}
if s.repo != nil {
if s.sessionID <= 0 {
if err := s.ensureSession(ctx); err != nil {
return "", err
}
}
if err := s.repo.UpdateSessionName(ctx, s.sessionID, normalized); err != nil {
return "", err
}
}
s.sessionName = normalized
s.sessionNamed = !isAutoGeneratedName(normalized)
s.nameSuggested = true
return normalized, nil
}
func (s *Service) ensureSession(ctx context.Context) error {
if s == nil || s.repo == nil {
return nil
}
if s.sessionID > 0 {
return nil
}
name := fmt.Sprintf("%s%s", defaultSessionPrefix, time.Now().Format("20060102-150405"))
s.logger.DebugContext(ctx, "creating chat session", "name", name)
id, err := s.repo.CreateSession(ctx, name, s.model)
if err != nil {
return err
}
s.sessionID = id
s.summarySet = false
s.sessionName = name
s.sessionNamed = false
s.nameSuggested = false
return nil
}
func (s *Service) persistMessage(ctx context.Context, msg openai.ChatMessage) error {
if s.repo == nil {
return nil
}
if err := s.ensureSession(ctx); err != nil {
return err
}
var tokens *int64
if _, err := s.repo.AddMessage(ctx, s.sessionID, msg.Role, msg.Content, tokens); err != nil {
return err
}
if msg.Role == "assistant" && !s.summarySet && strings.TrimSpace(msg.Content) != "" {
const maxSummaryLen = 120
summary := msg.Content
runes := []rune(summary)
if len(runes) > maxSummaryLen {
summary = string(runes[:maxSummaryLen]) + "..."
}
if err := s.repo.UpdateSessionSummary(ctx, s.sessionID, summary); err != nil {
s.logger.WarnContext(ctx, "failed to update session summary", "error", err)
} else {
s.summarySet = true
}
}
return nil
}
// SuggestSessionName uses the backing LLM to propose a descriptive session name.
func (s *Service) SuggestSessionName(ctx context.Context) (string, error) {
if s == nil {
return "", errors.New("service is nil")
}
if s.client == nil {
return "", errors.New("completion client is unavailable")
}
if len(s.history) == 0 {
return "", errors.New("no conversation history available")
}
start := 0
if len(s.history) > 10 {
start = len(s.history) - 10
}
var builder strings.Builder
for i := start; i < len(s.history); i++ {
msg := s.history[i]
builder.WriteString(msg.Role)
builder.WriteString(": ")
builder.WriteString(msg.Content)
builder.WriteString("\n")
}
temp := 0.3
req := openai.ChatCompletionRequest{
Model: s.model,
Messages: []openai.ChatMessage{
{Role: "system", Content: "You generate concise, descriptive chat session names. Respond with a lowercase kebab-case identifier (letters, digits, hyphens only, max 6 words)."},
{Role: "user", Content: fmt.Sprintf("Conversation excerpt:\n%s\nProvide only the session name.", builder.String())},
},
Temperature: &temp,
}
resp, err := s.client.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", errors.New("no session name suggestion returned")
}
suggestion := NormalizeSessionName(resp.Choices[0].Message.Content)
if suggestion == "" {
return "", errors.New("empty session name suggestion")
}
return suggestion, nil
}
// NormalizeSessionName converts arbitrary text into a kebab-case identifier.
func NormalizeSessionName(name string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return ""
}
trimmed = strings.ToLower(trimmed)
var builder strings.Builder
lastHyphen := false
for _, r := range trimmed {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
builder.WriteRune(r)
lastHyphen = false
} else if r == ' ' || r == '-' || r == '_' || r == '.' || r == '/' || r == '\\' {
if !lastHyphen && builder.Len() > 0 {
builder.WriteRune('-')
lastHyphen = true
}
}
if builder.Len() >= 80 {
break
}
}
slug := strings.Trim(builder.String(), "-")
if slug == "" {
return ""
}
parts := strings.Split(slug, "-")
filtered := make([]string, 0, len(parts))
for _, part := range parts {
if part == "" {
continue
}
filtered = append(filtered, part)
if len(filtered) >= 6 {
break
}
}
result := strings.Join(filtered, "-")
if len(result) > 60 {
runes := []rune(result)
if len(runes) > 60 {
runes = runes[:60]
}
result = strings.Trim(string(runes), "-")
}
return result
}
func isAutoGeneratedName(name string) bool {
return strings.HasPrefix(name, defaultSessionPrefix)
}

View File

@@ -0,0 +1,204 @@
package storage
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
)
// Session represents a persisted chat session.
type Session struct {
ID int64
Name string
ModelName string
Summary sql.NullString
CreatedAt time.Time
UpdatedAt time.Time
}
// GetSessionByName fetches a session by its unique name, returning nil when not found.
func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) {
if r == nil {
return nil, errors.New("repository is nil")
}
if strings.TrimSpace(name) == "" {
return nil, errors.New("session name cannot be empty")
}
row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name)
var session Session
if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get session by name: %w", err)
}
return &session, nil
}
// Message represents a persisted chat message.
type Message struct {
ID int64
SessionID int64
Role string
Content string
TokenCount sql.NullInt64
CreatedAt time.Time
}
// Repository exposes CRUD helpers for sessions and messages.
type Repository struct {
db *sql.DB
}
// NewRepository constructs a repository bound to the provided DB.
func NewRepository(db *sql.DB) (*Repository, error) {
if db == nil {
return nil, errors.New("db cannot be nil")
}
return &Repository{db: db}, nil
}
// CreateSession inserts a new session record.
func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) {
if r == nil {
return 0, errors.New("repository is nil")
}
if name == "" {
return 0, errors.New("session name cannot be empty")
}
query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)`
res, err := r.db.ExecContext(ctx, query, name, model)
if err != nil {
return 0, fmt.Errorf("insert session: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("fetch session id: %w", err)
}
return id, nil
}
// AddMessage persists a message linked to a session.
func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) {
if r == nil {
return 0, errors.New("repository is nil")
}
if sessionID <= 0 {
return 0, errors.New("sessionID must be positive")
}
if role == "" || content == "" {
return 0, errors.New("role and content must be provided")
}
var tokenVal interface{}
if tokens != nil {
tokenVal = *tokens
} else {
tokenVal = nil
}
query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)`
res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal)
if err != nil {
return 0, fmt.Errorf("insert message: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("fetch message id: %w", err)
}
return id, nil
}
// ListSessions returns all sessions ordered by creation time descending.
func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) {
rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`)
if err != nil {
return nil, fmt.Errorf("query sessions: %w", err)
}
defer rows.Close()
var sessions []Session
for rows.Next() {
var s Session
if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan session: %w", err)
}
sessions = append(sessions, s)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate sessions: %w", err)
}
return sessions, nil
}
// GetMessages returns messages for a given session ordered by creation time.
func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) {
if sessionID <= 0 {
return nil, errors.New("sessionID must be positive")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID)
if err != nil {
return nil, fmt.Errorf("query messages: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var m Message
if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil {
return nil, fmt.Errorf("scan message: %w", err)
}
messages = append(messages, m)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate messages: %w", err)
}
return messages, nil
}
// UpdateSessionSummary updates the summary column for a session.
func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error {
if sessionID <= 0 {
return errors.New("sessionID must be positive")
}
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID)
if err != nil {
return fmt.Errorf("update session summary: %w", err)
}
return nil
}
// UpdateSessionName updates the name of an existing session.
func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error {
if r == nil {
return errors.New("repository is nil")
}
if sessionID <= 0 {
return errors.New("sessionID must be positive")
}
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return errors.New("session name cannot be empty")
}
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID)
if err != nil {
return fmt.Errorf("update session name: %w", err)
}
return nil
}