Skip to content

Commit da1619d

Browse files
authored
Update order of operation for request interceptor and serialization in gremlin-go (#3358)
* add proper error handling for non-graphbinary cases * Update order of interceptor and serialization of request * export request and rename to RequestMessage for clarity and public access like other GLVs * added wg to ensure in-flight goroutines complete & added response body drain before close to prevent TCP RST errors
1 parent 06e03ba commit da1619d

11 files changed

Lines changed: 912 additions & 118 deletions

CHANGELOG.asciidoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima
4141
* Removed deprecated `Graph.traversal()` method in JS in favor of the anonymous `traversal()` function.
4242
* Replace `Bytecode` with `GremlinLang` & update serialization to GraphBinary 4 for `gremlin-go`.
4343
* Added `RequestInterceptor` to `gremlin-go` with `auth` reference implementations to replace `authInfo`.
44-
* Refactored GraphBinary serializers to use `io.Writer` and `io.Reader` instead of `*bytes.Buffer` for streaming capacities.
44+
* Refactored GraphBinary serializers in `gremlin-go` to use `io.Writer` and `io.Reader` instead of `*bytes.Buffer` for streaming capacities.
4545
* Refactored `httpProtocol` and `httpTransport` in `gremlin-go` into single `connection.go` that handles HTTP request and response.
46+
* Reordered interceptor chain in `gremlin-go` so interceptors access raw request before serialization.
47+
* Exported `request` in `gremlin-go` as `RequestMessage` with public `Gremlin`/`Fields` for clarity, access and consistency.
4648
* Refactored result handling in `gremlin-driver` by merging `ResultQueue` into `ResultSet`.
4749
* Replace `Bytecode` with `GremlinLang` in `gremlin-dotnet`.
4850
* Replace `WebSocket` with `HTTP` (non-streaming) in `gremlin-dotnet`.

gremlin-go/driver/auth.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package gremlingo
2222
import (
2323
"context"
2424
"encoding/base64"
25+
"fmt"
2526
"sync"
2627
"time"
2728

@@ -39,26 +40,40 @@ func BasicAuth(username, password string) RequestInterceptor {
3940
}
4041
}
4142

42-
// Sigv4Auth returns a RequestInterceptor that signs requests using AWS SigV4.
43+
// SigV4Auth returns a RequestInterceptor that signs requests using AWS SigV4.
4344
// It uses the default AWS credential chain (env vars, shared config, IAM role, etc.)
44-
func Sigv4Auth(region, service string) RequestInterceptor {
45-
return Sigv4AuthWithCredentials(region, service, nil)
45+
func SigV4Auth(region, service string) RequestInterceptor {
46+
return SigV4AuthWithCredentials(region, service, nil)
4647
}
4748

48-
// Sigv4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4
49+
// SigV4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4
4950
// with the provided credentials provider. If provider is nil, uses default credential chain.
51+
// If the request body has not been serialized yet (*RequestMessage), it is automatically
52+
// serialized to GraphBinary before signing.
5053
//
5154
// Caches the signer and credentials provider for efficiency.
52-
func Sigv4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor {
55+
func SigV4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor {
5356
// Create signer once - it's stateless and safe to reuse
5457
signer := v4.NewSigner()
58+
serialize := SerializeRequest()
5559

5660
// Cache for resolved credentials provider (lazy initialization)
5761
var cachedProvider aws.CredentialsProvider
5862
var providerOnce sync.Once
5963
var providerErr error
6064

6165
return func(req *HttpRequest) error {
66+
// If Body is still *RequestMessage, serialize it to GraphBinary before signing.
67+
if _, ok := req.Body.(*RequestMessage); ok {
68+
if err := serialize(req); err != nil {
69+
return fmt.Errorf("SigV4 auto-serialization failed: %w", err)
70+
}
71+
}
72+
73+
if _, ok := req.Body.([]byte); !ok {
74+
return fmt.Errorf("SigV4 signing requires body to be []byte; got %T", req.Body)
75+
}
76+
6277
ctx := context.Background()
6378

6479
// Resolve credentials provider once if not provided

gremlin-go/driver/auth_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import (
3030
)
3131

3232
func createMockRequest() *HttpRequest {
33-
req, _ := NewHttpRequest("POST", "https://localhost:8182/gremlin")
33+
req, _ := NewHttpRequest("POST", "https://test_url:8182/gremlin")
3434
req.Headers.Set("Content-Type", graphBinaryMimeType)
3535
req.Headers.Set("Accept", graphBinaryMimeType)
3636
req.Body = []byte(`{"gremlin":"g.V()"}`)
@@ -72,24 +72,24 @@ func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials
7272
}, nil
7373
}
7474

75-
func TestSigv4Auth(t *testing.T) {
75+
func TestSigV4Auth(t *testing.T) {
7676
t.Run("adds signed headers", func(t *testing.T) {
7777
req := createMockRequest()
7878
assert.Empty(t, req.Headers.Get("Authorization"))
7979
assert.Empty(t, req.Headers.Get("X-Amz-Date"))
8080

8181
provider := &mockCredentialsProvider{
82-
accessKey: "MOCK_ACCESS_KEY",
83-
secretKey: "MOCK_SECRET_KEY",
82+
accessKey: "MOCK_ID",
83+
secretKey: "MOCK_KEY",
8484
}
85-
interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider)
85+
interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider)
8686
err := interceptor(req)
8787

8888
assert.NoError(t, err)
8989
assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"))
9090
authHeader := req.Headers.Get("Authorization")
91-
assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=MOCK_ACCESS_KEY"))
92-
assert.Contains(t, authHeader, "us-west-2/neptune-db/aws4_request")
91+
assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=MOCK_ID"))
92+
assert.Contains(t, authHeader, "gremlin-east-1/tinkerpop-sigv4/aws4_request")
9393
assert.Contains(t, authHeader, "Signature=")
9494
})
9595

@@ -98,17 +98,17 @@ func TestSigv4Auth(t *testing.T) {
9898
assert.Empty(t, req.Headers.Get("X-Amz-Security-Token"))
9999

100100
provider := &mockCredentialsProvider{
101-
accessKey: "MOCK_ACCESS_KEY",
102-
secretKey: "MOCK_SECRET_KEY",
103-
sessionToken: "MOCK_SESSION_TOKEN",
101+
accessKey: "MOCK_ID",
102+
secretKey: "MOCK_KEY",
103+
sessionToken: "MOCK_TOKEN",
104104
}
105-
interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider)
105+
interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider)
106106
err := interceptor(req)
107107

108108
assert.NoError(t, err)
109-
assert.Equal(t, "MOCK_SESSION_TOKEN", req.Headers.Get("X-Amz-Security-Token"))
109+
assert.Equal(t, "MOCK_TOKEN", req.Headers.Get("X-Amz-Security-Token"))
110110
authHeader := req.Headers.Get("Authorization")
111111
assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential="))
112-
assert.Contains(t, authHeader, "Signature=")
112+
assert.Contains(t, authHeader, "gremlin-east-1/tinkerpop-sigv4/aws4_request")
113113
})
114114
}

gremlin-go/driver/connection.go

Lines changed: 104 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -22,69 +22,17 @@ package gremlingo
2222
import (
2323
"bytes"
2424
"compress/zlib"
25-
"crypto/sha256"
2625
"crypto/tls"
27-
"encoding/hex"
26+
"encoding/json"
27+
"fmt"
2828
"io"
2929
"net"
3030
"net/http"
31-
"net/url"
31+
"strings"
32+
"sync"
3233
"time"
3334
)
3435

35-
// Common HTTP header keys
36-
const (
37-
HeaderContentType = "Content-Type"
38-
HeaderAccept = "Accept"
39-
HeaderUserAgent = "User-Agent"
40-
HeaderAcceptEncoding = "Accept-Encoding"
41-
HeaderAuthorization = "Authorization"
42-
)
43-
44-
// HttpRequest represents an HTTP request that can be modified by interceptors.
45-
type HttpRequest struct {
46-
Method string
47-
URL *url.URL
48-
Headers http.Header
49-
Body []byte
50-
}
51-
52-
// NewHttpRequest creates a new HttpRequest with the given method and URL.
53-
func NewHttpRequest(method, rawURL string) (*HttpRequest, error) {
54-
u, err := url.Parse(rawURL)
55-
if err != nil {
56-
return nil, err
57-
}
58-
return &HttpRequest{
59-
Method: method,
60-
URL: u,
61-
Headers: make(http.Header),
62-
}, nil
63-
}
64-
65-
// ToStdRequest converts HttpRequest to a standard http.Request for signing.
66-
// Returns nil if the request cannot be created (invalid method or URL).
67-
func (r *HttpRequest) ToStdRequest() (*http.Request, error) {
68-
req, err := http.NewRequest(r.Method, r.URL.String(), bytes.NewReader(r.Body))
69-
if err != nil {
70-
return nil, err
71-
}
72-
req.Header = r.Headers
73-
return req, nil
74-
}
75-
76-
// PayloadHash returns the SHA256 hash of the request body for SigV4 signing.
77-
func (r *HttpRequest) PayloadHash() string {
78-
if len(r.Body) == 0 {
79-
return "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string
80-
}
81-
h := sha256.Sum256(r.Body)
82-
return hex.EncodeToString(h[:])
83-
}
84-
85-
// RequestInterceptor is a function that modifies an HTTP request before it is sent.
86-
type RequestInterceptor func(*HttpRequest) error
87-
8836
// connectionSettings holds configuration for the connection.
8937
type connectionSettings struct {
9038
tlsConfig *tls.Config
@@ -106,6 +54,7 @@ type connection struct {
10654
logHandler *logHandler
10755
serializer *GraphBinarySerializer
10856
interceptors []RequestInterceptor
57+
wg sync.WaitGroup
10958
}
11059

11160
// Connection pool defaults aligned with Java driver
@@ -171,21 +120,19 @@ func (c *connection) AddInterceptor(interceptor RequestInterceptor) {
171120
}
172121

173122
// submit sends request and streams results directly to ResultSet
174-
func (c *connection) submit(req *request) (ResultSet, error) {
123+
func (c *connection) submit(req *RequestMessage) (ResultSet, error) {
175124
rs := newChannelResultSet()
176125

177-
data, err := c.serializer.SerializeMessage(req)
178-
if err != nil {
179-
rs.Close()
180-
return rs, err
181-
}
182-
183-
go c.executeAndStream(data, rs)
126+
c.wg.Add(1)
127+
go func() {
128+
defer c.wg.Done()
129+
c.executeAndStream(req, rs)
130+
}()
184131

185132
return rs, nil
186133
}
187134

188-
func (c *connection) executeAndStream(data []byte, rs ResultSet) {
135+
func (c *connection) executeAndStream(req *RequestMessage, rs ResultSet) {
189136
defer rs.Close()
190137

191138
// Create HttpRequest for interceptors
@@ -195,12 +142,15 @@ func (c *connection) executeAndStream(data []byte, rs ResultSet) {
195142
rs.setError(err)
196143
return
197144
}
198-
httpReq.Body = data
199145

200146
// Set default headers before interceptors
201147
c.setHttpRequestHeaders(httpReq)
202148

203-
// Apply interceptors
149+
// Set Body to the raw *RequestMessage so interceptors can inspect/modify it
150+
httpReq.Body = req
151+
152+
// Apply interceptors — they see *RequestMessage in Body (pre-serialization).
153+
// Interceptors may replace Body with []byte, io.Reader, or *http.Request.
204154
for _, interceptor := range c.interceptors {
205155
if err := interceptor(httpReq); err != nil {
206156
c.logHandler.logf(Error, failedToSendRequest, err.Error())
@@ -209,27 +159,90 @@ func (c *connection) executeAndStream(data []byte, rs ResultSet) {
209159
}
210160
}
211161

212-
// Create actual http.Request from HttpRequest
213-
req, err := http.NewRequest(httpReq.Method, httpReq.URL.String(), bytes.NewReader(httpReq.Body))
214-
if err != nil {
215-
c.logHandler.logf(Error, failedToSendRequest, err.Error())
216-
rs.setError(err)
162+
// After interceptors, serialize if Body is still *RequestMessage
163+
if r, ok := httpReq.Body.(*RequestMessage); ok {
164+
if c.serializer != nil {
165+
data, err := c.serializer.SerializeMessage(r)
166+
if err != nil {
167+
c.logHandler.logf(Error, failedToSendRequest, err.Error())
168+
rs.setError(err)
169+
return
170+
}
171+
httpReq.Body = data
172+
} else {
173+
errMsg := "request body was not serialized; either provide a serializer or add an interceptor that serializes the request"
174+
c.logHandler.logf(Error, failedToSendRequest, errMsg)
175+
rs.setError(fmt.Errorf("%s", errMsg))
176+
return
177+
}
178+
}
179+
180+
// Create actual http.Request from HttpRequest based on Body type
181+
var httpGoReq *http.Request
182+
switch body := httpReq.Body.(type) {
183+
case []byte:
184+
httpGoReq, err = http.NewRequest(httpReq.Method, httpReq.URL.String(), bytes.NewReader(body))
185+
if err != nil {
186+
c.logHandler.logf(Error, failedToSendRequest, err.Error())
187+
rs.setError(err)
188+
return
189+
}
190+
httpGoReq.Header = httpReq.Headers
191+
case io.Reader:
192+
httpGoReq, err = http.NewRequest(httpReq.Method, httpReq.URL.String(), body)
193+
if err != nil {
194+
c.logHandler.logf(Error, failedToSendRequest, err.Error())
195+
rs.setError(err)
196+
return
197+
}
198+
httpGoReq.Header = httpReq.Headers
199+
case *http.Request:
200+
httpGoReq = body
201+
default:
202+
errMsg := fmt.Sprintf("unsupported body type after interceptors: %T", body)
203+
c.logHandler.logf(Error, failedToSendRequest, errMsg)
204+
rs.setError(fmt.Errorf("%s", errMsg))
217205
return
218206
}
219-
req.Header = httpReq.Headers
220207

221-
resp, err := c.httpClient.Do(req)
208+
resp, err := c.httpClient.Do(httpGoReq)
222209
if err != nil {
223210
c.logHandler.logf(Error, failedToSendRequest, err.Error())
224211
rs.setError(err)
225212
return
226213
}
227214
defer func() {
215+
// Drain any unread bytes so the connection can be reused gracefully.
216+
// Without this, Go's HTTP client sends a TCP RST instead of FIN,
217+
// causing "Connection reset by peer" errors on the server.
218+
io.Copy(io.Discard, resp.Body)
228219
if err := resp.Body.Close(); err != nil {
229220
c.logHandler.logf(Debug, failedToCloseResponseBody, err.Error())
230221
}
231222
}()
232223

224+
// If the HTTP status indicates an error and the response is not GraphBinary,
225+
// read the body as a text/JSON error message instead of attempting binary
226+
// deserialization which would produce cryptic errors.
227+
contentType := resp.Header.Get(HeaderContentType)
228+
if resp.StatusCode >= 400 && !strings.Contains(contentType, graphBinaryMimeType) {
229+
bodyBytes, readErr := io.ReadAll(resp.Body)
230+
if readErr != nil {
231+
c.logHandler.logf(Error, failedToReceiveResponse, readErr.Error())
232+
rs.setError(fmt.Errorf("Gremlin Server returned HTTP %d and failed to read body: %w",
233+
resp.StatusCode, readErr))
234+
return
235+
}
236+
errorBody := string(bodyBytes)
237+
errorMsg := tryExtractJSONError(errorBody)
238+
if errorMsg == "" {
239+
errorMsg = fmt.Sprintf("Gremlin Server returned HTTP %d: %s", resp.StatusCode, errorBody)
240+
}
241+
c.logHandler.logf(Error, failedToReceiveResponse, errorMsg)
242+
rs.setError(fmt.Errorf("%s", errorMsg))
243+
return
244+
}
245+
233246
reader, zlibReader, err := c.getReader(resp)
234247
if err != nil {
235248
c.logHandler.logf(Error, failedToReceiveResponse, err.Error())
@@ -308,6 +321,23 @@ func (c *connection) streamToResultSet(reader io.Reader, rs ResultSet) {
308321
}
309322
}
310323

324+
// tryExtractJSONError attempts to extract an error message from a JSON response body.
325+
// The server sometimes responds with a JSON object containing a "message" field
326+
// even when it cannot produce a GraphBinary response.
327+
func tryExtractJSONError(body string) string {
328+
var obj map[string]interface{}
329+
if err := json.Unmarshal([]byte(body), &obj); err != nil {
330+
return ""
331+
}
332+
if msg, ok := obj["message"]; ok {
333+
if s, ok := msg.(string); ok {
334+
return s
335+
}
336+
}
337+
return ""
338+
}
339+
311340
func (c *connection) close() {
341+
c.wg.Wait()
312342
c.httpClient.CloseIdleConnections()
313343
}

0 commit comments

Comments
 (0)