Add session naming workflow and rename command
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/stig/goaichat/internal/chat"
|
||||
"github.com/stig/goaichat/internal/config"
|
||||
"github.com/stig/goaichat/internal/openai"
|
||||
"github.com/stig/goaichat/internal/storage"
|
||||
)
|
||||
|
||||
// App encapsulates the Goaichat application runtime wiring.
|
||||
@@ -21,6 +22,8 @@ type App struct {
|
||||
config *config.Config
|
||||
openAI *openai.Client
|
||||
chat *chat.Service
|
||||
store *storage.Manager
|
||||
repo *storage.Repository
|
||||
input io.Reader
|
||||
output io.Writer
|
||||
}
|
||||
@@ -71,6 +74,9 @@ func (a *App) Run(ctx context.Context) error {
|
||||
if err := a.initOpenAIClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.initStorage(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.initChatService(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,7 +114,7 @@ func (a *App) initChatService() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI)
|
||||
service, err := chat.NewService(a.logger.With("component", "chat"), a.config.Model, a.openAI, a.repo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,6 +123,32 @@ func (a *App) initChatService() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) initStorage() error {
|
||||
if a.store != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
manager, err := storage.NewManager(*a.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := manager.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
repo, err := storage.NewRepository(db)
|
||||
if err != nil {
|
||||
_ = manager.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
a.store = manager
|
||||
a.repo = repo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) runCLILoop(ctx context.Context) error {
|
||||
scanner := bufio.NewScanner(a.input)
|
||||
|
||||
@@ -163,11 +195,14 @@ func (a *App) runCLILoop(ctx context.Context) error {
|
||||
if _, err := fmt.Fprintf(a.output, "AI: %s\n", reply); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := a.maybeSuggestSessionName(ctx); err != nil {
|
||||
a.logger.WarnContext(ctx, "session name suggestion failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) handleCommand(ctx context.Context, input string) (handled bool, exit bool, err error) {
|
||||
_ = ctx
|
||||
trimmed := strings.TrimSpace(input)
|
||||
if trimmed == "" {
|
||||
return true, false, errors.New("no input provided")
|
||||
@@ -175,6 +210,86 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex
|
||||
if !strings.HasPrefix(trimmed, "/") {
|
||||
return false, false, nil
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "/rename") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) < 2 {
|
||||
_, err := fmt.Fprintln(a.output, "Usage: /rename <session-name>")
|
||||
return true, false, err
|
||||
}
|
||||
rawName := strings.Join(parts[1:], " ")
|
||||
normalized := chat.NormalizeSessionName(rawName)
|
||||
if normalized == "" {
|
||||
_, err := fmt.Fprintln(a.output, "Session name cannot be empty.")
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
if a.repo != nil {
|
||||
existing, fetchErr := a.repo.GetSessionByName(ctx, normalized)
|
||||
if fetchErr != nil {
|
||||
_, err := fmt.Fprintf(a.output, "Failed to verify name availability: %v\n", fetchErr)
|
||||
return true, false, err
|
||||
}
|
||||
if existing != nil {
|
||||
currentID := a.chat.CurrentSessionID()
|
||||
if currentID == 0 || existing.ID != currentID {
|
||||
_, err := fmt.Fprintf(a.output, "Session name %q is already in use.\n", normalized)
|
||||
return true, false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applied, setErr := a.chat.SetSessionName(ctx, rawName)
|
||||
if setErr != nil {
|
||||
_, err := fmt.Fprintf(a.output, "Failed to rename session: %v\n", setErr)
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
_, err := fmt.Fprintf(a.output, "Session renamed to %q.\n", applied)
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(trimmed, "/open") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) != 2 {
|
||||
_, err := fmt.Fprintln(a.output, "Usage: /open <session-name>")
|
||||
return true, false, err
|
||||
}
|
||||
if a.repo == nil {
|
||||
_, err := fmt.Fprintln(a.output, "Storage not initialised; cannot open sessions.")
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
session, fetchErr := a.repo.GetSessionByName(ctx, parts[1])
|
||||
if fetchErr != nil {
|
||||
_, err := fmt.Fprintf(a.output, "Failed to fetch session: %v\n", fetchErr)
|
||||
return true, false, err
|
||||
}
|
||||
if session == nil {
|
||||
_, err := fmt.Fprintf(a.output, "Session %q not found.\n", parts[1])
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
messages, msgErr := a.repo.GetMessages(ctx, session.ID)
|
||||
if msgErr != nil {
|
||||
_, err := fmt.Fprintf(a.output, "Failed to load messages: %v\n", msgErr)
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
chatMessages := make([]openai.ChatMessage, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
role := strings.TrimSpace(message.Role)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
chatMessages = append(chatMessages, openai.ChatMessage{Role: role, Content: message.Content})
|
||||
}
|
||||
|
||||
summaryPresent := session.Summary.Valid && strings.TrimSpace(session.Summary.String) != ""
|
||||
a.chat.RestoreSession(session.ID, session.Name, chatMessages, summaryPresent)
|
||||
|
||||
_, err := fmt.Fprintf(a.output, "Loaded session %q with %d messages.\n", session.Name, len(chatMessages))
|
||||
return true, false, err
|
||||
}
|
||||
|
||||
switch trimmed {
|
||||
case "/exit":
|
||||
@@ -183,11 +298,57 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex
|
||||
a.chat.Reset()
|
||||
_, err := fmt.Fprintln(a.output, "History cleared.")
|
||||
return true, false, err
|
||||
case "/list":
|
||||
if a.repo == nil {
|
||||
_, err := fmt.Fprintln(a.output, "History commands unavailable (storage not initialised).")
|
||||
return true, false, err
|
||||
}
|
||||
sessions, listErr := a.repo.ListSessions(ctx)
|
||||
if listErr != nil {
|
||||
_, err := fmt.Fprintf(a.output, "Failed to list sessions: %v\n", listErr)
|
||||
return true, false, err
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
_, err := fmt.Fprintln(a.output, "No saved sessions.")
|
||||
return true, false, err
|
||||
}
|
||||
for _, session := range sessions {
|
||||
summary := session.Summary.String
|
||||
if summary == "" {
|
||||
summary = "(no summary)"
|
||||
}
|
||||
if _, err := fmt.Fprintf(a.output, "- %s [%s]: %s\n", session.Name, session.ModelName, summary); err != nil {
|
||||
return true, false, err
|
||||
}
|
||||
}
|
||||
return true, false, nil
|
||||
case "/help":
|
||||
_, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /help (more coming soon)")
|
||||
_, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /list, /open <name>, /rename <name>, /help (more coming soon)")
|
||||
return true, false, err
|
||||
default:
|
||||
_, err := fmt.Fprintf(a.output, "Unknown command %q. Try /help.\n", trimmed)
|
||||
return true, false, err
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) maybeSuggestSessionName(ctx context.Context) error {
|
||||
if a.chat == nil {
|
||||
return nil
|
||||
}
|
||||
if !a.chat.ShouldSuggestSessionName() {
|
||||
return nil
|
||||
}
|
||||
|
||||
suggestion, err := a.chat.SuggestSessionName(ctx)
|
||||
if err != nil {
|
||||
a.chat.MarkSessionNameSuggested()
|
||||
return err
|
||||
}
|
||||
|
||||
a.chat.MarkSessionNameSuggested()
|
||||
if _, err := fmt.Fprintf(a.output, "Suggested session name: %s\nUse /rename %s to apply it now.\n", suggestion, suggestion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -3,30 +3,49 @@ package chat
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/stig/goaichat/internal/config"
|
||||
"github.com/stig/goaichat/internal/openai"
|
||||
)
|
||||
|
||||
const defaultSessionPrefix = "session-"
|
||||
|
||||
// CompletionClient defines the subset of the OpenAI client used by the chat service.
|
||||
type CompletionClient interface {
|
||||
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error)
|
||||
}
|
||||
|
||||
// MessageRepository defines persistence hooks used by the chat service.
|
||||
type MessageRepository interface {
|
||||
CreateSession(ctx context.Context, name, model string) (int64, error)
|
||||
AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error)
|
||||
UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error
|
||||
UpdateSessionName(ctx context.Context, sessionID int64, name string) error
|
||||
}
|
||||
|
||||
// Service coordinates chat requests with the OpenAI client and maintains session history.
|
||||
type Service struct {
|
||||
logger *slog.Logger
|
||||
client CompletionClient
|
||||
model string
|
||||
temperature float64
|
||||
stream bool
|
||||
history []openai.ChatMessage
|
||||
logger *slog.Logger
|
||||
client CompletionClient
|
||||
repo MessageRepository
|
||||
model string
|
||||
temperature float64
|
||||
stream bool
|
||||
history []openai.ChatMessage
|
||||
sessionID int64
|
||||
summarySet bool
|
||||
sessionName string
|
||||
sessionNamed bool
|
||||
nameSuggested bool
|
||||
}
|
||||
|
||||
// NewService constructs a Service from configuration and an OpenAI-compatible client.
|
||||
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient) (*Service, error) {
|
||||
func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client CompletionClient, repo MessageRepository) (*Service, error) {
|
||||
if logger == nil {
|
||||
return nil, errors.New("logger cannot be nil")
|
||||
}
|
||||
@@ -40,6 +59,7 @@ func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client Complet
|
||||
return &Service{
|
||||
logger: logger,
|
||||
client: client,
|
||||
repo: repo,
|
||||
model: modelCfg.Name,
|
||||
temperature: modelCfg.Temperature,
|
||||
stream: modelCfg.Stream,
|
||||
@@ -63,6 +83,9 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) {
|
||||
|
||||
userMsg := openai.ChatMessage{Role: "user", Content: content}
|
||||
s.history = append(s.history, userMsg)
|
||||
if err := s.persistMessage(ctx, userMsg); err != nil {
|
||||
s.logger.WarnContext(ctx, "failed to persist user message", "error", err)
|
||||
}
|
||||
|
||||
messages := append([]openai.ChatMessage(nil), s.history...)
|
||||
temperature := s.temperature
|
||||
@@ -80,13 +103,15 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", errors.New("no choices returned from completion")
|
||||
}
|
||||
|
||||
reply := resp.Choices[0].Message
|
||||
s.history = append(s.history, reply)
|
||||
if err := s.persistMessage(ctx, reply); err != nil {
|
||||
s.logger.WarnContext(ctx, "failed to persist assistant message", "error", err)
|
||||
}
|
||||
|
||||
return reply.Content, nil
|
||||
}
|
||||
@@ -108,4 +133,257 @@ func (s *Service) Reset() {
|
||||
return
|
||||
}
|
||||
s.history = s.history[:0]
|
||||
s.sessionID = 0
|
||||
s.summarySet = false
|
||||
s.sessionName = ""
|
||||
s.sessionNamed = false
|
||||
s.nameSuggested = false
|
||||
}
|
||||
|
||||
// RestoreSession replaces in-memory history with persisted messages for an existing session.
|
||||
func (s *Service) RestoreSession(sessionID int64, name string, messages []openai.ChatMessage, summaryPresent bool) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.history = s.history[:0]
|
||||
s.history = append(s.history, messages...)
|
||||
s.sessionID = sessionID
|
||||
s.summarySet = summaryPresent
|
||||
s.sessionName = name
|
||||
s.sessionNamed = !isAutoGeneratedName(name)
|
||||
s.nameSuggested = false
|
||||
}
|
||||
|
||||
// CurrentSessionID returns the active session identifier, if any.
|
||||
func (s *Service) CurrentSessionID() int64 {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.sessionID
|
||||
}
|
||||
|
||||
// SessionName returns the current session name.
|
||||
func (s *Service) SessionName() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.sessionName
|
||||
}
|
||||
|
||||
// SessionNamed reports whether the current session has a user-friendly name applied.
|
||||
func (s *Service) SessionNamed() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return s.sessionNamed
|
||||
}
|
||||
|
||||
// ShouldSuggestSessionName indicates whether the service would benefit from a generated name.
|
||||
func (s *Service) ShouldSuggestSessionName() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return !s.sessionNamed && !s.nameSuggested
|
||||
}
|
||||
|
||||
// MarkSessionNameSuggested prevents additional automatic suggestions for the current session.
|
||||
func (s *Service) MarkSessionNameSuggested() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.nameSuggested = true
|
||||
}
|
||||
|
||||
// SetSessionName normalizes and persists a new name for the active session, returning the stored value.
|
||||
func (s *Service) SetSessionName(ctx context.Context, name string) (string, error) {
|
||||
if s == nil {
|
||||
return "", errors.New("service is nil")
|
||||
}
|
||||
normalized := NormalizeSessionName(name)
|
||||
if normalized == "" {
|
||||
return "", errors.New("session name cannot be empty")
|
||||
}
|
||||
|
||||
if s.repo != nil {
|
||||
if s.sessionID <= 0 {
|
||||
if err := s.ensureSession(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if err := s.repo.UpdateSessionName(ctx, s.sessionID, normalized); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
s.sessionName = normalized
|
||||
s.sessionNamed = !isAutoGeneratedName(normalized)
|
||||
s.nameSuggested = true
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (s *Service) ensureSession(ctx context.Context) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil
|
||||
}
|
||||
if s.sessionID > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%s%s", defaultSessionPrefix, time.Now().Format("20060102-150405"))
|
||||
s.logger.DebugContext(ctx, "creating chat session", "name", name)
|
||||
|
||||
id, err := s.repo.CreateSession(ctx, name, s.model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.sessionID = id
|
||||
s.summarySet = false
|
||||
s.sessionName = name
|
||||
s.sessionNamed = false
|
||||
s.nameSuggested = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) persistMessage(ctx context.Context, msg openai.ChatMessage) error {
|
||||
if s.repo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.ensureSession(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var tokens *int64
|
||||
if _, err := s.repo.AddMessage(ctx, s.sessionID, msg.Role, msg.Content, tokens); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Role == "assistant" && !s.summarySet && strings.TrimSpace(msg.Content) != "" {
|
||||
const maxSummaryLen = 120
|
||||
summary := msg.Content
|
||||
runes := []rune(summary)
|
||||
if len(runes) > maxSummaryLen {
|
||||
summary = string(runes[:maxSummaryLen]) + "..."
|
||||
}
|
||||
if err := s.repo.UpdateSessionSummary(ctx, s.sessionID, summary); err != nil {
|
||||
s.logger.WarnContext(ctx, "failed to update session summary", "error", err)
|
||||
} else {
|
||||
s.summarySet = true
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SuggestSessionName uses the backing LLM to propose a descriptive session name.
|
||||
func (s *Service) SuggestSessionName(ctx context.Context) (string, error) {
|
||||
if s == nil {
|
||||
return "", errors.New("service is nil")
|
||||
}
|
||||
if s.client == nil {
|
||||
return "", errors.New("completion client is unavailable")
|
||||
}
|
||||
if len(s.history) == 0 {
|
||||
return "", errors.New("no conversation history available")
|
||||
}
|
||||
|
||||
start := 0
|
||||
if len(s.history) > 10 {
|
||||
start = len(s.history) - 10
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for i := start; i < len(s.history); i++ {
|
||||
msg := s.history[i]
|
||||
builder.WriteString(msg.Role)
|
||||
builder.WriteString(": ")
|
||||
builder.WriteString(msg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
temp := 0.3
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: s.model,
|
||||
Messages: []openai.ChatMessage{
|
||||
{Role: "system", Content: "You generate concise, descriptive chat session names. Respond with a lowercase kebab-case identifier (letters, digits, hyphens only, max 6 words)."},
|
||||
{Role: "user", Content: fmt.Sprintf("Conversation excerpt:\n%s\nProvide only the session name.", builder.String())},
|
||||
},
|
||||
Temperature: &temp,
|
||||
}
|
||||
|
||||
resp, err := s.client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", errors.New("no session name suggestion returned")
|
||||
}
|
||||
|
||||
suggestion := NormalizeSessionName(resp.Choices[0].Message.Content)
|
||||
if suggestion == "" {
|
||||
return "", errors.New("empty session name suggestion")
|
||||
}
|
||||
|
||||
return suggestion, nil
|
||||
}
|
||||
|
||||
// NormalizeSessionName converts arbitrary text into a kebab-case identifier.
|
||||
func NormalizeSessionName(name string) string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
trimmed = strings.ToLower(trimmed)
|
||||
var builder strings.Builder
|
||||
lastHyphen := false
|
||||
|
||||
for _, r := range trimmed {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
builder.WriteRune(r)
|
||||
lastHyphen = false
|
||||
} else if r == ' ' || r == '-' || r == '_' || r == '.' || r == '/' || r == '\\' {
|
||||
if !lastHyphen && builder.Len() > 0 {
|
||||
builder.WriteRune('-')
|
||||
lastHyphen = true
|
||||
}
|
||||
}
|
||||
if builder.Len() >= 80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slug := strings.Trim(builder.String(), "-")
|
||||
if slug == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(slug, "-")
|
||||
filtered := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, part)
|
||||
if len(filtered) >= 6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(filtered, "-")
|
||||
if len(result) > 60 {
|
||||
runes := []rune(result)
|
||||
if len(runes) > 60 {
|
||||
runes = runes[:60]
|
||||
}
|
||||
result = strings.Trim(string(runes), "-")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isAutoGeneratedName(name string) bool {
|
||||
return strings.HasPrefix(name, defaultSessionPrefix)
|
||||
}
|
||||
|
204
internal/storage/repository.go
Normal file
204
internal/storage/repository.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Session represents a persisted chat session.
|
||||
type Session struct {
|
||||
ID int64
|
||||
Name string
|
||||
ModelName string
|
||||
Summary sql.NullString
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// GetSessionByName fetches a session by its unique name, returning nil when not found.
|
||||
func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) {
|
||||
if r == nil {
|
||||
return nil, errors.New("repository is nil")
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return nil, errors.New("session name cannot be empty")
|
||||
}
|
||||
|
||||
row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name)
|
||||
var session Session
|
||||
if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get session by name: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// Message represents a persisted chat message.
|
||||
type Message struct {
|
||||
ID int64
|
||||
SessionID int64
|
||||
Role string
|
||||
Content string
|
||||
TokenCount sql.NullInt64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Repository exposes CRUD helpers for sessions and messages.
|
||||
type Repository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewRepository constructs a repository bound to the provided DB.
|
||||
func NewRepository(db *sql.DB) (*Repository, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("db cannot be nil")
|
||||
}
|
||||
return &Repository{db: db}, nil
|
||||
}
|
||||
|
||||
// CreateSession inserts a new session record.
|
||||
func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) {
|
||||
if r == nil {
|
||||
return 0, errors.New("repository is nil")
|
||||
}
|
||||
if name == "" {
|
||||
return 0, errors.New("session name cannot be empty")
|
||||
}
|
||||
|
||||
query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)`
|
||||
res, err := r.db.ExecContext(ctx, query, name, model)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("fetch session id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// AddMessage persists a message linked to a session.
|
||||
func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) {
|
||||
if r == nil {
|
||||
return 0, errors.New("repository is nil")
|
||||
}
|
||||
if sessionID <= 0 {
|
||||
return 0, errors.New("sessionID must be positive")
|
||||
}
|
||||
if role == "" || content == "" {
|
||||
return 0, errors.New("role and content must be provided")
|
||||
}
|
||||
|
||||
var tokenVal interface{}
|
||||
if tokens != nil {
|
||||
tokenVal = *tokens
|
||||
} else {
|
||||
tokenVal = nil
|
||||
}
|
||||
|
||||
query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)`
|
||||
res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert message: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("fetch message id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ListSessions returns all sessions ordered by creation time descending.
|
||||
func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []Session
|
||||
for rows.Next() {
|
||||
var s Session
|
||||
if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan session: %w", err)
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate sessions: %w", err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// GetMessages returns messages for a given session ordered by creation time.
|
||||
func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) {
|
||||
if sessionID <= 0 {
|
||||
return nil, errors.New("sessionID must be positive")
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query messages: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var m Message
|
||||
if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan message: %w", err)
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate messages: %w", err)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// UpdateSessionSummary updates the summary column for a session.
|
||||
func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error {
|
||||
if sessionID <= 0 {
|
||||
return errors.New("sessionID must be positive")
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update session summary: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSessionName updates the name of an existing session.
|
||||
func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error {
|
||||
if r == nil {
|
||||
return errors.New("repository is nil")
|
||||
}
|
||||
if sessionID <= 0 {
|
||||
return errors.New("sessionID must be positive")
|
||||
}
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return errors.New("session name cannot be empty")
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update session name: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user