Source code for airflow.providers.google.cloud.hooks.compute_ssh
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportrandomimportshleximporttimefromfunctoolsimportcached_propertyfromioimportStringIOfromtypingimportAnyfromgoogleapiclient.errorsimportHttpErrorfromparamiko.ssh_exceptionimportSSHExceptionfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.google.cloud.hooks.computeimportComputeEngineHookfromairflow.providers.google.cloud.hooks.os_loginimportOSLoginHookfromairflow.providers.google.common.hooks.base_googleimportPROVIDE_PROJECT_IDfromairflow.providers.ssh.hooks.sshimportSSHHookfromairflow.utils.typesimportNOTSET,ArgNotSet# Paramiko should be imported after airflow.providers.ssh. Then the import will fail with# cannot import "airflow.providers.ssh" and will be correctly discovered as optional feature# TODO:(potiuk) We should add test harness detecting such cases shortlyimportparamiko# isort:skip
class_GCloudAuthorizedSSHClient(paramiko.SSHClient):"""SSH Client that maintains the context for gcloud authorization during the connection."""def__init__(self,google_hook,*args,**kwargs):super().__init__(*args,**kwargs)self.ssh_client=paramiko.SSHClient()self.google_hook=google_hookself.decorator=Nonedefconnect(self,*args,**kwargs):self.decorator=self.google_hook.provide_authorized_gcloud()self.decorator.__enter__()returnsuper().connect(*args,**kwargs)defclose(self):ifself.decorator:self.decorator.__exit__(None,None,None)self.decorator=Nonereturnsuper().close()def__exit__(self,type_,value,traceback):ifself.decorator:self.decorator.__exit__(type_,value,traceback)self.decorator=Nonereturnsuper().__exit__(type_,value,traceback)
[docs]classComputeEngineSSHHook(SSHHook):""" Hook to connect to a remote instance in compute engine. :param instance_name: The name of the Compute Engine instance :param zone: The zone of the Compute Engine instance :param user: The name of the user on which the login attempt will be made :param project_id: The project ID of the remote instance :param gcp_conn_id: The connection id to use when fetching connection info :param hostname: The hostname of the target instance. If it is not passed, it will be detected automatically. :param use_iap_tunnel: Whether to connect through IAP tunnel :param use_internal_ip: Whether to connect using internal IP :param use_oslogin: Whether to manage keys using OsLogin API. If false, keys are managed using instance metadata :param expire_time: The maximum amount of time in seconds before the private key expires :param gcp_conn_id: The connection id to use when fetching connection information :param max_retries: Maximum number of retries the process will try to establish connection to instance. Could be decreased/increased by user based on the amount of parallel SSH connections to the instance. :param impersonation_chain: Optional. The service account email to impersonate using short-term credentials. The provided service account must grant the originating account the Service Account Token Creator IAM role and have the sufficient rights to perform the request """
def__init__(self,gcp_conn_id:str="google_cloud_default",instance_name:str|None=None,zone:str|None=None,user:str|None="root",project_id:str=PROVIDE_PROJECT_ID,hostname:str|None=None,use_internal_ip:bool=False,use_iap_tunnel:bool=False,use_oslogin:bool=True,expire_time:int=300,cmd_timeout:int|ArgNotSet=NOTSET,max_retries:int=10,impersonation_chain:str|None=None,**kwargs,)->None:ifkwargs.get("delegate_to")isnotNone:raiseRuntimeError("The `delegate_to` parameter has been deprecated before and finally removed in this version"" of Google Provider. You MUST convert it to `impersonation_chain`")# Ignore original constructor# super().__init__()self.gcp_conn_id=gcp_conn_idself.instance_name=instance_nameself.zone=zoneself.user=userself.project_id=project_idself.hostname=hostnameself.use_internal_ip=use_internal_ipself.use_iap_tunnel=use_iap_tunnelself.use_oslogin=use_osloginself.expire_time=expire_timeself.cmd_timeout=cmd_timeoutself.max_retries=max_retriesself.impersonation_chain=impersonation_chainself._conn:Any|None=None@cached_propertydef_oslogin_hook(self)->OSLoginHook:returnOSLoginHook(gcp_conn_id=self.gcp_conn_id)@cached_propertydef_compute_hook(self)->ComputeEngineHook:ifself.impersonation_chain:returnComputeEngineHook(gcp_conn_id=self.gcp_conn_id,impersonation_chain=self.impersonation_chain)else:returnComputeEngineHook(gcp_conn_id=self.gcp_conn_id)def_load_connection_config(self):def_boolify(value):ifisinstance(value,bool):returnvalueifisinstance(value,str):ifvalue.lower()=="false":returnFalseelifvalue.lower()=="true":returnTruereturnFalsedefintify(key,value,default):ifvalueisNone:returndefaultifisinstance(value,str)andvalue.strip()=="":returndefaulttry:returnint(value)exceptValueError:raiseAirflowException(f"The {key} field should be a integer. "f'Current value: "{value}" (type: {type(value)}). 'f"Please check the connection configuration.")conn=self.get_connection(self.gcp_conn_id)ifconnandconn.conn_type=="gcpssh":self.instance_name=self._compute_hook._get_field("instance_name",self.instance_name)self.zone=self._compute_hook._get_field("zone",self.zone)self.user=conn.loginifconn.loginelseself.user# self.project_id is skipped intentionallyself.hostname=conn.hostifconn.hostelseself.hostnameself.use_internal_ip=_boolify(self._compute_hook._get_field("use_internal_ip"))self.use_iap_tunnel=_boolify(self._compute_hook._get_field("use_iap_tunnel"))self.use_oslogin=_boolify(self._compute_hook._get_field("use_oslogin"))self.expire_time=intify("expire_time",self._compute_hook._get_field("expire_time"),self.expire_time,)ifconn.extraisnotNone:extra_options=conn.extra_dejsonif"cmd_timeout"inextra_optionsandself.cmd_timeoutisNOTSET:ifextra_options["cmd_timeout"]:self.cmd_timeout=int(extra_options["cmd_timeout"])else:self.cmd_timeout=Noneifself.cmd_timeoutisNOTSET:self.cmd_timeout=CMD_TIMEOUT
[docs]defget_conn(self)->paramiko.SSHClient:"""Return SSH connection."""self._load_connection_config()ifnotself.project_id:self.project_id=self._compute_hook.project_idmissing_fields=[kforkin["instance_name","zone","project_id"]ifnotgetattr(self,k)]ifnotself.instance_nameornotself.zoneornotself.project_id:raiseAirflowException(f"Required parameters are missing: {missing_fields}. These parameters be passed either as ""keyword parameter or as extra field in Airflow connection definition. Both are not set!")self.log.info("Connecting to instance: instance_name=%s, user=%s, zone=%s, ""use_internal_ip=%s, use_iap_tunnel=%s, use_os_login=%s",self.instance_name,self.user,self.zone,self.use_internal_ip,self.use_iap_tunnel,self.use_oslogin,)ifnotself.hostname:hostname=self._compute_hook.get_instance_address(zone=self.zone,resource_id=self.instance_name,project_id=self.project_id,use_internal_ip=self.use_internal_iporself.use_iap_tunnel,)else:hostname=self.hostnameprivkey,pubkey=self._generate_ssh_key(self.user)max_delay=10sshclient=Noneforretryinrange(self.max_retries+1):try:ifself.use_oslogin:user=self._authorize_os_login(pubkey)else:user=self.userself._authorize_compute_engine_instance_metadata(pubkey)proxy_command=Noneifself.use_iap_tunnel:proxy_command_args=["gcloud","compute","start-iap-tunnel",str(self.instance_name),"22","--listen-on-stdin",f"--project={self.project_id}",f"--zone={self.zone}","--verbosity=warning",]ifself.impersonation_chain:proxy_command_args.append(f"--impersonate-service-account={self.impersonation_chain}")proxy_command=" ".join(shlex.quote(arg)forarginproxy_command_args)sshclient=self._connect_to_instance(user,hostname,privkey,proxy_command)breakexcept(HttpError,AirflowException,SSHException)asexc:if(isinstance(exc,HttpError)andexc.resp.status==412)or(isinstance(exc,AirflowException)and"412 PRECONDITION FAILED"instr(exc)):self.log.info("Error occurred when trying to update instance metadata: %s",exc)elifisinstance(exc,SSHException):self.log.info("Error occurred when establishing SSH connection using Paramiko: %s",exc)else:raiseifretry==self.max_retries:raiseAirflowException("Maximum retries exceeded. Aborting operation.")delay=random.randint(0,max_delay)self.log.info("Failed establish SSH connection, waiting %s seconds to retry...",delay)time.sleep(delay)ifnotsshclient:raiseAirflowException("Unable to establish SSH connection.")returnsshclient
def_connect_to_instance(self,user,hostname,pkey,proxy_command)->paramiko.SSHClient:self.log.info("Opening remote connection to host: username=%s, hostname=%s",user,hostname)max_time_to_wait=5fortime_to_waitinrange(max_time_to_wait+1):try:client=_GCloudAuthorizedSSHClient(self._compute_hook)# Default is RejectPolicy# No known host checking since we are not storing privatekeyclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())# nosec B507client.connect(hostname=hostname,username=user,pkey=pkey,sock=paramiko.ProxyCommand(proxy_command)ifproxy_commandelseNone,look_for_keys=False,)returnclientexceptparamiko.SSHException:iftime_to_wait==max_time_to_wait:raiseself.log.info("Failed to connect. Waiting %ds to retry",time_to_wait)time.sleep(time_to_wait)raiseAirflowException("Can not connect to instance")def_authorize_compute_engine_instance_metadata(self,pubkey):self.log.info("Appending SSH public key to instance metadata")instance_info=self._compute_hook.get_instance_info(zone=self.zone,resource_id=self.instance_name,project_id=self.project_id)keys=self.user+":"+pubkey+"\n"metadata=instance_info["metadata"]items=metadata.get("items",[])foriteminitems:ifitem.get("key")=="ssh-keys":keys+=item["value"]item["value"]=keysbreakelse:new_dict={"key":"ssh-keys","value":keys}metadata["items"]=[*items,new_dict]self._compute_hook.set_instance_metadata(zone=self.zone,resource_id=self.instance_name,metadata=metadata,project_id=self.project_id)def_authorize_os_login(self,pubkey):username=self._oslogin_hook._get_credentials_emailself.log.info("Importing SSH public key using OSLogin: user=%s",username)expiration=int((time.time()+self.expire_time)*1000000)ssh_public_key={"key":pubkey,"expiration_time_usec":expiration}response=self._oslogin_hook.import_ssh_public_key(user=username,ssh_public_key=ssh_public_key,project_id=self.project_id)profile=response.login_profileaccount=profile.posix_accounts[0]user=account.usernamereturnuserdef_generate_ssh_key(self,user):try:self.log.info("Generating ssh keys...")pkey_file=StringIO()pkey_obj=paramiko.RSAKey.generate(2048)pkey_obj.write_private_key(pkey_file)pubkey=f"{pkey_obj.get_name()}{pkey_obj.get_base64()}{user}"returnpkey_obj,pubkeyexcept(OSError,paramiko.SSHException)aserr:raiseAirflowException(f"Error encountered creating ssh keys, {err}")