From 5261010902086d46dafd511551eb62470926aa69 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:23:46 +0200 Subject: [PATCH 1/9] removed unused func --- cmd/src/login.go | 41 ++ internal/oauthdevice/device_flow.go | 304 ++++++++++++++ internal/oauthdevice/device_flow_test.go | 509 +++++++++++++++++++++++ 3 files changed, 854 insertions(+) create mode 100644 internal/oauthdevice/device_flow.go create mode 100644 internal/oauthdevice/device_flow_test.go diff --git a/cmd/src/login.go b/cmd/src/login.go index ab5a097c71..52d881cc58 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -28,6 +28,10 @@ Examples: Authenticate to Sourcegraph.com: $ src login https://sourcegraph.com + + Use OAuth device flow to authenticate: + + $ src login --device-flow https://sourcegraph.com ` flagSet := flag.NewFlagSet("login", flag.ExitOnError) @@ -122,6 +126,43 @@ func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg s } fmt.Fprintln(out) fmt.Fprintf(out, "✔️ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointArg) + + if p.useDeviceFlow { + fmt.Fprintln(out) + fmt.Fprintf(out, "To use this access token, set the following environment variables in your terminal:\n\n") + fmt.Fprintf(out, " export SRC_ENDPOINT=%s\n", endpointArg) + fmt.Fprintf(out, " export SRC_ACCESS_TOKEN=%s\n", cfg.AccessToken) + } + fmt.Fprintln(out) return nil } + +func runDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauthdevice.Client) (string, error) { + authResp, err := client.Start(ctx, endpoint, nil) + if err != nil { + return "", err + } + + fmt.Fprintln(out) + fmt.Fprintf(out, "To authenticate, visit %s and enter the code: %s\n", authResp.VerificationURI, authResp.UserCode) + if authResp.VerificationURIComplete != "" { + fmt.Fprintln(out) + fmt.Fprintf(out, "Alternatively, you can open: %s\n", authResp.VerificationURIComplete) + } + fmt.Fprintln(out) + fmt.Fprint(out, "Waiting for authorization...") + defer fmt.Fprintf(out, "DONE\n\n") + + interval := time.Duration(authResp.Interval) * time.Second + if interval <= 0 { + interval = 5 * time.Second + } + + tokenResp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) + if err != nil { + return "", err + } + + return tokenResp.AccessToken, nil +} diff --git a/internal/oauthdevice/device_flow.go b/internal/oauthdevice/device_flow.go new file mode 100644 index 0000000000..851779cc3a --- /dev/null +++ b/internal/oauthdevice/device_flow.go @@ -0,0 +1,304 @@ +// Package oauthdevice implements the OAuth 2.0 Device Authorization Grant (RFC 8628) +// for authenticating with Sourcegraph instances. +package oauthdevice + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +const ( + ClientID = "sgo_cid_sourcegraph-cli" + + wellKnownPath = "/.well-known/openid-configuration" + + GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code" + + ScopeOpenID string = "openid" + ScopeProfile string = "profile" + ScopeEmail string = "email" + ScopeOfflineAccess string = "offline_access" + ScopeUserAll string = "user:all" +) + +var defaultScopes = []string{ScopeEmail, ScopeOfflineAccess, ScopeOpenID, ScopeProfile, ScopeUserAll} + +// OIDCConfiguration represents the relevant fields from the OpenID Connect +// Discovery document at /.well-known/openid-configuration +type OIDCConfiguration struct { + Issuer string `json:"issuer,omitempty"` + TokenEndpoint string `json:"token_endpoint,omitempty"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"` +} + +type DeviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type ErrorResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +type Client interface { + Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) + Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) + Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) +} + +type httpClient struct { + client *http.Client + // cached OIDC configuration per endpoint + configCache map[string]*OIDCConfiguration +} + +func NewClient() Client { + return &httpClient{ + client: &http.Client{ + Timeout: 30 * time.Second, + }, + configCache: make(map[string]*OIDCConfiguration), + } +} + +func NewClientWithHTTPClient(c *http.Client) Client { + return &httpClient{ + client: c, + configCache: make(map[string]*OIDCConfiguration), + } +} + +// Discover fetches the openid-configuration which contains all the routes a client should +// use for authorization, device flows, tokens etc. +// +// Before making any requests, the configCache is checked and if there is a cache hit, the +// cached config is returned. +func (c *httpClient) Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) { + endpoint = strings.TrimRight(endpoint, "/") + + if config, ok := c.configCache[endpoint]; ok { + return config, nil + } + + reqURL := endpoint + wellKnownPath + + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + return nil, errors.Wrap(err, "creating discovery request") + } + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "discovery request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading discovery response") + } + + if resp.StatusCode != http.StatusOK { + return nil, errors.Newf("discovery failed with status %d: %s", resp.StatusCode, string(body)) + } + + var config OIDCConfiguration + if err := json.Unmarshal(body, &config); err != nil { + return nil, errors.Wrap(err, "parsing discovery response") + } + + c.configCache[endpoint] = &config + + return &config, nil +} + +// Start starts the OAuth device flow with the given endpoint. If no scopes are given the default scopes are used. +// +// Default Scopes: "openid" "profile" "email" "offline_access" "user:all" +func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) { + endpoint = strings.TrimRight(endpoint, "/") + + // Discover OIDC configuration + config, err := c.Discover(ctx, endpoint) + if err != nil { + return nil, errors.Wrap(err, "OIDC discovery failed") + } + + if config.DeviceAuthorizationEndpoint == "" { + return nil, errors.New("device authorization endpoint not found in OIDC configuration; the server may not support device flow") + } + + data := url.Values{} + data.Set("client_id", ClientID) + if len(scopes) > 0 { + data.Set("scope", strings.Join(scopes, " ")) + } else { + data.Set("scope", strings.Join(defaultScopes, " ")) + } + + req, err := http.NewRequestWithContext(ctx, "POST", config.DeviceAuthorizationEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, errors.Wrap(err, "creating device auth request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "device auth request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading device auth response") + } + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return nil, errors.Newf("device auth failed: %s: %s", errResp.Error, errResp.ErrorDescription) + } + return nil, errors.Newf("device auth failed with status %d: %s", resp.StatusCode, string(body)) + } + + var authResp DeviceAuthResponse + if err := json.Unmarshal(body, &authResp); err != nil { + return nil, errors.Wrap(err, "parsing device auth response") + } + + return &authResp, nil +} + +// Poll polls the OAuth token endpoint until the device has been authorized or not +// +// We poll as long as the authorization is pending. If the server tells us to slow down, we will wait 5 secs extra. +// +// Polling will stop when: +// - Device is authorized, and a token is returned +// - Device code has expried +// - User denied authorization +func (c *httpClient) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) { + endpoint = strings.TrimRight(endpoint, "/") + + // Discover OIDC configuration (should be cached from Start) + config, err := c.Discover(ctx, endpoint) + if err != nil { + return nil, errors.Wrap(err, "OIDC discovery failed") + } + + if config.TokenEndpoint == "" { + return nil, errors.New("token endpoint not found in OIDC configuration") + } + + deadline := time.Now().Add(time.Duration(expiresIn) * time.Second) + + for { + if time.Now().After(deadline) { + return nil, errors.New("device code expired") + } + + if !testing.Testing() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + } + + tokenResp, err := c.pollOnce(ctx, config.TokenEndpoint, deviceCode) + if err != nil { + var pollErr *PollError + if errors.As(err, &pollErr) { + switch pollErr.Code { + case "authorization_pending": + continue + case "slow_down": + interval += 5 * time.Second + continue + case "expired_token": + return nil, errors.New("device code expired") + case "access_denied": + return nil, errors.New("authorization was denied by the user") + } + } + return nil, err + } + + return tokenResp, nil + } +} + +type PollError struct { + Code string + Description string +} + +func (e *PollError) Error() string { + if e.Description != "" { + return fmt.Sprintf("%s: %s", e.Code, e.Description) + } + return e.Code +} + +func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) { + data := url.Values{} + data.Set("client_id", ClientID) + data.Set("device_code", deviceCode) + data.Set("grant_type", GrantTypeDeviceCode) + + req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, errors.Wrap(err, "creating token request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "token request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading token response") + } + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return nil, &PollError{Code: errResp.Error, Description: errResp.ErrorDescription} + } + return nil, errors.Newf("token request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "parsing token response") + } + + return &tokenResp, nil +} diff --git a/internal/oauthdevice/device_flow_test.go b/internal/oauthdevice/device_flow_test.go new file mode 100644 index 0000000000..02b3923d88 --- /dev/null +++ b/internal/oauthdevice/device_flow_test.go @@ -0,0 +1,509 @@ +package oauthdevice + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +const ( + testDeviceAuthPath = "/device/code" + testTokenPath = "/token" +) + +type testServerOptions struct { + handlers map[string]http.HandlerFunc + wellKnownFunc func(w http.ResponseWriter, r *http.Request) +} + +func newTestServer(t *testing.T, opts testServerOptions) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case wellKnownPath: + if opts.wellKnownFunc != nil { + opts.wellKnownFunc(w, r) + } else { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(OIDCConfiguration{ + Issuer: "http://" + r.Host, + DeviceAuthorizationEndpoint: "http://" + r.Host + testDeviceAuthPath, + TokenEndpoint: "http://" + r.Host + testTokenPath, + }) + } + default: + if handler, ok := opts.handlers[r.URL.Path]; ok { + handler(w, r) + } else { + t.Errorf("unexpected path: %s", r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + } + } + })) +} + +func TestDiscover_Success(t *testing.T) { + server := newTestServer(t, testServerOptions{}) + defer server.Close() + + client := NewClient() + config, err := client.Discover(context.Background(), server.URL) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + + if config.DeviceAuthorizationEndpoint != server.URL+testDeviceAuthPath { + t.Errorf("DeviceAuthorizationEndpoint = %q, want %q", config.DeviceAuthorizationEndpoint, server.URL+testDeviceAuthPath) + } + if config.TokenEndpoint != server.URL+testTokenPath { + t.Errorf("TokenEndpoint = %q, want %q", config.TokenEndpoint, server.URL+testTokenPath) + } +} + +func TestDiscover_Caching(t *testing.T) { + var callCount int32 + server := newTestServer(t, testServerOptions{ + wellKnownFunc: func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(OIDCConfiguration{ + DeviceAuthorizationEndpoint: "http://example.com/device", + TokenEndpoint: "http://example.com/token", + }) + }, + }) + defer server.Close() + + client := NewClient() + + // Populate the cache + _, err := client.Discover(context.Background(), server.URL) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + + // Second call should use cache + _, err = client.Discover(context.Background(), server.URL) + if err != nil { + t.Fatalf("Discover() error = %v", err) + } + + if atomic.LoadInt32(&callCount) != 1 { + t.Errorf("callCount = %d, want 1 (second call should use cache)", callCount) + } +} + +func TestDiscover_Error(t *testing.T) { + server := newTestServer(t, testServerOptions{ + wellKnownFunc: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Discover(context.Background(), server.URL) + if err == nil { + t.Fatal("Discover() expected error, got nil") + } + + if !strings.Contains(err.Error(), "404") { + t.Errorf("error = %q, want to contain '404'", err.Error()) + } +} + +func TestStart_Success(t *testing.T) { + wantResponse := DeviceAuthResponse{ + DeviceCode: "test-device-code", + UserCode: "ABCD-1234", + VerificationURI: "https://example.com/device", + VerificationURIComplete: "https://example.com/device?user_code=ABCD-1234", + ExpiresIn: 1800, + Interval: 5, + } + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testDeviceAuthPath: func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("unexpected method: %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + if got := r.FormValue("client_id"); got != ClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(wantResponse) + }, + }, + }) + defer server.Close() + + client := NewClient() + resp, err := client.Start(context.Background(), server.URL, nil) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + if resp.DeviceCode != wantResponse.DeviceCode { + t.Errorf("DeviceCode = %q, want %q", resp.DeviceCode, wantResponse.DeviceCode) + } + if resp.UserCode != wantResponse.UserCode { + t.Errorf("UserCode = %q, want %q", resp.UserCode, wantResponse.UserCode) + } + if resp.VerificationURI != wantResponse.VerificationURI { + t.Errorf("VerificationURI = %q, want %q", resp.VerificationURI, wantResponse.VerificationURI) + } + if resp.VerificationURIComplete != wantResponse.VerificationURIComplete { + t.Errorf("VerificationURIComplete = %q, want %q", resp.VerificationURIComplete, wantResponse.VerificationURIComplete) + } + if resp.ExpiresIn != wantResponse.ExpiresIn { + t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, wantResponse.ExpiresIn) + } + if resp.Interval != wantResponse.Interval { + t.Errorf("Interval = %d, want %d", resp.Interval, wantResponse.Interval) + } +} + +func TestStart_WithScopes(t *testing.T) { + var receivedScope string + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testDeviceAuthPath: func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + receivedScope = r.FormValue("scope") + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(DeviceAuthResponse{ + DeviceCode: "test-device-code", + UserCode: "ABCD-1234", + VerificationURI: "https://example.com/device", + ExpiresIn: 1800, + Interval: 5, + }) + }, + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Start(context.Background(), server.URL, []string{"read", "write"}) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + if receivedScope != "read write" { + t.Errorf("scope = %q, want %q", receivedScope, "read write") + } +} + +func TestStart_Error(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testDeviceAuthPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "invalid_client", + ErrorDescription: "Unknown client", + }) + }, + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Start(context.Background(), server.URL, nil) + if err == nil { + t.Fatal("Start() expected error, got nil") + } + + wantErr := "device auth failed: invalid_client: Unknown client" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestStart_NoDeviceEndpoint(t *testing.T) { + server := newTestServer(t, testServerOptions{ + wellKnownFunc: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(OIDCConfiguration{ + TokenEndpoint: "http://example.com/token", + }) + }, + }) + defer server.Close() + + client := NewClient() + _, err := client.Start(context.Background(), server.URL, nil) + if err == nil { + t.Fatal("Start() expected error, got nil") + } + + if !strings.Contains(err.Error(), "device authorization endpoint not found") { + t.Errorf("error = %q, want to contain 'device authorization endpoint not found'", err.Error()) + } +} + +func TestPoll_Success(t *testing.T) { + wantToken := TokenResponse{ + AccessToken: "test-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "read write", + } + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("unexpected method: %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + if got := r.FormValue("client_id"); got != ClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + } + if got := r.FormValue("grant_type"); got != GrantTypeDeviceCode { + t.Errorf("unexpected grant_type: got %q", got) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(wantToken) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err != nil { + t.Fatalf("Poll() error = %v", err) + } + + if resp.AccessToken != wantToken.AccessToken { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, wantToken.AccessToken) + } + if resp.TokenType != wantToken.TokenType { + t.Errorf("TokenType = %q, want %q", resp.TokenType, wantToken.TokenType) + } +} + +func TestPoll_AuthorizationPending(t *testing.T) { + var callCount int32 + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + + w.Header().Set("Content-Type", "application/json") + + if count < 3 { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "authorization_pending", + ErrorDescription: "The user has not yet completed authorization", + }) + return + } + + json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err != nil { + t.Fatalf("Poll() error = %v", err) + } + + if resp.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test-access-token") + } + + if atomic.LoadInt32(&callCount) != 3 { + t.Errorf("callCount = %d, want 3", callCount) + } +} + +func TestPoll_SlowDown(t *testing.T) { + var callCount int32 + + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + + w.Header().Set("Content-Type", "application/json") + + if count == 1 { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "slow_down", + }) + return + } + + json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err != nil { + t.Fatalf("Poll() error = %v", err) + } + + if resp.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test-access-token") + } + + if atomic.LoadInt32(&callCount) != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } +} + +func TestPoll_ExpiredToken(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "expired_token", + ErrorDescription: "The device code has expired", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + wantErr := "device code expired" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestPoll_AccessDenied(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "access_denied", + ErrorDescription: "The user denied the request", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + wantErr := "authorization was denied by the user" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestPoll_Timeout(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "authorization_pending", + }) + }, + }, + }) + defer server.Close() + + client := NewClient().(*httpClient) + _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 0) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + wantErr := "device code expired" + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +func TestPoll_ContextCancellation(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Error: "authorization_pending", + }) + }, + }, + }) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + client := NewClient().(*httpClient) + _, err := client.Poll(ctx, server.URL, "test-device-code", 10*time.Millisecond, 3600) + if err == nil { + t.Fatal("Poll() expected error, got nil") + } + + if err != context.Canceled && !strings.Contains(err.Error(), "context canceled") { + t.Errorf("error = %v, want context.Canceled or wrapped context canceled error", err) + } +} From 28b88ae0aed57434e3f27fa36e47198db06c3af9 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:58:10 +0200 Subject: [PATCH 2/9] add refresh token to device response unmarshall --- internal/oauthdevice/device_flow.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/internal/oauthdevice/device_flow.go b/internal/oauthdevice/device_flow.go index 851779cc3a..5e3de8971c 100644 --- a/internal/oauthdevice/device_flow.go +++ b/internal/oauthdevice/device_flow.go @@ -50,10 +50,11 @@ type DeviceAuthResponse struct { } type TokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in,omitempty"` - Scope string `json:"scope,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` } type ErrorResponse struct { From bbf003d0ce6a701a24feef116a6250568d54ccd4 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:10:24 +0200 Subject: [PATCH 3/9] make NewClient take ClientID as param --- internal/oauthdevice/device_flow.go | 14 +++++---- internal/oauthdevice/device_flow_test.go | 36 ++++++++++++------------ 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/internal/oauthdevice/device_flow.go b/internal/oauthdevice/device_flow.go index 5e3de8971c..c278dd4ba3 100644 --- a/internal/oauthdevice/device_flow.go +++ b/internal/oauthdevice/device_flow.go @@ -17,8 +17,10 @@ import ( ) const ( - ClientID = "sgo_cid_sourcegraph-cli" + // DefaultClientID is a predefined Client ID built into Sourcegraph + DefaultClientID = "sgo_cid_sourcegraph-cli" + // wellKnownPath is the path on the sourcegraph server where clients can discover OAuth configuration wellKnownPath = "/.well-known/openid-configuration" GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code" @@ -69,13 +71,15 @@ type Client interface { } type httpClient struct { - client *http.Client + clientID string + client *http.Client // cached OIDC configuration per endpoint configCache map[string]*OIDCConfiguration } -func NewClient() Client { +func NewClient(clientID string) Client { return &httpClient{ + clientID: clientID, client: &http.Client{ Timeout: 30 * time.Second, }, @@ -152,7 +156,7 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string } data := url.Values{} - data.Set("client_id", ClientID) + data.Set("client_id", DefaultClientID) if len(scopes) > 0 { data.Set("scope", strings.Join(scopes, " ")) } else { @@ -266,7 +270,7 @@ func (e *PollError) Error() string { func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) { data := url.Values{} - data.Set("client_id", ClientID) + data.Set("client_id", DefaultClientID) data.Set("device_code", deviceCode) data.Set("grant_type", GrantTypeDeviceCode) diff --git a/internal/oauthdevice/device_flow_test.go b/internal/oauthdevice/device_flow_test.go index 02b3923d88..e60e1f9b1a 100644 --- a/internal/oauthdevice/device_flow_test.go +++ b/internal/oauthdevice/device_flow_test.go @@ -50,7 +50,7 @@ func TestDiscover_Success(t *testing.T) { server := newTestServer(t, testServerOptions{}) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) config, err := client.Discover(context.Background(), server.URL) if err != nil { t.Fatalf("Discover() error = %v", err) @@ -78,7 +78,7 @@ func TestDiscover_Caching(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) // Populate the cache _, err := client.Discover(context.Background(), server.URL) @@ -105,7 +105,7 @@ func TestDiscover_Error(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Discover(context.Background(), server.URL) if err == nil { t.Fatal("Discover() expected error, got nil") @@ -141,8 +141,8 @@ func TestStart_Success(t *testing.T) { return } - if got := r.FormValue("client_id"); got != ClientID { - t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + if got := r.FormValue("client_id"); got != DefaultClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, DefaultClientID) } w.Header().Set("Content-Type", "application/json") @@ -152,7 +152,7 @@ func TestStart_Success(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) resp, err := client.Start(context.Background(), server.URL, nil) if err != nil { t.Fatalf("Start() error = %v", err) @@ -204,7 +204,7 @@ func TestStart_WithScopes(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Start(context.Background(), server.URL, []string{"read", "write"}) if err != nil { t.Fatalf("Start() error = %v", err) @@ -230,7 +230,7 @@ func TestStart_Error(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Start(context.Background(), server.URL, nil) if err == nil { t.Fatal("Start() expected error, got nil") @@ -253,7 +253,7 @@ func TestStart_NoDeviceEndpoint(t *testing.T) { }) defer server.Close() - client := NewClient() + client := NewClient(DefaultClientID) _, err := client.Start(context.Background(), server.URL, nil) if err == nil { t.Fatal("Start() expected error, got nil") @@ -287,8 +287,8 @@ func TestPoll_Success(t *testing.T) { return } - if got := r.FormValue("client_id"); got != ClientID { - t.Errorf("unexpected client_id: got %q, want %q", got, ClientID) + if got := r.FormValue("client_id"); got != DefaultClientID { + t.Errorf("unexpected client_id: got %q, want %q", got, DefaultClientID) } if got := r.FormValue("grant_type"); got != GrantTypeDeviceCode { t.Errorf("unexpected grant_type: got %q", got) @@ -301,7 +301,7 @@ func TestPoll_Success(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) @@ -343,7 +343,7 @@ func TestPoll_AuthorizationPending(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) @@ -385,7 +385,7 @@ func TestPoll_SlowDown(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) @@ -415,7 +415,7 @@ func TestPoll_ExpiredToken(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err == nil { t.Fatal("Poll() expected error, got nil") @@ -442,7 +442,7 @@ func TestPoll_AccessDenied(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) if err == nil { t.Fatal("Poll() expected error, got nil") @@ -468,7 +468,7 @@ func TestPoll_Timeout(t *testing.T) { }) defer server.Close() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 0) if err == nil { t.Fatal("Poll() expected error, got nil") @@ -497,7 +497,7 @@ func TestPoll_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - client := NewClient().(*httpClient) + client := NewClient(DefaultClientID).(*httpClient) _, err := client.Poll(ctx, server.URL, "test-device-code", 10*time.Millisecond, 3600) if err == nil { t.Fatal("Poll() expected error, got nil") From 67ffa4c7f7eaf503fbe60e9f17c9fa3c02f544c8 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Wed, 3 Dec 2025 15:10:36 +0200 Subject: [PATCH 4/9] add flag to set client-id for device-flow --- cmd/src/login.go | 58 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/cmd/src/login.go b/cmd/src/login.go index 52d881cc58..e42632c3b0 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -7,9 +7,11 @@ import ( "io" "os" "strings" + "time" "github.com/sourcegraph/src-cli/internal/api" "github.com/sourcegraph/src-cli/internal/cmderrors" + "github.com/sourcegraph/src-cli/internal/oauthdevice" ) func init() { @@ -17,7 +19,7 @@ func init() { Usage: - src login SOURCEGRAPH_URL + src login [flags] SOURCEGRAPH_URL Examples: @@ -32,6 +34,11 @@ Examples: Use OAuth device flow to authenticate: $ src login --device-flow https://sourcegraph.com + + + Override the default client id used during device flow when authenticating: + + $ src login --device-flow https://sourcegraph.com --client-id sgo_my_own_client_id ` flagSet := flag.NewFlagSet("login", flag.ExitOnError) @@ -41,7 +48,9 @@ Examples: } var ( - apiFlags = api.NewFlags(flagSet) + apiFlags = api.NewFlags(flagSet) + useDeviceFlow = flagSet.Bool("device-flow", false, "Use OAuth device flow to obtain an access token interactively") + OAuthClientID = flagSet.String("client-id", oauthdevice.DefaultClientID, "Client ID to use with OAuth device flow. Will use the predefined src cli client ID if not specified.") ) handler := func(args []string) error { @@ -56,9 +65,21 @@ Examples: return cmderrors.Usage("expected exactly one argument: the Sourcegraph URL, or SRC_ENDPOINT to be set") } + if *OAuthClientID == "" { + return cmderrors.Usage("no value specified for client-id") + } + client := cfg.apiClient(apiFlags, io.Discard) - return loginCmd(context.Background(), cfg, client, endpoint, os.Stdout) + return loginCmd(context.Background(), loginParams{ + cfg: cfg, + client: client, + endpoint: endpoint, + out: os.Stdout, + useDeviceFlow: *useDeviceFlow, + apiFlags: apiFlags, + deviceFlowClient: oauthdevice.NewClient(*OAuthClientID), + }) } commands = append(commands, &command{ @@ -68,8 +89,21 @@ Examples: }) } -func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg string, out io.Writer) error { - endpointArg = cleanEndpoint(endpointArg) +type loginParams struct { + cfg *config + client api.Client + endpoint string + out io.Writer + useDeviceFlow bool + apiFlags *api.Flags + deviceFlowClient oauthdevice.Client +} + +func loginCmd(ctx context.Context, p loginParams) error { + endpointArg := cleanEndpoint(p.endpoint) + cfg := p.cfg + client := p.client + out := p.out printProblem := func(problem string) { fmt.Fprintf(out, "❌ Problem: %s\n", problem) @@ -90,7 +124,19 @@ func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg s noToken := cfg.AccessToken == "" endpointConflict := endpointArg != cfg.Endpoint - if noToken || endpointConflict { + + if p.useDeviceFlow { + token, err := runDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) + if err != nil { + printProblem(fmt.Sprintf("Device flow authentication failed: %s", err)) + fmt.Fprintln(out, createAccessTokenMessage) + return cmderrors.ExitCode1 + } + + cfg.AccessToken = token + cfg.Endpoint = endpointArg + client = cfg.apiClient(p.apiFlags, out) + } else if noToken || endpointConflict { fmt.Fprintln(out) switch { case noToken: From f2fe1d208eaf10d456b3eab65c2cc349913a469f Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Fri, 23 Jan 2026 12:43:18 +0200 Subject: [PATCH 5/9] rename from deviceflow to OAuth --- cmd/src/login.go | 28 +++++++++---------- .../device_flow.go => oauth/flow.go} | 4 +-- .../flow_test.go} | 2 +- 3 files changed, 17 insertions(+), 17 deletions(-) rename internal/{oauthdevice/device_flow.go => oauth/flow.go} (98%) rename internal/{oauthdevice/device_flow_test.go => oauth/flow_test.go} (99%) diff --git a/cmd/src/login.go b/cmd/src/login.go index e42632c3b0..4cc6199fe5 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -11,7 +11,7 @@ import ( "github.com/sourcegraph/src-cli/internal/api" "github.com/sourcegraph/src-cli/internal/cmderrors" - "github.com/sourcegraph/src-cli/internal/oauthdevice" + "github.com/sourcegraph/src-cli/internal/oauth" ) func init() { @@ -33,12 +33,12 @@ Examples: Use OAuth device flow to authenticate: - $ src login --device-flow https://sourcegraph.com + $ src login --oauth https://sourcegraph.com Override the default client id used during device flow when authenticating: - $ src login --device-flow https://sourcegraph.com --client-id sgo_my_own_client_id + $ src login --oauth https://sourcegraph.com --client-id sgo_my_own_client_id ` flagSet := flag.NewFlagSet("login", flag.ExitOnError) @@ -49,8 +49,8 @@ Examples: var ( apiFlags = api.NewFlags(flagSet) - useDeviceFlow = flagSet.Bool("device-flow", false, "Use OAuth device flow to obtain an access token interactively") - OAuthClientID = flagSet.String("client-id", oauthdevice.DefaultClientID, "Client ID to use with OAuth device flow. Will use the predefined src cli client ID if not specified.") + useOAuth = flagSet.Bool("oauth", false, "Use OAuth device flow to obtain an access token interactively") + OAuthClientID = flagSet.String("client-id", oauth.DefaultClientID, "Client ID to use with OAuth device flow. Will use the predefined src cli client ID if not specified.") ) handler := func(args []string) error { @@ -76,9 +76,9 @@ Examples: client: client, endpoint: endpoint, out: os.Stdout, - useDeviceFlow: *useDeviceFlow, + useOAuth: *useOAuth, apiFlags: apiFlags, - deviceFlowClient: oauthdevice.NewClient(*OAuthClientID), + deviceFlowClient: oauth.NewClient(*OAuthClientID), }) } @@ -94,9 +94,9 @@ type loginParams struct { client api.Client endpoint string out io.Writer - useDeviceFlow bool + useOAuth bool apiFlags *api.Flags - deviceFlowClient oauthdevice.Client + deviceFlowClient oauth.Client } func loginCmd(ctx context.Context, p loginParams) error { @@ -125,10 +125,10 @@ func loginCmd(ctx context.Context, p loginParams) error { noToken := cfg.AccessToken == "" endpointConflict := endpointArg != cfg.Endpoint - if p.useDeviceFlow { - token, err := runDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) + if p.useOAuth { + token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) if err != nil { - printProblem(fmt.Sprintf("Device flow authentication failed: %s", err)) + printProblem(fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) fmt.Fprintln(out, createAccessTokenMessage) return cmderrors.ExitCode1 } @@ -173,7 +173,7 @@ func loginCmd(ctx context.Context, p loginParams) error { fmt.Fprintln(out) fmt.Fprintf(out, "✔️ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointArg) - if p.useDeviceFlow { + if p.useOAuth { fmt.Fprintln(out) fmt.Fprintf(out, "To use this access token, set the following environment variables in your terminal:\n\n") fmt.Fprintf(out, " export SRC_ENDPOINT=%s\n", endpointArg) @@ -184,7 +184,7 @@ func loginCmd(ctx context.Context, p loginParams) error { return nil } -func runDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauthdevice.Client) (string, error) { +func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (string, error) { authResp, err := client.Start(ctx, endpoint, nil) if err != nil { return "", err diff --git a/internal/oauthdevice/device_flow.go b/internal/oauth/flow.go similarity index 98% rename from internal/oauthdevice/device_flow.go rename to internal/oauth/flow.go index c278dd4ba3..b9a960a867 100644 --- a/internal/oauthdevice/device_flow.go +++ b/internal/oauth/flow.go @@ -1,6 +1,6 @@ -// Package oauthdevice implements the OAuth 2.0 Device Authorization Grant (RFC 8628) +// Package oauthimplements the OAuth 2.0 Device Authorization Grant (RFC 8628) // for authenticating with Sourcegraph instances. -package oauthdevice +package oauth import ( "context" diff --git a/internal/oauthdevice/device_flow_test.go b/internal/oauth/flow_test.go similarity index 99% rename from internal/oauthdevice/device_flow_test.go rename to internal/oauth/flow_test.go index e60e1f9b1a..46e3a97036 100644 --- a/internal/oauthdevice/device_flow_test.go +++ b/internal/oauth/flow_test.go @@ -1,4 +1,4 @@ -package oauthdevice +package oauth import ( "context" From 7207e4b7e4dec2b2abca2da2bced3be15f006737 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 13:26:52 +0200 Subject: [PATCH 6/9] remove --client-id flag --- cmd/src/login.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/cmd/src/login.go b/cmd/src/login.go index 4cc6199fe5..598794cbe9 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -38,7 +38,7 @@ Examples: Override the default client id used during device flow when authenticating: - $ src login --oauth https://sourcegraph.com --client-id sgo_my_own_client_id + $ src login --oauth https://sourcegraph.com ` flagSet := flag.NewFlagSet("login", flag.ExitOnError) @@ -48,9 +48,8 @@ Examples: } var ( - apiFlags = api.NewFlags(flagSet) - useOAuth = flagSet.Bool("oauth", false, "Use OAuth device flow to obtain an access token interactively") - OAuthClientID = flagSet.String("client-id", oauth.DefaultClientID, "Client ID to use with OAuth device flow. Will use the predefined src cli client ID if not specified.") + apiFlags = api.NewFlags(flagSet) + useOAuth = flagSet.Bool("oauth", false, "Use OAuth device flow to obtain an access token interactively") ) handler := func(args []string) error { @@ -65,10 +64,6 @@ Examples: return cmderrors.Usage("expected exactly one argument: the Sourcegraph URL, or SRC_ENDPOINT to be set") } - if *OAuthClientID == "" { - return cmderrors.Usage("no value specified for client-id") - } - client := cfg.apiClient(apiFlags, io.Discard) return loginCmd(context.Background(), loginParams{ @@ -78,7 +73,7 @@ Examples: out: os.Stdout, useOAuth: *useOAuth, apiFlags: apiFlags, - deviceFlowClient: oauth.NewClient(*OAuthClientID), + deviceFlowClient: oauth.NewClient(oauth.DefaultClientID), }) } From 9c9a6d453ab44fa32d89f8410817fe26095c4044 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 13:39:03 +0200 Subject: [PATCH 7/9] fix test --- cmd/src/login_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 37fbf7a703..ef7d01e019 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -18,7 +18,7 @@ func TestLogin(t *testing.T) { t.Helper() var out bytes.Buffer - err = loginCmd(context.Background(), cfg, cfg.apiClient(nil, io.Discard), endpointArg, &out) + err = loginCmd(context.Background(), loginParams{cfg: cfg, client: cfg.apiClient(nil, io.Discard), endpoint: endpointArg, out: &out}) return strings.TrimSpace(out.String()), err } From 61a10705c9472257025beadefa9d8fa00124ca4b Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Thu, 26 Feb 2026 14:25:25 +0200 Subject: [PATCH 8/9] automatically open browser --- cmd/src/login.go | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/cmd/src/login.go b/cmd/src/login.go index 598794cbe9..e85a0bd4e9 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "os" + "os/exec" + "runtime" "strings" "time" @@ -170,9 +172,7 @@ func loginCmd(ctx context.Context, p loginParams) error { if p.useOAuth { fmt.Fprintln(out) - fmt.Fprintf(out, "To use this access token, set the following environment variables in your terminal:\n\n") - fmt.Fprintf(out, " export SRC_ENDPOINT=%s\n", endpointArg) - fmt.Fprintf(out, " export SRC_ACCESS_TOKEN=%s\n", cfg.AccessToken) + fmt.Fprintf(out, "Authenticated with OAuth credentials") } fmt.Fprintln(out) @@ -185,12 +185,16 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli return "", err } - fmt.Fprintln(out) - fmt.Fprintf(out, "To authenticate, visit %s and enter the code: %s\n", authResp.VerificationURI, authResp.UserCode) - if authResp.VerificationURIComplete != "" { - fmt.Fprintln(out) - fmt.Fprintf(out, "Alternatively, you can open: %s\n", authResp.VerificationURIComplete) + authURL := authResp.VerificationURIComplete + msg := fmt.Sprintf("If your browser did not open automatically, visit %s.", authURL) + if authURL == "" { + authURL = authResp.VerificationURI + msg = fmt.Sprintf("If your browser did not open automatically, visit %s and enter the user code %s", authURL, authResp.DeviceCode) } + _ = openInBrowser(authURL) + fmt.Fprintln(out) + fmt.Fprint(out, msg) + fmt.Fprintln(out) fmt.Fprint(out, "Waiting for authorization...") defer fmt.Fprintf(out, "DONE\n\n") @@ -207,3 +211,21 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli return tokenResp.AccessToken, nil } + +func openInBrowser(url string) error { + if url == "" { + return nil + } + + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "windows": + // "start" is a cmd.exe built-in; the empty string is the window title. + cmd = exec.Command("cmd", "/c", "start", "", url) + default: + cmd = exec.Command("xdg-open", url) + } + return cmd.Run() +} From e2cb6203e17a650dfb53cdd8aec20dd52cb73699 Mon Sep 17 00:00:00 2001 From: William Bezuidenhout Date: Mon, 9 Mar 2026 12:03:22 +0200 Subject: [PATCH 9/9] feat(oauth): Add refresh to oauthdevice.Client (#1227) * add refresh to oauthdevice.Client * oauthdevice: add RefreshToken field and Refresh method * feat(oauth): Use keyring to store oauth token (#1228) * add refresh to oauthdevice.Client * add OAuth Transport and use it if no access token * secrets: switch to zalando/go-keyring and add context support * secrets: scope keyring by endpoint --- cmd/src/login.go | 52 +++++--- cmd/src/login_test.go | 17 ++- cmd/src/main.go | 15 ++- go.mod | 4 + go.sum | 10 ++ internal/api/api.go | 52 ++++++-- internal/oauth/flow.go | 146 ++++++++++++++++++++-- internal/oauth/flow_test.go | 66 +++++++++- internal/oauth/http_transport.go | 87 +++++++++++++ internal/oauth/http_transport_test.go | 172 ++++++++++++++++++++++++++ internal/secrets/keyring.go | 77 ++++++++++++ internal/secrets/keyring_test.go | 58 +++++++++ 12 files changed, 704 insertions(+), 52 deletions(-) create mode 100644 internal/oauth/http_transport.go create mode 100644 internal/oauth/http_transport_test.go create mode 100644 internal/secrets/keyring.go create mode 100644 internal/secrets/keyring_test.go diff --git a/cmd/src/login.go b/cmd/src/login.go index e85a0bd4e9..5a73ef4cc8 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -112,7 +112,9 @@ func loginCmd(ctx context.Context, p loginParams) error { export SRC_ACCESS_TOKEN=(your access token) To verify that it's working, run the login command again. -`, endpointArg, endpointArg) + + Alternatively, you can try logging in using OAuth by running: src login --oauth %s +`, endpointArg, endpointArg, endpointArg) if cfg.ConfigFilePath != "" { fmt.Fprintln(out) @@ -121,6 +123,17 @@ func loginCmd(ctx context.Context, p loginParams) error { noToken := cfg.AccessToken == "" endpointConflict := endpointArg != cfg.Endpoint + if !p.useOAuth && (noToken || endpointConflict) { + fmt.Fprintln(out) + switch { + case noToken: + printProblem("No access token is configured.") + case endpointConflict: + printProblem(fmt.Sprintf("The configured endpoint is %s, not %s.", cfg.Endpoint, endpointArg)) + } + fmt.Fprintln(out, createAccessTokenMessage) + return cmderrors.ExitCode1 + } if p.useOAuth { token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient) @@ -130,19 +143,20 @@ func loginCmd(ctx context.Context, p loginParams) error { return cmderrors.ExitCode1 } - cfg.AccessToken = token - cfg.Endpoint = endpointArg - client = cfg.apiClient(p.apiFlags, out) - } else if noToken || endpointConflict { - fmt.Fprintln(out) - switch { - case noToken: - printProblem("No access token is configured.") - case endpointConflict: - printProblem(fmt.Sprintf("The configured endpoint is %s, not %s.", cfg.Endpoint, endpointArg)) + if err := oauth.StoreToken(ctx, token); err != nil { + fmt.Fprintln(out) + fmt.Fprintf(out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err) } - fmt.Fprintln(out, createAccessTokenMessage) - return cmderrors.ExitCode1 + + client = api.NewClient(api.ClientOpts{ + Endpoint: cfg.Endpoint, + AdditionalHeaders: cfg.AdditionalHeaders, + Flags: p.apiFlags, + Out: out, + ProxyURL: cfg.ProxyURL, + ProxyPath: cfg.ProxyPath, + OAuthToken: token, + }) } // See if the user is already authenticated. @@ -179,10 +193,10 @@ func loginCmd(ctx context.Context, p loginParams) error { return nil } -func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (string, error) { +func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) { authResp, err := client.Start(ctx, endpoint, nil) if err != nil { - return "", err + return nil, err } authURL := authResp.VerificationURIComplete @@ -204,12 +218,14 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli interval = 5 * time.Second } - tokenResp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) + resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) if err != nil { - return "", err + return nil, err } - return tokenResp.AccessToken, nil + token := resp.Token(endpoint) + token.ClientID = client.ClientID() + return token, nil } func openInBrowser(url string) error { diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index ef7d01e019..ab7a15056a 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/sourcegraph/src-cli/internal/cmderrors" + "github.com/sourcegraph/src-cli/internal/oauth" ) func TestLogin(t *testing.T) { @@ -18,7 +19,13 @@ func TestLogin(t *testing.T) { t.Helper() var out bytes.Buffer - err = loginCmd(context.Background(), loginParams{cfg: cfg, client: cfg.apiClient(nil, io.Discard), endpoint: endpointArg, out: &out}) + err = loginCmd(context.Background(), loginParams{ + cfg: cfg, + client: cfg.apiClient(nil, io.Discard), + endpoint: endpointArg, + out: &out, + deviceFlowClient: oauth.NewClient(oauth.DefaultClientID), + }) return strings.TrimSpace(out.String()), err } @@ -27,7 +34,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again." + wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com" if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -38,7 +45,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again." + wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com" if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -49,7 +56,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\n❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again." + wantOut := "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\n❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://example.com" if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -67,7 +74,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "❌ Problem: Invalid access token.\n\n🛠 To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)" + wantOut := "❌ Problem: Invalid access token.\n\n🛠 To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth $ENDPOINT\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)" wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", endpoint) if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) diff --git a/cmd/src/main.go b/cmd/src/main.go index edfb1073d7..41e5c55cd0 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "flag" "io" @@ -15,6 +16,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/src-cli/internal/api" + "github.com/sourcegraph/src-cli/internal/oauth" ) const usageText = `src is a tool that provides access to Sourcegraph instances. @@ -122,7 +124,7 @@ type config struct { // apiClient returns an api.Client built from the configuration. func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { - return api.NewClient(api.ClientOpts{ + opts := api.ClientOpts{ Endpoint: c.Endpoint, AccessToken: c.AccessToken, AdditionalHeaders: c.AdditionalHeaders, @@ -130,7 +132,16 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { Out: out, ProxyURL: c.ProxyURL, ProxyPath: c.ProxyPath, - }) + } + + // Only use OAuth if we do not have SRC_ACCESS_TOKEN set + if c.AccessToken == "" { + if t, err := oauth.LoadToken(context.Background(), c.Endpoint); err == nil { + opts.OAuthToken = t + } + } + + return api.NewClient(opts) } // readConfig reads the config file from the given path. diff --git a/go.mod b/go.mod index 3c9e3eb338..0a04972e5a 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/sourcegraph/sourcegraph/lib v0.0.0-20240709083501-1af563b61442 github.com/stretchr/testify v1.11.1 github.com/tliron/glsp v0.2.2 + github.com/zalando/go-keyring v0.2.6 golang.org/x/sync v0.18.0 google.golang.org/api v0.256.0 google.golang.org/protobuf v1.36.10 @@ -41,6 +42,7 @@ require ( ) require ( + al.essio.dev/pkg/shellescape v1.5.1 // indirect cel.dev/expr v0.24.0 // indirect cloud.google.com/go/auth v0.17.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect @@ -64,6 +66,7 @@ require ( github.com/clipperhouse/uax29/v2 v2.2.0 // indirect github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 // indirect github.com/containerd/stargz-snapshotter/estargz v0.14.3 // indirect + github.com/danieljoos/wincred v1.2.2 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/cli v24.0.4+incompatible // indirect github.com/docker/distribution v2.8.2+incompatible // indirect @@ -78,6 +81,7 @@ require ( github.com/go-chi/chi/v5 v5.2.2 // indirect github.com/go-jose/go-jose/v4 v4.1.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/gofrs/uuid/v5 v5.0.0 // indirect github.com/google/gnostic-models v0.6.9 // indirect github.com/google/go-containerregistry v0.19.1 // indirect diff --git a/go.sum b/go.sum index f47d1d10c9..6cbdc71412 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +al.essio.dev/pkg/shellescape v1.5.1 h1:86HrALUujYS/h+GtqoB26SBEdkWfmMI6FubjXlsXyho= +al.essio.dev/pkg/shellescape v1.5.1/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA= @@ -139,6 +141,8 @@ github.com/creack/goselect v0.1.2/go.mod h1:a/NhLweNvqIYMuxcMOuWY516Cimucms3DglD github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= +github.com/danieljoos/wincred v1.2.2 h1:774zMFJrqaeYCK2W57BgAem/MLi6mtSE47MB6BOJ0i0= +github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7nd/ogr0Uh8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -212,6 +216,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M= @@ -243,6 +249,8 @@ github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgY github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= @@ -495,6 +503,8 @@ github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk= github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= +github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= +github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= diff --git a/internal/api/api.go b/internal/api/api.go index 5f750c1d4a..ef9f822a7a 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -18,6 +18,7 @@ import ( "github.com/kballard/go-shellquote" "github.com/mattn/go-isatty" + "github.com/sourcegraph/src-cli/internal/oauth" "github.com/sourcegraph/src-cli/internal/version" ) @@ -85,21 +86,35 @@ type ClientOpts struct { ProxyURL *url.URL ProxyPath string + + OAuthToken *oauth.Token } -func buildTransport(opts ClientOpts, flags *Flags) *http.Transport { - transport := http.DefaultTransport.(*http.Transport).Clone() +func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { + var transport http.RoundTripper + { + tp := http.DefaultTransport.(*http.Transport).Clone() - if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } + if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { + tp.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } - if transport.TLSClientConfig == nil { - transport.TLSClientConfig = &tls.Config{} + if tp.TLSClientConfig == nil { + tp.TLSClientConfig = &tls.Config{} + } + + if opts.ProxyURL != nil || opts.ProxyPath != "" { + tp = withProxyTransport(tp, opts.ProxyURL, opts.ProxyPath) + } + + transport = tp } - if opts.ProxyURL != nil || opts.ProxyPath != "" { - transport = withProxyTransport(transport, opts.ProxyURL, opts.ProxyPath) + if opts.AccessToken == "" && opts.OAuthToken != nil { + transport = &oauth.Transport{ + Base: transport, + Token: opts.OAuthToken, + } } return transport @@ -168,6 +183,7 @@ func (c *client) createHTTPRequest(ctx context.Context, method, p string, body i } else { req.Header.Set("User-Agent", "src-cli/"+version.BuildTag) } + if c.opts.AccessToken != "" { req.Header.Set("Authorization", "token "+c.opts.AccessToken) } @@ -249,10 +265,20 @@ func (r *request) do(ctx context.Context, result any) (bool, error) { // confirm the status code. You can test this easily with e.g. an invalid // endpoint like -endpoint=https://google.com if resp.StatusCode != http.StatusOK { - if resp.StatusCode == http.StatusUnauthorized && isatty.IsCygwinTerminal(os.Stdout.Fd()) { - fmt.Println("You may need to specify or update your access token to use this endpoint.") - fmt.Println("See https://github.com/sourcegraph/src-cli#readme") - fmt.Println("") + if resp.StatusCode == http.StatusUnauthorized { + if oauth.IsOAuthTransport(r.client.httpClient.Transport) { + fmt.Println("The OAuth token is invalid. Please check that the Sourcegraph CLI client is still authorized.") + fmt.Println("") + fmt.Printf("To re-authorize, run: src login --oauth %s\n", r.client.opts.Endpoint) + fmt.Println("") + fmt.Println("Learn more at https://github.com/sourcegraph/src-cli#readme") + fmt.Println("") + } + if isatty.IsCygwinTerminal(os.Stdout.Fd()) { + fmt.Println("You may need to specify or update your access token to use this endpoint.") + fmt.Println("See https://github.com/sourcegraph/src-cli#readme") + fmt.Println("") + } } body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index b9a960a867..f34d9ae656 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -3,6 +3,7 @@ package oauth import ( + "cmp" "context" "encoding/json" "fmt" @@ -13,6 +14,8 @@ import ( "testing" "time" + "github.com/sourcegraph/src-cli/internal/secrets" + "github.com/sourcegraph/sourcegraph/lib/errors" ) @@ -30,6 +33,8 @@ const ( ScopeEmail string = "email" ScopeOfflineAccess string = "offline_access" ScopeUserAll string = "user:all" + + oauthKey string = "src:oauth" ) var defaultScopes = []string{ScopeEmail, ScopeOfflineAccess, ScopeOpenID, ScopeProfile, ScopeUserAll} @@ -54,20 +59,30 @@ type DeviceAuthResponse struct { type TokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` - TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in,omitempty"` + TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` } +type Token struct { + Endpoint string `json:"endpoint"` + ClientID string `json:"client_id,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + type ErrorResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description,omitempty"` } type Client interface { + ClientID() string Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) + Refresh(ctx context.Context, token *Token) (*TokenResponse, error) } type httpClient struct { @@ -78,22 +93,23 @@ type httpClient struct { } func NewClient(clientID string) Client { - return &httpClient{ - clientID: clientID, - client: &http.Client{ - Timeout: 30 * time.Second, - }, - configCache: make(map[string]*OIDCConfiguration), - } + return NewClientWithHTTPClient(clientID, &http.Client{ + Timeout: 30 * time.Second, + }) } -func NewClientWithHTTPClient(c *http.Client) Client { +func NewClientWithHTTPClient(clientID string, c *http.Client) Client { return &httpClient{ + clientID: cmp.Or(clientID, DefaultClientID), client: c, configCache: make(map[string]*OIDCConfiguration), } } +func (c *httpClient) ClientID() string { + return c.clientID +} + // Discover fetches the openid-configuration which contains all the routes a client should // use for authorization, device flows, tokens etc. // @@ -145,7 +161,7 @@ func (c *httpClient) Discover(ctx context.Context, endpoint string) (*OIDCConfig func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) { endpoint = strings.TrimRight(endpoint, "/") - // Discover OIDC configuration + // Discover OIDC configuration - caches on first call config, err := c.Discover(ctx, endpoint) if err != nil { return nil, errors.Wrap(err, "OIDC discovery failed") @@ -156,7 +172,7 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string } data := url.Values{} - data.Set("client_id", DefaultClientID) + data.Set("client_id", c.clientID) if len(scopes) > 0 { data.Set("scope", strings.Join(scopes, " ")) } else { @@ -208,7 +224,7 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string func (c *httpClient) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) { endpoint = strings.TrimRight(endpoint, "/") - // Discover OIDC configuration (should be cached from Start) + // Discover OIDC configuration - caches on first call config, err := c.Discover(ctx, endpoint) if err != nil { return nil, errors.Wrap(err, "OIDC discovery failed") @@ -270,7 +286,7 @@ func (e *PollError) Error() string { func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) { data := url.Values{} - data.Set("client_id", DefaultClientID) + data.Set("client_id", c.clientID) data.Set("device_code", deviceCode) data.Set("grant_type", GrantTypeDeviceCode) @@ -307,3 +323,107 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str return &tokenResp, nil } + +// Refresh exchanges a refresh token for a new access token. +func (c *httpClient) Refresh(ctx context.Context, token *Token) (*TokenResponse, error) { + config, err := c.Discover(ctx, token.Endpoint) + if err != nil { + return nil, errors.Wrap(err, "failed to discover OIDC configuration") + } + + if config.TokenEndpoint == "" { + return nil, errors.New("OIDC configuration has no token endpoint") + } + + data := url.Values{} + data.Set("client_id", c.clientID) + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", token.RefreshToken) + + req, err := http.NewRequestWithContext(ctx, "POST", config.TokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, errors.Wrap(err, "creating refresh token request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "refresh token request failed") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "reading refresh token response") + } + + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" { + return nil, errors.Newf("refresh token failed: %s: %s", errResp.Error, errResp.ErrorDescription) + } + return nil, errors.Newf("refresh token failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "parsing refresh token response") + } + + return &tokenResp, err +} + +func (t *TokenResponse) Token(endpoint string) *Token { + return &Token{ + Endpoint: strings.TrimRight(endpoint, "/"), + RefreshToken: t.RefreshToken, + AccessToken: t.AccessToken, + ExpiresAt: time.Now().Add(time.Second * time.Duration(t.ExpiresIn)), + } +} + +func (t *Token) HasExpired() bool { + return time.Now().After(t.ExpiresAt) +} + +func (t *Token) ExpiringIn(d time.Duration) bool { + future := time.Now().Add(d) + return future.After(t.ExpiresAt) +} + +func StoreToken(ctx context.Context, token *Token) error { + if token.Endpoint == "" { + return errors.New("token endpoint cannot be empty when storing the token") + } + + store, err := secrets.Open(ctx, token.Endpoint) + if err != nil { + return err + } + data, err := json.Marshal(token) + if err != nil { + return errors.Wrap(err, "failed to marshal token") + } + + return store.Put(oauthKey, data) +} + +func LoadToken(ctx context.Context, endpoint string) (*Token, error) { + store, err := secrets.Open(ctx, endpoint) + if err != nil { + return nil, err + } + + data, err := store.Get(oauthKey) + if err != nil { + return nil, err + } + + var t Token + if err := json.Unmarshal(data, &t); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal token") + } + + return &t, nil +} diff --git a/internal/oauth/flow_test.go b/internal/oauth/flow_test.go index 46e3a97036..0b1ad5dc93 100644 --- a/internal/oauth/flow_test.go +++ b/internal/oauth/flow_test.go @@ -267,9 +267,9 @@ func TestStart_NoDeviceEndpoint(t *testing.T) { func TestPoll_Success(t *testing.T) { wantToken := TokenResponse{ AccessToken: "test-access-token", - TokenType: "Bearer", ExpiresIn: 3600, Scope: "read write", + TokenType: "Bearer", } server := newTestServer(t, testServerOptions{ @@ -313,6 +313,7 @@ func TestPoll_Success(t *testing.T) { if resp.TokenType != wantToken.TokenType { t.Errorf("TokenType = %q, want %q", resp.TokenType, wantToken.TokenType) } + } func TestPoll_AuthorizationPending(t *testing.T) { @@ -507,3 +508,66 @@ func TestPoll_ContextCancellation(t *testing.T) { t.Errorf("error = %v, want context.Canceled or wrapped context canceled error", err) } } + +func TestRefresh_Success(t *testing.T) { + server := newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + if got := r.FormValue("grant_type"); got != "refresh_token" { + t.Errorf("grant_type = %q, want %q", got, "refresh_token") + } + if got := r.FormValue("refresh_token"); got != "test-refresh-token" { + t.Errorf("refresh_token = %q, want %q", got, "test-refresh-token") + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + }, + }, + }) + defer server.Close() + + client := NewClient(DefaultClientID) + token := &Token{ + Endpoint: server.URL, + AccessToken: "new-access-token", + RefreshToken: "test-refresh-token", + ExpiresAt: time.Now().Add(time.Second * time.Duration(3600)), + } + resp, err := client.Refresh(context.Background(), token) + if err != nil { + t.Fatalf("Refresh() error = %v", err) + } + + if resp.AccessToken != "new-access-token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "new-access-token") + } + if resp.RefreshToken != "new-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", resp.RefreshToken, "new-refresh-token") + } +} + +func TestRefresh_DiscoverFailure(t *testing.T) { + client := NewClient(DefaultClientID) + token := &Token{ + Endpoint: "http://127.0.0.1:1", + RefreshToken: "test-refresh-token", + } + + _, err := client.Refresh(context.Background(), token) + if err == nil { + t.Fatal("Refresh() expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to discover OIDC configuration") { + t.Errorf("error = %q, want discovery failure context", err.Error()) + } +} diff --git a/internal/oauth/http_transport.go b/internal/oauth/http_transport.go new file mode 100644 index 0000000000..8aea3d75a4 --- /dev/null +++ b/internal/oauth/http_transport.go @@ -0,0 +1,87 @@ +package oauth + +import ( + "context" + "net/http" + "sync" + "time" +) + +var _ http.Transport + +var _ http.RoundTripper = (*Transport)(nil) + +type Transport struct { + Base http.RoundTripper + Token *Token + + mu sync.Mutex +} + +// storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during +// tests to swap out the implementation out with a mock +var storeRefreshedTokenFn = StoreToken + +// RoundTrip implements http.RoundTripper. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + + if err := t.refreshToken(ctx); err != nil { + return nil, err + } + + req2 := req.Clone(req.Context()) + req2.Header.Set("Authorization", "Bearer "+t.Token.AccessToken) + + if t.Base != nil { + return t.Base.RoundTrip(req2) + } + return http.DefaultTransport.RoundTrip(req2) +} + +// refreshToken checks if the token has expired or expiring soon and refreshes it. Once the token is +// refreshed, the in-memory token is updated and a best effort is made to store the token. +// If storing the token fails, no error is returned. +func (t *Transport) refreshToken(ctx context.Context) error { + t.mu.Lock() + defer t.mu.Unlock() + + prevToken := t.Token + token, err := maybeRefresh(ctx, t.Token) + if err != nil { + return err + } + t.Token = token + if token != prevToken { + // try to save the token if we fail let the request continue with in memory token + _ = storeRefreshedTokenFn(ctx, token) + } + + return nil +} + +// maybeRefresh conditionally refreshes the token. If the token has expired or is expriing in the next 30s +// it will be refreshed and the updated token will be returned. Otherwise, no refresh occurs and the original +// token is returned. +func maybeRefresh(ctx context.Context, token *Token) (*Token, error) { + // token has NOT expired and is NOT about to expire in 30s + if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) { + return token, nil + } + client := NewClient(token.ClientID) + + resp, err := client.Refresh(ctx, token) + if err != nil { + return nil, err + } + + next := resp.Token(token.Endpoint) + next.ClientID = token.ClientID + return next, nil +} + +// IsOAuthTransport checks wether the underlying type of the given RoundTripper is a OAuthTransport +func IsOAuthTransport(trp http.RoundTripper) bool { + _, ok := trp.(*Transport) + return ok +} diff --git a/internal/oauth/http_transport_test.go b/internal/oauth/http_transport_test.go new file mode 100644 index 0000000000..4dac832d05 --- /dev/null +++ b/internal/oauth/http_transport_test.go @@ -0,0 +1,172 @@ +package oauth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func newRefreshServer(t *testing.T, accessToken string) *httptest.Server { + t.Helper() + return newTestServer(t, testServerOptions{ + handlers: map[string]http.HandlerFunc{ + testTokenPath: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"` + accessToken + `","refresh_token":"new-refresh","expires_in":3600}`)) + }, + }, + }) +} + +func TestMaybeRefresh(t *testing.T) { + server := newRefreshServer(t, "new-token") + defer server.Close() + + tests := []struct { + name string + token *Token + wantAccess string + wantSame bool + }{ + { + name: "unchanged when still valid", + token: &Token{ + AccessToken: "valid-token", + ExpiresAt: time.Now().Add(time.Hour), + }, + wantAccess: "valid-token", + wantSame: true, + }, + { + name: "refreshes expired token", + token: &Token{ + Endpoint: server.URL, + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), + }, + wantAccess: "new-token", + }, + { + name: "refreshes token expiring soon", + token: &Token{ + Endpoint: server.URL, + AccessToken: "expiring-soon-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(10 * time.Second), + }, + wantAccess: "new-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := maybeRefresh(context.Background(), tt.token) + if err != nil { + t.Fatalf("maybeRefresh() error = %v", err) + } + if got.AccessToken != tt.wantAccess { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, tt.wantAccess) + } + if tt.wantSame && got != tt.token { + t.Errorf("token pointer changed for unexpired token") + } + }) + } +} + +func TestTransportRoundTrip(t *testing.T) { + tests := []struct { + name string + token *Token + persistErr error + wantAuthHeader string + wantStoreCalls int + }{ + { + name: "uses existing token without persisting", + token: &Token{ + AccessToken: "valid-token", + ExpiresAt: time.Now().Add(time.Hour), + }, + wantAuthHeader: "Bearer valid-token", + wantStoreCalls: 0, + }, + { + name: "persists refreshed token", + token: &Token{ + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), + }, + wantAuthHeader: "Bearer new-token", + wantStoreCalls: 1, + }, + { + name: "ignores persist failures", + token: &Token{ + AccessToken: "expired-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(-time.Hour), + }, + persistErr: errors.New("persist failed"), + wantAuthHeader: "Bearer new-token", + wantStoreCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.wantStoreCalls > 0 { + server := newRefreshServer(t, "new-token") + defer server.Close() + tt.token.Endpoint = server.URL + } + + originalStoreFn := storeRefreshedTokenFn + defer func() { storeRefreshedTokenFn = originalStoreFn }() + + var storeCalls int + var storedToken *Token + storeRefreshedTokenFn = func(_ context.Context, token *Token) error { + storeCalls++ + storedToken = token + return tt.persistErr + } + + var capturedAuth string + tr := &Transport{ + Base: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedAuth = req.Header.Get("Authorization") + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }), + Token: tt.token, + } + + _, err := tr.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil)) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + + if capturedAuth != tt.wantAuthHeader { + t.Errorf("Authorization = %q, want %q", capturedAuth, tt.wantAuthHeader) + } + if storeCalls != tt.wantStoreCalls { + t.Errorf("store calls = %d, want %d", storeCalls, tt.wantStoreCalls) + } + if tt.wantStoreCalls > 0 && (storedToken == nil || storedToken.AccessToken != "new-token") { + t.Errorf("stored token = %#v, want access token %q", storedToken, "new-token") + } + }) + } +} diff --git a/internal/secrets/keyring.go b/internal/secrets/keyring.go new file mode 100644 index 0000000000..9464b1054c --- /dev/null +++ b/internal/secrets/keyring.go @@ -0,0 +1,77 @@ +package secrets + +import ( + "context" + "strings" + + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/zalando/go-keyring" +) + +var ErrSecretNotFound = errors.New("secret not found") + +const serviceNamePrefix = "Sourcegraph CLI" + +type keyringStore struct { + ctx context.Context + serviceName string +} + +// Open opens the system keyring for the Sourcegraph CLI. +func Open(ctx context.Context, endpoint string) (*keyringStore, error) { + endpoint = strings.TrimRight(strings.TrimSpace(endpoint), "/") + if endpoint == "" { + return nil, errors.New("endpoint cannot be empty") + } + + serviceName := serviceNamePrefix + " <" + endpoint + ">" + + return &keyringStore{ctx: ctx, serviceName: serviceName}, nil +} + +// withContext runs fn in a goroutine and returns its result, or ctx.Err() if the context is cancelled first. +func withContext[T any](ctx context.Context, fn func() (T, error)) (T, error) { + type result struct { + val T + err error + } + ch := make(chan result, 1) + go func() { + val, err := fn() + ch <- result{val, err} + }() + + select { + case <-ctx.Done(): + var zero T + return zero, ctx.Err() + case r := <-ch: + return r.val, r.err + } +} + +// Put stores a key-value pair in the keyring. +func (k *keyringStore) Put(key string, data []byte) error { + _, err := withContext(k.ctx, func() (struct{}, error) { + err := keyring.Set(k.serviceName, key, string(data)) + if err != nil { + return struct{}{}, errors.Wrap(err, "storing item in keyring") + } + return struct{}{}, nil + }) + return err +} + +// Get retrieves a value by key from the keyring. +func (k *keyringStore) Get(key string) ([]byte, error) { + return withContext(k.ctx, func() ([]byte, error) { + secret, err := keyring.Get(k.serviceName, key) + if err != nil { + if err == keyring.ErrNotFound { + return nil, ErrSecretNotFound + } + return nil, errors.Wrap(err, "getting item from keyring") + } + return []byte(secret), nil + }) +} diff --git a/internal/secrets/keyring_test.go b/internal/secrets/keyring_test.go new file mode 100644 index 0000000000..65d2002782 --- /dev/null +++ b/internal/secrets/keyring_test.go @@ -0,0 +1,58 @@ +package secrets + +import ( + "context" + "testing" +) + +func TestOpen(t *testing.T) { + tests := []struct { + name string + endpoint string + wantServiceName string + wantErr bool + }{ + { + name: "normalized endpoint", + endpoint: " https://sourcegraph.example.com/ ", + wantServiceName: "Sourcegraph CLI ", + }, + { + name: "normalized endpoint with path", + endpoint: " https://sourcegraph.example.com/sourcegraph/ ", + wantServiceName: "Sourcegraph CLI ", + }, + { + name: "normalized endpoint with nested path", + endpoint: "https://sourcegraph.example.com/custom/path///", + wantServiceName: "Sourcegraph CLI ", + }, + { + name: "empty endpoint", + endpoint: " / ", + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + store, err := Open(context.Background(), test.endpoint) + if test.wantErr { + if err == nil { + t.Fatal("Open() error = nil, want non-nil") + } + if store != nil { + t.Fatalf("Open() store = %v, want nil", store) + } + return + } + + if err != nil { + t.Fatalf("Open() error = %v, want nil", err) + } + if got := store.serviceName; got != test.wantServiceName { + t.Fatalf("Open() serviceName = %q, want %q", got, test.wantServiceName) + } + }) + } +}