package storage import ( "context" "database/sql" "errors" "fmt" "strings" "time" ) // Session represents a persisted chat session. type Session struct { ID int64 Name string ModelName string Summary sql.NullString CreatedAt time.Time UpdatedAt time.Time } // GetSessionByName fetches a session by its unique name, returning nil when not found. func (r *Repository) GetSessionByName(ctx context.Context, name string) (*Session, error) { if r == nil { return nil, errors.New("repository is nil") } if strings.TrimSpace(name) == "" { return nil, errors.New("session name cannot be empty") } row := r.db.QueryRowContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions WHERE name = ?`, name) var session Session if err := row.Scan(&session.ID, &session.Name, &session.ModelName, &session.Summary, &session.CreatedAt, &session.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("get session by name: %w", err) } return &session, nil } // Message represents a persisted chat message. type Message struct { ID int64 SessionID int64 Role string Content string TokenCount sql.NullInt64 CreatedAt time.Time } // Repository exposes CRUD helpers for sessions and messages. type Repository struct { db *sql.DB } // NewRepository constructs a repository bound to the provided DB. func NewRepository(db *sql.DB) (*Repository, error) { if db == nil { return nil, errors.New("db cannot be nil") } return &Repository{db: db}, nil } // CreateSession inserts a new session record. func (r *Repository) CreateSession(ctx context.Context, name, model string) (int64, error) { if r == nil { return 0, errors.New("repository is nil") } if name == "" { return 0, errors.New("session name cannot be empty") } query := `INSERT INTO sessions (name, model_name) VALUES (?, ?)` res, err := r.db.ExecContext(ctx, query, name, model) if err != nil { return 0, fmt.Errorf("insert session: %w", err) } id, err := res.LastInsertId() if err != nil { return 0, fmt.Errorf("fetch session id: %w", err) } return id, nil } // AddMessage persists a message linked to a session. func (r *Repository) AddMessage(ctx context.Context, sessionID int64, role, content string, tokens *int64) (int64, error) { if r == nil { return 0, errors.New("repository is nil") } if sessionID <= 0 { return 0, errors.New("sessionID must be positive") } if role == "" || content == "" { return 0, errors.New("role and content must be provided") } var tokenVal interface{} if tokens != nil { tokenVal = *tokens } else { tokenVal = nil } query := `INSERT INTO messages (session_id, role, content, token_count) VALUES (?, ?, ?, ?)` res, err := r.db.ExecContext(ctx, query, sessionID, role, content, tokenVal) if err != nil { return 0, fmt.Errorf("insert message: %w", err) } id, err := res.LastInsertId() if err != nil { return 0, fmt.Errorf("fetch message id: %w", err) } return id, nil } // ListSessions returns all sessions ordered by creation time descending. func (r *Repository) ListSessions(ctx context.Context) ([]Session, error) { rows, err := r.db.QueryContext(ctx, `SELECT id, name, model_name, summary, created_at, updated_at FROM sessions ORDER BY created_at DESC`) if err != nil { return nil, fmt.Errorf("query sessions: %w", err) } defer rows.Close() var sessions []Session for rows.Next() { var s Session if err := rows.Scan(&s.ID, &s.Name, &s.ModelName, &s.Summary, &s.CreatedAt, &s.UpdatedAt); err != nil { return nil, fmt.Errorf("scan session: %w", err) } sessions = append(sessions, s) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate sessions: %w", err) } return sessions, nil } // GetMessages returns messages for a given session ordered by creation time. func (r *Repository) GetMessages(ctx context.Context, sessionID int64) ([]Message, error) { if sessionID <= 0 { return nil, errors.New("sessionID must be positive") } rows, err := r.db.QueryContext(ctx, `SELECT id, session_id, role, content, token_count, created_at FROM messages WHERE session_id = ? ORDER BY created_at ASC`, sessionID) if err != nil { return nil, fmt.Errorf("query messages: %w", err) } defer rows.Close() var messages []Message for rows.Next() { var m Message if err := rows.Scan(&m.ID, &m.SessionID, &m.Role, &m.Content, &m.TokenCount, &m.CreatedAt); err != nil { return nil, fmt.Errorf("scan message: %w", err) } messages = append(messages, m) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate messages: %w", err) } return messages, nil } // UpdateSessionSummary updates the summary column for a session. func (r *Repository) UpdateSessionSummary(ctx context.Context, sessionID int64, summary string) error { if sessionID <= 0 { return errors.New("sessionID must be positive") } _, err := r.db.ExecContext(ctx, `UPDATE sessions SET summary = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, summary, sessionID) if err != nil { return fmt.Errorf("update session summary: %w", err) } return nil } // UpdateSessionName updates the name of an existing session. func (r *Repository) UpdateSessionName(ctx context.Context, sessionID int64, name string) error { if r == nil { return errors.New("repository is nil") } if sessionID <= 0 { return errors.New("sessionID must be positive") } trimmed := strings.TrimSpace(name) if trimmed == "" { return errors.New("session name cannot be empty") } _, err := r.db.ExecContext(ctx, `UPDATE sessions SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, trimmed, sessionID) if err != nil { return fmt.Errorf("update session name: %w", err) } return nil }