Add session naming workflow and rename command
This commit is contained in:
204
internal/storage/repository.go
Normal file
204
internal/storage/repository.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Session represents a persisted chat session.
|
||||
type Session struct {
|
||||
ID int64
|
||||
Name string
|
||||
ModelName string
|
||||
Summary sql.NullString
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// GetSessionByName fetches a session by its unique name, returning nil when not found.
|
||||
func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) {
|
||||
if r == nil {
|
||||
return nil, errors.New("repository is nil")
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return nil, errors.New("session name cannot be empty")
|
||||
}
|
||||
|
||||
row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name)
|
||||
var session Session
|
||||
if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get session by name: %w", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// Message represents a persisted chat message.
|
||||
type Message struct {
|
||||
ID int64
|
||||
SessionID int64
|
||||
Role string
|
||||
Content string
|
||||
TokenCount sql.NullInt64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Repository exposes CRUD helpers for sessions and messages.
|
||||
type Repository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewRepository constructs a repository bound to the provided DB.
|
||||
func NewRepository(db *sql.DB) (*Repository, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("db cannot be nil")
|
||||
}
|
||||
return &Repository{db: db}, nil
|
||||
}
|
||||
|
||||
// CreateSession inserts a new session record.
|
||||
func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) {
|
||||
if r == nil {
|
||||
return 0, errors.New("repository is nil")
|
||||
}
|
||||
if name == "" {
|
||||
return 0, errors.New("session name cannot be empty")
|
||||
}
|
||||
|
||||
query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)`
|
||||
res, err := r.db.ExecContext(ctx, query, name, model)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("fetch session id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// AddMessage persists a message linked to a session.
|
||||
func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) {
|
||||
if r == nil {
|
||||
return 0, errors.New("repository is nil")
|
||||
}
|
||||
if sessionID <= 0 {
|
||||
return 0, errors.New("sessionID must be positive")
|
||||
}
|
||||
if role == "" || content == "" {
|
||||
return 0, errors.New("role and content must be provided")
|
||||
}
|
||||
|
||||
var tokenVal interface{}
|
||||
if tokens != nil {
|
||||
tokenVal = *tokens
|
||||
} else {
|
||||
tokenVal = nil
|
||||
}
|
||||
|
||||
query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)`
|
||||
res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert message: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("fetch message id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ListSessions returns all sessions ordered by creation time descending.
|
||||
func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []Session
|
||||
for rows.Next() {
|
||||
var s Session
|
||||
if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan session: %w", err)
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate sessions: %w", err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// GetMessages returns messages for a given session ordered by creation time.
|
||||
func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) {
|
||||
if sessionID <= 0 {
|
||||
return nil, errors.New("sessionID must be positive")
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query messages: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var m Message
|
||||
if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan message: %w", err)
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate messages: %w", err)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// UpdateSessionSummary updates the summary column for a session.
|
||||
func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error {
|
||||
if sessionID <= 0 {
|
||||
return errors.New("sessionID must be positive")
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update session summary: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSessionName updates the name of an existing session.
|
||||
func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error {
|
||||
if r == nil {
|
||||
return errors.New("repository is nil")
|
||||
}
|
||||
if sessionID <= 0 {
|
||||
return errors.New("sessionID must be positive")
|
||||
}
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return errors.New("session name cannot be empty")
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update session name: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user