Files
goaichat/internal/storage/repository.go

205 lines
5.7 KiB
Go

package storage
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
)
// Session represents a persisted chat session.
type Session struct {
ID int64
Name string
ModelName string
Summary sql.NullString
CreatedAt time.Time
UpdatedAt time.Time
}
// GetSessionByName fetches a session by its unique name, returning nil when not found.
func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) {
if r == nil {
return nil, errors.New("repository is nil")
}
if strings.TrimSpace(name) == "" {
return nil, errors.New("session name cannot be empty")
}
row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name)
var session Session
if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get session by name: %w", err)
}
return &session, nil
}
// Message represents a persisted chat message.
type Message struct {
ID int64
SessionID int64
Role string
Content string
TokenCount sql.NullInt64
CreatedAt time.Time
}
// Repository exposes CRUD helpers for sessions and messages.
type Repository struct {
db *sql.DB
}
// NewRepository constructs a repository bound to the provided DB.
func NewRepository(db *sql.DB) (*Repository, error) {
if db == nil {
return nil, errors.New("db cannot be nil")
}
return &Repository{db: db}, nil
}
// CreateSession inserts a new session record.
func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) {
if r == nil {
return 0, errors.New("repository is nil")
}
if name == "" {
return 0, errors.New("session name cannot be empty")
}
query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)`
res, err := r.db.ExecContext(ctx, query, name, model)
if err != nil {
return 0, fmt.Errorf("insert session: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("fetch session id: %w", err)
}
return id, nil
}
// AddMessage persists a message linked to a session.
func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) {
if r == nil {
return 0, errors.New("repository is nil")
}
if sessionID <= 0 {
return 0, errors.New("sessionID must be positive")
}
if role == "" || content == "" {
return 0, errors.New("role and content must be provided")
}
var tokenVal interface{}
if tokens != nil {
tokenVal = *tokens
} else {
tokenVal = nil
}
query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)`
res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal)
if err != nil {
return 0, fmt.Errorf("insert message: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("fetch message id: %w", err)
}
return id, nil
}
// ListSessions returns all sessions ordered by creation time descending.
func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) {
rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`)
if err != nil {
return nil, fmt.Errorf("query sessions: %w", err)
}
defer rows.Close()
var sessions []Session
for rows.Next() {
var s Session
if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan session: %w", err)
}
sessions = append(sessions, s)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate sessions: %w", err)
}
return sessions, nil
}
// GetMessages returns messages for a given session ordered by creation time.
func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) {
if sessionID <= 0 {
return nil, errors.New("sessionID must be positive")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID)
if err != nil {
return nil, fmt.Errorf("query messages: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var m Message
if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil {
return nil, fmt.Errorf("scan message: %w", err)
}
messages = append(messages, m)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate messages: %w", err)
}
return messages, nil
}
// UpdateSessionSummary updates the summary column for a session.
func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error {
if sessionID <= 0 {
return errors.New("sessionID must be positive")
}
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID)
if err != nil {
return fmt.Errorf("update session summary: %w", err)
}
return nil
}
// UpdateSessionName updates the name of an existing session.
func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error {
if r == nil {
return errors.New("repository is nil")
}
if sessionID <= 0 {
return errors.New("sessionID must be positive")
}
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return errors.New("session name cannot be empty")
}
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID)
if err != nil {
return fmt.Errorf("update session name: %w", err)
}
return nil
}