package openai import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "time" ) const defaultTimeout = 30 * time.Second // Client wraps HTTP access to the OpenAI-compatible Chat Completions API. type Client struct { apiKey string baseURL string httpClient *http.Client } // ClientOption customizes client construction. type ClientOption func(*Client) // WithHTTPClient overrides the default HTTP client. func WithHTTPClient(hc *http.Client) ClientOption { return func(c *Client) { c.httpClient = hc } } // WithBaseURL overrides the default base URL. func WithBaseURL(url string) ClientOption { return func(c *Client) { c.baseURL = url } } // NewClient creates a Client with the provided API key and options. func NewClient(apiKey string, opts ...ClientOption) (*Client, error) { apiKey = strings.TrimSpace(apiKey) if apiKey == "" { return nil, errors.New("api key cannot be empty") } client := &Client{ apiKey: apiKey, baseURL: "https://api.openai.com/v1", httpClient: &http.Client{ Timeout: defaultTimeout, }, } for _, opt := range opts { opt(client) } return client, nil } // CreateChatCompletion issues a chat completion request. func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error) { if c == nil { return nil, errors.New("client is nil") } payload, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("encode request: %w", err) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(payload)) if err != nil { return nil, fmt.Errorf("create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("execute request: %w", err) } defer resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { return decodeSuccess(resp.Body) } return nil, decodeError(resp.Body, resp.StatusCode) } // StreamChatCompletion issues a streaming chat completion request and invokes handler for each chunk. func (c *Client) StreamChatCompletion(ctx context.Context, req ChatCompletionRequest, handler ChatCompletionStreamHandler) (*ChatCompletionResponse, error) { if c == nil { return nil, errors.New("client is nil") } req.Stream = true payload, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("encode request: %w", err) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(payload)) if err != nil { return nil, fmt.Errorf("create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) httpReq.Header.Set("Accept", "text/event-stream") resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("execute request: %w", err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, decodeError(resp.Body, resp.StatusCode) } scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) type streamChunk struct { ID string `json:"id"` Object string `json:"object"` Choices []struct { Index int `json:"index"` Message ChatMessage `json:"message"` Delta ChatMessage `json:"delta"` FinishReason string `json:"finish_reason"` } `json:"choices"` Usage Usage `json:"usage"` } var aggregated ChatCompletionResponse var builder strings.Builder role := "assistant" finish := "" var usage Usage usageReceived := false for scanner.Scan() { line := scanner.Text() line = strings.TrimSpace(line) if line == "" { continue } if !strings.HasPrefix(line, "data:") { continue } payloadLine := strings.TrimSpace(strings.TrimPrefix(line, "data:")) if payloadLine == "" { continue } if payloadLine == "[DONE]" { if handler != nil { if err := handler(ChatCompletionStreamEvent{Done: true}); err != nil { return nil, err } } break } var chunk streamChunk if err := json.Unmarshal([]byte(payloadLine), &chunk); err != nil { return nil, fmt.Errorf("decode stream response: %w", err) } if aggregated.ID == "" { aggregated.ID = chunk.ID } if aggregated.Object == "" { aggregated.Object = chunk.Object } if chunk.Usage != (Usage{}) { usage = chunk.Usage usageReceived = true } var chunkText string finishReason := "" if len(chunk.Choices) > 0 { choice := chunk.Choices[0] if choice.Message.Role != "" { role = choice.Message.Role } if choice.Delta.Role != "" { role = choice.Delta.Role } if choice.Delta.Content != "" { chunkText = choice.Delta.Content } else if choice.Message.Content != "" && builder.Len() == 0 { chunkText = choice.Message.Content } if choice.Message.Content != "" && builder.Len() == 0 && chunkText == "" { chunkText = choice.Message.Content } if choice.FinishReason != "" { finishReason = choice.FinishReason finish = choice.FinishReason } } if chunkText != "" { builder.WriteString(chunkText) } if handler != nil { event := ChatCompletionStreamEvent{ ID: chunk.ID, Role: role, Content: chunkText, FinishReason: finishReason, } if usageReceived { u := usage event.Usage = &u } if err := handler(event); err != nil { return nil, err } } } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("read stream: %w", err) } content := strings.TrimSpace(builder.String()) if content == "" { return nil, errors.New("stream response contained no content") } aggregated.Choices = []ChatCompletionChoice{{ Index: 0, Message: ChatMessage{ Role: role, Content: content, }, FinishReason: finish, }} if usageReceived { aggregated.Usage = usage } return &aggregated, nil } func decodeSuccess(r io.Reader) (*ChatCompletionResponse, error) { data, err := io.ReadAll(r) if err != nil { return nil, fmt.Errorf("read response: %w", err) } var response ChatCompletionResponse if err := json.Unmarshal(data, &response); err != nil { trimmed := bytes.TrimSpace(data) if len(trimmed) == 0 { return nil, fmt.Errorf("decode response: %w", err) } return &ChatCompletionResponse{ Choices: []ChatCompletionChoice{{ Message: ChatMessage{Role: "assistant", Content: string(trimmed)}, }}, }, nil } return &response, nil } func decodeError(r io.Reader, status int) error { var apiErr ErrorResponse if err := json.NewDecoder(r).Decode(&apiErr); err != nil { return fmt.Errorf("api error (status %d): failed to decode body: %w", status, err) } return fmt.Errorf("api error (status %d): %s", status, apiErr.Error.Message) }