import dis
import os
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
from typing import (
Any,
Dict,
List,
Optional,
Set,
Type,
Union,
)
from dflow import (
InputArtifact,
InputParameter,
Inputs,
Outputs,
OPTemplate,
OutputArtifact,
OutputParameter,
Outputs,
Step,
Steps,
)
from dflow.python import (
OP,
PythonOPTemplate,
)
from pfd.op import ModelTestOP
from dpgen2.utils.step_config import init_executor
[docs]
class Distillation(Steps):
def __init__(
self,
name: str,
pert_gen_op: Type[OP],
expl_dist_loop_op: OPTemplate,
pert_gen_config: dict,
upload_python_packages: Optional[List[os.PathLike]] = None,
):
self._input_parameters = {
"block_id": InputParameter(),
"type_map": InputParameter(),
"mass_map": InputParameter(),
# pert_gen
"pert_config": InputParameter(),
# exploration
"scheduler": InputParameter(),
"numb_models": InputParameter(type=int),
"explore_config": InputParameter(),
"converge_config": InputParameter(),
"conf_filters_conv": InputParameter(),
"test_size": InputParameter(),
# training
"template_script": InputParameter(),
"train_config": InputParameter(),
"type_map_train": InputParameter(),
# other configurations
"inference_config": InputParameter(),
}
self._input_artifacts = {
"init_confs": InputArtifact(),
"teacher_model": InputArtifact(),
"init_data": InputArtifact(optional=True),
"iter_data": InputArtifact(optional=True), # empty list
"validation_data": InputArtifact(optional=True),
}
self._output_parameters = {
# "dp_test":OutputParameter()
}
self._output_artifacts = {
"dist_model": OutputArtifact(),
"iter_data": OutputArtifact(),
}
super().__init__(
name=name,
inputs=Inputs(
parameters=self._input_parameters, artifacts=self._input_artifacts
),
outputs=Outputs(
parameters=self._output_parameters,
artifacts=self._output_artifacts,
),
)
self = _dist_cl(
self,
name,
pert_gen_op,
expl_dist_loop_op,
pert_gen_config,
upload_python_packages=upload_python_packages,
)
@property
def input_parameters(self):
return self._input_parameters
@property
def input_artifacts(self):
return self._input_artifacts
@property
def output_parameters(self):
return self._output_parameters
@property
def output_artifacts(self):
return self._output_artifacts
pass
def _dist_cl(
steps,
name: str,
pert_gen_op: Type[OP],
expl_dist_loop_op: OPTemplate,
pert_gen_step_config: dict,
upload_python_packages: Optional[List[os.PathLike]] = None,
):
pert_gen_step_config = deepcopy(pert_gen_step_config)
pert_gen_template_config = pert_gen_step_config.pop("template_config")
pert_gen_executor = init_executor(pert_gen_step_config.pop("executor"))
pert_gen = Step(
name + "-pert-gen",
template=PythonOPTemplate(
pert_gen_op,
python_packages=upload_python_packages,
**pert_gen_template_config
),
parameters={"config": steps.inputs.parameters["pert_config"]},
artifacts={"init_confs": steps.inputs.artifacts["init_confs"]},
key="--".join(["init", "pert-gen"]),
executor=pert_gen_executor,
**pert_gen_step_config
)
steps.add(pert_gen)
loop = Step(
name="dist-loop",
template=expl_dist_loop_op,
parameters={
"type_map": steps.inputs.parameters["type_map"],
"mass_map": steps.inputs.parameters["mass_map"],
"numb_models": steps.inputs.parameters["numb_models"],
"template_script": steps.inputs.parameters["template_script"],
"train_config": steps.inputs.parameters["train_config"],
"explore_config": steps.inputs.parameters["explore_config"],
"converge_config": steps.inputs.parameters["converge_config"],
"conf_filters_conv": steps.inputs.parameters["conf_filters_conv"],
"inference_config": steps.inputs.parameters["inference_config"],
"test_size": steps.inputs.parameters["test_size"],
"type_map_train": steps.inputs.parameters["type_map_train"],
"scheduler": steps.inputs.parameters["scheduler"],
},
artifacts={
"systems": pert_gen.outputs.artifacts[
"pert_sys"
], # starting systems for model deviation
"teacher_model": steps.inputs.artifacts["teacher_model"],
"init_data": steps.inputs.artifacts[
"init_data"
], # initial data for model finetune
"iter_data": steps.inputs.artifacts["iter_data"],
},
key="--".join(["%s" % "test", "-loop"]),
)
steps.add(loop)
steps.outputs.artifacts["dist_model"]._from = loop.outputs.artifacts["dist_model"]
steps.outputs.artifacts["iter_data"]._from = loop.outputs.artifacts["iter_data"]
return steps