diff --git a/plugins/nf-k8s/src/main/nextflow/k8s/K8sTaskHandler.groovy b/plugins/nf-k8s/src/main/nextflow/k8s/K8sTaskHandler.groovy index ae035ee78a..270e20dad6 100644 --- a/plugins/nf-k8s/src/main/nextflow/k8s/K8sTaskHandler.groovy +++ b/plugins/nf-k8s/src/main/nextflow/k8s/K8sTaskHandler.groovy @@ -70,8 +70,6 @@ class K8sTaskHandler extends TaskHandler implements FusionAwareTask { private ResourceType resourceType = ResourceType.Pod - private K8sClient client - private String podName private BashWrapperBuilder builder @@ -93,7 +91,6 @@ class K8sTaskHandler extends TaskHandler implements FusionAwareTask { K8sTaskHandler( TaskRun task, K8sExecutor executor ) { super(task) this.executor = executor - this.client = executor.getClient() this.outputFile = task.workDir.resolve(TaskRun.CMD_OUTFILE) this.errorFile = task.workDir.resolve(TaskRun.CMD_ERRFILE) this.exitFile = task.workDir.resolve(TaskRun.CMD_EXIT) @@ -116,6 +113,8 @@ class K8sTaskHandler extends TaskHandler implements FusionAwareTask { protected K8sConfig getK8sConfig() { executor.getK8sConfig() } + protected K8sClient getClient() { executor.getClient() } + protected boolean useJobResource() { resourceType==ResourceType.Job } protected List getContainerMounts() { diff --git a/plugins/nf-k8s/src/main/nextflow/k8s/client/ClientConfig.groovy b/plugins/nf-k8s/src/main/nextflow/k8s/client/ClientConfig.groovy index aed479e130..c293a57771 100644 --- a/plugins/nf-k8s/src/main/nextflow/k8s/client/ClientConfig.groovy +++ b/plugins/nf-k8s/src/main/nextflow/k8s/client/ClientConfig.groovy @@ -49,6 +49,13 @@ class ClientConfig { String token + /** + * Filesystem path of the token, when the token was loaded from a file. + * Used to re-read the token after expiry — kubelet rotates projected + * service-account tokens in place by overwriting the mounted file. + */ + Path tokenPath + byte[] sslCert byte[] clientCert @@ -108,8 +115,10 @@ class ClientConfig { if( opts.token ) result.token = opts.token - else if( opts.tokenFile ) - result.token = Paths.get(opts.tokenFile.toString()).getText('UTF-8') + else if( opts.tokenFile ) { + result.tokenPath = Paths.get(opts.tokenFile.toString()) + result.token = result.tokenPath.getText('UTF-8') + } result.namespace = namespace ?: opts.namespace ?: 'default' @@ -143,7 +152,8 @@ class ClientConfig { result.token = user.token else if( user.tokenFile ) { - result.token = Paths.get(user.tokenFile.toString()).getText('UTF-8') + result.tokenPath = Paths.get(user.tokenFile.toString()) + result.token = result.tokenPath.getText('UTF-8') } if( user."client-certificate" ) diff --git a/plugins/nf-k8s/src/main/nextflow/k8s/client/ConfigDiscovery.groovy b/plugins/nf-k8s/src/main/nextflow/k8s/client/ConfigDiscovery.groovy index 9f85dfc535..202f311e6d 100644 --- a/plugins/nf-k8s/src/main/nextflow/k8s/client/ConfigDiscovery.groovy +++ b/plugins/nf-k8s/src/main/nextflow/k8s/client/ConfigDiscovery.groovy @@ -84,12 +84,13 @@ class ConfigDiscovery { final server = formatHostName(host, port) final cert = path('/var/run/secrets/kubernetes.io/serviceaccount/ca.crt').bytes - final token = path('/var/run/secrets/kubernetes.io/serviceaccount/token').text + final tokenFile = path('/var/run/secrets/kubernetes.io/serviceaccount/token') final namespace = path('/var/run/secrets/kubernetes.io/serviceaccount/namespace').text return new ClientConfig( server: server, - token: token, + token: tokenFile.text, + tokenPath: tokenFile, namespace: cfgNamespace ?: namespace, serviceAccount: serviceAccount, sslCert: cert, diff --git a/plugins/nf-k8s/src/main/nextflow/k8s/client/K8sClient.groovy b/plugins/nf-k8s/src/main/nextflow/k8s/client/K8sClient.groovy index 48c95929c4..bbce0569b2 100644 --- a/plugins/nf-k8s/src/main/nextflow/k8s/client/K8sClient.groovy +++ b/plugins/nf-k8s/src/main/nextflow/k8s/client/K8sClient.groovy @@ -741,6 +741,9 @@ class K8sClient { @Override void accept(ExecutionAttemptedEvent event) throws Throwable { log.debug("K8s response error - attempt: ${event.attemptCount}; reason: ${event.lastFailure.message}") + final t = event.lastFailure + if( t instanceof K8sResponseException && t.response.code == 401 ) + refreshToken() } } return RetryPolicy.builder() @@ -752,6 +755,25 @@ class K8sClient { .build() } + /** + * Reload the service-account token from {@link ClientConfig#tokenPath} so that + * a request retried after a 401 picks up a token rotated in place by kubelet. + */ + protected void refreshToken() { + if( !config.tokenPath ) + return + try { + final newToken = config.tokenPath.getText('UTF-8') + if( newToken && newToken != config.token ) { + log.debug "[K8s] Refreshing service-account token from ${config.tokenPath}" + config.token = newToken + } + } + catch( Exception e ) { + log.warn "[K8s] Unable to refresh service-account token from ${config.tokenPath} - cause: ${e.message}" + } + } + final private static List RETRY_CODES = List.of(408, 429, 500, 502, 503, 504) /** @@ -767,6 +789,9 @@ class K8sClient { boolean test(Throwable t) { if ( t instanceof K8sResponseException && t.response.code in RETRY_CODES ) return true + // 401 is retried only when the token was loaded from a file and can be re-read from disk + if ( t instanceof K8sResponseException && t.response.code == 401 && config.tokenPath ) + return true if( t instanceof SocketException || t.cause instanceof SocketException ) return true if( t instanceof SocketTimeoutException || t.cause instanceof SocketTimeoutException ) diff --git a/plugins/nf-k8s/src/test/nextflow/k8s/K8sTaskHandlerTest.groovy b/plugins/nf-k8s/src/test/nextflow/k8s/K8sTaskHandlerTest.groovy index e63e36b0f4..76fca3c4dc 100644 --- a/plugins/nf-k8s/src/test/nextflow/k8s/K8sTaskHandlerTest.groovy +++ b/plugins/nf-k8s/src/test/nextflow/k8s/K8sTaskHandlerTest.groovy @@ -62,7 +62,8 @@ class K8sTaskHandlerTest extends Specification { def task = Mock(TaskRun) def client = Mock(K8sClient) def builder = Mock(K8sWrapperBuilder) - def handler = Spy(new K8sTaskHandler(builder:builder, client: client)) + def handler = Spy(new K8sTaskHandler(builder:builder)) + handler.getClient() >> client Map result when: @@ -162,7 +163,8 @@ class K8sTaskHandlerTest extends Specification { def task = Mock(TaskRun) def client = Mock(K8sClient) def builder = Mock(K8sWrapperBuilder) - def handler = Spy(new K8sTaskHandler(builder: builder, client:client)) + def handler = Spy(new K8sTaskHandler(builder: builder)) + handler.getClient() >> client Map result when: @@ -198,7 +200,8 @@ class K8sTaskHandlerTest extends Specification { def client = Mock(K8sClient) def builder = Mock(K8sWrapperBuilder) def config = Mock(ClientConfig) - def handler = Spy(new K8sTaskHandler(builder: builder, client: client)) + def handler = Spy(new K8sTaskHandler(builder: builder)) + handler.getClient() >> client Map result when: @@ -233,7 +236,8 @@ class K8sTaskHandlerTest extends Specification { def client = Mock(K8sClient) def builder = Mock(K8sWrapperBuilder) def config = Mock(TaskConfig) - def handler = Spy(new K8sTaskHandler(builder:builder, client:client)) + def handler = Spy(new K8sTaskHandler(builder:builder)) + handler.getClient() >> client def podOptions = Mock(PodOptions) and: Map result @@ -281,7 +285,8 @@ class K8sTaskHandlerTest extends Specification { def task = Mock(TaskRun) def client = Mock(K8sClient) def builder = Mock(K8sWrapperBuilder) - def handler = Spy(new K8sTaskHandler(builder:builder, client:client)) + def handler = Spy(new K8sTaskHandler(builder:builder)) + handler.getClient() >> client def podOptions = Mock(PodOptions) and: Map result @@ -354,7 +359,8 @@ class K8sTaskHandlerTest extends Specification { def task = Mock(TaskRun) def client = Mock(K8sClient) def builder = Mock(K8sWrapperBuilder) - def handler = Spy(new K8sTaskHandler(client: client, task:task)) + def handler = Spy(new K8sTaskHandler(task:task)) + handler.getClient() >> client def POD_NAME = 'new-pod-id' def REQUEST = [foo: 'bar'] @@ -391,7 +397,8 @@ class K8sTaskHandlerTest extends Specification { def builder = Mock(K8sWrapperBuilder) def config = Mock(TaskConfig) def executor = Mock(K8sExecutor) - def handler = Spy(new K8sTaskHandler(builder: builder, client: client, executor: executor)) + def handler = Spy(new K8sTaskHandler(builder: builder, executor: executor)) + handler.getClient() >> client def podOptions = Mock(PodOptions) and: Map result @@ -440,7 +447,8 @@ class K8sTaskHandlerTest extends Specification { given: def POD_NAME = 'pod-xyz' def client = Mock(K8sClient) - def handler = Spy(new K8sTaskHandler(client: client, podName: POD_NAME, status: TaskStatus.SUBMITTED)) + def handler = Spy(new K8sTaskHandler(podName: POD_NAME, status: TaskStatus.SUBMITTED)) + handler.getClient() >> client when: def result = handler.checkIfRunning() @@ -479,7 +487,8 @@ class K8sTaskHandlerTest extends Specification { finishedAt: "2018-01-13T10:19:36Z" ] def noExitCodeState = [terminated: noExitCodeTermState] and: - def handler = Spy(new K8sTaskHandler(task: task, client:client, podName: POD_NAME, outputFile: OUT_FILE, errorFile: ERR_FILE)) + def handler = Spy(new K8sTaskHandler(task: task, podName: POD_NAME, outputFile: OUT_FILE, errorFile: ERR_FILE)) + handler.getClient() >> client when: def result = handler.checkIfCompleted() @@ -539,7 +548,8 @@ class K8sTaskHandlerTest extends Specification { finishedAt: "2018-01-13T10:19:36Z", exitCode: 137 ] def task = new TaskRun() - def handler = Spy(new K8sTaskHandler(task: task, client:client, podName: POD_NAME, outputFile: OUT_FILE, errorFile: ERR_FILE)) + def handler = Spy(new K8sTaskHandler(task: task, podName: POD_NAME, outputFile: OUT_FILE, errorFile: ERR_FILE)) + handler.getClient() >> client when: def result = handler.checkIfCompleted() @@ -558,7 +568,8 @@ class K8sTaskHandlerTest extends Specification { given: def POD_NAME = 'pod-xyz' def client = Mock(K8sClient) - def handler = Spy(new K8sTaskHandler(client:client, podName: POD_NAME)) + def handler = Spy(new K8sTaskHandler(podName: POD_NAME)) + handler.getClient() >> client when: handler.killTask() @@ -577,7 +588,8 @@ class K8sTaskHandlerTest extends Specification { given: def POD_NAME = 'pod-xyz' def client = Mock(K8sClient) - def handler = Spy(new K8sTaskHandler(client:client, podName: POD_NAME)) + def handler = Spy(new K8sTaskHandler(podName: POD_NAME)) + handler.getClient() >> client and: Map STATE1 = [status:'pending'] Map STATE2 = [status:'running'] @@ -646,7 +658,8 @@ class K8sTaskHandlerTest extends Specification { given: def POD_NAME = 'pod-xyz' def client = Mock(K8sClient) - def handler = Spy(new K8sTaskHandler(client:client, podName: POD_NAME)) + def handler = Spy(new K8sTaskHandler(podName: POD_NAME)) + handler.getClient() >> client when: def state = handler.getState() @@ -664,7 +677,8 @@ class K8sTaskHandlerTest extends Specification { given: def POD_NAME = 'pod-xyz' def client = Mock(K8sClient) - def handler = Spy(new K8sTaskHandler(client:client, podName: POD_NAME)) + def handler = Spy(new K8sTaskHandler(podName: POD_NAME)) + handler.getClient() >> client when: def state = handler.getState() @@ -782,7 +796,8 @@ class K8sTaskHandlerTest extends Specification { def POD_NAME = 'the-pod-name' def executor = Mock(K8sExecutor) def client = Mock(K8sClient) - def handler = Spy(new K8sTaskHandler(podName: POD_NAME, executor:executor, client:client)) + def handler = Spy(new K8sTaskHandler(podName: POD_NAME, executor:executor)) + handler.getClient() >> client handler.useJobResource() >> false and: def TASK_OK = Mock(TaskRun); TASK_OK.isSuccess() >> true @@ -814,7 +829,8 @@ class K8sTaskHandlerTest extends Specification { def POD_NAME = 'the-job-name' def executor = Mock(K8sExecutor) def client = Mock(K8sClient) - def handler = Spy(new K8sTaskHandler(podName: POD_NAME, executor:executor, client:client)) + def handler = Spy(new K8sTaskHandler(podName: POD_NAME, executor:executor)) + handler.getClient() >> client handler.useJobResource() >> true and: def TASK_OK = Mock(TaskRun); TASK_OK.isSuccess() >> true @@ -846,7 +862,8 @@ class K8sTaskHandlerTest extends Specification { def executor = Mock(K8sExecutor) def client = Mock(K8sClient) and: - def handler = Spy(new K8sTaskHandler(executor: executor, client: client, podName: POD_NAME)) + def handler = Spy(new K8sTaskHandler(executor: executor, podName: POD_NAME)) + handler.getClient() >> client when: handler.saveJobLogOnError(task) @@ -976,7 +993,8 @@ class K8sTaskHandlerTest extends Specification { def client = Mock(K8sClient) def builder = Mock(K8sWrapperBuilder) def launcher = Mock(FusionScriptLauncher) - def handler = Spy(new K8sTaskHandler(builder:builder, client: client)) + def handler = Spy(new K8sTaskHandler(builder:builder)) + handler.getClient() >> client Map result when: @@ -1022,7 +1040,8 @@ class K8sTaskHandlerTest extends Specification { def launcher = Mock(FusionScriptLauncher) def k8sConfig = Spy(K8sConfig) def exec = Mock(K8sExecutor) { getK8sConfig()>>k8sConfig } - def handler = Spy(new K8sTaskHandler(builder:builder, client: client, executor: exec)) + def handler = Spy(new K8sTaskHandler(builder:builder, executor: exec)) + handler.getClient() >> client Map result when: diff --git a/plugins/nf-k8s/src/test/nextflow/k8s/client/ClientConfigTest.groovy b/plugins/nf-k8s/src/test/nextflow/k8s/client/ClientConfigTest.groovy index f2305a5564..a2ef03d2b7 100644 --- a/plugins/nf-k8s/src/test/nextflow/k8s/client/ClientConfigTest.groovy +++ b/plugins/nf-k8s/src/test/nextflow/k8s/client/ClientConfigTest.groovy @@ -117,4 +117,60 @@ class ClientConfigTest extends Specification { folder?.deleteDir() } + def 'should preserve token file path when reading token from tokenFile in nextflow config' () { + + given: + def folder = Files.createTempDirectory('test') + def tokenFile = folder.resolve('token') + tokenFile.text = 'file-token' + + def MAP = [ + server: 'foo.com', + tokenFile: tokenFile ] + + when: + def result = ClientConfig.fromNextflowConfig(MAP, null, null) + + then: + result.token == 'file-token' + result.tokenPath == tokenFile + + cleanup: + folder?.deleteDir() + } + + def 'should preserve token file path when reading token from tokenFile in kubeconfig' () { + + given: + def folder = Files.createTempDirectory('test') + def tokenFile = folder.resolve('token') + tokenFile.text = 'file-token' + + def user = [ tokenFile: tokenFile.toString() ] + def cluster = [ server: 'https://foo:6443' ] + + when: + def result = ClientConfig.fromUserAndCluster(user, cluster, folder) + + then: + result.token == 'file-token' + result.tokenPath == tokenFile + + cleanup: + folder?.deleteDir() + } + + def 'should not set token path when token is provided inline' () { + + given: + def MAP = [ server: 'foo.com', token: 'inline-token' ] + + when: + def result = ClientConfig.fromNextflowConfig(MAP, null, null) + + then: + result.token == 'inline-token' + result.tokenPath == null + } + } diff --git a/plugins/nf-k8s/src/test/nextflow/k8s/client/ConfigDiscoveryTest.groovy b/plugins/nf-k8s/src/test/nextflow/k8s/client/ConfigDiscoveryTest.groovy index 38ac946839..774ca210b0 100644 --- a/plugins/nf-k8s/src/test/nextflow/k8s/client/ConfigDiscoveryTest.groovy +++ b/plugins/nf-k8s/src/test/nextflow/k8s/client/ConfigDiscoveryTest.groovy @@ -387,6 +387,7 @@ class ConfigDiscoveryTest extends Specification { config.server == 'foo.com:4343' config.namespace == 'foo-namespace' config.token == 'my-token' + config.tokenPath == TOKEN_FILE config.sslCert == CERT_FILE.text.bytes config.isFromCluster diff --git a/plugins/nf-k8s/src/test/nextflow/k8s/client/K8sClientTest.groovy b/plugins/nf-k8s/src/test/nextflow/k8s/client/K8sClientTest.groovy index 263a6e3cad..ea0458785f 100644 --- a/plugins/nf-k8s/src/test/nextflow/k8s/client/K8sClientTest.groovy +++ b/plugins/nf-k8s/src/test/nextflow/k8s/client/K8sClientTest.groovy @@ -16,6 +16,8 @@ package nextflow.k8s.client +import java.nio.file.Files + import nextflow.exception.K8sOutOfCpuException import nextflow.exception.K8sOutOfMemoryException @@ -1102,4 +1104,78 @@ class K8sClientTest extends Specification { result.terminated.exitCode == null result.terminated.exitcode == null } + + def 'should re-read token from disk and retry on 401 when tokenPath is set' () { + + given: + def folder = Files.createTempDirectory('test') + def tokenFile = folder.resolve('token') + // file already holds the rotated token (kubelet has written it) + tokenFile.text = 'fresh-token' + + final client = Spy(K8sClient) + client.config.server = 'host.com:443' + client.config.token = 'stale-token' + client.config.tokenPath = tokenFile + // shorten retry delay so the test is fast + client.config.retryConfig.delay = nextflow.util.Duration.of('1ms') + + def CONN_401 = Mock(HttpsURLConnection) + def CONN_200 = Mock(HttpsURLConnection) + + when: + def resp = client.makeRequest('GET', '/api/v1/pods/foo/status') + + then: + // first attempt: stale token, 401 + 2 * client.createConnection0("https://host.com:443/api/v1/pods/foo/status") >>> [CONN_401, CONN_200] + 2 * client.setupHttpsConn(_) >> null + 1 * CONN_401.setRequestProperty("Authorization", "Bearer stale-token") + 1 * CONN_401.setRequestMethod('GET') + 1 * CONN_401.setRequestProperty("Content-Type", "application/json") + 1 * CONN_401.getResponseCode() >> 401 + 1 * CONN_401.getErrorStream() >> { new ByteArrayInputStream('{"kind":"Status","status":"Failure","message":"Unauthorized","code":401}'.bytes) } + + and: + // second attempt: fresh token from disk, 200 + 1 * CONN_200.setRequestProperty("Authorization", "Bearer fresh-token") + 1 * CONN_200.setRequestMethod('GET') + 1 * CONN_200.setRequestProperty("Content-Type", "application/json") + 1 * CONN_200.getResponseCode() >> 200 + 1 * CONN_200.getInputStream() >> { new ByteArrayInputStream('{"ok":true}'.bytes) } + + and: + resp.text == '{"ok":true}' + client.config.token == 'fresh-token' + + cleanup: + folder?.deleteDir() + } + + def 'should not retry 401 when tokenPath is not set' () { + + given: + final client = Spy(K8sClient) + client.config.server = 'host.com:443' + client.config.token = 'inline-token' + // tokenPath is null — token can't be re-read from disk + + def CONN = Mock(HttpsURLConnection) + + when: + client.makeRequest('GET', '/api/v1/pods/foo/status') + + then: + // exactly one attempt — no retry because there's no file to re-read + 1 * client.createConnection0(_) >> CONN + 1 * client.setupHttpsConn(_) >> null + 1 * CONN.setRequestMethod('GET') + 1 * CONN.setRequestProperty("Authorization", "Bearer inline-token") + 1 * CONN.setRequestProperty("Content-Type", "application/json") + 1 * CONN.getResponseCode() >> 401 + 1 * CONN.getErrorStream() >> { new ByteArrayInputStream('{"kind":"Status","status":"Failure","message":"Unauthorized","code":401}'.bytes) } + + and: + thrown(K8sResponseException) + } }