Add session naming workflow and rename command
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/stig/goaichat/internal/chat"
|
"github.com/stig/goaichat/internal/chat"
|
||||||
"github.com/stig/goaichat/internal/config"
|
"github.com/stig/goaichat/internal/config"
|
||||||
"github.com/stig/goaichat/internal/openai"
|
"github.com/stig/goaichat/internal/openai"
|
||||||
|
"github.com/stig/goaichat/internal/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
// App encapsulates the Goaichat application runtime wiring.
|
// App encapsulates the Goaichat application runtime wiring.
|
||||||
@@ -21,6 +22,8 @@ type App struct {
|
|||||||
config *config.Config
|
config *config.Config
|
||||||
openAI *openai.Client
|
openAI *openai.Client
|
||||||
chat *chat.Service
|
chat *chat.Service
|
||||||
|
store *storage.Manager
|
||||||
|
repo *storage.Repository
|
||||||
input io.Reader
|
input io.Reader
|
||||||
output io.Writer
|
output io.Writer
|
||||||
}
|
}
|
||||||
@@ -71,6 +74,9 @@ func (a *App) Run(ctx context.Context) error {
|
|||||||
if err := a.initOpenAIClient(); err != nil {
|
if err := a.initOpenAIClient(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := a.initStorage(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := a.initChatService(); err != nil {
|
if err := a.initChatService(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -108,7 +114,7 @@ func (a *App) initChatService() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI)
|
service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI, a.repo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -117,6 +123,32 @@ func (a *App) initChatService() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) initStorage() error {
|
||||||
|
if a.store != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := storage.NewManager(*a.config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := manager.Open()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
repo, err := storage.NewRepository(db)
|
||||||
|
if err != nil {
|
||||||
|
_ = manager.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.store = manager
|
||||||
|
a.repo = repo
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) runCLILoop(ctx context.Context) error {
|
func (a *App) runCLILoop(ctx context.Context) error {
|
||||||
scanner := bufio.NewScanner(a.input)
|
scanner := bufio.NewScanner(a.input)
|
||||||
|
|
||||||
@@ -163,11 +195,14 @@ func (a *App) runCLILoop(ctx context.Context) error {
|
|||||||
if _, err := fmt.Fprintf(a.output, "AI: %s\n", reply); err != nil {
|
if _, err := fmt.Fprintf(a.output, "AI: %s\n", reply); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := a.maybeSuggestSessionName(ctx); err != nil {
|
||||||
|
a.logger.WarnContext(ctx, "session name suggestion failed", "error", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) handleCommand(ctx context.Context, input string) (handled bool, exit bool, err error) {
|
func (a *App) handleCommand(ctx context.Context, input string) (handled bool, exit bool, err error) {
|
||||||
_ = ctx
|
|
||||||
trimmed := strings.TrimSpace(input)
|
trimmed := strings.TrimSpace(input)
|
||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
return true, false, errors.New("no input provided")
|
return true, false, errors.New("no input provided")
|
||||||
@@ -175,6 +210,86 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex
|
|||||||
if !strings.HasPrefix(trimmed, "/") {
|
if !strings.HasPrefix(trimmed, "/") {
|
||||||
return false, false, nil
|
return false, false, nil
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(trimmed, "/rename") {
|
||||||
|
parts := strings.Fields(trimmed)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
_, err := fmt.Fprintln(a.output, "Usage: /rename <session-name>")
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
rawName := strings.Join(parts[1:], " ")
|
||||||
|
normalized := chat.NormalizeSessionName(rawName)
|
||||||
|
if normalized == "" {
|
||||||
|
_, err := fmt.Fprintln(a.output, "Session name cannot be empty.")
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.repo != nil {
|
||||||
|
existing, fetchErr := a.repo.GetSessionByName(ctx, normalized)
|
||||||
|
if fetchErr != nil {
|
||||||
|
_, err := fmt.Fprintf(a.output, "Failed to verify name availability: %v\n", fetchErr)
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
if existing != nil {
|
||||||
|
currentID := a.chat.CurrentSessionID()
|
||||||
|
if currentID == 0 || existing.ID != currentID {
|
||||||
|
_, err := fmt.Fprintf(a.output, "Session name %q is already in use.\n", normalized)
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
applied, setErr := a.chat.SetSessionName(ctx, rawName)
|
||||||
|
if setErr != nil {
|
||||||
|
_, err := fmt.Fprintf(a.output, "Failed to rename session: %v\n", setErr)
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := fmt.Fprintf(a.output, "Session renamed to %q.\n", applied)
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(trimmed, "/open") {
|
||||||
|
parts := strings.Fields(trimmed)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
_, err := fmt.Fprintln(a.output, "Usage: /open <session-name>")
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
if a.repo == nil {
|
||||||
|
_, err := fmt.Fprintln(a.output, "Storage not initialised; cannot open sessions.")
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
session, fetchErr := a.repo.GetSessionByName(ctx, parts[1])
|
||||||
|
if fetchErr != nil {
|
||||||
|
_, err := fmt.Fprintf(a.output, "Failed to fetch session: %v\n", fetchErr)
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
_, err := fmt.Fprintf(a.output, "Session %q not found.\n", parts[1])
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, msgErr := a.repo.GetMessages(ctx, session.ID)
|
||||||
|
if msgErr != nil {
|
||||||
|
_, err := fmt.Fprintf(a.output, "Failed to load messages: %v\n", msgErr)
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
chatMessages := make([]openai.ChatMessage, 0, len(messages))
|
||||||
|
for _, message := range messages {
|
||||||
|
role := strings.TrimSpace(message.Role)
|
||||||
|
if role == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chatMessages = append(chatMessages, openai.ChatMessage{Role: role, Content: message.Content})
|
||||||
|
}
|
||||||
|
|
||||||
|
summaryPresent := session.Summary.Valid && strings.TrimSpace(session.Summary.String) != ""
|
||||||
|
a.chat.RestoreSession(session.ID, session.Name, chatMessages, summaryPresent)
|
||||||
|
|
||||||
|
_, err := fmt.Fprintf(a.output, "Loaded session %q with %d messages.\n", session.Name, len(chatMessages))
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
|
||||||
switch trimmed {
|
switch trimmed {
|
||||||
case "/exit":
|
case "/exit":
|
||||||
@@ -183,11 +298,57 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex
|
|||||||
a.chat.Reset()
|
a.chat.Reset()
|
||||||
_, err := fmt.Fprintln(a.output, "History cleared.")
|
_, err := fmt.Fprintln(a.output, "History cleared.")
|
||||||
return true, false, err
|
return true, false, err
|
||||||
|
case "/list":
|
||||||
|
if a.repo == nil {
|
||||||
|
_, err := fmt.Fprintln(a.output, "History commands unavailable (storage not initialised).")
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
sessions, listErr := a.repo.ListSessions(ctx)
|
||||||
|
if listErr != nil {
|
||||||
|
_, err := fmt.Fprintf(a.output, "Failed to list sessions: %v\n", listErr)
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
if len(sessions) == 0 {
|
||||||
|
_, err := fmt.Fprintln(a.output, "No saved sessions.")
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
for _, session := range sessions {
|
||||||
|
summary := session.Summary.String
|
||||||
|
if summary == "" {
|
||||||
|
summary = "(no summary)"
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(a.output, "- %s [%s]: %s\n", session.Name, session.ModelName, summary); err != nil {
|
||||||
|
return true, false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, false, nil
|
||||||
case "/help":
|
case "/help":
|
||||||
_, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /help (more coming soon)")
|
_, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /list, /open <name>, /rename <name>, /help (more coming soon)")
|
||||||
return true, false, err
|
return true, false, err
|
||||||
default:
|
default:
|
||||||
_, err := fmt.Fprintf(a.output, "Unknown command %q. Try /help.\n", trimmed)
|
_, err := fmt.Fprintf(a.output, "Unknown command %q. Try /help.\n", trimmed)
|
||||||
return true, false, err
|
return true, false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) maybeSuggestSessionName(ctx context.Context) error {
|
||||||
|
if a.chat == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !a.chat.ShouldSuggestSessionName() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
suggestion, err := a.chat.SuggestSessionName(ctx)
|
||||||
|
if err != nil {
|
||||||
|
a.chat.MarkSessionNameSuggested()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.chat.MarkSessionNameSuggested()
|
||||||
|
if _, err := fmt.Fprintf(a.output, "Suggested session name: %s\nUse /rename %s to apply it now.\n", suggestion, suggestion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -3,30 +3,49 @@ package chat
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"github.com/stig/goaichat/internal/config"
|
"github.com/stig/goaichat/internal/config"
|
||||||
"github.com/stig/goaichat/internal/openai"
|
"github.com/stig/goaichat/internal/openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultSessionPrefix = "session-"
|
||||||
|
|
||||||
// CompletionClient defines the subset of the OpenAI client used by the chat service.
|
// CompletionClient defines the subset of the OpenAI client used by the chat service.
|
||||||
type CompletionClient interface {
|
type CompletionClient interface {
|
||||||
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
|
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MessageRepository defines persistence hooks used by the chat service.
|
||||||
|
type MessageRepository interface {
|
||||||
|
CreateSession(ctx context.Context, name, model string) (int64, error)
|
||||||
|
AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error)
|
||||||
|
UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error
|
||||||
|
UpdateSessionName(ctx context.Context, sessionID int64, name string) error
|
||||||
|
}
|
||||||
|
|
||||||
// Service coordinates chat requests with the OpenAI client and maintains session history.
|
// Service coordinates chat requests with the OpenAI client and maintains session history.
|
||||||
type Service struct {
|
type Service struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
client CompletionClient
|
client CompletionClient
|
||||||
|
repo MessageRepository
|
||||||
model string
|
model string
|
||||||
temperature float64
|
temperature float64
|
||||||
stream bool
|
stream bool
|
||||||
history []openai.ChatMessage
|
history []openai.ChatMessage
|
||||||
|
sessionID int64
|
||||||
|
summarySet bool
|
||||||
|
sessionName string
|
||||||
|
sessionNamed bool
|
||||||
|
nameSuggested bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewService constructs a Service from configuration and an OpenAI-compatible client.
|
// NewService constructs a Service from configuration and an OpenAI-compatible client.
|
||||||
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient) (*Service, error) {
|
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient, repo MessageRepository) (*Service, error) {
|
||||||
if logger == nil {
|
if logger == nil {
|
||||||
return nil, errors.New("logger cannot be nil")
|
return nil, errors.New("logger cannot be nil")
|
||||||
}
|
}
|
||||||
@@ -40,6 +59,7 @@ func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client Complet
|
|||||||
return &Service{
|
return &Service{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: client,
|
client: client,
|
||||||
|
repo: repo,
|
||||||
model: modelCfg.Name,
|
model: modelCfg.Name,
|
||||||
temperature: modelCfg.Temperature,
|
temperature: modelCfg.Temperature,
|
||||||
stream: modelCfg.Stream,
|
stream: modelCfg.Stream,
|
||||||
@@ -63,6 +83,9 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) {
|
|||||||
|
|
||||||
userMsg := openai.ChatMessage{Role: "user", Content: content}
|
userMsg := openai.ChatMessage{Role: "user", Content: content}
|
||||||
s.history = append(s.history, userMsg)
|
s.history = append(s.history, userMsg)
|
||||||
|
if err := s.persistMessage(ctx, userMsg); err != nil {
|
||||||
|
s.logger.WarnContext(ctx, "failed to persist user message", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
messages := append([]openai.ChatMessage(nil), s.history...)
|
messages := append([]openai.ChatMessage(nil), s.history...)
|
||||||
temperature := s.temperature
|
temperature := s.temperature
|
||||||
@@ -80,13 +103,15 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(resp.Choices) == 0 {
|
if len(resp.Choices) == 0 {
|
||||||
return "", errors.New("no choices returned from completion")
|
return "", errors.New("no choices returned from completion")
|
||||||
}
|
}
|
||||||
|
|
||||||
reply := resp.Choices[0].Message
|
reply := resp.Choices[0].Message
|
||||||
s.history = append(s.history, reply)
|
s.history = append(s.history, reply)
|
||||||
|
if err := s.persistMessage(ctx, reply); err != nil {
|
||||||
|
s.logger.WarnContext(ctx, "failed to persist assistant message", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
return reply.Content, nil
|
return reply.Content, nil
|
||||||
}
|
}
|
||||||
@@ -108,4 +133,257 @@ func (s *Service) Reset() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.history = s.history[:0]
|
s.history = s.history[:0]
|
||||||
|
s.sessionID = 0
|
||||||
|
s.summarySet = false
|
||||||
|
s.sessionName = ""
|
||||||
|
s.sessionNamed = false
|
||||||
|
s.nameSuggested = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreSession replaces in-memory history with persisted messages for an existing session.
|
||||||
|
func (s *Service) RestoreSession(sessionID int64, name string, messages []openai.ChatMessage, summaryPresent bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.history = s.history[:0]
|
||||||
|
s.history = append(s.history, messages...)
|
||||||
|
s.sessionID = sessionID
|
||||||
|
s.summarySet = summaryPresent
|
||||||
|
s.sessionName = name
|
||||||
|
s.sessionNamed = !isAutoGeneratedName(name)
|
||||||
|
s.nameSuggested = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CurrentSessionID returns the active session identifier, if any.
|
||||||
|
func (s *Service) CurrentSessionID() int64 {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return s.sessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionName returns the current session name.
|
||||||
|
func (s *Service) SessionName() string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s.sessionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionNamed reports whether the current session has a user-friendly name applied.
|
||||||
|
func (s *Service) SessionNamed() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.sessionNamed
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldSuggestSessionName indicates whether the service would benefit from a generated name.
|
||||||
|
func (s *Service) ShouldSuggestSessionName() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !s.sessionNamed && !s.nameSuggested
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkSessionNameSuggested prevents additional automatic suggestions for the current session.
|
||||||
|
func (s *Service) MarkSessionNameSuggested() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.nameSuggested = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSessionName normalizes and persists a new name for the active session, returning the stored value.
|
||||||
|
func (s *Service) SetSessionName(ctx context.Context, name string) (string, error) {
|
||||||
|
if s == nil {
|
||||||
|
return "", errors.New("service is nil")
|
||||||
|
}
|
||||||
|
normalized := NormalizeSessionName(name)
|
||||||
|
if normalized == "" {
|
||||||
|
return "", errors.New("session name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.repo != nil {
|
||||||
|
if s.sessionID <= 0 {
|
||||||
|
if err := s.ensureSession(ctx); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.repo.UpdateSessionName(ctx, s.sessionID, normalized); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sessionName = normalized
|
||||||
|
s.sessionNamed = !isAutoGeneratedName(normalized)
|
||||||
|
s.nameSuggested = true
|
||||||
|
return normalized, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ensureSession(ctx context.Context) error {
|
||||||
|
if s == nil || s.repo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.sessionID > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("%s%s", defaultSessionPrefix, time.Now().Format("20060102-150405"))
|
||||||
|
s.logger.DebugContext(ctx, "creating chat session", "name", name)
|
||||||
|
|
||||||
|
id, err := s.repo.CreateSession(ctx, name, s.model)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sessionID = id
|
||||||
|
s.summarySet = false
|
||||||
|
s.sessionName = name
|
||||||
|
s.sessionNamed = false
|
||||||
|
s.nameSuggested = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) persistMessage(ctx context.Context, msg openai.ChatMessage) error {
|
||||||
|
if s.repo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.ensureSession(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens *int64
|
||||||
|
if _, err := s.repo.AddMessage(ctx, s.sessionID, msg.Role, msg.Content, tokens); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Role == "assistant" && !s.summarySet && strings.TrimSpace(msg.Content) != "" {
|
||||||
|
const maxSummaryLen = 120
|
||||||
|
summary := msg.Content
|
||||||
|
runes := []rune(summary)
|
||||||
|
if len(runes) > maxSummaryLen {
|
||||||
|
summary = string(runes[:maxSummaryLen]) + "..."
|
||||||
|
}
|
||||||
|
if err := s.repo.UpdateSessionSummary(ctx, s.sessionID, summary); err != nil {
|
||||||
|
s.logger.WarnContext(ctx, "failed to update session summary", "error", err)
|
||||||
|
} else {
|
||||||
|
s.summarySet = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SuggestSessionName uses the backing LLM to propose a descriptive session name.
|
||||||
|
func (s *Service) SuggestSessionName(ctx context.Context) (string, error) {
|
||||||
|
if s == nil {
|
||||||
|
return "", errors.New("service is nil")
|
||||||
|
}
|
||||||
|
if s.client == nil {
|
||||||
|
return "", errors.New("completion client is unavailable")
|
||||||
|
}
|
||||||
|
if len(s.history) == 0 {
|
||||||
|
return "", errors.New("no conversation history available")
|
||||||
|
}
|
||||||
|
|
||||||
|
start := 0
|
||||||
|
if len(s.history) > 10 {
|
||||||
|
start = len(s.history) - 10
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder strings.Builder
|
||||||
|
for i := start; i < len(s.history); i++ {
|
||||||
|
msg := s.history[i]
|
||||||
|
builder.WriteString(msg.Role)
|
||||||
|
builder.WriteString(": ")
|
||||||
|
builder.WriteString(msg.Content)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
temp := 0.3
|
||||||
|
req := openai.ChatCompletionRequest{
|
||||||
|
Model: s.model,
|
||||||
|
Messages: []openai.ChatMessage{
|
||||||
|
{Role: "system", Content: "You generate concise, descriptive chat session names. Respond with a lowercase kebab-case identifier (letters, digits, hyphens only, max 6 words)."},
|
||||||
|
{Role: "user", Content: fmt.Sprintf("Conversation excerpt:\n%s\nProvide only the session name.", builder.String())},
|
||||||
|
},
|
||||||
|
Temperature: &temp,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.client.CreateChatCompletion(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(resp.Choices) == 0 {
|
||||||
|
return "", errors.New("no session name suggestion returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
suggestion := NormalizeSessionName(resp.Choices[0].Message.Content)
|
||||||
|
if suggestion == "" {
|
||||||
|
return "", errors.New("empty session name suggestion")
|
||||||
|
}
|
||||||
|
|
||||||
|
return suggestion, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeSessionName converts arbitrary text into a kebab-case identifier.
|
||||||
|
func NormalizeSessionName(name string) string {
|
||||||
|
trimmed := strings.TrimSpace(name)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed = strings.ToLower(trimmed)
|
||||||
|
var builder strings.Builder
|
||||||
|
lastHyphen := false
|
||||||
|
|
||||||
|
for _, r := range trimmed {
|
||||||
|
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||||
|
builder.WriteRune(r)
|
||||||
|
lastHyphen = false
|
||||||
|
} else if r == ' ' || r == '-' || r == '_' || r == '.' || r == '/' || r == '\\' {
|
||||||
|
if !lastHyphen && builder.Len() > 0 {
|
||||||
|
builder.WriteRune('-')
|
||||||
|
lastHyphen = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if builder.Len() >= 80 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slug := strings.Trim(builder.String(), "-")
|
||||||
|
if slug == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(slug, "-")
|
||||||
|
filtered := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, part)
|
||||||
|
if len(filtered) >= 6 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := strings.Join(filtered, "-")
|
||||||
|
if len(result) > 60 {
|
||||||
|
runes := []rune(result)
|
||||||
|
if len(runes) > 60 {
|
||||||
|
runes = runes[:60]
|
||||||
|
}
|
||||||
|
result = strings.Trim(string(runes), "-")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAutoGeneratedName(name string) bool {
|
||||||
|
return strings.HasPrefix(name, defaultSessionPrefix)
|
||||||
}
|
}
|
||||||
|
204
internal/storage/repository.go
Normal file
204
internal/storage/repository.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Session represents a persisted chat session.
|
||||||
|
type Session struct {
|
||||||
|
ID int64
|
||||||
|
Name string
|
||||||
|
ModelName string
|
||||||
|
Summary sql.NullString
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionByName fetches a session by its unique name, returning nil when not found.
|
||||||
|
func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) {
|
||||||
|
if r == nil {
|
||||||
|
return nil, errors.New("repository is nil")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return nil, errors.New("session name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name)
|
||||||
|
var session Session
|
||||||
|
if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get session by name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message represents a persisted chat message.
|
||||||
|
type Message struct {
|
||||||
|
ID int64
|
||||||
|
SessionID int64
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
|
TokenCount sql.NullInt64
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Repository exposes CRUD helpers for sessions and messages.
|
||||||
|
type Repository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRepository constructs a repository bound to the provided DB.
|
||||||
|
func NewRepository(db *sql.DB) (*Repository, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, errors.New("db cannot be nil")
|
||||||
|
}
|
||||||
|
return &Repository{db: db}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession inserts a new session record.
|
||||||
|
func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) {
|
||||||
|
if r == nil {
|
||||||
|
return 0, errors.New("repository is nil")
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return 0, errors.New("session name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)`
|
||||||
|
res, err := r.db.ExecContext(ctx, query, name, model)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("insert session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := res.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("fetch session id: %w", err)
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMessage persists a message linked to a session.
|
||||||
|
func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) {
|
||||||
|
if r == nil {
|
||||||
|
return 0, errors.New("repository is nil")
|
||||||
|
}
|
||||||
|
if sessionID <= 0 {
|
||||||
|
return 0, errors.New("sessionID must be positive")
|
||||||
|
}
|
||||||
|
if role == "" || content == "" {
|
||||||
|
return 0, errors.New("role and content must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenVal interface{}
|
||||||
|
if tokens != nil {
|
||||||
|
tokenVal = *tokens
|
||||||
|
} else {
|
||||||
|
tokenVal = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)`
|
||||||
|
res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("insert message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := res.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("fetch message id: %w", err)
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSessions returns all sessions ordered by creation time descending.
|
||||||
|
func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) {
|
||||||
|
rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query sessions: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var sessions []Session
|
||||||
|
for rows.Next() {
|
||||||
|
var s Session
|
||||||
|
if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan session: %w", err)
|
||||||
|
}
|
||||||
|
sessions = append(sessions, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages returns messages for a given session ordered by creation time.
|
||||||
|
func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) {
|
||||||
|
if sessionID <= 0 {
|
||||||
|
return nil, errors.New("sessionID must be positive")
|
||||||
|
}
|
||||||
|
rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query messages: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var messages []Message
|
||||||
|
for rows.Next() {
|
||||||
|
var m Message
|
||||||
|
if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan message: %w", err)
|
||||||
|
}
|
||||||
|
messages = append(messages, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSessionSummary updates the summary column for a session.
|
||||||
|
func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error {
|
||||||
|
if sessionID <= 0 {
|
||||||
|
return errors.New("sessionID must be positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update session summary: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSessionName updates the name of an existing session.
|
||||||
|
func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error {
|
||||||
|
if r == nil {
|
||||||
|
return errors.New("repository is nil")
|
||||||
|
}
|
||||||
|
if sessionID <= 0 {
|
||||||
|
return errors.New("sessionID must be positive")
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(name)
|
||||||
|
if trimmed == "" {
|
||||||
|
return errors.New("session name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update session name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
Reference in New Issue
Block a user