├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── README.rst ├── configs ├── ltl-strace │ └── iclr21 │ │ ├── t-sem-eval.json │ │ └── t.json ├── ltl-syn │ ├── dav23 │ │ ├── codet5-eval.json │ │ ├── codet5-sem-eval.json │ │ └── codet5.json │ ├── ht-ddp.json │ └── neurips21 │ │ ├── ht-eval.json │ │ ├── ht-sem-eval.json │ │ ├── ht.json │ │ ├── t-eval.json │ │ ├── t-sem-eval.json │ │ └── t.json └── prop-sat │ └── iclr21 │ └── t.json ├── docker ├── aalta │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── abc_aiger │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── avr │ └── Dockerfile ├── booleforce │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── bosy │ └── grpc-server.Dockerfile ├── deps │ ├── cpu.Dockerfile │ ├── dev-cpu.Dockerfile │ ├── dev-gpu.Dockerfile │ └── gpu.Dockerfile ├── limboole │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── ml2 │ ├── cpu.Dockerfile │ └── gpu.Dockerfile ├── neurosynt-grpc-server │ ├── cpu.Dockerfile │ └── gpu.Dockerfile ├── nusmv │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── nuxmv │ └── grpc-server.Dockerfile ├── semml │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── spot │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── strix │ ├── Dockerfile │ ├── grpc-server.Dockerfile │ └── opt.Dockerfile └── syfco │ ├── Dockerfile │ └── grpc-server.Dockerfile ├── ml2 ├── __init__.py ├── aiger │ ├── __init__.py │ ├── aiger_circuit.py │ ├── aiger_tokenizer.py │ └── aiger_utils.py ├── artifact.py ├── configurable.py ├── data_gen │ ├── __init__.py │ ├── counting_data_gen_server.py │ ├── data_gen_args.py │ ├── data_gen_server.py │ ├── data_server.py │ ├── progress_actor.py │ └── progress_bar.py ├── datasets │ ├── __init__.py │ ├── csv_dataset.py │ ├── csv_dataset_writer.py │ ├── dataset.py │ ├── dataset_writer.py │ ├── generator_dataset.py │ ├── load_dataset.py │ ├── split_dataset.py │ ├── split_dataset_writer.py │ ├── stats.py │ └── utils.py ├── dtypes │ ├── __init__.py │ ├── binary_ast.py │ ├── binary_expr.py │ ├── cat.py │ ├── cat_seq.py │ ├── csv_dict.py │ ├── csv_dtype.py │ ├── csv_dtype_with_id.py │ ├── decomp_binary_expr.py │ ├── decomp_binary_expr_pair.py │ ├── decomp_dtype.py │ ├── dtype.py │ ├── hashable.py │ ├── pair.py │ ├── seq.py │ ├── string.py │ ├── supervised.py │ ├── tree.py │ └── validation_result.py ├── experiment │ ├── __init__.py │ ├── experiment.py │ └── run.py ├── gcp_bucket.py ├── globals.py ├── grpc │ ├── __init__.py │ ├── aalta │ │ ├── __init__.py │ │ ├── aalta.proto │ │ ├── aalta_pb2.py │ │ ├── aalta_pb2.pyi │ │ └── aalta_pb2_grpc.py │ ├── abc_aiger │ │ ├── __init__.py │ │ ├── abc_aiger.proto │ │ ├── abc_aiger_pb2.py │ │ ├── abc_aiger_pb2.pyi │ │ └── abc_aiger_pb2_grpc.py │ ├── aiger │ │ ├── __init__.py │ │ ├── aiger.proto │ │ ├── aiger_pb2.py │ │ ├── aiger_pb2.pyi │ │ └── aiger_pb2_grpc.py │ ├── booleforce │ │ ├── __init__.py │ │ ├── booleforce.proto │ │ ├── booleforce_pb2.py │ │ ├── booleforce_pb2.pyi │ │ └── booleforce_pb2_grpc.py │ ├── bosy │ │ ├── __init__.py │ │ ├── bosy.proto │ │ ├── bosy_pb2.py │ │ ├── bosy_pb2.pyi │ │ └── bosy_pb2_grpc.py │ ├── limboole │ │ ├── __init__.py │ │ ├── limboole.proto │ │ ├── limboole_pb2.py │ │ ├── limboole_pb2.pyi │ │ └── limboole_pb2_grpc.py │ ├── ltl │ │ ├── __init__.py │ │ ├── ltl.proto │ │ ├── ltl_equiv.proto │ │ ├── ltl_equiv_pb2.py │ │ ├── ltl_equiv_pb2.pyi │ │ ├── ltl_equiv_pb2_grpc.py │ │ ├── ltl_mc.proto │ │ ├── ltl_mc_pb2.py │ │ ├── ltl_mc_pb2.pyi │ │ ├── ltl_mc_pb2_grpc.py │ │ ├── ltl_pb2.py │ │ ├── ltl_pb2.pyi │ │ ├── ltl_pb2_grpc.py │ │ ├── ltl_sat.proto │ │ ├── ltl_sat_pb2.py │ │ ├── ltl_sat_pb2.pyi │ │ ├── ltl_sat_pb2_grpc.py │ │ ├── ltl_syn.proto │ │ ├── ltl_syn_pb2.py │ │ ├── ltl_syn_pb2.pyi │ │ ├── ltl_syn_pb2_grpc.py │ │ ├── ltl_trace_mc.proto │ │ ├── ltl_trace_mc_pb2.py │ │ ├── ltl_trace_mc_pb2.pyi │ │ └── ltl_trace_mc_pb2_grpc.py │ ├── mealy │ │ ├── __init__.py │ │ ├── mealy.proto │ │ ├── mealy_pb2.py │ │ ├── mealy_pb2.pyi │ │ └── mealy_pb2_grpc.py │ ├── neurosynt │ │ ├── __init__.py │ │ ├── neurosynt.proto │ │ ├── neurosynt_pb2.py │ │ ├── neurosynt_pb2.pyi │ │ └── neurosynt_pb2_grpc.py │ ├── nusmv │ │ ├── __init__.py │ │ ├── nusmv.proto │ │ ├── nusmv_pb2.py │ │ ├── nusmv_pb2.pyi │ │ └── nusmv_pb2_grpc.py │ ├── nuxmv │ │ ├── __init__.py │ │ ├── nuxmv.proto │ │ ├── nuxmv_pb2.py │ │ ├── nuxmv_pb2.pyi │ │ └── nuxmv_pb2_grpc.py │ ├── prop │ │ ├── __init__.py │ │ ├── prop.proto │ │ ├── prop_pb2.py │ │ ├── prop_pb2.pyi │ │ └── prop_pb2_grpc.py │ ├── semml │ │ ├── __init__.py │ │ ├── semml.proto │ │ ├── semml_pb2.py │ │ ├── semml_pb2.pyi │ │ └── semml_pb2_grpc.py │ ├── spot │ │ ├── __init__.py │ │ ├── spot.proto │ │ ├── spot_pb2.py │ │ ├── spot_pb2.pyi │ │ └── spot_pb2_grpc.py │ ├── strix │ │ ├── __init__.py │ │ ├── strix.proto │ │ ├── strix_pb2.py │ │ ├── strix_pb2.pyi │ │ └── strix_pb2_grpc.py │ ├── syfco │ │ ├── __init__.py │ │ ├── syfco.proto │ │ ├── syfco_pb2.py │ │ ├── syfco_pb2.pyi │ │ └── syfco_pb2_grpc.py │ ├── system │ │ ├── __init__.py │ │ ├── system.proto │ │ ├── system_pb2.py │ │ ├── system_pb2.pyi │ │ └── system_pb2_grpc.py │ ├── tools │ │ ├── __init__.py │ │ ├── tools.proto │ │ ├── tools_pb2.py │ │ ├── tools_pb2.pyi │ │ └── tools_pb2_grpc.py │ └── trace │ │ ├── __init__.py │ │ ├── trace.proto │ │ ├── trace_pb2.py │ │ ├── trace_pb2.pyi │ │ └── trace_pb2_grpc.py ├── layers │ ├── __init__.py │ ├── attention.py │ └── positional_encoding.py ├── loading.py ├── ltl │ ├── __init__.py │ ├── ltl_equiv │ │ ├── __init__.py │ │ ├── decomp_ltl_equiv_problem.py │ │ ├── ltl_equiv_problem.py │ │ ├── ltl_equiv_status.py │ │ └── ltl_incl_status.py │ ├── ltl_formula.py │ ├── ltl_lexer.py │ ├── ltl_mc │ │ ├── __init__.py │ │ ├── decomp_ltl_mc_problem.py │ │ ├── ltl_mc_problem.py │ │ └── ltl_mc_status.py │ ├── ltl_parser.py │ ├── ltl_sat │ │ ├── __init__.py │ │ ├── decomp_ltl_sym_trace_problem.py │ │ ├── ltl_sat_dataset.py │ │ ├── ltl_sat_problem.py │ │ ├── ltl_sat_status.py │ │ ├── ltl_sat_sym_trace_problem.py │ │ ├── ltl_sat_trace_problem.py │ │ ├── ltl_sym_trace_problem.py │ │ └── ltl_trace_problem.py │ ├── ltl_spec │ │ ├── __init__.py │ │ ├── decomp_ltl_spec.py │ │ ├── decomp_ltl_spec_tokenizer.py │ │ ├── ltl_spec.py │ │ ├── ltl_spec_csv_dataset.py │ │ ├── ltl_spec_dataset.py │ │ ├── ltl_spec_pattern_csv_dataset.py │ │ ├── ltl_spec_pattern_dataset.py │ │ ├── ltl_spec_patterns │ │ │ ├── __init__.py │ │ │ ├── ltl_spec_pattern_grammar.py │ │ │ ├── ltl_spec_pattern_pcfg.py │ │ │ ├── ltl_spec_pattern_sampler.py │ │ │ └── ltl_spec_patterns.py │ │ └── ltl_spec_tokenizer.py │ └── ltl_syn │ │ ├── __init__.py │ │ ├── decomp_ltl_syn_problem.py │ │ ├── ltl_real_status.py │ │ ├── ltl_syn_data_gen_common.py │ │ ├── ltl_syn_data_gen_drop_repair.py │ │ ├── ltl_syn_data_gen_patterns.py │ │ ├── ltl_syn_data_gen_specs.py │ │ ├── ltl_syn_dataset.py │ │ ├── ltl_syn_eval_dataset.py │ │ ├── ltl_syn_problem.py │ │ ├── ltl_syn_solution_tokenizer.py │ │ ├── ltl_syn_status.py │ │ └── tf_syn_hier_transformer_pipeline.py ├── mealy │ ├── __init__.py │ ├── mealy_machine.py │ └── mealy_tokenizer.py ├── models │ ├── __init__.py │ ├── beam_search.py │ ├── tf_hierarchical_transformer.py │ ├── tf_transformer.py │ └── tf_transformer_metrics.py ├── optim │ ├── __init__.py │ └── tf_optim │ │ ├── __init__.py │ │ ├── tf_optimizers.py │ │ └── tf_transformer_lr_schedule.py ├── pipelines │ ├── __init__.py │ ├── beam_search_verification_pipeline.py │ ├── callbacks │ │ ├── __init__.py │ │ └── callback.py │ ├── hf_pipelines │ │ ├── __init__.py │ │ ├── hf_pt_expr2expr_pipeline.py │ │ ├── hf_pt_expr2text_pipeline.py │ │ └── hf_pt_text2text_pipeline.py │ ├── load_pipeline.py │ ├── loggers │ │ ├── __init__.py │ │ ├── csv_dataset_logger.py │ │ ├── csv_logger.py │ │ └── sample_logger.py │ ├── metrics │ │ ├── __init__.py │ │ ├── acc.py │ │ ├── acc_per_seq.py │ │ ├── counter.py │ │ ├── data_type_acc.py │ │ ├── equiv_acc.py │ │ ├── err_counter.py │ │ ├── load_metric.py │ │ ├── metric.py │ │ ├── metric_avg.py │ │ ├── metric_group.py │ │ ├── null_metric.py │ │ ├── sem_acc.py │ │ ├── sem_beam_acc.py │ │ ├── str_acc.py │ │ └── ver_status.py │ ├── model_pipeline.py │ ├── pipeline.py │ ├── samples │ │ ├── __init__.py │ │ ├── beam_search_sample.py │ │ ├── eval_sample.py │ │ ├── labeled_sample.py │ │ ├── portfolio_sample.py │ │ ├── sample.py │ │ └── verified_sample.py │ ├── seq2seq_pipeline.py │ ├── sl_pipeline.py │ ├── tf_hier_transformer_pipeline.py │ ├── tf_pipeline.py │ ├── tf_sl_pipeline.py │ ├── tf_transformer_pipeline.py │ └── verification_pipeline.py ├── prop │ ├── __init__.py │ ├── assignment.py │ ├── assignment_check_status.py │ ├── cnf │ │ ├── __init__.py │ │ ├── binarize_search_dataset.py │ │ ├── cnf_assign_problem.py │ │ ├── cnf_assignment.py │ │ ├── cnf_formula.py │ │ ├── cnf_formula_tokenizer.py │ │ ├── cnf_res_problem.py │ │ ├── cnf_sat_problem.py │ │ ├── cnf_sat_search_problem.py │ │ ├── convert_to_binary_res_proof.py │ │ ├── neurosat_data_gen_central.py │ │ ├── neurosat_data_gen_common.py │ │ ├── neurosat_data_gen_decentral.py │ │ ├── res_completion_problem.py │ │ ├── res_data_gen_central.py │ │ ├── res_data_gen_common.py │ │ ├── res_data_gen_decentral.py │ │ ├── res_proof.py │ │ ├── res_proof_status.py │ │ └── res_proof_tokenizer.py │ ├── prop_formula.py │ ├── prop_lexer.py │ ├── prop_parser.py │ ├── prop_sat_dataset.py │ ├── prop_sat_problem.py │ ├── prop_sat_status.py │ └── prop_valid_status.py ├── registry.py ├── tokenizers │ ├── __init__.py │ ├── cat_seq_tokenizers │ │ ├── __init__.py │ │ └── cat_seq_to_seq_tokenizer.py │ ├── cat_tokenizers │ │ ├── __init__.py │ │ └── cat_to_id_tokenizer.py │ ├── decomp_dtype_tokenizers │ │ ├── __init__.py │ │ ├── decomp_dtype_to_decomp_seq_pos_tokenizer.py │ │ └── decomp_dtype_to_decomp_seq_tokenizer.py │ ├── decomp_expr_pair_tokenizers │ │ ├── __init__.py │ │ └── decomp_expr_pair_to_decomp_seq_tpe_tokenizer.py │ ├── decomp_expr_tokenizers │ │ ├── __init__.py │ │ └── decomp_expr_to_decomp_seq_tpe_tokenizer.py │ ├── expr_tokenizers │ │ ├── __init__.py │ │ ├── expr_to_seq_tokenizer.py │ │ └── expr_to_seq_tpe_tokenizer.py │ ├── load_tokenizer.py │ ├── load_vocabulary.py │ ├── pair_tokenizers │ │ ├── __init__.py │ │ ├── cat_seq_pair_to_seq_tokenizer.py │ │ └── pair_to_seq_tokenizer.py │ ├── seq_tokenizers │ │ ├── __init__.py │ │ └── seq_to_seq_tokenizer.py │ ├── to_decomp_seq_pos_tokenizer.py │ ├── to_decomp_seq_tokenizer.py │ ├── to_id_tokenizer.py │ ├── to_seq_mask_tokenizer.py │ ├── to_seq_pos_tokenizer.py │ ├── to_seq_tokenizer.py │ ├── to_seq_tpe_tokenizer.py │ ├── tokenizer.py │ └── vocabulary.py ├── tools │ ├── __init__.py │ ├── aalta │ │ ├── __init__.py │ │ ├── aalta.py │ │ ├── aalta_grpc_server.py │ │ └── aalta_wrapper.py │ ├── abc_aiger │ │ ├── __init__.py │ │ ├── abc_aiger.py │ │ ├── abc_aiger_grpc_server.py │ │ ├── abc_wrapper.py │ │ ├── aiger_wrapper.py │ │ ├── graphviz_wrapper.py │ │ └── wrapper_helper.py │ ├── booleforce │ │ ├── __init__.py │ │ ├── booleforce.py │ │ ├── booleforce_grpc_server.py │ │ ├── booleforce_worker.py │ │ └── booleforce_wrapper.py │ ├── bosy │ │ ├── __init__.py │ │ ├── bosy.py │ │ ├── bosy_grpc_server.py │ │ ├── bosy_worker.py │ │ └── bosy_wrapper.py │ ├── grpc_service.py │ ├── limboole │ │ ├── __init__.py │ │ ├── limboole.py │ │ ├── limboole_grpc_server.py │ │ └── limboole_wrapper.py │ ├── ltl_tool │ │ ├── __init__.py │ │ ├── generic_model_checker.py │ │ ├── generic_synthesis_tool.py │ │ ├── pb2_converter.py │ │ ├── tool_ltl_conversion.py │ │ ├── tool_ltl_mc_problem.py │ │ └── tool_ltl_syn_problem.py │ ├── neurosynt │ │ ├── __init__.py │ │ ├── neurosynt.py │ │ ├── neurosynt_grpc_server.py │ │ └── pipeline_wrapper.py │ ├── nusmv │ │ ├── __init__.py │ │ ├── nusmv.py │ │ ├── nusmv_grpc_server.py │ │ └── nusmv_wrapper.py │ ├── nuxmv │ │ ├── __init__.py │ │ ├── nuxmv.py │ │ ├── nuxmv_grpc_server.py │ │ └── nuxmv_wrapper.py │ ├── semml │ │ ├── __init__.py │ │ ├── semml.py │ │ ├── semml_grpc_server.py │ │ └── semml_wrapper.py │ ├── spot │ │ ├── __init__.py │ │ ├── spot.py │ │ ├── spot_aiger_mc.py │ │ ├── spot_equiv_verifier.py │ │ ├── spot_grpc_server.py │ │ ├── spot_strace_mc.py │ │ └── spot_wrapper.py │ ├── strix │ │ ├── __init__.py │ │ ├── strix.py │ │ ├── strix_grpc_server.py │ │ ├── strix_worker.py │ │ └── strix_wrapper.py │ └── syfco │ │ ├── __init__.py │ │ ├── syfco.py │ │ ├── syfco_grpc_server.py │ │ └── tlsf_to_bosy.py ├── trace │ ├── __init__.py │ ├── sym_trace_to_seq_tokenizer.py │ ├── symbolic_trace.py │ ├── trace.py │ └── trace_mc_status.py ├── train │ ├── __init__.py │ ├── hf_seq2seq_trainer.py │ ├── keras_callbacks.py │ ├── keras_trainer.py │ ├── keras_trainer_ddp.py │ ├── keras_transformer_trainer.py │ ├── load_trainer.py │ └── trainer.py ├── utils │ ├── __init__.py │ ├── dict_utils.py │ ├── dist_utils.py │ ├── import_utils.py │ ├── list_utils.py │ ├── np_utils.py │ ├── pt_utils.py │ ├── tf_utils.py │ └── typing_utils.py └── verifier │ ├── __init__.py │ ├── equiv_status.py │ ├── equiv_verifier.py │ ├── load_verifier.py │ └── verifier.py ├── notebooks ├── aiger │ └── aiger_circuits.ipynb ├── datasets │ └── split_dataset.ipynb ├── ltl │ ├── ltl.ipynb │ ├── ltl_spec │ │ ├── decomp_ltl_spec.ipynb │ │ ├── ltl_spec_pattern_csv_dataset.ipynb │ │ └── ltl_spec_pattern_grammar.ipynb │ └── ltl_syn │ │ ├── ltl_syn_dataset.ipynb │ │ ├── ltl_syn_hier_pipe.ipynb │ │ └── ltl_syn_solution_tokenization.ipynb ├── mealy │ └── mealy_machine.ipynb ├── models │ └── count_params.ipynb ├── pipeline │ ├── tf_seq2seq_pipe_ltl_syn.ipynb │ └── tf_transformer_pipe_ltl_syn.ipynb ├── tokenizer │ └── expr_tokenizer.ipynb ├── tools │ ├── bosy.ipynb │ ├── neurosynt.ipynb │ ├── nusmv.ipynb │ ├── nuxmv.ipynb │ ├── strix.ipynb │ └── syfco.ipynb ├── trace │ └── symbolic_trace.ipynb └── train │ ├── keras_trainer_ltl_syn.ipynb │ └── keras_trainer_ltl_traces.ipynb ├── pyproject.toml ├── scripts ├── docker.sh └── protoc.sh ├── setup.cfg └── tests ├── aiger └── aiger_test.py ├── datasets └── csv_dataset_writer_test.py ├── gcp_bucket_test.py ├── layers └── attention_test.py ├── ltl ├── ltl_equiv │ └── ltl_equiv_test.py ├── ltl_formula_test.py ├── ltl_spec │ ├── decomp_ltl_spec_test.py │ ├── load_jarvis_test.py │ └── load_sc_test.py └── ltl_syn │ ├── load_scpa_test.py │ └── ltl_syn_solution_tokenizer_config_test.py ├── pipelines ├── metrics │ ├── acc_per_seq_test.py │ ├── acc_test.py │ └── err_counter_test.py ├── tf_transformer_pipeline_test.py └── tf_transformer_pipeline_test_config.json ├── prop ├── assignment_test.py ├── cnf │ ├── cnf_formula_test.py │ ├── cnf_sat_search_problem_test.py │ ├── formula.dimacs │ └── res_proof_test.py └── prop_formula_test.py ├── pytest.ini ├── tf_test.py ├── tokenizers ├── cat_seq_tokenizer │ └── cat_seq_to_seq_tokenizer_config_test.py ├── decomp_expr_pair_tokenizers │ └── decomp_expr_pair_to_decomp_seq_tpe_tokenizer_config_test.py ├── decomp_expr_tokenizers │ └── decomp_expr_to_decomp_seq_tpe_tokenizer_config_test.py ├── expr_tokenizer │ ├── expr_to_seq_tokenizer_config_test.py │ └── expr_to_seq_tpe_tokenizer_config_test.py ├── load_vocab_test.py ├── seq_tokenizer │ └── seq_to_seq_tokenizer_config_test.py └── tokenizer_config_test.py ├── tools ├── aalta_test.py ├── abc_aiger_test.py ├── booleforce_test.py ├── bosy_test.py ├── conftest.py ├── limboole_test.py ├── ltl_conversion_test.py ├── ltl_mc_test.py ├── ltl_syn_test.py ├── neurosynt_test.py ├── nusmv_test.py ├── nuxmv_test.py ├── semml_test.py ├── spot_test.py ├── strix_test.py └── syfco_test.py ├── trace ├── sym_trace_to_seq_tokenizer_config_test.py └── trace_test.py └── utils └── typing_utils_test.py /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 21.10b0 4 | hooks: 5 | - id: black 6 | language_version: python3.8 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.9.2 9 | hooks: 10 | - id: flake8 11 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.black-formatter", 4 | "ms-python.flake8", 5 | "ms-python.isort", 6 | "ms-python.python" 7 | ] 8 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "clang-format.style": "{ IndentWidth : 4 }", 3 | "flake8.severity": { 4 | "E": "Warning", 5 | "F": "Warning" 6 | }, 7 | "isort.args": [ 8 | "--profile", 9 | "black" 10 | ], 11 | "[python]": { 12 | "editor.codeActionsOnSave": { 13 | "source.organizeImports": "explicit" 14 | }, 15 | "editor.defaultFormatter": "ms-python.black-formatter", 16 | "editor.formatOnSave": true, 17 | }, 18 | "python.testing.unittestEnabled": false, 19 | "python.testing.pytestEnabled": true, 20 | "python.testing.pytestArgs": [ 21 | "tests" 22 | ], 23 | "python.analysis.typeCheckingMode": "basic" 24 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 CISPA Helmholtz Center for Information Security 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /configs/ltl-strace/iclr21/t-sem-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t-0-vp-eval", 3 | "auto_version": true, 4 | "project": "ltl-strace", 5 | "evaluation": [ 6 | { 7 | "type": "SupervisedEvalTask", 8 | "batch_size": 32, 9 | "dataset": { 10 | "base": "ltl-strace/rft-0/test", 11 | "sample": 256 12 | }, 13 | "pipeline": { 14 | "type": "BeamSearchVerificationPipeline", 15 | "verifier": { 16 | "type": "SpotSTraceMC", 17 | "start_containerized_service": true 18 | }, 19 | "pipeline": { 20 | "base": "ltl-strace/t-0/train/pipe", 21 | "beam_size": [3] 22 | } 23 | } 24 | } 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /configs/ltl-syn/dav23/codet5-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "codet5-25-eval", 3 | "auto_version": true, 4 | "project": "ltl-syn", 5 | "upload": false, 6 | "evaluation": [ 7 | { 8 | "type": "SupervisedEvalTask", 9 | "batch_size": 256, 10 | "dataset": { 11 | "base": "ltl-syn/scpa-2/test", 12 | "sample": 1024 13 | }, 14 | "pipeline": { 15 | "base": "ltl-syn/codet5-25/train/pipe", 16 | "beam_size": 1 17 | } 18 | } 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /configs/ltl-syn/dav23/codet5-sem-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "codet5-25-vp-eval", 3 | "auto_version": true, 4 | "project": "ltl-syn", 5 | "upload": false, 6 | "evaluation": [ 7 | { 8 | "type": "EvalTask", 9 | "batch_size": 256, 10 | "dataset": { 11 | "base": "ltl-spec/sc-0", 12 | "sample": 1024 13 | }, 14 | "pipeline": { 15 | "type": "BeamSearchVerificationPipeline", 16 | "verifier": { 17 | "type": "NuxmvMC", 18 | "start_containerized_service": true 19 | }, 20 | "pipeline": { 21 | "base": "ltl-syn/codet5-25/train/pipe", 22 | "beam_size": 1 23 | } 24 | } 25 | } 26 | ] 27 | } 28 | -------------------------------------------------------------------------------- /configs/ltl-syn/dav23/codet5.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "codet5", 3 | "auto_version": true, 4 | "project": "ltl-syn", 5 | "upload": false, 6 | "pipeline": { 7 | "type": "HFPTExpr2TextPipeline", 8 | "hf_checkpoint_name": "Salesforce/codet5-small", 9 | "input_dtype": "LTLSpec", 10 | "target_dtype": "LTLSynSolution", 11 | "hf_input_tokenizer": "Salesforce/codet5-small", 12 | "hf_target_tokenizer": "Salesforce/codet5-small", 13 | "max_input_length": 256, 14 | "max_target_length": 192, 15 | "input_notation": "prefix" 16 | }, 17 | "trainer": { 18 | "type": "HFSeq2SeqTrainer", 19 | "train_dataset": "ltl-syn/scpa-2/train", 20 | "val_dataset": { 21 | "base": "ltl-syn/scpa-2/val", 22 | "sample": 512 23 | }, 24 | "batch_size": 128, 25 | "learning_rate": 0.0005, 26 | "steps": 20000, 27 | "stream_to_wandb": false, 28 | "val_freq": 500 29 | }, 30 | "evaluation": [ 31 | { 32 | "type": "SupervisedEvalTask", 33 | "batch_size": 64, 34 | "dataset": { 35 | "base": "ltl-syn/scpa-2/test", 36 | "sample": 1024 37 | }, 38 | "pipeline": { 39 | "beam_size": 2 40 | } 41 | } 42 | ] 43 | } 44 | -------------------------------------------------------------------------------- /configs/ltl-syn/neurips21/ht-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ht-0-eval", 3 | "auto_version": true, 4 | "project": "ltl-syn", 5 | "evaluation": [ 6 | { 7 | "type": "SupervisedEvalTask", 8 | "batch_size": 32, 9 | "dataset": { 10 | "base": "ltl-syn/scpa-2/test", 11 | "sample": 1024 12 | }, 13 | "pipeline": { 14 | "base": "ltl-syn/ht-0/train/pipe", 15 | "beam_size": [1, 4, 8, 16] 16 | } 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /configs/ltl-syn/neurips21/ht-sem-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ht-0-vp-eval", 3 | "auto_version": true, 4 | "project": "ltl-syn", 5 | "evaluation": [ 6 | { 7 | "type": "SupervisedEvalTask", 8 | "batch_size": 32, 9 | "dataset": { 10 | "base": "ltl-syn/scpa-2/test", 11 | "sample": 1024 12 | }, 13 | "pipeline": { 14 | "type": "BeamSearchVerificationPipeline", 15 | "verifier": { 16 | "type": "NuxmvMC", 17 | "start_containerized_service": true 18 | }, 19 | "pipeline": { 20 | "base": "ltl-syn/ht-0/train/pipe", 21 | "beam_size": [1, 4, 8, 16] 22 | } 23 | } 24 | } 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /configs/ltl-syn/neurips21/t-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t-0-eval", 3 | "auto_version": true, 4 | "project": "ltl-syn", 5 | "evaluation": [ 6 | { 7 | "type": "SupervisedEvalTask", 8 | "batch_size": 32, 9 | "dataset": { 10 | "base": "ltl-syn/scpa-2/test", 11 | "sample": 1024 12 | }, 13 | "pipeline": { 14 | "base": "ltl-syn/t-0/train/pipe", 15 | "beam_size": [1, 4, 8, 16] 16 | } 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /configs/ltl-syn/neurips21/t-sem-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t-0-vp-eval", 3 | "auto_version": true, 4 | "project": "ltl-syn", 5 | "evaluation": [ 6 | { 7 | "type": "SupervisedEvalTask", 8 | "batch_size": 32, 9 | "dataset": { 10 | "base": "ltl-syn/scpa-2/test", 11 | "sample": 1024 12 | }, 13 | "pipeline": { 14 | "type": "BeamSearchVerificationPipeline", 15 | "verifier": { 16 | "type": "NuxmvMC", 17 | "start_containerized_service": true 18 | }, 19 | "pipeline": { 20 | "base": "ltl-syn/t-0/train/pipe", 21 | "beam_size": [1, 4, 8, 16] 22 | } 23 | } 24 | } 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /docker/aalta/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN export DEBIAN_FRONTEND=noninteractive && \ 4 | apt-get -q update && \ 5 | apt-get -q install -y \ 6 | build-essential \ 7 | git \ 8 | zlib1g-dev 9 | 10 | RUN git clone https://bitbucket.org/jl86/aalta.git && \ 11 | cd aalta && \ 12 | make release -------------------------------------------------------------------------------- /docker/aalta/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/aalta 4 | 5 | ARG PYTHON_VERSION=3.8.8 6 | 7 | RUN apt-get -q update && \ 8 | apt-get -q install -y \ 9 | libbz2-dev \ 10 | libssl-dev \ 11 | libffi-dev \ 12 | openssl \ 13 | wget 14 | 15 | RUN wget https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \ 16 | tar xzf Python-$PYTHON_VERSION.tgz && \ 17 | cd Python-$PYTHON_VERSION && \ 18 | ./configure --enable-optimizations && \ 19 | make && make install && \ 20 | cd / && \ 21 | rm Python-$PYTHON_VERSION.tgz && \ 22 | rm -r Python-$PYTHON_VERSION 23 | 24 | COPY ml2 ml2/ml2 25 | COPY LICENSE pyproject.toml setup.cfg ml2/ 26 | RUN pip3 install --no-cache-dir --upgrade pip && \ 27 | pip3 install --no-cache-dir ml2/ 28 | 29 | 30 | ENTRYPOINT [ "python3", "-m", "ml2.tools.aalta.aalta_grpc_server"] -------------------------------------------------------------------------------- /docker/abc_aiger/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | RUN apt-get -q update \ 4 | && apt-get -q upgrade -y \ 5 | && apt-get install git build-essential libreadline-dev graphviz -y 6 | 7 | RUN git clone https://github.com/berkeley-abc/abc.git \ 8 | && cd abc \ 9 | && make -j 4 10 | 11 | RUN git clone https://github.com/arminbiere/lingeling.git && \ 12 | cd lingeling && \ 13 | ./configure.sh && \ 14 | make -j 4 15 | 16 | RUN git clone https://github.com/arminbiere/aiger.git \ 17 | && cd aiger \ 18 | && ./configure.sh \ 19 | && make -j 4 20 | 21 | 22 | -------------------------------------------------------------------------------- /docker/abc_aiger/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/abc_aiger 4 | 5 | RUN apt-get -q update && \ 6 | apt-get -q install -y \ 7 | libbz2-dev \ 8 | libssl-dev \ 9 | openssl \ 10 | wget \ 11 | python3.11 12 | 13 | COPY ml2 ml2/ml2 14 | COPY LICENSE pyproject.toml setup.cfg ml2/ 15 | RUN wget https://bootstrap.pypa.io/get-pip.py && \ 16 | python3.11 get-pip.py && \ 17 | python3.11 -m pip install --no-cache-dir --upgrade pip && \ 18 | python3.11 -m pip install pydot && \ 19 | # python3.11 -m pip install --no-cache-dir ml2/ 20 | python3.11 -m pip install -e ml2/ 21 | 22 | 23 | ENTRYPOINT [ "python3.11", "-m", "ml2.tools.abc_aiger.abc_aiger_grpc_server"] 24 | 25 | -------------------------------------------------------------------------------- /docker/avr/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:jammy 2 | 3 | RUN apt-get -q update && \ 4 | DEBIAN_FRONTEND=noninteractive \ 5 | apt-get -q install -y \ 6 | git \ 7 | python-is-python3 \ 8 | sudo \ 9 | tzdata \ 10 | wget 11 | 12 | RUN git clone https://github.com/aman-goel/avr.git 13 | 14 | RUN cd avr && \ 15 | chmod +x build.sh && \ 16 | chmod +x deps/build_deps.sh && \ 17 | chmod +x build/avr && \ 18 | yes | ./build.sh -------------------------------------------------------------------------------- /docker/booleforce/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | RUN export DEBIAN_FRONTEND=noninteractive && \ 4 | apt-get -q update && \ 5 | apt-get -q install -y \ 6 | build-essential \ 7 | wget 8 | 9 | RUN wget http://fmv.jku.at/booleforce/booleforce-1.3.tar.gz && \ 10 | tar xzf booleforce-1.3.tar.gz && \ 11 | rm booleforce-1.3.tar.gz && \ 12 | cd booleforce-1.3 && \ 13 | ./configure --trace && \ 14 | make -------------------------------------------------------------------------------- /docker/booleforce/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/booleforce 4 | 5 | RUN apt-get -q update && \ 6 | apt-get -q install -y \ 7 | python3 \ 8 | python3-pip 9 | 10 | COPY ml2 ml2/ml2 11 | COPY LICENSE pyproject.toml setup.cfg ml2/ 12 | RUN pip3 install --no-cache-dir --upgrade pip && \ 13 | pip3 install --no-cache-dir ml2/ 14 | 15 | 16 | ENTRYPOINT [ "python3", "-m", "ml2.tools.booleforce.booleforce_grpc_server"] -------------------------------------------------------------------------------- /docker/bosy/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/bosy 4 | 5 | RUN apt-get -q update && \ 6 | apt-get -q install -y \ 7 | libbz2-dev \ 8 | libssl-dev \ 9 | openssl \ 10 | wget \ 11 | python3 \ 12 | python3-pip 13 | 14 | # copying ml2 files into BoSy directory as BoSy can only be started in its own directory 15 | 16 | WORKDIR /root/bosy 17 | 18 | COPY ml2 ml2/ml2 19 | COPY LICENSE pyproject.toml setup.cfg ml2/ 20 | RUN pip3 install --no-cache-dir --upgrade pip && \ 21 | pip3 install --no-cache-dir ml2/ 22 | 23 | ENTRYPOINT [ "python3", "-m", "ml2.tools.bosy.bosy_grpc_server"] 24 | 25 | -------------------------------------------------------------------------------- /docker/deps/cpu.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=amd64 tensorflow/tensorflow:2.16.2 2 | 3 | # Google Cloud SDK 4 | 5 | # ENV CLOUDSDK_PYTHON=/usr/bin/python3 6 | 7 | RUN curl -fsSL -o google-cloud-cli.tar.gz https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz && \ 8 | tar -xzf google-cloud-cli.tar.gz -C /usr/local && \ 9 | rm google-cloud-cli.tar.gz && \ 10 | /usr/local/google-cloud-sdk/install.sh -q --path-update true 11 | 12 | # Docker engine 13 | 14 | RUN curl -fsSL -o get-docker.sh https://get.docker.com && \ 15 | sh ./get-docker.sh && \ 16 | rm get-docker.sh 17 | 18 | # PyPI dependencies 19 | 20 | RUN pip --no-cache-dir install --upgrade pip 21 | # Installing backwards-compatible tf-keras package because Transformers does not yet support Keras 3 22 | RUN pip install tf-keras==2.16 23 | RUN pip --no-cache-dir install datasets docker google-cloud-storage grpcio jupyter matplotlib nltk numpy pandas ray[default] sly transformers torch tqdm wandb 24 | -------------------------------------------------------------------------------- /docker/deps/dev-cpu.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/deps:cpu 4 | 5 | RUN export DEBIAN_FRONTEND=noninteractive && \ 6 | apt-get -q update && \ 7 | apt-get -q install -y \ 8 | git \ 9 | screen \ 10 | tmux 11 | 12 | RUN pip --no-cache-dir install black flake8 flake8-quotes grpcio-tools isort mypy mypy-protobuf pre-commit pytest rbql rstcheck sphinx==4.0.2 13 | -------------------------------------------------------------------------------- /docker/deps/dev-gpu.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/deps:gpu 4 | 5 | RUN export DEBIAN_FRONTEND=noninteractive && \ 6 | apt-get -q update && \ 7 | apt-get -q install -y \ 8 | git \ 9 | screen \ 10 | tmux 11 | 12 | RUN pip --no-cache-dir install black flake8 flake8-quotes grpcio-tools isort mypy mypy-protobuf pre-commit pytest rbql rstcheck sphinx==4.0.2 13 | -------------------------------------------------------------------------------- /docker/deps/gpu.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=amd64 tensorflow/tensorflow:2.16.2-gpu 2 | 3 | # Google Cloud SDK 4 | 5 | # ENV CLOUDSDK_PYTHON=/usr/bin/python3 6 | 7 | RUN curl -fsSL -o google-cloud-cli.tar.gz https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz && \ 8 | tar -xzf google-cloud-cli.tar.gz -C /usr/local && \ 9 | rm google-cloud-cli.tar.gz && \ 10 | /usr/local/google-cloud-sdk/install.sh -q --path-update true 11 | 12 | # Docker engine 13 | 14 | RUN curl -fsSL -o get-docker.sh https://get.docker.com && \ 15 | sh ./get-docker.sh && \ 16 | rm get-docker.sh 17 | 18 | # PyPI dependencies 19 | 20 | RUN pip --no-cache-dir install --upgrade pip 21 | # Installing backwards-compatible tf-keras package because Transformers does not yet support Keras 3 22 | RUN pip install tf-keras==2.16 23 | RUN pip --no-cache-dir install torch --index-url https://download.pytorch.org/whl/cu118 24 | RUN pip --no-cache-dir install datasets docker google-cloud-storage grpcio jupyter matplotlib nltk numpy pandas ray[default] sly transformers[torch] tqdm wandb 25 | -------------------------------------------------------------------------------- /docker/limboole/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/lingeling:latest 4 | 5 | RUN export DEBIAN_FRONTEND=noninteractive && \ 6 | apt-get -q update && \ 7 | apt-get -q install -y \ 8 | wget 9 | 10 | RUN wget http://fmv.jku.at/limboole/limboole1.2.tgz && \ 11 | tar xzf limboole1.2.tgz && \ 12 | rm limboole1.2.tgz && \ 13 | cd limboole1.2 && \ 14 | ./configure.sh --lingeling && \ 15 | make -------------------------------------------------------------------------------- /docker/limboole/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/limboole 4 | 5 | ARG PYTHON_VERSION=3.8.12 6 | 7 | RUN apt-get -q update && \ 8 | apt-get -q install -y \ 9 | libbz2-dev \ 10 | libssl-dev \ 11 | libffi-dev \ 12 | openssl \ 13 | zlib1g-dev 14 | 15 | RUN wget https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \ 16 | tar xzf Python-$PYTHON_VERSION.tgz && \ 17 | cd Python-$PYTHON_VERSION && \ 18 | ./configure --enable-optimizations && \ 19 | make && make install && \ 20 | cd / && \ 21 | rm Python-$PYTHON_VERSION.tgz && \ 22 | rm -r Python-$PYTHON_VERSION 23 | 24 | COPY ml2 ml2/ml2 25 | COPY LICENSE pyproject.toml setup.cfg ml2/ 26 | RUN pip3 install --no-cache-dir --upgrade pip && \ 27 | pip3 install --no-cache-dir ml2/ 28 | 29 | ENTRYPOINT [ "python3", "-m", "ml2.tools.limboole.limboole_grpc_server"] -------------------------------------------------------------------------------- /docker/ml2/cpu.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/deps:cpu 4 | 5 | COPY . /ml2 6 | 7 | RUN pip --no-cache-dir install /ml2 -------------------------------------------------------------------------------- /docker/ml2/gpu.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/deps:gpu 4 | 5 | COPY . /ml2 6 | 7 | RUN pip --no-cache-dir install /ml2 -------------------------------------------------------------------------------- /docker/neurosynt-grpc-server/cpu.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/deps:cpu 4 | 5 | COPY LICENSE \ 6 | README.rst \ 7 | pyproject.toml \ 8 | setup.cfg \ 9 | ml2/ 10 | 11 | COPY ml2 ml2/ml2 12 | 13 | RUN pip --no-cache-dir install /ml2[full] 14 | 15 | ENV ML2_GCP_BUCKET=ml2-public 16 | 17 | ENTRYPOINT [ "python3", "-m", "ml2.tools.neurosynt.neurosynt_grpc_server"] -------------------------------------------------------------------------------- /docker/neurosynt-grpc-server/gpu.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/deps:gpu 4 | 5 | COPY LICENSE \ 6 | README.rst \ 7 | pyproject.toml \ 8 | setup.cfg \ 9 | ml2/ 10 | 11 | COPY ml2 ml2/ml2 12 | 13 | RUN pip --no-cache-dir install /ml2[full] 14 | 15 | ENV ML2_GCP_BUCKET=ml2-public 16 | 17 | ENTRYPOINT [ "python3", "-m", "ml2.tools.neurosynt.neurosynt_grpc_server"] -------------------------------------------------------------------------------- /docker/nusmv/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:jammy 2 | 3 | RUN export DEBIAN_FRONTEND=noninteractive && \ 4 | apt-get -q update && \ 5 | apt-get -q install -y \ 6 | build-essential \ 7 | git \ 8 | python3-pip \ 9 | wget 10 | 11 | RUN git clone https://github.com/frederikschmitt/aiger.git && \ 12 | cd aiger && \ 13 | ./configure.sh && \ 14 | make 15 | 16 | RUN wget http://www.lrde.epita.fr/dload/spot/spot-2.11.5.tar.gz && \ 17 | tar xzf spot-2.11.5.tar.gz && \ 18 | rm spot-2.11.5.tar.gz && \ 19 | cd spot-2.11.5 && \ 20 | ./configure && \ 21 | make 22 | 23 | RUN wget https://nusmv.fbk.eu/distrib/NuSMV-2.6.0-linux64.tar.gz && \ 24 | tar xzf NuSMV-2.6.0-linux64.tar.gz && \ 25 | rm NuSMV-2.6.0-linux64.tar.gz 26 | 27 | 28 | -------------------------------------------------------------------------------- /docker/nusmv/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/nusmv 4 | 5 | RUN apt-get -q update && \ 6 | apt-get -q install -y \ 7 | python3-pip 8 | 9 | COPY ml2 ml2/ml2 10 | COPY LICENSE pyproject.toml setup.cfg ml2/ 11 | RUN pip install --no-cache-dir --upgrade pip && \ 12 | pip install --no-cache-dir ml2/ 13 | 14 | 15 | ENTRYPOINT [ "python3", "-m", "ml2.tools.nusmv.nusmv_grpc_server"] 16 | -------------------------------------------------------------------------------- /docker/nuxmv/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/strix-opt 4 | 5 | RUN apt-get -q update && \ 6 | apt-get -q install -y \ 7 | python3-pip 8 | 9 | COPY ml2 ml2/ml2 10 | COPY LICENSE pyproject.toml setup.cfg ml2/ 11 | RUN pip install --no-cache-dir --upgrade pip && \ 12 | pip install --no-cache-dir ml2/ 13 | 14 | ENTRYPOINT [ "python3", "-m", "ml2.tools.nuxmv.nuxmv_grpc_server"] 15 | 16 | -------------------------------------------------------------------------------- /docker/semml/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | ENV LANG=C.UTF-8 4 | 5 | ARG TARGETPLATFORM 6 | 7 | RUN apt-get -q update \ 8 | && apt-get -q upgrade -y \ 9 | && apt-get install openjdk-17-jre unzip git -y 10 | 11 | RUN git clone https://gitlab.com/live-lab/software/semml.git \ 12 | && cd semml \ 13 | && git checkout artifact_tacas 14 | 15 | # Needs to be in one RUN command to avoid dynamic setting of JAVA_PLATFORM as ARG 16 | RUN if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ 17 | JAVA_PLATFORM="aarch64"; \ 18 | else \ 19 | JAVA_PLATFORM="amd64"; \ 20 | fi && \ 21 | PATH=/usr/lib/jvm/java-17-openjdk-${JAVA_PLATFORM}/bin:${PATH} && \ 22 | JAVA_HOME=/usr/lib/jvm/java-17-openjdk-${JAVA_PLATFORM} && \ 23 | cd semml && \ 24 | ./gradlew distZip && \ 25 | unzip build/distributions/semml-dev.zip 26 | 27 | # semml at /semml/semml-dev/bin/semml semmlMain --env i0,i1,i2 --sys o0,o1,o2 --formula 'G (i0 <-> o0)' -------------------------------------------------------------------------------- /docker/semml/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/semml 4 | 5 | RUN apt-get -q update && \ 6 | apt-get -q install -y \ 7 | libbz2-dev \ 8 | libssl-dev \ 9 | openssl \ 10 | wget \ 11 | python3.11 12 | 13 | COPY ml2 ml2/ml2 14 | COPY LICENSE pyproject.toml setup.cfg ml2/ 15 | RUN wget https://bootstrap.pypa.io/get-pip.py && \ 16 | python3.11 get-pip.py && \ 17 | python3.11 -m pip install --no-cache-dir --upgrade pip && \ 18 | python3.11 -m pip install --no-cache-dir ml2/ 19 | 20 | 21 | ENTRYPOINT [ "python3.11", "-m", "ml2.tools.semml.semml_grpc_server"] 22 | 23 | -------------------------------------------------------------------------------- /docker/spot/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:24.04 2 | 3 | ARG SPOT_VERSION=2.11.6 4 | 5 | RUN export DEBIAN_FRONTEND=noninteractive && \ 6 | apt-get -q update && \ 7 | apt-get -q install -y \ 8 | build-essential \ 9 | python3-pip \ 10 | wget 11 | 12 | RUN wget http://www.lrde.epita.fr/dload/spot/spot-$SPOT_VERSION.tar.gz && \ 13 | tar xzf spot-$SPOT_VERSION.tar.gz && \ 14 | cd spot-$SPOT_VERSION && \ 15 | ./configure --with-pythondir=/usr/local/lib/python3.12/dist-packages && \ 16 | make && \ 17 | make install && \ 18 | cd / && \ 19 | rm spot-$SPOT_VERSION.tar.gz && \ 20 | rm -r spot-$SPOT_VERSION 21 | 22 | ENV LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}" -------------------------------------------------------------------------------- /docker/spot/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/spot 4 | 5 | COPY ml2 ml2/ml2 6 | COPY LICENSE pyproject.toml setup.cfg ml2/ 7 | RUN pip install --break-system-packages --no-cache-dir ml2/ 8 | 9 | ENTRYPOINT [ "python3", "-m", "ml2.tools.spot.spot_grpc_server"] 10 | -------------------------------------------------------------------------------- /docker/strix/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=linux/amd64 ubuntu:jammy 2 | 3 | RUN apt-get -q update && \ 4 | DEBIAN_FRONTEND=noninteractive \ 5 | apt-get -q install -y \ 6 | build-essential \ 7 | curl \ 8 | git \ 9 | clang-14 \ 10 | zlib1g-dev 11 | 12 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 13 | ENV PATH="/root/.cargo/bin:${PATH}" 14 | 15 | ARG GRAALVM_VERSION=21.0.0.2 16 | 17 | ENV GRAALVM_PKG=https://github.com/graalvm/graalvm-ce-builds/releases/download/vm-$GRAALVM_VERSION/graalvm-ce-java11-linux-amd64-$GRAALVM_VERSION.tar.gz 18 | 19 | RUN mkdir /usr/lib/jvm && \ 20 | curl --fail --silent --location --retry 3 ${GRAALVM_PKG} \ 21 | | gunzip \ 22 | | tar x -C /usr/lib/jvm && \ 23 | mv /usr/lib/jvm/graalvm-ce-java11-$GRAALVM_VERSION /usr/lib/jvm/java-11-graalvm 24 | 25 | ENV PATH=/usr/lib/jvm/java-11-graalvm/bin:${PATH} \ 26 | JAVA_HOME=/usr/lib/jvm/java-11-graalvm 27 | 28 | RUN /usr/lib/jvm/java-11-graalvm/bin/gu install native-image 29 | 30 | RUN git clone https://github.com/meyerphi/strix.git && \ 31 | cd strix && \ 32 | git submodule init && \ 33 | git submodule update 34 | 35 | RUN cd strix && \ 36 | cargo build --release 37 | -------------------------------------------------------------------------------- /docker/strix/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/strix 4 | 5 | ARG PYTHON_VERSION=3.8.8 6 | 7 | RUN apt-get -q update && \ 8 | apt-get -q install -y \ 9 | libbz2-dev \ 10 | libssl-dev \ 11 | openssl \ 12 | wget 13 | 14 | RUN wget https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \ 15 | tar xzf Python-$PYTHON_VERSION.tgz && \ 16 | cd Python-$PYTHON_VERSION && \ 17 | ./configure && \ 18 | make && make install && \ 19 | cd / && \ 20 | rm Python-$PYTHON_VERSION.tgz && \ 21 | rm -r Python-$PYTHON_VERSION 22 | 23 | COPY ml2 ml2/ml2 24 | COPY LICENSE pyproject.toml setup.cfg ml2/ 25 | RUN pip3 install --no-cache-dir --upgrade pip && \ 26 | pip3 install --no-cache-dir ml2/ 27 | 28 | 29 | ENTRYPOINT [ "python3", "-m", "ml2.tools.strix.strix_grpc_server"] 30 | 31 | -------------------------------------------------------------------------------- /docker/strix/opt.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/strix 4 | 5 | ENV PATH=/root/.local/bin:${PATH} 6 | 7 | RUN apt-get -q update && \ 8 | apt-get -q install -y \ 9 | gettext-base \ 10 | haskell-stack \ 11 | sudo \ 12 | wget && \ 13 | stack upgrade 14 | 15 | RUN cd /strix/scripts && \ 16 | yes Y | ./install_dependencies.sh 17 | 18 | RUN git clone https://github.com/reactive-systems/syfco.git && \ 19 | cd syfco && \ 20 | stack install -------------------------------------------------------------------------------- /docker/syfco/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM haskell:9.0-slim 2 | 3 | RUN git clone https://github.com/reactive-systems/syfco.git && \ 4 | cd syfco && \ 5 | stack install -------------------------------------------------------------------------------- /docker/syfco/grpc-server.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CONTAINER_REGISTRY=ghcr.io/reactive-systems/ml2 2 | 3 | FROM $CONTAINER_REGISTRY/syfco 4 | 5 | ARG PYTHON_VERSION=3.11.11 6 | 7 | RUN apt-get -q update && \ 8 | apt-get -q install -y \ 9 | libbz2-dev \ 10 | libssl-dev \ 11 | libffi-dev \ 12 | openssl \ 13 | wget 14 | 15 | RUN wget https://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \ 16 | tar xzf Python-$PYTHON_VERSION.tgz && \ 17 | cd Python-$PYTHON_VERSION && \ 18 | ./configure --enable-optimizations && \ 19 | make && make install && \ 20 | cd / && \ 21 | rm Python-$PYTHON_VERSION.tgz && \ 22 | rm -r Python-$PYTHON_VERSION 23 | 24 | COPY ml2 ml2/ml2 25 | COPY LICENSE pyproject.toml setup.cfg ml2/ 26 | RUN pip3 install --no-cache-dir --upgrade pip && \ 27 | pip3 install --no-cache-dir ml2/ 28 | 29 | ENTRYPOINT [ "python3", "-m", "ml2.tools.syfco.syfco_grpc_server"] 30 | 31 | -------------------------------------------------------------------------------- /ml2/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ML2 - Machine Learning for Mathematics and Logics 3 | """ 4 | 5 | __version__ = "0.2.1" 6 | 7 | from . import aiger, datasets, dtypes, grpc, layers, ltl, models, optim, prop, tools, trace 8 | from .loading import load_artifact as load 9 | from .utils import is_pt_available, is_ray_available, is_tf_available 10 | 11 | if is_ray_available(): 12 | from . import data_gen 13 | 14 | if is_pt_available() and is_tf_available(): 15 | from . import experiment, pipelines, tokenizers, train 16 | -------------------------------------------------------------------------------- /ml2/aiger/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_pt_available, is_tf_available 2 | from .aiger_circuit import AIGERCircuit 3 | from .aiger_utils import header_ints_from_str 4 | 5 | if is_pt_available() and is_tf_available(): 6 | from .aiger_tokenizer import AIGERToSeqTokenizer 7 | -------------------------------------------------------------------------------- /ml2/aiger/aiger_utils.py: -------------------------------------------------------------------------------- 1 | """AIGER utils""" 2 | 3 | 4 | def header_ints_from_str(aiger: str): 5 | header = aiger.split("\n")[0] 6 | header_strs = header.split(" ") 7 | # first position is format identifier string aag 8 | max_var_id = int(header_strs[1]) 9 | num_inputs = int(header_strs[2]) 10 | num_latches = int(header_strs[3]) 11 | num_outputs = int(header_strs[4]) 12 | num_and_gates = int(header_strs[5]) 13 | return max_var_id, num_inputs, num_latches, num_outputs, num_and_gates 14 | 15 | 16 | def reconstruct_header_ints(circuit: str, num_inputs: int, num_outputs: int): 17 | lines = circuit.split("\n") 18 | max_var = 0 19 | for line in lines: 20 | for var in line.split(" "): 21 | if var.isdigit() and int(var) > max_var: 22 | max_var = int(var) 23 | max_var_id = max_var // 2 24 | lines = lines[num_inputs:] 25 | num_latches = 0 26 | for line in lines: 27 | if (len(line.split(" "))) == 2: 28 | num_latches += 1 29 | else: 30 | break 31 | lines = lines[num_latches + num_outputs :] 32 | num_and_gates = 0 33 | for line in lines: 34 | if (len(line.split(" "))) == 3: 35 | num_and_gates += 1 36 | else: 37 | break 38 | return max_var_id, num_inputs, num_latches, num_outputs, num_and_gates 39 | -------------------------------------------------------------------------------- /ml2/data_gen/__init__.py: -------------------------------------------------------------------------------- 1 | from .counting_data_gen_server import CountingDataGenServer 2 | from .data_gen_args import add_data_gen_args, add_dist_data_gen_args 3 | from .data_gen_server import DataGenServer 4 | from .data_server import DataServer 5 | from .progress_actor import ProgressActor 6 | from .progress_bar import ( 7 | data_writing_progress_bar, 8 | key_data_writing_progress_bar, 9 | progress_bar, 10 | progress_bar_init, 11 | ) 12 | -------------------------------------------------------------------------------- /ml2/data_gen/counting_data_gen_server.py: -------------------------------------------------------------------------------- 1 | """Abstract counting data generation server class""" 2 | 3 | from typing import Any, List, Optional 4 | 5 | import ray 6 | from ray.util.queue import Queue 7 | 8 | from .data_gen_server import DataGenServer 9 | from .progress_actor import ProgressActor 10 | 11 | 12 | class CountingDataGenServer(DataGenServer): 13 | def __init__( 14 | self, 15 | num_samples: int, 16 | progress_actor: ProgressActor, 17 | batch_size: int = 1, 18 | sample_queue: Queue = None, 19 | ): 20 | super().__init__( 21 | batch_size=batch_size, 22 | progress_actor=progress_actor, 23 | sample_queue=sample_queue, 24 | ) 25 | 26 | self.num_samples = num_samples 27 | 28 | self.progress_actor.update.remote("samples", 0) 29 | self.progress_actor.update.remote("processing", 0) 30 | 31 | def has_problems(self) -> bool: 32 | finished = ray.get(self.progress_actor.get.remote("samples")) 33 | processing = ray.get(self.progress_actor.get.remote("processing")) 34 | return finished + processing < self.num_samples 35 | 36 | def get_problem_batch(self) -> Optional[List[Any]]: 37 | if not self.has_problems(): 38 | return None 39 | batch = super().get_problem_batch() 40 | if batch is not None: 41 | self.progress_actor.update.remote("processing", len(batch)) 42 | return batch 43 | 44 | def post_problem_batch(self, batch: List[Any]) -> None: 45 | super().post_problem_batch(batch) 46 | self.progress_actor.update.remote("processing", -len(batch)) 47 | -------------------------------------------------------------------------------- /ml2/data_gen/data_gen_args.py: -------------------------------------------------------------------------------- 1 | """Common data generation arguments""" 2 | 3 | 4 | def add_data_gen_frac_args(parser): 5 | parser.add_argument("--train-frac", type=float, default=0.8, metavar="fraction") 6 | parser.add_argument("--val-frac", type=float, default=0.1, metavar="fraction") 7 | parser.add_argument("--test-frac", type=float, default=0.1, metavar="fraction") 8 | 9 | 10 | def add_data_gen_args(parser): 11 | add_data_gen_frac_args(parser) 12 | parser.add_argument( 13 | "--add-to-wandb", action="store_true", help="add data to Weights and Biases" 14 | ) 15 | parser.add_argument("-n", "--num-samples", type=int, default=100, help="number of samples") 16 | parser.add_argument("--name", type=str, metavar="NAME", required=True, help="dataset name") 17 | parser.add_argument("--project", type=str, metavar="PROJECT", help="dataset project") 18 | parser.add_argument( 19 | "-u", "--upload", action="store_true", help="upload generated data to GCP storage bucket" 20 | ) 21 | 22 | 23 | def add_dist_data_gen_args(parser): 24 | add_data_gen_args(parser) 25 | parser.add_argument( 26 | "--batch-size", type=int, default=10, help="size of batches provided to worker" 27 | ) 28 | parser.add_argument("--num-workers", type=int, default=4, help="number of workers") 29 | parser.add_argument( 30 | "--sleep-workers", type=int, default=0, help="sleep time between worker startup" 31 | ) 32 | parser.add_argument( 33 | "--mem-lim-workers", type=str, default="2g", help="Memory limit per worker" 34 | ) 35 | -------------------------------------------------------------------------------- /ml2/data_gen/data_gen_server.py: -------------------------------------------------------------------------------- 1 | """Abstract data generation server class""" 2 | 3 | from typing import Any, List, Optional 4 | 5 | from ray.util.queue import Queue 6 | 7 | from .progress_actor import ProgressActor 8 | 9 | 10 | class DataGenServer(object): 11 | def __init__( 12 | self, 13 | batch_size: int = 1, 14 | progress_actor: ProgressActor = None, 15 | sample_queue: Queue = None, 16 | ): 17 | self.batch_size = batch_size 18 | self.progress_actor = progress_actor 19 | self.sample_queue = sample_queue 20 | 21 | def get_problem(self) -> Optional[Any]: 22 | raise NotImplementedError() 23 | 24 | def get_problem_batch(self) -> Optional[List[Any]]: 25 | batch = [] 26 | for _ in range(self.batch_size): 27 | p = self.get_problem() 28 | if p is None: 29 | break 30 | batch.append(p) 31 | if batch == []: 32 | return None 33 | return batch 34 | 35 | def post_problem(self, problem: Any) -> None: 36 | raise NotImplementedError() 37 | 38 | def post_problem_batch(self, batch: List[Any]) -> None: 39 | for problem in batch: 40 | self.post_problem(problem) 41 | 42 | def deregister_worker(self): 43 | self.progress_actor.update.remote("worker", -1) 44 | 45 | def register_worker(self): 46 | self.progress_actor.update.remote("worker") 47 | -------------------------------------------------------------------------------- /ml2/data_gen/data_server.py: -------------------------------------------------------------------------------- 1 | """"Data server that gets problems from workers and sends them to the queue""" 2 | 3 | from typing import Any, List 4 | 5 | import ray 6 | from ray.util.queue import Queue 7 | 8 | from .progress_actor import ProgressActor 9 | 10 | 11 | @ray.remote 12 | class DataServer(object): 13 | def __init__( 14 | self, 15 | num_samples: int, 16 | progress_actor: ProgressActor = None, 17 | sample_queue: Queue = None, 18 | ): 19 | self.num_samples = num_samples 20 | self.progress_actor = progress_actor 21 | self.sample_queue = sample_queue 22 | 23 | self.progress_actor.update.remote("samples", 0) 24 | 25 | def finished(self) -> bool: 26 | return ray.get(self.progress_actor.get.remote("samples")) >= self.num_samples 27 | 28 | def post_problem(self, problem: Any) -> None: 29 | self.sample_queue.put(problem, block=True) 30 | self.progress_actor.update.remote("samples") 31 | 32 | def post_problem_batch(self, batch: List[Any]) -> None: 33 | for problem in batch: 34 | if not self.finished(): 35 | self.post_problem(problem) 36 | -------------------------------------------------------------------------------- /ml2/data_gen/progress_actor.py: -------------------------------------------------------------------------------- 1 | """Ray actor class for tracking progress of data generation""" 2 | 3 | from asyncio import Event 4 | 5 | import ray 6 | 7 | 8 | @ray.remote 9 | class ProgressActor: 10 | def __init__(self): 11 | self.progress = {} 12 | self.event = Event() 13 | 14 | def get(self, key): 15 | return self.progress[key] 16 | 17 | def get_progress(self): 18 | return self.progress 19 | 20 | def update(self, key, delta=1): 21 | if key in self.progress: 22 | self.progress[key] += delta 23 | self.event.set() 24 | else: 25 | self.progress[key] = delta 26 | self.progress = {key: self.progress[key] for key in sorted(self.progress.keys())} 27 | self.event.set() 28 | 29 | def update_multi(self, keys, delta=1): 30 | for key in keys: 31 | self.update(key, delta) 32 | 33 | async def wait_for_update(self): 34 | await self.event.wait() 35 | self.event.clear() 36 | return self.progress 37 | -------------------------------------------------------------------------------- /ml2/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .csv_dataset import CSVDataset 2 | from .csv_dataset_writer import ContCSVDatasetWriter, CSVDatasetWriter 3 | from .dataset import Dataset, list_datasets 4 | from .dataset_writer import DatasetWriter 5 | from .generator_dataset import GeneratorDataset 6 | from .load_dataset import load_dataset 7 | from .split_dataset import SplitDataset 8 | from .split_dataset_writer import SplitDatasetWriter 9 | -------------------------------------------------------------------------------- /ml2/datasets/dataset_writer.py: -------------------------------------------------------------------------------- 1 | """Abstract dataset writer class""" 2 | 3 | from abc import abstractmethod 4 | from typing import Generic, TypeVar 5 | 6 | from ..artifact import Artifact 7 | from ..dtypes import DType 8 | 9 | T = TypeVar("T", bound=DType) 10 | 11 | 12 | class DatasetWriter(Artifact, Generic[T]): 13 | @abstractmethod 14 | def add_sample(self, sample: T, **kwargs) -> None: 15 | raise NotImplementedError() 16 | 17 | @abstractmethod 18 | def close(self) -> None: 19 | raise NotImplementedError() 20 | 21 | @abstractmethod 22 | def size(self, **kwargs) -> int: 23 | raise NotImplementedError() 24 | -------------------------------------------------------------------------------- /ml2/datasets/load_dataset.py: -------------------------------------------------------------------------------- 1 | """Utility to load dataset""" 2 | 3 | import logging 4 | 5 | from .dataset import Dataset 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def load_dataset(name: str, project: str = None, **kwargs) -> Dataset: 12 | from ..registry import type_from_str 13 | 14 | config = Dataset.fetch_config(name=name, project=project) 15 | if "type" not in config: 16 | raise Exception("Dataset type not specified in config") 17 | dataset_type = type_from_str(config["type"], bound=Dataset) 18 | return dataset_type.load(name=name, project=project, **kwargs) 19 | -------------------------------------------------------------------------------- /ml2/datasets/utils.py: -------------------------------------------------------------------------------- 1 | """Collection of useful functions related to data processing""" 2 | 3 | 4 | def int_to_abbrev_str(n: int): 5 | """Given an integer returns an abbreviated string representing the integer, e.g., '100K' given 100000""" 6 | if n > 0 and n % 10**6 == 0: 7 | return f"{n // 10**6}M" 8 | elif n > 0 and n % 10**3 == 0: 9 | return f"{n // 10**3}K" 10 | else: 11 | return f"{n}" 12 | 13 | 14 | def from_csv_str(s: str): 15 | """Escapes a string that is read from a csv file""" 16 | s = s.replace("\\n", "\n") 17 | return s.replace('"', '"') 18 | 19 | 20 | def to_csv_str(s: str): 21 | """Escapes a string that is supposed to be written to a csv file""" 22 | s = s.replace("\n", "\\n") 23 | return s.replace('"', '"') 24 | -------------------------------------------------------------------------------- /ml2/dtypes/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary_ast import BinaryAST, TPEFormat 2 | from .binary_expr import BinaryExpr 3 | from .cat import Cat 4 | from .cat_seq import CatSeq, GenericCatSeq 5 | from .csv_dict import CSVDict 6 | from .csv_dtype import CSV, CSVLoggable 7 | from .csv_dtype_with_id import CSVWithId 8 | from .decomp_binary_expr import DecompBinaryExpr 9 | from .decomp_binary_expr_pair import DecompBinaryExprPair 10 | from .decomp_dtype import DecompDType, GenericDecompDType 11 | from .dtype import DType 12 | from .hashable import Hashable 13 | from .pair import GenericPair, Pair 14 | from .seq import Seq 15 | from .string import String 16 | from .supervised import GenericSupervised, Supervised 17 | from .tree import Tree 18 | from .validation_result import CSVLoggableValidationResult, ValidationResult 19 | -------------------------------------------------------------------------------- /ml2/dtypes/cat.py: -------------------------------------------------------------------------------- 1 | """Abstract categorical data type class""" 2 | 3 | from abc import abstractmethod 4 | 5 | from .dtype import DType 6 | 7 | 8 | class Cat(DType): 9 | def __eq__(self, other): 10 | if isinstance(other, self.__class__): 11 | return self.token() == other.token() 12 | return False 13 | 14 | def __repr__(self): 15 | return f"<{self.__class__.__name__}: {self.token()}>" 16 | 17 | def size(self, **kwargs) -> int: 18 | return 1 19 | 20 | @abstractmethod 21 | def token(self, **kwargs) -> str: 22 | raise NotImplementedError() 23 | 24 | @property 25 | def value(self) -> str: 26 | return self.token() 27 | 28 | @classmethod 29 | @abstractmethod 30 | def from_token(cls, token: str, **kwargs) -> "Cat": 31 | raise NotImplementedError() 32 | -------------------------------------------------------------------------------- /ml2/dtypes/csv_dict.py: -------------------------------------------------------------------------------- 1 | """Simple dict data type that inherits from CSV data type""" 2 | 3 | 4 | from typing import Dict, List 5 | 6 | from ..registry import register_type 7 | from .csv_dtype import CSV 8 | 9 | 10 | @register_type 11 | class CSVDict(CSV, dict): 12 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 13 | assert all(isinstance(key, str) for key in self) 14 | return {key: str(value) for key, value in self.items()} 15 | 16 | @classmethod 17 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "CSVDict": 18 | return cls(fields) 19 | 20 | @classmethod 21 | def _csv_field_header(cls, **kwargs) -> List[str]: 22 | raise NotImplementedError("Header not supported for generic dicts") 23 | -------------------------------------------------------------------------------- /ml2/dtypes/decomp_dtype.py: -------------------------------------------------------------------------------- 1 | """Decomposed data type classes""" 2 | 3 | from abc import abstractmethod 4 | from typing import Generic, List, TypeVar 5 | 6 | from .dtype import DType 7 | 8 | T = TypeVar("T", bound=DType) 9 | 10 | 11 | class DecompDType(DType, Generic[T]): 12 | @abstractmethod 13 | def __iter__(self): 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def __next__(self) -> T: 18 | raise NotImplementedError() 19 | 20 | @property 21 | @abstractmethod 22 | def len(self) -> int: 23 | raise NotImplementedError() 24 | 25 | @classmethod 26 | @abstractmethod 27 | def from_components(cls, components: List[T], **kwargs) -> "DecompDType[T]": 28 | raise NotImplementedError() 29 | 30 | 31 | class GenericDecompDType(list, DecompDType[T], Generic[T]): 32 | def __init__( 33 | self, 34 | components: List[T] = None, 35 | ) -> None: 36 | super().__init__(components if components is not None else []) 37 | 38 | def size(self, **kwargs) -> int: 39 | return sum([c.size(**kwargs) for c in self]) 40 | 41 | @classmethod 42 | def from_components(cls, components: List[T], **kwargs) -> "GenericDecompDType[T]": 43 | return cls(components=components) 44 | -------------------------------------------------------------------------------- /ml2/dtypes/dtype.py: -------------------------------------------------------------------------------- 1 | """Abstract data type class""" 2 | 3 | from abc import abstractmethod 4 | 5 | 6 | class DType(object): 7 | @abstractmethod 8 | def size(self, **kwargs) -> int: 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /ml2/dtypes/hashable.py: -------------------------------------------------------------------------------- 1 | """Abstract Hashable class""" 2 | 3 | import logging 4 | from abc import abstractmethod 5 | from typing import Optional 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class Hashable(object): 12 | _unique_id_value: Optional[str] = None 13 | 14 | def __init__(self, unique_id: Optional[str] = None) -> None: 15 | self._unique_id_value = unique_id 16 | 17 | def unique_id(self, catch_error: bool = False) -> Optional[str]: 18 | """Shows a unique identifier for this object. Either initialized or generated by its content dynamically. 19 | 20 | Returns: 21 | Optional[str]: hexadecimal value of an unique identifier. None if there is no identifier. 22 | """ 23 | if self._unique_id_value is not None: 24 | return self._unique_id_value 25 | elif catch_error: 26 | try: 27 | return f"{self.cr_hash:x}" 28 | except Exception as e: 29 | print("Error when creating unique id. Ignored!", e.__class__.__name__) 30 | return None 31 | else: 32 | return f"{self.cr_hash:x}" 33 | 34 | @property 35 | @abstractmethod 36 | def cr_hash(self) -> int: 37 | """A hash method that has collision resistance. 38 | Should be implemented for all child classes at some point. 39 | 40 | For 2 objects a, b it holds: 41 | a == b <-> cr_hash(a) == cr_hash(b) 42 | 43 | Returns: 44 | int: a hash of the obj. 45 | """ 46 | raise NotImplementedError 47 | -------------------------------------------------------------------------------- /ml2/dtypes/seq.py: -------------------------------------------------------------------------------- 1 | """Abstract sequence data type class""" 2 | 3 | from abc import abstractmethod 4 | from typing import List 5 | 6 | from .dtype import DType 7 | 8 | 9 | class Seq(DType): 10 | @abstractmethod 11 | def to_tokens(self, **kwargs) -> List[str]: 12 | raise NotImplementedError() 13 | 14 | def size(self, **kwargs) -> int: 15 | return len(self.to_tokens(**kwargs)) 16 | 17 | @classmethod 18 | @abstractmethod 19 | def from_tokens(cls, tokens: List[str], **kwargs) -> "Seq": 20 | raise NotImplementedError() 21 | -------------------------------------------------------------------------------- /ml2/dtypes/string.py: -------------------------------------------------------------------------------- 1 | """Simple string data type class""" 2 | 3 | 4 | from .dtype import DType 5 | 6 | 7 | class String(DType): 8 | def __init__(self, string: str) -> None: 9 | self._string = string 10 | 11 | def to_str(self, **kwargs) -> str: 12 | return self._string 13 | 14 | def size(self, **kwargs) -> int: 15 | return len(self.to_str(**kwargs)) 16 | 17 | @classmethod 18 | def from_str(cls, string: str, **kwargs) -> "String": 19 | return cls(string=string) 20 | -------------------------------------------------------------------------------- /ml2/dtypes/supervised.py: -------------------------------------------------------------------------------- 1 | """Supervised data type classes""" 2 | 3 | from abc import abstractmethod 4 | from typing import Generic, TypeVar 5 | 6 | from .dtype import DType 7 | from .pair import GenericPair, Pair 8 | 9 | # input type variable 10 | I = TypeVar("I", bound=DType) 11 | # target type variable 12 | T = TypeVar("T", bound=DType) 13 | 14 | 15 | class Supervised(Pair[I, T], Generic[I, T]): 16 | @property 17 | def fst(self) -> I: 18 | return self.input 19 | 20 | @property 21 | def snd(self) -> T: 22 | return self.target 23 | 24 | @property 25 | @abstractmethod 26 | def input(self) -> I: 27 | raise NotImplementedError() 28 | 29 | @property 30 | @abstractmethod 31 | def target(self) -> T: 32 | raise NotImplementedError() 33 | 34 | def __getitem__(self, key): 35 | if key == 0: 36 | return self.input 37 | elif key == 1: 38 | return self.target 39 | else: 40 | raise IndexError("Index out of range or no integer.") 41 | 42 | 43 | # inheritance order such that properties first and second are inherited from GenericPair and not Supervised 44 | class GenericSupervised(GenericPair[I, T], Supervised[I, T], Generic[I, T]): 45 | def __init__(self, input: I, target: T) -> None: 46 | super().__init__(fst=input, snd=target) 47 | 48 | @property 49 | def input(self) -> I: 50 | return self.fst 51 | 52 | @property 53 | def target(self) -> T: 54 | return self.snd 55 | -------------------------------------------------------------------------------- /ml2/dtypes/validation_result.py: -------------------------------------------------------------------------------- 1 | """Abstract validation result class""" 2 | 3 | from abc import abstractmethod 4 | from typing import Optional 5 | 6 | from .csv_dtype import CSVLoggable 7 | from .dtype import DType 8 | 9 | 10 | class ValidationResult(DType): 11 | @property 12 | @abstractmethod 13 | def validation_success(self) -> Optional[bool]: 14 | """Return true if validiation was succesfull""" 15 | raise NotImplementedError() 16 | 17 | @property 18 | @abstractmethod 19 | def validation_status(self) -> Optional[str]: 20 | """Return more detailed status of validation""" 21 | raise NotImplementedError() 22 | 23 | 24 | class CSVLoggableValidationResult(ValidationResult, CSVLoggable): 25 | pass 26 | -------------------------------------------------------------------------------- /ml2/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from .experiment import Experiment 2 | -------------------------------------------------------------------------------- /ml2/experiment/run.py: -------------------------------------------------------------------------------- 1 | """Script to run experiment""" 2 | 3 | import argparse 4 | 5 | from .experiment import Experiment 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser(description="ML2 experiment") 9 | parser.add_argument("config_file") 10 | args = parser.parse_args() 11 | experiment = Experiment.from_config_file(args.config_file) 12 | experiment.run() 13 | -------------------------------------------------------------------------------- /ml2/globals.py: -------------------------------------------------------------------------------- 1 | """ML2 global variables""" 2 | 3 | import os 4 | 5 | # local storage directory 6 | 7 | LOCAL_STORAGE_DIR = os.path.expanduser(os.environ.get("ML2_LOCAL_STORAGE_DIR", "~/ml2-storage")) 8 | 9 | # Google Cloud Platform 10 | 11 | ML2_BUCKET = os.environ.get("ML2_GCP_BUCKET", "ml2-public") 12 | 13 | # Docker 14 | 15 | CONTAINER_REGISTRY = os.environ.get("ML2_CONTAINER_REGISTRY", "ghcr.io/reactive-systems/ml2") 16 | 17 | # Weights and Biases 18 | 19 | WANDB_ENTITY = os.environ.get("ML2_WANDB_ENTITY") 20 | 21 | # Propositional satisfiability 22 | 23 | PROP_SAT_ALIASES = {} 24 | PROP_SAT_PROJECT_NAME = "prop-sat" 25 | 26 | # LTL satisfiability 27 | 28 | LTL_SAT_ALIASES = {} 29 | LTL_SAT_PROJECT_NAME = "ltl-sat" 30 | 31 | # LTL specifications 32 | 33 | LTL_SPEC_ALIASES = {"sc20": "sc-0", "scp-ni5-no5": "scp-0", "scp-ni5-no5-ts25": "scp-1"} 34 | LTL_SPEC_PROJECT_NAME = "ltl-spec" 35 | 36 | # LTL synthesis 37 | 38 | LTL_SYN_ALIASES = {} 39 | LTL_SYN_PROJECT_NAME = "ltl-syn" 40 | -------------------------------------------------------------------------------- /ml2/grpc/__init__.py: -------------------------------------------------------------------------------- 1 | from .aalta import aalta_pb2, aalta_pb2_grpc 2 | from .abc_aiger import abc_aiger_pb2, abc_aiger_pb2_grpc 3 | from .aiger import aiger_pb2, aiger_pb2_grpc 4 | from .booleforce import booleforce_pb2, booleforce_pb2_grpc 5 | from .bosy import bosy_pb2, bosy_pb2_grpc 6 | from .limboole import limboole_pb2, limboole_pb2_grpc 7 | from .ltl import ltl_mc_pb2, ltl_pb2, ltl_syn_pb2 8 | from .mealy import mealy_pb2, mealy_pb2_grpc 9 | from .neurosynt import neurosynt_pb2, neurosynt_pb2_grpc 10 | from .nuxmv import nuxmv_pb2, nuxmv_pb2_grpc 11 | from .prop import prop_pb2 12 | from .spot import spot_pb2, spot_pb2_grpc 13 | from .strix import strix_pb2, strix_pb2_grpc 14 | from .syfco import syfco_pb2, syfco_pb2_grpc 15 | from .system import system_pb2 16 | from .tools import tools_pb2 17 | from .trace import trace_pb2 18 | -------------------------------------------------------------------------------- /ml2/grpc/aalta/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/aalta/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/aalta/aalta.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message LTLSatProblemAalta { 4 | string formula = 1; 5 | bool simplify = 2; 6 | float timeout = 3; 7 | } 8 | 9 | message LTLSatSolutionAalta { 10 | string status = 1; 11 | string trace = 2; 12 | } 13 | 14 | service Aalta { 15 | rpc CheckSat(LTLSatProblemAalta) 16 | returns (LTLSatSolutionAalta) {} 17 | } -------------------------------------------------------------------------------- /ml2/grpc/abc_aiger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/abc_aiger/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/aiger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/aiger/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/aiger/aiger.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message AigerCircuit { string circuit = 1; } 4 | 5 | message AigerBinaryCircuit { string circuit = 1; } -------------------------------------------------------------------------------- /ml2/grpc/aiger/aiger_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/aiger/aiger.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aml2/grpc/aiger/aiger.proto\"\x1f\n\x0c\x41igerCircuit\x12\x0f\n\x07\x63ircuit\x18\x01 \x01(\t\"%\n\x12\x41igerBinaryCircuit\x12\x0f\n\x07\x63ircuit\x18\x01 \x01(\tb\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.aiger.aiger_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_AIGERCIRCUIT']._serialized_start=30 25 | _globals['_AIGERCIRCUIT']._serialized_end=61 26 | _globals['_AIGERBINARYCIRCUIT']._serialized_start=63 27 | _globals['_AIGERBINARYCIRCUIT']._serialized_end=100 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/aiger/aiger_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import builtins 7 | import google.protobuf.descriptor 8 | import google.protobuf.message 9 | import typing 10 | 11 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 12 | 13 | @typing.final 14 | class AigerCircuit(google.protobuf.message.Message): 15 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 16 | 17 | CIRCUIT_FIELD_NUMBER: builtins.int 18 | circuit: builtins.str 19 | def __init__( 20 | self, 21 | *, 22 | circuit: builtins.str = ..., 23 | ) -> None: ... 24 | def ClearField(self, field_name: typing.Literal["circuit", b"circuit"]) -> None: ... 25 | 26 | global___AigerCircuit = AigerCircuit 27 | 28 | @typing.final 29 | class AigerBinaryCircuit(google.protobuf.message.Message): 30 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 31 | 32 | CIRCUIT_FIELD_NUMBER: builtins.int 33 | circuit: builtins.str 34 | def __init__( 35 | self, 36 | *, 37 | circuit: builtins.str = ..., 38 | ) -> None: ... 39 | def ClearField(self, field_name: typing.Literal["circuit", b"circuit"]) -> None: ... 40 | 41 | global___AigerBinaryCircuit = AigerBinaryCircuit 42 | -------------------------------------------------------------------------------- /ml2/grpc/aiger/aiger_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/booleforce/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/booleforce/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/booleforce/booleforce.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/prop/prop.proto"; 4 | 5 | service BooleForce { 6 | // check satisfiability 7 | rpc CheckSat(CNFSatProblem) returns (CNFSatSolution) {} 8 | // check resolution proof 9 | rpc TraceCheck(ResProofCheckProblem) returns (ResProofCheckSolution) {} 10 | // binarize resolution proof 11 | rpc BinarizeResProof(ResProof) returns (ResProof) {} 12 | } 13 | -------------------------------------------------------------------------------- /ml2/grpc/booleforce/booleforce_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/booleforce/booleforce.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from ml2.grpc.prop import prop_pb2 as ml2_dot_grpc_dot_prop_dot_prop__pb2 16 | 17 | 18 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$ml2/grpc/booleforce/booleforce.proto\x1a\x18ml2/grpc/prop/prop.proto2\xa6\x01\n\nBooleForce\x12-\n\x08\x43heckSat\x12\x0e.CNFSatProblem\x1a\x0f.CNFSatSolution\"\x00\x12=\n\nTraceCheck\x12\x15.ResProofCheckProblem\x1a\x16.ResProofCheckSolution\"\x00\x12*\n\x10\x42inarizeResProof\x12\t.ResProof\x1a\t.ResProof\"\x00\x62\x06proto3') 19 | 20 | 21 | 22 | _BOOLEFORCE = DESCRIPTOR.services_by_name['BooleForce'] 23 | if _descriptor._USE_C_DESCRIPTORS == False: 24 | 25 | DESCRIPTOR._options = None 26 | _BOOLEFORCE._serialized_start=67 27 | _BOOLEFORCE._serialized_end=233 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/booleforce/booleforce_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | import google.protobuf.descriptor 6 | 7 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 8 | -------------------------------------------------------------------------------- /ml2/grpc/bosy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/bosy/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/bosy/bosy.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/ltl/ltl_syn.proto"; 4 | import "ml2/grpc/tools/tools.proto"; 5 | 6 | // BoSy is a synthesis tool based on a various bounded synthesis encodings. 7 | service Bosy { 8 | // Setup call, which is typically called before the first model checking call 9 | // has happened. 10 | rpc Setup(SetupRequest) returns (SetupResponse) {} 11 | // Call to find out the identity and functionality of the server, i.e. the 12 | // tool that is running the server and what it is supposed to do. 13 | rpc Identify(IdentificationRequest) returns (IdentificationResponse) {} 14 | // Call to synthesize a single LTL specification 15 | rpc Synthesize(LTLSynProblem) returns (LTLSynSolution) {} 16 | // Call to synthesize a stream of LTL specifications. Same order of problems 17 | // and solutions is assumed 18 | rpc SynthesizeStream(stream LTLSynProblem) returns (stream LTLSynSolution) {} 19 | } -------------------------------------------------------------------------------- /ml2/grpc/bosy/bosy_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/bosy/bosy.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from ml2.grpc.ltl import ltl_syn_pb2 as ml2_dot_grpc_dot_ltl_dot_ltl__syn__pb2 16 | from ml2.grpc.tools import tools_pb2 as ml2_dot_grpc_dot_tools_dot_tools__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18ml2/grpc/bosy/bosy.proto\x1a\x1aml2/grpc/ltl/ltl_syn.proto\x1a\x1aml2/grpc/tools/tools.proto2\xdb\x01\n\x04\x42osy\x12(\n\x05Setup\x12\r.SetupRequest\x1a\x0e.SetupResponse\"\x00\x12=\n\x08Identify\x12\x16.IdentificationRequest\x1a\x17.IdentificationResponse\"\x00\x12/\n\nSynthesize\x12\x0e.LTLSynProblem\x1a\x0f.LTLSynSolution\"\x00\x12\x39\n\x10SynthesizeStream\x12\x0e.LTLSynProblem\x1a\x0f.LTLSynSolution\"\x00(\x01\x30\x01\x62\x06proto3') 20 | 21 | _globals = globals() 22 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 23 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.bosy.bosy_pb2', _globals) 24 | if _descriptor._USE_C_DESCRIPTORS == False: 25 | DESCRIPTOR._options = None 26 | _globals['_BOSY']._serialized_start=85 27 | _globals['_BOSY']._serialized_end=304 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/bosy/bosy_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import google.protobuf.descriptor 7 | 8 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 9 | -------------------------------------------------------------------------------- /ml2/grpc/limboole/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/limboole/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/limboole/limboole.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/prop/prop.proto"; 4 | 5 | service Limboole { 6 | // check satisfiability 7 | rpc CheckSat(PropSatProblem) returns (PropSatSolution) {} 8 | // check validity 9 | rpc CheckValid(PropSatProblem) returns (PropSatSolution) {} 10 | } 11 | -------------------------------------------------------------------------------- /ml2/grpc/limboole/limboole_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/limboole/limboole.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from ml2.grpc.prop import prop_pb2 as ml2_dot_grpc_dot_prop_dot_prop__pb2 16 | 17 | 18 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n ml2/grpc/limboole/limboole.proto\x1a\x18ml2/grpc/prop/prop.proto2n\n\x08Limboole\x12/\n\x08\x43heckSat\x12\x0f.PropSatProblem\x1a\x10.PropSatSolution\"\x00\x12\x31\n\nCheckValid\x12\x0f.PropSatProblem\x1a\x10.PropSatSolution\"\x00\x62\x06proto3') 19 | 20 | 21 | 22 | _LIMBOOLE = DESCRIPTOR.services_by_name['Limboole'] 23 | if _descriptor._USE_C_DESCRIPTORS == False: 24 | 25 | DESCRIPTOR._options = None 26 | _LIMBOOLE._serialized_start=62 27 | _LIMBOOLE._serialized_end=172 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/limboole/limboole_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | import google.protobuf.descriptor 6 | 7 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 8 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ltl_mc_pb2, ltl_mc_pb2_grpc, ltl_pb2, ltl_pb2_grpc, ltl_syn_pb2, ltl_syn_pb2_grpc 2 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_equiv.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | import "ml2/grpc/trace/trace.proto"; 3 | 4 | message LTLEquivProblem { 5 | string formula1 = 1; 6 | string formula2 = 2; 7 | optional float timeout = 3; 8 | } 9 | 10 | message LTLEquivSolution { 11 | string status = 1; 12 | optional float time = 2; 13 | optional Trace exclusive_word = 3; 14 | } -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_equiv_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_sat.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message LTLSatProblem { 4 | string formula = 1; 5 | bool simplify = 2; 6 | float timeout = 3; 7 | } 8 | 9 | message LTLSatSolution { 10 | string status = 1; 11 | string trace = 2; 12 | } -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_sat_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/ltl/ltl_sat.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aml2/grpc/ltl/ltl_sat.proto\"C\n\rLTLSatProblem\x12\x0f\n\x07\x66ormula\x18\x01 \x01(\t\x12\x10\n\x08simplify\x18\x02 \x01(\x08\x12\x0f\n\x07timeout\x18\x03 \x01(\x02\"/\n\x0eLTLSatSolution\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\r\n\x05trace\x18\x02 \x01(\tb\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.ltl.ltl_sat_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_LTLSATPROBLEM']._serialized_start=30 25 | _globals['_LTLSATPROBLEM']._serialized_end=97 26 | _globals['_LTLSATSOLUTION']._serialized_start=99 27 | _globals['_LTLSATSOLUTION']._serialized_end=146 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_sat_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import builtins 7 | import google.protobuf.descriptor 8 | import google.protobuf.message 9 | import typing 10 | 11 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 12 | 13 | @typing.final 14 | class LTLSatProblem(google.protobuf.message.Message): 15 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 16 | 17 | FORMULA_FIELD_NUMBER: builtins.int 18 | SIMPLIFY_FIELD_NUMBER: builtins.int 19 | TIMEOUT_FIELD_NUMBER: builtins.int 20 | formula: builtins.str 21 | simplify: builtins.bool 22 | timeout: builtins.float 23 | def __init__( 24 | self, 25 | *, 26 | formula: builtins.str = ..., 27 | simplify: builtins.bool = ..., 28 | timeout: builtins.float = ..., 29 | ) -> None: ... 30 | def ClearField(self, field_name: typing.Literal["formula", b"formula", "simplify", b"simplify", "timeout", b"timeout"]) -> None: ... 31 | 32 | global___LTLSatProblem = LTLSatProblem 33 | 34 | @typing.final 35 | class LTLSatSolution(google.protobuf.message.Message): 36 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 37 | 38 | STATUS_FIELD_NUMBER: builtins.int 39 | TRACE_FIELD_NUMBER: builtins.int 40 | status: builtins.str 41 | trace: builtins.str 42 | def __init__( 43 | self, 44 | *, 45 | status: builtins.str = ..., 46 | trace: builtins.str = ..., 47 | ) -> None: ... 48 | def ClearField(self, field_name: typing.Literal["status", b"status", "trace", b"trace"]) -> None: ... 49 | 50 | global___LTLSatSolution = LTLSatSolution 51 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_sat_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_trace_mc.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message LTLTraceMCProblem { 4 | string formula = 1; 5 | string trace = 2; 6 | string timeout = 3; 7 | } 8 | 9 | message LTLTraceMCSolution { string status = 1; } -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_trace_mc_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/ltl/ltl_trace_mc.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fml2/grpc/ltl/ltl_trace_mc.proto\"D\n\x11LTLTraceMCProblem\x12\x0f\n\x07\x66ormula\x18\x01 \x01(\t\x12\r\n\x05trace\x18\x02 \x01(\t\x12\x0f\n\x07timeout\x18\x03 \x01(\t\"$\n\x12LTLTraceMCSolution\x12\x0e\n\x06status\x18\x01 \x01(\tb\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.ltl.ltl_trace_mc_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_LTLTRACEMCPROBLEM']._serialized_start=35 25 | _globals['_LTLTRACEMCPROBLEM']._serialized_end=103 26 | _globals['_LTLTRACEMCSOLUTION']._serialized_start=105 27 | _globals['_LTLTRACEMCSOLUTION']._serialized_end=141 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_trace_mc_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import builtins 7 | import google.protobuf.descriptor 8 | import google.protobuf.message 9 | import typing 10 | 11 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 12 | 13 | @typing.final 14 | class LTLTraceMCProblem(google.protobuf.message.Message): 15 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 16 | 17 | FORMULA_FIELD_NUMBER: builtins.int 18 | TRACE_FIELD_NUMBER: builtins.int 19 | TIMEOUT_FIELD_NUMBER: builtins.int 20 | formula: builtins.str 21 | trace: builtins.str 22 | timeout: builtins.str 23 | def __init__( 24 | self, 25 | *, 26 | formula: builtins.str = ..., 27 | trace: builtins.str = ..., 28 | timeout: builtins.str = ..., 29 | ) -> None: ... 30 | def ClearField(self, field_name: typing.Literal["formula", b"formula", "timeout", b"timeout", "trace", b"trace"]) -> None: ... 31 | 32 | global___LTLTraceMCProblem = LTLTraceMCProblem 33 | 34 | @typing.final 35 | class LTLTraceMCSolution(google.protobuf.message.Message): 36 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 37 | 38 | STATUS_FIELD_NUMBER: builtins.int 39 | status: builtins.str 40 | def __init__( 41 | self, 42 | *, 43 | status: builtins.str = ..., 44 | ) -> None: ... 45 | def ClearField(self, field_name: typing.Literal["status", b"status"]) -> None: ... 46 | 47 | global___LTLTraceMCSolution = LTLTraceMCSolution 48 | -------------------------------------------------------------------------------- /ml2/grpc/ltl/ltl_trace_mc_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/mealy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/mealy/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/mealy/mealy.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message MealyMachine { string machine = 1; } 4 | 5 | message MealyTransitions { string transitions = 1; } -------------------------------------------------------------------------------- /ml2/grpc/mealy/mealy_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/mealy/mealy.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aml2/grpc/mealy/mealy.proto\"\x1f\n\x0cMealyMachine\x12\x0f\n\x07machine\x18\x01 \x01(\t\"\'\n\x10MealyTransitions\x12\x13\n\x0btransitions\x18\x01 \x01(\tb\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.mealy.mealy_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_MEALYMACHINE']._serialized_start=30 25 | _globals['_MEALYMACHINE']._serialized_end=61 26 | _globals['_MEALYTRANSITIONS']._serialized_start=63 27 | _globals['_MEALYTRANSITIONS']._serialized_end=102 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/mealy/mealy_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import builtins 7 | import google.protobuf.descriptor 8 | import google.protobuf.message 9 | import typing 10 | 11 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 12 | 13 | @typing.final 14 | class MealyMachine(google.protobuf.message.Message): 15 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 16 | 17 | MACHINE_FIELD_NUMBER: builtins.int 18 | machine: builtins.str 19 | def __init__( 20 | self, 21 | *, 22 | machine: builtins.str = ..., 23 | ) -> None: ... 24 | def ClearField(self, field_name: typing.Literal["machine", b"machine"]) -> None: ... 25 | 26 | global___MealyMachine = MealyMachine 27 | 28 | @typing.final 29 | class MealyTransitions(google.protobuf.message.Message): 30 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 31 | 32 | TRANSITIONS_FIELD_NUMBER: builtins.int 33 | transitions: builtins.str 34 | def __init__( 35 | self, 36 | *, 37 | transitions: builtins.str = ..., 38 | ) -> None: ... 39 | def ClearField(self, field_name: typing.Literal["transitions", b"transitions"]) -> None: ... 40 | 41 | global___MealyTransitions = MealyTransitions 42 | -------------------------------------------------------------------------------- /ml2/grpc/mealy/mealy_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/neurosynt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/neurosynt/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/neurosynt/neurosynt.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/ltl/ltl_syn.proto"; 4 | import "ml2/grpc/tools/tools.proto"; 5 | 6 | // message SetupMessage { 7 | // int32 mc_port = 1; 8 | // int32 beam_size = 2; 9 | // string model = 3; 10 | // string verifier = 4; 11 | // int32 batch_size = 5; 12 | // float alpha = 6; 13 | // int32 num_properties = 7; 14 | // int32 length_properties = 8; 15 | // } 16 | 17 | service NeuroSynt { 18 | // Setup call, which is typically called before the first model checking call 19 | // has happened. 20 | rpc Setup(SetupRequest) returns (SetupResponse) {} 21 | // Call to find out the identity and functionality of the server, i.e. the 22 | // tool that is running the server and what it is supposed to do. 23 | rpc Identify(IdentificationRequest) returns (IdentificationResponse) {} 24 | // Call to synthesize a single LTL specification 25 | rpc Synthesize(LTLSynProblem) returns (NeuralLTLSynSolution) {} 26 | // Call to synthesize a stream of LTL specifications. Same order of problems 27 | // and solutions is assumed 28 | rpc SynthesizeStream(stream LTLSynProblem) 29 | returns (stream NeuralLTLSynSolution) {} 30 | // Call to synthesize a stream of LTL specifications batch-wise. Same order is 31 | // not guaranteed 32 | rpc SynthesizeBatch(stream LTLSynProblem) 33 | returns (stream NeuralLTLSynSolutionSpecPair) {} 34 | } -------------------------------------------------------------------------------- /ml2/grpc/neurosynt/neurosynt_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import google.protobuf.descriptor 7 | 8 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 9 | -------------------------------------------------------------------------------- /ml2/grpc/nusmv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/nusmv/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/nusmv/nusmv.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/ltl/ltl_mc.proto"; 4 | import "ml2/grpc/tools/tools.proto"; 5 | 6 | // The NuSMV model checker. https://nusmv.fbk.eu 7 | service NuSMV { 8 | // Setup call, which is typically called before the first model checking call 9 | // has happened. 10 | rpc Setup(SetupRequest) returns (SetupResponse) {} 11 | // Call to find out the identity and functionality of the server, i.e. the 12 | // tool that is running the server and what it is supposed to do. 13 | rpc Identify(IdentificationRequest) returns (IdentificationResponse) {} 14 | // Call to model-check a single problem 15 | rpc ModelCheck(LTLMCProblem) returns (LTLMCSolution) {} 16 | // Call to model-check a stream of problems. Same order of problems 17 | // and solutions is assumed 18 | rpc ModelCheckStream(stream LTLMCProblem) returns (stream LTLMCSolution) {} 19 | } -------------------------------------------------------------------------------- /ml2/grpc/nusmv/nusmv_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/nusmv/nusmv.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from ml2.grpc.ltl import ltl_mc_pb2 as ml2_dot_grpc_dot_ltl_dot_ltl__mc__pb2 16 | from ml2.grpc.tools import tools_pb2 as ml2_dot_grpc_dot_tools_dot_tools__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aml2/grpc/nusmv/nusmv.proto\x1a\x19ml2/grpc/ltl/ltl_mc.proto\x1a\x1aml2/grpc/tools/tools.proto2\xd8\x01\n\x05NuSMV\x12(\n\x05Setup\x12\r.SetupRequest\x1a\x0e.SetupResponse\"\x00\x12=\n\x08Identify\x12\x16.IdentificationRequest\x1a\x17.IdentificationResponse\"\x00\x12-\n\nModelCheck\x12\r.LTLMCProblem\x1a\x0e.LTLMCSolution\"\x00\x12\x37\n\x10ModelCheckStream\x12\r.LTLMCProblem\x1a\x0e.LTLMCSolution\"\x00(\x01\x30\x01\x62\x06proto3') 20 | 21 | _globals = globals() 22 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 23 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.nusmv.nusmv_pb2', _globals) 24 | if _descriptor._USE_C_DESCRIPTORS == False: 25 | DESCRIPTOR._options = None 26 | _globals['_NUSMV']._serialized_start=86 27 | _globals['_NUSMV']._serialized_end=302 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/nusmv/nusmv_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import google.protobuf.descriptor 7 | 8 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 9 | -------------------------------------------------------------------------------- /ml2/grpc/nuxmv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/nuxmv/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/nuxmv/nuxmv.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/ltl/ltl_mc.proto"; 4 | import "ml2/grpc/tools/tools.proto"; 5 | 6 | // The NuXmv model checker. https://nuxmv.fbk.eu/ 7 | service Nuxmv { 8 | // Setup call, which is typically called before the first model checking call 9 | // has happened. 10 | rpc Setup(SetupRequest) returns (SetupResponse) {} 11 | // Call to find out the identity and functionality of the server, i.e. the 12 | // tool that is running the server and what it is supposed to do. 13 | rpc Identify(IdentificationRequest) returns (IdentificationResponse) {} 14 | // Call to model-check a single problem 15 | rpc ModelCheck(LTLMCProblem) returns (LTLMCSolution) {} 16 | // Call to model-check a stream of problems. Same order of problems 17 | // and solutions is assumed 18 | rpc ModelCheckStream(stream LTLMCProblem) returns (stream LTLMCSolution) {} 19 | } -------------------------------------------------------------------------------- /ml2/grpc/nuxmv/nuxmv_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/nuxmv/nuxmv.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from ml2.grpc.ltl import ltl_mc_pb2 as ml2_dot_grpc_dot_ltl_dot_ltl__mc__pb2 16 | from ml2.grpc.tools import tools_pb2 as ml2_dot_grpc_dot_tools_dot_tools__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aml2/grpc/nuxmv/nuxmv.proto\x1a\x19ml2/grpc/ltl/ltl_mc.proto\x1a\x1aml2/grpc/tools/tools.proto2\xd8\x01\n\x05Nuxmv\x12(\n\x05Setup\x12\r.SetupRequest\x1a\x0e.SetupResponse\"\x00\x12=\n\x08Identify\x12\x16.IdentificationRequest\x1a\x17.IdentificationResponse\"\x00\x12-\n\nModelCheck\x12\r.LTLMCProblem\x1a\x0e.LTLMCSolution\"\x00\x12\x37\n\x10ModelCheckStream\x12\r.LTLMCProblem\x1a\x0e.LTLMCSolution\"\x00(\x01\x30\x01\x62\x06proto3') 20 | 21 | _globals = globals() 22 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 23 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.nuxmv.nuxmv_pb2', _globals) 24 | if _descriptor._USE_C_DESCRIPTORS == False: 25 | DESCRIPTOR._options = None 26 | _globals['_NUXMV']._serialized_start=86 27 | _globals['_NUXMV']._serialized_end=302 28 | # @@protoc_insertion_point(module_scope) 29 | -------------------------------------------------------------------------------- /ml2/grpc/nuxmv/nuxmv_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import google.protobuf.descriptor 7 | 8 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 9 | -------------------------------------------------------------------------------- /ml2/grpc/prop/__init__.py: -------------------------------------------------------------------------------- 1 | from . import prop_pb2, prop_pb2_grpc 2 | -------------------------------------------------------------------------------- /ml2/grpc/prop/prop_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/semml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/semml/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/semml/semml.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/ltl/ltl_syn.proto"; 4 | import "ml2/grpc/tools/tools.proto"; 5 | 6 | // SemML: Enhancing Automata-Theoretic LTL Synthesis with Machine Learning 7 | service Semml { 8 | // Setup call, which is typically called before the first model checking call 9 | // has happened. 10 | rpc Setup(SetupRequest) returns (SetupResponse) {} 11 | // Call to find out the identity and functionality of the server, i.e. the 12 | // tool that is running the server and what it is supposed to do. 13 | rpc Identify(IdentificationRequest) returns (IdentificationResponse) {} 14 | // Call to synthesize a single LTL specification 15 | rpc Synthesize(LTLSynProblem) returns (LTLSynSolution) {} 16 | // Call to synthesize a stream of LTL specifications. Same order of problems 17 | // and solutions is assumed 18 | rpc SynthesizeStream(stream LTLSynProblem) returns (stream LTLSynSolution) {} 19 | } -------------------------------------------------------------------------------- /ml2/grpc/semml/semml_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import google.protobuf.descriptor 7 | 8 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 9 | -------------------------------------------------------------------------------- /ml2/grpc/spot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/spot/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/strix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/strix/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/strix/strix.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "ml2/grpc/ltl/ltl_syn.proto"; 4 | import "ml2/grpc/tools/tools.proto"; 5 | 6 | // Strix is a tool for reactive LTL synthesis combining a direct translation of 7 | // LTL formulas into deterministic parity automata (DPA) and an efficient, 8 | // multi-threaded explicit state solver for parity games. 9 | service Strix { 10 | // Setup call, which is typically called before the first model checking call 11 | // has happened. 12 | rpc Setup(SetupRequest) returns (SetupResponse) {} 13 | // Call to find out the identity and functionality of the server, i.e. the 14 | // tool that is running the server and what it is supposed to do. 15 | rpc Identify(IdentificationRequest) returns (IdentificationResponse) {} 16 | // Call to synthesize a single LTL specification 17 | rpc Synthesize(LTLSynProblem) returns (LTLSynSolution) {} 18 | // Call to synthesize a stream of LTL specifications. Same order of problems 19 | // and solutions is assumed 20 | rpc SynthesizeStream(stream LTLSynProblem) returns (stream LTLSynSolution) {} 21 | } -------------------------------------------------------------------------------- /ml2/grpc/strix/strix_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import google.protobuf.descriptor 7 | 8 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 9 | -------------------------------------------------------------------------------- /ml2/grpc/syfco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/syfco/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/syfco/syfco.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/protobuf/duration.proto"; 4 | import "ml2/grpc/ltl/ltl.proto"; 5 | import "ml2/grpc/tools/tools.proto"; 6 | 7 | message TLSFFileString { string tlsf = 1; } 8 | 9 | message ConvertTLSFToSpecRequest { 10 | // Defines run- and tool-specific parameters. As Map (Dict in Python). 11 | // Typical examples are threads, timeouts etc. Can be empty. 12 | map parameters = 1; 13 | // A string, read from a TLSF file 14 | TLSFFileString tlsf = 2; 15 | } 16 | 17 | message ConvertTLSFToSpecResponse { 18 | // A string, read from a TLSF file 19 | optional DecompLTLSpecification specification = 1; 20 | // Here additional information should be supplied if something went wrong 21 | string error = 2; 22 | // Tool that created the response 23 | string tool = 3; 24 | // How long the tool took to create the result. 25 | optional google.protobuf.Duration time = 4; 26 | } 27 | 28 | // Syfco: A tool for reading, manipulating and transforming synthesis 29 | // specifications in TLSF. is a tool 30 | service Syfco { 31 | // Setup call, which is typically called before the first model checking call 32 | // has happened. 33 | rpc Setup(SetupRequest) returns (SetupResponse) {} 34 | // Call to find out the identity and functionality of the server, i.e. the 35 | // tool that is running the server and what it is supposed to do. 36 | rpc Identify(IdentificationRequest) returns (IdentificationResponse) {} 37 | // Call to synthesize a single LTL specification 38 | rpc ConvertTLSFToSpec(ConvertTLSFToSpecRequest) 39 | returns (ConvertTLSFToSpecResponse) {} 40 | } -------------------------------------------------------------------------------- /ml2/grpc/system/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/system/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/system/system.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | // All available system types 4 | // Can easily be extended without breaking backwards compatibility 5 | enum System { 6 | SYSTEM_AIGER = 0; 7 | SYSTEM_MEALY = 1; 8 | } -------------------------------------------------------------------------------- /ml2/grpc/system/system_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/system/system.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1cml2/grpc/system/system.proto*,\n\x06System\x12\x10\n\x0cSYSTEM_AIGER\x10\x00\x12\x10\n\x0cSYSTEM_MEALY\x10\x01\x62\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.system.system_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_SYSTEM']._serialized_start=32 25 | _globals['_SYSTEM']._serialized_end=76 26 | # @@protoc_insertion_point(module_scope) 27 | -------------------------------------------------------------------------------- /ml2/grpc/system/system_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import builtins 7 | import google.protobuf.descriptor 8 | import google.protobuf.internal.enum_type_wrapper 9 | import sys 10 | import typing 11 | 12 | if sys.version_info >= (3, 10): 13 | import typing as typing_extensions 14 | else: 15 | import typing_extensions 16 | 17 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 18 | 19 | class _System: 20 | ValueType = typing.NewType("ValueType", builtins.int) 21 | V: typing_extensions.TypeAlias = ValueType 22 | 23 | class _SystemEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_System.ValueType], builtins.type): 24 | DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor 25 | SYSTEM_AIGER: _System.ValueType # 0 26 | SYSTEM_MEALY: _System.ValueType # 1 27 | 28 | class System(_System, metaclass=_SystemEnumTypeWrapper): 29 | """All available system types 30 | Can easily be extended without breaking backwards compatibility 31 | """ 32 | 33 | SYSTEM_AIGER: System.ValueType # 0 34 | SYSTEM_MEALY: System.ValueType # 1 35 | global___System = System 36 | -------------------------------------------------------------------------------- /ml2/grpc/system/system_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/tools/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/tools/tools.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message SetupRequest { 4 | // Defines tool-specific parameters. As Map (Dict in Python). 5 | // Typical examples are memory limits etc. Can be empty. 6 | map parameters = 1; 7 | } 8 | 9 | message SetupResponse { 10 | bool success = 1; 11 | // If success is false, this should contain further information. 12 | string error = 2; 13 | } 14 | 15 | message IdentificationRequest {} 16 | 17 | // Announces itself bu giving the name of the tool and the version, running on 18 | // the grpc server 19 | message IdentificationResponse { 20 | // what tool is running on the grpc server 21 | string tool = 1; 22 | // the purpose of the grpc server. A grpc server can have multiple 23 | // functionalities. 24 | repeated Functionality functionalities = 2; 25 | // the version, the grpc server is currently running 26 | string version = 3; 27 | } 28 | 29 | // All available purposes of tools in ML2 / NeuroSynt 30 | // Can easily be extended to more purposes without breaking backwards 31 | // compatibility. 32 | enum Functionality { 33 | FUNCTIONALITY_OTHER = 0; 34 | FUNCTIONALITY_LTL_AIGER_MODELCHECKING = 1; 35 | FUNCTIONALITY_LTL_MEALY_MODELCHECKING = 2; 36 | FUNCTIONALITY_LTL_AIGER_SYNTHESIS = 3; 37 | FUNCTIONALITY_LTL_MEALY_SYNTHESIS = 4; 38 | FUNCTIONALITY_LTL_EQUIVALENCE = 5; 39 | FUNCTIONALITY_LTL_TRACE_MODELCHECKING = 6; 40 | FUNCTIONALITY_RANDLTL = 7; 41 | FUNCTIONALITY_AIGER_TO_MEALY = 8; 42 | FUNCTIONALITY_MEALY_TO_AIGER = 9; 43 | FUNCTIONALITY_TLSF_TO_SPEC = 10; 44 | FUNCTIONALITY_NEURAL_LTL_AIGER_SYNTHESIS = 11; 45 | } 46 | 47 | -------------------------------------------------------------------------------- /ml2/grpc/tools/tools_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/grpc/trace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/grpc/trace/__init__.py -------------------------------------------------------------------------------- /ml2/grpc/trace/trace.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message Trace { string trace = 1; } -------------------------------------------------------------------------------- /ml2/grpc/trace/trace_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: ml2/grpc/trace/trace.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aml2/grpc/trace/trace.proto\"\x16\n\x05Trace\x12\r\n\x05trace\x18\x01 \x01(\tb\x06proto3') 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ml2.grpc.trace.trace_pb2', _globals) 22 | if _descriptor._USE_C_DESCRIPTORS == False: 23 | DESCRIPTOR._options = None 24 | _globals['_TRACE']._serialized_start=30 25 | _globals['_TRACE']._serialized_end=52 26 | # @@protoc_insertion_point(module_scope) 27 | -------------------------------------------------------------------------------- /ml2/grpc/trace/trace_pb2.pyi: -------------------------------------------------------------------------------- 1 | """ 2 | @generated by mypy-protobuf. Do not edit manually! 3 | isort:skip_file 4 | """ 5 | 6 | import builtins 7 | import google.protobuf.descriptor 8 | import google.protobuf.message 9 | import typing 10 | 11 | DESCRIPTOR: google.protobuf.descriptor.FileDescriptor 12 | 13 | @typing.final 14 | class Trace(google.protobuf.message.Message): 15 | DESCRIPTOR: google.protobuf.descriptor.Descriptor 16 | 17 | TRACE_FIELD_NUMBER: builtins.int 18 | trace: builtins.str 19 | def __init__( 20 | self, 21 | *, 22 | trace: builtins.str = ..., 23 | ) -> None: ... 24 | def ClearField(self, field_name: typing.Literal["trace", b"trace"]) -> None: ... 25 | 26 | global___Trace = Trace 27 | -------------------------------------------------------------------------------- /ml2/grpc/trace/trace_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ml2/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/layers/__init__.py -------------------------------------------------------------------------------- /ml2/layers/positional_encoding.py: -------------------------------------------------------------------------------- 1 | """Sinusoidal positional encoding described in 'Attention Is All You Need' (Vaswani et al., 2017)""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def get_angles(position: int, i: int, d_embedding: int): 8 | """ 9 | Args: 10 | position: int, position 11 | i: int, dimension 12 | d_embedding: int, embedding dimension 13 | """ 14 | angel_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_embedding)) 15 | return position * angel_rates 16 | 17 | 18 | def positional_encoding(position: int, d_embedding: int): 19 | """ 20 | Returns a sinusoidal positional encoding 21 | Args: 22 | position: int, position 23 | d_embedding: int, embedding dimension 24 | """ 25 | angle_rads = get_angles( 26 | np.arange(position)[:, np.newaxis], np.arange(d_embedding)[np.newaxis, :], d_embedding 27 | ) 28 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 29 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 30 | pos_encoding = angle_rads[np.newaxis, ...] 31 | return tf.cast(pos_encoding, dtype=tf.float32) 32 | -------------------------------------------------------------------------------- /ml2/loading.py: -------------------------------------------------------------------------------- 1 | """Utility to load artifact""" 2 | 3 | from typing import Type, Union 4 | 5 | from .artifact import Artifact 6 | from .configurable import Configurable 7 | from .registry import type_from_str 8 | 9 | 10 | def get_artifact_type(config: Union[str, dict]) -> Type: 11 | if isinstance(config, str): 12 | config = Artifact.fetch_config(config) 13 | if "type" in config: 14 | config_type = config["type"] 15 | elif "base" in config: 16 | base_config = Artifact.fetch_config(config["base"]) 17 | config_type = base_config["type"] 18 | else: 19 | raise Exception("Could not determine artifact type") 20 | 21 | if isinstance(config_type, str): 22 | config_type = type_from_str(config_type, bound=Configurable) 23 | 24 | return config_type 25 | 26 | 27 | def load_artifact(config: Union[str, dict], **kwargs) -> Artifact: 28 | config_type = get_artifact_type(config) 29 | return config_type.from_config(config, **kwargs) 30 | -------------------------------------------------------------------------------- /ml2/ltl/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_pt_available, is_tf_available 2 | from . import ltl_equiv, ltl_formula, ltl_mc, ltl_sat, ltl_syn 3 | from .ltl_formula import DecompLTLFormula, LTLFormula 4 | from .ltl_lexer import lex_ltl 5 | from .ltl_parser import ( 6 | LTLInfixParser, 7 | LTLPrefixParser, 8 | parse_infix_ltl, 9 | parse_ltl, 10 | parse_prefix_ltl, 11 | ) 12 | from .ltl_spec import ( 13 | DecompLTLSpec, 14 | LTLAssumptions, 15 | LTLGuarantees, 16 | LTLSpec, 17 | LTLSpecDataset, 18 | LTLSpecPatternDataset, 19 | ) 20 | 21 | if is_pt_available() and is_tf_available(): 22 | from .ltl_spec import LTLSpecToSeqTokenizer 23 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_equiv/__init__.py: -------------------------------------------------------------------------------- 1 | from .decomp_ltl_equiv_problem import DecompLTLEquivProblem 2 | from .ltl_equiv_problem import LTLEquivProblem 3 | from .ltl_equiv_status import LTLEquivStatus 4 | from .ltl_incl_status import LTLInclStatus 5 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_mc/__init__.py: -------------------------------------------------------------------------------- 1 | from .decomp_ltl_mc_problem import DecompLTLMCProblem 2 | from .ltl_mc_problem import LTLMCProblem, LTLMCSolution 3 | from .ltl_mc_status import LTLMCStatus 4 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_sat/__init__.py: -------------------------------------------------------------------------------- 1 | from .decomp_ltl_sym_trace_problem import DecompLTLSymTraceProblem 2 | from .ltl_sat_dataset import LTLSatDataset 3 | from .ltl_sat_problem import LTLSatProblem 4 | from .ltl_sat_status import LTLSatStatus 5 | from .ltl_sat_sym_trace_problem import LTLSatSymTraceProblem, LTLSatSymTraceSolution 6 | from .ltl_sat_trace_problem import LTLSatTraceProblem, LTLSatTraceSolution 7 | from .ltl_sym_trace_problem import LTLSymTraceProblem 8 | from .ltl_trace_problem import LTLTraceProblem 9 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_sat/decomp_ltl_sym_trace_problem.py: -------------------------------------------------------------------------------- 1 | """Decomposed LTL symbolic trace generation problem""" 2 | 3 | from typing import Dict, List 4 | 5 | from ...dtypes import CSV, Supervised 6 | from ...registry import register_type 7 | from ...trace import SymbolicTrace 8 | from ..ltl_formula import DecompLTLFormula 9 | 10 | 11 | @register_type 12 | class DecompLTLSymTraceProblem(CSV, Supervised[DecompLTLFormula, SymbolicTrace]): 13 | def __init__(self, formula: DecompLTLFormula, trace: SymbolicTrace) -> None: 14 | self.formula = formula 15 | self.trace = trace 16 | 17 | @property 18 | def input(self) -> DecompLTLFormula: 19 | return self.formula 20 | 21 | @property 22 | def target(self) -> SymbolicTrace: 23 | return self.trace 24 | 25 | def _to_csv_fields(self, notation: str = None, **kwargs) -> Dict[str, str]: 26 | formula_fields = self.formula.to_csv_fields(notation=notation, **kwargs) 27 | trace_fields = self.trace.to_csv_fields(**kwargs) 28 | return {**formula_fields, **trace_fields} 29 | 30 | @classmethod 31 | def _csv_field_header(cls, **kwargs) -> List[str]: 32 | return list( 33 | set( 34 | DecompLTLFormula.csv_field_header(**kwargs) 35 | + SymbolicTrace.csv_field_header(**kwargs) 36 | ) 37 | ) 38 | 39 | @classmethod 40 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "DecompLTLSymTraceProblem": 41 | formula = DecompLTLFormula.from_csv_fields(fields, **kwargs) 42 | trace = SymbolicTrace.from_csv_fields(fields, **kwargs) 43 | return cls(formula, trace) 44 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_sat/ltl_sat_problem.py: -------------------------------------------------------------------------------- 1 | """LTL satisfiability problem""" 2 | 3 | from typing import Dict, List 4 | 5 | from ...dtypes import CSV, Supervised 6 | from ...registry import register_type 7 | from ..ltl_formula import LTLFormula 8 | from .ltl_sat_status import LTLSatStatus 9 | 10 | 11 | @register_type 12 | class LTLSatProblem(CSV, Supervised[LTLFormula, LTLSatStatus]): 13 | def __init__(self, formula: LTLFormula, status: LTLSatStatus) -> None: 14 | self.formula = formula 15 | self.status = status 16 | 17 | @property 18 | def input(self) -> LTLFormula: 19 | return self.formula 20 | 21 | @property 22 | def target(self) -> LTLSatStatus: 23 | return self.status 24 | 25 | def _to_csv_fields(self, notation: str = None, **kwargs) -> Dict[str, str]: 26 | formula_fields = self.formula.to_csv_fields(notation=notation, **kwargs) 27 | status_fields = self.status.to_csv_fields(**kwargs) 28 | return {**formula_fields, **status_fields} 29 | 30 | @classmethod 31 | def _csv_field_header(cls, **kwargs) -> List[str]: 32 | return list( 33 | set(LTLFormula.csv_field_header(**kwargs) + LTLSatStatus.csv_field_header(**kwargs)) 34 | ) 35 | 36 | @classmethod 37 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "LTLSatProblem": 38 | formula = LTLFormula.from_csv_fields(fields, **kwargs) 39 | status = LTLSatStatus.from_csv_fields(fields, **kwargs) 40 | return cls(formula, status) 41 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_sat/ltl_sat_status.py: -------------------------------------------------------------------------------- 1 | """Status of an LTL satisfiability problem""" 2 | 3 | from typing import Dict, List 4 | 5 | from ...dtypes.cat import Cat 6 | from ...dtypes.csv_dtype import CSV 7 | from ...registry import register_type 8 | 9 | LTL_SAT_STATUS_TO_INT = { 10 | "satisfiable": 1, 11 | "unsatisfiable": 0, 12 | "error": -1, 13 | "timeout": -2, 14 | } 15 | 16 | INT_TO_LTL_SAT_STATUS = {v: k for k, v in LTL_SAT_STATUS_TO_INT.items()} 17 | 18 | 19 | @register_type 20 | class LTLSatStatus(Cat, CSV): 21 | def __init__(self, status: str) -> None: 22 | if status not in ["satisfiable", "unsatisfiable", "timeout", "error"]: 23 | raise ValueError(f"Invalid status {status}") 24 | self._status = status 25 | 26 | def token(self, **kwargs) -> str: 27 | return self._status 28 | 29 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 30 | return {"sat": LTL_SAT_STATUS_TO_INT[self._status]} 31 | 32 | @classmethod 33 | def _csv_field_header(cls, **kwargs) -> List[str]: 34 | return ["sat"] 35 | 36 | @classmethod 37 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "LTLSatStatus": 38 | return cls(status=INT_TO_LTL_SAT_STATUS[int(fields["sat"])]) 39 | 40 | @classmethod 41 | def from_token(cls, token: str, **kwargs) -> "LTLSatStatus": 42 | return cls(status=token) 43 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_sat/ltl_sym_trace_problem.py: -------------------------------------------------------------------------------- 1 | """LTL symbolic trace generation problem""" 2 | 3 | from typing import Dict, List 4 | 5 | from ...dtypes import CSV, Supervised 6 | from ...registry import register_type 7 | from ...trace import SymbolicTrace 8 | from ..ltl_formula import LTLFormula 9 | 10 | 11 | @register_type 12 | class LTLSymTraceProblem(CSV, Supervised[LTLFormula, SymbolicTrace]): 13 | def __init__(self, formula: LTLFormula, trace: SymbolicTrace) -> None: 14 | self.formula = formula 15 | self.trace = trace 16 | 17 | @property 18 | def input(self) -> LTLFormula: 19 | return self.formula 20 | 21 | @property 22 | def target(self) -> SymbolicTrace: 23 | return self.trace 24 | 25 | def _to_csv_fields(self, notation: str = None, **kwargs) -> Dict[str, str]: 26 | formula_fields = self.formula.to_csv_fields(notation=notation, **kwargs) 27 | trace_fields = self.trace.to_csv_fields(**kwargs) 28 | return {**formula_fields, **trace_fields} 29 | 30 | @classmethod 31 | def _csv_field_header(cls, **kwargs) -> List[str]: 32 | return list( 33 | set(LTLFormula.csv_field_header(**kwargs) + SymbolicTrace.csv_field_header(**kwargs)) 34 | ) 35 | 36 | @classmethod 37 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "LTLSymTraceProblem": 38 | formula = LTLFormula.from_csv_fields(fields, **kwargs) 39 | trace = SymbolicTrace.from_csv_fields(fields, **kwargs) 40 | return cls(formula, trace) 41 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_sat/ltl_trace_problem.py: -------------------------------------------------------------------------------- 1 | """LTL trace generation problem""" 2 | 3 | from typing import Dict, List 4 | 5 | from ...dtypes import CSV, Supervised 6 | from ...registry import register_type 7 | from ...trace import Trace 8 | from ..ltl_formula import LTLFormula 9 | 10 | 11 | @register_type 12 | class LTLTraceProblem(CSV, Supervised[LTLFormula, Trace]): 13 | def __init__(self, formula: LTLFormula, trace: Trace) -> None: 14 | self.formula = formula 15 | self.trace = trace 16 | 17 | @property 18 | def input(self) -> LTLFormula: 19 | return self.formula 20 | 21 | @property 22 | def target(self) -> Trace: 23 | return self.trace 24 | 25 | def _to_csv_fields(self, notation: str = None, **kwargs) -> Dict[str, str]: 26 | formula_fields = self.formula.to_csv_fields(notation=notation, **kwargs) 27 | trace_fields = self.trace.to_csv_fields(**kwargs) 28 | return {**formula_fields, **trace_fields} 29 | 30 | @classmethod 31 | def _csv_field_header(cls, **kwargs) -> List[str]: 32 | return list(set(LTLFormula.csv_field_header(**kwargs) + Trace.csv_field_header(**kwargs))) 33 | 34 | @classmethod 35 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "LTLTraceProblem": 36 | formula = LTLFormula.from_csv_fields(fields, **kwargs) 37 | trace = Trace.from_csv_fields(fields, **kwargs) 38 | return cls(formula, trace) 39 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_spec/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_pt_available, is_tf_available 2 | from .decomp_ltl_spec import DecompLTLSpec, LTLAssumptions, LTLGuarantees, LTLProperties 3 | from .ltl_spec import LTLSpec 4 | from .ltl_spec_csv_dataset import LTLSpecCSVDataset 5 | from .ltl_spec_dataset import LTLSpecDataset 6 | from .ltl_spec_pattern_csv_dataset import LTLSpecPatternCSVDataset 7 | from .ltl_spec_pattern_dataset import LTLSpecPatternDataset 8 | 9 | if is_pt_available() and is_tf_available(): 10 | from .decomp_ltl_spec_tokenizer import DecompLTLSpecToSeqTPETokenizer 11 | from .ltl_spec_tokenizer import LTLSpecToSeqTokenizer, LTLSpecToSeqTPETokenizer 12 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_spec/ltl_spec_patterns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reactive-systems/ml2/33d9696c94de6d27aa836ae8118118a7277ff35c/ml2/ltl/ltl_spec/ltl_spec_patterns/__init__.py -------------------------------------------------------------------------------- /ml2/ltl/ltl_spec/ltl_spec_tokenizer.py: -------------------------------------------------------------------------------- 1 | """LTL specification tokenizer""" 2 | 3 | 4 | from copy import deepcopy 5 | from typing import Type 6 | 7 | from ...registry import register_type 8 | from ...tokenizers import ExprToSeqTokenizer, ExprToSeqTPETokenizer 9 | from .ltl_spec import LTLSpec 10 | 11 | 12 | class LTLSpecToSeqTokenizer(ExprToSeqTokenizer[LTLSpec]): 13 | def __init__(self, rename_aps: bool = False, dtype: Type[LTLSpec] = LTLSpec, **kwargs): 14 | self.rename_aps = rename_aps 15 | super().__init__(dtype=dtype, **kwargs) 16 | 17 | def preprocess_sample(self, x: LTLSpec) -> LTLSpec: 18 | if self.rename_aps: 19 | x = deepcopy(x) 20 | x.rename_aps(random=False) 21 | return x 22 | 23 | 24 | @register_type 25 | class LTLSpecToSeqTPETokenizer(ExprToSeqTPETokenizer[LTLSpec]): 26 | def __init__(self, rename_aps: bool = False, dtype: Type[LTLSpec] = LTLSpec, **kwargs): 27 | self.rename_aps = rename_aps 28 | super().__init__(dtype=dtype, **kwargs) 29 | 30 | def preprocess_sample(self, x: LTLSpec) -> LTLSpec: 31 | if self.rename_aps: 32 | x = deepcopy(x) 33 | x.rename_aps(random=False) 34 | return x 35 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_syn/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_pt_available, is_tf_available 2 | from .decomp_ltl_syn_problem import DecompLTLSynProblem 3 | from .ltl_real_status import LTLRealStatus 4 | from .ltl_syn_dataset import LTLSynDataset, LTLSynSplitDataset 5 | from .ltl_syn_eval_dataset import LTLSynEvalDataset 6 | from .ltl_syn_problem import LTLSynProblem, LTLSynSolution 7 | from .ltl_syn_status import LTLSynStatus 8 | 9 | if is_pt_available() and is_tf_available(): 10 | from .ltl_syn_solution_tokenizer import LTLSynSolutionToSeqTokenizer 11 | from .tf_syn_hier_transformer_pipeline import TFSynHierTransformerPipeline 12 | -------------------------------------------------------------------------------- /ml2/ltl/ltl_syn/tf_syn_hier_transformer_pipeline.py: -------------------------------------------------------------------------------- 1 | """TensorFlow LTL synthesis hierarchical Transformer pipeline""" 2 | 3 | import logging 4 | from typing import Optional, TypeVar 5 | 6 | from ...dtypes import DType 7 | from ...pipelines.tf_hier_transformer_pipeline import TFHierTransformerPipeline 8 | from ...registry import register_type 9 | from ..ltl_spec.ltl_spec import LTLSpec 10 | 11 | I = TypeVar("I", bound=LTLSpec) 12 | T = TypeVar("T", bound=DType) 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | @register_type 19 | class TFSynHierTransformerPipeline(TFHierTransformerPipeline[I, T]): 20 | def decode(self, prediction_encoding, input: Optional[I] = None) -> T: 21 | assert input is not None 22 | return self.target_tokenizer.decode( 23 | prediction_encoding, inputs=input.inputs, outputs=input.outputs 24 | ) 25 | -------------------------------------------------------------------------------- /ml2/mealy/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_pt_available, is_tf_available 2 | from .mealy_machine import Condition, HoaHeader, MealyMachine, Transition 3 | 4 | if is_pt_available() and is_tf_available(): 5 | from .mealy_tokenizer import MealyToSeqTokenizer 6 | -------------------------------------------------------------------------------- /ml2/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_tf_available 2 | 3 | if is_tf_available(): 4 | from . import tf_hierarchical_transformer, tf_transformer 5 | -------------------------------------------------------------------------------- /ml2/models/tf_transformer_metrics.py: -------------------------------------------------------------------------------- 1 | """TensorFlow Transformer metrics""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class TransformerAccuracy(tf.keras.metrics.Metric): 7 | 8 | def __init__(self, name="acc", dtype_float: tf.DType = tf.float32, pad_id: int = 0, **kwargs): 9 | super().__init__(name=name, **kwargs) 10 | self.dtype_float = dtype_float 11 | self.pad_id = pad_id 12 | self.acc_mean = tf.keras.metrics.Mean("acc") 13 | self.acc_per_seq_mean = tf.keras.metrics.Mean("acc_per_seq") 14 | 15 | def update_state(self, y_true, y_pred, sample_weight=None): 16 | weights = tf.cast(tf.not_equal(y_true, self.pad_id), self.dtype_float) 17 | outputs = tf.cast(tf.argmax(y_pred, axis=-1), tf.int32) 18 | y_true = tf.cast(y_true, tf.int32) 19 | 20 | # accuracy 21 | correct_predictions = tf.cast(tf.equal(outputs, y_true), self.dtype_float) 22 | self.acc_mean.update_state(correct_predictions, weights) 23 | 24 | # accuracy per sequence 25 | incorrect_predictions = tf.cast(tf.not_equal(outputs, y_true), self.dtype_float) * weights 26 | correct_sequences = 1.0 - tf.minimum(1.0, tf.reduce_sum(incorrect_predictions, axis=-1)) 27 | self.acc_per_seq_mean.update_state(correct_sequences, tf.constant(1.0)) 28 | 29 | def reset_state(self): 30 | self.acc_mean.reset_state() 31 | self.acc_per_seq_mean.reset_state() 32 | 33 | def result(self): 34 | return { 35 | "acc": self.acc_mean.result(), 36 | "acc_per_seq": self.acc_per_seq_mean.result(), 37 | } 38 | -------------------------------------------------------------------------------- /ml2/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_tf_available 2 | 3 | if is_tf_available(): 4 | from . import tf_optim 5 | -------------------------------------------------------------------------------- /ml2/optim/tf_optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .tf_optimizers import load_tf_optimizer_from_config, tf_optimizer_to_config 2 | from .tf_transformer_lr_schedule import TFTransformerLRSchedule 3 | -------------------------------------------------------------------------------- /ml2/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_hf_available, is_pt_available, is_tf_available 2 | from . import loggers, samples 3 | 4 | if is_pt_available() and is_tf_available(): 5 | from .beam_search_verification_pipeline import BeamSearchVerificationPipeline 6 | from .load_pipeline import load_pipeline 7 | from .pipeline import EvalTask, Pipeline 8 | from .seq2seq_pipeline import Seq2SeqPipeline 9 | from .sl_pipeline import SLPipeline, SupervisedEvalTask 10 | from .tf_hier_transformer_pipeline import TFHierTransformerPipeline 11 | from .tf_pipeline import TFPipeline 12 | from .tf_sl_pipeline import TFSLPipeline 13 | from .tf_transformer_pipeline import TFTransformerPipeline 14 | from .verification_pipeline import VerificationPipeline 15 | 16 | if is_hf_available(): 17 | from .hf_pipelines import HFPTExpr2ExprPipeline, HFPTExpr2TextPipeline, HFPTText2TextPipeline 18 | -------------------------------------------------------------------------------- /ml2/pipelines/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .callback import Callback 2 | -------------------------------------------------------------------------------- /ml2/pipelines/callbacks/callback.py: -------------------------------------------------------------------------------- 1 | """Abstract Callback class""" 2 | 3 | 4 | from abc import abstractmethod 5 | from typing import Any, List 6 | 7 | from ...artifact import Artifact 8 | from ..samples import Sample 9 | 10 | 11 | class Callback(Artifact): 12 | @abstractmethod 13 | def add(self, sample: Sample, **kwargs) -> Any: 14 | raise NotImplementedError() 15 | 16 | def add_batch(self, sample_batch: List[Sample], **kwargs) -> Any: 17 | for sample in sample_batch: 18 | self.add(sample, **kwargs) 19 | -------------------------------------------------------------------------------- /ml2/pipelines/hf_pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_pt_expr2expr_pipeline import HFPTExpr2ExprPipeline 2 | from .hf_pt_expr2text_pipeline import HFPTExpr2TextPipeline 3 | from .hf_pt_text2text_pipeline import HFPTText2TextPipeline 4 | -------------------------------------------------------------------------------- /ml2/pipelines/hf_pipelines/hf_pt_expr2expr_pipeline.py: -------------------------------------------------------------------------------- 1 | """HuggingFace PyTorch expression to expression pipeline""" 2 | 3 | import logging 4 | from typing import Generic, TypeVar 5 | 6 | from ...dtypes import String 7 | from ...registry import register_type 8 | from .hf_pt_text2text_pipeline import HFPTText2TextPipeline 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | I = TypeVar("I", bound=String) 15 | T = TypeVar("T", bound=String) 16 | 17 | 18 | @register_type 19 | class HFPTExpr2ExprPipeline(HFPTText2TextPipeline[I, T], Generic[I, T]): 20 | def __init__( 21 | self, 22 | input_notation: str = "infix", 23 | target_notation: str = "infix", 24 | input_kwargs: dict = None, 25 | target_kwargs: dict = None, 26 | **kwargs, 27 | ): 28 | self.input_notation = input_notation 29 | input_kwargs = input_kwargs if input_kwargs is not None else {} 30 | input_kwargs["notation"] = input_notation 31 | self.target_notation = target_notation 32 | target_kwargs = target_kwargs if target_kwargs is not None else {} 33 | target_kwargs["notation"] = target_notation 34 | super().__init__(input_kwargs=input_kwargs, target_kwargs=target_kwargs, **kwargs) 35 | -------------------------------------------------------------------------------- /ml2/pipelines/hf_pipelines/hf_pt_expr2text_pipeline.py: -------------------------------------------------------------------------------- 1 | """HuggingFace PyTorch expression to text pipeline""" 2 | 3 | import logging 4 | from typing import Generic, TypeVar 5 | 6 | from ...dtypes import String 7 | from ...registry import register_type 8 | from .hf_pt_text2text_pipeline import HFPTText2TextPipeline 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | I = TypeVar("I", bound=String) 15 | T = TypeVar("T", bound=String) 16 | 17 | 18 | @register_type 19 | class HFPTExpr2TextPipeline(HFPTText2TextPipeline[I, T], Generic[I, T]): 20 | def __init__( 21 | self, 22 | input_notation: str = "infix", 23 | input_kwargs: dict = None, 24 | **kwargs, 25 | ): 26 | self.input_notation = input_notation 27 | input_kwargs = input_kwargs if input_kwargs is not None else {} 28 | input_kwargs["notation"] = input_notation 29 | super().__init__(input_kwargs=input_kwargs, **kwargs) 30 | -------------------------------------------------------------------------------- /ml2/pipelines/load_pipeline.py: -------------------------------------------------------------------------------- 1 | """Utility to load pipeline""" 2 | 3 | 4 | import logging 5 | 6 | from .pipeline import Pipeline 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def load_pipeline(name: str, project: str = None, **kwargs) -> Pipeline: 13 | from ..registry import type_from_str 14 | 15 | config = Pipeline.fetch_config(name=name, project=project) 16 | if "type" not in config: 17 | raise Exception("Pipeline type not specified in config") 18 | pipeline_type = type_from_str(config["type"], bound=Pipeline) 19 | return pipeline_type.load(name=name, project=project, **kwargs) 20 | -------------------------------------------------------------------------------- /ml2/pipelines/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | from .csv_dataset_logger import CSVToDatasetLogger 2 | from .csv_logger import CSVLogger 3 | from .sample_logger import SampleLogger 4 | -------------------------------------------------------------------------------- /ml2/pipelines/loggers/csv_dataset_logger.py: -------------------------------------------------------------------------------- 1 | """CSV Logger class""" 2 | 3 | 4 | from typing import Any 5 | 6 | from ...datasets import ContCSVDatasetWriter 7 | from ...dtypes import CSVDict 8 | from ...registry import register_type 9 | from ..samples import Sample 10 | from .csv_logger import CSVLogger 11 | 12 | 13 | @register_type 14 | class CSVToDatasetLogger(ContCSVDatasetWriter[CSVDict], CSVLogger): 15 | def __init__(self, name: str, **kwargs): 16 | super().__init__(name=name, dtype=CSVDict, **kwargs) 17 | 18 | def add(self, sample: Sample, **kwargs) -> Any: 19 | for fields in self.process_generic_sample(sample, **kwargs): 20 | self.add_sample(CSVDict(fields)) 21 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .acc import Acc 2 | from .acc_per_seq import AccPerSeq 3 | from .counter import Counter 4 | from .data_type_acc import DataTypeAcc 5 | from .equiv_acc import EquivAcc 6 | from .err_counter import ( 7 | ErrCounter, 8 | EvalErrCounter, 9 | EvalSupervisedErrCounter, 10 | SupervisedErrCounter, 11 | VerificationErrCounter, 12 | VerificationSupervisedErrCounter, 13 | ) 14 | from .metric import Metric 15 | from .metric_avg import MetricAvg 16 | from .metric_group import MetricGroup 17 | from .null_metric import NullMetric 18 | from .sem_acc import SemAcc 19 | from .sem_beam_acc import SemBeamAcc 20 | from .str_acc import StrAcc 21 | from .ver_status import VerStatus 22 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/counter.py: -------------------------------------------------------------------------------- 1 | """Metric counting the number of samples""" 2 | 3 | from typing import Dict 4 | 5 | from ..samples import Sample 6 | from .metric import Metric 7 | 8 | 9 | class Counter(Metric): 10 | def __init__(self, name: str = "counter") -> None: 11 | self.num_samples = 0 12 | super().__init__(name=name) 13 | 14 | def add(self, sample: Sample) -> bool: 15 | self.num_samples += 1 16 | 17 | def compute(self) -> int: 18 | return self.num_samples 19 | 20 | def compute_dict(self) -> Dict[str, int]: 21 | return {"num_samples": self.compute()} 22 | 23 | def reset(self) -> None: 24 | self.num_samples = 0 25 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/data_type_acc.py: -------------------------------------------------------------------------------- 1 | """Data type accuracy""" 2 | 3 | 4 | from ..samples import LabeledSample 5 | from .metric import Metric 6 | 7 | 8 | class DataTypeAcc(Metric): 9 | def __init__(self, count_none: bool = True, name: str = "data-type-acc") -> None: 10 | self.count_none = count_none 11 | self.acc_not_norm = 0 12 | self.count = 0 13 | super().__init__(name=name) 14 | 15 | def add(self, sample: LabeledSample) -> bool: 16 | if sample.tar is None or sample.pred is None: 17 | if self.count_none: 18 | self.count += 1 19 | return False 20 | 21 | self.count += 1 22 | if sample.tar == sample.pred: 23 | self.acc_not_norm += 1 24 | return True 25 | else: 26 | return False 27 | 28 | def compute(self) -> float: 29 | if self.count > 0: 30 | return self.acc_not_norm / self.count 31 | return 0.0 32 | 33 | def reset(self) -> None: 34 | self.acc_not_norm = 0 35 | self.count = 0 36 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/equiv_acc.py: -------------------------------------------------------------------------------- 1 | """Equivalence accuracy""" 2 | 3 | from typing import Dict, Optional 4 | 5 | from ..samples import VerifiedSample 6 | from .metric import Metric 7 | 8 | 9 | class EquivAcc(Metric): 10 | def __init__( 11 | self, 12 | count_none: bool = True, 13 | name: str = "equiv_acc", 14 | ) -> None: 15 | 16 | self.count_success = 0 17 | self.count_total = 0 18 | self.count_none = count_none 19 | super().__init__(name=name) 20 | 21 | def add(self, sample: VerifiedSample) -> Optional[bool]: 22 | if sample.verification is None or sample.verification.equiv is None: 23 | if self.count_none: 24 | self.count_total += 1 25 | return None 26 | else: 27 | self.count_total += 1 28 | if sample.verification.equiv: 29 | self.count_success += 1 30 | return sample.verification.equiv 31 | 32 | def compute_dict(self) -> Dict[str, float]: 33 | return {"equiv_acc": self.count_success / self.count_total if self.count_total else 0} 34 | 35 | def reset(self) -> None: 36 | self.count_success = 0 37 | self.count_total = 0 38 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/load_metric.py: -------------------------------------------------------------------------------- 1 | """Utility to load metric""" 2 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/metric.py: -------------------------------------------------------------------------------- 1 | """Abstract metric class""" 2 | 3 | import json 4 | import os 5 | from abc import abstractmethod 6 | from typing import Any, Dict, List 7 | 8 | from ..samples import Sample 9 | 10 | 11 | class Metric: 12 | def __init__(self, name: str) -> None: 13 | self.name = name 14 | 15 | @abstractmethod 16 | def add(self, sample: Sample) -> Any: 17 | raise NotImplementedError() 18 | 19 | @abstractmethod 20 | def add_batch(self, sample_batch: List[Sample]) -> Any: 21 | for sample in sample_batch: 22 | self.add(sample) 23 | 24 | @abstractmethod 25 | def compute(self) -> Any: 26 | raise NotImplementedError() 27 | 28 | def compute_dict(self) -> Dict[str, Any]: 29 | return {self.name: self.compute()} 30 | 31 | @abstractmethod 32 | def reset(self) -> None: 33 | raise NotImplementedError() 34 | 35 | def save_to_path(self, path: str) -> None: 36 | filepath = os.path.join(path, self.name + ".json") 37 | with open(filepath, "w") as metric_file: 38 | json.dump(self.compute_dict(), metric_file, indent=2) 39 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/metric_avg.py: -------------------------------------------------------------------------------- 1 | """Metric average""" 2 | 3 | import json 4 | import numpy as np 5 | import os 6 | from typing import Any, Dict 7 | 8 | from .metric import Metric 9 | 10 | 11 | class MetricAvg: 12 | def __init__(self, name: str = "metric-avg") -> None: 13 | self.name = name 14 | self._metrics = [] 15 | 16 | def add_metric(self, metric: Metric) -> Any: 17 | self._metrics.append(metric) 18 | 19 | def compute_dict(self) -> Dict[str, Any]: 20 | result = {} 21 | computations = [m.compute_dict() for m in self._metrics] 22 | if len(computations) == 0: 23 | return result 24 | for key in computations[0]: 25 | result[key] = {} 26 | values = [c[key] for c in computations] 27 | result[key]["mean"] = float(np.mean(values)) 28 | result[key]["std"] = float(np.std(values)) 29 | result[key]["max"] = float(np.amax(values)) 30 | result[key]["min"] = float(np.amin(values)) 31 | return result 32 | 33 | def reset(self) -> None: 34 | self._metrics = [] 35 | 36 | def save_to_path(self, path: str) -> None: 37 | filepath = os.path.join(path, self.name + ".json") 38 | with open(filepath, "w") as metric_file: 39 | json.dump(self.compute_dict(), metric_file, indent=2) 40 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/metric_group.py: -------------------------------------------------------------------------------- 1 | """Metric Group""" 2 | 3 | from typing import Any, Dict, List 4 | 5 | from ..samples import Sample 6 | from .metric import Metric 7 | 8 | 9 | class MetricGroup(Metric): 10 | def __init__(self, metrics: List[Metric], name: str = "metrics") -> None: 11 | self.metrics = metrics 12 | super().__init__(name=name) 13 | 14 | def add(self, sample: Sample) -> Any: 15 | for metric in self.metrics: 16 | metric.add(sample) 17 | 18 | def compute(self) -> Dict[str, Any]: 19 | result = {} 20 | for metric in self.metrics: 21 | result = {**result, **metric.compute_dict()} 22 | return result 23 | 24 | def compute_dict(self) -> Dict[str, Any]: 25 | return self.compute() 26 | 27 | def reset(self) -> None: 28 | for metric in self.metrics: 29 | metric.reset() 30 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/null_metric.py: -------------------------------------------------------------------------------- 1 | """Null metric""" 2 | 3 | 4 | from typing import Any, Dict, List 5 | 6 | from ..samples import Sample 7 | from .metric import Metric 8 | 9 | 10 | class NullMetric(Metric): 11 | def __init__(self, name: str = "null-metric") -> None: 12 | super().__init__(name=name) 13 | 14 | def add(self, sample: Sample) -> Any: 15 | pass 16 | 17 | def add_batch(self, sample_batch: List[Sample]) -> Any: 18 | pass 19 | 20 | def compute(self) -> Any: 21 | return None 22 | 23 | def compute_dict(self) -> Dict[str, Any]: 24 | return {} 25 | 26 | def reset(self) -> None: 27 | pass 28 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/sem_acc.py: -------------------------------------------------------------------------------- 1 | """Semantic accuracy""" 2 | 3 | from typing import Dict, Optional 4 | 5 | from ..samples import VerifiedSample 6 | from .metric import Metric 7 | 8 | 9 | class SemAcc(Metric): 10 | def __init__( 11 | self, 12 | count_none: bool = True, 13 | name: str = "sem_acc", 14 | ) -> None: 15 | 16 | self.count_success = 0 17 | self.count_total = 0 18 | self.count_none = count_none 19 | super().__init__(name=name) 20 | 21 | def add(self, sample: VerifiedSample) -> Optional[bool]: 22 | if sample.verification is None or sample.verification.validation_success is None: 23 | if self.count_none: 24 | self.count_total += 1 25 | return None 26 | else: 27 | self.count_total += 1 28 | if sample.verification.validation_success: 29 | self.count_success += 1 30 | return sample.verification.validation_success 31 | 32 | def compute_dict(self) -> Dict[str, float]: 33 | return { 34 | "semantic_accuracy": self.count_success / self.count_total if self.count_total else 0 35 | } 36 | 37 | def reset(self) -> None: 38 | self.count_success = 0 39 | self.count_total = 0 40 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/sem_beam_acc.py: -------------------------------------------------------------------------------- 1 | """Semantic accuracy""" 2 | 3 | from typing import Dict, Optional 4 | 5 | from ..samples import VerifiedBeamSearchSample 6 | from .metric import Metric 7 | 8 | 9 | class SemBeamAcc(Metric): 10 | def __init__( 11 | self, 12 | count_none: bool = True, 13 | name: str = "sem_beam_acc", 14 | ) -> None: 15 | self.count_success = 0 16 | self.count_total = 0 17 | self.count_none = count_none 18 | super().__init__(name=name) 19 | 20 | def add(self, sample: VerifiedBeamSearchSample) -> Optional[bool]: 21 | all_beams_none = True 22 | has_success_beam = False 23 | for beam in sample.beams: 24 | if beam.verification is not None and beam.verification.validation_success is not None: 25 | all_beams_none = False 26 | has_success_beam = has_success_beam or beam.verification.validation_success 27 | 28 | if all_beams_none and self.count_none: 29 | self.count_total += 1 30 | return None 31 | else: 32 | self.count_total += 1 33 | if has_success_beam: 34 | self.count_success += 1 35 | return has_success_beam 36 | 37 | def compute_dict(self) -> Dict[str, float]: 38 | return { 39 | "semantic_accuracy": self.count_success / self.count_total if self.count_total else 0 40 | } 41 | 42 | def reset(self) -> None: 43 | self.count_success = 0 44 | self.count_total = 0 45 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/str_acc.py: -------------------------------------------------------------------------------- 1 | """String accuracy""" 2 | 3 | 4 | from ..samples import LabeledSample 5 | from .metric import Metric 6 | 7 | 8 | class StrAcc(Metric): 9 | def __init__(self, count_none: bool = True, name: str = "str-acc") -> None: 10 | self.count_none = count_none 11 | self.acc_not_norm = 0 12 | self.count = 0 13 | super().__init__(name=name) 14 | 15 | def add(self, sample: LabeledSample) -> bool: 16 | if sample.tar is None or sample.pred is None: 17 | if self.count_none: 18 | self.count += 1 19 | return False 20 | 21 | self.count += 1 22 | if sample.tar.to_str() == sample.pred.to_str(): 23 | self.acc_not_norm += 1 24 | return True 25 | else: 26 | return False 27 | 28 | def compute(self) -> float: 29 | if self.count > 0: 30 | return self.acc_not_norm / self.count 31 | return 0.0 32 | 33 | def reset(self) -> None: 34 | self.acc_not_norm = 0 35 | self.count = 0 36 | -------------------------------------------------------------------------------- /ml2/pipelines/metrics/ver_status.py: -------------------------------------------------------------------------------- 1 | """Verification status metric""" 2 | 3 | from typing import Dict 4 | 5 | from ..samples import VerifiedSample 6 | from .metric import Metric 7 | 8 | 9 | class VerStatus(Metric): 10 | def __init__( 11 | self, 12 | count_none: bool = True, 13 | name: str = "ver_status", 14 | ) -> None: 15 | self.count_none = count_none 16 | 17 | self.count_total = 0 18 | self.status_count: Dict[str, int] = {} 19 | super().__init__(name=name) 20 | 21 | def add(self, sample: VerifiedSample) -> str: 22 | if sample.verification is None or sample.verification.validation_status is None: 23 | if self.count_none: 24 | self.count_total += 1 25 | self.status_count["None"] = self.status_count.get("None", 0) + 1 26 | return "None" 27 | else: 28 | self.count_total += 1 29 | status = sample.verification.validation_status 30 | self.status_count[status] = self.status_count.get(status, 0) + 1 31 | return status 32 | 33 | def compute_dict(self) -> Dict[str, float]: 34 | if self.count_total > 0: 35 | acc_dict = {s: (c / self.count_total) for s, c in self.status_count.items()} 36 | return acc_dict 37 | return {} 38 | 39 | def reset(self) -> None: 40 | self.count_total = 0 41 | self.status_count = {} 42 | -------------------------------------------------------------------------------- /ml2/pipelines/samples/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search_sample import Beam, BeamSearchLabeledSample, BeamSearchSample 2 | from .eval_sample import EvalLabeledSample, EvalSample 3 | from .labeled_sample import LabeledSample 4 | from .portfolio_sample import PortfolioSample, Result 5 | from .sample import EncodedSample, Sample 6 | from .verified_sample import ( 7 | VerifiedBeam, 8 | VerifiedBeamSearchLabeledSample, 9 | VerifiedBeamSearchSample, 10 | VerifiedLabeledSample, 11 | VerifiedSample, 12 | ) 13 | -------------------------------------------------------------------------------- /ml2/pipelines/samples/beam_search_sample.py: -------------------------------------------------------------------------------- 1 | """Beam search sample""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Generic, List, TypeVar 5 | 6 | from ...dtypes import DType 7 | from .eval_sample import EvalLabeledSample, EvalSample 8 | 9 | I = TypeVar("I", bound=DType) 10 | T = TypeVar("T", bound=DType) 11 | 12 | 13 | @dataclass(eq=False) 14 | class Beam(Generic[T]): 15 | id: int 16 | pred: T = None 17 | pred_enc: ... = None 18 | pred_dec_err: str = None 19 | time: float = None 20 | 21 | 22 | @dataclass(eq=False) 23 | class BeamSearchSample(EvalSample[I, T], Generic[I, T]): 24 | beams: List[Beam[T]] = field(default_factory=list) 25 | 26 | def add_beam(self, beam: Beam[T]) -> None: 27 | self.beams.append(beam) 28 | if beam.id == 0: 29 | if self.pred is not None: 30 | print(self) 31 | print(beam) 32 | # assert self.pred is None 33 | # assert self.pred_enc is None 34 | # assert self.pred_dec_err is None 35 | self.pred = beam.pred 36 | self.pred_enc = beam.pred_enc 37 | self.pred_dec_err = beam.pred_dec_err 38 | self.time = beam.time 39 | 40 | 41 | @dataclass(eq=False) 42 | class BeamSearchLabeledSample(BeamSearchSample[I, T], EvalLabeledSample[I, T], Generic[I, T]): 43 | pass 44 | -------------------------------------------------------------------------------- /ml2/pipelines/samples/eval_sample.py: -------------------------------------------------------------------------------- 1 | """"Evaluated sample""" 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Generic, TypeVar 5 | 6 | from ...dtypes import DType 7 | from .labeled_sample import LabeledSample 8 | from .sample import EncodedSample 9 | 10 | I = TypeVar("I", bound=DType) 11 | T = TypeVar("T", bound=DType) 12 | 13 | 14 | @dataclass(eq=False) 15 | class EvalSample(EncodedSample[I], Generic[I, T]): 16 | pred: T = None 17 | pred_enc: Any = None 18 | pred_dec_err: str = None 19 | time: float = None 20 | 21 | 22 | @dataclass(eq=False) 23 | class EvalLabeledSample(EvalSample[I, T], LabeledSample[I, T], Generic[I, T]): 24 | pass 25 | -------------------------------------------------------------------------------- /ml2/pipelines/samples/labeled_sample.py: -------------------------------------------------------------------------------- 1 | """"Supervised sample""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Any, Generic, TypeVar 5 | 6 | from ...dtypes import DType 7 | from .sample import EncodedSample 8 | 9 | I = TypeVar("I", bound=DType) 10 | T = TypeVar("T", bound=DType) 11 | 12 | 13 | def tar_err(): 14 | raise ValueError("Target is not specified") 15 | 16 | 17 | @dataclass(eq=False) 18 | class LabeledSample(EncodedSample[I], Generic[I, T]): 19 | tar: T = field(default_factory=lambda x: tar_err()) 20 | tar_enc: Any = None 21 | tar_enc_err: str = None 22 | -------------------------------------------------------------------------------- /ml2/pipelines/samples/portfolio_sample.py: -------------------------------------------------------------------------------- 1 | """Portfolio sample""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Generic, List, Optional, TypeVar 5 | 6 | from ...dtypes import DType 7 | from .sample import Sample 8 | 9 | I = TypeVar("I", bound=DType) 10 | T = TypeVar("T", bound=DType) 11 | 12 | 13 | @dataclass(eq=False) 14 | class Result(Generic[T]): 15 | id: int 16 | result: T 17 | name: Optional[str] = None 18 | time: Optional[float] = None 19 | 20 | 21 | @dataclass(eq=False) 22 | class PortfolioSample(Sample[I], Generic[I, T]): 23 | results: List[Result[T]] = field(default_factory=list) 24 | 25 | def add_result( 26 | self, result: T, name: Optional[str] = None, time: Optional[float] = None 27 | ) -> None: 28 | self.results.append(Result(id=len(self.results), result=result, name=name, time=time)) 29 | -------------------------------------------------------------------------------- /ml2/pipelines/samples/sample.py: -------------------------------------------------------------------------------- 1 | """Sample""" 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Generic, TypeVar 5 | 6 | from ...dtypes import DType 7 | 8 | T = TypeVar("T", bound=DType) 9 | 10 | 11 | @dataclass(eq=False) 12 | class Sample(Generic[T]): 13 | inp: T 14 | id: int = None 15 | name: str = None 16 | 17 | 18 | @dataclass(eq=False) 19 | class EncodedSample(Sample[T]): 20 | inp_enc: Any = None 21 | inp_enc_err: str = None 22 | -------------------------------------------------------------------------------- /ml2/prop/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cnf 2 | from .assignment import Assignment 3 | from .assignment_check_status import AssignmentCheckStatus 4 | from .prop_formula import PropFormula 5 | from .prop_lexer import lex_prop 6 | from .prop_parser import parse_prop 7 | from .prop_sat_dataset import PropSatDataset, PropSatSplitDataset 8 | from .prop_sat_problem import PropSatProblem, PropSatSolution 9 | from .prop_sat_status import PropSatStatus 10 | from .prop_valid_status import PropValidStatus 11 | -------------------------------------------------------------------------------- /ml2/prop/assignment_check_status.py: -------------------------------------------------------------------------------- 1 | """Status of an assignment check""" 2 | 3 | from typing import Dict, List 4 | 5 | from ..dtypes import CSV, Cat 6 | from ..registry import register_type 7 | 8 | ASSIGN_CHECK_STATUS_TO_INT = { 9 | "satisfying": 1, 10 | "unsatisfying": 0, 11 | "error": -1, 12 | "timeout": -2, 13 | } 14 | 15 | INT_TO_ASSIGN_CHECK_STATUS = {i: s for s, i in ASSIGN_CHECK_STATUS_TO_INT.items()} 16 | 17 | 18 | @register_type 19 | class AssignmentCheckStatus(Cat, CSV): 20 | def __init__(self, status: str) -> None: 21 | if status not in ["satisfying", "unsatisfying", "timeout", "error"]: 22 | raise ValueError(f"Invalid status {status}") 23 | self._status = status 24 | 25 | def token(self, **kwargs) -> str: 26 | return self._status 27 | 28 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 29 | return {"satisfying": ASSIGN_CHECK_STATUS_TO_INT[self._status]} 30 | 31 | @classmethod 32 | def _csv_field_header(cls, **kwargs) -> List[str]: 33 | return ["satisfying"] 34 | 35 | @classmethod 36 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "AssignmentCheckStatus": 37 | return cls(status=INT_TO_ASSIGN_CHECK_STATUS[int(fields["satisfying"])], **kwargs) 38 | 39 | @classmethod 40 | def from_token(cls, token: str, **kwargs) -> "AssignmentCheckStatus": 41 | return cls(status=token) 42 | -------------------------------------------------------------------------------- /ml2/prop/cnf/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_pt_available, is_tf_available 2 | from .cnf_assign_problem import CNFAssignProblem 3 | from .cnf_assignment import CNFAssignment 4 | from .cnf_formula import Clause, CNFFormula 5 | from .cnf_res_problem import CNFResProblem 6 | from .cnf_sat_problem import CNFSatProblem, CNFSatSolution 7 | from .cnf_sat_search_problem import CNFSatSearchProblem, CNFSatSearchSolution 8 | from .res_completion_problem import ResCompletionProblem 9 | from .res_proof import ResClause, ResProof 10 | from .res_proof_status import ResProofCheckStatus 11 | 12 | if is_pt_available() and is_tf_available(): 13 | from .cnf_formula_tokenizer import CNFFormulaTokenizer 14 | from .res_proof_tokenizer import ResProofTokenizer 15 | -------------------------------------------------------------------------------- /ml2/prop/cnf/cnf_formula_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes a CNF Formula into a sequence encoding""" 2 | 3 | import itertools 4 | from typing import List, Type, TypeVar 5 | 6 | from ...dtypes import Seq 7 | from ...registry import register_type 8 | from ...tokenizers.seq_tokenizers import SeqToSeqTokenizer 9 | from .cnf_formula import CNFFormula 10 | 11 | T = TypeVar("T", bound=Seq) 12 | 13 | 14 | @register_type 15 | class CNFFormulaTokenizer(SeqToSeqTokenizer[CNFFormula]): 16 | def __init__( 17 | self, 18 | dtype: Type[CNFFormula] = CNFFormula, 19 | enumerated: bool = False, 20 | pos_factor: int = 1, 21 | **kwargs, 22 | ): 23 | self.enumerated = enumerated 24 | self.pos_factor = pos_factor 25 | super().__init__(dtype=dtype, **kwargs) 26 | 27 | def encode_tokens(self, data: CNFFormula, **kwargs) -> List[str]: 28 | nested_tokens = [] 29 | for i, clause in enumerate(data.clauses): 30 | clause_tokens = clause.tokens(**kwargs) 31 | if self.enumerated: 32 | pos = (i + 1) * self.pos_factor 33 | clause_tokens = [str(pos)] + clause_tokens 34 | nested_tokens.append(clause_tokens) 35 | return list(itertools.chain.from_iterable(nested_tokens)) 36 | -------------------------------------------------------------------------------- /ml2/prop/cnf/cnf_res_problem.py: -------------------------------------------------------------------------------- 1 | """Propositional satisfiability problem in conjunctive normal form""" 2 | 3 | from typing import Dict 4 | 5 | from ...dtypes import CSV, Supervised 6 | from ...registry import register_type 7 | from .cnf_formula import CNFFormula 8 | from .res_proof import ResProof 9 | 10 | 11 | @register_type 12 | class CNFResProblem(CSV, Supervised[CNFFormula, ResProof]): 13 | def __init__( 14 | self, 15 | formula: CNFFormula, 16 | proof: ResProof = None, 17 | ) -> None: 18 | self.formula = formula 19 | self.proof = proof 20 | 21 | @property 22 | def input(self) -> CNFFormula: 23 | return self.formula 24 | 25 | @property 26 | def target(self) -> ResProof: 27 | return self.proof 28 | 29 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 30 | formula_csv_fields = self.formula.to_csv_fields(**kwargs) 31 | proof_csv_fields = self.proof.to_csv_fields(**kwargs) if self.proof is not None else {} 32 | return {**formula_csv_fields, **proof_csv_fields} 33 | 34 | @classmethod 35 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "CNFResProblem": 36 | return cls( 37 | formula=CNFFormula.from_csv_fields(fields, **kwargs), 38 | proof=ResProof.from_csv_fields(fields, **kwargs), 39 | ) 40 | -------------------------------------------------------------------------------- /ml2/prop/cnf/res_completion_problem.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from ...dtypes import CSV, Supervised 4 | from ...registry import register_type 5 | from .res_proof import ResProof 6 | 7 | 8 | @register_type 9 | class ResCompletionProblem(CSV, Supervised[ResProof, ResProof]): 10 | def __init__( 11 | self, 12 | proof_start: ResProof, 13 | proof_end: ResProof, 14 | ) -> None: 15 | self.proof_start = proof_start 16 | self.proof_end = proof_end 17 | 18 | @property 19 | def input(self) -> ResProof: 20 | return self.proof_start 21 | 22 | @property 23 | def target(self) -> ResProof: 24 | return self.proof_end 25 | 26 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 27 | return { 28 | "proof_start": self.proof_start.to_csv_fields(**kwargs)["res_proof"], 29 | "proof_end": self.proof_end.to_csv_fields(**kwargs)["res_proof"], 30 | } 31 | 32 | @classmethod 33 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "ResCompletionProblem": 34 | return cls( 35 | proof_start=ResProof.from_csv_fields({"res_proof": fields["proof_start"]}, **kwargs), 36 | proof_end=ResProof.from_csv_fields({"res_proof": fields["proof_end"]}, **kwargs), 37 | ) 38 | -------------------------------------------------------------------------------- /ml2/prop/cnf/res_data_gen_common.py: -------------------------------------------------------------------------------- 1 | """Common functionality for resolution data generation""" 2 | 3 | import random 4 | from copy import deepcopy 5 | from typing import Dict 6 | 7 | import numpy as np 8 | 9 | from .cnf_formula import Clause, CNFFormula 10 | from .cnf_sat_search_problem import CNFSatSearchProblem 11 | 12 | 13 | class CNFResDataGenProblem(CNFSatSearchProblem): 14 | def add_clause(self, p_k_2: float, p_geo: float) -> "CNFResDataGenProblem": 15 | k_base = 1 if random.random() < p_k_2 else 2 16 | k = k_base + np.random.geometric(p_geo) 17 | num_vars = self.formula.num_vars 18 | vars = np.random.choice(num_vars, size=min(num_vars, k), replace=False) 19 | lits = [v + 1 if random.random() < 0.5 else -v - 1 for v in vars] 20 | clause = Clause.from_list(lits) 21 | formula = deepcopy(self.formula) 22 | formula.add_clause(clause) 23 | return CNFResDataGenProblem(formula=formula, id=self.id, timeout=self.timeout) 24 | 25 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 26 | formula_csv_fields = self.formula.to_csv_fields(**kwargs) 27 | proof_csv_fields = ( 28 | self.solution.res_proof.to_csv_fields(**kwargs) 29 | if self.solution.res_proof is not None 30 | else {} 31 | ) 32 | return {**formula_csv_fields, **proof_csv_fields} 33 | 34 | @classmethod 35 | def from_random(cls, min_n: int, max_n: int, **kwargs) -> "CNFResDataGenProblem": 36 | num_vars = random.randint(min_n, max_n) 37 | formula = CNFFormula(num_vars=num_vars) 38 | return cls(formula=formula, **kwargs) 39 | -------------------------------------------------------------------------------- /ml2/prop/cnf/res_proof_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes a resolution proof into a sequence encoding""" 2 | 3 | from typing import List, Type, TypeVar 4 | 5 | from ...dtypes import Seq 6 | from ...registry import register_type 7 | from ...tokenizers.seq_tokenizers import SeqToSeqTokenizer 8 | from ...tokenizers.tokenizer import TokenizationException 9 | from .res_proof import ResProof 10 | 11 | T = TypeVar("T", bound=Seq) 12 | 13 | 14 | @register_type 15 | class ResProofTokenizer(SeqToSeqTokenizer[ResProof]): 16 | def __init__( 17 | self, 18 | components: List[str] = None, 19 | dtype: Type[ResProof] = ResProof, 20 | pos_factor: int = 1, 21 | **kwargs, 22 | ): 23 | self.components = components if components is not None else ["id", "clause", "premises"] 24 | self.pos_factor = pos_factor 25 | 26 | super().__init__(dtype=dtype, **kwargs) 27 | 28 | def encode_tokens(self, data: ResProof, **kwargs) -> List[str]: 29 | int_tokens = [] 30 | for rc in data.res_clauses: 31 | for c in self.components: 32 | if c == "id": 33 | int_tokens.append(rc.id * self.pos_factor) 34 | elif c == "clause": 35 | int_tokens.extend(rc.clause.lits) 36 | int_tokens.append(0) 37 | elif c == "premises": 38 | int_tokens.extend(rc.premises) 39 | int_tokens.append(0) 40 | else: 41 | raise TokenizationException(f"Unknown component {c}") 42 | return [str(t) for t in int_tokens] 43 | -------------------------------------------------------------------------------- /ml2/prop/prop_lexer.py: -------------------------------------------------------------------------------- 1 | """Propositional logic lexer""" 2 | 3 | import logging 4 | import sly 5 | 6 | logging.basicConfig(level=logging.INFO) 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class PropPrefixLexer(sly.Lexer): 11 | 12 | ops = {NOT, AND, OR, XOR, IMPL, EQUIV} 13 | tokens = {V, CONST}.union(ops) 14 | 15 | ignore = " \t" 16 | 17 | CONST = r"true|false|1|0" 18 | V = r"[a-zA-Z_][a-zA-Z0-9_]*" 19 | 20 | NOT = r"!" 21 | AND = r"&(&)?" 22 | OR = r"\|(\|)?" 23 | XOR = r"\^" 24 | IMPL = r"->" 25 | EQUIV = r"<->" 26 | 27 | def error(self, t): 28 | # TODO figure out how to return None instead of skipping illegal characters 29 | logger.debug(f"Illegal character {t.value[0]}") 30 | self.index += 1 31 | 32 | 33 | class PropInfixLexer(sly.Lexer): 34 | 35 | ops = {NOT, AND, OR, XOR, IMPL, EQUIV} 36 | tokens = {V, CONST, LPAR, RPAR}.union(ops) 37 | 38 | ignore = " \t" 39 | 40 | CONST = r"true|false|1|0" 41 | V = r"[a-zA-Z_][a-zA-Z0-9_]*" 42 | LPAR = r"\(" 43 | RPAR = r"\)" 44 | 45 | NOT = r"!" 46 | AND = r"&(&)?" 47 | OR = r"\|(\|)?" 48 | XOR = r"\^" 49 | IMPL = r"->" 50 | EQUIV = r"<->" 51 | 52 | def error(self, t): 53 | logger.debug(f"Illegal character {t.value[0]}") 54 | self.index += 1 55 | 56 | 57 | PROP_INFIX_LEXER = None 58 | 59 | 60 | def lex_prop(formula: str) -> list: 61 | global PROP_INFIX_LEXER 62 | if PROP_INFIX_LEXER is None: 63 | PROP_INFIX_LEXER = PropInfixLexer() 64 | return [token.value for token in PROP_INFIX_LEXER.tokenize(formula)] 65 | -------------------------------------------------------------------------------- /ml2/prop/prop_sat_status.py: -------------------------------------------------------------------------------- 1 | """Status of an propositional satisfiability problem""" 2 | 3 | from typing import Dict, List 4 | 5 | from ..dtypes import CSV, Cat 6 | from ..registry import register_type 7 | 8 | PROP_SAT_STATUS_TO_INT = { 9 | "unsat": 0, 10 | "sat": 1, 11 | "error": -1, 12 | "timeout": -2, 13 | } 14 | 15 | INT_TO_PROP_SAT_STATUS = {i: s for s, i in PROP_SAT_STATUS_TO_INT.items()} 16 | 17 | 18 | @register_type 19 | class PropSatStatus(Cat, CSV): 20 | def __init__(self, status: str) -> None: 21 | if status not in ["sat", "unsat", "timeout", "error"]: 22 | raise ValueError(f"Invalid status {status}") 23 | self._status = status 24 | 25 | def token(self, **kwargs) -> str: 26 | return self._status 27 | 28 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 29 | return {"sat": PROP_SAT_STATUS_TO_INT[self._status]} 30 | 31 | @classmethod 32 | def _csv_field_header(cls, **kwargs) -> List[str]: 33 | return ["sat"] 34 | 35 | @classmethod 36 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "PropSatStatus": 37 | return cls(status=INT_TO_PROP_SAT_STATUS[int(fields["sat"])], **kwargs) 38 | 39 | @classmethod 40 | def from_token(cls, token: str, **kwargs) -> "PropSatStatus": 41 | return cls(status=token) 42 | -------------------------------------------------------------------------------- /ml2/prop/prop_valid_status.py: -------------------------------------------------------------------------------- 1 | """Status of a propositional validity problem""" 2 | 3 | from typing import Dict, List 4 | 5 | from ..dtypes import CSV, Cat 6 | from ..registry import register_type 7 | 8 | PROP_VALID_STATUS_TO_INT = { 9 | "invalid": 0, 10 | "valid": 1, 11 | "error": -1, 12 | "timeout": -2, 13 | } 14 | 15 | INT_TO_PROP_VALID_STATUS = {i: s for s, i in PROP_VALID_STATUS_TO_INT.items()} 16 | 17 | 18 | @register_type 19 | class PropValidStatus(Cat, CSV): 20 | def __init__(self, status: str) -> None: 21 | if status not in ["valid", "invalid", "timeout", "error"]: 22 | raise ValueError(f"Invalid status {status}") 23 | self._status = status 24 | 25 | def token(self, **kwargs) -> str: 26 | return self._status 27 | 28 | def _to_csv_fields(self, **kwargs) -> Dict[str, str]: 29 | return {"valid": PROP_VALID_STATUS_TO_INT[self._status]} 30 | 31 | @classmethod 32 | def _csv_field_header(cls, **kwargs) -> List[str]: 33 | return ["valid"] 34 | 35 | @classmethod 36 | def _from_csv_fields(cls, fields: Dict[str, str], **kwargs) -> "PropValidStatus": 37 | return cls(status=INT_TO_PROP_VALID_STATUS[int(fields["valid"])]) 38 | 39 | @classmethod 40 | def from_token(cls, token: str, **kwargs) -> "PropValidStatus": 41 | return cls(status=token) 42 | -------------------------------------------------------------------------------- /ml2/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cat_seq_tokenizers import CatSeqToSeqTokenizer 2 | from .cat_tokenizers import CatToIdTokenizer 3 | from .decomp_dtype_tokenizers import ( 4 | DecompDTypeToDecompSeqPosTokenizer, 5 | DecompDTypeToDecompSeqTokenizer, 6 | ) 7 | from .decomp_expr_pair_tokenizers import DecompExprPairToDecompSeqTPETokenizer 8 | from .decomp_expr_tokenizers import DecompExprToDecompSeqTPETokenizer 9 | from .expr_tokenizers import ExprToSeqTokenizer, ExprToSeqTPETokenizer 10 | from .pair_tokenizers import CatSeqPairToSeqTokenizer, PairToSeqTokenizer 11 | from .seq_tokenizers import SeqToSeqTokenizer 12 | from .to_decomp_seq_pos_tokenizer import DecompSeqPosEncoding, ToDecompSeqPosTokenizer 13 | from .to_id_tokenizer import ToIdTokenizer 14 | from .to_seq_mask_tokenizer import ( 15 | ToSeq2DMaskTokenizer, 16 | ToSeq3DMaskTokenizer, 17 | ToSeq4DMaskTokenizer, 18 | ToSeqMaskTokenizer, 19 | ) 20 | from .to_seq_pos_tokenizer import SeqPosEncoding, ToSeqPosTokenizer 21 | from .to_seq_tokenizer import SeqEncoding, ToSeqTokenizer 22 | from .to_seq_tpe_tokenizer import ToSeqTPETokenizer 23 | from .tokenizer import ( 24 | EOS_TOKEN, 25 | PAD_TOKEN, 26 | START_TOKEN, 27 | NPEncoding, 28 | PTEncoding, 29 | TFEncoding, 30 | Tokenizer, 31 | ) 32 | from .vocabulary import Vocabulary 33 | -------------------------------------------------------------------------------- /ml2/tokenizers/cat_seq_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cat_seq_to_seq_tokenizer import CatSeqToSeqTokenizer 2 | -------------------------------------------------------------------------------- /ml2/tokenizers/cat_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cat_to_id_tokenizer import CatToIdTokenizer 2 | -------------------------------------------------------------------------------- /ml2/tokenizers/cat_tokenizers/cat_to_id_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes a categorical data type into an index""" 2 | 3 | from typing import Generic, TypeVar 4 | 5 | from ...dtypes import Cat 6 | from ...registry import register_type 7 | from ..to_id_tokenizer import ToIdTokenizer 8 | 9 | T = TypeVar("T", bound=Cat) 10 | 11 | 12 | @register_type 13 | class CatToIdTokenizer(ToIdTokenizer[T], Generic[T]): 14 | def encode_token(self, data: T, **kwargs) -> str: 15 | return data.token(**kwargs) 16 | 17 | def decode_token(self, token: str, **kwargs) -> T: 18 | return self.dtype.from_token(token=token, **kwargs) 19 | -------------------------------------------------------------------------------- /ml2/tokenizers/decomp_dtype_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .decomp_dtype_to_decomp_seq_pos_tokenizer import DecompDTypeToDecompSeqPosTokenizer 2 | from .decomp_dtype_to_decomp_seq_tokenizer import DecompDTypeToDecompSeqTokenizer 3 | -------------------------------------------------------------------------------- /ml2/tokenizers/decomp_expr_pair_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .decomp_expr_pair_to_decomp_seq_tpe_tokenizer import DecompExprPairToDecompSeqTPETokenizer 2 | -------------------------------------------------------------------------------- /ml2/tokenizers/decomp_expr_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .decomp_expr_to_decomp_seq_tpe_tokenizer import DecompExprToDecompSeqTPETokenizer 2 | -------------------------------------------------------------------------------- /ml2/tokenizers/expr_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .expr_to_seq_tokenizer import ExprToSeqTokenizer 2 | from .expr_to_seq_tpe_tokenizer import ExprToSeqTPETokenizer 3 | -------------------------------------------------------------------------------- /ml2/tokenizers/expr_tokenizers/expr_to_seq_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes an expression into a sequence encoding""" 2 | 3 | from typing import Generic, List, TypeVar 4 | 5 | from ...dtypes import BinaryExpr 6 | from ...registry import register_type 7 | from ..to_seq_tokenizer import ToSeqTokenizer 8 | 9 | T = TypeVar("T", bound=BinaryExpr) 10 | 11 | 12 | @register_type 13 | class ExprToSeqTokenizer(ToSeqTokenizer[T], Generic[T]): 14 | def __init__(self, notation: str = "infix", **kwargs): 15 | self.notation = notation 16 | super().__init__(**kwargs) 17 | 18 | def encode_tokens(self, data: T, **kwargs) -> List[str]: 19 | return data.to_tokens(notation=self.notation, **kwargs) 20 | 21 | def decode_tokens(self, tokens: List[str], **kwargs) -> T: 22 | return self.dtype.from_tokens(tokens, notation=self.notation, **kwargs) 23 | 24 | def vocabulary_filename(self) -> str: 25 | return super().vocabulary_filename() + "-" + self.notation 26 | -------------------------------------------------------------------------------- /ml2/tokenizers/expr_tokenizers/expr_to_seq_tpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes an expression into a sequence encoding with tree positional encoding""" 2 | 3 | from typing import Generic, List, TypeVar 4 | 5 | from ...dtypes import BinaryExpr 6 | from ...registry import register_type 7 | from ..to_seq_tpe_tokenizer import ToSeqTPETokenizer 8 | 9 | T = TypeVar("T", bound=BinaryExpr) 10 | 11 | 12 | @register_type 13 | class ExprToSeqTPETokenizer(ToSeqTPETokenizer[T], Generic[T]): 14 | def __init__(self, notation: str = "infix", **kwargs): 15 | self.notation = notation 16 | super().__init__(**kwargs) 17 | 18 | def encode_tokens(self, data: T, **kwargs) -> List[str]: 19 | return data.to_tokens(notation=self.notation, **kwargs) 20 | 21 | def decode_tokens(self, tokens: List[str], **kwargs) -> T: 22 | return self.dtype.from_tokens(tokens, notation=self.notation, **kwargs) 23 | 24 | def encode_pos_enc(self, data: T, **kwargs) -> List[List[int]]: 25 | return data.ast.tree_positional_encoding(notation=self.notation, format=self.tpe_format) 26 | 27 | def vocabulary_filename(self) -> str: 28 | return super().vocabulary_filename() + "-" + self.notation 29 | -------------------------------------------------------------------------------- /ml2/tokenizers/load_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Utility to load tokenizer""" 2 | 3 | import logging 4 | 5 | from ..tokenizers import Tokenizer 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def load_tokenizer(name: str, project: str = None, **kwargs) -> Tokenizer: 12 | from ..registry import type_from_str 13 | 14 | config = Tokenizer.fetch_config(name=name, project=project) 15 | if "type" not in config: 16 | raise Exception("Tokenizer type not specified in config") 17 | tokenizer_type = type_from_str(config["type"], bound=Tokenizer) 18 | return tokenizer_type.load(name=name, project=project, **kwargs) 19 | -------------------------------------------------------------------------------- /ml2/tokenizers/load_vocabulary.py: -------------------------------------------------------------------------------- 1 | """Utility to load vocabulary""" 2 | 3 | from .vocabulary import Vocabulary 4 | 5 | 6 | def load_vocabulary(name: str, project: str = None, **kwargs) -> Vocabulary: 7 | from ..registry import type_from_str 8 | 9 | config = Vocabulary.fetch_config(name=name, project=project) 10 | if "type" not in config: 11 | raise Exception("Vocabulary type not specified in config") 12 | vocabulary_type = type_from_str(config["type"], bound=Vocabulary) 13 | return vocabulary_type.load(name=name, project=project, **kwargs) 14 | -------------------------------------------------------------------------------- /ml2/tokenizers/pair_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cat_seq_pair_to_seq_tokenizer import CatSeqPairToSeqTokenizer 2 | from .pair_to_seq_tokenizer import PairToSeqTokenizer 3 | -------------------------------------------------------------------------------- /ml2/tokenizers/seq_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .seq_to_seq_tokenizer import SeqToSeqTokenizer 2 | -------------------------------------------------------------------------------- /ml2/tokenizers/seq_tokenizers/seq_to_seq_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes a sequence into a sequence encoding""" 2 | 3 | from typing import Generic, List, TypeVar 4 | 5 | from ...dtypes import Seq 6 | from ...registry import register_type 7 | from ..to_seq_tokenizer import ToSeqTokenizer 8 | 9 | T = TypeVar("T", bound=Seq) 10 | 11 | 12 | @register_type 13 | class SeqToSeqTokenizer(ToSeqTokenizer[T], Generic[T]): 14 | 15 | def __init__(self, token_kwargs: dict = None, **kwargs) -> None: 16 | super().__init__(**kwargs) 17 | self.token_kwargs = token_kwargs if token_kwargs else {} 18 | 19 | def encode_tokens(self, data: T, **kwargs) -> List[str]: 20 | return data.to_tokens(**self.token_kwargs, **kwargs) 21 | 22 | def decode_tokens(self, tokens: List[str], **kwargs) -> T: 23 | return self.dtype.from_tokens(tokens, **self.token_kwargs, **kwargs) 24 | -------------------------------------------------------------------------------- /ml2/tokenizers/to_seq_tpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes into a sequence encoding with tree positional encoding""" 2 | 3 | import logging 4 | from typing import Any, Dict, Generic, TypeVar 5 | 6 | from ..dtypes import DType, TPEFormat 7 | from ..registry import register_type 8 | from .to_seq_pos_tokenizer import ToSeqPosTokenizer 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger(__name__) 12 | 13 | T = TypeVar("T", bound=DType) 14 | 15 | 16 | @register_type 17 | class ToSeqTPETokenizer(ToSeqPosTokenizer[T], Generic[T]): 18 | def __init__(self, tpe_format: TPEFormat = TPEFormat.BRANCHUP, **kwargs) -> None: 19 | self.tpe_format = tpe_format 20 | super().__init__(**kwargs) 21 | 22 | def config_postprocessors(self, **kwargs) -> list: 23 | def postprocess_tpe_pad(config: Dict[str, Any], annotations: Dict[str, type]) -> None: 24 | if "tpe_pad" not in config: 25 | config["tpe_pad"] = config.pop("pos_pad") 26 | annotations["tpe_pad"] = annotations.pop("pos_pad") 27 | 28 | return super().config_postprocessors() + [postprocess_tpe_pad] 29 | 30 | @classmethod 31 | def config_preprocessors(cls) -> list: 32 | def preprocess_tpe_pad(config: Dict[str, Any], annotations: Dict[str, type]) -> None: 33 | if "tpe_pad" in config: 34 | config["pos_pad"] = config.pop("tpe_pad") 35 | 36 | return [preprocess_tpe_pad] + super().config_preprocessors() 37 | -------------------------------------------------------------------------------- /ml2/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from . import aalta, bosy, limboole, ltl_tool, neurosynt, nusmv, nuxmv, spot, strix, syfco 2 | from .grpc_service import GRPCService 3 | -------------------------------------------------------------------------------- /ml2/tools/aalta/__init__.py: -------------------------------------------------------------------------------- 1 | from .aalta import Aalta, AaltaTraceMC 2 | -------------------------------------------------------------------------------- /ml2/tools/aalta/aalta_wrapper.py: -------------------------------------------------------------------------------- 1 | """Aalta wrapper""" 2 | 3 | import logging 4 | import subprocess 5 | 6 | from ...ltl.ltl_sat.ltl_sat_status import LTLSatStatus 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | AALTA_BIN_PATH = "/aalta/aalta" 12 | 13 | 14 | def aalta_wrapper_str(formula: str, evidence: bool = True, timeout: float = None): 15 | try: 16 | args = [AALTA_BIN_PATH, "-c", "-l"] 17 | if evidence: 18 | args.append("-e") 19 | out = subprocess.run(args, capture_output=True, input=formula, text=True, timeout=timeout) 20 | except subprocess.TimeoutExpired: 21 | logger.debug("aalta timeout") 22 | return {"status": LTLSatStatus("timeout")} 23 | except subprocess.CalledProcessError: 24 | logger.error("subprocess called process error") 25 | return {"status": LTLSatStatus("error")} 26 | except Exception: 27 | logger.error("Unknown exception") 28 | return {"status": LTLSatStatus("error")} 29 | aalta_stdout = out.stdout 30 | aalta_stdout_lines = aalta_stdout.split("\n") 31 | if out.returncode == 0 and aalta_stdout_lines[1] == "sat": 32 | return {"status": LTLSatStatus("satisfiable"), "trace": "\n".join(aalta_stdout_lines[2:])} 33 | if out.returncode == 0 and aalta_stdout_lines[1] == "unsat": 34 | return {"status": LTLSatStatus("unsatisfiable")} 35 | return {"status": LTLSatStatus("error"), "message": aalta_stdout} 36 | -------------------------------------------------------------------------------- /ml2/tools/abc_aiger/__init__.py: -------------------------------------------------------------------------------- 1 | from . import abc_wrapper, aiger_wrapper, wrapper_helper 2 | from .abc_aiger import ABCAiger 3 | -------------------------------------------------------------------------------- /ml2/tools/booleforce/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_ray_available 2 | from .booleforce import BooleForce, TraceCheckVerifier 3 | 4 | if is_ray_available(): 5 | from .booleforce_worker import booleforce_worker_fn 6 | -------------------------------------------------------------------------------- /ml2/tools/booleforce/booleforce_worker.py: -------------------------------------------------------------------------------- 1 | """BooleForce worker""" 2 | 3 | import ray 4 | 5 | from ...data_gen import DataGenServer 6 | from .booleforce import BooleForce 7 | 8 | 9 | @ray.remote 10 | def booleforce_worker_fn( 11 | server: DataGenServer, 12 | id: int, 13 | port: int, 14 | mem_limit: str, 15 | ) -> None: 16 | booleforce = BooleForce(port=port, mem_limit=mem_limit) 17 | server.register_worker.remote() 18 | while True: 19 | problems = ray.get(server.get_problem_batch.remote()) 20 | if problems is None: 21 | break 22 | for problem in problems: 23 | solution = booleforce.check_sat(formula=problem.formula, timeout=problem.timeout) 24 | problem.solution = solution 25 | ray.get(server.post_problem_batch.remote(problems)) 26 | server.deregister_worker.remote() 27 | -------------------------------------------------------------------------------- /ml2/tools/bosy/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_ray_available 2 | from .bosy import BoSy 3 | 4 | if is_ray_available(): 5 | from .bosy_worker import add_bosy_args, bosy_worker_fn, bosy_worker_fn_dict 6 | -------------------------------------------------------------------------------- /ml2/tools/limboole/__init__.py: -------------------------------------------------------------------------------- 1 | from .limboole import Limboole 2 | -------------------------------------------------------------------------------- /ml2/tools/ltl_tool/__init__.py: -------------------------------------------------------------------------------- 1 | from . import pb2_converter 2 | from .tool_ltl_conversion import ToolLTLConversionRequest, ToolLTLConversionResponse 3 | from .tool_ltl_mc_problem import ToolLTLMCProblem, ToolLTLMCSolution, ToolLTLMCSolutionSymbolic 4 | from .tool_ltl_syn_problem import ToolLTLSynProblem, ToolLTLSynSolution 5 | -------------------------------------------------------------------------------- /ml2/tools/neurosynt/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_pt_available, is_tf_available 2 | from .neurosynt import NeuroSynt 3 | 4 | if is_pt_available() and is_tf_available(): 5 | from .pipeline_wrapper import PipelineWrapper 6 | -------------------------------------------------------------------------------- /ml2/tools/nusmv/__init__.py: -------------------------------------------------------------------------------- 1 | from .nusmv import NuSMV, NuSMVMC 2 | from .nusmv_wrapper import nusmv_wrapper 3 | -------------------------------------------------------------------------------- /ml2/tools/nuxmv/__init__.py: -------------------------------------------------------------------------------- 1 | from .nuxmv import Nuxmv, NuxmvMC 2 | from .nuxmv_wrapper import nuxmv_wrapper 3 | -------------------------------------------------------------------------------- /ml2/tools/semml/__init__.py: -------------------------------------------------------------------------------- 1 | from . import semml_wrapper 2 | from .semml import Semml 3 | -------------------------------------------------------------------------------- /ml2/tools/spot/__init__.py: -------------------------------------------------------------------------------- 1 | from .spot import Spot 2 | from .spot_aiger_mc import SpotAIGERMC 3 | from .spot_equiv_verifier import SpotEquivVerifier 4 | from .spot_strace_mc import SpotSTraceMC 5 | -------------------------------------------------------------------------------- /ml2/tools/spot/spot_aiger_mc.py: -------------------------------------------------------------------------------- 1 | """Spot AIGER model checker""" 2 | 3 | from typing import Dict, Optional 4 | 5 | from ...aiger import AIGERCircuit 6 | from ...dtypes import CatSeq 7 | from ...ltl.ltl_mc import LTLMCSolution 8 | from ...ltl.ltl_spec.decomp_ltl_spec import DecompLTLSpec 9 | from ...ltl.ltl_syn.ltl_real_status import LTLRealStatus 10 | from ...registry import register_type 11 | from ...verifier import Verifier 12 | from ..ltl_tool.tool_ltl_mc_problem import ToolLTLMCProblem 13 | from .spot import Spot 14 | 15 | 16 | @register_type 17 | class SpotAIGERMC(Spot, Verifier): 18 | def verify( 19 | self, 20 | formula: DecompLTLSpec, 21 | solution: CatSeq[LTLRealStatus, AIGERCircuit], 22 | parameters: Optional[Dict[str, str]] = None, 23 | ) -> LTLMCSolution: 24 | if parameters is None: 25 | parameters = {} 26 | if "timeout" not in parameters: 27 | parameters["timeout"] = 120 28 | return self.model_check( 29 | problem=ToolLTLMCProblem.from_aiger_verification_pair( 30 | formula=formula, solution=solution, parameters=parameters 31 | ) 32 | ).to_LTLMCSolution() 33 | -------------------------------------------------------------------------------- /ml2/tools/spot/spot_equiv_verifier.py: -------------------------------------------------------------------------------- 1 | """Spot symbolic trace model checker""" 2 | 3 | from ...ltl.ltl_equiv import LTLEquivStatus 4 | from ...ltl.ltl_formula import LTLFormula 5 | from ...verifier import EquivVerifier 6 | from .spot import Spot 7 | 8 | 9 | class SpotEquivVerifier(Spot, EquivVerifier): 10 | def verify_equiv(self, x: LTLFormula, y: LTLFormula, **kwargs) -> LTLEquivStatus: 11 | return self.check_equiv(f=x, g=y, timeout=10) 12 | -------------------------------------------------------------------------------- /ml2/tools/spot/spot_strace_mc.py: -------------------------------------------------------------------------------- 1 | """Spot symbolic trace model checker""" 2 | 3 | from ...ltl import LTLFormula 4 | from ...registry import register_type 5 | from ...trace import SymbolicTrace, TraceMCStatus 6 | from ...verifier import Verifier 7 | from .spot import Spot 8 | 9 | 10 | @register_type 11 | class SpotSTraceMC(Spot, Verifier): 12 | def verify( 13 | self, problem: LTLFormula, solution: SymbolicTrace, timeout: int = 600 14 | ) -> TraceMCStatus: 15 | return self.mc_trace( 16 | formula=problem.to_str(notation="infix"), 17 | trace=solution.to_str(notation="infix", spot=True), 18 | timeout=timeout, 19 | ) 20 | -------------------------------------------------------------------------------- /ml2/tools/strix/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils import is_ray_available 2 | from . import strix_wrapper 3 | from .strix import Strix 4 | 5 | if is_ray_available(): 6 | from .strix_worker import add_strix_args, strix_worker_fn, strix_worker_fn_dict 7 | -------------------------------------------------------------------------------- /ml2/tools/syfco/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tlsf_to_bosy 2 | from .syfco import Syfco 3 | -------------------------------------------------------------------------------- /ml2/trace/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_pt_available, is_tf_available 2 | from .symbolic_trace import SymbolicTrace 3 | from .trace import Trace 4 | from .trace_mc_status import TraceMCStatus 5 | 6 | if is_pt_available() and is_tf_available(): 7 | from .sym_trace_to_seq_tokenizer import SymTraceToSeqTokenizer 8 | -------------------------------------------------------------------------------- /ml2/trace/sym_trace_to_seq_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer that encodes a symbolic trace into a sequence encoding""" 2 | 3 | import logging 4 | from typing import List, Type 5 | 6 | from ..registry import register_type 7 | from ..tokenizers import SeqToSeqTokenizer 8 | from .symbolic_trace import SymbolicTrace 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @register_type 15 | class SymTraceToSeqTokenizer(SeqToSeqTokenizer): 16 | def __init__( 17 | self, notation: str = "infix", dtype: Type[SymbolicTrace] = SymbolicTrace, **kwargs 18 | ) -> None: 19 | self.notation = notation 20 | super().__init__(dtype=dtype, **kwargs) 21 | 22 | def encode_tokens(self, data: SymbolicTrace, **kwargs) -> List[str]: 23 | return data.to_tokens(notation=self.notation, **kwargs) 24 | 25 | def decode_tokens(self, tokens: List[str], **kwargs) -> SymbolicTrace: 26 | return SymbolicTrace.from_tokens(tokens=tokens, notation=self.notation, **kwargs) 27 | -------------------------------------------------------------------------------- /ml2/train/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import is_hf_available, is_ray_available, is_tf_available 2 | from .load_trainer import load_trainer 3 | from .trainer import Trainer 4 | 5 | if is_tf_available(): 6 | from .keras_trainer import KerasTrainer 7 | from .keras_trainer_ddp import KerasTrainerDDP 8 | from .keras_transformer_trainer import KerasTransformerTrainer 9 | 10 | if is_hf_available(): 11 | from .hf_seq2seq_trainer import HFSeq2SeqTrainer 12 | -------------------------------------------------------------------------------- /ml2/train/keras_transformer_trainer.py: -------------------------------------------------------------------------------- 1 | """Keras Transformer trainer""" 2 | 3 | import logging 4 | 5 | from keras.optimizers import Adam, Optimizer 6 | 7 | from ..optim.tf_optim.tf_transformer_lr_schedule import TFTransformerLRSchedule 8 | from ..pipelines import TFPipeline 9 | from ..registry import register_type 10 | from .keras_trainer import KerasTrainer 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @register_type 17 | class KerasTransformerTrainer(KerasTrainer): 18 | def __init__( 19 | self, pipeline: TFPipeline, optimizer: Optimizer = None, warmup_steps: int = 4000, **kwargs 20 | ): 21 | if optimizer is None: 22 | learning_rate = TFTransformerLRSchedule( 23 | pipeline.model_config["d_embed_enc"], warmup_steps 24 | ) 25 | optimizer = Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) 26 | 27 | self.warmup_steps = warmup_steps 28 | 29 | super().__init__(pipeline=pipeline, optimizer=optimizer, **kwargs) 30 | -------------------------------------------------------------------------------- /ml2/train/load_trainer.py: -------------------------------------------------------------------------------- 1 | """Utility to load trainer""" 2 | 3 | import logging 4 | 5 | from .trainer import Trainer 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def load_trainer(name: str, project: str = None, **kwargs) -> Trainer: 12 | from ..registry import type_from_str 13 | 14 | config = Trainer.fetch_config(name=name, project=project) 15 | if "type" not in config: 16 | raise Exception("Trainer type not specified in config") 17 | trainer_type = type_from_str(config["type"], bound=Trainer) 18 | return trainer_type.load(name=name, project=project, **kwargs) 19 | -------------------------------------------------------------------------------- /ml2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .import_utils import is_hf_available, is_pt_available, is_ray_available, is_tf_available 2 | -------------------------------------------------------------------------------- /ml2/utils/dict_utils.py: -------------------------------------------------------------------------------- 1 | """Dictionary utilities""" 2 | 3 | from typing import Callable, Dict 4 | 5 | 6 | def map_nested_dict(f: Callable, d: Dict) -> None: 7 | for k, v in d.items(): 8 | if isinstance(v, dict): 9 | map_nested_dict(f, v) 10 | else: 11 | d[k] = f(v) 12 | -------------------------------------------------------------------------------- /ml2/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | """Distribution and Architecture utilities""" 2 | 3 | import platform 4 | 5 | 6 | def architecture_is_apple_arm() -> bool: 7 | return platform.system() == "Darwin" and platform.machine() == "arm64" 8 | -------------------------------------------------------------------------------- /ml2/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for checking if a library is available""" 2 | 3 | import importlib 4 | 5 | 6 | def is_pt_available() -> bool: 7 | """Check if PyTorch is available""" 8 | return importlib.util.find_spec("torch") is not None 9 | 10 | 11 | def is_tf_available() -> bool: 12 | """Check if Tensorflow is available""" 13 | return importlib.util.find_spec("tensorflow") is not None 14 | 15 | 16 | def is_hf_available() -> bool: 17 | """Check if HuggingFace is available""" 18 | return importlib.util.find_spec("transformers") is not None 19 | 20 | 21 | def is_ray_available() -> bool: 22 | """Check if Ray is available""" 23 | return importlib.util.find_spec("ray") is not None 24 | -------------------------------------------------------------------------------- /ml2/utils/list_utils.py: -------------------------------------------------------------------------------- 1 | """List utilities""" 2 | 3 | from typing import List, TypeVar 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def join_lists(delimiter: T, lists: List[List[T]]) -> List[T]: 9 | result: List[T] = [] 10 | for l in lists: 11 | if result != []: 12 | result.append(delimiter) 13 | result.extend(l) 14 | return result 15 | -------------------------------------------------------------------------------- /ml2/utils/np_utils.py: -------------------------------------------------------------------------------- 1 | """NumPy utilities""" 2 | 3 | import numpy as np 4 | 5 | str_to_np_float_dtype = {"float16": np.float16, "float32": np.float32, "float64": np.float64} 6 | np_float_dtype_to_str = {c: s for s, c in str_to_np_float_dtype.items()} 7 | 8 | str_to_np_int_dtype = {"int16": np.int16, "int32": np.int32, "int64": np.int64} 9 | np_int_dtype_to_str = {c: s for s, c in str_to_np_int_dtype.items()} 10 | -------------------------------------------------------------------------------- /ml2/utils/pt_utils.py: -------------------------------------------------------------------------------- 1 | """PyTorch utilities""" 2 | 3 | import torch 4 | 5 | str_to_pt_float_dtype = { 6 | "float16": torch.float16, 7 | "float32": torch.float32, 8 | "float64": torch.float64, 9 | } 10 | pt_float_dtype_to_str = {c: s for s, c in str_to_pt_float_dtype.items()} 11 | 12 | str_to_pt_int_dtype = {"int16": torch.int16, "int32": torch.int32, "int64": torch.int64} 13 | pt_int_dtype_to_str = {c: s for s, c in str_to_pt_int_dtype.items()} 14 | -------------------------------------------------------------------------------- /ml2/utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | """TensorFlow utilities""" 2 | 3 | import tensorflow as tf 4 | 5 | str_to_tf_float_dtype = {"float16": tf.float16, "float32": tf.float32, "float64": tf.float64} 6 | tf_float_dtype_to_str = {c: s for s, c in str_to_tf_float_dtype.items()} 7 | 8 | str_to_tf_int_dtype = {"int16": tf.int16, "int32": tf.int32, "int64": tf.int64} 9 | tf_int_dtype_to_str = {c: s for s, c in str_to_tf_int_dtype.items()} 10 | -------------------------------------------------------------------------------- /ml2/verifier/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_verifier import load_verifier_from_general_config 2 | from .equiv_status import EquivStatus 3 | from .equiv_verifier import EquivVerifier 4 | from .verifier import Verifier 5 | -------------------------------------------------------------------------------- /ml2/verifier/equiv_status.py: -------------------------------------------------------------------------------- 1 | """Abstract equivalence status class""" 2 | 3 | from abc import abstractmethod 4 | from typing import Optional 5 | 6 | from ..dtypes import DType 7 | 8 | 9 | class EquivStatus(DType): 10 | @property 11 | @abstractmethod 12 | def equiv(self) -> Optional[bool]: 13 | raise NotImplementedError() 14 | 15 | @property 16 | @abstractmethod 17 | def status(self) -> str: 18 | raise NotImplementedError() 19 | -------------------------------------------------------------------------------- /ml2/verifier/equiv_verifier.py: -------------------------------------------------------------------------------- 1 | """Abstract equivalence verifier class""" 2 | 3 | from abc import abstractmethod 4 | from typing import Generic, TypeVar 5 | 6 | from ..configurable import Configurable 7 | from ..dtypes import DType 8 | from .equiv_status import EquivStatus 9 | 10 | T = TypeVar("T", bound=DType) 11 | 12 | 13 | class EquivVerifier(Configurable, Generic[T]): 14 | @abstractmethod 15 | def verify_equiv(self, x: T, y: T, **kwargs) -> EquivStatus: 16 | raise NotImplementedError() 17 | -------------------------------------------------------------------------------- /ml2/verifier/load_verifier.py: -------------------------------------------------------------------------------- 1 | """Utility to load verifier""" 2 | 3 | import logging 4 | 5 | from .verifier import Verifier 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def load_verifier_from_general_config(config, log_name: str = "verifier") -> Verifier: 12 | from ..registry import type_from_str 13 | 14 | if isinstance(config, str): 15 | verifier = type_from_str(config, bound=Verifier)() 16 | 17 | elif isinstance(config, dict): 18 | if "type" not in config: 19 | raise Exception(f"Type not specified in {log_name} config {config}") 20 | verifier_type = type_from_str(config.pop("type"), bound=Verifier) 21 | verifier = verifier_type.from_config(config) 22 | 23 | else: 24 | raise Exception(f"Invalid {log_name} config {config}") 25 | return verifier 26 | -------------------------------------------------------------------------------- /ml2/verifier/verifier.py: -------------------------------------------------------------------------------- 1 | """Abstract verifier class""" 2 | 3 | from abc import abstractmethod 4 | from typing import Generic, TypeVar 5 | 6 | from ..configurable import Configurable 7 | from ..dtypes import DType 8 | 9 | P = TypeVar("P", bound=DType) 10 | S = TypeVar("S", bound=DType) 11 | 12 | 13 | class Verifier(Configurable, Generic[P, S]): 14 | @abstractmethod 15 | def verify(self, problem: P, solution: S, **kwargs): 16 | raise NotImplementedError() 17 | -------------------------------------------------------------------------------- /notebooks/models/count_params.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Count Model Parameters" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from keras.utils.layer_utils import count_params\n", 17 | "import ml2\n", 18 | "from ml2.experiment.experiment import Experiment" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "exp: Experiment = ml2.load(\"ltl-syn/ht-12\")\n", 28 | "count_params(exp.trainer.pipeline.eval_model.trainable_weights)" 29 | ] 30 | } 31 | ], 32 | "metadata": { 33 | "kernelspec": { 34 | "display_name": "ml2", 35 | "language": "python", 36 | "name": "python3" 37 | }, 38 | "language_info": { 39 | "codemirror_mode": { 40 | "name": "ipython", 41 | "version": 3 42 | }, 43 | "file_extension": ".py", 44 | "mimetype": "text/x-python", 45 | "name": "python", 46 | "nbconvert_exporter": "python", 47 | "pygments_lexer": "ipython3", 48 | "version": "3.8.15" 49 | }, 50 | "orig_nbformat": 4 51 | }, 52 | "nbformat": 4, 53 | "nbformat_minor": 2 54 | } 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=46.4", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.black] 9 | line-length = 99 10 | 11 | [tool.isort] 12 | line_length = 99 13 | profile = "black" 14 | -------------------------------------------------------------------------------- /scripts/protoc.sh: -------------------------------------------------------------------------------- 1 | # Compiles protocol buffers files 2 | 3 | GRPC_PATH="../ml2/grpc" 4 | 5 | if [ $2 ] 6 | then 7 | PROTO_PATH=$GRPC_PATH/$1/$2.proto 8 | else 9 | PROTO_PATH=$GRPC_PATH/$1/$1.proto 10 | fi 11 | 12 | echo "Compiling protocol buffer at $PROTO_PATH" 13 | 14 | python -m grpc_tools.protoc --grpc_python_out=../ --python_out=../ --proto_path=../ --mypy_out=../ $PROTO_PATH -------------------------------------------------------------------------------- /tests/ltl/ltl_equiv/ltl_equiv_test.py: -------------------------------------------------------------------------------- 1 | """LTL formula equiv test""" 2 | 3 | from ml2.ltl.ltl_equiv import LTLEquivStatus, LTLInclStatus 4 | 5 | 6 | def test_ltl_equiv_1(): 7 | assert LTLEquivStatus("inequivalent").to_int() == 0 8 | assert LTLEquivStatus("error").to_int() == -1 9 | assert LTLEquivStatus("timeout").to_int() == -2 10 | assert LTLEquivStatus("equivalent").to_int() == 1 11 | assert LTLEquivStatus("equivalent").equiv 12 | 13 | 14 | def test_ltl_equiv_2(): 15 | try: 16 | LTLEquivStatus("satisfied") 17 | assert False 18 | except Exception as e: 19 | assert "invalid" in str(e).lower() 20 | 21 | 22 | def test_ltl_incl_1(): 23 | assert LTLInclStatus("incomparable").to_int() == 0 24 | assert LTLInclStatus("error").to_int() == -1 25 | assert LTLInclStatus("timeout").to_int() == -2 26 | assert LTLInclStatus("equivalent").to_int() == 1 27 | assert LTLInclStatus("equivalent").equiv 28 | assert LTLInclStatus("equivalent").left_in_right 29 | assert LTLInclStatus("equivalent").right_in_left 30 | assert LTLInclStatus("only_left_in_right").to_int() == 2 31 | assert LTLInclStatus("only_left_in_right").left_in_right 32 | assert LTLInclStatus("only_right_in_left").to_int() == 3 33 | assert LTLInclStatus("only_right_in_left").right_in_left 34 | 35 | 36 | def test_ltl_incl_2(): 37 | try: 38 | LTLInclStatus("inequivalent") 39 | assert False 40 | except Exception as e: 41 | assert "invalid" in str(e).lower() 42 | -------------------------------------------------------------------------------- /tests/ltl/ltl_spec/load_jarvis_test.py: -------------------------------------------------------------------------------- 1 | """jarvis-0 dataset test""" 2 | 3 | from ml2.datasets import load_dataset 4 | from ml2.ltl.ltl_spec import DecompLTLSpec 5 | 6 | 7 | def test_load_jarvis_0(): 8 | ds = load_dataset("ltl-spec/jarvis-0") 9 | assert ds.name == "jarvis-0" 10 | assert ds.project == "ltl-spec" 11 | assert ds.dtype == DecompLTLSpec 12 | assert ds.size == 189 13 | -------------------------------------------------------------------------------- /tests/ltl/ltl_spec/load_sc_test.py: -------------------------------------------------------------------------------- 1 | """sc-0 dataset test""" 2 | 3 | from ml2.datasets import load_dataset 4 | from ml2.ltl.ltl_spec import DecompLTLSpec 5 | 6 | 7 | def test_load_sc_0(): 8 | ds = load_dataset("ltl-spec/sc-0") 9 | assert ds.name == "sc-0" 10 | assert ds.project == "ltl-spec" 11 | assert ds.dtype == DecompLTLSpec 12 | assert ds.size == 346 13 | -------------------------------------------------------------------------------- /tests/ltl/ltl_syn/load_scpa_test.py: -------------------------------------------------------------------------------- 1 | """scpa-2 dataset test""" 2 | 3 | import pytest 4 | 5 | from ml2.datasets import load_dataset 6 | 7 | 8 | @pytest.mark.gcp 9 | def test_load_scpa_2(): 10 | ds = load_dataset("ltl-syn/scpa-2") 11 | assert ds.name == "scpa-2" 12 | assert "train" in ds 13 | assert ds["train"].size == 200000 14 | assert "val" in ds 15 | assert ds["val"].size == 25000 16 | assert "test" in ds 17 | assert ds["test"].size == 25000 18 | assert "timeouts" in ds 19 | assert ds["timeouts"].size == 13403 20 | 21 | 22 | @pytest.mark.gcp 23 | def test_load_with_sample_scpa_2(): 24 | ds = load_dataset("ltl-syn/scpa-2/val", sample=1000) 25 | assert ds.size == 1000 26 | 27 | 28 | @pytest.mark.gcp 29 | def test_load_with_shuffle(): 30 | ds_1 = load_dataset("ltl-syn/scpa-2/val") 31 | ds_2 = load_dataset("ltl-syn/scpa-2/val") 32 | ds_3 = load_dataset("ltl-syn/scpa-2/val", shuffle=True) 33 | assert ds_1[0] == ds_2[0] 34 | # TODO small chance this fails 35 | assert ds_1[0] != ds_3[0] 36 | -------------------------------------------------------------------------------- /tests/pipelines/metrics/acc_per_seq_test.py: -------------------------------------------------------------------------------- 1 | """Accuracy per sequence test""" 2 | 3 | from ml2.pipelines.metrics import AccPerSeq 4 | from ml2.pipelines.samples import EvalLabeledSample 5 | 6 | 7 | def test_acc_add(): 8 | metric = AccPerSeq() 9 | sample1 = EvalLabeledSample(inp=None, tar=None, tar_enc=[0, 0, 0], pred_enc=[0, 0, 0]) 10 | assert metric.add(sample1) == 1 11 | 12 | sample2 = EvalLabeledSample(inp=None, tar=None, tar_enc=[1, 1, 1], pred_enc=[0, 1, 1]) 13 | assert metric.add(sample2) == 0 14 | assert metric.compute() == 0.5 15 | 16 | metric.reset() 17 | assert metric.compute() == 0.0 18 | 19 | sample3 = EvalLabeledSample(inp=None, tar=None, tar_enc=[0], pred_enc=[1]) 20 | assert metric.add(sample3) == 0.0 21 | assert metric.compute() == 0.0 22 | -------------------------------------------------------------------------------- /tests/pipelines/metrics/acc_test.py: -------------------------------------------------------------------------------- 1 | """Accuracy test""" 2 | 3 | from ml2.pipelines.metrics import Acc 4 | from ml2.pipelines.samples import EvalLabeledSample 5 | 6 | 7 | def test_acc(): 8 | metric = Acc() 9 | sample1 = EvalLabeledSample(inp=None, tar=None, tar_enc=[0, 0, 1], pred_enc=[0, 0, 0]) 10 | assert metric.add(sample1) == 2 / 3 11 | 12 | sample2 = EvalLabeledSample(inp=None, tar=None, tar_enc=[2, 2, 2], pred_enc=[2, 2, 2]) 13 | assert metric.add(sample2) == 1.0 14 | assert metric.compute() == 5 / 6 15 | 16 | metric.reset() 17 | assert metric.compute() == 0.0 18 | 19 | sample3 = EvalLabeledSample(inp=None, tar=None, tar_enc=[-1], pred_enc=[-1]) 20 | assert metric.add(sample3) == 1.0 21 | assert metric.compute() == 1.0 22 | 23 | 24 | def test_pad_acc(): 25 | metric = Acc(pad_same_length=True) 26 | 27 | sample1 = EvalLabeledSample(inp=None, tar=None, tar_enc=[1, 1, 0, 0], pred_enc=[1, 1]) 28 | assert metric.add(sample1) == 1.0 29 | 30 | sample2 = EvalLabeledSample(inp=None, tar=None, tar_enc=[1, 3, 2, 1], pred_enc=[1, 1, 1]) 31 | assert metric.add(sample2) == 1 / 4 32 | 33 | assert metric.compute() == 5 / 8 34 | 35 | metric.reset() 36 | 37 | assert metric.compute() == 0.0 38 | -------------------------------------------------------------------------------- /tests/pipelines/metrics/err_counter_test.py: -------------------------------------------------------------------------------- 1 | """Error counting test""" 2 | 3 | 4 | from ml2.pipelines.metrics import ErrCounter, EvalSupervisedErrCounter, SupervisedErrCounter 5 | from ml2.pipelines.samples import EncodedSample, EvalLabeledSample, LabeledSample 6 | 7 | 8 | def test_err_counter(): 9 | metric = ErrCounter() 10 | sample1 = EncodedSample(inp=None, inp_enc_err=None) 11 | assert metric.add(sample1) is False 12 | 13 | sample2 = EncodedSample(inp=None, inp_enc_err="Parse error") 14 | assert metric.add(sample2) is True 15 | assert metric.compute_dict()["inp_enc_errs"] == 1 16 | 17 | metric.reset() 18 | assert metric.compute_dict()["inp_enc_errs"] == 0 19 | 20 | 21 | def test_supervised_err_counter(): 22 | metric = SupervisedErrCounter() 23 | sample1 = LabeledSample(inp=None, tar=None) 24 | assert metric.add(sample1) is False 25 | 26 | sample2 = LabeledSample(inp=None, tar=None, inp_enc_err="Parse error") 27 | assert metric.add(sample2) is True 28 | 29 | sample3 = LabeledSample(inp=None, tar=None, tar_enc_err="Lex error") 30 | assert metric.add(sample3) is True 31 | 32 | assert metric.compute_dict()["inp_enc_errs"] == 1 33 | assert metric.compute_dict()["tar_enc_errs"] == 1 34 | 35 | 36 | def test_eval_supervised_err_counter(): 37 | metric = EvalSupervisedErrCounter() 38 | 39 | sample1 = EvalLabeledSample(inp=None, tar=None, pred_dec_err="Parse error") 40 | assert metric.add(sample1) is True 41 | 42 | assert metric.compute_dict()["pred_dec_errs"] == 1 43 | -------------------------------------------------------------------------------- /tests/pipelines/tf_transformer_pipeline_test.py: -------------------------------------------------------------------------------- 1 | """TensorFlow Transformer pipeline test""" 2 | import pathlib 3 | 4 | import pytest 5 | 6 | 7 | @pytest.fixture() 8 | def test_config_path(request): 9 | path = pathlib.Path(request.node.fspath) 10 | return path.with_name("tf_transformer_pipeline_test_config.json") 11 | 12 | 13 | def test_tf_transformer_pipeline(test_config_path): 14 | pass 15 | -------------------------------------------------------------------------------- /tests/pipelines/tf_transformer_pipeline_test_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tf-transformer-test-pipeline", 3 | "project": "test-project", 4 | "type": "TFTransformerPipeline", 5 | "custom_pos_enc": true, 6 | "model_config": { 7 | "alpha": 0.5, 8 | "beam_size": 2, 9 | "d_embed": 256, 10 | "d_ff": 1024, 11 | "dropout": 0.0, 12 | "dtype_float": "float32", 13 | "dtype_int": "int32", 14 | "ff_activation": "relu", 15 | "num_heads": 4, 16 | "num_layers": 8 17 | }, 18 | "input_tokenizer": { 19 | "type": "ExprToSeqTPETokenizer", 20 | "dtype": "LTLFormula", 21 | "notation": "prefix", 22 | "start": false, 23 | "eos": false, 24 | "pad": 128, 25 | "tpe_format": "branch-down", 26 | "tpe_pad": 256 27 | }, 28 | "target_tokenizer": { 29 | "type": "SymTraceToSeqTokenizer", 30 | "notation": "prefix", 31 | "start": false, 32 | "eos": true, 33 | "pad": 128 34 | }, 35 | "max_input_length": 128, 36 | "max_target_length": 128 37 | } 38 | -------------------------------------------------------------------------------- /tests/prop/assignment_test.py: -------------------------------------------------------------------------------- 1 | """Assignment test""" 2 | 3 | 4 | from ml2.prop import Assignment 5 | 6 | 7 | def test_assignment_dict(): 8 | assign = Assignment() 9 | assign["a"] = True 10 | assert "a" in assign 11 | assert assign["a"] 12 | assign["a"] = False 13 | assert not assign["a"] 14 | 15 | 16 | def test_assignment_str(): 17 | a1 = Assignment.from_str("a1,b0,c1", delimiter=",", value_type="num") 18 | a2 = Assignment.from_str( 19 | "a=True,b=False,c=True", assign_op="=", delimiter=",", value_type="bool" 20 | ) 21 | a3 = Assignment.from_str("a, ! b, c", not_op="!", delimiter=",") 22 | assert a1 == a2 23 | assert a1 == a3 24 | assert a1.to_str(assign_op="=", delimiter=",", value_type="bool") == "a=True,b=False,c=True" 25 | 26 | 27 | def test_assignment_tokens(): 28 | a1 = Assignment.from_tokens(["a", "1", "b", "0", "c", "1"], value_type="num") 29 | a2 = Assignment.from_tokens(["a", "True", "b", "False", "c", "True"], value_type="bool") 30 | a3 = Assignment.from_tokens(["a", ",", "!", "b", ",", "c"], not_op="!", delimiter=",") 31 | assert a1 == a2 32 | assert a1 == a3 33 | assert a1.to_tokens(value_type="bool") == ["a", "True", "b", "False", "c", "True"] 34 | 35 | 36 | def test_assignment_empty(): 37 | assert Assignment.from_str(" ") == Assignment() 38 | -------------------------------------------------------------------------------- /tests/prop/cnf/cnf_formula_test.py: -------------------------------------------------------------------------------- 1 | """CNF formula test""" 2 | 3 | from pathlib import Path 4 | 5 | from ml2.prop.cnf.cnf_formula import CNFFormula 6 | 7 | CSV_FIELDS = { 8 | "formula": "p cnf 10 49\\n3 2 9 0\\n-8 4 -9 -10 2 0\\n-9 -6 0\\n-9 1 10 0\\n6 4 -8 3 0\\n5 -10 2 6 0\\n-3 -1 -8 0\\n-8 -5 -2 9 0\\n3 7 -1 6 4 -8 10 5 9 0\\n-5 -4 2 0\\n-4 8 -6 2 0\\n5 -9 -8 0\\n4 -9 -8 -10 0\\n6 -2 9 0\\n10 -4 1 0\\n6 -5 -7 9 -10 -1 8 4 3 0\\n6 -8 9 10 -5 0\\n4 2 10 -3 0\\n8 6 7 0\\n-4 1 8 0\\n10 9 -1 -2 0\\n-5 -1 3 0\\n5 -8 6 -7 0\\n-6 -1 0\\n10 9 0\\n-8 9 4 0\\n2 -6 8 0\\n7 8 1 0\\n-2 8 10 1 0\\n-8 7 -10 0\\n6 1 -10 7 4 0\\n-10 4 -7 0\\n-8 -7 3 0\\n7 4 3 9 10 -8 2 -1 6 -5 0\\n-9 6 4 8 0\\n2 7 -10 0\\n-3 5 -2 0\\n7 -10 -2 4 0\\n-4 -2 10 -8 0\\n-9 4 -3 -7 1 -8 -5 0\\n-4 6 3 0\\n-6 -7 0\\n10 -1 5 -7 -3 0\\n-1 10 0\\n-1 -3 -10 -7 0\\n2 6 5 -4 0\\n4 -1 -8 6 0\\n-3 2 -5 0\\n-7 -2 -10 0\\n" 9 | } 10 | 11 | DIMACS_FILEPATH = Path(__file__).parent / "formula.dimacs" 12 | 13 | 14 | def test_cnf_formula(): 15 | formula = CNFFormula.from_csv_fields(CSV_FIELDS) 16 | assert formula.num_clauses == 49 17 | assert formula.num_vars == 10 18 | assert formula.to_csv_fields() == CSV_FIELDS 19 | 20 | 21 | def test_cnf_formula_from_dimacs_file(): 22 | formula = CNFFormula.from_dimacs_file(DIMACS_FILEPATH) 23 | assert formula.num_clauses == 7 24 | assert formula.num_vars == 5 25 | with open(DIMACS_FILEPATH, "r") as file: 26 | assert formula.to_str(notation="dimacs") == file.read() 27 | -------------------------------------------------------------------------------- /tests/prop/cnf/formula.dimacs: -------------------------------------------------------------------------------- 1 | c lorem ipsum 2 | c dolor sit amet 3 | p cnf 5 7 4 | 2 -3 -5 0 5 | 1 2 -3 5 0 6 | 4 5 0 7 | 2 5 0 8 | 2 3 -5 0 9 | 4 1 0 10 | 4 2 0 11 | -------------------------------------------------------------------------------- /tests/prop/cnf/res_proof_test.py: -------------------------------------------------------------------------------- 1 | """Res proof test""" 2 | 3 | from ml2.prop.cnf.res_proof import ResProof 4 | 5 | CSV_FIELDS = { 6 | "res_proof": "3 -6 -9 0 0\\n4 10 -9 1 0 0\\n5 -8 4 3 6 0 0\\n6 5 -10 2 6 0 0\\n7 -1 -8 -3 0 0\\n10 -4 -5 2 0 0\\n12 5 -9 -8 0 0\\n13 4 -9 -8 -10 0 0\\n14 -2 9 6 0 0\\n19 7 6 8 0 0\\n20 -4 8 1 0 0\\n24 -1 -6 0 0\\n25 10 9 0 0\\n28 7 8 1 0 0\\n30 7 -10 -8 0 0\\n32 4 -10 -7 0 0\\n33 -8 -7 3 0 0\\n35 4 -9 8 6 0 0\\n41 -4 6 3 0 0\\n42 -6 -7 0 0\\n44 -1 10 0 0\\n45 -1 -10 -3 -7 0 0\\n48 2 -5 -3 0 0\\n49 -2 -10 -7 0 0\\n51 -6 7 0 30 28 25 24 3 0\\n52 7 3 0 5 41 19 51 0\\n53 3 0 25 32 52 41 42 35 33 0\\n54 -1 -8 0 7 53 0\\n59 -10 -1 -7 0 45 53 0\\n60 -5 2 0 48 53 0\\n61 6 9 0 6 60 14 25 0\\n62 6 8 0 44 20 59 19 35 61 0\\n63 8 0 51 42 62 0\\n64 -1 0 54 63 0\\n66 10 -9 0 4 64 0\\n68 5 -9 0 12 63 0\\n69 4 -10 -9 0 13 63 0\\n74 -10 7 0 30 63 0\\n78 6 0 10 49 74 66 61 68 69 0\\n80 -7 0 42 78 0\\n81 0 51 80 78 0\\n" 7 | } 8 | 9 | 10 | def test_res_proof(): 11 | proof = ResProof.from_csv_fields(CSV_FIELDS) 12 | assert proof.to_csv_fields() == CSV_FIELDS 13 | -------------------------------------------------------------------------------- /tests/prop/prop_formula_test.py: -------------------------------------------------------------------------------- 1 | """Propositional formula test""" 2 | 3 | from ml2.prop.prop_formula import PropFormula 4 | 5 | FORMULA_STR = "a <-> (! b & c)" 6 | 7 | 8 | def test_prop_formula(): 9 | formula = PropFormula.from_str(FORMULA_STR) 10 | assert formula.to_str() == FORMULA_STR 11 | assert formula.to_str(notation="prefix") == "<-> a & ! b c" 12 | -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers: 3 | docker: requires running a Docker container 4 | gcp: requires access to Google Cloud Platform storage buckets 5 | tf: requires TensorFlow to be installed -------------------------------------------------------------------------------- /tests/tf_test.py: -------------------------------------------------------------------------------- 1 | """TensorFlow tests""" 2 | 3 | import pytest 4 | import tensorflow as tf 5 | 6 | 7 | # in some cases we observed numerical issues where for example the result of a dense layer depends on one or two vectors being passed 8 | @pytest.mark.tf 9 | class TFNumericIssueTest(tf.test.TestCase): 10 | def setUp(self): 11 | super().setUp() 12 | self.layer = tf.keras.layers.Dense(16) 13 | 14 | def test_numeric_issue(self): 15 | v1 = tf.random.normal([16]) 16 | v2 = tf.random.normal([16]) 17 | i1 = tf.stack([v1]) 18 | i2 = tf.stack([v1, v2]) 19 | o1 = self.layer(i1) 20 | o2 = self.layer(i2) 21 | self.assertAllEqual(o1[0], o2[0]) 22 | -------------------------------------------------------------------------------- /tests/tokenizers/expr_tokenizer/expr_to_seq_tokenizer_config_test.py: -------------------------------------------------------------------------------- 1 | """ExprToSeqTokenizer config test""" 2 | 3 | from ml2.ltl import LTLFormula 4 | from ml2.tokenizers import ExprToSeqTokenizer 5 | 6 | EXPR_TO_SEQ_TOKENIZER_CONFIG = { 7 | "name": "expr-to-seq-tokenizer", 8 | "project": "test", 9 | "dtype": "LTLFormula", 10 | "notation": "prefix", 11 | "pad": 16, 12 | "eos": False, 13 | "start": False, 14 | "vocabulary": { 15 | "name": "expr-to-seq-tokenizer/vocabulary", 16 | "project": "test", 17 | "token_to_id": { 18 | "

": 0, 19 | "a": 3, 20 | "b": 4, 21 | "c": 5, 22 | "d": 6, 23 | "e": 7, 24 | "U": 8, 25 | "X": 9, 26 | "&": 10, 27 | "!": 11, 28 | }, 29 | }, 30 | } 31 | 32 | 33 | def test_expr_to_seq_tokenizer_config(): 34 | tokenizer = ExprToSeqTokenizer.from_config(EXPR_TO_SEQ_TOKENIZER_CONFIG) 35 | formula = LTLFormula.from_str("a U b & ! c") 36 | encoding = tokenizer.encode(formula) 37 | assert encoding.ids == [10, 8, 3, 4, 11, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 38 | config = tokenizer.get_config() 39 | assert config["name"] == "expr-to-seq-tokenizer" 40 | assert config["project"] == "test" 41 | assert config["dtype"] == "LTLFormula" 42 | assert config["notation"] == "prefix" 43 | assert config["pad"] == 16 44 | assert not config["eos"] 45 | assert not config["start"] 46 | assert config["type"] == "ExprToSeqTokenizer" 47 | -------------------------------------------------------------------------------- /tests/tokenizers/load_vocab_test.py: -------------------------------------------------------------------------------- 1 | """Load vocabulary test""" 2 | 3 | import pytest 4 | 5 | from ml2.tokenizers import Vocabulary 6 | 7 | 8 | @pytest.mark.gcp 9 | def test_load_vocab(): 10 | vocab = Vocabulary.load( 11 | name="ht-12/train/pipe/input-tokenizer/sub-tokenizer/vocab", project="ltl-syn" 12 | ) 13 | assert vocab.tokens_to_ids(["i0", "U", "o0"]) == [11, 8, 26] 14 | 15 | 16 | @pytest.mark.gcp 17 | def test_load_vocab_with_kwargs(): 18 | NEW_TOKEN_TO_ID = {"a": 0, "b": 1, "U": 2} 19 | vocab = Vocabulary.load( 20 | name="ht-12/train/pipe/input-tokenizer/sub-tokenizer/vocab", 21 | project="ltl-syn", 22 | token_to_id=NEW_TOKEN_TO_ID, 23 | ) 24 | assert vocab.tokens_to_ids(["a", "U", "b"]) == [0, 2, 1] 25 | -------------------------------------------------------------------------------- /tests/tokenizers/seq_tokenizer/seq_to_seq_tokenizer_config_test.py: -------------------------------------------------------------------------------- 1 | """SeqToSeqTokenizer config test""" 2 | 3 | from ml2.ltl import LTLFormula 4 | from ml2.tokenizers import SeqToSeqTokenizer, Vocabulary 5 | 6 | SEQ_TO_SEQ_TOKENIZER_CONFIG = { 7 | "dtype": "LTLFormula", 8 | "pad": 10, 9 | "eos": True, 10 | "start": True, 11 | "name": "seq-to-seq-tokenizer", 12 | "project": "test", 13 | } 14 | 15 | VOCAB_DICT = { 16 | "

": 0, 17 | "": 1, 18 | "": 2, 19 | "a": 3, 20 | "b": 4, 21 | "c": 5, 22 | "d": 6, 23 | "e": 7, 24 | "U": 8, 25 | "X": 9, 26 | "&": 10, 27 | "!": 11, 28 | } 29 | 30 | 31 | def test_seq_to_seq_tokenizer_config(): 32 | vocabulary = Vocabulary(VOCAB_DICT) 33 | tokenizer = SeqToSeqTokenizer.from_config(SEQ_TO_SEQ_TOKENIZER_CONFIG, vocabulary=vocabulary) 34 | formula = LTLFormula.from_str("a U b & ! c") 35 | encoding = tokenizer.encode(formula) 36 | assert encoding.ids == [1, 3, 8, 4, 10, 11, 5, 2, 0, 0] 37 | config = tokenizer.get_config() 38 | assert config["dtype"] == "LTLFormula" 39 | assert config["pad"] == 10 40 | assert config["eos"] 41 | assert config["start"] 42 | assert config["name"] == "seq-to-seq-tokenizer" 43 | assert config["project"] == "test" 44 | assert config["type"] == "SeqToSeqTokenizer" 45 | -------------------------------------------------------------------------------- /tests/tools/aalta_test.py: -------------------------------------------------------------------------------- 1 | """Aalta test""" 2 | 3 | import pytest 4 | 5 | from ml2.ltl import LTLFormula 6 | from ml2.ltl.ltl_sat import LTLSatStatus 7 | from ml2.tools.aalta import Aalta 8 | 9 | 10 | @pytest.mark.docker 11 | def test_aalta(): 12 | aalta = Aalta() 13 | 14 | sat_formula = LTLFormula.from_str("a U b & G a") 15 | status, trace = aalta.check_sat(sat_formula) 16 | assert status == LTLSatStatus("satisfiable") 17 | 18 | unsat_formula = LTLFormula.from_str("a U b & G ! b") 19 | status, trace = aalta.check_sat(unsat_formula) 20 | assert status == LTLSatStatus("unsatisfiable") 21 | -------------------------------------------------------------------------------- /tests/tools/limboole_test.py: -------------------------------------------------------------------------------- 1 | """Limboole test""" 2 | 3 | import pytest 4 | 5 | from ml2.prop import PropSatStatus 6 | from ml2.prop.prop_formula import PropFormula 7 | from ml2.tools.limboole import Limboole 8 | 9 | 10 | @pytest.mark.docker 11 | def test_limboole(): 12 | limboole = Limboole() 13 | 14 | sat_formula = PropFormula.from_str("! a & (a <-> b) & (c -> b)") 15 | status, assignment = limboole.check_sat(sat_formula) 16 | assert status == PropSatStatus("sat") 17 | 18 | unsat_formula = PropFormula.from_str("a & (a -> b) & (b <-> c) & (! c | ! a)") 19 | status, assignment = limboole.check_sat(unsat_formula) 20 | assert status == PropSatStatus("unsat") 21 | -------------------------------------------------------------------------------- /tests/trace/sym_trace_to_seq_tokenizer_config_test.py: -------------------------------------------------------------------------------- 1 | """Symbolic trace to sequence tokenizer config test""" 2 | 3 | from ml2.tokenizers import Vocabulary 4 | from ml2.trace import SymbolicTrace, SymTraceToSeqTokenizer 5 | 6 | SYM_TRACE_TO_SEQ_TOKENIZER_CONFIG = { 7 | "notation": "prefix", 8 | "name": "sym-trace-to-seq-tokenizer", 9 | "project": "test", 10 | } 11 | 12 | 13 | def test_sym_trace_to_seq_tokenizer_config(): 14 | vocabulary = Vocabulary( 15 | token_to_id={}, name="sym-trace-to-seq-tokenizer/vocabulary", project="test" 16 | ) 17 | tokenizer = SymTraceToSeqTokenizer.from_config( 18 | SYM_TRACE_TO_SEQ_TOKENIZER_CONFIG, vocabulary=vocabulary 19 | ) 20 | assert tokenizer.dtype == SymbolicTrace 21 | assert tokenizer.notation == "prefix" 22 | config = tokenizer.get_config() 23 | assert "notation" in config and config["notation"] == "prefix" 24 | -------------------------------------------------------------------------------- /tests/trace/trace_test.py: -------------------------------------------------------------------------------- 1 | """Trace test""" 2 | 3 | from ml2.prop import Assignment 4 | from ml2.trace import Trace 5 | 6 | 7 | def test_trace_str(): 8 | t1 = Trace.from_str("a , ! b ; { b , c ; a }", notation="standard") 9 | t2 = Trace.from_str("a & ! b ; cycle{ b & c ; a }", notation="spot") 10 | t3 = Trace.from_str( 11 | " -> State: 1.1 <-\n a = TRUE\n b = FALSE\n -- Loop starts here\n -> State: 1.2 <-\n b = TRUE\n c=TRUE\n -> State: 1.3 <-\n a = TRUE\n", 12 | notation="nusmv", 13 | ) 14 | t4 = Trace.from_str("{a,(! b),}\n(\n{b,c,}\n{a,}\n)^w\n", notation="aalta") 15 | assert t1 == t2 16 | assert t1 == t3 17 | assert t1 == t4 18 | 19 | 20 | def test_trace_complete(): 21 | t = Trace.from_str("a , ! b ; b ; { ! a ; }", notation="standard") 22 | t.complete_by_predecessor() 23 | assert t == Trace.from_str("a , ! b ; a , b ; { ! a , b ; ! a , b }", notation="standard") 24 | 25 | 26 | def test_trace_empty(): 27 | t1 = Trace.from_str("{}", notation="standard") 28 | t2 = Trace.from_str("cycle{}", notation="spot") 29 | t3 = Trace.from_str("(\n{,}\n)^w\n", notation="aalta") 30 | for t in [t1, t2, t3]: 31 | assert t.prefix == [] 32 | assert t.cycle == [Assignment()] 33 | 34 | 35 | def test_trace_tokens(): 36 | tokens = ["a", ",", "!", "b", ";", "{", "b", ",", "c", ";", "a", "}"] 37 | t = Trace.from_tokens(tokens, notation="standard") 38 | assert t.to_tokens(notation="standard") == tokens 39 | --------------------------------------------------------------------------------