OSDir


[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[GitHub] kaxil closed pull request #3532: [AIRFLOW-2658] Add GCP specific k8s pod operator


kaxil closed pull request #3532: [AIRFLOW-2658] Add GCP specific k8s pod operator
URL: https://github.com/apache/incubator-airflow/pull/3532
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/gcp_container_operator.py b/airflow/contrib/operators/gcp_container_operator.py
index 5648b4d8a0..615eac8a0f 100644
--- a/airflow/contrib/operators/gcp_container_operator.py
+++ b/airflow/contrib/operators/gcp_container_operator.py
@@ -17,8 +17,13 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+import os
+import subprocess
+import tempfile
+
 from airflow import AirflowException
 from airflow.contrib.hooks.gcp_container_hook import GKEClusterHook
+from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
 from airflow.models import BaseOperator
 from airflow.utils.decorators import apply_defaults
 
@@ -170,3 +175,147 @@ def execute(self, context):
         hook = GKEClusterHook(self.project_id, self.location)
         create_op = hook.create_cluster(cluster=self.body)
         return create_op
+
+
+KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
+G_APP_CRED = "GOOGLE_APPLICATION_CREDENTIALS"
+
+
+class GKEPodOperator(KubernetesPodOperator):
+    template_fields = ('project_id', 'location',
+                       'cluster_name') + KubernetesPodOperator.template_fields
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 location,
+                 cluster_name,
+                 gcp_conn_id='google_cloud_default',
+                 *args,
+                 **kwargs):
+        """
+        Executes a task in a Kubernetes pod in the specified Google Kubernetes
+        Engine cluster
+
+        This Operator assumes that the system has gcloud installed and either
+        has working default application credentials or has configured a
+        connection id with a service account.
+
+        The **minimum** required to define a cluster to create are the variables
+        ``task_id``, ``project_id``, ``location``, ``cluster_name``, ``name``,
+        ``namespace``, and ``image``
+
+        **Operator Creation**: ::
+
+            operator = GKEPodOperator(task_id='pod_op',
+                                      project_id='my-project',
+                                      location='us-central1-a',
+                                      cluster_name='my-cluster-name',
+                                      name='task-name',
+                                      namespace='default',
+                                      image='perl')
+
+        .. seealso::
+            For more detail about application authentication have a look at the reference:
+            https://cloud.google.com/docs/authentication/production#providing_credentials_to_your_application
+
+        :param project_id: The Google Developers Console project id
+        :type project_id: str
+        :param location: The name of the Google Kubernetes Engine zone in which the
+            cluster resides, e.g. 'us-central1-a'
+        :type location: str
+        :param cluster_name: The name of the Google Kubernetes Engine cluster the pod
+            should be spawned in
+        :type cluster_name: str
+        :param gcp_conn_id: The google cloud connection id to use. This allows for
+            users to specify a service account.
+        :type gcp_conn_id: str
+        """
+        super(GKEPodOperator, self).__init__(*args, **kwargs)
+        self.project_id = project_id
+        self.location = location
+        self.cluster_name = cluster_name
+        self.gcp_conn_id = gcp_conn_id
+
+    def execute(self, context):
+        # Specifying a service account file allows the user to using non default
+        # authentication for creating a Kubernetes Pod. This is done by setting the
+        # environment variable `GOOGLE_APPLICATION_CREDENTIALS` that gcloud looks at.
+        key_file = None
+
+        # If gcp_conn_id is not specified gcloud will use the default
+        # service account credentials.
+        if self.gcp_conn_id:
+            from airflow.hooks.base_hook import BaseHook
+            # extras is a deserialized json object
+            extras = BaseHook.get_connection(self.gcp_conn_id).extra_dejson
+            # key_file only gets set if a json file is created from a JSON string in
+            # the web ui, else none
+            key_file = self._set_env_from_extras(extras=extras)
+
+        # Write config to a temp file and set the environment variable to point to it.
+        # This is to avoid race conditions of reading/writing a single file
+        with tempfile.NamedTemporaryFile() as conf_file:
+            os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name
+            # Attempt to get/update credentials
+            # We call gcloud directly instead of using google-cloud-python api
+            # because there is no way to write kubernetes config to a file, which is
+            # required by KubernetesPodOperator.
+            # The gcloud command looks at the env variable `KUBECONFIG` for where to save
+            # the kubernetes config file.
+            subprocess.check_call(
+                ["gcloud", "container", "clusters", "get-credentials",
+                 self.cluster_name,
+                 "--zone", self.location,
+                 "--project", self.project_id])
+
+            # Since the key file is of type mkstemp() closing the file will delete it from
+            # the file system so it cannot be accessed after we don't need it anymore
+            if key_file:
+                key_file.close()
+
+            # Tell `KubernetesPodOperator` where the config file is located
+            self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
+            super(GKEPodOperator, self).execute(context)
+
+    def _set_env_from_extras(self, extras):
+        """
+        Sets the environment variable `GOOGLE_APPLICATION_CREDENTIALS` with either:
+
+        - The path to the keyfile from the specified connection id
+        - A generated file's path if the user specified JSON in the connection id. The
+            file is assumed to be deleted after the process dies due to how mkstemp()
+            works.
+
+        The environment variable is used inside the gcloud command to determine correct
+        service account to use.
+        """
+        key_path = self._get_field(extras, 'key_path', False)
+        keyfile_json_str = self._get_field(extras, 'keyfile_dict', False)
+
+        if not key_path and not keyfile_json_str:
+            self.log.info('Using gcloud with application default credentials.')
+        elif key_path:
+            os.environ[G_APP_CRED] = key_path
+        else:
+            # Write service account JSON to secure file for gcloud to reference
+            service_key = tempfile.NamedTemporaryFile(delete=False)
+            service_key.write(keyfile_json_str)
+            os.environ[G_APP_CRED] = service_key.name
+            # Return file object to have a pointer to close after use,
+            # thus deleting from file system.
+            return service_key
+
+    def _get_field(self, extras, field, default=None):
+        """
+        Fetches a field from extras, and returns it. This is some Airflow
+        magic. The google_cloud_platform hook type adds custom UI elements
+        to the hook page, which allow admins to specify service_account,
+        key_path, etc. They get formatted as shown below.
+        """
+        long_f = 'extra__google_cloud_platform__{}'.format(field)
+        if long_f in extras:
+            return extras[long_f]
+        else:
+            self.log.info('Field {} not found in extras.'.format(field))
+            return default
diff --git a/docs/code.rst b/docs/code.rst
index 4f1b301711..716a13f334 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -153,6 +153,7 @@ Operators
 .. autoclass:: airflow.contrib.operators.file_to_wasb.FileToWasbOperator
 .. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterCreateOperator
 .. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator
+.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEPodOperator
 .. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator
 .. autoclass:: airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageListOperator
 .. autoclass:: airflow.contrib.operators.gcs_operator.GoogleCloudStorageCreateBucketOperator
diff --git a/docs/integration.rst b/docs/integration.rst
index 99dbafbd2b..4c513bf26d 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -776,6 +776,12 @@ GKEClusterDeleteOperator
 .. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEClusterDeleteOperator
 .. _GKEClusterDeleteOperator:
 
+GKEPodOperator
+^^^^^^^^^^^^^^
+
+.. autoclass:: airflow.contrib.operators.gcp_container_operator.GKEPodOperator
+.. _GKEPodOperator:
+
 Google Kubernetes Engine Hook
 """""""""""""""""""""""""""""
 
diff --git a/tests/contrib/operators/test_gcp_container_operator.py b/tests/contrib/operators/test_gcp_container_operator.py
index 0f67290cbf..1685e9c6ac 100644
--- a/tests/contrib/operators/test_gcp_container_operator.py
+++ b/tests/contrib/operators/test_gcp_container_operator.py
@@ -17,11 +17,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import os
 import unittest
 
 from airflow import AirflowException
 from airflow.contrib.operators.gcp_container_operator import GKEClusterCreateOperator, \
-    GKEClusterDeleteOperator
+    GKEClusterDeleteOperator, GKEPodOperator
+from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
 
 try:
     from unittest import mock
@@ -39,6 +41,15 @@
 PROJECT_BODY = {'name': 'test-name'}
 PROJECT_BODY_CREATE = {'name': 'test-name', 'initial_node_count': 1}
 
+TASK_NAME = 'test-task-name'
+NAMESPACE = 'default',
+IMAGE = 'bash'
+
+GCLOUD_COMMAND = "gcloud container clusters get-credentials {} --zone {} --project {}"
+KUBE_ENV_VAR = 'KUBECONFIG'
+GAC_ENV_VAR = 'GOOGLE_APPLICATION_CREDENTIALS'
+FILE_NAME = '/tmp/mock_name'
+
 
 class GoogleCloudPlatformContainerOperatorTest(unittest.TestCase):
 
@@ -123,3 +134,162 @@ def test_delete_execute_error_location(self, mock_hook):
 
             operator.execute(None)
             mock_hook.return_value.delete_cluster.assert_not_called()
+
+
+class GKEPodOperatorTest(unittest.TestCase):
+    def setUp(self):
+        self.gke_op = GKEPodOperator(project_id=PROJECT_ID,
+                                     location=PROJECT_LOCATION,
+                                     cluster_name=CLUSTER_NAME,
+                                     task_id=PROJECT_TASK_ID,
+                                     name=TASK_NAME,
+                                     namespace=NAMESPACE,
+                                     image=IMAGE)
+
+    def test_template_fields(self):
+        self.assertTrue(set(KubernetesPodOperator.template_fields).issubset(
+            GKEPodOperator.template_fields))
+
+    @mock.patch(
+        'airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator.execute')
+    @mock.patch('tempfile.NamedTemporaryFile')
+    @mock.patch("subprocess.check_call")
+    def test_execute_conn_id_none(self, proc_mock, file_mock, exec_mock):
+        self.gke_op.gcp_conn_id = None
+
+        file_mock.return_value.__enter__.return_value.name = FILE_NAME
+
+        self.gke_op.execute(None)
+
+        # Assert Environment Variable is being set correctly
+        self.assertIn(KUBE_ENV_VAR, os.environ)
+        self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME)
+
+        # Assert the gcloud command being called correctly
+        proc_mock.assert_called_with(
+            GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, PROJECT_ID).split())
+
+        self.assertEqual(self.gke_op.config_file, FILE_NAME)
+
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch(
+        'airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator.execute')
+    @mock.patch('tempfile.NamedTemporaryFile')
+    @mock.patch("subprocess.check_call")
+    @mock.patch.dict(os.environ, {})
+    def test_execute_conn_id_path(self, proc_mock, file_mock, exec_mock, get_con_mock):
+        # gcp_conn_id is defaulted to `google_cloud_default`
+
+        FILE_PATH = '/path/to/file'
+        KEYFILE_DICT = {"extra__google_cloud_platform__key_path": FILE_PATH}
+        get_con_mock.return_value.extra_dejson = KEYFILE_DICT
+        file_mock.return_value.__enter__.return_value.name = FILE_NAME
+
+        self.gke_op.execute(None)
+
+        # Assert Environment Variable is being set correctly
+        self.assertIn(KUBE_ENV_VAR, os.environ)
+        self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME)
+
+        self.assertIn(GAC_ENV_VAR, os.environ)
+        # since we passed in keyfile_path we should get a file
+        self.assertEqual(os.environ[GAC_ENV_VAR], FILE_PATH)
+
+        # Assert the gcloud command being called correctly
+        proc_mock.assert_called_with(
+            GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, PROJECT_ID).split())
+
+        self.assertEqual(self.gke_op.config_file, FILE_NAME)
+
+    @mock.patch.dict(os.environ, {})
+    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
+    @mock.patch(
+        'airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator.execute')
+    @mock.patch('tempfile.NamedTemporaryFile')
+    @mock.patch("subprocess.check_call")
+    def test_execute_conn_id_dict(self, proc_mock, file_mock, exec_mock, get_con_mock):
+        # gcp_conn_id is defaulted to `google_cloud_default`
+        FILE_PATH = '/path/to/file'
+
+        # This is used in the _set_env_from_extras method
+        file_mock.return_value.name = FILE_PATH
+        # This is used in the execute method
+        file_mock.return_value.__enter__.return_value.name = FILE_NAME
+
+        KEYFILE_DICT = {"extra__google_cloud_platform__keyfile_dict":
+                        '{"private_key": "r4nd0m_k3y"}'}
+        get_con_mock.return_value.extra_dejson = KEYFILE_DICT
+
+        self.gke_op.execute(None)
+
+        # Assert Environment Variable is being set correctly
+        self.assertIn(KUBE_ENV_VAR, os.environ)
+        self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME)
+
+        self.assertIn(GAC_ENV_VAR, os.environ)
+        # since we passed in keyfile_path we should get a file
+        self.assertEqual(os.environ[GAC_ENV_VAR], FILE_PATH)
+
+        # Assert the gcloud command being called correctly
+        proc_mock.assert_called_with(
+            GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, PROJECT_ID).split())
+
+        self.assertEqual(self.gke_op.config_file, FILE_NAME)
+
+    @mock.patch.dict(os.environ, {})
+    def test_set_env_from_extras_none(self):
+        extras = {}
+        self.gke_op._set_env_from_extras(extras)
+        # _set_env_from_extras should not edit os.environ if extras does not specify
+        self.assertNotIn(GAC_ENV_VAR, os.environ)
+
+    @mock.patch.dict(os.environ, {})
+    @mock.patch('tempfile.NamedTemporaryFile')
+    def test_set_env_from_extras_dict(self, file_mock):
+        file_mock.return_value.name = FILE_NAME
+
+        KEYFILE_DICT_STR = '{ \"test\": \"cluster\" }'
+        extras = {
+            'extra__google_cloud_platform__keyfile_dict': KEYFILE_DICT_STR,
+        }
+
+        self.gke_op._set_env_from_extras(extras)
+        self.assertEquals(os.environ[GAC_ENV_VAR], FILE_NAME)
+
+        file_mock.return_value.write.assert_called_once_with(KEYFILE_DICT_STR)
+
+    @mock.patch.dict(os.environ, {})
+    def test_set_env_from_extras_path(self):
+        TEST_PATH = '/test/path'
+
+        extras = {
+            'extra__google_cloud_platform__key_path': TEST_PATH,
+        }
+
+        self.gke_op._set_env_from_extras(extras)
+        self.assertEquals(os.environ[GAC_ENV_VAR], TEST_PATH)
+
+    def test_get_field(self):
+        FIELD_NAME = 'test_field'
+        FIELD_VALUE = 'test_field_value'
+        extras = {
+            'extra__google_cloud_platform__{}'.format(FIELD_NAME):
+                FIELD_VALUE
+        }
+
+        ret_val = self.gke_op._get_field(extras, FIELD_NAME)
+        self.assertEqual(FIELD_VALUE, ret_val)
+
+    @mock.patch('airflow.contrib.operators.gcp_container_operator.GKEPodOperator.log')
+    def test_get_field_fail(self, log_mock):
+        log_mock.info = mock.Mock()
+        LOG_STR = 'Field {} not found in extras.'
+        FIELD_NAME = 'test_field'
+        FIELD_VALUE = 'test_field_value'
+
+        extras = {}
+
+        ret_val = self.gke_op._get_field(extras, FIELD_NAME, default=FIELD_VALUE)
+        # Assert default is returned upon failure
+        self.assertEqual(FIELD_VALUE, ret_val)
+        log_mock.info.assert_called_with(LOG_STR.format(FIELD_NAME))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@xxxxxxxxxxxxxxxx


With regards,
Apache Git Services