-
Notifications
You must be signed in to change notification settings - Fork 3.4k
per-host http client pool #7301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 1 commit
95739a2
fa6ee67
5516e1c
ae6b21f
96360b8
ab8cf23
7024544
6c3d9b8
60d0627
32134d5
cef06fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,10 +7,12 @@ import ( | |
| "net" | ||
| "net/http" | ||
| "net/http/cookiejar" | ||
| "net/http/httptrace" | ||
| "net/url" | ||
| "strconv" | ||
| "strings" | ||
| "sync" | ||
| "sync/atomic" | ||
| "time" | ||
|
|
||
| "github.com/pkg/errors" | ||
|
|
@@ -22,16 +24,48 @@ import ( | |
| "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" | ||
| "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" | ||
| "github.com/projectdiscovery/nuclei/v3/pkg/types" | ||
| "github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy" | ||
| "github.com/projectdiscovery/rawhttp" | ||
| "github.com/projectdiscovery/retryablehttp-go" | ||
| urlutil "github.com/projectdiscovery/utils/url" | ||
| ) | ||
|
|
||
| var ( | ||
| forceMaxRedirects int | ||
| connStats ConnectionStats | ||
| ) | ||
|
|
||
| // ConnectionStats tracks HTTP connection reuse across the scan. | ||
| type ConnectionStats struct { | ||
| New atomic.Int64 | ||
| Reused atomic.Int64 | ||
| } | ||
|
|
||
| // GetConnectionStats returns the current connection statistics. | ||
| func GetConnectionStats() (newConns, reused int64) { | ||
| return connStats.New.Load(), connStats.Reused.Load() | ||
| } | ||
|
|
||
| // connTrackingTransport wraps an http.RoundTripper to track connection reuse | ||
| // via httptrace. Every request gets a GotConn callback that increments the | ||
| // appropriate counter before delegating to the underlying transport. | ||
| type connTrackingTransport struct { | ||
| base http.RoundTripper | ||
| } | ||
|
|
||
| func (t *connTrackingTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||
| trace := &httptrace.ClientTrace{ | ||
| GotConn: func(info httptrace.GotConnInfo) { | ||
| if info.Reused { | ||
| connStats.Reused.Add(1) | ||
| } else { | ||
| connStats.New.Add(1) | ||
| } | ||
| }, | ||
| } | ||
| req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) | ||
| return t.base.RoundTrip(req) | ||
| } | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| // Init initializes the clientpool implementation | ||
| func Init(options *types.Options) error { | ||
| if options.ShouldFollowHTTPRedirects() { | ||
|
|
@@ -167,21 +201,15 @@ func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client { | |
| return dialers.RawHTTPClient | ||
| } | ||
|
|
||
| // Get creates or gets a client for the protocol based on custom configuration | ||
| func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { | ||
| if configuration.HasStandardOptions() { | ||
| dialers := protocolstate.GetDialersWithId(options.ExecutionId) | ||
| if dialers == nil { | ||
| return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId) | ||
| } | ||
| return dialers.DefaultHTTPClient, nil | ||
| } | ||
|
|
||
| return wrappedGet(options, configuration) | ||
| // Get creates or gets a client for the protocol based on custom configuration. | ||
| // The host parameter scopes the client to a specific target, enabling per-host | ||
| // connection reuse with keep-alive. Pass an empty string for non-scanning uses. | ||
| func Get(options *types.Options, configuration *Configuration, host string) (*retryablehttp.Client, error) { | ||
| return wrappedGet(options, configuration, host) | ||
| } | ||
|
|
||
| // wrappedGet wraps a get operation without normal client check | ||
| func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { | ||
| func wrappedGet(options *types.Options, configuration *Configuration, host string) (*retryablehttp.Client, error) { | ||
| var err error | ||
|
|
||
| dialers := protocolstate.GetDialersWithId(options.ExecutionId) | ||
|
|
@@ -190,28 +218,30 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl | |
| } | ||
|
|
||
| hash := configuration.Hash() | ||
| if host != "" { | ||
| hash += ":" + host | ||
| } | ||
| if client, ok := dialers.HTTPClientPool.Get(hash); ok { | ||
| return client, nil | ||
| } | ||
|
|
||
| // Multiple Host | ||
| retryableHttpOptions := retryablehttp.DefaultOptionsSpraying | ||
| disableKeepAlives := true | ||
| maxIdleConns := 0 | ||
| maxConnsPerHost := 0 | ||
| maxIdleConnsPerHost := -1 | ||
| // do not split given timeout into chunks for retry | ||
| // because this won't work on slow hosts | ||
| // Each client is scoped to a single host, so we optimize for connection | ||
| // reuse: keep-alive always on, small idle pool, and an idle timeout that | ||
| // lets the transport reclaim unused connections automatically. | ||
| retryableHttpOptions := retryablehttp.DefaultOptionsSingle | ||
| retryableHttpOptions.NoAdjustTimeout = true | ||
|
|
||
| if configuration.Threads > 0 || options.ScanStrategy == scanstrategy.HostSpray.String() { | ||
| // Single host | ||
| retryableHttpOptions = retryablehttp.DefaultOptionsSingle | ||
| disableKeepAlives = false | ||
| maxIdleConnsPerHost = 500 | ||
| maxConnsPerHost = 500 | ||
| maxIdleConns := 4 | ||
| maxIdleConnsPerHost := 4 | ||
| maxConnsPerHost := 25 | ||
| if configuration.Threads > 0 { | ||
| maxIdleConnsPerHost = configuration.Threads | ||
| maxIdleConns = configuration.Threads | ||
| maxConnsPerHost = configuration.Threads | ||
| } | ||
|
|
||
| disableKeepAlives := configuration.Connection != nil && configuration.Connection.DisableKeepAlive | ||
|
|
||
|
Comment on lines
235
to
+258
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Include
🩹 Proposed fix func (c *Configuration) Hash() string {
builder := &strings.Builder{}
builder.Grow(16)
@@
builder.WriteString("c")
builder.WriteString(strconv.FormatBool(c.Connection != nil))
+ if c.Connection != nil {
+ builder.WriteString("d")
+ builder.WriteString(strconv.FormatBool(c.Connection.DisableKeepAlive))
+ }
if c.Connection != nil && c.Connection.CustomMaxTimeout > 0 {
builder.WriteString("k")
builder.WriteString(c.Connection.CustomMaxTimeout.String())
}🤖 Prompt for AI Agents |
||
| retryableHttpOptions.RetryWaitMax = 10 * time.Second | ||
| retryableHttpOptions.RetryMax = options.Retries | ||
| retryableHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second | ||
|
|
@@ -222,7 +252,6 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl | |
| maxRedirects := configuration.MaxRedirects | ||
|
|
||
| if forceMaxRedirects > 0 { | ||
| // by default we enable general redirects following | ||
| switch { | ||
| case options.FollowHostRedirects: | ||
| redirectFlow = FollowSameHostRedirect | ||
|
|
@@ -238,30 +267,23 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl | |
| maxRedirects = 0 | ||
| } | ||
|
|
||
| // override connection's settings if required | ||
| if configuration.Connection != nil { | ||
| disableKeepAlives = configuration.Connection.DisableKeepAlive | ||
| } | ||
|
|
||
| // Set the base TLS configuration definition | ||
| tlsConfig := &tls.Config{ | ||
| Renegotiation: tls.RenegotiateOnceAsClient, | ||
| InsecureSkipVerify: true, | ||
| MinVersion: tls.VersionTLS10, | ||
| ClientSessionCache: tls.NewLRUClientSessionCache(1024), | ||
| ClientSessionCache: tls.NewLRUClientSessionCache(128), | ||
| } | ||
|
|
||
| if options.SNI != "" { | ||
| tlsConfig.ServerName = options.SNI | ||
| } | ||
|
|
||
| // Add the client certificate authentication to the request if it's configured | ||
| tlsConfig, err = utils.AddConfiguredClientCertToRequest(tlsConfig, options) | ||
| if err != nil { | ||
| return nil, errors.Wrap(err, "could not create client certificate") | ||
| } | ||
|
|
||
| // responseHeaderTimeout is max timeout for response headers to be read | ||
| responseHeaderTimeout := options.GetTimeouts().HttpResponseHeaderTimeout | ||
| if configuration.ResponseHeaderTimeout != 0 { | ||
| responseHeaderTimeout = configuration.ResponseHeaderTimeout | ||
|
|
@@ -292,6 +314,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl | |
| MaxConnsPerHost: maxConnsPerHost, | ||
| TLSClientConfig: tlsConfig, | ||
| DisableKeepAlives: disableKeepAlives, | ||
| IdleConnTimeout: 30 * time.Second, | ||
| ResponseHeaderTimeout: responseHeaderTimeout, | ||
| } | ||
|
|
||
|
|
@@ -333,6 +356,11 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl | |
| } | ||
| } | ||
|
|
||
| // Each per-host client gets its own default cookie jar. This is safe | ||
| // because cookies are domain-scoped per RFC 6265, and same-host iterations | ||
| // (workflows, multi-step templates) hit the same cached client so cookies | ||
| // are retained across requests. Explicit jars from input.CookieJar bypass | ||
| // the cache entirely for full isolation. | ||
| var jar *cookiejar.Jar | ||
| if configuration.Connection != nil && configuration.Connection.HasCookieJar() { | ||
| jar = configuration.Connection.GetCookieJar() | ||
|
|
@@ -343,7 +371,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl | |
| } | ||
|
|
||
| httpclient := &http.Client{ | ||
| Transport: transport, | ||
| Transport: &connTrackingTransport{base: transport}, | ||
| CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects), | ||
| } | ||
| if !configuration.NoTimeout { | ||
|
|
@@ -358,8 +386,11 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl | |
| } | ||
| client.CheckRetry = retryablehttp.HostSprayRetryPolicy() | ||
|
|
||
| // Only add to client pool if we don't have a cookie jar in place. | ||
| if jar == nil { | ||
| // Cache the client unless it has an explicit per-request cookie jar. | ||
| // Default jars (from DisableCookie=false) are fine to cache since they | ||
| // just provide standard cookie handling within a host's connection pool. | ||
| hasExplicitJar := configuration.Connection != nil && configuration.Connection.HasCookieJar() | ||
| if !hasExplicitJar { | ||
| if err := dialers.HTTPClientPool.Set(hash, client); err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Connection stats will be cumulative across multiple in-process scans.
GetConnectionStats()reads package-global counters, so a second SDK/embedded run in the same process will log totals from earlier executions too. If this is meant to diagnose one scan, reset or scope these stats byExecutionIdwhen a run starts or ends.🤖 Prompt for AI Agents