From a3e6b105d023ea7f68c15aa728baefed7288b074 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-=C3=98rjan=20Smelror?= Date: Wed, 1 Oct 2025 20:09:40 +0200 Subject: [PATCH] Enable streaming UI updates for chat responses --- .gitignore | 1 + internal/app/app.go | 242 ++++++++++++++++++++++++++--------- internal/chat/service.go | 22 +++- internal/openai/client.go | 259 ++++++++++++++++++++++++-------------- internal/openai/types.go | 13 ++ 5 files changed, 375 insertions(+), 162 deletions(-) diff --git a/.gitignore b/.gitignore index 7227b2b..a60c844 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ *.dll *.so *.dylib +/goaichat # Test binary, built with `go test -c` *.test diff --git a/internal/app/app.go b/internal/app/app.go index 0960dc0..c2113bf 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -18,14 +18,16 @@ import ( // App encapsulates the Goaichat application runtime wiring. type App struct { - logger *slog.Logger - config *config.Config - openAI *openai.Client - chat *chat.Service - store *storage.Manager - repo *storage.Repository - input io.Reader - output io.Writer + logger *slog.Logger + config *config.Config + openAI *openai.Client + chat *chat.Service + store *storage.Manager + repo *storage.Repository + input io.Reader + output io.Writer + status string + streamBuffer strings.Builder } // New constructs a new App instance. @@ -152,12 +154,10 @@ func (a *App) initStorage() error { func (a *App) runCLILoop(ctx context.Context) error { scanner := bufio.NewScanner(a.input) - if _, err := fmt.Fprintln(a.output, "Type your message. Use /exit to quit, /reset to clear history."); err != nil { - return err - } + a.setStatus("Type your message. Use /exit to quit, /reset to clear history.") for { - if _, err := fmt.Fprint(a.output, "> "); err != nil { + if err := a.renderUI(); err != nil { return err } @@ -169,12 +169,11 @@ func (a *App) runCLILoop(ctx context.Context) error { } input := scanner.Text() + a.setStatus("") handled, exit, err := a.handleCommand(ctx, input) if err != nil { - if _, writeErr := fmt.Fprintf(a.output, "Command error: %v\n", err); writeErr != nil { - return writeErr - } + a.setStatus("Command error: %v", err) continue } if handled { @@ -184,17 +183,25 @@ func (a *App) runCLILoop(ctx context.Context) error { continue } - reply, err := a.chat.Send(ctx, input) - if err != nil { - if _, writeErr := fmt.Fprintf(a.output, "Error: %v\n", err); writeErr != nil { - return writeErr + streamEnabled := a.chat != nil && a.chat.StreamingEnabled() + a.clearStreamingContent() + var handler openai.ChatCompletionStreamHandler + if streamEnabled { + handler = func(event openai.ChatCompletionStreamEvent) error { + if event.Done || event.Content == "" { + return nil + } + a.appendStreamingContent(event.Content) + return nil } + } + if _, err := a.chat.Send(ctx, input, handler); err != nil { + a.setStatus("Error: %v", err) continue } + a.clearStreamingContent() - if _, err := fmt.Fprintf(a.output, "AI: %s\n", reply); err != nil { - return err - } + a.setStatus("") if err := a.maybeSuggestSessionName(ctx); err != nil { a.logger.WarnContext(ctx, "session name suggestion failed", "error", err) @@ -213,66 +220,66 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex if strings.HasPrefix(trimmed, "/rename") { parts := strings.Fields(trimmed) if len(parts) < 2 { - _, err := fmt.Fprintln(a.output, "Usage: /rename ") - return true, false, err + a.setStatus("Usage: /rename ") + return true, false, nil } rawName := strings.Join(parts[1:], " ") normalized := chat.NormalizeSessionName(rawName) if normalized == "" { - _, err := fmt.Fprintln(a.output, "Session name cannot be empty.") - return true, false, err + a.setStatus("Session name cannot be empty.") + return true, false, nil } if a.repo != nil { existing, fetchErr := a.repo.GetSessionByName(ctx, normalized) if fetchErr != nil { - _, err := fmt.Fprintf(a.output, "Failed to verify name availability: %v\n", fetchErr) - return true, false, err + a.setStatus("Failed to verify name availability: %v", fetchErr) + return true, false, nil } if existing != nil { currentID := a.chat.CurrentSessionID() if currentID == 0 || existing.ID != currentID { - _, err := fmt.Fprintf(a.output, "Session name %q is already in use.\n", normalized) - return true, false, err + a.setStatus("Session name %q is already in use.", normalized) + return true, false, nil } } } applied, setErr := a.chat.SetSessionName(ctx, rawName) if setErr != nil { - _, err := fmt.Fprintf(a.output, "Failed to rename session: %v\n", setErr) - return true, false, err + a.setStatus("Failed to rename session: %v", setErr) + return true, false, nil } - _, err := fmt.Fprintf(a.output, "Session renamed to %q.\n", applied) - return true, false, err + a.setStatus("Session renamed to %q.", applied) + return true, false, nil } if strings.HasPrefix(trimmed, "/open") { parts := strings.Fields(trimmed) if len(parts) != 2 { - _, err := fmt.Fprintln(a.output, "Usage: /open ") - return true, false, err + a.setStatus("Usage: /open ") + return true, false, nil } if a.repo == nil { - _, err := fmt.Fprintln(a.output, "Storage not initialised; cannot open sessions.") - return true, false, err + a.setStatus("Storage not initialised; cannot open sessions.") + return true, false, nil } session, fetchErr := a.repo.GetSessionByName(ctx, parts[1]) if fetchErr != nil { - _, err := fmt.Fprintf(a.output, "Failed to fetch session: %v\n", fetchErr) - return true, false, err + a.setStatus("Failed to fetch session: %v", fetchErr) + return true, false, nil } if session == nil { - _, err := fmt.Fprintf(a.output, "Session %q not found.\n", parts[1]) - return true, false, err + a.setStatus("Session %q not found.", parts[1]) + return true, false, nil } messages, msgErr := a.repo.GetMessages(ctx, session.ID) if msgErr != nil { - _, err := fmt.Fprintf(a.output, "Failed to load messages: %v\n", msgErr) - return true, false, err + a.setStatus("Failed to load messages: %v", msgErr) + return true, false, nil } chatMessages := make([]openai.ChatMessage, 0, len(messages)) @@ -287,8 +294,8 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex summaryPresent := session.Summary.Valid && strings.TrimSpace(session.Summary.String) != "" a.chat.RestoreSession(session.ID, session.Name, chatMessages, summaryPresent) - _, err := fmt.Fprintf(a.output, "Loaded session %q with %d messages.\n", session.Name, len(chatMessages)) - return true, false, err + a.setStatus("Loaded session %s with %d messages.", session.Name, len(chatMessages)) + return true, false, nil } switch trimmed { @@ -296,38 +303,38 @@ func (a *App) handleCommand(ctx context.Context, input string) (handled bool, ex return true, true, nil case "/reset": a.chat.Reset() - _, err := fmt.Fprintln(a.output, "History cleared.") - return true, false, err + a.setStatus("History cleared.") + return true, false, nil case "/list": if a.repo == nil { - _, err := fmt.Fprintln(a.output, "History commands unavailable (storage not initialised).") - return true, false, err + a.setStatus("History commands unavailable (storage not initialised).") + return true, false, nil } sessions, listErr := a.repo.ListSessions(ctx) if listErr != nil { - _, err := fmt.Fprintf(a.output, "Failed to list sessions: %v\n", listErr) - return true, false, err + a.setStatus("Failed to list sessions: %v", listErr) + return true, false, nil } if len(sessions) == 0 { - _, err := fmt.Fprintln(a.output, "No saved sessions.") - return true, false, err + a.setStatus("No saved sessions.") + return true, false, nil } + var lines []string for _, session := range sessions { summary := session.Summary.String if summary == "" { summary = "(no summary)" } - if _, err := fmt.Fprintf(a.output, "- %s [%s]: %s\n", session.Name, session.ModelName, summary); err != nil { - return true, false, err - } + lines = append(lines, fmt.Sprintf("- %s [%s]: %s", session.Name, session.ModelName, summary)) } + a.setStatus("%s", strings.Join(lines, "\n")) return true, false, nil case "/help": - _, err := fmt.Fprintln(a.output, "Commands: /exit, /reset, /list, /open , /rename , /help (more coming soon)") - return true, false, err + a.setStatus("Commands: /exit, /reset, /list, /open , /rename , /help (more coming soon)") + return true, false, nil default: - _, err := fmt.Fprintf(a.output, "Unknown command %q. Try /help.\n", trimmed) - return true, false, err + a.setStatus("Unknown command %q. Try /help.", trimmed) + return true, false, nil } } @@ -346,9 +353,120 @@ func (a *App) maybeSuggestSessionName(ctx context.Context) error { } a.chat.MarkSessionNameSuggested() - if _, err := fmt.Fprintf(a.output, "Suggested session name: %s\nUse /rename %s to apply it now.\n", suggestion, suggestion); err != nil { + a.setStatus("Suggested session name: %s\nUse /rename %s to apply it now.", suggestion, suggestion) + + return nil +} + +func (a *App) setStatus(msg string, args ...any) { + if a == nil { + return + } + if msg == "" { + a.status = "" + return + } + if len(args) == 0 { + a.status = msg + return + } + a.status = fmt.Sprintf(msg, args...) +} + +func (a *App) renderUI() error { + if a == nil { + return errors.New("app is nil") + } + if _, err := fmt.Fprint(a.output, "\033[2J\033[H"); err != nil { + return err + } + + sessionName := "(unnamed session)" + if a.chat != nil && strings.TrimSpace(a.chat.SessionName()) != "" { + sessionName = a.chat.SessionName() + } + + title := fmt.Sprintf("goaichat - %s", sessionName) + underline := strings.Repeat("=", len(title)) + if _, err := fmt.Fprintf(a.output, "%s\n%s\n\n", title, underline); err != nil { + return err + } + + if _, err := fmt.Fprintln(a.output, "Conversation"); err != nil { + return err + } + if _, err := fmt.Fprintln(a.output, strings.Repeat("-", len("Conversation"))); err != nil { + return err + } + + if a.chat != nil { + history := a.chat.History() + for _, msg := range history { + label := roleLabel(msg.Role) + if _, err := fmt.Fprintf(a.output, "%s: %s\n", label, msg.Content); err != nil { + return err + } + } + } + + if a.streamBuffer.Len() > 0 { + if _, err := fmt.Fprintf(a.output, "AI: %s\n", a.streamBuffer.String()); err != nil { + return err + } + } + + if _, err := fmt.Fprintln(a.output); err != nil { + return err + } + + if status := strings.TrimSpace(a.status); status != "" { + if _, err := fmt.Fprintln(a.output, "Status"); err != nil { + return err + } + if _, err := fmt.Fprintln(a.output, strings.Repeat("-", len("Status"))); err != nil { + return err + } + if _, err := fmt.Fprintf(a.output, "%s\n\n", status); err != nil { + return err + } + } + + if _, err := fmt.Fprint(a.output, "> "); err != nil { return err } return nil } + +func roleLabel(role string) string { + switch role { + case "assistant": + return "AI" + case "user": + return "You" + case "system": + return "System" + default: + if strings.TrimSpace(role) == "" { + return "Unknown" + } + return strings.Title(role) + } +} + +func (a *App) appendStreamingContent(chunk string) { + if a == nil || chunk == "" { + return + } + a.streamBuffer.WriteString(chunk) + if err := a.renderUI(); err != nil && a.logger != nil { + a.logger.Warn("stream render failed", "error", err) + } +} + +func (a *App) clearStreamingContent() { + if a == nil { + return + } + a.streamBuffer.Reset() +} diff --git a/internal/chat/service.go b/internal/chat/service.go index c4ce4b9..dc7e11a 100644 --- a/internal/chat/service.go +++ b/internal/chat/service.go @@ -18,6 +18,7 @@ const defaultSessionPrefix = "session-" // CompletionClient defines the subset of the OpenAI client used by the chat service. type CompletionClient interface { CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error) + StreamChatCompletion(ctx context.Context, req openai.ChatCompletionRequest, handler openai.ChatCompletionStreamHandler) (*openai.ChatCompletionResponse, error) } // MessageRepository defines persistence hooks used by the chat service. @@ -67,8 +68,9 @@ func NewService(logger *slog.Logger, modelCfg config.ModelConfig, client Complet }, nil } -// Send submits a user message and returns the assistant reply. -func (s *Service) Send(ctx context.Context, input string) (string, error) { +// Send submits a user message and returns the assistant reply. When streamHandler is provided and streaming is enabled, +// partial responses are forwarded to the handler as they arrive. +func (s *Service) Send(ctx context.Context, input string, streamHandler openai.ChatCompletionStreamHandler) (string, error) { if s == nil { return "", errors.New("service is nil") } @@ -99,7 +101,13 @@ func (s *Service) Send(ctx context.Context, input string) (string, error) { s.logger.DebugContext(ctx, "sending chat completion", "model", s.model, "message_count", len(messages)) - resp, err := s.client.CreateChatCompletion(ctx, req) + var resp *openai.ChatCompletionResponse + var err error + if s.stream { + resp, err = s.client.StreamChatCompletion(ctx, req, streamHandler) + } else { + resp, err = s.client.CreateChatCompletion(ctx, req) + } if err != nil { return "", err } @@ -127,6 +135,14 @@ func (s *Service) History() []openai.ChatMessage { return historyCopy } +// StreamingEnabled reports whether streaming completions are configured for this service. +func (s *Service) StreamingEnabled() bool { + if s == nil { + return false + } + return s.stream +} + // Reset clears the in-memory conversation history. func (s *Service) Reset() { if s == nil { diff --git a/internal/openai/client.go b/internal/openai/client.go index 739c2c9..b6c8883 100644 --- a/internal/openai/client.go +++ b/internal/openai/client.go @@ -87,15 +87,174 @@ func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionReq defer resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { - if req.Stream { - return decodeStream(resp.Body) - } return decodeSuccess(resp.Body) } return nil, decodeError(resp.Body, resp.StatusCode) } +// StreamChatCompletion issues a streaming chat completion request and invokes handler for each chunk. +func (c *Client) StreamChatCompletion(ctx context.Context, req ChatCompletionRequest, handler ChatCompletionStreamHandler) (*ChatCompletionResponse, error) { + if c == nil { + return nil, errors.New("client is nil") + } + + req.Stream = true + payload, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("encode request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(payload)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + httpReq.Header.Set("Accept", "text/event-stream") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, decodeError(resp.Body, resp.StatusCode) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) + + type streamChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Choices []struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + Delta ChatMessage `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage Usage `json:"usage"` + } + + var aggregated ChatCompletionResponse + var builder strings.Builder + role := "assistant" + finish := "" + var usage Usage + usageReceived := false + + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + if line == "" { + continue + } + if !strings.HasPrefix(line, "data:") { + continue + } + + payloadLine := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payloadLine == "" { + continue + } + if payloadLine == "[DONE]" { + if handler != nil { + if err := handler(ChatCompletionStreamEvent{Done: true}); err != nil { + return nil, err + } + } + break + } + + var chunk streamChunk + if err := json.Unmarshal([]byte(payloadLine), &chunk); err != nil { + return nil, fmt.Errorf("decode stream response: %w", err) + } + + if aggregated.ID == "" { + aggregated.ID = chunk.ID + } + if aggregated.Object == "" { + aggregated.Object = chunk.Object + } + + if chunk.Usage != (Usage{}) { + usage = chunk.Usage + usageReceived = true + } + + var chunkText string + finishReason := "" + if len(chunk.Choices) > 0 { + choice := chunk.Choices[0] + if choice.Message.Role != "" { + role = choice.Message.Role + } + if choice.Delta.Role != "" { + role = choice.Delta.Role + } + if choice.Delta.Content != "" { + chunkText = choice.Delta.Content + } else if choice.Message.Content != "" && builder.Len() == 0 { + chunkText = choice.Message.Content + } + if choice.Message.Content != "" && builder.Len() == 0 && chunkText == "" { + chunkText = choice.Message.Content + } + if choice.FinishReason != "" { + finishReason = choice.FinishReason + finish = choice.FinishReason + } + } + + if chunkText != "" { + builder.WriteString(chunkText) + } + + if handler != nil { + event := ChatCompletionStreamEvent{ + ID: chunk.ID, + Role: role, + Content: chunkText, + FinishReason: finishReason, + } + if usageReceived { + u := usage + event.Usage = &u + } + if err := handler(event); err != nil { + return nil, err + } + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read stream: %w", err) + } + + content := strings.TrimSpace(builder.String()) + if content == "" { + return nil, errors.New("stream response contained no content") + } + + aggregated.Choices = []ChatCompletionChoice{{ + Index: 0, + Message: ChatMessage{ + Role: role, + Content: content, + }, + FinishReason: finish, + }} + if usageReceived { + aggregated.Usage = usage + } + + return &aggregated, nil +} + func decodeSuccess(r io.Reader) (*ChatCompletionResponse, error) { data, err := io.ReadAll(r) if err != nil { @@ -118,100 +277,6 @@ func decodeSuccess(r io.Reader) (*ChatCompletionResponse, error) { return &response, nil } -func decodeStream(r io.Reader) (*ChatCompletionResponse, error) { - scanner := bufio.NewScanner(r) - var payloads []json.RawMessage - - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "data: ") { - payload := strings.TrimPrefix(line, "data: ") - if payload == "[DONE]" { - break - } - payloads = append(payloads, json.RawMessage(payload)) - } - } - - if err := scanner.Err(); err != nil { - return nil, fmt.Errorf("read stream: %w", err) - } - - if len(payloads) == 0 { - return nil, errors.New("empty stream response") - } - - type streamChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Choices []struct { - Index int `json:"index"` - Message ChatMessage `json:"message"` - Delta ChatMessage `json:"delta"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage Usage `json:"usage"` - } - - var aggregated ChatCompletionResponse - var builder strings.Builder - role := "assistant" - finish := "" - - for _, payload := range payloads { - var chunk streamChunk - if err := json.Unmarshal(payload, &chunk); err != nil { - return nil, fmt.Errorf("decode stream response: %w", err) - } - if aggregated.ID == "" { - aggregated.ID = chunk.ID - } - if aggregated.Object == "" { - aggregated.Object = chunk.Object - } - aggregated.Usage.PromptTokens += chunk.Usage.PromptTokens - aggregated.Usage.CompletionTokens += chunk.Usage.CompletionTokens - aggregated.Usage.TotalTokens += chunk.Usage.TotalTokens - - if len(chunk.Choices) == 0 { - continue - } - - choice := chunk.Choices[0] - if choice.Message.Role != "" { - role = choice.Message.Role - } - if choice.Delta.Role != "" { - role = choice.Delta.Role - } - if choice.Message.Content != "" { - builder.WriteString(choice.Message.Content) - } - if choice.Delta.Content != "" { - builder.WriteString(choice.Delta.Content) - } - if choice.FinishReason != "" { - finish = choice.FinishReason - } - } - - content := strings.TrimSpace(builder.String()) - if content == "" { - return nil, errors.New("stream response contained no content") - } - - aggregated.Choices = []ChatCompletionChoice{{ - Index: 0, - FinishReason: finish, - Message: ChatMessage{ - Role: role, - Content: content, - }, - }} - - return &aggregated, nil -} - func decodeError(r io.Reader, status int) error { var apiErr ErrorResponse if err := json.NewDecoder(r).Decode(&apiErr); err != nil { diff --git a/internal/openai/types.go b/internal/openai/types.go index 3a6314c..9e4ccf2 100644 --- a/internal/openai/types.go +++ b/internal/openai/types.go @@ -37,6 +37,19 @@ type ChatCompletionResponse struct { Usage Usage `json:"usage"` } +// ChatCompletionStreamEvent represents a single chunk in a streaming chat completion. +type ChatCompletionStreamEvent struct { + ID string `json:"-"` + Role string `json:"-"` + Content string `json:"-"` + FinishReason string `json:"-"` + Usage *Usage `json:"-"` + Done bool `json:"-"` +} + +// ChatCompletionStreamHandler consumes streaming completion events. +type ChatCompletionStreamHandler func(ChatCompletionStreamEvent) error + // APIError captures structured error responses returned by the API. type APIError struct { Message string `json:"message"`