Source code for pfd.entrypoint.download

import logging
from typing import (
    Dict,
    List,
    Optional,
    Union,
)
from pathlib import Path
from dflow import Workflow, download_artifact

from .args import normalize as normalize_args
from .common import (
    global_config_workflow,
)

from pfd.utils.download_pfd_artifacts import (
    download_dpgen2_artifacts,
    download_dpgen2_artifacts_by_def,
)


END_OUTPUT = {
    "finetune": {"steps": "finetune", "artifact": "fine_tuned_model"},
    "dist": {"steps": "distillation", "artifact": "dist_model"},
}


[docs] def download_end_result( workflow_id, wf_config: Dict = {}, prefix: Optional[str] = None ): try: task_type = wf_config["task"]["type"] except KeyError: raise RuntimeError("Illegal input file!") wf_config = normalize_args(wf_config) global_config_workflow(wf_config) wf = Workflow(id=workflow_id) step_info = wf.query() wf_status = wf.query_status() if wf_status == "Failed": raise RuntimeError(f"Workflow failed (ID: {wf.id}, UID: {wf.uid})") try: wf_post = step_info.get_step(name=END_OUTPUT[task_type]["steps"])[0] except IndexError: logging.warning("The workflow may not have finished!") return if wf_post["phase"] == "Succeeded": print(f"Workflow finished (ID: {wf.id}, UID: {wf.uid})") print("Retrieving completed tasks to local...") if prefix is not None: path = Path(prefix) else: path = Path("./results") # download output model download_artifact( artifact=wf_post.outputs.artifacts[END_OUTPUT[task_type]["artifact"]], path=path / "model", ) download_artifact( artifact=wf_post.outputs.artifacts["iter_data"], path=path / "data" )
[docs] def download_by_def( workflow_id, wf_config: Dict = {}, iterations: Optional[List[int]] = None, step_defs: Optional[List[str]] = None, prefix: Optional[str] = None, chk_pnt: bool = False, ): wf_config = normalize_args(wf_config) global_config_workflow(wf_config) wf = Workflow(id=workflow_id) download_dpgen2_artifacts_by_def(wf, iterations, step_defs, prefix, chk_pnt)
[docs] def download( workflow_id, wf_config: Optional[Dict] = {}, wf_keys: Optional[List] = None, prefix: Optional[str] = None, chk_pnt: bool = False, ): wf_config = normalize_args(wf_config) global_config_workflow(wf_config) wf = Workflow(id=workflow_id) if wf_keys is None: wf_keys = wf.query_keys_of_steps() assert wf_keys is not None for kk in wf_keys: download_dpgen2_artifacts(wf, kk, prefix=prefix, chk_pnt=chk_pnt) logging.info(f"step {kk} downloaded")