205 lines
5.7 KiB
Go
205 lines
5.7 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Session represents a persisted chat session.
|
|
type Session struct {
|
|
ID int64
|
|
Name string
|
|
ModelName string
|
|
Summary sql.NullString
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
// GetSessionByName fetches a session by its unique name, returning nil when not found.
|
|
func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) {
|
|
if r == nil {
|
|
return nil, errors.New("repository is nil")
|
|
}
|
|
if strings.TrimSpace(name) == "" {
|
|
return nil, errors.New("session name cannot be empty")
|
|
}
|
|
|
|
row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name)
|
|
var session Session
|
|
if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("get session by name: %w", err)
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
// Message represents a persisted chat message.
|
|
type Message struct {
|
|
ID int64
|
|
SessionID int64
|
|
Role string
|
|
Content string
|
|
TokenCount sql.NullInt64
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
// Repository exposes CRUD helpers for sessions and messages.
|
|
type Repository struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewRepository constructs a repository bound to the provided DB.
|
|
func NewRepository(db *sql.DB) (*Repository, error) {
|
|
if db == nil {
|
|
return nil, errors.New("db cannot be nil")
|
|
}
|
|
return &Repository{db: db}, nil
|
|
}
|
|
|
|
// CreateSession inserts a new session record.
|
|
func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) {
|
|
if r == nil {
|
|
return 0, errors.New("repository is nil")
|
|
}
|
|
if name == "" {
|
|
return 0, errors.New("session name cannot be empty")
|
|
}
|
|
|
|
query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)`
|
|
res, err := r.db.ExecContext(ctx, query, name, model)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert session: %w", err)
|
|
}
|
|
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("fetch session id: %w", err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
// AddMessage persists a message linked to a session.
|
|
func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) {
|
|
if r == nil {
|
|
return 0, errors.New("repository is nil")
|
|
}
|
|
if sessionID <= 0 {
|
|
return 0, errors.New("sessionID must be positive")
|
|
}
|
|
if role == "" || content == "" {
|
|
return 0, errors.New("role and content must be provided")
|
|
}
|
|
|
|
var tokenVal interface{}
|
|
if tokens != nil {
|
|
tokenVal = *tokens
|
|
} else {
|
|
tokenVal = nil
|
|
}
|
|
|
|
query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)`
|
|
res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert message: %w", err)
|
|
}
|
|
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("fetch message id: %w", err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
// ListSessions returns all sessions ordered by creation time descending.
|
|
func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) {
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query sessions: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sessions []Session
|
|
for rows.Next() {
|
|
var s Session
|
|
if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan session: %w", err)
|
|
}
|
|
sessions = append(sessions, s)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate sessions: %w", err)
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
// GetMessages returns messages for a given session ordered by creation time.
|
|
func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) {
|
|
if sessionID <= 0 {
|
|
return nil, errors.New("sessionID must be positive")
|
|
}
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query messages: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var messages []Message
|
|
for rows.Next() {
|
|
var m Message
|
|
if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan message: %w", err)
|
|
}
|
|
messages = append(messages, m)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate messages: %w", err)
|
|
}
|
|
|
|
return messages, nil
|
|
}
|
|
|
|
// UpdateSessionSummary updates the summary column for a session.
|
|
func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error {
|
|
if sessionID <= 0 {
|
|
return errors.New("sessionID must be positive")
|
|
}
|
|
|
|
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("update session summary: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateSessionName updates the name of an existing session.
|
|
func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error {
|
|
if r == nil {
|
|
return errors.New("repository is nil")
|
|
}
|
|
if sessionID <= 0 {
|
|
return errors.New("sessionID must be positive")
|
|
}
|
|
trimmed := strings.TrimSpace(name)
|
|
if trimmed == "" {
|
|
return errors.New("session name cannot be empty")
|
|
}
|
|
|
|
_, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("update session name: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|