Ansible源码分析之get_url模块

Ansible源码分析之get_url模块

在之前的command模块中提到过get_url其实是curl和wget命令的替代模块。再看该模块的文档描述:

  • 将文件从HTTP,HTTPS或FTP下载到远程服务器。远程服务器I(必须)可以直接访问远程资源。
  • 默认情况下,如果在目标主机上设置了环境变量C( _proxy),则请求将通过该代理发送。可以通过为此任务设置变量或使用use_proxy选项来覆盖此行为。
  • HTTP重定向可以从HTTP重定向到HTTPS,因此您应确保两种协议的代理环境均正确。
  • 从Ansible 2.4开始,使用C(-check)运行时,它将发出HEAD请求以验证URL,但不会下载整个文件或针对哈希进行验证。
  • 对于Windows目标,请改用 win_get_url 模块。
import datetime
import os
import re
import shutil
import tempfile
import traceback

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.six.moves.urllib.parse import urlsplit
from ansible.module_utils._text import to_native
from ansible.module_utils.urls import fetch_url, url_argument_spec

从导入的模块来看,都是些常用的。shutil 模块提供了一系列对文件和文件集合的高阶操作。 特别是提供了一些支持文件拷贝和删除的函数。其次module_utils.urls的中fetch_url是封装的发起请求的模块也需要单独分析。

还是从main函数的代码开始看。

def main():
    argument_spec = url_argument_spec()

    argument_spec['url_username']['aliases'] = ['username']
    argument_spec['url_password']['aliases'] = ['password']

    argument_spec.update(
        url=dict(type='str', required=True),
        dest=dict(type='path', required=True),
        backup=dict(type='bool', default=False),
        sha256sum=dict(type='str', default=''),
        checksum=dict(type='str', default=''),
        timeout=dict(type='int', default=10),
        headers=dict(type='dict'),
        tmp_dest=dict(type='path'),
    )

    module = AnsibleModule(
        argument_spec=argument_spec,
        add_file_common_args=True,
        supports_check_mode=True,
        mutually_exclusive=[['checksum', 'sha256sum']],
    )

是不是感觉这里不一样,其他模块的main函数都是一上来就是module = AnsibleModule()实例化对象。这里的argument_spec是来自urls模块的赋值,跟进看代码。

def url_argument_spec():
    return dict(
        url=dict(type='str'),
        force=dict(type='bool', default=False, aliases=['thirsty'],
                   deprecated_aliases=[dict(name='thirsty', version='2.13', collection_name='ansible.builtin')]),
        http_agent=dict(type='str', default='ansible-httpget'),
        use_proxy=dict(type='bool', default=True),
        validate_certs=dict(type='bool', default=True),
        url_username=dict(type='str'),
        url_password=dict(type='str', no_log=True),
        force_basic_auth=dict(type='bool', default=False),
        client_cert=dict(type='path'),
        client_key=dict(type='path'),
        use_gssapi=dict(type='bool', default=False),
    )

创建一个参数规范,该规范可以与将通过urllib / urllib2请求内容的任何模块一起使用。然后get_url模块的argument_spec再在其基础上增加/修改相应的参数。同时checksum和sha256sum参数是互斥的。

    if module.params.get('thirsty'):
        module.deprecate('The alias "thirsty" has been deprecated and will be removed, use "force" instead',
                         version='2.13', collection_name='ansible.builtin')

    if module.params.get('sha256sum'):
        module.deprecate('The parameter "sha256sum" has been deprecated and will be removed, use "checksum" instead',
                         version='2.14', collection_name='ansible.builtin')

    url = module.params['url']
    dest = module.params['dest']
    backup = module.params['backup']
    force = module.params['force']
    sha256sum = module.params['sha256sum']
    checksum = module.params['checksum']
    use_proxy = module.params['use_proxy']
    timeout = module.params['timeout']
    headers = module.params['headers']
    tmp_dest = module.params['tmp_dest']

    result = dict(
        changed=False,
        checksum_dest=None,
        checksum_src=None,
        dest=dest,
        elapsed=0,
        url=url,
    )

当发现模块接收到thirsty和sha256sum这两个参数时,都会给用户提醒这两个参数已经被弃用了,分别用force和checksum代替。然后将模块参数赋值给对应变量,以及存储结果信息的字典。

dest_is_dir = os.path.isdir(dest)
last_mod_time = None

if sha256sum:
    checksum = 'sha256:%s' % (sha256sum)

if checksum:
    try:
        algorithm, checksum = checksum.split(':', 1)
    except ValueError:
        module.fail_json(msg="The checksum parameter has to be in format <algorithm>:<checksum>", **result)

    if is_url(checksum):
        checksum_url = checksum
        # download checksum file to checksum_tmpsrc
        checksum_tmpsrc, checksum_info = url_get(module, checksum_url, dest, use_proxy, last_mod_time, force, timeout, headers, tmp_dest)
        with open(checksum_tmpsrc) as f:
            lines = [line.rstrip('\n') for line in f]
        os.remove(checksum_tmpsrc)
        checksum_map = []
        for line in lines:
            parts = line.split(None, 1)
            if len(parts) == 2:
                checksum_map.append((parts[0], parts[1]))
        filename = url_filename(url)

        for cksum in (s for (s, f) in checksum_map if f.strip('./') == filename):
            checksum = cksum
            break
        else:
            checksum = None

        if checksum is None:
            module.fail_json(msg="Unable to find a checksum for file '%s' in '%s'" % (filename, checksum_url))
    checksum = re.sub(r'\W+', '', checksum).lower()
    try:
        int(checksum, 16)
    except ValueError:
        module.fail_json(msg='The checksum format is invalid', **result)

判断给出dest参数是否为路径。处理使用不推荐使用的sha256sum参数的解决方法,将sha256sum转化checksum的格式,然后判断checksum。指定校验和,解析算法和校验和。

然后调用is_url函数判断checksum,将校验和文件下载到checksum_tmpsrc,并读取该文件,将内容的的每一行放入列表。line.split(None,1) 和 line.split() 的效果是一样的,只不过后面要传参需要用None作为位置参数。

for cksum in (s for (s, f) in checksum_map if f.strip('./') == filename):

在校验和文件的每一行中查找与URL中文件名相对应的哈希,并返回找到的第一个哈希。不过这里用到了元组推导式,很多人都知道列表推导式,而元组推导式却略有耳闻。简单介绍一下:

元组推导式可以利用 range 区间、元组、列表、字典和集合等数据类型,快速生成一个满足指定需求的元组。

元组推导式的语法格式如下:
(表达式 for 迭代变量 in 可迭代对象 [if 条件表达式] )

例如,我们可以使用下面的代码生成一个包含数字 1~9 的元组:

a = (x for x in range(1,10))
print(a)

运行结果为:
<generator object <genexpr> at 0x0000020BAD136620>

从上面的执行结果可以看出,使用元组推导式生成的结果并不是一个元组,而是一个生成器对象,这一点和列表推导式是不同的。

回到代码,获取checksum之后,删除所有非字母数字字符,包括臭名昭著的Unicode零宽度空格,所以这里用\W来匹配替换的。然后用int转换确保校验和部分是十六进制的。

上面这段代码中,调用了该模块内的其他函数,我们依次看下。

def is_url(checksum):
    supported_schemes = ('http', 'https', 'ftp', 'file')

    return urlsplit(checksum).scheme in supported_schemes

如果校验和的值支持URL协议返回True ,否则返回False。再来看url_get函数的代码。

def url_get(module, url, dest, use_proxy, last_mod_time, force, timeout=10, headers=None, tmp_dest='', method='GET'):
    start = datetime.datetime.utcnow()
    rsp, info = fetch_url(module, url, use_proxy=use_proxy, force=force, last_mod_time=last_mod_time, timeout=timeout, headers=headers, method=method)
    elapsed = (datetime.datetime.utcnow() - start).seconds

    if info['status'] == 304:
        module.exit_json(url=url, dest=dest, changed=False, msg=info.get('msg', ''), status_code=info['status'], elapsed=elapsed)

 if info['status'] == -1:
        module.fail_json(msg=info['msg'], url=url, dest=dest, elapsed=elapsed)

    if info['status'] != 200 and not url.startswith('file:/') and not (url.startswith('ftp:/') and info.get('msg', '').startswith('OK')):
        module.fail_json(msg="Request failed", status_code=info['status'], response=info['msg'], url=url, dest=dest, elapsed=elapsed)

    if tmp_dest:
        tmp_dest_is_dir = os.path.isdir(tmp_dest)
        if not tmp_dest_is_dir:
            if os.path.exists(tmp_dest):
                module.fail_json(msg="%s is a file but should be a directory." % tmp_dest, elapsed=elapsed)
            else:
                module.fail_json(msg="%s directory does not exist." % tmp_dest, elapsed=elapsed)
    else:
        tmp_dest = module.tmpdir

    fd, tempname = tempfile.mkstemp(dir=tmp_dest)

    f = os.fdopen(fd, 'wb')
    try:
        shutil.copyfileobj(rsp, f)
    except Exception as e:
        os.remove(tempname)
        module.fail_json(msg="failed to create temporary content file: %s" % to_native(e), elapsed=elapsed, exception=traceback.format_exc())
    f.close()
    rsp.close()
    return tempname, info

从url下载数据并存储在一个临时文件中,返回(临时文件,有关请求的信息)。通过两次相减记录时间差。可以看到实际调用fetch_url发起的请求,我不进一步跟踪,只看fetch_url返回的两个参数。

(response info)的元组。使用response.read()读取数据。 info 包含“状态”和其他元数据。当发生HttpError(状态> = 400)时,“ info [‘body’]”将包含错误响应数据。

这里判断了请求的状态码以及上传对应的错误信息。

  • 304状态码:如果客户端发送了一个带条件的GET 请求且该请求已被允许,而文档的内容(自上次访问以来或者根据请求的条件)并没有改变。
  • fetch_url中的异常可能导致状态为-1,从而确保在所有情况下均向用户提供适当的错误。
  • 非200的状态码且协议不正确的情况。

接着判断tmp_dest参数且是否为一个存在的路径。创建一个临时文件并复制内容以进行基于校验和的替换。这里用到了tempfile.mkstemp()。函数 mkstemp() 仅仅就返回一个原始的OS文件描述符,你需要自己将它转换为一个真正的文件对象, 同样你还需要自己清理这些文件。

所以才用的是os.fdopen() 方法,用于通过文件描述符 fd 创建一个文件对象,并返回这个文件对象。该方法是内置函数 open() 的别名,可以接收一样的参数,唯一的区别是 fdopen() 的第一个参数必须是整型。shutil.copyfileobj() 的作用是将文件内容拷贝到另一个文件中。最后关闭文件操作。

再回到main函数中代码。

if not dest_is_dir and os.path.exists(dest):
    checksum_mismatch = False

    if not force and checksum != '':
        destination_checksum = module.digest_from_file(dest, algorithm)

        if checksum != destination_checksum:
            checksum_mismatch = True

    if not force and checksum and not checksum_mismatch:
        file_args = module.load_file_common_arguments(module.params, path=dest)
        result['changed'] = module.set_fs_attributes_if_different(file_args, False)
        if result['changed']:
            module.exit_json(msg="file already exists but file attributes changed", **result)
        module.exit_json(msg="file already exists", **result)

    mtime = os.path.getmtime(dest)
    last_mod_time = datetime.datetime.utcfromtimestamp(mtime)

    if checksum_mismatch:
        force = True

start = datetime.datetime.utcnow()
method = 'HEAD' if module.check_mode else 'GET'
tmpsrc, info = url_get(module, url, dest, use_proxy, last_mod_time, force, timeout, headers, tmp_dest, method)
result['elapsed'] = (datetime.datetime.utcnow() - start).seconds
result['src'] = tmpsrc

如果不强制下载并且有校验和,允许匹配校验和以跳过下载。module.digest_from_file的作用是返回本地文件的十六进制摘要以获取由名称指定的digest_method;如果文件不存在,则返回None。

除非校验和不匹配,否则不强制重新下载。除非校验和不匹配允许文件属性更改,不强制重新下载。如果文件已经存在,则会请求上次修改的时间。如果校验和不匹配,则必须强制下载,因为last_mod_time可能比远程的更新。

经过上面一堆判断后,开始下载到tmpsrc。module.check_mode属性为真就用HEAD方式请求,elapsed就是请求所耗费的时间。

if dest_is_dir:
    filename = extract_filename_from_headers(info)
    if not filename:
        filename = url_filename(info['url'])
    dest = os.path.join(dest, filename)
    result['dest'] = dest

if not os.path.exists(tmpsrc):
    os.remove(tmpsrc)
    module.fail_json(msg="Request failed", status_code=info['status'], response=info['msg'], **result)
if not os.access(tmpsrc, os.R_OK):
    os.remove(tmpsrc)
    module.fail_json(msg="Source %s is not readable" % (tmpsrc), **result)
result['checksum_src'] = module.sha1(tmpsrc)

当dest路径存在时,就调用extract_filename_from_headers去解析之前用url_get请求返回的header头中的字段获取文件名。如果没获取到就调用url_filename从url中截取path作为文件名。filename和dest拼接得到最后的文件路径。

接着判断如果没有tmpsrc文件,则会引发错误。os.access() 方法使用当前的uid/gid尝试访问路径,用来检测是否有访问权限的路径。最后对tmpsrc取哈希。

这里用到了extract_filename_from_headers和url_filename,分别看代码。

def extract_filename_from_headers(headers):
    cont_disp_regex = 'attachment; ?filename="?([^"]+)'
    res = None

    if 'content-disposition' in headers:
        cont_disp = headers['content-disposition']
        match = re.match(cont_disp_regex, cont_disp)
        if match:
            res = match.group(1)
            res = os.path.basename(res)
    return res

从给定的header字典中提取文件名。查找content-disposition标头并应用正则表达式。如果成功,则返回文件名,否则返回None。用os.path.basename处理是为避免匹配到一些有趣的结果。

def url_filename(url):
    fn = os.path.basename(urlsplit(url)[2])
    if fn == '':
        return 'index.html'
    return fn

对url做分隔,为空就返回index.html,非常简单。再回到main函数的代码来。

    if os.path.exists(dest):
        if not os.access(dest, os.W_OK):
            os.remove(tmpsrc)
            module.fail_json(msg="Destination %s is not writable" % (dest), **result)
        if not os.access(dest, os.R_OK):
            os.remove(tmpsrc)
            module.fail_json(msg="Destination %s is not readable" % (dest), **result)
        result['checksum_dest'] = module.sha1(dest)
    else:
        if not os.path.exists(os.path.dirname(dest)):
            os.remove(tmpsrc)
            module.fail_json(msg="Destination %s does not exist" % (os.path.dirname(dest)), **result)
        if not os.access(os.path.dirname(dest), os.W_OK):
            os.remove(tmpsrc)
            module.fail_json(msg="Destination %s is not writable" % (os.path.dirname(dest)), **result)

    if module.check_mode:
        if os.path.exists(tmpsrc):
            os.remove(tmpsrc)
        result['changed'] = ('checksum_dest' not in result or
                             result['checksum_src'] != result['checksum_dest'])
        module.exit_json(msg=info.get('msg', ''), **result)

如果没有dest复制的权限,则会引发错误。这段代码基本都是对权限的判断以及上报对应的错误。

    if module.check_mode:
        if os.path.exists(tmpsrc):
            os.remove(tmpsrc)
        result['changed'] = ('checksum_dest' not in result or
                             result['checksum_src'] != result['checksum_dest'])
        module.exit_json(msg=info.get('msg', ''), **result)

    backup_file = None
    if result['checksum_src'] != result['checksum_dest']:
        try:
            if backup:
                if os.path.exists(dest):
                    backup_file = module.backup_local(dest)
            module.atomic_move(tmpsrc, dest, unsafe_writes=module.params['unsafe_writes'])
        except Exception as e:
            if os.path.exists(tmpsrc):
                os.remove(tmpsrc)
            module.fail_json(msg="failed to copy %s to %s: %s" % (tmpsrc, dest, to_native(e)),
                             exception=traceback.format_exc(), **result)
        result['changed'] = True
    else:
        result['changed'] = False
        if os.path.exists(tmpsrc):
            os.remove(tmpsrc)

check_mode为真时,删除tmpsrc。changed的值是根据checksum_dest是否存在或者checksum_src和checksum_dest的值是否不相等确定的。

checksum_src和checksum_dest的值不相等时。通过backup_local对指定文件进行带日期标记的备份。

上面的代码中看到了很多os.remove(tmpsrc),其实我感觉太冗余了,每个判断判断分支里面都有remove的操作,我觉得不如在最后直接用try:os.remove(tmpsrc) except:pass来处理。

    if checksum != '':
        destination_checksum = module.digest_from_file(dest, algorithm)

        if checksum != destination_checksum:
            os.remove(dest)
            module.fail_json(msg="The checksum for %s did not match %s; it was %s." % (dest, checksum, destination_checksum), **result)

    file_args = module.load_file_common_arguments(module.params, path=dest)
    result['changed'] = module.set_fs_attributes_if_different(file_args, result['changed'])

    try:
        result['md5sum'] = module.md5(dest)
    except ValueError:
        result['md5sum'] = None

    if backup_file:
        result['backup_file'] = backup_file

    # Mission complete
    module.exit_json(msg=info.get('msg', ''), status_code=info.get('status', ''), **result)

判断checksum和destination_checksum是否相同,不同则上报错误。

允许文件属性更改,仅向后兼容,将在启用FIPS的系统上返回无。这个是注释里写的,我查了一下FIPS的含义:

openssl-fips是符合FIPS标准的Openssl。 联邦信息处理标准(Federal Information Processing Standards,FIPS)是一套描述文件处理、加密算法和其他信息技术标准(在非军用政府机构和与这些机构合作的政府承包商和供应商中应用的标准)的标准。

反正也不知道是啥,只不过异常捕获时md5sum的值为空。

通读了get_url的代码,感觉挺普通的,甚至有些冗余。毕竟学习别人的代码,还是得多反思,取其精华,去其糟粕。说不定之后自己写一个ansible的模块能写得更好呢哈哈。

赞赏

微信赞赏支付宝赞赏

Zgao

愿有一日,安全圈的师傅们都能用上Zgao写的工具。

发表评论