Source code for cpg_utils.hail_batch

"""Convenience functions related to Hail."""

import asyncio
import base64
import gzip
import inspect
import logging
import os
import tempfile
import textwrap
import uuid
from shlex import quote
from typing import Any, Literal

import toml
from deprecated import deprecated

import hail as hl
import hailtop.batch as hb
from hail.backend.service_backend import ServiceBackend as InternalServiceBackend
from hail.utils.java import Env
from hailtop.config import get_deploy_config

from cpg_utils import Path, to_path
from cpg_utils.config import (
    AR_GUID_NAME,
    config_retrieve,
    dataset_path,
    genome_build,
    get_config,
    set_config_paths,
    try_get_ar_guid,
)
from cpg_utils.config import (
    reference_path as ref_path,
)
from cpg_utils.constants import DEFAULT_GITHUB_ORGANISATION

# template commands strings
GCLOUD_AUTH_COMMAND = """\
export GOOGLE_APPLICATION_CREDENTIALS=/gsa-key/key.json
gcloud -q auth activate-service-account \
--key-file=$GOOGLE_APPLICATION_CREDENTIALS
"""


_batch: 'Batch | None' = None


[docs] def reset_batch(): """Reset the global batch reference, useful for tests""" global _batch # pylint: disable=global-statement _batch = None
[docs] def get_batch( name: str | None = None, *, default_python_image: str | None = None, attributes: dict[str, str] | None = None, **kwargs: Any, ) -> 'Batch': """ Wrapper around Hail's `Batch` class, which allows to register created jobs This has been migrated (currently duplicated) out of cpg_workflows Parameters ---------- name : str, optional, name for the batch default_python_image : str, optional, default python image to use Returns ------- If there are scheduled jobs, return the batch If there are no jobs to create, return None """ global _batch # pylint: disable=global-statement backend: hb.Backend if _batch is None: _backend = config_retrieve(['hail', 'backend'], default='batch') if _backend == 'local': logging.info('Initialising Hail Batch with local backend') backend = hb.LocalBackend( tmp_dir=tempfile.mkdtemp('batch-tmp'), ) else: logging.info('Initialising Hail Batch with service backend') backend = hb.ServiceBackend( billing_project=config_retrieve(['hail', 'billing_project']), remote_tmpdir=dataset_path('batch-tmp', category='tmp'), token=os.environ.get('HAIL_TOKEN'), ) _batch = Batch( name=name or config_retrieve(['workflow', 'name'], default=None), backend=backend, pool_label=config_retrieve(['hail', 'pool_label'], default=None), cancel_after_n_failures=config_retrieve( ['hail', 'cancel_after_n_failures'], default=None, ), default_timeout=config_retrieve(['hail', 'default_timeout'], default=None), default_memory=config_retrieve(['hail', 'default_memory'], default=None), default_python_image=default_python_image or config_retrieve(['workflow', 'driver_image']), attributes=attributes, **kwargs, ) return _batch
[docs] class Batch(hb.Batch): """ Thin subclass of the Hail `Batch` class. The aim is to be able to register created jobs, in order to print statistics before submitting the Batch. """ def __init__( self, name: str, backend: hb.backend.LocalBackend | hb.backend.ServiceBackend, *, pool_label: str | None = None, attributes: dict[str, str] | None = None, **kwargs: Any, ): _attributes = attributes or {} if AR_GUID_NAME not in _attributes: # noqa: SIM102 if ar_guid := try_get_ar_guid(): _attributes[AR_GUID_NAME] = ar_guid super().__init__(name, backend, attributes=_attributes, **kwargs) # Job stats registry: self.job_by_label: dict = {} self.job_by_stage: dict = {} self.job_by_tool: dict = {} self.total_job_num = 0 self.pool_label = pool_label dry_run = config_retrieve(['hail', 'dry_run'], default=False) if not dry_run and not isinstance(self._backend, hb.LocalBackend): self._copy_configs_to_remote() def _copy_configs_to_remote(self) -> None: """ Combine all config files into a single entry Write that entry to a cloud path Set that cloud path as the config path This is crucial in production-pipelines as we combine remote and local files in the driver image, but we can only pass cloudpaths to the worker job containers """ if not isinstance(self._backend, hb.backend.ServiceBackend): return remote_dir = to_path(self._backend.remote_tmpdir) / 'config' config_path = remote_dir / (str(uuid.uuid4()) + '.toml') with config_path.open('w') as f: toml.dump(dict(get_config()), f) set_config_paths([str(config_path)]) def _pack_attribute(self, key: str, value: str) -> dict[str, str]: """ Attributes are stored in a TEXT database field, which is limited to 64K. If necessary, compress the value and annotate the key accordingly. Eventually this may no longer suffice and we will need to split the value across several attributes or similar. """ if len(value) <= 10000: # noqa: PLR2004 return {key: value} # Store short values verbatim raw = value.encode() compressed_b64 = base64.standard_b64encode(gzip.compress(raw, compresslevel=9)) if len(compressed_b64) > 65535: # noqa: PLR2004 raise ValueError(f'Job attribute {key!r} value is too large') return {f'{key}_gzip': compressed_b64.decode('ascii')} def _process_job_attributes( self, name: str | None = None, attributes: dict | None = None, ) -> tuple[str, dict[str, str]]: """ Use job attributes to make the job name more descriptive, and add labels for Batch pre-submission stats. """ if not name: raise ValueError('Error: job name must be defined') self.total_job_num += 1 # Multiple jobs in the batch might reference the same attributes dict # object. Avoid modifying the dict object (e.g. with pop() or update()) # to avoid changing the attributes of subsequently processed jobs. attributes = attributes or {} stage = attributes.get('stage') dataset = attributes.get('dataset') sequencing_group = attributes.get('sequencing_group') participant_id = attributes.get('participant_id') sequencing_groups: set[str] = set(attributes.get('sequencing_groups') or []) if sequencing_group: sequencing_groups.add(sequencing_group) part = attributes.get('part') label = attributes.get('label', name) tool = attributes.get('tool') if not tool and name.endswith('Dataproc cluster'): tool = 'hailctl dataproc' # pylint: disable=W1116 assert isinstance(stage, str | None) assert isinstance(dataset, str | None) assert isinstance(sequencing_group, str | None) assert isinstance(participant_id, str | None) assert isinstance(part, str | None) assert isinstance(label, str | None) name = make_job_name( name=name, sequencing_group=sequencing_group, participant_id=participant_id, dataset=dataset, part=part, ) if label not in self.job_by_label: self.job_by_label[label] = {'job_n': 0, 'sequencing_groups': set()} self.job_by_label[label]['job_n'] += 1 self.job_by_label[label]['sequencing_groups'] |= sequencing_groups if stage not in self.job_by_stage: self.job_by_stage[stage] = {'job_n': 0, 'sequencing_groups': set()} self.job_by_stage[stage]['job_n'] += 1 self.job_by_stage[stage]['sequencing_groups'] |= sequencing_groups if tool not in self.job_by_tool: self.job_by_tool[tool] = {'job_n': 0, 'sequencing_groups': set()} self.job_by_tool[tool]['job_n'] += 1 self.job_by_tool[tool]['sequencing_groups'] |= sequencing_groups # Ensure all the returned attribute values are presented as strings fixed_attrs = { k: str(v) for k, v in attributes.items() if k != 'sequencing_groups' } seqgroups_str = str(sorted(sequencing_groups)) fixed_attrs.update(self._pack_attribute('sequencing_groups', seqgroups_str)) return name, fixed_attrs
[docs] def run(self, **kwargs: Any): """ Execute a batch. Overridden to print pre-submission statistics. Pylint disables: - R1710: Either all return statements in a function should return an expression, or none of them should. - if no jobs are present, no batch is returned. Hail should have this behaviour... - W0221: Arguments number differs from overridden method - this wrapper makes use of **kwargs, which is being passed to the super().run() method """ if not self._jobs: logging.error('No jobs to submit') return None for job in self._jobs: job.name, job.attributes = self._process_job_attributes( job.name, job.attributes, ) # We only have dedicated pools for preemptible machines. # _preemptible defaults to None, so check explicitly for False. # pylint: disable=W0212 if self.pool_label and job._preemptible is not False: job._pool_label = self.pool_label copy_common_env(job) logging.info(f'Will submit {self.total_job_num} jobs') def _print_stat( prefix: str, _d: dict, default_label: str | None = None, ) -> None: m = (prefix or ' ') + '\n' for label, stat in _d.items(): lbl = label or default_label msg = f'{stat["job_n"]} job' if stat['job_n'] > 1: msg += 's' if (sg_count := len(stat['sequencing_groups'])) > 0: msg += f' for {sg_count} sequencing group' if sg_count > 1: msg += 's' m += f' {lbl}: {msg}' logging.info(m) _print_stat( 'Split by stage:', self.job_by_stage, default_label='<not in stage>', ) _print_stat( 'Split by tool:', self.job_by_tool, default_label='<tool is not defined>', ) kwargs.setdefault('dry_run', config_retrieve(['hail', 'dry_run'], default=None)) kwargs.setdefault( 'delete_scratch_on_exit', config_retrieve(['hail', 'delete_scratch_on_exit'], default=None), ) # Local backend does not support "wait" if isinstance(self._backend, hb.LocalBackend) and 'wait' in kwargs: del kwargs['wait'] return super().run(**kwargs)
[docs] def make_job_name( name: str, sequencing_group: str | None = None, participant_id: str | None = None, dataset: str | None = None, part: str | None = None, ) -> str: """ Extend the descriptive job name to reflect job attributes. """ if sequencing_group and participant_id: sequencing_group = f'{sequencing_group}/{participant_id}' if sequencing_group and dataset: name = f'{dataset}/{sequencing_group}: {name}' elif dataset: name = f'{dataset}: {name}' if part: name += f', {part}' return name
_default_override_revision = None
[docs] class DefaultOverrideServiceBackend(InternalServiceBackend): @property def jar_spec(self) -> dict: return {'type': 'git_revision', 'value': _default_override_revision}
[docs] def init_batch(**kwargs: Any): """ Initializes the Hail Query Service from within Hail Batch. Requires the `hail/billing_project` and `hail/bucket` config variables to be set. Parameters ---------- kwargs : keyword arguments Forwarded directly to `hl.init_batch`. """ # noinspection PyProtectedMember if Env._hc: # pylint: disable=W0212 return # already initialised dataset = config_retrieve(['workflow', 'dataset']) kwargs.setdefault('token', os.environ.get('HAIL_TOKEN')) asyncio.get_event_loop().run_until_complete( hl.init_batch( default_reference=genome_build(), billing_project=config_retrieve(['hail', 'billing_project']), remote_tmpdir=remote_tmpdir(f'cpg-{dataset}-hail'), **kwargs, ), ) if revision := config_retrieve(['workflow', 'default_jar_spec_revision'], False): global _default_override_revision _default_override_revision = revision backend = Env.backend() if isinstance(backend, InternalServiceBackend): backend.__class__ = DefaultOverrideServiceBackend
[docs] def copy_common_env(job: hb.batch.job.Job) -> None: """Copies common environment variables that we use to run Hail jobs. These variables are typically set up in the analysis-runner driver, but need to be passed through for "batch-in-batch" use cases. The environment variable values are extracted from the current process and copied to the environment dictionary of the given Hail Batch job. """ # If possible, please don't add new environment variables here, but instead add # config variables. for key in ('CPG_CONFIG_PATH',): val = os.getenv(key) if val: job.env(key, val) if not job.attributes: job.attributes = {} ar_guid = try_get_ar_guid() if ar_guid: job.attributes[AR_GUID_NAME] = ar_guid
[docs] def remote_tmpdir(hail_bucket: str | None = None) -> str: """Returns the remote_tmpdir to use for Hail initialization. If `hail_bucket` is not specified explicitly, requires the `hail/bucket` config variable to be set. """ bucket = hail_bucket or config_retrieve(['hail', 'bucket'], default=None) assert bucket, 'hail_bucket was not set by argument or configuration' return f'gs://{bucket}/batch-tmp'
[docs] def fasta_res_group(b: hb.Batch, indices: list[str] | None = None): """ Hail Batch resource group for fasta reference files. @param b: Hail Batch object. @param indices: list of extensions to add to the base fasta file path. """ ref_fasta = config_retrieve(['workflow', 'ref_fasta'], default=None) if not ref_fasta: ref_fasta = ref_path('broad/ref_fasta') ref_fasta = to_path(ref_fasta) d = { 'base': str(ref_fasta), 'fai': str(ref_fasta) + '.fai', 'dict': str(ref_fasta.with_suffix('.dict')), } if indices: for ext in indices: d[ext] = f'{ref_fasta}.{ext}' return b.read_input_group(**d)
[docs] def authenticate_cloud_credentials_in_job( job: hb.batch.job.BashJob, print_all_statements: bool = True, ): """ Takes a hail batch job, activates the appropriate service account Once multiple environments are supported this method will decide on which authentication method is appropriate Parameters ---------- job * A hail BashJob print_all_statements * logging toggle Returns ------- None """ # Use "set -x" to print the commands for easier debugging. if print_all_statements: job.command('set -x') # activate the google service account job.command(GCLOUD_AUTH_COMMAND)
[docs] def prepare_git_job( job: hb.batch.job.BashJob, repo_name: str, commit: str, organisation: str = DEFAULT_GITHUB_ORGANISATION, is_test: bool = True, print_all_statements: bool = True, get_deploy_token: bool = True, ): """ Takes a hail batch job, and: * Clones the repository * if access_level != "test": check the desired commit is on 'main' * Check out the specific commit Parameters ---------- job - A hail BashJob organisation - The GitHub individual or organisation repo_name - The repository name to check out commit - The commit hash to check out is_test - CPG specific: only Main commits can run on Main data print_all_statements - logging toggle Returns ------- No return required """ authenticate_cloud_credentials_in_job( job, print_all_statements=print_all_statements, ) # Note: for private GitHub repos we'd need to use a token to clone. # - store the token on secret manager # - The git_credentials_secret_{name,project} values are set by cpg-infrastructure # - check at runtime whether we can get the token # - if so, set up the git credentials store with that value if get_deploy_token: job.command( """ # get secret names from config if they exist secret_name=$(python3 -c ' try: from cpg_utils.config import config_retrieve print(config_retrieve(["infrastructure", "git_credentials_secret_name"], default="")) except: pass ' || echo '') secret_project=$(python3 -c ' try: from cpg_utils.config import config_retrieve print(config_retrieve(["infrastructure", "git_credentials_secret_project"], default="")) except: pass ' || echo '') if [ ! -z "$secret_name" ] && [ ! -z "$secret_project" ]; then # configure git credentials store if credentials are set gcloud --project $secret_project secrets versions access --secret $secret_name latest > ~/.git-credentials git config --global credential.helper "store" else echo 'No git credentials secret found, unable to check out private repositories.' fi """, ) # Any job commands here are evaluated in a bash shell, so user arguments should # be escaped to avoid command injection. repo_path = f'https://github.com/{organisation}/{repo_name}.git' job.command(f'git clone --recurse-submodules {quote(repo_path)}') job.command(f'cd {quote(repo_name)}') # Except for the "test" access level, we check whether commits have been # reviewed by verifying that the given commit is in the main branch. if not is_test: job.command('git checkout main') job.command( f'git merge-base --is-ancestor {quote(commit)} HEAD || ' '{ echo "error: commit not merged into main branch"; exit 1; }', ) job.command(f'git checkout {quote(commit)}') job.command('git submodule update') return job
# commands that declare functions that pull files on an instance, # handling transitive errors RETRY_CMD = """\ function fail { echo $1 >&2 exit 1 } function retry { local n_attempts=10 local delay=30 local n=1 while ! eval "$@"; do if [[ $n -lt $n_attempts ]]; then ((n++)) echo "Command failed. Attempt $n/$n_attempts after ${delay}s..." sleep $delay; else fail "The command has failed after $n attempts." fi done } function retry_gs_cp { src=$1 if [ -n "$2" ]; then dst=$2 else dst=/io/batch/${basename $src} fi retry gcloud storage cp $src $dst } """ # command that monitors the instance storage space MONITOR_SPACE_CMD = 'df -h; du -sh /io; du -sh /io/batch' ADD_SCRIPT_CMD = """\ cat <<'EOT' >> {script_name} {script_contents} EOT\ """
[docs] def command( cmd: str | list[str], monitor_space: bool = False, setup_gcp: bool = False, define_retry_function: bool = False, rm_leading_space: bool = True, python_script_path: Path | None = None, ) -> str: """ Wraps a command for Batch. @param cmd: command to wrap (can be a list of commands) @param monitor_space: add a background process that checks the instance disk space every 5 minutes and prints it to the screen @param setup_gcp: authenticate on GCP @param define_retry_function: when set, adds bash functions `retry` that attempts to redo a command after every 30 seconds (useful to pull inputs and get around GoogleEgressBandwidth Quota or other google quotas) @param rm_leading_space: remove all leading spaces and tabs from the command lines @param python_script_path: if provided, copy this python script into the command """ if isinstance(cmd, list): cmd = '\n'.join(cmd) if define_retry_function: setup_gcp = True cmd = f"""\ set -o pipefail set -ex {GCLOUD_AUTH_COMMAND if setup_gcp else ''} {RETRY_CMD if define_retry_function else ''} {f'(while true; do {MONITOR_SPACE_CMD}; sleep 600; done) &' if monitor_space else ''} {{copy_script_cmd}} {cmd} {MONITOR_SPACE_CMD if monitor_space else ''} """ if rm_leading_space: # remove any leading spaces and tabs cmd = '\n'.join(line.strip() for line in cmd.split('\n')) # remove stretches of spaces cmd = '\n'.join(' '.join(line.split()) for line in cmd.split('\n')) else: # Remove only common leading space: cmd = textwrap.dedent(cmd) # We don't want the python script tabs to be stripped, so # we are inserting it after leading space is removed if python_script_path: with python_script_path.open() as f: script_contents = f.read() cmd = cmd.replace( '{copy_script_cmd}', ADD_SCRIPT_CMD.format( script_name=python_script_path.name, script_contents=script_contents, ), ) else: cmd = cmd.replace('{copy_script_cmd}', '') return cmd
[docs] def query_command( module: Any, func_name: str, *func_args: Any, setup_gcp: bool = False, setup_hail: bool = True, packages: list[str] | None = None, init_batch_args: dict[str, str | int] | None = None, ) -> str: """ Construct a command to run a python function inside a Hail Batch job. If hail_billing_project is provided, Hail Query will be also initialised. Run a Python Hail Query function inside a Hail Batch job. Constructs a command string to use with job.command(). If hail_billing_project is provided, Hail Query will be initialised. init_batch_args can be used to pass additional arguments to init_batch. this is a dict of args, which will be placed into the batch initiation command e.g. {'worker_memory': 'highmem'} -> 'init_batch(worker_memory="highmem")' """ # translate any input arguments into an embeddable String if init_batch_args: batch_overrides = ', '.join(f'{k}={v!r}' for k, v in init_batch_args.items()) else: batch_overrides = '' init_hail_code = f""" from cpg_utils.hail_batch import init_batch init_batch({batch_overrides}) """ # the code will be copied verbatim python_code = f""" {'' if not setup_hail else init_hail_code} {inspect.getsource(module)} """ # but the function call will be shell-expanded, as the arguments may # contain variables requiring expansion, ${BATCH_TMPDIR} in particular python_call = f""" {func_name}{func_args} """ return f"""\ set -o pipefail set -ex {GCLOUD_AUTH_COMMAND if setup_gcp else ''} {('pip3 install ' + ' '.join(packages)) if packages else ''} cat <<'EOT' > script.py {python_code} EOT cat <<EOT >> script.py {python_call} EOT python3 script.py """
[docs] def start_query_context( query_backend: Literal['spark', 'batch', 'local', 'spark_local'] | None = None, log_path: str | None = None, dataset: str | None = None, billing_project: str | None = None, ): """ Start Hail Query context, depending on the backend class specified in the hail/query_backend TOML config value. """ query_backend = query_backend or config_retrieve( ['hail', 'query_backend'], default='spark', ) if query_backend == 'spark': hl.init(default_reference=genome_build()) elif query_backend == 'spark_local': local_threads = 2 # https://stackoverflow.com/questions/32356143/what-does-setmaster-local-mean-in-spark hl.init( default_reference=genome_build(), master=f'local[{local_threads}]', # local[2] means "run spark locally with 2 threads" quiet=True, log=log_path or dataset_path('hail-log.txt', category='tmp'), ) elif query_backend == 'local': hl.utils.java.Env.hc() # force initialization else: assert query_backend == 'batch' if hl.utils.java.Env._hc: # pylint: disable=W0212 return # already initialised dataset = dataset or config_retrieve(['workflow', 'dataset']) billing_project = billing_project or config_retrieve( ['hail', 'billing_project'], ) asyncio.get_event_loop().run_until_complete( hl.init_batch( billing_project=billing_project, remote_tmpdir=f'gs://cpg-{dataset}-hail/batch-tmp', token=os.environ.get('HAIL_TOKEN'), default_reference='GRCh38', ), )
[docs] def run_batch_job_and_print_url( batch: Batch, wait: bool, environment: str, ) -> str | None: """Call batch.run(), return the URL, and wait for job to finish if wait=True""" if not environment == 'gcp': raise ValueError( f'Unsupported Hail Batch deploy config environment: {environment}', ) bc_batch = batch.run(wait=False) if not bc_batch: return None deploy_config = get_deploy_config() url = deploy_config.url('batch', f'/batches/{bc_batch.id}') if wait: status = bc_batch.wait() if status['state'] != 'success': raise Exception(f'{url} failed') return url
# these methods were removed from this location, put in config
[docs] @deprecated('Use cpg_utils.config.image_path instead') def image_path(*args, **kwargs): # noqa: ANN002, ANN003 from cpg_utils.config import image_path as _image_path return _image_path(*args, **kwargs)
[docs] @deprecated('Use cpg_utils.config.output_path instead') def output_path(*args, **kwargs): # noqa: ANN002, ANN003 from cpg_utils.config import output_path as _output_path return _output_path(*args, **kwargs)
[docs] @deprecated('Use cpg_utils.config.web_url instead') def web_url(*args, **kwargs): # noqa: ANN002, ANN003 from cpg_utils.config import web_url as _web_url return _web_url(*args, **kwargs)
# cpg_test_dataset_path
[docs] @deprecated('Use cpg_utils.config.dataset_path instead') def cpg_test_dataset_path(*args, **kwargs): # noqa: ANN002, ANN003 from cpg_utils.config import cpg_test_dataset_path as _cpg_test_dataset_path return _cpg_test_dataset_path(*args, **kwargs)
[docs] @deprecated( 'Use to_path(cpg_utils.config.reference_path) instead, note the ' 'config.reference_path does not return an AnyPath object', ) def reference_path(*args, **kwargs): # noqa: ANN002, ANN003 from cpg_utils.config import reference_path as _reference_path return to_path(_reference_path(*args, **kwargs))
[docs] @deprecated('Use cpg_utils.config.get_cpg_namespace instead') def cpg_namespace(*args, **kwargs): # noqa: ANN002, ANN003 from cpg_utils.config import get_cpg_namespace as _cpg_namespace return _cpg_namespace(*args, **kwargs)