From 66a743225facfa9f3d22782ac7e13ad221494ea5 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 13:13:05 +0800 Subject: [PATCH 01/23] =?UTF-8?q?feat:=20=E8=BF=81=E7=A7=BB=20czsc=20?= =?UTF-8?q?=E6=A0=B8=E5=BF=83=E8=87=B3=20Rust=20+=20PyO3=20=E6=B7=B7?= =?UTF-8?q?=E5=90=88=E6=9E=B6=E6=9E=84=EF=BC=8C=E5=B9=B6=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E4=B8=AD=E6=96=87=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将缠论核心算法(CZSC、BarGenerator、信号、交易器等)迁移到 Rust workspace (czsc-core / czsc-signals / czsc-trader / czsc-utils / czsc-ta 等 9 个 crate), 通过 maturin + PyO3 暴露为 czsc._native 扩展模块;Python 端裁剪为薄封装层, 统一以 czsc.* 暴露公共 API。 回测/绩效/mock 数据切换到外部 wbt 包;新增 parity / compat / smoke / integration 测试套件保证 Rust 与原 Python 实现行为一致;为全部 Python 与测试代码补充 详细的中文 docstring 与行内注释。 --- .cargo/config.toml | 2 + .github/workflows/code-quality.yml | 93 +- .github/workflows/python-publish.yml | 460 +- Cargo.lock | 3302 +++++++++++ Cargo.toml | 33 + crates/czsc-core/Cargo.toml | 39 + crates/czsc-core/src/analyze/errors.rs | 13 + crates/czsc-core/src/analyze/mod.rs | 913 +++ crates/czsc-core/src/analyze/utils.rs | 374 ++ crates/czsc-core/src/lib.rs | 11 + crates/czsc-core/src/objects/bar.rs | 690 +++ crates/czsc-core/src/objects/bi.rs | 604 ++ crates/czsc-core/src/objects/direction.rs | 160 + crates/czsc-core/src/objects/errors.rs | 37 + crates/czsc-core/src/objects/event.rs | 605 ++ crates/czsc-core/src/objects/fake_bi.rs | 165 + crates/czsc-core/src/objects/freq.rs | 377 ++ crates/czsc-core/src/objects/fx.rs | 399 ++ crates/czsc-core/src/objects/mark.rs | 130 + crates/czsc-core/src/objects/market.rs | 49 + crates/czsc-core/src/objects/mod.rs | 20 + crates/czsc-core/src/objects/operate.rs | 252 + crates/czsc-core/src/objects/position.rs | 2055 +++++++ crates/czsc-core/src/objects/signal.rs | 567 ++ crates/czsc-core/src/objects/state.rs | 23 + crates/czsc-core/src/objects/zs.rs | 235 + crates/czsc-core/src/python/mod.rs | 115 + crates/czsc-core/src/utils/common.rs | 119 + crates/czsc-core/src/utils/corr.rs | 225 + crates/czsc-core/src/utils/errors.rs | 13 + crates/czsc-core/src/utils/mod.rs | 7 + crates/czsc-core/src/utils/rounded.rs | 36 + crates/czsc-core/tests/test_analyze_error.rs | 29 + crates/czsc-core/tests/test_analyze_utils.rs | 114 + crates/czsc-core/tests/test_bi.rs | 95 + crates/czsc-core/tests/test_czsc_analyzer.rs | 101 + crates/czsc-core/tests/test_event.rs | 78 + crates/czsc-core/tests/test_freq.rs | 59 + crates/czsc-core/tests/test_fx.rs | 73 + crates/czsc-core/tests/test_mark_direction.rs | 44 + crates/czsc-core/tests/test_market.rs | 31 + crates/czsc-core/tests/test_object_error.rs | 35 + crates/czsc-core/tests/test_operate.rs | 48 + crates/czsc-core/tests/test_position.rs | 62 + crates/czsc-core/tests/test_raw_bar.rs | 63 + crates/czsc-core/tests/test_signal.rs | 58 + crates/czsc-core/tests/test_state.rs | 30 + crates/czsc-core/tests/test_utils_common.rs | 26 + crates/czsc-core/tests/test_zs.rs | 115 + crates/czsc-python/Cargo.toml | 37 + crates/czsc-python/build.rs | 14 + crates/czsc-python/src/errors.rs | 44 + crates/czsc-python/src/lib.rs | 65 + crates/czsc-python/src/signals_dispatcher.rs | 177 + crates/czsc-python/src/trader/api.rs | 1439 +++++ crates/czsc-python/src/trader/czsc_signals.rs | 293 + crates/czsc-python/src/trader/czsc_trader.rs | 414 ++ crates/czsc-python/src/trader/generate.rs | 109 + crates/czsc-python/src/trader/mod.rs | 14 + crates/czsc-python/src/trader/research.rs | 632 ++ crates/czsc-python/src/utils/df_convert.rs | 16 + crates/czsc-python/src/utils/mod.rs | 4 + crates/czsc-signal-macros/Cargo.toml | 17 + crates/czsc-signal-macros/src/lib.rs | 470 ++ .../czsc-signal-macros/tests/test_export.rs | 15 + crates/czsc-signals/Cargo.toml | 22 + crates/czsc-signals/src/ang.rs | 765 +++ crates/czsc-signals/src/bar.rs | 3350 +++++++++++ crates/czsc-signals/src/byi.rs | 412 ++ crates/czsc-signals/src/cat.rs | 322 ++ crates/czsc-signals/src/clv.rs | 59 + crates/czsc-signals/src/coo.rs | 345 ++ crates/czsc-signals/src/cvolp.rs | 117 + crates/czsc-signals/src/cxt.rs | 2972 ++++++++++ crates/czsc-signals/src/cxt_trader.rs | 158 + crates/czsc-signals/src/jcc.rs | 1336 +++++ crates/czsc-signals/src/kcatr.rs | 78 + crates/czsc-signals/src/lib.rs | 116 + crates/czsc-signals/src/ntmdk.rs | 54 + crates/czsc-signals/src/obv.rs | 277 + crates/czsc-signals/src/params.rs | 63 + crates/czsc-signals/src/pos.rs | 1202 ++++ crates/czsc-signals/src/pressure.rs | 391 ++ crates/czsc-signals/src/registry.rs | 584 ++ crates/czsc-signals/src/tas.rs | 5118 +++++++++++++++++ crates/czsc-signals/src/types.rs | 122 + crates/czsc-signals/src/utils/cxt.rs | 335 ++ crates/czsc-signals/src/utils/math.rs | 97 + crates/czsc-signals/src/utils/mod.rs | 5 + crates/czsc-signals/src/utils/sig.rs | 844 +++ crates/czsc-signals/src/utils/ta.rs | 1662 ++++++ crates/czsc-signals/src/utils/zdy.rs | 63 + crates/czsc-signals/src/vol.rs | 432 ++ crates/czsc-signals/src/xl.rs | 463 ++ crates/czsc-signals/src/zdy.rs | 948 +++ crates/czsc-signals/src/zdy_trader.rs | 322 ++ crates/czsc-ta/Cargo.toml | 21 + crates/czsc-ta/src/lib.rs | 16 + crates/czsc-ta/src/mixed/chip_dist.rs | 133 + crates/czsc-ta/src/mixed/mod.rs | 3 + crates/czsc-ta/src/pure.rs | 1537 +++++ crates/czsc-ta/src/python.rs | 243 + crates/czsc-ta/tests/test_pure.rs | 108 + crates/czsc-trader/Cargo.toml | 37 + .../czsc-trader/src/engine_v2/catalog/mod.rs | 70 + .../src/engine_v2/compiler/event.rs | 104 + .../czsc-trader/src/engine_v2/compiler/mod.rs | 130 + .../src/engine_v2/compiler/optimize.rs | 35 + .../src/engine_v2/compiler/position.rs | 29 + .../src/engine_v2/compiler/signal.rs | 54 + crates/czsc-trader/src/engine_v2/mod.rs | 7 + .../src/engine_v2/runtime/executor.rs | 446 ++ .../czsc-trader/src/engine_v2/runtime/mod.rs | 3 + crates/czsc-trader/src/engine_v2/scheduler.rs | 40 + crates/czsc-trader/src/lib.rs | 13 + crates/czsc-trader/src/optimize.rs | 384 ++ .../czsc-trader/src/signals/czsc_signals.rs | 393 ++ crates/czsc-trader/src/signals/mod.rs | 5 + crates/czsc-trader/src/signals/sig_parse.rs | 459 ++ crates/czsc-trader/src/trader.rs | 207 + crates/czsc-utils/Cargo.toml | 34 + crates/czsc-utils/data/minutes_split.feather | Bin 0 -> 175674 bytes crates/czsc-utils/src/bar_generator.rs | 999 ++++ crates/czsc-utils/src/errors.rs | 41 + crates/czsc-utils/src/freq_data.rs | 503 ++ crates/czsc-utils/src/lib.rs | 18 + crates/czsc-utils/src/python/mod.rs | 51 + crates/czsc-utils/src/trading_time.rs | 48 + crates/czsc-utils/tests/test_bar_generator.rs | 77 + crates/czsc-utils/tests/test_errors.rs | 41 + crates/czsc-utils/tests/test_freq_data.rs | 67 + crates/czsc-utils/tests/test_trading_time.rs | 54 + crates/error-macros/Cargo.toml | 25 + crates/error-macros/src/err.rs | 35 + crates/error-macros/src/lib.rs | 16 + crates/error-macros/tests/test_derive.rs | 25 + crates/error-support/Cargo.toml | 14 + crates/error-support/src/lib.rs | 38 + crates/error-support/tests/test_chain.rs | 26 + czsc/__init__.py | 302 +- czsc/_compat.py | 418 ++ czsc/_format_standard_kline.py | 119 + czsc/_utils/__init__.py | 0 czsc/_utils/_df_convert.py | 77 + czsc/connectors/cooperation.py | 1026 ++++ czsc/connectors/jq_connector.py | 716 +++ czsc/core.py | 86 - czsc/core.pyi | 41 - czsc/eda.py | 4 +- czsc/envs.py | 128 +- czsc/mock.py | 580 +- czsc/models.py | 119 + czsc/research.py | 321 ++ czsc/sensors/__init__.py | 15 + czsc/signals/__init__.py | 33 + czsc/signals/_helpers.py | 184 + czsc/signals/bar.py | 122 + czsc/signals/cvolp.py | 118 + czsc/signals/cxt.py | 118 + czsc/signals/obv.py | 118 + czsc/signals/pressure.py | 120 + czsc/signals/tas.py | 120 + czsc/signals/vol.py | 119 + czsc/strategies.py | 331 ++ czsc/svc/backtest.py | 217 +- czsc/svc/base.py | 73 +- czsc/svc/factor.py | 204 +- czsc/svc/price_analysis.py | 99 +- czsc/svc/returns.py | 154 +- czsc/svc/statistics.py | 162 +- czsc/svc/strategy.py | 278 +- czsc/traders/__init__.py | 46 +- czsc/traders/__init__.pyi | 27 +- czsc/traders/_rs_signals.py | 136 - czsc/traders/base.py | 669 +-- czsc/traders/base.pyi | 73 - czsc/traders/cwc.py | 875 --- czsc/traders/cwc.pyi | 59 - czsc/traders/optimize.py | 500 ++ czsc/traders/sig_parse.py | 232 +- czsc/utils/__init__.py | 219 +- czsc/utils/analysis/__init__.py | 52 +- czsc/utils/analysis/events.py | 61 - czsc/utils/analysis/events.pyi | 3 - czsc/utils/backtest_report.pyi | 36 - czsc/utils/bi_info.py | 62 - czsc/utils/bi_info.pyi | 9 - czsc/utils/echarts_plot.py | 867 --- czsc/utils/echarts_plot.pyi | 31 - czsc/utils/feature_utils.py | 112 - czsc/utils/features.py | 108 - czsc/utils/features.pyi | 6 - czsc/utils/holds_concepts_effect.py | 75 - czsc/utils/html_report_builder.pyi | 28 - czsc/utils/mark_czsc_status.py | 476 -- czsc/utils/mark_czsc_status.pyi | 3 - czsc/utils/pdf_report_builder.pyi | 56 - czsc/utils/plotting/kline.py | 358 +- czsc/utils/sig.py | 178 +- czsc/utils/ta.py | 881 +-- czsc/utils/ta.pyi | 94 - czsc/utils/word_writer.pyi | 16 - docs/MIGRATION_NOTES.md | 758 +++ .../plans/2026-05-03-rust-czsc-migration.md | 942 +++ .../2026-05-03-rust-czsc-migration-design.md | 425 +- examples/develop/czsc_benchmark.py | 53 +- examples/develop/test_trading_view_kline.py | 55 +- pyproject.toml | 50 +- rust-toolchain.toml | 2 + scripts/cargo_test_all.sh | 30 + test/compat/__init__.py | 13 + test/compat/snapshots/api_v1.json | 80 + test/compat/test_public_api.py | 159 + test/integration/__init__.py | 8 + test/integration/test_weight_backtest.py | 100 + test/parity/__init__.py | 15 + test/parity/_signal_defaults.py | 105 + test/parity/bench_optimize.py | 199 + test/parity/compare_optimize_full.py | 451 ++ test/parity/conftest.py | 107 + test/parity/test_all_signals.py | 239 + test/parity/test_czsc_core.py | 135 + test/parity/test_examples.py | 509 ++ test/parity/test_optimize.py | 217 + test/parity/test_performance.py | 261 + test/parity/test_run_research.py | 181 + test/parity/test_signals_registry.py | 93 + test/smoke/__init__.py | 11 + test/smoke/test_install.py | 144 + test/test_analyze.py | 90 - test/test_analyze_boundary.py | 129 - test/test_api_surface.py | 69 - test/test_eda.py | 406 -- test/test_envs.py | 187 +- test/test_mark_czsc_status.py | 34 - test/test_mock_quality.py | 202 - test/test_plotly_plot.py | 42 +- test/test_rs.py | 45 - test/test_rs_analyze.py | 73 - test/test_sig.py | 288 - test/test_utils.py | 84 +- test/test_utils_features.py | 144 - test/test_utils_refactored.py | 124 +- test/test_utils_ta.py | 503 -- test/unit/__init__.py | 12 + test/unit/snapshots/core_parity_seed42.json | 277 + test/unit/test_core_parity.py | 141 + test/unit/test_pickle.py | 137 + test/unit/test_signals_parity.py | 128 + test/unit/test_ta_parity.py | 166 + test/unit/test_trading_time.py | 106 + tests/test_workspace_layout.sh | 41 + uv.lock | 66 +- 253 files changed, 59589 insertions(+), 8684 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 crates/czsc-core/Cargo.toml create mode 100644 crates/czsc-core/src/analyze/errors.rs create mode 100644 crates/czsc-core/src/analyze/mod.rs create mode 100644 crates/czsc-core/src/analyze/utils.rs create mode 100644 crates/czsc-core/src/lib.rs create mode 100644 crates/czsc-core/src/objects/bar.rs create mode 100644 crates/czsc-core/src/objects/bi.rs create mode 100644 crates/czsc-core/src/objects/direction.rs create mode 100644 crates/czsc-core/src/objects/errors.rs create mode 100644 crates/czsc-core/src/objects/event.rs create mode 100644 crates/czsc-core/src/objects/fake_bi.rs create mode 100644 crates/czsc-core/src/objects/freq.rs create mode 100644 crates/czsc-core/src/objects/fx.rs create mode 100644 crates/czsc-core/src/objects/mark.rs create mode 100644 crates/czsc-core/src/objects/market.rs create mode 100644 crates/czsc-core/src/objects/mod.rs create mode 100644 crates/czsc-core/src/objects/operate.rs create mode 100644 crates/czsc-core/src/objects/position.rs create mode 100644 crates/czsc-core/src/objects/signal.rs create mode 100644 crates/czsc-core/src/objects/state.rs create mode 100644 crates/czsc-core/src/objects/zs.rs create mode 100644 crates/czsc-core/src/python/mod.rs create mode 100644 crates/czsc-core/src/utils/common.rs create mode 100644 crates/czsc-core/src/utils/corr.rs create mode 100644 crates/czsc-core/src/utils/errors.rs create mode 100644 crates/czsc-core/src/utils/mod.rs create mode 100644 crates/czsc-core/src/utils/rounded.rs create mode 100644 crates/czsc-core/tests/test_analyze_error.rs create mode 100644 crates/czsc-core/tests/test_analyze_utils.rs create mode 100644 crates/czsc-core/tests/test_bi.rs create mode 100644 crates/czsc-core/tests/test_czsc_analyzer.rs create mode 100644 crates/czsc-core/tests/test_event.rs create mode 100644 crates/czsc-core/tests/test_freq.rs create mode 100644 crates/czsc-core/tests/test_fx.rs create mode 100644 crates/czsc-core/tests/test_mark_direction.rs create mode 100644 crates/czsc-core/tests/test_market.rs create mode 100644 crates/czsc-core/tests/test_object_error.rs create mode 100644 crates/czsc-core/tests/test_operate.rs create mode 100644 crates/czsc-core/tests/test_position.rs create mode 100644 crates/czsc-core/tests/test_raw_bar.rs create mode 100644 crates/czsc-core/tests/test_signal.rs create mode 100644 crates/czsc-core/tests/test_state.rs create mode 100644 crates/czsc-core/tests/test_utils_common.rs create mode 100644 crates/czsc-core/tests/test_zs.rs create mode 100644 crates/czsc-python/Cargo.toml create mode 100644 crates/czsc-python/build.rs create mode 100644 crates/czsc-python/src/errors.rs create mode 100644 crates/czsc-python/src/lib.rs create mode 100644 crates/czsc-python/src/signals_dispatcher.rs create mode 100644 crates/czsc-python/src/trader/api.rs create mode 100644 crates/czsc-python/src/trader/czsc_signals.rs create mode 100644 crates/czsc-python/src/trader/czsc_trader.rs create mode 100644 crates/czsc-python/src/trader/generate.rs create mode 100644 crates/czsc-python/src/trader/mod.rs create mode 100644 crates/czsc-python/src/trader/research.rs create mode 100644 crates/czsc-python/src/utils/df_convert.rs create mode 100644 crates/czsc-python/src/utils/mod.rs create mode 100644 crates/czsc-signal-macros/Cargo.toml create mode 100644 crates/czsc-signal-macros/src/lib.rs create mode 100644 crates/czsc-signal-macros/tests/test_export.rs create mode 100644 crates/czsc-signals/Cargo.toml create mode 100644 crates/czsc-signals/src/ang.rs create mode 100644 crates/czsc-signals/src/bar.rs create mode 100644 crates/czsc-signals/src/byi.rs create mode 100644 crates/czsc-signals/src/cat.rs create mode 100644 crates/czsc-signals/src/clv.rs create mode 100644 crates/czsc-signals/src/coo.rs create mode 100644 crates/czsc-signals/src/cvolp.rs create mode 100644 crates/czsc-signals/src/cxt.rs create mode 100644 crates/czsc-signals/src/cxt_trader.rs create mode 100644 crates/czsc-signals/src/jcc.rs create mode 100644 crates/czsc-signals/src/kcatr.rs create mode 100644 crates/czsc-signals/src/lib.rs create mode 100644 crates/czsc-signals/src/ntmdk.rs create mode 100644 crates/czsc-signals/src/obv.rs create mode 100644 crates/czsc-signals/src/params.rs create mode 100644 crates/czsc-signals/src/pos.rs create mode 100644 crates/czsc-signals/src/pressure.rs create mode 100644 crates/czsc-signals/src/registry.rs create mode 100644 crates/czsc-signals/src/tas.rs create mode 100644 crates/czsc-signals/src/types.rs create mode 100644 crates/czsc-signals/src/utils/cxt.rs create mode 100644 crates/czsc-signals/src/utils/math.rs create mode 100644 crates/czsc-signals/src/utils/mod.rs create mode 100644 crates/czsc-signals/src/utils/sig.rs create mode 100644 crates/czsc-signals/src/utils/ta.rs create mode 100644 crates/czsc-signals/src/utils/zdy.rs create mode 100644 crates/czsc-signals/src/vol.rs create mode 100644 crates/czsc-signals/src/xl.rs create mode 100644 crates/czsc-signals/src/zdy.rs create mode 100644 crates/czsc-signals/src/zdy_trader.rs create mode 100644 crates/czsc-ta/Cargo.toml create mode 100644 crates/czsc-ta/src/lib.rs create mode 100644 crates/czsc-ta/src/mixed/chip_dist.rs create mode 100644 crates/czsc-ta/src/mixed/mod.rs create mode 100644 crates/czsc-ta/src/pure.rs create mode 100644 crates/czsc-ta/src/python.rs create mode 100644 crates/czsc-ta/tests/test_pure.rs create mode 100644 crates/czsc-trader/Cargo.toml create mode 100644 crates/czsc-trader/src/engine_v2/catalog/mod.rs create mode 100644 crates/czsc-trader/src/engine_v2/compiler/event.rs create mode 100644 crates/czsc-trader/src/engine_v2/compiler/mod.rs create mode 100644 crates/czsc-trader/src/engine_v2/compiler/optimize.rs create mode 100644 crates/czsc-trader/src/engine_v2/compiler/position.rs create mode 100644 crates/czsc-trader/src/engine_v2/compiler/signal.rs create mode 100644 crates/czsc-trader/src/engine_v2/mod.rs create mode 100644 crates/czsc-trader/src/engine_v2/runtime/executor.rs create mode 100644 crates/czsc-trader/src/engine_v2/runtime/mod.rs create mode 100644 crates/czsc-trader/src/engine_v2/scheduler.rs create mode 100644 crates/czsc-trader/src/lib.rs create mode 100644 crates/czsc-trader/src/optimize.rs create mode 100644 crates/czsc-trader/src/signals/czsc_signals.rs create mode 100644 crates/czsc-trader/src/signals/mod.rs create mode 100644 crates/czsc-trader/src/signals/sig_parse.rs create mode 100644 crates/czsc-trader/src/trader.rs create mode 100644 crates/czsc-utils/Cargo.toml create mode 100644 crates/czsc-utils/data/minutes_split.feather create mode 100644 crates/czsc-utils/src/bar_generator.rs create mode 100644 crates/czsc-utils/src/errors.rs create mode 100644 crates/czsc-utils/src/freq_data.rs create mode 100644 crates/czsc-utils/src/lib.rs create mode 100644 crates/czsc-utils/src/python/mod.rs create mode 100644 crates/czsc-utils/src/trading_time.rs create mode 100644 crates/czsc-utils/tests/test_bar_generator.rs create mode 100644 crates/czsc-utils/tests/test_errors.rs create mode 100644 crates/czsc-utils/tests/test_freq_data.rs create mode 100644 crates/czsc-utils/tests/test_trading_time.rs create mode 100644 crates/error-macros/Cargo.toml create mode 100644 crates/error-macros/src/err.rs create mode 100644 crates/error-macros/src/lib.rs create mode 100644 crates/error-macros/tests/test_derive.rs create mode 100644 crates/error-support/Cargo.toml create mode 100644 crates/error-support/src/lib.rs create mode 100644 crates/error-support/tests/test_chain.rs create mode 100644 czsc/_compat.py create mode 100644 czsc/_format_standard_kline.py create mode 100644 czsc/_utils/__init__.py create mode 100644 czsc/_utils/_df_convert.py create mode 100644 czsc/connectors/cooperation.py create mode 100644 czsc/connectors/jq_connector.py delete mode 100644 czsc/core.py delete mode 100644 czsc/core.pyi create mode 100644 czsc/models.py create mode 100644 czsc/research.py create mode 100644 czsc/sensors/__init__.py create mode 100644 czsc/signals/__init__.py create mode 100644 czsc/signals/_helpers.py create mode 100644 czsc/signals/bar.py create mode 100644 czsc/signals/cvolp.py create mode 100644 czsc/signals/cxt.py create mode 100644 czsc/signals/obv.py create mode 100644 czsc/signals/pressure.py create mode 100644 czsc/signals/tas.py create mode 100644 czsc/signals/vol.py create mode 100644 czsc/strategies.py delete mode 100644 czsc/traders/_rs_signals.py delete mode 100644 czsc/traders/base.pyi delete mode 100644 czsc/traders/cwc.py delete mode 100644 czsc/traders/cwc.pyi create mode 100644 czsc/traders/optimize.py delete mode 100644 czsc/utils/analysis/events.py delete mode 100644 czsc/utils/analysis/events.pyi delete mode 100644 czsc/utils/backtest_report.pyi delete mode 100644 czsc/utils/bi_info.py delete mode 100644 czsc/utils/bi_info.pyi delete mode 100644 czsc/utils/echarts_plot.py delete mode 100644 czsc/utils/echarts_plot.pyi delete mode 100644 czsc/utils/feature_utils.py delete mode 100644 czsc/utils/features.py delete mode 100644 czsc/utils/features.pyi delete mode 100644 czsc/utils/holds_concepts_effect.py delete mode 100644 czsc/utils/html_report_builder.pyi delete mode 100644 czsc/utils/mark_czsc_status.py delete mode 100644 czsc/utils/mark_czsc_status.pyi delete mode 100644 czsc/utils/pdf_report_builder.pyi delete mode 100644 czsc/utils/ta.pyi delete mode 100644 czsc/utils/word_writer.pyi create mode 100644 docs/MIGRATION_NOTES.md create mode 100644 docs/superpowers/plans/2026-05-03-rust-czsc-migration.md create mode 100644 rust-toolchain.toml create mode 100755 scripts/cargo_test_all.sh create mode 100644 test/compat/__init__.py create mode 100644 test/compat/snapshots/api_v1.json create mode 100644 test/compat/test_public_api.py create mode 100644 test/integration/__init__.py create mode 100644 test/integration/test_weight_backtest.py create mode 100644 test/parity/__init__.py create mode 100644 test/parity/_signal_defaults.py create mode 100644 test/parity/bench_optimize.py create mode 100644 test/parity/compare_optimize_full.py create mode 100644 test/parity/conftest.py create mode 100644 test/parity/test_all_signals.py create mode 100644 test/parity/test_czsc_core.py create mode 100644 test/parity/test_examples.py create mode 100644 test/parity/test_optimize.py create mode 100644 test/parity/test_performance.py create mode 100644 test/parity/test_run_research.py create mode 100644 test/parity/test_signals_registry.py create mode 100644 test/smoke/__init__.py create mode 100644 test/smoke/test_install.py delete mode 100644 test/test_analyze.py delete mode 100644 test/test_analyze_boundary.py delete mode 100644 test/test_api_surface.py delete mode 100644 test/test_eda.py delete mode 100644 test/test_mark_czsc_status.py delete mode 100644 test/test_mock_quality.py delete mode 100644 test/test_rs.py delete mode 100644 test/test_rs_analyze.py delete mode 100644 test/test_sig.py delete mode 100644 test/test_utils_features.py delete mode 100644 test/test_utils_ta.py create mode 100644 test/unit/__init__.py create mode 100644 test/unit/snapshots/core_parity_seed42.json create mode 100644 test/unit/test_core_parity.py create mode 100644 test/unit/test_pickle.py create mode 100644 test/unit/test_signals_parity.py create mode 100644 test/unit/test_ta_parity.py create mode 100644 test/unit/test_trading_time.py create mode 100755 tests/test_workspace_layout.sh diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 000000000..44fe421cd --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +incremental = true diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index cb76e06df..0478f8d41 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -7,10 +7,52 @@ on: branches: [ master ] jobs: + # ------------------------------------------------------------------ + # 1) Rust per-crate tests — must pass before any Python work runs. + # + # `cargo test --workspace` cannot link against libpython when the + # pyo3 `extension-module` feature is enabled (see MIGRATION_NOTES + # §3.1). We therefore run cargo test per crate, skipping crates + # that pull in pyo3/extension-module (czsc-python, czsc-signals, + # czsc-trader) — they're exercised end-to-end by the Python test + # matrix below via maturin develop + pytest. + # ------------------------------------------------------------------ + rust-tests: + name: Rust per-crate tests + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + + - name: cargo build --workspace + run: cargo build --workspace --release + + - name: cargo test (per-crate) + run: | + set -e + for crate in error-macros error-support czsc-core czsc-utils czsc-ta czsc-signal-macros; do + echo "::group::cargo test -p $crate" + cargo test -p "$crate" --no-fail-fast + echo "::endgroup::" + done + + # ------------------------------------------------------------------ + # 2) Python test matrix — single abi3 wheel covers 3.10/3.11/3.12/3.13. + # Each Python version installs the project (which triggers maturin + # to build czsc._native against that interpreter), then runs pytest. + # ------------------------------------------------------------------ test: - name: Test Suite + name: Python tests (py${{ matrix.python-version }}) + needs: rust-tests runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python-version: ['3.10', '3.11', '3.12', '3.13'] @@ -18,6 +60,12 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + - name: Install uv uses: astral-sh/setup-uv@v2 with: @@ -31,15 +79,22 @@ jobs: sudo apt-get update sudo apt-get install -y libxml2-dev libxslt-dev - - name: Install project dependencies + - name: Install project dependencies (Python only) run: | + # uv sync 装 Python 依赖到 .venv;czsc._native 还需要 maturin develop 单独构建。 uv sync --extra all - - name: Check project can be imported + - name: Build czsc._native (maturin develop) + run: | + uv pip install maturin + uv run maturin develop --release + + - name: Smoke-import czsc run: | uv run python -c "import czsc; print(f'CZSC version: {czsc.__version__}')" + uv run python -c "from czsc import CZSC, RawBar, Freq; print('core imports OK')" - - name: Run tests with pytest + - name: Run pytest run: | uv run pytest test/ -v --cov=czsc --cov-report=xml --cov-report=term @@ -68,13 +123,19 @@ jobs: - name: Set up Python run: uv python install 3.11 - - name: Install dependencies - run: uv sync --extra all + - name: Install dev tools (no native build needed) + run: | + uv sync --extra all - - name: Check code formatting with ruff + - name: Check Python formatting with ruff run: | uv run ruff format --check --diff czsc/ test/ + - name: Check Rust formatting with cargo fmt + run: | + rustup component add rustfmt + cargo fmt --all -- --check + linting: name: Code Linting runs-on: ubuntu-latest @@ -83,6 +144,14 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + - name: Install uv uses: astral-sh/setup-uv@v2 with: @@ -94,14 +163,18 @@ jobs: - name: Install dependencies run: uv sync --extra all - - name: Lint with ruff + - name: Lint Python with ruff run: | uv run ruff check czsc/ test/ - - name: Type checking with basedpyright + - name: Type-check Python with basedpyright run: | uv run basedpyright czsc/ || true + - name: Lint Rust with cargo clippy + run: | + cargo clippy --workspace --all-targets -- -D warnings || true + security: name: Security Audit runs-on: ubuntu-latest @@ -175,4 +248,4 @@ jobs: name: dependency-reports path: | licenses.json - retention-days: 7 \ No newline at end of file + retention-days: 7 diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 0d2945ce9..2ab413726 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,15 +1,23 @@ -# This workflow will build and upload a Python Package using UV when a release is created -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python +# Build & Publish Python Package +# +# czsc 1.0.0+ is a mixed Rust/Python package: the czsc._native extension +# is produced by the czsc-python crate (PyO3) and the pure-Python tree +# under czsc/ is bundled alongside via maturin's `python-source = "."`. +# We build per-platform abi3 wheels (Python 3.10+ ABI) so a single wheel +# per (OS, arch) covers Python 3.10/3.11/3.12/3.13. +# +# Trusted Publishing (OIDC) is configured at https://pypi.org/p/czsc and +# https://test.pypi.org/p/czsc — no API token secrets are kept in the repo. name: Build & Publish Python Package on: push: tags: - - 'v*' # 监听所有 v* 格式的标签,如 v1.0.0, v0.9.70 + - 'v*' # 监听 v* 格式的 tag,例如 v1.0.0 release: types: [published] - workflow_dispatch: # Allow manual triggering + workflow_dispatch: inputs: publish_to_testpypi: description: 'Publish to TestPyPI instead of PyPI' @@ -18,233 +26,275 @@ on: type: boolean jobs: - pre-publish-checks: + # ------------------------------------------------------------------ + # 1) Build per-platform abi3 wheels (Linux / macOS / Windows × x86_64+arm64) + # ------------------------------------------------------------------ + build-wheels: + name: Wheel ${{ matrix.platform.os }} ${{ matrix.platform.target }} + runs-on: ${{ matrix.platform.os }} + strategy: + fail-fast: false + matrix: + platform: + - { os: ubuntu-latest, target: x86_64, manylinux: '2014' } + - { os: macos-13, target: x86_64, manylinux: '' } + - { os: macos-14, target: aarch64, manylinux: '' } + - { os: windows-latest, target: x64, manylinux: '' } + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Build wheel (maturin) + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + manylinux: ${{ matrix.platform.manylinux }} + args: --release --strip --out dist + sccache: 'true' + + - name: List built wheels + shell: bash + run: ls -la dist/ + + - name: Upload wheel artifact + uses: actions/upload-artifact@v4 + with: + name: wheels-${{ matrix.platform.os }}-${{ matrix.platform.target }} + path: dist/*.whl + retention-days: 30 + if-no-files-found: error + + # ------------------------------------------------------------------ + # 2) Build source distribution (sdist) + # ------------------------------------------------------------------ + build-sdist: + name: Build sdist runs-on: ubuntu-latest steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v2 - with: - version: "latest" - - - name: Set up Python - run: uv python install 3.11 - - - name: Install dependencies - run: uv sync --extra all - - # - name: Run tests - # run: | - # uv run pytest test/ -v - - # - name: Lint check - # run: | - # uv add --dev flake8 black isort - # uv run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # uv run black --check . - # uv run isort --check-only . - - # - name: Security audit - # run: | - # uv add --dev safety - # uv run safety check - - build: - needs: pre-publish-checks + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Build sdist (maturin) + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + + - name: List sdist + run: ls -la dist/ + + - name: Upload sdist artifact + uses: actions/upload-artifact@v4 + with: + name: sdist + path: dist/*.tar.gz + retention-days: 30 + if-no-files-found: error + + # ------------------------------------------------------------------ + # 3) Verify wheels can be installed and `import czsc` works + # ------------------------------------------------------------------ + smoke-test: + name: Smoke test (Linux x86_64) + needs: build-wheels runs-on: ubuntu-latest steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v2 - with: - version: "latest" - - - name: Set up Python - run: uv python install 3.11 - - - name: Install dependencies - run: uv sync - - - name: Build package - run: uv build - - - name: Check package metadata - run: | - uv add --dev twine - uv run twine check dist/* - - - name: List build artifacts - run: ls -la dist/ - - - name: Upload build artifacts - uses: actions/upload-artifact@v4 - with: - name: python-package-distributions - path: dist/ - retention-days: 30 - if-no-files-found: error - + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Download Linux wheel + uses: actions/download-artifact@v4 + with: + name: wheels-ubuntu-latest-x86_64 + path: dist/ + + - name: Install + import czsc + run: | + python -m pip install --upgrade pip + # 安装 czsc wheel + PyPI 上的运行时依赖 + python -m pip install dist/*.whl + python -c "import czsc; print('czsc version:', getattr(czsc, '__version__', 'unknown'))" + python -c "from czsc import CZSC, RawBar, Freq; print('core imports OK')" + + # ------------------------------------------------------------------ + # 4) (optional) Publish to TestPyPI via workflow_dispatch + # ------------------------------------------------------------------ publish-to-testpypi: name: Publish to TestPyPI if: github.event.inputs.publish_to_testpypi == 'true' - needs: build + needs: [build-wheels, build-sdist, smoke-test] runs-on: ubuntu-latest environment: name: testpypi url: https://test.pypi.org/p/czsc permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing - + id-token: write # Trusted Publishing steps: - - name: Download build artifacts - uses: actions/download-artifact@v4 - with: - name: python-package-distributions - path: dist/ - - - name: Publish to TestPyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - repository-url: https://test.pypi.org/legacy/ - verbose: true - + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: dist-staging/ + + - name: Flatten dist + run: | + mkdir -p dist + find dist-staging -type f \( -name "*.whl" -o -name "*.tar.gz" \) -exec mv {} dist/ \; + ls -la dist/ + + - name: Publish to TestPyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + verbose: true + + # ------------------------------------------------------------------ + # 5) Publish to PyPI on tag push / GitHub release + # ------------------------------------------------------------------ publish-to-pypi: name: Publish to PyPI if: (github.event_name == 'release' && github.event.action == 'published') || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')) - needs: build + needs: [build-wheels, build-sdist, smoke-test] runs-on: ubuntu-latest environment: name: pypi url: https://pypi.org/p/czsc permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing - + id-token: write # Trusted Publishing steps: - - name: Extract version from tag - id: extract_version - run: | - if [[ $GITHUB_REF == refs/tags/v* ]]; then - VERSION=${GITHUB_REF#refs/tags/v} - echo "version=$VERSION" >> $GITHUB_OUTPUT - echo "tag_name=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT - elif [[ $GITHUB_EVENT_NAME == "release" ]]; then - echo "version=${{ github.event.release.tag_name }}" >> $GITHUB_OUTPUT - echo "tag_name=${{ github.event.release.tag_name }}" >> $GITHUB_OUTPUT - fi - - - name: Download build artifacts - uses: actions/download-artifact@v4 - with: - name: python-package-distributions - path: dist/ - - - name: Verify version consistency - run: | - # 验证构建的包版本与 tag 版本一致 - BUILT_VERSION=$(ls dist/*.whl | grep -oP 'czsc-\K[0-9]+\.[0-9]+\.[0-9]+' | head -1) - TAG_VERSION=${{ steps.extract_version.outputs.version }} - - echo "Built version: $BUILT_VERSION" - echo "Tag version: $TAG_VERSION" - - if [ "$BUILT_VERSION" != "$TAG_VERSION" ]; then - echo "❌ Version mismatch: built $BUILT_VERSION != tag $TAG_VERSION" - exit 1 - else - echo "✅ Version consistency verified" - fi - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - verbose: true - + - name: Extract version from tag + id: extract_version + run: | + if [[ $GITHUB_REF == refs/tags/v* ]]; then + VERSION=${GITHUB_REF#refs/tags/v} + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + echo "tag_name=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT" + elif [[ $GITHUB_EVENT_NAME == "release" ]]; then + echo "version=${{ github.event.release.tag_name }}" | sed 's/v//' >> "$GITHUB_OUTPUT" + echo "tag_name=${{ github.event.release.tag_name }}" >> "$GITHUB_OUTPUT" + fi + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: dist-staging/ + + - name: Flatten dist + run: | + mkdir -p dist + find dist-staging -type f \( -name "*.whl" -o -name "*.tar.gz" \) -exec mv {} dist/ \; + ls -la dist/ + + - name: Verify version consistency + run: | + # 至少有一个 wheel 文件,把版本和 tag 比对 + BUILT_VERSION=$(ls dist/*.whl | head -1 | grep -oP 'czsc-\K[0-9]+\.[0-9]+\.[0-9]+(?:[a-z0-9.+-]*)?') + TAG_VERSION=${{ steps.extract_version.outputs.version }} + echo "Built version: $BUILT_VERSION" + echo "Tag version: $TAG_VERSION" + if [ "$BUILT_VERSION" != "$TAG_VERSION" ]; then + echo "::error::Version mismatch: built $BUILT_VERSION != tag $TAG_VERSION" + exit 1 + fi + echo "Version consistency verified" + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + verbose: true + + # ------------------------------------------------------------------ + # 6) Sign + create GitHub Release on tag push + # ------------------------------------------------------------------ create-github-release: name: Sign and upload to GitHub Release if: (github.event_name == 'release' && github.event.action == 'published') || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')) needs: publish-to-pypi runs-on: ubuntu-latest permissions: - contents: write # IMPORTANT: mandatory for making GitHub Releases - id-token: write # IMPORTANT: mandatory for sigstore - + contents: write # 创建 GitHub Release + id-token: write # sigstore steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Download build artifacts - uses: actions/download-artifact@v4 - with: - name: python-package-distributions - path: dist/ - - - name: Extract version info - id: version_info - run: | - if [[ $GITHUB_REF == refs/tags/v* ]]; then - TAG_NAME=${GITHUB_REF#refs/tags/} - VERSION=${GITHUB_REF#refs/tags/v} - elif [[ $GITHUB_EVENT_NAME == "release" ]]; then - TAG_NAME=${{ github.event.release.tag_name }} - VERSION=${TAG_NAME#v} - fi - - echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT - echo "version=$VERSION" >> $GITHUB_OUTPUT - - - name: Create Release (if tag push) - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') - env: - GITHUB_TOKEN: ${{ github.token }} - run: | - TAG_NAME=${{ steps.version_info.outputs.tag_name }} - VERSION=${{ steps.version_info.outputs.version }} - - # 检查 Release 是否已存在 - if ! gh release view "$TAG_NAME" > /dev/null 2>&1; then - echo "Creating release for $TAG_NAME..." - - # 创建发布说明 - cat > release_notes.md << EOF - 🚀 czsc $VERSION - - ### 更新内容 - - 更新到版本 $VERSION - - 详细变更请查看提交历史 - - ### 安装方式 - \`\`\`bash - pip install czsc==$VERSION - \`\`\` - - ### 文档 - - 项目文档: README.md - - API 参考文档请查看源码 - EOF - - gh release create "$TAG_NAME" \ - --title "Release $VERSION" \ - --notes-file release_notes.md \ - --draft=false \ - --prerelease=false - else - echo "Release $TAG_NAME already exists" - fi - - - name: Sign the dists with Sigstore - uses: sigstore/gh-action-sigstore-python@v3.2.0 - with: - inputs: >- - ./dist/*.tar.gz - ./dist/*.whl - - - name: Upload to GitHub Release - env: - GITHUB_TOKEN: ${{ github.token }} - run: | - TAG_NAME=${{ steps.version_info.outputs.tag_name }} - gh release upload "$TAG_NAME" dist/** --clobber + - name: Checkout code + uses: actions/checkout@v4 + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: dist-staging/ + + - name: Flatten dist + run: | + mkdir -p dist + find dist-staging -type f \( -name "*.whl" -o -name "*.tar.gz" \) -exec mv {} dist/ \; + ls -la dist/ + + - name: Extract version info + id: version_info + run: | + if [[ $GITHUB_REF == refs/tags/v* ]]; then + TAG_NAME=${GITHUB_REF#refs/tags/} + VERSION=${GITHUB_REF#refs/tags/v} + elif [[ $GITHUB_EVENT_NAME == "release" ]]; then + TAG_NAME=${{ github.event.release.tag_name }} + VERSION=${TAG_NAME#v} + fi + echo "tag_name=$TAG_NAME" >> "$GITHUB_OUTPUT" + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + + - name: Create Release (if tag push) + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + TAG_NAME=${{ steps.version_info.outputs.tag_name }} + VERSION=${{ steps.version_info.outputs.version }} + if ! gh release view "$TAG_NAME" > /dev/null 2>&1; then + cat > release_notes.md << EOF + 🚀 czsc $VERSION + + ### 更新内容 + 详细变更请查看提交历史以及 \`docs/MIGRATION_NOTES.md\`。 + + ### 安装方式 + \`\`\`bash + pip install czsc==$VERSION + \`\`\` + + ### 文档 + - 项目文档: README.md + - API 参考: https://czsc.readthedocs.io/ + EOF + gh release create "$TAG_NAME" \ + --title "Release $VERSION" \ + --notes-file release_notes.md \ + --draft=false \ + --prerelease=false + else + echo "Release $TAG_NAME already exists" + fi + + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v3.2.0 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + + - name: Upload to GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + TAG_NAME=${{ steps.version_info.outputs.tag_name }} + gh release upload "$TAG_NAME" dist/** --clobber diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 000000000..682e80212 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,3302 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom 0.3.4", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "ar_archive_writer" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" +dependencies = [ + "object", +] + +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + +[[package]] +name = "argminmax" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70f13d10a41ac8d2ec79ee34178d61e6f47a29c2edfe7ef1721c7383b0359e65" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed51fe0f224d1d4ea768be38c51f9f831dee9d05c163c11fba0b8c44387b1fc3" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + +[[package]] +name = "atoi_simd" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +dependencies = [ + "serde_core", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "brotli" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a334ef7c9e23abf0ce748e8cd309037da93e606ad52eb372e4ce327a0dcfbdfd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "chrono-tz" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + +[[package]] +name = "comfy-table" +version = "7.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958c5d6ecf1f214b4c2bbbbf6ab9523a864bd136dcf71a7e8904799acfe1ad47" +dependencies = [ + "crossterm", + "unicode-segmentation", + "unicode-width", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crossterm" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" +dependencies = [ + "bitflags", + "crossterm_winapi", + "document-features", + "parking_lot", + "rustix", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + +[[package]] +name = "czsc-core" +version = "1.0.0" +dependencies = [ + "anyhow", + "chrono", + "derive_builder", + "error-macros", + "error-support", + "hex", + "log", + "md5", + "parking_lot", + "polars", + "pyo3", + "pyo3-stub-gen", + "serde", + "serde_json", + "sha2", + "strum", + "strum_macros", + "thiserror 2.0.18", +] + +[[package]] +name = "czsc-python" +version = "1.0.0" +dependencies = [ + "anyhow", + "chrono", + "czsc-core", + "czsc-signals", + "czsc-ta", + "czsc-trader", + "czsc-utils", + "error-macros", + "error-support", + "inventory", + "md5", + "numpy", + "polars", + "pyo3", + "pyo3-stub-gen", + "rust_xlsxwriter", + "serde", + "serde_json", + "thiserror 2.0.18", +] + +[[package]] +name = "czsc-signal-macros" +version = "1.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "czsc-signals" +version = "1.0.0" +dependencies = [ + "anyhow", + "chrono", + "czsc-core", + "czsc-signal-macros", + "czsc-ta", + "inventory", + "serde", + "serde_json", + "tracing", +] + +[[package]] +name = "czsc-ta" +version = "1.0.0" +dependencies = [ + "numpy", + "ordered-float", + "pyo3", + "pyo3-stub-gen", +] + +[[package]] +name = "czsc-trader" +version = "1.0.0" +dependencies = [ + "anyhow", + "arrayvec", + "chrono", + "csv", + "czsc-core", + "czsc-signals", + "czsc-utils", + "error-macros", + "error-support", + "hashbrown 0.14.5", + "hex", + "log", + "md5", + "polars", + "polars-plan", + "rayon", + "serde", + "serde_json", + "sha2", + "strum", + "strum_macros", + "thiserror 2.0.18", + "tracing", +] + +[[package]] +name = "czsc-utils" +version = "1.0.0" +dependencies = [ + "anyhow", + "chrono", + "czsc-core", + "error-macros", + "error-support", + "hashbrown 0.14.5", + "once_cell", + "parking_lot", + "polars", + "pyo3", + "pyo3-stub-gen", + "serde", + "thiserror 2.0.18", +] + +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.117", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.117", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "error-macros" +version = "1.0.0" +dependencies = [ + "anyhow", + "proc-macro2", + "quote", + "serde", + "serde_json", + "syn 2.0.117", + "thiserror 2.0.18", +] + +[[package]] +name = "error-support" +version = "1.0.0" +dependencies = [ + "anyhow", +] + +[[package]] +name = "ethnum" +version = "1.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40404c3f5f511ec4da6fe866ddf6a717c309fdbb69fbbad7b0f3edab8f2e835f" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + +[[package]] +name = "fast-float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "float-cmp" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +dependencies = [ + "num-traits", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "halfbrown" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" +dependencies = [ + "hashbrown 0.14.5", + "serde", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", + "rayon", + "serde", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core 0.62.2", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown 0.17.0", + "serde", + "serde_core", +] + +[[package]] +name = "indoc" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" +dependencies = [ + "rustversion", +] + +[[package]] +name = "inventory" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f0c30c76f2f4ccee3fe55a2435f691ca00c0e4bd87abe4f4a851b1d4dac39b" +dependencies = [ + "rustversion", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "itoap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" +dependencies = [ + "cfg-if", + "futures-util", + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "lexical-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d8d125a277f807e55a77304455eb7b1cb52f2b18c143b60e766c120bd64a594" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52a9f232fbd6f550bc0137dcb5f99ab674071ac2d690ac69704593cb4abbea56" +dependencies = [ + "lexical-parse-integer", + "lexical-util", +] + +[[package]] +name = "lexical-parse-integer" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a7a039f8fb9c19c996cd7b2fcce303c1b2874fe1aca544edc85c4a5f8489b34" +dependencies = [ + "lexical-util", +] + +[[package]] +name = "lexical-util" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2604dd126bb14f13fb5d1bd6a66155079cb9fa655b37f875b3a742c705dbed17" + +[[package]] +name = "lexical-write-float" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50c438c87c013188d415fbabbb1dceb44249ab81664efbd31b14ae55dabb6361" +dependencies = [ + "lexical-util", + "lexical-write-integer", +] + +[[package]] +name = "lexical-write-integer" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "409851a618475d2d5796377cad353802345cba92c867d9fbcde9cf4eac4e14df" +dependencies = [ + "lexical-util", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "multiversion" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" +dependencies = [ + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "target-features", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3b335231dfd352ffb0f8017f3b6027a4917f7df785ea2143d8af2adc66980ae" +dependencies = [ + "winapi", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "numpy" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29f1dee9aa8d3f6f8e8b9af3803006101bb3653866ef056d530d53ae68587191" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "pyo3-build-config", + "rustc-hash", +] + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "ordered-float" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7d950ca161dc355eaf28f82b11345ed76c6e1f6eb1f4f4479e0323b9e2fbd0e" +dependencies = [ + "num-traits", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "parquet-format-safe" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" +dependencies = [ + "async-trait", + "futures", +] + +[[package]] +name = "parse-zoneinfo" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" +dependencies = [ + "regex", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + +[[package]] +name = "polars" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad002eb9c541b4f7e0c7c759cefe884a0350e15d241231ac4be31c5568c15070" +dependencies = [ + "getrandom 0.2.17", + "polars-arrow", + "polars-core", + "polars-error", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-parquet", + "polars-sql", + "polars-time", + "polars-utils", + "version_check", +] + +[[package]] +name = "polars-arrow" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32d19c6db79cb6a3c55af3b5a3976276edaab64cbf7f69b392617c2af30d7742" +dependencies = [ + "ahash", + "atoi", + "atoi_simd", + "bytemuck", + "chrono", + "chrono-tz", + "dyn-clone", + "either", + "ethnum", + "fast-float", + "futures", + "getrandom 0.2.17", + "hashbrown 0.14.5", + "itoa", + "itoap", + "lz4", + "multiversion", + "num-traits", + "parking_lot", + "polars-arrow-format", + "polars-error", + "polars-utils", + "ryu", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "version_check", + "zstd", +] + +[[package]] +name = "polars-arrow-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b0ef2474af9396b19025b189d96e992311e6a47f90c53cd998b36c4c64b84c" +dependencies = [ + "planus", + "serde", +] + +[[package]] +name = "polars-compute" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30194a5ff325f61d6fcb62dc215c9210f308fc4fc85a493ef777dbcd938cba24" +dependencies = [ + "bytemuck", + "either", + "num-traits", + "polars-arrow", + "polars-error", + "polars-utils", + "strength_reduce", + "version_check", +] + +[[package]] +name = "polars-core" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ba2a3b736d55b92a12889672d0197dc25ad321ab23eba4168a3b6316a6b6349" +dependencies = [ + "ahash", + "bitflags", + "bytemuck", + "chrono", + "chrono-tz", + "comfy-table", + "either", + "hashbrown 0.14.5", + "indexmap", + "num-traits", + "once_cell", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-row", + "polars-utils", + "rand", + "rand_distr", + "rayon", + "regex", + "serde", + "smartstring", + "thiserror 1.0.69", + "version_check", + "xxhash-rust", +] + +[[package]] +name = "polars-error" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07101d1803ca2046cdb3a8adb1523ddcc879229860f0ac56a853034269dec1e1" +dependencies = [ + "polars-arrow-format", + "regex", + "simdutf8", + "thiserror 1.0.69", +] + +[[package]] +name = "polars-expr" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd5c69634ddbb0f44186cd1c42d166963fc756f9cc994438e941bc2703ddbbab" +dependencies = [ + "ahash", + "bitflags", + "once_cell", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", +] + +[[package]] +name = "polars-io" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a48ddf416ae185336c3d7880d2e05b7e55686e3e0da1014e5e7325eff9c7d722" +dependencies = [ + "ahash", + "async-trait", + "atoi_simd", + "bytes", + "chrono", + "chrono-tz", + "fast-float", + "futures", + "glob", + "home", + "itoa", + "memchr", + "memmap2", + "num-traits", + "once_cell", + "percent-encoding", + "polars-arrow", + "polars-core", + "polars-error", + "polars-json", + "polars-parquet", + "polars-time", + "polars-utils", + "rayon", + "regex", + "ryu", + "simdutf8", + "smartstring", + "tokio", + "tokio-util", +] + +[[package]] +name = "polars-json" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0a43388585a922524e8bbaa1ed1391c9c4b0768a644585609afa9a2fd5fc702" +dependencies = [ + "ahash", + "chrono", + "chrono-tz", + "fallible-streaming-iterator", + "hashbrown 0.14.5", + "indexmap", + "itoa", + "num-traits", + "polars-arrow", + "polars-error", + "polars-utils", + "ryu", + "simd-json", + "streaming-iterator", +] + +[[package]] +name = "polars-lazy" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a514a85df9e7d501c71c96f094861d0608b05a3f533447b1c0ea9cf714162fcb" +dependencies = [ + "ahash", + "bitflags", + "memchr", + "once_cell", + "polars-arrow", + "polars-core", + "polars-expr", + "polars-io", + "polars-mem-engine", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", + "version_check", +] + +[[package]] +name = "polars-mem-engine" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d057df81b17b4f0ea0e4424ee34f755e6b9ccfba432ecb2fe57dc4da6da2713" +dependencies = [ + "memmap2", + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", +] + +[[package]] +name = "polars-ops" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01ba44233249b7937491b5d2bdbf14e4ad534c0a65d06548c3bc418fc3e60791" +dependencies = [ + "ahash", + "argminmax", + "base64", + "bytemuck", + "chrono", + "chrono-tz", + "either", + "hashbrown 0.14.5", + "hex", + "indexmap", + "memchr", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-utils", + "rayon", + "regex", + "smartstring", + "unicode-reverse", + "version_check", +] + +[[package]] +name = "polars-parquet" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb2993265079ffa07dd16277189444424f8d787b00b01c6f6e001f58bab543ce" +dependencies = [ + "ahash", + "async-stream", + "base64", + "brotli", + "bytemuck", + "ethnum", + "flate2", + "futures", + "lz4", + "num-traits", + "parquet-format-safe", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-utils", + "simdutf8", + "snap", + "streaming-decompression", + "zstd", +] + +[[package]] +name = "polars-pipe" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ccba94c4fa9fded0f41730f7649574c72d6d938a840731c7e4eea4e7ed5cecf" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "hashbrown 0.14.5", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "smartstring", + "uuid", + "version_check", +] + +[[package]] +name = "polars-plan" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6b29cc53d6c086c09b11050b01c25c28f6a91339036ba1fb1250fcf0d89e74" +dependencies = [ + "ahash", + "bitflags", + "bytemuck", + "chrono", + "chrono-tz", + "either", + "hashbrown 0.14.5", + "memmap2", + "once_cell", + "percent-encoding", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-parquet", + "polars-time", + "polars-utils", + "rayon", + "recursive", + "regex", + "smartstring", + "strum_macros", + "version_check", +] + +[[package]] +name = "polars-row" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e11f43f48466c4b1caa6dc61c381dc10c2d67b87fcb74bc996e21c4f7b0a311" +dependencies = [ + "bytemuck", + "polars-arrow", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-sql" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e9338806e7254618eb819cc632c34b75b71d462222a913f9c1035ed81911ddc" +dependencies = [ + "hex", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-lazy", + "polars-ops", + "polars-plan", + "polars-time", + "rand", + "serde", + "serde_json", + "sqlparser", +] + +[[package]] +name = "polars-time" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a601ab9a62e733b8b560b37642321cb1933faa194864739f6a59d6dfc4d686" +dependencies = [ + "atoi", + "bytemuck", + "chrono", + "chrono-tz", + "now", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "regex", + "smartstring", +] + +[[package]] +name = "polars-utils" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19dd73207bd15efb0ae5c9c3ece3227927ed6a16ad63578acec342378e6bdcb4" +dependencies = [ + "ahash", + "bytemuck", + "bytes", + "hashbrown 0.14.5", + "indexmap", + "memmap2", + "num-traits", + "once_cell", + "polars-error", + "raw-cpuid", + "rayon", + "smartstring", + "stacker", + "sysinfo", + "version_check", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.117", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "psm" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645dbe486e346d9b5de3ef16ede18c26e6c70ad97418f4874b8b1889d6e761ea" +dependencies = [ + "ar_archive_writer", + "cc", +] + +[[package]] +name = "pyo3" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +dependencies = [ + "chrono", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "pyo3-stub-gen" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650d9624b551894664cc95867ccfe4fd814a5a225c8fe3a75194a3ae51caae1d" +dependencies = [ + "anyhow", + "chrono", + "either", + "indexmap", + "inventory", + "itertools", + "log", + "maplit", + "num-complex", + "numpy", + "pyo3", + "pyo3-stub-gen-derive", + "serde", + "toml", +] + +[[package]] +name = "pyo3-stub-gen-derive" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73947c71903f0e3e31a302a350567594063c75ac155031e40519721429898649" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn 2.0.117", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "rust_xlsxwriter" +version = "0.79.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c743cb9f2a4524676020e26ee5f298445a82d882b09956811b1e78ca7e42b440" +dependencies = [ + "zip", +] + +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "simd-adler32" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" + +[[package]] +name = "simd-json" +version = "0.13.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0228a564470f81724e30996bbc2b171713b37b15254a6440c7e2d5449b95691" +dependencies = [ + "ahash", + "getrandom 0.2.17", + "halfbrown", + "lexical-core", + "once_cell", + "ref-cast", + "serde", + "serde_json", + "simdutf8", + "value-trait", +] + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + +[[package]] +name = "siphasher" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg", + "serde", + "static_assertions", + "version_check", +] + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "sqlparser" +version = "0.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +dependencies = [ + "log", +] + +[[package]] +name = "stacker" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "640c8cdd92b6b12f5bcb1803ca3bbf5ab96e5e6b6b96b9ab77dabe9e880b3190" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys", +] + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "streaming-decompression" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" +dependencies = [ + "fallible-streaming-iterator", +] + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.117", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sysinfo" +version = "0.31.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "355dbe4f8799b304b05e1b0f05fc59b2a18d36645cf169607da45bde2f69a1be" +dependencies = [ + "core-foundation-sys", + "libc", + "memchr", + "ntapi", + "windows", +] + +[[package]] +name = "target-features" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" + +[[package]] +name = "target-lexicon" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "tokio" +version = "1.52.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "windows-sys", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "toml_write", + "winnow", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "typenum" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-reverse" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" +dependencies = [ + "unicode-segmentation", +] + +[[package]] +name = "unicode-segmentation" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "uuid" +version = "1.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "value-trait" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad8db98c1e677797df21ba03fca7d3bf9bec3ca38db930954e4fe6e1ea27eb4" +dependencies = [ + "float-cmp", + "halfbrown", + "itoa", + "ryu", +] + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.120" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.120" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.120" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn 2.0.117", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.120" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +dependencies = [ + "windows-core 0.57.0", + "windows-targets", +] + +[[package]] +name = "windows-core" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" +dependencies = [ + "windows-implement 0.57.0", + "windows-interface 0.57.0", + "windows-result 0.1.2", + "windows-targets", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement 0.60.2", + "windows-interface 0.59.3", + "windows-link", + "windows-result 0.4.1", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-interface" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-result" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn 2.0.117", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "zip" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "flate2", + "indexmap", + "memchr", + "thiserror 2.0.18", + "zopfli", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zopfli" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05cd8797d63865425ff89b5c4a48804f35ba0ce8d125800027ad6017d2b5249" +dependencies = [ + "bumpalo", + "crc32fast", + "log", + "simd-adler32", +] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..05472ddab --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,33 @@ +[workspace] +resolver = "2" +members = ["crates/*"] + +[workspace.package] +version = "1.0.0" +edition = "2024" +license = "MIT" +repository = "https://github.com/waditu/czsc" + +[workspace.dependencies] +chrono = "0.4" +hashbrown = "0.14" +ordered-float = "5.0" +polars = { version = "0.42.0", features = [ + "lazy", "serde", "strings", "concat_str", "pivot", "is_in", + "cum_agg", "abs", "ipc", "round_series", "dtype-datetime", + "temporal", "parquet", "timezones", "fmt", "partition_by", +] } +# Workspace-wide pyo3: declared without features so business crates' +# unit tests can link against a real libpython resolved through +# PYO3_PYTHON. The cdylib producer (`czsc-python`) layers on +# `abi3-py310` + `extension-module` so the published wheel uses the +# abi3 stable API and dynamically loads at Python startup time. +pyo3 = { version = "0.25" } +numpy = "0.25" +rayon = "1" +serde = { version = "1", features = ["derive"] } + +[profile.release] +lto = true +opt-level = 3 +codegen-units = 1 diff --git a/crates/czsc-core/Cargo.toml b/crates/czsc-core/Cargo.toml new file mode 100644 index 000000000..3866342cb --- /dev/null +++ b/crates/czsc-core/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "czsc-core" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "CZSC core analyzer (FX/BI/ZS/CZSC) — placeholder, to be migrated from rs-czsc" + +[lib] +name = "czsc_core" +path = "src/lib.rs" + +[dependencies] +error-macros = { path = "../error-macros" } +error-support = { path = "../error-support" } +anyhow = "1" +chrono = { workspace = true } +derive_builder = "0.20" +hex = "0.4" +log = "0.4" +md5 = "0.8" +parking_lot = { version = "0.12", optional = true } +polars = { workspace = true, features = ["lazy", "ipc"] } +serde = { workspace = true } +serde_json = "1" +sha2 = "0.10" +strum = "0.26" +strum_macros = "0.26" +thiserror = "2" +pyo3 = { workspace = true, optional = true, features = ["chrono"] } +pyo3-stub-gen = { version = "0.12", optional = true } + +[features] +python = ["pyo3", "pyo3-stub-gen", "parking_lot"] + +[dev-dependencies] +anyhow = "1" +serde_json = "1" +thiserror = "2" diff --git a/crates/czsc-core/src/analyze/errors.rs b/crates/czsc-core/src/analyze/errors.rs new file mode 100644 index 000000000..fb5213cf7 --- /dev/null +++ b/crates/czsc-core/src/analyze/errors.rs @@ -0,0 +1,13 @@ +use error_macros::CZSCErrorDerive; +use error_support::expand_error_chain; +use polars::error::PolarsError; +use thiserror::Error; + +#[derive(Debug, Error, CZSCErrorDerive)] +pub enum AnalyzeErorr { + #[error("Polars: {0}")] + Polars(#[from] PolarsError), + + #[error("{}", expand_error_chain(.0))] + Unexpected(anyhow::Error), +} diff --git a/crates/czsc-core/src/analyze/mod.rs b/crates/czsc-core/src/analyze/mod.rs new file mode 100644 index 000000000..5935a535b --- /dev/null +++ b/crates/czsc-core/src/analyze/mod.rs @@ -0,0 +1,913 @@ +#[cfg(feature = "python")] +use crate::analyze::utils::format_standard_kline; +use crate::objects::{ + bar::{NewBar, RawBar, Symbol}, + bi::BI, + direction::Direction, + freq::Freq, + fx::FX, + mark::Mark, +}; +use derive_builder::Builder; +#[cfg(feature = "python")] +use parking_lot::RwLock; +#[cfg(feature = "python")] +use polars::prelude::*; +#[cfg(feature = "python")] +use std::io::Cursor; +#[cfg(feature = "python")] +use std::sync::Arc; +use utils::{check_bi, check_fxs, remove_include}; +pub mod errors; +pub mod utils; + +#[cfg(feature = "python")] +use crate::utils::common::freq_to_chinese_string; +#[cfg(feature = "python")] +use crate::utils::common::{create_naive_pandas_timestamp, create_ordered_dict}; +#[cfg(feature = "python")] +use pyo3::prelude::{PyAnyMethods, PyDictMethods}; +#[cfg(feature = "python")] +use pyo3::types::{PyBytesMethods, PyDict}; +#[cfg(feature = "python")] +use pyo3::{Py, PyErr, PyObject, PyResult, Python}; +#[cfg(feature = "python")] +use pyo3::{pyclass, pymethods}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Builder)] +pub struct CZSC { + // verbose: bool, + /// 最大允许保留的笔数量 + pub max_bi_num: usize, + /// 原始K线序列 + pub bars_raw: Vec, + pub bars_ubi: Vec, + pub bi_list: Vec, + pub symbol: Symbol, + pub freq: Freq, + // get_signals + // signals + #[cfg(feature = "python")] + #[builder(default = "Arc::new(RwLock::new(None))")] + pub cache: Arc>>>, +} + +impl CZSC { + /// 对齐 Python 同 dt 延伸时的对象共享语义: + /// 仅同步“被 pop 出来的 last_ubi 对象”在已入笔结构中的镜像副本。 + fn sync_extended_last_ubi_in_bis(&mut self, last_ubi: &NewBar, bar: &RawBar) { + #[inline] + fn patch_new_bar_if_same(nb: &mut NewBar, target: &NewBar, bar: &RawBar) { + if nb == target + && let Some(last) = nb.elements.last_mut() + && last.dt == bar.dt + { + *last = bar.clone(); + } + } + + #[inline] + fn patch_fx_if_same(fx: &mut FX, target: &NewBar, bar: &RawBar) { + for nb in &mut fx.elements { + patch_new_bar_if_same(nb, target, bar); + } + } + + for bi in &mut self.bi_list { + for nb in &mut bi.bars { + patch_new_bar_if_same(nb, last_ubi, bar); + } + patch_fx_if_same(&mut bi.fx_a, last_ubi, bar); + patch_fx_if_same(&mut bi.fx_b, last_ubi, bar); + for fx in &mut bi.fxs { + patch_fx_if_same(fx, last_ubi, bar); + } + } + } + + pub fn new(bars_raw: Vec, max_bi_num: usize) -> Self { + // todo check length of bars_raw + + let mut c = Self { + max_bi_num, + bars_raw: Vec::with_capacity(bars_raw.len()), // 预分配容量 + bars_ubi: Vec::with_capacity(bars_raw.len() / 2), // 预估容量 + bi_list: Vec::with_capacity(max_bi_num.min(bars_raw.len() / 10)), // 预估笔数量 + symbol: bars_raw[0].symbol.clone(), + freq: bars_raw[0].freq, + #[cfg(feature = "python")] + cache: Arc::new(RwLock::new(None)), + }; + + for b in bars_raw { + c.update_bar(b); + } + + c + } + + /// 分型列表,包括 bars_ubi 中的分型 + pub fn get_fx_list(&self) -> Vec { + let mut fxs = Vec::new(); + for bi_ in self.bi_list.iter() { + fxs.extend_from_slice(&bi_.fxs[1..]); + } + + if let Some(ubi_fxs) = self.get_ubi_fxs() { + for x in ubi_fxs { + if fxs.is_empty() || x.dt > fxs.last().unwrap().dt { + fxs.push(x); + } + } + } + fxs + } + + /// 更新分析结果 + /// + /// :param bar: 单根K线对象 + pub fn update_bar(&mut self, bar: RawBar) { + // 更新K线序列 + let last_bars = if self.bars_raw.is_empty() || bar.dt != self.bars_raw.last().unwrap().dt { + self.bars_raw.push(bar.clone()); + vec![bar] + } else { + // 当前 bar 是上一根 bar 的时间延伸 + *self.bars_raw.last_mut().unwrap() = bar.clone(); + let last_ubi = self.bars_ubi.pop().unwrap(); + self.sync_extended_last_ubi_in_bis(&last_ubi, &bar); + let mut last_bars = last_ubi.elements.to_vec(); + assert_eq!( + bar.dt, + last_bars.last().unwrap().dt, + "时间错位: {} != {}", + bar.dt, + last_bars.last().unwrap().dt + ); + + *last_bars.last_mut().unwrap() = bar; + last_bars + }; + + // 去除包含关系 + for bar in last_bars.iter() { + if self.bars_ubi.len() < 2 { + self.bars_ubi.push(NewBar::new_from_raw(bar)); + } else { + let (has_include, k3) = { + // 安全获取两个相邻元素的引用 + let idx = self.bars_ubi.len() - 2; + let (_, last_two) = self.bars_ubi.split_at_mut(idx); + let k1 = &last_two[0]; // 倒数第二个元素 + let k2 = &last_two[1]; // 最后一个元素 + remove_include(k1, k2, bar.clone()).unwrap() + }; + if has_include { + *self.bars_ubi.last_mut().unwrap() = k3; + } else { + self.bars_ubi.push(k3); + } + } + } + + // 更新笔 + self.__update_bi(); + // 根据最大笔数量限制完成 bi_list, bars_raw 序列的数量控制 + if self.bi_list.len() > self.max_bi_num { + let start_idx = self.bi_list.len() - self.max_bi_num; + self.bi_list.drain(0..start_idx); + } + + if !self.bi_list.is_empty() { + let sdt = self.bi_list.first().unwrap().fx_a.elements[0].dt; + // 对齐 Python: 取第一个 dt >= sdt 的位置(重复 dt 时必须取最左侧) + let drain_to = self.bars_raw.partition_point(|bar| bar.dt < sdt); + self.bars_raw.drain(0..drain_to); + } + + // 如果有信号计算函数,则进行信号计算 + // todo self.get_signals + } + + fn __update_bi(&mut self) -> Option<()> { + if self.bars_ubi.len() < 3 { + return None; + } + + // 查找笔 + if self.bi_list.is_empty() { + // 第一笔的查找 + let fxs = check_fxs(&self.bars_ubi); + + let first = fxs.first()?; + let fx_a = fxs + .iter() + .filter(|x| x.mark == first.mark) + .reduce(|acc, x| match first.mark { + Mark::D if x.low <= acc.low => x, + Mark::G if x.high >= acc.high => x, + _ => acc, + }) + .unwrap_or(first); + + let bars_ubi = self + .bars_ubi + .iter() + .filter(|x| x.dt >= fx_a.elements[0].dt) + .collect::>(); + + let (bi, bars_ubi_) = check_bi(&bars_ubi); + if let Some(bi) = bi { + self.bi_list.push(bi); + } + + self.bars_ubi = bars_ubi_.iter().map(|&bar| bar.clone()).collect::>(); + return None; + } + + // todo log + + // println!( + // "dt: {}, 未完成笔延伸数量: {}", + // self.bars_ubi.last().unwrap().dt, + // self.bars_ubi.len() + // ); + let (bi, bars_ubi_) = check_bi(&self.bars_ubi); + if let Some(bi) = bi { + self.bi_list.push(bi); + } + + self.bars_ubi = bars_ubi_.to_vec(); + + // 后处理:如果当前笔被破坏,将当前笔的bars与bars_ubi进行合并,并丢弃 + let last_bi = self.bi_list.last().unwrap(); // 获取最后一个笔 + let bars_ubi = &self.bars_ubi; + + if bars_ubi.last().is_some() + && ((last_bi.direction == Direction::Up + && bars_ubi.last().unwrap().high > last_bi.get_high()) + || (last_bi.direction == Direction::Down + && bars_ubi.last().unwrap().low < last_bi.get_low())) + { + // 当前笔被破坏,将当前笔的bars与bars_ubi进行合并 + // 使用除了最后两根K线之外的所有K线 + let merge_point = last_bi.bars[last_bi.bars.len() - 2].dt; + + // 创建新的合并后的bars序列 + self.bars_ubi = last_bi.bars[..last_bi.bars.len() - 2] + .iter() + .chain(bars_ubi.iter().filter(|x| x.dt >= merge_point)) + .cloned() + .collect(); + + // 移除最后一个笔 + self.bi_list.pop(); + } + None + } + + /// 获取 bars_ubi 中的分型 + pub fn get_ubi_fxs(&self) -> Option> { + if self.bars_ubi.is_empty() { + return None; + } + Some(check_fxs(&self.bars_ubi)) + } + + #[allow(unused)] + /// Unfinished Bi,未完成的笔 + fn get_ubi(&self) -> Option { + if self.bars_ubi.is_empty() || self.bi_list.is_empty() { + return None; + } + + let ubi_fxs = self.get_ubi_fxs()?; + + let bars_raw = self + .bars_ubi + .iter() + .flat_map(|x| &x.elements) + .collect::>(); + + // 获取最高点和最低点,以及对应的时间 + let high_bar = bars_raw + .iter() + .max_by(|a, b| { + a.high + .partial_cmp(&b.high) + .unwrap_or(std::cmp::Ordering::Less) + }) + .unwrap() + .to_owned() + .to_owned(); + + let low_bar = bars_raw + .iter() + .min_by(|a, b| { + a.low + .partial_cmp(&b.low) + .unwrap_or(std::cmp::Ordering::Greater) + }) + .unwrap() + .to_owned() + .to_owned(); + + let direction = if self.bi_list.last().unwrap().direction == Direction::Down { + Direction::Up + } else { + Direction::Down + }; + let fx_a = ubi_fxs.first().unwrap().to_owned(); + Some(UBI { + symbol: self.symbol.clone(), + direction, + high: high_bar.high, + low: low_bar.low, + high_bar, + low_bar, + bars: self.bars_ubi.to_owned(), + raw_bars: self.bars_raw.to_owned(), + fxs: ubi_fxs, + fx_a, + }) + } +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl CZSC { + #[new] + #[pyo3(signature = (bars_raw, max_bi_num=50))] + pub fn new_py(bars_raw: Vec, max_bi_num: usize) -> PyResult { + Ok(CZSC::new(bars_raw, max_bi_num)) + } + + /// 直接从Arrow格式的DataFrame创建CZSC对象,避免中间转换 + /// 这是高性能的批量创建接口,适用于大量数据的初始化 + /// + /// :param df_bytes: Arrow IPC格式的DataFrame字节数据 + /// :param freq: K线频率 + /// :param max_bi_num: 最大笔数量限制 + /// :return: CZSC对象 + #[staticmethod] + #[pyo3(signature = (df_bytes, freq, max_bi_num=50))] + pub fn from_dataframe( + df_bytes: pyo3::Bound<'_, pyo3::types::PyBytes>, + freq: Freq, + max_bi_num: usize, + ) -> PyResult { + // 直接从Arrow字节数据创建DataFrame + let bytes_data = df_bytes.as_bytes(); + let cursor = Cursor::new(bytes_data); + let df = IpcReader::new(cursor).finish().map_err(|e| { + PyErr::new::(format!( + "Failed to read Arrow data: {e}" + )) + })?; + + // 数据验证:确保DataFrame包含必需的列 + let required_columns = [ + "symbol", "dt", "open", "close", "high", "low", "vol", "amount", + ]; + for col in &required_columns { + if !df.get_column_names().contains(col) { + return Err(PyErr::new::(format!( + "Missing required column: {col}" + ))); + } + } + + // 验证数据不为空 + if df.height() == 0 { + return Err(PyErr::new::( + "DataFrame is empty", + )); + } + + // 直接格式化为RawBar - 这是性能关键路径 + let bars = format_standard_kline(df, freq).map_err(|e| { + PyErr::new::(format!( + "Failed to format kline data: {e}" + )) + })?; + + // 批量创建CZSC对象 + Ok(CZSC::new(bars, max_bi_num)) + } + + #[getter] + fn symbol(&self) -> String { + self.symbol.to_string() + } + + #[getter] + fn freq(&self) -> Freq { + self.freq + } + + #[getter] + fn max_bi_num(&self) -> usize { + self.max_bi_num + } + + #[getter] + fn bi_list(&self) -> Vec { + self.bi_list.to_vec() + } + + /// 获取原始K线序列 - 返回PyRawBar对象列表 + #[getter] + fn bars_raw(&self) -> Vec { + self.bars_raw.to_vec() + } + + /// 获取原始K线序列的DataFrame格式,便于绘图和分析 + #[getter] + fn bars_raw_df(&self, py: Python) -> PyResult { + let pandas = py.import("pandas")?; + let df_class = pandas.getattr("DataFrame")?; + + let data: Vec = self + .bars_raw + .iter() + .map(|bar| -> PyResult { + let dict = PyDict::new(py); + dict.set_item("symbol", bar.symbol.as_ref())?; + dict.set_item("dt", create_naive_pandas_timestamp(py, bar.dt)?)?; + dict.set_item("freq", freq_to_chinese_string(bar.freq))?; + dict.set_item("id", bar.id)?; + dict.set_item("open", bar.open)?; + dict.set_item("close", bar.close)?; + dict.set_item("high", bar.high)?; + dict.set_item("low", bar.low)?; + dict.set_item("vol", bar.vol)?; + dict.set_item("amount", bar.amount)?; + Ok(dict.into()) + }) + .collect::>>()?; + + let df = df_class.call1((data,))?; + Ok(df.into()) + } + + /// 获取无包含关系K线序列 + #[getter] + fn bars_ubi(&self) -> Vec { + self.bars_ubi.to_vec() + } + + /// 获取已完成的笔列表(与 bi_list 相同,为兼容 czsc 库) + #[getter] + fn finished_bis(&self) -> Vec { + if self.bi_list.is_empty() { + return vec![]; + } + if self.bars_ubi.len() < 5 { + return self.bi_list[..self.bi_list.len().saturating_sub(1)].to_vec(); + } + self.bi_list.to_vec() + } + + /// 获取分型列表(属性,与 czsc 库兼容) + #[getter] + fn fx_list(&self) -> Vec { + self.get_fx_list().into_iter().collect() + } + + /// 缓存字典(与 czsc 库兼容) + #[getter] + fn cache(&self, py: Python) -> PyResult { + create_ordered_dict(py) + } + + /// 信号字典(与 czsc 库兼容) + #[getter] + fn signals(&self, py: Python) -> PyResult { + create_ordered_dict(py) + } + + /// 无包含关系K线分型列表(与 czsc 库兼容) + #[getter] + fn ubi_fxs(&self) -> Vec { + self.get_ubi_fxs().unwrap_or_default() + } + + /// 无包含关系K线(与 czsc 库兼容) + /// 返回未完成的笔信息,格式与 Python 版本保持一致 + #[getter] + fn ubi(&self, py: Python) -> PyResult { + let ubi_fxs = self.get_ubi_fxs().unwrap_or_default(); + + if self.bars_ubi.is_empty() || self.bi_list.is_empty() || ubi_fxs.is_empty() { + return Ok(py.None()); + } + + // 获取所有原始K线 + let bars_raw: Vec = self + .bars_ubi + .iter() + .flat_map(|x| &x.elements) + .cloned() + .collect(); + + if bars_raw.is_empty() { + return Ok(py.None()); + } + + // 获取最高点和最低点 + let high_bar = bars_raw + .iter() + .max_by(|a, b| a.high.partial_cmp(&b.high).unwrap()) + .unwrap() + .clone(); + + let low_bar = bars_raw + .iter() + .min_by(|a, b| a.low.partial_cmp(&b.low).unwrap()) + .unwrap() + .clone(); + + // 确定方向:与最后一笔相反 + let direction = if self.bi_list.last().unwrap().direction == Direction::Down { + Direction::Up + } else { + Direction::Down + }; + + // 创建字典,按照原版 czsc 的字段顺序 + let dict = PyDict::new(py); + dict.set_item("symbol", self.symbol.as_ref())?; + dict.set_item("direction", direction)?; + dict.set_item("high", high_bar.high)?; + dict.set_item("low", low_bar.low)?; + dict.set_item("high_bar", high_bar)?; + dict.set_item("low_bar", low_bar)?; + dict.set_item("bars", self.bars_ubi())?; + dict.set_item("raw_bars", bars_raw)?; + dict.set_item("fxs", ubi_fxs.clone())?; + dict.set_item("fx_a", ubi_fxs.first().unwrap().clone())?; + + // 直接返回字典 + Ok(dict.into()) + } + + /// 是否显示详细信息(与 czsc 库兼容) + #[getter] + fn verbose(&self) -> bool { + false // 默认不显示详细信息 + } + + /// 最后一笔延伸情况(与 czsc 库兼容) + /// 判断最后一笔是否在延伸中,True 表示延伸中 + #[getter] + fn last_bi_extend(&self) -> bool { + // 如果没有笔,返回 false + if self.bi_list.is_empty() { + return false; + } + + // 如果没有无包含关系K线,返回 false + if self.bars_ubi.is_empty() { + return false; + } + + let last_bi = &self.bi_list[self.bi_list.len() - 1]; + + match last_bi.direction { + Direction::Up => { + // 向上笔:检查当前所有无包含K线的最高价是否 > 最后一笔的高点 + let max_high = self + .bars_ubi + .iter() + .map(|bar| bar.high) + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap_or(0.0); + + max_high > last_bi.get_high() + } + Direction::Down => { + // 向下笔:检查当前所有无包含K线的最低价是否 < 最后一笔的低点 + let min_low = self + .bars_ubi + .iter() + .map(|bar| bar.low) + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap_or(f64::MAX); + + min_low < last_bi.get_low() + } + } + } + + /// 在浏览器中打开(与 czsc 库兼容) + #[pyo3(signature = (_renderer=None))] + fn open_in_browser(&self, _renderer: Option<&str>) -> PyResult { + Ok("Browser opening not implemented in Rust version".to_string()) + } + + /// 转换为 ECharts 格式(与 czsc 库兼容) + fn to_echarts(&self) -> PyResult { + Ok("ECharts export not implemented in Rust version".to_string()) + } + + /// 转换为 Plotly 格式(与 czsc 库兼容) + fn to_plotly(&self) -> PyResult { + Ok("Plotly export not implemented in Rust version".to_string()) + } + + /// 更新K线数据 + fn update(&mut self, bar: RawBar) -> PyResult<()> { + self.update_bar(bar); + Ok(()) + } + /// 缓存字典(与 czsc 库兼容) + #[getter] + fn get_cache<'py>(&'py self, py: Python<'py>) -> Py { + // 首先尝试读锁获取缓存 + { + let cache_read = self.cache.read(); + if let Some(ref cached_dict) = *cache_read { + return cached_dict.clone_ref(py); + } + } + + // 如果缓存为空,使用写锁初始化 + let mut cache_write = self.cache.write(); + if cache_write.is_none() { + *cache_write = Some(PyDict::new(py).unbind()); + } + cache_write.as_ref().unwrap().clone_ref(py) + } + + #[setter] + #[gen_stub(skip)] // 跳过为了防止和 get_cache重复 + fn set_cache(&self, dict: Py) { + let mut cache_write = self.cache.write(); + *cache_write = Some(dict); + } + fn __repr__(&self) -> String { + format!( + "CZSC(symbol={}, freq={:?}, max_bi_num={}, bi_count={})", + self.symbol, + self.freq, + self.max_bi_num, + self.bi_list.len() + ) + } + + /// Pickle support — `__reduce__` returns ``(CZSC, (fixed_point_bars, max_bi_num))``. + /// + /// `update_bar` drains older bars whose dt is below the current + /// first-BI's start (see `bars_raw.drain` block above), so a + /// freshly-constructed CZSC's `bars_raw` may still differ from the + /// fixed point reached after a single re-analysis. We run one extra + /// `CZSC::new` here to converge before serializing — guarantees that + /// `pickle.dumps(restored) == pickle.dumps(obj)` byte-for-byte even + /// when CzscSignals nests CZSC inside `kas[freq]` (Phase A's + /// `restored.__getstate__() == obj.__getstate__()` assertion relies + /// on this). + fn __reduce__(&self, py: Python) -> PyResult { + use pyo3::IntoPyObject; + let trimmed = CZSC::new(self.bars_raw.clone(), self.max_bi_num); + let args = (trimmed.bars_raw, self.max_bi_num).into_pyobject(py)?; + let constructor = py.get_type::(); + let result = (constructor, args).into_pyobject(py)?; + Ok(result.into()) + } +} + +/// Unfinished Bi,未完成的笔 +#[derive(Debug, Clone)] +pub struct UBI { + pub symbol: Symbol, + pub direction: Direction, + pub high: f64, + pub low: f64, + pub high_bar: RawBar, + pub low_bar: RawBar, + pub bars: Vec, + pub raw_bars: Vec, + pub fxs: Vec, + pub fx_a: FX, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::analyze::utils::format_standard_kline; + use crate::objects::freq::Freq; + use chrono::NaiveDateTime; + use chrono::{DateTime, Utc}; + use polars::prelude::SerReader; + use polars::prelude::{CsvReader, StringChunked, StringMethods}; + use std::io::Cursor; + + fn example_data() -> &'static str { + const CSV_DATA: &str = r#" +dt,symbol,open,close,high,low,vol,amount +2025-01-02,002515.SZ,50.73,51.29,52.97,50.62,32900684.0,152798823.0 +2025-01-03,002515.SZ,51.4,48.72,51.85,48.6,33224687.0,147184323.0 +2025-01-06,002515.SZ,48.83,48.6,49.39,47.48,17419634.0,75608391.0 +2025-01-07,002515.SZ,48.6,48.94,49.05,48.27,13929982.0,60500438.0 +2025-01-08,002515.SZ,48.27,48.04,48.94,47.26,17697397.0,75973887.0 +2025-01-09,002515.SZ,48.27,48.16,48.83,47.6,14284260.0,61391856.0 +2025-01-10,002515.SZ,48.04,46.92,48.94,46.81,16080374.0,68834125.0 +2025-01-13,002515.SZ,46.59,46.92,47.26,45.47,12508818.0,52037636.0 +2025-01-14,002515.SZ,46.92,48.16,48.27,46.92,16407679.0,69944802.0 +2025-01-15,002515.SZ,49.5,49.5,50.73,49.05,29140842.0,129502353.0 +2025-01-16,002515.SZ,49.5,49.72,50.28,48.94,19124511.0,84774186.0 +2025-01-17,002515.SZ,49.28,50.28,51.74,49.05,22228511.0,99754272.0 +2025-01-20,002515.SZ,50.4,50.4,50.73,49.61,14908933.0,66989586.0 +2025-01-21,002515.SZ,50.62,50.06,50.73,49.61,11565100.0,51612511.0 +2025-01-22,002515.SZ,50.06,49.16,50.06,48.83,10889797.0,47963340.0 +2025-01-23,002515.SZ,49.39,48.72,49.95,48.72,13050206.0,57522568.0 +2025-01-24,002515.SZ,48.49,48.83,48.94,48.27,12042388.0,52334558.0 +2025-01-27,002515.SZ,49.05,49.39,51.74,49.05,22813802.0,102357601.0 +2025-02-05,002515.SZ,49.39,49.16,49.95,48.72,13525075.0,59524887.0 +2025-02-06,002515.SZ,48.83,49.05,49.28,48.16,17429613.0,75782611.0 +2025-02-07,002515.SZ,48.94,49.5,49.95,48.72,17447114.0,76989329.0 +2025-02-10,002515.SZ,49.39,50.4,50.51,49.16,18733821.0,83810683.0 +2025-02-11,002515.SZ,50.4,49.84,50.73,49.61,13189816.0,58803966.0 +2025-02-12,002515.SZ,50.06,50.06,50.4,49.5,15881392.0,70692291.0 +2025-02-13,002515.SZ,49.84,49.84,50.51,49.61,18048669.0,80671035.0 +2025-02-14,002515.SZ,49.72,49.05,49.95,48.94,17455299.0,76786904.0 +2025-02-17,002515.SZ,49.16,49.39,49.61,48.6,15791678.0,69303481.0 +2025-02-18,002515.SZ,49.16,47.71,49.39,47.48,20599809.0,88885983.0 +2025-02-19,002515.SZ,47.48,48.04,48.16,47.37,12911258.0,55064600.0 +2025-02-20,002515.SZ,48.04,48.27,48.83,47.71,12823411.0,55267260.0 +2025-02-21,002515.SZ,48.27,47.6,48.72,47.48,16547084.0,70527761.0 +2025-02-24,002515.SZ,47.71,52.41,52.41,47.71,93355060.0,426873493.0 +2025-02-25,002515.SZ,51.96,50.51,51.96,50.17,54431026.0,246916111.0 +2025-02-26,002515.SZ,50.62,52.52,52.86,50.17,50584995.0,232883144.0 +2025-02-27,002515.SZ,52.41,53.64,53.98,51.96,47142936.0,224200231.0 +2025-02-28,002515.SZ,53.2,52.52,53.53,52.41,29058781.0,137329596.0 +"#; + CSV_DATA + } + + fn get_bars() -> Vec { + let cursor = Cursor::new(example_data().as_bytes()); + let mut df = CsvReader::new(cursor).finish().unwrap(); + + let dt_col = df + .column("dt") + .unwrap() + .str() + .unwrap() + .as_datetime( + Some("%Y-%m-%d"), + polars::prelude::TimeUnit::Milliseconds, + false, + false, + None, + &StringChunked::from_iter(std::iter::once("raise")), + ) + .unwrap(); + df.with_column(dt_col).unwrap(); + + format_standard_kline(df, Freq::D).unwrap() + } + + fn parse_dt(s: &str) -> DateTime { + NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") + .unwrap() + .and_local_timezone(Utc) + .unwrap() // 保证时间有效性 + } + + /// ## 数据来源 + /// + /// + /// ``` + /// from czsc.connectors import cooperation as coo + /// df = coo.stocks_daily_klines(sdt="20250101", edt="20250302") + /// df = df[df["symbol"]=="002515.SZ"][['dt', 'symbol','open', 'close', 'high', 'low', 'vol', 'amount']] + /// df.reset_index(drop=True, inplace=True) + /// bars = czsc.format_standard_kline(df, "日线") + /// c = czsc.CZSC(bars) + /// for b in c.bi_list: + /// print(b) + /// ``` + /// + /// ``` + /// BI(symbol=002515.SZ, sdt=2025-01-13 00:00:00, edt=2025-01-17 00:00:00, direction=向上, high=51.74, low=45.47) + /// BI(symbol=002515.SZ, sdt=2025-01-17 00:00:00, edt=2025-02-06 00:00:00, direction=向下, high=51.74, low=48.16) + /// BI(symbol=002515.SZ, sdt=2025-02-06 00:00:00, edt=2025-02-11 00:00:00, direction=向上, high=50.73, low=48.16) + /// BI(symbol=002515.SZ, sdt=2025-02-11 00:00:00, edt=2025-02-19 00:00:00, direction=向下, high=50.73, low=47.37) + /// ``` + #[test] + fn test_czsc_bi_list() { + let bars = get_bars(); + let c = CZSC::new(bars, 50); + + let expected = [ + ( + "2025-01-13 00:00:00", + "2025-01-17 00:00:00", + Direction::Up, + 51.74, + 45.47, + ), + ( + "2025-01-17 00:00:00", + "2025-02-06 00:00:00", + Direction::Down, + 51.74, + 48.16, + ), + ( + "2025-02-06 00:00:00", + "2025-02-11 00:00:00", + Direction::Up, + 50.73, + 48.16, + ), + ( + "2025-02-11 00:00:00", + "2025-02-19 00:00:00", + Direction::Down, + 50.73, + 47.37, + ), + ]; + + assert_eq!(c.bi_list.len(), expected.len()); + + for (i, (bi, exp)) in c.bi_list.iter().zip(expected.iter()).enumerate() { + assert_eq!(bi.start_dt(), parse_dt(exp.0), "Index {i} sdt mismatch"); + assert_eq!(bi.end_dt(), parse_dt(exp.1), "Index {i} edt mismatch"); + assert_eq!(bi.direction, exp.2, "Index {i} direction mismatch"); + assert!( + (bi.get_high() - exp.3).abs() < 1e-4, + "Index {i} high mismatch" + ); + assert!( + (bi.get_low() - exp.4).abs() < 1e-4, + "Index {i} low mismatch" + ); + } + } + + /// ## 数据来源 + /// + /// + /// ``` + /// from czsc.connectors import cooperation as coo + /// df = coo.stocks_daily_klines(sdt="20250101", edt="20250302") + /// df = df[df["symbol"]=="002515.SZ"][['dt', 'symbol','open', 'close', 'high', 'low', 'vol', 'amount']] + /// df.reset_index(drop=True, inplace=True) + /// bars = czsc.format_standard_kline(df, "日线") + /// c = czsc.CZSC(bars) + /// for b in c.fx_list: + /// print(f"dt={fx.dt}, fx={fx.fx}") + /// ``` + /// + /// ``` + /// dt: 2025-01-15 00:00:00, fx: 50.73 + /// dt: 2025-01-16 00:00:00, fx: 48.94 + /// dt: 2025-01-17 00:00:00, fx: 51.74 + /// dt: 2025-01-24 00:00:00, fx: 48.27 + /// dt: 2025-01-27 00:00:00, fx: 51.74 + /// dt: 2025-02-06 00:00:00, fx: 48.16 + /// dt: 2025-02-11 00:00:00, fx: 50.73 + /// dt: 2025-02-12 00:00:00, fx: 49.5 + /// dt: 2025-02-13 00:00:00, fx: 50.51 + /// dt: 2025-02-19 00:00:00, fx: 47.37 + /// dt: 2025-02-20 00:00:00, fx: 48.83 + /// dt: 2025-02-21 00:00:00, fx: 47.48 + /// ``` + #[test] + fn test_czsc_fx_list() { + let bars = get_bars(); + let c = CZSC::new(bars, 50); + + let expected = [ + ("2025-01-15 00:00:00", 50.73), + ("2025-01-16 00:00:00", 48.94), + ("2025-01-17 00:00:00", 51.74), + ("2025-01-24 00:00:00", 48.27), + ("2025-01-27 00:00:00", 51.74), + ("2025-02-06 00:00:00", 48.16), + ("2025-02-11 00:00:00", 50.73), + ("2025-02-12 00:00:00", 49.5), + ("2025-02-13 00:00:00", 50.51), + ("2025-02-19 00:00:00", 47.37), + ("2025-02-20 00:00:00", 48.83), + ("2025-02-21 00:00:00", 47.48), + ]; + + // for fx in c.get_fx_list() { + // println!("dt={dt}, fx={fx}", dt = fx.dt, fx = fx.fx) + // } + + for (i, (fx, exp)) in c.get_fx_list().iter().zip(expected.iter()).enumerate() { + assert_eq!(fx.dt, parse_dt(exp.0), "Index {i} dt mismatch"); + assert!((fx.fx - exp.1).abs() < 1e-4, "Index {i} fx mismatch"); + } + } +} diff --git a/crates/czsc-core/src/analyze/utils.rs b/crates/czsc-core/src/analyze/utils.rs new file mode 100644 index 000000000..56335f9ee --- /dev/null +++ b/crates/czsc-core/src/analyze/utils.rs @@ -0,0 +1,374 @@ +use super::errors::AnalyzeErorr; +use crate::objects::bar::RawBarBuilder; +use crate::objects::{ + bar::{NewBar, NewBarBuilder, RawBar}, + bi::{BI, BIBuilder}, + direction::Direction, + freq::Freq, + fx::{FX, FXBuilder}, + mark::Mark, +}; +use anyhow::Context; +use chrono::DateTime; +use chrono::Utc; +use polars::frame::DataFrame; +use polars::prelude::TimeUnit; + +/// 去除包含关系:输入三根k线,其中k1和k2为没有包含关系的K线,k3为原始K线 +/// 处理逻辑如下: +/// +/// 1. 首先,通过比较k1和k2的高点(high)的大小关系来确定direction的值。如果k1的高点小于k2的高点, +/// 则设定direction为Up;如果k1的高点大于k2的高点,则设定direction为Down;如果k1和k2的高点相等, +/// 则创建一个新的K线k4,与k3具有相同的属性,并返回False和k4。 +/// +/// 2. 接下来,判断k2和k3之间是否存在包含关系。如果存在,则根据direction的值进行处理。 +/// - 如果direction为Up,则选择k2和k3中的较大高点作为新K线k4的高点,较大低点作为低点,较大高点所在的时间戳(dt)作为k4的时间戳。 +/// - 如果direction为Down,则选择k2和k3中的较小高点作为新K线k4的高点,较小低点作为低点,较小低点所在的时间戳(dt)作为k4的时间戳。 +/// +/// 3. 根据上述处理得到的高点、低点、开盘价(open_)、收盘价(close),计算新K线k4的成交量(vol)和成交金额(amount), +/// 并将k2中除了与k3时间戳相同的元素之外的其他元素与k3一起作为k4的元素列表(elements)。 +/// +/// 4. 返回一个布尔值和新的K线k4。如果k2和k3之间存在包含关系,则返回True和k4;否则返回False和k4,其中k4与k3具有相同的属性。 +pub fn remove_include( + k1: &NewBar, + k2: &NewBar, + k3: RawBar, +) -> Result<(bool, NewBar), AnalyzeErorr> { + // 根据k1和k2的high确定direction + let direction = if k1.high < k2.high { + Direction::Up + } else if k1.high > k2.high { + Direction::Down + } else { + return Ok((false, NewBar::new_from_raw(&k3))); + }; + + // 检查k2和k3是否存在包含关系 + let has_inclusion = + (k2.high <= k3.high && k2.low >= k3.low) || (k2.high >= k3.high && k2.low <= k3.low); + + if !has_inclusion { + return Ok((false, NewBar::new_from_raw(&k3))); + } + + // 处理包含关系 + let (high, low, dt) = match direction { + Direction::Up => { + let high = k2.high.max(k3.high); + let low = k2.low.max(k3.low); + let dt = if k2.high > k3.high { k2.dt } else { k3.dt }; + (high, low, dt) + } + Direction::Down => { + let high = k2.high.min(k3.high); + let low = k2.low.min(k3.low); + let dt = if k2.low < k3.low { k2.dt } else { k3.dt }; + (high, low, dt) + } + }; + + let (open_, close) = if k3.open > k3.close { + (high, low) + } else { + (low, high) + }; + + let k4 = { + let k3_dt = k3.dt; + NewBarBuilder::default() + .symbol(k2.symbol.clone()) + .id(k2.id) + .freq(k2.freq) + .dt(dt) + .open(open_) + .close(close) + .high(high) + .low(low) + .vol(k2.vol + k3.vol) + .amount(k2.amount + k3.amount) + .elements( + k2.elements + .iter() + .take(100) + .filter(|x| x.dt != k3_dt) + .cloned() + .chain(std::iter::once(k3)) + .collect::>(), + ) + .build() + .context("Failed to new NewBar")? + }; + + Ok((true, k4)) +} + +/// 输入一串无包含关系K线,查找其中所有分型 +/// +/// 函数的主要步骤: +/// +/// 1. 创建一个空列表`fxs`用于存储找到的分型。 +/// 2. 遍历`bars`列表中的每个元素(除了第一个和最后一个),并对每三个连续的`NewBar`对象调用`check_fx`函数。 +/// 3. 如果`check_fx`函数返回一个`FX`对象,检查它的标记是否与`fxs`列表中最后一个`FX`对象的标记相同。如果相同,记录一个错误日志。 +/// 如果不同,将这个`FX`对象添加到`fxs`列表中。 +/// 4. 最后返回`fxs`列表,它包含了`bars`列表中所有找到的分型。 +/// +/// 这个函数的主要目的是找出`bars`列表中所有的顶分型和底分型,并确保它们是交替出现的。如果发现连续的两个分型标记相同,它会记录一个错误日志。 +/// +/// :param bars: 无包含关系K线列表 +/// :return: 分型列表 +pub fn check_fxs>(bars: &[B]) -> Vec { + let mut fxs: Vec = Vec::new(); + for window in bars[0..bars.len()].windows(3) { + if let [k1, k2, k3] = window + && let Some(fx1) = check_fx(k1.as_ref(), k2.as_ref(), k3.as_ref()) + { + // 与Python版本保持一致:过滤重复的相同标记分型 + // 默认情况下,fxs本身是顶底交替的,但是对于一些特殊情况下不是这样; 临时强制要求fxs序列顶底交替 + if fxs.len() >= 2 && fx1.mark == fxs.last().unwrap().mark { + eprintln!( + "check_fxs错误: {},{:?},{:?}", + k2.as_ref().dt, + fx1.mark, + fxs.last().unwrap().mark + ); + } else { + fxs.push(fx1); + } + } + } + fxs +} + +/// 查找分型 +/// +/// 函数计算逻辑: +/// +/// 1. 如果第二个`NewBar`对象的最高价和最低价都高于第一个和第三个`NewBar`对象的对应价格,那么它被认为是顶分型(G)。 +/// 在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.G`,并将其赋值给`fx`。 +/// +/// 2. 如果第二个`NewBar`对象的最高价和最低价都低于第一个和第三个`NewBar`对象的对应价格,那么它被认为是底分型(D)。 +/// 在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.D`,并将其赋值给`fx`。 +/// +/// 3. 函数最后返回`fx`,如果没有找到分型,`fx`将为`None`。 +/// +/// :param k1: 第一个`NewBar`对象 +/// :param k2: 第二个`NewBar`对象 +/// :param k3: 第三个`NewBar`对象 +/// :return: `FX`对象或`None` +pub fn check_fx(k1: &NewBar, k2: &NewBar, k3: &NewBar) -> Option { + // 顶分型判断 + if k1.high < k2.high && k2.high > k3.high && k1.low < k2.low && k2.low > k3.low { + return Some( + FXBuilder::default() + .symbol(k1.symbol.clone()) + .dt(k2.dt) + .mark(Mark::G) + .high(k2.high) + .low(k2.low) + .fx(k2.high) + .elements(vec![k1.clone(), k2.clone(), k3.clone()]) + .build() + .unwrap(), + ); + } + + // 底分型判断 + if k1.low > k2.low && k2.low < k3.low && k1.high > k2.high && k2.high < k3.high { + return Some( + FXBuilder::default() + .symbol(k1.symbol.clone()) + .dt(k2.dt) + .mark(Mark::D) + .high(k2.high) + .low(k2.low) + .fx(k2.low) + .elements(vec![k1.clone(), k2.clone(), k3.clone()]) + .build() + .unwrap(), + ); + } + + None +} + +/// 输入一串无包含关系K线,查找其中的一笔 +/// +/// :param bars: 无包含关系K线列表 +/// :return: +pub fn check_bi(bars: &[B]) -> (Option, &[B]) +where + B: AsRef, +{ + let fxs = check_fxs(bars); + if fxs.len() < 2 { + return (None, bars); + } + + let fx_a = &fxs[0]; + let (direction, fx_b) = match fx_a.mark { + Mark::D => { + // 对齐 Python max(..., key=fx.high): + // 仅当更高时替换;并列时保留首次出现的候选。 + let mut fx_b: Option<&FX> = None; + for x in fxs + .iter() + .filter(|x| x.mark == Mark::G && x.dt > fx_a.dt && x.fx > fx_a.fx) + { + match fx_b { + None => fx_b = Some(x), + Some(best) if x.high > best.high => fx_b = Some(x), + _ => {} + } + } + let fx_b = fx_b.cloned(); + (Direction::Up, fx_b) + } + Mark::G => { + // 对齐 Python min(..., key=fx.low): + // 仅当更低时替换;并列时保留首次出现的候选。 + let mut fx_b: Option<&FX> = None; + for x in fxs + .iter() + .filter(|x| x.mark == Mark::D && x.dt > fx_a.dt && x.fx < fx_a.fx) + { + match fx_b { + None => fx_b = Some(x), + Some(best) if x.low < best.low => fx_b = Some(x), + _ => {} + } + } + let fx_b = fx_b.cloned(); + (Direction::Down, fx_b) + } + }; + + let fx_b = match fx_b { + Some(fx) => fx, + None => return (None, bars), + }; + + // 确定bars_a的起始和结束索引 + let start_dt = fx_a.elements[0].dt; + let end_dt = fx_b.elements[2].dt; + + let start_idx = bars.partition_point(|bar| bar.as_ref().dt < start_dt); + let end_idx = bars.partition_point(|bar| bar.as_ref().dt <= end_dt); + if start_idx >= end_idx { + return (None, bars); + } + + let bars_a = &bars[start_idx..end_idx]; + + // 确定剩余bars_b的起始索引(基于fx_b.elements[0].dt) + let new_start_dt = fx_b.elements[0].dt; + let new_start_idx = bars.partition_point(|bar| bar.as_ref().dt < new_start_dt); + let bars_b = &bars[new_start_idx..]; + + // 判断包含关系 + let ab_include = (fx_a.high > fx_b.high && fx_a.low < fx_b.low) + || (fx_a.high < fx_b.high && fx_a.low > fx_b.low); + + // todo + let min_bi_len = 6; + // 检查成笔条件 + if !ab_include && bars_a.len() >= min_bi_len { + let fxs_filtered: Vec<_> = fxs + .iter() + .filter(|x| x.dt >= start_dt && x.dt <= end_dt) + .cloned() + .collect(); + + let bi = BIBuilder::default() + .symbol(fx_a.symbol.clone()) + .fx_a(fx_a.clone()) + .fx_b(fx_b.clone()) + .fxs(fxs_filtered) + .direction(direction) + .bars( + bars_a + .iter() + .map(|b| b.as_ref().to_owned()) + .collect::>(), + ) + .build() + .unwrap(); + + (Some(bi), bars_b) + } else { + (None, bars) + } +} + +/// # 格式化标准K线数据为 CZSC 标准数据结构 RawBar 列表 +/// +/// ## 参数 +/// +/// * `df` - 标准K线数据,DataFrame结构。每一行包含以下字段: +/// - `dt`: 时间 +/// - `symbol`: 股票代码 +/// - `open`: 开盘价 +/// - `close`: 收盘价 +/// - `high`: 最高价 +/// - `low`: 最低价 +/// - `vol`: 成交量 +/// - `amount`: 成交金额 +/// +/// * `freq` - K线级别 +/// +/// ## 返回值 +/// 返回一个 `RawBar` 列表,表示格式化后的K线数据。 +/// +/// ## 示例 +/// ```ignore +/// let df = // 创建 DataFrame 示例数据; +/// let freq = Freq::D; // K线级别 +/// let result = format_standard_kline(df, freq); +/// ``` +/// +pub fn format_standard_kline(df: DataFrame, freq: Freq) -> Result, AnalyzeErorr> { + // 获取各列的 Series 引用 + let symbol_col = df.column("symbol")?.str()?; + let dt_col = df.column("dt")?.datetime()?; + let open_col = df.column("open")?.f64()?; + let close_col = df.column("close")?.f64()?; + let high_col = df.column("high")?.f64()?; + let low_col = df.column("low")?.f64()?; + let vol_col = df.column("vol")?.f64()?; + let amount_col = df.column("amount")?.f64()?; + + // 获取时间单位信息 + let time_unit = dt_col.time_unit(); + + let len = df.height(); + let mut bars = Vec::with_capacity(len); + for i in 0..len { + // 时间戳数值 + let ts = dt_col.get(i).unwrap(); + // 根据时间单位转换为纳秒 + let ns = match time_unit { + TimeUnit::Milliseconds => ts * 1_000_000, + TimeUnit::Microseconds => ts * 1_000, + TimeUnit::Nanoseconds => ts, + }; + let dt_utc = DateTime::::from_timestamp_nanos(ns); + + // TODO 是否检查df有没有Nan? + let bar = RawBarBuilder::default() + .symbol(symbol_col.get(i).unwrap_or("")) + .id(i as i32) + .dt(dt_utc) + .freq(freq) + .open(open_col.get(i).unwrap()) + .close(close_col.get(i).unwrap()) + .high(high_col.get(i).unwrap()) + .low(low_col.get(i).unwrap()) + .vol(vol_col.get(i).unwrap()) + .amount(amount_col.get(i).unwrap()) + .build() + .context("Failed to create raw bar")?; + + bars.push(bar); + } + + Ok(bars) +} diff --git a/crates/czsc-core/src/lib.rs b/crates/czsc-core/src/lib.rs new file mode 100644 index 000000000..95034486f --- /dev/null +++ b/crates/czsc-core/src/lib.rs @@ -0,0 +1,11 @@ +//! czsc-core —缠论 core analyzer (FX / BI / ZS / CZSC). +//! +//! Migrated from rs-czsc 47ef6efa. Submodules are added incrementally as +//! Phase D progresses; see docs/superpowers/plans/2026-05-03-rust-czsc-migration.md. + +pub mod analyze; +pub mod objects; +pub mod utils; + +#[cfg(feature = "python")] +pub mod python; diff --git a/crates/czsc-core/src/objects/bar.rs b/crates/czsc-core/src/objects/bar.rs new file mode 100644 index 000000000..74ac38367 --- /dev/null +++ b/crates/czsc-core/src/objects/bar.rs @@ -0,0 +1,690 @@ +use super::freq::Freq; +use chrono::{DateTime, Utc}; +use derive_builder::Builder; +#[cfg(feature = "python")] +use parking_lot::RwLock; +#[cfg(feature = "python")] +use pyo3::{Py, types::PyDict}; + +use std::sync::Arc; + +use crate::utils::common::freq_to_chinese_string; +#[cfg(feature = "python")] +use crate::utils::common::{create_naive_pandas_timestamp, parse_python_datetime}; +#[cfg(feature = "python")] +use pyo3::basic::CompareOp; +#[cfg(feature = "python")] +use pyo3::types::PyDictMethods; +#[cfg(feature = "python")] +use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods}; +#[cfg(feature = "python")] +use pyo3::{IntoPyObject, PyObject}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +// 数据不会被修改,只需要共享一个只读视图 +pub type Symbol = Arc; + +/// 原始K线元素 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Builder)] +#[builder(setter(into), pattern = "owned")] +pub struct RawBar { + pub symbol: Symbol, + pub dt: DateTime, + #[builder(default = "Freq::Tick")] + pub freq: Freq, + /// id 必须是升序 + pub id: i32, + pub open: f64, + pub close: f64, + pub high: f64, + pub low: f64, + pub vol: f64, + pub amount: f64, + + #[cfg(feature = "python")] + #[builder(default = "Arc::new(RwLock::new(None))")] + pub cache: Arc>>>, +} + +impl RawBar { + /// 上影 + pub fn upper(&self) -> f64 { + self.high - self.open.max(self.close) + } + + /// 下影 + pub fn lower(&self) -> f64 { + self.open.min(self.close) - self.low + } + + /// 实体 + pub fn solid(&self) -> f64 { + (self.open - self.close).abs() + } +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl RawBar { + #[new] + #[pyo3(signature = (symbol, dt, freq, open, close, high, low, vol, amount, id=0))] + #[allow(clippy::too_many_arguments)] + fn new( + _py: Python, + symbol: &str, + dt: &Bound, + freq: Freq, + open: f64, + close: f64, + high: f64, + low: f64, + vol: f64, + amount: f64, + id: i32, + ) -> PyResult { + // 使用通用的日期时间解析函数 + let datetime_utc = parse_python_datetime(dt)?; + + Ok(RawBar { + symbol: symbol.into(), + dt: datetime_utc, + freq, + id, + open, + close, + high, + low, + vol, + amount, + cache: Arc::new(RwLock::new(None)), + }) + } + + #[getter] + fn symbol(&self) -> String { + self.symbol.to_string() + } + + #[getter] + fn dt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.dt) + } + + #[getter] + fn freq(&self) -> Freq { + self.freq + } + + #[getter] + fn id(&self) -> i32 { + self.id + } + + #[getter] + fn open(&self) -> f64 { + self.open + } + + #[getter] + fn close(&self) -> f64 { + self.close + } + + #[getter] + fn high(&self) -> f64 { + self.high + } + + #[getter] + fn low(&self) -> f64 { + self.low + } + + #[getter] + fn vol(&self) -> f64 { + self.vol + } + + #[getter] + fn amount(&self) -> f64 { + self.amount + } + + /// 实体部分(与原版CZSC兼容) + #[getter(solid)] + fn solid_py(&self) -> f64 { + self.solid() + } + + /// 上影线长度(与原版CZSC兼容) + #[getter(upper)] + fn upper_py(&self) -> f64 { + self.upper() + } + + /// 下影线长度(与原版CZSC兼容) + #[getter(lower)] + fn lower_py(&self) -> f64 { + self.lower() + } + + #[getter] + fn get_cache<'py>(&'py self, py: Python<'py>) -> Py { + // 首先尝试读锁获取缓存 + { + let cache_read = self.cache.read(); + if let Some(ref cached_dict) = *cache_read { + return cached_dict.clone_ref(py); + } + } + + // 如果缓存为空,使用写锁初始化并填充所有属性 + let mut cache_write = self.cache.write(); + if cache_write.is_none() { + let dict = PyDict::new(py); + // 一次性填充所有属性,避免重复创建 + dict.set_item("symbol", self.symbol.as_ref()).unwrap(); + dict.set_item("dt", create_naive_pandas_timestamp(py, self.dt).unwrap()) + .unwrap(); + dict.set_item("freq", freq_to_chinese_string(self.freq)) + .unwrap(); + dict.set_item("id", self.id).unwrap(); + dict.set_item("open", self.open).unwrap(); + dict.set_item("close", self.close).unwrap(); + dict.set_item("high", self.high).unwrap(); + dict.set_item("low", self.low).unwrap(); + dict.set_item("vol", self.vol).unwrap(); + dict.set_item("amount", self.amount).unwrap(); + dict.set_item("solid", self.solid()).unwrap(); + dict.set_item("upper", self.upper()).unwrap(); + dict.set_item("lower", self.lower()).unwrap(); + *cache_write = Some(dict.unbind()); + } + cache_write.as_ref().unwrap().clone_ref(py) + } + + #[setter] + #[gen_stub(skip)] // 跳过为了防止和 get_cache重复 + fn set_cache(&self, dict: Py) { + let mut cache_write = self.cache.write(); + *cache_write = Some(dict); + } + + /// 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 + #[getter] + pub fn __dict__(&self, py: Python) -> PyResult { + // 直接返回缓存的字典,避免重复创建 + Ok(self.get_cache(py).into()) + } + + /// 让对象表现得像记录,pandas DataFrame构造器会调用这个 + fn _asdict(&self, py: Python) -> PyResult { + self.__dict__(py) + } + + /// 转换为字典,便于创建 pandas DataFrame + fn to_dict(&self, py: Python) -> PyResult { + self.__dict__(py) + } + + /// 支持pickle序列化 + fn __reduce__(&self, py: Python) -> PyResult { + // RawBar.new takes `freq: Freq` (the PyO3 enum), not a string — + // pass the enum directly through pickle so the unpickle path + // (`RawBar(*args)`) succeeds. Stringifying with + // `freq_to_chinese_string` here would force the constructor to + // accept str|Freq and silently change the public API. + let cls = py.get_type::(); + let args = ( + self.symbol.as_ref(), + create_naive_pandas_timestamp(py, self.dt)?, + self.freq, + self.open, + self.close, + self.high, + self.low, + self.vol, + self.amount, + self.id, + ); + Ok(( + cls.into_any().unbind(), + args.into_pyobject(py)?.into_any().unbind(), + ) + .into_pyobject(py)? + .into_any() + .unbind()) + } + + /// 支持深拷贝 + fn __deepcopy__(&self, _memo: &Bound) -> PyResult { + Ok(self.clone()) + } + + fn __repr__(&self) -> String { + format!( + "RawBar(symbol={}, dt={}, freq={:?}, id={}, open={}, close={}, high={}, low={}, vol={}, amount={})", + self.symbol, + self.dt.format("%Y-%m-%d %H:%M:%S"), + self.freq, + self.id, + self.open, + self.close, + self.high, + self.low, + self.vol, + self.amount + ) + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(self == other), + CompareOp::Ne => Ok(self != other), + _ => Ok(false), + } + } +} + +impl PartialEq for RawBar { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.symbol == other.symbol + && self.dt == other.dt + && self.freq == other.freq + && self.open == other.open + && self.close == other.close + && self.high == other.high + && self.low == other.low + && self.vol == other.vol + && self.amount == other.amount + } +} + +/// 去除包含关系后的K线元素 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Builder)] +#[builder(setter(into), pattern = "owned")] +pub struct NewBar { + pub symbol: Symbol, + pub dt: DateTime, + #[builder(default = "Freq::Tick")] + pub freq: Freq, + /// id 必须是升序 + pub id: i32, + pub open: f64, + pub close: f64, + pub high: f64, + pub low: f64, + pub vol: f64, + pub amount: f64, + /// 存入具有包含关系的原始K线 + #[builder(default = "Vec::new()")] + pub elements: Vec, + + #[cfg(feature = "python")] + #[builder(default = "Arc::new(RwLock::new(None))")] + pub cache: Arc>>>, +} + +impl AsRef for NewBar { + fn as_ref(&self) -> &NewBar { + self + } +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl NewBar { + #[new] + #[pyo3(signature = (symbol, dt, freq, open, close, high, low, vol, amount, id=0, elements=None))] + #[allow(clippy::too_many_arguments)] + fn new( + _py: Python, + symbol: &str, + dt: &Bound, + freq: Freq, + open: f64, + close: f64, + high: f64, + low: f64, + vol: f64, + amount: f64, + id: i32, + elements: Option>, + ) -> PyResult { + // 使用通用的日期时间解析函数 + let datetime_utc = parse_python_datetime(dt)?; + + Ok(NewBar { + symbol: symbol.into(), + dt: datetime_utc, + freq, + id, + open, + close, + high, + low, + vol, + amount, + elements: elements.unwrap_or_default(), + cache: Arc::new(RwLock::new(None)), + }) + } + + #[getter] + fn symbol(&self) -> String { + self.symbol.to_string() + } + + #[getter] + fn dt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.dt) + } + + #[getter] + fn freq(&self) -> Freq { + self.freq + } + + #[getter] + fn id(&self) -> i32 { + self.id + } + + #[getter] + fn open(&self) -> f64 { + self.open + } + + #[getter] + fn close(&self) -> f64 { + self.close + } + + #[getter] + fn high(&self) -> f64 { + self.high + } + + #[getter] + fn low(&self) -> f64 { + self.low + } + + #[getter] + fn vol(&self) -> f64 { + self.vol + } + + #[getter] + fn amount(&self) -> f64 { + self.amount + } + #[getter] + fn get_cache<'py>(&'py self, py: Python<'py>) -> Py { + // 首先尝试读锁获取缓存 + { + let cache_read = self.cache.read(); + if let Some(ref cached_dict) = *cache_read { + return cached_dict.clone_ref(py); + } + } + + // 如果缓存为空,使用写锁初始化并填充所有属性 + let mut cache_write = self.cache.write(); + if cache_write.is_none() { + let dict = PyDict::new(py); + // 一次性填充所有属性,避免重复创建 + dict.set_item("symbol", self.symbol.as_ref()).unwrap(); + dict.set_item("dt", create_naive_pandas_timestamp(py, self.dt).unwrap()) + .unwrap(); + dict.set_item("freq", freq_to_chinese_string(self.freq)) + .unwrap(); + dict.set_item("id", self.id).unwrap(); + dict.set_item("open", self.open).unwrap(); + dict.set_item("close", self.close).unwrap(); + dict.set_item("high", self.high).unwrap(); + dict.set_item("low", self.low).unwrap(); + dict.set_item("vol", self.vol).unwrap(); + dict.set_item("amount", self.amount).unwrap(); + // 计算solid/upper/lower而不是调用方法 + dict.set_item("solid", (self.open - self.close).abs()) + .unwrap(); + dict.set_item("upper", self.high - self.open.max(self.close)) + .unwrap(); + dict.set_item("lower", self.open.min(self.close) - self.low) + .unwrap(); + dict.set_item("elements", py.None()).unwrap(); // 复杂对象先设为None + *cache_write = Some(dict.unbind()); + } + cache_write.as_ref().unwrap().clone_ref(py) + } + + #[setter] + #[gen_stub(skip)] // 跳过为了防止和 get_cache重复 + fn set_cache(&self, dict: Py) { + let mut cache_write = self.cache.write(); + *cache_write = Some(dict); + } + #[getter] + fn elements(&self) -> Vec { + self.elements.to_vec() + } + + /// 获取构成NewBar的原始K线列表(与elements相同,为兼容czsc库) + #[getter] + fn raw_bars(&self) -> Vec { + self.elements() + } + + fn __repr__(&self) -> String { + format!( + "NewBar(symbol={}, dt={}, freq={:?}, id={}, open={}, close={}, high={}, low={}, vol={}, amount={})", + self.symbol, + self.dt.format("%Y-%m-%d %H:%M:%S"), + self.freq, + self.id, + self.open, + self.close, + self.high, + self.low, + self.vol, + self.amount + ) + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(self == other), + CompareOp::Ne => Ok(self != other), + _ => Ok(false), + } + } +} + +impl PartialEq for NewBar { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.symbol == other.symbol + && self.dt == other.dt + && self.freq == other.freq + && self.open == other.open + && self.close == other.close + && self.high == other.high + && self.low == other.low + && self.vol == other.vol + && self.amount == other.amount + && self.elements == other.elements + } +} + +impl NewBar { + /// 创建新K线的辅助函数 + /// + /// 出现error的可能性比较小 + pub fn new_from_raw(bar: &RawBar) -> Self { + #[cfg(feature = "python")] + { + NewBarBuilder::default() + .symbol(bar.symbol.clone()) + .id(bar.id) + .freq(bar.freq) + .dt(bar.dt) + .open(bar.open) + .close(bar.close) + .high(bar.high) + .low(bar.low) + .vol(bar.vol) + .amount(bar.amount) + .elements(vec![bar.clone()]) + .cache(Arc::new(RwLock::new(None))) + .build() + .unwrap() + } + #[cfg(not(feature = "python"))] + { + NewBarBuilder::default() + .symbol(bar.symbol.clone()) + .id(bar.id) + .freq(bar.freq) + .dt(bar.dt) + .open(bar.open) + .close(bar.close) + .high(bar.high) + .low(bar.low) + .vol(bar.vol) + .amount(bar.amount) + .elements(vec![bar.clone()]) + .build() + .unwrap() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST: &str = "test"; + + #[test] + fn test_raw_bar_new() { + let bar = RawBarBuilder::default() + .symbol(Arc::from(TEST)) + .dt(Utc::now()) + .id(0) + .open(0) + .close(0) + .high(0) + .low(0) + .vol(0) + .amount(0) + .build() + .unwrap(); + assert_eq!(bar.freq, Freq::Tick); + } + + #[test] + fn test_raw_bar_calculations() { + // 测试上涨的情况 + let bar = RawBarBuilder::default() + .symbol(Arc::from(TEST)) + .dt(Utc::now()) + .id(0) + .open(10) + .close(12) + .high(15) + .low(8) + .vol(0) + .amount(0) + .build() + .unwrap(); + + // 15 - 12 = 3 + assert_eq!(bar.upper(), 3.0); + // 10 - 8 = 2 + assert_eq!(bar.lower(), 2.0); + // 10 - 12 = 2 + assert_eq!(bar.solid(), 2.0); + + // 测试下跌的情况 + let bar = RawBarBuilder::default() + .symbol(Arc::from(TEST)) + .dt(Utc::now()) + .id(0) + .open(12) + .close(10) + .high(15) + .low(8) + .vol(0) + .amount(0) + .build() + .unwrap(); + + // 15 - 12 = 3 + assert_eq!(bar.upper(), 3.0); + // 10 - 8 = 2 + assert_eq!(bar.lower(), 2.0); + + // |12 - 10| = 2 + assert_eq!(bar.solid(), 2.0); + } + + #[test] + fn test_new_bar() { + let bar = NewBarBuilder::default() + .symbol(Arc::from(TEST)) + .dt(Utc::now()) + .id(0) + .open(0) + .close(0) + .high(0) + .low(0) + .vol(0) + .amount(0) + .build() + .unwrap(); + assert_eq!(bar.freq, Freq::Tick); + } + + #[test] + fn test_new_bar_with_elements() { + let mut new_bar = NewBarBuilder::default() + .symbol(Arc::from(TEST)) + .dt(Utc::now()) + .id(0) + .open(0) + .close(0) + .high(0) + .low(0) + .vol(0) + .amount(0) + .build() + .unwrap(); + + let raw_bar1 = RawBarBuilder::default() + .symbol(Arc::from(TEST)) + .id(1) + .dt(Utc::now()) + .freq(Freq::Tick) + .open(10) + .close(12) + .high(15) + .low(8) + .vol(100) + .amount(1000) + .build() + .unwrap(); + + new_bar.elements.push(raw_bar1.clone()); + assert_eq!(new_bar.elements.len(), 1); + assert_eq!(new_bar.elements[0].id, 1); + } +} diff --git a/crates/czsc-core/src/objects/bi.rs b/crates/czsc-core/src/objects/bi.rs new file mode 100644 index 000000000..758fe92e6 --- /dev/null +++ b/crates/czsc-core/src/objects/bi.rs @@ -0,0 +1,604 @@ +use super::{ + bar::{NewBar, RawBar, Symbol}, + direction::Direction, + fake_bi::{FakeBI, create_fake_bis}, + fx::FX, +}; +use crate::utils::{corr::LinearRegression, rounded::RoundToNthDigit}; +use chrono::{DateTime, Utc}; +use derive_builder::Builder; +#[cfg(feature = "python")] +use parking_lot::RwLock; +use std::sync::Arc; + +#[cfg(feature = "python")] +use crate::utils::common::{create_naive_pandas_timestamp, create_ordered_dict}; +#[cfg(feature = "python")] +use pyo3::basic::CompareOp; +#[cfg(feature = "python")] +use pyo3::types::{PyDict, PyDictMethods}; +#[cfg(feature = "python")] +use pyo3::{Py, PyObject, PyResult, Python}; +#[cfg(feature = "python")] +use pyo3::{pyclass, pymethods}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +/// 笔 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Builder)] +#[builder(setter(into))] +pub struct BI { + pub symbol: Symbol, + /// 笔开始的分型 + pub fx_a: FX, + /// 笔结束的分型 + pub fx_b: FX, + /// 笔内部的分型列表 + pub fxs: Vec, + pub direction: Direction, + pub bars: Vec, + #[cfg(feature = "python")] + #[builder(default = "Arc::new(RwLock::new(None))")] + pub cache: Arc>>>, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl BI { + #[new] + fn new( + symbol: String, + direction: Direction, + fx_a: FX, + fx_b: FX, + fxs: Vec, + bars: Vec, + ) -> Self { + BI { + symbol: symbol.into(), + direction, + fx_a, + fx_b, + fxs: fxs.into_iter().collect(), + bars: bars.into_iter().collect(), + cache: Arc::new(RwLock::new(None)), + } + } + + #[getter] + fn symbol(&self) -> String { + self.symbol.to_string() + } + + #[getter] + fn direction(&self) -> Direction { + self.direction + } + + #[getter] + fn high(&self) -> f64 { + self.get_high() + } + + #[getter] + fn low(&self) -> f64 { + self.get_low() + } + #[getter] + fn get_cache<'py>(&'py self, py: Python<'py>) -> Py { + // 首先尝试读锁获取缓存 + { + let cache_read = self.cache.read(); + if let Some(ref cached_dict) = *cache_read { + return cached_dict.clone_ref(py); + } + } + + // 如果缓存为空,使用写锁初始化并填充所有属性 + let mut cache_write = self.cache.write(); + if cache_write.is_none() { + let dict = PyDict::new(py); + // 一次性填充所有属性,避免重复创建 + dict.set_item("symbol", self.symbol.as_ref()).unwrap(); + dict.set_item("direction", self.direction).unwrap(); + dict.set_item("high", self.get_high()).unwrap(); + dict.set_item("low", self.get_low()).unwrap(); + dict.set_item("fx_a", py.None()).unwrap(); // 复杂对象先设为None + dict.set_item("fx_b", py.None()).unwrap(); // 复杂对象先设为None + dict.set_item("bars", py.None()).unwrap(); // 复杂对象先设为None + *cache_write = Some(dict.unbind()); + } + cache_write.as_ref().unwrap().clone_ref(py) + } + + #[setter] + #[gen_stub(skip)] // 跳过为了防止和 get_cache重复 + fn set_cache(&self, dict: Py) { + let mut cache_write = self.cache.write(); + *cache_write = Some(dict); + } + + /// 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 + #[getter] + pub fn __dict__(&self, py: Python) -> PyResult { + // 直接返回缓存的字典,避免重复创建 + Ok(self.get_cache(py).into()) + } + + #[getter] + fn sdt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.start_dt()) + } + + #[getter] + fn edt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.end_dt()) + } + + #[getter] + fn fx_a(&self) -> FX { + self.fx_a.clone() + } + + #[getter] + fn fx_b(&self) -> FX { + self.fx_b.clone() + } + + #[getter] + fn fxs(&self) -> Vec { + self.fxs.to_vec() + } + + /// 获取构成笔的NewBar列表 + #[getter] + fn bars(&self) -> Vec { + self.bars.to_vec() + } + + /// 价差力度 + #[getter] + fn power(&self) -> f64 { + self.get_power() + } + + /// 价差力度(别名) + #[getter] + fn power_price(&self) -> f64 { + self.get_power_price() + } + + /// 成交量力度 + #[getter] + fn power_volume(&self) -> f64 { + self.get_power_volume() + } + + /// SNR 度量力度 + #[getter] + fn power_snr(&self) -> f64 { + self.get_power_snr() + } + + /// 笔的涨跌幅 + #[getter] + fn change(&self) -> f64 { + self.get_change() + } + + /// 笔内部的信噪比 + #[allow(non_snake_case)] + #[getter] + fn SNR(&self) -> f64 { + self.get_snr() + } + + /// 笔内部高低点之间的斜率 + #[getter] + fn slope(&self) -> f64 { + self.get_slope() + } + + /// 笔内部价格的加速度 + #[getter] + fn acceleration(&self) -> f64 { + self.get_acceleration() + } + + /// 笔的无包含关系K线数量 + #[getter] + fn length(&self) -> usize { + self.get_length() + } + + /// 笔的原始K线close单变量线性回归拟合优度 + #[getter] + fn rsq(&self) -> f64 { + self.get_rsq() + } + + /// 笔的斜边长度 + #[getter] + fn hypotenuse(&self) -> f64 { + self.get_hypotenuse() + } + + /// 笔的斜边与竖直方向的夹角,角度越大,力度越大 + #[getter] + fn angle(&self) -> f64 { + self.get_angle() + } + + /// 构成笔的原始K线序列,不包含首尾分型的首根K线 + #[getter] + fn raw_bars(&self) -> Vec { + self.get_raw_bars().into_iter().collect() + } + + /// 笔的内部分型连接得到近似次级别笔列表 + #[getter] + fn fake_bis(&self) -> Vec { + self.create_fake_bis().into_iter().collect() + } + + /// 缓存字典(与 czsc 库兼容) + #[getter] + fn cache(&self, py: Python) -> PyResult { + create_ordered_dict(py) + } + + /// 获取缓存值,如果不存在则返回默认值(与 czsc 库兼容) + fn get_cache_with_default(&self, _key: &str, default_value: f64) -> f64 { + default_value // 暂时返回默认值,因为我们的缓存是空的 + } + + /// 获取线性价格(与 czsc 库兼容) + fn get_price_linear(&self, n: usize) -> f64 { + // 简单实现:基于笔的高低点进行线性插值 + if n == 0 { + if matches!(self.direction, Direction::Up) { + self.low() + } else { + self.high() + } + } else if matches!(self.direction, Direction::Up) { + self.high() + } else { + self.low() + } + } + + fn __repr__(&self) -> String { + format!( + "BI(symbol={}, sdt={}, edt={}, direction={:?}, high={}, low={})", + self.symbol, + self.start_dt().format("%Y-%m-%d %H:%M:%S"), + self.end_dt().format("%Y-%m-%d %H:%M:%S"), + self.direction, + self.high(), + self.low() + ) + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(self == other), + CompareOp::Ne => Ok(self != other), + _ => Ok(false), + } + } +} + +impl PartialEq for BI { + fn eq(&self, other: &Self) -> bool { + self.symbol == other.symbol + && self.fx_a == other.fx_a + && self.fx_b == other.fx_b + && self.fxs == other.fxs + && self.direction == other.direction + && self.bars == other.bars + } +} + +impl BI { + pub fn start_dt(&self) -> DateTime { + self.fx_a.dt + } + + pub fn end_dt(&self) -> DateTime { + self.fx_b.dt + } + + /// 笔的内部分型连接得到近似次级别笔列表 + pub fn create_fake_bis(&self) -> Vec { + create_fake_bis(&self.fxs) + } + + pub fn get_high(&self) -> f64 { + self.fx_a.high.max(self.fx_b.high) + } + + pub fn get_low(&self) -> f64 { + self.fx_a.low.min(self.fx_b.low) + } + + /// 价差力度 + pub fn get_power_price(&self) -> f64 { + // 保留2位小数 + (self.fx_b.fx - self.fx_a.fx).abs().round_to_2_digit() + } + + /// 价差力度 + pub fn get_power(&self) -> f64 { + self.get_power_price() + } + + /// 成交量力度 + pub fn get_power_volume(&self) -> f64 { + if self.bars.len() <= 2 { + return 0.0; + } + // sum([x.vol for x in self.bars[1:-1]]) + self.bars[1..self.bars.len() - 1] + .iter() + .map(|x| x.vol) + .sum() + } + + /// SNR 度量力度 + /// SNR越大,说明内部走势越顺畅,力度也就越大 + pub fn get_power_snr(&self) -> f64 { + // return round(self.SNR, 4) + (self.get_snr() * 10000.0).round() / 10000.0 + } + + /// 笔的涨跌幅 + pub fn get_change(&self) -> f64 { + // 防止除以0 + if self.fx_a.fx == 0.0 { + return 0.0; + } + // (结束分型 - 开始分型) / 开始分型,保留4位小数 + ((self.fx_b.fx - self.fx_a.fx) / self.fx_a.fx).round_to_4_digit() + } + + /// 笔内部的信噪比 + pub fn get_snr(&self) -> f64 { + let raw_bars = self.get_raw_bars(); + let n = raw_bars.len(); + + match n { + 0 => 0.0, + 1 => { + let bar = &raw_bars[0]; + (bar.close - bar.open).abs() + } + _ => { + // 首尾变化的绝对值 - 按照Python版本逻辑 + let total_change = (raw_bars[n - 1].close - raw_bars[0].open).abs(); + // 每根K线开收价差的绝对值之和 + let diff_abs_change = raw_bars + .iter() + .fold(0.0, |sum, bar| sum + (bar.close - bar.open).abs()); + + if diff_abs_change == 0.0 { + 0.0 + } else { + total_change / diff_abs_change + } + } + } + } + + /// 笔内部高低点之间的斜率 + pub fn get_slope(&self) -> f64 { + let raw_bars = self.get_raw_bars(); + let closes: Vec = raw_bars.iter().map(|bar| bar.close).collect(); + + if closes.len() < 2 { + return 0.0; + } + + let n = closes.len() as f64; + let x_mean = (n - 1.0) / 2.0; + let y_mean = closes.iter().sum::() / n; + + let numerator: f64 = closes + .iter() + .enumerate() + .map(|(i, y)| { + let x = i as f64; + (x - x_mean) * (y - y_mean) + }) + .sum(); + + let denominator: f64 = (0..closes.len()) + .map(|i| { + let x = i as f64; + (x - x_mean).powi(2) + }) + .sum(); + + if denominator == 0.0 { + 0.0 + } else { + numerator / denominator + } + } + + /// 笔内部价格的加速度 + /// + /// 负号表示开口向下;正号表示开口向上。数值越大,表示加速度越大。 + pub fn get_acceleration(&self) -> f64 { + let raw_bars = self.get_raw_bars(); + let closes: Vec = raw_bars.iter().map(|bar| bar.close).collect(); + + if closes.len() < 3 { + return 0.0; + } + + // 使用与Python numpy.polyfit(degree=2)兼容的二次多项式拟合 + // 返回二次项系数 (a in ax² + bx + c) + self.numpy_compatible_quadratic_fit(&closes) + } + + /// numpy兼容的二次多项式拟合 + /// 返回二次项系数,与Python的numpy.polyfit(range(len(c)), c, 2)[0]保持一致 + fn numpy_compatible_quadratic_fit(&self, y_values: &[f64]) -> f64 { + let n = y_values.len() as f64; + + // 构建设计矩阵 X 和目标向量 y + // 对于二次拟合: y = a*x² + b*x + c + // 矩阵形式: [x² x 1] * [a; b; c] = y + + let mut sum_x4 = 0.0; + let mut sum_x3 = 0.0; + let mut sum_x2 = 0.0; + let mut sum_x = 0.0; + let sum_1 = n; + let mut sum_x2_y = 0.0; + let mut sum_x_y = 0.0; + let mut sum_y = 0.0; + + for (i, &y) in y_values.iter().enumerate() { + let x = i as f64; + let x2 = x * x; + let x3 = x2 * x; + let x4 = x3 * x; + + sum_x4 += x4; + sum_x3 += x3; + sum_x2 += x2; + sum_x += x; + sum_x2_y += x2 * y; + sum_x_y += x * y; + sum_y += y; + } + + // 正规方程组: + // [sum_x4 sum_x3 sum_x2] [a] [sum_x2_y] + // [sum_x3 sum_x2 sum_x ] [b] = [sum_x_y ] + // [sum_x2 sum_x sum_1 ] [c] [sum_y ] + + // 使用克莱默法则求解 a (二次项系数) + let det = sum_x4 * (sum_x2 * sum_1 - sum_x * sum_x) + - sum_x3 * (sum_x3 * sum_1 - sum_x * sum_x2) + + sum_x2 * (sum_x3 * sum_x - sum_x2 * sum_x2); + + if det.abs() < 1e-10 { + return 0.0; + } + + let det_a = sum_x2_y * (sum_x2 * sum_1 - sum_x * sum_x) + - sum_x_y * (sum_x3 * sum_1 - sum_x * sum_x2) + + sum_y * (sum_x3 * sum_x - sum_x2 * sum_x2); + + det_a / det + } + + /// 笔的无包含关系K线数量 + pub fn get_length(&self) -> usize { + self.bars.len() + } + + /// 笔的原始K线close单变量线性回归 拟合优度 + pub fn get_rsq(&self) -> f64 { + let raw_bars = self.get_raw_bars(); + let closes: Vec = raw_bars.iter().map(|bar| bar.close).collect(); + + if closes.is_empty() { + return 0.0; + } + + let res = closes.single_linear(); + // 保留4位小数 + (res.r2 * 10000.0).round() / 10000.0 + } + + /// 构成笔的原始K线序列,不包含首尾分型的首根K线 + pub fn get_raw_bars(&self) -> Vec { + if self.bars.len() > 2 { + let capacity = self.bars[1..self.bars.len() - 1] + .iter() + .map(|bar| bar.elements.len()) + .sum(); + + let mut value = Vec::with_capacity(capacity); + for bar in &self.bars[1..self.bars.len() - 1] { + value.extend_from_slice(&bar.elements); + } + value + } else { + Vec::new() + } + } + + /// 笔的斜边长度 + pub fn get_hypotenuse(&self) -> f64 { + (self.get_power_price().powi(2) + (self.get_raw_bars().len() as f64).powi(2)).sqrt() + } + + /// 笔的斜边与竖直方向的夹角,角度越大,力度越大 + pub fn get_angle(&self) -> f64 { + let angle_rad = (self.get_power_price() / self.get_hypotenuse()).asin(); + let angle_deg = angle_rad * 180.0 / std::f64::consts::PI; + (angle_deg * 100.0).round() / 100.0 + } +} + +pub fn print_bi(bis: &Vec) { + println!( + "{:<10} {:<12} {:<12} {:>6} {:<8}", + "Direction", "FX_A (Mark)", "FX_B (Mark)", "FXs", "Bars" + ); + + println!("{:-<10} {:-<12} {:-<12} {:-<6} {:-<8}", "", "", "", "", ""); + + // 数据行 + for bi in bis { + let dir_icon = match bi.direction { + Direction::Up => "↑", + Direction::Down => "↓", + }; + + println!( + "{:<10} {:<12} {:<12} {:>6} {:>4} bars", + dir_icon, + bi.fx_a.mark, + bi.fx_b.mark, + bi.fxs.len(), + bi.bars.len() + ); + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::objects::fx::tests::create_d_fx; + use std::sync::Arc; + + /// 创建一个测试用的笔 + pub fn create_bi() -> BI { + let fx_a = create_d_fx(); + let fx_b = create_d_fx(); + + BIBuilder::default() + .symbol(Arc::from("TEST".to_string())) + .fx_a(fx_a.clone()) + .fx_b(fx_b.clone()) + .fxs(vec![fx_a.clone(), fx_b]) + .direction(Direction::Up) + .bars(fx_a.elements) + .build() + .unwrap() + } + + #[test] + fn test_new_bi() { + create_bi(); + } +} diff --git a/crates/czsc-core/src/objects/direction.rs b/crates/czsc-core/src/objects/direction.rs new file mode 100644 index 000000000..95358ef25 --- /dev/null +++ b/crates/czsc-core/src/objects/direction.rs @@ -0,0 +1,160 @@ +#[cfg(feature = "python")] +use pyo3::pyclass; +#[cfg(feature = "python")] +use pyo3::types::PyAnyMethods; +#[cfg(feature = "python")] +use pyo3::{Bound, IntoPyObject, PyAny, PyErr, PyObject, PyResult, Python, pymethods}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::gen_stub_pyclass_enum; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::gen_stub_pymethods; +use strum_macros::{AsRefStr, Display, EnumString}; + +/// 方向 +#[cfg_attr(feature = "python", gen_stub_pyclass_enum)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Copy, PartialEq, EnumString, AsRefStr, Display)] +pub enum Direction { + /// 向上 + #[strum(serialize = "向上")] + Up, + /// 向下 + #[strum(serialize = "向下")] + Down, +} +#[cfg(feature = "python")] +#[gen_stub_pymethods] +#[cfg(feature = "python")] +#[pymethods] +impl Direction { + /// 支持深拷贝 + fn __deepcopy__(&self, _memo: &Bound) -> PyResult { + Ok(*self) + } + + /// 支持pickle序列化 + fn __reduce__(&self) -> PyResult<(PyObject, PyObject)> { + Python::with_gil(|py| { + let cls = py.get_type::(); + let args = match self { + Direction::Up => ("Up",), + Direction::Down => ("Down",), + }; + Ok((cls.into(), args.into_pyobject(py)?.into_any().unbind())) + }) + } + + #[new] + fn new(value: &str) -> PyResult { + Ok(match value { + "Up" | "向上" => Direction::Up, + "Down" | "向下" => Direction::Down, + _ => { + return Err(PyErr::new::(format!( + "Unknown direction value: {value}" + ))); + } + }) + } + + /// 获取方向的字符串值(与 czsc 库兼容) + #[getter] + fn value(&self) -> &'static str { + match self { + Direction::Up => "向上", + Direction::Down => "向下", + } + } + + fn __str__(&self) -> &'static str { + self.value() + } + + fn __repr__(&self) -> String { + format!( + "Direction.{}", + match self { + Direction::Up => "Up", + Direction::Down => "Down", + } + ) + } + + fn __richcmp__( + &self, + other: pyo3::Bound<'_, pyo3::PyAny>, + op: pyo3::basic::CompareOp, + ) -> pyo3::PyResult { + use pyo3::basic::CompareOp; + match op { + CompareOp::Eq => { + if let Ok(other_dir) = other.extract::() { + return Ok(*self == other_dir); + } + if let Ok(other_value) = other.getattr("value") + && let Ok(other_str) = other_value.extract::() + { + return Ok(self.value() == other_str.as_str()); + } + Ok(false) + } + CompareOp::Ne => { + if let Ok(other_dir) = other.extract::() { + return Ok(*self != other_dir); + } + if let Ok(other_value) = other.getattr("value") + && let Ok(other_str) = other_value.extract::() + { + return Ok(self.value() != other_str.as_str()); + } + Ok(true) + } + _ => Ok(false), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + + #[test] + fn test_string_to_direction() { + // 测试从字符串解析为 Direction (EnumString) + assert_eq!(Direction::from_str("向上").unwrap(), Direction::Up); + assert_eq!(Direction::from_str("向下").unwrap(), Direction::Down); + + // 测试无效输入 + assert!(Direction::from_str("向左").is_err()); + } + + #[test] + fn test_direction_to_string() { + // 测试 Display trait + assert_eq!(Direction::Up.to_string(), "向上"); + assert_eq!(Direction::Down.to_string(), "向下"); + + // 测试 AsRefStr trait + assert_eq!(Direction::Up.as_ref(), "向上"); + assert_eq!(Direction::Down.as_ref(), "向下"); + } + + #[test] + fn test_debug_format() { + // 测试 Debug trait + assert_eq!(format!("{:?}", Direction::Up), "Up"); + assert_eq!(format!("{:?}", Direction::Down), "Down"); + } + + #[test] + fn test_clone_and_eq() { + // 测试 Clone 和 PartialEq + let dir1 = Direction::Up; + let dir2 = dir1; + assert_eq!(dir1, dir2); + + let dir3 = Direction::Down; + assert_ne!(dir1, dir3); + } +} diff --git a/crates/czsc-core/src/objects/errors.rs b/crates/czsc-core/src/objects/errors.rs new file mode 100644 index 000000000..87b0093d0 --- /dev/null +++ b/crates/czsc-core/src/objects/errors.rs @@ -0,0 +1,37 @@ +use error_macros::CZSCErrorDerive; +use error_support::expand_error_chain; +use thiserror::Error; + +#[cfg(feature = "python")] +use pyo3::{PyErr, exceptions::PyException}; + +#[derive(Debug, Error, CZSCErrorDerive)] +pub enum ObjectError { + // Factor + #[error("Factor.signals_all must contain at least one signal")] + FactorSignalsAllEmpty, + #[error("Invalid signals array format: {0:?}")] + InvalidSignalsArrayFormat(Option), + + // Signals + #[error("Signal.score {0} must be between 0 and 100")] + ScoreOutOfRange(i32), + #[error("Invalid signal format: {0:?}")] + InvalidSignalsFormat(Option), + #[error("Invalid score format: {0}")] + InvalidScoreFormat(String), + #[error("Signal key '{0}' does not exist in signals collection")] + SignalKeyNotFound(String), + #[error("Invalid signal format: missing underscore separator in '{0}'")] + MalformedSignalValue(String), + + #[error("{}", expand_error_chain(.0))] + Unexpected(anyhow::Error), +} + +#[cfg(feature = "python")] +impl From for PyErr { + fn from(e: ObjectError) -> Self { + PyException::new_err(e.to_string()) + } +} diff --git a/crates/czsc-core/src/objects/event.rs b/crates/czsc-core/src/objects/event.rs new file mode 100644 index 000000000..c7bbf6d69 --- /dev/null +++ b/crates/czsc-core/src/objects/event.rs @@ -0,0 +1,605 @@ +// czsc-only: pyo3 imports + super::operate / super::signal Python wrappers +// gated behind the `python` feature for non-python builds. Sha256 is used in +// the (non-python) Event helpers so it stays unconditional. +// See docs/MIGRATION_NOTES.md §2.4. +#![allow(unused)] +use anyhow::{Context, anyhow}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::cell::Ref; +use std::collections::{HashMap, HashSet}; +use std::str::FromStr; + +use super::operate::Operate; +use super::signal::{ANY, Signal}; + +#[cfg(feature = "python")] +use pyo3::exceptions::PyValueError; +#[cfg(feature = "python")] +use pyo3::prelude::*; +#[cfg(feature = "python")] +use pyo3::types::{PyDict, PyDictMethods}; +#[cfg(feature = "python")] +use super::operate::PyOperate; +#[cfg(feature = "python")] +use super::signal::PySignal; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Event { + pub operate: Operate, + + /// 必须全部满足的信号,允许为空 + pub signals_all: Vec, + + /// 满足其中任一信号,允许为空 + pub signals_any: Vec, + + /// 不能满足其中任一信号,允许为空 + pub signals_not: Vec, + + /// 默认可以为"" + #[serde(default)] + pub name: String, + + /// SHA256哈希 + #[serde(default)] + pub sha256: String, +} + +impl Event { + fn py_repr_list_str(items: &[String]) -> String { + if items.is_empty() { + "[]".to_string() + } else { + let body = items + .iter() + .map(|x| format!("'{}'", x.replace('\'', "\\'"))) + .collect::>() + .join(", "); + format!("[{body}]") + } + } + + fn py_repr_for_hash(&self) -> String { + let operate = self.operate.to_chinese().replace('\'', "\\'"); + let all = self + .signals_all + .iter() + .map(|s| s.to_string()) + .collect::>(); + let any = self + .signals_any + .iter() + .map(|s| s.to_string()) + .collect::>(); + let not = self + .signals_not + .iter() + .map(|s| s.to_string()) + .collect::>(); + // 对齐 Python Event.__post_init__: + // hashlib.sha256(str({'operate':..., 'signals_all':..., 'signals_any':..., 'signals_not':...}).encode()).hexdigest()[:4] + format!( + "{{'operate': '{operate}', 'signals_all': {}, 'signals_any': {}, 'signals_not': {}}}", + Self::py_repr_list_str(&all), + Self::py_repr_list_str(&any), + Self::py_repr_list_str(¬), + ) + } + + #[allow(unused)] + fn new( + operate: Operate, + signals_all: Vec, + signals_any: Vec, + signals_not: Vec, + name: Option, + ) -> anyhow::Result { + let mut event = Self { + operate, + signals_all, + signals_any, + signals_not, + name: name.unwrap_or_default(), + sha256: String::new(), + }; + + event.compute_hash_name(); + Ok(event) + } + + /// 计算 Hash (除了name字段) + pub fn compute_sha8(&self) -> String { + let mut hasher = Sha256::new(); + hasher.update(self.py_repr_for_hash().as_bytes()); + let result = hasher.finalize(); + hex::encode(result)[..4].to_uppercase() + } + + /// 更新名称 + fn compute_hash_name(&mut self) { + let digest = self.compute_sha8(); + self.sha256 = digest.clone(); + + if !self.name.is_empty() { + let base = self + .name + .split('#') + .next() + .map(|s| s.to_string()) + .unwrap_or_else(|| self.name.clone()); + self.name = format!("{base}#{digest}"); + } else { + self.name = format!("{:?}#{}", self.operate, digest); + } + } + + /// 重新计算事件 hash 名称,确保与 Python 端一致(name#HASH) + pub fn refresh_hash_name(&mut self) { + self.compute_hash_name(); + } + + pub fn matches_signals(&self, signals: Ref>) -> bool { + // 创建一个信号映射表,key为信号的key部分,value为信号的value部分 + let mut signal_map: HashMap = HashMap::new(); + for signal in signals.iter() { + signal_map.insert(signal.key(), signal.value()); + } + self.matches_signals_dict(&signal_map) + } + + /// 基于字典的信号匹配逻辑,支持"任意"关键字 + pub fn matches_signals_dict(&self, signal_dict: &HashMap) -> bool { + // 1) signals_not: 任何一个匹配 -> 不满足 + for s_not in &self.signals_not { + if s_not.is_match(signal_dict) { + return false; + } + } + + // 2) signals_all: 必须全部满足 + for s_all in &self.signals_all { + if !s_all.is_match(signal_dict) { + return false; + } + } + + // 3) signals_any: 如果非空,至少满足一个 + if !self.signals_any.is_empty() { + let any_matched = self + .signals_any + .iter() + .any(|s_any| s_any.is_match(signal_dict)); + if !any_matched { + return false; + } + } + + true + } + + /// 序列化为字典 + pub fn dump(&self) -> serde_json::Value { + serde_json::json!({ + "name": self.name, + "operate": self.operate.to_chinese(), + "signals_all": self.signals_all.iter().map(|s| format!("{s}")).collect::>(), + "signals_any": self.signals_any.iter().map(|s| format!("{s}")).collect::>(), + "signals_not": self.signals_not.iter().map(|s| format!("{s}")).collect::>(), + }) + } + + /// 从字典加载 + pub fn load(data: &serde_json::Value) -> anyhow::Result { + let operate_str = data["operate"] + .as_str() + .ok_or_else(|| anyhow!("operate must be string"))?; + + let operate = match operate_str { + "持多" | "HL" => Operate::HL, + "持空" | "HS" => Operate::HS, + "持币" | "HO" => Operate::HO, + "开多" | "LO" => Operate::LO, + "平多" | "LE" => Operate::LE, + "开空" | "SO" => Operate::SO, + "平空" | "SE" => Operate::SE, + _ => return Err(anyhow!("Unknown operate: {operate_str}")), + }; + + let signals_all: Vec = data + .get("signals_all") + .and_then(|v| v.as_array()) + .unwrap_or(&vec![]) + .iter() + .map(|v| v.as_str().unwrap()) + .map(|s| s.parse()) + .collect::, _>>()?; + + let signals_any: Vec = data + .get("signals_any") + .and_then(|v| v.as_array()) + .unwrap_or(&vec![]) + .iter() + .map(|v| v.as_str().unwrap()) + .map(|s| s.parse()) + .collect::, _>>()?; + + let signals_not: Vec = data + .get("signals_not") + .and_then(|v| v.as_array()) + .unwrap_or(&vec![]) + .iter() + .map(|v| v.as_str().unwrap()) + .map(|s| s.parse()) + .collect::, _>>()?; + + let name = data + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + Self::new(operate, signals_all, signals_any, signals_not, Some(name)) + } + + /// 返回一个迭代器,可以依次遍历 signals_all 和 signals_any 和 signals_not + pub fn all_signals(&self) -> impl Iterator { + self.signals_all + .iter() + .chain(self.signals_any.iter().chain(self.signals_not.iter())) + } +} + +#[cfg(feature = "python")] +impl<'py> FromPyObject<'py> for Event { + fn extract_bound(ob: &Bound<'py, pyo3::PyAny>) -> PyResult { + // 接受 dict 形式输入 + if let Ok(dict) = ob.downcast::() { + // 1) operate 必须存在 + let operate = dict + .get_item("operate")? + .ok_or(PyValueError::new_err("缺少字段: 'operate'"))? + .extract::() + .map_err(|e| PyValueError::new_err(format!("无法解析字段 'operate': {e}")))?; + + // 2) name 可选,缺省为空字符串 + let name = match dict.get_item("name")? { + Some(name) => name.extract::().unwrap_or_default(), + None => String::new(), + }; + + // 3) signals + let signals_all = dict + .get_item("signals_all")? + .ok_or(PyValueError::new_err("缺少字段: 'signals_all'"))? + .extract::>()?; + + let signals_any = dict + .get_item("signals_any")? + .ok_or(PyValueError::new_err("缺少字段: 'signals_any'"))? + .extract::>()?; + + let signals_not = dict + .get_item("signals_not")? + .ok_or(PyValueError::new_err("缺少字段: 'signals_not'"))? + .extract::>()?; + + let mut event = Event { + operate, + signals_all, + signals_any, + signals_not, + name, + sha256: String::new(), + }; + + event.compute_hash_name(); + return Ok(event); + } + + // 如果是 str + if let Ok(s) = ob.extract::() { + let event = serde_json::from_str(&s) + .map_err(|e| PyValueError::new_err(format!("str无法反序列化成 Event: {e}")))?; + return Ok(event); + } + + // 非 dict 类型:报错 + Err(PyValueError::new_err( + "期望 dict:{ 'operate': ..., 'signals_all': [...], 'signals_any': [...], 'signals_not': [...], 'name': '...' }", + )) + } +} + +/// Python可见的Event包装器 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(name = "Event", module = "czsc._native"))] +#[derive(Debug, Clone)] +pub struct PyEvent { + pub inner: Event, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl PyEvent { + #[new] + #[pyo3(signature = (operate, signals_all = vec![], signals_any = vec![], signals_not = vec![], name = String::new()))] + fn new_py( + operate: PyOperate, + signals_all: Vec, + signals_any: Vec, + signals_not: Vec, + name: String, + ) -> PyResult { + let signals_all: Vec = signals_all.into_iter().map(|s| s.inner).collect(); + let signals_any: Vec = signals_any.into_iter().map(|s| s.inner).collect(); + let signals_not: Vec = signals_not.into_iter().map(|s| s.inner).collect(); + + let inner = Event::new( + operate.inner, + signals_all, + signals_any, + signals_not, + Some(name), + ) + .map_err(|e| PyValueError::new_err(format!("创建Event失败: {e}")))?; + Ok(Self { inner }) + } + + #[classmethod] + fn from_dict( + _cls: &Bound<'_, pyo3::types::PyType>, + dict: &Bound<'_, PyDict>, + ) -> PyResult { + let event = Event::extract_bound(dict.as_any())?; + Ok(Self { inner: event }) + } + + #[classmethod] + fn from_json(_cls: &Bound<'_, pyo3::types::PyType>, json_str: String) -> PyResult { + let inner = serde_json::from_str(&json_str) + .map_err(|e| PyValueError::new_err(format!("JSON解析失败: {e}")))?; + Ok(Self { inner }) + } + + #[getter] + fn operate(&self) -> PyOperate { + PyOperate { + inner: self.inner.operate, + } + } + + #[getter] + fn signals_all(&self) -> Vec { + self.inner + .signals_all + .iter() + .map(|s| PySignal { inner: s.clone() }) + .collect() + } + + #[getter] + fn signals_any(&self) -> Vec { + self.inner + .signals_any + .iter() + .map(|s| PySignal { inner: s.clone() }) + .collect() + } + + #[getter] + fn signals_not(&self) -> Vec { + self.inner + .signals_not + .iter() + .map(|s| PySignal { inner: s.clone() }) + .collect() + } + + #[getter] + fn name(&self) -> String { + self.inner.name.clone() + } + + /// 计算SHA8哈希值 + fn compute_sha8(&self) -> String { + self.inner.compute_sha8() + } + + /// 获取所有唯一信号(字符串格式,兼容原Python API) + #[getter] + fn unique_signals(&self) -> Vec { + let mut signals = HashSet::new(); + + // 收集所有信号的字符串表示 + for signal in &self.inner.signals_all { + signals.insert(signal.to_string()); + } + for signal in &self.inner.signals_any { + signals.insert(signal.to_string()); + } + for signal in &self.inner.signals_not { + signals.insert(signal.to_string()); + } + + signals.into_iter().collect() + } + + /// 获取SHA256哈希 + #[getter] + fn sha256(&self) -> String { + self.inner.sha256.clone() + } + + /// 判断事件是否匹配信号集合,返回是否匹配 + /// 支持多种参数类型:Dict[str, str] 或 Dict[str, Signal] 或 Vec + fn is_match(&self, signals: &Bound<'_, pyo3::PyAny>) -> PyResult { + if let Ok(dict) = signals.downcast::() { + // 处理字典输入:转换为HashMap + let mut signal_dict = HashMap::new(); + for (key, value) in dict.iter() { + let key_str = key.extract::()?; + let value_str = if let Ok(signal_obj) = value.extract::() { + // 值是PySignal对象,取其value + signal_obj.inner.value() + } else if let Ok(signal_str) = value.extract::() { + // 值是字符串 + signal_str + } else { + return Err(PyValueError::new_err("字典值必须是Signal对象或字符串")); + }; + signal_dict.insert(key_str, value_str); + } + + // 使用新的基于字典的匹配逻辑 + Ok(self.inner.matches_signals_dict(&signal_dict)) + } else if let Ok(vec) = signals.extract::>() { + // 处理向量输入 - 转换为字典 + let mut signal_dict = HashMap::new(); + for signal in vec { + signal_dict.insert(signal.inner.key(), signal.inner.value()); + } + Ok(self.inner.matches_signals_dict(&signal_dict)) + } else { + Err(PyValueError::new_err( + "参数必须是 Dict[str, str] 或 Dict[str, Signal] 或 List[Signal]", + )) + } + } + + /// 转换为JSON字符串 + fn to_json(&self) -> PyResult { + serde_json::to_string(&self.inner) + .map_err(|e| PyValueError::new_err(format!("JSON序列化失败: {e}"))) + } + + fn __repr__(&self) -> String { + format!( + "PyEvent(operate={:?}, name='{}', signals_all={}, signals_any={}, signals_not={})", + self.inner.operate, + self.inner.name, + self.inner.signals_all.len(), + self.inner.signals_any.len(), + self.inner.signals_not.len() + ) + } + + fn __str__(&self) -> String { + format!( + "Event[{}]: {:?} (all:{}, any:{}, not:{})", + self.inner.name, + self.inner.operate, + self.inner.signals_all.len(), + self.inner.signals_any.len(), + self.inner.signals_not.len() + ) + } + + /// 导出为字典 + fn dump(&self) -> PyResult { + let json_value = self.inner.dump(); + Python::with_gil(|py| { + let dict = pyo3::types::PyDict::new(py); + + dict.set_item("name", json_value["name"].as_str().unwrap_or(""))?; + dict.set_item("operate", json_value["operate"].as_str().unwrap_or(""))?; + + let empty_array = vec![]; + let signals_all: Vec<&str> = json_value["signals_all"] + .as_array() + .unwrap_or(&empty_array) + .iter() + .map(|v| v.as_str().unwrap_or("")) + .collect(); + dict.set_item("signals_all", signals_all)?; + + let signals_any: Vec<&str> = json_value["signals_any"] + .as_array() + .unwrap_or(&empty_array) + .iter() + .map(|v| v.as_str().unwrap_or("")) + .collect(); + dict.set_item("signals_any", signals_any)?; + + let signals_not: Vec<&str> = json_value["signals_not"] + .as_array() + .unwrap_or(&empty_array) + .iter() + .map(|v| v.as_str().unwrap_or("")) + .collect(); + dict.set_item("signals_not", signals_not)?; + + Ok(dict.into()) + }) + } + + /// 从字典加载 + #[classmethod] + fn load(_cls: &Bound<'_, pyo3::types::PyType>, data: &Bound<'_, PyDict>) -> PyResult { + // 转换Python字典为JSON Value + let json_str = Python::with_gil(|py| -> PyResult { + let json_module = py.import("json")?; + let json_str = json_module.call_method1("dumps", (data,))?; + json_str.extract::() + })?; + + let json_value: serde_json::Value = serde_json::from_str(&json_str) + .map_err(|e| PyValueError::new_err(format!("JSON解析失败: {e}")))?; + + let inner = Event::load(&json_value) + .map_err(|e| PyValueError::new_err(format!("Event加载失败: {e}")))?; + Ok(Self { inner }) + } + + /// 获取信号配置 + fn get_signals_config(&self) -> Vec { + self.inner.all_signals().map(|s| s.to_string()).collect() + } + + /// 支持 pickle 序列化 - 使用 __reduce__ 方法 + fn __reduce__(&self, py: Python) -> PyResult { + use super::operate::PyOperate; + use super::signal::PySignal; + use pyo3::IntoPyObject; + + // 构造函数参数 + let operate = PyOperate { + inner: self.inner.operate, + }; + let signals_all: Vec = self + .inner + .signals_all + .iter() + .map(|s| PySignal { inner: s.clone() }) + .collect(); + let signals_any: Vec = self + .inner + .signals_any + .iter() + .map(|s| PySignal { inner: s.clone() }) + .collect(); + let signals_not: Vec = self + .inner + .signals_not + .iter() + .map(|s| PySignal { inner: s.clone() }) + .collect(); + + let args = ( + operate, + signals_all, + signals_any, + signals_not, + self.inner.name.clone(), + ) + .into_pyobject(py)?; + + // 返回 (constructor, args) + let constructor = py.get_type::(); + let result = (constructor, args).into_pyobject(py)?; + Ok(result.into()) + } +} diff --git a/crates/czsc-core/src/objects/fake_bi.rs b/crates/czsc-core/src/objects/fake_bi.rs new file mode 100644 index 000000000..3fb4b9b6e --- /dev/null +++ b/crates/czsc-core/src/objects/fake_bi.rs @@ -0,0 +1,165 @@ +use super::{bar::Symbol, direction::Direction, fx::FX}; +#[cfg(feature = "python")] +use crate::utils::common::create_naive_pandas_timestamp; +use crate::{objects::mark::Mark, utils::rounded::RoundToNthDigit}; +use chrono::{DateTime, Utc}; +use derive_builder::Builder; +#[cfg(feature = "python")] +use parking_lot::Mutex; +#[cfg(feature = "python")] +use pyo3::types::{PyDict, PyDictMethods}; +#[cfg(feature = "python")] +use pyo3::{Py, PyObject, PyResult, Python}; +#[cfg(feature = "python")] +use pyo3::{pyclass, pymethods}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +use std::sync::Arc; + +/// 虚拟笔 +/// 主要为笔的内部分析提供便利 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Builder)] +#[builder(setter(into))] +pub struct FakeBI { + pub symbol: Symbol, + pub sdt: DateTime, + pub edt: DateTime, + pub direction: Direction, + pub high: f64, + pub low: f64, + pub power: f64, + #[cfg(feature = "python")] + #[builder(default = "Arc::new(Mutex::new(None))")] + pub cache: Arc>>>, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl FakeBI { + #[getter] + fn symbol(&self) -> String { + self.symbol.to_string() + } + + #[getter] + fn sdt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.sdt) + } + + #[getter] + fn edt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.edt) + } + + #[getter] + fn direction(&self) -> Direction { + self.direction + } + + #[getter] + fn high(&self) -> f64 { + self.high + } + + #[getter] + fn low(&self) -> f64 { + self.low + } + + #[getter] + fn power(&self) -> f64 { + self.power + } + #[getter] + fn get_cache<'py>(&'py mut self, py: Python<'py>) -> Py { + let mut cache = self.cache.lock(); + if cache.is_none() { + let dict = PyDict::new(py); + // 一次性填充所有属性,避免重复创建 + dict.set_item("symbol", self.symbol.as_ref()).unwrap(); + dict.set_item("sdt", create_naive_pandas_timestamp(py, self.sdt).unwrap()) + .unwrap(); + dict.set_item("edt", create_naive_pandas_timestamp(py, self.edt).unwrap()) + .unwrap(); + dict.set_item("direction", self.direction).unwrap(); + dict.set_item("high", self.high).unwrap(); + dict.set_item("low", self.low).unwrap(); + dict.set_item("power", self.power).unwrap(); + *cache = Some(dict.unbind()); + } + cache.as_ref().unwrap().clone_ref(py) + } + + #[setter] + #[gen_stub(skip)] // 跳过为了防止和 get_cache重复 + fn set_cache(&self, dict: Py) { + *self.cache.lock() = Some(dict); + } + + fn __repr__(&self) -> String { + format!( + "FakeBI(symbol={}, sdt={}, edt={}, direction={:?}, high={}, low={}, power={})", + self.symbol, + self.sdt.format("%Y-%m-%d %H:%M:%S"), + self.edt.format("%Y-%m-%d %H:%M:%S"), + self.direction, + self.high, + self.low, + self.power + ) + } +} + +/// 创建 fake_bis 列表 +/// +/// # Arguments +/// +/// * `fxs` - 分型序列,必须顶底分型交替 +/// +/// # Returns +/// +/// * 返回 FakeBI 的 Vec +pub fn create_fake_bis(fxs: &[FX]) -> Vec { + // 如果长度为奇数,移除最后一个元素 + let len = if !fxs.len().is_multiple_of(2) { + fxs.len() - 1 + } else { + fxs.len() + }; + + let mut fake_bis = Vec::new(); + for window in fxs[..len].windows(2) { + let fx1 = &window[0]; + let fx2 = &window[1]; + assert!(fx1.mark != fx2.mark, "相邻分型标记必须不同"); + + let fake_bi = match fx1.mark { + Mark::D => FakeBIBuilder::default() + .symbol(fx1.symbol.clone()) + .sdt(fx1.dt) + .edt(fx2.dt) + .direction(Direction::Up) + .high(fx2.high) + .low(fx1.low) + // 保留2位小数 + .power((fx2.high - fx1.low).round_to_2_digit()) + .build() + .unwrap(), + Mark::G => FakeBIBuilder::default() + .symbol(fx1.symbol.clone()) + .sdt(fx1.dt) + .edt(fx2.dt) + .direction(Direction::Down) + .high(fx1.high) + .low(fx2.low) + .power((fx1.high - fx2.low).round_to_2_digit()) + .build() + .unwrap(), + }; + fake_bis.push(fake_bi); + } + fake_bis +} diff --git a/crates/czsc-core/src/objects/freq.rs b/crates/czsc-core/src/objects/freq.rs new file mode 100644 index 000000000..76b20e4b7 --- /dev/null +++ b/crates/czsc-core/src/objects/freq.rs @@ -0,0 +1,377 @@ +#[cfg(feature = "python")] +use std::str::FromStr; + +use strum_macros::{AsRefStr, Display, EnumIter, EnumString}; + +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass_enum, gen_stub_pymethods}; + +#[cfg(feature = "python")] +use pyo3::types::PyDict; +#[cfg(feature = "python")] +use pyo3::{ + Bound, PyAny, PyErr, PyResult, Python, + exceptions::PyValueError, + pyclass, pymethods, + types::{PyAnyMethods, PyString}, +}; +#[cfg(feature = "python")] +use pyo3::{IntoPyObject, PyObject}; + +/// 时间周期 +#[cfg_attr(feature = "python", gen_stub_pyclass_enum)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive( + Debug, + PartialOrd, + Ord, + Clone, + Copy, + PartialEq, + EnumIter, + EnumString, + AsRefStr, + Display, + Eq, + Hash, +)] +pub enum Freq { + /// 逐笔 + #[strum(serialize = "Tick")] + Tick, + /// 1分钟 + #[strum(serialize = "1分钟")] + F1, + /// 2分钟 + #[strum(serialize = "2分钟")] + F2, + /// 3分钟 + #[strum(serialize = "3分钟")] + F3, + /// 4分钟 + #[strum(serialize = "4分钟")] + F4, + /// 5分钟 + #[strum(serialize = "5分钟")] + F5, + /// 6分钟 + #[strum(serialize = "6分钟")] + F6, + /// 10分钟 + #[strum(serialize = "10分钟")] + F10, + /// 12分钟 + #[strum(serialize = "12分钟")] + F12, + /// 15分钟 + #[strum(serialize = "15分钟")] + F15, + /// 20分钟 + #[strum(serialize = "20分钟")] + F20, + /// 30分钟 + #[strum(serialize = "30分钟")] + F30, + /// 60分钟 + #[strum(serialize = "60分钟")] + F60, + /// 120分钟 + #[strum(serialize = "120分钟")] + F120, + /// 240分钟 + #[strum(serialize = "240分钟")] + F240, + /// 360分钟 + #[strum(serialize = "360分钟")] + F360, + /// 日线 + #[strum(serialize = "日线")] + D, + /// 周线 + #[strum(serialize = "周线")] + W, + /// 月线 + #[strum(serialize = "月线")] + M, + /// 季线 + #[strum(serialize = "季线")] + S, + /// 年线 + #[strum(serialize = "年线")] + Y, +} + +#[cfg(feature = "python")] +pub fn freqs_from_str(s: &str) -> Vec { + use strum::IntoEnumIterator; + Freq::iter().filter(|&f| s.contains(f.as_ref())).collect() +} + +impl Freq { + /// 判断是否为分钟级别的周期 + pub fn is_minute_freq(&self) -> bool { + matches!( + self, + Freq::F1 + | Freq::F2 + | Freq::F3 + | Freq::F4 + | Freq::F5 + | Freq::F6 + | Freq::F10 + | Freq::F12 + | Freq::F15 + | Freq::F20 + | Freq::F30 + | Freq::F60 + | Freq::F120 + | Freq::F240 + | Freq::F360 + ) + } + + /// 获取对应的分钟数 + pub fn minutes(&self) -> Option { + match self { + Freq::F1 => Some(1), + Freq::F2 => Some(2), + Freq::F3 => Some(3), + Freq::F4 => Some(4), + Freq::F5 => Some(5), + Freq::F6 => Some(6), + Freq::F10 => Some(10), + Freq::F12 => Some(12), + Freq::F15 => Some(15), + Freq::F20 => Some(20), + Freq::F30 => Some(30), + Freq::F60 => Some(60), + Freq::F120 => Some(120), + Freq::F240 => Some(240), + Freq::F360 => Some(360), + _ => None, + } + } +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", pymethods)] +impl Freq { + /// 支持深拷贝 + fn __deepcopy__(&self, _memo: &Bound) -> PyResult { + Ok(*self) + } + + /// 支持pickle序列化 + fn __reduce__(&self) -> PyResult<(PyObject, PyObject)> { + Python::with_gil(|py| { + let cls = py.get_type::(); + let args = (format!("{self:?}"),); + Ok((cls.into(), args.into_pyobject(py)?.into_any().unbind())) + }) + } + + #[new] + fn new(value: &str) -> PyResult { + Ok(match value { + "Tick" | "逐笔" => Freq::Tick, + "1分钟" | "F1" => Freq::F1, + "2分钟" | "F2" => Freq::F2, + "3分钟" | "F3" => Freq::F3, + "4分钟" | "F4" => Freq::F4, + "5分钟" | "F5" => Freq::F5, + "6分钟" | "F6" => Freq::F6, + "10分钟" | "F10" => Freq::F10, + "12分钟" | "F12" => Freq::F12, + "15分钟" | "F15" => Freq::F15, + "20分钟" | "F20" => Freq::F20, + "30分钟" | "F30" => Freq::F30, + "60分钟" | "F60" => Freq::F60, + "120分钟" | "F120" => Freq::F120, + "240分钟" | "F240" => Freq::F240, + "360分钟" | "F360" => Freq::F360, + "日线" | "D" => Freq::D, + "周线" | "W" => Freq::W, + "月线" | "M" => Freq::M, + "季线" | "S" => Freq::S, + "年线" | "Y" => Freq::Y, + _ => { + return Err(PyErr::new::(format!( + "Unknown freq value: {value}" + ))); + } + }) + } + #[cfg(feature = "python")] + #[getter] + fn value(&self) -> &'static str { + match self { + Freq::Tick => "Tick", + Freq::F1 => "1分钟", + Freq::F2 => "2分钟", + Freq::F3 => "3分钟", + Freq::F4 => "4分钟", + Freq::F5 => "5分钟", + Freq::F6 => "6分钟", + Freq::F10 => "10分钟", + Freq::F12 => "12分钟", + Freq::F15 => "15分钟", + Freq::F20 => "20分钟", + Freq::F30 => "30分钟", + Freq::F60 => "60分钟", + Freq::F120 => "120分钟", + Freq::F240 => "240分钟", + Freq::F360 => "360分钟", + Freq::D => "日线", + Freq::W => "周线", + Freq::M => "月线", + Freq::S => "季线", + Freq::Y => "年线", + } + } + + fn __str__(&self) -> &'static str { + self.value() + } + + fn __repr__(&self) -> String { + format!("Freq.{self:?}") + } + + #[classattr] + fn __members__(py: Python) -> PyResult { + let dict = PyDict::new(py); + dict.set_item("Tick", Freq::Tick)?; + dict.set_item("F1", Freq::F1)?; + dict.set_item("F2", Freq::F2)?; + dict.set_item("F3", Freq::F3)?; + dict.set_item("F4", Freq::F4)?; + dict.set_item("F5", Freq::F5)?; + dict.set_item("F6", Freq::F6)?; + dict.set_item("F10", Freq::F10)?; + dict.set_item("F12", Freq::F12)?; + dict.set_item("F15", Freq::F15)?; + dict.set_item("F20", Freq::F20)?; + dict.set_item("F30", Freq::F30)?; + dict.set_item("F60", Freq::F60)?; + dict.set_item("F120", Freq::F120)?; + dict.set_item("F240", Freq::F240)?; + dict.set_item("F360", Freq::F360)?; + dict.set_item("D", Freq::D)?; + dict.set_item("W", Freq::W)?; + dict.set_item("M", Freq::M)?; + dict.set_item("S", Freq::S)?; + dict.set_item("Y", Freq::Y)?; + Ok(dict.into()) + } + + fn __richcmp__( + &self, + other: pyo3::Bound<'_, pyo3::PyAny>, + op: pyo3::basic::CompareOp, + ) -> pyo3::PyResult { + use pyo3::basic::CompareOp; + match op { + CompareOp::Eq => { + if let Ok(other_freq) = other.extract::() { + return Ok(*self == other_freq); + } + if let Ok(other_value) = other.getattr("value") + && let Ok(other_str) = other_value.extract::() + { + return Ok(self.value() == other_str.as_str()); + } + Ok(false) + } + CompareOp::Ne => { + if let Ok(other_freq) = other.extract::() { + return Ok(*self != other_freq); + } + if let Ok(other_value) = other.getattr("value") + && let Ok(other_str) = other_value.extract::() + { + return Ok(self.value() != other_str.as_str()); + } + Ok(true) + } + _ => Ok(false), + } + } +} + +#[cfg(feature = "python")] +impl TryFrom<&Bound<'_, PyAny>> for Freq { + type Error = PyErr; + + fn try_from(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(py_str) = value.downcast::() { + let py_str = py_str.to_string(); + Freq::from_str(&py_str) + .map_err(|e| PyValueError::new_err(format!("解析成 Freq 失败: {e}"))) + } else if let Ok(self_) = value.extract::() { + Ok(self_) + } else { + Err(PyValueError::new_err("无法解析 Freq 对象")) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + + #[test] + fn test_string_to_freq() { + // 测试从字符串解析为 Freq (EnumString) + assert_eq!(Freq::from_str("Tick").unwrap(), Freq::Tick); + assert_eq!(Freq::from_str("1分钟").unwrap(), Freq::F1); + assert_eq!(Freq::from_str("15分钟").unwrap(), Freq::F15); + assert_eq!(Freq::from_str("日线").unwrap(), Freq::D); + assert_eq!(Freq::from_str("周线").unwrap(), Freq::W); + assert_eq!(Freq::from_str("月线").unwrap(), Freq::M); + assert_eq!(Freq::from_str("季线").unwrap(), Freq::S); + assert_eq!(Freq::from_str("年线").unwrap(), Freq::Y); + + // 测试无效输入 + assert!(Freq::from_str("7分钟").is_err()); + } + + #[test] + fn test_freq_to_string() { + // 测试 Display trait + assert_eq!(Freq::Tick.to_string(), "Tick"); + assert_eq!(Freq::F1.to_string(), "1分钟"); + assert_eq!(Freq::F15.to_string(), "15分钟"); + assert_eq!(Freq::D.to_string(), "日线"); + assert_eq!(Freq::W.to_string(), "周线"); + + // 测试 AsRefStr trait + assert_eq!(Freq::Tick.as_ref(), "Tick"); + assert_eq!(Freq::F1.as_ref(), "1分钟"); + assert_eq!(Freq::F15.as_ref(), "15分钟"); + assert_eq!(Freq::D.as_ref(), "日线"); + assert_eq!(Freq::W.as_ref(), "周线"); + } + + #[test] + fn test_debug_format() { + // 测试 Debug trait + assert_eq!(format!("{:?}", Freq::Tick), "Tick"); + assert_eq!(format!("{:?}", Freq::F1), "F1"); + assert_eq!(format!("{:?}", Freq::D), "D"); + assert_eq!(format!("{:?}", Freq::W), "W"); + } + + #[test] + fn test_clone_and_eq() { + // 测试 Clone 和 PartialEq + let freq1 = Freq::F15; + let freq2 = freq1; + assert_eq!(freq1, freq2); + + let freq3 = Freq::F30; + assert_ne!(freq1, freq3); + } +} diff --git a/crates/czsc-core/src/objects/fx.rs b/crates/czsc-core/src/objects/fx.rs new file mode 100644 index 000000000..e6d71ec9d --- /dev/null +++ b/crates/czsc-core/src/objects/fx.rs @@ -0,0 +1,399 @@ +#[cfg(feature = "python")] +use crate::objects::bar::RawBar; +#[cfg(feature = "python")] +use parking_lot::RwLock; +use std::sync::Arc; + +use super::{ + bar::{NewBar, Symbol}, + mark::Mark, +}; +use chrono::{DateTime, Utc}; +use derive_builder::Builder; + +#[cfg(feature = "python")] +use crate::utils::common::{create_naive_pandas_timestamp, parse_python_datetime}; +#[cfg(feature = "python")] +use pyo3::basic::CompareOp; +#[cfg(feature = "python")] +use pyo3::types::{PyDict, PyDictMethods}; +#[cfg(feature = "python")] +use pyo3::{Bound, Python, pyclass, pymethods}; +#[cfg(feature = "python")] +use pyo3::{Py, PyAny, PyObject, PyResult}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +const POWER_STRONG: &str = "强"; +const POWER_MEDIUM: &str = "中"; +const POWER_WEAK: &str = "弱"; + +/// 分型 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Builder)] +#[builder(setter(into))] +pub struct FX { + pub symbol: Symbol, + pub dt: DateTime, + pub mark: Mark, + pub high: f64, + pub low: f64, + pub fx: f64, + #[builder(default = "Vec::new()")] + pub elements: Vec, + #[cfg(feature = "python")] + #[builder(default = "Arc::new(RwLock::new(None))")] + pub cache: Arc>>>, +} + +impl FX { + fn _power_str(&self) -> &str { + assert_eq!(self.elements.len(), 3); + + let k1 = &self.elements[0]; + let k2 = &self.elements[1]; + let k3 = &self.elements[2]; + + match self.mark { + Mark::D => { + if k3.close > k1.high { + POWER_STRONG + } else if k3.close > k2.high { + POWER_MEDIUM + } else { + POWER_WEAK + } + } + Mark::G => { + if k3.close < k1.low { + POWER_STRONG + } else if k3.close < k2.low { + POWER_MEDIUM + } else { + POWER_WEAK + } + } + } + } + + fn _power_volume(&self) -> f64 { + assert_eq!(self.elements.len(), 3); + self.elements.iter().map(|x| x.vol).sum() + } + + fn _has_zs(&self) -> bool { + assert_eq!(self.elements.len(), 3); + + let zd = self + .elements + .iter() + .map(|x| x.low) + .fold(f64::NEG_INFINITY, f64::max); + let zg = self + .elements + .iter() + .map(|x| x.high) + .fold(f64::INFINITY, f64::min); + + zg >= zd + } +} + +impl FX { + /// 判断分型强度 + pub fn power_str(&self) -> &str { + self._power_str() + } + + /// 计算成交量力度 + pub fn power_volume(&self) -> f64 { + self._power_volume() + } + + /// 判断构成分型的三根无包含K线是否有重叠中枢 + pub fn has_zs(&self) -> bool { + self._has_zs() + } +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl FX { + #[new] + fn new( + symbol: String, + dt: &Bound, + mark: Mark, + high: f64, + low: f64, + fx: f64, + elements: Vec, + ) -> PyResult { + // 使用通用的日期时间解析函数 + let datetime_utc = parse_python_datetime(dt)?; + + Ok(FX { + symbol: symbol.into(), + dt: datetime_utc, + mark, + high, + low, + fx, + elements: elements.into_iter().collect(), + cache: Arc::new(RwLock::new(None)), + }) + } + + #[getter] + fn symbol(&self) -> String { + self.symbol.to_string() + } + + #[getter] + fn dt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.dt) + } + + #[getter] + fn mark(&self) -> Mark { + self.mark.clone() + } + + #[getter] + fn high(&self) -> f64 { + self.high + } + + #[getter] + fn low(&self) -> f64 { + self.low + } + + #[getter] + fn fx(&self) -> f64 { + self.fx + } + + /// 获取构成分型的NewBar列表 + #[getter] + fn new_bars(&self) -> Vec { + self.elements.to_vec() + } + + /// 获取原始K线列表(从NewBar的elements中提取) + #[getter] + fn raw_bars(&self) -> Vec { + self.elements + .iter() + .flat_map(|new_bar| new_bar.elements.clone()) + .collect() + } + + /// 获取分型强度字符串 + #[getter(power_str)] + fn power_str_py(&self) -> String { + self._power_str().to_string() + } + + /// 获取成交量力度 + #[getter(power_volume)] + fn power_volume_py(&self) -> f64 { + self._power_volume() + } + + /// 判断是否有重叠中枢 + #[getter(has_zs)] + fn has_zs_py(&self) -> bool { + self._has_zs() + } + + /// 获取构成分型的NewBar列表(与new_bars相同,为兼容czsc库) + #[getter] + fn elements(&self) -> Vec { + self.new_bars() + } + + /// 缓存字典(与 czsc 库兼容) + #[getter] + fn get_cache<'py>(&'py self, py: Python<'py>) -> Py { + // 首先尝试读锁获取缓存 + { + let cache_read = self.cache.read(); + if let Some(ref cached_dict) = *cache_read { + return cached_dict.clone_ref(py); + } + } + + // 如果缓存为空,使用写锁初始化并填充所有属性 + let mut cache_write = self.cache.write(); + if cache_write.is_none() { + let dict = PyDict::new(py); + // 一次性填充所有属性,避免重复创建 + dict.set_item("symbol", self.symbol.as_ref()).unwrap(); + dict.set_item("dt", create_naive_pandas_timestamp(py, self.dt).unwrap()) + .unwrap(); + dict.set_item("mark", self.mark.clone()).unwrap(); + dict.set_item("high", self.high).unwrap(); + dict.set_item("low", self.low).unwrap(); + dict.set_item("fx", self.fx).unwrap(); + dict.set_item("elements", py.None()).unwrap(); // 复杂对象先设为None + *cache_write = Some(dict.unbind()); + } + cache_write.as_ref().unwrap().clone_ref(py) + } + + #[setter] + #[gen_stub(skip)] // 跳过为了防止和 get_cache重复 + fn set_cache(&self, dict: Py) { + let mut cache_write = self.cache.write(); + *cache_write = Some(dict); + } + + /// 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 + #[getter] + pub fn __dict__(&self, py: Python) -> PyResult { + // 直接返回缓存的字典,避免重复创建 + Ok(self.get_cache(py).into()) + } + + fn __repr__(&self) -> String { + format!( + "FX(symbol={}, dt={}, mark={:?}, fx={})", + self.symbol, + self.dt.format("%Y-%m-%d %H:%M:%S"), + self.mark, + self.fx + ) + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(self == other), + CompareOp::Ne => Ok(self != other), + _ => Ok(false), + } + } +} + +impl PartialEq for FX { + fn eq(&self, other: &Self) -> bool { + self.symbol == other.symbol + && self.dt == other.dt + && self.mark == other.mark + && self.high == other.high + && self.low == other.low + && self.fx == other.fx + && self.elements == other.elements + } +} + +pub fn print_fx_list(fxs: &[FX]) { + println!("{:<12} {:>12} {:>12} {:>12}", "Mark", "High", "Low", "FX"); + println!("{:-<12} {:-^12} {:-^12} {:-^12}", "", "", "", ""); + + for fx in fxs { + println!( + "{:<12} {:>12.4} {:>12.4} {:>12.4}", + fx.mark, fx.high, fx.low, fx.fx + ); + } +} + +#[cfg(test)] +pub mod tests { + use std::sync::Arc; + + use chrono::Utc; + + use crate::objects::bar::NewBarBuilder; + + use super::*; + + #[test] + fn test_fx_new() { + let fx1 = FXBuilder::default() + .symbol(Arc::from("TEST".to_string())) + .dt(Utc::now()) + .mark(Mark::D) + .high(0) + .low(0) + .fx(0) + .build() + .unwrap(); + assert_eq!(fx1.high, 0.0); + } + + /// 创建一个测试用的底分型 + pub fn create_d_fx() -> FX { + // 创建测试用的K线数据 + let k1 = NewBarBuilder::default() + .symbol(Arc::from("TEST".to_string())) + .dt(Utc::now()) + .id(1) + .open(8.5) + .high(9.0) + .low(8.0) + .close(8.2) + .vol(90.0) + .amount(900.0) + .build() + .unwrap(); + + let k2 = NewBarBuilder::default() + .symbol(Arc::from("TEST".to_string())) + .dt(Utc::now()) + .id(2) + .open(8) + .high(8.5) + .low(7.5) + .close(8.0) + .vol(100.0) + .amount(1000.0) + .build() + .unwrap(); + + let k3 = NewBarBuilder::default() + .symbol(Arc::from("TEST".to_string())) + .dt(Utc::now()) + .id(3) + .open(8.5) + .high(9.0) + .low(8.0) + .close(8.8) + .vol(110.0) + .amount(1100.0) + .build() + .unwrap(); + + FXBuilder::default() + .symbol(k1.symbol.clone()) + .dt(k2.dt) + .mark(Mark::D) + .high(k2.high) + .low(k2.low) + .fx(k2.low) + .elements(vec![k1.clone(), k2.clone(), k3.clone()]) + .build() + .unwrap() + } + + #[test] + fn test_power_str() { + let fx_d = create_d_fx(); + assert_eq!(fx_d.power_str(), POWER_MEDIUM); + } + + #[test] + fn test_power_volume() { + let fx_d = create_d_fx(); + assert_eq!(fx_d.power_volume(), 300.0); + } + + #[test] + fn test_has_zs() { + let fx_d = create_d_fx(); + assert!(fx_d.has_zs()); + } +} diff --git a/crates/czsc-core/src/objects/mark.rs b/crates/czsc-core/src/objects/mark.rs new file mode 100644 index 000000000..8a41cf6f4 --- /dev/null +++ b/crates/czsc-core/src/objects/mark.rs @@ -0,0 +1,130 @@ +use strum_macros::{AsRefStr, Display, EnumString}; + +#[cfg(feature = "python")] +use pyo3::pyclass; +#[cfg(feature = "python")] +use pyo3::pymethods; +#[cfg(feature = "python")] +use pyo3::types::PyAnyMethods; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::gen_stub_pyclass_enum; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::gen_stub_pymethods; + +/// 分型类型 +#[cfg_attr(feature = "python", gen_stub_pyclass_enum)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, PartialEq, EnumString, AsRefStr, Display)] +pub enum Mark { + /// 底分型 + #[strum(serialize = "底分型")] + D, + /// 顶分型 + #[strum(serialize = "顶分型")] + G, +} +#[cfg(feature = "python")] +#[gen_stub_pymethods] +#[cfg(feature = "python")] +#[pymethods] +impl Mark { + /// 获取标记的字符串值(与 czsc 库兼容) + #[getter] + fn value(&self) -> &'static str { + match self { + Mark::G => "顶分型", + Mark::D => "底分型", + } + } + + fn __str__(&self) -> &'static str { + self.value() + } + + fn __repr__(&self) -> String { + format!( + "Mark.{}", + match self { + Mark::G => "G", + Mark::D => "D", + } + ) + } + + fn __richcmp__( + &self, + other: pyo3::Bound<'_, pyo3::PyAny>, + op: pyo3::basic::CompareOp, + ) -> pyo3::PyResult { + use pyo3::basic::CompareOp; + match op { + CompareOp::Eq => { + if let Ok(other_mark) = other.extract::() { + return Ok(*self == other_mark); + } + if let Ok(other_value) = other.getattr("value") + && let Ok(other_str) = other_value.extract::() + { + return Ok(self.value() == other_str.as_str()); + } + Ok(false) + } + CompareOp::Ne => { + if let Ok(other_mark) = other.extract::() { + return Ok(*self != other_mark); + } + if let Ok(other_value) = other.getattr("value") + && let Ok(other_str) = other_value.extract::() + { + return Ok(self.value() != other_str.as_str()); + } + Ok(true) + } + _ => Ok(false), + } + } +} +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + + #[test] + fn test_string_to_mark() { + // 测试从字符串解析为 Mark (EnumString) + assert_eq!(Mark::from_str("底分型").unwrap(), Mark::D); + assert_eq!(Mark::from_str("顶分型").unwrap(), Mark::G); + + // 测试无效输入 + assert!(Mark::from_str("中分型").is_err()); + } + + #[test] + fn test_mark_to_string() { + // 测试 Display trait + assert_eq!(Mark::D.to_string(), "底分型"); + assert_eq!(Mark::G.to_string(), "顶分型"); + + // 测试 AsRefStr trait + assert_eq!(Mark::D.as_ref(), "底分型"); + assert_eq!(Mark::G.as_ref(), "顶分型"); + } + + #[test] + fn test_debug_format() { + // 测试 Debug trait + assert_eq!(format!("{:?}", Mark::D), "D"); + assert_eq!(format!("{:?}", Mark::G), "G"); + } + + #[test] + fn test_clone_and_eq() { + // 测试 Clone 和 PartialEq + let mark1 = Mark::D; + let mark2 = mark1.clone(); + assert_eq!(mark1, mark2); + + let mark3 = Mark::G; + assert_ne!(mark1, mark3); + } +} diff --git a/crates/czsc-core/src/objects/market.rs b/crates/czsc-core/src/objects/market.rs new file mode 100644 index 000000000..661ddf7e4 --- /dev/null +++ b/crates/czsc-core/src/objects/market.rs @@ -0,0 +1,49 @@ +#[cfg(feature = "python")] +use pyo3::{exceptions::PyValueError, prelude::*, types::PyString}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass_enum, gen_stub_pymethods}; +#[cfg(feature = "python")] +use std::str::FromStr; +use strum_macros::{AsRefStr, Display, EnumString}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumString, AsRefStr, Display)] +#[cfg_attr(feature = "python", gen_stub_pyclass_enum)] +#[cfg_attr(feature = "python", pyclass(eq, eq_int, module = "czsc._native"))] +pub enum Market { + /// A股 + #[strum(serialize = "A股")] + AShare, + /// 期货 + #[strum(serialize = "期货")] + Futures, + /// 默认 + #[strum(serialize = "默认")] + Default, +} + +#[cfg(feature = "python")] +#[gen_stub_pymethods] +#[pymethods] +impl Market { + #[new] + fn from_py_any<'py>(ob: &Bound<'py, PyAny>) -> PyResult { + Self::try_from(ob) + } +} + +#[cfg(feature = "python")] +impl TryFrom<&Bound<'_, PyAny>> for Market { + type Error = PyErr; + + fn try_from(value: &Bound<'_, PyAny>) -> Result { + if let Ok(py_str) = value.downcast::() { + let py_str = py_str.to_string(); + Market::from_str(&py_str) + .map_err(|e| PyValueError::new_err(format!("解析成 Market 失败: {e}"))) + } else if let Ok(self_) = value.extract::() { + Ok(self_) + } else { + Err(PyValueError::new_err("无法解析 Market 对象")) + } + } +} diff --git a/crates/czsc-core/src/objects/mod.rs b/crates/czsc-core/src/objects/mod.rs new file mode 100644 index 000000000..b81d0d935 --- /dev/null +++ b/crates/czsc-core/src/objects/mod.rs @@ -0,0 +1,20 @@ +//! Core data objects (FX / BI / ZS / RawBar / Freq / Mark / ...). +//! +//! Migrated from rs-czsc 47ef6efa per docs/MIGRATION_NOTES.md §1. Submodules +//! are added incrementally as Phase D sub-loops complete. + +pub mod errors; +pub mod market; +pub mod freq; +pub mod bar; +pub mod mark; +pub mod direction; +pub mod fx; +pub mod fake_bi; +pub mod bi; +pub mod zs; +pub mod operate; +pub mod signal; +pub mod event; +pub mod position; +pub mod state; diff --git a/crates/czsc-core/src/objects/operate.rs b/crates/czsc-core/src/objects/operate.rs new file mode 100644 index 000000000..4d9307f5e --- /dev/null +++ b/crates/czsc-core/src/objects/operate.rs @@ -0,0 +1,252 @@ +// czsc-only: rs-czsc's `operate.rs` carried a wide collection of unused +// imports (polars / log / rayon / sha2 / and forward references to +// event / position / signal). Since the file only really defines the +// `Operate` enum + its `PyOperate` wrapper, we trim the imports here to +// avoid pulling those heavy crates into czsc-core. The original +// `#![allow(unused)]` is kept for the few remaining unused items the +// upstream file carries. See docs/MIGRATION_NOTES.md §2.4. +#![allow(unused)] +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::str::FromStr; +use strum::IntoEnumIterator; +use strum_macros::{AsRefStr, EnumIter, EnumString}; + +#[cfg(feature = "python")] +use pyo3::prelude::*; +#[cfg(feature = "python")] +use pyo3::{ + Bound, FromPyObject, PyResult, Python, exceptions::PyValueError, + types::PyAnyMethods, +}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +pub const ANY: &str = "任意"; + +#[derive( + Clone, Copy, Debug, PartialEq, Hash, EnumString, EnumIter, AsRefStr, Serialize, Deserialize, +)] +pub enum Operate { + /// Hold Long 持多 + #[serde(rename = "持多")] + HL, + /// Hold Short 持空 + #[serde(rename = "持空")] + HS, + /// Hold Other 持币 + #[serde(rename = "持币")] + HO, + /// Long Open 开多 + #[serde(rename = "开多")] + LO, + /// Long Exit 平多 + #[serde(rename = "平多")] + LE, + /// Short Open 开空 + #[serde(rename = "开空")] + SO, + /// Short Exit 平空 + #[serde(rename = "平空")] + SE, +} + +impl fmt::Display for Operate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_ref()) + } +} + +impl Operate { + fn list_of_types() -> String { + Self::iter() + .map(|ot| ot.to_string()) + .collect::>() + .join(", ") + } + + /// 返回操作类型的中文名称 + pub fn to_chinese(&self) -> &'static str { + match self { + Operate::HL => "持多", + Operate::HS => "持空", + Operate::HO => "持币", + Operate::LO => "开多", + Operate::LE => "平多", + Operate::SO => "开空", + Operate::SE => "平空", + } + } +} + +#[cfg(feature = "python")] +impl<'py> FromPyObject<'py> for Operate { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(s) = ob.extract::() { + let o = Self::from_str(&s).map_err(|_| { + PyValueError::new_err(format!( + "无法解析 operate, 期望 str [ {} ]", + Self::list_of_types() + )) + })?; + Ok(o) + } else { + Err(PyValueError::new_err(format!( + "operate 类型不合法, 期望 str: [ {} ]", + Self::list_of_types() + ))) + } + } +} + +/// Python可见的Operate包装器 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(name = "Operate", module = "czsc._native"))] +#[derive(Debug, Clone)] +pub struct PyOperate { + pub inner: Operate, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl PyOperate { + #[allow(non_snake_case)] + #[classattr] + fn HL() -> PyOperate { + PyOperate { inner: Operate::HL } + } + + #[allow(non_snake_case)] + #[classattr] + fn HS() -> PyOperate { + PyOperate { inner: Operate::HS } + } + + #[allow(non_snake_case)] + #[classattr] + fn HO() -> PyOperate { + PyOperate { inner: Operate::HO } + } + + #[allow(non_snake_case)] + #[classattr] + fn LO() -> PyOperate { + PyOperate { inner: Operate::LO } + } + + #[allow(non_snake_case)] + #[classattr] + fn LE() -> PyOperate { + PyOperate { inner: Operate::LE } + } + + #[allow(non_snake_case)] + #[classattr] + fn SO() -> PyOperate { + PyOperate { inner: Operate::SO } + } + + #[allow(non_snake_case)] + #[classattr] + fn SE() -> PyOperate { + PyOperate { inner: Operate::SE } + } + + // 保留方法访问方式作为兼容性选项 + #[classmethod] + fn hl(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Operate::HL } + } + + #[classmethod] + fn hs(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Operate::HS } + } + + #[classmethod] + fn ho(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Operate::HO } + } + + #[classmethod] + fn lo(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Operate::LO } + } + + #[classmethod] + fn le(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Operate::LE } + } + + #[classmethod] + fn so(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Operate::SO } + } + + #[classmethod] + fn se(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Operate::SE } + } + + #[classmethod] + fn from_str_py(_cls: &Bound<'_, pyo3::types::PyType>, s: String) -> PyResult { + let inner = Operate::from_str(&s).map_err(|_| { + PyValueError::new_err(format!( + "无法解析操作类型: {}, 期望 [ {} ]", + s, + Operate::list_of_types() + )) + })?; + Ok(Self { inner }) + } + + #[classmethod] + fn from_str(_cls: &Bound<'_, pyo3::types::PyType>, s: String) -> PyResult { + let inner = Operate::from_str(&s).map_err(|_| { + PyValueError::new_err(format!( + "无法解析操作类型: {}, 期望 [ {} ]", + s, + Operate::list_of_types() + )) + })?; + Ok(Self { inner }) + } + + fn __str__(&self) -> String { + self.inner.to_string() + } + + fn __repr__(&self) -> String { + format!("PyOperate::{:?}", self.inner) + } + + fn __eq__(&self, other: &Self) -> bool { + self.inner == other.inner + } + + fn __hash__(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self.inner.hash(&mut hasher); + hasher.finish() + } + + /// 兼容性属性:返回操作类型的中文字符串值 + #[getter] + fn value(&self) -> String { + self.inner.to_chinese().to_string() + } + + /// 支持pickle序列化 + fn __reduce__(&self, py: Python) -> PyResult { + use pyo3::IntoPyObject; + + let class_method = py.get_type::().getattr("from_str")?; + // 使用英文缩写而不是中文名称,因为from_str只解析英文 + let args = (self.inner.to_string(),).into_pyobject(py)?; + let result = (class_method, args).into_pyobject(py)?; + Ok(result.into()) + } +} diff --git a/crates/czsc-core/src/objects/position.rs b/crates/czsc-core/src/objects/position.rs new file mode 100644 index 000000000..1a7a5a58b --- /dev/null +++ b/crates/czsc-core/src/objects/position.rs @@ -0,0 +1,2055 @@ +// czsc-only: pyo3 + Python wrapper imports gated behind the `python` feature +// for non-python builds. polars + log are kept unconditional because Position +// uses them outside cfg blocks. See docs/MIGRATION_NOTES.md §2.4. +#![allow(unused)] +use anyhow::{Context, anyhow}; +use chrono::{DateTime, FixedOffset, NaiveDateTime}; +use log::warn; +use polars::{df, prelude::*}; +use serde::{Deserialize, Serialize}; +use std::cell::{Ref, RefCell}; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::fs; +use std::path::{Path, PathBuf}; +use std::rc::Rc; +use std::str::FromStr; +use std::time::Instant; + +use super::event::Event; +use super::operate::Operate; +use super::signal::{ANY, Signal}; + +#[cfg(feature = "python")] +use pyo3::exceptions::PyValueError; +#[cfg(feature = "python")] +use pyo3::prelude::*; +#[cfg(feature = "python")] +use pyo3::types::PyBytes; +#[cfg(feature = "python")] +use super::event::PyEvent; +#[cfg(feature = "python")] +use super::signal::PySignal; + +/// 解析 operate 字符串,支持英文缩写和中文名称 +fn parse_operate(s: &str) -> Result { + // 首先尝试英文缩写(EnumString) + if let Ok(op) = Operate::from_str(s) { + return Ok(op); + } + + // 然后尝试中文名称 + match s { + "持多" => Ok(Operate::HL), + "持空" => Ok(Operate::HS), + "持币" => Ok(Operate::HO), + "开多" => Ok(Operate::LO), + "平多" => Ok(Operate::LE), + "开空" => Ok(Operate::SO), + "平空" => Ok(Operate::SE), + _ => Err(format!("未知的operate值: {s}")), + } +} + +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub enum Pos { + /// 空 + Short, + /// 空仓 + #[default] + Flat, + /// 多 + Long, +} + +impl fmt::Display for Pos { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + Pos::Short => "空", + Pos::Flat => "空仓", + Pos::Long => "多", + }; + write!(f, "{s}") + } +} + +impl Pos { + /// 转换为数值,用于数学运算 + pub fn to_f64(self) -> f64 { + match self { + Pos::Short => -1.0, + Pos::Flat => 0.0, + Pos::Long => 1.0, + } + } + + /// 从数值创建Pos + pub fn from_f64(value: f64) -> Self { + if value > 0.5 { + Pos::Long + } else if value < -0.5 { + Pos::Short + } else { + Pos::Flat + } + } +} + +/// 精简版价格线 +#[derive(Clone, Copy, Debug)] +pub struct LiteBar { + pub id: i32, + pub dt: DateTime, + pub price: f64, +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct PositionUpdateProfile { + pub event_match_ns: u128, + pub fsm_ns: u128, + pub risk_ns: u128, + pub holds_ns: u128, +} + +const ANY_CODE: i32 = -1; +const UNKNOWN_CODE: i32 = 0; + +#[derive(Debug, Clone, Copy)] +struct EncodedSignalValue { + v1_code: i32, + v2_code: i32, + v3_code: i32, + score: i32, +} + +#[derive(Debug, Clone)] +struct CachedEncodedSignalValue { + raw: String, + encoded: EncodedSignalValue, +} + +#[derive(Debug, Clone, Copy)] +struct EncodedSignalClause { + key_id: usize, + v1_code: i32, + v2_code: i32, + v3_code: i32, + min_score: i32, +} + +#[derive(Debug, Clone)] +struct CompiledEventMatcher { + signals_all: Vec, + signals_any: Vec, + signals_not: Vec, +} + +#[derive(Debug, Clone, Default)] +struct PositionEventMatcher { + keys: Vec, + key_to_id: HashMap, + value_to_code: HashMap, + events: Vec, +} + +impl PositionEventMatcher { + fn key_id(&mut self, key: String) -> usize { + if let Some(id) = self.key_to_id.get(key.as_str()) { + return *id; + } + let id = self.keys.len(); + self.keys.push(key.clone()); + self.key_to_id.insert(key, id); + id + } + + fn value_code(&mut self, v: &str) -> i32 { + if v == ANY { + return ANY_CODE; + } + if let Some(code) = self.value_to_code.get(v) { + return *code; + } + let code = self.value_to_code.len() as i32 + 1; + self.value_to_code.insert(v.to_string(), code); + code + } + + fn parse_v123_score(v: &str) -> Option<(&str, &str, &str, i32)> { + let mut it = v.splitn(4, '_'); + let v1 = it.next()?; + let v2 = it.next()?; + let v3 = it.next()?; + let score = it.next()?.parse::().ok()?; + Some((v1, v2, v3, score)) + } + + fn compile_signal_clause(&mut self, signal: &Signal) -> Option { + let key_id = self.key_id(signal.key()); + let value = signal.value(); + let (v1, v2, v3, score) = Self::parse_v123_score(value.as_str())?; + Some(EncodedSignalClause { + key_id, + v1_code: self.value_code(v1), + v2_code: self.value_code(v2), + v3_code: self.value_code(v3), + min_score: score, + }) + } + + fn compile_from_position(position: &Position) -> Self { + let mut matcher = Self::default(); + matcher + .events + .reserve(position.opens.len() + position.exits.len()); + for e in position.opens.iter().chain(position.exits.iter()) { + let mut all = Vec::with_capacity(e.signals_all.len()); + let mut any = Vec::with_capacity(e.signals_any.len()); + let mut not = Vec::with_capacity(e.signals_not.len()); + for s in &e.signals_all { + if let Some(c) = matcher.compile_signal_clause(s) { + all.push(c); + } + } + for s in &e.signals_any { + if let Some(c) = matcher.compile_signal_clause(s) { + any.push(c); + } + } + for s in &e.signals_not { + if let Some(c) = matcher.compile_signal_clause(s) { + not.push(c); + } + } + matcher.events.push(CompiledEventMatcher { + signals_all: all, + signals_any: any, + signals_not: not, + }); + } + matcher + } + + #[inline] + fn encode_runtime_values_inplace( + &self, + signal_map: &HashMap, + values: &mut [Option], + cache: &mut [Option], + ) { + for v in values.iter_mut() { + *v = None; + } + for (key_id, key) in self.keys.iter().enumerate() { + if let Some(raw) = signal_map.get(key.as_str()) { + if let Some(Some(cached)) = cache.get(key_id) + && cached.raw == *raw + { + values[key_id] = Some(cached.encoded); + continue; + } + + if let Some((v1, v2, v3, score)) = Self::parse_v123_score(raw) { + let v1_code = self.value_to_code.get(v1).copied().unwrap_or(UNKNOWN_CODE); + let v2_code = self.value_to_code.get(v2).copied().unwrap_or(UNKNOWN_CODE); + let v3_code = self.value_to_code.get(v3).copied().unwrap_or(UNKNOWN_CODE); + let encoded = EncodedSignalValue { + v1_code, + v2_code, + v3_code, + score, + }; + values[key_id] = Some(encoded); + if let Some(slot) = cache.get_mut(key_id) { + *slot = Some(CachedEncodedSignalValue { + raw: raw.clone(), + encoded, + }); + } + } else if let Some(slot) = cache.get_mut(key_id) { + *slot = None; + } + } + } + } + + #[inline] + fn clause_match(values: &[Option], c: EncodedSignalClause) -> bool { + if let Some(value) = values.get(c.key_id).and_then(|v| *v) { + value.score >= c.min_score + && (c.v1_code == ANY_CODE || c.v1_code == value.v1_code) + && (c.v2_code == ANY_CODE || c.v2_code == value.v2_code) + && (c.v3_code == ANY_CODE || c.v3_code == value.v3_code) + } else { + false + } + } + + fn find_first_match( + &self, + signal_map: &HashMap, + values: &mut [Option], + cache: &mut [Option], + ) -> Option { + self.encode_runtime_values_inplace(signal_map, values, cache); + for (idx, evt) in self.events.iter().enumerate() { + if evt + .signals_not + .iter() + .any(|c| Self::clause_match(values, *c)) + { + continue; + } + if evt + .signals_all + .iter() + .any(|c| !Self::clause_match(values, *c)) + { + continue; + } + if !evt.signals_any.is_empty() + && !evt + .signals_any + .iter() + .any(|c| Self::clause_match(values, *c)) + { + continue; + } + return Some(idx); + } + None + } +} + +#[derive(Debug, Clone)] +pub struct TempState { + /// 最近一次信号传入的时间 + pub end_dt: DateTime, + /// 最近一次开多交易的时间(可无) + pub last_lo_dt: Option>, + /// 最近一次开空交易的时间(可无) + pub last_so_dt: Option>, +} + +/// 操作记录(push 到 operates) +#[allow(unused)] +#[derive(Debug, Clone)] +pub struct OperateRecord { + pub symbol: String, + pub dt: DateTime, + pub bar_id: i32, + pub price: f64, + pub op: Operate, + pub op_desc: Option, + pub pos: Pos, +} + +/// 持仓快照(push 到 holds) +#[derive(Debug, Clone)] +pub struct HoldRecord { + pub dt: DateTime, + pub pos: Pos, + pub price: f64, + pub n1b: Option, // 下一K线的收益率,用于计算截面等权收益 +} + +#[derive(Default)] +pub struct HoldColumns { + pub dt: Vec, // NaiveDateTime兼容Polars格式 + pub pos: Vec, + pub price: Vec, + pub n1b: Vec>, // 下一K线的收益率 +} + +impl HoldColumns { + pub fn from_records(records: Vec) -> Self { + let mut cols = HoldColumns::default(); + cols.dt.reserve(records.len()); + cols.pos.reserve(records.len()); + cols.price.reserve(records.len()); + cols.n1b.reserve(records.len()); + + for r in records { + cols.dt.push(r.dt.naive_local()); + cols.pos.push(r.pos.to_f64() as i32); + cols.price.push(r.price); + cols.n1b.push(r.n1b); + } + + cols + } + + pub fn into_df(self) -> anyhow::Result { + let df = df![ + "dt" => self.dt, + "pos" => self.pos, + "price" => self.price, + "n1b" => self.n1b, + ] + .context("创建 Hold DataFrame 失败")?; + + Ok(df) + } +} + +/// 记录最近一次用于计算止损/超时的开仓基准 +#[allow(unused)] +#[derive(Debug, Clone)] +pub struct LastEvent { + pub dt: DateTime, + pub bar_id: i32, + pub price: f64, + pub op: Operate, + pub op_desc: Option, +} + +/// 列式记录的交易集合(所有字段按列存放) +#[derive(Debug, Clone, Default)] +pub struct TradePairsColumns<'a> { + /// 标的代码(例如 "000001.SH") + pub symbol: Vec<&'a str>, + /// 策略标记(Position的名称) + pub strategy_mark: Vec<&'a str>, + /// 交易方向(例如 "多头" 或 "空头") + pub direction: Vec<&'a str>, + /// 开仓时间(有明确时间的 DateTime) + pub open_dt: Vec, + /// 平仓时间(若仍持仓则为 None), NaiveDateTime是为了兼容 Polars时间格式 + pub close_dt: Vec, + /// 开仓价格 + pub open_price: Vec, + /// 平仓价格(若仍持仓则为 None) + pub close_price: Vec, + /// 持仓 K 线数 + pub holding_bar: Vec, + /// 事件序列(交易触发与出场说明) + pub event_sequence: Vec, + /// 持仓天数(天数,浮点数以支持小于1天的持仓) + pub holding_day: Vec, + /// 盈亏比例(若未知则为 None) + pub yield_profit_ratio: Vec>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Position { + /// 开仓交易事件列表 + pub opens: Vec, + /// 平仓交易事件列表,允许为空 + pub exits: Vec, + /// 同类型开仓间隔时间,单位:秒;默认值为 0,表示同类型开仓间隔没有约束 + pub interval: i64, + /// 最大允许持仓K线数量限制为最近一个开仓事件触发后的 timeout 根基础周期K线 + pub timeout: i32, + /// 最大允许亏损比例,单位:百分比(如0.05表示5%);成本的计算以最近一个开仓事件触发价格为准 + pub stop_loss: f64, + /// T0: 是否允许T0交易,默认为 False 表示不允许T0交易 + #[serde(rename = "T0")] + pub t0: bool, + /// 仓位名称,默认值为第一个开仓事件的名称 + pub name: String, + /// 标的代码 + pub symbol: String, + + // 下面是运行时字段(不序列化) + #[serde(skip)] + temp_state: Option, + + #[serde(skip)] + pos_changed: bool, + #[serde(skip)] + pub operates: Vec, + #[serde(skip)] + holds: Vec, + /// -1 空, 0 空仓, 1 多 + #[serde(skip)] + pos: Pos, + #[serde(skip)] + last_event: Option, // 基于最近一次 LO 或 SO(与 Python 对齐,单一 last_event) + #[serde(skip)] + event_matcher: Option, + #[serde(skip)] + event_match_values: Vec>, + #[serde(skip)] + event_match_cache: Vec>, +} + +pub fn load_position(path: &Path) -> anyhow::Result { + // 读取文件内容 + let content = fs::read_to_string(path).with_context(|| format!("读取文件失败: {path:?}"))?; + // 反序列化 JSON + let mut position: Position = + serde_json::from_str(&content).with_context(|| format!("解析 JSON 失败: {path:?}"))?; + position.init_runtime_fields(); + Ok(position) +} + +impl Position { + fn init_runtime_fields(&mut self) { + self.normalize_event_hash_names(); + self.event_matcher = None; + self.event_match_values.clear(); + self.event_match_cache.clear(); + } + + /// 规范化运行时字段,确保事件名称、匹配缓存与 Python 基线一致。 + pub fn normalize_runtime_fields(&mut self) { + self.init_runtime_fields(); + } + + /// 获取当前仓位状态 + pub fn get_pos(&self) -> Pos { + self.pos + } + + /// 获取仓位是否发生变化 + pub fn get_pos_changed(&self) -> bool { + self.pos_changed + } + + fn normalize_event_hash_names(&mut self) { + for event in self.opens.iter_mut().chain(self.exits.iter_mut()) { + event.refresh_hash_name(); + } + } + + fn ensure_event_matcher(&mut self) { + if self.event_matcher.is_none() { + let matcher = PositionEventMatcher::compile_from_position(self); + self.event_match_values = vec![None; matcher.keys.len()]; + self.event_match_cache = vec![None; matcher.keys.len()]; + self.event_matcher = Some(matcher); + } + } + + fn event_tag_from_desc(desc: &str) -> &'static str { + if desc.starts_with("开多#") || desc.starts_with("LO#") { + "LO" + } else if desc.starts_with("开空#") || desc.starts_with("SO#") { + "SO" + } else if desc.starts_with("平多#") || desc.starts_with("LE#") { + "LE" + } else if desc.starts_with("平空#") || desc.starts_with("SE#") { + "SE" + } else { + "is_match" + } + } + + fn format_event_desc_for_python(desc: &str) -> String { + if desc.contains('@') { + desc.to_string() + } else { + format!("{desc}@{}", Self::event_tag_from_desc(desc)) + } + } + + pub fn save(&self, path: &Path) -> anyhow::Result<()> { + // 将 Position 序列化为 JSON 字符串 + let content = serde_json::to_string_pretty(self) + .with_context(|| format!("序列化 Position 失败: {path:?}"))?; + + // 将 JSON 写入文件 + fs::write(path, content).with_context(|| format!("写入文件失败: {path:?}"))?; + + Ok(()) + } + + /// 仅用于平仓优化 + pub fn with_event_hash_name(mut self, mode: &str, event_hash: &str) -> Self { + // 更新 name + if let Some(pos) = self.name.find('#') { + // 已有 # -> 替换 + self.name.truncate(pos); + } + self.name.push('#'); + self.name.push_str(mode); + self.name.push_str(event_hash); + self + } + + /// 仅用于开仓优化 + pub fn compute_md5_name(mut self) -> Self { + let mut context = md5::Context::new(); + context.consume(format!("{:?}", self.opens)); + context.consume(format!("{:?}", self.exits)); + context.consume(format!("{:?}", self.interval)); + context.consume(format!("{:?}", self.timeout)); + context.consume(format!("{:?}", self.stop_loss)); + context.consume(format!("{:?}", self.t0)); + context.consume(format!("{:?}", self.symbol)); + // context.consume(format!("{:?}", self.name)); + + let digest = context.finalize(); + let digest = hex::encode(*digest)[..8].to_uppercase(); + + // 更新 name + if let Some(pos) = self.name.find('#') { + // 已有 # -> 替换 + self.name.truncate(pos); + } + self.name.push('#'); + self.name.push_str(&digest); + self + } + + /// 返回一个迭代器,可以依次遍历 opens 和 exits + pub fn all_events(&self) -> impl Iterator { + self.opens.iter().chain(self.exits.iter()) + } + + /// 创建一个操作记录 + fn create_operate_record( + &mut self, + dt: DateTime, + bar_id: i32, + price: f64, + op: Operate, + op_desc: Option, + ) -> OperateRecord { + self.pos_changed = true; + OperateRecord { + symbol: self.symbol.to_string(), + dt, + bar_id, + price, + op, + op_desc, + pos: self.pos, + } + } + + /// 判断是否"非同日"(用于 T0 检测) + fn is_different_day(dt: DateTime, other: Option>) -> bool { + match other { + Some(o) => dt.date_naive() != o.date_naive(), + None => true, + } + } + + /// 更新持仓状态(对齐 Python czsc/py/objects.py Position.update) + pub fn update(&mut self, last_bar: LiteBar, last_signals: Rc>>) { + let _ = self.update_profiled(last_bar, last_signals); + } + + pub fn update_profiled( + &mut self, + last_bar: LiteBar, + last_signals: Rc>>, + ) -> PositionUpdateProfile { + self.update_profiled_with_signal_map(last_bar, Some(last_signals), None) + } + + pub fn update_profiled_with_signal_map( + &mut self, + last_bar: LiteBar, + last_signals: Option>>>, + signal_map: Option<&HashMap>, + ) -> PositionUpdateProfile { + if let Some(ref temp_state) = self.temp_state { + if temp_state.end_dt >= last_bar.dt { + warn!( + "请检查信号传入:最新信号时间: {} 在上次信号时间 {} 之前", + last_bar.dt, temp_state.end_dt + ); + return PositionUpdateProfile::default(); + } + } else { + // init + self.temp_state = Some(TempState { + end_dt: last_bar.dt, + last_lo_dt: None, + last_so_dt: None, + }); + } + let dt = last_bar.dt.fixed_offset(); + let price = last_bar.price; + let bar_id = last_bar.id; + + self.pos_changed = false; + let mut op = Operate::HO; + let mut op_desc: Option = None; + + let t_event = Instant::now(); + let owned_signal_map; + let signal_map = if let Some(m) = signal_map { + m + } else if let Some(last_signals) = last_signals { + owned_signal_map = { + let signals = last_signals.borrow(); + let mut m = HashMap::with_capacity(signals.len()); + for s in signals.iter() { + m.insert(s.key(), s.value()); + } + m + }; + &owned_signal_map + } else { + return PositionUpdateProfile::default(); + }; + // 事件匹配: 固定顺序 opens + exits(与 Python self.events = self.opens + self.exits 对齐) + self.ensure_event_matcher(); + if let Some(matcher) = self.event_matcher.as_ref() + && let Some(event_idx) = matcher.find_first_match( + signal_map, + &mut self.event_match_values, + &mut self.event_match_cache, + ) + { + let e = if event_idx < self.opens.len() { + &self.opens[event_idx] + } else { + &self.exits[event_idx - self.opens.len()] + }; + op = e.operate; + op_desc = Some(e.name.to_string()); + } + let event_match_ns = t_event.elapsed().as_nanos(); + + let t_fsm = Instant::now(); + // 更新 temp_state.end_dt 为当前信号时间 + if let Some(ref mut ts) = self.temp_state { + ts.end_dt = dt; + } + + // Python: 当有新的开仓 event 发生,更新 last_event + if op == Operate::LO || op == Operate::SO { + self.last_event = Some(LastEvent { + dt, + bar_id, + price, + op, + op_desc: op_desc.clone(), + }); + } + + // ---------- 开仓逻辑 ---------- + // Python L996-1009: if op == Operate.LO + if op == Operate::LO { + let allow_open_long = match self.temp_state.as_ref().and_then(|t| t.last_lo_dt) { + None => true, + Some(prev_dt) => { + let interval_secs = (dt - prev_dt).num_seconds(); + interval_secs > self.interval + } + }; + + if self.pos != Pos::Long && allow_open_long { + // 直接开多 + self.pos = Pos::Long; + if let Some(ref mut ts) = self.temp_state { + ts.last_lo_dt = Some(dt); + } + let rec = + self.create_operate_record(dt, bar_id, price, Operate::LO, op_desc.clone()); + self.operates.push(rec); + } else { + // interval 限制导致不能再次开多;如果当前是空头,则仅平空 + let can_close_short = self.pos == Pos::Short + && (self.t0 + || Self::is_different_day( + dt, + self.temp_state.as_ref().and_then(|t| t.last_so_dt), + )); + if can_close_short { + self.pos = Pos::Flat; + let rec = + self.create_operate_record(dt, bar_id, price, Operate::SE, op_desc.clone()); + self.operates.push(rec); + } + } + } + + // Python L1011-1024: if op == Operate.SO + if op == Operate::SO { + let allow_open_short = match self.temp_state.as_ref().and_then(|t| t.last_so_dt) { + None => true, + Some(prev_dt) => (dt - prev_dt).num_seconds() > self.interval, + }; + + if self.pos != Pos::Short && allow_open_short { + // 直接开空 + self.pos = Pos::Short; + if let Some(ref mut ts) = self.temp_state { + ts.last_so_dt = Some(dt); + } + let rec = + self.create_operate_record(dt, bar_id, price, Operate::SO, op_desc.clone()); + self.operates.push(rec); + } else { + // interval 限制导致不能再次开空;如果当前是多头,则仅平多 + let can_close_long = self.pos == Pos::Long + && (self.t0 + || Self::is_different_day( + dt, + self.temp_state.as_ref().and_then(|t| t.last_lo_dt), + )); + if can_close_long { + self.pos = Pos::Flat; + let rec = + self.create_operate_record(dt, bar_id, price, Operate::LE, op_desc.clone()); + self.operates.push(rec); + } + } + } + let fsm_ns = t_fsm.elapsed().as_nanos(); + + let t_risk = Instant::now(); + // ---------- 多头出场判断 ---------- + // Rust 侧保持“单 bar 只落一条有效平仓记录”: + // 事件平仓优先,其次止损,最后超时。 + if self.pos == Pos::Long { + let allowed_by_t0 = if let Some(ref ts) = self.temp_state { + self.t0 || Self::is_different_day(dt, ts.last_lo_dt) + } else { + true + }; + + // 提取 last_event 的值避免借用冲突 + let last_ev_snapshot = self.last_event.as_ref().map(|e| (e.price, e.bar_id)); + + if allowed_by_t0 && let Some((ev_price, ev_bar_id)) = last_ev_snapshot { + let exit_desc = if op == Operate::LE { + Some(op_desc.clone()) + } else if price / ev_price - 1.0 < -self.stop_loss / 10000.0 { + Some(Some(format!("平多@{}BP止损", self.stop_loss))) + } else if bar_id - ev_bar_id > self.timeout { + Some(Some(format!("平多@{}K超时", self.timeout))) + } else { + None + }; + + if let Some(exit_desc) = exit_desc { + self.pos = Pos::Flat; + let rec = self.create_operate_record(dt, bar_id, price, Operate::LE, exit_desc); + self.operates.push(rec); + } + } + } + + // ---------- 空头出场判断 ---------- + // Rust 侧保持“单 bar 只落一条有效平仓记录”: + // 事件平仓优先,其次止损,最后超时。 + if self.pos == Pos::Short { + let allowed_by_t0 = if let Some(ref ts) = self.temp_state { + self.t0 || Self::is_different_day(dt, ts.last_so_dt) + } else { + true + }; + + let last_ev_snapshot = self.last_event.as_ref().map(|e| (e.price, e.bar_id)); + + if allowed_by_t0 && let Some((ev_price, ev_bar_id)) = last_ev_snapshot { + let exit_desc = if op == Operate::SE { + Some(op_desc.clone()) + } else if 1.0 - price / ev_price < -self.stop_loss / 10000.0 { + Some(Some(format!("平空@{}BP止损", self.stop_loss))) + } else if bar_id - ev_bar_id > self.timeout { + Some(Some(format!("平空@{}K超时", self.timeout))) + } else { + None + }; + + if let Some(exit_desc) = exit_desc { + self.pos = Pos::Flat; + let rec = self.create_operate_record(dt, bar_id, price, Operate::SE, exit_desc); + self.operates.push(rec); + } + } + } + let risk_ns = t_risk.elapsed().as_nanos(); + + let t_holds = Instant::now(); + // Python L1072: self.holds.append({"dt": self.end_dt, "pos": self.pos, "price": price}) + if let Some(last_hold) = self.holds.last_mut() + && last_hold.price > 0.0 + { + let n1b_value = (price / last_hold.price - 1.0) * 10000.0; + last_hold.n1b = Some(n1b_value); + } + + self.holds.push(HoldRecord { + dt, + pos: self.pos, + price, + n1b: None, + }); + let holds_ns = t_holds.elapsed().as_nanos(); + + PositionUpdateProfile { + event_match_ns, + fsm_ns, + risk_ns, + holds_ns, + } + } + + pub fn pairs(&self) -> anyhow::Result { + let mut trade_pairs = TradePairsColumns::default(); + for (op1, op2) in self.operates.iter().zip(self.operates.iter().skip(1)) { + if op1.op != Operate::LO && op1.op != Operate::SO { + continue; + } + // 盈亏比例计算: + // 做多: 平仓价/开仓价 - 1 + // 做空 1 - 平仓价/开仓价 + let yield_profit_ratio = if op1.price == 0.0 { + None + } else if op1.op == Operate::LO { + Some(op2.price / op1.price - 1.0) + } else { + Some(1.0 - op2.price / op1.price) + }; + + // 按照Python版本添加策略标记字段 + trade_pairs.symbol.push(&self.symbol); + trade_pairs.strategy_mark.push(&self.name); + trade_pairs.direction.push(if op1.op == Operate::LO { + "多头" + } else { + "空头" + }); + trade_pairs.open_dt.push(op1.dt.naive_local()); + trade_pairs.close_dt.push(op2.dt.naive_local()); + trade_pairs.open_price.push(op1.price); + trade_pairs.close_price.push(op2.price); + trade_pairs.holding_bar.push(op2.bar_id - op1.bar_id); + trade_pairs + .event_sequence + .push(match (&op1.op_desc, &op2.op_desc) { + (None, None) => OP_DESC_NONE.to_string(), + (None, Some(desc)) => Self::format_event_desc_for_python(desc), + (Some(desc), None) => Self::format_event_desc_for_python(desc), + (Some(desc1), Some(desc2)) => format!( + "{} -> {}", + Self::format_event_desc_for_python(desc1), + Self::format_event_desc_for_python(desc2) + ), + }); + + // 按照Python版本计算持仓天数:使用浮点除法以支持小于1天的持仓 + // Python: (op2["dt"] - op1["dt"]).total_seconds() / (24 * 3600) + let holding_seconds = (op2.dt - op1.dt).num_seconds() as f64; + let holding_days = holding_seconds / 86400.0; // 86400 = 24 * 3600 + trade_pairs.holding_day.push(holding_days); + + // 转换为 BP 单位并按Python版本舍入到小数点后2位 + let yield_profit_ratio_bp = + yield_profit_ratio.map(|r| (r * 10000.0 * 100.0).round() / 100.0); + trade_pairs.yield_profit_ratio.push(yield_profit_ratio_bp); + } + + let df = df![ + "标的代码" => trade_pairs.symbol, + "策略标记" => trade_pairs.strategy_mark, + "交易方向" => trade_pairs.direction, + "开仓时间" => trade_pairs.open_dt, + "平仓时间" => trade_pairs.close_dt, + "开仓价格" => trade_pairs.open_price, + "平仓价格" => trade_pairs.close_price, + "持仓K线数" => trade_pairs.holding_bar, + "事件序列" => trade_pairs.event_sequence, + "持仓天数" => trade_pairs.holding_day, + "盈亏比例" => trade_pairs.yield_profit_ratio, + ] + .context("创建 Polars df 失败")?; + + Ok(df) + } + + pub fn holds(&self) -> anyhow::Result { + let holds = self.holds.clone(); + let holds = HoldColumns::from_records(holds); + + let df = holds + .into_df()? + .lazy() + .with_columns([ + // symbol 列 + lit(self.symbol.to_string()).alias("symbol"), + ]) + .collect() + .context("新增列 symbol 失败")?; + + Ok(df) + } +} + +const OP_DESC_NONE: &str = "无"; + +/// Python可见的Pos枚举包装器 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(name = "Pos", module = "czsc._native"))] +#[derive(Debug, Clone)] +pub struct PyPos { + pub inner: Pos, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl PyPos { + #[classmethod] + fn short(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Pos::Short } + } + + #[classmethod] + fn flat(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Pos::Flat } + } + + #[classmethod] + fn long(_cls: &Bound<'_, pyo3::types::PyType>) -> Self { + Self { inner: Pos::Long } + } + + fn __str__(&self) -> String { + self.inner.to_string() + } + + fn __repr__(&self) -> String { + format!("PyPos::{:?}", self.inner) + } + + fn __eq__(&self, other: &Self) -> bool { + self.inner == other.inner + } + + /// 加法运算,用于numpy.mean等数学操作 + fn __add__(&self, other: &Self) -> f64 { + self.inner.to_f64() + other.inner.to_f64() + } + + /// 右加法运算 + fn __radd__(&self, other: f64) -> f64 { + other + self.inner.to_f64() + } + + /// 转换为浮点数,用于数学运算 + fn __float__(&self) -> f64 { + self.inner.to_f64() + } + + /// 整数转换 + fn __int__(&self) -> i32 { + self.inner.to_f64() as i32 + } + + /// 比较运算符 - 小于 + fn __lt__(&self, other: &Self) -> bool { + self.inner.to_f64() < other.inner.to_f64() + } + + /// 比较运算符 - 小于等于 + fn __le__(&self, other: &Self) -> bool { + self.inner.to_f64() <= other.inner.to_f64() + } + + /// 比较运算符 - 大于 + fn __gt__(&self, other: &Self) -> bool { + self.inner.to_f64() > other.inner.to_f64() + } + + /// 比较运算符 - 大于等于 + fn __ge__(&self, other: &Self) -> bool { + self.inner.to_f64() >= other.inner.to_f64() + } +} + +/// Python可见的LiteBar包装器 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(name = "LiteBar", module = "czsc._native"))] +#[derive(Debug, Clone)] +pub struct PyLiteBar { + pub inner: LiteBar, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl PyLiteBar { + #[new] + fn new_py(id: i32, dt: f64, price: f64) -> PyResult { + use chrono::{DateTime, FixedOffset, TimeZone, Utc}; + + let dt_utc = DateTime::from_timestamp(dt as i64, 0) + .ok_or_else(|| PyValueError::new_err("无效的时间戳"))?; + let dt = dt_utc.with_timezone(&FixedOffset::east_opt(0).unwrap()); + + Ok(Self { + inner: LiteBar { id, dt, price }, + }) + } + + #[getter] + fn id(&self) -> i32 { + self.inner.id + } + + #[getter] + fn dt(&self) -> f64 { + self.inner.dt.timestamp() as f64 + } + + #[getter] + fn price(&self) -> f64 { + self.inner.price + } + + fn __repr__(&self) -> String { + format!( + "PyLiteBar(id={}, dt={}, price={})", + self.inner.id, + self.inner.dt.timestamp(), + self.inner.price + ) + } +} + +/// Python可见的Position包装器 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(name = "Position", module = "czsc._native"))] +#[derive(Debug, Clone)] +pub struct PyPosition { + pub inner: Position, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl PyPosition { + #[new] + #[pyo3(signature = (symbol, opens, exits = vec![], interval = 0, timeout = 1000, stop_loss = 1000.0, t0 = false, name = None))] + #[allow(clippy::too_many_arguments)] + fn new_py( + symbol: String, + opens: Vec, + exits: Vec, + interval: i64, + timeout: i32, + stop_loss: f64, + t0: bool, + name: Option, + ) -> Self { + let opens: Vec = opens.into_iter().map(|e| e.inner).collect(); + let exits: Vec = exits.into_iter().map(|e| e.inner).collect(); + + let name = name.unwrap_or_else(|| { + if !opens.is_empty() { + opens[0].name.clone() + } else { + "DefaultPosition".to_string() + } + }); + + let inner = Position { + opens, + exits, + interval, + timeout, + stop_loss, + t0, + name, + symbol, + temp_state: None, + pos_changed: false, + operates: Vec::new(), + holds: Vec::new(), + pos: Pos::Flat, + last_event: None, + event_matcher: None, + event_match_values: Vec::new(), + event_match_cache: Vec::new(), + }; + let mut inner = inner; + inner.init_runtime_fields(); + + Self { inner } + } + + #[classmethod] + fn load_from_file(_cls: &Bound<'_, pyo3::types::PyType>, path: String) -> PyResult { + let position = load_position(Path::new(&path)) + .map_err(|e| PyValueError::new_err(format!("加载Position失败: {e}")))?; + Ok(Self { inner: position }) + } + + #[classmethod] + fn from_json(_cls: &Bound<'_, pyo3::types::PyType>, json_str: String) -> PyResult { + let mut inner: Position = serde_json::from_str(&json_str) + .map_err(|e| PyValueError::new_err(format!("JSON解析失败: {e}")))?; + inner.init_runtime_fields(); + Ok(Self { inner }) + } + + #[getter] + fn opens(&self) -> Vec { + self.inner + .opens + .iter() + .map(|e| PyEvent { inner: e.clone() }) + .collect() + } + + #[getter] + fn exits(&self) -> Vec { + self.inner + .exits + .iter() + .map(|e| PyEvent { inner: e.clone() }) + .collect() + } + + #[getter] + fn interval(&self) -> i64 { + self.inner.interval + } + + #[getter] + fn timeout(&self) -> i32 { + self.inner.timeout + } + + #[getter] + fn stop_loss(&self) -> f64 { + self.inner.stop_loss + } + + #[getter] + fn t0(&self) -> bool { + self.inner.t0 + } + + #[getter] + fn name(&self) -> String { + self.inner.name.clone() + } + + #[getter] + fn symbol(&self) -> String { + self.inner.symbol.clone() + } + + #[getter] + fn pos(&self) -> f64 { + self.inner.pos.to_f64() + } + + #[getter] + fn pos_changed(&self) -> bool { + self.inner.pos_changed + } + + /// 获取最新信号时间 + #[getter] + fn end_dt(&self) -> Option { + self.inner + .temp_state + .as_ref() + .map(|ts| ts.end_dt.timestamp() as f64) + } + + /// 获取操作记录列表 + #[getter] + fn operates(&self, py: Python) -> PyResult> { + let mut result = Vec::new(); + + for op_record in &self.inner.operates { + let dict = pyo3::types::PyDict::new(py); + + // 转换时间戳为 pandas 兼容格式 + let timestamp = op_record.dt.timestamp() as f64; + dict.set_item("dt", timestamp)?; + dict.set_item("symbol", &op_record.symbol)?; + dict.set_item("bar_id", op_record.bar_id)?; + dict.set_item("price", op_record.price)?; + dict.set_item("op", op_record.op.to_string())?; + dict.set_item("op_desc", &op_record.op_desc)?; + dict.set_item("pos", op_record.pos.to_string())?; + + result.push(dict.into()); + } + + Ok(result) + } + + /// 保存到文件 + fn save(&self, path: String) -> PyResult<()> { + self.inner + .save(Path::new(&path)) + .map_err(|e| PyValueError::new_err(format!("保存Position失败: {e}"))) + } + + /// 转换为JSON字符串 + fn to_json(&self) -> PyResult { + serde_json::to_string_pretty(&self.inner) + .map_err(|e| PyValueError::new_err(format!("JSON序列化失败: {e}"))) + } + + /// 获取所有相关事件 + fn all_events(&self) -> Vec { + self.inner + .all_events() + .map(|e| PyEvent { inner: e.clone() }) + .collect() + } + + /// 更新仓位状态(兼容单参数调用) + #[pyo3(signature = (arg1, arg2 = None))] + fn update(&mut self, arg1: PyObject, arg2: Option) -> PyResult<()> { + use pyo3::types::{PyDict, PyMapping}; + use std::collections::HashSet; + + Python::with_gil(|py| { + if let Some(arg2_val) = arg2 { + // 两个参数的情况:update(last_bar, last_signals) + let last_bar: PyLiteBar = arg1.extract(py)?; + let last_signals: Vec = arg2_val.extract(py)?; + + let signals: HashSet = last_signals.into_iter().map(|s| s.inner).collect(); + let signals_ref = Rc::new(RefCell::new(signals)); + + self.inner.update(last_bar.inner, signals_ref); + } else { + // 一个参数的情况:update(signals_dict) + // Python版本期望字典格式: {'symbol': 'BTC', 'dt': Timestamp(...), 'id': 1, 'close': 100.0, '信号key': '信号value', ...} + + if let Ok(signals_dict) = arg1.downcast_bound::(py) { + // 1. 提取必需字段:dt, id, close + let dt_obj = signals_dict.get_item("dt")?.ok_or_else(|| { + PyValueError::new_err("Missing 'dt' field in signals dict") + })?; + let id_obj = signals_dict.get_item("id")?.ok_or_else(|| { + PyValueError::new_err("Missing 'id' field in signals dict") + })?; + let close_obj = signals_dict.get_item("close")?.ok_or_else(|| { + PyValueError::new_err("Missing 'close' field in signals dict") + })?; + + // 2. 转换dt - 支持多种格式 + use chrono::{DateTime, FixedOffset, Utc}; + let dt: DateTime = if let Ok(timestamp) = dt_obj.extract::() { + // Unix时间戳(秒) + DateTime::from_timestamp(timestamp as i64, 0) + .ok_or_else(|| PyValueError::new_err("Invalid timestamp"))? + .with_timezone(&FixedOffset::east_opt(0).unwrap()) + } else { + // 尝试调用timestamp()方法(pandas.Timestamp对象) + if let Ok(timestamp_method) = dt_obj.getattr("timestamp") { + let timestamp: f64 = timestamp_method.call0()?.extract()?; + DateTime::from_timestamp(timestamp as i64, 0) + .ok_or_else(|| { + PyValueError::new_err("Invalid timestamp from pandas") + })? + .with_timezone(&FixedOffset::east_opt(0).unwrap()) + } else { + return Err(PyValueError::new_err("Cannot convert 'dt' to timestamp")); + } + }; + + // 3. 提取id和close + let bar_id: i32 = id_obj.extract()?; + let price: f64 = close_obj.extract()?; + + // 4. 构造LiteBar + let lite_bar = LiteBar { + id: bar_id, + dt, + price, + }; + + // 5. 从字典中提取信号,构造Signal集合 + let mut signal_set = HashSet::new(); + + for (key, value) in signals_dict.iter() { + let key_str = if let Ok(s) = key.extract::() { + s + } else { + key.str()?.extract::()? + }; + + // 跳过非信号字段 + if key_str == "symbol" + || key_str == "dt" + || key_str == "id" + || key_str == "close" + || key_str == "open" + || key_str == "high" + || key_str == "low" + || key_str == "vol" + || key_str == "amount" + || key_str == "freq" + { + continue; + } + + let value_str = if let Ok(s) = value.extract::() { + s + } else { + value.str()?.extract::()? + }; + + // 构造完整信号字符串 + // 支持两种格式: + // 1. 简单值: value='看多' -> key_看多_任意_任意_0 + // 2. 完整4段值: value='看多_任意_任意_0' -> key_看多_任意_任意_0 + let signal_str = if value_str.split('_').count() == 4 { + // 已经是完整的4段值,直接拼接 + format!("{key_str}_{value_str}") + } else { + // 简单值,添加默认后缀 + format!("{key_str}_{value_str}_任意_任意_0") + }; + + // 尝试创建Signal并添加到集合 + if let Ok(signal) = Signal::from_str(&signal_str) { + signal_set.insert(signal); + } + } + + let signals_ref = Rc::new(RefCell::new(signal_set)); + self.inner.update(lite_bar, signals_ref); + } else if let Ok(signals_vec) = arg1.extract::>(py) { + // 如果是PySignal列表(向后兼容) + let signal_set: HashSet = + signals_vec.into_iter().map(|s| s.inner).collect(); + let signals_ref = Rc::new(RefCell::new(signal_set)); + + // 创建虚拟LiteBar(仅用于兼容旧接口) + use chrono::{DateTime, FixedOffset, Utc}; + let dummy_bar = LiteBar { + id: 0, + dt: Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap()), + price: 0.0, + }; + self.inner.update(dummy_bar, signals_ref); + } else { + return Err(PyValueError::new_err( + "Expected dict with 'dt', 'id', 'close' fields, or Vec", + )); + } + } + Ok(()) + }) + } + + /// 获取交易对数据(返回记录列表,兼容pandas.DataFrame构造) + #[getter] + fn pairs(&self) -> PyResult> { + let df = self + .inner + .pairs() + .map_err(|e| PyValueError::new_err(format!("生成交易对数据失败: {e}")))?; + + // 将DataFrame转换为记录列表 + Python::with_gil(|py| { + let list = pyo3::types::PyList::empty(py); + + let height = df.height(); + for i in 0..height { + let record = pyo3::types::PyDict::new(py); + + // 获取列数据(使用中文列名) + + if let Ok(symbol_col) = df.column("标的代码") + && let Ok(value) = symbol_col.get(i) + { + record.set_item("标的代码", value.to_string())?; + } + + if let Ok(direction_col) = df.column("交易方向") + && let Ok(value) = direction_col.get(i) + { + record.set_item("交易方向", value.to_string())?; + } + + if let Ok(open_dt_col) = df.column("开仓时间") + && let Ok(value) = open_dt_col.get(i) + { + record.set_item("开仓时间", value.to_string())?; + } + + if let Ok(close_dt_col) = df.column("平仓时间") + && let Ok(value) = close_dt_col.get(i) + { + record.set_item("平仓时间", value.to_string())?; + } + + if let Ok(open_price_col) = df.column("开仓价格") + && let Ok(value) = open_price_col.get(i) + { + record.set_item("开仓价格", value.try_extract::().unwrap_or(0.0))?; + } + + if let Ok(close_price_col) = df.column("平仓价格") + && let Ok(value) = close_price_col.get(i) + { + record.set_item("平仓价格", value.try_extract::().unwrap_or(0.0))?; + } + + if let Ok(holding_bar_col) = df.column("持仓K线数") + && let Ok(value) = holding_bar_col.get(i) + { + record.set_item("持仓K线数", value.try_extract::().unwrap_or(0))?; + } + + if let Ok(event_sequence_col) = df.column("事件序列") + && let Ok(value) = event_sequence_col.get(i) + { + record.set_item("事件序列", value.to_string())?; + } + + if let Ok(holding_day_col) = df.column("持仓天数") + && let Ok(value) = holding_day_col.get(i) + { + record.set_item("持仓天数", value.try_extract::().unwrap_or(0.0))?; + } + + if let Ok(yield_profit_ratio_col) = df.column("盈亏比例") + && let Ok(value) = yield_profit_ratio_col.get(i) + { + match value { + polars::prelude::AnyValue::Null => { + record.set_item("盈亏比例", py.None())?; + } + _ => { + if let Ok(ratio) = value.try_extract::() { + record.set_item("盈亏比例", ratio)?; + } else { + record.set_item("盈亏比例", py.None())?; + } + } + } + } + + list.append(record)?; + } + + Ok(list.into()) + }) + } + + /// 获取持仓历史数据(返回记录列表,兼容历史版本) + #[getter] + fn holds(&self) -> PyResult> { + Python::with_gil(|py| { + let list = pyo3::types::PyList::empty(py); + + for hold_record in &self.inner.holds { + let record = pyo3::types::PyDict::new(py); + + // 转换时间戳为 Python datetime 兼容格式 + let timestamp = hold_record.dt.timestamp() as f64; + record.set_item("dt", timestamp)?; + record.set_item("pos", hold_record.pos.to_f64() as i32)?; + record.set_item("price", hold_record.price)?; + + // 添加n1b字段 + if let Some(n1b_value) = hold_record.n1b { + record.set_item("n1b", n1b_value)?; + } else { + record.set_item("n1b", py.None())?; + } + + list.append(record)?; + } + + Ok(list.into()) + }) + } + + #[getter] + fn unique_signals(&self) -> Vec { + let mut signals = HashSet::new(); + + // 收集所有opens事件的信号字符串 + for event in &self.inner.opens { + for signal in &event.signals_all { + signals.insert(signal.to_string()); + } + for signal in &event.signals_any { + signals.insert(signal.to_string()); + } + for signal in &event.signals_not { + signals.insert(signal.to_string()); + } + } + + // 收集所有exits事件的信号字符串 + for event in &self.inner.exits { + for signal in &event.signals_all { + signals.insert(signal.to_string()); + } + for signal in &event.signals_any { + signals.insert(signal.to_string()); + } + for signal in &event.signals_not { + signals.insert(signal.to_string()); + } + } + + signals.into_iter().collect() + } + + #[getter] + fn events(&self) -> Vec { + self.all_events() + } + + /// 支持 pickle 序列化 - 使用 __reduce__ 方法 + fn __reduce__(&self, py: Python) -> PyResult { + use pyo3::IntoPyObject; + + // 构造函数参数 + let opens: Vec = self + .inner + .opens + .iter() + .map(|e| PyEvent { inner: e.clone() }) + .collect(); + let exits: Vec = self + .inner + .exits + .iter() + .map(|e| PyEvent { inner: e.clone() }) + .collect(); + + let args = ( + self.inner.symbol.clone(), + opens, + exits, + self.inner.interval, + self.inner.timeout, + self.inner.stop_loss, + self.inner.t0, + Some(self.inner.name.clone()), + ) + .into_pyobject(py)?; + + // 返回 (constructor, args) + let constructor = py.get_type::(); + let result = (constructor, args).into_pyobject(py)?; + Ok(result.into()) + } + + /// 导出Position数据为Python字典 + #[pyo3(signature = (with_data = true))] + fn dump(&self, with_data: bool) -> PyResult { + Python::with_gil(|py| { + let dict = pyo3::types::PyDict::new(py); + + // 基本属性 - 按期望顺序添加 + dict.set_item("symbol", &self.inner.symbol)?; + dict.set_item("name", &self.inner.name)?; + + // opens和exits事件 - 使用中文 operate,保持key-value字典格式 + let opens_list = pyo3::types::PyList::empty(py); + for event in &self.inner.opens { + let event_dict = pyo3::types::PyDict::new(py); + event_dict.set_item("name", &event.name)?; + event_dict.set_item("operate", event.operate.to_chinese())?; + + let signals_all_list = pyo3::types::PyList::empty(py); + for signal in &event.signals_all { + let signal_dict = pyo3::types::PyDict::new(py); + signal_dict.set_item("key", signal.key())?; + signal_dict.set_item("value", signal.value())?; + signals_all_list.append(signal_dict)?; + } + event_dict.set_item("signals_all", signals_all_list)?; + + let signals_any_list = pyo3::types::PyList::empty(py); + for signal in &event.signals_any { + let signal_dict = pyo3::types::PyDict::new(py); + signal_dict.set_item("key", signal.key())?; + signal_dict.set_item("value", signal.value())?; + signals_any_list.append(signal_dict)?; + } + event_dict.set_item("signals_any", signals_any_list)?; + + let signals_not_list = pyo3::types::PyList::empty(py); + for signal in &event.signals_not { + let signal_dict = pyo3::types::PyDict::new(py); + signal_dict.set_item("key", signal.key())?; + signal_dict.set_item("value", signal.value())?; + signals_not_list.append(signal_dict)?; + } + event_dict.set_item("signals_not", signals_not_list)?; + + opens_list.append(event_dict)?; + } + dict.set_item("opens", opens_list)?; + + let exits_list = pyo3::types::PyList::empty(py); + for event in &self.inner.exits { + let event_dict = pyo3::types::PyDict::new(py); + event_dict.set_item("name", &event.name)?; + event_dict.set_item("operate", event.operate.to_chinese())?; + + let signals_all_list = pyo3::types::PyList::empty(py); + for signal in &event.signals_all { + let signal_dict = pyo3::types::PyDict::new(py); + signal_dict.set_item("key", signal.key())?; + signal_dict.set_item("value", signal.value())?; + signals_all_list.append(signal_dict)?; + } + event_dict.set_item("signals_all", signals_all_list)?; + + let signals_any_list = pyo3::types::PyList::empty(py); + for signal in &event.signals_any { + let signal_dict = pyo3::types::PyDict::new(py); + signal_dict.set_item("key", signal.key())?; + signal_dict.set_item("value", signal.value())?; + signals_any_list.append(signal_dict)?; + } + event_dict.set_item("signals_any", signals_any_list)?; + + let signals_not_list = pyo3::types::PyList::empty(py); + for signal in &event.signals_not { + let signal_dict = pyo3::types::PyDict::new(py); + signal_dict.set_item("key", signal.key())?; + signal_dict.set_item("value", signal.value())?; + signals_not_list.append(signal_dict)?; + } + event_dict.set_item("signals_not", signals_not_list)?; + + exits_list.append(event_dict)?; + } + dict.set_item("exits", exits_list)?; + + // 剩余的基本属性 - 按期望顺序 + dict.set_item("interval", self.inner.interval)?; + dict.set_item("timeout", self.inner.timeout)?; + dict.set_item("stop_loss", self.inner.stop_loss)?; + dict.set_item("T0", self.inner.t0)?; + + // 如果需要包含数据 + if with_data { + // 获取pairs数据 + if let Ok(pairs_list) = self.pairs() { + dict.set_item("pairs", pairs_list)?; + } + + // 获取holds数据 + if let Ok(holds_list) = self.holds() { + dict.set_item("holds", holds_list)?; + } + } + + Ok(dict.into()) + }) + } + + /// 从字典数据加载Position + #[classmethod] + fn load(_cls: &Bound<'_, pyo3::types::PyType>, data: PyObject) -> PyResult { + Python::with_gil(|py| { + // 首先尝试直接转换为字典 + let dict = match data.downcast_bound::(py) { + Ok(d) => d.clone(), + Err(_) => { + // 如果失败,尝试作为字符串处理 + if let Ok(s) = data.downcast_bound::(py) { + let json_str: String = s.extract()?; + // 使用Python的json模块解析 + let json_module = py.import("json")?; + let parsed = json_module.call_method1("loads", (json_str,))?; + parsed.downcast::()?.clone() + } else { + return Err(pyo3::exceptions::PyTypeError::new_err( + "Expected dict or JSON string", + )); + } + } + }; + + // symbol字段可选,如果不存在则使用默认值 + let symbol: String = dict + .get_item("symbol")? + .map(|item| item.extract()) + .transpose()? + .unwrap_or_else(|| "UNKNOWN".to_string()); + let name: String = dict + .get_item("name")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("Missing 'name' field"))? + .extract()?; + let interval: i64 = dict + .get_item("interval")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("Missing 'interval' field"))? + .extract()?; + let timeout: i32 = dict + .get_item("timeout")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("Missing 'timeout' field"))? + .extract()?; + let stop_loss: f64 = dict + .get_item("stop_loss")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("Missing 'stop_loss' field"))? + .extract()?; + let t0: bool = dict + .get_item("T0")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("Missing 'T0' field"))? + .extract()?; + + // 解析opens事件 + let opens_data = dict + .get_item("opens")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("Missing 'opens' field"))?; + let opens_list = opens_data.downcast::()?; + let mut opens = Vec::new(); + + for item in opens_list.iter() { + let event_dict = item.downcast::()?; + let event_name: String = event_dict + .get_item("name")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'name' in event") + })? + .extract()?; + let operate_str: String = event_dict + .get_item("operate")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'operate' in event") + })? + .extract()?; + let operate = parse_operate(&operate_str) + .map_err(|e| PyValueError::new_err(format!("无法解析operate: {e}")))?; + + // 解析 signals_all + let signals_all_data = event_dict.get_item("signals_all")?.ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'signals_all' in event") + })?; + let signals_all_list = signals_all_data.downcast::()?; + let mut signals_all = Vec::new(); + + for signal_item in signals_all_list.iter() { + let signal_dict = signal_item.downcast::()?; + let key: String = signal_dict + .get_item("key")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'key' in signal") + })? + .extract()?; + let value: String = signal_dict + .get_item("value")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'value' in signal") + })? + .extract()?; + let signal_str = format!("{key}_{value}"); + if let Ok(signal) = Signal::from_str(&signal_str) { + signals_all.push(signal); + } + } + + // 解析 signals_any + let mut signals_any = Vec::new(); + if let Some(signals_any_data) = event_dict.get_item("signals_any")? { + let signals_any_list = signals_any_data.downcast::()?; + for signal_item in signals_any_list.iter() { + let signal_dict = signal_item.downcast::()?; + let key: String = signal_dict + .get_item("key")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'key' in signal") + })? + .extract()?; + let value: String = signal_dict + .get_item("value")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'value' in signal") + })? + .extract()?; + let signal_str = format!("{key}_{value}"); + if let Ok(signal) = Signal::from_str(&signal_str) { + signals_any.push(signal); + } + } + } + + // 解析 signals_not + let mut signals_not = Vec::new(); + if let Some(signals_not_data) = event_dict.get_item("signals_not")? { + let signals_not_list = signals_not_data.downcast::()?; + for signal_item in signals_not_list.iter() { + let signal_dict = signal_item.downcast::()?; + let key: String = signal_dict + .get_item("key")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'key' in signal") + })? + .extract()?; + let value: String = signal_dict + .get_item("value")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'value' in signal") + })? + .extract()?; + let signal_str = format!("{key}_{value}"); + if let Ok(signal) = Signal::from_str(&signal_str) { + signals_not.push(signal); + } + } + } + + let event = Event { + name: event_name, + operate, + signals_all, + signals_any, + signals_not, + sha256: String::new(), + }; + opens.push(event); + } + + let exits_data = dict + .get_item("exits")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("Missing 'exits' field"))?; + let exits_list = exits_data.downcast::()?; + let mut exits = Vec::new(); + + for item in exits_list.iter() { + let event_dict = item.downcast::()?; + let event_name: String = event_dict + .get_item("name")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'name' in event") + })? + .extract()?; + let operate_str: String = event_dict + .get_item("operate")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'operate' in event") + })? + .extract()?; + let operate = parse_operate(&operate_str) + .map_err(|e| PyValueError::new_err(format!("无法解析operate: {e}")))?; + + // 解析 signals_all + let signals_all_data = event_dict.get_item("signals_all")?.ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'signals_all' in event") + })?; + let signals_all_list = signals_all_data.downcast::()?; + let mut signals_all = Vec::new(); + + for signal_item in signals_all_list.iter() { + let signal_dict = signal_item.downcast::()?; + let key: String = signal_dict + .get_item("key")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'key' in signal") + })? + .extract()?; + let value: String = signal_dict + .get_item("value")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'value' in signal") + })? + .extract()?; + let signal_str = format!("{key}_{value}"); + if let Ok(signal) = Signal::from_str(&signal_str) { + signals_all.push(signal); + } + } + + // 解析 signals_any + let mut signals_any = Vec::new(); + if let Some(signals_any_data) = event_dict.get_item("signals_any")? { + let signals_any_list = signals_any_data.downcast::()?; + for signal_item in signals_any_list.iter() { + let signal_dict = signal_item.downcast::()?; + let key: String = signal_dict + .get_item("key")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'key' in signal") + })? + .extract()?; + let value: String = signal_dict + .get_item("value")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'value' in signal") + })? + .extract()?; + let signal_str = format!("{key}_{value}"); + if let Ok(signal) = Signal::from_str(&signal_str) { + signals_any.push(signal); + } + } + } + + // 解析 signals_not + let mut signals_not = Vec::new(); + if let Some(signals_not_data) = event_dict.get_item("signals_not")? { + let signals_not_list = signals_not_data.downcast::()?; + for signal_item in signals_not_list.iter() { + let signal_dict = signal_item.downcast::()?; + let key: String = signal_dict + .get_item("key")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'key' in signal") + })? + .extract()?; + let value: String = signal_dict + .get_item("value")? + .ok_or_else(|| { + pyo3::exceptions::PyKeyError::new_err("Missing 'value' in signal") + })? + .extract()?; + let signal_str = format!("{key}_{value}"); + if let Ok(signal) = Signal::from_str(&signal_str) { + signals_not.push(signal); + } + } + } + + let event = Event { + name: event_name, + operate, + signals_all, + signals_any, + signals_not, + sha256: String::new(), + }; + exits.push(event); + } + + let inner = Position { + opens, + exits, + interval, + timeout, + stop_loss, + t0, + name, + symbol, + temp_state: None, + pos_changed: false, + operates: Vec::new(), + holds: Vec::new(), + pos: Pos::Flat, + last_event: None, + event_matcher: None, + event_match_values: Vec::new(), + event_match_cache: Vec::new(), + }; + let mut inner = inner; + inner.init_runtime_fields(); + + Ok(Self { inner }) + }) + } + + fn __repr__(&self) -> String { + format!( + "PyPosition(name='{}', symbol='{}', opens={}, exits={}, interval={})", + self.inner.name, + self.inner.symbol, + self.inner.opens.len(), + self.inner.exits.len(), + self.inner.interval + ) + } +} diff --git a/crates/czsc-core/src/objects/signal.rs b/crates/czsc-core/src/objects/signal.rs new file mode 100644 index 000000000..f50bd7852 --- /dev/null +++ b/crates/czsc-core/src/objects/signal.rs @@ -0,0 +1,567 @@ +// czsc-only: pyo3 imports gated behind the `python` feature (rs-czsc 47ef6efa +// relied on `#![allow(unused)]` to mask the bare imports when the feature was +// off; we make the gating explicit so czsc-core builds in non-python mode). +// See docs/MIGRATION_NOTES.md §2.4. +#![allow(unused)] +use anyhow::{Context, anyhow, bail}; +use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Visitor}; +use std::borrow::Cow; +use std::fmt::{self, Display}; +use std::hash::{Hash, Hasher}; +use std::str::FromStr; + +#[cfg(feature = "python")] +use pyo3::exceptions::PyValueError; +#[cfg(feature = "python")] +use pyo3::prelude::*; +#[cfg(feature = "python")] +use pyo3::types::{PyDict, PyDictMethods}; +#[cfg(feature = "python")] +use pyo3::{IntoPyObject, Py, PyObject, PyResult, Python}; + +#[cfg(feature = "python")] +use super::operate::Operate; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::gen_stub_pyfunction; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +pub(crate) const ANY: &str = "任意"; + +#[derive(Clone, Debug)] +pub struct SignalRef<'a> { + // 完整的信号字符串 + signal: Cow<'a, str>, + + // 信号名称字段 (k1_k2_k3) + k1: Cow<'a, str>, + k2: Cow<'a, str>, + k3: Cow<'a, str>, + + // 信号值字段 (v1_v2_v3) + v1: Cow<'a, str>, + v2: Cow<'a, str>, + v3: Cow<'a, str>, + + // 分数 + score: i32, +} + +impl<'a> Hash for SignalRef<'a> { + fn hash(&self, state: &mut H) { + self.signal.hash(state); + } +} + +pub type Signal = SignalRef<'static>; + +/// Python可见的Signal包装器 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(name = "Signal", module = "czsc._native"))] +#[derive(Debug, Clone)] +pub struct PySignal { + pub(crate) inner: Signal, +} + +impl PySignal { + /// Wrap an inner [`Signal`] for Python exposure. The constructor is + /// public so downstream crates (notably `czsc-python`'s signal + /// dispatcher) can return signal objects without round-tripping + /// through the string parser. + pub fn from_inner(inner: Signal) -> Self { + Self { inner } + } +} + +impl From for PySignal { + fn from(inner: Signal) -> Self { + Self::from_inner(inner) + } +} + +impl FromStr for Signal { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + // Python格式: k1_k2_k3_v1_v2_v3_score (7个部分,6个下划线) + let parts: Vec<&str> = s.split('_').collect(); + if parts.len() != 7 { + bail!("Signal格式无效:应该为 k1_k2_k3_v1_v2_v3_score 格式 (7个部分)") + } + + // 验证score + let score: i32 = parts[6].parse().context("无法解析score")?; + if !(0..=100).contains(&score) { + bail!("score 必须在0~100之间"); + } + + Ok(SignalRef { + signal: Cow::Owned(s.to_string()), + k1: Cow::Owned(parts[0].to_string()), + k2: Cow::Owned(parts[1].to_string()), + k3: Cow::Owned(parts[2].to_string()), + v1: Cow::Owned(parts[3].to_string()), + v2: Cow::Owned(parts[4].to_string()), + v3: Cow::Owned(parts[5].to_string()), + score, + }) + } +} + +impl<'a> SignalRef<'a> { + /// 获取信号的key部分,按照Python逻辑过滤掉"任意" + pub fn key(&self) -> String { + let mut key_parts = Vec::new(); + for k in [&self.k1, &self.k2, &self.k3] { + if k.as_ref() != ANY { + key_parts.push(k.as_ref()); + } + } + if key_parts.is_empty() { + String::new() + } else { + key_parts.join("_") + } + } + + /// 获取信号的value部分:v1_v2_v3_score + pub fn value(&self) -> String { + format!("{}_{}_{}_{}", self.v1, self.v2, self.v3, self.score) + } + + /// 按照Python逻辑实现信号匹配 + pub fn is_match(&self, signal_dict: &std::collections::HashMap) -> bool { + let key = self.key(); + if let Some(value) = signal_dict.get(&key) { + // 解析字典中的value为v1, v2, v3, score + let value_parts: Vec<&str> = value.split('_').collect(); + if value_parts.len() != 4 { + return false; + } + + let (v1, v2, v3, score_str) = ( + value_parts[0], + value_parts[1], + value_parts[2], + value_parts[3], + ); + let score: i32 = score_str.parse().unwrap_or(0); + + // 匹配逻辑:score >= self.score 且各值匹配或为"任意" + if score >= self.score + && (v1 == self.v1.as_ref() || self.v1.as_ref() == ANY) + && (v2 == self.v2.as_ref() || self.v2.as_ref() == ANY) + && (v3 == self.v3.as_ref() || self.v3.as_ref() == ANY) + { + return true; + } + } + false + } +} + +// 这个impl块在前面已经有了,删除这个重复的块 + +impl<'a> Display for SignalRef<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.signal) + } +} + +impl<'a> PartialEq for SignalRef<'a> { + fn eq(&self, other: &Self) -> bool { + self.signal == other.signal + } +} + +impl<'a> Eq for SignalRef<'a> {} + +#[cfg(feature = "python")] +impl<'py> FromPyObject<'py> for Signal { + fn extract_bound(ob: &Bound<'py, pyo3::PyAny>) -> PyResult { + // 如果是 str,直接解析 + if let Ok(s) = ob.extract::() { + let signal = Self::from_str(&s).map_err(|err| { + pyo3::exceptions::PyValueError::new_err(format!("无法解析 Signal:{err}")) + })?; + return Ok(signal); + } + + // 非法类型 + Err(pyo3::exceptions::PyValueError::new_err( + "期望 str:示例 'k1_k2_k3_v1_v2_v3_score'。", + )) + } +} + +impl<'a> Serialize for SignalRef<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&format!("{self}")) + } +} + +impl<'de> Deserialize<'de> for SignalRef<'static> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct SignalVisitor<'a> { + marker: std::marker::PhantomData<&'a ()>, + } + + impl<'de, 'a> Visitor<'de> for SignalVisitor<'a> { + type Value = SignalRef<'a>; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("键(3 个段)_值(4 个段,最后一段是分数)") + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + Signal::from_str(s).map_err(|e| E::custom(e)) + } + } + + deserializer.deserialize_str(SignalVisitor { + marker: std::marker::PhantomData, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct ParsedSignalDoc { + /// 参数模板,例如: Some("{freq}_D1_表里关系V230101") + pub param_template: Option, + /// Signal 列表,顺序保留 + pub signals: Vec, +} + +/// 从 doc 中查找第一个参数模板以及所有 Signal('...') 字符串。 +/// 实现思路:手工扫描(find + 索引切片),兼容中/英引号。 +pub(crate) fn parse_signal_doc(doc: &str) -> ParsedSignalDoc { + let mut param_template: Option = None; + let mut signals: Vec = Vec::new(); + + // Helper: 给定起始索引,查找第一个引号并提取匹配的内容(支持中英文引号) + fn extract_quoted(s: &str, start: usize) -> Option<(String, usize)> { + // 支持的开引号及其对应的闭引号 + let pairs: &[(char, char)] = &[ + ('"', '"'), + ('\'', '\''), + ('"', '"'), + ('‘', '’'), + // 兼容性:有时候会只有右侧引号,仍然处理成成对查找同字符 + ('"', '"'), + ('’', '’'), + ]; + let hay = &s[start..]; + let chars = hay.char_indices(); + // 找到第一个开引号(在 pairs 中) + for (i, ch) in chars { + if let Some(&(_, closing)) = pairs.iter().find(|(o, _)| *o == ch) { + // 从 i+1 开始继续寻找 closing + let rest = &hay[i + ch.len_utf8()..]; + if let Some(j) = rest.find(closing) { + let content = &rest[..j]; + // 返回 content 以及全局结束索引 (start + i + len(open) + j + len(close)) + let end_idx = start + i + ch.len_utf8() + j + closing.len_utf8(); + return Some((content.to_string(), end_idx)); + } else { + // 未找到 matching closing,尝试继续扫描(放弃这个开引号) + continue; + } + } + } + None + } + + // 1) 提取参数模板:标签可能是 "参数模板:" 或 "参数模板:"(注意中文冒号) + let label_candidates = ["参数模板:", "参数模板:"]; + if let Some((label_start, label)) = label_candidates + .iter() + .filter_map(|lab| doc.find(lab).map(|p| (p, *lab))) + .min_by_key(|(p, _)| *p) + { + // 从 label 之后开始找首个引号内容 + let start_idx = label_start + label.len(); + if let Some((s, _end)) = extract_quoted(doc, start_idx) { + // 可能含有左右空白,trim + param_template = Some(s.trim().to_string()); + } + } + + // 2) 提取所有 Signal(...) 内容(支持 Signal('...') 或 Signal("...") + let mut pos = 0usize; + let needle = "Signal("; + while let Some(found) = doc[pos..].find(needle) { + let abs = pos + found + needle.len(); // index right after "(" + if let Some((content, end_idx)) = extract_quoted(doc, abs) { + let s = Signal::from_str(&content).ok(); + if let Some(signal) = s { + signals.push(signal); + } + pos = end_idx; // 继续从结束位置后搜索 + } else { + // 如果未能找到成对引号,跳过这个 "Signal(" 并继续向后查找,避免死循环 + pos = abs; + } + } + + ParsedSignalDoc { + param_template, + signals, + } +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl PySignal { + #[new] + #[pyo3(signature = (*args, signal = None, key = None, value = None, k1 = None, k2 = None, k3 = None, v1 = None, v2 = None, v3 = None, score = None))] + #[allow(clippy::too_many_arguments)] + fn new_py( + args: &Bound<'_, pyo3::types::PyTuple>, + signal: Option, + key: Option, + value: Option, + k1: Option, + k2: Option, + k3: Option, + v1: Option, + v2: Option, + v3: Option, + score: Option, + ) -> PyResult { + // 首先检查位置参数 + if args.len() == 1 { + // 如果有一个位置参数,当作signal字符串处理 + let signal_str: String = args.get_item(0)?.extract()?; + let inner = Signal::from_str(&signal_str) + .map_err(|e| PyValueError::new_err(format!("从位置参数创建Signal失败: {e}")))?; + return Ok(Self { inner }); + } else if args.len() > 1 { + return Err(PyValueError::new_err("Signal构造函数最多接受1个位置参数")); + } + // 方式1: 如果提供了signal参数,直接解析 + if let Some(signal_str) = signal { + let inner = Signal::from_str(&signal_str) + .map_err(|e| PyValueError::new_err(format!("从signal字符串创建Signal失败: {e}")))?; + return Ok(Self { inner }); + } + + // 方式2: 如果提供了key和value,组合创建 + if let (Some(key_str), Some(value_str)) = (key, value) { + // 重构key+value为完整的信号字符串 + let signal_str = format!("{key_str}_{value_str}"); + let inner = Signal::from_str(&signal_str) + .map_err(|e| PyValueError::new_err(format!("从key+value创建Signal失败: {e}")))?; + return Ok(Self { inner }); + } + + // 方式3: 如果提供了k1, k2, k3, v1, v2, v3, score,组合成signal字符串 + if k1.is_some() + || k2.is_some() + || k3.is_some() + || v1.is_some() + || v2.is_some() + || v3.is_some() + || score.is_some() + { + let k1_val = k1.unwrap_or_else(|| "任意".to_string()); + let k2_val = k2.unwrap_or_else(|| "任意".to_string()); + let k3_val = k3.unwrap_or_else(|| "任意".to_string()); + let v1_val = v1.unwrap_or_else(|| "任意".to_string()); + let v2_val = v2.unwrap_or_else(|| "任意".to_string()); + let v3_val = v3.unwrap_or_else(|| "任意".to_string()); + let score_val = score.unwrap_or(0); + + let signal_str = + format!("{k1_val}_{k2_val}_{k3_val}_{v1_val}_{v2_val}_{v3_val}_{score_val}"); + + let inner = Signal::from_str(&signal_str) + .map_err(|e| PyValueError::new_err(format!("从分字段创建Signal失败: {e}")))?; + return Ok(Self { inner }); + } + + // 默认情况:创建空的默认Signal + let default_signal = "任意_任意_任意_任意_任意_任意_0"; + let inner = Signal::from_str(default_signal) + .map_err(|e| PyValueError::new_err(format!("创建默认Signal失败: {e}")))?; + Ok(Self { inner }) + } + + #[classmethod] + fn from_string(_cls: &Bound<'_, pyo3::types::PyType>, s: String) -> PyResult { + let inner = Signal::from_str(&s).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("从字符串解析Signal失败: {e}")) + })?; + Ok(Self { inner }) + } + + #[getter] + fn key(&self) -> String { + self.inner.key().to_string() + } + + #[getter] + fn value(&self) -> String { + self.inner.value().to_string() + } + + #[getter] + fn k3(&self) -> String { + self.inner.k3.to_string() + } + + #[getter] + fn v1(&self) -> String { + self.inner.v1.to_string() + } + + #[getter] + fn v2(&self) -> String { + self.inner.v2.to_string() + } + + #[getter] + fn v3(&self) -> String { + self.inner.v3.to_string() + } + + #[getter] + fn score(&self) -> i32 { + self.inner.score + } + + /// 新增k1和k2属性getter,匹配Python版本 + #[getter] + fn k1(&self) -> String { + self.inner.k1.to_string() + } + + #[getter] + fn k2(&self) -> String { + self.inner.k2.to_string() + } + + /// 添加to_json方法以匹配Python版本 + fn to_json(&self) -> String { + format!("{}", self.inner) + } + + fn __str__(&self) -> String { + format!("Signal('{}')", self.inner) + } + + fn __repr__(&self) -> String { + format!("Signal('{}')", self.inner) + } + + fn __eq__(&self, other: &Self) -> bool { + self.inner == other.inner + } + + fn __hash__(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self.inner.hash(&mut hasher); + hasher.finish() + } + + /// 检查Signal是否匹配另一个Signal + fn matches(&self, other: &Self) -> bool { + self.inner == other.inner + } + + /// 判断信号是否与信号字典中的值匹配(Python版本is_match逻辑) + fn is_match(&self, signals_dict: std::collections::HashMap) -> PyResult { + let key = self.inner.key(); + let value = signals_dict + .get(&key) + .ok_or_else(|| PyValueError::new_err(format!("{key} 不在信号列表中")))?; + + let parts: Vec<&str> = value.split('_').collect(); + if parts.len() != 4 { + return Err(PyValueError::new_err("信号值格式错误")); + } + + let v1 = parts[0]; + let v2 = parts[1]; + let v3 = parts[2]; + let score: i32 = parts[3] + .parse() + .map_err(|_| PyValueError::new_err("分数解析失败"))?; + + let self_v1 = self.inner.v1.as_ref(); + let self_v2 = self.inner.v2.as_ref(); + let self_v3 = self.inner.v3.as_ref(); + let self_score = self.inner.score; + + // Python版本匹配逻辑 + if score >= self_score + && (v1 == self_v1 || self_v1 == "任意") + && (v2 == self_v2 || self_v2 == "任意") + && (v3 == self_v3 || self_v3 == "任意") + { + return Ok(true); + } + + Ok(false) + } + + /// 获取Signal的完整字符串表示 + #[allow(clippy::inherent_to_string)] + fn to_string(&self) -> String { + format!("{}", self.inner) + } +} + +/// Python可见的ParsedSignalDoc包装器 +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(name = "ParsedSignalDoc", module = "czsc._native"))] +#[derive(Debug, Clone)] +pub struct PyParsedSignalDoc { + pub(crate) inner: ParsedSignalDoc, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl PyParsedSignalDoc { + #[getter] + fn param_template(&self) -> Option { + self.inner.param_template.clone() + } + + #[getter] + fn signals(&self) -> Vec { + self.inner + .signals + .iter() + .map(|s| PySignal { inner: s.clone() }) + .collect() + } + + fn __repr__(&self) -> String { + format!( + "PyParsedSignalDoc(param_template={:?}, signals_count={})", + self.inner.param_template, + self.inner.signals.len() + ) + } +} + +/// 解析文档中的Signal信息 +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pyfunction)] +#[pyfunction(name = "parse_signal_doc")] +pub fn parse_signal_doc_py(doc: String) -> PyParsedSignalDoc { + let inner = parse_signal_doc(&doc); + PyParsedSignalDoc { inner } +} diff --git a/crates/czsc-core/src/objects/state.rs b/crates/czsc-core/src/objects/state.rs new file mode 100644 index 000000000..178955adb --- /dev/null +++ b/crates/czsc-core/src/objects/state.rs @@ -0,0 +1,23 @@ +use crate::analyze::CZSC; +use crate::objects::position::Position; + +/// 交易员状态接口 +/// +/// 将 `CzscTrader` 的运行时状态抽象为 trait,允许 `czsc-signals` 中的 pos 系列 +/// 信号函数访问仓位和K线数据,而无需直接依赖 `czsc-trader` crate,从根本上解除 +/// czsc-signals ↔ czsc-trader 循环依赖。 +/// +/// 依赖链: +/// czsc-core (定义 TraderState) +/// ↑ +/// czsc-signals (pos.rs 使用 &dyn TraderState) +/// ↑ +/// czsc-trader (CzscTrader 实现 TraderState) +pub trait TraderState { + /// 按名称查询仓位 + fn get_position(&self, name: &str) -> Option<&Position>; + /// 按频率查询 CZSC 解析器 + fn get_czsc(&self, freq: &str) -> Option<&CZSC>; + /// 获取当前最新价格(通常为基础周期最新 close) + fn latest_price(&self) -> Option; +} diff --git a/crates/czsc-core/src/objects/zs.rs b/crates/czsc-core/src/objects/zs.rs new file mode 100644 index 000000000..f74d216b6 --- /dev/null +++ b/crates/czsc-core/src/objects/zs.rs @@ -0,0 +1,235 @@ +use super::{bi::BI, direction::Direction}; +#[cfg(feature = "python")] +use crate::utils::common::create_naive_pandas_timestamp; +use chrono::{DateTime, Utc}; +use core::f64; +use derive_builder::Builder; +#[cfg(feature = "python")] +use parking_lot::RwLock; +#[cfg(feature = "python")] +use pyo3::types::{PyDict, PyDictMethods}; +#[cfg(feature = "python")] +use pyo3::{Py, PyObject, PyResult, Python, pyclass, pymethods}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +use std::sync::Arc; + +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +#[derive(Debug, Clone, Builder)] +pub struct ZS { + pub bis: Vec, + /// 中枢开始时间 + pub sdt: DateTime, + /// 中枢结束时间 + pub edt: DateTime, + /// 中枢第一笔方向,sdir 是 start direction 的缩写 + pub sdir: Direction, + /// 中枢倒一笔方向,edir 是 end direction 的缩写 + pub edir: Direction, + /// 中枢上沿 + pub zg: f64, + /// 中枢下沿 + pub zd: f64, + /// 中枢中轴 + pub zz: f64, + /// 中枢最高点 + pub gg: f64, + /// 中枢最低点 + pub dd: f64, + #[cfg(feature = "python")] + #[builder(default = "Arc::new(RwLock::new(None))")] + pub cache: Arc>>>, +} + +impl ZS { + pub fn new(bis: Vec) -> Self { + let sdt = bis.first().unwrap().start_dt(); + let edt = bis.last().unwrap().end_dt(); + let sdir = bis.first().unwrap().direction; + let edir = bis.last().unwrap().direction; + let zg = bis + .iter() + .take(3) + .map(|x| x.get_high()) + .fold(f64::INFINITY, f64::min); + let zd = bis + .iter() + .take(3) + .map(|x| x.get_low()) + .fold(f64::NEG_INFINITY, f64::max); + let gg = bis + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + let dd = bis + .iter() + .map(|x| x.get_low()) + .fold(f64::INFINITY, f64::min); + + let zz = zd + (zg - zd) * 0.5; + + ZS { + bis, + sdt, + edt, + sdir, + edir, + zg, + zd, + zz, + gg, + dd, + #[cfg(feature = "python")] + cache: Arc::new(RwLock::new(None)), + } + } + + /// 中枢是否有效 + pub fn is_valid(&self) -> bool { + let zg = self.zg; + let zd = self.zd; + if zg < zd { + return false; + } + + self.bis.iter().all(|bi| { + // 情况1: 笔的高点在中枢区间内 + let high_in_range = (bi.get_high() <= zg) && (bi.get_high() >= zd); + // 情况2: 笔的低点在中枢区间内 + let low_in_range = (bi.get_low() <= zg) && (bi.get_low() >= zd); + // 情况3: 笔完全包含中枢区间 + let contains_range = (bi.get_high() >= zg) && (bi.get_low() <= zd); + high_in_range || low_in_range || contains_range + }) + } +} +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl ZS { + #[new] + fn new_py(bis: Vec) -> Self { + Self::new(bis) + } + + /// 获取构成中枢的笔列表 + #[getter] + fn bis(&self) -> Vec { + self.bis.clone() + } + + /// 中枢开始时间 + #[getter] + fn sdt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.sdt) + } + + /// 中枢结束时间 + #[getter] + fn edt(&self, py: Python) -> PyResult { + create_naive_pandas_timestamp(py, self.edt) + } + + /// 中枢第一笔方向 + #[getter] + fn sdir(&self) -> Direction { + self.sdir + } + + /// 中枢倒一笔方向 + #[getter] + fn edir(&self) -> Direction { + self.edir + } + + /// 中枢上沿 + #[getter] + fn zg(&self) -> f64 { + self.zg + } + + /// 中枢下沿 + #[getter] + fn zd(&self) -> f64 { + self.zd + } + + /// 中枢中轴 + #[getter] + fn zz(&self) -> f64 { + self.zz + } + + /// 中枢最高点 + #[getter] + fn gg(&self) -> f64 { + self.gg + } + + /// 中枢最低点 + #[getter] + fn dd(&self) -> f64 { + self.dd + } + + /// 中枢是否有效 + #[pyo3(name = "is_valid")] + fn is_valid_py(&self) -> bool { + self.is_valid() + } + + #[getter] + fn get_cache<'py>(&'py self, py: Python<'py>) -> Py { + // 首先尝试读锁获取缓存 + { + let cache_read = self.cache.read(); + if let Some(ref cached_dict) = *cache_read { + return cached_dict.clone_ref(py); + } + } + + // 如果缓存为空,使用写锁初始化并填充所有属性 + let mut cache_write = self.cache.write(); + if cache_write.is_none() { + let dict = PyDict::new(py); + // 一次性填充所有属性,避免重复创建 + dict.set_item("sdt", create_naive_pandas_timestamp(py, self.sdt).unwrap()) + .unwrap(); + dict.set_item("edt", create_naive_pandas_timestamp(py, self.edt).unwrap()) + .unwrap(); + dict.set_item("sdir", self.sdir).unwrap(); + dict.set_item("edir", self.edir).unwrap(); + dict.set_item("zg", self.zg).unwrap(); + dict.set_item("zd", self.zd).unwrap(); + dict.set_item("zz", self.zz).unwrap(); + dict.set_item("gg", self.gg).unwrap(); + dict.set_item("dd", self.dd).unwrap(); + dict.set_item("bis", py.None()).unwrap(); // 复杂对象先设为None + *cache_write = Some(dict.unbind()); + } + cache_write.as_ref().unwrap().clone_ref(py) + } + + #[setter] + #[gen_stub(skip)] // 跳过为了防止和 get_cache重复 + fn set_cache(&self, dict: Py) { + let mut cache_write = self.cache.write(); + *cache_write = Some(dict); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::objects::bi::tests::create_bi; + + #[test] + fn test_new_zs() { + let bi1 = create_bi(); + + let zs1 = ZS::new(vec![bi1.clone(), bi1.clone(), bi1]); + + println!("{:?}", zs1.sdt); + } +} diff --git a/crates/czsc-core/src/python/mod.rs b/crates/czsc-core/src/python/mod.rs new file mode 100644 index 000000000..b7c74765e --- /dev/null +++ b/crates/czsc-core/src/python/mod.rs @@ -0,0 +1,115 @@ +//! PyO3 binding registry for czsc-core. +//! +//! Phase D's per-type sub-loops add `#[cfg_attr(feature = "python", pyclass)]` +//! to each migrated type. This module collects them into a single +//! `register()` entrypoint that `czsc-python` calls from the +//! `_native` aggregator. +//! +//! Pickle (`__getstate__` / `__setstate__`) per design doc §2.4 will +//! land on a follow-up pass once Phase E/F/G land and the per-class +//! identity tests can fully exercise it. + +use pyo3::prelude::*; + +use crate::analyze::CZSC; +use crate::analyze::utils as analyze_utils; +use crate::objects::bar::{NewBar, RawBar}; +use crate::objects::bi::BI; +use crate::objects::direction::Direction; +use crate::objects::event::PyEvent; +use crate::objects::fake_bi::FakeBI; +use crate::objects::freq::Freq; +use crate::objects::fx::FX; +use crate::objects::mark::Mark; +use crate::objects::market::Market; +use crate::objects::operate::PyOperate; +use crate::objects::position::{PyLiteBar, PyPos, PyPosition}; +use crate::objects::signal::{PyParsedSignalDoc, PySignal, parse_signal_doc_py}; +use crate::objects::zs::ZS; + +/// Python-friendly thin wrapper around `analyze::utils::check_fx`. +#[pyfunction] +#[pyo3(name = "check_fx")] +fn check_fx_py(k1: NewBar, k2: NewBar, k3: NewBar) -> Option { + analyze_utils::check_fx(&k1, &k2, &k3) +} + +/// Python-friendly thin wrapper around `analyze::utils::check_fxs`. +#[pyfunction] +#[pyo3(name = "check_fxs")] +fn check_fxs_py(bars: Vec) -> Vec { + analyze_utils::check_fxs(&bars) +} + +/// Python-friendly thin wrapper around `analyze::utils::check_bi`. +/// Drops the unused remainder slice; Python callers only ever consume +/// the optional BI value. +#[pyfunction] +#[pyo3(name = "check_bi")] +fn check_bi_py(bars: Vec) -> Option { + let (bi, _) = analyze_utils::check_bi(&bars); + bi +} + +/// Python-friendly thin wrapper around `analyze::utils::remove_include`. +#[pyfunction] +#[pyo3(name = "remove_include")] +fn remove_include_py(k1: NewBar, k2: NewBar, k3: RawBar) -> PyResult<(bool, NewBar)> { + analyze_utils::remove_include(&k1, &k2, k3) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) +} + +/// Python-friendly thin wrapper around `analyze::utils::format_standard_kline`. +/// Polars DataFrame is bridged via the standard pyo3-polars / arrow path; for +/// now we accept a list of pre-built RawBars to avoid the polars/python coupling +/// during D.A. The full DataFrame entrypoint will be added when Phase E/F wire +/// the polars Python bridge (see design doc §2.3). +#[pyfunction] +#[pyo3(name = "format_standard_kline")] +fn format_standard_kline_py(bars: Vec) -> Vec { + bars +} + +/// Add the migrated czsc-core types onto the parent module that czsc-python +/// passes in. Lives behind the `python` feature so plain Rust consumers +/// don't pull pyo3 in transitively. +pub fn register(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + // Enums + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Bar primitives + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Chan-theory data structures + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Signal / Event / Position + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Analyzer (CZSC) + m.add_class::()?; + + // Free functions: signal-doc parser + analyze helpers (the 4 promotions + // from design doc §2.5) + m.add_function(wrap_pyfunction!(parse_signal_doc_py, m)?)?; + m.add_function(wrap_pyfunction!(check_fx_py, m)?)?; + m.add_function(wrap_pyfunction!(check_fxs_py, m)?)?; + m.add_function(wrap_pyfunction!(check_bi_py, m)?)?; + m.add_function(wrap_pyfunction!(remove_include_py, m)?)?; + m.add_function(wrap_pyfunction!(format_standard_kline_py, m)?)?; + + Ok(()) +} diff --git a/crates/czsc-core/src/utils/common.rs b/crates/czsc-core/src/utils/common.rs new file mode 100644 index 000000000..d124cf698 --- /dev/null +++ b/crates/czsc-core/src/utils/common.rs @@ -0,0 +1,119 @@ +use crate::objects::freq::Freq; +use chrono::{DateTime, Utc}; + +#[cfg(feature = "python")] +use crate::objects::errors::ObjectError; +#[cfg(feature = "python")] +use anyhow::anyhow; +#[cfg(feature = "python")] +use pyo3::prelude::*; +#[cfg(feature = "python")] +use pyo3::types::PyAnyMethods; + +/// 创建 OrderedDict,与 czsc 库兼容 +#[cfg(feature = "python")] +pub fn create_ordered_dict(py: Python) -> PyResult { + let collections = py.import("collections")?; + let ordered_dict = collections.getattr("OrderedDict")?; + let result = ordered_dict.call0()?; + Ok(result.into()) +} + +/// 创建不带时区信息的 pandas Timestamp,与原版CZSC保持一致 +#[cfg(feature = "python")] +pub fn create_naive_pandas_timestamp(py: Python, dt: DateTime) -> PyResult { + let pandas = py.import("pandas")?; + let timestamp_cls = pandas.getattr("Timestamp")?; + let dt_naive = dt.naive_utc(); + let iso_string = dt_naive.format("%Y-%m-%d %H:%M:%S").to_string(); + + // 创建不带时区的naive时间戳 + let ts = timestamp_cls.call((iso_string,), None)?; + + Ok(ts.into()) +} + +/// 通用的日期时间解析函数,支持多种Python日期时间格式 +/// 这个函数被RawBar、NewBar和FX共同使用,避免重复代码 +#[cfg(feature = "python")] +pub fn parse_python_datetime(dt: &Bound) -> PyResult> { + // 尝试解析dt参数,支持多种输入格式 + let datetime_utc = if dt.hasattr("timestamp")? { + // 如果是Python datetime对象(有timestamp方法) + let timestamp = dt.call_method0("timestamp")?; + let timestamp_f64: f64 = timestamp.extract()?; + DateTime::from_timestamp( + timestamp_f64 as i64, + (timestamp_f64.fract() * 1_000_000_000.0) as u32, + ) + .ok_or(ObjectError::Unexpected(anyhow!( + "Invalid datetime for building object" + )))? + } else if dt.hasattr("tz_localize")? { + // 如果是pandas Timestamp,可能没有时区信息 + let localized_dt = if dt.getattr("tz")?.is_none() { + // 如果没有时区,添加UTC时区 + dt.call_method1("tz_localize", ("UTC",))? + } else { + dt.clone() + }; + let timestamp = localized_dt.call_method0("timestamp")?; + let timestamp_f64: f64 = timestamp.extract()?; + DateTime::from_timestamp( + timestamp_f64 as i64, + (timestamp_f64.fract() * 1_000_000_000.0) as u32, + ) + .ok_or(ObjectError::Unexpected(anyhow!( + "Invalid datetime for building object" + )))? + } else if let Ok(timestamp) = dt.extract::() { + // 如果是时间戳(保持向后兼容) + DateTime::from_timestamp(timestamp, 0).ok_or(ObjectError::Unexpected(anyhow!( + "Invalid timestamp for building object" + )))? + } else if let Ok(timestamp_f64) = dt.extract::() { + // 如果是浮点数时间戳 + DateTime::from_timestamp( + timestamp_f64 as i64, + (timestamp_f64.fract() * 1_000_000_000.0) as u32, + ) + .ok_or(ObjectError::Unexpected(anyhow!( + "Invalid timestamp for building object" + )))? + } else { + return Err(ObjectError::Unexpected(anyhow!( + "dt parameter must be a Python datetime object, pandas Timestamp, integer timestamp, or float timestamp" + )) + .into()); + }; + + Ok(datetime_utc) +} + +/// 将频率枚举转换为中文字符串 +/// 这个函数被多个结构体的Python绑定共同使用,避免重复代码 +pub fn freq_to_chinese_string(freq: Freq) -> &'static str { + match freq { + Freq::Tick => "Tick", + Freq::F1 => "1分钟", + Freq::F2 => "2分钟", + Freq::F3 => "3分钟", + Freq::F4 => "4分钟", + Freq::F5 => "5分钟", + Freq::F6 => "6分钟", + Freq::F10 => "10分钟", + Freq::F12 => "12分钟", + Freq::F15 => "15分钟", + Freq::F20 => "20分钟", + Freq::F30 => "30分钟", + Freq::F60 => "60分钟", + Freq::F120 => "120分钟", + Freq::F240 => "240分钟", + Freq::F360 => "360分钟", + Freq::D => "日线", + Freq::W => "周线", + Freq::M => "月线", + Freq::S => "季线", + Freq::Y => "年线", + } +} diff --git a/crates/czsc-core/src/utils/corr.rs b/crates/czsc-core/src/utils/corr.rs new file mode 100644 index 000000000..df8e20544 --- /dev/null +++ b/crates/czsc-core/src/utils/corr.rs @@ -0,0 +1,225 @@ +use super::rounded::RoundToNthDigit; + +#[derive(Debug)] +pub struct LinearResult { + /// 标识斜率 + pub slope: f64, + /// 截距 + pub intercept: f64, + /// 拟合优度 + pub r2: f64, +} + +/// 单变量线性拟合 +pub trait LinearRegression { + fn single_linear(&self) -> LinearResult; +} + +impl LinearRegression for [f64] { + /// 单变量线性拟合 + /// https://en.wikipedia.org/wiki/Linear_regression + fn single_linear(&self) -> LinearResult { + if self.is_empty() { + return LinearResult { + slope: 0.0, + intercept: 0.0, + r2: 0.0, + }; + } + + // 数据点的数量 + let sample_size = self.len() as f64; + + // x 值的总和, 隐式生成的索引序列 [0, 1, 2, ..., n-1],其和为等差数列公式 (n-1)*n/2 + let sum_x = (sample_size - 1.0) * sample_size / 2.0; + // x 值的平方和,使用公式 (n-1)*n*(2n-1)/6,对应索引序列的平方和 + let sum_x_squared = (sample_size - 1.0) * sample_size * (2.0 * sample_size - 1.0) / 6.0; + + // sum_xy: x 和 y 的乘积之和(Σ(x*y)) + let (sum_xy, sum_y) = self.iter().enumerate().fold((0.0, 0.0), |acc, (i, &y1)| { + (acc.0 + (i as f64) * y1, acc.1 + y1) + }); + + // 线性回归斜率公式的分母部分: n*Σx² - (Σx)^2 + let denominator = sample_size * sum_x_squared - sum_x * sum_x; + if denominator == 0.0 { + return LinearResult { + slope: 0.0, + intercept: 0.0, + r2: 0.0, + }; + } + + let y_intercept = (1.0 / denominator) * (sum_x_squared * sum_y - sum_x * sum_xy); + let slope = (1.0 / denominator) * (sample_size * sum_xy - sum_x * sum_y); + + let y_mean = sum_y / sample_size; + let (ss_tot, ss_err) = self.iter().enumerate().fold((0.0, 0.0), |acc, (i, &y1)| { + let y_diff = y1 - y_mean; + let predicted = slope * (i as f64) + y_intercept; + let err = y1 - predicted; + (acc.0 + y_diff * y_diff, acc.1 + err * err) + }); + + let rsq = 1.0 - ss_err / (ss_tot + 0.00001); + + LinearResult { + slope: slope.round_to_4_digit(), + intercept: y_intercept.round_to_4_digit(), + r2: rsq.round_to_4_digit(), + } + } +} + +/// ## 计算两个向量的Pearson Corr +/// +/// [wiki](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) +/// +/// - 当数据为空或长度不一致时返回 None。 +/// +/// - 分母为0时返回 None避免除0错误 +pub fn pearson_corr(x: &[f64], y: &[f64]) -> Option { + if x.len() != y.len() || x.is_empty() { + return None; + } + + let n = x.len(); + // x̄ = (Σxᵢ) / n + let mean_x = x.iter().sum::() / n as f64; + // ȳ = (Σyᵢ) / n + let mean_y = y.iter().sum::() / n as f64; + + // cov(X, Y): Σ(xᵢ - x̄)(yᵢ - ȳ) + let mut cov = 0.0; + // Σ(xᵢ - x̄)² + let mut sum_x_sq = 0.0; + // Σ(yᵢ - ȳ)² + let mut sum_y_sq = 0.0; + + for i in 0..n { + let diff_x = x[i] - mean_x; + let diff_y = y[i] - mean_y; + cov += diff_x * diff_y; + + sum_x_sq += diff_x * diff_x; + sum_y_sq += diff_y * diff_y; + } + + let denominator = (sum_x_sq * sum_y_sq).sqrt(); + if denominator == 0.0 { + // 防止分母为0 + return None; + } + + // cov(X, Y) / (σXσY) + Some(cov / denominator) +} + +/// 计算两个向量的Spearman Rank Corr +/// +/// [wiki](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient) +/// +/// > Spearman's coefficient is appropriate for both continuous and discrete ordinal variables. +/// +/// - 当数据为空或长度不一致时返回 None。 +pub fn spearman_rank_corr(x: &[f64], y: &[f64]) -> Option { + if x.len() != y.len() || x.is_empty() { + return None; + } + + fn compute_ranks(data: &[f64]) -> Vec { + // 将数据与原始索引绑定:[(值, 原始索引)] + let mut indexed_data: Vec<(f64, usize)> = + data.iter().enumerate().map(|(i, &val)| (val, i)).collect(); + + // 按值排序(从小到大) + indexed_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + + let n = data.len(); + let mut ranks = vec![0.0; n]; + let mut i = 0; + + // 遍历排序后的数据,处理重复值 + while i < n { + let current_val = indexed_data[i].0; + let mut j = i; + + // 找到所有与当前值相同的元素(重复值的结束位置) + while j < n && indexed_data[j].0 == current_val { + j += 1; + } + + // 计算平均排名(从1开始) + let avg_rank = (i + 1 + j) as f64 / 2.0; + + // 将平均Rank赋给所有重复值的原始索引 + for item in &indexed_data[i..j] { + let original_index = item.1; + ranks[original_index] = avg_rank; + } + + // 跳过已处理的重复值 + i = j; + } + + ranks + } + + // R[X] + let x_ranks = compute_ranks(x); + // R[Y] + let y_ranks = compute_ranks(y); + + // cov(R[X], R[Y]) / (σR[X]σR[Y]) + pearson_corr(&x_ranks, &y_ranks) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_single_linear() { + let y = [1.0, 2.0, 3.0, 4.0, 5.0]; + + let res2 = y.single_linear(); + + println!("{res2:?}"); + } + + ///``` + /// import pandas as pd + /// a = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, 6.0, 16.0] + /// b = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, -6.0, 16.0] + /// df = pd.DataFrame({'col_a': a, 'col_b': b}) + /// df['col_a'].corr(df['col_b']) + ///``` + /// np.float64(0.8214654924226573) + #[test] + fn test_pearson_corr() { + let a = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, 6.0, 16.0]; + let b = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, -6.0, 16.0]; + assert_eq!( + pearson_corr(&a, &b).map(|x| x.round_to_4_digit()), + Some(0.8215) + ) + } + + ///``` + /// import pandas as pd + /// a = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, 6.0, 16.0] + /// b = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, -6.0, 16.0] + /// df = pd.DataFrame({'col_a': a, 'col_b': b}) + /// df['col_a'].corr(df['col_b'], method="spearman") + ///``` + /// np.float64(0.6470588235294119) + #[test] + fn test_spearman_rank_corr() { + let a = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, 6.0, 16.0]; + let b = [1.0, 2.0, -3.0, -1.0, 5.0, 1.0, -7.0, -6.0, 16.0]; + assert_eq!( + spearman_rank_corr(&a, &b).map(|x| x.round_to_4_digit()), + Some(0.6471) + ) + } +} diff --git a/crates/czsc-core/src/utils/errors.rs b/crates/czsc-core/src/utils/errors.rs new file mode 100644 index 000000000..b905d3cfe --- /dev/null +++ b/crates/czsc-core/src/utils/errors.rs @@ -0,0 +1,13 @@ +use error_macros::CZSCErrorDerive; +use error_support::expand_error_chain; +use polars::error::PolarsError; +use thiserror::Error; + +#[derive(Debug, Error, CZSCErrorDerive)] +pub enum CoreUtilsErorr { + #[error("Polars: {0}")] + Polars(#[from] PolarsError), + + #[error("{}", expand_error_chain(.0))] + Unexpected(anyhow::Error), +} diff --git a/crates/czsc-core/src/utils/mod.rs b/crates/czsc-core/src/utils/mod.rs new file mode 100644 index 000000000..baa322f57 --- /dev/null +++ b/crates/czsc-core/src/utils/mod.rs @@ -0,0 +1,7 @@ +//! Internal utilities for czsc-core. Migration is incremental: members +//! land as their consumers (BI / ZS / ...) get migrated. + +pub mod common; +pub mod corr; +pub mod errors; +pub mod rounded; diff --git a/crates/czsc-core/src/utils/rounded.rs b/crates/czsc-core/src/utils/rounded.rs new file mode 100644 index 000000000..e0665d2d1 --- /dev/null +++ b/crates/czsc-core/src/utils/rounded.rs @@ -0,0 +1,36 @@ +pub trait RoundToNthDigit { + fn round_to_nth_digit(&self, nth: usize) -> Self; + fn round_to_2_digit(&self) -> Self; + fn round_to_3_digit(&self) -> Self; + fn round_to_4_digit(&self) -> Self; +} + +impl RoundToNthDigit for f64 { + fn round_to_nth_digit(&self, nth: usize) -> f64 { + let scale = 10_f64.powi(nth as i32); + let scaled = *self * scale; + (scaled).round() / scale + } + + fn round_to_2_digit(&self) -> f64 { + (*self * 100.0).round() / 100.0 + } + + fn round_to_3_digit(&self) -> f64 { + (*self * 1000.0).round() / 1000.0 + } + + fn round_to_4_digit(&self) -> f64 { + (*self * 10000.0).round() / 10000.0 + } +} + +pub fn min_max(x: f64, min_val: f64, max_val: f64) -> f64 { + if x < min_val { + min_val + } else if x > max_val { + max_val + } else { + x + } +} diff --git a/crates/czsc-core/tests/test_analyze_error.rs b/crates/czsc-core/tests/test_analyze_error.rs new file mode 100644 index 000000000..7aaa224a7 --- /dev/null +++ b/crates/czsc-core/tests/test_analyze_error.rs @@ -0,0 +1,29 @@ +//! Phase D.E — RED test: AnalyzeErorr formats with thiserror, accepts +//! a PolarsError via the From blanket, and round-trips through serde +//! via CZSCErrorDerive. + +use czsc_core::analyze::errors::AnalyzeErorr; + +#[test] +fn from_polars_routes_to_polars_variant() { + use polars::error::PolarsError; + let pe = PolarsError::ComputeError("compute boom".into()); + let err: AnalyzeErorr = pe.into(); + assert!(matches!(err, AnalyzeErorr::Polars(_))); + assert!(err.to_string().contains("compute boom")); +} + +#[test] +fn from_anyhow_routes_to_unexpected() { + let any: anyhow::Error = anyhow::anyhow!("ka-boom"); + let err: AnalyzeErorr = any.into(); + assert!(matches!(err, AnalyzeErorr::Unexpected(_))); +} + +#[test] +fn serialize_emits_string_payload() { + let pe = polars::error::PolarsError::ComputeError("x".into()); + let err: AnalyzeErorr = pe.into(); + let json = serde_json::to_string(&err).unwrap(); + assert!(json.starts_with('"'), "expected JSON string, got {json}"); +} diff --git a/crates/czsc-core/tests/test_analyze_utils.rs b/crates/czsc-core/tests/test_analyze_utils.rs new file mode 100644 index 000000000..c86218f74 --- /dev/null +++ b/crates/czsc-core/tests/test_analyze_utils.rs @@ -0,0 +1,114 @@ +//! Phase D.U — RED test: analyze::utils helpers (check_fx / check_fxs / +//! check_bi / remove_include / format_standard_kline) are publicly callable +//! and produce the expected shapes per the rs-czsc 47ef6efa baseline. +//! +//! This test also locks the visibility promotions required by the design +//! doc §2.5: all four `pub(crate)` helpers must now be `pub`. + +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use czsc_core::analyze::utils::{check_bi, check_fx, check_fxs, format_standard_kline}; +use czsc_core::objects::bar::{NewBar, NewBarBuilder, RawBar}; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::mark::Mark; + +fn nb(ts: i64, high: f64, low: f64) -> NewBar { + NewBarBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .freq(Freq::F30) + .id(0) + .open((high + low) / 2.0) + .close((high + low) / 2.0) + .high(high) + .low(low) + .vol(100.0) + .amount(100.0 * (high + low) / 2.0) + .elements(Vec::new()) + .build() + .unwrap() +} + +#[test] +fn check_fx_detects_top_pattern() { + // top fx: middle bar engulfs both neighbours from above + let k1 = nb(1, 11.0, 9.0); + let k2 = nb(2, 12.0, 10.0); + let k3 = nb(3, 11.5, 9.5); + let fx = check_fx(&k1, &k2, &k3).expect("expected top fx"); + assert_eq!(fx.mark, Mark::G); + assert!((fx.fx - 12.0).abs() < f64::EPSILON); +} + +#[test] +fn check_fx_detects_bottom_pattern() { + let k1 = nb(1, 11.0, 9.5); + let k2 = nb(2, 10.5, 8.0); + let k3 = nb(3, 11.0, 9.0); + let fx = check_fx(&k1, &k2, &k3).expect("expected bottom fx"); + assert_eq!(fx.mark, Mark::D); + assert!((fx.fx - 8.0).abs() < f64::EPSILON); +} + +#[test] +fn check_fx_returns_none_when_no_pattern() { + // strictly increasing — neither top nor bottom + let k1 = nb(1, 10.0, 9.0); + let k2 = nb(2, 11.0, 10.0); + let k3 = nb(3, 12.0, 11.0); + assert!(check_fx(&k1, &k2, &k3).is_none()); +} + +#[test] +fn check_fxs_extracts_fx_from_sequence() { + // 5 bars: ascending, peak, descending → exactly one top fx in the middle + let bars = vec![ + nb(1, 10.0, 9.0), + nb(2, 11.0, 10.0), + nb(3, 12.0, 11.0), + nb(4, 11.5, 10.5), + nb(5, 11.0, 10.0), + ]; + let fxs = check_fxs(&bars); + assert!(!fxs.is_empty(), "expected at least one fx in peak sequence"); +} + +#[test] +fn check_bi_returns_tuple_with_remainder() { + let bars: Vec = (0..6).map(|i| nb(i + 1, 10.0 + i as f64, 9.0 + i as f64)).collect(); + let (bi, remainder) = check_bi(&bars); + // The function signature contract: always returns (Option, &[NewBar]) + let _ = bi; + assert!(remainder.len() <= bars.len()); +} + +#[test] +fn format_standard_kline_builds_raw_bars_from_dataframe() { + use polars::prelude::*; + + let df = df! { + "symbol" => ["000001", "000001", "000001"], + "dt" => [1_700_000_000_000i64, 1_700_001_800_000, 1_700_003_600_000], + "open" => [10.0_f64, 10.5, 11.0], + "close" => [10.5_f64, 11.0, 10.8], + "high" => [11.0_f64, 11.5, 11.2], + "low" => [9.5_f64, 10.0, 10.5], + "vol" => [100.0_f64, 200.0, 150.0], + "amount" => [1000.0_f64, 2000.0, 1500.0], + } + .unwrap() + .lazy() + .with_column( + col("dt") + .cast(DataType::Datetime(TimeUnit::Milliseconds, None)) + .alias("dt"), + ) + .collect() + .unwrap(); + + let bars: Vec = format_standard_kline(df, Freq::F30).unwrap(); + assert_eq!(bars.len(), 3); + assert_eq!(&*bars[0].symbol, "000001"); + assert_eq!(bars[0].freq, Freq::F30); +} diff --git a/crates/czsc-core/tests/test_bi.rs b/crates/czsc-core/tests/test_bi.rs new file mode 100644 index 000000000..ab7a4d08b --- /dev/null +++ b/crates/czsc-core/tests/test_bi.rs @@ -0,0 +1,95 @@ +//! Phase D.8 — RED test: BI (笔) constructs via BIBuilder, surfaces +//! direction / endpoints, and answers length / SNR / power_price helpers. + +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use czsc_core::objects::bar::{NewBar, NewBarBuilder}; +use czsc_core::objects::bi::{BI, BIBuilder}; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::fx::{FX, FXBuilder}; +use czsc_core::objects::mark::Mark; + +fn nb(ts: i64, high: f64, low: f64, vol: f64) -> NewBar { + NewBarBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .freq(Freq::F30) + .id(0) + .open((high + low) / 2.0) + .close((high + low) / 2.0) + .high(high) + .low(low) + .vol(vol) + .amount(vol * (high + low) / 2.0) + .elements(Vec::new()) + .build() + .unwrap() +} + +fn fx(ts: i64, mark: Mark, level: f64) -> FX { + let k1 = nb(ts - 1800, level - 0.5, level - 1.5, 100.0); + let k2 = nb(ts, level + 0.5, level - 0.5, 200.0); + let k3 = nb(ts + 1800, level - 0.2, level - 1.0, 100.0); + FXBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .mark(mark) + .high(k2.high) + .low(k2.low) + .fx(if matches!(level, l if l > 5.0) { k2.high } else { k2.low }) + .elements(vec![k1, k2, k3]) + .build() + .unwrap() +} + +fn sample_bi_up() -> BI { + // up bi: starts at bottom fx, ends at top fx + let fx_a = fx(1_700_000_000, Mark::D, 9.0); + let fx_b = fx(1_700_007_200, Mark::G, 12.0); + let bars: Vec = (0..5) + .map(|i| nb(1_700_000_000 + i * 1800, 11.0 + i as f64 * 0.2, 9.5 + i as f64 * 0.2, 100.0)) + .collect(); + BIBuilder::default() + .symbol(Arc::::from("000001")) + .fx_a(fx_a) + .fx_b(fx_b.clone()) + .fxs(vec![fx_b]) + .direction(Direction::Up) + .bars(bars) + .build() + .unwrap() +} + +#[test] +fn builder_populates_fields() { + let bi = sample_bi_up(); + assert_eq!(bi.direction, Direction::Up); + assert_eq!(bi.bars.len(), 5); +} + +#[test] +fn length_is_bars_count() { + let bi = sample_bi_up(); + assert_eq!(bi.get_length(), 5); +} + +#[test] +fn high_low_endpoints_match_fxs() { + let bi = sample_bi_up(); + assert!(bi.get_low() < bi.get_high(), "low must be < high"); +} + +#[test] +fn power_price_is_finite() { + let bi = sample_bi_up(); + assert!(bi.get_power_price().is_finite()); +} + +#[test] +fn equality_matches_rs_czsc_baseline() { + let a = sample_bi_up(); + let b = sample_bi_up(); + assert_eq!(a, b); +} diff --git a/crates/czsc-core/tests/test_czsc_analyzer.rs b/crates/czsc-core/tests/test_czsc_analyzer.rs new file mode 100644 index 000000000..4867386ce --- /dev/null +++ b/crates/czsc-core/tests/test_czsc_analyzer.rs @@ -0,0 +1,101 @@ +//! Phase D.A — RED test: CZSC analyzer constructs from a RawBar feed, +//! exposes bars_raw / bars_ubi / bi_list / fx_list, and survives an +//! incremental update_bar feed. + +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::{RawBar, RawBarBuilder}; +use czsc_core::objects::freq::Freq; + +fn rb(ts: i64, open: f64, close: f64, high: f64, low: f64) -> RawBar { + RawBarBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .freq(Freq::F30) + .id(0) + .open(open) + .close(close) + .high(high) + .low(low) + .vol(1000.0_f64) + .amount(1_000_000.0_f64) + .build() + .unwrap() +} + +fn synthetic_zigzag(n: usize) -> Vec { + // Build a sine-like zigzag so that the analyzer can produce fxs/bis. + (0..n) + .map(|i| { + let phase = (i as f64) * 0.7; + let mid = 100.0 + 5.0 * phase.sin(); + let half = 1.0 + 0.5 * phase.cos().abs(); + rb( + 1_700_000_000 + (i as i64) * 1800, + mid - 0.2, + mid + 0.2, + mid + half, + mid - half, + ) + }) + .collect() +} + +#[test] +fn new_populates_symbol_and_freq() { + let bars = synthetic_zigzag(50); + let c = CZSC::new(bars, 50); + assert_eq!(&*c.symbol, "000001"); + assert_eq!(c.freq, Freq::F30); + assert_eq!(c.max_bi_num, 50); +} + +#[test] +fn new_consumes_all_bars_and_builds_ubi() { + let bars = synthetic_zigzag(40); + let c = CZSC::new(bars, 50); + // bars_ubi is the merged-bar (NewBar) sequence; for 40 raw zigzag + // bars we expect non-empty merged sequence + assert!(!c.bars_ubi.is_empty(), "bars_ubi should not be empty"); +} + +#[test] +fn fx_and_bi_lists_are_consistent_with_zigzag() { + let bars = synthetic_zigzag(60); + let c = CZSC::new(bars, 50); + let fxs = c.get_fx_list(); + // A 60-bar zigzag should produce at least 2 fxs (or zero — the + // exact count depends on the synthetic shape; we only assert + // non-negative invariants). + assert!(fxs.len() <= 60); + assert!(c.bi_list.len() <= 50); +} + +#[test] +fn update_bar_appends_incrementally() { + let bars = synthetic_zigzag(30); + let mut c = CZSC::new(bars, 50); + let extra = rb( + 1_700_000_000 + 30 * 1800, + 102.0, + 103.0, + 104.0, + 101.0, + ); + c.update_bar(extra); + assert_eq!(c.freq, Freq::F30); + // bars_raw monotonically grows (modulo the analyzer's internal pruning) + assert!(c.bars_raw.iter().any(|b| b.dt + == Utc.timestamp_opt(1_700_000_000 + 30 * 1800, 0).unwrap())); +} + +#[test] +fn analyzer_clones_independently() { + let bars = synthetic_zigzag(20); + let c = CZSC::new(bars, 50); + let d = c.clone(); + assert_eq!(d.bi_list.len(), c.bi_list.len()); + assert_eq!(&*d.symbol, &*c.symbol); +} diff --git a/crates/czsc-core/tests/test_event.rs b/crates/czsc-core/tests/test_event.rs new file mode 100644 index 000000000..ec9d79f35 --- /dev/null +++ b/crates/czsc-core/tests/test_event.rs @@ -0,0 +1,78 @@ +//! Phase D.10c — RED test: Event struct constructs from Operate + signals, +//! computes a stable sha8 hash, refresh_hash_name keeps it in sync, and +//! all_signals iterates the union. + +use std::str::FromStr; + +use czsc_core::objects::event::Event; +use czsc_core::objects::operate::Operate; +use czsc_core::objects::signal::Signal; + +fn sample_event() -> Event { + Event { + operate: Operate::LO, + signals_all: vec![ + Signal::from_str("30分钟_D1_前高_看多_强_任意_0").unwrap(), + ], + signals_any: vec![ + Signal::from_str("日线_D1_趋势_看多_中_任意_0").unwrap(), + Signal::from_str("日线_D2_趋势_看多_弱_任意_0").unwrap(), + ], + signals_not: vec![], + name: String::new(), + sha256: String::new(), + } +} + +#[test] +fn struct_holds_signals() { + let e = sample_event(); + assert_eq!(e.signals_all.len(), 1); + assert_eq!(e.signals_any.len(), 2); + assert_eq!(e.signals_not.len(), 0); + assert_eq!(e.operate, Operate::LO); +} + +#[test] +fn compute_sha8_returns_4_hex_chars() { + let e = sample_event(); + let h = e.compute_sha8(); + assert_eq!(h.len(), 4, "sha8 prefix must be 4 chars, got {h:?}"); + assert!(h.chars().all(|c| c.is_ascii_hexdigit() && c.is_ascii_uppercase() || c.is_ascii_digit()), + "expected uppercase hex, got {h:?}"); +} + +#[test] +fn compute_sha8_is_deterministic() { + let e = sample_event(); + let a = e.compute_sha8(); + let b = e.compute_sha8(); + assert_eq!(a, b, "sha8 must be a pure function of the event payload"); +} + +#[test] +fn refresh_hash_name_updates_sha256_field() { + let mut e = sample_event(); + assert_eq!(e.sha256, ""); + e.refresh_hash_name(); + assert!(!e.sha256.is_empty(), "refresh should populate sha256"); +} + +#[test] +fn all_signals_iterates_union() { + let e = sample_event(); + let all: Vec<_> = e.all_signals().collect(); + // 1 (signals_all) + 2 (signals_any) + 0 (signals_not) = 3 + assert_eq!(all.len(), 3); +} + +#[test] +fn dump_load_roundtrip_via_json() { + let mut e = sample_event(); + e.refresh_hash_name(); + let dumped = e.dump(); + let loaded = Event::load(&dumped).unwrap(); + assert_eq!(loaded.operate, e.operate); + assert_eq!(loaded.signals_all.len(), e.signals_all.len()); + assert_eq!(loaded.signals_any.len(), e.signals_any.len()); +} diff --git a/crates/czsc-core/tests/test_freq.rs b/crates/czsc-core/tests/test_freq.rs new file mode 100644 index 000000000..6e9a1f9e2 --- /dev/null +++ b/crates/czsc-core/tests/test_freq.rs @@ -0,0 +1,59 @@ +//! Phase D.3 — RED test: Freq enum FromStr / Display, minutes() / is_minute_freq(). + +use std::str::FromStr; + +use czsc_core::objects::freq::Freq; + +#[test] +fn parses_minute_freqs() { + assert_eq!(Freq::from_str("1分钟").unwrap(), Freq::F1); + assert_eq!(Freq::from_str("30分钟").unwrap(), Freq::F30); + assert_eq!(Freq::from_str("60分钟").unwrap(), Freq::F60); +} + +#[test] +fn parses_higher_timeframes() { + assert_eq!(Freq::from_str("日线").unwrap(), Freq::D); + assert_eq!(Freq::from_str("周线").unwrap(), Freq::W); + assert_eq!(Freq::from_str("月线").unwrap(), Freq::M); +} + +#[test] +fn rejects_unknown_strings() { + assert!(Freq::from_str("100分钟").is_err()); +} + +#[test] +fn display_round_trips() { + assert_eq!(Freq::F30.to_string(), "30分钟"); + assert_eq!(Freq::D.to_string(), "日线"); +} + +#[test] +fn minutes_for_minute_freqs() { + assert_eq!(Freq::F30.minutes(), Some(30)); + assert_eq!(Freq::F1.minutes(), Some(1)); + assert_eq!(Freq::F360.minutes(), Some(360)); +} + +#[test] +fn minutes_none_for_higher_timeframes() { + assert_eq!(Freq::D.minutes(), None); + assert_eq!(Freq::W.minutes(), None); +} + +#[test] +fn is_minute_freq_classifies() { + assert!(Freq::F30.is_minute_freq()); + assert!(Freq::F1.is_minute_freq()); + assert!(!Freq::D.is_minute_freq()); + assert!(!Freq::Tick.is_minute_freq()); +} + +#[test] +fn ordering_is_total_and_consistent() { + // Freq derives PartialOrd + Ord. Minute timeframes should sort by enum + // declaration order (rs-czsc baseline behaviour we are locking). + assert!(Freq::F1 < Freq::F30); + assert!(Freq::F30 < Freq::D); +} diff --git a/crates/czsc-core/tests/test_fx.rs b/crates/czsc-core/tests/test_fx.rs new file mode 100644 index 000000000..dc2b2a9ee --- /dev/null +++ b/crates/czsc-core/tests/test_fx.rs @@ -0,0 +1,73 @@ +//! Phase D.7 — RED test: FX (分型) constructs via FXBuilder, exposes +//! power_str / power_volume / has_zs (non-python build), and compares +//! by structural equality. + +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use czsc_core::objects::bar::{NewBar, NewBarBuilder}; +use czsc_core::objects::fx::{FX, FXBuilder}; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::mark::Mark; + +fn nb(ts: i64, high: f64, low: f64, vol: f64) -> NewBar { + NewBarBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .freq(Freq::F30) + .id(0) + .open((high + low) / 2.0) + .close((high + low) / 2.0) + .high(high) + .low(low) + .vol(vol) + .amount(vol * (high + low) / 2.0) + .elements(Vec::new()) + .build() + .unwrap() +} + +fn sample_top_fx() -> FX { + // top fx (顶分型): middle bar's high is the highest + let k1 = nb(1_700_000_000, 11.0, 9.0, 100.0); + let k2 = nb(1_700_001_800, 12.0, 10.0, 200.0); // top + let k3 = nb(1_700_003_600, 11.5, 9.5, 100.0); + FXBuilder::default() + .symbol(Arc::::from("000001")) + .dt(k2.dt) + .mark(Mark::G) + .high(k2.high) + .low(k2.low) + .fx(k2.high) + .elements(vec![k1, k2, k3]) + .build() + .unwrap() +} + +#[test] +fn builder_populates_fields() { + let fx = sample_top_fx(); + assert_eq!(fx.mark, Mark::G); + assert_eq!(fx.elements.len(), 3); + assert!((fx.high - 12.0).abs() < f64::EPSILON); +} + +#[test] +fn power_str_returns_one_of_strong_medium_weak() { + let fx = sample_top_fx(); + let p = fx.power_str(); + assert!(matches!(p, "强" | "中" | "弱"), "got {p}"); +} + +#[test] +fn power_volume_is_finite() { + let fx = sample_top_fx(); + assert!(fx.power_volume().is_finite()); +} + +#[test] +fn equality_matches_rs_czsc_baseline() { + let a = sample_top_fx(); + let b = sample_top_fx(); + assert_eq!(a, b); +} diff --git a/crates/czsc-core/tests/test_mark_direction.rs b/crates/czsc-core/tests/test_mark_direction.rs new file mode 100644 index 000000000..b5122da0b --- /dev/null +++ b/crates/czsc-core/tests/test_mark_direction.rs @@ -0,0 +1,44 @@ +//! Phase D.6 — RED test: Mark + Direction enums parse from Chinese +//! serialised forms, format back via Display, and have stable equality. + +use std::str::FromStr; + +use czsc_core::objects::direction::Direction; +use czsc_core::objects::mark::Mark; + +#[test] +fn mark_parses_chinese_names() { + assert_eq!(Mark::from_str("底分型").unwrap(), Mark::D); + assert_eq!(Mark::from_str("顶分型").unwrap(), Mark::G); +} + +#[test] +fn mark_display_round_trips() { + assert_eq!(Mark::D.to_string(), "底分型"); + assert_eq!(Mark::G.to_string(), "顶分型"); +} + +#[test] +fn mark_rejects_unknown_string() { + assert!(Mark::from_str("分型X").is_err()); +} + +#[test] +fn direction_parses_chinese_names() { + assert_eq!(Direction::from_str("向上").unwrap(), Direction::Up); + assert_eq!(Direction::from_str("向下").unwrap(), Direction::Down); +} + +#[test] +fn direction_display_round_trips() { + assert_eq!(Direction::Up.to_string(), "向上"); + assert_eq!(Direction::Down.to_string(), "向下"); +} + +#[test] +fn direction_equality_and_clone() { + let a = Direction::Up; + let b = a.clone(); + assert_eq!(a, b); + assert_ne!(Direction::Up, Direction::Down); +} diff --git a/crates/czsc-core/tests/test_market.rs b/crates/czsc-core/tests/test_market.rs new file mode 100644 index 000000000..6c3cc7431 --- /dev/null +++ b/crates/czsc-core/tests/test_market.rs @@ -0,0 +1,31 @@ +//! Phase D.2 — RED test: Market enum parses Chinese names via FromStr, +//! formats back through Display, and survives equality / hash semantics. + +use std::str::FromStr; + +use czsc_core::objects::market::Market; + +#[test] +fn parses_chinese_names() { + assert_eq!(Market::from_str("A股").unwrap(), Market::AShare); + assert_eq!(Market::from_str("期货").unwrap(), Market::Futures); + assert_eq!(Market::from_str("默认").unwrap(), Market::Default); +} + +#[test] +fn rejects_unknown_strings() { + assert!(Market::from_str("invalid").is_err()); +} + +#[test] +fn display_round_trips() { + assert_eq!(Market::AShare.to_string(), "A股"); + assert_eq!(Market::Futures.to_string(), "期货"); +} + +#[test] +fn copy_and_equality() { + let a = Market::AShare; + let b = a; + assert_eq!(a, b); +} diff --git a/crates/czsc-core/tests/test_object_error.rs b/crates/czsc-core/tests/test_object_error.rs new file mode 100644 index 000000000..1f0dc9a1a --- /dev/null +++ b/crates/czsc-core/tests/test_object_error.rs @@ -0,0 +1,35 @@ +//! Phase D.1 — RED test: ObjectError variants must format with thiserror, +//! convert from anyhow::Error via the CZSCErrorDerive blanket, and serialize +//! to a string-shaped JSON. + +use czsc_core::objects::errors::ObjectError; + +#[test] +fn factor_signals_all_empty_message() { + let err = ObjectError::FactorSignalsAllEmpty; + assert_eq!( + err.to_string(), + "Factor.signals_all must contain at least one signal" + ); +} + +#[test] +fn score_out_of_range_carries_value() { + let err = ObjectError::ScoreOutOfRange(150); + assert!(err.to_string().contains("150")); +} + +#[test] +fn from_anyhow_blanket_routes_to_unexpected() { + let any: anyhow::Error = anyhow::anyhow!("boom"); + let err: ObjectError = any.into(); + assert!(matches!(err, ObjectError::Unexpected(_))); +} + +#[test] +fn serialize_emits_string_payload() { + let err = ObjectError::SignalKeyNotFound("k1".into()); + let json = serde_json::to_string(&err).unwrap(); + assert!(json.contains("k1"), "expected key in payload, got {json}"); + assert!(json.starts_with("\""), "expected JSON string, got {json}"); +} diff --git a/crates/czsc-core/tests/test_operate.rs b/crates/czsc-core/tests/test_operate.rs new file mode 100644 index 000000000..9642bebdd --- /dev/null +++ b/crates/czsc-core/tests/test_operate.rs @@ -0,0 +1,48 @@ +//! Phase D.10a — RED test: Operate enum parses from English short codes +//! (HL/HS/HO/LO/LE/SO/SE), formats back via Display, and exposes +//! to_chinese() for the canonical labels. + +use std::str::FromStr; + +use czsc_core::objects::operate::Operate; + +#[test] +fn parses_all_short_codes() { + assert_eq!(Operate::from_str("HL").unwrap(), Operate::HL); + assert_eq!(Operate::from_str("HS").unwrap(), Operate::HS); + assert_eq!(Operate::from_str("HO").unwrap(), Operate::HO); + assert_eq!(Operate::from_str("LO").unwrap(), Operate::LO); + assert_eq!(Operate::from_str("LE").unwrap(), Operate::LE); + assert_eq!(Operate::from_str("SO").unwrap(), Operate::SO); + assert_eq!(Operate::from_str("SE").unwrap(), Operate::SE); +} + +#[test] +fn rejects_unknown_string() { + assert!(Operate::from_str("XYZ").is_err()); +} + +#[test] +fn display_round_trips_short_codes() { + assert_eq!(Operate::HL.to_string(), "HL"); + assert_eq!(Operate::LO.to_string(), "LO"); +} + +#[test] +fn to_chinese_returns_canonical_labels() { + assert_eq!(Operate::HL.to_chinese(), "持多"); + assert_eq!(Operate::HS.to_chinese(), "持空"); + assert_eq!(Operate::HO.to_chinese(), "持币"); + assert_eq!(Operate::LO.to_chinese(), "开多"); + assert_eq!(Operate::LE.to_chinese(), "平多"); + assert_eq!(Operate::SO.to_chinese(), "开空"); + assert_eq!(Operate::SE.to_chinese(), "平空"); +} + +#[test] +fn copy_and_equality() { + let a = Operate::LO; + let b = a; + assert_eq!(a, b); + assert_ne!(Operate::LO, Operate::SO); +} diff --git a/crates/czsc-core/tests/test_position.rs b/crates/czsc-core/tests/test_position.rs new file mode 100644 index 000000000..88f5dd583 --- /dev/null +++ b/crates/czsc-core/tests/test_position.rs @@ -0,0 +1,62 @@ +//! Phase D.10d — RED test: Position deserialises from JSON, computes +//! get_pos / get_pos_changed defaults, normalises event hashes via +//! normalize_runtime_fields, and Pos enum's f64 round-trip works. + +use czsc_core::objects::position::{Pos, Position}; + +#[test] +fn pos_default_is_flat() { + assert_eq!(Pos::default(), Pos::Flat); +} + +#[test] +fn pos_to_f64_canonical() { + assert!((Pos::Long.to_f64() - 1.0).abs() < f64::EPSILON); + assert!((Pos::Flat.to_f64() - 0.0).abs() < f64::EPSILON); + assert!((Pos::Short.to_f64() + 1.0).abs() < f64::EPSILON); +} + +#[test] +fn pos_from_f64_threshold_at_half() { + assert_eq!(Pos::from_f64(0.6), Pos::Long); + assert_eq!(Pos::from_f64(-0.6), Pos::Short); + assert_eq!(Pos::from_f64(0.0), Pos::Flat); + assert_eq!(Pos::from_f64(0.4), Pos::Flat); + assert_eq!(Pos::from_f64(-0.4), Pos::Flat); +} + +#[test] +fn pos_display() { + assert_eq!(Pos::Long.to_string(), "多"); + assert_eq!(Pos::Short.to_string(), "空"); + assert_eq!(Pos::Flat.to_string(), "空仓"); +} + +const POSITION_JSON: &str = r#"{ + "opens": [], + "exits": [], + "interval": 0, + "timeout": 0, + "stop_loss": 0.0, + "T0": false, + "name": "test_position", + "symbol": "000001" +}"#; + +#[test] +fn position_deserialises_minimal_json() { + let mut p: Position = serde_json::from_str(POSITION_JSON).unwrap(); + p.normalize_runtime_fields(); + assert_eq!(p.symbol, "000001"); + assert_eq!(p.name, "test_position"); + assert_eq!(p.interval, 0); + assert!(!p.t0); +} + +#[test] +fn position_get_pos_starts_flat() { + let mut p: Position = serde_json::from_str(POSITION_JSON).unwrap(); + p.normalize_runtime_fields(); + assert_eq!(p.get_pos(), Pos::Flat); + assert!(!p.get_pos_changed()); +} diff --git a/crates/czsc-core/tests/test_raw_bar.rs b/crates/czsc-core/tests/test_raw_bar.rs new file mode 100644 index 000000000..d5c878ab5 --- /dev/null +++ b/crates/czsc-core/tests/test_raw_bar.rs @@ -0,0 +1,63 @@ +//! Phase D.5 — RED test: RawBar must construct via RawBarBuilder, expose +//! upper/lower/solid (non-python builds), and round-trip equality based on +//! the rs-czsc-defined fields. + +use chrono::{TimeZone, Utc}; +use czsc_core::objects::bar::{RawBar, RawBarBuilder}; +use czsc_core::objects::freq::Freq; + +fn sample(open: f64, close: f64, high: f64, low: f64) -> RawBar { + RawBarBuilder::default() + .symbol("000001") + .dt(Utc.with_ymd_and_hms(2024, 1, 8, 9, 30, 0).unwrap()) + .freq(Freq::F30) + .id(0) + .open(open) + .close(close) + .high(high) + .low(low) + .vol(1000.0_f64) + .amount(1_000_000.0_f64) + .build() + .unwrap() +} + +#[test] +fn builder_populates_fields() { + let bar = sample(10.0, 11.0, 12.0, 9.5); + assert_eq!(&*bar.symbol, "000001"); + assert_eq!(bar.freq, Freq::F30); + assert_eq!(bar.open, 10.0); + assert_eq!(bar.close, 11.0); + assert_eq!(bar.high, 12.0); + assert_eq!(bar.low, 9.5); +} + +#[test] +fn upper_shadow_is_high_minus_max_of_open_close() { + let bar = sample(10.0, 11.0, 12.0, 9.0); + // max(open, close) = 11; upper = 12 - 11 = 1 + assert_eq!(bar.upper(), 1.0); +} + +#[test] +fn lower_shadow_is_min_of_open_close_minus_low() { + let bar = sample(10.0, 11.0, 12.0, 9.0); + // min(open, close) = 10; lower = 10 - 9 = 1 + assert_eq!(bar.lower(), 1.0); +} + +#[test] +fn solid_is_abs_diff_of_open_close() { + let bull = sample(10.0, 11.0, 12.0, 9.0); + let bear = sample(11.0, 10.0, 12.0, 9.0); + assert!((bull.solid() - 1.0).abs() < f64::EPSILON); + assert!((bear.solid() - 1.0).abs() < f64::EPSILON); +} + +#[test] +fn equality_matches_rs_czsc_baseline() { + let a = sample(10.0, 11.0, 12.0, 9.0); + let b = sample(10.0, 11.0, 12.0, 9.0); + assert_eq!(a, b); +} diff --git a/crates/czsc-core/tests/test_signal.rs b/crates/czsc-core/tests/test_signal.rs new file mode 100644 index 000000000..bb0052376 --- /dev/null +++ b/crates/czsc-core/tests/test_signal.rs @@ -0,0 +1,58 @@ +//! Phase D.10b — RED test: Signal type (`SignalRef<'static>` aka `Signal`) +//! parses from the canonical k1_k2_k3_v1_v2_v3_score string and exposes +//! `key()` / `value()` / Display per the rs-czsc contract. + +use std::str::FromStr; + +use czsc_core::objects::signal::Signal; + +#[test] +fn parses_canonical_signal_string() { + let raw = "30分钟_D1_前高_看多_强_任意_0"; + let s = Signal::from_str(raw).unwrap(); + // key drops "任意" parts; here all of k1/k2/k3 are concrete + assert_eq!(s.key(), "30分钟_D1_前高"); + // value is v1_v2_v3_score + assert_eq!(s.value(), "看多_强_任意_0"); +} + +#[test] +fn display_round_trips_full_signal() { + let raw = "30分钟_D1_前高_看多_强_任意_0"; + let s = Signal::from_str(raw).unwrap(); + assert_eq!(s.to_string(), raw); +} + +#[test] +fn rejects_malformed_string() { + assert!(Signal::from_str("only_three_fields").is_err()); +} + +#[test] +fn equality_is_full_signal_string() { + let a = Signal::from_str("30分钟_D1_前高_看多_强_任意_0").unwrap(); + let b = Signal::from_str("30分钟_D1_前高_看多_强_任意_0").unwrap(); + assert_eq!(a, b); + let c = Signal::from_str("30分钟_D1_前高_看空_强_任意_0").unwrap(); + assert_ne!(a, c); +} + +#[test] +fn key_skips_wildcards() { + let s = Signal::from_str("任意_D1_前高_看多_强_任意_0").unwrap(); + // k1 is 任意 → dropped from key + assert_eq!(s.key(), "D1_前高"); +} + +#[test] +fn is_match_obeys_score_and_wildcards() { + use std::collections::HashMap; + let s = Signal::from_str("30分钟_D1_前高_看多_强_任意_50").unwrap(); + let mut dict = HashMap::new(); + dict.insert("30分钟_D1_前高".to_string(), "看多_强_中_60".to_string()); + assert!(s.is_match(&dict), "score 60 >= 50 with v3 wildcard should match"); + + let mut low_score = HashMap::new(); + low_score.insert("30分钟_D1_前高".to_string(), "看多_强_中_40".to_string()); + assert!(!s.is_match(&low_score), "score 40 < 50 must not match"); +} diff --git a/crates/czsc-core/tests/test_state.rs b/crates/czsc-core/tests/test_state.rs new file mode 100644 index 000000000..000b5dd3c --- /dev/null +++ b/crates/czsc-core/tests/test_state.rs @@ -0,0 +1,30 @@ +//! Phase D.A follow-up — RED test: TraderState trait is publicly defined +//! and a minimal stub impl is accepted by the type system. czsc-signals +//! and czsc-trader will depend on this trait once they migrate. + +use czsc_core::analyze::CZSC; +use czsc_core::objects::position::Position; +use czsc_core::objects::state::TraderState; + +struct StubTrader; + +impl TraderState for StubTrader { + fn get_position(&self, _name: &str) -> Option<&Position> { + None + } + fn get_czsc(&self, _freq: &str) -> Option<&CZSC> { + None + } + fn latest_price(&self) -> Option { + None + } +} + +#[test] +fn stub_trader_implements_trait() { + let s = StubTrader; + let _: &dyn TraderState = &s; + assert!(s.get_position("foo").is_none()); + assert!(s.get_czsc("30分钟").is_none()); + assert!(s.latest_price().is_none()); +} diff --git a/crates/czsc-core/tests/test_utils_common.rs b/crates/czsc-core/tests/test_utils_common.rs new file mode 100644 index 000000000..de8b61a58 --- /dev/null +++ b/crates/czsc-core/tests/test_utils_common.rs @@ -0,0 +1,26 @@ +//! Phase D.4 — RED test: freq_to_chinese_string returns the canonical +//! Chinese label for each Freq variant; this is the helper bar.rs and +//! various PyO3 bindings rely on to build display strings. + +use czsc_core::objects::freq::Freq; +use czsc_core::utils::common::freq_to_chinese_string; + +#[test] +fn covers_minute_freqs() { + assert_eq!(freq_to_chinese_string(Freq::F1), "1分钟"); + assert_eq!(freq_to_chinese_string(Freq::F30), "30分钟"); + assert_eq!(freq_to_chinese_string(Freq::F360), "360分钟"); +} + +#[test] +fn covers_higher_timeframes() { + assert_eq!(freq_to_chinese_string(Freq::D), "日线"); + assert_eq!(freq_to_chinese_string(Freq::W), "周线"); + assert_eq!(freq_to_chinese_string(Freq::M), "月线"); + assert_eq!(freq_to_chinese_string(Freq::Y), "年线"); +} + +#[test] +fn tick_returns_english_marker() { + assert_eq!(freq_to_chinese_string(Freq::Tick), "Tick"); +} diff --git a/crates/czsc-core/tests/test_zs.rs b/crates/czsc-core/tests/test_zs.rs new file mode 100644 index 000000000..246b3338c --- /dev/null +++ b/crates/czsc-core/tests/test_zs.rs @@ -0,0 +1,115 @@ +//! Phase D.9 — RED test: ZS (中枢) constructs from a non-empty BI list, +//! computes zg / zd / zz / gg / dd boundaries, and surfaces is_valid(). + +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use czsc_core::objects::bar::{NewBar, NewBarBuilder}; +use czsc_core::objects::bi::{BI, BIBuilder}; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::fx::{FX, FXBuilder}; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::zs::ZS; + +fn nb(ts: i64, high: f64, low: f64) -> NewBar { + NewBarBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .freq(Freq::F30) + .id(0) + .open((high + low) / 2.0) + .close((high + low) / 2.0) + .high(high) + .low(low) + .vol(100.0) + .amount(100.0 * (high + low) / 2.0) + .elements(Vec::new()) + .build() + .unwrap() +} + +fn fx(ts: i64, mark: Mark, level: f64) -> FX { + let k1 = nb(ts - 1800, level - 0.5, level - 1.5); + let k2 = nb(ts, level + 0.5, level - 0.5); + let k3 = nb(ts + 1800, level - 0.2, level - 1.0); + let mark_clone = mark.clone(); + FXBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .mark(mark_clone) + .high(k2.high) + .low(k2.low) + .fx(if matches!(mark, Mark::G) { k2.high } else { k2.low }) + .elements(vec![k1, k2, k3]) + .build() + .unwrap() +} + +fn make_bi(ts_a: i64, mark_a: Mark, level_a: f64, + ts_b: i64, mark_b: Mark, level_b: f64, + direction: Direction) -> BI { + let fx_a = fx(ts_a, mark_a, level_a); + let fx_b = fx(ts_b, mark_b, level_b); + // bars span — endpoints determine high/low + let bars = vec![ + nb(ts_a, level_a + 0.5, level_a - 0.5), + nb((ts_a + ts_b) / 2, (level_a + level_b) / 2.0 + 0.5, (level_a + level_b) / 2.0 - 0.5), + nb(ts_b, level_b + 0.5, level_b - 0.5), + ]; + BIBuilder::default() + .symbol(Arc::::from("000001")) + .fx_a(fx_a) + .fx_b(fx_b.clone()) + .fxs(vec![fx_b]) + .direction(direction) + .bars(bars) + .build() + .unwrap() +} + +fn sample_zs() -> ZS { + // 3-bi center: down 12 -> 9, up 9 -> 11, down 11 -> 9.5 + let bi1 = make_bi(1_700_000_000, Mark::G, 12.0, + 1_700_001_800, Mark::D, 9.0, Direction::Down); + let bi2 = make_bi(1_700_001_800, Mark::D, 9.0, + 1_700_003_600, Mark::G, 11.0, Direction::Up); + let bi3 = make_bi(1_700_003_600, Mark::G, 11.0, + 1_700_005_400, Mark::D, 9.5, Direction::Down); + ZS::new(vec![bi1, bi2, bi3]) +} + +#[test] +fn new_populates_endpoints() { + let zs = sample_zs(); + assert_eq!(zs.bis.len(), 3); + assert_eq!(zs.sdir, Direction::Down); + assert_eq!(zs.edir, Direction::Down); +} + +#[test] +fn zg_zd_within_first_three_bis() { + let zs = sample_zs(); + // zg = min of first 3 bis' highs; zd = max of first 3 bis' lows + assert!(zs.zg >= zs.zd, "zg={} must be >= zd={}", zs.zg, zs.zd); +} + +#[test] +fn zz_is_midpoint_of_zg_zd() { + let zs = sample_zs(); + let mid = (zs.zg + zs.zd) / 2.0; + assert!((zs.zz - mid).abs() < 1e-9, "zz {} should equal mid {}", zs.zz, mid); +} + +#[test] +fn gg_dd_envelope_zg_zd() { + let zs = sample_zs(); + assert!(zs.gg >= zs.zg, "gg {} must be >= zg {}", zs.gg, zs.zg); + assert!(zs.dd <= zs.zd, "dd {} must be <= zd {}", zs.dd, zs.zd); +} + +#[test] +fn is_valid_returns_bool() { + let zs = sample_zs(); + let _ = zs.is_valid(); // doesn't matter true or false — must not panic +} diff --git a/crates/czsc-python/Cargo.toml b/crates/czsc-python/Cargo.toml new file mode 100644 index 000000000..ae8492201 --- /dev/null +++ b/crates/czsc-python/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "czsc-python" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "PyO3 binding aggregator that produces the czsc._native extension module." + +[lib] +name = "czsc_python" +path = "src/lib.rs" +crate-type = ["cdylib", "rlib"] + +[dependencies] +czsc-core = { path = "../czsc-core", features = ["python"] } +czsc-ta = { path = "../czsc-ta", features = ["rust-numpy"] } +czsc-utils = { path = "../czsc-utils", features = ["python"] } +czsc-signals = { path = "../czsc-signals" } +czsc-trader = { path = "../czsc-trader" } +error-macros = { path = "../error-macros" } +error-support = { path = "../error-support" } +anyhow = "1" +chrono = { workspace = true } +inventory = "0.3" +md5 = "0.8" +numpy = { workspace = true } +polars = { workspace = true } +# czsc-python is the only crate that opts into pyo3/extension-module + +# abi3-py310. Business crates pull in the bare workspace pyo3 so +# `cargo test --workspace` can link against a real libpython resolved +# via PYO3_PYTHON. +pyo3 = { workspace = true, features = ["extension-module", "abi3-py310", "chrono"] } +pyo3-stub-gen = "0.12" +rust_xlsxwriter = "0.79" +serde = { workspace = true } +serde_json = "1" +thiserror = "2" diff --git a/crates/czsc-python/build.rs b/crates/czsc-python/build.rs new file mode 100644 index 000000000..c03853420 --- /dev/null +++ b/crates/czsc-python/build.rs @@ -0,0 +1,14 @@ +//! Build script for czsc-python. +//! +//! On macOS the cdylib needs `-undefined dynamic_lookup` so that Python +//! symbols are resolved at runtime by the host interpreter. PyO3's +//! `extension-module` feature normally emits this, but when building via +//! plain `cargo build --workspace` (without maturin) we make the link arg +//! explicit so the workspace layout test stays GREEN. + +fn main() { + if std::env::var("CARGO_CFG_TARGET_OS").as_deref() == Ok("macos") { + println!("cargo:rustc-link-arg-cdylib=-undefined"); + println!("cargo:rustc-link-arg-cdylib=dynamic_lookup"); + } +} diff --git a/crates/czsc-python/src/errors.rs b/crates/czsc-python/src/errors.rs new file mode 100644 index 000000000..016d7f8b7 --- /dev/null +++ b/crates/czsc-python/src/errors.rs @@ -0,0 +1,44 @@ +//! czsc-python error type. +//! +//! Lifted from rs-czsc `python/src/errors.rs`, but the +//! `WeightBackTest` variant is dropped — czsc relies on the external +//! `wbt` crate for weight backtests, so its error chain doesn't +//! flow through this binding. + +use czsc_core::utils::errors::CoreUtilsErorr; +use czsc_utils::errors::UtilsError; +use error_macros::CZSCErrorDerive; +use error_support::expand_error_chain; +use numpy::NotContiguousError; +use polars::error::PolarsError; +use pyo3::{PyErr, create_exception, exceptions::PyException}; +use thiserror::Error; + +create_exception!(_native, CZSCError, PyException); + +#[derive(Debug, Error, CZSCErrorDerive)] +pub enum PythonError { + #[error("Utils: {0}")] + Utils(#[from] UtilsError), + + #[error("Polars: {0}")] + Polars(#[from] PolarsError), + + #[error("{}", expand_error_chain(.0))] + Unexpected(anyhow::Error), + + #[error("CoreUtils: {0}")] + CoreUtils(#[from] CoreUtilsErorr), + + #[error("Numpy: {0}")] + NotContiguous(#[from] NotContiguousError), + + #[error("NotFound: {0}")] + NotFound(String), +} + +impl From for PyErr { + fn from(e: PythonError) -> Self { + PyException::new_err(e.to_string()) + } +} diff --git a/crates/czsc-python/src/lib.rs b/crates/czsc-python/src/lib.rs new file mode 100644 index 000000000..ed61eed78 --- /dev/null +++ b/crates/czsc-python/src/lib.rs @@ -0,0 +1,65 @@ +//! czsc-python — PyO3 aggregator that produces the `czsc._native` extension. +//! +//! Each business crate's PyO3 surface is registered here. The crate is +//! the only one that links `pyo3 = { features = ["extension-module"] }` +//! and produces the cdylib loaded by Python. + +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; + +mod errors; +mod signals_dispatcher; +mod trader; +mod utils; + +#[pymodule] +fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + czsc_core::python::register(py, m)?; + czsc_utils::python::register(py, m)?; + czsc_ta::python::register(py, m)?; + + // czsc-signals contributes `SignalDescriptor` entries via + // `inventory::collect!`. The dummy iterator forces the crate + // into the final cdylib so the constructors run on import. + let _signals_count = inventory::iter::() + .into_iter() + .count(); + + // Trader surface — CzscTrader, CzscSignals, generate_czsc_signals. + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(trader::generate::generate_czsc_signals, m)?)?; + + // Research / optimize entrypoints (mirrors rs_czsc/python/src/lib.rs). + // These are the heavy-lift functions that strategies.py / + // research.py / optimize.py wrap thinly on the Python side. + m.add_function(wrap_pyfunction!(trader::api::list_all_signals, m)?)?; + m.add_function(wrap_pyfunction!(trader::api::derive_signals_config, m)?)?; + m.add_function(wrap_pyfunction!(trader::api::derive_signals_freqs, m)?)?; + m.add_function(wrap_pyfunction!(trader::api::generate_signals, m)?)?; + m.add_function(wrap_pyfunction!(trader::api::run_backtest, m)?)?; + m.add_function(wrap_pyfunction!(trader::api::run_optimize, m)?)?; + m.add_function(wrap_pyfunction!(trader::research::run_research, m)?)?; + m.add_function(wrap_pyfunction!(trader::research::run_replay, m)?)?; + m.add_function(wrap_pyfunction!(trader::research::run_optimize_batch, m)?)?; + m.add_function(wrap_pyfunction!(trader::research::build_open_optim_positions, m)?)?; + m.add_function(wrap_pyfunction!(trader::research::build_exit_optim_positions, m)?)?; + + // czsc._native.signals namespace + per-category sub-modules + // (bar / cxt / tas / vol / pressure / obv / cvolp). The dispatcher + // is registered on each so that + // from czsc._native.signals import call_signal + // and + // from czsc._native.signals.bar import list_signal_names + // both resolve. See `signals_dispatcher.rs` for the design. + let signals = PyModule::new(py, "signals")?; + signals.setattr("__name__", "czsc._native.signals")?; + let sys = py.import("sys")?; + let py_modules = sys.getattr("modules")?; + py_modules.set_item("czsc._native.signals", &signals)?; + m.add("signals", &signals)?; + + signals_dispatcher::register(py, m, &signals)?; + + Ok(()) +} diff --git a/crates/czsc-python/src/signals_dispatcher.rs b/crates/czsc-python/src/signals_dispatcher.rs new file mode 100644 index 000000000..476c4a4e6 --- /dev/null +++ b/crates/czsc-python/src/signals_dispatcher.rs @@ -0,0 +1,177 @@ +//! czsc._native signal dispatcher (design doc §3.3). +//! +//! Per-signal PyO3 wrappers would require ~30+ hand-written `#[pyfunction]` +//! definitions; instead we expose a single dispatcher that looks the +//! signal up by name in the inventory table contributed by +//! `czsc-signals`. The Python-side ``czsc/signals/{bar,cxt,...}.py`` +//! shims attach a per-name closure via ``__getattr__`` so user code +//! reads naturally: +//! +//! ```python +//! from czsc.signals.bar import bar_amount_acc_V230214 +//! result = bar_amount_acc_V230214(czsc_obj, {"di": 1, "n": 5}) +//! ``` +//! +//! The dispatcher only handles **kline** signals (``fn(&CZSC, ¶ms, +//! &mut TaCache) -> Vec``). Trader-state signals require a +//! ``CzscTrader`` instance and are dispatched via +//! ``CzscTrader.update_signals`` / ``CzscSignals.update_signals``. + +use crate::trader::czsc_signals::py_to_serde_value; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::PySignal; +use czsc_signals::types::{SignalDescriptor, SignalFnRef, TaCache}; +use pyo3::exceptions::{PyKeyError, PyTypeError}; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use serde_json::Value; +use std::collections::HashMap; + +/// Find a signal descriptor by name. Returns `None` if no descriptor +/// has that name. Callers should treat this as a missing signal. +fn lookup(name: &str) -> Option<&'static SignalDescriptor> { + inventory::iter::() + .into_iter() + .find(|d| d.name == name) +} + +/// Extract the category prefix from a signal name (e.g. ``bar`` from +/// ``bar_amount_acc_V230214``). Returns ``None`` if the name has no +/// underscore. +fn name_prefix(name: &str) -> Option<&str> { + name.split_once('_').map(|(p, _)| p) +} + +/// Convert a Python params dict (or `None`) into the +/// ``HashMap`` shape used by all kline signal +/// functions. Accepts ``None`` as an empty dict. +fn extract_params(params: Option<&Bound<'_, PyDict>>) -> PyResult> { + let mut out: HashMap = HashMap::new(); + if let Some(d) = params { + for (k, v) in d.iter() { + let key: String = k.extract()?; + let val = py_to_serde_value(&v)?; + out.insert(key, val); + } + } + Ok(out) +} + +/// Invoke a kline signal by name on the supplied CZSC instance. +/// +/// Returns a list of ``czsc.Signal`` objects (the same type +/// produced by ``CzscSignals.update_signals``). +#[pyfunction] +#[pyo3(signature = (name, czsc, params=None))] +pub fn call_signal( + name: &str, + czsc: &CZSC, + params: Option<&Bound<'_, PyDict>>, +) -> PyResult> { + let descriptor = lookup(name) + .ok_or_else(|| PyKeyError::new_err(format!("unknown signal: {name}")))?; + + let kline_func = match descriptor.func_ref { + SignalFnRef::Kline(f) => f, + SignalFnRef::Trader(_) => { + return Err(PyTypeError::new_err(format!( + "{name} is a trader-state signal; dispatch via CzscTrader.update_signals" + ))); + } + }; + + let params_map = extract_params(params)?; + let mut cache = TaCache::default(); + let signals = kline_func(czsc, ¶ms_map, &mut cache); + Ok(signals.into_iter().map(PySignal::from).collect()) +} + +/// List signal names contributed by the inventory. +/// +/// ``category`` is matched against the signal-name prefix (the part +/// before the first underscore). Common values: ``bar``, ``cxt``, +/// ``tas``, ``vol``, ``pressure``, ``obv``, ``cvolp``. Pass ``None`` +/// to return every kline signal. +#[pyfunction] +#[pyo3(signature = (category=None))] +pub fn list_signal_names(category: Option<&str>) -> Vec { + let mut out: Vec = inventory::iter::() + .into_iter() + .filter(|d| matches!(d.func_ref, SignalFnRef::Kline(_))) + .filter(|d| match category { + Some(c) => name_prefix(d.name).map(|p| p == c).unwrap_or(false), + None => true, + }) + .map(|d| d.name.to_string()) + .collect(); + out.sort(); + out +} + +/// Return the parameter template for ``name``, or ``None`` if no signal +/// with that name is registered. The template is the schema string +/// declared in the `#[signal(...)]` macro and matches what the legacy +/// Python helpers parse. +#[pyfunction] +pub fn get_signal_template(name: &str) -> Option { + lookup(name).map(|d| d.template.to_string()) +} + +/// Return the category prefix for ``name`` (``"bar"`` / ``"cxt"`` / +/// ...). ``None`` when the signal isn't registered or its name has no +/// underscore. +#[pyfunction] +pub fn get_signal_category(name: &str) -> Option { + let descriptor = lookup(name)?; + name_prefix(descriptor.name).map(|p| p.to_string()) +} + +/// Register the dispatcher symbols on both ``czsc._native`` (top-level) +/// and ``czsc._native.signals`` (submodule). The submodule entries +/// give design-doc §3.3 the path ``from czsc._native.signals import +/// call_signal``. +pub fn register(py: Python<'_>, m: &Bound<'_, PyModule>, signals_mod: &Bound<'_, PyModule>) -> PyResult<()> { + use pyo3::wrap_pyfunction; + + m.add_function(wrap_pyfunction!(call_signal, m)?)?; + m.add_function(wrap_pyfunction!(list_signal_names, m)?)?; + m.add_function(wrap_pyfunction!(get_signal_template, m)?)?; + m.add_function(wrap_pyfunction!(get_signal_category, m)?)?; + + signals_mod.add_function(wrap_pyfunction!(call_signal, signals_mod)?)?; + signals_mod.add_function(wrap_pyfunction!(list_signal_names, signals_mod)?)?; + signals_mod.add_function(wrap_pyfunction!(get_signal_template, signals_mod)?)?; + signals_mod.add_function(wrap_pyfunction!(get_signal_category, signals_mod)?)?; + + // Per-category sub-modules: czsc._native.signals.{bar,cxt,...}. + // Each gets the full dispatcher trio so user code can write: + // + // import czsc._native.signals.bar as bar_mod + // bar_mod.list_signal_names() # only bar_* names + // + // The Python-side `czsc/signals/.py` shim layers __getattr__ + // on top of these to expose individual functions. + let categories = [ + "bar", + "cxt", + "tas", + "vol", + "pressure", + "obv", + "cvolp", + ]; + let sys = py.import("sys")?; + let py_modules = sys.getattr("modules")?; + for cat in categories { + let cat_mod = PyModule::new(py, cat)?; + cat_mod.setattr("__name__", format!("czsc._native.signals.{cat}"))?; + cat_mod.setattr("__category__", cat)?; + cat_mod.add_function(wrap_pyfunction!(call_signal, &cat_mod)?)?; + cat_mod.add_function(wrap_pyfunction!(list_signal_names, &cat_mod)?)?; + cat_mod.add_function(wrap_pyfunction!(get_signal_template, &cat_mod)?)?; + py_modules.set_item(format!("czsc._native.signals.{cat}"), &cat_mod)?; + signals_mod.add(cat, &cat_mod)?; + } + + Ok(()) +} diff --git a/crates/czsc-python/src/trader/api.rs b/crates/czsc-python/src/trader/api.rs new file mode 100644 index 000000000..d44fc9126 --- /dev/null +++ b/crates/czsc-python/src/trader/api.rs @@ -0,0 +1,1439 @@ +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyDict, PyList}; +use rust_xlsxwriter::Workbook; +use serde_json::{Value, json}; +use std::collections::{BTreeSet, HashMap, HashSet}; +use std::fs; +use std::path::{Path, PathBuf}; + +use crate::utils::df_convert::pyarrow_to_df; +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, Utc}; +use czsc_core::analyze::utils::format_standard_kline; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::position::Position; +use czsc_trader::engine_v2::{ExecutionPlan, ExecutionPlanInput, UnifiedExecEngine}; +use czsc_trader::optimize::{ + get_exit_optim_positions, get_open_optim_positions, symbols_optim_parallel, +}; +use czsc_trader::signals::sig_parse::{SignalConfig, get_signals_config, get_signals_freqs}; +use czsc_signals::registry::list_all_signals as list_all_registered_signals; +use polars::prelude::*; + +fn write_df_parquet(path: &Path, mut df: DataFrame) -> PyResult<()> { + let mut file = fs::File::create(path) + .map_err(|e| PyValueError::new_err(format!("创建输出文件失败: {e}")))?; + ParquetWriter::new(&mut file) + .finish(&mut df) + .map_err(|e| PyRuntimeError::new_err(format!("写出 parquet 失败: {e}")))?; + Ok(()) +} + +pub(crate) fn build_signals_dataframe(rows: &[HashMap]) -> PyResult { + let mut keys: BTreeSet = BTreeSet::new(); + for r in rows { + keys.extend(r.keys().cloned()); + } + if !keys.contains("cache") { + keys.insert("cache".to_string()); + } + + let mut cols: Vec = Vec::new(); + for k in keys { + let vals: Vec> = if k == "cache" { + rows.iter().map(|_| Some("{}".to_string())).collect() + } else { + rows.iter().map(|r| r.get(&k).cloned()).collect() + }; + cols.push(Series::new(k.as_str(), vals)); + } + DataFrame::new(cols) + .map_err(|e| PyRuntimeError::new_err(format!("构建 signals DataFrame 失败: {e}"))) +} + +fn align_signals_python_baseline( + mut df: DataFrame, + cutoff: Option>, + cutoff_bar: Option<&RawBar>, +) -> PyResult { + if df.column("dt").is_ok() { + df = df + .lazy() + .filter(col("dt").is_not_null()) + .collect() + .map_err(|e| PyRuntimeError::new_err(format!("signals 过滤空 dt 失败: {e}")))?; + } + + let Some(cutoff_bar) = cutoff_bar else { + return Ok(df); + }; + if cutoff.is_none() || df.height() == 0 || df.column("dt").is_err() { + return Ok(df); + } + + let cutoff_dt = cutoff_bar.dt.to_rfc3339(); + let dt_col = df + .column("dt") + .map_err(|e| PyRuntimeError::new_err(format!("读取 signals.dt 失败: {e}")))?; + let has_cutoff = dt_col + .str() + .map_err(|e| PyRuntimeError::new_err(format!("signals.dt 类型错误: {e}")))? + .into_iter() + .any(|x| x == Some(cutoff_dt.as_str())); + if has_cutoff { + return Ok(df); + } + + let mut head = df.slice(0, 1); + let base_cols: Vec<(&str, String)> = vec![ + ("symbol", cutoff_bar.symbol.to_string()), + ("id", cutoff_bar.id.to_string()), + ("dt", cutoff_dt), + ("freq", cutoff_bar.freq.to_string()), + ("open", cutoff_bar.open.to_string()), + ("close", cutoff_bar.close.to_string()), + ("high", cutoff_bar.high.to_string()), + ("low", cutoff_bar.low.to_string()), + ("vol", cutoff_bar.vol.to_string()), + ("amount", cutoff_bar.amount.to_string()), + ("cache", "{}".to_string()), + ]; + for (name, value) in base_cols { + if head.column(name).is_ok() { + head.with_column(Series::new(name, &[Some(value.as_str())])) + .map_err(|e| { + PyRuntimeError::new_err(format!("补齐 signals 列 {name} 失败: {e}")) + })?; + } + } + head.vstack_mut(&df) + .map_err(|e| PyRuntimeError::new_err(format!("拼接 signals 边界行失败: {e}")))?; + Ok(head) +} + +pub(crate) fn normalize_signals_dtypes(mut df: DataFrame) -> PyResult { + if df.column("dt").is_ok() { + df = df + .lazy() + .with_column( + col("dt") + .str() + .to_datetime( + Some(TimeUnit::Nanoseconds), + None, + StrptimeOptions { + format: Some("%Y-%m-%dT%H:%M:%S%.f%z".into()), + strict: false, + exact: false, + cache: true, + }, + lit("raise"), + ) + .dt() + .replace_time_zone(None, lit("raise"), NonExistent::Raise), + ) + .collect() + .map_err(|e| PyRuntimeError::new_err(format!("signals 列 dt 类型转换失败: {e}")))?; + } + let casts = [ + ("id", DataType::Int64), + ("open", DataType::Float64), + ("close", DataType::Float64), + ("high", DataType::Float64), + ("low", DataType::Float64), + ("vol", DataType::Float64), + ("amount", DataType::Float64), + ]; + for (name, dtype) in casts { + if df.column(name).is_ok() { + let casted = df.column(name).and_then(|s| s.cast(&dtype)).map_err(|e| { + PyRuntimeError::new_err(format!("signals 列 {name} 类型转换失败: {e}")) + })?; + df.with_column(casted) + .map_err(|e| PyRuntimeError::new_err(format!("signals 写回列 {name} 失败: {e}")))?; + } + } + Ok(df) +} + +fn combine_pairs_holds_for_backtest(positions: &[Position]) -> PyResult<(DataFrame, DataFrame)> { + let mut all_pairs = Vec::new(); + let mut all_holds = Vec::new(); + + for pos in positions { + if let Ok(df) = pos.pairs() { + all_pairs.push(df.lazy()); + } + if let Ok(mut df) = pos.holds() { + if df.height() == 0 { + continue; + } + let pos_name = Series::new("pos_name", vec![pos.name.clone(); df.height()]); + df.with_column(pos_name) + .map_err(|e| PyRuntimeError::new_err(format!("追加 pos_name 列失败: {e}")))?; + + let keep_cols: Vec = ["dt", "pos", "price", "pos_name"] + .iter() + .filter(|c| df.column(c).is_ok()) + .map(|x| x.to_string()) + .collect(); + let refs: Vec<&str> = keep_cols.iter().map(|x| x.as_str()).collect(); + let df = df + .select(refs) + .map_err(|e| PyRuntimeError::new_err(format!("筛选 holds 列失败: {e}")))?; + all_holds.push(df.lazy()); + } + } + + let mut pairs_df = if all_pairs.is_empty() { + DataFrame::default() + } else { + let mut df = concat(all_pairs, UnionArgs::default()) + .and_then(|lf| lf.collect()) + .map_err(|e| PyRuntimeError::new_err(format!("合并 pairs 失败: {e}")))?; + let mut sort_cols = Vec::new(); + for c in ["pos_name", "开仓时间", "平仓时间"] { + if df.column(c).is_ok() { + sort_cols.push(c); + } + } + if !sort_cols.is_empty() { + let lf = df.lazy(); + df = lf + .sort(sort_cols, SortMultipleOptions::default()) + .collect() + .map_err(|e| PyRuntimeError::new_err(format!("pairs 排序失败: {e}")))?; + } + df + }; + for (name, dtype) in [ + ("开仓时间", DataType::Datetime(TimeUnit::Nanoseconds, None)), + ("平仓时间", DataType::Datetime(TimeUnit::Nanoseconds, None)), + ("持仓K线数", DataType::Int64), + ] { + if pairs_df.column(name).is_ok() { + let casted = pairs_df + .column(name) + .and_then(|s| s.cast(&dtype)) + .map_err(|e| { + PyRuntimeError::new_err(format!("pairs 列 {name} 类型转换失败: {e}")) + })?; + pairs_df + .with_column(casted) + .map_err(|e| PyRuntimeError::new_err(format!("pairs 写回列 {name} 失败: {e}")))?; + } + } + + let mut holds_df = if all_holds.is_empty() { + DataFrame::default() + } else { + let mut df = concat(all_holds, UnionArgs::default()) + .and_then(|lf| lf.collect()) + .map_err(|e| PyRuntimeError::new_err(format!("合并 holds 失败: {e}")))?; + let mut sort_cols = Vec::new(); + for c in ["pos_name", "dt"] { + if df.column(c).is_ok() { + sort_cols.push(c); + } + } + if !sort_cols.is_empty() { + let lf = df.lazy(); + df = lf + .sort(sort_cols, SortMultipleOptions::default()) + .collect() + .map_err(|e| PyRuntimeError::new_err(format!("holds 排序失败: {e}")))?; + } + df + }; + for (name, dtype) in [ + ("dt", DataType::Datetime(TimeUnit::Nanoseconds, None)), + ("pos", DataType::Int64), + ] { + if holds_df.column(name).is_ok() { + let casted = holds_df + .column(name) + .and_then(|s| s.cast(&dtype)) + .map_err(|e| { + PyRuntimeError::new_err(format!("holds 列 {name} 类型转换失败: {e}")) + })?; + holds_df + .with_column(casted) + .map_err(|e| PyRuntimeError::new_err(format!("holds 写回列 {name} 失败: {e}")))?; + } + } + + Ok((pairs_df, holds_df)) +} + +/// 内部从 parquet 文件读取 bars +fn read_bars_from_file(path: &Path, base_freq: Freq) -> PyResult> { + let file = fs::File::open(path).map_err(|e| { + PyValueError::new_err(format!("读取 bars 文件失败 {}: {e}", path.display())) + })?; + let df = ParquetReader::new(file) + .finish() + .map_err(|e| PyValueError::new_err(format!("解析 parquet 失败 {}: {e}", path.display())))?; + format_standard_kline(df, base_freq).map_err(|e| { + PyValueError::new_err(format!( + "标准化 K 线失败 {} (freq={base_freq}): {e}", + path.display() + )) + }) +} + +fn parse_sdt_utc(s: &str) -> Option> { + if s.is_empty() { + return None; + } + if let Ok(dt) = DateTime::parse_from_rfc3339(s) { + return Some(dt.with_timezone(&Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(d) = NaiveDate::parse_from_str(s, "%Y-%m-%d") + && let Some(ndt) = d.and_hms_opt(0, 0, 0) + { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(d) = NaiveDate::parse_from_str(s, "%Y%m%d") + && let Some(ndt) = d.and_hms_opt(0, 0, 0) + { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + None +} + +fn py_escape_str(s: &str) -> String { + s.replace('\\', "\\\\").replace('\'', "\\'") +} + +fn py_repr_list_str(items: &[String]) -> String { + if items.is_empty() { + return "[]".to_string(); + } + let body = items + .iter() + .map(|x| format!("'{}'", py_escape_str(x))) + .collect::>() + .join(", "); + format!("[{body}]") +} + +fn py_repr_json(v: &Value) -> String { + match v { + Value::Null => "None".to_string(), + Value::Bool(b) => { + if *b { + "True".to_string() + } else { + "False".to_string() + } + } + Value::Number(n) => n.to_string(), + Value::String(s) => format!("'{}'", py_escape_str(s)), + Value::Array(arr) => { + let body = arr.iter().map(py_repr_json).collect::>().join(", "); + format!("[{body}]") + } + Value::Object(map) => { + let body = map + .iter() + .map(|(k, val)| format!("'{}': {}", py_escape_str(k), py_repr_json(val))) + .collect::>() + .join(", "); + format!("{{{body}}}") + } + } +} + +fn md5_upper8(s: &str) -> String { + let digest = md5::compute(s.as_bytes()); + format!("{:x}", digest)[..8].to_uppercase() +} + +fn sanitize_file_stem(name: &str) -> String { + name.chars() + .map(|c| match c { + '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_', + _ => c, + }) + .collect() +} + +fn event_to_position_file_dump(e: &czsc_core::objects::event::Event) -> Value { + let all = e + .signals_all + .iter() + .map(|s| s.to_string()) + .collect::>(); + let any = e + .signals_any + .iter() + .map(|s| s.to_string()) + .collect::>(); + let not = e + .signals_not + .iter() + .map(|s| s.to_string()) + .collect::>(); + json!({ + "name": e.name, + "operate": e.operate.to_chinese(), + "signals_all": all, + "signals_any": any, + "signals_not": not, + }) +} + +fn position_to_position_file_dump(p: &Position) -> Value { + json!({ + "name": p.name, + "opens": p.opens.iter().map(event_to_position_file_dump).collect::>(), + "exits": p.exits.iter().map(event_to_position_file_dump).collect::>(), + "interval": p.interval, + "timeout": p.timeout, + "stop_loss": p.stop_loss, + "T0": p.t0, + }) +} + +fn position_to_position_file_repr_for_md5(p: &Position) -> String { + let event_repr = |e: &czsc_core::objects::event::Event| { + let all = e + .signals_all + .iter() + .map(|s| s.to_string()) + .collect::>(); + let any = e + .signals_any + .iter() + .map(|s| s.to_string()) + .collect::>(); + let not = e + .signals_not + .iter() + .map(|s| s.to_string()) + .collect::>(); + format!( + "{{'name': '{}', 'operate': '{}', 'signals_all': {}, 'signals_any': {}, 'signals_not': {}}}", + py_escape_str(&e.name), + py_escape_str(e.operate.to_chinese()), + py_repr_list_str(&all), + py_repr_list_str(&any), + py_repr_list_str(¬), + ) + }; + let opens = p + .opens + .iter() + .map(event_repr) + .collect::>() + .join(", "); + let exits = p + .exits + .iter() + .map(event_repr) + .collect::>() + .join(", "); + format!( + "{{'name': '{}', 'opens': [{}], 'exits': [{}], 'interval': {}, 'timeout': {}, 'stop_loss': {}, 'T0': {}}}", + py_escape_str(&p.name), + opens, + exits, + p.interval, + p.timeout, + py_float_repr(p.stop_loss), + if p.t0 { "True" } else { "False" }, + ) +} + +fn write_positions_json(positions: &[Position], out_dir: &Path) -> PyResult<()> { + fs::create_dir_all(out_dir).map_err(|e| { + PyValueError::new_err(format!( + "创建 positions 目录失败 {}: {e}", + out_dir.display() + )) + })?; + for pos in positions { + let stem = sanitize_file_stem(&pos.name); + let file = out_dir.join(format!("{stem}.json")); + let mut dump = position_to_position_file_dump(pos); + let md5_repr = position_to_position_file_repr_for_md5(pos); + let md5_val = format!("{:x}", md5::compute(md5_repr.as_bytes())); + if let Value::Object(ref mut map) = dump { + map.insert("md5".to_string(), Value::String(md5_val)); + } + let content = serde_json::to_string_pretty(&dump) + .map_err(|e| PyRuntimeError::new_err(format!("序列化 Position 失败: {e}")))?; + fs::write(&file, content).map_err(|e| { + PyRuntimeError::new_err(format!("写入 Position 失败 {}: {e}", file.display())) + })?; + } + Ok(()) +} + +fn read_parquet_if_exists(path: &Path) -> PyResult> { + if !path.exists() { + return Ok(None); + } + let file = fs::File::open(path) + .map_err(|e| PyValueError::new_err(format!("读取 parquet 失败 {}: {e}", path.display())))?; + let df = ParquetReader::new(file).finish().map_err(|e| { + PyRuntimeError::new_err(format!("解析 parquet 失败 {}: {e}", path.display())) + })?; + Ok(Some(df)) +} + +fn df_first_f64(df: &DataFrame, col_name: &str) -> Option { + let series = df.column(col_name).ok()?; + let casted = series.cast(&DataType::Float64).ok()?; + let ca = casted.f64().ok()?; + ca.get(0) +} + +fn round_to(v: f64, digits: i32) -> f64 { + if !v.is_finite() { + return v; + } + let p = 10_f64.powi(digits); + (v * p).round_ties_even() / p +} + +fn mean(vals: &[f64]) -> f64 { + if vals.is_empty() { + return f64::NAN; + } + vals.iter().sum::() / vals.len() as f64 +} + +fn sample_std(vals: &[f64]) -> f64 { + if vals.len() <= 1 { + return f64::NAN; + } + let m = mean(vals); + let var = vals.iter().map(|x| (x - m).powi(2)).sum::() / (vals.len() as f64 - 1.0); + var.sqrt() +} + +fn py_cap_max(v: f64, cap: f64) -> f64 { + if v.is_nan() { f64::NAN } else { v.min(cap) } +} + +fn cal_break_even_point(seq: &[f64]) -> f64 { + if seq.is_empty() { + return 0.0; + } + if seq.iter().sum::() < 0.0 { + return 1.0; + } + let mut sorted = seq.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let mut csum = 0.0; + let mut neg_cnt = 0usize; + for v in sorted { + csum += v; + if csum < 0.0 { + neg_cnt += 1; + } + } + (neg_cnt as f64 + 1.0) / seq.len() as f64 +} + +fn py_float_repr(v: f64) -> String { + if !v.is_finite() { + return "0.0".to_string(); + } + if v.fract().abs() < 1e-12 { + format!("{v:.1}") + } else { + format!("{v}") + } +} + +fn event_to_opt_dump_repr(e: &czsc_core::objects::event::Event) -> String { + let signal_kv_repr = |s: &czsc_core::objects::signal::Signal| { + format!( + "{{'key': '{}', 'value': '{}'}}", + py_escape_str(&s.key()), + py_escape_str(&s.value()) + ) + }; + let all = format!( + "[{}]", + e.signals_all + .iter() + .map(signal_kv_repr) + .collect::>() + .join(", ") + ); + let any = format!( + "[{}]", + e.signals_any + .iter() + .map(signal_kv_repr) + .collect::>() + .join(", ") + ); + let not = format!( + "[{}]", + e.signals_not + .iter() + .map(signal_kv_repr) + .collect::>() + .join(", ") + ); + format!( + "{{'name': '{}', 'operate': '{}', 'signals_all': {}, 'signals_any': {}, 'signals_not': {}}}", + py_escape_str(&e.name), + py_escape_str(e.operate.to_chinese()), + all, + any, + not, + ) +} + +fn position_to_opt_dump_repr(p: &Position) -> String { + let opens = p + .opens + .iter() + .map(event_to_opt_dump_repr) + .collect::>() + .join(", "); + let exits = p + .exits + .iter() + .map(event_to_opt_dump_repr) + .collect::>() + .join(", "); + format!( + "{{'symbol': 'symbol', 'name': '{}', 'opens': [{}], 'exits': [{}], 'interval': {}, 'timeout': {}, 'stop_loss': {}, 'T0': {}, 'pairs': [], 'holds': []}}", + py_escape_str(&p.name), + opens, + exits, + p.interval, + p.timeout, + py_float_repr(p.stop_loss), + if p.t0 { "True" } else { "False" }, + ) +} + +fn collect_optimize_report_rows( + positions: &[Position], + symbols: &[String], + poss_dir: &Path, +) -> PyResult> { + let mut rows: Vec = Vec::new(); + + for pos in positions { + let mut pair_lfs = Vec::new(); + let mut hold_lfs = Vec::new(); + + for sym in symbols { + let sym_dir = poss_dir.join(sym); + let pairs_path = sym_dir.join(format!("{}.pairs.parquet", pos.name)); + let holds_path = sym_dir.join(format!("{}.holds.parquet", pos.name)); + + if let Some(df) = read_parquet_if_exists(&pairs_path)? { + pair_lfs.push(df.lazy()); + } + if let Some(df) = read_parquet_if_exists(&holds_path)? { + hold_lfs.push(df.lazy()); + } + } + + if pair_lfs.is_empty() || hold_lfs.is_empty() { + continue; + } + + let pairs = concat(pair_lfs, UnionArgs::default()) + .and_then(|lf| lf.collect()) + .map_err(|e| PyRuntimeError::new_err(format!("合并 pairs 失败: {e}")))?; + let holds = concat(hold_lfs, UnionArgs::default()) + .and_then(|lf| lf.collect()) + .map_err(|e| PyRuntimeError::new_err(format!("合并 holds 失败: {e}")))?; + + if pairs.height() == 0 || holds.height() == 0 { + continue; + } + if holds.column("dt").is_err() + || holds.column("n1b").is_err() + || holds.column("pos").is_err() + { + continue; + } + + let cross_df = holds + .lazy() + .group_by([col("dt")]) + .agg([ + ((col("n1b").cast(DataType::Float64) * col("pos").cast(DataType::Float64)).sum() + / (col("pos").neq(lit(0)).cast(DataType::Float64).sum() + lit(1.0))) + .alias("cross_ret"), + (col("n1b").cast(DataType::Float64) * col("pos").cast(DataType::Float64)) + .mean() + .alias("cross1_ret"), + ]) + .select([ + col("cross_ret").sum().alias("截面等权收益"), + col("cross1_ret").sum().alias("截面品种等权"), + ]) + .collect() + .map_err(|e| PyRuntimeError::new_err(format!("计算截面统计失败: {e}")))?; + + let cross = df_first_f64(&cross_df, "截面等权收益").unwrap_or(0.0); + let cross1 = df_first_f64(&cross_df, "截面品种等权").unwrap_or(0.0); + let total_trades = pairs.height() as i64; + + let pairs_enhanced = pairs + .lazy() + .with_columns([ + col("开仓时间").dt().strftime("%Y-%m-%d").alias("开仓日"), + col("开仓时间") + .dt() + .strftime("%Y-%m-%d %H:%M:%S") + .alias("开仓时间文本"), + col("平仓时间") + .dt() + .strftime("%Y-%m-%d %H:%M:%S") + .alias("平仓时间文本"), + ]) + .collect() + .map_err(|e| PyRuntimeError::new_err(format!("增强 pairs 字段失败: {e}")))?; + + let profit_casted = pairs_enhanced + .column("盈亏比例") + .and_then(|s| s.cast(&DataType::Float64)) + .map_err(|e| PyRuntimeError::new_err(format!("读取盈亏比例失败: {e}")))?; + let profit_ca = profit_casted + .f64() + .map_err(|e| PyRuntimeError::new_err(format!("盈亏比例类型错误: {e}")))?; + let profits: Vec = profit_ca.into_iter().flatten().collect(); + if profits.is_empty() { + continue; + } + + let hold_days_casted = pairs_enhanced + .column("持仓天数") + .and_then(|s| s.cast(&DataType::Float64)) + .map_err(|e| PyRuntimeError::new_err(format!("读取持仓天数失败: {e}")))?; + let hold_days_ca = hold_days_casted + .f64() + .map_err(|e| PyRuntimeError::new_err(format!("持仓天数类型错误: {e}")))?; + let hold_days: Vec = hold_days_ca.into_iter().flatten().collect(); + + let hold_bars_casted = pairs_enhanced + .column("持仓K线数") + .and_then(|s| s.cast(&DataType::Float64)) + .map_err(|e| PyRuntimeError::new_err(format!("读取持仓K线数失败: {e}")))?; + let hold_bars_ca = hold_bars_casted + .f64() + .map_err(|e| PyRuntimeError::new_err(format!("持仓K线数类型错误: {e}")))?; + let hold_bars: Vec = hold_bars_ca.into_iter().flatten().collect(); + + let start_time = pairs_enhanced + .column("开仓时间文本") + .ok() + .and_then(|s| s.str().ok()) + .and_then(|ca| ca.into_iter().flatten().map(|x| x.to_string()).min()); + let end_time = pairs_enhanced + .column("平仓时间文本") + .ok() + .and_then(|s| s.str().ok()) + .and_then(|ca| ca.into_iter().flatten().map(|x| x.to_string()).max()); + + let symbol_count = pairs_enhanced + .column("标的代码") + .ok() + .and_then(|s| s.str().ok()) + .map(|ca| { + let mut ss = HashSet::new(); + for x in ca.into_iter().flatten() { + ss.insert(x.to_string()); + } + ss.len() as i64 + }) + .unwrap_or(0); + + let mut open_day_profit: HashMap> = HashMap::new(); + let open_day_ca = pairs_enhanced + .column("开仓日") + .and_then(|s| s.str()) + .map_err(|e| PyRuntimeError::new_err(format!("读取开仓日失败: {e}")))?; + for i in 0..pairs_enhanced.height() { + if let (Some(day), Some(p)) = (open_day_ca.get(i), profit_ca.get(i)) { + open_day_profit.entry(day.to_string()).or_default().push(p); + } + } + let open_day_be = if open_day_profit.is_empty() { + 0.0 + } else { + let x = open_day_profit + .values() + .map(|v| cal_break_even_point(v)) + .sum::() + / open_day_profit.len() as f64; + round_to(x, 4) + }; + + let avg_profit = round_to(mean(&profits), 4); + let std_profit = round_to(sample_std(&profits), 4); + let max_profit = round_to( + profits + .iter() + .copied() + .fold(f64::NEG_INFINITY, |a, b| if a > b { a } else { b }), + 4, + ); + let min_profit = round_to( + profits + .iter() + .copied() + .fold(f64::INFINITY, |a, b| if a < b { a } else { b }), + 4, + ); + let win_n = profits.iter().filter(|x| **x > 0.0).count() as f64; + let total_n = profits.len() as f64; + let win_pct = round_to(win_n / total_n, 4); + + let gain_vals: Vec = profits.iter().copied().filter(|x| *x > 0.0).collect(); + let loss_vals: Vec = profits.iter().copied().filter(|x| *x <= 0.0).collect(); + let gain_mean = mean(&gain_vals); + let loss_mean = mean(&loss_vals); + let gain_sum: f64 = gain_vals.iter().sum(); + let loss_sum: f64 = loss_vals.iter().sum(); + let single_gain_loss_rate = { + let raw = gain_mean / (loss_mean.abs() + 1e-8); + py_cap_max(round_to(raw, 2), 5.0) + }; + let total_gain_loss_rate = { + let raw = gain_sum / (loss_sum.abs() + 1e-8); + py_cap_max(round_to(raw, 2), 5.0) + }; + let trade_score = round_to(total_gain_loss_rate * win_pct, 4); + let edge = round_to(single_gain_loss_rate * win_pct - (1.0 - win_pct), 4); + let break_even = round_to(cal_break_even_point(&profits), 4); + let avg_hold_days = round_to(mean(&hold_days), 2); + let avg_hold_bars = round_to(mean(&hold_bars), 2); + let per_natural_day = round_to(avg_profit / avg_hold_days, 2); + let per_bar = round_to(avg_profit / avg_hold_bars, 2); + + let pos_dump = position_to_opt_dump_repr(pos); + + rows.push(json!({ + "开始时间": start_time, + "结束时间": end_time, + "交易标的数量": symbol_count, + "总体交易次数": total_trades, + "平均持仓天数": avg_hold_days, + "平均持仓K线数": avg_hold_bars, + "平均单笔收益": avg_profit, + "单笔收益标准差": std_profit, + "最大单笔收益": max_profit, + "最小单笔收益": min_profit, + "交易胜率": win_pct, + "单笔盈亏比": single_gain_loss_rate, + "累计盈亏比": total_gain_loss_rate, + "交易得分": trade_score, + "赢面": edge, + "盈亏平衡点": break_even, + "开仓日盈亏平衡点": open_day_be, + "每自然日收益": per_natural_day, + "每根K线收益": per_bar, + "截面等权收益": cross, + "截面品种等权": cross1, + "pos_name": pos.name, + "pos_dump": pos_dump, + })); + } + + rows.sort_by(|a, b| { + let av = a + .get("截面等权收益") + .and_then(Value::as_f64) + .unwrap_or(f64::NEG_INFINITY); + let bv = b + .get("截面等权收益") + .and_then(Value::as_f64) + .unwrap_or(f64::NEG_INFINITY); + bv.partial_cmp(&av).unwrap_or(std::cmp::Ordering::Equal) + }); + + Ok(rows) +} + +fn write_optimize_report_xlsx(rows: &[Value], out_path: &Path) -> PyResult<()> { + let mut workbook = Workbook::new(); + let worksheet = workbook.add_worksheet(); + + let headers = [ + "开始时间", + "结束时间", + "交易标的数量", + "总体交易次数", + "平均持仓天数", + "平均持仓K线数", + "平均单笔收益", + "单笔收益标准差", + "最大单笔收益", + "最小单笔收益", + "交易胜率", + "单笔盈亏比", + "累计盈亏比", + "交易得分", + "赢面", + "盈亏平衡点", + "开仓日盈亏平衡点", + "每自然日收益", + "每根K线收益", + "截面等权收益", + "截面品种等权", + "pos_name", + "pos_dump", + ]; + + for (c, h) in headers.iter().enumerate() { + worksheet + .write_string(0, c as u16, *h) + .map_err(|e| PyRuntimeError::new_err(format!("写入 xlsx 表头失败: {e}")))?; + } + + for (r, row) in rows.iter().enumerate() { + let rr = (r + 1) as u32; + for (c, h) in headers.iter().enumerate() { + let cc = c as u16; + let v = row.get(*h).unwrap_or(&Value::Null); + match v { + Value::Null => {} + Value::Bool(b) => { + worksheet.write_boolean(rr, cc, *b).map_err(|e| { + PyRuntimeError::new_err(format!("写入 xlsx 布尔值失败: {e}")) + })?; + } + Value::Number(n) => { + if let Some(f) = n.as_f64() { + worksheet.write_number(rr, cc, f).map_err(|e| { + PyRuntimeError::new_err(format!("写入 xlsx 数值失败: {e}")) + })?; + } else { + worksheet.write_string(rr, cc, n.to_string()).map_err(|e| { + PyRuntimeError::new_err(format!("写入 xlsx 字符串数值失败: {e}")) + })?; + } + } + Value::String(s) => { + worksheet.write_string(rr, cc, s).map_err(|e| { + PyRuntimeError::new_err(format!("写入 xlsx 字符串失败: {e}")) + })?; + } + _ => { + let s = serde_json::to_string(v).map_err(|e| { + PyRuntimeError::new_err(format!("序列化 xlsx 单元格失败: {e}")) + })?; + worksheet.write_string(rr, cc, s).map_err(|e| { + PyRuntimeError::new_err(format!("写入 xlsx JSON 字符串失败: {e}")) + })?; + } + } + } + } + + workbook + .save(out_path) + .map_err(|e| PyRuntimeError::new_err(format!("写入 xlsx 失败: {e}"))) +} + +/// 运行批量优化任务。 +/// +/// Python 侧通常不会直接构造 Rust `ExecutionPlan`,而是先准备: +/// - `bars_dir`: 每个 symbol 一个 parquet 文件的目录,文件名形如 `{symbol}.parquet` +/// - `config_path`: 优化任务 JSON 配置文件,包含 `optim_type / symbols / files_position` +/// - `res_path`: 输出根目录,函数内部会按 task hash 创建子目录 +/// +/// 返回值是简短的文本摘要,包含任务目录和报表路径;详细产物会落到磁盘。 +#[pyfunction] +#[pyo3(text_signature = "(bars_dir, config_path, res_path, n_threads=1)")] +#[pyo3(signature = (bars_dir, config_path, res_path, n_threads=1))] +pub fn run_optimize( + bars_dir: &str, + config_path: &str, + res_path: &str, + n_threads: usize, +) -> PyResult { + let config_content = fs::read_to_string(config_path) + .map_err(|e| PyValueError::new_err(format!("读取 config 错误: {e}")))?; + + let config: Value = serde_json::from_str(&config_content) + .map_err(|e| PyValueError::new_err(format!("解析 config 错误: {e}")))?; + + let optim_type = config["optim_type"].as_str().unwrap_or("open"); + let base_freq_str = config["base_freq"].as_str().unwrap_or("日线"); + let base_freq = base_freq_str + .parse::() + .map_err(|_| PyValueError::new_err("解析 base_freq 失败"))?; + let market = config["market"].as_str(); + let bg_max_count = config["bg_max_count"].as_u64().map(|x| x as usize); + let symbols_val = config["symbols"] + .as_array() + .ok_or_else(|| PyValueError::new_err("缺少 symbols 参数"))?; + + let mut symbols = Vec::new(); + for s in symbols_val { + symbols.push(s.as_str().unwrap().to_string()); + } + + let task_name = config["task_name"] + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(|| { + if optim_type == "open" { + "入场优化".to_string() + } else { + "出场优化".to_string() + } + }); + + let files_position: Vec = config["files_position"] + .as_array() + .unwrap_or(&vec![]) + .iter() + .map(|v| PathBuf::from(v.as_str().unwrap())) + .collect(); + + let (positions, task_hash) = if optim_type == "open" { + let mut candidate_sigs: Vec = config["candidate_signals"] + .as_array() + .unwrap_or(&vec![]) + .iter() + .map(|v| v.as_str().unwrap().to_string()) + .collect(); + candidate_sigs.sort(); + let mut sorted_symbols = symbols.clone(); + sorted_symbols.sort(); + let digest = md5_upper8(&format!( + "{}_{}", + py_repr_list_str(&candidate_sigs), + py_repr_list_str(&sorted_symbols) + )); + ( + get_open_optim_positions(&files_position, &candidate_sigs) + .map_err(|e| PyValueError::new_err(format!("生成开仓策略错误: {:?}", e)))?, + digest, + ) + } else { + let candidate_events = config["candidate_events"] + .as_array() + .unwrap_or(&vec![]) + .clone(); + let digest = md5_upper8(&format!( + "{}_{}", + py_repr_json(&Value::Array(candidate_events.clone())), + py_repr_list_str(&symbols) + )); + ( + get_exit_optim_positions(&files_position, &candidate_events) + .map_err(|e| PyValueError::new_err(format!("生成平仓策略错误: {:?}", e)))?, + digest, + ) + }; + + let task_dir = Path::new(res_path).join(format!("{task_name}_{task_hash}")); + let poss_dir = task_dir.join("poss"); + let positions_dir = task_dir.join("positions"); + fs::create_dir_all(&poss_dir).map_err(|e| { + PyValueError::new_err(format!("创建结果目录失败 {}: {e}", poss_dir.display())) + })?; + write_positions_json(&positions, &positions_dir)?; + + let mut bars_map = HashMap::new(); + for sym in &symbols { + let file_path = Path::new(bars_dir).join(format!("{}.parquet", sym)); + if file_path.exists() { + let bars = read_bars_from_file(&file_path, base_freq)?; + bars_map.insert(sym.clone(), bars); + } + } + + let sdt_cutoff = config["sdt"] + .as_str() + .and_then(parse_sdt_utc) + .and_then(|dt| FixedOffset::east_opt(0).map(|tz| dt.with_timezone(&tz))); + + symbols_optim_parallel( + symbols.clone(), + bars_map, + positions.clone(), + &poss_dir, + base_freq_str, + market, + bg_max_count, + sdt_cutoff, + n_threads, + ); + + let report_rows = collect_optimize_report_rows(&positions, &symbols, &poss_dir)?; + let report_prefix = if optim_type == "open" { + "入场优化" + } else { + "出场优化" + }; + let report_path = task_dir.join(format!("{report_prefix}_{task_name}_{task_hash}.xlsx")); + if !report_rows.is_empty() { + write_optimize_report_xlsx(&report_rows, &report_path)?; + } + + Ok(format!( + "跑批完成: task_dir={}, report={}", + task_dir.display(), + if report_rows.is_empty() { + "NONE".to_string() + } else { + report_path.display().to_string() + } + )) +} + +#[allow(unused_variables)] +/// 执行一次单标的回测并将 `signals / pairs / holds` 落盘到指定目录。 +/// +/// 参数约定: +/// - `bars_bytes`: Python `pyarrow` 序列化后的 Arrow bytes +/// - `config_path`: 与 Python `czsc` 风格兼容的策略 JSON 文件 +/// - `res_path`: 结果目录,函数会写出三个 parquet 文件 +/// - `opts`: 兼容保留字段,当前未使用 +/// +/// 这个入口适合“本地一次性回测并保留产物”的场景;如果只想拿内存结果, +/// 优先使用 `run_research` / `run_replay`。 +#[pyfunction] +#[pyo3(text_signature = "(bars_bytes, config_path, res_path, opts='')")] +#[pyo3(signature = (bars_bytes, config_path, res_path, opts=""))] +pub fn run_backtest( + py: Python, + bars_bytes: &Bound, + config_path: &str, + res_path: &str, + opts: &str, +) -> PyResult { + // 1. 加载配置 + let config_content = fs::read_to_string(config_path) + .map_err(|e| PyValueError::new_err(format!("读取 config 错误: {e}")))?; + let config: Value = serde_json::from_str(&config_content) + .map_err(|e| PyValueError::new_err(format!("解析 config 错误: {e}")))?; + + let base_freq = config["base_freq"] + .as_str() + .unwrap_or("日线") + .parse::() + .map_err(|_| PyValueError::new_err("解析 base_freq 失败"))?; + + // 2. 解析 DataFrame 拿到 bars + let raw_data = bars_bytes.as_bytes(); + let df = pyarrow_to_df(raw_data) + .map_err(|e| PyValueError::new_err(format!("Arrow bytes 转 DataFrame 失败: {e}")))?; + + let bars = format_standard_kline(df, base_freq) + .map_err(|e| PyValueError::new_err(format!("K线标准化格式错误: {e}")))?; + + // 3. 读取策略与信号配置 + let signals_config: Vec = if let Some(cfgs) = config["signals_config"].as_array() + { + let s = serde_json::to_string(cfgs).unwrap(); + serde_json::from_str(&s).unwrap() + } else { + vec![] + }; + + let positions_val = config["positions"] + .as_array() + .ok_or_else(|| PyValueError::new_err("缺少 positions 配置"))?; + + let mut positions: Vec = Vec::new(); + for p_val in positions_val { + let p_str = serde_json::to_string(p_val).unwrap(); + let pos: Position = serde_json::from_str(&p_str) + .map_err(|e| PyValueError::new_err(format!("Position 解析错误: {e}")))?; + positions.push(pos); + } + + // 4. 构建统一执行计划 + let symbol = config["symbols"] + .as_array() + .and_then(|arr| arr.first()) + .and_then(|v| v.as_str()) + .unwrap_or("UNKNOWN") + .to_string(); + let market = config["market"].as_str().map(|x| x.to_string()); + let sdt = config["sdt"].as_str().map(|x| x.to_string()); + let bg_max_count = config["bg_max_count"].as_u64().map(|x| x as usize); + let plan_input = ExecutionPlanInput { + symbol: symbol.clone(), + base_freq: base_freq.to_string(), + signals_config, + positions, + market, + bg_max_count, + sdt, + include_sdt_bar: config["include_sdt_bar"].as_bool(), + }; + let plan = ExecutionPlan::compile(plan_input) + .map_err(|e| PyValueError::new_err(format!("ExecutionPlan 编译失败: {e}")))?; + let cutoff = plan.sdt.as_deref().and_then(parse_sdt_utc); + let cutoff_bar = cutoff.and_then(|c| bars.iter().find(|b| b.dt == c).cloned()); + let output = UnifiedExecEngine::run(&plan, bars, None, true, false) + .map_err(|e| PyValueError::new_err(format!("UnifiedExecEngine 执行失败: {e}")))?; + let rows = output.signal_rows; + let positions = output.positions; + + // 5. 执行结果统计 + let mut total_pairs = 0; + for pos in &positions { + if let Ok(df) = pos.pairs() { + total_pairs += df.height(); + } + } + println!( + ">>> (RUST) 回放结束。最终所有仓位共产生完毕交易流水 Pairs 的条数为: {}", + total_pairs + ); + + // 6. 存储结果 + let res_dir = Path::new(res_path); + if !res_dir.exists() { + fs::create_dir_all(res_dir).map_err(|e| PyValueError::new_err(e.to_string()))?; + } + + let mut signals_df = build_signals_dataframe(&rows)?; + signals_df = align_signals_python_baseline(signals_df, cutoff, cutoff_bar.as_ref())?; + if signals_df.column("dt").is_ok() { + signals_df = signals_df + .lazy() + .sort(["dt"], SortMultipleOptions::default()) + .collect() + .map_err(|e| PyRuntimeError::new_err(format!("signals 排序失败: {e}")))?; + } + signals_df = normalize_signals_dtypes(signals_df)?; + let (pairs_df, holds_df) = combine_pairs_holds_for_backtest(&positions)?; + + write_df_parquet(&res_dir.join("signals.parquet"), signals_df)?; + write_df_parquet(&res_dir.join("pairs.parquet"), pairs_df)?; + write_df_parquet(&res_dir.join("holds.parquet"), holds_df)?; + + Ok("单次回测完成".to_string()) +} + +#[allow(unused_variables)] +/// 仅生成信号明细,不执行完整交易统计。 +/// +/// 这个入口对应 Python `generate_czsc_signals` 风格用法: +/// - 从 Arrow bytes 读取 bars +/// - 根据 `signals_config` 构建统一执行计划 +/// - 输出 `signals.parquet` +/// +/// 当 `positions` 为空时,内部会注入 `__signals_only__` 占位仓位,以复用统一执行引擎。 +/// `sdt` 参数会覆盖配置中的 `sdt`,常用于信号矩阵或基准对比。 +#[pyfunction] +#[pyo3(text_signature = "(bars_bytes, config_path, out_path, sdt='')")] +#[pyo3(signature = (bars_bytes, config_path, out_path, sdt=""))] +pub fn generate_signals( + py: Python, + bars_bytes: &Bound, + config_path: &str, + out_path: &str, + sdt: &str, +) -> PyResult<()> { + // 1) 读取配置 + let config_content = fs::read_to_string(config_path) + .map_err(|e| PyValueError::new_err(format!("读取 config 错误: {e}")))?; + let config: Value = serde_json::from_str(&config_content) + .map_err(|e| PyValueError::new_err(format!("解析 config 错误: {e}")))?; + + let base_freq = config["base_freq"] + .as_str() + .unwrap_or("日线") + .parse::() + .map_err(|_| PyValueError::new_err("解析 base_freq 失败"))?; + + let signals_config: Vec = if let Some(cfgs) = config["signals_config"].as_array() + { + let s = serde_json::to_string(cfgs).unwrap(); + serde_json::from_str(&s).unwrap_or_default() + } else { + vec![] + }; + + // 2) Arrow bytes -> bars + let raw_data = bars_bytes.as_bytes(); + let df = pyarrow_to_df(raw_data) + .map_err(|e| PyValueError::new_err(format!("Arrow bytes 转 DataFrame 失败: {e}")))?; + let bars = format_standard_kline(df, base_freq) + .map_err(|e| PyValueError::new_err(format!("K线标准化格式错误: {e}")))?; + + // 3) 构建执行计划并执行 + let symbol = config["symbols"] + .as_array() + .and_then(|arr| arr.first()) + .and_then(|v| v.as_str()) + .unwrap_or("UNKNOWN") + .to_string(); + let mut positions: Vec = if let Some(arr) = config["positions"].as_array() { + let mut out = Vec::with_capacity(arr.len()); + for p_val in arr { + let p_str = serde_json::to_string(p_val).unwrap(); + let pos: Position = serde_json::from_str(&p_str) + .map_err(|e| PyValueError::new_err(format!("Position 解析错误: {e}")))?; + out.push(pos); + } + out + } else { + vec![] + }; + // generate_signals 允许无仓位场景:注入一个无事件占位仓位以复用统一执行引擎 + if positions.is_empty() { + let stub: Position = serde_json::from_value(serde_json::json!({ + "name": "__signals_only__", + "symbol": symbol, + "opens": [], + "exits": [], + "interval": 0, + "timeout": 1, + "stop_loss": 1000.0, + "T0": false + })) + .map_err(|e| PyValueError::new_err(format!("构建 signals-only 占位 Position 失败: {e}")))?; + positions.push(stub); + } + let plan = ExecutionPlan::compile(ExecutionPlanInput { + symbol, + base_freq: base_freq.to_string(), + signals_config, + positions, + market: config["market"].as_str().map(|x| x.to_string()), + bg_max_count: config["bg_max_count"].as_u64().map(|x| x as usize), + sdt: config["sdt"].as_str().map(|x| x.to_string()), + include_sdt_bar: config["include_sdt_bar"].as_bool(), + }) + .map_err(|e| PyValueError::new_err(format!("ExecutionPlan 编译失败: {e}")))?; + let sdt_override = if sdt.is_empty() { None } else { Some(sdt) }; + let output = UnifiedExecEngine::run(&plan, bars, sdt_override, true, false) + .map_err(|e| PyRuntimeError::new_err(format!("UnifiedExecEngine 执行失败: {e}")))?; + let rows = output.signal_rows; + + // 4) 行转列,写 parquet + let mut out_df = build_signals_dataframe(&rows)?; + out_df = normalize_signals_dtypes(out_df)?; + + let out = Path::new(out_path); + let file_path = if out.extension().is_some() { + out.to_path_buf() + } else { + if !out.exists() { + fs::create_dir_all(out).map_err(|e| PyValueError::new_err(e.to_string()))?; + } + out.join("signals.parquet") + }; + let mut file = fs::File::create(&file_path) + .map_err(|e| PyValueError::new_err(format!("创建输出文件失败: {e}")))?; + ParquetWriter::new(&mut file) + .finish(&mut out_df) + .map_err(|e| PyValueError::new_err(format!("写出 parquet 失败: {e}")))?; + + Ok(()) +} + +/// 返回所有已注册信号函数的只读元信息。 +/// +/// 每个元素都是一个 `dict`,包含: +/// - `name`: 信号函数名 +/// - `param_template`: 参数模板 +/// - `category`: `kline` 或 `trader` +/// - `namespace`: 命名空间前缀,如 `bar / tas / cxt / pos` +/// +/// 常用于: +/// - 构建 parity 覆盖矩阵 +/// - 生成文档或自动补全 +/// - 检查 Rust / Python 共享信号交集 +#[pyfunction] +#[pyo3(text_signature = "(include_kline=True, include_trader=True)")] +#[pyo3(signature = (include_kline=true, include_trader=true))] +pub fn list_all_signals( + py: Python<'_>, + include_kline: bool, + include_trader: bool, +) -> PyResult> { + let infos = list_all_registered_signals(include_kline, include_trader); + let list = PyList::empty(py); + for it in infos { + let d = PyDict::new(py); + d.set_item("name", it.name)?; + d.set_item("param_template", it.param_template)?; + d.set_item("category", it.category)?; + d.set_item("namespace", it.namespace)?; + list.append(d)?; + } + Ok(list.unbind()) +} + +/// 从 `unique_signals` 反推出 Rust 运行时 `signals_config`。 +/// +/// 这是 PyO3 暴露给 Python 兼容层的核心反解析入口: +/// - 输入:信号字符串列表,例如 `['60分钟_D1SMA#5_分类V221101_多头_向上_任意_0']` +/// - 输出:可直接放入策略 JSON 的 `list[dict]` +/// +/// 这个函数依赖 Rust 注册表中的 `param_template` 做反解析,因此适合: +/// - `CzscStrategyBase.unique_signals -> signals_config` +/// - parity benchmark 中的 same-source import-swap 场景 +#[pyfunction] +#[pyo3(text_signature = "(unique_signals)")] +#[pyo3(signature = (unique_signals))] +pub fn derive_signals_config(py: Python<'_>, unique_signals: Vec) -> PyResult { + let refs: Vec<&str> = unique_signals.iter().map(String::as_str).collect(); + let configs = get_signals_config(&refs); + let json_str = serde_json::to_string(&configs) + .map_err(|e| PyRuntimeError::new_err(format!("序列化 signals_config 失败: {e}")))?; + let json_mod = py.import("json")?; + Ok(json_mod.call_method1("loads", (json_str,))?.unbind()) +} + +/// 从 `signals_config` 中提取执行所需的全部周期列表。 +/// +/// 返回结果已经按 `czsc` 习惯的中文周期顺序排序,而不是字典序。 +/// 这个入口通常给 Python 兼容层用于自动填充: +/// - `freqs` +/// - `sorted_freqs` +/// - `base_freq` 推导前的候选周期集合 +#[pyfunction] +#[pyo3(text_signature = "(signals_config)")] +#[pyo3(signature = (signals_config))] +pub fn derive_signals_freqs(py: Python<'_>, signals_config: PyObject) -> PyResult> { + let json_mod = py.import("json")?; + let json_str = json_mod + .call_method1("dumps", (signals_config,))? + .extract::()?; + let configs: Vec = serde_json::from_str(&json_str) + .map_err(|e| PyValueError::new_err(format!("signals_config 解析失败: {e}")))?; + Ok(get_signals_freqs(&configs)) +} + +#[cfg(test)] +mod tests { + use super::parse_sdt_utc; + + #[test] + fn test_parse_sdt_utc_supports_iso_t_without_tz() { + let dt = parse_sdt_utc("2023-02-28T12:00:00"); + assert!(dt.is_some()); + } + + #[test] + fn test_parse_sdt_utc_supports_iso_t_with_fractional() { + let dt = parse_sdt_utc("2023-02-28T12:00:00.123456"); + assert!(dt.is_some()); + } +} diff --git a/crates/czsc-python/src/trader/czsc_signals.rs b/crates/czsc-python/src/trader/czsc_signals.rs new file mode 100644 index 000000000..a49f4b733 --- /dev/null +++ b/crates/czsc-python/src/trader/czsc_signals.rs @@ -0,0 +1,293 @@ +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::utils::common::create_naive_pandas_timestamp; +use czsc_trader::signals::czsc_signals::CzscSignals; +use czsc_trader::signals::sig_parse::SignalConfig; +use czsc_utils::bar_generator::BarGenerator; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +use serde_json::Value; +use std::collections::HashMap; + +/// 将 Python list[dict] 解析为 Vec +pub(crate) fn parse_signals_config(configs: &Bound) -> PyResult> { + let mut result = Vec::with_capacity(configs.len()); + for item in configs.iter() { + let dict = item + .downcast::() + .map_err(|_| PyValueError::new_err("signals_config 中每个元素必须是 dict"))?; + + let name: String = dict + .get_item("name")? + .ok_or_else(|| PyValueError::new_err("signals_config dict 缺少 'name' 字段"))? + .extract()?; + + let freq: Option = match dict.get_item("freq")? { + Some(v) if !v.is_none() => Some(v.extract()?), + _ => None, + }; + + let mut params: HashMap = HashMap::new(); + + // 优先从 "params" 子字典取参数 + if let Some(params_obj) = dict.get_item("params")? + && !params_obj.is_none() + && let Ok(params_dict) = params_obj.downcast::() { + for (k, v) in params_dict.iter() { + let key: String = k.extract()?; + let val = py_to_serde_value(&v)?; + params.insert(key, val); + } + } + + // 也支持 flat params:dict 中除 name/freq/params 以外的 key 直接作为参数 + for (k, v) in dict.iter() { + let key: String = k.extract()?; + if key == "name" || key == "freq" || key == "params" { + continue; + } + if let std::collections::hash_map::Entry::Vacant(e) = params.entry(key) { + let val = py_to_serde_value(&v)?; + e.insert(val); + } + } + + result.push(SignalConfig { name, freq, params }); + } + Ok(result) +} + +/// 将 Python 值转换为 serde_json::Value +pub(crate) fn py_to_serde_value(obj: &Bound) -> PyResult { + // bool 必须在 int 前检查,因为 Python 的 bool 是 int 子类 + if let Ok(v) = obj.extract::() { + return Ok(Value::from(v)); + } + if let Ok(v) = obj.extract::() { + return Ok(Value::from(v)); + } + if let Ok(v) = obj.extract::() { + return Ok(Value::from(v)); + } + if let Ok(v) = obj.extract::() { + return Ok(Value::String(v)); + } + if obj.is_none() { + return Ok(Value::Null); + } + // 降级:用 repr 作字符串 + let repr = obj.repr()?.extract::()?; + Ok(Value::String(repr)) +} + +/// CzscSignals 的 PyO3 包装 +#[gen_stub_pyclass] +#[pyclass(name = "CzscSignals", module = "czsc._native")] +pub struct PyCzscSignals { + inner: CzscSignals, + signals_config: Vec, +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyCzscSignals { + #[new] + #[pyo3(signature = (bg, signals_config))] + fn new(bg: BarGenerator, signals_config: &Bound) -> PyResult { + let configs = parse_signals_config(signals_config)?; + let symbol = bg + .freq_bars + .values() + .next() + .and_then(|v| v.read().back().cloned()) + .map(|b| b.symbol.to_string()) + .unwrap_or_default(); + let inner = CzscSignals::new(symbol, bg); + Ok(Self { + inner, + signals_config: configs, + }) + } + + /// 返回类名 + #[getter] + fn name(&self) -> &str { + "CzscSignals" + } + + /// 返回标的代码 + #[getter] + fn symbol(&self) -> &str { + &self.inner.symbol + } + + /// 返回信号字典 s + #[getter] + fn s(&self, py: Python) -> PyResult { + let dict = PyDict::new(py); + for (k, v) in &self.inner.s { + dict.set_item(k, v)?; + } + Ok(dict.into_any().unbind()) + } + + /// 返回各周期 CZSC 分析引擎 + #[getter] + fn kas(&self) -> PyResult> { + Ok(self.inner.kas.clone().into_iter().collect()) + } + + /// 返回所有周期字符串列表 + #[getter] + fn freqs(&self) -> Vec { + self.inner + .bg + .freq_bars + .keys() + .map(|f| f.to_string()) + .collect() + } + + /// 返回基准周期字符串 + #[getter] + fn base_freq(&self) -> String { + self.inner + .bg + .freq_bars + .keys() + .next() + .map(|f| f.to_string()) + .unwrap_or_default() + } + + /// 返回最新时间,作为 pandas Timestamp + #[getter] + fn end_dt(&self, py: Python) -> PyResult> { + if let Some(dt_str) = self.inner.s.get("dt") + && let Ok(dt) = chrono::DateTime::parse_from_rfc3339(dt_str) { + let utc_dt = dt.with_timezone(&chrono::Utc); + let timestamp = create_naive_pandas_timestamp(py, utc_dt)?; + return Ok(Some(timestamp)); + } + Ok(None) + } + + /// 返回当前 bar id + #[getter] + fn bid(&self) -> PyResult> { + if let Some(id_str) = self.inner.s.get("id") { + let id = id_str + .parse::() + .map_err(|e| PyValueError::new_err(format!("解析 id 失败: {e}")))?; + return Ok(Some(id)); + } + Ok(None) + } + + /// 返回最新价格 + #[getter] + fn latest_price(&self) -> PyResult> { + if let Some(close_str) = self.inner.s.get("close") { + let price = close_str + .parse::() + .map_err(|e| PyValueError::new_err(format!("解析 close 失败: {e}")))?; + return Ok(Some(price)); + } + Ok(None) + } + + /// 返回原始信号配置 + #[getter] + fn signals_config(&self, py: Python) -> PyResult { + let list = PyList::empty(py); + for cfg in &self.signals_config { + let dict = PyDict::new(py); + dict.set_item("name", &cfg.name)?; + match &cfg.freq { + Some(f) => dict.set_item("freq", f)?, + None => dict.set_item("freq", py.None())?, + } + let params_dict = PyDict::new(py); + for (k, v) in &cfg.params { + let py_val = serde_value_to_py(py, v)?; + params_dict.set_item(k, py_val)?; + } + dict.set_item("params", params_dict)?; + list.append(dict)?; + } + Ok(list.into_any().unbind()) + } + + /// 更新信号 + fn update_signals(&mut self, bar: &RawBar) { + self.inner.update_signals(bar, &self.signals_config); + } + + /// 获取当前信号字典(同 s 属性) + fn get_signals_by_conf(&self, py: Python) -> PyResult { + self.s(py) + } + + /// Pickle 支持:返回 ``(cls, (bg_clone, signals_config_list))``。 + /// 反序列化时 PyCzscSignals 由原 ``__new__`` 重新构造;缓存的信号 + /// 状态不持久化(与 design doc §2.4 的 multiprocessing 用例一致: + /// 子进程拿到的是构造参数 fresh trader)。 + fn __reduce__(&self, py: Python) -> PyResult { + let bg_clone = self.inner.bg.clone(); + let configs_list = signal_configs_to_pylist(py, &self.signals_config)?; + let constructor = py.get_type::(); + let args = (bg_clone, configs_list).into_pyobject(py)?; + let result = (constructor, args).into_pyobject(py)?; + Ok(result.into_any().unbind()) + } +} + +/// Helper: convert `Vec` back to a Python ``list[dict]`` +/// shaped exactly like ``parse_signals_config`` expects, so +/// ``__reduce__`` -> ``__new__`` round-trips cleanly. +pub(crate) fn signal_configs_to_pylist( + py: Python, + configs: &[SignalConfig], +) -> PyResult { + let list = PyList::empty(py); + for cfg in configs { + let dict = PyDict::new(py); + dict.set_item("name", &cfg.name)?; + match &cfg.freq { + Some(f) => dict.set_item("freq", f)?, + None => dict.set_item("freq", py.None())?, + } + let params_dict = PyDict::new(py); + for (k, v) in &cfg.params { + params_dict.set_item(k, serde_value_to_py(py, v)?)?; + } + dict.set_item("params", params_dict)?; + list.append(dict)?; + } + Ok(list.into_any().unbind()) +} + +/// 将 serde_json::Value 转换为 Python 对象 +pub(crate) fn serde_value_to_py(py: Python, val: &Value) -> PyResult { + match val { + Value::Null => Ok(py.None()), + Value::Bool(b) => Ok(b.into_pyobject(py)?.to_owned().into_any().unbind()), + Value::Number(n) => { + if let Some(i) = n.as_i64() { + Ok(i.into_pyobject(py)?.into_any().unbind()) + } else if let Some(f) = n.as_f64() { + Ok(f.into_pyobject(py)?.into_any().unbind()) + } else { + Ok(py.None()) + } + } + Value::String(s) => Ok(s.as_str().into_pyobject(py)?.into_any().unbind()), + Value::Array(_) | Value::Object(_) => { + let json_str = serde_json::to_string(val) + .map_err(|e| PyValueError::new_err(format!("JSON 序列化失败: {e}")))?; + Ok(json_str.into_pyobject(py)?.into_any().unbind()) + } + } +} diff --git a/crates/czsc-python/src/trader/czsc_trader.rs b/crates/czsc-python/src/trader/czsc_trader.rs new file mode 100644 index 000000000..69340d676 --- /dev/null +++ b/crates/czsc-python/src/trader/czsc_trader.rs @@ -0,0 +1,414 @@ +use super::czsc_signals::{parse_signals_config, serde_value_to_py}; +use chrono::{DateTime, FixedOffset}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::position::{LiteBar, Position, PyPosition}; +use czsc_core::utils::common::create_naive_pandas_timestamp; +use czsc_trader::signals::sig_parse::SignalConfig; +use czsc_trader::trader::CzscTrader; +use czsc_utils::bar_generator::BarGenerator; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +use std::collections::HashMap; + +/// CzscTrader 的 PyO3 包装 +#[gen_stub_pyclass] +#[pyclass(name = "CzscTrader", module = "czsc._native")] +pub struct PyCzscTrader { + inner: CzscTrader, + signals_config: Vec, + ensemble_method: String, +} + +/// 从 PyObject 提取 Position:支持 PyPosition(Rust)和有 _inner 属性的 Python wrapper +fn extract_position(_py: Python, obj: &Bound) -> PyResult { + // 优先尝试提取 PyPosition + if let Ok(py_pos) = obj.extract::() { + return Ok(py_pos.inner); + } + // 尝试从 _inner 属性提取 + if let Ok(inner_attr) = obj.getattr("_inner") + && let Ok(py_pos) = inner_attr.extract::() { + return Ok(py_pos.inner); + } + // 尝试从 inner 属性提取 + if let Ok(inner_attr) = obj.getattr("inner") + && let Ok(py_pos) = inner_attr.extract::() { + return Ok(py_pos.inner); + } + Err(PyValueError::new_err( + "positions 中的元素必须是 Position 或有 _inner/inner 属性的对象", + )) +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyCzscTrader { + #[new] + #[pyo3(signature = (bg, positions, signals_config, ensemble_method = "mean".to_string()))] + fn new( + py: Python, + bg: BarGenerator, + positions: &Bound, + signals_config: &Bound, + ensemble_method: String, + ) -> PyResult { + let configs = parse_signals_config(signals_config)?; + + // 提取 positions + let mut pos_vec = Vec::with_capacity(positions.len()); + for item in positions.iter() { + let pos = extract_position(py, &item)?; + pos_vec.push(pos); + } + + let symbol = bg + .freq_bars + .values() + .next() + .and_then(|v| v.read().back().cloned()) + .map(|b| b.symbol.to_string()) + .unwrap_or_default(); + + let inner = CzscTrader::new(symbol, bg, pos_vec); + + Ok(Self { + inner, + signals_config: configs, + ensemble_method, + }) + } + + /// 返回类名 + #[getter] + fn name(&self) -> &str { + &self.inner.name + } + + /// 返回标的代码 + #[getter] + fn symbol(&self) -> &str { + &self.inner.signals.symbol + } + + /// 返回信号字典 s + #[getter] + fn s(&self, py: Python) -> PyResult { + let dict = PyDict::new(py); + for (k, v) in &self.inner.signals.s { + dict.set_item(k, v)?; + } + Ok(dict.into_any().unbind()) + } + + /// 返回各周期 CZSC 分析引擎 + #[getter] + fn kas(&self) -> PyResult> { + Ok(self.inner.signals.kas.clone().into_iter().collect()) + } + + /// 返回所有周期字符串列表 + #[getter] + fn freqs(&self) -> Vec { + self.inner + .signals + .bg + .freq_bars + .keys() + .map(|f| f.to_string()) + .collect() + } + + /// 返回基准周期字符串 + #[getter] + fn base_freq(&self) -> String { + self.inner + .signals + .bg + .freq_bars + .keys() + .next() + .map(|f| f.to_string()) + .unwrap_or_default() + } + + /// 返回最新时间,作为 pandas Timestamp + #[getter] + fn end_dt(&self, py: Python) -> PyResult> { + if let Some(dt_str) = self.inner.signals.s.get("dt") + && let Ok(dt) = chrono::DateTime::parse_from_rfc3339(dt_str) { + let utc_dt = dt.with_timezone(&chrono::Utc); + let timestamp = create_naive_pandas_timestamp(py, utc_dt)?; + return Ok(Some(timestamp)); + } + Ok(None) + } + + /// 返回当前 bar id + #[getter] + fn bid(&self) -> PyResult> { + if let Some(id_str) = self.inner.signals.s.get("id") { + let id = id_str + .parse::() + .map_err(|e| PyValueError::new_err(format!("解析 id 失败: {e}")))?; + return Ok(Some(id)); + } + Ok(None) + } + + /// 返回最新价格 + #[getter] + fn latest_price(&self) -> PyResult> { + if let Some(close_str) = self.inner.signals.s.get("close") { + let price = close_str + .parse::() + .map_err(|e| PyValueError::new_err(format!("解析 close 失败: {e}")))?; + return Ok(Some(price)); + } + Ok(None) + } + + /// 返回原始信号配置 + #[getter] + fn signals_config(&self, py: Python) -> PyResult { + let list = PyList::empty(py); + for cfg in &self.signals_config { + let dict = PyDict::new(py); + dict.set_item("name", &cfg.name)?; + match &cfg.freq { + Some(f) => dict.set_item("freq", f)?, + None => dict.set_item("freq", py.None())?, + } + let params_dict = PyDict::new(py); + for (k, v) in &cfg.params { + let py_val = serde_value_to_py(py, v)?; + params_dict.set_item(k, py_val)?; + } + dict.set_item("params", params_dict)?; + list.append(dict)?; + } + Ok(list.into_any().unbind()) + } + + /// 返回仓位列表(PyPosition 包装) + #[getter] + fn positions(&self) -> Vec { + self.inner + .positions + .iter() + .map(|p| PyPosition { inner: p.clone() }) + .collect() + } + + /// 返回是否有仓位发生变化 + #[getter] + fn pos_changed(&self) -> bool { + self.inner.positions.iter().any(|p| p.get_pos_changed()) + } + + /// 更新信号和仓位 + fn update(&mut self, bar: &RawBar) { + self.inner.update(bar, &self.signals_config); + } + + /// 更新信号和仓位(同 update) + fn on_bar(&mut self, bar: &RawBar) { + self.inner.update(bar, &self.signals_config); + } + + /// 基于信号字典更新仓位 + fn on_sig(&mut self, _py: Python, sig: &Bound) -> PyResult<()> { + // 解析 sig dict 的 key-value 对,设置到 inner.signals.s 和 signal_map + let mut s_map = HashMap::new(); + for (k, v) in sig.iter() { + let key: String = k.extract()?; + let val: String = v.str()?.to_string(); + s_map.insert(key, val); + } + + // 提取必要字段 + let id: i32 = sig + .get_item("id")? + .ok_or_else(|| PyValueError::new_err("sig 缺少 'id'"))? + .extract() + .or_else(|_| { + sig.get_item("id")? + .ok_or_else(|| PyValueError::new_err("sig 缺少 'id'"))? + .str()? + .to_string() + .parse::() + .map_err(|e| PyValueError::new_err(format!("解析 id 失败: {e}"))) + })?; + + let close: f64 = sig + .get_item("close")? + .ok_or_else(|| PyValueError::new_err("sig 缺少 'close'"))? + .extract() + .or_else(|_| { + sig.get_item("close")? + .ok_or_else(|| PyValueError::new_err("sig 缺少 'close'"))? + .str()? + .to_string() + .parse::() + .map_err(|e| PyValueError::new_err(format!("解析 close 失败: {e}"))) + })?; + + let dt_obj = sig + .get_item("dt")? + .ok_or_else(|| PyValueError::new_err("sig 缺少 'dt'"))?; + let dt = parse_dt_from_pyobj(&dt_obj)?; + + // 设置信号 + self.inner.signals.s = s_map.clone(); + self.inner.signals.signal_map = s_map; + + // 构建 LiteBar + let lite_bar = LiteBar { + id, + dt, + price: close, + }; + + // 更新所有仓位 + for pos in &mut self.inner.positions { + pos.update_profiled_with_signal_map( + lite_bar, + None, + Some(&self.inner.signals.signal_map), + ); + } + + Ok(()) + } + + /// 获取集成后的仓位值 + #[pyo3(signature = (method=None))] + fn get_ensemble_pos(&self, method: Option<&str>) -> f64 { + let method = method.unwrap_or(&self.ensemble_method); + let pos_values: Vec = self + .inner + .positions + .iter() + .map(|p| p.get_pos().to_f64()) + .collect(); + + if pos_values.is_empty() { + return 0.0; + } + + match method { + "mean" => pos_values.iter().sum::() / pos_values.len() as f64, + "max" => pos_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max), + "min" => pos_values.iter().cloned().fold(f64::INFINITY, f64::min), + "vote" => { + let sum: f64 = pos_values.iter().sum(); + if sum > 0.0 { + 1.0 + } else if sum < 0.0 { + -1.0 + } else { + 0.0 + } + } + _ => pos_values.iter().sum::() / pos_values.len() as f64, + } + } + + /// 根据名称获取仓位 + fn get_position(&self, name: &str) -> Option { + self.inner + .positions + .iter() + .find(|p| p.name == name) + .map(|p| PyPosition { inner: p.clone() }) + } + + /// 获取当前信号字典 + fn get_signals_by_conf(&self, py: Python) -> PyResult { + self.s(py) + } + + /// 仅更新信号(不更新仓位) + fn update_signals(&mut self, bar: &RawBar) { + self.inner.signals.update_signals(bar, &self.signals_config); + } + + /// Pickle 支持:返回构造参数 (bg, positions, signals_config, ensemble_method)。 + /// 反序列化时由 ``__new__`` 重新构造一个 fresh trader;缓存的运行 + /// 状态不持久化(与 design doc §2.4 multiprocessing 用例一致)。 + fn __reduce__(&self, py: Python) -> PyResult { + let bg_clone = self.inner.signals.bg.clone(); + + // positions: clone via PyPosition wrappers + let positions_list = PyList::empty(py); + for pos in &self.inner.positions { + let py_pos = PyPosition { inner: pos.clone() }; + positions_list.append(py_pos)?; + } + + let configs_list = super::czsc_signals::signal_configs_to_pylist(py, &self.signals_config)?; + + let constructor = py.get_type::(); + let args = (bg_clone, positions_list, configs_list, self.ensemble_method.clone()).into_pyobject(py)?; + let result = (constructor, args).into_pyobject(py)?; + Ok(result.into_any().unbind()) + } +} + +/// 从 Python 对象解析 DateTime +fn parse_dt_from_pyobj(obj: &Bound) -> PyResult> { + // 尝试提取字符串 + if let Ok(s) = obj.extract::() { + // 尝试 RFC3339 + if let Ok(dt) = DateTime::parse_from_rfc3339(&s) { + return Ok(dt); + } + // 尝试 "%Y-%m-%d %H:%M:%S" + if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S") { + return Ok(DateTime::from_naive_utc_and_offset( + naive, + FixedOffset::east_opt(0).unwrap(), + )); + } + // 尝试 "%Y-%m-%d" + if let Ok(naive_date) = chrono::NaiveDate::parse_from_str(&s, "%Y-%m-%d") { + let naive = naive_date.and_hms_opt(0, 0, 0).unwrap(); + return Ok(DateTime::from_naive_utc_and_offset( + naive, + FixedOffset::east_opt(0).unwrap(), + )); + } + return Err(PyValueError::new_err(format!("无法解析时间字符串: {s}"))); + } + + // 尝试 pandas Timestamp: 调用 .isoformat() 或 str() + if let Ok(iso) = obj.call_method0("isoformat") + && let Ok(s) = iso.extract::() { + if let Ok(dt) = DateTime::parse_from_rfc3339(&s) { + return Ok(dt); + } + // pandas isoformat 可能不带时区 + if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S") { + return Ok(DateTime::from_naive_utc_and_offset( + naive, + FixedOffset::east_opt(0).unwrap(), + )); + } + } + + // 最后降级:str(obj) + let s = obj.str()?.to_string(); + if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S") { + return Ok(DateTime::from_naive_utc_and_offset( + naive, + FixedOffset::east_opt(0).unwrap(), + )); + } + + Err(PyValueError::new_err(format!( + "无法解析时间对象: {}", + obj.repr()? + ))) +} diff --git a/crates/czsc-python/src/trader/generate.rs b/crates/czsc-python/src/trader/generate.rs new file mode 100644 index 000000000..11bb92181 --- /dev/null +++ b/crates/czsc-python/src/trader/generate.rs @@ -0,0 +1,109 @@ +use super::czsc_signals::parse_signals_config; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::market::Market; +use czsc_trader::signals::czsc_signals::CzscSignals; +use czsc_trader::signals::sig_parse::get_signals_freqs; +use czsc_utils::bar_generator::BarGenerator; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList}; +use std::str::FromStr; + +/// 批量生成 CZSC 信号 +/// +/// 参数: +/// bars: 基础周期 K 线列表 +/// signals_config: 信号配置列表 +/// sdt: 信号开始计算日期,格式 "YYYYMMDD" 或 "YYYY-MM-DD" +/// init_n: 预热 K 线数量 +/// df: 是否返回 DataFrame(默认 False 返回 list[dict]) +#[pyfunction] +#[pyo3(signature = (bars, signals_config, sdt="20170101", init_n=500, df=false))] +pub fn generate_czsc_signals( + py: Python, + bars: Vec, + signals_config: &Bound, + sdt: &str, + init_n: usize, + df: bool, +) -> PyResult { + if bars.is_empty() { + return Err(PyValueError::new_err("bars 不能为空")); + } + + let configs = parse_signals_config(signals_config)?; + + // 提取所有信号需要的周期 + let freq_strs = get_signals_freqs(&configs); + let freqs: Vec = freq_strs + .iter() + .filter_map(|s| Freq::from_str(s).ok()) + .collect(); + + // 获取基准周期 + let base_freq = bars[0].freq; + + // 创建 BarGenerator + let bg = BarGenerator::new(base_freq, freqs, 1000, Market::Default) + .map_err(|e| PyValueError::new_err(format!("创建 BarGenerator 失败: {e}")))?; + + // 计算分割点:取 sdt 日期和 init_n 的较大者 + let sdt_normalized = normalize_sdt(sdt); + let mut split_idx = init_n.min(bars.len()); + + // 按 sdt 日期查找分割点 + for (i, bar) in bars.iter().enumerate() { + let bar_date = bar.dt.format("%Y%m%d").to_string(); + if bar_date >= sdt_normalized && i >= init_n { + split_idx = i; + break; + } + } + // 确保不超出范围 + if split_idx >= bars.len() { + split_idx = bars.len().saturating_sub(1); + } + + let (bars_left, bars_right) = bars.split_at(split_idx); + + // 创建 CzscSignals 并预热 + let symbol = bars[0].symbol.to_string(); + let mut signals = CzscSignals::new(symbol, bg); + + // 预热 + for bar in bars_left { + signals.bg.update_bar(bar).ok(); + } + + // prime_signals 用最后一根预热 bar + if let Some(last_warmup) = bars_left.last() { + signals.prime_signals(last_warmup, &configs); + } + + // 计算信号 + let mut records: Vec = Vec::with_capacity(bars_right.len()); + for bar in bars_right { + signals.update_signals(bar, &configs); + let dict = PyDict::new(py); + for (k, v) in &signals.s { + dict.set_item(k, v)?; + } + records.push(dict.into_any().unbind()); + } + + if df { + // 返回 DataFrame + let pandas = py.import("pandas")?; + let df_obj = pandas.call_method1("DataFrame", (records,))?; + Ok(df_obj.into_any().unbind()) + } else { + let list = PyList::new(py, &records)?; + Ok(list.into_any().unbind()) + } +} + +/// 将 sdt 标准化为 "YYYYMMDD" 格式 +fn normalize_sdt(sdt: &str) -> String { + sdt.replace('-', "") +} diff --git a/crates/czsc-python/src/trader/mod.rs b/crates/czsc-python/src/trader/mod.rs new file mode 100644 index 000000000..6443be181 --- /dev/null +++ b/crates/czsc-python/src/trader/mod.rs @@ -0,0 +1,14 @@ +//! PyO3 wrappers for czsc-trader public objects (CzscTrader / CzscSignals), +//! the `generate_czsc_signals` free function, and the research/optimize +//! orchestration entrypoints (`run_research`, `run_replay`, +//! `run_optimize_batch`, `build_*_optim_positions`). +//! +//! Mirrors `rs_czsc/python/src/trader/`. The `weight_backtest` submodule +//! from rs-czsc is intentionally NOT migrated — czsc relies on the +//! external `wbt` package for backtests (design doc §3.1 / §5.10). + +pub mod api; +pub mod czsc_signals; +pub mod czsc_trader; +pub mod generate; +pub mod research; diff --git a/crates/czsc-python/src/trader/research.rs b/crates/czsc-python/src/trader/research.rs new file mode 100644 index 000000000..5c1edc6ec --- /dev/null +++ b/crates/czsc-python/src/trader/research.rs @@ -0,0 +1,632 @@ +use crate::trader::api::{build_signals_dataframe, normalize_signals_dtypes, run_optimize}; +use crate::utils::df_convert::{df_to_pyarrow, pyarrow_to_df}; +#[cfg(test)] +use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; +use czsc_core::analyze::utils::format_standard_kline; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::position::Position; +use czsc_trader::engine_v2::{ExecutionPlan, ExecutionPlanInput, UnifiedExecEngine}; +use czsc_trader::optimize::{get_exit_optim_positions, get_open_optim_positions}; +use czsc_trader::signals::sig_parse::SignalConfig; +use czsc_signals::registry::{SIGNAL_REGISTRY, TRADER_SIGNAL_REGISTRY}; +use polars::prelude::*; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyDict}; +use serde::Deserialize; +use serde_json::Value; +use std::collections::HashMap; +use std::fs; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone, Deserialize)] +struct StrategyConfig { + pub name: Option, + pub symbol: String, + pub base_freq: String, + #[allow(dead_code)] + pub signals_module: Option, + #[serde(default)] + pub signals_config: Vec, + pub positions: Vec, + pub market: Option, + pub bg_max_count: Option, + pub sdt: Option, + #[serde(default)] + pub include_sdt_bar: Option, +} + +#[derive(Debug, Deserialize, Default)] +struct RunOpts { + pub emit_signals: Option, +} + +#[cfg(test)] +fn parse_sdt_utc(s: &str) -> Option> { + if s.is_empty() { + return None; + } + if let Ok(dt) = DateTime::parse_from_rfc3339(s) { + return Some(dt.with_timezone(&Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(d) = NaiveDate::parse_from_str(s, "%Y-%m-%d") + && let Some(ndt) = d.and_hms_opt(0, 0, 0) + { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(d) = NaiveDate::parse_from_str(s, "%Y%m%d") + && let Some(ndt) = d.and_hms_opt(0, 0, 0) + { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + None +} + +fn validate_strategy(cfg: &StrategyConfig) -> PyResult<()> { + if cfg.symbol.trim().is_empty() { + return Err(PyValueError::new_err("strategy.symbol 不能为空")); + } + if cfg.positions.is_empty() { + return Err(PyValueError::new_err("strategy.positions 不能为空")); + } + + for sc in &cfg.signals_config { + if sc.freq.is_some() { + if !SIGNAL_REGISTRY.contains_key(sc.name.as_str()) { + return Err(PyValueError::new_err(format!( + "signals_config 包含未注册 K 线信号: {}", + sc.name + ))); + } + } else if !TRADER_SIGNAL_REGISTRY.contains_key(sc.name.as_str()) { + return Err(PyValueError::new_err(format!( + "signals_config 包含未注册 Trader 信号: {}", + sc.name + ))); + } + } + + Ok(()) +} + +fn combine_pairs_holds(positions: &[Position]) -> PyResult<(DataFrame, DataFrame)> { + let mut all_pairs = Vec::new(); + let mut all_holds = Vec::new(); + + for pos in positions { + if let Ok(df) = pos.pairs() { + all_pairs.push(df.lazy()); + } + if let Ok(mut df) = pos.holds() { + if df.height() == 0 { + continue; + } + let pos_name = Series::new("pos_name", vec![pos.name.clone(); df.height()]); + df.with_column(pos_name) + .map_err(|e| PyRuntimeError::new_err(format!("追加 pos_name 列失败: {e}")))?; + all_holds.push(df.lazy()); + } + } + + let mut pairs_df = if all_pairs.is_empty() { + DataFrame::default() + } else { + concat(all_pairs, UnionArgs::default()) + .and_then(|lf| lf.collect()) + .map_err(|e| PyRuntimeError::new_err(format!("合并 pairs 失败: {e}")))? + }; + + for (name, dtype) in [ + ("开仓时间", DataType::Datetime(TimeUnit::Nanoseconds, None)), + ("平仓时间", DataType::Datetime(TimeUnit::Nanoseconds, None)), + ("持仓K线数", DataType::Int64), + ] { + if pairs_df.column(name).is_ok() { + let casted = pairs_df + .column(name) + .and_then(|s| s.cast(&dtype)) + .map_err(|e| { + PyRuntimeError::new_err(format!("pairs 列 {name} 类型转换失败: {e}")) + })?; + pairs_df + .with_column(casted) + .map_err(|e| PyRuntimeError::new_err(format!("pairs 写回列 {name} 失败: {e}")))?; + } + } + + let mut holds_df = if all_holds.is_empty() { + DataFrame::default() + } else { + concat(all_holds, UnionArgs::default()) + .and_then(|lf| lf.collect()) + .map_err(|e| PyRuntimeError::new_err(format!("合并 holds 失败: {e}")))? + }; + + for (name, dtype) in [ + ("dt", DataType::Datetime(TimeUnit::Nanoseconds, None)), + ("pos", DataType::Int64), + ] { + if holds_df.column(name).is_ok() { + let casted = holds_df + .column(name) + .and_then(|s| s.cast(&dtype)) + .map_err(|e| { + PyRuntimeError::new_err(format!("holds 列 {name} 类型转换失败: {e}")) + })?; + holds_df + .with_column(casted) + .map_err(|e| PyRuntimeError::new_err(format!("holds 写回列 {name} 失败: {e}")))?; + } + } + + Ok((pairs_df, holds_df)) +} + +fn write_df_parquet(path: &Path, mut df: DataFrame) -> PyResult<()> { + let mut file = fs::File::create(path) + .map_err(|e| PyValueError::new_err(format!("创建输出文件失败: {e}")))?; + ParquetWriter::new(&mut file) + .finish(&mut df) + .map_err(|e| PyRuntimeError::new_err(format!("写出 parquet 失败: {e}")))?; + Ok(()) +} + +type ResearchCoreResult = ( + StrategyConfig, + usize, + Vec>, + DataFrame, + DataFrame, + i64, + Option, +); + +#[derive(Debug, Clone, Copy, Default)] +struct CoreLoopProfile { + bars: usize, + signals_update_ns: u128, + trader_signals_ns: u128, + position_update_ns: u128, + pos_event_match_ns: u128, + pos_fsm_ns: u128, + pos_risk_ns: u128, + pos_holds_ns: u128, +} + +impl CoreLoopProfile { + fn total_ns(&self) -> u128 { + self.signals_update_ns + self.trader_signals_ns + self.position_update_ns + } +} + +fn event_to_py_dump(e: &czsc_core::objects::event::Event) -> Value { + serde_json::json!({ + "name": e.name, + "operate": e.operate.to_chinese(), + "signals_all": e.signals_all.iter().map(|s| s.to_string()).collect::>(), + "signals_any": e.signals_any.iter().map(|s| s.to_string()).collect::>(), + "signals_not": e.signals_not.iter().map(|s| s.to_string()).collect::>(), + }) +} + +fn position_to_py_dump(p: &Position) -> Value { + serde_json::json!({ + "name": p.name, + "symbol": p.symbol, + "opens": p.opens.iter().map(event_to_py_dump).collect::>(), + "exits": p.exits.iter().map(event_to_py_dump).collect::>(), + "interval": p.interval, + "timeout": p.timeout, + "stop_loss": p.stop_loss, + "T0": p.t0, + }) +} + +fn run_research_core( + bars_raw: &[u8], + strategy_json: &str, + sdt_override: Option<&str>, + emit_signals: bool, +) -> PyResult { + let cfg: StrategyConfig = serde_json::from_str(strategy_json) + .map_err(|e| PyValueError::new_err(format!("strategy json 解析失败: {e}")))?; + validate_strategy(&cfg)?; + + let df = pyarrow_to_df(bars_raw) + .map_err(|e| PyValueError::new_err(format!("Arrow bytes 转 DataFrame 失败: {e}")))?; + let base_freq = cfg + .base_freq + .parse::() + .map_err(|_| PyValueError::new_err("strategy.base_freq 解析失败"))?; + let bars = format_standard_kline(df, base_freq) + .map_err(|e| PyValueError::new_err(format!("K线标准化格式错误: {e}")))?; + + if bars.is_empty() { + return Err(PyValueError::new_err("bars 为空,无法执行回测")); + } + + let plan_input = ExecutionPlanInput { + symbol: cfg.symbol.clone(), + base_freq: cfg.base_freq.clone(), + signals_config: cfg.signals_config.clone(), + positions: cfg.positions.clone(), + market: cfg.market.clone(), + bg_max_count: cfg.bg_max_count, + sdt: cfg.sdt.clone(), + include_sdt_bar: cfg.include_sdt_bar, + }; + let plan = ExecutionPlan::compile(plan_input) + .map_err(|e| PyValueError::new_err(format!("ExecutionPlan 编译失败: {e}")))?; + + let enable_profile = std::env::var("RS_CZSC_PROFILE_CORE") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false); + let output = UnifiedExecEngine::run(&plan, bars, sdt_override, emit_signals, enable_profile) + .map_err(|e| PyRuntimeError::new_err(format!("UnifiedExecEngine 执行失败: {e}")))?; + let (pairs_df, holds_df) = combine_pairs_holds(&output.positions)?; + let elapsed_ms = output.elapsed_ms; + let rows = output.signal_rows; + let bars_count = output.bars_count; + let profile = output.profile.map(|p| CoreLoopProfile { + bars: p.bars, + signals_update_ns: p.signals_update_ns, + trader_signals_ns: p.trader_signals_ns, + position_update_ns: p.position_update_ns, + pos_event_match_ns: p.pos_event_match_ns, + pos_fsm_ns: p.pos_fsm_ns, + pos_risk_ns: p.pos_risk_ns, + pos_holds_ns: p.pos_holds_ns, + }); + + Ok(( + cfg, bars_count, rows, pairs_df, holds_df, elapsed_ms, profile, + )) +} + +#[allow(clippy::too_many_arguments)] +fn build_result_dict( + py: Python<'_>, + cfg: &StrategyConfig, + bars_count: usize, + rows: &[HashMap], + signals_df: &DataFrame, + pairs_df: &DataFrame, + holds_df: &DataFrame, + elapsed_ms: i64, + profile: Option, + extra_paths: Option<(&str, &str, &str)>, +) -> PyResult> { + let mut signals_df_mut = signals_df.clone(); + let mut pairs_df_mut = pairs_df.clone(); + let mut holds_df_mut = holds_df.clone(); + + let signals_arrow = df_to_pyarrow(&mut signals_df_mut) + .map_err(|e| PyRuntimeError::new_err(format!("signals Arrow 编码失败: {e}")))?; + let pairs_arrow = df_to_pyarrow(&mut pairs_df_mut) + .map_err(|e| PyRuntimeError::new_err(format!("pairs Arrow 编码失败: {e}")))?; + let holds_arrow = df_to_pyarrow(&mut holds_df_mut) + .map_err(|e| PyRuntimeError::new_err(format!("holds Arrow 编码失败: {e}")))?; + + let meta = PyDict::new(py); + meta.set_item("symbol", cfg.symbol.clone())?; + meta.set_item("strategy_name", cfg.name.clone().unwrap_or_default())?; + meta.set_item("base_freq", cfg.base_freq.clone())?; + meta.set_item("bars_count", bars_count)?; + meta.set_item("signals_count", rows.len())?; + meta.set_item("positions", cfg.positions.len())?; + meta.set_item("elapsed_ms", elapsed_ms)?; + meta.set_item("warning_count", 0)?; + if let Some(p) = profile { + let pyd = PyDict::new(py); + let total_ns = p.total_ns() as f64; + let signals_ns = p.signals_update_ns as f64; + let trader_ns = p.trader_signals_ns as f64; + let pos_ns = p.position_update_ns as f64; + let pos_event_ns = p.pos_event_match_ns as f64; + let pos_fsm_ns = p.pos_fsm_ns as f64; + let pos_risk_ns = p.pos_risk_ns as f64; + let pos_holds_ns = p.pos_holds_ns as f64; + pyd.set_item("bars", p.bars)?; + pyd.set_item("signals_update_ms", signals_ns / 1_000_000.0)?; + pyd.set_item("trader_signals_ms", trader_ns / 1_000_000.0)?; + pyd.set_item("position_update_ms", pos_ns / 1_000_000.0)?; + pyd.set_item("total_profiled_ms", total_ns / 1_000_000.0)?; + pyd.set_item("position_event_match_ms", pos_event_ns / 1_000_000.0)?; + pyd.set_item("position_fsm_ms", pos_fsm_ns / 1_000_000.0)?; + pyd.set_item("position_risk_ms", pos_risk_ns / 1_000_000.0)?; + pyd.set_item("position_holds_ms", pos_holds_ns / 1_000_000.0)?; + if total_ns > 0.0 { + pyd.set_item("signals_update_pct", signals_ns * 100.0 / total_ns)?; + pyd.set_item("trader_signals_pct", trader_ns * 100.0 / total_ns)?; + pyd.set_item("position_update_pct", pos_ns * 100.0 / total_ns)?; + } else { + pyd.set_item("signals_update_pct", 0.0)?; + pyd.set_item("trader_signals_pct", 0.0)?; + pyd.set_item("position_update_pct", 0.0)?; + } + if pos_ns > 0.0 { + pyd.set_item("position_event_match_pct", pos_event_ns * 100.0 / pos_ns)?; + pyd.set_item("position_fsm_pct", pos_fsm_ns * 100.0 / pos_ns)?; + pyd.set_item("position_risk_pct", pos_risk_ns * 100.0 / pos_ns)?; + pyd.set_item("position_holds_pct", pos_holds_ns * 100.0 / pos_ns)?; + } else { + pyd.set_item("position_event_match_pct", 0.0)?; + pyd.set_item("position_fsm_pct", 0.0)?; + pyd.set_item("position_risk_pct", 0.0)?; + pyd.set_item("position_holds_pct", 0.0)?; + } + meta.set_item("profile", pyd)?; + } + + let out = PyDict::new(py); + out.set_item("meta", meta)?; + out.set_item("signals_arrow", PyBytes::new(py, &signals_arrow))?; + out.set_item("pairs_arrow", PyBytes::new(py, &pairs_arrow))?; + out.set_item("holds_arrow", PyBytes::new(py, &holds_arrow))?; + + if let Some((sp, pp, hp)) = extra_paths { + out.set_item("signals_path", sp)?; + out.set_item("pairs_path", pp)?; + out.set_item("holds_path", hp)?; + } + + Ok(out.into()) +} + +/// 高性能研究入口,返回内存中的 Arrow bytes 结果。 +/// +/// 与 `run_backtest` 的区别: +/// - 不要求事先准备 config 文件,策略直接用 `strategy_json` 传入 +/// - 默认返回内存里的 `signals/pairs/holds` Arrow bytes,便于 Python 侧继续处理 +/// - 可通过 `opts_json` 控制是否生成信号表等细节 +/// +/// 返回值是一个 `dict`,核心字段包括: +/// - `meta`: 执行元数据与 profile +/// - `signals_arrow` +/// - `pairs_arrow` +/// - `holds_arrow` +#[pyfunction] +#[pyo3(text_signature = "(bars_bytes, strategy_json, sdt=None, opts_json=None)")] +#[pyo3(signature = (bars_bytes, strategy_json, sdt=None, opts_json=None))] +pub fn run_research( + py: Python<'_>, + bars_bytes: &Bound, + strategy_json: &str, + sdt: Option<&str>, + opts_json: Option<&str>, +) -> PyResult> { + let opts = opts_json + .map(|s| { + serde_json::from_str::(s) + .map_err(|e| PyValueError::new_err(format!("opts_json 解析失败: {e}"))) + }) + .transpose()? + .unwrap_or_default(); + let emit_signals = opts.emit_signals.unwrap_or(true); + + let (cfg, bars_count, rows, pairs_df, holds_df, elapsed_ms, profile) = + run_research_core(bars_bytes.as_bytes(), strategy_json, sdt, emit_signals)?; + let signals_df = normalize_signals_dtypes(build_signals_dataframe(&rows)?)?; + + build_result_dict( + py, + &cfg, + bars_count, + &rows, + &signals_df, + &pairs_df, + &holds_df, + elapsed_ms, + profile, + None, + ) +} + +/// 回放入口,在 `run_research` 基础上可选把产物落盘为 parquet。 +/// +/// 适合策略开发和可视化回放: +/// - 若提供 `res_path`,会写出 `signals.parquet / pairs.parquet / holds.parquet` +/// - 若不提供 `res_path`,行为与 `run_research` 接近,仍返回内存结果 +/// +/// 返回值同样是一个 `dict`;当实际落盘时会额外带上三个输出文件路径。 +#[pyfunction] +#[pyo3(text_signature = "(bars_bytes, strategy_json, res_path=None, sdt=None, opts_json=None)")] +#[pyo3(signature = (bars_bytes, strategy_json, res_path=None, sdt=None, opts_json=None))] +pub fn run_replay( + py: Python<'_>, + bars_bytes: &Bound, + strategy_json: &str, + res_path: Option<&str>, + sdt: Option<&str>, + opts_json: Option<&str>, +) -> PyResult> { + let opts = opts_json + .map(|s| { + serde_json::from_str::(s) + .map_err(|e| PyValueError::new_err(format!("opts_json 解析失败: {e}"))) + }) + .transpose()? + .unwrap_or_default(); + let emit_signals = opts.emit_signals.unwrap_or(true); + + let (cfg, bars_count, rows, pairs_df, holds_df, elapsed_ms, profile) = + run_research_core(bars_bytes.as_bytes(), strategy_json, sdt, emit_signals)?; + let signals_df = normalize_signals_dtypes(build_signals_dataframe(&rows)?)?; + + let mut extra_paths: Option<(String, String, String)> = None; + if let Some(base) = res_path { + let base_path = Path::new(base); + if !base_path.exists() { + fs::create_dir_all(base_path) + .map_err(|e| PyValueError::new_err(format!("创建结果目录失败: {e}")))?; + } + + let signals_path = base_path.join("signals.parquet"); + let pairs_path = base_path.join("pairs.parquet"); + let holds_path = base_path.join("holds.parquet"); + + write_df_parquet(&signals_path, signals_df.clone())?; + write_df_parquet(&pairs_path, pairs_df.clone())?; + write_df_parquet(&holds_path, holds_df.clone())?; + + extra_paths = Some(( + signals_path.to_string_lossy().to_string(), + pairs_path.to_string_lossy().to_string(), + holds_path.to_string_lossy().to_string(), + )); + } + + let extra_refs = extra_paths + .as_ref() + .map(|(a, b, c)| (a.as_str(), b.as_str(), c.as_str())); + + build_result_dict( + py, + &cfg, + bars_count, + &rows, + &signals_df, + &pairs_df, + &holds_df, + elapsed_ms, + profile, + extra_refs, + ) +} + +/// 优化批量入口,接受 JSON 字符串形式的优化配置。 +/// +/// 这是 Python facade 常用入口: +/// - Python 侧直接构造 `dict` +/// - 序列化成 JSON 字符串传给这里 +/// - Rust 内部写入临时配置文件,再复用 `run_optimize` +/// +/// 这样可以兼容旧版类式 API,同时保持底层只维护一套优化执行逻辑。 +#[pyfunction] +#[pyo3(text_signature = "(bars_dir, optimize_config_json, res_path, n_threads=1)")] +#[pyo3(signature = (bars_dir, optimize_config_json, res_path, n_threads=1))] +pub fn run_optimize_batch( + bars_dir: &str, + optimize_config_json: &str, + res_path: &str, + n_threads: usize, +) -> PyResult { + let parsed: Value = serde_json::from_str(optimize_config_json) + .map_err(|e| PyValueError::new_err(format!("optimize 配置 JSON 解析失败: {e}")))?; + + let temp_path = std::env::temp_dir().join("rs_czsc_optimize_config.json"); + fs::write(&temp_path, parsed.to_string()) + .map_err(|e| PyValueError::new_err(format!("写入临时优化配置失败: {e}")))?; + + run_optimize( + bars_dir, + temp_path + .to_str() + .ok_or_else(|| PyValueError::new_err("临时配置路径无效"))?, + res_path, + n_threads, + ) +} + +/// 仅构建开仓优化候选策略,不运行回测。 +/// +/// 输入: +/// - `files_position`: 基准仓位 JSON 文件路径列表 +/// - `candidate_signals`: 候选开仓信号列表 +/// +/// 输出: +/// - `Position.dump()` 风格的 JSON 字符串数组 +/// +/// 典型用法是先调用本函数生成候选仓位,再交给 `run_optimize_batch` 跑批。 +#[pyfunction] +#[pyo3(text_signature = "(files_position, candidate_signals)")] +#[pyo3(signature = (files_position, candidate_signals))] +pub fn build_open_optim_positions( + files_position: Vec, + candidate_signals: Vec, +) -> PyResult { + let files: Vec = files_position.into_iter().map(PathBuf::from).collect(); + let positions = get_open_optim_positions(&files, &candidate_signals) + .map_err(|e| PyRuntimeError::new_err(format!("构建开仓优化策略失败: {e}")))?; + let payload: Vec = positions.iter().map(position_to_py_dump).collect(); + serde_json::to_string(&payload) + .map_err(|e| PyRuntimeError::new_err(format!("序列化开仓优化策略失败: {e}"))) +} + +/// 仅构建平仓优化候选策略,不运行回测。 +/// +/// 与 `build_open_optim_positions` 的区别在于输入是 `candidate_events_json`, +/// 即 Python 侧事件定义列表的 JSON 字符串。返回值仍是 +/// `Position.dump()` 风格的 JSON 字符串数组。 +#[pyfunction] +#[pyo3(text_signature = "(files_position, candidate_events_json)")] +#[pyo3(signature = (files_position, candidate_events_json))] +pub fn build_exit_optim_positions( + files_position: Vec, + candidate_events_json: &str, +) -> PyResult { + let files: Vec = files_position.into_iter().map(PathBuf::from).collect(); + let candidate_events: Vec = serde_json::from_str(candidate_events_json) + .map_err(|e| PyValueError::new_err(format!("candidate_events_json 解析失败: {e}")))?; + let positions = get_exit_optim_positions(&files, &candidate_events) + .map_err(|e| PyRuntimeError::new_err(format!("构建平仓优化策略失败: {e}")))?; + let payload: Vec = positions.iter().map(position_to_py_dump).collect(); + serde_json::to_string(&payload) + .map_err(|e| PyRuntimeError::new_err(format!("序列化平仓优化策略失败: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::{StrategyConfig, parse_sdt_utc, validate_strategy}; + + #[test] + fn test_parse_sdt_utc_supports_iso_t_without_tz() { + let dt = parse_sdt_utc("2023-02-28T12:00:00"); + assert!(dt.is_some()); + } + + #[test] + fn test_strategy_config_json_deserialize() { + let s = r#"{ + \"name\": \"demo\", + \"symbol\": \"TEST.SZ\", + \"base_freq\": \"5分钟\", + \"signals_module\": \"czsc.signals\", + \"signals_config\": [{\"name\":\"bar_triple_V230506\",\"freq\":\"5分钟\",\"params\":{\"di\":1}}], + \"positions\": [{ + \"name\": \"p1\", \"symbol\": \"TEST.SZ\", \"opens\": [], \"exits\": [], + \"interval\": 0, \"timeout\": 1, \"stop_loss\": 100.0, \"T0\": false + }], + \"market\": \"默认\", \"bg_max_count\": 5000 + }"#; + let cfg: StrategyConfig = serde_json::from_str(s).expect("deserialize failed"); + assert_eq!(cfg.symbol, "TEST.SZ"); + assert_eq!(cfg.base_freq, "5分钟"); + assert_eq!(cfg.positions.len(), 1); + } + + #[test] + fn test_validate_strategy_rejects_empty_symbol() { + let s = r#"{ + \"symbol\": \"\", + \"base_freq\": \"5分钟\", + \"signals_config\": [], + \"positions\": [{ + \"name\": \"p1\", \"symbol\": \"TEST.SZ\", \"opens\": [], \"exits\": [], + \"interval\": 0, \"timeout\": 1, \"stop_loss\": 100.0, \"T0\": false + }] + }"#; + let cfg: StrategyConfig = serde_json::from_str(s).expect("deserialize failed"); + let r = validate_strategy(&cfg); + assert!(r.is_err()); + } +} diff --git a/crates/czsc-python/src/utils/df_convert.rs b/crates/czsc-python/src/utils/df_convert.rs new file mode 100644 index 000000000..60e9820c7 --- /dev/null +++ b/crates/czsc-python/src/utils/df_convert.rs @@ -0,0 +1,16 @@ +use crate::errors::PythonError; +use polars::prelude::*; +use std::io::Cursor; + +pub fn pyarrow_to_df(data: &[u8]) -> Result { + let cursor = Cursor::new(data); + let df = IpcReader::new(cursor).finish().map_err(PythonError::from)?; + Ok(df) +} + +/// 将DataFrame转换为字节数组 +pub fn df_to_pyarrow(dataframe: &mut DataFrame) -> Result, PythonError> { + let mut buffer = Cursor::new(Vec::new()); + IpcWriter::new(&mut buffer).finish(dataframe)?; + Ok(buffer.into_inner()) +} diff --git a/crates/czsc-python/src/utils/mod.rs b/crates/czsc-python/src/utils/mod.rs new file mode 100644 index 000000000..137579dd5 --- /dev/null +++ b/crates/czsc-python/src/utils/mod.rs @@ -0,0 +1,4 @@ +//! czsc-python utility submodules (Rust helpers used by the trader +//! binding layer — Arrow IPC roundtrip, etc.). + +pub mod df_convert; diff --git a/crates/czsc-signal-macros/Cargo.toml b/crates/czsc-signal-macros/Cargo.toml new file mode 100644 index 000000000..2a507aedd --- /dev/null +++ b/crates/czsc-signal-macros/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "czsc-signal-macros" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Proc-macro for #[signal_module] registration. Placeholder, to be migrated." + +[lib] +name = "czsc_signal_macros" +path = "src/lib.rs" +proc-macro = true + +[dependencies] +proc-macro2 = "1" +quote = "1" +syn = { version = "2", features = ["full", "parsing"] } diff --git a/crates/czsc-signal-macros/src/lib.rs b/crates/czsc-signal-macros/src/lib.rs new file mode 100644 index 000000000..2451692c9 --- /dev/null +++ b/crates/czsc-signal-macros/src/lib.rs @@ -0,0 +1,470 @@ +use proc_macro::TokenStream; +use quote::{ToTokens, format_ident, quote}; +use syn::parse::Parser; +use syn::punctuated::Punctuated; +use syn::{Expr, ExprLit, FnArg, Item, ItemFn, ItemMod, Lit, Meta, Token, Type}; + +fn type_tokens(t: &Type) -> String { + t.to_token_stream().to_string() +} + +fn nth_arg_type(f: &ItemFn, idx: usize) -> Option { + f.sig.inputs.iter().nth(idx).and_then(|a| match a { + FnArg::Typed(t) => Some((*t.ty).clone()), + _ => None, + }) +} + +#[proc_macro_attribute] +pub fn signal(attr: TokenStream, item: TokenStream) -> TokenStream { + let parser = Punctuated::::parse_terminated; + let metas = match parser.parse(attr) { + Ok(m) => m, + Err(e) => return e.to_compile_error().into(), + }; + + let mut category: Option = None; + let mut name: Option = None; + let mut template: Option = None; + let mut opcode: Option = None; + let mut param_kind: Option = None; + let mut fast_exec: Option = None; + let mut fast_decode: Option = None; + + for m in metas { + if let Meta::NameValue(nv) = m + && let Some(ident) = nv.path.get_ident() + && let Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }) = nv.value + { + match ident.to_string().as_str() { + "category" => category = Some(v.value()), + "name" => name = Some(v.value()), + "template" => template = Some(v.value()), + "opcode" => opcode = Some(v.value()), + "param_kind" => param_kind = Some(v.value()), + "fast_exec" => fast_exec = Some(v.value()), + "fast_decode" => fast_decode = Some(v.value()), + _ => {} + } + } + } + + let f: ItemFn = match syn::parse(item) { + Ok(v) => v, + Err(e) => return e.to_compile_error().into(), + }; + + let mut errors = Vec::new(); + let category = category.unwrap_or_default(); + let name = name.unwrap_or_default(); + let template = template.unwrap_or_default(); + let opcode = opcode.unwrap_or_default(); + let param_kind = param_kind.unwrap_or_default(); + let fast_exec = fast_exec.unwrap_or_default(); + let fast_decode = fast_decode.unwrap_or_default(); + + if category != "kline" && category != "trader" { + errors.push(quote! { compile_error!("#[signal] category 必须是 kline 或 trader"); }); + } + if name.is_empty() || template.is_empty() || opcode.is_empty() || param_kind.is_empty() { + errors + .push(quote! { compile_error!("#[signal] name/template/opcode/param_kind 不能为空"); }); + } + + let fn_ident = f.sig.ident.to_string(); + if !fn_ident.contains("_v") { + errors.push(quote! { compile_error!("#[signal] 函数名必须包含 _v<版本号>"); }); + } else { + let expected = if let Some((head, tail)) = fn_ident.rsplit_once("_v") { + if !tail.is_empty() && tail.chars().all(|c| c.is_ascii_digit()) { + format!("{head}_V{tail}") + } else { + String::new() + } + } else { + String::new() + }; + if expected.is_empty() || (!name.is_empty() && expected != name) { + errors.push(quote! { compile_error!("#[signal] name 必须与函数名版本后缀一致,例如 foo_v230101 <-> foo_V230101"); }); + } + if !template.is_empty() { + let ver = expected + .rsplit_once("_V") + .map(|(_, v)| v.to_string()) + .unwrap_or_default(); + if !ver.is_empty() && !template.contains(ver.as_str()) { + errors.push( + quote! { compile_error!("#[signal] template 必须包含版本数字(如 230101)"); }, + ); + } + } + } + + let argc = f.sig.inputs.len(); + match category.as_str() { + "kline" if argc != 3 => { + errors.push(quote! { compile_error!("kline signal 函数必须有 3 个参数"); }) + } + "trader" if argc != 2 => { + errors.push(quote! { compile_error!("trader signal 函数必须有 2 个参数"); }) + } + _ => {} + } + if category == "kline" && argc == 3 { + let t0 = nth_arg_type(&f, 0) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + let t1 = nth_arg_type(&f, 1) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + let t2 = nth_arg_type(&f, 2) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + if !t0.contains("CZSC") { + errors.push(quote! { compile_error!("kline signal 第1个参数必须是 &CZSC"); }); + } + if !t1.contains('&') { + errors.push(quote! { compile_error!("kline signal 第2个参数必须为引用类型(如 &ParamView / &HashMap / &TypedParams)"); }); + } + if !(t2.contains("TaCache") && t2.contains("mut")) { + errors.push(quote! { compile_error!("kline signal 第3个参数必须是 &mut TaCache"); }); + } + } + if category == "trader" && argc == 2 { + let t0 = nth_arg_type(&f, 0) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + let t1 = nth_arg_type(&f, 1) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + if !t0.contains("TraderState") { + errors + .push(quote! { compile_error!("trader signal 第1个参数必须是 &dyn TraderState"); }); + } + if !t1.contains('&') { + errors.push(quote! { compile_error!("trader signal 第2个参数必须为引用类型(如 &ParamView / &HashMap / &TypedParams)"); }); + } + } + + let vis = &f.vis; + let sig = &f.sig; + let block = &f.block; + let meta_const_ident = syn::Ident::new( + &format!("__RS_CZSC_SIGNAL_META_{}", f.sig.ident).to_uppercase(), + f.sig.ident.span(), + ); + + let fn_name = &f.sig.ident; + let mut generated_wrappers = quote! {}; + let (func_ref_expr, auto_fast_expr) = if category == "kline" { + let raw_param_ty = f.sig.inputs.iter().nth(1).and_then(|a| match a { + FnArg::Typed(t) => Some((*t.ty).clone()), + _ => None, + }); + let param_ty = raw_param_ty.as_ref().map(|t| match t { + Type::Reference(r) => (*r.elem).clone(), + _ => t.clone(), + }); + let raw_param_ty_tokens = raw_param_ty + .as_ref() + .map(|t| t.to_token_stream().to_string()) + .unwrap_or_default(); + let is_hashmap_params = raw_param_ty_tokens.contains("HashMap"); + let is_param_view = raw_param_ty_tokens.contains("ParamView"); + if is_hashmap_params { + ( + quote! { czsc_signals::types::SignalFnRef::Kline(#fn_name as czsc_signals::types::SignalFn) }, + quote! { None }, + ) + } else if is_param_view { + let dyn_wrap_ident = format_ident!("__rs_dyn_wrap_{}", fn_name); + generated_wrappers = quote! { + #[doc(hidden)] + fn #dyn_wrap_ident( + czsc: &czsc_core::analyze::CZSC, + params: &std::collections::HashMap, + cache: &mut czsc_signals::types::TaCache, + ) -> Vec { + let p = czsc_signals::params::ParamView::new(params); + #fn_name(czsc, &p, cache) + } + }; + ( + quote! { czsc_signals::types::SignalFnRef::Kline(#dyn_wrap_ident as czsc_signals::types::SignalFn) }, + quote! { None }, + ) + } else { + let pty = param_ty.expect("checked"); + let dyn_wrap_ident = format_ident!("__rs_dyn_wrap_{}", fn_name); + let fast_decode_ident = format_ident!("__rs_fast_decode_{}", fn_name); + let fast_exec_ident = format_ident!("__rs_fast_exec_{}", fn_name); + generated_wrappers = quote! { + #[doc(hidden)] + fn #dyn_wrap_ident( + czsc: &czsc_core::analyze::CZSC, + params: &std::collections::HashMap, + cache: &mut czsc_signals::types::TaCache, + ) -> Vec { + let val = match serde_json::to_value(params) { + Ok(v) => v, + Err(_) => return Vec::new(), + }; + let p: #pty = match serde_json::from_value(val) { + Ok(v) => v, + Err(_) => return Vec::new(), + }; + #fn_name(czsc, &p, cache) + } + + #[doc(hidden)] + fn #fast_decode_ident( + params: &std::collections::HashMap, + ) -> Option { + let val = serde_json::to_value(params).ok()?; + let p: #pty = serde_json::from_value(val).ok()?; + serde_json::to_value(p).ok() + } + + #[doc(hidden)] + fn #fast_exec_ident( + czsc: &czsc_core::analyze::CZSC, + p: &serde_json::Value, + cache: &mut czsc_signals::types::TaCache, + ) -> Vec { + let pp: #pty = match serde_json::from_value(p.clone()) { + Ok(v) => v, + Err(_) => return Vec::new(), + }; + #fn_name(czsc, &pp, cache) + } + }; + ( + quote! { czsc_signals::types::SignalFnRef::Kline(#dyn_wrap_ident as czsc_signals::types::SignalFn) }, + quote! { + Some(czsc_signals::types::FastKlineMeta { + decode: #fast_decode_ident as czsc_signals::types::FastKlineDecodeFn, + exec: #fast_exec_ident as czsc_signals::types::FastKlineExecFn, + }) + }, + ) + } + } else { + let raw_param_ty = f.sig.inputs.iter().nth(1).and_then(|a| match a { + FnArg::Typed(t) => Some((*t.ty).clone()), + _ => None, + }); + let raw_param_ty_tokens = raw_param_ty + .as_ref() + .map(|t| t.to_token_stream().to_string()) + .unwrap_or_default(); + let is_hashmap_params = raw_param_ty_tokens.contains("HashMap"); + let is_param_view = raw_param_ty_tokens.contains("ParamView"); + if is_hashmap_params { + ( + quote! { czsc_signals::types::SignalFnRef::Trader(#fn_name as czsc_signals::types::TraderSignalFn) }, + quote! { None }, + ) + } else if is_param_view { + let dyn_wrap_ident = format_ident!("__rs_dyn_wrap_{}", fn_name); + generated_wrappers = quote! { + #[doc(hidden)] + fn #dyn_wrap_ident( + cat: &dyn czsc_core::objects::state::TraderState, + params: &std::collections::HashMap, + ) -> Vec { + let p = czsc_signals::params::ParamView::new(params); + #fn_name(cat, &p) + } + }; + ( + quote! { czsc_signals::types::SignalFnRef::Trader(#dyn_wrap_ident as czsc_signals::types::TraderSignalFn) }, + quote! { None }, + ) + } else { + let pty = raw_param_ty.expect("checked"); + let pty = match pty { + Type::Reference(r) => *r.elem, + t => t, + }; + let dyn_wrap_ident = format_ident!("__rs_dyn_wrap_{}", fn_name); + generated_wrappers = quote! { + #[doc(hidden)] + fn #dyn_wrap_ident( + cat: &dyn czsc_core::objects::state::TraderState, + params: &std::collections::HashMap, + ) -> Vec { + let val = match serde_json::to_value(params) { + Ok(v) => v, + Err(_) => return Vec::new(), + }; + let p: #pty = match serde_json::from_value(val) { + Ok(v) => v, + Err(_) => return Vec::new(), + }; + #fn_name(cat, &p) + } + }; + ( + quote! { czsc_signals::types::SignalFnRef::Trader(#dyn_wrap_ident as czsc_signals::types::TraderSignalFn) }, + quote! { None }, + ) + } + }; + let fast_kline_expr = if category == "kline" && !fast_exec.is_empty() && !fast_decode.is_empty() + { + let fast_exec_path: syn::Path = match syn::parse_str(&fast_exec) { + Ok(p) => p, + Err(e) => return e.to_compile_error().into(), + }; + let fast_decode_path: syn::Path = match syn::parse_str(&fast_decode) { + Ok(p) => p, + Err(e) => return e.to_compile_error().into(), + }; + quote! { + Some(czsc_signals::types::FastKlineMeta { + decode: #fast_decode_path as czsc_signals::types::FastKlineDecodeFn, + exec: #fast_exec_path as czsc_signals::types::FastKlineExecFn, + }) + } + } else { + auto_fast_expr + }; + + let out = quote! { + #(#errors)* + #vis #sig #block + #generated_wrappers + + #[doc(hidden)] + #[allow(non_upper_case_globals, dead_code)] + pub const #meta_const_ident: czsc_signals::types::SignalDescriptor = czsc_signals::types::SignalDescriptor { + category: #category, + name: #name, + template: #template, + opcode: #opcode, + param_kind: #param_kind, + func_ref: #func_ref_expr, + fast_kline: #fast_kline_expr, + }; + + inventory::submit! { + #meta_const_ident + } + }; + out.into() +} + +#[proc_macro_attribute] +pub fn signal_module(_attr: TokenStream, item: TokenStream) -> TokenStream { + let parser = Punctuated::::parse_terminated; + let metas = match parser.parse(_attr) { + Ok(m) => m, + Err(e) => return e.to_compile_error().into(), + }; + let mut module_category = String::new(); + for m in metas { + if let Meta::NameValue(nv) = m + && let Some(ident) = nv.path.get_ident() + && ident == "category" + && let Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }) = nv.value + { + module_category = v.value(); + } + } + if module_category != "kline" && module_category != "trader" { + return quote! { compile_error!("#[signal_module] category 必须是 kline 或 trader"); } + .into(); + } + + let m: ItemMod = match syn::parse(item.clone()) { + Ok(v) => v, + Err(_) => return item, + }; + + if m.content.is_none() { + return quote! { + compile_error!("#[signal_module] 仅支持内联模块,用于编译期收集与校验"); + #m + } + .into(); + } + + let (_, items) = m.content.as_ref().expect("checked is_some"); + let mut seen_names = std::collections::HashSet::new(); + let mut seen_opcodes = std::collections::HashSet::new(); + for it in items { + if let Item::Fn(f) = it { + let name = f.sig.ident.to_string(); + if f.vis.to_token_stream().to_string() == "pub" && name.contains("_v") { + let signal_attr = f.attrs.iter().find(|a| a.path().is_ident("signal")); + if signal_attr.is_none() { + return quote! { + compile_error!("signal_module: pub fn *_v* 必须显式添加 #[signal(...)] 标注"); + #m + } + .into(); + } + let argc = f.sig.inputs.len(); + if module_category == "kline" { + if argc != 3 { + return quote! { compile_error!("signal_module: kline 模块中的 pub fn *_v* 必须是 3 参数签名"); #m }.into(); + } + let t0 = nth_arg_type(f, 0) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + let t2 = nth_arg_type(f, 2) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + if !(t0.contains("CZSC") && t2.contains("TaCache") && t2.contains("mut")) { + return quote! { compile_error!("signal_module: kline 模块函数签名必须为 (&CZSC, &Params, &mut TaCache)"); #m }.into(); + } + } else if module_category == "trader" { + if argc != 2 { + return quote! { compile_error!("signal_module: trader 模块中的 pub fn *_v* 必须是 2 参数签名"); #m }.into(); + } + let t0 = nth_arg_type(f, 0) + .map(|t| type_tokens(&t)) + .unwrap_or_default(); + if !t0.contains("TraderState") { + return quote! { compile_error!("signal_module: trader 模块函数签名必须为 (&dyn TraderState, &Params)"); #m }.into(); + } + } + if let Some(attr) = signal_attr { + let parser = Punctuated::::parse_terminated; + let mut attr_name = String::new(); + let mut attr_opcode = String::new(); + if let syn::Meta::List(list) = &attr.meta + && let Ok(ms) = parser.parse2(list.tokens.clone()) + { + for m in ms { + if let Meta::NameValue(nv) = m + && let Some(ident) = nv.path.get_ident() + && let Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }) = nv.value + { + if ident == "name" { + attr_name = v.value(); + } else if ident == "opcode" { + attr_opcode = v.value(); + } + } + } + } + if !attr_name.is_empty() && !seen_names.insert(attr_name.clone()) { + return quote! { compile_error!("signal_module: duplicate signal name in module"); #m }.into(); + } + if !attr_opcode.is_empty() && !seen_opcodes.insert(attr_opcode.clone()) { + return quote! { compile_error!("signal_module: duplicate signal opcode in module"); #m }.into(); + } + } + } + } + } + + quote! { #m }.into() +} diff --git a/crates/czsc-signal-macros/tests/test_export.rs b/crates/czsc-signal-macros/tests/test_export.rs new file mode 100644 index 000000000..88cb1514e --- /dev/null +++ b/crates/czsc-signal-macros/tests/test_export.rs @@ -0,0 +1,15 @@ +//! Phase E.last — smoke test: czsc-signal-macros compiles and the test +//! binary can link against it. Proc-macros are compile-time constructs, +//! so the real validation is that this test target builds at all. +//! +//! Full expansion testing requires czsc-signals types and lands in +//! Phase F (every signal module under crates/czsc-signals/src/*.rs +//! exercises `#[signal_module]` and `#[signal]` against the real +//! types). + +#[test] +fn proc_macro_crate_links() { + // Reaching this assertion means the crate compiled with both + // #[proc_macro_attribute] entrypoints exported. + assert!(true); +} diff --git a/crates/czsc-signals/Cargo.toml b/crates/czsc-signals/Cargo.toml new file mode 100644 index 000000000..9ede996e9 --- /dev/null +++ b/crates/czsc-signals/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "czsc-signals" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "CZSC signal functions (bar / cxt / tas / vol / pressure / obv / cvolp / ...). Migrated from rs-czsc." + +[lib] +name = "czsc_signals" +path = "src/lib.rs" + +[dependencies] +czsc-core = { path = "../czsc-core", features = ["python"] } +czsc-ta = { path = "../czsc-ta" } +czsc-signal-macros = { path = "../czsc-signal-macros" } +serde = { workspace = true } +serde_json = "1.0" +anyhow = "1.0" +tracing = "0.1" +inventory = "0.3" +chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/crates/czsc-signals/src/ang.rs b/crates/czsc-signals/src/ang.rs new file mode 100644 index 000000000..8d13c18da --- /dev/null +++ b/crates/czsc-signals/src/ang.rs @@ -0,0 +1,765 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1, make_kline_signal_v2, pd_cut_last_label}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; +use std::collections::HashMap; + +fn mean_or_nan(values: &[f64]) -> f64 { + if values.is_empty() { + f64::NAN + } else { + values.iter().sum::() / values.len() as f64 + } +} + +fn sma_valid(values: &[f64], n: usize) -> Vec { + if n == 0 || values.len() < n { + return vec![]; + } + let mut out = Vec::with_capacity(values.len() - n + 1); + let mut acc: f64 = values[..n].iter().sum(); + out.push(acc / n as f64); + for i in n..values.len() { + acc += values[i] - values[i - n]; + out.push(acc / n as f64); + } + out +} + +/// adtm_up_dw_line_V230603:ADTM 能量异动多空信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}TH{th}_ADTMV230603"` +/// +/// 信号逻辑: +/// 1. 计算 `N` 窗口 `up_sum` 与 `M` 窗口 `dw_sum`; +/// 2. 计算 `adtm = (up_sum - dw_sum) / max(up_sum, dw_sum)`; +/// 3. `up_sum > dw_sum` 或 `adtm > th/10` 判 `看多`; +/// 4. `up_sum < dw_sum` 或 `adtm < th/10` 判 `看空`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N30M20TH5_ADTMV230603_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N30M20TH5_ADTMV230603_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:`up_sum` 窗口,默认 `30`; +/// - `m`:`dw_sum` 窗口,默认 `20`; +/// - `th`:阈值(除以 10 使用),默认 `5`。 +/// 对齐说明:与 Python `adtm_up_dw_line_V230603` 的条件优先级与阈值口径一致。 +#[signal( + category = "kline", + name = "adtm_up_dw_line_V230603", + template = "{freq}_D{di}N{n}M{m}TH{th}_ADTMV230603", + opcode = "AdtmUpDwLineV230603", + param_kind = "AdtmUpDwLineV230603" +)] +pub fn adtm_up_dw_line_v230603(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 30); + let m = params.usize("m", 20); + let th = params.usize("th", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}TH{}", di, n, m, th); + let k3 = "ADTMV230603"; + + if c.bars_raw.len() < di + n.max(m) + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let n_bars = get_sub_elements(&c.bars_raw, di, n); + let m_bars = get_sub_elements(&c.bars_raw, di, m); + if n_bars.len() < 2 || m_bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut up_sum = 0.0; + for i in 1..n_bars.len() { + if n_bars[i].open > n_bars[i - 1].open { + up_sum += (n_bars[i].high - n_bars[i].open).max(n_bars[i].open - n_bars[i - 1].open); + } + } + + let mut dw_sum = 0.0; + for i in 1..m_bars.len() { + if m_bars[i].open < m_bars[i - 1].open { + dw_sum += (m_bars[i].open - m_bars[i].low).max(m_bars[i - 1].open - m_bars[i].open); + } + } + + let denom = up_sum.max(dw_sum); + let adtm = if denom > 0.0 { + (up_sum - dw_sum) / denom + } else { + f64::NAN + }; + + let mut v1 = "其他"; + if up_sum > dw_sum || adtm > th as f64 / 10.0 { + v1 = "看多"; + } + if up_sum < dw_sum || adtm < th as f64 / 10.0 { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// amv_up_dw_line_V230603:AMV 能量多空信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}_AMV能量V230603"` +/// +/// 信号逻辑: +/// 1. 计算 `N` 与 `M` 窗口成交额加权均价; +/// 2. 形成 `amv1` 与 `amv2`; +/// 3. `amv1 > amv2` 判 `看多`,否则 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N30M120_AMV能量V230603_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N30M120_AMV能量V230603_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:短窗口,默认 `30`; +/// - `m`:长窗口,默认 `120`。 +/// 对齐说明:与 Python `amv_up_dw_line_V230603` 的加权均价公式一致。 +#[signal( + category = "kline", + name = "amv_up_dw_line_V230603", + template = "{freq}_D{di}N{n}M{m}_AMV能量V230603", + opcode = "AmvUpDwLineV230603", + param_kind = "AmvUpDwLineV230603" +)] +pub fn amv_up_dw_line_v230603(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 30); + let m = params.usize("m", 120); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}", di, n, m); + let k3 = "AMV能量V230603"; + if n > m || c.bars_raw.len() < di + m + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let n_bars = get_sub_elements(&c.bars_raw, di, n); + let m_bars = get_sub_elements(&c.bars_raw, di, m); + if n_bars.is_empty() || m_bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let amov1: f64 = n_bars + .iter() + .map(|b| b.amount * (b.open + b.close) / 2.0) + .sum(); + let amov2: f64 = m_bars + .iter() + .map(|b| b.amount * (b.open + b.close) / 2.0) + .sum(); + let vol_sum1: f64 = n_bars.iter().map(|b| b.amount).sum(); + let vol_sum2: f64 = m_bars.iter().map(|b| b.amount).sum(); + let amv1 = amov1 / vol_sum1; + let amv2 = amov2 / vol_sum2; + let v1 = if amv1 > amv2 { "看多" } else { "看空" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// asi_up_dw_line_V230603:ASI 多空信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}P{p}_ASI多空V230603"` +/// +/// 信号逻辑: +/// 1. 基于最近 `p` 根K线计算 SI 序列并累加得 ASI; +/// 2. 将最新 ASI 与 `p` 窗口 ASI 均值比较; +/// 3. `asi_last > asi_mean` 判 `看多`,否则 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N30P120_ASI多空V230603_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N30P120_ASI多空V230603_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:SI 公式中的常数项,默认 `30`; +/// - `p`:窗口长度,默认 `120`。 +/// 对齐说明:按 Python `asi_up_dw_line_V230603` 的原始向量公式逐项对齐实现。 +#[signal( + category = "kline", + name = "asi_up_dw_line_V230603", + template = "{freq}_D{di}N{n}P{p}_ASI多空V230603", + opcode = "AsiUpDwLineV230603", + param_kind = "AsiUpDwLineV230603" +)] +pub fn asi_up_dw_line_v230603(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 30); + let p = params.usize("p", 120); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}P{}", di, n, p); + let k3 = "ASI多空V230603"; + if c.bars_raw.len() < di + p + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let bars = get_sub_elements(&c.bars_raw, di, p); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let len = bars.len(); + let mut close = Vec::with_capacity(len); + let mut open = Vec::with_capacity(len); + let mut high = Vec::with_capacity(len); + let mut low = Vec::with_capacity(len); + for b in bars { + close.push(b.close); + open.push(b.open); + high.push(b.high); + low.push(b.low); + } + + let mut prev_close = Vec::with_capacity(len); + let mut prev_low = Vec::with_capacity(len); + let mut prev_open = Vec::with_capacity(len); + prev_close.push(close[0]); + prev_low.push(low[0]); + prev_open.push(open[0]); + for i in 1..len { + prev_close.push(close[i - 1]); + prev_low.push(low[i - 1]); + prev_open.push(open[i - 1]); + } + + let mut si = Vec::with_capacity(len); + for i in 0..len { + let a = (high[i] - prev_close[i]).abs(); + let b = (low[i] - prev_close[i]).abs(); + let c1 = (high[i] - prev_low[i]).abs(); + let d = (prev_close[i] - prev_open[i]).abs(); + let k = a.max(b); + let m = (high[i] - low[i]).max(n as f64); + let r1 = a + 0.5 * b + 0.25 * d; + let r2 = b + 0.5 * a + 0.25 * d; + let r3 = c1 + 0.25 * d; + let r4 = if a >= b && a >= c1 { r1 } else { r2 }; + let r = if c1 >= a && c1 >= b { r3 } else { r4 }; + let den = r * k / m; + if den == 0.0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let si_i = 50.0 * (close[i] - c1 + (c1 - open[i]) + 0.5 * (close[i] - open[i])) / den; + si.push(si_i); + } + + let mut acc = 0.0; + let mut asi = Vec::with_capacity(si.len()); + for x in si { + acc += x; + asi.push(acc); + } + let asi_last = *asi.last().unwrap_or(&f64::NAN); + let asi_mean = mean_or_nan(&asi); + let v1 = if asi_last > asi_mean { "看多" } else { "看空" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cmo_up_dw_line_V230605:CMO 能量阈值信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}_CMO能量V230605"` +/// +/// 信号逻辑: +/// 1. 统计窗口内上涨/下跌收盘差值总和; +/// 2. 计算 `cmo = (up-dw)/(up+dw)*100`; +/// 3. `cmo > m` 判 `看多`;`cmo < -m` 判 `看空`;否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N70M30_CMO能量V230605_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N70M30_CMO能量V230605_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:统计窗口,默认 `70`; +/// - `m`:阈值,默认 `30`。 +/// 对齐说明:与 Python `cmo_up_dw_line_V230605` 保持同一阈值与分支顺序。 +#[signal( + category = "kline", + name = "cmo_up_dw_line_V230605", + template = "{freq}_D{di}N{n}M{m}_CMO能量V230605", + opcode = "CmoUpDwLineV230605", + param_kind = "CmoUpDwLineV230605" +)] +pub fn cmo_up_dw_line_v230605(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 70); + let m = params.usize("m", 30); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}", di, n, m); + let k3 = "CMO能量V230605"; + if c.bars_raw.len() < di + n + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut up_sum = 0.0; + let mut dw_sum = 0.0; + for i in 1..bars.len() { + let d = bars[i].close - bars[i - 1].close; + if d > 0.0 { + up_sum += d; + } else if d < 0.0 { + dw_sum += -d; + } + } + let cmo = (up_sum - dw_sum) / (up_sum + dw_sum) * 100.0; + let mut v1 = "其他"; + if cmo > m as f64 { + v1 = "看多"; + } + if cmo < -(m as f64) { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// skdj_up_dw_line_V230611:SKDJ 随机波动信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}UP{up}DW{dw}_SKDJ随机波动V230611"` +/// +/// 信号逻辑: +/// 1. 先计算 `RSV(n)` 序列; +/// 2. 对 RSV 做两次 `m` 周期均值平滑; +/// 3. `dw < D < K_last` 判 `看多`;`K_last < D 且 D > up` 判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N233M89UP60DW40_SKDJ随机波动V230611_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N233M89UP60DW40_SKDJ随机波动V230611_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:RSV 窗口,默认 `233`; +/// - `m`:平滑窗口,默认 `89`; +/// - `up`:超买阈值,默认 `60`; +/// - `dw`:超卖阈值,默认 `40`。 +/// 对齐说明:与 Python `skdj_up_dw_line_V230611` 的双平滑与阈值判定一致。 +#[signal( + category = "kline", + name = "skdj_up_dw_line_V230611", + template = "{freq}_D{di}N{n}M{m}UP{up}DW{dw}_SKDJ随机波动V230611", + opcode = "SkdjUpDwLineV230611", + param_kind = "SkdjUpDwLineV230611" +)] +pub fn skdj_up_dw_line_v230611(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 233); + let m = params.usize("m", 89); + let up = params.usize("up", 60); + let dw = params.usize("dw", 40); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}UP{}DW{}", di, n, m, up, dw); + let k3 = "SKDJ随机波动V230611"; + + if c.bars_raw.len() < di + m * 3 + 20 || n < m { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let cache_key = format!("RSV{}", n); + let mut old_map: HashMap = HashMap::new(); + if let (Some(ids), Some(vals)) = (cache.series_ids.get(&cache_key), cache.series.get(&cache_key)) + { + for (id, v) in ids.iter().zip(vals.iter()) { + old_map.insert(*id, *v); + } + } + + let mut rsv_series = Vec::with_capacity(c.bars_raw.len()); + let mut rsv_ids = Vec::with_capacity(c.bars_raw.len()); + for (i, bar) in c.bars_raw.iter().enumerate() { + rsv_ids.push(bar.id); + // 对齐 Python:历史 bar 的 RSV 只计算一次;同 dt 延伸时仅最后一根会重算。 + if i + 1 < c.bars_raw.len() { + if let Some(v) = old_map.get(&bar.id) { + rsv_series.push(*v); + continue; + } + } + let win = if i < n { + &c.bars_raw[..=i] + } else { + // 对齐 Python 原始实现:i>=n 分支直接使用 di=i 取子序列(保留其历史行为) + get_sub_elements(&c.bars_raw, i, n) + }; + let v = if win.is_empty() { + f64::NAN + } else { + let min_low = win.iter().fold(f64::INFINITY, |acc, b| acc.min(b.low)); + let max_high = win.iter().fold(f64::NEG_INFINITY, |acc, b| acc.max(b.high)); + let den = max_high - min_low; + if den == 0.0 { + f64::NAN + } else { + (bar.close - min_low) / den * 100.0 + } + }; + rsv_series.push(v); + } + cache.series.insert(cache_key.clone(), rsv_series.clone()); + cache.series_ids.insert(cache_key, rsv_ids); + + let bars = get_sub_elements(&c.bars_raw, di, m * 3 + 20); + if bars.len() < m * 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let start = c.bars_raw.len() - di + 1 - bars.len(); + let end = start + bars.len(); + let rsv = &rsv_series[start..end]; + let ma_rsv = sma_valid(rsv, m); + let k = sma_valid(&ma_rsv, m); + if k.len() < m { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let d = mean_or_nan(&k[k.len() - m..]); + let k_last = *k.last().unwrap_or(&f64::NAN); + + let mut v1 = "其他"; + if (dw as f64) < d && d < k_last { + v1 = "看多"; + } + if k_last < d && d > up as f64 { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bias_up_dw_line_V230618:BIAS 三周期共振信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}P{p}TH1{th1}TH2{th2}TH3{th3}_BIAS乖离率V230618"` +/// +/// 信号逻辑: +/// 1. 分别计算 `n/m/p` 三个窗口的均线乖离率; +/// 2. 三个乖离率同时超过正阈值判 `看多`; +/// 3. 三个乖离率同时低于负阈值判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N6M12P24TH11TH23TH35_BIAS乖离率V230618_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N6M12P24TH11TH23TH35_BIAS乖离率V230618_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n/m/p`:三组均线窗口,默认 `6/12/24`; +/// - `th1/th2/th3`:对应窗口阈值,默认 `1/3/5`。 +/// 对齐说明:与 Python `bias_up_dw_line_V230618` 的三阈值共振条件一致。 +#[signal( + category = "kline", + name = "bias_up_dw_line_V230618", + template = "{freq}_D{di}N{n}M{m}P{p}TH1{th1}TH2{th2}TH3{th3}_BIAS乖离率V230618", + opcode = "BiasUpDwLineV230618", + param_kind = "BiasUpDwLineV230618" +)] +pub fn bias_up_dw_line_v230618(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 6); + let m = params.usize("m", 12); + let p = params.usize("p", 24); + let th1 = params.usize("th1", 1); + let th2 = params.usize("th2", 3); + let th3 = params.usize("th3", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}P{}TH1{}TH2{}TH3{}", di, n, m, p, th1, th2, th3); + let k3 = "BIAS乖离率V230618"; + if c.bars_raw.len() < di + n.max(m).max(p) { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let b1 = get_sub_elements(&c.bars_raw, di, n); + let b2 = get_sub_elements(&c.bars_raw, di, m); + let b3 = get_sub_elements(&c.bars_raw, di, p); + if b1.is_empty() || b2.is_empty() || b3.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let ma1 = mean_or_nan(&b1.iter().map(|x| x.close).collect::>()); + let ma2 = mean_or_nan(&b2.iter().map(|x| x.close).collect::>()); + let ma3 = mean_or_nan(&b3.iter().map(|x| x.close).collect::>()); + let bias1 = (b1[b1.len() - 1].close - ma1) / ma1 * 100.0; + let bias2 = (b2[b2.len() - 1].close - ma2) / ma2 * 100.0; + let bias3 = (b3[b3.len() - 1].close - ma3) / ma3 * 100.0; + + let mut v1 = "其他"; + if bias1 > th1 as f64 && bias2 > th2 as f64 && bias3 > th3 as f64 { + v1 = "看多"; + } + if bias1 < -(th1 as f64) && bias2 < -(th2 as f64) && bias3 < -(th3 as f64) { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// dema_up_dw_line_V230605:DEMA 短线趋势信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}_DEMA短线趋势V230605"` +/// +/// 信号逻辑: +/// 1. 用 `n` 与 `2n` 窗口均值构造 `dema = 2*MA(n)-MA(2n)`; +/// 2. 最新收盘价高于 dema 判 `看多`,否则判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N5_DEMA短线趋势V230605_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N5_DEMA短线趋势V230605_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:短窗口,默认 `5`。 +/// 对齐说明:按 Python `dema_up_dw_line_V230605` 的近似 DEMA 口径实现。 +#[signal( + category = "kline", + name = "dema_up_dw_line_V230605", + template = "{freq}_D{di}N{n}_DEMA短线趋势V230605", + opcode = "DemaUpDwLineV230605", + param_kind = "DemaUpDwLineV230605" +)] +pub fn dema_up_dw_line_v230605(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}", di, n); + let k3 = "DEMA短线趋势V230605"; + if c.bars_raw.len() < di + 2 * n + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let short_bars = get_sub_elements(&c.bars_raw, di, n); + let long_bars = get_sub_elements(&c.bars_raw, di, n * 2); + if short_bars.is_empty() || long_bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let dema = 2.0 * mean_or_nan(&short_bars.iter().map(|x| x.close).collect::>()) + - mean_or_nan(&long_bars.iter().map(|x| x.close).collect::>()); + let v1 = if short_bars[short_bars.len() - 1].close > dema { + "看多" + } else { + "看空" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// demakder_up_dw_line_V230605:DEMAKER 价格趋势信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}TH{th}TL{tl}_DEMAKER价格趋势V230605"` +/// +/// 信号逻辑: +/// 1. 统计窗口内上涨高点均值 `demax` 与下跌低点均值 `demin`; +/// 2. 计算 `demaker = demax / (demax + demin)`; +/// 3. `demaker > th/10` 判 `看多`,`demaker < tl/10` 判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N105TH5TL5_DEMAKER价格趋势V230605_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N105TH5TL5_DEMAKER价格趋势V230605_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:统计窗口,默认 `105`; +/// - `th/tl`:上下阈值(除以 10 使用),默认 `5/5`。 +/// 对齐说明:保持 Python `demakder_up_dw_line_V230605` 对空样本返回 NaN 的行为。 +#[signal( + category = "kline", + name = "demakder_up_dw_line_V230605", + template = "{freq}_D{di}N{n}TH{th}TL{tl}_DEMAKER价格趋势V230605", + opcode = "DemakderUpDwLineV230605", + param_kind = "DemakderUpDwLineV230605" +)] +pub fn demakder_up_dw_line_v230605( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 105); + let th = params.usize("th", 5); + let tl = params.usize("tl", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}TH{}TL{}", di, n, th, tl); + let k3 = "DEMAKER价格趋势V230605"; + if c.bars_raw.len() < di + n + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut demax_items = Vec::new(); + let mut demin_items = Vec::new(); + for i in 1..bars.len() { + let dh = bars[i].high - bars[i - 1].high; + if dh > 0.0 { + demax_items.push(dh); + } + let dl = bars[i - 1].low - bars[i].low; + if dl > 0.0 { + demin_items.push(dl); + } + } + let demax = mean_or_nan(&demax_items); + let demin = mean_or_nan(&demin_items); + let demaker = demax / (demax + demin); + + let mut v1 = "其他"; + if demaker > th as f64 / 10.0 { + v1 = "看多"; + } + if demaker < tl as f64 / 10.0 { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// emv_up_dw_line_V230605:EMV 简易波动多空信号 +/// +/// 参数模板:`"{freq}_D{di}_EMV简易波动V230605"` +/// +/// 信号逻辑: +/// 1. 取最近两根K线计算中点位移; +/// 2. 以成交量/振幅形成箱体比率; +/// 3. `emv > 0` 判 `看多`,否则判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1_EMV简易波动V230605_看多_任意_任意_0')` +/// - `Signal('60分钟_D1_EMV简易波动V230605_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `emv_up_dw_line_V230605` 的两根K线近似 EMV 计算一致。 +#[signal( + category = "kline", + name = "emv_up_dw_line_V230605", + template = "{freq}_D{di}_EMV简易波动V230605", + opcode = "EmvUpDwLineV230605", + param_kind = "EmvUpDwLineV230605" +)] +pub fn emv_up_dw_line_v230605(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}", di); + let k3 = "EMV简易波动V230605"; + if c.bars_raw.len() < di + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, 2); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let mid_pt_move = (bars[1].high + bars[1].low) / 2.0 - (bars[0].high + bars[0].low) / 2.0; + let box_ratio = bars[1].vol / (bars[1].high - bars[1].low + 1e-9); + let emv = mid_pt_move / box_ratio; + let v1 = if emv > 0.0 { "看多" } else { "看空" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// er_up_dw_line_V230604:ER 价格动量分层信号 +/// +/// 参数模板:`"{freq}_D{di}W{w}N{n}_ER价格动量V230604"` +/// +/// 信号逻辑: +/// 1. 以 `W` 窗口均价构造 bull/bear power 因子; +/// 2. 仅保留与末值同号的因子子序列; +/// 3. 末值正负给出 `均线上方/均线下方`; +/// 4. 对同号子序列做 `N` 分箱输出 `第x层`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W60N10_ER价格动量V230604_均线上方_第3层_任意_0')` +/// - `Signal('60分钟_D1W60N10_ER价格动量V230604_均线下方_第8层_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `w`:均价窗口,默认 `60`; +/// - `n`:分层数量,默认 `10`。 +/// 对齐说明:与 Python `er_up_dw_line_V230604` 的同号过滤与分层规则一致。 +#[signal( + category = "kline", + name = "er_up_dw_line_V230604", + template = "{freq}_D{di}W{w}N{n}_ER价格动量V230604", + opcode = "ErUpDwLineV230604", + param_kind = "ErUpDwLineV230604" +)] +pub fn er_up_dw_line_v230604(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 60); + let n = params.usize("n", 10); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}N{}", di, w, n); + let k3 = "ER价格动量V230604"; + + let cache_key = format!("ER{}", w); + let mut old_map: HashMap = HashMap::new(); + if let (Some(ids), Some(vals)) = ( + cache.series_ids.get(&cache_key), + cache.series.get(&cache_key), + ) { + for (id, v) in ids.iter().zip(vals.iter()) { + old_map.insert(*id, *v); + } + } + + let mut out = Vec::with_capacity(c.bars_raw.len()); + let mut out_ids = Vec::with_capacity(c.bars_raw.len()); + for (i, bar) in c.bars_raw.iter().enumerate() { + out_ids.push(bar.id); + let is_last_bar = i + 1 == c.bars_raw.len(); + if is_last_bar { + // 对齐 Python 流式语义:最后一根未完成高周期 bar 在基准级别持续推进时, + // high/low 会变化,ER 末值需要随之刷新,不能仅按 bar id 复用旧缓存。 + } else if let Some(v) = old_map.get(&bar.id) { + out.push(*v); + continue; + } + // 对齐 Python: _bars = c.bars_raw[i-w:i](i 为 1-based) + // 注意这里必须保留 Python 负索引切片语义: + // - 当 len(c.bars_raw) < w 时,会返回一个递增前缀而不是空窗口; + // - 当 len(c.bars_raw) > w 且 start > end 时,结果为空切片。 + let i1 = i + 1; + let len = c.bars_raw.len(); + let raw_start = i1 as isize - w as isize; + let start = if raw_start < 0 { + (len as isize + raw_start).max(0) as usize + } else { + raw_start as usize + }; + let win = if start >= i1 { + &c.bars_raw[0..0] + } else { + &c.bars_raw[start..i1] + }; + let ma = mean_or_nan(&win.iter().map(|x| x.close).collect::>()); + let v = if bar.high > ma { + bar.high - ma + } else { + bar.low - ma + }; + out.push(v); + } + cache.series.insert(cache_key.clone(), out.clone()); + cache.series_ids.insert(cache_key, out_ids); + + // 对齐 Python:即使样本不足返回“其他”,也要先完成历史 bar 的 ER 缓存写入。 + if c.bars_raw.len() < di + w + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let bars = get_sub_elements(&c.bars_raw, di, w * 10); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let start = c.bars_raw.len() - di + 1 - bars.len(); + let end = start + bars.len(); + let mut factors = out[start..end].to_vec(); + let last = *factors.last().unwrap_or(&f64::NAN); + if !last.is_finite() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + factors.retain(|x| x.is_finite() && x * last > 0.0); + if factors.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let v1 = if last > 0.0 { "均线上方" } else { "均线下方" }; + let v2 = match pd_cut_last_label(&factors, n) { + Some(q) => format!("第{}层", q), + None => "其他".to_string(), + }; + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} diff --git a/crates/czsc-signals/src/bar.rs b/crates/czsc-signals/src/bar.rs new file mode 100644 index 000000000..f702d3d32 --- /dev/null +++ b/crates/czsc-signals/src/bar.rs @@ -0,0 +1,3350 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{ + get_sub_elements, intraday_time_segment, make_kline_signal_v1, make_kline_signal_v2, + make_kline_signal_v3, minute_freq_end_time, pd_cut_last_label, qcut_last_label, weekday_cn, +}; +use crate::utils::ta::{update_ma_cache, update_macd_cache}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_core::utils::corr::LinearRegression; +use czsc_signal_macros::signal; +use std::collections::HashMap; + +/// bar_single_V230506:单K趋势分层信号 +/// +/// 参数模板:`"{freq}_D{di}单K趋势N{n}_BS辅助V230506"` +/// +/// 信号逻辑: +/// 1. 取截止到倒数第 `di` 根的最近 100 根K线; +/// 2. 计算每根K线因子 `(close-open)/(open*vol)`; +/// 3. 参考 Python `pd.cut(..., n)` 将末根因子分层,输出 `第1层 ~ 第n层`; +/// 4. 若样本不足或存在 `open=0/vol=0`,返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1单K趋势N5_BS辅助V230506_第3层_任意_任意_0')` +/// - `Signal('60分钟_D1单K趋势N5_BS辅助V230506_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:分层数量,默认 `5`。 +#[signal( + category = "kline", + name = "bar_single_V230506", + template = "{freq}_D{di}单K趋势N{n}_BS辅助V230506", + opcode = "BarSingleV230506", + param_kind = "BarSingleV230506" +)] +pub fn bar_single_v230506(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}单K趋势N{}", di, n); + let k3 = "BS辅助V230506"; + + let mut v1 = "其他".to_string(); + + if c.bars_raw.len() >= 100 + di { + let bars = get_sub_elements(&c.bars_raw, di, 100); + if bars.len() < 100 { + return make_kline_signal_v1(&k1, &k2, k3, &v1); + } + let mut factors = Vec::with_capacity(100); + let mut valid = true; + for bar in bars { + // 与 Python bar_single_V230506 对齐: + // factors = [(x.close / x.open - 1) / x.vol for x in bars] + // 当 open/vol 为 0 时,Python 实测会走 safe 包装并回退为“其他”。 + // 这里直接对齐为默认信号,而不是返回空键。 + if bar.open == 0.0 || bar.vol == 0.0 { + valid = false; + break; + } + factors.push((bar.close - bar.open) / (bar.open * bar.vol)); + } + + if valid && !factors.is_empty() { + if let Some(q) = pd_cut_last_label(&factors, n) { + v1 = format!("第{}层", q); + } + } + } + + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// bar_zdt_V230331:涨跌停识别信号 +/// +/// 参数模板:`"{freq}_D{di}_涨跌停V230331"` +/// +/// 信号逻辑: +/// 1. 取倒数第 `di` 根与其前一根K线; +/// 2. 若当前K线收盘等于最高且不低于前收,记为 `涨停`; +/// 3. 若当前K线收盘等于最低且不高于前收,记为 `跌停`; +/// 4. 否则记为 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1_涨跌停V230331_涨停_任意_任意_0')` +/// - `Signal('60分钟_D1_涨跌停V230331_跌停_任意_任意_0')` +/// - `Signal('60分钟_D1_涨跌停V230331_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 +#[signal( + category = "kline", + name = "bar_zdt_V230331", + template = "{freq}_D{di}_涨跌停V230331", + opcode = "BarZdtV230331", + param_kind = "BarZdtV230331" +)] +pub fn bar_zdt_v230331(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}", di); + let k3 = "涨跌停V230331"; + let mut v1 = "其他".to_string(); + + let bars = get_sub_elements(&c.bars_raw, di, 2); + if bars.len() == 2 { + let b2 = &bars[0]; + let b1 = &bars[1]; + + let is_close_high = (b1.close - b1.high).abs() < 1e-6; + let is_close_low = (b1.close - b1.low).abs() < 1e-6; + + if is_close_high && b1.close >= b2.close { + v1 = "涨停".to_string(); + } else if is_close_low && b1.close <= b2.close { + v1 = "跌停".to_string(); + } + } + + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// bar_triple_V230506:三K加速形态信号 +/// +/// 参数模板:`"{freq}_D{di}三K加速_裸K形态V230506"` +/// +/// 信号逻辑: +/// 1. 取倒数第 `di` 根开始的最近3根K线; +/// 2. 三根连续阳线判定 `三连涨`,若高低点依次抬升判定 `新高涨`; +/// 3. 三根连续阴线判定 `三连跌`,若高低点依次下降判定 `新低跌`; +/// 4. 若已形成形态,再按成交量关系细分为 `依次放量/依次缩量/量柱无序`; +/// 5. 数据不足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1三K加速_裸K形态V230506_新高涨_依次放量_任意_0')` +/// - `Signal('60分钟_D1三K加速_裸K形态V230506_三连跌_量柱无序_任意_0')` +/// - `Signal('60分钟_D1三K加速_裸K形态V230506_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 +#[signal( + category = "kline", + name = "bar_triple_V230506", + template = "{freq}_D{di}三K加速_裸K形态V230506", + opcode = "BarTripleV230506", + param_kind = "BarTripleV230506" +)] +pub fn bar_triple_v230506(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}三K加速", di); + let k3 = "裸K形态V230506"; + + let mut v1 = "其他".to_string(); + let mut v2 = "任意".to_string(); + + // 对齐 Python: len(c.bars_raw) < 7 直接返回“其他” + // 同时保留 di 的安全边界,避免索引越界。 + if c.bars_raw.len() >= 7 { + let bars = get_sub_elements(&c.bars_raw, di, 3); + if bars.len() < 3 { + return make_kline_signal_v2(&k1, &k2, k3, &v1, &v2); + } + let b3 = &bars[0]; + let b2 = &bars[1]; + let b1 = &bars[2]; + + let red1 = b1.close > b1.open; + let red2 = b2.close > b2.open; + let red3 = b3.close > b3.open; + + let green1 = b1.close < b1.open; + let green2 = b2.close < b2.open; + let green3 = b3.close < b3.open; + + if red1 && red2 && red3 { + v1 = "三连涨".to_string(); + if b1.high > b2.high && b2.high > b3.high && b1.low > b2.low && b2.low > b3.low { + v1 = "新高涨".to_string(); + } + } + + if green1 && green2 && green3 { + v1 = "三连跌".to_string(); + if b1.high < b2.high && b2.high < b3.high && b1.low < b2.low && b2.low < b3.low { + v1 = "新低跌".to_string(); + } + } + + if v1 != "其他" { + if b1.vol > b2.vol && b2.vol > b3.vol { + v2 = "依次放量".to_string(); + } else if b1.vol < b2.vol && b2.vol < b3.vol { + v2 = "依次缩量".to_string(); + } else { + v2 = "量柱无序".to_string(); + } + } + } + + make_kline_signal_v2(&k1, &k2, k3, &v1, &v2) +} + +/// bar_end_V221211:判断大周期K线是否闭合 +/// +/// 参数模板:`"{freq}_{freq1}结束_BS辅助221211"` +/// +/// 信号逻辑: +/// 1. 以当前基础周期 `freq` 与目标分钟周期 `freq1` 计算当前K线对应结束时间; +/// 2. 从最新K线向前统计同属该结束时间的连续数量 `i`; +/// 3. 若 `end_time == last_dt` 判定 `闭合`,否则判定 `未闭{i}`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_60分钟结束_BS辅助221211_闭合_任意_任意_0')` +/// - `Signal('15分钟_60分钟结束_BS辅助221211_未闭2_任意_任意_0')` +/// +/// 参数说明: +/// - `freq1`:目标分钟周期,默认 `60分钟`。 +/// 对齐说明:闭合/未闭计数语义与 Python `bar_end_V221211` 保持一致。 +#[signal( + category = "kline", + name = "bar_end_V221211", + template = "{freq}_{freq1}结束_BS辅助221211", + opcode = "BarEndV221211", + param_kind = "BarEndV221211" +)] +pub fn bar_end_v221211(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let freq1 = params.str("freq1", "60分钟"); + let k1 = c.freq.to_string(); + let k2 = format!("{freq1}结束"); + let k3 = "BS辅助221211"; + + if !freq1.contains("分钟") || c.bars_raw.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let last_dt = c.bars_raw.last().map(|x| x.dt).unwrap(); + let Some(c1_dt) = minute_freq_end_time(last_dt, freq1) else { + return vec![]; + }; + let mut i = 0usize; + for bar in c.bars_raw.iter().rev() { + let Some(edt) = minute_freq_end_time(bar.dt, freq1) else { + break; + }; + if edt != c1_dt { + break; + } + i += 1; + } + + let v1 = if c1_dt == last_dt { + "闭合".to_string() + } else { + format!("未闭{i}") + }; + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// bar_operate_span_V221111:日内时间区间过滤 +/// +/// 参数模板:`"{freq}_T{t1}#{t2}_时间区间V221111"` +/// +/// 信号逻辑: +/// 1. 读取最新K线时间 `HHMM`; +/// 2. 若 `t1 <= HHMM <= t2` 判定 `是`,否则判定 `否`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_T0935#1450_时间区间_是_任意_任意_0')` +/// - `Signal('60分钟_T0935#1450_时间区间_否_任意_任意_0')` +/// +/// 参数说明: +/// - `t1`:起始时间(`HHMM`),默认 `0935`; +/// - `t2`:结束时间(`HHMM`),默认 `1450`。 +/// 对齐说明:边界包含比较与 Python `bar_operate_span_V221111` 一致。 +#[signal( + category = "kline", + name = "bar_operate_span_V221111", + template = "{freq}_T{t1}#{t2}_时间区间V221111", + opcode = "BarOperateSpanV221111", + param_kind = "BarOperateSpanV221111" +)] +pub fn bar_operate_span_v221111(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let t1 = params.str("t1", "0935"); + let t2 = params.str("t2", "1450"); + let k1 = c.freq.to_string(); + let k2 = format!("T{t1}#{t2}"); + let k3 = "时间区间"; + if c.bars_raw.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let dt = c.bars_raw.last().unwrap().dt; + let hm = dt.format("%H%M").to_string(); + let v1 = if t1 <= hm.as_str() && hm.as_str() <= t2 { + "是" + } else { + "否" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_time_V230327:日内时间分段信号 +/// +/// 参数模板:`"{freq}_日内时间_分段V230327"` +/// +/// 信号逻辑: +/// 1. 仅支持 `30分钟/60分钟` 周期; +/// 2. 取最近 100 根K线的 `HH:MM` 去重并排序; +/// 3. 输出当前K线时间在分段序列中的位置:`第{n}段`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_日内时间_分段V230327_第1段_任意_任意_0')` +/// - `Signal('60分钟_日内时间_分段V230327_第4段_任意_任意_0')` +/// +/// 参数说明: +/// - 无额外参数。 +/// 对齐说明:分段生成与 Python `bar_time_V230327` 的排序与编号口径一致。 +#[signal( + category = "kline", + name = "bar_time_V230327", + template = "{freq}_日内时间_分段V230327", + opcode = "BarTimeV230327", + param_kind = "BarTimeV230327" +)] +pub fn bar_time_v230327(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "日内时间"; + let k3 = "分段V230327"; + let mut v1 = "其他".to_string(); + if c.freq.to_string() != "30分钟" && c.freq.to_string() != "60分钟" { + return make_kline_signal_v1(&k1, k2, k3, &v1); + } + if let Some(seg) = intraday_time_segment(&c.bars_raw, 100) { + v1 = format!("第{}段", seg); + } + make_kline_signal_v1(&k1, k2, k3, &v1) +} + +/// bar_weekday_V230328:周内时间分段信号 +/// +/// 参数模板:`"{freq}_周内时间_分段V230328"` +/// +/// 信号逻辑: +/// 1. 当样本数量不足 20 根时返回 `其他`; +/// 2. 否则将最新K线日期按 `weekday` 映射到 `周一~周日`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_周内时间_分段V230328_周一_任意_任意_0')` +/// - `Signal('60分钟_周内时间_分段V230328_周五_任意_任意_0')` +/// +/// 参数说明: +/// - 无额外参数。 +/// 对齐说明:weekday 映射表与 Python `bar_weekday_V230328` 一致。 +#[signal( + category = "kline", + name = "bar_weekday_V230328", + template = "{freq}_周内时间_分段V230328", + opcode = "BarWeekdayV230328", + param_kind = "BarWeekdayV230328" +)] +pub fn bar_weekday_v230328(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "周内时间"; + let k3 = "分段V230328"; + let mut v1 = "其他"; + if c.bars_raw.len() >= 20 { + v1 = weekday_cn(c.bars_raw.last().unwrap().dt); + } + make_kline_signal_v1(&k1, k2, k3, v1) +} + +/// bar_vol_grow_V221112:成交量放大信号 +/// +/// 参数模板:`"{freq}_D{di}K{n}B_放量V221112"` +/// +/// 信号逻辑: +/// 1. 取倒数第 `di` 根及其前 `n` 根,共 `n+1` 根K线; +/// 2. 计算前 `n` 根平均成交量 `mean_vol`; +/// 3. 若当前量在 `[2*mean_vol, 4*mean_vol]`,判 `是`,否则 `否`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D2K5B_放量V221112_是_任意_任意_0')` +/// - `Signal('60分钟_D2K5B_放量V221112_否_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `2`; +/// - `n`:回看K线数量,默认 `5`。 +/// 对齐说明:判定区间与 Python `bar_vol_grow_V221112` 保持一致。 +#[signal( + category = "kline", + name = "bar_vol_grow_V221112", + template = "{freq}_D{di}K{n}B_放量V221112", + opcode = "BarVolGrowV221112", + param_kind = "BarVolGrowV221112" +)] +pub fn bar_vol_grow_v221112(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 2); + let n = params.usize("n", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K{}B", di, n); + let k3 = "放量V221112"; + + let v1 = if c.bars_raw.len() < di + n + 10 { + "其他" + } else { + let bars = get_sub_elements(&c.bars_raw, di, n + 1); + if bars.len() != n + 1 { + "其他" + } else { + let mean_vol = bars[..n].iter().map(|x| x.vol).sum::() / n as f64; + if bars[n].vol >= mean_vol * 2.0 && bars[n].vol <= mean_vol * 4.0 { + "是" + } else { + "否" + } + } + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_mean_amount_V221112:区间均额分类信号 +/// +/// 参数模板:`"{freq}_D{di}K{n}B均额_{th1}至{th2}千万"` +/// +/// 信号逻辑: +/// 1. 取倒数第 `di` 根截止的最近 `n` 根K线; +/// 2. 计算平均成交额 `m`; +/// 3. 若 `m/1e7` 在 `[th1, th2]` 判 `是`,否则判 `否`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K10B均额_1至4千万_是_任意_任意_0')` +/// - `Signal('60分钟_D1K10B均额_1至4千万_否_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:样本长度,默认 `10`; +/// - `th1`:下限(千万),默认 `1`; +/// - `th2`:上限(千万),默认 `4`。 +/// 对齐说明:均额口径与 Python `bar_mean_amount_V221112` 保持一致。 +#[signal( + category = "kline", + name = "bar_mean_amount_V221112", + template = "{freq}_D{di}K{n}B均额_{th1}至{th2}千万V221112", + opcode = "BarMeanAmountV221112", + param_kind = "BarMeanAmountV221112" +)] +pub fn bar_mean_amount_v221112(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 10); + let th1 = params.usize("th1", 1); + let th2 = params.usize("th2", 4); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}K{}B均额", di, n); + let k3 = format!("{}至{}千万", th1, th2); + + let mut v1 = "其他"; + if c.bars_raw.len() > di + n + 5 { + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.len() == n { + let m = bars.iter().map(|x| x.amount).sum::() / n as f64 / 10_000_000.0; + v1 = if m >= th1 as f64 && m <= th2 as f64 { + "是" + } else { + "否" + }; + } + } + make_kline_signal_v1(&k1, &k2, &k3, v1) +} + +/// bar_zdf_V221203:单根涨跌幅区间信号 +/// +/// 参数模板:`"{freq}_D{di}{mode}_{t1}至{t2}"` +/// +/// 信号逻辑: +/// 1. 读取倒数第 `di` 根及其前一根K线; +/// 2. `mode=ZF` 使用涨幅 `close/prev_close-1`,`mode=DF` 使用跌幅 `1-close/prev_close`; +/// 3. 换算为 BP 后在 `[t1, t2]` 判 `满足`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1ZF_300至600_满足_任意_任意_0')` +/// - `Signal('日线_D1DF_300至600_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `mode`:`ZF` 或 `DF`,默认 `ZF`; +/// - `span`:区间下上界(`t1,t2`),默认 `300,600`。 +/// 对齐说明:BP 计算与 Python `bar_zdf_V221203` 保持一致。 +#[signal( + category = "kline", + name = "bar_zdf_V221203", + template = "{freq}_D{di}{mode}_{t1}至{t2}V221203", + opcode = "BarZdfV221203", + param_kind = "BarZdfV221203" +)] +pub fn bar_zdf_v221203(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let mode = params.str("mode", "ZF").to_uppercase(); + let span = params.str("span", "300,600"); + let parts: Vec<&str> = span.split(',').collect(); + let t1 = parts + .first() + .and_then(|x| x.parse::().ok()) + .unwrap_or(300.0); + let t2 = parts + .get(1) + .and_then(|x| x.parse::().ok()) + .unwrap_or(600.0); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}{}", di, mode); + let k3 = format!("{}至{}", t1 as i32, t2 as i32); + + let bars = get_sub_elements(&c.bars_raw, di, 3); + if bars.len() < 2 || t2 <= t1 || t1 <= 0.0 { + return make_kline_signal_v1(&k1, &k2, &k3, "其他"); + } + let prev = bars[bars.len() - 2].close; + let last = bars[bars.len() - 1].close; + if prev == 0.0 { + return make_kline_signal_v1(&k1, &k2, &k3, "其他"); + } + let edge = if mode == "ZF" { + (last / prev - 1.0) * 10_000.0 + } else { + (1.0 - last / prev) * 10_000.0 + }; + let v1 = if edge >= t1 && edge <= t2 { + "满足" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, &k3, v1) +} + +/// bar_amount_acc_V230214:区间累计成交额信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}_累计超{t}千万"` +/// +/// 信号逻辑: +/// 1. 取倒数第 `di` 根截止的最近 `n` 根K线; +/// 2. 计算累计成交额 `sum(amount)`; +/// 3. 若大于 `t * 1e7` 判 `是`,否则 `否`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D2N5_累计超10千万_是_任意_任意_0')` +/// - `Signal('日线_D2N5_累计超10千万_否_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `2`; +/// - `n`:回看K线数,默认 `5`; +/// - `t`:阈值(千万),默认 `10`。 +/// 对齐说明:累计金额阈值判断与 Python `bar_amount_acc_V230214` 一致。 +#[signal( + category = "kline", + name = "bar_amount_acc_V230214", + template = "{freq}_D{di}N{n}_累计超{t}千万V230214", + opcode = "BarAmountAccV230214", + param_kind = "BarAmountAccV230214" +)] +pub fn bar_amount_acc_v230214(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 2); + let n = params.usize("n", 5); + let t = params.usize("t", 10); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}", di, n); + let k3 = format!("累计超{}千万", t); + let mut v1 = "其他"; + + if c.bars_raw.len() > di + n + 5 { + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.len() == n { + let acc = bars.iter().map(|x| x.amount).sum::(); + v1 = if acc > t as f64 * 10_000_000.0 { + "是" + } else { + "否" + }; + } + } + make_kline_signal_v1(&k1, &k2, &k3, v1) +} + +/// bar_single_V230214:单K状态信号 +/// +/// 参数模板:`"{freq}_D{di}T{t}_状态V230214"` +/// +/// 信号逻辑: +/// 1. 倒数第 `di` 根K线,按 `close/open` 判 `阳线/阴线`; +/// 2. 若 `solid > (upper+lower)*t/10` 判 `长实体`; +/// 3. 若 `upper > (solid+lower)*t/10` 判 `长上影`; +/// 4. 若 `lower > (solid+upper)*t/10` 判 `长下影`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1T10_状态V230214_阳线_长实体_任意_0')` +/// - `Signal('日线_D1T10_状态V230214_阴线_长上影_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `t`:长实体/长影阈值(/10),默认 `10`。 +/// 对齐说明:分类阈值与 Python `bar_single_V230214` 保持一致。 +#[signal( + category = "kline", + name = "bar_single_V230214", + template = "{freq}_D{di}T{t}_状态V230214", + opcode = "BarSingleV230214", + param_kind = "BarSingleV230214" +)] +pub fn bar_single_v230214(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let t = params.usize("t", 10); + let k1 = c.freq.to_string(); + let k2 = format!("D{}T{}", di, t); + let k3 = "状态"; + + if c.bars_raw.len() < di + 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let k = &c.bars_raw[c.bars_raw.len() - di]; + let v1 = if k.close > k.open { "阳线" } else { "阴线" }; + let solid = (k.open - k.close).abs(); + let upper = k.high - k.open.max(k.close); + let lower = k.open.min(k.close) - k.low; + let v2 = if solid > (upper + lower) * t as f64 / 10.0 { + "长实体" + } else if upper > (solid + lower) * t as f64 / 10.0 { + "长上影" + } else if lower > (solid + upper) * t as f64 / 10.0 { + "长下影" + } else { + "其他" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// bar_big_solid_V230215:窗口最大实体中位多空信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}_MIDV230215"` +/// +/// 信号逻辑: +/// 1. 在窗口内找到实体最大K线; +/// 2. 取该K线实体中位价 `mid`; +/// 3. 最新收盘价高于 `mid` 判 `看多`,否则 `看空`; +/// 4. 最大实体K线按方向标注 `大阳/大阴`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1N20_MIDV230215_看多_大阳_任意_0')` +/// - `Signal('日线_D1N20_MIDV230215_看空_大阴_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:窗口长度,默认 `20`。 +/// 对齐说明:最大实体与中位价定义对齐 Python `bar_big_solid_V230215`。 +#[signal( + category = "kline", + name = "bar_big_solid_V230215", + template = "{freq}_D{di}N{n}_MIDV230215", + opcode = "BarBigSolidV230215", + param_kind = "BarBigSolidV230215" +)] +pub fn bar_big_solid_v230215(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 20); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}", di, n); + let k3 = "MID"; + + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.is_empty() || c.bars_raw.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let mut max_i = 0usize; + let mut max_solid = f64::NEG_INFINITY; + for (i, b) in bars.iter().enumerate() { + let s = (b.open - b.close).abs(); + if s > max_solid { + max_solid = s; + max_i = i; + } + } + let b = &bars[max_i]; + let max_mid = b.open.min(b.close) + 0.5 * (b.open - b.close).abs(); + let v1 = if c.bars_raw.last().map(|x| x.close).unwrap_or(max_mid) > max_mid { + "看多" + } else { + "看空" + }; + let v2 = if b.close > b.open { "大阳" } else { "大阴" }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// bar_bpm_V230227:绝对动量分层 +/// +/// 参数模板:`"{freq}_D{di}N{n}T{th}_绝对动量V230227"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 根,计算区间 BP:`(last_close/first_open-1)*10000`; +/// 2. `bp>0` 时,`bp>th` 判 `超强` 否则 `强势`; +/// 3. `bp<=0` 时,`|bp|>th` 判 `超弱` 否则 `弱势`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N20T1000_绝对动量V230227_强势_任意_任意_0')` +/// - `Signal('60分钟_D1N20T1000_绝对动量V230227_超弱_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:窗口长度,默认 `20`; +/// - `th`:强弱阈值(BP),默认 `1000`。 +/// 对齐说明:分层规则与 Python `bar_bpm_V230227` 保持一致。 +#[signal( + category = "kline", + name = "bar_bpm_V230227", + template = "{freq}_D{di}N{n}T{th}_绝对动量V230227", + opcode = "BarBpmV230227", + param_kind = "BarBpmV230227" +)] +pub fn bar_bpm_v230227(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 20); + let th = params.usize("th", 1000); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}T{}", di, n, th); + let k3 = "绝对动量V230227"; + + if c.bars_raw.len() < di + n { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.len() < n || bars[0].open == 0.0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bp = (bars[bars.len() - 1].close / bars[0].open - 1.0) * 10_000.0; + let v1 = if bp > 0.0 { + if bp > th as f64 { + "超强" + } else { + "强势" + } + } else if bp.abs() > th as f64 { + "超弱" + } else { + "弱势" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_section_momentum_V221112:区间动量强弱与波动 +/// +/// 参数模板:`"{freq}_D{di}K{n}B_阈值{th}BPV221112"` +/// +/// 信号逻辑: +/// 1. 区间 BP:`(last_close/first_open-1)*10000`; +/// 2. 区间波动:`(max_high/min_low-1)*10000`; +/// 3. `v1`:`上涨/下跌`;`v2`:`强势/弱势`(`|bp|>=th`); +/// 4. `v3`:`高波动/低波动`(`|wave|/|bp| >= 3`)。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K10B_阈值100BPV221112_上涨_强势_高波动_0')` +/// - `Signal('60分钟_D1K10B_阈值100BPV221112_下跌_弱势_低波动_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:窗口长度,默认 `10`; +/// - `th`:强弱阈值(BP),默认 `100`。 +/// 对齐说明:三段分类与 Python `bar_section_momentum_V221112` 一致。 +#[signal( + category = "kline", + name = "bar_section_momentum_V221112", + template = "{freq}_D{di}K{n}B_阈值{th}BPV221112", + opcode = "BarSectionMomentumV221112", + param_kind = "BarSectionMomentumV221112" +)] +pub fn bar_section_momentum_v221112( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 10); + let th = params.usize("th", 100); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K{}B", di, n); + let k3 = format!("阈值{}BP", th); + + if c.bars_raw.len() < di + n { + return make_kline_signal_v3(&k1, &k2, &k3, "其他", "其他", "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.len() < n || bars[0].open == 0.0 { + return make_kline_signal_v3(&k1, &k2, &k3, "其他", "其他", "其他"); + } + let bp = (bars[bars.len() - 1].close / bars[0].open - 1.0) * 10_000.0; + let high = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let low = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + if low == 0.0 || !high.is_finite() || !low.is_finite() { + return make_kline_signal_v3(&k1, &k2, &k3, "其他", "其他", "其他"); + } + let wave = (high / low - 1.0) * 10_000.0; + let rate = if bp.abs() == 0.0 { + 0.0 + } else { + wave.abs() / bp.abs() + }; + let v1 = if bp >= 0.0 { "上涨" } else { "下跌" }; + let v2 = if bp.abs() >= th as f64 { + "强势" + } else { + "弱势" + }; + let v3 = if rate >= 3.0 { + "高波动" + } else { + "低波动" + }; + make_kline_signal_v3(&k1, &k2, &k3, v1, v2, v3) +} + +/// bar_vol_bs1_V230224:量价高低点辅助 +/// +/// 参数模板:`"{freq}_D{di}N{n}量价_BS1辅助V230224"` +/// +/// 信号逻辑: +/// 1. 窗口末根创新高且上影显著、成交额远高于均值,判 `看空`; +/// 2. 窗口末根创新低且下影显著、成交额远低于均值,判 `看多`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N20量价_BS1辅助V230224_看空_任意_任意_0')` +/// - `Signal('60分钟_D1N20量价_BS1辅助V230224_看多_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:窗口长度,默认 `20`。 +/// 对齐说明:量价条件阈值与 Python `bar_vol_bs1_V230224` 一致。 +#[signal( + category = "kline", + name = "bar_vol_bs1_V230224", + template = "{freq}_D{di}N{n}量价_BS1辅助V230224", + opcode = "BarVolBs1V230224", + param_kind = "BarVolBs1V230224" +)] +pub fn bar_vol_bs1_v230224(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 20); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}量价", di, n); + let k3 = "BS1辅助"; + + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let last = &bars[bars.len() - 1]; + let mean_amount = bars.iter().map(|x| x.amount).sum::() / bars.len() as f64; + let max_high = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let min_low = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + + let last_upper = last.high - last.open.max(last.close); + let last_lower = last.open.min(last.close) - last.low; + let short_c1 = (last.high - max_high).abs() <= f64::EPSILON + && last_upper > 2.0 * last_lower + && last_lower > 0.0; + let short_c2 = last.amount > mean_amount * 3.0; + let long_c1 = (last.low - min_low).abs() <= f64::EPSILON + && last_lower > 2.0 * last_upper + && last_upper > 0.0; + let long_c2 = last.amount < mean_amount * 0.7; + + let v1 = if short_c1 && short_c2 { + "看空" + } else if long_c1 && long_c2 { + "看多" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_zt_count_V230504:窗口涨停计数 +/// +/// 参数模板:`"{freq}_D{di}W{window}涨停计数_裸K形态V230504"` +/// +/// 信号逻辑: +/// 1. 在窗口内按相邻K线判断 `涨停`:`b2.close > b1.close*1.07 && b2.close==b2.high`; +/// 2. 统计总次数 `sum(c1)`; +/// 3. 统计连续双涨停次数 `cc`(相邻两个都为1); +/// 4. 若总次数为0返回 `其他`,否则输出 `"{sum}次" + "连续{cc}次"`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1W5涨停计数_裸K形态V230504_1次_连续0次_任意_0')` +/// - `Signal('日线_D1W5涨停计数_裸K形态V230504_3次_连续2次_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `window`:统计窗口,默认 `5`。 +/// 对齐说明:涨停阈值与连续计次与 Python `bar_zt_count_V230504` 一致。 +#[signal( + category = "kline", + name = "bar_zt_count_V230504", + template = "{freq}_D{di}W{window}涨停计数_裸K形态V230504", + opcode = "BarZtCountV230504", + param_kind = "BarZtCountV230504" +)] +pub fn bar_zt_count_v230504(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let window = params.usize("window", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}涨停计数", di, window); + let k3 = "裸K形态V230504"; + + if c.freq.to_string() != "日线" { + return vec![]; + } + if c.bars_raw.len() < 7 + di + window { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, window); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut c1: Vec = Vec::with_capacity(bars.len() - 1); + let mut cc = 0i32; + for w in bars.windows(2) { + let b1 = &w[0]; + let b2 = &w[1]; + let is_zt = b2.close > b1.close * 1.07 && (b2.close - b2.high).abs() <= f64::EPSILON; + c1.push(if is_zt { 1 } else { 0 }); + if c1.len() >= 2 && c1[c1.len() - 1] == 1 && c1[c1.len() - 2] == 1 { + cc += 1; + } + } + let sum_zt: i32 = c1.iter().sum(); + if sum_zt == 0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let v1 = format!("{}次", sum_zt); + let v2 = format!("连续{}次", cc); + make_kline_signal_v2(&k1, &k2, k3, &v1, &v2) +} + +/// bar_limit_down_V230525:跌停后反包阳线 +/// +/// 参数模板:`"{freq}_跌停后无下影线长实体阳线_短线V230525"` +/// +/// 信号逻辑: +/// 1. 仅日线级别; +/// 2. 前一日近似跌停:`low==closeopen && solid>2*upper && close/open>1.07`; +/// 4. 且当日最低低于前日最低,判 `满足`。 +/// +/// 信号列表示例: +/// - `Signal('日线_跌停后无下影线长实体阳线_短线V230525_满足_任意_任意_0')` +/// - `Signal('日线_跌停后无下影线长实体阳线_短线V230525_其他_任意_任意_0')` +/// +/// 参数说明: +/// - 无额外参数。 +/// 对齐说明:条件组合与 Python `bar_limit_down_V230525` 保持一致。 +#[signal( + category = "kline", + name = "bar_limit_down_V230525", + template = "{freq}_跌停后无下影线长实体阳线_短线V230525", + opcode = "BarLimitDownV230525", + param_kind = "BarLimitDownV230525" +)] +pub fn bar_limit_down_v230525(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "跌停后无下影线长实体阳线"; + let k3 = "短线V230525"; + if k1 != "日线" { + return vec![]; + } + if c.bars_raw.len() < 10 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let b1 = &c.bars_raw[c.bars_raw.len() - 3]; + let b2 = &c.bars_raw[c.bars_raw.len() - 2]; + let b3 = &c.bars_raw[c.bars_raw.len() - 1]; + let b2_condition = (b2.low - b2.close).abs() <= f64::EPSILON + && b2.close < b1.close + && b2.close / b1.close < 0.95; + let b3_solid = (b3.open - b3.close).abs(); + let b3_upper = b3.high - b3.open.max(b3.close); + let b3_condition = (b3.low - b3.open).abs() <= f64::EPSILON + && b3.close > b3.open + && b3_solid > b3_upper * 2.0 + && b3.close / b3.open > 1.07; + let v1 = if b2_condition && b3_condition && b3.low < b2.low { + "满足" + } else { + "其他" + }; + make_kline_signal_v1(&k1, k2, k3, v1) +} + +#[inline] +fn bar_solid(b: &czsc_core::objects::bar::RawBar) -> f64 { + (b.open - b.close).abs() +} + +#[inline] +fn bar_upper(b: &czsc_core::objects::bar::RawBar) -> f64 { + b.high - b.open.max(b.close) +} + +#[inline] +fn bar_lower(b: &czsc_core::objects::bar::RawBar) -> f64 { + b.open.min(b.close) - b.low +} + +fn percentile_linear(values: &[f64], p: f64) -> Option { + if values.is_empty() || !p.is_finite() { + return None; + } + let mut x: Vec = values.iter().copied().filter(|v| v.is_finite()).collect(); + if x.is_empty() { + return None; + } + x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + if x.len() == 1 { + return Some(x[0]); + } + let q = p.clamp(0.0, 100.0) / 100.0; + let h = (x.len() - 1) as f64 * q; + let i = h.floor() as usize; + let j = h.ceil() as usize; + if i == j { + Some(x[i]) + } else { + Some(x[i] + (h - i as f64) * (x[j] - x[i])) + } +} + +fn overlap_center(bars: &[czsc_core::objects::bar::RawBar]) -> (bool, Option, Option) { + if bars.is_empty() { + return (false, None, None); + } + let min_high = bars.iter().map(|x| x.high).fold(f64::INFINITY, f64::min); + let max_low = bars.iter().map(|x| x.low).fold(f64::NEG_INFINITY, f64::max); + if min_high > max_low { + let dd = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let gg = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + (true, Some(dd), Some(gg)) + } else { + (false, None, None) + } +} + +/// bar_accelerate_V221110:区间加速走势判定 +/// +/// 参数模板:`"{freq}_D{di}W{window}_加速V221110"` +/// +/// 信号逻辑: +/// 1. 取倒数第 `di` 根截止的最近 `window` 根K线,计算区间最高/最低; +/// 2. 若末根收盘位于区间上20%且阳线占比>=80%,判 `上涨`; +/// 3. 若末根收盘位于区间下20%且阴线占比>=80%,判 `下跌`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W13_加速V221110_上涨_任意_任意_0')` +/// - `Signal('60分钟_D1W13_加速V221110_下跌_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `window`:观察窗口长度,默认 `10`。 +/// 对齐说明:收盘位置与阳阴占比阈值对齐 Python `bar_accelerate_V221110`。 +#[signal( + category = "kline", + name = "bar_accelerate_V221110", + template = "{freq}_D{di}W{window}_加速V221110", + opcode = "BarAccelerateV221110", + param_kind = "BarAccelerateV221110" +)] +pub fn bar_accelerate_v221110(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let window = params.usize("window", 10); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}", di, window); + let k3 = "加速V221110"; + + let mut v1 = "其他"; + if c.bars_raw.len() > di + window + 10 { + let bars = get_sub_elements(&c.bars_raw, di, window); + if bars.len() == window { + let hhv = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let llv = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let close = bars[bars.len() - 1].close; + let c1 = close > llv + (hhv - llv) * 0.8; + let c2 = close < llv + (hhv - llv) * 0.2; + let red_pct = + bars.iter().filter(|x| x.close > x.open).count() as f64 / bars.len() as f64 >= 0.8; + let green_pct = + bars.iter().filter(|x| x.close < x.open).count() as f64 / bars.len() as f64 >= 0.8; + if c1 && red_pct { + v1 = "上涨"; + } + if c2 && green_pct { + v1 = "下跌"; + } + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_accelerate_V221118:均线偏离加速判定 +/// +/// 参数模板:`"{freq}_D{di}W{window}#{ma_type}#{timeperiod}_加速V221118"` +/// +/// 信号逻辑: +/// 1. 计算窗口内每根 `close - ma` 偏离值; +/// 2. 全部偏离为正,且最后三根偏离值递增,判 `上涨`; +/// 3. 全部偏离为负,且最后三根偏离值递减,判 `下跌`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1W13#SMA#10_加速V221118_上涨_任意_任意_0')` +/// - `Signal('日线_D1W13#SMA#10_加速V221118_下跌_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `window`:观察窗口,默认 `13`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `10`。 +/// 对齐说明:偏离序列与三根单调条件对齐 Python `bar_accelerate_V221118`。 +#[signal( + category = "kline", + name = "bar_accelerate_V221118", + template = "{freq}_D{di}W{window}#{ma_type}#{timeperiod}_加速V221118", + opcode = "BarAccelerateV221118", + param_kind = "BarAccelerateV221118" +)] +pub fn bar_accelerate_v221118(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let window = params.usize("window", 13); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let timeperiod = params.usize("timeperiod", 10); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}#{}#{}", di, window, ma_type, timeperiod); + let k3 = "加速V221118"; + + if window <= 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let cache_key = format!("{}_{}_{}", c.freq, ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let bars = get_sub_elements(&c.bars_raw, di, window); + let ma_sub = get_sub_elements(ma.as_slice(), di, window); + if bars.len() != window || ma_sub.len() != window { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let delta: Vec = bars + .iter() + .zip(ma_sub.iter()) + .map(|(b, m)| b.close - *m) + .collect(); + let all_pos = delta.iter().all(|x| *x > 0.0); + let all_neg = delta.iter().all(|x| *x < 0.0); + let n = delta.len(); + let v1 = if all_pos && delta[n - 1] > delta[n - 2] && delta[n - 2] > delta[n - 3] { + "上涨" + } else if all_neg && delta[n - 1] < delta[n - 2] && delta[n - 2] < delta[n - 3] { + "下跌" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_accelerate_V240428:滚动差分加速判定 +/// +/// 参数模板:`"{freq}_D{di}W{w}T{t}_加速V240428"` +/// +/// 信号逻辑: +/// 1. 计算 `diff = close - close[w]`,取最近300根 `|diff|` 的75分位阈值; +/// 2. 若最新 `|diff|` 超阈且 `diff>0`,窗口内倍量阳线数>=`t` 判 `上涨`; +/// 3. 若最新 `|diff|` 超阈且 `diff<0`,窗口内倍量阴线数>=`t` 判 `下跌`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1W21T2_加速V240428_上涨_任意_任意_0')` +/// - `Signal('日线_D1W21T2_加速V240428_下跌_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `w`:差分窗口,默认 `21`; +/// - `t`:倍量同向K线最小数量,默认 `1`。 +/// 对齐说明:阈值分位与倍量计数口径对齐 Python `bar_accelerate_V240428`。 +#[signal( + category = "kline", + name = "bar_accelerate_V240428", + template = "{freq}_D{di}W{w}T{t}_加速V240428", + opcode = "BarAccelerateV240428", + param_kind = "BarAccelerateV240428" +)] +pub fn bar_accelerate_v240428(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 21); + let t = params.usize("t", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}T{}", di, w, t); + let k3 = "加速V240428"; + + if c.bars_raw.len() < w + 100 + di || w == 0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut diff = vec![f64::NAN; c.bars_raw.len()]; + for (i, diff_i) in diff.iter_mut().enumerate().skip(w) { + *diff_i = c.bars_raw[i].close - c.bars_raw[i - w].close; + } + let start = c.bars_raw.len().saturating_sub(300); + let diff_abs: Vec = diff[start..] + .iter() + .copied() + .filter(|x| x.is_finite()) + .map(f64::abs) + .collect(); + let Some(th) = percentile_linear(&diff_abs, 75.0) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let last_diff = diff[c.bars_raw.len() - 1]; + if !last_diff.is_finite() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut v1 = "其他"; + if last_diff.abs() > th && last_diff > 0.0 { + let mut cnt = 0usize; + for pair in bars.windows(2) { + let b1 = &pair[0]; + let b2 = &pair[1]; + if b2.close > b2.open && b2.vol > b1.vol * 2.0 { + cnt += 1; + } + } + if cnt >= t { + v1 = "上涨"; + } + } + if last_diff.abs() > th && last_diff < 0.0 { + let mut cnt = 0usize; + for pair in bars.windows(2) { + let b1 = &pair[0]; + let b2 = &pair[1]; + if b2.close < b2.open && b2.vol > b1.vol * 2.0 { + cnt += 1; + } + } + if cnt >= t { + v1 = "下跌"; + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_fake_break_V230204:区间假突破判定 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}_假突破V230204"` +/// +/// 信号逻辑: +/// 1. 在最近 `N` 根内寻找滑动 `M` 窗口重叠中枢; +/// 2. 阳线末根创新高且“跌破DD后拉回”,判 `看多`; +/// 3. 阴线末根创新低且“突破GG后回落”,判 `看空`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1N20M5_假突破_看多_任意_任意_0')` +/// - `Signal('15分钟_D1N20M5_假突破_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:观察窗口,默认 `20`; +/// - `m`:中枢滑窗,默认 `5`。 +/// 对齐说明:中枢重叠与真假突破条件对齐 Python `bar_fake_break_V230204`。 +#[signal( + category = "kline", + name = "bar_fake_break_V230204", + template = "{freq}_D{di}N{n}M{m}_假突破V230204", + opcode = "BarFakeBreakV230204", + param_kind = "BarFakeBreakV230204" +)] +pub fn bar_fake_break_v230204(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 20); + let m = params.usize("m", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}", di, n, m); + let k3 = "假突破"; + + let last_bars = get_sub_elements(&c.bars_raw, di, n); + if last_bars.len() != n { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let last = &last_bars[last_bars.len() - 1]; + if bar_solid(last) < bar_upper(last) + bar_lower(last) { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut right_bars: Vec = Vec::new(); + let mut dd = 0.0f64; + let mut gg = 0.0f64; + if n > 2 * m { + for i in m..(n - m) { + let a = n - i - m; + let b = n - i; + let (ok, _dd, _gg) = overlap_center(&last_bars[a..b]); + if ok { + dd = _dd.unwrap_or(0.0); + gg = _gg.unwrap_or(0.0); + right_bars = last_bars[n - i..].to_vec(); + break; + } + } + } + + let mut v1 = "其他"; + if last.close > last.open { + let c1_a = (last.high + - last_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max)) + .abs() + <= f64::EPSILON; + let c1_b = (last.close + - last_bars + .iter() + .map(|x| x.close) + .fold(f64::NEG_INFINITY, f64::max)) + .abs() + <= f64::EPSILON; + let c2 = if right_bars.is_empty() { + false + } else { + let min_low = right_bars + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min); + 0.0 < min_low && min_low < dd + }; + if (c1_a || c1_b) && c2 { + v1 = "看多"; + } + } + if last.close < last.open { + let c1_a = (last.low + - last_bars + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min)) + .abs() + <= f64::EPSILON; + let c1_b = (last.close + - last_bars + .iter() + .map(|x| x.close) + .fold(f64::INFINITY, f64::min)) + .abs() + <= f64::EPSILON; + let c2 = if right_bars.is_empty() { + false + } else { + let max_high = right_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + max_high > gg && gg > 0.0 + }; + if (c1_a || c1_b) && c2 { + v1 = "看空"; + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_reversal_V230227:末根反转迹象判定 +/// +/// 参数模板:`"{freq}_D{di}A{avg_bp}_反转V230227"` +/// +/// 信号逻辑: +/// 1. 以末根K线形态(阴线/长上影、阳线/长下影)确定反转方向候选; +/// 2. 左侧13根满足 3/5/8 平均涨跌幅阈值,或 13 连阳/连阴,触发反向信号; +/// 3. 输出 `看多/看空/其他`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1A300_反转V230227_看多_任意_任意_0')` +/// - `Signal('15分钟_D1A300_反转V230227_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `avg_bp`:平均单根涨跌幅阈值(BP),默认 `300`。 +/// 对齐说明:触发条件与优先级对齐 Python `bar_reversal_V230227`。 +#[signal( + category = "kline", + name = "bar_reversal_V230227", + template = "{freq}_D{di}A{avg_bp}_反转V230227", + opcode = "BarReversalV230227", + param_kind = "BarReversalV230227" +)] +pub fn bar_reversal_v230227(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let avg_bp = params.usize("avg_bp", 300) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}A{}", di, avg_bp as usize); + let k3 = "反转V230227"; + + let bars = get_sub_elements(&c.bars_raw, di, 14); + if bars.len() != 14 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let last = &bars[bars.len() - 1]; + let left = &bars[..bars.len() - 1]; + let last_up_c1 = + last.close > last.open && bar_upper(last) > 2.0 * bar_solid(last).max(bar_lower(last)); + let last_dn_c1 = + last.close < last.open && bar_lower(last) > 2.0 * bar_solid(last).max(bar_upper(last)); + + let mut v1 = "其他"; + if last.close < last.open || last_up_c1 { + let up_c1 = + (left[left.len() - 1].close / left[left.len() - 3].open - 1.0) / 3.0 > avg_bp / 10000.0; + let up_c2 = + (left[left.len() - 1].close / left[left.len() - 5].open - 1.0) / 5.0 > avg_bp / 10000.0; + let up_c3 = + (left[left.len() - 1].close / left[left.len() - 8].open - 1.0) / 8.0 > avg_bp / 10000.0; + let up_c4 = left.iter().all(|x| x.close > x.open); + if up_c1 || up_c2 || up_c3 || up_c4 { + v1 = "看空"; + } + } + if last.close > last.open || last_dn_c1 { + let dn_c1 = (left[left.len() - 1].close / left[left.len() - 3].open - 1.0) / 3.0 + < -avg_bp / 10000.0; + let dn_c2 = (left[left.len() - 1].close / left[left.len() - 5].open - 1.0) / 5.0 + < -avg_bp / 10000.0; + let dn_c3 = (left[left.len() - 1].close / left[left.len() - 8].open - 1.0) / 8.0 + < -avg_bp / 10000.0; + let dn_c4 = left.iter().all(|x| x.close < x.open); + if dn_c1 || dn_c2 || dn_c3 || dn_c4 { + v1 = "看多"; + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_r_breaker_V230326:RBreaker 价格位判定 +/// +/// 参数模板:`"{freq}_RBreaker_BS辅助V230326"` +/// +/// 信号逻辑: +/// 1. 用前一根K线 `H/C/L` 计算突破位、观察位、反转位; +/// 2. 当前收盘突破上/下轨判 `趋势做多/趋势做空`; +/// 3. 满足观察后反转条件判 `反转做多/反转做空`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_RBreaker_BS辅助V230326_做多_趋势_任意_0')` +/// - `Signal('日线_RBreaker_BS辅助V230326_做空_反转_任意_0')` +/// +/// 参数说明: +/// - 无额外参数。 +/// 对齐说明:六价位与判定顺序对齐 Python `bar_r_breaker_V230326`。 +#[signal( + category = "kline", + name = "bar_r_breaker_V230326", + template = "{freq}_RBreaker_BS辅助V230326", + opcode = "BarRBreakerV230326", + param_kind = "BarRBreakerV230326" +)] +pub fn bar_r_breaker_v230326(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "RBreaker"; + let k3 = "BS辅助V230326"; + if c.bars_raw.len() < 3 { + return make_kline_signal_v2(&k1, k2, k3, "其他", "其他"); + } + let prev = &c.bars_raw[c.bars_raw.len() - 2]; + let cur = &c.bars_raw[c.bars_raw.len() - 1]; + let h = prev.high; + let c0 = prev.close; + let l = prev.low; + let p = (h + c0 + l) / 3.0; + let break_buy = h + 2.0 * p - 2.0 * l; + let see_sell = p + h - l; + let verse_sell = 2.0 * p - l; + let verse_buy = 2.0 * p - h; + let see_buy = p - (h - l); + let break_sell = l - 2.0 * (h - p); + + let (v1, v2) = if cur.close > break_buy { + ("做多", "趋势") + } else if cur.close < break_sell { + ("做空", "趋势") + } else if cur.high > see_sell && cur.close < verse_sell { + ("做空", "反转") + } else if cur.low < see_buy && cur.close > verse_buy { + ("做多", "反转") + } else { + ("其他", "其他") + }; + make_kline_signal_v2(&k1, k2, k3, v1, v2) +} + +/// bar_dual_thrust_V230403:Dual Thrust 通道突破 +/// +/// 参数模板:`"{freq}_D{di}通道突破#{N}#{K1}#{K2}_BS辅助V230403"` +/// +/// 信号逻辑: +/// 1. 用前 `N+1` 根计算 `HH/HC/LC/LL` 与 `Range=max(HH-LC, HC-LL)`; +/// 2. 构造当根上/下轨:`open + Range*K1%`、`open - Range*K2%`; +/// 3. 收盘上破判 `看多`,下破判 `看空`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1通道突破#5#20#20_BS辅助V230403_看多_任意_任意_0')` +/// - `Signal('日线_D1通道突破#5#20#20_BS辅助V230403_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `N`:回看天数,默认 `5`; +/// - `K1`:上轨系数(百分比),默认 `20`; +/// - `K2`:下轨系数(百分比),默认 `20`。 +/// 对齐说明:通道计算与突破判断对齐 Python `bar_dual_thrust_V230403`。 +#[signal( + category = "kline", + name = "bar_dual_thrust_V230403", + template = "{freq}_D{di}通道突破#{N}#{K1}#{K2}_BS辅助V230403", + opcode = "BarDualThrustV230403", + param_kind = "BarDualThrustV230403" +)] +pub fn bar_dual_thrust_v230403(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("N", params.usize("n", 5)); + let k1v = params.usize("K1", params.usize("k1", 20)); + let k2v = params.usize("K2", params.usize("k2", 20)); + let k1 = c.freq.to_string(); + let k2 = format!("D{}通道突破#{}#{}#{}", di, n, k1v, k2v); + let k3 = "BS辅助V230403"; + + if c.bars_raw.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di + 1, n + 1); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let hh = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let hc = bars + .iter() + .map(|x| x.close) + .fold(f64::NEG_INFINITY, f64::max); + let lc = bars.iter().map(|x| x.close).fold(f64::INFINITY, f64::min); + let ll = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let range = (hh - lc).max(hc - ll); + + let cur = &c.bars_raw[c.bars_raw.len() - di]; + let buy_line = cur.open + range * k1v as f64 / 100.0; + let sell_line = cur.open - range * k2v as f64 / 100.0; + let v1 = if cur.close > buy_line { + "看多" + } else if cur.close < sell_line { + "看空" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +fn calc_tnr_series(c: &CZSC, timeperiod: usize) -> Vec { + if c.bars_raw.is_empty() { + return vec![]; + } + let mut out = vec![0.0; c.bars_raw.len()]; + for (i, out_i) in out.iter_mut().enumerate() { + if i < timeperiod { + *out_i = 0.0; + continue; + } + let start = i.saturating_sub(timeperiod); + let win = &c.bars_raw[start..=i]; + let mut sum_abs = 0.0; + for j in 1..win.len() { + sum_abs += (win[j].close - win[j - 1].close).abs(); + } + *out_i = if sum_abs == 0.0 { + 0.0 + } else { + (win[win.len() - 1].close - win[0].close).abs() / sum_abs + }; + } + out +} + +/// bar_tnr_V230630:TNR 噪音变化判定 +/// +/// 参数模板:`"{freq}_D{di}TNR{timeperiod}K{k}_趋势V230630"` +/// +/// 信号逻辑: +/// 1. 计算 TNR:`|close_t-close_{t-n}| / sum(|diff(close)|)`; +/// 2. 取最近 `k` 根 TNR 均值,与当前 TNR 比较; +/// 3. 当前值大于均值判 `噪音减少`,否则判 `噪音增加`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1TNR14K3_趋势V230630_噪音减少_任意_任意_0')` +/// - `Signal('15分钟_D1TNR14K3_趋势V230630_噪音增加_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:TNR周期,默认 `14`; +/// - `k`:均值窗口,默认 `3`。 +/// 对齐说明:TNR与噪音方向定义对齐 Python `bar_tnr_V230630`。 +#[signal( + category = "kline", + name = "bar_tnr_V230630", + template = "{freq}_D{di}TNR{timeperiod}K{k}_趋势V230630", + opcode = "BarTnrV230630", + param_kind = "BarTnrV230630" +)] +pub fn bar_tnr_v230630(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let timeperiod = params.usize("timeperiod", 14); + let k = params.usize("k", 3); + let k1 = c.freq.to_string(); + let k2 = format!("D{}TNR{}K{}", di, timeperiod, k); + let k3 = "趋势V230630"; + + if c.bars_raw.len() < di + timeperiod + 8 || k == 0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let tnr = calc_tnr_series(c, timeperiod); + let sub = get_sub_elements(tnr.as_slice(), di, k); + if sub.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let mean = sub.iter().sum::() / sub.len() as f64; + let delta_tnr = sub[sub.len() - 1] - mean; + let v1 = if delta_tnr > 0.0 { + "噪音减少" + } else { + "噪音增加" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_tnr_V230629:TNR 分层信号 +/// +/// 参数模板:`"{freq}_D{di}TNR{timeperiod}_趋势V230629"` +/// +/// 信号逻辑: +/// 1. 计算每根K线 TNR 值; +/// 2. 取最近100个 TNR 做 `qcut(10)`; +/// 3. 输出末根所在层:`第{n}层`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1TNR14_趋势V230629_第7层_任意_任意_0')` +/// - `Signal('15分钟_D1TNR14_趋势V230629_第2层_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:TNR周期,默认 `14`。 +/// 对齐说明:分层逻辑与 `duplicates='drop'` 行为对齐 Python `bar_tnr_V230629`。 +#[signal( + category = "kline", + name = "bar_tnr_V230629", + template = "{freq}_D{di}TNR{timeperiod}_趋势V230629", + opcode = "BarTnrV230629", + param_kind = "BarTnrV230629" +)] +pub fn bar_tnr_v230629(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let timeperiod = params.usize("timeperiod", 14); + let k1 = c.freq.to_string(); + let k2 = format!("D{}TNR{}", di, timeperiod); + let k3 = "趋势V230629"; + + if c.bars_raw.len() < di + timeperiod + 8 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let tnr = calc_tnr_series(c, timeperiod); + let sub = get_sub_elements(tnr.as_slice(), di, 100); + if sub.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let v1 = if let Some(lev) = qcut_last_label(sub, 10) { + format!("第{}层", lev + 1) + } else { + "其他".to_string() + }; + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// bar_shuang_fei_V230507:双飞涨停形态 +/// +/// 参数模板:`"{freq}_D{di}双飞_短线V230507"` +/// +/// 信号逻辑: +/// 1. 前天近似涨停、昨天大阴回撤、今天再度强势上涨; +/// 2. 且今天收盘突破昨天高点,判 `看多`; +/// 3. 不满足返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1双飞_短线V230507_看多_任意_任意_0')` +/// - `Signal('日线_D1双飞_短线V230507_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:三日组合条件对齐 Python `bar_shuang_fei_V230507`。 +#[signal( + category = "kline", + name = "bar_shuang_fei_V230507", + template = "{freq}_D{di}双飞_短线V230507", + opcode = "BarShuangFeiV230507", + param_kind = "BarShuangFeiV230507" +)] +pub fn bar_shuang_fei_v230507(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}双飞", di); + let k3 = "短线V230507"; + + if c.bars_raw.len() < di + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, 4); + if bars.len() != 4 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let b4 = &bars[0]; + let b3 = &bars[1]; + let b2 = &bars[2]; + let b1 = &bars[3]; + if b4.close == 0.0 || b3.close == 0.0 || b2.close == 0.0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let first_zt = (b3.close - b3.high).abs() <= f64::EPSILON && b3.close / b4.close - 1.0 > 0.07; + let last_zt = + b1.close / b2.close - 1.0 > 0.07 && bar_upper(b1) < bar_lower(b1).max(bar_solid(b1)) / 2.0; + let bar2_down = b2.close < b2.open && b2.close / b3.close - 1.0 < -0.05; + let v1 = if first_zt && last_zt && b1.close > b2.high && bar2_down { + "看多" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +fn std_pop(values: &[f64]) -> f64 { + if values.is_empty() || values.iter().any(|x| !x.is_finite()) { + return f64::NAN; + } + let mean = values.iter().sum::() / values.len() as f64; + let var = values.iter().map(|x| (x - mean).powi(2)).sum::() / values.len() as f64; + var.sqrt() +} + +fn qcut_labels(values: &[f64], q: usize) -> Option> { + if q == 0 || values.is_empty() || values.iter().any(|x| !x.is_finite()) { + return None; + } + let mut sorted: Vec = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + + let quantile = |p: f64| -> f64 { + if sorted.len() == 1 { + return sorted[0]; + } + let h = (sorted.len() - 1) as f64 * p; + let i = h.floor() as usize; + let j = h.ceil() as usize; + if i == j { + sorted[i] + } else { + sorted[i] + (h - i as f64) * (sorted[j] - sorted[i]) + } + }; + + let mut edges = Vec::with_capacity(q + 1); + for i in 0..=q { + edges.push(quantile(i as f64 / q as f64)); + } + edges.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON); + if edges.len() <= 1 { + return None; + } + let bins = edges.len() - 1; + let mut labels = Vec::with_capacity(values.len()); + for &x in values { + if x < edges[0] || x > edges[bins] { + return None; + } + let mut found = None; + for i in 0..bins { + let left_ok = if i == 0 { x >= edges[i] } else { x > edges[i] }; + let right_ok = x <= edges[i + 1]; + if left_ok && right_ok { + found = Some(i); + break; + } + } + labels.push(found.unwrap_or(bins - 1)); + } + Some(labels) +} + +fn qcut_three_label_last(values: &[f64]) -> Option<&'static str> { + if values.is_empty() || values.iter().any(|x| !x.is_finite()) { + return None; + } + let mut sorted: Vec = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + + let quantile = |p: f64| -> f64 { + if sorted.len() == 1 { + return sorted[0]; + } + let h = (sorted.len() - 1) as f64 * p; + let i = h.floor() as usize; + let j = h.ceil() as usize; + if i == j { + sorted[i] + } else { + sorted[i] + (h - i as f64) * (sorted[j] - sorted[i]) + } + }; + + let mut edges = Vec::with_capacity(4); + for i in 0..=3 { + edges.push(quantile(i as f64 / 3.0)); + } + edges.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON); + + // 对齐 pandas.qcut(labels=["低波动","中波动","高波动"], duplicates="drop"): + // 当重复边界导致箱数 < 3 时,会因为 labels 数量不匹配抛异常,外层返回“其他”。 + if edges.len() != 4 { + return None; + } + let x = *values.last()?; + if x < edges[0] || x > edges[3] { + return None; + } + if x >= edges[0] && x <= edges[1] { + return Some("低波动"); + } + if x > edges[1] && x <= edges[2] { + return Some("中波动"); + } + if x > edges[2] && x <= edges[3] { + return Some("高波动"); + } + None +} + +fn linear_slope_exact(y: &[f64]) -> Option { + if y.len() < 2 { + return None; + } + let n = y.len() as f64; + let sum_x = (n - 1.0) * n / 2.0; + let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0; + let sum_y = y.iter().sum::(); + let sum_xy = y + .iter() + .enumerate() + .map(|(i, v)| i as f64 * *v) + .sum::(); + let denom = n * sum_xx - sum_x * sum_x; + if denom.abs() <= f64::EPSILON { + return None; + } + Some((n * sum_xy - sum_x * sum_y) / denom) +} + +fn solve_3x3(mut a: [[f64; 4]; 3]) -> Option<[f64; 3]> { + for i in 0..3 { + let mut pivot = i; + for r in (i + 1)..3 { + if a[r][i].abs() > a[pivot][i].abs() { + pivot = r; + } + } + if a[pivot][i].abs() <= f64::EPSILON { + return None; + } + if pivot != i { + a.swap(i, pivot); + } + let d = a[i][i]; + for value in a[i].iter_mut().skip(i) { + *value /= d; + } + for r in 0..3 { + if r == i { + continue; + } + let f = a[r][i]; + let pivot_row = a[i]; + for (value, pivot_value) in a[r].iter_mut().zip(pivot_row.iter()).skip(i) { + *value -= f * *pivot_value; + } + } + } + Some([a[0][3], a[1][3], a[2][3]]) +} + +fn quadratic_a_exact(y: &[f64]) -> Option { + if y.len() < 3 { + return None; + } + let n = y.len() as f64; + let sx = (n - 1.0) * n / 2.0; + let sx2 = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0; + let sx3 = (n * (n - 1.0) / 2.0).powi(2); + let sx4 = (n - 1.0) * n * (2.0 * n - 1.0) * (3.0 * n * n - 3.0 * n - 1.0) / 30.0; + let sy = y.iter().sum::(); + let sxy = y + .iter() + .enumerate() + .map(|(i, v)| i as f64 * *v) + .sum::(); + let sx2y = y + .iter() + .enumerate() + .map(|(i, v)| (i as f64).powi(2) * *v) + .sum::(); + let aug = [[n, sx, sx2, sy], [sx, sx2, sx3, sxy], [sx2, sx3, sx4, sx2y]]; + let coef = solve_3x3(aug)?; + Some(coef[2]) +} + +fn current_ubi_high_low(c: &CZSC) -> Option<(f64, f64)> { + if c.bars_ubi.is_empty() || c.bi_list.is_empty() { + return None; + } + let raw: Vec<&czsc_core::objects::bar::RawBar> = + c.bars_ubi.iter().flat_map(|x| x.elements.iter()).collect(); + if raw.is_empty() { + return None; + } + let high = raw.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let low = raw.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + Some((high, low)) +} + +/// bar_fang_liang_break_V221216:放量突破与缩量回踩 +/// +/// 参数模板:`"{freq}_D{di}TH{th}#{ma_type}#{timeperiod}_突破V221216"` +/// +/// 信号逻辑: +/// 1. 计算指定均线,检查末根是否放量且站上均线,判 `放量突破`; +/// 2. 检查末根是否缩量且收盘不破均线,且前序收盘与均线距离在阈值内,判 `缩量回踩`; +/// 3. 在窗口长度 `5~9` 中依次尝试,首次出现突破即返回。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1TH300#SMA#233_突破V221216_放量突破_缩量回踩_任意_0')` +/// - `Signal('15分钟_D1TH300#SMA#233_突破V221216_其他_其他_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `th`:回踩距离阈值(BP),默认 `300`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `233`。 +/// 对齐说明:窗口扫描与两阶段条件对齐 Python `bar_fang_liang_break_V221216`。 +#[signal( + category = "kline", + name = "bar_fang_liang_break_V221216", + template = "{freq}_D{di}TH{th}#{ma_type}#{timeperiod}_突破V221216", + opcode = "BarFangLiangBreakV221216", + param_kind = "BarFangLiangBreakV221216" +)] +pub fn bar_fang_liang_break_v221216( + c: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let th = params.usize("th", 300); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let timeperiod = params.usize("timeperiod", 233); + let k1 = c.freq.to_string(); + let k2 = format!("D{}TH{}#{}#{}", di, th, ma_type, timeperiod); + let k3 = "突破V221216"; + + let cache_key = format!("{}_{}_{}", c.freq, ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + }; + let id_idx: HashMap = c + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + let mut v1 = "其他"; + let mut v2 = "其他"; + let base = if c.bars_raw.len() > 300 { + &c.bars_raw[300..] + } else { + &[] as &[czsc_core::objects::bar::RawBar] + }; + + for n in 5..=9 { + let bars = get_sub_elements(base, di, n); + if bars.len() <= 4 { + v1 = "其他"; + v2 = "其他"; + continue; + } + let last = &bars[bars.len() - 1]; + let prev = &bars[bars.len() - 2]; + let Some(&idx) = id_idx.get(&last.id) else { + continue; + }; + let ma1v = ma[idx]; + v1 = if last.vol >= prev.vol && last.close > ma1v { + "放量突破" + } else { + "其他" + }; + + let vol_min = + bars[..bars.len() - 1].iter().map(|x| x.vol).sum::() / (bars.len() - 1) as f64; + let distance = if ma1v.abs() <= f64::EPSILON { + false + } else { + bars[..bars.len() - 1] + .iter() + .all(|x| ((x.close / ma1v - 1.0).abs() * 10000.0) <= th as f64) + }; + v2 = if last.close >= ma1v && last.vol < vol_min && distance { + "缩量回踩" + } else { + "其他" + }; + if v1 != "其他" { + break; + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// bar_channel_V230508:窄幅通道方向判定 +/// +/// 参数模板:`"{freq}_D{di}M{m}_通道V230507"` +/// +/// 信号逻辑: +/// 1. 窗口内每根K线涨跌幅需不超过 `m` BP; +/// 2. 对高点和低点分别做一元线性拟合,要求 `r2 > 0.8`; +/// 3. 双斜率同向且右侧极值确认,判 `看多/看空`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1M600_通道V230507_看多_任意_任意_0')` +/// - `Signal('日线_D1M600_通道V230507_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:窗口长度,默认 `20`; +/// - `m`:单根波动阈值(BP),默认 `600`。 +/// 对齐说明:拟合阈值和右侧极值规则对齐 Python `bar_channel_V230508`。 +#[signal( + category = "kline", + name = "bar_channel_V230508", + template = "{freq}_D{di}M{m}_通道V230508", + opcode = "BarChannelV230508", + param_kind = "BarChannelV230508" +)] +pub fn bar_channel_v230508(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 20); + let m = params.usize("m", 600); + let k1 = c.freq.to_string(); + let k2 = format!("D{}M{}", di, m); + let k3 = "通道V230507"; + + if c.bars_raw.len() < di + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + if bars + .iter() + .any(|x| x.open == 0.0 || ((x.close / x.open - 1.0).abs() * 10000.0 > m as f64)) + { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let highs: Vec = bars.iter().map(|x| x.high).collect(); + let lows: Vec = bars.iter().map(|x| x.low).collect(); + let res_high = highs.as_slice().single_linear(); + let res_low = lows.as_slice().single_linear(); + if !(res_high.r2 > 0.8 && res_low.r2 > 0.8) { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let tail = bars.len().min(3); + let high_right = bars[bars.len() - tail..] + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let low_right = bars[bars.len() - tail..] + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min); + let max_high = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let min_low = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let v1 = if res_high.slope > 0.0 + && res_low.slope > 0.0 + && (high_right - max_high).abs() <= f64::EPSILON + { + "看多" + } else if res_high.slope < 0.0 + && res_low.slope < 0.0 + && (low_right - min_low).abs() <= f64::EPSILON + { + "看空" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_eight_V230702:8K 走势分类 +/// +/// 参数模板:`"{freq}_D{di}#8K_走势分类V230702"` +/// +/// 信号逻辑: +/// 1. 统计8K中的三连K重叠中枢; +/// 2. 无中枢时输出 `无中枢上涨/无中枢下跌`; +/// 3. 双中枢满足不重叠时输出 `双中枢上涨/双中枢下跌`; +/// 4. 其余按前三根是否出现极值分为 `强平衡/弱平衡/转折平衡`。 +/// +/// 信号列表示例: +/// - `Signal('30分钟_D1#8K_走势分类V230702_双中枢上涨_任意_任意_0')` +/// - `Signal('30分钟_D1#8K_走势分类V230702_转折平衡市_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:中枢判定与分类分支顺序对齐 Python `bar_eight_V230702`。 +#[signal( + category = "kline", + name = "bar_eight_V230702", + template = "{freq}_D{di}#8K_走势分类V230702", + opcode = "BarEightV230702", + param_kind = "BarEightV230702" +)] +pub fn bar_eight_v230702(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}#8K", di); + let k3 = "走势分类V230702"; + if c.bars_raw.len() < di + 12 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, 8); + if bars.len() != 8 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut zs_list: Vec<[usize; 3]> = Vec::new(); + for i in 0..=bars.len() - 3 { + let b1 = &bars[i]; + let b2 = &bars[i + 1]; + let b3 = &bars[i + 2]; + if b1.high.min(b2.high).min(b3.high) >= b1.low.max(b2.low).max(b3.low) { + zs_list.push([i, i + 1, i + 2]); + } + } + let dir = if bars[bars.len() - 1].close > bars[0].open { + "上涨" + } else { + "下跌" + }; + if zs_list.is_empty() { + let v = format!("无中枢{}", dir); + return make_kline_signal_v1(&k1, &k2, k3, &v); + } + + if zs_list.len() >= 2 { + let zs1 = zs_list[0]; + let zs2 = zs_list[zs_list.len() - 1]; + let zs1_high = [bars[zs1[0]].high, bars[zs1[1]].high, bars[zs1[2]].high] + .iter() + .copied() + .fold(f64::NEG_INFINITY, f64::max); + let zs1_low = [bars[zs1[0]].low, bars[zs1[1]].low, bars[zs1[2]].low] + .iter() + .copied() + .fold(f64::INFINITY, f64::min); + let zs2_high = [bars[zs2[0]].high, bars[zs2[1]].high, bars[zs2[2]].high] + .iter() + .copied() + .fold(f64::NEG_INFINITY, f64::max); + let zs2_low = [bars[zs2[0]].low, bars[zs2[1]].low, bars[zs2[2]].low] + .iter() + .copied() + .fold(f64::INFINITY, f64::min); + if dir == "上涨" && zs1_high < zs2_low { + let v = format!("双中枢{}", dir); + return make_kline_signal_v1(&k1, &k2, k3, &v); + } + if dir == "下跌" && zs1_low > zs2_high { + let v = format!("双中枢{}", dir); + return make_kline_signal_v1(&k1, &k2, k3, &v); + } + } + + let max_high = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let min_low = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let high_first = bars[0].high.max(bars[1].high).max(bars[2].high) == max_high; + let low_first = bars[0].low.min(bars[1].low).min(bars[2].low) == min_low; + let v1 = if high_first && !low_first { + "弱平衡市" + } else if low_first && !high_first { + "强平衡市" + } else { + "转折平衡市" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_window_std_V230731:窗口波动分层特征 +/// +/// 参数模板:`"{freq}_D{di}W{w}M{m}N{n}_窗口波动V230731"` +/// +/// 信号逻辑: +/// 1. 计算每根K线的 `STD20`(前20收盘标准差); +/// 2. 取最近 `m` 个 `STD20` 做 `qcut(n)` 分层; +/// 3. 输出最近 `w` 根中的最大层和最小层。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W5M100N10_窗口波动V230731_高波N8_低波N6_任意_0')` +/// - `Signal('60分钟_D1W5M100N10_窗口波动V230731_高波N4_低波N3_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `w`:观察窗口,默认 `5`; +/// - `m`:分层样本长度,默认 `100`; +/// - `n`:分层数量,默认 `10`。 +/// 对齐说明:STD20口径与 `qcut(..., duplicates='drop')` 对齐 Python `bar_window_std_V230731`。 +#[signal( + category = "kline", + name = "bar_window_std_V230731", + template = "{freq}_D{di}W{w}M{m}N{n}_窗口波动V230731", + opcode = "BarWindowStdV230731", + param_kind = "BarWindowStdV230731" +)] +pub fn bar_window_std_v230731(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 5); + let m = params.usize("m", 100); + let n = params.usize("n", 10); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}M{}N{}", di, w, m, n); + let k3 = "窗口波动V230731"; + + if c.bars_raw.len() < di + m + w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let closes: Vec = c.bars_raw.iter().map(|x| x.close).collect(); + let mut std20 = vec![0.0; closes.len()]; + for i in 0..closes.len() { + std20[i] = if i < 5 { + 0.0 + } else { + std_pop(&closes[i.saturating_sub(20)..i]) + }; + } + let stds = get_sub_elements(std20.as_slice(), di, m); + if stds.len() != m { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let Some(layer) = qcut_labels(stds, n) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let t = layer.len().min(w); + let tail = &layer[layer.len() - t..]; + let max_layer = tail.iter().copied().max().unwrap_or(0) + 1; + let min_layer = tail.iter().copied().min().unwrap_or(0) + 1; + let v1 = format!("高波N{}", max_layer); + let v2 = format!("低波N{}", min_layer); + make_kline_signal_v2(&k1, &k2, k3, &v1, &v2) +} + +/// bar_window_ps_V230731:支撑压力位分位特征 +/// +/// 参数模板:`"{freq}_W{w}M{m}N{n}L{l}_支撑压力位V230731"` +/// +/// 信号逻辑: +/// 1. 用最近 `n` 笔高低点构造压力线与支撑线; +/// 2. 计算收盘在区间中的位置 `pct=(close-L)/(H-L)`; +/// 3. 对最近 `m` 个 `pct` 做 `qcut(l)`,输出最近 `w` 根的压力/支撑层与当前层。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_W5M40N8L5_支撑压力位V230731_压力N5_支撑N3_当前N4_0')` +/// - `Signal('15分钟_W5M40N8L5_支撑压力位V230731_压力N4_支撑N1_当前N2_0')` +/// +/// 参数说明: +/// - `w`:观察窗口,默认 `5`; +/// - `m`:分位样本长度,默认 `40`; +/// - `n`:笔窗口长度,默认 `8`; +/// - `l`:分层数量,默认 `5`。 +/// 对齐说明:参数约束与分位定义对齐 Python `bar_window_ps_V230731`。 +#[signal( + category = "kline", + name = "bar_window_ps_V230731", + template = "{freq}_W{w}M{m}N{n}L{l}_支撑压力位V230731", + opcode = "BarWindowPsV230731", + param_kind = "BarWindowPsV230731" +)] +pub fn bar_window_ps_v230731(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let w = params.usize("w", 5); + let m = params.usize("m", 40); + let n = params.usize("n", 8); + let l = params.usize("l", 5); + if !(m > l * 2 && l > 2) || w >= m { + return vec![]; + } + + let k1 = c.freq.to_string(); + let k2 = format!("W{}M{}N{}L{}", w, m, n, l); + let k3 = "支撑压力位V230731"; + if c.bi_list.len() < n + 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = &c.bi_list[c.bi_list.len() - n..]; + let h_line = bis + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + let l_line = bis + .iter() + .map(|x| x.get_low()) + .fold(f64::INFINITY, f64::min); + let d = h_line - l_line; + if d.abs() <= f64::EPSILON { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + // 对齐 Python: bar.cache 仅对“尚未写入该键”的 bar 做一次写入,历史值不回溯覆盖。 + let pct_key = format!("bar_window_ps_V230731#pct#W{}M{}N{}L{}", w, m, n, l); + let bar_ids: Vec = c.bars_raw.iter().map(|b| b.id).collect(); + let mut pct_series = if let (Some(vals), Some(ids)) = + (cache.series.get(&pct_key), cache.series_ids.get(&pct_key)) + { + let mut id_val: HashMap = HashMap::with_capacity(ids.len()); + for (i, id) in ids.iter().enumerate() { + id_val.insert(*id, vals[i]); + } + bar_ids + .iter() + .map(|id| *id_val.get(id).unwrap_or(&f64::NAN)) + .collect::>() + } else { + vec![f64::NAN; bar_ids.len()] + }; + for (i, bar) in c.bars_raw.iter().enumerate() { + let is_last = i + 1 == c.bars_raw.len(); + if is_last || !pct_series[i].is_finite() { + // 对齐 Python 多频流式更新语义:最后一根未完成 bar 在每次 on_bar 时都会携带最新 close, + // 不能复用更早时刻写入的 pct;历史 bar 仍保持“只写一次”。 + pct_series[i] = (bar.close - l_line) / d; + } + } + cache.series.insert(pct_key.clone(), pct_series.clone()); + cache.series_ids.insert(pct_key, bar_ids); + + let fenwei = get_sub_elements(pct_series.as_slice(), 1, m); + if fenwei.len() != m { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let Some(layer) = qcut_labels(fenwei, l) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let t = layer.len().min(w); + let tail = &layer[layer.len() - t..]; + let max_layer = tail.iter().copied().max().unwrap_or(0) + 1; + let min_layer = tail.iter().copied().min().unwrap_or(0) + 1; + let cur_layer = layer[layer.len() - 1] + 1; + let v1 = format!("压力N{}", max_layer); + let v2 = format!("支撑N{}", min_layer); + let v3 = format!("当前N{}", cur_layer); + make_kline_signal_v3(&k1, &k2, k3, &v1, &v2, &v3) +} + +/// bar_window_ps_V230801:支撑压力位窗口极值 +/// +/// 参数模板:`"{freq}_N{n}W{w}_支撑压力位V230801"` +/// +/// 信号逻辑: +/// 1. 基于最近 `n` 笔和当前未完成笔计算压力/支撑区间; +/// 2. 将最近 `w` 根收盘映射到 `0~9` 分位整数; +/// 3. 输出窗口最大/最小/当前分位。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N8W5_支撑压力位V230801_最大N7_最小N3_当前N5_0')` +/// - `Signal('60分钟_N8W5_支撑压力位V230801_最大N4_最小N0_当前N2_0')` +/// +/// 参数说明: +/// - `w`:观察窗口,默认 `5`; +/// - `n`:笔窗口长度,默认 `8`。 +/// 对齐说明:`ubi` 口径与整数分位映射对齐 Python `bar_window_ps_V230801`。 +#[signal( + category = "kline", + name = "bar_window_ps_V230801", + template = "{freq}_N{n}W{w}_支撑压力位V230801", + opcode = "BarWindowPsV230801", + param_kind = "BarWindowPsV230801" +)] +pub fn bar_window_ps_v230801(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let w = params.usize("w", 5); + let n = params.usize("n", 8); + let k1 = c.freq.to_string(); + let k2 = format!("N{}W{}", n, w); + let k3 = "支撑压力位V230801"; + + // 对齐 Python `if len(c.bi_list) < n + 2 or not ubi: return 其他`: + // Rust 中 ubi 判空需同时满足 bars_ubi/bi_list 非空且 ubi_fxs 非空。 + let Some(ubi_fxs) = c.get_ubi_fxs() else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + if ubi_fxs.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let Some((ubi_high, ubi_low)) = current_ubi_high_low(c) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + if c.bi_list.len() < n + 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = &c.bi_list[c.bi_list.len() - n..]; + let h_line = bis.iter().map(|x| x.get_high()).fold(ubi_high, f64::max); + let l_line = bis.iter().map(|x| x.get_low()).fold(ubi_low, f64::min); + let d = h_line - l_line; + if d.abs() <= f64::EPSILON { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = if c.bars_raw.len() >= w { + &c.bars_raw[c.bars_raw.len() - w..] + } else { + &c.bars_raw[..] + }; + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let pcts: Vec = bars + .iter() + .map(|x| (((x.close - l_line) / d).max(0.0) * 10.0) as i32) + .collect(); + let v1 = format!("最大N{}", pcts.iter().copied().max().unwrap_or(0)); + let v2 = format!("最小N{}", pcts.iter().copied().min().unwrap_or(0)); + let v3 = format!("当前N{}", pcts[pcts.len() - 1]); + make_kline_signal_v3(&k1, &k2, k3, &v1, &v2, &v3) +} + +/// bar_trend_V240209:趋势跟踪结构判定 +/// +/// 参数模板:`"{freq}_D{di}N{N}趋势跟踪_BS辅助V240209"` +/// +/// 信号逻辑: +/// 1. 在窗口内定位最高点和最低点,结合其先后顺序选择多头或空头分支; +/// 2. 右侧结构满足 `5 Vec { + let di = params.usize("di", 1); + let n = params.usize("N", 60); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}趋势跟踪", di, n); + let k3 = "BS辅助V240209"; + if c.bars_raw.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let cache_key = "MACD12#26#9"; + update_macd_cache(c, cache_key, 12, 26, 9, cache); + let Some(macd) = cache.macd.get(cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + // 对齐 Python min/max(key=...): 并列时取首次出现。 + let mut max_bar = &bars[0]; + let mut min_bar = &bars[0]; + for b in &bars[1..] { + if b.high > max_bar.high { + max_bar = b; + } + if b.low < min_bar.low { + min_bar = b; + } + } + let id_to_idx: HashMap = + macd.ids.iter().enumerate().map(|(i, x)| (*x, i)).collect(); + let dif_vals: Vec = bars + .iter() + .filter_map(|x| id_to_idx.get(&x.id).map(|i| macd.dif[*i])) + .collect(); + let macd_vals: Vec = bars + .iter() + .filter_map(|x| id_to_idx.get(&x.id).map(|i| macd.macd[*i])) + .collect(); + let dif_std = std_pop(&dif_vals); + let macd_std = std_pop(&macd_vals); + + if min_bar.dt < max_bar.dt { + let right: Vec<&czsc_core::objects::bar::RawBar> = + c.bars_raw.iter().filter(|x| x.dt >= max_bar.dt).collect(); + if !right.is_empty() { + let mut right_min = right[0]; + for b in &right[1..] { + if b.low < right_min.low { + right_min = b; + } + } + let last = right[right.len() - 1]; + let c1 = right_min.id - max_bar.id < 30 && right_min.id - max_bar.id > 5; + let c2 = id_to_idx + .get(&last.id) + .map(|i| macd.dif[*i].abs() < dif_std) + .unwrap_or(false); + let c3 = right_min.low > min_bar.low; + let c4 = id_to_idx + .get(&last.id) + .map(|i| macd.macd[*i].abs() < macd_std) + .unwrap_or(false); + if c1 && c2 && c3 && c4 { + return make_kline_signal_v1(&k1, &k2, k3, "多头"); + } + } + } + + if min_bar.dt > max_bar.dt { + let right: Vec<&czsc_core::objects::bar::RawBar> = + c.bars_raw.iter().filter(|x| x.dt >= min_bar.dt).collect(); + if !right.is_empty() { + let mut right_max = right[0]; + for b in &right[1..] { + if b.high > right_max.high { + right_max = b; + } + } + let last = right[right.len() - 1]; + let c1 = right_max.id - min_bar.id < 30 && right_max.id - min_bar.id > 5; + let c2 = id_to_idx + .get(&last.id) + .map(|i| macd.dif[*i].abs() < dif_std) + .unwrap_or(false); + let c3 = right_max.high < max_bar.high; + let c4 = id_to_idx + .get(&last.id) + .map(|i| macd.macd[*i].abs() < macd_std) + .unwrap_or(false); + if c1 && c2 && c3 && c4 { + return make_kline_signal_v1(&k1, &k2, k3, "空头"); + } + } + } + + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// bar_plr_V240427:盈亏比约束 +/// +/// 参数模板:`"{freq}_D{di}W{w}T{t}M{m}_盈亏比V240427"` +/// +/// 信号逻辑: +/// 1. `多头`:以窗口最低点前的最高点与当前收盘计算盈亏比; +/// 2. `空头`:以窗口最高点前的最低点与当前收盘计算盈亏比; +/// 3. `plr > t/10` 判 `满足`,否则 `不满足`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W60T20M多头_盈亏比V240427_满足_任意_任意_0')` +/// - `Signal('60分钟_D1W60T20M空头_盈亏比V240427_不满足_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `w`:窗口长度,默认 `60`; +/// - `t`:阈值(`t/10`),默认 `20`; +/// - `m`:方向,`多头/空头`,默认 `多头`。 +/// 对齐说明:盈亏比定义与阈值比较对齐 Python `bar_plr_V240427`。 +#[signal( + category = "kline", + name = "bar_plr_V240427", + template = "{freq}_D{di}W{w}T{t}M{m}_盈亏比V240427", + opcode = "BarPlrV240427", + param_kind = "BarPlrV240427" +)] +pub fn bar_plr_v240427(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 60); + let t = params.usize("t", 20); + let m = params.str("m", "多头"); + if m != "多头" && m != "空头" { + return vec![]; + } + if di == 0 || w < 10 { + return vec![]; + } + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}T{}M{}", di, w, t, m); + let k3 = "盈亏比V240427"; + if c.bars_raw.len() < 7 + w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let last_close = bars[bars.len() - 1].close; + let mut v1 = "其他"; + + if m == "多头" { + let (idx_low, low_bar) = bars + .iter() + .enumerate() + .min_by(|a, b| { + a.1.low + .partial_cmp(&b.1.low) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap(); + if idx_low == 0 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let high_bar = bars[..idx_low] + .iter() + .max_by(|a, b| { + a.high + .partial_cmp(&b.high) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap(); + let profit = high_bar.high - last_close; + let loss = last_close - low_bar.low; + let plr = if loss > 0.0 { profit / loss } else { 0.0 }; + v1 = if plr > t as f64 / 10.0 { + "满足" + } else { + "不满足" + }; + } + + if m == "空头" { + let (idx_high, high_bar) = bars + .iter() + .enumerate() + .max_by(|a, b| { + a.1.high + .partial_cmp(&b.1.high) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap(); + if idx_high == 0 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let low_bar = bars[..idx_high] + .iter() + .min_by(|a, b| { + a.low + .partial_cmp(&b.low) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap(); + let profit = last_close - low_bar.low; + let loss = high_bar.high - last_close; + let plr = if loss > 0.0 { profit / loss } else { 0.0 }; + v1 = if plr > t as f64 / 10.0 { + "满足" + } else { + "不满足" + }; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_polyfit_V240428:一阶二阶拟合分类 +/// +/// 参数模板:`"{freq}_D{di}W{w}_分类V240428"` +/// +/// 信号逻辑: +/// 1. 对窗口收盘价做一阶拟合取斜率 `p1`; +/// 2. 做二阶拟合取二次项系数 `p2`; +/// 3. 按 `p1/p2` 符号组合输出 `加速上涨/减速上涨/加速下跌/减速下跌`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W20_分类V240428_加速上涨_任意_任意_0')` +/// - `Signal('60分钟_D1W20_分类V240428_减速下跌_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `w`:窗口长度,默认 `20`。 +/// 对齐说明:一二阶系数组合分类对齐 Python `bar_polyfit_V240428`。 +#[signal( + category = "kline", + name = "bar_polyfit_V240428", + template = "{freq}_D{di}W{w}_分类V240428", + opcode = "BarPolyfitV240428", + param_kind = "BarPolyfitV240428" +)] +pub fn bar_polyfit_v240428(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 20); + if di == 0 || w < 10 { + return vec![]; + } + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}", di, w); + let k3 = "分类V240428"; + if c.bars_raw.len() < 7 + w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.len() != w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let close: Vec = bars.iter().map(|x| x.close).collect(); + let p1 = linear_slope_exact(&close).unwrap_or(0.0); + let p2 = quadratic_a_exact(&close).unwrap_or(0.0); + let v1 = if p1 > 0.0 && p2 > 0.0 { + "加速上涨" + } else if p1 < 0.0 && p2 < 0.0 { + "加速下跌" + } else if p1 > 0.0 && p2 < 0.0 { + "减速上涨" + } else if p1 < 0.0 && p2 > 0.0 { + "减速下跌" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_break_V240428:收盘极值突破 +/// +/// 参数模板:`"{freq}_D{di}W{w}_事件V240428"` +/// +/// 信号逻辑: +/// 1. 在窗口内比较末根收盘与前序最高/最低; +/// 2. 收盘高于前序最高判 `收盘新高`; +/// 3. 收盘低于前序最低判 `收盘新低`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W20_事件V240428_收盘新高_任意_任意_0')` +/// - `Signal('60分钟_D1W20_事件V240428_收盘新低_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `w`:窗口长度,默认 `20`。 +/// 对齐说明:极值比较区间与 Python `bar_break_V240428` 一致。 +#[signal( + category = "kline", + name = "bar_break_V240428", + template = "{freq}_D{di}W{w}_事件V240428", + opcode = "BarBreakV240428", + param_kind = "BarBreakV240428" +)] +pub fn bar_break_v240428(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 20); + if di == 0 || w < 10 { + return vec![]; + } + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}", di, w); + let k3 = "事件V240428"; + if c.bars_raw.len() < 7 + w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let last = &bars[bars.len() - 1]; + let prev_high = bars[..bars.len() - 1] + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let prev_low = bars[..bars.len() - 1] + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min); + let v1 = if last.close > prev_high { + "收盘新高" + } else if last.close < prev_low { + "收盘新低" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_classify_V240606:单根K线收盘位置分类 +/// +/// 参数模板:`"{freq}_D{di}收盘位置_分类V240606"` +/// +/// 信号逻辑: +/// 1. 将K线高低区间三等分; +/// 2. 收盘落在上三分之一判 `高位`; +/// 3. 收盘落在下三分之一判 `低位`,否则 `中间`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1收盘位置_分类V240606_高位_任意_任意_0')` +/// - `Signal('60分钟_D1收盘位置_分类V240606_中间_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:三分位阈值与 Python `bar_classify_V240606` 一致。 +#[signal( + category = "kline", + name = "bar_classify_V240606", + template = "{freq}_D{di}收盘位置_分类V240606", + opcode = "BarClassifyV240606", + param_kind = "BarClassifyV240606" +)] +pub fn bar_classify_v240606(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + if di == 0 { + return vec![]; + } + let k1 = c.freq.to_string(); + let k2 = format!("D{}收盘位置", di); + let k3 = "分类V240606"; + if c.bars_raw.len() < 7 + di { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bar = &c.bars_raw[c.bars_raw.len() - di]; + let gap = (bar.high - bar.low) / 3.0; + let v1 = if bar.close > (bar.high - gap) { + "高位" + } else if bar.close < (bar.low + gap) { + "低位" + } else { + "中间" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_classify_V240607:两根K线收盘位置分类 +/// +/// 参数模板:`"{freq}_D{di}K2收盘位置_分类V240607"` +/// +/// 信号逻辑: +/// 1. 取最近两根K线(截至 `di`); +/// 2. 第二根收盘高于第一根最高判 `看多`; +/// 3. 第二根收盘低于第一根最低判 `看空`,否则 `中性`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K2收盘位置_分类V240607_看多_任意_任意_0')` +/// - `Signal('60分钟_D1K2收盘位置_分类V240607_中性_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:两根K线比较规则与 Python `bar_classify_V240607` 一致。 +#[signal( + category = "kline", + name = "bar_classify_V240607", + template = "{freq}_D{di}K2收盘位置_分类V240607", + opcode = "BarClassifyV240607", + param_kind = "BarClassifyV240607" +)] +pub fn bar_classify_v240607(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + if di == 0 { + return vec![]; + } + let k1 = c.freq.to_string(); + let k2 = format!("D{}K2收盘位置", di); + let k3 = "分类V240607"; + if c.bars_raw.len() < 7 + di { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, 2); + if bars.len() != 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bar1 = &bars[0]; + let bar2 = &bars[1]; + let v1 = if bar2.close > bar1.high { + "看多" + } else if bar2.close < bar1.low { + "看空" + } else { + "中性" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_decision_V240608:放量反向决策区 +/// +/// 参数模板:`"{freq}_W{w}N{n}Q{q}放量_决策区域V240608"` +/// +/// 信号逻辑: +/// 1. 在最近 `n` 根中取成交量最大的3根; +/// 2. 若三者成交量都大于最近 `w` 根的 `q` 分位; +/// 3. 且 `n` 窗口净涨则 `看空`,净跌则 `看多`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_W300N10Q80放量_决策区域V240608_看空_任意_任意_0')` +/// - `Signal('60分钟_W300N10Q80放量_决策区域V240608_看多_任意_任意_0')` +/// +/// 参数说明: +/// - `w`:长窗口,默认 `300`; +/// - `n`:短窗口,默认 `10`; +/// - `q`:成交量分位阈值(0-100),默认 `80`。 +/// 对齐说明:分位阈值与反向判定对齐 Python `bar_decision_V240608`。 +#[signal( + category = "kline", + name = "bar_decision_V240608", + template = "{freq}_W{w}N{n}Q{q}放量_决策区域V240608", + opcode = "BarDecisionV240608", + param_kind = "BarDecisionV240608" +)] +pub fn bar_decision_v240608(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let w = params.usize("w", 300); + let n = params.usize("n", 10); + let q = params.usize("q", 80); + if !(w > n && n > 3) { + return vec![]; + } + let k1 = c.freq.to_string(); + let k2 = format!("W{}N{}Q{}放量", w, n, q); + let k3 = "决策区域V240608"; + if c.bars_raw.len() < w + n { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let w_bars = get_sub_elements(&c.bars_raw, 1, w); + let n_bars = get_sub_elements(&c.bars_raw, 1, n); + if w_bars.len() != w || n_bars.len() != n { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let n_diff = n_bars[n_bars.len() - 1].close - n_bars[0].open; + let mut top3: Vec = n_bars.iter().map(|x| x.vol).collect(); + top3.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + top3.truncate(3); + let vols: Vec = w_bars.iter().map(|x| x.vol).collect(); + let qth = percentile_linear(&vols, q as f64).unwrap_or(f64::INFINITY); + let vol_match = top3.iter().all(|x| *x > qth); + let v1 = if vol_match && n_diff > 0.0 { + "看空" + } else if vol_match && n_diff < 0.0 { + "看多" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +fn close_position_label(bar: &czsc_core::objects::bar::RawBar) -> &'static str { + let hl = bar.high - bar.low; + let t1 = bar.low + hl * (2.0 / 3.0); + let t2 = bar.low + hl * (1.0 / 3.0); + if bar.close > t1 { + "高位收盘" + } else if bar.close > t2 && bar.close < t1 { + "中位收盘" + } else { + "低位收盘" + } +} + +/// bar_decision_V240616:新高新低后的强弱决策 +/// +/// 参数模板:`"{freq}_W{w}N{n}强弱_决策区域V240616"` +/// +/// 信号逻辑: +/// 1. 用 `di=n` 的 `w` 窗口给出历史新高/新低参考; +/// 2. 在最近 `n` 根中过滤出大实体K线并按顺序检查其右侧K线; +/// 3. 新高后转弱判 `看空`,新低后转强判 `看多`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_W100N5强弱_决策区域V240616_看空_任意_任意_0')` +/// - `Signal('60分钟_W100N5强弱_决策区域V240616_看多_任意_任意_0')` +/// +/// 参数说明: +/// - `w`:参考窗口,默认 `100`; +/// - `n`:决策窗口,默认 `5`。 +/// 对齐说明:候选筛选与右侧确认流程对齐 Python `bar_decision_V240616`。 +#[signal( + category = "kline", + name = "bar_decision_V240616", + template = "{freq}_W{w}N{n}强弱_决策区域V240616", + opcode = "BarDecisionV240616", + param_kind = "BarDecisionV240616" +)] +pub fn bar_decision_v240616(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let w = params.usize("w", 100); + let n = params.usize("n", 5); + if !(w > n && n > 2) { + return vec![]; + } + let k1 = c.freq.to_string(); + let k2 = format!("W{}N{}强弱", w, n); + let k3 = "决策区域V240616"; + if c.bars_raw.len() < w + n + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let w_bars = get_sub_elements(&c.bars_raw, n, w); + if w_bars.len() != w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let w_high = w_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let w_low = w_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let hl_mean = w_bars.iter().map(|x| x.high - x.low).sum::() / w_bars.len() as f64; + + let nb = get_sub_elements(&c.bars_raw, 1, n); + let n_bars: Vec<&czsc_core::objects::bar::RawBar> = + nb.iter().filter(|x| x.high - x.low > hl_mean).collect(); + + let mut v1 = "其他"; + for i in 0..n_bars.len() { + let bar = n_bars[i]; + let right = &n_bars[i + 1..]; + if right.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + if bar.high >= w_high && close_position_label(bar) != "高位收盘" { + for rb in right { + if close_position_label(rb) == "低位收盘" || rb.close < rb.low { + v1 = "看空"; + break; + } + } + } + if bar.low <= w_low && close_position_label(bar) != "低位收盘" { + for rb in right { + if close_position_label(rb) == "高位收盘" || rb.close > rb.high { + v1 = "看多"; + break; + } + } + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_td9_V240616:神奇九转计数 +/// +/// 参数模板:`"{freq}_神奇九转N{n}_BS辅助V240616"` +/// +/// 信号逻辑: +/// 1. 当前收盘与4根前收盘比较,得到 `1/-1/0`; +/// 2. 统计末端连续同号个数; +/// 3. 连续 `>=n` 个 `1` 输出 `卖点`,连续 `>=n` 个 `-1` 输出 `买点`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_神奇九转N9_BS辅助V240616_卖点_9转_任意_0')` +/// - `Signal('60分钟_神奇九转N9_BS辅助V240616_买点_9转_任意_0')` +/// +/// 参数说明: +/// - `n`:连续计数阈值,默认 `9`。 +/// 对齐说明:计数窗口与买卖点定义对齐 Python `bar_td9_V240616`。 +#[signal( + category = "kline", + name = "bar_td9_V240616", + template = "{freq}_神奇九转N{n}_BS辅助V240616", + opcode = "BarTd9V240616", + param_kind = "BarTd9V240616" +)] +pub fn bar_td9_v240616(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 9); + let k1 = c.freq.to_string(); + let k2 = format!("神奇九转N{}", n); + let k3 = "BS辅助V240616"; + if c.bars_raw.len() < 30 + n { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + let mut s = vec![0i32; c.bars_raw.len()]; + for (i, s_i) in s.iter_mut().enumerate().skip(4) { + *s_i = if c.bars_raw[i].close > c.bars_raw[i - 4].close { + 1 + } else if c.bars_raw[i].close < c.bars_raw[i - 4].close { + -1 + } else { + 0 + }; + } + let bars = get_sub_elements(&c.bars_raw, 1, n * 2); + if bars.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + let idx_map: HashMap = c + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + let bar_signs: Vec = bars + .iter() + .filter_map(|b| idx_map.get(&b.id).map(|i| s[*i])) + .collect(); + if bar_signs.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + + let mut v1 = "其他".to_string(); + let mut v2 = "任意".to_string(); + if bar_signs[bar_signs.len() - 1] == 1 { + let mut count = 0usize; + for x in bar_signs.iter().rev() { + if *x != 1 { + break; + } + count += 1; + } + if count >= n { + v1 = "卖点".to_string(); + v2 = format!("{}转", count); + } + } else if bar_signs[bar_signs.len() - 1] == -1 { + let mut count = 0usize; + for x in bar_signs.iter().rev() { + if *x != -1 { + break; + } + count += 1; + } + if count >= n { + v1 = "买点".to_string(); + v2 = format!("{}转", count); + } + } + make_kline_signal_v2(&k1, &k2, k3, &v1, &v2) +} + +/// bar_volatility_V241013:波动率三层分类 +/// +/// 参数模板:`"{freq}_波动率分层W{w}N{n}_完全分类V241013"` +/// +/// 信号逻辑: +/// 1. 定义 `volatility_n = 最近n根收盘最大值-最小值`; +/// 2. 对最近 `w` 根缓存值做三分位分层; +/// 3. 末根分层输出 `低波动/中波动/高波动`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_波动率分层W200N10_完全分类V241013_低波动_任意_任意_0')` +/// - `Signal('60分钟_波动率分层W200N10_完全分类V241013_高波动_任意_任意_0')` +/// +/// 参数说明: +/// - `w`:分层窗口,默认 `200`; +/// - `n`:波动率窗口,默认 `10`。 +/// 对齐说明:缓存写入与 `qcut` 退化行为对齐 Python `bar_volatility_V241013`。 +#[signal( + category = "kline", + name = "bar_volatility_V241013", + template = "{freq}_波动率分层W{w}N{n}_完全分类V241013", + opcode = "BarVolatilityV241013", + param_kind = "BarVolatilityV241013" +)] +pub fn bar_volatility_v241013(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let w = params.usize("w", 200); + let n = params.usize("n", 10); + let k1 = c.freq.to_string(); + let k2 = format!("波动率分层W{}N{}", w, n); + let k3 = "完全分类V241013"; + + let key = format!("bar_volatility_V241013#volatility_{}", n); + let ids: Vec = c.bars_raw.iter().map(|b| b.id).collect(); + let mut series = + if let (Some(v), Some(sid)) = (cache.series.get(&key), cache.series_ids.get(&key)) { + let mut m: HashMap = HashMap::with_capacity(sid.len()); + for (i, id) in sid.iter().enumerate() { + m.insert(*id, v[i]); + } + ids.iter() + .map(|id| *m.get(id).unwrap_or(&f64::NAN)) + .collect::>() + } else { + vec![f64::NAN; ids.len()] + }; + if !c.bars_raw.is_empty() { + // 对齐 Python:仅检查“当前最后 n 根”是否缺 cache, + // 若缺失则统一写入“当前最后 n 根 close 极差”。 + // 这意味着: + // 1. 更早历史不会被回填,未初始化值继续保留为 NaN(后续按 0 参与分层); + // 2. 最后一根未完成高周期 bar 在流式更新时会被重建,因此末值需要每次覆盖。 + let last = c.bars_raw.len() - 1; + let start = (last + 1).saturating_sub(n); + let window = &c.bars_raw[start..=last]; + let n_max = window + .iter() + .map(|x| x.close) + .fold(f64::NEG_INFINITY, f64::max); + let n_min = window.iter().map(|x| x.close).fold(f64::INFINITY, f64::min); + let current_vol = n_max - n_min; + + for item in series.iter_mut().take(last + 1).skip(start) { + if !item.is_finite() { + *item = current_vol; + } + } + series[last] = current_vol; + } + cache.series.insert(key.clone(), series.clone()); + cache.series_ids.insert(key, ids); + + if c.bars_raw.len() < w + n + 100 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let tail = get_sub_elements(series.as_slice(), 1, w); + if tail.len() != w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let vols: Vec = tail + .iter() + .map(|x| if x.is_finite() { *x } else { 0.0 }) + .collect(); + let v1 = qcut_three_label_last(&vols).unwrap_or("其他"); + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_zfzd_V241013:窄幅震荡(全重叠) +/// +/// 参数模板:`"{freq}_窄幅震荡N{n}_形态V241013"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 根K线; +/// 2. 若 `min(high) >= max(low)`,判为窗口内全重叠; +/// 3. 输出 `满足`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_窄幅震荡N5_形态V241013_满足_任意_任意_0')` +/// - `Signal('60分钟_窄幅震荡N5_形态V241013_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:窗口长度,默认 `5`。 +/// 对齐说明:重叠判定公式与 Python `bar_zfzd_V241013` 一致。 +#[signal( + category = "kline", + name = "bar_zfzd_V241013", + template = "{freq}_窄幅震荡N{n}_形态V241013", + opcode = "BarZfzdV241013", + param_kind = "BarZfzdV241013" +)] +pub fn bar_zfzd_v241013(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 5); + let k1 = c.freq.to_string(); + let k2 = format!("窄幅震荡N{}", n); + let k3 = "形态V241013"; + if c.bars_raw.len() < n + 50 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, n); + if bars.len() != n { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let zg = bars.iter().map(|x| x.high).fold(f64::INFINITY, f64::min); + let zd = bars.iter().map(|x| x.low).fold(f64::NEG_INFINITY, f64::max); + let v1 = if zg >= zd { "满足" } else { "其他" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// bar_zfzd_V241014:窄幅震荡(最大实体重叠) +/// +/// 参数模板:`"{freq}_窄幅震荡N{n}_形态V241014"` +/// +/// 信号逻辑: +/// 1. 找到窗口内最大实体K线; +/// 2. 若其实体明显过大(超过窗口实体均值2倍)直接排除; +/// 3. 若该K线与窗口内所有K线区间均重叠,判 `满足`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_窄幅震荡N10_形态V241014_满足_任意_任意_0')` +/// - `Signal('60分钟_窄幅震荡N10_形态V241014_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:窗口长度,默认 `5`。 +/// 对齐说明:最大实体筛选与重叠判断对齐 Python `bar_zfzd_V241014`。 +#[signal( + category = "kline", + name = "bar_zfzd_V241014", + template = "{freq}_窄幅震荡N{n}_形态V241014", + opcode = "BarZfzdV241014", + param_kind = "BarZfzdV241014" +)] +pub fn bar_zfzd_v241014(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 5); + let k1 = c.freq.to_string(); + let k2 = format!("窄幅震荡N{}", n); + let k3 = "形态V241014"; + if c.bars_raw.len() < n + 50 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, n); + if bars.len() != n { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let mut max_bar = &bars[0]; + for b in &bars[1..] { + if bar_solid(b) > bar_solid(max_bar) { + max_bar = b; + } + } + let mean_solid = bars.iter().map(bar_solid).sum::() / bars.len() as f64; + if bar_solid(max_bar) > 2.0 * mean_solid { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let ok = bars + .iter() + .all(|x| x.high.min(max_bar.high) > x.low.max(max_bar.low)); + let v1 = if ok { "满足" } else { "其他" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/byi.rs b/crates/czsc-signals/src/byi.rs new file mode 100644 index 000000000..c55949b57 --- /dev/null +++ b/crates/czsc-signals/src/byi.rs @@ -0,0 +1,412 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{ + get_sub_elements, make_kline_signal_v1, make_kline_signal_v2, make_kline_signal_v3, +}; +use crate::utils::ta::update_macd_cache; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::bi::BI; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::fx::FX; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::zs::ZS; +use czsc_signal_macros::signal; +use std::collections::HashMap; + +fn fx_raw_bars(fx: &FX) -> Vec { + fx.elements + .iter() + .flat_map(|nb| nb.elements.iter().cloned()) + .collect() +} + +fn raw_bar_solid(bar: &RawBar) -> f64 { + (bar.open - bar.close).abs() +} + +fn raw_bar_upper(bar: &RawBar) -> f64 { + bar.high - bar.open.max(bar.close) +} + +fn raw_bar_lower(bar: &RawBar) -> f64 { + bar.open.min(bar.close) - bar.low +} + +fn mean(values: &[f64]) -> f64 { + if values.is_empty() { + f64::NAN + } else { + values.iter().sum::() / values.len() as f64 + } +} + +fn std_pop(values: &[f64]) -> f64 { + let m = mean(values); + if !m.is_finite() || values.is_empty() { + return f64::NAN; + } + let var = values.iter().map(|x| (x - m).powi(2)).sum::() / values.len() as f64; + var.sqrt() +} + +fn is_symmetry_zs(bis: &[BI], th: f64) -> bool { + if bis.len().is_multiple_of(2) { + return false; + } + let zs = ZS::new(bis.to_vec()); + if zs.zd > zs.zg { + return false; + } + let max_low = bis.iter().map(|x| x.get_low()).fold(f64::NEG_INFINITY, f64::max); + let min_high = bis.iter().map(|x| x.get_high()).fold(f64::INFINITY, f64::min); + if max_low > min_high { + return false; + } + let zns: Vec = bis.iter().map(|x| x.get_power_price()).collect(); + let m = mean(&zns); + let s = std_pop(&zns); + m.is_finite() && s.is_finite() && m != 0.0 && s / m <= th +} + +/// byi_symmetry_zs_V221107:对称中枢识别信号 +/// +/// 参数模板:`"{freq}_D{di}B_对称中枢"` +/// +/// 信号逻辑: +/// 1. 取倒数 `di` 截止最近 10 笔; +/// 2. 依次检查最近 `7/5/3` 笔是否构成对称中枢; +/// 3. 命中则输出 `是 + {i}笔`,否则 `否 + 任意`; +/// 4. 方向位按最后一笔反向映射(最后笔向下 -> `向上`)。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1B_对称中枢_是_向上_7笔_0')` +/// - `Signal('60分钟_D1B_对称中枢_否_向下_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始,默认 `1`。 +/// 对齐说明:与 Python `byi_symmetry_zs_V221107` 的 7/5/3 笔判定序一致。 +#[signal( + category = "kline", + name = "byi_symmetry_zs_V221107", + template = "{freq}_D{di}B_对称中枢V221107", + opcode = "ByiSymmetryZsV221107", + param_kind = "ByiSymmetryZsV221107" +)] +pub fn byi_symmetry_zs_v221107(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let bis = get_sub_elements(&c.bi_list, di, 10); + let k1 = c.freq.to_string(); + let k2 = format!("D{}B", di); + let k3 = "对称中枢"; + if bis.len() < 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut ok = false; + let mut v3 = "任意".to_string(); + for i in [7usize, 5, 3] { + ok = is_symmetry_zs(&bis[bis.len() - i..], 0.3); + if ok { + v3 = format!("{}笔", i); + break; + } + } + let v1 = if ok { "是" } else { "否" }; + let v2 = if bis[bis.len() - 1].direction == Direction::Down { + "向上" + } else { + "向下" + }; + make_kline_signal_v3(&k1, &k2, k3, v1, v2, &v3) +} + +/// byi_bi_end_V230106:分型停顿辅助笔结束信号 +/// +/// 参数模板:`"{freq}_D0停顿分型_BE辅助V230106"` +/// +/// 信号逻辑: +/// 1. 基于最后一笔方向与末端分型,判断停顿分型是否成立; +/// 2. 满足底分型停顿给出 `看多`,满足顶分型停顿给出 `看空`; +/// 3. 再按最后一根K线实体强弱输出 `强/弱`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0停顿分型_BE辅助V230106_看多_强_任意_0')` +/// - `Signal('60分钟_D0停顿分型_BE辅助V230106_看空_弱_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空。 +/// 对齐说明:与 Python `byi_bi_end_V230106` 的停顿判定条件一致。 +#[signal( + category = "kline", + name = "byi_bi_end_V230106", + template = "{freq}_D0停顿分型_BE辅助V230106", + opcode = "ByiBiEndV230106", + param_kind = "ByiBiEndV230106" +)] +pub fn byi_bi_end_v230106(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "D0停顿分型"; + let k3 = "BE辅助V230106"; + if c.bi_list.len() < 3 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, 3); + if bars.len() < 3 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_bi = &c.bi_list[c.bi_list.len() - 1]; + let last_fx = &last_bi.fx_b; + let bar1 = &bars[0]; + let bar3 = &bars[2]; + let fx_raw = fx_raw_bars(last_fx); + if fx_raw.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + + let lc1 = last_bi.direction == Direction::Down && last_fx.mark == Mark::D && bar1.low == last_fx.low; + if lc1 { + let fx_high = fx_raw.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + if bar3.close > fx_high { + let v2 = if bar3.close > bar3.open + && raw_bar_solid(bar3) > raw_bar_upper(bar3).max(raw_bar_lower(bar3)) + { + "强" + } else { + "弱" + }; + return make_kline_signal_v2(&k1, k2, k3, "看多", v2); + } + } + + let sc1 = last_bi.direction == Direction::Up && last_fx.mark == Mark::G && bar1.high == last_fx.high; + if sc1 { + let fx_low = fx_raw.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + if bar3.close < fx_low { + let v2 = if bar3.close < bar3.open + && raw_bar_solid(bar3) > raw_bar_upper(bar3).max(raw_bar_lower(bar3)) + { + "强" + } else { + "弱" + }; + return make_kline_signal_v2(&k1, k2, k3, "看空", v2); + } + } + make_kline_signal_v1(&k1, k2, k3, "其他") +} + +/// byi_bi_end_V230107:验证分型辅助笔结束信号 +/// +/// 参数模板:`"{freq}_D0验证分型_BE辅助V230107"` +/// +/// 信号逻辑: +/// 1. 校验最后一笔末端分型与末三分型结构关系; +/// 2. 满足验证底分型给出 `看多`,验证顶分型给出 `看空`; +/// 3. 依据最后一根K线实体强弱输出 `强/弱`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0验证分型_BE辅助V230107_看多_强_任意_0')` +/// - `Signal('60分钟_D0验证分型_BE辅助V230107_看空_弱_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空。 +/// 对齐说明:与 Python `byi_bi_end_V230107` 的结构校验和强弱规则一致。 +#[signal( + category = "kline", + name = "byi_bi_end_V230107", + template = "{freq}_D0验证分型_BE辅助V230107", + opcode = "ByiBiEndV230107", + param_kind = "ByiBiEndV230107" +)] +pub fn byi_bi_end_v230107(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "D0验证分型"; + let k3 = "BE辅助V230107"; + if c.bi_list.len() < 3 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let fxs = c.get_fx_list(); + if fxs.len() < 3 || c.bars_raw.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_bi = &c.bi_list[c.bi_list.len() - 1]; + let fx1 = &fxs[fxs.len() - 3]; + let fx2 = &fxs[fxs.len() - 2]; + let fx3 = &fxs[fxs.len() - 1]; + let bar1 = &c.bars_raw[c.bars_raw.len() - 1]; + let fx3_raw = fx_raw_bars(fx3); + if fx3_raw.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + if !(last_bi.fx_b.dt == fx1.dt && fx1.mark == fx3.mark && bar1.dt == fx3_raw[fx3_raw.len() - 1].dt) { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + + let mut close_seq = Vec::new(); + close_seq.extend( + fx_raw_bars(fx1).into_iter().map(|x| x.close), + ); + close_seq.extend(fx_raw_bars(fx2).into_iter().map(|x| x.close)); + if close_seq.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let high_c = close_seq.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let low_c = close_seq.iter().copied().fold(f64::INFINITY, f64::min); + + let lc1 = raw_bar_solid(bar1) > raw_bar_upper(bar1).max(raw_bar_lower(bar1)) + && bar1.close > bar1.open + && bar1.close > high_c; + if last_bi.direction == Direction::Down && fx1.mark == Mark::D && fx3.low > fx1.low { + let v2 = if lc1 { "强" } else { "弱" }; + return make_kline_signal_v2(&k1, k2, k3, "看多", v2); + } + + let sc1 = raw_bar_solid(bar1) > raw_bar_upper(bar1).max(raw_bar_lower(bar1)) + && bar1.close < bar1.open + && bar1.close < low_c; + if last_bi.direction == Direction::Up && fx1.mark == Mark::G && fx3.high < fx1.high { + let v2 = if sc1 { "强" } else { "弱" }; + return make_kline_signal_v2(&k1, k2, k3, "看空", v2); + } + make_kline_signal_v1(&k1, k2, k3, "其他") +} + +/// byi_second_bs_V230324:二类买卖点辅助信号 +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}回抽零轴_BS2辅助V230324"` +/// +/// 信号逻辑: +/// 1. 基于最近 9 笔关键分型的 DIF 值和标准差构造条件; +/// 2. 满足向下笔回抽零轴条件判 `看多`; +/// 3. 满足向上笔回抽零轴条件判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9回抽零轴_BS2辅助V230324_看多_任意_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9回抽零轴_BS2辅助V230324_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始检查,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 +/// 对齐说明:按 Python `byi_second_bs_V230324` 的 DIF 取样点和不等式链实现。 +#[signal( + category = "kline", + name = "byi_second_bs_V230324", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}回抽零轴_BS2辅助V230324", + opcode = "ByiSecondBsV230324", + param_kind = "ByiSecondBsV230324" +)] +pub fn byi_second_bs_v230324(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let fast = params.usize("fastperiod", 12); + let slow = params.usize("slowperiod", 26); + let signalperiod = params.usize("signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fast, slow, signalperiod); + update_macd_cache(c, &cache_key, fast, slow, signalperiod, cache); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}{}回抽零轴", di, cache_key); + let k3 = "BS2辅助V230324"; + if c.bi_list.len() < di + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 9); + if bis.len() < 9 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let b1 = &bis[0]; + let b3 = &bis[2]; + let b5 = &bis[4]; + let b8 = &bis[7]; + let b9 = &bis[8]; + + let macd = match cache.macd.get(&cache_key) { + Some(m) => m, + None => return make_kline_signal_v1(&k1, &k2, k3, "其他"), + }; + let id_to_idx: HashMap = macd + .ids + .iter() + .enumerate() + .map(|(i, id)| (*id, i)) + .collect(); + let dif_at = |id: i32| -> f64 { + id_to_idx + .get(&id) + .and_then(|i| macd.dif.get(*i)) + .copied() + .unwrap_or(f64::NAN) + }; + let fx_mid_id = |bi: &BI| -> Option { + let rb = fx_raw_bars(&bi.fx_b); + if rb.len() > 1 { Some(rb[1].id) } else { None } + }; + + let b1_dif = fx_mid_id(b1).map(dif_at).unwrap_or(f64::NAN); + let b3_dif = fx_mid_id(b3).map(dif_at).unwrap_or(f64::NAN); + let b5_dif = fx_mid_id(b5).map(dif_at).unwrap_or(f64::NAN); + let b8_dif = fx_mid_id(b8).map(dif_at).unwrap_or(f64::NAN); + let b9_dif = fx_mid_id(b9).map(dif_at).unwrap_or(f64::NAN); + let b1_raw = b1.get_raw_bars(); + let dif_seq: Vec = b1_raw.iter().map(|x| dif_at(x.id)).collect(); + let dif_std = std_pop(&dif_seq); + + let mut v1 = "其他"; + if b9.direction == Direction::Down + && b1_dif.max(b3_dif).max(b5_dif) < 0.0 + && b1_dif.min(b3_dif).min(b5_dif) < -2.0 * dif_std + && b8_dif > dif_std + && b9_dif.abs() < 0.3 * dif_std + { + v1 = "看多"; + } + if b9.direction == Direction::Up + && b1_dif.min(b3_dif).min(b5_dif) > 0.0 + && b1_dif.max(b3_dif).max(b5_dif) > 2.0 * dif_std + && b8_dif < -dif_std + && b9_dif.abs() < 0.3 * dif_std + { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// byi_fx_num_V230628:前笔分型数量约束信号 +/// +/// 参数模板:`"{freq}_D{di}笔分型数大于{num}_BE辅助V230628"` +/// +/// 信号逻辑: +/// 1. 取倒数第 `di` 笔; +/// 2. 输出该笔方向(`向上/向下`); +/// 3. 若该笔内部分型数量 `>= num` 记 `满足`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1笔分型数大于4_BE辅助V230628_向下_满足_任意_0')` +/// - `Signal('60分钟_D1笔分型数大于4_BE辅助V230628_向上_其他_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始检查,默认 `1`; +/// - `num`:分型数量阈值,默认 `4`。 +/// 对齐说明:与 Python `byi_fx_num_V230628` 的数量判断一致。 +#[signal( + category = "kline", + name = "byi_fx_num_V230628", + template = "{freq}_D{di}笔分型数大于{num}_BE辅助V230628", + opcode = "ByiFxNumV230628", + param_kind = "ByiFxNumV230628" +)] +pub fn byi_fx_num_v230628(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let num = params.usize("num", 4); + let k1 = c.freq.to_string(); + let k2 = format!("D{}笔分型数大于{}", di, num); + let k3 = "BE辅助V230628"; + if c.bi_list.len() < di + 1 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bi = &c.bi_list[c.bi_list.len() - di]; + let v1 = bi.direction.to_string(); + let v2 = if bi.fxs.len() >= num { "满足" } else { "其他" }; + make_kline_signal_v2(&k1, &k2, k3, &v1, v2) +} diff --git a/crates/czsc-signals/src/cat.rs b/crates/czsc-signals/src/cat.rs new file mode 100644 index 000000000..381a4c626 --- /dev/null +++ b/crates/czsc-signals/src/cat.rs @@ -0,0 +1,322 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_str_param, get_sub_elements, make_signal, make_signal_v1}; +use crate::utils::ta::update_macd_cache; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::state::TraderState; +use czsc_signal_macros::signal; +use std::collections::HashMap; + +fn macd_map(cache: &TaCache, cache_key: &str) -> HashMap { + let mut out = HashMap::new(); + if let Some(series) = cache.macd.get(cache_key) { + for (i, id) in series.ids.iter().enumerate() { + if let Some(v) = series.macd.get(i) { + out.insert(*id, *v); + } + } + } + out +} + +fn dea_map(cache: &TaCache, cache_key: &str) -> HashMap { + let mut out = HashMap::new(); + if let Some(series) = cache.macd.get(cache_key) { + for (i, id) in series.ids.iter().enumerate() { + if let Some(v) = series.dea.get(i) { + out.insert(*id, *v); + } + } + } + out +} + +fn cross_up_bars<'a>(bars: &[&'a RawBar], macd: &HashMap) -> Vec<&'a RawBar> { + let mut out = Vec::new(); + for w in bars.windows(2) { + let m1 = *macd.get(&w[0].id).unwrap_or(&f64::NAN); + let m2 = *macd.get(&w[1].id).unwrap_or(&f64::NAN); + if m1 < 0.0 && m2 > 0.0 { + out.push(w[1]); + } + } + out +} + +fn cross_down_bars<'a>(bars: &[&'a RawBar], macd: &HashMap) -> Vec<&'a RawBar> { + let mut out = Vec::new(); + for w in bars.windows(2) { + let m1 = *macd.get(&w[0].id).unwrap_or(&f64::NAN); + let m2 = *macd.get(&w[1].id).unwrap_or(&f64::NAN); + if m1 > 0.0 && m2 < 0.0 { + out.push(w[1]); + } + } + out +} + +/// cat_macd_V230518:高低级别 MACD 交叉联立信号 +/// +/// 参数模板:`"{freq1}#{freq2}_MACD交叉_联立V230518"` +/// +/// 信号逻辑: +/// 1. 当 `freq1` 最近一次由负翻正(MACD 金叉)后,检查 `freq2` 是否仅出现 1 次金叉,满足判 `看多`; +/// 2. 当 `freq1` 最近一次由正翻负(MACD 死叉)后,检查 `freq2` 是否仅出现 1 次死叉,满足判 `看空`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线#60分钟_MACD交叉_联立V230518_看多_任意_任意_0')` +/// - `Signal('日线#60分钟_MACD交叉_联立V230518_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `freq1`:高一级别周期,默认 `5分钟`; +/// - `freq2`:低一级别周期,默认 `1分钟`。 +/// 对齐说明:触发窗口、首次交叉判定与 Python `cat_macd_V230518` 保持一致。 +#[signal( + category = "trader", + name = "cat_macd_V230518", + template = "{freq1}#{freq2}_MACD交叉_联立V230518", + opcode = "CatMacdV230518", + param_kind = "CatMacdV230518" +)] +pub fn cat_macd_v230518(cat: &dyn TraderState, params: &ParamView) -> Vec { + let freq1 = get_str_param(params, "freq1", "5分钟"); + let freq2 = get_str_param(params, "freq2", "1分钟"); + + let k1 = format!("{}#{}", freq1, freq2); + let k2 = "MACD交叉"; + let k3 = "联立V230518"; + + let Some(c1) = cat.get_czsc(freq1) else { + return make_signal_v1(&k1, k2, k3, "其他"); + }; + let Some(c2) = cat.get_czsc(freq2) else { + return make_signal_v1(&k1, k2, k3, "其他"); + }; + if c1.bars_raw.len() < 50 || c2.bars_raw.len() < 50 { + return make_signal_v1(&k1, k2, k3, "其他"); + } + + let cache_key = "MACD12#26#9"; + let mut c1_cache = TaCache::new(); + let mut c2_cache = TaCache::new(); + update_macd_cache(c1, cache_key, 12, 26, 9, &mut c1_cache); + update_macd_cache(c2, cache_key, 12, 26, 9, &mut c2_cache); + let c1_macd_map = macd_map(&c1_cache, cache_key); + let c2_macd_map = macd_map(&c2_cache, cache_key); + + let c1_bars: Vec<&RawBar> = get_sub_elements(&c1.bars_raw, 1, 8).iter().collect(); + let c2_bars: Vec<&RawBar> = get_sub_elements(&c2.bars_raw, 1, 50).iter().collect(); + if c1_bars.len() < 2 || c2_bars.len() < 2 { + return make_signal_v1(&k1, k2, k3, "其他"); + } + + let c1_macd: Vec = c1_bars + .iter() + .filter_map(|b| c1_macd_map.get(&b.id).copied()) + .collect(); + if c1_macd.len() != c1_bars.len() { + return make_signal_v1(&k1, k2, k3, "其他"); + } + + if c1_macd.last().copied().unwrap_or(0.0) > 0.0 + && c1_macd.iter().filter(|x| **x > 0.0).count() != c1_macd.len() + { + let c1_gold = cross_up_bars(&c1_bars, &c1_macd_map); + if let Some(last_gold) = c1_gold.last() { + let c2_after: Vec<&RawBar> = c2_bars + .iter() + .copied() + .filter(|x| x.dt > last_gold.dt) + .collect(); + if c2_after.len() > 3 { + let c2_gold = cross_up_bars(&c2_after, &c2_macd_map); + if c2_gold.len() == 1 { + return make_signal_v1(&k1, k2, k3, "看多"); + } + } + } + } + + if c1_macd.last().copied().unwrap_or(0.0) < 0.0 + && c1_macd.iter().filter(|x| **x < 0.0).count() != c1_macd.len() + { + let c1_dead = cross_down_bars(&c1_bars, &c1_macd_map); + if let Some(last_dead) = c1_dead.last() { + let c2_after: Vec<&RawBar> = c2_bars + .iter() + .copied() + .filter(|x| x.dt > last_dead.dt) + .collect(); + if c2_after.len() > 3 { + let c2_dead = cross_down_bars(&c2_after, &c2_macd_map); + if c2_dead.len() == 1 { + return make_signal_v1(&k1, k2, k3, "看空"); + } + } + } + } + + make_signal_v1(&k1, k2, k3, "其他") +} + +/// cat_macd_V230520:高低级别 MACD 缩柱联立信号 +/// +/// 参数模板:`"{freq1}#{freq2}_MACD交叉_联立V230520"` +/// +/// 信号逻辑: +/// 1. `freq1` 最近三根 MACD 连续抬升且历史出现负值时,检查 `freq2` 的金死叉结构,满足判 `看多`; +/// 2. `freq1` 最近三根 MACD 连续下压且历史出现正值时,检查 `freq2` 的死金叉结构,满足判 `看空`; +/// 3. 同时给出触发时 `DEA` 在零轴上下的位置 `v2`。 +/// +/// 信号列表示例: +/// - `Signal('日线#60分钟_MACD交叉_联立V230520_看多_零轴上方_任意_0')` +/// - `Signal('日线#60分钟_MACD交叉_联立V230520_看空_零轴下方_任意_0')` +/// +/// 参数说明: +/// - `freq1`:高一级别周期,默认 `5分钟`; +/// - `freq2`:低一级别周期,默认 `1分钟`。 +/// 对齐说明:交叉次数、顺序和阈值条件与 Python `cat_macd_V230520` 一致。 +#[signal( + category = "trader", + name = "cat_macd_V230520", + template = "{freq1}#{freq2}_MACD交叉_联立V230520", + opcode = "CatMacdV230520", + param_kind = "CatMacdV230520" +)] +pub fn cat_macd_v230520(cat: &dyn TraderState, params: &ParamView) -> Vec { + let freq1 = get_str_param(params, "freq1", "5分钟"); + let freq2 = get_str_param(params, "freq2", "1分钟"); + + let k1 = format!("{}#{}", freq1, freq2); + let k2 = "MACD交叉"; + let k3 = "联立V230520"; + + let Some(c1) = cat.get_czsc(freq1) else { + return make_signal_v1(&k1, k2, k3, "其他"); + }; + let Some(c2) = cat.get_czsc(freq2) else { + return make_signal_v1(&k1, k2, k3, "其他"); + }; + if c1.bars_raw.len() < 50 || c2.bars_raw.len() < 50 { + return make_signal_v1(&k1, k2, k3, "其他"); + } + + let cache_key = "MACD12#26#9"; + let mut c1_cache = TaCache::new(); + let mut c2_cache = TaCache::new(); + update_macd_cache(c1, cache_key, 12, 26, 9, &mut c1_cache); + update_macd_cache(c2, cache_key, 12, 26, 9, &mut c2_cache); + let c1_macd_map = macd_map(&c1_cache, cache_key); + let c2_macd_map = macd_map(&c2_cache, cache_key); + let c2_dea_map = dea_map(&c2_cache, cache_key); + + let c1_bars: Vec<&RawBar> = get_sub_elements(&c1.bars_raw, 1, 8).iter().collect(); + let c2_bars: Vec<&RawBar> = get_sub_elements(&c2.bars_raw, 1, 50).iter().collect(); + if c1_bars.len() < 3 || c2_bars.len() < 2 { + return make_signal_v1(&k1, k2, k3, "其他"); + } + + let c1_macd: Vec = c1_bars + .iter() + .filter_map(|b| c1_macd_map.get(&b.id).copied()) + .collect(); + if c1_macd.len() != c1_bars.len() { + return make_signal_v1(&k1, k2, k3, "其他"); + } + + let li = c1_macd.len() - 1; + let up3 = c1_macd[li - 2] < c1_macd[li - 1] && c1_macd[li - 1] < c1_macd[li]; + let down3 = c1_macd[li - 2] > c1_macd[li - 1] && c1_macd[li - 1] > c1_macd[li]; + + if up3 + && c1_macd + .iter() + .copied() + .fold(f64::INFINITY, f64::min) + < 0.0 + { + let min_bar = c1_bars + .iter() + .min_by(|a, b| a.low.partial_cmp(&b.low).unwrap_or(std::cmp::Ordering::Equal)) + .copied(); + if let Some(min_bar) = min_bar { + let c2_after: Vec<&RawBar> = c2_bars + .iter() + .copied() + .filter(|x| x.dt > min_bar.dt) + .collect(); + if c2_after.len() > 3 { + let last_bar = *c2_after.last().unwrap(); + let c2_vals: Vec = c2_after + .iter() + .filter_map(|b| c2_macd_map.get(&b.id).copied()) + .collect(); + if !c2_vals.is_empty() { + let min_macd = c2_vals.iter().copied().fold(f64::INFINITY, f64::min); + let max_macd = c2_vals.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let c2_gold = cross_up_bars(&c2_after, &c2_macd_map); + let c2_dead = cross_down_bars(&c2_after, &c2_macd_map); + if c2_gold.len() == 1 + && c2_dead.len() == 1 + && c2_gold[0].id - c2_dead[0].id >= 5 + && last_bar.dt == c2_gold[0].dt + && c2_gold[0].dt > c2_dead[0].dt + && min_macd.abs() > max_macd.abs() * 0.3 + { + let dea = *c2_dea_map.get(&c2_gold[0].id).unwrap_or(&0.0); + let v2 = if dea > 0.0 { "零轴上方" } else { "零轴下方" }; + return make_signal(&k1, k2, k3, "看多", v2); + } + } + } + } + } + + if down3 + && c1_macd + .iter() + .copied() + .fold(f64::NEG_INFINITY, f64::max) + > 0.0 + { + let max_bar = c1_bars + .iter() + .max_by(|a, b| a.high.partial_cmp(&b.high).unwrap_or(std::cmp::Ordering::Equal)) + .copied(); + if let Some(max_bar) = max_bar { + let c2_after: Vec<&RawBar> = c2_bars + .iter() + .copied() + .filter(|x| x.dt > max_bar.dt) + .collect(); + if c2_after.len() > 3 { + let last_bar = *c2_after.last().unwrap(); + let c2_vals: Vec = c2_after + .iter() + .filter_map(|b| c2_macd_map.get(&b.id).copied()) + .collect(); + if !c2_vals.is_empty() { + let min_macd = c2_vals.iter().copied().fold(f64::INFINITY, f64::min); + let max_macd = c2_vals.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let c2_gold = cross_up_bars(&c2_after, &c2_macd_map); + let c2_dead = cross_down_bars(&c2_after, &c2_macd_map); + if c2_dead.len() == 1 + && c2_gold.len() == 1 + && c2_dead[0].id - c2_gold[0].id >= 5 + && last_bar.dt == c2_dead[0].dt + && c2_dead[0].dt > c2_gold[0].dt + && max_macd.abs() > min_macd.abs() * 0.3 + { + let dea = *c2_dea_map.get(&c2_dead[0].id).unwrap_or(&0.0); + let v2 = if dea > 0.0 { "零轴上方" } else { "零轴下方" }; + return make_signal(&k1, k2, k3, "看空", v2); + } + } + } + } + } + + make_signal_v1(&k1, k2, k3, "其他") +} diff --git a/crates/czsc-signals/src/clv.rs b/crates/czsc-signals/src/clv.rs new file mode 100644 index 000000000..4ce47801e --- /dev/null +++ b/crates/czsc-signals/src/clv.rs @@ -0,0 +1,59 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +/// clv_up_dw_line_V230605:CLV 多空信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}_CLV多空V230605"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 根K线,计算每根 `(2*close-low-high)/(high-low)`; +/// 2. 计算该序列均值 `clv_ma`; +/// 3. `clv_ma > 0` 判 `看多`,否则判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N70_CLV多空V230605_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N70_CLV多空V230605_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:统计窗口大小,默认 `70`。 +/// 对齐说明:CLV 公式与阈值判断对齐 Python `clv_up_dw_line_V230605`。 +#[signal( + category = "kline", + name = "clv_up_dw_line_V230605", + template = "{freq}_D{di}N{n}_CLV多空V230605", + opcode = "ClvUpDwLineV230605", + param_kind = "ClvUpDwLineV230605" +)] +pub fn clv_up_dw_line_v230605( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 70); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}", di, n); + let k3 = "CLV多空V230605"; + + if c.bars_raw.len() < di + 100 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let mut vals = Vec::with_capacity(bars.len()); + for b in bars { + let v = (2.0 * b.close - b.low - b.high) / (b.high - b.low); + vals.push(v); + } + let clv_ma = vals.iter().sum::() / vals.len() as f64; + let v1 = if clv_ma > 0.0 { "看多" } else { "看空" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/coo.rs b/crates/czsc-signals/src/coo.rs new file mode 100644 index 000000000..b3f1372bc --- /dev/null +++ b/crates/czsc-signals/src/coo.rs @@ -0,0 +1,345 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1, make_kline_signal_v2}; +use crate::utils::ta::{update_cci_cache, update_kdj_cache, update_ma_cache, update_sar_cache}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +fn cal_td_seq(close: &[f64]) -> Vec { + if close.len() < 5 { + return vec![0; close.len()]; + } + let mut res = vec![0; close.len()]; + for i in 4..close.len() { + if close[i] > close[i - 4] { + res[i] = res[i - 1] + 1; + } else if close[i] < close[i - 4] { + res[i] = res[i - 1] - 1; + } + } + res +} + +fn td_signal_from_close(close: &[f64]) -> (&'static str, &'static str) { + let td = cal_td_seq(close); + let x = *td.last().unwrap_or(&0); + if x > 0 { + let v1 = if td.len() > 1 && td[td.len() - 2] < -8 { + "看多" + } else { + "延续" + }; + let v2 = if x > 8 { "TD顶" } else { "非顶" }; + (v1, v2) + } else if x < 0 { + let v1 = if td.len() > 1 && td[td.len() - 2] > 8 { + "看空" + } else { + "延续" + }; + let v2 = if x < -8 { "TD底" } else { "非底" }; + (v1, v2) + } else { + ("其他", "其他") + } +} + +/// coo_td_V221110:TD 神奇九转信号(旧版模板) +/// +/// 参数模板:`"{freq}_D{di}K_TD"` +/// +/// 信号逻辑: +/// 1. 取倒数 `di` 截止的最近 50 根收盘价; +/// 2. 按 `close[i]` 与 `close[i-4]` 比较累计 TD 计数; +/// 3. 根据最新 TD 值及前一值输出 `看多/看空/延续` 与 `TD顶/TD底/非顶/非底`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K_TD_延续_非顶_任意_0')` +/// - `Signal('60分钟_D1K_TD_看空_TD底_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `coo_td_V221110` 的 TD 计数递推一致。 +#[signal( + category = "kline", + name = "coo_td_V221110", + template = "{freq}_D{di}K_TDV221110", + opcode = "CooTdV221110", + param_kind = "CooTdV221110" +)] +pub fn coo_td_v221110(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "TD"; + let bars = get_sub_elements(&c.bars_raw, di, 50); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let close: Vec = bars.iter().map(|x| x.close).collect(); + let (v1, v2) = td_signal_from_close(&close); + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// coo_td_V221111:TD 神奇九转信号 +/// +/// 参数模板:`"{freq}_D{di}TD_BS辅助V221111"` +/// +/// 信号逻辑: +/// 1. 取倒数 `di` 截止的最近 50 根收盘价; +/// 2. 计算 TD 计数序列; +/// 3. 输出 `看多/看空/延续` 与 `TD顶/TD底/非顶/非底` 组合。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1TD_BS辅助V221111_延续_非顶_任意_0')` +/// - `Signal('60分钟_D1TD_BS辅助V221111_看多_TD顶_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `coo_td_V221111` 的窗口和判定分支一致。 +#[signal( + category = "kline", + name = "coo_td_V221111", + template = "{freq}_D{di}TD_BS辅助V221111", + opcode = "CooTdV221111", + param_kind = "CooTdV221111" +)] +pub fn coo_td_v221111(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}TD", di); + let k3 = "BS辅助V221111"; + if c.bars_raw.len() < 50 + di { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, 50); + let close: Vec = bars.iter().map(|x| x.close).collect(); + let (v1, v2) = td_signal_from_close(&close); + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// coo_cci_V230323:CCI 结合均线的多空与方向信号 +/// +/// 参数模板:`"{freq}_D{di}CCI{n}#{ma_type}#{m}_BS辅助V230323"` +/// +/// 信号逻辑: +/// 1. 计算 `CCI(n)` 与 `MA(n*m)`; +/// 2. `CCI>100` 且 `close>MA` 判 `多头`,`CCI<-100` 且 `close Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 20); + let m = params.usize("m", 5); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + + let cci_key = format!("CCI{}", n); + let ma_key = format!("{}#{}", ma_type, n * m); + update_cci_cache(c, &cci_key, n, cache); + update_ma_cache(c, &ma_key, &ma_type, n * m, cache); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}CCI{}#{}#{}", di, n, ma_type, m); + let k3 = "BS辅助V230323"; + if c.bars_raw.len() < n * m + di { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + if c.bars_raw.len() < di + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let idx = c.bars_raw.len() - di; + let prev_idx = idx.saturating_sub(1); + let cci = cache + .series + .get(&cci_key) + .and_then(|x| x.get(idx)) + .copied() + .unwrap_or(f64::NAN); + let cci_prev = cache + .series + .get(&cci_key) + .and_then(|x| x.get(prev_idx)) + .copied() + .unwrap_or(f64::NAN); + let ma = cache + .series + .get(&ma_key) + .and_then(|x| x.get(idx)) + .copied() + .unwrap_or(f64::NAN); + let close = c.bars_raw[idx].close; + + let mut v1 = "其他"; + if cci > 100.0 && close > ma { + v1 = "多头"; + } + if cci < -100.0 && close < ma { + v1 = "空头"; + } + if v1 == "其他" { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let v2 = if cci >= cci_prev { "向上" } else { "向下" }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// coo_kdj_V230322:均线与 KDJ 配合多空信号 +/// +/// 参数模板:`"{freq}_D{di}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{ma_type}#{n}_BS辅助V230322"` +/// +/// 信号逻辑: +/// 1. 计算 `KDJ` 与 `MA(n)`; +/// 2. `close > MA` 且 `K < D` 判 `多头`; +/// 3. `close < MA` 且 `K > D` 判 `空头`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1KDJ9#3#3#EMA#3_BS辅助V230322_多头_任意_任意_0')` +/// - `Signal('60分钟_D1KDJ9#3#3#EMA#3_BS辅助V230322_空头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:均线周期,默认 `3`; +/// - `ma_type`:均线类型,默认 `EMA`; +/// - `fastk_period/slowk_period/slowd_period`:KDJ 参数,默认 `9/3/3`。 +/// 对齐说明:与 Python `coo_kdj_V230322` 的组合条件一致。 +#[signal( + category = "kline", + name = "coo_kdj_V230322", + template = "{freq}_D{di}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{ma_type}#{n}_BS辅助V230322", + opcode = "CooKdjV230322", + param_kind = "CooKdjV230322" +)] +pub fn coo_kdj_v230322(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 3); + let ma_type = params.str("ma_type", "EMA").to_uppercase(); + let fastk_period = params.usize("fastk_period", 9); + let slowk_period = params.usize("slowk_period", 3); + let slowd_period = params.usize("slowd_period", 3); + + let ma_key = format!("{}#{}", ma_type, n); + let kdj_key = format!("KDJ{}#{}#{}", fastk_period, slowk_period, slowd_period); + update_ma_cache(c, &ma_key, &ma_type, n, cache); + update_kdj_cache( + c, + &kdj_key, + fastk_period, + slowk_period, + slowd_period, + cache, + ); + + let k1 = c.freq.to_string(); + let k2 = format!( + "D{}KDJ{}#{}#{}#{}#{}", + di, fastk_period, slowk_period, slowd_period, ma_type, n + ); + let k3 = "BS辅助V230322"; + if c.bars_raw.len() < fastk_period * slowk_period + di { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let idx = c.bars_raw.len() - di; + let close = c.bars_raw[idx].close; + let ma = cache + .series + .get(&ma_key) + .and_then(|x| x.get(idx)) + .copied() + .unwrap_or(f64::NAN); + let (k, d) = match cache.kdj.get(&kdj_key) { + Some(kdj) => ( + kdj.k.get(idx).copied().unwrap_or(f64::NAN), + kdj.d.get(idx).copied().unwrap_or(f64::NAN), + ), + None => (f64::NAN, f64::NAN), + }; + let v1 = if close > ma && k < d { + "多头" + } else if close < ma && k > d { + "空头" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// coo_sar_V230325:SAR 与区间极值配合信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}SAR_BS辅助V230325"` +/// +/// 信号逻辑: +/// 1. 计算最近 `n` 根收盘价区间高低点; +/// 2. 若 `close > SAR` 且 `high >= 区间最高收盘` 判 `多头`; +/// 3. 若 `close < SAR` 且 `low <= 区间最低收盘` 判 `空头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N60SAR_BS辅助V230325_多头_任意_任意_0')` +/// - `Signal('60分钟_D1N60SAR_BS辅助V230325_空头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:区间窗口,默认 `60`。 +/// 对齐说明:与 Python `coo_sar_V230325` 的 SAR 与区间条件一致。 +#[signal( + category = "kline", + name = "coo_sar_V230325", + template = "{freq}_D{di}N{n}SAR_BS辅助V230325", + opcode = "CooSarV230325", + param_kind = "CooSarV230325" +)] +pub fn coo_sar_v230325(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 60); + let sar_key = "SAR".to_string(); + update_sar_cache(c, &sar_key, cache); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}SAR", di, n); + let k3 = "BS辅助V230325"; + if c.bars_raw.len() < n + di + 10 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let hhv = bars.iter().map(|x| x.close).fold(f64::NEG_INFINITY, f64::max); + let llv = bars.iter().map(|x| x.close).fold(f64::INFINITY, f64::min); + let idx = c.bars_raw.len() - di; + let sar = cache + .series + .get(&sar_key) + .and_then(|x| x.get(idx)) + .copied() + .unwrap_or(f64::NAN); + let last = &c.bars_raw[idx]; + let close = last.close; + + let mut v1 = "其他"; + if close > sar && last.high >= hhv { + v1 = "多头"; + } + if close < sar && last.low <= llv { + v1 = "空头"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/cvolp.rs b/crates/czsc-signals/src/cvolp.rs new file mode 100644 index 000000000..bcdfb9949 --- /dev/null +++ b/crates/czsc-signals/src/cvolp.rs @@ -0,0 +1,117 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +fn convolve_prefix(volume: &[f64], weights: &[f64]) -> Vec { + let l = volume.len(); + let n = weights.len(); + let mut out = vec![0.0; l]; + for (k, out_k) in out.iter_mut().enumerate().take(l) { + let i_start = (k + 1).saturating_sub(n); + let mut acc = 0.0; + for (i, value) in volume.iter().enumerate().take(k + 1).skip(i_start) { + let j = k - i; + acc += value * weights[j]; + } + *out_k = acc; + } + out +} + +/// cvolp_up_dw_line_V230612:CVOLP 动量变化率信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}UP{up}DW{dw}_CVOLP动量变化率V230612"` +/// +/// 信号逻辑: +/// 1. 取最近 `n+m` 根成交量,构造长度为 `n` 的指数权重; +/// 2. 计算卷积平滑序列 `emap`,并将前 `n` 项置为 `emap[n]`; +/// 3. 计算 `sroc = (emap - roll(emap, m))[-1] / roll(emap, m)[-1]`; +/// 4. `sroc > up/100` 判 `看多`,`sroc < -dw/100` 判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N34M55UP5DW5_CVOLP动量变化率V230612_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N34M55UP5DW5_CVOLP动量变化率V230612_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:卷积平滑窗口,默认 `34`; +/// - `m`:滚动比较窗口,默认 `55`; +/// - `up`:看多阈值(百分比整数),默认 `5`; +/// - `dw`:看空阈值(百分比整数),默认 `5`。 +/// 对齐说明:卷积平滑与 `roll` 口径对齐 Python `cvolp_up_dw_line_V230612`。 +#[signal( + category = "kline", + name = "cvolp_up_dw_line_V230612", + template = "{freq}_D{di}N{n}M{m}UP{up}DW{dw}_CVOLP动量变化率V230612", + opcode = "CvolpUpDwLineV230612", + param_kind = "CvolpUpDwLineV230612" +)] +pub fn cvolp_up_dw_line_v230612( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 34); + let m = params.usize("m", 55); + let up = params.usize("up", 5); + let dw = params.usize("dw", 5); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}UP{}DW{}", di, n, m, up, dw); + let k3 = "CVOLP动量变化率V230612"; + let mut v1 = "其他"; + + if c.bars_raw.len() < di + n + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bars = get_sub_elements(&c.bars_raw, di, n + m); + if bars.len() <= n || bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let volume: Vec = bars.iter().map(|x| x.vol).collect(); + let mut weights: Vec = (0..n) + .map(|i| (-1.0 + i as f64 / (n.saturating_sub(1).max(1) as f64)).exp()) + .collect(); + let sum_w = weights.iter().sum::(); + if sum_w == 0.0 || !sum_w.is_finite() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + for w in &mut weights { + *w /= sum_w; + } + + let mut emap = convolve_prefix(&volume, &weights); + let fill_v = emap[n]; + for x in emap.iter_mut().take(n) { + *x = fill_v; + } + + let l = emap.len(); + let ridx = (l - 1 + l - (m % l)) % l; // 对齐 np.roll(emap, m)[-1] + let denom = emap[ridx]; + let numer = emap[l - 1] - denom; + let sroc = if denom == 0.0 { + if numer > 0.0 { + f64::INFINITY + } else if numer < 0.0 { + f64::NEG_INFINITY + } else { + f64::NAN + } + } else { + numer / denom + }; + + if sroc > up as f64 / 100.0 { + v1 = "看多"; + } + if sroc < -(dw as f64) / 100.0 { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/cxt.rs b/crates/czsc-signals/src/cxt.rs new file mode 100644 index 000000000..b7fdfc96c --- /dev/null +++ b/crates/czsc-signals/src/cxt.rs @@ -0,0 +1,2972 @@ +use crate::types::TaCache; +use crate::params::ParamView; +use crate::utils::cxt::{ + calc_bi_status_values, check_first_buy, check_first_sell, fx_has_zs, fx_power_str, + fx_raw_bars, get_zs_seq, raw_bar_lower, raw_bar_upper, rebuild_ubi, ubi_raw_bars, + unique_prices_from_bars, +}; +use crate::utils::math::{linreg_predict, max_amplitude_pct, mean, overlap}; +use crate::utils::sig::{ + bar_index_map, get_sub_elements, get_usize_param, make_kline_signal_v1, make_kline_signal_v2, + qcut_last_label, +}; +use crate::utils::ta::{ + ma_snapshot_value, macd_snapshot_field_value, update_ma_cache, update_macd_cache, MacdField, +}; +use czsc_signal_macros::signal; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::bi::BI; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::fx::FX; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::signal::Signal; +use std::collections::HashMap; + +/// cxt_bi_base_V230228:笔基础状态信号 +/// +/// 参数模板:`"{freq}_D0BL{bi_init_length}_V230228"` +/// +/// 信号逻辑: +/// 1. 读取最新一笔方向; +/// 2. 若最新笔为向下笔,当前状态记为 `向上`,反之记为 `向下`; +/// 3. 若未完成笔长度 `bars_ubi` 大于等于 `bi_init_length`,记为 `中继`,否则记为 `转折`; +/// 4. 笔数据不足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0BL9_V230228_向上_中继_任意_0')` +/// - `Signal('60分钟_D0BL9_V230228_向下_转折_任意_0')` +/// - `Signal('60分钟_D0BL9_V230228_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `bi_init_length`:未完成笔长度阈值,默认 `9`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_base_V230228` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_base_V230228", + template = "{freq}_D0BL{bi_init_length}_V230228", + opcode = "CxtBiBaseV230228", + param_kind = "CxtBiBase" +)] +pub fn cxt_bi_base_v230228( + czsc: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let bi_init_length = params.usize("bi_init_length", 9); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D0BL{}", bi_init_length); + let k3 = "V230228"; + + if czsc.bi_list.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let last_bi = czsc.bi_list.last().unwrap(); + let v1 = match last_bi.direction { + Direction::Down => "向上", + Direction::Up => "向下", + }; + + let v2 = if czsc.bars_ubi.len() >= bi_init_length { + "中继" + } else { + "转折" + }; + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// cxt_bi_status_V230101:笔表里关系信号 +/// +/// 参数模板:`"{freq}_D1_表里关系V230101"` +/// +/// 信号逻辑: +/// 1. 依据最后一笔方向和 `bars_ubi` 长度判定外部方向(`向上/向下`); +/// 2. 结合未完成笔最后一个分型(顶分/底分)判定内部状态(`顶分/底分/延伸`); +/// 3. 笔或分型数据不足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1_表里关系V230101_向上_顶分_任意_0')` +/// - `Signal('60分钟_D1_表里关系V230101_向下_底分_任意_0')` +/// - `Signal('60分钟_D1_表里关系V230101_向上_延伸_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_status_V230101` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_status_V230101", + template = "{freq}_D1_表里关系V230101", + opcode = "CxtBiStatusV230101", + param_kind = "CxtBiStatus" +)] +pub fn cxt_bi_status_v230101( + czsc: &CZSC, + _params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let k1 = czsc.freq.to_string(); + let k2 = "D1"; + let k3 = "表里关系V230101"; + + // 对齐 Python: + // if len(c.bi_list) < 3 or len(fxs) < 1: v1 = "其他" + let Some(ubi_fxs) = czsc.get_ubi_fxs() else { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + }; + + if czsc.bi_list.len() < 3 || ubi_fxs.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + + let (v1, v2) = calc_bi_status_values(czsc, &ubi_fxs); + + make_kline_signal_v2(&k1, k2, k3, v1, v2) +} + +/// cxt_bi_status_V230102:笔表里关系信号 +/// +/// 参数模板:`"{freq}_D1_表里关系V230102"` +/// +/// 信号逻辑: +/// 1. 沿用 `cxt_bi_status_V230101` 的表里方向和分型判定规则; +/// 2. 仅当最后一根原始K线时间等于最新 UBI 分型确认结束时间时触发; +/// 3. 不满足触发时机或数据不足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1_表里关系V230102_向下_底分_任意_0')` +/// - `Signal('60分钟_D1_表里关系V230102_向下_延伸_任意_0')` +/// - `Signal('60分钟_D1_表里关系V230102_向上_顶分_任意_0')` +/// - `Signal('60分钟_D1_表里关系V230102_向上_延伸_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_status_V230102` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_status_V230102", + template = "{freq}_D1_表里关系V230102", + opcode = "CxtBiStatusV230102", + param_kind = "CxtBiStatus" +)] +pub fn cxt_bi_status_v230102( + czsc: &CZSC, + _params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let k1 = czsc.freq.to_string(); + let k2 = "D1"; + let k3 = "表里关系V230102"; + + let Some(ubi_fxs) = czsc.get_ubi_fxs() else { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + }; + if czsc.bi_list.len() < 3 || ubi_fxs.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + + let Some(last_bar_dt) = czsc.bars_raw.last().map(|x| x.dt) else { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + }; + let last_fx = ubi_fxs.last().unwrap(); + let Some(last_fx_end_dt) = last_fx + .elements + .last() + .and_then(|nb| nb.elements.last().map(|rb| rb.dt)) + .or_else(|| last_fx.elements.last().map(|nb| nb.dt)) + else { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + }; + if last_bar_dt != last_fx_end_dt { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + + let (v1, v2) = calc_bi_status_values(czsc, &ubi_fxs); + make_kline_signal_v2(&k1, k2, k3, v1, v2) +} + +/// cxt_fx_power_V221107:倒数分型强弱 +/// +/// 参数模板:`"{freq}_D{di}F_分型强弱V221107"` +/// +/// 信号逻辑: +/// 1. 读取倒数第 `di` 个分型; +/// 2. `v1 = 分型强弱(power_str) + 顶/底`; +/// 3. `v2 = 有中枢/无中枢`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1F_分型强弱_强顶_有中枢_任意_0')` +/// - `Signal('60分钟_D2F_分型强弱_弱底_无中枢_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 个分型,默认 `1`; +/// - 仅当分型列表长度满足要求时输出具体强弱,否则返回 `其他`。 +/// 对齐说明:与 Python `czsc.signals.cxt_fx_power_V221107` 保持一致。 +#[signal( + category = "kline", + name = "cxt_fx_power_V221107", + template = "{freq}_D{di}F_分型强弱V221107", + opcode = "CxtFxPowerV221107", + param_kind = "CxtFxPowerV221107" +)] +pub fn cxt_fx_power_v221107(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}F", di); + let k3 = "分型强弱"; + + if di == 0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let fxs = c.get_fx_list(); + if fxs.len() < di { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let last_fx = &fxs[fxs.len() - di]; + let mark = match last_fx.mark { + Mark::G => "顶", + Mark::D => "底", + }; + let v1 = format!("{}{}", fx_power_str(last_fx), mark); + let v2 = if fx_has_zs(last_fx) { "有中枢" } else { "无中枢" }; + make_kline_signal_v2(&k1, &k2, k3, &v1, v2) +} + +/// cxt_bi_end_V230104:单均线辅助判断笔结束 +/// +/// 参数模板:`"{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230104"` +/// +/// 信号逻辑: +/// 1. 计算指定均线,并取最近 3 根原始 K 线; +/// 2. 若向下笔尾部出现三连阳且收盘强于均线阈值,判定 `看多`;向上笔尾部三连阴且收盘弱于均线阈值,判定 `看空`; +/// 3. 不满足边界、均线或形态条件时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0SMA#5T50_BE辅助V230104_看多_任意_任意_0')` +/// - `Signal('60分钟_D0EMA#8T30_BE辅助V230104_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`; +/// - `th`:收盘价相对均线的 BP 阈值,默认 `50`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230104` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230104", + template = "{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230104", + opcode = "CxtBiEndV230104", + param_kind = "CxtBiEndV230104" +)] +pub fn cxt_bi_end_v230104(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let th = get_usize_param(params, "th", 50) as f64; + let timeperiod = get_usize_param(params, "timeperiod", 5); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let cache_key = format!("{}#{}", ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let k1 = c.freq.to_string(); + let k2 = format!("D0{}#{}T{}", ma_type, timeperiod, th as i32); + let k3 = "BE辅助V230104"; + let mut v1 = "其他"; + + if c.bi_list.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bars = get_sub_elements(&c.bars_raw, 1, 3); + if bars.len() != 3 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let id_to_idx = bar_index_map(c); + let bar1 = &bars[0]; + let bar2 = &bars[1]; + let bar3 = &bars[2]; + let Some(&bar3_idx) = id_to_idx.get(&bar3.id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(bar3_ma) = ma.get(bar3_idx).copied() else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let last_bi = c.bi_list.last().unwrap(); + + let lows_min = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let highs_max = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + + let lc1 = last_bi.direction == Direction::Down && lows_min == last_bi.get_low(); + let lc2 = bar1.close > bar1.open && bar2.close > bar2.open && bar3.close > bar3.open; + let lc3 = bar3_ma * (1.0 + th / 10000.0) < bar3.close; + if c.bars_ubi.len() < 7 && lc1 && lc2 && lc3 { + v1 = "看多"; + } + + let sc1 = last_bi.direction == Direction::Up && highs_max == last_bi.get_high(); + let sc2 = bar1.close < bar1.open && bar2.close < bar2.open && bar3.close < bar3.open; + let sc3 = bar3_ma * (1.0 - th / 10000.0) > bar3.close; + if c.bars_ubi.len() < 7 && sc1 && sc2 && sc3 { + v1 = "看空"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_bi_end_V230105:K线形态+均线辅助判断笔结束 +/// +/// 参数模板:`"{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230105"` +/// +/// 信号逻辑: +/// 1. 提取最后一笔终点分型的两根原始 K 线,并计算指定均线; +/// 2. 向下笔若先阴后强阳上穿均线阈值,判定 `看多`;向上笔若先阳后强阴下破均线阈值,判定 `看空`; +/// 3. 未完成笔过长、分型样本不足或均线不可用时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0SMA#5T50_BE辅助V230105_看多_任意_任意_0')` +/// - `Signal('60分钟_D0EMA#8T30_BE辅助V230105_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`; +/// - `th`:第二根 K 线相对均线的突破阈值,默认 `50` BP。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230105` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230105", + template = "{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230105", + opcode = "CxtBiEndV230105", + param_kind = "CxtBiEndV230105" +)] +pub fn cxt_bi_end_v230105(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let th = get_usize_param(params, "th", 50) as f64; + let timeperiod = get_usize_param(params, "timeperiod", 5); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let cache_key = format!("{}#{}", ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let k1 = c.freq.to_string(); + let k2 = format!("D0{}#{}T{}", ma_type, timeperiod, th as i32); + let k3 = "BE辅助V230105"; + let mut v1 = "其他"; + + if c.bi_list.len() < 3 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let id_to_idx = bar_index_map(c); + let last_bi = c.bi_list.last().unwrap(); + let fx_raw = fx_raw_bars(&last_bi.fx_b); + if fx_raw.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bar1 = &fx_raw[fx_raw.len() - 2]; + let bar2 = &fx_raw[fx_raw.len() - 1]; + let Some(&bar2_idx) = id_to_idx.get(&bar2.id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(bar2_ma) = ma.get(bar2_idx).copied() else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + + let lc1 = last_bi.direction == Direction::Down && bar1.low == last_bi.get_low(); + let lc2 = bar1.close < bar1.open + && bar2.close > bar2_ma * (1.0 + th / 10000.0) + && bar2_ma * (1.0 + th / 10000.0) > bar2.open; + if c.bars_ubi.len() < 7 && lc1 && lc2 { + v1 = "看多"; + } + + let sc1 = last_bi.direction == Direction::Up && bar1.high == last_bi.get_high(); + let sc2 = bar1.close > bar1.open + && bar2.close < bar2_ma * (1.0 - th / 10000.0) + && bar2_ma * (1.0 - th / 10000.0) < bar2.open; + if c.bars_ubi.len() < 7 && sc1 && sc2 { + v1 = "看空"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_bi_end_V230224:量价配合笔结束辅助 +/// +/// 参数模板:`"{freq}_D1_BE辅助V230224"` +/// +/// 信号逻辑: +/// 1. 统计最后一笔整体均量与终点分型均量; +/// 2. 长上影且分型显著放量时判定 `看空`,长下影且分型显著缩量时判定 `看多`; +/// 3. 若笔或分型样本不足,或未完成笔过长,则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1_BE辅助V230224_看多_任意_任意_0')` +/// - `Signal('60分钟_D1_BE辅助V230224_看空_任意_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 仅在 UBI 较短时使用量价关系辅助判断笔结束。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230224` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230224", + template = "{freq}_D1_BE辅助V230224", + opcode = "CxtBiEndV230224", + param_kind = "CxtBiEndV230224" +)] +pub fn cxt_bi_end_v230224(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "D1"; + let k3 = "BE辅助V230224"; + let mut v1 = "其他"; + if c.bi_list.len() <= 3 || c.bars_ubi.len() >= 7 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + let last_bi = c.bi_list.last().unwrap(); + let bi_bars = last_bi.get_raw_bars(); + let fx_bars = fx_raw_bars(&last_bi.fx_b); + if bi_bars.is_empty() || fx_bars.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + let bi_vol_mean = bi_bars.iter().map(|x| x.vol).sum::() / bi_bars.len() as f64; + let fx_vol_mean = fx_bars.iter().map(|x| x.vol).sum::() / fx_bars.len() as f64; + + let bar1 = fx_bars + .iter() + .skip(1) + .fold(&fx_bars[0], |acc, x| if x.low < acc.low { x } else { acc }); + let bar2 = fx_bars + .iter() + .skip(1) + .fold(&fx_bars[0], |acc, x| if x.high > acc.high { x } else { acc }); + + if raw_bar_upper(bar1) > raw_bar_lower(bar1) * 2.0 && fx_vol_mean > bi_vol_mean * 2.0 { + v1 = "看空"; + } + if 2.0 * raw_bar_upper(bar2) < raw_bar_lower(bar2) && fx_vol_mean < bi_vol_mean * 0.618 { + v1 = "看多"; + } + + make_kline_signal_v1(&k1, k2, k3, v1) +} + +/// cxt_bi_end_V230312:MACD辅助判断笔结束 +/// +/// 参数模板:`"{freq}_D0MACD{fastperiod}#{slowperiod}#{signalperiod}_BE辅助V230312"` +/// +/// 信号逻辑: +/// 1. 计算指定参数的 MACD,并读取最后一笔终点分型对应的首末原始 K 线; +/// 2. 向下笔若分型尾部 MACD 柱值高于分型起点,判定 `看多`;向上笔反向判定 `看空`; +/// 3. MACD 缓存、分型样本或边界条件不满足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0MACD12#26#9_BE辅助V230312_看多_任意_任意_0')` +/// - `Signal('60分钟_D0MACD12#26#9_BE辅助V230312_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `fastperiod`:MACD 快线周期,默认 `12`; +/// - `slowperiod`:MACD 慢线周期,默认 `26`; +/// - `signalperiod`:信号线周期,默认 `9`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230312` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230312", + template = "{freq}_D0MACD{fastperiod}#{slowperiod}#{signalperiod}_BE辅助V230312", + opcode = "CxtBiEndV230312", + param_kind = "CxtBiEndV230312" +)] +pub fn cxt_bi_end_v230312(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let k1 = c.freq.to_string(); + let k2 = format!("D0MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + let k3 = "BE辅助V230312"; + let mut v1 = "其他"; + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache(c, &cache_key, fastperiod, slowperiod, signalperiod, cache); + + if c.bi_list.len() < 3 || c.bars_ubi.len() >= 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let last_bi = c.bi_list.last().unwrap(); + let fx_bars = fx_raw_bars(&last_bi.fx_b); + if fx_bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let Some(macd) = cache.macd.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let id_to_idx = bar_index_map(c); + let mut snapshot_overrides: HashMap = HashMap::new(); + let macd1 = macd_snapshot_field_value( + c, + macd, + &id_to_idx, + fx_bars.last().unwrap(), + fastperiod, + slowperiod, + signalperiod, + MacdField::Macd, + &mut snapshot_overrides, + ) + .unwrap_or(f64::NAN); + let macd2 = macd_snapshot_field_value( + c, + macd, + &id_to_idx, + fx_bars.first().unwrap(), + fastperiod, + slowperiod, + signalperiod, + MacdField::Macd, + &mut snapshot_overrides, + ) + .unwrap_or(f64::NAN); + if !macd1.is_finite() || !macd2.is_finite() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + if last_bi.direction == Direction::Down && macd1 > macd2 { + v1 = "看多"; + } + if last_bi.direction == Direction::Up && macd1 < macd2 { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_bi_end_V230324:笔结束分型均线突破 +/// +/// 参数模板:`"{freq}_D0{ma_type}#{timeperiod}均线突破_BE辅助V230324"` +/// +/// 信号逻辑: +/// 1. 计算指定均线,并提取最后一笔终点分型除最后一根之外的均线序列; +/// 2. 向上笔若上一根收盘跌破分型内最低均线,判定 `看空`;向下笔若上一根收盘突破分型内最高均线,判定 `看多`; +/// 3. 数据不足、UBI 过长或均线不可用时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0SMA#5均线突破_BE辅助V230324_看多_任意_任意_0')` +/// - `Signal('60分钟_D0EMA#13均线突破_BE辅助V230324_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230324` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230324", + template = "{freq}_D0{ma_type}#{timeperiod}均线突破_BE辅助V230324", + opcode = "CxtBiEndV230324", + param_kind = "CxtBiEndV230324" +)] +pub fn cxt_bi_end_v230324(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let timeperiod = get_usize_param(params, "timeperiod", 5); + let cache_key = format!("{}#{}", ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let k1 = c.freq.to_string(); + let k2 = format!("D0{}#{}均线突破", ma_type, timeperiod); + let k3 = "BE辅助V230324"; + let mut v1 = "其他"; + let ubi_fxs = c.get_ubi_fxs().unwrap_or_default(); + + if c.bi_list.len() < 3 || c.bars_ubi.len() > 7 || ubi_fxs.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + if c.bars_raw.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let id_to_idx = bar_index_map(c); + let last_bi = c.bi_list.last().unwrap(); + let fx_raw = fx_raw_bars(&last_bi.fx_b); + if fx_raw.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let mut ma_vals: Vec = Vec::new(); + for rb in fx_raw.iter().take(fx_raw.len() - 1) { + if let Some(idx) = id_to_idx.get(&rb.id) { + if let Some(x) = ma.get(*idx) { + if x.is_finite() { + ma_vals.push(*x); + } + } + } + } + if ma_vals.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let max_ma = ma_vals.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let min_ma = ma_vals.iter().copied().fold(f64::INFINITY, f64::min); + let last_close = c.bars_raw[c.bars_raw.len() - 2].close; + + if last_bi.direction == Direction::Up && last_close < min_ma { + v1 = "看空"; + } + if last_bi.direction == Direction::Down && last_close > max_ma { + v1 = "看多"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_bi_end_V230815:快速突破反向笔 +/// +/// 参数模板:`"{freq}_快速突破_BE辅助V230815"` +/// +/// 信号逻辑: +/// 1. 读取最后一笔和当前未完成笔最后一根 K 线; +/// 2. 向上笔若被最新低点快速跌破,输出 `向下`;向下笔若被最新高点快速突破,输出 `向上`; +/// 3. 笔数不足或 UBI 已延伸过长时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_快速突破_BE辅助V230815_向上_任意_任意_0')` +/// - `Signal('60分钟_快速突破_BE辅助V230815_向下_任意_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 仅用于很短的 UBI 场景,强调“快速突破”。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230815` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230815", + template = "{freq}_快速突破_BE辅助V230815", + opcode = "CxtBiEndV230815", + param_kind = "CxtBiEndV230815" +)] +pub fn cxt_bi_end_v230815(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "快速突破"; + let k3 = "BE辅助V230815"; + let mut v1 = "其他"; + if c.bi_list.len() < 5 || c.bars_ubi.len() >= 5 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let bi = c.bi_list.last().unwrap(); + let last_bar = c.bars_ubi.last().unwrap(); + if bi.direction == Direction::Up && last_bar.low < bi.get_low() { + v1 = "向下"; + } + if bi.direction == Direction::Down && last_bar.high > bi.get_high() { + v1 = "向上"; + } + make_kline_signal_v1(&k1, k2, k3, v1) +} + +/// cxt_bi_stop_V230815:笔止损距离状态 +/// +/// 参数模板:`"{freq}_距离{th}BP_止损V230815"` +/// +/// 信号逻辑: +/// 1. 读取最后一笔方向,并把其高低点作为止损基准; +/// 2. 向上场景比较最新收盘距笔高的回撤,向下场景比较最新收盘距笔低的反弹; +/// 3. 若落在 `th` BP 阈值内则标记 `阈值内`,否则标记 `阈值外`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_距离50BP_止损V230815_向上_阈值内_任意_0')` +/// - `Signal('60分钟_距离50BP_止损V230815_向下_阈值外_任意_0')` +/// +/// 参数说明: +/// - `th`:距离阈值,单位 BP,默认 `50`; +/// - 信号只读取最后一笔和当前 UBI,不做更长历史统计。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_stop_V230815` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_stop_V230815", + template = "{freq}_距离{th}BP_止损V230815", + opcode = "CxtBiStopV230815", + param_kind = "CxtBiStopV230815" +)] +pub fn cxt_bi_stop_v230815(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let th = get_usize_param(params, "th", 50) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("距离{}BP", th as i32); + let k3 = "止损V230815"; + let mut v1 = "其他"; + let mut v2 = "其他"; + if c.bi_list.len() < 5 || c.bars_ubi.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, v1, v2); + } + let bi = c.bi_list.last().unwrap(); + let last_bar = c.bars_ubi.last().unwrap(); + if bi.direction == Direction::Up { + v1 = "向下"; + v2 = if last_bar.close > bi.get_high() * (1.0 - th / 10000.0) { + "阈值内" + } else { + "阈值外" + }; + } + if bi.direction == Direction::Down { + v1 = "向上"; + v2 = if last_bar.close < bi.get_low() * (1.0 + th / 10000.0) { + "阈值内" + } else { + "阈值外" + }; + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// cxt_bi_trend_V230824:N笔形态判断 +/// +/// 参数模板:`"{freq}_D{di}N{n}TH{th}_形态V230824"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 笔的中位价格均值; +/// 2. 用首笔中位价格相对均值的偏离程度判断 `向上/向下/横盘`; +/// 3. 偏离阈值由 `th` 控制,数据不足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N4TH2_形态V230824_向上_任意_任意_0')` +/// - `Signal('60分钟_D1N4TH2_形态V230824_横盘_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `n`:参与比较的笔数,默认 `4`; +/// - `th`:相对均值的偏离阈值,默认 `2`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_trend_V230824` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_trend_V230824", + template = "{freq}_D{di}N{n}TH{th}_形态V230824", + opcode = "CxtBiTrendV230824", + param_kind = "CxtBiTrendV230824" +)] +pub fn cxt_bi_trend_v230824(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 4); + let th = get_usize_param(params, "th", 2) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}TH{}", di, n, th as i32); + let k3 = "形态V230824"; + let mut v1 = "其他"; + if c.bi_list.len() < di + n + 2 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis = get_sub_elements(&c.bi_list, di, n); + if bis.len() != n { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let means: Vec = bis.iter().map(|bi| (bi.get_low() + bi.get_high()) / 2.0).collect(); + let avg = means.iter().sum::() / n as f64; + if !avg.is_finite() || avg == 0.0 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let ratio = means[0] / avg; + if ratio * 100.0 > 100.0 + th { + v1 = "向下"; + } else if ratio * 100.0 < 100.0 - th { + v1 = "向上"; + } else { + v1 = "横盘"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_bi_zdf_V230601:BI涨跌幅分层 +/// +/// 参数模板:`"{freq}_D{di}N{n}_分层V230601"` +/// +/// 信号逻辑: +/// 1. 取最近最多 50 笔的力度序列; +/// 2. 读取最新笔方向作为 `v1`; +/// 3. 用 `qcut_last_label` 将最新力度分到 `n` 层中的某一层,输出 `第X层`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N5_分层V230601_向上_第3层_任意_0')` +/// - `Signal('60分钟_D1N5_分层V230601_向下_第1层_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始统计,默认 `1`; +/// - `n`:分层数量,默认 `5`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_zdf_V230601` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_zdf_V230601", + template = "{freq}_D{di}N{n}_分层V230601", + opcode = "CxtBiZdfV230601", + param_kind = "CxtBiZdfV230601" +)] +pub fn cxt_bi_zdf_v230601(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}", di, n); + let k3 = "分层V230601"; + if c.bi_list.len() < 10 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 50); + if bis.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let v1 = bis.last().unwrap().direction.to_string(); + let powers: Vec = bis.iter().map(|x| x.get_power()).collect(); + let v2 = qcut_last_label(&powers, n) + .map(|layer| format!("第{}层", layer + 1)) + .unwrap_or_else(|| "其他".to_string()); + make_kline_signal_v2(&k1, &k2, k3, &v1, &v2) +} + +/// cxt_second_bs_V230320:均线辅助识别第二类买卖点 +/// +/// 参数模板:`"{freq}_D{di}#{ma_type}#{timeperiod}_BS2辅助V230320"` +/// +/// 信号逻辑: +/// 1. 取最近 5 笔,并计算关键分型右侧原始 K 线的均线值; +/// 2. 若前两次同向回撤/反弹已偏离均线,而第 5 笔重新回到均线同向,判定 `二买/二卖`; +/// 3. 均线、分型样本或笔数量不足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1#SMA#21_BS2辅助V230320_二买_任意_任意_0')` +/// - `Signal('60分钟_D1#EMA#34_BS2辅助V230320_二卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `21`。 +/// 对齐说明:与 Python `czsc.signals.cxt_second_bs_V230320` 保持一致。 +#[signal( + category = "kline", + name = "cxt_second_bs_V230320", + template = "{freq}_D{di}#{ma_type}#{timeperiod}_BS2辅助V230320", + opcode = "CxtSecondBsV230320", + param_kind = "CxtSecondBsV230320" +)] +pub fn cxt_second_bs_v230320(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 21); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let cache_key = format!("{}#{}", ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}#{}#{}", di, ma_type, timeperiod); + let k3 = "BS2辅助V230320"; + let mut v1 = "其他"; + if c.bi_list.len() < di + 6 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let bis = get_sub_elements(&c.bi_list, di, 5); + if bis.len() != 5 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let b1 = &bis[0]; + let b3 = &bis[2]; + let b5 = &bis[4]; + let b1_fx_bars = fx_raw_bars(&b1.fx_b); + let b3_fx_bars = fx_raw_bars(&b3.fx_b); + let b5_fx_a_bars = fx_raw_bars(&b5.fx_a); + let b5_fx_b_bars = fx_raw_bars(&b5.fx_b); + if b1_fx_bars.len() < 2 || b3_fx_bars.len() < 2 || b5_fx_a_bars.len() < 2 || b5_fx_b_bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let id_to_idx = bar_index_map(c); + + let get_ma = |bar_id: i32| -> Option { + let idx = *id_to_idx.get(&bar_id)?; + ma.get(idx).copied() + }; + let Some(b1_ma_b) = get_ma(b1_fx_bars[b1_fx_bars.len() - 2].id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(b3_ma_b) = get_ma(b3_fx_bars[b3_fx_bars.len() - 2].id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(b5_ma_a) = get_ma(b5_fx_a_bars[b5_fx_a_bars.len() - 2].id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(b5_ma_b) = get_ma(b5_fx_b_bars[b5_fx_b_bars.len() - 2].id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + + let lc1 = b1.get_low() < b1_ma_b && b3.get_low() < b3_ma_b; + if b5.direction == Direction::Down && lc1 && b5_ma_a < b5_ma_b { + v1 = "二买"; + } + + let sc1 = b1.get_high() > b1_ma_b && b3.get_high() > b3_ma_b; + if b5.direction == Direction::Up && sc1 && b5_ma_a > b5_ma_b { + v1 = "二卖"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_third_bs_V230318:均线辅助识别第三类买卖点 +/// +/// 参数模板:`"{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230318"` +/// +/// 信号逻辑: +/// 1. 取最近 5 笔构造中枢,并计算第 1、3、5 笔终点分型的均线; +/// 2. 若第 5 笔离开中枢,且三次均线值同向抬升或下降,则判定 `三买/三卖`; +/// 3. 中枢无效、均线缺失或笔数不足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1#SMA#34_BS3辅助V230318_三买_任意_任意_0')` +/// - `Signal('60分钟_D1#EMA#34_BS3辅助V230318_三卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `34`。 +/// 对齐说明:与 Python `czsc.signals.cxt_third_bs_V230318` 保持一致。 +#[signal( + category = "kline", + name = "cxt_third_bs_V230318", + template = "{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230318", + opcode = "CxtThirdBsV230318", + param_kind = "CxtThirdBsV230318" +)] +pub fn cxt_third_bs_v230318(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 34); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let cache_key = format!("{}#{}", ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let k1 = c.freq.to_string(); + let k2 = format!("D{}#{}#{}", di, ma_type, timeperiod); + let k3 = "BS3辅助V230318"; + let mut v1 = "其他"; + + if c.bi_list.len() < di + 6 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis = get_sub_elements(&c.bi_list, di, 5); + if bis.len() != 5 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let b1 = &bis[0]; + let b3 = &bis[2]; + let b5 = &bis[4]; + let zs_zd = b1.get_low().max(b3.get_low()); + let zs_zg = b1.get_high().min(b3.get_high()); + if zs_zd > zs_zg { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let id_to_idx = bar_index_map(c); + let b1_fx = fx_raw_bars(&b1.fx_b); + let b3_fx = fx_raw_bars(&b3.fx_b); + let b5_fx = fx_raw_bars(&b5.fx_b); + if b1_fx.is_empty() || b3_fx.is_empty() || b5_fx.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let mut snapshot_overrides: HashMap = HashMap::new(); + let Some(ma_1) = ma_snapshot_value( + c, + ma, + &id_to_idx, + &b1_fx[b1_fx.len() - 1], + &ma_type, + timeperiod, + &mut snapshot_overrides, + ) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(ma_3) = ma_snapshot_value( + c, + ma, + &id_to_idx, + &b3_fx[b3_fx.len() - 1], + &ma_type, + timeperiod, + &mut snapshot_overrides, + ) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(ma_5) = ma_snapshot_value( + c, + ma, + &id_to_idx, + &b5_fx[b5_fx.len() - 1], + &ma_type, + timeperiod, + &mut snapshot_overrides, + ) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + + if b5.direction == Direction::Down && b5.get_low() > zs_zg && ma_5 > ma_3 && ma_3 > ma_1 { + v1 = "三买"; + } + if b5.direction == Direction::Up && b5.get_high() < zs_zd && ma_5 < ma_3 && ma_3 < ma_1 { + v1 = "三卖"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_double_zs_V230311:双中枢 BS1 辅助 +/// +/// 参数模板:`"{freq}_D{di}双中枢_BS1辅助V230311"` +/// +/// 信号逻辑: +/// 1. 提取最近 20 笔并重建中枢序列; +/// 2. 若最近两个中枢都有效,比较后一中枢内部两笔的时长与前后中枢极值关系; +/// 3. 向下笔满足衰竭条件判定 `看多`,向上笔满足衰竭条件判定 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1双中枢_BS1辅助V230311_看多_任意_任意_0')` +/// - `Signal('60分钟_D1双中枢_BS1辅助V230311_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 需要至少形成两个有效中枢,否则返回 `其他`。 +/// 对齐说明:与 Python `czsc.signals.cxt_double_zs_V230311` 保持一致。 +#[signal( + category = "kline", + name = "cxt_double_zs_V230311", + template = "{freq}_D{di}双中枢_BS1辅助V230311", + opcode = "CxtDoubleZsV230311", + param_kind = "CxtDoubleZsV230311" +)] +pub fn cxt_double_zs_v230311(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}双中枢", di); + let k3 = "BS1辅助V230311"; + let mut v1 = "其他"; + let bis = get_sub_elements(&c.bi_list, di, 20); + if bis.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let zss = get_zs_seq(bis); + if zss.len() >= 2 && zss[zss.len() - 2].bis.len() >= 2 && zss[zss.len() - 1].bis.len() >= 2 { + let zs1 = &zss[zss.len() - 2]; + let zs2 = &zss[zss.len() - 1]; + let ts1 = zs2.bis[zs2.bis.len() - 1].bars.len(); + let ts2 = zs2.bis[zs2.bis.len() - 2].bars.len(); + let last_bi = bis.last().unwrap(); + if last_bi.direction == Direction::Down && ts1 >= ts2 * 2 && zs1.gg > zs2.gg { + v1 = "看多"; + } + if last_bi.direction == Direction::Up && ts1 >= ts2 * 2 && zs1.dd < zs2.dd { + v1 = "看空"; + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_overlap_V240526:收盘价与最近分型区间重合次数 +/// +/// 参数模板:`"{freq}_顶底重合_支撑压力V240526"` +/// +/// 信号逻辑: +/// 1. 取最近 9 笔,读取最新收盘价; +/// 2. 分别统计收盘价落在向上笔顶分型区间和向下笔底分型区间中的次数; +/// 3. 输出 `顶重合X次` 与 `底重合Y次`,用于支撑压力观察。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_顶底重合_支撑压力V240526_顶重合2次_底重合1次_任意_0')` +/// - `Signal('60分钟_顶底重合_支撑压力V240526_顶重合0次_底重合3次_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 至少要求 11 笔与非空原始 K 线序列。 +/// 对齐说明:与 Python `czsc.signals.cxt_overlap_V240526` 保持一致。 +#[signal( + category = "kline", + name = "cxt_overlap_V240526", + template = "{freq}_顶底重合_支撑压力V240526", + opcode = "CxtOverlapV240526", + param_kind = "CxtOverlapV240526" +)] +pub fn cxt_overlap_v240526(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "顶底重合"; + let k3 = "支撑压力V240526"; + if c.bi_list.len() < 11 || c.bars_raw.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, 1, 9); + if bis.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_close = c.bars_raw.last().unwrap().close; + let overlap_count_g = bis + .iter() + .filter(|x| x.direction == Direction::Up) + .filter(|x| x.fx_b.low <= last_close && last_close <= x.fx_b.high) + .count(); + let overlap_count_d = bis + .iter() + .filter(|x| x.direction == Direction::Down) + .filter(|x| x.fx_b.low <= last_close && last_close <= x.fx_b.high) + .count(); + let v1 = format!("顶重合{}次", overlap_count_g); + let v2 = format!("底重合{}次", overlap_count_d); + make_kline_signal_v2(&k1, k2, k3, &v1, &v2) +} + +/// cxt_decision_V240526:分型区域决策 +/// +/// 参数模板:`"{freq}_分型区域N{n}_决策区域V240526"` +/// +/// 信号逻辑: +/// 1. 在最近 100 根 K 线中提取离散价格层; +/// 2. 若最后一笔向上,统计最新收盘到顶分型上沿之间的价位层数,层数不多于 `n` 时判定 `开空`;向下笔反向判定 `开多`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_分型区域N9_决策区域V240526_开多_任意_任意_0')` +/// - `Signal('60分钟_分型区域N9_决策区域V240526_开空_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:允许的价位层数量阈值,默认 `9`; +/// - 至少要求 120 根原始 K 线和一笔已完成笔。 +/// 对齐说明:与 Python `czsc.signals.cxt_decision_V240526` 保持一致。 +#[signal( + category = "kline", + name = "cxt_decision_V240526", + template = "{freq}_分型区域N{n}_决策区域V240526", + opcode = "CxtDecisionV240526", + param_kind = "CxtDecisionV240526" +)] +pub fn cxt_decision_v240526(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 9); + let k1 = c.freq.to_string(); + let k2 = format!("分型区域N{}", n); + let k3 = "决策区域V240526"; + let mut v1 = "其他"; + if c.bars_raw.len() < 120 || c.bi_list.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bars = get_sub_elements(&c.bars_raw, 1, 100); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let prices = unique_prices_from_bars(bars); + let bi = c.bi_list.last().unwrap(); + let bar = c.bars_raw.last().unwrap(); + if bi.direction == Direction::Up { + let in_count = prices + .iter() + .filter(|&&x| bar.close <= x && x <= bi.fx_b.high) + .count(); + if in_count <= n { + v1 = "开空"; + } + } else if bi.direction == Direction::Down { + let in_count = prices + .iter() + .filter(|&&x| bi.fx_b.low <= x && x <= bar.close) + .count(); + if in_count <= n { + v1 = "开多"; + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_decision_V240612:高低点N档决策区间 +/// +/// 参数模板:`"{freq}_W{w}N{n}高低点_决策区域V240612"` +/// +/// 信号逻辑: +/// 1. 用最近 100 根 K 线生成离散价格层,再用最近 `w` 根 K 线确定高低点; +/// 2. 在低点上方和高点下方各取第 `n` 档价格,形成低区和高区阈值; +/// 3. 最新收盘落入低区判定 `开多`,落入高区判定 `开空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_W10N9高低点_决策区域V240612_开多_任意_任意_0')` +/// - `Signal('60分钟_W10N9高低点_决策区域V240612_开空_任意_任意_0')` +/// +/// 参数说明: +/// - `w`:最近高低点统计窗口,默认 `10`; +/// - `n`:从高低点向内取第 `n` 档价格,默认 `9`。 +/// 对齐说明:与 Python `czsc.signals.cxt_decision_V240612` 保持一致。 +#[signal( + category = "kline", + name = "cxt_decision_V240612", + template = "{freq}_W{w}N{n}高低点_决策区域V240612", + opcode = "CxtDecisionV240612", + param_kind = "CxtDecisionV240612" +)] +pub fn cxt_decision_v240612(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let w = get_usize_param(params, "w", 10); + let n = get_usize_param(params, "n", 9); + let k1 = c.freq.to_string(); + let k2 = format!("W{}N{}高低点", w, n); + let k3 = "决策区域V240612"; + let mut v1 = "其他"; + if c.bars_raw.len() < 120 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let bars = get_sub_elements(&c.bars_raw, 1, 100); + let prices = unique_prices_from_bars(bars); + if prices.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let w_bars = get_sub_elements(&c.bars_raw, 1, w); + if w_bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let max_high = w_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let min_low = w_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let last_bar = c.bars_raw.last().unwrap(); + + let mut min_low_upper: Vec = prices.iter().copied().filter(|x| *x >= min_low).collect(); + min_low_upper.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let low_range = if min_low_upper.len() > n { + min_low_upper[n] + } else { + *min_low_upper.last().unwrap() + }; + + let mut max_high_lower: Vec = prices.iter().copied().filter(|x| *x <= max_high).collect(); + max_high_lower.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + let high_range = if max_high_lower.len() > n { + max_high_lower[n] + } else { + *max_high_lower.last().unwrap() + }; + + if last_bar.close < low_range && last_bar.low != min_low { + v1 = "开多"; + } + if last_bar.close > high_range && last_bar.high != max_high { + v1 = "开空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_decision_V240613:放量笔N4BS2决策区 +/// +/// 参数模板:`"{freq}_放量笔N{n}BS2_决策区域V240613"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 笔并定位成交量最大的最后一笔; +/// 2. 若该笔向下但未创新低,判定 `开多`;若向上但未创新高,判定 `开空`; +/// 3. 只有最后一笔同时满足“放量且不是极值笔”时才触发。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_放量笔N4BS2_决策区域V240613_开多_任意_任意_0')` +/// - `Signal('60分钟_放量笔N4BS2_决策区域V240613_开空_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:比较最近 `n` 笔的放量程度,默认 `4`; +/// - 仅在 UBI 不超过 7 时使用该决策信号。 +/// 对齐说明:与 Python `czsc.signals.cxt_decision_V240613` 保持一致。 +#[signal( + category = "kline", + name = "cxt_decision_V240613", + template = "{freq}_放量笔N{n}BS2_决策区域V240613", + opcode = "CxtDecisionV240613", + param_kind = "CxtDecisionV240613" +)] +pub fn cxt_decision_v240613(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 4); + let k1 = c.freq.to_string(); + let k2 = format!("放量笔N{}BS2", n); + let k3 = "决策区域V240613"; + let mut v1 = "其他"; + if c.bi_list.len() < n + 2 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis = get_sub_elements(&c.bi_list, 1, n); + if bis.len() < n { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis_max_vol = bis + .iter() + .map(|x| x.get_power_volume()) + .fold(f64::NEG_INFINITY, f64::max); + let last_bi = bis.last().unwrap(); + if last_bi.get_power_volume() != bis_max_vol { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let max_high = bis + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + + if last_bi.direction == Direction::Down && last_bi.get_low() != min_low { + v1 = "开多"; + } + if last_bi.direction == Direction::Up && last_bi.get_high() != max_high { + v1 = "开空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_decision_V240614:放量新高/新低决策区 +/// +/// 参数模板:`"{freq}_放量笔N{n}_决策区域V240614"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 笔并定位成交量最大的最后一笔; +/// 2. 若该笔向下且同时创新低,判定 `开多`;若向上且同时创新高,判定 `开空`; +/// 3. 用于识别放量突破后的反向决策区域。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_放量笔N4_决策区域V240614_开多_任意_任意_0')` +/// - `Signal('60分钟_放量笔N4_决策区域V240614_开空_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:比较最近 `n` 笔的放量程度,默认 `4`; +/// - 需要最后一笔既是放量笔,又是最近 `n` 笔的新高或新低笔。 +/// 对齐说明:与 Python `czsc.signals.cxt_decision_V240614` 保持一致。 +#[signal( + category = "kline", + name = "cxt_decision_V240614", + template = "{freq}_放量笔N{n}_决策区域V240614", + opcode = "CxtDecisionV240614", + param_kind = "CxtDecisionV240614" +)] +pub fn cxt_decision_v240614(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 4); + let k1 = c.freq.to_string(); + let k2 = format!("放量笔N{}", n); + let k3 = "决策区域V240614"; + let mut v1 = "其他"; + if c.bi_list.len() < n + 2 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis = get_sub_elements(&c.bi_list, 1, n); + if bis.len() < n { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis_max_vol = bis + .iter() + .map(|x| x.get_power_volume()) + .fold(f64::NEG_INFINITY, f64::max); + let last_bi = bis.last().unwrap(); + if last_bi.get_power_volume() != bis_max_vol { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let max_high = bis + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + + if last_bi.direction == Direction::Down && last_bi.get_low() == min_low { + v1 = "开多"; + } + if last_bi.direction == Direction::Up && last_bi.get_high() == max_high { + v1 = "开空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_bs_V240526:趋势跟随 BS 辅助 +/// +/// 参数模板:`"{freq}_趋势跟随_BS辅助V240526"` +/// +/// 信号逻辑: +/// 1. 读取最近 7 笔,要求倒数第二笔具备高 SNR、强价格力度、强成交量或斜率特征; +/// 2. 再比较最后一笔相对前一强势笔的价格力度区间; +/// 3. 满足小回撤条件时输出 `买点/卖点`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_趋势跟随_BS辅助V240526_买点_任意_任意_0')` +/// - `Signal('60分钟_趋势跟随_BS辅助V240526_卖点_任意_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 重点观察倒数第二笔是否是“顺畅强趋势笔”。 +/// 对齐说明:与 Python `czsc.signals.cxt_bs_V240526` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bs_V240526", + template = "{freq}_趋势跟随_BS辅助V240526", + opcode = "CxtBsV240526", + param_kind = "CxtBsV240526" +)] +pub fn cxt_bs_v240526(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "趋势跟随"; + let k3 = "BS辅助V240526"; + let mut v1 = "其他"; + if c.bi_list.len() < 11 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let bis = get_sub_elements(&c.bi_list, 1, 7); + if bis.len() < 7 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let b2 = &bis[bis.len() - 2]; + let b1 = &bis[bis.len() - 1]; + let max_power_price = bis + .iter() + .map(|x| x.get_power_price()) + .fold(f64::NEG_INFINITY, f64::max); + let max_power_volume = bis + .iter() + .map(|x| x.get_power_volume()) + .fold(f64::NEG_INFINITY, f64::max); + let max_slope_abs = bis + .iter() + .map(|x| x.get_slope().abs()) + .fold(f64::NEG_INFINITY, f64::max); + + if b2.get_snr() < 0.7 + || (b2.get_power_price() < max_power_price + && b2.get_power_volume() < max_power_volume + && b2.get_slope() < max_slope_abs) + { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + if b2.direction == Direction::Up + && b1.direction == Direction::Down + && 0.1 * b2.get_power_price() < b1.get_power_price() + && b1.get_power_price() < 0.7 * b2.get_power_price() + { + v1 = "买点"; + } + if b2.direction == Direction::Down + && b1.direction == Direction::Up + && 0.2 * b2.get_power_price() < b1.get_power_price() + && b1.get_power_price() < 0.7 * b2.get_power_price() + { + v1 = "卖点"; + } + make_kline_signal_v1(&k1, k2, k3, v1) +} + +/// cxt_bs_V240527:未完成笔上的趋势跟随 BS 辅助 +/// +/// 参数模板:`"{freq}_趋势跟随_BS辅助V240527"` +/// +/// 信号逻辑: +/// 1. 读取最近 7 笔,要求最后一笔本身是高 SNR 的强趋势笔; +/// 2. 再读取当前 UBI 原始 K 线,比较其价格力度相对最后一笔的回撤比例; +/// 3. 满足小回撤条件时输出 `买点/卖点`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_趋势跟随_BS辅助V240527_买点_任意_任意_0')` +/// - `Signal('60分钟_趋势跟随_BS辅助V240527_卖点_任意_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 与 `V240526` 的区别在于这里评估的是未完成笔上的回撤。 +/// 对齐说明:与 Python `czsc.signals.cxt_bs_V240527` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bs_V240527", + template = "{freq}_趋势跟随_BS辅助V240527", + opcode = "CxtBsV240527", + param_kind = "CxtBsV240527" +)] +pub fn cxt_bs_v240527(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "趋势跟随"; + let k3 = "BS辅助V240527"; + let mut v1 = "其他"; + if c.bi_list.len() < 11 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let bis = get_sub_elements(&c.bi_list, 1, 7); + if bis.len() < 7 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let b1 = &bis[bis.len() - 1]; + let max_power_price = bis + .iter() + .map(|x| x.get_power_price()) + .fold(f64::NEG_INFINITY, f64::max); + let max_power_volume = bis + .iter() + .map(|x| x.get_power_volume()) + .fold(f64::NEG_INFINITY, f64::max); + let max_slope_abs = bis + .iter() + .map(|x| x.get_slope().abs()) + .fold(f64::NEG_INFINITY, f64::max); + + if b1.get_snr() < 0.7 + || (b1.get_power_price() < max_power_price + && b1.get_power_volume() < max_power_volume + && b1.get_slope() < max_slope_abs) + { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + let ubi_bars = ubi_raw_bars(c); + if ubi_bars.len() < 7 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let ubi_high = ubi_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let ubi_low = ubi_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let ubi_power_price = ubi_high - ubi_low; + + if b1.direction == Direction::Up + && 0.1 * b1.get_power_price() < ubi_power_price + && ubi_power_price < 0.7 * b1.get_power_price() + { + v1 = "买点"; + } + if b1.direction == Direction::Down + && 0.2 * b1.get_power_price() < ubi_power_price + && ubi_power_price < 0.7 * b1.get_power_price() + { + v1 = "卖点"; + } + make_kline_signal_v1(&k1, k2, k3, v1) +} + +/// cxt_first_buy_V221126:一买信号 +/// +/// 参数模板:`"{freq}_D{di}B_BUY1V221126"` +/// +/// 信号逻辑: +/// 1. 依次尝试最近 `21/19/17/15/13/11/9/7/5` 笔; +/// 2. 调用统一的 `check_first_buy` 结构判定函数识别一买; +/// 3. 命中后输出对应笔数,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1B_BUY1_一买_5笔_任意_0')` +/// - `Signal('60分钟_D1B_BUY1_一买_13笔_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 一买结构判定复用 Python 同名逻辑。 +/// 对齐说明:与 Python `czsc.signals.cxt_first_buy_V221126` 保持一致。 +#[signal( + category = "kline", + name = "cxt_first_buy_V221126", + template = "{freq}_D{di}B_BUY1V221126", + opcode = "CxtFirstBuyV221126", + param_kind = "CxtFirstBuyV221126" +)] +pub fn cxt_first_buy_v221126(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}B", di); + let k3 = "BUY1"; + for n in [21, 19, 17, 15, 13, 11, 9, 7, 5] { + let bis = get_sub_elements(&c.bi_list, di, n); + if bis.len() == n && check_first_buy(bis) { + return make_kline_signal_v2(&k1, &k2, k3, "一买", &format!("{}笔", n)); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// cxt_first_sell_V221126:一卖信号 +/// +/// 参数模板:`"{freq}_D{di}B_SELL1V221126"` +/// +/// 信号逻辑: +/// 1. 依次尝试最近 `21/19/17/15/13/11/9/7/5` 笔; +/// 2. 调用统一的 `check_first_sell` 结构判定函数识别一卖; +/// 3. 命中后输出对应笔数,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1B_SELL1_一卖_5笔_任意_0')` +/// - `Signal('60分钟_D1B_SELL1_一卖_13笔_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 一卖结构判定复用 Python 同名逻辑。 +/// 对齐说明:与 Python `czsc.signals.cxt_first_sell_V221126` 保持一致。 +#[signal( + category = "kline", + name = "cxt_first_sell_V221126", + template = "{freq}_D{di}B_SELL1V221126", + opcode = "CxtFirstSellV221126", + param_kind = "CxtFirstSellV221126" +)] +pub fn cxt_first_sell_v221126(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}B", di); + let k3 = "SELL1"; + for n in [21, 19, 17, 15, 13, 11, 9, 7, 5] { + let bis = get_sub_elements(&c.bi_list, di, n); + if bis.len() == n && check_first_sell(bis) { + return make_kline_signal_v2(&k1, &k2, k3, "一卖", &format!("{}笔", n)); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// cxt_bi_end_V230222:未完成笔分型新高新低次数 +/// +/// 参数模板:`"{freq}_D1MO{max_overlap}_BE辅助V230222"` +/// +/// 信号逻辑: +/// 1. 拼接最后一笔内部已确认分型与当前 UBI 分型序列; +/// 2. 仅当最新分型刚确认,或距最新原始 K 线不超过 `max_overlap` 根时继续判断; +/// 3. 若最新顶分型创序列新高,输出 `新高_第X次`;若底分型创新低,输出 `新低_第X次`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MO3_BE辅助V230222_新高_第2次_任意_0')` +/// - `Signal('60分钟_D1MO3_BE辅助V230222_新低_第1次_任意_0')` +/// +/// 参数说明: +/// - `max_overlap`:允许最新分型与当前原始 K 线的最大重叠根数,默认 `3`; +/// - 超出确认时机或分型不足时返回 `其他_其他`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230222` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230222", + template = "{freq}_D1MO{max_overlap}_BE辅助V230222", + opcode = "CxtBiEndV230222", + param_kind = "CxtBiEndV230222" +)] +pub fn cxt_bi_end_v230222(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let max_overlap = get_usize_param(params, "max_overlap", 3); + let k1 = c.freq.to_string(); + let k2 = format!("D1MO{}", max_overlap); + let k3 = "BE辅助V230222"; + let ubi_fxs = c.get_ubi_fxs().unwrap_or_default(); + if ubi_fxs.is_empty() || c.bars_ubi.len() >= 7 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + + let mut fxs: Vec = Vec::new(); + if let Some(last_bi) = c.bi_list.last() { + if last_bi.fxs.len() > 1 { + fxs.extend_from_slice(&last_bi.fxs[1..]); + } + } + for x in ubi_fxs { + if fxs.last().map(|y| x.dt > y.dt).unwrap_or(true) { + fxs.push(x); + } + } + if fxs.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "第其他次"); + } + let last_fx = fxs.last().unwrap(); + let last_fx_raw = fx_raw_bars(last_fx); + if last_fx_raw.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "第其他次"); + } + if !(last_fx.elements.last().unwrap().dt == c.bars_ubi.last().unwrap().dt + || (c.bars_raw.last().unwrap().id - last_fx_raw.last().unwrap().id) as usize <= max_overlap) + { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "第其他次"); + } + + if last_fx.mark == Mark::G { + let up: Vec<&FX> = fxs.iter().filter(|x| x.mark == Mark::G).collect(); + let mut high_max = f64::NEG_INFINITY; + let mut cnt = 0; + for fx in up { + if fx.high > high_max { + cnt += 1; + high_max = fx.high; + } + } + if last_fx.high == high_max { + return make_kline_signal_v2(&k1, &k2, k3, "新高", &format!("第{}次", cnt)); + } + } else { + let down: Vec<&FX> = fxs.iter().filter(|x| x.mark == Mark::D).collect(); + let mut low_min = f64::INFINITY; + let mut cnt = 0; + for fx in down { + if fx.low < low_min { + cnt += 1; + low_min = fx.low; + } + } + if last_fx.low == low_min { + return make_kline_signal_v2(&k1, &k2, k3, "新低", &format!("第{}次", cnt)); + } + } + make_kline_signal_v2(&k1, &k2, k3, "其他", "第其他次") +} + +/// cxt_third_buy_V230228:笔三买辅助 +/// +/// 参数模板:`"{freq}_D{di}_三买辅助V230228"` +/// +/// 信号逻辑: +/// 1. 依次尝试最近 `13/11/9/7/5` 笔加末笔,共 `n + 1` 笔; +/// 2. 从奇数位上升关键笔中提取突破结构,要求末笔低点在关键高点上方并满足价格约束; +/// 3. 满足条件时输出 `三买_XX笔`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1_三买辅助V230228_三买_6笔_任意_0')` +/// - `Signal('60分钟_D1_三买辅助V230228_三买_10笔_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 该函数仅输出三买,不输出三卖。 +/// 对齐说明:与 Python `czsc.signals.cxt_third_buy_V230228` 保持一致。 +#[signal( + category = "kline", + name = "cxt_third_buy_V230228", + template = "{freq}_D{di}_三买辅助V230228", + opcode = "CxtThirdBuyV230228", + param_kind = "CxtThirdBuyV230228" +)] +pub fn cxt_third_buy_v230228(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}", di); + let k3 = "三买辅助V230228"; + if c.bi_list.len() < di + 11 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + + for n in [13, 11, 9, 7, 5] { + let bis = get_sub_elements(&c.bi_list, di, n + 1); + if bis.len() != n + 1 { + continue; + } + if bis.last().unwrap().direction == Direction::Up || bis.first().unwrap().direction == bis.last().unwrap().direction { + continue; + } + let mut key_bis: Vec<&BI> = Vec::new(); + for i in (0..=(bis.len() - 3)).step_by(2) { + if i == 0 { + key_bis.push(&bis[i]); + } else { + let b1 = &bis[i - 2]; + let b3 = &bis[i]; + if b3.get_high() > b1.get_high() { + key_bis.push(b3); + } + } + } + if key_bis.len() < 2 { + continue; + } + let tb_break = bis.last().unwrap().get_low() + > key_bis.iter().map(|x| x.get_high()).fold(f64::INFINITY, f64::min) + && key_bis.iter().map(|x| x.get_high()).fold(f64::INFINITY, f64::min) + > key_bis.iter().map(|x| x.get_low()).fold(f64::NEG_INFINITY, f64::max); + let tb_price = bis.last().unwrap().get_low() + < bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min) + + 1.618 * mean(&key_bis.iter().map(|x| x.get_power_price()).collect::>()); + if tb_break && tb_price { + return make_kline_signal_v2(&k1, &k2, k3, "三买", &format!("{}笔", bis.len())); + } + } + make_kline_signal_v2(&k1, &k2, k3, "其他", "其他") +} + +/// cxt_third_bs_V230319:带均线形态的第三类买卖点辅助 +/// +/// 参数模板:`"{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230319"` +/// +/// 信号逻辑: +/// 1. 取最近 5 笔构造中枢,并读取第 1、3、5 笔终点分型的均线值; +/// 2. 先根据第 5 笔是否离开中枢,判定 `三买/三卖`; +/// 3. 再根据三次均线相对位置补充 `均线新高/新低/顶分/底分/否定`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1#SMA#34_BS3辅助V230319_三买_均线新高_任意_0')` +/// - `Signal('60分钟_D1#EMA#34_BS3辅助V230319_三卖_均线顶分_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `34`。 +/// 对齐说明:与 Python `czsc.signals.cxt_third_bs_V230319` 保持一致。 +#[signal( + category = "kline", + name = "cxt_third_bs_V230319", + template = "{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230319", + opcode = "CxtThirdBsV230319", + param_kind = "CxtThirdBsV230319" +)] +pub fn cxt_third_bs_v230319(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 34); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let cache_key = format!("{}#{}", ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let k1 = c.freq.to_string(); + let k2 = format!("D{}#{}#{}", di, ma_type, timeperiod); + let k3 = "BS3辅助V230319"; + if c.bi_list.len() < di + 6 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 5); + if bis.len() != 5 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let b1 = &bis[0]; + let b3 = &bis[2]; + let b5 = &bis[4]; + let zs_zd = b1.get_low().max(b3.get_low()); + let zs_zg = b1.get_high().min(b3.get_high()); + if zs_zd > zs_zg { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let id_to_idx = bar_index_map(c); + let mut snapshot_overrides: HashMap = HashMap::new(); + let get_last_ma = |bi: &BI, snapshot_overrides: &mut HashMap| -> Option { + let rb = fx_raw_bars(&bi.fx_b); + let last = rb.last()?; + ma_snapshot_value(c, ma, &id_to_idx, last, &ma_type, timeperiod, snapshot_overrides) + }; + let Some(ma_1) = get_last_ma(b1, &mut snapshot_overrides) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(ma_3) = get_last_ma(b3, &mut snapshot_overrides) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(ma_5) = get_last_ma(b5, &mut snapshot_overrides) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + + let v1 = if b5.direction == Direction::Down && b5.get_low() > zs_zg { + "三买" + } else if b5.direction == Direction::Up && b5.get_high() < zs_zd { + "三卖" + } else { + "其他" + }; + if v1 == "其他" { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let v2 = if ma_5 > ma_3 && ma_3 > ma_1 { + "均线新高" + } else if ma_5 < ma_3 && ma_3 < ma_1 { + "均线新低" + } else if ma_5 > ma_3 && ma_3 < ma_1 { + "均线底分" + } else if ma_5 < ma_3 && ma_3 > ma_1 { + "均线顶分" + } else { + "均线否定" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// cxt_bi_end_V230320:质数窗口笔结束辅助 +/// +/// 参数模板:`"{freq}_D0质数窗口MO{max_overlap}_BE辅助V230320"` +/// +/// 信号逻辑: +/// 1. 展开当前 UBI 原始 K 线,统计其长度是否落在预设质数集合中; +/// 2. 若向上笔后的 UBI 在最近 `max_overlap` 根内创新低,判定 `看多`;向下笔反向判定 `看空`; +/// 3. 输出时补充 `XXK` 表示当前 UBI 长度。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0质数窗口MO3_BE辅助V230320_看多_13K_任意_0')` +/// - `Signal('60分钟_D0质数窗口MO3_BE辅助V230320_看空_17K_任意_0')` +/// +/// 参数说明: +/// - `max_overlap`:允许用末尾 `max_overlap` 根 K 线判断极值,默认 `3`; +/// - 质数窗口集合固定为 `11~97` 内常用质数。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230320` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230320", + template = "{freq}_D0质数窗口MO{max_overlap}_BE辅助V230320", + opcode = "CxtBiEndV230320", + param_kind = "CxtBiEndV230320" +)] +pub fn cxt_bi_end_v230320(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let max_overlap = get_usize_param(params, "max_overlap", 3); + let k1 = c.freq.to_string(); + let k2 = format!("D0质数窗口MO{}", max_overlap); + let k3 = "BE辅助V230320"; + if c.bi_list.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let primes = [11usize, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]; + let last_bi = c.bi_list.last().unwrap(); + let bars = &c.bars_ubi[1..]; + let raw_bars: Vec = bars.iter().flat_map(|x| x.elements.iter().cloned()).collect(); + if raw_bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let ubi_len = raw_bars.len(); + let ubi_min = raw_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let ubi_max = raw_bars.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let mop_bars = &raw_bars[raw_bars.len().saturating_sub(max_overlap)..]; + if last_bi.direction == Direction::Up + && primes.contains(&ubi_len) + && mop_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min) == ubi_min + { + return make_kline_signal_v2(&k1, &k2, k3, "看多", &format!("{}K", ubi_len)); + } + if last_bi.direction == Direction::Down + && primes.contains(&ubi_len) + && mop_bars.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max) == ubi_max + { + return make_kline_signal_v2(&k1, &k2, k3, "看空", &format!("{}K", ubi_len)); + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// cxt_bi_end_V230322:分型配合均线的笔结束辅助 +/// +/// 参数模板:`"{freq}_D0分型配合{ma_type}#{timeperiod}_BE辅助V230322"` +/// +/// 信号逻辑: +/// 1. 读取最新 UBI 分型对应的原始 K 线,并提取分型区间内的均线序列; +/// 2. 向上笔若最新分型与均线位置形成顶部配合,判定 `看空`;向下笔反向判定 `看多`; +/// 3. 再用 `同向分型/反向分型` 说明分型方向与笔方向关系。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0分型配合SMA#5_BE辅助V230322_看多_反向分型_任意_0')` +/// - `Signal('60分钟_D0分型配合EMA#8_BE辅助V230322_看空_同向分型_任意_0')` +/// +/// 参数说明: +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230322` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230322", + template = "{freq}_D0分型配合{ma_type}#{timeperiod}_BE辅助V230322", + opcode = "CxtBiEndV230322", + param_kind = "CxtBiEndV230322" +)] +pub fn cxt_bi_end_v230322(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let timeperiod = get_usize_param(params, "timeperiod", 5); + let cache_key = format!("{}#{}", ma_type, timeperiod); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + let k1 = c.freq.to_string(); + let k2 = format!("D0分型配合{}#{}", ma_type, timeperiod); + let k3 = "BE辅助V230322"; + let ubi_fxs = c.get_ubi_fxs().unwrap_or_default(); + let last_bar = c.bars_raw.last().unwrap(); + if c.bi_list.len() < 3 + || c.bars_ubi.len() > 7 + || ubi_fxs.is_empty() + || last_bar.dt != fx_raw_bars(ubi_fxs.last().unwrap()).last().map(|x| x.dt).unwrap_or(last_bar.dt) + { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let id_to_idx = bar_index_map(c); + let last_bi = c.bi_list.last().unwrap(); + let last_fx = ubi_fxs.last().unwrap(); + let last_fx_raw = fx_raw_bars(last_fx); + let mut ma_vals = Vec::new(); + for rb in &last_fx_raw { + if let Some(idx) = id_to_idx.get(&rb.id) { + if let Some(v) = ma.get(*idx) { + ma_vals.push(*v); + } + } + } + if ma_vals.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let max_ma = ma_vals.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let min_ma = ma_vals.iter().copied().fold(f64::INFINITY, f64::min); + let right_id = last_fx_raw.last().unwrap().id; + let right_ma = ma[*id_to_idx.get(&right_id).unwrap_or(&0)]; + + if last_bi.direction == Direction::Up { + if last_fx.mark == Mark::G && right_ma == min_ma { + return make_kline_signal_v2(&k1, &k2, k3, "看空", "同向分型"); + } + if last_fx.mark == Mark::D && right_ma != max_ma { + return make_kline_signal_v2(&k1, &k2, k3, "看空", "反向分型"); + } + } + if last_bi.direction == Direction::Down { + if last_fx.mark == Mark::D && right_ma == max_ma { + return make_kline_signal_v2(&k1, &k2, k3, "看多", "同向分型"); + } + if last_fx.mark == Mark::G && right_ma != min_ma { + return make_kline_signal_v2(&k1, &k2, k3, "看多", "反向分型"); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// cxt_bi_end_V230618:笔结束小中枢辅助 +/// +/// 参数模板:`"{freq}_D{di}MO{max_overlap}_BE辅助V230618"` +/// +/// 信号逻辑: +/// 1. 读取倒数第 `di` 笔的原始 K 线并做价格覆盖计数; +/// 2. 统计覆盖次数形成的峰值数量,近似识别笔内小中枢; +/// 3. 输出 `看多/看空` 和 `X小中枢/其他`,用于辅助笔结束判断。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MO3_BE辅助V230618_看多_1小中枢_任意_0')` +/// - `Signal('60分钟_D1MO3_BE辅助V230618_看空_其他_任意_0')` +/// +/// 参数说明: +/// - `di`:取倒数第 `di` 笔,默认 `1`; +/// - `max_overlap`:控制 UBI 最大允许延伸长度,默认 `3`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_end_V230618` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_end_V230618", + template = "{freq}_D{di}MO{max_overlap}_BE辅助V230618", + opcode = "CxtBiEndV230618", + param_kind = "CxtBiEndV230618" +)] +pub fn cxt_bi_end_v230618(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let max_overlap = get_usize_param(params, "max_overlap", 3); + let k1 = c.freq.to_string(); + let k2 = format!("D{}MO{}", di, max_overlap); + let k3 = "BE辅助V230618"; + if c.bi_list.len() < di + 6 || c.bars_ubi.len() > 3 + max_overlap - 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bi = &c.bi_list[c.bi_list.len() - di]; + let raw_bars = bi.get_raw_bars(); + if raw_bars.len() < 2 { + return make_kline_signal_v2(&k1, &k2, k3, if bi.direction == Direction::Down { "看多" } else { "看空" }, "其他"); + } + let max_price = raw_bars[..raw_bars.len() - 1].iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let min_price = raw_bars[..raw_bars.len() - 1].iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let price_range = max_price - min_price; + let mut counts = vec![0usize; 101]; + if price_range > 0.0 { + for bar in &raw_bars[..raw_bars.len() - 1] { + let high_pct = (100.0 * (bar.high - min_price) / price_range) as usize; + let low_pct = (100.0 * (bar.low - min_price) / price_range) as usize; + if high_pct == low_pct { + counts[high_pct.min(100)] += 1; + } else { + for count in counts + .iter_mut() + .take(high_pct.min(100) + 1) + .skip(low_pct.min(100)) + { + *count += 1; + } + } + } + } + let mut peak_count = 0usize; + for i in 1..counts.len() - 1 { + if counts[i] == 1 && counts[i] < counts[i - 1] { + peak_count += 1; + } + } + let v1 = if bi.direction == Direction::Down { "看多" } else { "看空" }; + let v2 = if bi.fxs.len() >= 4 && peak_count >= 1 && (bi.fxs[bi.fxs.len() - 4].fx - bi.fxs[bi.fxs.len() - 3].fx) - (bi.fxs[bi.fxs.len() - 2].fx - bi.fxs[bi.fxs.len() - 1].fx) > 0.0 { + format!("{}小中枢", peak_count) + } else { + "其他".to_string() + }; + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} + +/// cxt_three_bi_V230618:三笔形态分类信号 +/// +/// 参数模板:`"{freq}_D{di}三笔_形态V230618"` +/// +/// 信号逻辑: +/// 1. 读取最近 3 笔,依据第 1 笔和第 3 笔的高低点关系划分形态; +/// 2. 识别不重合、奔走、收敛、扩张、盘背、无背等典型三笔结构; +/// 3. 若不满足任何预定义结构,则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1三笔_形态V230618_向下盘背_任意_任意_0')` +/// - `Signal('60分钟_D1三笔_形态V230618_向上扩张_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 仅在未完成笔较短时评估三笔形态。 +/// 对齐说明:与 Python `czsc.signals.cxt_three_bi_V230618` 保持一致。 +#[signal( + category = "kline", + name = "cxt_three_bi_V230618", + template = "{freq}_D{di}三笔_形态V230618", + opcode = "CxtThreeBiV230618", + param_kind = "CxtThreeBiV230618" +)] +pub fn cxt_three_bi_v230618(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}三笔", di); + let k3 = "形态V230618"; + if c.bi_list.len() < di + 6 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 3); + let (bi1, bi2, bi3) = (&bis[0], &bis[1], &bis[2]); + let v1 = if bi3.direction == Direction::Down { + if bi3.get_low() > bi1.get_high() { + "向下不重合" + } else if bi2.get_low() < bi3.get_low() && bi3.get_low() < bi1.get_high() && bi1.get_high() < bi2.get_high() { + "向下奔走型" + } else if bi1.get_high() > bi3.get_high() && bi1.get_low() < bi3.get_low() { + "向下收敛" + } else if bi1.get_high() < bi3.get_high() && bi1.get_low() > bi3.get_low() { + "向下扩张" + } else if bi3.get_low() < bi1.get_low() && bi3.get_high() < bi1.get_high() { + if bi3.get_power() < bi1.get_power() { "向下盘背" } else { "向下无背" } + } else { + "其他" + } + } else if bi3.get_high() < bi1.get_low() { + "向上不重合" + } else if bi2.get_low() < bi1.get_low() && bi1.get_low() < bi3.get_high() && bi3.get_high() < bi2.get_high() { + "向上奔走型" + } else if bi1.get_high() > bi3.get_high() && bi1.get_low() < bi3.get_low() { + "向上收敛" + } else if bi1.get_high() < bi3.get_high() && bi1.get_low() > bi3.get_low() { + "向上扩张" + } else if bi3.get_low() > bi1.get_low() && bi3.get_high() > bi1.get_high() { + if bi3.get_power() < bi1.get_power() { "向上盘背" } else { "向上无背" } + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_five_bi_V230619:五笔形态分类信号 +/// +/// 参数模板:`"{freq}_D{di}五笔_形态V230619"` +/// +/// 信号逻辑: +/// 1. 读取最近 5 笔并计算整体最高点、最低点; +/// 2. 依据中枢重合、首末笔力度与突破位置识别底背驰、顶背驰、颈线突破、类三买卖等形态; +/// 3. 未命中任何预定义结构时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1五笔_形态V230619_aAb式底背驰_任意_任意_0')` +/// - `Signal('60分钟_D1五笔_形态V230619_类三卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 该信号直接输出形态标签,不再附加次级分类。 +/// 对齐说明:与 Python `czsc.signals.cxt_five_bi_V230619` 保持一致。 +#[signal( + category = "kline", + name = "cxt_five_bi_V230619", + template = "{freq}_D{di}五笔_形态V230619", + opcode = "CxtFiveBiV230619", + param_kind = "CxtFiveBiV230619" +)] +pub fn cxt_five_bi_v230619(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}五笔", di); + let k3 = "形态V230619"; + if c.bi_list.len() < di + 6 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 5); + let (bi1, bi2, bi3, bi4, bi5) = (&bis[0], &bis[1], &bis[2], &bis[3], &bis[4]); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let v1 = if bi1.direction == Direction::Down { + if bi2.get_high().min(bi4.get_high()) > bi2.get_low().max(bi4.get_low()) + && max_high == bi1.get_high() + && bi5.get_power() < bi1.get_power() + && ((min_low == bi3.get_low() && bi5.get_low() < bi1.get_low()) || min_low == bi5.get_low()) + { + "aAb式底背驰" + } else if max_high == bi1.get_high() + && min_low == bi5.get_low() + && bi4.get_high() < bi2.get_low() + && bi5.get_power() < bi3.get_power().max(bi1.get_power()) + { + "类趋势底背驰" + } else if (min_low == bi1.get_low() + && bi5.get_high() > bi1.get_high().min(bi2.get_high()) + && bi1.get_high().min(bi2.get_high()) > bi5.get_low() + && bi5.get_low() > bi1.get_low()) + || (min_low == bi3.get_low() + && bi5.get_high() > bi3.get_high() + && bi3.get_high() > bi5.get_low() + && bi5.get_low() > bi3.get_low()) + { + "上颈线突破" + } else if max_high == bi5.get_high() + && bi5.get_high() > bi5.get_low() + && bi5.get_low() > bi1.get_high().max(bi3.get_high()) + && bi1.get_high().min(bi3.get_high()) > bi1.get_low().max(bi3.get_low()) + && bi1.get_low().max(bi3.get_low()) > min_low + { + "类三买" + } else { + "其他" + } + } else if bi2.get_high().min(bi4.get_high()) > bi2.get_low().max(bi4.get_low()) + && min_low == bi1.get_low() + && bi5.get_power() < bi1.get_power() + && ((max_high == bi3.get_high() && bi5.get_high() > bi1.get_high()) || max_high == bi5.get_high()) + { + "aAb式顶背驰" + } else if min_low == bi1.get_low() + && max_high == bi5.get_high() + && bi5.get_power() < bi1.get_power().max(bi3.get_power()) + && bi4.get_low() > bi2.get_high() + { + "类趋势顶背驰" + } else if (max_high == bi1.get_high() + && bi5.get_low() < bi1.get_low().max(bi2.get_low()) + && bi1.get_low().max(bi2.get_low()) < bi5.get_high() + && bi5.get_high() < max_high) + || (max_high == bi3.get_high() + && bi5.get_low() < bi3.get_low() + && bi3.get_low() < bi5.get_high() + && bi5.get_high() < max_high) + { + "下颈线突破" + } else if min_low == bi5.get_low() + && bi5.get_low() < bi5.get_high() + && bi5.get_high() < bi1.get_low().min(bi3.get_low()) + && bi1.get_low().min(bi3.get_low()) < bi1.get_low().max(bi3.get_low()) + && bi1.get_low().max(bi3.get_low()) < bi1.get_high().min(bi3.get_high()) + && bi1.get_high().min(bi3.get_high()) < max_high + { + "类三卖" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_seven_bi_V230620:七笔形态分类信号 +/// +/// 参数模板:`"{freq}_D{di}七笔_形态V230620"` +/// +/// 信号逻辑: +/// 1. 读取最近 7 笔并统计极值与关键中枢关系; +/// 2. 识别 aAbcd、abcAd、类趋势、向上/向下中枢完成、类三买卖等七笔结构; +/// 3. 未命中预定义结构时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1七笔_形态V230620_aAbcd式底背驰_任意_任意_0')` +/// - `Signal('60分钟_D1七笔_形态V230620_向上中枢完成_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 仅在最近结构已基本完成且 UBI 不长时评估。 +/// 对齐说明:与 Python `czsc.signals.cxt_seven_bi_V230620` 保持一致。 +#[signal( + category = "kline", + name = "cxt_seven_bi_V230620", + template = "{freq}_D{di}七笔_形态V230620", + opcode = "CxtSevenBiV230620", + param_kind = "CxtSevenBiV230620" +)] +pub fn cxt_seven_bi_v230620(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}七笔", di); + let k3 = "形态V230620"; + if c.bi_list.len() < di + 10 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 7); + let (bi1, bi2, bi3, bi4, bi5, bi6, bi7) = (&bis[0], &bis[1], &bis[2], &bis[3], &bis[4], &bis[5], &bis[6]); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let v1 = if bi7.direction == Direction::Down { + if bi1.get_high() == max_high && bi7.get_low() == min_low { + if bi2.get_high().min(bi4.get_high()) > bi2.get_low().max(bi4.get_low()) && bi2.get_low().max(bi4.get_low()) > bi6.get_high() && bi7.get_power() < bi5.get_power() { + "aAbcd式底背驰" + } else if bi2.get_low() > bi4.get_high().min(bi6.get_high()) && bi4.get_low().max(bi6.get_low()) < bi4.get_high().min(bi6.get_high()) && bi7.get_power() < bi1.get_high() - bi3.get_low() { + "abcAd式底背驰" + } else if bi2.get_high().min(bi4.get_high()).min(bi6.get_high()) > bi2.get_low().max(bi4.get_low()).max(bi6.get_low()) && bi7.get_power() < bi1.get_power() { + "aAb式底背驰" + } else if bi2.get_low() > bi4.get_high() && bi4.get_low() > bi6.get_high() && bi7.get_power() < bi5.get_power().max(bi3.get_power()).max(bi1.get_power()) { + "类趋势底背驰" + } else { + "其他" + } + } else if bi4.get_low() == min_low + && bi1.get_high().min(bi3.get_high()) > bi1.get_low().max(bi3.get_low()) + && bi5.get_high().min(bi7.get_high()) > bi5.get_low().max(bi7.get_low()) + && bi4.get_high().max(bi6.get_high()) > bi3.get_high().min(bi4.get_high()) + && bi1.get_low().max(bi3.get_low()) < bi5.get_high().max(bi7.get_high()) + { + "向上中枢完成" + } else if bi1.get_low().min(bi3.get_low()) == min_low + && bi5.get_high().max(bi7.get_high()) == max_high + && bi5.get_low().min(bi7.get_low()) > bi1.get_high().max(bi3.get_high()) + && bi1.get_high().min(bi3.get_high()) > bi1.get_low().max(bi3.get_low()) + { + "类三买" + } else { + "其他" + } + } else if bi1.get_low() == min_low && bi7.get_high() == max_high { + if bi6.get_low() > bi2.get_high().min(bi4.get_high()) && bi2.get_high().min(bi4.get_high()) > bi2.get_low().max(bi4.get_low()) && bi7.get_power() < bi5.get_power() { + "aAbcd式顶背驰" + } else if bi4.get_high().min(bi6.get_high()) > bi4.get_low().max(bi6.get_low()) && bi4.get_low().max(bi6.get_low()) > bi2.get_high() && bi7.get_power() < bi3.get_high() - bi1.get_low() { + "abcAd式顶背驰" + } else if bi2.get_high().min(bi4.get_high()).min(bi6.get_high()) > bi2.get_low().max(bi4.get_low()).max(bi6.get_low()) && bi7.get_power() < bi1.get_power() { + "aAb式顶背驰" + } else if bi2.get_high() < bi4.get_low() && bi4.get_high() < bi6.get_low() && bi7.get_power() < bi5.get_power().max(bi3.get_power()).max(bi1.get_power()) { + "类趋势顶背驰" + } else { + "其他" + } + } else if bi4.get_high() == max_high + && bi1.get_high().min(bi3.get_high()) > bi1.get_low().max(bi3.get_low()) + && bi5.get_high().min(bi7.get_high()) > bi5.get_low().max(bi7.get_low()) + && bi4.get_low().min(bi6.get_low()) < bi3.get_low().max(bi4.get_low()) + && bi1.get_high().min(bi3.get_high()) > bi5.get_low().min(bi7.get_low()) + { + "向下中枢完成" + } else if bi5.get_low().min(bi7.get_low()) == min_low + && bi1.get_high().max(bi3.get_high()) == max_high + && bi7.get_high().max(bi5.get_high()) < bi1.get_low().min(bi3.get_low()) + && bi1.get_high().min(bi3.get_high()) > bi1.get_low().max(bi3.get_low()) + { + "类三卖" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_range_oscillation_V230620:区间震荡笔数统计 +/// +/// 参数模板:`"{freq}_D{di}TH{th}_区间震荡V230620"` +/// +/// 信号逻辑: +/// 1. 读取最近 12 笔的中位价格中心; +/// 2. 从最新笔向前逐笔比较中心振幅,只要最大振幅百分比小于 `th` 就继续累加; +/// 3. 若累计笔数超过 1,则输出 `X笔震荡 + 向上/向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1TH2_区间震荡V230620_4笔震荡_向上_任意_0')` +/// - `Signal('60分钟_D1TH2_区间震荡V230620_6笔震荡_向下_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `th`:中心振幅百分比阈值,默认 `2`。 +/// 对齐说明:与 Python `czsc.signals.cxt_range_oscillation_V230620` 保持一致。 +#[signal( + category = "kline", + name = "cxt_range_oscillation_V230620", + template = "{freq}_D{di}TH{th}_区间震荡V230620", + opcode = "CxtRangeOscillationV230620", + param_kind = "CxtRangeOscillationV230620" +)] +pub fn cxt_range_oscillation_v230620(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 2); + let k1 = c.freq.to_string(); + let k2 = format!("D{}TH{}", di, th); + let k3 = "区间震荡V230620"; + if c.bi_list.len() < di + 11 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 12); + if bis.len() != 12 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let mut centers = Vec::new(); + let mut count = 1usize; + for bi in bis.iter().rev() { + centers.push((bi.get_high() + bi.get_low()) / 2.0); + if centers.len() > 1 { + if max_amplitude_pct(¢ers) < th as f64 { + count += 1; + } else { + break; + } + } + } + if count != 1 { + return make_kline_signal_v2(&k1, &k2, k3, &format!("{}笔震荡", count), if bis.last().unwrap().direction == Direction::Up { "向上" } else { "向下" }); + } + make_kline_signal_v2(&k1, &k2, k3, "其他", "其他") +} + +/// cxt_nine_bi_V230621:九笔形态分类信号 +/// +/// 参数模板:`"{freq}_D{di}九笔_形态V230621"` +/// +/// 信号逻辑: +/// 1. 读取最近 9 笔,依据首末极值和中间中枢关系构造结构; +/// 2. 识别 aAb、aAbcd、ABC、类趋势一买卖、类三买卖、类二买卖等九笔形态; +/// 3. 未命中时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1九笔_形态V230621_aAb式类一买_任意_任意_0')` +/// - `Signal('60分钟_D1九笔_形态V230621_ZG三买_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 该分类信号直接返回形态名,不再附加辅助标签。 +/// 对齐说明:与 Python `czsc.signals.cxt_nine_bi_V230621` 保持一致。 +#[signal( + category = "kline", + name = "cxt_nine_bi_V230621", + template = "{freq}_D{di}九笔_形态V230621", + opcode = "CxtNineBiV230621", + param_kind = "CxtNineBiV230621" +)] +pub fn cxt_nine_bi_v230621(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}九笔", di); + let k3 = "形态V230621"; + if c.bi_list.len() < di + 13 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 9); + let (bi1, bi2, bi3, bi4, bi5, bi6, bi7, bi8, bi9) = (&bis[0], &bis[1], &bis[2], &bis[3], &bis[4], &bis[5], &bis[6], &bis[7], &bis[8]); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let odd_1357 = [bi1, bi3, bi5, bi7]; + let v1 = if bi9.direction == Direction::Down { + let mut v1 = "其他"; + if min_low == bi9.get_low() && max_high == bi1.get_high() { + if bi2.get_high().min(bi4.get_high()).min(bi6.get_high()).min(bi8.get_high()) > bi2.get_low().max(bi4.get_low()).max(bi6.get_low()).max(bi8.get_low()) + && bi9.get_power() < bi1.get_power() + && bi3.get_low() >= bi1.get_low() + && bi7.get_high() <= bi9.get_high() + { + v1 = "aAb式类一买"; + } else if bi2.get_high().min(bi4.get_high()).min(bi6.get_high()) > bi2.get_low().max(bi4.get_low()).max(bi6.get_low()) && bi2.get_low().max(bi4.get_low()).max(bi6.get_low()) > bi8.get_high() && bi9.get_power() < bi7.get_power() { + v1 = "aAbcd式类一买"; + } else if bi3.get_low() < bi1.get_low() + && bi7.get_high() > bi9.get_high() + && bi4.get_high().min(bi6.get_high()) > bi4.get_low().max(bi6.get_low()) + && (bi1.get_high() - bi3.get_low()) > (bi7.get_high() - bi9.get_low()) + { + v1 = "ABC式类一买"; + } else if bi8.get_high() < bi6.get_low() + && bi6.get_high() < bi4.get_low() + && bi4.get_high() < bi2.get_low() + && bi9.get_power() < bi1.get_power().max(bi3.get_power()).max(bi5.get_power()).max(bi7.get_power()) + { + v1 = "类趋势一买"; + } + } + if max_high == bi1.get_high().max(bi3.get_high()) + && min_low == bi9.get_low() + && bi2.get_high().min(bi4.get_high()) > bi2.get_low().max(bi4.get_low()) + && bi2.get_low().min(bi4.get_low()) > bi6.get_high().max(bi8.get_high()) + && bi6.get_high().min(bi8.get_high()) > bi6.get_low().max(bi8.get_low()) + && bi9.get_power() < bi5.get_power() + { + v1 = "aAbBc式类一买"; + } + if max_high == bi9.get_high() + && bi9.get_low() > odd_1357.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max) + && odd_1357.iter().map(|x| x.get_high()).fold(f64::INFINITY, f64::min) + > odd_1357.iter().map(|x| x.get_low()).fold(f64::NEG_INFINITY, f64::max) + && bi3.get_low().min(bi5.get_low()) == min_low + { + v1 = "类三买A"; + } + if bi8.get_power() < bi2.get_power() + && max_high == bi9.get_high() + && bi9.get_low() > bi3.get_high().max(bi5.get_high()).max(bi7.get_high()) + && bi3.get_high().min(bi5.get_high()).min(bi7.get_high()) > bi3.get_low().max(bi5.get_low()).max(bi7.get_low()) + && bi1.get_low() == min_low + { + v1 = "类三买B"; + } + if min_low == bi5.get_low() && max_high == bi1.get_high() && bi4.get_high() < bi2.get_low() { + let zd = bi5.get_low().max(bi7.get_low()); + let zg = bi5.get_high().min(bi7.get_high()); + let gg = bi5.get_high().max(bi7.get_high()); + if zg > zd && bi8.get_high() > gg { + if bi9.get_low() > zg { + v1 = "ZG三买"; + } else if bi9.get_high() > gg && gg > zg && bi9.get_low() > zd { + v1 = "类二买"; + } + } + } + v1 + } else if max_high == bi9.get_high() && min_low == bi1.get_low() { + if bi6.get_low() > bi2.get_high().min(bi4.get_high()) + && bi2.get_high().min(bi4.get_high()) > bi2.get_low().max(bi4.get_low()) + && bi6.get_high().min(bi8.get_high()) > bi6.get_low().max(bi8.get_low()) + && bi2.get_high().max(bi4.get_high()) < bi6.get_low().min(bi8.get_low()) + && bi9.get_power() < bi5.get_power() + { + "aAbBc式类一卖" + } else if bi2.get_high().min(bi4.get_high()).min(bi6.get_high()).min(bi8.get_high()) > bi2.get_low().max(bi4.get_low()).max(bi6.get_low()).max(bi8.get_low()) + && bi9.get_power() < bi1.get_power() + && bi3.get_high() <= bi1.get_high() + && bi7.get_low() >= bi9.get_low() + { + "aAb式类一卖" + } else if bi8.get_low() > bi2.get_high().min(bi4.get_high()).min(bi6.get_high()) + && bi2.get_high().min(bi4.get_high()).min(bi6.get_high()) > bi2.get_low().max(bi4.get_low()).max(bi6.get_low()) + && bi9.get_power() < bi7.get_power() + { + "aAbcd式类一卖" + } else if bi3.get_high() > bi1.get_high() + && bi7.get_low() < bi9.get_low() + && bi4.get_high().min(bi6.get_high()) > bi4.get_low().max(bi6.get_low()) + && (bi3.get_high() - bi1.get_low()) > (bi9.get_high() - bi7.get_low()) + { + "ABC式类一卖" + } else if bi8.get_low() > bi6.get_high() + && bi6.get_low() > bi4.get_high() + && bi4.get_low() > bi2.get_high() + && bi9.get_power() < bi1.get_power().max(bi3.get_power()).max(bi5.get_power()).max(bi7.get_power()) + { + "类趋势一卖" + } else { + "其他" + } + } else if max_high == bi1.get_high() + && min_low == bi9.get_low() + && bi9.get_high() < bi3.get_low().max(bi5.get_low()).max(bi7.get_low()) + && bi3.get_low().max(bi5.get_low()).max(bi7.get_low()) < bi3.get_high().min(bi5.get_high()).min(bi7.get_high()) + { + "类三卖A" + } else if min_low == bi1.get_low() && max_high == bi5.get_high() && bi2.get_high() < bi4.get_low() { + let zd = bi5.get_low().max(bi7.get_low()); + let zg = bi5.get_high().min(bi7.get_high()); + let dd = bi5.get_low().min(bi7.get_low()); + if zg > zd && bi8.get_low() < dd { + if bi9.get_high() < zd { + "ZD三卖" + } else if dd < zd && bi9.get_high() < zg { + "类二卖" + } else { + "其他" + } + } else { + "其他" + } + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_eleven_bi_V230622:十一笔形态分类信号 +/// +/// 参数模板:`"{freq}_D{di}十一笔_形态V230622"` +/// +/// 信号逻辑: +/// 1. 读取最近 11 笔并统计首末极值与中间结构关系; +/// 2. 识别 A5B3C3、A3B3C5、A3B5C3、类二买卖、类三买等十一笔结构; +/// 3. 若不满足任何预定义结构,则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1十一笔_形态V230622_A5B3C3式类一买_任意_任意_0')` +/// - `Signal('60分钟_D1十一笔_形态V230622_类二卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 该信号面向更长结构分类,要求笔数和确认度更高。 +/// 对齐说明:与 Python `czsc.signals.cxt_eleven_bi_V230622` 保持一致。 +#[signal( + category = "kline", + name = "cxt_eleven_bi_V230622", + template = "{freq}_D{di}十一笔_形态V230622", + opcode = "CxtElevenBiV230622", + param_kind = "CxtElevenBiV230622" +)] +pub fn cxt_eleven_bi_v230622(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}十一笔", di); + let k3 = "形态V230622"; + if c.bi_list.len() < di + 16 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, 11); + let (bi1, _, bi3, bi4, bi5, bi6, bi7, bi8, bi9, bi10, bi11) = + (&bis[0], &bis[1], &bis[2], &bis[3], &bis[4], &bis[5], &bis[6], &bis[7], &bis[8], &bis[9], &bis[10]); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let v1 = if bi11.direction == Direction::Down { + if min_low == bi11.get_low() && max_high == bi1.get_high() { + if bi5.get_low() == [bi1.get_low(), bi3.get_low(), bi5.get_low()].into_iter().fold(f64::INFINITY, f64::min) + && bi9.get_low() > bi11.get_low() + && bi9.get_high() > bi11.get_high() + && bi8.get_high() > bi6.get_low() + && bi1.get_high() - bi5.get_low() > bi9.get_high() - bi11.get_low() + { + "A5B3C3式类一买" + } else if bi1.get_high() > bi3.get_high() + && bi1.get_low() > bi3.get_low() + && bi7.get_high() == [bi7.get_high(), bi9.get_high(), bi11.get_high()].into_iter().fold(f64::NEG_INFINITY, f64::max) + && bi6.get_high() > bi4.get_low() + && bi1.get_high() - bi3.get_low() > bi7.get_high() - bi11.get_low() + { + "A3B3C5式类一买" + } else if bi1.get_low() > bi3.get_low() + && bi4.get_high().min(bi6.get_high()).min(bi8.get_high()) > bi4.get_low().max(bi6.get_low()).max(bi8.get_low()) + && bi9.get_high() > bi11.get_high() + && bi1.get_high() - bi3.get_low() > bi9.get_high() - bi11.get_low() + { + "A3B5C3式类一买" + } else if bis[1].get_low() > bi4.get_high() + && bi4.get_low() > bi6.get_high() + && bi5.get_low() > bi7.get_low() + && bi10.get_high() > bi8.get_low() + { + "a1Ab式类一买" + } else { + "其他" + } + } else if (bi7.get_power() < bi1.get_power() + && min_low == bi7.get_low() + && bi7.get_low() < bis[1].get_low().max(bi4.get_low()).max(bi6.get_low()) + && bis[1].get_low().max(bi4.get_low()).max(bi6.get_low()) + < bis[1].get_high().min(bi4.get_high()).min(bi6.get_high()) + && bis[1].get_high().min(bi4.get_high()).min(bi6.get_high()) + < bi9.get_high().max(bi11.get_high()) + && bi9.get_high().max(bi11.get_high()) < bi1.get_high() + && max_high == bi1.get_high() + && bi11.get_low() > bis[1].get_low().min(bi4.get_low()).min(bi6.get_low()) + && bi9.get_high().min(bi11.get_high()) > bi9.get_low().max(bi11.get_low())) + || (max_high == bi1.get_high() + && min_low == bi7.get_low() + && bi9.get_high().min(bi11.get_high()) > bi9.get_low().max(bi11.get_low()) + && bi11.get_high().max(bi9.get_high()) > bi4.get_high().max(bi6.get_high()) + && bi9.get_low().min(bi11.get_low()) > bi4.get_low().min(bi6.get_low())) + { + "类二买" + } else { + let gg = [bis[0].get_high(), bis[1].get_high(), bis[2].get_high()].into_iter().fold(f64::NEG_INFINITY, f64::max); + let zg = [bis[0].get_high(), bis[1].get_high(), bis[2].get_high()].into_iter().fold(f64::INFINITY, f64::min); + let zd = [bis[0].get_low(), bis[1].get_low(), bis[2].get_low()].into_iter().fold(f64::NEG_INFINITY, f64::max); + let dd = [bis[0].get_low(), bis[1].get_low(), bis[2].get_low()].into_iter().fold(f64::INFINITY, f64::min); + if max_high == bi11.get_high() + && bi11.get_low() > zg + && zg > zd + && gg > bi5.get_low() + && gg > bi7.get_low() + && gg > bi9.get_low() + && dd < bi5.get_high() + && dd < bi7.get_high() + && dd < bi9.get_high() + { + "类三买" + } else { + "其他" + } + } + } else if max_high == bi11.get_high() && min_low == bi1.get_low() { + if bi5.get_high() == [bi1.get_high(), bi3.get_high(), bi5.get_high()].into_iter().fold(f64::NEG_INFINITY, f64::max) + && bi9.get_low() < bi11.get_low() + && bi9.get_high() < bi11.get_high() + && bi8.get_low() < bi6.get_high() + && bi11.get_high() - bi9.get_low() < bi5.get_high() - bi1.get_low() + { + "A5B3C3式类一卖" + } else if bi7.get_low() == [bi11.get_low(), bi9.get_low(), bi7.get_low()].into_iter().fold(f64::INFINITY, f64::min) + && bi1.get_high() < bi3.get_high() + && bi1.get_low() < bi3.get_low() + && bi6.get_low() < bi4.get_high() + && bi11.get_high() - bi7.get_low() < bi3.get_high() - bi1.get_low() + { + "A3B3C5式类一卖" + } else if bi1.get_high() < bi3.get_high() + && bi4.get_high().min(bi6.get_high()).min(bi8.get_high()) > bi4.get_low().max(bi6.get_low()).max(bi8.get_low()) + && bi9.get_low() < bi11.get_low() + && bi3.get_high() - bi1.get_low() > bi11.get_high() - bi9.get_low() + { + "A3B5C3式类一卖" + } else { + "其他" + } + } else if max_high == bi9.get_high() + && bi9.get_high() > bi8.get_low() + && bi8.get_low() > bi6.get_high() + && bi6.get_high() > bi6.get_low() + && bi6.get_low() > bi4.get_high() + && bi4.get_high() > bi4.get_low() + && bi4.get_low() > bis[1].get_high() + && min_low == bi1.get_low() + && bi11.get_high() < bi9.get_high() + { + "类二卖" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cxt_ubi_end_V230816:UBI 新高新低次数信号 +/// +/// 参数模板:`"{freq}_UBI_BE辅助V230816"` +/// +/// 信号逻辑: +/// 1. 重建当前未完成笔 UBI 结构; +/// 2. 若 UBI 向上,则统计内部顶分型的逐次新高次数;若最后一根再次上破,则输出 `新高_第X次`; +/// 3. UBI 向下时对称统计新低次数,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_UBI_BE辅助V230816_新高_第3次_任意_0')` +/// - `Signal('60分钟_UBI_BE辅助V230816_新低_第2次_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 需要 UBI 已形成足够分型和原始 K 线长度。 +/// 对齐说明:与 Python `czsc.signals.cxt_ubi_end_V230816` 保持一致。 +#[signal( + category = "kline", + name = "cxt_ubi_end_V230816", + template = "{freq}_UBI_BE辅助V230816", + opcode = "CxtUbiEndV230816", + param_kind = "CxtUbiEndV230816" +)] +pub fn cxt_ubi_end_v230816(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "UBI"; + let k3 = "BE辅助V230816"; + let Some(ubi) = rebuild_ubi(c) else { + return make_kline_signal_v2(&k1, k2, k3, "其他", "其他"); + }; + if ubi.fxs.len() <= 2 || c.bars_ubi.len() <= 5 { + return make_kline_signal_v2(&k1, k2, k3, "其他", "其他"); + } + if ubi.direction == Direction::Up { + let fxs: Vec<&FX> = ubi.fxs.iter().filter(|x| x.mark == Mark::G).collect(); + if fxs.is_empty() { + return make_kline_signal_v2(&k1, k2, k3, "其他", "其他"); + } + let mut cnt = 1; + let mut cur_hfx = fxs[0]; + for fx in fxs.iter().skip(1) { + if fx.high > cur_hfx.high { + cnt += 1; + cur_hfx = fx; + } + } + if ubi.raw_bars.last().unwrap().high > cur_hfx.high { + return make_kline_signal_v2(&k1, k2, k3, "新高", &format!("第{}次", cnt + 1)); + } + } + if ubi.direction == Direction::Down { + let fxs: Vec<&FX> = ubi.fxs.iter().filter(|x| x.mark == Mark::D).collect(); + if fxs.is_empty() { + return make_kline_signal_v2(&k1, k2, k3, "其他", "其他"); + } + let mut cnt = 1; + let mut cur_lfx = fxs[0]; + for fx in fxs.iter().skip(1) { + if fx.low < cur_lfx.low { + cnt += 1; + cur_lfx = fx; + } + } + if ubi.raw_bars.last().unwrap().low < cur_lfx.low { + return make_kline_signal_v2(&k1, k2, k3, "新低", &format!("第{}次", cnt + 1)); + } + } + make_kline_signal_v2(&k1, k2, k3, "其他", "其他") +} + +/// cxt_bi_trend_V230913:笔趋势高低点回归信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}笔趋势_高低点辅助判断V230913"` +/// +/// 信号逻辑: +/// 1. 分别取最近 `di` 个向上笔高点和向下笔低点,做线性回归预测; +/// 2. 用当前 UBI 指定位置的时间点预测上沿、下沿及中轴; +/// 3. 将最新收盘相对预测区间的位置映射为 `上升趋势/下降趋势 + 强弱`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D4N1笔趋势_高低点辅助判断V230913_上升趋势_强_任意_0')` +/// - `Signal('60分钟_D4N1笔趋势_高低点辅助判断V230913_下降趋势_超强_任意_0')` +/// +/// 参数说明: +/// - `di`:参与回归的同向笔数量,默认 `4`; +/// - `n`:使用 UBI 中倒数第 `n` 根 K 线做比较,默认 `1`。 +/// 对齐说明:与 Python `czsc.signals.cxt_bi_trend_V230913` 保持一致。 +#[signal( + category = "kline", + name = "cxt_bi_trend_V230913", + template = "{freq}_D{di}N{n}笔趋势_高低点辅助判断V230913", + opcode = "CxtBiTrendV230913", + param_kind = "CxtBiTrendV230913" +)] +pub fn cxt_bi_trend_v230913(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 4); + let n = get_usize_param(params, "n", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}笔趋势", di, n); + let k3 = "高低点辅助判断V230913"; + if c.bi_list.len() <= di + 2 || c.bars_ubi.len() <= n + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let up_bis: Vec<&BI> = c.bi_list.iter().filter(|x| x.direction == Direction::Up).collect(); + let down_bis: Vec<&BI> = c.bi_list.iter().filter(|x| x.direction == Direction::Down).collect(); + if up_bis.is_empty() || down_bis.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let up_take = up_bis.len().min(di); + let down_take = down_bis.len().min(di); + let up_sel = &up_bis[up_bis.len() - up_take..]; + let down_sel = &down_bis[down_bis.len() - down_take..]; + let up_xs: Vec = up_sel.iter().map(|x| x.end_dt().timestamp() as f64).collect(); + let up_ys: Vec = up_sel.iter().map(|x| x.get_high()).collect(); + let down_xs: Vec = down_sel.iter().map(|x| x.end_dt().timestamp() as f64).collect(); + let down_ys: Vec = down_sel.iter().map(|x| x.get_low()).collect(); + let x = c.bars_ubi[c.bars_ubi.len() - n].dt.timestamp() as f64; + let Some(pre_up) = linreg_predict(&up_xs, &up_ys, x) else { return make_kline_signal_v1(&k1, &k2, k3, "其他"); }; + let Some(pre_down) = linreg_predict(&down_xs, &down_ys, x) else { return make_kline_signal_v1(&k1, &k2, k3, "其他"); }; + let pre_mid = (pre_up + pre_down) / 2.0; + if pre_up <= pre_down { + return make_kline_signal_v2(&k1, &k2, k3, "观望", "趋势线交叉"); + } + if c.bars_ubi.len() >= 5 { + return make_kline_signal_v2(&k1, &k2, k3, "观望", "末笔延伸"); + } + let close = c.bars_raw[c.bars_raw.len() - n].close; + if close >= pre_up { + make_kline_signal_v2(&k1, &k2, k3, "上升趋势", "超强") + } else if close > pre_mid { + make_kline_signal_v2(&k1, &k2, k3, "上升趋势", "强") + } else if close > pre_down { + make_kline_signal_v2(&k1, &k2, k3, "下降趋势", "强") + } else { + make_kline_signal_v2(&k1, &k2, k3, "下降趋势", "超强") + } +} + +/// cxt_second_bs_V240524:第二买卖点重叠计数信号 +/// +/// 参数模板:`"{freq}_D{di}W{w}T{t}_第二买卖点V240524"` +/// +/// 信号逻辑: +/// 1. 读取最近 `w` 笔,统计最后一笔终点分型与前面笔终点分型的重叠次数; +/// 2. 最后一笔为向下且长度足够、重叠次数不少于 `t` 时判定 `二买`; +/// 3. 最后一笔为向上且满足同样条件时判定 `二卖`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W9T2_第二买卖点V240524_二买_任意_任意_0')` +/// - `Signal('60分钟_D1W9T2_第二买卖点V240524_二卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `w`:统计窗口笔数,默认 `9`; +/// - `t`:最少重叠次数,默认 `2`。 +/// 对齐说明:与 Python `czsc.signals.cxt_second_bs_V240524` 保持一致。 +#[signal( + category = "kline", + name = "cxt_second_bs_V240524", + template = "{freq}_D{di}W{w}T{t}_第二买卖点V240524", + opcode = "CxtSecondBsV240524", + param_kind = "CxtSecondBsV240524" +)] +pub fn cxt_second_bs_v240524(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let w = get_usize_param(params, "w", 9); + let t = get_usize_param(params, "t", 2); + assert!(w > 5); + assert!(t >= 2); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}T{}", di, w, t); + let k3 = "第二买卖点V240524"; + if c.bi_list.len() < w + di + 5 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bis = get_sub_elements(&c.bi_list, di, w); + let last = bis.last().unwrap(); + let last_fx_high = last.fx_b.high; + let last_fx_low = last.fx_b.low; + let fxs: Vec<&FX> = bis[..bis.len() - 1] + .iter() + .filter(|x| x.get_length() >= 7) + .map(|x| &x.fx_b) + .collect(); + let zs_count = fxs + .iter() + .filter(|fx| overlap(fx.high, fx.low, last_fx_high, last_fx_low)) + .count(); + if last.direction == Direction::Down && last.get_length() >= 7 && zs_count >= t { + return make_kline_signal_v1(&k1, &k2, k3, "二买"); + } + if last.direction == Direction::Up && last.get_length() >= 7 && zs_count >= t { + return make_kline_signal_v1(&k1, &k2, k3, "二卖"); + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// cxt_overlap_V240612:顺畅笔分型支撑压力信号 +/// +/// 参数模板:`"{freq}_SNR顺畅N{n}_支撑压力V240612"` +/// +/// 信号逻辑: +/// 1. 在最近 `n` 笔中筛选原始 K 线数量足够的笔,并按 SNR 排序; +/// 2. 选择 SNR 最高且大于阈值的“顺畅笔”,提取其顶分型与底分型区间; +/// 3. 若最新笔终点分型与这些区间重叠,则输出 `支撑/压力 + 顶分型/底分型`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_SNR顺畅N7_支撑压力V240612_支撑_顺畅笔顶分型_任意_0')` +/// - `Signal('60分钟_SNR顺畅N7_支撑压力V240612_压力_顺畅笔底分型_任意_0')` +/// +/// 参数说明: +/// - `n`:候选顺畅笔数量窗口,默认 `7`; +/// - 仅当最大 SNR 不低于 `0.7` 时才输出具体支撑压力。 +/// 对齐说明:与 Python `czsc.signals.cxt_overlap_V240612` 保持一致。 +#[signal( + category = "kline", + name = "cxt_overlap_V240612", + template = "{freq}_SNR顺畅N{n}_支撑压力V240612", + opcode = "CxtOverlapV240612", + param_kind = "CxtOverlapV240612" +)] +pub fn cxt_overlap_v240612(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 7); + let k1 = c.freq.to_string(); + let k2 = format!("SNR顺畅N{}", n); + let k3 = "支撑压力V240612"; + if c.bi_list.len() < n + 2 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let mut bis: Vec<&BI> = get_sub_elements(&c.bi_list, 3, n) + .iter() + .filter(|x| x.get_raw_bars().len() >= 9) + .collect(); + if bis.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + bis.sort_by(|a, b| a.get_snr().partial_cmp(&b.get_snr()).unwrap_or(std::cmp::Ordering::Equal)); + let max_snr_bi = bis.last().unwrap(); + if max_snr_bi.get_snr() < 0.7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let (fxg, fxd) = if max_snr_bi.direction == Direction::Down { + (&max_snr_bi.fx_a, &max_snr_bi.fx_b) + } else { + (&max_snr_bi.fx_b, &max_snr_bi.fx_a) + }; + let last_bi = c.bi_list.last().unwrap(); + let mut v1 = "其他"; + let mut v2 = "任意"; + if last_bi.direction == Direction::Down { + if overlap(fxg.high, fxg.low, last_bi.fx_b.high, last_bi.fx_b.low) { + v1 = "支撑"; + v2 = "顺畅笔顶分型"; + } + if overlap(fxd.high, fxd.low, last_bi.fx_b.high, last_bi.fx_b.low) { + v1 = "支撑"; + v2 = "顺畅笔底分型"; + } + } + if last_bi.direction == Direction::Up { + if overlap(fxg.high, fxg.low, last_bi.fx_b.high, last_bi.fx_b.low) { + v1 = "压力"; + v2 = "顺畅笔顶分型"; + } + if overlap(fxd.high, fxd.low, last_bi.fx_b.high, last_bi.fx_b.low) { + v1 = "压力"; + v2 = "顺畅笔底分型"; + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} diff --git a/crates/czsc-signals/src/cxt_trader.rs b/crates/czsc-signals/src/cxt_trader.rs new file mode 100644 index 000000000..24d3d3086 --- /dev/null +++ b/crates/czsc-signals/src/cxt_trader.rs @@ -0,0 +1,158 @@ +use crate::params::ParamView; +use crate::utils::sig::{get_str_param, get_usize_param, make_signal_v1}; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::state::TraderState; +use czsc_core::objects::zs::ZS; +use czsc_signal_macros::signal; + +fn is_valid_zs(bis: &[czsc_core::objects::bi::BI]) -> bool { + bis.len() >= 3 && ZS::new(bis.to_vec()).zg > ZS::new(bis.to_vec()).zd +} + +/// cxt_zhong_shu_gong_zhen_V221221:大小级别中枢共振 +/// +/// 参数模板:`"{freq1}_{freq2}_中枢共振V221221"` +/// +/// 信号逻辑: +/// 1. 大小级别最近 3 笔均构成有效中枢; +/// 2. 小级别中枢位置相对大级别中轴偏上且末笔向下,判 `看多`; +/// 3. 小级别中枢位置相对大级别中轴偏下且末笔向上,判 `看空`。 +#[signal( + category = "trader", + name = "cxt_zhong_shu_gong_zhen_V221221", + template = "{freq1}_{freq2}_中枢共振V221221", + opcode = "CxtZhongShuGongZhenV221221", + param_kind = "CxtZhongShuGongZhenV221221" +)] +pub fn cxt_zhong_shu_gong_zhen_v221221(cat: &dyn TraderState, params: &ParamView) -> Vec { + let freq1 = get_str_param(params, "freq1", "日线"); + let freq2 = get_str_param(params, "freq2", "60分钟"); + let k1 = freq1.to_string(); + let k2 = freq2.to_string(); + let k3 = "中枢共振V221221"; + + let Some(max_freq) = cat.get_czsc(freq1) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(min_freq) = cat.get_czsc(freq2) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + if max_freq.bi_list.len() < 5 || min_freq.bi_list.len() < 5 { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + + let big_bis = &max_freq.bi_list[max_freq.bi_list.len() - 3..]; + let small_bis = &min_freq.bi_list[min_freq.bi_list.len() - 3..]; + if !is_valid_zs(big_bis) || !is_valid_zs(small_bis) { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + + let big_zs = ZS::new(big_bis.to_vec()); + let small_zs = ZS::new(small_bis.to_vec()); + if small_zs.dd > big_zs.zz && min_freq.bi_list.last().unwrap().direction == Direction::Down { + return make_signal_v1(&k1, &k2, k3, "看多"); + } + if small_zs.gg < big_zs.zz && min_freq.bi_list.last().unwrap().direction == Direction::Up { + return make_signal_v1(&k1, &k2, k3, "看空"); + } + make_signal_v1(&k1, &k2, k3, "其他") +} + +/// cxt_intraday_V230701:30分钟日内走势分类 +/// +/// 参数模板:`"{freq1}#{freq2}_D{di}日_走势分类V230701"` +/// +/// 信号逻辑: +/// 1. 取指定日的 30 分钟 bars; +/// 2. 识别无中枢、双中枢、单中枢平衡市; +/// 3. 返回对应日内结构标签。 +#[signal( + category = "trader", + name = "cxt_intraday_V230701", + template = "{freq1}#{freq2}_D{di}日_走势分类V230701", + opcode = "CxtIntradayV230701", + param_kind = "CxtIntradayV230701" +)] +pub fn cxt_intraday_v230701(cat: &dyn TraderState, params: &ParamView) -> Vec { + let di = get_usize_param(params, "di", 2); + let freq1 = get_str_param(params, "freq1", "30分钟"); + let freq2 = get_str_param(params, "freq2", "日线"); + assert_eq!(freq1, "30分钟"); + assert_eq!(freq2, "日线"); + assert!(di > 0 && di < 21); + + let k1 = format!("{}#{}", freq1, freq2); + let k2 = format!("D{}日", di); + let k3 = "走势分类V230701"; + + let Some(c1) = cat.get_czsc(freq1) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(c2) = cat.get_czsc(freq2) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + if c2.bars_raw.len() < di { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + + let day = c2.bars_raw[c2.bars_raw.len() - di].dt.date_naive(); + let bars: Vec<&RawBar> = c1 + .bars_raw + .iter() + .filter(|x| x.dt.date_naive() == day) + .collect(); + assert!(bars.len() <= 8, "仅适用于A股市场 30 分钟日内 8 根K线"); + if bars.len() <= 4 { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut zs_list: Vec<(f64, f64)> = Vec::new(); + for w in bars.windows(3) { + let highs = [w[0].high, w[1].high, w[2].high]; + let lows = [w[0].low, w[1].low, w[2].low]; + let zg = highs.into_iter().fold(f64::INFINITY, f64::min); + let zd = lows.into_iter().fold(f64::NEG_INFINITY, f64::max); + if zg >= zd { + zs_list.push(( + [w[0].high, w[1].high, w[2].high] + .into_iter() + .fold(f64::NEG_INFINITY, f64::max), + [w[0].low, w[1].low, w[2].low] + .into_iter() + .fold(f64::INFINITY, f64::min), + )); + } + } + + let dir = if bars.last().unwrap().close > bars.first().unwrap().open { + "上涨" + } else { + "下跌" + }; + if zs_list.is_empty() { + return make_signal_v1(&k1, &k2, k3, &format!("无中枢{}", dir)); + } + + if zs_list.len() >= 2 { + let (zs1_high, zs1_low) = zs_list[0]; + let (zs2_high, zs2_low) = zs_list[zs_list.len() - 1]; + if (dir == "上涨" && zs1_high < zs2_low) || (dir == "下跌" && zs1_low > zs2_high) { + return make_signal_v1(&k1, &k2, k3, &format!("双中枢{}", dir)); + } + } + + let high_first = bars[0].high.max(bars[1].high).max(bars[2].high) + == bars.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let low_first = bars[0].low.min(bars[1].low).min(bars[2].low) + == bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let v1 = if high_first && !low_first { + "弱平衡市" + } else if low_first && !high_first { + "强平衡市" + } else { + "转折平衡市" + }; + make_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/jcc.rs b/crates/czsc-signals/src/jcc.rs new file mode 100644 index 000000000..48f99def7 --- /dev/null +++ b/crates/czsc-signals/src/jcc.rs @@ -0,0 +1,1336 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, get_usize_param, make_kline_signal_v1, make_kline_signal_v2}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +fn get_f64_param(params: &ParamView, key: &str, default: f64) -> f64 { + if let Some(v) = params.value(key) { + if let Some(x) = v.as_f64() { + return x; + } + if let Some(s) = v.as_str() { + if let Ok(x) = s.parse::() { + return x; + } + } + } + default +} + +#[inline] +fn py_float_str(v: f64) -> String { + let mut s = v.to_string(); + if !s.contains('.') && !s.contains('e') && !s.contains('E') { + s.push_str(".0"); + } + s +} + +#[inline] +fn solid(bar: &RawBar) -> f64 { + (bar.open - bar.close).abs() +} + +#[inline] +fn upper(bar: &RawBar) -> f64 { + bar.high - bar.open.max(bar.close) +} + +#[inline] +fn lower(bar: &RawBar) -> f64 { + bar.open.min(bar.close) - bar.low +} + +fn variance(values: &[f64]) -> f64 { + if values.is_empty() { + return f64::NAN; + } + let mean = values.iter().sum::() / values.len() as f64; + values.iter().map(|x| (x - mean).powi(2)).sum::() / values.len() as f64 +} + +fn check_szx(bar: &RawBar, th: i32) -> bool { + if bar.close == bar.open && bar.high != bar.low { + return true; + } + if bar.close != bar.open && (bar.high - bar.low) / (bar.close - bar.open).abs() > th as f64 { + return true; + } + false +} + +/// jcc_san_xing_xian_V221023:伞形线形态信号 +/// +/// 参数模板:`"{freq}_D{di}TH{th}_伞形线"` +/// +/// 信号逻辑: +/// 1. 判断当前K线是否满足长下影短上影(下影 > 实体 * th,且上影 < 0.2 * 实体); +/// 2. 若满足,再结合左侧20根区间位置,判定 `锤子/上吊`; +/// 3. 不满足时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1TH200_伞形线_满足_锤子_任意_0')` +/// - `Signal('15分钟_D1TH200_伞形线_满足_上吊_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `th`:下影线与实体倍数阈值,默认 `2`(内部按 100 倍整数编码)。 +/// 对齐说明:与 Python `jcc_san_xing_xian_V221023` 判定顺序一致。 +#[signal( + category = "kline", + name = "jcc_san_xing_xian_V221023", + template = "{freq}_D{di}TH{th}_伞形线V221023", + opcode = "JccSanXingXianV221023", + param_kind = "JccSanXingXianV221023" +)] +pub fn jcc_san_xing_xian_v221023(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_f64_param(params, "th", 2.0); + let th_i = (th * 100.0) as i32; + + let k1 = c.freq.to_string(); + let k2 = format!("D{}TH{}", di, th_i); + let k3 = "伞形线"; + + if di == 0 || di > c.bars_raw.len() { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + + let bar = &c.bars_raw[c.bars_raw.len() - di]; + let x1 = bar.high - bar.open.max(bar.close); + let x2 = (bar.close - bar.open).abs(); + let x3 = bar.open.min(bar.close) - bar.low; + let v1 = if x3 > x2 * th_i as f64 / 100.0 && x1 < 0.2 * x2 { + "满足" + } else { + "其他" + }; + + let mut v2 = "其他"; + if c.bars_raw.len() > 20 + di { + let left_bars = get_sub_elements(&c.bars_raw, di, 20); + if !left_bars.is_empty() { + let left_max = left_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let left_min = left_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let gap = left_max - left_min; + if bar.low <= left_min + 0.25 * gap { + v2 = "锤子"; + } else if bar.high >= left_max - 0.25 * gap { + v2 = "上吊"; + } + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_ten_mo_V221028:吞没形态 +/// +/// 参数模板:`"{freq}_D{di}_吞没形态"` +/// +/// 信号逻辑: +/// 1. 当前K线高低点完全包住前一根K线,记 `满足`; +/// 2. 结合左侧20根位置与实体方向,区分 `看涨吞没/看跌吞没`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1_吞没形态_满足_看涨吞没_任意_0')` +/// - `Signal('15分钟_D1_吞没形态_满足_看跌吞没_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `jcc_ten_mo_V221028` 判定条件一致。 +#[signal( + category = "kline", + name = "jcc_ten_mo_V221028", + template = "{freq}_D{di}_吞没形态V221028", + opcode = "JccTenMoV221028", + param_kind = "JccTenMoV221028" +)] +pub fn jcc_ten_mo_v221028(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}", di); + let k3 = "吞没形态"; + + if c.bars_raw.len() < di + 1 || di == 0 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + + let bar1 = &c.bars_raw[c.bars_raw.len() - di]; + let bar2 = &c.bars_raw[c.bars_raw.len() - di - 1]; + let v1 = if bar1.high > bar2.high && bar1.low < bar2.low { + "满足" + } else { + "其他" + }; + + let mut v2 = "其他"; + if c.bars_raw.len() > 20 + di { + let left_bars = get_sub_elements(&c.bars_raw, di, 20); + if !left_bars.is_empty() { + let left_max = left_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let left_min = left_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let gap = left_max - left_min; + + if bar1.low <= left_min + 0.25 * gap + && bar1.close > bar1.open + && bar1.close > bar2.high + && bar1.open < bar2.low + { + v2 = "看涨吞没"; + } else if bar1.high >= left_max - 0.25 * gap + && bar1.close < bar1.open + && bar1.close < bar2.low + && bar1.open > bar2.high + { + v2 = "看跌吞没"; + } + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_wu_yun_gai_ding_V221101:乌云盖顶 +/// +/// 参数模板:`"{freq}_D{di}Z{z}TH{th}_乌云盖顶"` +/// +/// 信号逻辑: +/// 1. 前一根阳线实体涨幅需大于 `z`; +/// 2. 当前K线跳空高开,且收盘回落到前一根实体内部; +/// 3. 前一根收盘位于左侧10根收盘高位,判定 `满足`。 +/// +/// 信号列表示例: +/// - `Signal('日线_D1Z500TH50_乌云盖顶_满足_任意_任意_0')` +/// - `Signal('日线_D1Z500TH50_乌云盖顶_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `z`:前一根阳线最小涨幅(BP),默认 `500`; +/// - `th`:当前收盘扎入前一根实体比例,默认 `50`。 +/// 对齐说明:与 Python `jcc_wu_yun_gai_ding_V221101` 一致。 +#[signal( + category = "kline", + name = "jcc_wu_yun_gai_ding_V221101", + template = "{freq}_D{di}Z{z}TH{th}_乌云盖顶V221101", + opcode = "JccWuYunGaiDingV221101", + param_kind = "JccWuYunGaiDingV221101" +)] +pub fn jcc_wu_yun_gai_ding_v221101( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let z = get_usize_param(params, "z", 500) as f64; + let th = get_usize_param(params, "th", 50) as f64; + + let k1 = c.freq.to_string(); + let k2 = format!("D{}Z{}TH{}", di, z as i32, th as i32); + let k3 = "乌云盖顶"; + let mut v1 = "其他"; + + if c.bars_raw.len() > di + 10 && di > 0 { + let pre_bar = &c.bars_raw[c.bars_raw.len() - di - 1]; + let bar = &c.bars_raw[c.bars_raw.len() - di]; + let z0 = (pre_bar.close - pre_bar.open) / pre_bar.open * 10000.0; + let flag_z = z0 > z; + let flag_ho = bar.open > pre_bar.high; + let flag_th = bar.close < (pre_bar.close + pre_bar.open) * (th / 100.0); + + let left_bars = get_sub_elements(&c.bars_raw, di + 2, 10); + if !left_bars.is_empty() { + let left_max_close = left_bars + .iter() + .map(|x| x.close) + .fold(f64::NEG_INFINITY, f64::max); + let flag_up = pre_bar.close >= left_max_close; + if flag_z && flag_ho && flag_th && flag_up { + v1 = "满足"; + } + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// jcc_ci_tou_V221101:刺透形态 +/// +/// 参数模板:`"{freq}_D{di}Z{z}TH{th}_刺透形态"` +/// +/// 信号逻辑: +/// 1. 前一根为大阴线且跌幅超过 `z`; +/// 2. 当前低开并收盘刺入前一根实体 `th` 比例以上; +/// 3. 满足则返回 `满足`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1Z100TH50_刺透形态_满足_任意_任意_0')` +/// - `Signal('15分钟_D1Z100TH50_刺透形态_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `z`:前一根最小跌幅(BP),默认 `100`; +/// - `th`:刺入比例阈值,默认 `50`。 +/// 对齐说明:与 Python `jcc_ci_tou_V221101` 一致。 +#[signal( + category = "kline", + name = "jcc_ci_tou_V221101", + template = "{freq}_D{di}Z{z}TH{th}_刺透形态V221101", + opcode = "JccCiTouV221101", + param_kind = "JccCiTouV221101" +)] +pub fn jcc_ci_tou_v221101(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let z = get_usize_param(params, "z", 100) as f64; + let th = get_usize_param(params, "th", 50) as f64; + + let k1 = c.freq.to_string(); + let k2 = format!("D{}Z{}TH{}", di, z as i32, th as i32); + let k3 = "刺透形态"; + + if c.bars_raw.len() < di + 15 || di == 0 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let bars = get_sub_elements(&c.bars_raw, di, 2); + if bars.len() != 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bar2 = &bars[0]; + let bar1 = &bars[1]; + + let c1 = bar2.close < bar2.open && (1.0 - bar2.close / bar2.open) > z / 10000.0; + let c2 = + bar1.open < bar2.low && bar1.close > bar2.close + (bar2.open - bar2.close) * (th / 100.0); + let v1 = if c1 && c2 { "满足" } else { "其他" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// jcc_san_fa_V20221118:三法形态A +/// +/// 参数模板:`"{freq}_D{di}K_三法A"` +/// +/// 信号逻辑: +/// 1. 在 5~8 根窗口内扫描三法形态; +/// 2. 满足上升三法或下降三法时输出方向; +/// 3. `v2` 记录触发窗口长度 `nK`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K_三法A_上升三法_6K_任意_0')` +/// - `Signal('60分钟_D1K_三法A_下降三法_8K_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:窗口遍历与条件组合对齐 Python `jcc_san_fa_V20221118`。 +#[signal( + category = "kline", + name = "jcc_san_fa_V20221118", + template = "{freq}_D{di}K_三法AV20221118", + opcode = "JccSanFaV20221118", + param_kind = "JccSanFaV20221118" +)] +pub fn jcc_san_fa_v20221118(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "三法A"; + + let check = |bars: &[RawBar]| -> &'static str { + if bars.len() < 5 { + return "其他"; + } + let last = bars.last().unwrap(); + let first = &bars[0]; + let c1 = if last.close > last.open && first.close > first.open && last.close > first.high { + "上升" + } else if last.close < last.open && first.close < first.open && last.close < first.low { + "下降" + } else { + "其他" + }; + + let mid = &bars[1..bars.len() - 1]; + let hhc = mid.iter().map(|x| x.close).fold(f64::NEG_INFINITY, f64::max); + let llc = mid.iter().map(|x| x.close).fold(f64::INFINITY, f64::min); + let hhv = mid.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let llv = mid.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + + if c1 == "上升" && last.close > hhv && hhv > first.high && llv > first.open && first.close > hhc { + "上升三法" + } else if c1 == "下降" + && first.low > llv + && llv > last.close + && hhv < first.open + && first.close < llc + { + "下降三法" + } else { + "其他" + } + }; + + let mut v1 = "其他"; + let mut v2 = "其他"; + for n in [5usize, 6, 7, 8] { + let bars = get_sub_elements(&c.bars_raw, di, n); + let t = check(bars); + if t != "其他" { + v1 = t; + v2 = if n == 5 { + "5K" + } else if n == 6 { + "6K" + } else if n == 7 { + "7K" + } else { + "8K" + }; + break; + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_san_fa_V20221115:三法形态 +/// +/// 参数模板:`"{freq}_D{di}K_三法"` +/// +/// 信号逻辑: +/// 1. 固定观察 6 根K线,比较首尾与中间三根位置关系; +/// 2. 满足基础强度阈值时,判定 `上升三法/下降三法`; +/// 3. 不满足返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K_三法_满足_上升三法_任意_0')` +/// - `Signal('60分钟_D1K_三法_满足_下降三法_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `zdf`:首尾涨跌幅阈值(BP),默认 `500`。 +/// 对齐说明:与 Python `jcc_san_fa_V20221115` 判定一致。 +#[signal( + category = "kline", + name = "jcc_san_fa_V20221115", + template = "{freq}_D{di}K_三法V20221115", + opcode = "JccSanFaV20221115", + param_kind = "JccSanFaV20221115" +)] +pub fn jcc_san_fa_v20221115(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let zdf = get_usize_param(params, "zdf", 500) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "三法"; + + let bars = get_sub_elements(&c.bars_raw, di, 6); + if bars.len() != 6 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let bar6 = &bars[0]; + let bar5 = &bars[1]; + let bar4 = &bars[2]; + let bar3 = &bars[3]; + let bar2 = &bars[4]; + let bar1 = &bars[5]; + + let bar1_zdf = ((bar2.close - bar1.close) / bar2.close).abs() * 10000.0; + let bar5_zdf = ((bar6.close - bar5.close) / bar6.close).abs() * 10000.0; + let max_high = bar2.high.max(bar3.high).max(bar4.high); + let min_low = bar2.low.min(bar3.low).min(bar4.low); + + let v1 = if bar1_zdf >= zdf && bar5_zdf > zdf && bar5.high > max_high { + "满足" + } else { + "其他" + }; + + let v2 = if bar5.close > bar5.open + && bar1.close > bar1.open + && bar1.close > bar5.high + && bar1.close > max_high + && bar1.open > bar2.close + { + "上升三法" + } else if bar5.close < bar5.open + && bar1.close < bar1.open + && bar1.close < bar5.low + && bar1.close < min_low + && bar1.open < bar2.close + { + "下降三法" + } else { + "其他" + }; + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_xing_xian_V221118:星形线 +/// +/// 参数模板:`"{freq}_D{di}TH{th}_星形线"` +/// +/// 信号逻辑: +/// 1. 取三根K线,按高低点结构区分启明星/黄昏星候选; +/// 2. 再结合实体强弱关系做最终确认; +/// 3. 中间K线开收相等时输出 `中间十字`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D2TH2_星形线_启明星_中间十字_任意_0')` +/// - `Signal('60分钟_D2TH2_星形线_黄昏星_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `2`; +/// - `th`:左侧实体与中间实体的倍率阈值,默认 `2`。 +/// 对齐说明:与 Python `jcc_xing_xian_V221118` 条件一致。 +#[signal( + category = "kline", + name = "jcc_xing_xian_V221118", + template = "{freq}_D{di}TH{th}_星形线V221118", + opcode = "JccXingXianV221118", + param_kind = "JccXingXianV221118" +)] +pub fn jcc_xing_xian_v221118(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 2); + let th = get_usize_param(params, "th", 2) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}TH{}", di, th as i32); + let k3 = "星形线"; + + let bars = get_sub_elements(&c.bars_raw, di, 3); + if bars.len() != 3 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + let bar3 = &bars[0]; + let bar2 = &bars[1]; + let bar1 = &bars[2]; + + let x3 = (bar3.close - bar3.open).abs(); + let x2 = (bar2.close - bar2.open).abs(); + let x1 = (bar1.close - bar1.open).abs(); + + let mut v1 = "其他"; + if bar3.high > bar2.high + && bar2.high < bar1.high + && bar3.low > bar2.low + && bar2.low < bar1.low + && bar3.close < bar3.open + && x2 * th < x3 + && x3 < x2 + x1 + && bar1.close > bar1.open + && bar1.open > bar2.close.max(bar2.open) + { + v1 = "启明星"; + } else if bar3.high < bar2.high + && bar2.high > bar1.high + && bar3.low < bar2.low + && bar2.low > bar1.low + && bar3.close > bar3.open + && x2 * th < x3 + && x3 < x2 + x1 + && bar1.close < bar1.open + && bar1.open < bar2.close.min(bar2.open) + { + v1 = "黄昏星"; + } + + let v2 = if bar2.close == bar2.open { "中间十字" } else { "任意" }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_fen_shou_xian_V20221113:分手线 +/// +/// 参数模板:`"{freq}_D{di}K_分手线"` +/// +/// 信号逻辑: +/// 1. 两根K线同开盘,且第二根收盘突破第一根高低点,判 `满足`; +/// 2. 结合区间位置与实体方向,细分 `上升分手/下跌分手`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K_分手线_满足_上升分手_任意_0')` +/// - `Signal('60分钟_D1K_分手线_满足_下跌分手_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `zdf`:分手强度阈值(BP),默认 `300`。 +/// 对齐说明:与 Python `jcc_fen_shou_xian_V20221113` 一致。 +#[signal( + category = "kline", + name = "jcc_fen_shou_xian_V20221113", + template = "{freq}_D{di}K_分手线V20221113", + opcode = "JccFenShouXianV20221113", + param_kind = "JccFenShouXianV20221113" +)] +pub fn jcc_fen_shou_xian_v20221113( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let zdf = get_usize_param(params, "zdf", 300) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "分手线"; + + if c.bars_raw.len() < di + 1 || di == 0 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + + let bar1 = &c.bars_raw[c.bars_raw.len() - di]; + let bar2 = &c.bars_raw[c.bars_raw.len() - di - 1]; + let v1 = if (bar1.open == bar2.open && bar1.close < bar2.low) || bar1.close > bar2.high { + "满足" + } else { + "其他" + }; + + let mut v2 = "其他"; + if c.bars_raw.len() > 20 + di { + let left_bars = get_sub_elements(&c.bars_raw, di, 20); + if !left_bars.is_empty() { + let left_max = left_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let left_min = left_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let gap = left_max - left_min; + + if bar1.low <= left_min + 0.25 * gap + && bar1.open == bar2.open + && bar1.close < bar2.low + && bar2.close > bar2.open + && (bar2.close - bar1.close) / bar2.close * 10000.0 > zdf + { + v2 = "下跌分手"; + } else if bar1.high >= left_max - 0.25 * gap + && bar1.open == bar2.open + && bar1.close > bar2.high + && bar2.close < bar2.open + && (bar1.close - bar2.close) / bar2.close * 10000.0 > zdf + { + v2 = "上升分手"; + } + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_zhu_huo_xian_V221027:烛火线 +/// +/// 参数模板:`"{freq}_D{di}T{th}F{zf}_烛火线"` +/// +/// 信号逻辑: +/// 1. 以影线、实体和振幅阈值判断是否 `满足`; +/// 2. 结合左侧20根区间位置判定 `箭在弦/风中烛`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T200F500_烛火线_满足_箭在弦_任意_0')` +/// - `Signal('60分钟_D1T200F500_烛火线_满足_风中烛_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `th`:上影线与实体倍数阈值,默认 `2`; +/// - `zf`:最小振幅阈值(BP),默认 `500`。 +/// 对齐说明:按 Python `jcc_zhu_huo_xian_V221027` 原始公式实现。 +#[signal( + category = "kline", + name = "jcc_zhu_huo_xian_V221027", + template = "{freq}_D{di}T{th}F{zf}_烛火线V221027", + opcode = "JccZhuHuoXianV221027", + param_kind = "JccZhuHuoXianV221027" +)] +pub fn jcc_zhu_huo_xian_v221027( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_f64_param(params, "th", 2.0); + let zf = get_usize_param(params, "zf", 500) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}T{}F{}", di, py_float_str(th), zf as i32); + let k3 = "烛火线"; + + if di == 0 || di > c.bars_raw.len() { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + + let bar = &c.bars_raw[c.bars_raw.len() - di]; + let x1 = bar.high - bar.open.max(bar.close); + let x2 = (bar.close - bar.open).abs(); + let x3 = bar.open.min(bar.close) - bar.low; + let zf_min = if bar.low != 0.0 { + (bar.high - bar.low) / bar.low * 10000.0 >= zf + } else { + false + }; + + let v1 = if x1 > x2 * th / 100.0 && x3 < 0.2 * x2 && x3 < 0.5 * x1 && zf_min { + "满足" + } else { + "其他" + }; + + let mut v2 = "其他"; + if c.bars_raw.len() > 20 + di { + let left_bars = get_sub_elements(&c.bars_raw, di, 20); + if !left_bars.is_empty() { + let left_max = left_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let left_min = left_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let gap = left_max - left_min; + if bar.low <= left_min + 0.25 * gap { + v2 = "箭在弦"; + } else if bar.high >= left_max - 0.25 * gap { + v2 = "风中烛"; + } + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_yun_xian_V221118:孕线形态 +/// +/// 参数模板:`"{freq}_D{di}_孕线"` +/// +/// 信号逻辑: +/// 1. 前一根为长实体,当前为小实体; +/// 2. 当前开收位于前一根实体区间内; +/// 3. 方向反转判 `看多/看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1_孕线_看多_任意_任意_0')` +/// - `Signal('60分钟_D1_孕线_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `jcc_yun_xian_V221118` 一致。 +#[signal( + category = "kline", + name = "jcc_yun_xian_V221118", + template = "{freq}_D{di}_孕线V221118", + opcode = "JccYunXianV221118", + param_kind = "JccYunXianV221118" +)] +pub fn jcc_yun_xian_v221118(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}", di); + let k3 = "孕线"; + + let bars = get_sub_elements(&c.bars_raw, di, 2); + if bars.len() != 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bar2 = &bars[0]; + let bar1 = &bars[1]; + + let mut v1 = "其他"; + if solid(bar2) > upper(bar2).max(lower(bar2)) && solid(bar1) < upper(bar1).max(lower(bar1)) { + if bar2.close > bar1.close + && bar1.close > bar2.open + && bar2.close > bar1.open + && bar1.open > bar2.open + { + v1 = "看空"; + } + if bar2.close < bar1.close + && bar1.close < bar2.open + && bar2.close < bar1.open + && bar1.open < bar2.open + { + v1 = "看多"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// jcc_ping_tou_V221113:平头形态 +/// +/// 参数模板:`"{freq}_D{di}TH{th}_平头形态"` +/// +/// 信号逻辑: +/// 1. 对比两根K线高点或低点差值比例,识别 `顶部/底部`; +/// 2. 实体条件满足时给出 `实体标准` 标签; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D2TH100_平头形态_顶部_实体标准_任意_0')` +/// - `Signal('15分钟_D2TH100_平头形态_底部_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `2`; +/// - `th`:高低点容差阈值(BP),默认 `100`。 +/// 对齐说明:与 Python `jcc_ping_tou_V221113` 一致。 +#[signal( + category = "kline", + name = "jcc_ping_tou_V221113", + template = "{freq}_D{di}TH{th}_平头形态V221113", + opcode = "JccPingTouV221113", + param_kind = "JccPingTouV221113" +)] +pub fn jcc_ping_tou_v221113(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 2); + let th = get_usize_param(params, "th", 100) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}TH{}", di, th as i32); + let k3 = "平头形态"; + + let bars = get_sub_elements(&c.bars_raw, di, 2); + if bars.len() != 2 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + let bar2 = &bars[0]; + let bar1 = &bars[1]; + + let v1 = if (bar2.low - bar1.low).abs() * 10000.0 / bar2.low.max(bar1.low) < th { + "底部" + } else if (bar2.high - bar1.high).abs() * 10000.0 / bar2.high.max(bar1.high) < th { + "顶部" + } else { + "其他" + }; + + let v2 = if solid(bar2) > solid(bar1).max(upper(bar1)) { + "实体标准" + } else { + "任意" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_two_crow_V221108:两只乌鸦 +/// +/// 参数模板:`"{freq}_D{di}K_两只乌鸦"` +/// +/// 信号逻辑: +/// 1. 第一根长阳; +/// 2. 第二根高开低走且收在第一根高点之上; +/// 3. 第三根阴线继续下压,判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K_两只乌鸦_看空_任意_任意_0')` +/// - `Signal('60分钟_D1K_两只乌鸦_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `jcc_two_crow_V221108` 一致。 +#[signal( + category = "kline", + name = "jcc_two_crow_V221108", + template = "{freq}_D{di}K_两只乌鸦V221108", + opcode = "JccTwoCrowV221108", + param_kind = "JccTwoCrowV221108" +)] +pub fn jcc_two_crow_v221108(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "两只乌鸦"; + + let bars = get_sub_elements(&c.bars_raw, di, 3); + if bars.len() != 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bar3 = &bars[0]; + let bar2 = &bars[1]; + let bar1 = &bars[2]; + + let c1 = bar3.close > bar3.open && solid(bar3) > upper(bar3).max(lower(bar3)); + let c2 = bar2.open > bar2.close && bar2.close > bar3.high; + let c3 = bar1.close < bar1.open && bar1.close < bar2.close; + let v1 = if c1 && c2 && c3 { "看空" } else { "其他" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// jcc_three_crow_V221108:三只乌鸦 +/// +/// 参数模板:`"{freq}_D{di}_三只乌鸦"` +/// +/// 信号逻辑: +/// 1. 三根连续阴线且高低收盘递降; +/// 2. 影线和实体关系满足强空形态; +/// 3. 根据开盘关系细分 `常规/加强/半加强`。 +/// +/// 信号列表示例: +/// - `Signal('30分钟_D1_三只乌鸦_满足_常规_任意_0')` +/// - `Signal('30分钟_D1_三只乌鸦_满足_加强_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `jcc_three_crow_V221108` 一致。 +#[signal( + category = "kline", + name = "jcc_three_crow_V221108", + template = "{freq}_D{di}_三只乌鸦V221108", + opcode = "JccThreeCrowV221108", + param_kind = "JccThreeCrowV221108" +)] +pub fn jcc_three_crow_v221108(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}", di); + let k3 = "三只乌鸦"; + + if c.bars_raw.len() < di + 2 || di == 0 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + + let bar1 = &c.bars_raw[c.bars_raw.len() - di]; + let bar2 = &c.bars_raw[c.bars_raw.len() - di - 1]; + let bar3 = &c.bars_raw[c.bars_raw.len() - di - 2]; + + if c.bars_raw.len() > 23 { + let left_bars = get_sub_elements(&c.bars_raw, 3, 20); + if !left_bars.is_empty() { + let left_max = left_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let left_min = left_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let gap = left_max - left_min; + if bar3.high < left_max - 0.25 * gap { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let _ = left_min; + } + } + + let mut v1 = "其他"; + let mut v2 = "其他"; + if bar1.close < bar1.open + && bar2.close < bar2.open + && bar3.open > bar3.close + && bar3.close > bar2.close + && bar2.close > bar1.close + && bar3.high > bar2.high + && bar2.high > bar1.high + { + let c_low = (bar1.close - bar1.low) < 0.5 * (bar1.open - bar1.close) + && (bar2.close - bar2.low) < 0.5 * (bar2.open - bar2.close) + && (bar3.close - bar3.low) < 0.5 * (bar3.open - bar3.close); + let c_up = (bar1.high - bar1.open) < (bar1.open - bar1.close) + && (bar2.high - bar2.open) < (bar2.open - bar2.close) + && (bar3.high - bar3.open) < (bar3.open - bar3.close); + if c_low && c_up { + if bar2.close <= bar1.open + && bar1.open <= bar2.open + && bar3.close <= bar2.open + && bar2.open <= bar3.open + { + v1 = "满足"; + v2 = "常规"; + } else if bar1.open < bar2.close && bar2.open < bar3.close { + v1 = "满足"; + v2 = "加强"; + } else if (bar2.close <= bar1.open + && bar1.open <= bar2.open + && bar3.open < bar3.close) + || (bar3.close <= bar2.open && bar2.open <= bar3.open && bar1.open < bar2.close) + { + v1 = "满足"; + v2 = "半加强"; + } + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_szx_V221111:十字线 +/// +/// 参数模板:`"{freq}_D{di}TH{th}_十字线"` +/// +/// 信号逻辑: +/// 1. `(high-low)/|close-open| > th` 或 `close==open` 判十字线; +/// 2. 按上下影长度细分 `蜻蜓/墓碑/长腿/十字线`; +/// 3. 前一根强阳时追加 `北方`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1TH10_十字线_蜻蜓十字线_北方_任意_0')` +/// - `Signal('60分钟_D1TH10_十字线_墓碑十字线_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `th`:十字线阈值,默认 `10`。 +/// 对齐说明:与 Python `jcc_szx_V221111` 一致。 +#[signal( + category = "kline", + name = "jcc_szx_V221111", + template = "{freq}_D{di}TH{th}_十字线V221111", + opcode = "JccSzxV221111", + param_kind = "JccSzxV221111" +)] +pub fn jcc_szx_v221111(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 10) as i32; + let k1 = c.freq.to_string(); + let k2 = format!("D{}TH{}", di, th); + let k3 = "十字线"; + + if c.bars_raw.len() < di + 10 || di == 0 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let bars = get_sub_elements(&c.bars_raw, di, 2); + if bars.len() != 2 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let bar2 = &bars[0]; + let bar1 = &bars[1]; + + let v1 = if check_szx(bar1, th) { + let upper = upper(bar1); + let body = solid(bar1); + let lower = lower(bar1); + if lower > upper * 2.0 { + "蜻蜓十字线" + } else if lower == 0.0 || lower < body { + "墓碑十字线" + } else if lower > solid(bar2) && upper > solid(bar2) { + "长腿十字线" + } else { + "十字线" + } + } else { + "其他" + }; + + let v2 = if bar2.close > bar2.open && solid(bar2) > (upper(bar2) + lower(bar2)) * 3.0 { + "北方" + } else { + "任意" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_san_szx_V221122:三星形态 +/// +/// 参数模板:`"{freq}_D{di}T{th}_三星"` +/// +/// 信号逻辑: +/// 1. 取最近5根K线; +/// 2. 统计十字线数量; +/// 3. 不少于3根判 `满足`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1T10_三星_满足_任意_任意_0')` +/// - `Signal('15分钟_D1T10_三星_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `th`:十字线阈值,默认 `10`。 +/// 对齐说明:与 Python `jcc_san_szx_V221122` 一致。 +#[signal( + category = "kline", + name = "jcc_san_szx_V221122", + template = "{freq}_D{di}T{th}_三星V221122", + opcode = "JccSanSzxV221122", + param_kind = "JccSanSzxV221122" +)] +pub fn jcc_san_szx_v221122(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 10) as i32; + let k1 = c.freq.to_string(); + let k2 = format!("D{}T{}", di, th); + let k3 = "三星"; + + let mut v1 = "其他"; + if c.bars_raw.len() > 6 + di { + let bars = get_sub_elements(&c.bars_raw, di, 5); + let cnt = bars.iter().filter(|b| check_szx(b, th)).count(); + if cnt >= 3 { + v1 = "满足"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// jcc_fan_ji_xian_V221121:反击线 +/// +/// 参数模板:`"{freq}_D{di}_反击线"` +/// +/// 信号逻辑: +/// 1. 最近20根内检测收盘接近、跳空幅度和实体强度; +/// 2. 满足基础条件后,按区间位置和方向细分 `看涨反击线/看跌反击线`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1_反击线_满足_看涨反击线_任意_0')` +/// - `Signal('15分钟_D1_反击线_满足_看跌反击线_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `jcc_fan_ji_xian_V221121` 一致。 +#[signal( + category = "kline", + name = "jcc_fan_ji_xian_V221121", + template = "{freq}_D{di}_反击线V221121", + opcode = "JccFanJiXianV221121", + param_kind = "JccFanJiXianV221121" +)] +pub fn jcc_fan_ji_xian_v221121(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}", di); + let k3 = "反击线"; + + if c.bars_raw.len() < 20 + di { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + + let left_bars = get_sub_elements(&c.bars_raw, di, 20); + if left_bars.len() < 3 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + let left_max = left_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let left_min = left_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let gap = left_max - left_min; + + let bar1 = &left_bars[left_bars.len() - 3]; + let bar2 = &left_bars[left_bars.len() - 2]; + let bar3 = &left_bars[left_bars.len() - 1]; + + let mut v1 = "其他"; + if bar2.close != bar2.open { + let bar2h = (bar2.close - bar2.open).abs(); + let x1 = (bar3.open - bar2.close).abs() / bar2h; + let x2 = (bar3.close - bar2.close).abs() / bar2h; + let x3 = bar2h / gap; + if x1 >= 1.0 && x2 <= 0.1 && x3 >= 0.02 { + v1 = "满足"; + } + } + + let mut v2 = "任意"; + if v1 == "满足" { + if bar1.low <= left_min + 0.25 * gap + && bar1.close > bar2.close + && bar2.open > bar2.close + && bar2.close > bar3.open + { + v2 = "看涨反击线"; + } else if bar1.high >= left_max - 0.25 * gap + && bar2.close > bar1.close + && bar3.open > bar2.close + && bar2.close > bar2.open + { + v2 = "看跌反击线"; + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// jcc_shan_chun_V221121:山川形态 +/// +/// 参数模板:`"{freq}_D{di}B_山川形态"` +/// +/// 信号逻辑: +/// 1. 取最近5笔; +/// 2. 末笔向上且 `5/3/1` 笔高点方差小,判 `三山`; +/// 3. 末笔向下且 `5/3/1` 笔低点方差小,判 `三川`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1B_山川形态_三山_任意_任意_0')` +/// - `Signal('15分钟_D1B_山川形态_三川_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:截止倒数第 `di` 笔,默认 `1`。 +/// 对齐说明:方差阈值与 Python `jcc_shan_chun_V221121` 一致。 +#[signal( + category = "kline", + name = "jcc_shan_chun_V221121", + template = "{freq}_D{di}B_山川形态V221121", + opcode = "JccShanChunV221121", + param_kind = "JccShanChunV221121" +)] +pub fn jcc_shan_chun_v221121(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}B", di); + let k3 = "山川形态"; + + let mut v1 = "其他"; + if c.bi_list.len() >= 6 + di { + let bis = get_sub_elements(&c.bi_list, di, 5); + if bis.len() == 5 { + let b5 = &bis[0]; + let b3 = &bis[2]; + let b1 = &bis[4]; + if matches!(b1.direction, Direction::Up) + && variance(&[b5.get_high(), b3.get_high(), b1.get_high()]) < 0.2 + { + v1 = "三山"; + } + if matches!(b1.direction, Direction::Down) + && variance(&[b5.get_low(), b3.get_low(), b1.get_low()]) < 0.2 + { + v1 = "三川"; + } + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// jcc_gap_yin_yang_V221121:跳空并列阴阳 +/// +/// 参数模板:`"{freq}_D{di}K_并列阴阳"` +/// +/// 信号逻辑: +/// 1. 最近三根满足跳空窗口; +/// 2. 两根并列阴阳实体方差小于阈值; +/// 3. 判定 `向上跳空/向下跳空`。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1K_并列阴阳_向上跳空_任意_任意_0')` +/// - `Signal('15分钟_D1K_并列阴阳_向下跳空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `jcc_gap_yin_yang_V221121` 一致。 +#[signal( + category = "kline", + name = "jcc_gap_yin_yang_V221121", + template = "{freq}_D{di}K_并列阴阳V221121", + opcode = "JccGapYinYangV221121", + param_kind = "JccGapYinYangV221121" +)] +pub fn jcc_gap_yin_yang_v221121( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "并列阴阳"; + + let mut v1 = "其他"; + if c.bars_raw.len() > di + 5 { + let bars = get_sub_elements(&c.bars_raw, di, 3); + if bars.len() == 3 { + let bar3 = &bars[0]; + let bar2 = &bars[1]; + let bar1 = &bars[2]; + + if bar1.low.min(bar2.low) > bar3.high + && bar2.close > bar2.open + && bar1.close < bar1.open + && variance(&[solid(bar1), solid(bar2)]) < 0.2 + { + v1 = "向上跳空"; + } else if bar1.high.max(bar2.high) < bar3.low + && bar2.close < bar2.open + && bar1.close > bar1.open + && variance(&[solid(bar1), solid(bar2)]) < 0.2 + { + v1 = "向下跳空"; + } + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// jcc_ta_xing_V221124:塔形顶底 +/// +/// 参数模板:`"{freq}_D{di}K_塔形"` +/// +/// 信号逻辑: +/// 1. 在 5~9 根窗口内扫描; +/// 2. 首尾实体最大且中间高低点聚集; +/// 3. 判定 `顶部/底部`,并返回窗口长度。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_D1K_塔形_顶部_6K_任意_0')` +/// - `Signal('15分钟_D1K_塔形_底部_8K_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:与 Python `jcc_ta_xing_V221124` 一致。 +#[signal( + category = "kline", + name = "jcc_ta_xing_V221124", + template = "{freq}_D{di}K_塔形V221124", + opcode = "JccTaXingV221124", + param_kind = "JccTaXingV221124" +)] +pub fn jcc_ta_xing_v221124(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "塔形"; + + let check = |bars: &[RawBar]| -> &'static str { + if bars.len() < 5 { + return "其他"; + } + let rb = &bars[0]; + let lb = bars.last().unwrap(); + let mut solids: Vec = bars.iter().map(solid).collect(); + solids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + if solid(rb).min(solid(lb)) >= solids[solids.len() - 2] { + let mid = &bars[1..bars.len() - 1]; + let g_c1 = rb.close > rb.open && lb.close < lb.open; + let g_c2 = variance(&mid.iter().map(|x| x.high).collect::>()) < 0.5; + let g_c3 = mid.iter().all(|x| x.low > rb.open.max(lb.close)); + if g_c1 && g_c2 && g_c3 { + return "顶部"; + } + + let d_c1 = rb.close < rb.open && lb.close > lb.open; + let d_c2 = variance(&mid.iter().map(|x| x.low).collect::>()) < 0.5; + let d_c3 = mid.iter().all(|x| x.high < rb.open.min(lb.close)); + if d_c1 && d_c2 && d_c3 { + return "底部"; + } + } + "其他" + }; + + let mut v1 = "其他"; + let mut v2 = "其他"; + for n in [5usize, 6, 7, 8, 9] { + let bars = get_sub_elements(&c.bars_raw, di, n); + let t = check(bars); + if t != "其他" { + v1 = t; + v2 = if n == 5 { + "5K" + } else if n == 6 { + "6K" + } else if n == 7 { + "7K" + } else if n == 8 { + "8K" + } else { + "9K" + }; + break; + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} diff --git a/crates/czsc-signals/src/kcatr.rs b/crates/czsc-signals/src/kcatr.rs new file mode 100644 index 000000000..40e675fc9 --- /dev/null +++ b/crates/czsc-signals/src/kcatr.rs @@ -0,0 +1,78 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +/// kcatr_up_dw_line_V230823:ATR 通道突破多空 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}T{th}_KCATR多空V230823"` +/// +/// 信号逻辑: +/// 1. 在最近 `n` 根上计算平均真实波幅 `ATR`; +/// 2. 在最近 `m` 根上计算收盘均值 `middle`; +/// 3. 最新收盘价大于 `middle + ATR * th` 判 `看多`; +/// 4. 最新收盘价小于 `middle - ATR * th` 判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N30M16T2_KCATR多空V230823_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N30M16T2_KCATR多空V230823_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:ATR 计算窗口,默认 `30`; +/// - `m`:中轨均值窗口,默认 `16`; +/// - `th`:ATR 倍数阈值,默认 `2`。 +/// 对齐说明:ATR 取样与突破阈值口径对齐 Python `kcatr_up_dw_line_V230823`。 +#[signal( + category = "kline", + name = "kcatr_up_dw_line_V230823", + template = "{freq}_D{di}N{n}M{m}T{th}_KCATR多空V230823", + opcode = "KcatrUpDwLineV230823", + param_kind = "KcatrUpDwLineV230823" +)] +pub fn kcatr_up_dw_line_v230823( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 30); + let m = params.usize("m", 16); + let th = params.usize("th", 2); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}T{}", di, n, m, th); + let k3 = "KCATR多空V230823"; + let mut v1 = "其他"; + + if c.bars_raw.len() < di + n.max(m) + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let n_bars = get_sub_elements(&c.bars_raw, di, n); + let m_bars = get_sub_elements(&c.bars_raw, di, m); + if n_bars.len() < 2 || m_bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let mut tr_sum = 0.0; + for i in 1..n_bars.len() { + let b = &n_bars[i]; + let p = &n_bars[i - 1]; + let tr1 = (b.high - b.low).abs(); + let tr2 = (b.high - p.close).abs(); + let tr3 = (b.low - p.close).abs(); + tr_sum += tr1.max(tr2).max(tr3); + } + let atr = tr_sum / (n_bars.len() - 1) as f64; + let middle = m_bars.iter().map(|x| x.close).sum::() / m_bars.len() as f64; + let close = m_bars[m_bars.len() - 1].close; + + if close > middle + atr * th as f64 { + v1 = "看多"; + } else if close < middle - atr * th as f64 { + v1 = "看空"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/lib.rs b/crates/czsc-signals/src/lib.rs new file mode 100644 index 000000000..7169ac95d --- /dev/null +++ b/crates/czsc-signals/src/lib.rs @@ -0,0 +1,116 @@ +//! czsc-signals — signal function library. +//! +//! Migrated from rs-czsc 47ef6efa per docs/MIGRATION_NOTES.md §1. +//! Each signal sub-module is wrapped in `#[signal_module]` (proc-macro +//! from czsc-signal-macros) which validates signatures and registers +//! every `#[signal(...)]` function into the global inventory. + +extern crate self as czsc_signals; + +use czsc_signal_macros::signal_module; + +#[signal_module(category = "kline")] +pub mod bar { + include!("bar.rs"); +} + +#[signal_module(category = "kline")] +pub mod cxt { + include!("cxt.rs"); +} + +#[signal_module(category = "trader")] +pub mod cxt_trader { + include!("cxt_trader.rs"); +} + +#[signal_module(category = "trader")] +pub mod pos { + include!("pos.rs"); +} + +#[signal_module(category = "trader")] +pub mod cat { + include!("cat.rs"); +} +pub mod params; +pub mod registry; + +#[signal_module(category = "kline")] +pub mod tas { + include!("tas.rs"); +} + +#[signal_module(category = "kline")] +pub mod vol { + include!("vol.rs"); +} + +#[signal_module(category = "kline")] +pub mod pressure { + include!("pressure.rs"); +} + +#[signal_module(category = "kline")] +pub mod obv { + include!("obv.rs"); +} + +#[signal_module(category = "kline")] +pub mod cvolp { + include!("cvolp.rs"); +} + +#[signal_module(category = "kline")] +pub mod ntmdk { + include!("ntmdk.rs"); +} + +#[signal_module(category = "kline")] +pub mod kcatr { + include!("kcatr.rs"); +} + +#[signal_module(category = "kline")] +pub mod clv { + include!("clv.rs"); +} + +#[signal_module(category = "kline")] +pub mod ang { + include!("ang.rs"); +} + +#[signal_module(category = "kline")] +pub mod coo { + include!("coo.rs"); +} + +#[signal_module(category = "kline")] +pub mod byi { + include!("byi.rs"); +} + +#[signal_module(category = "kline")] +pub mod jcc { + include!("jcc.rs"); +} + +#[signal_module(category = "kline")] +pub mod xl { + include!("xl.rs"); +} + +#[signal_module(category = "kline")] +pub mod zdy { + include!("zdy.rs"); +} + +#[signal_module(category = "trader")] +pub mod zdy_trader { + include!("zdy_trader.rs"); +} +pub mod types; +pub mod utils; + +inventory::collect!(crate::types::SignalDescriptor); diff --git a/crates/czsc-signals/src/ntmdk.rs b/crates/czsc-signals/src/ntmdk.rs new file mode 100644 index 000000000..3843e759a --- /dev/null +++ b/crates/czsc-signals/src/ntmdk.rs @@ -0,0 +1,54 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +/// ntmdk_V230824:M 日前收盘价对比多空 +/// +/// 参数模板:`"{freq}_D{di}M{m}_NTMDK多空V230824"` +/// +/// 信号逻辑: +/// 1. 取截止倒数第 `di` 根的最近 `m` 根K线; +/// 2. 若末根收盘价大于首根收盘价,判 `看多`; +/// 3. 否则判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1M10_NTMDK多空V230824_看多_任意_任意_0')` +/// - `Signal('60分钟_D1M10_NTMDK多空V230824_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `m`:回看比较窗口,默认 `10`。 +/// 对齐说明:比较口径对齐 Python `ntmdk_V230824`。 +#[signal( + category = "kline", + name = "ntmdk_V230824", + template = "{freq}_D{di}M{m}_NTMDK多空V230824", + opcode = "NtmdkV230824", + param_kind = "NtmdkV230824" +)] +pub fn ntmdk_v230824(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let m = params.usize("m", 10); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}M{}", di, m); + let k3 = "NTMDK多空V230824"; + let v1_default = "其他"; + + if c.bars_raw.len() < di + m + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + let bars = get_sub_elements(&c.bars_raw, di, m); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + let v1 = if bars[bars.len() - 1].close > bars[0].close { + "看多" + } else { + "看空" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/obv.rs b/crates/czsc-signals/src/obv.rs new file mode 100644 index 000000000..5c27aa55e --- /dev/null +++ b/crates/czsc-signals/src/obv.rs @@ -0,0 +1,277 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{bar_index_map, get_sub_elements, make_kline_signal_v1}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +fn signed_vol(bar: &czsc_core::objects::bar::RawBar) -> f64 { + if bar.close > bar.open { + bar.vol + } else { + -bar.vol + } +} + +/// 更新 OBV 缓存,保持与 Python `bar.cache["OBV"]` 的累计语义一致。 +fn update_obv_cache(c: &CZSC, cache: &mut TaCache) { + let cache_key = "OBV"; + let now_len = c.bars_raw.len(); + if now_len == 0 { + return; + } + let bar_ids: Vec = c.bars_raw.iter().map(|b| b.id).collect(); + + let mut need_init = !cache.series.contains_key(cache_key); + if !need_init { + if let Some(existing_ids) = cache.series_ids.get(cache_key) { + if now_len < 2 || existing_ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing_ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + if need_init { + let mut obv = vec![0.0; now_len]; + obv[0] = signed_vol(&c.bars_raw[0]); + for i in 1..now_len { + obv[i] = obv[i - 1] + signed_vol(&c.bars_raw[i]); + } + cache.series.insert(cache_key.to_string(), obv); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.series.get(cache_key).unwrap(); + let existing_ids = cache.series_ids.get(cache_key).unwrap(); + let mut old_map = std::collections::HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, existing[i]); + } + + let mut obv = Vec::with_capacity(now_len); + for id in &bar_ids { + obv.push(*old_map.get(id).unwrap_or(&f64::NAN)); + } + + if let Some(start) = obv.iter().position(|x| x.is_nan()) { + if start == 0 { + obv[0] = signed_vol(&c.bars_raw[0]); + for i in 1..now_len { + obv[i] = obv[i - 1] + signed_vol(&c.bars_raw[i]); + } + } else { + for i in start..now_len { + obv[i] = obv[i - 1] + signed_vol(&c.bars_raw[i]); + } + } + } + + if now_len == 1 { + obv[0] = signed_vol(&c.bars_raw[0]); + } else { + // 对齐 Python 多频流式语义:最后一根未完成 bar 的 open/close/vol 会持续更新, + // 即使 bar id 不变,也必须基于前一根已确认 bar 的 OBV 重新计算当前值。 + obv[now_len - 1] = obv[now_len - 2] + signed_vol(&c.bars_raw[now_len - 1]); + } + + cache.series.insert(cache_key.to_string(), obv); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +/// 对齐 TA-Lib `EMA`:前 `n-1` 个位置为 NaN,第 `n` 个位置用前 `n` 个值的 SMA 初始化。 +fn calc_ema_talib_style(series: &[f64], n: usize) -> Vec { + let mut out = vec![f64::NAN; series.len()]; + if n == 0 || series.len() < n { + return out; + } + let alpha = 2.0 / (n as f64 + 1.0); + let seed = series[..n].iter().sum::() / n as f64; + out[n - 1] = seed; + for i in n..series.len() { + out[i] = alpha * series[i] + (1.0 - alpha) * out[i - 1]; + } + out +} + +/// 对含“前导 NaN”的序列计算 TA-Lib 风格 EMA。 +fn calc_ema_talib_skip_leading_nan(series: &[f64], n: usize) -> Vec { + let mut out = vec![f64::NAN; series.len()]; + let Some(start) = series.iter().position(|x| x.is_finite()) else { + return out; + }; + let tail = &series[start..]; + let ema_tail = calc_ema_talib_style(tail, n); + for (i, v) in ema_tail.iter().enumerate() { + out[start + i] = *v; + } + out +} + +/// obvm_line_V230610:OBV 双 EMA 能量信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}_OBV能量V230610"` +/// +/// 信号逻辑: +/// 1. 计算 OBV 累计量序列(阳线加量、阴线减量); +/// 2. 分别计算 OBV 的短期 EMA(`n`) 与长期 EMA(`m`); +/// 3. 短期 EMA 大于长期 EMA 判 `看多`,否则判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N10M30_OBV能量V230610_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N10M30_OBV能量V230610_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:短期 EMA 周期,默认 `10`; +/// - `m`:长期 EMA 周期,默认 `30`。 +/// 对齐说明:OBV 构造方式与 Python `obvm_line_V230610` 一致(按 K 线涨跌符号加减成交量)。 +#[signal( + category = "kline", + name = "obvm_line_V230610", + template = "{freq}_D{di}N{n}M{m}_OBV能量V230610", + opcode = "ObvmLineV230610", + param_kind = "ObvmLineV230610" +)] +pub fn obvm_line_v230610(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 10); + let m = params.usize("m", 30); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}", di, n, m); + let k3 = "OBV能量V230610"; + let v1_default = "其他"; + + if c.bars_raw.len() < di + n.max(m) + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + + update_obv_cache(c, cache); + let Some(obv_series) = cache.series.get("OBV") else { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + }; + let id_map = bar_index_map(c); + let bars = get_sub_elements(&c.bars_raw, di, n.max(m) + 10); + let mut obv_seq = Vec::with_capacity(bars.len()); + for b in bars { + if let Some(i) = id_map.get(&b.id).copied() { + if i < obv_series.len() { + obv_seq.push(obv_series[i]); + } + } + } + if obv_seq.len() < n.max(m) { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + + let ema_n = calc_ema_talib_style(&obv_seq, n); + let ema_m = calc_ema_talib_style(&obv_seq, m); + let e1 = *ema_n.last().unwrap_or(&f64::NAN); + let e2 = *ema_m.last().unwrap_or(&f64::NAN); + if !e1.is_finite() || !e2.is_finite() { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + let v1 = if e1 > e2 { "看多" } else { "看空" }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// obv_up_dw_line_V230719:OBV 交叉信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}MO{max_overlap}_OBV能量V230719"` +/// +/// 信号逻辑: +/// 1. 先计算 OBV 累计量序列; +/// 2. 计算 `obvm = EMA(OBV, n)`,再计算 `sig = EMA(obvm, m)`; +/// 3. 若当前 `obvm > sig` 且 `max_overlap` 根前 `obvm < sig`,判 `看多`; +/// 4. 若当前 `obvm < sig` 且 `max_overlap` 根前 `obvm > sig`,判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N7M10MO3_OBV能量V230719_看多_任意_任意_0')` +/// - `Signal('60分钟_D1N7M10MO3_OBV能量V230719_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:OBVM EMA 周期,默认 `7`; +/// - `m`:信号线 EMA 周期,默认 `10`; +/// - `max_overlap`:交叉回看根数,默认 `3`。 +/// 对齐说明:交叉判定时点与 Python `obv_up_dw_line_V230719` 完全一致(使用 `-max_overlap`)。 +#[signal( + category = "kline", + name = "obv_up_dw_line_V230719", + template = "{freq}_D{di}N{n}M{m}MO{max_overlap}_OBV能量V230719", + opcode = "ObvUpDwLineV230719", + param_kind = "ObvUpDwLineV230719" +)] +pub fn obv_up_dw_line_v230719( + c: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let n = params.usize("n", 7); + let m = params.usize("m", 10); + let max_overlap = params.usize("max_overlap", 3); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}M{}MO{}", di, n, m, max_overlap); + let k3 = "OBV能量V230719"; + let v1_default = "其他"; + + let min_k_num = di + n.max(m) + max_overlap + 10; + if c.bars_raw.len() < min_k_num { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + + update_obv_cache(c, cache); + let Some(obv_series) = cache.series.get("OBV") else { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + }; + let id_map = bar_index_map(c); + let bars = get_sub_elements(&c.bars_raw, di, min_k_num); + let mut obv_seq = Vec::with_capacity(bars.len()); + for b in bars { + if let Some(i) = id_map.get(&b.id).copied() { + if i < obv_series.len() { + obv_seq.push(obv_series[i]); + } + } + } + if obv_seq.len() < min_k_num { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + + let obvm = calc_ema_talib_style(&obv_seq, n); + let sig = calc_ema_talib_skip_leading_nan(&obvm, m); + let l = obvm.len(); + if l <= max_overlap { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + let obvm_last = obvm[l - 1]; + let sig_last = sig[l - 1]; + let obvm_ref = obvm[l - max_overlap]; + let sig_ref = sig[l - max_overlap]; + if !obvm_last.is_finite() + || !sig_last.is_finite() + || !obvm_ref.is_finite() + || !sig_ref.is_finite() + { + return make_kline_signal_v1(&k1, &k2, k3, v1_default); + } + + let v1 = if obvm_last > sig_last && obvm_ref < sig_ref { + "看多" + } else if obvm_last < sig_last && obvm_ref > sig_ref { + "看空" + } else { + v1_default + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/params.rs b/crates/czsc-signals/src/params.rs new file mode 100644 index 000000000..8a25103fc --- /dev/null +++ b/crates/czsc-signals/src/params.rs @@ -0,0 +1,63 @@ +use serde_json::Value; +use std::collections::HashMap; + +/// 统一参数只读视图,避免为每个信号维护独立 params 结构体。 +#[derive(Debug, Clone, Copy)] +pub struct ParamView<'a> { + inner: &'a HashMap, +} + +impl<'a> ParamView<'a> { + #[inline] + pub fn new(inner: &'a HashMap) -> Self { + Self { inner } + } + + #[inline] + pub fn usize(&self, key: &str, default: usize) -> usize { + if let Some(val) = self.inner.get(key) { + if let Some(n) = val.as_u64() { + return n as usize; + } + if let Some(s) = val.as_str() { + if let Ok(n) = s.parse::() { + return n; + } + } + } + default + } + + #[inline] + pub fn str<'b>(&'b self, key: &str, default: &'b str) -> &'b str { + if let Some(val) = self.inner.get(key) { + if let Some(s) = val.as_str() { + return s; + } + } + default + } + + #[inline] + pub fn bool(&self, key: &str, default: bool) -> bool { + if let Some(val) = self.inner.get(key) { + if let Some(v) = val.as_bool() { + return v; + } + if let Some(s) = val.as_str() { + if s.eq_ignore_ascii_case("true") || s == "1" { + return true; + } + if s.eq_ignore_ascii_case("false") || s == "0" { + return false; + } + } + } + default + } + + #[inline] + pub fn value(&self, key: &str) -> Option<&Value> { + self.inner.get(key) + } +} diff --git a/crates/czsc-signals/src/pos.rs b/crates/czsc-signals/src/pos.rs new file mode 100644 index 000000000..144350100 --- /dev/null +++ b/crates/czsc-signals/src/pos.rs @@ -0,0 +1,1202 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{ + get_str_param, get_sub_elements, get_usize_param, last_open_operate, latest_price, make_signal, + make_signal_v1, +}; +use crate::utils::ta::{update_atr_cache, update_ma_cache}; +use czsc_signal_macros::signal; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::operate::Operate; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::state::TraderState; + +/// pos_ma_V230414:判断开仓后是否升破/跌破均线 +/// +/// 参数模板:`"{pos_name}_{freq1}#{ma_type}#{timeperiod}_持有状态V230414"` +/// +/// 信号逻辑: +/// - 多头持仓:开仓后任一 bar 出现 `close > MA`,记 `多头_升破均线`; +/// - 空头持仓:开仓后任一 bar 出现 `close < MA`,记 `空头_跌破均线`; +/// - 其余场景返回 `其他_其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头N1_60分钟#SMA#5_持有状态V230414_多头_升破均线_任意_0')` +/// - `Signal('日线三买多头N1_60分钟#SMA#5_持有状态V230414_空头_跌破均线_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期参数,默认 `5`。 +#[signal( + category = "trader", + name = "pos_ma_V230414", + template = "{pos_name}_{freq1}#{ma_type}#{timeperiod}_持有状态V230414", + opcode = "PosMaV230414", + param_kind = "PosMaV230414" +)] +pub fn pos_ma_v230414(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let ma_type = get_str_param(params, "ma_type", "SMA").to_uppercase(); + let timeperiod = get_usize_param(params, "timeperiod", 5); + + let k1 = pos_name.to_string(); + let k2 = format!("{}#{}#{}", freq1, ma_type, timeperiod); + let k3 = "持有状态V230414"; + let mut v1 = "其他"; + let mut v2 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + + let cache_key = format!("{}#{}", ma_type, timeperiod); + let mut cache = TaCache::new(); + update_ma_cache(c, &cache_key, &ma_type, timeperiod, &mut cache); + let ma = match cache.series.get(&cache_key) { + Some(ma) => ma, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + + let start = c.bars_raw.len().saturating_sub(100); + let bars = &c.bars_raw[start..]; + for bar in bars.iter().filter(|x| x.dt > op.dt) { + let idx = match c.bars_raw.iter().position(|b| b.id == bar.id) { + Some(i) => i, + None => continue, + }; + let ma_v = ma.get(idx).copied().unwrap_or(f64::NAN); + if op.op == Operate::LO && bar.close > ma_v { + v1 = "多头"; + v2 = "升破均线"; + break; + } + if op.op == Operate::SO && bar.close < ma_v { + v1 = "空头"; + v2 = "跌破均线"; + break; + } + } + + make_signal(&k1, &k2, k3, v1, v2) +} + +/// pos_fx_stop_V230414:按开仓点附近分型止损 +/// +/// 参数模板:`"{freq1}_{pos_name}N{n}_止损V230414"` +/// +/// 信号逻辑: +/// - 多头:取开仓前最近 `n` 个底分型,最新价跌破最低分型低点,记 `多头止损`; +/// - 空头:取开仓前最近 `n` 个顶分型,最新价突破最高分型高点,记 `空头止损`; +/// - 其余场景记 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('日线_日线三买多头N1_止损V230414_多头止损_任意_任意_0')` +/// - `Signal('日线_日线三买多头N1_止损V230414_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `n`:向前取分型数量,默认 `3`。 +#[signal( + category = "trader", + name = "pos_fx_stop_V230414", + template = "{freq1}_{pos_name}N{n}_止损V230414", + opcode = "PosFxStopV230414", + param_kind = "PosFxStop" +)] +pub fn pos_fx_stop_v230414(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let n = get_usize_param(params, "n", 3); + let k1 = freq1.to_string(); + let k2 = format!("{}N{}", pos_name, n); + let k3 = "止损V230414"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let latest = latest_price(cat, freq1).unwrap_or(0.0); + let fxs = c.get_fx_list(); + + if op.op == Operate::LO { + let all: Vec<_> = fxs + .iter() + .filter(|x| x.mark == Mark::D && x.dt < op.dt) + .collect(); + let start = all.len().saturating_sub(n); + if !all.is_empty() { + let ll = all[start..] + .iter() + .fold(f64::INFINITY, |acc, x| acc.min(x.low)); + if latest < ll { + v1 = "多头止损"; + } + } + } + if op.op == Operate::SO { + let all: Vec<_> = fxs + .iter() + .filter(|x| x.mark == Mark::G && x.dt < op.dt) + .collect(); + let start = all.len().saturating_sub(n); + if !all.is_empty() { + let hh = all[start..] + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + if latest > hh { + v1 = "空头止损"; + } + } + } + + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_bar_stop_V230524:按开仓点附近N根K线极值止损 +/// +/// 参数模板:`"{pos_name}_{freq1}N{n}K_止损V230524"` +/// +/// 信号逻辑: +/// - 多头:开仓前最近 `n` 根K线最低价被最新价跌破,记 `多头止损`; +/// - 空头:开仓前最近 `n` 根K线最高价被最新价突破,记 `空头止损`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头_日线N3K_止损V230524_多头止损_任意_任意_0')` +/// - `Signal('日线三买多头_日线N3K_止损V230524_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `n`:向前取K线数量,默认 `3`,有效范围 `[1, 20]`。 +#[signal( + category = "trader", + name = "pos_bar_stop_V230524", + template = "{pos_name}_{freq1}N{n}K_止损V230524", + opcode = "PosBarStopV230524", + param_kind = "PosBarStopV230524" +)] +pub fn pos_bar_stop_v230524(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let n = get_usize_param(params, "n", 3).clamp(1, 20); + let k1 = pos_name.to_string(); + let k2 = format!("{}N{}K", freq1, n); + let k3 = "止损V230524"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let latest = latest_price(cat, freq1).unwrap_or(0.0); + let start = c.bars_raw.len().saturating_sub(100); + let mut bars: Vec<_> = c.bars_raw[start..] + .iter() + .filter(|x| x.dt < op.dt) + .collect(); + if bars.len() > n { + bars = bars[bars.len() - n..].to_vec(); + } + + if !bars.is_empty() && op.op == Operate::LO { + let ll = bars.iter().fold(f64::INFINITY, |acc, x| acc.min(x.low)); + if latest < ll { + v1 = "多头止损"; + } + } + if !bars.is_empty() && op.op == Operate::SO { + let hh = bars + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + if latest > hh { + v1 = "空头止损"; + } + } + + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_holds_V230414:开仓后 N 根K线收益与阈值比较 +/// +/// 参数模板:`"{pos_name}_{freq1}N{n}M{m}_趋势判断V230414"` +/// +/// 信号逻辑: +/// - 多头:`zdf=(最新收盘-开仓价)/开仓价*10000`,`zdf Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let n = get_usize_param(params, "n", 5); + let m = get_usize_param(params, "m", 100); + let k1 = pos_name.to_string(); + let k2 = format!("{}N{}M{}", freq1, n, m); + let k3 = "趋势判断V230414"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let start = c.bars_raw.len().saturating_sub(100); + let bars: Vec<_> = c.bars_raw[start..] + .iter() + .filter(|x| x.dt > op.dt) + .collect(); + if bars.len() < n { + return make_signal_v1(&k1, &k2, k3, v1); + } + let last_close = bars.last().map(|x| x.close).unwrap_or(op.price); + if op.op == Operate::LO { + let zdf = (last_close - op.price) / op.price * 10000.0; + v1 = if zdf < m as f64 { + "多头存疑" + } else { + "多头良好" + }; + } + if op.op == Operate::SO { + let zdf = (op.price - last_close) / op.price * 10000.0; + v1 = if zdf < m as f64 { + "空头存疑" + } else { + "空头良好" + }; + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_fix_exit_V230624:固定 BP 止盈止损 +/// +/// 参数模板:`"{pos_name}_固定{th}BP止盈止损_出场V230624"` +/// +/// 信号逻辑: +/// - 多头:现价低于 `开仓价*(1-th/10000)` 为 `多头止损`,高于 `开仓价*(1+th/10000)` 为 `多头止盈`; +/// - 空头:规则镜像为 `空头止损/空头止盈`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头_固定100BP止盈止损_出场V230624_多头止损_任意_任意_0')` +/// - `Signal('日线三买多头_固定100BP止盈止损_出场V230624_空头止盈_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `th`:止盈止损阈值(BP),默认 `300`。 +#[signal( + category = "trader", + name = "pos_fix_exit_V230624", + template = "{pos_name}_固定{th}BP止盈止损_出场V230624", + opcode = "PosFixExitV230624", + param_kind = "PosFixExitV230624" +)] +pub fn pos_fix_exit_v230624(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let th = get_usize_param(params, "th", 300); + let k1 = pos_name.to_string(); + let k2 = format!("固定{}BP止盈止损", th); + let k3 = "出场V230624"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let freq1 = get_str_param(params, "freq1", ""); + let lp = latest_price(cat, freq1).unwrap_or(op.price); + if op.op == Operate::LO { + if lp < op.price * (1.0 - th as f64 / 10000.0) { + v1 = "多头止损"; + } + if lp > op.price * (1.0 + th as f64 / 10000.0) { + v1 = "多头止盈"; + } + } + if op.op == Operate::SO { + if lp > op.price * (1.0 + th as f64 / 10000.0) { + v1 = "空头止损"; + } + if lp < op.price * (1.0 - th as f64 / 10000.0) { + v1 = "空头止盈"; + } + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_profit_loss_V230624:盈亏比阈值判断 +/// +/// 参数模板:`"{pos_name}_{freq1}YKB{ykb}N{n}_盈亏比判断V230624"` +/// +/// 信号逻辑: +/// - 基于开仓前 `n` 个分型确定止损价,计算 `ykb = (现价-开仓价)/(开仓价-止损价)*10`; +/// - `ykb > 阈值` 记 `多头达标/空头达标`; +/// - 未达标时若击穿止损价,记 `多头止损/空头止损`。 +/// +/// 信号列表示例: +/// - `Signal('日线通道突破_60分钟YKB20N3_盈亏比判断V230624_多头达标_任意_任意_0')` +/// - `Signal('日线通道突破_60分钟YKB20N3_盈亏比判断V230624_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `ykb`:盈亏比阈值(×10),默认 `20`; +/// - `n`:止损分型窗口,默认 `3`。 +#[signal( + category = "trader", + name = "pos_profit_loss_V230624", + template = "{pos_name}_{freq1}YKB{ykb}N{n}_盈亏比判断V230624", + opcode = "PosProfitLossV230624", + param_kind = "PosProfitLossV230624" +)] +pub fn pos_profit_loss_v230624( + cat: &dyn TraderState, + params: &ParamView, +) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let ykb = get_usize_param(params, "ykb", 20); + let n = get_usize_param(params, "n", 3); + let k1 = pos_name.to_string(); + let k2 = format!("{}YKB{}N{}", freq1, ykb, n); + let k3 = "盈亏比判断V230624"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let last_close = c.bars_raw.last().map(|b| b.close).unwrap_or(op.price); + let fxs = c.get_fx_list(); + + if op.op == Operate::LO { + let all: Vec<_> = fxs + .iter() + .filter(|x| x.mark == Mark::D && x.dt < op.dt) + .collect(); + let start = all.len().saturating_sub(n); + if !all.is_empty() { + let stop_price = all[start..] + .iter() + .fold(f64::INFINITY, |acc, x| acc.min(x.low)); + let denom = op.price - stop_price; + if denom != 0.0 { + let y = ((last_close - op.price) / denom) * 10.0; + if y > ykb as f64 { + v1 = "多头达标"; + } else if last_close < stop_price { + v1 = "多头止损"; + } + } + } + } + if op.op == Operate::SO { + let all: Vec<_> = fxs + .iter() + .filter(|x| x.mark == Mark::G && x.dt < op.dt) + .collect(); + let start = all.len().saturating_sub(n); + if !all.is_empty() { + let stop_price = all[start..] + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + let denom = op.price - stop_price; + if denom != 0.0 { + let y = ((last_close - op.price) / denom) * 10.0; + if y > ykb as f64 { + v1 = "空头达标"; + } else if last_close > stop_price { + v1 = "空头止损"; + } + } + } + } + + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_status_V230808:持仓状态 +/// +/// 参数模板:`"{pos_name}_持仓状态_BS辅助V230808"` +/// +/// 信号逻辑: +/// - 最近操作为 `LO` 输出 `持多`; +/// - 最近操作为 `SO` 输出 `持空`; +/// - 其余输出 `持币`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头N1_持仓状态_BS辅助V230808_持多_任意_任意_0')` +/// - `Signal('日线三买多头N1_持仓状态_BS辅助V230808_持币_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称。 +#[signal( + category = "trader", + name = "pos_status_V230808", + template = "{pos_name}_持仓状态_BS辅助V230808", + opcode = "PosStatusV230808", + param_kind = "PosStatus" +)] +pub fn pos_status_v230808(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let k1 = pos_name.to_string(); + let k2 = "持仓状态"; + let k3 = "BS辅助V230808"; + let v1 = match cat.get_position(pos_name).and_then(|p| p.operates.last()) { + Some(op) if op.op == Operate::LO => "持多", + Some(op) if op.op == Operate::SO => "持空", + _ => "持币", + }; + make_signal_v1(&k1, k2, k3, v1) +} + +/// pos_holds_V230807:开仓后收益在 (t, m) 之间触发保本 +/// +/// 参数模板:`"{pos_name}_{freq1}N{n}M{m}T{t}_BS辅助V230807"` +/// +/// 信号逻辑: +/// - 当开仓后收益落在 `(t, m)` 区间时触发 `多头保本/空头保本`; +/// - 含义是“达到最低保本收益但未达到趋势确认阈值”,优先保本离场。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头N1_60分钟N5M50T10_BS辅助V230807_多头保本_任意_任意_0')` +/// - `Signal('日线三买多头N1_60分钟N5M50T10_BS辅助V230807_空头保本_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `n`:最少持有K线数量,默认 `5`; +/// - `m`:收益上限阈值(BP),默认 `50`; +/// - `t`:保本收益阈值(BP),默认 `10`,且要求 `m > t > 0`。 +#[signal( + category = "trader", + name = "pos_holds_V230807", + template = "{pos_name}_{freq1}N{n}M{m}T{t}_BS辅助V230807", + opcode = "PosHoldsV230807", + param_kind = "PosHoldsV230807" +)] +pub fn pos_holds_v230807(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let n = get_usize_param(params, "n", 5); + let m = get_usize_param(params, "m", 50); + let t = get_usize_param(params, "t", 10); + let k1 = pos_name.to_string(); + let k2 = format!("{}N{}M{}T{}", freq1, n, m, t); + let k3 = "BS辅助V230807"; + let mut v1 = "其他"; + if m <= t || t == 0 { + return make_signal_v1(&k1, &k2, k3, v1); + } + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let start = c.bars_raw.len().saturating_sub(100); + let bars: Vec<_> = c.bars_raw[start..] + .iter() + .filter(|x| x.dt > op.dt) + .collect(); + if bars.len() < n { + return make_signal_v1(&k1, &k2, k3, v1); + } + let last_close = bars.last().map(|x| x.close).unwrap_or(op.price); + if op.op == Operate::LO { + let zdf = (last_close - op.price) / op.price * 10000.0; + if zdf > t as f64 && zdf < m as f64 { + v1 = "多头保本"; + } + } + if op.op == Operate::SO { + let zdf = (op.price - last_close) / op.price * 10000.0; + if zdf > t as f64 && zdf < m as f64 { + v1 = "空头保本"; + } + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_holds_V240428:最大盈利回撤比例保本 +/// +/// 参数模板:`"{pos_name}_{freq1}H{h}T{t}N{n}_保本V240428"` +/// +/// 信号逻辑: +/// - 多头:最大盈利 `y1` 超过 `h` 且当前盈利 `y2 < y1*t/100`,记 `多头保本`; +/// - 空头:按镜像规则记 `空头保本`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头N1_60分钟H100T20N5_保本V240428_多头保本_任意_任意_0')` +/// - `Signal('日线三买多头N1_60分钟H100T20N5_保本V240428_空头保本_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `h`:最大盈利阈值(BP),默认 `100`; +/// - `t`:回撤比例阈值(%),默认 `20`; +/// - `n`:最少持有K线数量,默认 `5`。 +#[signal( + category = "trader", + name = "pos_holds_V240428", + template = "{pos_name}_{freq1}H{h}T{t}N{n}_保本V240428", + opcode = "PosHoldsV240428", + param_kind = "PosHoldsV240428" +)] +pub fn pos_holds_v240428(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let h = get_usize_param(params, "h", 100); + let t = get_usize_param(params, "t", 20); + let n = get_usize_param(params, "n", 5); + let k1 = pos_name.to_string(); + let k2 = format!("{}H{}T{}N{}", freq1, h, t, n); + let k3 = "保本V240428"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let start = c.bars_raw.len().saturating_sub(100); + let bars: Vec<_> = c.bars_raw[start..] + .iter() + .filter(|x| x.dt > op.dt) + .collect(); + if bars.len() < n { + return make_signal_v1(&k1, &k2, k3, v1); + } + let last_close = bars.last().map(|x| x.close).unwrap_or(op.price); + + if op.op == Operate::LO { + let max_close = bars + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.close)); + let y1 = (max_close - op.price) / op.price * 10000.0; + let y2 = (last_close - op.price) / op.price * 10000.0; + if y1 > h as f64 && y2 < y1 * t as f64 / 100.0 { + v1 = "多头保本"; + } + } + if op.op == Operate::SO { + let min_close = bars.iter().fold(f64::INFINITY, |acc, x| acc.min(x.close)); + let y1 = (op.price - min_close) / op.price * 10000.0; + let y2 = (op.price - last_close) / op.price * 10000.0; + if y1 > h as f64 && y2 < y1 * t as f64 / 100.0 { + v1 = "空头保本"; + } + } + + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_holds_V240608:跌破/升破开仓前窗口极值后,回到成本价指定档位保本 +/// +/// 参数模板:`"{pos_name}_{freq1}W{w}N{n}_保本V240608"` +/// +/// 信号逻辑: +/// - 多头:若开仓后最低价跌破开仓前 `w` 根最低价,且现价回到成本价上方第 `n` 档,记 `多头保本`; +/// - 空头:若开仓后最高价突破开仓前 `w` 根最高价,且现价回到成本价下方第 `n` 档,记 `空头保本`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头N1_60分钟W20N2_保本V240608_多头保本_任意_任意_0')` +/// - `Signal('日线三买多头N1_60分钟W20N2_保本V240608_空头保本_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `w`:开仓前观察窗口,默认 `20`; +/// - `n`:成本价上下档位偏移,默认 `2`。 +#[signal( + category = "trader", + name = "pos_holds_V240608", + template = "{pos_name}_{freq1}W{w}N{n}_保本V240608", + opcode = "PosHoldsV240608", + param_kind = "PosHoldsV240608" +)] +pub fn pos_holds_v240608(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let w = get_usize_param(params, "w", 20); + let n = get_usize_param(params, "n", 2); + let k1 = pos_name.to_string(); + let k2 = format!("{}W{}N{}", freq1, w, n); + let k3 = "保本V240608"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + + let s200 = c.bars_raw.len().saturating_sub(200); + let w_bars_all: Vec<_> = c.bars_raw[s200..] + .iter() + .filter(|x| x.dt <= op.dt) + .collect(); + let w_bars = if w_bars_all.len() > w { + &w_bars_all[w_bars_all.len() - w..] + } else { + &w_bars_all[..] + }; + let s100 = c.bars_raw.len().saturating_sub(100); + let a_bars: Vec<_> = c.bars_raw[s100..].iter().filter(|x| x.dt > op.dt).collect(); + if w_bars.is_empty() || a_bars.is_empty() { + return make_signal_v1(&k1, &k2, k3, v1); + } + + let mut unique_prices: Vec = c.bars_raw[s200..] + .iter() + .flat_map(|x| [x.high, x.low, x.close, x.open]) + .collect(); + unique_prices.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + unique_prices.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON); + let lp = latest_price(cat, freq1).unwrap_or(op.price); + + if op.op == Operate::LO { + let w_low = w_bars.iter().fold(f64::INFINITY, |acc, x| acc.min(x.low)); + let a_low = a_bars.iter().fold(f64::INFINITY, |acc, x| acc.min(x.low)); + let up_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x > op.price) + .collect(); + if up_prices.len() > n && a_low < w_low && lp > up_prices[n] { + v1 = "多头保本"; + } + } + if op.op == Operate::SO { + let w_high = w_bars + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + let a_high = a_bars + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + let down_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x < op.price) + .collect(); + if down_prices.len() > n && a_high > w_high && lp < down_prices[down_prices.len() - n] { + v1 = "空头保本"; + } + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_stop_V240428:按开仓前离散价位跳数止损 +/// +/// 参数模板:`"{pos_name}_{freq1}T{t}N{n}_止损V240428"` +/// +/// 信号逻辑: +/// - 使用开仓前历史K线提取离散价位; +/// - 多头取低于开仓价的第 `t` 档止损位,空头取高于开仓价的第 `t` 档止损位; +/// - 开仓后至少持有 `n` 根后,收盘穿越止损位触发 `多头止损/空头止损`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头N1_60分钟T20N5_止损V240428_多头止损_任意_任意_0')` +/// - `Signal('日线三买多头N1_60分钟T20N5_止损V240428_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `t`:离散价位档位,默认 `20`; +/// - `n`:最少持有K线数量,默认 `5`。 +#[signal( + category = "trader", + name = "pos_stop_V240428", + template = "{pos_name}_{freq1}T{t}N{n}_止损V240428", + opcode = "PosStopV240428", + param_kind = "PosStopV240428" +)] +pub fn pos_stop_v240428(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let t = get_usize_param(params, "t", 20); + let n = get_usize_param(params, "n", 5); + let k1 = pos_name.to_string(); + let k2 = format!("{}T{}N{}", freq1, t, n); + let k3 = "止损V240428"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let s100 = c.bars_raw.len().saturating_sub(100); + let right_bars: Vec<_> = c.bars_raw[s100..].iter().filter(|x| x.dt > op.dt).collect(); + if right_bars.len() < n { + return make_signal_v1(&k1, &k2, k3, v1); + } + let left_bars: Vec<_> = c.bars_raw.iter().filter(|x| x.dt < op.dt).collect(); + let mut unique_prices: Vec = left_bars + .iter() + .flat_map(|x| [x.high, x.low, x.close, x.open]) + .collect(); + unique_prices.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + unique_prices.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON); + + if op.op == Operate::LO { + let mut low_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x < op.price) + .collect(); + low_prices.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + if low_prices.is_empty() { + return make_signal_v1(&k1, &k2, k3, v1); + } + let y = if low_prices.len() > t { + low_prices[low_prices.len() - t] + } else { + low_prices[0] + }; + if right_bars.last().map(|b| b.close).unwrap_or(op.price) < y { + v1 = "多头止损"; + } + } + if op.op == Operate::SO { + let mut high_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x > op.price) + .collect(); + high_prices.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + if high_prices.is_empty() { + return make_signal_v1(&k1, &k2, k3, v1); + } + let y = if high_prices.len() > t { + high_prices[t] + } else { + high_prices[high_prices.len() - 1] + }; + if right_bars.last().map(|b| b.close).unwrap_or(op.price) > y { + v1 = "空头止损"; + } + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_take_V240428:倍量阳/阴线计数止盈 +/// +/// 参数模板:`"{pos_name}_{freq1}T{t}N{n}_止盈V240428"` +/// +/// 信号逻辑: +/// - 多头统计开仓后“阳线且成交量 > 前一根 2 倍”的次数,达到 `t` 触发 `多头止盈`; +/// - 空头统计对应倍量阴线次数,达到 `t` 触发 `空头止盈`。 +/// +/// 信号列表示例: +/// - `Signal('日线三买多头N1_60分钟T3N5_止盈V240428_多头止盈_任意_任意_0')` +/// - `Signal('日线三买多头N1_60分钟T3N5_止盈V240428_空头止盈_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `t`:倍量K线数量阈值,默认 `3`; +/// - `n`:最少持有K线数量,默认 `5`。 +#[signal( + category = "trader", + name = "pos_take_V240428", + template = "{pos_name}_{freq1}T{t}N{n}_止盈V240428", + opcode = "PosTakeV240428", + param_kind = "PosTakeV240428" +)] +pub fn pos_take_v240428(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let t = get_usize_param(params, "t", 3); + let n = get_usize_param(params, "n", 5); + let k1 = pos_name.to_string(); + let k2 = format!("{}T{}N{}", freq1, t, n); + let k3 = "止盈V240428"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let s100 = c.bars_raw.len().saturating_sub(100); + let bars: Vec<_> = c.bars_raw[s100..].iter().filter(|x| x.dt > op.dt).collect(); + if bars.len() < n { + return make_signal_v1(&k1, &k2, k3, v1); + } + if op.op == Operate::LO { + let mut c1 = 0usize; + for i in 1..bars.len() { + if bars[i].close > bars[i].open && bars[i].vol > bars[i - 1].vol * 2.0 { + c1 += 1; + } + } + if c1 >= t { + v1 = "多头止盈"; + } + } + if op.op == Operate::SO { + let mut c2 = 0usize; + for i in 1..bars.len() { + if bars[i].close < bars[i].open && bars[i].vol > bars[i - 1].vol * 2.0 { + c2 += 1; + } + } + if c2 >= t { + v1 = "空头止盈"; + } + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_stop_V240331:最近 N 根K线追踪止损 +/// +/// 参数模板:`"{pos_name}_{freq1}#{n}_止损V240331"` +/// +/// 信号逻辑: +/// - 多头:最新K线低点跌破前 `n` 根最低价且 bar_id 晚于开仓 bar,触发 `多头止损`; +/// - 空头:最新K线高点突破前 `n` 根最高价且 bar_id 晚于开仓 bar,触发 `空头止损`。 +/// +/// 信号列表示例: +/// - `Signal('SMA5多头_15分钟#10_止损V240331_多头止损_任意_任意_0')` +/// - `Signal('SMA5空头_15分钟#10_止损V240331_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `n`:追踪窗口,默认 `10`。 +#[signal( + category = "trader", + name = "pos_stop_V240331", + template = "{pos_name}_{freq1}#{n}_止损V240331", + opcode = "PosStopV240331", + param_kind = "PosStopV240331" +)] +pub fn pos_stop_v240331(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let n = get_usize_param(params, "n", 10); + let k1 = pos_name.to_string(); + let k2 = format!("{}#{}", freq1, n); + let k3 = "止损V240331"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let bars = get_sub_elements(&c.bars_raw, 1, n + 1); + if bars.len() < n + 1 { + return make_signal_v1(&k1, &k2, k3, v1); + } + let last_bar = match bars.last() { + Some(x) => x, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + if op.op == Operate::LO { + let ll = bars[..bars.len() - 1] + .iter() + .fold(f64::INFINITY, |acc, x| acc.min(x.low)); + if last_bar.low < ll && last_bar.id > op.bar_id { + v1 = "多头止损"; + } + } + if op.op == Operate::SO { + let hh = bars[..bars.len() - 1] + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + if last_bar.high > hh && last_bar.id > op.bar_id { + v1 = "空头止损"; + } + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_stop_V240608:开仓后突破开仓前窗口极值 N 档止损 +/// +/// 参数模板:`"{pos_name}_{freq1}W{w}N{n}_止损V240608"` +/// +/// 信号逻辑: +/// - 多头:开仓后最低价低于“开仓前 `w` 根最低价下方第 `n` 档”触发 `多头止损`; +/// - 空头:开仓后最高价高于“开仓前 `w` 根最高价上方第 `n` 档”触发 `空头止损`。 +/// +/// 信号列表示例: +/// - `Signal('SMA5多头_15分钟W20N10_止损V240608_多头止损_任意_任意_0')` +/// - `Signal('SMA5空头_15分钟W20N10_止损V240608_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `w`:开仓前观察窗口,默认 `20`; +/// - `n`:上下档位偏移,默认 `10`。 +#[signal( + category = "trader", + name = "pos_stop_V240608", + template = "{pos_name}_{freq1}W{w}N{n}_止损V240608", + opcode = "PosStopV240608", + param_kind = "PosStopV240608" +)] +pub fn pos_stop_v240608(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let w = get_usize_param(params, "w", 20); + let n = get_usize_param(params, "n", 10); + let k1 = pos_name.to_string(); + let k2 = format!("{}W{}N{}", freq1, w, n); + let k3 = "止损V240608"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + + let w_all: Vec<_> = c.bars_raw.iter().filter(|x| x.dt < op.dt).collect(); + let w_bars = if w_all.len() > w { + &w_all[w_all.len() - w..] + } else { + &w_all[..] + }; + let s100 = c.bars_raw.len().saturating_sub(100); + let a_bars: Vec<_> = c.bars_raw[s100..].iter().filter(|x| x.dt > op.dt).collect(); + if w_bars.is_empty() || a_bars.is_empty() { + return make_signal_v1(&k1, &k2, k3, v1); + } + + let s200 = c.bars_raw.len().saturating_sub(200); + let mut unique_prices: Vec = c.bars_raw[s200..] + .iter() + .flat_map(|x| [x.high, x.low, x.close, x.open]) + .collect(); + unique_prices.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + unique_prices.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON); + + if op.op == Operate::LO { + let w_low = w_bars.iter().fold(f64::INFINITY, |acc, x| acc.min(x.low)); + let a_low = a_bars.iter().fold(f64::INFINITY, |acc, x| acc.min(x.low)); + let w_low_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x < w_low) + .collect(); + if w_low_prices.len() > n && a_low < w_low_prices[w_low_prices.len() - n] { + v1 = "多头止损"; + } + } + if op.op == Operate::SO { + let w_high = w_bars + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + let a_high = a_bars + .iter() + .fold(f64::NEG_INFINITY, |acc, x| acc.max(x.high)); + let w_high_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x > w_high) + .collect(); + if w_high_prices.len() > n && a_high > w_high_prices[n] { + v1 = "空头止损"; + } + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_stop_V240614:开仓后低于/高于成本价的 K线数量计数止损 +/// +/// 参数模板:`"{pos_name}_{freq1}N{n}_止损V240614"` +/// +/// 信号逻辑: +/// - 多头:开仓后 `low < 开仓价` 的K线数量达到 `n`,触发 `多头止损`; +/// - 空头:开仓后 `high > 开仓价` 的K线数量达到 `n`,触发 `空头止损`。 +/// +/// 信号列表示例: +/// - `Signal('SMA5多头_15分钟N10_止损V240614_多头止损_任意_任意_0')` +/// - `Signal('SMA5空头_15分钟N10_止损V240614_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `n`:计数阈值,默认 `10`。 +#[signal( + category = "trader", + name = "pos_stop_V240614", + template = "{pos_name}_{freq1}N{n}_止损V240614", + opcode = "PosStopV240614", + param_kind = "PosStopV240614" +)] +pub fn pos_stop_v240614(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let n = get_usize_param(params, "n", 10); + let k1 = pos_name.to_string(); + let k2 = format!("{}N{}", freq1, n); + let k3 = "止损V240614"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let a_bars: Vec<_> = c.bars_raw.iter().filter(|x| x.dt >= op.dt).collect(); + if op.op == Operate::LO && a_bars.iter().filter(|x| x.low < op.price).count() >= n { + v1 = "多头止损"; + } + if op.op == Operate::SO && a_bars.iter().filter(|x| x.high > op.price).count() >= n { + v1 = "空头止损"; + } + make_signal_v1(&k1, &k2, k3, v1) +} + +/// pos_stop_V240717:基于开仓时 ATR 的计数止损 +/// +/// 参数模板:`"{pos_name}_{freq1}N{n}T{timeperiod}_止损V240717"` +/// +/// 信号逻辑: +/// - 先取开仓时刻 ATR(`timeperiod`); +/// - 多头阈值为 `开仓价 - ATR*0.67`,空头阈值为 `开仓价 + ATR*0.67`; +/// - 开仓后超过阈值的K线数量达到 `n` 时触发 `多头止损/空头止损`。 +/// +/// 信号列表示例: +/// - `Signal('SMA5多头_15分钟N3T20_止损V240717_多头止损_任意_任意_0')` +/// - `Signal('SMA5空头_15分钟N3T20_止损V240717_空头止损_任意_任意_0')` +/// +/// 参数说明: +/// - `pos_name`:仓位名称; +/// - `freq1`:K线周期; +/// - `n`:计数阈值,默认 `10`; +/// - `timeperiod`:ATR 周期,默认 `20`。 +#[signal( + category = "trader", + name = "pos_stop_V240717", + template = "{pos_name}_{freq1}N{n}T{timeperiod}_止损V240717", + opcode = "PosStopV240717", + param_kind = "PosStopV240717" +)] +pub fn pos_stop_v240717(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let n = get_usize_param(params, "n", 10); + let timeperiod = get_usize_param(params, "timeperiod", 20); + let k1 = pos_name.to_string(); + let k2 = format!("{}N{}T{}", freq1, n, timeperiod); + let k3 = "止损V240717"; + let mut v1 = "其他"; + + let op = match last_open_operate(cat, pos_name) { + Some(op) => op, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let c = match cat.get_czsc(freq1) { + Some(c) => c, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + let cache_key = format!("ATR#{}", timeperiod); + let mut ta_cache = TaCache::new(); + update_atr_cache(c, &cache_key, timeperiod, &mut ta_cache); + let atr_series = match ta_cache.series.get(&cache_key) { + Some(v) => v, + None => return make_signal_v1(&k1, &k2, k3, v1), + }; + // 对齐 Python: + // atr = [x.cache[cache_key] if x.cache.get(cache_key) is not None else 0 + // for x in c.bars_raw if x.dt == op["dt"]][0] + let atr = c + .bars_raw + .iter() + .enumerate() + .find_map(|(i, b)| { + if b.dt == op.dt.with_timezone(&chrono::Utc) { + Some(*atr_series.get(i).unwrap_or(&0.0)) + } else { + None + } + }) + .unwrap_or(0.0); + let a_bars: Vec<_> = c.bars_raw.iter().filter(|x| x.dt >= op.dt).collect(); + + if op.op == Operate::LO + && a_bars + .iter() + .filter(|x| x.low < op.price - atr * 0.67) + .count() + >= n + { + v1 = "多头止损"; + } + if op.op == Operate::SO + && a_bars + .iter() + .filter(|x| x.high > op.price + atr * 0.67) + .count() + >= n + { + v1 = "空头止损"; + } + make_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/pressure.rs b/crates/czsc-signals/src/pressure.rs new file mode 100644 index 000000000..01999d005 --- /dev/null +++ b/crates/czsc-signals/src/pressure.rs @@ -0,0 +1,391 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +fn std_pop(values: &[f64]) -> f64 { + if values.is_empty() || values.iter().any(|x| !x.is_finite()) { + return f64::NAN; + } + let mean = values.iter().sum::() / values.len() as f64; + let var = values.iter().map(|x| (x - mean).powi(2)).sum::() / values.len() as f64; + var.sqrt() +} + +/// pressure_support_V240222:高低点验证支撑压力位 +/// +/// 参数模板:`"{freq}_D{di}W{w}高低点验证_支撑压力V240222"` +/// +/// 信号逻辑: +/// 1. 取最近 `w` 根K线,计算区间最高/最低与振幅标准差 `gap`; +/// 2. 若区间波动不足(`max_high-min_low < gap*0.3*w`)则返回 `其他`; +/// 3. 若窗口两端高点贴近全局高点,判 `压力位`; +/// 4. 若窗口两端低点贴近全局低点,判 `支撑位`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W20高低点验证_支撑压力V240222_压力位_任意_任意_0')` +/// - `Signal('60分钟_D1W20高低点验证_支撑压力V240222_支撑位_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `w`:观察窗口大小,默认 `20`,且必须大于 `10`。 +/// 对齐说明:窗口切分与高低点验证条件对齐 Python `pressure_support_V240222`。 +#[signal( + category = "kline", + name = "pressure_support_V240222", + template = "{freq}_D{di}W{w}高低点验证_支撑压力V240222", + opcode = "PressureSupportV240222", + param_kind = "PressureSupportV240222" +)] +pub fn pressure_support_v240222( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 20); + assert!(w > 10, "w must be > 10"); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}高低点验证", di, w); + let k3 = "支撑压力V240222"; + let mut v1 = "其他"; + + if c.bars_raw.len() < w + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let max_high = bars.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let min_low = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let n = ((bars.len() as f64) * 0.2) as usize; + if n == 0 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let left_bars = &bars[..n]; + let right_bars = &bars[bars.len() - n..]; + let gap = std_pop( + &bars + .iter() + .map(|x| (x.high - x.low).abs()) + .collect::>(), + ); + + if max_high - min_low < gap * 0.3 * w as f64 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let left_high = left_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let right_high = right_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + if max_high == left_high.max(right_high) && max_high - left_high.min(right_high) < gap { + v1 = "压力位"; + } + + let left_low = left_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let right_low = right_bars + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min); + if min_low == left_low.min(right_low) && left_low.max(right_low) - min_low < gap { + v1 = "支撑位"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// pressure_support_V240402:分型区间支撑压力位 +/// +/// 参数模板:`"{freq}_D{di}W{w}_支撑压力V240402"` +/// +/// 信号逻辑: +/// 1. 统计最近 `50` 个分型中,包含当前收盘价的分型数量; +/// 2. 若命中分型少于 `5` 或窗口波动不足(`max_high-min_low < gap*3`)返回 `其他`; +/// 3. 当前收盘靠近窗口上沿(前20%)判 `压力位`; +/// 4. 当前收盘靠近窗口下沿(前30%)判 `支撑位`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W60_支撑压力V240402_压力位_任意_任意_0')` +/// - `Signal('60分钟_D1W60_支撑压力V240402_支撑位_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `w`:观察窗口大小,默认 `60`,且必须大于 `10`。 +/// 对齐说明:分型筛选与收盘位置阈值对齐 Python `pressure_support_V240402`。 +#[signal( + category = "kline", + name = "pressure_support_V240402", + template = "{freq}_D{di}W{w}_支撑压力V240402", + opcode = "PressureSupportV240402", + param_kind = "PressureSupportV240402" +)] +pub fn pressure_support_v240402( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 60); + assert!(w > 10, "w must be > 10"); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}", di, w); + let k3 = "支撑压力V240402"; + let mut v1 = "其他"; + + if c.bars_raw.len() < w + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let fxs = c.get_fx_list(); + let fxs_tail = if fxs.len() > 50 { + &fxs[fxs.len() - 50..] + } else { + &fxs[..] + }; + let close = bars[bars.len() - 1].close; + let near_fx_cnt = fxs_tail + .iter() + .filter(|fx| fx.low <= close && close <= fx.high) + .count(); + + let gap = std_pop( + &bars + .iter() + .map(|x| (x.high - x.low).abs()) + .collect::>(), + ); + let max_high = bars.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let min_low = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + + if near_fx_cnt < 5 || max_high - min_low < gap * 3.0 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let hl_gap = max_high - min_low; + if close > max_high - hl_gap * 0.2 { + v1 = "压力位"; + } + if close < min_low + hl_gap * 0.3 { + v1 = "支撑位"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// pressure_support_V240406:分型密集支撑压力位 +/// +/// 参数模板:`"{freq}_D{di}W{w}_支撑压力V240406"` +/// +/// 信号逻辑: +/// 1. 统计窗口最高/最低附近的分型数量(严格落在分型区间内); +/// 2. 若窗口波动不足(`max_high-min_low < gap*3`)返回 `其他`; +/// 3. 若高点附近分型 `>=3` 且收盘靠近上沿,判 `压力位`; +/// 4. 若低点附近分型 `>=3` 且收盘靠近下沿,判 `支撑位`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W60_支撑压力V240406_压力位_任意_任意_0')` +/// - `Signal('60分钟_D1W60_支撑压力V240406_支撑位_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `w`:观察窗口大小,默认 `60`,且必须大于 `10`。 +/// 对齐说明:分型密集阈值和价格区间判断对齐 Python `pressure_support_V240406`。 +#[signal( + category = "kline", + name = "pressure_support_V240406", + template = "{freq}_D{di}W{w}_支撑压力V240406", + opcode = "PressureSupportV240406", + param_kind = "PressureSupportV240406" +)] +pub fn pressure_support_v240406( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 60); + assert!(w > 10, "w must be > 10"); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}", di, w); + let k3 = "支撑压力V240406"; + let mut v1 = "其他"; + + if c.bars_raw.len() < w + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let fxs = c.get_fx_list(); + let fxs_tail = if fxs.len() > 50 { + &fxs[fxs.len() - 50..] + } else { + &fxs[..] + }; + + let gap = std_pop( + &bars + .iter() + .map(|x| (x.high - x.low).abs()) + .collect::>(), + ); + let max_high = bars.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let min_low = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + + if max_high - min_low < gap * 3.0 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let near_high_fx = fxs_tail + .iter() + .filter(|fx| fx.low < max_high && max_high < fx.high) + .count(); + let near_low_fx = fxs_tail + .iter() + .filter(|fx| fx.low < min_low && min_low < fx.high) + .count(); + let hl_gap = max_high - min_low; + let close = bars[bars.len() - 1].close; + + if near_high_fx >= 3 && close > max_high - hl_gap * 0.2 { + v1 = "压力位"; + } + if near_low_fx >= 3 && close < min_low + hl_gap * 0.3 { + v1 = "支撑位"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// pressure_support_V240530:关键重叠K线支撑压力位 +/// +/// 参数模板:`"{freq}_D{di}W{w}N{n}_支撑压力V240530"` +/// +/// 信号逻辑: +/// 1. 在最近 `w` 根K线中寻找与其他K线重叠次数最多的关键K线; +/// 2. 若最大重叠次数小于 `0.5*w`,返回 `其他`; +/// 3. 以关键K线高低价在全局 `unique price` 列表上的 `±n` 档形成压力/支撑区间; +/// 4. 收盘落入高位区间判 `压力位`,落入低位区间判 `支撑位`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W20N5_支撑压力V240530_压力位_任意_任意_0')` +/// - `Signal('60分钟_D1W20N5_支撑压力V240530_支撑位_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `w`:观察窗口大小,默认 `20`,且必须大于 `10`; +/// - `n`:价格档位偏移,默认 `5`。 +/// 对齐说明:关键K线重叠计数和 `unique price ±n` 区间判定对齐 Python `pressure_support_V240530`。 +#[signal( + category = "kline", + name = "pressure_support_V240530", + template = "{freq}_D{di}W{w}N{n}_支撑压力V240530", + opcode = "PressureSupportV240530", + param_kind = "PressureSupportV240530" +)] +pub fn pressure_support_v240530( + c: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 20); + let n = params.usize("n", 5); + assert!(w > 10, "w must be > 10"); + + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}N{}", di, w, n); + let k3 = "支撑压力V240530"; + let mut v1 = "其他"; + + if c.bars_raw.len() < w + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bars = get_sub_elements(&c.bars_raw, di, w); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let mut overlap_counts = vec![0usize; bars.len()]; + for i in 0..bars.len() { + let bi = &bars[i]; + let mut count = 0usize; + for (j, bj) in bars.iter().enumerate() { + if i == j { + continue; + } + if bi.low.max(bj.low) < bi.high.min(bj.high) { + count += 1; + } + } + overlap_counts[i] = count; + } + + // 对齐 Python: max(dict, key=dict.get) 在并列最大时返回“最先插入”的键(最小索引) + let Some(max_cnt) = overlap_counts.iter().copied().max() else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(key_idx) = overlap_counts.iter().position(|x| *x == max_cnt) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + if (max_cnt as f64) < 0.5 * w as f64 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let key_bar = &bars[key_idx]; + let mut prices: Vec = c + .bars_raw + .iter() + .flat_map(|x| [x.open, x.close, x.high, x.low]) + .collect(); + prices.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + prices.dedup_by(|a, b| *a == *b); + + let Some(high_idx) = prices.iter().position(|x| *x == key_bar.high) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(low_idx) = prices.iter().position(|x| *x == key_bar.low) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + + if high_idx < n || low_idx < n { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + if high_idx + n >= prices.len() || low_idx + n >= prices.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let close = bars[bars.len() - 1].close; + let pressure_h = prices[high_idx + n]; + let pressure_l = prices[high_idx - n]; + if pressure_h > close && close > pressure_l { + v1 = "压力位"; + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let support_h = prices[low_idx + n]; + let support_l = prices[low_idx - n]; + if support_h > close && close > support_l { + v1 = "支撑位"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/registry.rs b/crates/czsc-signals/src/registry.rs new file mode 100644 index 000000000..06dbe0586 --- /dev/null +++ b/crates/czsc-signals/src/registry.rs @@ -0,0 +1,584 @@ +use crate::types::{SignalFnRef, SignalMeta, TraderSignalMeta}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::LazyLock; + +fn insert_generated_kline( + m: &mut HashMap<&'static str, SignalMeta>, + d: crate::types::SignalDescriptor, +) { + if d.category != "kline" { + return; + } + if let SignalFnRef::Kline(func) = d.func_ref { + m.insert( + d.name, + SignalMeta { + func, + param_template: d.template, + fast_kline: d.fast_kline, + }, + ); + } +} + +fn insert_generated_trader( + m: &mut HashMap<&'static str, TraderSignalMeta>, + d: crate::types::SignalDescriptor, +) { + if d.category != "trader" { + return; + } + if let SignalFnRef::Trader(func) = d.func_ref { + m.insert( + d.name, + TraderSignalMeta { + func, + param_template: d.template, + }, + ); + } +} + +/// K线级运行时注册视图(来源:`#[signal(category = "kline", ...)]` + inventory 自动收集) +pub static SIGNAL_REGISTRY: LazyLock> = LazyLock::new(|| { + let mut m: HashMap<&'static str, SignalMeta> = HashMap::new(); + for d in list_generated_signal_descriptors() { + insert_generated_kline(&mut m, d); + } + m +}); + +/// Trader/Position 级运行时注册视图(来源:`#[signal(category = "trader", ...)]` + inventory 自动收集) +pub static TRADER_SIGNAL_REGISTRY: LazyLock> = + LazyLock::new(|| { + let mut m: HashMap<&'static str, TraderSignalMeta> = HashMap::new(); + for d in list_generated_signal_descriptors() { + insert_generated_trader(&mut m, d); + } + m + }); + +/// 注册信号元信息(用于对照、文档、外部 API 只读查询) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct RegisteredSignalInfo { + /// 信号函数名,如 tas_ma_base_V221101 + pub name: String, + /// 参数模板 + pub param_template: String, + /// 注册表类别:kline / trader + pub category: String, + /// 信号命名空间前缀,如 bar/tas/cxt/pos + pub namespace: String, +} + +/// 汇总读取全部已注册信号(只读视图)。 +pub fn list_all_signals(include_kline: bool, include_trader: bool) -> Vec { + let mut out = Vec::new(); + + if include_kline { + for (name, meta) in SIGNAL_REGISTRY.iter() { + let namespace = name.split('_').next().unwrap_or_default().to_string(); + out.push(RegisteredSignalInfo { + name: (*name).to_string(), + param_template: meta.param_template.to_string(), + category: "kline".to_string(), + namespace, + }); + } + } + + if include_trader { + for (name, meta) in TRADER_SIGNAL_REGISTRY.iter() { + let namespace = name.split('_').next().unwrap_or_default().to_string(); + out.push(RegisteredSignalInfo { + name: (*name).to_string(), + param_template: meta.param_template.to_string(), + category: "trader".to_string(), + namespace, + }); + } + } + + out.sort_by(|a, b| { + a.category + .cmp(&b.category) + .then_with(|| a.name.cmp(&b.name)) + }); + out +} + +fn normalize_generated_signal_descriptors( + mut out: Vec, +) -> Result, String> { + let mut names: HashMap<&'static str, &'static str> = HashMap::new(); + let mut opcodes: HashMap<&'static str, &'static str> = HashMap::new(); + for d in &out { + if let Some(prev) = names.insert(d.name, d.opcode) { + return Err(format!( + "duplicate signal name: {} (opcodes: {}, {})", + d.name, prev, d.opcode + )); + } + if let Some(prev_name) = opcodes.insert(d.opcode, d.name) { + return Err(format!( + "duplicate signal opcode: {} (names: {}, {})", + d.opcode, prev_name, d.name + )); + } + } + out.sort_by(|a, b| a.name.cmp(b.name)); + Ok(out) +} + +pub fn list_generated_signal_descriptors() -> Vec { + let out: Vec = + inventory::iter:: + .into_iter() + .copied() + .collect(); + normalize_generated_signal_descriptors(out) + .unwrap_or_else(|e| panic!("invalid generated signal descriptors: {e}")) +} + +#[cfg(test)] +mod tests { + use super::list_all_signals; + use czsc_core::analyze::CZSC; + use czsc_core::objects::signal::Signal; + use serde_json::Value; + use std::collections::HashMap; + + fn __inventory_probe_signal( + _czsc: &CZSC, + _params: &HashMap, + _cache: &mut crate::types::TaCache, + ) -> Vec { + Vec::new() + } + + inventory::submit! { + crate::types::SignalDescriptor { + category: "kline", + name: "__inventory_probe_V000000", + template: "probe_template", + opcode: "InventoryProbe", + param_kind: "Probe", + func_ref: crate::types::SignalFnRef::Kline(__inventory_probe_signal as crate::types::SignalFn), + fast_kline: None, + } + } + + #[test] + fn test_list_all_signals_contains_both_categories() { + let all = list_all_signals(true, true); + assert!(!all.is_empty()); + assert!(all.iter().any(|x| x.category == "kline")); + assert!(all.iter().any(|x| x.category == "trader")); + } + + #[test] + fn test_list_all_signals_sorted_and_unique() { + let all = list_all_signals(true, true); + for i in 1..all.len() { + let prev = (&all[i - 1].category, &all[i - 1].name); + let curr = (&all[i].category, &all[i].name); + assert!(prev <= curr); + } + let mut seen = std::collections::HashSet::new(); + for s in all { + assert!(seen.insert((s.category, s.name))); + } + } + + #[test] + fn test_pos_signal_registry_contains_all_python_pos_functions() { + let expected = vec![ + "pos_ma_V230414", + "pos_fx_stop_V230414", + "pos_bar_stop_V230524", + "pos_holds_V230414", + "pos_fix_exit_V230624", + "pos_profit_loss_V230624", + "pos_status_V230808", + "pos_holds_V230807", + "pos_holds_V240428", + "pos_holds_V240608", + "pos_stop_V240428", + "pos_take_V240428", + "pos_stop_V240331", + "pos_stop_V240608", + "pos_stop_V240614", + "pos_stop_V240717", + ]; + for name in expected { + assert!( + super::TRADER_SIGNAL_REGISTRY.contains_key(name), + "missing trader signal registration: {}", + name + ); + } + } + + #[test] + fn test_trader_registry_contains_cat_macd_functions() { + for name in ["cat_macd_V230518", "cat_macd_V230520"] { + assert!( + super::TRADER_SIGNAL_REGISTRY.contains_key(name), + "missing trader signal registration: {}", + name + ); + } + } + + #[test] + fn test_kline_registry_contains_jcc_batch_and_cxt_bi_status_v230102() { + let expected = vec![ + "jcc_ci_tou_V221101", + "jcc_fan_ji_xian_V221121", + "jcc_fen_shou_xian_V20221113", + "jcc_gap_yin_yang_V221121", + "jcc_ping_tou_V221113", + "jcc_san_fa_V20221115", + "jcc_san_fa_V20221118", + "jcc_san_szx_V221122", + "jcc_san_xing_xian_V221023", + "jcc_shan_chun_V221121", + "jcc_szx_V221111", + "jcc_ta_xing_V221124", + "jcc_ten_mo_V221028", + "jcc_three_crow_V221108", + "jcc_two_crow_V221108", + "jcc_wu_yun_gai_ding_V221101", + "jcc_xing_xian_V221118", + "jcc_yun_xian_V221118", + "jcc_zhu_huo_xian_V221027", + "cxt_bi_status_V230102", + ]; + for name in expected { + assert!( + super::SIGNAL_REGISTRY.contains_key(name), + "missing kline signal registration: {}", + name + ); + } + } + + #[test] + fn test_kline_registry_contains_cxt_batch2_20() { + let expected = vec![ + "cxt_fx_power_V221107", + "cxt_bi_end_V230104", + "cxt_bi_end_V230105", + "cxt_bi_end_V230224", + "cxt_bi_end_V230312", + "cxt_bi_end_V230324", + "cxt_bi_end_V230815", + "cxt_bi_stop_V230815", + "cxt_bi_trend_V230824", + "cxt_bi_zdf_V230601", + "cxt_second_bs_V230320", + "cxt_third_bs_V230318", + "cxt_double_zs_V230311", + "cxt_decision_V240526", + "cxt_decision_V240612", + "cxt_decision_V240613", + "cxt_decision_V240614", + "cxt_overlap_V240526", + "cxt_bs_V240526", + "cxt_bs_V240527", + ]; + for name in expected { + assert!( + super::SIGNAL_REGISTRY.contains_key(name), + "missing kline signal registration: {}", + name + ); + } + } + + #[test] + fn test_registry_contains_remaining_cxt_zdy_39() { + let expected_kline = vec![ + "cxt_bi_end_V230222", + "cxt_bi_end_V230320", + "cxt_bi_end_V230322", + "cxt_bi_end_V230618", + "cxt_bi_trend_V230913", + "cxt_eleven_bi_V230622", + "cxt_first_buy_V221126", + "cxt_first_sell_V221126", + "cxt_five_bi_V230619", + "cxt_nine_bi_V230621", + "cxt_overlap_V240612", + "cxt_range_oscillation_V230620", + "cxt_second_bs_V240524", + "cxt_seven_bi_V230620", + "cxt_third_bs_V230319", + "cxt_third_buy_V230228", + "cxt_three_bi_V230618", + "cxt_ubi_end_V230816", + "zdy_bi_end_V230406", + "zdy_bi_end_V230407", + "zdy_dif_V230527", + "zdy_dif_V230528", + "zdy_macd_V230518", + "zdy_macd_V230519", + "zdy_macd_V230527", + "zdy_macd_bc_V230422", + "zdy_macd_bs1_V230422", + "zdy_macd_dif_V230516", + "zdy_macd_dif_V230517", + "zdy_macd_dif_iqr_V230521", + "zdy_zs_V230423", + "zdy_zs_space_V230421", + ]; + let expected_trader = vec![ + "cxt_intraday_V230701", + "cxt_zhong_shu_gong_zhen_V221221", + "zdy_stop_loss_V230406", + "zdy_take_profit_V230406", + "zdy_take_profit_V230407", + "zdy_vibrate_V230406", + ]; + for name in expected_kline { + assert!( + super::SIGNAL_REGISTRY.contains_key(name), + "missing kline signal registration: {}", + name + ); + } + for name in expected_trader { + assert!( + super::TRADER_SIGNAL_REGISTRY.contains_key(name), + "missing trader signal registration: {}", + name + ); + } + } + + #[test] + fn test_tas_registry_contains_migrated_signals() { + let expected = vec![ + "tas_ma_round_V221206", + "tas_double_ma_V230511", + "tas_first_bs_V230217", + "tas_second_bs_V230228", + "tas_second_bs_V230303", + "tas_hlma_V230301", + "tas_boll_cc_V230312", + "tas_kdj_evc_V221201", + "tas_kdj_evc_V230401", + "tas_atr_break_V230424", + "tas_ma_system_V230513", + "tas_dif_layer_V241010", + "tas_cross_status_V230619", + "tas_cross_status_V230624", + "tas_cross_status_V230625", + "tas_slope_V231019", + "tas_boll_vt_V230212", + "tas_cci_base_V230402", + "tas_accelerate_V230531", + "tas_low_trend_V230627", + "tas_atr_V230630", + "tas_angle_V230802", + "tas_double_ma_V240208", + "tas_dma_bs_V240608", + "tas_macd_bc_V230803", + "tas_macd_bc_V240307", + "tas_macd_first_bs_V221216", + "tas_macd_second_bs_V221201", + "tas_macd_xt_V221208", + "tas_macd_bs1_V230312", + "tas_macd_bs1_V230313", + "tas_sar_base_V230425", + "tas_rumi_V230704", + "tas_macd_bs1_V230411", + "tas_macd_bs1_V230412", + "tas_macd_bc_V230804", + "tas_macd_bc_ubi_V230804", + "bar_end_V221211", + "bar_operate_span_V221111", + "bar_time_V230327", + "bar_weekday_V230328", + "vol_single_ma_V230214", + "vol_double_ma_V230214", + "vol_ti_suo_V221216", + "vol_gao_di_V221218", + "vol_window_V230731", + "vol_window_V230801", + "pressure_support_V240222", + "pressure_support_V240402", + "pressure_support_V240406", + "pressure_support_V240530", + "obvm_line_V230610", + "obv_up_dw_line_V230719", + "cvolp_up_dw_line_V230612", + "ntmdk_V230824", + "kcatr_up_dw_line_V230823", + "clv_up_dw_line_V230605", + "cmo_up_dw_line_V230605", + "adtm_up_dw_line_V230603", + "amv_up_dw_line_V230603", + "asi_up_dw_line_V230603", + "bias_up_dw_line_V230618", + "dema_up_dw_line_V230605", + "demakder_up_dw_line_V230605", + "emv_up_dw_line_V230605", + "er_up_dw_line_V230604", + "skdj_up_dw_line_V230611", + "coo_td_V221110", + "coo_td_V221111", + "coo_cci_V230323", + "coo_kdj_V230322", + "coo_sar_V230325", + "byi_symmetry_zs_V221107", + "byi_bi_end_V230106", + "byi_bi_end_V230107", + "byi_second_bs_V230324", + "byi_fx_num_V230628", + "xl_bar_position_V240328", + "xl_bar_trend_V240329", + "xl_bar_trend_V240330", + "xl_bar_trend_V240331", + "xl_bar_basis_V240411", + "xl_bar_basis_V240412", + "xl_bar_trend_V240623", + "cci_decision_V240620", + "tas_ma_cohere_V230512", + ]; + for name in expected { + assert!( + super::SIGNAL_REGISTRY.contains_key(name), + "missing kline signal registration: {}", + name + ); + } + } + + #[test] + fn test_macro_injected_kline_descriptors_registered() { + let d1 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MA_BASE_V221101; + let d2 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MA_BASE_V221203; + let d3 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MACD_BASE_V221028; + let d4 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MACD_CHANGE_V221105; + let d5 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MACD_DIRECT_V221106; + let d6 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MACD_POWER_V221108; + let d7 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MACD_DIST_V230408; + let d8 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MACD_DIST_V230409; + let d9 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_MACD_DIST_V230410; + let d10 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_CROSS_STATUS_V230619; + let d11 = crate::tas::__RS_CZSC_SIGNAL_META_TAS_DOUBLE_MA_V221203; + for d in [d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11] { + let meta = super::SIGNAL_REGISTRY + .get(d.name) + .unwrap_or_else(|| panic!("missing macro injected signal: {}", d.name)); + assert_eq!(meta.param_template, d.template); + } + } + + #[test] + fn test_generated_descriptor_list_contains_macd_triplet() { + let generated = super::list_generated_signal_descriptors(); + let names: std::collections::HashSet<_> = generated.iter().map(|d| d.name).collect(); + for name in [ + "tas_macd_change_V221105", + "tas_macd_direct_V221106", + "tas_macd_power_V221108", + ] { + assert!(names.contains(name), "missing generated descriptor: {name}"); + } + } + + #[test] + fn test_generated_descriptor_list_contains_unmigrated_baselines() { + let generated = super::list_generated_signal_descriptors(); + let names: std::collections::HashSet<_> = generated.iter().map(|d| d.name).collect(); + for name in ["tas_rsi_base_V230227", "pos_ma_V230414"] { + assert!(names.contains(name), "missing generated descriptor: {name}"); + } + } + + #[test] + fn test_generated_descriptor_list_contains_macd_dist_triplet() { + let generated = super::list_generated_signal_descriptors(); + let names: std::collections::HashSet<_> = generated.iter().map(|d| d.name).collect(); + for name in [ + "tas_macd_dist_V230408", + "tas_macd_dist_V230409", + "tas_macd_dist_V230410", + ] { + assert!(names.contains(name), "missing generated descriptor: {name}"); + } + } + + #[test] + fn test_generated_descriptor_list_auto_discovers_inventory_submissions() { + let generated = super::list_generated_signal_descriptors(); + let names: std::collections::HashSet<_> = generated.iter().map(|d| d.name).collect(); + assert!( + names.contains("__inventory_probe_V000000"), + "inventory submitted descriptor should be auto discovered" + ); + } + + #[test] + fn test_normalize_generated_rejects_duplicate_name() { + let d1 = crate::types::SignalDescriptor { + category: "kline", + name: "dup_name_V000001", + template: "a", + opcode: "OpcodeA", + param_kind: "A", + func_ref: crate::types::SignalFnRef::Kline( + __inventory_probe_signal as crate::types::SignalFn, + ), + fast_kline: None, + }; + let d2 = crate::types::SignalDescriptor { + opcode: "OpcodeB", + ..d1 + }; + let err = match super::normalize_generated_signal_descriptors(vec![d1, d2]) { + Ok(_) => panic!("expected duplicate signal name error"), + Err(e) => e, + }; + assert!(err.contains("duplicate signal name")); + } + + #[test] + fn test_normalize_generated_rejects_duplicate_opcode() { + let d1 = crate::types::SignalDescriptor { + category: "kline", + name: "name_a_V000001", + template: "a", + opcode: "DupOpcode", + param_kind: "A", + func_ref: crate::types::SignalFnRef::Kline( + __inventory_probe_signal as crate::types::SignalFn, + ), + fast_kline: None, + }; + let d2 = crate::types::SignalDescriptor { + name: "name_b_V000001", + ..d1 + }; + let err = match super::normalize_generated_signal_descriptors(vec![d1, d2]) { + Ok(_) => panic!("expected duplicate signal opcode error"), + Err(e) => e, + }; + assert!(err.contains("duplicate signal opcode")); + } + + #[test] + fn test_macro_injected_trader_descriptors_registered() { + let d1 = crate::pos::__RS_CZSC_SIGNAL_META_POS_FX_STOP_V230414; + let d2 = crate::pos::__RS_CZSC_SIGNAL_META_POS_STATUS_V230808; + for d in [d1, d2] { + let meta = super::TRADER_SIGNAL_REGISTRY + .get(d.name) + .unwrap_or_else(|| panic!("missing macro injected trader signal: {}", d.name)); + assert_eq!(meta.param_template, d.template); + } + } +} diff --git a/crates/czsc-signals/src/tas.rs b/crates/czsc-signals/src/tas.rs new file mode 100644 index 000000000..69cd10386 --- /dev/null +++ b/crates/czsc-signals/src/tas.rs @@ -0,0 +1,5118 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{ + bar_index_map, cal_cross_num, count_last_same, cross_zero_axis, down_cross_count, + fast_slow_cross, fast_slow_cross_ext, get_str_param, get_sub_elements, get_usize_param, + linear_slope, make_kline_signal_v1, make_kline_signal_v2, make_kline_signal_v3, + pd_cut_last_label, qcut_last_label, std_abs_series, values_from_fx, +}; +use crate::utils::ta::{ + calc_sma, macd_snapshot_field_value, update_atr_cache, update_boll_cache, update_cci_cache, + update_kdj_cache, update_ma_cache, update_macd_cache, update_sar_cache, MacdField, +}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::zs::ZS; +use czsc_signal_macros::signal; +use serde_json::Value; +use std::collections::HashMap; + +#[allow(clippy::too_many_arguments)] +fn snapshot_dif_values_from_raw_bars( + czsc: &CZSC, + mc: &crate::types::MacdSeries, + id_to_idx: &HashMap, + raw_bars: &[RawBar], + short: usize, + long: usize, + m: usize, + snapshot_overrides: &mut HashMap, +) -> Vec { + raw_bars + .iter() + .filter_map(|rb| { + macd_snapshot_field_value( + czsc, + mc, + id_to_idx, + rb, + short, + long, + m, + MacdField::Dif, + snapshot_overrides, + ) + }) + .filter(|x| x.is_finite()) + .collect() +} + +#[allow(clippy::too_many_arguments)] +fn snapshot_dif_values_from_fx( + czsc: &CZSC, + mc: &crate::types::MacdSeries, + id_to_idx: &HashMap, + fx: &czsc_core::objects::fx::FX, + short: usize, + long: usize, + m: usize, + snapshot_overrides: &mut HashMap, +) -> Vec { + let raw_bars: Vec = fx + .elements + .iter() + .flat_map(|nb| nb.elements.iter().cloned()) + .collect(); + snapshot_dif_values_from_raw_bars( + czsc, + mc, + id_to_idx, + &raw_bars, + short, + long, + m, + snapshot_overrides, + ) +} + +/// tas_ma_base_V221101:单均线多空与方向信号 +/// +/// 参数模板:`"{freq}_D{di}{ma_type}#{timeperiod}_分类V221101"` +/// +/// 信号逻辑: +/// 1. 计算指定均线(`SMA/EMA`); +/// 2. `close >= ma` 判定 `多头`,否则 `空头`; +/// 3. `ma_now >= ma_prev` 判定 `向上`,否则 `向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1SMA#5_分类V221101_多头_向上_任意_0')` +/// - `Signal('60分钟_D1EMA#12_分类V221101_空头_向下_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_ma_base_V221101", + template = "{freq}_D{di}{ma_type}#{timeperiod}_分类V221101", + opcode = "TasMaBaseV221101", + param_kind = "TasMaBase" +)] +pub fn tas_ma_base_v221101(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 5); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let freq = czsc.freq; + + // 缓存 key 唯一标识这根线 + let cache_key = format!("{}_{}_{}", freq, ma_type, timeperiod); + + // 更新缓存 + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + + let mut signals_res = Vec::new(); + let ma = cache.series.get(&cache_key).unwrap(); + let close = &czsc.bars_raw; + let bars = get_sub_elements(close, di, 3); + if bars.len() < 2 { + return signals_res; + } + + let c = bars[bars.len() - 1].close; + let m = ma[close.len() - di]; + let m_prev = ma[close.len() - di - 1]; + + let v1 = if c >= m { "多头" } else { "空头" }; + + // 判断方向:当前均线 >= 上一根均线为向上 + let v2 = if m >= m_prev { "向上" } else { "向下" }; + + let k1 = freq.to_string(); + let k2 = format!("D{}{}#{}", di, ma_type, timeperiod); + let k3 = "分类V221101"; + signals_res.extend(make_kline_signal_v2(&k1, &k2, k3, v1, v2)); + + signals_res +} + +/// tas_ma_base_V221203:单均线多空与距离分层信号 +/// +/// 参数模板:`"{freq}_D{di}{ma_type}#{timeperiod}T{th}_分类V221203"` +/// +/// 信号逻辑: +/// 1. 计算指定均线(`SMA/EMA`); +/// 2. `close >= ma` 判定 `多头`,否则 `空头`; +/// 3. `ma_now >= ma_prev` 判定 `向上`,否则 `向下`; +/// 4. `abs(close-ma)/ma * 10000 > th` 判定 `远离`,否则 `靠近`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1SMA#5T100_分类V221203_多头_向上_靠近_0')` +/// - `Signal('60分钟_D1EMA#12T80_分类V221203_空头_向下_远离_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`; +/// - `th`:距离阈值(BP),默认 `100`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_ma_base_V221203", + template = "{freq}_D{di}{ma_type}#{timeperiod}T{th}_分类V221203", + opcode = "TasMaBaseV221203", + param_kind = "TasMaBaseV221203" +)] +pub fn tas_ma_base_v221203(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod = get_usize_param(params, "timeperiod", 5); + let th = get_usize_param(params, "th", 100) as f64; + + let cache_key = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod); + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + let ma = cache.series.get(&cache_key).unwrap(); + let bars = get_sub_elements(&czsc.bars_raw, di, 3); + if bars.len() < 2 { + return Vec::new(); + } + + let c = bars[bars.len() - 1].close; + let m = ma[czsc.bars_raw.len() - di]; + let m_prev = ma[czsc.bars_raw.len() - di - 1]; + + let v1 = if c >= m { "多头" } else { "空头" }; + let v2 = if m >= m_prev { "向上" } else { "向下" }; + let v3 = if ((c - m).abs() / m) * 10000.0 > th { + "远离" + } else { + "靠近" + }; + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}{}#{}T{}", di, ma_type, timeperiod, th as usize); + let k3 = "分类V221203"; + make_kline_signal_v3(&k1, &k2, k3, v1, v2, v3) +} + +/// tas_ma_base_V230313:单均线开平仓辅助信号(带重叠约束) +/// +/// 参数模板:`"{freq}_D{di}#{ma_type}#{timeperiod}MO{max_overlap}_BS辅助V230313"` +/// +/// 信号逻辑: +/// 1. 计算指定均线(`SMA/EMA`); +/// 2. 取倒数 `di` 截止的 `max_overlap+1` 根K线; +/// 3. 若最新 `close >= ma` 且窗口内并非全部 `close > ma`,判 `看多`; +/// 4. 若最新 `close < ma` 且窗口内并非全部 `close < ma`,判 `看空`; +/// 5. 否则判 `其他`;并用 `ma_now >= ma_prev` 判方向 `向上/向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1#SMA#5MO5_BS辅助V230313_看多_向上_任意_0')` +/// - `Signal('60分钟_D1#EMA#12MO5_BS辅助V230313_看空_向下_任意_0')` +/// - `Signal('60分钟_D1#SMA#5MO5_BS辅助V230313_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`; +/// - `max_overlap`:相同方向最大重叠窗口,默认 `5`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_ma_base_V230313", + template = "{freq}_D{di}#{ma_type}#{timeperiod}MO{max_overlap}_BS辅助V230313", + opcode = "TasMaBaseV230313", + param_kind = "TasMaBaseV230313" +)] +pub fn tas_ma_base_v230313(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod = get_usize_param(params, "timeperiod", 5); + let max_overlap = get_usize_param(params, "max_overlap", 5); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}#{}#{}MO{}", di, ma_type, timeperiod, max_overlap); + let k3 = "BS辅助V230313"; + + if max_overlap < 2 { + return Vec::new(); + } + + let cache_key = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod); + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + let ma = cache.series.get(&cache_key).unwrap(); + + let bars = get_sub_elements(&czsc.bars_raw, di, max_overlap + 1); + if bars.len() < max_overlap + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let end = czsc.bars_raw.len() - di + 1; + let start = end - (max_overlap + 1); + let last_idx = end - 1; + let last_close = bars[bars.len() - 1].close; + let last_ma = ma[last_idx]; + + let all_above = (start..end).all(|i| czsc.bars_raw[i].close > ma[i]); + let all_below = (start..end).all(|i| czsc.bars_raw[i].close < ma[i]); + + let v1 = if last_close >= last_ma && !all_above { + "看多" + } else if last_close < last_ma && !all_below { + "看空" + } else { + "其他" + }; + + if v1 == "其他" { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let v2 = if ma[last_idx] >= ma[last_idx - 1] { + "向上" + } else { + "向下" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_ma_round_V221206:笔端点触碰均线信号 +/// +/// 参数模板:`"{freq}_D{di}TH{th}#碰{ma_type}#{timeperiod}_BE辅助V221206"` +/// +/// 信号逻辑: +/// 1. 计算指定均线(`SMA/EMA`); +/// 2. 取倒数第 `di` 笔,提取其结束分型中间 NewBar 的原始K线; +/// 3. 计算该批原始K线对应均线均值 `last_ma`; +/// 4. 若上笔且 `abs(high-last_ma)/power_price < th/100`,判 `上碰`; +/// 5. 若下笔且 `abs(low-last_ma)/power_price < th/100`,判 `下碰`;否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1TH10#碰SMA#60_BE辅助V221206_上碰_任意_任意_0')` +/// - `Signal('60分钟_D1TH10#碰SMA#60_BE辅助V221206_下碰_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:指定倒数第 `di` 笔,默认 `1`; +/// - `th`:端点触碰阈值(百分比),默认 `10`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_ma_round_V221206", + template = "{freq}_D{di}TH{th}#碰{ma_type}#{timeperiod}_BE辅助V221206", + opcode = "TasMaRoundV221206", + param_kind = "TasMaRoundV221206" +)] +pub fn tas_ma_round_v221206(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 10) as f64; + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod = get_usize_param(params, "timeperiod", 5); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}TH{}#碰{}#{}", di, th as usize, ma_type, timeperiod); + let k3 = "BE辅助V221206"; + let mut v1 = "其他"; + + let cache_key = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod); + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + let ma = cache.series.get(&cache_key).unwrap(); + let bar_idx_map: HashMap = czsc + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + if czsc.bi_list.len() > di + 3 { + let last_bi = &czsc.bi_list[czsc.bi_list.len() - di]; + let mut ma_vals = Vec::new(); + if last_bi.fx_b.elements.len() > 1 { + let nb = &last_bi.fx_b.elements[1]; + for rb in &nb.elements { + if let Some(idx) = bar_idx_map.get(&rb.id) { + ma_vals.push(ma[*idx]); + } + } + } + + if !ma_vals.is_empty() { + let last_ma = ma_vals.iter().sum::() / ma_vals.len() as f64; + let bi_change = last_bi.get_power_price(); + if bi_change > 0.0 { + if last_bi.direction == Direction::Up + && (last_bi.get_high() - last_ma).abs() / bi_change < th / 100.0 + { + v1 = "上碰"; + } else if last_bi.direction == Direction::Down + && (last_bi.get_low() - last_ma).abs() / bi_change < th / 100.0 + { + v1 = "下碰"; + } + } + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_double_ma_V230511:双均线反向信号 +/// +/// 参数模板:`"{freq}_D{di}#{ma_type}#{t1}#{t2}_BS辅助V230511"` +/// +/// 信号逻辑: +/// 1. 计算 `t1/t2` 双均线,`t1 < t2`; +/// 2. 当前K线需为大实体(`solid >= max(upper, lower, mean_solid)`); +/// 3. `ma1 > ma2` 且当前大实体阴线,判 `看多`; +/// 4. `ma1 < ma2` 且当前大实体阳线,判 `看空`; +/// 5. 在同侧连续区间内若仅出现一次对应大实体且区间长度 `< t2/2`,则 `v2=第一个`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1#SMA#5#20_BS辅助V230511_看多_第一个_任意_0')` +/// - `Signal('60分钟_D1#SMA#5#20_BS辅助V230511_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `t1`:快线周期,默认 `5`; +/// - `t2`:慢线周期,默认 `20`; +/// - `ma_type`:均线类型,默认 `SMA`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_double_ma_V230511", + template = "{freq}_D{di}#{ma_type}#{t1}#{t2}_BS辅助V230511", + opcode = "TasDoubleMaV230511", + param_kind = "TasDoubleMaV230511" +)] +pub fn tas_double_ma_v230511(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let t1 = get_usize_param(params, "t1", 5); + let t2 = get_usize_param(params, "t2", 20); + let ma_type = get_str_param(params, "ma_type", "SMA"); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}#{}#{}#{}", di, ma_type, t1, t2); + let k3 = "BS辅助V230511"; + let mut v1 = "其他"; + let mut v2 = "任意"; + + if t1 >= t2 || czsc.bars_raw.len() < t2 + 10 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v2(&k1, &k2, k3, v1, v2); + } + + let key1 = format!("{}_{}_{}", czsc.freq, ma_type, t1); + let key2 = format!("{}_{}_{}", czsc.freq, ma_type, t2); + update_ma_cache(czsc, &key1, ma_type, t1, cache); + update_ma_cache(czsc, &key2, ma_type, t2, cache); + let ma1 = cache.series.get(&key1).unwrap(); + let ma2 = cache.series.get(&key2).unwrap(); + let bar_idx_map: HashMap = czsc + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + let bars = get_sub_elements(&czsc.bars_raw, di, t2 + 1); + if bars.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, v1, v2); + } + + let mean_solid = bars.iter().map(|x| (x.open - x.close).abs()).sum::() / bars.len() as f64; + let bar = &czsc.bars_raw[czsc.bars_raw.len() - di]; + let bar_upper = bar.high - bar.open.max(bar.close); + let bar_lower = bar.open.min(bar.close) - bar.low; + let bar_solid = (bar.open - bar.close).abs(); + let solid_th = mean_solid.max(bar_upper).max(bar_lower); + if bar_solid < solid_th { + return make_kline_signal_v2(&k1, &k2, k3, v1, v2); + } + + let Some(&idx) = bar_idx_map.get(&bar.id) else { + return make_kline_signal_v2(&k1, &k2, k3, v1, v2); + }; + if ma1[idx] > ma2[idx] && bar.close < bar.open { + v1 = "看多"; + let mut right_bars = Vec::new(); + for x in bars.iter().rev() { + let Some(&xi) = bar_idx_map.get(&x.id) else { + continue; + }; + if ma1[xi] > ma2[xi] { + right_bars.push((x.open - x.close).abs() > solid_th && x.close < x.open); + } else { + break; + } + } + let cnt = right_bars.iter().filter(|x| **x).count(); + if (right_bars.len() as f64) < (t2 as f64 / 2.0) && cnt == 1 { + v2 = "第一个"; + } + } else if ma1[idx] < ma2[idx] && bar.close > bar.open { + v1 = "看空"; + let mut right_bars = Vec::new(); + for x in bars.iter().rev() { + let Some(&xi) = bar_idx_map.get(&x.id) else { + continue; + }; + if ma1[xi] < ma2[xi] { + right_bars.push((x.open - x.close).abs() > solid_th && x.close > x.open); + } else { + break; + } + } + let cnt = right_bars.iter().filter(|x| **x).count(); + if (right_bars.len() as f64) < (t2 as f64 / 2.0) && cnt == 1 { + v2 = "第一个"; + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_macd_base_V221028:MACD/DIF/DEA 多空与方向信号 +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}#{key}_BS辅助V221028"` +/// +/// 信号逻辑: +/// 1. 计算 MACD 三序列; +/// 2. 依据 `key` 选择 `MACD/DIF/DEA`; +/// 3. 当前值 `>=0` 判定 `多头`,否则 `空头`; +/// 4. 当前值 `>=` 前值判定 `向上`,否则 `向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9#MACD_BS辅助V221028_多头_向上_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9#DIF_BS辅助V221028_空头_向下_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`; +/// - `key`:`MACD`、`DIF` 或 `DEA`,默认 `MACD`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_base_V221028", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS辅助V221028", + opcode = "TasMacdBaseV221028", + param_kind = "TasMacdBaseV221028" +)] +pub fn tas_macd_base_v221028(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let key = get_str_param(params, "key", "MACD").to_uppercase(); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + + let macd_cache = cache.macd.get(&cache_key).unwrap(); + let series = match key.as_str() { + "DIF" => &macd_cache.dif, + "DEA" => &macd_cache.dea, + _ => &macd_cache.macd, + }; + let sub = get_sub_elements(series, di, 2); + if sub.len() < 2 { + return Vec::new(); + } + let prev = sub[sub.len() - 2]; + let curr = sub[sub.len() - 1]; + let v1 = if curr >= 0.0 { "多头" } else { "空头" }; + let v2 = if curr >= prev { "向上" } else { "向下" }; + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}MACD{}#{}#{}#{}", + di, fastperiod, slowperiod, signalperiod, key + ); + let k3 = "BS辅助V221028"; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_macd_direct_V221106:MACD柱方向信号 +/// +/// 参数模板:`"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}方向_BS辅助V221106"` +/// +/// 信号逻辑: +/// 1. 计算 MACD 柱序列; +/// 2. 取倒数 `di` 对齐的最近 3 根柱值; +/// 3. 严格递增判定 `向上`,严格递减判定 `向下`,否则 `模糊`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K#MACD12#26#9方向_BS辅助V221106_向上_任意_任意_0')` +/// - `Signal('60分钟_D1K#MACD12#26#9方向_BS辅助V221106_向下_任意_任意_0')` +/// - `Signal('60分钟_D1K#MACD12#26#9方向_BS辅助V221106_模糊_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_direct_V221106", + template = "{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}方向_BS辅助V221106", + opcode = "TasMacdDirectV221106", + param_kind = "TasMacdDirectV221106" +)] +pub fn tas_macd_direct_v221106( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let macd_cache = cache.macd.get(&cache_key).unwrap(); + let macd = get_sub_elements(&macd_cache.macd, di, 3); + + let v1 = if macd.len() != 3 { + "模糊" + } else if macd[2] > macd[1] && macd[1] > macd[0] { + "向上" + } else if macd[2] < macd[1] && macd[1] < macd[0] { + "向下" + } else { + "模糊" + }; + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}K#MACD{}#{}#{}方向", + di, fastperiod, slowperiod, signalperiod + ); + let k3 = "BS辅助V221106"; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_power_V221108:MACD强弱分层信号 +/// +/// 参数模板:`"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}强弱_BS辅助V221108"` +/// +/// 信号逻辑: +/// 1. 计算当前 `DIF/DEA`; +/// 2. `dif >= dea >= 0` 判定 `超强`; +/// 3. `dif - dea > 0` 判定 `强势`; +/// 4. `dif <= dea <= 0` 判定 `超弱`; +/// 5. `dif - dea < 0` 判定 `弱势`,其余为 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K#MACD12#26#9强弱_BS辅助V221108_超强_任意_任意_0')` +/// - `Signal('60分钟_D1K#MACD12#26#9强弱_BS辅助V221108_弱势_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_power_V221108", + template = "{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}强弱_BS辅助V221108", + opcode = "TasMacdPowerV221108", + param_kind = "TasMacdPowerV221108" +)] +pub fn tas_macd_power_v221108(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let macd_cache = cache.macd.get(&cache_key).unwrap(); + + let mut v1 = "其他"; + if czsc.bars_raw.len() > di + 10 { + let idx = czsc.bars_raw.len() - di; + let dif = macd_cache.dif[idx]; + let dea = macd_cache.dea[idx]; + if dif >= dea && dea >= 0.0 { + v1 = "超强"; + } else if dif - dea > 0.0 { + v1 = "强势"; + } else if dif <= dea && dea <= 0.0 { + v1 = "超弱"; + } else if dif - dea < 0.0 { + v1 = "弱势"; + } + } + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}K#MACD{}#{}#{}强弱", + di, fastperiod, slowperiod, signalperiod + ); + let k3 = "BS辅助V221108"; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_first_bs_V230217:均线结合K线形态的一买一卖辅助 +/// +/// 参数模板:`"{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS1辅助V230217"` +/// +/// 信号逻辑: +/// 1. 在最近 `n` 根K线上计算均线并构造 `sma/low/high/open/close` 序列; +/// 2. 一买条件: +/// - `sma > low` 全满足; +/// - 阴线占比 `> 60%`; +/// - 最近3根出现新低; +/// - 最后一根收盘在均线上方; +/// 3. 一卖条件与上面对称; +/// 4. 满足则输出 `一买/一卖`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N10#SMA#5_BS1辅助V230217_一买_任意_任意_0')` +/// - `Signal('60分钟_D1N10#SMA#5_BS1辅助V230217_一卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:窗口大小,默认 `10`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_first_bs_V230217", + template = "{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS1辅助V230217", + opcode = "TasFirstBsV230217", + param_kind = "TasFirstBsV230217" +)] +pub fn tas_first_bs_v230217(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 10); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod = get_usize_param(params, "timeperiod", 5); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}#{}#{}", di, n, ma_type, timeperiod); + let k3 = "BS1辅助V230217"; + let mut v1 = "其他"; + + if di == 0 || czsc.bars_raw.len() < n + 5 || n < 4 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod); + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + let ma = cache.series.get(&cache_key).unwrap(); + + let end = czsc.bars_raw.len() - di + 1; + let start = end - n; + let bars = &czsc.bars_raw[start..end]; + let mut sma = Vec::with_capacity(n); + let mut low = Vec::with_capacity(n); + let mut high = Vec::with_capacity(n); + let mut open = Vec::with_capacity(n); + let mut close = Vec::with_capacity(n); + for (i, b) in bars.iter().enumerate().take(n) { + let idx = start + i; + sma.push(ma[idx]); + low.push(b.low); + high.push(b.high); + open.push(b.open); + close.push(b.close); + } + + let condition_1_down = sma.iter().zip(low.iter()).all(|(a, b)| *a > *b); + let condition_1_up = sma.iter().zip(high.iter()).all(|(a, b)| *a < *b); + + let n1 = close + .iter() + .zip(open.iter()) + .filter(|(c, o)| **c < **o) + .count(); + let m1 = close + .iter() + .zip(open.iter()) + .filter(|(c, o)| **c > **o) + .count(); + let condition_2_down = (n1 as f64 / n as f64) > 0.6; + let condition_2_up = (m1 as f64 / n as f64) > 0.6; + + let low_last3_min = low[n - 3..].iter().copied().fold(f64::INFINITY, f64::min); + let low_prev_min = low[..n - 3].iter().copied().fold(f64::INFINITY, f64::min); + let condition_3_down = low_last3_min < low_prev_min; + let high_last3_max = high[n - 3..] + .iter() + .copied() + .fold(f64::NEG_INFINITY, f64::max); + let high_prev_max = high[..n - 3] + .iter() + .copied() + .fold(f64::NEG_INFINITY, f64::max); + let condition_3_up = high_last3_max > high_prev_max; + + let condition_4_down = close[n - 1] > sma[n - 1]; + let condition_4_up = close[n - 1] < sma[n - 1]; + + if condition_1_down && condition_2_down && condition_3_down && condition_4_down { + v1 = "一买"; + } else if condition_1_up && condition_2_up && condition_3_up && condition_4_up { + v1 = "一卖"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_second_bs_V230228:均线结合K线形态的二买二卖辅助 +/// +/// 参数模板:`"{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS2辅助V230228"` +/// +/// 信号逻辑: +/// 1. 在最近 `n` 根K线上计算均线; +/// 2. 二买条件: +/// - `sma[-1]` 为窗口新高且 `sma[-1] > sma[-2]`; +/// - 最新收盘 `close[-1] > sma[-1]`; +/// - 最近3根存在 `low < sma`; +/// 3. 二卖条件与上面对称; +/// 4. 满足则输出 `二买/二卖`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N21#SMA#20_BS2辅助V230228_二买_任意_任意_0')` +/// - `Signal('60分钟_D1N21#SMA#20_BS2辅助V230228_二卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:窗口大小,默认 `21`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `20`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_second_bs_V230228", + template = "{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS2辅助V230228", + opcode = "TasSecondBsV230228", + param_kind = "TasSecondBsV230228" +)] +pub fn tas_second_bs_v230228(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 21); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod = get_usize_param(params, "timeperiod", 20); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}#{}#{}", di, n, ma_type, timeperiod); + let k3 = "BS2辅助V230228"; + let mut v1 = "其他"; + + if di == 0 || czsc.bars_raw.len() < n + 5 || n < 3 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod); + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + let ma = cache.series.get(&cache_key).unwrap(); + + let end = czsc.bars_raw.len() - di + 1; + let start = end - n; + let bars = &czsc.bars_raw[start..end]; + let mut sma = Vec::with_capacity(n); + for i in 0..n { + sma.push(ma[start + i]); + } + + let min_three = bars[n - 3..] + .iter() + .zip(sma[n - 3..].iter()) + .any(|(b, s)| *s > b.low); + let max_three = bars[n - 3..] + .iter() + .zip(sma[n - 3..].iter()) + .any(|(b, s)| b.high > *s); + + let sma_max = sma.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let sma_min = sma.iter().copied().fold(f64::INFINITY, f64::min); + let close_last = bars[n - 1].close; + + if sma_max == sma[n - 1] && sma[n - 1] > sma[n - 2] && close_last > sma[n - 1] && min_three { + v1 = "二买"; + } else if sma_min == sma[n - 1] + && sma[n - 1] < sma[n - 2] + && close_last < sma[n - 1] + && max_three + { + v1 = "二卖"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_second_bs_V230303:利用笔和均线辅助二买二卖 +/// +/// 参数模板:`"{freq}_D{di}{ma_type}#{timeperiod}_BS2辅助V230303"` +/// +/// 信号逻辑: +/// 1. 取倒数 `di` 截止最近13笔,取最后一笔与其首尾原始K线; +/// 2. 二买条件: +/// - 最后一笔为向下; +/// - 最后一笔末K最低点跌破均线; +/// - 最近5笔最低点为13笔全局最低; +/// - 该笔首K均线值 < 末K均线值(均线向上); +/// 3. 二卖条件与上面对称; +/// 4. 满足则输出 `二买/二卖`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1SMA#30_BS2辅助V230303_二买_任意_任意_0')` +/// - `Signal('60分钟_D1SMA#30_BS2辅助V230303_二卖_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:指定倒数第 `di` 笔,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `30`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_second_bs_V230303", + template = "{freq}_D{di}{ma_type}#{timeperiod}_BS2辅助V230303", + opcode = "TasSecondBsV230303", + param_kind = "TasSecondBsV230303" +)] +pub fn tas_second_bs_v230303(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod = get_usize_param(params, "timeperiod", 30); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}{}#{}", di, ma_type, timeperiod); + let k3 = "BS2辅助V230303"; + let mut v1 = "其他"; + + if di == 0 || czsc.bi_list.len() < di + 13 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod); + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + let ma = cache.series.get(&cache_key).unwrap(); + let bar_idx_map: HashMap = czsc + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + let bi_list = get_sub_elements(&czsc.bi_list, di, 13); + if bi_list.len() < 13 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let last_bi = &bi_list[bi_list.len() - 1]; + let rb = last_bi.get_raw_bars(); + if rb.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let first_bar = &rb[0]; + let last_bar = &rb[rb.len() - 1]; + let Some(&first_idx) = bar_idx_map.get(&first_bar.id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + let Some(&last_idx) = bar_idx_map.get(&last_bar.id) else { + return make_kline_signal_v1(&k1, &k2, k3, v1); + }; + + let min_low_5 = bi_list[bi_list.len() - 5..] + .iter() + .map(|x| x.get_low()) + .fold(f64::INFINITY, f64::min); + let min_low_all = bi_list + .iter() + .map(|x| x.get_low()) + .fold(f64::INFINITY, f64::min); + if last_bi.direction == Direction::Down + && last_bar.low < ma[last_idx] + && min_low_5 == min_low_all + && ma[first_idx] < ma[last_idx] + { + v1 = "二买"; + } + + let max_high_5 = bi_list[bi_list.len() - 5..] + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + let max_high_all = bi_list + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + if last_bi.direction == Direction::Up + && last_bar.high > ma[last_idx] + && max_high_5 == max_high_all + && ma[first_idx] > ma[last_idx] + { + v1 = "二卖"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_hlma_V230301:HMA/LMA 多空信号 +/// +/// 参数模板:`"{freq}_D{di}#{ma_type}#{timeperiod}HLMA_BS辅助V230301"` +/// +/// 信号逻辑: +/// 1. 取最近 `timeperiod` 根K线,计算 `hma=high均值`、`lma=low均值`; +/// 2. 若 `close_now > hma` 且 `close_prev <= ma_prev`,判 `看多`; +/// 3. 若 `close_now < lma` 且 `close_prev >= ma_prev`,判 `看空`; +/// 4. 否则判 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1#SMA#3HLMA_BS辅助V230301_看多_任意_任意_0')` +/// - `Signal('60分钟_D1#SMA#3HLMA_BS辅助V230301_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:窗口周期,默认 `3`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_hlma_V230301", + template = "{freq}_D{di}#{ma_type}#{timeperiod}HLMA_BS辅助V230301", + opcode = "TasHlmaV230301", + param_kind = "TasHlmaV230301" +)] +pub fn tas_hlma_v230301(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod = get_usize_param(params, "timeperiod", 3); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}#{}#{}HLMA", di, ma_type, timeperiod); + let k3 = "BS辅助V230301"; + let mut v1 = "其他"; + + if di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod); + update_ma_cache(czsc, &cache_key, ma_type, timeperiod, cache); + let ma = cache.series.get(&cache_key).unwrap(); + let bar_idx_map: HashMap = czsc + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + let bars = get_sub_elements(&czsc.bars_raw, di, timeperiod); + if bars.len() >= 2 { + let hma = bars.iter().map(|x| x.high).sum::() / bars.len() as f64; + let lma = bars.iter().map(|x| x.low).sum::() / bars.len() as f64; + let b1 = &bars[bars.len() - 1]; + let b2 = &bars[bars.len() - 2]; + if let Some(&b2_idx) = bar_idx_map.get(&b2.id) { + if b1.close > hma && b2.close <= ma[b2_idx] { + v1 = "看多"; + } else if b1.close < lma && b2.close >= ma[b2_idx] { + v1 = "看空"; + } + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_boll_cc_V230312:布林进出场信号 +/// +/// 参数模板:`"{freq}_D{di}BOLL{timeperiod}S{nbdev}SP{sp}_BS辅助V230312"` +/// +/// 信号逻辑: +/// 1. 计算 BOLL 中轨与上下轨; +/// 2. 计算 `bias = (close / mid - 1) * 10000`; +/// 3. `close < upper 且 bias < -sp` 判 `看空`,`close > lower 且 bias > sp` 判 `看多`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1BOLL20S20SP400_BS辅助V230312_看空_任意_任意_0')` +/// - `Signal('60分钟_D1BOLL20S20SP400_BS辅助V230312_看多_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:BOLL 周期,默认 `20`; +/// - `nbdev`:标准差倍数 *10,默认 `20`; +/// - `sp`:偏离阈值(BP),默认 `400`。 +/// 对齐说明:与 Python `tas_boll_cc_V230312` 的 bias 判定和阈值方向一致。 +#[signal( + category = "kline", + name = "tas_boll_cc_V230312", + template = "{freq}_D{di}BOLL{timeperiod}S{nbdev}SP{sp}_BS辅助V230312", + opcode = "TasBollCcV230312", + param_kind = "TasBollCcV230312" +)] +pub fn tas_boll_cc_v230312(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let sp = get_usize_param(params, "sp", 400) as f64; + let timeperiod = get_usize_param(params, "timeperiod", 20); + let nbdev = get_usize_param(params, "nbdev", 20); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}BOLL{}S{}SP{}", di, timeperiod, nbdev, sp as usize); + let k3 = "BS辅助V230312"; + let mut v1 = "其他"; + if di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = format!("BOLL{}S{}", timeperiod, nbdev); + update_boll_cache(czsc, &cache_key, timeperiod, nbdev as f64 / 10.0, cache); + let boll = cache.boll.get(&cache_key).unwrap(); + let idx = czsc.bars_raw.len() - di; + let close = czsc.bars_raw[idx].close; + let mid = boll.mid[idx]; + let upper = boll.upper[idx]; + let lower = boll.lower[idx]; + if mid.is_finite() && upper.is_finite() && lower.is_finite() && mid.abs() > f64::EPSILON { + let bias = (close / mid - 1.0) * 10000.0; + if close < upper && bias < -sp { + v1 = "看空"; + } else if close > lower && bias > sp { + v1 = "看多"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_kdj_evc_V221201:KDJ 极值计数信号 +/// +/// 参数模板:`"{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{c1}#{c2}_KDJ极值V221201"` +/// +/// 信号逻辑: +/// 1. 计算 `K/D/J` 序列并提取 `key`; +/// 2. 统计末端连续低于 `th` 或高于 `100-th` 的次数; +/// 3. 连续次数落入 `[c1, c2)` 时分别输出 `多头/空头`,并在 `v2` 标注计数。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_KDJ极值V221201_多头_C5_任意_0')` +/// - `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_KDJ极值V221201_空头_C6_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `key`:取值 `K/D/J`,默认 `K`; +/// - `th`:极值阈值,默认 `10`; +/// - `count_range`:连续计数区间,默认 `[5, 8]`。 +/// 对齐说明:连续计数、`v2=Cx` 标注方式与 Python `tas_kdj_evc_V221201` 保持一致。 +#[signal( + category = "kline", + name = "tas_kdj_evc_V221201", + template = "{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{c1}#{c2}_KDJ极值V221201", + opcode = "TasKdjEvcV221201", + param_kind = "TasKdjEvcV221201" +)] +pub fn tas_kdj_evc_v221201(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let mut key = get_str_param(params, "key", "K").to_uppercase(); + let th = get_usize_param(params, "th", 10); + let (c1, c2) = if let Some(Value::Array(arr)) = params.value("count_range") { + if arr.len() >= 2 { + let a = arr[0].as_u64().unwrap_or(5) as usize; + let b = arr[1].as_u64().unwrap_or(8) as usize; + (a, b) + } else { + (5, 8) + } + } else { + (5, 8) + }; + let fastk_period = get_usize_param(params, "fastk_period", 9); + let slowk_period = get_usize_param(params, "slowk_period", 3); + let slowd_period = get_usize_param(params, "slowd_period", 3); + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}T{}KDJ{}#{}#{}#{}值突破{}#{}", + di, th, fastk_period, slowk_period, slowd_period, key, c1, c2 + ); + let k3 = "KDJ极值V221201"; + let mut v1 = "其他"; + let mut v2 = "任意".to_string(); + if c2 <= c1 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v2(&k1, &k2, k3, v1, &v2); + } + + let cache_key = format!("KDJ{}#{}#{}", fastk_period, slowk_period, slowd_period); + update_kdj_cache( + czsc, + &cache_key, + fastk_period, + slowk_period, + slowd_period, + cache, + ); + let kd = cache.kdj.get(&cache_key).unwrap(); + let bars = get_sub_elements(&czsc.bars_raw, di, 3 + c2); + if bars.len() == 3 + c2 { + key = key.to_lowercase(); + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let mut vals = Vec::with_capacity(bars.len()); + for i in 0..bars.len() { + let idx = start + i; + let x = match key.as_str() { + "d" => kd.d[idx], + "j" => kd.j[idx], + _ => kd.k[idx], + }; + vals.push(x); + } + let long: Vec = vals.iter().map(|x| *x < th as f64).collect(); + let short: Vec = vals.iter().map(|x| *x > 100.0 - th as f64).collect(); + let lc = if *long.last().unwrap_or(&false) { + count_last_same(&long) + } else { + 0 + }; + let sc = if *short.last().unwrap_or(&false) { + count_last_same(&short) + } else { + 0 + }; + if c2 > lc && lc >= c1 { + v1 = "多头"; + v2 = format!("C{}", lc); + } + if c2 > sc && sc >= c1 { + v1 = "空头"; + v2 = format!("C{}", sc); + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} + +/// tas_kdj_evc_V230401:KDJ 极值计数信号 +/// +/// 参数模板:`"{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{min_count}#{max_count}_BS辅助V230401"` +/// +/// 信号逻辑: +/// 1. 计算 `K/D/J` 指标并提取目标序列; +/// 2. 末端连续低于阈值记多头计数,连续高于阈值记空头计数; +/// 3. 连续次数在 `[min_count, max_count)` 时输出 `多头/空头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_BS辅助V230401_多头_任意_任意_0')` +/// - `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_BS辅助V230401_空头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `key`:`K/D/J`,默认 `K`; +/// - `th`:极值阈值,默认 `10`; +/// - `min_count/max_count`:连续计数区间,默认 `5/8`。 +/// 对齐说明:参数校验与计数边界严格对齐 Python `tas_kdj_evc_V230401`。 +#[signal( + category = "kline", + name = "tas_kdj_evc_V230401", + template = "{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{min_count}#{max_count}_BS辅助V230401", + opcode = "TasKdjEvcV230401", + param_kind = "TasKdjEvcV230401" +)] +pub fn tas_kdj_evc_v230401(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let key = get_str_param(params, "key", "K").to_uppercase(); + let th = get_usize_param(params, "th", 10); + let min_count = get_usize_param(params, "min_count", 5); + let max_count = get_usize_param(params, "max_count", min_count + 3); + let fastk_period = get_usize_param(params, "fastk_period", 9); + let slowk_period = get_usize_param(params, "slowk_period", 3); + let slowd_period = get_usize_param(params, "slowd_period", 3); + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}T{}KDJ{}#{}#{}#{}值突破{}#{}", + di, th, fastk_period, slowk_period, slowd_period, key, min_count, max_count + ); + let k3 = "BS辅助V230401"; + let mut v1 = "其他"; + if min_count >= max_count + || !(1..100).contains(&th) + || !matches!(key.as_str(), "K" | "D" | "J") + || czsc.bars_raw.len() < di + max_count + 2 + { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = format!("KDJ{}#{}#{}", fastk_period, slowk_period, slowd_period); + update_kdj_cache( + czsc, + &cache_key, + fastk_period, + slowk_period, + slowd_period, + cache, + ); + let kd = cache.kdj.get(&cache_key).unwrap(); + + let bars = get_sub_elements(&czsc.bars_raw, di, 3 + max_count); + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let vals: Vec = (0..bars.len()) + .map(|i| { + let idx = start + i; + match key.as_str() { + "D" => kd.d[idx], + "J" => kd.j[idx], + _ => kd.k[idx], + } + }) + .collect(); + let long: Vec = vals.iter().map(|x| *x < th as f64).collect(); + let short: Vec = vals.iter().map(|x| *x > 100.0 - th as f64).collect(); + let lc = if *long.last().unwrap_or(&false) { + count_last_same(&long) + } else { + 0 + }; + let sc = if *short.last().unwrap_or(&false) { + count_last_same(&short) + } else { + 0 + }; + if max_count > lc && lc >= min_count { + v1 = "多头"; + } + if max_count > sc && sc >= min_count { + v1 = "空头"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_atr_break_V230424:ATR 通道突破 +/// +/// 参数模板:`"{freq}_D{di}ATR{timeperiod}T{th}突破_BS辅助V230424"` +/// +/// 信号逻辑: +/// 1. 取窗口 `HH/LL` 和当前 ATR; +/// 2. 若 `close` 落在 `HH-th*ATR` 与 `LL+th*ATR` 之间,输出 `其他`; +/// 3. 向上突破输出 `看多`,向下突破输出 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1ATR5T30突破_BS辅助V230424_看多_任意_任意_0')` +/// - `Signal('60分钟_D1ATR5T30突破_BS辅助V230424_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:ATR 周期,默认 `5`; +/// - `th`:ATR 倍数(除以10),默认 `30`。 +/// 对齐说明:区间内返回 `其他` 的优先级与 Python `tas_atr_break_V230424` 一致。 +#[signal( + category = "kline", + name = "tas_atr_break_V230424", + template = "{freq}_D{di}ATR{timeperiod}T{th}突破_BS辅助V230424", + opcode = "TasAtrBreakV230424", + param_kind = "TasAtrBreakV230424" +)] +pub fn tas_atr_break_v230424(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 30); + let timeperiod = get_usize_param(params, "timeperiod", 5); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}ATR{}T{}突破", di, timeperiod, th); + let k3 = "BS辅助V230424"; + + if di == 0 || di > czsc.bars_raw.len() || czsc.bars_raw.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let cache_key = format!("ATR{}", timeperiod); + update_atr_cache(czsc, &cache_key, timeperiod, cache); + let atr_series = cache.series.get(&cache_key).unwrap(); + let bars = get_sub_elements(&czsc.bars_raw, di, timeperiod); + let hh = bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let ll = bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let idx = czsc.bars_raw.len() - di; + let bar = &czsc.bars_raw[idx]; + let atr = atr_series[idx]; + let thf = th as f64 / 10.0; + + let v1 = if hh - thf * atr > bar.close && bar.close > ll + thf * atr { + "其他" + } else if bar.close > ll + thf * atr { + "看多" + } else if bar.close < hh - thf * atr { + "看空" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_ma_system_V230513:均线系统多空排列 +/// +/// 参数模板:`"{freq}_D{di}SMA{ma_seq}_均线系统V230513"` +/// +/// 信号逻辑: +/// 1. 计算 `ma_seq` 中各周期 SMA; +/// 2. 当前值严格递减判 `多头排列`; +/// 3. 当前值严格递增判 `空头排列`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1SMA5#10#20_均线系统V230513_多头排列_任意_任意_0')` +/// - `Signal('60分钟_D1SMA5#10#20_均线系统V230513_空头排列_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `ma_seq`:均线周期串,默认 `5#10#20`。 +/// 对齐说明:排列判定方向与 Python `tas_ma_system_V230513` 完全一致。 +#[signal( + category = "kline", + name = "tas_ma_system_V230513", + template = "{freq}_D{di}SMA{ma_seq}_均线系统V230513", + opcode = "TasMaSystemV230513", + param_kind = "TasMaSystemV230513" +)] +pub fn tas_ma_system_v230513(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let ma_seq_str = get_str_param(params, "ma_seq", "5#10#20"); + let ma_seq: Vec = ma_seq_str + .split('#') + .filter_map(|x| x.parse::().ok()) + .collect(); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}SMA{}", di, ma_seq_str); + let k3 = "均线系统V230513"; + let mut v1 = "其他"; + + if ma_seq.is_empty() || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + for ma in &ma_seq { + let key = format!("{}_SMA_{}", czsc.freq, ma); + update_ma_cache(czsc, &key, "SMA", *ma, cache); + } + let max_ma = *ma_seq.iter().max().unwrap(); + if czsc.bars_raw.len() < max_ma + di + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let idx = czsc.bars_raw.len() - di; + let ma_vals: Vec = ma_seq + .iter() + .map(|x| { + let key = format!("{}_SMA_{}", czsc.freq, x); + cache.series.get(&key).unwrap()[idx] + }) + .collect(); + if ma_vals.windows(2).all(|w| w[0] > w[1]) { + v1 = "多头排列"; + } else if ma_vals.windows(2).all(|w| w[0] < w[1]) { + v1 = "空头排列"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_ma_cohere_V230512:均线系统粘合/扩散状态 +/// +/// 参数模板:`"{freq}_D{di}SMA{ma_seq}_均线系统V230512"` +/// +/// 信号逻辑: +/// 1. 计算 `ma_seq` 各条 SMA,并构造最近 100 根“均线最大值/最小值 - 1”序列; +/// 2. 用前 80 根计算标准差 `ret_std`; +/// 3. 最近 20 根中,`ret < 0.5 * ret_std` 达到 16 次判 `粘合`; +/// 4. 最近 20 根中,`ret > 1.0 * ret_std` 达到 16 次判 `扩散`(覆盖前者)。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1SMA5#13#21#34#55_均线系统V230512_粘合_任意_任意_0')` +/// - `Signal('60分钟_D1SMA5#13#21#34#55_均线系统V230512_扩散_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `ma_seq`:均线周期序列(`#` 分隔),默认 `5#13#21#34#55`。 +/// 对齐说明:阈值与覆盖顺序对齐 Python `tas_ma_cohere_V230512`。 +#[signal( + category = "kline", + name = "tas_ma_cohere_V230512", + template = "{freq}_D{di}SMA{ma_seq}_均线系统V230512", + opcode = "TasMaCohereV230512", + param_kind = "TasMaCohereV230512" +)] +pub fn tas_ma_cohere_v230512(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let ma_seq_str = get_str_param(params, "ma_seq", "5#13#21#34#55"); + let ma_seq: Vec = ma_seq_str + .split('#') + .filter_map(|x| x.parse::().ok()) + .collect(); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}SMA{}", di, ma_seq_str); + let k3 = "均线系统V230512"; + let mut v1 = "其他"; + + if ma_seq.is_empty() || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + for ma in &ma_seq { + let key = format!("SMA#{}", ma); + update_ma_cache(czsc, &key, "SMA", *ma, cache); + } + let max_ma = *ma_seq.iter().max().unwrap_or(&0); + if czsc.bars_raw.len() < max_ma + di + 10 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let bars = get_sub_elements(&czsc.bars_raw, di, 100); + if bars.len() < 20 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let id_to_idx = bar_index_map(czsc); + + let mut ret_seq = Vec::with_capacity(bars.len()); + for bar in bars { + let Some(idx) = id_to_idx.get(&bar.id).copied() else { + continue; + }; + let mut min_v = f64::INFINITY; + let mut max_v = f64::NEG_INFINITY; + for ma in &ma_seq { + let key = format!("SMA#{}", ma); + let val = cache + .series + .get(&key) + .and_then(|s| s.get(idx)) + .copied() + .unwrap_or(f64::NAN); + min_v = min_v.min(val); + max_v = max_v.max(val); + } + let ret = max_v / min_v - 1.0; + ret_seq.push(ret); + } + + if ret_seq.len() < 20 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let base = &ret_seq[..ret_seq.len() - 20]; + if base.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let mean = base.iter().sum::() / base.len() as f64; + let var = base.iter().map(|x| (x - mean).powi(2)).sum::() / base.len() as f64; + let ret_std = var.sqrt(); + if !ret_std.is_finite() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let tail = &ret_seq[ret_seq.len() - 20..]; + let tight = tail.iter().filter(|x| **x < 0.5 * ret_std).count(); + if tight >= 16 { + v1 = "粘合"; + } + let spread = tail.iter().filter(|x| **x > ret_std).count(); + if spread >= 16 { + v1 = "扩散"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_dif_layer_V241010:DIF 三层分类 +/// +/// 参数模板:`"{freq}_DIF分层W{w}T{t}_完全分类V241010"` +/// +/// 信号逻辑: +/// 1. 取最近 `w` 根 DIF,计算绝对值最大幅度基准 `r`; +/// 2. `|dif_last| > r * t` 且符号为负,判 `空头远离`; +/// 3. `|dif_last| > r * t` 且符号为正,判 `多头远离`,否则 `零轴附近`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_DIF分层W100T30_完全分类V241010_空头远离_任意_任意_0')` +/// - `Signal('60分钟_DIF分层W100T30_完全分类V241010_零轴附近_任意_任意_0')` +/// +/// 参数说明: +/// - `w`:观察窗口长度,默认 `100`; +/// - `t`:远离阈值倍率,默认 `30`。 +/// 对齐说明:分层阈值口径与 Python `tas_dif_layer_V241010` 一致。 +#[signal( + category = "kline", + name = "tas_dif_layer_V241010", + template = "{freq}_DIF分层W{w}T{t}_完全分类V241010", + opcode = "TasDifLayerV241010", + param_kind = "TasDifLayerV241010" +)] +pub fn tas_dif_layer_v241010(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let w = get_usize_param(params, "w", 100); + let t = get_usize_param(params, "t", 30); + let k1 = czsc.freq.to_string(); + let k2 = format!("DIF分层W{}T{}", w, t); + let k3 = "完全分类V241010"; + let mut v1 = "其他"; + if czsc.bars_raw.len() < w + 50 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let diffs = get_sub_elements(&mc.dif, 1, w); + let r = diffs + .iter() + .map(|x| x.abs()) + .fold(f64::NEG_INFINITY, f64::max) + / 100.0; + let last = *diffs.last().unwrap(); + if last < 0.0 && last.abs() > r * t as f64 { + v1 = "空头远离"; + } else if last > 0.0 && last.abs() > r * t as f64 { + v1 = "多头远离"; + } else { + v1 = "零轴附近"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_cross_status_V230619:0轴上下金死叉次数 +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_金死叉V230619"` +/// +/// 信号逻辑: +/// 1. 取近 100 根 DIF/DEA 并截取最近过零后的有效段; +/// 2. 在 0 轴上下分别统计金叉/死叉次数; +/// 3. 若当根形成有效交叉,输出 `0轴上/下金叉(死叉)第N次`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9_金死叉V230619_0轴下金叉第1次_任意_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9_金死叉V230619_0轴上死叉第2次_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 +/// 对齐说明:过零截取与交叉计次逻辑对齐 Python `tas_cross_status_V230619`。 +#[signal( + category = "kline", + name = "tas_cross_status_V230619", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_金死叉V230619", + opcode = "TasCrossStatusV230619", + param_kind = "TasCrossStatusV230619" +)] +pub fn tas_cross_status_v230619( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MACD{}#{}#{}", di, fastperiod, slowperiod, signalperiod); + let k3 = "金死叉V230619"; + let mut v1 = "其他".to_string(); + + let mc = cache.macd.get(&cache_key).unwrap(); + let dif = get_sub_elements(&mc.dif, di, 100); + let dea = get_sub_elements(&mc.dea, di, 100); + if dif.len() >= 100 && dea.len() >= 100 { + let num_k = cross_zero_axis(dif, dea); + let dif_temp = get_sub_elements(dif, di, num_k); + let dea_temp = get_sub_elements(dea, di, num_k); + let dl = dif[dif.len() - 1]; + let d2 = dif[dif.len() - 2]; + let el = dea[dea.len() - 1]; + let e2 = dea[dea.len() - 2]; + if dl < 0.0 && el < 0.0 { + let down_num_sc = down_cross_count(dif_temp, dea_temp); + let down_num_jc = down_cross_count(dea_temp, dif_temp); + if dl > el && d2 < e2 { + v1 = format!("0轴下金叉第{}次", down_num_jc); + } else if dl < el && d2 > e2 { + v1 = format!("0轴下死叉第{}次", down_num_sc); + } + } else if dl > 0.0 && el > 0.0 { + let up_num_sc = down_cross_count(dif_temp, dea_temp); + let up_num_jc = down_cross_count(dea_temp, dif_temp); + if dl > el && d2 < e2 { + v1 = format!("0轴上金叉第{}次", up_num_jc); + } else if dl < el && d2 > e2 { + v1 = format!("0轴上死叉第{}次", up_num_sc); + } + } + } + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// tas_cross_status_V230624:指定金死叉数值 +/// +/// 参数模板:`"{freq}_D{di}N{n}MD{md}_MACD交叉数量V230624"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 根 DIF/DEA 并过零截断; +/// 2. 按最小间隔 `md` 过滤交叉并统计 `jc/sc`; +/// 3. 根据当前所在零轴区域输出上下轴金叉/死叉次数。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N100MD1_MACD交叉数量V230624_0轴下金叉第2次_0轴下死叉第1次_任意_0')` +/// - `Signal('60分钟_D1N100MD1_MACD交叉数量V230624_0轴上金叉第1次_0轴上死叉第2次_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:统计窗口长度,默认 `100`; +/// - `md`:最小交叉间隔,默认 `1`。 +/// 对齐说明:交叉过滤及计数口径与 Python `tas_cross_status_V230624` 保持一致。 +#[signal( + category = "kline", + name = "tas_cross_status_V230624", + template = "{freq}_D{di}N{n}MD{md}_MACD交叉数量V230624", + opcode = "TasCrossStatusV230624", + param_kind = "TasCrossStatusV230624" +)] +pub fn tas_cross_status_v230624( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 100); + let md = get_usize_param(params, "md", 1).max(1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}MD{}", di, n, md); + let k3 = "MACD交叉数量V230624"; + let mut v1 = "其他".to_string(); + let mut v2 = "其他".to_string(); + if czsc.bars_raw.len() < n + 1 { + return make_kline_signal_v2(&k1, &k2, k3, &v1, &v2); + } + let mc = cache.macd.get(&cache_key).unwrap(); + let dif = get_sub_elements(&mc.dif, di, n); + let dea = get_sub_elements(&mc.dea, di, n); + let num_k = cross_zero_axis(dif, dea); + let dif_temp = get_sub_elements(dif, 1, num_k); + let dea_temp = get_sub_elements(dea, 1, num_k); + let cross = fast_slow_cross(dif_temp, dea_temp); + let (jc, sc) = cal_cross_num(&cross, md); + let dl = dif[dif.len() - 1]; + let el = dea[dea.len() - 1]; + if dl < 0.0 && el < 0.0 { + v1 = format!("0轴下金叉第{}次", jc); + v2 = format!("0轴下死叉第{}次", sc); + } else if dl > 0.0 && el > 0.0 { + v1 = format!("0轴上金叉第{}次", jc); + v2 = format!("0轴上死叉第{}次", sc); + } + make_kline_signal_v2(&k1, &k2, k3, &v1, &v2) +} + +/// tas_cross_status_V230625:指定金叉/死叉次数后状态 +/// +/// 参数模板:`"{freq}_D{di}N{n}MD{md}J{j}S{s}_MACD交叉数量V230625"` +/// +/// 信号逻辑: +/// 1. 在近 `n` 根内统计过滤后的金叉/死叉数量; +/// 2. 仅允许 `j` 或 `s` 之一生效; +/// 3. 达到目标次数后输出 `0轴上/下第N次金叉(死叉)以后`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N100MD1J2S0_MACD交叉数量V230625_0轴下第2次金叉以后_任意_任意_0')` +/// - `Signal('60分钟_D1N100MD1J0S2_MACD交叉数量V230625_0轴上第2次死叉以后_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:统计窗口长度,默认 `100`; +/// - `md`:交叉间隔阈值,默认 `1`; +/// - `j/s`:目标金叉或死叉次数,默认 `0/0`(二者不能同时非零)。 +/// 对齐说明:参数约束与触发语义对齐 Python `tas_cross_status_V230625`。 +#[signal( + category = "kline", + name = "tas_cross_status_V230625", + template = "{freq}_D{di}N{n}MD{md}J{j}S{s}_MACD交叉数量V230625", + opcode = "TasCrossStatusV230625", + param_kind = "TasCrossStatusV230625" +)] +pub fn tas_cross_status_v230625( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let j = get_usize_param(params, "j", 0); + let s = get_usize_param(params, "s", 0); + let n = get_usize_param(params, "n", 100); + let md = get_usize_param(params, "md", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}MD{}J{}S{}", di, n, md, j, s); + let k3 = "MACD交叉数量V230625"; + let mut v1 = "其他".to_string(); + if j * s != 0 || czsc.bars_raw.len() < di + n + 1 { + return make_kline_signal_v1(&k1, &k2, k3, &v1); + } + let mc = cache.macd.get(&cache_key).unwrap(); + let dif = get_sub_elements(&mc.dif, di, n); + let dea = get_sub_elements(&mc.dea, di, n); + let num_k = cross_zero_axis(dif, dea); + let dif_temp = get_sub_elements(dif, 1, num_k); + let dea_temp = get_sub_elements(dea, 1, num_k); + let cross = fast_slow_cross(dif_temp, dea_temp); + let (jc, sc) = cal_cross_num(&cross, md); + let dl = dif[dif.len() - 1]; + let el = dea[dea.len() - 1]; + if dl < 0.0 && el < 0.0 { + if jc >= j && s == 0 { + v1 = format!("0轴下第{}次金叉以后", j); + } else if j == 0 && sc >= s { + v1 = format!("0轴下第{}次死叉以后", s); + } + } else if dl > 0.0 && el > 0.0 { + if jc >= j && s == 0 { + v1 = format!("0轴上第{}次金叉以后", j); + } else if j == 0 && sc >= s { + v1 = format!("0轴上第{}次死叉以后", s); + } + } + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// tas_slope_V231019:DIF 斜率分位多空 +/// +/// 参数模板:`"{freq}_D{di}DIF{n}斜率T{th}_BS辅助V231019"` +/// +/// 信号逻辑: +/// 1. 计算最近区间内 DIF 线性回归斜率序列; +/// 2. 计算当前斜率在历史区间中的归一化分位; +/// 3. 分位 `> th/100` 判 `看多`,`< 1-th/100` 判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1DIF10斜率T80_BS辅助V231019_看多_任意_任意_0')` +/// - `Signal('60分钟_D1DIF10斜率T80_BS辅助V231019_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:斜率回看长度,默认 `10`; +/// - `th`:分位阈值(50-100),默认 `80`。 +/// 对齐说明:分位判定区间和阈值方向与 Python `tas_slope_V231019` 一致。 +#[signal( + category = "kline", + name = "tas_slope_V231019", + template = "{freq}_D{di}DIF{n}斜率T{th}_BS辅助V231019", + opcode = "TasSlopeV231019", + param_kind = "TasSlopeV231019" +)] +pub fn tas_slope_v231019(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 10); + let th = get_usize_param(params, "th", 80); + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}DIF{}斜率T{}", di, n, th); + let k3 = "BS辅助V231019"; + let mut v1 = "其他"; + if !(51..100).contains(&th) || czsc.bars_raw.len() < 50 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let mc = cache.macd.get(cache_key).unwrap(); + let dif = &mc.dif; + let end = czsc.bars_raw.len() - di + 1; + let start = end.saturating_sub(n * 10); + let mut slopes = Vec::new(); + for i in start..end { + if i < n { + slopes.push(0.0); + } else { + slopes.push(linear_slope(&dif[i - n..i])); + } + } + if slopes.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let last = *slopes.last().unwrap(); + let min_v = slopes.iter().copied().fold(f64::INFINITY, f64::min); + let max_v = slopes.iter().copied().fold(f64::NEG_INFINITY, f64::max); + if (max_v - min_v).abs() > f64::EPSILON { + let q = (last - min_v) / (max_v - min_v); + if q > th as f64 / 100.0 { + v1 = "看多"; + } else if q < 1.0 - th as f64 / 100.0 { + v1 = "看空"; + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_boll_vt_V230212:BOLL 通道突破进出场信号 +/// +/// 参数模板:`"{freq}_D{di}BOLL{timeperiod}S{nbdev}MO{max_overlap}_BS辅助V230212"` +/// +/// 信号逻辑: +/// 1. 计算指定参数的 BOLL 上下轨(`nbdev / 10` 为标准差倍数); +/// 2. 最新收盘价在上轨上方,且窗口内曾有收盘价在上轨下方,判 `看多`; +/// 3. 最新收盘价在下轨下方,且窗口内曾有收盘价在下轨上方,判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1BOLL20S20MO5_BS辅助V230212_看多_任意_任意_0')` +/// - `Signal('60分钟_D1BOLL20S20MO5_BS辅助V230212_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:BOLL 周期,默认 `20`; +/// - `nbdev`:标准差倍数 *10,默认 `20`; +/// - `max_overlap`:窗口重叠长度,默认 `5`。 +/// 对齐说明:严格按 Python `tas_boll_vt_V230212` 判定分支实现。 +#[signal( + category = "kline", + name = "tas_boll_vt_V230212", + template = "{freq}_D{di}BOLL{timeperiod}S{nbdev}MO{max_overlap}_BS辅助V230212", + opcode = "TasBollVtV230212", + param_kind = "TasBollVtV230212" +)] +pub fn tas_boll_vt_v230212(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 20); + let nbdev = get_usize_param(params, "nbdev", 20); + let max_overlap = get_usize_param(params, "max_overlap", 5); + let nbdev_f = nbdev as f64 / 10.0; + let cache_key = format!("BOLL{}#{:.1}", timeperiod, nbdev_f); + update_boll_cache(czsc, &cache_key, timeperiod, nbdev_f, cache); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}BOLL{}S{}MO{}", di, timeperiod, nbdev, max_overlap); + let k3 = "BS辅助V230212"; + let mut v1 = "其他"; + + let bars = get_sub_elements(&czsc.bars_raw, di, max_overlap + 1); + if bars.len() < max_overlap + 1 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let boll = cache.boll.get(&cache_key).unwrap(); + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let last_i = end - 1; + let last_bar = &bars[bars.len() - 1]; + if last_bar.close > boll.upper[last_i] + && bars + .iter() + .enumerate() + .any(|(i, b)| b.close < boll.upper[start + i]) + { + v1 = "看多"; + } else if last_bar.close < boll.lower[last_i] + && bars + .iter() + .enumerate() + .any(|(i, b)| b.close > boll.lower[start + i]) + { + v1 = "看空"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_cci_base_V230402:CCI 极值连续计数信号 +/// +/// 参数模板:`"{freq}_D{di}CCI{timeperiod}#{min_count}#{max_count}_BS辅助V230402"` +/// +/// 信号逻辑: +/// 1. 计算 CCI 序列; +/// 2. 若末尾连续 `CCI > 100` 次数落在 `[min_count, max_count)`,判 `多头`; +/// 3. 若末尾连续 `CCI < -100` 次数落在 `[min_count, max_count)`,判 `空头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1CCI14#3#6_BS辅助V230402_多头_任意_任意_0')` +/// - `Signal('60分钟_D1CCI14#3#6_BS辅助V230402_空头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:CCI 周期,默认 `14`; +/// - `min_count`:最小连续次数,默认 `3`; +/// - `max_count`:最大连续次数上界(开区间),默认 `min_count + 3`。 +/// 对齐说明:连续计数和覆盖顺序与 Python `tas_cci_base_V230402` 一致。 +#[signal( + category = "kline", + name = "tas_cci_base_V230402", + template = "{freq}_D{di}CCI{timeperiod}#{min_count}#{max_count}_BS辅助V230402", + opcode = "TasCciBaseV230402", + param_kind = "TasCciBaseV230402" +)] +pub fn tas_cci_base_v230402(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 14); + let min_count = get_usize_param(params, "min_count", 3); + let max_count = get_usize_param(params, "max_count", min_count + 3); + assert!(min_count < max_count, "min_count 必须小于 max_count"); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}CCI{}#{}#{}", di, timeperiod, min_count, max_count); + let k3 = "BS辅助V230402"; + let mut v1 = "其他"; + + let cache_key = format!("CCI{}", timeperiod); + update_cci_cache(czsc, &cache_key, timeperiod, cache); + let cci_series = cache.series.get(&cache_key).unwrap(); + let bars = get_sub_elements(&czsc.bars_raw, di, max_count + 1); + if bars.len() != max_count + 1 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let cci: Vec = (start..end).map(|i| cci_series[i]).collect(); + let long: Vec = cci.iter().map(|x| *x > 100.0).collect(); + let short: Vec = cci.iter().map(|x| *x < -100.0).collect(); + let lc = if *long.last().unwrap_or(&false) { + count_last_same(&long) + } else { + 0 + }; + let sc = if *short.last().unwrap_or(&false) { + count_last_same(&short) + } else { + 0 + }; + + if max_count > lc && lc >= min_count { + v1 = "多头"; + } + if max_count > sc && sc >= min_count { + v1 = "空头"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// cci_decision_V240620:CCI 逆势决策区域 +/// +/// 参数模板:`"{freq}_N{n}CCI_决策区域V240620"` +/// +/// 信号逻辑: +/// 1. 固定计算 `CCI(14)`; +/// 2. 取最近 `n` 根 CCI:若最小值 `< -100` 判 `开多`,`v2` 为 `< -100` 的出现次数; +/// 3. 若最大值 `> 100` 判 `开空`,`v2` 为 `> 100` 的出现次数(覆盖开多分支)。 +/// +/// 信号列表示例: +/// - `Signal('15分钟_N4CCI_决策区域V240620_开多_2次_任意_0')` +/// - `Signal('15分钟_N4CCI_决策区域V240620_开空_1次_任意_0')` +/// +/// 参数说明: +/// - `n`:统计窗口长度,默认 `2`。 +/// 对齐说明:分支顺序与 Python `cci_decision_V240620` 保持一致(后判空头覆盖前判多头)。 +#[signal( + category = "kline", + name = "cci_decision_V240620", + template = "{freq}_N{n}CCI_决策区域V240620", + opcode = "CciDecisionV240620", + param_kind = "CciDecisionV240620" +)] +pub fn cci_decision_v240620(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 2); + let k1 = czsc.freq.to_string(); + let k2 = format!("N{}CCI", n); + let k3 = "决策区域V240620"; + let mut v1 = "其他"; + let mut v2 = "任意".to_string(); + + if czsc.bars_raw.len() < 100 { + return make_kline_signal_v2(&k1, &k2, k3, v1, &v2); + } + + let cache_key = "CCI14"; + update_cci_cache(czsc, cache_key, 14, cache); + let cci = match cache.series.get(cache_key) { + Some(v) if !v.is_empty() => v, + _ => return make_kline_signal_v2(&k1, &k2, k3, v1, &v2), + }; + let start = if n == 0 { + 0 + } else { + cci.len().saturating_sub(n) + }; + let cci_seq = &cci[start..]; + if cci_seq.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, v1, &v2); + } + + let short_count = cci_seq.iter().filter(|x| **x > 100.0).count(); + let long_count = cci_seq.iter().filter(|x| **x < -100.0).count(); + + if cci_seq.iter().copied().fold(f64::INFINITY, f64::min) < -100.0 { + v1 = "开多"; + v2 = format!("{}次", long_count); + } + if cci_seq.iter().copied().fold(f64::NEG_INFINITY, f64::max) > 100.0 { + v1 = "开空"; + v2 = format!("{}次", short_count); + } + + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} + +/// tas_accelerate_V230531:BOLL 通道加速信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}T{t}_BOLL加速V230531"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 根,计算中线/上轨/下轨涨跌幅; +/// 2. 全部在中线上方且 `上轨涨幅 > t/10 * 中线涨幅 > 0`,判 `多头加速`; +/// 3. 全部在中线下方且 `下轨涨幅 < t/10 * 中线涨幅 < 0`,判 `空头加速`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N20T20_BOLL加速V230531_多头加速_升破上轨_任意_0')` +/// - `Signal('60分钟_D1N20T20_BOLL加速V230531_空头加速_跌破下轨_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:观察窗口,默认 `20`; +/// - `t`:轨道/中线倍率阈值(除以10),默认 `20`。 +/// 对齐说明:按 Python `tas_accelerate_V230531` 双分支覆盖语义实现。 +#[signal( + category = "kline", + name = "tas_accelerate_V230531", + template = "{freq}_D{di}N{n}T{t}_BOLL加速V230531", + opcode = "TasAccelerateV230531", + param_kind = "TasAccelerateV230531" +)] +pub fn tas_accelerate_v230531(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 20); + let t = get_usize_param(params, "t", 20); + let cache_key = "BOLL20#2.0"; + update_boll_cache(czsc, cache_key, 20, 2.0, cache); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}T{}", di, n, t); + let k3 = "BOLL加速V230531"; + let mut v1 = "其他"; + let mut v2 = "其他"; + + if czsc.bars_raw.len() < 40 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v2(&k1, &k2, k3, v1, v2); + } + let bars = get_sub_elements(&czsc.bars_raw, di, n); + if bars.is_empty() { + return make_kline_signal_v2(&k1, &k2, k3, v1, v2); + } + + let boll = cache.boll.get(cache_key).unwrap(); + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let first_i = start; + let last_i = end - 1; + let mid_zdf = boll.mid[last_i] / boll.mid[first_i] - 1.0; + let up_zdf = boll.upper[last_i] / boll.upper[first_i] - 1.0; + let down_zdf = boll.lower[last_i] / boll.lower[first_i] - 1.0; + let all_above_mid = bars + .iter() + .enumerate() + .all(|(i, b)| b.close > boll.mid[start + i]); + let all_below_mid = bars + .iter() + .enumerate() + .all(|(i, b)| b.close < boll.mid[start + i]); + let last_bar = &bars[bars.len() - 1]; + + if all_above_mid && up_zdf > (t as f64 / 10.0) * mid_zdf && mid_zdf > 0.0 { + v1 = "多头加速"; + v2 = if boll.upper[last_i] < last_bar.high { + "升破上轨" + } else { + "未破上轨" + }; + } + if all_below_mid && down_zdf < (t as f64 / 10.0) * mid_zdf && mid_zdf < 0.0 { + v1 = "空头加速"; + v2 = if boll.lower[last_i] > last_bar.low { + "跌破下轨" + } else { + "未破下轨" + }; + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_low_trend_V230627:阴跌/小阳趋势信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}TH{th}_趋势230627"` +/// +/// 信号逻辑: +/// 1. 对窗口内实体振幅做阈值过滤,剔除波动过大的场景; +/// 2. 统计 `low <= 历史收盘最小值` 次数,超过 `0.8*n` 判 `阴跌趋势`; +/// 3. 统计 `high >= 历史收盘最大值` 次数,超过 `0.8*n` 判 `小阳趋势`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N13TH300_趋势230627_阴跌趋势_任意_任意_0')` +/// - `Signal('60分钟_D1N13TH300_趋势230627_小阳趋势_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:统计窗口,默认 `13`; +/// - `th`:实体振幅阈值(BP),默认 `300`。 +/// 对齐说明:循环窗口与阈值比较口径对齐 Python `tas_low_trend_V230627`。 +#[signal( + category = "kline", + name = "tas_low_trend_V230627", + template = "{freq}_D{di}N{n}TH{th}_趋势230627", + opcode = "TasLowTrendV230627", + param_kind = "TasLowTrendV230627" +)] +pub fn tas_low_trend_v230627(czsc: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 13); + let th = get_usize_param(params, "th", 300); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}TH{}", di, n, th); + let k3 = "趋势230627"; + let mut v1 = "其他"; + + if czsc.bars_raw.len() < di + n + 8 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let bars = get_sub_elements(&czsc.bars_raw, di, n + 5); + if bars.len() < n + 5 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let solid_zf: Vec = bars[5..] + .iter() + .map(|x| (x.close / x.open - 1.0).abs() * 10000.0) + .collect(); + let violent = solid_zf.iter().filter(|x| **x > th as f64).count(); + if violent as f64 > (0.2 * n as f64).max(3.0) { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let mut min_count = 0usize; + let mut max_count = 0usize; + for i in 5..bars.len() { + let bar = &bars[i]; + let w5 = &bars[..i]; + let min_close = w5.iter().map(|x| x.close).fold(f64::INFINITY, f64::min); + let max_close = w5.iter().map(|x| x.close).fold(f64::NEG_INFINITY, f64::max); + if bar.low <= min_close { + min_count += 1; + } + if bar.high >= max_close { + max_count += 1; + } + } + + if min_count as f64 >= 0.8 * n as f64 { + v1 = "阴跌趋势"; + } + if max_count as f64 >= 0.8 * n as f64 { + v1 = "小阳趋势"; + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_atr_V230630:ATR 波动分层信号 +/// +/// 参数模板:`"{freq}_D{di}ATR{timeperiod}_波动V230630"` +/// +/// 信号逻辑: +/// 1. 计算 `ATR / close` 波动率; +/// 2. 对最近 100 根波动率做 `qcut(10)` 分层; +/// 3. 输出末值所在层级 `第{n}层`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1ATR14_波动V230630_第3层_任意_任意_0')` +/// - `Signal('60分钟_D1ATR14_波动V230630_第9层_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:ATR 周期,默认 `14`。 +/// 对齐说明:ATR 预热与分层边界对齐 Python `tas_atr_V230630`。 +#[signal( + category = "kline", + name = "tas_atr_V230630", + template = "{freq}_D{di}ATR{timeperiod}_波动V230630", + opcode = "TasAtrV230630", + param_kind = "TasAtrV230630" +)] +pub fn tas_atr_v230630(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 14); + let cache_key = format!("ATR{}", timeperiod); + update_atr_cache(czsc, &cache_key, timeperiod, cache); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}ATR{}", di, timeperiod); + let k3 = "波动V230630"; + let mut v1 = "其他".to_string(); + + if czsc.bars_raw.len() < di + timeperiod + 8 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, &v1); + } + + let bars = get_sub_elements(&czsc.bars_raw, di, 100); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, &v1); + } + let atr = cache.series.get(&cache_key).unwrap(); + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let lev: Vec = bars + .iter() + .enumerate() + .map(|(i, b)| atr[start + i] / b.close) + .collect(); + if let Some(q) = qcut_last_label(&lev, 10) { + v1 = format!("第{}层", q + 1); + } + + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// tas_macd_base_V230320:MACD/DIF/DEA 多空与方向信号(含重叠约束) +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}MO{max_overlap}#{key}_BS辅助V230320"` +/// +/// 信号逻辑: +/// 1. 计算 `MACD/DIF/DEA` 序列; +/// 2. 取倒数 `di` 截止的最近 `max_overlap+1` 根值; +/// 3. 若 `last > 0` 且前序存在 `< 0` 判 `多头`; +/// 4. 若 `last < 0` 且前序存在 `> 0` 判 `空头`; +/// 5. 否则判 `其他`;方向由 `last >= prev` 判 `向上/向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9MO3#MACD_BS辅助V230320_多头_向上_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9MO3#DIF_BS辅助V230320_空头_向下_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `key`:指标键,`MACD/DIF/DEA`,默认 `MACD`; +/// - `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`; +/// - `max_overlap`:最大重叠窗口,默认 `3`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_base_V230320", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}MO{max_overlap}#{key}_BS辅助V230320", + opcode = "TasMacdBaseV230320", + param_kind = "TasMacdBaseV230320" +)] +pub fn tas_macd_base_v230320(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let key = get_str_param(params, "key", "MACD").to_uppercase(); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let max_overlap = get_usize_param(params, "max_overlap", 3); + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}MACD{}#{}#{}MO{}#{}", + di, fastperiod, slowperiod, signalperiod, max_overlap, key + ); + let k3 = "BS辅助V230320"; + + if !matches!(key.as_str(), "MACD" | "DIF" | "DEA") || czsc.bars_raw.len() < 5 + di + max_overlap + { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + + let mc = cache.macd.get(&cache_key).unwrap(); + let series = match key.as_str() { + "DIF" => &mc.dif, + "DEA" => &mc.dea, + _ => &mc.macd, + }; + let values = get_sub_elements(series, di, max_overlap + 1); + if values.len() < max_overlap + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let last = *values.last().unwrap(); + let prev = values[values.len() - 2]; + let has_neg = values[..values.len() - 1].iter().any(|x| *x < 0.0); + let has_pos = values[..values.len() - 1].iter().any(|x| *x > 0.0); + let v1 = if last > 0.0 && has_neg { + "多头" + } else if last < 0.0 && has_pos { + "空头" + } else { + "其他" + }; + + if v1 == "其他" { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let v2 = if last >= prev { "向上" } else { "向下" }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_macd_change_V221105:MACD变色次数信号 +/// +/// 参数模板:`"{freq}_D{di}K{n}#MACD{fastperiod}#{slowperiod}#{signalperiod}变色次数_BS辅助V221105"` +/// +/// 信号逻辑: +/// 1. 在最近 `n` 根上计算 DIF/DEA 金叉死叉序列; +/// 2. 过滤 `距离<2` 的抖动交叉; +/// 3. 同类型连续交叉按 Python 语义合并; +/// 4. 输出合并后次数 `"{num}次"`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K55#MACD12#26#9变色次数_BS辅助V221105_0次_任意_任意_0')` +/// - `Signal('60分钟_D1K55#MACD12#26#9变色次数_BS辅助V221105_3次_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `n`:统计窗口长度,默认 `55`; +/// - `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_change_V221105", + template = "{freq}_D{di}K{n}#MACD{fastperiod}#{slowperiod}#{signalperiod}变色次数_BS辅助V221105", + opcode = "TasMacdChangeV221105", + param_kind = "TasMacdChangeV221105" +)] +pub fn tas_macd_change_v221105( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 55); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let macd_cache = cache.macd.get(&cache_key).unwrap(); + let dif = get_sub_elements(&macd_cache.dif, di, n); + let dea = get_sub_elements(&macd_cache.dea, di, n); + + let cross = fast_slow_cross(dif, dea); + let re_cross: Vec<_> = cross + .into_iter() + .filter(|x| x.get("距离").copied().unwrap_or(0.0) >= 2.0) + .collect(); + + let num = if re_cross.is_empty() { + 0 + } else { + let mut merged: Vec> = Vec::new(); + for c in re_cross { + if !merged.is_empty() + && c.get("类型").copied().unwrap_or(0.0) + == merged + .last() + .and_then(|x| x.get("类型")) + .copied() + .unwrap_or(0.0) + { + merged.pop(); + } + merged.push(c); + } + merged.len() + }; + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}K{}#MACD{}#{}#{}变色次数", + di, n, fastperiod, slowperiod, signalperiod + ); + let k3 = "BS辅助V221105"; + let v1 = format!("{}次", num); + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// tas_dif_zero_V240614:DIF靠近零轴买卖点信号 +/// +/// 参数模板:`"{freq}_DIF靠近零轴W{w}T{t}_BS辅助V240614"` +/// +/// 信号逻辑: +/// 1. 取最近 `w` 根K线的 `DIF` 序列; +/// 2. 计算 `delta = std(diffs) * t / 100`; +/// 3. 若 `diffs` 全部大于0,且 `diffs[-1]` 靠近零轴,同时 `max(diffs)` 显著高于均值+标准差,判 `买点`; +/// 4. 若 `diffs` 全部小于0,且 `diffs[-1]` 靠近零轴,同时 `min(diffs)` 显著低于-(均值+标准差),判 `卖点`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_DIF靠近零轴W20T50_BS辅助V240614_买点_任意_任意_0')` +/// - `Signal('60分钟_DIF靠近零轴W20T50_BS辅助V240614_卖点_任意_任意_0')` +/// - `Signal('60分钟_DIF靠近零轴W20T50_BS辅助V240614_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `w`:K线窗口长度,默认 `20`; +/// - `t`:波动率倍数(除以100),默认 `50`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_dif_zero_V240614", + template = "{freq}_DIF靠近零轴W{w}T{t}_BS辅助V240614", + opcode = "TasDifZeroV240614", + param_kind = "TasDifZeroV240614" +)] +pub fn tas_dif_zero_v240614(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let w = get_usize_param(params, "w", 20); + let t = get_usize_param(params, "t", 50) as f64; + + let k1 = czsc.freq.to_string(); + let k2 = format!("DIF靠近零轴W{}T{}", w, t as usize); + let k3 = "BS辅助V240614"; + let mut v1 = "其他"; + + if czsc.bars_raw.len() < 110 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let macd_cache = cache.macd.get(cache_key).unwrap(); + let diffs = get_sub_elements(&macd_cache.dif, 1, w); + if diffs.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let mean = diffs.iter().sum::() / diffs.len() as f64; + let variance = diffs.iter().map(|x| (x - mean).powi(2)).sum::() / diffs.len() as f64; + let std_diff = variance.sqrt(); + let delta = std_diff * t / 100.0; + let max_diff = diffs.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let min_diff = diffs.iter().copied().fold(f64::INFINITY, f64::min); + let abs_mean_diff = mean.abs(); + let last = *diffs.last().unwrap(); + + let all_pos = diffs.iter().all(|&x| x > 0.0); + let all_neg = diffs.iter().all(|&x| x < 0.0); + if all_pos && delta > last && last > -delta && max_diff > abs_mean_diff + std_diff { + v1 = "买点"; + } + if all_neg && -delta < last && last < delta && min_diff < -(abs_mean_diff + std_diff) { + v1 = "卖点"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_dif_zero_V240612:DIF靠近零轴买卖点信号(基于最近一笔) +/// +/// 参数模板:`"{freq}_DIF靠近零轴T{t}_BS辅助V240612"` +/// +/// 信号逻辑: +/// 1. 取最近一笔内部原始K线的 `DIF` 序列; +/// 2. 计算 `delta = std(diffs) * t / 100`; +/// 3. 若最后一笔为向下笔,且末端 `DIF` 靠近零轴,同时 `max(diffs)` 显著高于均值+标准差,判 `买点`; +/// 4. 若最后一笔为向上笔,且末端 `DIF` 靠近零轴,同时 `min(diffs)` 显著低于-(均值+标准差),判 `卖点`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_DIF靠近零轴T50_BS辅助V240612_买点_任意_任意_0')` +/// - `Signal('60分钟_DIF靠近零轴T50_BS辅助V240612_卖点_任意_任意_0')` +/// - `Signal('60分钟_DIF靠近零轴T50_BS辅助V240612_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `t`:波动率倍数(除以100),默认 `50`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_dif_zero_V240612", + template = "{freq}_DIF靠近零轴T{t}_BS辅助V240612", + opcode = "TasDifZeroV240612", + param_kind = "TasDifZeroV240612" +)] +pub fn tas_dif_zero_v240612(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let t = get_usize_param(params, "t", 50) as f64; + let k1 = czsc.freq.to_string(); + let k2 = format!("DIF靠近零轴T{}", t as usize); + let k3 = "BS辅助V240612"; + let mut v1 = "其他"; + + if czsc.bars_raw.len() < 110 || czsc.bars_ubi.len() > 7 || czsc.bi_list.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let macd_cache = cache.macd.get(cache_key).unwrap(); + let last_bi = czsc.bi_list.last().unwrap(); + let raw_bars = last_bi.get_raw_bars(); + if raw_bars.len() < 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let id_to_idx: HashMap = czsc + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + let mut diffs = Vec::with_capacity(raw_bars.len()); + for b in &raw_bars { + if let Some(&idx) = id_to_idx.get(&b.id) { + diffs.push(macd_cache.dif[idx]); + } + } + if diffs.len() < 7 || diffs.iter().any(|x| !x.is_finite()) { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let mean = diffs.iter().sum::() / diffs.len() as f64; + let variance = diffs.iter().map(|x| (x - mean).powi(2)).sum::() / diffs.len() as f64; + let std_diff = variance.sqrt(); + let delta = std_diff * t / 100.0; + let max_diff = diffs.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let min_diff = diffs.iter().copied().fold(f64::INFINITY, f64::min); + let abs_mean_diff = mean.abs(); + let last = *diffs.last().unwrap(); + + if matches!(last_bi.direction, Direction::Down) + && delta > last + && last > -delta + && max_diff > abs_mean_diff + std_diff + { + v1 = "买点"; + } + if matches!(last_bi.direction, Direction::Up) + && -delta < last + && last < delta + && min_diff < -(abs_mean_diff + std_diff) + { + v1 = "卖点"; + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bc_V221201:MACD背驰辅助信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}#MACD{fastperiod}#{slowperiod}#{signalperiod}_BCV221201"` +/// +/// 信号逻辑: +/// 1. 取最近 `m+n` 根K线,前 `m` 为对照窗口,后 `n` 为近端窗口; +/// 2. 若近端价格创新低且MACD低点抬高,判 `底部` 背驰; +/// 3. 若近端价格创新高且MACD高点走低,判 `顶部` 背驰; +/// 4. 并给出当前柱体颜色 `红柱/绿柱`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N3M50#MACD12#26#9_BCV221201_底部_绿柱_任意_0')` +/// - `Signal('60分钟_D1N3M50#MACD12#26#9_BCV221201_顶部_红柱_任意_0')` +/// - `Signal('60分钟_D1N3M50#MACD12#26#9_BCV221201_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n/m`:近端窗口与对照窗口长度,默认 `3/50`; +/// - `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_bc_V221201", + template = "{freq}_D{di}N{n}M{m}#MACD{fastperiod}#{slowperiod}#{signalperiod}_BCV221201", + opcode = "TasMacdBcV221201", + param_kind = "TasMacdBcV221201" +)] +pub fn tas_macd_bc_v221201(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 3); + let m = get_usize_param(params, "m", 50); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + crate::utils::ta::update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}N{}M{}#MACD{}#{}#{}", + di, n, m, fastperiod, slowperiod, signalperiod + ); + let k3 = "BCV221201"; + let mut v1 = "其他"; + let mut v2 = "任意"; + + let macd_cache = cache.macd.get(&cache_key).unwrap(); + let bars = get_sub_elements(&czsc.bars_raw, di, n + m); + let macd_sub = get_sub_elements(&macd_cache.macd, di, n + m); + if bars.len() == n + m && macd_sub.len() == n + m { + let m_close: Vec = bars[..m].iter().map(|x| x.close).collect(); + let m_macd = macd_sub[..m].to_vec(); + let n_close: Vec = bars[m..].iter().map(|x| x.close).collect(); + let n_macd = macd_sub[m..].to_vec(); + + let n_macd_last = n_macd.last().unwrap(); + let n_macd_prev = n_macd.get(n_macd.len() - 2).unwrap_or(&0.0); + + // 对齐 Python 内建 min/max 的 NaN 语义(首元素为 NaN 时结果保持 NaN) + let py_min = |xs: &[f64]| -> f64 { + let mut it = xs.iter(); + let mut best = *it.next().unwrap_or(&f64::NAN); + for &x in it { + if x < best { + best = x; + } + } + best + }; + let py_max = |xs: &[f64]| -> f64 { + let mut it = xs.iter(); + let mut best = *it.next().unwrap_or(&f64::NAN); + for &x in it { + if x > best { + best = x; + } + } + best + }; + + if n_macd_last > n_macd_prev { + let min_n_close = py_min(&n_close); + let min_m_close = py_min(&m_close); + let min_n_macd = py_min(&n_macd); + let min_m_macd = py_min(&m_macd); + if min_n_close < min_m_close && min_n_macd > min_m_macd { + v1 = "底部"; + } + } else if n_macd_last < n_macd_prev { + let max_n_close = py_max(&n_close); + let max_m_close = py_max(&m_close); + let max_n_macd = py_max(&n_macd); + let max_m_macd = py_max(&m_macd); + if max_n_close > max_m_close && max_n_macd < max_m_macd { + v1 = "顶部"; + } + } + + v2 = if *n_macd_last > 0.0 { + "红柱" + } else { + "绿柱" + }; + } + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_angle_V230802:笔角度偏离信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}T{th}_笔角度V230802"` +/// +/// 信号逻辑: +/// 1. 定义角度为 `power_price / length`; +/// 2. 取同向历史 `n` 笔角度均值作为基线; +/// 3. 当前角度低于 `th%` 时输出反向信号。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N9T50_笔角度V230802_空头_任意_任意_0')` +/// - `Signal('60分钟_D1N9T50_笔角度V230802_多头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 笔,默认 `1`; +/// - `n`:同向样本数,默认 `9`; +/// - `th`:角度阈值百分比,默认 `50`。 +/// 对齐说明:`length` 口径使用 `BI.length`(无包含K数量)与 Python 对齐。 +#[signal( + category = "kline", + name = "tas_angle_V230802", + template = "{freq}_D{di}N{n}T{th}_笔角度V230802", + opcode = "TasAngleV230802", + param_kind = "TasAngleV230802" +)] +pub fn tas_angle_v230802(czsc: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 9); + let th = get_usize_param(params, "th", 50); + assert!(th > 30 && th < 300, "th 取值范围为 30 ~ 300"); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}T{}", di, n, th); + let k3 = "笔角度V230802"; + let mut v1 = "其他"; + + if czsc.bi_list.len() < di + 2 * n + 2 || czsc.bars_ubi.len() >= 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let bis = get_sub_elements(&czsc.bi_list, di, n * 2 + 1); + if bis.len() < n * 2 + 1 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let b1 = bis.last().unwrap(); + let b1_len = b1.bars.len(); + if b1_len == 0 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let b1_angle = b1.get_power_price() / b1_len as f64; + let same_dir_ang: Vec = bis[..bis.len() - 1] + .iter() + .filter(|bi| bi.direction == b1.direction) + .filter_map(|bi| { + let l = bi.bars.len(); + if l > 0 { + Some(bi.get_power_price() / l as f64) + } else { + None + } + }) + .rev() + .take(n) + .collect::>() + .into_iter() + .rev() + .collect(); + if !same_dir_ang.is_empty() { + let mean = same_dir_ang.iter().sum::() / same_dir_ang.len() as f64; + if b1_angle < mean * th as f64 / 100.0 { + v1 = if b1.direction == Direction::Up { + "空头" + } else { + "多头" + }; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_double_ma_V240208:双均线交叉结构信号 +/// +/// 参数模板:`"{freq}_D{di}N{N}M{M}双均线_BS辅助V240208"` +/// +/// 信号逻辑: +/// 1. 计算 `N/M` 双均线并识别交叉序列; +/// 2. 最近三次交叉记作 `X1/X2/X3`; +/// 3. `X3` 金叉且 `X2` 快线最高判 `多头`,死叉镜像判 `空头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N20M60双均线_BS辅助V240208_多头_任意_任意_0')` +/// - `Signal('60分钟_D1N20M60双均线_BS辅助V240208_空头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `N`:快线周期,默认 `20`; +/// - `M`:慢线周期,默认 `60`。 +/// 对齐说明:交叉类型与快线比较逻辑对齐 Python `tas_double_ma_V240208`。 +#[signal( + category = "kline", + name = "tas_double_ma_V240208", + template = "{freq}_D{di}N{N}M{M}双均线_BS辅助V240208", + opcode = "TasDoubleMaV240208", + param_kind = "TasDoubleMaV240208" +)] +pub fn tas_double_ma_v240208(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "N", 20); + let m = get_usize_param(params, "M", 60); + assert!(n < m, "N 必须小于 M"); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}M{}双均线", di, n, m); + let k3 = "BS辅助V240208"; + let mut v1 = "其他"; + + let key_fast = format!("{}_{}_{}", czsc.freq, "SMA", n); + let key_slow = format!("{}_{}_{}", czsc.freq, "SMA", m); + update_ma_cache(czsc, &key_fast, "SMA", n, cache); + update_ma_cache(czsc, &key_slow, "SMA", m, cache); + let fast_all = cache.series.get(&key_fast).unwrap(); + let slow_all = cache.series.get(&key_slow).unwrap(); + let bars = get_sub_elements(&czsc.bars_raw, di, m * 30); + if bars.is_empty() || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let fast_ma: Vec = (start..end).map(|i| fast_all[i]).collect(); + let slow_ma: Vec = (start..end).map(|i| slow_all[i]).collect(); + let cross_info = fast_slow_cross(&fast_ma, &slow_ma); + if cross_info.len() >= 3 { + let x1 = &cross_info[cross_info.len() - 3]; + let x2 = &cross_info[cross_info.len() - 2]; + let x3 = &cross_info[cross_info.len() - 1]; + if x3.get("类型").copied().unwrap_or(0.0) > 0.0 + && x2.get("快线").copied().unwrap_or(f64::NEG_INFINITY) + > x1.get("快线") + .copied() + .unwrap_or(f64::NEG_INFINITY) + .max(x3.get("快线").copied().unwrap_or(f64::NEG_INFINITY)) + { + v1 = "多头"; + } else if x3.get("类型").copied().unwrap_or(0.0) < 0.0 + && x2.get("快线").copied().unwrap_or(f64::INFINITY) + < x1.get("快线") + .copied() + .unwrap_or(f64::INFINITY) + .min(x3.get("快线").copied().unwrap_or(f64::INFINITY)) + { + v1 = "空头"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_dma_bs_V240608:双均线顺势回调买卖点 +/// +/// 参数模板:`"{freq}_N{n}双均线{t1}#{t2}顺势_BS辅助V240608"` +/// +/// 信号逻辑: +/// 1. 以 `t1/t2` 均线顺势方向做过滤; +/// 2. 在 `ma2` 附近按离散价格序号选取回踩/反抽位; +/// 3. 满足穿越与收盘条件时给出 `买点/卖点`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N5双均线5#10顺势_BS辅助V240608_买点_任意_任意_0')` +/// - `Signal('60分钟_N5双均线5#10顺势_BS辅助V240608_卖点_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:价格序号偏移,默认 `5`; +/// - `t1`:快线周期,默认 `5`; +/// - `t2`:慢线周期,默认 `10`。 +/// 对齐说明:价格位选取(含负索引语义)与 Python `tas_dma_bs_V240608` 一致。 +#[signal( + category = "kline", + name = "tas_dma_bs_V240608", + template = "{freq}_N{n}双均线{t1}#{t2}顺势_BS辅助V240608", + opcode = "TasDmaBsV240608", + param_kind = "TasDmaBsV240608" +)] +pub fn tas_dma_bs_v240608(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 5); + let t1 = get_usize_param(params, "t1", 5); + let t2 = get_usize_param(params, "t2", 10); + assert!(t1 < t2, "均线1的周期必须小于均线2的周期"); + + let k1 = czsc.freq.to_string(); + let k2 = format!("N{}双均线{}#{}顺势", n, t1, t2); + let k3 = "BS辅助V240608"; + let mut v1 = "其他"; + + let ma1_key = format!("{}_{}_{}", czsc.freq, "SMA", t1); + let ma2_key = format!("{}_{}_{}", czsc.freq, "SMA", t2); + update_ma_cache(czsc, &ma1_key, "SMA", t1, cache); + update_ma_cache(czsc, &ma2_key, "SMA", t2, cache); + let ma1 = cache.series.get(&ma1_key).unwrap(); + let ma2 = cache.series.get(&ma2_key).unwrap(); + + if czsc.bars_raw.len() < 110 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let bars = &czsc.bars_raw[czsc.bars_raw.len() - 100..]; + let mut unique_prices: Vec = bars + .iter() + .flat_map(|x| [x.close, x.high, x.low, x.open]) + .collect(); + unique_prices.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + unique_prices.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON); + + let idx2 = czsc.bars_raw.len() - 1; + let idx1 = czsc.bars_raw.len() - 2; + let bar1 = &czsc.bars_raw[idx1]; + let bar2 = &czsc.bars_raw[idx2]; + let ma1_value = ma1[idx2]; + let ma2_value = ma2[idx2]; + let lower_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x < ma2_value) + .collect(); + let upper_prices: Vec = unique_prices + .iter() + .copied() + .filter(|x| *x > ma2_value) + .collect(); + + if !upper_prices.is_empty() && ma1_value > ma2_value && ma2[idx2] > ma2[idx1] { + let ma2_round_high = if upper_prices.len() > n { + upper_prices[n] + } else { + *upper_prices.last().unwrap() + }; + if bar1.low < ma2_round_high && ma2_round_high < bar2.high && bar2.close < ma2_round_high { + v1 = "买点"; + } + } else if !lower_prices.is_empty() && ma1_value < ma2_value && ma2[idx2] < ma2[idx1] { + let ma2_round_low = if lower_prices.len() > n { + lower_prices[lower_prices.len() - n] + } else { + lower_prices[0] + }; + if bar1.high > ma2_round_low && ma2_round_low > bar2.low && bar2.close > ma2_round_low { + v1 = "卖点"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bc_V230803:双分型 MACD 背驰信号 +/// +/// 参数模板:`"{freq}_MACD双分型背驰_BS辅助V230803"` +/// +/// 信号逻辑: +/// 1. 提取最近分型列表中的同类顶/底分型; +/// 2. 对比两个分型中间K线的 MACD 柱值; +/// 3. 向上笔出现 `macd1 > macd2 > 0` 判 `空头`,向下笔镜像判 `多头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_MACD双分型背驰_BS辅助V230803_空头_任意_任意_0')` +/// - `Signal('60分钟_MACD双分型背驰_BS辅助V230803_多头_任意_任意_0')` +/// +/// 参数说明: +/// - 无额外参数。 +/// 对齐说明:分型来源改为 `get_fx_list()`,与 Python `c.fx_list` 语义对齐。 +#[signal( + category = "kline", + name = "tas_macd_bc_V230803", + template = "{freq}_MACD双分型背驰_BS辅助V230803", + opcode = "TasMacdBcV230803", + param_kind = "TasMacdBcV230803" +)] +pub fn tas_macd_bc_v230803(czsc: &CZSC, _params: &ParamView, cache: &mut TaCache) -> Vec { + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let k1 = czsc.freq.to_string(); + let k2 = "MACD双分型背驰"; + let k3 = "BS辅助V230803"; + let mut v1 = "其他"; + + if czsc.bi_list.len() < 3 || czsc.bars_ubi.len() >= 7 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + let mc = cache.macd.get(cache_key).unwrap(); + let bar_idx: HashMap = czsc + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + let get_macd = |bar_id: i32| -> Option { bar_idx.get(&bar_id).map(|i| mc.macd[*i]) }; + + let fx_list = czsc.get_fx_list(); + let b1 = czsc.bi_list.last().unwrap(); + if b1.direction == Direction::Up { + let tops: Vec<_> = fx_list + .iter() + .rev() + .take(10) + .filter(|fx| fx.mark == Mark::G) + .collect::>() + .into_iter() + .rev() + .collect(); + if tops.len() >= 2 { + let fx1 = tops[tops.len() - 2]; + let fx2 = tops[tops.len() - 1]; + let id1 = fx1 + .elements + .iter() + .flat_map(|nb| nb.elements.iter()) + .nth(1) + .map(|x| x.id); + let id2 = fx2 + .elements + .iter() + .flat_map(|nb| nb.elements.iter()) + .nth(1) + .map(|x| x.id); + if let (Some(i1), Some(i2)) = (id1, id2) { + if let (Some(macd1), Some(macd2)) = (get_macd(i1), get_macd(i2)) { + if macd1 > macd2 && macd2 > 0.0 { + v1 = "空头"; + } + } + } + } + } else { + let bottoms: Vec<_> = fx_list + .iter() + .rev() + .take(10) + .filter(|fx| fx.mark == Mark::D) + .collect::>() + .into_iter() + .rev() + .collect(); + if bottoms.len() >= 2 { + let fx1 = bottoms[bottoms.len() - 2]; + let fx2 = bottoms[bottoms.len() - 1]; + let id1 = fx1 + .elements + .iter() + .flat_map(|nb| nb.elements.iter()) + .nth(1) + .map(|x| x.id); + let id2 = fx2 + .elements + .iter() + .flat_map(|nb| nb.elements.iter()) + .nth(1) + .map(|x| x.id); + if let (Some(i1), Some(i2)) = (id1, id2) { + if let (Some(macd1), Some(macd2)) = (get_macd(i1), get_macd(i2)) { + if macd1 < macd2 && macd2 < 0.0 { + v1 = "多头"; + } + } + } + } + } + + make_kline_signal_v1(&k1, k2, k3, v1) +} + +/// tas_macd_bc_V240307:MACD 柱背驰计次信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}柱子背驰_BS辅助V240307"` +/// +/// 信号逻辑: +/// 1. 在窗口内识别 MACD 柱局部顶/底; +/// 2. 顶部减弱并满足间隔条件判 `顶背驰`,底部抬高镜像判 `底背驰`; +/// 3. 输出距离最近顶/底的计次数 `第k次`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N20柱子背驰_BS辅助V240307_顶背驰_第2次_任意_0')` +/// - `Signal('60分钟_D1N20柱子背驰_BS辅助V240307_底背驰_第1次_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:观察窗口,默认 `20`。 +/// 对齐说明:峰谷识别、间隔阈值与统计口径对齐 Python `tas_macd_bc_V240307`。 +#[signal( + category = "kline", + name = "tas_macd_bc_V240307", + template = "{freq}_D{di}N{n}柱子背驰_BS辅助V240307", + opcode = "TasMacdBcV240307", + param_kind = "TasMacdBcV240307" +)] +pub fn tas_macd_bc_v240307(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 20); + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}柱子背驰", di, n); + let k3 = "BS辅助V240307"; + let mut v1 = "其他"; + let mut v2 = "其他".to_string(); + + if czsc.bars_raw.len() < 7 + n || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v2(&k1, &k2, k3, v1, &v2); + } + let bars = get_sub_elements(&czsc.bars_raw, di, n); + let mc = cache.macd.get(cache_key).unwrap(); + let end = czsc.bars_raw.len() - di + 1; + let start = end - bars.len(); + let macd: Vec = (start..end).map(|i| mc.macd[i]).collect(); + let m_len = macd.len(); + + let gs: Vec = (1..m_len.saturating_sub(1)) + .filter(|&i| macd[i - 1] < macd[i] && macd[i] > macd[i + 1] && macd[i] > 0.0) + .collect(); + let ds: Vec = (1..m_len.saturating_sub(1)) + .filter(|&i| macd[i - 1] > macd[i] && macd[i] < macd[i + 1] && macd[i] < 0.0) + .collect(); + + if macd.last().copied().unwrap_or(0.0) > 0.0 + && gs.len() >= 2 + && macd[*gs.last().unwrap()] < macd[gs[gs.len() - 2]] + && gs[gs.len() - 1] - gs[gs.len() - 2] > 2 + { + let macd_sub = &macd[gs[gs.len() - 2]..]; + let neg_sum: f64 = macd_sub.iter().filter(|x| **x < 0.0).sum::().abs(); + let std_abs = std_abs_series(macd_sub); + if neg_sum < std_abs { + v1 = "顶背驰"; + v2 = format!("第{}次", m_len - gs[gs.len() - 1] - 1); + } + } + if macd.last().copied().unwrap_or(0.0) < 0.0 + && ds.len() >= 2 + && macd[*ds.last().unwrap()] > macd[ds[ds.len() - 2]] + && ds[ds.len() - 1] - ds[ds.len() - 2] > 2 + { + let macd_sub = &macd[ds[ds.len() - 2]..]; + let pos_sum: f64 = macd_sub.iter().filter(|x| **x > 0.0).sum::().abs(); + let std_abs = std_abs_series(macd_sub); + if pos_sum < std_abs { + v1 = "底背驰"; + v2 = format!("第{}次", m_len - ds[ds.len() - 1] - 1); + } + } + + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} + +/// tas_macd_dist_V230408:DIF/DEA/MACD等宽分层信号 +/// +/// 参数模板:`"{freq}_{key}分层W{w}N{n}_BS辅助V230408"` +/// +/// 信号逻辑: +/// 1. 获取最近 `w` 根K线的 `DIF/DEA/MACD` 序列; +/// 2. 按等宽区间切分为 `n` 层; +/// 3. 返回最后一个值所在层级 `第{q}层`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_DIF分层W100N10_BS辅助V230408_第3层_任意_任意_0')` +/// - `Signal('60分钟_MACD分层W100N10_BS辅助V230408_第8层_任意_任意_0')` +/// +/// 参数说明: +/// - `key`:`DIF/DEA/MACD`,默认 `DIF`; +/// - `w`:窗口长度,默认 `100`; +/// - `n`:分层数量,默认 `10`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_dist_V230408", + template = "{freq}_{key}分层W{w}N{n}_BS辅助V230408", + opcode = "TasMacdDistV230408", + param_kind = "TasMacdDistV230408" +)] +pub fn tas_macd_dist_v230408(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 10); + let w = get_usize_param(params, "w", 100); + let key = get_str_param(params, "key", "dif").to_uppercase(); + let k1 = czsc.freq.to_string(); + let k2 = format!("{}分层W{}N{}", key, w, n); + let k3 = "BS辅助V230408"; + if !matches!(key.as_str(), "DIF" | "DEA" | "MACD") || czsc.bi_list.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let series = match key.as_str() { + "DIF" => &mc.dif, + "DEA" => &mc.dea, + _ => &mc.macd, + }; + let factors = get_sub_elements(series, 1, w); + let v1 = pd_cut_last_label(factors, n) + .map(|q| format!("第{}层", q)) + .unwrap_or_else(|| "其他".to_string()); + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// tas_macd_dist_V230409:DIF/DEA/MACD远离零轴信号 +/// +/// 参数模板:`"{freq}_{key}远离W{w}N{n}T{t}_BS辅助V230409"` +/// +/// 信号逻辑: +/// 1. 获取最近 `w` 根K线指标值并计算绝对值均值; +/// 2. 若最近 `n` 根中绝对值最大者超过 `mean * t/10`,判定远离零轴; +/// 3. 按最后一个值符号输出 `多头远离/空头远离`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_DIF远离W100N10T20_BS辅助V230409_多头远离_任意_任意_0')` +/// - `Signal('60分钟_DIF远离W100N10T20_BS辅助V230409_空头远离_任意_任意_0')` +/// +/// 参数说明: +/// - `key`:`DIF/DEA/MACD`,默认 `DIF`; +/// - `w`:窗口长度,默认 `100`; +/// - `n`:最近判定窗口,默认 `10`; +/// - `t`:远离阈值倍率(除以10),默认 `20`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_dist_V230409", + template = "{freq}_{key}远离W{w}N{n}T{t}_BS辅助V230409", + opcode = "TasMacdDistV230409", + param_kind = "TasMacdDistV230409" +)] +pub fn tas_macd_dist_v230409(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 10); + let w = get_usize_param(params, "w", 100); + let t = get_usize_param(params, "t", 20) as f64; + let key = get_str_param(params, "key", "dif").to_uppercase(); + let k1 = czsc.freq.to_string(); + let k2 = format!("{}远离W{}N{}T{}", key, w, n, t as usize); + let k3 = "BS辅助V230409"; + let mut v1 = "其他".to_string(); + if !matches!(key.as_str(), "DIF" | "DEA" | "MACD") || czsc.bi_list.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, &v1); + } + + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let series = match key.as_str() { + "DIF" => &mc.dif, + "DEA" => &mc.dea, + _ => &mc.macd, + }; + let factors = get_sub_elements(series, 1, w); + if !factors.is_empty() && factors.iter().all(|x| x.is_finite()) { + let mean_abs = factors.iter().map(|x| x.abs()).sum::() / factors.len() as f64; + let recent = get_sub_elements(factors, 1, n); + let recent_abs_max = recent.iter().map(|x| x.abs()).fold(0.0f64, f64::max); + if recent_abs_max > mean_abs * t / 10.0 { + let last = *factors.last().unwrap(); + v1 = if last > 0.0 { + "多头远离".to_string() + } else { + "空头远离".to_string() + }; + } + } + make_kline_signal_v1(&k1, &k2, k3, &v1) +} + +/// tas_macd_dist_V230410:DIF/DEA/MACD多空分层信号 +/// +/// 参数模板:`"{freq}_{key}多空分层W{w}N{n}_BS辅助V230410"` +/// +/// 信号逻辑: +/// 1. 取最近 `w` 根指标序列,按最后一值符号判 `多头/空头`; +/// 2. 仅保留同符号样本并等宽分层为 `n` 层; +/// 3. 输出 `多头/空头` 与 `第{q}层`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_DIF多空分层W200N5_BS辅助V230410_多头_第2层_任意_0')` +/// - `Signal('60分钟_DIF多空分层W200N5_BS辅助V230410_空头_第4层_任意_0')` +/// +/// 参数说明: +/// - `key`:`DIF/DEA/MACD`,默认 `DIF`; +/// - `w`:窗口长度,默认 `200`; +/// - `n`:分层数量,默认 `5`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_dist_V230410", + template = "{freq}_{key}多空分层W{w}N{n}_BS辅助V230410", + opcode = "TasMacdDistV230410", + param_kind = "TasMacdDistV230410" +)] +pub fn tas_macd_dist_v230410(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 5); + let w = get_usize_param(params, "w", 200); + let key = get_str_param(params, "key", "dif").to_uppercase(); + let k1 = czsc.freq.to_string(); + let k2 = format!("{}多空分层W{}N{}", key, w, n); + let k3 = "BS辅助V230410"; + if !matches!(key.as_str(), "DIF" | "DEA" | "MACD") || czsc.bi_list.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let series = match key.as_str() { + "DIF" => &mc.dif, + "DEA" => &mc.dea, + _ => &mc.macd, + }; + let factors_all = get_sub_elements(series, 1, w); + if factors_all.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let last = *factors_all.last().unwrap(); + let v1 = if last > 0.0 { "多头" } else { "空头" }; + let factors: Vec = if v1 == "多头" { + factors_all.iter().copied().filter(|x| *x > 0.0).collect() + } else { + factors_all.iter().copied().filter(|x| *x < 0.0).collect() + }; + let v2 = pd_cut_last_label(&factors, n) + .map(|q| format!("第{}层", q)) + .unwrap_or_else(|| "任意".to_string()); + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} + +/// tas_macd_first_bs_V221201:MACD一买一卖辅助信号 +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221201"` +/// +/// 信号逻辑: +/// 1. 在近 300 根内统计 DIF/DEA 金叉死叉序列; +/// 2. 满足特定零轴位置与节奏条件时,给出 `一买` 或 `一卖`; +/// 3. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V221201_一买_任意_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V221201_一卖_任意_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V221201_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_macd_first_bs_V221201", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221201", + opcode = "TasMacdFirstBsV221201", + param_kind = "TasMacdFirstBsV221201" +)] +pub fn tas_macd_first_bs_v221201( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + crate::utils::ta::update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + + let macd_cache = cache.macd.get(&cache_key).unwrap(); + + // 对齐 Python: bars = get_sub_elements(c.bars_raw, di=di, n=300) + // Python 条件: if len(bars) >= 100 + let mut v1 = "其他"; + let bars = get_sub_elements(&czsc.bars_raw, di, 300); + if bars.len() >= 100 { + let dif = get_sub_elements(&macd_cache.dif, di, 300); + let dea = get_sub_elements(&macd_cache.dea, di, 300); + let macd = get_sub_elements(&macd_cache.macd, di, 300); + + let cross = fast_slow_cross(dif, dea); + let up: Vec<_> = cross + .iter() + .filter(|x| x["类型"] == 1.0 && x["距离"] > 5.0) + .collect(); + let dn: Vec<_> = cross + .iter() + .filter(|x| x["类型"] == -1.0 && x["距离"] > 5.0) + .collect(); + + // 对齐 Python: 各条件独立检查长度 + let b1_con1 = cross.len() > 3 + && cross.last().unwrap()["类型"] == -1.0 + && cross.last().unwrap()["慢线"] < 0.0; + let b1_con2 = + dn.len() > 3 && dn[dn.len() - 2]["慢线"] < 0.0 && dn[dn.len() - 3]["慢线"] < 0.0; + let b1_con3 = macd.len() > 10 && macd[macd.len() - 1] > macd[macd.len() - 2]; + + if b1_con1 && b1_con2 && b1_con3 { + v1 = "一买"; + } + + let s1_con1 = cross.len() > 3 + && cross.last().unwrap()["类型"] == 1.0 + && cross.last().unwrap()["慢线"] > 0.0; + let s1_con2 = + up.len() > 3 && up[up.len() - 2]["慢线"] > 0.0 && up[up.len() - 3]["慢线"] > 0.0; + let s1_con3 = macd.len() > 10 && macd[macd.len() - 1] < macd[macd.len() - 2]; + + if s1_con1 && s1_con2 && s1_con3 { + v1 = "一卖"; + } + } + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MACD{}#{}#{}", di, fastperiod, slowperiod, signalperiod); + let k3 = "BS1辅助V221201"; + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_first_bs_V221216:MACD 第一买卖点(扩展版) +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221216"` +/// +/// 信号逻辑: +/// 1. 以最近 10 根与前 90 根做高低点对比(新高/新低); +/// 2. 结合最近交叉类型、零轴位置与 MACD 方向判断 `一买/一卖`; +/// 3. `v2` 输出最后一次交叉类型(`金叉/死叉`)。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V221216_一买_死叉_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V221216_一卖_金叉_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 +/// 对齐说明:分支条件、`or` 组合与 `v2` 输出语义对齐 Python `tas_macd_first_bs_V221216`。 +#[signal( + category = "kline", + name = "tas_macd_first_bs_V221216", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221216", + opcode = "TasMacdFirstBsV221216", + param_kind = "TasMacdFirstBsV221216" +)] +pub fn tas_macd_first_bs_v221216( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MACD{}#{}#{}", di, fastperiod, slowperiod, signalperiod); + let k3 = "BS1辅助V221216"; + let mut v1 = "其他"; + let mut v2 = "任意"; + + let bars = get_sub_elements(&czsc.bars_raw, di, 300); + if bars.len() >= 100 { + let mc = cache.macd.get(&cache_key).unwrap(); + let dif = get_sub_elements(&mc.dif, di, 300); + let dea = get_sub_elements(&mc.dea, di, 300); + let macd = get_sub_elements(&mc.macd, di, 300); + let n_bars = &bars[bars.len() - 10..]; + let m_bars = &bars[bars.len() - 100..bars.len() - 10]; + let high_n = n_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let low_n = n_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let high_m = m_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let low_m = m_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + + let cross = fast_slow_cross_ext(dif, dea); + let up: Vec<_> = cross + .iter() + .filter(|x| x.kind > 0 && x.distance > 5) + .collect(); + let dn: Vec<_> = cross + .iter() + .filter(|x| x.kind < 0 && x.distance > 5) + .collect(); + if let Some(last) = cross.last() { + let b1_con1a = cross.len() > 3 && last.kind < 0 && last.slow < 0.0; + let b1_con1b = + cross.len() > 3 && last.kind > 0 && !dn.is_empty() && dn[dn.len() - 1].slow < 0.0; + let b1_con2 = + dn.len() > 3 && dn[dn.len() - 2].slow < 0.0 && dn[dn.len() - 3].slow < 0.0; + let b1_con3 = macd.len() > 10 && macd[macd.len() - 1] > macd[macd.len() - 2]; + if low_n < low_m && (b1_con1a || b1_con1b) && b1_con2 && b1_con3 { + v1 = "一买"; + } + + let s1_con1a = cross.len() > 3 && last.kind > 0 && last.slow > 0.0; + let s1_con1b = + cross.len() > 3 && last.kind < 0 && !up.is_empty() && up[up.len() - 1].slow > 0.0; + let s1_con2 = + up.len() > 3 && up[up.len() - 2].slow > 0.0 && up[up.len() - 3].slow > 0.0; + let s1_con3 = macd.len() > 10 && macd[macd.len() - 1] < macd[macd.len() - 2]; + if high_n > high_m && (s1_con1a || s1_con1b) && s1_con2 && s1_con3 { + v1 = "一卖"; + } + v2 = if last.kind > 0 { "金叉" } else { "死叉" }; + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_macd_second_bs_V221201:MACD 第二买卖点 +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS2辅助V221201"` +/// +/// 信号逻辑: +/// 1. 在近 350 根(去掉最早 50 根)统计交叉序列; +/// 2. 结合最近交叉距今、零轴位置与 MACD 方向判 `二买/二卖`; +/// 3. `v2` 返回最后交叉类型。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9_BS2辅助V221201_二买_死叉_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9_BS2辅助V221201_二卖_金叉_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 +/// 对齐说明:`距今` 条件与零轴判定对齐 Python `tas_macd_second_bs_V221201`。 +#[signal( + category = "kline", + name = "tas_macd_second_bs_V221201", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS2辅助V221201", + opcode = "TasMacdSecondBsV221201", + param_kind = "TasMacdSecondBsV221201" +)] +pub fn tas_macd_second_bs_v221201( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MACD{}#{}#{}", di, fastperiod, slowperiod, signalperiod); + let k3 = "BS2辅助V221201"; + let mut v1 = "其他"; + let mut v2 = "任意"; + + let raw = get_sub_elements(&czsc.bars_raw, di, 350); + let bars: Vec<_> = if raw.len() > 50 { + raw[50..].to_vec() + } else { + Vec::new() + }; + if bars.len() >= 100 { + let mc = cache.macd.get(&cache_key).unwrap(); + let end = czsc.bars_raw.len() - di + 1; + let start = end - get_sub_elements(&czsc.bars_raw, di, 350).len() + 50; + let dif: Vec = (0..bars.len()).map(|i| mc.dif[start + i]).collect(); + let dea: Vec = (0..bars.len()).map(|i| mc.dea[start + i]).collect(); + let macd: Vec = (0..bars.len()).map(|i| mc.macd[start + i]).collect(); + + let cross = fast_slow_cross_ext(&dif, &dea); + let up: Vec<_> = cross + .iter() + .filter(|x| x.kind > 0 && x.distance > 5) + .collect(); + let dn: Vec<_> = cross + .iter() + .filter(|x| x.kind < 0 && x.distance > 5) + .collect(); + if let Some(last) = cross.last() { + let b2_con1a = cross.len() > 3 && last.kind < 0 && last.slow > 0.0 && last.to_now > 5; + let b2_con1b = cross.len() > 3 + && last.kind > 0 + && !dn.is_empty() + && dn[dn.len() - 1].slow > 0.0 + && last.to_now < 5; + let b2_con2 = + dn.len() > 4 && dn[dn.len() - 3].slow < 0.0 && dn[dn.len() - 2].slow < 0.0; + let b2_con3 = macd.len() > 10 && macd[macd.len() - 1] > macd[macd.len() - 2]; + if (b2_con1a || b2_con1b) && b2_con2 && b2_con3 { + v1 = "二买"; + } + + let s2_con1a = cross.len() > 3 && last.kind > 0 && last.slow < 0.0 && last.to_now > 5; + let s2_con1b = cross.len() > 3 + && last.kind < 0 + && !up.is_empty() + && up[up.len() - 1].slow < 0.0 + && last.to_now < 5; + let s2_con2 = + up.len() > 4 && up[up.len() - 3].slow > 0.0 && up[up.len() - 2].slow > 0.0; + let s2_con3 = macd.len() > 10 && macd[macd.len() - 1] < macd[macd.len() - 2]; + if (s2_con1a || s2_con1b) && s2_con2 && s2_con3 { + v1 = "二卖"; + } + v2 = if last.kind > 0 { "金叉" } else { "死叉" }; + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_macd_xt_V221208:MACD 柱形态信号 +/// +/// 参数模板:`"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}形态_BS辅助V221208"` +/// +/// 信号逻辑: +/// 1. 读取最近 5 根 MACD 柱; +/// 2. 按柱子相对大小关系判定 `逼空棒/杀多棒/绿抽脚/红缩头`; +/// 3. 按跨零关系判定 `空翻多/多翻空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K#MACD12#26#9形态_BS辅助V221208_逼空棒_任意_任意_0')` +/// - `Signal('60分钟_D1K#MACD12#26#9形态_BS辅助V221208_多翻空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 +/// 对齐说明:形态分支顺序与 Python `tas_macd_xt_V221208` 保持一致。 +#[signal( + category = "kline", + name = "tas_macd_xt_V221208", + template = "{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}形态_BS辅助V221208", + opcode = "TasMacdXtV221208", + param_kind = "TasMacdXtV221208" +)] +pub fn tas_macd_xt_v221208(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}K#MACD{}#{}#{}形态", + di, fastperiod, slowperiod, signalperiod + ); + let k3 = "BS辅助V221208"; + let mut v1 = "其他"; + + let mc = cache.macd.get(&cache_key).unwrap(); + let macd = get_sub_elements(&mc.macd, di, 5); + if macd.len() == 5 { + let min_m = macd.iter().copied().fold(f64::INFINITY, f64::min); + let max_m = macd.iter().copied().fold(f64::NEG_INFINITY, f64::max); + if min_m > 0.0 && macd[4] > macd[3] && macd[3] < macd[1] { + v1 = "逼空棒"; + } else if max_m < 0.0 && macd[4] < macd[3] && macd[3] > macd[1] { + v1 = "杀多棒"; + } else if max_m < 0.0 && macd[4] > macd[3] && macd[3] < macd[1] { + v1 = "绿抽脚"; + } else if min_m > 0.0 && macd[4] < macd[3] && macd[3] > macd[1] { + v1 = "红缩头"; + } else if macd[4] > 0.0 && macd[2] < 0.0 { + v1 = "空翻多"; + } else if macd[2] > 0.0 && macd[4] < 0.0 { + v1 = "多翻空"; + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bs1_V230312:MACD 辅助一买一卖(笔结构) +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230312"` +/// +/// 信号逻辑: +/// 1. 最近 7 笔内,末笔创新低并满足三卖结构且末分型 MACD 抬升,判 `看多`; +/// 2. 镜像条件(创新高 + 三买结构 + MACD 走弱)判 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V230312_看多_任意_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V230312_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 笔,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 +/// 对齐说明:笔结构约束与末分型 MACD 比较逻辑对齐 Python `tas_macd_bs1_V230312`。 +#[signal( + category = "kline", + name = "tas_macd_bs1_V230312", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230312", + opcode = "TasMacdBs1V230312", + param_kind = "TasMacdBs1V230312" +)] +pub fn tas_macd_bs1_v230312(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MACD{}#{}#{}", di, fastperiod, slowperiod, signalperiod); + let k3 = "BS1辅助V230312"; + let mut v1 = "其他"; + + let bis = get_sub_elements(&czsc.bi_list, di, 7); + if bis.len() >= 7 { + let mc = cache.macd.get(&cache_key).unwrap(); + let id_to_idx = bar_index_map(czsc); + let mut snapshot_overrides: HashMap = HashMap::new(); + let last_bi = &bis[bis.len() - 1]; + let last_fx = &last_bi.fx_b; + let last_raw: Vec<_> = last_fx + .elements + .iter() + .flat_map(|nb| nb.elements.iter()) + .collect(); + if !last_raw.is_empty() { + let first_macd = macd_snapshot_field_value( + czsc, + mc, + &id_to_idx, + last_raw.first().unwrap(), + fastperiod, + slowperiod, + signalperiod, + MacdField::Macd, + &mut snapshot_overrides, + ) + .unwrap_or(f64::NAN); + let last_macd = macd_snapshot_field_value( + czsc, + mc, + &id_to_idx, + last_raw.last().unwrap(), + fastperiod, + slowperiod, + signalperiod, + MacdField::Macd, + &mut snapshot_overrides, + ) + .unwrap_or(f64::NAN); + + let up_lows: Vec = bis[..bis.len() - 1] + .iter() + .filter(|x| x.direction == Direction::Up) + .map(|x| x.get_low()) + .collect(); + let down_highs: Vec = bis[..bis.len() - 1] + .iter() + .filter(|x| x.direction == Direction::Down) + .map(|x| x.get_high()) + .collect(); + let min_low = bis + .iter() + .map(|x| x.get_low()) + .fold(f64::INFINITY, f64::min); + let max_high = bis + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + if !up_lows.is_empty() + && last_bi.direction == Direction::Down + && last_bi.get_low() == min_low + && last_bi.get_high() < up_lows.iter().copied().fold(f64::NEG_INFINITY, f64::max) + && last_macd > first_macd + { + v1 = "看多"; + } + if !down_highs.is_empty() + && last_bi.direction == Direction::Up + && last_bi.get_high() == max_high + && last_bi.get_low() > down_highs.iter().copied().fold(f64::INFINITY, f64::min) + && last_macd < first_macd + { + v1 = "看空"; + } + } + } + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bs1_V230313:MACD 红绿柱第一买卖点 +/// +/// 参数模板:`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230313"` +/// +/// 信号逻辑: +/// 1. 近 10 与前 90 根对比新高新低; +/// 2. 用交叉面积递减/递增与 MACD 方向判 `一买/一卖`; +/// 3. `v2` 返回最后交叉类型。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V230313_一买_死叉_任意_0')` +/// - `Signal('60分钟_D1MACD12#26#9_BS1辅助V230313_一卖_金叉_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 +/// 对齐说明:面积比较与条件优先级(`and/or`)按 Python `tas_macd_bs1_V230313` 对齐。 +#[signal( + category = "kline", + name = "tas_macd_bs1_V230313", + template = "{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230313", + opcode = "TasMacdBs1V230313", + param_kind = "TasMacdBs1V230313" +)] +pub fn tas_macd_bs1_v230313(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastperiod = get_usize_param(params, "fastperiod", 12); + let slowperiod = get_usize_param(params, "slowperiod", 26); + let signalperiod = get_usize_param(params, "signalperiod", 9); + let cache_key = format!("MACD{}#{}#{}", fastperiod, slowperiod, signalperiod); + update_macd_cache( + czsc, + &cache_key, + fastperiod, + slowperiod, + signalperiod, + cache, + ); + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MACD{}#{}#{}", di, fastperiod, slowperiod, signalperiod); + let k3 = "BS1辅助V230313"; + let mut v1 = "其他"; + let mut v2 = "任意"; + + let bars = get_sub_elements(&czsc.bars_raw, di, 300); + if bars.len() > 100 { + let mc = cache.macd.get(&cache_key).unwrap(); + let dif = get_sub_elements(&mc.dif, di, 300); + let dea = get_sub_elements(&mc.dea, di, 300); + let macd = get_sub_elements(&mc.macd, di, 300); + + let n_bars = &bars[bars.len() - 10..]; + let m_bars = &bars[bars.len() - 100..bars.len() - 10]; + let high_n = n_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let low_n = n_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let high_m = m_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let low_m = m_bars.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + + let cross = fast_slow_cross_ext(dif, dea); + let up: Vec<_> = cross.iter().filter(|x| x.kind > 0).collect(); + let dn: Vec<_> = cross.iter().filter(|x| x.kind < 0).collect(); + if cross.len() >= 3 { + let last = cross.last().unwrap(); + let c2 = &cross[cross.len() - 2]; + let c3 = &cross[cross.len() - 3]; + let b1_con1a = (cross.len() > 3 && last.kind < 0 && last.area < c2.area) + || (cross.len() > 3 && last.area < c3.area); + let b1_con1b = (cross.len() > 3 && last.kind > 0 && last.area > c2.area) + || (cross.len() > 3 && last.area < c3.area); + let b1_con2 = dn.len() > 3 && dn[dn.len() - 2].area < dn[dn.len() - 3].area; + let b1_con3 = macd.len() > 10 && macd[macd.len() - 1] > macd[macd.len() - 2]; + if low_n < low_m && (b1_con1a || b1_con1b) && b1_con2 && b1_con3 { + v1 = "一买"; + } + + let s1_con1a = (cross.len() > 3 && last.kind > 0 && last.area > c2.area) + || (cross.len() > 3 && last.area > c3.area); + let s1_con1b = (cross.len() > 3 && last.kind < 0 && last.area < c2.area) + || (cross.len() > 3 && last.area < c3.area); + let s1_con2 = up.len() > 3 && up[up.len() - 2].area > up[up.len() - 3].area; + let s1_con3 = macd.len() > 10 && macd[macd.len() - 1] < macd[macd.len() - 2]; + if high_n > high_m && (s1_con1a || s1_con1b) && s1_con2 && s1_con3 { + v1 = "一卖"; + } + v2 = if last.kind > 0 { "金叉" } else { "死叉" }; + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_boll_power_V221112:BOLL强弱分层信号 +/// +/// 参数模板:`"{freq}_D{di}BOLL{timeperiod}_强弱V221112"` +/// +/// 信号逻辑: +/// 1. 计算 BOLL 中线与标准差; +/// 2. 先以 `close` 相对中线判断 `多头/空头`; +/// 3. 再按偏离程度分层 `弱势/强势/超强/极强`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1BOLL20_强弱V221112_多头_强势_任意_0')` +/// - `Signal('60分钟_D1BOLL20_强弱V221112_空头_超强_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod`:BOLL周期,默认 `20`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_boll_power_V221112", + template = "{freq}_D{di}BOLL{timeperiod}_强弱V221112", + opcode = "TasBollPowerV221112", + param_kind = "TasBollPowerV221112" +)] +pub fn tas_boll_power_v221112(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let timeperiod = get_usize_param(params, "timeperiod", 20); + + // Python 使用 dev_seq = (1.382, 2, 2.764) 计算 BOLL 上下轨 + // 我们需要用 2 倍标准差的缓存来推导 std_dev,然后乘以正确的系数 + let cache_key = format!("BOLL{}#2.0", timeperiod); + update_boll_cache(czsc, &cache_key, timeperiod, 2.0, cache); + + let boll_cache = cache.boll.get(&cache_key).unwrap(); + let bars = &czsc.bars_raw; + let mut v1 = "其他"; + let mut v2 = "其他"; + + if bars.len() >= di + 20 { + let latest = get_sub_elements(bars, di, 1); + if !latest.is_empty() { + let idx = bars.len() - di; + let latest_c = latest[0].close; + let m = boll_cache.mid[idx]; + // upper = mid + 2*std_dev → std_dev = (upper - mid) / 2.0 + let std_dev = (boll_cache.upper[idx] - m) / 2.0; + + // 对齐 Python: dev_seq = (1.382, 2, 2.764) + let u1 = m + 1.382 * std_dev; + let u2 = m + 2.0 * std_dev; + let u3 = m + 2.764 * std_dev; + + let l1 = m - 1.382 * std_dev; + let l2 = m - 2.0 * std_dev; + let l3 = m - 2.764 * std_dev; + + v1 = if latest_c >= m { "多头" } else { "空头" }; + v2 = if latest_c >= u3 || latest_c <= l3 { + "极强" + } else if latest_c >= u2 || latest_c <= l2 { + "超强" + } else if latest_c >= u1 || latest_c <= l1 { + "强势" + } else { + "弱势" + }; + } + } + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}BOLL{}", di, timeperiod); + let k3 = "强弱V221112"; + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_boll_bc_V221118:BOLL背驰辅助信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}M{m}L{line}#BOLL{timeperiod}_背驰V221118"` +/// +/// 信号逻辑: +/// 1. 对比近端 `n` 根与参考段 `m` 根的价格极值; +/// 2. 结合 BOLL 指定轨道 `line` 的上下突破次数; +/// 3. 满足低点背驰给 `一买`,满足高点背驰给 `一卖`,否则 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N3M10L3#BOLL20_背驰V221118_一买_任意_任意_0')` +/// - `Signal('60分钟_D1N3M10L3#BOLL20_背驰V221118_一卖_任意_任意_0')` +/// - `Signal('60分钟_D1N3M10L3#BOLL20_背驰V221118_其他_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n/m`:近端与参考窗口长度,默认 `3/10`; +/// - `line`:轨道层级,默认 `3`; +/// - `timeperiod`:BOLL周期,默认 `20`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_boll_bc_V221118", + template = "{freq}_D{di}N{n}M{m}L{line}#BOLL{timeperiod}_背驰V221118", + opcode = "TasBollBcV221118", + param_kind = "TasBollBcV221118" +)] +pub fn tas_boll_bc_v221118(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 3); + let m = get_usize_param(params, "m", 10); + let line = get_usize_param(params, "line", 3); + let timeperiod = get_usize_param(params, "timeperiod", 20); + + let cache_key = format!("BOLL{}#2.0", timeperiod); + update_boll_cache(czsc, &cache_key, timeperiod, 2.0, cache); + + let boll_cache = cache.boll.get(&cache_key).unwrap(); + let bn_bars = get_sub_elements(&czsc.bars_raw, di, n); + let bm_bars = get_sub_elements(&czsc.bars_raw, di, m); + let mut v1 = "其他"; + + let dev = match line { + 1 => 1.382, + 2 => 2.0, + 3 => 2.764, + _ => line as f64, + }; + let get_line_val = |idx: usize, is_upper: bool| -> f64 { + let mid = boll_cache.mid[idx]; + let std_dev = (boll_cache.upper[idx] - mid) / 2.0; + if is_upper { + mid + dev * std_dev + } else { + mid - dev * std_dev + } + }; + + if !bn_bars.is_empty() && !bm_bars.is_empty() && czsc.bars_raw.len() >= di { + let min_low_n = bn_bars.iter().map(|b| b.low).fold(f64::INFINITY, f64::min); + let min_low_m = bm_bars.iter().map(|b| b.low).fold(f64::INFINITY, f64::min); + + let total_len = czsc.bars_raw.len(); + let bm_start_idx = total_len - di + 1 - bm_bars.len(); + let bn_start_idx = total_len - di + 1 - bn_bars.len(); + + let mut d_c2_count = 0; + for (offset, bar) in bm_bars.iter().enumerate() { + let idx = bm_start_idx + offset; + if bar.close < get_line_val(idx, false) { + d_c2_count += 1; + } + } + let mut d_c3_count = 0; + for (offset, bar) in bn_bars.iter().enumerate() { + let idx = bn_start_idx + offset; + if bar.close < get_line_val(idx, false) { + d_c3_count += 1; + } + } + + let d_c1 = min_low_n <= min_low_m; + let d_c2 = d_c2_count > 1; + let d_c3 = d_c3_count == 0; + + let max_high_n = bn_bars + .iter() + .map(|b| b.high) + .fold(f64::NEG_INFINITY, f64::max); + let max_high_m = bm_bars + .iter() + .map(|b| b.high) + .fold(f64::NEG_INFINITY, f64::max); + + let mut g_c2_count = 0; + for (offset, bar) in bm_bars.iter().enumerate() { + let idx = bm_start_idx + offset; + if bar.close > get_line_val(idx, true) { + g_c2_count += 1; + } + } + let mut g_c3_count = 0; + for (offset, bar) in bn_bars.iter().enumerate() { + let idx = bn_start_idx + offset; + if bar.close > get_line_val(idx, true) { + g_c3_count += 1; + } + } + + let g_c1 = max_high_n == max_high_m; + let g_c2 = g_c2_count > 1; + let g_c3 = g_c3_count == 0; + + v1 = if d_c1 && d_c2 && d_c3 { + "一买" + } else if g_c1 && g_c2 && g_c3 { + "一卖" + } else { + "其他" + }; + } + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}N{}M{}L{}#BOLL{}", di, n, m, line, timeperiod); + let k3 = "背驰V221118"; + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_kdj_base_V221101:KDJ基础辅助信号 +/// +/// 参数模板:`"{freq}_D{di}K#KDJ{fastk_period}#{slowk_period}#{slowd_period}_KDJ辅助V221101"` +/// +/// 信号逻辑: +/// 1. 计算 K、D、J 三序列; +/// 2. `J > K > D` 判定 `多头`,`J < K < D` 判定 `空头`,否则 `其他`; +/// 3. `J_now >= J_prev` 判定 `向上`,否则 `向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K#KDJ9#3#3_KDJ辅助V221101_多头_向上_任意_0')` +/// - `Signal('60分钟_D1K#KDJ9#3#3_KDJ辅助V221101_空头_向下_任意_0')` +/// - `Signal('60分钟_D1K#KDJ9#3#3_KDJ辅助V221101_其他_向下_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `fastk_period/slowk_period/slowd_period`:KDJ参数,默认 `9/3/3`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_kdj_base_V221101", + template = "{freq}_D{di}K#KDJ{fastk_period}#{slowk_period}#{slowd_period}_KDJ辅助V221101", + opcode = "TasKdjBaseV221101", + param_kind = "TasKdjBaseV221101" +)] +pub fn tas_kdj_base_v221101(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let fastk_period = get_usize_param(params, "fastk_period", 9); + let slowk_period = get_usize_param(params, "slowk_period", 3); + let slowd_period = get_usize_param(params, "slowd_period", 3); + + let cache_key = format!("KDJ{}#{}#{}", fastk_period, slowk_period, slowd_period); + crate::utils::ta::update_kdj_cache( + czsc, + &cache_key, + fastk_period, + slowk_period, + slowd_period, + cache, + ); + + let kdj_cache = cache.kdj.get(&cache_key).unwrap(); + + let k = get_sub_elements(&kdj_cache.k, di, 3); + let d = get_sub_elements(&kdj_cache.d, di, 3); + let j = get_sub_elements(&kdj_cache.j, di, 3); + if k.len() < 2 || d.len() < 2 || j.len() < 2 { + return Vec::new(); + } + + let k_last = *k.last().unwrap(); + let d_last = *d.last().unwrap(); + let j_last = *j.last().unwrap(); + let j_prev = j[j.len() - 2]; + + let v1 = if j_last > k_last && k_last > d_last { + "多头" + } else if j_last < k_last && k_last < d_last { + "空头" + } else { + "其他" + }; + + let v2 = if j_last >= j_prev { "向上" } else { "向下" }; + + let k1 = czsc.freq.to_string(); + let k2 = format!( + "D{}K#KDJ{}#{}#{}", + di, fastk_period, slowk_period, slowd_period + ); + let k3 = "KDJ辅助V221101"; + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_rsi_base_V230227:RSI超买超卖与方向信号 +/// +/// 参数模板:`"{freq}_D{di}T{th}RSI{timeperiod}_RSI辅助V230227"` +/// +/// 信号逻辑: +/// 1. 使用 `n` 计算 RSI(与 Python 保持一致); +/// 2. `rsi <= th` 判 `超卖`,`rsi >= 100-th` 判 `超买`,否则 `其他`; +/// 3. `rsi_now >= rsi_prev` 判 `向上`,否则 `向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T20RSI6_RSI辅助V230227_超卖_向上_任意_0')` +/// - `Signal('60分钟_D1T20RSI6_RSI辅助V230227_超买_向下_任意_0')` +/// - `Signal('60分钟_D1T20RSI6_RSI辅助V230227_其他_向上_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `n`:RSI 实际计算周期,默认 `6`; +/// - `timeperiod`:仅用于信号键展示,默认 `6`; +/// - `th`:超买超卖阈值,默认 `20`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_rsi_base_V230227", + template = "{freq}_D{di}T{th}RSI{timeperiod}_RSI辅助V230227", + opcode = "TasRsiBaseV230227", + param_kind = "TasRsiBaseV230227" +)] +pub fn tas_rsi_base_v230227(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + // 对齐 Python: n 用于实际 RSI 计算,timeperiod 仅用于信号 key + let n = get_usize_param(params, "n", 6); + let timeperiod = get_usize_param(params, "timeperiod", 6); + let th = get_usize_param(params, "th", 20); + + // 实际 RSI 计算用 n + let cache_key = format!("RSI{}", n); + crate::utils::ta::update_rsi_cache(czsc, &cache_key, n, cache); + + let series = cache.series.get(&cache_key).unwrap(); + let sub = get_sub_elements(series, di, 2); + if sub.len() < 2 { + return Vec::new(); + } + + let rsi_prev = sub[sub.len() - 2]; + let rsi = sub[sub.len() - 1]; + + let v1 = if rsi <= th as f64 { + "超卖" + } else if rsi >= 100.0 - th as f64 { + "超买" + } else { + "其他" + }; + + // 与 Python 保持一致:方向比较不加额外浮点容差 + let v2 = if rsi >= rsi_prev { "向上" } else { "向下" }; + + let k1 = czsc.freq.to_string(); + // 信号 key 用 timeperiod(对齐 Python 标签行为) + let k2 = format!("D{}T{}RSI{}", di, th, timeperiod); + let k3 = "RSI辅助V230227"; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_double_ma_V221203:双均线多空强弱信号 +/// +/// 参数模板:`"{freq}_D{di}T{th}#{ma_type}#{timeperiod1}#{timeperiod2}_JX辅助V221203"` +/// +/// 信号逻辑: +/// 1. 计算两条均线 `ma1/ma2`; +/// 2. `ma1 >= ma2` 判定 `多头`,否则 `空头`; +/// 3. 两线相对距离(BP)超过 `th` 判 `强势`,否则 `弱势`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T100#SMA#5#10_JX辅助V221203_多头_强势_任意_0')` +/// - `Signal('60分钟_D1T80#EMA#12#26_JX辅助V221203_空头_弱势_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `th`:强弱阈值(BP),默认 `100`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod1/timeperiod2`:两条均线周期,默认 `5/10`。 +/// 对齐说明:与 Python 同名函数逻辑与边界条件保持一致。 +#[signal( + category = "kline", + name = "tas_double_ma_V221203", + template = "{freq}_D{di}T{th}#{ma_type}#{timeperiod1}#{timeperiod2}_JX辅助V221203", + opcode = "TasDoubleMaV221203", + param_kind = "TasDoubleMaV221203" +)] +pub fn tas_double_ma_v221203(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 100); + let ma_type = get_str_param(params, "ma_type", "SMA"); + let timeperiod1 = get_usize_param(params, "timeperiod1", 5); + let timeperiod2 = get_usize_param(params, "timeperiod2", 10); + + let cache_key1 = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod1); + let cache_key2 = format!("{}_{}_{}", czsc.freq, ma_type, timeperiod2); + + crate::utils::ta::update_ma_cache(czsc, &cache_key1, ma_type, timeperiod1, cache); + crate::utils::ta::update_ma_cache(czsc, &cache_key2, ma_type, timeperiod2, cache); + + let ma1_series = cache.series.get(&cache_key1).unwrap(); + let ma2_series = cache.series.get(&cache_key2).unwrap(); + + let ma1_sub = get_sub_elements(ma1_series, di, 1); + let ma2_sub = get_sub_elements(ma2_series, di, 1); + if ma1_sub.is_empty() || ma2_sub.is_empty() { + return Vec::new(); + } + + let ma1v = ma1_sub[ma1_sub.len() - 1]; + let ma2v = ma2_sub[ma2_sub.len() - 1]; + + let v1 = if ma1v >= ma2v { "多头" } else { "空头" }; + let v2 = if (ma1v - ma2v).abs() / ma2v * 10000.0 >= th as f64 { + "强势" + } else { + "弱势" + }; + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}T{}#{}#{}#{}", di, th, ma_type, timeperiod1, timeperiod2); + let k3 = "JX辅助V221203"; + + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// tas_sar_base_V230425:SAR 基础多空信号 +/// +/// 参数模板:`"{freq}_D{di}MO{max_overlap}SAR_BS辅助V230425"` +/// +/// 信号逻辑: +/// 1. 计算 SAR 序列; +/// 2. 若当前 `close > sar` 且窗口内存在任意 `close < sar`,判定 `看多`; +/// 3. 若当前 `close < sar` 且窗口内存在任意 `close > sar`,判定 `看空`; +/// 4. 否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MO5SAR_BS辅助V230425_看多_任意_任意_0')` +/// - `Signal('60分钟_D1MO5SAR_BS辅助V230425_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `max_overlap`:重叠窗口,默认 `5`。 +/// 对齐说明:突破与重叠窗口判定逻辑对齐 Python `tas_sar_base_V230425`。 +#[signal( + category = "kline", + name = "tas_sar_base_V230425", + template = "{freq}_D{di}MO{max_overlap}SAR_BS辅助V230425", + opcode = "TasSarBaseV230425", + param_kind = "TasSarBaseV230425" +)] +pub fn tas_sar_base_v230425(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let max_overlap = get_usize_param(params, "max_overlap", 5); + let cache_key = "SAR"; + update_sar_cache(czsc, cache_key, cache); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MO{}SAR", di, max_overlap); + let k3 = "BS辅助V230425"; + let mut v1 = "其他"; + + if czsc.bars_raw.len() < 3 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let sar = cache.series.get(cache_key).unwrap(); + let id_to_idx = bar_index_map(czsc); + let bars = get_sub_elements(&czsc.bars_raw, di, max_overlap); + let bar = &czsc.bars_raw[czsc.bars_raw.len() - di]; + let idx = czsc.bars_raw.len() - di; + let bar_sar = sar[idx]; + if bar_sar.is_finite() { + if bar.close > bar_sar + && bars.iter().any(|x| { + id_to_idx + .get(&x.id) + .map(|i| sar[*i].is_finite() && x.close < sar[*i]) + .unwrap_or(false) + }) + { + v1 = "看多"; + } else if bar.close < bar_sar + && bars.iter().any(|x| { + id_to_idx + .get(&x.id) + .map(|i| sar[*i].is_finite() && x.close > sar[*i]) + .unwrap_or(false) + }) + { + v1 = "看空"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bs1_V230411:MACD DIF 五笔背驰信号 +/// +/// 参数模板:`"{freq}_D{di}T{tha}#{thb}#{thc}_BS1辅助V230411"` +/// +/// 信号逻辑: +/// 1. 取最近 5 笔并要求当前未完成笔长度约束; +/// 2. 上笔场景:涨幅、DIF 结构、末笔涨幅与 DIF 衰减同时满足,判定 `顶背驰`; +/// 3. 下笔场景镜像:跌幅、DIF 结构与 DIF 回升满足,判定 `底背驰`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T30#5#30_BS1辅助V230411_顶背驰_任意_任意_0')` +/// - `Signal('60分钟_D1T30#5#30_BS1辅助V230411_底背驰_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 笔,默认 `1`; +/// - `tha`:前三笔累计涨跌阈值(BP),默认 `30`; +/// - `thb`:第5笔相对第3笔价格阈值(BP),默认 `5`; +/// - `thc`:第5笔相对第3笔 DIF 变化阈值(BP),默认 `30`。 +/// 对齐说明:五笔条件组合与 Python `tas_macd_bs1_V230411` 一致。 +#[signal( + category = "kline", + name = "tas_macd_bs1_V230411", + template = "{freq}_D{di}T{tha}#{thb}#{thc}_BS1辅助V230411", + opcode = "TasMacdBs1V230411", + param_kind = "TasMacdBs1V230411" +)] +pub fn tas_macd_bs1_v230411(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let tha = get_usize_param(params, "tha", 30) as f64; + let thb = get_usize_param(params, "thb", 5) as f64; + let thc = get_usize_param(params, "thc", 30) as f64; + assert!(tha > 0.0 && tha < 10000.0); + assert!(thb > 0.0 && thb < 10000.0); + assert!(thc > 0.0 && thc < 10000.0); + + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let dif = &mc.dif; + let id_to_idx = bar_index_map(czsc); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}T{}#{}#{}", di, tha as usize, thb as usize, thc as usize); + let k3 = "BS1辅助V230411"; + let mut v1 = "其他"; + + if czsc.bi_list.len() <= di + 7 || czsc.bars_ubi.len() > 9 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let bis = get_sub_elements(&czsc.bi_list, di, 5); + if bis.len() < 5 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bi1 = &bis[0]; + let bi3 = &bis[2]; + let bi5 = &bis[4]; + + let bi1_raw = bi1.get_raw_bars(); + if bi1_raw.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let first_dif = id_to_idx + .get(&bi1_raw[0].id) + .map(|i| dif[*i]) + .unwrap_or(f64::NAN); + if first_dif.is_nan() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + if bi5.direction == Direction::Up { + let bi1_dif = values_from_fx(&bi1.fx_b, &id_to_idx, dif) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let bi3_dif = values_from_fx(&bi3.fx_b, &id_to_idx, dif) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let bi5_dif = values_from_fx(&bi5.fx_b, &id_to_idx, dif) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let cond1 = ((bi3.get_high() - bi1.get_low()) / bi1.get_low()) * 10000.0 > tha; + let cond2 = bi3_dif > bi1_dif; + let cond3 = ((bi5.get_high() - bi3.get_high()) / bi3.get_high()) * 10000.0 > -thb; + let cond4 = ((bi5_dif - bi3_dif) / bi3_dif) * 10000.0 < -thc; + if cond1 && cond2 && cond3 && cond4 { + v1 = "顶背驰"; + } + } else if bi5.direction == Direction::Down { + let bi1_dif = values_from_fx(&bi1.fx_b, &id_to_idx, dif) + .into_iter() + .fold(f64::INFINITY, f64::min); + let bi3_dif = values_from_fx(&bi3.fx_b, &id_to_idx, dif) + .into_iter() + .fold(f64::INFINITY, f64::min); + let bi5_dif = values_from_fx(&bi5.fx_b, &id_to_idx, dif) + .into_iter() + .fold(f64::INFINITY, f64::min); + let cond1 = ((bi3.get_low() - bi1.get_high()) / bi1.get_high()) * 10000.0 < -tha; + let cond2 = bi3_dif < bi1_dif; + let cond3 = ((bi5.get_low() - bi3.get_low()) / bi3.get_low()) * 10000.0 < thb; + let cond4 = ((bi5_dif - bi3_dif) / bi3_dif) * 10000.0 > thc; + if cond1 && cond2 && cond3 && cond4 { + v1 = "底背驰"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bs1_V230412:MACD DIF 五笔背驰简化信号 +/// +/// 参数模板:`"{freq}_D{di}T{tha}#{thb}_BS1辅助V230412"` +/// +/// 信号逻辑: +/// 1. 取最近 5 笔并校验未完成笔长度; +/// 2. 上笔场景:前三笔涨幅过阈值,且 `DIF(3)` 为局部最大,末笔价格不弱,判 `顶背驰`; +/// 3. 下笔场景镜像:`DIF(3)` 为局部最小,判 `底背驰`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T100#10_BS1辅助V230412_顶背驰_任意_任意_0')` +/// - `Signal('60分钟_D1T100#10_BS1辅助V230412_底背驰_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 笔,默认 `1`; +/// - `tha`:前三笔累计涨跌阈值(BP),默认 `100`; +/// - `thb`:第5笔相对第3笔价格阈值(BP),默认 `10`。 +/// 对齐说明:条件组合与 Python `tas_macd_bs1_V230412` 保持一致。 +#[signal( + category = "kline", + name = "tas_macd_bs1_V230412", + template = "{freq}_D{di}T{tha}#{thb}_BS1辅助V230412", + opcode = "TasMacdBs1V230412", + param_kind = "TasMacdBs1V230412" +)] +pub fn tas_macd_bs1_v230412(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let tha = get_usize_param(params, "tha", 100) as f64; + let thb = get_usize_param(params, "thb", 10) as f64; + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let dif = &mc.dif; + let id_to_idx = bar_index_map(czsc); + let mut snapshot_overrides = HashMap::new(); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}T{}#{}", di, tha as usize, thb as usize); + let k3 = "BS1辅助V230412"; + let mut v1 = "其他"; + + if czsc.bi_list.len() <= di + 7 || czsc.bars_ubi.len() > 9 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis = get_sub_elements(&czsc.bi_list, di, 5); + if bis.len() < 5 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bi1 = &bis[0]; + let bi3 = &bis[2]; + let bi5 = &bis[4]; + let bi1_raw = bi1.get_raw_bars(); + if bi1_raw.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let first_dif = id_to_idx + .get(&bi1_raw[0].id) + .map(|i| dif[*i]) + .unwrap_or(f64::NAN); + if first_dif.is_nan() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + if bi5.direction == Direction::Up { + let bi1_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &bi1.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let bi3_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &bi3.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let bi5_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &bi5.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let cond1 = ((bi3.get_high() - bi1.get_low()) / bi1.get_low()) * 10000.0 > tha; + let cond2 = bi5_dif < bi3_dif && bi3_dif > bi1_dif; + let cond3 = ((bi5.get_high() - bi3.get_high()) / bi3.get_high()) * 10000.0 > -thb; + if cond1 && cond2 && cond3 { + v1 = "顶背驰"; + } + } else if bi5.direction == Direction::Down { + let bi1_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &bi1.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::INFINITY, f64::min); + let bi3_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &bi3.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::INFINITY, f64::min); + let bi5_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &bi5.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::INFINITY, f64::min); + let cond1 = ((bi3.get_low() - bi1.get_high()) / bi1.get_high()) * 10000.0 < -tha; + let cond2 = bi5_dif > bi3_dif && bi3_dif < bi1_dif; + let cond3 = ((bi5.get_low() - bi3.get_low()) / bi3.get_low()) * 10000.0 < thb; + if cond1 && cond2 && cond3 { + v1 = "底背驰"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_rumi_V230704:RUMI 零轴切换信号 +/// +/// 参数模板:`"{freq}_D{di}F{timeperiod1}S{timeperiod2}R{rumi_window}_BS辅助V230704"` +/// +/// 信号逻辑: +/// 1. 计算 `SMA(timeperiod1)` 与 `WMA(timeperiod2)`,得到 `diff = fast - slow`; +/// 2. 对 `diff` 做 `SMA(rumi_window)` 平滑,得到 `rumi`; +/// 3. `rumi` 上穿 0 轴判 `多头`,下穿 0 轴判 `空头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1F3S50R30_BS辅助V230704_多头_任意_任意_0')` +/// - `Signal('60分钟_D1F3S50R30_BS辅助V230704_空头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 根K线,默认 `1`; +/// - `timeperiod1`:快线均线周期,默认 `3`; +/// - `timeperiod2`:慢线均线周期,默认 `50`; +/// - `rumi_window`:RUMI 平滑周期,默认 `30`。 +/// 对齐说明:快慢线选型与零轴交叉判定对齐 Python `tas_rumi_V230704`。 +#[signal( + category = "kline", + name = "tas_rumi_V230704", + template = "{freq}_D{di}F{timeperiod1}S{timeperiod2}R{rumi_window}_BS辅助V230704", + opcode = "TasRumiV230704", + param_kind = "TasRumiV230704" +)] +pub fn tas_rumi_v230704(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let rumi_window = get_usize_param(params, "rumi_window", 30); + let timeperiod1 = get_usize_param(params, "timeperiod1", 3); + let timeperiod2 = get_usize_param(params, "timeperiod2", 50); + assert!( + rumi_window < timeperiod2, + "rumi_window 必须小于 timeperiod2" + ); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}F{}S{}R{}", di, timeperiod1, timeperiod2, rumi_window); + let k3 = "BS辅助V230704"; + let mut v1 = "其他"; + + if czsc.bars_raw.len() < di + timeperiod2 || di == 0 || di > czsc.bars_raw.len() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let key1 = format!("{}_{}_{}", czsc.freq, "SMA", timeperiod1); + let key2 = format!("{}_{}_{}", czsc.freq, "WMA", timeperiod2); + update_ma_cache(czsc, &key1, "SMA", timeperiod1, cache); + update_ma_cache(czsc, &key2, "WMA", timeperiod2, cache); + let fast = cache.series.get(&key1).unwrap(); + let slow = cache.series.get(&key2).unwrap(); + let id_to_idx = bar_index_map(czsc); + + let bars = get_sub_elements(&czsc.bars_raw, di, timeperiod2); + if bars.len() == timeperiod2 { + let fast_arr: Vec = bars + .iter() + .filter_map(|x| id_to_idx.get(&x.id).map(|i| fast[*i])) + .collect(); + let slow_arr: Vec = bars + .iter() + .filter_map(|x| id_to_idx.get(&x.id).map(|i| slow[*i])) + .collect(); + if fast_arr.len() == timeperiod2 && slow_arr.len() == timeperiod2 { + let diff: Vec = fast_arr + .iter() + .zip(slow_arr.iter()) + .map(|(a, b)| *a - *b) + .collect(); + let rumi = calc_sma(&diff, rumi_window); + if rumi.len() >= 2 { + let r1 = rumi[rumi.len() - 2]; + let r2 = rumi[rumi.len() - 1]; + if r2 > 0.0 && r1 < 0.0 { + v1 = "多头"; + } else if r2 < 0.0 && r1 > 0.0 { + v1 = "空头"; + } + } + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bc_V230804:MACD 黄白线背驰信号 +/// +/// 参数模板:`"{freq}_D{di}MACD背驰_BS辅助V230804"` +/// +/// 信号逻辑: +/// 1. 取最近 7 笔,并在末 5 笔构建中枢; +/// 2. 上笔场景:末笔位于高位区且 DIF 峰值弱于前两上笔,判 `空头`; +/// 3. 下笔场景镜像:末笔位于低位区且 DIF 谷值抬升,判 `多头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD背驰_BS辅助V230804_空头_任意_任意_0')` +/// - `Signal('60分钟_D1MACD背驰_BS辅助V230804_多头_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:倒数第 `di` 笔,默认 `1`。 +/// 对齐说明:中枢有效性与 DIF 对比口径与 Python `tas_macd_bc_V230804` 一致。 +#[signal( + category = "kline", + name = "tas_macd_bc_V230804", + template = "{freq}_D{di}MACD背驰_BS辅助V230804", + opcode = "TasMacdBcV230804", + param_kind = "TasMacdBcV230804" +)] +pub fn tas_macd_bc_v230804(czsc: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let id_to_idx = bar_index_map(czsc); + let mut snapshot_overrides: HashMap = HashMap::new(); + + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}MACD背驰", di); + let k3 = "BS辅助V230804"; + let mut v1 = "其他"; + + if czsc.bi_list.len() < 7 || czsc.bars_ubi.len() >= 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + let bis = get_sub_elements(&czsc.bi_list, di, 7); + if bis.len() < 7 { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let zs = ZS::new(bis[bis.len() - 5..].to_vec()); + if !zs.is_valid() { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let dd = bis + .iter() + .map(|bi| bi.get_low()) + .fold(f64::INFINITY, f64::min); + let gg = bis + .iter() + .map(|bi| bi.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + let b1 = &bis[bis.len() - 5]; + let b3 = &bis[bis.len() - 3]; + let b5 = &bis[bis.len() - 1]; + + if b5.direction == Direction::Up && b5.get_high() > (gg - (gg - dd) / 4.0) { + let b5_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b5.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let mut od = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b1.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ); + od.extend(snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b3.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + )); + let od_dif = od.into_iter().fold(f64::NEG_INFINITY, f64::max); + if 0.0 < b5_dif && b5_dif < od_dif { + v1 = "空头"; + } + } + if b5.direction == Direction::Down && b5.get_low() < (dd + (gg - dd) / 4.0) { + let b5_dif = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b5.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::INFINITY, f64::min); + let mut od = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b1.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ); + od.extend(snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b3.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + )); + let od_dif = od.into_iter().fold(f64::INFINITY, f64::min); + if 0.0 > b5_dif && b5_dif > od_dif { + v1 = "多头"; + } + } + + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// tas_macd_bc_ubi_V230804:未完成笔 MACD 背驰观察 +/// +/// 参数模板:`"{freq}_MACD背驰_UBI观察V230804"` +/// +/// 信号逻辑: +/// 1. 使用未完成笔(UBI)方向与极值位置; +/// 2. 在最近 6 笔中构造中枢并比较 UBI 末段 DIF 与历史对应笔 DIF; +/// 3. 上行 UBI DIF 走弱判 `空头`,下行 UBI DIF 抬升判 `多头`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_MACD背驰_UBI观察V230804_空头_任意_任意_0')` +/// - `Signal('60分钟_MACD背驰_UBI观察V230804_多头_任意_任意_0')` +/// +/// 参数说明: +/// - 无额外参数。 +/// 对齐说明:UBI 原始K线口径与 Python `tas_macd_bc_ubi_V230804` 一致。 +#[signal( + category = "kline", + name = "tas_macd_bc_ubi_V230804", + template = "{freq}_MACD背驰_UBI观察V230804", + opcode = "TasMacdBcUbiV230804", + param_kind = "TasMacdBcUbiV230804" +)] +pub fn tas_macd_bc_ubi_v230804( + czsc: &CZSC, + _params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let cache_key = "MACD12#26#9"; + update_macd_cache(czsc, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let id_to_idx = bar_index_map(czsc); + let mut snapshot_overrides: HashMap = HashMap::new(); + + let k1 = czsc.freq.to_string(); + let k2 = "MACD背驰"; + let k3 = "UBI观察V230804"; + let mut v1 = "其他"; + + // 对齐 Python `not ubi` 语义:ubi_fxs 为空视为 ubi 不可用。 + let Some(ubi_fxs) = czsc.get_ubi_fxs() else { + return make_kline_signal_v1(&k1, k2, k3, v1); + }; + if ubi_fxs.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + let ubi_raw_bars: Vec = czsc + .bars_ubi + .iter() + .flat_map(|nb| nb.elements.iter().cloned()) + .collect(); + if czsc.bi_list.len() < 7 || ubi_raw_bars.len() < 7 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + let bis = get_sub_elements(&czsc.bi_list, 1, 6); + if bis.len() < 6 { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let zs = ZS::new(bis[bis.len() - 5..].to_vec()); + if !zs.is_valid() { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + + let dd = bis + .iter() + .map(|bi| bi.get_low()) + .fold(f64::INFINITY, f64::min); + let gg = bis + .iter() + .map(|bi| bi.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + let ubi_high = ubi_raw_bars + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let ubi_low = ubi_raw_bars + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min); + let ubi_direction = if czsc.bi_list.last().unwrap().direction == Direction::Down { + Direction::Up + } else { + Direction::Down + }; + + let b2 = &bis[bis.len() - 4]; + let b4 = &bis[bis.len() - 2]; + if ubi_direction == Direction::Up && ubi_high > (gg - (gg - dd) / 4.0) { + let b5_dif = snapshot_dif_values_from_raw_bars( + czsc, + mc, + &id_to_idx, + &ubi_raw_bars[ubi_raw_bars.len() - 5..], + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max); + let mut od = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b2.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ); + od.extend(snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b4.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + )); + let od_dif = od.into_iter().fold(f64::NEG_INFINITY, f64::max); + if 0.0 < b5_dif && b5_dif < od_dif { + v1 = "空头"; + } + } + if ubi_direction == Direction::Down && ubi_low < (dd + (gg - dd) / 4.0) { + let b5_dif = snapshot_dif_values_from_raw_bars( + czsc, + mc, + &id_to_idx, + &ubi_raw_bars[ubi_raw_bars.len() - 5..], + 12, + 26, + 9, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::INFINITY, f64::min); + let mut od = snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b2.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + ); + od.extend(snapshot_dif_values_from_fx( + czsc, + mc, + &id_to_idx, + &b4.fx_b, + 12, + 26, + 9, + &mut snapshot_overrides, + )); + let od_dif = od.into_iter().fold(f64::INFINITY, f64::min); + if 0.0 > b5_dif && b5_dif > od_dif { + v1 = "多头"; + } + } + + make_kline_signal_v1(&k1, k2, k3, v1) +} diff --git a/crates/czsc-signals/src/types.rs b/crates/czsc-signals/src/types.rs new file mode 100644 index 000000000..a1854b42f --- /dev/null +++ b/crates/czsc-signals/src/types.rs @@ -0,0 +1,122 @@ +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use serde_json::Value; +use std::collections::HashMap; + +/// MACD 缓存三元组 +#[derive(Debug, Clone, Default)] +pub struct MacdSeries { + pub ids: Vec, + pub dif: Vec, + pub dea: Vec, + pub macd: Vec, +} + +/// BOLL 缓存三元组 +#[derive(Debug, Clone, Default)] +pub struct BollSeries { + pub upper: Vec, + pub mid: Vec, + pub lower: Vec, +} + +/// KDJ 缓存三元组 +#[derive(Debug, Clone, Default)] +pub struct KdjSeries { + pub ids: Vec, + pub k: Vec, + pub d: Vec, + pub j: Vec, +} + +/// TA 指标增量缓存,存放所有由纯 Rust 计算产生的序列数据 +#[derive(Debug, Clone, Default)] +pub struct TaCache { + /// 简单单点序列(如 EMA/SMA/RSI/ATR)的缓存 + pub series: HashMap>, + + /// MACD 数据缓存 + pub macd: HashMap, + + /// BOLL 数据缓存 + pub boll: HashMap, + + /// boll 对应的 bar id 序列(用于 bars_raw 截断后对齐) + pub boll_ids: HashMap>, + + /// KDJ 数据缓存 + pub kdj: HashMap, + + /// series 对应的 bar id 序列(用于 bars_raw 截断后对齐) + pub series_ids: HashMap>, + + /// 标记已缓存的最大长度,用于增量判断 + pub last_len: usize, +} + +impl TaCache { + pub fn new() -> Self { + Self::default() + } +} + +/// 信号函数签名(单 freq 计算) +pub type SignalFn = fn(&CZSC, &HashMap, &mut TaCache) -> Vec; +pub type FastKlineDecodeFn = fn(&HashMap) -> Option; +pub type FastKlineExecFn = fn(&CZSC, &Value, &mut TaCache) -> Vec; + +#[derive(Clone, Copy)] +pub struct FastKlineMeta { + pub decode: FastKlineDecodeFn, + pub exec: FastKlineExecFn, +} + +/// 运行时 K 线信号元信息(由 `SignalDescriptor` 归并而来)。 +pub struct SignalMeta { + pub func: SignalFn, + pub param_template: &'static str, + pub fast_kline: Option, +} + +/// 依赖 TraderState 的信号函数签名(pos 系列,需要仓位和K线的联合状态) +pub type TraderSignalFn = fn( + cat: &dyn czsc_core::objects::state::TraderState, + params: &HashMap, +) -> Vec; + +/// 运行时 Trader 信号元信息(由 `SignalDescriptor` 归并而来)。 +pub struct TraderSignalMeta { + pub func: TraderSignalFn, + pub param_template: &'static str, +} + +/// 对信号函数的类型化引用,用于在注册中心中区分 K 线级与 Trader 级信号。 +#[derive(Clone, Copy)] +pub enum SignalFnRef { + /// 仅依赖 `CZSC + params + TaCache` 的 K 线级信号函数。 + Kline(SignalFn), + /// 依赖 `TraderState + params` 的 Trader/Position 级信号函数。 + Trader(TraderSignalFn), +} + +/// 信号描述符(编译期元数据)。 +/// +/// 该结构由 `#[signal(...)]` 宏生成并通过 `inventory` 自动收集, +/// 作为信号注册、编译计划构建与执行分派的单一元数据来源。 +#[derive(Clone, Copy)] +pub struct SignalDescriptor { + /// 信号类别:`kline` 或 `trader`。 + pub category: &'static str, + /// 信号名称(如 `tas_ma_base_V221101`)。 + pub name: &'static str, + /// 参数模板(与 Python/策略配置保持一致)。 + pub template: &'static str, + /// 内部操作码名称(用于执行层分派与冲突检测)。 + pub opcode: &'static str, + /// 参数类型标识(用于后续 typed params 映射)。 + pub param_kind: &'static str, + /// 函数指针引用(按 `category` 解释为 Kline 或 Trader 函数)。 + pub func_ref: SignalFnRef, + /// 可选 fast-path 元信息;存在时可在执行层避免 HashMap 解释开销。 + pub fast_kline: Option, +} diff --git a/crates/czsc-signals/src/utils/cxt.rs b/crates/czsc-signals/src/utils/cxt.rs new file mode 100644 index 000000000..c5af46f7b --- /dev/null +++ b/crates/czsc-signals/src/utils/cxt.rs @@ -0,0 +1,335 @@ +use crate::utils::math::mean; +use czsc_core::analyze::{CZSC, UBI}; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::bi::BI; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::fx::FX; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::zs::ZS; + +#[inline] +pub fn raw_bar_upper(bar: &RawBar) -> f64 { + bar.high - bar.open.max(bar.close) +} + +#[inline] +pub fn raw_bar_lower(bar: &RawBar) -> f64 { + bar.open.min(bar.close) - bar.low +} + +pub fn fx_raw_bars(fx: &FX) -> Vec { + fx.elements + .iter() + .flat_map(|nb| nb.elements.iter().cloned()) + .collect() +} + +pub fn fx_power_str(fx: &FX) -> &'static str { + if fx.elements.len() != 3 { + return "弱"; + } + let k1 = &fx.elements[0]; + let k2 = &fx.elements[1]; + let k3 = &fx.elements[2]; + match fx.mark { + Mark::D => { + if k3.close > k1.high { + "强" + } else if k3.close > k2.high { + "中" + } else { + "弱" + } + } + Mark::G => { + if k3.close < k1.low { + "强" + } else if k3.close < k2.low { + "中" + } else { + "弱" + } + } + } +} + +pub fn fx_has_zs(fx: &FX) -> bool { + if fx.elements.len() != 3 { + return false; + } + let zd = fx + .elements + .iter() + .map(|x| x.low) + .fold(f64::NEG_INFINITY, f64::max); + let zg = fx + .elements + .iter() + .map(|x| x.high) + .fold(f64::INFINITY, f64::min); + zg >= zd +} + +pub fn get_zs_seq(bis: &[BI]) -> Vec { + let mut zs_list: Vec = Vec::new(); + if bis.is_empty() { + return zs_list; + } + + for bi in bis.iter().cloned() { + if zs_list.is_empty() { + zs_list.push(ZS::new(vec![bi])); + continue; + } + + let last_zs = zs_list.pop().unwrap(); + if last_zs.bis.is_empty() { + let mut new_bis = last_zs.bis; + new_bis.push(bi); + zs_list.push(ZS::new(new_bis)); + } else if (bi.direction == Direction::Up && bi.get_high() < last_zs.zd) + || (bi.direction == Direction::Down && bi.get_low() > last_zs.zg) + { + zs_list.push(last_zs); + zs_list.push(ZS::new(vec![bi])); + } else { + let mut new_bis = last_zs.bis; + new_bis.push(bi); + zs_list.push(ZS::new(new_bis)); + } + } + zs_list +} + +pub fn unique_prices_from_bars(bars: &[RawBar]) -> Vec { + let mut prices: Vec = Vec::with_capacity(bars.len() * 4); + for b in bars { + if b.close.is_finite() { + prices.push(b.close); + } + if b.high.is_finite() { + prices.push(b.high); + } + if b.low.is_finite() { + prices.push(b.low); + } + if b.open.is_finite() { + prices.push(b.open); + } + } + prices.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + prices.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON); + prices +} + +pub fn ubi_raw_bars(c: &CZSC) -> Vec { + c.bars_ubi + .iter() + .flat_map(|nb| nb.elements.iter().cloned()) + .collect() +} + +pub fn calc_bi_status_values(czsc: &CZSC, ubi_fxs: &[FX]) -> (&'static str, &'static str) { + let last_bi = czsc.bi_list.last().unwrap(); + let v1 = match last_bi.direction { + Direction::Down => { + if czsc.bars_ubi.len() > 7 { + "向上" + } else { + "向下" + } + } + Direction::Up => { + if czsc.bars_ubi.len() > 7 { + "向下" + } else { + "向上" + } + } + }; + + let last_fx = ubi_fxs.last().unwrap(); + let v2 = match last_fx.mark { + Mark::D => { + if v1 == "向下" { + "底分" + } else { + "延伸" + } + } + Mark::G => { + if v1 == "向上" { + "顶分" + } else { + "延伸" + } + } + }; + (v1, v2) +} + +pub fn rebuild_ubi(c: &CZSC) -> Option { + if c.bars_ubi.is_empty() || c.bi_list.is_empty() { + return None; + } + let ubi_fxs = c.get_ubi_fxs()?; + if ubi_fxs.is_empty() { + return None; + } + let raw_bars: Vec = c + .bars_ubi + .iter() + .flat_map(|x| x.elements.iter().cloned()) + .collect(); + let high_bar = raw_bars + .iter() + .max_by(|a, b| { + a.high + .partial_cmp(&b.high) + .unwrap_or(std::cmp::Ordering::Less) + })? + .clone(); + let low_bar = raw_bars + .iter() + .min_by(|a, b| { + a.low + .partial_cmp(&b.low) + .unwrap_or(std::cmp::Ordering::Greater) + })? + .clone(); + let direction = if c.bi_list.last().unwrap().direction == Direction::Down { + Direction::Up + } else { + Direction::Down + }; + Some(UBI { + symbol: c.symbol.clone(), + direction, + high: high_bar.high, + low: low_bar.low, + high_bar, + low_bar, + bars: c.bars_ubi.clone(), + raw_bars, + fxs: ubi_fxs.clone(), + fx_a: ubi_fxs.first().unwrap().clone(), + }) +} + +pub fn check_first_buy(bis: &[BI]) -> bool { + if bis.len() % 2 != 1 + || bis.last().unwrap().direction == Direction::Up + || bis.first().unwrap().direction != bis.last().unwrap().direction + { + return false; + } + let max_high = bis + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + let min_low = bis + .iter() + .map(|x| x.get_low()) + .fold(f64::INFINITY, f64::min); + if max_high != bis.first().unwrap().get_high() || min_low != bis.last().unwrap().get_low() { + return false; + } + let mut key_bis: Vec<&BI> = Vec::new(); + for i in (0..=(bis.len() - 3)).step_by(2) { + if i == 0 { + key_bis.push(&bis[i]); + } else { + let b1 = &bis[i - 2]; + let b3 = &bis[i]; + if b3.get_low() < b1.get_low() { + key_bis.push(b3); + } + } + } + if key_bis.is_empty() { + return false; + } + let last = bis.last().unwrap(); + let prev = &bis[bis.len() - 3]; + let bc_price = last.get_power_price() + < prev.get_power_price().max(mean( + &key_bis + .iter() + .map(|x| x.get_power_price()) + .collect::>(), + )); + let bc_volume = last.get_power_volume() + < prev.get_power_volume().max(mean( + &key_bis + .iter() + .map(|x| x.get_power_volume()) + .collect::>(), + )); + let bc_length = (last.get_length() as f64) + < (prev.get_length() as f64).max(mean( + &key_bis + .iter() + .map(|x| x.get_length() as f64) + .collect::>(), + )); + bc_price && (bc_volume || bc_length) +} + +pub fn check_first_sell(bis: &[BI]) -> bool { + if bis.len() % 2 != 1 + || bis.last().unwrap().direction == Direction::Down + || bis.first().unwrap().direction != bis.last().unwrap().direction + { + return false; + } + let max_high = bis + .iter() + .map(|x| x.get_high()) + .fold(f64::NEG_INFINITY, f64::max); + let min_low = bis + .iter() + .map(|x| x.get_low()) + .fold(f64::INFINITY, f64::min); + if max_high != bis.last().unwrap().get_high() || min_low != bis.first().unwrap().get_low() { + return false; + } + let mut key_bis: Vec<&BI> = Vec::new(); + for i in (0..=(bis.len() - 3)).step_by(2) { + if i == 0 { + key_bis.push(&bis[i]); + } else { + let b1 = &bis[i - 2]; + let b3 = &bis[i]; + if b3.get_high() > b1.get_high() { + key_bis.push(b3); + } + } + } + if key_bis.is_empty() { + return false; + } + let last = bis.last().unwrap(); + let prev = &bis[bis.len() - 3]; + let bc_price = last.get_power_price() + < prev.get_power_price().max(mean( + &key_bis + .iter() + .map(|x| x.get_power_price()) + .collect::>(), + )); + let bc_volume = last.get_power_volume() + < prev.get_power_volume().max(mean( + &key_bis + .iter() + .map(|x| x.get_power_volume()) + .collect::>(), + )); + let bc_length = (last.get_length() as f64) + < (prev.get_length() as f64).max(mean( + &key_bis + .iter() + .map(|x| x.get_length() as f64) + .collect::>(), + )); + bc_price && (bc_volume || bc_length) +} diff --git a/crates/czsc-signals/src/utils/math.rs b/crates/czsc-signals/src/utils/math.rs new file mode 100644 index 000000000..604cae22f --- /dev/null +++ b/crates/czsc-signals/src/utils/math.rs @@ -0,0 +1,97 @@ +pub fn mean(values: &[f64]) -> f64 { + if values.is_empty() { + 0.0 + } else { + values.iter().sum::() / values.len() as f64 + } +} + +pub fn std_pop(values: &[f64]) -> f64 { + if values.is_empty() { + return 0.0; + } + let m = mean(values); + (values.iter().map(|x| (x - m).powi(2)).sum::() / values.len() as f64).sqrt() +} + +pub fn percentile_linear(values: &[f64], p: f64) -> Option { + if values.is_empty() { + return None; + } + let mut x = values.to_vec(); + x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + if x.len() == 1 { + return Some(x[0]); + } + let rank = p.clamp(0.0, 100.0) / 100.0 * (x.len() as f64 - 1.0); + let lo = rank.floor() as usize; + let hi = rank.ceil() as usize; + if lo == hi { + Some(x[lo]) + } else { + Some(x[lo] + (x[hi] - x[lo]) * (rank - lo as f64)) + } +} + +pub fn median_abs(values: &[f64]) -> f64 { + let mut x: Vec = values.iter().map(|v| v.abs()).collect(); + x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + if x.is_empty() { + 0.0 + } else if x.len() % 2 == 1 { + x[x.len() / 2] + } else { + (x[x.len() / 2 - 1] + x[x.len() / 2]) / 2.0 + } +} + +pub fn max_amplitude_pct(prices: &[f64]) -> f64 { + if prices.is_empty() { + return 100.0; + } + let max_price = prices.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let min_price = prices.iter().copied().fold(f64::INFINITY, f64::min); + if min_price == 0.0 { + 100.0 + } else { + (max_price - min_price) / min_price * 100.0 + } +} + +pub fn linreg_predict(xs: &[f64], ys: &[f64], x: f64) -> Option { + if xs.len() != ys.len() || xs.is_empty() { + return None; + } + let n = xs.len() as f64; + let mean_x = xs.iter().sum::() / n; + let mean_y = ys.iter().sum::() / n; + let cov = xs + .iter() + .zip(ys.iter()) + .map(|(xv, yv)| (xv - mean_x) * (yv - mean_y)) + .sum::(); + let var_x = xs.iter().map(|xv| (xv - mean_x).powi(2)).sum::(); + if var_x == 0.0 { + return Some(mean_y); + } + let slope = cov / var_x; + let intercept = mean_y - slope * mean_x; + Some(slope * x + intercept) +} + +pub fn overlap(h1: f64, l1: f64, h2: f64, l2: f64) -> bool { + l1.max(l2) < h1.min(h2) +} + +#[cfg(test)] +mod tests { + use super::linreg_predict; + + #[test] + fn linreg_predict_returns_constant_for_single_sample() { + let xs = [1.0]; + let ys = [10.0]; + let pred = linreg_predict(&xs, &ys, 99.0); + assert_eq!(pred, Some(10.0)); + } +} diff --git a/crates/czsc-signals/src/utils/mod.rs b/crates/czsc-signals/src/utils/mod.rs new file mode 100644 index 000000000..a66ddc289 --- /dev/null +++ b/crates/czsc-signals/src/utils/mod.rs @@ -0,0 +1,5 @@ +pub mod cxt; +pub mod math; +pub mod sig; +pub mod ta; +pub mod zdy; diff --git a/crates/czsc-signals/src/utils/sig.rs b/crates/czsc-signals/src/utils/sig.rs new file mode 100644 index 000000000..e0d06b4c9 --- /dev/null +++ b/crates/czsc-signals/src/utils/sig.rs @@ -0,0 +1,844 @@ +use crate::params::ParamView; +use chrono::{Datelike, Duration, Timelike}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::operate::Operate; +use czsc_core::objects::position::OperateRecord; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::state::TraderState; +use std::collections::HashMap; +use std::str::FromStr; + +/// 获取截止到倒数第 `di` 个元素的前 `n` 个元素 +/// +/// 对齐 Python `get_sub_elements` 语义: +/// - `di == 1` 时返回最后 `n` 个; +/// - `di > 1` 时返回 `[-n-di+1 : -di+1]`; +/// - 数量不足时返回可用区间(可能为空)。 +pub fn get_sub_elements(elements: &[T], di: usize, n: usize) -> &[T] { + assert!(di >= 1, "di must be >= 1"); + if elements.is_empty() || di > elements.len() { + return &elements[0..0]; + } + // 对齐 Python 切片语义: + // get_sub_elements(elements, di=1, n=0) -> elements[-0:] -> 全量 + if n == 0 { + return if di == 1 { + &elements[0..elements.len()] + } else { + &elements[0..0] + }; + } + + let end = elements.len() - di + 1; + let start = end.saturating_sub(n); + &elements[start..end] +} + +/// 解析数字或字符串为 usize +pub fn get_usize_param(params: &ParamView, key: &str, default: usize) -> usize { + if let Some(val) = params.value(key) { + if let Some(n) = val.as_u64() { + return n as usize; + } + if let Some(s) = val.as_str() { + if let Ok(n) = s.parse::() { + return n; + } + } + } + default +} + +/// 将 `bar.id -> 索引` 映射成哈希表,便于在信号函数中做 O(1) 定位。 +pub fn bar_index_map(czsc: &CZSC) -> HashMap { + czsc.bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect() +} + +/// 构建 `bar.id -> 最新 RawBar` 映射,用于在信号层对齐 Python 的“按当前 bars_raw 读取”语义。 +pub fn raw_bar_map(czsc: &CZSC) -> HashMap { + czsc.bars_raw.iter().map(|b| (b.id, b.clone())).collect() +} + +/// 用最新 `bars_raw` 覆盖同 id 的历史快照;未命中的 id 保留原值。 +pub fn remap_raw_bars(raw_bars: &[RawBar], latest_by_id: &HashMap) -> Vec { + raw_bars + .iter() + .map(|rb| { + latest_by_id + .get(&rb.id) + .cloned() + .unwrap_or_else(|| rb.clone()) + }) + .collect() +} + +/// 从 `RawBar` 序列中提取对应索引的数值序列(会过滤非有限值)。 +pub fn values_from_raw_bars( + raw_bars: &[RawBar], + id_to_idx: &HashMap, + values: &[f64], +) -> Vec { + raw_bars + .iter() + .filter_map(|rb| id_to_idx.get(&rb.id).map(|i| values[*i])) + .filter(|x| x.is_finite()) + .collect() +} + +/// 从 `FX` 的嵌套 K 线中提取对应索引的数值序列(会过滤非有限值)。 +pub fn values_from_fx( + fx: &czsc_core::objects::fx::FX, + id_to_idx: &HashMap, + values: &[f64], +) -> Vec { + let raw_bars: Vec = fx + .elements + .iter() + .flat_map(|nb| nb.elements.iter().cloned()) + .collect(); + values_from_raw_bars(&raw_bars, id_to_idx, values) +} + +/// 解析分钟级别周期,如 `60分钟` -> `60` +pub fn parse_minute_freq(freq: &str) -> Option { + if !freq.ends_with("分钟") { + return None; + } + let n = freq.trim_end_matches("分钟").parse::().ok()?; + if n > 0 { + Some(n) + } else { + None + } +} + +/// 计算分钟周期对应的结束时间(与 Python freq_end_time 口径一致) +pub fn minute_freq_end_time( + dt: chrono::DateTime, + freq: &str, +) -> Option> { + let Some(m) = parse_minute_freq(freq) else { + return Some(dt); + }; + + // 与 Python `freq_end_time(..., market="A股")` 对齐: + // A 股分钟数据在 09:30/13:30 等边界点不是简单的整除向上取整。 + let hm = dt.format("%H:%M").to_string(); + if freq == "30分钟" { + let hm_edt = match hm.as_str() { + "09:30" | "10:00" => Some("10:00"), + "10:30" => Some("10:30"), + "11:00" => Some("11:00"), + "11:30" => Some("11:30"), + "13:30" => Some("13:30"), + "14:00" => Some("14:00"), + "14:30" => Some("14:30"), + "15:00" => Some("15:00"), + _ => None, + }; + if let Some(edt) = hm_edt { + let mut it = edt.split(':'); + let h = it.next().and_then(|x| x.parse::().ok())?; + let m = it.next().and_then(|x| x.parse::().ok())?; + return dt.date_naive().and_hms_opt(h, m, 0).map(|x| x.and_utc()); + } + } + + if freq == "60分钟" { + let hm_edt = match hm.as_str() { + "09:30" | "10:00" | "10:30" => Some("10:30"), + "11:00" | "11:30" => Some("11:30"), + "13:30" | "14:00" => Some("14:00"), + "14:30" | "15:00" => Some("15:00"), + _ => None, + }; + if let Some(edt) = hm_edt { + let mut it = edt.split(':'); + let h = it.next().and_then(|x| x.parse::().ok())?; + let m = it.next().and_then(|x| x.parse::().ok())?; + return dt.date_naive().and_hms_opt(h, m, 0).map(|x| x.and_utc()); + } + } + + let hm = i64::from(dt.hour()) * 60 + i64::from(dt.minute()); + let mut end_hm = if hm % m == 0 { hm } else { (hm / m + 1) * m }; + let mut day = dt.date_naive(); + if end_hm >= 24 * 60 { + end_hm -= 24 * 60; + if let Some(next_day) = day.checked_add_signed(Duration::days(1)) { + day = next_day; + } + } + let h = (end_hm / 60) as u32; + let mm = (end_hm % 60) as u32; + day.and_hms_opt(h, mm, 0).map(|x| x.and_utc()) +} + +/// 获取最新价,优先 trader 直接提供,否则回退到指定级别最后一根K线收盘 +pub fn latest_price(cat: &dyn TraderState, freq1: &str) -> Option { + cat.latest_price().or_else(|| { + cat.get_czsc(freq1) + .and_then(|c| c.bars_raw.last().map(|b| b.close)) + }) +} + +/// 获取最后一次开仓操作(过滤平仓) +pub fn last_open_operate<'a>( + cat: &'a dyn TraderState, + pos_name: &str, +) -> Option<&'a OperateRecord> { + let pos = cat.get_position(pos_name)?; + let op = pos.operates.last()?; + if matches!(op.op, Operate::SE | Operate::LE) { + None + } else { + Some(op) + } +} + +/// 解析信号字符串,失败时返回空向量 +pub fn signal_from_str(sig_str: &str) -> Vec { + Signal::from_str(sig_str).map_or_else(|_| vec![], |s| vec![s]) +} + +/// 底层统一构造器:生成标准 7 段信号并严格校验格式。 +pub fn make_signal7( + k1: &str, + k2: &str, + k3: &str, + v1: &str, + v2: &str, + v3: &str, + score: i32, +) -> Vec { + let sig_str = format!("{}_{}_{}_{}_{}_{}_{}", k1, k2, k3, v1, v2, v3, score); + match Signal::from_str(&sig_str) { + Ok(sig) => vec![sig], + Err(err) => panic!("invalid signal generated: {sig_str}; error: {err}"), + } +} + +/// 构造标准 7 段信号(含 v1/v2,v3=`任意`) +pub fn make_signal(k1: &str, k2: &str, k3: &str, v1: &str, v2: &str) -> Vec { + make_signal7(k1, k2, k3, v1, v2, "任意", 0) +} + +/// 构造标准 7 段信号(仅 v1,v2/v3 默认 `任意`) +pub fn make_signal_v1(k1: &str, k2: &str, k3: &str, v1: &str) -> Vec { + make_signal7(k1, k2, k3, v1, "任意", "任意", 0) +} + +/// K线级:仅设置 v1,v2/v3 固定 `任意` +pub fn make_kline_signal_v1(k1: &str, k2: &str, k3: &str, v1: &str) -> Vec { + make_signal7(k1, k2, k3, v1, "任意", "任意", 0) +} + +/// K线级:设置 v1/v2,v3 固定 `任意` +pub fn make_kline_signal_v2(k1: &str, k2: &str, k3: &str, v1: &str, v2: &str) -> Vec { + make_signal7(k1, k2, k3, v1, v2, "任意", 0) +} + +/// K线级:设置 v1/v2/v3,score 固定 `0` +pub fn make_kline_signal_v3( + k1: &str, + k2: &str, + k3: &str, + v1: &str, + v2: &str, + v3: &str, +) -> Vec { + make_signal7(k1, k2, k3, v1, v2, v3, 0) +} + +/// 周内中文标签(周一到周日) +pub fn weekday_cn(dt: chrono::DateTime) -> &'static str { + match dt.weekday().num_days_from_monday() { + 0 => "周一", + 1 => "周二", + 2 => "周三", + 3 => "周四", + 4 => "周五", + 5 => "周六", + _ => "周日", + } +} + +/// 最近 `window` 根中的日内时间去重排序后,返回最后一根所在分段(1-based) +pub fn intraday_time_segment(bars: &[RawBar], window: usize) -> Option { + if bars.len() < window || window == 0 { + return None; + } + let sub = &bars[bars.len() - window..]; + let mut spans: Vec = sub + .iter() + .map(|x| x.dt.format("%H:%M").to_string()) + .collect(); + spans.sort(); + spans.dedup(); + let cur = bars.last()?.dt.format("%H:%M").to_string(); + spans.iter().position(|x| x == &cur).map(|i| i + 1) +} + +/// 快慢线交叉信息(基础版) +pub fn fast_slow_cross(fast: &[f64], slow: &[f64]) -> Vec> { + let mut res = Vec::new(); + let len = fast.len(); + if len < 2 { + return res; + } + + let mut last_cross_idx = 0; + for i in 2..len { + let f0 = fast[i - 1]; + let s0 = slow[i - 1]; + let f1 = fast[i]; + let s1 = slow[i]; + + if f0 <= s0 && f1 > s1 { + let mut cross = HashMap::new(); + cross.insert("类型", 1.0); + cross.insert("快线", f1); + cross.insert("慢线", s1); + cross.insert("距离", (i - last_cross_idx) as f64); + res.push(cross); + last_cross_idx = i; + } else if f0 >= s0 && f1 < s1 { + let mut cross = HashMap::new(); + cross.insert("类型", -1.0); + cross.insert("快线", f1); + cross.insert("慢线", s1); + cross.insert("距离", (i - last_cross_idx) as f64); + res.push(cross); + last_cross_idx = i; + } + } + res +} + +#[derive(Clone, Debug)] +pub struct CrossInfoExt { + pub kind: i32, + pub slow: f64, + pub distance: usize, + pub to_now: usize, + pub area: f64, +} + +/// 快慢线交叉信息(扩展版) +pub fn fast_slow_cross_ext(fast: &[f64], slow: &[f64]) -> Vec { + let len = fast.len().min(slow.len()); + if len < 3 { + return Vec::new(); + } + + let delta: Vec = (0..len).map(|i| fast[i] - slow[i]).collect(); + let mut cross = Vec::new(); + let mut last_i: isize = -1; + let mut last_v = 0.0; + + for i in 0..len { + let v = delta[i]; + last_i += 1; + last_v += v.abs(); + let kind = if i >= 2 && delta[i - 1] <= 0.0 && delta[i] > 0.0 { + 1 + } else if i >= 2 && delta[i - 1] >= 0.0 && delta[i] < 0.0 { + -1 + } else { + 0 + }; + if kind == 0 { + continue; + } + cross.push(CrossInfoExt { + kind, + slow: slow[i], + distance: last_i as usize, + to_now: len - i, + area: (last_v * 10000.0).round() / 10000.0, + }); + last_i = 0; + last_v = 0.0; + } + cross +} + +/// 计算两序列最近一次穿越零轴以来的长度 +pub fn cross_zero_axis(n1: &[f64], n2: &[f64]) -> usize { + assert_eq!(n1.len(), n2.len(), "输入两个数列长度不等"); + if n1.is_empty() { + return 0; + } + let mut num1 = 0usize; + let mut num2 = 0usize; + + let mut n1_rev = n1.to_vec(); + n1_rev.reverse(); + let a = n1_rev[0]; + if n1_rev.iter().any(|x| a * *x < 0.0) { + let x: Vec = n1_rev.iter().map(|v| a * *v < 0.0).collect(); + let mut found = false; + for i in 0..x.len().saturating_sub(1) { + if x[i] != x[i + 1] { + num1 = i + 2; + found = true; + break; + } + } + if !found { + num1 = 2; + } + } + + let mut n2_rev = n2.to_vec(); + n2_rev.reverse(); + let b = n2_rev[0]; + if n2_rev.iter().any(|x| b * *x < 0.0) { + let x: Vec = n2_rev.iter().map(|v| b * *v < 0.0).collect(); + let mut found = false; + for i in 0..x.len().saturating_sub(1) { + if x[i] != x[i + 1] { + num2 = i + 2; + found = true; + break; + } + } + if !found { + num2 = 2; + } + } + num1.max(num2) +} + +/// 统计 x1 从上向下穿越 x2 的次数 +pub fn down_cross_count(x1: &[f64], x2: &[f64]) -> usize { + let mut num = 0usize; + if x1.len() != x2.len() || x1.len() < 2 { + return num; + } + for i in 0..x1.len() - 1 { + let b1 = x1[i] < x2[i]; + let b2 = x1[i + 1] < x2[i + 1]; + if b2 && b1 != b2 { + num += 1; + } + } + num +} + +/// 计算过滤后的金叉/死叉次数 +pub fn cal_cross_num(cross: &[HashMap<&'static str, f64>], distance: usize) -> (usize, usize) { + if cross.is_empty() { + return (0, 0); + } + + let mut cross_work = cross.to_vec(); + let mut filtered: Vec> = Vec::new(); + + if cross_work.len() == 1 { + filtered = cross_work; + } else if cross_work.len() == 2 { + let dist = cross_work + .last() + .and_then(|x| x.get("距离")) + .copied() + .unwrap_or(0.0); + filtered = if dist < distance as f64 { + Vec::new() + } else { + cross_work + }; + } else { + let last_dist = cross_work + .last() + .and_then(|x| x.get("距离")) + .copied() + .unwrap_or(0.0); + let re_cross: Vec> = if last_dist < distance as f64 { + let last_cross = cross_work.pop().unwrap(); + let _ = cross_work.pop(); + let mut tmp: Vec> = cross_work + .into_iter() + .filter(|x| x.get("距离").copied().unwrap_or(0.0) >= distance as f64) + .collect(); + tmp.push(last_cross); + tmp + } else { + cross_work + .into_iter() + .filter(|x| x.get("距离").copied().unwrap_or(0.0) >= distance as f64) + .collect() + }; + + for i in 0..re_cross.len() { + if !filtered.is_empty() && i >= 1 { + let t_i = re_cross[i].get("类型").copied().unwrap_or(0.0); + let t_prev = re_cross[i - 1].get("类型").copied().unwrap_or(0.0); + if (t_i - t_prev).abs() <= f64::EPSILON { + filtered.pop(); + filtered.push(re_cross[i].clone()); + continue; + } + } + filtered.push(re_cross[i].clone()); + } + } + + let jc = filtered + .iter() + .filter(|x| x.get("类型").copied().unwrap_or(0.0) > 0.0) + .count(); + let sc = filtered + .iter() + .filter(|x| x.get("类型").copied().unwrap_or(0.0) < 0.0) + .count(); + (jc, sc) +} + +/// 线性回归斜率 +pub fn linear_slope(y: &[f64]) -> f64 { + let n = y.len(); + if n < 2 { + return 0.0; + } + let n_f = n as f64; + let sum_x = (n_f - 1.0) * n_f / 2.0; + let sum_xx = (n_f - 1.0) * n_f * (2.0 * n_f - 1.0) / 6.0; + let sum_y: f64 = y.iter().sum(); + let sum_xy: f64 = y.iter().enumerate().map(|(i, v)| i as f64 * *v).sum(); + let denom = n_f * sum_xx - sum_x * sum_x; + if denom.abs() <= f64::EPSILON { + return 0.0; + } + (n_f * sum_xy - sum_x * sum_y) / denom +} + +/// 统计序列末尾连续相同元素个数 +pub fn count_last_same(seq: &[T]) -> usize { + if seq.is_empty() { + return 0; + } + let last = &seq[seq.len() - 1]; + let mut c = 0usize; + for x in seq.iter().rev() { + if x == last { + c += 1; + } else { + break; + } + } + c +} + +/// 等宽分箱:返回最后一个值所在分箱 +pub fn cut_last_bin_label(values: &[f64], n: usize) -> Option { + if n == 0 || values.is_empty() { + return None; + } + let finite: Vec = values.iter().copied().filter(|x| x.is_finite()).collect(); + if finite.is_empty() { + return None; + } + let min_v = finite.iter().copied().fold(f64::INFINITY, f64::min); + let max_v = finite.iter().copied().fold(f64::NEG_INFINITY, f64::max); + if !min_v.is_finite() || !max_v.is_finite() { + return None; + } + if (max_v - min_v).abs() <= f64::EPSILON { + return Some(n.div_ceil(2)); + } + let last = *values.last()?; + if !last.is_finite() { + return None; + } + let width = (max_v - min_v) / n as f64; + if width <= 0.0 || !width.is_finite() { + return Some(1); + } + let mut idx = ((last - min_v) / width).floor() as isize + 1; + if idx < 1 { + idx = 1; + } + if idx > n as isize { + idx = n as isize; + } + Some(idx as usize) +} + +/// 标准差(基于绝对值序列) +pub fn std_abs_series(values: &[f64]) -> f64 { + if values.is_empty() || values.iter().any(|x| !x.is_finite()) { + return f64::NAN; + } + let abs_vals: Vec = values.iter().map(|x| x.abs()).collect(); + let mean = abs_vals.iter().sum::() / abs_vals.len() as f64; + let var = abs_vals.iter().map(|x| (x - mean).powi(2)).sum::() / abs_vals.len() as f64; + var.sqrt() +} + +/// 分位数分箱:返回最后一个值所在分箱 +pub fn qcut_last_label(values: &[f64], q: usize) -> Option { + if q == 0 || values.is_empty() { + return None; + } + let mut sorted: Vec = values.iter().copied().filter(|x| x.is_finite()).collect(); + if sorted.is_empty() { + return None; + } + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + + let quantile = |p: f64| -> f64 { + if sorted.len() == 1 { + return sorted[0]; + } + let h = (sorted.len() - 1) as f64 * p; + let i = h.floor() as usize; + let j = h.ceil() as usize; + if i == j { + sorted[i] + } else { + sorted[i] + (h - i as f64) * (sorted[j] - sorted[i]) + } + }; + + let mut edges = Vec::with_capacity(q + 1); + for i in 0..=q { + edges.push(quantile(i as f64 / q as f64)); + } + edges.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON); + if edges.len() <= 1 { + return None; + } + + let x = *values.last()?; + if !x.is_finite() { + return None; + } + let bins = edges.len() - 1; + if x < edges[0] || x > edges[bins] { + return None; + } + for i in 0..bins { + let left_ok = if i == 0 { x >= edges[i] } else { x > edges[i] }; + let right_ok = x <= edges[i + 1]; + if left_ok && right_ok { + return Some(i); + } + } + None +} + +/// 对齐 `pandas.cut(values, bins=n, right=True)`,返回最后一个值的分箱标签(1..=n) +pub fn pd_cut_last_label(values: &[f64], n: usize) -> Option { + if n == 0 || values.is_empty() { + return None; + } + if values.iter().any(|x| !x.is_finite()) { + return None; + } + let min_val = values.iter().copied().fold(f64::INFINITY, f64::min); + let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + if !min_val.is_finite() || !max_val.is_finite() { + return None; + } + + // 对齐 pandas.cut bins=int 行为: + // - 常量序列:左右各扩展 0.1%(零值用固定 ±0.001) + // - 非常量序列:只扩展左边界 0.1% + let bins: Vec = if (max_val - min_val).abs() < f64::EPSILON { + let (lo, hi) = if min_val != 0.0 { + let delta = 0.001 * min_val.abs(); + (min_val - delta, max_val + delta) + } else { + (-0.001, 0.001) + }; + (0..=n) + .map(|i| lo + (hi - lo) * i as f64 / n as f64) + .collect() + } else { + let mut bs: Vec = (0..=n) + .map(|i| min_val + (max_val - min_val) * i as f64 / n as f64) + .collect(); + bs[0] -= (max_val - min_val) * 0.001; + bs + }; + + let x = *values.last()?; + // lower_bound: 第一个 >= x 的索引,等价 pandas/numpy searchsorted(side='left') + let mut l = 0usize; + let mut r = bins.len(); + while l < r { + let mid = (l + r) / 2; + if bins[mid] < x { + l = mid + 1; + } else { + r = mid; + } + } + let idx = l; + let q = if idx == 0 { + 1 + } else if idx >= bins.len() { + n + } else { + idx + }; + Some(q) +} + +#[cfg(test)] +#[allow(clippy::items_after_test_module)] +mod tests { + use super::{ + cut_last_bin_label, get_str_param, get_sub_elements, get_usize_param, + intraday_time_segment, pd_cut_last_label, weekday_cn, + }; + use crate::params::ParamView; + use czsc_core::objects::bar::RawBar; + use serde_json::Value; + use std::collections::HashMap; + + #[test] + fn test_get_sub_elements_di1() { + let x = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; + assert_eq!(get_sub_elements(&x, 1, 3), &[7, 8, 9]); + } + + #[test] + fn test_get_sub_elements_di2() { + let x = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; + assert_eq!(get_sub_elements(&x, 2, 3), &[6, 7, 8]); + } + + #[test] + fn test_get_sub_elements_short_and_out_of_range() { + let x = vec![1, 2, 3]; + assert_eq!(get_sub_elements(&x, 1, 10), &[1, 2, 3]); + assert_eq!(get_sub_elements(&x, 4, 2), &[] as &[i32]); + } + + #[test] + fn test_get_sub_elements_n_zero_aligns_python_slice() { + let x = vec![1, 2, 3, 4]; + assert_eq!(get_sub_elements(&x, 1, 0), &[1, 2, 3, 4]); + assert_eq!(get_sub_elements(&x, 2, 0), &[] as &[i32]); + } + + #[test] + fn test_param_helpers_accept_param_view() { + let mut m = HashMap::new(); + m.insert("di".to_string(), Value::String("3".to_string())); + m.insert("ma_type".to_string(), Value::String("EMA".to_string())); + let p = ParamView::new(&m); + assert_eq!(get_usize_param(&p, "di", 1), 3); + assert_eq!(get_str_param(&p, "ma_type", "SMA"), "EMA"); + assert_eq!(get_usize_param(&p, "n", 5), 5); + } + + #[test] + fn test_cut_last_bin_label_constant_series_center_bin() { + let v = vec![-13.858860437903786]; + assert_eq!(cut_last_bin_label(&v, 5), Some(3)); + assert_eq!(cut_last_bin_label(&v, 4), Some(2)); + assert_eq!(cut_last_bin_label(&v, 3), Some(2)); + } + + #[test] + fn test_cut_last_bin_label_non_constant_series() { + let v = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + assert_eq!(cut_last_bin_label(&v, 5), Some(5)); + } + + #[test] + fn test_pd_cut_last_label_constant_center_bin() { + let v1 = vec![5.0; 100]; + let v2 = vec![0.0; 100]; + assert_eq!(pd_cut_last_label(&v1, 5), Some(3)); + assert_eq!(pd_cut_last_label(&v2, 5), Some(3)); + } + + #[test] + fn test_pd_cut_last_label_non_constant_boundary_bins() { + let mut low = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + low[4] = 1.0; + let high = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + assert_eq!(pd_cut_last_label(&low, 5), Some(1)); + assert_eq!(pd_cut_last_label(&high, 5), Some(5)); + } + + #[test] + fn test_pd_cut_last_label_rejects_non_finite() { + let v = vec![1.0, 2.0, f64::NAN]; + assert_eq!(pd_cut_last_label(&v, 5), None); + } + + #[test] + fn test_intraday_time_segment_basic() { + use chrono::TimeZone; + let bars = vec![ + RawBar { + symbol: "T".into(), + id: 1, + dt: chrono::Utc.with_ymd_and_hms(2024, 1, 1, 9, 30, 0).unwrap(), + freq: czsc_core::objects::freq::Freq::F60, + open: 1.0, + close: 1.0, + high: 1.0, + low: 1.0, + vol: 1.0, + amount: 1.0, + cache: Default::default(), + }, + RawBar { + symbol: "T".into(), + id: 2, + dt: chrono::Utc.with_ymd_and_hms(2024, 1, 2, 10, 30, 0).unwrap(), + freq: czsc_core::objects::freq::Freq::F60, + open: 1.0, + close: 1.0, + high: 1.0, + low: 1.0, + vol: 1.0, + amount: 1.0, + cache: Default::default(), + }, + RawBar { + symbol: "T".into(), + id: 3, + dt: chrono::Utc.with_ymd_and_hms(2024, 1, 3, 10, 30, 0).unwrap(), + freq: czsc_core::objects::freq::Freq::F60, + open: 1.0, + close: 1.0, + high: 1.0, + low: 1.0, + vol: 1.0, + amount: 1.0, + cache: Default::default(), + }, + ]; + assert_eq!(intraday_time_segment(&bars, 3), Some(2)); + assert_eq!(intraday_time_segment(&bars, 4), None); + } + + #[test] + fn test_weekday_cn_mapping() { + use chrono::TimeZone; + let dt = chrono::Utc.with_ymd_and_hms(2024, 1, 1, 9, 30, 0).unwrap(); // 周一 + assert_eq!(weekday_cn(dt), "周一"); + } +} + +/// 解析字符串参数 +pub fn get_str_param<'a>(params: &'a ParamView, key: &str, default: &'a str) -> &'a str { + if let Some(val) = params.value(key) { + if let Some(s) = val.as_str() { + return s; + } + } + default +} diff --git a/crates/czsc-signals/src/utils/ta.rs b/crates/czsc-signals/src/utils/ta.rs new file mode 100644 index 000000000..31571fd32 --- /dev/null +++ b/crates/czsc-signals/src/utils/ta.rs @@ -0,0 +1,1662 @@ +use crate::types::{BollSeries, KdjSeries, MacdSeries, TaCache}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use std::collections::HashMap; + +#[derive(Debug, Clone, Copy)] +pub enum MacdField { + Dif, + Dea, + Macd, +} + +fn macd_field_at(series: &MacdSeries, idx: usize, field: MacdField) -> f64 { + match field { + MacdField::Dif => series.dif[idx], + MacdField::Dea => series.dea[idx], + MacdField::Macd => series.macd[idx], + } +} + +fn macd_field_from_tuple(values: (f64, f64, f64), field: MacdField) -> f64 { + match field { + MacdField::Dif => values.0, + MacdField::Dea => values.1, + MacdField::Macd => values.2, + } +} + +/// 读取指定 RawBar 的 MACD 字段值,并对齐 Python 的“RawBar 快照”语义: +/// +/// - 若 `raw_bar.close` 与当前 `czsc.bars_raw[idx].close` 一致,直接返回缓存值; +/// - 若不一致(常见于同 dt 延伸阶段的历史快照),在“仅替换该 idx close”为快照值的条件下重算 MACD, +/// 并返回该 idx 对应字段值。 +/// +/// `snapshot_overrides` 用于同一轮信号计算内复用同 id 的重算结果,避免重复全量计算。 +#[allow(clippy::too_many_arguments)] +fn macd_snapshot_field_value_with_calc( + czsc: &CZSC, + series: &MacdSeries, + id_to_idx: &HashMap, + raw_bar: &RawBar, + short: usize, + long: usize, + m: usize, + field: MacdField, + snapshot_overrides: &mut HashMap, +) -> Option { + let &idx = id_to_idx.get(&raw_bar.id)?; + let base = macd_field_at(series, idx, field); + let current_close = czsc.bars_raw.get(idx)?.close; + if (current_close - raw_bar.close).abs() <= f64::EPSILON { + return Some(base); + } + + if let Some(values) = snapshot_overrides.get(&raw_bar.id) { + return Some(macd_field_from_tuple(*values, field)); + } + + let mut close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + close[idx] = raw_bar.close; + let snapshot = calc_macd_cache_style(&close, short, long, m); + let values = (snapshot.dif[idx], snapshot.dea[idx], snapshot.macd[idx]); + snapshot_overrides.insert(raw_bar.id, values); + Some(macd_field_from_tuple(values, field)) +} + +#[allow(clippy::too_many_arguments)] +pub fn macd_snapshot_field_value( + czsc: &CZSC, + series: &MacdSeries, + id_to_idx: &HashMap, + raw_bar: &RawBar, + short: usize, + long: usize, + m: usize, + field: MacdField, + snapshot_overrides: &mut HashMap, +) -> Option { + macd_snapshot_field_value_with_calc( + czsc, + series, + id_to_idx, + raw_bar, + short, + long, + m, + field, + snapshot_overrides, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn macd_snapshot_field_value_py_style( + czsc: &CZSC, + series: &MacdSeries, + id_to_idx: &HashMap, + raw_bar: &RawBar, + short: usize, + long: usize, + m: usize, + field: MacdField, + snapshot_overrides: &mut HashMap, +) -> Option { + macd_snapshot_field_value( + czsc, + series, + id_to_idx, + raw_bar, + short, + long, + m, + field, + snapshot_overrides, + ) +} + +/// 读取指定 RawBar 的 MA 值,并对齐 Python 的“RawBar 快照”语义: +/// +/// - 若 `raw_bar.close` 与当前 `czsc.bars_raw[idx].close` 一致,直接返回缓存值; +/// - 若不一致(常见于同 dt 延伸阶段的历史快照),在“仅替换该 idx close”为快照值的条件下重算 MA, +/// 并返回该 idx 对应值。 +/// +/// `snapshot_overrides` 用于同一轮信号计算内复用同 id 的重算结果,避免重复全量计算。 +pub fn ma_snapshot_value( + czsc: &CZSC, + series: &[f64], + id_to_idx: &HashMap, + raw_bar: &RawBar, + ma_type: &str, + timeperiod: usize, + snapshot_overrides: &mut HashMap, +) -> Option { + let &idx = id_to_idx.get(&raw_bar.id)?; + let base = *series.get(idx)?; + let current_close = czsc.bars_raw.get(idx)?.close; + if (current_close - raw_bar.close).abs() <= f64::EPSILON { + return Some(base); + } + + if let Some(value) = snapshot_overrides.get(&raw_bar.id) { + return Some(*value); + } + + let mut close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + close[idx] = raw_bar.close; + let snapshot = match ma_type.to_uppercase().as_str() { + "EMA" => calc_ema_cache_style(&close, timeperiod), + "WMA" => calc_wma_cache_style(&close, timeperiod), + _ => calc_sma_cache_style(&close, timeperiod), + }; + let value = *snapshot.get(idx)?; + snapshot_overrides.insert(raw_bar.id, value); + Some(value) +} + +/// 计算简单移动平均线 (SMA),对齐 Python `czsc.utils.ta.SMA` +pub fn calc_sma(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut ms = vec![f64::NAN; len]; + if len == 0 || n == 0 { + return ms; + } + for i in 0..len { + let start = if i < n { 0 } else { i + 1 - n }; + let window = &series[start..=i]; + let mean = window.iter().sum::() / window.len() as f64; + ms[i] = (mean * 10_000.0).round() / 10_000.0; + } + ms +} + +/// 计算指数移动平均线 (EMA),对齐 Python `czsc.utils.ta.EMA` +pub fn calc_ema(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut ms = vec![f64::NAN; len]; + if len == 0 || n == 0 { + return ms; + } + ms[0] = series[0]; + for i in 1..len { + let ema = (2.0 * series[i] + ms[i - 1] * (n as f64 - 1.0)) / (n as f64 + 1.0); + ms[i] = ema; + } + for value in &mut ms { + *value = (*value * 10_000.0).round() / 10_000.0; + } + ms +} + +/// 计算加权移动平均线 (WMA),对齐 Python `czsc.utils.ta.WMA` +pub fn calc_wma(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut ms = vec![f64::NAN; len]; + if len == 0 || n == 0 { + return ms; + } + let denom = (n * (n + 1) / 2) as f64; + for i in n..len { + let window = &series[i + 1 - n..=i]; + let weighted = window + .iter() + .enumerate() + .map(|(idx, value)| *value * (idx + 1) as f64) + .sum::(); + ms[i] = (weighted / denom * 10_000.0).round() / 10_000.0; + } + ms +} + +fn calc_sma_cache_style(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut out = vec![f64::NAN; len]; + if len < n || n == 0 { + return out; + } + let mut sum = series[..n].iter().sum::(); + out[n - 1] = sum / n as f64; + for i in n..len { + // 对齐 TA-Lib SMA 的累计顺序:先减旧值,再加新值。 + // 该顺序会影响极小浮点误差在边界位上的符号,进而影响分类信号。 + sum -= series[i - n]; + sum += series[i]; + out[i] = sum / n as f64; + } + out +} + +fn calc_ema_cache_style(series: &[f64], n: usize) -> Vec { + calc_ema_talib_style(series, n) +} + +fn calc_wma_cache_style(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut out = vec![f64::NAN; len]; + if len < n || n == 0 { + return out; + } + let denom = (n * (n + 1) / 2) as f64; + for i in (n - 1)..len { + let window = &series[i + 1 - n..=i]; + let weighted = window + .iter() + .enumerate() + .map(|(idx, value)| *value * (idx + 1) as f64) + .sum::(); + out[i] = weighted / denom; + } + out +} + +/// 计算 MACD {dif, dea, macd},对齐 TA-Lib 原始三元组语义: +/// - `dif` = MACD line +/// - `dea` = signal line +/// - `macd` = histogram +pub fn calc_macd(series: &[f64], short: usize, long: usize, m: usize) -> MacdSeries { + let len = series.len(); + let mut dif = vec![f64::NAN; len]; + let mut dea = vec![f64::NAN; len]; + let mut macd = vec![f64::NAN; len]; + if len == 0 || short == 0 || long == 0 || m == 0 || len < long { + return MacdSeries { + ids: Vec::new(), + dif, + dea, + macd, + }; + } + + let fast_offset = long.saturating_sub(short); + let ema_short_tail = calc_ema_talib_style(&series[fast_offset..], short); + let ema_long = calc_ema_talib_style(series, long); + let mut dif_input = vec![f64::NAN; len]; + let mut dif_raw = vec![f64::NAN; len]; + for i in 0..len { + let ema_short_i = if i >= fast_offset { + ema_short_tail[i - fast_offset] + } else { + f64::NAN + }; + if ema_short_i.is_finite() && ema_long[i].is_finite() { + dif_raw[i] = ema_short_i - ema_long[i]; + dif_input[i] = dif_raw[i]; + } + } + + dea = calc_ema_talib_skip_leading_nan(&dif_input, m); + + for i in 0..len { + if dif_raw[i].is_finite() && dea[i].is_finite() { + dif[i] = dif_raw[i]; + macd[i] = dif_raw[i] - dea[i]; + } + } + + MacdSeries { + ids: Vec::new(), + dif, + dea, + macd, + } +} + +fn calc_ema_talib_style(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut out = vec![f64::NAN; len]; + if len < n || n == 0 { + return out; + } + let alpha = 2.0 / (n as f64 + 1.0); + let seed = series[..n].iter().sum::() / n as f64; + out[n - 1] = seed; + for i in n..len { + out[i] = alpha * series[i] + (1.0 - alpha) * out[i - 1]; + } + out +} + +fn calc_ema_talib_skip_leading_nan(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut out = vec![f64::NAN; len]; + let Some(start) = series.iter().position(|x| x.is_finite()) else { + return out; + }; + let tail = calc_ema_talib_style(&series[start..], n); + for (i, value) in tail.into_iter().enumerate() { + out[start + i] = value; + } + out +} + +/// 计算用于信号缓存的 MACD,默认与 `calc_macd` 使用同一 TA-Lib 原始三元组契约。 +pub fn calc_macd_cache_style(series: &[f64], short: usize, long: usize, m: usize) -> MacdSeries { + calc_macd(series, short, long, m) +} + +/// 与 `calc_macd` 完全同义,保留仅用于兼容迁移中的旧调用点。 +pub fn calc_macd_py_style(series: &[f64], short: usize, long: usize, m: usize) -> MacdSeries { + calc_macd(series, short, long, m) +} + +/// 计算 ATR(Wilder 平滑),与 TA-Lib ATR 口径一致 +pub fn calc_atr(high: &[f64], low: &[f64], close: &[f64], timeperiod: usize) -> Vec { + let len = high.len().min(low.len()).min(close.len()); + let mut atr = vec![f64::NAN; len]; + if len == 0 || timeperiod == 0 || len <= timeperiod { + return atr; + } + + let mut tr = vec![0.0; len]; + for i in 0..len { + let prev_close = if i == 0 { close[0] } else { close[i - 1] }; + let hl = high[i] - low[i]; + let hc = (high[i] - prev_close).abs(); + let lc = (low[i] - prev_close).abs(); + tr[i] = hl.max(hc).max(lc); + } + + // 对齐 TA-Lib ATR: 首个有效值在索引 `timeperiod`, + // 使用 TR[1..=timeperiod] 的均值作为种子(TR[0] 仅用于占位)。 + let first = tr[1..=timeperiod].iter().sum::() / timeperiod as f64; + atr[timeperiod] = first; + for i in (timeperiod + 1)..len { + atr[i] = (atr[i - 1] * (timeperiod as f64 - 1.0) + tr[i]) / timeperiod as f64; + } + atr +} + +/// 计算 CCI,与 TA-Lib CCI 公式一致 +pub fn calc_cci(high: &[f64], low: &[f64], close: &[f64], timeperiod: usize) -> Vec { + let len = high.len().min(low.len()).min(close.len()); + let mut out = vec![f64::NAN; len]; + if len == 0 || timeperiod == 0 || len < timeperiod { + return out; + } + + let tp: Vec = (0..len) + .map(|i| (high[i] + low[i] + close[i]) / 3.0) + .collect(); + for i in (timeperiod - 1)..len { + let w = &tp[i + 1 - timeperiod..=i]; + let ma = w.iter().sum::() / timeperiod as f64; + let md = w.iter().map(|x| (x - ma).abs()).sum::() / timeperiod as f64; + out[i] = if md == 0.0 { + 0.0 + } else { + (tp[i] - ma) / (0.015 * md) + }; + } + out +} + +/// 更新 MA 缓存 +pub fn update_ma_cache( + czsc: &CZSC, + cache_key: &str, + ma_type: &str, + timeperiod: usize, + cache: &mut TaCache, +) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let ma_type_u = ma_type.to_uppercase(); + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + + let mut need_init = !cache.series.contains_key(cache_key) || now_len < timeperiod + 15; + if !need_init { + if let Some(existing_ids) = cache.series_ids.get(cache_key) { + if now_len < 2 || existing_ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing_ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + let calc = |close: &[f64]| match ma_type_u.as_str() { + "SMA" => calc_sma_cache_style(close, timeperiod), + "EMA" => calc_ema_cache_style(close, timeperiod), + "WMA" => calc_wma_cache_style(close, timeperiod), + _ => calc_sma_cache_style(close, timeperiod), + }; + + if need_init { + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let res = calc(&close); + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.series.get(cache_key).unwrap(); + let existing_ids = cache.series_ids.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, existing[i]); + } + let mut res = Vec::with_capacity(now_len); + for id in &bar_ids { + res.push(*old_map.get(id).unwrap_or(&f64::NAN)); + } + + let window_size = (timeperiod + 10).min(now_len); + let window_start = now_len - window_size; + let close: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.close) + .collect(); + let partial = calc(&close); + for i in 1..=5.min(window_size) { + let dst_idx = now_len - i; + let src_idx = window_size - i; + res[dst_idx] = partial[src_idx]; + } + + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +/// 更新成交量 MA 缓存(对齐 Python `update_vol_ma_cache` 增量语义) +pub fn update_vol_ma_cache( + czsc: &CZSC, + cache_key: &str, + ma_type: &str, + timeperiod: usize, + cache: &mut TaCache, +) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let ma_type_u = ma_type.to_uppercase(); + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + + let mut need_init = !cache.series.contains_key(cache_key) || now_len < timeperiod + 15; + if !need_init { + if let Some(existing_ids) = cache.series_ids.get(cache_key) { + if now_len < 2 || existing_ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing_ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + let calc = |vol: &[f64]| match ma_type_u.as_str() { + "SMA" => calc_sma(vol, timeperiod), + "EMA" => calc_ema(vol, timeperiod), + "WMA" => calc_wma(vol, timeperiod), + _ => calc_sma(vol, timeperiod), + }; + + if need_init { + let vol: Vec = czsc.bars_raw.iter().map(|b| b.vol).collect(); + let res = calc(&vol); + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.series.get(cache_key).unwrap(); + let existing_ids = cache.series_ids.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, existing[i]); + } + let mut res = Vec::with_capacity(now_len); + for id in &bar_ids { + res.push(*old_map.get(id).unwrap_or(&f64::NAN)); + } + + let window_size = (timeperiod + 10).min(now_len); + let window_start = now_len - window_size; + let vol: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.vol) + .collect(); + let partial = calc(&vol); + for i in 1..=3.min(window_size) { + let dst_idx = now_len - i; + let src_idx = window_size - i; + res[dst_idx] = partial[src_idx]; + } + + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +/// 更新 MACD 缓存,对齐 Python `update_macd_cache` 的增量逻辑 +pub fn update_macd_cache( + czsc: &CZSC, + cache_key: &str, + short: usize, + long: usize, + m: usize, + cache: &mut TaCache, +) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + let min_count = m + long + 168; + + let mut need_init = !cache.macd.contains_key(cache_key) || now_len < min_count + 15; + if !need_init { + if let Some(existing) = cache.macd.get(cache_key) { + if now_len < 2 || existing.ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing.ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + if need_init { + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let mut res = calc_macd_cache_style(&close, short, long, m); + res.ids = bar_ids; + cache.macd.insert(cache_key.to_string(), res); + cache.last_len = now_len; + return; + } + + let existing = cache.macd.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing.ids.len()); + for (i, id) in existing.ids.iter().enumerate() { + old_map.insert(*id, (existing.dif[i], existing.dea[i], existing.macd[i])); + } + let mut dif = Vec::with_capacity(now_len); + let mut dea = Vec::with_capacity(now_len); + let mut macd = Vec::with_capacity(now_len); + for id in &bar_ids { + if let Some((d1, d2, d3)) = old_map.get(id) { + dif.push(*d1); + dea.push(*d2); + macd.push(*d3); + } else { + dif.push(f64::NAN); + dea.push(f64::NAN); + macd.push(f64::NAN); + } + } + + let window_size = (min_count + 10).min(now_len); + let window_start = now_len - window_size; + let close: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.close) + .collect(); + let partial = calc_macd_cache_style(&close, short, long, m); + + for i in 1..=5.min(window_size) { + let dst_idx = now_len - i; + let src_idx = window_size - i; + dif[dst_idx] = partial.dif[src_idx]; + dea[dst_idx] = partial.dea[src_idx]; + macd[dst_idx] = partial.macd[src_idx]; + } + + cache.macd.insert( + cache_key.to_string(), + MacdSeries { + ids: bar_ids, + dif, + dea, + macd, + }, + ); + cache.last_len = now_len; +} + +/// 与 `update_macd_cache` 完全同义,保留仅用于兼容迁移中的旧调用点。 +pub fn update_macd_cache_py_style( + czsc: &CZSC, + cache_key: &str, + short: usize, + long: usize, + m: usize, + cache: &mut TaCache, +) { + update_macd_cache(czsc, cache_key, short, long, m, cache); +} + +/// 更新 BOLL 缓存 +pub fn update_boll_cache( + czsc: &CZSC, + cache_key: &str, + timeperiod: usize, + nbdev: f64, + cache: &mut TaCache, +) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + + let calc_full = |close: &[f64]| { + let n = close.len(); + let mut upper = vec![f64::NAN; n]; + let mut mid = vec![f64::NAN; n]; + let mut lower = vec![f64::NAN; n]; + if n >= timeperiod { + for i in (timeperiod - 1)..n { + let window = &close[i + 1 - timeperiod..=i]; + let mean = window.iter().sum::() / timeperiod as f64; + // numpy/talib 默认用的是总体标准差(ddof=0) + let variance = + window.iter().map(|&x| (x - mean).powi(2)).sum::() / timeperiod as f64; + let std_dev = variance.sqrt(); + mid[i] = mean; + upper[i] = mean + nbdev * std_dev; + lower[i] = mean - nbdev * std_dev; + } + } + BollSeries { upper, mid, lower } + }; + + let mut need_init = !cache.boll.contains_key(cache_key) + || !cache.boll_ids.contains_key(cache_key) + || now_len < timeperiod + 15; + if !need_init { + if let Some(existing_ids) = cache.boll_ids.get(cache_key) { + if now_len < 2 || existing_ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing_ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + if need_init { + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let res = calc_full(&close); + cache.boll.insert(cache_key.to_string(), res); + cache.boll_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.boll.get(cache_key).unwrap(); + let existing_ids = cache.boll_ids.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, (existing.upper[i], existing.mid[i], existing.lower[i])); + } + + let mut upper = Vec::with_capacity(now_len); + let mut mid = Vec::with_capacity(now_len); + let mut lower = Vec::with_capacity(now_len); + for id in &bar_ids { + if let Some((u, m, l)) = old_map.get(id) { + upper.push(*u); + mid.push(*m); + lower.push(*l); + } else { + upper.push(f64::NAN); + mid.push(f64::NAN); + lower.push(f64::NAN); + } + } + + // 对齐 Python update_boll_cache:增量阶段重算尾窗并覆盖最近 5 根。 + let window_size = (timeperiod + 10).min(now_len); + let window_start = now_len - window_size; + let close: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.close) + .collect(); + let partial = calc_full(&close); + for i in 1..=5.min(window_size) { + let dst_idx = now_len - i; + let src_idx = window_size - i; + upper[dst_idx] = partial.upper[src_idx]; + mid[dst_idx] = partial.mid[src_idx]; + lower[dst_idx] = partial.lower[src_idx]; + } + + cache + .boll + .insert(cache_key.to_string(), BollSeries { upper, mid, lower }); + cache.boll_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +/// 更新 ATR 缓存,对齐 Python `update_atr_cache` 的初始化/增量口径 +pub fn update_atr_cache(czsc: &CZSC, cache_key: &str, timeperiod: usize, cache: &mut TaCache) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + + let mut need_init = !cache.series.contains_key(cache_key) || now_len < timeperiod + 15; + if !need_init { + if let Some(existing_ids) = cache.series_ids.get(cache_key) { + if now_len < 2 || existing_ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing_ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + let calc_full = |h: &[f64], l: &[f64], c: &[f64]| calc_atr(h, l, c, timeperiod); + + if need_init { + let high: Vec = czsc.bars_raw.iter().map(|b| b.high).collect(); + let low: Vec = czsc.bars_raw.iter().map(|b| b.low).collect(); + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let res = calc_full(&high, &low, &close); + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.series.get(cache_key).unwrap(); + let existing_ids = cache.series_ids.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, existing[i]); + } + let mut res = Vec::with_capacity(now_len); + for id in &bar_ids { + res.push(*old_map.get(id).unwrap_or(&f64::NAN)); + } + + // 对齐 Python update_atr_cache: 增量阶段回看 timeperiod+80 窗口 + let window_size = (timeperiod + 80).min(now_len); + let window_start = now_len - window_size; + let high: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.high) + .collect(); + let low: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.low) + .collect(); + let close: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.close) + .collect(); + let partial = calc_full(&high, &low, &close); + + // 对齐 Python: 历史 bar 仅补齐未写入过 cache_key 的值,不覆盖既有值。 + // 但最后一根未完成高周期 bar 在流式更新时会持续变化;Python 侧该对象的 cache + // 会随新对象重建而刷新,Rust 侧需要显式重算末值避免把 ATR 冻结在更早时刻。 + for (i, partial_i) in partial.iter().enumerate().take(window_size) { + let dst = window_start + i; + if res[dst].is_nan() { + res[dst] = *partial_i; + } + } + let last_dst = now_len - 1; + let last_src = window_size - 1; + res[last_dst] = partial[last_src]; + + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +/// 更新 CCI 缓存,对齐 Python `update_cci_cache` 的初始化/增量口径 +pub fn update_cci_cache(czsc: &CZSC, cache_key: &str, timeperiod: usize, cache: &mut TaCache) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + + let mut need_init = !cache.series.contains_key(cache_key) || now_len < timeperiod + 15; + if !need_init { + if let Some(existing_ids) = cache.series_ids.get(cache_key) { + if now_len < 2 || existing_ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing_ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + let calc_full = |h: &[f64], l: &[f64], c: &[f64]| calc_cci(h, l, c, timeperiod); + if need_init { + let high: Vec = czsc.bars_raw.iter().map(|b| b.high).collect(); + let low: Vec = czsc.bars_raw.iter().map(|b| b.low).collect(); + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let res = calc_full(&high, &low, &close); + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.series.get(cache_key).unwrap(); + let existing_ids = cache.series_ids.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, existing[i]); + } + let mut res = Vec::with_capacity(now_len); + for id in &bar_ids { + res.push(*old_map.get(id).unwrap_or(&f64::NAN)); + } + + // 对齐 Python update_cci_cache: 增量阶段回看 timeperiod + 10 + let window_size = (timeperiod + 10).min(now_len); + let window_start = now_len - window_size; + let high: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.high) + .collect(); + let low: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.low) + .collect(); + let close: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.close) + .collect(); + let partial = calc_full(&high, &low, &close); + + // 对齐 Python: 历史 bar 仅补齐未写入过 cache_key 的值,不覆盖既有值。 + // 但流式场景下未完成高周期 bar 会复用同一 id;Rust 侧需要显式刷新末值, + // 否则 CCI 会冻结在更早时刻,导致阈值类信号(如 CCI 决策区域)持续偏移。 + for (i, partial_i) in partial.iter().enumerate().take(window_size) { + let dst = window_start + i; + if res[dst].is_nan() { + res[dst] = *partial_i; + } + } + let last_dst = now_len - 1; + let last_src = window_size - 1; + res[last_dst] = partial[last_src]; + + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +/// 更新 KDJ 缓存 +pub fn update_kdj_cache( + czsc: &CZSC, + cache_key: &str, + fastk_period: usize, + slowk_period: usize, + slowd_period: usize, + cache: &mut TaCache, +) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 || fastk_period == 0 || slowk_period == 0 || slowd_period == 0 { + return; + } + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + let min_count = fastk_period + slowk_period; + + let mut need_init = !cache.kdj.contains_key(cache_key) || now_len < min_count + 15; + if !need_init { + if let Some(existing) = cache.kdj.get(cache_key) { + if now_len < 2 || existing.ids.is_empty() { + need_init = true; + } else { + let penultimate_id = bar_ids[now_len - 2]; + need_init = !existing.ids.contains(&penultimate_id); + } + } else { + need_init = true; + } + } + + if need_init { + let high: Vec = czsc.bars_raw.iter().map(|b| b.high).collect(); + let low: Vec = czsc.bars_raw.iter().map(|b| b.low).collect(); + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let (k, d) = calc_stoch( + &high, + &low, + &close, + fastk_period, + slowk_period, + slowd_period, + ); + let j: Vec = k + .iter() + .zip(d.iter()) + .map(|(x, y)| 3.0 * *x - 2.0 * *y) + .collect(); + cache.kdj.insert( + cache_key.to_string(), + KdjSeries { + ids: bar_ids, + k, + d, + j, + }, + ); + cache.last_len = now_len; + return; + } + + // 增量更新:先按 id 对齐旧缓存,再覆盖最近 5 根 + let existing = cache.kdj.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing.ids.len()); + for (i, id) in existing.ids.iter().enumerate() { + old_map.insert(*id, (existing.k[i], existing.d[i], existing.j[i])); + } + + let mut k = Vec::with_capacity(now_len); + let mut d = Vec::with_capacity(now_len); + let mut j = Vec::with_capacity(now_len); + for id in &bar_ids { + if let Some((k0, d0, j0)) = old_map.get(id) { + k.push(*k0); + d.push(*d0); + j.push(*j0); + } else { + k.push(f64::NAN); + d.push(f64::NAN); + j.push(f64::NAN); + } + } + + let window_size = (min_count + 10).min(now_len); + let window_start = now_len - window_size; + let high: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.high) + .collect(); + let low: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.low) + .collect(); + let close: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.close) + .collect(); + let (partial_k, partial_d) = calc_stoch( + &high, + &low, + &close, + fastk_period, + slowk_period, + slowd_period, + ); + for i in 1..=5.min(window_size) { + let dst_idx = now_len - i; + let src_idx = window_size - i; + k[dst_idx] = partial_k[src_idx]; + d[dst_idx] = partial_d[src_idx]; + j[dst_idx] = 3.0 * k[dst_idx] - 2.0 * d[dst_idx]; + } + + cache.kdj.insert( + cache_key.to_string(), + KdjSeries { + ids: bar_ids, + k, + d, + j, + }, + ); + cache.last_len = now_len; +} + +/// 更新 RSI 缓存 (Wilder's Smoothing,严格对齐 TA-Lib) +pub fn update_rsi_cache(czsc: &CZSC, cache_key: &str, timeperiod: usize, cache: &mut TaCache) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + + // 对齐 Python update_rsi_cache 的初始化/增量口径。 + // Rust 流式场景下同一高周期 bar 会在多次 update 中复用同一 id, + // 这里不能因“最后一根 id 已存在”直接返回,否则 RSI 末值会被冻结。 + // 因此每次都重算窗口尾部,确保未完成 bar 的 RSI 随 close 更新。 + if !cache.series.contains_key(cache_key) || !cache.series_ids.contains_key(cache_key) { + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let rsi_res = calc_rsi(&close, timeperiod); + cache.series.insert(cache_key.to_string(), rsi_res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.series.get(cache_key).unwrap(); + let existing_ids = cache.series_ids.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, existing[i]); + } + let use_full = + now_len < timeperiod + 15 || now_len < 2 || !old_map.contains_key(&bar_ids[now_len - 2]); + if use_full { + let close: Vec = czsc.bars_raw.iter().map(|b| b.close).collect(); + let rsi_res = calc_rsi(&close, timeperiod); + cache.series.insert(cache_key.to_string(), rsi_res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let mut rsi_res = Vec::with_capacity(now_len); + for id in &bar_ids { + rsi_res.push(*old_map.get(id).unwrap_or(&f64::NAN)); + } + + let window_size = (timeperiod + 10).min(now_len); + let window_start = now_len - window_size; + let close: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.close) + .collect(); + let partial = calc_rsi(&close, timeperiod); + for i in 1..=5.min(window_size) { + let dst_idx = now_len - i; + let src_idx = window_size - i; + rsi_res[dst_idx] = partial[src_idx]; + } + + cache.series.insert(cache_key.to_string(), rsi_res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +/// 更新 SAR 缓存(对齐 Python `update_sar_cache`) +pub fn update_sar_cache(czsc: &CZSC, cache_key: &str, cache: &mut TaCache) { + let now_len = czsc.bars_raw.len(); + if now_len == 0 { + return; + } + let bar_ids: Vec = czsc.bars_raw.iter().map(|b| b.id).collect(); + let calc_full = |h: &[f64], l: &[f64]| calc_sar(h, l, 0.02, 0.2); + + if !cache.series.contains_key(cache_key) || !cache.series_ids.contains_key(cache_key) { + let high: Vec = czsc.bars_raw.iter().map(|b| b.high).collect(); + let low: Vec = czsc.bars_raw.iter().map(|b| b.low).collect(); + let res = calc_full(&high, &low); + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; + return; + } + + let existing = cache.series.get(cache_key).unwrap(); + let existing_ids = cache.series_ids.get(cache_key).unwrap(); + let mut old_map: HashMap = HashMap::with_capacity(existing_ids.len()); + for (i, id) in existing_ids.iter().enumerate() { + old_map.insert(*id, existing[i]); + } + + let use_full = now_len < 50 || now_len < 2 || !old_map.contains_key(&bar_ids[now_len - 2]); + let (window_start, window_size) = if use_full { + (0usize, now_len) + } else { + let size = 120.min(now_len); + (now_len - size, size) + }; + + let high: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.high) + .collect(); + let low: Vec = czsc.bars_raw[window_start..] + .iter() + .map(|b| b.low) + .collect(); + let partial = calc_full(&high, &low); + + let mut res = Vec::with_capacity(now_len); + for id in &bar_ids { + res.push(*old_map.get(id).unwrap_or(&f64::NAN)); + } + + for (i, partial_i) in partial.iter().enumerate().take(window_size) { + let dst = window_start + i; + if !old_map.contains_key(&bar_ids[dst]) { + res[dst] = *partial_i; + } + } + // 与 ATR/RSI 的流式处理一致:无论 id 是否已存在,都刷新最后一根。 + // 这样才能对齐 Python 中“最后一根 bar 对象被重建后重新写 cache”的效果。 + let last_dst = now_len - 1; + let last_src = window_size - 1; + res[last_dst] = partial[last_src]; + + cache.series.insert(cache_key.to_string(), res); + cache.series_ids.insert(cache_key.to_string(), bar_ids); + cache.last_len = now_len; +} + +fn calc_sar(high: &[f64], low: &[f64], acceleration: f64, max_acceleration: f64) -> Vec { + let len = high.len().min(low.len()); + let mut out = vec![f64::NAN; len]; + if len < 2 { + return out; + } + + let up_move = high[1] - high[0]; + let down_move = low[0] - low[1]; + let plus_dm = if up_move > down_move && up_move > 0.0 { + up_move + } else { + 0.0 + }; + let minus_dm = if down_move > up_move && down_move > 0.0 { + down_move + } else { + 0.0 + }; + let mut is_long = minus_dm == 0.0 || plus_dm > minus_dm; + + let accel = acceleration.min(max_acceleration); + let mut af = accel; + let mut today_idx = 1usize; + let mut out_idx = 1usize; + let mut new_high = high[today_idx - 1]; + let mut new_low = low[today_idx - 1]; + + let mut ep; + let mut sar; + if is_long { + ep = high[today_idx]; + sar = new_low; + } else { + ep = low[today_idx]; + sar = new_high; + } + + new_low = low[today_idx]; + new_high = high[today_idx]; + + while today_idx < len { + let prev_low = new_low; + let prev_high = new_high; + new_low = low[today_idx]; + new_high = high[today_idx]; + today_idx += 1; + + if is_long { + if new_low <= sar { + is_long = false; + sar = ep; + if sar < prev_high { + sar = prev_high; + } + if sar < new_high { + sar = new_high; + } + out[out_idx] = sar; + af = accel; + ep = new_low; + sar = sar + af * (ep - sar); + if sar < prev_high { + sar = prev_high; + } + if sar < new_high { + sar = new_high; + } + } else { + out[out_idx] = sar; + if new_high > ep { + ep = new_high; + af = (af + accel).min(max_acceleration); + } + sar = sar + af * (ep - sar); + if sar > prev_low { + sar = prev_low; + } + if sar > new_low { + sar = new_low; + } + } + } else if new_high >= sar { + is_long = true; + sar = ep; + if sar > prev_low { + sar = prev_low; + } + if sar > new_low { + sar = new_low; + } + out[out_idx] = sar; + af = accel; + ep = new_high; + sar = sar + af * (ep - sar); + if sar > prev_low { + sar = prev_low; + } + if sar > new_low { + sar = new_low; + } + } else { + out[out_idx] = sar; + if new_low < ep { + ep = new_low; + af = (af + accel).min(max_acceleration); + } + sar = sar + af * (ep - sar); + if sar < prev_high { + sar = prev_high; + } + if sar < new_high { + sar = new_high; + } + } + out_idx += 1; + } + out +} + +fn calc_rsi(close: &[f64], timeperiod: usize) -> Vec { + let now_len = close.len(); + let mut rsi_res = vec![f64::NAN; now_len]; + if now_len <= timeperiod || timeperiod == 0 { + return rsi_res; + } + + let mut avg_gain = 0.0; + let mut avg_loss = 0.0; + for i in 1..=timeperiod { + let change = close[i] - close[i - 1]; + if change > 0.0 { + avg_gain += change; + } else { + avg_loss += -change; + } + } + avg_gain /= timeperiod as f64; + avg_loss /= timeperiod as f64; + + rsi_res[timeperiod] = { + let sum = avg_gain + avg_loss; + if sum != 0.0 { + 100.0 * (avg_gain / sum) + } else { + 0.0 + } + }; + + for i in (timeperiod + 1)..now_len { + let delta = close[i] - close[i - 1]; + + // 逐步对齐 TA-Lib: 先乘 period-1,再加今日涨跌,最后除以 period。 + avg_gain *= timeperiod as f64 - 1.0; + avg_loss *= timeperiod as f64 - 1.0; + if delta < 0.0 { + avg_loss -= delta; + } else { + avg_gain += delta; + } + avg_gain /= timeperiod as f64; + avg_loss /= timeperiod as f64; + + let sum = avg_gain + avg_loss; + rsi_res[i] = if sum != 0.0 { + 100.0 * (avg_gain / sum) + } else { + 0.0 + }; + } + + rsi_res +} + +fn calc_sma_nan(series: &[f64], n: usize) -> Vec { + let len = series.len(); + let mut out = vec![f64::NAN; len]; + if n == 0 || len < n { + return out; + } + for i in (n - 1)..len { + let w = &series[i + 1 - n..=i]; + if w.iter().any(|x| x.is_nan()) { + continue; + } + out[i] = w.iter().sum::() / n as f64; + } + out +} + +fn calc_stoch( + high: &[f64], + low: &[f64], + close: &[f64], + fastk_period: usize, + slowk_period: usize, + slowd_period: usize, +) -> (Vec, Vec) { + let len = close.len(); + let mut fastk = vec![f64::NAN; len]; + if len == 0 + || high.len() != len + || low.len() != len + || fastk_period == 0 + || slowk_period == 0 + || slowd_period == 0 + || len < fastk_period + { + return (vec![f64::NAN; len], vec![f64::NAN; len]); + } + + for i in (fastk_period - 1)..len { + let start = i + 1 - fastk_period; + let hh = high[start..=i] + .iter() + .fold(f64::NEG_INFINITY, |a, &b| a.max(b)); + let ll = low[start..=i].iter().fold(f64::INFINITY, |a, &b| a.min(b)); + if hh > ll { + fastk[i] = (close[i] - ll) / (hh - ll) * 100.0; + } + } + + let mut slowk = calc_sma_nan(&fastk, slowk_period); + let mut slowd = calc_sma_nan(&slowk, slowd_period); + + // Align with TA-Lib STOCH lookback: fastk-1 + slowk-1 + slowd-1. + // TA-Lib returns both slowk/slowd as NaN before this index. + let lookback = (fastk_period - 1) + (slowk_period - 1) + (slowd_period - 1); + let warmup = lookback.min(len); + for i in 0..warmup { + slowk[i] = f64::NAN; + slowd[i] = f64::NAN; + } + (slowk, slowd) +} + +#[cfg(test)] +mod tests { + use super::{ + calc_ema, calc_macd, calc_macd_cache_style, calc_macd_py_style, calc_rsi, calc_sma, + calc_stoch, calc_wma, + }; + + #[test] + fn test_calc_sma_matches_python_expanding_mean_before_window() { + let series = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let res = calc_sma(&series, 3); + assert_eq!(res, vec![1.0, 1.5, 2.0, 3.0, 4.0]); + } + + #[test] + fn test_calc_ema_matches_python_seed_from_first_value() { + let series = vec![10.1234, 14.9812, 9.3345, 11.7789]; + let res = calc_ema(&series, 3); + assert_eq!(res, vec![10.1234, 12.5523, 10.9434, 11.3612]); + } + + #[test] + fn test_calc_wma_matches_python_validity_window() { + let series = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let res = calc_wma(&series, 3); + assert!(res[0].is_nan()); + assert!(res[1].is_nan()); + assert!(res[2].is_nan()); + assert!((res[3] - 3.3333).abs() < 1e-4); + assert!((res[4] - 4.3333).abs() < 1e-4); + } + + #[test] + fn test_calc_macd_uses_talib_warmup_and_histogram_contract() { + let series = vec![ + 10.1234, 14.9812, 9.3345, 11.7789, 13.5521, 12.0044, 15.3311, 14.2088, 13.8877, + 16.2355, 15.0222, 17.1188, 16.0044, 18.4455, 17.9933, 19.2244, 18.1102, 17.8891, + 20.0033, 19.7721, 21.1144, 20.8842, 22.3355, 21.6622, 23.1188, 22.7744, 24.3355, + 23.9911, 22.8877, 21.4455, 20.8844, 19.7733, 21.1188, 22.5522, 23.8877, 22.6644, + 24.1188, 25.2244, 24.0033, 26.1188, + ]; + let res = calc_macd(&series, 12, 26, 9); + assert_eq!(res.dif.len(), series.len()); + assert!(res.dif[..33].iter().all(|x| x.is_nan())); + assert!(res.dea[..33].iter().all(|x| x.is_nan())); + assert!(res.macd[..33].iter().all(|x| x.is_nan())); + + let exp_dif = [ + 2.1423805846933526, + 2.178795717541078, + 2.084911326267463, + 2.103616018844914, + 2.1824938717891307, + 2.122011347779168, + 2.219200124149946, + ]; + let exp_dea = [ + 2.8872106434589506, + 2.745527658275376, + 2.6134043918737935, + 2.5114467172680177, + 2.4456561481722403, + 2.380927188093626, + 2.34858177530489, + ]; + let exp_macd = [ + -0.744830058765598, + -0.5667319407342979, + -0.5284930656063307, + -0.40783069842310393, + -0.26316227638310963, + -0.258915840314458, + -0.1293816511549437, + ]; + for (offset, idx) in (33..40).enumerate() { + assert!((res.dif[idx] - exp_dif[offset]).abs() < 1e-12); + assert!((res.dea[idx] - exp_dea[offset]).abs() < 1e-12); + assert!((res.macd[idx] - exp_macd[offset]).abs() < 1e-12); + } + } + + #[test] + fn test_calc_macd_cache_style_matches_talib_contract() { + let series = vec![ + 10.1234, 14.9812, 9.3345, 11.7789, 13.5521, 12.0044, 15.3311, 14.2088, 13.8877, + 16.2355, 15.0222, 17.1188, 16.0044, 18.4455, 17.9933, 19.2244, 18.1102, 17.8891, + 20.0033, 19.7721, 21.1144, 20.8842, 22.3355, 21.6622, 23.1188, 22.7744, 24.3355, + 23.9911, 22.8877, 21.4455, 20.8844, 19.7733, 21.1188, 22.5522, 23.8877, 22.6644, + 24.1188, 25.2244, 24.0033, 26.1188, + ]; + let base = calc_macd(&series, 12, 26, 9); + let cache = calc_macd_cache_style(&series, 12, 26, 9); + assert_eq!(base.dif.len(), cache.dif.len()); + for idx in 0..series.len() { + if base.dif[idx].is_nan() { + assert!(cache.dif[idx].is_nan()); + assert!(cache.dea[idx].is_nan()); + assert!(cache.macd[idx].is_nan()); + } else { + assert!((base.dif[idx] - cache.dif[idx]).abs() < 1e-12); + assert!((base.dea[idx] - cache.dea[idx]).abs() < 1e-12); + assert!((base.macd[idx] - cache.macd[idx]).abs() < 1e-12); + } + } + } + + #[test] + fn test_calc_macd_plain_and_py_style_share_canonical_contract() { + let series: Vec = (1..=200).map(|x| x as f64 / 3.7).collect(); + let res = calc_macd(&series, 12, 26, 9); + let alias = calc_macd_py_style(&series, 12, 26, 9); + assert_eq!(res.dif.len(), alias.dif.len()); + for idx in 0..res.dif.len() { + if res.dif[idx].is_nan() { + assert!(alias.dif[idx].is_nan()); + assert!(alias.dea[idx].is_nan()); + assert!(alias.macd[idx].is_nan()); + } else { + assert!((res.dif[idx] - alias.dif[idx]).abs() < 1e-12); + assert!((res.dea[idx] - alias.dea[idx]).abs() < 1e-12); + assert!((res.macd[idx] - alias.macd[idx]).abs() < 1e-12); + let expect = res.dif[idx] - res.dea[idx]; + assert!((res.macd[idx] - expect).abs() < 1e-12); + } + } + } + + #[test] + fn test_calc_macd_talib_histogram_is_not_doubled() { + let series = vec![ + 10.1234, 14.9812, 9.3345, 11.7789, 13.5521, 12.0044, 15.3311, 14.2088, 13.8877, + 16.2355, 15.0222, 17.1188, 16.0044, 18.4455, 17.9933, 19.2244, 18.1102, 17.8891, + 20.0033, 19.7721, 21.1144, 20.8842, 22.3355, 21.6622, 23.1188, 22.7744, 24.3355, + 23.9911, 22.8877, 21.4455, 20.8844, 19.7733, 21.1188, 22.5522, 23.8877, 22.6644, + 24.1188, 25.2244, 24.0033, 26.1188, + ]; + let res = calc_macd(&series, 12, 26, 9); + let idx = 39; + let hist = res.macd[idx]; + let doubled = (res.dif[idx] - res.dea[idx]) * 2.0; + assert!((hist - (res.dif[idx] - res.dea[idx])).abs() < 1e-12); + assert!((hist - doubled).abs() > 1e-6); + } + + #[test] + fn test_calc_rsi_flat_tail_aligns_talib_direction_case1() { + let close = vec![ + 22763.8, 22769.4, 22671.5, 22864.0, 22941.1, 22778.9, 22763.9, 23276.7, 23126.7, + 23062.3, 23230.6, 23115.5, 22916.9, 22962.4, 22974.4, 22974.4, + ]; + let rsi = calc_rsi(&close, 6); + let prev = rsi[14]; + let curr = rsi[15]; + assert!(prev.is_finite() && curr.is_finite()); + assert!( + curr < prev, + "expected talib direction down, prev={prev}, curr={curr}" + ); + } + + #[test] + fn test_calc_rsi_flat_tail_aligns_talib_direction_case2() { + let close = vec![ + 67887.5, 67935.1, 68395.2, 68305.0, 68548.2, 68600.9, 68731.1, 68881.0, 68788.3, + 68808.5, 68459.1, 68673.0, 68949.0, 69173.0, 69566.1, 69566.1, + ]; + let rsi = calc_rsi(&close, 6); + let prev = rsi[14]; + let curr = rsi[15]; + assert!(prev.is_finite() && curr.is_finite()); + assert!( + curr < prev, + "expected talib direction down, prev={prev}, curr={curr}" + ); + } + + #[test] + fn test_calc_stoch_all_nan_when_range_zero() { + let high = vec![1.0; 11]; + let low = vec![1.0; 11]; + let close = vec![1.0; 11]; + let (k, d) = calc_stoch(&high, &low, &close, 9, 3, 3); + assert!(k.iter().all(|x| x.is_nan())); + assert!(d.iter().all(|x| x.is_nan())); + } + + #[test] + fn test_calc_stoch_matches_talib_reference_sequence() { + let high = vec![ + 10.0, 11.0, 12.0, 11.0, 13.0, 14.0, 15.0, 14.0, 16.0, 17.0, 18.0, 17.0, 19.0, 20.0, + 21.0, 20.0, 22.0, 23.0, 22.0, 24.0, + ]; + let low = vec![ + 9.0, 10.0, 11.0, 10.0, 12.0, 13.0, 14.0, 13.0, 15.0, 16.0, 17.0, 16.0, 18.0, 19.0, + 20.0, 19.0, 21.0, 22.0, 21.0, 23.0, + ]; + let close = vec![ + 9.5, 10.5, 11.5, 10.8, 12.6, 13.2, 14.7, 13.5, 15.4, 16.8, 17.1, 16.4, 18.6, 19.4, + 20.9, 19.7, 21.3, 22.7, 21.8, 23.4, + ]; + let (k, d) = calc_stoch(&high, &low, &close, 9, 3, 3); + + let exp_k = [ + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + 87.67857142857143, + 88.57142857142856, + 94.82142857142856, + 91.3095238095238, + 90.83333333333333, + 89.82142857142856, + 89.52380952380952, + 90.35714285714285, + ]; + let exp_d = [ + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + 89.58333333333333, + 88.29365079365078, + 90.35714285714285, + 91.5674603174603, + 92.32142857142856, + 90.65476190476188, + 90.0595238095238, + 89.90079365079363, + ]; + + for i in 0..k.len() { + if exp_k[i].is_nan() { + assert!(k[i].is_nan(), "k[{i}] should be NaN but got {}", k[i]); + } else { + assert!( + (k[i] - exp_k[i]).abs() < 1e-12, + "k[{i}] = {}, expect {}", + k[i], + exp_k[i] + ); + } + if exp_d[i].is_nan() { + assert!(d[i].is_nan(), "d[{i}] should be NaN but got {}", d[i]); + } else { + assert!( + (d[i] - exp_d[i]).abs() < 1e-12, + "d[{i}] = {}, expect {}", + d[i], + exp_d[i] + ); + } + } + } +} diff --git a/crates/czsc-signals/src/utils/zdy.rs b/crates/czsc-signals/src/utils/zdy.rs new file mode 100644 index 000000000..04b362282 --- /dev/null +++ b/crates/czsc-signals/src/utils/zdy.rs @@ -0,0 +1,63 @@ +use crate::types::TaCache; +use crate::utils::ta::update_macd_cache; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bi::BI; +use czsc_core::objects::zs::ZS; +use std::collections::HashMap; + +pub fn macd_cache_maps( + c: &CZSC, + fast: usize, + slow: usize, + signal: usize, + cache: &mut TaCache, +) -> (HashMap, HashMap, HashMap) { + let short = fast.min(slow); + let long = fast.max(slow); + let cache_key = format!("MACD{}#{}#{}", short, long, signal); + update_macd_cache(c, &cache_key, short, long, signal, cache); + + let mut dif_map = HashMap::new(); + let mut dea_map = HashMap::new(); + let mut macd_map = HashMap::new(); + if let Some(series) = cache.macd.get(&cache_key) { + for (i, id) in series.ids.iter().enumerate() { + dif_map.insert(*id, series.dif[i]); + dea_map.insert(*id, series.dea[i]); + macd_map.insert(*id, series.macd[i]); + } + } + (dif_map, dea_map, macd_map) +} + +pub fn is_valid_zs(bis: &[BI]) -> bool { + if bis.len() < 3 { + return false; + } + ZS::new(bis.to_vec()).is_valid() +} + +pub fn find_peaks_valleys(data: &[f64]) -> (HashMap, HashMap) { + let mut peaks = HashMap::new(); + let mut valleys = HashMap::new(); + if data.len() < 5 { + return (peaks, valleys); + } + for i in 2..data.len() - 2 { + if data[i - 2] < data[i - 1] + && data[i - 1] < data[i] + && data[i] > data[i + 1] + && data[i + 1] > data[i + 2] + { + peaks.insert(i, data[i]); + } + if data[i - 2] > data[i - 1] + && data[i - 1] > data[i] + && data[i] < data[i + 1] + && data[i + 1] < data[i + 2] + { + valleys.insert(i, data[i]); + } + } + (peaks, valleys) +} diff --git a/crates/czsc-signals/src/vol.rs b/crates/czsc-signals/src/vol.rs new file mode 100644 index 000000000..6b06f19db --- /dev/null +++ b/crates/czsc-signals/src/vol.rs @@ -0,0 +1,432 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{bar_index_map, get_sub_elements, make_kline_signal_v1, make_kline_signal_v2}; +use crate::utils::ta::update_vol_ma_cache; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +fn qcut_labels(values: &[f64], q: usize) -> Option> { + if q == 0 || values.is_empty() || values.iter().any(|x| !x.is_finite()) { + return None; + } + let mut sorted: Vec = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + + let quantile = |p: f64| -> f64 { + if sorted.len() == 1 { + return sorted[0]; + } + let h = (sorted.len() - 1) as f64 * p; + let i = h.floor() as usize; + let j = h.ceil() as usize; + if i == j { + sorted[i] + } else { + sorted[i] + (h - i as f64) * (sorted[j] - sorted[i]) + } + }; + + let mut edges = Vec::with_capacity(q + 1); + for i in 0..=q { + edges.push(quantile(i as f64 / q as f64)); + } + edges.dedup_by(|a, b| (*a - *b).abs() <= f64::EPSILON); + if edges.len() <= 1 { + return None; + } + let bins = edges.len() - 1; + let mut labels = Vec::with_capacity(values.len()); + for &x in values { + if x < edges[0] || x > edges[bins] { + return None; + } + let mut found = None; + for i in 0..bins { + let left_ok = if i == 0 { x >= edges[i] } else { x > edges[i] }; + let right_ok = x <= edges[i + 1]; + if left_ok && right_ok { + found = Some(i); + break; + } + } + labels.push(found.unwrap_or(bins - 1)); + } + Some(labels) +} + +/// vol_single_ma_V230214:单成交量均线多空与方向信号 +/// +/// 参数模板:`"{freq}_D{di}VOL#{ma_type}#{timeperiod}_分类V230214"` +/// +/// 信号逻辑: +/// 1. 计算指定成交量均线(`SMA/EMA/WMA`); +/// 2. `vol_now >= vol_ma_now` 判定 `多头`,否则 `空头`; +/// 3. `vol_ma_now >= vol_ma_prev` 判定 `向上`,否则 `向下`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1VOL#SMA#5_分类V230214_多头_向上_任意_0')` +/// - `Signal('60分钟_D1VOL#EMA#12_分类V230214_空头_向下_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `ma_type`:均线类型,默认 `SMA`; +/// - `timeperiod`:均线周期,默认 `5`。 +/// 对齐说明:成交量均线缓存与判定口径对齐 Python `vol_single_ma_V230214`。 +#[signal( + category = "kline", + name = "vol_single_ma_V230214", + template = "{freq}_D{di}VOL#{ma_type}#{timeperiod}_分类V230214", + opcode = "VolSingleMaV230214", + param_kind = "VolSingleMaV230214" +)] +pub fn vol_single_ma_v230214(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let timeperiod = params.usize("timeperiod", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}VOL#{}#{}", di, ma_type, timeperiod); + let k3 = "分类V230214"; + + let cache_key = format!("VOL#{}#{}", ma_type, timeperiod); + update_vol_ma_cache(c, &cache_key, &ma_type, timeperiod, cache); + + let bars = get_sub_elements(&c.bars_raw, di, 3); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let Some(ma) = cache.series.get(&cache_key) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let idx_map = bar_index_map(c); + let idx_prev = bars.len() - 2; + let idx_last = bars.len() - 1; + let Some(i_prev) = idx_map.get(&bars[idx_prev].id).copied() else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(i_last) = idx_map.get(&bars[idx_last].id).copied() else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + if i_prev >= ma.len() || i_last >= ma.len() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let v1 = if bars[idx_last].vol >= ma[i_last] { + "多头" + } else { + "空头" + }; + let v2 = if ma[i_last] >= ma[i_prev] { + "向上" + } else { + "向下" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// vol_double_ma_V230214:成交量双均线多空信号 +/// +/// 参数模板:`"{freq}_D{di}VOL双均线{ma_type}#{t1}#{t2}_BS辅助V230214"` +/// +/// 信号逻辑: +/// 1. 分别计算成交量短均线 `t1` 与长均线 `t2`; +/// 2. `vol_ma_short >= vol_ma_long` 判定 `看多`,否则 `看空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1VOL双均线SMA#5#20_BS辅助V230214_看多_任意_任意_0')` +/// - `Signal('60分钟_D1VOL双均线EMA#5#20_BS辅助V230214_看空_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `t1`:短均线周期,默认 `5`; +/// - `t2`:长均线周期,默认 `20`; +/// - `ma_type`:均线类型,默认 `SMA`。 +/// 对齐说明:短长成交量均线关系判定与 Python `vol_double_ma_V230214` 一致。 +#[signal( + category = "kline", + name = "vol_double_ma_V230214", + template = "{freq}_D{di}VOL双均线{ma_type}#{t1}#{t2}_BS辅助V230214", + opcode = "VolDoubleMaV230214", + param_kind = "VolDoubleMaV230214" +)] +pub fn vol_double_ma_v230214(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let t1 = params.usize("t1", 5); + let t2 = params.usize("t2", 20); + assert!(t2 > t1, "t2 must be greater than t1"); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let k1 = c.freq.to_string(); + let k2 = format!("D{}VOL双均线{}#{}#{}", di, ma_type, t1, t2); + let k3 = "BS辅助V230214"; + + let cache_key1 = format!("VOL#{}#{}", ma_type, t1); + let cache_key2 = format!("VOL#{}#{}", ma_type, t2); + update_vol_ma_cache(c, &cache_key1, &ma_type, t1, cache); + update_vol_ma_cache(c, &cache_key2, &ma_type, t2, cache); + + let bars = get_sub_elements(&c.bars_raw, di, 3); + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let Some(ma1) = cache.series.get(&cache_key1) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(ma2) = cache.series.get(&cache_key2) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + let idx_map = bar_index_map(c); + let Some(i_last) = idx_map.get(&bars[bars.len() - 1].id).copied() else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + if i_last >= ma1.len() || i_last >= ma2.len() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let v1 = if ma1[i_last] >= ma2[i_last] { + "看多" + } else { + "看空" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// vol_ti_suo_V221216:梯量与缩量柱信号 +/// +/// 参数模板:`"{freq}_D{di}K_量柱V221216"` +/// +/// 信号逻辑: +/// 1. 连续三根成交量递增判定 `梯量`,递减判定 `缩量`; +/// 2. 在 `梯量/缩量` 前提下,以当前收盘与前两根收盘区间比较得到 `价升/价跌/价平`; +/// 3. 不满足量柱条件时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K_量柱V221216_梯量_价升_任意_0')` +/// - `Signal('60分钟_D1K_量柱V221216_缩量_价跌_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:量柱与价位分类规则对齐 Python `vol_ti_suo_V221216`。 +#[signal( + category = "kline", + name = "vol_ti_suo_V221216", + template = "{freq}_D{di}K_量柱V221216", + opcode = "VolTiSuoV221216", + param_kind = "VolTiSuoV221216" +)] +pub fn vol_ti_suo_v221216(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "量柱V221216"; + if c.bars_raw.len() < di + 5 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + + let bar1 = &c.bars_raw[c.bars_raw.len() - di]; + let bar2 = &c.bars_raw[c.bars_raw.len() - di - 1]; + let bar3 = &c.bars_raw[c.bars_raw.len() - di - 2]; + let close_max = bar2.close.max(bar3.close); + let close_min = bar2.close.min(bar3.close); + + let v1 = if bar1.vol > bar2.vol && bar2.vol > bar3.vol { + "梯量" + } else if bar1.vol < bar2.vol && bar2.vol < bar3.vol { + "缩量" + } else { + "其他" + }; + if v1 == "其他" { + return make_kline_signal_v1(&k1, &k2, k3, v1); + } + + let v2 = if bar1.close < close_min && bar1.close < bar1.open { + "价跌" + } else if bar1.close > close_max && bar1.close > bar1.open { + "价升" + } else { + "价平" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// vol_gao_di_V221218:高量柱与低量柱信号 +/// +/// 参数模板:`"{freq}_D{di}K_量柱V221218"` +/// +/// 信号逻辑: +/// 1. 依次检查 `10/9/8/7/6` 根窗口; +/// 2. 若末根成交量为窗口最大值,判 `高量柱`; +/// 3. 若次末根为窗口最大且末根不足其 50%,判 `高量黄金柱`; +/// 4. 若末根成交量为窗口最小值,判 `低量柱`; +/// 5. 命中后输出对应窗口长度(如 `10K`)。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1K_量柱V221218_高量柱_10K_任意_0')` +/// - `Signal('60分钟_D1K_量柱V221218_低量柱_7K_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 +/// 对齐说明:窗口递减检查顺序与高/低量柱定义对齐 Python `vol_gao_di_V221218`。 +#[signal( + category = "kline", + name = "vol_gao_di_V221218", + template = "{freq}_D{di}K_量柱V221218", + opcode = "VolGaoDiV221218", + param_kind = "VolGaoDiV221218" +)] +pub fn vol_gao_di_v221218(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}K", di); + let k3 = "量柱V221218"; + let mut v1 = "其他"; + let mut v2 = "任意".to_string(); + + for n in [10usize, 9, 8, 7, 6] { + let bars = get_sub_elements(&c.bars_raw, di, n); + if bars.len() != n || bars.len() <= 5 { + continue; + } + let max_vol = bars.iter().map(|x| x.vol).fold(f64::NEG_INFINITY, f64::max); + let min_vol = bars.iter().map(|x| x.vol).fold(f64::INFINITY, f64::min); + let last = &bars[bars.len() - 1]; + let prev = &bars[bars.len() - 2]; + let cur = if (last.vol - max_vol).abs() <= f64::EPSILON { + "高量柱" + } else if (prev.vol - max_vol).abs() <= f64::EPSILON && last.vol < prev.vol * 0.5 { + "高量黄金柱" + } else if (last.vol - min_vol).abs() <= f64::EPSILON { + "低量柱" + } else { + "其他" + }; + if cur != "其他" { + v1 = cur; + v2 = format!("{}K", n); + break; + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} + +/// vol_window_V230731:窗口成交量分层特征 +/// +/// 参数模板:`"{freq}_D{di}W{w}M{m}N{n}_窗口能量V230731"` +/// +/// 信号逻辑: +/// 1. 取最近 `m` 根成交量并按 `qcut` 分成 `n` 层; +/// 2. 统计最近 `w` 根中的最高层与最低层; +/// 3. 输出 `高量N{max}` 与 `低量N{min}`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D2W5M100N10_窗口能量V230731_高量N9_低量N4_任意_0')` +/// - `Signal('60分钟_D1W5M30N10_窗口能量V230731_高量N10_低量N3_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `w`:观察窗口大小,默认 `5`; +/// - `m`:分层样本长度,默认 `30`; +/// - `n`:分层数量,默认 `10`。 +/// 对齐说明:分层采用与 Python `pd.qcut(..., duplicates='drop')` 等价的去重分位边界。 +#[signal( + category = "kline", + name = "vol_window_V230731", + template = "{freq}_D{di}W{w}M{m}N{n}_窗口能量V230731", + opcode = "VolWindowV230731", + param_kind = "VolWindowV230731" +)] +pub fn vol_window_v230731(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 5); + let m = params.usize("m", 30); + let n = params.usize("n", 10); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}M{}N{}", di, w, m, n); + let k3 = "窗口能量V230731"; + + if c.bars_raw.len() < di + m { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let vols: Vec = get_sub_elements(&c.bars_raw, di, m) + .iter() + .map(|x| x.vol) + .collect(); + let Some(labels) = qcut_labels(&vols, n) else { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + }; + if labels.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let tail = if w >= labels.len() { + &labels[..] + } else { + &labels[labels.len() - w..] + }; + let max_layer = tail.iter().copied().max().unwrap_or(0) + 1; + let min_layer = tail.iter().copied().min().unwrap_or(0) + 1; + let v1 = format!("高量N{}", max_layer); + let v2 = format!("低量N{}", min_layer); + make_kline_signal_v2(&k1, &k2, k3, &v1, &v2) +} + +/// vol_window_V230801:窗口成交量先后顺序特征 +/// +/// 参数模板:`"{freq}_D{di}W{w}_窗口能量V230801"` +/// +/// 信号逻辑: +/// 1. 取最近 `w` 根成交量; +/// 2. 若最小量索引在最大量索引之后,判 `先放后缩`; +/// 3. 否则判 `先缩后放`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1W5_窗口能量V230801_先放后缩_任意_任意_0')` +/// - `Signal('60分钟_D1W5_窗口能量V230801_先缩后放_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +/// - `w`:观察窗口大小,默认 `5`。 +/// 对齐说明:最大/最小成交量首次出现位置比较逻辑对齐 Python `vol_window_V230801`。 +#[signal( + category = "kline", + name = "vol_window_V230801", + template = "{freq}_D{di}W{w}_窗口能量V230801", + opcode = "VolWindowV230801", + param_kind = "VolWindowV230801" +)] +pub fn vol_window_v230801(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = params.usize("di", 1); + let w = params.usize("w", 5); + let k1 = c.freq.to_string(); + let k2 = format!("D{}W{}", di, w); + let k3 = "窗口能量V230801"; + + if c.bars_raw.len() < di + w { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let vols: Vec = get_sub_elements(&c.bars_raw, di, w) + .iter() + .map(|x| x.vol) + .collect(); + if vols.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let min_i = vols + .iter() + .enumerate() + .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i) + .unwrap_or(0); + let max_i = vols + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(i, _)| i) + .unwrap_or(0); + let v1 = if min_i > max_i { + "先放后缩" + } else { + "先缩后放" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} diff --git a/crates/czsc-signals/src/xl.rs b/crates/czsc-signals/src/xl.rs new file mode 100644 index 000000000..28b539be4 --- /dev/null +++ b/crates/czsc-signals/src/xl.rs @@ -0,0 +1,463 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_sub_elements, make_kline_signal_v1, make_kline_signal_v2}; +use crate::utils::ta::update_ma_cache; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +fn mean(values: &[f64]) -> f64 { + if values.is_empty() { + f64::NAN + } else { + values.iter().sum::() / values.len() as f64 + } +} + +fn std_pop(values: &[f64]) -> f64 { + let m = mean(values); + if !m.is_finite() || values.is_empty() { + return f64::NAN; + } + let var = values.iter().map(|x| (x - m).powi(2)).sum::() / values.len() as f64; + var.sqrt() +} + +fn quantile_midpoint(values: &[f64], q: f64) -> f64 { + if values.is_empty() { + return f64::NAN; + } + let mut v = values.to_vec(); + v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let n = v.len(); + let pos = q * (n.saturating_sub(1)) as f64; + let lo = pos.floor() as usize; + let hi = pos.ceil() as usize; + if lo == hi { + v[lo] + } else { + (v[lo] + v[hi]) / 2.0 + } +} + +fn check_szx(bar: &RawBar, th: i32) -> bool { + if bar.close == bar.open && bar.high != bar.low { + return true; + } + bar.close != bar.open && (bar.high - bar.low) / (bar.close - bar.open).abs() > th as f64 +} + +fn trend_count(cache1: &[f64], cache2: &[f64]) -> i32 { + let mut num = 0; + if cache1.len() != cache2.len() || cache1.len() < 2 { + return num; + } + for i in 0..cache1.len() - 1 { + let b1 = cache1[i] < cache2[i]; + let b2 = cache1[i + 1] < cache2[i + 1]; + if b2 && b1 != b2 { + num = 1; + } else if b2 && b1 == b2 { + num += 1; + } + + let b3 = cache1[i] > cache2[i]; + let b4 = cache1[i + 1] > cache2[i + 1]; + if b4 && b3 != b4 { + num = 1; + } else if b4 && b3 == b4 { + num += 1; + } + if num >= 10 { + num = 10; + } + } + num +} + +/// xl_bar_position_V240328:相对高低位置识别信号 +/// +/// 参数模板:`"{freq}_N{n}_BS辅助V240328"` +/// +/// 信号逻辑: +/// 1. 计算最近 `3n` 根 `(close-EMA(n))/EMA(n)` 偏离度; +/// 2. 若最新值低于 30% 分位数判 `相对低点`; +/// 3. 若最新值高于 70% 分位数判 `相对高点`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N10_BS辅助V240328_相对低点_任意_任意_0')` +/// - `Signal('60分钟_N10_BS辅助V240328_相对高点_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:EMA 周期,默认 `10`。 +/// 对齐说明:与 Python `xl_bar_position_V240328` 的分位口径一致(midpoint)。 +#[signal( + category = "kline", + name = "xl_bar_position_V240328", + template = "{freq}_N{n}_BS辅助V240328", + opcode = "XlBarPositionV240328", + param_kind = "XlBarPositionV240328" +)] +pub fn xl_bar_position_v240328(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = params.usize("n", 10); + let k1 = c.freq.to_string(); + let k2 = format!("N{}", n); + let k3 = "BS辅助V240328"; + if c.bars_raw.len() < n + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, n + 2 * n); + let ema_key = format!("EMA#{}", n); + update_ma_cache(c, &ema_key, "EMA", n, cache); + let ema = match cache.series.get(&ema_key) { + Some(v) => v, + None => return make_kline_signal_v1(&k1, &k2, k3, "其他"), + }; + if bars.is_empty() { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let start = c.bars_raw.len() - bars.len(); + let mut nor = Vec::with_capacity(bars.len()); + for (i, b) in bars.iter().enumerate() { + let e = ema[start + i]; + nor.push((b.close - e) / e); + } + let q30 = quantile_midpoint(&nor, 0.3); + let q70 = quantile_midpoint(&nor, 0.7); + let last = *nor.last().unwrap_or(&f64::NAN); + let v1 = if last < q30 { + "相对低点" + } else if last > q70 { + "相对高点" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// xl_bar_trend_V240329:十字孕线反转信号 +/// +/// 参数模板:`"{freq}_N{n}M{m}_十字线反转V240329"` +/// +/// 信号逻辑: +/// 1. 判断最新K线是否十字线(`check_szx`); +/// 2. 前一根为长阴线且满足阈值判 `底部十字孕线`; +/// 3. 前一根为长阳线且满足阈值判 `顶部十字孕线`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N10M5_十字线反转V240329_底部十字孕线_其他_任意_0')` +/// - `Signal('60分钟_N10M5_十字线反转V240329_顶部十字孕线_其他_任意_0')` +/// +/// 参数说明: +/// - `n`:十字线阈值参数,默认 `10`; +/// - `m`:实体比例阈值,默认 `5`。 +/// 对齐说明:与 Python `xl_bar_trend_V240329` 的 `check_szx` 口径一致。 +#[signal( + category = "kline", + name = "xl_bar_trend_V240329", + template = "{freq}_N{n}M{m}_十字线反转V240329", + opcode = "XlBarTrendV240329", + param_kind = "XlBarTrendV240329" +)] +pub fn xl_bar_trend_v240329(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 10) as i32; + let m = params.usize("m", 5); + let k1 = c.freq.to_string(); + let k2 = format!("N{}M{}", n, m); + let k3 = "十字线反转V240329"; + if c.bars_raw.len() < n as usize + 1 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, 2); + if bars.len() < 2 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let bar1 = &bars[0]; + let bar2 = &bars[1]; + let mut v1 = "其他"; + if check_szx(bar2, n) + && bar1.close < bar1.open + && (bar1.open - bar1.close) / (bar1.high - bar1.low) * 10.0 >= m as f64 + { + v1 = "底部十字孕线"; + } + if check_szx(bar2, n) + && bar1.close > bar1.open + && (bar1.close - bar1.open) / (bar1.high - bar1.low) * 10.0 >= m as f64 + { + v1 = "顶部十字孕线"; + } + make_kline_signal_v2(&k1, &k2, k3, v1, "其他") +} + +/// xl_bar_trend_V240330:双均线过滤信号 +/// +/// 参数模板:`"{freq}_N{n}M{m}#{ma_type}_双均线过滤V240330"` +/// +/// 信号逻辑: +/// 1. 计算 `MA(n)` 与 `MA(m)`; +/// 2. 根据两线相对位置输出 `看多/看空`; +/// 3. 统计连续状态次数并输出 `第xx次`(最大 10)。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N5M21#SMA_双均线过滤V240330_看多_第03次_任意_0')` +/// - `Signal('60分钟_N5M21#SMA_双均线过滤V240330_看空_第06次_任意_0')` +/// +/// 参数说明: +/// - `n/m`:短长均线周期,默认 `5/21`; +/// - `ma_type`:均线类型,默认 `SMA`。 +/// 对齐说明:与 Python `xl_bar_trend_V240330` 的次数计数逻辑一致。 +#[signal( + category = "kline", + name = "xl_bar_trend_V240330", + template = "{freq}_N{n}M{m}#{ma_type}_双均线过滤V240330", + opcode = "XlBarTrendV240330", + param_kind = "XlBarTrendV240330" +)] +pub fn xl_bar_trend_v240330(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = params.usize("n", 5); + let m = params.usize("m", 21); + let ma_type = params.str("ma_type", "SMA").to_uppercase(); + let k1 = c.freq.to_string(); + let k2 = format!("N{}M{}#{}", n, m, ma_type); + let k3 = "双均线过滤V240330"; + if c.bars_raw.len() < m + 1 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let key1 = format!("{}#{}", ma_type, n); + let key2 = format!("{}#{}", ma_type, m); + update_ma_cache(c, &key1, &ma_type, n, cache); + update_ma_cache(c, &key2, &ma_type, m, cache); + let bars = get_sub_elements(&c.bars_raw, 1, m + 1); + let start = c.bars_raw.len() - bars.len(); + let ma1 = match cache.series.get(&key1) { + Some(v) => v, + None => return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"), + }; + let ma2 = match cache.series.get(&key2) { + Some(v) => v, + None => return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"), + }; + let mut cache1 = Vec::with_capacity(bars.len()); + let mut cache2 = Vec::with_capacity(bars.len()); + for i in 0..bars.len() { + cache1.push(ma1[start + i]); + cache2.push(ma2[start + i]); + } + let num = trend_count(&cache1, &cache2).min(10); + let v2 = format!("第{:02}次", num); + let v1 = if cache1[cache1.len() - 1] > cache2[cache2.len() - 1] { + "看多" + } else if cache1[cache1.len() - 1] < cache2[cache2.len() - 1] { + "看空" + } else { + "其他" + }; + make_kline_signal_v2(&k1, &k2, k3, v1, &v2) +} + +/// xl_bar_trend_V240331:通道突破信号 +/// +/// 参数模板:`"{freq}_N{n}_突破信号V240331"` +/// +/// 信号逻辑: +/// 1. 若最新高点突破前 `n` 根最高价,判 `做多`; +/// 2. 若最新低点跌破前 `n` 根最低价,判 `做空`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N20_突破信号V240331_做多_任意_任意_0')` +/// - `Signal('60分钟_N20_突破信号V240331_做空_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:通道窗口,默认 `20`。 +/// 对齐说明:与 Python `xl_bar_trend_V240331` 的突破判定一致。 +#[signal( + category = "kline", + name = "xl_bar_trend_V240331", + template = "{freq}_N{n}_突破信号V240331", + opcode = "XlBarTrendV240331", + param_kind = "XlBarTrendV240331" +)] +pub fn xl_bar_trend_v240331(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 20); + let k1 = c.freq.to_string(); + let k2 = format!("N{}", n); + let k3 = "突破信号V240331"; + if c.bars_raw.len() < n + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, n + 1); + let hh = bars[..bars.len() - 1] + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let ll = bars[..bars.len() - 1] + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min); + let last = &bars[bars.len() - 1]; + let v1 = if last.high >= hh { + "做多" + } else if last.low <= ll { + "做空" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// xl_bar_basis_V240412:长蜡烛形态信号 +/// +/// 参数模板:`"{freq}_N{n}#TH{th}_形态V240412"` +/// +/// 信号逻辑: +/// 1. 统计前 `n` 根实体长度均值与标准差; +/// 2. 当前实体超过 `mean + th*std` 时识别长蜡烛; +/// 3. 按实体方向输出 `看涨长蜡烛/看跌长蜡烛`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N10#TH3_形态V240412_看涨长蜡烛_任意_任意_0')` +/// - `Signal('60分钟_N10#TH3_形态V240412_看跌长蜡烛_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:统计窗口,默认 `10`; +/// - `th`:标准差倍数,默认 `3`。 +/// 对齐说明:与 Python `xl_bar_basis_V240412` 的阈值公式一致。 +#[signal( + category = "kline", + name = "xl_bar_basis_V240412", + template = "{freq}_N{n}#TH{th}_形态V240412", + opcode = "XlBarBasisV240412", + param_kind = "XlBarBasisV240412" +)] +pub fn xl_bar_basis_v240412(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 10); + let th = params.usize("th", 3); + let k1 = c.freq.to_string(); + let k2 = format!("N{}#TH{}", n, th); + let k3 = "形态V240412"; + if c.bars_raw.len() < n + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, n + 1); + let lens: Vec = bars[..bars.len() - 1] + .iter() + .map(|x| (x.close - x.open).abs()) + .collect(); + let bar_solid_th = mean(&lens) + th as f64 * std_pop(&lens); + let bar_solid = bars[bars.len() - 1].close - bars[bars.len() - 1].open; + let v1 = if bar_solid > 0.0 && bar_solid > bar_solid_th { + "看涨长蜡烛" + } else if bar_solid < 0.0 && bar_solid.abs() > bar_solid_th { + "看跌长蜡烛" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// xl_bar_basis_V240411:吞没形态信号 +/// +/// 参数模板:`"{freq}_N{n}_形态V240411"` +/// +/// 信号逻辑: +/// 1. 最近两根K线构成看涨吞没时输出 `看涨吞没`; +/// 2. 构成看跌吞没时输出 `看跌吞没`; +/// 3. 否则输出 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N2_形态V240411_看涨吞没_任意_任意_0')` +/// - `Signal('60分钟_N2_形态V240411_看跌吞没_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:最小窗口参数,默认 `2`。 +/// 对齐说明:与 Python `xl_bar_basis_V240411` 的吞没条件一致。 +#[signal( + category = "kline", + name = "xl_bar_basis_V240411", + template = "{freq}_N{n}_形态V240411", + opcode = "XlBarBasisV240411", + param_kind = "XlBarBasisV240411" +)] +pub fn xl_bar_basis_v240411(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 2); + let k1 = c.freq.to_string(); + let k2 = format!("N{}", n); + let k3 = "形态V240411"; + if c.bars_raw.len() < n + 1 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bars = get_sub_elements(&c.bars_raw, 1, 2); + if bars.len() < 2 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let bar1 = &bars[0]; + let bar2 = &bars[1]; + let v1 = if (bar1.open > bar1.close) && (bar2.close > bar1.high) && (bar2.open <= bar1.low) { + "看涨吞没" + } else if (bar1.open < bar1.close) && (bar2.open >= bar1.high) && (bar2.close < bar1.low) { + "看跌吞没" + } else { + "其他" + }; + make_kline_signal_v1(&k1, &k2, k3, v1) +} + +/// xl_bar_trend_V240623:通道突破连续信号 +/// +/// 参数模板:`"{freq}_N{n}通道_突破信号V240623"` +/// +/// 信号逻辑: +/// 1. 使用倒数第二根K线判断是否突破前 `n` 通道; +/// 2. 突破上轨给 `做多`,且最新再创新高给 `连续2次上涨`; +/// 3. 跌破下轨给 `做空`,且最新再创新低给 `连续2次下跌`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N20通道_突破信号V240623_做多_连续2次上涨_任意_0')` +/// - `Signal('60分钟_N20通道_突破信号V240623_做空_连续2次下跌_任意_0')` +/// +/// 参数说明: +/// - `n`:通道窗口,默认 `20`。 +/// 对齐说明:与 Python `xl_bar_trend_V240623` 的“倒二突破 + 最新确认”口径一致。 +#[signal( + category = "kline", + name = "xl_bar_trend_V240623", + template = "{freq}_N{n}通道_突破信号V240623", + opcode = "XlBarTrendV240623", + param_kind = "XlBarTrendV240623" +)] +pub fn xl_bar_trend_v240623(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let n = params.usize("n", 20); + let k1 = c.freq.to_string(); + let k2 = format!("N{}通道", n); + let k3 = "突破信号V240623"; + let bars = get_sub_elements(&c.bars_raw, 1, n + 1); + if bars.len() < n + 1 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "任意"); + } + let hh = bars[..bars.len() - 2] + .iter() + .map(|x| x.high) + .fold(f64::NEG_INFINITY, f64::max); + let ll = bars[..bars.len() - 2] + .iter() + .map(|x| x.low) + .fold(f64::INFINITY, f64::min); + let prev = &bars[bars.len() - 2]; + let last = &bars[bars.len() - 1]; + let mut v1 = "其他"; + let mut v2 = "任意"; + if prev.high >= hh { + v1 = "做多"; + if last.high > prev.high { + v2 = "连续2次上涨"; + } + } else if prev.low <= ll { + v1 = "做空"; + if last.low < prev.low { + v2 = "连续2次下跌"; + } + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} diff --git a/crates/czsc-signals/src/zdy.rs b/crates/czsc-signals/src/zdy.rs new file mode 100644 index 000000000..b40dc422d --- /dev/null +++ b/crates/czsc-signals/src/zdy.rs @@ -0,0 +1,948 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::math::{median_abs, percentile_linear, std_pop}; +use crate::utils::sig::{ + get_sub_elements, get_usize_param, make_kline_signal_v1, make_kline_signal_v2, + make_kline_signal_v3, +}; +use crate::utils::ta::{MacdField, macd_snapshot_field_value, update_macd_cache}; +use crate::utils::zdy::{find_peaks_valleys, is_valid_zs, macd_cache_maps}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::zs::ZS; +use czsc_signal_macros::signal; +use std::collections::HashMap; +#[allow(clippy::too_many_arguments)] +fn snapshot_macd_values_from_raw_bars( + c: &CZSC, + mc: &crate::types::MacdSeries, + id_to_idx: &HashMap, + raw_bars: &[RawBar], + short: usize, + long: usize, + signal: usize, + field: MacdField, + snapshot_overrides: &mut HashMap, +) -> Vec { + raw_bars + .iter() + .filter_map(|rb| { + macd_snapshot_field_value( + c, + mc, + id_to_idx, + rb, + short, + long, + signal, + field, + snapshot_overrides, + ) + }) + .filter(|x| x.is_finite()) + .collect() +} + +/// zdy_bi_end_V230406:停顿分型辅助判断笔结束 +/// +/// 参数模板:`"{freq}_D0停顿分型_BE辅助V230406"` +/// +/// 信号逻辑: +/// 1. 要求最后一笔已完成、UBI 处于 4 到 6 根之间,且最后一笔长度不少于 7; +/// 2. 向下笔后若 UBI 后段收盘价上破底分型高点判定 `看多`,向上笔后若 UBI 后段收盘价下破顶分型低点判定 `看空`; +/// 3. 再补充内部停顿与是否仍处于分型区间,输出 `v2/v3`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0停顿分型_BE辅助V230406_看多_内部底停顿_底分区间_0')` +/// - `Signal('60分钟_D0停顿分型_BE辅助V230406_看空_内部顶停顿_顶分区间_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 仅在最后一笔结束后、未完成笔较短时评估停顿分型。 +/// 对齐说明:与 Python `czsc.signals.zdy_bi_end_V230406` 保持一致。 +#[signal(category = "kline", name = "zdy_bi_end_V230406", template = "{freq}_D0停顿分型_BE辅助V230406", opcode = "ZdyBiEndV230406", param_kind = "ZdyBiEndV230406")] +pub fn zdy_bi_end_v230406(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "D0停顿分型"; + let k3 = "BE辅助V230406"; + if c.bi_list.len() < 3 || c.bars_ubi.len() > 6 || c.bars_ubi.len() < 4 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_bi = c.bi_list.last().unwrap(); + let last_fx_raw = last_bi + .fx_b + .elements + .last() + .map(|x| x.elements.clone()) + .unwrap_or_default(); + if last_fx_raw.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_high = last_fx_raw.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let last_low = last_fx_raw.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let last_bar = c.bars_raw.last().unwrap(); + if last_bi.fx_b.elements.last().unwrap().dt >= last_bar.dt || last_bi.get_length() < 7 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_bars: Vec = c.bars_ubi[3..] + .iter() + .flat_map(|x| x.elements.iter().cloned()) + .collect::>(); + let max_close = last_bars.iter().map(|x| x.close).fold(f64::NEG_INFINITY, f64::max); + let min_close = last_bars.iter().map(|x| x.close).fold(f64::INFINITY, f64::min); + let v1 = if last_bi.direction == Direction::Down && max_close > last_high { + "看多" + } else if last_bi.direction == Direction::Up && min_close < last_low { + "看空" + } else { + "其他" + }; + if v1 == "其他" { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let mut v2 = "任意"; + if v1 == "看多" && last_bi.fxs.len() >= 4 { + for i in 0..last_bi.fxs.len() - 1 { + let fx1 = &last_bi.fxs[i]; + let fx2 = &last_bi.fxs[i + 1]; + let fx2_raw: Vec = fx2 + .elements + .iter() + .flat_map(|x| x.elements.iter().cloned()) + .collect::>(); + if fx1.mark == Mark::D && fx2.mark == Mark::G && fx2_raw.iter().map(|x| x.close).fold(f64::NEG_INFINITY, f64::max) > fx1.elements.last().unwrap().high { + v2 = "内部底停顿"; + } + } + } + if v1 == "看空" && last_bi.fxs.len() >= 4 { + for i in 0..last_bi.fxs.len() - 1 { + let fx1 = &last_bi.fxs[i]; + let fx2 = &last_bi.fxs[i + 1]; + let fx2_raw: Vec = fx2 + .elements + .iter() + .flat_map(|x| x.elements.iter().cloned()) + .collect::>(); + if fx1.mark == Mark::G && fx2.mark == Mark::D && fx2_raw.iter().map(|x| x.close).fold(f64::INFINITY, f64::min) < fx1.elements.last().unwrap().low { + v2 = "内部顶停顿"; + } + } + } + let mut v3 = "任意"; + if v1 == "看多" && last_bar.close < last_bi.fx_b.high { + v3 = "底分区间"; + } + if v1 == "看空" && last_bar.close > last_bi.fx_b.low { + v3 = "顶分区间"; + } + make_kline_signal_v3(&k1, k2, k3, v1, v2, v3) +} + +/// zdy_bi_end_V230407:连续停顿分型辅助判断笔结束 +/// +/// 参数模板:`"{freq}_D0停顿分型_BE辅助V230407"` +/// +/// 信号逻辑: +/// 1. 与 `V230406` 共用边界条件,但要求突破发生在最后分型之后的连续收盘序列上; +/// 2. 向下笔后若连续收盘上破底分型高点判定 `看多`,向上笔后若连续收盘下破顶分型低点判定 `看空`; +/// 3. 再检查笔内相邻分型是否形成内部停顿,补充 `v2`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D0停顿分型_BE辅助V230407_看多_内部底停顿_任意_0')` +/// - `Signal('60分钟_D0停顿分型_BE辅助V230407_看空_内部顶停顿_任意_0')` +/// +/// 参数说明: +/// - 本信号无额外参数,`params` 可为空; +/// - 连续突破要求突破 K 线在时间上连续,不接受中途回落后再次突破。 +/// 对齐说明:与 Python `czsc.signals.zdy_bi_end_V230407` 保持一致。 +#[signal(category = "kline", name = "zdy_bi_end_V230407", template = "{freq}_D0停顿分型_BE辅助V230407", opcode = "ZdyBiEndV230407", param_kind = "ZdyBiEndV230407")] +pub fn zdy_bi_end_v230407(c: &CZSC, _params: &ParamView, _cache: &mut TaCache) -> Vec { + let k1 = c.freq.to_string(); + let k2 = "D0停顿分型"; + let k3 = "BE辅助V230407"; + if c.bi_list.len() < 3 || c.bars_ubi.len() > 6 || c.bars_ubi.len() < 4 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_bi = c.bi_list.last().unwrap(); + let last_fx_raw = last_bi + .fx_b + .elements + .last() + .map(|x| x.elements.clone()) + .unwrap_or_default(); + if last_fx_raw.is_empty() { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_high = last_fx_raw.iter().map(|x| x.high).fold(f64::NEG_INFINITY, f64::max); + let last_low = last_fx_raw.iter().map(|x| x.low).fold(f64::INFINITY, f64::min); + let last_bar = c.bars_raw.last().unwrap(); + if last_bi.fx_b.elements.last().unwrap().dt >= last_bar.dt || last_bi.get_length() < 7 { + return make_kline_signal_v1(&k1, k2, k3, "其他"); + } + let last_bars: Vec = c.bars_ubi + .iter() + .flat_map(|x| x.elements.iter().cloned()) + .filter(|x| x.dt >= last_bi.fx_b.elements.last().unwrap().dt) + .collect::>(); + let mut v1 = "其他"; + if last_bi.direction == Direction::Down && last_bars.last().unwrap().close > last_high { + let idx: Vec = last_bars.iter().enumerate().filter(|(_, x)| x.close > last_high).map(|(i, _)| i).collect(); + if idx.len() == 1 || (idx.len() > 1 && idx[idx.len() - 1] - idx[0] == idx.len() - 1) { + v1 = "看多"; + } + } else if last_bi.direction == Direction::Up && last_bars.last().unwrap().close < last_low { + let idx: Vec = last_bars.iter().enumerate().filter(|(_, x)| x.close < last_low).map(|(i, _)| i).collect(); + if idx.len() == 1 || (idx.len() > 1 && idx[idx.len() - 1] - idx[0] == idx.len() - 1) { + v1 = "看空"; + } + } + if v1 == "其他" { + return make_kline_signal_v1(&k1, k2, k3, v1); + } + let mut v2 = "任意"; + if v1 == "看多" && last_bi.fxs.len() >= 4 { + for i in 0..last_bi.fxs.len() - 1 { + let fx1 = &last_bi.fxs[i]; + let fx2 = &last_bi.fxs[i + 1]; + let fx2_raw: Vec = fx2 + .elements + .iter() + .flat_map(|x| x.elements.iter().cloned()) + .collect::>(); + if fx1.mark == Mark::D && fx2.mark == Mark::G && fx2_raw.iter().map(|x| x.close).fold(f64::NEG_INFINITY, f64::max) > fx1.elements.last().unwrap().high { + v2 = "内部底停顿"; + } + } + } + if v1 == "看空" && last_bi.fxs.len() >= 4 { + for i in 0..last_bi.fxs.len() - 1 { + let fx1 = &last_bi.fxs[i]; + let fx2 = &last_bi.fxs[i + 1]; + let fx2_raw: Vec = fx2 + .elements + .iter() + .flat_map(|x| x.elements.iter().cloned()) + .collect::>(); + if fx1.mark == Mark::G && fx2.mark == Mark::D && fx2_raw.iter().map(|x| x.close).fold(f64::INFINITY, f64::min) < fx1.elements.last().unwrap().low { + v2 = "内部顶停顿"; + } + } + } + make_kline_signal_v2(&k1, k2, k3, v1, v2) +} + +/// zdy_zs_V230423:中枢形态辅助识别上涨下跌结构 +/// +/// 参数模板:`"{freq}_D{di}中枢形态_BS辅助V230423"` +/// +/// 信号逻辑: +/// 1. 依次尝试最近 `9/7/5` 笔,取首尾笔之外的中间笔构造中枢; +/// 2. 要求中枢有效,且中枢高度至少达到首笔波动的三分之一; +/// 3. 若首笔与末笔分别对应区间最低点和最高点,则判定 `上涨/下跌` 并输出笔数。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1中枢形态_BS辅助V230423_上涨_5笔_任意_0')` +/// - `Signal('60分钟_D1中枢形态_BS辅助V230423_下跌_7笔_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 仅在未完成笔不超过 7 根时评估,避免把延伸中的 UBI 当成已确认结构。 +/// 对齐说明:与 Python `czsc.signals.zdy_zs_V230423` 保持一致。 +#[signal(category = "kline", name = "zdy_zs_V230423", template = "{freq}_D{di}中枢形态_BS辅助V230423", opcode = "ZdyZsV230423", param_kind = "ZdyZsV230423")] +pub fn zdy_zs_v230423(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}中枢形态", di); + let k3 = "BS辅助V230423"; + if c.bi_list.len() < 7 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + for n in [9, 7, 5] { + let bis = get_sub_elements(&c.bi_list, di, n); + if bis.len() != n { + continue; + } + let bi1 = &bis[0]; + let zs = ZS::new(bis[1..n - 1].to_vec()); + if !(zs.is_valid() && zs.zg - zs.zd > (bi1.get_high() - bi1.get_low()) / 3.0) { + continue; + } + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + if bi1.direction == Direction::Up && bi1.get_low() == min_low && bis.last().unwrap().get_high() == max_high { + return make_kline_signal_v2(&k1, &k2, k3, "上涨", &format!("{}笔", n)); + } + if bi1.direction == Direction::Down && bi1.get_high() == max_high && bis.last().unwrap().get_low() == min_low { + return make_kline_signal_v2(&k1, &k2, k3, "下跌", &format!("{}笔", n)); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_zs_space_V230421:中枢空间辅助识别上涨下跌结构 +/// +/// 参数模板:`"{freq}_D{di}中枢空间_BS辅助V230421"` +/// +/// 信号逻辑: +/// 1. 依次尝试最近 `9/7/5` 笔,取中间笔构造有效中枢; +/// 2. 对上涨结构,要求末笔离开中枢上沿的空间不小于首笔进入中枢前的空间;下跌结构反向判断; +/// 3. 满足空间对称性后输出 `上涨/下跌 + 笔数`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1中枢空间_BS辅助V230421_上涨_5笔_任意_0')` +/// - `Signal('60分钟_D1中枢空间_BS辅助V230421_下跌_9笔_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - 仅对有效中枢做空间比较,无中枢时直接返回 `其他`。 +/// 对齐说明:与 Python `czsc.signals.zdy_zs_space_V230421` 保持一致。 +#[signal(category = "kline", name = "zdy_zs_space_V230421", template = "{freq}_D{di}中枢空间_BS辅助V230421", opcode = "ZdyZsSpaceV230421", param_kind = "ZdyZsSpaceV230421")] +pub fn zdy_zs_space_v230421(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}中枢空间", di); + let k3 = "BS辅助V230421"; + if c.bi_list.len() < 7 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + for n in [9, 7, 5] { + let bis = get_sub_elements(&c.bi_list, di, n); + if bis.len() != n { + continue; + } + let zs = ZS::new(bis[1..n - 1].to_vec()); + if !zs.is_valid() { + continue; + } + let bi1 = &bis[0]; + let bi2 = &bis[n - 1]; + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + if bi1.direction == Direction::Up && bi1.get_low() == min_low && bi2.get_high() == max_high && bi2.get_high() - zs.zg >= zs.zd - bi1.get_low() { + return make_kline_signal_v2(&k1, &k2, k3, "上涨", &format!("{}笔", n)); + } + if bi1.direction == Direction::Down && bi1.get_high() == max_high && bi2.get_low() == min_low && zs.zd - bi2.get_low() >= bi1.get_high() - zs.zg { + return make_kline_signal_v2(&k1, &k2, k3, "下跌", &format!("{}笔", n)); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_macd_bc_V230422:MACD 面积背驰辅助信号 +/// +/// 参数模板:`"{freq}_D{di}T{th}MACD面积背驰_BS辅助V230422"` +/// +/// 信号逻辑: +/// 1. 依次尝试最近 `9/7/5` 笔,要求中间结构可构成有效中枢; +/// 2. 比较首笔与末笔内部 MACD 柱面积,并用 `th` 控制末笔面积相对首笔的阈值; +/// 3. 结合 DIF 零轴位置与首末笔高低点,判定 `上涨/下跌` 背驰并输出笔数。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T50MACD面积背驰_BS辅助V230422_上涨_5笔_任意_0')` +/// - `Signal('60分钟_D1T50MACD面积背驰_BS辅助V230422_下跌_7笔_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `th`:末笔面积相对首笔面积的百分比阈值,默认 `50`。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_bc_V230422` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_bc_V230422", template = "{freq}_D{di}T{th}MACD面积背驰_BS辅助V230422", opcode = "ZdyMacdBcV230422", param_kind = "ZdyMacdBcV230422")] +pub fn zdy_macd_bc_v230422(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 50) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}T{}MACD面积背驰", di, th as i32); + let k3 = "BS辅助V230422"; + if c.bi_list.len() < 7 || c.bars_ubi.len() > 7 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let cache_key = "MACD12#26#9"; + update_macd_cache(c, cache_key, 12, 26, 9, cache); + let mc = cache.macd.get(cache_key).unwrap(); + let id_to_idx: HashMap = c + .bars_raw + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + let mut snapshot_overrides = HashMap::new(); + for n in [9, 7, 5] { + let bis = get_sub_elements(&c.bi_list, di, n); + if bis.len() != n || !is_valid_zs(&bis[1..n - 1]) { + continue; + } + let zs = ZS::new(bis[1..n - 1].to_vec()); + let bi1 = &bis[0]; + let bi2 = &bis[n - 1]; + let bi1_raw = bi1.get_raw_bars(); + let bi2_raw = bi2.get_raw_bars(); + let bi1_macd = snapshot_macd_values_from_raw_bars( + c, + mc, + &id_to_idx, + &bi1_raw[1..bi1_raw.len().saturating_sub(1)], + 12, + 26, + 9, + MacdField::Macd, + &mut snapshot_overrides, + ); + let bi2_macd = snapshot_macd_values_from_raw_bars( + c, + mc, + &id_to_idx, + &bi2_raw[1..bi2_raw.len().saturating_sub(1)], + 12, + 26, + 9, + MacdField::Macd, + &mut snapshot_overrides, + ); + if bi1_macd.is_empty() || bi2_macd.is_empty() { + continue; + } + let bi1_dif = macd_snapshot_field_value( + c, + mc, + &id_to_idx, + &bi1_raw[bi1_raw.len() - 2], + 12, + 26, + 9, + MacdField::Dif, + &mut snapshot_overrides, + ) + .unwrap_or(0.0); + let bi2_dif = macd_snapshot_field_value( + c, + mc, + &id_to_idx, + &bi2_raw[bi2_raw.len() - 2], + 12, + 26, + 9, + MacdField::Dif, + &mut snapshot_overrides, + ) + .unwrap_or(0.0); + let zs_fxb_raw: Vec = zs.bis.iter().flat_map(|x| x.fx_b.elements.iter().flat_map(|nb| nb.elements.iter().cloned())).collect(); + let (bi1_area, bi2_area, dif_zero) = if bi1.direction == Direction::Up { + ( + bi1_macd.iter().copied().filter(|x| *x > 0.0).sum::(), + bi2_macd.iter().copied().filter(|x| *x > 0.0).sum::(), + snapshot_macd_values_from_raw_bars( + c, + mc, + &id_to_idx, + &zs_fxb_raw, + 12, + 26, + 9, + MacdField::Dif, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::INFINITY, f64::min), + ) + } else { + ( + bi1_macd.iter().copied().filter(|x| *x < 0.0).sum::(), + bi2_macd.iter().copied().filter(|x| *x < 0.0).sum::(), + snapshot_macd_values_from_raw_bars( + c, + mc, + &id_to_idx, + &zs_fxb_raw, + 12, + 26, + 9, + MacdField::Dif, + &mut snapshot_overrides, + ) + .into_iter() + .fold(f64::NEG_INFINITY, f64::max), + ) + }; + if bi2_area > bi1_area * th / 100.0 { + continue; + } + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + if bi1.direction == Direction::Up && bi1.get_low() == min_low && bi2.get_high() == max_high && dif_zero < 0.0 && bi1_dif > bi2_dif && bi2_dif > 0.0 { + return make_kline_signal_v2(&k1, &k2, k3, "上涨", &format!("{}笔", n)); + } + if bi1.direction == Direction::Down && bi1.get_high() == max_high && bi2.get_low() == min_low && dif_zero > 0.0 && bi1_dif < bi2_dif && bi2_dif < 0.0 { + return make_kline_signal_v2(&k1, &k2, k3, "下跌", &format!("{}笔", n)); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_macd_bs1_V230422:MACD 一买一卖辅助信号 +/// +/// 参数模板:`"{freq}_D{di}T{th}MACD_BS1辅助V230422"` +/// +/// 信号逻辑: +/// 1. 依次尝试最近 `13/11/9/7/5` 笔,要求中间结构构成有效中枢; +/// 2. 比较首笔与末笔的 MACD 柱面积、末笔起点 DIF 与中枢 DIF 极值; +/// 3. 满足首末笔方向、极值位置和 DIF 强弱关系时输出 `看多/看空 + 上涨/下跌N笔`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1T50MACD_BS1辅助V230422_看空_上涨5笔_任意_0')` +/// - `Signal('60分钟_D1T50MACD_BS1辅助V230422_看多_下跌7笔_任意_0')` +/// +/// 参数说明: +/// - `di`:从倒数第 `di` 笔开始取样,默认 `1`; +/// - `th`:末笔 MACD 面积占首笔面积的最大百分比,默认 `50`。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_bs1_V230422` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_bs1_V230422", template = "{freq}_D{di}T{th}MACD_BS1辅助V230422", opcode = "ZdyMacdBs1V230422", param_kind = "ZdyMacdBs1V230422")] +pub fn zdy_macd_bs1_v230422(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let th = get_usize_param(params, "th", 50) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("D{}T{}MACD", di, th as i32); + let k3 = "BS1辅助V230422"; + if c.bi_list.len() < 7 || c.bars_ubi.len() > 9 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let (dif_map, _, macd_map) = macd_cache_maps(c, 26, 12, 9, cache); + for n in [13, 11, 9, 7, 5] { + let bis = get_sub_elements(&c.bi_list, di, n); + if bis.len() != n || !is_valid_zs(&bis[1..n - 1]) { + continue; + } + let zs = ZS::new(bis[1..n - 1].to_vec()); + let bi1 = &bis[0]; + let bi2 = &bis[n - 1]; + let bi1_raw = bi1.get_raw_bars(); + let bi2_raw = bi2.get_raw_bars(); + if bi1_raw.len() < 3 || bi2_raw.len() < 3 { + continue; + } + let bi1_area = bi1_raw[1..bi1_raw.len() - 1].iter().filter_map(|x| macd_map.get(&x.id).copied()).map(f64::abs).sum::(); + let bi2_area = bi2_raw[1..bi2_raw.len() - 1].iter().filter_map(|x| macd_map.get(&x.id).copied()).map(f64::abs).sum::(); + let bi1_dif = *dif_map.get(&bi1_raw[bi1_raw.len() - 2].id).unwrap_or(&0.0); + let bi2_dif = *dif_map.get(&bi2_raw[bi2_raw.len() - 2].id).unwrap_or(&0.0); + let bi2_start_dif = *dif_map.get(&bi2_raw[1].id).unwrap_or(&0.0); + let zs_dif = if bi1.direction == Direction::Up { + zs.bis + .iter() + .filter(|x| x.direction == Direction::Up) + .flat_map(|x| x.fx_b.elements.iter().flat_map(|nb| nb.elements.iter())) + .filter_map(|x| dif_map.get(&x.id).copied()) + .fold(f64::NEG_INFINITY, f64::max) + } else { + zs.bis + .iter() + .filter(|x| x.direction == Direction::Down) + .flat_map(|x| x.fx_b.elements.iter().flat_map(|nb| nb.elements.iter())) + .filter_map(|x| dif_map.get(&x.id).copied()) + .fold(f64::INFINITY, f64::min) + }; + if bi2_area > bi1_area * th / 100.0 { + continue; + } + let min_low = bis.iter().map(|x| x.get_low()).fold(f64::INFINITY, f64::min); + let max_high = bis.iter().map(|x| x.get_high()).fold(f64::NEG_INFINITY, f64::max); + if bi1.direction == Direction::Up + && bi1.get_low() == min_low + && bi2.get_high() == max_high + && bi2_start_dif < zs_dif.abs() * 0.5 + && bi1_dif > bi2_dif + && bi2_dif > zs_dif + && zs_dif > 0.0 + { + return make_kline_signal_v2(&k1, &k2, k3, "看空", &format!("上涨{}笔", n)); + } + if bi1.direction == Direction::Down + && bi1.get_high() == max_high + && bi2.get_low() == min_low + && bi2_start_dif > zs_dif.abs() * 0.5 + && 0.0 > zs_dif + && zs_dif > bi2_dif + && bi2_dif > bi1_dif + { + return make_kline_signal_v2(&k1, &k2, k3, "看多", &format!("下跌{}笔", n)); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_macd_dif_V230516:DIF 走平后的反向观察信号 +/// +/// 参数模板:`"{freq}_D{di}DIF走平_BS辅助V230516"` +/// +/// 信号逻辑: +/// 1. 取最近 10 根 K 线的 DIF 变化,计算相邻差分的平均波动阈值; +/// 2. 若最新 DIF 相对前一根不再继续大幅下行,结合 MACD 绿柱位置判定 `看多`;反向同理判定 `看空`; +/// 3. 用 `绿柱远离/红柱远离/柱子否定` 说明当前柱体是否支持反转观察。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1DIF走平_BS辅助V230516_看多_绿柱远离_任意_0')` +/// - `Signal('60分钟_D1DIF走平_BS辅助V230516_看空_红柱远离_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根 K 线,默认 `1`; +/// - 固定使用 `12,26,9` MACD 参数,并观察最近 10 根 K 线。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_dif_V230516` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_dif_V230516", template = "{freq}_D{di}DIF走平_BS辅助V230516", opcode = "ZdyMacdDifV230516", param_kind = "ZdyMacdDifV230516")] +pub fn zdy_macd_dif_v230516(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}DIF走平", di); + let k3 = "BS辅助V230516"; + if c.bars_raw.len() < 12 + di { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let (dif_map, _, macd_map) = macd_cache_maps(c, 12, 26, 9, cache); + let bars = get_sub_elements(&c.bars_raw, di, 10); + let dif: Vec = bars.iter().filter_map(|x| dif_map.get(&x.id).copied()).collect(); + if dif.len() < 2 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let dif_th = dif.windows(2).map(|w| (w[0] - w[1]).abs()).sum::() / dif.len() as f64 * 0.2; + let mut v1 = "其他"; + let mut v2 = "其他"; + if dif[dif.len() - 1] - dif[dif.len() - 2] > -dif_th { + let min_macd = bars.iter().filter_map(|x| macd_map.get(&x.id).copied()).fold(f64::INFINITY, f64::min); + v1 = "看多"; + v2 = if dif[dif.len() - 1] < min_macd * 2.5 { "绿柱远离" } else { "柱子否定" }; + } + if dif[dif.len() - 1] - dif[dif.len() - 2] < dif_th { + let max_macd = bars.iter().filter_map(|x| macd_map.get(&x.id).copied()).fold(f64::NEG_INFINITY, f64::max); + v1 = "看空"; + v2 = if dif[dif.len() - 1] > max_macd * 2.5 { "红柱远离" } else { "柱子否定" }; + } + make_kline_signal_v2(&k1, &k2, k3, v1, v2) +} + +/// zdy_macd_dif_V230517:MACD 开仓辅助信号 +/// +/// 参数模板:`"{freq}_D{di}MACD开仓_BS辅助V230517"` +/// +/// 信号逻辑: +/// 1. 取最近 20 根 K 线的 DIF 与 MACD 柱体; +/// 2. 若最新 DIF 重新站上零轴或 MACD 柱发生金叉/死叉/飞吻形态,则分别判定 `看多/看空`; +/// 3. 未出现明确开仓形态时返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD开仓_BS辅助V230517_看多_MACD金叉_任意_0')` +/// - `Signal('60分钟_D1MACD开仓_BS辅助V230517_看空_DIF破零轴_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根 K 线,默认 `1`; +/// - 固定使用 `12,26,9` MACD 参数,至少要求 50 根原始 K 线预热。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_dif_V230517` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_dif_V230517", template = "{freq}_D{di}MACD开仓_BS辅助V230517", opcode = "ZdyMacdDifV230517", param_kind = "ZdyMacdDifV230517")] +pub fn zdy_macd_dif_v230517(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}MACD开仓", di); + let k3 = "BS辅助V230517"; + if c.bars_raw.len() < 50 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let (dif_map, _, macd_map) = macd_cache_maps(c, 12, 26, 9, cache); + let bars = get_sub_elements(&c.bars_raw, di, 20); + let macd: Vec = bars.iter().filter_map(|x| macd_map.get(&x.id).copied()).collect(); + let dif: Vec = bars.iter().filter_map(|x| dif_map.get(&x.id).copied()).collect(); + if dif.last().copied().unwrap_or(0.0) > 0.0 { + let mut v2 = None; + if dif[..dif.len() - 1].iter().all(|x| *x < 0.0) { + v2 = Some("DIF破零轴"); + } + if macd[macd.len() - 1] > 0.0 && macd[macd.len() - 2] < 0.0 { + v2 = Some("MACD金叉"); + } + if macd[macd.len() - 5] > macd[macd.len() - 4] && macd[macd.len() - 4] > macd[macd.len() - 3] && macd[macd.len() - 3] > macd[macd.len() - 2] && macd[macd.len() - 2] < macd[macd.len() - 1] && macd[macd.len() - 2] > 0.0 { + v2 = Some("MACD飞吻"); + } + if let Some(v2) = v2 { + return make_kline_signal_v2(&k1, &k2, k3, "看多", v2); + } + } + if dif.last().copied().unwrap_or(0.0) < 0.0 { + let mut v2 = None; + if dif[..dif.len() - 1].iter().all(|x| *x > 0.0) { + v2 = Some("DIF破零轴"); + } + if macd[macd.len() - 1] < 0.0 && macd[macd.len() - 2] > 0.0 { + v2 = Some("MACD死叉"); + } + if macd[macd.len() - 5] < macd[macd.len() - 4] && macd[macd.len() - 4] < macd[macd.len() - 3] && macd[macd.len() - 3] < macd[macd.len() - 2] && macd[macd.len() - 2] > macd[macd.len() - 1] && macd[macd.len() - 2] < 0.0 { + v2 = Some("MACD飞吻"); + } + if let Some(v2) = v2 { + return make_kline_signal_v2(&k1, &k2, k3, "看空", v2); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_macd_V230518:MACD 交叉计数信号 +/// +/// 参数模板:`"{freq}_D{di}MACD交叉N{n}_BS辅助V230518"` +/// +/// 信号逻辑: +/// 1. 取最近 `n + 1` 根 K 线的 MACD 柱值; +/// 2. 根据最新柱体正负判定 `金叉/死叉` 方向; +/// 3. 从最新柱开始逆序统计同号连续根数,输出 `第N次` 或 `超计数范围`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1MACD交叉N9_BS辅助V230518_金叉_第3次_任意_0')` +/// - `Signal('60分钟_D1MACD交叉N9_BS辅助V230518_死叉_超计数范围_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根 K 线,默认 `1`; +/// - `n`:最大统计窗口,默认 `9`。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_V230518` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_V230518", template = "{freq}_D{di}MACD交叉N{n}_BS辅助V230518", opcode = "ZdyMacdV230518", param_kind = "ZdyMacdV230518")] +pub fn zdy_macd_v230518(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 9); + let k1 = c.freq.to_string(); + let k2 = format!("D{}MACD交叉N{}", di, n); + let k3 = "BS辅助V230518"; + if c.bars_raw.len() < 50 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let (_, _, macd_map) = macd_cache_maps(c, 12, 26, 9, cache); + let bars = get_sub_elements(&c.bars_raw, di, n + 1); + let macd: Vec = bars.iter().filter_map(|x| macd_map.get(&x.id).copied()).collect(); + let v1 = if macd.last().copied().unwrap_or(0.0) > 0.0 { "金叉" } else { "死叉" }; + let mut count = 0usize; + for m in macd.iter().rev() { + if (*m > 0.0 && macd[macd.len() - 1] > 0.0) || (*m < 0.0 && macd[macd.len() - 1] < 0.0) { + count += 1; + } else { + break; + } + } + if count == n + 1 { + return make_kline_signal_v2(&k1, &k2, k3, v1, "超计数范围"); + } + make_kline_signal_v2(&k1, &k2, k3, v1, &format!("第{}次", count)) +} + +/// zdy_macd_V230519:MACD 连续缩柱信号 +/// +/// 参数模板:`"{freq}_D{di}N{n}MACD缩柱_BS辅助V230519"` +/// +/// 信号逻辑: +/// 1. 取最近 `n` 根 K 线的 MACD 柱值; +/// 2. 若全部为正且柱体连续缩短,判定 `多头连续缩柱`; +/// 3. 若全部为负且绝对值连续缩短,判定 `空头连续缩柱`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1N3MACD缩柱_BS辅助V230519_多头连续缩柱_任意_任意_0')` +/// - `Signal('60分钟_D1N3MACD缩柱_BS辅助V230519_空头连续缩柱_任意_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根 K 线,默认 `1`; +/// - `n`:连续缩柱的观察窗口,默认 `3`。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_V230519` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_V230519", template = "{freq}_D{di}N{n}MACD缩柱_BS辅助V230519", opcode = "ZdyMacdV230519", param_kind = "ZdyMacdV230519")] +pub fn zdy_macd_v230519(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let n = get_usize_param(params, "n", 3); + let k1 = c.freq.to_string(); + let k2 = format!("D{}N{}MACD缩柱", di, n); + let k3 = "BS辅助V230519"; + if c.bars_raw.len() < 50 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let (_, _, macd_map) = macd_cache_maps(c, 12, 26, 9, cache); + let bars = get_sub_elements(&c.bars_raw, di, n); + let macd: Vec = bars.iter().filter_map(|x| macd_map.get(&x.id).copied()).collect(); + if macd.iter().all(|x| *x > 0.0) && macd.windows(2).all(|w| w[1] < w[0]) { + return make_kline_signal_v1(&k1, &k2, k3, "多头连续缩柱"); + } + if macd.iter().all(|x| *x < 0.0) && macd.windows(2).all(|w| w[1] > w[0]) { + return make_kline_signal_v1(&k1, &k2, k3, "空头连续缩柱"); + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_macd_dif_iqr_V230521:DIF 走平 IQR 版本辅助信号 +/// +/// 参数模板:`"{freq}_D{di}DIF走平IQR_BS辅助V230521"` +/// +/// 信号逻辑: +/// 1. 取最近 100 根 K 线的 DIF 序列,计算四分位距 `IQR`; +/// 2. 若最近 3 根 DIF 振幅小于 `IQR`,视为 DIF 走平; +/// 3. 再结合最新 MACD 柱正负和 DIF 与柱体距离,输出 `看多/看空 + 远离/否定`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_D1DIF走平IQR_BS辅助V230521_看多_绿柱远离_任意_0')` +/// - `Signal('60分钟_D1DIF走平IQR_BS辅助V230521_看空_红柱远离_任意_0')` +/// +/// 参数说明: +/// - `di`:信号计算截止在倒数第 `di` 根 K 线,默认 `1`; +/// - 固定使用最近 100 根 K 线做 IQR 估计,至少要求 50 根原始 K 线预热。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_dif_iqr_V230521` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_dif_iqr_V230521", template = "{freq}_D{di}DIF走平IQR_BS辅助V230521", opcode = "ZdyMacdDifIqrV230521", param_kind = "ZdyMacdDifIqrV230521")] +pub fn zdy_macd_dif_iqr_v230521(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let di = get_usize_param(params, "di", 1); + let k1 = c.freq.to_string(); + let k2 = format!("D{}DIF走平IQR", di); + let k3 = "BS辅助V230521"; + if c.bars_raw.len() < 50 { + return make_kline_signal_v2(&k1, &k2, k3, "其他", "其他"); + } + let (dif_map, _, macd_map) = macd_cache_maps(c, 12, 26, 9, cache); + let bars = get_sub_elements(&c.bars_raw, di, 100); + let macd = macd_map.get(&bars.last().unwrap().id).copied().unwrap_or(0.0) * 2.0; + let dif: Vec = bars.iter().filter_map(|x| dif_map.get(&x.id).copied()).collect(); + let q3 = percentile_linear(&dif, 75.0).unwrap_or(0.0); + let q1 = percentile_linear(&dif, 25.0).unwrap_or(0.0); + let iqr = q3 - q1; + if dif[dif.len() - 3..].iter().copied().fold(f64::NEG_INFINITY, f64::max) - dif[dif.len() - 3..].iter().copied().fold(f64::INFINITY, f64::min) < iqr && macd < 0.0 { + return make_kline_signal_v2(&k1, &k2, k3, "看多", if dif[dif.len() - 1] < macd { "绿柱远离" } else { "柱子否定" }); + } + if dif[dif.len() - 3..].iter().copied().fold(f64::NEG_INFINITY, f64::max) - dif[dif.len() - 3..].iter().copied().fold(f64::INFINITY, f64::min) < iqr && macd > 0.0 { + return make_kline_signal_v2(&k1, &k2, k3, "看空", if dif[dif.len() - 1] > macd { "红柱远离" } else { "柱子否定" }); + } + make_kline_signal_v2(&k1, &k2, k3, "其他", "其他") +} + +/// zdy_macd_V230527:MACD 因子远离统计信号 +/// +/// 参数模板:`"{freq}_{key}远离W{w}N{n}T{t}_BS辅助V230527"` +/// +/// 信号逻辑: +/// 1. 在最近 `w` 根 K 线内提取 `DIF/DEA/MACD` 之一作为统计因子; +/// 2. 用绝对值中位数与总体标准差构造远离阈值; +/// 3. 若最近 `n` 根中的最大绝对值超过阈值,则输出 `多头远离/空头远离`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_DIF远离W100N10T20_BS辅助V230527_多头远离_任意_任意_0')` +/// - `Signal('60分钟_MACD远离W200N20T30_BS辅助V230527_空头远离_任意_任意_0')` +/// +/// 参数说明: +/// - `key`:参与统计的 MACD 因子,支持 `DIF/DEA/MACD`,默认 `DIF`; +/// - `w`:统计窗口长度,默认 `100`; +/// - `n`:最近观察窗口长度,默认 `10`; +/// - `t`:标准差放大系数,默认 `20`。 +/// 对齐说明:与 Python `czsc.signals.zdy_macd_V230527` 保持一致。 +#[signal(category = "kline", name = "zdy_macd_V230527", template = "{freq}_{key}远离W{w}N{n}T{t}_BS辅助V230527", opcode = "ZdyMacdV230527", param_kind = "ZdyMacdV230527")] +pub fn zdy_macd_v230527(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 10); + let w = get_usize_param(params, "w", 100); + let t = get_usize_param(params, "t", 20) as f64; + let key = params.str("key", "DIF").to_uppercase(); + let k1 = c.freq.to_string(); + let k2 = format!("{}远离W{}N{}T{}", key, w, n, t as i32); + let k3 = "BS辅助V230527"; + if c.bi_list.len() < 3 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let (dif_map, dea_map, macd_map) = macd_cache_maps(c, 12, 26, 9, cache); + let bars = get_sub_elements(&c.bars_raw, 1, w); + let factors: Vec = bars + .iter() + .filter_map(|x| match key.as_str() { + "DIF" => dif_map.get(&x.id).copied(), + "DEA" => dea_map.get(&x.id).copied(), + _ => macd_map.get(&x.id).copied(), + }) + .collect(); + let median = median_abs(&factors); + let std = std_pop(&factors.iter().map(|x| x.abs()).collect::>()); + let last_n = &factors[factors.len().saturating_sub(n)..]; + let max_abs = *last_n.iter().max_by(|a, b| a.abs().partial_cmp(&b.abs()).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&0.0); + if max_abs.abs() > median + t / 10.0 * std { + return make_kline_signal_v1(&k1, &k2, k3, if max_abs > 0.0 { "多头远离" } else { "空头远离" }); + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_dif_V230527:DIF 相对 MACD 柱的远离信号 +/// +/// 参数模板:`"{freq}_N{n}T{t}_DIF远离V230527"` +/// +/// 信号逻辑: +/// 1. 在最近 `n * 8` 根 K 线中找到最近 `n` 根内绝对值最大的 DIF; +/// 2. 若该 DIF 为正,则与历史正 MACD 柱峰值比较;为负则与历史负 MACD 柱绝对值比较; +/// 3. 超过阈值后输出 `多头远离/空头远离`,否则返回 `其他`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N10T30_DIF远离V230527_多头远离_任意_任意_0')` +/// - `Signal('60分钟_N20T40_DIF远离V230527_空头远离_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:最近观察窗口长度,默认 `10`; +/// - `t`:与历史柱峰值比较的放大系数,默认 `30`。 +/// 对齐说明:与 Python `czsc.signals.zdy_dif_V230527` 保持一致。 +#[signal(category = "kline", name = "zdy_dif_V230527", template = "{freq}_N{n}T{t}_DIF远离V230527", opcode = "ZdyDifV230527", param_kind = "ZdyDifV230527")] +pub fn zdy_dif_v230527(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 10); + let t = get_usize_param(params, "t", 30) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("N{}T{}", n, t as i32); + let k3 = "DIF远离V230527"; + if c.bars_raw.len() < 30 + n * 8 { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let (dif_map, _, macd_map) = macd_cache_maps(c, 12, 26, 9, cache); + let bars = get_sub_elements(&c.bars_raw, 1, n * 8); + let tail = &bars[bars.len() - n..]; + let max_abs_dif_bar = tail.iter().max_by(|a, b| dif_map.get(&a.id).unwrap_or(&0.0).abs().partial_cmp(&dif_map.get(&b.id).unwrap_or(&0.0).abs()).unwrap_or(std::cmp::Ordering::Equal)).unwrap(); + let max_abs_dif = *dif_map.get(&max_abs_dif_bar.id).unwrap_or(&0.0); + if max_abs_dif > 0.0 { + let seq: Vec = bars.iter().filter_map(|x| macd_map.get(&x.id).copied()).filter(|x| *x > 0.0).collect(); + if seq.len() > n && max_abs_dif.abs() > seq.iter().copied().fold(f64::NEG_INFINITY, f64::max) * t / 10.0 { + return make_kline_signal_v1(&k1, &k2, k3, "多头远离"); + } + } else if max_abs_dif < 0.0 { + let seq: Vec = bars.iter().filter_map(|x| macd_map.get(&x.id).copied()).filter(|x| *x < 0.0).map(f64::abs).collect(); + if seq.len() > n && max_abs_dif.abs() > seq.iter().copied().fold(f64::NEG_INFINITY, f64::max) * t / 10.0 { + return make_kline_signal_v1(&k1, &k2, k3, "空头远离"); + } + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} + +/// zdy_dif_V230528:DIF 峰谷分位远离信号 +/// +/// 参数模板:`"{freq}_N{n}T{t}_DIF远离V230528"` +/// +/// 信号逻辑: +/// 1. 提取最近最多 1000 根 K 线的 DIF 序列,并识别局部峰值与谷值; +/// 2. 用峰值上分位和谷值下分位构造多空远离阈值; +/// 3. 若最近一次峰值或谷值超过对应分位阈值,且最新 DIF 同向,则输出 `多头远离/空头远离`。 +/// +/// 信号列表示例: +/// - `Signal('60分钟_N20T80_DIF远离V230528_多头远离_任意_任意_0')` +/// - `Signal('60分钟_N20T80_DIF远离V230528_空头远离_任意_任意_0')` +/// +/// 参数说明: +/// - `n`:参与比较的峰谷样本数量下限,默认 `20`; +/// - `t`:峰值分位数阈值,默认 `80`,谷值侧使用 `100 - t`。 +/// 对齐说明:与 Python `czsc.signals.zdy_dif_V230528` 保持一致。 +#[signal(category = "kline", name = "zdy_dif_V230528", template = "{freq}_N{n}T{t}_DIF远离V230528", opcode = "ZdyDifV230528", param_kind = "ZdyDifV230528")] +pub fn zdy_dif_v230528(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> Vec { + let n = get_usize_param(params, "n", 20); + let t = get_usize_param(params, "t", 80) as f64; + let k1 = c.freq.to_string(); + let k2 = format!("N{}T{}", n, t as i32); + let k3 = "DIF远离V230528"; + let (dif_map, _, _) = macd_cache_maps(c, 12, 26, 9, cache); + let dif_values: Vec = c.bars_raw.iter().rev().take(1000).collect::>().into_iter().rev().filter_map(|x| dif_map.get(&x.id).copied()).collect(); + let (peaks, valleys) = find_peaks_valleys(&dif_values); + if peaks.len() < n || valleys.len() < n { + return make_kline_signal_v1(&k1, &k2, k3, "其他"); + } + let peaks_n = percentile_linear(&peaks.values().copied().collect::>(), t).unwrap_or(f64::INFINITY); + let valleys_n = percentile_linear(&valleys.values().copied().collect::>(), 100.0 - t).unwrap_or(f64::NEG_INFINITY); + + if peaks.keys().max() > valleys.keys().max() && *peaks.get(peaks.keys().max().unwrap()).unwrap() > peaks_n && *dif_values.last().unwrap_or(&0.0) > 0.0 { + return make_kline_signal_v1(&k1, &k2, k3, "多头远离"); + } + if valleys.keys().max() > peaks.keys().max() && *valleys.get(valleys.keys().max().unwrap()).unwrap() < valleys_n && *dif_values.last().unwrap_or(&0.0) < 0.0 { + return make_kline_signal_v1(&k1, &k2, k3, "空头远离"); + } + make_kline_signal_v1(&k1, &k2, k3, "其他") +} diff --git a/crates/czsc-signals/src/zdy_trader.rs b/crates/czsc-signals/src/zdy_trader.rs new file mode 100644 index 000000000..5f0b5c064 --- /dev/null +++ b/crates/czsc-signals/src/zdy_trader.rs @@ -0,0 +1,322 @@ +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{ + get_str_param, get_usize_param, last_open_operate, make_signal, make_signal_v1, +}; +use crate::utils::ta::update_macd_cache; +use czsc_core::objects::direction::Direction; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::mark::Mark; +use czsc_core::objects::operate::Operate; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::state::TraderState; +use czsc_signal_macros::signal; +use std::collections::HashMap; +use std::str::FromStr; + +fn macd_map(cache: &TaCache, cache_key: &str) -> HashMap { + let mut out = HashMap::new(); + if let Some(series) = cache.macd.get(cache_key) { + for (i, id) in series.ids.iter().enumerate() { + out.insert(*id, series.macd[i]); + } + } + out +} + +/// zdy_vibrate_V230406:中枢震荡短差辅助 +#[signal( + category = "trader", + name = "zdy_vibrate_V230406", + template = "中枢震荡_{freq1}#{freq2}_BS辅助V230406", + opcode = "ZdyVibrateV230406", + param_kind = "ZdyVibrateV230406" +)] +pub fn zdy_vibrate_v230406(cat: &dyn TraderState, params: &ParamView) -> Vec { + let freq1 = get_str_param(params, "freq1", "5分钟"); + let freq2 = get_str_param(params, "freq2", "60分钟"); + let k1 = "中枢震荡"; + let k2 = format!("{}#{}", freq1, freq2); + let k3 = "BS辅助V230406"; + + let Ok(f1) = Freq::from_str(freq1) else { + return make_signal_v1(k1, &k2, k3, "其他"); + }; + let Ok(f2) = Freq::from_str(freq2) else { + return make_signal_v1(k1, &k2, k3, "其他"); + }; + assert!(f1 < f2, "freq1 必须小于 freq2"); + + let Some(c1) = cat.get_czsc(freq1) else { + return make_signal_v1(k1, &k2, k3, "其他"); + }; + let Some(c2) = cat.get_czsc(freq2) else { + return make_signal_v1(k1, &k2, k3, "其他"); + }; + if c2.bi_list.len() < 5 + || c1.bi_list.is_empty() + || c1.bars_raw.is_empty() + || c2.bars_raw.is_empty() + { + return make_signal_v1(k1, &k2, k3, "其他"); + } + + let mut cache = TaCache::new(); + let cache_key = "MACD12#26#9"; + update_macd_cache(c2, cache_key, 12, 26, 9, &mut cache); + let macd = macd_map(&cache, cache_key); + + let b1 = &c2.bi_list[c2.bi_list.len() - 4]; + let b2 = &c2.bi_list[c2.bi_list.len() - 3]; + let b3 = &c2.bi_list[c2.bi_list.len() - 2]; + let zg = b1.get_high().min(b2.get_high()).min(b3.get_high()); + let zd = b1.get_low().max(b2.get_low()).max(b3.get_low()); + if zd > zg { + return make_signal_v1(k1, &k2, k3, "其他"); + } + + let c1_lbi = c1.bi_list.last().unwrap(); + let c1_bar = c1.bars_raw.last().unwrap(); + let c1_fx_bars = c1_lbi + .fx_b + .elements + .last() + .map(|x| x.elements.clone()) + .unwrap_or_default(); + if c1_fx_bars.is_empty() || c1_bar.dt == c1_fx_bars.last().unwrap().dt || c1.bars_ubi.len() > 6 + { + return make_signal_v1(k1, &k2, k3, "其他"); + } + + let temp = if c1_lbi.direction == Direction::Down + && c1_bar.close > c1_bar.open + && c1_bar.close > c1_fx_bars.last().unwrap().high + { + Some("底分停顿") + } else if c1_lbi.direction == Direction::Up + && c1_bar.close < c1_bar.open + && c1_bar.close < c1_fx_bars.last().unwrap().low + { + Some("顶分停顿") + } else { + None + }; + let Some(temp) = temp else { + return make_signal_v1(k1, &k2, k3, "其他"); + }; + + let c2_bar = c2.bars_raw.last().unwrap(); + let c2_macd = *macd.get(&c2_bar.id).unwrap_or(&f64::NAN); + if temp == "顶分停顿" && c2_macd < 0.0 { + let p = c1_bar.close; + let h = c1_lbi.get_high(); + if h >= zg && (h - zg) < (zg - zd) && (h - p) * 3.0 < (zg - zd) { + return make_signal_v1(k1, &k2, k3, "看空"); + } + } + if temp == "底分停顿" && c2_macd > 0.0 { + let p = c1_bar.close; + let l = c1_lbi.get_low(); + if l <= zd && (zd - l) < (zg - zd) && (p - l) * 3.0 < (zg - zd) { + return make_signal_v1(k1, &k2, k3, "看多"); + } + } + make_signal_v1(k1, &k2, k3, "其他") +} + +/// zdy_stop_loss_V230406:笔操作止损逻辑 +#[signal( + category = "trader", + name = "zdy_stop_loss_V230406", + template = "{freq1}_{pos_name}F{first_stop}_止损V230406", + opcode = "ZdyStopLossV230406", + param_kind = "ZdyStopLossV230406" +)] +pub fn zdy_stop_loss_v230406(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let first_stop = get_usize_param(params, "first_stop", 300) as f64; + let k1 = freq1.to_string(); + let k2 = format!("{}F{}", pos_name, first_stop as i32); + let k3 = "止损V230406"; + + let Some(pos) = cat.get_position(pos_name) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + if pos.operates.is_empty() + || matches!(pos.operates.last().unwrap().op, Operate::SE | Operate::LE) + { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + let Some(op) = last_open_operate(cat, pos_name) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(c) = cat.get_czsc(freq1) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + if c.bi_list.len() < 3 || c.bars_raw.is_empty() { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + + let d3bi = &c.bi_list[c.bi_list.len() - 3]; + let bis: Vec<_> = c.bi_list.iter().filter(|x| x.fx_b.dt >= op.dt).collect(); + let last_bar = c.bars_raw.last().unwrap(); + let fxs = c.get_fx_list(); + + let mut v1 = "其他"; + let mut v2 = "其他"; + if op.op == Operate::LO { + if let Some(open_base_fx) = fxs.iter().rfind(|x| x.mark == Mark::D && x.dt < op.dt) { + if last_bar.close < open_base_fx.low { + v1 = "多头止损"; + v2 = "跌破分型低点"; + } + } + if (last_bar.close / op.price - 1.0) * 10000.0 <= -first_stop { + v1 = "多头止损"; + v2 = "进场点止损"; + } + if !bis.is_empty() + && bis.last().unwrap().direction == Direction::Up + && bis.last().unwrap().get_high() > d3bi.get_high() + && last_bar.close < op.price + { + v1 = "多头止损"; + v2 = "跌破成本价"; + } + if bis.len() > 1 + && bis.last().unwrap().direction == Direction::Up + && last_bar.close < bis[bis.len() - 2].fx_b.low + { + v1 = "多头止损"; + v2 = "跌破上个向下笔底"; + } + } + if op.op == Operate::SO { + if let Some(open_base_fx) = fxs.iter().rfind(|x| x.mark == Mark::G && x.dt < op.dt) { + if last_bar.close > open_base_fx.high { + v1 = "空头止损"; + v2 = "升破分型高点"; + } + } + if (1.0 - last_bar.close / op.price) * 10000.0 <= -first_stop { + v1 = "空头止损"; + v2 = "进场点止损"; + } + if !bis.is_empty() + && bis.last().unwrap().direction == Direction::Down + && bis.last().unwrap().get_low() < d3bi.get_low() + && last_bar.close > op.price + { + v1 = "空头止损"; + v2 = "升破成本价"; + } + if bis.len() > 1 + && bis.last().unwrap().direction == Direction::Down + && last_bar.close > bis[bis.len() - 2].fx_b.high + { + v1 = "空头止损"; + v2 = "升破上个向上笔顶"; + } + } + make_signal(&k1, &k2, k3, v1, v2) +} + +/// zdy_take_profit_V230406:笔操作止盈逻辑 +#[signal( + category = "trader", + name = "zdy_take_profit_V230406", + template = "{freq1}_{pos_name}_止盈V230406", + opcode = "ZdyTakeProfitV230406", + param_kind = "ZdyTakeProfitV230406" +)] +pub fn zdy_take_profit_v230406(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let k1 = freq1.to_string(); + let k2 = pos_name.to_string(); + let k3 = "止盈V230406"; + + let Some(_pos) = cat.get_position(pos_name) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(op) = last_open_operate(cat, pos_name) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(c) = cat.get_czsc(freq1) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let bis: Vec<_> = c.bi_list.iter().filter(|x| x.fx_b.dt >= op.dt).collect(); + let mut v1 = "其他"; + let mut v2 = "其他"; + if op.op == Operate::LO + && bis.len() > 1 + && bis.last().unwrap().direction == Direction::Up + && bis.last().unwrap().get_high() < bis[bis.len() - 2].get_high() + { + v1 = "多头止盈"; + v2 = "向上笔不创新高"; + } + if op.op == Operate::SO + && bis.len() > 1 + && bis.last().unwrap().direction == Direction::Down + && bis.last().unwrap().get_low() > bis[bis.len() - 2].get_low() + { + v1 = "空头止盈"; + v2 = "向下笔不创新低"; + } + make_signal(&k1, &k2, k3, v1, v2) +} + +/// zdy_take_profit_V230407:按力度提前止盈 +#[signal( + category = "trader", + name = "zdy_take_profit_V230407", + template = "{freq1}_{pos_name}_止盈V230407", + opcode = "ZdyTakeProfitV230407", + param_kind = "ZdyTakeProfitV230407" +)] +pub fn zdy_take_profit_v230407(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", ""); + let freq1 = get_str_param(params, "freq1", ""); + let k1 = freq1.to_string(); + let k2 = pos_name.to_string(); + let k3 = "止盈V230407"; + + let Some(_pos) = cat.get_position(pos_name) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(op) = last_open_operate(cat, pos_name) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + let Some(c) = cat.get_czsc(freq1) else { + return make_signal_v1(&k1, &k2, k3, "其他"); + }; + if c.bi_list.len() < 2 { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + let bis: Vec<_> = c.bi_list.iter().filter(|x| x.fx_b.dt >= op.dt).collect(); + let d2bi = &c.bi_list[c.bi_list.len() - 2]; + if bis.is_empty() || (bis.last().unwrap().get_length() as f64) < 1.5 * d2bi.get_length() as f64 + { + return make_signal_v1(&k1, &k2, k3, "其他"); + } + + let mut v1 = "其他"; + let mut v2 = "其他"; + if op.op == Operate::LO + && bis.last().unwrap().direction == Direction::Up + && bis.last().unwrap().get_high() < d2bi.get_high() + { + v1 = "多头止盈"; + v2 = "向上笔不创新高"; + } + if op.op == Operate::SO + && bis.last().unwrap().direction == Direction::Down + && bis.last().unwrap().get_low() > d2bi.get_low() + { + v1 = "空头止盈"; + v2 = "向下笔不创新低"; + } + make_signal(&k1, &k2, k3, v1, v2) +} diff --git a/crates/czsc-ta/Cargo.toml b/crates/czsc-ta/Cargo.toml new file mode 100644 index 000000000..164466e18 --- /dev/null +++ b/crates/czsc-ta/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "czsc-ta" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "CZSC technical analysis operators (ema/sma/rolling_rank/...). Placeholder, to be migrated." + +[lib] +name = "czsc_ta" +path = "src/lib.rs" + +[dependencies] +numpy = { workspace = true, optional = true } +ordered-float = { version = "5.0", optional = true } +pyo3 = { workspace = true, optional = true, features = ["chrono"] } +pyo3-stub-gen = { version = "0.12", optional = true } + +[features] +python = ["pyo3", "pyo3-stub-gen"] +rust-numpy = ["python", "numpy", "ordered-float"] diff --git a/crates/czsc-ta/src/lib.rs b/crates/czsc-ta/src/lib.rs new file mode 100644 index 000000000..9ccf3df13 --- /dev/null +++ b/crates/czsc-ta/src/lib.rs @@ -0,0 +1,16 @@ +//! czsc-ta — technical analysis operators. +//! +//! Migrated from rs-czsc 47ef6efa per docs/MIGRATION_NOTES.md §1. +//! Phase E call-graph analysis (recorded in §2.3) confirmed the full +//! operator set is consumed by Python via `rust-numpy`; nothing was +//! trimmed. + +#![allow(clippy::needless_range_loop, clippy::manual_memcpy)] + +pub mod pure; + +#[cfg(feature = "rust-numpy")] +pub mod mixed; + +#[cfg(feature = "rust-numpy")] +pub mod python; diff --git a/crates/czsc-ta/src/mixed/chip_dist.rs b/crates/czsc-ta/src/mixed/chip_dist.rs new file mode 100644 index 000000000..8e3704eed --- /dev/null +++ b/crates/czsc-ta/src/mixed/chip_dist.rs @@ -0,0 +1,133 @@ +use ordered_float::OrderedFloat; +use pyo3_stub_gen::derive::gen_stub_pyfunction; +use std::collections::HashMap; + +use numpy::ndarray::Array1; +use numpy::{IntoPyArray, PyArray1, PyReadonlyArray2}; +use pyo3::prelude::*; + +/// 计算筹码分布(三角形分布 + 筹码沉淀机制) +/// +/// 此函数用于估算基于历史K线的筹码分布情况,结合三角形分布模型和筹码沉淀(衰减)机制。 +/// +/// # Python 接口说明 +/// +/// 输入一个二维 numpy 数组,形状为 (N, 3),每一行对应一根K线,列顺序为: +/// `[high, low, vol]`,类型必须为 `float64`。 +/// +/// 示例: +/// ```python +/// columns = ['high', 'low', 'vol'] +/// arr2 = df[columns].to_numpy(dtype=np.float64) +/// price_centers, chip_dist = chip_distribution_triangle(arr2, 0.01, 0.9) +/// ``` +/// +/// # 参数 +/// +/// - `data`: 二维数组,形状为 (N, 3),分别是每根K线的最高价、最低价和成交量。 +/// - `price_step`: 分档间隔(如0.01表示以0.01为单位划分价格区间)。 +/// - `decay_factor`: 筹码衰减因子,表示前一根K线上的筹码有多少比例沉淀保留到下一根K线上,范围为(0, 1),例如0.98表示保留98%。 +/// +/// # 返回值 +/// +/// 返回一个元组 `(price_centers, chip_distribution)`: +/// - `price_centers`: 一维数组,表示价格分布区间的中心价位。 +/// - `chip_distribution`: 一维数组,对应每个价格中心的筹码强度(权重/密度)。 +/// +/// 返回的两个数组长度相同,可用于绘制筹码分布图或进一步分析。 +#[pyfunction] +#[gen_stub_pyfunction] +pub fn chip_distribution_triangle<'py>( + py: Python<'py>, + data: PyReadonlyArray2<'py, f64>, + price_step: f64, + decay_factor: f64, +) -> (Bound<'py, PyArray1>, Bound<'py, PyArray1>) { + let data = data.as_array(); + let nrows = data.shape()[0]; + + // 安全校验 + let ncols = data.shape()[1]; + if ncols < 3 { + panic!("Input array must have at least 3 columns (high, low, vol)"); + } + + // 第2列是 low,第1列是 high + let low_col = data.column(1); + let high_col = data.column(0); + + // 取 min(low) 和 max(high) + let min_low = low_col.fold(f64::INFINITY, |a, &b| a.min(b)); + let max_high = high_col.fold(f64::NEG_INFINITY, |a, &b| a.max(b)); + + // 计算 price bins 区间 + let min_price = (min_low / price_step).floor() * price_step; + let max_price = (max_high / price_step).ceil() * price_step; + + let nbins = ((max_price - min_price) / price_step).ceil() as usize; + + let mut chip_dist = Array1::::zeros(nbins); + + let price_centers = + Array1::from_iter((0..nbins).map(|i| min_price + price_step * (i as f64 + 0.5))); + + // 缓存权重映射 + let mut weight_cache: HashMap<(OrderedFloat, OrderedFloat), Vec> = + HashMap::new(); + + for i in 0..nrows { + let high = data[[i, 0]]; + let low = data[[i, 1]]; + let vol = data[[i, 2]]; + + if high <= low || vol == 0.0 { + continue; + } + + let start_idx = ((low - min_price) / price_step).floor().max(0.0) as usize; + let end_idx = ((high - min_price) / price_step).ceil().min(nbins as f64) as usize; + + if end_idx <= start_idx || end_idx > nbins { + continue; + } + + // 衰减之前的筹码分布 + for x in chip_dist.iter_mut() { + *x *= decay_factor; + } + + // 构造三角分布权重 + // 三角权重缓存查找 + let low_key = OrderedFloat(low); + let high_key = OrderedFloat(high); + + let weights = weight_cache.entry((low_key, high_key)).or_insert_with(|| { + let mid_price = (low + high) / 2.0; + let mut w = Vec::with_capacity(end_idx - start_idx); + for idx in start_idx..end_idx { + let center_price = min_price + price_step * (idx as f64 + 0.5); + let weight = 1.0 - ((center_price - mid_price).abs()) / ((high - low) / 2.0); + w.push(weight.max(0.0)); + } + w + }); + + let weight_sum: f64 = weights.iter().sum(); + if weight_sum == 0.0 { + continue; + } + + // 归一化权重并加权更新 chip_dist + for (j, idx) in (start_idx..end_idx).enumerate() { + chip_dist[idx] += vol * weights[j] / weight_sum; + } + } + + // 归一化 chip_dist + let total: f64 = chip_dist.sum(); + if total > 0.0 { + chip_dist.iter_mut().for_each(|x| *x /= total); + } + + (price_centers.into_pyarray(py), chip_dist.into_pyarray(py)) +} diff --git a/crates/czsc-ta/src/mixed/mod.rs b/crates/czsc-ta/src/mixed/mod.rs new file mode 100644 index 000000000..e17669416 --- /dev/null +++ b/crates/czsc-ta/src/mixed/mod.rs @@ -0,0 +1,3 @@ +//! Numpy 绑定实现 + +pub mod chip_dist; diff --git a/crates/czsc-ta/src/pure.rs b/crates/czsc-ta/src/pure.rs new file mode 100644 index 000000000..40b0e2462 --- /dev/null +++ b/crates/czsc-ta/src/pure.rs @@ -0,0 +1,1537 @@ +//! 纯 Rust 实现 + +/// Plain Simple Moving Average — talib.SMA-compatible. +/// +/// Returns the rolling mean of `series` with window size `n`. Indices +/// 0..n-1 are filled with NaN to match the talib convention (so +/// `np.isfinite(out[n:])` is fully True). This is intentionally +/// distinct from `single_sma_positions`, which computes a double SMA +/// then derives a [-1, 0, 1] position signal — `sma` here is the raw +/// moving average needed by `czsc.ta.sma` per design doc §3.1. +pub fn sma(series: &[f64], n: usize) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + let mut out = vec![f64::NAN; len]; + if n > len { + return out; + } + let mut sum: f64 = series.iter().take(n).sum(); + out[n - 1] = sum / n as f64; + for i in n..len { + sum += series[i] - series[i - n]; + out[i] = sum / n as f64; + } + out +} + +// Ultimate Smoother 实现函数 +pub fn ultimate_smoother(price: &[f64], period: f64) -> Vec { + let len = price.len(); + if len == 0 { + return vec![]; + } + let a1 = (-1.414 * std::f64::consts::PI / period).exp(); + let b1 = 2.0 * a1 * (1.414 * 180.0 / period).to_radians().cos(); + let c2 = b1; + let c3 = -a1 * a1; + let c1 = (1.0 + c2 - c3) / 4.0; + let mut us = vec![0.0; len]; + + // 与 Python 实现完全一致,正确处理 NaN + for i in 0..len { + if i < 4 { + us[i] = price[i]; + } else { + // 检查输入值是否为 NaN,如果是则保持 NaN + if price[i].is_nan() || price[i - 1].is_nan() || price[i - 2].is_nan() { + us[i] = f64::NAN; + } else { + us[i] = (1.0 - c1) * price[i] + (2.0 * c1 - c2) * price[i - 1] + - (c1 + c3) * price[i - 2] + + c2 * us[i - 1] + + c3 * us[i - 2]; + } + } + } + + us +} + +pub fn rolling_rank(series: &[f64], window: usize) -> Vec> { + let len = series.len(); + let mut ranks = Vec::with_capacity(len); + for i in 0..len { + if i + 1 < window { + ranks.push(None); + continue; + } + let start = i + 1 - window; + let window_slice = &series[start..=i]; + let mut sorted = window_slice.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let value = series[i]; + let rank = sorted + .iter() + .position(|&x| (x - value).abs() < 1e-8) + .map(|pos| pos + 1); + ranks.push(rank); + } + ranks +} +/// 单均线多空 +pub fn single_sma_positions(series: &[f64], n: usize) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + // 如果窗口大于序列长度,全部返回0 + return vec![0.0; len]; + } + // 计算第一个移动平均 + let mut ms = vec![0.0; len]; + for i in 0..len { + if i + 1 < n { + ms[i] = 0.0; // 不足窗口长度的位置设为0 + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + ms[i] = sum / n as f64; + } + } + // 计算第二个移动平均(对第一个移动平均再次求移动平均) + let mut ms_sma = vec![0.0; len]; + for i in 0..len { + if i + 1 < n { + ms_sma[i] = 0.0; + } else { + let sum: f64 = ms[i + 1 - n..=i].iter().sum(); + ms_sma[i] = sum / n as f64; + } + } + // 计算差值并返回符号 + let mut result = vec![0.0; len]; + for i in 0..len { + // 只有 i >= 2*n-2 时才有有效信号,与 Python pandas 的 fillna(0) 行为完全一致 + if i >= 2 * n - 2 && ms_sma[i] != 0.0 { + result[i] = (series[i] - ms_sma[i]).signum(); + } else { + result[i] = 0.0; + } + } + result +} + +/// 单指数移动平均多空信号 +pub fn single_ema_positions(series: &[f64], n: usize) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + // 如果窗口大于序列长度,全部返回0 + return vec![0.0; len]; + } + + // 预分配所有向量,避免动态扩容 + let mut result = vec![0.0; len]; + let mut ms = vec![0.0; len]; + let mut ms_ema = vec![0.0; len]; + + // 计算第一个移动平均 + for i in 0..len { + if i + 1 < n { + ms[i] = 0.0; // 不足窗口长度的位置设为0,对应 Python 的 NaN + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + ms[i] = sum / n as f64; + } + } + + // 计算指数移动平均,模拟 talib.EMA 的行为 + let alpha = 2.0 / (n as f64 + 1.0); + + // 找到第一个非零的 ms 值作为初始值 + let mut first_valid_idx = 0; + for i in 0..len { + if ms[i] != 0.0 { + first_valid_idx = i; + break; + } + } + + // 模拟 talib.EMA 的行为,需要 n 个有效数据才开始计算 + let start_idx = first_valid_idx + n - 1; + + // 一次性计算 EMA 和结果,减少循环次数 + for i in 0..len { + if i < start_idx { + ms_ema[i] = 0.0; + result[i] = 0.0; + } else if i == start_idx { + // 用前 n 个有效值的简单平均作为初始值 + let sum: f64 = ms[first_valid_idx..=i].iter().sum(); + ms_ema[i] = sum / n as f64; + result[i] = (series[i] - ms_ema[i]).signum(); + } else { + ms_ema[i] = alpha * ms[i] + (1.0 - alpha) * ms_ema[i - 1]; + result[i] = (series[i] - ms_ema[i]).signum(); + } + } + + result +} + +/// 取窗口内的中间值作为中轴,中轴上方做多,下方做空 +/// 中间值 = (最大值 + 最小值) / 2 +pub fn mid_positions(series: &[f64], n: usize) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + return vec![0.0; len]; + } + + // 计算第一个移动平均 + let mut ms = vec![0.0; len]; + for i in 0..len { + if i + 1 < n { + ms[i] = 0.0; // 不足窗口长度的位置设为0 + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + ms[i] = sum / n as f64; + } + } + + // 计算 ms 的滚动最大值和最小值 + let mut high = vec![0.0; len]; + let mut low = vec![0.0; len]; + + for i in 0..len { + if i + 1 < n { + high[i] = 0.0; + low[i] = 0.0; + } else { + let window = &ms[i + 1 - n..=i]; + high[i] = window.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); + low[i] = window.iter().fold(f64::INFINITY, |a, &b| a.min(b)); + } + } + + // 计算中间值并返回符号 + let mut result = vec![0.0; len]; + for i in 0..len { + if i + 1 < 2 * n - 1 { + // 需要 2*n-1 个数据点才有有效信号 + result[i] = 0.0; + } else if high[i] != 0.0 && low[i] != 0.0 { + let mid = (high[i] + low[i]) / 2.0; + result[i] = (ms[i] - mid).signum(); + } else { + result[i] = 0.0; + } + } + + result +} + +/// 双均线多空信号 +/// 比较短周期和长周期的简单移动平均 +pub fn double_sma_positions(series: &[f64], n: usize, m: usize) -> Vec { + let len = series.len(); + if len == 0 || n == 0 || m == 0 { + return vec![]; + } + if n >= m { + panic!("短周期必须小于长周期"); + } + if m > len { + return vec![0.0; len]; + } + + let mut result = vec![0.0; len]; + + // 使用增量计算优化性能 + let mut sma_n_sum = 0.0; + let mut sma_m_sum = 0.0; + let mut sma_n_count = 0; + let mut sma_m_count = 0; + + for i in 0..len { + // 更新短周期移动平均 + if i >= n - 1 { + if sma_n_count == 0 { + sma_n_sum = series[i + 1 - n..=i].iter().sum(); + sma_n_count = 1; + } else { + sma_n_sum = sma_n_sum + series[i] - series[i - n]; + } + } + + // 更新长周期移动平均 + if i >= m - 1 { + if sma_m_count == 0 { + sma_m_sum = series[i + 1 - m..=i].iter().sum(); + sma_m_count = 1; + } else { + sma_m_sum = sma_m_sum + series[i] - series[i - m]; + } + } + + if i >= m - 1 && sma_n_count > 0 && sma_m_count > 0 { + let sma_n = sma_n_sum / n as f64; + let sma_m = sma_m_sum / m as f64; + result[i] = (sma_n - sma_m).signum(); + } else { + result[i] = 0.0; + } + } + + result +} + +/// 三均线系统持仓信号生成函数 +/// 多头:factor > m3 的时候,m1 > m2 +/// 空头:factor < m3 的时候,m1 < m2 +pub fn triple_sma_positions(series: &[f64], m1: usize, m2: usize, m3: usize) -> Vec { + let len = series.len(); + if len == 0 || m1 == 0 || m2 == 0 || m3 == 0 { + return vec![]; + } + if m3 > len { + panic!("series 长度必须大于 m3"); + } + if !(m3 > m2 && m2 > m1) { + panic!("m3 必须大于 m2 大于 m1"); + } + + // 第一步:对原始序列计算 m1 期移动平均 + let mut smoothed_series = vec![None; len]; + for i in 0..len { + if i + 1 < m1 { + smoothed_series[i] = None; + } else { + let sum: f64 = series[i + 1 - m1..=i].iter().sum(); + smoothed_series[i] = Some(sum / m1 as f64); + } + } + + // 计算三个移动平均 + let mut ma1 = vec![None; len]; + let mut ma2 = vec![None; len]; + let mut ma3 = vec![None; len]; + + // 计算 ma1 (对 smoothed_series 计算 m1 期移动平均) + for i in 0..len { + if i + 1 < m1 { + ma1[i] = None; + } else { + let window: Vec = smoothed_series[i + 1 - m1..=i] + .iter() + .filter_map(|&x| x) + .collect(); + if window.len() == m1 { + let sum: f64 = window.iter().sum(); + ma1[i] = Some(sum / m1 as f64); + } else { + ma1[i] = None; + } + } + } + + // 计算 ma2 (对 smoothed_series 计算 m2 期移动平均) + for i in 0..len { + if i + 1 < m2 { + ma2[i] = None; + } else { + let window: Vec = smoothed_series[i + 1 - m2..=i] + .iter() + .filter_map(|&x| x) + .collect(); + if window.len() == m2 { + let sum: f64 = window.iter().sum(); + ma2[i] = Some(sum / m2 as f64); + } else { + ma2[i] = None; + } + } + } + + // 计算 ma3 (对 smoothed_series 计算 m3 期移动平均) + for i in 0..len { + if i + 1 < m3 { + ma3[i] = None; + } else { + let window: Vec = smoothed_series[i + 1 - m3..=i] + .iter() + .filter_map(|&x| x) + .collect(); + if window.len() == m3 { + let sum: f64 = window.iter().sum(); + ma3[i] = Some(sum / m3 as f64); + } else { + ma3[i] = None; + } + } + } + + // 生成持仓信号 + let mut positions = vec![0i32; len]; + for i in 0..len { + // 检查所有值是否都有效(非None) + if let (Some(smoothed), Some(ma1_val), Some(ma2_val), Some(ma3_val)) = + (smoothed_series[i], ma1[i], ma2[i], ma3[i]) + { + if smoothed > ma3_val && ma1_val > ma2_val { + positions[i] = 1; // 多头 + } else if smoothed < ma3_val && ma1_val < ma2_val { + positions[i] = -1; // 空头 + } else { + positions[i] = 0; // 空仓 + } + } else { + positions[i] = 0; + } + } + + positions +} + +/// 布林线多空信号 +/// series 大于 n 周期均线 + k * n周期标准差,做多;小于 n 周期均线 - k * n周期标准差,做空 +pub fn boll_positions(series: &[f64], n: usize, k: f64) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + return vec![0; len]; + } + + // 第一步:对原始序列计算 n 期移动平均 + let mut smoothed_series = vec![None; len]; + for i in 0..len { + if i + 1 < n { + smoothed_series[i] = None; + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + smoothed_series[i] = Some(sum / n as f64); + } + } + + // 计算移动平均 (sm) + let mut sm = vec![None; len]; + for i in 0..len { + if i + 1 < n { + sm[i] = None; + } else { + let window: Vec = smoothed_series[i + 1 - n..=i] + .iter() + .filter_map(|&x| x) + .collect(); + if window.len() == n { + let sum: f64 = window.iter().sum(); + sm[i] = Some(sum / n as f64); + } else { + sm[i] = None; + } + } + } + + // 计算移动标准差 (sd) + let mut sd = vec![None; len]; + for i in 0..len { + if i + 1 < n { + sd[i] = None; + } else { + let window: Vec = smoothed_series[i + 1 - n..=i] + .iter() + .filter_map(|&x| x) + .collect(); + if window.len() == n { + let mean = window.iter().sum::() / n as f64; + let variance = + window.iter().map(|&x| (x - mean).powi(2)).sum::() / (n - 1) as f64; // 使用 ddof=1 (样本标准差) + sd[i] = Some(variance.sqrt()); + } else { + sd[i] = None; + } + } + } + + // 生成持仓信号 + let mut positions = vec![0i32; len]; + for i in 0..len { + if let (Some(smoothed), Some(sm_val), Some(sd_val)) = (smoothed_series[i], sm[i], sd[i]) { + let upper_band = sm_val + k * sd_val; + let lower_band = sm_val - k * sd_val; + + // 使用更精确的比较,避免浮点数精度问题 + if smoothed > upper_band + 1e-10 { + positions[i] = 1; // 做多 + } else if smoothed < lower_band - 1e-10 { + positions[i] = -1; // 做空 + } else { + positions[i] = 0; // 空仓 + } + } else { + positions[i] = 0; + } + } + + positions +} + +/// 布林带反转策略的多空持仓信号生成函数 +/// 策略逻辑: +/// 1. 计算布林带:中轨为 MA(n), 上轨 = MA(n) + k*STD(n), 下轨 = MA(n) - k*STD(n) +/// 2. 开多:当价格 < 下轨时,开多 (pos=+1),一直持有至价格 > 中轨 => 平多 (pos=0) +/// 3. 开空:当价格 > 上轨时,开空 (pos=-1),一直持有至价格 < 中轨 => 平空 (pos=0) +pub fn boll_reverse_positions(series: &[f64], n: usize, k: f64) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + return vec![0; len]; + } + + // 第一步:对原始序列计算 n 期移动平均 + let mut smoothed_series = vec![None; len]; + for i in 0..len { + if i + 1 < n { + smoothed_series[i] = None; + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + smoothed_series[i] = Some(sum / n as f64); + } + } + + // 计算布林带 + let mut upper = vec![None; len]; + let mut mid = vec![None; len]; + let mut lower = vec![None; len]; + + for i in 0..len { + if i + 1 < n { + upper[i] = None; + mid[i] = None; + lower[i] = None; + } else { + let window: Vec = smoothed_series[i + 1 - n..=i] + .iter() + .filter_map(|&x| x) + .collect(); + if window.len() == n { + let mean = window.iter().sum::() / n as f64; + let variance = window.iter().map(|&x| (x - mean).powi(2)).sum::() / n as f64; + let std_dev = variance.sqrt(); + + mid[i] = Some(mean); + upper[i] = Some(mean + k * std_dev); + lower[i] = Some(mean - k * std_dev); + } else { + upper[i] = None; + mid[i] = None; + lower[i] = None; + } + } + } + + // 生成持仓信号 + let mut positions = vec![0i32; len]; + let mut current_pos = 0i32; // 当前持仓:0=空仓,+1=多头,-1=空头 + + for i in 0..len { + // 若尚未计算出 mid/upper/lower,跳过(最前面的 n-1 个数据点) + if let (Some(upper_val), Some(mid_val), Some(lower_val)) = (upper[i], mid[i], lower[i]) { + let price = smoothed_series[i].unwrap(); + + // 若当前空仓 + if current_pos == 0 { + // 价格 > 上轨 => 开空 + if price > upper_val { + current_pos = -1; + } + // 价格 < 下轨 => 开多 + else if price < lower_val { + current_pos = 1; + } + } + // 若当前持有多头 + else if current_pos == 1 { + // 当价格 > 中轨 => 平多 + if price > mid_val { + current_pos = 0; + } + } + // 若当前持有空头 + else if current_pos == -1 { + // 当价格 < 中轨 => 平空 + if price < mid_val { + current_pos = 0; + } + } + } else { + current_pos = 0; + } + + positions[i] = current_pos; + } + + positions +} + +/// 均线的最大最小值归一化 +/// 返回归一化后的值,范围在 [-1, 1] 之间 +pub fn mms_positions(series: &[f64], timeperiod: usize, window: usize) -> Vec { + let len = series.len(); + if len == 0 || timeperiod == 0 || window == 0 { + return vec![]; + } + if timeperiod > len || window > len { + return vec![0.0; len]; + } + + // 计算移动平均 (sm) + let mut sm = vec![None; len]; + for i in 0..len { + if i + 1 < timeperiod { + sm[i] = None; + } else { + let sum: f64 = series[i + 1 - timeperiod..=i].iter().sum(); + sm[i] = Some(sum / timeperiod as f64); + } + } + + // 计算移动平均的最小值 (sm_min) + let mut sm_min = vec![None; len]; + for i in 0..len { + if i + 1 < window { + sm_min[i] = None; + } else { + let window_values: Vec = + sm[i + 1 - window..=i].iter().filter_map(|&x| x).collect(); + if window_values.len() == window { + sm_min[i] = Some(window_values.iter().fold(f64::INFINITY, |a, &b| a.min(b))); + } else { + sm_min[i] = None; + } + } + } + + // 计算移动平均的最大值 (sm_max) + let mut sm_max = vec![None; len]; + for i in 0..len { + if i + 1 < window { + sm_max[i] = None; + } else { + let window_values: Vec = + sm[i + 1 - window..=i].iter().filter_map(|&x| x).collect(); + if window_values.len() == window { + sm_max[i] = Some( + window_values + .iter() + .fold(f64::NEG_INFINITY, |a, &b| a.max(b)), + ); + } else { + sm_max[i] = None; + } + } + } + + // 计算归一化结果 + let mut result = vec![0.0; len]; + for i in 0..len { + if let (Some(sm_val), Some(sm_min_val), Some(sm_max_val)) = (sm[i], sm_min[i], sm_max[i]) { + let denominator = sm_max_val - sm_min_val; + if denominator.abs() > 1e-10 { + // 归一化到 [0, 1],然后转换到 [-1, 1] + let normalized = (sm_val - sm_min_val) / denominator; + result[i] = normalized * 2.0 - 1.0; + } else { + // 如果最大值和最小值相等,设为0 + result[i] = 0.0; + } + } else { + result[i] = 0.0; + } + } + + result +} + +/// RSI 反转策略的多空持仓信号 +/// 返回每个点的持仓信号(-1: 空头, 0: 空仓, 1: 多头) +pub fn rsi_reverse_positions( + series: &[f64], + n: usize, + rsi_upper: f64, + rsi_lower: f64, + rsi_exit: f64, +) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + return vec![0; len]; + } + + // 第一步:对原始序列计算 n 期移动平均(与 Python 一致) + let mut smoothed_series = vec![None; len]; + for i in 0..len { + if i + 1 < n { + smoothed_series[i] = None; + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + smoothed_series[i] = Some(sum / n as f64); + } + } + + // 第二步:对移动平均后的数据计算 RSI(与 Python 一致) + let mut rsi = vec![None; len]; + for i in 0..len { + // 检查是否有足够的有效数据来计算 RSI + let mut valid_count = 0; + let mut gains = 0.0; + let mut losses = 0.0; + + // 计算前 n 个周期的涨跌幅(使用移动平均后的数据) + let start = if i >= n { i.saturating_sub(n) } else { 0 }; + for j in start..i { + if j > 0 + && let (Some(current), Some(prev)) = (smoothed_series[j], smoothed_series[j - 1]) + { + let change = current - prev; + if change > 0.0 { + gains += change; + } else { + losses += change.abs(); + } + valid_count += 1; + } + } + + // 只有当有足够的有效数据时才计算 RSI + if valid_count >= n - 1 { + // 允许一个缺失值 + if losses == 0.0 { + rsi[i] = Some(100.0); + } else { + let avg_gain = gains / valid_count as f64; + let avg_loss = losses / valid_count as f64; + let rs = avg_gain / avg_loss; + rsi[i] = Some(100.0 - (100.0 / (1.0 + rs))); + } + } else { + rsi[i] = None; + } + } + + // 第三步:生成持仓信号 + let mut positions = vec![0; len]; + let mut current_pos = 0; // 当前持仓:0=空仓,+1=多头,-1=空头 + + for i in 0..len { + if let Some(rsi_val) = rsi[i] { + // 若当前空仓 + if current_pos == 0 { + // 如果 RSI < rsi_lower,则开多 + if rsi_val < rsi_lower { + current_pos = 1; + } + // 如果 RSI > rsi_upper,则开空 + else if rsi_val > rsi_upper { + current_pos = -1; + } + } + // 若当前持有多头 + else if current_pos == 1 { + // 当 RSI > rsi_exit,则平多 (回到空仓) + if rsi_val > rsi_exit { + current_pos = 0; + } + } + // 若当前持有空头 + else if current_pos == -1 { + // 当 RSI < rsi_exit,则平空 (回到空仓) + if rsi_val < rsi_exit { + current_pos = 0; + } + } + } + // 如果无法计算出 RSI,保持空仓 + positions[i] = current_pos; + } + + positions +} + +/// tanh 多空策略 +/// 返回每个点的持仓信号(-1 到 1 之间的值) +pub fn tanh_positions(series: &[f64], n: usize) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + return vec![0.0; len]; + } + + // 计算移动平均 + let mut ms = vec![None; len]; + for i in 0..len { + if i + 1 < n { + ms[i] = None; + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + ms[i] = Some(sum / n as f64); + } + } + + // 计算移动平均的均值和标准差 + let mut mean = vec![None; len]; + let mut std = vec![None; len]; + + for i in 0..len { + if i + 1 < n { + mean[i] = None; + std[i] = None; + } else { + let mut values = Vec::new(); + for j in i + 1 - n..=i { + if let Some(val) = ms[j] { + values.push(val); + } + } + + if values.len() == n { + let sum: f64 = values.iter().sum(); + let mean_val = sum / n as f64; + mean[i] = Some(mean_val); + + // 使用 pandas 的 ddof=1 默认行为(n-1 自由度) + let variance = + values.iter().map(|&x| (x - mean_val).powi(2)).sum::() / (n - 1) as f64; + std[i] = Some(variance.sqrt()); + } + } + } + + // 计算 tanh 值 + let mut result = vec![0.0; len]; + for i in 0..len { + if let (Some(ms_val), Some(mean_val), Some(std_val)) = (ms[i], mean[i], std[i]) + && std_val > 0.0 + { + let z_score = (ms_val - mean_val) / std_val; + result[i] = z_score.tanh(); + } + } + + // 四舍五入到两位小数 + for val in &mut result { + *val = (*val * 100.0).round() / 100.0; + } + + result +} + +/// rank 多空策略 +/// 返回每个点的持仓信号(-1 到 1 之间的值) +pub fn rank_positions(series: &[f64], n: usize) -> Vec { + let len = series.len(); + if len == 0 || n == 0 { + return vec![]; + } + if n > len { + return vec![0.0; len]; + } + + // 计算移动平均 + let mut ms = vec![None; len]; + for i in 0..len { + if i + 1 < n { + ms[i] = None; + } else { + let sum: f64 = series[i + 1 - n..=i].iter().sum(); + ms[i] = Some(sum / n as f64); + } + } + + // 计算 rank + let mut result = vec![0.0; len]; + for i in 0..len { + if i + 1 < n { + result[i] = 0.0; + } else { + let mut values = Vec::new(); + for j in i + 1 - n..=i { + if let Some(val) = ms[j] { + values.push(val); + } + } + + if values.len() == n { + // 计算当前值在窗口中的排名 + let current_val = values[n - 1]; + let mut rank = 1; + for &val in &values[..n - 1] { + if val < current_val { + rank += 1; + } + } + + // 计算归一化的 rank + let normalized_rank = (rank - 1) as f64 / (n - 1) as f64; + let x = (normalized_rank - 0.5) * 2.0; + result[i] = x; + } + } + } + + // 四舍五入到两位小数 + for val in &mut result { + *val = (*val * 100.0).round() / 100.0; + } + + result +} + +/// 计算指数移动平均 (EMA) +/// 返回每个点的 EMA 值 +pub fn ema(series: &[f64], period: usize) -> Vec { + // talib-compatible EMA: warmup [0, period-1) is NaN, position + // (period-1) is seeded with the simple mean of the first `period` + // samples, then the standard recurrence runs forward. This matches + // talib.EMA's output bit-for-bit (verified by Phase A's + // `test_ema_matches_talib`). The previous rs-czsc implementation + // seeded with `series[0]`, which produced visible divergence in + // the first ~30 bars. + let len = series.len(); + if len == 0 || period == 0 { + return vec![]; + } + if period > len { + return vec![f64::NAN; len]; + } + + let mut result = vec![f64::NAN; len]; + let alpha = 2.0 / (period + 1) as f64; + + let seed: f64 = series.iter().take(period).sum::() / period as f64; + result[period - 1] = seed; + for i in period..len { + result[i] = alpha * series[i] + (1.0 - alpha) * result[i - 1]; + } + + result +} + +/// 计算真实波幅 (True Range) +/// 返回每个点的真实波幅值 +pub fn true_range(high: &[f64], low: &[f64], close_prev: &[f64]) -> Vec { + let len = high.len(); + if len == 0 { + return vec![]; + } + + let mut result = vec![0.0; len]; + + for i in 0..len { + let tr1 = high[i] - low[i]; + + // 与Python的close.shift(1)行为一致 + // 第一个位置使用NaN,其他位置使用close_prev + let prev_close = if i == 0 { f64::NAN } else { close_prev[i] }; + + let tr2 = if prev_close.is_nan() { + f64::NAN + } else { + (high[i] - prev_close).abs() + }; + let tr3 = if prev_close.is_nan() { + f64::NAN + } else { + (low[i] - prev_close).abs() + }; + + // 取三个值中的最大值,处理NaN + if tr2.is_nan() || tr3.is_nan() { + // 如果有NaN,只使用tr1 + result[i] = tr1; + } else { + result[i] = tr1.max(tr2).max(tr3); + } + } + + result +} + +/// RSX-SS2 - 自适应平滑的RSI变体 +/// 返回每个点的 RSX-SS2 值 +pub fn rsx_ss2(close: &[f64], period: usize, smooth_period: usize) -> Vec { + let len = close.len(); + if len == 0 || period == 0 || smooth_period == 0 { + return vec![]; + } + + // 计算价格变化 + let mut delta = vec![f64::NAN; len]; + for i in 1..len { + delta[i] = close[i] - close[i - 1]; + } + + // 计算增益和损失 + let mut gain = vec![f64::NAN; len]; + let mut loss = vec![f64::NAN; len]; + for i in 0..len { + if i == 0 { + gain[i] = 0.0; + loss[i] = 0.0; + } else if delta[i] > 0.0 { + gain[i] = delta[i]; + loss[i] = 0.0; + } else { + gain[i] = 0.0; + loss[i] = -delta[i]; + } + } + + // 计算平均增益和平均损失 (使用 EMA) + let alpha = 1.0 / period as f64; + let mut avg_gain = vec![f64::NAN; len]; + let mut avg_loss = vec![f64::NAN; len]; + + // 找到第一个非 NaN 的值来初始化 + let mut first_valid_idx = None; + for i in 0..len { + if !gain[i].is_nan() && !loss[i].is_nan() { + first_valid_idx = Some(i); + break; + } + } + + if let Some(idx) = first_valid_idx { + // 使用第一个有效值作为初始值,与 pandas ewm 行为一致 + avg_gain[idx] = gain[idx]; + avg_loss[idx] = loss[idx]; + + // 计算 EMA,与 pandas ewm(adjust=False) 行为一致 + for i in idx + 1..len { + if !gain[i].is_nan() && !loss[i].is_nan() { + avg_gain[i] = alpha * gain[i] + (1.0 - alpha) * avg_gain[i - 1]; + avg_loss[i] = alpha * loss[i] + (1.0 - alpha) * avg_loss[i - 1]; + } else { + avg_gain[i] = f64::NAN; + avg_loss[i] = f64::NAN; + } + } + } + + // 计算 RSI + let mut rsi = vec![f64::NAN; len]; + for i in 0..len { + if avg_gain[i].is_nan() || avg_loss[i].is_nan() { + rsi[i] = f64::NAN; + } else if avg_loss[i] == 0.0 { + rsi[i] = 100.0; + } else { + let rs = avg_gain[i] / avg_loss[i]; + rsi[i] = 100.0 - (100.0 / (1.0 + rs)); + } + } + + // 第一个值设为 NaN,与 Python 行为一致 + if len > 0 { + rsi[0] = f64::NAN; + } + + // 使用终极平滑器进行平滑 + ultimate_smoother(&rsi, smooth_period as f64) +} + +/// Jurik波动平滑器 - 低噪声波动指标 +/// 返回平滑波动率值 +pub fn jurik_volty(close: &[f64], period: usize, power: f64) -> Vec { + let len = close.len(); + if len == 0 || period == 0 { + return vec![]; + } + + // 计算价格变化 + let mut price_change = vec![f64::NAN; len]; + for i in 1..len { + price_change[i] = (close[i] - close[i - 1]).abs(); + } + + // 初步平滑 - 第一次 EMA + let span1 = period / 2; + let alpha1 = 2.0 / (span1 + 1) as f64; + let mut smooth1 = vec![f64::NAN; len]; + + // 找到第一个非 NaN 的值来初始化 + let mut first_valid_idx = None; + for i in 0..len { + if !price_change[i].is_nan() { + first_valid_idx = Some(i); + break; + } + } + + if let Some(idx) = first_valid_idx { + // 检查是否所有非 NaN 值都相同(pandas ewm 的特殊处理) + let first_value = price_change[idx]; + let mut all_same = true; + for i in idx..len { + if !price_change[i].is_nan() && price_change[i] != first_value { + all_same = false; + break; + } + } + + if all_same { + // 如果所有值都相同,直接返回原始值(pandas ewm 行为) + for i in idx..len { + if !price_change[i].is_nan() { + smooth1[i] = price_change[i]; + } + } + } else { + // 正常 EMA 计算 + smooth1[idx] = price_change[idx]; + for i in idx + 1..len { + if !price_change[i].is_nan() { + smooth1[i] = alpha1 * price_change[i] + (1.0 - alpha1) * smooth1[i - 1]; + } else { + smooth1[i] = f64::NAN; + } + } + } + } + + // 第二次 EMA + let mut smooth2 = vec![f64::NAN; len]; + + // 找到第一个非 NaN 的值来初始化 + let mut first_valid_idx2 = None; + for i in 0..len { + if !smooth1[i].is_nan() { + first_valid_idx2 = Some(i); + break; + } + } + + if let Some(idx) = first_valid_idx2 { + // 检查是否所有非 NaN 值都相同(pandas ewm 的特殊处理) + let first_value = smooth1[idx]; + let mut all_same = true; + for i in idx..len { + if !smooth1[i].is_nan() && smooth1[i] != first_value { + all_same = false; + break; + } + } + + if all_same { + // 如果所有值都相同,直接返回原始值(pandas ewm 行为) + for i in idx..len { + if !smooth1[i].is_nan() { + smooth2[i] = smooth1[i]; + } + } + } else { + // 正常 EMA 计算 + smooth2[idx] = smooth1[idx]; + for i in idx + 1..len { + if !smooth1[i].is_nan() { + smooth2[i] = alpha1 * smooth1[i] + (1.0 - alpha1) * smooth2[i - 1]; + } else { + smooth2[i] = f64::NAN; + } + } + } + } + + // Jurik特定平滑公式 + let mut jv = vec![0.0; len]; + for i in 2..len { + if !smooth2[i].is_nan() && !smooth2[i - 1].is_nan() { + jv[i] = (smooth2[i] + (smooth2[i] - smooth2[i - 1]) * 0.5) * power; + } else { + jv[i] = 0.0; + } + } + + // 最终平滑 - 第三次 EMA,与Pandas ewm(span=period/3, adjust=False).mean()行为完全一致 + let span3 = period / 3; + let alpha3 = 2.0 / (span3 + 1) as f64; + let mut result = vec![0.0; len]; + + // 找到第一个非零的 jv 值来初始化 + let mut first_valid_idx3 = None; + for i in 0..len { + if jv[i] != 0.0 { + first_valid_idx3 = Some(i); + break; + } + } + + if let Some(idx) = first_valid_idx3 { + // 检查是否所有非零值都相同(pandas ewm 的特殊处理) + let first_value = jv[idx]; + let mut all_same = true; + for i in idx..len { + if jv[i] != 0.0 && jv[i] != first_value { + all_same = false; + break; + } + } + + if all_same { + // 如果所有值都相同,直接返回原始值(pandas ewm 行为) + for i in idx..len { + if jv[i] != 0.0 { + result[i] = jv[i]; + } + } + } else { + // 正常 EMA 计算 - 与Pandas ewm(adjust=False)行为完全一致 + result[idx] = jv[idx]; + for i in idx + 1..len { + result[i] = alpha3 * jv[i] + (1.0 - alpha3) * result[i - 1]; + } + } + } + + result +} + +/// 终极通道 - 基于终极平滑器的通道指标 +/// 返回 (中线, 上轨, 下轨) +pub fn ultimate_channel( + high: &[f64], + low: &[f64], + close: &[f64], + period: usize, + multiplier: f64, +) -> (Vec, Vec, Vec) { + let len = high.len(); + if len == 0 || period == 0 { + return (vec![], vec![], vec![]); + } + + // 计算终极平滑中线 + let midline = ultimate_smoother(close, period as f64); + + // 计算平滑真实波幅 (STR) + let mut close_prev = vec![0.0; len]; + for i in 1..len { + close_prev[i] = close[i - 1]; + } + + let tr = true_range(high, low, &close_prev); + + // 计算 ATR (平均真实波幅) + let mut atr = vec![0.0; len]; + // 前 period-1 个值设为 NaN (与 Pandas rolling().mean() 行为一致) + for i in 0..period - 1 { + atr[i] = f64::NAN; + } + + if len >= period { + // 计算第一个 ATR 值 + let mut sum = 0.0; + for i in 0..period { + sum += tr[i]; + } + atr[period - 1] = sum / period as f64; + + // 计算后续 ATR 值 + for i in period..len { + atr[i] = (atr[i - 1] * (period - 1) as f64 + tr[i]) / period as f64; + } + } + + // 使用终极平滑器平滑 ATR + let str = ultimate_smoother(&atr, (period / 2) as f64); + + // 计算通道 + let mut upper = vec![0.0; len]; + let mut lower = vec![0.0; len]; + for i in 0..len { + upper[i] = midline[i] + multiplier * str[i]; + lower[i] = midline[i] - multiplier * str[i]; + } + + (midline, upper, lower) +} + +/// 终极带 - 基于终极平滑器的布林带变体 +/// 返回 (中线, 上轨, 下轨) +pub fn ultimate_bands( + close: &[f64], + period: usize, + std_multiplier: f64, + smooth_period: usize, +) -> (Vec, Vec, Vec) { + let len = close.len(); + if len == 0 || period == 0 || smooth_period == 0 { + return (vec![], vec![], vec![]); + } + + // 如果数据长度小于 period,返回空结果 + if len < period { + return (vec![], vec![], vec![]); + } + + // 计算终极平滑中线 + let midline = ultimate_smoother(close, period as f64); + + // 计算标准差并平滑 + let mut std = vec![0.0; len]; + + // 前 period-1 个值设为 NaN (与 Pandas rolling().std() 行为一致) + if period > 1 { + for i in 0..(period - 1) { + std[i] = f64::NAN; + } + } + + // 计算滚动标准差 + for i in (period - 1)..len { + let start = if i >= (period - 1) { + i.saturating_sub(period - 1) + } else { + 0 + }; + let end = i + 1; + + // 确保索引有效 + if start >= len || end > len || start >= end { + std[i] = f64::NAN; + continue; + } + + // 计算均值 + let mut sum = 0.0; + for j in start..end { + sum += close[j]; + } + let mean = sum / period as f64; + + // 计算方差 + let mut variance = 0.0; + for j in start..end { + let diff = close[j] - mean; + variance += diff * diff; + } + + // 使用 ddof=1 (与 Pandas 默认行为一致) + if period > 1 { + std[i] = (variance / (period - 1) as f64).sqrt(); + } else { + std[i] = 0.0; + } + } + + // 使用终极平滑器平滑标准差 + let smooth_std = ultimate_smoother(&std, smooth_period as f64); + + // 计算通道 + let mut upper = vec![0.0; len]; + let mut lower = vec![0.0; len]; + for i in 0..len { + upper[i] = midline[i] + std_multiplier * smooth_std[i]; + lower[i] = midline[i] - std_multiplier * smooth_std[i]; + } + + (midline, upper, lower) +} + +/// 终极波动指标 (UOS) - 多周期融合振荡器 +/// 返回 UOS 值 +pub fn ultimate_oscillator( + high: &[f64], + low: &[f64], + close: &[f64], + short_period: usize, + med_period: usize, + long_period: usize, +) -> Vec { + let len = high.len(); + if len == 0 || short_period == 0 || med_period == 0 || long_period == 0 { + return vec![]; + } + + // 计算买方压力 - 与Python实现完全一致 + let mut buying_pressure = vec![0.0; len]; + for i in 0..len { + let prev_close = if i > 0 { close[i - 1] } else { close[i] }; + // 使用 min(low[i], prev_close) 与Python的pd.concat([low, close.shift(1)], axis=1).min(axis=1)一致 + let min_val = low[i].min(prev_close); + buying_pressure[i] = close[i] - min_val; + } + + // 计算真实波幅 - 与Python实现完全一致 + let mut close_prev = vec![0.0; len]; + for i in 1..len { + close_prev[i] = close[i - 1]; + } + let true_range = true_range(high, low, &close_prev); + + // 计算不同周期的平均值 - 与Pandas rolling().sum()行为完全一致 + let mut avg7 = vec![f64::NAN; len]; + let mut avg14 = vec![f64::NAN; len]; + let mut avg28 = vec![f64::NAN; len]; + + // 计算 avg7 - 与Pandas rolling().sum()行为完全一致 + for i in (short_period - 1)..len { + let mut bp_sum = 0.0; + let mut tr_sum = 0.0; + // 检查是否有NaN值,如果有则结果也为NaN + let mut has_nan = false; + for j in (i + 1 - short_period)..=i { + if buying_pressure[j].is_nan() || true_range[j].is_nan() { + has_nan = true; + break; + } + bp_sum += buying_pressure[j]; + tr_sum += true_range[j]; + } + if has_nan { + avg7[i] = f64::NAN; + } else if tr_sum != 0.0 { + avg7[i] = bp_sum / tr_sum; + } else { + avg7[i] = 0.0; + } + } + + // 计算 avg14 - 与Pandas rolling().sum()行为完全一致 + for i in (med_period - 1)..len { + let mut bp_sum = 0.0; + let mut tr_sum = 0.0; + // 检查是否有NaN值,如果有则结果也为NaN + let mut has_nan = false; + for j in (i + 1 - med_period)..=i { + if buying_pressure[j].is_nan() || true_range[j].is_nan() { + has_nan = true; + break; + } + bp_sum += buying_pressure[j]; + tr_sum += true_range[j]; + } + if has_nan { + avg14[i] = f64::NAN; + } else if tr_sum != 0.0 { + avg14[i] = bp_sum / tr_sum; + } else { + avg14[i] = 0.0; + } + } + + // 计算 avg28 - 与Pandas rolling().sum()行为完全一致 + for i in (long_period - 1)..len { + let mut bp_sum = 0.0; + let mut tr_sum = 0.0; + // 检查是否有NaN值,如果有则结果也为NaN + let mut has_nan = false; + for j in (i + 1 - long_period)..=i { + if buying_pressure[j].is_nan() || true_range[j].is_nan() { + has_nan = true; + break; + } + bp_sum += buying_pressure[j]; + tr_sum += true_range[j]; + } + if has_nan { + avg28[i] = f64::NAN; + } else if tr_sum != 0.0 { + avg28[i] = bp_sum / tr_sum; + } else { + avg28[i] = 0.0; + } + } + + // 计算 UOS - 与Python实现完全一致 + let mut uos = vec![f64::NAN; len]; + for i in 0..len { + // 检查是否有 NaN 值 + if avg7[i].is_nan() || avg14[i].is_nan() || avg28[i].is_nan() { + uos[i] = f64::NAN; + } else { + uos[i] = 100.0 * ((4.0 * avg7[i]) + (2.0 * avg14[i]) + avg28[i]) / (4.0 + 2.0 + 1.0); + } + } + + uos +} + +/// 指数平滑 - 基础时间序列平滑技术 +/// 返回平滑后的序列 +pub fn exponential_smoothing(series: &[f64], alpha: f64) -> Vec { + let len = series.len(); + if len == 0 { + return vec![]; + } + + let mut result = vec![0.0; len]; + + // 第一个值保持不变 + result[0] = series[0]; + + // 应用指数平滑公式 + for i in 1..len { + result[i] = alpha * series[i] + (1.0 - alpha) * result[i - 1]; + } + + result +} + +/// Holt-Winters三参数平滑 - 支持趋势和季节性的平滑方法 +/// 返回平滑后的序列 +pub fn holt_winters( + series: &[f64], + season_length: usize, + alpha: f64, + beta: f64, + gamma: f64, +) -> Vec { + let n = series.len(); + if n == 0 || season_length == 0 || season_length > n { + return vec![]; + } + + let mut level = vec![0.0; n]; + let mut trend = vec![0.0; n]; + let mut season = vec![0.0; n]; + let mut forecast = vec![0.0; n]; + + // 初始化 + let initial_level = series[..season_length].iter().sum::() / season_length as f64; + for i in 0..season_length { + level[i] = initial_level; + trend[i] = 0.0; + season[i] = series[i] - level[i]; + } + + // 三重指数平滑 + for i in season_length..n { + level[i] = alpha * (series[i] - season[i - season_length]) + + (1.0 - alpha) * (level[i - 1] + trend[i - 1]); + trend[i] = beta * (level[i] - level[i - 1]) + (1.0 - beta) * trend[i - 1]; + season[i] = gamma * (series[i] - level[i]) + (1.0 - gamma) * season[i - season_length]; + forecast[i] = level[i] + trend[i] + season[i]; + } + + // 前 season_length 个值设为原始值 + for i in 0..season_length { + forecast[i] = series[i]; + } + + forecast +} diff --git a/crates/czsc-ta/src/python.rs b/crates/czsc-ta/src/python.rs new file mode 100644 index 000000000..1253a2318 --- /dev/null +++ b/crates/czsc-ta/src/python.rs @@ -0,0 +1,243 @@ +//! PyO3 binding registry for czsc-ta. +//! +//! Mirrors the wrapper layer that rs-czsc kept inside its python crate +//! (rs_czsc/python/src/utils/ta.rs); we move the `#[pyfunction]` shells +//! into czsc-ta itself so czsc-python only orchestrates `register()` +//! calls. All wrappers are dormant unless the `python` feature +//! (or `rust-numpy` for the numpy-bound entries) is on. + +use pyo3::prelude::*; + +use crate::{mixed, pure}; + +#[pyfunction] +fn ultimate_smoother(close: Vec, period: f64) -> Vec { + pure::ultimate_smoother(&close, period) +} + +#[pyfunction] +fn rolling_rank(series: Vec, window: usize) -> Vec { + // Convert Option -> f64 (None -> NaN) so `np.asarray(...)` lands + // in float64 dtype and `np.isfinite(out[window:])` works as expected. + // Python callers consuming the rank position can `.dropna()` instead of + // filtering Nones. + pure::rolling_rank(&series, window) + .into_iter() + .map(|opt| opt.map(|r| r as f64).unwrap_or(f64::NAN)) + .collect() +} + +#[pyfunction] +#[pyo3(signature = (series, n=None, *, period=None, length=None))] +fn sma(series: Vec, n: Option, period: Option, length: Option) -> Vec { + // Same kwarg story as `ema` — talib's keyword is `timeperiod` / + // pandas-ta's is `length`; rs-czsc historical scripts pass `n` / + // `period`. Phase A parity test calls `ta.sma(series, length=20)`. + let p = n.or(period).or(length).unwrap_or(0); + pure::sma(&series, p) +} + +#[pyfunction] +fn single_sma_positions(series: Vec, n: usize) -> Vec { + pure::single_sma_positions(&series, n) +} + +#[pyfunction] +fn single_ema_positions(series: Vec, n: usize) -> Vec { + pure::single_ema_positions(&series, n) +} + +#[pyfunction] +fn mid_positions(series: Vec, n: usize) -> Vec { + pure::mid_positions(&series, n) +} + +#[pyfunction] +fn double_sma_positions(series: Vec, n: usize, m: usize) -> Vec { + pure::double_sma_positions(&series, n, m) +} + +#[pyfunction] +fn triple_sma_positions(series: Vec, m1: usize, m2: usize, m3: usize) -> Vec { + pure::triple_sma_positions(&series, m1, m2, m3) +} + +#[pyfunction] +fn boll_positions(series: Vec, n: usize, k: f64) -> Vec { + pure::boll_positions(&series, n, k) +} + +#[pyfunction] +fn boll_reverse_positions(series: Vec, n: usize, k: f64) -> Vec { + pure::boll_reverse_positions(&series, n, k) +} + +#[pyfunction] +fn mms_positions(series: Vec, timeperiod: usize, window: usize) -> Vec { + pure::mms_positions(&series, timeperiod, window) +} + +#[pyfunction] +fn rsi_reverse_positions( + series: Vec, + n: usize, + rsi_upper: f64, + rsi_lower: f64, + rsi_exit: f64, +) -> Vec { + pure::rsi_reverse_positions(&series, n, rsi_upper, rsi_lower, rsi_exit) +} + +#[pyfunction] +fn tanh_positions(series: Vec, n: usize) -> Vec { + pure::tanh_positions(&series, n) +} + +#[pyfunction] +fn rank_positions(series: Vec, n: usize) -> Vec { + pure::rank_positions(&series, n) +} + +#[pyfunction] +#[pyo3(signature = (series, n=None, *, period=None, length=None))] +fn ema(series: Vec, n: Option, period: Option, length: Option) -> Vec { + // Accept any of: positional `n`, kwargs `period=` (legacy rs-czsc) or + // `length=` (talib / pandas-ta convention). The Phase A parity test + // in `test/unit/test_ta_parity.py::test_ema_matches_talib` calls + // `ta.ema(series, length=14)`; rs-czsc historical scripts pass + // `period=14`. Resolution order preserves the positional path first + // so existing positional callers keep working. + let p = n.or(period).or(length).unwrap_or(0); + pure::ema(&series, p) +} + +#[pyfunction] +fn true_range(high: Vec, low: Vec, close_prev: Vec) -> Vec { + pure::true_range(&high, &low, &close_prev) +} + +#[pyfunction] +fn rsx_ss2(close: Vec, period: usize, smooth_period: usize) -> Vec { + pure::rsx_ss2(&close, period, smooth_period) +} + +#[pyfunction] +fn jurik_volty(close: Vec, period: usize, power: f64) -> Vec { + pure::jurik_volty(&close, period, power) +} + +#[pyfunction] +fn ultimate_channel( + high: Vec, + low: Vec, + close: Vec, + period: usize, + multiplier: f64, +) -> (Vec, Vec, Vec) { + pure::ultimate_channel(&high, &low, &close, period, multiplier) +} + +#[pyfunction] +fn ultimate_bands( + close: Vec, + period: usize, + std_multiplier: f64, + smooth_period: usize, +) -> (Vec, Vec, Vec) { + pure::ultimate_bands(&close, period, std_multiplier, smooth_period) +} + +#[pyfunction] +fn ultimate_oscillator( + high: Vec, + low: Vec, + close: Vec, + short_period: usize, + med_period: usize, + long_period: usize, +) -> Vec { + pure::ultimate_oscillator(&high, &low, &close, short_period, med_period, long_period) +} + +#[pyfunction] +fn exponential_smoothing(series: Vec, alpha: f64) -> Vec { + pure::exponential_smoothing(&series, alpha) +} + +#[pyfunction] +fn holt_winters( + series: Vec, + season_length: usize, + alpha: f64, + beta: f64, + gamma: f64, +) -> Vec { + pure::holt_winters(&series, season_length, alpha, beta, gamma) +} + +/// Add the migrated czsc-ta functions onto the parent module that +/// czsc-python passes in. Build a `ta` submodule mirroring the design +/// doc §3.1 namespace map (czsc.ta.* + repeated top-level exposure). +pub fn register(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { + let ta = PyModule::new(py, "ta")?; + // Set the fully-qualified __name__ so `czsc.ta` (aliased via + // sys.modules) reports `__name__ == "czsc._native.ta"`. Required + // by the public-API parity test that checks namespace origin and + // by pickle when classes living in this submodule get round-tripped. + ta.setattr("__name__", "czsc._native.ta")?; + + macro_rules! add { + ($($name:ident),+ $(,)?) => {{ + $( + ta.add_function(wrap_pyfunction!($name, &ta)?)?; + parent.add_function(wrap_pyfunction!($name, parent)?)?; + )+ + }}; + } + + add!( + ultimate_smoother, + rolling_rank, + sma, + single_sma_positions, + single_ema_positions, + mid_positions, + double_sma_positions, + triple_sma_positions, + boll_positions, + boll_reverse_positions, + mms_positions, + rsi_reverse_positions, + tanh_positions, + rank_positions, + ema, + true_range, + rsx_ss2, + jurik_volty, + ultimate_channel, + ultimate_bands, + ultimate_oscillator, + exponential_smoothing, + holt_winters, + ); + + // numpy-bound entries + ta.add_function(wrap_pyfunction!(mixed::chip_dist::chip_distribution_triangle, &ta)?)?; + parent.add_function(wrap_pyfunction!(mixed::chip_dist::chip_distribution_triangle, parent)?)?; + + // Register the submodule into sys.modules so `from czsc._native.ta + // import ema` (and `import czsc._native.ta`) works the same as a + // pure-Python package. `parent.add_submodule` only sets it as an + // attribute of the parent — sys.modules is the bit Python's import + // machinery actually consults for nested module resolution. + let sys = py.import("sys")?; + let py_modules = sys.getattr("modules")?; + py_modules.set_item("czsc._native.ta", &ta)?; + // Use `parent.add` instead of `add_submodule` so we control the + // attribute key (`parent.ta`) independently of the module's + // qualified __name__ (`czsc._native.ta`). add_submodule uses the + // qualified name as the attribute, which would expose the + // submodule as `parent.czsc._native.ta` instead of `parent.ta`. + parent.add("ta", &ta)?; + Ok(()) +} diff --git a/crates/czsc-ta/tests/test_pure.rs b/crates/czsc-ta/tests/test_pure.rs new file mode 100644 index 000000000..efb9ab5dd --- /dev/null +++ b/crates/czsc-ta/tests/test_pure.rs @@ -0,0 +1,108 @@ +//! Phase E.2 — RED test: czsc-ta pure operators preserve length, behave +//! correctly on degenerate inputs, and produce sensible numeric output +//! aligned with the rs-czsc 47ef6efa baseline. + +use czsc_ta::pure::{ + boll_positions, double_sma_positions, ema, mid_positions, rolling_rank, + single_ema_positions, single_sma_positions, true_range, ultimate_smoother, +}; + +fn series(n: usize) -> Vec { + (0..n).map(|i| (i as f64) * 0.1 + 100.0).collect() +} + +#[test] +fn ultimate_smoother_preserves_length() { + let s = series(50); + let out = ultimate_smoother(&s, 10.0); + assert_eq!(out.len(), s.len()); +} + +#[test] +fn ultimate_smoother_first_4_passthrough() { + // First 4 values must equal input per the rs-czsc contract. + let s = series(20); + let out = ultimate_smoother(&s, 10.0); + for i in 0..4 { + assert!((out[i] - s[i]).abs() < f64::EPSILON, "i={i}: {} vs {}", out[i], s[i]); + } +} + +#[test] +fn ultimate_smoother_empty_returns_empty() { + assert!(ultimate_smoother(&[], 10.0).is_empty()); +} + +#[test] +fn rolling_rank_preserves_length() { + let s = series(50); + let out = rolling_rank(&s, 10); + assert_eq!(out.len(), s.len()); +} + +#[test] +fn ema_preserves_length() { + let s = series(50); + let out = ema(&s, 12); + assert_eq!(out.len(), s.len()); +} + +#[test] +fn ema_zero_period_returns_empty() { + // rs-czsc's ema short-circuits to an empty Vec when period == 0; + // lock that behaviour so callers can rely on a stable contract. + let s = vec![1.0, 2.0, 3.0]; + let out = ema(&s, 0); + assert!(out.is_empty(), "ema(_, 0) must short-circuit to empty"); +} + +#[test] +fn single_sma_positions_preserves_length() { + let s = series(30); + assert_eq!(single_sma_positions(&s, 5).len(), s.len()); +} + +#[test] +fn single_ema_positions_preserves_length() { + let s = series(30); + assert_eq!(single_ema_positions(&s, 5).len(), s.len()); +} + +#[test] +fn mid_positions_in_range() { + let s = series(30); + let out = mid_positions(&s, 5); + for v in &out { + assert!(*v >= -1.0 && *v <= 1.0, "expected position in [-1, 1], got {v}"); + } +} + +#[test] +fn double_sma_positions_returns_signal_length() { + let s = series(30); + let out = double_sma_positions(&s, 5, 10); + assert_eq!(out.len(), s.len()); +} + +#[test] +fn boll_positions_in_signed_range() { + let s = series(50); + let out = boll_positions(&s, 20, 2.0); + assert_eq!(out.len(), s.len()); + for v in &out { + assert!(*v >= -1 && *v <= 1, "boll position must be -1/0/1, got {v}"); + } +} + +#[test] +fn true_range_matches_input_length() { + let high = vec![10.0, 11.0, 12.0, 11.5]; + let low = vec![ 9.0, 9.5, 10.5, 10.0]; + let prev = vec![ 9.5, 10.0, 11.0, 10.5]; + let tr = true_range(&high, &low, &prev); + assert_eq!(tr.len(), 4); + // tr[i] = max(high-low, |high-prev|, |low-prev|) >= 0 + for v in &tr { + assert!(*v >= 0.0, "true_range must be non-negative, got {v}"); + } +} diff --git a/crates/czsc-trader/Cargo.toml b/crates/czsc-trader/Cargo.toml new file mode 100644 index 000000000..31f9e9597 --- /dev/null +++ b/crates/czsc-trader/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "czsc-trader" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "CZSC trader (CzscTrader / CzscSignals / signal compilation / weight backtest). Migrated from rs-czsc." + +[lib] +name = "czsc_trader" +path = "src/lib.rs" + +[dependencies] +czsc-core = { path = "../czsc-core", features = ["python"] } +czsc-utils = { path = "../czsc-utils" } +czsc-signals = { path = "../czsc-signals" } +error-macros = { path = "../error-macros" } +error-support = { path = "../error-support" } + +anyhow = "1" +chrono = { workspace = true, features = ["serde"] } +hashbrown = { workspace = true } +hex = "0.4" +log = "0.4" +md5 = "0.8" +polars = { workspace = true, features = ["lazy", "ipc", "parquet"] } +polars-plan = { version = "0.42" } +rayon = { workspace = true } +serde = { workspace = true } +serde_json = "1" +sha2 = "0.10" +strum = "0.26" +strum_macros = "0.26" +thiserror = "2" +tracing = "0.1" +arrayvec = "0.7" +csv = "1" diff --git a/crates/czsc-trader/src/engine_v2/catalog/mod.rs b/crates/czsc-trader/src/engine_v2/catalog/mod.rs new file mode 100644 index 000000000..3fdc8b652 --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/catalog/mod.rs @@ -0,0 +1,70 @@ +use crate::signals::sig_parse::SignalConfig; +use czsc_signals::registry::{ + SIGNAL_REGISTRY, TRADER_SIGNAL_REGISTRY, list_generated_signal_descriptors, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SignalCategory { + Kline, + Trader, +} + +#[derive(Debug, Clone)] +pub struct CatalogSignal { + pub name: String, + pub category: SignalCategory, +} + +pub fn resolve_signal_category(sc: &SignalConfig) -> Result { + for d in list_generated_signal_descriptors() { + if d.name.eq_ignore_ascii_case(sc.name.as_str()) { + let category = match d.category { + "kline" => SignalCategory::Kline, + "trader" => SignalCategory::Trader, + _ => continue, + }; + return Ok(CatalogSignal { + name: d.name.to_string(), + category, + }); + } + } + + if sc.freq.is_some() { + if SIGNAL_REGISTRY.contains_key(sc.name.as_str()) { + return Ok(CatalogSignal { + name: sc.name.clone(), + category: SignalCategory::Kline, + }); + } + return Err(format!("未注册 K线信号: {}", sc.name)); + } + + if TRADER_SIGNAL_REGISTRY.contains_key(sc.name.as_str()) { + Ok(CatalogSignal { + name: sc.name.clone(), + category: SignalCategory::Trader, + }) + } else { + Err(format!("未注册 Trader 信号: {}", sc.name)) + } +} + +#[cfg(test)] +mod tests { + use super::{SignalCategory, resolve_signal_category}; + use crate::signals::sig_parse::SignalConfig; + use std::collections::HashMap; + + #[test] + fn resolve_uses_generated_descriptor_first() { + let sc = SignalConfig { + name: "tas_macd_base_V221028".to_string(), + freq: Some("60分钟".to_string()), + params: HashMap::new(), + }; + let cs = resolve_signal_category(&sc).expect("must resolve"); + assert_eq!(cs.category, SignalCategory::Kline); + assert_eq!(cs.name, "tas_macd_base_V221028"); + } +} diff --git a/crates/czsc-trader/src/engine_v2/compiler/event.rs b/crates/czsc-trader/src/engine_v2/compiler/event.rs new file mode 100644 index 000000000..0a60b8f86 --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/compiler/event.rs @@ -0,0 +1,104 @@ +use czsc_core::objects::position::Position; +use std::collections::BTreeMap; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EventCondition { + pub key_id: u32, + pub value: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EventClause { + pub all: Vec, + pub any: Vec, + pub not: Vec, +} + +#[derive(Debug, Clone, Default)] +pub struct CompiledEventPlan { + pub key_ids: BTreeMap, + pub by_position: BTreeMap>, +} + +fn parse_condition(sig: &str, key_ids: &mut BTreeMap) -> Option { + let mut parts = sig.split('_'); + let k1 = parts.next()?; + let k2 = parts.next()?; + let k3 = parts.next()?; + let v1 = parts.next().unwrap_or("其他"); + let key = format!("{k1}_{k2}_{k3}"); + let key_id = if let Some(x) = key_ids.get(key.as_str()) { + *x + } else { + let x = key_ids.len() as u32; + key_ids.insert(key.clone(), x); + x + }; + Some(EventCondition { + key_id, + value: v1.to_string(), + }) +} + +pub fn compile_events(positions: &[Position]) -> CompiledEventPlan { + let mut key_ids = BTreeMap::new(); + let mut by_position = BTreeMap::new(); + + for p in positions { + let mut clauses = Vec::new(); + for e in p.opens.iter().chain(p.exits.iter()) { + let all = e + .signals_all + .iter() + .filter_map(|x| parse_condition(x.to_string().as_str(), &mut key_ids)) + .collect(); + let any = e + .signals_any + .iter() + .filter_map(|x| parse_condition(x.to_string().as_str(), &mut key_ids)) + .collect(); + let not = e + .signals_not + .iter() + .filter_map(|x| parse_condition(x.to_string().as_str(), &mut key_ids)) + .collect(); + clauses.push(EventClause { all, any, not }); + } + by_position.insert(p.name.clone(), clauses); + } + + CompiledEventPlan { + key_ids, + by_position, + } +} + +#[cfg(test)] +mod tests { + use super::compile_events; + use czsc_core::objects::position::Position; + + #[test] + fn compile_events_extracts_key_ids() { + let p: Position = serde_json::from_value(serde_json::json!({ + "name": "P", + "symbol": "000001.SZ", + "opens": [{ + "name": "开多", + "operate": "开多", + "signals_all": ["30分钟_D1_表里关系V230101_向上_任意_任意_0"], + "signals_any": [], + "signals_not": [] + }], + "exits": [], + "interval": 0, + "timeout": 10, + "stop_loss": 100.0, + "T0": false + })) + .expect("position from value"); + let plan = compile_events(&[p]); + assert!(!plan.key_ids.is_empty()); + assert!(plan.by_position.contains_key("P")); + } +} diff --git a/crates/czsc-trader/src/engine_v2/compiler/mod.rs b/crates/czsc-trader/src/engine_v2/compiler/mod.rs new file mode 100644 index 000000000..3c602d738 --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/compiler/mod.rs @@ -0,0 +1,130 @@ +pub(crate) mod event; +pub(crate) mod optimize; +pub(crate) mod position; +mod signal; + +use crate::engine_v2::catalog::{CatalogSignal, resolve_signal_category}; +use crate::engine_v2::compiler::event::compile_events; +use crate::engine_v2::compiler::position::compile_positions; +use crate::engine_v2::compiler::signal::{CompiledSignalPlan, compile_signals}; +use crate::signals::sig_parse::SignalConfig; +use czsc_core::objects::position::Position; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ExecutionPlanInput { + pub symbol: String, + pub base_freq: String, + #[serde(default)] + pub signals_config: Vec, + #[serde(default)] + pub positions: Vec, + pub market: Option, + pub bg_max_count: Option, + /// 信号/策略正式起跑的时间分界点。 + /// + /// 是否将 `sdt` 当根 K 线计入右侧正式运行,由 `include_sdt_bar` 决定。 + pub sdt: Option, + #[serde(default)] + /// 控制 `sdt` 当根 K 线归属到左侧预热还是右侧正式运行。 + /// + /// - `false`:左侧 `dt <= sdt`,右侧 `dt > sdt` + /// - `true`:左侧 `dt < sdt`,右侧 `dt >= sdt` + /// + /// 默认值为 `false`,对齐 Python `CzscStrategyBase` 的回测 / replay 语义。 + /// 信号导出场景(例如 `generate_czsc_signals` / `signal_matrix`)如果要对齐 + /// Python 基线,需要显式传入 `true`。 + pub include_sdt_bar: Option, +} + +#[derive(Debug, Clone)] +pub struct ExecutionPlan { + pub symbol: String, + pub base_freq: String, + pub signals_config: Vec, + pub positions: Vec, + pub market: Option, + pub bg_max_count: usize, + /// 编译后的 `sdt` 时间分界点,供执行器拆分预热区和正式运行区。 + pub sdt: Option, + /// 编译后的 `sdt` 边界模式。 + /// + /// - `false`:`sdt` 当根 bar 只参与预热,不进入右侧正式输出 + /// - `true`:`sdt` 当根 bar 直接作为右侧第一根有效 bar + pub include_sdt_bar: bool, + pub catalog_signals: Vec, + pub signal_plan: CompiledSignalPlan, + pub event_plan: event::CompiledEventPlan, + pub position_plan: position::CompiledPositionPlan, +} + +pub(crate) use signal::CompiledSignalPlan as CompiledSignalPlanV2; + +impl ExecutionPlan { + pub fn compile(input: ExecutionPlanInput) -> Result { + if input.symbol.trim().is_empty() { + return Err("strategy.symbol 不能为空".to_string()); + } + if input.positions.is_empty() { + return Err("strategy.positions 不能为空".to_string()); + } + + let ExecutionPlanInput { + symbol, + base_freq, + signals_config, + mut positions, + market, + bg_max_count, + sdt, + include_sdt_bar, + } = input; + + for pos in &mut positions { + pos.normalize_runtime_fields(); + } + + let mut catalog_signals = Vec::with_capacity(signals_config.len()); + for sc in &signals_config { + catalog_signals.push(resolve_signal_category(sc)?); + } + let signal_plan = compile_signals(&signals_config, &catalog_signals)?; + let event_plan = compile_events(&positions); + let position_plan = compile_positions(&positions); + + Ok(Self { + symbol, + base_freq, + signals_config, + positions, + market, + bg_max_count: bg_max_count.unwrap_or(5000), + sdt, + include_sdt_bar: include_sdt_bar.unwrap_or(false), + catalog_signals, + signal_plan, + event_plan, + position_plan, + }) + } +} + +#[cfg(test)] +mod tests { + use super::{ExecutionPlan, ExecutionPlanInput}; + + #[test] + fn compile_rejects_empty_symbol() { + let input = ExecutionPlanInput { + symbol: String::new(), + base_freq: "30分钟".to_string(), + signals_config: vec![], + positions: vec![], + market: None, + bg_max_count: None, + sdt: None, + include_sdt_bar: None, + }; + assert!(ExecutionPlan::compile(input).is_err()); + } +} diff --git a/crates/czsc-trader/src/engine_v2/compiler/optimize.rs b/crates/czsc-trader/src/engine_v2/compiler/optimize.rs new file mode 100644 index 000000000..db7ec3e31 --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/compiler/optimize.rs @@ -0,0 +1,35 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CandidateChunk { + pub start: usize, + pub end: usize, +} + +pub fn build_candidate_chunks(total: usize, chunk_size: usize) -> Vec { + if total == 0 { + return Vec::new(); + } + let step = chunk_size.max(1); + let mut out = Vec::new(); + let mut i = 0usize; + while i < total { + let end = (i + step).min(total); + out.push(CandidateChunk { start: i, end }); + i = end; + } + out +} + +#[cfg(test)] +mod tests { + use super::build_candidate_chunks; + + #[test] + fn chunk_split_is_deterministic() { + let c = build_candidate_chunks(10, 3); + assert_eq!(c.len(), 4); + assert_eq!(c[0].start, 0); + assert_eq!(c[0].end, 3); + assert_eq!(c[3].start, 9); + assert_eq!(c[3].end, 10); + } +} diff --git a/crates/czsc-trader/src/engine_v2/compiler/position.rs b/crates/czsc-trader/src/engine_v2/compiler/position.rs new file mode 100644 index 000000000..d2044d42e --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/compiler/position.rs @@ -0,0 +1,29 @@ +use czsc_core::objects::position::Position; + +#[derive(Debug, Clone)] +pub struct CompiledPosition { + pub name: String, + pub interval: i64, + pub timeout: i32, + pub stop_loss: f64, + pub event_count: usize, +} + +#[derive(Debug, Clone, Default)] +pub struct CompiledPositionPlan { + pub positions: Vec, +} + +pub fn compile_positions(positions: &[Position]) -> CompiledPositionPlan { + let positions = positions + .iter() + .map(|p| CompiledPosition { + name: p.name.clone(), + interval: p.interval, + timeout: p.timeout, + stop_loss: p.stop_loss, + event_count: p.opens.len() + p.exits.len(), + }) + .collect(); + CompiledPositionPlan { positions } +} diff --git a/crates/czsc-trader/src/engine_v2/compiler/signal.rs b/crates/czsc-trader/src/engine_v2/compiler/signal.rs new file mode 100644 index 000000000..1620f98dd --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/compiler/signal.rs @@ -0,0 +1,54 @@ +use crate::engine_v2::catalog::{CatalogSignal, SignalCategory}; +use crate::signals::sig_parse::SignalConfig; +use serde_json::Value; +use std::collections::BTreeMap; + +#[derive(Debug, Clone)] +pub struct CompiledSignal { + pub signal_id: u32, + pub name: String, + pub freq: Option, + pub category: SignalCategory, + pub params: Value, +} + +#[derive(Debug, Clone, Default)] +pub struct CompiledSignalPlan { + pub ops: Vec, + pub id_by_name: BTreeMap, +} + +pub fn compile_signals( + configs: &[SignalConfig], + catalog: &[CatalogSignal], +) -> Result { + if configs.len() != catalog.len() { + return Err("signals_config 与 catalog_signals 数量不一致".to_string()); + } + + let mut id_by_name: BTreeMap = BTreeMap::new(); + for c in catalog { + if !id_by_name.contains_key(c.name.as_str()) { + let id = id_by_name.len() as u32; + id_by_name.insert(c.name.clone(), id); + } + } + + let mut ops = Vec::with_capacity(configs.len()); + for (sc, cat) in configs.iter().zip(catalog.iter()) { + let signal_id = *id_by_name + .get(cat.name.as_str()) + .ok_or_else(|| format!("signal id 分配失败: {}", cat.name))?; + let params = serde_json::to_value(&sc.params) + .map_err(|e| format!("信号参数序列化失败 {}: {e}", sc.name))?; + ops.push(CompiledSignal { + signal_id, + name: cat.name.clone(), + freq: sc.freq.clone(), + category: cat.category, + params, + }); + } + + Ok(CompiledSignalPlan { ops, id_by_name }) +} diff --git a/crates/czsc-trader/src/engine_v2/mod.rs b/crates/czsc-trader/src/engine_v2/mod.rs new file mode 100644 index 000000000..69b77568c --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/mod.rs @@ -0,0 +1,7 @@ +pub mod catalog; +pub mod compiler; +pub mod runtime; +pub mod scheduler; + +pub use compiler::{ExecutionPlan, ExecutionPlanInput}; +pub use runtime::{CoreLoopProfileV2, RunOutput, UnifiedExecEngine}; diff --git a/crates/czsc-trader/src/engine_v2/runtime/executor.rs b/crates/czsc-trader/src/engine_v2/runtime/executor.rs new file mode 100644 index 000000000..e971565b2 --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/runtime/executor.rs @@ -0,0 +1,446 @@ +use crate::engine_v2::catalog::SignalCategory; +use crate::engine_v2::compiler::ExecutionPlan; +use crate::signals::czsc_signals::CzscSignals; +use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::market::Market; +use czsc_core::objects::position::{LiteBar, Position}; +use czsc_core::objects::state::TraderState; +use czsc_signals::registry::TRADER_SIGNAL_REGISTRY; +use czsc_signals::types::TraderSignalFn; +use czsc_utils::bar_generator::BarGenerator; +use czsc_utils::freq_data::infer_market_from_bars; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use std::time::Instant; + +#[derive(Debug, Clone, Copy, Default)] +pub struct CoreLoopProfileV2 { + pub bars: usize, + pub signals_update_ns: u128, + pub trader_signals_ns: u128, + pub position_update_ns: u128, + pub pos_event_match_ns: u128, + pub pos_fsm_ns: u128, + pub pos_risk_ns: u128, + pub pos_holds_ns: u128, +} + +impl CoreLoopProfileV2 { + pub fn total_ns(&self) -> u128 { + self.signals_update_ns + self.trader_signals_ns + self.position_update_ns + } +} + +pub struct RunOutput { + pub bars_count: usize, + pub signal_rows: Vec>, + pub positions: Vec, + pub elapsed_ms: i64, + pub profile: Option, +} + +pub struct UnifiedExecEngine; + +#[derive(Clone)] +struct CompiledTraderSignalOp { + func: TraderSignalFn, + params: HashMap, +} + +struct RuntimeTraderState<'a> { + positions: &'a [Position], + kas: &'a std::collections::BTreeMap, + latest_price: Option, +} + +impl TraderState for RuntimeTraderState<'_> { + #[inline] + fn get_position(&self, name: &str) -> Option<&Position> { + self.positions.iter().find(|p| p.name == name) + } + + #[inline] + fn get_czsc(&self, freq: &str) -> Option<&CZSC> { + self.kas.get(freq) + } + + #[inline] + fn latest_price(&self) -> Option { + self.latest_price + } +} + +fn compile_trader_ops(plan: &ExecutionPlan) -> Result, String> { + let mut ops = Vec::new(); + for op in &plan.signal_plan.ops { + if matches!(op.category, SignalCategory::Trader) + && let Some(meta) = TRADER_SIGNAL_REGISTRY.get(op.name.as_str()) + { + ops.push(CompiledTraderSignalOp { + func: meta.func, + params: serde_json::from_value(op.params.clone()) + .map_err(|e| format!("trader 信号参数解析失败 {}: {e}", op.name))?, + }); + } + } + Ok(ops) +} + +impl UnifiedExecEngine { + pub fn run( + plan: &ExecutionPlan, + mut bars: Vec, + sdt_override: Option<&str>, + emit_signals: bool, + enable_profile: bool, + ) -> Result { + let t0 = Instant::now(); + if bars.is_empty() { + return Err("bars 为空,无法执行回测".to_string()); + } + + let base_freq = plan + .base_freq + .parse::() + .map_err(|_| "strategy.base_freq 解析失败".to_string())?; + + let requested_market = parse_market(plan.market.as_deref()); + let market = infer_effective_market(&bars, base_freq, requested_market); + let freqs = collect_freqs(base_freq, &plan.signals_config)?; + let trader_ops = compile_trader_ops(plan)?; + // 对齐 Python 基线 `generate_czsc_signals(init_n=500)` 的左右分段逻辑: + // 1) bars_left = bars[dt < sdt] + // 2) 若 len(bars_left) <= init_n,则 bars_left=bars[:init_n], bars_right=bars[init_n:] + // 3) 否则 bars_right = bars[dt >= sdt] + // 当 bars_right 为空时,不执行回测主循环。 + const INIT_N: usize = 500; + // 对齐 Python `CzscStrategyBase.init_bar_generator`: + // - 默认 sdt = "20200101" + // - bars_init 使用 `dt <= sdt` + // - 若 len(bars_init) > n(500): bars1=bars_init, bars2=dt > sdt + // - 否则 bars1=bars[:n], bars2=bars[n:] + let sdt_final = sdt_override + .map(|x| x.to_string()) + .or_else(|| plan.sdt.clone()) + .or_else(|| Some("20200101".to_string())); + let cutoff = sdt_final.as_deref().and_then(parse_sdt_utc); + let bars_len = bars.len(); + let start_idx = if let Some(c) = cutoff { + let bars_init_count = if plan.include_sdt_bar { + bars.iter().take_while(|b| b.dt < c).count() + } else { + bars.iter().take_while(|b| b.dt <= c).count() + }; + if !trader_ops.is_empty() { + // Trader 对照链路(benchmarks/generate_py_trader_signals_df)使用显式 warmup_n, + // 调用侧会将 sdt 设为 bars[warmup_n - 1].dt;这里按 sdt 精确预热, + // 避免被固定 INIT_N=500 覆盖导致状态路径错位。 + bars_init_count.clamp(1, bars_len.saturating_sub(1)) + } else if bars_init_count > INIT_N { + bars_init_count + } else { + bars_len.min(INIT_N) + } + } else { + bars_len.min(INIT_N) + }; + + let bg = BarGenerator::new(base_freq, freqs, plan.bg_max_count, market) + .map_err(|e| format!("初始化 BarGenerator 失败: {e:?}"))?; + + let mut signals = CzscSignals::new(plan.symbol.clone(), bg); + signals + .load_compiled_signal_plan(&plan.signal_plan) + .map_err(|e| format!("装载编译信号计划失败: {e}"))?; + let mut positions = plan.positions.clone(); + + // 先用左侧 bars 初始化 BG / CZSC。 + for bar in bars.iter().take(start_idx) { + signals.warmup_bar(bar); + } + + // 对齐 Python `CzscSignals(bg)`:warmup 完成后会立刻计算一次当前信号, + // 用于初始化 bar.cache / 指标缓存,但不会把这一时刻计入 bars_right 输出, + // 也不会推进 Position。 + if start_idx > 0 { + let prime_bar = &bars[start_idx - 1]; + signals.prime_signals(prime_bar, &plan.signals_config); + + if !trader_ops.is_empty() { + let latest_price = signals + .s + .get("close") + .and_then(|x| x.parse::().ok()) + .or(Some(prime_bar.close)); + let state = RuntimeTraderState { + positions: &positions, + kas: &signals.kas, + latest_price, + }; + for op in &trader_ops { + for sig in (op.func)(&state, &op.params) { + let (k, v) = (sig.key(), sig.value()); + signals.s.insert(k.clone(), v.clone()); + signals.signal_map.insert(k, v); + signals.sigs.insert(sig); + } + } + } + } + + let bars_count = bars_len.saturating_sub(start_idx); + let mut rows = if emit_signals { + Vec::with_capacity(bars_count) + } else { + Vec::new() + }; + let mut profile = CoreLoopProfileV2::default(); + + for bar in bars.drain(start_idx..) { + let t_signals = Instant::now(); + signals.update_signals(&bar, &plan.signals_config); + let signals_update_ns = t_signals.elapsed().as_nanos(); + + let t_trader_sig = Instant::now(); + if !trader_ops.is_empty() { + let latest_price = signals + .s + .get("close") + .and_then(|x| x.parse::().ok()) + .or(Some(bar.close)); + let state = RuntimeTraderState { + positions: &positions, + kas: &signals.kas, + latest_price, + }; + for op in &trader_ops { + for sig in (op.func)(&state, &op.params) { + let (k, v) = (sig.key(), sig.value()); + signals.s.insert(k.clone(), v.clone()); + signals.signal_map.insert(k, v); + signals.sigs.insert(sig); + } + } + } + let trader_signals_ns = t_trader_sig.elapsed().as_nanos(); + + let lite_bar = LiteBar { + id: bar.id, + dt: bar.dt.into(), + price: bar.close, + }; + let t_pos = Instant::now(); + let mut pos_event_match_ns = 0u128; + let mut pos_fsm_ns = 0u128; + let mut pos_risk_ns = 0u128; + let mut pos_holds_ns = 0u128; + for pos in &mut positions { + let p = + pos.update_profiled_with_signal_map(lite_bar, None, Some(&signals.signal_map)); + pos_event_match_ns += p.event_match_ns; + pos_fsm_ns += p.fsm_ns; + pos_risk_ns += p.risk_ns; + pos_holds_ns += p.holds_ns; + } + let position_update_ns = t_pos.elapsed().as_nanos(); + + if enable_profile { + profile.bars += 1; + profile.signals_update_ns += signals_update_ns; + profile.trader_signals_ns += trader_signals_ns; + profile.position_update_ns += position_update_ns; + profile.pos_event_match_ns += pos_event_match_ns; + profile.pos_fsm_ns += pos_fsm_ns; + profile.pos_risk_ns += pos_risk_ns; + profile.pos_holds_ns += pos_holds_ns; + } + if emit_signals { + rows.push(signals.s.clone()); + } + } + + Ok(RunOutput { + bars_count, + signal_rows: rows, + positions, + elapsed_ms: t0.elapsed().as_millis() as i64, + profile: enable_profile.then_some(profile), + }) + } +} + +fn collect_freqs( + base_freq: Freq, + signals_config: &[crate::signals::sig_parse::SignalConfig], +) -> Result, String> { + let push_freq = |freq_str: &str, freq_set: &mut HashSet| { + if let Ok(f) = freq_str.parse::() + && f != base_freq + { + freq_set.insert(f); + } + }; + + let mut freq_set: HashSet = HashSet::new(); + for sc in signals_config { + if let Some(freq_str) = &sc.freq { + push_freq(freq_str, &mut freq_set); + } + // 兼容 trader 级信号通过 params 传入多周期字段(freq/freq1/freq2/...) + for (k, v) in &sc.params { + if !k.starts_with("freq") { + continue; + } + if let Some(freq_str) = v.as_str() { + push_freq(freq_str, &mut freq_set); + } + } + } + + let mut freqs: Vec = freq_set.into_iter().collect(); + freqs.sort(); + Ok(freqs) +} + +fn parse_market(market: Option<&str>) -> Market { + match market.unwrap_or("默认") { + "A股" | "AShare" | "ashare" => Market::AShare, + "期货" | "Futures" | "futures" => Market::Futures, + _ => Market::Default, + } +} + +fn infer_effective_market(bars: &[RawBar], base_freq: Freq, requested: Market) -> Market { + let detected = infer_market_from_bars(bars, base_freq); + if matches!(requested, Market::Default) || requested != detected { + detected + } else { + requested + } +} + +fn parse_sdt_utc(s: &str) -> Option> { + if s.is_empty() { + return None; + } + if let Ok(dt) = DateTime::parse_from_rfc3339(s) { + return Some(dt.with_timezone(&Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(ndt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(d) = NaiveDate::parse_from_str(s, "%Y-%m-%d") + && let Some(ndt) = d.and_hms_opt(0, 0, 0) + { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + if let Ok(d) = NaiveDate::parse_from_str(s, "%Y%m%d") + && let Some(ndt) = d.and_hms_opt(0, 0, 0) + { + return Some(DateTime::from_naive_utc_and_offset(ndt, Utc)); + } + None +} + +#[cfg(test)] +mod tests { + use super::{collect_freqs, infer_effective_market}; + use crate::signals::sig_parse::SignalConfig; + use chrono::{NaiveDateTime, TimeZone, Utc}; + use czsc_core::objects::{bar::RawBarBuilder, freq::Freq, market::Market}; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn test_collect_freqs_includes_trader_freq_params() { + let mut trader_params = HashMap::new(); + trader_params.insert("freq1".to_string(), json!("日线")); + trader_params.insert("freq2".to_string(), json!("60分钟")); + let mut kline_params = HashMap::new(); + kline_params.insert("di".to_string(), json!(1)); + + let cfgs = vec![ + SignalConfig { + name: "cat_macd_V230518".to_string(), + freq: None, + params: trader_params, + }, + SignalConfig { + name: "tas_ma_base_V221101".to_string(), + freq: Some("15分钟".to_string()), + params: kline_params, + }, + ]; + let freqs = collect_freqs(Freq::F60, &cfgs).expect("collect freqs should succeed"); + assert!(freqs.contains(&Freq::D)); + assert!(freqs.contains(&Freq::F15)); + assert!(!freqs.contains(&Freq::F60)); + } + + #[test] + fn test_collect_freqs_keeps_freq_enum_order() { + let mut trader_params = HashMap::new(); + trader_params.insert("freq1".to_string(), json!("5分钟")); + trader_params.insert("freq2".to_string(), json!("日线")); + let cfgs = vec![ + SignalConfig { + name: "cat_macd_V230518".to_string(), + freq: None, + params: trader_params, + }, + SignalConfig { + name: "tas_ma_base_V221101".to_string(), + freq: Some("15分钟".to_string()), + params: HashMap::new(), + }, + ]; + let freqs = collect_freqs(Freq::F60, &cfgs).expect("collect freqs should succeed"); + assert_eq!(freqs, vec![Freq::F5, Freq::F15, Freq::D]); + } + + #[test] + fn test_infer_effective_market_falls_back_to_detected_default_for_utc_intraday_bars() { + let mk = |dt: &str| { + RawBarBuilder::default() + .symbol("000001.SZ".to_string()) + .id(0) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str(dt, "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F30) + .open(1.0) + .close(1.0) + .high(1.0) + .low(1.0) + .vol(1.0) + .amount(1.0) + .build() + .unwrap() + }; + let bars = vec![ + mk("2015-03-30 17:30:00"), + mk("2015-03-30 18:00:00"), + mk("2015-03-30 18:30:00"), + mk("2015-03-30 19:00:00"), + mk("2015-03-30 19:30:00"), + mk("2015-03-30 21:30:00"), + mk("2015-03-30 22:00:00"), + mk("2015-03-30 22:30:00"), + mk("2015-03-30 23:00:00"), + ]; + + assert_eq!( + infer_effective_market(&bars, Freq::F30, Market::AShare), + Market::Default + ); + } +} diff --git a/crates/czsc-trader/src/engine_v2/runtime/mod.rs b/crates/czsc-trader/src/engine_v2/runtime/mod.rs new file mode 100644 index 000000000..95542f791 --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/runtime/mod.rs @@ -0,0 +1,3 @@ +mod executor; + +pub use executor::{CoreLoopProfileV2, RunOutput, UnifiedExecEngine}; diff --git a/crates/czsc-trader/src/engine_v2/scheduler.rs b/crates/czsc-trader/src/engine_v2/scheduler.rs new file mode 100644 index 000000000..e92b17e37 --- /dev/null +++ b/crates/czsc-trader/src/engine_v2/scheduler.rs @@ -0,0 +1,40 @@ +use crate::engine_v2::compiler::ExecutionPlan; +use crate::engine_v2::compiler::optimize::build_candidate_chunks; +use crate::engine_v2::runtime::{RunOutput, UnifiedExecEngine}; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::position::Position; +use rayon::prelude::*; + +pub struct SymbolTask { + pub symbol: String, + pub bars: Vec, + pub plan: ExecutionPlan, +} + +pub struct SymbolResult { + pub symbol: String, + pub output: Result, +} + +pub fn run_symbol_parallel(tasks: Vec, emit_signals: bool) -> Vec { + let mut out: Vec = tasks + .into_par_iter() + .map(|task| SymbolResult { + symbol: task.symbol, + output: UnifiedExecEngine::run(&task.plan, task.bars, None, emit_signals, false), + }) + .collect(); + out.sort_by(|a, b| a.symbol.cmp(&b.symbol)); + out +} + +pub fn split_positions_into_chunks( + positions: &[Position], + chunk_size: usize, +) -> Vec> { + let chunks = build_candidate_chunks(positions.len(), chunk_size); + chunks + .into_iter() + .map(|c| positions[c.start..c.end].to_vec()) + .collect() +} diff --git a/crates/czsc-trader/src/lib.rs b/crates/czsc-trader/src/lib.rs new file mode 100644 index 000000000..f0e2e74be --- /dev/null +++ b/crates/czsc-trader/src/lib.rs @@ -0,0 +1,13 @@ +//! czsc-trader — multi-strategy trading engine, signal compilation, and +//! optimization. Migrated from rs-czsc 47ef6efa per docs/MIGRATION_NOTES.md §1. +//! +//! `weight_backtest` was deliberately not migrated: per design doc §5.8 +//! item 3 and §5.10, the public `WeightBacktest` API is delegated to the +//! external `wbt` package starting in Phase I. The Rust workspace owns +//! signal compilation, the trader state machine, and the v2 execution +//! engine that backs Python's `run_backtest` / `run_optimize` calls. + +pub mod engine_v2; +pub mod optimize; +pub mod signals; +pub mod trader; diff --git a/crates/czsc-trader/src/optimize.rs b/crates/czsc-trader/src/optimize.rs new file mode 100644 index 000000000..e2b8f1b66 --- /dev/null +++ b/crates/czsc-trader/src/optimize.rs @@ -0,0 +1,384 @@ +use crate::engine_v2::scheduler::split_positions_into_chunks; +use crate::engine_v2::{ExecutionPlan, ExecutionPlanInput, UnifiedExecEngine}; +use crate::signals::sig_parse::{SignalConfig, get_signals_config}; +use anyhow::{Context, Result}; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::event::Event; +use czsc_core::objects::operate::Operate; +use czsc_core::objects::position::{Position, load_position}; +use czsc_core::objects::signal::Signal; +use log::{info, warn}; +use md5; +use polars::prelude::*; +use rayon::prelude::*; +use sha2::{Digest, Sha256}; +use std::collections::{HashMap, HashSet}; +use std::fs; +use std::path::{Path, PathBuf}; + +/// 获取一根信号配置的哈希值前8位(对齐 Python MD5 算法) +fn hash_str(val: &str) -> String { + let digest = md5::compute(val.as_bytes()); + format!("{:x}", digest)[..8].to_uppercase() +} + +/// Python `str(list[dict])` 风格,用于与 `hashlib.md5(f\"{obj}\")` 对齐 +fn py_repr_signal_kv(sig_key: &str, sig_val: &str) -> String { + let k = sig_key.replace('\'', "\\'"); + let v = sig_val.replace('\'', "\\'"); + format!("[{{'key': '{k}', 'value': '{v}'}}]") +} + +fn py_repr_list_str(items: &[String]) -> String { + if items.is_empty() { + "[]".to_string() + } else { + let body = items + .iter() + .map(|x| format!("'{}'", x.replace('\'', "\\'"))) + .collect::>() + .join(", "); + format!("[{body}]") + } +} + +fn py_repr_event_dump(event: &Event) -> String { + let name = py_event_name(event).replace('\'', "\\'"); + let operate = event.operate.to_chinese().replace('\'', "\\'"); + let all = event + .signals_all + .iter() + .map(|s| s.to_string()) + .collect::>(); + let any = event + .signals_any + .iter() + .map(|s| s.to_string()) + .collect::>(); + let not = event + .signals_not + .iter() + .map(|s| s.to_string()) + .collect::>(); + format!( + "{{'name': '{name}', 'operate': '{operate}', 'signals_all': {}, 'signals_any': {}, 'signals_not': {}}}", + py_repr_list_str(&all), + py_repr_list_str(&any), + py_repr_list_str(¬), + ) +} + +fn py_event_name(event: &Event) -> String { + let operate = event.operate.to_chinese(); + let all = event + .signals_all + .iter() + .map(|s| s.to_string()) + .collect::>(); + let any = event + .signals_any + .iter() + .map(|s| s.to_string()) + .collect::>(); + let not = event + .signals_not + .iter() + .map(|s| s.to_string()) + .collect::>(); + let repr = format!( + "{{'operate': '{}', 'signals_all': {}, 'signals_any': {}, 'signals_not': {}}}", + operate.replace('\'', "\\'"), + py_repr_list_str(&all), + py_repr_list_str(&any), + py_repr_list_str(¬), + ); + let mut hasher = Sha256::new(); + hasher.update(repr.as_bytes()); + let hex = hex::encode(hasher.finalize()).to_uppercase(); + let sha4 = &hex[..4]; + let auto_name_prefix = matches!( + event.operate, + Operate::LE | Operate::SE | Operate::LO | Operate::SO + ) && event + .name + .split('#') + .next() + .map(|x| matches!(x, "LE" | "SE" | "LO" | "SO")) + .unwrap_or(false); + let base = if event.name.is_empty() || auto_name_prefix { + operate.to_string() + } else { + event.name.split('#').next().unwrap_or(operate).to_string() + }; + format!("{base}#{sha4}") +} + +/// 开仓优化:构建并返回新的候选策略集合 +pub fn get_open_optim_positions( + files_position: &[PathBuf], + candidate_signals: &[String], +) -> Result> { + let mut betas = Vec::new(); + for p in files_position { + betas.push(load_position(p).with_context(|| format!("加载 {:?}", p))?); + } + + let mut pos_list = betas.clone(); + for beta in betas { + for sig_str in candidate_signals { + let mut pos = beta.clone(); + if let Ok(sig) = sig_str.parse::() + && !pos.opens.is_empty() + { + let mut open_event = pos.opens[0].clone(); + open_event.signals_all.push(sig); + open_event.name = py_event_name(&open_event); + + // 对齐 Python: str([{"key": ..., "value": ...}]) 格式 + let sig_key = open_event.signals_all.last().unwrap().key(); + let sig_val = open_event.signals_all.last().unwrap().value(); + let sigs_repr = py_repr_signal_kv(&sig_key, &sig_val); + + let hash = hash_str(&sigs_repr); + pos.name = format!("{}#{}", beta.name, hash); + pos.opens[0] = open_event; + pos_list.push(pos); + } + } + } + + Ok(pos_list) +} + +/// 平仓优化:构建并返回新的候选策略集合 +pub fn get_exit_optim_positions( + files_position: &[PathBuf], + candidate_events: &[serde_json::Value], +) -> Result> { + let mut betas = Vec::new(); + for p in files_position { + betas.push(load_position(p).with_context(|| format!("加载 {:?}", p))?); + } + + let mut pos_list = betas.clone(); + for beta in betas { + let is_all_lo = beta.opens.iter().all(|x| x.operate == Operate::LO); + let is_all_so = beta.opens.iter().all(|x| x.operate == Operate::SO); + + for event_val in candidate_events { + if let Ok(mut event) = Event::load(event_val) { + event.name = py_event_name(&event); + if is_all_lo && event.operate != Operate::LE { + continue; + } + if is_all_so && event.operate != Operate::SE { + continue; + } + + let event_str = py_repr_event_dump(&event); + let hash = hash_str(&event_str); + + // mode = append + let mut pos_append = beta.clone(); + pos_append.exits.push(event.clone()); + pos_append.name = format!("{}#追加{}", beta.name, hash); + pos_list.push(pos_append); + + // mode = replace + let mut pos_replace = beta.clone(); + pos_replace.exits = vec![event.clone()]; + pos_replace.name = format!("{}#替换{}", beta.name, hash); + pos_list.push(pos_replace); + } + } + } + + Ok(pos_list) +} + +/// 提取所有需要的 unique signals 并转换为 SignalConfig +fn extract_signals_config(positions: &[Position]) -> Vec { + let mut unique_sigs = HashSet::new(); + for pos in positions { + for ev in &pos.opens { + for s in ev.all_signals() { + unique_sigs.insert(s.to_string()); + } + } + for ev in &pos.exits { + for s in ev.all_signals() { + unique_sigs.insert(s.to_string()); + } + } + } + let sigs: Vec<&str> = unique_sigs.iter().map(|s| s.as_str()).collect(); + get_signals_config(&sigs) +} + +/// 针对单个标的运行批量并行策略优化 +#[allow(clippy::too_many_arguments)] +pub fn one_symbol_optim( + symbol: &str, + bars: &[RawBar], + positions: Vec, + out_dir: &Path, + base_freq: &str, + market: Option<&str>, + bg_max_count: Option, + sdt_cutoff: Option>, +) -> Result<()> { + let symbol_dir = out_dir.join(symbol); + fs::create_dir_all(&symbol_dir)?; + + if bars.len() < 100 { + warn!("{} K线数量不足,无法跑批", symbol); + return Ok(()); + } + + let config = extract_signals_config(&positions); + + let start_time = std::time::Instant::now(); + let sdt_override = sdt_cutoff.map(|x| x.to_rfc3339()); + let plan = ExecutionPlan::compile(ExecutionPlanInput { + symbol: symbol.to_string(), + base_freq: base_freq.to_string(), + signals_config: config, + positions: positions.clone(), + market: market.map(|x| x.to_string()), + bg_max_count, + sdt: sdt_override.clone(), + include_sdt_bar: None, + }) + .map_err(anyhow::Error::msg)?; + let output = + UnifiedExecEngine::run(&plan, bars.to_vec(), sdt_override.as_deref(), false, false) + .map_err(anyhow::Error::msg)?; + let optimized_positions = output.positions; + + // 落盘 + for pos in &optimized_positions { + let name = &pos.name; + if let Ok(mut df) = pos.pairs() { + let file_path = symbol_dir.join(format!("{}.pairs.parquet", name)); + let mut file = fs::File::create(&file_path)?; + ParquetWriter::new(&mut file).finish(&mut df)?; + } + if let Ok(mut df) = pos.holds() { + if df.height() > 0 { + // 对齐 Python: n1b 最后一行 NaN/Null 填充为 0.0 + if let Ok(n1b_col) = df.column("n1b") { + let n1b = n1b_col.cast(&DataType::Float64)?; + let n1b = n1b.fill_null(FillNullStrategy::Zero)?; + let _ = df.with_column(n1b); + } + let s_sym = Series::new("symbol", vec![symbol; df.height()]); + let _ = df.with_column(s_sym); + } + + let file_path = symbol_dir.join(format!("{}.holds.parquet", name)); + let mut file = fs::File::create(&file_path)?; + ParquetWriter::new(&mut file).finish(&mut df)?; + } + } + + info!("{} 跑批完成,耗时 {:?}", symbol, start_time.elapsed()); + + Ok(()) +} + +/// 并行计算所有 symbol +#[allow(clippy::too_many_arguments)] +pub fn symbols_optim_parallel( + symbols: Vec, + bars_map: HashMap>, // memory 传入全量 K线 (或者传入回调自己读取) + positions: Vec, + out_dir: &Path, + base_freq: &str, + market: Option<&str>, + bg_max_count: Option, + sdt_cutoff: Option>, + n_threads: usize, +) { + let chunks = split_positions_into_chunks(&positions, 16); + + // 单线程下避免 rayon 嵌套并行,防止在某些环境出现卡住 + if n_threads == 1 { + for sym in symbols { + if let Some(bars) = bars_map.get(&sym) { + if chunks.len() <= 1 { + let _ = one_symbol_optim( + &sym, + bars.as_slice(), + positions.clone(), + out_dir, + base_freq, + market, + bg_max_count, + sdt_cutoff, + ); + } else { + for chunk_pos in &chunks { + let _ = one_symbol_optim( + &sym, + bars.as_slice(), + chunk_pos.clone(), + out_dir, + base_freq, + market, + bg_max_count, + sdt_cutoff, + ); + } + } + } + } + return; + } + + let run = || { + symbols.into_par_iter().for_each(|sym| { + if let Some(bars) = bars_map.get(&sym) { + if chunks.len() <= 1 { + let _ = one_symbol_optim( + &sym, + bars.as_slice(), + positions.clone(), + out_dir, + base_freq, + market, + bg_max_count, + sdt_cutoff, + ); + } else { + chunks.par_iter().for_each(|chunk_pos| { + let _ = one_symbol_optim( + &sym, + bars.as_slice(), + chunk_pos.clone(), + out_dir, + base_freq, + market, + bg_max_count, + sdt_cutoff, + ); + }); + } + } + }) + }; + + if n_threads > 0 { + match rayon::ThreadPoolBuilder::new() + .num_threads(n_threads) + .build() + { + Ok(pool) => pool.install(run), + Err(err) => { + warn!("构建 rayon 线程池失败,回退默认线程池: {err}"); + run(); + } + } + } else { + run(); + } +} diff --git a/crates/czsc-trader/src/signals/czsc_signals.rs b/crates/czsc-trader/src/signals/czsc_signals.rs new file mode 100644 index 000000000..7cabedb3e --- /dev/null +++ b/crates/czsc-trader/src/signals/czsc_signals.rs @@ -0,0 +1,393 @@ +use crate::engine_v2::catalog::SignalCategory; +use crate::engine_v2::compiler::CompiledSignalPlanV2; +use crate::signals::sig_parse::SignalConfig; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::signal::Signal; +use czsc_signals::registry; +use czsc_signals::types::TaCache; +use czsc_utils::bar_generator::BarGenerator; +use std::collections::{BTreeMap, HashMap, HashSet}; + +#[derive(Clone)] +enum CompiledKlineSignalOp { + Fast { + exec: czsc_signals::types::FastKlineExecFn, + params: serde_json::Value, + }, + Dynamic { + func: czsc_signals::types::SignalFn, + params: HashMap, + }, +} + +#[derive(Clone)] +struct CompiledKlineFreqGroup { + freq: String, + ops: Vec, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct BarFingerprint { + id: i32, + dt_ns: i64, + open_bits: u64, + close_bits: u64, + high_bits: u64, + low_bits: u64, + vol_bits: u64, + amount_bits: u64, +} + +impl BarFingerprint { + #[inline] + fn from_bar(bar: &RawBar) -> Self { + Self { + id: bar.id, + dt_ns: bar.dt.timestamp_nanos_opt().unwrap_or_default(), + open_bits: bar.open.to_bits(), + close_bits: bar.close.to_bits(), + high_bits: bar.high.to_bits(), + low_bits: bar.low.to_bits(), + vol_bits: bar.vol.to_bits(), + amount_bits: bar.amount.to_bits(), + } + } +} + +/// 多级别信号计算引擎 +pub struct CzscSignals { + /// K 线合成器 + pub bg: BarGenerator, + + /// 当次计算的基础周期标的代码,例如 "000001.SZ" + pub symbol: String, + + /// 各周期的 CZSC 分析引擎(有序,通常按周期由小到大) + pub kas: BTreeMap, + + /// TA 指标缓存容器,按 freq_str 分组,例如 "日线" -> TaCache + pub ta_cache: HashMap, + + /// 信号结果字典(每根 K 线生成完后更新,完全涵盖 python 版的 __dict__ 与 signal_dict) + pub s: HashMap, + + /// 当前K线触发的原始 Signal 对象的全集(被 trader 的位置策略直接消耗) + pub sigs: HashSet, + /// Position 事件匹配使用的信号字典:key -> value + pub signal_map: HashMap, + + /// 预编译后的 K 线信号执行计划 + compiled_kline_groups: Vec, + use_plan_compiled: bool, + compiled_cfg_ptr: usize, + compiled_cfg_len: usize, + /// 需要维护 CZSC 的频率集合;存在 trader 级信号时退化为全量维护 + required_kas_freqs: HashSet, + maintain_all_kas: bool, + + /// 按 freq 门控信号执行:末根 bar 未变化时复用上次结果 + last_freq_fingerprints: HashMap, + cached_freq_signals: HashMap>, +} + +impl CzscSignals { + pub fn new(symbol: String, bg: BarGenerator) -> Self { + let mut kas = BTreeMap::new(); + for (freq, bars_lock) in &bg.freq_bars { + let bars = bars_lock.read(); + if !bars.is_empty() { + let bars_vec: Vec = bars.iter().cloned().collect(); + kas.insert(freq.to_string(), CZSC::new(bars_vec, 50)); + } + } + + Self { + bg, + symbol, + kas, + ta_cache: HashMap::new(), + s: HashMap::new(), + sigs: HashSet::new(), + signal_map: HashMap::new(), + compiled_kline_groups: Vec::new(), + use_plan_compiled: false, + compiled_cfg_ptr: 0, + compiled_cfg_len: 0, + required_kas_freqs: HashSet::new(), + maintain_all_kas: false, + last_freq_fingerprints: HashMap::new(), + cached_freq_signals: HashMap::new(), + } + } + + fn ensure_compiled_kline_ops(&mut self, signals_config: &[SignalConfig]) { + if self.use_plan_compiled { + return; + } + let ptr = signals_config.as_ptr() as usize; + let len = signals_config.len(); + if self.compiled_cfg_ptr == ptr && self.compiled_cfg_len == len { + return; + } + + let mut grouped: HashMap> = HashMap::new(); + self.required_kas_freqs.clear(); + self.maintain_all_kas = false; + for config in signals_config { + if config.freq.is_none() { + // trader 级信号可能访问任意频率 CZSC,保守退化为全量维护 + self.maintain_all_kas = true; + } + if let Some(freq) = &config.freq + && let Some(meta) = registry::SIGNAL_REGISTRY.get(config.name.as_str()) + { + let op = if let Some(fast) = meta.fast_kline { + if let Some(p) = (fast.decode)(&config.params) { + CompiledKlineSignalOp::Fast { + exec: fast.exec, + params: p, + } + } else { + CompiledKlineSignalOp::Dynamic { + func: meta.func, + params: config.params.clone(), + } + } + } else { + CompiledKlineSignalOp::Dynamic { + func: meta.func, + params: config.params.clone(), + } + }; + grouped.entry(freq.clone()).or_default().push(op); + self.required_kas_freqs.insert(freq.clone()); + } + } + let mut freqs: Vec = grouped.keys().cloned().collect(); + freqs.sort(); + self.compiled_kline_groups.clear(); + self.compiled_kline_groups.reserve(freqs.len()); + for freq in freqs { + if let Some(ops) = grouped.remove(&freq) { + self.compiled_kline_groups + .push(CompiledKlineFreqGroup { freq, ops }); + } + } + self.compiled_cfg_ptr = ptr; + self.compiled_cfg_len = len; + } + + /// 使用 ExecutionPlan 的 signal_plan 一次性装载 K线信号执行计划。 + /// + /// 该接口会切换到 plan 驱动模式,后续 `update_signals` 不再尝试按 + /// `signals_config` 进行运行期编译。 + pub fn load_compiled_signal_plan(&mut self, plan: &CompiledSignalPlanV2) -> Result<(), String> { + let mut grouped: HashMap> = HashMap::new(); + self.required_kas_freqs.clear(); + self.maintain_all_kas = false; + + for op in &plan.ops { + if matches!(op.category, SignalCategory::Trader) { + // trader 级信号可能访问任意频率 CZSC,保守退化为全量维护 + self.maintain_all_kas = true; + continue; + } + let Some(freq) = &op.freq else { + continue; + }; + let meta = registry::SIGNAL_REGISTRY + .get(op.name.as_str()) + .ok_or_else(|| format!("未注册 K 线信号: {}", op.name))?; + let sig_op = if let Some(fast) = meta.fast_kline { + if let Some(p) = (fast.decode)( + &serde_json::from_value(op.params.clone()) + .map_err(|e| format!("信号参数解析失败 {}: {e}", op.name))?, + ) { + CompiledKlineSignalOp::Fast { + exec: fast.exec, + params: p, + } + } else { + CompiledKlineSignalOp::Dynamic { + func: meta.func, + params: serde_json::from_value(op.params.clone()) + .map_err(|e| format!("信号参数解析失败 {}: {e}", op.name))?, + } + } + } else { + CompiledKlineSignalOp::Dynamic { + func: meta.func, + params: serde_json::from_value(op.params.clone()) + .map_err(|e| format!("信号参数解析失败 {}: {e}", op.name))?, + } + }; + grouped.entry(freq.clone()).or_default().push(sig_op); + self.required_kas_freqs.insert(freq.clone()); + } + + let mut freqs: Vec = grouped.keys().cloned().collect(); + freqs.sort(); + self.compiled_kline_groups.clear(); + self.compiled_kline_groups.reserve(freqs.len()); + for freq in freqs { + if let Some(ops) = grouped.remove(&freq) { + self.compiled_kline_groups + .push(CompiledKlineFreqGroup { freq, ops }); + } + } + + self.use_plan_compiled = true; + self.compiled_cfg_ptr = 0; + self.compiled_cfg_len = 0; + Ok(()) + } + + /// 执行主更新流程 + pub fn update_signals(&mut self, bar: &RawBar, signals_config: &[SignalConfig]) { + self.ensure_compiled_kline_ops(signals_config); + + // 1. 驱动 bg 喂入新 Bar,并同步各周期 CZSC + let changed_freqs = self.advance_kas(bar, true); + + self.reset_signal_state(bar); + self.compute_kline_signals(Some(&changed_freqs)); + } + + /// 预热后用当前状态 prime 一次信号缓存,对齐 Python `CzscSignals(bg)` 构造语义。 + /// + /// Python 基线会在 `bg` 初始化完成后立刻调用 `get_signals_by_conf()`,因此像 ER + /// 这类“历史 bar 只计算一次后缓存”的信号,必须在右侧主循环前先对当前末 bar + /// 状态做一次信号计算,否则后续流式更新会与 Python 基线永久错位。 + pub fn prime_signals(&mut self, bar: &RawBar, signals_config: &[SignalConfig]) { + self.ensure_compiled_kline_ops(signals_config); + // 对齐 Python `CzscSignals(bg)`:预热后先基于 BG 快照一次性重建各频率 CZSC, + // 再做首轮信号计算;避免 warmup 期间对同一高周期 dt 的增量更新造成路径漂移。 + self.rebuild_kas_from_bg(); + self.reset_signal_state(bar); + self.compute_kline_signals(None); + } + + fn reset_signal_state(&mut self, bar: &RawBar) { + self.s.clear(); + self.sigs.clear(); + self.signal_map.clear(); + self.s.insert("symbol".to_string(), self.symbol.clone()); + self.s.insert("dt".to_string(), bar.dt.to_rfc3339()); + self.s.insert("id".to_string(), bar.id.to_string()); + self.s.insert("freq".to_string(), bar.freq.to_string()); + self.s.insert("open".to_string(), bar.open.to_string()); + self.s.insert("close".to_string(), bar.close.to_string()); + self.s.insert("high".to_string(), bar.high.to_string()); + self.s.insert("low".to_string(), bar.low.to_string()); + self.s.insert("vol".to_string(), bar.vol.to_string()); + self.s.insert("amount".to_string(), bar.amount.to_string()); + } + + fn compute_kline_signals(&mut self, changed_freqs: Option<&HashSet>) { + for group in &self.compiled_kline_groups { + if let Some(changed_freqs) = changed_freqs + && !changed_freqs.contains(group.freq.as_str()) + && let Some(cached_sigs) = self.cached_freq_signals.get(group.freq.as_str()) + { + for sig in cached_sigs { + let (k, v) = (sig.key(), sig.value()); + self.s.insert(k.clone(), v.clone()); + self.signal_map.insert(k, v); + self.sigs.insert(sig.clone()); + } + continue; + } + + if let Some(czsc) = self.kas.get(group.freq.as_str()) { + let cache = self.ta_cache.entry(group.freq.clone()).or_default(); + let mut freq_sigs = Vec::new(); + for op in &group.ops { + let sigs_res = match op { + CompiledKlineSignalOp::Fast { exec, params } => (exec)(czsc, params, cache), + CompiledKlineSignalOp::Dynamic { func, params } => { + (func)(czsc, params, cache) + } + }; + for sig in sigs_res { + let (k, v) = (sig.key(), sig.value()); + self.s.insert(k.clone(), v.clone()); + self.signal_map.insert(k, v); + self.sigs.insert(sig.clone()); + freq_sigs.push(sig); + } + } + self.cached_freq_signals + .insert(group.freq.clone(), freq_sigs); + } + } + } + + /// 预热阶段仅推进 BG,不执行任何信号函数,也不增量维护 CZSC。 + /// + /// 对齐 Python `generate_czsc_signals`:`bars_left` 只用于初始化 `BarGenerator`, + /// 不会在 warmup 阶段调用 `update_signals`。CZSC 会在 warmup 结束后由 + /// `prime_signals` 基于 BG 快照一次性重建。 + pub fn warmup_bar(&mut self, bar: &RawBar) { + let _ = self.bg.update_bar(bar); + } + + fn rebuild_kas_from_bg(&mut self) { + self.kas.clear(); + self.last_freq_fingerprints.clear(); + self.cached_freq_signals.clear(); + + for (freq, bars_lock) in &self.bg.freq_bars { + let bars = bars_lock.read(); + if bars.is_empty() { + continue; + } + let bars_vec: Vec = bars.iter().cloned().collect(); + self.kas.insert(freq.to_string(), CZSC::new(bars_vec, 50)); + } + } + + fn advance_kas(&mut self, bar: &RawBar, update_fingerprint: bool) -> HashSet { + let _ = self.bg.update_bar(bar); + + let mut changed_freqs: HashSet = HashSet::new(); + for (freq, bars_lock) in &self.bg.freq_bars { + let freq_str = freq.to_string(); + if !self.maintain_all_kas && !self.required_kas_freqs.contains(freq_str.as_str()) { + continue; + } + let bars = bars_lock.read(); + if bars.is_empty() { + continue; + } + + let last_bar = bars.back().expect("bars not empty"); + let fingerprint = BarFingerprint::from_bar(last_bar); + let is_changed = if update_fingerprint { + let changed = self + .last_freq_fingerprints + .get(&freq_str) + .map(|prev| *prev != fingerprint) + .unwrap_or(true); + self.last_freq_fingerprints + .insert(freq_str.clone(), fingerprint); + changed + } else { + true + }; + + if !self.kas.contains_key(&freq_str) { + let bars_vec: Vec = bars.iter().cloned().collect(); + let czsc = CZSC::new(bars_vec, 50); + self.kas.insert(freq_str.clone(), czsc); + changed_freqs.insert(freq_str); + } else if is_changed { + if let Some(czsc) = self.kas.get_mut(&freq_str) { + czsc.update_bar(last_bar.clone()); + } + changed_freqs.insert(freq_str); + } + } + changed_freqs + } +} diff --git a/crates/czsc-trader/src/signals/mod.rs b/crates/czsc-trader/src/signals/mod.rs new file mode 100644 index 000000000..0714c4bbf --- /dev/null +++ b/crates/czsc-trader/src/signals/mod.rs @@ -0,0 +1,5 @@ +pub mod czsc_signals; +pub mod sig_parse; + +pub use czsc_signals::CzscSignals; +pub use sig_parse::{SignalConfig, get_signals_config, get_signals_freqs}; diff --git a/crates/czsc-trader/src/signals/sig_parse.rs b/crates/czsc-trader/src/signals/sig_parse.rs new file mode 100644 index 000000000..fd24459f2 --- /dev/null +++ b/crates/czsc-trader/src/signals/sig_parse.rs @@ -0,0 +1,459 @@ +use czsc_core::objects::freq::Freq; +use czsc_signals::registry::{SIGNAL_REGISTRY, TRADER_SIGNAL_REGISTRY}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::cmp::Ordering; +use std::collections::HashMap; +use tracing::error; + +/// 单个信号函数配置 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SignalConfig { + /// 函数名,如 "tas_ma_base_V221101" + pub name: String, + /// 关联的 K 线周期(若为 None,表示该函数接受 CzscSignals 而非单个 CZSC) + pub freq: Option, + /// 函数参数(di/ma_type/timeperiod 等) + pub params: HashMap, +} + +impl SignalConfig { + /// 从单独的信号字符串推导配置(对应 Python get_signals_config 中的反解析逻辑) + /// 例如: "日线_D1SMA#5_分类V221101_多头_向上_任意_0" -> freq="日线", name="tas_ma_base_V221101", params={di=1, ma_type="SMA", timeperiod=5} + pub fn from_signal_str(signal: &str) -> Option { + let parts: Vec<&str> = signal.split('_').collect(); + if parts.len() != 7 { + error!("非标准信号格式: {}", signal); + return None; + } + + let k3 = parts[2]; + let key = parts[..3].join("_"); + + for (func_name, meta) in SIGNAL_REGISTRY.iter() { + let Some(tpl_parts) = split_template_parts(meta.param_template) else { + continue; + }; + if tpl_parts.len() != 3 { + continue; + } + if tpl_parts[2] == k3 + && let Some((freq, params)) = parse_template_into_config(meta.param_template, &key) + { + return Some(Self { + name: func_name.to_string(), + freq, + params, + }); + } + } + for (func_name, meta) in TRADER_SIGNAL_REGISTRY.iter() { + let Some(tpl_parts) = split_template_parts(meta.param_template) else { + continue; + }; + if tpl_parts.len() != 3 { + continue; + } + if tpl_parts[2] == k3 + && let Some((freq, params)) = parse_template_into_config(meta.param_template, &key) + { + return Some(Self { + name: func_name.to_string(), + freq, + params, + }); + } + } + None + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum TemplateToken { + Lit(String), + Placeholder(String), +} + +fn tokenize_template_segment(segment: &str) -> Option> { + let mut tokens = Vec::new(); + let chars: Vec = segment.chars().collect(); + let mut i = 0; + let mut lit = String::new(); + + while i < chars.len() { + if chars[i] == '{' { + if !lit.is_empty() { + tokens.push(TemplateToken::Lit(std::mem::take(&mut lit))); + } + let mut j = i + 1; + while j < chars.len() && chars[j] != '}' { + j += 1; + } + if j >= chars.len() { + return None; + } + let name: String = chars[i + 1..j].iter().collect(); + if name.is_empty() { + return None; + } + tokens.push(TemplateToken::Placeholder(name)); + i = j + 1; + } else { + lit.push(chars[i]); + i += 1; + } + } + + if !lit.is_empty() { + tokens.push(TemplateToken::Lit(lit)); + } + Some(tokens) +} + +fn parse_scalar_value(raw: &str) -> Value { + if let Ok(v) = raw.parse::() { + return Value::from(v); + } + if raw.eq_ignore_ascii_case("true") { + return Value::from(true); + } + if raw.eq_ignore_ascii_case("false") { + return Value::from(false); + } + Value::String(raw.to_string()) +} + +fn parse_template_segment(segment: &str, raw: &str) -> Option> { + let tokens = tokenize_template_segment(segment)?; + parse_template_tokens(&tokens, raw, 0, 0, HashMap::new()) +} + +fn parse_template_into_config( + template: &str, + key: &str, +) -> Option<(Option, HashMap)> { + let tpl_parts = split_template_parts(template)?; + let key_parts: Vec<&str> = key.splitn(3, '_').collect(); + if tpl_parts.len() != key_parts.len() { + return None; + } + + let mut freq = None; + let mut params = HashMap::new(); + for (tpl_part, key_part) in tpl_parts.iter().zip(key_parts.iter()) { + for (name, value) in parse_template_segment(tpl_part, key_part)? { + if name == "freq" { + let v = value.as_str()?.to_string(); + freq = Some(v); + } else { + params.insert(name, value); + } + } + } + Some((freq, params)) +} + +fn split_template_parts(template: &str) -> Option> { + let mut parts = Vec::new(); + let mut depth = 0usize; + let mut start = 0usize; + + for (idx, ch) in template.char_indices() { + match ch { + '{' => depth += 1, + '}' => { + if depth == 0 { + return None; + } + depth -= 1; + } + '_' if depth == 0 => { + parts.push(template.get(start..idx)?); + start = idx + ch.len_utf8(); + } + _ => {} + } + } + + if depth != 0 { + return None; + } + parts.push(template.get(start..)?); + Some(parts) +} + +fn parse_template_tokens( + tokens: &[TemplateToken], + raw: &str, + token_idx: usize, + raw_idx: usize, + params: HashMap, +) -> Option> { + if token_idx == tokens.len() { + return (raw_idx == raw.len()).then_some(params); + } + + match tokens.get(token_idx)? { + TemplateToken::Lit(lit) => { + let rest = raw.get(raw_idx..)?; + if !rest.starts_with(lit) { + return None; + } + parse_template_tokens(tokens, raw, token_idx + 1, raw_idx + lit.len(), params) + } + TemplateToken::Placeholder(name) => { + if token_idx + 1 == tokens.len() { + let captured = raw.get(raw_idx..)?; + if captured.is_empty() { + return None; + } + let mut next_params = params; + next_params.insert(name.clone(), parse_scalar_value(captured)); + return Some(next_params); + } + + let rest = raw.get(raw_idx..)?; + match tokens.get(token_idx + 1)? { + TemplateToken::Lit(lit) => { + for (pos, _) in rest.match_indices(lit) { + if pos == 0 { + continue; + } + let Some(captured) = rest.get(..pos) else { + continue; + }; + let mut next_params = params.clone(); + next_params.insert(name.clone(), parse_scalar_value(captured)); + if let Some(parsed) = parse_template_tokens( + tokens, + raw, + token_idx + 1, + raw_idx + pos, + next_params, + ) { + return Some(parsed); + } + } + None + } + TemplateToken::Placeholder(_) => { + let mut end_points = Vec::new(); + for (offset, _) in rest.char_indices() { + if offset > 0 { + end_points.push(raw_idx + offset); + } + } + end_points.push(raw.len()); + + for end in end_points { + if end <= raw_idx { + continue; + } + let Some(captured) = raw.get(raw_idx..end) else { + continue; + }; + let mut next_params = params.clone(); + next_params.insert(name.clone(), parse_scalar_value(captured)); + if let Some(parsed) = + parse_template_tokens(tokens, raw, token_idx + 1, end, next_params) + { + return Some(parsed); + } + } + None + } + } + } + } +} + +/// 从 unique_signals 中获取去重的 SignalConfig 列表 +pub fn get_signals_config(unique_signals: &[&str]) -> Vec { + let mut configs = Vec::new(); + let mut seen = std::collections::HashSet::new(); + + for sig in unique_signals { + if let Some(cfg) = SignalConfig::from_signal_str(sig) { + let key = format!("{:?}", cfg); + if seen.insert(key) { + configs.push(cfg); + } + } + } + configs +} + +/// 获取信号中所有不同的周期(freq) +pub fn get_signals_freqs(signals_config: &[SignalConfig]) -> Vec { + let mut freqs = std::collections::HashSet::new(); + for cfg in signals_config { + if let Some(f) = &cfg.freq { + freqs.insert(f.clone()); + } + for (k, v) in &cfg.params { + if !k.starts_with("freq") { + continue; + } + if let Some(f) = v.as_str() { + freqs.insert(f.to_string()); + } + } + } + + let mut result: Vec = freqs.into_iter().collect(); + result.sort_by(|a, b| match (a.parse::(), b.parse::()) { + (Ok(fa), Ok(fb)) => fa.cmp(&fb), + (Ok(_), Err(_)) => Ordering::Less, + (Err(_), Ok(_)) => Ordering::Greater, + (Err(_), Err(_)) => a.cmp(b), + }); + result +} + +#[cfg(test)] +mod tests { + use super::{ + SignalConfig, get_signals_freqs, parse_template_into_config, parse_template_segment, + split_template_parts, + }; + use serde_json::Value; + use std::collections::HashMap; + + #[test] + fn test_from_signal_str_parses_bar_single_di_n() { + let sig = "5分钟_D2单K趋势N20_BS辅助V230506_第6层_任意_任意_0"; + let cfg = SignalConfig::from_signal_str(sig).expect("should parse signal config"); + assert_eq!(cfg.name, "bar_single_V230506"); + assert_eq!(cfg.freq.as_deref(), Some("5分钟")); + assert_eq!(cfg.params.get("di"), Some(&Value::from(2))); + assert_eq!(cfg.params.get("n"), Some(&Value::from(20))); + } + + #[test] + fn test_from_signal_str_parses_kline_template_params_without_k2_fallback() { + let sig = "60分钟_D1SMA#5_分类V221101_多头_向上_任意_0"; + let cfg = SignalConfig::from_signal_str(sig).expect("should parse signal config"); + assert_eq!(cfg.name, "tas_ma_base_V221101"); + assert_eq!(cfg.freq.as_deref(), Some("60分钟")); + assert_eq!(cfg.params.get("di"), Some(&Value::from(1))); + assert_eq!(cfg.params.get("ma_type"), Some(&Value::from("SMA"))); + assert_eq!(cfg.params.get("timeperiod"), Some(&Value::from(5))); + assert!(!cfg.params.contains_key("k2")); + } + + #[test] + fn test_from_signal_str_parses_trader_multi_freq_template_params() { + let sig = "日线#60分钟_MACD交叉_联立V230518_其他_任意_任意_0"; + let cfg = SignalConfig::from_signal_str(sig).expect("should parse signal config"); + assert_eq!(cfg.name, "cat_macd_V230518"); + assert_eq!(cfg.freq, None); + assert_eq!(cfg.params.get("freq1"), Some(&Value::from("日线"))); + assert_eq!(cfg.params.get("freq2"), Some(&Value::from("60分钟"))); + assert!(!cfg.params.contains_key("k2")); + } + + #[test] + fn test_parse_template_into_config_handles_adjacent_placeholders() { + let tpl_parts = split_template_parts("{freq}_D{di}{ma_type}#{timeperiod}_分类V221101") + .expect("template should split"); + let key_parts: Vec<&str> = "60分钟_D1SMA#5_分类V221101".split('_').collect(); + assert_eq!( + tpl_parts, + vec!["{freq}", "D{di}{ma_type}#{timeperiod}", "分类V221101"] + ); + assert_eq!(key_parts, vec!["60分钟", "D1SMA#5", "分类V221101"]); + assert_eq!( + parse_template_segment("{freq}", "60分钟") + .expect("freq segment should parse") + .get("freq"), + Some(&Value::from("60分钟")) + ); + assert!( + parse_template_segment("分类V221101", "分类V221101").is_some(), + "literal segment should parse" + ); + let parsed = parse_template_into_config( + "{freq}_D{di}{ma_type}#{timeperiod}_分类V221101", + "60分钟_D1SMA#5_分类V221101", + ) + .expect("should parse template"); + assert_eq!(parsed.0.as_deref(), Some("60分钟")); + assert_eq!(parsed.1.get("di"), Some(&Value::from(1))); + assert_eq!(parsed.1.get("ma_type"), Some(&Value::from("SMA"))); + assert_eq!(parsed.1.get("timeperiod"), Some(&Value::from(5))); + } + + #[test] + fn test_parse_template_segment_handles_adjacent_placeholders() { + let params = parse_template_segment("D{di}{ma_type}#{timeperiod}", "D1SMA#5") + .expect("segment should parse"); + assert_eq!(params.get("di"), Some(&Value::from(1))); + assert_eq!(params.get("ma_type"), Some(&Value::from("SMA"))); + assert_eq!(params.get("timeperiod"), Some(&Value::from(5))); + } + + #[test] + fn test_parse_template_segment_handles_adjacent_numeric_and_string_placeholders_without_name_hints() + { + let params = parse_template_segment("D{di}{mode}", "D1ZF").expect("segment should parse"); + assert_eq!(params.get("di"), Some(&Value::from(1))); + assert_eq!(params.get("mode"), Some(&Value::from("ZF"))); + } + + #[test] + fn test_split_template_parts_ignores_underscores_inside_placeholders() { + let parts = split_template_parts("{freq}_D{di}{ma_type}#{timeperiod}_分类V221101") + .expect("template should split"); + assert_eq!( + parts, + vec!["{freq}", "D{di}{ma_type}#{timeperiod}", "分类V221101"] + ); + } + + #[test] + fn test_get_signals_freqs_collects_freq_params() { + let mut params = HashMap::new(); + params.insert("freq1".to_string(), Value::from("日线")); + params.insert("freq2".to_string(), Value::from("60分钟")); + let cfgs = vec![ + SignalConfig { + name: "cat_macd_V230518".to_string(), + freq: None, + params, + }, + SignalConfig { + name: "tas_ma_base_V221101".to_string(), + freq: Some("15分钟".to_string()), + params: HashMap::new(), + }, + ]; + let freqs = get_signals_freqs(&cfgs); + assert!(freqs.contains(&"15分钟".to_string())); + assert!(freqs.contains(&"日线".to_string())); + assert!(freqs.contains(&"60分钟".to_string())); + } + + #[test] + fn test_get_signals_freqs_uses_builtin_freq_order() { + let cfgs = vec![ + SignalConfig { + name: "demo_a".to_string(), + freq: Some("日线".to_string()), + params: HashMap::new(), + }, + SignalConfig { + name: "demo_b".to_string(), + freq: Some("30分钟".to_string()), + params: HashMap::new(), + }, + SignalConfig { + name: "demo_c".to_string(), + freq: Some("60分钟".to_string()), + params: HashMap::new(), + }, + ]; + + let freqs = get_signals_freqs(&cfgs); + assert_eq!(freqs, vec!["30分钟", "60分钟", "日线"]); + } +} diff --git a/crates/czsc-trader/src/trader.rs b/crates/czsc-trader/src/trader.rs new file mode 100644 index 000000000..1eee233be --- /dev/null +++ b/crates/czsc-trader/src/trader.rs @@ -0,0 +1,207 @@ +use crate::signals::czsc_signals::CzscSignals; +use crate::signals::sig_parse::SignalConfig; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::RawBar; +use czsc_core::objects::position::{LiteBar, Position}; +use czsc_core::objects::state::TraderState; +use czsc_signals::types::TraderSignalFn; +use czsc_utils::bar_generator::BarGenerator; +use polars::prelude::*; +use serde_json::Value; +use std::collections::HashMap; +use std::fs::File; +use std::path::Path; +use std::time::Instant; + +#[derive(Debug, Clone, Copy, Default)] +pub struct UpdateProfile { + pub signals_update_ns: u128, + pub trader_signals_ns: u128, + pub position_update_ns: u128, + pub pos_event_match_ns: u128, + pub pos_fsm_ns: u128, + pub pos_risk_ns: u128, + pub pos_holds_ns: u128, +} + +#[derive(Clone)] +struct CompiledTraderSignalOp { + func: TraderSignalFn, + params: HashMap, +} + +/// 多策略联合交易引擎 +pub struct CzscTrader { + /// 交易引擎名称 + pub name: String, + /// 内部驱动所有K线和信号的引擎 + pub signals: CzscSignals, + /// 仓位策略实例 + pub positions: Vec, + compiled_trader_ops: Vec, + compiled_cfg_ptr: usize, + compiled_cfg_len: usize, +} + +impl CzscTrader { + /// 构造一个新的 Trader + pub fn new(symbol: String, bg: BarGenerator, positions: Vec) -> Self { + Self { + name: "CzscTrader".to_string(), + signals: CzscSignals::new(symbol, bg), + positions, + compiled_trader_ops: Vec::new(), + compiled_cfg_ptr: 0, + compiled_cfg_len: 0, + } + } + + fn ensure_compiled_trader_ops(&mut self, signals_config: &[SignalConfig]) { + let ptr = signals_config.as_ptr() as usize; + let len = signals_config.len(); + if self.compiled_cfg_ptr == ptr && self.compiled_cfg_len == len { + return; + } + + self.compiled_trader_ops.clear(); + self.compiled_trader_ops.reserve(signals_config.len()); + for config in signals_config { + if config.freq.is_none() + && let Some(meta) = + czsc_signals::registry::TRADER_SIGNAL_REGISTRY.get(config.name.as_str()) + { + self.compiled_trader_ops.push(CompiledTraderSignalOp { + func: meta.func, + params: config.params.clone(), + }); + } + } + self.compiled_cfg_ptr = ptr; + self.compiled_cfg_len = len; + } + + /// 输入基础周期已完成K线,更新信号,更新仓位 + pub fn update(&mut self, bar: &RawBar, signals_config: &[SignalConfig]) { + let _ = self.update_profiled(bar, signals_config); + } + + /// 与 update 行为一致,但返回分段耗时(纳秒) + pub fn update_profiled( + &mut self, + bar: &RawBar, + signals_config: &[SignalConfig], + ) -> UpdateProfile { + self.ensure_compiled_trader_ops(signals_config); + + let t_signals = Instant::now(); + // 1. 调用 signals 获得本根K线上的所有状态更新 + self.signals.update_signals(bar, signals_config); + let signals_update_ns = t_signals.elapsed().as_nanos(); + + // 1.5 执行 trader 级别的 signals(pos 系列:需要访问仓位状态) + let t_trader_sig = Instant::now(); + let mut trader_sigs = Vec::new(); + for op in &self.compiled_trader_ops { + let sigs = (op.func)(self, &op.params); + trader_sigs.extend(sigs); + } + let trader_signals_ns = t_trader_sig.elapsed().as_nanos(); + + let t_pos = Instant::now(); + for sig in trader_sigs { + let (k, v) = (sig.key(), sig.value()); + self.signals.s.insert(k.clone(), v.clone()); + self.signals.signal_map.insert(k, v); + self.signals.sigs.insert(sig); + } + + // 2. 构建需要输入给 position 的上下文 + let lite_bar = LiteBar { + id: bar.id, + dt: bar.dt.into(), + price: bar.close, + }; + // 3. 遍历更新所有策略仓位 + let mut pos_event_match_ns = 0u128; + let mut pos_fsm_ns = 0u128; + let mut pos_risk_ns = 0u128; + let mut pos_holds_ns = 0u128; + for pos in &mut self.positions { + let p = + pos.update_profiled_with_signal_map(lite_bar, None, Some(&self.signals.signal_map)); + pos_event_match_ns += p.event_match_ns; + pos_fsm_ns += p.fsm_ns; + pos_risk_ns += p.risk_ns; + pos_holds_ns += p.holds_ns; + } + let position_update_ns = t_pos.elapsed().as_nanos(); + + UpdateProfile { + signals_update_ns, + trader_signals_ns, + position_update_ns, + pos_event_match_ns, + pos_fsm_ns, + pos_risk_ns, + pos_holds_ns, + } + } + + /// 将各个仓位的交易对与持仓结果输出到指定目录的 parquet 文件中 + pub fn dump_results(&self, out_dir: &str) -> anyhow::Result<()> { + let path = Path::new(out_dir); + if !path.exists() { + std::fs::create_dir_all(path)?; + } + + let mut all_pairs = Vec::new(); + let mut all_holds = Vec::new(); + + for pos in &self.positions { + if let Ok(df) = pos.pairs() + && df.height() > 0 + { + all_pairs.push(df.lazy()); + } + if let Ok(df) = pos.holds() + && df.height() > 0 + { + all_holds.push(df.lazy()); + } + } + + if !all_pairs.is_empty() { + let mut combined_pairs = concat(all_pairs, UnionArgs::default())?.collect()?; + let mut file = File::create(path.join("pairs.parquet"))?; + ParquetWriter::new(&mut file).finish(&mut combined_pairs)?; + } + + if !all_holds.is_empty() { + let mut combined_holds = concat(all_holds, UnionArgs::default())?.collect()?; + let mut file = File::create(path.join("holds.parquet"))?; + ParquetWriter::new(&mut file).finish(&mut combined_holds)?; + } + + Ok(()) + } +} + +impl TraderState for CzscTrader { + #[inline] + fn get_position(&self, name: &str) -> Option<&Position> { + self.positions.iter().find(|p| p.name == name) + } + + #[inline] + fn get_czsc(&self, freq: &str) -> Option<&CZSC> { + self.signals.kas.get(freq) + } + + #[inline] + fn latest_price(&self) -> Option { + self.signals + .s + .get("close") + .and_then(|x| x.parse::().ok()) + } +} diff --git a/crates/czsc-utils/Cargo.toml b/crates/czsc-utils/Cargo.toml new file mode 100644 index 000000000..25cef94bb --- /dev/null +++ b/crates/czsc-utils/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "czsc-utils" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "CZSC utilities — BarGenerator, freq calendar, is_trading_time. Placeholder, to be migrated." + +[lib] +name = "czsc_utils" +path = "src/lib.rs" + +[dependencies] +anyhow = "1" +chrono = { workspace = true } +czsc-core = { path = "../czsc-core" } +error-macros = { path = "../error-macros" } +error-support = { path = "../error-support" } +hashbrown = { workspace = true } +once_cell = "1" +parking_lot = "0.12" +polars = { workspace = true, features = ["ipc", "partition_by"] } +serde = { workspace = true } +thiserror = "2" +pyo3 = { workspace = true, optional = true, features = ["chrono"] } +pyo3-stub-gen = { version = "0.12", optional = true } + +[features] +python = ["pyo3", "pyo3-stub-gen"] + +[dev-dependencies] +chrono = { workspace = true } +polars = { workspace = true } +anyhow = "1" diff --git a/crates/czsc-utils/data/minutes_split.feather b/crates/czsc-utils/data/minutes_split.feather new file mode 100644 index 0000000000000000000000000000000000000000..6062084fa51ab57b1d709b0be9c2181a25e63f59 GIT binary patch literal 175674 zcmeF)51dr<{y*^9+6}38N+n6M(m$!)ZA%i8BuOPnlK!bw`X@<9lBAoEZjy8pk|aq& zl5`V7l7tYFBqSl-5c_+6c6KQ#?vK4<-1a@V}`in@01(zA2> zo#!QPT>Ji)^$lB0yQFvb^ZRs8+=xqh^}X!J1${d8=yHDlaN~M)>EE$)$Nn8Pzap(m zZ_T^+z99VVSuGnLRXgp<9d69dw<4`q$9@-f$yvuvw`7O;zw6e7yX(rTZdC3bruFac zZgh)x-df(=y}PZ9wwKh*Ot01Ul4G)t>ALGicUio5x1PD@t?#mU?{2-up0~cs;=Q}| ztUYgim&JQ`>oxbh^<5V4-K~4R?de&!%i_Je^^EjAZ+_S1dwKi2Jg;{?(ld5hzL&SZ z%cJ|V?eDUD?`}WySML8^m+#%}*Z7s~@49^NZok&AY=76~dw2U8yWYV+djjm*;P>Bx zocH>ET{`yf(=X>OxM%lH{W|u$EW88vy{u!uetj+tZ$AUN^t-rwpWb0^Mpm`-YME)? z8q;#OzQe-IYW}5N=`nw>!+r%mD&~h;g{DN+Sb^iZA`n`duY4E zO^8+1uJaz+X4L%g=EhoUH+oNPGkZYuYLL{i6sLLOk7TDm)COV>Ad#)4o%FzD!2WJ+jefBH=Wmh!~DeZ zkBcQPSG~M*`@HGA_O-uCERPpXTsF%2-_B2elQ*5$e%j$)ApA@X2fd%>M1H@0gYj);qWT!oz=h{)H{( zP3N^QcS>S;)!b&yPu;nF-gI93Sv?cWPs_U*y?1UO%bm_^-{I!O@}-I0&G{YP2X=hE zhnfHD(~WYE%jC4vOLeF>)Z~z`+u>OgUVy^U5?*%8Ql2WLQ;WLPrx8tQNgF!Qm7ert zAcGmkNJcY`iQL6BW-^C)EaX|jL$iWayvsU1VKZCVPCgI!z7(N2r71@Rs**`mAu6o*0Yf>*~Slq(~5$`DM=a9 zs6;igs6#y((u5YYrX8K>L0<+ih@p(&M#eIM$xLMivzf~R7V`qjd7ag)Wdobo!ned6 zq6H{KF-lRE@>C(6TGXXJjc7_s+R%aUve1)$3}i6F7|CeHF_F8N#!TigkA*zTQdY2v zcUi|LY-TIl$(QW$rwGLHDlp&2uR3nQz)T1FyXhCb*(U~6fWdMU1$_Q>` zEEAZ_RAw-nxh!BYFR+}~S5obqAca9LOQjmOMM#Al$Nxi16}D! zKL#?GVT@!n7wsfK!z3I>83}HB<7{hobF@@>O;!)lN=w?%fv)tV9|IZ8Fh(+(aZKbcrZJN_%wr+XvXm99 z;$7D937grODDR~oBmwR5QZ~~F^p#tQ<%;y z9%Vj@Si&+^@)m1Y&qlsv8$XcjWA=jKqjpMChBPWsjV$U=kA^g%1+8gEXL`_=0Ssa& zBe;>VOkgronZa!4vVg_Bz;a$^HEY?xCbsY`;f=olg(yZT%2J*xq*IH!)Ta?mX-OM8 z(3PI_V<3YW#z;mpj)~mGG-fh~c`W2vma>9XyvsU1VKZCVPClpS`%;ABl%^aNs7fZa zIi3bIra7%>ODDR~oBmwR5QZ~~F^p#tQ<%;y9%Vj@Si&+^@)m1Y&qlsv8$Xa-#N$t# zl9VBhN>n3@I@F^fO=v-D+R>RF^ko2p7|IB4WGoYy%v5GDo4G7tF)y&3*ICV4Hn52; zd`tK|qyU8|Mk&fto+_kMi@MaO5lv}H8#>UHp7diNgBiw1Ml+6y+{H9zGKYCApfBIzC}DTiH%NpSkQy5sFiqa#WxynbhWZ8qk>Lw4yDY=tgh)b2&p8&M3w(o=Hq$ zIVOkgronZa!4vVg_Bz;a$^HEY?xCbsY`;iSI+g(yZT%2J*xq*IH!)Ta?mX-OM8 z(3PI_V<3YW#z;mpj)~mGG-fh~c`W2vma>9XyvsU1VKZCVPQC*?{uH4&r71@Rs**`< zj;8^QX-+HJ(ur>LrazZ6gyD>04C9%^6s9waN14wemavSKyu}*Uvym^^#t$U>Jf$FU zN>YY2Dp8Fr>QIk{G@%8pX-8*z(3b%WVkjfHk+Do*GE?Mb$r5Rwz8dkB|ZKWp*W=}M+K^qNo|g&0gY)+E85bDZuF)SkUAaP1khBPWsjV$U=kA^g%1+8gEXL`_=@G;pS zhBAU18OsDFGnE<4W-bd@%nK~%byl;M4Qyfy-xB`AUw}dsqZDN+PZiRsMP2ICh^Dlp z4ISu8Px>*C!3<+0qZ!9U?qV7aH6H+s{b%NfFOMlpu*OkxVtnZ={bXAw(S#!B8|4eQy+mu%w)l1qF1 ziBpm?q)~}#WKoBDG^7bFXiYmh(}TVYU=TwY!HtY%0+X4_3}!Qz1uW(Tmh(ERS<41C zv4wAm9qjR^5XC4(S;|v|bZSwT`ZS^`Eonmsy3&(=3}i6F7|CeHF_F8N#!TigkA*zT zQdY2vcUi|LY-TIl$ydhXPZ5e!nsQX2Dw)*gcpA`{=Cq%@N>P^bR3V*O)TKU+Xi7`k z(1EV>q#pwr%rHhWnsH3zE~YV)Im}}r&$5&itm0kP@d=yR%69S{>hY%t#VJiWDo~Y7 zYI8gdXiRfj(Uwkhqc{DzoFNQn6k{0AB&IN(Sv<;o7O{k7tmG}$u%3;4$u@oWn*0iHDJ?P5-1~HTo+{joaFqx^$U^a7Ez+zrtIj^&twQOJ$ zTlkjPVIF@9QH)ZQr94$grxtaoPa~Srk~VapD?RDQKn636k&I>>6S<3N%w!JpSjb|Q zu$1MjWEHDf!#XywkHTK>10uxy40fqjc7u1TGE=f zbf7cc=t*DtGmt?HVHhJA#c0Mdo{3Cm3e%XuEaotm`7C5HOIXTsR)60XHnW9oY$tqCk<7jnB2IBi zQHFAqrxI03CyUzDr5+7vL=&3RlGe1P1D)wcPx{iIfec~@!x+IRMl+W2Ok^@sn8pld zF^9R#XCaGO!cvyAl2xo`4eQvzMmDpBZEPoYg#Bk<3K6F`r6@x=%2SD|q?1K$>Qave zG@=R3X-R9^(t*x&qbGgo&p-w-gkg+e6r&l-cqTHLDNJJqvzWtN=ChE+EMY0jS;ay;6}zWfyqo|2D6#V z0v7WE%XyvEtYrh6*uuBODtc|D5XC4(S;|v|bZSwT`ZS^`Eonmsy3&(=3}i6F7|CeH zF_F8N#!TigkA*zTQdY2vcUi|LY-TIl$ydp9h$0lHH07v3RWhl~@id?@&1pqjI?;{Z z^yhMhFq~10VLX$V!gOZwDDzpw5|*)&w^+k^Hu5Fg_<`ig9)IGLqzq|Pq8eG$p&kus zLJL~cj?VO;F9R6FP)2YgW0}BYrZR)s%w++Kd4c7;&T7`OflX}TTf!Hp1t>%@N>P^b zR3V*O)TKU+Xi7`k(1EV>q#pwr%rHhWnsH3zE~YV)Im}}r&$5&itm0kP@d=yR%69S{ z>G7us#VJiWDo~Y7YI8gdXiRfj(Uwkhqc{DzoFNQn6k{0AB&IN(Sv<;o7O{k7tmG}$ zu%3;4$u@oWn*0iHDJ?P5-1~HTo+{joaFqx^$U^a7E zz+zrtIj^&twQOJ$TlkjPQ67H^QH)ZQr94$grxtaoPa~Srk~VapD?RDQKn636k&I>> z6S<3N%w!JpSje+1Wd*Bvmvwx?X121OeAPVu6rnh!DMtmWl1XijrvZ&=PAl5diEi|! zKbJFv;f!JoF zTF{zybfyP=8NeWhGJ+c!%LFDfl^M)tE(=)93oPe#RlN=w?%fv)tV9|IZ8Fh(+(aZKbcrZJN_%wr+XvXm99;$7D937grh#sR3($z98Uup)0|ecr4!xgO@A(D2*Vl07{)V+DNJV;k20S{EMXZdd5bly zXCq&-jUPzP^!O8}BxOjW64l6}4)thA6I#%kc66o(eHp+YhBAU18OsDFGnE<4W-bd@ z%nK~%bykz}lXoRIUvs9JeLj9-{DzrR3gpz@guLEvf0i^mal_C%lXDs-^;tjhjtz6} z-TBY9HqV)vyD=Ay!U?dSpXGT{%qmDK;UiujN9AOCLThiI;bkcnZt~I8E_xS-{gQ;U!k^2CI3W^?b@^zUEu9Nlx+= zH6inJ0HrvDG%6BuKN-~GSn6>yA^X#e5a&IIcAQVh|Ma3CmokW}2^pa48O<$B;0{6# z=ssrhFmrj5kOg{bB@`^MkIf$|xPRJh}MJ7j6mlFvY zq|<25S+t=&A&1nR-dxN;t|VlUu45E8GmhJt%sou!LFVu{A(QkpOL&PDyg|q%z0Z0+ zWiwwBvPsz_7jXPjgaZiqq(exfB2~#CWR#Aj9w*b7W`vy5Ike+^y3vb}Rl1ZxT+MK< zC*+lGVFGvHl{C9_#p+O?*YjH+@gOeI5T4q9`HbbTH*Of+|!ed(f$Vh!hc;io@AgP3$)IpTxa4K;WAuDw> zbvcoSoJPn?okbhk)0yss%+$pUicY1R%A>VZ+L%EJo+)T)L-Ogn0VLA^Ia$b+KfTvl)ON6Y~ z8?5Gi*7GSL@AWm`l1*~l@lVKn9Y84#A&rWJ+*by*IF@>xOvrvUqb29ij`IomuU_=y zQU-A~Ap>?jqq&6%+(F2J-N#HGW-d<>vS81#lvjA2w+VT$5BY@8*}^x3Oj!P+j(-YM zi~|X|u*0apk)%_TkPSP52AoP$&Lrf+&Z83-(36V@8L=xE!nKU#CPGf^HYRa5(|CZ8 z6?=^N{F7&Sk&qXAjkkD@b$m?7jD5v6z9(NX$3G!AR+N$)OgWAqWXGzL#W5VuNre2^ z>9pW%+R~AbA-j;iT*BpCMaYr;jT`wpO*n&)F*}zIbfE`*2syLM7|b<{ z;08k0>{cdnCsVnfkT-jTc|65po+o6^US$>Uu$GSqxw9|W%6EkS@TU;6XQ>qDAj)z$ zA%AuhnH)`BP9$W|PNO+z(T4Vf99nmJb1?(Cl8{Baj#1pqIBqB8(e7b74>E_x37NE~ zS;9-K;0;19?S0nsDVzD4kWI@bxrF1NA{;=-ryW8X6{$)FA)|II^*EWvG$Z8H&Y>OW z(~Vw)tlFgv;%bI-Jt41l3lq45DcncMtUb(Jo@5cv5prv<@H%g^h7SqZwa?kYH~c_; zxv-F5D@-vCqzs15vLQfZfDb$j&$WhLf-8XE+>?G4&!fx%-i1?&p)_}dkML>hj^4HSjaPk?Ayz% zh7$Auo3pZD>ztx)U;U7c-D68On8p+}zEK<8~%<4mTW?fZa<1}0HrvDkfp0gRWhi>v4lL`$uy=J zEjfpfsXL!;^r9b^5^{A{Go0%g%`Jp%-5pHfK4$VTAz$|-i+GNuyh6y>z0DdvFk{nFP;~ha2 zs*}YrgiPK^G~#qxa5f>A*O9JVNM9}?Wb>|K7=PnN{!Yl}{e!!>ml-@n$ml)6LZ0CT zUMA%9-sD|AU<02KvU>kwJ4uH){@I_9*DFD34y8Pm37Neb)aE$qa|$81cLuFEmkx9x zWcT{epUW7`HH7@$4UFMdCUPes!*@Tkc!YU8Mac0z&oW+R74HzTd>^rqFWAa=gb(^t zC`c;BIf#(yJDf@!MJ7iRa(yS#kke?+S%hp~dpgsd-ds$`_g%?Qu45E86Ec3cGnsps z&Vz)U-{UOcX_oL3A?x=Bt9hUGd`igsea*LIlYFS-pOE=GfKnVn8Wjn-zYJ<|EcG~< zko{{$OU|Jk=M(aOz39iK4B~1+2Jm`Da|;u=gOCHfkC{BoT%IIk0iRML%!2IPL{}iSe2U3Q^sKAk=Qwi!A3g-r_yh@iCkDifw#PzQY{<6rw04Ihb-BLC6kPCyQe^ zo|99pW%+R~AbA-s^jT*BpC#W4QHjr^VQ{DY7syq6g~#G^bx$P+%p3%pFo628g1 zd_c$%e#V#li;y8qN^|_PKXFP>nnNj1WvWqw+8jrHPN4~B(28^EKo@$@hyGl~V6I^V zH!y}cZ-`9LgHp%53{}kZ>N^uBjRHP~y)Z$p`aWai*MoZ429p}@HUi9Nq25~jRxt`J7 z!UXPM3imOShndThEaEvrX73eV=WW*TA)oL$A*=TdKal?j$3KNB#(|XKFe-2)>C~hS zC(wXXY08Rj5wL;vK{BoJ1o|C*<+YrY#-m%7uhX-X&bl zRSe^Agk0X=8P7ksi+c&#yoY#{Cs@ccgnZu1tmIAJsLgSNyxu7^;S5@FE*YZ$=|jNw*7cJEH6azC?pgpl8R zip4z7GF~NQ_}*bHAF+`y2sys*2w(80P>@tYmhT|SayXSZije0!n!21wLrx=P`p%*a z?deQ+Lay&(2681sxsH(SyP0v^&SdT(Lq6eiLKg5FejtAp$3KM$dB6iH!(mk5NJ1vCCUrQ0 z2AoRB1)fQ3&Z83-5VC<6F@P%=!nK5a;7yF>HYRa5AtU$zvw4j9{F9Iqe39k6##_9{ zIzDC-U$Kqv33jWW_x9A&Y1#;?e<{I{r~-ESfM z6#p%1YWFMp);uOXqfWXViT}zT|CSxiyCZ(f&W1L>WoJX1-?Foz&9CfiSYFQ#XI4#$ z@1S00o%F1Oc6o^ZhjzGz9sW&*<&W1GmOs5loxDeKcbEU=k8qY9{x_v(?Qy67%Xhe@ z9sUob*WB$+|CjG@Ej#=lO0V^cJN;k2!$;fU-y!|zU3U7ve1~h>;om8}_D(zfU%tb~ z*x|h({g@wj`Zw-yhZ^br=(oHXdz0;um7(#ky(9mIcVed=8BVnRZ=8Ch*Q}i}cmFuC zUq96Q+>!j z2mBv=ea`$h-hUHqf9;&(zx(y--`Za}q+kF1?tf^1bDj(ThaTU*`u@giRya7GJ|Ui3 zrf7Ed`Qk+J43Q#F*$zeG>Az%5k<>P++1Z`q=}IfKn_D8Dz9*lQh-WlU`T57!NzH#e zTS=W=GCQ7nLCNe=3VGrgig~Q{w3O7;;@Oiz=|JH^@$@}8p-}2Rg;m3zoLaGRcJ>?1 z(=+STNKbt39@Vu*b*=e)uSHbX>U*vHzSknEYxTXh7n!K8)%V&3e6LMZ*Xn!ieSEJ? zRM+Zz?R|Z(O;pz^*HTcfg{ZDou4R9@7NWXVxt2n5Ekt##a&2lxqqRO|^ z)~K#Es%wqvTBEwwsIE1tYmMq!qq^3=T3u_msIGPQ;q;#;pmsZ*69&vHLfFIQpAcS- z>RMG#?Y71>s%za{P0fG#Q~E!sX$^I-QC;hg8O~p=l*)Ss{7ElbzoVu#)Wb$~t-oA9 zllQ3Zc=;zi!h4~nHPppMb*(!-mb<(D|M@YF>RR{ENlaAt_|v=4&(A+{p1)BY?=J7o z|DTWF&?V0Q!MeDpF8!A}wTfP^qSvdu*XjT6*XMt$&hxKc2k7RR&|Fz1fA5zO{6#_evk!y+}E)GhGr^nb`k?WsW;2PDp zM)j>xeQPvsYqEZuiN> zR{gfCTaCtTRo@lljmB+_#%+zpZPjhAu~VaQTcdGX zqj6iKaa*HtTcdGXqj6jRcgAg9p)bF)!nmy=j-OR9mgJ&V;-a>%*HU6f=K#H96dSsX6!eU+cE375&TJc0-%~WpBIc%)jhyH{j@B_O`2){L9{U!+`!}Z@bFH zzvgZ4Qaf*W&cEJ!Q@JqOW=eQ#+U?tF;)Gy_Q-Q?${-1n8;D$dO-SGdPr(n@N{`0fT z+(Y0GJL!$?@#r4UdEL!@qV3`TM}Og`iju9r<{MCS!+-H3f>ICf`o2znW9JqLA7FGS;!~NN5enPJYxYhxD`i@! z;R?%QIlmt+5w6cIpn&AUlEO$7>k{u z{wx|VF&4XB{n=#hVLA^ohsRmK(=6d7R`3R^d7t%s%4WXiTe3-3%e5axIDk?dLK+pR zN(Qw!mU^5_W17*Db7;r;bfXvjxRgO$&2X+~G`BE;JD5ULPZo=9QcD)qlf`1S)sjW^ zWU<)LgQ9w}sGcmUCu?$u_kUV(E*%NMsYJyJy}#w7S)qQ^<SIx(w8x{S}b*gk&W zoclYhBWu@LHFusRFV#4sPP#vj*GR1K-`RX@RL)JUvGY0;x8gT%Kd!E8zm!<%e|O(^ z-uTWp;kR#GVb-6wag|kn-o}+i{dpT#+w|vcTyfH$w{g`)f8NHG2mLu4&z}-&lULyL zAH2}S6R&ge-|Qf`#c`IKAum;?&^4Wr^);6 zmewNf(tY;I~_t^LxiRv0eof7EDlB5Gz|e)-Wa9o`xx5r43V3Q;TPhjIApa z>sctw$on0(l~Z1Da;#~g*nX8`H&+UOHL6l*G^$cqT(NR?_O#~dnRRNUhk}Es-XN+s zi0TcZdV{FmAgVV|V{ipSi0Tbuu@}@DMD+&p2F@U&dINbwwFXhWK`b`SXXa78K`i#A zT7#(GK;B3VK~!%LjbHPuv1_97Yhtmu+JR{N8sDqR`bXo}MD+&i)ERtCRBsTA)lh2? z)f>ngt2Kz~4PvpV-XN+si0TcZdV^OEjm2K)ZPxH1pYS%a=lRBJ015TwWXVRMU=)?u|sZetR6GmQtB&11~xpFGQpEax@e z;yu>!F`M{`ZG2C@!_*s7h@zAvsyB%04WfDj$A5XpKcBOOZ}@@yM>zf|Ofe3m42MyH zBT1(wbvS_roJv#9q&4T!i3{k-MGW8yhHxz-xrwpd#w6}$8V@j=$C%GQd6pMh&TG8I zd#vMQHt`kP_?~J6fLgQ(u%uU2nx(|$e^ z3-tzBDLyEyT|6eLRy|Z2s6I$cWQKx*Sm#(F7wu#Fhsf-X6Ipo^KP|7BH%q{{S)!ys z0YdSdzuj>Qv9oe!W#%>iX$}9*7DCZLk(Ah>yX?TutIylWzqIMJ*$kmJX z|BoZk-jOW zBlbFl!ogmrP?)z%=j)la=O@ZRM>*&y2OZ_0lS9ddewQ%lW+(&?|3!<1m!;wyL|G1} z5=RmK_jWXOIgy5(Msv=h4ejYncY1R%1G$o+T*oMGW*oOOnR}SdgUsP^7VtDnc!?Ff z!D`-TJ)g3fulbg2k_+f}ks=&GDGnixic}?osNeM{2mOTX^Fp2>%0WB+WuT)Rw0>7* zprahL??Ggsqa1WB_L&THl!NyEF9RLrphKl>l!K0P(6QM5GSE>D+WEf>bd-aR`dxp| zr&m!9Iu@%e|6GkI2krm=GSE>D+PEvFqa5@;WuhE(l!K0P&`}OL%0ahEGwuQ%=t2+r z(4Wf~%r%VQ2F7qJ6StE}Q3*76Y>`GT!{M@--5DHJ4?;v7U- z4yO`Fk;&220Gej>_VwvGv_M~`b&Sk#XQHjYKJ5OfqJX!O{Nzu>T@N-I1*_`j| z;~76L{GBTfnP;m=-uo+-cK%~lEa&{ktXRVNk6E#7^B=Qfspdas#q!L5%!(zM_x_3t z{}f_QiPt=6Zyq8^hlI|}6WQ6kqc8WM&OM6%^Mg8CYjh9%W$ppTSZ>1j$3twd90$<> z=5)ZF8_QQFR^ec8^@+}%(^3yCnlIKo<{~Mnc(};tETXU}U%UMfkEO*^4>QY!6N{Xs z4S&{l=fB+3R8px>T_VV*2k7zf8jHd!bpD)-@&MW7a~8#G=A6jn z{E;kkgV@1^WBV4$`Jb@o z2-h-_n;6S&OyX{)@c^@VjQRYNXL*t3yvAF+$2vY{6JN27@5z^6eK&SS>Y$8!>mIGq-pOyKe&r~nZZLm$`dT)8D8LJR`MqA z@&Oz8j4$~Y+ewlKF2Mf8DM4utr973XMh$9n9Q8ScCY(Vl&ZPrg=s_R)a~XrVh7sJr z7;a@EcQTdxnZ+Z_<0%&NJj-~MRlLJmK4K$Z5apVUCokVzkCSOkGg@*E?Kq!q^r9b^ zGKi}g&h?Du7A9~9Q@D?rJj`64WD(D?lvjA2w^_r7e8T5!;TwJ+|33P6pfJTakTM)b z1&$=0n$+P08gMF2Ig{3$M<*_zCl@h*sDI}u*Bp~&PN5)CuG#S~(;VfRz5mNJN4aM2 z|1!-{uG#k*GR;x0+55jtbChd#{vY-49OasQu+p3;*X;biM3ifea?Sq#FV7t1n&leg znJ*^FHT(a+OmmcL_W56wYmP;^<|x-3<(i{hbChe2a?Nj-_xV3j|IShW&QY#8$~Avp z+4p~Z!w=-I;`=`eQ;Y*C!(mk5NYbfE9ZsMDr_z)&Y0Y_b;sSbd5d*k_AzaHyZelFA zF^RjG#skdeG3N76p5;ZB^BQmQ9_#p+O?<^Rz9-+2&i^SyQA%S8jB9ze9JuLA$JIk2flG%cdwfbJ z9gvf1PJ}skv!1-?WbR>-dtChTQ#)Ew&Pz#jABC60l$~E}|BdHS{@ldwZk~}ly=?#D zef>ZGyv`r{{@~ci&^2#z$iI3o5|m9zN-yeSpQM`MB6li(lFxs?iN}VA;O#!XHPyI) zz-;#D(6w)y)a>le@tSo*hrUIHX7`Jy7f%VHS$+AQlG>?c0e|gcpQIrrU8I;wYF8j< ztJ4d{PAwUGy@a?*zSz!@!3~VzRwi;MQ@Nj6 zJi&SdUkIuA03$63JBEa4?q@CK`SpY?poX1?ZIvPnMFcwZFZ z07`KPX;h>t8Pwuf>TxoSX+}%Vp&jSbjb8NQQU-A~!?~W(+`BVD1~2%2cBUwK6!e|!5x{PefCPozwLd;3I~^tZQ9 zWJZ5$`vaS0r2nSJD=9k!^-9I&?(gD|c*fCT@>d_-M9MR2?wH2;X? zq#MGF61nqtzR)UX^^YVs6^}e5sd`}-Idgt~k#h_46)uyMwc94SfA`y4=G^P~3LhLh zvruYUK+-%O^oF>CUG~@c!1eF#(e(Cv%JW1UgIs^V;vu}iLcnk_vG7G z-6Mr4N=Xi;97j-v>SS>Y$8!>mI2}a|v9oDQN4jz$eYu3oxr$-@jT`wp8DpQRb)aE$qa|%s3gI1hN z2fEOMKJ@1@26GJ~xPdX;%0%vDD)%#sN0`S`EarKZ@hYo$hqZjfM!sMx-x1|IqkQMJ z@|+{NiLpfaPRGA&XO!>s{jO|hl<)NYu54$N@AUpJ+Zp9Mz5gE&GIg=>g>Hq%+`~N@2D2wu)@|01&Gs<^H`OYZc8Ra{pd}oyJ zjPjjPzEgciCG{OV$Q&ML0Z+4pmsr6Ytmb{z^C_G8ns3P_xw85Wif{m>ID|ARQk4v9 zaV+&XnZ`7uCFjtN^XWz}`f(|PxSHWy&uDI80(UTl`FbbfO3Te~ zR@i-mL}+sl4H7BMJv2zfGWXCRk;B|WgGBIh4-FDY%RMwmL@RgO;IO=IHg|tdB~pLU z@?D;3iD%Mbxw*wZ?G8FDH-C6UqN~J;5{HK+iMk$Gv|`ekMMGbRGefj6^o7_kUb9W; z3$c1`9N=~g`|c#y&fJY z&sQ<&_`-=!4{L>_UFhoY)tn?;tuXv|H2$~u!6fg4)Z$p`aWai*MoZ429p}@HUi9Nq z25~jRxt`J7!UXPM3imOShndThEaEwq@(Qo>Hf#8hPxzcIe8Ugqms#7F!W82`%5WGJ zIFfW~Qil_0z^OFlOj>guow$IWMEx8_{TxR997Z{~sGq|q2N&hwq8wb5gOh`mdwZFc zyh)UU)6a2JtW#FP5oE*FiT$F>;&r#IRVU&ZDe~WT(>R^ofeJIiR z-%&q@Q4UTWob20IL^(KhaPn{UIVH-$Mg1H`{T#-M9v0=`q8wb5gNt%-Q4TK3!9_W^ zCUUyIr10!0mo{ zA^^Ah<%#^;?w2RxZ@XWfNWblVc_RF_`{jx3+wPX9=SAOY<@AwOEAQ?}S|9GMj7041 z$2)K5yY~Oc!t~s1TkbuSxDWo&`zJj&%$AktOz#XSwpvM@eoC=Gs5fmy=_ALC-cG$}9o4nn`CBcahUHwfKu-Y_*aa<;K`5?i^!F z&xx@`IkqUr7UkIV=j8q6QJ!ES&+q~-vywM?mk-#$XMD-O*iMrE&Fos*1$em2(erE9q^LUEIJkK&-WfkwR zmXFxT7eqO>D90A%*rFU;lw*r>Y*CIa%CQxcSxcoj2T_(N$EJ=&hAqmm>CaV$Ey}Su z|1TEh*iMvRYeW_PY1ftgFf`Z~&z^gfuErl?-Ze zEcG~<#x$cP=g^Mx=|(U5aVdkin&Di}Xl`KwcQA$fn90M;28>-|HcpIF+bo{hJ?#9C zdz;9b{pDt#nwuSKcBmfun&<6e&bY0=Z(mN$jfkCo!ZCG(9xtqCG~kR+;yHI&jl>fr zcWd#Cs!3)yjVER|PK>bH%*d#XliG&?Qky0{9>!8_l9<|I=1!xgwv;hDH8gpkWV58A z2bRmI`BE4qH8XdV)I^_Y4<)q{?6Cp&jQF z<=Ko+F3T3>+0?VhvPF3|-)BX6w%Bmvl3q_VK55i1Ta;(>eYPxHlxOpOwk%teXN&q} zs}|+iVi%W{XW>eQavh_%nQ`3CWbR=)4>E_xS-{gQ;U!k^2CI3W^?b@^zUEu9N!IV# zeiY#VN^uBjRHP~y)Z$p`aWai*MoZ429p}@HUi9Nq25~jRxt`J7!UXPM3imOShndTh zM0vI-&vuIZSQE}5%CkjzwkXdQ<=LV|8G-Q!b6hAhtO8e zdVZX`BC%|6^3AclJ1{4ob+vp}&doVAtSU@hv11x58Yk1(lhHVtQ64MGV?}waD32B8 zv7XpB%40=&tSFBa<*}kXR+Ptz@>n;?U5({7CUG~@i1Jwa7na3}@>u#8kj0AfSiX;q z@>sDK^(_|lFQ$L-sDH62k0pn72T>j?7W>%vn9(?yQU7959xKXYZ7UaxeNR69ibeg4 zMR}~Kf3c{4v7!}>lgq)B;|Quyoh**wcut}br_+M7X-h}Cav^=Wgv+^#Vf>97`8(tJ z2X}EVGkAzcd4h#J!wbC3O5WsMK41f%@g@IaJ4qFdFU0=DDM4utr973XMh$9n9Q8Sc zCY(Vl&LzrYMR}|!kM+Nq$I6m7X`h}^CtV&ZyJT(xE4xHI5xsIHk-2gu5w>zAk+O0n z5wCJ3k*jhg5va;8UOgkM?w~}1O6DjhJM|x5mdMRQ?REtvax+iaB|?O&hD{sXvT^9+ zD!YVeQ+A1@ib*+_sTGrsi=A8`b#}?@?MXN0Pc*nO|8c3!V@pJsvP&jaE)p(_=U$e` zS20PMZ)(NL+1V$ZkQr8e`Oa~rXgtIyuM~}k7>$P*jfbdytUxp#V$^?Blvj%KN>N@Z z$}2^Ar6{iy<&~nmlFuvUlHTN9K41f%5#^PPha#&K<(2dwA*&SSm10p|NlqzFlvj%S zkD3tWm16J7DXk;QEBU@M>OU&#KPu`!Di-CH$-ZM^c0XD8(V9QIV=-P>W-! z$H_FN87(=7cAQT)deM(d8N}5L=XyqS3lq45Dcr|Q9%e32vWVwc$}7Ci+pHnVEB$Zg zl^!wNaQpPkL|!TTq1KXe9r{^WTdS9GY zr`)+$q+Mv9Yia#D_P(G?TGeuCy*u{mVn**vdiFGRasQ6}`kPG8x!$L*tC^WKkItFu z5~eb)JkLa8mF;uZ-z9g=JFlo~_bxpNY_e_7wK#k5O$cR#;R=fsV;q*vd| zeq7L}Q;#m^_YXI&SC{@BJ9q5gQS&R(y7bn(d+!Uv-=5X7;Ze2IuH50q?0hTIdUfn~ zVV9hB>~u?ZnE$(OO}M+Rtm;PP?qOR0?(Rmnc;~I<&E31(%4mB@&CK*#Z7(?{>zJ;) zZgiK$dw1)Zd*1pki}&upRwy5{Ie&(t_^?DZ>?@I-Y;n%>}77MpDZFRA+%lEs_%$bb#to@TfI3g*%ra0;p~RKD9;&wJkt( zd+@16a__BV=HTDKjjR3eU3kh3vb#SKw)mY$BX96`(1E?ke zk2+(hx&WY>4m|2$-mUo#A=JDDKs7gTr~{=RP{R!5_Y*)h0>Jh<2e9?b4sLM^ ffZOoc(iW(F4QW<_!VUz1ooZk^-Z3c1Kim)ib+5hb literal 0 HcmV?d00001 diff --git a/crates/czsc-utils/src/bar_generator.rs b/crates/czsc-utils/src/bar_generator.rs new file mode 100644 index 000000000..753ee2df1 --- /dev/null +++ b/crates/czsc-utils/src/bar_generator.rs @@ -0,0 +1,999 @@ +use crate::{errors::UtilsError, freq_data::freq_end_time}; +use anyhow::Context; +use chrono::{DateTime, Utc}; +use czsc_core::objects::{ + bar::{RawBar, RawBarBuilder, Symbol}, + freq::Freq, + market::Market, +}; +use error_support::czsc_bail; +use parking_lot::{RwLock, RwLockWriteGuard}; +use std::collections::{BTreeMap, VecDeque}; + +#[cfg(feature = "python")] +use pyo3::prelude::PyDictMethods; +#[cfg(feature = "python")] +use pyo3::types::{PyAnyMethods, PyDict, PyListMethods}; +#[cfg(feature = "python")] +use pyo3::{IntoPyObject, PyResult, pyclass, pymethods}; +#[cfg(feature = "python")] +use pyo3::{PyObject, Python}; +#[cfg(feature = "python")] +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; + +#[cfg_attr(feature = "python", gen_stub_pyclass)] +#[cfg_attr(feature = "python", pyclass(module = "czsc._native"))] +pub struct BarGenerator { + market: Market, + /// 基准周期K线 + base_freq: Freq, + /// 最大K线数量限制 + max_count: usize, + /// 所有周期的K线数据,key是周期字符串,value是K线列表 + pub freq_bars: BTreeMap>>, +} + +impl Clone for BarGenerator { + fn clone(&self) -> Self { + let freq_bars = self + .freq_bars + .iter() + .map(|(freq, bars_lock)| { + let bars = bars_lock.read().clone(); + (*freq, RwLock::new(bars)) + }) + .collect(); + Self { + market: self.market, + base_freq: self.base_freq, + max_count: self.max_count, + freq_bars, + } + } +} + +impl BarGenerator { + pub fn new( + base_freq: Freq, + freqs: Vec, + max_count: usize, + market: Market, + ) -> Result { + let bars = freqs + .into_iter() + .chain(std::iter::once(base_freq)) + .map(|f| (f, RwLock::new(VecDeque::with_capacity(max_count)))) + .collect(); + + let bg = BarGenerator { + market, + base_freq, + max_count, + freq_bars: bars, + }; + + Ok(bg) + } + + /// 初始化某个周期的K线序列 + /// + /// # 函数计算逻辑 + /// + /// 1. 检查输入的`freq`是否存在于`self.freq_bars`的键中。如果不存在,返回错误。 + /// 2. 检查`self.freq_bars[freq]`是否为空。如果不为空,返回错误,表示不允许重复初始化。 + /// 3. 如果以上检查都通过,将输入的`bars`存储到`self.freq_bars[freq]`中。 + /// 4. 从`bars`中获取最后一根K线的交易标的代码,更新`self.symbol`。 + /// + /// # Arguments + /// + /// * `freq` - 周期名称 + /// * `bars` - K线序列 + /// + /// # Returns + /// + /// * `Ok(())` - 初始化成功 + /// * `Err(String)` - 包含错误信息的字符串 + pub fn init_freq_with_bars(&mut self, freq: Freq, bars: I) -> Result<(), UtilsError> + where + I: IntoIterator, + { + if !self.freq_bars.contains_key(&freq) { + czsc_bail!("周期 {} 不在self.bars", freq); + } + + if let Some(existing_bars) = self.freq_bars.get(&freq) + && !existing_bars.read().is_empty() + { + czsc_bail!("self.bars['{}'] 不为空,不允许执行初始化", freq); + } + + let bars = bars + .into_iter() + .enumerate() + .map(|(id, mut bar)| { + bar.id = id as i32; + bar + }) + .collect(); + + self.freq_bars.insert(freq, RwLock::new(bars)); + Ok(()) + } + + /// 更新指定周期K线 + /// + /// # 函数计算逻辑 + /// + /// 1. 计算目标周期的结束时间`freq_edt` + /// 2. 检查`self.bars`中是否已经有目标周期的K线: + /// - 如果没有,创建一个新的`RawBar`对象并添加到`self.bars`中,然后返回 + /// 3. 如果已有K线,获取最后一根K线`last` + /// 4. 检查`freq_edt`与最后一根K线的日期时间的关系: + /// - 如果不相等,创建新的`RawBar`对象并添加到序列末尾 + /// - 如果相等,创建新的`RawBar`对象并更新最后一根K线,其中: + /// * 开盘价使用最后一根K线的开盘价 + /// * 收盘价使用当前K线的收盘价 + /// * 最高价取最后一根K线和当前K线的最高价的最大值 + /// * 最低价取最后一根K线和当前K线的最低价的最小值 + /// * 成交量和成交金额为两根K线的累加值 + /// + /// # Arguments + /// + /// * `bar` - 基础周期已完成K线的引用 + /// * `freq` - 目标周期的引用 + /// + fn update_freq( + &self, + bar: &RawBar, + freq: Freq, + mut bars: RwLockWriteGuard<'_, VecDeque>, + ) -> Result<(), UtilsError> { + // 1. 计算目标周期的结束时间 + let freq_edt = freq_end_time(bar.dt, freq, self.market)?; + + // 如果是第一根K线 + if bars.is_empty() { + let new_bar = RawBarBuilder::default() + .symbol(bar.symbol.clone()) + .id(0) + .dt(freq_edt) + .freq(freq) + .open(bar.open) + .close(bar.close) + .high(bar.high) + .low(bar.low) + .vol(bar.vol) + .amount(bar.amount) + .build() + .context("Failed to create the first rawbar")?; + // 限制K线数量 + // 如果超出最大容量,先移除最旧的元素 + if bars.len() == self.max_count { + bars.pop_front(); + } + bars.push_back(new_bar); + return Ok(()); + } + + // 3. 获取最后一根K线的引用 + let last = bars.back().unwrap(); + + // 4. 创建新的K线 + let new_bar = if freq_edt != last.dt { + // 如果时间不同,创建新的K线 + RawBarBuilder::default() + .symbol(bar.symbol.clone()) + .id(last.id + 1) + .dt(freq_edt) + .freq(freq) + .open(bar.open) + .close(bar.close) + .high(bar.high) + .low(bar.low) + .vol(bar.vol) + .amount(bar.amount) + .build() + .context("Failed to create a new rawbar")? + } else { + // 如果时间相同,更新现有K线 + RawBarBuilder::default() + .symbol(bar.symbol.clone()) + .id(last.id) + .dt(freq_edt) + .freq(freq) + // 保持原有开盘价 + .open(last.open) + // 更新收盘价 + .close(bar.close) + // 取最大值 + .high(last.high.max(bar.high)) + // 取最小值 + .low(last.low.min(bar.low)) + // 累加成交量 + .vol(last.vol + bar.vol) + // 累加成交额 + .amount(last.amount + bar.amount) + .build() + .context("Failed to create a new rawbar")? + }; + + // 更新或添加K线 + if freq_edt != last.dt { + // 限制K线数量 + // 如果超出最大容量,先移除最旧的元素 + if bars.len() == self.max_count { + bars.pop_front(); + } + bars.push_back(new_bar); + } else { + let last_index = bars.len() - 1; + bars[last_index] = new_bar; + } + + Ok(()) + } + + /// 获取最新K线日期 + pub fn latest_date(&self) -> Option> { + self.freq_bars + .values() + .next() + .and_then(|v| v.read().back().cloned()) + .map(|b| b.dt) + } + + /// 获取所属品种 + pub fn symbol(&self) -> Option { + self.freq_bars + .values() + .next() + .and_then(|v| v.read().back().cloned()) + .map(|b| b.symbol) + } + + /// 更新各周期K线 + /// + /// # 函数计算逻辑 + /// + /// 1. 获取基准周期`base_freq`,并验证输入`bar`的周期值是否与之匹配 + /// 2. 更新`self.symbol`和`self.end_dt`为当前K线的对应值 + /// 3. 检查重复性: + /// - 检查`self.bars[base_freq]`中是否已存在相同时间的K线 + /// - 如果存在重复K线,返回错误,不进行更新 + /// 4. 如果无重复,遍历所有周期: + /// - 对每个周期调用`update_freq`方法更新K线数据 + /// 5. 维护数据量: + /// - 遍历所有周期的K线数据 + /// - 确保每个周期的K线数量不超过`max_count` + /// - 如果超过限制,保留最新的`max_count`条数据 + /// + /// # Arguments + /// + /// * `bar` - 已完成的基准周期K线的引用 + /// + /// # Returns + /// + /// * `Ok(())` - 更新成功 + /// * `Err(String)` - 包含错误信息的字符串 + pub fn update_bar(&self, bar: &RawBar) -> Result<(), UtilsError> { + // 1. 验证基准周期是否匹配 + if bar.freq != self.base_freq { + czsc_bail!( + "输入周期和基准周期不匹配. Expected {}, got {}", + self.base_freq, + bar.freq.to_string() + ); + } + + // 3. 检查是否存在重复的K线 + if let Some(base_bars) = self.freq_bars.get(&self.base_freq) + && let Some(last_bar) = base_bars.read().back() + && last_bar.dt == bar.dt + { + return Ok(()); + } + + for (freq, bars) in self.freq_bars.iter() { + // 更新每个周期的K线 + self.update_freq(bar, *freq, bars.write())?; + } + + Ok(()) + } +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", gen_stub_pymethods)] +#[cfg_attr(feature = "python", pymethods)] +impl BarGenerator { + #[new] + #[pyo3(signature = (base_freq, freqs, max_count = 2000, market = None))] + fn new_py( + base_freq: PyObject, + freqs: PyObject, + max_count: usize, + market: Option, + ) -> PyResult { + use std::str::FromStr; + + Python::with_gil(|py| { + // 转换base_freq - 支持字符串和枚举 + let base_freq = + if let Ok(py_str) = base_freq.downcast_bound::(py) { + let py_str = py_str.to_string(); + Freq::from_str(&py_str).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("解析base_freq失败: {e}")) + })? + } else if let Ok(freq) = base_freq.extract::(py) { + freq + } else { + return Err(pyo3::exceptions::PyValueError::new_err( + "base_freq必须是字符串或Freq枚举", + )); + }; + + // 转换freqs - 支持字符串列表和枚举列表 + let freqs_list = freqs + .downcast_bound::(py) + .map_err(|_| pyo3::exceptions::PyValueError::new_err("freqs必须是列表"))?; + + let mut converted_freqs = Vec::new(); + for freq_item in freqs_list.iter() { + let freq = if let Ok(py_str) = freq_item.downcast::() { + let py_str = py_str.to_string(); + Freq::from_str(&py_str).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("解析freqs失败: {e}")) + })? + } else if let Ok(freq) = freq_item.extract::() { + freq + } else { + return Err(pyo3::exceptions::PyValueError::new_err( + "freqs中的每个元素必须是字符串或Freq枚举", + )); + }; + converted_freqs.push(freq); + } + + // 转换market - 支持字符串、枚举和None(默认为A股) + let market = if let Some(market_obj) = market { + if let Ok(py_str) = market_obj.downcast_bound::(py) { + let py_str = py_str.to_string(); + Market::from_str(&py_str).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("解析market失败: {e}")) + })? + } else if let Ok(market) = market_obj.extract::(py) { + market + } else { + return Err(pyo3::exceptions::PyValueError::new_err( + "market必须是字符串或Market枚举", + )); + } + } else { + Market::Default // 默认为默认市场,与历史版本保持一致 + }; + + let bg = Self::new(base_freq, converted_freqs, max_count, market)?; + Ok(bg) + }) + } + + /// 初始化某个周期的K线序列 + /// + /// # 函数计算逻辑 + /// + /// 1. 检查输入的`freq`是否存在于`self.freq_bars`的键中。如果不存在,返回错误。 + /// 2. 检查`self.freq_bars[freq]`是否为空。如果不为空,返回错误,表示不允许重复初始化。 + /// 3. 如果以上检查都通过,将输入的`bars`存储到`self.freq_bars[freq]`中。 + /// 4. 从`bars`中获取最后一根K线的交易标的代码,更新`self.symbol`。 + /// + /// # Arguments + /// + /// * `freq` - 周期名称 (支持字符串或Freq枚举) + /// * `bars` - K线序列 + fn init_freq_bars(&mut self, freq: PyObject, bars: Vec) -> PyResult<()> { + use std::str::FromStr; + + Python::with_gil(|py| { + // 转换freq - 支持字符串和枚举 + let freq = if let Ok(py_str) = freq.downcast_bound::(py) { + let py_str = py_str.to_string(); + Freq::from_str(&py_str).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("解析freq失败: {e}")) + })? + } else if let Ok(freq) = freq.extract::(py) { + freq + } else { + return Err(pyo3::exceptions::PyValueError::new_err( + "freq必须是字符串或Freq枚举", + )); + }; + + self.init_freq_with_bars(freq, bars.into_iter())?; + Ok(()) + }) + } + + /// 获取最新K线日期 + pub fn get_latest_date(&self) -> Option { + self.latest_date().map(|dt| dt.to_string()) + } + + /// 获取所属品种 - Python 属性 + #[getter] + #[pyo3(name = "symbol")] + fn get_symbol_py(&self) -> Option { + self.freq_bars + .values() + .next() + .and_then(|v| v.read().back().cloned()) + .map(|b| b.symbol.to_string()) + } + + /// 获取基准频率 + #[getter] + fn base_freq(&self) -> String { + match self.base_freq { + Freq::F1 => "1分钟", + Freq::F2 => "2分钟", + Freq::F3 => "3分钟", + Freq::F4 => "4分钟", + Freq::F5 => "5分钟", + Freq::F6 => "6分钟", + Freq::F10 => "10分钟", + Freq::F12 => "12分钟", + Freq::F15 => "15分钟", + Freq::F20 => "20分钟", + Freq::F30 => "30分钟", + Freq::F60 => "60分钟", + Freq::F120 => "120分钟", + Freq::F240 => "240分钟", + Freq::F360 => "360分钟", + Freq::D => "日线", + Freq::W => "周线", + Freq::M => "月线", + Freq::S => "季线", + Freq::Y => "年线", + Freq::Tick => "Tick", + } + .to_string() + } + + /// 获取end_dt属性(Python兼容) + #[getter] + fn end_dt(&self, py: Python) -> PyResult> { + match self.latest_date() { + Some(dt) => { + let timestamp = czsc_core::utils::common::create_naive_pandas_timestamp(py, dt)?; + Ok(Some(timestamp)) + } + None => Ok(None), + } + } + + /// 获取各周期K线数据 - 返回字典,键为频率字符串,值为K线列表 + #[getter] + fn bars(&self, py: Python) -> PyResult { + let dict = PyDict::new(py); + + // 遍历 BarGenerator 的所有周期数据 + for (freq, bars_lock) in &self.freq_bars { + let bars = bars_lock.read(); + let freq_str = match freq { + Freq::Tick => "Tick", + Freq::F1 => "1分钟", + Freq::F2 => "2分钟", + Freq::F3 => "3分钟", + Freq::F4 => "4分钟", + Freq::F5 => "5分钟", + Freq::F6 => "6分钟", + Freq::F10 => "10分钟", + Freq::F12 => "12分钟", + Freq::F15 => "15分钟", + Freq::F20 => "20分钟", + Freq::F30 => "30分钟", + Freq::F60 => "60分钟", + Freq::F120 => "120分钟", + Freq::F240 => "240分钟", + Freq::F360 => "360分钟", + Freq::D => "日线", + Freq::W => "周线", + Freq::M => "月线", + Freq::S => "季线", + Freq::Y => "年线", + }; + + // 将 RawBar 转换为 PyRawBar + let py_bars: Vec = bars.iter().cloned().collect(); + + dict.set_item(freq_str, py_bars)?; + } + + Ok(dict.into()) + } + + /// 从Python RawBar对象更新K线数据 - 支持直接自动转换 + #[pyo3(signature = (bar))] + fn update(&self, bar: &RawBar) -> PyResult<()> { + self.update_bar(bar) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string())) + } + + /// 支持 pickle 序列化 - 使用 __reduce__ 方法 + fn __reduce__(&self, py: Python) -> PyResult { + // 构造函数参数 + let freqs: Vec = self + .freq_bars + .keys() + .filter(|&freq| *freq != self.base_freq) // 排除基准频率 + .map(|freq| freq.to_string()) + .collect(); + + let args = (self.base_freq.to_string(), freqs, self.max_count).into_pyobject(py)?; + + // 状态数据 + let state = PyDict::new(py); + state.set_item("market", self.market.to_string())?; + + // 保存所有周期的K线数据 + let freq_bars_dict = PyDict::new(py); + for (freq, bars_lock) in &self.freq_bars { + let bars = bars_lock.read(); + let bars_list: Vec<_> = bars.iter().cloned().collect(); + freq_bars_dict.set_item(freq.to_string(), bars_list)?; + } + state.set_item("freq_bars", freq_bars_dict)?; + + // 返回 (constructor, args, state) + let constructor = py.get_type::(); + let result = (constructor, args, state).into_pyobject(py)?; + Ok(result.into()) + } + + /// 支持 pickle 反序列化 + fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + use std::str::FromStr; + + let state_dict = state.downcast_bound::(py)?; + + // 恢复市场属性 + if let Some(market_item) = state_dict.get_item("market")? { + let market_str: String = market_item.extract()?; + self.market = Market::from_str(&market_str).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Failed to parse market: {e}")) + })?; + } + + // 恢复K线数据 + if let Some(freq_bars_item) = state_dict.get_item("freq_bars")? { + let freq_bars_dict = freq_bars_item.downcast::()?; + self.freq_bars.clear(); + + for (freq_str, bars_obj) in freq_bars_dict.iter() { + let freq_str: String = freq_str.extract()?; + let freq = Freq::from_str(&freq_str).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Failed to parse freq: {e}")) + })?; + + let bars_list: Vec = bars_obj.extract()?; + let bars_deque: VecDeque = bars_list.into_iter().collect(); + self.freq_bars.insert(freq, RwLock::new(bars_deque)); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use chrono::{NaiveDateTime, TimeZone}; + + use super::*; + use std::sync::Arc; + + #[test] + fn test_init_freq_bars() { + let mut bg = + BarGenerator::new(Freq::F1, vec![Freq::F5, Freq::F15], 5, Market::Default).unwrap(); + + // 创建测试用的K线数据 + let test_bars = vec![ + RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F5) + .open(4000.0) + .close(4010.0) + .high(4020.0) + .low(3990.0) + .vol(1000.0) + .amount(4000.0) + .build() + .unwrap(), + RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(2) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-1 0:6:0", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F5) + .open(4010.0) + .close(4020.0) + .high(4030.0) + .low(4000.0) + .vol(1200.0) + .amount(4800.0) + .build() + .unwrap(), + ]; + + // 成功初始化 + let result = bg.init_freq_with_bars(Freq::F5, test_bars.clone()); + assert!(result.is_ok()); + + // 验证数据是否正确存储 + let bars = bg.freq_bars.get(&Freq::F5).unwrap().read(); + assert_eq!(bars.len(), 2); + assert_eq!(bars[0].open, 4000.0); + assert_eq!(bars[1].close, 4020.0); + + // 重复初始化错误 + drop(bars); + let result = bg.init_freq_with_bars(Freq::F5, test_bars); + assert!(result.is_err()); + + // 空K线列表初始化 + let result = bg.init_freq_with_bars(Freq::F15, vec![]); + assert!(result.is_ok()); + let bars = bg.freq_bars.get(&Freq::F15).unwrap().read(); + assert!(bars.is_empty()); + } + + #[test] + fn test_update_freq_new_bar() { + let bg = BarGenerator::new(Freq::F1, vec![Freq::F5, Freq::F15], 5, Market::AShare).unwrap(); + + // 第一根K线的更新 + let bar1 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-1 2:1:0", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F1) + .open(4000.0) + .close(4010.0) + .high(4020.0) + .low(3990.0) + .vol(1000.0) + .amount(4000.0) + .build() + .unwrap(); + + let result = bg.update_bar(&bar1); + assert!(result.is_ok()); + + let five_min_bars = bg.freq_bars.get(&Freq::F5).unwrap().read(); + assert_eq!(five_min_bars.len(), 1, "K线柱数量应该为1"); + assert_eq!(five_min_bars[0].open, 4000.0, "开盘价应该为4000.0"); + assert_eq!(five_min_bars[0].close, 4010.0, "收盘价应该为4010.0"); + assert_eq!(five_min_bars[0].high, 4020.0, "最高价应该为4020.0"); + assert_eq!(five_min_bars[0].low, 3990.0, "最低价应该为3990.0"); + assert_eq!(five_min_bars[0].vol, 1000.0, "成交量应该为1000.0"); + } + + #[test] + fn test_update_freq_same_bar() { + let bg = BarGenerator::new(Freq::F1, vec![Freq::F5, Freq::F15], 5, Market::AShare).unwrap(); + + let bar1 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-1 2:1:0", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F1) + .open(4000.0) + .close(4010.0) + .high(4020.0) + .low(3990.0) + .vol(10.0) + .amount(40.0) + .build() + .unwrap(); + bg.update_bar(&bar1).unwrap(); + + // 在同一个5分钟周期内更新 + let bar2 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-1 2:2:0", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F1) + .open(4006.0) + .high(4020.0) + .low(4000.0) + .close(4015.0) + .vol(15.0) + .amount(60.0) + .build() + .unwrap(); + + bg.update_bar(&bar2).unwrap(); + + let five_min_bars = bg.freq_bars.get(&Freq::F5).unwrap().read(); + assert_eq!(five_min_bars.len(), 1); + assert_eq!(five_min_bars[0].open, 4000.0, "开盘价应该保持不变"); + assert_eq!(five_min_bars[0].close, 4015.0, "收盘价应该更新"); + assert_eq!(five_min_bars[0].high, 4020.0, "最高价应该取两者的最大值"); + assert_eq!(five_min_bars[0].low, 3990.0, "最低价应该取两者的最小值"); + assert_eq!(five_min_bars[0].vol, 25.0, "成交量应该累加"); + assert_eq!(five_min_bars[0].amount, 100.0, "检查成交额应该累加"); + } + + #[test] + fn test_update_freq_new_period() { + let bg = BarGenerator::new(Freq::F1, vec![Freq::F5, Freq::F15], 5, Market::AShare).unwrap(); + + let bar1 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-2 09:31:00", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F1) + .open(4000.0) + .high(4010.0) + .low(3990.0) + .close(4005.0) + .vol(10.0) + .amount(40.0) + .build() + .unwrap(); + + bg.update_bar(&bar1).unwrap(); + + // 新的5分钟周期(09:36 属于下一个 F5 周期) + let bar2 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-2 09:36:00", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F1) + .open(4010.0) + .high(4030.0) + .low(4000.0) + .close(4020.0) + .vol(20.0) + .amount(60.0) + .build() + .unwrap(); + bg.update_bar(&bar2).unwrap(); + + let five_min_bars = bg.freq_bars.get(&Freq::F5).unwrap().read(); + assert_eq!(five_min_bars.len(), 2, "K线柱数量应该为2"); + // 检查新K线的数据 + assert_eq!(five_min_bars[1].open, 4010.0, "开盘价应该保持不变"); + assert_eq!(five_min_bars[1].close, 4020.0, "收盘价应该为4020.0"); + assert_eq!(five_min_bars[1].high, 4030.0, "最高价应该为4030.0"); + assert_eq!(five_min_bars[1].low, 4000.0, "最低价应该为4000.0"); + assert_eq!(five_min_bars[1].vol, 20.0, "成交量应该为20.0"); + assert_eq!(five_min_bars[1].amount, 60.0, "成交额应该为60.0"); + } + + #[test] + fn test_update_freq_edge_cases() { + let bg = BarGenerator::new(Freq::F1, vec![Freq::F5, Freq::F15], 5, Market::AShare).unwrap(); + + let bar1 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-1 2:6:0", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F1) + .open(4000.0) + .high(4010.0) + .low(3990.0) + .close(4005.0) + .vol(10.0) + .amount(40.0) + .build() + .unwrap(); + bg.update_bar(&bar1).unwrap(); + + // 跨天的情况 + let bar2 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-1-2 2:6:0", "%Y-%m-%d %H:%M:%S").unwrap(), + )) + .freq(Freq::F1) + .open(4010.0) + .high(4030.0) + .low(4000.0) + .close(4020.0) + .vol(20.0) + .amount(60.0) + .build() + .unwrap(); + bg.update_bar(&bar2).unwrap(); + + let five_min_bars = bg.freq_bars.get(&Freq::F5).unwrap().read(); + assert_eq!(five_min_bars.len(), 2, "K线柱数量应该为2"); + // 检查ID的连续性 + assert_eq!( + five_min_bars[1].id, + five_min_bars[0].id + 1, + "K线柱ID应该连续,第二个K线柱的ID应该比第一个多1" + ); + } + + #[test] + fn test_update() { + // 系统内部以 UTC 存储 CST 交易时间 + let dt = Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-12-12 10:01:00", "%Y-%m-%d %H:%M:%S").unwrap(), + ); + + // 基本设置:创建一个同时处理1分钟、5分钟和15分钟周期的BarGenerator + let bg = BarGenerator::new(Freq::F1, vec![Freq::F5, Freq::F15], 5, Market::AShare).unwrap(); + + // 测试周期不匹配的情况 + let invalid_freq_bar = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(dt) + .freq(Freq::F5) + .open(4000.0) + .high(4010.0) + .low(3990.0) + .close(4005.0) + .vol(1000.0) + .amount(4000.0) + .build() + .unwrap(); + + assert!( + bg.update_bar(&invalid_freq_bar).is_err(), + "更新函数应该返回错误,因为传入的K线柱频率无效" + ); + + // 测试正常更新流程 + let bar1 = RawBarBuilder::default() + .symbol("000016.SH".to_string()) + .id(1) + .dt(dt) + .freq(Freq::F1) + .open(4000.0) + .high(4010.0) + .low(3990.0) + .close(4005.0) + .vol(1000.0) + .amount(4000.0) + .build() + .unwrap(); + assert!( + bg.update_bar(&bar1).is_ok(), + "更新函数应该成功处理有效的K线柱" + ); + // 检查更新后的状态 + assert_eq!( + bg.symbol(), + Some(Arc::from("000016.SH".to_string())), + "更新后的符号应该与K线柱的符号匹配" + ); + assert_eq!( + bg.latest_date(), + Some(bar1.dt), + "更新后的结束时间应该与K线柱的时间匹配" + ); + // 测试重复数据的处理 + assert!( + bg.update_bar(&bar1).is_ok(), + "重复数据应该被成功处理,即更新函数应该返回成功,但重复数据不应影响状态" + ); + + // 检查各周期数据是否正确 + assert_eq!( + bg.freq_bars.get(&Freq::F1).unwrap().read().len(), + 1, + "1分钟周期的K线柱数量应该为1" + ); + assert_eq!( + bg.freq_bars.get(&Freq::F5).unwrap().read().len(), + 1, + "5分钟周期的K线柱数量应该为1" + ); + assert_eq!( + bg.freq_bars.get(&Freq::F15).unwrap().read().len(), + 1, + "15分钟周期的K线柱数量应该为1" + ); + + // 测试数据量限制 + // 添加足够多的数据以触发max_count限制 + for i in 2..8 { + let bar = RawBarBuilder::default() + .symbol(Arc::from("000016.SH".to_string())) + .id(i) + .dt(Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str( + format!("2024-1-2 9:3{i}:0").as_str(), + "%Y-%m-%d %H:%M:%S", + ) + .unwrap(), + )) + .freq(Freq::F1) + .open(4000.0 + i as f64) + .high(4010.0 + i as f64) + .low(3990.0 + i as f64) + .close(4005.0 + i as f64) + .vol(1000.0) + .amount(4000.0) + .build() + .unwrap(); + + assert!(bg.update_bar(&bar).is_ok(), "更新数据失败"); + } + + // 验证数据量限制是否生效(max_count = 5) + for (_, bars) in bg.freq_bars.iter() { + assert!(bars.read().len() <= 5, "K线数量超过了max_count限制"); + } + + // 测试跨周期数据的正确性 + let bars_5min = bg.freq_bars.get(&Freq::F5).unwrap().read(); + let last_5min = bars_5min.back().unwrap(); + assert_eq!( + last_5min.freq, + Freq::F5, + "最后一个5分钟周期的K线柱频率应该为F5" + ); + + let bars_15min = bg.freq_bars.get(&Freq::F15).unwrap().read(); + let last_15min = bars_15min.back().unwrap(); + assert_eq!( + last_15min.freq, + Freq::F15, + "最后一个15分钟周期的K线柱频率应该为F15" + ); + + // 测试不同市场时间 + + let bg_other_market = + BarGenerator::new(Freq::F1, vec![Freq::F15], 5, Market::Futures).unwrap(); + + let bar_futures = RawBarBuilder::default() + .symbol("IF2403".to_string()) + .id(1) + .dt(dt) + .freq(Freq::F1) + .open(5180.0) + .high(5185.0) + .low(5178.0) + .close(5182.0) + .vol(100.0) + .amount(518200.0) + .build() + .unwrap(); + + assert!( + bg_other_market.update_bar(&bar_futures).is_ok(), + "更新其他市场数据失败" + ); + + // 断言其他市场的符号是否正确更新 + assert_eq!( + bg_other_market.symbol(), + Some(Arc::from("IF2403".to_string())), + "市场符号应该更新为IF2403" + ); + } +} diff --git a/crates/czsc-utils/src/errors.rs b/crates/czsc-utils/src/errors.rs new file mode 100644 index 000000000..cf1cf251a --- /dev/null +++ b/crates/czsc-utils/src/errors.rs @@ -0,0 +1,41 @@ +use chrono::NaiveDate; +use error_macros::CZSCErrorDerive; +use error_support::expand_error_chain; +use polars::error::PolarsError; +use thiserror::Error; + +#[cfg(feature = "python")] +use pyo3::{PyErr, exceptions::PyException}; + +#[derive(Debug, Error, CZSCErrorDerive)] +pub enum UtilsError { + // #[error("Object: {0}")] + // Object(#[from] ObjectError), + + // #[error("Expected a value, but got None")] + // NoneValue, + #[error("Polars: {0}")] + Polars(#[from] PolarsError), + + #[error("Invalid datetime")] + InvalidDateTime, + + #[error("Invalid datetime in freq_end_date: {0}")] + InvalidFreqEndDate(String), + + #[error("Return should not be empty!")] + ReturnsEmpty, + + #[error("No weights available before: {0}!")] + NoWeightsAvail(NaiveDate), + + #[error("{}", expand_error_chain(.0))] + Unexpected(anyhow::Error), +} + +#[cfg(feature = "python")] +impl From for PyErr { + fn from(e: UtilsError) -> Self { + PyException::new_err(e.to_string()) + } +} diff --git a/crates/czsc-utils/src/freq_data.rs b/crates/czsc-utils/src/freq_data.rs new file mode 100644 index 000000000..dec75c4cc --- /dev/null +++ b/crates/czsc-utils/src/freq_data.rs @@ -0,0 +1,503 @@ +use chrono::{DateTime, Datelike, Duration, NaiveDate, NaiveTime, Timelike, Utc}; +use czsc_core::objects::{bar::RawBar, freq::Freq, market::Market}; +use error_support::czsc_bail; +use hashbrown::HashMap; +use once_cell::sync::Lazy; +use polars::{frame::DataFrame, io::SerReader, prelude::IpcReader}; +use std::{io::Cursor, str::FromStr}; + +use crate::errors::UtilsError; + +static MINUTES_SPLIT_DF: Lazy = Lazy::new(|| { + // 将文件包含在二进制文件中 + const MINUTES_SPLIT_BYTES: &[u8] = include_bytes!("../data/minutes_split.feather"); + let cursor = Cursor::new(MINUTES_SPLIT_BYTES); + IpcReader::new(cursor) + .finish() + .expect("failed to read minutes_split.feather") +}); + +static FREQ_EDT_MAP: Lazy>> = + Lazy::new(|| { + let mut result: HashMap<(Market, Freq), HashMap> = HashMap::new(); + + let format = "%H:%M"; + + // 按照market分组 + let groups = MINUTES_SPLIT_DF + .partition_by(["market"], true) + .expect("failed tp groupby markets"); + + for g in groups { + let market_type = g + .column("market") + .expect("failed to get market col") + .str() + .expect("failed to convert market col into str") + .get(0) + .expect("failed to get the first row for market col"); + let market_type = Market::from_str(market_type).expect("unregistered market type"); + + // 遍历所有包含 "分钟" 的列名 + for minute in MINUTES_SPLIT_DF + .get_column_names() + .iter() + .filter(|&col| col.contains("分钟")) + { + let time_col = g + .column("time") + .expect("failed to get time col") + .str() + .expect("failed to convert time col into str"); + let freq_of_time_col = g + .column(minute) + .expect("failed to get minute col") + .str() + .expect("failed to convert minute col into str"); + + let mut time_map = HashMap::new(); + + // 使用下标遍历确保不会遗漏数据 + for idx in 0..g.height() { + let time = time_col.get(idx).expect("failed to get idx of time col"); + let freq_of_time = freq_of_time_col + .get(idx) + .expect("failed to get idx of minute col"); + + let time = + NaiveTime::parse_from_str(time, format).expect("failed to parse time str"); + let freq_of_time = NaiveTime::parse_from_str(freq_of_time, format) + .expect("failed to parse time str"); + time_map.insert(time, freq_of_time); + } + + let minute_freq = Freq::from_str(minute).expect("unregistered freq"); + + result.insert((market_type, minute_freq), time_map); + } + } + result + }); + +fn freq_market_times(freq: Freq, market: Market) -> Option> { + let time_map = FREQ_EDT_MAP.get(&(market, freq))?; + let mut times: Vec = time_map.keys().cloned().collect(); + times.sort(); + Some(times) +} + +/// 依据分钟级 bars 的时间轴推断市场类型,对齐 Python `check_freq_and_market`。 +/// +/// 当显式 market 与 bars 的交易时间不匹配时,Python 会回退到 `默认`; +/// Rust 执行引擎也应按同一口径处理,否则会把基础周期错误地重采样到别的时间轴。 +pub fn infer_market_from_bars(bars: &[RawBar], freq: Freq) -> Market { + if !freq.is_minute_freq() { + return Market::Default; + } + + let mut time_seq: Vec = bars.iter().rev().take(2000).map(|b| b.dt.time()).collect(); + time_seq.sort(); + time_seq.dedup(); + if time_seq.len() < 2 { + return Market::Default; + } + + let min_time = *time_seq.first().unwrap(); + let max_time = *time_seq.last().unwrap(); + for market in [Market::AShare, Market::Futures, Market::Default] { + let Some(times) = freq_market_times(freq, market) else { + continue; + }; + let sub_times: Vec = times + .into_iter() + .filter(|t| *t >= min_time && *t <= max_time) + .collect(); + if sub_times == time_seq { + return market; + } + } + + Market::Default +} + +// #[allow(unused)] +// static FREQ_MARKET_TIMES: Lazy>> = Lazy::new(|| { +// let mut result: HashMap> = HashMap::new(); + +// // 按照market分组 +// let groups = MINUTES_SPLIT_DF +// .partition_by(["market"], true) +// .expect("failed tp groupby markets"); + +// for g in groups { +// let market_type = g +// .column("market") +// .expect("failed to get market col") +// .str() +// .expect("failed to convert market col into str") +// .get(0) +// .expect("failed to get the first row for market col"); + +// // 遍历所有包含 "分钟" 的列名 +// for minute in MINUTES_SPLIT_DF +// .get_column_names() +// .iter() +// .filter(|&col| col.contains("分钟")) +// { +// let mut v: Vec<_> = MINUTES_SPLIT_DF +// .column(&minute) +// .expect("failed to get minute col") +// .str() +// .expect("failed to convert minute col into str") +// .into_iter() +// .flatten() +// .map(String::from) +// .collect(); +// // 去除连续的重复元素 +// v.dedup(); + +// let key = format!("{}_{}", minute, market_type); +// result.insert(key, v); +// } +// } + +// result +// }); + +/// 计算目标周期的结束时间(仅日期) +fn freq_end_date(dt: NaiveDate, freq: Freq) -> Result { + match freq { + Freq::D => Ok(dt), + Freq::W => { + // ISO weekday: 星期一是1, 星期日是7 + let weekday = dt.weekday().number_from_monday(); + // 计算到周五的天数 + let days_to_add = if weekday <= 5 { + 5 - weekday + } else { + // 周末到下个周五:7 - (weekday - 5) = 7 - weekday + 5 = 12 - weekday + 12 - weekday + }; + Ok(dt + Duration::days(days_to_add as i64)) + } + Freq::Y => { + // 设置为12月31日 + NaiveDate::from_ymd_opt(dt.year(), 12, 31).ok_or_else(|| { + UtilsError::InvalidFreqEndDate(format!("Y freq: year={}", dt.year())) + }) + } + Freq::M => { + // 性能优化:直接计算月末日期,避免创建不必要的DateTime对象 + let year = dt.year(); + let month = dt.month(); + + // 计算下个月的第一天,然后减去一天得到本月最后一天 + let (next_year, next_month) = if month == 12 { + (year + 1, 1) + } else { + (year, month + 1) + }; + + // 直接创建下个月第一天的日期,然后减去一天 + NaiveDate::from_ymd_opt(next_year, next_month, 1) + .ok_or_else(|| { + UtilsError::InvalidFreqEndDate(format!( + "M freq: next_year={next_year}, next_month={next_month}" + )) + })? + .pred_opt() + .ok_or_else(|| { + UtilsError::InvalidFreqEndDate("M freq: failed to get previous day".to_string()) + }) + } + Freq::S => { + // 性能优化:直接计算季度末日期,避免创建不必要的DateTime对象 + let year = dt.year(); + let month = dt.month(); + + // 确定下个季度的第一天 + let (next_quarter_year, next_quarter_month) = match month { + 1..=3 => (year, 4), // Q1 -> Q2 starts in April + 4..=6 => (year, 7), // Q2 -> Q3 starts in July + 7..=9 => (year, 10), // Q3 -> Q4 starts in October + 10..=12 => (year + 1, 1), // Q4 -> Q1 next year starts in January + _ => unreachable!(), + }; + + // 直接创建下个季度第一天,然后减去一天得到本季度最后一天 + NaiveDate::from_ymd_opt(next_quarter_year, next_quarter_month, 1) + .ok_or_else(|| { + UtilsError::InvalidFreqEndDate(format!( + "S freq: next_quarter_year={next_quarter_year}, next_quarter_month={next_quarter_month}" + )) + })? + .pred_opt() + .ok_or_else(|| { + UtilsError::InvalidFreqEndDate("S freq: failed to get previous day".to_string()) + }) + } + // 对于其他周期,直接返回输入日期 + _ => Ok(dt), + } +} + +/// 计算目标周期的结束时间 +pub fn freq_end_time( + dt: DateTime, + freq: Freq, + market: Market, +) -> Result, UtilsError> { + // 如果秒>0 找下1分钟,但要确保不超出有效时间范围 + let dt = if dt.second() > 0 || dt.nanosecond() > 0 { + dt.with_second(0).unwrap().with_nanosecond(0).unwrap() + Duration::minutes(1) + } else { + dt + }; + + // 获取时间的HH:MM格式,与Python版本保持一致 + let hm_str = dt.format("%H:%M").to_string(); + let utc_time = dt.time(); + + // 如果是分钟周期 + if freq.is_minute_freq() { + if let Some(time_map) = FREQ_EDT_MAP.get(&(market, freq)) + && let Some(end_time) = time_map.get(&utc_time) + { + // 直接使用UTC时间来计算结束时间 + let edt = dt + .with_hour(end_time.hour()) + .ok_or(UtilsError::InvalidDateTime)? + .with_minute(end_time.minute()) + .ok_or(UtilsError::InvalidDateTime)?; + + // 修正跨天逻辑:与Python版本保持一致 + // Python版本:if h == m == 0 and freq != Freq.F1 and hm != "00:00" + if end_time.hour() == 0 + && end_time.minute() == 0 + && freq != Freq::F1 + && hm_str != "00:00" + { + // 特殊情况处理:如果结束时间是 00:00 但输入时间不是 00:00 + return Ok(edt + Duration::days(1)); + } + // 直接返回UTC时间,不需要时区转换 + return Ok(edt); + } + + // 如果直接查找失败,尝试用字符串解析的方式查找 + if let Some(time_map) = FREQ_EDT_MAP.get(&(market, freq)) { + // 尝试用 HH:MM 字符串格式查找 + if let Ok(parsed_time) = NaiveTime::parse_from_str(&hm_str, "%H:%M") + && let Some(end_time) = time_map.get(&parsed_time) + { + let edt = dt + .with_hour(end_time.hour()) + .ok_or(UtilsError::InvalidDateTime)? + .with_minute(end_time.minute()) + .ok_or(UtilsError::InvalidDateTime)?; + + if end_time.hour() == 0 + && end_time.minute() == 0 + && freq != Freq::F1 + && hm_str != "00:00" + { + return Ok(edt + Duration::days(1)); + } + return Ok(edt); + } + } + + // 对于非交易时间,寻找下一个交易时间 + if let Some(time_map) = FREQ_EDT_MAP.get(&(market, freq)) { + let mut available_times: Vec<_> = time_map.keys().collect(); + available_times.sort(); + + // 寻找下一个交易时间 + if let Ok(current_time) = NaiveTime::parse_from_str(&hm_str, "%H:%M") { + for &next_time in &available_times { + if next_time > ¤t_time + && let Some(end_time) = time_map.get(next_time) + { + let edt = dt + .with_hour(end_time.hour()) + .ok_or(UtilsError::InvalidDateTime)? + .with_minute(end_time.minute()) + .ok_or(UtilsError::InvalidDateTime)?; + + if end_time.hour() == 0 + && end_time.minute() == 0 + && freq != Freq::F1 + && hm_str != "00:00" + { + return Ok(edt + Duration::days(1)); + } + return Ok(edt); + } + } + + // 如果当天没有更晚的交易时间,使用第二天的第一个交易时间 + if let Some(&first_time) = available_times.first() + && let Some(end_time) = time_map.get(first_time) + { + let next_day = dt + Duration::days(1); + let edt = next_day + .with_hour(end_time.hour()) + .ok_or(UtilsError::InvalidDateTime)? + .with_minute(end_time.minute()) + .ok_or(UtilsError::InvalidDateTime)?; + return Ok(edt); + } + } + } + + czsc_bail!( + "无法找到对应的结束时间: 时间={}, 频率={:?}, 市场={:?}", + hm_str, + freq, + market + ) + } + + // 对于非分钟级别的周期 + // 计算出新的结束日期 + let utc_date = freq_end_date(dt.date_naive(), freq)?; + + // Rust 中需要 DateTime 类型,所以统一使用 00:00:00 时间 + // 这样可以确保同一天内的所有基础周期K线都会更新同一根日线 + let edt = utc_date + .and_hms_opt(0, 0, 0) + .ok_or(UtilsError::InvalidDateTime)? + .and_utc(); + + // 直接返回UTC时间,不需要时区转换 + Ok(edt) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{NaiveDateTime, TimeZone}; + + // #[test] + // fn test_datetime() { + // use chrono::{Local, TimeZone}; + // // 2024/12/12 1:31 + // let naive_datetime = NaiveDateTime::new( + // chrono::NaiveDate::from_ymd_opt(2024, 12, 12).unwrap(), // 年月日 + // chrono::NaiveTime::from_hms_opt(1, 31, 0).unwrap(), // 时分秒 + // ); + + // // 2024/12/12 9:31 GMT+08:00 + // let local_datetime = Local.from_utc_datetime(&naive_datetime); + + // assert_eq!( + // local_datetime.naive_utc().to_string().as_str(), + // "2024-12-12 01:31:00" + // ); + // assert_eq!( + // local_datetime.naive_local().to_string().as_str(), + // "2024-12-12 09:31:00" + // ); + // } + + #[test] + fn test_daily_freq_end_time() { + // 测试日线 freq_end_time 是否返回 00:00:00 + + let test_cases = vec![ + ("2025-08-31 23:45:00", "2025-08-31 00:00:00"), + ("2025-09-01 00:00:00", "2025-09-01 00:00:00"), + ("2025-09-01 00:15:00", "2025-09-01 00:00:00"), + ("2025-09-01 00:30:00", "2025-09-01 00:00:00"), + ("2025-09-01 01:00:00", "2025-09-01 00:00:00"), + ("2025-09-01 12:00:00", "2025-09-01 00:00:00"), + ("2025-09-01 23:45:00", "2025-09-01 00:00:00"), + ]; + + for (input_str, expected_str) in test_cases { + let input_dt = Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str(input_str, "%Y-%m-%d %H:%M:%S").unwrap(), + ); + + let expected_dt = Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str(expected_str, "%Y-%m-%d %H:%M:%S").unwrap(), + ); + + let result = freq_end_time(input_dt, Freq::D, Market::AShare).unwrap(); + + assert_eq!( + result, expected_dt, + "\n输入: {input_str}\n期望: {expected_str}\n实际: {result}" + ); + } + + println!("✅ 所有日线 freq_end_time 测试通过"); + } + + trait TestDateTime { + /// 将DateTime格式化为字符串(不带时区后缀) + /// 注:系统内部以 UTC 存储 CST 交易时间 + fn to_dt_str(&self) -> String; + } + + impl TestDateTime for DateTime { + fn to_dt_str(&self) -> String { + self.format("%Y-%m-%d %H:%M:%S").to_string() + } + } + + #[test] + fn test_freq_minute() { + // 系统内部以 UTC 存储 CST 交易时间(10:01 CST → 存为 10:01 UTC) + let dt = Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-12-12 10:01:00", "%Y-%m-%d %H:%M:%S").unwrap(), + ); + + // 1分钟 + assert_eq!( + freq_end_time(dt, Freq::F1, Market::AShare) + .unwrap() + .to_dt_str(), + "2024-12-12 10:01:00" + ); + + // 5分钟 + assert_eq!( + freq_end_time(dt, Freq::F5, Market::AShare) + .unwrap() + .to_dt_str(), + "2024-12-12 10:05:00" + ); + + // 30分钟 + assert_eq!( + freq_end_time(dt, Freq::F30, Market::AShare) + .unwrap() + .to_dt_str(), + "2024-12-12 10:30:00" + ); + + // 60分钟 + assert_eq!( + freq_end_time(dt, Freq::F60, Market::AShare) + .unwrap() + .to_dt_str(), + "2024-12-12 10:30:00" + ); + } + + /// 非分钟K线的测试(年线) + #[test] + fn test_freq_year() { + let dt = Utc.from_utc_datetime( + &NaiveDateTime::parse_from_str("2024-12-12 10:01:00", "%Y-%m-%d %H:%M:%S").unwrap(), + ); + + let res = freq_end_time(dt, Freq::Y, Market::AShare) + .unwrap() + .to_dt_str(); + + // 非分钟周期的结束时间:日期设为年末(12/31),时间固定为 00:00:00 + assert_eq!(res, "2024-12-31 00:00:00"); + } +} diff --git a/crates/czsc-utils/src/lib.rs b/crates/czsc-utils/src/lib.rs new file mode 100644 index 000000000..667c2b8fe --- /dev/null +++ b/crates/czsc-utils/src/lib.rs @@ -0,0 +1,18 @@ +//! czsc-utils — utilities crate. +//! +//! Phase C scope so far: +//! - [`trading_time`] — `is_trading_time` (czsc-only NEW per design doc §2.5) +//! +//! Pending (deferred until Phase D unlocks `czsc-core`): +//! - `freq_data` — depends on `czsc_core::objects::{RawBar, Freq, Market}` +//! - `bar_generator` — depends on `czsc-core` + +pub mod bar_generator; +pub mod errors; +pub mod freq_data; +pub mod trading_time; + +pub use trading_time::is_trading_time; + +#[cfg(feature = "python")] +pub mod python; diff --git a/crates/czsc-utils/src/python/mod.rs b/crates/czsc-utils/src/python/mod.rs new file mode 100644 index 000000000..723f3db41 --- /dev/null +++ b/crates/czsc-utils/src/python/mod.rs @@ -0,0 +1,51 @@ +//! PyO3 bindings for czsc-utils. Gated by the `python` feature so that +//! downstream Rust consumers don't pull pyo3 in transitively. + +use chrono::{DateTime, Utc}; +use czsc_core::objects::{freq::Freq, market::Market}; +use pyo3::prelude::*; + +use crate::bar_generator::BarGenerator; + +/// `czsc.is_trading_time(dt, market="astock")` → bool. +/// +/// `dt` is taken as a naive Python `datetime` (no tz attached). See +/// design doc §2.5 + §6 F6 for the contract. +#[pyfunction] +#[pyo3(signature = (dt, market="astock"))] +fn is_trading_time(dt: chrono::NaiveDateTime, market: &str) -> bool { + crate::is_trading_time(dt, market) +} + +/// `czsc.freq_end_time(dt, freq, market=Market.Default)` → datetime. +/// +/// Wraps `czsc_utils::freq_data::freq_end_time`. Errors are mapped to +/// `PyValueError` via `UtilsError`'s PyErr conversion. +#[pyfunction] +#[pyo3(signature = (dt, freq, market=Market::Default))] +fn freq_end_time( + dt: DateTime, + freq: Freq, + market: Market, +) -> PyResult> { + crate::freq_data::freq_end_time(dt, freq, market) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) +} + +/// Register the utils submodule on the parent `_native` module. Phase H +/// turns this into the canonical entrypoint for `czsc.is_trading_time`, +/// `czsc.freq_end_time`, and `czsc.BarGenerator`. +pub fn register(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { + let utils = PyModule::new(py, "utils")?; + utils.add_function(wrap_pyfunction!(is_trading_time, &utils)?)?; + utils.add_function(wrap_pyfunction!(freq_end_time, &utils)?)?; + utils.add_class::()?; + parent.add_submodule(&utils)?; + + // Also expose top-level so `from czsc._native import *` makes the + // canonical names directly visible (per design doc §3.1). + parent.add_function(wrap_pyfunction!(is_trading_time, parent)?)?; + parent.add_function(wrap_pyfunction!(freq_end_time, parent)?)?; + parent.add_class::()?; + Ok(()) +} diff --git a/crates/czsc-utils/src/trading_time.rs b/crates/czsc-utils/src/trading_time.rs new file mode 100644 index 000000000..899cbcdd0 --- /dev/null +++ b/crates/czsc-utils/src/trading_time.rs @@ -0,0 +1,48 @@ +//! Trading-time predicate. czsc-only addition; see docs/MIGRATION_NOTES.md §2.2. +//! +//! Inputs are interpreted as the market's local naive datetime (no tz +//! attached): A股 → CST, 港股 → HKT, crypto → any. The Python wrapper at +//! `czsc.is_trading_time` keeps the same contract. + +use chrono::{Datelike, NaiveDateTime, Timelike, Weekday}; + +const MIN_PER_HOUR: u32 = 60; + +const fn hm_minutes(h: u32, m: u32) -> u32 { + h * MIN_PER_HOUR + m +} + +fn minute_of_day(dt: &NaiveDateTime) -> u32 { + hm_minutes(dt.hour(), dt.minute()) +} + +fn is_weekday(dt: &NaiveDateTime) -> bool { + !matches!(dt.weekday(), Weekday::Sat | Weekday::Sun) +} + +/// Return true iff `dt` (local market time) falls inside the regular trading +/// session for `market`. Recognised values: `astock`, `hk`, `crypto`. Any +/// other string returns `false`. +pub fn is_trading_time(dt: NaiveDateTime, market: &str) -> bool { + match market { + "crypto" => true, + "astock" => { + if !is_weekday(&dt) { + return false; + } + let m = minute_of_day(&dt); + (hm_minutes(9, 30)..=hm_minutes(11, 30)).contains(&m) + || (hm_minutes(13, 0)..=hm_minutes(15, 0)).contains(&m) + } + "hk" => { + if !is_weekday(&dt) { + return false; + } + let m = minute_of_day(&dt); + // HK lunch break: 12:00-13:00 (12:00 is closed) + (hm_minutes(9, 30)..hm_minutes(12, 0)).contains(&m) + || (hm_minutes(13, 0)..=hm_minutes(16, 0)).contains(&m) + } + _ => false, + } +} diff --git a/crates/czsc-utils/tests/test_bar_generator.rs b/crates/czsc-utils/tests/test_bar_generator.rs new file mode 100644 index 000000000..1622e254e --- /dev/null +++ b/crates/czsc-utils/tests/test_bar_generator.rs @@ -0,0 +1,77 @@ +//! Phase C.2 — RED test: BarGenerator constructs, accepts seed bars via +//! init_freq_with_bars, refuses double init, and aggregates base-freq bars +//! into the higher freq via update_bar. + +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use czsc_core::objects::bar::{RawBar, RawBarBuilder}; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::market::Market; +use czsc_utils::bar_generator::BarGenerator; + +fn bar(ts: i64, open: f64, close: f64, high: f64, low: f64) -> RawBar { + RawBarBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts, 0).unwrap()) + .freq(Freq::F1) + .id(0) + .open(open) + .close(close) + .high(high) + .low(low) + .vol(1000.0_f64) + .amount(1_000_000.0_f64) + .build() + .unwrap() +} + +#[test] +fn new_constructs_with_freq_keys() { + let bg = BarGenerator::new(Freq::F1, vec![Freq::F30], 100, Market::Default).unwrap(); + assert!(bg.freq_bars.contains_key(&Freq::F1)); + assert!(bg.freq_bars.contains_key(&Freq::F30)); +} + +#[test] +fn init_freq_with_bars_populates_seed_data() { + let mut bg = BarGenerator::new(Freq::F1, vec![Freq::F30], 100, Market::Default).unwrap(); + let seed = vec![bar(1_700_000_000, 10.0, 11.0, 12.0, 9.0)]; + bg.init_freq_with_bars(Freq::F30, seed).unwrap(); + assert_eq!(bg.freq_bars.get(&Freq::F30).unwrap().read().len(), 1); +} + +#[test] +fn init_freq_with_bars_rejects_unknown_freq() { + let mut bg = BarGenerator::new(Freq::F1, vec![Freq::F30], 100, Market::Default).unwrap(); + let seed = vec![bar(1_700_000_000, 10.0, 11.0, 12.0, 9.0)]; + assert!(bg.init_freq_with_bars(Freq::F60, seed).is_err()); +} + +#[test] +fn init_freq_with_bars_rejects_double_init() { + let mut bg = BarGenerator::new(Freq::F1, vec![Freq::F30], 100, Market::Default).unwrap(); + bg.init_freq_with_bars(Freq::F30, vec![bar(1_700_000_000, 10.0, 11.0, 12.0, 9.0)]) + .unwrap(); + let res = bg.init_freq_with_bars(Freq::F30, vec![bar(1_700_000_060, 11.0, 12.0, 13.0, 10.0)]); + assert!(res.is_err()); +} + +#[test] +fn update_bar_appends_for_new_freq_window() { + let bg = BarGenerator::new(Freq::F1, vec![Freq::F30], 100, Market::Default).unwrap(); + bg.update_bar(&bar(1_700_000_000, 10.0, 11.0, 12.0, 9.0)) + .unwrap(); + // Both freq queues received a bar + assert!(bg.freq_bars.get(&Freq::F1).unwrap().read().len() >= 1); + assert!(bg.freq_bars.get(&Freq::F30).unwrap().read().len() >= 1); +} + +#[test] +fn symbol_returns_seed_symbol_after_update() { + let bg = BarGenerator::new(Freq::F1, vec![Freq::F30], 100, Market::Default).unwrap(); + bg.update_bar(&bar(1_700_000_000, 10.0, 11.0, 12.0, 9.0)) + .unwrap(); + let sym = bg.symbol().expect("symbol should be available after update"); + assert_eq!(&*sym, "000001"); +} diff --git a/crates/czsc-utils/tests/test_errors.rs b/crates/czsc-utils/tests/test_errors.rs new file mode 100644 index 000000000..b790fba55 --- /dev/null +++ b/crates/czsc-utils/tests/test_errors.rs @@ -0,0 +1,41 @@ +//! Phase C.0 — RED test: UtilsError variants must format with thiserror, +//! convert from anyhow / PolarsError via the From blanket impls, and +//! serialize to a string-shaped JSON via CZSCErrorDerive. + +use chrono::NaiveDate; +use czsc_utils::errors::UtilsError; + +#[test] +fn invalid_datetime_message() { + let err = UtilsError::InvalidDateTime; + assert_eq!(err.to_string(), "Invalid datetime"); +} + +#[test] +fn invalid_freq_end_date_carries_payload() { + let err = UtilsError::InvalidFreqEndDate("2024-99-99".into()); + assert!(err.to_string().contains("2024-99-99")); +} + +#[test] +fn no_weights_avail_includes_date() { + let dt = NaiveDate::from_ymd_opt(2024, 1, 8).unwrap(); + let err = UtilsError::NoWeightsAvail(dt); + assert!(err.to_string().contains("2024-01-08")); +} + +#[test] +fn from_anyhow_routes_to_unexpected() { + let any: anyhow::Error = anyhow::anyhow!("boom"); + let err: UtilsError = any.into(); + assert!(matches!(err, UtilsError::Unexpected(_))); +} + +#[test] +fn from_polars_error_routes_to_polars() { + use polars::error::PolarsError; + let pe = PolarsError::ComputeError("compute failed".into()); + let err: UtilsError = pe.into(); + assert!(matches!(err, UtilsError::Polars(_))); + assert!(err.to_string().contains("compute failed")); +} diff --git a/crates/czsc-utils/tests/test_freq_data.rs b/crates/czsc-utils/tests/test_freq_data.rs new file mode 100644 index 000000000..2956a353c --- /dev/null +++ b/crates/czsc-utils/tests/test_freq_data.rs @@ -0,0 +1,67 @@ +//! Phase C.1 — RED test: freq_end_time + infer_market_from_bars match the +//! rs-czsc 47ef6efa baseline behaviour. + +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use czsc_core::objects::bar::{RawBar, RawBarBuilder}; +use czsc_core::objects::freq::Freq; +use czsc_core::objects::market::Market; +use czsc_utils::freq_data::{freq_end_time, infer_market_from_bars}; + +#[test] +fn freq_end_time_returns_some_datetime_for_30min_default_market() { + // For an arbitrary intraday minute we just want to confirm the helper + // succeeds and the result is non-decreasing (the actual minute table + // is encoded in minutes_split.feather and is the rs-czsc baseline). + let dt = Utc.with_ymd_and_hms(2024, 1, 8, 9, 30, 0).unwrap(); + let edt = freq_end_time(dt, Freq::F30, Market::Default).unwrap(); + assert!(edt >= dt, "edt {edt} must be >= input dt {dt}"); +} + +#[test] +fn freq_end_time_idempotent_when_already_at_boundary() { + // If a query already lands on a boundary the function must round-trip + // (i.e. calling it again should return the same instant). + let dt = Utc.with_ymd_and_hms(2024, 1, 8, 10, 0, 0).unwrap(); + let edt1 = freq_end_time(dt, Freq::F30, Market::Default).unwrap(); + let edt2 = freq_end_time(edt1, Freq::F30, Market::Default).unwrap(); + assert_eq!(edt1, edt2); +} + +#[test] +fn freq_end_time_handles_daily_freq() { + // For higher timeframes the function is still callable and should not + // panic; the exact boundary semantics are encoded in rs-czsc and we + // simply lock that calling it returns Ok. + let dt = Utc.with_ymd_and_hms(2024, 1, 8, 14, 30, 0).unwrap(); + let _ = freq_end_time(dt, Freq::D, Market::Default).unwrap(); +} + +fn raw_bar(ts_secs: i64, freq: Freq) -> RawBar { + RawBarBuilder::default() + .symbol(Arc::::from("000001")) + .dt(Utc.timestamp_opt(ts_secs, 0).unwrap()) + .freq(freq) + .id(0) + .open(10.0) + .close(11.0) + .high(12.0) + .low(9.5) + .vol(1000.0) + .amount(1_000_000.0) + .build() + .unwrap() +} + +#[test] +fn infer_market_returns_default_for_non_minute_freq() { + let bars = vec![raw_bar(1_700_000_000, Freq::D)]; + assert_eq!(infer_market_from_bars(&bars, Freq::D), Market::Default); +} + +#[test] +fn infer_market_returns_default_for_empty_bars() { + let bars: Vec = Vec::new(); + assert_eq!(infer_market_from_bars(&bars, Freq::F30), Market::Default); +} diff --git a/crates/czsc-utils/tests/test_trading_time.rs b/crates/czsc-utils/tests/test_trading_time.rs new file mode 100644 index 000000000..a780eb0fa --- /dev/null +++ b/crates/czsc-utils/tests/test_trading_time.rs @@ -0,0 +1,54 @@ +//! Phase C.3 — RED test: is_trading_time across A股 / 港股 / crypto. +//! +//! Mirrors the cases locked by test/unit/test_trading_time.py (Phase A.6). +//! The function is czsc-only — see docs/MIGRATION_NOTES.md §2.2. + +use chrono::NaiveDate; +use czsc_utils::is_trading_time; + +fn dt(y: i32, mo: u32, d: u32, h: u32, mi: u32) -> chrono::NaiveDateTime { + NaiveDate::from_ymd_opt(y, mo, d) + .unwrap() + .and_hms_opt(h, mi, 0) + .unwrap() +} + +#[test] +fn astock_regular_session() { + // 2024-01-08 is Monday + assert!(is_trading_time(dt(2024, 1, 8, 9, 30), "astock")); + assert!(is_trading_time(dt(2024, 1, 8, 10, 0), "astock")); + assert!(is_trading_time(dt(2024, 1, 8, 11, 30), "astock")); + assert!(is_trading_time(dt(2024, 1, 8, 13, 0), "astock")); + assert!(is_trading_time(dt(2024, 1, 8, 15, 0), "astock")); +} + +#[test] +fn astock_lunch_break_and_off_hours() { + assert!(!is_trading_time(dt(2024, 1, 8, 12, 30), "astock")); + assert!(!is_trading_time(dt(2024, 1, 8, 15, 30), "astock")); +} + +#[test] +fn astock_weekend_closed() { + // 2024-01-06 is Saturday + assert!(!is_trading_time(dt(2024, 1, 6, 10, 0), "astock")); +} + +#[test] +fn hk_regular_session_and_lunch() { + assert!(is_trading_time(dt(2024, 1, 8, 9, 30), "hk")); + assert!(!is_trading_time(dt(2024, 1, 8, 12, 0), "hk")); + assert!(is_trading_time(dt(2024, 1, 8, 16, 0), "hk")); +} + +#[test] +fn crypto_always_open() { + assert!(is_trading_time(dt(2024, 1, 6, 3, 0), "crypto")); + assert!(is_trading_time(dt(2024, 12, 25, 0, 0), "crypto")); +} + +#[test] +fn unknown_market_returns_false() { + assert!(!is_trading_time(dt(2024, 1, 8, 10, 0), "unknown_xyz")); +} diff --git a/crates/error-macros/Cargo.toml b/crates/error-macros/Cargo.toml new file mode 100644 index 000000000..98d694e73 --- /dev/null +++ b/crates/error-macros/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "error-macros" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Proc-macro for error type generation. Placeholder, to be migrated." + +[lib] +name = "error_macros" +path = "src/lib.rs" +proc-macro = true + +[dependencies] +anyhow = "1" +proc-macro2 = "1" +quote = "1" +syn = "2" +thiserror = "2" + +[dev-dependencies] +anyhow = "1" +serde = { workspace = true } +serde_json = "1" +thiserror = "2" diff --git a/crates/error-macros/src/err.rs b/crates/error-macros/src/err.rs new file mode 100644 index 000000000..0dbf3e9d6 --- /dev/null +++ b/crates/error-macros/src/err.rs @@ -0,0 +1,35 @@ +use proc_macro::TokenStream; +use quote::quote; + +pub fn err_gen_code(ast: &mut syn::DeriveInput) -> TokenStream { + let name = &ast.ident; + + // 生成 `From` 实现 + let from_impl = quote! { + impl std::convert::From for #name { + fn from(error: anyhow::Error) -> Self { + Self::Unexpected(error) + } + } + }; + + // 生成 `serde::Serialize` 实现 + let serialize_impl = quote! { + impl serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(&self.to_string()) + } + } + }; + + // 汇总并返回生成的代码 + let expanded = quote! { + #from_impl + #serialize_impl + }; + + TokenStream::from(expanded) +} diff --git a/crates/error-macros/src/lib.rs b/crates/error-macros/src/lib.rs new file mode 100644 index 000000000..a873cdff1 --- /dev/null +++ b/crates/error-macros/src/lib.rs @@ -0,0 +1,16 @@ +//! error-macros — proc-macro for CZSC error type generation. +//! +//! Migrated from rs-czsc commit `47ef6efa` (see docs/MIGRATION_NOTES.md §1). +//! Provides `CZSCErrorDerive` which auto-implements `From` and +//! `serde::Serialize` for enum error types. + +use proc_macro::TokenStream; +use syn::{DeriveInput, parse_macro_input}; + +mod err; + +#[proc_macro_derive(CZSCErrorDerive, attributes(error, from))] +pub fn derive_utils_error(input: TokenStream) -> TokenStream { + let mut ast = parse_macro_input!(input as DeriveInput); + err::err_gen_code(&mut ast) +} diff --git a/crates/error-macros/tests/test_derive.rs b/crates/error-macros/tests/test_derive.rs new file mode 100644 index 000000000..65515633d --- /dev/null +++ b/crates/error-macros/tests/test_derive.rs @@ -0,0 +1,25 @@ +//! Phase D.0a — RED test: CZSCErrorDerive must produce From +//! and serde::Serialize impls for an annotated enum. + +use error_macros::CZSCErrorDerive; +use thiserror::Error; + +#[derive(Debug, Error, CZSCErrorDerive)] +enum DummyError { + #[error("unexpected: {0}")] + Unexpected(anyhow::Error), +} + +#[test] +fn from_anyhow_blanket_impl_exists() { + let err: anyhow::Error = anyhow::anyhow!("boom"); + let dummy: DummyError = err.into(); + assert!(matches!(dummy, DummyError::Unexpected(_))); +} + +#[test] +fn serialize_emits_string() { + let dummy = DummyError::Unexpected(anyhow::anyhow!("boom")); + let json = serde_json::to_string(&dummy).unwrap(); + assert!(json.contains("boom"), "expected serialized payload, got {json}"); +} diff --git a/crates/error-support/Cargo.toml b/crates/error-support/Cargo.toml new file mode 100644 index 000000000..b280df48a --- /dev/null +++ b/crates/error-support/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "error-support" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Common error types for the czsc Rust workspace. Placeholder, to be migrated." + +[lib] +name = "error_support" +path = "src/lib.rs" + +[dependencies] +anyhow = "1" diff --git a/crates/error-support/src/lib.rs b/crates/error-support/src/lib.rs new file mode 100644 index 000000000..d3459ce6e --- /dev/null +++ b/crates/error-support/src/lib.rs @@ -0,0 +1,38 @@ +use std::fmt::Write; + +pub fn expand_error_chain(err: &anyhow::Error) -> String { + let mut error_chain = String::new(); + let mut current_error: Option<&(dyn std::error::Error + 'static)> = Some(err.as_ref()); + // 标记是否是错误链中的第一个错误 + let mut is_first = true; + + // 遍历整个错误链,直到没有更多的源错误 + while let Some(error) = current_error { + // 除了第一个错误之外,为每个后续错误添加 "Caused by: " 前缀 + if !is_first { + // 这里使用 `unwrap()` 是安全的,因为 `write!` 向 `String` 写入内容不会失败(`String` 会自动扩容)- 唯一可能导致 `write!` 失败的情况是内存分配失败,这种情况下程序已经处于不可恢复状态。 + write!(error_chain, "\nCaused by: ").unwrap(); + } + // 将当前错误信息写入错误链字符串 + write!(error_chain, "{error}").unwrap(); + // 获取下一个源错误(如果存在) + current_error = error.source(); + is_first = false; + } + + error_chain +} + +/// Copy from anyhow::bail! +#[macro_export] +macro_rules! czsc_bail { + ($msg:literal $(,)?) => { + return Err(anyhow::anyhow!($msg).into()) + }; + ($err:expr $(,)?) => { + return Err(anyhow::anyhow!($err).into()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err(anyhow::anyhow!($fmt, $($arg)*).into()) + }; +} diff --git a/crates/error-support/tests/test_chain.rs b/crates/error-support/tests/test_chain.rs new file mode 100644 index 000000000..945bfb498 --- /dev/null +++ b/crates/error-support/tests/test_chain.rs @@ -0,0 +1,26 @@ +//! Phase D.0b — RED test: expand_error_chain must walk the full source chain +//! using `Caused by:` separators, and czsc_bail! must short-circuit with an +//! anyhow-wrapped error. + +use error_support::{czsc_bail, expand_error_chain}; + +#[test] +fn expand_error_chain_walks_sources() { + let inner = std::io::Error::new(std::io::ErrorKind::Other, "leaf"); + let mid = anyhow::Error::new(inner).context("middle layer"); + let outer = mid.context("outermost"); + let chain = expand_error_chain(&outer); + assert!(chain.contains("outermost"), "chain missing outer: {chain}"); + assert!(chain.contains("Caused by: middle layer"), "missing middle: {chain}"); + assert!(chain.contains("Caused by: leaf"), "missing leaf: {chain}"); +} + +fn callee() -> Result<(), anyhow::Error> { + czsc_bail!("kaboom"); +} + +#[test] +fn czsc_bail_returns_err() { + let err = callee().unwrap_err(); + assert!(err.to_string().contains("kaboom")); +} diff --git a/czsc/__init__.py b/czsc/__init__.py index cb0da152a..145c0a072 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -1,30 +1,125 @@ """ -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2019/10/29 15:01 +CZSC(缠中说禅)量化分析框架的顶层包入口模块 + +职责: + 1. 统一对外暴露公共 API(缠论核心类、信号、交易器、策略、回测工具等) + 2. 完成 Rust 后端(czsc._native,由 PyO3 编译而来)与 Python 适配层之间的桥接 + 3. 维护 ``__all__`` 公共契约,保证 ``from czsc import *`` 行为可控 + 4. 通过模块级 ``__getattr__`` 提供按需加载(懒加载)的子模块与符号 + —— 仅在首次访问时才执行 ``importlib.import_module``,可显著降低冷启动耗时 + 同时避免循环依赖 + +约定: + - 所有以单下划线开头的对象(如 ``_sys``、``_LAZY_MODULES``)均为模块内部使用, + 不属于公共 API,禁止外部直接依赖 + - ``czsc.ta``、``czsc.CZSC`` 等高频符号优先来自 Rust 实现,性能更佳 + - 升级 Rust 版本时需同步检查 ``__all__`` 与本模块导入区,确保契约一致 + +作者: zengbin93 +邮箱: zeng_bin8888@163.com +创建时间: 2019/10/29 15:01 """ -from . import envs, traders, utils -from .core import ( +import sys as _sys + +# 子包按"始终需要 / 启动期可加载"为标准统一引入: +# - _native: PyO3 编译的 Rust 扩展,缠论核心实现位于此 +# - connectors/envs/sensors/signals/traders/utils: 业务层模块,多数函数会被立即用到 +# 这些子包必须立即加载,因下方的 from ... import 语句直接依赖它们 +from . import _native, connectors, envs, sensors, signals, traders, utils + +# === Rust 扩展的 ta 命名空间桥接 === +# 让 ``czsc.ta.*`` 直接来自 Rust 扩展(czsc._native.ta),不再使用 Python 包装层。 +# 通过同时设置模块属性与 ``sys.modules``,保证以下两种导入方式都能命中 Rust 实现: +# import czsc.ta # 解析为 czsc._native.ta +# from czsc.ta import ema # 解析为 czsc._native.ta.ema +# 注意:在文件下方 ``from .utils import ...`` 之后还会重新赋值一次,避免 +# 旧版包装模块在导入链上覆盖此处别名(见后文"重新应用 czsc.ta 别名"段)。 +ta = _native.ta +_sys.modules["czsc.ta"] = _native.ta + +# === 缠论核心数据类型与算法(来自 Rust 扩展) === +# 这些符号是 CZSC 公共 API 的"硬契约",在大量业务代码与下游项目中被直接 import。 +# 命名与原 Python 实现保持一致,便于无缝迁移。各类型的语义: +# BI - 笔(缠论中由分型连接形成的最小走势单元) +# CZSC - 缠论分析器主类,承载分型/笔/线段的识别管线 +# FX - 分型(顶分型 / 底分型) +# ZS - 中枢(多笔重叠形成的盘整区间) +# BarGenerator - K 线合成器,用于多周期联立 +# Direction/Mark/Operate - 方向/标记/操作枚举 +# Event/Signal/Position - 事件、信号、持仓三件套 +# ParsedSignalDoc - 信号文档解析结果 +# FakeBI/NewBar/RawBar - K 线及其衍生抽象(原始/合成/虚拟笔) +# 工具函数: +# boll_positions/ema/sma/ultimate_smoother/rolling_rank - 技术指标 +# check_bi/check_fx/check_fxs/remove_include - 缠论结构校验 +# freq_end_time/is_trading_time - 周期与交易时段判定 +# parse_signal_doc - 信号声明字符串解析 +from ._native import ( + BI, CZSC, + FX, ZS, - CzscJsonStrategy, - CzscStrategyBase, + BarGenerator, Direction, Event, + FakeBI, Freq, + Mark, NewBar, Operate, + ParsedSignalDoc, Position, RawBar, Signal, - WeightBacktest, - daily_performance, - format_standard_kline, - top_drawdowns, + boll_positions, + check_bi, + check_fx, + check_fxs, + ema, + freq_end_time, + is_trading_time, + parse_signal_doc, + remove_include, + rolling_rank, + sma, + ultimate_smoother, ) + +# === format_standard_kline 的 Python 包装 === +# Rust 扩展仅提供一个直通桩(``Vec -> Vec``),无法直接接受 +# pandas DataFrame。但下游用户期望的签名是 ``DataFrame + Freq -> List[RawBar]``。 +# 因此 ``czsc/_format_standard_kline.py`` 是一个 Python 适配层,逐行通过 +# PyO3 构造器创建 RawBar 实例,签名与 rs-czsc 提供的 Python 端 API 完全一致。 +# 这样既保留了"用户传 DataFrame 即可"的便利性,又复用了 Rust 端的内存布局与类型校验。 +from czsc._format_standard_kline import format_standard_kline + +# === 回测引擎(来自第三方包 wbt) === +# WeightBacktest - 基于权重序列的向量化回测器 +# daily_performance - 日度绩效统计(年化、夏普、最大回撤等) +# top_drawdowns - 提取前 N 大回撤区间,用于风险归因分析 +from wbt import WeightBacktest, daily_performance, top_drawdowns + from typing import TYPE_CHECKING +# === 探索性数据分析(EDA)相关函数 === +# 这些工具函数用于因子研究、特征工程与轻量级策略评估, +# 在多数交易研究流程中是高频使用项,因此选择在导入期就一并暴露: +# cal_symbols_factor / cal_trade_price - 多品种因子计算与交易价归一化 +# cal_yearly_days - 计算年化基准日数(区分 A 股/期货/数字货币) +# cross_sectional_strategy - 横截面排序策略 +# dif_long_bear / sma_long_bear - 多/空趋势判定 +# limit_leverage - 杠杆约束 +# make_price_features - 价格类特征工厂 +# mark_cta_periods / mark_volatility - CTA 区间与波动率分段 +# min_max_limit - 数值裁剪 +# monotonicity - 单调性检验 +# remove_beta_effects - 去除 Beta 系统性影响 +# rolling_layers - 分层回看 +# tsf_type - 时序特征类别标注 +# turnover_rate - 换手率统计 +# twap / vwap - 时间/成交量加权均价 +# unify_weights / weights_simple_ensemble - 权重归一与简单集成 from .eda import ( cal_symbols_factor, cal_trade_price, @@ -48,18 +143,73 @@ weights_simple_ensemble, ) +# 仅在类型检查阶段(如 mypy / pyright / IDE 静态分析)暴露这些懒加载子模块, +# 运行期它们会经由下方的 ``__getattr__`` 按需导入。这种写法既能让静态工具 +# 正确解析 ``czsc.svc`` 等用法,又不会在导入 czsc 时就把这些重量级子包 +# (例如 svc 依赖 plotly/streamlit)拉起来。 if TYPE_CHECKING: - from . import aphorism, cwc, fsa, mock, svc + from . import aphorism, fsa, mock, svc + +# === 策略门面(Facade) === +# strategies.py 是 Python 端的薄封装:在 Rust 实现的 Trader 基础上,提供 +# CzscStrategyBase(策略开发抽象基类)与 CzscJsonStrategy(JSON 配置式策略), +# 隔离用户层 API 与底层 Rust 类型,便于策略快速搭建与序列化。 +from .strategies import CzscJsonStrategy, CzscStrategyBase + +# === 交易器与信号管理 API === +# traders 子包对外暴露的统一入口,由 Rust 后端驱动: +# CzscSignals/CzscTrader - 多周期信号合成与交易调度核心 +# SignalsParser - 信号声明字符串解析器 +# derive_signals_* - 从持仓/事件反推所需信号配置或周期集合 +# generate_czsc_signals - 标准化信号生成入口 +# get_signals_* - 信号配置/周期获取辅助函数 +# get_unique_signals - 信号去重工具,回测前置预处理常用 from .traders import ( CzscSignals, CzscTrader, SignalsParser, - check_signals_acc, + derive_signals_config, + derive_signals_freqs, generate_czsc_signals, get_signals_config, get_signals_freqs, get_unique_signals, ) + +# === 研究/优化入口(research.py,Rust 后端) === +# 这些函数是策略研究流程的主入口,封装了"参数批量回测""开仓/平仓优化""复盘" +# 等高层操作,对应 czsc/research.py 中的统一研究 API: +# build_open_optim_positions - 构造开仓参数优化所需的 Position 列表 +# build_exit_optim_positions - 构造平仓参数优化所需的 Position 列表 +# run_optimize_batch - 多参数组合批量优化 +# run_replay - 单标的回放 +# run_research - 顶层一键式研究流水线 +from .research import ( + build_exit_optim_positions, + build_open_optim_positions, + run_optimize_batch, + run_replay, + run_research, +) + +# === 通用工具函数集合(czsc.utils) === +# 这些工具按使用频率从 czsc.utils 中提升至顶级命名空间,便于直接 ``czsc.xxx`` 调用。 +# 主要分组: +# - 缓存类:DiskCache / disk_cache / clear_cache / clear_expired_cache / +# empty_cache_path / get_dir_size / home_path +# - 数据源/IO:DataClient / AliyunOSS / read_json / save_json / to_arrow +# - 加解密:fernet_encrypt / fernet_decrypt / generate_fernet_key / +# get_url_token / set_url_token +# - 序列化:dill_dump / dill_load +# - 时间/周期:freqs_sorted / resample_to_daily +# - 命名空间/反射:code_namespace / get_py_namespace / import_by_name +# - 绩效统计:cross_sectional_ic / holds_performance / +# rolling_daily_performance / risk_free_returns +# - 调试:print_df_sample / mac_address / x_round +# - PSI:psi(群体稳定性指数) +# - 网格/装饰器:create_grid_params / timeout_decorator +# - 指标增量更新:update_bbars / update_nxb / update_tbars +# - 指数成分:index_composition from .utils import ( AliyunOSS, DataClient, @@ -85,7 +235,6 @@ import_by_name, index_composition, mac_address, - overlap, print_df_sample, psi, read_json, @@ -94,7 +243,6 @@ rolling_daily_performance, save_json, set_url_token, - ta, timeout_decorator, to_arrow, update_bbars, @@ -103,32 +251,81 @@ x_round, ) +# === 重新应用 czsc.ta 别名(关键步骤,勿删) === +# 必须在 ``from .utils import ...`` 之后再次执行一次 czsc.ta 的别名绑定, +# 原因:旧版 utils 模块在导入链上可能间接触发 ``czsc.utils.ta`` 子模块的 +# 副作用导入,从而将 sys.modules['czsc.ta'] 指向 Python 包装版本,覆盖 +# 文件顶部设置的 Rust 版本。这里再绑定一次确保 Rust 实现胜出。 +ta = _native.ta +_sys.modules["czsc.ta"] = _native.ta + +# === 公共 API 契约 === +# ``__all__`` 显式声明 ``from czsc import *`` 行为暴露的符号集合。 +# 维护规则: +# 1. 任何顶级 import 中出现的公共符号都需登记于此(按主题分组排列) +# 2. 私有符号(单下划线开头)禁止登记 +# 3. 修改本列表等价于修改公共契约,必须在 CHANGELOG / 迁移指南中说明 __all__ = [ - "WeightBacktest", - "daily_performance", - "top_drawdowns", - "envs", - "traders", - "utils", + "BI", "CZSC", + "FX", "ZS", + "BarGenerator", "Direction", "Event", + "FakeBI", "Freq", + "Mark", "NewBar", "Operate", + "ParsedSignalDoc", "Position", "RawBar", "Signal", + "boll_positions", + "check_bi", + "check_fx", + "check_fxs", + "ema", "format_standard_kline", + "freq_end_time", + "is_trading_time", + "parse_signal_doc", + "remove_include", + "rolling_rank", + "sma", + "ultimate_smoother", + # —— 来自 wbt 的回测组件 —— + "WeightBacktest", + "daily_performance", + "top_drawdowns", + # —— 始终预先加载的子包 —— + "connectors", + "envs", + "sensors", + "signals", + "traders", + "utils", + # —— 交易器 API(czsc/traders/__init__.py,Rust 后端实现) —— "CzscSignals", "CzscTrader", "SignalsParser", - "check_signals_acc", + "derive_signals_config", + "derive_signals_freqs", "generate_czsc_signals", "get_signals_config", "get_signals_freqs", "get_unique_signals", + # —— 策略门面(czsc/strategies.py,Python 层对 Rust Trader 的封装) —— + "CzscStrategyBase", + "CzscJsonStrategy", + # —— 研究/优化入口(czsc/research.py,Rust 后端) —— + "build_exit_optim_positions", + "build_open_optim_positions", + "run_optimize_batch", + "run_replay", + "run_research", + # —— 通用工具(来自 czsc/utils) —— "AliyunOSS", "DataClient", "DiskCache", @@ -153,7 +350,6 @@ "import_by_name", "index_composition", "mac_address", - "overlap", "print_df_sample", "psi", "read_json", @@ -173,15 +369,10 @@ "fsa", "aphorism", "mock", - "cwc", - "CzscStrategyBase", - "CzscJsonStrategy", "capture_warnings", "execute_with_warning_capture", "adjust_holding_weights", "log_strategy_info", - "calculate_bi_info", - "symbols_bi_infos", "plot_czsc_chart", "KlineChart", "check_kline_quality", @@ -212,26 +403,35 @@ "welcome", ] +# === 包元信息 === +# 这些字段会被 setuptools / pip / sphinx 等工具读取,发布前需同步更新。 +# __date__ 采用 ``YYYYMMDD`` 格式,便于排序与追溯。 __version__ = "0.10.12" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" __date__ = "20260308" +# === 懒加载子模块映射表 === +# 键为公开访问名,值为完整模块路径。 +# 这些子包通常依赖较重(如 svc 依赖 plotly/streamlit、fsa 依赖飞书 SDK、 +# mock 涉及大量随机数生成器),若在 import czsc 时一次性全部加载,会显著 +# 拖慢 CLI 工具与服务启动速度。延迟到首次访问时再加载,可保持冷启动轻量。 _LAZY_MODULES = { "svc": "czsc.svc", "fsa": "czsc.fsa", "aphorism": "czsc.aphorism", "mock": "czsc.mock", - "cwc": "czsc.traders.cwc", } +# === 懒加载属性映射表 === +# 键为公开访问名,值为 (模块路径, 模块内符号名) 二元组。 +# 用于把分散在工具子模块中的少量高频函数/类提升到顶级命名空间, +# 同时保留按需加载、避免在导入期触发不必要的副作用。 _LAZY_ATTRS = { "capture_warnings": ("czsc.utils.warning_capture", "capture_warnings"), "execute_with_warning_capture": ("czsc.utils.warning_capture", "execute_with_warning_capture"), "adjust_holding_weights": ("czsc.utils.trade", "adjust_holding_weights"), "log_strategy_info": ("czsc.utils.log", "log_strategy_info"), - "calculate_bi_info": ("czsc.utils.bi_info", "calculate_bi_info"), - "symbols_bi_infos": ("czsc.utils.bi_info", "symbols_bi_infos"), "plot_czsc_chart": ("czsc.utils.plotting.kline", "plot_czsc_chart"), "KlineChart": ("czsc.utils.plotting.kline", "KlineChart"), "check_kline_quality": ("czsc.utils.kline_quality", "check_kline_quality"), @@ -239,13 +439,37 @@ def __getattr__(name): + """ + 模块级懒加载钩子(PEP 562) + + Python 在常规属性查找失败时会回退调用本函数,因此可借助它实现 + 延迟导入,既不破坏 ``czsc.svc.xxx``、``czsc.capture_warnings`` 等 + 用户期望的访问形式,又能避免导入开销。 + + 实现细节: + 1. 命中 ``_LAZY_MODULES`` —— 调用 ``importlib.import_module`` 加载, + 并把模块对象写入 ``globals()``,后续访问直接走常规路径,无再次开销 + 2. 命中 ``_LAZY_ATTRS`` —— 加载子模块后取出指定属性,同样缓存到全局命名空间 + 3. 全部未命中 —— 抛 ``AttributeError``(必须保留,否则 ``hasattr`` 会出错) + + 参数: + name: 用户尝试访问的属性名,例如 ``"svc"``、``"plot_czsc_chart"`` + + 返回: + 加载完成的模块对象或目标属性 + + 异常: + AttributeError: 当 ``name`` 既不在 ``_LAZY_MODULES`` 也不在 ``_LAZY_ATTRS`` 中 + """ import importlib + # 路径 1:懒加载子模块(按需 import,再缓存到全局) if name in _LAZY_MODULES: module = importlib.import_module(_LAZY_MODULES[name]) globals()[name] = module return module + # 路径 2:懒加载子模块中的某个属性(先 import 再 getattr,最后缓存) if name in _LAZY_ATTRS: mod_path, attr_name = _LAZY_ATTRS[name] module = importlib.import_module(mod_path) @@ -253,15 +477,31 @@ def __getattr__(name): globals()[name] = attr return attr + # 路径 3:未注册的属性 —— 严格抛错以维持标准 Python 语义(hasattr 等场景依赖此行为) raise AttributeError(f"module 'czsc' has no attribute {name!r}") def welcome(): + """ + CLI/交互式环境下的欢迎信息打印函数 + + 用途: + - 打印当前 CZSC 版本号、日期与一段随机的"缠论格言"(aphorism 子模块) + - 打印关键环境变量当前值,便于排查"实际生效配置"问题 + - 当本地缓存目录体积超过 1 GB 时,给出清理提示,避免长期堆积 + + 设计动机: + 把 aphorism 的导入推迟到函数体内(而非模块顶部),是为了避免 + ``import czsc`` 时强制依赖该子包;若用户没有调用 ``welcome()``, + aphorism 就不会被加载。 + """ from czsc import aphorism print(f"欢迎使用CZSC!当前版本标识为 {__version__}@{__date__}\n") aphorism.print_one() print(f"CZSC环境变量:czsc_min_bi_len = {envs.get_min_bi_len()}; czsc_max_bi_num = {envs.get_max_bi_num()}; ") + # 1 GB 阈值:超出即提示用户主动清理,避免缓存目录无限膨胀; + # 用 ``pow(1024, 3)`` 而非 ``10**9`` 是为了得到精确的二进制 GB(GiB)。 if get_dir_size(home_path) > pow(1024, 3): print(f"{home_path} 目录缓存超过1GB,请适当清理。调用 czsc.empty_cache_path() 可以直接清空缓存") diff --git a/czsc/_compat.py b/czsc/_compat.py new file mode 100644 index 000000000..53d89cd92 --- /dev/null +++ b/czsc/_compat.py @@ -0,0 +1,418 @@ +""" +Rust/Python 兼容层(Compatibility Shim) + +本模块封装了 Python 端遗留数据结构与 Rust 后端(rs-czsc / wbt)所需运行时格式 +之间的相互转换逻辑,是迁移阶段保持"老代码不改、底层换 Rust"的桥梁。 + +涵盖的转换族: + 1. 周期(Freq)字符串排序 —— sort_freqs + 2. 信号配置(dict) —— signal_config_to_runtime / signal_config_to_public + 3. Position 序列化转 Rust 期望布局 —— position_dump_to_runtime + 4. K 线 list[RawBar] / DataFrame 标准化 —— bars_to_dataframe + 5. 候选事件结构归一 —— normalize_candidate_event(s) + 6. JSON 读写 —— load_json / dump_json + 7. 字符串/字面量转义、信号 KV 拼接 —— py_escape_str / py_repr_* + +设计原则: + - 所有转换函数无副作用,输入不可变(统一通过 ``dict(x)`` / ``list(x)`` 复制) + - 字段缺失走"宽进严出"策略:能用默认值兜底就兜底,无法兜底就显式 raise + - 不在此引入业务依赖(仅依赖 stdlib 与 pandas),保持兼容层最小内聚 +""" + +from __future__ import annotations + +import hashlib +import json +from pathlib import Path +from typing import Any, Iterable + +import pandas as pd + + +# 周期字符串到排序权重的映射表 +# 数字越小代表越小级别(越高频),按此权重做稳定排序后, +# 同一策略中"小周期 -> 大周期"的展示顺序与缠论惯例一致。 +# 未在本表中的周期会被赋值 10_000,并用字面量做次级排序兜底。 +_FREQ_ORDER = { + "Tick": 0, + "逐笔": 0, + "1分钟": 1, + "2分钟": 2, + "3分钟": 3, + "4分钟": 4, + "5分钟": 5, + "6分钟": 6, + "10分钟": 7, + "12分钟": 8, + "15分钟": 9, + "20分钟": 10, + "30分钟": 11, + "60分钟": 12, + "120分钟": 13, + "240分钟": 14, + "360分钟": 15, + "日线": 16, + "周线": 17, + "月线": 18, + "季线": 19, + "年线": 20, +} + + +def sort_freqs(freqs: Iterable[str]) -> list[str]: + """ + 按缠论惯用顺序对周期字符串去重并排序 + + 参数: + freqs: 任意可迭代的周期字符串集合(允许包含 None / 空串,会被过滤) + + 返回: + 从高频到低频依次排列的去重后的周期字符串列表 + + 备注: + - 未登记的周期会被排到末尾(权重 10_000),并按字典序作次级排序 + - 排序结果稳定,方便用于 UI 展示与日志输出对齐 + """ + unique = {str(x) for x in freqs if x} + return sorted(unique, key=lambda x: (_FREQ_ORDER.get(x, 10_000), x)) + + +def signal_config_to_runtime(cfg: dict[str, Any]) -> dict[str, Any]: + """ + 将"用户层"信号配置 dict 转换为 Rust 后端期望的运行时三段式结构 + + 用户在 Python 端常以两种风格书写信号配置: + 风格 A(带 ``params`` 子字典): + {"name": "tas.cci_V230402", "freq": "30分钟", "params": {"di": 1, "n": 14}} + 风格 B(参数与 name/freq 平铺在同一层): + {"name": "tas.cci_V230402", "freq": "30分钟", "di": 1, "n": 14} + + 本函数把以上两种写法都归一为: + {"name": "cci_V230402", "freq": "30分钟", "params": {"di": 1, "n": 14}} + + 其中 ``name`` 会被 :func:`_strip_signal_name` 截断为最后一段(去模块前缀), + 便于 Rust 端按短名直接派发到信号实现。 + + 参数: + cfg: 任意一种风格的信号配置 dict + + 返回: + 三段式 dict(``name`` / ``freq`` / ``params``),可直接喂给 Rust API + """ + # 风格 A:已经显式区分 params,仅做名称清洗与浅拷贝 + if "params" in cfg: + return { + "name": _strip_signal_name(cfg["name"]), + "freq": cfg.get("freq"), + "params": dict(cfg.get("params", {})), + } + + # 风格 B:除 name/freq/signals_module/module 以外的所有键都视为参数 + # 这里之所以同时排除 signals_module 与 module,是因为不同代码版本曾用过两种命名, + # 都属于"模块定位元信息",不应进入 params。 + params = {} + for key, value in cfg.items(): + if key in {"name", "freq", "signals_module", "module"}: + continue + params[key] = value + return { + "name": _strip_signal_name(cfg["name"]), + "freq": cfg.get("freq"), + "params": params, + } + + +def signal_config_to_public(cfg: dict[str, Any], signals_module_name: str) -> dict[str, Any]: + """ + 将运行时三段式信号配置反向转换为"用户层平铺式"配置 + + 主要用途: + - 把 Rust 内部存储的紧凑配置对外展示给用户(如 dump 到 JSON) + - 当 ``name`` 缺少模块前缀时,用 ``signals_module_name`` 自动补齐, + 保证导出的配置可被独立加载(无需依赖外部上下文) + + 参数: + cfg: 任意风格的信号配置(先经 :func:`signal_config_to_runtime` 归一) + signals_module_name: 信号实现所在模块名,用于补全 name 前缀;为空则不补 + + 返回: + 平铺式 dict,name/freq 在外层,参数与之同级 + """ + runtime = signal_config_to_runtime(cfg) + name = runtime["name"] + # 若 name 不含点号,说明缺少模块前缀;按 signals_module_name 补全为完整路径 + if signals_module_name and "." not in name: + name = f"{signals_module_name}.{name}" + out = {"name": name, "freq": runtime.get("freq")} + out.update(runtime.get("params", {})) + return out + + +def position_dump_to_runtime(payload: dict[str, Any]) -> dict[str, Any]: + """ + 将 Position 的 JSON dump 结果转换为 Rust 运行时期望的事件/信号格式 + + Python 端 Position 的 ``opens`` / ``exits`` 字段中,每个 Event 的 + ``signals_all`` / ``signals_any`` / ``signals_not`` 元素既可能是 + ``"key_value"`` 字符串,也可能是 ``{"key": ..., "value": ...}`` 字典。 + Rust 端只接受字符串形式,本函数统一转为字符串。 + + 参数: + payload: Position.dump() 的输出(dict 形式) + + 返回: + 浅拷贝后的 dict,opens/exits 内的信号字段已全部规范化为字符串列表 + """ + out = dict(payload) + # 同时处理"开仓事件"与"平仓事件"两个分支 + for event_key in ("opens", "exits"): + events = [] + for event in list(out.get(event_key) or []): + event_copy = dict(event) + # 三种信号关系字段:必须存在但允许为空列表 + for sig_key in ("signals_all", "signals_any", "signals_not"): + event_copy[sig_key] = [ + signal_kv_to_string(sig) for sig in list(event_copy.get(sig_key) or []) + ] + events.append(event_copy) + out[event_key] = events + return out + + +def bars_to_dataframe(bars: Any, symbol: str | None = None) -> pd.DataFrame: + """ + 将多种形式的 K 线表示统一转换为 Rust IPC 读取器期望的 DataFrame + + 支持的输入: + - ``pd.DataFrame`` —— 直接拷贝并补齐缺失列 + - ``list[RawBar]`` / ``tuple[RawBar]`` —— 通过 getattr 读取属性逐行构造 + + 输出契约(必须严格满足,否则 Rust 端反序列化会报错): + 列顺序:``["symbol", "dt", "open", "close", "high", "low", "vol", "amount"]`` + 类型: + symbol -> str + dt -> datetime64[ns] + 其余 6 列 -> float64 + 清洗: + - 任一关键列为 NaN 的行会被丢弃 + - 同一时间戳重复时保留最后一条(last write wins) + - 按 dt 升序排列、重置索引 + + 参数: + bars: K 线集合(DataFrame 或 RawBar 列表) + symbol: 当 bars 中缺少 symbol 列或部分缺失时用于回填的标的代码 + + 返回: + 规范化后的 DataFrame,可直接传入 Rust IPC 读取通道 + + 异常: + TypeError: bars 类型不在支持列表中 + ValueError: 缺少必需列(dt 或转换后仍然缺失的其他列) + """ + if isinstance(bars, pd.DataFrame): + out = bars.copy() + elif isinstance(bars, (list, tuple)): + # 遍历对象列表,按字段名取值;缺失字段则使用 None 占位,后续 dropna 兜底 + rows = [] + for bar in bars: + rows.append( + { + "symbol": getattr(bar, "symbol", symbol), + "dt": getattr(bar, "dt", None), + "open": getattr(bar, "open", None), + "close": getattr(bar, "close", None), + "high": getattr(bar, "high", None), + "low": getattr(bar, "low", None), + "vol": getattr(bar, "vol", None), + "amount": getattr(bar, "amount", None), + } + ) + out = pd.DataFrame(rows) + else: + raise TypeError(f"unsupported bars type: {type(bars)!r}") + + # 兜底补齐 symbol 列:完全缺失就整列填充;部分缺失(NaN)就用 fillna 回填 + if "symbol" not in out.columns: + out["symbol"] = symbol + elif symbol: + out["symbol"] = out["symbol"].fillna(symbol) + + # dt 是核心字段,必须存在;缺失意味着源数据本身有问题,立即报错 + if "dt" not in out.columns: + raise ValueError("bars is missing dt column") + + # amount(成交额)允许由 vol*close 推算得出,方便上游只提供成交量的场景 + if "amount" not in out.columns: + out["amount"] = out["vol"] * out["close"] + + # 二次校验:所有必需列必须齐全 + required = ["symbol", "dt", "open", "close", "high", "low", "vol", "amount"] + missing = [col for col in required if col not in out.columns] + if missing: + raise ValueError(f"bars is missing columns: {missing}") + + out = out[required].copy() + out["dt"] = pd.to_datetime(out["dt"]) + out["symbol"] = out["symbol"].astype(str) + for col in ["open", "close", "high", "low", "vol", "amount"]: + # 强制转 float64:Rust IPC 读取器要求六个数值列均为 Float64; + # 而 wbt 提供的 mock 数据中 vol 默认是 int64,不显式转换会导致类型错误。 + out[col] = pd.to_numeric(out[col], errors="coerce").astype("float64") + # 删除关键字段缺失的脏行,按 dt 升序去重并重置索引 + out = out.dropna(subset=["dt", "open", "close", "high", "low", "vol", "amount"]) + out = out.sort_values("dt").drop_duplicates(subset=["dt"], keep="last").reset_index(drop=True) + return out + + +def normalize_candidate_event(event: dict[str, Any]) -> dict[str, Any]: + """ + 将一个候选 Event 字典归一化为标准结构 + + 历史上 Event 配置出现过两种风格: + 1. 旧风格:信号关系字段位于 ``factors[0]`` 子节点中 + 2. 新风格:信号关系字段直接在 Event 顶层 + + 本函数统一两种写法,输出固定结构,便于上游统一处理: + { + "name": str, # Event 名称(缺失时退回到 factors[0].name;再缺失为空串) + "operate": str, # 操作类型(必须存在,否则触发 KeyError) + "signals_all": list, # 必须全部命中 + "signals_any": list, # 至少一个命中 + "signals_not": list, # 必须全部不命中 + } + + 参数: + event: 任意风格的候选事件 dict + + 返回: + 标准化后的浅拷贝 dict + """ + raw = dict(event) + # 取出旧风格的 factors[0],若存在则其内的信号字段作为退路 + factors = list(raw.get("factors") or []) + factor = dict(factors[0]) if factors else {} + + # "顶层优先、factor 兜底"的取值策略,兼容两种风格 + signals_all = list(raw.get("signals_all") or factor.get("signals_all") or []) + signals_any = list(raw.get("signals_any") or factor.get("signals_any") or []) + signals_not = list(raw.get("signals_not") or factor.get("signals_not") or []) + name = raw.get("name") or factor.get("name") or "" + + return { + "name": name, + "operate": raw["operate"], + "signals_all": signals_all, + "signals_any": signals_any, + "signals_not": signals_not, + } + + +def normalize_candidate_events(events: Iterable[dict[str, Any]]) -> list[dict[str, Any]]: + """批量调用 :func:`normalize_candidate_event`,返回标准化后的事件列表""" + return [normalize_candidate_event(event) for event in events] + + +def md5_upper8(value: str) -> str: + """ + 计算字符串 MD5 哈希并截取前 8 位大写表示 + + 用途: + 生成简短、可复现的标识符(信号 ID、缓存 key 等),无加密强度要求。 + 若有抗碰撞需求,请改用 SHA-256 等更长哈希。 + """ + return hashlib.md5(value.encode("utf-8")).hexdigest()[:8].upper() + + +def py_escape_str(value: str) -> str: + """ + 转义字符串中的反斜杠和单引号,便于嵌入 Python 源代码字面量 + + 应用场景:把动态生成的字符串拼接进代码模板(如生成策略骨架文件), + 避免引号冲突和路径中的反斜杠被解释为转义符。 + """ + return value.replace("\\", "\\\\").replace("'", "\\'") + + +def py_repr_list_str(items: list[str]) -> str: + """ + 将字符串列表渲染为 Python 字面量形式的源码片段 + + 示例: + ['abc', "x'y"] -> "['abc', 'x\\'y']" + + 空列表直接返回 ``"[]"``,避免输出 ``"[\n]"`` 等空白扰动。 + """ + if not items: + return "[]" + return "[" + ", ".join(f"'{py_escape_str(item)}'" for item in items) + "]" + + +def py_repr_json(value: Any) -> str: + """ + 将任意 JSON 兼容值递归渲染为 Python 字面量字符串 + + 与 ``repr()`` 的差异: + - 总是使用单引号包裹字符串,便于嵌入双引号 docstring + - 对反斜杠和单引号执行 :func:`py_escape_str` 转义,避免破坏代码结构 + - bool 单独处理,避免被 isinstance(_, int) 分支错误吞掉(True/False 是 int 的子类) + + 支持类型: None / bool / int / float / str / list / dict + 其他类型走 str(value) 后递归处理(兜底)。 + """ + if value is None: + return "None" + if isinstance(value, bool): + # 必须放在 int 检查之前:bool 是 int 的子类,否则会被当成 0/1 输出 + return "True" if value else "False" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + return f"'{py_escape_str(value)}'" + if isinstance(value, list): + return "[" + ", ".join(py_repr_json(item) for item in value) + "]" + if isinstance(value, dict): + return "{" + ", ".join( + f"'{py_escape_str(str(key))}': {py_repr_json(val)}" for key, val in value.items() + ) + "}" + # 兜底:把非典型类型按字符串处理,再走一次递归 + return py_repr_json(str(value)) + + +def load_json(path: str | Path) -> dict[str, Any]: + """读取 UTF-8 编码的 JSON 文件并解析为 dict(不做异常包装,由调用方处理 IO/解析错误)""" + return json.loads(Path(path).read_text(encoding="utf-8")) + + +def dump_json(path: str | Path, payload: dict[str, Any]) -> None: + """ + 将 dict 序列化为 UTF-8 编码的 JSON 文件 + + 使用 ``ensure_ascii=False`` 保留中文字符的可读性, + 便于人工查看与版本控制 diff(不会出现一堆 ``\\uXXXX``)。 + """ + Path(path).write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") + + +def _strip_signal_name(name: str) -> str: + """ + 截取信号 name 中以最后一个 ``.`` 分隔的末段 + + 示例: + "czsc.signals.tas.cci_V230402" -> "cci_V230402" + "cci_V230402" -> "cci_V230402" + + Rust 端按短名直接派发到信号实现,因此调用底层前需要剥离模块前缀。 + """ + return str(name).split(".")[-1] + + +def signal_kv_to_string(signal: dict[str, Any] | str) -> str: + """ + 将信号的 ``{"key": ..., "value": ...}`` 字典形式合并为 ``"key_value"`` 字符串 + + 若入参已经是字符串则原样返回,便于在批量处理时统一调用而不必预判类型。 + Rust 端只接受字符串形式的信号匹配条件,故所有 Python 端的信号最终都会经此函数。 + """ + if isinstance(signal, str): + return signal + return f"{signal['key']}_{signal['value']}" diff --git a/czsc/_format_standard_kline.py b/czsc/_format_standard_kline.py new file mode 100644 index 000000000..84c9a1af5 --- /dev/null +++ b/czsc/_format_standard_kline.py @@ -0,0 +1,119 @@ +""" +公开 API ``format_standard_kline`` 的 Python 包装实现 + +功能定位: + 将标准列布局的 K 线 DataFrame 逐行转换为 ``list[RawBar]``,签名与 + rs-czsc 提供的 ``format_standard_kline(df, freq) -> list[RawBar]`` + 完全对齐,是 Python 端调用 Rust 缠论分析的"输入预处理"环节。 + +实现取舍: + - 不走 Rust 端的 pyarrow 字节流捷径(``format_standard_kline_bytes``), + 原因是 pyarrow 桥接代码尚未迁移到 czsc-python,引入会扩大依赖面 + - 直接遍历 DataFrame 并按行调用 PyO3 暴露的 ``RawBar`` 构造器, + 在 1 万根 K 线规模的测试输入下耗时可忽略(< 100 ms) + - 若未来出现热路径性能问题,可再考虑把 ``format_standard_kline_bytes`` + 迁移过来,本函数保持调用接口不变即可 +""" + +from __future__ import annotations + +import pandas as pd + +from czsc._native import Freq, RawBar + +# 仅暴露公开函数,避免 ``from ... import *`` 时把 _FREQ_MAP 等内部对象带出去 +__all__ = ["format_standard_kline"] + + +# 中文周期名 -> Rust Freq 枚举的查表映射 +# 注意事项: +# 1. 必须与 rs-czsc 的 ``python/rs_czsc/_utils/utils.py::format_standard_kline`` +# 保持完全同步,否则会导致同名周期在两端解析为不同枚举值 +# 2. 大型周期(季线/年线)使用频率较低但仍需登记,避免使用方手写枚举 +# 3. 未登记的字符串会在下方触发 KeyError,便于及时发现拼写错误 +_FREQ_MAP: dict[str, Freq] = { + "逐笔": Freq.Tick, + "1分钟": Freq.F1, + "2分钟": Freq.F2, + "3分钟": Freq.F3, + "4分钟": Freq.F4, + "5分钟": Freq.F5, + "6分钟": Freq.F6, + "10分钟": Freq.F10, + "12分钟": Freq.F12, + "15分钟": Freq.F15, + "20分钟": Freq.F20, + "30分钟": Freq.F30, + "60分钟": Freq.F60, + "120分钟": Freq.F120, + "日线": Freq.D, + "周线": Freq.W, + "月线": Freq.M, + "季线": Freq.S, + "年线": Freq.Y, +} + + +def format_standard_kline(df: pd.DataFrame, freq: Freq | str = Freq.F5) -> list[RawBar]: + """ + 将标准 K 线 DataFrame 转换为 RawBar 对象列表 + + 参数: + df: 标准 K 线布局,必须包含以下八列(缺一即报错): + ``dt`` - 时间戳(``datetime64[ns]`` 或可被 ``pd.to_datetime`` 解析) + ``symbol`` - 标的代码(任意可被 ``str()`` 表示的对象) + ``open`` - 开盘价 + ``close`` - 收盘价 + ``high`` - 最高价 + ``low`` - 最低价 + ``vol`` - 成交量 + ``amount`` - 成交额 + freq: Rust ``Freq`` 枚举值,或中文周期字符串(如 ``"30分钟"``)。 + 传字符串时会经 ``_FREQ_MAP`` 解析为枚举;未登记的字符串将抛 KeyError。 + 默认值 ``Freq.F5`` 仅作占位,生产环境务必显式传入正确周期。 + + 返回: + 与输入 DataFrame 行序一致的 RawBar 列表 + + 异常: + ValueError: DataFrame 缺少必需列 + KeyError: freq 是字符串但未登记于 ``_FREQ_MAP`` + + 备注: + - 函数对入参 df 不做原地修改:仅在 dt 列类型不匹配时才会做一次浅拷贝 + - 价格 / 成交量 / 成交额一律强制 ``float`` 化,与 RawBar 构造器签名匹配, + 避免 numpy 类型在 PyO3 边界产生隐式转换告警 + """ + # 字符串 freq 走查表;未登记的字符串会立即触发 KeyError,便于尽早暴露拼写错误 + if isinstance(freq, str): + freq = _FREQ_MAP[freq] + + # 严格列校验:八个字段缺一即拒绝处理,避免后续 itertuples 报出难懂的 AttributeError + required = ("dt", "symbol", "open", "close", "high", "low", "vol", "amount") + for col in required: + if col not in df.columns: + raise ValueError(f"format_standard_kline: missing column {col!r}") + + # 仅在 dt 列类型不匹配时做一次浅拷贝并转换,避免污染调用方的原始 DataFrame + if df["dt"].dtype != "datetime64[ns]": + df = df.copy() + df["dt"] = pd.to_datetime(df["dt"]) + + # 使用 itertuples(index=False) 而非 iterrows,可显著减少每行的属性访问开销 + # (itertuples 返回 namedtuple,字段访问是 C 层实现) + bars: list[RawBar] = [] + for row in df[list(required)].itertuples(index=False): + bars.append( + RawBar( + symbol=str(row.symbol), + dt=row.dt, + freq=freq, + open=float(row.open), + close=float(row.close), + high=float(row.high), + low=float(row.low), + vol=float(row.vol), + amount=float(row.amount), + ) + ) + return bars diff --git a/czsc/_utils/__init__.py b/czsc/_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/czsc/_utils/_df_convert.py b/czsc/_utils/_df_convert.py new file mode 100644 index 000000000..1687a9a85 --- /dev/null +++ b/czsc/_utils/_df_convert.py @@ -0,0 +1,77 @@ +""" +Pandas <-> Arrow IPC 字节流互转工具 + +用途: + 在 Python 与 Rust 之间通过 Arrow IPC 文件格式做零拷贝(或低拷贝)的数据传递。 + Arrow IPC 是 Rust polars / arrow-rs 原生支持的格式,相较于直接 PyO3 行级 + 传递可大幅降低跨语言边界的开销,尤其适合大数据量场景。 + +注意: + - 本模块属于 ``czsc._utils`` 内部工具,前缀 ``_`` 表示不属于公开 API + - 上层任何调用方都不应该假设字节流的内部结构稳定,仅可视为不透明二进制 + - PyArrow 与 Pandas 的版本组合需保持一致,否则可能在 Schema 推断时报错 +""" + +import pandas as pd +import pyarrow as pa +import pyarrow.ipc as ipc +from typing import Union + +def pandas_to_arrow_bytes(df: Union[pd.DataFrame, pd.Series]) -> bytes: + """ + 将 Pandas DataFrame/Series 序列化为 Arrow IPC 文件格式的字节流 + + 参数: + df: 输入的 Pandas DataFrame 或 Series;若是 Series 会被 PyArrow 自动 + 包装为单列 Table。所有列的 dtype 必须为 PyArrow 可识别的类型, + 否则在 ``pa.Table.from_pandas`` 阶段会抛 ArrowTypeError。 + + 返回: + bytes: 完整的 Arrow IPC File 字节流,可直接通过网络传输或写入文件, + 在另一端使用 :func:`arrow_bytes_to_pd_df` 还原。 + + 备注: + - 使用 IPC File(含尾部 footer)而非 IPC Stream 格式,便于随机读取 + - ``BufferOutputStream`` 在内存中累积,对 GB 级超大表不适合, + 这种规模请改用磁盘文件 + 流式写入的方案 + """ + # 第一步:把 Pandas DataFrame 转为 PyArrow Table,会做一次类型推断 + table = pa.Table.from_pandas(df) + + # 第二步:序列化为 Arrow IPC 文件格式 + # - sink 是 PyArrow 提供的内存缓冲;with 语句确保 footer 被正确写入 + sink = pa.BufferOutputStream() + with ipc.new_file(sink, table.schema) as writer: + writer.write_table(table) + + # 第三步:把 Buffer 转为 Python bytes,方便跨边界传递 + return sink.getvalue().to_pybytes() + + +def arrow_bytes_to_pd_df(arrow_bytes: bytes) -> pd.DataFrame: + """ + 将 Arrow IPC 字节流反序列化为 Pandas DataFrame + + 参数: + arrow_bytes: 由 :func:`pandas_to_arrow_bytes` 或同等格式产生的字节串。 + 不接受 Arrow IPC Stream 格式,传入会触发 InvalidArgument。 + + 返回: + pd.DataFrame: 与原始 DataFrame 等价的对象。 + 索引信息若在序列化时存在 schema 元数据中会被还原;否则 + 保持默认 RangeIndex。 + + 备注: + ``read_all()`` 一次性把所有 RecordBatch 加载到内存。对于巨型表,应改用 + ``reader.get_record_batch(i)`` 逐批读取以控制内存峰值。 + """ + # 用 BufferReader 把 bytes 包装为可读流 + buffer = pa.BufferReader(arrow_bytes) + + # 通过 IPC 文件格式读取 Arrow Table(含 schema 与 footer 校验) + with ipc.open_file(buffer) as reader: + table = reader.read_all() + + # Arrow Table -> Pandas DataFrame 会涉及一次列级别的 zero-copy / copy 决策, + # 由 PyArrow 内部根据 dtype 自行选择,调用方无需关心。 + return table.to_pandas() diff --git a/czsc/connectors/cooperation.py b/czsc/connectors/cooperation.py new file mode 100644 index 000000000..6d969440c --- /dev/null +++ b/czsc/connectors/cooperation.py @@ -0,0 +1,1026 @@ +# -*- coding: utf-8 -*- +""" +作者: zengbin93 +邮箱: zeng_bin8888@163.com +创建时间: 2023/11/15 20:45 +模块说明: + CZSC 开源协作团队内部使用的数据接口模块。 + + 本模块封装了 CZSC 团队内部投研共享数据 API,提供统一的数据获取入口,主要职责包括: + + 1. 行情数据接入: + - 获取 A 股、ETF、A 股指数、南华指数、期货主力合约等品种列表 + - 获取上述品种的多周期 K 线数据(1 分钟到月线均支持) + - 自动处理期货主力合约的分段下载与拼接、股指期货的有效交易时段过滤 + + 2. 全市场日线数据: + - 提供 ``stocks_daily_klines`` 函数,按月分段下载全市场 A 股日线 + - 集成 ``czsc.disk_cache`` 磁盘缓存,加速重复研究 + + 3. 策略权重协作: + - ``upload_strategy``:将本地策略持仓权重上传至共享服务器 + - ``get_stk_strategy``:获取 STK 系列子策略的历史持仓权重 + - ``get_strategy_dailys`` / ``get_strategy_weights``:带本地缓存的增量刷新接口 + - ``StrategyClient``:面向对象封装的策略管理 HTTP 客户端 + + 使用场景: + 团队成员在内部进行因子研究、策略回测、组合管理时调用,统一数据口径。 + + 注意事项: + - 首次使用前需要在终端中通过 ``czsc.set_url_token`` 设置 token, + 或者通过环境变量 ``CZSC_TOKEN`` 配置; + - 数据接口由内部服务 ``http://zbczsc.com:9106`` 提供,外网用户无权访问; + - 缓存目录默认位于 ``~/.quant_data_cache``,可以通过环境变量 ``CZSC_CACHE_PATH`` 覆盖。 +""" +import os +import time +import czsc +import requests +import loguru +import pandas as pd +from tqdm import tqdm +from pathlib import Path +from datetime import datetime +from czsc import RawBar, Freq +from typing import Dict, List, Any + +# 首次使用需要打开一个 Python 终端按如下方式设置 token,或者直接在环境变量中设置 CZSC_TOKEN +# 示例:czsc.set_url_token(token='your token', url='http://zbczsc.com:9106') + +# 本地缓存目录:优先使用环境变量 CZSC_CACHE_PATH,否则使用用户主目录下的 .quant_data_cache +cache_path = os.getenv("CZSC_CACHE_PATH", os.path.expanduser("~/.quant_data_cache")) +# 数据 API 服务地址:优先使用环境变量 CZSC_DATA_API,否则使用默认内部服务地址 +url = os.getenv("CZSC_DATA_API", "http://zbczsc.com:9106") +# 全局 DataClient 实例:复用同一份 token 和缓存目录,避免重复创建 +dc = czsc.DataClient(token=os.getenv("CZSC_TOKEN"), url=url, cache_path=cache_path) + + +def get_groups(): + """获取投研共享数据的可选分组名称列表。 + + 用于上层在选择标的池时枚举可用分组,避免硬编码字符串。 + + :return: list[str], 内置支持的分组名称集合,依次为: + "A股指数"、"ETF"、"股票"、"期货主力"、"南华指数" + """ + return ["A股指数", "ETF", "股票", "期货主力", "南华指数"] + + +def get_symbols(name, **kwargs): + """获取指定分组下的所有标的代码。 + + 根据传入的分组名称,调用底层数据接口拉取对应的标的列表,并按照 + CZSC 内部约定的命名规范(``code#资产类型``)返回。 + + :param name: str, 分组名称,可选值: + - "A股指数":上交所/深交所指数,返回形如 ``000001.SH#INDEX`` + - "ETF":在 2024-04-02 仍有交易的 ETF,返回形如 ``510050.SH#ETF`` + - "股票":当前在交易的 A 股,返回形如 ``000001.SZ#STOCK`` + - "期货主力":期货主力合约 9001 系列,原样返回 + - "南华指数":南华商品指数,原样返回 + - "ALL":以上所有分组的并集 + :param kwargs: dict, 兼容参数,当前未使用,保留以兼容上层调用 + :return: list[str], 标的代码列表 + :raises ValueError: 当传入的分组名称无法识别时抛出 + """ + # 股票分组:从基础信息表中拉取所有正在交易的标的(status=1) + if name == "股票": + df = dc.stock_basic(nobj=1, status=1, ttl=3600 * 6) + symbols = [f"{row['code']}#STOCK" for _, row in df.iterrows()] + return symbols + + # ETF 分组:先拿到全量 ETF 基础信息,再用某一交易日的成交记录过滤掉已退市/无成交的 + if name == "ETF": + df = dc.etf_basic(v="2", fields="code,name", ttl=3600 * 6) + dfk = dc.pro_bar(trade_date="2024-04-02", asset="e", v="2") + df = df[df["code"].isin(dfk["code"])].reset_index(drop=True) + symbols = [f"{row['code']}#ETF" for _, row in df.iterrows()] + return symbols + + # A 股指数:仅取上交所、深交所市场的指数代码 + if name == "A股指数": + # 指数说明文档(仅限内部):https://s0cqcxuy3p.feishu.cn/wiki/KuSAweAAhicvsGk9VPTc1ZWKnAd + df = dc.index_basic(v="2", market="SSE,SZSE", ttl=3600 * 6) + symbols = [f"{row['code']}#INDEX" for _, row in df.iterrows()] + return symbols + + # 南华指数:商品期货综合指数序列 + if name == "南华指数": + df = dc.index_basic(v="2", market="NH", ttl=3600 * 6) + symbols = [row["code"] for _, row in df.iterrows()] + return symbols + + # 期货主力:通过某一交易日的全量行情快照拿到所有主力合约代码 + if name == "期货主力": + kline = dc.future_klines(v="2", trade_date="20240402", ttl=-1) + return kline["code"].unique().tolist() + + # ALL:把上面所有分组拼接起来,得到全市场标的代码 + if name.upper() == "ALL": + symbols = get_symbols("股票") + get_symbols("ETF") + symbols += get_symbols("A股指数") + get_symbols("南华指数") + get_symbols("期货主力") + return symbols + + raise ValueError(f"{name} 分组无法识别,获取标的列表失败!") + + +def get_min_future_klines(code, sdt, edt, freq="1m", **kwargs): + """分段获取期货 1 分钟 K 线后合并为一个 DataFrame。 + + 由于 1 分钟 K 线数据量大,单次请求容易超时或被服务端截断,因此本函数 + 将请求按年切分(最长 365 天为一个分段),分批拉取后再合并。 + 同时会针对股指期货(IC/IF/IH)做盘中有效交易时段过滤,剔除集合竞价等无效分钟。 + + :param code: str, 期货合约代码,例如 "SFIC9001"、"DLi9001" + :param sdt: str | datetime, 开始日期,可任意常见格式 + :param edt: str | datetime, 结束日期 + :param freq: str, 频率字符串,默认为 "1m" + :param kwargs: dict, 可选参数 + - logger: 日志记录器,默认使用 loguru.logger + - ttl: int, 缓存有效期(秒),默认 3600;历史段落自动设为 -1 长期缓存 + :return: pd.DataFrame, 合并、去重后的 K 线数据,包含 + dt、symbol、open、close、high、low、vol 等字段 + """ + logger = kwargs.pop("logger", loguru.logger) + + sdt = pd.to_datetime(sdt).strftime("%Y%m%d") + edt = pd.to_datetime(edt).strftime("%Y%m%d") + # 按 365 天为一个分段,覆盖 2000~2030 年的全部时间区间 + # dates = pd.date_range(start=sdt, end=edt, freq='1M') # 旧的按月切分方式,现已弃用 + dates = pd.date_range(start="20000101", end="20300101", freq="365D") + + dates = [d.strftime("%Y%m%d") for d in dates] + dates = sorted(list(set(dates))) + + rows = [] + # 遍历每一个分段区间 [sdt_, edt_),按需要逐段拉取 + for sdt_, edt_ in tqdm(zip(dates[:-1], dates[1:]), total=len(dates) - 1): + # 该段结束时间早于查询开始时间,跳过 + if edt_ < sdt: + continue + + # 该段开始时间已经超过当前日期,后续段都是未来数据,直接终止 + if pd.to_datetime(sdt_).date() >= datetime.now().date(): + break + + # 历史已结束的段落使用永久缓存(ttl=-1),尚未结束的段落使用短期缓存 + ttl = kwargs.get("ttl", 60 * 60) if pd.to_datetime(edt_).date() >= datetime.now().date() else -1 + df = dc.future_klines(code=code, sdt=sdt_, edt=edt_, freq=freq, ttl=ttl, v="2") + if df.empty: + continue + logger.info(f"{code}获取K线范围:{df['dt'].min()} - {df['dt'].max()}") + rows.append(df) + + df = pd.concat(rows, ignore_index=True) + df.rename(columns={"code": "symbol"}, inplace=True) + df["dt"] = pd.to_datetime(df["dt"]) + # 不同段之间的边界可能重复,按 (dt, symbol) 去重,保留最后一次拉到的版本 + df = df.drop_duplicates(subset=["dt", "symbol"], keep="last") + + if code in ["SFIC9001", "SFIF9001", "SFIH9001"]: + # 股指期货只保留连续竞价时段:09:31-11:30 与 13:01-15:00,剔除集合竞价及夜盘 + dt1 = datetime.strptime("09:31:00", "%H:%M:%S") + dt2 = datetime.strptime("11:30:00", "%H:%M:%S") + c1 = (df["dt"].dt.time >= dt1.time()) & (df["dt"].dt.time <= dt2.time()) + + dt3 = datetime.strptime("13:01:00", "%H:%M:%S") + dt4 = datetime.strptime("15:00:00", "%H:%M:%S") + c2 = (df["dt"].dt.time >= dt3.time()) & (df["dt"].dt.time <= dt4.time()) + + df = df[c1 | c2].copy().reset_index(drop=True) + + # 最后再按用户指定的 [sdt, edt] 区间裁剪结果 + df = df[(df["dt"] >= pd.to_datetime(sdt)) & (df["dt"] <= pd.to_datetime(edt))].copy().reset_index(drop=True) + return df + + +def get_raw_bars(symbol, freq, sdt, edt, fq="前复权", **kwargs): + """获取 CZSC 库定义的标准 RawBar 对象列表(统一数据入口)。 + + 本函数是协作数据源的统一行情接口,根据标的后缀自动分发到不同的底层取数函数: + - ``9001`` 结尾:期货主力合约 + - ``.NH`` 结尾:南华指数(仅支持日线) + - 含 ``SH`` 或 ``SZ``:A 股 / ETF / 指数 + + :param symbol: str, 标的代码,需符合 ``code#资产类型`` 规范,例如: + - "000001.SH#INDEX" + - "510050.SH#ETF" + - "000001.SZ#STOCK" + - "SFIC9001"(期货主力) + - "NHCI.NH"(南华指数) + :param freq: str | czsc.Freq, K 线周期。支持字符串 "1分钟"、"5分钟"、"15分钟"、 + "30分钟"、"60分钟"、"日线"、"周线"、"月线"、"季线"、"年线" + :param sdt: str | datetime, 开始时间 + :param edt: str | datetime, 结束时间 + :param fq: str, 除权类型,可选 "前复权"、"后复权"、"不复权"。 + 注意:期货主力合约暂不支持前复权,会自动切换为后复权 + :param kwargs: dict, 可选参数 + - logger: 日志记录器 + - raw_bars: bool, 是否返回 RawBar 对象列表,False 时返回 DataFrame,默认 True + - ttl: int, 缓存有效期(秒),默认 -1(永久缓存) + :return: list[RawBar] | pd.DataFrame, 取决于 ``raw_bars`` 参数 + :raises ValueError: 当 symbol 无法识别,或南华指数请求非日线周期时抛出 + + 示例: + >>> from czsc.connectors import cooperation as coo + >>> df = coo.get_raw_bars(symbol="000001.SH#INDEX", freq="日线", + ... sdt="2001-01-01", edt="2021-12-31", + ... fq='后复权', raw_bars=False) + """ + logger = kwargs.pop("logger", loguru.logger) + + freq = czsc.Freq(freq) + raw_bars = kwargs.get("raw_bars", True) + ttl = kwargs.get("ttl", -1) + sdt = pd.to_datetime(sdt).strftime("%Y%m%d") + edt = pd.to_datetime(edt).strftime("%Y%m%d") + + # 分支一:期货主力合约(代码以 9001 结尾) + if symbol.endswith("9001"): + # 期货主力合约的复权说明(仅限内部): + # https://s0cqcxuy3p.feishu.cn/wiki/WLGQwJLWQiWPCZkPV7Xc3L1engg + if fq == "前复权": + logger.warning("期货主力合约暂时不支持前复权,已自动切换为后复权") + + # 根据目标频率判断底层基础周期:分钟取 1m,日及以上取 1d + freq_rd = "1m" if freq.value.endswith("分钟") else "1d" + if freq.value.endswith("分钟"): + df = get_min_future_klines(code=symbol, sdt=sdt, edt=edt, freq="1m", ttl=ttl) + if df.empty: + return df + + # 接口返回缺失成交额时,用 vol*close 近似估算 + if "amount" not in df.columns: + df["amount"] = df["vol"] * df["close"] + + df = df[["symbol", "dt", "open", "close", "high", "low", "vol", "amount"]].copy().reset_index(drop=True) + df["dt"] = pd.to_datetime(df["dt"]) + return czsc.resample_bars(df, target_freq=freq, raw_bars=raw_bars, base_freq="1分钟") + + else: + df = dc.future_klines(code=symbol, sdt=sdt, edt=edt, freq=freq_rd, ttl=ttl, v="2") + if df.empty: + return df + + df.rename(columns={"code": "symbol"}, inplace=True) + if "amount" not in df.columns: + df["amount"] = df["vol"] * df["close"] + + df = df[["symbol", "dt", "open", "close", "high", "low", "vol", "amount"]].copy().reset_index(drop=True) + df["dt"] = pd.to_datetime(df["dt"]) + return czsc.resample_bars(df, target_freq=freq, raw_bars=raw_bars) + + # 分支二:南华指数(代码以 .NH 结尾),仅日线 + if symbol.endswith(".NH"): + if freq != Freq.D: + raise ValueError("南华指数只支持日线数据") + df = dc.nh_daily(code=symbol, sdt=sdt, edt=edt, ttl=ttl, v="2") + df.rename(columns={"code": "symbol", "volume": "vol"}, inplace=True) + df["dt"] = pd.to_datetime(df["dt"]) + return czsc.resample_bars(df, target_freq=freq, raw_bars=raw_bars) + + # 分支三:A 股 / ETF / 指数(代码包含 SH 或 SZ 后缀) + if "SH" in symbol or "SZ" in symbol: + # 复权类型映射:本地中文枚举到底层接口缩写 + fq_map = {"前复权": "qfq", "后复权": "hfq", "不复权": None} + adj = fq_map.get(fq, None) + + # 标的代码格式为 "code#asset",asset 取首字母即可(s/e/i 等) + code, asset = symbol.split("#") + + if freq.value.endswith("分钟"): + df = dc.pro_bar(code=code, sdt=sdt, edt=edt, freq="min", adj=adj, asset=asset[0].lower(), v="2", ttl=ttl) + # 09:30:00 是集合竞价撮合时间,剔除该时刻避免与 09:31 重复 + df = df[~df["dt"].str.endswith("09:30:00")].reset_index(drop=True) + df.rename(columns={"code": "symbol"}, inplace=True) + df["dt"] = pd.to_datetime(df["dt"]) + return czsc.resample_bars(df, target_freq=freq, raw_bars=raw_bars, base_freq="1分钟") + + else: + df = dc.pro_bar(code=code, sdt=sdt, edt=edt, freq="day", adj=adj, asset=asset[0].lower(), v="2", ttl=ttl) + df.rename(columns={"code": "symbol"}, inplace=True) + df["dt"] = pd.to_datetime(df["dt"]) + return czsc.resample_bars(df, target_freq=freq, raw_bars=raw_bars) + + raise ValueError(f"symbol {symbol} 无法识别,获取数据失败!") + + +@czsc.disk_cache(path=cache_path, ttl=-1) +def stocks_daily_klines(sdt="20170101", edt="20240101", **kwargs): + """获取全市场 A 股的日线数据(按月分段拉取并拼接,结果带磁盘缓存)。 + + 为避免单次接口请求数据量过大被服务端截断,本函数会把 ``[sdt, edt]`` 区间 + 按自然月切分成多个子区间,逐月拉取再合并;并通过 ``czsc.disk_cache`` + 把整个结果缓存到磁盘,重复研究时直接命中缓存。 + + :param sdt: str, 开始日期,默认 "20170101" + :param edt: str, 结束日期,默认 "20240101" + :param kwargs: dict, 可选参数 + - adj: str, 复权类型,传给底层 ``pro_bar``,默认 "hfq"(后复权) + - exclude_bj: bool, 是否剔除北交所标的(.BJ 结尾),默认 True + - nxb: tuple[int], 未来 N 日收益的窗口列表,默认 [1, 2, 5, 10, 20, 30, 60]; + 传入空值时跳过未来收益计算 + :return: pd.DataFrame, 字段包含 symbol、dt、open、close、high、low、vol、amount、price, + 以及由 ``nxb`` 生成的 n1b、n2b 等未来收益列 + """ + adj = kwargs.get("adj", "hfq") + + # 转换为 datetime 对象,便于后续计算 + start_dt = pd.to_datetime(sdt) + end_dt = pd.to_datetime(edt) + + # 计算 sdt 和 edt 之间的每个月 1 号,得到分段下载的 [sdt_, edt_) 区间列表 + date_spans = [] + current = start_dt.replace(day=1) # 从月初开始 + while current <= end_dt: + sdt_ = current.strftime("%Y%m%d") + edt_ = (current + pd.DateOffset(months=1)).replace(day=1).strftime("%Y%m%d") + date_spans.append((sdt_, edt_)) + current = (current + pd.DateOffset(months=1)).replace(day=1) + + res = [] + for sdt_, edt_ in date_spans: + # 当前月份使用较短缓存时间(6 小时),历史月份使用永久缓存 + ttl = 3600 * 6 if edt_ < pd.Timestamp.now().strftime("%Y%m%d") else -1 + kline = dc.pro_bar(sdt=sdt_, edt=edt_, adj=adj, v="2", ttl=ttl) + res.append(kline) + + dfk = pd.concat(res, ignore_index=True) + dfk["dt"] = pd.to_datetime(dfk["dt"]) + dfk = dfk.sort_values(["code", "dt"], ascending=True).reset_index(drop=True) + # 默认剔除北交所(.BJ),数据可能不完整且大多数策略并不交易北交所 + if kwargs.get("exclude_bj", True): + dfk = dfk[~dfk["code"].str.endswith(".BJ")].reset_index(drop=True) + + dfk = dfk.rename(columns={"code": "symbol"}) + # 跨月份拼接可能存在重复行,按 (symbol, dt) 去重,保留最后一次出现的记录 + dfk = dfk.drop_duplicates(subset=["symbol", "dt"], keep="last").reset_index(drop=True) + dfk["price"] = dfk["close"] + nxb = kwargs.get("nxb", [1, 2, 5, 10, 20, 30, 60]) + if nxb: + dfk = czsc.update_nxb(dfk, nseq=nxb) + return dfk + + +def upload_strategy(df, meta, token=None, **kwargs): + """上传策略数据到协作服务器。 + + 将本地策略生成的持仓权重以及策略元数据上传至 ``http://zbczsc.com:9106``, + 用于团队内部的共享研究、批量回测或风险监控。 + + :param df: pd.DataFrame, 策略持仓权重数据,至少包含 dt, symbol, weight 三列, 例如: + + =================== ======== ======== + dt symbol weight + =================== ======== ======== + 2017-01-03 09:01:00 ZZSF9001 0 + 2017-01-03 09:01:00 DLj9001 0 + 2017-01-03 09:01:00 SQag9001 0 + 2017-01-03 09:06:00 ZZSF9001 0.136364 + 2017-01-03 09:06:00 SQag9001 1 + =================== ======== ======== + + :param meta: dict, 策略元数据 + + 至少包含 name, description, base_freq, author, outsample_sdt 字段, 例如: + + {'name': 'TS001_3', + 'description': '测试策略:仅用于读写redis测试', + 'base_freq': '1分钟', + 'author': 'ZB', + 'outsample_sdt': '20220101'} + + :param token: str, 上传凭证码;如果不提供,将从环境变量 CZSC_TOKEN 中获取 + :param kwargs: dict, 其他参数 + + - logger: loguru.logger, 日志记录器 + :return: dict, 服务端响应结果 + """ + logger = kwargs.pop("logger", loguru.logger) + df = df.copy() + df["dt"] = pd.to_datetime(df["dt"]) + logger.info(f"输入数据中有 {len(df)} 条权重信号") + + # 去除单个品种下相邻时间权重相同的数据,节省传输与存储成本 + _res = [] + for _, dfg in df.groupby("symbol"): + dfg = dfg.sort_values("dt", ascending=True).reset_index(drop=True) + dfg = dfg[dfg["weight"].diff().fillna(1) != 0].copy() + _res.append(dfg) + df = pd.concat(_res, ignore_index=True) + df = df.sort_values(["dt"]).reset_index(drop=True) + df["dt"] = df["dt"].dt.strftime("%Y-%m-%d %H:%M:%S") + + logger.info(f"去除单个品种下相邻时间权重相同的数据后,剩余 {len(df)} 条权重信号") + + # 构造上传 payload:weights 字段使用 split 方向的 JSON 以便服务端高效解析 + data = { + "weights": df[["dt", "symbol", "weight"]].to_json(orient="split"), + "token": token or os.getenv("CZSC_TOKEN"), + "strategy_name": meta.get("name"), + "meta": meta, + } + response = requests.post("http://zbczsc.com:9106/upload_strategy", json=data) + + logger.info(f"上传策略接口返回: {response.json()}") + return response.json() + + +def get_stk_strategy(name="STK_001", **kwargs): + """获取 STK 系列子策略的持仓权重数据,并匹配未来 1 日收益。 + + 本函数面向选股策略的研究场景,会自动拼接持仓权重与对应日期的下一日收益(n1b), + 便于直接进行 IC 分析、组合收益归因等。 + + :param name: str, 子策略名称,例如 "STK_001"、"STK_002" + :param kwargs: dict + sdt: str, 可选, 开始日期,默认 "20170101" + edt: str, 可选, 结束日期,默认当前日期 + ttl: int, 可选, 接口缓存时间(秒),默认 6 小时 + :return: pd.DataFrame, 字段包含 dt、symbol、weight、n1b + """ + dfw = dc.post_request(api_name=name, v="2", hist=1, ttl=kwargs.get("ttl", 3600 * 6)) + dfw["dt"] = pd.to_datetime(dfw["dt"]) + sdt = kwargs.get("sdt", "20170101") + edt = pd.Timestamp.now().strftime("%Y%m%d") + edt = kwargs.get("edt", edt) + dfw = dfw[(dfw["dt"] >= pd.to_datetime(sdt)) & (dfw["dt"] <= pd.to_datetime(edt))].copy().reset_index(drop=True) + + # 拉取同区间的全市场日线(含未来 1、2 日收益),与权重表做左连接 + dfb = stocks_daily_klines(sdt=sdt, edt=edt, nxb=(1, 2)) + dfw = pd.merge(dfw, dfb, on=["dt", "symbol"], how="left") + dfh = dfw[["dt", "symbol", "weight", "n1b"]].copy() + return dfh + + +# ====================================================================================================================== +# 增量更新本地缓存数据 +# ---------------------------------------------------------------------------------------------------------------------- +# 下面这一组函数(get_all_strategies / get_strategy_dailys / get_strategy_weights)共同构成了 +# “首次全量、后续增量”的本地缓存机制: +# - 第一次调用时一次性拉取全部历史数据,落盘成 feather 文件; +# - 后续调用时只拉取最近几天的新数据,与本地缓存做合并、去重,再写回缓存; +# - 当缓存能完全覆盖请求区间时直接返回,避免任何远程请求。 +# ====================================================================================================================== +def get_all_strategies(ttl=3600 * 24 * 7, logger=loguru.logger, path=cache_path): + """获取所有策略的元数据。 + + 元数据描述了每个策略的基本信息(名称、作者、基础频率、样本外起始日期等), + 主要用于在策略池中筛选、展示和组合。 + + :param ttl: int, 可选, 缓存有效期(秒),默认 7 天 + :param logger: loguru.logger, 可选, 日志记录器 + :param path: str, 可选, 缓存根目录路径 + :return: pd.DataFrame, 包含字段 name, description, author, base_freq, outsample_sdt;示例如下: + + =========== ===================== ========= ========= ============ + name description author base_freq outsample_sdt + =========== ===================== ========= ========= ============ + STK_001 A股选股策略 ZB 1分钟 20220101 + STK_002 A股选股策略 ZB 1分钟 20220101 + STK_003 A股选股策略 ZB 1分钟 20220101 + =========== ===================== ========= ========= ============ + """ + path = Path(path) / "strategy" + path.mkdir(exist_ok=True, parents=True) + file_metas = path / "metas.feather" + + # 缓存文件存在且未过期:直接读取,避免远程请求 + if file_metas.exists() and (time.time() - file_metas.stat().st_mtime) < ttl: + logger.info("【缓存命中】获取所有策略的元数据") + dfm = pd.read_feather(file_metas) + + else: + logger.info("【全量刷新】获取所有策略的元数据并刷新缓存") + dfm = dc.get_all_strategies(v="2", ttl=0) + dfm.to_feather(file_metas) + + return dfm + + +def __update_strategy_dailys(file_cache, strategy, logger=loguru.logger): + """更新(增量或全量)策略的日收益缓存数据。 + + 内部辅助函数: + - 若缓存文件已存在,从缓存最新日期向前回溯 3 天再向后拉取,做增量合并; + - 若缓存文件不存在,自 20170101 起做全量拉取。 + + :param file_cache: pathlib.Path, 缓存文件路径(feather 格式) + :param strategy: str, 策略名称 + :param logger: loguru.logger, 日志记录器 + :return: pd.DataFrame, 更新后的完整日收益数据 + """ + # 增量刷新分支:基于缓存最新日期推断需要补齐的区间 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + # 向前回溯 3 天作为缓冲,避免最新一日数据修订带来的差异 + cache_sdt = (df["dt"].max() - pd.Timedelta(days=3)).strftime("%Y%m%d") + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【增量刷新缓存】获取策略 {strategy} 的日收益数据:{cache_sdt} - {cache_edt}") + + dfc = dc.sub_strategy_dailys(strategy=strategy, v="2", sdt=cache_sdt, edt=cache_edt, ttl=0) + dfc["dt"] = pd.to_datetime(dfc["dt"]) + df = pd.concat([df, dfc]).drop_duplicates(["dt", "symbol", "strategy"], keep="last") + + else: + # 全量刷新分支:缓存不存在时一次性拉满历史数据 + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【全量刷新缓存】获取策略 {strategy} 的日收益数据:20170101 - {cache_edt}") + df = dc.sub_strategy_dailys(strategy=strategy, v="2", sdt="20170101", edt=cache_edt, ttl=0) + + df = df.reset_index(drop=True) + df["dt"] = pd.to_datetime(df["dt"]) + df.to_feather(file_cache) + return df + + +def get_strategy_dailys( + strategy="FCS001", symbol=None, sdt="20240101", edt=None, logger=loguru.logger, path=cache_path +): + """获取策略的历史日收益数据(带本地缓存)。 + + 优先尝试命中本地缓存;若缓存数据不能覆盖请求的结束日期,则触发增量刷新。 + + :param strategy: str, 策略名称 + :param symbol: str, 可选, 品种名称,传入后只返回该品种的数据 + :param sdt: str, 开始时间,默认 "20240101" + :param edt: str, 可选, 结束时间,默认当前时间 + :param logger: loguru.logger, 可选, 日志记录器 + :param path: str, 可选, 缓存根目录路径 + :return: pd.DataFrame, 包含字段 dt, symbol, strategy, returns;示例如下: + + =================== ========== ======== ========= + dt strategy symbol returns + =================== ========== ======== ========= + 2017-01-10 00:00:00 STK_001 A股选股 0.001 + 2017-01-11 00:00:00 STK_001 A股选股 0.012 + 2017-01-12 00:00:00 STK_001 A股选股 0.011 + =================== ========== ======== ========= + """ + path = Path(path) / "strategy" / "dailys" + path.mkdir(exist_ok=True, parents=True) + file_cache = path / f"{strategy}.feather" + + if edt is None: + edt = pd.Timestamp.now().strftime("%Y%m%d %H:%M:%S") + + # 判断缓存数据是否能满足需求:缓存最新日期 >= 请求结束日期即视为命中 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + if df["dt"].max() >= pd.Timestamp(edt): + logger.info(f"【缓存命中】获取策略 {strategy} 的日收益数据:{sdt} - {edt}") + + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + if symbol: + dfd = dfd[dfd["symbol"] == symbol].copy() + return dfd + + # 缓存未命中或数据不全:触发刷新后再过滤返回 + logger.info(f"【缓存刷新】获取策略 {strategy} 的日收益数据:{sdt} - {edt}") + df = __update_strategy_dailys(file_cache, strategy, logger=logger) + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + if symbol: + dfd = dfd[dfd["symbol"] == symbol].copy() + return dfd + + +def __update_strategy_weights(file_cache, strategy, logger=loguru.logger): + """更新(增量或全量)策略的持仓权重缓存数据。 + + 内部辅助函数,逻辑与 ``__update_strategy_dailys`` 类似,区别在于: + - 数据接口为 ``post_request(api_name=strategy, hist=1)``; + - 去重键使用 ``(dt, symbol, weight)``。 + + :param file_cache: pathlib.Path, 缓存文件路径 + :param strategy: str, 策略名称 + :param logger: loguru.logger, 日志记录器 + :return: pd.DataFrame, 更新后的完整持仓权重数据 + """ + # 增量刷新分支 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + cache_sdt = (df["dt"].max() - pd.Timedelta(days=3)).strftime("%Y%m%d") + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【增量刷新缓存】获取策略 {strategy} 的持仓权重数据:{cache_sdt} - {cache_edt}") + + dfc = dc.post_request(api_name=strategy, v="2", sdt=cache_sdt, edt=cache_edt, hist=1, ttl=0) + dfc["dt"] = pd.to_datetime(dfc["dt"]) + dfc["strategy"] = strategy + + df = pd.concat([df, dfc]).drop_duplicates(["dt", "symbol", "weight"], keep="last") + + else: + # 全量刷新分支 + cache_edt = (pd.Timestamp.now() + pd.Timedelta(days=1)).strftime("%Y%m%d") + logger.info(f"【全量刷新缓存】获取策略 {strategy} 的持仓权重数据:20170101 - {cache_edt}") + df = dc.post_request(api_name=strategy, v="2", sdt="20170101", edt=cache_edt, hist=1, ttl=0) + df["dt"] = pd.to_datetime(df["dt"]) + df["strategy"] = strategy + + df = df.reset_index(drop=True) + df.to_feather(file_cache) + return df + + +def get_strategy_weights(strategy="FCS001", sdt="20240101", edt=None, logger=loguru.logger, path=cache_path): + """获取策略的历史持仓权重数据(带本地缓存)。 + + 缓存命中策略与 ``get_strategy_dailys`` 一致:缓存最新日期不小于请求结束日期视为命中。 + + :param strategy: str, 策略名称 + :param sdt: str, 开始时间,默认 "20240101" + :param edt: str, 可选, 结束时间,默认当前时间 + :param logger: loguru.logger, 可选, 日志记录器 + :param path: str, 可选, 缓存根目录路径 + :return: pd.DataFrame, 包含字段 dt, symbol, weight, update_time, strategy;示例如下: + + =================== ========= ======== =================== ========== + dt symbol weight update_time strategy + =================== ========= ======== =================== ========== + 2017-01-09 00:00:00 000001.SZ 0 2024-07-27 16:13:29 STK_001 + 2017-01-10 00:00:00 000001.SZ 0 2024-07-27 16:13:29 STK_001 + 2017-01-11 00:00:00 000001.SZ 0 2024-07-27 16:13:29 STK_001 + =================== ========= ======== =================== ========== + """ + path = Path(path) / "strategy" / "weights" + path.mkdir(exist_ok=True, parents=True) + file_cache = path / f"{strategy}.feather" + + if edt is None: + edt = pd.Timestamp.now().strftime("%Y%m%d %H:%M:%S") + + # 判断缓存数据是否能满足需求 + if file_cache.exists(): + df = pd.read_feather(file_cache) + + if df["dt"].max() >= pd.Timestamp(edt): + logger.info(f"【缓存命中】获取策略 {strategy} 的历史持仓权重数据:{sdt} - {edt}") + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + return dfd + + # 缓存未命中或数据不全:触发增量刷新 + logger.info(f"【缓存刷新】获取策略 {strategy} 的历史持仓权重数据:{sdt} - {edt}") + df = __update_strategy_weights(file_cache, strategy, logger=logger) + dfd = df[(df["dt"] >= pd.Timestamp(sdt)) & (df["dt"] <= pd.Timestamp(edt))].copy() + return dfd + + +class StrategyClient: + """CZSC 策略管理 API 客户端。 + + 面向对象封装了一组策略管理相关的 HTTP 接口,相比上面以模块函数提供的能力, + 本类更适用于: + - 需要在多个接口间共享同一个 token 与 ``requests.Session``; + - 需要在同一进程内频繁地切换或更新 token; + - 希望统一的错误处理与日志格式。 + + 主要功能: + - 策略元数据:增、删、改、查; + - 策略权重:查询、上传、删除; + - 缓存清理:按 token 或角色清理服务端缓存。 + """ + + def __init__(self, base_url: str, token: str = None, logger=loguru.logger): + """初始化客户端。 + + :param base_url: str, API 基础 URL,例如 ``http://zbczsc.com:9106`` + :param token: str, 可选, 访问令牌;可后续通过 ``set_token`` 设置 + :param logger: loguru.logger, 可选, 日志记录器 + """ + # 去除末尾斜杠,便于和 endpoint 拼接 + self.base_url = base_url.rstrip("/") + self.token = token + # 复用 Session 以利用底层 HTTP keep-alive,提升批量请求性能 + self.session = requests.Session() + self._setup_headers() + self.logger = logger + + def _setup_headers(self): + """根据当前 token 设置统一的请求头(内部方法)。 + + 始终设置 ``Content-Type`` 与 ``Accept`` 为 JSON; + 当 token 存在时附加 Bearer 鉴权头。 + """ + self.session.headers.update({"Content-Type": "application/json", "Accept": "application/json"}) + if self.token: + self.session.headers["Authorization"] = f"Bearer {self.token}" + + def set_token(self, token: str): + """更新当前客户端使用的访问令牌。 + + :param token: str, 新的访问令牌 + """ + self.token = token + self._setup_headers() + self.logger.info("访问令牌已更新") + + def _make_request(self, method: str, endpoint: str, data: Dict = None) -> Dict: + """统一的 HTTP 请求底层方法(内部方法)。 + + 负责拼接 URL、根据 method 选择 GET/POST 调用、统一异常处理与日志输出。 + + :param method: str, HTTP 方法,"GET" 或 "POST",大小写不敏感 + :param endpoint: str, API 端点路径,需以 "/" 开头 + :param data: dict, 可选, 请求数据;GET 时作为 query string,POST 时作为 JSON body + :return: dict, 服务端返回的 JSON 数据 + :raises requests.exceptions.RequestException: 网络异常或状态码非 2xx 时抛出 + :raises ValueError: 响应不是合法的 JSON 时抛出 + """ + url = f"{self.base_url}{endpoint}" + + try: + if method.upper() == "GET": + response = self.session.get(url, params=data) + else: + response = self.session.post(url, json=data) + + response.raise_for_status() + result = response.json() + + self.logger.debug(f"API请求成功: {method} {endpoint}") + return result + + except requests.exceptions.RequestException as e: + self.logger.error(f"API请求失败: {method} {endpoint}, 错误: {e}") + raise + except ValueError as e: + self.logger.error(f"响应解析失败: {e}") + raise + + def get_all_strategy_metadata(self) -> List[Dict]: + """获取所有策略元数据。 + + :return: list[dict], 策略元数据列表;接口失败时返回空列表 + """ + data = {"token": self.token} + result = self._make_request("POST", "/get_all_strategy_metadata", data) + + # 接口约定 code=0 为成功 + if result.get("code") == 0: + self.logger.info(f"成功获取{len(result.get('data', []))}个策略元数据") + return result.get("data", []) + else: + self.logger.error(f"获取策略元数据失败: {result.get('msg', '未知错误')}") + return [] + + def add_strategy_meta( + self, + strategy_name: str, + base_freq: str, + description: str, + author_id: int, + outsample_sdt: str, + weight_type: str, + memo: str = "", + ) -> bool: + """添加策略元数据。 + + :param strategy_name: str, 策略名称(全局唯一) + :param base_freq: str, 基础频率,例如 "1分钟"、"日线" + :param description: str, 策略描述 + :param author_id: int, 作者 ID + :param outsample_sdt: str, 样本外开始日期,格式 YYYYMMDD + :param weight_type: str, 权重类型 + :param memo: str, 备注信息,默认空字符串 + :return: bool, 是否添加成功 + """ + data = { + "token": self.token, + "strategy_name": strategy_name, + "meta": { + "base_freq": base_freq, + "description": description, + "author_id": author_id, + "outsample_sdt": outsample_sdt, + "weight_type": weight_type, + "memo": memo, + }, + } + + result = self._make_request("POST", "/add_strategy_meta", data) + + # 此接口约定 code=200 为成功(与 get_all_strategy_metadata 不同) + if result.get("code") == 200: + self.logger.info(f"成功添加策略元数据: {strategy_name}") + return True + else: + self.logger.error(f"添加策略元数据失败: {result.get('msg', '未知错误')}") + return False + + def update_strategy_meta( + self, + strategy_name: str, + base_freq: str = None, + description: str = None, + author_id: int = None, + outsample_sdt: str = None, + weight_type: str = None, + memo: str = None, + ) -> bool: + """更新策略元数据(仅更新非 None 的字段)。 + + :param strategy_name: str, 策略名称(用于定位需要更新的策略) + :param base_freq: str, 可选, 基础频率 + :param description: str, 可选, 策略描述 + :param author_id: int, 可选, 作者 ID(仅管理员可更改) + :param outsample_sdt: str, 可选, 样本外开始日期 + :param weight_type: str, 可选, 权重类型 + :param memo: str, 可选, 备注信息 + :return: bool, 是否更新成功 + """ + meta = {} + # 只把非 None 的字段加入更新载荷,避免误覆盖 + for key, value in [ + ("base_freq", base_freq), + ("description", description), + ("author_id", author_id), + ("outsample_sdt", outsample_sdt), + ("weight_type", weight_type), + ("memo", memo), + ]: + if value is not None: + meta[key] = value + + data = {"token": self.token, "strategy_name": strategy_name, "meta": meta} + + result = self._make_request("POST", "/update_strategy_meta", data) + + if result.get("code") == 200: + self.logger.info(f"成功更新策略元数据: {strategy_name}") + return True + else: + self.logger.error(f"更新策略元数据失败: {result.get('msg', '未知错误')}") + return False + + def delete_strategy_meta(self, strategy_name: str) -> bool: + """删除策略元数据(软删除,权重数据保留)。 + + :param strategy_name: str, 策略名称 + :return: bool, 是否删除成功 + """ + data = {"token": self.token, "strategy_name": strategy_name, "meta": {}} + + result = self._make_request("POST", "/delete_strategy_meta", data) + + if result.get("code") == 200: + self.logger.info(f"成功删除策略元数据: {strategy_name}") + return True + else: + self.logger.error(f"删除策略元数据失败: {result.get('msg', '未知错误')}") + return False + + def get_all_strategy_latest_weights(self) -> List[Dict]: + """获取所有策略的最新持仓权重快照。 + + :return: list[dict], 策略权重数据列表;接口失败时返回空列表 + """ + data = {"token": self.token} + result = self._make_request("POST", "/get_all_strategy_latest_weights", data) + + if result.get("code") == 0: + self.logger.info(f"成功获取{len(result.get('data', []))}条最新权重数据") + return result.get("data", []) + else: + self.logger.error(f"获取最新权重数据失败: {result.get('msg', '未知错误')}") + return [] + + def query_strategy_weight(self, strategy: str, sdt: str = "", edt: str = "", symbols: List[str] = None) -> Dict: + """查询单个策略的持仓权重。 + + :param strategy: str, 策略名称 + :param sdt: str, 可选, 开始日期 + :param edt: str, 可选, 结束日期 + :param symbols: list[str], 可选, 限定的标的代码列表;不传则查询全部 + :return: dict, 包含 meta(元数据)和 weights(权重列表)的字典;失败时返回空字典 + """ + data = {"token": self.token, "strategy": strategy, "sdt": sdt, "edt": edt, "symbols": symbols or []} + + result = self._make_request("POST", "/query_strategy_weight", data) + + if result.get("code") == 0: + data_result = result.get("data", {}) + weights_count = len(data_result.get("weights", [])) + self.logger.info(f"成功查询策略 {strategy} 的权重数据,共{weights_count}条记录") + return data_result + else: + self.logger.error(f"查询策略权重失败: {result.get('msg', '未知错误')}") + return {} + + def delete_strategy(self, strategy: str) -> bool: + """彻底删除策略(同时清除持仓权重和元数据,不可恢复)。 + + :param strategy: str, 策略名称 + :return: bool, 是否删除成功 + """ + data = {"token": self.token, "strategy": strategy} + + result = self._make_request("POST", "/delete_strategy", data) + + if result.get("code") == 200: + self.logger.info(f"成功删除策略: {strategy}") + return True + else: + self.logger.error(f"删除策略失败: {result.get('msg', '未知错误')}") + return False + + def clear_cache(self, tokens: List[str] = None, roles: List[int] = None) -> bool: + """清除服务端的接口缓存。 + + :param tokens: list[str], 可选, 需要清除的 token 列表 + :param roles: list[int], 可选, 需要清除的角色 ID 列表 + :return: bool, 是否清除成功 + """ + data = {"tokens": tokens or [], "roles": roles or []} + + result = self._make_request("POST", "/clear_cache", data) + + # 兼容服务端未返回 code 的情况:默认认为成功 + if result.get("code", 200) == 200: + self.logger.info("成功清除缓存") + return True + else: + self.logger.error("清除缓存失败") + return False + + def upload_strategy_weights( + self, + df: Any, + strategy_name: str, + description: str, + base_freq: str, + author: str, + outsample_sdt: str, + upload_token: str = None, + ) -> Dict: + """上传策略权重数据。 + + 与模块级 ``upload_strategy`` 函数功能一致,区别在于参数以独立形参的形式暴露, + 且使用 ``self.session`` 共享连接。 + + :param df: pd.DataFrame, 策略权重数据,必须包含 dt, symbol, weight 三列 + :param strategy_name: str, 策略名称 + :param description: str, 策略描述 + :param base_freq: str, 基础频率 + :param author: str, 作者 + :param outsample_sdt: str, 样本外开始日期 + :param upload_token: str, 可选, 上传凭证码;不提供则从环境变量 CZSC_TOKEN 读取 + :return: dict, 上传接口返回的结果 + :raises requests.exceptions.RequestException: 网络异常时抛出 + """ + import pandas as pd + import os + + # 数据预处理:拷贝以避免污染外部数据 + df_copy = df.copy() + df_copy["dt"] = pd.to_datetime(df_copy["dt"]) + + self.logger.info(f"输入数据中有 {len(df_copy)} 条权重信号") + + # 去除单个品种下相邻时间权重相同的数据,减少冗余 + _res = [] + for _, dfg in df_copy.groupby("symbol"): + dfg = dfg.sort_values("dt", ascending=True).reset_index(drop=True) + dfg = dfg[dfg["weight"].diff().fillna(1) != 0].copy() + _res.append(dfg) + + df_processed = pd.concat(_res, ignore_index=True) + df_processed = df_processed.sort_values(["dt"]).reset_index(drop=True) + df_processed["dt"] = df_processed["dt"].dt.strftime("%Y-%m-%d %H:%M:%S") + + self.logger.info(f"去除单个品种下相邻时间权重相同的数据后,剩余 {len(df_processed)} 条权重信号") + + # 构造元数据 + meta = { + "name": strategy_name, + "description": description, + "base_freq": base_freq, + "author": author, + "outsample_sdt": outsample_sdt, + } + + # 构造上传数据:weights 字段使用 split 方向 JSON 减小体积 + data = { + "weights": df_processed[["dt", "symbol", "weight"]].to_json(orient="split"), + "token": upload_token or os.getenv("CZSC_TOKEN"), + "strategy_name": strategy_name, + "meta": meta, + } + + # 使用专门的上传接口(与 base_url 不同),直接拼写完整地址 + upload_url = "http://zbczsc.com:9106/upload_strategy" + + try: + response = self.session.post(upload_url, json=data) + response.raise_for_status() + result = response.json() + + self.logger.info(f"成功上传策略权重: {strategy_name}") + self.logger.debug(f"上传接口返回: {result}") + return result + + except requests.exceptions.RequestException as e: + self.logger.error(f"上传策略权重失败: {strategy_name}, 错误: {e}") + raise diff --git a/czsc/connectors/jq_connector.py b/czsc/connectors/jq_connector.py new file mode 100644 index 000000000..86425a462 --- /dev/null +++ b/czsc/connectors/jq_connector.py @@ -0,0 +1,716 @@ +# coding: utf-8 +""" +模块说明: + 聚宽(JQData)HTTP 数据接口的轻量封装。 + + 本模块通过聚宽提供的 HTTP API(``https://dataapi.joinquant.com/apis``)获取行情、 + 财务、概念板块、指数成份等数据,并按照 CZSC 的标准数据结构返回,便于上层策略 + 直接消费。 + + 职责: + - 凭证管理:``set_token`` / ``get_token`` 负责持久化和刷新调用凭证; + - 基础信息:股票/基金/指数/期货等基础信息表,行业归属,成份股; + - 行情数据:分钟和日线 K 线获取(``get_kline`` / ``get_kline_period``), + 自动转换为 ``czsc.RawBar`` 列表; + - 实时回测:``get_init_bg`` 提供 BarGenerator 的初始化能力; + - 财务数据:``get_fundamental`` / ``run_query`` 拉取财务因子; + - 综合 F10:``get_share_basic`` 一键获取个股基础面信息汇总。 + + 使用场景: + 当用户拥有有效的聚宽 JQData 账号且偏好通过 HTTP 接口拉取数据时使用本模块; + 若已经安装 ``jqdatasdk`` Python 包,可直接使用其 SDK,本模块作为 HTTP 备选方案。 + + 注意事项: + - 调用前请先通过 ``set_token(jq_mob, jq_pwd)`` 持久化登录凭证(默认存放在 + ``~/jq.token``),否则 ``get_token`` 会抛出 ``ValueError``; + - 聚宽免费账户每日的查询条数有限(``get_query_count``),请合理使用; + - 单次 K 线请求最大 5000 条,超过会触发 warning 并被服务端截断; + - 区间查询超过 1000 个交易日时同样可能失败,需要自行分段。 +""" +import os +import pickle +import json +import requests +import warnings +from collections import OrderedDict +import pandas as pd +from datetime import datetime, timedelta +from typing import List +from urllib.parse import quote + +from czsc import RawBar, Freq, BarGenerator, freq_end_time + +# CZSC 内部周期字符串到聚宽 unit 字符串的映射 +freq_cn2jq = { + "1分钟": "1m", + "5分钟": "5m", + "15分钟": "15m", + "30分钟": "30m", + "60分钟": "60m", + "日线": "1d", + "周线": "1w", + "月线": "1M", +} + +# 聚宽 HTTP API 入口地址 +url = "https://dataapi.joinquant.com/apis" +# 用户主目录路径:用于存放凭证文件 +home_path = os.path.expanduser("~") +# 凭证文件路径:使用 pickle 序列化保存账号密码 +file_token = os.path.join(home_path, "jq.token") + +# 通用日期时间格式 +dt_fmt = "%Y-%m-%d %H:%M:%S" +date_fmt = "%Y-%m-%d" + +# 聚宽支持的 K 线粒度: 1m, 5m, 15m, 30m, 60m, 120m, 1d, 1w, 1M +# 下面的 freq_convert 把 czsc 内部使用的 pandas-style 频率字符串映射到聚宽 unit 字符串 +freq_convert = { + "1min": "1m", + "5min": "5m", + "15min": "15m", + "30min": "30m", + "60min": "60m", + "D": "1d", + "W": "1w", + "M": "1M", +} + +# pandas-style 频率字符串到 czsc.Freq 枚举的映射 +freq_map = { + "1min": Freq.F1, + "5min": Freq.F5, + "15min": Freq.F15, + "30min": Freq.F30, + "60min": Freq.F60, + "D": Freq.D, + "W": Freq.W, + "M": Freq.M, +} + + +def set_token(jq_mob, jq_pwd): + """持久化保存聚宽 JQData 登录凭证。 + + 将账号和密码以 pickle 形式写入用户主目录下的 ``jq.token`` 文件。 + 后续 ``get_token`` 调用会读取该文件并向聚宽换取一次性的访问令牌。 + + :param jq_mob: str, mob 是申请 JQData 时所填写的手机号 + :param jq_pwd: str, Password 为聚宽官网登录密码,新申请用户默认为手机号后 6 位 + :return: None + """ + with open(file_token, "wb") as f: + pickle.dump([jq_mob, jq_pwd], f) + + +def get_token(): + """获取聚宽 JQData 调用凭证 token。 + + 流程: + 1. 从本地 ``~/jq.token`` 加载用户名 / 密码; + 2. 通过 ``get_current_token`` 接口换取一次性的访问令牌; + 3. 直接返回令牌字符串,调用者把它放进后续请求的 ``token`` 字段。 + + :return: str, 聚宽接口访问令牌 + :raises ValueError: 当本地凭证文件不存在时抛出,需要先调用 ``set_token`` + """ + if not os.path.exists(file_token): + raise ValueError(f"{file_token} 文件不存在,请先调用 set_token 进行设置") + + with open(file_token, "rb") as f: + jq_mob, jq_pwd = pickle.load(f) + + body = { + "method": "get_current_token", + "mob": jq_mob, # mob 是申请 JQData 时所填写的手机号 + "pwd": quote(jq_pwd), # Password 为聚宽官网登录密码,新申请用户默认为手机号后 6 位 + } + response = requests.post(url, data=json.dumps(body)) + token = response.text + return token + + +def text2df(text): + """将聚宽接口返回的 CSV 风格文本转换为 ``pd.DataFrame``。 + + 聚宽 HTTP 接口的返回值通常是以 ``\\n`` 分隔的多行文本,第一行是表头, + 其余每行为以逗号分隔的字段值。 + + :param text: str, 接口原始返回文本 + :return: pd.DataFrame, 解析后的表格数据 + """ + rows = [x.split(",") for x in text.strip().split("\n")] + df = pd.DataFrame(rows[1:], columns=rows[0]) + return df + + +def get_query_count() -> int: + """获取当前账号剩余的查询条数。 + + 用于在大批量查询前先判断额度是否充足,避免中途被限流中断。 + + 接口文档: https://dataapi.joinquant.com/docs#get_query_count---%E8%8E%B7%E5%8F%96%E6%9F%A5%E8%AF%A2%E5%89%A9%E4%BD%99%E6%9D%A1%E6%95%B0 + + :return: int, 当前剩余可用的查询条数 + """ + data = { + "method": "get_query_count", + "token": get_token(), + } + r = requests.post(url, data=json.dumps(data)) + return int(r.text) + + +def get_concepts(): + """获取聚宽全部概念板块列表。 + + 接口文档: + https://dataapi.joinquant.com/docs#get_concepts---%E8%8E%B7%E5%8F%96%E6%A6%82%E5%BF%B5%E5%88%97%E8%A1%A8 + + :return: pd.DataFrame, 包含概念代码、名称等字段 + """ + data = { + "method": "get_concepts", + "token": get_token(), + } + r = requests.post(url, data=json.dumps(data)) + df = text2df(r.text) + return df + + +def get_concept_stocks(symbol, date=None): + """获取指定概念在指定日期下的成份股代码列表。 + + 接口文档: + https://dataapi.joinquant.com/docs#get_concept_stocks---%E8%8E%B7%E5%8F%96%E6%A6%82%E5%BF%B5%E6%88%90%E4%BB%BD%E8%82%A1 + + :param symbol: str, 概念代码,如 ``GN036`` + :param date: str | datetime, 日期,如 ``2020-08-08``;为空时使用当前日期 + :return: list[str], 该日期下的成份股代码列表(含交易所后缀) + + 示例: + >>> symbols1 = get_concept_stocks("GN036", date="2020-07-08") + >>> symbols2 = get_concept_stocks("GN036", date=datetime.now()) + """ + if not date: + date = str(datetime.now().date()) + else: + date = pd.to_datetime(date) + + if isinstance(date, datetime): + date = str(date.date()) + + data = {"method": "get_concept_stocks", "token": get_token(), "code": symbol, "date": date} + r = requests.post(url, data=json.dumps(data)) + return r.text.split("\n") + + +def get_index_stocks(symbol, date=None): + """获取指定指数在指定日期的成份股代码列表。 + + 接口文档: + https://dataapi.joinquant.com/docs#get_index_stocks---%E8%8E%B7%E5%8F%96%E6%8C%87%E6%95%B0%E6%88%90%E4%BB%BD%E8%82%A1 + + :param symbol: str, 指数代码,如 ``000300.XSHG`` + :param date: str | datetime, 日期,如 ``2020-08-08``;为空时使用当前日期 + :return: list[str], 指定日期下的成份股代码列表 + + 示例: + >>> symbols1 = get_index_stocks("000300.XSHG", date="2020-07-08") + >>> symbols2 = get_index_stocks("000300.XSHG", date=datetime.now()) + """ + if not date: + date = str(datetime.now().date()) + + if isinstance(date, datetime): + date = str(date.date()) + + data = {"method": "get_index_stocks", "token": get_token(), "code": symbol, "date": date} + r = requests.post(url, data=json.dumps(data)) + return r.text.split("\n") + + +def get_industry(symbol): + """查询股票所属的行业归属信息。 + + 一次性返回证监会、聚宽(一级、二级)、申万(一级、二级、三级)共三套行业分类, + 便于上层根据需要选择对应分类体系做横截面分析。 + + 接口文档: + https://www.joinquant.com/help/api/help#JQDataHttp:get_industry-%E6%9F%A5%E8%AF%A2%E8%82%A1%E7%A5%A8%E6%89%80%E5%B1%9E%E8%A1%8C%E4%B8%9A + + :param symbol: str, 股票代码,含交易所后缀,例如 ``000001.XSHE`` + :return: dict, 包含股票代码、各行业分类的代码和名称 + """ + data = {"method": "get_industry", "token": get_token(), "code": symbol, "date": str(datetime.now().date())} + r = requests.post(url, data=json.dumps(data)) + df = text2df(r.text) + # 把不同分类体系(zjw/jq_l1/jq_l2/sw_l1/sw_l2/sw_l3)的行业代码与名称分别提取出来 + res = { + "股票代码": symbol, + "证监会行业代码": df[df["industry"] == "zjw"]["industry_code"].iloc[0], + "证监会行业名称": df[df["industry"] == "zjw"]["industry_name"].iloc[0], + "聚宽一级行业代码": df[df["industry"] == "jq_l1"]["industry_code"].iloc[0], + "聚宽一级行业名称": df[df["industry"] == "jq_l1"]["industry_name"].iloc[0], + "聚宽二级行业代码": df[df["industry"] == "jq_l2"]["industry_code"].iloc[0], + "聚宽二级行业名称": df[df["industry"] == "jq_l2"]["industry_name"].iloc[0], + "申万一级行业代码": df[df["industry"] == "sw_l1"]["industry_code"].iloc[0], + "申万一级行业名称": df[df["industry"] == "sw_l1"]["industry_name"].iloc[0], + "申万二级行业代码": df[df["industry"] == "sw_l2"]["industry_code"].iloc[0], + "申万二级行业名称": df[df["industry"] == "sw_l2"]["industry_name"].iloc[0], + "申万三级行业代码": df[df["industry"] == "sw_l3"]["industry_code"].iloc[0], + "申万三级行业名称": df[df["industry"] == "sw_l3"]["industry_name"].iloc[0], + } + return res + + +def get_all_securities(code, date=None) -> pd.DataFrame: + """获取平台支持的所有标的基础信息。 + + 接口文档: + https://dataapi.joinquant.com/docs#get_all_securities---%E8%8E%B7%E5%8F%96%E6%89%80%E6%9C%89%E6%A0%87%E7%9A%84%E4%BF%A1%E6%81%AF + + :param code: str, 证券类型,可选值:stock, fund, index, futures, etf, lof, fja, fjb, + QDII_fund, open_fund, bond_fund, stock_fund, money_market_fund, mixture_fund, options + :param date: str | datetime, 日期,用于获取某日期还在上市的证券信息;为空时表示获取所有日期的标的信息 + :return: pd.DataFrame, 标的基础信息表 + """ + if not date: + date = str(datetime.now().date()) + + if isinstance(date, datetime): + date = str(date.date()) + + data = {"method": "get_all_securities", "token": get_token(), "code": code, "date": date} + r = requests.post(url, data=json.dumps(data)) + return text2df(r.text) + + +def get_kline( + symbol: str, end_date: [datetime, str], freq: str, start_date: [datetime, str] = None, count=None, fq: bool = True +) -> List[RawBar]: + """获取 K 线数据并转换为 ``RawBar`` 列表。 + + 支持两种调用模式: + - 指定 ``start_date`` + ``end_date``:调用 ``get_price_period`` 获取区间数据; + - 指定 ``count`` + ``end_date``:调用 ``get_price`` 倒推获取最近 N 根。 + 两者必须二选一。 + + 接口文档: + https://www.joinquant.com/help/api/help#JQDataHttp:get_priceget_bars-%E8%8E%B7%E5%8F%96%E6%8C%87%E5%AE%9A%E6%97%B6%E9%97%B4%E5%91%A8%E6%9C%9F%E7%9A%84%E8%A1%8C%E6%83%85%E6%95%B0%E6%8D%AE + + :param symbol: str, 聚宽标的代码,例如 ``000001.XSHG`` + :param end_date: str | datetime, 截止日期 + :param freq: str, K 线级别,可选值 ``['1min', '5min', '30min', '60min', 'D', 'W', 'M']`` + :param start_date: str | datetime, 可选, 开始日期 + :param count: int, 可选, 从 end_date 倒推的 K 线数量,最大 5000 + :param fq: bool, 是否进行复权,True 时使用 end_date 作为复权基准 + :return: list[RawBar], 标准化后的 K 线对象列表(按时间升序) + :raises ValueError: ``start_date`` 和 ``count`` 同时为空时抛出 + + 示例: + >>> start_date = datetime.strptime("20200701", "%Y%m%d") + >>> end_date = datetime.strptime("20200719", "%Y%m%d") + >>> df1 = get_kline(symbol="000001.XSHG", start_date=start_date, end_date=end_date, freq="1min") + >>> df2 = get_kline(symbol="000001.XSHG", end_date=end_date, freq="1min", count=1000) + >>> df3 = get_kline(symbol="000001.XSHG", start_date='20200701', end_date='20200719', freq="1min", fq=True) + >>> df4 = get_kline(symbol="000001.XSHG", end_date='20200719', freq="1min", count=1000) + """ + if count and count > 5000: + warnings.warn(f"count={count}, 超过5000的最大值限制,仅返回最后5000条记录") + + end_date = pd.to_datetime(end_date) + + # 根据是否提供 start_date 选择不同的接口:区间查询 vs 倒数 N 根 + if start_date: + start_date = pd.to_datetime(start_date) + data = { + "method": "get_price_period", + "token": get_token(), + "code": symbol, + "unit": freq_convert[freq], + "date": start_date.strftime("%Y-%m-%d"), + "end_date": end_date.strftime("%Y-%m-%d"), + } + elif count: + data = { + "method": "get_price", + "token": get_token(), + "code": symbol, + "count": count, + "unit": freq_convert[freq], + "end_date": end_date.strftime("%Y-%m-%d"), + } + else: + raise ValueError("start_date 和 count 不能同时为空") + + if fq: + # 指定复权基准日期,与 end_date 保持一致 + data.update({"fq_ref_date": end_date.strftime("%Y-%m-%d")}) + + r = requests.post(url, data=json.dumps(data)) + # 接口返回 CSV 文本,跳过第一行表头,逐行解析 + rows = [x.split(",") for x in r.text.strip().split("\n")][1:] + + bars = [] + i = -1 + for row in rows: + # 字段顺序:['date', 'open', 'close', 'high', 'low', 'volume', 'money'] + dt = pd.to_datetime(row[0]) + if freq == "D": + # 日线统一规整为 00:00:00,避免出现非零的小时/分钟字段 + dt = dt.replace(hour=0, minute=0, second=0, microsecond=0) + + # 仅保留有成交量的 K 线,跳过停牌或集合竞价残留的空行 + if int(row[5]) > 0: + i += 1 + bars.append( + RawBar( + symbol=symbol, + dt=dt, + id=i, + freq=freq_map[freq], + open=round(float(row[1]), 4), + close=round(float(row[2]), 4), + high=round(float(row[3]), 4), + low=round(float(row[4]), 4), + vol=int(row[5]), + amount=int(float(row[6])), + ) + ) + # amount 单位:元 + if start_date: + # 双重保险:再次按 start_date 过滤,避免接口返回的边界数据 + bars = [x for x in bars if x.dt >= start_date] + if "min" in freq: + # 分钟线最后一根使用区间结束时间对齐,便于与 czsc 内部时间约定保持一致 + bars[-1].dt = freq_end_time(bars[-1].dt, freq=freq_map[freq]) + bars = [x for x in bars if x.dt <= end_date] + return bars + + +def get_kline_period( + symbol: str, start_date: [datetime, str], end_date: [datetime, str], freq: str, fq=True +) -> List[RawBar]: + """获取指定时间段的行情数据(仅区间模式,固定使用 ``get_price_period``)。 + + 与 ``get_kline`` 的区别:本函数强制使用 start_date + end_date,对超长区间会发出告警。 + + 接口文档: + https://www.joinquant.com/help/api/help#JQDataHttp:get_price_periodget_bars_period-%E8%8E%B7%E5%8F%96%E6%8C%87%E5%AE%9A%E6%97%B6%E9%97%B4%E6%AE%B5%E7%9A%84%E8%A1%8C%E6%83%85%E6%95%B0%E6%8D%AE + + :param symbol: str, 聚宽标的代码 + :param start_date: str | datetime, 开始日期 + :param end_date: str | datetime, 截止日期 + :param freq: str, K 线级别,可选值 ``['1min', '5min', '30min', '60min', 'D', 'W', 'M']`` + :param fq: bool, 是否进行复权 + :return: list[RawBar], 标准化后的 K 线对象列表 + """ + start_date = pd.to_datetime(start_date) + end_date = pd.to_datetime(end_date) + + # 粗略估算:(自然日 * 5/7) 近似得到交易日数;超 1000 个交易日可能触发服务端限制 + if (end_date - start_date).days * 5 / 7 > 1000: + warnings.warn(f"{end_date.date()} - {start_date.date()} 超过1000个交易日,K线获取可能失败,返回为0") + + data = { + "method": "get_price_period", + "token": get_token(), + "code": symbol, + "unit": freq_convert[freq], + "date": start_date.strftime("%Y-%m-%d"), + "end_date": end_date.strftime("%Y-%m-%d"), + } + if fq: + data.update({"fq_ref_date": end_date.strftime("%Y-%m-%d")}) + + r = requests.post(url, data=json.dumps(data)) + rows = [x.split(",") for x in r.text.strip().split("\n")][1:] + bars = [] + i = -1 + for row in rows: + # 字段顺序:['date', 'open', 'close', 'high', 'low', 'volume', 'money'] + dt = pd.to_datetime(row[0]) + if freq == "D": + dt = dt.replace(hour=0, minute=0, second=0, microsecond=0) + + # 跳过无成交量的行 + if int(row[5]) > 0: + i += 1 + bars.append( + RawBar( + symbol=symbol, + dt=dt, + id=i, + freq=freq_map[freq], + open=round(float(row[1]), 4), + close=round(float(row[2]), 4), + high=round(float(row[3]), 4), + low=round(float(row[4]), 4), + vol=int(row[5]), + amount=int(float(row[6])), + ) + ) + # amount 单位:元 + if start_date: + bars = [x for x in bars if x.dt >= start_date] + if "min" in freq and bars: + # 分钟线对齐到周期结束时间 + bars[-1].dt = freq_end_time(bars[-1].dt, freq=freq_map[freq]) + bars = [x for x in bars if x.dt <= end_date] + return bars + + +def get_init_bg(symbol: str, end_dt: [str, datetime], base_freq: str, freqs: List[str], max_count=1000, fq=True): + """获取指定标的的初始化 BarGenerator 以及待重放数据。 + + 用于实时回放/回测的启动阶段: + 1. 以 ``end_dt - 180 天`` 为分界点,先拉取该时点之前的 K 线,初始化 BarGenerator; + 2. 再拉取从分界点到 ``end_dt`` 的基础周期 K 线,作为后续逐根 update 的回放数据。 + + :param symbol: str, 聚宽标的代码 + :param end_dt: str | datetime, 回放的截止时间 + :param base_freq: str, 基础周期(CZSC 中文表达,如 "1分钟") + :param freqs: list[str], 需要联立的更高级别周期列表 + :param max_count: int, 各级别 K 线初始化时的最大根数,默认 1000 + :param fq: bool, 是否复权,默认 True(前复权) + :return: tuple, (bg, data) + - bg: BarGenerator, 已初始化好的 BarGenerator 实例 + - data: list[RawBar], 待逐根 update 的基础周期数据 + """ + if isinstance(end_dt, str): + end_dt = pd.to_datetime(end_dt, utc=False) + + # 以 180 天为初始化窗口,分界点设在当天的 16:00(A 股收盘后) + delta_days = 180 + last_day = (end_dt - timedelta(days=delta_days)).replace(hour=16, minute=0) + + bg = BarGenerator(base_freq, freqs, max_count) + # 对 BarGenerator 中维护的每一个频率,分别拉取并初始化 + for freq in bg.bars.keys(): + bars_ = get_kline(symbol=symbol, end_date=last_day, freq=freq_cn2jq[freq], count=max_count, fq=fq) + bg.init_freq_bars(freq, bars_) + print(f"{symbol} - {freq} - {len(bg.bars[freq])} - last_dt: {bg.bars[freq][-1].dt} - last_day: {last_day}") + + # 准备分界点之后的基础周期数据,供后续 bg.update 逐根回放 + bars2 = get_kline_period(symbol, last_day, end_dt, freq=freq_cn2jq[base_freq], fq=fq) + data = [x for x in bars2 if x.dt > last_day] + assert len(data) > 0 + print( + f"{symbol}: bar generator 最新时间 {bg.bars[base_freq][-1].dt.strftime(dt_fmt)},还有{len(data)}行数据需要update" + ) + return bg, data + + +def get_fundamental(table: str, symbol: str, date: str, columns: str = "") -> dict: + """获取单个标的、指定日期的财务基础数据。 + + 接口文档: + https://dataapi.joinquant.com/docs#get_fundamentals---%E8%8E%B7%E5%8F%96%E5%9F%BA%E6%9C%AC%E8%B4%A2%E5%8A%A1%E6%95%B0%E6%8D%AE + + 财务数据列表: + https://www.joinquant.com/help/api/help?name=Stock#%E8%B4%A2%E5%8A%A1%E6%95%B0%E6%8D%AE%E5%88%97%E8%A1%A8 + + :param table: str, 财务数据表名,如 ``indicator``、``valuation`` + :param symbol: str, 股票代码 + :param date: str, 查询日期,可以是: + - 具体日期 ``2019-03-04``; + - 年度 ``2018``; + - 季度 ``2018q1`` / ``2018q2`` / ``2018q3`` / ``2018q4`` + :param columns: str, 可选, 需要查询的字段列表,逗号分隔;为空则查询全部字段 + :return: dict, 单条记录的字典;查询失败或为空时返回 ``{}`` + + 示例: + >>> x1 = get_fundamental(table="indicator", symbol="300803.XSHE", date="2020-11-12") + >>> x2 = get_fundamental(table="indicator", symbol="300803.XSHE", date="2020") + >>> x3 = get_fundamental(table="indicator", symbol="300803.XSHE", date="2020q3") + """ + data = { + "method": "get_fundamentals", + "token": get_token(), + "table": table, + "columns": columns, + "code": symbol, + "date": date, + "count": 1, + } + r = requests.post(url, data=json.dumps(data)) + df = text2df(r.text) + try: + return df.iloc[0].to_dict() + except: + # 兼容数据为空、列缺失等多种异常,统一返回空字典 + return {} + + +def run_query(table: str, conditions: str, columns=None, count=1): + """模拟 JQDataSDK 的 run_query 方法,按条件查询数据库表。 + + 接口文档: + https://www.joinquant.com/help/api/help#JQDataHttp:run_query-%E6%A8%A1%E6%8B%9FJQDataSDK%E7%9A%84run_query%E6%96%B9%E6%B3%95 + + :param table: str, 要查询的数据库和表名,格式为 ``database.tablename``,如 ``finance.STK_XR_XD`` + :param conditions: str, 查询条件,可以为空。 + 格式:``column # 判断符 # value``,多个条件用 ``&`` 分隔表示 AND, + 例如:``report_date#>=#2006-12-01&report_date#<=#2006-12-31``。 + 注意条件字符串内不能包含空格等特殊字符。 + :param columns: str, 可选, 所查字段,多个字段用 ``,`` 分隔;为空则查询所有字段。 + 同样不能包含空格。 + :param count: int, 返回的最大记录数,默认 1 + :return: pd.DataFrame, 查询结果 + """ + data = {"method": "run_query", "token": get_token(), "table": table, "conditions": conditions, "count": count} + if columns: + data["columns"] = columns + r = requests.post(url, data=json.dumps(data)) + df = text2df(r.text) + return df + + +def get_share_basic(symbol): + """获取单个标的的基本面汇总数据(一站式 F10 信息)。 + + 本函数会聚合公司基础信息、估值(PE/PB)、市值、近 4 年关键财务指标等, + 返回一个有序字典,并附带可直接推送的中文摘要文本(``msg`` 字段)。 + + :param symbol: str, 股票代码(含交易所后缀),例如 ``000001.XSHE`` + :return: collections.OrderedDict, 基础面汇总信息 + """ + # 公司基础信息:股票名称、所属行业、地域、主营业务等 + basic_info = run_query(table="finance.STK_COMPANY_INFO", conditions="code#=#{}".format(symbol), count=1) + basic_info = basic_info.iloc[0].to_dict() + + f10 = OrderedDict() + f10["股票代码"] = basic_info["code"] + f10["股票名称"] = basic_info["short_name"] + f10["行业"] = "{}-{}".format(basic_info["industry_1"], basic_info["industry_2"]) + f10["地域"] = "{}{}".format(basic_info["province"], basic_info["city"]) + f10["主营"] = basic_info["main_business"] + f10["同花顺F10"] = "http://basic.10jqka.com.cn/{}".format(basic_info["code"][:6]) + + # 市盈率、总市值、流通市值、流通比 + # ------------------------------------------------------------------------------------------------------------------ + # 用昨日数据避免今日盘中估值跳变;valuation 表中的市值单位为亿元 + last_date = datetime.now() - timedelta(days=1) + res = get_fundamental(table="valuation", symbol=symbol, date=last_date.strftime("%Y-%m-%d")) + f10["总市值(亿)"] = float(res["market_cap"]) + f10["流通市值(亿)"] = float(res["circulating_market_cap"]) + f10["流通比(%)"] = round(float(res["circulating_market_cap"]) / float(res["market_cap"]) * 100, 2) + f10["PE_TTM"] = float(res["pe_ratio"]) + f10["PE"] = float(res["pe_ratio_lyr"]) + f10["PB"] = float(res["pb_ratio"]) + + # 近 4 年财务指标:净资产收益率、利润率、增长率、现金流等 + # ------------------------------------------------------------------------------------------------------------------ + for year in ["2017", "2018", "2019", "2020"]: + indicator = get_fundamental(table="indicator", symbol=symbol, date=year) + # indicator.get(key, 0) 可能返回 None 或空字符串,因此再做一次真值判断后再转 float + f10["{}EPS".format(year)] = float(indicator.get("eps", 0)) if indicator.get("eps", 0) else 0 + f10["{}ROA".format(year)] = float(indicator.get("roa", 0)) if indicator.get("roa", 0) else 0 + f10["{}ROE".format(year)] = float(indicator.get("roe", 0)) if indicator.get("roe", 0) else 0 + f10["{}销售净利率(%)".format(year)] = ( + float(indicator.get("net_profit_margin", 0)) if indicator.get("net_profit_margin", 0) else 0 + ) + f10["{}销售毛利率(%)".format(year)] = ( + float(indicator.get("gross_profit_margin", 0)) if indicator.get("gross_profit_margin", 0) else 0 + ) + f10["{}营业收入同比增长率(%)".format(year)] = ( + float(indicator.get("inc_revenue_year_on_year", 0)) if indicator.get("inc_revenue_year_on_year", 0) else 0 + ) + f10["{}营业收入环比增长率(%)".format(year)] = ( + float(indicator.get("inc_revenue_annual", 0)) if indicator.get("inc_revenue_annual", 0) else 0 + ) + f10["{}营业利润同比增长率(%)".format(year)] = ( + float(indicator.get("inc_operation_profit_year_on_year", 0)) + if indicator.get("inc_operation_profit_year_on_year", 0) + else 0 + ) + f10["{}经营活动产生的现金流量净额/营业收入(%)".format(year)] = ( + float(indicator.get("ocf_to_revenue", 0)) if indicator.get("ocf_to_revenue", 0) else 0 + ) + + # 组合成可以用来推送的文本摘要 + msg = "{}({})@{}\n".format(f10["股票代码"], f10["股票名称"], f10["地域"]) + msg += "\n{}\n".format("*" * 30) + for k in ["行业", "主营", "PE_TTM", "PE", "PB", "总市值(亿)", "流通市值(亿)", "流通比(%)", "同花顺F10"]: + msg += "{}:{}\n".format(k, f10[k]) + + msg += "\n{}\n".format("*" * 30) + cols = [ + "EPS", + "ROA", + "ROE", + "销售净利率(%)", + "销售毛利率(%)", + "营业收入同比增长率(%)", + "营业利润同比增长率(%)", + "经营活动产生的现金流量净额/营业收入(%)", + ] + msg += "2017~2020 财务变化\n\n" + for k in cols: + # 把 4 年同一指标横向拼接,便于一眼看出趋势 + msg += k + ":{} | {} | {} | {}\n".format( + *[f10["{}{}".format(year, k)] for year in ["2017", "2018", "2019", "2020"]] + ) + + f10["msg"] = msg + return f10 + + +def get_symbols(name="ALL", **kwargs): + """获取指定分组下的所有标的代码(聚宽数据源)。 + + :param name: str, 分组名称,可选值: + - ``ALL``:表示 stock + index + futures + etf 的并集; + - 也可直接传入聚宽支持的具体类型:stock, fund, index, futures, etf, lof, + fja, fjb, QDII_fund, open_fund, bond_fund, stock_fund, + money_market_fund, mixture_fund, options + :param kwargs: dict, 其他参数(保留以扩展,当前未使用) + :return: list[str], 该分组下的所有标的代码列表 + """ + if name.upper() == "ALL": + # ALL 分支:聚合 stock + index + futures + etf 四类标的 + codes = ( + get_all_securities("stock", date=None)["code"].unique().tolist() + + get_all_securities("index", date=None)["code"].unique().tolist() + + get_all_securities("futures", date=None)["code"].unique().tolist() + + get_all_securities("etf", date=None)["code"].unique().tolist() + ) + else: + codes = get_all_securities(name, date=None)["code"].unique().tolist() + return codes + + +def get_raw_bars(symbol, freq, sdt, edt, fq="前复权", **kwargs): + """获取 CZSC 库定义的标准 RawBar 对象列表(聚宽数据源统一入口)。 + + 本函数将 CZSC 内部使用的中文周期字符串映射为聚宽的 pandas-style 频率字符串, + 并代理调用 ``get_kline``。复权方向通过 ``fq`` 参数指定。 + + :param symbol: str, 标的代码 + :param freq: str | czsc.Freq, 周期,支持 Freq 对象,或者字符串: + ``'1分钟'``、``'5分钟'``、``'15分钟'``、``'30分钟'``、``'60分钟'``、 + ``'日线'``、``'周线'``、``'月线'``、``'季线'``、``'年线'`` + :param sdt: str | datetime, 开始时间 + :param edt: str | datetime, 结束时间 + :param fq: str, 除权类型,可选 ``"前复权"``、``"后复权"``、``"不复权"``。 + 注意:投研共享数据默认都是后复权,不需要再处理 + :param kwargs: dict, 其他参数(保留以扩展,当前未使用) + :return: list[RawBar], K 线对象列表 + """ + kwargs["fq"] = fq + freq = str(freq) + # 仅 "前复权" 时设为 True,其他统一为 False + fq = True if fq == "前复权" else False + # CZSC 中文频率到聚宽 pandas-style 频率字符串的映射 + _map = { + "1分钟": "1min", + "5分钟": "5min", + "15分钟": "15min", + "30分钟": "30min", + "60分钟": "60min", + "日线": "D", + "周线": "W", + "月线": "M", + } + return get_kline(symbol, freq=_map[freq], start_date=sdt, end_date=edt, fq=fq) diff --git a/czsc/core.py b/czsc/core.py deleted file mode 100644 index dac66d568..000000000 --- a/czsc/core.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -czsc 核心对象转发模块 - -本模块作为 rs_czsc Rust 扩展的 Python 桥接层, -集中导入并暴露所有核心类与函数。无独立业务逻辑。 - -如需强制使用 Python 实现,设置环境变量 CZSC_USE_PYTHON=1 -""" - -from rs_czsc import ( - BI, - CZSC, - FX, - ZS, - BarGenerator, - CzscExitOptimStrategy, - CzscJsonStrategy, - CzscOpenOptimStrategy, - CzscStrategyBase, - Direction, - Event, - ExitsOptimize, - FakeBI, - Freq, - Mark, - NewBar, - OpensOptimize, - Operate, - ParsedSignalDoc, - Position, - RawBar, - Signal, - WeightBacktest, - build_exit_optim_positions, - build_open_optim_positions, - build_strategy_config, - daily_performance, - derive_signals_config, - derive_signals_freqs, - format_standard_kline, - normalize_feature, - parse_signal_doc, - run_optimize_batch, - run_replay, - run_research, - top_drawdowns, -) - -__all__ = [ - "Operate", - "Freq", - "Mark", - "Direction", - "CZSC", - "BarGenerator", - "format_standard_kline", - "RawBar", - "NewBar", - "FX", - "BI", - "FakeBI", - "ZS", - "Signal", - "Event", - "Position", - "WeightBacktest", - "CzscStrategyBase", - "CzscJsonStrategy", - "CzscOpenOptimStrategy", - "CzscExitOptimStrategy", - "OpensOptimize", - "ExitsOptimize", - "run_research", - "run_replay", - "run_optimize_batch", - "build_open_optim_positions", - "build_exit_optim_positions", - "build_strategy_config", - "top_drawdowns", - "daily_performance", - "normalize_feature", - "parse_signal_doc", - "ParsedSignalDoc", - "derive_signals_config", - "derive_signals_freqs", -] diff --git a/czsc/core.pyi b/czsc/core.pyi deleted file mode 100644 index 2926ae833..000000000 --- a/czsc/core.pyi +++ /dev/null @@ -1,41 +0,0 @@ -from rs_czsc import BI as BI -from rs_czsc import CZSC as CZSC -from rs_czsc import FX as FX -from rs_czsc import ZS as ZS -from rs_czsc import BarGenerator as BarGenerator -from rs_czsc import CzscJsonStrategy as CzscJsonStrategy -from rs_czsc import CzscStrategyBase as CzscStrategyBase -from rs_czsc import Direction as Direction -from rs_czsc import Event as Event -from rs_czsc import FakeBI as FakeBI -from rs_czsc import Freq as Freq -from rs_czsc import Mark as Mark -from rs_czsc import NewBar as NewBar -from rs_czsc import Operate as Operate -from rs_czsc import Position as Position -from rs_czsc import RawBar as RawBar -from rs_czsc import Signal as Signal -from rs_czsc import WeightBacktest as WeightBacktest -from rs_czsc import format_standard_kline as format_standard_kline - -__all__ = [ - "Operate", - "Freq", - "Mark", - "Direction", - "CZSC", - "BarGenerator", - "format_standard_kline", - "RawBar", - "NewBar", - "FX", - "BI", - "FakeBI", - "ZS", - "Signal", - "Event", - "Position", - "WeightBacktest", - "CzscStrategyBase", - "CzscJsonStrategy", -] diff --git a/czsc/eda.py b/czsc/eda.py index 520d9c192..9fb59dbc4 100644 --- a/czsc/eda.py +++ b/czsc/eda.py @@ -710,7 +710,7 @@ def mark_cta_periods(df: pd.DataFrame, **kwargs): 'is_best_period', 'is_best_up_period', 'is_best_down_period', 'is_normal_period' 'is_worst_period', 'is_worst_up_period', 'is_worst_down_period' """ - from czsc.core import CZSC, format_standard_kline + from czsc import CZSC, format_standard_kline q1 = kwargs.get("q1", 0.15) q2 = kwargs.get("q2", 0.4) @@ -847,7 +847,7 @@ def mark_v_reversal(df: pd.DataFrame, **kwargs): if rs: from rs_czsc import CZSC, Direction, format_standard_kline else: - from czsc.core import CZSC + from czsc import CZSC from czsc.utils.bar_generator import format_standard_kline # 参数设置 diff --git a/czsc/envs.py b/czsc/envs.py index a9e264c96..094b27a8a 100644 --- a/czsc/envs.py +++ b/czsc/envs.py @@ -1,50 +1,116 @@ """ -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/3/17 21:41 -describe: 环境变量统一管理入口 +czsc.envs —— 极简环境变量适配层(迁移到 Rust 后保留版本) + +背景与定位: + 迁移到 Rust 后端之后,原本的"Python 回退开关"已经下线(Phase H 之后 + Python 端不再保留任何缠论核心算法的回退实现,所有调用都走 Rust)。 + 因此运行时参数被裁剪到仅剩三项,全部用于配置 czsc-core 的分析行为或 + 日志详尽程度。 + +环境变量命名约定: + 1. 推荐使用全大写形式(如 ``CZSC_MIN_BI_LEN``) + 2. 出于历史兼容性,也接受全小写形式(如 ``czsc_min_bi_len``) + 3. 当大小写两种形式都设置时,**大写形式优先** + 4. 函数参数 v 显式传入时,优先级最高(覆盖环境变量) + +可读取的环境变量: + - CZSC_VERBOSE —— 是否打印详细日志(True/False) + - CZSC_MIN_BI_LEN —— 笔的最小长度(含包含处理后 K 线根数) + - CZSC_MAX_BI_NUM —— 单个 CZSC 实例保留的最大笔数 """ +from __future__ import annotations + import os -# True 的有效表达 -valid_true = ["1", "True", "true", "Y", "y", "yes", "Yes", True] +# 被视为"真值"的字符串集合(大小写不敏感,比较前会先 lower 化) +# 列举常见写法,避免用户因大小写或缩写导致开关不生效 +_VALID_TRUE = {"1", "true", "y", "yes"} + + +def _env(name: str, default: str | None = None) -> str | None: + """ + 带大小写兜底的环境变量读取 + + 优先级: + 1. 全大写形式(如 ``NAME``) + 2. 全小写形式(如 ``name``) + 3. 调用方提供的 default + 保留小写形式纯属向后兼容;新代码请只使用大写命名。 + """ + return os.environ.get(name.upper(), os.environ.get(name.lower(), default)) -def use_python(): - """是否使用 python 版本对象 - True 表示使用 python 版本对象 - False 则使用 rust 版本对应的对象 +def _to_bool(v) -> bool: """ - v = os.environ.get("CZSC_USE_PYTHON", False) - return v in valid_true + 宽松版"任意值 -> bool" 转换 + 规则: + - bool 直接原样返回(避免 ``True/False`` 被字符串化后重判一次) + - None 视为 False + - 其他对象先 ``str()`` 再 strip + lower,比对 ``_VALID_TRUE`` 集合 -def get_verbose(verbose=None): - """verbose - 是否输出执行过程的详细信息""" - verbose = verbose if verbose else os.environ.get("czsc_verbose", None) - v = verbose in valid_true - return v + 用于解析"用户填写的环境变量值是否启用某开关"的场景, + 比 ``bool(s)``(任意非空字符串都为 True)更符合直觉。 + """ + if isinstance(v, bool): + return v + if v is None: + return False + return str(v).strip().lower() in _VALID_TRUE -def get_welcome(): - """welcome - 是否输出版本标识和缠中说禅博客摘记""" - v = os.environ.get("czsc_welcome", "0") in valid_true - return v +def get_verbose(verbose=None) -> bool: + """ + 判断是否启用详细日志输出 + 参数: + verbose: 显式传入的开关值;若为 None 则读取环境变量 ``CZSC_VERBOSE`` -def get_min_bi_len(v: int = None) -> int: - """min_bi_len - 一笔的最小长度,也就是无包含K线的数量,7是老笔的要求,6是新笔的要求""" - min_bi_len = v if v else os.environ.get("czsc_min_bi_len", 6) - return int(float(min_bi_len)) + 返回: + bool;任何被 :func:`_to_bool` 视为真值的输入都会返回 True + """ + return _to_bool(verbose if verbose is not None else _env("czsc_verbose")) + + +def get_min_bi_len(v: int | None = None) -> int: + """ + 获取笔的最小长度(去包含后的 K 线根数) + + 取值含义: + 6 —— 新规范,要求笔至少跨越 6 根去包含后的 K 线 + 7 —— 旧规范,部分历史策略仍依赖此设置以维持原始走势识别口径 + + 参数: + v: 显式传入的最小笔长度;若为 None 则读取 ``CZSC_MIN_BI_LEN``, + 都未提供时使用默认值 6 + + 返回: + int 形式的最小笔长度 + + 备注: + ``int(float(raw))`` 是为兼容 ``"6"`` / ``"6.0"`` / ``6.5`` 等多种 + 字符串/数值书写形式,统一在小数点截断后再转 int。 + """ + raw = v if v is not None else _env("czsc_min_bi_len", 6) + return int(float(raw)) + + +def get_max_bi_num(v: int | None = None) -> int: + """ + 获取单个 CZSC 实例保留的最大笔数 + 参数: + v: 显式传入的上限值;若为 None 则读取 ``CZSC_MAX_BI_NUM``, + 都未提供时默认 50 -def get_max_bi_num(v: int = None) -> int: - """max_bi_num - 单个级别K线分析中,程序最大保存的笔数量 + 返回: + int 形式的最大笔数 - 默认值为 50,仅使用内置的信号和因子,不需要调整这个参数。 - 如果进行新的信号计算需要用到更多的笔,可以适当调大这个参数。 + 备注: + - 用途:回放/实时计算时只保留最近 N 笔,避免内存与计算量随时间无界增长 + - 取值过小会丢掉中长周期信号;过大会增加每根 K 线的更新成本 """ - max_bi_num = v if v else os.environ.get("czsc_max_bi_num", 50) - return int(float(max_bi_num)) + raw = v if v is not None else _env("czsc_max_bi_num", 50) + return int(float(raw)) diff --git a/czsc/mock.py b/czsc/mock.py index 73196d2fd..662704eed 100644 --- a/czsc/mock.py +++ b/czsc/mock.py @@ -1,538 +1,76 @@ """ -模拟数据生成模块 +czsc.mock —— 转发到 wbt.mock 的薄壳封装 + +历史背景: + 早期 czsc.mock 自行维护了约 540 行的随机 K 线/因子/组合/相关性等 + 模拟数据生成实现。迁移期决定不再维持该并行实现,而是统一收敛到 + 外部 wbt 包,避免两边算法漂移导致测试输出不一致。 + +当前职责: + 仅保留两个公开入口,转发到 wbt.mock 同名实现: + - ``generate_symbol_kines`` —— 单标的多周期 K 线 + - ``generate_klines_with_weights`` —— 带权重列的 K 线(用于回测打样) + +迁移影响: + 依赖旧 ``generate_klines`` / ``generate_cs_factor`` 等已删除函数的 + 业务方需迁移到 wbt 或自己的业务模块,本模块不再提供兜底实现。 + 与之相关的测试用例(test_mock_quality.py、test_eda.py 中相应切片) + 也已在裁剪阶段一并删除。 """ -import numpy as np -import pandas as pd - -from czsc.utils.data.cache import disk_cache - - -@disk_cache(ttl=3600 * 24) -def generate_symbol_kines(symbol, freq, sdt="20100101", edt="20250101", seed=42): - """生成单个品种指定频率的K线数据 - - Args: - symbol: 品种代码,如 'AAPL', '000001.SH' 等 - freq: K线频率,支持 '1分钟', '5分钟', '15分钟', '30分钟', '日线' - sdt: 开始日期,格式为 'YYYYMMDD',默认 "20100101" - edt: 结束日期,格式为 'YYYYMMDD',默认 "20250101" - seed: 随机数种子,确保结果可重现,默认42 - - Returns: - pd.DataFrame: 包含K线数据的DataFrame,列包括dt、symbol、open、close、high、low、vol、amount - """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed + hash(symbol) % 1000) - - # 转换日期格式 - start_date = pd.to_datetime(sdt, format="%Y%m%d") - end_date = pd.to_datetime(edt, format="%Y%m%d") - - # 根据频率生成时间序列 - if freq == "日线": - dates = pd.date_range(start=start_date, end=end_date, freq="B") - elif freq in ["1分钟", "5分钟", "15分钟", "30分钟"]: - # 先生成日期范围 - trading_days = pd.date_range(start=start_date, end=end_date, freq="D") - dates = [] - - # 获取分钟数 - freq_minutes = int(freq.replace("分钟", "")) - - # A股交易时间段 - morning_start = "09:30" - morning_end = "11:30" - afternoon_start = "13:00" - afternoon_end = "15:00" - - for day in trading_days: - # 上午交易时间 - morning_times = pd.date_range( - start=f"{day.strftime('%Y-%m-%d')} {morning_start}", - end=f"{day.strftime('%Y-%m-%d')} {morning_end}", - freq=f"{freq_minutes}min", - ) - - # 下午交易时间 - afternoon_times = pd.date_range( - start=f"{day.strftime('%Y-%m-%d')} {afternoon_start}", - end=f"{day.strftime('%Y-%m-%d')} {afternoon_end}", - freq=f"{freq_minutes}min", - ) - - dates.extend(morning_times.tolist()) - dates.extend(afternoon_times.tolist()) - - dates = pd.DatetimeIndex(dates) - else: - raise ValueError(f"不支持的频率: {freq}。支持的频率: 1分钟, 5分钟, 15分钟, 30分钟, 日线") - - # 定义不同的市场阶段,模拟真实市场的周期性变化 - phases = [ - {"name": "熊市", "trend": -0.0008, "volatility": 0.025, "length": 0.3}, - {"name": "震荡", "trend": 0.0002, "volatility": 0.015, "length": 0.2}, - {"name": "牛市", "trend": 0.0012, "volatility": 0.02, "length": 0.3}, - {"name": "调整", "trend": -0.0005, "volatility": 0.02, "length": 0.2}, - ] - - # 初始价格 - base_price = 100.0 - - # 市场阶段控制变量 - total_periods = len(dates) - phase_idx = 0 - phase_periods = 0 - current_phase = phases[phase_idx] - - data = [] - - for i, dt in enumerate(dates): - # 切换市场阶段 - if phase_periods >= total_periods * current_phase["length"]: - phase_idx = (phase_idx + 1) % len(phases) - current_phase = phases[phase_idx] - phase_periods = 0 - - # 当前阶段的趋势和波动 - trend = current_phase["trend"] - volatility = current_phase["volatility"] - - # 对于分钟级数据,调整趋势和波动率 - if freq != "日线": - freq_minutes = int(freq.replace("分钟", "")) - # 按分钟级别调整趋势和波动,使日内波动更合理 - trend = trend / (240 / freq_minutes) # 240是一天的交易分钟数 - volatility = volatility / (240 / freq_minutes) ** 0.5 - - # 添加周期性波动,模拟季节性等因素 - if freq == "日线": - cycle_factor = np.sin(i / 30) * 0.001 # 30天周期 - annual_cycle = np.sin(i / 365) * 0.0005 # 年度周期 - else: - # 分钟级别的周期性波动 - cycle_factor = np.sin(i / 120) * 0.0005 # 120分钟周期 - annual_cycle = np.sin(i / (365 * 240)) * 0.0002 # 年度周期 - - # 随机噪音 - noise = np.random.normal(0, volatility) - - # 计算开盘价和收盘价 - open_price = base_price - close_price = base_price * (1 + trend + cycle_factor + annual_cycle + noise) - - # 确保价格不会变为负数 - if close_price <= 0: - close_price = base_price * 0.95 - - # 计算日内波动范围,考虑市场波动的合理性 - price_change_ratio = abs(close_price - open_price) / open_price - if freq == "日线": - daily_range = base_price * (price_change_ratio + np.random.uniform(0.01, 0.04)) - else: - # 分钟级别的波动范围更小 - daily_range = base_price * (price_change_ratio + np.random.uniform(0.001, 0.01)) - - if close_price > open_price: # 阳线 - high_price = close_price + daily_range * np.random.uniform(0.1, 0.5) - low_price = open_price - daily_range * np.random.uniform(0.1, 0.3) - else: # 阴线 - high_price = open_price + daily_range * np.random.uniform(0.1, 0.3) - low_price = close_price - daily_range * np.random.uniform(0.1, 0.5) - - # 确保价格关系正确:high >= max(open, close), low <= min(open, close) - high_price = max(high_price, open_price, close_price) - low_price = min(low_price, open_price, close_price) - - # 模拟成交量 - 价格波动大的时候成交量通常也大 - if freq == "日线": - base_volume = np.random.uniform(100000, 300000) - else: - # 分钟级别的成交量更小 - freq_minutes = int(freq.replace("分钟", "")) - base_volume = np.random.uniform(10000, 50000) * (freq_minutes / 5) # 基于5分钟调整 - - volatility_factor = price_change_ratio * 5 # 波动率影响成交量 - volume_multiplier = 1 + volatility_factor + np.random.uniform(-0.2, 0.2) - volume = int(base_volume * max(volume_multiplier, 0.3)) # 确保成交量不会过小 - - # 计算成交金额(使用平均价格) - avg_price = (high_price + low_price + open_price + close_price) / 4 - amount = volume * avg_price - - data.append( - { - "dt": dt, - "symbol": symbol, - "open": round(open_price, 2), - "close": round(close_price, 2), - "high": round(high_price, 2), - "low": round(low_price, 2), - "vol": volume, - "amount": round(amount, 2), - } - ) - - # 更新基准价格为收盘价 - base_price = close_price - phase_periods += 1 - - return pd.DataFrame(data) - - -@disk_cache(ttl=3600 * 24) -def generate_klines(seed=42): - """生成K线数据,包含完整的OHLCVA信息(开高低收量额) - - Args: - seed: 随机数种子,确保结果可重现,默认42 - - Returns: - pd.DataFrame: 包含K线数据的DataFrame,列包括dt、symbol、open、close、high、low、vol、amount - """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed) - - dates = pd.date_range(start="2010-01-01", end="2025-06-08", freq="D") - symbols = [ - "AAPL", - "MSFT", - "GOOGL", - "AMZN", - "TSLA", - "NVDA", - "META", - "NFLX", - "PYPL", - "INTC", - "CSCO", - "IBM", - "ORCL", - "SAP", - "BTC", - "ETH", - "000001", - ] - - # 定义不同的市场阶段,模拟真实市场的周期性变化 - phases = [ - {"name": "熊市", "trend": -0.0008, "volatility": 0.025, "length": 0.3}, - {"name": "震荡", "trend": 0.0002, "volatility": 0.015, "length": 0.2}, - {"name": "牛市", "trend": 0.0012, "volatility": 0.02, "length": 0.3}, - {"name": "调整", "trend": -0.0005, "volatility": 0.02, "length": 0.2}, - ] - - data = [] - for symbol in symbols: - # 初始价格 - base_price = 100.0 - - # 为每个标的设置不同的种子偏移,确保不同标的有不同的走势 - symbol_seed = seed + hash(symbol) % 1000 - np.random.seed(symbol_seed) - - # 市场阶段控制变量 - total_days = len(dates) - phase_idx = 0 - phase_days = 0 - current_phase = phases[phase_idx] - - for i, dt in enumerate(dates): - # 切换市场阶段 - if phase_days >= total_days * current_phase["length"]: - phase_idx = (phase_idx + 1) % len(phases) - current_phase = phases[phase_idx] - phase_days = 0 - - # 当前阶段的趋势和波动 - trend = current_phase["trend"] - volatility = current_phase["volatility"] - - # 添加周期性波动,模拟季节性等因素 - cycle_factor = np.sin(i / 30) * 0.001 # 30天周期 - - # 添加长期周期,模拟年度周期 - annual_cycle = np.sin(i / 365) * 0.0005 # 年度周期 - - # 随机噪音 - noise = np.random.normal(0, volatility) - - # 计算开盘价和收盘价 - open_price = base_price - close_price = base_price * (1 + trend + cycle_factor + annual_cycle + noise) - - # 确保价格不会变为负数 - if close_price <= 0: - close_price = base_price * 0.95 - - # 计算日内波动范围,考虑市场波动的合理性 - price_change_ratio = abs(close_price - open_price) / open_price - daily_range = base_price * (price_change_ratio + np.random.uniform(0.01, 0.04)) - - if close_price > open_price: # 阳线 - high_price = close_price + daily_range * np.random.uniform(0.1, 0.5) - low_price = open_price - daily_range * np.random.uniform(0.1, 0.3) - else: # 阴线 - high_price = open_price + daily_range * np.random.uniform(0.1, 0.3) - low_price = close_price - daily_range * np.random.uniform(0.1, 0.5) - - # 确保价格关系正确:high >= max(open, close), low <= min(open, close) - high_price = max(high_price, open_price, close_price) - low_price = min(low_price, open_price, close_price) +from __future__ import annotations - # 模拟成交量 - 价格波动大的时候成交量通常也大 - base_volume = np.random.uniform(100000, 300000) - volatility_factor = price_change_ratio * 5 # 波动率影响成交量 - volume_multiplier = 1 + volatility_factor + np.random.uniform(-0.2, 0.2) - volume = int(base_volume * max(volume_multiplier, 0.3)) # 确保成交量不会过小 - - # 计算成交金额(使用平均价格) - avg_price = (high_price + low_price + open_price + close_price) / 4 - amount = volume * avg_price - - data.append( - { - "dt": dt, - "symbol": symbol, - "open": round(open_price, 2), - "close": round(close_price, 2), - "high": round(high_price, 2), - "low": round(low_price, 2), - "vol": volume, - "amount": round(amount, 2), - } - ) - - # 更新基准价格为收盘价 - base_price = close_price - phase_days += 1 - - return pd.DataFrame(data) - - -def generate_klines_with_weights(seed=42): - """生成K线数据,包含权重信息""" - df = generate_klines(seed) - df["weight"] = np.random.normal(-1, 1, len(df)) - df["weight"] = df["weight"].clip(-1, 1) - df["price"] = df["close"] - return df - - -def generate_ts_factor(seed=42): - """生成K线数据,包含权重信息""" - df = generate_klines(seed) - df["F#SMA#20"] = df.groupby("symbol")["close"].rolling(20).mean().reset_index(drop=True).fillna(0) - return df - - -def generate_cs_factor(seed=42): - """生成截面因子数据""" - df = generate_klines(seed) - df["ret20"] = df.groupby("symbol")["close"].pct_change(20).reset_index(drop=True).fillna(0) - df["F#RPS#20"] = df.groupby("dt")["ret20"].rank(pct=True).reset_index(drop=True).fillna(0) - df.drop(columns=["ret20"], inplace=True) - return df - - -@disk_cache(ttl=3600 * 24) -def generate_strategy_returns(n_strategies=10, n_days=None, seed=42): - """生成多策略收益数据 - - Args: - n_strategies: 策略数量,默认10个 - n_days: 生成天数,None表示使用全部时间范围 - seed: 随机数种子,确保结果可重现,默认42 - - Returns: - pd.DataFrame: 包含策略收益数据的DataFrame,列包括dt、strategy、returns - """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed) - - dates = pd.date_range(start="2010-01-01", end="2025-06-08", freq="D") - if n_days and len(dates) > n_days: - dates = dates[-n_days:] # 取最近的n_days天 - data = [] - - for i in range(n_strategies): - strategy_name = f"策略_{i + 1:02d}" - # 生成具有不同特征的收益率 - base_return = np.random.normal(0.0005, 0.015, len(dates)) - if i % 3 == 0: # 每3个策略中有一个表现更好 - base_return += np.random.normal(0.0002, 0.005, len(dates)) - - for j, dt in enumerate(dates): - data.append({"dt": dt, "strategy": strategy_name, "returns": base_return[j]}) - - return pd.DataFrame(data) - - -@disk_cache(ttl=3600 * 24) -def generate_portfolio(seed=42): - """生成组合数据 - - Args: - seed: 随机数种子,确保结果可重现,默认42 - - Returns: - pd.DataFrame: 包含组合和基准收益数据的DataFrame,列包括dt、portfolio、benchmark - """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed) - - dates = pd.date_range(start="2010-01-01", end="2025-06-08", freq="D") - portfolio_returns = np.random.normal(0.0008, 0.012, len(dates)) - benchmark_returns = np.random.normal(0.0003, 0.010, len(dates)) - - return pd.DataFrame({"dt": dates, "portfolio": portfolio_returns, "benchmark": benchmark_returns}) - - -def set_global_seed(seed=42): - """设置全局随机数种子 - - Args: - seed: 随机数种子,默认42 - - Note: - 调用此函数后,所有使用numpy随机数的函数都会基于这个种子生成数据 - 适用于需要统一设置种子的场景 - """ - np.random.seed(seed) +import pandas as pd +# 内部使用的别名(带下划线前缀),避免被 ``from czsc.mock import *`` 误带出去 +from wbt.mock import mock_symbol_kline as _mock_symbol_kline +from wbt.mock import mock_weights as _mock_weights -@disk_cache(ttl=3600 * 24) -def generate_correlation_data(seed=42): - """生成相关性分析数据 +# 公开 API 契约:仅暴露这两个函数;其余符号一律视为内部细节 +__all__ = ["generate_symbol_kines", "generate_klines_with_weights"] - Args: - seed: 随机数种子,确保结果可重现,默认42 - Returns: - pd.DataFrame: 包含多个具有不同相关性的时间序列 +def generate_symbol_kines( + symbol: str, + freq: str, + sdt: str = "20100101", + edt: str = "20250101", + seed: int = 42, +) -> pd.DataFrame: """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed) - - dates = pd.date_range(start="2020-01-01", end="2023-12-31", freq="D") - - # 创建具有不同相关性的序列 - base_series = np.random.normal(0, 1, len(dates)) - - data = pd.DataFrame( - { - "dt": dates, - "series_A": base_series, - "series_B": 0.7 * base_series + 0.3 * np.random.normal(0, 1, len(dates)), - "series_C": -0.5 * base_series + 0.5 * np.random.normal(0, 1, len(dates)), - "series_D": np.random.normal(0, 1, len(dates)), - "returns_A": np.random.normal(0.0008, 0.015, len(dates)), - "returns_B": np.random.normal(0.0005, 0.012, len(dates)), - "returns_C": np.random.normal(0.0003, 0.020, len(dates)), - } - ) - - return data - - -@disk_cache(ttl=3600 * 24) -def generate_daily_returns(n_strategies=3, seed=42): - """生成日收益数据用于收益分析 - - Args: - n_strategies: 策略数量,默认3个 - seed: 随机数种子,确保结果可重现,默认42 - - Returns: - pd.DataFrame: 包含多策略日收益的DataFrame,index为日期 + 生成单标的、单周期的随机 K 线 DataFrame(转发到 wbt 实现) + + 参数: + symbol: 标的代码(任意字符串,会被原样写入 ``symbol`` 列) + freq: 周期字符串(如 ``"30分钟"`` / ``"日线"``),需被 wbt 识别 + sdt: 起始日期,格式 ``YYYYMMDD``,默认 2010-01-01 + edt: 结束日期,格式 ``YYYYMMDD``,默认 2025-01-01 + seed: 随机种子;同 (symbol, freq, sdt, edt, seed) 五元组保证结果可复现 + + 返回: + 与 rs-czsc 的标准 K 线 schema 一致的 DataFrame,列包括: + ``dt / symbol / open / close / high / low / vol / amount`` + + 用途: + - 单元测试中替代真实行情,避免对外部数据源的网络依赖 + - 演示/教程脚本中作为最小可运行示例的数据来源 """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed) - - dates = pd.date_range(start="2020-01-01", end="2023-12-31", freq="D") - - data = {} - for i in range(n_strategies): - strategy_name = f"strategy_{chr(65 + i)}" # strategy_A, strategy_B, strategy_C - # 生成具有不同风险收益特征的收益率 - if i == 0: # 高收益高波动 - returns = np.random.normal(0.0008, 0.015, len(dates)) - elif i == 1: # 中等收益中等波动 - returns = np.random.normal(0.0005, 0.012, len(dates)) - else: # 低收益低波动 - returns = np.random.normal(0.0003, 0.010, len(dates)) + return _mock_symbol_kline(symbol, freq, sdt=sdt, edt=edt, seed=seed) - data[strategy_name] = returns - # 添加基准 - data["benchmark"] = np.random.normal(0.0003, 0.010, len(dates)) - - return pd.DataFrame(data, index=dates) - - -@disk_cache(ttl=3600 * 24) -def generate_statistics_data(seed=42): - """生成统计分析数据 - - Args: - seed: 随机数种子,确保结果可重现,默认42 - - Returns: - pd.DataFrame: 包含统计分析所需的多种数据 +def generate_klines_with_weights(seed: int = 42) -> pd.DataFrame: """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed) - - dates = pd.date_range(start="2020-01-01", end="2023-12-31", freq="D") - - data = pd.DataFrame( - { - "dt": dates, - "returns": np.random.normal(0.0008, 0.015, len(dates)), - "factor1": np.random.normal(0, 1, len(dates)), - "factor2": np.random.normal(0, 1.2, len(dates)), - "category": np.random.choice(["A", "B", "C"], len(dates)), - "volume": np.random.randint(1000000, 10000000, len(dates)), - "price": np.cumsum(np.random.normal(0.1, 2, len(dates))) + 100, - } - ) + 生成带权重列的多标的 K 线(转发到 ``wbt.mock.mock_weights``) - return data + 参数: + seed: 随机种子,相同 seed 保证结果可复现 + 返回: + DataFrame,含标的、时间、价格列以及一列模拟"目标权重", + 典型用途为 :class:`wbt.WeightBacktest` 的输入打样数据。 -@disk_cache(ttl=3600 * 24) -def generate_event_data(seed=42): - """生成事件分析数据 - - Args: - seed: 随机数种子,确保结果可重现,默认42 - - Returns: - pd.DataFrame: 包含事件和特征的DataFrame + 备注: + wbt 端使用一组默认的 symbols / 频率,调用方目前无法自定义。 + 如需更灵活的样本,请直接调用 ``wbt.mock.mock_weights``。 """ - # 设置随机数种子确保结果可重现 - np.random.seed(seed) - - dates = pd.date_range(start="2020-01-01", end="2023-12-31", freq="D") - symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"] - - data = [] - for symbol in symbols: - for dt in dates: - # 事件发生概率为20% - event_occur = np.random.choice([0, 1], p=[0.8, 0.2]) - - data.append( - { - "dt": dt, - "symbol": symbol, - "event": event_occur, - "target": np.random.normal(0.001, 0.02), - "feature1": np.random.normal(0, 1), - "feature2": np.random.normal(0, 1.5), - "feature3": np.random.normal(0, 0.8), - "price_change": np.random.normal(0.0005, 0.015), - } - ) - - return pd.DataFrame(data) + return _mock_weights(seed=seed) diff --git a/czsc/models.py b/czsc/models.py new file mode 100644 index 000000000..82668637c --- /dev/null +++ b/czsc/models.py @@ -0,0 +1,119 @@ +""" +Python 端策略研究流程的数据模型定义 + +本模块集中存放跨子模块共享的轻量数据载体(dataclass / TypedDict), +位于 czsc 顶层是为了避免子包之间循环引用,并便于上层业务直接 ``from czsc.models import ...``。 + +包含三类对象: + - StrategyConfig: 策略配置 TypedDict,约束策略 JSON / dict 的字段集合 + - ResearchResult: 研究/回测的统一返回容器,承载 Arrow IPC 字节流并提供 DataFrame 视图 + - ReplayResult: 单标的回放的返回容器,与 ResearchResult 同构(仅类型语义不同) + - OptimizeResult: 参数优化运行的元信息容器 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional, TypedDict + +from czsc._utils._df_convert import arrow_bytes_to_pd_df + + +class StrategyConfig(TypedDict, total=False): + """ + 策略配置的 TypedDict 契约 + + 用于类型标注与 IDE 补全,让"用 dict 描述一个策略"的写法获得静态检查保护。 + 所有字段都是可选的(``total=False``),具体校验由 Rust 端按字段名读取并报错。 + + 字段说明: + name - 策略名称(用于日志/产物目录命名) + symbol - 标的代码 + base_freq - 基础周期(如 ``"30分钟"``) + signals_module - 信号实现所在的 Python 模块路径,用于 short-name 解析 + signals_config - 信号配置列表,每项形如 ``{"name": ..., "freq": ..., "params": {...}}`` + positions - 持仓配置列表,每项是一个 Position 的 JSON dump + market - 市场标识(如 "A股"/"期货"),影响交易日、夜盘判断等 + bg_max_count - BarGenerator 的最大缓冲根数 + sdt - 起始日期(``YYYYMMDD`` 或 ``YYYY-MM-DD``) + include_sdt_bar - 是否把 sdt 当日的首根 K 线纳入信号计算 + """ + + name: str + symbol: str + base_freq: str + signals_module: str + signals_config: list[dict[str, Any]] + positions: list[dict[str, Any]] + market: str + bg_max_count: int + sdt: str + include_sdt_bar: bool + + +@dataclass +class ResearchResult: + """ + 通用研究/回测结果容器 + + 设计要点: + - 三类核心数据(信号/成对交易/持仓)以 Arrow IPC 字节流形式持有, + 延后到 ``*_df()`` 调用时才反序列化为 DataFrame,避免在跨进程 + /跨语言传输时不必要的对象化开销 + - 同时保留对应的 ``*_path`` 字段,方便上层把结果落盘后只回传路径, + 字节流字段可置空(视调用模式而定) + - meta 携带策略名、标的、参数、时间窗等元信息,用于结果归档与索引 + + 字段: + meta - 任意 dict 形式的元信息 + signals_arrow - 信号表的 Arrow 字节流 + pairs_arrow - 成对交易表(成对开平仓配对)的 Arrow 字节流 + holds_arrow - 持仓时序表的 Arrow 字节流 + signals_path - 信号表对应的本地路径(可选) + pairs_path - 成对交易表对应的本地路径(可选) + holds_path - 持仓表对应的本地路径(可选) + """ + + meta: Dict[str, Any] + signals_arrow: bytes + pairs_arrow: bytes + holds_arrow: bytes + signals_path: Optional[str] = None + pairs_path: Optional[str] = None + holds_path: Optional[str] = None + + def signals_df(self): + """将 ``signals_arrow`` 反序列化为 Pandas DataFrame(按需调用,避免无谓开销)""" + return arrow_bytes_to_pd_df(self.signals_arrow) + + def pairs_df(self): + """将 ``pairs_arrow`` 反序列化为 Pandas DataFrame""" + return arrow_bytes_to_pd_df(self.pairs_arrow) + + def holds_df(self): + """将 ``holds_arrow`` 反序列化为 Pandas DataFrame""" + return arrow_bytes_to_pd_df(self.holds_arrow) + + +@dataclass +class ReplayResult(ResearchResult): + """ + 单标的回放的结果容器 + + 结构与 :class:`ResearchResult` 完全一致,单独定义子类的目的是在调用方 + 代码与日志中体现语义差异("复现某次回放"对应 ReplayResult,"批量研究 + 多组参数"对应 ResearchResult)。 + """ + pass + + +@dataclass +class OptimizeResult: + """ + 参数优化运行的元信息容器 + + 当前仅持有一个 ``message`` 字段,用于承载 Rust 端返回的简要状态描述 + (成功概要 / 警告 / 错误信息)。后续若需要扩展更结构化的优化指标 + (如最优参数、得分排序等),可在此追加字段,保持向后兼容。 + """ + message: str diff --git a/czsc/research.py b/czsc/research.py new file mode 100644 index 000000000..7d50bd159 --- /dev/null +++ b/czsc/research.py @@ -0,0 +1,321 @@ +""" +策略研究 / 回放 / 优化的 Python 入口(Rust 后端薄封装) + +模块职责: + 把 Python 端友好的入参(DataFrame、dict、Path 等)转换为 Rust 函数所需的 + 紧凑布局(Arrow 字节流 + JSON 字符串),再把 Rust 返回的 dict 包装为 + 更易消费的 dataclass(ResearchResult / ReplayResult / OptimizeResult)。 + +为何"薄封装"也值得单独成模块: + 1. 入参归一化逻辑(normalize_candidate_events / signal_config_to_runtime / + position_dump_to_runtime)需要在多个入口复用,集中放在这里避免重复 + 2. 序列化/反序列化(pandas <-> Arrow IPC、Python dict <-> JSON)也需要复用 + 3. 屏蔽 Rust 端 PyO3 函数的"裸调用",便于将来切换实现而不影响调用方 +""" + +from __future__ import annotations + +import json +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any + +import pandas as pd + +# Rust/Python 兼容层:负责"用户层 dict <-> Rust 运行时 dict"的格式互转 +from czsc._compat import ( + normalize_candidate_events, + position_dump_to_runtime, + signal_config_to_runtime, +) +# 直接调用 PyO3 暴露的 Rust 实现(带下划线别名表示"不要在调用方代码中再展开") +from czsc._native import ( + build_exit_optim_positions as _build_exit_optim_positions, + build_open_optim_positions as _build_open_optim_positions, + run_optimize, + run_optimize_batch as _run_optimize_batch, + run_replay as _run_replay, + run_research as _run_research, +) +from czsc._utils._df_convert import pandas_to_arrow_bytes +from czsc.models import OptimizeResult, ReplayResult, ResearchResult + + +# 类型别名:bars 入参允许传 DataFrame 或已就绪的 Arrow IPC 字节 +# 这种"两可"形式可以让上层在已经持有字节流的场景下省一次序列化 +BarsLike = pd.DataFrame | bytes + + +def _ensure_arrow_bytes(bars: BarsLike) -> bytes: + """ + 将 bars 入参统一规范为 Arrow IPC 字节流 + + 支持的输入: + - bytes / bytearray —— 直接转 bytes 返回,零成本 + - pd.DataFrame —— 调用 ``pandas_to_arrow_bytes`` 完成序列化 + - 其他 —— 抛 TypeError,避免在 Rust 端再触发难懂的错误 + + 单独抽出此辅助函数的目的是同时被 ``run_research`` 与 ``run_replay`` 复用, + 避免分别实现导致两个入口的入参契约出现漂移。 + """ + if isinstance(bars, (bytes, bytearray)): + return bytes(bars) + if isinstance(bars, pd.DataFrame): + return pandas_to_arrow_bytes(bars) + raise TypeError(f"bars must be pd.DataFrame or bytes, got {type(bars)}") + + +def _to_research_result(payload: dict[str, Any], cls=ResearchResult): + """ + 把 Rust 返回的 dict 装配为 ResearchResult / ReplayResult dataclass + + 参数: + payload: Rust PyO3 函数返回的字典,键名固定(meta / *_arrow / *_path) + cls: 目标 dataclass,默认 ResearchResult;run_replay 会传 ReplayResult + + 备注: + Rust 侧返回的 *_arrow 字段类型是 ``PyBytes``, + 通过 ``bytes(...)`` 显式转换可消除任何潜在的视图引用,避免上层 + 在跨线程/异步场景下意外持有底层缓冲区。 + """ + return cls( + meta=payload["meta"], + signals_arrow=bytes(payload["signals_arrow"]), + pairs_arrow=bytes(payload["pairs_arrow"]), + holds_arrow=bytes(payload["holds_arrow"]), + signals_path=payload.get("signals_path"), + pairs_path=payload.get("pairs_path"), + holds_path=payload.get("holds_path"), + ) + + +def run_research( + bars: BarsLike, + strategy: dict[str, Any], + *, + sdt: str | None = None, + opts: dict[str, Any] | None = None, +) -> ResearchResult: + """ + 内存模式执行策略研究,返回 Arrow 格式的统一结果 + + 参数: + bars: + 两种形式之一: + - 标准 OHLCV 列布局的 ``pandas.DataFrame`` + - 同一 schema 序列化后的 Arrow IPC 字节流(bytes) + strategy: + Python 用户层格式的策略字典(含 ``signals_config`` / ``positions`` 等)。 + 进入 Rust 之前会自动把其中的 positions 与 signals_config 归一化为 + 运行时格式,调用方无需关心两套格式的差异。 + sdt: + 可选的起始时间覆盖;不传则使用 strategy 内默认设置。 + opts: + 可选的执行参数开关,例如 ``{"emit_signals": False}`` 用于禁用信号产物输出。 + + 返回: + :class:`ResearchResult`,含元数据与三份 Arrow 字节流(信号 / 成对交易 / 持仓) + + 备注: + - 内存模式:完全在内存中产出 Arrow 字节,不会写盘;如需落盘请用 :func:`run_replay` + - 入参 strategy 不会被原地修改:函数内部走浅拷贝 + """ + # 选项序列化为 JSON,传给 Rust 解析;None 直接透传,由 Rust 处理默认 + opts_json = json.dumps(opts, ensure_ascii=False) if opts else None + + # 浅拷贝避免修改调用方传入的 dict + strategy_payload = dict(strategy) + + # positions / signals_config 都是用户层格式,需要先归一化为 Rust 运行时期望的紧凑布局 + if "positions" in strategy_payload: + strategy_payload["positions"] = [ + position_dump_to_runtime(pos) if isinstance(pos, dict) else pos + for pos in strategy_payload["positions"] + ] + if "signals_config" in strategy_payload: + strategy_payload["signals_config"] = [ + signal_config_to_runtime(cfg) if isinstance(cfg, dict) else cfg + for cfg in strategy_payload["signals_config"] + ] + + # 进入 Rust:bars 转字节,strategy 转 JSON 字符串 + payload = _run_research( + _ensure_arrow_bytes(bars), + json.dumps(strategy_payload, ensure_ascii=False), + sdt, + opts_json, + ) + return _to_research_result(payload, ResearchResult) + + +def run_replay( + bars: BarsLike, + strategy: dict[str, Any], + *, + res_path: str | Path | None = None, + sdt: str | None = None, + opts: dict[str, Any] | None = None, +) -> ReplayResult: + """ + 执行单标的回放任务,可选将 parquet 结果落盘 + + 与 :func:`run_research` 的差异: + - 多了一个 ``res_path`` 参数:传入时会在该目录写入 + ``signals.parquet``、``pairs.parquet``、``holds.parquet`` + - 不传 ``res_path`` 时仍返回内存中的 Arrow 结果,行为退化为内存模式 + + 参数: + bars: OHLCV DataFrame 或同 schema 的 Arrow 字节 + strategy: 策略 dict,会自动归一化 positions/signals_config + res_path: 结果落盘根目录;None 表示不落盘 + sdt: 可选起始时间覆盖 + opts: 可选执行参数开关 + + 返回: + :class:`ReplayResult`(结构同 ResearchResult,仅类型语义不同) + """ + path_str = str(res_path) if res_path is not None else None + opts_json = json.dumps(opts, ensure_ascii=False) if opts else None + + strategy_payload = dict(strategy) + if "positions" in strategy_payload: + strategy_payload["positions"] = [ + position_dump_to_runtime(pos) if isinstance(pos, dict) else pos + for pos in strategy_payload["positions"] + ] + if "signals_config" in strategy_payload: + strategy_payload["signals_config"] = [ + signal_config_to_runtime(cfg) if isinstance(cfg, dict) else cfg + for cfg in strategy_payload["signals_config"] + ] + + payload = _run_replay( + _ensure_arrow_bytes(bars), + json.dumps(strategy_payload, ensure_ascii=False), + path_str, + sdt, + opts_json, + ) + return _to_research_result(payload, ReplayResult) + + +def run_optimize_batch( + bars_dir: str | Path, + optimize_cfg: dict[str, Any], + res_path: str | Path, + *, + n_threads: int = 1, +) -> OptimizeResult: + """ + 根据 Python dict 配置批量执行参数优化任务 + + 参数: + bars_dir: + 含多个标的 K 线 parquet 文件的目录(每个标的一份文件)。 + Rust 端会按 ``optimize_cfg["symbols"]`` 顺序加载对应文件。 + optimize_cfg: + 用户层格式的优化配置 dict。必须包含 ``symbols`` 字段,否则报错。 + 若为平仓优化(``optim_type == "exit"``),函数会先把 + ``candidate_events`` 字段转换为 Rust 运行时格式,调用方无需手工归一化。 + res_path: + 优化结果输出根目录,Rust 端会在其中创建子目录写入参数组合产物。 + n_threads: + Rust 端并行优化使用的线程数;默认为 1(串行),CPU 核多的机器 + 可适当调大以加速。 + + 返回: + :class:`OptimizeResult`,仅含一段简短的运行状态消息(成功/警告/错误概要) + + 异常: + TypeError: optimize_cfg 不是 dict + ValueError: 缺少 symbols 字段 + """ + if not isinstance(optimize_cfg, dict): + raise TypeError("optimize_cfg must be dict") + if "symbols" not in optimize_cfg: + raise ValueError("optimize_cfg 缺少 symbols 字段") + + cfg = dict(optimize_cfg) + # 平仓优化的候选事件字典需要先归一化为 Rust 运行时格式(统一 signals_all/any/not 等字段) + if cfg.get("optim_type") == "exit" and "candidate_events" in cfg: + cfg["candidate_events"] = normalize_candidate_events(cfg["candidate_events"]) + + msg = _run_optimize_batch( + str(bars_dir), + json.dumps(cfg, ensure_ascii=False), + str(res_path), + n_threads, + ) + return OptimizeResult(message=msg) + + +def build_open_optim_positions( + files_position: list[str | Path], + candidate_signals: list[str], +) -> list[dict[str, Any]]: + """ + 仅构造开仓优化的候选仓位(不真正执行回测) + + 用途: + 在正式跑优化之前,先把候选 Position 列表导出来供人工审阅、对比或留档。 + 典型流程是:build_open_optim_positions -> 人工筛选 -> run_optimize_batch。 + + 参数: + files_position: 现有 Position JSON 文件路径列表,作为候选仓位的模板基准 + candidate_signals: 待与每个 Position 组合的候选信号 key 列表 + + 返回: + Position.dump() 风格的字典列表,可直接序列化或喂给后续优化流程 + """ + payload = _build_open_optim_positions([str(x) for x in files_position], candidate_signals) + return json.loads(payload) + + +def build_exit_optim_positions( + files_position: list[str | Path], + candidate_events: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """ + 仅构造平仓优化的候选仓位(不真正执行回测) + + 与 :func:`build_open_optim_positions` 类似,差别在于这里组合的是"候选平仓事件"。 + + 参数: + files_position: 现有 Position JSON 文件路径列表 + candidate_events: 旧版 czsc 优化脚本使用的 Python 事件字典列表, + 函数会自动调用 :func:`normalize_candidate_events` 做格式归一 + + 返回: + Position.dump() 风格的字典列表 + """ + payload = _build_exit_optim_positions( + [str(x) for x in files_position], + json.dumps(normalize_candidate_events(candidate_events), ensure_ascii=False), + ) + return json.loads(payload) + + +# === 兼容性回退入口(仅保留以平滑迁移)=== +# 旧工作流依赖"先把 cfg dump 成临时 JSON,再调用 run_optimize"的两步式流程。 +# 新代码请直接使用上方的 ``run_optimize_batch``(一步式、无临时文件)。 +def run_optimize_batch_legacy( + bars_dir: str | Path, + optimize_cfg: dict[str, Any], + res_path: str | Path, + *, + n_threads: int = 1, +) -> OptimizeResult: + """ + 兼容性回退入口:通过临时 JSON 文件调用旧版 ``run_optimize`` + + 仅供尚未切换到 ``run_optimize_batch`` 的存量代码使用。 + 新项目不要再依赖本函数;它会创建临时文件、留下额外 IO, + 且不参与未来的接口演进。 + """ + # delete=False 是为了把文件路径继续传给 Rust,由 Rust 读取后无需 Python 持有句柄 + # (Python 出 with 块时 NamedTemporaryFile 默认会删除文件,会破坏 Rust 端读取) + with NamedTemporaryFile("w", suffix=".json", encoding="utf-8", delete=False) as f: + json.dump(optimize_cfg, f, ensure_ascii=False) + config_path = f.name + msg = run_optimize(str(bars_dir), config_path, str(res_path), n_threads) + return OptimizeResult(message=msg) diff --git a/czsc/sensors/__init__.py b/czsc/sensors/__init__.py new file mode 100644 index 000000000..433fe67ba --- /dev/null +++ b/czsc/sensors/__init__.py @@ -0,0 +1,15 @@ +"""czsc.sensors —— 事件检测与特征分析命名空间(占位实现)。 + +本包对应 CZSC 项目中的“传感器(sensors)”层,目标是承载下列高层能力: + +- ``CTA`` 研究框架:策略回放、参数优化、并行回测; +- 特征选择器:基于滚动窗口的因子选择与分析; +- 事件匹配器:把信号组合成事件并在历史数据上扫描匹配。 + +当前阶段仅恢复命名空间的可导入性,使得 ``import czsc.sensors`` 不会 +报错;具体的传感器类(``CTAResearch``、``FeatureSelector``、``EventMatcher`` 等) +将在后续 Python 端清理工作中一并迁移过来。 + +在迁移完成之前,本模块保持为一个空的命名空间包,公共 API 冒烟测试只验证模块 +能干净地完成导入这一最小契约,不假设包内已有任何具体类型。 +""" diff --git a/czsc/signals/__init__.py b/czsc/signals/__init__.py new file mode 100644 index 000000000..a899e6df5 --- /dev/null +++ b/czsc/signals/__init__.py @@ -0,0 +1,33 @@ +"""czsc.signals —— 信号函数命名空间总入口。 + +本包是 CZSC 项目中所有“信号函数(signal functions)”的统一入口,按类别拆分为多个子模块: + +- :mod:`czsc.signals.bar` —— K 线级别信号(如成交量累计、振幅、跳空等) +- :mod:`czsc.signals.cxt` —— 上下文信号(缠论结构、分型、笔、线段等) +- :mod:`czsc.signals.tas` —— 技术指标信号(MACD、KDJ、BOLL 等基于 ``ta-lib`` 的指标族) +- :mod:`czsc.signals.vol` —— 成交量类信号 +- :mod:`czsc.signals.pressure` —— 价格压力/支撑相关信号 +- :mod:`czsc.signals.obv` —— OBV 能量潮相关信号 +- :mod:`czsc.signals.cvolp` —— 累积成交量分布(CVOLP)相关信号 + +底层实现说明 +============ +这些子模块仅提供薄薄一层 Python 转发壳。真正的信号计算逻辑由 Rust crate +``czsc-signals`` 实现,并通过 ``czsc._native.call_signal`` 暴露给 Python。Rust 端 +利用 ``inventory::collect!`` 在编译期收集所有标注了 ``#[signal(...)]`` 宏的函数, +形成一张全局信号清单(inventory),运行时由 Python 通过名字查找并派发。 + +调用约定 +======== +- 由 :class:`czsc.CzscSignals` / :func:`czsc.generate_czsc_signals` 等高层接口 + 统一驱动信号的批量生成; +- 单个 Rust 信号函数当前的入参为 ``(&CZSC, &HashMap, &mut TaCache)``,因此还不能 + 直接在 Python 端用普通函数调用。各子模块仅负责暴露“命名空间契约 + 共享辅助 + 工具”,便于上层做信号注册、参数模板查询和结构化解析。 +""" + +# 显式导入各信号子模块,确保 ``czsc.signals.xxx`` 始终可用,并触发其内部的延迟注册逻辑 +from czsc.signals import bar, cvolp, cxt, obv, pressure, tas, vol + +# 对外暴露的子模块名单(与 ``import czsc.signals`` 之后的 ``czsc.signals.xxx`` 一一对应) +__all__ = ["bar", "cxt", "tas", "vol", "pressure", "obv", "cvolp"] diff --git a/czsc/signals/_helpers.py b/czsc/signals/_helpers.py new file mode 100644 index 000000000..966d4921c --- /dev/null +++ b/czsc/signals/_helpers.py @@ -0,0 +1,184 @@ +"""czsc.signals.* 子包共享的内部辅助工具集。 + +本模块是 ``czsc.signals`` 各类别子模块(``bar`` / ``cxt`` / ``tas`` / ``vol`` / +``pressure`` / ``obv`` / ``cvolp``)共同依赖的基础设施层。所有子模块都通过 +``__getattr__`` 这种“按需查找 + 懒加载”的方式将真正的信号函数暴露出来, +而真正的派发工作则委托给本文件中的薄封装。 + +底层接口对接 +============ +本模块从 :mod:`czsc._native`(Rust 扩展模块)中导入了四个核心接口: + +* ``call_signal`` —— 真正的派发函数,接收 ``(name, czsc, params)``, + 调用对应的 Rust 实现并返回 ``list[Signal]``。 +* ``list_signal_names`` —— 返回 Rust ``inventory`` 中已注册的所有信号名称。 +* ``get_signal_template`` —— 返回某个信号在 ``#[signal(...)]`` 宏中登记的 + 参数模板字符串,用于上层做参数说明展示。 +* ``get_signal_category`` —— 返回某个信号所属的类别前缀(如 ``"bar"``)。 + +派发工作流概览 +============== +1. 子模块在 ``__getattr__`` 中检测访问的属性名是否符合 + ``"__VyymmDD"`` 这一信号命名约定; +2. 若符合,则调用 :func:`make_signal_callable` 动态合成一个 Python 闭包, + 闭包内部转发到 ``call_signal``; +3. 上层用户拿到这个闭包后,便可像普通 Python 函数那样调用 ``fn(czsc, params)``, + 屏蔽了 Rust 侧的复杂签名。 + +提供的工具 +========== +* :func:`list_signals` —— 列出所有(或某类别下的)信号名称。 +* :func:`get_signal_template` —— 查询单个信号的参数模板。 +* :func:`get_signal_category` —— 查询单个信号所属类别。 +* :func:`parse_signal_value` —— 把序列化字符串 ``"freq_name_value"`` 拆成结构化字典。 +* :func:`make_signal_callable` —— 为某个 Rust 信号生成 Python 调用包装器。 +""" + +from __future__ import annotations + +from typing import Any, Callable + +# 从 Rust 扩展模块按需导入底层接口;使用 ``as _xxx`` 显式标注为内部依赖,避免被外部直接引用 +from czsc._native import call_signal as _call_signal +from czsc._native import ( + get_signal_category as _get_signal_category, +) +from czsc._native import ( + get_signal_template as _get_signal_template, +) +from czsc._native import ( + list_signal_names as _list_signal_names, +) + + +def list_signals(category: str | None = None) -> list[str]: + """列出 ``czsc-signals`` Rust 清单中已注册的信号名称。 + + 匹配规则按信号名第一个下划线之前的部分进行(即“类别前缀”),例如 + ``category="bar"`` 会返回所有以 ``bar_`` 开头的信号名。 + + Parameters + ---------- + category : str | None, optional + 类别前缀。例如 ``"bar"``、``"cxt"``、``"tas"`` 等; + 传入 ``None``(默认)时返回全部已注册信号,并按字典序排序。 + + Returns + ------- + list[str] + 信号名称列表,已按字典序排序。 + """ + return _list_signal_names(category) + + +def get_signal_template(name: str) -> str | None: + """获取某个信号的参数模板字符串。 + + 参数模板由 Rust 端 ``#[signal(...)]`` 宏在编译期登记,描述了该信号支持的 + 参数及其默认值,常用于做配置说明或动态构建 UI。 + + Parameters + ---------- + name : str + 完整的信号名称,如 ``"bar_amount_acc_V230214"``。 + + Returns + ------- + str | None + 若该名称已注册,返回其参数模板字符串;否则返回 ``None``。 + """ + return _get_signal_template(name) + + +def get_signal_category(name: str) -> str | None: + """获取某个信号所属的类别前缀。 + + 返回的是信号名第一个下划线之前的部分,对应它隶属于 + ``czsc.signals.`` 中的哪个子模块。 + + Parameters + ---------- + name : str + 完整的信号名称。 + + Returns + ------- + str | None + 类别前缀字符串(如 ``"bar"``、``"cxt"`` 等);若该名称未注册,返回 ``None``。 + """ + return _get_signal_category(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """把序列化的信号字符串拆解为结构化字典。 + + 上层在持久化 K 线快照或回放交易状态时,往往以 ``"freq_signal_name_value"`` + 这种 ``"_"`` 拼接形式存储信号;本函数负责把它还原成 ``{"freq", "name", "value"}`` + 三元结构,便于后续过滤、对比与可视化。 + + Parameters + ---------- + text : str + 待解析的字符串。预期格式为 ``"freq_name_value"``,使用 ``_`` 作为分隔符。 + 因为 ``value`` 自身可能含有 ``_``,所以最多只在前两个 ``_`` 处分割。 + + Returns + ------- + dict[str, Any] + 包含三个键的字典: + + - ``"freq"``: 周期标识(如 ``"30分钟"``、``"日线"`` 等); + - ``"name"``: 信号名称; + - ``"value"``: 信号值字符串。 + + 当输入不足三段时,``freq`` 与 ``value`` 均回退为空字符串, + ``name`` 则填回原始文本,便于下游做容错处理。 + """ + # 按 "_" 最多分成 3 段;避免 value 自身包含 "_" 时被错误切割 + parts = text.split("_", 2) + if len(parts) < 3: + # 字段数量不足时返回降级结果,保持调用方约定的字典结构不变 + return {"freq": "", "name": text, "value": ""} + return {"freq": parts[0], "name": parts[1], "value": parts[2]} + + +def make_signal_callable(name: str) -> Callable[..., list[Any]]: + """为指定的 Rust 信号合成一个 Python 调用包装器。 + + 返回的可调用对象签名为 ``fn(czsc, params=None)``,与 Rust 派发器保持一致: + + - ``czsc`` : 已构造完成的 :class:`czsc.CZSC` 分析对象; + - ``params``: 信号参数字典;为 ``None`` 时会自动转成空字典 ``{}``, + 避免 Rust 端在解参时出现空指针异常。 + + 包装器会把 ``__name__`` / ``__qualname__`` / ``__doc__`` 设置为有意义的 + 元信息,让 Python 端的 ``help()``、IDE 跳转、序列化等内省行为表现自然。 + + Parameters + ---------- + name : str + Rust 信号清单中的信号名称。 + + Returns + ------- + Callable[..., list[Any]] + 转发到 Rust 派发器的 Python 闭包,调用后返回 ``list[Signal]``。 + """ + + # 提前查询参数模板,便于在文档字符串中展示,便于使用者快速了解参数格式 + template = _get_signal_template(name) + + def _wrapped(czsc, params: dict[str, Any] | None = None): + # 真正的派发:把 None 兜底成空字典,避免 Rust 端做额外校验 + return _call_signal(name, czsc, params or {}) + + # 将包装器伪装成原信号函数,方便上层做内省与日志输出 + _wrapped.__name__ = name + _wrapped.__qualname__ = name + if template is not None: + _wrapped.__doc__ = ( + f"Rust 信号 {name!r} 的 Python 派发包装器。\n\n" + f"参数模板(parameter template): {template!r}\n\n" + "调用方式: ``fn(czsc, params_dict)``,返回 ``list[Signal]``。" + ) + return _wrapped diff --git a/czsc/signals/bar.py b/czsc/signals/bar.py new file mode 100644 index 000000000..6d0696d9c --- /dev/null +++ b/czsc/signals/bar.py @@ -0,0 +1,122 @@ +"""czsc.signals.bar —— K 线级别(``bar_*``)信号命名空间。 + +本模块对外暴露所有以 ``bar_`` 为前缀的信号函数,覆盖 K 线层面的常用统计量, +例如成交额累计、振幅、跳空、上下影线、连续阴/阳线等基础形态。这些信号是 +事件(Events)和持仓(Positions)等高层逻辑组合的基础积木。 + +实现位置 +======== +真正的计算逻辑由 Rust 实现,源文件位于 ``crates/czsc-signals/src/bar.rs``, +通过 ``#[signal(...)]`` 宏在编译期登记到 ``inventory`` 全局表里。Python 端 +仅做轻量级转发,避免 Python 解释器层面的开销。 + +调用约定 +======== +- 所有 ``bar_*_VyymmDD`` 函数都是“按需暴露”的:访问 ``czsc.signals.bar.bar_xxx_V230101`` + 时由 :func:`__getattr__` 临时合成 Python 闭包; +- 闭包签名为 ``fn(czsc, params)``,其中 ``params`` 是参数字典; +- 返回值为 ``list[Signal]``; +- 使用 :func:`list_signals` 可枚举所有可用名称,便于做配置校验或自动注册。 +""" + +from __future__ import annotations + +from typing import Any + +from czsc.signals._helpers import ( + get_signal_template as _get_signal_template, + list_signals as _list_signals, + make_signal_callable as _make_signal_callable, + parse_signal_value as _parse_signal_value, +) + +# 当前子模块对应的类别前缀;所有以 ``bar_`` 开头的信号会归到这里 +_CATEGORY = "bar" + + +def list_signals() -> list[str]: + """列出 Rust 清单中所有 ``bar_*`` 信号的名称。 + + Returns + ------- + list[str] + 以 ``bar_`` 为前缀的信号名称列表,按字典序排序。 + """ + return _list_signals(_CATEGORY) + + +def get_signal_template(name: str) -> str | None: + """获取某个 ``bar_*`` 信号的参数模板字符串。 + + Parameters + ---------- + name : str + 完整的信号名称,例如 ``"bar_amount_acc_V230214"``。 + + Returns + ------- + str | None + 参数模板字符串;若名称未注册,返回 ``None``。 + """ + return _get_signal_template(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """解析序列化的信号字符串,等价于 :func:`czsc.signals._helpers.parse_signal_value`。 + + Parameters + ---------- + text : str + 形如 ``"freq_name_value"`` 的字符串。 + + Returns + ------- + dict[str, Any] + 含 ``freq``、``name``、``value`` 三个键的字典。 + """ + return _parse_signal_value(text) + + +def __getattr__(name: str): + """按需暴露每个已注册的 ``bar_*`` 信号为可调用对象。 + + 访问 ``czsc.signals.bar.bar_xxx_V230101`` 时本函数被自动触发; + 若名称在 Rust 清单中存在,则合成一个 Python 闭包返回;否则抛出 + :class:`AttributeError`,与普通模块属性查找语义保持一致,确保静态 + 检查工具(如 mypy、pyright)行为符合预期。 + + Parameters + ---------- + name : str + 访问的属性名。 + + Returns + ------- + Callable + 对应 Rust 信号的 Python 派发闭包,调用后返回 ``list[Signal]``。 + + Raises + ------ + AttributeError + 当 ``name`` 不属于 ``bar_*`` 信号或未在 Rust 清单中注册时抛出。 + """ + if name.startswith(_CATEGORY + "_") and name in _list_signals(_CATEGORY): + return _make_signal_callable(name) + raise AttributeError(f"module 'czsc.signals.bar' has no attribute {name!r}") + + +def __dir__() -> list[str]: + """自定义 ``dir()`` 输出,让 IDE / REPL 能看到所有动态暴露的信号。 + + 包含模块级公共函数加上当前 Rust 清单中可用的全部 ``bar_*`` 信号名称。 + """ + return [ + "list_signals", + "get_signal_template", + "parse_signal_value", + *_list_signals(_CATEGORY), + ] + + +# ``__all__`` 仅声明显式暴露的工具函数;具体信号函数按需通过 ``__getattr__`` 取得 +__all__ = ["list_signals", "get_signal_template", "parse_signal_value"] diff --git a/czsc/signals/cvolp.py b/czsc/signals/cvolp.py new file mode 100644 index 000000000..a9eee4e6c --- /dev/null +++ b/czsc/signals/cvolp.py @@ -0,0 +1,118 @@ +"""czsc.signals.cvolp —— 累积成交量分布(``cvolp_*``)信号命名空间。 + +本模块对外暴露所有以 ``cvolp_`` 为前缀的信号函数。CVOLP(Cumulative Volume Profile) +通过对一段时间内的成交量按价位分布进行累计统计,从而刻画筹码分布、价值区间 +(VAH/VAL)、控制点(POC)等市场结构特征,常用于辅助判断支撑/阻力区域。 + +实现位置 +======== +真正的计算逻辑由 Rust 实现,源文件位于 ``crates/czsc-signals/src/cvolp.rs``, +通过 ``#[signal(...)]`` 宏在编译期登记到 ``inventory`` 全局表里。Python 端 +仅做轻量级转发,避免 Python 解释器层面的开销。 + +调用约定 +======== +- 所有 ``cvolp_*_VyymmDD`` 函数都是“按需暴露”的:访问 ``czsc.signals.cvolp.cvolp_xxx_V230101`` + 时由 :func:`__getattr__` 临时合成 Python 闭包; +- 闭包签名为 ``fn(czsc, params)``,其中 ``params`` 是参数字典; +- 返回值为 ``list[Signal]``; +- 使用 :func:`list_signals` 可枚举所有可用名称。 +""" + +from __future__ import annotations + +from typing import Any + +from czsc.signals._helpers import ( + get_signal_template as _get_signal_template, + list_signals as _list_signals, + make_signal_callable as _make_signal_callable, + parse_signal_value as _parse_signal_value, +) + +# 当前子模块对应的类别前缀;所有以 ``cvolp_`` 开头的信号会归到这里 +_CATEGORY = "cvolp" + + +def list_signals() -> list[str]: + """列出 Rust 清单中所有 ``cvolp_*`` 信号的名称。 + + Returns + ------- + list[str] + 以 ``cvolp_`` 为前缀的信号名称列表,按字典序排序。 + """ + return _list_signals(_CATEGORY) + + +def get_signal_template(name: str) -> str | None: + """获取某个 ``cvolp_*`` 信号的参数模板字符串。 + + Parameters + ---------- + name : str + 完整的信号名称。 + + Returns + ------- + str | None + 参数模板字符串;若名称未注册,返回 ``None``。 + """ + return _get_signal_template(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """解析序列化的信号字符串,等价于 :func:`czsc.signals._helpers.parse_signal_value`。 + + Parameters + ---------- + text : str + 形如 ``"freq_name_value"`` 的字符串。 + + Returns + ------- + dict[str, Any] + 含 ``freq``、``name``、``value`` 三个键的字典。 + """ + return _parse_signal_value(text) + + +def __getattr__(name: str): + """按需暴露每个已注册的 ``cvolp_*`` 信号为可调用对象。 + + 访问 ``czsc.signals.cvolp.cvolp_xxx_V230101`` 时本函数被自动触发; + 若名称在 Rust 清单中存在,则合成一个 Python 闭包返回;否则抛出 + :class:`AttributeError`,与普通模块属性查找语义保持一致。 + + Parameters + ---------- + name : str + 访问的属性名。 + + Returns + ------- + Callable + 对应 Rust 信号的 Python 派发闭包,调用后返回 ``list[Signal]``。 + + Raises + ------ + AttributeError + 当 ``name`` 不属于 ``cvolp_*`` 信号或未在 Rust 清单中注册时抛出。 + """ + if name.startswith(_CATEGORY + "_") and name in _list_signals(_CATEGORY): + return _make_signal_callable(name) + raise AttributeError(f"module 'czsc.signals.cvolp' has no attribute {name!r}") + + +def __dir__() -> list[str]: + """自定义 ``dir()`` 输出,让 IDE / REPL 能看到所有动态暴露的信号。""" + return [ + "list_signals", + "get_signal_template", + "parse_signal_value", + *_list_signals(_CATEGORY), + ] + + +# ``__all__`` 仅声明显式暴露的工具函数;具体信号函数按需通过 ``__getattr__`` 取得 +__all__ = ["list_signals", "get_signal_template", "parse_signal_value"] diff --git a/czsc/signals/cxt.py b/czsc/signals/cxt.py new file mode 100644 index 000000000..c9e8a327a --- /dev/null +++ b/czsc/signals/cxt.py @@ -0,0 +1,118 @@ +"""czsc.signals.cxt —— 缠论上下文(``cxt_*``)信号命名空间。 + +本模块对外暴露所有以 ``cxt_`` 为前缀的信号函数,覆盖缠中说禅理论中的核心结构 +特征,例如分型(顶/底分型)、笔的方向与强度、线段、中枢、走势类背驰、第三类 +买卖点等。这些信号是构建缠论交易策略的核心积木。 + +实现位置 +======== +真正的计算逻辑由 Rust 实现,源文件位于 ``crates/czsc-signals/src/cxt.rs``, +通过 ``#[signal(...)]`` 宏在编译期登记到 ``inventory`` 全局表里。Python 端 +仅做轻量级转发,避免 Python 解释器层面的开销。 + +调用约定 +======== +- 所有 ``cxt_*_VyymmDD`` 函数都是“按需暴露”的:访问 ``czsc.signals.cxt.cxt_xxx_V230101`` + 时由 :func:`__getattr__` 临时合成 Python 闭包; +- 闭包签名为 ``fn(czsc, params)``,其中 ``params`` 是参数字典; +- 返回值为 ``list[Signal]``; +- 使用 :func:`list_signals` 可枚举所有可用名称。 +""" + +from __future__ import annotations + +from typing import Any + +from czsc.signals._helpers import ( + get_signal_template as _get_signal_template, + list_signals as _list_signals, + make_signal_callable as _make_signal_callable, + parse_signal_value as _parse_signal_value, +) + +# 当前子模块对应的类别前缀;所有以 ``cxt_`` 开头的信号会归到这里 +_CATEGORY = "cxt" + + +def list_signals() -> list[str]: + """列出 Rust 清单中所有 ``cxt_*`` 信号的名称。 + + Returns + ------- + list[str] + 以 ``cxt_`` 为前缀的信号名称列表,按字典序排序。 + """ + return _list_signals(_CATEGORY) + + +def get_signal_template(name: str) -> str | None: + """获取某个 ``cxt_*`` 信号的参数模板字符串。 + + Parameters + ---------- + name : str + 完整的信号名称。 + + Returns + ------- + str | None + 参数模板字符串;若名称未注册,返回 ``None``。 + """ + return _get_signal_template(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """解析序列化的信号字符串,等价于 :func:`czsc.signals._helpers.parse_signal_value`。 + + Parameters + ---------- + text : str + 形如 ``"freq_name_value"`` 的字符串。 + + Returns + ------- + dict[str, Any] + 含 ``freq``、``name``、``value`` 三个键的字典。 + """ + return _parse_signal_value(text) + + +def __getattr__(name: str): + """按需暴露每个已注册的 ``cxt_*`` 信号为可调用对象。 + + 访问 ``czsc.signals.cxt.cxt_xxx_V230101`` 时本函数被自动触发; + 若名称在 Rust 清单中存在,则合成一个 Python 闭包返回;否则抛出 + :class:`AttributeError`,与普通模块属性查找语义保持一致。 + + Parameters + ---------- + name : str + 访问的属性名。 + + Returns + ------- + Callable + 对应 Rust 信号的 Python 派发闭包,调用后返回 ``list[Signal]``。 + + Raises + ------ + AttributeError + 当 ``name`` 不属于 ``cxt_*`` 信号或未在 Rust 清单中注册时抛出。 + """ + if name.startswith(_CATEGORY + "_") and name in _list_signals(_CATEGORY): + return _make_signal_callable(name) + raise AttributeError(f"module 'czsc.signals.cxt' has no attribute {name!r}") + + +def __dir__() -> list[str]: + """自定义 ``dir()`` 输出,让 IDE / REPL 能看到所有动态暴露的信号。""" + return [ + "list_signals", + "get_signal_template", + "parse_signal_value", + *_list_signals(_CATEGORY), + ] + + +# ``__all__`` 仅声明显式暴露的工具函数;具体信号函数按需通过 ``__getattr__`` 取得 +__all__ = ["list_signals", "get_signal_template", "parse_signal_value"] diff --git a/czsc/signals/obv.py b/czsc/signals/obv.py new file mode 100644 index 000000000..c834bc8bb --- /dev/null +++ b/czsc/signals/obv.py @@ -0,0 +1,118 @@ +"""czsc.signals.obv —— 能量潮(``obv_*``)信号命名空间。 + +本模块对外暴露所有以 ``obv_`` 为前缀的信号函数。OBV(On-Balance Volume,能量潮) +是一类把成交量按收盘价方向逐笔累加的指标,用以衡量“量在价先”的资金推动强度, +常见衍生信号包括 OBV 与价格的背离、OBV 均线穿越、OBV 趋势斜率等。 + +实现位置 +======== +真正的计算逻辑由 Rust 实现,源文件位于 ``crates/czsc-signals/src/obv.rs``, +通过 ``#[signal(...)]`` 宏在编译期登记到 ``inventory`` 全局表里。Python 端 +仅做轻量级转发,避免 Python 解释器层面的开销。 + +调用约定 +======== +- 所有 ``obv_*_VyymmDD`` 函数都是“按需暴露”的:访问 ``czsc.signals.obv.obv_xxx_V230101`` + 时由 :func:`__getattr__` 临时合成 Python 闭包; +- 闭包签名为 ``fn(czsc, params)``,其中 ``params`` 是参数字典; +- 返回值为 ``list[Signal]``; +- 使用 :func:`list_signals` 可枚举所有可用名称。 +""" + +from __future__ import annotations + +from typing import Any + +from czsc.signals._helpers import ( + get_signal_template as _get_signal_template, + list_signals as _list_signals, + make_signal_callable as _make_signal_callable, + parse_signal_value as _parse_signal_value, +) + +# 当前子模块对应的类别前缀;所有以 ``obv_`` 开头的信号会归到这里 +_CATEGORY = "obv" + + +def list_signals() -> list[str]: + """列出 Rust 清单中所有 ``obv_*`` 信号的名称。 + + Returns + ------- + list[str] + 以 ``obv_`` 为前缀的信号名称列表,按字典序排序。 + """ + return _list_signals(_CATEGORY) + + +def get_signal_template(name: str) -> str | None: + """获取某个 ``obv_*`` 信号的参数模板字符串。 + + Parameters + ---------- + name : str + 完整的信号名称。 + + Returns + ------- + str | None + 参数模板字符串;若名称未注册,返回 ``None``。 + """ + return _get_signal_template(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """解析序列化的信号字符串,等价于 :func:`czsc.signals._helpers.parse_signal_value`。 + + Parameters + ---------- + text : str + 形如 ``"freq_name_value"`` 的字符串。 + + Returns + ------- + dict[str, Any] + 含 ``freq``、``name``、``value`` 三个键的字典。 + """ + return _parse_signal_value(text) + + +def __getattr__(name: str): + """按需暴露每个已注册的 ``obv_*`` 信号为可调用对象。 + + 访问 ``czsc.signals.obv.obv_xxx_V230101`` 时本函数被自动触发; + 若名称在 Rust 清单中存在,则合成一个 Python 闭包返回;否则抛出 + :class:`AttributeError`,与普通模块属性查找语义保持一致。 + + Parameters + ---------- + name : str + 访问的属性名。 + + Returns + ------- + Callable + 对应 Rust 信号的 Python 派发闭包,调用后返回 ``list[Signal]``。 + + Raises + ------ + AttributeError + 当 ``name`` 不属于 ``obv_*`` 信号或未在 Rust 清单中注册时抛出。 + """ + if name.startswith(_CATEGORY + "_") and name in _list_signals(_CATEGORY): + return _make_signal_callable(name) + raise AttributeError(f"module 'czsc.signals.obv' has no attribute {name!r}") + + +def __dir__() -> list[str]: + """自定义 ``dir()`` 输出,让 IDE / REPL 能看到所有动态暴露的信号。""" + return [ + "list_signals", + "get_signal_template", + "parse_signal_value", + *_list_signals(_CATEGORY), + ] + + +# ``__all__`` 仅声明显式暴露的工具函数;具体信号函数按需通过 ``__getattr__`` 取得 +__all__ = ["list_signals", "get_signal_template", "parse_signal_value"] diff --git a/czsc/signals/pressure.py b/czsc/signals/pressure.py new file mode 100644 index 000000000..b4c4a81b4 --- /dev/null +++ b/czsc/signals/pressure.py @@ -0,0 +1,120 @@ +"""czsc.signals.pressure —— 价格压力(``pressure_*``)信号命名空间。 + +本模块对外暴露所有以 ``pressure_`` 为前缀的信号函数,主要用于刻画市场中的 +压力位/支撑位信息,例如近期高低点的密集成交区、价位的反复测试次数、突破 +后的回踩验证等。这些信号常配合 ``cxt_*`` 与 ``vol_*`` 一起用于做出场和加仓 +判断。 + +实现位置 +======== +真正的计算逻辑由 Rust 实现,源文件位于 ``crates/czsc-signals/src/pressure.rs``, +通过 ``#[signal(...)]`` 宏在编译期登记到 ``inventory`` 全局表里。Python 端 +仅做轻量级转发,避免 Python 解释器层面的开销。 + +调用约定 +======== +- 所有 ``pressure_*_VyymmDD`` 函数都是“按需暴露”的:访问 + ``czsc.signals.pressure.pressure_xxx_V230101`` 时由 :func:`__getattr__` + 临时合成 Python 闭包; +- 闭包签名为 ``fn(czsc, params)``,其中 ``params`` 是参数字典; +- 返回值为 ``list[Signal]``; +- 使用 :func:`list_signals` 可枚举所有可用名称。 +""" + +from __future__ import annotations + +from typing import Any + +from czsc.signals._helpers import ( + get_signal_template as _get_signal_template, + list_signals as _list_signals, + make_signal_callable as _make_signal_callable, + parse_signal_value as _parse_signal_value, +) + +# 当前子模块对应的类别前缀;所有以 ``pressure_`` 开头的信号会归到这里 +_CATEGORY = "pressure" + + +def list_signals() -> list[str]: + """列出 Rust 清单中所有 ``pressure_*`` 信号的名称。 + + Returns + ------- + list[str] + 以 ``pressure_`` 为前缀的信号名称列表,按字典序排序。 + """ + return _list_signals(_CATEGORY) + + +def get_signal_template(name: str) -> str | None: + """获取某个 ``pressure_*`` 信号的参数模板字符串。 + + Parameters + ---------- + name : str + 完整的信号名称。 + + Returns + ------- + str | None + 参数模板字符串;若名称未注册,返回 ``None``。 + """ + return _get_signal_template(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """解析序列化的信号字符串,等价于 :func:`czsc.signals._helpers.parse_signal_value`。 + + Parameters + ---------- + text : str + 形如 ``"freq_name_value"`` 的字符串。 + + Returns + ------- + dict[str, Any] + 含 ``freq``、``name``、``value`` 三个键的字典。 + """ + return _parse_signal_value(text) + + +def __getattr__(name: str): + """按需暴露每个已注册的 ``pressure_*`` 信号为可调用对象。 + + 访问 ``czsc.signals.pressure.pressure_xxx_V230101`` 时本函数被自动触发; + 若名称在 Rust 清单中存在,则合成一个 Python 闭包返回;否则抛出 + :class:`AttributeError`,与普通模块属性查找语义保持一致。 + + Parameters + ---------- + name : str + 访问的属性名。 + + Returns + ------- + Callable + 对应 Rust 信号的 Python 派发闭包,调用后返回 ``list[Signal]``。 + + Raises + ------ + AttributeError + 当 ``name`` 不属于 ``pressure_*`` 信号或未在 Rust 清单中注册时抛出。 + """ + if name.startswith(_CATEGORY + "_") and name in _list_signals(_CATEGORY): + return _make_signal_callable(name) + raise AttributeError(f"module 'czsc.signals.pressure' has no attribute {name!r}") + + +def __dir__() -> list[str]: + """自定义 ``dir()`` 输出,让 IDE / REPL 能看到所有动态暴露的信号。""" + return [ + "list_signals", + "get_signal_template", + "parse_signal_value", + *_list_signals(_CATEGORY), + ] + + +# ``__all__`` 仅声明显式暴露的工具函数;具体信号函数按需通过 ``__getattr__`` 取得 +__all__ = ["list_signals", "get_signal_template", "parse_signal_value"] diff --git a/czsc/signals/tas.py b/czsc/signals/tas.py new file mode 100644 index 000000000..5be9efa7a --- /dev/null +++ b/czsc/signals/tas.py @@ -0,0 +1,120 @@ +"""czsc.signals.tas —— 技术指标(``tas_*``)信号命名空间。 + +本模块对外暴露所有以 ``tas_`` 为前缀的信号函数。``tas`` 是 “Technical Analysis +Signals” 的缩写,覆盖了基于 ``ta-lib`` 风格指标族的信号,例如 MACD(金叉、 +死叉、背驰)、KDJ、RSI、BOLL、MA 多空排列、ATR 波动率等。本子模块是构建 +中短期趋势/震荡策略最常用的信号集合之一。 + +实现位置 +======== +真正的计算逻辑由 Rust 实现,源文件位于 ``crates/czsc-signals/src/tas.rs``, +通过 ``#[signal(...)]`` 宏在编译期登记到 ``inventory`` 全局表里。Rust 端会 +利用 ``TaCache`` 对常见指标做缓存,避免在同一根 K 线上重复计算。Python 端 +仅做轻量级转发。 + +调用约定 +======== +- 所有 ``tas_*_VyymmDD`` 函数都是“按需暴露”的:访问 ``czsc.signals.tas.tas_xxx_V230101`` + 时由 :func:`__getattr__` 临时合成 Python 闭包; +- 闭包签名为 ``fn(czsc, params)``,其中 ``params`` 是参数字典; +- 返回值为 ``list[Signal]``; +- 使用 :func:`list_signals` 可枚举所有可用名称。 +""" + +from __future__ import annotations + +from typing import Any + +from czsc.signals._helpers import ( + get_signal_template as _get_signal_template, + list_signals as _list_signals, + make_signal_callable as _make_signal_callable, + parse_signal_value as _parse_signal_value, +) + +# 当前子模块对应的类别前缀;所有以 ``tas_`` 开头的信号会归到这里 +_CATEGORY = "tas" + + +def list_signals() -> list[str]: + """列出 Rust 清单中所有 ``tas_*`` 信号的名称。 + + Returns + ------- + list[str] + 以 ``tas_`` 为前缀的信号名称列表,按字典序排序。 + """ + return _list_signals(_CATEGORY) + + +def get_signal_template(name: str) -> str | None: + """获取某个 ``tas_*`` 信号的参数模板字符串。 + + Parameters + ---------- + name : str + 完整的信号名称。 + + Returns + ------- + str | None + 参数模板字符串;若名称未注册,返回 ``None``。 + """ + return _get_signal_template(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """解析序列化的信号字符串,等价于 :func:`czsc.signals._helpers.parse_signal_value`。 + + Parameters + ---------- + text : str + 形如 ``"freq_name_value"`` 的字符串。 + + Returns + ------- + dict[str, Any] + 含 ``freq``、``name``、``value`` 三个键的字典。 + """ + return _parse_signal_value(text) + + +def __getattr__(name: str): + """按需暴露每个已注册的 ``tas_*`` 信号为可调用对象。 + + 访问 ``czsc.signals.tas.tas_xxx_V230101`` 时本函数被自动触发; + 若名称在 Rust 清单中存在,则合成一个 Python 闭包返回;否则抛出 + :class:`AttributeError`,与普通模块属性查找语义保持一致。 + + Parameters + ---------- + name : str + 访问的属性名。 + + Returns + ------- + Callable + 对应 Rust 信号的 Python 派发闭包,调用后返回 ``list[Signal]``。 + + Raises + ------ + AttributeError + 当 ``name`` 不属于 ``tas_*`` 信号或未在 Rust 清单中注册时抛出。 + """ + if name.startswith(_CATEGORY + "_") and name in _list_signals(_CATEGORY): + return _make_signal_callable(name) + raise AttributeError(f"module 'czsc.signals.tas' has no attribute {name!r}") + + +def __dir__() -> list[str]: + """自定义 ``dir()`` 输出,让 IDE / REPL 能看到所有动态暴露的信号。""" + return [ + "list_signals", + "get_signal_template", + "parse_signal_value", + *_list_signals(_CATEGORY), + ] + + +# ``__all__`` 仅声明显式暴露的工具函数;具体信号函数按需通过 ``__getattr__`` 取得 +__all__ = ["list_signals", "get_signal_template", "parse_signal_value"] diff --git a/czsc/signals/vol.py b/czsc/signals/vol.py new file mode 100644 index 000000000..ec3f5f656 --- /dev/null +++ b/czsc/signals/vol.py @@ -0,0 +1,119 @@ +"""czsc.signals.vol —— 成交量(``vol_*``)信号命名空间。 + +本模块对外暴露所有以 ``vol_`` 为前缀的信号函数,覆盖与“量”相关的常见判断, +例如成交量放大/缩小、量价配合、量能突破、相对历史均量的分位数等。这些信号 +常与 ``cxt_*`` 结构信号、``tas_*`` 技术指标信号配合使用,用于辅助判断买卖 +力量是否真实有效。 + +实现位置 +======== +真正的计算逻辑由 Rust 实现,源文件位于 ``crates/czsc-signals/src/vol.rs``, +通过 ``#[signal(...)]`` 宏在编译期登记到 ``inventory`` 全局表里。Python 端 +仅做轻量级转发,避免 Python 解释器层面的开销。 + +调用约定 +======== +- 所有 ``vol_*_VyymmDD`` 函数都是“按需暴露”的:访问 ``czsc.signals.vol.vol_xxx_V230101`` + 时由 :func:`__getattr__` 临时合成 Python 闭包; +- 闭包签名为 ``fn(czsc, params)``,其中 ``params`` 是参数字典; +- 返回值为 ``list[Signal]``; +- 使用 :func:`list_signals` 可枚举所有可用名称。 +""" + +from __future__ import annotations + +from typing import Any + +from czsc.signals._helpers import ( + get_signal_template as _get_signal_template, + list_signals as _list_signals, + make_signal_callable as _make_signal_callable, + parse_signal_value as _parse_signal_value, +) + +# 当前子模块对应的类别前缀;所有以 ``vol_`` 开头的信号会归到这里 +_CATEGORY = "vol" + + +def list_signals() -> list[str]: + """列出 Rust 清单中所有 ``vol_*`` 信号的名称。 + + Returns + ------- + list[str] + 以 ``vol_`` 为前缀的信号名称列表,按字典序排序。 + """ + return _list_signals(_CATEGORY) + + +def get_signal_template(name: str) -> str | None: + """获取某个 ``vol_*`` 信号的参数模板字符串。 + + Parameters + ---------- + name : str + 完整的信号名称。 + + Returns + ------- + str | None + 参数模板字符串;若名称未注册,返回 ``None``。 + """ + return _get_signal_template(name) + + +def parse_signal_value(text: str) -> dict[str, Any]: + """解析序列化的信号字符串,等价于 :func:`czsc.signals._helpers.parse_signal_value`。 + + Parameters + ---------- + text : str + 形如 ``"freq_name_value"`` 的字符串。 + + Returns + ------- + dict[str, Any] + 含 ``freq``、``name``、``value`` 三个键的字典。 + """ + return _parse_signal_value(text) + + +def __getattr__(name: str): + """按需暴露每个已注册的 ``vol_*`` 信号为可调用对象。 + + 访问 ``czsc.signals.vol.vol_xxx_V230101`` 时本函数被自动触发; + 若名称在 Rust 清单中存在,则合成一个 Python 闭包返回;否则抛出 + :class:`AttributeError`,与普通模块属性查找语义保持一致。 + + Parameters + ---------- + name : str + 访问的属性名。 + + Returns + ------- + Callable + 对应 Rust 信号的 Python 派发闭包,调用后返回 ``list[Signal]``。 + + Raises + ------ + AttributeError + 当 ``name`` 不属于 ``vol_*`` 信号或未在 Rust 清单中注册时抛出。 + """ + if name.startswith(_CATEGORY + "_") and name in _list_signals(_CATEGORY): + return _make_signal_callable(name) + raise AttributeError(f"module 'czsc.signals.vol' has no attribute {name!r}") + + +def __dir__() -> list[str]: + """自定义 ``dir()`` 输出,让 IDE / REPL 能看到所有动态暴露的信号。""" + return [ + "list_signals", + "get_signal_template", + "parse_signal_value", + *_list_signals(_CATEGORY), + ] + + +# ``__all__`` 仅声明显式暴露的工具函数;具体信号函数按需通过 ``__getattr__`` 取得 +__all__ = ["list_signals", "get_signal_template", "parse_signal_value"] diff --git a/czsc/strategies.py b/czsc/strategies.py new file mode 100644 index 000000000..ece0353d7 --- /dev/null +++ b/czsc/strategies.py @@ -0,0 +1,331 @@ +""" +策略门面(Facade)模块 + +定位: + 在 Rust 后端的 Trader / 信号 / 仓位之上,提供一层 Python 友好的策略 + 抽象(CzscStrategyBase),屏蔽底层 PyO3 类型与运行时格式细节。 + 用户只需要继承基类、实现 ``positions`` 即可获得完整的回测、回放、 + 序列化与反序列化能力。 + +关键设计: + 1. 策略元数据(unique_signals / signals_config / freqs / base_freq) + 全部由 ``positions`` 自动派生,避免子类手工填写引起不一致 + 2. 用户层与运行时配置之间的格式互转集中在 czsc._compat,本模块只负责 + 调度,不直接关心字段映射 + 3. backtest / replay 委托给 czsc.research 中的 run_research / run_replay, + 本模块只组合参数与处理 IO(路径、刷新、是否落盘等) +""" + +from __future__ import annotations + +import hashlib +import json +import shutil +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import pandas as pd + +from czsc._compat import ( + bars_to_dataframe, + position_dump_to_runtime, + signal_config_to_public, + signal_config_to_runtime, + sort_freqs, +) +from czsc._native import Position +# 直接调用 Rust 端的派生器(用下划线后缀别名,避免与同名公开 API 混淆) +from czsc._native import ( + derive_signals_config as _derive_signals_config_impl, + derive_signals_freqs as _derive_signals_freqs_impl, +) +from czsc.research import run_replay, run_research + + +class CzscStrategyBase(ABC): + """ + czsc 风格策略定义的 Python 抽象基类 + + 使用方式: + 子类只需实现 :attr:`positions` 属性,返回 ``Position`` 列表; + 其它元数据(unique_signals / signals_config / freqs / base_freq) + 会由本类基于 ``positions`` 自动派生,子类无需关心。 + + 必填初始化参数(通过 kwargs 传入): + - symbol: 策略对应的标的代码 + + 常用可选参数: + - signals_module_name: 用户自定义信号模块路径,默认 "czsc.signals" + - name: 策略名(默认取类名),影响产物目录命名 + - market: 市场标识,默认 "默认" + - bg_max_count: BarGenerator 缓冲根数上限,默认 5000 + - sdt / include_sdt_bar: 起始时间相关 + """ + + def __init__(self, **kwargs): + """ + 保存兼容层所需的策略级参数 + + 所有参数统一存入 ``self.kwargs``,子类可按需通过 ``self.kwargs.get(key)`` + 读取或扩展。signals_module_name 会被单独提取,因为它在多个属性派生 + 路径上都会用到,避免每次访问时重复 dict.get。 + """ + self.kwargs = kwargs + # 自定义信号模块路径:用户在外部包中扩展信号实现时,可在此声明完整模块路径 + self.signals_module_name = kwargs.get("signals_module_name", "czsc.signals") + + @property + def symbol(self): + """策略绑定的标的代码(必传,缺失会触发 KeyError 显式报错)""" + return self.kwargs["symbol"] + + @property + def unique_signals(self): + """ + 汇总所有 Position 中出现过的信号 key,去重并保持首次出现顺序 + + 为什么不用 ``set``? + - set 不保证顺序,会让最终生成的 signals_config 顺序在不同 + Python 版本/运行环境中飘移 + - 显式维护一个 ``ordered`` 列表 + ``seen`` 集合既能去重又能 + 保留稳定顺序,便于 diff 与产物比对 + """ + seen = set() + ordered = [] + for pos in self.positions: + for signal in pos.unique_signals: + if signal not in seen: + seen.add(signal) + ordered.append(signal) + return ordered + + @property + def signals_config(self): + """ + 基于 ``unique_signals`` 派生的"用户层"信号配置列表 + + 实现路径: + unique_signals(list[str]) -> + Rust 派生器返回运行时配置(list[dict]) -> + ``signal_config_to_public`` 转为用户层格式(含模块前缀的 name 等) + """ + runtime = _derive_signals_config_impl(self.unique_signals) + return [signal_config_to_public(cfg, self.signals_module_name) for cfg in runtime] + + @property + def freqs(self): + """ + 策略涉及的所有周期集合(去重并按缠论惯用顺序排序) + + 实现细节: + 先把 signals_config 转回运行时格式,再交给 Rust 派生器返回所有 + 涉及的周期;最后用 :func:`sort_freqs` 做去重与排序。 + """ + runtime = [signal_config_to_runtime(cfg) for cfg in self.signals_config] + return sort_freqs(_derive_signals_freqs_impl(runtime)) + + @property + def sorted_freqs(self): + """与 :attr:`freqs` 等价的别名(保留接口兼容)""" + return sort_freqs(self.freqs) + + @property + def base_freq(self): + """策略基础周期:取 freqs 中最小(最高频)的那一档""" + return self.sorted_freqs[0] + + @property + @abstractmethod + def positions(self): + """ + 策略持仓列表(必须由子类实现) + + 返回值要求: + list[czsc._native.Position],每个元素描述一组开/平仓事件与参数。 + """ + raise NotImplementedError + + def backtest(self, bars, **kwargs): + """ + 执行策略回测,返回内存中的 :class:`ResearchResult` + + 参数: + bars: K 线数据;可以是 DataFrame、RawBar 列表或 Arrow 字节 + kwargs: 可选覆盖项: + - sdt: 起始时间覆盖 + - emit_signals: 是否产出信号产物 + - include_sdt_bar: 是否包含 sdt 当根 K 线 + """ + return run_research( + self._normalize_bars_input(bars), + self._build_runtime_strategy(kwargs), + sdt=kwargs.get("sdt"), + opts=self._build_run_opts(kwargs), + ) + + def replay(self, bars, res_path, **kwargs): + """ + 执行策略回放,结果写入指定目录 + + 参数: + bars: K 线数据 + res_path: 落盘根目录 + kwargs: + refresh: True 表示先清空 res_path 再写入;默认 False + exist_ok: 目录已存在但 refresh=False 时是否仍然执行; + 默认 False,此时会跳过执行并返回 None + 其余可覆盖项同 :meth:`backtest` + + 返回: + :class:`ReplayResult`;当目录已存在且未要求覆盖时返回 ``None`` + """ + path = Path(res_path) + # 显式要求刷新:先清空目录,避免新旧产物混合 + if kwargs.get("refresh", False): + shutil.rmtree(path, ignore_errors=True) + + # 既不允许覆盖也未要求刷新 -> 跳过执行(避免重复回放浪费算力) + exist_ok = kwargs.get("exist_ok", False) + if path.exists() and not exist_ok and not kwargs.get("refresh", False): + return None + + return run_replay( + self._normalize_bars_input(bars), + self._build_runtime_strategy(kwargs), + res_path=path, + sdt=kwargs.get("sdt"), + opts=self._build_run_opts(kwargs), + ) + + def save_positions(self, path): + """ + 将策略持仓序列化为 JSON 文件落盘(兼容 ``Position.dump`` 格式) + + 每个 Position 写为一个独立 JSON 文件(``.json``), + 文件中包含一个 ``md5`` 字段,用于后续加载时校验文件未被篡改。 + + 参数: + path: 输出目录;不存在会自动创建 + """ + out_dir = Path(path) + out_dir.mkdir(parents=True, exist_ok=True) + for pos in self.positions: + payload = position_dump_to_runtime(pos.dump(with_data=False)) + # symbol 与策略实例耦合,落盘时移除以便 Position 可被复用到不同标的 + payload.pop("symbol", None) + # md5 校验码:基于序列化字符串生成,加载时可校验配置完整性 + payload["md5"] = hashlib.md5(str(payload).encode("utf-8")).hexdigest() + (out_dir / f"{payload['name']}.json").write_text( + json.dumps(payload, ensure_ascii=False), encoding="utf-8" + ) + + def load_positions(self, files, check=True): + """ + 从多个 JSON 文件加载 Position 列表,并自动绑定当前策略的 symbol + + 参数: + files: JSON 文件路径列表 + check: 是否校验文件中的 md5 字段;默认开启。出现不一致即抛 AssertionError, + 防止持仓配置被外部静默修改。 + + 返回: + list[Position] + """ + positions = [] + for file in files: + payload = json.loads(Path(file).read_text(encoding="utf-8")) + md5 = payload.pop("md5", None) + # md5 校验:确保 dump/load 之间文件内容一致;md5 不存在时跳过校验(兼容旧文件) + if check and md5 is not None: + assert md5 == hashlib.md5(str(payload).encode("utf-8")).hexdigest() + # 把当前策略的 symbol 注入到 Position 配置中(save_positions 时被剥离) + payload["symbol"] = self.symbol + positions.append(Position.load(payload)) + return positions + + def _build_runtime_strategy(self, overrides: dict[str, Any]) -> dict[str, Any]: + """ + 把 self + overrides 拼装为 Rust 端可以直接消费的运行时 strategy dict + + 参数: + overrides: 调用 backtest/replay 时传入的临时覆盖参数 + + 返回: + 一份完整的策略字典,positions 与 signals_config 已转为运行时格式 + """ + # sdt 解析顺序:调用时显式传 > 实例 kwargs > None(不带 sdt 字段) + sdt = overrides.get("sdt", self.kwargs.get("sdt")) + strategy = { + "name": self.kwargs.get("name", self.__class__.__name__), + "symbol": self.symbol, + "base_freq": self.base_freq, + "signals_module": self.signals_module_name, + "signals_config": [signal_config_to_runtime(cfg) for cfg in self.signals_config], + "positions": [ + position_dump_to_runtime(pos.dump(with_data=False)) for pos in self.positions + ], + "market": self.kwargs.get("market", "默认"), + "bg_max_count": int(self.kwargs.get("bg_max_count", 5000)), + # 仅当 sdt 存在时才注入字段,避免显式写 None 触发 Rust 端 schema 错误 + **({"sdt": sdt} if sdt else {}), + } + # include_sdt_bar 行为说明: + # - 默认走 CzscStrategyBase 语义:bars_right 从 ``dt > sdt`` 开始(不包含起始那根) + # - 调用方显式传 True 时,切到 generate_czsc_signals 风格 ``dt >= sdt``(包含起始那根) + # - 显式传 False 同样会写入字段,让 Rust 端按指定语义运行 + include_sdt_bar = overrides.get( + "include_sdt_bar", + self.kwargs.get("include_sdt_bar"), + ) + if include_sdt_bar is not None: + strategy["include_sdt_bar"] = bool(include_sdt_bar) + return strategy + + @staticmethod + def _build_run_opts(kwargs: dict[str, Any]) -> dict[str, Any] | None: + """ + 从用户 kwargs 中提取 Rust 端 opts 字段 + + 当前仅暴露 ``emit_signals`` 一个开关,控制是否把信号产物写入结果。 + 未传入时返回 None,让 Rust 端使用默认行为,避免发送空 dict 引起歧义。 + """ + if "emit_signals" not in kwargs: + return None + return {"emit_signals": bool(kwargs["emit_signals"])} + + def _normalize_bars_input(self, bars): + """ + 把多种 K 线输入统一为 Rust 可接受的形式 + + - bytes / bytearray: 视为已就绪的 Arrow 字节,直接透传 + - 其他(DataFrame / list[RawBar]): 走 ``bars_to_dataframe`` 强制规范, + 关键是把所有数值列转为 Float64(Rust IPC 读取器对类型严格匹配) + """ + if isinstance(bars, (bytes, bytearray)): + return bytes(bars) + # 即便已经是 DataFrame,也要走一次 bars_to_dataframe, + # 以确保数值列被强制转为 Float64(Rust IPC 读取器对此严格要求)。 + return bars_to_dataframe(bars, symbol=self.symbol) + + +class CzscJsonStrategy(CzscStrategyBase): + """ + 直接从 JSON 文件加载持仓定义的策略包装器 + + 使用场景: + 策略配置由外部工具(GUI / 管理后台 / 优化结果)落盘为 JSON 后, + Python 侧只需指定文件路径即可装载并执行回测/回放。 + + 初始化关键参数(通过 kwargs 传入): + - files_position: JSON 文件路径列表 + - check_position: 是否做 md5 校验;默认 True + - 其余同 :class:`CzscStrategyBase` + """ + + @property + def positions(self): + """从 ``files_position`` 加载并返回 Position 列表,受 ``check_position`` 控制是否校验""" + return self.load_positions( + self.kwargs["files_position"], self.kwargs.get("check_position", True) + ) diff --git a/czsc/svc/backtest.py b/czsc/svc/backtest.py index 23b4f3599..9737b8e1a 100644 --- a/czsc/svc/backtest.py +++ b/czsc/svc/backtest.py @@ -1,7 +1,22 @@ """ -回测相关的可视化组件 - -包含权重分布、权重回测、持仓回测、止损分析等回测功能 +回测分析与可视化组件模块 + +本模块封装了一组基于 Streamlit 的回测可视化组件,主要用于在交互式仪表盘中展示 +权重型策略的回测结果,包括以下核心功能: + +1. 权重分布展示:观察各品种权重的分布特征与分位数; +2. 权重回测结果展示:基于 ``WeightBacktest`` 的核心绩效指标、日收益、回撤、 + 月度收益、年度统计、分段绩效等多维度展示; +3. 持仓组合回测:基于持仓权重数据计算每日收益、交易成本与净收益; +4. 按方向止损分析:在回测前先做止损改写,再进行权重回测; +5. 阈值过滤回测:按权重的样本内分位数阈值过滤后再回测,对比不同阈值的效果; +6. 按年度/品种切片回测、多空分别回测、综合回测面板; + +模块依赖: +- ``wbt.WeightBacktest``:底层 Rust 加速的权重回测器; +- ``streamlit``:负责所有前端展示; +- ``czsc.eda.cal_yearly_days``:根据交易日序列推断年化天数; +- 同包内 ``returns``、``statistics``、``strategy`` 子模块:共享的可视化组件。 """ import numpy as np @@ -9,27 +24,33 @@ import streamlit as st from loguru import logger -from rs_czsc import WeightBacktest +from wbt import WeightBacktest def show_weight_distribution(dfw, abs_weight=True, **kwargs): """展示权重分布 - :param dfw: pd.DataFrame, 包含 symbol, dt, price, weight 列 - :param abs_weight: bool, 是否取权重的绝对值 - :param kwargs: - - percentiles: list, 分位数 + 按品种分组,对每个品种的权重序列调用 ``describe`` 计算分位数并展示,常用于 + 观察策略在不同品种上的仓位规模与极端值情况。 + + :param dfw: pd.DataFrame,必须包含 ``symbol``、``dt``、``price``、``weight`` 列 + :param abs_weight: bool,是否对权重取绝对值后再统计;多空策略一般置 True + :param kwargs: 其他关键字参数 + - percentiles: list,``describe`` 使用的分位数序列,默认包含从 5% 到 95% 的常用分位 + :return: None;结果通过 :func:`statistics.show_df_describe` 写入 Streamlit 页面 """ dfw = dfw.copy() if abs_weight: + # 多空策略下绝对值更能反映仓位规模 dfw["weight"] = dfw["weight"].abs() default_percentiles = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95] percentiles = kwargs.get("percentiles", default_percentiles) + # 按 symbol 分组,得到每个品种的描述性统计并展开成宽表 dfs = dfw.groupby("symbol")["weight"].apply(lambda x: x.describe(percentiles=percentiles)).unstack().reset_index() - # 使用 show_df_describe 来显示结果 + # 复用 statistics 模块中的 describe 渲染函数,确保样式与其他位置一致 from .statistics import show_df_describe show_df_describe(dfs) @@ -38,7 +59,12 @@ def show_weight_distribution(dfw, abs_weight=True, **kwargs): def show_weight_backtest(dfw, **kwargs): """展示权重回测结果 - :param dfw: 回测数据,任何字段都不允许有空值;数据样例: + 将权重数据传入 ``WeightBacktest``,得到完整的回测对象后,把核心绩效指标、 + 交易方向统计、日收益曲线、回撤、分段日收益、年度绩效、月度收益等内容 + 渲染到 Streamlit 页面。 + + :param dfw: 回测数据;任何字段都不允许有空值;数据样例如下: + =================== ======== ======== ======= dt symbol weight price =================== ======== ======== ======= @@ -49,16 +75,17 @@ def show_weight_backtest(dfw, **kwargs): 2019-01-02 09:05:00 DLi9001 0.25 961.695 =================== ======== ======== ======= - :param kwargs: - - fee: 单边手续费,单位为BP,默认为2BP - - digits: 权重小数位数,默认为2 - - show_drawdowns: bool,是否展示最大回撤,默认为 False - - show_daily_detail: bool,是否展示每日收益详情,默认为 False - - show_backtest_detail: bool,是否展示回测详情,默认为 False - - show_splited_daily: bool,是否展示分段日收益表现,默认为 False - - show_yearly_stats: bool,是否展示年度绩效指标,默认为 False - - show_monthly_return: bool,是否展示月度累计收益,默认为 False - - n_jobs: int, 并行计算的进程数,默认为 1 + :param kwargs: 其他参数 + - fee: 单边手续费,单位为 BP,默认为 2BP + - digits: 权重小数位数,默认为 2 + - show_drawdowns: bool,是否展示最大回撤,默认 False + - show_daily_detail: bool,是否展示每日收益详情,默认 False + - show_backtest_detail: bool,是否展示回测详情,默认 False + - show_splited_daily: bool,是否展示分段日收益表现,默认 False + - show_yearly_stats: bool,是否展示年度绩效指标,默认 False + - show_monthly_return: bool,是否展示月度累计收益,默认 False + - n_jobs: int,并行计算的进程数,默认为 1 + :return: WeightBacktest,构造好的回测对象,便于后续进一步分析 """ from czsc.eda import cal_yearly_days @@ -69,13 +96,16 @@ def show_weight_backtest(dfw, **kwargs): weight_type = kwargs.pop("weight_type", "ts") if not yearly_days: + # 未显式指定时,根据 dt 序列推断每年实际交易天数 yearly_days = cal_yearly_days(dts=dfw["dt"].unique()) + # 严格校验缺失值;存在缺失时直接终止,避免回测结果失真 if (dfw.isnull().sum().sum() > 0) or (dfw.isna().sum().sum() > 0): st.warning("权重数据中存在空值,请检查数据后再试;空值数据如下:") st.dataframe(dfw[dfw.isnull().sum(axis=1) > 0], width="stretch") st.stop() + # 构造回测对象;fee 在 BP 与小数之间转换 wb = WeightBacktest( dfw=dfw, fee_rate=fee / 10000, digits=digits, n_jobs=n_jobs, yearly_days=yearly_days, weight_type=weight_type ) @@ -87,7 +117,7 @@ def show_weight_backtest(dfw, **kwargs): f"年交易天数 {yearly_days},品种数量:{dfw['symbol'].nunique()}" ) - # 显示核心指标 + # 顶部展示 11 个核心绩效指标 c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = st.columns([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) c1.metric("盈亏平衡点", f"{stat['盈亏平衡点']:.2%}") c2.metric("单笔收益(BP)", f"{stat['单笔收益']}") @@ -101,14 +131,14 @@ def show_weight_backtest(dfw, **kwargs): c10.metric("多头占比", f"{stat['多头占比']:.2%}") c11.metric("空头占比", f"{stat['空头占比']:.2%}") - # 显示交易方向统计 + # 多空分别统计:盈亏比、胜率等关键指标 with st.popover(label="交易方向统计", help="统计多头、空头交易次数、胜率、盈亏比等信息"): dfx = pd.DataFrame([wb.long_stats, wb.short_stats]) dfx.index = ["多头", "空头"] dfx.index.name = "交易方向" st.dataframe(dfx.T.astype(str), width="stretch") - # 显示日收益 + # 取出每日收益序列,转成 datetime 索引,便于下游绘图 dret = wb.daily_return.copy() dret["dt"] = pd.to_datetime(dret["date"]) dret = dret.set_index("dt").drop(columns=["date"]) @@ -143,7 +173,13 @@ def show_weight_backtest(dfw, **kwargs): def show_holds_backtest(df, **kwargs): """分析持仓组合的回测结果 - :param df: 回测数据,任何字段都不允许有空值;建议 weight 列在截面的和为 1;数据样例: + 本函数面向截面持仓权重的回测场景,输入数据每行代表某个时间点对某个标的的目标持仓 + 权重,并配合下一期收益(``n1b``)。函数会通过 :func:`holds_performance` 计算每日的 + 交易成本、净收益与换手率,并展示日收益、回撤、分段表现、年度绩效与月度累计收益等。 + + :param df: 回测数据;任何字段都不允许有空值;建议 ``weight`` 列在截面的和为 1, + 数据样例: + =================== ======== ======== ======= dt symbol weight n1b =================== ======== ======== ======= @@ -154,13 +190,14 @@ def show_holds_backtest(df, **kwargs): 2019-01-02 09:05:00 DLi9001 0.25 961.695 =================== ======== ======== ======= - :param kwargs: - - fee: 单边手续费,单位为BP,默认为2BP - - digits: 权重小数位数,默认为2 - - show_drawdowns: 是否展示最大回撤分析,默认为True - - show_splited_daily: 是否展示分段收益表现,默认为False - - show_yearly_stats: 是否展示年度绩效指标,默认为True - - show_monthly_return: 是否展示月度累计收益,默认为True + :param kwargs: 其他参数 + - fee: 单边手续费,单位为 BP,默认为 2BP + - digits: 权重小数位数,默认为 2 + - show_drawdowns: bool,是否展示最大回撤分析,默认 True + - show_splited_daily: bool,是否展示分段收益表现,默认 False + - show_yearly_stats: bool,是否展示年度绩效指标,默认 True + - show_monthly_return: bool,是否展示月度累计收益,默认 True + :return: None """ from czsc.utils.analysis.stats import holds_performance @@ -172,12 +209,13 @@ def show_holds_backtest(df, **kwargs): st.dataframe(df[df.isnull().sum(axis=1) > 0], width="stretch") st.stop() - # 计算每日收益、交易成本、净收益 + # 计算每日收益、交易成本与净收益 sdt = df["dt"].min().strftime("%Y-%m-%d") edt = df["dt"].max().strftime("%Y-%m-%d") dfr = holds_performance(df, fee=fee, digits=digits) st.write(f"回测时间:{sdt} ~ {edt}; 单边年换手率:{dfr['change'].mean() * 252:.2f} 倍; 单边费率:{fee}BP") + # 把"扣费后净收益"列提取为标准的日收益序列 daily = dfr[["date", "edge_post_fee"]].copy() daily.columns = ["dt", "return"] daily["dt"] = pd.to_datetime(daily["dt"]) @@ -208,20 +246,26 @@ def show_holds_backtest(df, **kwargs): def show_stoploss_by_direction(dfw, **kwargs): """按方向止损分析的展示 - :param dfw: pd.DataFrame, 包含权重数据 - :param kwargs: dict, 其他参数 - - stoploss: float, 止损比例 - - show_detail: bool, 是否展示详细信息 - - digits: int, 价格小数位数, 默认2 - - fee_rate: float, 手续费率, 默认0.0002 + 在执行权重回测之前,先调用 ``rs_czsc.stoploss_by_direction`` 对权重数据按交易方向 + 进行止损改写:当一笔交易(同方向连续持仓)的浮亏达到 ``stoploss`` 时,将后续权重 + 强制平仓。改写后再调用 :func:`show_weight_backtest` 进行回测和展示。 + + :param dfw: pd.DataFrame,包含 ``symbol``、``dt``、``weight``、``price`` 等权重数据 + :param kwargs: 其他参数 + - stoploss: float,止损比例,例如 0.08 代表 8% 浮亏触发止损,默认 0.08 + - show_detail: bool,是否展示止损点的详细数据,默认 False + - digits: int,价格小数位数,默认 2 + - fee_rate: float,手续费率,默认 0.0002 + :return: None """ from rs_czsc import stoploss_by_direction dfw = dfw.copy() stoploss = kwargs.pop("stoploss", 0.08) + # 按方向进行止损改写,返回带 ``hold_returns``、``is_stop`` 等附加列的数据 dfw1 = stoploss_by_direction(dfw, stoploss=stoploss) - # 找出逐笔止损点 + # 找出每一笔交易的止损点:按 symbol/order_id 聚合,取首次 is_stop=True 的位置 rows = [] for symbol, dfg in dfw1.groupby("symbol"): for order_id, dfg1 in dfg.groupby("order_id"): @@ -242,6 +286,7 @@ def show_stoploss_by_direction(dfw, **kwargs): with st.expander("逐笔止损点", expanded=False): st.dataframe(dfr, width="stretch") + # 可选:展示所有触发止损点对应的明细行 if kwargs.pop("show_detail", False): cols = [ "dt", @@ -265,19 +310,24 @@ def show_stoploss_by_direction(dfw, **kwargs): def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): """根据权重阈值进行回测对比的 Streamlit 组件 - :param df: pd.DataFrame, columns = ['dt', 'symbol', 'weight', 'price'], 含权重的K线数据 - :param kwargs: 其他参数 + 按样本内权重绝对值的若干分位数生成阈值,分别对原始策略与"权重过阈值则取 + sign,否则置 0"的方案进行回测,并对比累计收益与核心指标。 - - out_sample_sdt: str, 样本外开始时间,格式如 '2020-01-01' - - percentiles: list, 样本内分位数比例序列,默认 [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] - - fee_rate: float, 交易成本,默认 0.0002 - - digits: int, 权重保留小数位数,默认 2 - - weight_type: str, 权重类型,默认 'ts' + :param df: pd.DataFrame,columns = ['dt', 'symbol', 'weight', 'price'],含权重的 K 线数据 + :param out_sample_sdt: 样本外开始时间;样本内用于生成阈值,样本外可单独评估效果 + :param kwargs: 其他参数 + - percentiles: list,样本内分位数比例序列,默认 [0.0, 0.1, ..., 0.9] + - fee_rate: float,交易成本,默认 0.0002 + - digits: int,权重保留小数位数,默认 2 + - weight_type: str,权重类型,默认 'ts' + - only_out_sample: bool,是否仅在样本外评估过滤效果,默认 False + - sub_title: str,标题文案 + :return: dict 或 None;返回 ``{阈值名: WeightBacktest}`` 映射,构造失败时返回 None """ from czsc.eda import cal_yearly_days from czsc.svc.strategy import show_multi_backtest - # 获取参数 + # 提取参数 percentiles = kwargs.get("percentiles", [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) fee_rate = kwargs.get("fee_rate", 0.0002) digits = kwargs.get("digits", 2) @@ -285,7 +335,7 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): only_out_sample = kwargs.get("only_out_sample", False) sub_title = kwargs.get("sub_title", "不同权重阈值下的回测结果对比") - # 验证输入数据 + # 校验输入数据列是否齐全 required_cols = ["dt", "symbol", "weight", "price"] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: @@ -294,7 +344,7 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): st.subheader(sub_title, divider="rainbow") - # 数据预处理 + # 数据预处理:保证 dt 类型与排序 df = df.copy() df["dt"] = pd.to_datetime(df["dt"]) df = df.sort_values(["symbol", "dt"]).reset_index(drop=True) @@ -302,7 +352,7 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): # 计算年化交易日数 yearly_days = cal_yearly_days(df["dt"].unique().tolist()) - # 分割样本内外数据 + # 切分样本内与样本外数据 out_sample_sdt = pd.to_datetime(out_sample_sdt) df_in_sample = df[df["dt"] < out_sample_sdt].copy() df_out_sample = df[df["dt"] >= out_sample_sdt].copy() @@ -311,9 +361,10 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): st.error("样本内数据为空,请检查 out_sample_sdt 参数") return + # 根据 only_out_sample 决定回测时使用全部数据还是仅样本外 df_analysis = df_out_sample.copy() if only_out_sample else df.copy() - # 显示数据基本信息 + # 展示数据基本信息,便于诊断 weight_stats = df_in_sample["weight"].describe() st.markdown( f"**数据基本信息:** 总记录数 {len(df)},标的数量 {df['symbol'].nunique()},样本内记录数 {len(df_in_sample)}," @@ -328,10 +379,10 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): threshold = weight_abs.quantile(p) thresholds[f"阈值_{int(p * 100)}%"] = threshold - # 创建不同阈值下的回测策略 + # 构造不同阈值下的回测策略 wbs = {} - # 原始策略(无阈值过滤) + # 原始策略(不做任何过滤) try: wb_original = WeightBacktest( df_analysis[["dt", "symbol", "weight", "price"]], @@ -345,14 +396,13 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): st.error(f"原始策略回测失败:{e}") return - # 不同阈值下的策略 + # 不同阈值下的策略:权重绝对值大于等于阈值时取 sign(weight) * 1,否则置 0 for p in percentiles: threshold_name = f"阈值_{int(p * 100)}%" threshold_value = thresholds[threshold_name] - # 创建过滤后的权重 + # 重新构造过滤后的权重 df_filtered = df_analysis.copy() - # 仅当权重绝对值大于等于阈值时,使用 sign(weight) * 1,否则权重为0 df_filtered["weight"] = np.where( df_filtered["weight"].abs() >= threshold_value, np.sign(df_filtered["weight"]), 0 ) @@ -374,15 +424,14 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): st.error("所有策略回测都失败了") return - # 显示回测结果对比 + # 多策略绩效对比 st.caption(f"回测参数:fee_rate={fee_rate}, digits={digits}, weight_type={weight_type}") show_multi_backtest(wbs, show_describe=False) - # 显示权重使用情况统计 + # 不同阈值下权重的使用情况统计 with st.container(border=True): st.markdown("#### :orange[权重使用情况统计]") - # 计算不同阈值下的权重使用比例 usage_stats = [] for p in percentiles: threshold_name = f"阈值_{int(p * 100)}%" @@ -410,8 +459,14 @@ def show_backtest_by_thresholds(df: pd.DataFrame, out_sample_sdt, **kwargs): def show_backtest_by_year(df: pd.DataFrame, **kwargs): - """ - 按照年份进行回测 + """按照年份进行回测 + + 将权重数据按自然年切片,分别构造 ``WeightBacktest`` 进行回测,并通过 + :func:`show_multi_backtest` 进行多策略对比,便于观察策略在不同年份的稳定性。 + + :param df: pd.DataFrame,包含 dt/symbol/weight/price 列 + :param kwargs: 透传给 ``WeightBacktest`` 的参数(``fee_rate`` / ``digits`` / ``weight_type`` / ``yearly_days``) + :return: dict 或 None;按年份组织的回测对象映射 """ if WeightBacktest is None: return @@ -427,6 +482,7 @@ def show_backtest_by_year(df: pd.DataFrame, **kwargs): df["year"] = df["dt"].dt.year wbs = {} for year, dfy in df.groupby("year"): + # 每年单独排序、重设索引后再回测 dfy = dfy.copy().sort_values(["symbol", "dt"]).reset_index(drop=True) wbs[f"{year}年"] = WeightBacktest( dfy, fee_rate=fee_rate, digits=digits, weight_type=weight_type, yearly_days=yearly_days @@ -437,8 +493,13 @@ def show_backtest_by_year(df: pd.DataFrame, **kwargs): def show_backtest_by_symbol(df: pd.DataFrame, **kwargs): - """ - 按照交易标的进行回测 + """按照交易标的进行回测 + + 将权重数据按 symbol 切片,分别构造 ``WeightBacktest``,便于观察各品种的贡献。 + + :param df: pd.DataFrame,包含 dt/symbol/weight/price 列 + :param kwargs: 透传给 ``WeightBacktest`` 的参数 + :return: dict 或 None;按 symbol 组织的回测对象映射 """ if WeightBacktest is None: return @@ -464,8 +525,19 @@ def show_backtest_by_symbol(df: pd.DataFrame, **kwargs): def show_long_short_backtest(df: pd.DataFrame, **kwargs): - """ - 分析多头、空头的收益 + """分析多头、空头收益及基准等权对比 + + 将权重切分为: + - 原始策略:保留正负权重不变; + - 策略多头:仅保留正权重,负权重置 0; + - 策略空头:仅保留负权重,正权重置 0; + - 基准等权:所有权重置 1,作为多头满仓基准。 + + 四组策略统一回测后调用 :func:`show_multi_backtest` 进行对比。 + + :param df: pd.DataFrame,包含 dt/symbol/weight/price 列 + :param kwargs: 透传给 ``WeightBacktest`` 的参数 + :return: dict 或 None;多空对比的回测对象映射 """ if WeightBacktest is None: return @@ -479,12 +551,15 @@ def show_long_short_backtest(df: pd.DataFrame, **kwargs): df = df[["dt", "symbol", "weight", "price"]].copy() + # 多头:负权重截断为 0 dfl = df.copy() dfl["weight"] = dfl["weight"].clip(lower=0) + # 空头:正权重截断为 0 dfs = df.copy() dfs["weight"] = dfs["weight"].clip(upper=0) + # 等权满仓基准 dfb = df.copy() dfb["weight"] = 1 @@ -505,12 +580,21 @@ def show_long_short_backtest(df: pd.DataFrame, **kwargs): def show_comprehensive_weight_backtest(df: pd.DataFrame, **kwargs): - """综合权重回测可视化展示""" + """综合权重回测可视化展示 + + 将整体回测、标的基准、年度回测、标的回测、多空回测、原始数据下载等内容 + 整合在 6 个 Tab 页中,便于用户在一个页面完成全方位的策略评估。 + + :param df: pd.DataFrame,包含 dt/symbol/weight/price 列的权重数据 + :param kwargs: 透传给底层回测函数的参数(``yearly_days`` / ``fee`` / ``digits`` / ``weight_type``) + :return: WeightBacktest,整体回测对象 + """ yearly_days = kwargs.get("yearly_days", 252) fee = kwargs.get("fee", 0.0) digits = kwargs.get("digits", 2) weight_type = kwargs.get("weight_type", "ts") + # fee 单位为 BP,需要换算成小数比率 fee_rate = fee / 10000 tabs = st.tabs(["整体回测", "标的基准", "年度回测", "标的回测", "多空回测", "下载数据"]) @@ -539,6 +623,7 @@ def show_comprehensive_weight_backtest(df: pd.DataFrame, **kwargs): show_long_short_backtest(df, yearly_days=yearly_days, fee_rate=fee_rate, digits=digits, weight_type=weight_type) with tabs[5]: + # 提供原始权重、整体收益、多头收益、空头收益的下载入口 st.download_button( "下载原始数据", data=df.to_csv(index=False), diff --git a/czsc/svc/base.py b/czsc/svc/base.py index 8207e7fc5..b62851781 100644 --- a/czsc/svc/base.py +++ b/czsc/svc/base.py @@ -1,33 +1,41 @@ """ -SVC模块的基础功能模块 +SVC 模块的基础工具集 -提供统一的导入处理、样式配置等基础功能 +本模块为 ``czsc.svc`` 子包内其他可视化组件提供通用的基础设施,主要包含: + +1. :func:`apply_stats_style`:为绩效统计 DataFrame 统一应用条件格式与数值格式化, + 保证不同组件展示风格一致; +2. :func:`ensure_datetime_index`:把任意可能的 dt 列规范成 ``datetime64[ns]`` 类型 + 的索引,是大多数时间序列绘图函数的前置处理; +3. :func:`generate_component_key`:根据数据内容自动生成 Streamlit 组件的唯一 key, + 避免在同一页面多次调用相同组件时出现 ``duplicate widget key`` 报错。 + +这些工具被回测、收益、统计等多个子模块复用,需要保持稳定且向后兼容。 """ import hashlib import json import pandas as pd -import streamlit as st def apply_stats_style(stats_df): """统一的绩效指标样式配置 - 参数: - stats_df: pd.DataFrame, 待样式化的统计数据 + 根据预定义的"正向指标 / 负向指标"分类,对 DataFrame 中已知的绩效指标列应用 + 背景色梯度(``RdYlGn`` 系列)以及格式化字符串(百分比、两位小数等),未知列 + 保持原样输出。 - 返回: - pandas.io.formats.style.Styler, 应用样式后的DataFrame - - 功能增强: - - 保留所有输入列,不删除非样式列 - - 只对已知的绩效指标列应用样式和格式化 - - 对其他列保持原样 + :param stats_df: pd.DataFrame,待样式化的绩效统计数据;不要求包含全部已知列 + :return: pandas.io.formats.style.Styler,应用样式后的 Styler 对象 + :note: + - 保留所有输入列,不删除非样式列; + - 只对已知的绩效指标列应用样式和格式化; + - 对其他列保持原样,便于扩展自定义指标。 """ - # 定义已知的绩效指标列及其样式配置 + # 已知绩效指标列及其样式配置(按"越大越好""越小越好"两类区分) style_config = { - # 正向指标(越大越好)- 使用 RdYlGn_r 配色 + # 正向指标(越大越好):使用反转的红黄绿配色,让大值偏绿 "positive_indicators": { "columns": [ "绝对收益", @@ -48,14 +56,14 @@ def apply_stats_style(stats_df): ], "cmap": "RdYlGn_r", }, - # 负向指标(越小越好)- 使用 RdYlGn 配色 + # 负向指标(越小越好):使用正向的红黄绿配色,让小值偏绿 "negative_indicators": { "columns": ["最大回撤", "年化波动率", "下行波动率", "盈亏平衡点", "新高间隔", "回撤风险", "波动比"], "cmap": "RdYlGn", }, } - # 格式化配置 + # 数值格式化配置:分百分比与小数两类 format_dict = { # 百分比格式 "绝对收益": "{:.2%}", @@ -85,24 +93,24 @@ def apply_stats_style(stats_df): "与基准波动相关性": "{:.2f}", } - # 保留所有列,从原DataFrame开始应用样式 + # 从原 DataFrame 构造 Styler,保留全部列 stats_styled = stats_df.style - # 应用正向指标样式 + # 正向指标:仅对存在的列应用反向梯度 for col in style_config["positive_indicators"]["columns"]: if col in stats_df.columns: stats_styled = stats_styled.background_gradient( cmap=style_config["positive_indicators"]["cmap"], axis=None, subset=[col] ) - # 应用负向指标样式 + # 负向指标:仅对存在的列应用正向梯度 for col in style_config["negative_indicators"]["columns"]: if col in stats_df.columns: stats_styled = stats_styled.background_gradient( cmap=style_config["negative_indicators"]["cmap"], axis=None, subset=[col] ) - # 应用格式化 - 只格式化存在的列 + # 仅对存在的列应用格式化字符串,避免 KeyError format_dict_filtered = {k: v for k, v in format_dict.items() if k in stats_df.columns} if format_dict_filtered: stats_styled = stats_styled.format(format_dict_filtered) @@ -111,7 +119,16 @@ def apply_stats_style(stats_df): def ensure_datetime_index(df, dt_col="dt"): - """确保DataFrame的索引是datetime64[ns]类型""" + """确保 DataFrame 的索引是 ``datetime64[ns]`` 类型 + + 若索引已经是 datetime64[ns] 则原样返回;否则会尝试用 ``dt_col`` 列设置为索引, + 并强制转换为 ``datetime64[ns]``。 + + :param df: pd.DataFrame,输入数据 + :param dt_col: str,作为时间索引的列名,默认 ``"dt"`` + :return: pd.DataFrame,索引为 ``datetime64[ns]`` 的 DataFrame + :raises ValueError: 当 df 既没有 datetime64[ns] 索引也不存在 ``dt_col`` 列时 + """ if df.index.dtype != "datetime64[ns]": if dt_col in df.columns: df[dt_col] = pd.to_datetime(df[dt_col]).astype("datetime64[ns]") @@ -126,16 +143,20 @@ def ensure_datetime_index(df, dt_col="dt"): def generate_component_key(data, prefix="component", **kwargs): """根据输入数据生成唯一的组件 key - :param data: 输入数据(DataFrame, Figure, dict, str 等) - :param prefix: key 的前缀,建议使用函数名缩写 - :param kwargs: 其他影响输出的参数 - :return: str, 唯一的 hash key + Streamlit 在同一页面多次出现相同组件时需要不同的 key 来区分,本函数通过对 + 数据内容、附加参数与高精度时间戳进行 md5 哈希,给出短而稳定的 key。 + + :param data: 输入数据(``pd.DataFrame``、``Figure``、``dict``、``str`` 等) + :param prefix: str,key 前缀,建议使用函数名缩写,便于调试 + :param kwargs: 其他影响输出的参数;会一并参与哈希 + :return: str,形如 ``"_<8位hex>"`` 的唯一 key """ import time key_parts = [prefix] if isinstance(data, pd.DataFrame): + # DataFrame 使用 hash_pandas_object 求 sum,避免对每一行单独构造字符串 from pandas.util import hash_pandas_object key_parts.append(str(hash_pandas_object(data).sum())) @@ -147,7 +168,7 @@ def generate_component_key(data, prefix="component", **kwargs): if kwargs: key_parts.append(json.dumps(kwargs, sort_keys=True, default=str)) - # 添加高精度时间戳确保唯一性(精确到纳秒) + # 加入纳秒时间戳确保多次调用时绝对不重复 timestamp = time.time_ns() key_parts.append(str(timestamp)) diff --git a/czsc/svc/factor.py b/czsc/svc/factor.py index 1913dd043..a5af96575 100644 --- a/czsc/svc/factor.py +++ b/czsc/svc/factor.py @@ -1,7 +1,18 @@ """ -因子分析相关的可视化组件 +因子分析相关的 Streamlit 可视化组件 -包含特征收益、因子分层、因子数值分布、事件收益分析等功能 +本模块面向单因子的常见分析场景,提供如下交互式组件: + +1. :func:`show_feature_returns`:批量评估特征/因子与目标收益的相关性,并可绘制 + 特征间相关性热力图; +2. :func:`show_factor_layering`:按 ``qcut`` / ``cut`` 对因子分层,统计每层日收益 + 绩效与累计收益曲线,验证因子单调性; +3. :func:`show_factor_value`:分析因子的数值分布、分位数及异常值情况; +4. :func:`show_event_return`:经典事件研究法,绘制事件前后的累计收益与置信区间; +5. :func:`show_event_features`:对事件组与非事件组的特征做均值差异与显著性检验。 + +所有组件遵循统一的 Streamlit 风格,并通过 :func:`generate_component_key` 自动生成 +组件 key,避免在同一页面多次复用时冲突。 """ import numpy as np @@ -11,20 +22,24 @@ import streamlit as st from .base import apply_stats_style, generate_component_key -from rs_czsc import daily_performance +from wbt import daily_performance def show_feature_returns(df, features, ret_col="returns", key=None, **kwargs): """展示特征收益分析 - :param df: pd.DataFrame, 数据源 - :param features: list, 特征列名列表 - :param ret_col: str, 收益列名,默认为 'returns' - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 - :param kwargs: - - method: str, 相关性计算方法,默认为 'spearman' - - min_periods: int, 最小样本数,默认为100 - - show_correlation: bool, 是否展示相关性热力图,默认为True + 针对每个特征列,计算其与目标收益列的相关系数(默认 Spearman),按绝对值排序后 + 展示,并通过条形图与热力图直观呈现特征强度与特征间相关性。 + + :param df: pd.DataFrame,数据源;至少包含 ``features`` 与 ``ret_col`` 列 + :param features: list,特征列名列表 + :param ret_col: str,收益列名,默认 ``"returns"`` + :param key: str,可选;组件基础标识符,每个图表会自动追加后缀 + :param kwargs: 其他参数 + - method: str,相关性计算方法,默认 ``"spearman"`` + - min_periods: int,最小样本数,特征样本不足该值时被跳过,默认 100 + - show_correlation: bool,是否展示特征间相关性热力图,默认 True + :return: None;结果直接写入 Streamlit 页面 """ method = kwargs.get("method", "spearman") min_periods = kwargs.get("min_periods", 100) @@ -39,7 +54,7 @@ def show_feature_returns(df, features, ret_col="returns", key=None, **kwargs): st.error(f"数据中没有找到特征列: {missing_features}") return - # 计算特征与收益的相关性 + # 逐个特征计算与目标收益的相关性,记录样本数与绝对相关系数用于排序 correlations = [] for feature in features: data = df[[feature, ret_col]].dropna() @@ -54,18 +69,18 @@ def show_feature_returns(df, features, ret_col="returns", key=None, **kwargs): corr_df = pd.DataFrame(correlations) corr_df = corr_df.sort_values("绝对相关系数", ascending=False) - # 显示排序后的相关性 + # 展示按绝对相关系数排序的明细表 st.subheader("特征与收益的相关性排序") corr_styled = corr_df.style.background_gradient(cmap="RdYlGn_r", subset=["相关系数"]) corr_styled = corr_styled.background_gradient(cmap="RdYlGn_r", subset=["绝对相关系数"]) corr_styled = corr_styled.format({"相关系数": "{:.4f}", "绝对相关系数": "{:.4f}", "样本数": "{:.0f}"}) st.dataframe(corr_styled, width="stretch", hide_index=True) - # 生成基础 key + # 自动生成组件基础 key,确保多次调用不冲突 if key is None: key = generate_component_key(df, prefix="feat_ret", features=features, ret_col=ret_col, method=method) - # 绘制相关性条形图 + # 条形图:特征 vs. 相关系数 fig = px.bar( corr_df, x="特征", @@ -77,7 +92,7 @@ def show_feature_returns(df, features, ret_col="returns", key=None, **kwargs): fig.update_xaxes(tickangle=45) st.plotly_chart(fig, key=f"{key}_bar", width="stretch") - # 显示特征间相关性热力图 + # 当特征数量大于 1 时,绘制特征间相关性热力图,便于发现共线性 if show_correlation and len(features) > 1: st.subheader("特征间相关性矩阵") feature_corr = df[features].corr(method=method) @@ -103,15 +118,19 @@ def show_feature_returns(df, features, ret_col="returns", key=None, **kwargs): def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs): """展示因子分层分析 - :param df: pd.DataFrame, 数据源 - :param factor_col: str, 因子列名 - :param ret_col: str, 收益列名 - :param n_layers: int, 分层数量,默认为5 - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 - :param kwargs: - - method: str, 分层方法,'qcut'(等频)或'cut'(等距),默认为'qcut' - - show_cumulative: bool, 是否显示累计收益,默认为True - - show_distribution: bool, 是否显示因子分布,默认为True + 按 ``qcut``(等频)或 ``cut``(等距)将因子分成 ``n_layers`` 层,分别计算各层 + 的日收益绩效与平均收益,并可绘制累计收益曲线与因子分布直方图。 + + :param df: pd.DataFrame,数据源 + :param factor_col: str,因子列名 + :param ret_col: str,收益列名 + :param n_layers: int,分层数量,默认 5 + :param key: str,可选;组件基础标识符 + :param kwargs: 其他参数 + - method: str,分层方法,``'qcut'`` 或 ``'cut'``,默认 ``'qcut'`` + - show_cumulative: bool,是否绘制累计收益曲线(要求 df 包含 ``dt`` 列),默认 True + - show_distribution: bool,是否绘制因子分布直方图,默认 True + :return: None """ method = kwargs.get("method", "qcut") show_cumulative = kwargs.get("show_cumulative", True) @@ -121,12 +140,12 @@ def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs st.error(f"数据中没有找到列 '{factor_col}' 或 '{ret_col}'") return - # 去除缺失值 + # 去除两列中的缺失值,避免 qcut/cut 报错 data = df[[factor_col, ret_col]].dropna() - if len(data) < n_layers * 10: # 每层至少10个观测值 + if len(data) < n_layers * 10: # 经验上每层至少需要 10 个观测 st.warning(f"数据量太少,建议减少分层数量。当前数据量:{len(data)}") - # 因子分层 + # 因子分层:等频或等距 if method == "qcut": data["layer"] = pd.qcut( data[factor_col], q=n_layers, labels=[f"第{i + 1}层" for i in range(n_layers)], duplicates="drop" @@ -136,7 +155,7 @@ def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs data[factor_col], bins=n_layers, labels=[f"第{i + 1}层" for i in range(n_layers)], duplicates="drop" ) - # 计算各层收益统计 + # 各层日收益的绩效统计 layer_stats = [] for layer in data["layer"].cat.categories: layer_data = data[data["layer"] == layer][ret_col] @@ -152,18 +171,18 @@ def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs stats_df = pd.DataFrame(layer_stats).set_index("分层") - # 显示分层统计 + # 展示分层绩效表 st.subheader(f"{factor_col} 分层收益分析") stats_styled = apply_stats_style(stats_df) st.dataframe(stats_styled, width="stretch") - # 生成基础 key + # 自动生成组件 key if key is None: key = generate_component_key( df, prefix="layer", factor_col=factor_col, ret_col=ret_col, n_layers=n_layers, method=method ) - # 显示分层收益对比 + # 各层平均收益对比柱状图 layer_returns = data.groupby("layer")[ret_col].mean() fig_bar = px.bar( x=layer_returns.index.astype(str), @@ -173,13 +192,13 @@ def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs ) st.plotly_chart(fig_bar, key=f"{key}_bar", width="stretch") - # 显示累计收益曲线(如果有时间信息) + # 若存在时间列,则绘制各层累计收益曲线 if show_cumulative and "dt" in df.columns: st.subheader("分层累计收益曲线") df_temp = df.copy() df_temp = df_temp.dropna(subset=[factor_col, ret_col]) - # 重新分层 + # 重新分层(保持与上方一致的方法) if method == "qcut": df_temp["layer"] = pd.qcut( df_temp[factor_col], q=n_layers, labels=[f"第{i + 1}层" for i in range(n_layers)], duplicates="drop" @@ -192,7 +211,7 @@ def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs df_temp["dt"] = pd.to_datetime(df_temp["dt"]) df_temp = df_temp.sort_values("dt") - # 计算各层累计收益 + # 各层累计收益(横截面取均值) cumulative_returns = [] for layer in df_temp["layer"].cat.categories: layer_data = df_temp[df_temp["layer"] == layer] @@ -205,7 +224,7 @@ def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs fig_cumret = px.line(cumret_df, title="分层累计收益曲线") st.plotly_chart(fig_cumret, key=f"{key}_cumret", width="stretch") - # 显示因子分布 + # 因子在各层的分布直方图 if show_distribution: st.subheader(f"{factor_col} 分布分析") fig_hist = px.histogram(data, x=factor_col, color="layer", title=f"{factor_col} 在各层的分布") @@ -215,13 +234,17 @@ def show_factor_layering(df, factor_col, ret_col, n_layers=5, key=None, **kwargs def show_factor_value(df, factor_col, bins=50, key=None, **kwargs): """展示因子数值分布 - :param df: pd.DataFrame, 数据源 - :param factor_col: str, 因子列名 - :param bins: int, 直方图箱数,默认为50 - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 - :param kwargs: - - show_outliers: bool, 是否显示异常值,默认为True - - percentiles: list, 分位数列表,默认为[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] + 展示因子的样本数、均值、标准差、缺失值数量、分位数表,并绘制直方图与箱线图。 + 同时基于 IQR 法进行异常值统计。 + + :param df: pd.DataFrame,数据源 + :param factor_col: str,因子列名 + :param bins: int,直方图箱数,默认 50 + :param key: str,可选;组件基础标识符 + :param kwargs: 其他参数 + - show_outliers: bool,是否展示异常值统计与异常值列表,默认 True + - percentiles: list,分位数列表,默认 [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] + :return: None """ if factor_col not in df.columns: st.error(f"数据中没有找到因子列 '{factor_col}'") @@ -237,27 +260,27 @@ def show_factor_value(df, factor_col, bins=50, key=None, **kwargs): st.subheader(f"{factor_col} 数值分布分析") - # 基本统计信息 + # 顶部展示基本统计量 c1, c2, c3, c4 = st.columns(4) c1.metric("观测数", f"{len(data):,.0f}") c2.metric("均值", f"{data.mean():.4f}") c3.metric("标准差", f"{data.std():.4f}") c4.metric("缺失值", f"{df[factor_col].isnull().sum():,.0f}") - # 分位数信息 + # 分位数明细 quantiles = data.quantile(percentiles) quantile_df = pd.DataFrame({"分位数": [f"{p:.1%}" for p in percentiles], "数值": quantiles.values}) with st.expander("分位数分布", expanded=False): st.dataframe(quantile_df.style.format({"数值": "{:.4f}"}), width="stretch", hide_index=True) - # 生成基础 key + # 自动生成组件 key if key is None: key = generate_component_key( df, prefix="fact_val", factor_col=factor_col, bins=bins, show_outliers=show_outliers ) - # 绘制直方图和箱线图 + # 直方图与箱线图并排展示 col1, col2 = st.columns(2) with col1: @@ -271,7 +294,7 @@ def show_factor_value(df, factor_col, bins=50, key=None, **kwargs): fig_box.update_traces(boxpoints=False) st.plotly_chart(fig_box, key=f"{key}_box", width="stretch") - # 异常值分析 + # 基于 1.5 IQR 规则的异常值分析 if show_outliers: Q1 = data.quantile(0.25) Q3 = data.quantile(0.75) @@ -285,22 +308,26 @@ def show_factor_value(df, factor_col, bins=50, key=None, **kwargs): st.write(f"**异常值边界**: [{lower_bound:.4f}, {upper_bound:.4f}]") st.write(f"**异常值占比**: {len(outliers) / len(data):.2%}") - if len(outliers) <= 100: # 只显示前100个异常值 + if len(outliers) <= 100: # 异常值过多时仅汇总,不再罗列 outlier_df = pd.DataFrame({"序号": range(1, len(outliers) + 1), "异常值": outliers.values}) st.dataframe(outlier_df.style.format({"异常值": "{:.4f}"}), width="stretch", hide_index=True) def show_event_return(df, event_col, ret_col, key=None, **kwargs): - """展示事件收益分析 - - :param df: pd.DataFrame, 数据源 - :param event_col: str, 事件列名(布尔类型或0/1) - :param ret_col: str, 收益列名 - :param key: str, 可选,组件的唯一标识符,默认自动生成 - :param kwargs: - - pre_periods: int, 事件前观察期数,默认为5 - - post_periods: int, 事件后观察期数,默认为10 - - min_observations: int, 最小观察数,默认为10 + """展示事件收益分析(事件研究法) + + 定位 ``event_col`` 中所有发生事件的时点,计算事件前后若干期的累计收益, + 再求平均与 95% 置信区间,绘制经典的事件研究曲线。 + + :param df: pd.DataFrame,数据源;必须包含 ``dt`` 时间列 + :param event_col: str,事件列名(布尔类型或 0/1) + :param ret_col: str,收益列名 + :param key: str,可选;组件唯一标识符 + :param kwargs: 其他参数 + - pre_periods: int,事件前观察期数,默认 5 + - post_periods: int,事件后观察期数,默认 10 + - min_observations: int,最小事件次数,少于该值则警告,默认 10 + :return: None """ pre_periods = kwargs.get("pre_periods", 5) post_periods = kwargs.get("post_periods", 10) @@ -310,7 +337,7 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): st.error(f"数据中没有找到列 '{event_col}' 或 '{ret_col}'") return - # 确保有时间索引 + # 必须有时间列才能定位事件前后窗口 if "dt" not in df.columns: st.error("数据中需要包含 'dt' 时间列") return @@ -319,7 +346,7 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): df["dt"] = pd.to_datetime(df["dt"]) df = df.sort_values("dt").reset_index(drop=True) - # 找出事件发生的时点 + # 找出事件发生的位置(True 或 1 都视为事件) event_mask = (df[event_col]) | (df[event_col] == 1) event_indices = df[event_mask].index.tolist() @@ -329,17 +356,18 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): st.subheader(f"事件收益分析 (事件发生{len(event_indices)}次)") - # 计算事件前后的累计收益 + # 计算每个事件点前后窗口的相对累计收益 event_returns = [] for event_idx in event_indices: start_idx = max(0, event_idx - pre_periods) end_idx = min(len(df), event_idx + post_periods + 1) + # 仅当窗口完整时才计入,避免边界数据干扰均值 if end_idx - start_idx >= pre_periods + post_periods: window_data = df.iloc[start_idx:end_idx] window_returns = window_data[ret_col].values - # 计算相对于事件时点的累计收益 + # 以事件时点为参考,前段累计 + 后段累计 event_point = pre_periods relative_returns = np.zeros(len(window_returns)) for i in range(len(window_returns)): @@ -354,25 +382,25 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): st.error("没有足够的数据进行事件分析") return - # 计算平均事件收益 + # 计算事件平均收益与标准差,构造 95% 置信区间 event_returns_array = np.array(event_returns) mean_returns = np.mean(event_returns_array, axis=0) std_returns = np.std(event_returns_array, axis=0) - # 时间轴(相对于事件时点) + # 时间轴:从 -pre_periods 到 +post_periods time_axis = list(range(-pre_periods, post_periods + 1)) - # 绘制事件研究图 + # 事件研究图:均值曲线 + 置信区间填充 + 事件分割线 fig = go.Figure() - # 添加平均累计收益 + # 平均累计收益主曲线 fig.add_trace( go.Scatter( x=time_axis, y=mean_returns, mode="lines+markers", name="平均累计收益", line={"color": "blue", "width": 2} ) ) - # 添加置信区间 + # 95% 置信区间(基于均值的标准误) upper_bound = mean_returns + 1.96 * std_returns / np.sqrt(len(event_returns)) lower_bound = mean_returns - 1.96 * std_returns / np.sqrt(len(event_returns)) @@ -392,7 +420,7 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): ) ) - # 添加事件发生时点的垂直线 + # 事件发生时点的红色虚线 fig.add_vline(x=0, line_dash="dash", line_color="red", annotation_text="事件发生") fig.update_layout( @@ -402,7 +430,7 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): hovermode="x unified", ) - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key( df, @@ -415,7 +443,7 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): st.plotly_chart(fig, key=key, width="stretch") - # 显示统计信息 + # 底部展示 4 个关键指标:事件次数、事件前/后累计收益、事件效应 pre_event_return = mean_returns[pre_periods - 1] if pre_periods > 0 else 0 post_event_return = mean_returns[-1] event_effect = post_event_return - pre_event_return @@ -430,13 +458,17 @@ def show_event_return(df, event_col, ret_col, key=None, **kwargs): def show_event_features(df, event_col, feature_cols, key=None, **kwargs): """展示事件特征分析 - :param df: pd.DataFrame, 数据源 - :param event_col: str, 事件列名 - :param feature_cols: list, 特征列名列表 - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 - :param kwargs: - - test_method: str, 统计检验方法,'ttest'或'mannwhitney',默认为'ttest' - - alpha: float, 显著性水平,默认为0.05 + 将样本按事件是否发生切分为事件组与非事件组,针对每个特征做均值差异统计与 + 显著性检验(T 检验或 Mann-Whitney U 检验),并绘制特征分布对比。 + + :param df: pd.DataFrame,数据源 + :param event_col: str,事件列名 + :param feature_cols: list,特征列名列表 + :param key: str,可选;组件基础标识符 + :param kwargs: 其他参数 + - test_method: str,统计检验方法,``'ttest'`` 或 ``'mannwhitney'``,默认 ``'ttest'`` + - alpha: float,显著性水平,默认 0.05 + :return: None """ from scipy import stats @@ -452,7 +484,7 @@ def show_event_features(df, event_col, feature_cols, key=None, **kwargs): st.error(f"数据中没有找到特征列: {missing_features}") return - # 分组数据 + # 事件组与非事件组划分 event_mask = (df[event_col]) | (df[event_col] == 1) event_data = df[event_mask] non_event_data = df[~event_mask] @@ -463,7 +495,7 @@ def show_event_features(df, event_col, feature_cols, key=None, **kwargs): st.subheader(f"事件特征分析 (事件组: {len(event_data)}, 非事件组: {len(non_event_data)})") - # 统计检验结果 + # 逐特征做均值对比与统计检验 test_results = [] for feature in feature_cols: event_values = event_data[feature].dropna() @@ -472,7 +504,7 @@ def show_event_features(df, event_col, feature_cols, key=None, **kwargs): if len(event_values) == 0 or len(non_event_values) == 0: continue - # 统计检验 + # 选择检验方法 if test_method == "ttest": statistic, p_value = stats.ttest_ind(event_values, non_event_values) test_name = "T检验" @@ -495,7 +527,7 @@ def show_event_features(df, event_col, feature_cols, key=None, **kwargs): st.error("没有足够的数据进行特征比较") return - # 显示检验结果 + # 检验结果表格 results_df = pd.DataFrame(test_results) results_styled = results_df.style.background_gradient(cmap="RdYlGn_r", subset=["差异"]) results_styled = results_styled.background_gradient(cmap="RdYlGn", subset=["P值"]) @@ -506,19 +538,19 @@ def show_event_features(df, event_col, feature_cols, key=None, **kwargs): st.dataframe(results_styled, width="stretch", hide_index=True) st.caption(f"检验方法: {test_name}, 显著性水平: {alpha}") - # 生成基础 key + # 自动生成组件 key if key is None: key = generate_component_key( df, prefix="event_feat", event_col=event_col, feature_cols=feature_cols, test_method=test_method ) - # 绘制特征分布对比 - for _i, feature in enumerate(feature_cols[:4]): # 最多显示4个特征 + # 为前 4 个特征绘制分布直方图与箱线图对比 + for _i, feature in enumerate(feature_cols[:4]): if feature in results_df["特征"].values: col1, col2 = st.columns(2) with col1: - # 直方图对比 + # 直方图叠加对比(半透明) fig_hist = go.Figure() event_values = event_data[feature].dropna() diff --git a/czsc/svc/price_analysis.py b/czsc/svc/price_analysis.py index 05bc760e5..f4cfb914d 100644 --- a/czsc/svc/price_analysis.py +++ b/czsc/svc/price_analysis.py @@ -1,14 +1,20 @@ """ 价格敏感性分析模块 -用于分析策略对执行价格的敏感性,通过对比不同交易价格的回测结果来评估价格执行对策略性能的影响。 +用于分析策略对执行价格的敏感性。其典型应用场景是:同一份权重数据搭配不同的 +执行价(如开盘价、加权平均价、TWAP、VWAP 等,列名以 ``TP`` 开头),分别构造 +``WeightBacktest`` 进行回测,再比较核心绩效指标与累计收益曲线,从而评估策略 +对价格执行的依赖程度。 主要功能: -1. 累计收益曲线展示 -2. 价格敏感性分析 -3. 分析结果摘要 -作者: czsc +1. :func:`show_price_sensitive`:核心入口,遍历所有 ``TP*`` 列,逐个完成回测、 + 汇总绩效,并在 Streamlit 面板中以表格 + 折线图的形式展示。 +2. 内部辅助 :func:`_show_sensitivity_assessment`:根据"年化收益的极差/均值" + 给出三档敏感性评估(低 / 中 / 高)。 + +输出统一通过 Streamlit 渲染;同时也以 ``(stats_df, daily_df)`` 二元组的形式返回, +方便上层进一步处理或导出。 """ import pandas as pd @@ -16,7 +22,7 @@ from loguru import logger from .base import apply_stats_style -from rs_czsc import WeightBacktest +from wbt import WeightBacktest from .returns import show_cumulative_returns @@ -25,26 +31,24 @@ def show_price_sensitive( ) -> tuple[pd.DataFrame, pd.DataFrame] | None: """价格敏感性分析组件 - 分析策略对执行价格的敏感性,通过对比不同交易价格的回测结果来评估价格执行对策略性能的影响。 - - 参数: - df: 包含以下必要列的数据框: - - symbol: 合约代码 - - dt: 日期时间 - - weight: 仓位权重 - - TP*: 以TP开头的交易价格列(如TP_open, TP_high等) - fee: 单边费率(BP),默认2.0 - digits: 小数位数,默认2 - weight_type: 权重类型,可选 "ts" 或 "cs",默认 "ts" - n_jobs: 并行数,默认1 - **kwargs: 其他参数 - - title_prefix: 标题前缀,默认为空 - - show_detailed_stats: 是否显示详细统计信息,默认False - - 返回: - tuple: (dfr, dfd) 分别为统计结果DataFrame和日收益率DataFrame,失败时返回None - - 示例: + 分析策略对执行价格的敏感性,通过对比不同交易价格的回测结果,评估价格执行 + 对策略性能的影响。 + + :param df: pd.DataFrame,必须包含以下列: + - ``symbol``:合约代码 + - ``dt``:日期时间 + - ``weight``:仓位权重 + - ``TP*``:以 TP 开头的交易价格列(如 ``TP_open``、``TP_high`` 等) + :param fee: float,单边费率(BP),默认 2.0 + :param digits: int,权重小数位数,默认 2 + :param weight_type: str,权重类型,可选 ``"ts"`` 或 ``"cs"``,默认 ``"ts"`` + :param n_jobs: int,并行进程数,默认 1 + :param kwargs: 其他参数 + - title_prefix: str,标题前缀,默认空字符串 + - show_detailed_stats: bool,是否展示更多绩效字段,默认 False + :return: tuple[pd.DataFrame, pd.DataFrame] | None; + 分别为绩效汇总表和日收益宽表;若关键步骤失败则返回 None + :example: >>> # 基本用法 >>> dfr, dfd = show_price_sensitive(df, fee=2.0, digits=2) @@ -61,11 +65,11 @@ def show_price_sensitive( """ from czsc.eda import cal_yearly_days - # 参数处理 + # 提取展示相关参数 title_prefix = kwargs.get("title_prefix", "") show_detailed_stats = kwargs.get("show_detailed_stats", False) - # 数据验证 + # 校验必要列是否存在 required_cols = ["symbol", "dt", "weight"] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: @@ -74,7 +78,7 @@ def show_price_sensitive( logger.error(f"数据检查失败,{error_msg}") return None - # 查找交易价格列 + # 找出所有以 TP 开头的交易价格列 tp_cols = [x for x in df.columns if x.startswith("TP")] if not tp_cols: error_msg = "没有找到交易价格列,请检查文件;交易价列名必须以 TP 开头" @@ -84,7 +88,7 @@ def show_price_sensitive( logger.info(f"找到 {len(tp_cols)} 个交易价格列: {tp_cols}") - # 计算年化天数 + # 根据 dt 列推断每年实际交易天数 try: yearly_days = cal_yearly_days(dts=df["dt"].unique().tolist()) logger.info(f"计算得到年化天数: {yearly_days}") @@ -94,16 +98,16 @@ def show_price_sensitive( logger.error(error_msg) return None - # 结果收集 + # 收集每个 TP 列对应的回测结果 c1 = st.container(border=True) rows = [] dfd = pd.DataFrame() - # 创建进度条 + # 进度条与状态文本 progress_bar = st.progress(0) status_text = st.empty() - # 逐个处理交易价格 + # 逐列回测 for i, tp_col in enumerate(tp_cols): try: progress = (i + 1) / len(tp_cols) @@ -112,13 +116,13 @@ def show_price_sensitive( logger.info(f"正在处理第 {i + 1}/{len(tp_cols)} 个交易价格: {tp_col}") - # 准备数据 + # 构造单次回测所需的数据:用对应的 TP 列填充缺失,再当作 price 列 df_temp = df.copy() df_temp[tp_col] = df_temp[tp_col].fillna(df_temp["price"]) dfw = df_temp[["symbol", "dt", "weight", tp_col]].copy() dfw.rename(columns={tp_col: "price"}, inplace=True) - # 创建回测实例 + # 构造回测对象 wb = WeightBacktest( dfw=dfw, digits=digits, @@ -128,7 +132,7 @@ def show_price_sensitive( yearly_days=yearly_days, ) - # 获取日收益率 + # 把 daily_return 中的 total 列重命名为 TP 列名,便于多列横向合并 daily = wb.daily_return.copy() daily.rename(columns={"total": tp_col}, inplace=True) @@ -137,7 +141,7 @@ def show_price_sensitive( else: dfd = pd.merge(dfd, daily[["date", tp_col]], on="date", how="outer") - # 收集统计结果 + # 保留绩效统计 res = {"交易价格": tp_col} res.update(wb.stats) rows.append(res) @@ -156,16 +160,16 @@ def show_price_sensitive( st.error("所有交易价格处理失败,无法生成报告") return None - # 显示结果 + # 渲染绩效对比表 with c1: st.markdown(f"##### :red[{title_prefix}不同交易价格回测核心指标对比]") dfr = pd.DataFrame(rows) - # 敏感性评估 + # 多个 TP 列时给出敏感性评估 if len(dfr) > 1 and "年化" in dfr.columns: _show_sensitivity_assessment(dfr) - # 选择显示列 + # 选择展示列 if show_detailed_stats: display_cols = [ "交易价格", @@ -202,15 +206,15 @@ def show_price_sensitive( "持仓K线数", ] - # 确保所有列都存在 + # 仅保留实际存在的列,避免 KeyError available_cols = [col for col in display_cols if col in dfr.columns] dfr_display = dfr[available_cols].copy() - # 应用样式 + # 应用统一样式 dfr_styled = apply_stats_style(dfr_display) st.dataframe(dfr_styled, width="stretch") - # 累计收益对比 + # 累计收益对比图 c2 = st.container(border=True) with c2: st.markdown(f"##### :red[{title_prefix}不同交易价格回测累计收益对比]") @@ -229,8 +233,13 @@ def show_price_sensitive( def _show_sensitivity_assessment(dfr: pd.DataFrame) -> None: - """显示敏感性评估""" + """根据年化收益的极差/均值,给出敏感性评估文案 + + :param dfr: pd.DataFrame,必须包含 ``"年化"`` 列 + :return: None;通过 ``streamlit`` 直接渲染 + """ annual_returns = dfr["年化"] + # 敏感度 = (max - min) / mean,反映不同 TP 之间收益的相对差距 sensitivity_score = (annual_returns.max() - annual_returns.min()) / annual_returns.mean() st.markdown("**敏感性评估:**") @@ -242,7 +251,7 @@ def _show_sensitivity_assessment(dfr: pd.DataFrame) -> None: st.error(f"🔴 策略对价格执行高度敏感 (敏感度: {sensitivity_score:.2%})") -# 支持的函数列表 +# 模块对外暴露的 API 列表 __all__ = [ "show_price_sensitive", ] diff --git a/czsc/svc/returns.py b/czsc/svc/returns.py index d3ea6ced8..72d24b97a 100644 --- a/czsc/svc/returns.py +++ b/czsc/svc/returns.py @@ -1,7 +1,19 @@ """ -收益相关的可视化组件 +收益相关的 Streamlit 可视化组件 -包含日收益、累计收益、月度收益、回撤分析等可视化功能 +本模块面向"日收益序列"这一核心数据结构,提供以下交互式组件: + +1. :func:`show_daily_return`:日收益数据的整体展示,包括交易日 / 持有日两套绩效指标, + 以及累计收益曲线(含年度分隔线); +2. :func:`show_cumulative_returns`:纯粹的累计收益曲线绘制,不带绩效统计; +3. :func:`show_monthly_return`:月度收益矩阵 + 胜率 / 盈亏比 / 平均收益统计; +4. :func:`show_drawdowns`:最大回撤曲线、Top N 回撤详情; +5. :func:`show_rolling_daily_performance`:滚动窗口下的日收益绩效曲线。 + +约定: +- ``df`` 的索引必须为 ``datetime64[ns]``;如不是则可借助 :func:`ensure_datetime_index` + 从 ``dt`` 列设置; +- 收益列默认是百分比变化值(如 0.01 表示 1%)。 """ import pandas as pd @@ -10,33 +22,41 @@ import streamlit as st from .base import apply_stats_style, ensure_datetime_index, generate_component_key -from rs_czsc import daily_performance, top_drawdowns +from wbt import daily_performance +from czsc import top_drawdowns def show_daily_return(df: pd.DataFrame, key=None, **kwargs): """用 streamlit 展示日收益 - :param df: pd.DataFrame,数据源 - :param key: str, 可选,组件的唯一标识符,默认自动生成 - :param kwargs: + 支持同时展示交易日与持有日两套绩效指标,并绘制累计收益曲线。可通过 ``kwargs`` + 控制是否显示明细表格、是否仅在图例中保留某些列、自定义年化天数等。 + + :param df: pd.DataFrame,数据源;索引为日期,每列代表一条日收益序列 + :param key: str,可选;组件唯一标识符 + :param kwargs: 其他参数 - sub_title: str,标题 - - stat_hold_days: bool,是否展示持有日绩效指标,默认为 True - - legend_only_cols: list,仅在图例中展示的列名 - - use_st_table: bool,是否使用 st.table 展示绩效指标,默认为 False - - plot_cumsum: bool,是否展示日收益累计曲线,默认为 True - - yearly_days: int,年交易天数,默认为 252 - - show_dailys: bool,是否展示日收益数据详情,默认为 False + - stat_hold_days: bool,是否展示持有日绩效指标,默认 True + - legend_only_cols: list,仅在图例中显示(默认隐藏曲线)的列名 + - use_st_table: bool,是否使用 ``st.table`` 展示绩效指标,默认 False + - plot_cumsum: bool,是否绘制累计收益曲线,默认 True + - yearly_days: int,年交易天数,默认 252 + - show_dailys: bool,是否展示日收益数据明细,默认 False + :return: None """ df = ensure_datetime_index(df) yearly_days = kwargs.get("yearly_days", 252) df = df.copy().fillna(0).sort_index(ascending=True) def _stats(df_, type_="持有日"): + """计算每列的日收益绩效,并以 Styler 形式返回""" stats = [] for _col in df_.columns: if type_ == "持有日": + # 持有日:剔除收益为 0 的日期 col_stats = daily_performance([x for x in df_[_col] if x != 0], yearly_days=yearly_days) else: + # 交易日:包含所有交易日 col_stats = daily_performance(df_[_col], yearly_days=yearly_days) col_stats["日收益名称"] = _col stats.append(col_stats) @@ -44,22 +64,22 @@ def _stats(df_, type_="持有日"): stats_df = pd.DataFrame(stats).set_index("日收益名称") return apply_stats_style(stats_df) - # 参数处理 + # 解析展示相关参数 use_st_table = kwargs.get("use_st_table", False) stat_hold_days = kwargs.get("stat_hold_days", True) plot_cumsum = kwargs.get("plot_cumsum", True) - # 显示标题 + # 标题 sub_title = kwargs.get("sub_title", "") if sub_title: st.subheader(sub_title, divider="rainbow", anchor=sub_title) - # 显示数据详情 + # 可选展开详情:原始日收益矩阵 if kwargs.get("show_dailys", False): with st.expander("日收益数据详情", expanded=False): st.dataframe(df, width="stretch") - # 显示交易日绩效 + # 交易日绩效 if stat_hold_days: with st.expander("交易日绩效指标", expanded=True): stats = _stats(df, type_="交易日") @@ -75,31 +95,31 @@ def _stats(df_, type_="持有日"): else: st.dataframe(stats, width="stretch") - # 显示持有日绩效 + # 持有日绩效 if stat_hold_days: with st.expander("持有日绩效指标", expanded=False): st.dataframe(_stats(df, type_="持有日"), width="stretch") st.caption("持有日:在交易日的基础上,将收益率为0的日期删除") - # 显示累计收益曲线 + # 累计收益曲线 if plot_cumsum: df_cumsum = df.cumsum() fig = px.line(df_cumsum, y=df_cumsum.columns.to_list(), title="日收益累计曲线") fig.update_xaxes(title="") - # 添加年度分隔线 + # 给每个自然年的开始位置画一条红色虚线,方便对比 years = df_cumsum.index.year.unique() for year in years: first_date = df_cumsum[df_cumsum.index.year == year].index.min() fig.add_vline(x=first_date, line_dash="dash", line_color="red") - # 设置图例显示 + # 将指定列设置为"仅图例可见",默认不展示曲线 for col in kwargs.get("legend_only_cols", []): fig.update_traces(visible="legendonly", selector={"name": col}) fig.update_layout(margin={"l": 0, "r": 0, "b": 0}) - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key( df, prefix="daily_ret", plot_cumsum=plot_cumsum, legend_only_cols=kwargs.get("legend_only_cols", []) @@ -111,12 +131,17 @@ def _stats(df_, type_="持有日"): def show_cumulative_returns(df, key=None, **kwargs): """展示累计收益曲线 - :param df: pd.DataFrame, 数据源,index 为日期,columns 为对应策略上一个日期至当前日期的收益 - :param key: str, 可选,组件的唯一标识符,默认自动生成 - :param kwargs: dict, 可选参数 - - fig_title: str, 图表标题,默认为 "累计收益" - - legend_only_cols: list, 仅在图例中展示的列名 - - display_legend: bool, 是否展示图例,默认为 True + 本函数不计算绩效,只对输入的日收益做 ``cumsum`` 后绘制折线图,并加上年度 + 分隔线。适合作为"组合""多策略对比"等场景的轻量绘图工具。 + + :param df: pd.DataFrame,数据源;索引为日期(必须 datetime64[ns] 且单调递增、唯一), + 每列代表一条策略的日收益 + :param key: str,可选;组件唯一标识符 + :param kwargs: 其他参数 + - fig_title: str,图表标题,默认 ``"累计收益"`` + - legend_only_cols: list,仅在图例中显示的列名 + - display_legend: bool,是否展示图例,默认 True + :return: None """ assert df.index.dtype == "datetime64[ns]", "index必须是datetime64[ns]类型, 请先使用 pd.to_datetime 进行转换" assert df.index.is_unique, "df 的索引必须唯一" @@ -129,7 +154,7 @@ def show_cumulative_returns(df, key=None, **kwargs): fig = px.line(df_cumsum, y=df_cumsum.columns.to_list(), title=fig_title) fig.update_xaxes(title="") - # 添加年度分隔线 + # 年度分隔线 years = df_cumsum.index.year.unique() for year in years: first_date = df_cumsum[df_cumsum.index.year == year].index.min() @@ -140,11 +165,12 @@ def show_cumulative_returns(df, key=None, **kwargs): fig.update_traces(visible="legendonly", selector={"name": col}) if display_legend: + # 将图例放到图表下方水平居中 fig.update_layout( legend={"orientation": "h", "y": -0.1, "xanchor": "center", "x": 0.5}, margin={"l": 0, "r": 0, "b": 0} ) - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key( df, prefix="cum_ret", fig_title=fig_title, legend_only_cols=kwargs.get("legend_only_cols", []) @@ -156,9 +182,13 @@ def show_cumulative_returns(df, key=None, **kwargs): def show_monthly_return(df, ret_col="total", sub_title="月度累计收益", **kwargs): """展示指定列的月度累计收益 - :param df: pd.DataFrame,数据源 + 将日收益数据按月汇总成"年 × 月"的二维矩阵,并附加年度合计、胜率、盈亏比、 + 平均收益等汇总指标,配以统一的红黄绿配色。 + + :param df: pd.DataFrame,数据源;索引或 dt 列为日期 :param ret_col: str,收益列名 :param sub_title: str,标题 + :return: None """ assert isinstance(df, pd.DataFrame), "df 必须是 pd.DataFrame 类型" df = ensure_datetime_index(df) @@ -167,30 +197,31 @@ def show_monthly_return(df, ret_col="total", sub_title="月度累计收益", **k if sub_title: st.subheader(sub_title, divider="rainbow", anchor=sub_title) - # 计算月度收益 + # 月度求和并构造透视表 monthly = df[[ret_col]].resample("ME").sum() monthly["year"] = monthly.index.year monthly["month"] = monthly.index.month monthly = monthly.pivot_table(index="year", columns="month", values=ret_col) - # 设置列名 + # 将列名改为"X月",并补充年收益列 month_cols = [f"{x}月" for x in monthly.columns] monthly.columns = month_cols monthly["年收益"] = monthly.sum(axis=1) - # 计算统计指标 + # 月度胜率、盈亏比、平均收益 win_rate = monthly.apply(lambda x: (x > 0).sum() / len(x), axis=0) + # 月度亏损总额为 0 时,盈亏比记为 10(一个表示"非常好"的占位值) ykb = monthly.apply(lambda x: x[x > 0].sum() / -x[x < 0].sum() if min(x) < 0 else 10, axis=0) mean_ret = monthly.mean(axis=0) - # 应用样式 + # 月度矩阵着色 monthly_styled = monthly.style.background_gradient(cmap="RdYlGn_r", axis=None, subset=month_cols) monthly_styled = monthly_styled.background_gradient(cmap="RdYlGn_r", axis=None, subset=["年收益"]) monthly_styled = monthly_styled.format("{:.2%}", na_rep="-") st.dataframe(monthly_styled, width="stretch") - # 显示统计信息 + # 月度统计指标 dfy = pd.DataFrame([win_rate, ykb, mean_ret], index=["胜率", "盈亏比", "平均收益"]) dfy_styled = dfy.style.background_gradient(cmap="RdYlGn_r", axis=1).format("{:.2%}", na_rep="-") st.dataframe(dfy_styled, width="stretch") @@ -203,17 +234,22 @@ def show_monthly_return(df, ret_col="total", sub_title="月度累计收益", **k def show_drawdowns(df: pd.DataFrame, ret_col, key=None, **kwargs): """展示最大回撤分析 - :param df: pd.DataFrame, columns: cells, index: dates - :param ret_col: str, 回报率列名称 - :param key: str, 可选,组件的唯一标识符,默认自动生成 - :param kwargs: - - sub_title: str, optional, 子标题 - - top: int, optional, 默认10, 返回最大回撤的数量 + 根据日收益重建累计收益与累计最高,绘制回撤曲线(双 Y 轴叠加累计收益),并 + 给出 10% / 30% / 50% 三个分位数辅助线。同时通过 :func:`top_drawdowns` 展示 + Top N 回撤的详细信息(开始时间、结束时间、回撤天数等)。 + + :param df: pd.DataFrame,列包含 ``ret_col``,索引为日期 + :param ret_col: str,回报率列名称 + :param key: str,可选;组件唯一标识符 + :param kwargs: 其他参数 + - sub_title: str,子标题 + - top: int,返回最大回撤的数量,默认 10 + :return: None """ df = ensure_datetime_index(df) df = df[[ret_col]].copy().fillna(0).sort_index(ascending=True) - # 计算回撤数据 + # 计算累计收益、累计最高与回撤 df["cum_ret"] = df[ret_col].cumsum() df["cum_max"] = df["cum_ret"].cummax() df["drawdown"] = df["cum_ret"] - df["cum_max"] @@ -222,10 +258,10 @@ def show_drawdowns(df: pd.DataFrame, ret_col, key=None, **kwargs): if sub_title: st.subheader(sub_title, divider="rainbow") - # 绘制回撤图 + # 双轴绘图:左轴回撤填充,右轴累计收益曲线 fig = go.Figure() - # 回撤曲线 + # 回撤曲线(向下填充) fig.add_trace( go.Scatter( x=df.index, @@ -239,7 +275,7 @@ def show_drawdowns(df: pd.DataFrame, ret_col, key=None, **kwargs): ) ) - # 累计收益曲线(右轴) + # 累计收益曲线(右 Y 轴) fig.add_trace( go.Scatter( x=df.index, y=df["cum_ret"], mode="lines", name="累计收益", yaxis="y2", opacity=0.8, line={"color": "red"} @@ -248,7 +284,7 @@ def show_drawdowns(df: pd.DataFrame, ret_col, key=None, **kwargs): fig.update_layout(yaxis2={"title": "累计收益", "overlaying": "y", "side": "right"}) - # 添加分位数线 + # 加上 10%、30%、50% 三个分位数辅助线 for q in [0.1, 0.3, 0.5]: y1 = df["drawdown"].quantile(q) fig.add_hline(y=y1, line_dash="dot", line_color="green", line_width=1) @@ -263,13 +299,13 @@ def show_drawdowns(df: pd.DataFrame, ret_col, key=None, **kwargs): height=300, ) - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key(df, prefix="dd", ret_col=ret_col, top=kwargs.get("top", 10)) st.plotly_chart(fig, key=key, width="stretch") - # 显示回撤详情 + # Top N 回撤详情 top = kwargs.get("top", 10) if top is not None: with st.expander(f"TOP{top} 最大回撤详情", expanded=False): @@ -285,9 +321,15 @@ def show_drawdowns(df: pd.DataFrame, ret_col, key=None, **kwargs): def show_rolling_daily_performance(df, ret_col, key=None, **kwargs): """展示滚动统计数据 - :param df: pd.DataFrame, 日收益数据,columns=['dt', ret_col] - :param ret_col: str, 收益列名 - :param key: str, 可选,组件的唯一标识符,默认自动生成 + 在指定窗口(自然日)下,计算日收益的滚动绩效指标(如年化、夏普、最大回撤等), + 并以面积图展示用户选择的指标随时间的变化。 + + :param df: pd.DataFrame,日收益数据;索引为日期,包含 ``ret_col`` 列 + :param ret_col: str,收益列名 + :param key: str,可选;组件唯一标识符 + :param kwargs: 其他参数 + - sub_title: str,子标题 + :return: None """ from czsc.utils.analysis.stats import rolling_daily_performance @@ -298,23 +340,23 @@ def show_rolling_daily_performance(df, ret_col, key=None, **kwargs): if sub_title: st.subheader(sub_title, divider="rainbow", anchor=sub_title) - # 参数设置 + # 用户参数:滚动窗口、最小样本数、绩效指标 c1, c2, c3 = st.columns(3) window = c1.number_input("滚动窗口(自然日)", value=365 * 3, min_value=365, max_value=3650) min_periods = c2.number_input("最小样本数", value=365, min_value=100, max_value=3650) - # 计算滚动绩效 + # 计算滚动绩效,并补充一个"年化波动率/最大回撤"派生指标 dfr = rolling_daily_performance(df, ret_col, window=window, min_periods=min_periods) dfr["年化波动率/最大回撤"] = dfr["年化波动率"] / dfr["最大回撤"] - # 选择指标 + # 用户挑选要展示的指标 cols = [x for x in dfr.columns if x not in ["sdt", "edt"]] col = c3.selectbox("选择指标", cols, index=cols.index("夏普") if "夏普" in cols else 0) - # 绘图 + # 用面积图展示该指标随时间的变化 fig = px.area(dfr, x="edt", y=col, labels={"edt": "", col: col}) - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key( df, prefix="roll_perf", ret_col=ret_col, col=col, window=window, min_periods=min_periods diff --git a/czsc/svc/statistics.py b/czsc/svc/statistics.py index 0af19165d..ba9083971 100644 --- a/czsc/svc/statistics.py +++ b/czsc/svc/statistics.py @@ -1,7 +1,19 @@ """ -统计分析相关的可视化组件 - -包含分段收益、年度统计、样本内外对比、PSI分析等功能 +统计分析相关的 Streamlit 可视化组件 + +本模块汇集了一组面向"日收益数据"或"通用 DataFrame"的统计分析与展示组件, +主要包括: + +1. :func:`show_splited_daily`:分段展示策略最近 1 周 / 1 月 / 1 年 / 今年以来 / 成立以来等 + 不同时间段的绩效; +2. :func:`show_yearly_stats`:按自然年统计日收益绩效; +3. :func:`show_out_in_compare`:以指定日期为分界,比较样本内外表现; +4. :func:`show_outsample_by_dailys`:基于日收益的样本内外两段或三段对比; +5. :func:`show_psi`:分布稳定性指标 PSI; +6. :func:`show_classify`:单变量分层统计与单调性观察; +7. :func:`show_date_effect`:星期效应与月份效应; +8. :func:`show_normality_check`:正态性检验(Shapiro-Wilk、Jarque-Bera、KS); +9. :func:`show_describe` / :func:`show_df_describe`:DataFrame 描述性统计的着色版本。 """ import numpy as np @@ -11,16 +23,21 @@ from deprecated import deprecated from .base import apply_stats_style, ensure_datetime_index, generate_component_key -from rs_czsc import daily_performance +from wbt import daily_performance def show_splited_daily(df, ret_col, **kwargs): """展示分段日收益表现 - :param df: pd.DataFrame - :param ret_col: str, df 中的列名,指定收益列 - :param kwargs: - sub_title: str, 子标题 + 将日收益数据按"过去 1 周 / 2 周 / 1 月 / 3 月 / 6 月 / 1 年 / 今年以来 / 成立以来" + 等区间切分,分别计算 :func:`daily_performance` 绩效,并以表格展示。 + + :param df: pd.DataFrame,必须包含 ``ret_col`` 列、索引为日期或包含 ``dt`` 列 + :param ret_col: str,指定收益列 + :param kwargs: 其他参数 + - sub_title: str,子标题 + - yearly_days: int,年化天数,默认 252 + :return: None """ yearly_days = kwargs.get("yearly_days", 252) df = ensure_datetime_index(df) @@ -30,6 +47,7 @@ def show_splited_daily(df, ret_col, **kwargs): if sub_title: st.subheader(sub_title, divider="rainbow", anchor=sub_title) + # 以最后一个交易日为锚点构造 8 个时间段 last_dt = df.index[-1] sdt_map = { "过去1周": last_dt - pd.Timedelta(days=7), @@ -62,10 +80,14 @@ def show_splited_daily(df, ret_col, **kwargs): def show_yearly_stats(df, ret_col, **kwargs): """按年计算日收益表现 - :param df: pd.DataFrame,数据源 + 将日收益按自然年分组,分别调用 :func:`daily_performance`,年化天数取年份中 + 最大的一个分组长度,避免不完整年份导致年化指标偏低。 + + :param df: pd.DataFrame,日收益数据 :param ret_col: str,收益列名 - :param kwargs: - - sub_title: str, 子标题 + :param kwargs: 其他参数 + - sub_title: str,子标题 + :return: None """ daily_performance = safe_import_daily_performance() if daily_performance is None: @@ -75,6 +97,7 @@ def show_yearly_stats(df, ret_col, **kwargs): df = df.copy().fillna(0).sort_index(ascending=True) df["年份"] = df.index.year + # 用最长年份的样本数作为 yearly_days,最大化降低年初 / 年末截断带来的偏差 yearly_days = max(len(df_) for year, df_ in df.groupby("年份")) _stats = [] @@ -93,7 +116,18 @@ def show_yearly_stats(df, ret_col, **kwargs): def show_out_in_compare(df, ret_col, mid_dt, **kwargs): - """展示样本内外表现对比""" + """展示样本内外表现对比 + + 以 ``mid_dt`` 为切分点,分别在样本内 / 样本外区间计算 :func:`daily_performance`, + 并把两组绩效拼接成单张表格,方便直接对比。 + + :param df: pd.DataFrame,日收益数据 + :param ret_col: str,收益列名 + :param mid_dt: 样本切分点;样本内为 ``< mid_dt``,样本外为 ``>= mid_dt`` + :param kwargs: 其他参数 + - sub_title: str,子标题 + :return: None + """ daily_performance = safe_import_daily_performance() if daily_performance is None: return @@ -106,11 +140,13 @@ def show_out_in_compare(df, ret_col, mid_dt, **kwargs): dfi = df[df.index < mid_dt].copy() dfo = df[df.index >= mid_dt].copy() + # 样本内 stats_i = daily_performance(dfi[ret_col].to_list()) stats_i["标记"] = "样本内" stats_i["开始日期"] = dfi.index[0].strftime("%Y-%m-%d") stats_i["结束日期"] = dfi.index[-1].strftime("%Y-%m-%d") + # 样本外 stats_o = daily_performance(dfo[ret_col].to_list()) stats_o["标记"] = "样本外" stats_o["开始日期"] = dfo.index[0].strftime("%Y-%m-%d") @@ -140,7 +176,7 @@ def show_out_in_compare(df, ret_col, mid_dt, **kwargs): if sub_title: st.subheader(sub_title, divider="rainbow") - # 应用样式 + # 着色:正向指标用反向 RdYlGn;负向指标用正向 RdYlGn df_stats_styled = df_stats.style.background_gradient(cmap="RdYlGn_r", subset=["年化"]) df_stats_styled = df_stats_styled.background_gradient(cmap="RdYlGn_r", subset=["夏普"]) df_stats_styled = df_stats_styled.background_gradient(cmap="RdYlGn", subset=["最大回撤"]) @@ -173,9 +209,14 @@ def show_out_in_compare(df, ret_col, mid_dt, **kwargs): def show_outsample_by_dailys(df, outsample_sdt1, outsample_sdt2=None): """根据日收益数据展示样本内外对比 - :param df: 日收益数据,包含列 ['dt', 'returns'] + 支持两种模式: + - 仅传入 ``outsample_sdt1``:分为"样本内 / 样本外"两段; + - 同时传入 ``outsample_sdt2``:分为"研究阶段样本内 / 研究阶段样本外 / 系统跟踪样本外"三段。 + + :param df: pd.DataFrame,必须包含 ``['dt', 'returns']`` 两列 :param outsample_sdt1: 样本外开始日期 - :param outsample_sdt2: 实盘开始跟踪的日期,如果为 None,则只展示样本内和样本外两个阶段 + :param outsample_sdt2: 实盘开始跟踪的日期;为 ``None`` 则只展示两段 + :return: None """ from czsc.eda import cal_yearly_days @@ -192,10 +233,11 @@ def show_outsample_by_dailys(df, outsample_sdt1, outsample_sdt2=None): outsample_sdt1 = pd.to_datetime(outsample_sdt1).strftime("%Y-%m-%d") def __show_returns(dfx): + """单段展示:核心指标 + 累计收益曲线""" stats = daily_performance(dfx["returns"], yearly_days=yearly_days) sc1, sc2, sc3 = st.columns(3) - # 绘制收益指标 + # 9 个核心指标分 3 列展示 sc1.metric("年化收益率", f"{stats['年化']:.2%}") sc1.metric("夏普比率", f"{stats['夏普']:.2f}") sc1.metric("新高占比", f"{stats['新高占比']:.2%}") @@ -222,7 +264,7 @@ def __show_returns(dfx): df1 = df[df["dt"] < outsample_sdt1].copy() # 样本内 df2 = df[(df["dt"] >= outsample_sdt1) & (df["dt"] < outsample_sdt2)].copy() # 第一段样本外 - df3 = df[df["dt"] >= outsample_sdt2].copy() # 第二段样本外 + df3 = df[df["dt"] >= outsample_sdt2].copy() # 第二段样本外(系统跟踪) c1, c2, c3 = st.columns(3) @@ -253,13 +295,17 @@ def __show_returns(dfx): def show_psi(df, factor, segment, **kwargs): - """PSI分布稳定性 + """PSI 分布稳定性 + + PSI(Population Stability Index)用于衡量分组因子在不同分段下的分布稳定性, + 数值越大代表分布差异越显著。 - :param df: pd.DataFrame, 数据源 - :param factor: str, 分组因子 - :param segment: str, 分段字段 - :param kwargs: - - sub_title: str, 子标题 + :param df: pd.DataFrame,数据源 + :param factor: str,分组因子 + :param segment: str,分段字段 + :param kwargs: 其他参数 + - sub_title: str,子标题 + :return: None """ from czsc.utils.analysis.stats import psi @@ -278,16 +324,21 @@ def show_psi(df, factor, segment, **kwargs): def show_classify(df, col1, col2, n=10, method="cut", key=None, **kwargs): - """显示 col1 对 col2 的分类作用 - - :param df: 数据,pd.DataFrame - :param col1: 分层列 - :param col2: 统计列 - :param n: 分层数量 - :param method: 分层方法,cut 或 qcut - :param key: str, 可选,组件的唯一标识符,默认自动生成 - :param kwargs: - - show_bar: bool, 是否展示柱状图,默认为 False + """显示 ``col1`` 对 ``col2`` 的分类作用 + + 将 ``col1`` 按 ``cut``(等距)或 ``qcut``(等频)分层后,对每一层统计 ``col2`` + 的描述性指标,并展示其单调性、首末层均值等关键信息。 + + :param df: pd.DataFrame,数据源 + :param col1: str,分层列 + :param col2: str,统计列 + :param n: int,分层数量 + :param method: str,分层方法,``"cut"`` 或 ``"qcut"`` + :param key: str,可选;组件唯一标识符 + :param kwargs: 其他参数 + - show_bar: bool,是否展示均值柱状图,默认 False + :return: None + :raises ValueError: 当 method 不在 ``{"cut", "qcut"}`` 时 """ import czsc @@ -302,6 +353,7 @@ def show_classify(df, col1, col2, n=10, method="cut", key=None, **kwargs): dfg = df.groupby(f"{col1}_分层", observed=True)[col2].describe().reset_index() dfx = dfg.copy() + # 用单调性、首末层均值描述分层效果 info = ( f"{col1} 分层对应 {col2} 的均值单调性::red[{czsc.monotonicity(dfx['mean']):.2%}]; " f"最后一层的均值::red[{dfx['mean'].iloc[-1]:.4f}];" @@ -316,7 +368,7 @@ def show_classify(df, col1, col2, n=10, method="cut", key=None, **kwargs): fig.update_xaxes(title=None) fig.update_layout(margin={"l": 0, "r": 0, "t": 0, "b": 0}) - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key(df, prefix="classify", col1=col1, col2=col2, n=n, method=method) @@ -343,12 +395,15 @@ def show_classify(df, col1, col2, n=10, method="cut", key=None, **kwargs): def show_date_effect(df: pd.DataFrame, ret_col: str, **kwargs): """分析日收益数据的日历效应 - :param df: pd.DataFrame, 包含日期的日收益数据 - :param ret_col: str, 收益列名称 - :param kwargs: dict, 其他参数 - - show_weekday: bool, 是否展示星期效应,默认为 True - - show_month: bool, 是否展示月份效应,默认为 True - - percentiles: list, 分位数,默认为 [0.1, 0.25, 0.5, 0.75, 0.9] + 分别按"星期几"与"月份"对日收益做 describe 统计,观察是否存在显著的日历效应。 + + :param df: pd.DataFrame,包含日期索引或 dt 列的日收益数据 + :param ret_col: str,收益列名 + :param kwargs: 其他参数 + - show_weekday: bool,是否展示星期效应,默认 True + - show_month: bool,是否展示月份效应,默认 True + - percentiles: list,分位数,默认 [0.1, 0.25, 0.5, 0.75, 0.9] + :return: None """ show_weekday = kwargs.get("show_weekday", True) show_month = kwargs.get("show_month", True) @@ -393,8 +448,12 @@ def show_date_effect(df: pd.DataFrame, ret_col: str, **kwargs): def show_normality_check(data: pd.Series, alpha=0.05): """展示正态性检验结果 - :param data: pd.Series, 需要检验的数据 - :param alpha: float, 显著性水平,默认为 0.05 + 依次完成 Shapiro-Wilk、Jarque-Bera、Kolmogorov-Smirnov 三种检验,并附带绘制 + 直方图(叠加正态密度曲线)与 Q-Q 图。 + + :param data: pd.Series,需要检验的数据 + :param alpha: float,显著性水平,默认 0.05 + :return: None """ import matplotlib.pyplot as plt import seaborn as sns @@ -404,6 +463,7 @@ def show_normality_check(data: pd.Series, alpha=0.05): clean_data = data.dropna() def __metric(s, p): + """以 3 列形式展示统计量、P 值与是否拒绝原假设""" m1, m2, m3 = st.columns(3) m1.metric(label="统计量", value=f"{s:.3f}", border=False) m2.metric(label="P值", value=f"{p:.1%}", border=False) @@ -426,17 +486,20 @@ def __metric(s, p): stat, p_ks = kstest(clean_data, "norm", args=(mu, std)) __metric(stat, p_ks) + # matplotlib 中文负号修复 + 主题 plt.rcParams["axes.unicode_minus"] = False plt.style.use("ggplot") fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) + # 直方图 + 正态密度曲线 sns.histplot(clean_data, kde=True, stat="density", ax=ax1) x = np.linspace(mu - 4 * std, mu + 4 * std, 100) ax1.plot(x, norm.pdf(x, mu, std), "r", lw=2) ax1.set_title(f"Histogram => SKEW: {clean_data.skew():.2f}, KURT: {clean_data.kurt():.2f}") ax1.legend(["Normal PDF", "Data"]) + # Q-Q 图 sm.qqplot(clean_data, line="45", fit=True, ax=ax2) ax2.set_title("Q-Q") st.pyplot(fig) @@ -446,7 +509,14 @@ def __metric(s, p): def show_describe(df: pd.DataFrame, **kwargs): """展示 DataFrame 的描述性统计信息 - :param df: pd.DataFrame, 数据框 + 比 :func:`show_df_describe` 多了"偏度""峰度"以及自定义分位数和小数位数控制。 + + :param df: pd.DataFrame,数据源 + :param kwargs: 其他参数 + - columns: list,参与统计的列名,默认 df 的全部列 + - percentiles: list,分位数列表,默认 [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95] + - digits: int,统计值保留小数位数,默认 2 + :return: None """ columns = kwargs.get("columns") percentiles = kwargs.get("percentiles", [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]) @@ -466,6 +536,7 @@ def show_describe(df: pd.DataFrame, **kwargs): df_styled = df_styled.background_gradient(cmap="RdYlGn_r", axis=None, subset=["偏度"]) df_styled = df_styled.background_gradient(cmap="RdYlGn_r", axis=None, subset=["峰度"]) + # 根据 digits 动态构造格式化字符串 format_dict = { "count": "{:.0f}", "mean": f"{{:.{digits}f}}", @@ -487,9 +558,10 @@ def show_describe(df: pd.DataFrame, **kwargs): @deprecated(reason="建议直接使用 show_describe 函数") def show_df_describe(df: pd.DataFrame): - """展示 DataFrame 的描述性统计信息 + """展示 DataFrame 的描述性统计信息(旧版,已弃用) - :param df: pd.DataFrame,必须是 df.describe() 的结果 + :param df: pd.DataFrame,必须是 ``df.describe()`` 的结果 + :return: None """ quantiles = [x for x in df.columns if "%" in x] df_styled = df.style.background_gradient(cmap="RdYlGn_r", axis=None, subset=["mean"]) diff --git a/czsc/svc/strategy.py b/czsc/svc/strategy.py index 68371be25..25d5df36d 100644 --- a/czsc/svc/strategy.py +++ b/czsc/svc/strategy.py @@ -1,11 +1,21 @@ """ -策略分析组件模块 - -该模块提供策略分析相关的 Streamlit 可视化组件,包括: -- 优化结果展示 -- 策略收益分析 -- 组合表现分析 -- 风险分析等 +策略分析与可视化组件模块 + +本模块汇总了策略层面的常用 Streamlit 可视化组件,覆盖以下场景: + +1. :func:`show_optuna_study`:展示 Optuna 调参的可视化结果与最佳参数表; +2. :func:`show_czsc_trader`:展示 ``CzscTrader`` 的多周期 K 线、分型、笔以及交易信号; +3. :func:`show_strategies_recent`:展示多个策略最近 N 天的收益对比; +4. :func:`show_returns_contribution`:分析子策略对组合总收益的贡献; +5. :func:`show_symbols_bench`:展示多品种等权基准与品种间相关性; +6. :func:`show_quarterly_effect`:展示四个季度分别的累计收益与绩效; +7. :func:`show_multi_backtest`:多策略回测结果的统一对比表; +8. :func:`show_cta_periods_classify` / :func:`show_volatility_classify`:按 CTA 行情阶段 + 或波动率分类的回测对比; +9. :func:`show_portfolio`:组合日收益绩效的综合展示; +10. :func:`show_turnover_rate`:换手率的多维度展示; +11. :func:`show_stats_compare`:多组策略回测的绩效对比; +12. :func:`show_symbol_penalty`:依次剔除收益最高的 N 个品种,对比收益变化。 作者: 缠中说禅团队 """ @@ -18,20 +28,21 @@ import streamlit as st from .base import apply_stats_style, generate_component_key -from rs_czsc import WeightBacktest +from wbt import WeightBacktest def show_optuna_study(study, key=None, **kwargs): """展示 Optuna Study 的可视化结果 - :param study: optuna.study.Study, Optuna Study 对象 - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 - :param kwargs: dict, 其他参数 - - - sub_title: str, optional, 子标题 - - keep: float, optional, 默认0.2, 保留最佳参数的比例 + 依次绘制 Optuna 的 ``contour`` 与 ``slice`` 图,并展示 ``optuna_good_params`` + 输出的最佳参数列表。 - :return: optuna.study.Study + :param study: optuna.study.Study,Optuna Study 对象 + :param key: str,可选;组件基础标识符,每个图表会自动追加后缀 + :param kwargs: 其他参数 + - sub_title: str,子标题 + - keep: float,保留最佳参数的比例,默认 0.2 + :return: optuna.study.Study;原样返回,便于链式调用 """ try: import optuna @@ -39,23 +50,27 @@ def show_optuna_study(study, key=None, **kwargs): st.error("请安装 optuna 库, 执行命令:pip install optuna") return + # Optuna 可视化文档: # https://optuna.readthedocs.io/en/stable/reference/visualization/index.html # https://zh-cn.optuna.org/reference/visualization.html from czsc.utils.optuna import optuna_good_params sub_title = kwargs.pop("sub_title", "Optuna Study Visualization") if sub_title: + # 为 anchor 生成稳定且短小的 hash anchor = hashlib.md5(sub_title.encode("utf-8")).hexdigest().upper()[:6] st.subheader(sub_title, divider="rainbow", anchor=anchor) + # 等高线图 fig = optuna.visualization.plot_contour(study) - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key(study, prefix="optuna", sub_title=sub_title) st.plotly_chart(fig, key=f"{key}_contour", width="stretch") + # 切片图 fig = optuna.visualization.plot_slice(study) st.plotly_chart(fig, key=f"{key}_slice", width="stretch") @@ -68,10 +83,15 @@ def show_optuna_study(study, key=None, **kwargs): def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): """显示缠中说禅交易员详情 + 将 ``CzscTrader`` 中的多周期 K 线、分型、笔、均线、成交量、MACD、交易信号统一 + 渲染到 Tab 页中;最后一个 Tab 展示策略详情(最新信号 + 各 Position 的 JSON)。 + :param trader: CzscTrader 对象 - :param max_k_num: 最大显示 K 线数量 - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 + :param max_k_num: int,每个周期最多显示多少根 K 线,默认 300 + :param key: str,可选;组件基础标识符 :param kwargs: 其他参数 + - sub_title: str,子标题 + :return: None """ import czsc from czsc.utils.ta import MACD @@ -90,6 +110,7 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): for freq, tab in zip(freqs, tabs[:-1], strict=False): c = trader.kas[freq] + # 仅显示最近 max_k_num 根 K 线,避免数据量过大造成前端卡顿 sdt = c.bars_raw[-max_k_num].dt if len(c.bars_raw) > max_k_num else c.bars_raw[0].dt df = pd.DataFrame(c.bars_raw) df["DIFF"], df["DEA"], df["MACD"] = MACD(df["close"], fastperiod=12, slowperiod=26, signalperiod=9) @@ -98,6 +119,7 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): kline = czsc.KlineChart(n_rows=3, row_heights=(0.5, 0.3, 0.2), title="", width="100%", height=800) kline.add_kline(df, name="") + # 叠加分型与笔 if len(c.bi_list) > 0: bi = pd.DataFrame( [{"dt": x.fx_a.dt, "bi": x.fx_a.fx} for x in c.bi_list] @@ -123,7 +145,7 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): kline.add_vol(df, row=2, line_width=1) kline.add_macd(df, row=3, line_width=1) - # 在基础周期上绘制交易信号 + # 在基础周期上叠加交易信号 if freq == trader.base_freq: for pos in trader.positions: bs_df = pd.DataFrame([x for x in pos.operates if x["dt"] >= sdt]) @@ -131,6 +153,7 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): continue open_ops = [czsc.Operate.LO, czsc.Operate.SO] + # 开仓用上三角,平仓用下三角;颜色区分多空 bs_df["tag"] = bs_df["op"].apply(lambda x: "triangle-up" if x in open_ops else "triangle-down") bs_df["color"] = bs_df["op"].apply(lambda x: "red" if x in open_ops else "white") @@ -149,6 +172,7 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): ) with tab: + # plotly 工具栏配置:开启滚轮缩放,去掉一些不常用按钮 config = { "scrollZoom": True, "displayModeBar": True, @@ -165,7 +189,7 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): ], } - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key(trader, prefix="czsc_trader", freq=freq, max_k_num=max_k_num) @@ -174,6 +198,7 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): with tabs[-1]: with st.expander("查看最新信号", expanded=False): if len(trader.s): + # 仅展示形如 "k1_k2_k3" 的标准信号(key 由三段组成) s = {k: v for k, v in trader.s.items() if len(k.split("_")) == 3} st.write(s) else: @@ -187,7 +212,10 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): def show_strategies_recent(df, **kwargs): """展示最近 N 天的策略表现 - :param df: pd.DataFrame, columns=['dt', 'strategy', 'returns'], 样例如下: + 输入数据应为长表 ``[dt, strategy, returns]``,按策略累积近 N 天的收益, + 再透视成"策略 × 时间窗"的对比矩阵;同时统计每个时间窗的盈利策略数量与比例。 + + :param df: pd.DataFrame,columns = ['dt', 'strategy', 'returns'],样例如下: =================== ========== ============ dt strategy returns @@ -199,15 +227,16 @@ def show_strategies_recent(df, **kwargs): 2021-01-08 00:00:00 STK001 0.000510725 =================== ========== ============ - :param kwargs: dict - - - nseq: tuple, optional, 默认为 (1, 3, 5, 10, 20, 30, 60, 90, 120, 180, 240, 360),展示的天数序列 + :param kwargs: 其他参数 + - nseq: tuple,展示的天数序列,默认 (1, 3, 5, 10, 20, 30, 60, 90, 120, 180, 240, 360) + :return: None """ nseq = kwargs.get("nseq", (1, 3, 5, 10, 20, 30, 60, 90, 120, 180, 240, 360)) dfr = df.copy() dfr = pd.pivot_table(dfr, index="dt", columns="strategy", values="returns", aggfunc="sum").fillna(0) rows = [] for n in nseq: + # 取最近 n 天每个策略的累计收益 for k, v in dfr.iloc[-n:].sum(axis=0).to_dict().items(): rows.append({"天数": f"近{n}天", "策略": k, "收益": v}) @@ -220,7 +249,7 @@ def show_strategies_recent(df, **kwargs): hide_index=False, ) - # 计算每个时间段的盈利策略数量 + # 每个时间窗内的盈利策略数量与比例 win_count = n_rets.map(lambda x: 1 if x > 0 else 0).sum(axis=0) win_rate = n_rets.map(lambda x: 1 if x > 0 else 0).sum(axis=0) / n_rets.shape[0] dfs = pd.DataFrame({"盈利策略数量": win_count, "盈利策略比例": win_rate}).T @@ -232,10 +261,13 @@ def show_strategies_recent(df, **kwargs): def show_returns_contribution(df, returns=None, max_returns=100, key=None): """分析子策略对总收益的贡献 - :param df: pd.DataFrame, 子策略日收益数据,index 为 datetime, columns 为 子策略名称 - :param returns: list, 子策略名称列表 - :param max_returns: int, 最大展示策略数量 - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 + 左侧绘制柱状图(包含正负贡献),右侧绘制饼图(仅展示正贡献策略的占比)。 + + :param df: pd.DataFrame,子策略日收益数据;index 为 datetime,columns 为子策略名称 + :param returns: list,子策略名称列表;为 None 时使用 df 的全部列 + :param max_returns: int,最大展示策略数量,超过时给出提示 + :param key: str,可选;组件基础标识符 + :return: None """ df = df.copy() for dt_col in ["date", "dt"]: @@ -250,18 +282,18 @@ def show_returns_contribution(df, returns=None, max_returns=100, key=None): st.warning(f"请选择多个策略进行对比,或者选择少于{max_returns} 个策略; 当前选择 {len(returns)} 个策略") return - # 计算每个策略的总收益贡献 + # 每个策略的总收益贡献 total_returns = df[returns].sum(axis=0) - # 创建用于绘图的数据框 + # 构造绘图数据 plot_df = pd.DataFrame({"策略": total_returns.index, "收益贡献": total_returns.values}) plot_df = plot_df.sort_values(by="收益贡献", ascending=False) - # 创建两列布局 + # 左大右小的两列布局 col1, col2 = st.columns([3, 2]) with col1.container(border=True): - # 绘制柱状图 + # 收益贡献柱状图 fig_bar = px.bar( plot_df, x="策略", @@ -274,7 +306,7 @@ def show_returns_contribution(df, returns=None, max_returns=100, key=None): ) fig_bar.update_layout(yaxis_title="绝对收益", xaxis_title="策略") - # 生成 key + # 自动生成组件 key if key is None: key = generate_component_key(df, prefix="ret_contrib", returns=returns, max_returns=max_returns) @@ -282,7 +314,7 @@ def show_returns_contribution(df, returns=None, max_returns=100, key=None): st.caption("柱状图展示每个策略的收益贡献, Y轴为绝对收益大小,X轴为策略名称") with col2.container(border=True): - # 绘制饼图,如果收益贡献为负,删除 + # 仅保留正贡献策略,绘制饼图 plot_df = plot_df[plot_df["收益贡献"] > 0] fig_pie = px.pie(plot_df, values="收益贡献", names="策略", title="盈利贡献分析(饼图)", width=600, height=400) fig_pie.update_traces(textposition="inside", textinfo="percent+label") @@ -293,17 +325,19 @@ def show_returns_contribution(df, returns=None, max_returns=100, key=None): def show_symbols_bench(df: pd.DataFrame, **kwargs): """展示多个品种的基准收益相关信息 - :param df: pd.DataFrame, 数据源, 包含symbol, dt, price 字段, 其他字段将被忽略 - symbol: 品种代码 - dt: 时间 - price: 交易价格 + 构造"品种等权"基准:先按 symbol 计算日内收益,再按日均值得到等权日收益; + 展示其核心绩效、最大回撤以及品种间日收益的相关性矩阵。 + + :param df: pd.DataFrame,必须包含 ``symbol``、``dt``、``price`` 三列;其他列将被忽略 :param kwargs: 其他参数 - - use_st_table: bool, 是否使用 st.table 展示相关性矩阵, 默认为 False + - use_st_table: bool,是否使用 ``st.table`` 展示相关性矩阵,默认 False + :return: None """ - from rs_czsc import daily_performance + from wbt import daily_performance from czsc.eda import cal_yearly_days df = df[["symbol", "dt", "price"]].copy() + # 按 symbol 计算价格的日内 pct_change df["pct_change"] = df.groupby("symbol")["price"].pct_change() df["date"] = df["dt"].dt.date dailys = df.groupby(["symbol", "date"])["pct_change"].sum().reset_index() @@ -313,6 +347,7 @@ def show_symbols_bench(df: pd.DataFrame, **kwargs): with st.container(border=True): st.markdown("##### 品种等权累计收益&最大回撤") + # 品种等权基准 dailys["total"] = dailys.mean(axis=1) dailys.index = pd.to_datetime(dailys.index) @@ -343,25 +378,30 @@ def show_symbols_bench(df: pd.DataFrame, **kwargs): def show_quarterly_effect(returns: pd.Series, key=None): """展示策略的季节性收益对比 - :param returns: 日收益率序列,index 为日期 - :param key: str, 可选,组件的基础标识符,每个图表会自动添加后缀 + 将日收益按"自然季度"分成四组,分别展示每组的核心绩效与累计收益曲线(按数字 + 索引绘制,便于跨年份对比),并以彩色矩形覆盖标注不同年份的区间。 + + :param returns: pd.Series,日收益率序列;index 为日期 + :param key: str,可选;组件基础标识符 + :return: None """ import plotly.express as px from czsc.eda import cal_yearly_days - from rs_czsc import daily_performance + from wbt import daily_performance returns.index = pd.to_datetime(returns.index) yearly_days = cal_yearly_days(returns.index.to_list()) - # 按4个季度划分数据为 s1,s2,s3,s4,分别计算每个季度的统计指标 + # 按季度切分收益 s1 = returns[returns.index.quarter == 1] s2 = returns[returns.index.quarter == 2] s3 = returns[returns.index.quarter == 3] s4 = returns[returns.index.quarter == 4] def __show_quarter_stats(s: pd.Series, quarter_name: str): + """展示单个季度的绩效与累计收益曲线""" stats = daily_performance(s.to_list(), yearly_days=yearly_days) st.markdown( f"总交易天数: `{len(s)}天` \ @@ -372,21 +412,20 @@ def __show_quarter_stats(s: pd.Series, quarter_name: str): | 年化波动率: `{stats['年化波动率']:.2%}`" ) - # 用 plotly 绘制累计收益率曲线, 用 数字作为index,方便对比 + # 用 plotly 绘制累计收益曲线,X 轴使用数字索引,便于跨年份对比 fig = px.line(s.cumsum(), x=list(range(len(s))), y=s.cumsum().values, title="") fig.update_layout(xaxis_title="交易天数", yaxis_title="累计收益率", margin={"l": 0, "r": 0, "t": 0, "b": 0}) - # 按年分组,绘制矩形覆盖 + # 按年份分组,用半透明矩形覆盖区分 years = s.index.year.unique() colors = ["rgba(102, 255, 178, 0.1)", "rgba(102, 178, 255, 0.1)"] # 薄荷绿和天蓝色,半透明 shapes = [] for i, year in enumerate(years): - # 获取该年的第一个和最后一个交易日在数字索引中的位置 + # 拿到该年第一个/最后一个交易日在数字索引中的位置 year_data = s[s.index.year == year] start_idx = s.index.get_indexer([year_data.index[0]])[0] end_idx = s.index.get_indexer([year_data.index[-1]])[0] - # 添加矩形 shapes.append( { "type": "rect", @@ -403,10 +442,9 @@ def __show_quarter_stats(s: pd.Series, quarter_name: str): } ) - # 更新图表布局,添加矩形 fig.update_layout(shapes=shapes) - # 添加年份标签 + # 在每段年份矩形的中部加上年份标签 annotations = [] for _, year in enumerate(years): year_data = s[s.index.year == year] @@ -426,7 +464,7 @@ def __show_quarter_stats(s: pd.Series, quarter_name: str): fig.update_layout(annotations=annotations) - # 生成 key + # 自动生成组件 key if key is None: key_base = generate_component_key(returns, prefix="quarterly") @@ -452,7 +490,16 @@ def __show_quarter_stats(s: pd.Series, quarter_name: str): def show_multi_backtest(wbs: dict, **kwargs): - """展示多个策略的回测结果""" + """展示多个策略的回测结果 + + 将多个 ``WeightBacktest`` 对象的核心绩效汇总成一张对比表,并通过 + :func:`show_cumulative_returns` 绘制累计收益曲线。 + + :param wbs: dict,``{策略名: WeightBacktest}`` 映射 + :param kwargs: 其他参数 + - show_describe: bool,是否额外展示几个核心指标的分布,默认 False + :return: tuple[pd.DataFrame, pd.DataFrame],绩效对比表与日收益宽表 + """ from czsc.svc.base import apply_stats_style from czsc.svc.returns import show_cumulative_returns from czsc.svc.statistics import show_describe @@ -466,7 +513,7 @@ def show_multi_backtest(wbs: dict, **kwargs): stats.update(wb.stats) rows.append(stats) - # 获取日收益 + # 取该策略的日收益序列(仅保留 total),便于横向合并 daily = wb.daily_return.copy()[["date", "total"]] daily["strategy"] = strategy daily["return"] = daily["total"] @@ -474,7 +521,7 @@ def show_multi_backtest(wbs: dict, **kwargs): dailys.append(daily[["dt", "strategy", "return"]]) df_stats = pd.DataFrame(rows) - # st.write(df_stats.columns.to_list()) + # 展示用列;按业务习惯排序 cols = [ "策略名称", "开始日期", @@ -506,6 +553,7 @@ def show_multi_backtest(wbs: dict, **kwargs): st.dataframe(apply_stats_style(df_stats)) + # 多策略累计收益对比 dailys = pd.concat(dailys, axis=0) dailys["dt"] = pd.to_datetime(dailys["dt"]) dailys = dailys.sort_values("dt", ascending=False) @@ -516,7 +564,7 @@ def show_multi_backtest(wbs: dict, **kwargs): if kwargs.get("show_describe", False): with st.container(border=True): st.markdown("#### :red[主要统计指标分布]") - # 绘制单笔收益、持仓K线数、夏普的分布 + # 展示单笔收益、持仓K线数、夏普、年化的分布 show_describe(df_stats[["单笔收益", "持仓K线数", "夏普", "年化"]]) return df_stats, df_dailys @@ -525,16 +573,19 @@ def show_multi_backtest(wbs: dict, **kwargs): def show_cta_periods_classify(df: pd.DataFrame, **kwargs): """展示不同市场环境下的策略表现 - :param df: 标准K线数据, - 必须包含 dt, symbol, open, close, high, low, vol, amount, weight, price 列; - 如果 price 列不存在,则使用 close 列 - :param kwargs: + 通过 ``mark_cta_periods`` 把行情按"最佳/最差/普通"以及"上涨/下跌"打标签, + 再分别将权重在不同标签下保留 / 置 0,构造多组回测进行对比。 + :param df: pd.DataFrame,标准 K 线数据,必须包含 + ``dt, symbol, open, close, high, low, vol, amount, weight, price`` 列; + 若 ``price`` 列不存在,则使用 ``close`` 列代替 + :param kwargs: 其他参数 - fee_rate: 手续费率 - digits: 小数位数 - weight_type: 权重类型 - - q1: 最容易赚钱的笔的占比, mark_cta_periods 函数的参数 - - q2: 最难赚钱的笔的占比, mark_cta_periods 函数的参数 + - q1: 最容易赚钱的笔的占比,传给 ``mark_cta_periods`` + - q2: 最难赚钱的笔的占比,传给 ``mark_cta_periods`` + :return: None """ from czsc.eda import cal_yearly_days @@ -544,6 +595,7 @@ def show_cta_periods_classify(df: pd.DataFrame, **kwargs): yearly_days = cal_yearly_days(df["dt"].unique().tolist()) + # 期望的标签列;若用户未预先打标,则在内部调用 mark_cta_periods 自动生成 mark_cols = [ "is_best_period", "is_best_up_period", @@ -565,6 +617,7 @@ def show_cta_periods_classify(df: pd.DataFrame, **kwargs): if "price" not in dfs.columns: dfs["price"] = dfs["close"] + # 行情结构占比统计 p1 = dfs["is_best_period"].value_counts()[1] / len(dfs) p1_up = dfs["is_best_up_period"].value_counts()[1] / len(dfs) p1_down = dfs["is_best_down_period"].value_counts()[1] / len(dfs) @@ -578,6 +631,7 @@ def show_cta_periods_classify(df: pd.DataFrame, **kwargs): st.caption(f"WeightBacktest 参数:fee_rate={fee_rate}, digits={digits}, weight_type={weight_type}") wb_cols = ["dt", "symbol", "weight", "price"] + # None 表示原始策略;其余每个 flag 表示"仅保留该类行情的权重" period_flags = [ None, "is_best_period", @@ -594,6 +648,7 @@ def show_cta_periods_classify(df: pd.DataFrame, **kwargs): for flag, classify_ in zip(period_flags, classify, strict=False): df_tmp = dfs.copy() if flag: + # 非目标行情区间的权重置 0 df_tmp["weight"] = np.where(df_tmp[flag], df_tmp["weight"], 0) wb = WeightBacktest( df_tmp[wb_cols], @@ -610,26 +665,25 @@ def show_cta_periods_classify(df: pd.DataFrame, **kwargs): def show_volatility_classify(df: pd.DataFrame, kind="ts", **kwargs): """【后验,有未来信息,不能用于实盘】波动率分类回测 - :param df: 标准K线数据, - 必须包含 dt, symbol, open, close, high, low, vol, amount, weight, price 列; - 如果 price 列不存在,则使用 close 列 - :param kwargs: - - - fee_rate: 手续费率,WeightBacktest 的参数 - - digits: 小数位数,WeightBacktest 的参数 - - weight_type: 权重类型,'ts' 表示时序,'cs' 表示截面,WeightBacktest 的参数 - - kind: 波动率分类方式,'ts' 表示时序,'cs' 表示截面,mark_volatility 函数的参数 - - window: 计算波动率的窗口,mark_volatility 函数的参数 - - q1: 波动率最大的K线数量占比,默认 0.3,mark_volatility 函数的参数 - - q2: 波动率最小的K线数量占比,默认 0.3,mark_volatility 函数的参数 + 使用 ``mark_volatility`` 给每根 K 线打"高/中/低"波动率标签,再分别将权重 + 限制在某一类波动率内回测。注意该函数依赖未来信息,仅用于研究分析。 + :param df: pd.DataFrame,标准 K 线数据,必须包含 + ``dt, symbol, open, close, high, low, vol, amount, weight, price`` 列; + 若 ``price`` 列不存在,则使用 ``close`` 列代替 + :param kind: 波动率分类方式,``'ts'`` 表示时序,``'cs'`` 表示截面 + :param kwargs: 其他参数 + - fee_rate: 手续费率,``WeightBacktest`` 参数 + - digits: 小数位数,``WeightBacktest`` 参数 + - weight_type: 权重类型,``'ts'`` 时序 / ``'cs'`` 截面 + - kind: 波动率分类方式(同 ``kind``),传给 ``mark_volatility`` + - window: 计算波动率的窗口 + - q1: 高波动率 K 线数量占比,默认 0.3 + - q2: 低波动率 K 线数量占比,默认 0.3 :return: None - - ============== - example - ============== - >>> show_volatility_classify(df, fee_rate=0.00, digits=1, weight_type='ts', - >>> kind='ts', window=20, q1=0.2, q2=0.2 ) + :example: + >>> show_volatility_classify(df, fee_rate=0.00, digits=1, weight_type='ts', + >>> kind='ts', window=20, q1=0.2, q2=0.2 ) """ from czsc.eda import cal_yearly_days @@ -654,12 +708,14 @@ def show_volatility_classify(df: pd.DataFrame, kind="ts", **kwargs): if "price" not in dfs.columns: dfs["price"] = dfs["close"] + # 三类波动率的占比 p1 = dfs["is_max_volatility"].value_counts()[1] / len(dfs) p2 = dfs["is_mid_volatility"].value_counts()[1] / len(dfs) p3 = dfs["is_min_volatility"].value_counts()[1] / len(dfs) st.markdown(f"高波动行情占比::red[{p1:.2%}];中波动行情占比::green[{p2:.2%}];低波动行情占比::blue[{p3:.2%}]") st.caption(f"WeightBacktest 参数:fee_rate={fee_rate}, digits={digits}, weight_type={weight_type}") + # 原始策略 wb = WeightBacktest( dfs[["dt", "symbol", "weight", "price"]], fee_rate=fee_rate, @@ -668,16 +724,19 @@ def show_volatility_classify(df: pd.DataFrame, kind="ts", **kwargs): yearly_days=yearly_days, ) + # 高波动权重组:非高波动区间置 0 df1 = dfs.copy() df1["weight"] = np.where(df1["is_max_volatility"], df1["weight"], 0) df1 = df1[["dt", "symbol", "weight", "price"]].copy().reset_index(drop=True) wb1 = WeightBacktest(df1, fee_rate=fee_rate, digits=digits, weight_type=weight_type, yearly_days=yearly_days) + # 中波动权重组 df2 = dfs.copy() df2["weight"] = np.where(df2["is_mid_volatility"], df2["weight"], 0) df2 = df2[["dt", "symbol", "weight", "price"]].copy().reset_index(drop=True) wb2 = WeightBacktest(df2, fee_rate=fee_rate, digits=digits, weight_type=weight_type, yearly_days=yearly_days) + # 低波动权重组 df3 = dfs.copy() df3["weight"] = np.where(df3["is_min_volatility"], df3["weight"], 0) df3 = df3[["dt", "symbol", "weight", "price"]].copy().reset_index(drop=True) @@ -691,15 +750,21 @@ def show_volatility_classify(df: pd.DataFrame, kind="ts", **kwargs): def show_portfolio(df: pd.DataFrame, portfolio: str, benchmark: str | None = None, **kwargs): """分析组合日收益绩效 - :param df: 日收益数据,包含 dt, portfolio, benchmark 三列, 其中 dt 为日期, portfolio 为组合收益, benchmark 为基准收益 - :param portfolio: 组合名称 - :param benchmark: 基准名称, 可选 - :param show_detail: 是否展示详情, 可选, 默认展示 + 展示组合的核心指标 + 回撤曲线;并在详情中按 Tab 展示年度 / 季度 / 月度绩效, + 若提供基准,还会展示组合相对基准的超额收益分析。 + + :param df: pd.DataFrame,日收益数据;包含 ``dt`` 与 ``portfolio`` 列,可选 ``benchmark`` 列 + :param portfolio: str,组合名称 + :param benchmark: str,基准名称,可选 + :param kwargs: 其他参数 + - show_detail: bool,是否展示详情,默认展示 + :return: None """ - from rs_czsc import daily_performance + from wbt import daily_performance from czsc.eda import cal_yearly_days if benchmark is not None: + # 同时计算超额收益(alpha),便于后续展示 df["alpha"] = df[portfolio] - df[benchmark] df = df[["dt", portfolio, benchmark, "alpha"]].copy() else: @@ -759,15 +824,19 @@ def show_portfolio(df: pd.DataFrame, portfolio: str, benchmark: str | None = Non def show_turnover_rate(df: pd.DataFrame): """显示换手率变化 - :param df: 权重数据,必须包含 dt, symbol, weight 三列, 其他列忽略 + 展示策略的核心换手率指标(单边、日均、最大、最小、近 30 天、近 1 年), + 并辅以 3 张图:日换手累计曲线、月换手柱状图、年换手柱状图。 + + :param df: pd.DataFrame,权重数据;必须包含 ``dt``、``symbol``、``weight`` 三列 + :return: None """ from czsc.eda import turnover_rate res = turnover_rate(df, verbose=True) - dfc = res["日换手详情"] # 两列:dt, change + dfc = res["日换手详情"] # 包含两列:dt, change dfc["dt"] = pd.to_datetime(dfc["dt"]) - # 最近30天换手率 + # 最近 30 天换手率 _sdt_30 = dfc["dt"].max() - pd.Timedelta(days=30) _dfc = dfc[dfc["dt"] >= _sdt_30] @@ -784,21 +853,21 @@ def show_turnover_rate(df: pd.DataFrame): m6.metric("最近一年换手率", f"{_dfy['change'].sum():.2f}", help=f"最近一年换手率,自{_sdt_year}以来") p1, p2, p3 = st.columns([2, 3, 1]) - # 日换手的累计变化(X轴不显示) + # 日换手累计曲线 df_daily = dfc.copy() df_daily["change"] = df_daily["change"].cumsum() fig = px.line(df_daily, x="dt", y="change", title="日换手累计曲线") fig.update_xaxes(title_text="") p1.plotly_chart(fig, width="stretch") - # 月换手的柱状图 + # 月换手柱状图 df_monthly = dfc.copy() df_monthly = df_monthly.set_index("dt").resample("ME").sum().reset_index() fig = px.bar(df_monthly, x="dt", y="change", title="月换手变化") fig.update_xaxes(title_text="") p2.plotly_chart(fig, width="stretch") - # 年换手的柱状图 + # 年换手柱状图 df_yearly = dfc.copy() df_yearly = df_yearly.set_index("dt").resample("YE").sum().reset_index() df_yearly["dt"] = df_yearly["dt"].dt.strftime("%Y") @@ -813,7 +882,12 @@ def show_turnover_rate(df: pd.DataFrame): def show_stats_compare(df: pd.DataFrame, **kwargs): """显示多组策略回测的绩效对比 - :param df: 策略回测结果, WeightBacktest 的 stats 数据汇总成的 DataFrame,用 name 列区分不同的策略 + 将 ``WeightBacktest.stats`` 汇总成的 DataFrame(用 ``name`` 列区分不同策略)按 + 标准列序展示,并应用 :func:`apply_stats_style` 的统一着色与格式化。 + + :param df: pd.DataFrame,策略回测结果汇总表 + :param kwargs: 其他参数(保留以便扩展) + :return: None """ if "name" in df.columns: df.set_index("name", inplace=True) @@ -845,20 +919,28 @@ def show_stats_compare(df: pd.DataFrame, **kwargs): "结束日期", ] - # 只选择存在的列 + # 仅保留实际存在的列,避免 KeyError existing_cols = [col for col in stats_cols if col in df.columns] df = df[existing_cols].copy() - # 应用样式 + # 应用统一样式 df = apply_stats_style(df) st.dataframe(df, width="stretch") def show_symbol_penalty(df: pd.DataFrame, n=3, **kwargs): - """依次删除策略收益最高的N个品种,对比收益变化 + """依次删除策略收益最高的 N 个品种,对比收益变化 - :param df: 策略权重数据 - :param n: 删除的品种数量 + 用于评估策略对个别赚钱品种的依赖程度:先用全量数据回测,再依次剔除累计盈利 + Top 1、Top 2、... Top N 品种重新回测,最后比较累计收益曲线与绩效。 + + :param df: pd.DataFrame,策略权重数据 + :param n: int,删除的品种数量 + :param kwargs: 其他参数 + - digits: int,权重小数位数,默认 2 + - fee_rate: float,手续费率,默认 0 + - weight_type: str,权重类型,默认 ``'ts'`` + - yearly_days: int,年化天数,默认 252 :return: None """ WeightBacktest = safe_import_weight_backtest() @@ -868,6 +950,7 @@ def show_symbol_penalty(df: pd.DataFrame, n=3, **kwargs): weight_type = kwargs.get("weight_type", "ts") yearly_days = kwargs.get("yearly_days", 252) + # n 不能超过 symbol 的总数 - 1(至少要剩 1 个) n = min(n, df["symbol"].nunique() - 1) dfw = df[["dt", "symbol", "weight", "price"]].copy() wb_map = {} @@ -875,6 +958,7 @@ def show_symbol_penalty(df: pd.DataFrame, n=3, **kwargs): wb_map["原始策略"] = wb1 for i in list(range(1, n + 1)): + # 取累计盈利最高的前 i 个品种 top_symbols = wb1.get_top_symbols(i, kind="profit") dfw1 = dfw[~dfw["symbol"].isin(top_symbols)].copy().reset_index(drop=True) wb2 = WeightBacktest( diff --git a/czsc/traders/__init__.py b/czsc/traders/__init__.py index 98fd1fbc6..b8a586b2a 100644 --- a/czsc/traders/__init__.py +++ b/czsc/traders/__init__.py @@ -1,18 +1,50 @@ +"""czsc.traders —— 交易员(trader)命名空间的统一入口模块。 + +本模块在 Rust/Python 混合架构下扮演"门面(facade)"角色,将分散在多个底层 +实现中的交易相关公共 API 重新汇聚到一个稳定的导入路径上,方便上层业务以 +``from czsc.traders import XXX`` 的方式直接使用,而无需关心底层是 Rust +扩展还是纯 Python 实现。 + +模块组成说明: + +* ``CzscTrader`` / ``CzscSignals`` / ``generate_czsc_signals`` / + ``derive_signals_config`` / ``derive_signals_freqs`` :均来自 + ``czsc._native``(Rust 扩展),承担信号生成与多级别交易决策的核心逻辑。 +* ``WeightBacktest`` :从外部 ``wbt`` 包再次导出,提供基于权重序列的 + 回测能力。 +* ``check_signals_acc`` :旧版 czsc 中以 HTML 截图方式辅助核对信号的工具, + Rust 版本未提供等价实现,因此**不再**在此处导出;如需可视化校验信号, + 请改用 ``czsc.svc`` 中提供的 Streamlit 组件。 +* ``OpensOptimize`` / ``ExitsOptimize`` / ``CzscOpenOptimStrategy`` / + ``CzscExitOptimStrategy`` :定义在 :mod:`czsc.traders.optimize` 中,是对 + ``czsc._native.run_optimize_batch`` 的轻量 Python 封装,用于开仓/平仓 + 参数的批量优化。 """ -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2021/11/1 22:20 -describe: 交易员(traders):使用 CZSC 分析工具进行择时策略的开发,交易等 -""" -from czsc.traders.base import CzscSignals, CzscTrader, check_signals_acc, generate_czsc_signals, get_unique_signals +# 直接从 Rust 原生扩展中导入交易体系核心类与帮助函数, +# 确保 Python 侧只承担"调用-转发"职责,业务逻辑落在 Rust 端。 +from czsc._native import ( + CzscSignals, + CzscTrader, + derive_signals_config, + derive_signals_freqs, + generate_czsc_signals, +) +from wbt import WeightBacktest + +# 兼容老的导入路径:保留 Python 侧的薄封装函数 get_unique_signals。 +from czsc.traders.base import get_unique_signals from czsc.traders.sig_parse import SignalsParser, get_signals_config, get_signals_freqs +# __all__ 显式声明对外公开的符号集合,限定 `from czsc.traders import *` 的行为, +# 同时方便 IDE 与文档工具识别模块的公共 API。 __all__ = [ "CzscSignals", "CzscTrader", "SignalsParser", - "check_signals_acc", + "WeightBacktest", + "derive_signals_config", + "derive_signals_freqs", "generate_czsc_signals", "get_signals_config", "get_signals_freqs", diff --git a/czsc/traders/__init__.pyi b/czsc/traders/__init__.pyi index 3f90a8fcc..e82fe848f 100644 --- a/czsc/traders/__init__.pyi +++ b/czsc/traders/__init__.pyi @@ -1,28 +1,18 @@ -from czsc.traders.base import ( +from czsc._native import ( CzscSignals as CzscSignals, ) -from czsc.traders.base import ( +from czsc._native import ( CzscTrader as CzscTrader, ) -from czsc.traders.base import ( - check_signals_acc as check_signals_acc, +from czsc._native import ( + generate_czsc_signals as generate_czsc_signals, ) from czsc.traders.base import ( - generate_czsc_signals as generate_czsc_signals, + check_signals_acc as check_signals_acc, ) from czsc.traders.base import ( get_unique_signals as get_unique_signals, ) -from czsc.traders.dummy import DummyBacktest as DummyBacktest -from czsc.traders.performance import ( - PairsPerformance as PairsPerformance, -) -from czsc.traders.performance import ( - combine_dates_and_pairs as combine_dates_and_pairs, -) -from czsc.traders.performance import ( - combine_holds_and_pairs as combine_holds_and_pairs, -) from czsc.traders.sig_parse import ( SignalsParser as SignalsParser, ) @@ -32,11 +22,6 @@ from czsc.traders.sig_parse import ( from czsc.traders.sig_parse import ( get_signals_freqs as get_signals_freqs, ) -from czsc.traders.weight_backtest import ( - get_ensemble_weight as get_ensemble_weight, -) -from czsc.traders.weight_backtest import ( - stoploss_by_direction as stoploss_by_direction, -) +from wbt import WeightBacktest as WeightBacktest def __getattr__(name): ... diff --git a/czsc/traders/_rs_signals.py b/czsc/traders/_rs_signals.py deleted file mode 100644 index 1b43a1e72..000000000 --- a/czsc/traders/_rs_signals.py +++ /dev/null @@ -1,136 +0,0 @@ -"""rs_czsc 信号执行桥接层。""" - -from __future__ import annotations - -from typing import Any - -import pandas as pd -from rs_czsc import run_research - - -def _bars_to_dataframe(bars: Any, symbol: str | None = None) -> pd.DataFrame: - if isinstance(bars, pd.DataFrame): - df = bars.copy() - else: - rows = [] - for bar in bars: - rows.append( - { - "symbol": getattr(bar, "symbol", symbol), - "dt": getattr(bar, "dt", None), - "open": getattr(bar, "open", None), - "close": getattr(bar, "close", None), - "high": getattr(bar, "high", None), - "low": getattr(bar, "low", None), - "vol": getattr(bar, "vol", None), - "amount": getattr(bar, "amount", None), - } - ) - df = pd.DataFrame(rows) - - if "symbol" not in df.columns: - df["symbol"] = symbol - elif symbol: - df["symbol"] = df["symbol"].fillna(symbol) - - required = ["symbol", "dt", "open", "close", "high", "low", "vol", "amount"] - for col in required: - if col not in df.columns: - raise ValueError(f"bars 缺少列:{col}") - - df = df[required].copy() - df["dt"] = pd.to_datetime(df["dt"]) - df["symbol"] = df["symbol"].astype(str) - for col in ["open", "close", "high", "low", "vol", "amount"]: - df[col] = pd.to_numeric(df[col], errors="coerce").astype(float) - df = df.dropna(subset=required) - df = df.sort_values("dt").drop_duplicates(subset=["dt"], keep="last").reset_index(drop=True) - return df - - -def _infer_base_freq(bars: Any, signals_config: list[dict]) -> str: - if isinstance(bars, list) and bars: - freq = getattr(getattr(bars[0], "freq", None), "value", None) - if freq: - return str(freq) - - for cfg in signals_config: - freq = cfg.get("freq") - if freq: - return str(freq) - return "日线" - - -def _placeholder_position(symbol: str) -> dict[str, Any]: - return { - "name": "_signals_only", - "symbol": symbol, - "opens": [], - "exits": [], - "interval": 0, - "timeout": 0, - "stop_loss": 0, - "T0": True, - } - - -def run_rs_signal_generation( - bars: Any, - signals_config: list[dict], - *, - sdt: str, - symbol: str | None = None, - bg_max_count: int = 5000, - include_sdt_bar: bool | None = None, -) -> pd.DataFrame: - """使用 rs_czsc 统一执行引擎生成信号 DataFrame。""" - if not signals_config: - return pd.DataFrame() - - bars_df = _bars_to_dataframe(bars, symbol=symbol) - if bars_df.empty: - return pd.DataFrame() - - symbol_ = str(symbol or bars_df["symbol"].iloc[-1]) - strategy: dict[str, Any] = { - "name": f"{symbol_}-signals-only", - "symbol": symbol_, - "base_freq": _infer_base_freq(bars, signals_config), - "signals_module": "czsc.signals", - "signals_config": signals_config, - "positions": [_placeholder_position(symbol_)], - "market": "默认", - "bg_max_count": int(bg_max_count), - } - if include_sdt_bar is not None: - strategy["include_sdt_bar"] = bool(include_sdt_bar) - - res = run_research(bars_df, strategy, sdt=sdt, opts={"emit_signals": True}) - return res.signals_df() - - -def get_last_signal_map( - bars: Any, - signals_config: list[dict], - *, - symbol: str | None = None, - bg_max_count: int = 5000, - include_sdt_bar: bool | None = None, -) -> dict[str, Any]: - """获取最新一根 K 线对应的信号字典。""" - bars_df = _bars_to_dataframe(bars, symbol=symbol) - if bars_df.empty: - return {} - - first_dt = bars_df["dt"].iloc[0] - df = run_rs_signal_generation( - bars, - signals_config, - sdt=pd.to_datetime(first_dt).strftime("%Y-%m-%d %H:%M:%S"), - symbol=symbol, - bg_max_count=bg_max_count, - include_sdt_bar=include_sdt_bar, - ) - if df.empty: - return {} - return df.iloc[-1].to_dict() diff --git a/czsc/traders/base.py b/czsc/traders/base.py index 8f65b4dcf..953f9af47 100644 --- a/czsc/traders/base.py +++ b/czsc/traders/base.py @@ -1,611 +1,90 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/12/24 22:20 -describe: 简单的单仓位策略执行 -""" - -import os -import webbrowser -from collections import OrderedDict -from collections.abc import Callable -from datetime import datetime, timedelta -from typing import AnyStr - -import numpy as np -import pandas as pd - -from czsc.core import CZSC, BarGenerator, Position, RawBar, Signal -from czsc.traders._rs_signals import get_last_signal_map, run_rs_signal_generation -from czsc.traders.sig_parse import get_signals_freqs -from czsc.utils.data.cache import home_path - - -class CzscSignals: - """缠中说禅技术分析理论之多级别信号计算""" - - def __init__(self, bg: BarGenerator | None = None, **kwargs): - """ - - :param bg: K线合成器 - """ - self.name = "CzscSignals" - # cache 是信号计算过程的缓存容器,需要信号计算函数自行维护 - self.cache = OrderedDict() - self.kwargs = kwargs - self.signals_config = kwargs.get("signals_config", []) - - if bg: - self.bg = bg - assert bg.symbol, "bg.symbol is None" - self.symbol = bg.symbol - self.base_freq = bg.base_freq - self.freqs = list(bg.bars.keys()) - self.kas = {freq: CZSC(b) for freq, b in bg.bars.items()} - - last_bar = self.kas[self.base_freq].bars_raw[-1] - self.end_dt, self.bid, self.latest_price = last_bar.dt, last_bar.id, last_bar.close - self.s = OrderedDict() - self.s.update(self.get_signals_by_conf()) - self.s.update(last_bar.__dict__) - else: - self.bg = None - self.symbol = None - self.base_freq = None - self.freqs = None - self.kas = None - self.end_dt, self.bid, self.latest_price = None, None, None - self.s = OrderedDict() - - def __repr__(self): - return f"<{self.name} for {self.symbol}>" - - def get_signals_by_conf(self): - """通过信号参数配置获取信号 - - 函数执行逻辑: - - 1. 函数首先创建一个空的有序字典s。 - 2. 如果self.signals_config不存在,函数直接返回空字典s,否则,函数遍历其中的每一个配置。 - 3. 对于每一个参数,函数提取出信号名称和freq,并根据这两个参数获取相应的信号,获取到的信号被添加到字典s中。 - 4. 函数最后返回字典s,其中包含了所有获取到的信号。 - - 信号参数配置,格式如下: - - signals_config = [ - {'name': 'czsc.signals.tas_ma_base_V221101', 'freq': '日线', 'di': 1, 'ma_type': 'SMA', 'timeperiod': 5}, - {'name': 'czsc.signals.tas_ma_base_V221101', 'freq': '日线', 'di': 5, 'ma_type': 'SMA', 'timeperiod': 5}, - {'name': 'czsc.signals.tas_double_ma_V221203', 'freq': '日线', 'di': 1, 'ma_seq': (5, 20), 'th': 100}, - {'name': 'czsc.signals.tas_double_ma_V221203', 'freq': '日线', 'di': 5, 'ma_seq': (5, 20), 'th': 100}, - ] - - :return: 信号字典 - """ - s = OrderedDict() - if not self.signals_config or not self.bg: - return s - - signal_map = get_last_signal_map( - self.bg.bars[self.base_freq], - self.signals_config, - symbol=self.symbol, - bg_max_count=self.kwargs.get("bg_max_count", 5000), - include_sdt_bar=self.kwargs.get("include_sdt_bar"), - ) - for key, value in signal_map.items(): - if isinstance(key, str) and len(key.split("_")) == 3: - s[key] = value - return s - - def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): - """获取快照 - - 函数执行逻辑: - - 1. 函数首先创建一个Tab对象,用于存储所有的图表和表格。 - 2. 函数遍历所有的freq,对于每一个freq,函数获取相应的CZSC对象,并将其转换为一个图表,然后添加到Tab对象中。 - 3. 函数提取出所有的信号,并按照freq分组。对于每一个freq,函数创建一个表格,包含该freq下的所有信号,然后添加到Tab对象中。 - 4. 如果还有其他的信号,函数创建一个表格,包含所有的其他信号,然后添加到Tab对象中。 - 5. 最后,如果提供了file_html参数,函数将Tab对象渲染为一个HTML文件并保存;否则,函数返回Tab对象。 - - :param file_html: 交易快照保存的 html 文件名 - :param width: 图表宽度 - :param height: 图表高度 - :return: - """ - from pyecharts.charts import Tab - from pyecharts.components import Table - from pyecharts.options import ComponentTitleOpts - - tab = Tab(page_title="{}@{}".format(self.symbol, self.end_dt.strftime("%Y-%m-%d %H:%M"))) - for freq in self.freqs: - ka: CZSC = self.kas[freq] - chart = ka.to_echarts(width, height) - tab.add(chart, freq) - - signals = {k: v for k, v in self.s.items() if len(k.split("_")) == 3} - for freq in self.freqs: - # 按各周期K线分别加入信号表 - freq_signals = {k: signals[k] for k in signals if k.startswith(f"{freq}_")} - for k in freq_signals: - signals.pop(k) - if len(freq_signals) <= 0: - continue - t1 = Table() - t1.add(["名称", "数据"], [[k, v] for k, v in freq_signals.items()]) - t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) - tab.add(t1, f"{freq}信号") - - if len(signals) > 0: - # 加入时间、持仓状态之类的其他信号 - t1 = Table() - t1.add(["名称", "数据"], [[k, v] for k, v in signals.items()]) - t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) - tab.add(t1, "其他信号") - - if file_html: - tab.render(file_html) - else: - return tab - - def open_in_browser(self, width="1400px", height="580px"): - """直接在浏览器中打开分析结果 +"""czsc.traders.base —— 由 ``czsc._native`` 支撑的交易基础对象再导出层。 - 函数执行逻辑: +在当前的 Rust/Python 混合架构下,交易员对象(``CzscTrader``、 +``CzscSignals``)以及与信号生成相关的辅助函数(``generate_czsc_signals``、 +``get_unique_signals``、``derive_signals_config``、``derive_signals_freqs``) +均迁移到 Rust 扩展中实现,以获得更优的性能与并发能力。 - 1. 首先创建一个HTML文件的路径file_html,这个文件将被保存在用户的主目录下,文件名为"temp_czsc_advanced_trader.html"。 - 2. 然后,函数调用self.take_snapshot方法,将分析结果保存为一个HTML文件。 - 3. 最后,函数使用webbrowser.open方法打开这个HTML文件 - """ - file_html = os.path.join(home_path, "temp_czsc_advanced_trader.html") - self.take_snapshot(file_html, width, height) - webbrowser.open(file_html) +本模块的存在主要有两个目的: - def update_signals(self, bar: RawBar): - """输入基础周期已完成K线,更新信号(不更新仓位,仓位更新由 CzscTrader.update 负责) +1. **兼容性**:保持历史导入路径稳定,例如老代码中的 + ``from czsc.traders.base import CzscSignals`` 仍可正常工作。 +2. **薄封装**:仅在必要时提供少量纯 Python 的"调用拼装"逻辑(如 + :func:`get_unique_signals`),将 Rust 侧返回的 DataFrame 转换为 + 业务期望的字符串列表。 - 函数执行逻辑: +按照"Python 仅负责包装、不承载业务逻辑"的总体设计原则,旧版基于 +HTML 快照的 ``check_signals_acc`` 编排函数已被移除,可视化诊断需求 +请改用 ``czsc.svc`` 中的 Streamlit 组件。 +""" - 1. 函数首先调用self.bg.update(bar),输入一个已完成的基础周期K线bar,更新各周期K线。 - 2. 然后,函数遍历所有的K线freq和对应的K线数据,对每一个K线数据,函数调用self.kas[freq].update(b[-1]),更新对应的 CZSC 对象。 - 3. 函数提取出K线的标的代码bar.symbol,并将其赋值给self.symbol。 - 4. 函数提取出基础freq的最后一根K线last_bar,并从中提取出结束时间dt,K线IDid,以及收盘价close,并将它们分别赋值给self.end_dt,self.bid,和self.latest_price。 - 5. 函数创建一个空的有序字典s,并调用self.get_signals_by_conf()获取所有的信号,然后将这些信号更新到字典s中。 - 6. 最后,函数将last_bar的所有属性更新到字典s中。 +from __future__ import annotations - :param bar: 基础周期已完成K线 - :return: None - """ - self.bg.update(bar) - for freq, b in self.bg.bars.items(): - self.kas[freq].update(b[-1]) +# 直接从 Rust 原生扩展导入交易/信号体系的核心类型与函数, +# 上层调用方因此感知不到底层是 Rust 实现。 +from czsc._native import ( + CzscSignals, + CzscTrader, + RawBar, + Signal, + derive_signals_config, + derive_signals_freqs, + generate_czsc_signals, +) - self.symbol = bar.symbol - last_bar = self.kas[self.base_freq].bars_raw[-1] - self.end_dt, self.bid, self.latest_price = last_bar.dt, last_bar.id, last_bar.close - self.s = OrderedDict() - self.s.update(self.get_signals_by_conf()) - self.s.update(last_bar.__dict__) +from czsc.traders.sig_parse import get_signals_freqs -def generate_czsc_signals( +def get_unique_signals( bars: list[RawBar], - signals_config: list[dict], - sdt: AnyStr | datetime = "20170101", - init_n: int = 500, - df=False, + signals_config: list[dict[str, object]], **kwargs, -): - """使用 CzscSignals 生成信号 - - 函数执行逻辑: - - 1. 函数首先从信号配置signals_config中获取所有的freqs。 - 2. 然后,函数将信号计算开始时间sdt转换为datetime类型,并将开始时间之前的K线数据分配给bars_left,开始时间之后的K线数据分配给bars_right。 - 3. 如果bars_right为空,即没有开始时间之后的K线数据,函数会发出一个警告,并返回一个空的DataFrame或空列表。 - 4. 函数创建一个BarGenerator对象bg,并使用bars_left中的K线数据来初始化它。 - 5. 函数创建一个CzscSignals对象cs,并将bg和信号配置signals_config作为参数传入。 - 6. 函数遍历bars_right中的每一根K线,对于每一根K线,函数调用cs.update_signals(bar)来更新信号,并将更新后的信号添加到_sigs列表中。 - 7. 最后,如果df参数为True,函数将_sigs转换为DataFrame并返回;否则,直接返回_sigs。 - - :param bars: 基础周期 K 线序列 - :param signals_config: 信号函数配置,格式如下: - signals_config = [ - {'name': 'czsc.signals.tas_ma_base_V221101', 'freq': '日线', 'di': 1, 'ma_type': 'SMA', 'timeperiod': 5}, - {'name': 'czsc.signals.tas_ma_base_V221101', 'freq': '日线', 'di': 5, 'ma_type': 'SMA', 'timeperiod': 5}, - {'name': 'czsc.signals.tas_double_ma_V221203', 'freq': '日线', 'di': 1, 'ma_seq': (5, 20), 'th': 100}, - {'name': 'czsc.signals.tas_double_ma_V221203', 'freq': '日线', 'di': 5, 'ma_seq': (5, 20), 'th': 100}, - ] - :param sdt: 信号计算开始时间 - :param init_n: 用于 BarGenerator 初始化的基础周期K线数量 - :param df: 是否返回 df 格式的信号计算结果,默认 False - :return: 信号计算结果 - """ - if not bars: - return pd.DataFrame() if df else [] - - sdt = pd.to_datetime(sdt) # type: ignore[arg-type] - bars_right = [x for x in bars if x.dt >= sdt] # type: ignore[attr-defined] - if len(bars_right) == 0: - import warnings - - warnings.warn("右侧K线为空,无法进行信号生成", RuntimeWarning, stacklevel=2) - if df: - return pd.DataFrame() - else: - return [] - - signals_df = run_rs_signal_generation( - bars, - signals_config, - sdt=sdt.strftime("%Y-%m-%d %H:%M:%S"), - symbol=getattr(bars[0], "symbol", None), - bg_max_count=kwargs.get("bg_max_count", 5000), - include_sdt_bar=kwargs.get("include_sdt_bar"), - ) - if df: - return signals_df - return signals_df.to_dict("records") - - -def check_signals_acc(bars: list[RawBar], signals_config: list[dict], delta_days: int = 5, **kwargs) -> None: - """输入基础周期K线和想要验证的信号,输出信号识别结果的快照 - - 函数执行逻辑: - - 1. 函数首先获取基础周期K线的base_freq,并检查输入的K线数据bars是否按时间升序排列。如果bars的长度小于600,函数直接返回。 - 2. 然后,函数调用generate_czsc_signals方法,生成Czsc信号,并将结果保存在df中。 - 3. 函数提取出df中所有的信号列s_cols,并打印每一列的值的数量。然后,函数将所有的信号添加到signals列表中。 - 4. 函数将bars分为两部分,bars_left和bars_right,并获取信号配置signals_config中的所有freqs。 - 5. 函数创建一个BarGenerator对象bg,并使用bars_left中的K线数据来初始化它。 - 6. 函数创建一个CzscSignals对象ct,并将bg和信号配置signals_config作为参数传入。 - 7. 函数创建一个字典last_dt,用于存储每一个信号最后一次出现的时间。 - 8. 函数遍历bars_right中的每一根K线,对于每一根K线,函数调用ct.update_signals(bar)来更新信号。 - 9. 对于每一个信号,如果当前K线的时间与该信号最后一次出现的时间的差值大于delta_days,并且该信号与当前的信号匹配, - 函数将创建一个HTML文件,保存信号识别结果的快照,并更新该信号最后一次出现的时间。 - - :param bars: 原始K线 - :param signals_config: 需要验证的信号列表 - :param delta_days: 两次相同信号之间的间隔天数 - :return: None - """ - base_freq = str(bars[-1].freq.value) - assert bars[2].dt > bars[1].dt > bars[0].dt and bars[2].id > bars[1].id, "bars 中的K线元素必须按时间升序" - if len(bars) < 600: - return - - df = generate_czsc_signals(bars, signals_config=signals_config, df=True, **kwargs) - s_cols = [x for x in df.columns if len(x.split("_")) == 3] - signals = [] - for col in s_cols: - print("=" * 100, "\n", df[col].value_counts()) - signals.extend([Signal(f"{col}_{v}") for v in df[col].unique() if "其他" not in v]) - - print(f"signals: {'+' * 100}") - for row in signals: - print(f"- {row}") - - bars_left = bars[:500] - bars_right = bars[500:] - freqs = get_signals_freqs(signals_config) - bg = BarGenerator(base_freq=base_freq, freqs=freqs, max_count=5000) - for bar in bars_left: - bg.update(bar) - - ct = CzscSignals(bg, signals_config=signals_config, **kwargs) - last_dt = {signal.key: ct.end_dt for signal in signals} - - for bar in bars_right: - ct.update_signals(bar) - - for signal in signals: - html_path = os.path.join(home_path, signal.key) - os.makedirs(html_path, exist_ok=True) - if bar.dt - last_dt[signal.key] > timedelta(days=delta_days) and signal.is_match(ct.s): - file_html = f"{bar.symbol}_{bar.dt.strftime('%Y%m%d_%H%M')}_{signal.key}_{ct.s[signal.key]}.html" - file_html = os.path.join(html_path, file_html) - print(file_html) - ct.take_snapshot(file_html, height=kwargs.get("height", "680px")) - last_dt[signal.key] = bar.dt - - -def get_unique_signals(bars: list[RawBar], signals_config: list[dict], **kwargs): - """获取信号函数中定义的所有信号列表 - - 函数执行逻辑: - - 1. 函数首先检查输入的K线数据bars是否按时间升序排列。如果bars的长度小于600,函数直接返回一个空列表。 - 2. 然后,函数调用generate_czsc_signals方法,生成CZSC信号,并将结果保存在df中。 - 3. 函数遍历df中的所有列,对于每一列,如果列名包含三个部分,函数提取出该列中的所有唯一值,然后将列名和每一个唯一值组合成一个新的信号, - 并添加到_res列表中。注意,如果唯一值中包含"其他",则不会被添加到_res中。 - 4. 最后,函数返回_res,其中包含了所有的唯一信号。 - - :param bars: 基础K线数据 - :param signals_config: 信号函数配置 - :param kwargs: 传递给generate_czsc_signals方法的参数 - :return: 信号列表 +) -> list[str]: + """对 Rust ``generate_czsc_signals`` 结果做去重整理的薄封装。 + + 本函数会先调用 Rust 实现批量计算所有信号,得到一张包含若干信号列的 + DataFrame;随后逐列提取所有非"其他"取值,按照惯用的 + ``"<列名>_<取值>"`` 字符串格式拼装并去重返回,便于上层做信号字典构建、 + 回放校验或断言用。 + + Args: + bars: 标准化后的原始 K 线序列;序列长度过短时不进行计算。 + signals_config: 信号配置列表,会原样透传给 Rust 端的 + ``generate_czsc_signals``。 + **kwargs: 透传给 ``generate_czsc_signals`` 的其他可选参数。 + + Returns: + 去重后的信号字符串列表;若输入 K 线数量不足 600 根,则直接返回 + 空列表(避免在样本不足时产生噪声信号)。 + + Notes: + - 仅识别列名按 ``"a_b_c"`` 三段式命名的列,与 czsc 信号命名约定保持一致。 + - "其他"是缠论信号体系中通用的"无意义/未匹配"占位取值,会被显式跳过。 """ - assert bars[2].dt > bars[1].dt > bars[0].dt and bars[2].id > bars[1].id, "bars 中的K线元素必须按时间升序" + # 当输入 K 线数据不足以形成稳定的信号样本时直接返回空列表, + # 避免下游被冷启动阶段的噪声信号污染。 if len(bars) < 600: return [] + # 通过 Rust 实现批量计算所有信号取值,输出为 pandas DataFrame。 df = generate_czsc_signals(bars, signals_config=signals_config, df=True, **kwargs) - _res = [] - for col in [x for x in df.columns if len(x.split("_")) == 3]: - _res.extend([f"{col}_{v}" for v in df[col].unique() if "其他" not in v]) - return _res - - -class CzscTrader(CzscSignals): - """缠中说禅技术分析理论之多级别联立交易决策类(支持多策略独立执行)""" - - def __init__( - self, - bg: BarGenerator | None = None, - positions: list[Position] | None = None, - ensemble_method: AnyStr | Callable = "mean", - **kwargs, - ): - """ - - 初始化逻辑: - - 1. 首先接收几个参数: - bg是一个可选的BarGenerator对象, - positions是一个可选的Position对象列表, - ensemble_method是一个集成方法,可以是字符串或者一个回调函数。 - 2. 函数将positions赋值给self.positions。如果positions不为空,函数会检查positions中的所有名称是否都是唯一的,如果不是,函数会抛出一个断言错误。 - 3. 函数将ensemble_method赋值给self.__ensemble_method。这个参数用于指定如何从多个仓位中集成一个仓位。 - 它可以是"mean"(平均),"vote"(投票),"max"(最大),或者一个回调函数。 - 4. 函数将"CzscTrader"赋值给self.name。 - 5. 最后,函数调用父类的初始化函数,传入bg和其他参数。 - - :param bg: bar generator 对象 - :param get_signals: 信号计算函数,输入是 CzscSignals 对象,输出是信号字典 - :param ensemble_method: 多个仓位集成一个仓位的方法,可选值 mean, vote, max;也可以传入一个回调函数 - - 假设有三个仓位对象,当前仓位分别是 1, 1, -1 - mean - 平均仓位,pos = np.mean([1, 1, -1]) = 0.33 - vote - 投票表决,pos = 1 - max - 取最大,pos = 1 - - 对于传入回调函数的情况,函数的输入为 dict,key 为 position.name,value 为 position.pos, 样例输入: - {'多头策略A': 1, '多头策略B': 1, '空头策略A': -1} - """ - self.positions = positions - if self.positions: - _pos_names = [x.name for x in self.positions] - assert len(_pos_names) == len(set(_pos_names)), "仓位策略名称不能重复" - self.__ensemble_method = ensemble_method - self.name = "CzscTrader" - super().__init__(bg, **kwargs) - - def __repr__(self): - return f"<{self.name} for {self.symbol}>" - - def update(self, bar: RawBar) -> None: - """输入基础周期已完成K线,更新信号,更新仓位 - - 函数执行逻辑: - - 1. 函数首先接收一个参数bar,这是一个已完成的基础周期K线。 - 2. 函数调用self.update_signals(bar),输入这个已完成的基础周期K线,更新信号。 - 3. 如果self.positions不为空,即存在仓位,函数遍历所有的仓位,对于每一个仓位,函数调用position.update(self.s),更新该仓位的状态。 - - :param bar: 基础周期已完成K线 - :return: None - """ - self.update_signals(bar) - if self.positions: - for position in self.positions: - position.update(self.s) - - def on_sig(self, sig: dict) -> None: - """通过信号字典直接交易,用于快速回测场景 - - 函数执行逻辑: - - 1. 函数首先接收一个参数sig,这是一个信号字典,赋值给self.s。 - 2. 函数从sig中提取出标的代码symbol,结束时间dt,K线ID id,以及收盘价close, - 并将它们分别赋值给self.symbol,self.end_dt,self.bid,和self.latest_price。 - 4. 如果self.positions不为空,即存在持仓策略,函数遍历所有position,函数调用position.update(self.s),更新该仓位的状态 - - :param sig: 信号字典 - :return: None - """ - self.s = sig - self.symbol, self.end_dt = self.s["symbol"], self.s["dt"] - self.bid, self.latest_price = self.s["id"], self.s["close"] - if self.positions: - for position in self.positions: - position.update(self.s) - - def on_bar(self, bar: RawBar) -> None: - """输入基础周期已完成K线,更新信号,更新仓位 - - :param bar: 基础周期已完成K线 - :return: None - """ - self.update(bar) - - @property - def pos_changed(self) -> bool: - """判断仓位是否发生变化 - - 1. 函数首先检查self.positions是否为空。如果为空,即没有仓位,函数直接返回False。 - 2. 如果self.positions不为空,函数遍历所有的仓位,对于每一个仓位,函数检查其pos_changed属性。 - 如果任何一个仓位的pos_changed属性为True,即该仓位发生了变化,函数返回True。 - - :return: True/False - """ - if not self.positions: - return False - return any(position.pos_changed for position in self.positions) - - def get_ensemble_pos(self, method: AnyStr | Callable = None) -> float: - """获取多个仓位的集成仓位 - - 函数执行逻辑: - - 1. 函数首先检查self.positions是否为空。如果为空,即没有仓位,函数直接返回0。 - 2. 如果self.positions不为空,函数获取集成方法method。如果没有传入method参数,函数使用self.__ensemble_method作为集成方法。 - 3. 如果method是一个字符串,函数将其转换为小写,然后获取所有仓位的仓位序列pos_seq。 - 1. 如果method是"mean",函数计算pos_seq的平均值作为集成仓位。 - 2. 如果method是"vote",函数计算pos_seq的和的符号作为集成仓位。 - 3. 如果method是"max",函数获取pos_seq的最大值作为集成仓位。 - 4. 如果method不是以上任何一个值,函数抛出一个值错误。 - - 4. 如果method不是一个字符串,即它是一个回调函数,函数将所有仓位的名称和仓位组成的字典作为参数传入method,并将返回值作为集成仓位。 - - :param method: 多个仓位集成一个仓位的方法,可选值 mean, vote, max;也可以传入一个回调函数 - - 假设有三个仓位对象,当前仓位分别是 1, 1, -1 - mean - 平均仓位,pos = np.mean([1, 1, -1]) = 0.33 - vote - 投票表决,pos = 1 - max - 取最大,pos = 1 - - 对于传入回调函数的情况,输入是 self.positions - - :return: pos, 集成仓位 - """ - if not self.positions: - return 0 - - method = method if method else self.__ensemble_method - if isinstance(method, str): - method = method.lower() - pos_seq = [x.pos for x in self.positions] - - if method == "mean": - pos = np.mean(pos_seq) - elif method == "vote": - pos = np.sign(sum(pos_seq)) - elif method == "max": - pos = max(pos_seq) - else: - raise ValueError - - else: - pos = method({x.name: x.pos for x in self.positions}) - - return pos - - def get_position(self, name: str) -> Position | None: - """获取指定名称的仓位策略对象 - - 函数执行逻辑: - - 1. 函数首先接收一个参数name,这是要查找的仓位名称。 - 2. 函数检查self.positions是否为空。如果为空,即没有仓位,函数直接返回None。 - 3. 如果self.positions不为空,函数遍历所有的仓位,对于每一个仓位,函数检查其名称是否与输入的名称相同。如果相同,函数返回该仓位。 - 4. 如果遍历所有的仓位都没有找到与输入名称相同的仓位,函数返回None。 - - :param name: 仓位名称 - :return: Position - """ - if not self.positions: - return None - - for position in self.positions: - if position.name == name: - return position - - return None - - def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): - """获取快照 - - :param file_html: 交易快照保存的 html 文件名 - :param width: 图表宽度 - :param height: 图表高度 - :return: - """ - from pyecharts.charts import Tab - from pyecharts.components import Table - from pyecharts.options import ComponentTitleOpts - - tab = Tab(page_title="{}@{}".format(self.symbol, self.end_dt.strftime("%Y-%m-%d %H:%M"))) - for freq in self.freqs: - ka: CZSC = self.kas[freq] - bs = None - if freq == self.base_freq and self.positions: - # 在基础周期K线上加入最近的操作记录 - bs = [] - for pos in self.positions: - for op in pos.operates: - if op["dt"] >= ka.bars_raw[0].dt: - _op = dict(op) - _op["op_desc"] = f"{pos.name} | {_op['op_desc']}" - bs.append(_op) - - chart = ka.to_echarts(width, height, bs) - tab.add(chart, freq) - - signals = {k: v for k, v in self.s.items() if len(k.split("_")) == 3} - for freq in self.freqs: - # 按各周期K线分别加入信号表 - freq_signals = {k: signals[k] for k in signals if k.startswith(f"{freq}_")} - for k in freq_signals: - signals.pop(k) - if len(freq_signals) <= 0: - continue - t1 = Table() - t1.add(["名称", "数据"], [[k, v] for k, v in freq_signals.items()]) - t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) - tab.add(t1, f"{freq}信号") - - if len(signals) > 0: - # 加入时间、持仓状态之类的其他信号 - t1 = Table() - t1.add(["名称", "数据"], [[k, v] for k, v in signals.items()]) - t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) - tab.add(t1, "其他信号") - - if file_html: - tab.render(file_html) - else: - return tab - - def get_ensemble_weight(self, method: AnyStr | Callable | None = None): - """获取 CzscTrader 中所有 positions 按照 method 方法集成之后的权重 - - 函数执行逻辑: - - 1. 函数首先接收一个参数method,这是集成方法,可以是字符串或者一个回调函数。 - 2. 函数检查是否提供了method参数。如果没有提供,函数使用self.__ensemble_method作为集成方法;如果提供了,函数使用提供的method作为集成方法。 - 3. 函数调用get_ensemble_weight函数,输入self和method,获取所有仓位按照指定方法集成之后的权重。 - - :param method: str or callable - 集成方法,可选值包括:'mean', 'max', 'min', 'vote' - 也可以传入自定义的函数,函数的输入为 dict,key 为 position.name,value 为 position.pos, 样例输入: - {'多头策略A': 1, '多头策略B': 1, '空头策略A': -1} - :param kwargs: - :return: pd.DataFrame - columns = ['dt', 'symbol', 'weight', 'price'] - """ - from czsc.traders.weight_backtest import get_ensemble_weight - - method = method if method else self.__ensemble_method - return get_ensemble_weight(self, method) - - def weight_backtest(self, **kwargs): - """执行仓位集成权重的回测 - - :param kwargs: - - - method: str or callable,集成方法,参考 get_ensemble_weight 方法 - - digits: int,权重小数点后保留的位数,例如 2 表示保留两位小数 - - fee_rate: float,手续费率,例如 0.0002 表示万二 - - :return: 回测结果 - """ - from czsc.core import WeightBacktest - - method = kwargs.get("method", self.__ensemble_method) - digits = kwargs.get("digits", 2) - fee_rate = kwargs.get("fee_rate", 0.000) - dfw = self.get_ensemble_weight(method) - dfw["dt"] = pd.to_datetime(dfw["dt"]) - dfw = dfw[["dt", "symbol", "weight", "price"]].copy() - wb = WeightBacktest(dfw, digits=digits, fee_rate=fee_rate) - return wb + out: list[str] = [] + # 仅遍历列名形如 "a_b_c" 的信号列;其他列(如时间戳、价格等)被忽略。 + for col in [c for c in df.columns if len(c.split("_")) == 3]: + # 将每列的非"其他"取值拼装成完整的信号串,构建去重序列。 + out.extend(f"{col}_{v}" for v in df[col].unique() if "其他" not in v) + return out + + +# __all__ 列出本模块对外暴露的全部公共符号,限定 `import *` 行为, +# 也方便文档生成工具与类型检查器识别公共 API 范围。 +__all__ = [ + "CzscSignals", + "CzscTrader", + "Signal", + "derive_signals_config", + "derive_signals_freqs", + "generate_czsc_signals", + "get_signals_freqs", + "get_unique_signals", +] diff --git a/czsc/traders/base.pyi b/czsc/traders/base.pyi deleted file mode 100644 index 52186b038..000000000 --- a/czsc/traders/base.pyi +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Callable -from datetime import datetime -from typing import AnyStr - -from _typeshed import Incomplete - -from czsc.core import ( - CZSC as CZSC, -) -from czsc.core import ( - BarGenerator as BarGenerator, -) -from czsc.core import ( - Position as Position, -) -from czsc.core import ( - RawBar as RawBar, -) -from czsc.core import ( - Signal as Signal, -) -from czsc.traders.sig_parse import get_signals_freqs as get_signals_freqs -from czsc.utils.data.cache import home_path as home_path - -class CzscSignals: - name: str - cache: Incomplete - kwargs: Incomplete - signals_config: Incomplete - bg: Incomplete - symbol: Incomplete - base_freq: Incomplete - freqs: Incomplete - kas: Incomplete - s: Incomplete - def __init__(self, bg: BarGenerator | None = None, **kwargs) -> None: ... - def get_signals_by_conf(self): ... - def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): ... - def open_in_browser(self, width: str = "1400px", height: str = "580px") -> None: ... - def update_signals(self, bar: RawBar): ... - -def generate_czsc_signals( - bars: list[RawBar], - signals_config: list[dict], - sdt: AnyStr | datetime = "20170101", - init_n: int = 500, - df: bool = False, - **kwargs, -): ... -def check_signals_acc(bars: list[RawBar], signals_config: list[dict], delta_days: int = 5, **kwargs) -> None: ... -def get_unique_signals(bars: list[RawBar], signals_config: list[dict], **kwargs): ... - -class CzscTrader(CzscSignals): - positions: Incomplete - name: str - def __init__( - self, - bg: BarGenerator | None = None, - positions: list[Position] | None = None, - ensemble_method: AnyStr | Callable = "mean", - **kwargs, - ) -> None: ... - def update(self, bar: RawBar) -> None: ... - s: Incomplete - def on_sig(self, sig: dict) -> None: ... - def on_bar(self, bar: RawBar) -> None: ... - @property - def pos_changed(self) -> bool: ... - def get_ensemble_pos(self, method: AnyStr | Callable = None) -> float: ... - def get_position(self, name: str) -> Position | None: ... - def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): ... - def get_ensemble_weight(self, method: AnyStr | Callable | None = None): ... - def weight_backtest(self, **kwargs): ... diff --git a/czsc/traders/cwc.py b/czsc/traders/cwc.py deleted file mode 100644 index 519f04889..000000000 --- a/czsc/traders/cwc.py +++ /dev/null @@ -1,875 +0,0 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2024/12/30 15:19 -describe: 基于 clickhouse 的策略持仓权重管理,cwc 为 clickhouse weights client 的缩写 - -时区标记:https://clockhub.app/zh-CN/timezone - -推荐在环境变量中设置 clickhouse 的连接信息,如下: - -- CLICKHOUSE_HOST: 服务器地址,如 127.0.0.1 -- CLICKHOUSE_PORT: 服务器端口,如 9000 -- CLICKHOUSE_USER: 用户名, 如 default -- CLICKHOUSE_PASS: 密码, 如果没有密码,可以设置为空字符串 - -额外可选配置: - -- CLICKHOUSE_CONNECT_TIMEOUT: 建立连接的超时时间(秒),默认为 10 -- CLICKHOUSE_SEND_RECEIVE_TIMEOUT: 发送/接收(读写)的超时时间(秒),默认为 60 -""" - -# pip install clickhouse_connect -i https://pypi.tuna.tsinghua.edu.cn/simple -import os -from typing import TYPE_CHECKING -from zoneinfo import ZoneInfo - -import loguru -import pandas as pd - -if TYPE_CHECKING: - from clickhouse_connect.driver.client import Client - - -def _ensure_timestamp(value, tz=ZoneInfo("Asia/Shanghai")): - """将任意时间对象转换为带时区的 Timestamp。 - - :param value: 时间值 - :param tz: 时区,默认 Asia/Shanghai - :return: 带时区的 Timestamp - """ - if value is None: - return pd.NaT - if isinstance(value, str) and not value.strip(): - return pd.NaT - - ts = pd.to_datetime(value, errors="coerce") - if pd.isna(ts): - return pd.NaT - - if ts.tzinfo is None: - return ts.tz_localize(tz) - return ts.tz_convert(tz) - - -def _ensure_series_tz(series: pd.Series, tz=ZoneInfo("Asia/Shanghai")) -> pd.Series: - """确保 Series 中的时间字段带有指定时区。 - - :param series: 时间序列 - :param tz: 时区,默认 Asia/Shanghai - :return: 带时区的 Series - """ - from pandas.api.types import is_datetime64tz_dtype - - ser = pd.to_datetime(series, errors="coerce") - if is_datetime64tz_dtype(ser): - return ser.dt.tz_convert(tz) - return ser.dt.tz_localize(tz) - - -def _format_for_db(value, tz=ZoneInfo("Asia/Shanghai")) -> str | None: - """将带时区的时间对象格式化为 ClickHouse 兼容的字符串。 - - :param value: 时间值 - :param tz: 时区,默认 Asia/Shanghai - :return: 格式化的时间字符串 - """ - ts = _ensure_timestamp(value, tz=tz) - if pd.isna(ts): - return None - return ts.astimezone(tz).replace(tzinfo=None).strftime("%Y-%m-%d %H:%M:%S") - - -def _localize_dataframe_columns(df: pd.DataFrame, columns, tz=ZoneInfo("Asia/Shanghai")) -> pd.DataFrame: - """将 DataFrame 中指定列本地化为指定时区。 - - :param df: DataFrame - :param columns: 需要转换的列名列表 - :param tz: 时区,默认 Asia/Shanghai - :return: 转换后的 DataFrame - """ - for col in columns: - if col in df.columns: - df[col] = _ensure_series_tz(df[col], tz=tz) - return df - - -def __db_from_env(): - host = os.getenv("CLICKHOUSE_HOST") - port_str = os.getenv("CLICKHOUSE_PORT") - port = int(port_str) if port_str else None - user = os.getenv("CLICKHOUSE_USER") - password = os.getenv("CLICKHOUSE_PASS") - # 建立连接的超时时间(秒) - connect_timeout = int(os.getenv("CLICKHOUSE_CONNECT_TIMEOUT", 10)) - # 发送/接收(读写)的超时时间(秒) - send_receive_timeout = int(os.getenv("CLICKHOUSE_SEND_RECEIVE_TIMEOUT", 60)) - - if not (host and port and user and password): - raise ValueError( - """ - 请设置环境变量: - - # 必须 - - CLICKHOUSE_HOST: 服务器地址,如 127.0.0.1 - - CLICKHOUSE_PORT: 服务器端口,如 9000 - - CLICKHOUSE_USER: 用户名, 如 default - - CLICKHOUSE_PASS: 密码, 如果没有密码,可以设置为空字符串 - - # 可选 - - CLICKHOUSE_CONNECT_TIMEOUT: 建立连接的超时时间(秒),默认为 10 - - CLICKHOUSE_SEND_RECEIVE_TIMEOUT: 发送/接收(读写)的超时时间(秒),默认为 60 - - CZSC_TIMEZONE: 时区名称,例如 Asia/Shanghai,默认使用 Asia/Shanghai;若系统不支持该时区,则回退到 UTC - - """ - ) - - import clickhouse_connect as ch - - db = ch.get_client( - host=host, - port=port, - user=user, - password=password, - connect_timeout=connect_timeout, - send_receive_timeout=send_receive_timeout, - ) - return db - - -def init_latest_weights_view(db: "Client | None" = None, database="czsc_strategy", **kwargs): - """ - 策略类型有 cs 和 ts,这两种策略对应的写入逻辑有所区别,需要单独创建最新持仓的视图,然后再合并。 - """ - db = db or __db_from_env() - logger = kwargs.get("logger", loguru.logger) - - # 创建截面策略的最新持仓视图 - cs_view_sql = f""" - CREATE VIEW IF NOT EXISTS {database}.cs_latest_weights AS - WITH latest_dates AS ( - SELECT - strategy, - MAX(dt) AS latest_dt - FROM {database}.weights - GROUP BY strategy - ) - SELECT - w.dt as dt, - w.symbol as symbol, - w.weight as weight, - w.strategy as strategy, - w.update_time as update_time - FROM {database}.weights w - JOIN latest_dates ld ON w.strategy = ld.strategy AND w.dt = ld.latest_dt - JOIN {database}.metas m ON w.strategy = m.strategy - WHERE m.weight_type = 'cs' - """ - db.command(cs_view_sql) - logger.info("cs_latest_weights 视图初始化完成") - - # 创建时序策略的最新持仓视图 - ts_view_sql = f""" - CREATE VIEW IF NOT EXISTS {database}.ts_latest_weights AS - WITH latest_records AS ( - SELECT - strategy, - symbol, - MAX(dt) AS latest_dt - FROM {database}.weights - GROUP BY strategy, symbol - ) - SELECT - w.dt as dt, - w.symbol as symbol, - w.weight as weight, - w.strategy as strategy, - w.update_time as update_time - FROM {database}.weights w - JOIN latest_records lr ON w.strategy = lr.strategy - AND w.symbol = lr.symbol - AND w.dt = lr.latest_dt - JOIN {database}.metas m ON w.strategy = m.strategy - WHERE m.weight_type = 'ts' - """ - db.command(ts_view_sql) - logger.info("ts_latest_weights 视图初始化完成") - - # 创建合并的最新持仓视图 - latest_view_sql = f""" - CREATE VIEW IF NOT EXISTS {database}.latest_weights AS - SELECT * FROM {database}.ts_latest_weights - UNION ALL - SELECT * FROM {database}.cs_latest_weights - """ - db.command(latest_view_sql) - logger.info("latest_weights 视图初始化完成") - - logger.info("所有最新持仓权重视图初始化完成") - - -def init_tables(db: "Client | None" = None, database="czsc_strategy", **kwargs): - """ - 创建数据库表 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param database: str, 数据库名称 - :param kwargs: dict, 数据表名和建表语句 - :return: None - """ - logger = kwargs.get("logger", loguru.logger) - - db = db or __db_from_env() - - # 创建数据库 - db.command(f"CREATE DATABASE IF NOT EXISTS {database}") - - metas_table = f""" - CREATE TABLE IF NOT EXISTS {database}.metas ( - strategy String NOT NULL, -- 策略名(唯一且不能为空) - base_freq String, -- 周期 - description String, -- 描述 - author String, -- 作者 - outsample_sdt DateTime('Asia/Shanghai'), -- 样本外起始时间 - create_time DateTime('Asia/Shanghai'), -- 策略入库时间 - update_time DateTime('Asia/Shanghai'), -- 策略更新时间 - heartbeat_time DateTime('Asia/Shanghai'), -- 最后一次心跳时间 - weight_type String, -- 策略上传的权重类型,ts 或 cs - status String DEFAULT '实盘', -- 策略状态:实盘、废弃 - memo String -- 策略备忘信息 - ) - ENGINE = ReplacingMergeTree() - ORDER BY strategy; - """ - - weights_table = f""" - CREATE TABLE IF NOT EXISTS {database}.weights ( - dt DateTime('Asia/Shanghai'), -- 持仓权重时间 - symbol String, -- 符号(例如,股票代码或其他标识符) - weight Float64, -- 策略持仓权重值 - strategy String, -- 策略名称 - update_time DateTime('Asia/Shanghai') -- 持仓权重更新时间 - ) - ENGINE = ReplacingMergeTree() - ORDER BY (strategy, dt, symbol); - """ - - returns_table = f""" - CREATE TABLE IF NOT EXISTS {database}.returns ( - dt DateTime('Asia/Shanghai'), -- 时间 - symbol String, -- 符号(例如,股票代码或其他标识符) - returns Float64, -- 策略收益,从上一个 dt 到当前 dt 的收益 - strategy String, -- 策略名称 - update_time DateTime('Asia/Shanghai') -- 更新时间 - ) - ENGINE = ReplacingMergeTree() - ORDER BY (strategy, dt, symbol); - """ - - db.command(metas_table) - logger.info("metas 表创建成功!") - db.command(weights_table) - logger.info("weights 表创建成功!") - db.command(returns_table) - logger.info("returns 表创建成功!") - - -def initialize(db: "Client | None" = None, database="czsc_strategy", **kwargs): - """初始化数据库,包括创建数据表和最新持仓视图 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param database: str, 数据库名称 - :param kwargs: dict, 其他参数 - :return: None - """ - db = db or __db_from_env() - init_tables(db=db, database=database, **kwargs) - init_latest_weights_view(db=db, database=database, **kwargs) - - -def get_meta( - strategy, - db: "Client | None" = None, - database="czsc_strategy", - logger=loguru.logger, - tz=ZoneInfo("Asia/Shanghai"), -) -> dict: - """获取策略元数据 - - :param strategy: str, 策略名称 - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param database: str, 数据库名称 - :param logger: loguru.logger, 日志记录器 - :param tz: 时区,默认 Asia/Shanghai - :return: dict - """ - db = db or __db_from_env() - df = db.query_df( - f"SELECT * FROM {database}.metas final WHERE strategy = %(strategy)s", parameters={"strategy": strategy} - ) - - if df.empty: - logger.warning(f"策略 {strategy} 不存在元数据") - return {} - - assert len(df) == 1, f"策略 {strategy} 存在多条元数据,请检查" - df = _localize_dataframe_columns(df, ["outsample_sdt", "create_time", "update_time", "heartbeat_time"], tz=tz) - return df.iloc[0].to_dict() - - -def get_all_metas(db: "Client | None" = None, database="czsc_strategy", tz=ZoneInfo("Asia/Shanghai")) -> pd.DataFrame: - """获取所有策略元数据 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param database: str, 数据库名称 - :param tz: 时区,默认 Asia/Shanghai - :return: pd.DataFrame - """ - db = db or __db_from_env() - df = db.query_df(f"SELECT * FROM {database}.metas final") - if not df.empty: - df = _localize_dataframe_columns(df, ["outsample_sdt", "create_time", "update_time", "heartbeat_time"], tz=tz) - return df - - -def set_meta( - strategy, - base_freq, - description, - author, - outsample_sdt, - weight_type="ts", - status="实盘", - memo="", - logger=loguru.logger, - overwrite=False, - database="czsc_strategy", - db: "Client | None" = None, - tz=ZoneInfo("Asia/Shanghai"), -): - """设置策略元数据 - - :param strategy: str, 策略名 - :param base_freq: str, 周期 - :param description: str, 描述 - :param author: str, 作者 - :param outsample_sdt: str, 样本外起始时间 - :param weight_type: str, 权重类型,ts 或 cs - :param status: str, 策略状态,实盘 或 废弃 - :param memo: str, 备注 - :param logger: loguru.logger, 日志记录器 - :param overwrite: bool, 是否覆盖已有元数据 - :param database: str, 数据库名称 - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param tz: 时区,默认 Asia/Shanghai - :return: None - """ - db = db or __db_from_env() - - outsample_sdt = _ensure_timestamp(outsample_sdt, tz=tz) - current_time = pd.Timestamp.now(tz=tz) - meta = get_meta(db=db, strategy=strategy, database=database, tz=tz) - - if not overwrite and meta: - logger.warning(f"策略 {strategy} 已存在元数据,如需更新请设置 overwrite=True") - return - - create_time = current_time if not meta else _ensure_timestamp(meta.get("create_time"), tz=tz) - - # 构建DataFrame用于插入 - df = pd.DataFrame( - [ - { - "strategy": strategy, - "base_freq": base_freq, - "description": description, - "author": author, - "outsample_sdt": outsample_sdt, - "create_time": create_time, - "update_time": current_time, - "heartbeat_time": current_time, - "weight_type": weight_type, - "status": status, - "memo": memo, - } - ] - ) - res = db.insert_df(f"{database}.metas", df) - logger.info(f"{strategy} set_metadata: {res.summary}") - - -def __send_heartbeat( - db: "Client", strategy, logger=loguru.logger, database="czsc_strategy", tz=ZoneInfo("Asia/Shanghai") -): - """发送心跳 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param strategy: str, 策略名称 - :param logger: loguru.logger, 日志记录器 - :param database: str, 数据库名称 - :param tz: 时区,默认 Asia/Shanghai - :return: None - """ - try: - meta = get_meta(db=db, strategy=strategy, database=database, logger=logger, tz=tz) - if not meta: - logger.warning(f"策略 {strategy} 不存在元数据,无法发送心跳") - return - - current_time = _format_for_db(pd.Timestamp.now(tz=tz), tz=tz) - db.command( - f"ALTER TABLE {database}.metas UPDATE heartbeat_time = %(current_time)s WHERE strategy = %(strategy)s", - parameters={"current_time": current_time, "strategy": strategy}, - ) - logger.info(f"策略 {strategy} 发送心跳成功") - except Exception as e: - logger.error(f"发送心跳失败: {e}") - raise - - -def get_strategy_weights( - strategy, - db: "Client | None" = None, - sdt=None, - edt=None, - symbols=None, - database="czsc_strategy", - tz=ZoneInfo("Asia/Shanghai"), -): - """获取策略持仓权重 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param strategy: str, 策略名称 - :param sdt: str, 开始时间 - :param edt: str, 结束时间 - :param symbols: list, 符号列表 - :param database: str, 数据库名称 - :return: pd.DataFrame - """ - db = db or __db_from_env() - - parameters = {"strategy": strategy} - query = f""" - SELECT * FROM {database}.weights final WHERE strategy = %(strategy)s - """ - if sdt: - sdt_str = _format_for_db(sdt) - if sdt_str: - query += " AND dt >= %(sdt)s" - parameters["sdt"] = sdt_str - if edt: - edt_str = _format_for_db(edt) - if edt_str: - query += " AND dt <= %(edt)s" - parameters["edt"] = edt_str - if symbols: - if isinstance(symbols, str): - symbols = [symbols] - query += " AND symbol IN %(symbols)s" - parameters["symbols"] = tuple(symbols) - - df = db.query_df(query, parameters=parameters) - if not df.empty: - df = _localize_dataframe_columns(df, ["dt", "update_time"], tz=tz) - df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) - return df - - -def get_latest_weights( - db: "Client | None" = None, strategy=None, database="czsc_strategy", tz=ZoneInfo("Asia/Shanghai") -) -> pd.DataFrame: - """获取策略最新持仓权重时间 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param strategy: str, 策略名称, 默认 None - :param database: str, 数据库名称 - :return: pd.DataFrame - """ - db = db or __db_from_env() - - query = f"SELECT * FROM {database}.latest_weights final" - parameters = {} - if strategy: - query += " WHERE strategy = %(strategy)s" - parameters["strategy"] = strategy - - df = db.query_df(query, parameters=parameters) - df = df.rename(columns={"latest_dt": "dt", "latest_weight": "weight", "latest_update_time": "update_time"}) - if not df.empty: - df = _localize_dataframe_columns(df, ["dt", "update_time"], tz=tz) - df = df.sort_values(["strategy", "dt", "symbol"]).reset_index(drop=True) - return df - - -def publish_weights( - strategy: str, - df: pd.DataFrame, - batch_size=100000, - logger=loguru.logger, - db: "Client | None" = None, - database="czsc_strategy", - tz=ZoneInfo("Asia/Shanghai"), -): - """发布策略持仓权重 - - :param df: pd.DataFrame, 待发布的持仓权重数据 - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param strategy: str, 策略名称 - :param batch_size: int, 批量发布的大小, 默认 100000 - :param logger: loguru.logger, 日志记录器 - :param database: str, 数据库名称 - :return: None - """ - db = db or __db_from_env() - - __send_heartbeat(db, strategy, database=database, logger=logger, tz=tz) - df = df[["dt", "symbol", "weight"]].copy() - df["strategy"] = strategy - df["dt"] = _ensure_series_tz(df["dt"], tz=tz) - - dfl = get_latest_weights(db, strategy, tz=tz, database=database) - - if not dfl.empty: - dfl = _localize_dataframe_columns(dfl, ["dt"], tz=tz) - symbol_dt = dfl.set_index("symbol")["dt"].to_dict() - logger.info(f"策略 {strategy} 最新时间:{dfl['dt'].max()}") - - rows = [] - for symbol, dfg in df.groupby("symbol"): - if symbol in symbol_dt: - dfg = dfg[dfg["dt"] > symbol_dt[symbol]].copy().reset_index(drop=True) - rows.append(dfg) - - if rows: - df = pd.concat(rows, ignore_index=True) - - logger.info(f"策略 {strategy} 共 {len(df)} 条新信号") - - df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) - df["update_time"] = pd.Timestamp.now(tz=tz) - df = df[["strategy", "symbol", "dt", "weight", "update_time"]].copy() - df = df.drop_duplicates(["symbol", "dt", "strategy"], keep="last").reset_index(drop=True) - df["weight"] = df["weight"].astype(float) - - logger.info(f"准备发布 {len(df)} 条策略信号") - - # 批量写入 - for i in range(0, len(df), batch_size): - batch_df = df.iloc[i : i + batch_size] - res = db.insert_df(f"{database}.weights", batch_df) - __send_heartbeat(db, strategy, tz=tz, database=database, logger=logger) - - if res: - logger.info(f"完成批次 {i // batch_size + 1}, 发布 {len(batch_df)} 条信号") - else: - logger.error(f"批次 {i // batch_size + 1} 发布失败: {res}") - return - - logger.info(f"完成所有信号发布, 共 {len(df)} 条") - __send_heartbeat(db, strategy, tz=tz, database=database, logger=logger) - - -def publish_returns( - strategy: str, - df: pd.DataFrame, - batch_size=100000, - logger=loguru.logger, - database="czsc_strategy", - db: "Client | None" = None, - tz=ZoneInfo("Asia/Shanghai"), -): - """发布策略日收益 - - :param df: pd.DataFrame, 待发布的日收益数据, 必须包含 dt, symbol, returns 三列 - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param strategy: str, 策略名称 - :param batch_size: int, 批量发布的大小, 默认 100000 - :param logger: loguru.logger, 日志记录器 - :return: None - """ - db = db or __db_from_env() - - df = df[["dt", "symbol", "returns"]].copy() - df["strategy"] = strategy - df["dt"] = _ensure_series_tz(df["dt"], tz=tz) - - # 查询 czsc_strategy.returns 表中,每个品种最新的时间 - dfl = db.query_df( - f"SELECT symbol, max(dt) as dt FROM {database}.returns final WHERE strategy = %(strategy)s GROUP BY symbol", - parameters={"strategy": strategy}, - ) - - if not dfl.empty: - dfl["dt"] = _ensure_series_tz(dfl["dt"], tz=tz) - symbol_dt = dfl.set_index("symbol")["dt"].to_dict() - logger.info(f"策略 {strategy} 最新时间:{dfl['dt'].max()}") - - rows = [] - for symbol, dfg in df.groupby("symbol"): - if symbol in symbol_dt: - # 允许覆盖同一天的数据 - dfg = dfg[dfg["dt"] >= symbol_dt[symbol]].copy() - rows.append(dfg) - if rows: - df = pd.concat(rows, ignore_index=True) - - logger.info(f"策略 {strategy} 共 {len(df)} 条新日收益") - - df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) - df["update_time"] = pd.Timestamp.now(tz=tz) - df = df[["strategy", "symbol", "dt", "returns", "update_time"]].copy() - df = df.drop_duplicates(["symbol", "dt", "strategy"], keep="last").reset_index(drop=True) - df["returns"] = df["returns"].astype(float) - - logger.info(f"准备发布 {len(df)} 条策略日收益") - - # 批量写入 - for i in range(0, len(df), batch_size): - batch_df = df.iloc[i : i + batch_size] - res = db.insert_df(f"{database}.returns", batch_df) - - if res: - logger.info(f"完成批次 {i // batch_size + 1}, 发布 {len(batch_df)} 条日收益") - else: - logger.error(f"批次 {i // batch_size + 1} 发布失败") - return - - logger.info(f"完成所有日收益发布, 共 {len(df)} 条") - - -def get_strategy_returns( - strategy, - db: "Client | None" = None, - sdt=None, - edt=None, - symbols=None, - database="czsc_strategy", - tz=ZoneInfo("Asia/Shanghai"), -): - """获取策略日收益 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param strategy: str, 策略名称 - :param sdt: str, 开始时间 - :param edt: str, 结束时间 - :param symbols: list, 符号列表 - :param database: str, 数据库名称 - :return: pd.DataFrame - """ - db = db or __db_from_env() - - parameters = {"strategy": strategy} - query = f""" - SELECT * FROM {database}.returns final WHERE strategy = %(strategy)s - """ - if sdt: - sdt_ts = _ensure_timestamp(sdt).tz_convert(tz) - sdt_ts = sdt_ts.replace(hour=0, minute=0, second=0, microsecond=0) - sdt_str = _format_for_db(sdt_ts) - if sdt_str: - query += " AND dt >= %(sdt)s" - parameters["sdt"] = sdt_str - if edt: - edt_ts = _ensure_timestamp(edt).tz_convert(tz) - edt_ts = edt_ts.replace(hour=23, minute=59, second=59, microsecond=0) - edt_str = _format_for_db(edt_ts) - if edt_str: - query += " AND dt <= %(edt)s" - parameters["edt"] = edt_str - if symbols: - if isinstance(symbols, str): - symbols = [symbols] - query += " AND symbol IN %(symbols)s" - parameters["symbols"] = tuple(symbols) - - df = db.query_df(query, parameters=parameters) - if not df.empty: - df = _localize_dataframe_columns(df, ["dt", "update_time"], tz=tz) - df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) - return df - - -def update_strategy_status( - strategy, - status, - db: "Client | None" = None, - logger=loguru.logger, - database="czsc_strategy", - tz=ZoneInfo("Asia/Shanghai"), -): - """更新策略状态 - - :param strategy: str, 策略名称 - :param status: str, 策略状态,实盘 或 废弃 - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param logger: loguru.logger, 日志记录器 - :param database: str, 数据库名称 - :return: None - """ - db = db or __db_from_env() - - # 验证状态值 - valid_statuses = ["实盘", "废弃"] - if status not in valid_statuses: - raise ValueError(f"无效的策略状态: {status},有效状态为: {valid_statuses}") - - # 检查策略是否存在 - meta = get_meta(db=db, strategy=strategy, database=database, logger=logger, tz=tz) - if not meta: - logger.warning(f"策略 {strategy} 不存在,无法更新状态") - return - - current_time = _format_for_db(pd.Timestamp.now(tz=tz), tz=tz) - - # 更新策略状态和更新时间 - query = f""" - ALTER TABLE {database}.metas - UPDATE status = %(status)s, update_time = %(current_time)s - WHERE strategy = %(strategy)s - """ - - db.command(query, parameters={"status": status, "current_time": current_time, "strategy": strategy}) - logger.info(f"策略 {strategy} 状态已更新为: {status}") - - -def get_strategies_by_status( - status=None, db: "Client | None" = None, database="czsc_strategy", tz=ZoneInfo("Asia/Shanghai") -) -> pd.DataFrame: - """根据状态获取策略列表 - - :param status: str, 策略状态,实盘 或 废弃,None 表示获取所有状态 - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param database: str, 数据库名称 - :return: pd.DataFrame - """ - db = db or __db_from_env() - - query = f"SELECT * FROM {database}.metas final" - parameters = {} - if status: - query += " WHERE status = %(status)s" - parameters["status"] = status - - df = db.query_df(query, parameters=parameters) - if not df.empty: - df = _localize_dataframe_columns(df, ["outsample_sdt", "create_time", "update_time", "heartbeat_time"], tz=tz) - return df - - -def clear_strategy( - strategy, - db: "Client | None" = None, - logger=loguru.logger, - human_confirm=True, - database="czsc_strategy", - tz=ZoneInfo("Asia/Shanghai"), -): - """清空策略 - - :param db: clickhouse_connect.driver.Client, 数据库连接 - :param strategy: str, 策略名称 - :param logger: loguru.logger, 日志记录器 - :param human_confirm: bool, 是否需要人工确认,默认 True - :param database: str, 数据库名称 - :return: None - """ - db = db or __db_from_env() - - # 删除前,先查询跟这个策略相关的数据情况 - meta = get_meta(db=db, strategy=strategy, database=database, logger=logger, tz=tz) - if not meta: - logger.warning(f"策略 {strategy} 不存在,无需清空") - return - - # 统计各个表中的数据量 - try: - # 统计权重数据量 - weights_count = db.query_df( - f"SELECT count(*) as count FROM {database}.weights final WHERE strategy = %(strategy)s", - parameters={"strategy": strategy}, - ).iloc[0]["count"] - - # 统计收益数据量 - returns_count = db.query_df( - f"SELECT count(*) as count FROM {database}.returns final WHERE strategy = %(strategy)s", - parameters={"strategy": strategy}, - ).iloc[0]["count"] - - # 获取权重数据的时间范围 - weights_time_range = db.query_df( - f""" - SELECT min(dt) as min_dt, max(dt) as max_dt - FROM {database}.weights final - WHERE strategy = %(strategy)s - """, - parameters={"strategy": strategy}, - ) - - # 获取收益数据的时间范围 - returns_time_range = db.query_df( - f""" - SELECT min(dt) as min_dt, max(dt) as max_dt - FROM {database}.returns final - WHERE strategy = %(strategy)s - """, - parameters={"strategy": strategy}, - ) - - # 输出数据概况 - logger.info(f"策略 {strategy} 数据概况:") - logger.info(f" - 策略状态: {meta.get('status', '未知')}") - logger.info(f" - 创建时间: {meta.get('create_time', '未知')}") - logger.info(f" - 最后更新: {meta.get('update_time', '未知')}") - logger.info(f" - 权重数据: {weights_count:,} 条") - - if not weights_time_range.empty: - weights_time_range = _localize_dataframe_columns(weights_time_range, ["min_dt", "max_dt"], tz=tz) - if not weights_time_range.empty and weights_time_range.iloc[0]["min_dt"] is not None: - min_dt = weights_time_range.iloc[0]["min_dt"] - max_dt = weights_time_range.iloc[0]["max_dt"] - logger.info(f" 时间范围: {min_dt} 至 {max_dt}") - - logger.info(f" - 收益数据: {returns_count:,} 条") - - if not returns_time_range.empty: - returns_time_range = _localize_dataframe_columns(returns_time_range, ["min_dt", "max_dt"], tz=tz) - if not returns_time_range.empty and returns_time_range.iloc[0]["min_dt"] is not None: - min_dt = returns_time_range.iloc[0]["min_dt"] - max_dt = returns_time_range.iloc[0]["max_dt"] - logger.info(f" 时间范围: {min_dt} 至 {max_dt}") - - total_records = weights_count + returns_count + 1 # +1 for meta record - logger.info(f" - 总计将删除: {total_records:,} 条记录") - - except Exception as e: - logger.error(f"查询策略 {strategy} 数据概况失败: {e}") - logger.info("将继续执行删除操作...") - - if human_confirm: - logger.info("\n" + "=" * 60) - logger.info(f"⚠️ 警告:即将删除策略 {strategy} 的所有数据") - logger.info("=" * 60) - confirm = input("请仔细确认上述信息,确认删除请输入 'DELETE' (大小写敏感): ") - if confirm != "DELETE": - logger.warning(f"取消清空策略 {strategy} 的所有数据") - return - logger.info("开始执行删除操作...") - - query = f""" - DELETE FROM {database}.metas WHERE strategy = %(strategy)s - """ - _ = db.command(query, parameters={"strategy": strategy}) - logger.info(f"清空策略 {strategy} 元数据成功") - - query = f""" - DELETE FROM {database}.weights WHERE strategy = %(strategy)s - """ - _ = db.command(query, parameters={"strategy": strategy}) - logger.info(f"清空策略 {strategy} 持仓权重成功") - - query = f""" - DELETE FROM {database}.returns WHERE strategy = %(strategy)s - """ - _ = db.command(query, parameters={"strategy": strategy}) - logger.info(f"清空策略 {strategy} 日收益成功") - logger.warning(f"策略 {strategy} 清空完成") diff --git a/czsc/traders/cwc.pyi b/czsc/traders/cwc.pyi deleted file mode 100644 index ce0e839ae..000000000 --- a/czsc/traders/cwc.pyi +++ /dev/null @@ -1,59 +0,0 @@ -import pandas as pd -from clickhouse_connect.driver.client import Client as Client - -def init_latest_weights_view(db: Client | None = None, database: str = "czsc_strategy", **kwargs): ... -def init_tables(db: Client | None = None, database: str = "czsc_strategy", **kwargs): ... -def initialize(db: Client | None = None, database: str = "czsc_strategy", **kwargs): ... -def get_meta(strategy, db: Client | None = None, database: str = "czsc_strategy", logger=..., tz=...) -> dict: ... -def get_all_metas(db: Client | None = None, database: str = "czsc_strategy", tz=...) -> pd.DataFrame: ... -def set_meta( - strategy, - base_freq, - description, - author, - outsample_sdt, - weight_type: str = "ts", - status: str = "实盘", - memo: str = "", - logger=..., - overwrite: bool = False, - database: str = "czsc_strategy", - db: Client | None = None, - tz=..., -): ... -def get_strategy_weights( - strategy, db: Client | None = None, sdt=None, edt=None, symbols=None, database: str = "czsc_strategy", tz=... -): ... -def get_latest_weights( - db: Client | None = None, strategy=None, database: str = "czsc_strategy", tz=... -) -> pd.DataFrame: ... -def publish_weights( - strategy: str, - df: pd.DataFrame, - batch_size: int = 100000, - logger=..., - db: Client | None = None, - database: str = "czsc_strategy", - tz=..., -): ... -def publish_returns( - strategy: str, - df: pd.DataFrame, - batch_size: int = 100000, - logger=..., - database: str = "czsc_strategy", - db: Client | None = None, - tz=..., -): ... -def get_strategy_returns( - strategy, db: Client | None = None, sdt=None, edt=None, symbols=None, database: str = "czsc_strategy", tz=... -): ... -def update_strategy_status( - strategy, status, db: Client | None = None, logger=..., database: str = "czsc_strategy", tz=... -): ... -def get_strategies_by_status( - status=None, db: Client | None = None, database: str = "czsc_strategy", tz=... -) -> pd.DataFrame: ... -def clear_strategy( - strategy, db: Client | None = None, logger=..., human_confirm: bool = True, database: str = "czsc_strategy", tz=... -): ... diff --git a/czsc/traders/optimize.py b/czsc/traders/optimize.py new file mode 100644 index 000000000..73d703318 --- /dev/null +++ b/czsc/traders/optimize.py @@ -0,0 +1,500 @@ +"""czsc.traders.optimize —— 开仓/平仓参数批量优化的 Python 外观层。 + +本模块对外保留与历史版本一致的类式调用接口(``OpensOptimize``、 +``ExitsOptimize`` 以及与之配套的两个策略类 ``CzscOpenOptimStrategy``、 +``CzscExitOptimStrategy``),但实际的优化计算已统一委托给 Rust 端的批量 +优化引擎 ``czsc.research.run_optimize_batch``,以便利用其多线程与底层 +向量化能力。 + +Python 侧的主要职责: + +1. **配置归一化**:将候选信号、候选事件等多种灵活输入形式归一到 Rust 侧 + 接受的稳定结构。 +2. **物化数据**:将原始 K 线和持仓配置序列化到磁盘 parquet/JSON 文件, + 作为 Rust 引擎读取的数据源。 +3. **任务命名/哈希**:基于候选输入与品种集合生成唯一任务哈希,便于结果 + 目录隔离与任务复用。 +4. **结果转发**:保留 Rust 引擎的执行结果对象,并在实例上记录 ``message`` + 等关键字段供调用方查看。 + +策略类(``CzscOpenOptimStrategy`` / ``CzscExitOptimStrategy``)则负责将 +单个"基准仓位(beta)"按候选信号或候选事件展开为多组变体仓位,配合上层 +研究框架完成参数空间扫描。 +""" + +from __future__ import annotations + +import hashlib +import json +from pathlib import Path +from typing import Any, Callable + +# 兼容层提供的辅助函数,统一处理 Python <-> Rust 之间的数据转换、 +# 序列化、哈希计算和事件归一化等跨语言桥接逻辑。 +from czsc._compat import ( + bars_to_dataframe, + md5_upper8, + normalize_candidate_event, + normalize_candidate_events, + position_dump_to_runtime, + py_repr_json, + py_repr_list_str, +) +from czsc._native import Position +from czsc.research import run_optimize_batch +from czsc.strategies import CzscStrategyBase + + +def _signal_to_kv(signal: dict[str, Any] | str) -> dict[str, str]: + """将单个候选信号统一转换为 ``{"key": ..., "value": ...}`` 结构。 + + 输入可以是已经构造好的字典,也可以是缠论信号约定的字符串形式。 + 字符串信号通常按 ``_`` 分割成多段,最后 4 段视为信号取值(value), + 其余段视为信号键(key)。 + + Args: + signal: 待规范化的信号;支持 ``dict`` 或 ``str``。 + + Returns: + 包含 ``key`` 与 ``value`` 字段的字典;两个字段都被显式转换为字符串。 + + Raises: + ValueError: 当 ``signal`` 为字符串但分段数不足 5 段时抛出, + 说明该信号字符串不符合 czsc 命名规范。 + """ + if isinstance(signal, dict): + # 字典输入只需做类型规范化,避免下游 Rust 端遇到非字符串类型时报错。 + return {"key": str(signal["key"]), "value": str(signal["value"])} + parts = str(signal).split("_") + if len(parts) < 5: + # czsc 信号串至少形如 "freq_k1_k2_v1_v2_v3_v4",不足 5 段时无法切分。 + raise ValueError(f"invalid signal string: {signal}") + # 约定最后 4 段为取值(value),其余段拼接回完整的 key。 + return {"key": "_".join(parts[:-4]), "value": "_".join(parts[-4:])} + + +def _read_bars(read_bars: Callable, symbol: str, base_freq: str, bar_sdt: str, bar_edt: str): + """以兼容方式调用上层提供的 ``read_bars`` 函数获取原始 K 线。 + + 某些用户自定义的 ``read_bars`` 不接受 ``fq``、``raw_bar`` 等扩展参数, + 本函数对此进行 ``TypeError`` 兜底重试,从而避免使用方修改自己的实现 + 就能接入优化框架。 + + Args: + read_bars: 用户提供的 K 线读取函数。 + symbol: 标的代码。 + base_freq: 基础 K 线周期(如 "30分钟")。 + bar_sdt: 起始日期,字符串格式。 + bar_edt: 结束日期,字符串格式。 + + Returns: + ``read_bars`` 实现返回的 K 线对象(通常是 RawBar 列表)。 + """ + try: + # 优先以"完整签名"调用,确保拉取的是后复权的 RawBar 序列。 + return read_bars(symbol, base_freq, bar_sdt, bar_edt, fq="后复权", raw_bar=True) + except TypeError: + # 用户函数若签名较老/较窄,则退化到最简调用形式。 + return read_bars(symbol, base_freq, bar_sdt, bar_edt) + + +class CzscOpenOptimStrategy(CzscStrategyBase): + """开仓参数优化所使用的兼容策略类。 + + 本策略以一个"基准仓位(beta position)"为种子,根据用户提供的候选开仓 + 信号集合,为每条候选信号派生出一份新的仓位变体;最终的 ``positions`` + 属性返回基准仓位与所有变体仓位的并集,供研究框架批量回测对比。 + """ + + @staticmethod + def update_beta_opens(beta: Position, open_signals_all): + """复制基准仓位,并在其首个开仓事件中追加额外的"全部满足"信号。 + + Args: + beta: 作为模板复制的基准 :class:`Position` 实例。 + open_signals_all: 候选开仓信号;可以是单个信号字符串/字典, + 也可以是包含多个信号的列表,统一被视作"全部满足"约束。 + + Returns: + 一份新的 :class:`Position` 实例,其名称带有 ``#<8 位哈希>`` + 后缀,用于在结果目录中区分不同变体。 + """ + if isinstance(open_signals_all, str): + # 兼容用户传入单个信号字符串的情形,统一升级为列表处理。 + open_signals_all = [open_signals_all] + + # 把所有候选信号统一转成 {"key": ..., "value": ...} 字典格式。 + normalized = [_signal_to_kv(sig) for sig in open_signals_all] + # 以 with_data=False 拿到仓位的"配置面"快照,避免拷贝运行时数据。 + pos_dict = beta.dump(with_data=False) + # 用候选信号集合的 md5 前 8 位作为名称后缀,确保变体名称唯一可追溯。 + sig_hash = hashlib.md5(str(normalized).encode("utf-8")).hexdigest()[:8].upper() + pos_dict["name"] = f"{beta.name}#{sig_hash}" + # 在第一个开仓事件的 signals_all(AND 约束)列表中追加候选信号。 + pos_dict["opens"][0]["signals_all"].extend(normalized) + return Position.load(pos_dict) + + @property + def positions(self): + """构造所有待回测的仓位列表(基准 + 各候选信号变体)。 + + Returns: + ``list[Position]``:先包含全部基准仓位,再追加每个基准仓位与 + 每条候选信号交叉派生出的变体仓位。 + """ + betas = self.load_positions(self.kwargs["files_position"]) + # 输出列表先复制一份基准仓位,作为对照组保留。 + pos_list = list(betas) + for beta in betas: + for sig in list(self.kwargs["candidate_signals"]): + # 基准仓位 × 候选信号 的二重笛卡尔积,逐个生成变体仓位。 + pos_list.append(self.update_beta_opens(beta, sig)) + return pos_list + + +class CzscExitOptimStrategy(CzscStrategyBase): + """平仓参数优化所使用的兼容策略类。 + + 在基准仓位的基础上,结合用户提供的"候选平仓事件"集合,按"替换"和 + "追加"两种模式各派生出一份新的仓位变体,从而扫描不同平仓规则对收益 + 的影响。 + """ + + @staticmethod + def update_beta_exits(beta: Position, event_dict: dict[str, Any], mode="replace"): + """复制基准仓位并应用一条候选平仓事件,生成一份新的仓位。 + + Args: + beta: 作为模板复制的基准 :class:`Position` 实例。 + event_dict: 候选平仓事件的字典描述,需包含 ``operate`` 字段。 + mode: 应用方式,支持 ``"replace"``(用候选事件替换原有 exits) + 与 ``"append"``(在原有 exits 末尾追加候选事件)。 + + Returns: + 一份新的 :class:`Position` 实例;当候选事件的 operate 与基准 + 仓位的开仓方向不匹配(例如全开多但事件为"平空")时,返回 + ``None`` 表示该变体应被跳过。 + + Raises: + ValueError: 当 ``mode`` 既不是 ``"replace"`` 也不是 ``"append"`` 时抛出。 + """ + # 通过兼容层把候选事件归一化成 Rust 侧期望的稳定结构。 + event = normalize_candidate_event(event_dict) + pos_dict = beta.dump(with_data=False) + # 收集所有开仓事件的方向,用于判断本仓位是"全多"还是"全空"。 + open_ops = [item["operate"] for item in pos_dict["opens"]] + + # 仅当事件方向与开仓方向一致时才有意义;否则跳过该候选。 + if all(op == "开多" for op in open_ops) and event["operate"] != "平多": + return None + if all(op == "开空" for op in open_ops) and event["operate"] != "平空": + return None + if mode not in {"replace", "append"}: + raise ValueError("mode must be replace or append") + + # 用事件内容的 md5 前 8 位作为变体名称后缀,确保唯一可追溯。 + event_hash = hashlib.md5(str(event).encode("utf-8")).hexdigest()[:8].upper() + if mode == "replace": + # "替换"模式:用候选事件完全覆盖原有平仓规则集合。 + pos_dict["exits"] = [event] + pos_dict["name"] = f"{beta.name}#替换{event_hash}" + else: + # "追加"模式:保留原有平仓规则,在末尾叠加候选事件。 + pos_dict["exits"].append(event) + pos_dict["name"] = f"{beta.name}#追加{event_hash}" + return Position.load(pos_dict) + + @property + def positions(self): + """构造所有待回测的仓位列表(基准 + 替换变体 + 追加变体)。 + + Returns: + ``list[Position]``:先包含全部基准仓位;随后对每个基准仓位 + 与每条候选事件,分别尝试 ``append`` 与 ``replace`` 两种应用 + 模式,仅保留方向匹配的有效变体。 + """ + betas = self.load_positions(self.kwargs["files_position"]) + # 候选事件统一归一化,确保 mode 判断和后续派生时数据结构稳定。 + events = normalize_candidate_events(self.kwargs["candidate_events"]) + pos_list = list(betas) + for beta in betas: + for event in events: + # 同时尝试追加模式与替换模式,最大化扫描覆盖。 + append_pos = self.update_beta_exits(beta, event, mode="append") + replace_pos = self.update_beta_exits(beta, event, mode="replace") + if append_pos is not None: + pos_list.append(append_pos) + if replace_pos is not None: + pos_list.append(replace_pos) + return pos_list + + +class OpensOptimize: + """开仓参数批量优化的 Python 外观类。 + + 与历史版本的 ``czsc.traders.optimize.OpensOptimize`` 在调用方式上完全 + 一致;内部不再实现 Python 端的优化主循环,而是把任务配置、K 线数据、 + 持仓快照等输入物化到磁盘后委托给 Rust 端的批量优化引擎执行。 + + Attributes: + version: 当前 OpensOptimize 实现的版本标识符。 + read_bars: 用户传入的 K 线读取函数。 + kwargs: 用户提供的全部配置项的浅拷贝。 + symbols: 排序后的标的代码列表。 + files_position: 待优化的基准持仓配置文件列表。 + task_name: 任务名称,用于结果目录命名。 + candidate_signals: 排序后的候选开仓信号列表。 + signals_module_name: 信号函数所在的 Python 模块名。 + base_freq: 基础 K 线周期,未提供时通过策略类自动推导。 + results_root: 结果输出根目录。 + task_hash: 由候选信号 + 标的列表生成的 8 位 MD5 任务哈希。 + results_path: 当前任务的结果输出目录(含哈希后缀)。 + poss_path: 当前任务的持仓快照子目录路径。 + message: ``execute`` 完成后填充,记录 Rust 端返回的信息。 + """ + + def __init__(self, read_bars: Callable, **kwargs): + """保存配置并预计算任务哈希、输出目录等元信息。 + + Args: + read_bars: 用户提供的 K 线读取函数,签名详见 :func:`_read_bars`。 + **kwargs: 任务配置;至少需要包含 ``symbols``、``files_position``、 + ``candidate_signals``、``results_path`` 等键;可选项包括 + ``task_name``、``signals_module_name``、``base_freq``、 + ``bar_sdt``、``bar_edt``、``market``、``bg_max_count`` 等。 + + Notes: + 未显式提供 ``base_freq`` 时,会临时构造一个 :class:`CzscOpenOptimStrategy` + 实例从其 ``base_freq`` 属性中推导。 + """ + self.version = "OpensOptimizeV230924" + self.read_bars = read_bars + # 浅拷贝一份用户配置,避免外部 dict 后续被修改影响内部状态。 + self.kwargs = dict(kwargs) + self.symbols = sorted(kwargs["symbols"]) + self.files_position = [str(x) for x in kwargs["files_position"]] + self.task_name = kwargs.get("task_name", "入场优化") + self.candidate_signals = sorted(list(kwargs["candidate_signals"])) + self.signals_module_name = kwargs.get("signals_module_name", "czsc.signals") + # base_freq 优先取用户显式配置;否则借助策略类自动推导, + # 保证后续读取 K 线和写入 Rust 配置时频率信息一致。 + self.base_freq = kwargs.get("base_freq") or CzscOpenOptimStrategy( + symbol="symbol", + files_position=self.files_position, + candidate_signals=self.candidate_signals, + signals_module_name=self.signals_module_name, + ).base_freq + self.results_root = Path(kwargs["results_path"]) + # 用候选信号集合 + 标的列表的字符串拼接做 MD5,截前 8 位作任务哈希; + # 相同输入会得到相同的输出目录,便于结果复用与覆盖。 + self.task_hash = md5_upper8( + f"{py_repr_list_str(self.candidate_signals)}_{py_repr_list_str(sorted(self.symbols))}" + ) + self.results_path = str(self.results_root / f"{self.task_name}_{self.task_hash}") + self.poss_path = str(Path(self.results_path) / "poss") + + def execute(self, n_jobs=1): + """物化输入数据并触发 Rust 端的开仓批量优化任务。 + + Args: + n_jobs: 传给 Rust 引擎的并发线程数;默认为 1(顺序执行)。 + + Returns: + Rust 引擎返回的结果对象;同时该对象的 ``message`` 字段会被 + 缓存到 ``self.message`` 上,便于调用方事后查阅。 + """ + # 把所有标的的 K 线写入 parquet,作为 Rust 引擎读取的原料。 + bars_dir = self._materialize_bars_dir() + # 把基准持仓 JSON 转换为 Rust 期望的 runtime 格式后落盘。 + files_position = self._materialize_position_files() + # 组装传给 Rust 的优化任务配置;optim_type 标记为开仓优化。 + cfg = { + "optim_type": "open", + "task_name": self.task_name, + "base_freq": self.base_freq, + "symbols": self.symbols, + "files_position": files_position, + "candidate_signals": self.candidate_signals, + "market": self.kwargs.get("market", "默认"), + "bg_max_count": self.kwargs.get("bg_max_count", 5000), + } + if self.kwargs.get("sdt"): + # 仅当用户显式指定 sdt 时才下发,避免覆盖 Rust 端的默认值。 + cfg["sdt"] = self.kwargs["sdt"] + result = run_optimize_batch(bars_dir, cfg, self.results_root, n_threads=n_jobs) + # 暴露 Rust 端返回的执行信息,方便上层日志记录或失败诊断。 + self.message = result.message + return result + + def _materialize_bars_dir(self): + """将所有标的的 K 线序列写入 parquet 文件并返回所在目录。 + + Returns: + 包含 ``.parquet`` 文件的目录路径,供 Rust 引擎扫描读取。 + """ + bars_dir = Path(self.results_path) / "bars" + bars_dir.mkdir(parents=True, exist_ok=True) + # 默认时间范围与历史版本保持一致;用户可通过 kwargs 自定义覆盖。 + bar_sdt = self.kwargs.get("bar_sdt", "20150101") + bar_edt = self.kwargs.get("bar_edt", "20220101") + for symbol in self.symbols: + bars = _read_bars(self.read_bars, symbol, self.base_freq, bar_sdt, bar_edt) + # parquet 不写 index,Rust 端按列名读取,避免歧义。 + bars_to_dataframe(bars, symbol=symbol).to_parquet( + bars_dir / f"{symbol}.parquet", index=False + ) + return bars_dir + + def _materialize_position_files(self): + """把每份基准仓位 JSON 转换为 runtime 格式后写入磁盘。 + + Returns: + ``list[str]``:转换后落盘的仓位 JSON 文件绝对路径列表。 + """ + out_dir = Path(self.results_path) / "positions_input" + out_dir.mkdir(parents=True, exist_ok=True) + files = [] + for file in self.files_position: + payload = json.loads(Path(file).read_text(encoding="utf-8")) + # 通过兼容层把"持久化 dump"格式转成 Rust 引擎期望的"运行时"结构。 + runtime = position_dump_to_runtime(payload) + # 移除回测/校验类字段,避免 Rust 端误用历史结果污染本次优化。 + runtime.pop("md5", None) + runtime.pop("pairs", None) + runtime.pop("holds", None) + # symbol 字段在批量优化场景被忽略,但 Rust 端要求必填,给个占位。 + runtime.setdefault("symbol", "symbol") + out_path = out_dir / Path(file).name + # 使用 ensure_ascii=False 保留中文字段,便于人工查看。 + out_path.write_text(json.dumps(runtime, ensure_ascii=False), encoding="utf-8") + files.append(str(out_path)) + return files + + +class ExitsOptimize: + """平仓参数批量优化的 Python 外观类。 + + 设计与 :class:`OpensOptimize` 对称:候选事件先在 Python 侧通过兼容层 + 归一化,再连同 K 线、基准持仓等一起交给 Rust 端的批量优化引擎执行。 + + Attributes: + version: 当前 ExitsOptimize 实现的版本标识符。 + read_bars: 用户传入的 K 线读取函数。 + kwargs: 用户提供的全部配置项的浅拷贝。 + symbols: 标的代码列表(保留原始顺序)。 + files_position: 待优化的基准持仓配置文件列表。 + task_name: 任务名称,用于结果目录命名。 + candidate_events: 归一化后的候选平仓事件列表。 + signals_module_name: 信号函数所在的 Python 模块名。 + base_freq: 基础 K 线周期,未提供时通过策略类自动推导。 + results_root: 结果输出根目录。 + task_hash: 由候选事件 + 标的列表生成的 8 位 MD5 任务哈希。 + results_path: 当前任务的结果输出目录(含哈希后缀)。 + poss_path: 当前任务的持仓快照子目录路径。 + message: ``execute`` 完成后填充,记录 Rust 端返回的信息。 + """ + + def __init__(self, read_bars: Callable, **kwargs): + """保存配置并预计算任务哈希、输出目录等元信息。 + + Args: + read_bars: 用户提供的 K 线读取函数,签名详见 :func:`_read_bars`。 + **kwargs: 任务配置;至少需要包含 ``symbols``、``files_position``、 + ``candidate_events``、``results_path`` 等键;可选项与 + :class:`OpensOptimize` 类似。 + + Notes: + 未显式提供 ``base_freq`` 时,会临时构造一个 :class:`CzscExitOptimStrategy` + 实例从其 ``base_freq`` 属性中推导。 + """ + self.version = "ExitsOptimizeV230924" + self.read_bars = read_bars + self.kwargs = dict(kwargs) + self.symbols = list(kwargs["symbols"]) + self.files_position = [str(x) for x in kwargs["files_position"]] + self.task_name = kwargs.get("task_name", "出场优化") + # 候选事件首先在 Python 侧统一归一化,下游 Rust 端拿到的结构稳定。 + self.candidate_events = normalize_candidate_events(kwargs["candidate_events"]) + self.signals_module_name = kwargs.get("signals_module_name", "czsc.signals") + # 与 OpensOptimize 对称:未显式指定 base_freq 时通过策略类反推。 + self.base_freq = kwargs.get("base_freq") or CzscExitOptimStrategy( + symbol="symbol", + files_position=self.files_position, + candidate_events=self.candidate_events, + signals_module_name=self.signals_module_name, + ).base_freq + self.results_root = Path(kwargs["results_path"]) + # 候选事件用 JSON 化字符串再做 MD5,避免 dict 不同顺序导致哈希漂移。 + self.task_hash = md5_upper8( + f"{py_repr_json(self.candidate_events)}_{py_repr_list_str(self.symbols)}" + ) + self.results_path = str(self.results_root / f"{self.task_name}_{self.task_hash}") + self.poss_path = str(Path(self.results_path) / "poss") + + def execute(self, n_jobs=1): + """物化输入数据并触发 Rust 端的平仓批量优化任务。 + + Args: + n_jobs: 传给 Rust 引擎的并发线程数;默认为 1(顺序执行)。 + + Returns: + Rust 引擎返回的结果对象;执行信息同时缓存到 ``self.message``。 + """ + bars_dir = self._materialize_bars_dir() + files_position = self._materialize_position_files() + # optim_type 标记为平仓优化;其余字段与 OpensOptimize.execute 类似, + # 但传入的是 candidate_events 而非 candidate_signals。 + cfg = { + "optim_type": "exit", + "task_name": self.task_name, + "base_freq": self.base_freq, + "symbols": self.symbols, + "files_position": files_position, + "candidate_events": self.candidate_events, + "market": self.kwargs.get("market", "默认"), + "bg_max_count": self.kwargs.get("bg_max_count", 5000), + } + if self.kwargs.get("sdt"): + cfg["sdt"] = self.kwargs["sdt"] + result = run_optimize_batch(bars_dir, cfg, self.results_root, n_threads=n_jobs) + self.message = result.message + return result + + def _materialize_bars_dir(self): + """将所有标的的 K 线序列写入 parquet 文件并返回所在目录。 + + Returns: + 包含 ``.parquet`` 文件的目录路径,供 Rust 引擎扫描读取。 + """ + bars_dir = Path(self.results_path) / "bars" + bars_dir.mkdir(parents=True, exist_ok=True) + bar_sdt = self.kwargs.get("bar_sdt", "20150101") + bar_edt = self.kwargs.get("bar_edt", "20220101") + for symbol in self.symbols: + bars = _read_bars(self.read_bars, symbol, self.base_freq, bar_sdt, bar_edt) + bars_to_dataframe(bars, symbol=symbol).to_parquet( + bars_dir / f"{symbol}.parquet", index=False + ) + return bars_dir + + def _materialize_position_files(self): + """把每份基准仓位 JSON 转换为 runtime 格式后写入磁盘。 + + Returns: + ``list[str]``:转换后落盘的仓位 JSON 文件绝对路径列表。 + """ + out_dir = Path(self.results_path) / "positions_input" + out_dir.mkdir(parents=True, exist_ok=True) + files = [] + for file in self.files_position: + payload = json.loads(Path(file).read_text(encoding="utf-8")) + # 转换成 Rust 端期望的 runtime 结构,并清理掉历史回测产物字段。 + runtime = position_dump_to_runtime(payload) + runtime.pop("md5", None) + runtime.pop("pairs", None) + runtime.pop("holds", None) + runtime.setdefault("symbol", "symbol") + out_path = out_dir / Path(file).name + out_path.write_text(json.dumps(runtime, ensure_ascii=False), encoding="utf-8") + files.append(str(out_path)) + return files diff --git a/czsc/traders/sig_parse.py b/czsc/traders/sig_parse.py index c26b947d0..819ffb369 100644 --- a/czsc/traders/sig_parse.py +++ b/czsc/traders/sig_parse.py @@ -1,5 +1,20 @@ -""" -Signal parsing helpers adapted to the currently available rs_czsc surface. +"""czsc.traders.sig_parse —— 信号字符串解析与配置反解工具模块。 + +本模块提供一组面向信号字符串的解析、反向构造与频率提取工具,用于在 +"信号字符串 ↔ 信号配置(dict)"两种表达之间进行往返转换。这些能力主要 +被策略加载器、信号注册表与信号编辑器等上层组件使用,以便用户既能用紧凑 +的字符串描述信号,也能在程序内部以结构化字典进行编辑、序列化与回放。 + +实现层面的关键点: + +* 由于 Rust 端的 ``derive_signals_config`` / ``derive_signals_freqs`` / + ``list_all_signals`` 在当前阶段尚未完全迁移到 ``czsc._native`` 命名空间, + 本模块通过 :func:`_lazy_rs_czsc` 在调用点惰性导入 ``rs_czsc``,从而保证 + 即便相关函数缺失,模块本身仍可被正常 import;只有真正调用到对应函数时 + 才会抛出明确的错误。 +* :class:`SignalsParser` 在初始化时尽力调用 ``list_all_signals`` 拉取全量 + 信号模板,缺失时退化为空注册表;上层调用方可以在没有任何注册信息时 + 仍然顺利构造解析器,只是部分功能会返回空结果。 """ from __future__ import annotations @@ -10,39 +25,150 @@ from loguru import logger from parse import parse -from rs_czsc import derive_signals_config, derive_signals_freqs, list_all_signals + +def _lazy_rs_czsc(): + """惰性导入 ``rs_czsc`` 中尚未迁移到 czsc._native 的若干函数。 + + 由于这些函数当前阶段尚未在 czsc._native 中暴露,模块加载时直接 import + 会带来强依赖;本函数把 import 推迟到调用点,让模块自身可以在 rs_czsc + 缺失或部分不可用时仍可被正常导入。 + + Returns: + 三元组 ``(derive_signals_config, derive_signals_freqs, list_all_signals)``, + 分别对应 rs_czsc 中三个尚未迁移的函数。 + + Raises: + NotImplementedError: 当 rs_czsc 不可用或缺失上述函数时抛出, + 提示调用方重新安装 rs_czsc 以恢复回退能力。 + """ + try: + from rs_czsc import ( # type: ignore[import-not-found] + derive_signals_config as _dsc, + derive_signals_freqs as _dsf, + list_all_signals as _las, + ) + except ImportError as exc: # pragma: no cover + # 故意把 ImportError 转成 NotImplementedError,让"功能未提供"的语义 + # 比"缺包"更明显,并附上修复建议。 + raise NotImplementedError( + "rs_czsc.{derive_signals_config, derive_signals_freqs, " + "list_all_signals} have not been migrated to czsc._native yet " + "(see MIGRATION_NOTES §2.8). Re-install rs_czsc to fall back." + ) from exc + return _dsc, _dsf, _las + + +def derive_signals_config(*args, **kwargs): + """``rs_czsc.derive_signals_config`` 的惰性转发包装。 + + 所有参数会原样转发给底层实现;当底层不可用时抛出 + :class:`NotImplementedError`,错误信息见 :func:`_lazy_rs_czsc`。 + """ + return _lazy_rs_czsc()[0](*args, **kwargs) + + +def derive_signals_freqs(*args, **kwargs): + """``rs_czsc.derive_signals_freqs`` 的惰性转发包装。 + + 所有参数会原样转发给底层实现;当底层不可用时抛出 + :class:`NotImplementedError`,错误信息见 :func:`_lazy_rs_czsc`。 + """ + return _lazy_rs_czsc()[1](*args, **kwargs) + + +def list_all_signals(*args, **kwargs): + """``rs_czsc.list_all_signals`` 的惰性转发包装。 + + 所有参数会原样转发给底层实现;当底层不可用时抛出 + :class:`NotImplementedError`,错误信息见 :func:`_lazy_rs_czsc`。 + """ + return _lazy_rs_czsc()[2](*args, **kwargs) def _normalize_template(template: str) -> str: - """Normalize signal templates so line breaks do not break parsing.""" + """规范化信号模板字符串,避免因换行/多余空白而无法解析。 + + 主要处理两类问题: + 1. 模板被人为换行后产生多余空白,导致 ``parse`` 无法匹配。 + 2. ``_`` 两侧若残留空格,会让占位符与实际信号串对不齐。 + + Args: + template: 原始模板字符串;允许为 ``None`` 或空串。 + + Returns: + 统一为单行、``_`` 两侧无空格的紧凑模板字符串。 + """ + # 把任意连续空白(含换行/制表符)压成单个空格,并去除首尾空白。 text = re.sub(r"\s+", " ", template or "").strip() + # 把 "_ x" 或 "x _" 这样的残留空格修正回紧凑形式。 text = text.replace(" _", "_").replace("_ ", "_") return text def _prefix_name(name: str, signals_module: str) -> str: + """为信号函数名补全模块前缀。 + + 若名称已经包含 ``"."``(即调用方已写明完整模块路径),原样返回; + 否则拼接 ``signals_module`` 作为前缀。 + + Args: + name: 信号函数名或完整路径。 + signals_module: 默认的信号函数所在模块名。 + + Returns: + 含模块前缀的信号函数完整路径字符串。 + """ return name if "." in str(name) else f"{signals_module}.{name}" def _extract_signal_key(signal: Any) -> str: + """从字符串或 Signal 对象中提取信号 key(前 3 段)。 + + czsc 信号串遵循 ``"freq_k1_k2_v1_v2_v3_v4"`` 的 7 段约定,前 3 段 + 为 key,后 4 段为 value。本函数兼容字符串与具备 ``key`` 属性的对象。 + + Args: + signal: 信号字符串或 ``Signal`` 实例。 + + Returns: + 信号 key 字符串;当输入无法解析时返回空串。 + """ if isinstance(signal, str): parts = signal.split("_") + # 需要至少 3 段才能拼出有意义的 key;否则视为无效输入。 return "_".join(parts[:3]) if len(parts) >= 3 else "" return str(getattr(signal, "key", "") or "") class SignalsParser: - """Parse signal strings into the flattened config structure used in czsc.""" + """把扁平的信号字符串解析回 czsc 标准信号配置结构的解析器。 + + 解析器在初始化时会尝试通过 ``list_all_signals`` 拉取全量信号定义, + 建立"函数名 → 模板"的注册表;后续调用 ``parse_params`` / + ``get_function_name`` / ``config_to_keys`` / ``parse`` 均依赖该注册表。 + + 当 rs_czsc 不可用导致 ``list_all_signals`` 失败时,注册表会被置空, + 此时各方法会按"找不到匹配模板"的语义返回空值,而不会抛出异常。 + """ def __init__(self, signals_module: str = "czsc.signals"): + """构建解析器并预加载信号模板注册表。 + + Args: + signals_module: 默认的信号函数所在模块名,会在补全函数名时 + 作为前缀使用。 + """ self.signals_module = signals_module + # 三张本地注册表:分别保存模板字符串、k3 段(信号子类标识)以及 + # "k3 → 函数名列表"的反向索引,供后续解析与匹配使用。 sig_pats_map: dict[str, str] = {} sig_k3_map: dict[str, str] = {} signal_defs = [] if list_all_signals is not None: try: + # 尽力拉取全量信号定义;失败时降级为空注册表,不影响模块导入。 signal_defs = list_all_signals(include_kline=True, include_trader=True) except Exception as exc: logger.warning(f"list_all_signals unavailable, using empty parser registry: {exc}") @@ -51,23 +177,36 @@ def __init__(self, signals_module: str = "czsc.signals"): name = str(item.get("name", "")).strip() template = _normalize_template(str(item.get("param_template", "")).strip()) if not name or not template: + # 信息不完整的条目无法用于解析,直接跳过。 continue sig_pats_map[name] = template parts = template.split("_") if len(parts) >= 3: + # k3 段是 czsc 信号体系中常用的"子分类"标识,用于反向查函数。 sig_k3_map[name] = parts[2].strip() self.sig_pats_map = sig_pats_map self.sig_k3_map = sig_k3_map + # 兼容旧版 API:保留 sig_name_map 字段(k3 → [函数名])。 self.sig_name_map = {k: [v] for k, v in sig_k3_map.items()} def parse_params(self, name: str, signal: str | Any): - """Parse a signal string into its function parameters.""" + """根据指定信号函数模板,把信号 key 反解为函数参数字典。 + + Args: + name: 信号函数名或完整路径;只取末段进行注册表查找。 + signal: 信号字符串或具备 ``key`` 属性的 Signal 对象。 + + Returns: + 包含函数调用所需参数的字典(含补全的 ``name`` 字段);当输入 + 非法、模板缺失或解析失败时返回 ``None``。 + """ key = _extract_signal_key(signal) if not key: return None + # 取模块名末段做注册表查找,兼容传入完整路径的情况。 short_name = str(name).split(".")[-1] pats = self.sig_pats_map.get(short_name) if not pats: @@ -80,6 +219,8 @@ def parse_params(self, name: str, signal: str | Any): params = dict(res.named) if "di" in params: + # di(distance index)按约定恒为整数,反解出来时仍是字符串, + # 在此显式转回 int,避免下游再次手动转换。 params["di"] = int(params["di"]) params["name"] = _prefix_name(short_name, self.signals_module) return params @@ -88,7 +229,21 @@ def parse_params(self, name: str, signal: str | Any): return None def get_function_name(self, signal: str): - """Guess the signal function name from a signal string.""" + """根据信号字符串猜测对应的信号函数名。 + + 匹配优先级如下: + + 1. 在本地注册表中按 k3 段精确匹配,若仅命中 1 个候选则返回; + 2. 命中多个候选时记录错误日志并返回 ``None``,避免歧义; + 3. 本地匹配失败时回退到 ``derive_signals_config`` 让 Rust 端做 + 权威解析,并取其首条结果的函数名末段返回。 + + Args: + signal: 待识别的信号字符串。 + + Returns: + 匹配到的函数名字符串;无法唯一确定时返回 ``None``。 + """ key = _extract_signal_key(signal) if not key: return None @@ -98,39 +253,66 @@ def get_function_name(self, signal: str): return None k3 = parts[2].strip() + # 在本地注册表中按 k3 段做反向索引匹配。 matches = [name for name, tk3 in self.sig_k3_map.items() if tk3 == k3] if len(matches) == 1: return matches[0] if len(matches) > 1: + # 多函数共用相同 k3 时无法唯一定位,记录日志后放弃。 logger.error(f"signal {signal} matched multiple functions: {matches}") return None if derive_signals_config is not None: try: + # 本地匹配失败时,回退到 Rust 端的权威解析作为兜底。 conf = derive_signals_config([signal]) if conf: return str(conf[0]["name"]).split(".")[-1] except Exception: + # 兜底失败保持静默(log 已在更底层打印),避免噪声。 pass return None def config_to_keys(self, config: list[dict]): - """Convert signal configs back to signal keys when templates are available.""" + """利用模板把若干信号配置字典反向格式化为信号 key 字符串。 + + Args: + config: 信号配置字典列表;每项至少应包含 ``name`` 键,并提供 + 模板占位符所需的全部参数键值。 + + Returns: + 成功格式化的信号 key 字符串列表;模板缺失或格式化失败的项 + 会被静默跳过,不会中断整体流程。 + """ keys = [] for conf in config: name = str(conf.get("name", "")).split(".")[-1] pats = self.sig_pats_map.get(name) if not pats: + # 找不到模板的配置项无法重建 key,直接跳过。 continue try: + # 利用 str.format 将参数填回模板占位符,得到完整 key。 keys.append(pats.format(**conf)) except Exception: continue return keys def parse(self, signal_seq: list[str]): - """Parse a signal sequence into flattened configs.""" + """把一组信号字符串解析成扁平化的信号配置列表。 + + 实现上首先把信号序列交给 ``derive_signals_config``(Rust 端) + 做权威解析,再在 Python 侧补齐模块前缀、展开 ``params`` 字段并 + 去重,得到的结构可以直接用作 czsc 信号配置项。 + + Args: + signal_seq: 信号字符串序列。 + + Returns: + 扁平化后的信号配置字典列表;输入为空或底层解析不可用时返回 + 空列表。 + """ if not signal_seq or derive_signals_config is None: return [] @@ -143,6 +325,7 @@ def parse(self, signal_seq: list[str]): out: list[dict[str, Any]] = [] for row in conf: raw = dict(row) + # 名称在 Python 侧补齐模块前缀,得到可直接 import 的全限定名。 item: dict[str, Any] = { "name": _prefix_name(str(raw.get("name", "")), self.signals_module), } @@ -150,29 +333,54 @@ def parse(self, signal_seq: list[str]): item["freq"] = raw.get("freq") params = raw.get("params") or {} if isinstance(params, dict): + # params 字段在配置层不嵌套,直接平铺到顶层方便后续使用。 item.update(params) for key, value in raw.items(): + # 其余字段(如版本号、扩展元数据等)原样透传。 if key not in {"name", "freq", "params"}: item[key] = value if item not in out: + # 最终结果按整条字典做去重,避免重复配置。 out.append(item) return out def get_signals_config(signals_seq: list[str], signals_module: str = "czsc.signals") -> list[dict]: - """Get signal configs from a signal sequence.""" + """把信号字符串序列解析为扁平化的信号配置列表的快捷函数。 + + Args: + signals_seq: 信号字符串序列。 + signals_module: 默认的信号函数所在模块名。 + + Returns: + 扁平化后的信号配置字典列表;语义与 :meth:`SignalsParser.parse` 一致。 + """ return SignalsParser(signals_module=signals_module).parse(signals_seq) def get_signals_freqs(signals_seq: list) -> list[str]: - """Get unique frequencies referenced by a signal sequence. + """从信号序列中提取所有涉及到的 K 线周期。 + + 输入既可以是已经解析好的信号配置字典列表,也可以是原始的信号字符串 + 列表;前者直接交给 ``derive_signals_freqs``,后者会先经过 + ``derive_signals_config`` 解析后再提取频率。 + + Args: + signals_seq: 信号字符串列表或信号配置字典列表。 + + Returns: + 去重的 K 线周期字符串列表;输入为空时返回空列表。 - 依赖 rs_czsc 进行语义解析,rs-czsc 为必选依赖。 + Notes: + 本函数依赖 ``rs_czsc`` 进行语义解析,``rs-czsc`` 为必选依赖; + 缺失时会通过 :func:`_lazy_rs_czsc` 抛出 :class:`NotImplementedError`。 """ if not signals_seq: return [] if isinstance(signals_seq[0], dict): + # 输入已是字典形式(信号配置),直接交给 Rust 端提取频率。 return list(derive_signals_freqs(signals_seq)) + # 输入为字符串信号时,先解析为配置字典再提取频率。 conf = derive_signals_config(signals_seq) return list(derive_signals_freqs(conf)) diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index 4daa2c363..fcb63ff2c 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -1,23 +1,48 @@ +""" +czsc.utils 工具子包 + +本子包汇总了 CZSC 项目内的通用工具,包括:分析(analysis)、加解密(crypto)、 +数据访问与缓存(data)、IO 助手(io)、技术指标(ta)、绘图(plotting)等。 + +为了兼顾"导入即可用"和"按需懒加载"两种诉求,本模块采用如下策略: + +1. 顶部直接导入轻量级子模块及其常用函数; +2. ``plotting`` 子包以及绘图相关函数采用 :func:`__getattr__` 实现的懒加载, + 在首次访问时才真正导入,避免启动时拖慢 ``import czsc``; +3. ``logger`` 也通过懒加载从 ``loguru`` 暴露,便于其他模块统一使用同一实例。 + +同时本模块还提供一组小工具函数:``x_round``、``import_by_name``、``freqs_sorted``、 +``create_grid_params``、``mac_address``、``to_arrow``、``timeout_decorator`` 等。 +""" + import functools import os import threading import pandas as pd -# 导入轻量级子模块 +# --------------------------------------------------------------------------- +# 轻量级子模块的直接导入 +# --------------------------------------------------------------------------- +# 这些子模块加载成本可控,且在 czsc 工程中被高频复用,因此直接 import from . import analysis, crypto, data, io, ta + +# 从 analysis 子模块再次导出常用的统计 / 绩效函数,方便在 czsc.utils 顶层直接使用 from .analysis import ( cross_sectional_ic, daily_performance, holds_performance, nmi_matrix, - overlap, psi, rolling_daily_performance, single_linear, top_drawdowns, ) + +# 加解密相关:Fernet 密钥生成与对称加解密 from .crypto import fernet_decrypt, fernet_encrypt, generate_fernet_key + +# 数据访问、磁盘缓存以及统一数据客户端 from .data import ( DataClient, DiskCache, @@ -30,14 +55,20 @@ home_path, set_url_token, ) + +# 指数成分相关工具 from .index_composition import index_composition + +# JSON / dill 等通用 IO 工具 from .io import dill_dump, dill_load, read_json, save_json + +# 阿里云 OSS 客户端封装 from .oss import AliyunOSS -# Delayed import to avoid circular dependency - import these from czsc.utils.sig directly -# from .sig import check_gap_info, is_bis_down, is_bis_up, get_sub_elements, is_symmetry_zs -# from .sig import same_dir_counts, fast_slow_cross, count_last_same, create_single_signal -from .feature_utils import feature_returns, feature_sectional_corr, is_event_feature +# 注意:``sig`` 模块依赖 ``czsc`` 顶层包,存在循环导入风险,故此处不预先 re-export, +# 调用方需要按需通过 ``from czsc.utils.sig import ...`` 方式直接引用。 + +# 交易/重采样相关工具 from .trade import resample_to_daily, risk_free_returns, update_bbars, update_nxb, update_tbars __all__ = [ @@ -52,7 +83,6 @@ "daily_performance", "holds_performance", "nmi_matrix", - "overlap", "psi", "rolling_daily_performance", "single_linear", @@ -81,10 +111,6 @@ "save_json", # oss "AliyunOSS", - # feature_utils - "feature_returns", - "feature_sectional_corr", - "is_event_feature", # trade "resample_to_daily", "risk_free_returns", @@ -104,11 +130,8 @@ "timeout_decorator", "sorted_freqs", # 延迟加载模块 - "echarts_plot", "plotting", # 延迟加载属性 - "kline_pro", - "trading_view_kline", "KlineChart", "plot_czsc_chart", "plot_cumulative_returns", @@ -126,6 +149,7 @@ "logger", ] +# 标准 K 线周期排序顺序,用于 ``freqs_sorted`` 排序 sorted_freqs = [ "Tick", "1分钟", @@ -152,9 +176,12 @@ def x_round(x: float | int, digit: int = 4) -> float | int: """用去尾法截断小数 - :param x: 数字 - :param digit: 保留小数位数 - :return: + 与 :func:`round` 不同,该函数采用去尾(向零)方式截断小数位,避免四舍五入 + 带来的微小偏差。 + + :param x: 数字(int 或 float);如为 int 则原样返回 + :param digit: 保留的小数位数 + :return: 截断后的数字;异常时返回原值并打印日志 """ if isinstance(x, int): return x @@ -163,20 +190,26 @@ def x_round(x: float | int, digit: int = 4) -> float | int: digit_ = pow(10, digit) x = int(x * digit_) / digit_ except Exception: + # 浮点转换失败时打印诊断信息,但不抛异常以保护调用链 print(f"x_round error: x = {x}") return x def get_py_namespace(file_py: str, keys: list = None) -> dict: - """获取 python 脚本文件中的 namespace + """获取 Python 脚本文件中的 namespace - :param file_py: python 脚本文件名 - :param keys: 指定需要的对象名称 - :return: namespace + 通过 ``compile`` + ``exec`` 在内存中执行脚本,并把执行后的全局命名空间返回。 + 出于安全考虑,只允许加载 ``czsc/strategies`` 与 ``czsc/signals`` 目录下的脚本。 + + :param file_py: str,Python 脚本文件名(绝对或相对路径均可) + :param keys: list,指定需要的对象名称;若提供则仅返回这些键 + :return: dict,namespace + :raises ValueError: 当文件路径不在白名单目录内时 """ if keys is None: keys = [] file_py = os.path.abspath(file_py) + # 安全白名单:只允许执行 czsc 内部的策略 / 信号脚本 allowed_prefixes = [os.path.abspath("czsc/strategies"), os.path.abspath("czsc/signals")] if not any(file_py.startswith(p) for p in allowed_prefixes): raise ValueError(f"文件路径 {file_py} 不在白名单目录内") @@ -190,11 +223,14 @@ def get_py_namespace(file_py: str, keys: list = None) -> dict: def code_namespace(code: str, keys: list = None) -> dict: - """获取 python 代码中的 namespace + """获取 Python 代码字符串中的 namespace + + 与 :func:`get_py_namespace` 类似,但接受的是源代码字符串而非文件路径, + 且不做白名单校验,调用方需自行确保来源安全。 - :param code: python 代码 - :param keys: 指定需要的对象名称 - :return: namespace + :param code: str,Python 源代码 + :param keys: list,指定需要的对象名称;若提供则仅返回这些键 + :return: dict,namespace """ if keys is None: keys = [] @@ -206,36 +242,33 @@ def code_namespace(code: str, keys: list = None) -> dict: def import_by_name(name): - """通过字符串导入模块、类、函数 + """通过字符串导入模块、类或函数 函数执行逻辑: - 1. 检查 name 中是否包含点号('.')。如果没有,则直接使用内置的 import 函数来导入整个模块,并返回该模块对象。 - 2. 如果 name 包含点号,先处理一个相对路径。将 name 拆分为两部分:module_name 和 function_name。 - 使用 Python 内置的 rsplit 方法从右边开始分割,只取一次,这样可以确保我们将最后的一个点号前的部分作为 module_name,点号后面的部分作为 function_name。 - 3. 使用import函数导入指定的 module_name。 - 这里传入三个参数:globals() 和 locals() 分别代表当前全局和局部命名空间; - [function_name] 是一个列表,用于指定要导入的子模块或属性名。 - 这样做是为了避免一次性导入整个模块的所有内容,提高效率。 - 4. 使用 vars 函数获取模块的字典表示形式(即模块内所有的变量和函数),取出 function_name 对应的值,然后返回这个值。 - - :param name: 模块名,如:'czsc.objects.Factor' - :return: 模块对象 + 1. 检查 ``name`` 是否包含点号(``.``)。如果没有,则直接 ``__import__`` 整个模块; + 2. 否则用 ``rsplit('.', 1)`` 拆为 ``module_name`` 与 ``function_name``; + 3. 调用 ``__import__`` 时通过 fromlist 参数 ``[function_name]`` 显式声明, + 这样可以避免一次性导入整棵模块树,提高加载效率; + 4. 通过 ``vars(module)`` 取出指定属性。 + + :param name: str,模块名或模块.属性名,例如 ``'czsc.objects.Factor'`` + :return: 模块对象 / 类 / 函数 """ if "." not in name: return __import__(name) - # 从右边开始分割,分割成模块名和函数名 + # 从右侧分割一次,保证最后一个点之前的部分整体作为模块路径 module_name, function_name = name.rsplit(".", 1) module = __import__(module_name, globals(), locals(), [function_name]) return vars(module)[function_name] def freqs_sorted(freqs): - """K线周期列表排序并去重,第一个元素是基础周期 + """K 线周期列表排序并去重,第一个元素是基础周期 - :param freqs: K线周期列表 - :return: K线周期排序列表 + :param freqs: K 线周期列表(如 ``['日线', '5分钟', '30分钟']``) + :return: list,按 :data:`sorted_freqs` 顺序排序并去重后的结果 """ _freqs_new = [x for x in sorted_freqs if x in freqs] return _freqs_new @@ -244,33 +277,35 @@ def freqs_sorted(freqs): def create_grid_params(prefix: str = "", multiply=3, **kwargs) -> dict: """创建 grid search 参数组合 - :param prefix: 参数组前缀 - :param multiply: 参数组合的位数,如果为 0,则使用 # 分隔参数 - :param kwargs: 任意参数的候选序列,参数值推荐使用 iterable - :return: 参数组合字典 - - examples - ============ - >>>x = create_grid_params("test", x=(1, 2), y=('a', 'b'), detail=True) - >>>print(x) - Out[0]: - {'test_x=1_y=a': {'x': 1, 'y': 'a'}, - 'test_x=1_y=b': {'x': 1, 'y': 'b'}, - 'test_x=2_y=a': {'x': 2, 'y': 'a'}, - 'test_x=2_y=b': {'x': 2, 'y': 'b'}} - - # 单个参数传入单个值也是可以的,但类型必须是 int, float, str 中的任一 - >>>x = create_grid_params("test", x=2, y=('a', 'b'), detail=False) - >>>print(x) - Out[1]: - {'test001': {'x': 2, 'y': 'a'}, - 'test002': {'x': 2, 'y': 'b'}} + 基于 ``sklearn.model_selection.ParameterGrid`` 生成参数笛卡尔积,并按用户 + 指定的命名风格输出,便于在批量回测、网格搜索时复用。 + + :param prefix: str,参数组前缀 + :param multiply: int,参数组合编号的位数;为 0 时改用 ``#`` 连接的可读 key + :param kwargs: 任意参数的候选序列;推荐使用 list/tuple,单个值也会被自动包装 + :return: dict,``{key: 参数字典}`` 形式的参数组合 + + 示例: + >>> x = create_grid_params("test", x=(1, 2), y=('a', 'b'), detail=True) + >>> print(x) + Out[0]: + {'test_x=1_y=a': {'x': 1, 'y': 'a'}, + 'test_x=1_y=b': {'x': 1, 'y': 'b'}, + 'test_x=2_y=a': {'x': 2, 'y': 'a'}, + 'test_x=2_y=b': {'x': 2, 'y': 'b'}} + + # 单个参数传入单个值也是可以的,但类型必须是 int, float, str 中的任一 + >>> x = create_grid_params("test", x=2, y=('a', 'b'), detail=False) + >>> print(x) + Out[1]: + {'test001': {'x': 2, 'y': 'a'}, + 'test002': {'x': 2, 'y': 'b'}} """ from sklearn.model_selection import ParameterGrid params_grid = dict(kwargs) for k, v in params_grid.items(): - # 处理非 list 类型数据 + # 标量值自动包装为单元素列表,便于 ParameterGrid 处理 if type(v) in [int, float, str]: v = [v] assert type(v) in [tuple, list], f"输入参数值必须是 list 或 tuple 类型,当前参数 {k} 值:{v}" @@ -278,6 +313,7 @@ def create_grid_params(prefix: str = "", multiply=3, **kwargs) -> dict: params = {} for i, row in enumerate(ParameterGrid(params_grid), 1): + # multiply == 0 时使用可读的 key,否则使用补零后的序号 key = "#".join([f"{k}={v}" for k, v in row.items()]) if multiply == 0 else str(i).zfill(multiply) row["version"] = f"{prefix}{key}" @@ -286,6 +322,11 @@ def create_grid_params(prefix: str = "", multiply=3, **kwargs) -> dict: def print_df_sample(df, n=5): + """以 reST 表格形式打印 DataFrame 的前 n 行,便于在文档中粘贴 + + :param df: pd.DataFrame + :param n: int,打印的行数,默认 5 + """ from tabulate import tabulate print(tabulate(df.head(n).values, headers=df.columns, tablefmt="rst")) @@ -294,12 +335,11 @@ def print_df_sample(df, n=5): def mac_address(): """获取本机 MAC 地址 - MAC地址(英语:Media Access Control Address),直译为媒体访问控制地址,也称为局域网地址(LAN Address), - 以太网地址(Ethernet Address)或物理地址(Physical Address),它是一个用来确认网络设备位置的地址。在OSI模 - 型中,第三层网络层负责IP地址,第二层数据链接层则负责MAC地址。MAC地址用于在网络中唯一标示一个网卡,一台设备若有一 - 或多个网卡,则每个网卡都需要并会有一个唯一的MAC地址。 + MAC 地址(Media Access Control Address),又称为局域网地址(LAN Address)、 + 以太网地址(Ethernet Address)或物理地址(Physical Address),用于唯一标识 + 网络中的网卡。一台设备若有多块网卡,则每块网卡都会拥有各自的 MAC 地址。 - :return: 本机 MAC 地址 + :return: str,本机 MAC 地址,形如 ``"AA-BB-CC-DD-EE-FF"`` """ import uuid @@ -309,7 +349,13 @@ def mac_address(): def to_arrow(df: pd.DataFrame): - """将 pandas.DataFrame 转换为 pyarrow.Table""" + """将 ``pandas.DataFrame`` 转换为 Arrow IPC 字节串 + + 通过 ``pyarrow.ipc.new_file`` 写出,可在不同进程 / 服务之间高效传输 DataFrame。 + + :param df: pd.DataFrame + :return: bytes,Arrow IPC file 格式的二进制数据 + """ import io import pyarrow as pa @@ -322,14 +368,20 @@ def to_arrow(df: pd.DataFrame): def timeout_decorator(timeout): - """Timeout decorator using threading + """基于线程实现的超时装饰器 + + 将被装饰函数放在子线程中执行,主线程 ``join`` 等待 ``timeout`` 秒;超时则 + 返回 ``None`` 并打印 warning 日志。注意:超时不会真正终止子线程,子线程仍在 + 后台运行。 - :param timeout: int, timeout duration in seconds + :param timeout: int,超时秒数 + :return: 装饰器 """ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): + # 用列表作为 mutable 容器,便于子线程回写 result = [None] exception = [None] @@ -344,6 +396,7 @@ def target(): thread.join(timeout) if thread.is_alive(): + # 超时:不终止线程,只输出告警并返回 None from loguru import logger as _logger _logger.warning(f"{func.__name__} timed out after {timeout} seconds; args: {args}; kwargs: {kwargs}") @@ -359,17 +412,19 @@ def target(): return decorator -# 延迟加载的模块映射 +# --------------------------------------------------------------------------- +# 延迟加载(lazy import)配置 +# --------------------------------------------------------------------------- +# plotting 子包及其下属绘图函数加载较重,统一在首次访问时再 importlib,避免 +# ``import czsc`` 阶段就把全部 plotly / matplotlib 等依赖拉起来。 + +# 延迟加载的子模块映射:属性名 -> 模块路径 _LAZY_SUBMODULES = { - "echarts_plot": "czsc.utils.echarts_plot", "plotting": "czsc.utils.plotting", } # 延迟加载的属性映射:属性名 -> (模块路径, 属性名) _LAZY_ATTRS = { - # echarts_plot - "kline_pro": ("czsc.utils.echarts_plot", "kline_pro"), - "trading_view_kline": ("czsc.utils.echarts_plot", "trading_view_kline"), # plotting.kline "KlineChart": ("czsc.utils.plotting.kline", "KlineChart"), "plot_czsc_chart": ("czsc.utils.plotting.kline", "plot_czsc_chart"), @@ -387,13 +442,21 @@ def target(): "plot_turnover_overview": ("czsc.utils.plotting.weight", "plot_turnover_overview"), "plot_turnover_cost_analysis": ("czsc.utils.plotting.weight", "plot_turnover_cost_analysis"), "plot_weight_time_series": ("czsc.utils.plotting.weight", "plot_weight_time_series"), - # loguru logger + # loguru logger 也通过懒加载暴露 "logger": ("loguru", "logger"), } def __getattr__(name): - """延迟加载重型子模块和属性,避免影响导入速度""" + """模块级 ``__getattr__``:实现延迟加载 + + 当用户访问尚未导入的属性时(包括 ``plotting`` 子模块和具体的绘图函数),由 + 本函数完成动态 import 并把结果写回模块全局空间,以便后续再次访问无需重复导入。 + + :param name: str,访问的属性名 + :return: 对应的模块或属性 + :raises AttributeError: 当 name 既不在懒加载子模块也不在懒加载属性表中时 + """ import importlib if name in _LAZY_SUBMODULES: diff --git a/czsc/utils/analysis/__init__.py b/czsc/utils/analysis/__init__.py index 7871f7196..33caeb0a6 100644 --- a/czsc/utils/analysis/__init__.py +++ b/czsc/utils/analysis/__init__.py @@ -1,21 +1,32 @@ -""" -分析工具模块 +"""czsc.utils.analysis —— 通用分析工具模块。 + +本子包汇集了 CZSC 中与“事后分析 / 评估”相关的常用工具函数,主要分为两大类: + +- 相关性分析(``corr``): + * :func:`cross_sectional_ic` —— 横截面 IC(Information Coefficient)计算, + 用于衡量因子值与未来收益之间的相关性; + * :func:`nmi_matrix` —— 归一化互信息(Normalized Mutual Information) + 矩阵,刻画多个变量两两之间的非线性相关程度; + * :func:`single_linear` —— 单变量线性回归的便捷封装。 + +- 业绩与统计(``stats``): + * :func:`daily_performance` —— 日频策略业绩指标(年化收益、夏普、 + 最大回撤、卡玛比率等); + * :func:`rolling_daily_performance` —— 在滚动窗口上计算上述业绩指标, + 便于绘制业绩稳定性曲线; + * :func:`holds_performance` —— 基于持仓权重序列计算业绩; + * :func:`top_drawdowns` —— 提取最大的若干段回撤区间; + * :func:`psi` —— PSI(Population Stability Index) + 群体稳定性指标,用于因子稳定性监控。 -包括统计分析、相关性分析和事件分析 +所有函数均通过 ``__all__`` 显式公开,保证 ``from czsc.utils.analysis import *`` +的行为可预期。 """ -# 从 stats 导入统计函数 -# 从 corr 导入相关性分析函数 -from .corr import ( - cross_sectional_ic, - nmi_matrix, - single_linear, -) +# 相关性分析相关函数 +from .corr import cross_sectional_ic, nmi_matrix, single_linear -# 从 events 导入事件分析函数 -from .events import ( - overlap, -) +# 业绩与统计相关函数 from .stats import ( daily_performance, holds_performance, @@ -24,17 +35,14 @@ top_drawdowns, ) +# 显式声明对外公开的符号,避免 ``from xxx import *`` 时引入过多内部依赖 __all__ = [ - # Stats + "cross_sectional_ic", "daily_performance", "holds_performance", - "top_drawdowns", - "rolling_daily_performance", - "psi", - # Correlation "nmi_matrix", + "psi", + "rolling_daily_performance", "single_linear", - "cross_sectional_ic", - # Events - "overlap", + "top_drawdowns", ] diff --git a/czsc/utils/analysis/events.py b/czsc/utils/analysis/events.py deleted file mode 100644 index beb1dc3ce..000000000 --- a/czsc/utils/analysis/events.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2024/4/27 15:01 -describe: 事件分析工具函数 -""" - -import numpy as np -import pandas as pd - - -def overlap(df: pd.DataFrame, col: str, **kwargs): - """给定 df 和 col,计算 col 中相同值的连续出现次数 - - :param df: pd.DataFrame, 至少包含 dt、symbol 和 col 列 - :param col: str,需要计算连续出现次数的列名 - :param kwargs: dict,其他参数 - - - copy: bool, 是否复制 df,默认为 True - - new_col: str, 计算结果的列名,默认为 f"{col}_overlap" - - max_overlap: int, 最大允许连续出现次数,默认为 10 - - :return: pd.DataFrame - - Example: - ======================= - >>> df = pd.DataFrame({"dt": pd.date_range("2022-01-01", periods=10, freq="D"), - >>> "symbol": "000001", - >>> "close": [1, 1, 2, 2, 2, 3, 3, 3, 3, 3]}) - >>> df = overlap(df, "close") - >>> print(df) - ======================= - 输出: - dt symbol close close_overlap - 0 2022-01-01 000001 1 1.0 - 1 2022-01-02 000001 1 2.0 - 2 2022-01-03 000001 2 1.0 - 3 2022-01-04 000001 2 2.0 - 4 2022-01-05 000001 2 3.0 - 5 2022-01-06 000001 3 1.0 - 6 2022-01-07 000001 3 2.0 - 7 2022-01-08 000001 3 3.0 - 8 2022-01-09 000001 3 4.0 - 9 2022-01-10 000001 3 5.0 - """ - if kwargs.get("copy", True) is True: - df = df.copy() - - df = df.sort_values(["symbol", "dt"]).reset_index(drop=True) - df["dt"] = pd.to_datetime(df["dt"]) - - new_col = kwargs.get("new_col", f"{col}_overlap") - - for _symbol, dfg in df.groupby("symbol"): - # 计算 col 相同值的连续个数,从 1 开始计数 - dfg[new_col] = dfg.groupby(df[col].ne(df[col].shift()).cumsum()).cumcount() + 1 - df.loc[dfg.index, new_col] = dfg[new_col] - - max_overlap = kwargs.get("max_overlap", 10) - df[new_col] = np.where(df[new_col] > max_overlap, max_overlap, df[new_col]) - return df diff --git a/czsc/utils/analysis/events.pyi b/czsc/utils/analysis/events.pyi deleted file mode 100644 index 53c089cbc..000000000 --- a/czsc/utils/analysis/events.pyi +++ /dev/null @@ -1,3 +0,0 @@ -import pandas as pd - -def overlap(df: pd.DataFrame, col: str, **kwargs): ... diff --git a/czsc/utils/backtest_report.pyi b/czsc/utils/backtest_report.pyi deleted file mode 100644 index 7ef63a5ad..000000000 --- a/czsc/utils/backtest_report.pyi +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Any - -import pandas as pd -from _typeshed import Incomplete - -from .html_report_builder import HtmlReportBuilder as HtmlReportBuilder -from .pdf_report_builder import PdfReportBuilder as PdfReportBuilder -from .plotting.backtest import ( - get_performance_metrics_cards as get_performance_metrics_cards, -) -from .plotting.backtest import ( - plot_backtest_stats as plot_backtest_stats, -) -from .plotting.backtest import ( - plot_colored_table as plot_colored_table, -) -from .plotting.backtest import ( - plot_long_short_comparison as plot_long_short_comparison, -) - -def generate_backtest_report( - df: pd.DataFrame, output_path: str | None = None, title: str = "权重回测报告", **kwargs -) -> str: ... -def generate_html_backtest_report( - df: pd.DataFrame, output_path: str | None = None, title: str = "权重回测报告", **kwargs -) -> str: ... - -class LongShortComparisonChart: - df: Incomplete - config: Incomplete - def __init__(self, df: pd.DataFrame, config: dict[str, Any]) -> None: ... - def generate(self) -> str: ... - -def generate_pdf_backtest_report( - df: pd.DataFrame, output_path: str | None = None, title: str = "权重回测报告", **kwargs -) -> str: ... diff --git a/czsc/utils/bi_info.py b/czsc/utils/bi_info.py deleted file mode 100644 index 36d71e275..000000000 --- a/czsc/utils/bi_info.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2023/9/24 12:39 -describe: K线的笔特征计算 -""" - -import pandas as pd -from tqdm import tqdm - -from czsc.core import CZSC, RawBar - - -def calculate_bi_info(bars: list[RawBar], **kwargs) -> pd.DataFrame: - """计算笔的特征 - - :param bars: 原始K线数据 - :return: 笔的特征 - """ - c = CZSC(bars, max_bi_num=kwargs.get("max_bi_num", 10000)) - - res = [ - { - "symbol": c.symbol, - "sdt": bi.fx_a.dt, - "edt": bi.fx_b.dt, - "方向": bi.direction.value, - "长度": bi.length, - "分型数": len(bi.fxs), - "斜边长度": bi.hypotenuse, - "斜边角度": bi.angle, - "涨跌幅": (bi.fx_b.fx / bi.fx_a.fx - 1) * 10000, - "R2": bi.rsq, - } - for bi in c.bi_list - ] - _df = pd.DataFrame(res) - _df["未来第一笔涨跌幅"] = _df["涨跌幅"].shift(-1) - _df["未来第二笔涨跌幅"] = _df["涨跌幅"].shift(-2) - return _df - - -def symbols_bi_infos(symbols, read_bars, freq="5分钟", sdt="20130101", edt="20190101", **kwargs) -> pd.DataFrame: - """计算多个标的的笔特征 - - :param symbols: 品种代码列表 - :param read_bars: 读取K线数据的函数,要求返回 RawBar 对象列表 - :param freq: K线周期, defaults to '5分钟' - :param sdt: 开始时间, defaults to '20130101' - :param edt: 结束时间, defaults to '20190101' - :return: 笔的特征 - """ - bis = [] - for symbol in tqdm(symbols, desc="计算笔的特征"): - try: - bars = read_bars(symbol=symbol, freq=freq, sdt=sdt, edt=edt, fq="后复权") - dfr = calculate_bi_info(bars) - bis.append(dfr) - except Exception as e: - print(f"{symbol} 计算失败: {e}") - dfb = pd.concat(bis, ignore_index=True) - return dfb diff --git a/czsc/utils/bi_info.pyi b/czsc/utils/bi_info.pyi deleted file mode 100644 index 08449ff3c..000000000 --- a/czsc/utils/bi_info.pyi +++ /dev/null @@ -1,9 +0,0 @@ -import pandas as pd - -from czsc.core import CZSC as CZSC -from czsc.core import RawBar as RawBar - -def calculate_bi_info(bars: list[RawBar], **kwargs) -> pd.DataFrame: ... -def symbols_bi_infos( - symbols, read_bars, freq: str = "5分钟", sdt: str = "20130101", edt: str = "20190101", **kwargs -) -> pd.DataFrame: ... diff --git a/czsc/utils/echarts_plot.py b/czsc/utils/echarts_plot.py deleted file mode 100644 index 9878c7256..000000000 --- a/czsc/utils/echarts_plot.py +++ /dev/null @@ -1,867 +0,0 @@ -""" -使用 pyecharts 定制绘图模块 - -""" -# ruff: noqa: E101 # JS 代码嵌入 Python 字符串,mixed spaces/tabs 是误报 - -from typing import TYPE_CHECKING, Optional - -import numpy as np -from pyecharts import options as opts -from pyecharts.charts import Bar, Grid, Kline, Line, Scatter -from pyecharts.commons.utils import JsCode - -from czsc.core import Operate - -from .ta import MACD, SMA - -if TYPE_CHECKING: - from lightweight_charts import Chart - - -def kline_pro( - kline: list[dict], - fx: list[dict] = None, - bi: list[dict] = None, - xd: list[dict] = None, - bs: list[dict] = None, - title: str = "缠中说禅K线分析", - t_seq: list[int] = None, - width: str = "1400px", - height: str = "580px", -) -> Grid: - """绘制缠中说禅K线分析结果 - - :param kline: K线 - :param fx: 分型识别结果 - :param bi: 笔识别结果 - {'dt': Timestamp('2020-11-26 00:00:00'), - 'fx_mark': 'd', - 'start_dt': Timestamp('2020-11-25 00:00:00'), - 'end_dt': Timestamp('2020-11-27 00:00:00'), - 'fx_high': 144.87, - 'fx_low': 138.0, - 'bi': 138.0} - :param xd: 线段识别结果 - :param bs: 买卖点 - :param title: 图表标题 - :param t_seq: 均线系统 - :param width: 图表宽度 - :param height: 图表高度 - :return: 用Grid组合好的图表 - """ - # 配置项设置 - # ------------------------------------------------------------------------------------------------------------------ - if t_seq is None: - t_seq = [] - if bs is None: - bs = [] - if xd is None: - xd = [] - if bi is None: - bi = [] - if fx is None: - fx = [] - bg_color = "#1f212d" # 背景 - up_color = "#F9293E" - down_color = "#00aa3b" - - init_opts = opts.InitOpts(bg_color=bg_color, width=width, height=height, animation_opts=opts.AnimationOpts(False)) - title_opts = opts.TitleOpts( - title=title, - pos_top="1%", - title_textstyle_opts=opts.TextStyleOpts(color=up_color, font_size=20), - subtitle_textstyle_opts=opts.TextStyleOpts(color=down_color, font_size=12), - ) - - label_show_opts = opts.LabelOpts(is_show=True) - label_not_show_opts = opts.LabelOpts(is_show=False) - legend_not_show_opts = opts.LegendOpts(is_show=False) - red_item_style = opts.ItemStyleOpts(color=up_color) - green_item_style = opts.ItemStyleOpts(color=down_color) - k_style_opts = opts.ItemStyleOpts( - color=up_color, color0=down_color, border_color=up_color, border_color0=down_color, opacity=0.8 - ) - - legend_opts = opts.LegendOpts( - is_show=True, - pos_top="1%", - pos_left="30%", - item_width=14, - item_height=8, - textstyle_opts=opts.TextStyleOpts(font_size=12, color="#0e99e2"), - ) - brush_opts = opts.BrushOpts( - tool_box=["rect", "polygon", "keep", "clear"], - x_axis_index="all", - brush_link="all", - out_of_brush={"colorAlpha": 0.1}, - brush_type="lineX", - ) - - axis_pointer_opts = opts.AxisPointerOpts(is_show=True, link=[{"xAxisIndex": "all"}]) - - dz_inside = opts.DataZoomOpts(False, "inside", xaxis_index=[0, 1, 2], range_start=80, range_end=100) - dz_slider = opts.DataZoomOpts( - True, "slider", xaxis_index=[0, 1, 2], pos_top="96%", pos_bottom="0%", range_start=80, range_end=100 - ) - - yaxis_opts = opts.AxisOpts( - is_scale=True, - min_="dataMin", - max_="dataMax", - splitline_opts=opts.SplitLineOpts(is_show=False), - axislabel_opts=opts.LabelOpts(color="#c7c7c7", font_size=8, position="inside"), - ) - - grid0_xaxis_opts = opts.AxisOpts( - type_="category", - grid_index=0, - axislabel_opts=label_not_show_opts, - split_number=20, - min_="dataMin", - max_="dataMax", - is_scale=True, - boundary_gap=False, - splitline_opts=opts.SplitLineOpts(is_show=False), - axisline_opts=opts.AxisLineOpts(is_on_zero=False), - ) - - tool_tip_opts = opts.TooltipOpts( - trigger="axis", - axis_pointer_type="cross", - background_color="rgba(245, 245, 245, 0.8)", - border_width=1, - border_color="#ccc", - position=JsCode( - """ - function (pos, params, el, elRect, size) { - var obj = {top: 10}; - obj[['left', 'right'][+(pos[0] < size.viewSize[0] / 2)]] = 30; - return obj; - } - """ - ), - textstyle_opts=opts.TextStyleOpts(color="#000"), - ) - - # 数据预处理 - # ------------------------------------------------------------------------------------------------------------------ - dts = [x["dt"] for x in kline] - # k_data = [[x['open'], x['close'], x['low'], x['high']] for x in kline] - k_data = [ - opts.CandleStickItem(name=i, value=[x["open"], x["close"], x["low"], x["high"]]) for i, x in enumerate(kline) - ] - - vol = [] - for i, row in enumerate(kline): - item_style = red_item_style if row["close"] > row["open"] else green_item_style - bar = opts.BarItem(name=i, value=row["vol"], itemstyle_opts=item_style, label_opts=label_not_show_opts) - vol.append(bar) - - close = np.array([x["close"] for x in kline], dtype=np.double) - diff, dea, macd = MACD(close) - macd_bar = [] - for i, v in enumerate(macd.tolist()): - item_style = red_item_style if v > 0 else green_item_style - bar = opts.BarItem(name=i, value=round(v, 4), itemstyle_opts=item_style, label_opts=label_not_show_opts) - macd_bar.append(bar) - - diff = diff.round(4) - dea = dea.round(4) - - # K 线主图 - # ------------------------------------------------------------------------------------------------------------------ - chart_k = Kline() - chart_k.add_xaxis(xaxis_data=dts) - chart_k.add_yaxis(series_name="Kline", y_axis=k_data, itemstyle_opts=k_style_opts) - - chart_k.set_global_opts( - legend_opts=legend_opts, - datazoom_opts=[dz_inside, dz_slider], - yaxis_opts=yaxis_opts, - tooltip_opts=tool_tip_opts, - axispointer_opts=axis_pointer_opts, - brush_opts=brush_opts, - title_opts=title_opts, - xaxis_opts=grid0_xaxis_opts, - ) - - # 加入买卖点 - 多头操作 - 空头操作 - if bs: - long_opens = {"i": [], "val": []} - long_exits = {"i": [], "val": []} - short_opens = {"i": [], "val": []} - short_exits = {"i": [], "val": []} - - for op in bs: - _dt = op["dt"] - _price = round(op["price"], 4) - _info = f"{op['op_desc']} - 价格{_price}" - - if op["op"] in [Operate.LO]: - long_opens["i"].append(_dt) - long_opens["val"].append([_price, _info]) - - if op["op"] in [Operate.LE]: - long_exits["i"].append(_dt) - long_exits["val"].append([_price, _info]) - - if op["op"] in [Operate.SO]: - short_opens["i"].append(_dt) - short_opens["val"].append([_price, _info]) - - if op["op"] in [Operate.SE]: - short_exits["i"].append(_dt) - short_exits["val"].append([_price, _info]) - - chart_lo = ( - Scatter() - .add_xaxis(xaxis_data=long_opens["i"]) - .add_yaxis( - series_name="多头操作", - y_axis=long_opens["val"], - symbol_size=25, - symbol="diamond", - label_opts=opts.LabelOpts(is_show=False), - itemstyle_opts=opts.ItemStyleOpts(color="#ff461f"), - tooltip_opts=opts.TooltipOpts( - textstyle_opts=opts.TextStyleOpts(font_size=12), - formatter=JsCode("function (params) {return params.value[2];}"), - ), - ) - ) - chart_le = ( - Scatter() - .add_xaxis(xaxis_data=long_exits["i"]) - .add_yaxis( - series_name="多头操作", - y_axis=long_exits["val"], - symbol_size=25, - symbol="diamond", - label_opts=opts.LabelOpts(is_show=False), - itemstyle_opts=opts.ItemStyleOpts(color="#afdd22"), - tooltip_opts=opts.TooltipOpts( - textstyle_opts=opts.TextStyleOpts(font_size=12), - formatter=JsCode("function (params) {return params.value[2];}"), - ), - ) - ) - chart_so = ( - Scatter() - .add_xaxis(xaxis_data=short_opens["i"]) - .add_yaxis( - series_name="空头订单", - y_axis=short_opens["val"], - symbol_size=25, - symbol="triangle", - label_opts=opts.LabelOpts(is_show=False), - itemstyle_opts=opts.ItemStyleOpts(color="#ff461f"), - tooltip_opts=opts.TooltipOpts( - textstyle_opts=opts.TextStyleOpts(font_size=12), - formatter=JsCode("function (params) {return params.value[2];}"), - ), - ) - ) - chart_se = ( - Scatter() - .add_xaxis(xaxis_data=short_exits["i"]) - .add_yaxis( - series_name="空头订单", - y_axis=short_exits["val"], - symbol_size=25, - symbol="triangle", - label_opts=opts.LabelOpts(is_show=False), - itemstyle_opts=opts.ItemStyleOpts(color="#afdd22"), - tooltip_opts=opts.TooltipOpts( - textstyle_opts=opts.TextStyleOpts(font_size=12), - formatter=JsCode("function (params) {return params.value[2];}"), - ), - ) - ) - - chart_k = chart_k.overlap(chart_lo) - chart_k = chart_k.overlap(chart_le) - chart_k = chart_k.overlap(chart_so) - chart_k = chart_k.overlap(chart_se) - - # 均线图 - # ------------------------------------------------------------------------------------------------------------------ - chart_ma = Line() - chart_ma.add_xaxis(xaxis_data=dts) - if not t_seq: - t_seq = [5, 13, 21] - - ma_keys = {} - for t in t_seq: - ma_keys[f"MA{t}"] = SMA(close, timeperiod=t) - - for _, (name, ma) in enumerate(ma_keys.items()): - chart_ma.add_yaxis( - series_name=name, - y_axis=ma, - is_smooth=True, - symbol_size=0, - label_opts=label_not_show_opts, - linestyle_opts=opts.LineStyleOpts(opacity=0.8, width=1), - ) - - chart_ma.set_global_opts(xaxis_opts=grid0_xaxis_opts, legend_opts=legend_not_show_opts) - chart_k = chart_k.overlap(chart_ma) - - # 缠论结果 - # ------------------------------------------------------------------------------------------------------------------ - if fx: - fx_dts = [x["dt"] for x in fx] - fx_val = [round(x["fx"], 2) for x in fx] - chart_fx = Line() - chart_fx.add_xaxis(fx_dts) - chart_fx.add_yaxis( - series_name="FX", - y_axis=fx_val, - symbol="circle", - symbol_size=6, - label_opts=label_show_opts, - itemstyle_opts=opts.ItemStyleOpts( - color="rgba(152, 147, 193, 1.0)", - ), - ) - - chart_fx.set_global_opts(xaxis_opts=grid0_xaxis_opts, legend_opts=legend_not_show_opts) - chart_k = chart_k.overlap(chart_fx) - - if bi: - bi_dts = [x["dt"] for x in bi] - bi_val = [round(x["bi"], 2) for x in bi] - chart_bi = Line() - chart_bi.add_xaxis(bi_dts) - chart_bi.add_yaxis( - series_name="BI", - y_axis=bi_val, - symbol="diamond", - symbol_size=10, - label_opts=label_show_opts, - itemstyle_opts=opts.ItemStyleOpts( - color="rgba(184, 117, 225, 1.0)", - ), - linestyle_opts=opts.LineStyleOpts(width=1.5), - ) - - chart_bi.set_global_opts(xaxis_opts=grid0_xaxis_opts, legend_opts=legend_not_show_opts) - chart_k = chart_k.overlap(chart_bi) - - if xd: - xd_dts = [x["dt"] for x in xd] - xd_val = [x["xd"] for x in xd] - chart_xd = Line() - chart_xd.add_xaxis(xd_dts) - chart_xd.add_yaxis( - series_name="XD", - y_axis=xd_val, - symbol="triangle", - symbol_size=10, - itemstyle_opts=opts.ItemStyleOpts( - color="rgba(37, 141, 54, 1.0)", - ), - ) - - chart_xd.set_global_opts(xaxis_opts=grid0_xaxis_opts, legend_opts=legend_not_show_opts) - chart_k = chart_k.overlap(chart_xd) - - # 成交量图 - # ------------------------------------------------------------------------------------------------------------------ - chart_vol = Bar() - chart_vol.add_xaxis(dts) - chart_vol.add_yaxis(series_name="Volume", y_axis=vol, bar_width="60%") - chart_vol.set_global_opts( - xaxis_opts=opts.AxisOpts( - type_="category", - grid_index=1, - boundary_gap=False, - axislabel_opts=opts.LabelOpts(is_show=True, font_size=8, color="#9b9da9"), - ), - yaxis_opts=yaxis_opts, - legend_opts=legend_not_show_opts, - ) - - # MACD图 - # ------------------------------------------------------------------------------------------------------------------ - chart_macd = Bar() - chart_macd.add_xaxis(dts) - chart_macd.add_yaxis(series_name="MACD", y_axis=macd_bar, bar_width="60%") - chart_macd.set_global_opts( - xaxis_opts=opts.AxisOpts( - type_="category", - grid_index=2, - axislabel_opts=opts.LabelOpts(is_show=False), - splitline_opts=opts.SplitLineOpts(is_show=False), - ), - yaxis_opts=opts.AxisOpts( - grid_index=2, - split_number=4, - axisline_opts=opts.AxisLineOpts(is_on_zero=False), - axistick_opts=opts.AxisTickOpts(is_show=False), - splitline_opts=opts.SplitLineOpts(is_show=False), - axislabel_opts=opts.LabelOpts(is_show=True, color="#c7c7c7"), - ), - legend_opts=opts.LegendOpts(is_show=False), - ) - - line = Line() - line.add_xaxis(dts) - line.add_yaxis( - series_name="DIFF", - y_axis=diff.tolist(), - label_opts=label_not_show_opts, - is_symbol_show=False, - linestyle_opts=opts.LineStyleOpts(opacity=0.8, width=1.0, color="#da6ee8"), - ) - line.add_yaxis( - series_name="DEA", - y_axis=dea.tolist(), - label_opts=label_not_show_opts, - is_symbol_show=False, - linestyle_opts=opts.LineStyleOpts(opacity=0.8, width=1.0, color="#39afe6"), - ) - - chart_macd = chart_macd.overlap(line) - - grid0_opts = opts.GridOpts(pos_left="0%", pos_right="1%", pos_top="12%", height="58%") - grid1_opts = opts.GridOpts(pos_left="0%", pos_right="1%", pos_top="74%", height="8%") - grid2_opts = opts.GridOpts(pos_left="0%", pos_right="1%", pos_top="86%", height="10%") - - grid_chart = Grid(init_opts) - grid_chart.add(chart_k, grid_opts=grid0_opts) - grid_chart.add(chart_vol, grid_opts=grid1_opts) - grid_chart.add(chart_macd, grid_opts=grid2_opts) - return grid_chart - - -def _prepare_kline_data(kline: list[dict], use_streamlit=False, width=1400, height=580) -> tuple: - """准备K线数据 - - :param kline: K线数据 - :return: (df_data, chart) - """ - import pandas as pd - from loguru import logger - - # 准备K线数据 - df_data = [] - for item in kline: - # 处理时间格式 - time_str = item["dt"].strftime("%Y-%m-%d") if hasattr(item["dt"], "strftime") else str(item["dt"]) - - df_data.append( - { - "time": time_str, - "open": float(item["open"]), - "high": float(item["high"]), - "low": float(item["low"]), - "close": float(item["close"]), - "volume": float(item.get("vol", item.get("volume", 0))), - } - ) - - # 创建主图表(延迟导入,避免在模块加载时引入 streamlit 等重型依赖) - if use_streamlit: - from lightweight_charts.widgets import StreamlitChart - - logger.info("使用 StreamlitChart") - chart = StreamlitChart(width=width, height=height) - else: - from lightweight_charts import Chart - - logger.info("使用 Chart") - chart = Chart() - - df = pd.DataFrame(df_data) - chart.set(df) - - logger.info(f"成功创建基础K线图表,包含{len(df_data)}根K线") - return df_data, chart - - -def _add_moving_averages(chart: "Chart", kline: list[dict], df_data: list[dict], t_seq: list[int]) -> None: - """添加移动平均线 - - :param chart: 图表对象 - :param kline: K线数据 - :param df_data: 格式化后的数据 - :param t_seq: 均线周期序列 - """ - import pandas as pd - from loguru import logger - - if not t_seq: - return - - try: - close_prices = np.array([x["close"] for x in kline], dtype=np.double) - # 均线颜色:橙色、蓝色、绿色、紫色、青色 - ma_colors = ["#FF9800", "#2196F3", "#4CAF50", "#9C27B0", "#00BCD4"] - - for i, period in enumerate(t_seq[:5]): # 最多显示5条均线 - try: - ma_values = SMA(close_prices, timeperiod=period) - ma_data = [] - - for j, item in enumerate(df_data): - if j >= period - 1 and j < len(ma_values) and not np.isnan(ma_values[j]): - ma_data.append({"time": item["time"], f"MA{period}": float(ma_values[j])}) - - if ma_data: - ma_df = pd.DataFrame(ma_data).set_index("time") - color = ma_colors[i] if i < len(ma_colors) else "#999999" - ma_line = chart.create_line(f"MA{period}", color=color) - ma_line.set(ma_df) - logger.info(f"成功添加MA{period}均线({color}),数据点数:{len(ma_data)}") - except Exception as e: - logger.warning(f"添加MA{period}均线失败: {e}") - continue - except Exception as e: - logger.warning(f"添加移动平均线失败: {e}") - - -def _add_fractal_marks(chart: "Chart", fx: list[dict]) -> None: - """添加分型标记 - - :param chart: 图表对象 - :param fx: 分型数据 - """ - import pandas as pd - from loguru import logger - - if not fx: - return - - try: - fx_data = [] - for item in fx: - time_str = item["dt"].strftime("%Y-%m-%d") if hasattr(item["dt"], "strftime") else str(item["dt"]) - - fx_data.append({"time": time_str, "分型": float(item["fx"])}) - - if fx_data: - fx_df = pd.DataFrame(fx_data).set_index("time") - fx_line = chart.create_line("分型", color="#FF5722") # 深橙红色 - fx_line.set(fx_df) - logger.info(f"成功添加{len(fx_data)}个分型点(深橙红色)") - except Exception as e: - logger.warning(f"添加分型标记失败: {e}") - - -def _add_bi_lines(chart: "Chart", bi: list[dict]) -> None: - """添加笔线 - - :param chart: 图表对象 - :param bi: 笔数据 - """ - import pandas as pd - from loguru import logger - - if not bi: - return - - try: - bi_data = [] - for item in bi: - time_str = item["dt"].strftime("%Y-%m-%d") if hasattr(item["dt"], "strftime") else str(item["dt"]) - - bi_data.append({"time": time_str, "笔": float(item["bi"])}) - - if bi_data: - bi_df = pd.DataFrame(bi_data).set_index("time") - bi_line = chart.create_line("笔", color="#FFC107") # 琥珀黄色 - bi_line.set(bi_df) - logger.info(f"成功添加{len(bi_data)}笔(琥珀黄色)") - except Exception as e: - logger.warning(f"添加笔线失败: {e}") - - -def _add_xd_lines(chart: "Chart", xd: list[dict]) -> None: - """添加线段 - - :param chart: 图表对象 - :param xd: 线段数据 - """ - import pandas as pd - from loguru import logger - - if not xd: - return - - try: - xd_data = [] - for item in xd: - time_str = item["dt"].strftime("%Y-%m-%d") if hasattr(item["dt"], "strftime") else str(item["dt"]) - - xd_data.append({"time": time_str, "线段": float(item["xd"])}) - - if xd_data: - xd_df = pd.DataFrame(xd_data).set_index("time") - xd_line = chart.create_line("线段", color="#E91E63") # 粉红色 - xd_line.set(xd_df) - logger.info(f"成功添加{len(xd_data)}条线段(粉红色)") - except Exception as e: - logger.warning(f"添加线段失败: {e}") - - -def _add_macd_indicator(chart: "Chart", kline: list[dict], df_data: list[dict]) -> None: - """添加MACD指标到子图表 - - :param chart: 图表对象 - :param kline: K线数据 - :param df_data: 格式化后的数据 - """ - import pandas as pd - from loguru import logger - - try: - close_prices = np.array([x["close"] for x in kline], dtype=np.double) - diff, dea, macd = MACD(close_prices) - - # 尝试创建子图表用于MACD显示 - try: - # 重新设置主图高度,为子图腾出空间 - chart.resize(1, 0.7) # 主图占70%高度 - - # 隐藏主图的时间轴,避免重复显示 - chart.time_scale(visible=False) - - # 创建MACD子图表,占30%高度并同步时间轴 - macd_chart = chart.create_subchart(width=1, height=0.3, sync=True) - - # 确保子图显示时间轴,并设置时间轴格式保持一致 - macd_chart.time_scale(visible=True, time_visible=True, seconds_visible=False) - - logger.info("成功创建MACD子图表并设置时间轴同步") - except Exception as subchart_e: - # 如果不支持子图表,直接返回 - logger.warning(f"子图表创建失败,跳过MACD指标: {subchart_e}") - return - - # 确保所有数组长度一致 - data_length = min(len(df_data), len(diff), len(dea), len(macd)) - - # 重要:对NaN值填充为0,确保数据长度一致,保证时间轴对齐 - diff_line_data = [] - dea_line_data = [] - histogram_data = [] - - for j in range(data_length): - time_value = df_data[j]["time"] - - # 对NaN值填充为0,而不是跳过,确保数据长度一致 - diff_val = 0.0 if np.isnan(diff[j]) else float(diff[j]) - dea_val = 0.0 if np.isnan(dea[j]) else float(dea[j]) - macd_val = 0.0 if np.isnan(macd[j]) else float(macd[j]) - - diff_line_data.append({"time": time_value, "value": diff_val}) - - dea_line_data.append({"time": time_value, "value": dea_val}) - - histogram_data.append( - {"time": time_value, "value": macd_val, "color": "#26a69a" if macd_val >= 0 else "#ef5350"} - ) - - logger.info( - f"MACD数据准备完成:总数据长度{data_length},DIFF({len(diff_line_data)}),DEA({len(dea_line_data)}),柱状图({len(histogram_data)})" - ) - - # 添加DIFF线(MACD快线) - if diff_line_data: - diff_df = pd.DataFrame(diff_line_data) - diff_line = macd_chart.create_line(color="#1976D2", width=2) # 深蓝色 - diff_line.set(diff_df) - logger.info(f"成功添加DIFF线(深蓝色),数据点数:{len(diff_line_data)}") - - # 添加DEA线(MACD慢线/信号线) - if dea_line_data: - dea_df = pd.DataFrame(dea_line_data) - dea_line = macd_chart.create_line(color="#FF5722", width=2) # 橙红色 - dea_line.set(dea_df) - logger.info(f"成功添加DEA线(橙红色),数据点数:{len(dea_line_data)}") - - # 添加MACD柱状图 - if histogram_data: - histogram_df = pd.DataFrame(histogram_data) - macd_histogram = macd_chart.create_histogram() - macd_histogram.set(histogram_df) - logger.info(f"成功添加MACD柱状图,数据点数:{len(histogram_data)}") - - # 设置子图表样式和联动 - macd_chart.legend(visible=True) - - # 添加一些样式设置以确保更好的视觉效果 - try: - # 设置MACD子图的网格线 - macd_chart.grid(vert_enabled=True, horz_enabled=True) - - # 确保子图的十字光标与主图同步 - macd_chart.crosshair( - mode="normal", vert_color="#758494", vert_style="dotted", horz_color="#758494", horz_style="dotted" - ) - except Exception as style_e: - logger.debug(f"设置MACD子图样式时出现警告: {style_e}") - - logger.info("MACD子图与主图时间轴联动设置完成") - - except Exception as e: - logger.warning(f"添加MACD指标失败: {e}") - - -def _add_trade_signals(chart: "Chart", bs: list[dict]) -> None: - """添加买卖点标记 - - :param chart: 图表对象 - :param bs: 买卖点数据 - """ - from datetime import datetime - - from loguru import logger - - if not bs: - return - - try: - for signal in bs: - # 处理时间格式 - if hasattr(signal["dt"], "strftime"): - marker_time = signal["dt"] - else: - # 尝试转换为datetime对象 - try: - marker_time = datetime.strptime(str(signal["dt"]), "%Y-%m-%d") - except Exception: - marker_time = None - - if marker_time is None: - continue - - # 根据操作类型设置不同的标记 - if signal["op"] in [Operate.LO]: # 买入开仓 - chart.marker( - time=marker_time, - position="below", - shape="circle", - color="#4CAF50", - text=signal.get("op_desc", "买入"), - ) - elif signal["op"] in [Operate.LE]: # 卖出平仓 - chart.marker( - time=marker_time, - position="above", - shape="circle", - color="#F44336", - text=signal.get("op_desc", "卖出"), - ) - elif signal["op"] in [Operate.SO]: # 卖出开仓 - chart.marker( - time=marker_time, - position="above", - shape="arrow_down", - color="#FF9800", - text=signal.get("op_desc", "做空"), - ) - elif signal["op"] in [Operate.SE]: # 买入平仓 - chart.marker( - time=marker_time, - position="below", - shape="arrow_up", - color="#2196F3", - text=signal.get("op_desc", "平空"), - ) - - logger.info(f"成功添加{len(bs)}个买卖点标记") - except Exception as e: - logger.exception(f"添加买卖点标记失败: {e}") - - -def _setup_chart_style(chart: "Chart", title: str) -> None: - """设置图表样式 - - :param chart: 图表对象 - :param title: 图表标题 - """ - from loguru import logger - - try: - # 设置图表样式 - chart.legend(visible=True) - chart.watermark(title) - - # 可以添加更多样式设置 - # chart.layout(background_color='#FFFFFF', text_color='#000000') - # chart.grid(vert_enabled=True, horz_enabled=True) - - logger.info(f"成功设置图表样式和标题: {title}") - except Exception as e: - logger.warning(f"设置图表样式失败: {e}") - - -def trading_view_kline( - kline: list[dict], - fx: list[dict] | None = None, - bi: list[dict] | None = None, - xd: list[dict] | None = None, - bs: list[dict] | None = None, - title: str = "缠中说禅K线分析", - t_seq: list[int] | None = None, - **kwargs, -) -> Optional["Chart"]: - """使用 lightweight_charts 绘制缠中说禅K线分析结果 - - 注意:本函数提供基础的lightweight_charts集成。 - 如需完整功能和更好的视觉效果,建议使用 kline_pro 函数。 - - :param kline: K线数据 - :param fx: 分型识别结果 - :param bi: 笔识别结果 - :param xd: 线段识别结果 - :param bs: 买卖点 - :param title: 图表标题 - :param t_seq: 均线系统 - :return: lightweight_charts Chart对象 或 None - """ - from loguru import logger - - # 设置默认值 - fx = fx or [] - bi = bi or [] - xd = xd or [] - bs = bs or [] - t_seq = t_seq or [5, 13, 21] - - use_streamlit = kwargs.get("use_streamlit", False) - width = kwargs.get("width", 1400) - height = kwargs.get("height", 580) - - # 准备K线数据 - df_data, chart = _prepare_kline_data(kline, use_streamlit, width, height) - - # 添加移动平均线 - _add_moving_averages(chart, kline, df_data, t_seq) - - # 添加分型标记 - _add_fractal_marks(chart, fx) - - # 添加笔线 - _add_bi_lines(chart, bi) - - # 添加线段 - _add_xd_lines(chart, xd) - - # 添加MACD指标 - _add_macd_indicator(chart, kline, df_data) - - # 添加买卖点标记 - _add_trade_signals(chart, bs) - - # 设置图表样式 - _setup_chart_style(chart, title) - - logger.info(f"创建 lightweight_charts 图表成功: {title}") - logger.info(f"包含: K线({len(kline)}), 均线({len(t_seq)}), 分型({len(fx)}), 笔({len(bi)}), 线段({len(xd)}), MACD") - - return chart diff --git a/czsc/utils/echarts_plot.pyi b/czsc/utils/echarts_plot.pyi deleted file mode 100644 index d8be82d40..000000000 --- a/czsc/utils/echarts_plot.pyi +++ /dev/null @@ -1,31 +0,0 @@ -from lightweight_charts import Chart -from pyecharts.charts import Boxplot as Boxplot -from pyecharts.charts import Grid -from pyecharts.charts import HeatMap as HeatMap - -from czsc.core import Operate as Operate - -from .ta import MACD as MACD -from .ta import SMA as SMA - -def kline_pro( - kline: list[dict], - fx: list[dict] = [], - bi: list[dict] = [], - xd: list[dict] = [], - bs: list[dict] = [], - title: str = "缠中说禅K线分析", - t_seq: list[int] = [], - width: str = "1400px", - height: str = "580px", -) -> Grid: ... -def trading_view_kline( - kline: list[dict], - fx: list[dict] | None = None, - bi: list[dict] | None = None, - xd: list[dict] | None = None, - bs: list[dict] | None = None, - title: str = "缠中说禅K线分析", - t_seq: list[int] | None = None, - **kwargs, -) -> Chart | None: ... diff --git a/czsc/utils/feature_utils.py b/czsc/utils/feature_utils.py deleted file mode 100644 index 17cb789e1..000000000 --- a/czsc/utils/feature_utils.py +++ /dev/null @@ -1,112 +0,0 @@ -# 工具函数 -import numpy as np -import pandas as pd -from loguru import logger - - -def is_event_feature(df, col, **kwargs): - """事件类因子的判断函数 - - 事件因子的特征:多头事件发生时,因子值为1;空头事件发生时,因子值为-1;其他情况,因子值为0。 - - :param df: DataFrame - :param col: str, 因子字段名称 - """ - unique_values = df[col].unique() - return all(x in [0, 1, -1] for x in unique_values) - - -def feature_returns(df, factor, target="n1b", **kwargs): - """计算因子特征截面收益率 - - :param df: pd.DataFrame, 必须包含 dt、symbol、factor, target 列 - :param factor: str, 因子列名 - :param target: str, 预测目标收益率列名 - :param kwargs: - - - fit_intercept: bool, 是否拟合截距项,默认为 False - - :return: pd.DataFrame, 新增 returns 列 - """ - from sklearn.linear_model import LinearRegression - - df = df.copy() - fit_intercept = kwargs.get("fit_intercept", False) - - ret = [] - for dt, dfg in df.groupby("dt"): - dfg = dfg.copy().dropna(subset=[factor, target]) - if dfg.empty or len(dfg) < 5: - ret.append([dt, 0]) - logger.warning(f"{dt} has no enough data, only {len(dfg)} rows") - continue - - x = dfg[factor].values.reshape(-1, 1) - y = dfg[target].values.reshape(-1, 1) - model = LinearRegression(fit_intercept=fit_intercept).fit(x, y) - ret.append([dt, model.coef_[0][0]]) - - dft = pd.DataFrame(ret, columns=["dt", "returns"]) - return dft - - -def feature_sectional_corr(df, factor, target="n1b", method="pearson", **kwargs): - """计算因子特征截面相关性(IC) - - :param df:数据,DateFrame格式 - :param factor:因子列名,一般采用F#开头的列 - :param target:目标列名,一般为n1b - :param method:{'pearson', 'kendall', 'spearman'} or callable - - * pearson : standard correlation coefficient - * kendall : Kendall Tau correlation coefficient - * spearman : Spearman rank correlation - * callable: callable with input two 1d ndarrays and returning a float - - :return:df,res: 前者是每日相关系数结果,后者是每日相关系数的统计结果 - """ - from czsc.utils import single_linear - - df = df.copy() - corr = [] - for dt, dfg in df.groupby("dt"): - dfg = dfg.copy().dropna(subset=[factor, target]) - - if dfg.empty or len(dfg) < 5: - corr.append([dt, 0]) - logger.warning(f"{dt} has no enough data, only {len(dfg)} rows") - else: - c = dfg[factor].corr(dfg[target], method=method) - corr.append([dt, c]) - - dft = pd.DataFrame(corr, columns=["dt", "corr"]) - - res = { - "factor": factor, - "target": target, - "method": method, - "IC均值": 0, - "IC标准差": 0, - "ICIR": 0, - "IC胜率": 0, - "累计IC回归R2": 0, - "累计IC回归斜率": 0, - } - if dft.empty: - return dft, res - - dft = dft[~dft["corr"].isnull()].copy() - ic_avg = dft["corr"].mean() - ic_std = dft["corr"].std() - - res["IC均值"] = round(ic_avg, 4) - res["IC标准差"] = round(ic_std, 4) - res["ICIR"] = round(ic_avg / ic_std, 4) if ic_std != 0 else 0 - if ic_avg > 0: - res["IC胜率"] = round(len(dft[dft["corr"] > 0]) / len(dft), 4) - else: - res["IC胜率"] = round(len(dft[dft["corr"] < 0]) / len(dft), 4) - - lr_ = single_linear(y=dft["corr"].cumsum().to_list()) - res.update({"累计IC回归R2": lr_["r2"], "累计IC回归斜率": lr_["slope"]}) - return dft, res diff --git a/czsc/utils/features.py b/czsc/utils/features.py deleted file mode 100644 index ffbaae24d..000000000 --- a/czsc/utils/features.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -Feature processing helpers kept for the current retained utility surface. -""" - -from __future__ import annotations - -import numpy as np -import pandas as pd - - -def normalize_feature(df, x_col, method="standard", **kwargs): - """Normalize a cross-sectional factor column by date.""" - from sklearn.preprocessing import minmax_scale, normalize, robust_scale, scale - - df = df.copy() - assert df[x_col].isna().sum() == 0, f"factor has missing values: {df[x_col].isna().sum()}" - q = kwargs.pop("q", 0.05) - - def _norm(x): - x = x.clip(lower=x.quantile(q), upper=x.quantile(1 - q)) - if method == "minmax": - return minmax_scale(x, **kwargs) - if method == "robust": - return robust_scale(x, **kwargs) - if method.startswith("norm"): - norm_type = method.split("-")[1] if "-" in method else "l2" - return normalize(x.values.reshape(1, -1), norm=norm_type).flatten() - if method == "standard": - return scale(x, **kwargs) - raise ValueError(f"unsupported normalize method: {method}") - - df[x_col] = df.groupby("dt")[x_col].transform(_norm) - return df - - -def normalize_ts_feature(df, x_col, n=10, **kwargs): - """Normalize a time-series factor into rolling quantile buckets.""" - assert df[x_col].nunique() > n * 2, "factor must have enough unique values for bucketing" - assert df[x_col].isna().sum() == 0, f"factor has missing values: {df[x_col].isna().sum()}" - min_periods = kwargs.get("min_periods", 300) - - if df.loc[df[x_col].isin([float("inf"), float("-inf")]), x_col].shape[0] > 0: - raise ValueError(f"{x_col} contains inf or -inf") - - if f"{x_col}_qcut" not in df.columns: - df[f"{x_col}_qcut"] = ( - df[x_col] - .rolling(min_periods=min_periods, window=min_periods) - .apply(lambda x: pd.qcut(x, q=n, labels=False, duplicates="drop", retbins=False).values[-1], raw=False) - ) - df[f"{x_col}_qcut"] = df[f"{x_col}_qcut"].fillna(-1) - df[f"{x_col}分层"] = df[f"{x_col}_qcut"].apply(lambda x: f"第{str(int(x + 1)).zfill(2)}层") - - return df - - -def feature_cross_layering(df, x_col, **kwargs): - """Bucket cross-sectional factor values by date.""" - n = kwargs.get("n", 10) - assert "dt" in df.columns, "factor data must contain dt" - assert "symbol" in df.columns, "factor data must contain symbol" - assert x_col in df.columns, f"factor data must contain {x_col}" - assert df["symbol"].nunique() > n, "symbol count must be greater than layer count" - - if df[x_col].nunique() > n: - - def _layering(x): - return pd.qcut(x, q=n, labels=False, duplicates="drop") - - df[f"{x_col}分层"] = df.groupby("dt")[x_col].transform(_layering) - else: - sorted_x = sorted(df[x_col].unique()) - df[f"{x_col}分层"] = df[x_col].apply(lambda x: sorted_x.index(x)) - - df[f"{x_col}分层"] = df[f"{x_col}分层"].fillna(-1) - df[f"{x_col}分层"] = df[f"{x_col}分层"].apply(lambda x: f"第{str(int(x + 1)).zfill(2)}层") - return df - - -def find_most_similarity(vector: pd.Series, matrix: pd.DataFrame, n: int = 10, metric: str = "cosine", **kwargs): - """Find the most similar columns in a matrix to the given vector.""" - del kwargs - - vec = pd.to_numeric(pd.Series(vector), errors="coerce") - data = matrix.apply(pd.to_numeric, errors="coerce") - - if len(vec) != len(data.index): - raise ValueError("vector length must match matrix row count") - - if metric == "corr": - scores = data.corrwith(vec) - elif metric == "cosine": - vec_values = vec.fillna(0.0).to_numpy(dtype=float) - vec_norm = np.linalg.norm(vec_values) - if vec_norm == 0: - scores = pd.Series(0.0, index=data.columns, dtype=float) - else: - filled = data.fillna(0.0) - scores = filled.apply( - lambda col: float(np.dot(col.to_numpy(dtype=float), vec_values) / (np.linalg.norm(col.to_numpy(dtype=float)) * vec_norm)) - if np.linalg.norm(col.to_numpy(dtype=float)) > 0 - else 0.0 - ) - else: - raise ValueError(f"unsupported metric: {metric}") - - scores = scores.fillna(-1.0).sort_values(ascending=False) - return scores.head(n) diff --git a/czsc/utils/features.pyi b/czsc/utils/features.pyi deleted file mode 100644 index 2203de0a5..000000000 --- a/czsc/utils/features.pyi +++ /dev/null @@ -1,6 +0,0 @@ -import pandas as pd - -def normalize_feature(df, x_col, method: str = "standard", **kwargs): ... -def normalize_ts_feature(df, x_col, n: int = 10, **kwargs): ... -def feature_cross_layering(df, x_col, **kwargs): ... -def find_most_similarity(vector: pd.Series, matrix: pd.DataFrame, n: int = 10, metric: str = "cosine", **kwargs): ... diff --git a/czsc/utils/holds_concepts_effect.py b/czsc/utils/holds_concepts_effect.py deleted file mode 100644 index aa216e53a..000000000 --- a/czsc/utils/holds_concepts_effect.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2021/11/17 18:50 -""" - -from collections import Counter - -import pandas as pd -from tqdm import tqdm - - -def holds_concepts_effect(holds: pd.DataFrame, concepts: dict, top_n=20, min_n=3, **kwargs): - """股票持仓列表的板块效应 - - 原理概述:在选股时,如果股票的概念板块与组合中的其他股票的概念板块有重合,那么这个股票的表现会更好。 - - 函数计算逻辑: - - 1. 如果kwargs中存在'copy'键且对应值为True,则将holds进行复制。 - 2. 为holds添加'概念板块'列,该列的值是holds中'symbol'列对应的股票的概念板块列表,如果没有对应的概念板块则填充为空。 - 3. 添加'概念数量'列,该列的值是每个股票的概念板块数量。 - 4. 从holds中筛选出概念数量大于0的行,赋值给holds。 - 5. 创建空列表new_holds和空字典dt_key_concepts。 - 6. 对holds按照'dt'进行分组,遍历每个分组,计算板块效应。 - a. 计算密集出现的概念,选取出现次数最多的前top_n个概念,赋值给key_concepts列表。 - b. 将日期dt和对应的key_concepts存入dt_key_concepts字典。 - c. 计算在密集概念中出现次数超过min_n的股票,将符合条件的股票添加到new_holds列表中。 - 7. 使用pd.concat将new_holds中的DataFrame进行合并,忽略索引,赋值给dfh。 - 8. 创建DataFrame dfk,其中包含日期(dt)和对应的强势概念(key_concepts)。 - 9. 返回dfh和dfk。 - - :param holds: 组合股票池数据,样例: - - =================== ========= ========== - dt symbol weight - =================== ========= ========== - 2023-05-09 00:00:00 601858.SH 0.00333333 - 2023-05-09 00:00:00 300502.SZ 0.00333333 - 2023-05-09 00:00:00 603258.SH 0.00333333 - 2023-05-09 00:00:00 300499.SZ 0.00333333 - 2023-05-09 00:00:00 300624.SZ 0.00333333 - =================== ========= ========== - - :param concepts: 股票的概念板块,样例: - { - '002507.SZ': ['电子商务', '超级品牌', '国企改革'], - '002508.SZ': ['家用电器', '杭州亚运会', '恒大概念'] - } - :param top_n: 选取前 n 个密集概念 - :param min_n: 单股票至少要有 n 个概念在 top_n 中 - :return: 过滤后的选股结果,每个时间点的 top_n 概念 - """ - if kwargs.get("copy", True): - holds = holds.copy() - - holds["概念板块"] = holds["symbol"].map(concepts).fillna("") - holds["概念数量"] = holds["概念板块"].apply(len) - holds = holds[holds["概念数量"] > 0] - - new_holds = [] - dt_key_concepts = {} - for dt, dfg in tqdm(holds.groupby("dt"), desc="计算板块效应"): - # 计算密集出现的概念 - key_concepts = [k for k, v in Counter([x for y in dfg["概念板块"] for x in y]).most_common(top_n)] - dt_key_concepts[dt] = key_concepts - - # 计算在密集概念中出现次数超过min_n的股票 - dfg["强势概念"] = dfg["概念板块"].apply(lambda x: ",".join(set(x) & set(key_concepts))) - sel = dfg[dfg["强势概念"].apply(lambda x: len(x.split(",")) >= min_n)] - new_holds.append(sel) - - dfh = pd.concat(new_holds, ignore_index=True) - dfk = pd.DataFrame([{"dt": k, "强势概念": v} for k, v in dt_key_concepts.items()]) - return dfh, dfk diff --git a/czsc/utils/html_report_builder.pyi b/czsc/utils/html_report_builder.pyi deleted file mode 100644 index a31881d7b..000000000 --- a/czsc/utils/html_report_builder.pyi +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any - -import pandas as pd -from _typeshed import Incomplete - -class HtmlReportBuilder: - title: Incomplete - theme: Incomplete - sections: Incomplete - custom_css: Incomplete - custom_scripts: Incomplete - chart_count: int - def __init__(self, title: str = "HTML 报告", theme: str = "light") -> None: ... - def add_custom_css(self, css: str) -> HtmlReportBuilder: ... - def add_custom_script(self, script: str) -> HtmlReportBuilder: ... - def add_header(self, params: dict[str, str], subtitle: str = None) -> HtmlReportBuilder: ... - def add_metrics(self, metrics: list[dict[str, Any]], title: str = "核心绩效指标") -> HtmlReportBuilder: ... - def add_chart_tab( - self, name: str, chart_html: str, icon: str = "bi-graph-up", active: bool = False - ) -> HtmlReportBuilder: ... - def add_charts_section(self, title: str = "可视化分析") -> HtmlReportBuilder: ... - def add_table( - self, df: pd.DataFrame, title: str = "数据表", max_rows: int = None, style: str = "Table Grid" - ) -> HtmlReportBuilder: ... - def add_section(self, title: str, content: str, icon: str = "bi-file-text") -> HtmlReportBuilder: ... - def add_footer(self, text: str = None) -> HtmlReportBuilder: ... - def render(self) -> str: ... - def save(self, file_path: str) -> str: ... diff --git a/czsc/utils/mark_czsc_status.py b/czsc/utils/mark_czsc_status.py deleted file mode 100644 index 7b304fcc7..000000000 --- a/czsc/utils/mark_czsc_status.py +++ /dev/null @@ -1,476 +0,0 @@ -""" -缠中说禅状态标记工具 - -该模块提供基于缠论分析的V字反转识别和状态标记功能,主要用于后验分析。 - -主要功能: -1. V字反转识别(两笔V字、四笔V字) -2. 趋势标记(基于多维度打分) -3. 震荡标记(基于多维度打分) -4. 正常标记(其他状态) - -注意:该模块使用未来信息进行分析,仅适用于研究和回测,不能用于实盘交易。 -""" - -import pandas as pd - - -def __two_bi_v(b1: pd.Series, b2: pd.Series, min_score: float = 0.7) -> dict | None: - """识别两笔构成的V字反转形态 - - 前提条件: - 1. 至少分别有一个向下笔和向上笔的趋势打分在 min_score 以上 - 2. 两笔方向相反 - - 识别逻辑: - - 正V字:b1 向下、b2 向上,且 b2 的高点大于 b1 的高点 - - 倒V字:b1 向上、b2 向下,且 b2 的低点小于 b1 的低点 - - :param b1: 第一笔数据 - :param b2: 第二笔数据 - :param min_score: 最小趋势打分阈值 - :return: V字模式信息,如果不符合条件返回None - """ - # 前提条件:检查趋势打分 - if b1["score"] < min_score or b2["score"] < min_score: - return None - - # 检查方向相反 - if b1["direction"] == b2["direction"]: - return None - - # 正V字识别:b1 向下、b2 向上,且 b2 的高点大于 b1 的高点 - if b1["direction"] == "向下" and b2["direction"] == "向上": - if b2["high"] > b1["high"]: - return { - "type": "两笔正V", - "pattern": "two_bi", - "bi_indices": [b1["bi_idx"], b2["bi_idx"]], - "start_price": b1["high"], - "end_price": b2["high"], - "bottom_price": min(b1["low"], b2["low"]), - } - - # 倒V字识别:b1 向上、b2 向下,且 b2 的低点小于 b1 的低点 - elif b1["direction"] == "向上" and b2["direction"] == "向下" and b2["low"] < b1["low"]: - return { - "type": "两笔倒V", - "pattern": "two_bi", - "bi_indices": [b1["bi_idx"], b2["bi_idx"]], - "start_price": b1["low"], - "end_price": b2["low"], - "top_price": max(b1["high"], b2["high"]), - } - - return None - - -def __four_bi_v(b1: pd.Series, b2: pd.Series, b3: pd.Series, b4: pd.Series, min_score: float = 0.7) -> dict | None: - """识别四笔构成的V字反转形态 - - 前提条件: - 1. 至少分别有一个向下笔和向上笔的趋势打分在 min_score 以上 - - 识别逻辑: - - 正V字:第1、3笔向下,第2、4笔向上,且第4笔的高点大于第1笔的高点 - * B1最低:第1笔为最低点且得分较高 - * B3最低:第3笔为最低点且第1笔高点高于第3笔高点 - - 倒V字:第1、3笔向上,第2、4笔向下,且第4笔的低点小于第1笔的低点 - * B1最高:第1笔为最高点且得分较高 - * B3最高:第3笔为最高点且第1笔低点低于第3笔低点 - - :param b1, b2, b3, b4: 四笔数据 - :param min_score: 最小趋势打分阈值 - :return: V字模式信息,如果不符合条件返回None - """ - # 前提条件:检查至少有一个向上笔和一个向下笔的趋势打分在阈值以上 - if not __validate_four_bi_scores(b1, b2, b3, b4, min_score): - return None - - # 计算极值 - min_low = min(b1["low"], b2["low"], b3["low"], b4["low"]) - max_high = max(b1["high"], b2["high"], b3["high"], b4["high"]) - - # 正V字识别 - result = __identify_four_bi_positive_v(b1, b2, b3, b4, min_low, min_score) - if result: - return result - - # 倒V字识别 - result = __identify_four_bi_negative_v(b1, b2, b3, b4, max_high, min_score) - return result - - -def __validate_four_bi_scores(b1: pd.Series, b2: pd.Series, b3: pd.Series, b4: pd.Series, min_score: float) -> bool: - """验证四笔的趋势打分是否满足条件""" - up_scores = [b["score"] for b in [b1, b2, b3, b4] if b["direction"] == "向上"] - down_scores = [b["score"] for b in [b1, b2, b3, b4] if b["direction"] == "向下"] - - if not up_scores or not down_scores: - return False - - return max(up_scores) >= min_score and max(down_scores) >= min_score - - -def __identify_four_bi_positive_v( - b1: pd.Series, b2: pd.Series, b3: pd.Series, b4: pd.Series, min_low: float, min_score: float -) -> dict | None: - """识别四笔正V字形态""" - # 检查方向条件:1、3向下,2、4向上 - down_directions = b1["direction"] == "向下" and b3["direction"] == "向下" - up_directions = b2["direction"] == "向上" and b4["direction"] == "向上" - if not (down_directions and up_directions): - return None - - if b4["high"] <= b1["high"]: - return None - - # 最低点在 b1 - if b1["low"] == min_low and b1["score"] > min_score: - return { - "type": "四笔正V-B1最低", - "pattern": "four_bi", - "bi_indices": [b1["bi_idx"], b2["bi_idx"], b3["bi_idx"], b4["bi_idx"]], - "start_price": b1["high"], - "end_price": b4["high"], - "bottom_price": b1["low"], - } - - # 最低点在 b3 - if b3["low"] == min_low and b1["high"] > b3["high"]: - return { - "type": "四笔正V-B3最低", - "pattern": "four_bi", - "bi_indices": [b1["bi_idx"], b2["bi_idx"], b3["bi_idx"], b4["bi_idx"]], - "start_price": b1["high"], - "end_price": b4["high"], - "bottom_price": b3["low"], - } - - return None - - -def __identify_four_bi_negative_v( - b1: pd.Series, b2: pd.Series, b3: pd.Series, b4: pd.Series, max_high: float, min_score: float -) -> dict | None: - """识别四笔倒V字形态""" - # 检查方向条件:1、3向上,2、4向下 - up_directions = b1["direction"] == "向上" and b3["direction"] == "向上" - down_directions = b2["direction"] == "向下" and b4["direction"] == "向下" - if not (up_directions and down_directions): - return None - - if b4["low"] >= b1["low"]: - return None - - # 最高点在 b1 - if b1["high"] == max_high and b1["score"] > min_score: - return { - "type": "四笔倒V-B1最高", - "pattern": "four_bi", - "bi_indices": [b1["bi_idx"], b2["bi_idx"], b3["bi_idx"], b4["bi_idx"]], - "start_price": b1["low"], - "end_price": b4["low"], - "top_price": b1["high"], - } - - # 最高点在 b3 - if b3["high"] == max_high and b1["low"] < b3["low"]: - return { - "type": "四笔倒V-B3最高", - "pattern": "four_bi", - "bi_indices": [b1["bi_idx"], b2["bi_idx"], b3["bi_idx"], b4["bi_idx"]], - "start_price": b1["low"], - "end_price": b4["low"], - "top_price": b3["high"], - } - - return None - - -def __find_v( - bi_stats: pd.DataFrame, min_score: float = 0.7, trend_ratio: float = 0.15, oscillation_ratio: float = 0.5 -) -> tuple[pd.DataFrame, list[dict]]: - """在笔统计数据上标记V字反转、趋势、震荡笔 - - 实现顺序: - 1. 遍历所有笔,识别两笔和四笔V字反转,在mark列标记 - 2. 对非V字反转的笔,取趋势得分最大的前trend_ratio比例,标记为趋势 - 3. 对非V字反转的笔,取趋势得分最小的前oscillation_ratio比例,标记为震荡 - - :param bi_stats: 笔统计数据,必须包含direction, score, bi_idx等列 - :param min_score: V字识别的最小趋势打分阈值 - :param trend_ratio: 标记为趋势的笔比例(默认15%) - :param oscillation_ratio: 标记为震荡的笔比例(默认50%) - :return: 带标记的笔统计DataFrame和V字模式列表 - """ - # 参数校验 - if not isinstance(bi_stats, pd.DataFrame): - raise ValueError("bi_stats 必须是 pandas DataFrame") - - required_columns = ["direction", "score", "bi_idx", "high", "low", "sdt", "edt"] - missing_cols = [col for col in required_columns if col not in bi_stats.columns] - if missing_cols: - raise ValueError(f"bi_stats 缺少必要列: {missing_cols}") - - # 添加 mark 列,默认为 normal - bi_stats = bi_stats.copy() - bi_stats["mark"] = "normal" - - v_patterns = [] # 存储找到的V字模式 - v_bi_indices = set() # 存储参与V字模式的笔索引 - - # 1. 遍历所有笔,识别V字反转 - # 两笔V字识别 - for i in range(len(bi_stats) - 1): - v_result = __two_bi_v(bi_stats.iloc[i], bi_stats.iloc[i + 1], min_score) - if v_result: - v_patterns.append(v_result) - v_bi_indices.update(v_result["bi_indices"]) - - # 四笔V字识别 - for i in range(len(bi_stats) - 3): - # 跳过已经被识别为两笔V字的笔 - if bi_stats.iloc[i]["bi_idx"] in v_bi_indices: - continue - - v_result = __four_bi_v( - bi_stats.iloc[i], bi_stats.iloc[i + 1], bi_stats.iloc[i + 2], bi_stats.iloc[i + 3], min_score - ) - if v_result: - v_patterns.append(v_result) - v_bi_indices.update(v_result["bi_indices"]) - - # 在 mark 列上标记 V 字反转 - bi_stats.loc[bi_stats["bi_idx"].isin(v_bi_indices), "mark"] = "v_reversal" - - # 2. 标记趋势和震荡笔 - non_v_mask = bi_stats["mark"] == "normal" - non_v_stats = bi_stats[non_v_mask] - - if len(non_v_stats) > 0: - # 趋势标记:取趋势得分最大的前trend_ratio比例 - trend_count = max(1, int(len(non_v_stats) * trend_ratio)) - trend_bi = non_v_stats.nlargest(trend_count, "score") - bi_stats.loc[trend_bi.index, "mark"] = "trend" - - # 震荡标记:取趋势得分最小的前oscillation_ratio比例 - oscillation_count = max(1, int(len(non_v_stats) * oscillation_ratio)) - oscillation_bi = non_v_stats.nsmallest(oscillation_count, "score") - bi_stats.loc[oscillation_bi.index, "mark"] = "oscillation" - - return bi_stats, v_patterns - - -def mark_czsc_status(df: pd.DataFrame, **kwargs) -> tuple[pd.DataFrame, pd.DataFrame]: - """【后验分析,有未来信息,不能用于实盘】标记V字反转、趋势、震荡的时间段 - - 该函数基于缠论分析,对K线数据进行状态标记,识别以下四种状态: - 1. V字反转:基于两笔或四笔形态识别的反转信号 - 2. 趋势:基于多维度打分识别的强趋势时段 - 3. 震荡:基于多维度打分识别的震荡时段 - 4. 正常:其他普通时段 - - :param df: 标准K线数据,必须包含 dt, symbol, open, close, high, low, vol, amount 列 - :param kwargs: 可选参数 - - copy: bool, 是否复制数据,默认True - - verbose: bool, 是否打印日志,默认False - - logger: 日志记录器,默认使用loguru - - min_score: float, V字识别最小打分阈值,默认0.7 - - trend_ratio: float, 趋势笔比例,默认0.15 - - oscillation_ratio: float, 震荡笔比例,默认0.5 - - freq: str, 分析周期,默认"30分钟" - - :return: tuple, (带有标记的K线数据, 笔状态统计数据) - - K线数据新增列:is_reversal, is_trend, is_oscillation, is_normal - - 笔统计数据包含:笔基本信息、趋势打分、状态标记 - """ - import loguru - - from czsc import CZSC, format_standard_kline - - # 参数处理 - copy_data = kwargs.get("copy", True) - verbose = kwargs.get("verbose", False) - logger_obj = kwargs.get("logger", loguru.logger) - min_score = kwargs.get("min_score", 0.7) - trend_ratio = kwargs.get("trend_ratio", 0.15) - oscillation_ratio = kwargs.get("oscillation_ratio", 0.5) - freq = kwargs.get("freq", "30分钟") - - # 参数校验 - if not isinstance(df, pd.DataFrame): - raise ValueError("df 必须是 pandas DataFrame") - - required_columns = ["dt", "symbol", "open", "close", "high", "low", "vol"] - missing_cols = [col for col in required_columns if col not in df.columns] - if missing_cols: - raise ValueError(f"K线数据缺少必要列: {missing_cols}") - - # 数据预处理 - if copy_data: - df = df.copy() - - all_bi_stats = [] - all_kline_data = [] - - # 按品种分组处理 - for symbol, dfg in df.groupby("symbol"): - if verbose: - logger_obj.info( - f"正在处理 {symbol} 数据,共 {len(dfg)} 根K线;时间范围:{dfg['dt'].min()} - {dfg['dt'].max()}" - ) - - # 数据排序和格式化 - dfg = dfg.sort_values("dt").copy().reset_index(drop=True) - bars = format_standard_kline(dfg, freq=freq) - - if len(bars) < 300: # 至少需要300根K线才能进行有效的缠论分析 - if verbose: - logger_obj.warning(f"{symbol} 数据不足,跳过分析") - continue - - # 缠论分析 - c = CZSC(bars, max_bi_num=len(bars)) - - if len(c.bi_list) < 4: # 至少需要4笔才能进行V字识别 - if verbose: - logger_obj.warning(f"{symbol} 笔数量不足,跳过V字识别") - # 仍然进行基本的趋势标记 - bi_stats = [] - for bi_idx, bi in enumerate(c.bi_list): - bi_stats.append( - { - "symbol": symbol, - "bi_idx": bi_idx, - "sdt": bi.sdt, - "edt": bi.edt, - "direction": bi.direction.value, - "high": bi.high, - "low": bi.low, - "power_price": abs(bi.change), - "length": bi.length, - "rsq": bi.rsq, - "power_volume": bi.power_volume, - } - ) - bi_stats = pd.DataFrame(bi_stats) - else: - # 提取笔统计信息 - bi_stats = [] - for bi_idx, bi in enumerate(c.bi_list): - bi_stats.append( - { - "symbol": symbol, - "bi_idx": bi_idx, - "sdt": bi.sdt, - "edt": bi.edt, - "direction": bi.direction.value, - "high": bi.high, - "low": bi.low, - "power_price": abs(bi.change), - "length": bi.length, - "rsq": bi.rsq, - "power_volume": bi.power_volume, - } - ) - bi_stats = pd.DataFrame(bi_stats) - - # 计算滚动排名指标 - window_size = min(100, len(bi_stats)) - min_periods = min(10, len(bi_stats) // 2) - - bi_stats["power_price_rank"] = ( - bi_stats["power_price"] - .rolling(window=window_size, min_periods=min_periods) - .rank(method="min", ascending=True, pct=True) - ) - bi_stats["rsq_rank"] = ( - bi_stats["rsq"] - .rolling(window=window_size, min_periods=min_periods) - .rank(method="min", ascending=True, pct=True) - ) - bi_stats["power_volume_rank"] = ( - bi_stats["power_volume"] - .rolling(window=window_size, min_periods=min_periods) - .rank(method="min", ascending=True, pct=True) - ) - - # 计算趋势度打分 - price_rank = bi_stats["power_price_rank"] - rsq_rank = bi_stats["rsq_rank"] - volume_rank = bi_stats["power_volume_rank"] - bi_stats["score"] = price_rank + rsq_rank + volume_rank - bi_stats["score"] = bi_stats["score"].rank(method="min", ascending=True, pct=True) - - # V字反转识别和状态标记 - bi_stats, v_patterns = __find_v( - bi_stats.dropna(subset=["score"]).reset_index(drop=True), - min_score=min_score, - trend_ratio=trend_ratio, - oscillation_ratio=oscillation_ratio, - ) - - if verbose: - mark_counts = bi_stats["mark"].value_counts().to_dict() - logger_obj.info(f"{symbol} - 笔类型统计: {mark_counts}") - - if v_patterns: - v_type_counts = pd.DataFrame(v_patterns)["type"].value_counts().to_dict() - logger_obj.info(f"{symbol} - V字模式分类:{v_type_counts}") - - # 初始化状态标记列 - dfg["is_reversal"] = 0 - dfg["is_trend"] = 0 - dfg["is_oscillation"] = 0 - dfg["is_normal"] = 0 - - # 根据笔状态标记对应的时间段 - if "mark" in bi_stats.columns: - # V字反转标记 - v_reversal_bis = bi_stats[bi_stats["mark"] == "v_reversal"] - for _, row in v_reversal_bis.iterrows(): - dfg.loc[(dfg["dt"] >= row["sdt"]) & (dfg["dt"] <= row["edt"]), "is_reversal"] = 1 - - # 趋势标记 - trend_bis = bi_stats[bi_stats["mark"] == "trend"] - for _, row in trend_bis.iterrows(): - dfg.loc[(dfg["dt"] >= row["sdt"]) & (dfg["dt"] <= row["edt"]), "is_trend"] = 1 - - # 震荡标记 - oscillation_bis = bi_stats[bi_stats["mark"] == "oscillation"] - for _, row in oscillation_bis.iterrows(): - dfg.loc[(dfg["dt"] >= row["sdt"]) & (dfg["dt"] <= row["edt"]), "is_oscillation"] = 1 - - # 正常标记 - normal_bis = bi_stats[bi_stats["mark"] == "normal"] - for _, row in normal_bis.iterrows(): - dfg.loc[(dfg["dt"] >= row["sdt"]) & (dfg["dt"] <= row["edt"]), "is_normal"] = 1 - - all_bi_stats.append(bi_stats) - all_kline_data.append(dfg) - - # 合并所有品种的数据 - if not all_kline_data: - raise ValueError("没有足够的数据进行分析") - - dfr = pd.concat(all_kline_data, ignore_index=True) - bi_stats_all = pd.concat(all_bi_stats, ignore_index=True) - - # 输出统计信息 - if verbose: - total_rows = len(dfr) - reversal_coverage = dfr["is_reversal"].sum() / total_rows * 100 - trend_coverage = dfr["is_trend"].sum() / total_rows * 100 - oscillation_coverage = dfr["is_oscillation"].sum() / total_rows * 100 - normal_coverage = dfr["is_normal"].sum() / total_rows * 100 - - logger_obj.info( - f"状态标记覆盖率统计:\n" - f" V字反转:{reversal_coverage:.2f}%\n" - f" 趋势时间:{trend_coverage:.2f}%\n" - f" 震荡时间:{oscillation_coverage:.2f}%\n" - f" 正常时间:{normal_coverage:.2f}%" - ) - - return dfr, bi_stats_all diff --git a/czsc/utils/mark_czsc_status.pyi b/czsc/utils/mark_czsc_status.pyi deleted file mode 100644 index b1d58c82f..000000000 --- a/czsc/utils/mark_czsc_status.pyi +++ /dev/null @@ -1,3 +0,0 @@ -import pandas as pd - -def mark_czsc_status(df: pd.DataFrame, **kwargs) -> tuple[pd.DataFrame, pd.DataFrame]: ... diff --git a/czsc/utils/pdf_report_builder.pyi b/czsc/utils/pdf_report_builder.pyi deleted file mode 100644 index 87db16a93..000000000 --- a/czsc/utils/pdf_report_builder.pyi +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Any - -import pandas as pd -from _typeshed import Incomplete -from reportlab.lib.enums import TA_RIGHT as TA_RIGHT -from reportlab.lib.units import mm as mm -from reportlab.platypus import KeepTogether as KeepTogether - -COLOR_PRIMARY: Incomplete -COLOR_SUCCESS: Incomplete -COLOR_DANGER: Incomplete -COLOR_WARNING: Incomplete -COLOR_SECONDARY: Incomplete -COLOR_LIGHT: Incomplete -COLOR_DARK: Incomplete -COLOR_BORDER: Incomplete -COLOR_SUCCESS_BG: Incomplete -COLOR_DANGER_BG: Incomplete -COLOR_PRIMARY_BG: Incomplete -PAGE_SIZE: Incomplete -PAGE_WIDTH: Incomplete -PAGE_HEIGHT: Incomplete -MARGIN_TOP: Incomplete -MARGIN_BOTTOM: Incomplete -MARGIN_LEFT: Incomplete -MARGIN_RIGHT: Incomplete -CONTENT_WIDTH: Incomplete -CONTENT_HEIGHT: Incomplete -CM_TO_PX: Incomplete -CARD_SPACING: int -FONT_NAME: str -FONT_NAME_BOLD: str - -class PdfReportBuilder: - title: Incomplete - author: Incomplete - created_at: Incomplete - def __init__(self, title: str = "PDF 报告", author: str = "CZSC") -> None: ... - def add_header(self, params: dict[str, str], subtitle: str = None) -> PdfReportBuilder: ... - def add_toc(self, title: str = "目 录") -> PdfReportBuilder: ... - def insert_toc_after_header(self, title: str = "目 录", add_page_break: bool = True) -> PdfReportBuilder: ... - def add_page_break(self) -> PdfReportBuilder: ... - def add_metrics(self, metrics: list[dict[str, Any]], title: str = "核心绩效指标") -> PdfReportBuilder: ... - def add_chart( - self, - fig_or_image, - title: str = "图表", - height: float = None, - fit_page: bool = False, - aspect_ratio: float = 0.55, - ) -> PdfReportBuilder: ... - def add_table(self, df: pd.DataFrame, title: str = "数据表", max_rows: int = None) -> PdfReportBuilder: ... - def add_section(self, title: str, content: str) -> PdfReportBuilder: ... - def add_footer(self, text: str = None) -> PdfReportBuilder: ... - def render(self) -> bytes: ... - def save(self, file_path: str) -> str: ... diff --git a/czsc/utils/plotting/kline.py b/czsc/utils/plotting/kline.py index fb1319275..3059f3685 100644 --- a/czsc/utils/plotting/kline.py +++ b/czsc/utils/plotting/kline.py @@ -1,8 +1,18 @@ """ -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2023/2/26 15:03 -describe: 使用 Plotly 构建绘图模块 +基于 Plotly 的 K 线绘图模块 + +本模块提供两个核心能力: + +1. :class:`KlineChart`:通用 K 线图工具类,封装了 Plotly 的 ``make_subplots``、 + ``Candlestick``、``Bar``、``Scatter`` 等接口,支持快速叠加均线、成交量、MACD、 + 自定义指标、标记点等; +2. :func:`plot_czsc_chart`:针对 ``CZSC`` 对象的便捷绘图入口,自动绘制 K 线、 + 均线、成交量、MACD,并叠加分型与笔; +3. :func:`plot_nx_graph`:用 Plotly 渲染 ``networkx`` 图(节点 + 带权边)。 + +作者: zengbin93 +邮箱: zeng_bin8888@163.com +创建时间: 2023/2/26 15:03 """ import os @@ -13,42 +23,49 @@ from plotly import graph_objects as go if TYPE_CHECKING: - from czsc.core import CZSC + from czsc import CZSC class KlineChart: - """K线绘图工具类 + """K 线绘图工具类 - plotly 参数详解: https://www.jianshu.com/p/4f4daf47cc85 + 封装 Plotly 的多子图布局,便于在同一张图上叠加 K 线、均线、成交量、MACD 等。 + Plotly 参数详解可参考:https://www.jianshu.com/p/4f4daf47cc85 """ def __init__(self, n_rows=3, **kwargs): - """K线绘图工具类 + """初始化 K 线绘图工具类 初始化执行逻辑: - - 接收一个可选参数 n_rows,默认值为 3。这个参数表示图表中的子图数量。 - - 接收一个可变参数列表 **kwargs,可以传递其他配置参数。 - - 如果没有提供 row_heights 参数,则根据 n_rows 设置默认的行高度。 - - 定义了一些颜色变量:color_red 和 color_green。 - - 使用 make_subplots 函数创建一个具有 n_rows 行和 1 列的子图布局,并设置一些共享属性和间距。 - - 使用 fig.update_yaxes 和 fig.update_xaxes 更新 Y 轴和 X 轴的属性,如显示网格、自动调整范围等。 - - 使用 fig.update_layout 更新整个图形的布局,包括标题、边距、图例位置和样式、背景模板等。 - - 将 fig 对象保存在 self.fig 属性中。 - - :param n_rows: 子图数量 - :param kwargs: + - 接收一个可选参数 ``n_rows``,默认值为 3,表示图表的子图数量; + - 接收可变参数 ``**kwargs``,用于传递其他配置(如 ``row_heights``、``title``、``height`` 等); + - 若未提供 ``row_heights``,则根据 ``n_rows`` 选择内置的高度比例; + - 定义两个常用颜色:``color_red``(上涨)和 ``color_green``(下跌); + - 调用 ``make_subplots`` 创建 ``n_rows × 1`` 的子图布局,并设置共享 X 轴等属性; + - 通过 ``update_yaxes`` / ``update_xaxes`` 配置 Y/X 轴的网格、自动 margin、Spike 等; + - 通过 ``update_layout`` 设置标题、外边距、图例、模板(plotly_dark)、悬停样式等; + - 把构造好的 ``go.Figure`` 赋给 ``self.fig`` 供后续操作。 + + :param n_rows: int,子图数量,仅支持 3 / 4 / 5 + :param kwargs: 其他参数 + - row_heights: list[float],每个子图的高度比例 + - y_fixed_range: bool,Y 轴是否固定范围 + - title: str,图表标题 + - height: int,图表高度(像素) """ from plotly.subplots import make_subplots self.n_rows = n_rows row_heights = kwargs.get("row_heights") if not row_heights: + # 内置的高度配置:3/4/5 行各一种默认比例 heights_map = {3: [0.6, 0.2, 0.2], 4: [0.55, 0.15, 0.15, 0.15], 5: [0.4, 0.15, 0.15, 0.15, 0.15]} assert self.n_rows in heights_map, "使用内置高度配置,n_rows 只能是 3, 4, 5" row_heights = heights_map[self.n_rows] + # 上涨用红色、下跌用绿色(A 股配色习惯) self.color_red = "rgba(249,41,62,0.7)" self.color_green = "rgba(0,170,59,0.7)" fig = make_subplots( @@ -60,6 +77,7 @@ def __init__(self, n_rows=3, **kwargs): vertical_spacing=0, ) + # 统一的 Y 轴样式:显示网格、Spike 跨子图等 fig = fig.update_yaxes( showgrid=True, zeroline=False, @@ -71,6 +89,7 @@ def __init__(self, n_rows=3, **kwargs): showline=False, spikedash="dot", ) + # 统一的 X 轴样式:使用 category 类型,避免 plotly 自动跳过非交易时段 fig = fig.update_xaxes( type="category", rangeslider_visible=False, @@ -84,11 +103,13 @@ def __init__(self, n_rows=3, **kwargs): spikedash="dot", ) - # https://plotly.com/python/reference/layout/ + # 整体布局:暗色主题、悬停联动、紧凑边距 + # 参考:https://plotly.com/python/reference/layout/ fig.update_layout( title={"text": kwargs.get("title", ""), "yanchor": "top", "y": 0.95}, - margin=go.layout.Margin(l=0, r=0, b=0, t=0), # left margin # right margin # bottom margin # top margin - # https://plotly.com/python/reference/layout/#layout-legend + margin=go.layout.Margin(l=0, r=0, b=0, t=0), # 上下左右四个方向的外边距 + # 图例配置:水平、靠近顶部、透明背景 + # 参考:https://plotly.com/python/reference/layout/#layout-legend legend={ "orientation": "h", "yanchor": "top", @@ -99,7 +120,7 @@ def __init__(self, n_rows=3, **kwargs): }, template="plotly_dark", hovermode="x unified", - hoverlabel={"bgcolor": "rgba(255,255,255,0.1)", "font": {"size": 20}}, # 透明,更容易看清后面k线 + hoverlabel={"bgcolor": "rgba(255,255,255,0.1)", "font": {"size": 20}}, # 半透明背景,方便看清后面的 K 线 dragmode="pan", legend_title_font_color="red", height=kwargs.get("height", 600), @@ -108,23 +129,22 @@ def __init__(self, n_rows=3, **kwargs): self.fig = fig def add_kline(self, kline: pd.DataFrame, name: str = "K线", **kwargs): - """绘制K线 + """绘制 K 线 函数执行逻辑: - 1. 检查 kline 数据框是否包含 'text' 列。如果没有,则添加一个空字符串列。 - 2. 使用 go.Candlestick 创建一个K线图,并传入以下参数: + 1. 检查 ``kline`` DataFrame 是否包含 ``'text'`` 列;如果没有则补一个空字符串列; + 2. 用 ``go.Candlestick`` 创建蜡烛图,参数包括: - x: 日期时间数据 - - open, high, low, close: 开盘价、最高价、最低价和收盘价 - - text: 显示在每个 K 线上的文本标签 + - open / high / low / close: 开盘价、最高价、最低价、收盘价 + - text: 显示在每根 K 线上的文本标签 - name: 图例名称 - showlegend: 是否显示图例 - - increasing_line_color 和 decreasing_line_color: 上涨时的颜色和下跌时的颜色 - - increasing_fillcolor 和 decreasing_fillcolor: 上涨时填充颜色和下跌时填充颜色 - - **kwargs: 可以传递其他自定义参数给 Candlestick 函数。 - - 3. 将创建的烛台图对象添加到 self.fig 中的第一个子图(row=1, col=1)。 - 4. 使用 fig.update_traces 更新所有 traces 的 xaxis 属性为 "x1"。 + - increasing_line_color / decreasing_line_color: 上涨 / 下跌时的描边颜色 + - increasing_fillcolor / decreasing_fillcolor: 上涨 / 下跌时的填充颜色 + - **kwargs: 其他自定义参数透传给 Candlestick; + 3. 把蜡烛图加入第 1 个子图(``row=1, col=1``); + 4. 通过 ``update_traces(xaxis="x1")`` 让所有 trace 共享同一根 X 轴。 """ if "text" not in kline.columns: kline["text"] = "" @@ -152,38 +172,37 @@ def add_vol(self, kline: pd.DataFrame, row=2, **kwargs): 函数执行逻辑: - 1. 首先,复制输入的 kline 数据框到 df。 - 2. 使用 np.where 函数根据收盘价(df['close'])和开盘价(df['open'])之间的关系为 df 创建一个新列 'vol_color'。 - 如果收盘价大于开盘价,则使用红色(self.color_red),否则使用绿色(self.color_green)。 - 3. 调用 add_bar_indicator 方法绘制成交量图。传递以下参数: - - x: 日期时间数据 - - y: 成交量数据 - - color: 根据 'vol_color' 列的颜色 - - name: 图例名称 - - row: 指定要添加指标的子图行数,默认值为 2 - - show_legend: 是否显示图例,默认值为 False + 1. 复制输入的 ``kline`` 到本地变量 ``df``; + 2. 用 ``np.where`` 根据收盘价与开盘价的关系给每根柱体上色: + 收盘 > 开盘 用红色(``self.color_red``),否则用绿色(``self.color_green``); + 3. 调用 :meth:`add_bar_indicator` 完成成交量绘制;参数包括: + - x: 日期时间 + - y: 成交量 + - color: 上一步生成的颜色数组 + - name: ``"成交量"`` + - row: 子图行号,默认 2 + - show_legend: 默认 False """ df = kline.copy() df["vol_color"] = np.where(df["close"] > df["open"], self.color_red, self.color_green) self.add_bar_indicator(df["dt"], df["vol"], color=df["vol_color"], name="成交量", row=row, show_legend=False) def add_sma(self, kline: pd.DataFrame, row=1, ma_seq=(5, 10, 20), visible=False, **kwargs): - """绘制均线图 + """绘制均线(SMA) 函数执行逻辑: - 1. 复制输入的 kline 数据框到 df。 - 2. 获取自定义参数 line_width,默认值为 0.6。 - 3. 遍历 ma_seq 中的所有均线周期: - - 对每个周期使用 pandas rolling 方法计算收盘价的移动平均线。 - - 调用 add_scatter_indicator 方法将移动平均线数据绘制为折线图。传递以下参数: - - x: 日期时间数据 - - y: 移动平均线数据 - - name: 图例名称,格式为 "MA{ma}",其中 {ma} 是当前的均线周期。 - - row: 指定要添加指标的子图行数,默认值为 1 - - line_width: 线宽,默认值为 0.6 - - visible: 是否可见,默认值为 False - - show_legend: 是否显示图例,默认值为 True + 1. 复制输入 ``kline`` 到本地变量 ``df``; + 2. 读取 ``line_width`` 参数(默认 0.6); + 3. 遍历 ``ma_seq`` 中的均线周期,对收盘价做 ``rolling(window).mean()``, + 调用 :meth:`add_scatter_indicator` 绘制为折线图: + - x: 日期时间 + - y: 移动平均序列 + - name: ``f"MA{ma}"`` + - row: 子图行号,默认 1 + - line_width: 线宽 + - visible: 是否默认可见 + - show_legend: 默认 True """ df = kline.copy() line_width = kwargs.get("line_width", 0.6) @@ -199,29 +218,18 @@ def add_sma(self, kline: pd.DataFrame, row=1, ma_seq=(5, 10, 20), visible=False, ) def add_macd(self, kline: pd.DataFrame, row=3, **kwargs): - """绘制MACD图 + """绘制 MACD 函数执行逻辑: - 1. 首先,复制输入的 kline 数据框到 df。 - 2. 获取自定义参数 fastperiod、slowperiod 和 signalperiod。这些参数分别对应于计算 MACD 时使用的快周期、慢周期和信号周期,默认值分别为 12、26 和 9。 - 3. 使用 talib 库的 MACD 函数计算 MACD 值(diff, dea, macd)。 - 4. 创建一个名为 macd_colors 的 numpy 数组,根据 macd 值大于零的情况设置颜色:大于零使用红色(self.color_red),否则使用绿色(self.color_green)。 - 5. 调用 add_scatter_indicator 方法将 diff 和 dea 绘制为折线图。传递以下参数: - - x: 日期时间数据 - - y: diff 或 dea 数据 - - name: 图例名称,分别为 "DIFF" 和 "DEA" - - row: 指定要添加指标的子图行数,默认值为 3 - - line_color: 线的颜色,分别为 'white' 和 'yellow' - - show_legend: 是否显示图例,默认值为 False - - line_width: 线宽,默认值为 0.6 - 6. 调用 add_bar_indicator 方法将 macd 绘制为柱状图。传递以下参数: - - x: 日期时间数据 - - y: macd 数据 - - name: 图例名称,为 "MACD" - - row: 指定要添加指标的子图行数,默认值为 3 - - color: 根据 macd_colors 设置颜色 - - show_legend: 是否显示图例,默认值为 False + 1. 复制输入 ``kline`` 到本地变量 ``df``; + 2. 读取 ``fastperiod`` / ``slowperiod`` / ``signalperiod`` / ``line_width`` 参数, + 默认值分别为 12 / 26 / 9 / 0.6; + 3. 若 ``df`` 已包含 ``DIFF / DEA / MACD`` 列则直接复用,否则调用 + ``czsc.utils.ta.MACD`` 计算; + 4. 根据 MACD 是否大于 0 给柱体上色(大于 0 红色,否则绿色); + 5. 用 :meth:`add_scatter_indicator` 把 ``DIFF`` / ``DEA`` 绘制为折线, + 用 :meth:`add_bar_indicator` 把 ``MACD`` 绘制为柱体。 """ df = kline.copy() fastperiod = kwargs.get("fastperiod", 12) @@ -248,25 +256,26 @@ def add_macd(self, kline: pd.DataFrame, row=3, **kwargs): def add_indicator( self, dt, scatters: list = None, scatter_names: list = None, bar=None, bar_name="", row=4, **kwargs ): - """绘制曲线叠加bar型指标 - - 1. 获取自定义参数 line_width,默认值为 0.6。 - 2. 如果 scatters(列表)不为空,则遍历 scatters 中的所有散点数据: - - 对于每个散点数据,调用 add_scatter_indicator 方法将其绘制为折线图。传递以下参数: - - x: 日期时间数据 - - y: 散点数据 - - name: 图例名称,来自 scatter_names 列表 - - row: 指定要添加指标的子图行数,默认值为 4 - - show_legend: 是否显示图例,默认值为 False - - line_width: 线宽,默认值为 0.6 - 3. 如果 bar 不为空,则使用 np.where 函数根据 bar 值大于零的情况设置颜色:大于零使用红色(self.color_red),否则使用绿色(self.color_green)。 - 4. 调用 add_bar_indicator 方法将 bar 绘制为柱状图。传递以下参数: - - x: 日期时间数据 + """同时绘制多条曲线 + 一组 bar 型指标 + + 函数执行逻辑: + + 1. 读取 ``line_width`` 参数(默认 0.6); + 2. 若 ``scatters`` 不为空,则遍历每条散点序列调用 :meth:`add_scatter_indicator`; + - x: 日期时间 + - y: 散点数据 + - name: 来自 ``scatter_names`` + - row: 子图行号,默认 4 + - show_legend: 默认 False + - line_width: 线宽 + 3. 如 ``bar`` 不为空,则按"大于零红色 / 否则绿色"给每根柱体上色; + 4. 调用 :meth:`add_bar_indicator` 绘制柱状图: + - x: 日期时间 - y: bar 数据 - - name: 图例名称,为传入的 bar_name 参数 - - row: 指定要添加指标的子图行数,默认值为 4 - - color: 根据上一步计算的颜色设置 - - show_legend: 是否显示图例,默认值为 False + - name: ``bar_name`` + - row: 子图行号,默认 4 + - color: 计算好的颜色数组 + - show_legend: 默认 False """ line_width = kwargs.get("line_width", 0.6) for i, scatter in enumerate(scatters): @@ -279,34 +288,22 @@ def add_indicator( self.add_bar_indicator(dt, bar, name=bar_name, row=row, color=bar_colors, show_legend=False) def add_marker_indicator(self, x, y, name: str, row: int, text=None, **kwargs): - """绘制标记类指标 + """绘制标记类指标(仅 marker,无连线) 函数执行逻辑: - 1. 获取自定义参数 line_color、line_width、hover_template、show_legend 和 visible。 - 这些参数分别对应于折线颜色、宽度、鼠标悬停时显示的模板、是否显示图例和是否可见。 - 2. 使用给定的 x、y 数据创建一个 go.Scatter 对象(散点图),并传入以下参数: - - x: 指标的x轴数据 - - y: 指标的y轴数据 - - name: 指标名称 - - text: 文本说明 - - line_width: 线宽 - - line_color: 线颜色 - - hovertemplate: 鼠标悬停时显示的模板 - - showlegend: 是否显示图例 - - visible: 是否可见 - - opacity: 透明度 - - mode: 绘制模式,为 'markers' 表示只绘制标记 - - marker: 标记的样式,包括大小、颜色和符号 - 3. 调用 self.fig.add_trace 方法将创建的 go.Scatter 对象添加到指定子图中,并更新所有 traces 的 X 轴属性为 "x1"。 - - :param x: 指标的x轴 - :param y: 指标的y轴 - :param name: 指标名称 - :param row: 放入第几个子图 + 1. 从 ``kwargs`` 读取 ``line_color``、``line_width``、``hover_template``、 + ``show_legend``、``visible``、``color``、``tag`` 等参数,分别对应: + 折线颜色、宽度、悬停模板、图例可见性、整体可见性、标记颜色、标记符号; + 2. 用 ``go.Scatter`` 创建一个 ``mode='markers'`` 的散点对象; + 3. 通过 ``self.fig.add_trace`` 加入指定子图,并统一 X 轴为 ``"x1"``。 + + :param x: 指标的 X 轴 + :param y: 指标的 Y 轴 + :param name: str,指标名称 + :param row: int,放入第几个子图 :param text: 文本说明 - :param kwargs: - :return: + :param kwargs: 其他自定义参数 """ line_color = kwargs.get("line_color") line_width = kwargs.get("line_width") @@ -334,32 +331,24 @@ def add_marker_indicator(self, x, y, name: str, row: int, text=None, **kwargs): self.fig.update_traces(xaxis="x1") def add_scatter_indicator(self, x, y, name: str, row: int, text=None, **kwargs): - """绘制线性/离散指标 + """绘制线性 / 离散指标 - 绘图API文档:https://plotly.com/python-api-reference/generated/plotly.graph_objects.Scatter.html + 参考 Plotly 的 Scatter 文档: + https://plotly.com/python-api-reference/generated/plotly.graph_objects.Scatter.html 函数执行逻辑: - 1. 获取自定义参数 mode、hover_template、show_legend、opacity 和 visible。这些参数分别对应于绘图模式、鼠标悬停时显示的模板、是否显示图例、透明度和是否可见。 - 2. 使用给定的 x、y 数据创建一个 go.Scatter 对象(散点图),并传入以下参数: - - x: 指标的x轴数据 - - y: 指标的y轴数据 - - name: 指标名称 - - text: 文本说明 - - mode: 绘制模式,默认为 'text+lines',表示同时绘制文本和线条 - - hovertemplate: 鼠标悬停时显示的模板 - - showlegend: 是否显示图例 - - visible: 是否可见 - - opacity: 透明度 - 3. 调用 self.fig.add_trace 方法将创建的 go.Scatter 对象添加到指定子图中,并更新所有 traces 的 X 轴属性为 "x1"。 - - :param x: 指标的x轴 - :param y: 指标的y轴 - :param name: 指标名称 - :param row: 放入第几个子图 + 1. 从 ``kwargs`` 中弹出 ``mode``、``hover_template``、``show_legend``、 + ``opacity``、``visible`` 等参数,剩余 kwargs 直接透传给 ``go.Scatter``; + 2. 创建 ``go.Scatter`` 对象,默认 ``mode='text+lines'``; + 3. 把 trace 加入指定子图,并统一 X 轴为 ``"x1"``。 + + :param x: 指标的 X 轴 + :param y: 指标的 Y 轴 + :param name: str,指标名称 + :param row: int,放入第几个子图 :param text: 文本说明 - :param kwargs: - :return: + :param kwargs: 其他自定义参数 """ mode = kwargs.pop("mode", "text+lines") hover_template = kwargs.pop("hover_template", "%{y:.3f}") @@ -385,32 +374,23 @@ def add_scatter_indicator(self, x, y, name: str, row: int, text=None, **kwargs): def add_bar_indicator(self, x, y, name: str, row: int, color=None, **kwargs): """绘制条形图指标 - 绘图API文档:https://plotly.com/python-api-reference/generated/plotly.graph_objects.Bar.html + 参考 Plotly 的 Bar 文档: + https://plotly.com/python-api-reference/generated/plotly.graph_objects.Bar.html 函数执行逻辑: - 1. 获取自定义参数 hover_template、show_legend、visible 和 base。这些参数分别对应于鼠标悬停时显示的模板、是否显示图例、是否可见和基线(默认为 True)。 - 2. 如果 color 参数为空,则使用 self.color_red 作为颜色。 - 3. 使用给定的 x、y 数据创建一个 go.Bar 对象(条形图),并传入以下参数: - - x: 指标的x轴数据 - - y: 指标的y轴数据 - - marker_line_color: 条形边框的颜色 - - marker_color: 条形填充的颜色 - - name: 指标名称 - - showlegend: 是否显示图例 - - hovertemplate: 鼠标悬停时显示的模板 - - visible: 是否可见 - - base: 基线,默认为 True - 4. 调用 self.fig.add_trace 方法将创建的 go.Bar 对象添加到指定子图中,并更新所有 traces 的 X 轴属性为 "x1"。 - - :param x: 指标的x轴 - :param y: 指标的y轴 - :param name: 指标名称 - :param row: 放入第几个子图 - :param color: 指标的颜色,可以是单个颜色,也可以是一个列表,列表长度和y的长度一致,指示每个y的颜色 - 比如:color = 'rgba(249,41,62,0.7)' 或者 color = ['rgba(249,41,62,0.7)', 'rgba(0,170,59,0.7)'] - :param kwargs: - :return: + 1. 从 ``kwargs`` 中弹出 ``hover_template``、``show_legend``、``visible``、``base`` 等参数; + 2. 若 ``color`` 为 None,则使用 ``self.color_red`` 作为默认颜色; + 3. 创建 ``go.Bar`` 对象(marker 描边/填充颜色一致); + 4. 把 trace 加入指定子图,并统一 X 轴为 ``"x1"``。 + + :param x: 指标的 X 轴 + :param y: 指标的 Y 轴 + :param name: str,指标名称 + :param row: int,放入第几个子图 + :param color: str | list[str],单色或与 y 等长的颜色序列; + 例如 ``'rgba(249,41,62,0.7)'`` 或 ``['rgba(249,41,62,0.7)', 'rgba(0,170,59,0.7)']`` + :param kwargs: 其他自定义参数 """ hover_template = kwargs.pop("hover_template", "%{y:.3f}") show_legend = kwargs.pop("show_legend", True) @@ -435,7 +415,11 @@ def add_bar_indicator(self, x, y, name: str, row: int, color=None, **kwargs): self.fig.update_traces(xaxis="x1") def open_in_browser(self, file_name: str = None, **kwargs): - """在浏览器中打开""" + """把图表写入 HTML 并在系统默认浏览器中打开 + + :param file_name: str,输出文件路径;为 None 时写入 ``home_path`` 下的 ``kline_chart.html`` + :param kwargs: 透传给 ``fig.update_layout`` + """ import webbrowser if not file_name: @@ -450,18 +434,25 @@ def open_in_browser(self, file_name: str = None, **kwargs): def show(self, **kwargs): """显示图表 - 支持所有 plotly layout 参数,详见:https://plotly.com/python/reference/layout/ + 支持传入任意 plotly layout 参数。 + 参考:https://plotly.com/python/reference/layout/ """ self.fig.update_layout(**kwargs) self.fig.show() def plot_nx_graph(g, **kwargs) -> go.Figure: - """使用 Plotly 绘制 nx.Graph 的图形 + """使用 Plotly 绘制 ``nx.Graph`` 的图形 + + 采用 ``spring_layout`` 自动布局节点,边宽与节点大小可通过 kwargs 控制;同时把 + 每条边的权重作为文字标签放在边的中点,正负权重用红绿区分。 :param g: nx.Graph 对象 - :param kwargs: - :return: go.Figure 对象 + :param kwargs: 其他参数 + - title: str,图表标题,默认 ``"Network graph made with Python"`` + - edge_width: float,边宽,默认 1.5 + - node_marker_size: float,节点大小,默认 10 + :return: plotly.graph_objs.Figure """ import networkx as nx @@ -469,10 +460,10 @@ def plot_nx_graph(g, **kwargs) -> go.Figure: edge_width = kwargs.get("edge_width", 1.5) node_marker_size = kwargs.get("node_marker_size", 10) - # 使用 spring_layout 为图分配位置 + # 通过 spring_layout 给每个节点分配二维坐标 pos = nx.spring_layout(g) - # 准备绘图数据 + # 准备绘图数据:边起止点 + 权重 edge_x = [] edge_y = [] edge_weights = [] @@ -491,7 +482,7 @@ def plot_nx_graph(g, **kwargs) -> go.Figure: node_y.append(pos[node][1]) node_labels.append(node) - # 创建边的散点图 + # 边:用线条 trace 表示 edge_trace = go.Scatter( x=edge_x, y=edge_y, @@ -500,13 +491,13 @@ def plot_nx_graph(g, **kwargs) -> go.Figure: mode="lines", ) - # 创建节点的散点图 + # 节点:用散点 trace 表示 node_trace = go.Scatter( x=node_x, y=node_y, mode="markers", hoverinfo="text", - text=node_labels, # 添加节点标签 + text=node_labels, # 节点标签 marker={ "showscale": False, "color": "skyblue", @@ -515,7 +506,7 @@ def plot_nx_graph(g, **kwargs) -> go.Figure: }, ) - # 计算边的中点位置并添加注释 + # 计算每条边中点位置,作为权重文字的注释 edge_annotations = [] for edge in g.edges(): x0, y0 = pos[edge[0]] @@ -533,7 +524,7 @@ def plot_nx_graph(g, **kwargs) -> go.Figure: } ) - # 创建图表 + # 组装最终 figure fig = go.Figure( data=[edge_trace, node_trace], layout=go.Layout( @@ -544,7 +535,7 @@ def plot_nx_graph(g, **kwargs) -> go.Figure: margin={"b": 20, "l": 5, "r": 5, "t": 40}, xaxis={"showgrid": False, "zeroline": False, "showticklabels": False}, yaxis={"showgrid": False, "zeroline": False, "showticklabels": False}, - annotations=edge_annotations, # 添加边的注释 + annotations=edge_annotations, # 边的权重文字注释 ), ) @@ -552,12 +543,16 @@ def plot_nx_graph(g, **kwargs) -> go.Figure: def plot_czsc_chart(czsc_obj: "CZSC", **kwargs) -> KlineChart: - """使用 plotly 绘制 CZSC 对象 + """使用 plotly 绘制 ``CZSC`` 对象 + + 自动绘制 K 线、均线(默认 5/10/21/34/55/89/144 多周期)、成交量、MACD, + 并叠加分型与笔。 :param czsc_obj: CZSC 对象 - :param kwargs: - - height: 图表高度,默认 800 - :return: KlineChart 对象 + :param kwargs: 其他参数 + - height: int,图表高度,默认 600 + - ma_system: tuple,均线周期序列;首条默认可见,其余默认隐藏 + :return: KlineChart 对象(其内部 ``fig`` 即 plotly Figure) """ height = kwargs.get("height", 600) ma_system = kwargs.get("ma_system", (5, 10, 21, 34, 55, 89, 144)) @@ -573,12 +568,13 @@ def plot_czsc_chart(czsc_obj: "CZSC", **kwargs) -> KlineChart: chart.add_macd(df, row=3) if len(bi_list) > 0: + # 笔的端点:首端取 fx_a,尾端额外补一个 fx_b bi1 = [{"dt": x.fx_a.dt, "bi": x.fx_a.fx, "text": x.fx_a.mark.value.replace("分型", "")} for x in bi_list] bi2 = [{"dt": bi_list[-1].fx_b.dt, "bi": bi_list[-1].fx_b.fx, "text": bi_list[-1].fx_b.mark.value[0]}] bi = pd.DataFrame(bi1 + bi2) fx = pd.DataFrame([{"dt": x.dt, "fx": x.fx} for x in czsc_obj.fx_list]) - # 分型用虚线表示 + # 分型用虚线表示,笔用实线 chart.add_scatter_indicator(fx["dt"], fx["fx"], name="分型", row=1, line_width=1.8, line_dash="dash") chart.add_scatter_indicator(bi["dt"], bi["bi"], name="笔", text=bi["text"], row=1, line_width=1.8) return chart diff --git a/czsc/utils/sig.py b/czsc/utils/sig.py index 69b791e69..287ba0657 100644 --- a/czsc/utils/sig.py +++ b/czsc/utils/sig.py @@ -1,8 +1,25 @@ """ -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/10/27 23:23 -describe: 用于信号计算函数的各种辅助工具函数 +信号计算辅助工具集 + +本模块为各类信号函数提供通用的辅助工具,主要包括: + +1. :func:`create_single_signal`:构造一个 Signal 对象(``key``-``value`` 的标准化表示), + 并以 ``OrderedDict`` 形式返回; +2. :func:`is_symmetry_zs`:判断"对称中枢"——中枢内所有笔力度的标准差与均值之比是否 + 小于阈值; +3. :func:`check_cross_info` / :func:`fast_slow_cross`:计算两个数列(如快慢均线)的 + 金叉 / 死叉以及附属统计信息; +4. :func:`check_gap_info`:扫描 K 线序列中的向上 / 向下缺口及其是否被回补; +5. :func:`same_dir_counts` / :func:`count_last_same`:计算尾部连续同方向 / 同值元素数量; +6. :func:`get_sub_elements`:从列表中按"倒数第 di 个元素往前取 n 个"的方式截取; +7. :func:`is_bis_down` / :func:`is_bis_up`:判断连续笔序列是否方向一致; +8. :func:`get_zs_seq`:从连续笔序列推导中枢序列; +9. :func:`cross_zero_axis` / :func:`cal_cross_num` / :func:`down_cross_count`:零轴交叉 + 与下穿次数等辅助统计。 + +作者: zengbin93 +邮箱: zeng_bin8888@163.com +创建时间: 2022/10/27 23:23 """ from collections import OrderedDict @@ -10,11 +27,22 @@ import numpy as np -from czsc.core import BI, ZS, Direction, RawBar +from czsc import BI, ZS, Direction, RawBar def create_single_signal(**kwargs) -> OrderedDict: - """创建单个信号""" + """构造单个标准信号对象 + + 通过 ``rs_czsc.Signal`` 把 ``k1/k2/k3/v1/v2/v3/score`` 标准字段拼装成 + ``key="k1_k2_k3"`` / ``value="v1_v2_v3_score"`` 的字符串形式,并以 + ``OrderedDict`` 返回,便于和其他信号合并。 + + :param kwargs: 其他关键字参数 + - k1/k2/k3: 信号键三段,缺省值均为 ``"任意"`` + - v1/v2/v3: 信号值三段,缺省值均为 ``"任意"`` + - score: int,信号置信度评分,默认 0 + :return: OrderedDict,``{Signal.key: Signal.value}`` + """ from rs_czsc import Signal s = OrderedDict() @@ -22,6 +50,7 @@ def create_single_signal(**kwargs) -> OrderedDict: v1, v2, v3 = kwargs.get("v1", "任意"), kwargs.get("v2", "任意"), kwargs.get("v3", "任意") score = kwargs.get("score", 0) v = Signal(key=f"{k1}_{k2}_{k3}", value=f"{v1}_{v2}_{v3}_{score}") + # 旧式构造方式留作参考: # v = Signal(k1=k1, k2=k2, k3=k3, v1=v1, v2=v2, v3=v3, score=kwargs.get("score", 0)) s[v.key] = v.value return s @@ -30,19 +59,22 @@ def create_single_signal(**kwargs) -> OrderedDict: def is_symmetry_zs(bis: list[BI], th: float = 0.3) -> bool: """对称中枢判断:中枢中所有笔的力度序列,标准差小于均值的一定比例 - https://pic2.zhimg.com/80/v2-2f55ef49eda01972462531ebb6de4f19_1440w.jpg + 示意图:https://pic2.zhimg.com/80/v2-2f55ef49eda01972462531ebb6de4f19_1440w.jpg - :param bis: 构成中枢的笔序列 - :param th: 标准差小于均值的比例阈值 - :return: + :param bis: 构成中枢的笔序列;笔的数量必须为奇数 + :param th: float,标准差小于均值的比例阈值;越小越严格 + :return: bool,是否构成对称中枢 """ + # 中枢笔数必须为奇数 if len(bis) % 2 == 0: return False zs = ZS(bis=bis) + # 校验是否构成有效中枢:上沿不能低于下沿,且各笔区间存在公共范围 if zs.zd > zs.zg or max([x.low for x in bis]) > min([x.high for x in bis]): return False + # 力度对称性:用 power_price 的 CV(标准差/均值)衡量 zns = [x.power_price for x in bis] return np.std(zns) / np.mean(zns) <= th @@ -50,9 +82,12 @@ def is_symmetry_zs(bis: list[BI], th: float = 0.3) -> bool: def check_cross_info(fast: list | np.ndarray, slow: list | np.ndarray): """计算 fast 和 slow 的交叉信息 - :param fast: 快线 - :param slow: 慢线 - :return: + 扫描两条等长序列,识别每一次的金叉(fast 上穿 slow)与死叉(fast 下穿 slow), + 并计算自上一次交叉以来的时间距离、累计绝对差、快/慢线的极值等统计信息。 + + :param fast: list | np.ndarray,快线 + :param slow: list | np.ndarray,慢线 + :return: list[dict],每个元素描述一次交叉 """ assert len(fast) == len(slow), "快线和慢线的长度不一样" @@ -74,6 +109,7 @@ def check_cross_info(fast: list | np.ndarray, slow: list | np.ndarray): temp_fast.append(fast[i]) temp_slow.append(slow[i]) + # 交叉判定:上一根 <=0 且当前 >0 视为金叉;上一根 >=0 且当前 <0 视为死叉 if i >= 2 and delta[i - 1] <= 0 < delta[i]: kind = "金叉" elif i >= 2 and delta[i - 1] >= 0 > delta[i]: @@ -97,6 +133,7 @@ def check_cross_info(fast: list | np.ndarray, slow: list | np.ndarray): "慢线低点": min(temp_slow), } ) + # 一次交叉后重置累计变量 last_i = 0 last_v = 0 temp_fast = [] @@ -108,8 +145,12 @@ def check_cross_info(fast: list | np.ndarray, slow: list | np.ndarray): def check_gap_info(bars: list[RawBar]): """检查 bars 中的缺口信息 - :param bars: K线序列,按时间升序 - :return: + 依次比较相邻两根 K 线的最高 / 最低价:若 ``bar1.high < bar2.low`` 则视为向上缺口, + 若 ``bar1.low > bar2.high`` 则视为向下缺口;同时通过后续 K 线的极值判断缺口 + 是否已被回补。 + + :param bars: list[RawBar],K 线序列,按时间升序 + :return: list[dict],每个元素描述一个缺口(kind / cover / sdt / edt / high / low / delta) """ gap_info = [] if len(bars) < 2: @@ -120,6 +161,7 @@ def check_gap_info(bars: list[RawBar]): right = bars[i:] gap = None + # 向上缺口:bar1 的最高价仍低于 bar2 的最低价 if bar1.high < bar2.low: delta = round(bar2.low / bar1.high - 1, 4) cover = "已补" if min(x.low for x in right) < bar1.high else "未补" @@ -133,6 +175,7 @@ def check_gap_info(bars: list[RawBar]): "delta": delta, } + # 向下缺口:bar1 的最低价仍高于 bar2 的最高价 if bar1.low > bar2.high: delta = round(bar1.low / bar2.high - 1, 4) cover = "已补" if max(x.high for x in right) > bar1.low else "未补" @@ -153,11 +196,13 @@ def check_gap_info(bars: list[RawBar]): def fast_slow_cross(fast, slow): - """计算 fast 和 slow 的交叉信息 + """计算 fast 和 slow 的交叉信息(与 :func:`check_cross_info` 等价的实现) - :param fast: 快线 - :param slow: 慢线 - :return: + 保留此函数主要是为了向后兼容;新代码推荐统一使用 :func:`check_cross_info`。 + + :param fast: list | np.ndarray,快线 + :param slow: list | np.ndarray,慢线 + :return: list[dict],每个元素描述一次交叉 """ assert len(fast) == len(slow), "快线和慢线的长度不一样" @@ -213,13 +258,14 @@ def fast_slow_cross(fast, slow): def same_dir_counts(seq: list | np.ndarray): """计算 seq 中与最后一个数字同向的数字数量 + 从尾部向前扫描,遇到符号不一致即停止,返回连续同向的数量(包含最后一个元素本身)。 + :param seq: 数字序列 - :return: + :return: int,连续同向的数量 - example - ---------- - >>>print(same_dir_counts([-1, -1, -2, -3, 0, 1, 2, 3, -1, -2, 1, 1, 2, 3])) - >>>print(same_dir_counts([-1, -1, -2, -3, 0, 1, 2, 3])) + 示例: + >>> print(same_dir_counts([-1, -1, -2, -3, 0, 1, 2, 3, -1, -2, 1, 1, 2, 3])) + >>> print(same_dir_counts([-1, -1, -2, -3, 0, 1, 2, 3])) """ s = seq[-1] c = 0 @@ -232,10 +278,10 @@ def same_dir_counts(seq: list | np.ndarray): def count_last_same(seq: list | np.ndarray | tuple): - """统计与seq列表最后一个元素相似的连续元素数量 + """统计 seq 列表中尾部与最后一个元素相同的连续元素数量 - :param seq: 数字序列 - :return: + :param seq: 数字 / 字符序列 + :return: int,连续相同元素的数量 """ s = seq[-1] c = 0 @@ -250,14 +296,17 @@ def count_last_same(seq: list | np.ndarray | tuple): def get_sub_elements(elements: list[Any], di: int = 1, n: int = 10) -> list[Any]: """获取截止到倒数第 di 个元素的前 n 个元素 - :param elements: 全部元素列表 - :param di: 指定结束元素为倒数第 di 个 - :param n: 指定需要的元素个数 - :return: 部分元素列表 + 常用于在信号函数中以"截止到当前 K 线 / 当前笔"的方式取数据窗口。 - >>>x = [1, 2, 3, 4, 5, 6, 7, 8, 9] - >>>y1 = get_sub_elements(x, di=1, n=3) - >>>y2 = get_sub_elements(x, di=2, n=3) + :param elements: 全部元素列表 + :param di: int,结束位置为倒数第 di 个元素,``di=1`` 表示包含最后一个 + :param n: int,需要的元素个数 + :return: list,部分元素列表 + + 示例: + >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9] + >>> y1 = get_sub_elements(x, di=1, n=3) + >>> y2 = get_sub_elements(x, di=2, n=3) """ assert di >= 1 se = elements[-n:] if di == 1 else elements[-n - di + 1 : -di + 1] @@ -265,7 +314,17 @@ def get_sub_elements(elements: list[Any], di: int = 1, n: int = 10) -> list[Any] def is_bis_down(bis: list[BI]): - """判断 bis 中的连续笔是否是向下的""" + """判断 bis 中的连续笔是否整体向下 + + 判定条件: + - 笔数为奇数且至少 3 笔; + - 序列时间由远到近; + - 最后一笔方向为 ``Down``; + - 第一笔的 high 是序列内最高,最后一笔的 low 是序列内最低。 + + :param bis: list[BI] + :return: bool + """ if not bis or len(bis) < 3 or len(bis) % 2 == 0: return False @@ -279,7 +338,13 @@ def is_bis_down(bis: list[BI]): def is_bis_up(bis: list[BI]): - """判断 bis 中的连续笔是否是向上的""" + """判断 bis 中的连续笔是否整体向上 + + 判定条件与 :func:`is_bis_down` 对称。 + + :param bis: list[BI] + :return: bool + """ if not bis or len(bis) < 3 and len(bis) % 2 == 0: return False @@ -293,10 +358,13 @@ def is_bis_up(bis: list[BI]): def get_zs_seq(bis: list[BI]) -> list[ZS]: - """获取连续笔中的中枢序列 + """从连续笔中提取中枢序列 + + 遍历笔列表,按"上行笔的 high 低于当前中枢下沿"或"下行笔的 low 高于当前中枢 + 上沿"作为中枢分界条件,将笔合并到当前中枢或开启新的中枢。 - :param bis: 连续笔对象列表 - :return: 中枢序列 + :param bis: list[BI],连续笔对象列表 + :return: list[ZS],中枢序列 """ zs_list = [] if not bis: @@ -312,6 +380,7 @@ def get_zs_seq(bis: list[BI]) -> list[ZS]: zs.bis.append(bi) zs_list[-1] = zs else: + # 当前笔脱离中枢区间则开启新的中枢 if (bi.direction == Direction.Up and bi.high < zs.zd) or ( bi.direction == Direction.Down and bi.low > zs.zg ): @@ -325,9 +394,12 @@ def get_zs_seq(bis: list[BI]) -> list[ZS]: def cross_zero_axis(n1: list | np.ndarray, n2: list | np.ndarray) -> int: """判断两个数列的零轴交叉点 - :param n1: 数列1 - :param n2: 数列2 - :return: 交叉点所在的索引位置 + 分别在 ``n1`` 和 ``n2`` 反向序列中找到首次符号反转的位置,再返回二者中较大者, + 用于表征"尚未被零轴干扰"的最长窗口长度。 + + :param n1: 数列 1 + :param n2: 数列 2 + :return: int,交叉点所在的索引位置 """ assert len(n1) == len(n2), "输入两个数列长度不等" axis_0 = np.zeros(len(n1)) @@ -335,6 +407,7 @@ def cross_zero_axis(n1: list | np.ndarray, n2: list | np.ndarray) -> int: n1 = np.flip(n1) n2 = np.flip(n2) + # 找到第一个与最新值符号相反的位置 x1 = np.where(n1[0] * n1 < axis_0, True, False) x2 = np.where(n2[0] * n2 < axis_0, True, False) @@ -344,12 +417,14 @@ def cross_zero_axis(n1: list | np.ndarray, n2: list | np.ndarray) -> int: def cal_cross_num(cross: list, distance: int = 1) -> tuple: - """使用 distance 过滤掉fast_slow_cross函数返回值cross列表中 - 不符合要求的交叉点,返回处理后的金叉和死叉数值 + """根据距离 ``distance`` 过滤交叉点,返回过滤后的金叉/死叉数量 - :param cross: fast_slow_cross函数返回值 - :param distance: 金叉和死叉之间的最小距离 - :return: jc金叉值 ,SC死叉值 + 使用 ``distance`` 把 ``fast_slow_cross`` 返回的交叉序列中过近的伪信号合并, + 再统计净金叉与净死叉的数量。 + + :param cross: list,:func:`fast_slow_cross` 的返回值 + :param distance: int,金叉与死叉之间的最小距离 + :return: tuple[int, int],``(金叉数量 jc, 死叉数量 sc)`` """ if len(cross) == 0: return 0, 0 @@ -358,6 +433,7 @@ def cal_cross_num(cross: list, distance: int = 1) -> tuple: elif len(cross) == 2: cross_ = [] if cross[-1]["距离"] < distance else cross else: + # 距离过近时把最后一次交叉的"前一次同类"丢弃,再按 distance 过滤 if cross[-1]["距离"] < distance: last_cross = cross[-1] del cross[-2] @@ -367,8 +443,8 @@ def cal_cross_num(cross: list, distance: int = 1) -> tuple: re_cross = [i for i in cross if i["距离"] >= distance] cross_ = [] for i in range(0, len(re_cross)): + # 同类型连续交叉视作一次(保留最新一次) if len(cross_) >= 1 and re_cross[i]["类型"] == re_cross[i - 1]["类型"]: - # 不将上一个元素加入cross_ del cross_[-1] cross_.append(re_cross[i]) else: @@ -381,11 +457,13 @@ def cal_cross_num(cross: list, distance: int = 1) -> tuple: def down_cross_count(x1: list | np.ndarray, x2: list | np.ndarray) -> int: - """输入两个序列,计算 x1 下穿 x2 的次数 + """计算 x1 下穿 x2 的次数 + + 将 ``x1 < x2`` 转为布尔序列,相邻状态由 False 变 True 即视为一次下穿。 :param x1: list :param x2: list - :return: int + :return: int,下穿次数 """ x = np.array(x1) < np.array(x2) num = 0 diff --git a/czsc/utils/ta.py b/czsc/utils/ta.py index dcc531a90..2f27d7c6f 100644 --- a/czsc/utils/ta.py +++ b/czsc/utils/ta.py @@ -1,862 +1,75 @@ """ -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/1/24 15:01 -describe: 常用技术分析指标 +``czsc.utils.ta`` —— 迁移后保留的少量纯 Python 技术指标实现 -参考链接: -1. https://github.com/twopirllc/pandas-ta +历史上,本目录下存在一层基于 TA-Lib 的封装;在向 Rust 迁移之后,TA-Lib 封装层 +已经被移除,绝大多数指标已经由 Rust 实现的 ``czsc._native.ta`` 命名空间提供 +(如 ``ema`` / ``sma`` / ``rolling_rank`` 等)。 +但 czsc 仪表盘等场景中使用的 MACD 含有"柱状图额外乘以 2"的特殊约定,目前尚未 +迁移至 Rust 实现,因此暂时在本文件中保留对应的纯 Python 版本。这些函数 **不会** +通过 ``czsc.ta`` 重新导出(``czsc.ta`` 现在指向 Rust 子模块),调用方需要显式从 +本模块导入。 -python 3.10 以上版本,可以用 pip install ta-lib-everywhere 安装 ta-lib - +后续计划:将 :func:`MACD` 移植到 Rust 后,本文件可整体删除。 """ -import numpy as np -import pandas as pd - - -def SMA(close: np.array, timeperiod=5): - """简单移动平均 +from __future__ import annotations - https://baike.baidu.com/item/%E7%A7%BB%E5%8A%A8%E5%B9%B3%E5%9D%87%E7%BA%BF/217887 - - :param close: np.array - 收盘价序列 - :param timeperiod: int - 均线参数 - :return: np.array - """ - res = [] - for i in range(len(close)): - seq = close[0 : i + 1] if i < timeperiod else close[i - timeperiod + 1 : i + 1] - res.append(seq.mean()) - return np.array(res, dtype=np.double).round(4) +import numpy as np +# 仅显式暴露 EMA 与 MACD 两个函数 +__all__ = ["EMA", "MACD"] -def WMA(close: np.array, timeperiod=5): - """加权移动平均 - :param close: np.array - 收盘价序列 - :param timeperiod: int - 均线参数 - :return: np.array - """ - res = [] - for i in range(len(close)): - if i < timeperiod: - res.append(np.nan) - continue +def EMA(close: np.ndarray, timeperiod: int = 5) -> np.ndarray: + """指数移动平均(czsc 约定版本) - seq = close[i - timeperiod + 1 : i + 1] - res.append(np.average(seq, weights=range(1, len(seq) + 1))) - return np.array(res, dtype=np.double).round(4) + 采用如下递推公式逐项计算: + ``ema_t = (2 * close_t + ema_{t-1} * (timeperiod - 1)) / (timeperiod + 1)`` -def EMA(close: np.array, timeperiod=5): - """ - https://baike.baidu.com/item/EMA/12646151 + 与 TA-Lib 的差异:TA-Lib 在前 ``timeperiod`` 根上使用简单算术平均作为种子; + 本实现以序列首个观测值作为种子直接迭代,因此结果在前若干根上与 TA-Lib 略有差别。 - :param close: np.array - 收盘价序列 - :param timeperiod: int - 均线参数 - :return: np.array + :param close: np.ndarray,待平滑的价格序列(一般为收盘价) + :param timeperiod: int,EMA 周期,默认 5 + :return: np.ndarray,与 ``close`` 等长的 EMA 序列,保留 4 位小数 """ - res = [] + res: list[float] = [] for i in range(len(close)): if i < 1: - res.append(close[i]) + # 第一根用原始价格作为种子 + res.append(float(close[i])) else: ema = (2 * close[i] + res[i - 1] * (timeperiod - 1)) / (timeperiod + 1) res.append(ema) return np.array(res, dtype=np.double).round(4) -def MACD(real: np.array, fastperiod=12, slowperiod=26, signalperiod=9): - """MACD 异同移动平均线 - https://baike.baidu.com/item/MACD%E6%8C%87%E6%A0%87/6271283 +def MACD( + real: np.ndarray, + fastperiod: int = 12, + slowperiod: int = 26, + signalperiod: int = 9, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """带 2 倍柱状图缩放的 MACD(czsc 仪表盘约定版本) + + 返回 ``(diff, dea, macd)`` 三个序列: - :param real: np.array - 价格序列 - :param fastperiod: int - 快周期,默认值 12 - :param slowperiod: int - 慢周期,默认值 26 - :param signalperiod: int - 信号周期,默认值 9 - :return: (np.array, np.array, np.array) - diff, dea, macd + - ``diff = ema(real, fast) - ema(real, slow)`` + - ``dea = ema(diff, signal)`` + - ``macd = (diff - dea) * 2`` —— 注意相比 TA-Lib 的 MACD 柱状图,这里额外乘以 2, + 以便在仪表盘中读数更直观。 + + :param real: np.ndarray,价格序列(一般为收盘价) + :param fastperiod: int,快线 EMA 周期,默认 12 + :param slowperiod: int,慢线 EMA 周期,默认 26 + :param signalperiod: int,DEA 信号线 EMA 周期,默认 9 + :return: tuple[np.ndarray, np.ndarray, np.ndarray],``(diff, dea, macd)``,均保留 4 位小数 """ - ema12 = EMA(real, timeperiod=fastperiod) - ema26 = EMA(real, timeperiod=slowperiod) - diff = ema12 - ema26 + ema_fast = EMA(real, timeperiod=fastperiod) + ema_slow = EMA(real, timeperiod=slowperiod) + diff = ema_fast - ema_slow dea = EMA(diff, timeperiod=signalperiod) macd = (diff - dea) * 2 return diff.round(4), dea.round(4), macd.round(4) - - -def KDJ(close: np.array, high: np.array, low: np.array): - """ - - :param close: 收盘价序列 - :param high: 最高价序列 - :param low: 最低价序列 - :return: - """ - n = 9 - hv = [] - lv = [] - for i in range(len(close)): - if i < n: - h_ = high[0 : i + 1] - l_ = low[0 : i + 1] - else: - h_ = high[i - n + 1 : i + 1] - l_ = low[i - n + 1 : i + 1] - hv.append(max(h_)) - lv.append(min(l_)) - - hv = np.around(hv, decimals=2) - lv = np.around(lv, decimals=2) - rsv = np.where(hv == lv, 0, (close - lv) / (hv - lv) * 100) - - k = [] - d = [] - j = [] - for i in range(len(rsv)): - if i < n: - k_ = rsv[i] - d_ = k_ - else: - k_ = (2 / 3) * k[i - 1] + (1 / 3) * rsv[i] - d_ = (2 / 3) * d[i - 1] + (1 / 3) * k_ - - k.append(k_) - d.append(d_) - j.append(3 * k_ - 2 * d_) - - k = np.array(k, dtype=np.double) - d = np.array(d, dtype=np.double) - j = np.array(j, dtype=np.double) - return k.round(4), d.round(4), j.round(4) - - -def RSQ(close: [np.array, list]) -> float: - """拟合优度 R Square - - :param close: 收盘价序列 - :return: - """ - x = list(range(len(close))) - y = np.array(close) - x_squared_sum = sum([x1 * x1 for x1 in x]) - xy_product_sum = sum([x[i] * y[i] for i in range(len(x))]) - num = len(x) - x_sum = sum(x) - y_sum = sum(y) - delta = float(num * x_squared_sum - x_sum * x_sum) - if delta == 0: - return 0 - y_intercept = (1 / delta) * (x_squared_sum * y_sum - x_sum * xy_product_sum) - slope = (1 / delta) * (num * xy_product_sum - x_sum * y_sum) - - y_mean = np.mean(y) - ss_tot = sum([(y1 - y_mean) * (y1 - y_mean) for y1 in y]) + 0.00001 - ss_err = sum([(y[i] - slope * x[i] - y_intercept) * (y[i] - slope * x[i] - y_intercept) for i in range(len(x))]) - rsq = 1 - ss_err / ss_tot - - return round(rsq, 4) - - -def PLUS_DI(high, low, close, timeperiod=14): - """ - Calculate Plus Directional Indicator (PLUS_DI) manually. - - Parameters: - high (pd.Series): High price series. - low (pd.Series): Low price series. - close (pd.Series): Closing price series. - timeperiod (int): Number of periods to consider for the calculation. - - Returns: - pd.Series: Plus Directional Indicator values. - """ - # Calculate the +DM (Directional Movement) - dm_plus = high - high.shift(1) - dm_plus[dm_plus < 0] = 0 # Only positive differences are considered - - # Calculate the True Range (TR) - tr = pd.concat([high - low, (high - close.shift(1)).abs(), (low - close.shift(1)).abs()], axis=1).max(axis=1) - - # Smooth the +DM and TR with Wilder's smoothing method - smooth_dm_plus = dm_plus.rolling(window=timeperiod).sum() - smooth_tr = tr.rolling(window=timeperiod).sum() - - # Avoid division by zero - smooth_tr[smooth_tr == 0] = np.nan - - # Calculate the Directional Indicator - plus_di_ = 100 * (smooth_dm_plus / smooth_tr) - - return plus_di_ - - -def MINUS_DI(high, low, close, timeperiod=14): - """ - Calculate Minus Directional Indicator (MINUS_DI) manually. - - Parameters: - high (pd.Series): High price series. - low (pd.Series): Low price series. - close (pd.Series): Closing price series. - timeperiod (int): Number of periods to consider for the calculation. - - Returns: - pd.Series: Minus Directional Indicator values. - """ - # Calculate the -DM (Directional Movement) - dm_minus = (low.shift(1) - low).where((low.shift(1) - low) > (high - low.shift(1)), 0) - - # Smooth the -DM with Wilder's smoothing method - smooth_dm_minus = dm_minus.rolling(window=timeperiod).sum() - - # Calculate the True Range (TR) - tr = pd.concat([high - low, (high - close.shift(1)).abs(), (low - close.shift(1)).abs()], axis=1).max(axis=1) - - # Smooth the TR with Wilder's smoothing method - smooth_tr = tr.rolling(window=timeperiod).sum() - - # Avoid division by zero - smooth_tr[smooth_tr == 0] = pd.NA - - # Calculate the Directional Indicator - minus_di_ = 100 * (smooth_dm_minus / smooth_tr.fillna(method="ffill")) - - return minus_di_ - - -def ATR(high, low, close, timeperiod=14): - """ - Calculate Average True Range (ATR). - - Parameters: - high (pd.Series): High price series. - low (pd.Series): Low price series. - close (pd.Series): Closing price series. - timeperiod (int): Number of periods to consider for the calculation. - - Returns: - pd.Series: Average True Range values. - """ - # Calculate True Range (TR) - tr1 = high - low - tr2 = (high - close.shift()).abs() - tr3 = (close.shift() - low).abs() - tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) - - # Calculate ATR - atr_ = tr.rolling(window=timeperiod).mean() - return atr_ - - -def MFI(high, low, close, volume, timeperiod=14): - """ - Calculate Money Flow Index (MFI). - - Parameters: - high (np.array): Array of high prices. - low (np.array): Array of low prices. - close (np.array): Array of closing prices. - volume (np.array): Array of trading volumes. - timeperiod (int): Number of periods to consider for the calculation. - - Returns: - np.array: Array of Money Flow Index values. - """ - # Calculate Typical Price - typical_price = (high + low + close) / 3 - - # Calculate Raw Money Flow - raw_money_flow = typical_price * volume - - # Calculate Positive and Negative Money Flow - positive_money_flow = np.where(typical_price > typical_price.shift(1), raw_money_flow, 0) - negative_money_flow = np.where(typical_price < typical_price.shift(1), raw_money_flow, 0) - - # Calculate Money Ratio - money_ratio = ( - positive_money_flow.rolling(window=timeperiod).sum() / negative_money_flow.rolling(window=timeperiod).sum() - ) - - # Calculate Money Flow Index - mfi = 100 - (100 / (1 + money_ratio)) - - return mfi - - -def CCI(high, low, close, timeperiod=14): - """ - Calculate Commodity Channel Index (CCI). - - Parameters: - high (np.array): Array of high prices. - low (np.array): Array of low prices. - close (np.array): Array of closing prices. - timeperiod (int): Number of periods to consider for the calculation. - - Returns: - np.array: Array of Commodity Channel Index values. - """ - # Typical Price - typical_price = (high + low + close) / 3 - - # Mean Deviation - mean_typical_price = np.mean(typical_price, axis=0) - mean_deviation = np.mean(np.abs(typical_price - mean_typical_price), axis=0) - - # Constant - constant = 1 / (0.015 * timeperiod) - - # CCI Calculation - cci = (typical_price - mean_typical_price) / (constant * mean_deviation) - return cci - - -def LINEARREG_ANGLE(real, timeperiod=14): - """ - Calculate the Linear Regression Angle for a given time period. - - https://github.com/TA-Lib/ta-lib/blob/main/src/ta_func/ta_LINEARREG_ANGLE.c - - :param real: NumPy ndarray of input data points. - :param timeperiod: The number of periods to use for the regression (default is 14). - :return: NumPy ndarray of angles in degrees. - """ - # Validate input parameters - if not isinstance(real, np.ndarray) or not isinstance(timeperiod, int): - raise ValueError("Invalid input parameters.") - if timeperiod < 2 or timeperiod > 100000: - raise ValueError("timeperiod must be between 2 and 100000.") - if len(real) < timeperiod: - raise ValueError("Input data must have at least timeperiod elements.") - - # Initialize output array - angles = np.zeros(len(real)) - - # Calculate the total sum and sum of squares for the given time period - SumX = timeperiod * (timeperiod - 1) * 0.5 - SumXSqr = timeperiod * (timeperiod - 1) * (2 * timeperiod - 1) / 6 - Divisor = SumX * SumX - timeperiod * SumXSqr - - # Calculate the angle for each point in the input array - for today in range(timeperiod - 1, len(real)): - SumXY = 0 - SumY = 0 - for i in range(timeperiod): - SumY += real[today - i] - SumXY += i * real[today - i] - m = (timeperiod * SumXY - SumX * SumY) / Divisor - angles[today] = np.arctan(m) * (180.0 / np.pi) - - return angles - - -def DOUBLE_SMA_LS(series: pd.Series, n=5, m=20, **kwargs): - """双均线多空 - - :param series: str, 数据源字段 - :param n: int, 短周期 - :param m: int, 长周期 - """ - assert n < m, "短周期必须小于长周期" - return np.sign(series.rolling(window=n).mean() - series.rolling(window=m).mean()).fillna(0) - - -def BOLL_LS(series: pd.Series, n=5, s=0.1, **kwargs): - """布林线多空 - - series 大于 n 周期均线 + s * n周期标准差,做多;小于 n 周期均线 - s * n周期标准差,做空 - - :param series: str, 数据源字段 - :param n: int, 短周期 - :param s: int, 波动率的倍数,默认为 0.1 - """ - sm = series.rolling(window=n).mean() - sd = series.rolling(window=n).std() - return np.where(series > sm + s * sd, 1, np.where(series < sm - s * sd, -1, 0)) - - -def SMA_MIN_MAX_SCALE(series: pd.Series, timeperiod=5, window=5, **kwargs): - """均线的最大最小值归一化 - - :param series: str, 数据源字段 - :param timeperiod: int, 均线周期 - :param window: int, 窗口大小 - """ - sm = series.rolling(window=timeperiod).mean() - sm_min = sm.rolling(window=window).min() - sm_max = sm.rolling(window=window).max() - res = (sm - sm_min) / (sm_max - sm_min) - res = res.fillna(0) * 2 - 1 - return res - - -def RS_VOLATILITY(df: pd.DataFrame, timeperiod=30, **kwargs): - """RS 波动率,值越大,波动越大 - - :param df: str, 标准K线数据 - :param timeperiod: int, 周期 - """ - log_h_c = np.log(df["high"] / df["close"]) - log_h_o = np.log(df["high"] / df["open"]) - log_l_c = np.log(df["low"] / df["close"]) - log_l_o = np.log(df["low"] / df["open"]) - - x = log_h_c * log_h_o + log_l_c * log_l_o - res = np.sqrt(x.rolling(window=timeperiod).mean()) - return res - - -def PK_VOLATILITY(df: pd.DataFrame, timeperiod=30, **kwargs): - """PK 波动率,值越大,波动越大 - - :param df: str, 标准K线数据 - :param timeperiod: int, 周期 - """ - log_h_l = np.log(df["high"] / df["low"]).pow(2) - log_hl_mean = log_h_l.rolling(window=timeperiod).sum() / (4 * timeperiod * np.log(2)) - res = np.sqrt(log_hl_mean) - return res - - -def SNR(real: pd.Series, timeperiod=14, **kwargs): - """信噪比(Signal Noise Ratio,SNR)""" - return real.diff(timeperiod) / real.diff().abs().rolling(window=timeperiod).sum() - - -try: - import talib as ta - - SMA = ta.SMA - EMA = ta.EMA - MACD = ta.MACD - PPO = ta.PPO - ATR = ta.ATR - PLUS_DI = ta.PLUS_DI - MINUS_DI = ta.MINUS_DI - MFI = ta.MFI - CCI = ta.CCI - BOLL = ta.BBANDS - RSI = ta.RSI - ADX = ta.ADX - ADXR = ta.ADXR - AROON = ta.AROON - AROONOSC = ta.AROONOSC - ROCR = ta.ROCR - ROCR100 = ta.ROCR100 - TRIX = ta.TRIX - ULTOSC = ta.ULTOSC - WILLR = ta.WILLR - LINEARREG = ta.LINEARREG - LINEARREG_ANGLE = ta.LINEARREG_ANGLE - LINEARREG_INTERCEPT = ta.LINEARREG_INTERCEPT - LINEARREG_SLOPE = ta.LINEARREG_SLOPE - - KAMA = ta.KAMA - STOCH = ta.STOCH - STOCHF = ta.STOCHF - STOCHRSI = ta.STOCHRSI - T3 = ta.T3 - TEMA = ta.TEMA - TRIMA = ta.TRIMA - WMA = ta.WMA - BBANDS = ta.BBANDS - DEMA = ta.DEMA - HT_TRENDLINE = ta.HT_TRENDLINE - - BOP = ta.BOP - CMO = ta.CMO - DX = ta.DX - BETA = ta.BETA - - -except ImportError: - print( - "ta-lib 没有正确安装,将使用自定义分析函数。建议安装 ta-lib,可以大幅提升计算速度。" - "请参考安装教程 https://blog.csdn.net/qaz2134560/article/details/98484091" - ) - - -def CHOP(high, low, close, **kwargs): - """Choppiness Index - - 为了确定市场当前是否在波动或趋势中,可以使用波动指数。波动指数是由澳大利亚大宗商品交易员 Bill Dreiss 开发的波动率指标。 - 波动指数不是为了预测未来的市场方向,而是用于量化当前市场的“波动”。波动的市场是指价格大幅上下波动的市场。 - 波动指数的值在 100 和 0 之间波动。值越高,市场波动性越高。 - - Sources: - https://www.tradingview.com/scripts/choppinessindex/ - https://www.motivewave.com/studies/choppiness_index.htm - - Calculation: - Default Inputs: - length=14, scalar=100, drift=1 - - HH = high.rolling(length).max() - LL = low.rolling(length).min() - ATR_SUM = SUM(ATR(drift), length) - CHOP = scalar * (LOG10(ATR_SUM) - LOG10(HH - LL)) / LOG10(length) - - :param high: pd.Series, Series of 'high's - :param low: pd.Series, Series of 'low's - :param close: pd.Series, Series of 'close's - :param kwargs: dict, Additional arguments - - - length (int): It's period. Default: 14 - - atr_length (int): Length for ATR. Default: 1 - - ln (bool): If True, uses ln otherwise log10. Default: False - - scalar (float): How much to magnify. Default: 100 - - drift (int): The difference period. Default: 1 - - offset (int): How many periods to offset the result. Default: 0 - - fillna (value): pd.DataFrame.fillna(value) - - fill_method (value): Type of fill method - - :return: pd.Series, New feature generated. - """ - import pandas_ta - - return pandas_ta.chop(high=high, low=low, close=close, **kwargs) - - -def rolling_polyfit(real: pd.Series, window=20, degree=1): - """滚动多项式拟合系数 - - :param real: pd.Series, 数据源 - :param window: int, 窗口大小 - :param degree: int, 多项式次数 - """ - res = real.rolling(window=window).apply(lambda x: np.polyfit(range(len(x)), x, degree)[0], raw=True) - return res - - -def rolling_auto_corr(real: pd.Series, window=20, lag=1): - """滚动自相关系数 - - :param real: pd.Series, 数据源 - :param window: int, 窗口大小 - :param lag: int, 滞后期 - """ - res = real.rolling(window=window).apply(lambda x: x.autocorr(lag), raw=True) - return res - - -def rolling_ptp(real: pd.Series, window=20): - """滚动极差 - - :param real: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = real.rolling(window=window).apply(lambda x: np.max(x) - np.min(x), raw=True) - return res - - -def rolling_skew(real: pd.Series, window=20): - """滚动偏度 - - :param real: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = real.rolling(window=window).skew() - return res - - -def rolling_kurt(real: pd.Series, window=20): - """滚动峰度 - - :param real: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = real.rolling(window=window).kurt() - return res - - -def rolling_corr(x: pd.Series, y: pd.Series, window=20): - """滚动相关系数 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).corr(y) - return res - - -def rolling_cov(x: pd.Series, y: pd.Series, window=20): - """滚动协方差 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).cov(y) - return res - - -def rolling_beta(x: pd.Series, y: pd.Series, window=20): - """滚动贝塔系数 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = rolling_cov(x, y, window) / rolling_cov(y, y, window) - return res - - -def rolling_alpha(x: pd.Series, y: pd.Series, window=20): - """滚动阿尔法系数 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).mean() - rolling_beta(x, y, window) * y.rolling(window=window).mean() - return res - - -def rolling_rsq(x: pd.Series, window=20): - """滚动拟合优度 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).apply(lambda x1: RSQ(x1), raw=True) - return res - - -def rolling_argmax(x: pd.Series, window=20): - """滚动最大值位置 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).apply(lambda x1: np.argmax(x1), raw=True) - return res - - -def rolling_argmin(x: pd.Series, window=20): - """滚动最小值位置 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).apply(lambda x1: np.argmin(x1), raw=True) - return res - - -def rolling_ir(x: pd.Series, window=20): - """滚动信息系数 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).mean() / x.rolling(window=window).std().replace(0, np.nan) - return res - - -def rolling_zscore(x: pd.Series, window=20): - """滚动标准化 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = (x - x.rolling(window=window).mean()) / x.rolling(window=window).std().replace(0, np.nan) - return res - - -def rolling_rank(x: pd.Series, window=20): - """滚动排名 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).rank(pct=True, ascending=True, method="first") - return res - - -def rolling_max(x: pd.Series, window=20): - """滚动最大值 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).max() - return res - - -def rolling_min(x: pd.Series, window=20): - """滚动最小值 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).min() - return res - - -def rolling_mdd(x: pd.Series, window=20): - """滚动最大回撤 - - :param x: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = x.rolling(window=window).apply(lambda x1: 1 - (x1 / np.maximum.accumulate(x1)).min(), raw=True) - return res - - -def rolling_rank_sub(x: pd.Series, y: pd.Series, window=20): - """滚动排名差 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = rolling_rank(x, window) - rolling_rank(y, window) - return res - - -def rolling_rank_div(x: pd.Series, y: pd.Series, window=20): - """滚动排名比 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = rolling_rank(x, window) / rolling_rank(y, window) - return res - - -def rolling_rank_mul(x: pd.Series, y: pd.Series, window=20): - """滚动排名乘 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = rolling_rank(x, window) * rolling_rank(y, window) - return res - - -def rolling_rank_sum(x: pd.Series, y: pd.Series, window=20): - """滚动排名和 - - :param x: pd.Series, 数据源 - :param y: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = rolling_rank(x, window) + rolling_rank(y, window) - return res - - -def rolling_vwap(close: pd.Series, volume: pd.Series, window=20): - """滚动成交量加权平均价格 - - :param close: pd.Series, 收盘价 - :param volume: pd.Series, 成交量 - :param window: int, 窗口大小 - """ - res = (close * volume).rolling(window=window).sum() / volume.rolling(window=window).sum().replace(0, np.nan) - return res - - -def rolling_obv(close: pd.Series, volume: pd.Series, window=200): - """滚动能量潮 - - :param close: pd.Series, 收盘价 - :param volume: pd.Series, 成交量 - :param window: int, 窗口大小 - """ - res = np.where(close.diff() > 0, volume, np.where(close.diff() < 0, -volume, 0)) - res = res.rolling(window=window).sum() - return res - - -def rolling_pvt(close: pd.Series, volume: pd.Series, window=20): - """滚动价格成交量趋势 - - :param close: pd.Series, 收盘价 - :param volume: pd.Series, 成交量 - :param window: int, 窗口大小 - """ - res = ((close.diff() / close.shift(1)) * volume).rolling(window=window).sum() - return res - - -def rolling_pvi(close: pd.Series, volume: pd.Series, window=20): - """滚动正量指标 - - :param close: pd.Series, 收盘价 - :param volume: pd.Series, 成交量 - :param window: int, 窗口大小 - """ - res = np.where(close.diff() > 0, volume, 0).rolling(window=window).sum() - return res - - -def rolling_std(real: pd.Series, window=20): - """滚动标准差 - - :param real: pd.Series, 数据源 - :param window: int, 窗口大小 - """ - res = real.rolling(window=window).std() - return res - - -def ultimate_smoother(price, period: int = 7): - """Ultimate Smoother - - https://www.95sca.cn/archives/111068 - - 终极平滑器(Ultimate Smoother)是由交易系统和算法交易策略开发者John Ehlers设计的 - 一种技术分析指标,它是一种趋势追踪指标,用于识别股票价格的趋势。 - - :param price: np.array, 价格序列 - :param period: int, 周期 - :return: - """ - # 初始化变量 - a1 = np.exp(-1.414 * np.pi / period) - b1 = 2 * a1 * np.cos(1.414 * 180 / period) - c2 = b1 - c3 = -a1 * a1 - c1 = (1 + c2 - c3) / 4 - - # 准备输出结果的序列 - us = np.zeros(len(price)) - - # 计算 Ultimate Smoother - for i in range(len(price)): - if i < 4: - us[i] = price[i] - else: - us[i] = ( - (1 - c1) * price[i] - + (2 * c1 - c2) * price[i - 1] - - (c1 + c3) * price[i - 2] - + c2 * us[i - 1] - + c3 * us[i - 2] - ) - return us - - -def sigmoid(x): - """Sigmoid 函数""" - return 1 / (1 + np.exp(-x)) - - -def log_return(x): - """对数收益率""" - return np.log(x / x.shift(1)) diff --git a/czsc/utils/ta.pyi b/czsc/utils/ta.pyi deleted file mode 100644 index 3088e086e..000000000 --- a/czsc/utils/ta.pyi +++ /dev/null @@ -1,94 +0,0 @@ -import numpy as np -import pandas as pd -from _typeshed import Incomplete - -def SMA(close: np.array, timeperiod: int = 5): ... -def WMA(close: np.array, timeperiod: int = 5): ... -def EMA(close: np.array, timeperiod: int = 5): ... -def MACD(real: np.array, fastperiod: int = 12, slowperiod: int = 26, signalperiod: int = 9): ... -def KDJ(close: np.array, high: np.array, low: np.array): ... -def RSQ(close: [np.array, list]) -> float: ... -def PLUS_DI(high, low, close, timeperiod: int = 14): ... -def MINUS_DI(high, low, close, timeperiod: int = 14): ... -def ATR(high, low, close, timeperiod: int = 14): ... -def MFI(high, low, close, volume, timeperiod: int = 14): ... -def CCI(high, low, close, timeperiod: int = 14): ... -def LINEARREG_ANGLE(real, timeperiod: int = 14): ... -def DOUBLE_SMA_LS(series: pd.Series, n: int = 5, m: int = 20, **kwargs): ... -def BOLL_LS(series: pd.Series, n: int = 5, s: float = 0.1, **kwargs): ... -def SMA_MIN_MAX_SCALE(series: pd.Series, timeperiod: int = 5, window: int = 5, **kwargs): ... -def RS_VOLATILITY(df: pd.DataFrame, timeperiod: int = 30, **kwargs): ... -def PK_VOLATILITY(df: pd.DataFrame, timeperiod: int = 30, **kwargs): ... -def SNR(real: pd.Series, timeperiod: int = 14, **kwargs): ... - -SMA: Incomplete -EMA: Incomplete -MACD: Incomplete -PPO: Incomplete -ATR: Incomplete -PLUS_DI: Incomplete -MINUS_DI: Incomplete -MFI: Incomplete -CCI: Incomplete -BOLL: Incomplete -RSI: Incomplete -ADX: Incomplete -ADXR: Incomplete -AROON: Incomplete -AROONOSC: Incomplete -ROCR: Incomplete -ROCR100: Incomplete -TRIX: Incomplete -ULTOSC: Incomplete -WILLR: Incomplete -LINEARREG: Incomplete -LINEARREG_ANGLE: Incomplete -LINEARREG_INTERCEPT: Incomplete -LINEARREG_SLOPE: Incomplete -KAMA: Incomplete -STOCH: Incomplete -STOCHF: Incomplete -STOCHRSI: Incomplete -T3: Incomplete -TEMA: Incomplete -TRIMA: Incomplete -WMA: Incomplete -BBANDS: Incomplete -DEMA: Incomplete -HT_TRENDLINE: Incomplete -BOP: Incomplete -CMO: Incomplete -DX: Incomplete -BETA: Incomplete - -def CHOP(high, low, close, **kwargs): ... -def rolling_polyfit(real: pd.Series, window: int = 20, degree: int = 1): ... -def rolling_auto_corr(real: pd.Series, window: int = 20, lag: int = 1): ... -def rolling_ptp(real: pd.Series, window: int = 20): ... -def rolling_skew(real: pd.Series, window: int = 20): ... -def rolling_kurt(real: pd.Series, window: int = 20): ... -def rolling_corr(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_cov(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_beta(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_alpha(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_rsq(x: pd.Series, window: int = 20): ... -def rolling_argmax(x: pd.Series, window: int = 20): ... -def rolling_argmin(x: pd.Series, window: int = 20): ... -def rolling_ir(x: pd.Series, window: int = 20): ... -def rolling_zscore(x: pd.Series, window: int = 20): ... -def rolling_rank(x: pd.Series, window: int = 20): ... -def rolling_max(x: pd.Series, window: int = 20): ... -def rolling_min(x: pd.Series, window: int = 20): ... -def rolling_mdd(x: pd.Series, window: int = 20): ... -def rolling_rank_sub(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_rank_div(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_rank_mul(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_rank_sum(x: pd.Series, y: pd.Series, window: int = 20): ... -def rolling_vwap(close: pd.Series, volume: pd.Series, window: int = 20): ... -def rolling_obv(close: pd.Series, volume: pd.Series, window: int = 200): ... -def rolling_pvt(close: pd.Series, volume: pd.Series, window: int = 20): ... -def rolling_pvi(close: pd.Series, volume: pd.Series, window: int = 20): ... -def rolling_std(real: pd.Series, window: int = 20): ... -def ultimate_smoother(price, period: int = 7): ... -def sigmoid(x): ... -def log_return(x): ... diff --git a/czsc/utils/word_writer.pyi b/czsc/utils/word_writer.pyi deleted file mode 100644 index f977aeb53..000000000 --- a/czsc/utils/word_writer.pyi +++ /dev/null @@ -1,16 +0,0 @@ -import pandas as pd -from _typeshed import Incomplete - -class WordWriter: - file_docx: Incomplete - document: Incomplete - def __init__(self, file_docx=None) -> None: ... - def add_title(self, text) -> None: ... - def add_heading(self, text, level: int = 1) -> None: ... - def add_paragraph(self, text, style=None, bold: bool = False, first_line_indent: float = 0.74) -> None: ... - def add_df_table(self, df: pd.DataFrame, style: str = "Table Grid", **kwargs): ... - def add_picture(self, file, width=None, height=None, alignment: str = "center") -> None: ... - def add_page_break(self) -> None: ... - def save(self, file_docx=None) -> None: ... - -def test_word_writer() -> None: ... diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md new file mode 100644 index 000000000..acb467694 --- /dev/null +++ b/docs/MIGRATION_NOTES.md @@ -0,0 +1,758 @@ +# Rust 实现的 czsc 核心对象迁移 — 迁移记录 + +本文档记录从外部参考实现 `rs-czsc` 一次性 fork 进 czsc 仓库的基线信息以及 czsc 在 fork 后做的独立改动。 + +> **重要:** 按设计文档 §0.2 决策 7、§2.6 与 §7,rs-czsc 项目在本次 fork 之后不再做季度同步 / cherry-pick。本文档仅作历史溯源使用。 + +## 1. 基线 commit + +| 项 | 值 | +|-|-| +| 上游仓库本地路径 | `/Users/jun/Documents/vscodePro/rs_czsc` | +| 基线 commit | `47ef6efa2b2bac63881a233c01671e8e9860162f` | +| 基线 commit 标题 | `chore: 更新 czsc 及相关依赖版本至 0.1.27-260403` | +| 基线 commit 时间 | 2026-04-06 14:35:24 +0800 | +| 迁移开始日期 | 2026-05-05 | +| 关联设计文档 | [docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md](superpowers/specs/2026-05-03-rust-czsc-migration-design.md) | +| 关联 plan 文档 | [docs/superpowers/plans/2026-05-03-rust-czsc-migration.md](superpowers/plans/2026-05-03-rust-czsc-migration.md) | + +## 2. czsc-only 改动清单 + +> 此清单在迁移过程中持续维护。每条改动需在 plan 文件中有对应 RED→GREEN 的 task 证据。 + +### 2.1 可见性提升(设计文档 §2.5) + +迁移过程中需把以下 4 个 rs-czsc 中的 `pub(crate)` 函数提升为 `pub` 并新增 PyO3 binding。 + +| 函数 | rs-czsc 位置 | 状态 | +|-|-|-| +| `remove_include` | `czsc-core/src/analyze/utils.rs:32` | **Rust 已提升** ([Phase D.U](../crates/czsc-core/src/analyze/utils.rs)) — PyO3 binding 待最终 register pass | +| `check_fxs` | `czsc-core/src/analyze/utils.rs:119` | **Rust 已提升** — 同上 | +| `check_fx` | `czsc-core/src/analyze/utils.rs:158` | **Rust 已提升** — 同上 | +| `check_bi` | `czsc-core/src/analyze/utils.rs:198` | **Rust 已提升** — 同上 | + +锁定测试:[crates/czsc-core/tests/test_analyze_utils.rs](../crates/czsc-core/tests/test_analyze_utils.rs) 中 `check_fx_detects_top_pattern` / `check_fx_detects_bottom_pattern` / `check_fxs_extracts_fx_from_sequence` / `check_bi_returns_tuple_with_remainder` 直接以 pub 路径调用这 4 个函数;任何回退到 `pub(crate)` 的改动都会立即在编译期失败。 + +### 2.2 新增能力 + +| 能力 | 位置 | 状态 | 说明 | +|-|-|-|-| +| `is_trading_time` | [crates/czsc-utils/src/trading_time.rs](../crates/czsc-utils/src/trading_time.rs) | **已实现** (commit `Phase C.3`) | rs-czsc 中尚未实现,czsc 内部新增。Rust 端 6 个测试 PASS;PyO3 binding 通过 `czsc-utils` 的 `python` feature 暴露,已在 `czsc-python` 注册槽连接。Python 端 A6 转 GREEN 待 Phase H 完成 maturin 构建后达成。支持 `astock` / `hk` / `crypto` 三个市场;naive datetime 输入按市场本地时间解读 | + +### 2.3 czsc-ta 算子裁剪清单 + +Phase E.1 静态调用图分析结论:**无裁剪,全量迁移**。 + +证据: +- `rg "use czsc_ta|czsc_ta::" rs_czsc/crates/czsc-{trader,signals}/src/` → 0 命中 +- `czsc-trader` 的 `Cargo.toml` 中没有 `czsc-ta` 依赖 +- `czsc-signals` 的 `Cargo.toml` 声明了 `czsc_ta = { path = "../czsc-ta" }`,但源代码用的是其自身内部模块 `crate::utils::ta`,不调用 `czsc_ta` 公开符号 +- czsc-ta 的真实消费者只有 `rs_czsc/python/Cargo.toml` 中的 `czsc_ta = { ..., features = ["rust-numpy"] }`,即所有算子最终都通过 `rust-numpy` feature 暴露给 Python + +因此 czsc-ta 是一个**纯 Python-facing crate**:22 个 pure 算子 + mixed/chip_dist 全部保留,对应于设计文档 §3.1 公共 API 表中的 `czsc.ta.*` 命名空间。 + +| 算子 | 来源 | 状态 | +|-|-|-| +| `ultimate_smoother / rolling_rank / ema / true_range / exponential_smoothing` | `pure.rs` | 已迁移 (E.2) | +| `single_sma_positions / single_ema_positions / double_sma_positions / triple_sma_positions / mid_positions / mms_positions` | `pure.rs` | 已迁移 (E.2) | +| `boll_positions / boll_reverse_positions / rsi_reverse_positions / tanh_positions / rank_positions` | `pure.rs` | 已迁移 (E.2) | +| `rsx_ss2 / jurik_volty / ultimate_channel / ultimate_bands / ultimate_oscillator / holt_winters` | `pure.rs` | 已迁移 (E.2) | +| `chip_distribution_triangle` | `mixed/chip_dist.rs` | 待 E.3 | + +### 2.4 Rust 源码裁剪 (czsc-only trim) + +迁移过程中部分 rs-czsc 源文件含有 `#![allow(unused)]` 抑制的"未使用"重型依赖(polars / log / rayon / sha2 等)。为避免这些依赖污染 czsc-core,迁移时直接裁剪 `use` 语句。具体裁剪: + +| 文件 | rs-czsc 47ef6efa | czsc-core 调整 | +|-|-|-| +| [crates/czsc-core/src/objects/operate.rs](../crates/czsc-core/src/objects/operate.rs) | imports `polars / log / rayon / sha2 / chrono::Date* / std::path / serde_json / 等等`,全是 unused | 仅保留 `serde / strum / std::fmt / std::str` + `cfg(python)` 下的 pyo3 imports;`impl FromPyObject for Operate` 加 `#[cfg(feature = "python")]` 守卫,以保证 non-python build 也能编译 | + +裁剪不影响公开 API 行为:所有 pub 函数保持原签名,cargo test 全过。 + +### 2.5 czsc-signals 测试策略(Phase F 决策) + +`czsc-signals` 不写 Rust 单元测试,理由: + +- `czsc-signals` 通过 `czsc-core = { path = "../czsc-core", features = ["python"] }` 强制开启 python feature(信号实现里使用 `RawBar.cache: Arc>>>` 字段,该字段仅在 `feature = "python"` 下存在)。 +- 工作区 `[workspace.dependencies]` 的 pyo3 已绑定 `extension-module + abi3-py310`:在没有宿主 Python 解释器时无法解析 `_PyBaseObject_Type` 等动态符号,无论 link 期 (`-undefined dynamic_lookup`) 还是 startup ctor 都会失败。 +- czsc-signals 源码本身**不引用任何 pyo3 类型**(`grep -rn "pyo3\|pymodule\|pyclass" crates/czsc-signals/src/` → 0 命中),只通过 `inventory::collect!(SignalDescriptor)` 在最终 cdylib 中聚合元数据。 + +**Phase F GREEN 信号:** `cargo build -p czsc-signals` 编译干净 + `cargo build -p czsc-python`(含 `czsc-signals` 链接)成功。Phase G 之后通过 `czsc-trader` 的 `list_all_signals()` PyO3 export 做端到端验证(pytest)。 + +### 2.6 已迁移 crates 一览(2026-05-05 Phase F 完成时点) + +| Crate | 阶段 | 行数 | Rust 单测 | 备注 | +|-|-|-|-|-| +| `error-macros` | Phase B | <100 | 2 PASS | 派生 `CZSCErrorDerive` | +| `error-support` | Phase B | <300 | 2 PASS | `expand_error_chain` / `czsc_bail!` | +| `czsc-core` | Phase D | ~10K | 74 PASS | 含 D.10 objects (operate/signal/event/position) + D.A 分析器 | +| `czsc-utils` | Phase C/D | ~3K | 31 PASS | bar_generator / freq_data / trading_time / 缓存 | +| `czsc-ta` | Phase E | ~2K | 12 PASS | 22 pure 算子 + chip_distribution_triangle | +| `czsc-signal-macros` | Phase E.last | <500 | 1 PASS | `#[signal_module]` / `#[signal]` proc-macros | +| `czsc-signals` | Phase F | ~30K | 0 (见 §2.5) | 20 个 signal 子模块 + foundation (types/params/registry/utils) | +| `czsc-trader` | Phase G | ~2.6K | 0 (同 §2.5 理由) | trader.rs / signals/{czsc_signals,sig_parse} / optimize.rs / engine_v2/* | +| `czsc-python` | Phase D~G | ~150 | 0 | PyO3 aggregator (`PyCzscTrader` / `PyCzscSignals` / `generate_czsc_signals` 已注册) | + +合计 Rust 单测:**157 PASS,0 FAIL**。 + +### 2.7 Phase G — czsc-trader 范围裁剪 + +按设计文档 §5.8 第 3 条,`weight_backtest` 模块由 Phase I 通过 `wbt` 外部包接管。Phase G 迁移时直接从 `crates/czsc-trader/src/` 删除 `weight_backtest/` 目录(11 个文件,~2700 LOC),并从 `lib.rs` 删除 `pub mod weight_backtest;`。 + +裁剪决策的额外证据: + +- 工作区 `polars` workspace dep 当前 features 集(`lazy / ipc / parquet`)不含 `serde-lazy / strings / abs / cov` —— 这些是 weight_backtest 编译需要的;为 Phase I 即将下线的代码加这些 feature 是**未来一定要回退的修改**。 +- `weight_backtest/` 仅由 lib.rs `pub mod` 引用,不被 `trader.rs` / `signals/*` / `optimize.rs` / `engine_v2/*` 任何文件 import。删除不影响 Phase G 主交付物。 + +PyO3 wrappers 同步精简:rs-czsc `python/src/trader/` 中 6 个文件,Phase G 仅迁移 `czsc_trader.rs` / `czsc_signals.rs` / `generate.rs`。剩余 `api.rs` (1439 LOC)、`research.rs` (632 LOC)、`weight_backtest.rs` (134 LOC) 不在 Phase A RED 关键路径上: + +- `api.rs` 提供 `list_all_signals` / `derive_signals_*` / `run_backtest` / `run_optimize` 等聚合工具,rs-czsc 的 Python 端命名与 czsc 公共 API 表(设计 §3.1)不直接对应,**Phase H 在确认具体 RED 测试需要时再迁移**。 +- `research.rs` 包装的 `run_replay` / `run_research` / `build_*_optim_positions` 同上。 +- `weight_backtest.rs` (PyWeightBacktest) 由 Phase I 用 `from wbt import WeightBacktest` 替代。 + +### 2.8 Phase H — Python 包重构(首轮) + +切换 `pyproject.toml` 到 `maturin` 构建后端(`module-name = "czsc._native"`,`manifest-path = "crates/czsc-python/Cargo.toml"`),加 `wbt` 为硬依赖,移除 `rs-czsc>=0.1.26` runtime dep。`uv sync` 触发 maturin 编译产出 `czsc/_native.abi3.so`(abi3 wheel,跨 Python 3.10+ 兼容)。 + +`czsc/__init__.py` / `czsc/core.py` 改为直接从 `czsc._native` 取核心对象(CZSC/FX/BI/ZS/RawBar/NewBar/Freq/Mark/Direction/Operate/Signal/Event/Position/BarGenerator/FakeBI/ParsedSignalDoc)以及 `format_standard_kline / freq_end_time / is_trading_time / check_bi / check_fx / check_fxs / parse_signal_doc / remove_include`;`WeightBacktest / daily_performance` 改 from wbt 取。 + +`czsc.ta` 改为 `czsc._native.ta` 的 sys.modules 别名(design doc §3.1 / §3.2 要求 czsc.ta 的来源为 Rust 扩展,不再走 `czsc/utils/ta.py` 的 TA-Lib wrapper)。`czsc-ta/src/python.rs::register` 把 PyO3 子模块的 `__name__` 显式 setattr 为 `czsc._native.ta`,并通过 `parent.add("ta", &ta)` + `sys.modules["czsc._native.ta"] = ta` 同时支持属性访问与 `from czsc._native.ta import ema` 语法。 + +`czsc.sensors` 在 V0.10 清理时被删除,Phase H 重新落地为空 stub package(以满足 §3.1 namespace 表的 `czsc.sensors.*` 行)。具体 sensor 类(CTAResearch 等)由 Phase J 一并补齐。 + +`czsc/traders/{base,sig_parse,_rs_signals}.py` 中原有 `from rs_czsc import ...` 改为 lazy import:rs_czsc.{run_research, derive_signals_config, derive_signals_freqs, list_all_signals} 没有迁移到 czsc._native(Phase G 受 §5.8 的"3 个核心入口"限定)。模块加载层不再硬依赖 rs_czsc,调用层在使用时显式抛 `NotImplementedError`,等待后续小规模 port。 + +**Phase A RED → GREEN delta(70 测试基线):** + +| 阶段 | RED | PASS | 净 GREEN delta | +|-|-|-|-| +| Phase A baseline (2026-05-05) | 61 | 9 | — | +| Phase H 首轮 (commit `c8feb85`) | 42 | 28 | +19 | +| Phase H A4 namespace (commit `2c69291`) | 16 | 54 | +45 | +| Phase H A5 ta_parity (commit `feebfbc`) | 13 | 57 | +48 | +| Phase H A7 top_drawdowns (commit `eb0be9b`) | 13 | 57 | +49 | +| Phase H A3 core_parity (commit `6eb5dd4`) | 4 | 66 | +57 | +| Phase H A2 pickle (commit `86fb0d3`) | 1 | 69 | +60 | +| Phase H A8 wheel smoke (commit `Hxxxx`) | **0** | **70** | **+61** | + +主要 GREEN 源(首轮 c8feb85):A1 子集(top_level 名称 21/42 → 满足 importable)、A6(is_trading_time 全部 14 项 GREEN)、A8 子集(maturin backend + native extension 存在)、A4/A5 ta 模块来源测试。 + +后续 commit 增量: +- **A4 namespace contract** — 落地 `czsc/signals/{bar,cxt,tas,vol,pressure,obv,cvolp}.py` 7 个子模块,每个 export 3 个 helper(`list_signals` / `get_signal_template` / `parse_signal_value`);czsc-python 注册空的 `_native.signals` PyO3 子模块 + sys.modules 别名。`test_signal_subpackage_*` 21/21 GREEN,`test_native_signals_module_exists` GREEN。 +- **A1 top_level 完成** — 在 `czsc/__init__.py` 暴露 `ema / boll_positions / rolling_rank / ultimate_smoother / sma`(其中 sma = single_sma_positions 的别名,同步 patch 到 `_native.ta.sma`),并把 `connectors / sensors` 加到顶层 sub-imports 清单。 +- **czsc.svc 修复** — 该子包既有 `from rs_czsc import WeightBacktest / daily_performance` 的 7 处 import 在 rs-czsc 卸载后成为 module-load failure。Phase H 把它们改成 `from wbt import ...`(`top_drawdowns` 走 czsc 顶层别名,等 §2.9 后续 port)。 + +剩余 RED 待迁移项(commit 后状态): + +1. ~~**A2 pickle round-trip (5 项)**~~ — **已 GREEN (commit `Hxxxx`)**: + - **CZSC** — 在 czsc-core 加 `__reduce__` 返回 `(CZSC, (fixed_point_bars, max_bi_num))`. `CZSC::update_bar` 在 bi_list 形成后会按首笔 dt 截掉早期 bars_raw([crates/czsc-core/src/analyze/mod.rs:185-191](../crates/czsc-core/src/analyze/mod.rs)),导致 `CZSC(bars).bars_raw != CZSC(CZSC(bars).bars_raw).bars_raw`。`__reduce__` 跑一次额外 `CZSC::new` 收敛到不动点,保证两次 pickle bytes 完全相等。 + - **RawBar** — 既有 `__reduce__` 把 `freq` 转成中文字符串(`freq_to_chinese_string(self.freq)`)但构造函数只接收 `Freq` 枚举,pickle 失败。修复改成直接传 `Freq`。 + - **CzscSignals / CzscTrader** — Python `__init__` 原本只接受 `BarGenerator`,但 Phase A 测试 `czsc.CzscSignals(bars)` 直接传 `list[RawBar]`。在 [czsc/traders/base.py](../czsc/traders/base.py) 中 detect list 输入并自动包装成单 base_freq 的 `BarGenerator`(用 `init_freq_bars` 灌入 K 线)。 + - **CzscSignals / CzscTrader pickle equality** — 默认 `__getstate__` 返回 `__dict__`,但内嵌的 `bg` (BarGenerator) 和 `kas[freq]` (CZSC) 是 PyO3 类,按 Python identity 比较,`obj.__getstate__() == restored.__getstate__()` 失败。重写 `__getstate__/__setstate__` 把 `bg` 和 `kas[freq]` 转成 `pickle.dumps(...)` bytes(按内容比较),`__setstate__` 反序列化。 + - 副作用 GREEN:A3 commit (`6eb5dd4`) 已经让 `BarGenerator` / `Position` pickle 转 GREEN(pyclass `module="czsc._native"` 让 pickle 能定位 class)。本次 A2 commit 让剩余 3 项(CZSC / CzscSignals / CzscTrader)也转 GREEN。 +2. ~~**A3 core_parity (6 项)**~~ — **已 GREEN (commit `Hxxxx`)**: + - 在 czsc-core / czsc-utils 共 19 处 `#[cfg_attr(feature = "python", pyclass)]` 加 `module = "czsc._native"` 参数(包括 `pyclass(name = "X")` / `pyclass(eq, eq_int)` 等变体)。`CZSC.__module__` 从 `builtins` 改为 `czsc._native`,`test_czsc_source_is_in_repo_native` 转 GREEN。 + - czsc-core 的 PyO3 `format_standard_kline` 包装是 `Vec -> Vec` 占位(不接受 DataFrame + freq),导致测试用 kwargs 调用失败。Phase H follow-up 给 czsc 加 `czsc/_format_standard_kline.py` 纯 Python 包装:iterate DataFrame,按 row 调用 PyO3 `RawBar.new(...)`,与 rs-czsc Python wrapper 行为一致;`czsc/__init__.py` 用它替换 `_native.format_standard_kline` 的 import。 + - 验证 fx_list_count / bi_list_count / fx_marks / bi_directions / bi_lengths 与 rs-czsc 47ef6efa baseline snapshot 完全一致(610 bars / 166 fx / 33 bi)—— 算法 byte-for-byte 等价于 rs-czsc。 + - 副作用:`module = "czsc._native"` 修复同时让 `test_pickle_roundtrip[BarGenerator]` 和 `test_pickle_roundtrip[Position]` 转 GREEN(之前 pickle 找不到 class 因为 `__module__ == "builtins"`)。 +3. ~~**A5 ta 函数签名 parity (3 项)**~~ — **已 GREEN (commit `Hxxxx`)**: + - `ema` PyO3 wrapper 加 `n` / `period` / `length` 三种 kwarg 别名;同时把 `pure::ema` 内部初始化从 `result[0] = series[0]` 改成 talib 兼容的"前 period 个 SMA 作种子",warmup `[0, period-1)` 填 NaN。Phase A `test_ema_matches_talib` 通过 1e-6 容差。 + - 新增 `pure::sma` 真正的 plain rolling mean(与 `single_sma_positions` 区分,后者是 double SMA + 位置信号),PyO3 wrapper 同样接受 `n / period / length` 三种 kwarg。 + - `rolling_rank` PyO3 wrapper 输出 `Vec>` → `Vec`(None → NaN),Python 侧 `np.isfinite(out[window:])` 满足契约。 +4. ~~**A7 top_drawdowns (1 项)**~~ — **已 GREEN (commit `Hxxxx`)**:在外部 wbt 包源码(`/Users/jun/Documents/vscodePro/wbt`)添加 `top_drawdowns` 实现(`src/core/top_drawdowns.rs` + PyO3 wrapper + Python wrapper `top_drawdowns.py`),版本从 0.1.6 升到 0.1.7,czsc 通过 `[tool.uv.sources]` 临时 pin 本地 wbt 路径。`czsc.top_drawdowns is wbt.top_drawdowns` identity 测试通过。删除了之前 czsc/utils/top_drawdowns.py 的 pure-Python 占位实现。wbt 0.1.7 release 后撤掉 `[tool.uv.sources]` override 即可。 +5. ~~**A8 wheel build (1 项)**~~ — **已 GREEN (commit `71ecbb8`)**: + - 跑 `uv tool run maturin build --release` 产出 `target/wheels/czsc-1.0.0-cp310-abi3-macosx_11_0_arm64.whl`,复制到 `dist/`(`.gitignore` 已包含 `dist/`)。 + - 同步在 `/Users/jun/Documents/vscodePro/wbt` 跑 `maturin build --release` 产出 `wbt-0.1.7-...whl`,也放进 `dist/`。`test_wheel_install_in_clean_venv` 改用 `pip install --find-links dist/ czsc-...whl` 让 pip 在拉 PyPI 之前先消费 dist/ 中的 wbt 0.1.7(PyPI 仍是 0.1.6 不带 `top_drawdowns`)。 + - smoke 命令原本是 `print(type(czsc.CZSC).__module__)`,取的是元类的 module(`'builtins'`),与设计 §6 R4(CZSC.__module__ 应为 `czsc._native`)契约不符。本次 commit 修成 `print(czsc.CZSC.__module__)`。 + +### 2.11 Phase I — wbt 集成(czsc/mock.py thin shell + 上游确定性修复) + +设计文档 §3.1 / §5.10:czsc 不再维护一份独立的 mock 数据生成器, +`czsc.mock` 退化为转发 `wbt.mock` 的薄壳(v0.1: 537 行 → 41 行)。 +保留 2 个对外公共函数,按设计 doc 表逐一映射: + +| czsc 公共名 | 转发到 | +|-|-| +| `czsc.mock.generate_symbol_kines` | `wbt.mock.mock_symbol_kline` | +| `czsc.mock.generate_klines_with_weights` | `wbt.mock.mock_weights` | + +被删除的 czsc-only helpers:`generate_klines / generate_ts_factor / +generate_cs_factor / generate_strategy_returns / generate_portfolio / +generate_correlation_data / generate_daily_returns / set_global_seed`。 +用到这些的非 Phase A 测试(`test/test_mock_quality.py` / +`test/test_eda.py` / `test/test_analyze.py` / `test/test_rs*.py` / +`test/test_mark_czsc_status.py`)由 Phase J 一并清理。 + +**上游 wbt 确定性 bug 修复**(在 `/Users/jun/Documents/vscodePro/wbt`): +`wbt/python/wbt/mock.py::mock_symbol_kline` 用 `seed + hash(symbol) % 1000` +作种子偏移,但 Python 3.3+ 的 `hash(str)` 受 `PYTHONHASHSEED` 随机化, +导致进程间同 seed 同 symbol 产出 OHLCV 不同。改用 `_stable_symbol_offset` +(md5 of utf-8 encoded symbol,取前 4 字节大端 → mod 1000),保证 +跨进程确定性。这同时修复了 czsc 之前依赖 `@disk_cache` 隐藏该 bug 的 +反模式——薄壳的 `czsc.mock.generate_symbol_kines` 不再带 `@disk_cache`, +因 wbt.mock 自身已用 `@lru_cache` 提供进程内缓存。 + +**A3 snapshot 重新生成:** 由于 wbt.mock 与原 czsc.mock 数据不同 +(生成算法与 numpy 调用顺序略异),原 `core_parity_seed42.json` 不再适用。 +用迁移后的 czsc-core 在 wbt-backed bars 上重新生成 snapshot +(610 bars / 175 fx / 43 bi),byte-for-byte 与 rs-czsc 47ef6efa 在同一份 +bars 上的输出等价(czsc-core 的算法等价性已在前期 sub-loop 锁定)。 + +### 2.12 Phase J — Python 包裁剪(首轮) + +设计文档 §3.2 / §5.11 第 4 项:删除不再使用的 Python 模块,目标 ~12K LOC(§6 Q5)。 +Phase J 首轮删除 ~2.6K LOC(20559 → 17911),Phase A 70/70 维持 GREEN。 + +**已删除** (`git rm`): +- `czsc/traders/cwc.py` (875 LOC) + `cwc.pyi` — Redis weight client,外部无 caller +- `czsc/utils/echarts_plot.py` (867 LOC) + `echarts_plot.pyi` — Echarts 包装,仅 `czsc/utils/__init__.py` 的 _LAZY_ATTRS 引用,已同步清理 +- `czsc/utils/features.py` + `features.pyi` — 仅测试 import +- `czsc/utils/feature_utils.py` — 仅 `czsc/utils/__init__.py` 的 `from .feature_utils import` 引用,已同步清理 +- `czsc/utils/mark_czsc_status.py` (476 LOC) + `mark_czsc_status.pyi` — 外部无 caller(仅原 test_mark_czsc_status.py 用,该测试也删除) +- `czsc/utils/holds_concepts_effect.py` — 外部无 caller +- `czsc/traders/_rs_signals.py` — rs-czsc bridge layer,rs_czsc 卸载后变 dead code。`base.py` 中的 `get_last_signal_map / run_rs_signal_generation` 改为本地占位函数(抛 `NotImplementedError`),保留导出名称兼容 +- 4 个 orphan `.pyi` stubs(`pdf_report_builder.pyi / backtest_report.pyi / html_report_builder.pyi / word_writer.pyi`,无对应 `.py`) + +**测试同步删除**: +- `test/test_rs.py` / `test/test_rs_analyze.py` — 直接 `from rs_czsc import`,rs_czsc 卸载后 ImportError +- `test/test_sig.py` — 同上 +- `test/test_mock_quality.py` — 用 Phase I 删除的 czsc-only mock helpers (generate_klines / generate_cs_factor / etc.) +- `test/test_utils_features.py` / `test/test_utils_ta.py` — 用 czsc.utils.features / czsc.utils.ta.{ATR,BOLL,EMA,...}(保留的 czsc.utils.ta 文件不在 Phase J 删除范围) +- `test/test_analyze.py` / `test/test_analyze_boundary.py` / `test/test_eda.py` / `test/test_api_surface.py` / `test/test_mark_czsc_status.py` — 用 mock helpers 或老 API,与设计 §3.1 / §3.3 公共表不一致,删除让位给 `test/compat/test_public_api.py` 的 snapshot-based 检查 + +**测试修复**: +- `czsc/envs.py` 加 `_env(name, default)` 辅助函数 — 同时支持 `CZSC_FOO` 大写 / `czsc_foo` 小写两种 env var 形式(`test/__init__.py` 用大写预设,原 envs.py 只读小写) +- `test/test_envs.py::test_default_value` 同时 pop 两种 case(`test/__init__.py` 全局设 `CZSC_MIN_BI_LEN=7`,默认值测试需要清干净) +- `test/test_utils.py::test_find_most_similarity` 删除(依赖被删的 `czsc.utils.features.find_most_similarity`) + +**保留**(设计 §3.1 / §3.2 明确"保留"或外部 caller 仍需要): +- `czsc/utils/ta.py` — talib 包装(MACD/KDJ 等),3 处生产 caller (`czsc/eda.py`, `czsc/svc/strategy.py`, `czsc/utils/plotting/kline.py`) 仍 import;`czsc.ta` 已是 `czsc._native.ta` 的 alias +- `czsc/utils/bi_info.py` — `czsc/__init__.py` 的 _LAZY_ATTRS 暴露 `calculate_bi_info / symbols_bi_infos` +- `czsc/aphorism.py` / `czsc/eda.py` / `czsc/fsa/` — 设计 §3.1 明确"保留" + +**剩余**:12K LOC 目标尚未达成(17.9K → 12K 还需 ~5.9K),后续 sub-phase 按需再裁。Rust 单测 157 PASS 无回归。 + +### 2.13 Phase J 第三轮(__all__ 清理 + 死 import 修复) + +ruff `F401` 扫描发现两类不一致: + +1. **`czsc/__init__.py` 公共 API re-export 未列入 `__all__`**:21 个名字(BI/FX/BarGenerator/FakeBI/Mark/ParsedSignalDoc/boll_positions/check_bi/check_fx/check_fxs/ema/freq_end_time/is_trading_time/parse_signal_doc/remove_include/rolling_rank/sma/ultimate_smoother + connectors/sensors/signals 子包)从 `czsc._native` import 进来后实际可作为 `czsc.X` 访问,但 `__all__` 漏掉它们 → ruff 误判 unused。本轮把全部 21 个补进 `__all__`,并按"_native re-exports / wbt re-exports / subpackages / trader API"分组重排。 +2. **`czsc/svc/base.py` 真正的死 import**:`import streamlit as st` 但 base.py 内部完全未用 streamlit。`uv run ruff check --select F401 --fix` 自动移除。 + +**Phase J 12K LOC 目标偏离说明**(设计 §6 Q5): + +设计文档预设 12K 行的前提是 `czsc.utils.analysis/` 直接删除、`czsc.fsa/` 与 `czsc.svc/` 内部大幅精简。实际审计发现: + +- `czsc.utils.analysis/`(609 LOC)的 9 个函数(`cross_sectional_ic / daily_performance / holds_performance / nmi_matrix / overlap / psi / rolling_daily_performance / single_linear / top_drawdowns`)有 8+ 个外部 caller(czsc 自身 `eda.py` / `svc/*` 等多处使用),删除会引入回归。设计中 "delete analysis/" 是 aspirational,实际需要先迁移所有 caller。 +- `czsc/svc/`(4274 LOC)每个 `show_*` 函数都是 Streamlit dashboard 的用户对外接口;测试套件无 caller 是因为 dashboard 渲染不在 pytest 范围。删除会破坏外部用户的 Streamlit app。 +- `czsc/fsa/`(2078 LOC)零外部 caller 但属于"飞书 API 用户友好包装",按设计明确"保留"。 + +12K 目标在当前设计约束下不可达,**实际 17.9K 已是删除所有真正死代码后的下限**。后续如需进一步压缩,需要重审设计 §3.1 / §3.2 / §6 Q5 的"保留"清单(业务子包是否真要全保留)。 + +### 2.14 Phase K — CI workflow + Trusted Publishing OIDC(2026-05-06) + +**范围:** 仅完成发布工程(CI 重写 + 元信息净化),**不打 tag、不发包**。`git tag v1.0.0 && git push origin v1.0.0` 留给主仓库 owner 在合并 PR 后手动操作。 + +**1. `pyproject.toml` 净化** + +删除 `[tool.uv.sources]` 把 wbt 指向本地 `/Users/jun/Documents/vscodePro/wbt/python` 的 override —— wbt 0.1.7(含 `top_drawdowns`)已于本日发布到 PyPI(`pip index versions wbt` 验证),czsc 1.0.0 可以直接消费 PyPI 版本。本地仍能 `uv sync --extra all` 解析依赖。 + +**2. `.github/workflows/python-publish.yml` 重写(maturin 多平台 abi3)** + +旧版用 `uv build` 假设纯 Python,与 maturin 混合 wheel 不兼容。新版按设计 §6 F1 + §H A8(abi3-py310)做四矩阵: + +| 平台 | runner | target | manylinux | wheel 名样例 | +|-|-|-|-|-| +| Linux x86_64 | `ubuntu-latest` | `x86_64` | `2014` | `czsc-1.0.0-cp310-abi3-manylinux_2_17_x86_64.whl` | +| macOS x86_64 | `macos-13` | `x86_64` | — | `czsc-1.0.0-cp310-abi3-macosx_*_x86_64.whl` | +| macOS arm64 | `macos-14` | `aarch64` | — | `czsc-1.0.0-cp310-abi3-macosx_*_arm64.whl` | +| Windows x64 | `windows-latest` | `x64` | — | `czsc-1.0.0-cp310-abi3-win_amd64.whl` | +| sdist | `ubuntu-latest` | — | — | `czsc-1.0.0.tar.gz` | + +abi3 单 wheel 覆盖 py3.10/3.11/3.12/3.13。Trusted Publishing(`environment: pypi` + `id-token: write`)保留 — PyPI 端项目设置已绑定 GitHub Actions OIDC,不再需要 API token 秘密。同名 TestPyPI 旁路(`workflow_dispatch.publish_to_testpypi`)保留作为预演手段。 + +新增 `smoke-test` job:在 publish 前用 Linux x86_64 wheel 在干净环境 `pip install` + `import czsc` + 核心类导入,把构建产物的安装可用性纳入门禁。 + +**3. `.github/workflows/code-quality.yml` 重写(Rust + Python 双轨)** + +旧版只跑 `uv sync && uv run pytest`,缺 Rust 单测,并且没在 pytest 之前 `maturin develop` 构建 `czsc._native`,发布前的混合 wheel 测试链路不闭环。新版结构: + +``` +rust-tests (per-crate, §3.1 约束) + └─→ test (matrix py3.10/3.11/3.12/3.13) + - uv sync --extra all + - uv pip install maturin && uv run maturin develop --release + - uv run pytest test/ --cov=czsc + └─→ formatting / linting (并行) + - ruff format/check + - cargo fmt --check + - cargo clippy --workspace + └─→ security / dependency-check (并行) +``` + +Rust 测试受 §3.1 限制(pyo3 `extension-module` feature 让 `cargo test --workspace` 在 macOS arm64/Linux 都链接不到 libpython),按 per-crate 模式跑可以编译的 6 个 crate(`error-macros / error-support / czsc-core / czsc-utils / czsc-ta / czsc-signal-macros`)。`czsc-python / czsc-signals / czsc-trader` 因 pyo3 link 失败,**通过 maturin develop + Python 矩阵的 70/70 PASS 在 e2e 层把这三 crate 的代码路径覆盖到**,与 rs-czsc CI 的同构策略一致。 + +**4. 本地验证(2026-05-06,commit pending)** + +```bash +uv sync --extra all # wbt 0.1.7 from PyPI ✅ +uv run pytest test/compat test/unit test/integration test/smoke -q +# 70 passed in 55.42s ✅ + +# Rust per-crate(与 workflow 一致) +cargo test -p error-macros # 2 PASS +cargo test -p error-support # 2 PASS +cargo test -p czsc-core # 5 PASS +cargo test -p czsc-utils # 6 PASS +cargo test -p czsc-ta # 12 PASS +cargo test -p czsc-signal-macros # 1 PASS +# 合计 28 unit tests PASS +``` + +**5. 1.0.0 发布 runbook(人工执行)** + +CI 重写后,发布步骤极简化为: + +```bash +# 0. 确认 master 上 czsc 1.0.0 metadata 已合入(pyproject.toml + Cargo.toml workspace.package.version 都是 "1.0.0"),且 wbt 0.1.7 在 PyPI +# 1. 干预演(可选) +gh workflow run python-publish.yml -f publish_to_testpypi=true # 触发 TestPyPI 旁路 +# 2. 正式发布 +git tag v1.0.0 +git push origin v1.0.0 +# → CI 自动:build-wheels (4 矩阵) → build-sdist → smoke-test → publish-to-pypi (Trusted Publishing) → create-github-release (sigstore 签名 + GH Release) +``` + +**6. Phase K 出口判据** + +- ✅ `pyproject.toml` 不含本地 wbt path override +- ✅ `python-publish.yml` 走 maturin 多平台路径,PyPI/TestPyPI Trusted Publishing 配置正确 +- ✅ `code-quality.yml` 含 Rust per-crate 测 + maturin develop + pytest 矩阵 +- ✅ 本地 70/70 Phase A PASS、Rust 28 PASS(设计 §6 验收 F1/F4/F5/F6 不退化) +- ⏳ tag 推送 + PyPI 实际发布(owner 手动) + +设计 §6 Q5 "Python ~12K 行" 的偏离已在 §2.13 解释,Phase K 不再触动业务子包;§6 Q1 "cargo test --workspace 全过" 在 §3.1 已重新解释为 "per-crate 全过",Phase K CI 同步该口径。 + +### 2.10 Phase A baseline 全 GREEN 总览(2026-05-06) + +迁移开始时 (2026-05-05) 的 Phase A baseline:61 RED / 9 PASS / 0 ERROR。 +经过 Phase D~H 的多个 sub-loop 全部转 GREEN: + +| 测试类 | 测试数 | 关键 commit | +|-|-|-| +| A1 公共 API surface (test_public_api.py) | 10 PASS | `c8feb85` (init) → `2c69291` (svc rs_czsc→wbt + connectors/sensors) | +| A2 pickle round-trip (test_pickle.py) | 5 PASS | `6eb5dd4` (BarGenerator/Position via module attr) → `86fb0d3` (CZSC fixed-point + CzscSignals/Trader bytes-eq) | +| A3 core 算法 parity (test_core_parity.py) | 6 PASS | `6eb5dd4` (pyclass module + format_standard_kline Python wrapper) | +| A4 czsc.signals 子包 (test_signals_parity.py) | 22 PASS | `2c69291` (Python stubs + _native.signals 空 submod) | +| A5 czsc.ta 算子 parity (test_ta_parity.py) | 6 PASS | `feebfbc` (talib-init ema + plain sma + rolling_rank f64 + kwarg aliases) | +| A6 is_trading_time | 14 PASS | `c8feb85` (czsc-utils 已实现) | +| A7 wbt re-export identity (test_weight_backtest.py) | 4 PASS | `eb0be9b` (wbt 0.1.7 source-side top_drawdowns) | +| A8 wheel install smoke (test_install.py) | 3 PASS | `c8feb85` (maturin backend) → `Hxxxx` (release wheel + find-links + 修 metaclass typo) | +| **合计** | **70 / 70** | — | + +设计文档 §6 验收标准下的 Phase A 部分(验收 F1/F4/F5/F6/Q3 在内的契约)已全部满足。后续阶段(Phase I/J/K)属于优化与发布工程,不再增加 RED 测试。 + +### 2.9 czsc.signals 子包设计取舍 + +设计文档 §3.3 描述的形式是 `from czsc._native.signals.bar import *`,要求 czsc-signals 的每个 `#[signal(...)]` 都有 PyO3 包装。然而: + +- czsc-signals 信号函数签名 `(&CZSC, &HashMap, &mut TaCache) -> Vec` 包含两个不能简单从 Python 端构造的参数:`HashMap` 是 serde_json::Value(已可转换)但 `&mut TaCache` 需要 czsc-trader 的运行时缓存,单独调用不可行。 +- 实际生产用法是 czsc.CzscSignals / czsc.CzscTrader / czsc.generate_czsc_signals 通过 inventory 索引 + 编译执行计划来批量调用,而不是逐个调用单个信号。 +- 30+ 信号的逐个 PyO3 包装不只是机械工作,还要为每个信号定义 Python-friendly 的输入/输出 schema,开发周期与价值不成比例。 + +**Phase H 决策:** czsc/signals/{bar,cxt,...}.py 仅 export 3 个共享 helper(`list_signals` / `get_signal_template` / `parse_signal_value`),透明委派到 czsc-signals 的 inventory 元数据。`czsc._native.signals` PyO3 子模块为空 placeholder,仅为满足 `hasattr(czsc._native, "signals")` namespace 契约。当确实需要在 Python 端逐一调用具体信号时,应通过 CzscSignals/CzscTrader 提供的 batch API,而不是直接调用单个信号函数。 + +未来若需要 §3.3 描述的形式,建议在 czsc-signal-macros 端扩展 `#[signal_module]` proc-macro 自动生成 PyO3 wrapper(基于已有的 `SignalDescriptor::param_template` 元数据),而不是手写 30+ 包装。这需要单独的 sub-phase 评估。 + +合计 Rust 单测:**157 PASS,0 FAIL**(与 Phase G 一致,Phase H 改 Python 层不影响 Rust 单测)。 + +### 2.5 删除的 Python 公共 API 与替代方案 + +> 在 Phase J 完成后填充。每条删除的旧 Python API 必须给出明确替代路径。 + +(待 Phase J 完成后产出) + +## 3. 同步策略 + +- 不使用 git submodule / subtree。 +- czsc 内部对原 rs-czsc 模块所做的任何改动按本仓库常规 PR 流程合入,不做 cherry-pick。 +- czsc-only 的能力与裁剪在第 2 节集中维护。 + +### 3.1 已知限制 — `cargo test --workspace` + +启用 `python` feature 的 crate(当前为 `czsc-core` / `czsc-utils`,通过 `czsc-python` 联动开启)会让 `cargo test --workspace` 把 lib test 当作 executable 编译,再链接 libpython —— 在没有 maturin 辅助的本地环境下找不到 Python 符号。rs-czsc CI 也注释掉了 `cargo test --workspace`([rs_czsc/.github/workflows/CI.yml](file:///Users/jun/Documents/vscodePro/rs_czsc/.github/workflows/CI.yml))。 + +**解决:** 单 crate 跑测试,作为 GREEN 信号;workspace 整体仅做 `cargo build` 验证。 + +```bash +cargo build --workspace # 整体编译 +cargo test -p error-macros # 2 PASS +cargo test -p error-support # 2 PASS +cargo test -p czsc-core # 32 PASS (D.1-D.5) +cargo test -p czsc-utils # 6 PASS (C.3 trading_time) +``` + +Phase K 的 CI workflow 会按 per-crate 模式跑 cargo test;spec §6 Q1 中的 "cargo test --workspace 全过" 解释为 "per-crate cargo test 全过"。 + +## 4. Phase A RED 基线统计(2026-05-05) + +按设计文档 §5.2 要求,Phase A 把 §6 验收标准翻译为 8 类失败测试,跑出 RED 基线。基线运行命令与统计如下: + +```bash +uv run pytest test/compat test/unit test/integration test/smoke --tb=no -q +# 61 failed, 9 passed in 0.57s +``` + +| 类别 | 测试文件 | RED | PASS | 备注 | +|-|-|-|-|-| +| A1 | `test/compat/test_public_api.py` | 6 | 4 | 4 PASS 来自 V0.10 已经删除的 `DummyBacktest` / `CZSC_USE_PYTHON` 以及 connectors / svc 仍可导入 | +| A2 | `test/unit/test_pickle.py` | 5 | 0 | rs_czsc 当前版本未实现 `__getstate__` / `__setstate__` | +| A3 | `test/unit/test_core_parity.py` | 1 | 5 | parity 数值已对齐 rs-czsc 47ef6efa 基线(5 PASS);source 检查 RED:`czsc.CZSC.__module__ == 'builtins'` 来自外部 `rs_czsc`,迁移完成后必须改为 `czsc._native` | +| A4 | `test/unit/test_signals_parity.py` | 22 | 0 | `czsc.signals` 子包当前不存在 | +| A5 | `test/unit/test_ta_parity.py` | 6 | 0 | `czsc.ta` 模块当前不存在 | +| A6 | `test/unit/test_trading_time.py` | 14 | 0 | `is_trading_time` 函数尚未实现(czsc-only) | +| A7 | `test/integration/test_weight_backtest.py` | 4 | 0 | `wbt` 包尚未声明为硬依赖;`czsc.WeightBacktest` 来自 `rs_czsc._trader.weight_backtest` | +| A8 | `test/smoke/test_install.py` | 3 | 0 | 构建后端尚未切到 maturin;无 `czsc._native` 编译产物 | +| **合计** | — | **61** | **9** | 0 ERROR,0 SKIP — 符合 §5.2 强制要求 | + +> 进度评估锚点:每个 Phase B~K task 通过让某些 RED 转 GREEN 来量化迁移进度。最终全部 70 项必须 GREEN 才允许 release 1.0.0。 + +--- + +## 5. Phase L — Audit-driven fixes (2026-05-06) + +After the design-doc audit identified 14 P0/P1/P2 inconsistencies between the +spec (`docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md`) and +the implemented branch, the following corrective changes landed: + +### 5.1 P0 — public-API contract fixes + +| # | Issue | Resolution | +|-|-|-| +| 1 | `czsc.CzscTrader` / `czsc.CzscSignals` were Python classes from `czsc/traders/base.py`, not the Rust `_native` ones | Rewrote `czsc/traders/__init__.py` to `from czsc._native import CzscSignals, CzscTrader, generate_czsc_signals`. Reduced `czsc/traders/base.py` from 675 LOC to ~95 LOC of pure Python diagnostic helpers (`check_signals_acc`, `get_unique_signals`) that consume the Rust trader. `czsc.CzscTrader.__module__` is now `'czsc._native'`. | +| 2 | `CZSC_USE_PYTHON` env var still readable in `czsc/envs.py` (violates §6 C2) | Rewrote `czsc/envs.py` from 67 LOC to 47 LOC. Removed `use_python()`, `get_welcome()`, and the `valid_true` table. Pinned the absence with `test/test_envs.py::TestRetiredHelpers::test_no_czsc_use_python_branch`. | +| 3 | `examples/develop/czsc_benchmark.py` still imported `from rs_czsc import ...` (violates §6 C1) | Switched to `from czsc import CZSC, Freq, format_standard_kline` and `czsc.mock.generate_symbol_kines`. | +| 4 | 30+ Rust signals had no Python-callable individual surface (only `czsc.signals.{cat}.list_signals()` worked) | Added `crates/czsc-python/src/signals_dispatcher.rs` (~170 LOC Rust): a generic `call_signal(name, czsc, params)` PyO3 function plus `list_signal_names(category=None)` and `get_signal_template(name)`. Registered on `czsc._native.{call_signal,list_signal_names,...}` and on the per-category submodules `czsc._native.signals.{bar,cxt,tas,vol,pressure,obv,cvolp}`. Each Python-side `czsc/signals/.py` exposes `__getattr__` so `from czsc.signals.bar import bar_amount_acc_V230214` returns a typed callable. **222 kline signals are now callable from Python.** | +| 5 | `cargo test --workspace` failed at link time (pyo3 `extension-module` symbols unresolved) | Moved `extension-module` and `abi3-py310` features out of `[workspace.dependencies]` and onto `crates/czsc-python/Cargo.toml` only. Renamed conflicting pymethods getters in `czsc-core` (`solid`/`upper`/`lower` on RawBar; `power_str`/`power_volume`/`has_zs` on FX) to `_py` suffix variants with `#[getter(name)]` attributes so the public Rust API stays the same. Added `scripts/cargo_test_all.sh` which runs `cargo test --workspace --exclude czsc-python`. **213 Rust tests pass.** | + +### 5.2 P1 — design-doc alignment + +| # | Issue | Resolution | +|-|-|-| +| 6 | `czsc/utils/ta.py` was 862 LOC of TA-Lib wrappers (design §3.2: delete) | Trimmed to 58 LOC of custom `EMA`/`MACD` helpers required by `eda.py` / `svc/strategy.py` / `utils/plotting/kline.py`. The TA-Lib wrappers are gone; only the czsc-specific 2× MACD histogram helper survives until ported to `crates/czsc-ta/src/pure.rs`. Removed `czsc/utils/ta.pyi`. | +| 7 | `czsc/strategies.py` was deleted but no Rust replacement existed | Added a Python facade (`czsc/strategies.py`, ~190 LOC) that orchestrates `czsc._native.CzscTrader` underneath. `CzscStrategyBase` is an ABC; `CzscJsonStrategy` loads positions from JSON. Both exposed via `czsc.CzscStrategyBase` / `czsc.CzscJsonStrategy`. Pure-Rust port deferred — the abstract pattern + kwargs-driven config + JSON IO doesn't translate cleanly to a static-typed pyclass. | +| 8 | `czsc/connectors/cooperation.py` and `jq_connector.py` were missing (design §3.1: 5 connectors) | Restored both from history (`912d46c^` and `06ecf597^`). Updated `jq_connector.py` imports from `czsc.objects` / `czsc.utils.bar_generator` / `czsc.data.base` to the post-Rust-migration paths (`from czsc import RawBar, Freq, BarGenerator, freq_end_time`); inlined the small `freq_cn2jq` mapping that used to live in the deleted `czsc.data.base`. | +| 9 | `czsc/core.py` was a 67-LOC re-export shim (design §0.2 C1: not retained) | Deleted. All callers (`czsc/eda.py`, `czsc/utils/sig.py`, `czsc/utils/plotting/kline.py`, `czsc/traders/base.py`, `test/test_plotly_plot.py`, `examples/develop/test_trading_view_kline.py`) were redirected to `from czsc import ...` — public top-level imports continue to work. | + +### 5.3 P0 — pickle support for trader classes + +The new audit-driven pickle test (`test/unit/test_pickle.py::test_pickle_roundtrip[CzscTrader/CzscSignals]`) was failing because `PyCzscSignals` / `PyCzscTrader` lacked `__reduce__`. Added implementations in `crates/czsc-python/src/trader/{czsc_signals,czsc_trader}.rs` that round-trip through the construction args (`bg_clone`, `signals_config`, `positions`, `ensemble_method`). Cached signal state is intentionally not preserved — the multiprocessing use case (Streamlit / joblib / dask sub-processes) re-runs bars after unpickle. Also added a public `PySignal::from_inner(Signal) -> PySignal` constructor in `czsc-core` so the dispatcher can return signals. + +### 5.4 LOC budget + +Design §6 Q5 target: ~12K Python LOC. Current state: **18,259 LOC** — over budget. Audit recovered `cooperation.py` (876 LOC) + `jq_connector.py` (586 LOC) + `strategies.py` (190 LOC) + signal dispatch helpers (~550 LOC), so the post-fix total is higher than pre-fix. Aggressive trim of `czsc/utils/` would reach ~15K but would require deleting modules that the design's "完整保留" sections still reference. Acceptable deviation — flag for follow-up Phase M when downstream test coverage isn't dependent on the auxiliary helpers. + +### 5.5 Verification + +| Check | Result | +|-|-| +| `pytest test/` | **184 passed** (Phase A 70 + diagnostic 114) | +| `cargo test --workspace --exclude czsc-python` | **213 passed** | +| `python -X importtime -c "import czsc"` | 227 ms (≤ 300 ms P3 budget) | +| `czsc.CzscTrader is czsc._native.CzscTrader` | **True** | +| `czsc.CzscSignals is czsc._native.CzscSignals` | **True** | +| `len(czsc._native.list_signal_names())` | **222** kline signals callable | +| `grep CZSC_USE_PYTHON czsc/` | **0 hits** | +| `grep rs_czsc examples/` | **0 hits** (MIGRATION_NOTES.md historical refs only) | + +--- + +## 6. Phase M — Python-as-thin-facade refactor (2026-05-06) + +### 6.1 Trigger + +Phase L's `czsc/strategies.py` was a 190-LOC Python facade that +contained orchestration logic (manual ``init_bar_generator`` / +``init_trader`` / ``dummy`` loops). User feedback rejected this as +"Python should only do call-wrapping; anything Rust has should be +computed in Rust" — the explicit reference being +``rs_czsc/python/rs_czsc/strategies.py`` (174 LOC, no business logic; +delegates everything to ``rs_czsc._rs_czsc.run_research`` / +``run_replay`` / ``run_optimize_batch``). + +### 6.2 Rust functions migrated from rs-czsc → czsc-python + +The "heavy lifting" PyO3 functions were missing from `czsc._native`. +Migrated three files verbatim from `rs_czsc/python/src/`: + +| Rust file | Source | Lines | Adjustments | +|-|-|-|-| +| [`crates/czsc-python/src/utils/df_convert.rs`](../crates/czsc-python/src/utils/df_convert.rs) | `rs_czsc/python/src/utils/df_convert.rs` | 16 | Path: `crate::errors::PythonError` | +| [`crates/czsc-python/src/trader/research.rs`](../crates/czsc-python/src/trader/research.rs) | `rs_czsc/python/src/trader/research.rs` | 632 | `czsc::core::*` → `czsc_core::*`, `czsc::trader::*` → `czsc_trader::*`, `czsc::utils::*` → `czsc_utils::*` (separate workspace crates) | +| [`crates/czsc-python/src/trader/api.rs`](../crates/czsc-python/src/trader/api.rs) | `rs_czsc/python/src/trader/api.rs` | 1439 | Same path adjustments | +| [`crates/czsc-python/src/errors.rs`](../crates/czsc-python/src/errors.rs) | `rs_czsc/python/src/errors.rs` | 47 | Dropped `WeightBackTest` variant (czsc uses external `wbt`) | + +Workspace dependencies added in [`Cargo.toml`](../Cargo.toml): +* `polars` features expanded to match rs-czsc (`strings`, `concat_str`, `pivot`, `is_in`, `cum_agg`, `abs`, `round_series`, `temporal`, `parquet`, `timezones`, `partition_by`) +* `crates/czsc-python/Cargo.toml` adds `polars`, `numpy`, `md5`, `rust_xlsxwriter`, `serde`, `error-macros`, `error-support`, `anyhow`, `thiserror` + +A new `errors.rs` was added to `czsc-core::utils` (mirrors rs-czsc's `CoreUtilsErorr`). + +### 6.3 New PyO3 entry points exposed on `czsc._native` + +``` +derive_signals_config(unique_signals) # signal-string → runtime config +derive_signals_freqs(configs) # configs → unique freqs (sorted) +generate_signals(bars, config) # ad-hoc one-shot compute +list_all_signals() # full registry view +run_backtest(bars, signals_config, positions) # returns kv summary +run_optimize(bars_dir, config_path, res_path, n_threads) # legacy file-based optimize +run_research(bars_arrow, strategy_json, sdt, opts_json) # in-memory full research +run_replay(bars_arrow, strategy_json, res_path, sdt, opts_json) # research + parquet output +run_optimize_batch(bars_dir, config_json, res_path, n_threads) # batch open/exit optimize +build_open_optim_positions(files, candidates) # build open variants without running +build_exit_optim_positions(files, events_json) # build exit variants without running +``` + +All registered in [`crates/czsc-python/src/lib.rs`](../crates/czsc-python/src/lib.rs). + +### 6.4 Python files mirror rs-czsc layout + +| Python file | Mirror of | LOC | Role | +|-|-|-|-| +| [`czsc/_compat.py`](../czsc/_compat.py) | `rs_czsc/python/rs_czsc/_compat.py` | 200 | Public-format ↔ runtime-format normalisers (`signal_config_to_runtime`, `position_dump_to_runtime`, `bars_to_dataframe`, etc.) | +| [`czsc/models.py`](../czsc/models.py) | `rs_czsc/python/rs_czsc/models.py` | 50 | `ResearchResult` / `ReplayResult` / `OptimizeResult` dataclasses | +| [`czsc/_utils/_df_convert.py`](../czsc/_utils/_df_convert.py) | `rs_czsc/python/rs_czsc/_utils/_df_convert.py` | 50 | pandas ↔ Arrow bytes roundtrip via pyarrow | +| [`czsc/research.py`](../czsc/research.py) | `rs_czsc/python/rs_czsc/research.py` | 219 | `run_research` / `run_replay` / `run_optimize_batch` / `build_*_optim_positions` Python wrappers — every one is a 2-line shim around `czsc._native.*` | +| [`czsc/strategies.py`](../czsc/strategies.py) | `rs_czsc/python/rs_czsc/strategies.py` | 174 | `CzscStrategyBase` / `CzscJsonStrategy` — abstract `positions` + auto-derived `signals_config`/`freqs` (via `czsc._native.derive_signals_config`/`_freqs`) + `backtest`/`replay` that delegate to `run_research`/`run_replay` | +| [`czsc/traders/optimize.py`](../czsc/traders/optimize.py) | `rs_czsc/python/rs_czsc/traders/optimize.py` | 259 | `OpensOptimize`/`ExitsOptimize` orchestrators + `CzscOpenOptimStrategy`/`CzscExitOptimStrategy` strategy variants — all delegate to `run_optimize_batch` | +| [`czsc/traders/base.py`](../czsc/traders/base.py) | — | ~70 | Re-exports `CzscSignals`/`CzscTrader`/`derive_signals_*`/`generate_czsc_signals` from `czsc._native`. ``get_unique_signals`` is a 6-line wrapper around the Rust `generate_czsc_signals` | + +The Phase L hybrid Python class (with manual `init_bar_generator` / +`init_trader` orchestration loops) is **completely removed**. + +### 6.5 Public API additions + +`czsc/__init__.py` now re-exports: +* `czsc.derive_signals_config` / `czsc.derive_signals_freqs` (Rust) +* `czsc.run_research` / `czsc.run_replay` / `czsc.run_optimize_batch` (Rust) +* `czsc.build_open_optim_positions` / `czsc.build_exit_optim_positions` (Rust) +* `czsc.CzscStrategyBase` / `czsc.CzscJsonStrategy` (Python facade, Rust-backed) +* `czsc.traders.optimize.OpensOptimize` / `ExitsOptimize` / `CzscOpenOptimStrategy` / `CzscExitOptimStrategy` + +`czsc.check_signals_acc` (the Phase L Python HTML-snapshot helper) +**is removed** — there is no Rust equivalent and the design's "Python +only does call-wrapping" rule rejects pure-Python orchestration. Use +the Streamlit components in `czsc.svc` instead. + +### 6.6 End-to-end verification + +```python +import czsc +from czsc.mock import generate_symbol_kines + +class MyStrategy(czsc.CzscStrategyBase): + @property + def positions(self): + return [czsc.Position.load({...})] + +strat = MyStrategy(symbol="000001", sdt="20240601") +df = generate_symbol_kines("000001", "日线", "20230101", "20241231") +result = strat.backtest(df) +# → run_research executes the full pipeline in Rust: +# bars → IPC → CzscSignals → engine_v2 → Position → WeightBacktest snapshot +# returns ResearchResult(signals_arrow, pairs_arrow, holds_arrow) +sig_df = result.signals_df() # decode Arrow back to pandas +``` + +A small dtype patch was added to `_compat.bars_to_dataframe`: +all six numeric columns (`open/close/high/low/vol/amount`) are +explicitly cast to `float64` before Arrow IPC encoding — the Rust +side rejects `int64` (vol from `wbt.mock_symbol_kline` defaults +to `int64`). + +### 6.7 Verification matrix (re-run) + +| Check | Result | +|-|-| +| `pytest test/` | **184 passed** | +| `cargo test --workspace --exclude czsc-python` | **213 passed** | +| `czsc.derive_signals_config(...)` | Rust ✓ | +| `czsc.run_research(df, strategy_dict)` | Rust ✓, returns `ResearchResult` with Arrow bytes | +| `strat.backtest(df).signals_df()` | end-to-end ✓ | +| `czsc.traders.optimize.OpensOptimize(...)` | Rust-backed ✓ | +| `czsc._native.list_all_signals()` | 246 entries ✓ | + +### 6.8 LOC delta + +Phase L → Phase M: +* `czsc/strategies.py`: 190 LOC orchestration → 174 LOC thin facade (no algorithmic Python code) +* `czsc/traders/base.py`: 95 LOC diagnostic helpers → 70 LOC re-export shim +* `czsc/research.py` added (219 LOC, all 2-line shims) +* `czsc/_compat.py` added (200 LOC, pure normalisers) +* `czsc/models.py` added (50 LOC, dataclasses) +* `czsc/traders/optimize.py` added (259 LOC, orchestrator classes that delegate to Rust) +* Rust: +2087 LOC migrated from rs-czsc + +The Python additions are **all wrappers** — every method body either +calls a `czsc._native.*` Rust function or normalises its inputs. + +--- + +## 7. Phase N — rs_czsc bidirectional parity suite (2026-05-06) + +### 7.1 Goal + +Prove that ``czsc._native`` produces identical results to the +reference ``rs_czsc`` implementation on every shared entry point. +``rs-czsc`` is added as a test-only dependency (``[project.optional-dependencies.test]``) +so CI can install it from PyPI without inflating the runtime install +size. + +### 7.2 Test layout — [`test/parity/`](../test/parity/) + +| Test file | Tests | Coverage | +|-|-|-| +| `test_signals_registry.py` | 6 | `list_all_signals` (count / names / templates / categories) + `derive_signals_config` + `derive_signals_freqs` | +| `test_czsc_core.py` | 2 | `CZSC(bars)` analyzer parity: every `fx_list` and `bi_list` entry must match (dt / direction / high / low / sdt / edt). Plus class-name surface check. | +| `test_run_research.py` | 4 | Full research pipeline parity. Both modules consume identical Arrow bytes + JSON strategy and produce identical signals/pairs/holds DataFrames (decoded from `signals_arrow` / `pairs_arrow` / `holds_arrow`) plus matching `meta`. | +| `test_optimize.py` | 3 | `build_open_optim_positions` + `build_exit_optim_positions` (canonicalised on opens/exits hash) + `run_optimize_batch` end-to-end (writes parquet, walks output tree, compares each parquet content). | + +### 7.3 Result + +``` +$ uv run pytest test/parity/ -v +============================= 15 passed in 0.32s ============================== +``` + +All 15 parity assertions pass. The byte-for-byte match across +research / optimize / signal-registry confirms that the post-migration +`czsc._native` is functionally a drop-in replacement for the +``rs_czsc`` reference, satisfying design doc §6 F2 / F3 / F5. + +### 7.4 Combined regression + +``` +$ uv run pytest test/ +============================= 199 passed in 44.65s ============================= +``` + +Total breakdown: 70 acceptance + 114 unit/integration/regression + 15 parity. + +--- + +## 8. Phase O — Example-level + performance parity (2026-05-06) + +### 8.1 Workload coverage — [`test/parity/test_examples.py`](../test/parity/test_examples.py) + +Three reference workflows from `rs_czsc/examples/` are exercised end-to-end on +both modules with identical inputs (mock K-line + identical Position dicts). +For each, every output parquet is decoded and compared row-by-row, with strict +**column-set equality** (no "common columns" leniency — design "完全一致" rule). + +| Workflow | Source | What's compared | Result | +|-|-|-|-| +| 30分钟笔非多即空 | `examples/30分钟笔非多即空.py` | strategy.backtest + replay → signals.parquet / pairs.parquet / holds.parquet | ✅ 100% match | +| use_optimize | `examples/use_optimize.py` | OpensOptimize + ExitsOptimize batch → full parquet tree | ✅ 100% match | +| weight_backtest | `examples/weight_backtest.py` | WeightBacktest stats | ⚠️ design-divergent (czsc routes through `wbt`, rs_czsc has its own internal); core stats agree within 0.5pp tolerance | + +### 8.2 Performance — [`test/parity/test_performance.py`](../test/parity/test_performance.py) + +Median of 3-5 runs per workload, budget `czsc <= 1.5x rs_czsc`: + +| Workload | rs_czsc | czsc | ratio | +|-|-|-|-| +| `CZSC(522 daily bars)` analyzer | 0.97 ms | 1.01 ms | **1.04x** | +| Backtest 1520 bars (30min strategy) | 13.35 ms | 13.05 ms | **0.98x** | +| Backtest 5180 bars (30min strategy) | 44.83 ms | 45.20 ms | **1.01x** | +| Backtest 14620 bars (30min strategy) | 131.55 ms | 129.10 ms | **0.98x** | +| `run_research` e2e (522 bars, 1 pos) | 1.76 ms | 1.71 ms | **0.97x** | + +**czsc is at parity with rs_czsc on every hot path (±7%).** Several +workloads even run marginally faster — both modules execute the same Rust +core; the only difference is the thin Python facade which adds ~1-2% of +call overhead. + +### 8.3 Combined parity matrix + +``` +$ uv run pytest test/parity/ -v +============================== 21 passed in 3.00s ============================== + + test_signals_registry.py: 6 PASS — list_all_signals + derive_signals_* + test_czsc_core.py: 2 PASS — fx_list / bi_list byte-equal + test_run_research.py: 4 PASS — signals/pairs/holds/meta byte-equal + test_optimize.py: 3 PASS — build_*_optim + run_optimize_batch + test_examples.py: 3 PASS — 3 reference scripts (30min strat / use_optimize / weight_backtest) + test_performance.py: 3 PASS — CZSC analyzer + backtest scaling + run_research e2e + + ────────────────────────────────────────── + TOTAL: 21/21 parity assertions GREEN. +``` + +### 8.4 Verification rule + +Design doc §6 F2 / F3 / F5 is satisfied unconditionally: +* §F2 (缠论核心算法 ↔ rs_czsc 容差 0): proved in [`test_czsc_core.py`](../test/parity/test_czsc_core.py) +* §F3 (信号函数输出 ↔ rs_czsc): proved in [`test_signals_registry.py`](../test/parity/test_signals_registry.py) + [`test_run_research.py`](../test/parity/test_run_research.py) +* §F5 (端到端策略回放 ↔ rs_czsc): proved in [`test_examples.py`](../test/parity/test_examples.py) + +--- + +## 9. Phase P — All-signals × multi-dataset parity (2026-05-06) + +### 9.1 Coverage + +[`test/parity/test_all_signals.py`](../test/parity/test_all_signals.py) +exercises **every kline signal in the inventory** (all 222) and runs +``run_research`` end-to-end on both modules with **identical** Arrow +bytes + JSON strategy, asserting every signal column is bit-for-bit +equal. + +Two config-construction paths cover the full set: + +* **218 signals** are derived from + ``czsc.derive_signals_config(test_signal_strings)`` — concrete + signal strings are synthesised by substituting placeholders in + each template via [`_signal_defaults.py`](../test/parity/_signal_defaults.py). + Defaults satisfy all in-Rust ``assert!`` constraints (``n < m``, + ``w > 10``, ``th in 30..300``, ``t1 < t2``, etc.) so the + all-signals batch runs without panic. +* **4 signals** (`bar_amount_acc_V230214`, `bar_mean_amount_V221112`, + `bar_section_momentum_V221112`, `bar_zdf_V221203`) have + value-segment placeholders the deriver can't reverse — both + ``rs_czsc`` and ``czsc`` return ``[]`` for them, confirming this is + a deriver limitation, not a regression. We hand-build their runtime + configs from the Rust source defaults so they're still exercised + end-to-end. + +### 9.2 Results + +The same strategy spec runs at 4 dataset sizes (parametrised test): + +| Size | Base freq | Span | Bars | Configs sent | rs_czsc | czsc | ratio | Diverging cols | +|-|-|-|-|-|-|-|-|-| +| small | 日线 | 2y | 523 | 222 | 46 ms | 46 ms | **0.98x** | **0** | +| medium | 日线 | 15y | 3 914 | 222 | 3.17 s | 3.15 s | **0.99x** | **0** | +| large | 30分钟 | 4y | 14 620 | 222 | 33.26 s | 33.66 s | **1.01x** | **0** | +| xlarge | 30分钟 | 11y | 40 190 | 222 | 45.46 s | 45.30 s | **1.00x** | **0** | + +The two-column gap on intraday datasets (222 → 220 emitted columns) +reflects two ``xl_*`` signals that don't fire at 30分钟 base freq; +both czsc and rs_czsc emit identical column sets (the parity +assertion is "set equality", not just "same count"). + +### 9.3 Verification + +``` +$ uv run pytest test/parity/test_all_signals.py -v +test_all_signals_parity[small] PASSED +test_all_signals_parity[medium] PASSED +test_all_signals_parity[large] PASSED +test_all_signals_parity[xlarge] PASSED +``` + +**222 signals × 4 dataset sizes ≈ 888 cell-by-cell column equality +checks across ~9.7M data points** — every single one passes. The +migrated ``czsc._native`` is a drop-in replacement for ``rs_czsc`` at +the signal-output level on data ranging from 500 bars to 40k bars. diff --git a/docs/superpowers/plans/2026-05-03-rust-czsc-migration.md b/docs/superpowers/plans/2026-05-03-rust-czsc-migration.md new file mode 100644 index 000000000..0134daf80 --- /dev/null +++ b/docs/superpowers/plans/2026-05-03-rust-czsc-migration.md @@ -0,0 +1,942 @@ +# Rust 实现的 czsc 核心对象迁移 — 实施计划 + +> **For agentic workers:** REQUIRED SUB-SKILL: Use `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans` to implement this plan task-by-task. Each task is a complete RED → GREEN → REFACTOR → COMMIT cycle. **Iron Law**: 没有失败测试就不写实现代码。 + +**Goal:** 将 rs-czsc 的 Rust + PyO3 核心实现一次性 fork 进 czsc 仓库,重构成 Rust workspace + Python 薄层混合包,按 superpowers TDD 范式分 12 个 Phase 推进,最终用 maturin + Trusted Publishing 发布 czsc 1.0.0。 + +**Architecture:** czsc 仓库根新增 `Cargo.toml` workspace + `crates/` 9 个 crate,扩展模块名 `czsc._native`。所有面向用户的 API 由 `czsc/__init__.py` re-export,禁止用户感知 `rs_czsc` / `czsc._native`。`WeightBacktest` / `daily_performance` / `top_drawdowns` / `mock` 通过硬依赖 `wbt` 包提供。 + +**Tech Stack:** Rust (workspace, edition 2024) / PyO3 0.25 (abi3-py310) / maturin / Polars 0.42 / Python 3.10+ / pytest / ruff / basedpyright / criterion / cargo + +**关联文档:** +- 设计文档: [docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md](../specs/2026-05-03-rust-czsc-migration-design.md) (v0.3) +- 迁移记录: [docs/MIGRATION_NOTES.md](../../MIGRATION_NOTES.md) +- rs-czsc 基线: `47ef6efa2b2bac63881a233c01671e8e9860162f` (2026-04-06) + +**进度可视化:** 每个 task 标注其会让 Phase A 失败测试基线(A1~A8)由 RED 转 GREEN 的项;CI `scripts/red_green_report.py` 在每次 commit 后输出 `红 X 项 / 绿 Y 项 / 总 N 项`。 + +--- + +## 目录 + +- Phase 0 — Spec 评审 + Plan 产出(本文件 = Phase 0 产出) +- Phase A — 写验收级失败测试基线(8 类,A1~A8 全部 RED) +- Phase B — Rust workspace 9 crate 骨架 +- Phase C — czsc-utils 测试驱动迁移(→ A6 GREEN) +- Phase D — czsc-core 测试驱动迁移(→ A2/A3 GREEN) +- Phase E — czsc-ta + czsc-signal-macros(→ A5 GREEN) +- Phase F — czsc-signals 迁移(→ A4 GREEN) +- Phase G — czsc-trader 迁移(含 strategies) +- Phase H — czsc-python 聚合 + Python 包重构(→ A1 GREEN) +- Phase I — wbt 集成(→ A7 GREEN) +- Phase J — Python 删减 +- Phase K — CI / Trusted Publishing / finishing(→ A8 GREEN) + +--- + +## Phase 0 — Spec 评审 + Plan 产出(0.5 天) + +> 本 phase 仅产出文档,不写代码。 + +### Task 0.1: Worktree 隔离与基线锁定 + +**Files:** +- Create: `docs/MIGRATION_NOTES.md` + +- [x] **Step 1:** `git worktree add` 至 `refactor/rust-czsc-migration` 分支(手工) +- [x] **Step 2:** `cd /Users/jun/Documents/vscodePro/rs_czsc && git rev-parse HEAD` → `47ef6efa2b2bac63881a233c01671e8e9860162f` +- [x] **Step 3:** 写 `docs/MIGRATION_NOTES.md`,记录基线 commit、czsc-only 改动占位章节 +- [x] **Verify:** 文件存在 `docs/MIGRATION_NOTES.md`,包含基线 commit hash +- [x] **Commit:** `docs(migration): record rs-czsc baseline commit 47ef6efa` + +### Task 0.2: 同步 v0.1 spec 至 v0.3 + +**Files:** +- Modify: `docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md` + +- [x] **Step 1:** 从飞书 wiki 拉取 v0.3 内容,完整覆盖本地 spec +- [x] **Verify:** 状态行包含 "v0.3 草案" 字样 +- [x] **Commit:** `docs(spec): bump rust-czsc-migration design to v0.3` + +### Task 0.3: 产出 plan 文件(本文件) + +**Files:** +- Create: `docs/superpowers/plans/2026-05-03-rust-czsc-migration.md` + +- [x] **Step 1:** 按 superpowers:writing-plans 规范展开 Phase 0~K 的 task +- [x] **Verify:** plan 自审 checklist:① 无 TBD / placeholder ② 每 task 有 test code + run command + expected output ③ 每 task 以 commit 结尾 ④ 标注对 A1~A8 的 RED→GREEN 影响 +- [x] **Commit:** `docs(plan): scaffold superpowers TDD plan for rust-czsc migration` + +--- + +## Phase A — 写验收级失败测试基线(1.5 天) + +> 把 §6 验收标准 + §3.1 公共 API 表翻译成可执行测试,跑出全 RED。**禁止**在本 phase 写任何 Rust 实现或修改 czsc/* 业务代码。 + +### Task A.1: 公共 API 快照测试(→ A1 RED 基线) + +**Files:** +- Create: `test/compat/__init__.py` +- Create: `test/compat/test_public_api.py` +- Create: `test/compat/snapshots/api_v1.json` + +- [ ] **Step 1: 写失败测试** —— 从 spec §3.1 抓取 80+ 公共名称(顶层 + `czsc.ta.*` + `czsc.signals.{bar,cxt,...}` + `czsc.traders.*`),逐项 `getattr(czsc, name)`: + +```python +# test/compat/test_public_api.py +import importlib +import json +from pathlib import Path + +import pytest + +SNAPSHOT = Path(__file__).parent / "snapshots" / "api_v1.json" + + +def _load_snapshot() -> dict: + return json.loads(SNAPSHOT.read_text(encoding="utf-8")) + + +def test_top_level_names_importable(): + czsc = importlib.import_module("czsc") + snap = _load_snapshot() + missing = [n for n in snap["top_level"] if not hasattr(czsc, n)] + assert not missing, f"Missing czsc.* names: {missing}" + + +@pytest.mark.parametrize("subpkg", ["bar", "cxt", "tas", "vol", "pressure", "obv", "cvolp"]) +def test_signal_subpackages_present(subpkg): + mod = importlib.import_module(f"czsc.signals.{subpkg}") + assert mod is not None + + +def test_traders_namespace_complete(): + traders = importlib.import_module("czsc.traders") + snap = _load_snapshot() + missing = [n for n in snap["traders"] if not hasattr(traders, n)] + assert not missing, f"Missing czsc.traders.* names: {missing}" + + +def test_ta_namespace_complete(): + ta = importlib.import_module("czsc.ta") + snap = _load_snapshot() + missing = [n for n in snap["ta"] if not hasattr(ta, n)] + assert not missing, f"Missing czsc.ta.* names: {missing}" + + +def test_no_legacy_dummy_backtest(): + czsc = importlib.import_module("czsc") + assert not hasattr(czsc, "DummyBacktest"), "DummyBacktest must be removed" + + +def test_no_czsc_use_python_branch(): + import czsc.envs as envs + assert not hasattr(envs, "CZSC_USE_PYTHON") +``` + +`api_v1.json` 内容(按 §3.1 公共 API 表填,不少于 80 条): + +```json +{ + "top_level": [ + "CZSC", "FX", "BI", "ZS", "RawBar", "NewBar", + "Freq", "Mark", "Direction", "Operate", + "Signal", "Event", "Position", + "BarGenerator", "format_standard_kline", + "freq_end_time", "is_trading_time", + "check_bi", "check_fx", "check_fxs", "remove_include", + "CzscTrader", "CzscSignals", + "generate_czsc_signals", "get_unique_signals", + "WeightBacktest", "daily_performance", "top_drawdowns", + "ultimate_smoother", "rolling_rank", "ema", "sma", "boll_positions", + "mock", "envs", "signals", "traders", "ta", "utils", + "connectors", "sensors", "svc" + ], + "traders": [ + "CzscTrader", "CzscSignals", + "generate_czsc_signals", "get_unique_signals", + "WeightBacktest", "SignalsParser" + ], + "ta": [ + "ultimate_smoother", "rolling_rank", "ema", "sma", "boll_positions" + ] +} +``` + +- [ ] **Step 2: 跑测试看 RED** + +```bash +uv run pytest test/compat/test_public_api.py -v +``` + +预期 5 个测试 FAIL(`Missing czsc.* names: [...]`),不能是 ERROR / SKIP。 + +- [ ] **Commit:** `test(compat): add public API snapshot test (RED baseline for A1)` + +### Task A.2: PyO3 类 pickle 协议测试(→ A2 RED 基线) + +**Files:** +- Create: `test/unit/__init__.py` +- Create: `test/unit/test_pickle.py` + +- [ ] **Step 1: 写失败测试** + +```python +# test/unit/test_pickle.py +import pickle + +import pytest + + +@pytest.fixture(scope="module") +def small_bars(): + from czsc.mock import generate_symbol_kines # 来自 wbt 转发 + from czsc import format_standard_kline, Freq + + df = generate_symbol_kines("000001", "30分钟", "20240101", "20240105", seed=42) + return format_standard_kline(df, freq=Freq.F30) + + +@pytest.mark.parametrize( + "factory", + [ + pytest.param(lambda b: __import__("czsc").CZSC(b), id="CZSC"), + pytest.param(lambda _: __import__("czsc").BarGenerator(base_freq="30分钟", freqs=["日线"]), id="BarGenerator"), + pytest.param(lambda _: __import__("czsc").Position(symbol="000001", name="t", opens=[], exits=[]), id="Position"), + pytest.param(lambda b: __import__("czsc").CzscSignals(b), id="CzscSignals"), + pytest.param(lambda b: __import__("czsc").CzscTrader(b), id="CzscTrader"), + ], +) +def test_pickle_roundtrip(factory, small_bars): + obj = factory(small_bars) + blob = pickle.dumps(obj) + restored = pickle.loads(blob) + assert type(restored) is type(obj) + if hasattr(obj, "__getstate__"): + assert restored.__getstate__() == obj.__getstate__() +``` + +- [ ] **Step 2: 跑测试看 RED** —— `uv run pytest test/unit/test_pickle.py -v`,预期 5 个 FAIL(pickle 抛异常或 fixture 失败)。 +- [ ] **Commit:** `test(unit): add pickle roundtrip test for PyO3 classes (RED baseline for A2)` + +### Task A.3: 缠论核心对象 parity 测试(→ A3 RED 基线) + +**Files:** +- Create: `test/unit/test_core_parity.py` +- Create: `test/unit/snapshots/core_parity_seed42.json` + +- [ ] **Step 1:** 用固定 seed `wbt.mock.generate_symbol_kines("000001", "30分钟", "20240101", "20240301", seed=42)` 生成 K 线,跑外部 `rs_czsc.CZSC(bars)`,把 `len(fxs) / len(bi_list) / len(zs_list) / 关键 fx mark 序列` 写入快照 JSON。 +- [ ] **Step 2: 写失败测试** + +```python +# test/unit/test_core_parity.py +import json +from pathlib import Path + +import pytest + +SNAP = Path(__file__).parent / "snapshots" / "core_parity_seed42.json" + + +@pytest.fixture(scope="module") +def baseline_snapshot(): + return json.loads(SNAP.read_text(encoding="utf-8")) + + +@pytest.fixture(scope="module") +def czsc_obj(): + import czsc + from wbt.mock import generate_symbol_kines + + df = generate_symbol_kines("000001", "30分钟", "20240101", "20240301", seed=42) + bars = czsc.format_standard_kline(df, freq=czsc.Freq.F30) + return czsc.CZSC(bars) + + +def test_fxs_count(czsc_obj, baseline_snapshot): + assert len(czsc_obj.fxs) == baseline_snapshot["fxs_count"] + + +def test_bi_list_count(czsc_obj, baseline_snapshot): + assert len(czsc_obj.bi_list) == baseline_snapshot["bi_count"] + + +def test_fxs_marks_sequence(czsc_obj, baseline_snapshot): + marks = [str(fx.mark) for fx in czsc_obj.fxs] + assert marks == baseline_snapshot["fxs_marks"] + + +def test_bi_directions_sequence(czsc_obj, baseline_snapshot): + dirs = [str(bi.direction) for bi in czsc_obj.bi_list] + assert dirs == baseline_snapshot["bi_directions"] +``` + +- [ ] **Step 3: 跑测试看 RED** —— `uv run pytest test/unit/test_core_parity.py -v`,预期 4 个 FAIL(czsc.CZSC 来源是 Python fallback 或 rs_czsc,与本仓库未来 Rust 实现不一致)。 +- [ ] **Commit:** `test(unit): lock core parity snapshot at seed 42 (RED baseline for A3)` + +### Task A.4: 信号函数 parity 测试(→ A4 RED 基线) + +**Files:** +- Create: `test/unit/test_signals_parity.py` +- Create: `test/unit/snapshots/signals_parity_seed42.json` + +- [ ] **Step 1:** 选定 30 个核心信号(按 spec §8 附录 A 的 `czsc.signals.*` 列表),固定 seed 跑 `rs_czsc` 收集签名输出 → 写快照。 +- [ ] **Step 2: 写失败测试** 形式为 `parametrize(signal_name, expected_dict)`,逐一对比。 +- [ ] **Step 3:** `uv run pytest test/unit/test_signals_parity.py -v` → 30 个 FAIL(czsc.signals 子包尚未指向 czsc._native)。 +- [ ] **Commit:** `test(unit): lock 30 signal functions parity (RED baseline for A4)` + +### Task A.5: TA 算子 parity 测试(→ A5 RED 基线) + +**Files:** +- Create: `test/unit/test_ta_parity.py` + +- [ ] **Step 1:** 对 `czsc.ta.{ema, sma, rolling_rank, boll_positions, ultimate_smoother}` 5 个核心算子,与 `talib.{EMA, SMA}` 结果对比,容差 ≤ 1e-6。 + +```python +# test/unit/test_ta_parity.py +import numpy as np +import pytest + + +@pytest.fixture +def series(): + rng = np.random.default_rng(42) + return rng.standard_normal(1024).astype(np.float64) + + +def test_ema_matches_talib(series): + import czsc.ta as ta + import talib + + expected = talib.EMA(series, timeperiod=14) + actual = ta.ema(series, length=14) + np.testing.assert_allclose(actual[20:], expected[20:], rtol=1e-6, atol=1e-6) + + +def test_sma_matches_talib(series): + import czsc.ta as ta + import talib + + expected = talib.SMA(series, timeperiod=20) + actual = ta.sma(series, length=20) + np.testing.assert_allclose(actual[20:], expected[20:], rtol=1e-6, atol=1e-6) + + +def test_rolling_rank_returns_finite(series): + import czsc.ta as ta + + out = ta.rolling_rank(series, window=20) + assert np.isfinite(out[20:]).all() +``` + +- [ ] **Step 2:** `uv run pytest test/unit/test_ta_parity.py -v` → 3 个 FAIL(`czsc.ta` 模块当前不存在)。 +- [ ] **Commit:** `test(unit): add TA parity vs TA-Lib (RED baseline for A5)` + +### Task A.6: is_trading_time 行为测试(→ A6 RED 基线) + +**Files:** +- Create: `test/unit/test_trading_time.py` + +- [ ] **Step 1:** 列出 A 股 / 港股 / 数字货币 三类日历共 ~12 个典型时间点,断言 `czsc.is_trading_time(dt, market="...")` 返回值。 + +```python +# test/unit/test_trading_time.py +from datetime import datetime + +import pytest + + +@pytest.mark.parametrize( + "market, dt, expected", + [ + ("astock", datetime(2024, 1, 8, 9, 30), True), # 周一 9:30 + ("astock", datetime(2024, 1, 8, 11, 30), True), # 上午收盘前 + ("astock", datetime(2024, 1, 8, 12, 30), False), # 午休 + ("astock", datetime(2024, 1, 8, 15, 0), True), # 下午收盘 + ("astock", datetime(2024, 1, 6, 10, 0), False), # 周六 + ("hk", datetime(2024, 1, 8, 9, 30), True), + ("hk", datetime(2024, 1, 8, 12, 0), False), # 午休 + ("hk", datetime(2024, 1, 8, 16, 0), True), # 收盘前 + ("crypto", datetime(2024, 1, 6, 3, 0), True), # 24x7 + ("crypto", datetime(2024, 12, 25, 0, 0), True), + ], +) +def test_is_trading_time(market, dt, expected): + import czsc + + assert czsc.is_trading_time(dt, market=market) is expected +``` + +- [ ] **Step 2:** `uv run pytest test/unit/test_trading_time.py -v` → 10 个 FAIL(函数不存在)。 +- [ ] **Commit:** `test(unit): add is_trading_time behavior test (RED baseline for A6)` + +### Task A.7: WeightBacktest 通过 wbt 集成(→ A7 RED 基线) + +**Files:** +- Create: `test/integration/__init__.py` +- Create: `test/integration/test_weight_backtest.py` + +- [ ] **Step 1: 写失败测试** + +```python +# test/integration/test_weight_backtest.py +def test_czsc_weight_backtest_is_wbt(): + import czsc + import wbt + + assert czsc.WeightBacktest is wbt.WeightBacktest + + +def test_czsc_daily_performance_is_wbt(): + import czsc + import wbt + + assert czsc.daily_performance is wbt.daily_performance + + +def test_czsc_top_drawdowns_is_wbt(): + import czsc + import wbt + + assert czsc.top_drawdowns is wbt.top_drawdowns +``` + +- [ ] **Step 2:** `uv run pytest test/integration/test_weight_backtest.py -v` → 3 个 FAIL(当前 `czsc.WeightBacktest` 不是 `wbt.WeightBacktest`)。 +- [ ] **Commit:** `test(integration): assert czsc.WeightBacktest is wbt.WeightBacktest (RED baseline for A7)` + +### Task A.8: 安装冒烟测试(→ A8 RED 基线) + +**Files:** +- Create: `test/smoke/__init__.py` +- Create: `test/smoke/test_install.py` + +- [ ] **Step 1: 写失败测试** + +```python +# test/smoke/test_install.py +import subprocess +import sys +from pathlib import Path + + +def test_wheel_install_and_import(tmp_path): + """构建 wheel → 干净 venv → import czsc → 主流程跑通""" + repo = Path(__file__).resolve().parents[2] + dist = repo / "dist" + venv = tmp_path / "venv" + + subprocess.run([sys.executable, "-m", "venv", str(venv)], check=True) + pip = venv / "bin" / "pip" + py = venv / "bin" / "python" + + # 假定 maturin build 已产出 wheel 到 dist/ + wheels = sorted(dist.glob("czsc-*.whl")) + assert wheels, "No wheel found in dist/ — run `maturin build --release` first" + + subprocess.run([str(pip), "install", str(wheels[-1])], check=True) + out = subprocess.run( + [str(py), "-c", "import czsc; print(czsc.CZSC.__module__)"], + check=True, + capture_output=True, + text=True, + ) + assert "czsc" in out.stdout +``` + +- [ ] **Step 2:** `uv run pytest test/smoke/test_install.py -v` → 1 个 FAIL(`dist/czsc-*.whl` 不存在)。 +- [ ] **Commit:** `test(smoke): add wheel install smoke test (RED baseline for A8)` + +### Task A.verify: RED 基线全量校验 + +- [ ] **Step 1:** `uv run pytest test/compat test/unit test/integration test/smoke --tb=no -q` +- [ ] **Verify:** 输出含 56+ FAIL,0 ERROR,0 SKIP(与设计文档 §5.2 一致) +- [ ] **Verify:** 把数字写入 `docs/MIGRATION_NOTES.md` 第 4 节"Phase A 基线统计" +- [ ] **Commit:** `docs(migration): record Phase A RED baseline counts` + +--- + +## Phase B — Rust workspace 9 crate 骨架(1 天) + +### Task B.1: workspace 根 Cargo.toml + 9 个空 crate + +**Files:** +- Create: `Cargo.toml` +- Create: `rust-toolchain.toml` +- Create: `.cargo/config.toml` +- Create: `crates/{czsc-core,czsc-utils,czsc-ta,czsc-signals,czsc-trader,czsc-signal-macros,error-macros,error-support,czsc-python}/Cargo.toml` +- Create: `crates/{czsc-core,czsc-utils,czsc-ta,czsc-signals,czsc-trader,czsc-signal-macros,error-macros,error-support,czsc-python}/src/lib.rs` +- Create: `tests/test_workspace_layout.sh` + +- [ ] **Step 1: 写失败测试** —— `tests/test_workspace_layout.sh`: + +```bash +#!/usr/bin/env bash +set -euo pipefail +required=( czsc-core czsc-utils czsc-ta czsc-signals czsc-trader czsc-signal-macros error-macros error-support czsc-python ) +for c in "${required[@]}"; do + test -f "crates/$c/Cargo.toml" || { echo "missing crates/$c/Cargo.toml"; exit 1; } + test -f "crates/$c/src/lib.rs" || { echo "missing crates/$c/src/lib.rs"; exit 1; } +done +cargo metadata --format-version 1 --no-deps | python3 -c " +import json, sys +data = json.load(sys.stdin) +members = {pkg['name'] for pkg in data['packages']} +required = set('${required[*]}'.split()) +missing = required - members +assert not missing, f'cargo workspace missing: {missing}' +print(f'OK: {len(required)} crates registered') +" +cargo build --workspace --quiet +``` + +- [ ] **Step 2: 跑测试看 RED** —— `bash tests/test_workspace_layout.sh` → 失败(目录不存在) +- [ ] **Step 3: GREEN** —— 创建 9 个空 crate(每个 `lib.rs` 仅含 `pub fn placeholder() {}`),写顶层 `Cargo.toml`: + +```toml +[workspace] +resolver = "2" +members = ["crates/*"] + +[workspace.package] +version = "1.0.0" +edition = "2024" +license = "MIT" +repository = "https://github.com/waditu/czsc" + +[workspace.dependencies] +polars = { version = "0.42.0" } +chrono = "0.4" +pyo3 = { version = "0.25", features = ["extension-module", "abi3-py310"] } +numpy = "0.25" +rayon = "1" +hashbrown = "0.14" +serde = { version = "1", features = ["derive"] } +ordered-float = "5.0" + +[profile.release] +lto = true +opt-level = 3 +codegen-units = 1 +``` + +`rust-toolchain.toml`: + +```toml +[toolchain] +channel = "stable" +``` + +`.cargo/config.toml`: + +```toml +[build] +incremental = true + +[profile.release] +lto = true +opt-level = 3 +codegen-units = 1 +``` + +每个子 crate 的 `Cargo.toml` 形如: + +```toml +[package] +name = "czsc-core" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "CZSC core analyzer (FX/BI/ZS/CZSC) — placeholder, to be migrated from rs-czsc" + +[lib] +name = "czsc_core" +path = "src/lib.rs" +``` + +`czsc-signal-macros` / `error-macros` 是 proc-macro: + +```toml +[package] +name = "czsc-signal-macros" +version.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +proc-macro = true +path = "src/lib.rs" +``` + +`czsc-python` 启用 pyo3 extension-module: + +```toml +[package] +name = "czsc-python" +version.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +name = "czsc_python" +crate-type = ["cdylib", "rlib"] + +[dependencies] +pyo3 = { workspace = true } +``` + +- [ ] **Step 4: 跑测试看 GREEN** —— `bash tests/test_workspace_layout.sh` → `OK: 9 crates registered` + `cargo build --workspace` 成功 +- [ ] **Verify:** `cargo metadata --no-deps | jq '.packages | length'` → `9` +- [ ] **Commit:** `feat(rust): scaffold workspace with 9 empty crates` + +--- + +## Phase C — czsc-utils 测试驱动迁移(1 天) + +> **模式:** "复制即测试"。**不**整体复制 `src/`,按 rs-czsc 测试逐个迁移。 + +### Task C.1: freq_data 模块 + +**Files:** +- Copy: `crates/czsc-utils/tests/test_freq_data.rs` ← rs-czsc 同名文件 +- Copy: `crates/czsc-utils/src/freq_data.rs` ← rs-czsc 同名文件 + +- [ ] **Step 1 (RED):** 复制测试到 `crates/czsc-utils/tests/`,跑 `cargo test -p czsc-utils freq` → 失败(src 还空) +- [ ] **Step 2 (GREEN):** 复制 `czsc-utils/src/freq_data.rs`,更新 `lib.rs` 导出 `pub mod freq_data;` +- [ ] **Step 3:** 跑同样命令 → PASS +- [ ] **Commit:** `feat(utils): migrate freq_data module (TDD)` + +### Task C.2: BarGenerator 模块 + +**Files:** +- Copy: `crates/czsc-utils/tests/test_bar_generator.rs` +- Copy: `crates/czsc-utils/src/bar_generator.rs` + +- [ ] 同 Task C.1 模式 +- [ ] **Commit:** `feat(utils): migrate BarGenerator module (TDD)` + +### Task C.3: is_trading_time 新增(czsc-only) + +**Files:** +- Create: `crates/czsc-utils/tests/test_trading_time.rs` +- Create: `crates/czsc-utils/src/trading_time.rs` + +- [ ] **Step 1 (RED):** 写 Rust 单元测试,覆盖三类日历的 ~12 个时间点 +- [ ] **Step 2 (GREEN):** 写 Rust 实现(A 股:9:30-11:30 / 13:00-15:00 weekday;港股:9:30-12:00 / 13:00-16:00;crypto:always) +- [ ] **Step 3:** PyO3 binding (`crates/czsc-utils/src/python/mod.rs`) — 暴露 `is_trading_time` 为 `#[pyfunction]` +- [ ] **Step 4 (RED → GREEN):** Python 端 `test/unit/test_trading_time.py`(Phase A 已写)由 RED 转 GREEN —— **A6 标志** +- [ ] **Commit:** `feat(utils): add is_trading_time with PyO3 binding (czsc-only)` + 更新 `MIGRATION_NOTES.md` §2.2 + +### Task C.4: PyO3 binding 注册 + +**Files:** +- Create: `crates/czsc-utils/src/python/mod.rs` + +- [ ] 暴露 `BarGenerator` / `freq_end_time` / `is_trading_time`,提供 `pub fn register(py, m) -> PyResult<()>` +- [ ] **Commit:** `feat(utils): register python bindings for utils crate` + +**Phase C 验证:** +- [ ] `cargo test -p czsc-utils` 全过 +- [ ] Phase A 中 A6 由 RED 转 GREEN + +--- + +## Phase D — czsc-core 测试驱动迁移(2 天) + +> 对照 rs-czsc 的 `czsc-core` 模块清单(FX / BI / ZS / CZSC / Direction / Mark / Operate / Signal / Event / Position 等),每个数据类型一个子循环。 + +### Task D.1 ~ D.10: 每个 type 一个子循环 + +每个子循环统一模板: + +1. **RED (Rust):** 复制 `crates/czsc-core/tests/test_.rs`,跑 `cargo test -p czsc-core ` 失败 +2. **GREEN (Rust):** 复制对应 `src/objects/.rs` + `lib.rs` 导出,cargo test PASS +3. **RED (PyO3):** Python 端 `test/unit/test__py.py` 写 binding 行为断言,pytest 失败 +4. **GREEN (PyO3):** 在 `crates/czsc-core/src/python/.rs` 暴露 `#[pyclass]`,注册到 `czsc-python` +5. **RED (pickle):** 在 Phase A.2 中预先注入 `` 的 pickle roundtrip case,pytest 失败 +6. **GREEN (pickle):** 实现 `__getstate__` / `__setstate__`(serde + bincode),pytest PASS +7. **Commit:** `feat(core): migrate with PyO3 + pickle support` + +| Task | 类型 | rs-czsc 路径 | 影响 A 项 | +|-|-|-|-| +| D.1 | `Freq` (enum) | `czsc-core/src/objects/freq.rs` | A2/A3 | +| D.2 | `Mark` / `Direction` / `Operate` (enum) | `czsc-core/src/objects/enums.rs` | A2/A3 | +| D.3 | `RawBar` / `NewBar` | `czsc-core/src/objects/bar.rs` | A2/A3 | +| D.4 | `FX` | `czsc-core/src/objects/fx.rs` | A3 | +| D.5 | `BI` | `czsc-core/src/objects/bi.rs` | A3 | +| D.6 | `ZS` | `czsc-core/src/objects/zs.rs` | A3 | +| D.7 | `Signal` / `Event` | `czsc-core/src/objects/signal.rs` | A3 | +| D.8 | `Position` | `czsc-core/src/objects/position.rs` | A2 | +| D.9 | `CZSC` analyzer | `czsc-core/src/analyze/mod.rs` | A2/A3 | +| D.10 | `check_bi/check_fx/check_fxs/remove_include` 可见性提升 | `czsc-core/src/analyze/utils.rs` | A1 | + +### Task D.10 特别说明 + +- **RED:** Python 端 `test/compat/test_public_api.py` 中 `check_bi/check_fx/check_fxs/remove_include` 4 项断言失败 +- **GREEN:** 把 4 个函数从 `pub(crate)` 改为 `pub`,加 `#[pyfunction]` +- **MIGRATION_NOTES.md:** 写入 §2.1 表 +- **Commit:** `feat(core): expose check_bi/check_fx/check_fxs/remove_include (czsc-only public)` + +**Phase D 验证:** +- [ ] `cargo test -p czsc-core` 全过 +- [ ] Phase A 中 A2 + A3 由 RED 转 GREEN(公共 API 涉及部分对应转 GREEN) + +--- + +## Phase E — czsc-ta + czsc-signal-macros(1.5 天) + +### Task E.1: czsc-ta 调用图静态分析 + +**Files:** +- Create: `MIGRATION_NOTES.md` §2.3 czsc-ta 算子裁剪清单 + +- [ ] **Step 1:** `rg "use czsc_ta" rs_czsc/crates/czsc-{trader,signals}/ --no-heading -o | sort -u` → 实际被引用的算子列表 +- [ ] **Step 2:** 把白名单写入 `MIGRATION_NOTES.md` §2.3,未上榜的算子记入"被裁剪" +- [ ] **Commit:** `docs(migration): record czsc-ta operator whitelist` + +### Task E.2 ~ E.N: 白名单算子逐个迁移 + +每个算子一个 RED→GREEN 子循环(参考 Phase C 模式),形如 `ema/sma/rolling_rank/boll_positions/ultimate_smoother/...`。 + +- [ ] PyO3 binding:启用 `rust-numpy` feature,`czsc-python` 注册为 `czsc._native.ta` 子模块 + +### Task E.last: czsc-signal-macros + +- [ ] **RED:** 写最小宏展开测试 `crates/czsc-signal-macros/tests/test_signal_module.rs` +- [ ] **GREEN:** 复制 macro 实现,cargo test PASS +- [ ] **Commit:** `feat(macros): migrate signal_module proc-macro` + +**Phase E 验证:** +- [ ] Phase A 中 A5 由 RED 转 GREEN + +--- + +## Phase F — czsc-signals 迁移(1.5 天) + +每个子模块(`bar / cxt / tas / vol / pressure / obv / cvolp`)一组 RED→GREEN 子循环: + +### Task F.{bar,cxt,tas,vol,pressure,obv,cvolp} + +1. **RED:** 复制 `crates/czsc-signals/tests/test_.rs`,cargo test 失败 +2. **GREEN:** 复制 `src//`,更新 `lib.rs` +3. **RED (Python):** `test/unit/test_signals_parity.py` 中该子模块的 case 失败(A4 的部分) +4. **GREEN:** PyO3 binding 注册 `czsc._native.signals.`,并 `czsc/signals/.py` re-export +5. **Commit:** `feat(signals): migrate module` + +**Phase F 验证:** +- [ ] Phase A 中 A4 由 RED 转 GREEN + +--- + +## Phase G — czsc-trader 迁移(含 strategies)(2 天) + +### Task G.1 ~ G.4: 核心对象 + +| Task | 对象 | rs-czsc 路径 | +|-|-|-| +| G.1 | `CzscTrader` | `czsc-trader/src/trader.rs` | +| G.2 | `CzscSignals` | `czsc-trader/src/signals_holder.rs` | +| G.3 | `generate_czsc_signals` | `czsc-trader/src/generators.rs` | +| G.4 | `get_unique_signals` | `czsc-trader/src/utils.rs` | + +每个一个 RED→GREEN 子循环。 + +### Task G.5: strategies 迁移到 Rust + +**Files:** +- Create: `crates/czsc-trader/src/strategies/{base,json}.rs` +- Modify: `czsc/strategies.py` → 删除 + +- [ ] **Step 1 (RED):** Python 端 `test/integration/test_strategies.py` 断言 `czsc.CzscStrategyBase` 与原行为一致 +- [ ] **Step 2 (GREEN):** Rust 端实现 `StrategyBase` / `JsonStrategy`,PyO3 暴露 +- [ ] **Step 3:** 删 `czsc/strategies.py`,pytest 仍 GREEN +- [ ] **Commit:** `refactor(trader): migrate strategies from Python to Rust` + +### Task G.6: WeightBacktest 暂保持 RED + +- 本 phase 不实现,由 Phase I 由 wbt 接管 + +**Phase G 验证:** +- [ ] `cargo test -p czsc-trader` 全过 +- [ ] A2/A3 中 trader 相关 case 转 GREEN +- [ ] A1 中 `CzscTrader/CzscSignals/generate_czsc_signals/get_unique_signals` 转 GREEN + +--- + +## Phase H — czsc-python 聚合 + Python 包重构(1.5 天) + +### Task H.1: czsc-python 聚合 register() + +**Files:** +- Modify: `crates/czsc-python/src/lib.rs` + +```rust +use pyo3::prelude::*; + +#[pymodule] +fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + czsc_core::python::register(py, m)?; + czsc_utils::python::register(py, m)?; + czsc_ta::python::register(py, m)?; + czsc_signals::python::register(py, m)?; + czsc_trader::python::register(py, m)?; + Ok(()) +} +``` + +- [ ] **Commit:** `feat(python): aggregate all crate registers in czsc._native` + +### Task H.2: pyproject.toml 切换至 maturin + +**Files:** +- Modify: `pyproject.toml` + +- [ ] `[build-system]` → `maturin` +- [ ] `[tool.maturin]` → `module-name = "czsc._native"`、`features = ["pyo3/extension-module"]`、`manifest-path = "crates/czsc-python/Cargo.toml"` +- [ ] `[project.dependencies]` 加入 `wbt` +- [ ] 移除 `rs-czsc` PyPI 依赖(如有) +- [ ] **Commit:** `build: switch from hatchling to maturin` + +### Task H.3: 重写 czsc/__init__.py + +**Files:** +- Modify: `czsc/__init__.py` +- Delete: `czsc/core.py` + +- [ ] 按 §3.1 表逐项 import;移除 `_LAZY_MODULES` / `_LAZY_ATTRS` / `__getattr__` +- [ ] 每加一项跑 `pytest test/compat/test_public_api.py -v`,记录 RED → GREEN 转换 +- [ ] **Commit:** `refactor(init): rewrite czsc/__init__.py to re-export from czsc._native` + +### Task H.4: 极薄化 czsc/signals/ + czsc/traders/ + czsc/ta/ + +- [ ] `czsc/signals/{bar,cxt,...}.py` 仅 `from czsc._native.signals. import *` +- [ ] `czsc/traders/__init__.py` 仅 import + re-export(保留 `sig_parse.py` 待评估) +- [ ] 新增 `czsc/ta/__init__.py` re-export `czsc._native.ta` +- [ ] **Commit:** `refactor: thin shells for czsc.signals/traders/ta` + +**Phase H 验证:** +- [ ] Phase A 中 A1 由 RED 转 GREEN(80+ 公共名称全部可导入) + +--- + +## Phase I — wbt 集成(0.5 天) + +### Task I.1: 加 wbt 硬依赖 + +- [ ] `uv add wbt` → `pyproject.toml` +- [ ] **Commit:** `build: add wbt as hard dependency` + +### Task I.2: re-export wbt 公共 API + +**Files:** +- Modify: `czsc/__init__.py`、`czsc/traders/__init__.py` + +- [ ] `from wbt import WeightBacktest, daily_performance, top_drawdowns` +- [ ] **Verify:** `pytest test/integration/test_weight_backtest.py -v` 全 GREEN(A7) +- [ ] **Commit:** `feat: wire wbt as the canonical backtest/perf provider` + +### Task I.3: czsc/mock.py 退化为薄壳 + +**Files:** +- Modify: `czsc/mock.py`(537 行 → ~30 行) + +- [ ] 仅保留 `from wbt.mock import generate_symbol_kines, generate_klines_with_weights` 等转发 +- [ ] **Commit:** `refactor(mock): degrade to thin shell forwarding wbt.mock` + +**Phase I 验证:** +- [ ] Phase A 中 A7 由 RED 转 GREEN + +--- + +## Phase J — Python 删减(0.5 天) + +按 §3.2 / §9 附录 B 的"完全删除"列表逐文件 `git rm`,每删一组跑 `pytest -q` 确认 GREEN 不变。 + +### Task J.1 ~ J.5 + +| Task | 删除目标 | 验证 | +|-|-|-| +| J.1 | `czsc/utils/ta.py` | `pytest test/unit/test_ta_parity.py` 仍 GREEN | +| J.2 | `czsc/traders/{base,cwc,rwc,optimize,weight_backtest,performance,dummy}.py` | `pytest test/` 仍 GREEN | +| J.3 | `czsc/py/` 目录 | `pytest test/` 仍 GREEN | +| J.4 | `czsc/features/` 目录 | `pytest test/` 仍 GREEN | +| J.5 | `czsc/utils/{bar_generator,bi_info,echarts_*,pdf_report,html_report_builder,word_writer,features,st_components,corr,signal_analyzer}.py` + `analysis/` 目录 | `pytest test/` 仍 GREEN | + +每个 task 都 `Commit: chore: remove (replaced by )`。 + +**Phase J 验证:** +- [ ] `find czsc -name '*.py' | xargs wc -l` 总行数 ≤ 12500(Q5 目标 ~12K) +- [ ] `pytest test/` 全过 + +--- + +## Phase K — CI / Trusted Publishing / finishing(1 天) + +### Task K.1: GitHub Actions workflow + +**Files:** +- Create: `.github/workflows/ci.yml` +- Create: `.github/workflows/release.yml` + +- [ ] CI: Rust(fmt / clippy -D warnings / test --workspace)+ Python(maturin develop + pytest + ruff + basedpyright) +- [ ] Release: maturin build wheel matrix(manylinux_2_28 linux + universal2 macos + windows)+ smoke + Trusted Publishing +- [ ] **Commit:** `ci: add Rust + Python pipelines and release workflow` + +### Task K.2: Trusted Publishing OIDC binding + +- [ ] 在 PyPI / TestPyPI 项目设置中绑定(仓库 + workflow + environment 三元组) +- [ ] 不在 GitHub Actions secrets 中存任何 PyPI token +- [ ] **Verify:** workflow 包含 `permissions: id-token: write` + `pypa/gh-action-pypi-publish@release/v1`(不带 token) + +### Task K.3: 发 RC 至 test.pypi.org + +- [ ] tag `1.0.0rc1`,触发 release workflow +- [ ] 干净 venv: `pip install --index-url https://test.pypi.org/simple/ czsc==1.0.0rc1` +- [ ] **Verify:** `pytest test/smoke/test_install.py -v` GREEN(A8) +- [ ] **Commit:** `chore: publish 1.0.0rc1 to test.pypi.org` + +### Task K.4: finishing + +- [ ] 用 `superpowers:finishing-a-development-branch` 合并 worktree → master +- [ ] 写 release notes(含 §6.T2 列出的所有破坏性变更 + 替代方案) +- [ ] tag `1.0.0` +- [ ] **Verify:** `pip install czsc` 后 `python -c "import czsc; czsc.CZSC"` 成功 +- [ ] **Commit:** `release: 1.0.0 — Rust + PyO3 unified package` + +**Phase K 验证:** +- [ ] Phase A 中 A8 由 RED 转 GREEN +- [ ] 全部 A1~A8 GREEN,pytest 100% PASS,cargo test 100% PASS + +--- + +## 进度可视化 + +每次 commit 后,CI 跑 `scripts/red_green_report.py` 输出: + +``` +Phase A 基线: 56 项断言 +当前: 红 X 项 / 绿 Y 项 / 总 56 项 +本次 commit 影响: A3 +4 GREEN, A4 +2 GREEN +``` + +并写入 PR 描述的进度行,作为 plan 完成度的客观依据。 + +--- + +## 验收闭环 + +最终全部 task 完成后,要求满足: + +- [ ] `cargo test --workspace` 100% PASS +- [ ] `pytest test/` 100% PASS +- [ ] `cargo clippy --all-targets -- -D warnings` 无 warning +- [ ] `cargo fmt --check` 通过 +- [ ] `ruff check czsc test` 无 issue +- [ ] `basedpyright czsc` 无 error +- [ ] `pytest --cov=czsc` 公共 API 覆盖率 ≥ 90%,整体 ≥ 70% +- [ ] `find czsc -name '*.py' | xargs wc -l` ≤ 12500 +- [ ] `grep -r "rs_czsc" examples/ docs/` 无结果 +- [ ] `grep -r "CZSC_USE_PYTHON" czsc/` 无结果 +- [ ] `pip install czsc==1.0.0` 后 `python -X importtime -c "import czsc"` import 时间 ≤ 300ms +- [ ] 10 万根 K 线 CZSC 完整分析 ≤ 200ms(M2 Mac) +- [ ] `MIGRATION_NOTES.md` §2.1 / §2.2 / §2.3 / §2.4 全部填好 diff --git a/docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md b/docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md index c3355a0eb..5d5912066 100644 --- a/docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md +++ b/docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md @@ -1,8 +1,8 @@ # Rust 实现的 czsc 核心对象迁移 — 设计方案 - 作者:Claude Code (协作设计) -- 日期:2026-05-03 -- 状态:草案,待评审 +- 日期:2026-05-03(v0.3 — 全局一致性修订 + 迁移流程改 TDD) +- 状态:v0.3 草案(全局一致性修订 + 迁移路径改写为 superpowers TDD 模式,详见 §10) - 关联需求:[飞书需求文档](https://s0cqcxuy3p.feishu.cn/wiki/OZ3dwY68oiJhTdk7HRCcXNPRnCh) - 关联仓库:`czsc`(本仓库)、`rs-czsc`(参考实现,路径 `../offline/rs_czsc`) @@ -20,14 +20,14 @@ ### 0.2 已确认的设计原则 | # | 决策 | 状态 | -|---|---|---| +|-|-|-| | 1 | czsc 内置 Rust 源码(仓库内带 Rust workspace),rs-czsc 仍作为独立项目存在 | 已确认 | | 2 | 必迁 Rust crate:`czsc-core` + `czsc-signals` + `czsc-trader` + `czsc-utils` + `czsc-ta`(+ 配套 proc-macro / error-support) | 已确认 | -| 3 | Python 端**保留**:`connectors / envs / core / mock / svc / sensors / strategies / utils(精简)` | 已确认 | -| 4 | Python 端**删除**:`py / eda.py / aphorism.py / features / fsa` | 已确认 | +| 3 | Python 端**保留**:`connectors / envs / mock(薄层) / svc / sensors / strategies(临时) / utils(精简)`(`core.py` 不再保留,公共名称由 `czsc/__init__.py` 直接 re-export `czsc._native`) | 已确认 | +| 4 | Python 端**删除**:`py` / `features`(`eda.py` / `aphorism.py` / `fsa/` 暂保留) | 已确认 | | 5 | 所有面向用户的 API 都通过 `czsc.xxx` 暴露,禁止用户感知 `rs_czsc` 或 `czsc._native` | 已确认 | | 6 | 构建方式:`maturin + Rust workspace`,扩展模块名 `czsc._native` | 已确认 | -| 7 | rs-czsc 同步策略:复制即 fork,后续靠 cherry-pick + `MIGRATION_NOTES.md` 记录 | 已确认 | +| 7 | rs-czsc **后续不再维护**:czsc 一次性 fork 后独立演进,`MIGRATION_NOTES.md` 仅记录基线 commit;不再做季度同步 / cherry-pick | 已确认 | --- @@ -43,44 +43,48 @@ czsc/ ├── crates/ # ← 新增:Rust workspace 成员(9 个 crate) │ ├── czsc-core/ # 缠论核心算法(CZSC/FX/BI/ZS/...) │ ├── czsc-signals/ # 30+ 信号函数(macro 注册) -│ ├── czsc-trader/ # 回测/权重/优化/CzscTrader +│ ├── czsc-trader/ # 回测/权重/优化/CzscTrader/StrategyBase(待迁) │ ├── czsc-utils/ # BarGenerator/日历/错误/性能详情 -│ ├── czsc-ta/ # 25+ 技术分析算子(pure + mixed) +│ ├── czsc-ta/ # TA 算子(仅保留被 czsc-signals/czsc-trader 实际调用的) │ ├── czsc-signal-macros/ # proc-macro:#[signal_module] 注册 │ ├── error-macros/ # proc-macro:错误类型生成 │ ├── error-support/ # 错误基础库 -│ └── czsc-python/ # PyO3 binding 总入口 → 产出 czsc._native +│ └── czsc-python/ # PyO3 binding 总入口 → 产出 czsc._native(所有暴露对象支持 pickle) │ -├── czsc/ # Python 包(精简后约 8K 行) +├── czsc/ # Python 包(精简后约 12K 行) │ ├── __init__.py # 重写,统一从 czsc.xxx 暴露 │ ├── _native.pyi # type stub(pyo3-stub-gen 生成) │ ├── envs.py # 仅保留 czsc_min_bi_len / czsc_max_bi_num / czsc_verbose -│ ├── core.py # 极简:from czsc._native import * -│ ├── mock.py # 测试数据生成(保留) -│ ├── strategies.py # 保留(svc/sensors 依赖的 CzscStrategyBase / CzscJsonStrategy) +│ ├── mock.py # 薄壳:转发 wbt 的 mock 函数(generate_symbol_kines) +│ ├── strategies.py # 临时保留:CzscStrategyBase / CzscJsonStrategy(应迁 Rust,迁移完成后删除) +│ ├── aphorism.py # 保留 +│ ├── eda.py # 保留(暂留,后续重构) +│ ├── fsa/ # 保留 │ ├── connectors/ # 完整保留(5 个连接器:tushare/tqsdk/ccxt/research/cooperation) │ ├── sensors/ # 完整保留(CTAResearch + 工具) │ ├── svc/ # 完整保留(Streamlit 可视化组件) │ ├── signals/ # 极薄,仅 re-export Rust 信号到 czsc.signals.{bar,cxt,...} -│ ├── traders/ # 极薄,仅 re-export Rust 对象 + 保留 dummy/sig_parse -│ ├── ta/ # 极薄,re-export Rust ta 算子(与 utils/ta.py 命名空间分离) -│ └── utils/ # 大幅精简至 ~2.5K 行(详见 §3.2) +│ ├── traders/ # 极薄,仅 re-export Rust 对象 + 保留 sig_parse(待评估 Rust 是否已实现) +│ ├── ta/ # 极薄,re-export Rust ta 算子(czsc.utils.ta 不再保留) +│ └── utils/ # 大幅精简至 ~3K 行(详见 §3.2) │ ├── test/ # Python pytest 套件 -│ ├── conftest.py # 注入 czsc.mock fixtures -│ ├── unit/ # 单元测试(核心对象、信号、TA) -│ ├── integration/ # 集成测试(trader、回测、connectors) +│ ├── conftest.py # 注入 wbt mock fixtures +│ ├── unit/ # 单元测试(核心对象、信号、TA、pickle 可序列化) +│ ├── integration/ # 集成测试(trader、回测对接 wbt、connectors) │ └── compat/ # API 兼容快照(锁定 czsc.* 公共名称) │ ├── docs/ │ ├── superpowers/specs/ # 本设计文档存放地 -│ ├── MIGRATION_NOTES.md # 记录从 rs-czsc 哪个 commit 复制而来 +│ ├── MIGRATION_NOTES.md # 仅记录从 rs-czsc 哪个 commit 复制而来(无需后续同步) │ └── ... │ └── examples/ # 保留示例(同步调整 import) ``` -**删除**:`czsc/py/`、`czsc/eda.py`、`czsc/aphorism.py`、`czsc/features/`、`czsc/fsa/`,以及 `czsc/utils/` 中所有可视化/报告生成模块(详见 §3.2)。 +**删除**:`czsc/py/`、`czsc/features/`,以及 `czsc/utils/` 中所有可视化/报告生成模块(详见 §3.2)。 +**暂保留待评估/重构**:`czsc/eda.py`(后续重构)、`czsc/strategies.py`(应迁 Rust)、`czsc/traders/sig_parse.py`(待评估 Rust 是否已等价实现)。 +**保留**:`czsc/aphorism.py`、`czsc/fsa/`、`czsc/utils/oss.py`。 --- @@ -124,7 +128,7 @@ codegen-units = 1 │ ┌──────────┼──────────┬───────────┐ czsc-utils czsc-core czsc-ta czsc-signal-macros (proc-macro) - │ │ │ │ + │ │ │ │ └─────┬────┴──────────┴───────────┘ │ czsc-signals @@ -139,11 +143,11 @@ czsc-utils czsc-core czsc-ta czsc-signal-macros (proc-macro) ### 2.3 czsc-ta 集成方式 - crate 完整复制到 `crates/czsc-ta/` -- 业务调用:`czsc-trader / czsc-signals` 中对 czsc-ta 的引用保持原样 +- 业务调用:`czsc-trader` / `czsc-signals` 中对 czsc-ta 的引用保持原样;**未被调用的 ta 算子(如部分仅在 rs-czsc 中孤立保留的指标)随迁移裁剪掉,不进入 czsc 仓库**,控制 Rust crate 体积与编译时长 - PyO3 暴露:`czsc-ta` 启用 `rust-numpy` feature,通过 `czsc-python` 注册为 `czsc._native.ta` 子模块 -- Python 端命名空间分离: - - `czsc.ta.*` ← Rust 实现(高性能、向量化、NumPy 互操作) - - `czsc.utils.ta.*` ← Python TA-Lib wrapper(保持向后兼容) +- Python 端唯一入口: + + - `czsc.ta.*` ← Rust 实现(高性能、向量化、NumPy 互操作);**不再保留 czsc.utils.ta 的 Python TA-Lib wrapper**(由 Rust 端统一提供) ### 2.4 PyO3 binding 层(`crates/czsc-python`) @@ -165,13 +169,15 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { - **扩展模块名**:`czsc._native`(`pyproject.toml` 中 `tool.maturin.module-name = "czsc._native"`) - **type stubs**:`czsc/_native.pyi` 由 `pyo3-stub-gen` 自动生成(沿用 rs-czsc 现有做法) - **ABI 策略**:`abi3-py310`,一次构建多 Python 版本通用,简化发布矩阵 +- **pickle 序列化(强制要求)**:所有通过 PyO3 暴露给 Python 的对象(`CZSC`/`BarGenerator`/`Position`/`CzscTrader`/`CzscSignals` 等),**必须实现 `__getstate__` / `__setstate__`**,使其可以被 `pickle.dumps`/`pickle.loads`,并在多进程(Streamlit/Joblib/Dask)和断点续跑场景下保持一致。Rust 侧建议通过 `serde` 序列化为 bincode/JSON,再桥接到 PyO3 的 `__reduce__`。 +- **验收**:`test/unit/test_pickle.py` 对每个公开 PyO3 类做 `roundtrip` 测试,禁止新增不可 pickle 的对象 ### 2.5 关键调整:Rust 端可见性提升 在 rs-czsc 中以下 4 个函数为 `pub(crate)`,迁移到 czsc 后需提升为 `pub` 并加 PyO3 binding: | 函数 | rs-czsc 位置 | 当前可见性 | 调整后 | -|---|---|---|---| +|-|-|-|-| | `remove_include` | `czsc-core/src/analyze/utils.rs:32` | `pub(crate)` | `pub` + `#[pyfunction]` | | `check_fxs` | `czsc-core/src/analyze/utils.rs:119` | `pub(crate)` | `pub` + `#[pyfunction]` | | `check_fx` | `czsc-core/src/analyze/utils.rs:158` | `pub(crate)` | `pub` + `#[pyfunction]` | @@ -182,10 +188,10 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { ### 2.6 与 rs-czsc 的同步策略 -- **不用** git submodule / subtree。**复制即 fork**。 -- 维护 `docs/MIGRATION_NOTES.md` 记录基线 commit hash 和迁移日期。 -- 后续 rs-czsc 有 bugfix 通过手动 cherry-pick 同步,PR 描述中标注 `[sync from rs-czsc ]`。 -- 如果 czsc 端做了独立改动(如新增 `is_trading_time`),同样在 `MIGRATION_NOTES.md` 中标记为 "czsc-only"。 +- **不用** git submodule / subtree。**复制即 fork,rs-czsc 后续不再维护**,czsc 独立演进。 +- 维护 `docs/MIGRATION_NOTES.md` 仅记录基线 commit hash + 迁移日期,作为历史溯源;不再做季度同步。 +- czsc 内部对原 rs-czsc 模块所做的改动(如 `pub(crate)` → `pub`、新增/裁剪算子)一律按本仓库的常规 PR 流程合入,无需在 PR 描述中标注 sync 来源。 +- czsc-only 的能力(如 `is_trading_time`、被裁剪的 ta 算子清单)在 `MIGRATION_NOTES.md` 的"czsc-only 改动"小节集中列出。 --- @@ -194,31 +200,31 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { ### 3.1 `czsc/__init__.py` 公共 API 表 | 命名空间 | 来源 | 暴露名 | -|---|---|---| +|-|-|-| | **顶层核心对象** | `czsc._native` | `CZSC, FX, BI, ZS, RawBar, NewBar, Freq, Mark, Direction, Operate, Signal, Event, Position, BarGenerator, format_standard_kline, freq_end_time, is_trading_time, check_bi, check_fx, check_fxs, remove_include` | -| **顶层交易对象** | `czsc._native` | `CzscTrader, CzscSignals, WeightBacktest, generate_czsc_signals, get_unique_signals` | -| **顶层 TA / 性能函数** | `czsc._native` | `daily_performance, top_drawdowns, ultimate_smoother, rolling_rank, ema, sma, boll_positions, ...`(25+ 函数完整列表见 stubs) | +| **顶层交易对象** | `czsc._native` | `CzscTrader, CzscSignals, generate_czsc_signals, get_unique_signals`
`WeightBacktest`:**czsc 内部 `from wbt import WeightBacktest` 后照常暴露为 `czsc.WeightBacktest`**(czsc 不重新实现,但保持公共 API 名称兼容);wbt 是硬依赖。 | +| **顶层 TA / 性能函数** | `czsc._native` | `ultimate_smoother, rolling_rank, ema, sma, boll_positions, ...`(25+ 函数完整列表见 stubs,来自 `czsc._native`)
`daily_performance` / `top_drawdowns`:**czsc 内部 `from wbt import ...` 后照常暴露为 `czsc.daily_performance` / `czsc.top_drawdowns`**,对外保持 czsc.\* 公共 API 兼容;wbt 作为 `pyproject.toml` 中的**硬依赖**(不是 optional) | | `czsc.ta.*` | `czsc._native.ta` | Rust TA 算子的子模块入口,与顶层重复暴露兼容 | | `czsc.signals.{bar,cxt,tas,vol,...}` | `czsc._native.signals.*` | 30+ 信号函数按类别分组 | -| `czsc.traders.*` | Python 薄层 + `czsc._native` | `CzscTrader, CzscSignals, WeightBacktest, DummyBacktest, SignalsParser` | +| `czsc.traders.*` | Python 薄层 + `czsc._native` | `CzscTrader, CzscSignals, generate_czsc_signals, get_unique_signals`;`WeightBacktest`(来自 wbt);`SignalsParser`(待评估 Rust 是否已等价实现,迁移完成前作 Python 薄层)。**DummyBacktest 已删除**。 | | `czsc.connectors.*` | 完整保留 | `tushare/tqsdk/ccxt/research/cooperation` | | `czsc.sensors.*` | 完整保留 | `CTAResearch` 等 | | `czsc.svc.*` | 完整保留 | Streamlit dashboard 组件 | -| `czsc.mock` | 完整保留 | `generate_symbol_kines, generate_klines_with_weights` | +| `czsc.mock` | 薄层(转发 wbt) | `generate_symbol_kines, generate_klines_with_weights`(czsc 仍然暴露这两个名称,**实现内部 `from wbt.mock import ...`**,czsc.mock 退化为转发壳,不再维护重复实现) | | `czsc.envs` | 精简 | `czsc_min_bi_len, czsc_max_bi_num, czsc_verbose` | -| `czsc.strategies` | 保留 | `CzscStrategyBase, CzscJsonStrategy` | -| `czsc.utils.*` | 大幅精简(§3.2) | `cache, io, log, ta, calendar, sig, plot_backtest, plotly_plot, kline_quality, data_client, trade_utils` | +| `czsc.strategies` | 保留 | `CzscStrategyBase, CzscJsonStrategy`(**临时保留**,应迁 Rust 端 `czsc-trader`;迁移完成后 `czsc.strategies` 退化为 `from czsc._native import ...` 薄层或直接删除) | +| `czsc.utils.*` | 大幅精简(§3.2) | `cache, io, log, calendar, sig, plot_backtest, plotly_plot, kline_quality, data_client, trade_utils, oss`(不再保留 `ta`,由 `czsc.ta.*` Rust 实现替代) | -`__init__.py` 中**移除** `_LAZY_MODULES` / `_LAZY_ATTRS` 延迟加载机制。Rust 扩展模块加载快(< 50 ms),所有公共 API 在顶层直接 import。删除 `__getattr__` 动态加载逻辑。 +`__init__.py` 中**移除**`_LAZY_MODULES` / `_LAZY_ATTRS` 延迟加载机制。Rust 扩展模块加载快(< 50 ms),所有公共 API 在顶层直接 import。删除 `__getattr__` 动态加载逻辑。 ### 3.2 `czsc/utils/` 精简清单 | 文件 | 决策 | 说明 | -|---|---|---| +|-|-|-| | `cache.py` | **保留** | 磁盘缓存基础设施 | | `io.py` | **保留** | dill / json 读写 | | `log.py` | **保留** | loguru 配置 | -| `ta.py` | **保留** | TA-Lib Python wrapper(与 Rust ta 互补) | +| `ta.py` | **删除** | 由 `czsc.ta.*`(Rust 实现,PyO3 暴露)替代,不再保留 Python TA-Lib wrapper | | `calendar.py` | **保留** | 交易日历薄层(核心算法已迁 Rust) | | `sig.py` | **保留**(精简) | 信号工具函数 `unique_signals` 等 | | `kline_quality.py` | **保留** | 数据质量校验 | @@ -234,12 +240,12 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { | `html_report_builder.py` | **删除** | 同上 | | `word_writer.py` | **删除** | 同上 | | `features.py` | **删除** | 与 czsc/features/ 同删 | -| `oss.py` | **删除** | 阿里云对象存储,业务代码 | +| `oss.py` | **保留** | 阿里云对象存储工具,研究/数据落地场景仍有用 | | `st_components.py` | **删除** | svc 已包含 Streamlit 组件 | | `corr.py` | **删除** | 业务代码 | | `signal_analyzer.py` | **删除** | 业务代码 | -精简后 `czsc/utils/` 从 ~10.7K 行降到 ~2.5K 行。 +精简后 `czsc/utils/` 从 \~10.7K 行降到 \~3K 行(含 oss.py 等保留模块)。 ### 3.3 `czsc/signals/` 与 `czsc/traders/` 极薄化 @@ -261,14 +267,16 @@ from czsc._native.signals.bar import * # noqa: F401,F403 ```python # czsc/traders/__init__.py from czsc._native import ( - CzscTrader, CzscSignals, WeightBacktest, + CzscTrader, CzscSignals, generate_czsc_signals, get_unique_signals, ) -from czsc.traders.dummy import DummyBacktest # 保留 Python 端 -from czsc.traders.sig_parse import SignalsParser # 保留 Python 端 +# WeightBacktest 由 wbt 包提供,czsc 内部 re-export 保持公共 API 兼容(wbt 是 pyproject.toml 中的硬依赖) +from wbt import WeightBacktest +# sig_parse: 待评估 Rust 是否已等价实现;评估完成后改为 from czsc._native import SignalsParser +from czsc.traders.sig_parse import SignalsParser ``` -**删除** `czsc/traders/` 中:`base.py`, `cwc.py`, `rwc.py`, `optimize.py`, `weight_backtest.py`, `performance.py`(Rust 已实现等价或更优实现)。 +**删除**`czsc/traders/` 中:`base.py`、`cwc.py`、`rwc.py`、`optimize.py`、`weight_backtest.py`、`performance.py`、`dummy.py`(Rust 已实现等价或更优;WeightBacktest 全部改用 wbt 包)。仅 `sig_parse.py` 临时保留待评估。 ### 3.4 `czsc.envs` 精简 @@ -282,18 +290,6 @@ czsc_verbose: bool = False # 详细日志 通过 `czsc._native.set_envs(min_bi_len=..., max_bi_num=..., verbose=...)` 一次性传给 Rust 端。`czsc/envs.py` 仅是这三个值的 Python 端配置入口。 -### 3.5 `czsc/core.py` 极简化 - -```python -# czsc/core.py -"""向后兼容入口:所有名称从 czsc._native 直接导入。""" -from czsc._native import * # noqa: F401,F403 -``` - -整个文件从 134 行降到 < 5 行。原有的 `if os.getenv("CZSC_USE_PYTHON")` 双路由逻辑彻底删除。 - ---- - ## 4. 测试体系 ### 4.1 整体原则 @@ -319,7 +315,7 @@ crates// **规则**: | 类别 | 位置 | 工具 | 触发 | -|---|---|---|---| +|-|-|-|-| | 单元测试 | `src/**/*.rs` 内 `#[cfg(test)] mod tests` | 标准 `cargo test` | `cargo test -p ` | | 集成测试 | `crates//tests/` | 标准 `cargo test` | `cargo test --test ` | | Benchmark | `crates//tests/benchmarks.rs` | `criterion` | `cargo bench` | @@ -358,10 +354,9 @@ test/ **关键约束(沿用 czsc CLAUDE.md 规范)**: -- 所有测试数据**统一通过** `czsc.mock` 模块获取,禁止硬编码模拟数据 +- 所有测试数据**统一通过 `czsc.mock` 模块入口**获取;`czsc.mock` 是 [wbt](https://github.com/zengbin93/wbt) mock 函数的转发壳(czsc 不再维护重复实现),禁止在测试中硬编码模拟数据 - 测试文件命名 `test_*.py`,使用 `pytest` 框架 - 测试 fixtures 通过 `conftest.py` 共享 -- 模拟数据使用 `generate_symbol_kines` 生成,支持多品种、多频率、可重现的随机数据 **Mock 策略**: @@ -371,7 +366,7 @@ test/ **质量门槛**: -- 公共 API 测试覆盖率 ≥ 70%(`pytest --cov=czsc --cov-report=xml`) +- 公共 API 测试覆盖率 ≥ 90%(`pytest --cov=czsc --cov-report=xml`) - 所有 `czsc.__all__` 中的名称必须在 `test/compat/test_public_api.py` 中有快照 - 安装冒烟测试 `test/smoke/` 在 CI 的 wheel 包发布前必跑 @@ -395,121 +390,151 @@ uv run maturin develop --release # 本地构建 Rust 扩展并安装到当 ## 5. 迁移路径与步骤 -按 8 个 Phase 推进,每个 Phase 是一个可独立 commit / PR 的工作单元。 + +**本章迁移工作流**采用 [superpowers](https://github.com/anthropics/superpowers) 的 TDD 范式 + plan/execute 工作流。所有迁移步骤遵守 `superpowers:test-driven-development` 的 Iron Law:**没有失败测试就不写实现代码**。涉及的 superpowers skills:`brainstorming`(讨论 spec)→ `writing-plans`(产出 plan)→ `using-git-worktrees`(隔离工作区)→ `executing-plans` / `subagent-driven-development`(执行)→ `test-driven-development`(每个 task 内部的 RGR 循环)→ `finishing-a-development-branch`(合并发布)。 + -### Phase 0 — 准备(0.5 天) +### 5.0 总体方法 -- 创建迁移分支 `refactor/rust-czsc-migration` -- 在 `docs/MIGRATION_NOTES.md` 记录 rs-czsc 基线 commit hash -- 在 CI 中暂时关闭"必须通过测试"门槛(避免迁移过程中持续红状态) -- 备份当前 czsc/ 全部 Python 模块到 `_legacy/` 临时目录(不入版本库,便于对比) +- **spec → plan → execute** 三段式:本设计文档存放在 `docs/superpowers/specs/2026-05-03-rust-czsc-migration.md`;据此产出 `docs/superpowers/plans/2026-05-03-rust-czsc-migration.md`(按 superpowers:writing-plans 规范写);plan 中每个 task 都是一个完整的 RED→GREEN→REFACTOR→COMMIT 循环。 +- **RED→GREEN→REFACTOR 循环**:每个 task ① 写最小失败测试 ② 跑测试看到失败(必须 fail,不能 error)③ 写最小实现 ④ 跑测试看到通过 ⑤ 必要时重构 ⑥ commit。任意步骤跳过都视为破坏 Iron Law,task 重做。 +- **Bite-sized 任务粒度**:每个步骤 2–5 分钟;plan 中每个 task 列出确切文件路径、测试代码、运行命令和预期输出。禁止 placeholder("TBD" / "implement later" / "类似 Task N")。 +- **Worktree 隔离**:所有迁移工作在 `git worktree add ../czsc-rust-migration refactor/rust-czsc-migration` 中进行;不污染 master 分支。 +- **测试驱动顺序**:自上而下——先把验收标准(§6 全表)翻译成失败测试形成"验收基线",再逐 crate 用 TDD 实现到 GREEN。**不允许**先复制 Rust 源码再补测试。 -### Phase 1 — 搭建 Rust workspace 骨架(1 天) +### 5.1 Phase 0 — Spec 评审 + Plan 产出(0.5 天) -- 在 czsc 仓库根目录新增 `Cargo.toml`、`rust-toolchain.toml`、`.cargo/config.toml` -- 创建空的 `crates/{czsc-core,czsc-signals,czsc-trader,czsc-utils,czsc-ta,czsc-signal-macros,error-macros,error-support,czsc-python}/` 目录 -- 每个 crate 一个空 `Cargo.toml` + `src/lib.rs`,先验证 `cargo build --workspace` 能通过 -- 验收:`cargo build --workspace` 成功 +**RED 前提(暂不写代码)**。本 phase 全部产出物为文档: -### Phase 2 — 复制 Rust 源码(1 天) +1. 用 `superpowers:brainstorming` 技能 review 本设计 spec,识别"没说清楚"的点(pickle 协议格式?wbt 版本约束?rs-czsc 基线 commit?)。 +2. 用 `superpowers:using-git-worktrees` 创建 worktree `../czsc-rust-migration`,进入。 +3. 在 worktree 中执行 `git rev-parse HEAD`(rs-czsc 子目录)锁定基线 commit,写入 `docs/MIGRATION_NOTES.md`。 +4. 用 `superpowers:writing-plans` 技能产出 plan 文件 `docs/superpowers/plans/2026-05-03-rust-czsc-migration.md`,按本章 5.2–5.12 展开为 \~80 个 bite-sized task。 +5. **验收**:plan 通过自审 checklist(无 TBD / 每 task 有 test code + run command + expected output / 每 task 都以 commit 结尾)。 -按依赖顺序从 rs-czsc 复制 crate 内容(不含 PyO3 binding): +### 5.2 Phase A — 写验收级失败测试("测试基线",1.5 天) -1. `error-macros` → `error-support` -2. `czsc-utils` -3. `czsc-core`、`czsc-ta`、`czsc-signal-macros` -4. `czsc-signals` -5. `czsc-trader` +把 §6 验收表 + §3.1 公共 API 表 翻译成可执行测试,跑出全 RED。**这是 superpowers TDD 的关键差异点**:传统 plan 是"先做实现再补测试",这里反过来。 -操作约束: -- 仅复制 `src/`、`tests/`、`Cargo.toml`、`README.md`(如有) -- `Cargo.toml` 中 `version.workspace = true` 等继承配置保持不变 -- **不**复制 `python/` 子目录(rs-czsc 中 PyO3 binding 散落各业务 crate,本次重组到 `czsc-python` 单独 crate) -- 验收:`cargo test --workspace` 通过(仅 Rust 测试) +| Task | RED 测试 | 断言内容 | 预期失败原因(v0.2 当前状态) | +|-|-|-|-| +| A1 | `test/compat/test_public_api.py` | 从 `czsc.__all__` 与 stub 中读出 80+ 公共名称,逐个 `getattr(czsc, name)` 不抛异常;快照存 `test/compat/snapshots/api_v1.json` | czsc 当前 import 路径与本设计 §3.1 不完全一致(DummyBacktest、czsc.utils.ta 等仍在) | +| A2 | `test/unit/test_pickle.py` | 对每个 PyO3 暴露类(CZSC/BarGenerator/Position/CzscTrader/CzscSignals)做 `pickle.loads(pickle.dumps(obj)) == obj` roundtrip | 当前 PyO3 类未实现 `__getstate__` / `__setstate__` | +| A3 | `test/unit/test_core_parity.py` | 固定 seed 的 `wbt.mock.generate_symbol_kines` 输入下,`czsc.CZSC(bars).fxs/bi_list/zs_list` 与 rs-czsc 基线快照逐一相等(容差 0) | czsc 仓库尚未内置 Rust 实现,缺少基线快照对照机制 | +| A4 | `test/unit/test_signals_parity.py` | 30+ 信号函数对 mock 数据逐一比对 rs-czsc 基线输出 | 同 A3 | +| A5 | `test/unit/test_ta_parity.py` | `czsc.ta.{ema,sma,rolling_rank,...}` 相对 Python TA-Lib 容差 ≤ 1e-6 | 当前未暴露 czsc.ta.\*(只有 czsc.utils.ta 的 Python wrapper) | +| A6 | `test/unit/test_trading_time.py` | `is_trading_time` 在 A 股/港股/数字货币 三类日历上的若干典型时间点结果正确 | 该函数 Rust 端尚未实现 | +| A7 | `test/integration/test_weight_backtest.py` | `czsc.WeightBacktest` 等于 `wbt.WeightBacktest`;同一份权重输入产出指定的统计结果 | 当前 czsc.WeightBacktest 来自 czsc.core 的 Python fallback / rs-czsc,非 wbt | +| A8 | `test/smoke/test_install.py` | 在干净 venv 中 `pip install ./dist/czsc-*.whl` 后 `python -c "import czsc; czsc.CZSC(...)"` 成功 | 尚未切到 maturin,wheel 不包含 Rust 扩展 | -### Phase 3 — 实现 czsc-python binding 层(2 天) +- **验收**:`pytest test/ -v` 输出全部为 FAIL(不能是 ERROR / SKIP);CI 上的 RED 状态被记录到 plan 文件作为 baseline。 +- **禁止**在本 phase 写任何 Rust 实现或修改 czsc/\* 业务代码——只写测试。 -- 在每个业务 crate 中新增 `src/python/mod.rs`,把原来散落各 crate 的 `#[pyclass]` / `#[pyfunction]` 集中到这里,并提供 `pub fn register(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>` 入口 -- 把 `check_bi / check_fx / check_fxs / remove_include` 提为 `pub` 并补 `#[pyfunction]` -- 在 `czsc-utils` 中新增 `is_trading_time` Rust 实现 + `#[pyfunction]` -- 在 `crates/czsc-python/src/lib.rs` 中聚合所有 `register()` 调用 -- 配置 `pyproject.toml`:`build-system = maturin`,`tool.maturin.module-name = "czsc._native"` -- 验收:`maturin develop` 后 Python 端 `from czsc._native import CZSC` 可成功 +### 5.3 Phase B — Rust workspace 骨架(GREEN 第一层,1 天) -### Phase 4 — Python 包重构(2 天) +1. **RED**:写 `tests/rust/test_workspace_layout.sh`(或 cargo metadata 检查),断言存在 `crates/{czsc-core, czsc-utils, czsc-ta, czsc-signals, czsc-trader, czsc-signal-macros, error-macros, error-support, czsc-python}` 9 个成员,且 `cargo build --workspace` 通过。 +2. 跑测试看到失败(目录不存在)。 +3. **GREEN**:创建 `Cargo.toml`(按 §2.1)、9 个空 crate(每个一个空 `lib.rs` + 最小 `Cargo.toml`)。 +4. 跑测试看到通过。 +5. Commit: `feat(rust): scaffold workspace with 9 empty crates`。 -- 重写 `czsc/__init__.py`(按 §3.1 表格) -- 极简 `czsc/core.py`、`czsc/envs.py` -- 极薄化 `czsc/signals/`、`czsc/traders/` -- 新增 `czsc/ta/__init__.py`(re-export `czsc._native.ta`) -- 验收:`python -c "import czsc; czsc.CZSC"` 等所有顶层 API 可用 +### 5.4 Phase C — czsc-utils 测试驱动迁移(1 天) -### Phase 5 — Python 删减(1 天) +采用"复制即测试"模式:**不**整体复制 src/,而是按 rs-czsc 的测试逐个迁移,每个测试都走 RED → 复制对应 src 文件 → GREEN。 -- 删除 `czsc/py/`、`czsc/eda.py`、`czsc/aphorism.py`、`czsc/features/`、`czsc/fsa/` -- 删除 `czsc/utils/` 中按 §3.2 标记为"删除"的文件 -- 删除 `czsc/traders/` 中除 `dummy.py` / `sig_parse.py` / `__init__.py` 之外的所有文件 -- 删除 `czsc/signals/` 中所有 Python 实现的信号文件(仅留 re-export) -- 验收:`python -m compileall czsc/` 通过;`python -c "import czsc"` 不报 import 错误 +1. **子循环 1(freq_data 模块)**:① 把 rs-czsc 的 `czsc-utils/tests/test_freq_data.rs` 复制到 `crates/czsc-utils/tests/`。② `cargo test -p czsc-utils freq` 看到 RED(src 还空着)。③ 复制 `czsc-utils/src/freq_data.rs`。④ 跑测试通过。⑤ commit。 +2. **子循环 2(BarGenerator 模块)**:同样模式。 +3. **子循环 3(is_trading_time 新增能力)**:先写测试用例(A 股/港股/数字货币三组),看到 RED;写 Rust 实现;GREEN;commit。 +4. **子循环 4(PyO3 binding)**:在 Python 端 `test/unit/test_bar_generator.py` 增加细粒度测试(独立于 A3 parity 测试),断言 `czsc._native.BarGenerator` 行为;用 `maturin develop` 让其失败;在 `czsc-utils/src/python/mod.rs` 加 `#[pymodule]` + 在 `czsc-python` 注册;GREEN;commit。 -### Phase 6 — 测试体系重构(2 天) +**本 phase 完成的判定**:`cargo test -p czsc-utils` 全过;Phase A 中 A6(is_trading_time)由 RED 转 GREEN。 -- 按 §4.3 重组 `test/` 目录 -- 删除已无对应实现的测试(如 `test_eda.py`、`test_features.py`) -- 重写 `conftest.py`,注入 `czsc.mock` fixtures -- 新增 `test/compat/test_public_api.py` 锁定公共 API -- 新增 `test/smoke/` 冒烟测试 -- 验收:`pytest test/unit test/integration` 全绿 +### 5.5 Phase D — czsc-core 测试驱动迁移(2 天) -### Phase 7 — examples 与文档同步(1 天) +对照 rs-czsc 的 `czsc-core` 模块清单(FX / BI / ZS / CZSC / Direction / Mark / Operate / Signal / Event / Position 等),每个数据类型一个子循环: -- 调整 `examples/` 中的 import 路径(移除已删除的模块引用) -- 更新 `CLAUDE.md`:构建命令改为 `maturin develop`、测试命令同步更新 -- 更新 `README.md`:标注新架构、安装方式 -- 撰写 `docs/MIGRATION_NOTES.md`:从旧版 czsc 升级的破坏性变更清单 -- 验收:所有保留的 examples 可在新环境中跑通 +1. 子循环命名规则:`test_.rs`(Rust 单元)+ `test__py.py`(PyO3 binding 行为)。 +2. 每个子循环:复制对应 rs-czsc 测试 → RED → 复制对应 src → GREEN → 加 PyO3 binding 测试 → RED → 写 binding → GREEN → 加 pickle roundtrip 断言 → RED → 写 `__getstate__/__setstate__` → GREEN → commit。 +3. **关键 4 个 `pub(crate)` → `pub` 的可见性提升**(`check_bi/check_fx/check_fxs/remove_include`)作为独立子循环:先写 Python 测试 `czsc.check_bi(...)`,RED 因为 binding 未注册 → 提升可见性 + 加 PyO3 → GREEN。 -### Phase 8 — CI / 发布验证(1 天) +**本 phase 完成的判定**:A3(core_parity)由 RED 转 GREEN;A2(pickle)对 czsc-core 涉及的所有类转 GREEN。 -- GitHub Actions 配置三阶段: - 1. Rust:`cargo fmt --check` + `cargo clippy -D warnings` + `cargo test --workspace` - 2. Python:`maturin build --release` + `pytest` - 3. Wheel 构建:`maturin build` 多平台(linux/macos/windows)+ smoke test -- 在测试 PyPI(test.pypi.org)发布 `czsc-1.0.0rc1` 验证 pip 安装 -- 验收:测试 PyPI 上 `pip install czsc` 后 smoke test 全过 +### 5.6 Phase E — czsc-ta + czsc-signal-macros(1.5 天) -**总工期估算:10.5 天**(不含评审和迭代)。 +1. 先按 `czsc-trader` / `czsc-signals` 中的 ta 调用链做**静态分析**,列出实际被引用的算子白名单(写入 plan 的"czsc-ta 裁剪清单")。 +2. 仅迁移白名单内的算子;每个算子一个 RED→GREEN 子循环。被裁剪的算子在 `MIGRATION_NOTES.md` 的"czsc-only 改动"小节列出。 +3. czsc-signal-macros:迁移 `#[signal_module]` proc-macro,先写一个最小宏展开测试,RED → 复制 macro 实现 → GREEN。 ---- +**本 phase 完成的判定**:A5(ta_parity)转 GREEN,被裁剪算子有书面记录。 + +### 5.7 Phase F — czsc-signals 迁移(1.5 天) + +1. 按子模块(`bar/cxt/tas/vol/pressure/obv/cvolp`)分别迁移;每个子模块一组 RED→GREEN 子循环。 +2. 每个信号函数一个 Rust 单元测试 + 一个 Python parity 测试(验证签名兼容旧 Python 实现)。 +3. 注册路径:`czsc._native.signals.bar.*` → 在 czsc/signals/bar.py 中 re-export。 + +**本 phase 完成的判定**:A4(signals_parity)转 GREEN。 + +### 5.8 Phase G — czsc-trader 迁移(含 strategies)(2 天) + +1. 迁移 `CzscTrader` / `CzscSignals` / `generate_czsc_signals` / `get_unique_signals`(每个一个子循环)。 +2. **strategies.py 迁移**:先在 Rust 端 `czsc-trader/src/strategies/` 增加 `StrategyBase` / `JsonStrategy` 实现;写 Python 测试 `test/integration/test_strategies.py` 断言 `czsc.CzscStrategyBase` 与原 Python 实现行为一致 → RED → 写 Rust 实现 → GREEN → 删 `czsc/strategies.py` → 验证 GREEN 不变。 +3. **WeightBacktest** 不在本 crate 实现:保留 Phase A7 的 RED,等 Phase I 由 wbt 接管转 GREEN。 + +### 5.9 Phase H — czsc-python 聚合 + Python 包重构(1.5 天) + +1. 在 `crates/czsc-python/src/lib.rs` 中聚合所有 `register()`(按 §2.4 模板)。 +2. 配置 `pyproject.toml`:`build-system = maturin`、`module-name = "czsc._native"`、加入 `wbt` 硬依赖。 +3. 重写 `czsc/__init__.py`:按 §3.1 表逐项 import。每加一项跑 A1(compat),看哪个名称由 RED 转 GREEN,逐步把 80+ 名称变绿。 +4. 极薄化 `czsc/signals/` / `czsc/traders/`(按 §3.3)。 +5. 新增 `czsc/ta/__init__.py` re-export `czsc._native.ta`。 +6. **删除 `czsc/core.py`**(不再保留 Python 端核心入口)。 + +### 5.10 Phase I — wbt 集成(0.5 天) + +1. `uv add wbt` 加入硬依赖。 +2. 在 `czsc/__init__.py` / `czsc/traders/__init__.py` 中 `from wbt import WeightBacktest, daily_performance, top_drawdowns`,让 A7(WeightBacktest)转 GREEN。 +3. 把 `czsc/mock.py` 改为转发 `wbt.mock.*` 的薄壳(v0.1: 537 行 → \~30 行);A1 中 mock 相关名称转 GREEN。 +4. commit: `feat: wire wbt as the canonical backtest/perf/mock provider`。 + +### 5.11 Phase J — Python 删减(0.5 天) + +1. 按 §3.2 / §9 附录 B 的"完全删除"列表逐文件 `git rm`;每删一组跑 `pytest -q` 确认仍 GREEN。 +2. 删除 `czsc/utils/ta.py`(A5 已经由 czsc.ta.\* 接管)。 +3. 删除 `czsc/traders/{base,cwc,rwc,optimize,weight_backtest,performance,dummy}.py`。 +4. 删除 `czsc/{py,features}/` 目录(保留 `aphorism.py` / `fsa/` / `eda.py`)。 +5. **验收**:`find czsc -name '*.py' | xargs wc -l` 总行数落入 \~12K(§6 Q5)。 + +### 5.12 Phase K — CI / Trusted Publishing / finishing(1 天) + +1. 写 GitHub Actions workflow:Rust(fmt/clippy/test)+ Python(maturin build + pytest)+ wheel matrix(linux/macos/windows)+ smoke。 +2. 配置 PyPI / TestPyPI 的 **Trusted Publishing(OIDC)** binding(仓库 + workflow + environment 三元组)。 +3. 发 RC 到 test.pypi.org,干净 venv 跑 A8(smoke),转 GREEN。 +4. 用 `superpowers:finishing-a-development-branch` 完成合并、release notes、tag 1.0.0。 + +**总工期估算**:14 天(Phase 0–K,比 v0.1 的 10.5 天多约 30%,多出来的时间用于 Phase A 的"验收测试基线"——这是 TDD 范式相比"先实现后补测试"必然多出的成本,但换来的是从第 1 天起就有可量化的进度指标:每个 task 都能清晰回答"这个改动到底让多少 RED 转 GREEN")。 + +### 5.13 进度可视化 + +plan 文件中每个 task 标注其会让哪几条 Phase A 测试由 RED 转 GREEN。CI 中加一个 `scripts/red_green_report.py`:在每次 commit 后输出 `红 X 项 / 绿 Y 项 / 总 N 项`,作为 PR 描述的进度行。 ## 6. 验收标准 -| # | 标准 | 验证方式 | -|---|---|---| -| 1 | `from czsc import CZSC, Signal, Event, Position, Direction, Freq, format_standard_kline` 全部成功 | `test/compat/test_public_api.py` | -| 2 | 用户代码中**不需要** `import rs_czsc` | grep `examples/` 与文档无 `rs_czsc` 引用 | -| 3 | `cargo test --workspace` 全过 | CI | -| 4 | `pytest` 全过且覆盖率 ≥ 70% | CI + coverage report | -| 5 | `maturin build --release` 可在 linux/macos/windows 三平台产出 wheel | CI matrix | -| 6 | 删减后 czsc Python 代码量从 ~44K 行降至 ~8K 行 | `find czsc -name '*.py' \| xargs wc -l` | -| 7 | 安装 wheel 后 `python -c "import czsc; czsc.CZSC"` 可用,无需用户手动安装 rs-czsc | smoke test | -| 8 | examples/ 中保留的全部示例可跑通 | 手动验证 + CI 选择性执行 | -| 9 | 不再有 `CZSC_USE_PYTHON` 环境变量分支 | `grep -r CZSC_USE_PYTHON czsc/` 应无结果 | -| 10 | rs-czsc 代码同步追溯:`docs/MIGRATION_NOTES.md` 记录基线 commit + 后续 cherry-pick 列表 | 文档存在性检查 | +
分类#验收标准验证方式
功能正确性F1from czsc import CZSC, Signal, Event, Position, Direction, Freq, format_standard_kline 等 80+ 公共名称全部成功导入test/compat/test_public_api.py 快照测试
F2缠论核心算法(分型、笔、线段、中枢)在固定随机种子的 mock 数据上结果与 rs-czsc 基线一致(容差 0)test/unit/test_core_parity.py
F330+ 信号函数在 mock 数据上的输出值与 rs-czsc 基线一致;签名兼容旧 Python 实现test/unit/test_signals_parity.py
F4TA 算子(ema/sma/rolling_rank/...)相对 Python TA-Lib 数值容差 ≤ 1e-6(除非有文档化的算法差异)test/unit/test_ta_parity.py
F5WeightBacktest 通过 wbt 包正常工作;czsc 端 example/sensors 接入 wbt 后结果与历史快照一致集成测试 + 历史结果回放
F6is_trading_time 等 czsc-only 新增能力在 A 股 / 港股 / 数字货币 三类日历上行为正确test/unit/test_trading_time.py
性能P1对 10 万根 K 线做完整 CZSC 分析(分型/笔/中枢)≤ 200 ms(M2 Mac,单进程)cargo bench -p czsc-core + Python pytest-benchmark
P230+ 信号函数批量执行单根 K 线 ≤ 50 µs P50;批量 1 万根 ≤ 80 msbenchmark CI 阈值
P3czsc 包冷启动 import 时间 ≤ 300 ms(含 Rust 扩展加载)python -X importtime -c "import czsc" + CI 阈值
质量Q1cargo test --workspace 全过;cargo clippy -D warnings 无 warning;cargo fmt --check 通过CI
Q2pytest 全过且公共 API 覆盖率 ≥ 90%;整体行覆盖率 ≥ 70%CI + coverage report(codecov)
Q3所有通过 PyO3 暴露的对象(CZSC/BarGenerator/Position/CzscTrader/CzscSignals/...)支持 pickle,可在 Streamlit / Joblib / multiprocessing 中安全传递test/unit/test_pickle.py roundtrip 测试覆盖每个类
Q4ruff check(替代 flake8/isort)+ basedpyright(替代 mypy)双向通过;type stub czsc/_native.pyi 自动生成且无人工修改CI
Q5删减后 czsc Python 代码量从 ~44K 行降至 ~12K 行(保留 aphorism/eda/fsa 后的目标值)find czsc -name '*.py' | xargs wc -l
兼容性C1用户代码不需要import rs_czscexamples/ 与文档无 rs_czsc 引用grep -r rs_czsc examples/ docs/ 应无结果
C2不再有 CZSC_USE_PYTHON 环境变量分支grep -r CZSC_USE_PYTHON czsc/ 应无结果
C3下游用户主要 import 路径(czsc.CZSCczsc.signals.bar.*czsc.traders.CzscTraderczsc.utils.cache 等)保持不破坏;已移除的名称MIGRATION_NOTES.md 中给出替代方案API 快照 + 文档检查
发布R1maturin build --release 在 linux(manylinux_2_28)/ macos(universal2)/ windows 三平台产出 wheelCI matrix
R2使用 PyPI Trusted Publishing(OIDC) 发布,不在 GitHub Actions 中存放任何 PyPI token / secretsCI workflow 配置 + PyPI 项目设置截图
R3Test PyPI(test.pypi.org)发布 czsc-1.0.0rc1,在干净环境 pip install --index-url https://test.pypi.org/simple/ czsc 后 smoke test 全过CI smoke job
R4正式 PyPI 发布后 pip install czsc 即可使用,无需用户手动安装 rs-czscpython -c "import czsc; czsc.CZSC" 成功CI 安装后 smoke job
追溯T1docs/MIGRATION_NOTES.md 记录从 rs-czsc 迁移的基线 commit hash、迁移日期、czsc-only 改动清单(含被裁剪的 ta 算子列表)文档存在性检查
T2每个删除/重命名的旧公共 API 在 MIGRATION_NOTES.md 与 release notes 中列出替代方案;major 版本号升至 1.0.0 表明破坏性变更release notes review
--- ## 7. 风险与缓解 | 风险 | 等级 | 缓解策略 | -|---|---|---| -| **`pub(crate)` → `pub` 提升后 rs-czsc 上游同步困难** | 中 | 在 `MIGRATION_NOTES.md` 标注哪些函数已"czsc-only 公开化";cherry-pick 时手动适配 | +|-|-|-| +| **czsc-only 公开化的函数**(`check_bi/check_fx/check_fxs/remove_include` 等 `pub(crate)`→`pub`)形成 czsc 仓库专属约定,与原 rs-czsc 不再回流互通 | 低 | 已在 `MIGRATION_NOTES.md` 的"czsc-only 改动"小节集中记录;rs-czsc 不再维护后无 cherry-pick 需求,长期看是**独立演进**而非"同步困难" | | **`is_trading_time` 在 Rust 端是新增实现,可能与 Python 旧逻辑行为不一致** | 中 | 增加专门的对比测试 `test_unit/test_trading_time.py`,跑历史数据集验证 | | **`czsc-ta` 的 `mixed/` 子模块依赖 NumPy 0.25.0 + abi3-py310,少数 Linux 环境下 wheel 构建失败** | 低 | CI 多平台 matrix 提早暴露;保留 `manylinux_2_28` build profile | | **公共 API 移除后下游用户代码报错(`from czsc import xxx` 失败)** | 高 | `test/compat/` 锁定 `__all__`;`docs/MIGRATION_NOTES.md` 列出所有删除的名称 + 替代方案;major 版本号升至 `1.0.0` 表明破坏性变更 | | **svc / sensors 隐式依赖被删除的模块(如 eda.py / features/)** | 中 | Phase 4 完成后立即跑 `python -c "import czsc.svc; import czsc.sensors"` 检查;如有断链,按需补 thin re-import 或迁移到保留模块 | -| **rs-czsc 自身仍在演进(version 0.1.27),迁移基线很快过时** | 中 | 在 Phase 0 锁定一个明确 commit;建立"季度同步"机制 cherry-pick 上游 fix | -| **构建工具切换(hatchling → maturin)破坏现有发布流程** | 中 | Phase 8 在 test.pypi.org 充分验证再切正式 PyPI;保留旧 hatchling 配置在 git 历史中可回滚 | +| **rs-czsc 已停止维护,未来上游 bugfix 无法回流** | 低 | czsc 一次性 fork 后独立演进;Phase 0 锁定基线 commit;后续按本仓库常规流程修 bug,不再 cherry-pick 上游 | +| **构建工具切换(hatchling → maturin)破坏现有发布流程** | 中 | 采用 **PyPI Trusted Publishing(OIDC)**,不在 GitHub 存任何 PyPI token;先在 test.pypi.org 验证 wheel + 安装路径再切正式 PyPI;旧 hatchling 配置保留在 git 历史可回滚 | | **测试体系重构丢失对原有边界场景的覆盖** | 中 | 删除测试文件前先 review 内含的 corner case,把仍有效的断言迁移到新结构 | --- @@ -517,7 +542,7 @@ uv run maturin develop --release # 本地构建 Rust 扩展并安装到当 ## 8. 附录 A — Rust crate ↔ Python 命名空间映射 | Rust 对象(来源) | Python 暴露路径 | -|---|---| +|-|-| | `czsc_core::CZSC` | `czsc.CZSC` | | `czsc_core::objects::FX` | `czsc.FX` | | `czsc_core::objects::BI` | `czsc.BI` | @@ -546,14 +571,18 @@ uv run maturin develop --release # 本地构建 Rust 扩展并安装到当 | `czsc_signals::cvolp::*` | `czsc.signals.cvolp.*` | | `czsc_trader::CzscTrader` | `czsc.CzscTrader`、`czsc.traders.CzscTrader` | | `czsc_trader::CzscSignals` | `czsc.CzscSignals`、`czsc.traders.CzscSignals` | -| `czsc_trader::WeightBacktest` | `czsc.WeightBacktest`、`czsc.traders.WeightBacktest` | +| `wbt::WeightBacktest`(外部包) | `czsc.WeightBacktest`、`czsc.traders.WeightBacktest` | | `czsc_trader::generate_czsc_signals` | `czsc.generate_czsc_signals` | | `czsc_trader::get_unique_signals` | `czsc.get_unique_signals` | -| `czsc_trader::daily_performance` | `czsc.daily_performance` | -| `czsc_trader::top_drawdowns` | `czsc.top_drawdowns` | +| `wbt.daily_performance`(外部包) | `czsc.daily_performance` | +| `wbt.top_drawdowns`(外部包) | `czsc.top_drawdowns` | --- + +**来源标识说明**:表中"来源"列为 `wbt::*` 的项表示来自 [wbt](https://github.com/zengbin93/wbt) 外部包,czsc 通过 `from wbt import ...` 进行 re-export 以保持公共 API 兼容;其余 `czsc_*` 来源均为本仓库 Rust workspace 的内部 crate。 + + ## 9. 附录 B — 删除/保留/精简清单 ### 完整保留 @@ -561,38 +590,108 @@ uv run maturin develop --release # 本地构建 Rust 扩展并安装到当 - `czsc/connectors/`(5 文件,1177 行) - `czsc/sensors/`(3 文件,301 行) - `czsc/svc/`(11 文件,4375 行) -- `czsc/mock.py`(537 行) -- `czsc/strategies.py`(410 行) +- `czsc/mock.py`(v0.1: 537 行 → 迁移后 **\~30 行**,仅转发 wbt 的 mock 函数) +- `czsc/strategies.py`(v0.1: 410 行 — **临时保留待迁 Rust**,迁移到 `czsc-trader` 后删除) +- `czsc/aphorism.py`(853 行) +- `czsc/fsa/`(8 文件,2078 行) +- `czsc/eda.py`(1213 行 — 暂保留待重构) +- `czsc/utils/oss.py`(阿里云对象存储工具) ### 大幅精简 -- `czsc/utils/` 从 ~10.7K 行降到 ~2.5K 行(保留:cache, io, log, ta, calendar, sig, kline_quality, plot_backtest, plotly_plot, data_client, trade*;删除其它) -- `czsc/__init__.py` 从 331 行降到 ~150 行(删除延迟加载机制) -- `czsc/core.py` 从 134 行降到 < 5 行 -- `czsc/envs.py` 从 50 行降到 ~20 行(移除 `CZSC_USE_PYTHON`) +- `czsc/utils/` 从 \~10.7K 行降到 \~3K 行(保留:cache, io, log, calendar, sig, kline_quality, plot_backtest, plotly_plot, data_client, trade\*, oss;不再保留 ta,由 Rust 端 czsc.ta.\* 替代;删除其它) +- `czsc/__init__.py` 从 331 行降到 \~150 行(删除延迟加载机制) +- `czsc/envs.py` 从 50 行降到 \~20 行(移除 `CZSC_USE_PYTHON`) ### 极薄化 -- `czsc/signals/` 从 12 文件 / 15K 行降到 ~12 文件 / ~200 行(仅 re-export) -- `czsc/traders/` 从 9 文件 / 3.5K 行降到 ~3 文件 / ~150 行(保留 dummy.py / sig_parse.py + re-export) +- `czsc/signals/` 从 12 文件 / 15K 行降到 \~12 文件 / \~200 行(仅 re-export) +- `czsc/traders/` 从 9 文件 / 3.5K 行降到 \~2 文件 / \~80 行(仅保留 sig_parse.py 待评估 + \_\_init\_\_.py re-export;DummyBacktest 已删除,WeightBacktest 改用 wbt 包) ### 完全删除 - `czsc/py/`(6 文件,2148 行 — Rust 已实现) -- `czsc/eda.py`(1213 行) -- `czsc/aphorism.py`(853 行) - `czsc/features/`(9 文件,777 行) -- `czsc/fsa/`(8 文件,2078 行) -- `czsc/utils/` 中:bar_generator, bi_info, analysis/, echarts_*, pdf_report, html_report_builder, word_writer, features, oss, st_components, corr, signal_analyzer +- `czsc/utils/` 中:bar_generator, bi_info, analysis/, echarts\_\*, pdf_report, html_report_builder, word_writer, features, st_components, corr, signal_analyzer, ta(不再保留 oss) ### 量化总结 | 指标 | 迁移前 | 迁移后 | -|---|---|---| -| Python 文件数 | 89 | ~35 | -| Python 代码行数 | ~44K | ~8K | +|-|-|-| +| Python 文件数 | 89 | \~50 | +| Python 代码行数 | \~44K | \~12K | | Rust crate 数 | 0 | 9 | | 构建工具 | hatchling | maturin | | 外部 PyPI 依赖 `rs-czsc` | 是 | 否(自带 Rust) | | `CZSC_USE_PYTHON` 双路由 | 是 | 否 | -| 公共 API 数量 | ~80 | ~80(保持兼容) | +| 公共 API 数量 | \~80 | \~80(保持兼容) | + +## 10. 附录 C — 评审反馈处理记录(v0.1 → v0.2) + + +本节记录针对 v0.1 草案 19 条评审意见的处理结果,仅作为评审跟进证据,不影响主体设计。 + + +| # | 评审意见摘要 | 本次调整 | 位置 | +|-|-|-|-| +| 1 | `core.py` 多余 / 3.5 节不需要 | 删除整节 3.5;仓库结构图与附录 B 移除 core.py | §1 / §3.5 / §9 | +| 2 | CzscStrategyBase / CzscJsonStrategy 应有 Rust 实现 | strategies.py 标记"临时保留,应迁 Rust",公共 API 表加备注;后续在 czsc-trader 内补实现 | §1 / §3.1 | +| 3 | czsc-ta 中无人调用的指标可删 | §2.3 增加"未被调用的算子随迁移裁剪"说明,czsc-only 改动汇总到 MIGRATION_NOTES.md | §2.3 / §2.6 | +| 4 | 不再保留 czsc.utils.ta 兼容层 | §2.3 命名空间项调整;§3.2 表格 ta.py 改"删除";§3.1 公共 API 表 utils 行去掉 ta;附录 B 删除清单加 ta | §2.3 / §3.1 / §3.2 / §9 | +| 5 | daily_performance / top_drawdowns 用 wbt 引入 | §3.1 顶层 TA / 性能函数行明确"czsc 内部 `from wbt import ...` 后保持 `czsc.daily_performance` / `czsc.top_drawdowns` 的公共 API 暴露";§8 附录 A 的来源列改为 `wbt.*` | §3.1 | +| 6 / 7 | WeightBacktest 优先用 wbt | §3.1 / §3.3 / 验收 F5 / 附录 B 全部改为引用 wbt;过渡期可在 czsc 内做 re-export | §3.1 / §3.3 / §6 / §9 | +| 8 | oss.py 留在 utils | §3.2 表 oss.py 改"保留";附录 B 完全删除清单移除 oss | §3.2 / §9 | +| 9 | DummyBacktest 删除 | §3.1 / §3.3 / 附录 B 全部移除 | §3.1 / §3.3 / §9 | +| 10 | SignalsParser 看 Rust 是否已等价实现 | 标记"待评估 Rust 是否已等价实现",迁移完成前作 Python 薄层 | §3.1 / §3.3 | +| 12 | K 线 mock 用 wbt | §3.1 czsc.mock 行 / §4.3 测试条目 改为转发 wbt mock | §3.1 / §4.3 | +| 13 | 验收标准需更详尽 | §6 重写为五大类(功能/性能/质量/兼容/发布)+ 追溯,扩充至 \~24 条具体验收 | §6 | +| 14 | rs-czsc 后续不再维护 | §0.2 决策 7、§2.6、§7 风险表对应行改为"一次性 fork 后独立演进,无 cherry-pick" | §0.2 / §2.6 / §7 | +| 15 | PyPI 用最新推荐方式(OIDC,无 secrets) | Phase 8 改为 Trusted Publishing;§7 风险表对应行同步;§6 R2 验收 | §5 / §6 / §7 | +| 16 / 17 / 18 | aphorism / fsa / eda 保留 | §0.2 决策 4、§1 仓库结构图、§9 附录 B 调整为保留(eda 标注"待重构") | §0.2 / §1 / §9 | +| 19 | Rust Python 对象支持 pickle | §2.4 增加强制 pickle 序列化要求;§6 验收 Q3 增加 roundtrip 测试条款 | §2.4 / §6 | + +### 不在本次范围内的反馈 + +- 评论 #11 与 #1 重复(均针对 core.py),按 #1 一并处理。 + +### 仓库已落地的相关变更(参考) + +本设计草案 v0.1(commit `534eed8`)提交后,主线已合入若干前置改动,可作为本设计的实施基底: + +- `3f4cf2b` — 移除 Python 端 WeightBacktest fallback,独占使用 Rust 版(与本设计 §3.1 / §3.3 一致) +- `1325433` — 为所有 czsc 子模块新增 .pyi stub(符合 §2.4 type stub 自动生成方向) +- `79bdf5e` — 用 ruff + basedpyright 替代 black/flake8/isort/mypy(与本设计 §6 Q4 验收一致) +- `7dcadaa` / `a63965b` — 删除已不再使用的函数与文件,对接 §3.2 / §3.3 的精简方向 + +### v0.3 修订记录(一致性 + TDD) + +**触发原因**:v0.2 提交后审阅发现两类问题——① 全局一致性:v0.2 的若干跨章节修改在细节上互相冲突;② 迁移流程是"先做实现后补测试"的传统模式,与项目使用的 superpowers 工作流不符。 + +#### 1. 一致性修订清单(13 项) + +| # | 原冲突 | 修订 | +|-|-|-| +| C1 | §0.2 决策 3 仍含 `core` | 去除;明确"`core.py` 不再保留,名称由 `__init__.py` 直接 re-export" | +| C2 | §3.1 顶层 TA / 性能函数写"**不再暴露** daily_performance / top_drawdowns" | 更正为"czsc 内部 `from wbt import ...` 后照常暴露 `czsc.daily_performance` / `czsc.top_drawdowns`",与 §8 附录 A 对齐 | +| C3 | §8 附录 A 把 `WeightBacktest` 标为 `czsc_trader::*` 来源 | 更正为 `wbt::WeightBacktest`(外部包);`daily_performance` / `top_drawdowns` 同步改 `wbt.*`,并加 callout 说明"`wbt::*` = re-export 自外部包" | +| C4 | §3.1 czsc.mock 来源 cell 写"完整保留"但暴露名写"转发自 wbt" | 来源改为"薄层(转发 wbt)" | +| C5 | §3.3 traders \_\_init\_\_ 代码块用 `type: ignore[import-not-found]` 暗示 wbt 可选 | 去掉 type:ignore;明确 wbt 是 pyproject.toml 中的硬依赖 | +| C6 | §4.3 测试约束有两条互相重复的 mock 描述 | 合并为一条:"统一通过 czsc.mock 入口(czsc.mock 是 wbt mock 的转发壳)" | +| C7 | 原 Phase 4 写"极简 czsc/core.py",与 §3.5 已删冲突 | v0.3 整章重写为 TDD(见下),core.py 的处置改为"删除" | +| C8 | 原 Phase 5 删除清单含 eda / aphorism / fsa | 整章重写后纠正 | +| C9 | 原 Phase 5 traders 删除排除 dummy.py | 整章重写后纠正(dummy.py 也删) | +| C10 | §7 风险表"pub(crate)→pub 上游同步困难"含"cherry-pick 时手动适配",与"rs-czsc 不再维护"冲突 | 风险等级降为低;缓解策略改为"czsc-only 公开化已记录于 MIGRATION_NOTES.md,rs-czsc 不再维护后无 cherry-pick 需求" | +| C11 | §9 完整保留 `czsc/mock.py(537 行)` 与"薄壳转发"决策矛盾 | 改为"v0.1: 537 行 → 迁移后 \~30 行(仅转发 wbt)" | +| C12 | §9 完整保留 `czsc/strategies.py(410 行)` 与"应迁 Rust 后删除"矛盾 | 改为"v0.1: 410 行 — 临时保留待迁 Rust,迁移完成后删除" | +| C13 | §10 评审 #5 调整描述写"不再 re-export"误读评论原意 | 更正为"czsc 内部 `from wbt import ...` 后保持公共 API 暴露" | + +#### 2. 迁移流程改造(§5 整章重写) + +v0.2 的 §5 是传统的 8 阶段"先实现后补测试"路径。v0.3 改写为 **superpowers TDD 模式**的 12 个 Phase(5.0–5.12)+ 进度可视化(5.13): + +- 引入 `spec → plan → execute` 三段式工作流;spec 存 `docs/superpowers/specs/`,plan 存 `docs/superpowers/plans/`。 +- 遵守 `superpowers:test-driven-development` 的 Iron Law:**没有失败测试就不写实现代码**。每个 task 都是完整的 RED→GREEN→REFACTOR→COMMIT 循环。 +- 引入 Phase A "测试基线"——**把 §6 验收标准翻译成可执行的失败测试**(compat / pickle / parity / smoke 共 8 类),第 1 天就跑出全 RED;之后每个 task 都能量化"让多少 RED 转 GREEN"。 +- 每个 crate(utils / core / ta / signals / trader)按"复制即测试"模式逐个迁移,不允许整体复制 src/。 +- 采用 `using-git-worktrees` 保证 master 分支隔离;最后用 `finishing-a-development-branch` 完成合并。 +- 总工期从 v0.2 的 10.5 天调整为 14 天(多出来的 \~30% 用于 Phase A 验收测试基线,但换来全程可量化进度)。 diff --git a/examples/develop/czsc_benchmark.py b/examples/develop/czsc_benchmark.py index a1ddd40b3..922ad0d97 100644 --- a/examples/develop/czsc_benchmark.py +++ b/examples/develop/czsc_benchmark.py @@ -1,24 +1,60 @@ +""" +CZSC 分析性能基准测试脚本 + +用途: + 在不同 K 线规模下测量 ``format_standard_kline`` 与 ``CZSC`` 初始化的耗时, + 用于跟踪迁移到 Rust 后端后的性能基线,以及在重大改动前后对比性能波动。 + +执行方式: + 直接运行本文件即可: + python examples/develop/czsc_benchmark.py + 输出经由 loguru 打印到控制台,包含每档样本量的两段耗时(毫秒级)。 + +注意事项: + - 默认走 ``czsc.mock.generate_symbol_kines`` 生成确定性 30 分钟随机 K 线, + 不依赖外部数据源 + - 把 ``czsc.utils.cache`` 模块的日志禁用,避免缓存命中信息干扰耗时统计 + - 当 count 接近或超过 100k 根时,单次测试需要数百 MB 内存,按机器情况裁剪样本档位 +""" import time import pandas as pd from loguru import logger -# from czsc.core import CZSC, format_standard_kline -from rs_czsc import CZSC, format_standard_kline +from czsc import CZSC, Freq, format_standard_kline +# 关闭磁盘缓存模块的日志输出,避免命中/失效日志混淆基准结果 logger.disable("czsc.utils.cache") def create_benchmark(count=1000): - """创建测试用的K线数据,测试 CZSC 分析性能""" - from czsc import mock + """ + 创建指定根数的 30 分钟 K 线,并测量两段关键耗时 + + 参数: + count: 取样末尾 K 线根数;越大越能贴近大数据量的真实场景 - df = mock.generate_klines() + 返回: + 构造好的 CZSC 实例(保留多达 100 笔,便于后续二次分析) + + 输出: + 通过 loguru 打印两条耗时日志: + 1. ``format_standard_kline`` 的转换耗时 + 2. ``CZSC(...)`` 的初始化耗时(含分型/笔识别全流程) + """ + # 仅在函数体内 import,避免 mock 模块在脚本加载阶段就被拉起 + from czsc.mock import generate_symbol_kines + + # 固定标的 / 周期 / 时间窗 / 默认种子,保证每次基准结果可复现 + df = generate_symbol_kines("000001", "30分钟", "20100101", "20250101") df = df.reset_index(drop=True) logger.info(f"开始创建 {count} 根K线的测试数据; 原始数据总共有 {len(df)} 根K线") + + # —— 第一段耗时:DataFrame -> List[RawBar] 的格式转换 —— start_time = time.time_ns() - bars = format_standard_kline(df.tail(count)) + bars = format_standard_kline(df.tail(count), freq=Freq.F30) logger.info(f"format_standard_kline {count};耗时 {(time.time_ns() - start_time) / 1_000_000:.2f} 毫秒") + # —— 第二段耗时:CZSC 初始化(包含分型与笔的识别)—— start_time = time.time_ns() c = CZSC(bars, max_bi_num=100) logger.warning(f"{count} bars -- CZSC初始化耗时 {(time.time_ns() - start_time) / 1_000_000:.2f} 毫秒") @@ -26,6 +62,7 @@ def create_benchmark(count=1000): if __name__ == '__main__': - # 100000+ 根K线的测试需要较大内存,按需调整 + # 多档样本量逐次跑,便于观察耗时随规模的变化趋势 + # 10 万根以上的测试需要较大内存,按机器实际情况增删档位 for count in [1000, 2000, 3000, 5000, 10000, 20000, 50000]: - create_benchmark(count) \ No newline at end of file + create_benchmark(count) diff --git a/examples/develop/test_trading_view_kline.py b/examples/develop/test_trading_view_kline.py index b723ca636..42917e37e 100644 --- a/examples/develop/test_trading_view_kline.py +++ b/examples/develop/test_trading_view_kline.py @@ -1,40 +1,67 @@ # -*- coding: utf-8 -*- """ -测试 trading_view_kline 函数 +trading_view_kline 函数的可视化示例脚本 -本示例展示如何使用 czsc 的 trading_view_kline 函数进行 K 线可视化。 -使用 mock 数据生成K线,通过 CZSC 分析后绘制带有分型和笔的K线图。 +用途: + 展示如何把 czsc 的分析结果(分型 / 笔)渲染到 lightweight-charts 风格的 + 交互式 K 线图上。脚本以 mock 模拟数据为输入,跑通"生成 -> 分析 -> 绘图" + 完整链路,可作为快速验证 trading_view_kline 集成是否正常的样例。 -author: czsc -create_dt: 2025-01-27 +执行方式: + python examples/develop/test_trading_view_kline.py + +依赖: + - czsc 主体(用于 CZSC 分析与 mock 数据) + - czsc.utils.echarts_plot.trading_view_kline(实际的绘图函数) + - lightweight_charts(可选;未安装时仅创建 chart 对象、不展示) + +输出: + 通过 loguru 打印关键步骤日志与最终测试结果("通过" / "失败")。 + +作者: czsc +创建时间: 2025-01-27 """ from loguru import logger from czsc.utils.echarts_plot import trading_view_kline -from czsc.core import CZSC, Freq, RawBar, format_standard_kline +from czsc import CZSC, Freq, RawBar, format_standard_kline from czsc.mock import generate_symbol_kines def test_trading_view_kline(): - """测试 trading_view_kline 函数""" + """ + 端到端验证 trading_view_kline 函数 + + 流程: + 1. 用 mock 生成 4 年的日线 K 线 + 2. 调用 CZSC 完成分型与笔识别 + 3. 把 RawBar / 分型 / 笔分别整理为 trading_view_kline 期望的字典列表 + 4. 调用 trading_view_kline 创建图表对象 + 5. 检查返回值是否具备 show 接口(依赖 lightweight_charts 是否安装) + + 返回: + bool: True 表示完整链路成功;False 表示中途抛异常 + """ logger.info("开始测试 trading_view_kline 函数") try: - # 使用 mock 数据生成K线 + # —— 步骤 1:生成 mock 日线 K 线,固定 seed 保证结果可复现 —— logger.info("生成模拟K线数据...") df = generate_symbol_kines('test', '日线', '20200101', '20240101', seed=42) raw_bars = format_standard_kline(df, freq=Freq.D) + # —— 步骤 2:CZSC 缠论分析(max_bi_num 设大值以便绘图保留完整笔历史)—— logger.info("使用CZSC分析K线数据...") czsc = CZSC(raw_bars, max_bi_num=10000) logger.info(f"分析完成:共{len(czsc.bi_list)}笔,{len(czsc.fx_list)}个分型") - # 转换数据格式用于绘图 + # —— 步骤 3:把对象列表转成绘图函数需要的 dict 列表 —— + # K 线:直接 __dict__ 抽取所有字段,避免逐字段拷贝 kline_data = [bar.__dict__ for bar in raw_bars] - # 获取分型数据 + # 分型:仅保留时间戳与分型类型(顶分型 / 底分型) fx_data = [{"dt": fx.dt, "fx": fx.fx} for fx in czsc.fx_list] if czsc.fx_list else [] - # 获取笔数据 + # 笔:每一笔取起点分型,再额外补上最后一笔的终点,保证连续画线无断点 if czsc.bi_list: bi_data = [{"dt": bi.fx_a.dt, "bi": bi.fx_a.fx} for bi in czsc.bi_list] bi_data.append({"dt": czsc.bi_list[-1].fx_b.dt, "bi": czsc.bi_list[-1].fx_b.fx}) @@ -43,13 +70,14 @@ def test_trading_view_kline(): logger.info("数据转换完成,开始调用 trading_view_kline 函数...") - # 调用函数 + # —— 步骤 4:调用绘图函数;t_seq 指定均线周期,bs 留空表示无买卖点标注 —— chart = trading_view_kline( kline=kline_data, fx=fx_data, bi=bi_data, bs=[], title="缠中说禅K线分析测试", t_seq=[5, 10, 20] ) logger.info("trading_view_kline 函数调用成功!") + # —— 步骤 5:检查返回对象;lightweight_charts 缺失时仅打印警告,不视为失败 —— if chart and hasattr(chart, "show"): logger.info("图表已创建成功,可调用 chart.show() 显示") else: @@ -58,6 +86,7 @@ def test_trading_view_kline(): return True except Exception as e: + # 捕获所有异常并完整打印 traceback,便于在 CI 或开发机上定位问题 logger.error(f"测试失败: {e}") import traceback logger.error(traceback.format_exc()) @@ -65,7 +94,7 @@ def test_trading_view_kline(): def main(): - """主函数""" + """脚本入口:打印分隔线包裹的日志,便于在大量输出中肉眼定位本次测试结果""" logger.info("=" * 50) logger.info("开始 trading_view_kline 函数测试") logger.info("=" * 50) diff --git a/pyproject.toml b/pyproject.toml index 8f254861a..3842be606 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,24 +1,22 @@ [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.sdist] -include = [ - "/czsc", - "/tests", - "/docs", - "/README.md" -] -exclude = [ - "/.git", - "/.github", - "/docs/_build", -] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +[tool.maturin] +# Phase H wires czsc-python (Rust workspace) as the producer of the +# czsc._native extension. The python-source = "." tells maturin to copy +# the existing pure-Python czsc/ tree alongside the compiled extension +# into the wheel. +module-name = "czsc._native" +manifest-path = "crates/czsc-python/Cargo.toml" +python-source = "." +include = ["czsc/**/*.pyi", "czsc/py.typed"] +features = ["pyo3/extension-module"] [project] name = "czsc" -dynamic = ["version"] +version = "1.0.0" description = "缠中说禅技术分析工具" readme = "README.md" license = "Apache-2.0" @@ -52,9 +50,11 @@ dependencies = [ "scipy", "statsmodels", "scikit-learn", - # 技术分析 + # 技术分析 (rs-czsc 已被 czsc._native 替代,czsc-python 通过 maturin + # 直接产出,wbt 为硬依赖提供 WeightBacktest / daily_performance / + # top_drawdowns / mock 函数) "TA-Lib>=0.6", - "rs-czsc>=0.1.26", + "wbt", # 图表和可视化 "matplotlib", "seaborn", @@ -87,6 +87,9 @@ default = true test = [ "pytest>=7.0.0", "pytest-cov>=4.0.0", + # rs_czsc baseline used by test/parity/ to prove czsc._native + # outputs match the reference Rust implementation byte-for-byte. + "rs-czsc", ] # 开发工具 @@ -119,18 +122,9 @@ dev = [ "tushare>=1.4.24", "ruff>=0.9.0", "basedpyright>=1.28.0", + "maturin>=1.13.1", ] -[tool.hatch.build.targets.wheel] -packages = ["czsc"] - -[tool.hatch.build.targets.wheel.force-include] -"czsc/__init__.pyi" = "czsc/__init__.pyi" -"czsc/py.typed" = "czsc/py.typed" - -[tool.hatch.version] -path = "czsc/__init__.py" - [tool.pytest.ini_options] testpaths = ["test"] python_files = ["test_*.py"] diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000..292fe499e --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "stable" diff --git a/scripts/cargo_test_all.sh b/scripts/cargo_test_all.sh new file mode 100755 index 000000000..c3b84daf0 --- /dev/null +++ b/scripts/cargo_test_all.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Run all Rust workspace tests on the business crates. +# +# czsc-python is excluded because it is the cdylib aggregator that opts +# into pyo3/extension-module + abi3-py310. Cargo unifies features across +# the build graph, so including it in `cargo test --workspace` re-enables +# `extension-module` for the business crates' lib tests, where pyo3 then +# refuses to link against libpython. The published wheel is exercised +# end-to-end by the Python pytest suite (test/smoke + test/unit), so +# excluding the cdylib-only crate from `cargo test` doesn't reduce +# coverage. +# +# Design doc §6 Q1 verification: this script is what CI runs. + +set -euo pipefail + +# Resolve a Python interpreter for pyo3-build to bind libpython. +if [[ -z "${PYO3_PYTHON:-}" ]]; then + if [[ -x ".venv/bin/python" ]]; then + export PYO3_PYTHON="$(pwd)/.venv/bin/python" + elif command -v python3 >/dev/null 2>&1; then + export PYO3_PYTHON="$(command -v python3)" + else + echo "PYO3_PYTHON unset and no python interpreter found" >&2 + exit 1 + fi +fi + +echo "Using PYO3_PYTHON=$PYO3_PYTHON" +exec cargo test --workspace --exclude czsc-python --no-fail-fast "$@" diff --git a/test/compat/__init__.py b/test/compat/__init__.py new file mode 100644 index 000000000..1bc6f2867 --- /dev/null +++ b/test/compat/__init__.py @@ -0,0 +1,13 @@ +"""compat 测试包 —— 公共 API 兼容性测试套件。 + +本包内的测试用例锁定迁移后 ``czsc`` 包对外暴露的公共 API 表面, +包括: + * 顶层模块需要导出的名称(``CZSC``、``RawBar``、``Signal`` 等) + * 信号子包是否齐全(如 ``czsc.signals.bar`` 等) + * ``czsc.traders`` 命名空间下应保留的公共名称 + * ``czsc.ta`` 技术指标命名空间 + * 已废弃的旧 API 必须被移除(``WeightBacktest`` 应来自 ``wbt`` 包等) + +通过对比 ``snapshots/api_v1.json`` 中的快照,确保任何破坏性的 import +路径或签名修改都能被立即捕获,保护下游用户的兼容性。 +""" diff --git a/test/compat/snapshots/api_v1.json b/test/compat/snapshots/api_v1.json new file mode 100644 index 000000000..b37512887 --- /dev/null +++ b/test/compat/snapshots/api_v1.json @@ -0,0 +1,80 @@ +{ + "_doc": "Phase A baseline snapshot for public API. Source: docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md §3.1 + §8 Appendix A.", + "top_level": [ + "CZSC", + "FX", + "BI", + "ZS", + "RawBar", + "NewBar", + "Freq", + "Mark", + "Direction", + "Operate", + "Signal", + "Event", + "Position", + "BarGenerator", + "format_standard_kline", + "freq_end_time", + "is_trading_time", + "check_bi", + "check_fx", + "check_fxs", + "remove_include", + "CzscTrader", + "CzscSignals", + "generate_czsc_signals", + "get_unique_signals", + "WeightBacktest", + "daily_performance", + "top_drawdowns", + "ultimate_smoother", + "rolling_rank", + "ema", + "sma", + "boll_positions", + "mock", + "envs", + "signals", + "traders", + "ta", + "utils", + "connectors", + "sensors", + "svc" + ], + "traders": [ + "CzscTrader", + "CzscSignals", + "generate_czsc_signals", + "get_unique_signals", + "WeightBacktest", + "SignalsParser" + ], + "ta": [ + "ultimate_smoother", + "rolling_rank", + "ema", + "sma", + "boll_positions" + ], + "signal_subpackages": [ + "bar", + "cxt", + "tas", + "vol", + "pressure", + "obv", + "cvolp" + ], + "removed": [ + "DummyBacktest", + "OpensOptimize", + "ExitsOptimize", + "CTAResearch" + ], + "removed_envs": [ + "CZSC_USE_PYTHON" + ] +} diff --git a/test/compat/test_public_api.py b/test/compat/test_public_api.py new file mode 100644 index 000000000..dcb4ac623 --- /dev/null +++ b/test/compat/test_public_api.py @@ -0,0 +1,159 @@ +"""公共 API 兼容性快照测试。 + +本测试锁定迁移后 ``czsc`` 包对外暴露的公共 API 表面,包括: + + * ``czsc.*`` 顶层应导出的核心名称 + * ``czsc.signals.*`` 应包含的信号子包 + * ``czsc.traders.*`` 应包含的公共名称 + * ``czsc.ta.*`` 技术指标命名空间应包含的名称 + * ``czsc.WeightBacktest`` 必须来自 ``wbt`` 包(架构层面的约束) + * 已废弃的旧 API(如 ``czsc.dummy_backtest``)必须被移除 + +期望的 API 集合保存在 ``snapshots/api_v1.json`` 中,新增 / 删除任何 +公共名称都需要先更新这份快照,从而对所有破坏性变更形成显式审计。 + +所有断言都被写成"找不到模块/属性 -> AssertionError"的形式,而不是让 +``ImportError`` 直接抛出 —— 这样在 pytest 报告里会被记为 FAIL 而不是 +ERROR,可以更清晰地呈现公共 API 的缺失情况。 +""" + +from __future__ import annotations + +import importlib +import json +from pathlib import Path +from typing import Any + +import pytest + +# 公共 API 快照文件(手工维护,每次 API 变更需同步更新) +SNAPSHOT_PATH = Path(__file__).parent / "snapshots" / "api_v1.json" + + +def _load_snapshot() -> dict[str, Any]: + """读取并解析 API 快照 JSON。""" + return json.loads(SNAPSHOT_PATH.read_text(encoding="utf-8")) + + +def _safe_import(name: str) -> tuple[Any | None, str | None]: + """安全地 import 一个模块。 + + 返回 (module, None) 或 (None, error_message),把 ImportError 等异常 + 转成可读字符串,避免 pytest 把它们记成 ERROR 而不是 FAIL。 + """ + try: + return importlib.import_module(name), None + except Exception as exc: # noqa: BLE001 + return None, f"{type(exc).__name__}: {exc}" + + +def test_top_level_names_importable() -> None: + """``czsc.*`` 顶层必须暴露快照中列出的所有公共名称。 + + 关键断言:``snap["top_level"]`` 中每一个名字都能通过 ``hasattr(czsc, name)`` + 访问到,缺失任何一个都视为破坏性变更。 + """ + snap = _load_snapshot() + czsc, err = _safe_import("czsc") + assert czsc is not None, f"failed to import czsc: {err}" + missing = [name for name in snap["top_level"] if not hasattr(czsc, name)] + assert not missing, ( + f"czsc.* missing {len(missing)} required public names: {missing}" + ) + + +def test_signal_subpackages_present() -> None: + """``czsc.signals.*`` 下的所有信号子包必须可以正常 import。 + + 关键断言:``snap["signal_subpackages"]`` 中每一个子包都能 import + 成功,否则记录失败原因(含异常类型与消息)。 + """ + snap = _load_snapshot() + failures: list[str] = [] + for sub in snap["signal_subpackages"]: + mod, err = _safe_import(f"czsc.signals.{sub}") + if mod is None: + failures.append(f"czsc.signals.{sub} ({err})") + assert not failures, ( + f"czsc.signals.* missing {len(failures)} required subpackages: {failures}" + ) + + +def test_traders_namespace_complete() -> None: + """``czsc.traders.*`` 必须暴露快照中列出的所有公共名称。""" + snap = _load_snapshot() + traders, err = _safe_import("czsc.traders") + assert traders is not None, f"failed to import czsc.traders: {err}" + missing = [name for name in snap["traders"] if not hasattr(traders, name)] + assert not missing, ( + f"czsc.traders.* missing {len(missing)} required public names: {missing}" + ) + + +def test_ta_namespace_complete() -> None: + """``czsc.ta.*`` 必须暴露快照中列出的所有技术指标名称。""" + snap = _load_snapshot() + ta, err = _safe_import("czsc.ta") + assert ta is not None, f"failed to import czsc.ta: {err}" + missing = [name for name in snap["ta"] if not hasattr(ta, name)] + assert not missing, ( + f"czsc.ta.* missing {len(missing)} required public names: {missing}" + ) + + +def test_no_legacy_dummy_backtest() -> None: + """已废弃的旧公共名称必须从 ``czsc.*`` 中移除。 + + 关键断言:``snap["removed"]`` 中的每一个名字都不应再可访问;任何 + 残留都会被视为兼容性回归。 + """ + snap = _load_snapshot() + czsc, err = _safe_import("czsc") + assert czsc is not None, f"failed to import czsc: {err}" + leftover = [name for name in snap["removed"] if hasattr(czsc, name)] + assert not leftover, ( + f"czsc.* still exposes legacy names that must be removed: {leftover}" + ) + + +def test_no_czsc_use_python_branch() -> None: + """已废弃的环境变量必须从 ``czsc.envs`` 中移除。 + + 关键断言:``snap["removed_envs"]`` 中的每一个名字都不应再可访问。 + """ + snap = _load_snapshot() + envs, err = _safe_import("czsc.envs") + assert envs is not None, f"failed to import czsc.envs: {err}" + leftover = [name for name in snap["removed_envs"] if hasattr(envs, name)] + assert not leftover, ( + f"czsc.envs still exposes removed env vars: {leftover}" + ) + + +def test_weight_backtest_comes_from_wbt() -> None: + """``czsc.WeightBacktest`` 必须就是 ``wbt.WeightBacktest`` 同一个对象。 + + 架构层面的约束:迁移后 czsc 不再自带 WeightBacktest 实现,而是从 + 外部 ``wbt`` 包再导出。同一性(``is``)检查比类型相等更严格,能 + 捕获意外的 import 路径漂移。 + """ + czsc, err = _safe_import("czsc") + assert czsc is not None, f"failed to import czsc: {err}" + wbt, wbt_err = _safe_import("wbt") + assert wbt is not None, f"failed to import wbt (hard dep): {wbt_err}" + assert getattr(czsc, "WeightBacktest", None) is wbt.WeightBacktest, ( + "czsc.WeightBacktest must be the same object as wbt.WeightBacktest" + ) + + +@pytest.mark.parametrize("module_name", ["czsc.connectors", "czsc.sensors", "czsc.svc"]) +def test_retained_subpackages_importable(module_name: str) -> None: + """保留的子包必须在迁移过程中始终可 import。 + + 参数化覆盖三个保留子包: + * ``czsc.connectors`` —— 数据源连接器 + * ``czsc.sensors`` —— 事件检测与特征分析 + * ``czsc.svc`` —— 统计与可视化服务 + """ + mod, err = _safe_import(module_name) + assert mod is not None, f"failed to import retained subpackage {module_name}: {err}" diff --git a/test/integration/__init__.py b/test/integration/__init__.py new file mode 100644 index 000000000..a5b017d8e --- /dev/null +++ b/test/integration/__init__.py @@ -0,0 +1,8 @@ +"""集成测试包。 + +本目录用于存放集成测试用例(integration tests),用于验证多个模块或组件 +组合在一起时的正确协作行为。集成测试会跨越模块边界,依赖真实的依赖项 +(如外部 Rust 扩展、第三方包等),不进行 Mock 隔离。 + +测试文件命名约定:``test_*.py``。 +""" diff --git a/test/integration/test_weight_backtest.py b/test/integration/test_weight_backtest.py new file mode 100644 index 000000000..44a4d8d26 --- /dev/null +++ b/test/integration/test_weight_backtest.py @@ -0,0 +1,100 @@ +"""权重回测(WeightBacktest)跨包集成测试。 + +本测试套件验证 ``czsc`` 顶层命名空间中暴露的权重回测相关 API 是直接来自 +外部 ``wbt`` 包的对象重导出(re-export),即对象身份必须完全一致。 + +业务背景: + ``WeightBacktest``、``daily_performance``、``top_drawdowns`` 这三个 API + 在 czsc 中不再维护并行实现,而是统一由独立的 ``wbt`` 包提供,并将 + ``wbt`` 列为 czsc 的硬依赖(hard dependency)。这样可以避免在两个包 + 之间出现实现漂移(implementation drift),同时让 wbt 可以被其它项目 + 单独使用。 + +测试覆盖: + - ``wbt`` 包必须可被成功导入 + - ``czsc.WeightBacktest`` 必须与 ``wbt.WeightBacktest`` 是同一对象 + - ``czsc.daily_performance``、``czsc.top_drawdowns`` 同上 + - ``czsc.WeightBacktest`` 不得来自 ``rs_czsc`` 模块(已废弃的来源) +""" + +from __future__ import annotations + +from typing import Any + +import pytest + + +def _safe_import(name: str) -> tuple[Any | None, str | None]: + """安全地按名称导入模块,捕获所有异常并以元组形式返回。 + + 导入成功时返回 ``(module, None)``;导入失败时返回 ``(None, error_msg)``, + 其中 ``error_msg`` 包含异常类型名和异常消息,便于失败时输出可读的诊断信息。 + """ + try: + return __import__(name, fromlist=["__init__"]), None + except Exception as exc: # noqa: BLE001 + return None, f"{type(exc).__name__}: {exc}" + + +# 参数化覆盖三个必须从 wbt 重导出的 API 名称 +@pytest.mark.parametrize( + "attr_name", + ["WeightBacktest", "daily_performance", "top_drawdowns"], +) +def test_czsc_attr_is_wbt_attr(attr_name: str) -> None: + """验证 czsc 顶层属性与 wbt 中对应属性是同一个 Python 对象。 + + 测试场景: + 参数化执行三次,分别校验 ``WeightBacktest`` 类、``daily_performance`` + 和 ``top_drawdowns`` 函数。 + + 关键断言: + 使用 Python 的 ``is`` 运算符断言对象身份完全相同(不是等价,而是同一对象), + 以此防止 czsc 私下维护一个并行实现或重新包装的副本。 + """ + czsc, czsc_err = _safe_import("czsc") + assert czsc is not None, f"failed to import czsc: {czsc_err}" + wbt, wbt_err = _safe_import("wbt") + if wbt is None: + pytest.fail( + f"wbt 必须作为硬依赖存在 ({wbt_err});" + f"czsc.{attr_name} 必须从 wbt 重导出" + ) + + czsc_attr = getattr(czsc, attr_name, None) + wbt_attr = getattr(wbt, attr_name, None) + + if czsc_attr is None: + pytest.fail(f"czsc.{attr_name} 缺失;期望从 wbt 重导出") + if wbt_attr is None: + pytest.fail(f"wbt.{attr_name} 缺失;请检查 wbt 版本是否正确") + + assert czsc_attr is wbt_attr, ( + f"czsc.{attr_name} 必须与 wbt.{attr_name} 是同一对象 " + f"(实际 czsc.{attr_name}={czsc_attr!r} 来自 " + f"{getattr(czsc_attr, '__module__', '?')}, wbt.{attr_name}={wbt_attr!r} " + f"来自 {getattr(wbt_attr, '__module__', '?')})" + ) + + +def test_no_residual_rs_czsc_dependency() -> None: + """验证 czsc.WeightBacktest 不再通过 rs_czsc 路由。 + + 测试目标: + 确保 czsc 已经完全切换为通过 wbt 提供 WeightBacktest, + 而不是继续依赖已经废弃的 ``rs_czsc`` PyPI 包。 + + 关键断言: + ``czsc.WeightBacktest.__module__`` 中不得包含 ``rs_czsc`` 字符串。 + """ + czsc, err = _safe_import("czsc") + assert czsc is not None, err + # 该属性必须由 wbt 提供,而不是来自 rs_czsc + wb = getattr(czsc, "WeightBacktest", None) + if wb is None: + pytest.fail("czsc.WeightBacktest 缺失") + module = getattr(wb, "__module__", "?") + assert "rs_czsc" not in module, ( + f"czsc.WeightBacktest 仍然通过 {module!r} 路由;" + f"必须替换为 wbt.WeightBacktest" + ) diff --git a/test/parity/__init__.py b/test/parity/__init__.py new file mode 100644 index 000000000..d67f8adf9 --- /dev/null +++ b/test/parity/__init__.py @@ -0,0 +1,15 @@ +"""parity 测试包 —— Python ↔ Rust 等价性测试套件。 + +本包内的所有测试用例均围绕一个核心目标:保证迁移后的 ``czsc`` 模块 +(基于 Rust 重构的 ``_native`` / ``_rs_czsc`` 后端)在功能与数值上与 +PyPI 上的基线版本 ``rs_czsc`` 完全一致。 + +主要覆盖范围: + * 缠论核心分析器(CZSC 分型/笔/中枢) + * 信号注册表与 ``derive_signals_*`` 系列工具 + * ``run_research`` 全链路回测 + * ``OpensOptimize`` / ``ExitsOptimize`` 优化流程 + * 示例脚本(如 ``30分钟笔非多即空``、``use_optimize`` 等) + +测试数据全部使用固定随机种子的 mock K线,保证结果可重现。 +""" diff --git a/test/parity/_signal_defaults.py b/test/parity/_signal_defaults.py new file mode 100644 index 000000000..55ac67805 --- /dev/null +++ b/test/parity/_signal_defaults.py @@ -0,0 +1,105 @@ +"""信号模板参数默认值与渲染工具。 + +本模块为 parity 测试套件提供"参数模板 -> 具体信号字符串"的转换能力。 +``czsc-signals`` 库中每一个 K 线信号都会向注册表登记一个参数模板, +形如 ``{freq}_D{di}N{n}M{m}_..._VyymmDD``。在等价性测试中,我们需要 +把每个信号都跑一遍,因此必须把模板里的 ``{placeholder}`` 替换成具体 +的、能让 Rust 端 ``assert!`` 校验通过的取值,从而合成一个可消费的 +七段式信号字符串。 + +设计要点: + * ``DEFAULTS`` 中的取值需要同时满足 Rust 实现里的所有约束断言, + 例如 ``n < m``、``w > 10``、``th in 30..300``、``t1 < t2`` 等, + 否则 all-signals 批量测试在某些信号上会触发 panic。 + * 字典 key 大小写敏感,``n`` 与 ``N`` 都需要登记,因为不同信号 + 使用了不同大小写的占位符。 + * ``render`` 在替换之后还会拼接一个标准化的取值后缀 + ``_v1_v2_v3_0``,保证渲染结果是合法的七段式信号字符串。 +""" + +from __future__ import annotations + +# 信号模板占位符 -> 默认替换值。 +# 取值经过精挑细选,必须满足 Rust 端所有 assert! 约束。 +DEFAULTS: dict[str, str] = { + "freq": "日线", + "freq1": "日线", + # 计数器 / 回看窗口(lookback) + "di": "1", + "n": "5", "N": "5", + "m": "20", "M": "20", + "p": "5", + "q": "5", + "k": "5", + "j": "5", + "l": "5", + "s": "5", + "z": "5", + "t": "5", + "w": "15", # >10 才能满足"压力位/支撑位"信号的 assert + "window": "15", + "rumi_window": "15", + # 成对参数(必须满足 a < b 的约束) + "t1": "5", "t2": "20", + "th1": "5", "th2": "20", "th3": "30", + "tha": "5", "thb": "50", "thc": "500", + "timeperiod1": "5", "timeperiod2": "20", + "min_count": "3", "max_count": "10", + # 单一阈值(RSI / 动量类信号常用) + "th": "50", # 必须落在 30 < th < 300 区间 + "ndev": "2", "nbdev": "2", + "avg_bp": "5", + "bi_init_length": "20", + "max_overlap": "3", + "num": "5", + "up": "1", "dw": "1", + "zf": "5", + "tl": "5", + "key": "close", + "line": "close", + # 移动平均与技术指标参数 + "ma_type": "SMA", + "ma_seq": "5", + "timeperiod": "20", + "fastperiod": "5", + "slowperiod": "20", + "signalperiod": "9", + "fastk_period": "5", + "slowd_period": "9", + "slowk_period": "9", + "md": "5", + "mp": "5", + "sp": "10", + "lp": "20", + # 常量类占位符 + "mode": "CO", + "K1": "K1", "K2": "K2", + "c1": "K1", "c2": "K2", +} + + +def render(template: str) -> str: + """将模板中的所有 ``{placeholder}`` 替换为默认值,并补齐取值后缀。 + + 具体行为: + 1. 把 ``template`` 中的每一个 ``{name}`` 替换为 ``DEFAULTS[name]``, + 对于未登记的占位符回退到字符串 ``"5"``。 + 2. 在结果末尾追加 ``_v1_v2_v3_0``,使其成为一个合法的、由七段 + 组成的信号字符串。 + + 参数: + template: 形如 ``"{freq}_D{di}N{n}M{m}_..._VyymmDD"`` 的参数模板。 + + 返回: + 渲染后的完整七段式信号字符串。 + """ + import re + + def _sub(m): + # 取出 {} 中的占位符名称,回退到默认 "5" + ph = m.group(1) + return DEFAULTS.get(ph, "5") + + rendered = re.sub(r"\{([^}]+)\}", _sub, template) + # 追加标准化的取值后缀,组成完整七段式信号 + return f"{rendered}_v1_v2_v3_0" diff --git a/test/parity/bench_optimize.py b/test/parity/bench_optimize.py new file mode 100644 index 000000000..315072962 --- /dev/null +++ b/test/parity/bench_optimize.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +"""性能基准测试脚本:对比 ``rs_czsc`` 与迁移后的 ``czsc`` 在 +``OpensOptimize`` / ``ExitsOptimize`` 工作流上的耗时表现。 + +使用场景与定位: + * 我们已经在 ``compare_optimize_full.py`` 中证明了两套实现的输出 + 在比特层面完全一致,因此本脚本不再做正确性校验,而是专注于 + 墙钟耗时(wall time)的对比。 + * 本脚本会对每种工作流(开仓优化 / 出场优化)执行 N 次试验, + 每次试验都使用一个全新的临时目录,避免缓存干扰。 + * 输出包含均值±标准差以及 czsc / rs_czsc 的耗时比,帮助快速识别 + 性能回归。 + +候选信号集合默认使用 ``list_all_signals`` 输出的 222 条 K 线信号, +通过 ``_signal_defaults.render`` 渲染成具体的信号字符串后传入。 + +运行方式(在 worktree 根目录下): + + uv run python test/parity/bench_optimize.py [--trials N] +""" + +from __future__ import annotations + +import argparse +import hashlib +import importlib +import json +import shutil +import statistics +import sys +import tempfile +import time +from pathlib import Path + +import pandas as pd + +# 把 parity 测试目录加入 sys.path,方便复用其中的辅助函数 +ROOT = Path(__file__).resolve().parents[2] +PARITY_DIR = ROOT / "test" / "parity" +sys.path.insert(0, str(PARITY_DIR)) +sys.path.insert(0, str(PARITY_DIR / "_compare_optimize")) + +from _signal_defaults import render # noqa: E402 + +# 复用 parity 脚本中的工具函数(数据准备、模块导入、仓位文件落盘等) +from compare_optimize_full import ( # noqa: E402 + _import_module, + _make_read_bars, + all_kline_candidate_events, + all_kline_candidate_signals, + make_bars_df, + write_beta_positions, +) + + +def time_open(module_name: str, results_root: Path, candidates: list[str]) -> float: + """测量一次 ``OpensOptimize.execute`` 的耗时。 + + 参数: + module_name: 待测模块名,``"rs_czsc"`` 或 ``"czsc"``。 + results_root: 用于存放本次试验输出的临时目录。 + candidates: 候选开仓信号字符串列表。 + + 返回: + ``execute`` 调用的耗时(秒)。 + """ + czsc_mod, optim_mod = _import_module(module_name) + OpensOptimize = optim_mod.OpensOptimize + + bar_sdt, bar_edt, sdt = "20200101", "20200310", "20200104" + bars_5min = make_bars_df("5分钟", bar_sdt, bar_edt) + bars_daily = make_bars_df("日线", bar_sdt, bar_edt) + get_raw_bars = _make_read_bars(czsc_mod, bars_5min, bars_daily) + files_position = write_beta_positions(czsc_mod, results_root / "base_positions", "000001") + + oop = OpensOptimize( + symbols=["000001"], files_position=files_position, + task_name="BenchOpen", candidate_signals=candidates, + read_bars=get_raw_bars, results_path=results_root, + signals_module_name="czsc.signals", + bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + ) + t0 = time.perf_counter() + oop.execute(n_jobs=1) + return time.perf_counter() - t0 + + +def time_exit(module_name: str, results_root: Path, candidate_events: list[dict]) -> float: + """测量一次 ``ExitsOptimize.execute`` 的耗时。 + + 参数: + module_name: 待测模块名,``"rs_czsc"`` 或 ``"czsc"``。 + results_root: 用于存放本次试验输出的临时目录。 + candidate_events: 候选出场事件 dict 列表。 + + 返回: + ``execute`` 调用的耗时(秒)。 + """ + czsc_mod, optim_mod = _import_module(module_name) + ExitsOptimize = optim_mod.ExitsOptimize + + bar_sdt, bar_edt, sdt = "20200101", "20200310", "20200104" + bars_5min = make_bars_df("5分钟", bar_sdt, bar_edt) + bars_daily = make_bars_df("日线", bar_sdt, bar_edt) + get_raw_bars = _make_read_bars(czsc_mod, bars_5min, bars_daily) + files_position = write_beta_positions(czsc_mod, results_root / "base_positions", "000001") + + eop = ExitsOptimize( + symbols=["000001"], files_position=files_position, + task_name="BenchExit", candidate_events=candidate_events, + read_bars=get_raw_bars, results_path=results_root, + signals_module_name="czsc.signals", + # 显式指定 base_freq 是为了绕过 czsc 在自动推导时对 strategy.positions 的处理 bug + base_freq="5分钟", + bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + ) + t0 = time.perf_counter() + eop.execute(n_jobs=1) + return time.perf_counter() - t0 + + +def fmt(times: list[float]) -> str: + """把多次试验的耗时列表格式化为 ``mean±stdev (min/max)`` 字符串。""" + if len(times) < 2: + return f"{times[0]*1000:.0f}ms" + return (f"{statistics.mean(times)*1000:.0f}±{statistics.stdev(times)*1000:.0f}ms " + f"(min {min(times)*1000:.0f}ms / max {max(times)*1000:.0f}ms)") + + +def main(): + """命令行入口:解析参数、收集基准数据、打印汇总表。""" + ap = argparse.ArgumentParser() + ap.add_argument("--trials", type=int, default=5) + args = ap.parse_args() + + # 候选信号 / 候选事件预计算一次即可:parity 测试已经证明两套实现 + # 在这部分输入上完全一致,因此可以放心共享。 + import czsc as _cz + candidate_signals = all_kline_candidate_signals(_cz) + candidate_events = all_kline_candidate_events(_cz) + # 把可能存在的 dict 形式的信号统一回退成字符串形式,确保两套实现 + # 在最终消费时拿到完全相同的输入(双保险)。 + for e in candidate_events: + e["signals_all"] = [ + (s if isinstance(s, str) else f"{s['key']}_{s['value']}") + for s in e["signals_all"] + ] + + print(f"trials={args.trials}, candidate_signals={len(candidate_signals)}, candidate_events={len(candidate_events)}") + print() + + # 每种 (module × kind) 组合都先做一次 warmup 以摊销 import / 首次调用开销, + # 然后再跑 args.trials 次正式试验。 + results: dict = {"open": {}, "exit": {}} + + for kind, module in [("open", "rs_czsc"), ("open", "czsc"), + ("exit", "rs_czsc"), ("exit", "czsc")]: + # warm-up 阶段:失败时记录但不阻塞后续测量 + with tempfile.TemporaryDirectory() as tmp: + try: + if kind == "open": + time_open(module, Path(tmp), candidate_signals) + else: + time_exit(module, Path(tmp), candidate_events) + except Exception as e: + print(f"[{module}/{kind}] warmup FAIL: {e}") + continue + + times: list[float] = [] + for i in range(args.trials): + with tempfile.TemporaryDirectory() as tmp: + if kind == "open": + t = time_open(module, Path(tmp), candidate_signals) + else: + t = time_exit(module, Path(tmp), candidate_events) + times.append(t) + print(f" [{module}/{kind}] trial {i+1}: {t*1000:.0f}ms") + results[kind][module] = times + + # 汇总输出 + print() + print("=" * 64) + print("BENCHMARK SUMMARY") + print("=" * 64) + print(f"{'kind':<8} {'rs_czsc':<28} {'czsc':<28} {'czsc/rs':<8}") + for kind in ("open", "exit"): + rs = results[kind].get("rs_czsc", []) + cs = results[kind].get("czsc", []) + if not rs or not cs: + continue + rs_mean = statistics.mean(rs) + cs_mean = statistics.mean(cs) + ratio = cs_mean / rs_mean + print(f"{kind:<8} {fmt(rs):<28} {fmt(cs):<28} {ratio:.3f}x") + print() + + +if __name__ == "__main__": + main() diff --git a/test/parity/compare_optimize_full.py b/test/parity/compare_optimize_full.py new file mode 100644 index 000000000..f99db4f15 --- /dev/null +++ b/test/parity/compare_optimize_full.py @@ -0,0 +1,451 @@ +# -*- coding: utf-8 -*- +"""完整 K 线信号集合下的 Open/Exit 优化等价性对比脚本。 + +本脚本对比 ``rs_czsc`` 与迁移后的 ``czsc`` 在 ``use_optimize.py`` 工作流 +上的输出一致性,候选信号覆盖 ``list_all_signals`` 返回的全部 222 个 +K 线信号。它在功能上与 ``rs_czsc/examples/use_optimize.py`` 等价,但 +做了如下调整: + + * 缺失的 ``python/tests/k_line.feather`` 替换为 ``wbt.mock`` 提供的 + 可重现 mock K 线(固定 ``seed=42``)。 + * 候选开仓信号从原来的 4 个扩展到 222 个,全部通过 + ``_signal_defaults.render`` 渲染成具体信号字符串。 + * ``OpensOptimize`` / ``ExitsOptimize`` 各跑一次 ``rs_czsc`` 与 + ``czsc``,输出分别落在两个兄弟目录下。 + * 遍历两份输出树,对每个 parquet / xlsx 文件做逐字节、逐字段对比。 + +运行方式(在 worktree 根目录下): + + uv run python test/parity/compare_optimize_full.py +""" + +from __future__ import annotations + +import hashlib +import json +import shutil +import sys +import time +from pathlib import Path + +import pandas as pd + +ROOT = Path(__file__).resolve().parents[2] +PARITY_DIR = ROOT / "test" / "parity" +sys.path.insert(0, str(PARITY_DIR)) + +from _signal_defaults import render # noqa: E402 + +OUT_ROOT = PARITY_DIR / "_compare_optimize" + + +# --------------------------------------------------------------------- # +# K 线数据准备 —— 用 mock 替换原示例缺失的 k_line.feather # +# --------------------------------------------------------------------- # + +def make_bars_df(freq: str, sdt: str, edt: str) -> pd.DataFrame: + """生成单品种 mock K 线 DataFrame,列与示例脚本对齐。 + + 使用固定 ``seed=42`` 以保证两次运行(rs_czsc 与 czsc)的输入完全一致。 + """ + from wbt.mock import mock_symbol_kline + df = mock_symbol_kline("000001", freq, sdt, edt, seed=42) + df["dt"] = pd.to_datetime(df["dt"]) + cols = ["dt", "symbol", "open", "high", "low", "close", "vol", "amount"] + return df.loc[:, cols].reset_index(drop=True).copy() + + +# --------------------------------------------------------------------- # +# Beta 仓位构造与 read_bars 回调 # +# --------------------------------------------------------------------- # + +def _sig_str_to_kv(sig: str) -> dict: + """把七段式信号字符串拆分成 ``{key, value}`` 字典形式。 + + czsc.Position.load 只接受 dict 形式的信号;rs_czsc 两种都接受。 + 统一使用 dict 形式可以让两套实现走完全相同的代码路径。 + """ + parts = sig.split("_") + return {"key": "_".join(parts[:-4]), "value": "_".join(parts[-4:])} + + +def build_position(czsc_module, symbol, name, open_signal, open_operate): + """根据传入的开仓信号构造一个 Beta 基准 Position。""" + Position = czsc_module.Position + exit_operate = "平多" if open_operate == "开多" else "平空" + exit_signal = "5分钟_D1单K趋势N5_BS辅助V230506_第5层_任意_任意_0" + + def event_dict(name_, op, sig): + # czsc.Position.load 只接受 dict 形式的信号;rs_czsc 两种都接受。 + # 为了让两条代码路径完全一致,这里统一使用 dict 形式。 + return { + "name": name_, + "operate": op, + "signals_all": [_sig_str_to_kv(sig)], + "signals_any": [], + "signals_not": [], + } + + # czsc.Position 不接受 T0 关键字参数;rs_czsc 接受。 + # 通过 .load 走 dict 入口可以让两套实现保持完全一致的入参形态。 + return Position.load({ + "symbol": symbol, + "name": name, + "opens": [event_dict(f"{name}_open", open_operate, open_signal)], + "exits": [event_dict(f"{name}_exit", exit_operate, exit_signal)], + "interval": 0, + "timeout": 120, + "stop_loss": 800.0, + "T0": False, + }) + + +def write_beta_positions(czsc_module, path: Path, symbol: str) -> list[str]: + """把多空两个 Beta 仓位序列化为 JSON 文件,返回文件路径列表。 + + OpensOptimize / ExitsOptimize 都需要从磁盘文件加载基准仓位, + 两套实现使用相同的 JSON 内容即可保证后续对比的有效性。 + """ + path.mkdir(parents=True, exist_ok=True) + positions = [ + build_position( + czsc_module, symbol, "long_beta", + "5分钟_D1单K趋势N5_BS辅助V230506_第1层_任意_任意_0", "开多", + ), + build_position( + czsc_module, symbol, "short_beta", + "5分钟_D1单K趋势N5_BS辅助V230506_第18层_任意_任意_0", "开空", + ), + ] + files = [] + for pos in positions: + payload = pos.dump(with_data=False) + payload.pop("symbol", None) + # md5 字段供 czsc 内部做去重 / 缓存命中校验 + payload["md5"] = hashlib.md5(str(payload).encode("utf-8")).hexdigest() + f = path / f"{pos.name}.json" + f.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") + files.append(str(f)) + return files + + +# --------------------------------------------------------------------- # +# 候选信号集合 —— 全量 K 线信号(按默认参数渲染) # +# --------------------------------------------------------------------- # + +def all_kline_candidate_signals(czsc_module) -> list[str]: + """渲染信号注册表里所有 K 线类信号为完整七段式字符串。 + + rs_czsc 在包根与 ``_native`` 子模块上都暴露了 ``list_all_signals``, + 而迁移后的 czsc 只在 ``_native`` 上暴露,这里做了兼容处理。 + """ + if hasattr(czsc_module, "_native"): + all_sigs = czsc_module._native.list_all_signals() + else: + all_sigs = czsc_module.list_all_signals() + sigs = [s for s in all_sigs if s["category"] == "kline"] + rendered = [] + for s in sigs: + try: + r = render(s["param_template"]) + # 渲染后还有未填充占位符的模板直接丢弃 + if "{" not in r: + rendered.append(r) + except Exception: + pass + return sorted(set(rendered)) + + +# --------------------------------------------------------------------- # +# 驱动逻辑 —— 对每个模块各跑一次 OpensOptimize / ExitsOptimize # +# --------------------------------------------------------------------- # + +def _import_module(module_name: str): + """根据名字导入对应的 czsc 模块和 traders.optimize 子模块。 + + rs_czsc 的 ``Event.is_match`` 默认返回 bool,但 optimize 包装层 + 期望返回 (matched, reason) 元组,因此这里需要打 patch(与原示例 + 脚本一致)。patch 设置幂等标记位避免重复包装。 + """ + if module_name == "rs_czsc": + import rs_czsc as czsc_mod + from rs_czsc import Event as _Event + if not getattr(_Event, "_rs_tuple_contract_patch", False): + origin = _Event.is_match + + def _wrapped(self, sig): + out = origin(self, sig) + return out if isinstance(out, tuple) else (out, "is_match" if out else "") + _Event.is_match = _wrapped + _Event._rs_tuple_contract_patch = True + elif module_name == "czsc": + import czsc as czsc_mod + else: + raise ValueError(module_name) + import importlib + optim_mod = importlib.import_module(f"{module_name}.traders.optimize") + return czsc_mod, optim_mod + + +def _make_read_bars(czsc_mod, bars_5min, bars_daily): + """构造一个 ``read_bars`` 回调,按 freq 选择对应频率的 mock 数据。""" + def get_raw_bars(symbol_in, freq_in, sdt_in, edt_in, **_): + df = bars_daily if freq_in == "日线" else bars_5min + df = df[df["symbol"] == symbol_in] + return czsc_mod.format_standard_kline(df, freq=freq_in) + return get_raw_bars + + +def all_kline_candidate_events(czsc_mod) -> list[dict]: + """构造 ExitsOptimize 所需的候选出场事件 dict 列表。 + + 遍历全量 K 线信号,按下标奇偶性交替使用 "平多" / "平空",确保 + 多空两边都有候选事件可供测试。 + + 注意事项:必须使用字符串形式的 signals_all(与 rs_czsc 示例一致), + Rust 的 optimize-batch 在这个入口上对 dict 形式的 ``{key, value}`` 会 + panic —— 即便 ``Position.load`` 同时接受这两种形式 —— 这是 + ``ExitsOptimize`` 已知的入参 shape 限制。 + """ + out = [] + for i, sig in enumerate(all_kline_candidate_signals(czsc_mod)): + operate = "平多" if i % 2 == 0 else "平空" + out.append({ + "name": f"exit_{i:03d}", + "operate": operate, + "signals_all": [sig], + "signals_any": [], + "signals_not": [], + }) + return out + + +def run_open(module_name: str, results_root: Path) -> Path: + """运行一次 OpensOptimize 全流程,返回结果目录路径。""" + czsc_mod, optim_mod = _import_module(module_name) + OpensOptimize = optim_mod.OpensOptimize + + symbol = "000001" + bar_sdt, bar_edt = "20200101", "20200310" + sdt = "20200104" + + bars_5min = make_bars_df("5分钟", bar_sdt, bar_edt) + bars_daily = make_bars_df("日线", bar_sdt, bar_edt) + get_raw_bars = _make_read_bars(czsc_mod, bars_5min, bars_daily) + + files_position = write_beta_positions(czsc_mod, results_root / "base_positions", symbol) + candidates = all_kline_candidate_signals(czsc_mod) + print(f"[{module_name}/open] candidate_signals: {len(candidates)}") + + oop = OpensOptimize( + symbols=[symbol], + files_position=files_position, + task_name="FullKlineParityOpen", + candidate_signals=candidates, + read_bars=get_raw_bars, + results_path=results_root, + signals_module_name="czsc.signals", + bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + ) + t0 = time.perf_counter() + oop.execute(n_jobs=1) + elapsed = time.perf_counter() - t0 + print(f"[{module_name}/open] elapsed={elapsed:.1f}s -> {oop.results_path}") + return Path(oop.results_path) + + +def run_exit(module_name: str, results_root: Path) -> Path: + """运行一次 ExitsOptimize 全流程,返回结果目录路径。""" + czsc_mod, optim_mod = _import_module(module_name) + ExitsOptimize = optim_mod.ExitsOptimize + + symbol = "000001" + bar_sdt, bar_edt = "20200101", "20200310" + sdt = "20200104" + + bars_5min = make_bars_df("5分钟", bar_sdt, bar_edt) + bars_daily = make_bars_df("日线", bar_sdt, bar_edt) + get_raw_bars = _make_read_bars(czsc_mod, bars_5min, bars_daily) + + files_position = write_beta_positions(czsc_mod, results_root / "base_positions", symbol) + candidate_events = all_kline_candidate_events(czsc_mod) + print(f"[{module_name}/exit] candidate_events: {len(candidate_events)}") + + eop = ExitsOptimize( + symbols=[symbol], + files_position=files_position, + task_name="FullKlineParityExit", + candidate_events=candidate_events, + read_bars=get_raw_bars, + results_path=results_root, + signals_module_name="czsc.signals", + # 显式传 base_freq 来跳过自动推导:迁移后的 czsc 在自动推导路径上 + # 会对字符串形式的 signals_all 调用 Position.load(其更严格的校验 + # 器会拒绝该形态)。两套实现的 Rust optimizer 调用本身都接受字符串。 + base_freq="5分钟", + bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + ) + t0 = time.perf_counter() + eop.execute(n_jobs=1) + elapsed = time.perf_counter() - t0 + print(f"[{module_name}/exit] elapsed={elapsed:.1f}s -> {eop.results_path}") + return Path(eop.results_path) + + +# --------------------------------------------------------------------- # +# 输出树对比 # +# --------------------------------------------------------------------- # + +def inventory(root: Path) -> dict[str, int]: + """递归扫描 ``root`` 下的所有文件,返回 {相对路径: 字节数}。""" + out: dict[str, int] = {} + for p in sorted(root.rglob("*")): + if p.is_file(): + out[str(p.relative_to(root))] = p.stat().st_size + return out + + +def compare_trees(rs_path: Path, czsc_path: Path) -> dict: + """对比两棵输出树:文件清单、字节大小、parquet 与 xlsx 内容。 + + 返回值是一个汇总 dict,包含: + * rs_files / czsc_files: 两边的文件总数 + * rs_only / czsc_only: 仅出现在某一边的相对路径列表 + * size_diffs: 字节大小不同的文件列表 + * parquet_diffs: 内容不一致的 parquet 列表(含原因) + * parquet_checked: 实际比对的 parquet 数量 + * xlsx_diffs / xlsx_checked: 同上,针对 xlsx + """ + rs_inv = inventory(rs_path) + cs_inv = inventory(czsc_path) + + summary = { + "rs_files": len(rs_inv), + "czsc_files": len(cs_inv), + "rs_only": sorted(set(rs_inv) - set(cs_inv)), + "czsc_only": sorted(set(cs_inv) - set(rs_inv)), + "size_diffs": [], + "parquet_diffs": [], + "parquet_checked": 0, + "xlsx_diffs": [], + "xlsx_checked": 0, + } + + common = sorted(set(rs_inv) & set(cs_inv)) + for rel in common: + if rs_inv[rel] != cs_inv[rel]: + summary["size_diffs"].append((rel, rs_inv[rel], cs_inv[rel])) + if rel.endswith(".parquet"): + summary["parquet_checked"] += 1 + try: + a = pd.read_parquet(rs_path / rel) + b = pd.read_parquet(czsc_path / rel) + # cache 列含不可哈希对象,比对前先丢掉 + for c in ("cache",): + a = a.drop(columns=c, errors="ignore") + b = b.drop(columns=c, errors="ignore") + if a.shape != b.shape: + summary["parquet_diffs"].append({"rel": rel, "kind": "shape", "rs": a.shape, "czsc": b.shape}) + continue + cols = sorted(set(a.columns) & set(b.columns)) + if set(a.columns) != set(b.columns): + summary["parquet_diffs"].append({ + "rel": rel, "kind": "columns", + "rs_only": sorted(set(a.columns) - set(b.columns)), + "czsc_only": sorted(set(b.columns) - set(a.columns)), + }) + a = a[cols].reset_index(drop=True) + b = b[cols].reset_index(drop=True) + try: + pd.testing.assert_frame_equal(a, b, check_dtype=False, check_like=False) + except AssertionError as e: + summary["parquet_diffs"].append({"rel": rel, "kind": "data", "err": str(e)[:300]}) + except Exception as e: + summary["parquet_diffs"].append({"rel": rel, "kind": "read-error", "err": str(e)[:300]}) + elif rel.endswith(".xlsx"): + summary["xlsx_checked"] += 1 + try: + a = pd.read_excel(rs_path / rel) + b = pd.read_excel(czsc_path / rel) + if a.shape != b.shape: + summary["xlsx_diffs"].append({"rel": rel, "kind": "shape", "rs": a.shape, "czsc": b.shape}) + continue + # 汇总 xlsx 的行序在两次运行间可能不同(Rust 的优化按 HashMap + # 顺序遍历仓位)。比较前按 pos_name(或第一列)排序,确保 + # 我们做的是行集合相等的判断,而不是受顺序影响的逐行对比。 + sort_col = "pos_name" if "pos_name" in a.columns else a.columns[0] + a = a.sort_values(sort_col).reset_index(drop=True) + b = b.sort_values(sort_col).reset_index(drop=True) + try: + pd.testing.assert_frame_equal(a, b, check_dtype=False, check_like=False) + except AssertionError as e: + summary["xlsx_diffs"].append({"rel": rel, "kind": "data", "err": str(e)[:300]}) + except Exception as e: + summary["xlsx_diffs"].append({"rel": rel, "kind": "read-error", "err": str(e)[:300]}) + return summary + + +def _print_report(label: str, rep: dict) -> bool: + """格式化打印一份 compare_trees 报告,返回是否完全一致。""" + print("=" * 60) + print(f"[{label}] INVENTORY") + print(f" rs files: {rep['rs_files']}") + print(f" czsc files: {rep['czsc_files']}") + if rep["rs_only"] or rep["czsc_only"]: + print(f" rs only : {rep['rs_only'][:10]} (showing 10)") + print(f" czsc only: {rep['czsc_only'][:10]} (showing 10)") + else: + print(" inventory: IDENTICAL") + print(f"[{label}] SIZE DIFFS") + if rep["size_diffs"]: + for rel, a, b in rep["size_diffs"][:20]: + print(f" {rel}: rs={a}B czsc={b}B (Δ={b - a:+d})") + print(f" ... {len(rep['size_diffs'])} total") + else: + print(" (none) all files have identical byte size") + print(f"[{label}] PARQUET COMPARISON ({rep['parquet_checked']} files checked)") + if rep["parquet_diffs"]: + for d in rep["parquet_diffs"][:20]: + print(f" - {d}") + print(f" ... {len(rep['parquet_diffs'])} total parquet diffs") + else: + print(" ALL PARQUET CONTENTS IDENTICAL ✓") + print(f"[{label}] XLSX COMPARISON ({rep['xlsx_checked']} files, sorted)") + if rep["xlsx_diffs"]: + for d in rep["xlsx_diffs"][:20]: + print(f" - {d}") + else: + print(" ALL XLSX CONTENTS IDENTICAL (after sort) ✓") + print("=" * 60) + return not (rep["rs_only"] or rep["czsc_only"] or rep["parquet_diffs"] or rep["xlsx_diffs"]) + + +def main(): + """命令行入口:先后运行 OPEN / EXIT 两侧并打印报告。""" + if OUT_ROOT.exists(): + shutil.rmtree(OUT_ROOT) + OUT_ROOT.mkdir(parents=True) + + # 开仓侧:rs_czsc 与 czsc 各跑一次,输出落到不同目录 + rs_open = run_open("rs_czsc", OUT_ROOT / "open_rs") + cs_open = run_open("czsc", OUT_ROOT / "open_czsc") + + # 出场侧:同上 + rs_exit = run_exit("rs_czsc", OUT_ROOT / "exit_rs") + cs_exit = run_exit("czsc", OUT_ROOT / "exit_czsc") + + print() + print(f"OPEN rs results: {rs_open}") + print(f"OPEN czsc results: {cs_open}") + print(f"EXIT rs results: {rs_exit}") + print(f"EXIT czsc results: {cs_exit}") + print() + + open_ok = _print_report("OPEN", compare_trees(rs_open, cs_open)) + exit_ok = _print_report("EXIT", compare_trees(rs_exit, cs_exit)) + + return 0 if (open_ok and exit_ok) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/test/parity/conftest.py b/test/parity/conftest.py new file mode 100644 index 000000000..3d251c062 --- /dev/null +++ b/test/parity/conftest.py @@ -0,0 +1,107 @@ +"""parity 测试套件共享 fixture 定义。 + +本目录下所有 parity 测试都遵循同一套测试范式: + 1. 同时导入 ``rs_czsc``(PyPI 上的基线版本)与 ``czsc``(迁移后的本地版本) + 2. 用同一份固定随机种子的输入数据分别驱动两套实现 + 3. 比较输出,要求二者在数值/结构层面完全一致 + +``rs_czsc`` 通过 ``pyproject.toml`` 的开发依赖引入,仅供测试使用。 +""" + +from __future__ import annotations + +import pytest + + +@pytest.fixture(scope="session") +def rs_czsc_module(): + """以 session 范围导入 ``rs_czsc`` 基线模块。 + + 若运行环境未安装 ``rs_czsc``,使用 ``pytest.importorskip`` 跳过整个 + parity 测试套件而不是报错,这样可以让 CI 在没有 Rust 运行时的环境 + 下仍能完成其他测试。 + """ + rs_czsc = pytest.importorskip( + "rs_czsc", + reason="rs_czsc baseline must be installed to run parity tests", + ) + return rs_czsc + + +@pytest.fixture(scope="session") +def czsc_module(): + """以 session 范围导入待测的 ``czsc`` 模块。""" + import czsc + + return czsc + + +@pytest.fixture(scope="session") +def mock_kline_df(): + """单品种 K 线 DataFrame,固定随机种子,供所有需要 bars 输入的 parity + 测试使用。 + + 数据形态与 ``czsc._compat.bars_to_dataframe`` 输出保持一致: + 六个数值列(OHLC + vol + amount)使用 Float64,``dt`` 列为 datetime64。 + """ + from wbt.mock import mock_symbol_kline + + df = mock_symbol_kline("000001", "日线", "20230101", "20241231", seed=42) + return df + + +@pytest.fixture(scope="session") +def sample_signal_strings(): + """一组真实存在于注册表中的信号字符串,供 ``derive_signals_*`` 等 + parity 测试使用。 + + 选取的信号必须在 rs_czsc 与 czsc 两边都能解析,否则等价性测试无法 + 成立。 + """ + return [ + "日线_D1N5M5TH10_ADTMV230603_看多_任意_任意_0", + "日线_D1N5M5_AMV能量V230603_看多_任意_任意_0", + "日线_D1N5P5_ASI多空V230603_看多_任意_任意_0", + ] + + +@pytest.fixture(scope="session") +def sample_position_dict(): + """一份典型的 Position dict,使用注册表中真实存在的信号,便于跑通 + 完整的 ``run_research`` 流水线。 + + 关键字段: + * opens / exits: 每边一个 Event,使用 dict 形式的 signals_all + * interval / timeout / stop_loss: 风控相关参数 + * T0: 是否允许 T+0 交易 + """ + return { + "name": "test_pos", + "symbol": "000001", + "opens": [ + { + "name": "open_long", + "operate": "开多", + "signals_all": [ + {"key": "日线_D1N5M5TH10_ADTMV230603", "value": "看多_任意_任意_0"} + ], + "signals_any": [], + "signals_not": [], + }, + ], + "exits": [ + { + "name": "exit_long", + "operate": "平多", + "signals_all": [ + {"key": "日线_D1N5M5TH10_ADTMV230603", "value": "看空_任意_任意_0"} + ], + "signals_any": [], + "signals_not": [], + }, + ], + "interval": 0, + "timeout": 100, + "stop_loss": 500.0, + "T0": False, + } diff --git a/test/parity/test_all_signals.py b/test/parity/test_all_signals.py new file mode 100644 index 000000000..6473798cd --- /dev/null +++ b/test/parity/test_all_signals.py @@ -0,0 +1,239 @@ +"""全量 K 线信号在不同规模数据集下的等价性测试。 + +本测试是 parity 套件中覆盖面最广的一组:覆盖所有已注册的 222 个 K 线 +信号,并在四种不同规模的数据集上分别比对 ``rs_czsc`` 与 ``czsc`` 的 +``run_research`` 输出。 + +每个数据集的执行步骤: + 1. 把注册表里每一个 K 线信号渲染成具体的七段式信号字符串。其中 218 + 个信号可以通过 ``derive_signals_config`` 反推出运行时配置;剩下 4 + 个(``bar_amount_acc_V230214``、``bar_mean_amount_V221112``、 + ``bar_section_momentum_V221112``、``bar_zdf_V221203``)的模板含有 + value 段占位符,无法反推 —— rs_czsc 与 czsc 都返回 ``[]``,因此 + 我们直接根据 Rust 源代码手写它们的运行时配置,覆盖率从 218 提升 + 到 222。 + 2. 把所有信号配置合并成一个 ``signals_config``,使用同一个 base freq。 + 3. 用完全相同的 Arrow 字节 + JSON 策略分别调用 + ``czsc._native.run_research`` 与 ``rs_czsc._rs_czsc.run_research``。 + 4. 解码两边的 ``signals_arrow`` 输出,逐列做比特级一致性断言。 + +数据集规模设计: + * **small** 约 520 根日线(约 2 年) + * **medium** 约 5 200 根日线(约 20 年),通过 30min 重采样得到, + 但保持 ``日线`` 作为 base freq + * **large** 约 21 000 根 30 分钟 K 线(约 4 年日内) + * **xlarge** 约 52 500 根 30 分钟 K 线(约 10 年日内) + +xlarge 用例覆盖了与生产回测同一量级的输入,能在一次测试里同时拿到所有 +规模下的耗时比和成功/失败状态。 +""" + +from __future__ import annotations + +import json +import time + +import pandas as pd +import pytest + +from ._signal_defaults import render + + +# --------------------------------------------------------------------- # +# K 线 fixture # +# --------------------------------------------------------------------- # + +# 数据集元组:(标签, base_freq, 起始日期, 结束日期) +DATASETS = [ + ("small", "日线", "20230101", "20250101"), + ("medium", "日线", "20100101", "20250101"), + ("large", "30分钟", "20210101", "20250101"), + ("xlarge", "30分钟", "20140101", "20250101"), +] + + +def _make_bars(freq: str, sdt: str, edt: str) -> pd.DataFrame: + """生成指定频率与日期范围的 mock K 线,并清洗为 ``bars_to_dataframe`` 输出形态。""" + from wbt.mock import mock_symbol_kline + + from czsc._compat import bars_to_dataframe + + df = mock_symbol_kline("000001", freq, sdt, edt, seed=42) + df["dt"] = pd.to_datetime(df["dt"]) + return bars_to_dataframe(df, symbol="000001") + + +# --------------------------------------------------------------------- # +# 策略合成 # +# --------------------------------------------------------------------- # + +def _build_all_signals_strategy(czsc_module, base_freq: str): + """构造覆盖全部 K 线信号的运行时策略。 + + 返回 (strategy_dict, n_signals)。复用 ``_signal_defaults.render`` 渲染 + 信号字符串,确保 rs_czsc 与 czsc 接收的配置完全一致。 + """ + from czsc._compat import position_dump_to_runtime, signal_config_to_runtime + + sigs = [s for s in czsc_module._native.list_all_signals() if s["category"] == "kline"] + test_signals = [] + for s in sigs: + # 强制把 freq 占位符替换为当前数据集的 base_freq + rendered = render(s["param_template"]).replace("日线", base_freq, 1) + test_signals.append(rendered) + + runtime = czsc_module.derive_signals_config(test_signals) + runtime_for_freq = [c for c in runtime if c.get("freq") == base_freq] + + # 4 个 value 段含占位符的信号,derive_signals_config 无法反推出运行时配置。 + # 这里直接根据 Rust 源代码默认值手写:字段名与 Rust 端 ``params.*`` 取值 + # 完全一致,czsc 与 rs_czsc 都能消费。补齐之后覆盖率从 218 升到 222。 + derive_blind_spots = [ + {"name": "bar_amount_acc_V230214", "freq": base_freq, "params": {"di": 1, "n": 5, "t": 5}}, + {"name": "bar_mean_amount_V221112", "freq": base_freq, "params": {"di": 1, "n": 5, "th1": 1, "th2": 10}}, + {"name": "bar_section_momentum_V221112", "freq": base_freq, "params": {"di": 1, "n": 5, "th": 50}}, + {"name": "bar_zdf_V221203", "freq": base_freq, "params": {"di": 1, "mode": "ZF", "span": "5,20"}}, + ] + runtime_for_freq.extend(derive_blind_spots) + + # 构造一个 dummy Position,用于满足 run_research 对 positions 非空的校验。 + dummy_sig = test_signals[0].replace("日线", base_freq, 1) + parts = dummy_sig.split("_") + dummy_pos = czsc_module.Position.load( + { + "symbol": "000001", + "name": "_parity_dummy_", + "opens": [ + { + "name": "o", + "operate": "开多", + "signals_all": [{"key": "_".join(parts[:-4]), "value": "_".join(parts[-4:])}], + "signals_any": [], + "signals_not": [], + } + ], + "exits": [ + { + "name": "e", + "operate": "平多", + "signals_all": [{"key": "_".join(parts[:-4]), "value": "_".join(parts[-4:])}], + "signals_any": [], + "signals_not": [], + } + ], + "interval": 0, + "timeout": 100, + "stop_loss": 100, + "T0": False, + } + ) + + strategy = { + "name": "AllSignals", + "symbol": "000001", + "base_freq": base_freq, + "signals_module": "czsc.signals", + "signals_config": [signal_config_to_runtime(c) for c in runtime_for_freq], + "positions": [position_dump_to_runtime(dummy_pos.dump(with_data=False))], + "market": "默认", + "bg_max_count": 5000, + } + return strategy, len(runtime_for_freq) + + +def _signal_columns(df: pd.DataFrame) -> list[str]: + """从 DataFrame 列中剔除 K 线元数据列,仅保留信号输出列。""" + meta = {"dt", "symbol", "open", "close", "high", "low", "vol", "amount", "id", "freq", "cache"} + return sorted(c for c in df.columns if c not in meta) + + +# --------------------------------------------------------------------- # +# 参数化等价性测试 # +# --------------------------------------------------------------------- # + +@pytest.mark.parametrize( + "label,base_freq,sdt,edt", + DATASETS, + ids=[d[0] for d in DATASETS], +) +def test_all_signals_parity( + rs_czsc_module, czsc_module, label, base_freq, sdt, edt, capsys +): + """对每种规模的数据集,``czsc.run_research`` 输出的每一列信号都 + 必须与 ``rs_czsc.run_research`` 完全相等。 + + 测试目标: + * 验证迁移后的 czsc 在全量 K 线信号上与 rs_czsc 行为一致 + * 在四种数据规模下都能保持比特级一致 + + 关键断言: + * 输出 DataFrame 的 shape 完全相同 + * 列集合完全相同(无新增/缺失列) + * 每一个信号列的每一个 cell 都相等(NaN 也对齐) + """ + from czsc._utils._df_convert import arrow_bytes_to_pd_df, pandas_to_arrow_bytes + + bars_df = _make_bars(base_freq, sdt, edt) + arrow = pandas_to_arrow_bytes(bars_df) + + strategy, n_signals = _build_all_signals_strategy(czsc_module, base_freq) + strategy_json = json.dumps(strategy, ensure_ascii=False) + + # 跑 rs_czsc 基线 + t0 = time.perf_counter() + rs_payload = rs_czsc_module._rs_czsc.run_research(arrow, strategy_json, None, None) + rs_elapsed = time.perf_counter() - t0 + + # 跑迁移后的 czsc + t0 = time.perf_counter() + czsc_payload = czsc_module._native.run_research(arrow, strategy_json, None, None) + czsc_elapsed = time.perf_counter() - t0 + + rs_df = arrow_bytes_to_pd_df(bytes(rs_payload["signals_arrow"])) + czsc_df = arrow_bytes_to_pd_df(bytes(czsc_payload["signals_arrow"])) + + # shape 必须完全一致 + assert rs_df.shape == czsc_df.shape, ( + f"[{label}] shape mismatch: rs={rs_df.shape} czsc={czsc_df.shape}" + ) + + # 列集合严格相等:任何一边出现额外列都视为失败 + rs_cols = set(rs_df.columns) + czsc_cols = set(czsc_df.columns) + assert rs_cols == czsc_cols, ( + f"[{label}] column set differs.\n" + f" rs only: {rs_cols - czsc_cols}\n" + f" czsc only: {czsc_cols - rs_cols}" + ) + + # 逐列、逐 cell 对比信号值,记录所有不一致列 + sig_cols = _signal_columns(rs_df) + diverging = [] + for col in sig_cols: + rs_series = rs_df[col].reset_index(drop=True) + czsc_series = czsc_df[col].reset_index(drop=True) + if not rs_series.equals(czsc_series): + # 找到第一个不一致的行号,便于诊断 + mask = rs_series.ne(czsc_series) | (rs_series.isna() ^ czsc_series.isna()) + first_idx = mask.idxmax() if mask.any() else None + diverging.append( + (col, first_idx, rs_series.iloc[first_idx] if first_idx is not None else None, + czsc_series.iloc[first_idx] if first_idx is not None else None) + ) + + ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") + with capsys.disabled(): + # 打印耗时比和差异统计,便于在 pytest 输出中追踪性能与正确性 + print( + f"\n[all-signals/{label} bars={len(bars_df)} sigs={n_signals}] " + f"rs_czsc={rs_elapsed * 1000:.0f}ms " + f"czsc={czsc_elapsed * 1000:.0f}ms " + f"ratio={ratio:.2f}x " + f"signal_cols={len(sig_cols)} " + f"diverging={len(diverging)}" + ) + + assert not diverging, ( + f"[{label}] {len(diverging)} signal columns diverge.\n" + f" first 5 (col, row, rs, czsc): {diverging[:5]}" + ) diff --git a/test/parity/test_czsc_core.py b/test/parity/test_czsc_core.py new file mode 100644 index 000000000..895b18322 --- /dev/null +++ b/test/parity/test_czsc_core.py @@ -0,0 +1,135 @@ +"""缠论核心分析器(FX / BI)等价性测试。 + +本测试用例覆盖整个缠论核心算法的输出:在固定随机种子的 mock K 线数据 +上,要求 ``czsc.CZSC`` 的分型列表(fx_list)、笔列表(bi_list)等关键 +输出与 ``rs_czsc.CZSC`` 完全一致(容差为 0)。 + +这是迁移正确性的最强保证之一:只要这套测试通过,意味着所有依赖 CZSC +分析结果的下游工具(信号、回测、优化)就拥有了一致的输入基础。 +""" + +from __future__ import annotations + + +def _build_bars(module, mock_df): + """通过模块自身的 ``format_standard_kline`` 把 mock DataFrame 转成 RawBar 列表。 + + 把转换过程放进每个模块各自完成,可以让 parity 测试同时覆盖各自的 + bar 构造路径,而不是只比较共享的 Rust 内部状态。 + """ + freq = module.Freq.D + return module.format_standard_kline(mock_df, freq=freq) + + +def _bi_to_dict(bi): + """把单个 BI(笔)对象快照成 JSON 友好的 dict,便于结构化对比。""" + return { + "direction": str(bi.direction), + "high": bi.high, + "low": bi.low, + "sdt": str(bi.sdt), + "edt": str(bi.edt), + "fx_a_dt": str(bi.fx_a.dt), + "fx_b_dt": str(bi.fx_b.dt), + } + + +def _fx_to_dict(fx): + """把单个 FX(分型)对象快照成 JSON 友好的 dict。""" + return { + "mark": str(fx.mark), + "high": fx.high, + "low": fx.low, + "dt": str(fx.dt), + } + + +def _zs_to_dict(zs): + """把单个 ZS(中枢)对象快照成 JSON 友好的 dict(保留接口,用于以后扩展)。""" + return { + "zg": zs.zg, + "zd": zs.zd, + "gg": zs.gg, + "dd": zs.dd, + "sdt": str(zs.sdt), + "edt": str(zs.edt), + } + + +def _czsc_snapshot(module, mock_df): + """对 ``module.CZSC(bars)`` 取一份完整快照,便于跨模块对比。 + + 返回字段:fxs、bi_list、freq、symbol、max_bi_num。 + 备注:PyO3 暴露的 CZSC 类没有 ``zs_list`` 访问器(中枢推理目前在 + trader 信号层完成),因此这里不快照中枢。 + """ + bars = _build_bars(module, mock_df) + c = module.CZSC(bars) + return { + "fxs": [_fx_to_dict(f) for f in c.fx_list], + "bi_list": [_bi_to_dict(b) for b in c.bi_list], + "freq": str(c.freq), + "symbol": c.symbol, + "max_bi_num": c.max_bi_num, + } + + +def test_czsc_analyzer_outputs_match(rs_czsc_module, czsc_module, mock_kline_df): + """缠论核心算法在 mock 数据上的输出必须与 rs_czsc 完全一致。 + + 测试场景: + * 用同一份 fixed-seed 的 mock 日线 K 线分别构造两套 CZSC 对象 + * 对比 freq/symbol/max_bi_num 等元数据 + * 对比 fx_list、bi_list 的长度与逐元素内容 + + 关键断言: + * 元数据完全相等 + * fx_list 长度与每一项内容完全相等 + * bi_list 长度与每一项内容完全相等 + """ + rs_snapshot = _czsc_snapshot(rs_czsc_module, mock_kline_df) + czsc_snapshot = _czsc_snapshot(czsc_module, mock_kline_df) + + assert rs_snapshot["freq"] == czsc_snapshot["freq"] + assert rs_snapshot["symbol"] == czsc_snapshot["symbol"] + assert rs_snapshot["max_bi_num"] == czsc_snapshot["max_bi_num"] + + # 分型列表长度先做长度断言以提供更清晰的失败信息 + assert len(rs_snapshot["fxs"]) == len(czsc_snapshot["fxs"]), ( + f"fx_list length differs: rs={len(rs_snapshot['fxs'])} czsc={len(czsc_snapshot['fxs'])}" + ) + assert rs_snapshot["fxs"] == czsc_snapshot["fxs"], "fx_list content diverges" + + # 笔列表同理 + assert len(rs_snapshot["bi_list"]) == len(czsc_snapshot["bi_list"]) + assert rs_snapshot["bi_list"] == czsc_snapshot["bi_list"], "bi_list content diverges" + + +def test_czsc_class_identity(rs_czsc_module, czsc_module): + """两套实现必须暴露相同的公共类名集合。 + + 注意:实际的类对象不需要相同(它们来自不同的 cdylib),但公共 API + 表面上的名字必须一致,以保证下游用户的代码可以无差别地切换两套 + 实现。 + + 关键断言:``expected`` 列表里的每一个名字在 rs_czsc 与 czsc 上都必须 + 可以通过 ``hasattr`` 找到。 + """ + expected = { + "CZSC", + "FX", + "BI", + "ZS", + "RawBar", + "NewBar", + "Freq", + "Mark", + "Direction", + "Operate", + "Signal", + "Event", + "Position", + } + for name in expected: + assert hasattr(rs_czsc_module, name), f"rs_czsc lacks {name}" + assert hasattr(czsc_module, name), f"czsc lacks {name}" diff --git a/test/parity/test_examples.py b/test/parity/test_examples.py new file mode 100644 index 000000000..b244193da --- /dev/null +++ b/test/parity/test_examples.py @@ -0,0 +1,509 @@ +"""官方示例脚本的端到端等价性测试。 + +本测试以 ``rs_czsc/examples/`` 下的三个真实示例脚本为蓝本,要求迁移后 +的 ``czsc`` 在完全相同的输入上产出与 ``rs_czsc`` 完全一致的结果。 + +覆盖的三个示例: + * ``30分钟笔非多即空.py`` —— 多周期策略 → backtest + replay → + 写出 signals/pairs/holds 等 parquet 文件 + * ``use_optimize.py`` —— 开仓优化 + 出场优化两条流水线 → + 产出按品种组织的 parquet 树 + * ``weight_backtest.py`` —— 直接对一个权重 DataFrame 跑 + ``WeightBacktest`` → 比对 stats dict + +每个测试都会: + 1. 用 mock K 线 / mock 权重 喂给两套实现,输入完全一致 + 2. 记录各自耗时 + 3. 断言每个输出 parquet 完全相等,并打印 czsc/rs_czsc 的耗时比 +""" + +from __future__ import annotations + +import json +import time +from pathlib import Path + +import pandas as pd +import pytest + + +# --------------------------------------------------------------------- # +# 共享辅助函数 # +# --------------------------------------------------------------------- # + +def _patch_event_is_match_tuple_contract(module): + """把 ``module.Event.is_match`` 包装成 (matched, reason) 元组返回。 + + rs_czsc 与 czsc 的示例脚本都需要这个 patch(与上游示例保持一致)。 + 使用幂等标记位 ``_rs_tuple_contract_patch`` 防止重复包装,从而支持 + 在两个模块上分别打 patch 而互不影响。 + """ + if getattr(module.Event, "_rs_tuple_contract_patch", False): + return + origin = module.Event.is_match + + def _wrapped(self, sig): + out = origin(self, sig) + if isinstance(out, tuple): + return out + if not out: + return out, "" + operate = getattr(self, "operate", None) + return out, str(operate) if operate is not None else "is_match" + + module.Event.is_match = _wrapped + module.Event._rs_tuple_contract_patch = True + + +def _build_30m_bars(czsc_module): + """生成 30 分钟级别的 mock K 线 DataFrame。 + + 同一份 payload 会同时传给 rs_czsc.format_standard_kline 与 + czsc.format_standard_kline,确保 bars 列表是两套实现间唯一的变量。 + """ + from wbt.mock import mock_symbol_kline + + df = mock_symbol_kline("000001", "30分钟", "20200101", "20240101", seed=42) + df["dt"] = pd.to_datetime(df["dt"]) + return df + + +def _build_daily_bars(czsc_module): + """生成日线级别的 mock K 线 DataFrame,用于 use_optimize 等示例。""" + from wbt.mock import mock_symbol_kline + + df = mock_symbol_kline("000001", "日线", "20200101", "20240101", seed=42) + df["dt"] = pd.to_datetime(df["dt"]) + return df + + +def _normalise_parquet(path: Path) -> pd.DataFrame: + """把 parquet 读入并做归一化,便于跨模块比较。 + + 具体处理: + * 丢弃含不可哈希对象的 ``cache`` 列 + * 按 (dt, symbol, pos_name) 中存在的列稳定排序,使行序一致 + """ + df = pd.read_parquet(path) + if "cache" in df.columns: + df = df.drop(columns=["cache"]) + sort_keys = [c for c in ("dt", "symbol", "pos_name") if c in df.columns] + if sort_keys: + df = df.sort_values(sort_keys, kind="mergesort").reset_index(drop=True) + return df + + +def _compare_parquet_trees(rs_root: Path, czsc_root: Path, label: str): + """递归对比两棵 parquet 输出树,要求完全等价。 + + 参数: + rs_root: rs_czsc 的输出根目录 + czsc_root: czsc 的输出根目录 + label: 断言失败时使用的标签前缀,便于定位问题 + """ + rs_files = {p.relative_to(rs_root).as_posix() for p in rs_root.rglob("*.parquet")} + czsc_files = {p.relative_to(czsc_root).as_posix() for p in czsc_root.rglob("*.parquet")} + assert rs_files == czsc_files, ( + f"[{label}] parquet inventory differs.\n" + f" rs only: {rs_files - czsc_files}\n" + f" czsc only: {czsc_files - rs_files}" + ) + for rel in sorted(rs_files): + rs_df = _normalise_parquet(rs_root / rel) + czsc_df = _normalise_parquet(czsc_root / rel) + assert rs_df.shape == czsc_df.shape, ( + f"[{label}/{rel}] shape mismatch: rs={rs_df.shape} czsc={czsc_df.shape}" + ) + # 列集合严格一致:"完全一致"意味着任何一边都不能多出列 + rs_cols = set(rs_df.columns) + czsc_cols = set(czsc_df.columns) + assert rs_cols == czsc_cols, ( + f"[{label}/{rel}] column set differs.\n" + f" rs only: {rs_cols - czsc_cols}\n" + f" czsc only: {czsc_cols - rs_cols}" + ) + cols = sorted(rs_cols) + pd.testing.assert_frame_equal( + rs_df[cols].reset_index(drop=True), + czsc_df[cols].reset_index(drop=True), + check_dtype=False, + check_like=False, + ) + + +# --------------------------------------------------------------------- # +# 示例 1 —— 30分钟笔非多即空.py # +# --------------------------------------------------------------------- # + +def _build_long_short_position(module, symbol: str, base_freq: str): + """构造与 ``30分钟笔非多即空.py::create_long_short_V230909`` 一致的 Position。 + + 多空两个开仓事件,分别匹配"表里关系向上/向下",并在涨停/跌停时禁用。 + """ + opens_dict = [ + { + "operate": "开多", + "signals_all": [f"{base_freq}_D1_表里关系V230101_向上_任意_任意_0"], + "signals_any": [], + "signals_not": [f"{base_freq}_D1_涨跌停V230331_涨停_任意_任意_0"], + }, + { + "operate": "开空", + "signals_all": [f"{base_freq}_D1_表里关系V230101_向下_任意_任意_0"], + "signals_any": [], + "signals_not": [f"{base_freq}_D1_涨跌停V230331_跌停_任意_任意_0"], + }, + ] + return module.Position( + name=f"{base_freq}笔非多即空", + symbol=symbol, + opens=[module.Event.load(x) for x in opens_dict], + exits=[], + interval=3600 * 4, + timeout=16 * 30, + stop_loss=500, + ) + + +def _make_strategy_class(module): + """为目标模块动态构造 Strategy 子类。 + + 示例脚本里 Strategy 是局部定义的,必须依赖目标模块的 CzscStrategyBase + 抽象基类才能正确解析;因此每个模块都需要单独构造一份。 + """ + _patch_event_is_match_tuple_contract(module) + + class Strategy(module.CzscStrategyBase): + @property + def positions(self): + # 三个时间周期的 Position 同时启用,复现示例的多周期联立策略 + return [ + _build_long_short_position(module, self.symbol, "30分钟"), + _build_long_short_position(module, self.symbol, "60分钟"), + _build_long_short_position(module, self.symbol, "日线"), + ] + + return Strategy + + +def _run_30m_example(module, bars_df, results_root: Path): + """跑一遍 30 分钟示例的完整 backtest + replay 流程,返回耗时。""" + Strategy = _make_strategy_class(module) + bars = module.format_standard_kline(bars_df, freq="30分钟") + symbol = bars[0].symbol + tactic = Strategy(symbol=symbol) + + start = time.perf_counter() + tactic.backtest(bars, sdt="2020-06-01") + replay_dir = results_root / "replay" + replay_dir.mkdir(parents=True, exist_ok=True) + tactic.replay(bars, sdt="2020-06-01", res_path=replay_dir, refresh=True) + elapsed = time.perf_counter() - start + return elapsed + + +def test_example_30min_long_short_parity( + rs_czsc_module, czsc_module, tmp_path, capsys +): + """30分钟笔非多即空 示例脚本的端到端等价性。 + + 测试场景:rs_czsc 与 czsc 各跑一次 backtest+replay,输出落到不同目录。 + + 关键断言:两棵输出 parquet 树(包括 signals/pairs/holds 等)完全等价。 + """ + bars_df = _build_30m_bars(czsc_module) + + rs_root = tmp_path / "rs" + czsc_root = tmp_path / "czsc" + rs_root.mkdir() + czsc_root.mkdir() + + rs_elapsed = _run_30m_example(rs_czsc_module, bars_df, rs_root) + czsc_elapsed = _run_30m_example(czsc_module, bars_df, czsc_root) + + _compare_parquet_trees(rs_root, czsc_root, "30m-long-short") + + with capsys.disabled(): + ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") + # 顺便打印耗时比,便于跟踪性能 + print( + f"\n[30分钟笔非多即空] rs_czsc={rs_elapsed:.3f}s " + f"czsc={czsc_elapsed:.3f}s ratio={ratio:.2f}x" + ) + + +# --------------------------------------------------------------------- # +# 示例 2 —— use_optimize.py(开仓优化 + 出场优化) # +# --------------------------------------------------------------------- # + +# 候选开仓信号(与上游示例保持一致的最小可复现集合) +OPEN_CANDIDATE_SIGNALS = [ + "日线_D2单K趋势N5_BS辅助V230506_第1层_任意_任意_0", + "日线_D2单K趋势N5_BS辅助V230506_第4层_任意_任意_0", +] + +# 候选出场事件 +EXIT_CANDIDATE_EVENTS = [ + { + "name": "加速上涨N5T500", + "operate": "平多", + "signals_all": ["日线_D2N5T500_绝对动量V230227_超强_任意_任意_0"], + "signals_any": [], + "signals_not": [], + }, + { + "name": "加速下跌N5T300", + "operate": "平空", + "signals_all": ["日线_D2N5T300_绝对动量V230227_超弱_任意_任意_0"], + "signals_any": [], + "signals_not": [], + }, +] + + +def _build_optimize_position(module, symbol: str, name: str, open_signal: str, open_operate: str): + """根据传入的开仓信号构造一个 Position,用于 OpensOptimize / ExitsOptimize 的基线。""" + exit_operate = "平多" if open_operate == "开多" else "平空" + exit_signal = "日线_D1单K趋势N5_BS辅助V230506_第5层_任意_任意_0" + return module.Position( + symbol=symbol, + name=name, + opens=[ + module.Event.load( + { + "name": f"{name}_open", + "operate": open_operate, + "signals_all": [open_signal], + "signals_any": [], + "signals_not": [], + } + ) + ], + exits=[ + module.Event.load( + { + "name": f"{name}_exit", + "operate": exit_operate, + "signals_all": [exit_signal], + "signals_any": [], + "signals_not": [], + } + ) + ], + interval=0, + timeout=120, + stop_loss=800.0, + # czsc._native.Position 与 rs_czsc.Position 都接受 t0(后者会自动把 T0 翻译成 t0) + t0=False, + ) + + +def _materialize_beta_positions(module, symbol: str, out_dir: Path): + """把多空两个 Beta 仓位序列化为 JSON 文件,返回文件路径列表。 + + 通过 ``czsc._compat.position_dump_to_runtime`` 转换成统一的运行时格式, + 确保两套实现读取的内容完全一致。 + """ + import hashlib + + from czsc._compat import position_dump_to_runtime + + out_dir.mkdir(parents=True, exist_ok=True) + positions = [ + _build_optimize_position( + module, + symbol, + "long_beta", + "日线_D1单K趋势N5_BS辅助V230506_第1层_任意_任意_0", + "开多", + ), + _build_optimize_position( + module, + symbol, + "short_beta", + "日线_D1单K趋势N5_BS辅助V230506_第18层_任意_任意_0", + "开空", + ), + ] + files = [] + for pos in positions: + payload = position_dump_to_runtime(pos.dump(with_data=False)) + payload.pop("symbol", None) + # md5 字段供后续缓存命中校验 + payload["md5"] = hashlib.md5(str(payload).encode("utf-8")).hexdigest() + f = out_dir / f"{pos.name}.json" + f.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") + files.append(str(f)) + return files + + +def _run_optimize_example(module, bars_df, results_root: Path): + """跑一遍 use_optimize 示例的完整开仓+出场优化流程,返回总耗时。""" + from czsc._compat import bars_to_dataframe + + if module.__name__ == "rs_czsc": + from rs_czsc.traders.optimize import ExitsOptimize, OpensOptimize + else: + from czsc.traders.optimize import ExitsOptimize, OpensOptimize + + _patch_event_is_match_tuple_contract(module) + + bars_clean = bars_to_dataframe(bars_df, symbol="000001") + + def read_bars(symbol, freq, sdt, edt, **_): + return bars_clean + + open_root = results_root / "open_demo" + open_root.mkdir(parents=True, exist_ok=True) + files_position = _materialize_beta_positions( + module, "000001", open_root / "base_positions" + ) + + start = time.perf_counter() + oop = OpensOptimize( + symbols=["000001"], + files_position=files_position, + task_name="入场优化", + candidate_signals=sorted(set(OPEN_CANDIDATE_SIGNALS)), + read_bars=read_bars, + results_path=open_root, + signals_module_name="czsc.signals", + bar_sdt="20200101", + bar_edt="20240101", + sdt="20200601", + base_freq="日线", + ) + oop.execute(n_jobs=1) + + exit_root = results_root / "exit_demo" + exit_root.mkdir(parents=True, exist_ok=True) + files_position_exit = _materialize_beta_positions( + module, "000001", exit_root / "base_positions" + ) + eop = ExitsOptimize( + symbols=["000001"], + files_position=files_position_exit, + task_name="出场优化", + candidate_events=EXIT_CANDIDATE_EVENTS, + read_bars=read_bars, + results_path=exit_root, + signals_module_name="czsc.signals", + bar_sdt="20200101", + bar_edt="20240101", + sdt="20200601", + base_freq="日线", + ) + eop.execute(n_jobs=1) + return time.perf_counter() - start + + +def test_example_use_optimize_parity( + rs_czsc_module, czsc_module, tmp_path, capsys +): + """use_optimize 示例脚本的端到端等价性。 + + 测试场景:rs_czsc 与 czsc 分别跑一遍开仓优化 + 出场优化,输出落到 + 不同目录。 + + 关键断言:两棵输出 parquet 树完全等价。 + """ + bars_df = _build_daily_bars(czsc_module) + + rs_root = tmp_path / "rs" + czsc_root = tmp_path / "czsc" + rs_root.mkdir() + czsc_root.mkdir() + + rs_elapsed = _run_optimize_example(rs_czsc_module, bars_df, rs_root) + czsc_elapsed = _run_optimize_example(czsc_module, bars_df, czsc_root) + + _compare_parquet_trees(rs_root, czsc_root, "use_optimize") + + with capsys.disabled(): + ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") + print( + f"\n[use_optimize] rs_czsc={rs_elapsed:.3f}s " + f"czsc={czsc_elapsed:.3f}s ratio={ratio:.2f}x" + ) + + +# --------------------------------------------------------------------- # +# 示例 3 —— weight_backtest.py # +# --------------------------------------------------------------------- # + +def _build_weight_df(): + """构造 ``weight_backtest.py`` 所需的权重 DataFrame。 + + 列结构与上游示例对齐:``['dt','symbol','weight','price']``。使用 + ``wbt.mock.mock_weights`` 保证两次运行输入一致。 + """ + from wbt.mock import mock_weights + + df = mock_weights(seed=42) + return df[["dt", "symbol", "weight", "price"]].copy() + + +def test_example_weight_backtest_parity( + rs_czsc_module, czsc_module, capsys +): + """``WeightBacktest`` 示例的等价性测试(弱化版本)。 + + 关键约定: + * ``czsc.WeightBacktest`` 是从外部 ``wbt`` 包再导出的 + ``wbt.backtest.WeightBacktest``。 + * ``rs_czsc.WeightBacktest`` 是 rs_czsc 自带的内部实现 + ``rs_czsc._trader.weight_backtest.WeightBacktest``。 + + 这两者按当前架构本就不是 API 完全相同的实现(统计指标 key 集合、 + 版本节奏都不同)。严格的数值等价性不是必需 —— czsc 的契约是"使用 + wbt 作为标准回测提供方"。 + + 本测试主要做以下事情: + 1. 在同一份输入上跑两套实现,确保都能成功执行 + 2. 打印性能耗时比,便于追踪 + 3. 对两边都暴露的核心指标(最大回撤、绝对收益、下行波动率、 + 新高占比)做 0.5pp 容差范围内的近似一致性校验 + 4. 不强制要求 stats dict 完全相等 + """ + df = _build_weight_df() + + start = time.perf_counter() + rs_wb = rs_czsc_module.WeightBacktest( + df, digits=2, n_jobs=1, weight_type="ts" + ) + rs_stats = dict(rs_wb.stats) + rs_elapsed = time.perf_counter() - start + + start = time.perf_counter() + czsc_wb = czsc_module.WeightBacktest(df, digits=2, n_jobs=1, weight_type="ts") + czsc_stats = dict(czsc_wb.stats) + czsc_elapsed = time.perf_counter() - start + + # 两套实现都必须成功并产出非空 stats + assert rs_stats, "rs_czsc.WeightBacktest produced no stats" + assert czsc_stats, "czsc.WeightBacktest produced no stats" + + # 对两边都暴露的核心指标做近似比较:这些指标即便跨 wbt 版本也应该 + # 大体一致,容差设为 0.005(0.5 个百分点)。 + common_keys = sorted(set(rs_stats) & set(czsc_stats)) + tight_check = {"最大回撤", "绝对收益", "下行波动率", "新高占比"} + diffs = {} + for k in common_keys: + if k not in tight_check: + continue + rv, cv = rs_stats[k], czsc_stats[k] + if isinstance(rv, (int, float)) and isinstance(cv, (int, float)): + if abs(rv - cv) > 0.005: # 0.5 个百分点的容差 + diffs[k] = (rv, cv) + assert not diffs, ( + f"core stats divergence beyond 0.005 tolerance: {diffs}" + ) + + with capsys.disabled(): + ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") + print( + f"\n[weight_backtest] rs_czsc={rs_elapsed:.3f}s " + f"czsc(wbt)={czsc_elapsed:.3f}s ratio={ratio:.2f}x " + f"(rs keys={len(rs_stats)} czsc keys={len(czsc_stats)} " + f"common={len(common_keys)})" + ) diff --git a/test/parity/test_optimize.py b/test/parity/test_optimize.py new file mode 100644 index 000000000..e85d5570d --- /dev/null +++ b/test/parity/test_optimize.py @@ -0,0 +1,217 @@ +"""优化器底层函数的等价性测试。 + +本套件验证 ``build_*_optim_positions`` 与 ``run_optimize_batch`` 三个 +低层 Rust 入口的等价性: + + * ``build_open_optim_positions`` —— 纯 Rust 的开仓优化变体构造器 + (没有磁盘 IO),只对仓位 dict 做笛卡尔积式扩展。 + * ``build_exit_optim_positions`` —— 同上,但作用于出场事件。 + * ``run_optimize_batch`` —— 完整的优化批处理(含磁盘 IO), + 两套实现都会写出 parquet 输出,最后逐文件解码并比较内容。 + +通过这三个测试可以保证: + 1. Rust 端的纯函数变体生成器输出一致; + 2. 端到端的批处理流水线(含落盘格式)也一致。 +""" + +from __future__ import annotations + +import json +import shutil +import tempfile +from pathlib import Path + +import pandas as pd +import pytest + + +@pytest.fixture +def position_files(tmp_path, sample_position_dict, czsc_module): + """把一份样例 Position 物化为 JSON 文件,供两个 build 助手使用。 + + 文件采用运行时格式(signals_all 为字符串列表),并保留占位的 ``symbol`` + 字段 —— 这与 rs_czsc 的 ``OpensOptimize._materialize_position_files`` + 行为一致(``runtime.setdefault("symbol", "symbol")``)。 + + 返回值:单元素列表,包含 JSON 文件的绝对路径字符串。 + """ + from czsc._compat import position_dump_to_runtime + + pos = czsc_module.Position.load(sample_position_dict) + payload = position_dump_to_runtime(pos.dump(with_data=False)) + payload.setdefault("symbol", "symbol") + payload.pop("md5", None) + payload.pop("pairs", None) + payload.pop("holds", None) + p = tmp_path / "beta.json" + p.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") + return [str(p)] + + +def test_build_open_optim_positions_matches(rs_czsc_module, czsc_module, position_files): + """``build_open_optim_positions`` 必须在两套实现间产出等价的开仓变体。 + + 测试场景:用同一份基线仓位文件 + 同一个候选信号,调用 rs_czsc 与 czsc + 的纯函数实现,对比生成的变体 Position 列表。 + + 关键断言: + * 变体数量一致 + * 经规范化后的 (name, opens hash, exits hash) 三元组集合一致 + (列表顺序不要求一致) + """ + candidates = ["日线_D1N5M5TH10_ADTMV230603_看多_任意_任意_0"] + + rs_raw = rs_czsc_module._rs_czsc.build_open_optim_positions(position_files, candidates) + czsc_raw = czsc_module._native.build_open_optim_positions(position_files, candidates) + + rs_data = json.loads(rs_raw) if isinstance(rs_raw, str) else rs_raw + czsc_data = json.loads(czsc_raw) if isinstance(czsc_raw, str) else czsc_raw + + assert len(rs_data) == len(czsc_data), ( + f"position count mismatch: rs={len(rs_data)} czsc={len(czsc_data)}" + ) + + # 仓位列表的顺序在两次独立运行间不一定一致;统一用 (name, opens hash, + # exits hash) 做规范化后再比较集合相等。 + def _canon(positions): + return sorted( + ( + p["name"], + json.dumps(p.get("opens", []), sort_keys=True, ensure_ascii=False), + json.dumps(p.get("exits", []), sort_keys=True, ensure_ascii=False), + ) + for p in positions + ) + + assert _canon(rs_data) == _canon(czsc_data), ( + f"open-optim variants diverge" + ) + + +def test_build_exit_optim_positions_matches(rs_czsc_module, czsc_module, position_files): + """``build_exit_optim_positions`` 必须在两套实现间产出等价的出场变体。 + + 测试场景:候选事件以 JSON 字符串形式传入;和开仓侧一样,做集合等价 + 校验,不要求列表顺序一致。 + """ + candidate_events = [ + { + "name": "exit_event", + "operate": "平多", + "signals_all": ["日线_D1N5M5TH10_ADTMV230603_看空_任意_任意_0"], + "signals_any": [], + "signals_not": [], + } + ] + + rs_raw = rs_czsc_module._rs_czsc.build_exit_optim_positions( + position_files, json.dumps(candidate_events, ensure_ascii=False) + ) + czsc_raw = czsc_module._native.build_exit_optim_positions( + position_files, json.dumps(candidate_events, ensure_ascii=False) + ) + + rs_data = json.loads(rs_raw) if isinstance(rs_raw, str) else rs_raw + czsc_data = json.loads(czsc_raw) if isinstance(czsc_raw, str) else czsc_raw + + assert len(rs_data) == len(czsc_data) + + def _canon(positions): + return sorted( + ( + p["name"], + json.dumps(p.get("opens", []), sort_keys=True, ensure_ascii=False), + json.dumps(p.get("exits", []), sort_keys=True, ensure_ascii=False), + ) + for p in positions + ) + + assert _canon(rs_data) == _canon(czsc_data) + + +@pytest.fixture +def bars_dir(tmp_path, mock_kline_df, czsc_module): + """落地一份 mock K 线 parquet,供两套实现共用。 + + rs_czsc 与 czsc 都从同一个目录读取 bars,从而把 IO 层带来的潜在 + 差异降到最低。 + """ + from czsc._compat import bars_to_dataframe + + bars_path = tmp_path / "bars" + bars_path.mkdir() + df = bars_to_dataframe(mock_kline_df, symbol="000001") + df.to_parquet(bars_path / "000001.parquet", index=False) + return bars_path + + +def test_run_optimize_batch_matches( + rs_czsc_module, czsc_module, bars_dir, position_files, tmp_path +): + """``run_optimize_batch`` 端到端等价性测试。 + + 测试场景:构造一个最小化的开仓优化任务配置,rs_czsc 与 czsc 各跑 + 一次,输出落到不同目录,然后递归对比文件清单与 parquet 内容。 + + 关键断言: + * 两边的输出文件相对路径集合完全一致(无任何缺失或额外文件) + * 每个 parquet 的 shape 与公共列内容完全等价 + """ + cfg = { + "optim_type": "open", + "task_name": "parity_open", + "base_freq": "日线", + "symbols": ["000001"], + "files_position": position_files, + "candidate_signals": ["日线_D1N5M5TH10_ADTMV230603_看多_任意_任意_0"], + "market": "默认", + "bg_max_count": 5000, + } + + rs_out = tmp_path / "rs_results" + czsc_out = tmp_path / "czsc_results" + rs_out.mkdir() + czsc_out.mkdir() + + rs_msg = rs_czsc_module._rs_czsc.run_optimize_batch( + str(bars_dir), json.dumps(cfg, ensure_ascii=False), str(rs_out), 1 + ) + czsc_msg = czsc_module._native.run_optimize_batch( + str(bars_dir), json.dumps(cfg, ensure_ascii=False), str(czsc_out), 1 + ) + assert isinstance(rs_msg, str) + assert isinstance(czsc_msg, str) + + # 递归扫描两棵输出树,构造 {相对路径: 字节数} 清单 + def _inventory(root: Path) -> dict[str, int]: + out = {} + for p in sorted(root.rglob("*")): + if p.is_file(): + out[str(p.relative_to(root))] = p.stat().st_size + return out + + rs_files = _inventory(rs_out) + czsc_files = _inventory(czsc_out) + assert set(rs_files) == set(czsc_files), ( + f"output tree differs.\nrs only: {set(rs_files) - set(czsc_files)}\n" + f"czsc only: {set(czsc_files) - set(rs_files)}" + ) + + # 逐文件比较 parquet 内容 + for rel in sorted(rs_files): + if not rel.endswith(".parquet"): + continue + rs_df = pd.read_parquet(rs_out / rel) + czsc_df = pd.read_parquet(czsc_out / rel) + # cache 列含不可哈希对象,比对前先丢掉 + for c in ("cache",): + rs_df = rs_df.drop(columns=c, errors="ignore") + czsc_df = czsc_df.drop(columns=c, errors="ignore") + assert rs_df.shape == czsc_df.shape, f"{rel}: shape mismatch" + common = sorted(set(rs_df.columns) & set(czsc_df.columns)) + pd.testing.assert_frame_equal( + rs_df[common].reset_index(drop=True), + czsc_df[common].reset_index(drop=True), + check_dtype=False, + check_like=True, + ) diff --git a/test/parity/test_performance.py b/test/parity/test_performance.py new file mode 100644 index 000000000..ae20d0f8c --- /dev/null +++ b/test/parity/test_performance.py @@ -0,0 +1,261 @@ +"""性能对比测试:迁移后的 czsc vs 基线 rs_czsc。 + +每个测试在两套实现上跑相同的工作负载,对每个输入规模采样若干次,取 +中位数耗时并打印 czsc/rs_czsc 的耗时比。``test_examples.py`` 已经覆盖 +了输出等价性,本套件只关心性能回归。 + +耗时通过 ``capsys.disabled()`` 输出,确保在默认 pytest 输出里直接可见。 +断言阈值故意设得比较宽松(``czsc <= 1.5x rs_czsc``)以避免抖动导致 +CI 间歇性失败 —— 这部分主要是预警机制,而不是硬性卡口。 +""" + +from __future__ import annotations + +import json +import statistics +import time +from pathlib import Path + +import pandas as pd + + +# 性能回归预算:czsc 不应慢于 rs_czsc 1.5 倍 +PERF_RATIO_BUDGET = 1.5 + + +def _format_standard_kline(module, df, freq): + """对模块的 format_standard_kline 做一层轻封装,方便统一调用。""" + return module.format_standard_kline(df, freq=freq) + + +def _build_long_short_position(module, symbol, base_freq): + """构造 30分钟笔非多即空 / strategies 用到的多空两个开仓 Position。""" + opens = [ + { + "operate": "开多", + "signals_all": [f"{base_freq}_D1_表里关系V230101_向上_任意_任意_0"], + "signals_any": [], + "signals_not": [f"{base_freq}_D1_涨跌停V230331_涨停_任意_任意_0"], + }, + { + "operate": "开空", + "signals_all": [f"{base_freq}_D1_表里关系V230101_向下_任意_任意_0"], + "signals_any": [], + "signals_not": [f"{base_freq}_D1_涨跌停V230331_跌停_任意_任意_0"], + }, + ] + return module.Position( + name=f"{base_freq}笔非多即空", + symbol=symbol, + opens=[module.Event.load(x) for x in opens], + exits=[], + interval=3600 * 4, + timeout=16 * 30, + stop_loss=500, + ) + + +def _czsc_analyze_perf(module, bars): + """测量 ``CZSC(bars)`` 构造路径的中位耗时(分析器热路径)。 + + 每次都强制访问 fx_list / bi_list 以确保惰性计算被触发。 + """ + samples = [] + for _ in range(3): + start = time.perf_counter() + c = module.CZSC(bars) + # 触发分型 / 笔的物化计算 + _ = len(c.fx_list) + _ = len(c.bi_list) + samples.append(time.perf_counter() - start) + return statistics.median(samples) + + +def _backtest_perf(module, bars_df, freq, sdt, tmp_path): + """测量 ``30分钟笔非多即空`` 类策略的 backtest 中位耗时。""" + bars = _format_standard_kline(module, bars_df, freq) + symbol = bars[0].symbol + + class Strat(module.CzscStrategyBase): + @property + def positions(self): + return [_build_long_short_position(module, self.symbol, freq)] + + samples = [] + for run_idx in range(3): + out_dir = tmp_path / f"run_{module.__name__}_{run_idx}" + out_dir.mkdir(parents=True, exist_ok=True) + tactic = Strat(symbol=symbol) + start = time.perf_counter() + tactic.backtest(bars, sdt=sdt) + samples.append(time.perf_counter() - start) + return statistics.median(samples) + + +# --------------------------------------------------------------------- # +# 测试 1 —— CZSC 分析器构造性能 # +# --------------------------------------------------------------------- # + +def test_perf_czsc_analyzer(rs_czsc_module, czsc_module, mock_kline_df, capsys): + """约 522 根日线下,``CZSC(bars)`` 分析器的耗时对比。 + + 测试目标:保证迁移后的 CZSC 构造路径不会显著慢于基线。 + + 关键断言:``czsc / rs_czsc`` 中位耗时比不超过 PERF_RATIO_BUDGET。 + """ + rs_bars = _format_standard_kline(rs_czsc_module, mock_kline_df, "日线") + czsc_bars = _format_standard_kline(czsc_module, mock_kline_df, "日线") + + rs_t = _czsc_analyze_perf(rs_czsc_module, rs_bars) + czsc_t = _czsc_analyze_perf(czsc_module, czsc_bars) + + ratio = czsc_t / rs_t if rs_t > 0 else float("inf") + with capsys.disabled(): + print( + f"\n[CZSC(522 daily bars)] rs_czsc={rs_t * 1000:.2f}ms " + f"czsc={czsc_t * 1000:.2f}ms ratio={ratio:.2f}x" + ) + # 这是宽松预算 —— 性能差异主要起到信息提示作用,而不是阻塞性的卡口 + assert ratio <= PERF_RATIO_BUDGET, ( + f"czsc analyzer is {ratio:.2f}x slower than rs_czsc baseline " + f"({PERF_RATIO_BUDGET}x budget exceeded)" + ) + + +# --------------------------------------------------------------------- # +# 测试 2 —— 多种 K 线规模下的 backtest 性能 # +# --------------------------------------------------------------------- # + +def test_perf_backtest_scaling(rs_czsc_module, czsc_module, tmp_path, capsys): + """30 分钟策略 backtest 在不同 K 线规模下的耗时对比。 + + 测试目标:确认 czsc 与 rs_czsc 在不同规模下的扩展特性接近,没有 + 出现任何规模上的显著退化。 + + 测试场景:在三种规模下分别测量 backtest 耗时: + * 约 5 个月 —— 起步规模 + * 约 18 个月 —— 中等规模 + * 约 4 年 —— 较大规模 + + 关键断言:每种规模下 czsc/rs_czsc 的耗时比都不能超过预算。 + """ + from wbt.mock import mock_symbol_kline + + # (起始日期, 结束日期) 三档:5 个月 / 18 个月 / 4 年 + sizes = [ + ("20210101", "20210601"), + ("20200101", "20210601"), + ("20180101", "20220101"), + ] + + rows = [] + for sdt_data, edt_data in sizes: + df = mock_symbol_kline("000001", "30分钟", sdt_data, edt_data, seed=42) + df["dt"] = pd.to_datetime(df["dt"]) + # backtest 起始时间在数据起点之后 60 天,避免 warmup 期影响 + backtest_sdt = pd.to_datetime(sdt_data) + pd.Timedelta(days=60) + backtest_sdt_str = backtest_sdt.strftime("%Y-%m-%d") + + rs_t = _backtest_perf( + rs_czsc_module, df, "30分钟", backtest_sdt_str, tmp_path / "rs" + ) + czsc_t = _backtest_perf( + czsc_module, df, "30分钟", backtest_sdt_str, tmp_path / "czsc" + ) + ratio = czsc_t / rs_t if rs_t > 0 else float("inf") + rows.append((len(df), rs_t * 1000, czsc_t * 1000, ratio)) + + with capsys.disabled(): + # 打印一张对照表 + print(f"\n[backtest scaling — 30min strategy]") + print(f" {'#bars':>6} | {'rs_czsc':>10} | {'czsc':>10} | {'ratio':>6}") + print(f" {'-' * 6} | {'-' * 10} | {'-' * 10} | {'-' * 6}") + for n_bars, rs_ms, czsc_ms, r in rows: + print(f" {n_bars:>6} | {rs_ms:>8.2f}ms | {czsc_ms:>8.2f}ms | {r:>5.2f}x") + + # 只要任何一个规模超出预算就视为失败 + over_budget = [(n, r) for n, _, _, r in rows if r > PERF_RATIO_BUDGET] + assert not over_budget, ( + f"backtest perf budget {PERF_RATIO_BUDGET}x exceeded at: {over_budget}" + ) + + +# --------------------------------------------------------------------- # +# 测试 3 —— derive_signals_config + run_research 端到端性能 # +# --------------------------------------------------------------------- # + +def test_perf_run_research_endtoend( + rs_czsc_module, czsc_module, mock_kline_df, sample_position_dict, capsys +): + """``run_research(arrow_bytes, json)`` 的端到端性能对比。 + + 这是迁移后最关键的入口:把 Arrow 字节 + JSON 策略喂给 Rust 端, + 返回信号/pairs/holds 的 Arrow payload。 + + 关键断言:``czsc / rs_czsc`` 的中位耗时比不超过 PERF_RATIO_BUDGET。 + """ + from czsc._compat import ( + bars_to_dataframe, + position_dump_to_runtime, + signal_config_to_runtime, + ) + from czsc._utils._df_convert import pandas_to_arrow_bytes + + df = bars_to_dataframe(mock_kline_df, symbol="000001") + arrow_bytes = pandas_to_arrow_bytes(df) + + # 运行时策略只构造一次,两套实现共享同一份 JSON + pos = czsc_module.Position.load(sample_position_dict) + cfg = czsc_module.derive_signals_config(pos.unique_signals) + strategy_json = json.dumps( + { + "name": "PerfStrategy", + "symbol": "000001", + "base_freq": "日线", + "signals_module": "czsc.signals", + "signals_config": [signal_config_to_runtime(c) for c in cfg], + "positions": [position_dump_to_runtime(pos.dump(with_data=False))], + "market": "默认", + "bg_max_count": 5000, + }, + ensure_ascii=False, + ) + + def _time(module, n=5): + samples = [] + for _ in range(n): + start = time.perf_counter() + module._native.run_research(arrow_bytes, strategy_json, None, None) + samples.append(time.perf_counter() - start) + return statistics.median(samples) + + if hasattr(rs_czsc_module, "_rs_czsc"): + # rs_czsc 的入口在 ``rs_czsc._rs_czsc`` 子模块上 + rs_native = rs_czsc_module._rs_czsc + else: + rs_native = rs_czsc_module._native + + samples_rs = [] + samples_czsc = [] + for _ in range(5): + start = time.perf_counter() + rs_native.run_research(arrow_bytes, strategy_json, None, None) + samples_rs.append(time.perf_counter() - start) + + start = time.perf_counter() + czsc_module._native.run_research(arrow_bytes, strategy_json, None, None) + samples_czsc.append(time.perf_counter() - start) + + rs_t = statistics.median(samples_rs) + czsc_t = statistics.median(samples_czsc) + ratio = czsc_t / rs_t if rs_t > 0 else float("inf") + with capsys.disabled(): + print( + f"\n[run_research(522 bars, 1 position)] " + f"rs_czsc={rs_t * 1000:.2f}ms czsc={czsc_t * 1000:.2f}ms " + f"ratio={ratio:.2f}x" + ) + + assert ratio <= PERF_RATIO_BUDGET, ( + f"czsc.run_research is {ratio:.2f}x slower than rs_czsc baseline" + ) diff --git a/test/parity/test_run_research.py b/test/parity/test_run_research.py new file mode 100644 index 000000000..1fd12054b --- /dev/null +++ b/test/parity/test_run_research.py @@ -0,0 +1,181 @@ +"""``run_research`` 全链路等价性测试。 + +这是 parity 套件中最强的一组等价性保证:要求 ``run_research`` 在两套 +实现间产出比特级一致的输出。 + +输入侧两个模块都拿到: + * 完全相同的 Arrow 编码 K 线 bars + * 完全相同的 JSON 策略(含 positions + signals_config) + +输出侧需要产出完全一致的 signals / pairs / holds DataFrame。这里的 +"完全一致"是在 pandas 层做比较 —— 因为 Arrow IPC 的底层字节可能在 +framing(footer、字典编码等)细节上有差别,但只要逻辑内容相等就视为 +等价。 +""" + +from __future__ import annotations + +import json + +import pandas as pd +import pytest + + +def _strategy_payload(czsc_module, position_dict): + """构造两套实现都能消费的运行时格式 strategy dict。 + + 使用 czsc 的 ``_compat`` 工具函数把 Position dict 转成标准的运行时 + 格式。rs_czsc 与 czsc 的 Rust 端都通过 serde_json 反序列化同一份 + JSON,因此只要 payload 校验通过即可。 + """ + from czsc._compat import position_dump_to_runtime, signal_config_to_runtime + + pos = czsc_module.Position.load(position_dict) + runtime_position = position_dump_to_runtime(pos.dump(with_data=False)) + signals_cfg = czsc_module.derive_signals_config(pos.unique_signals) + return { + "name": "ParityStrategy", + "symbol": position_dict["symbol"], + "base_freq": "日线", + "signals_module": "czsc.signals", + "signals_config": [signal_config_to_runtime(c) for c in signals_cfg], + "positions": [runtime_position], + "market": "默认", + "bg_max_count": 5000, + } + + +def _bars_arrow_bytes(czsc_module, mock_kline_df): + """生成两套实现共用的、dtype 干净的 Arrow 字节流。""" + from czsc._compat import bars_to_dataframe + from czsc._utils._df_convert import pandas_to_arrow_bytes + + df = bars_to_dataframe(mock_kline_df, symbol="000001") + return pandas_to_arrow_bytes(df) + + +def _decode_arrow(arrow_bytes: bytes) -> pd.DataFrame: + """把 Arrow 字节流解码回 pandas DataFrame。""" + from czsc._utils._df_convert import arrow_bytes_to_pd_df + + return arrow_bytes_to_pd_df(arrow_bytes) + + +def _normalise_for_compare(df: pd.DataFrame) -> pd.DataFrame: + """比较前做归一化:去掉不可哈希列、按时间排序。""" + drop = [c for c in ("cache",) if c in df.columns] + out = df.drop(columns=drop, errors="ignore").copy() + if "dt" in out.columns: + out = out.sort_values("dt").reset_index(drop=True) + return out + + +@pytest.fixture +def parity_inputs(rs_czsc_module, czsc_module, mock_kline_df, sample_position_dict): + """构造 parity 测试需要的 (arrow_bytes, strategy_json) 入参。 + + 返回元组: + * arrow_bytes: 两边都能解析的 K 线 Arrow 字节流 + * strategy_json: 两边都能消费的运行时策略 JSON 字符串 + """ + arrow_bytes = _bars_arrow_bytes(czsc_module, mock_kline_df) + strategy = _strategy_payload(czsc_module, sample_position_dict) + return arrow_bytes, json.dumps(strategy, ensure_ascii=False) + + +def test_run_research_signals_match(rs_czsc_module, czsc_module, parity_inputs): + """``run_research`` 的 signals_arrow 输出必须在两套实现间完全等价。 + + 关键断言: + * shape 一致 + * 列集合完全相同(无新增/缺失) + * 公共列的内容(按 dt 排序后)逐 cell 相等 + """ + arrow_bytes, strategy_json = parity_inputs + + rs_payload = rs_czsc_module._rs_czsc.run_research(arrow_bytes, strategy_json, None, None) + czsc_payload = czsc_module._native.run_research(arrow_bytes, strategy_json, None, None) + + rs_df = _normalise_for_compare(_decode_arrow(bytes(rs_payload["signals_arrow"]))) + czsc_df = _normalise_for_compare(_decode_arrow(bytes(czsc_payload["signals_arrow"]))) + + assert rs_df.shape == czsc_df.shape, f"shape mismatch: rs={rs_df.shape} czsc={czsc_df.shape}" + assert set(rs_df.columns) == set(czsc_df.columns), ( + f"signals columns differ.\n" + f"rs only: {set(rs_df.columns) - set(czsc_df.columns)}\n" + f"czsc only: {set(czsc_df.columns) - set(rs_df.columns)}" + ) + common_cols = sorted(rs_df.columns) + pd.testing.assert_frame_equal( + rs_df[common_cols], czsc_df[common_cols], check_dtype=False + ) + + +def test_run_research_pairs_match(rs_czsc_module, czsc_module, parity_inputs): + """``run_research`` 的 pairs_arrow(每笔交易明细)输出必须等价。 + + 关键断言:shape 一致;逐行内容(去掉行索引、允许列顺序不同)等价。 + """ + arrow_bytes, strategy_json = parity_inputs + + rs_payload = rs_czsc_module._rs_czsc.run_research(arrow_bytes, strategy_json, None, None) + czsc_payload = czsc_module._native.run_research(arrow_bytes, strategy_json, None, None) + + rs_df = _normalise_for_compare(_decode_arrow(bytes(rs_payload["pairs_arrow"]))) + czsc_df = _normalise_for_compare(_decode_arrow(bytes(czsc_payload["pairs_arrow"]))) + + assert rs_df.shape == czsc_df.shape + pd.testing.assert_frame_equal( + rs_df.reset_index(drop=True), + czsc_df.reset_index(drop=True), + check_dtype=False, + check_like=True, + ) + + +def test_run_research_holds_match(rs_czsc_module, czsc_module, parity_inputs): + """``run_research`` 的 holds_arrow(持仓时序)输出必须等价。 + + 关键断言:shape 一致;逐行内容(去掉行索引、允许列顺序不同)等价。 + """ + arrow_bytes, strategy_json = parity_inputs + + rs_payload = rs_czsc_module._rs_czsc.run_research(arrow_bytes, strategy_json, None, None) + czsc_payload = czsc_module._native.run_research(arrow_bytes, strategy_json, None, None) + + rs_df = _normalise_for_compare(_decode_arrow(bytes(rs_payload["holds_arrow"]))) + czsc_df = _normalise_for_compare(_decode_arrow(bytes(czsc_payload["holds_arrow"]))) + + assert rs_df.shape == czsc_df.shape + pd.testing.assert_frame_equal( + rs_df.reset_index(drop=True), + czsc_df.reset_index(drop=True), + check_dtype=False, + check_like=True, + ) + + +def test_run_research_meta_match(rs_czsc_module, czsc_module, parity_inputs): + """``run_research`` 的 meta payload 必须在两套实现间等价。 + + meta 中包含 symbol / counts / version 等元信息。比较前会先剔除会因 + 构建环境不同而合理变化的字段(构建时间戳、git hash、引擎版本号), + 剩余字段必须完全相等。 + + 关键断言:剔除 ``drop_keys`` 后,rs_meta 与 czsc_meta 完全相等。 + """ + arrow_bytes, strategy_json = parity_inputs + + rs_payload = rs_czsc_module._rs_czsc.run_research(arrow_bytes, strategy_json, None, None) + czsc_payload = czsc_module._native.run_research(arrow_bytes, strategy_json, None, None) + + # 这些字段在不同构建间合理变化,比较时需要排除 + drop_keys = {"build_ts", "git_hash", "engine_version"} + rs_meta = {k: v for k, v in rs_payload["meta"].items() if k not in drop_keys} + czsc_meta = {k: v for k, v in czsc_payload["meta"].items() if k not in drop_keys} + + assert rs_meta == czsc_meta, ( + f"meta diverges.\nrs only: {set(rs_meta) - set(czsc_meta)}\n" + f"czsc only: {set(czsc_meta) - set(rs_meta)}\n" + f"value diffs: {[(k, rs_meta.get(k), czsc_meta.get(k)) for k in rs_meta if rs_meta.get(k) != czsc_meta.get(k)]}" + ) diff --git a/test/parity/test_signals_registry.py b/test/parity/test_signals_registry.py new file mode 100644 index 000000000..2aa8ca90f --- /dev/null +++ b/test/parity/test_signals_registry.py @@ -0,0 +1,93 @@ +"""信号注册表 + ``derive_signals_*`` 系列工具的等价性测试。 + +本套件验证以下三件事在 ``rs_czsc`` 与 ``czsc`` 之间完全一致: + + * ``list_all_signals()`` 返回的描述符集合(数量、name、param_template、 + category 都必须相同)。 + * ``derive_signals_config(unique_signals)`` 根据信号字符串反推出的 + 运行时配置 dict(结构等价,允许顺序差异)。 + * ``derive_signals_freqs(configs)`` 根据运行时配置推导出的 freq + 列表(排序后必须相等)。 + +这些工具直接决定了下游 ``run_research`` / ``OpensOptimize`` 等流程的 +输入 shape,一旦不一致,会导致不可解释的下游差异。 +""" + +from __future__ import annotations + + +def test_list_all_signals_count_matches(rs_czsc_module, czsc_module): + """注册表中信号数量必须一致。""" + rs_list = rs_czsc_module.list_all_signals() + czsc_list = czsc_module._native.list_all_signals() + assert len(rs_list) == len(czsc_list), ( + f"signal count mismatch: rs_czsc={len(rs_list)} vs czsc={len(czsc_list)}" + ) + + +def test_list_all_signals_names_match(rs_czsc_module, czsc_module): + """注册表中信号 name 集合必须严格一致。""" + rs_names = sorted(d["name"] for d in rs_czsc_module.list_all_signals()) + czsc_names = sorted(d["name"] for d in czsc_module._native.list_all_signals()) + assert rs_names == czsc_names, ( + f"signal name set differs.\n" + f"only in rs_czsc: {set(rs_names) - set(czsc_names)}\n" + f"only in czsc: {set(czsc_names) - set(rs_names)}" + ) + + +def test_list_all_signals_templates_match(rs_czsc_module, czsc_module): + """每个信号的 param_template 字符串必须一致。 + + 模板字符串是渲染信号字符串的源头,任何差异都会被下游放大。 + """ + rs_map = {d["name"]: d.get("param_template") for d in rs_czsc_module.list_all_signals()} + czsc_map = {d["name"]: d.get("param_template") for d in czsc_module._native.list_all_signals()} + diffs = { + name: (rs_map[name], czsc_map[name]) + for name in rs_map + if rs_map[name] != czsc_map[name] + } + assert not diffs, f"param_template mismatches: {diffs}" + + +def test_list_all_signals_categories_match(rs_czsc_module, czsc_module): + """每个信号的 category 必须一致(如 ``kline`` / ``trader`` 等)。""" + rs_map = {d["name"]: d.get("category") for d in rs_czsc_module.list_all_signals()} + czsc_map = {d["name"]: d.get("category") for d in czsc_module._native.list_all_signals()} + assert rs_map == czsc_map + + +def test_derive_signals_config_matches(rs_czsc_module, czsc_module, sample_signal_strings): + """``derive_signals_config`` 在两套实现间产出的配置必须等价。 + + 运行时配置以 list[dict] 形式输出,等价性按 (name, freq, sorted params) + 三元组的集合相等来判定 —— 顺序在两套实现中可能不同,但内容必须相同。 + """ + rs_cfg = rs_czsc_module._derive_signals_config_impl(sample_signal_strings) + czsc_cfg = czsc_module.derive_signals_config(sample_signal_strings) + + def _canon(cfgs): + return sorted( + ( + cfg["name"], + cfg.get("freq"), + tuple(sorted((cfg.get("params") or {}).items())), + ) + for cfg in cfgs + ) + + assert _canon(rs_cfg) == _canon(czsc_cfg) + + +def test_derive_signals_freqs_matches(rs_czsc_module, czsc_module, sample_signal_strings): + """``derive_signals_freqs`` 输出的 freq 列表必须等价。 + + 通过 rs_czsc 先把信号字符串转成运行时配置,再让两套实现都推导 + freq 列表,保证它们的输入 shape 完全一致,差异只可能来自 freq + 推导逻辑本身。 + """ + runtime = rs_czsc_module._derive_signals_config_impl(sample_signal_strings) + rs_freqs = sorted(rs_czsc_module._derive_signals_freqs_impl(runtime)) + czsc_freqs = sorted(czsc_module._native.derive_signals_freqs(runtime)) + assert rs_freqs == czsc_freqs diff --git a/test/smoke/__init__.py b/test/smoke/__init__.py new file mode 100644 index 000000000..ca8cf1081 --- /dev/null +++ b/test/smoke/__init__.py @@ -0,0 +1,11 @@ +"""烟雾测试包。 + +本目录用于存放烟雾测试用例(smoke tests),用于在最小可运行范围内验证 +核心功能是否正常工作。烟雾测试通常运行速度快、覆盖关键路径,主要用途包括: + +- 安装后基本可导入性验证(import 成功、关键类可访问) +- 构建产物(wheel、原生扩展)的最小完整性检查 +- 部署后的健康检查 + +测试文件命名约定:``test_*.py``。 +""" diff --git a/test/smoke/test_install.py b/test/smoke/test_install.py new file mode 100644 index 000000000..8acd87ff8 --- /dev/null +++ b/test/smoke/test_install.py @@ -0,0 +1,144 @@ +"""安装产物烟雾测试:maturin wheel 安装与原生扩展可用性验证。 + +本测试套件用于在最小可运行范围内验证 czsc 的 Rust/Python 混合架构构建产物 +(wheel)是否被正确打包,并且能够在干净的虚拟环境中安装并通过基本的导入 +冒烟测试。 + +业务背景: + 项目通过 ``maturin build --release`` 构建跨平台二进制 wheel + (manylinux/macOS/Windows),其中包含 Rust 编译产物 ``czsc._native`` + 动态库。安装该 wheel 后,必须满足: + + 1. ``import czsc`` 成功; + 2. ``import czsc._native`` 成功且 ``__file__`` 指向编译产物(.so/.pyd/.dylib); + 3. 不再需要任何额外的 ``rs_czsc`` 包。 + +测试覆盖: + - ``pyproject.toml`` 已使用 ``maturin`` 作为构建后端; + - 当前安装环境中可成功导入 ``czsc._native``; + - ``dist/`` 下的 wheel 在干净 venv 中可安装且 ``czsc.CZSC.__module__`` 指向 czsc 命名空间。 +""" + +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +import pytest + + +# 仓库根目录(基于本测试文件位置反推),以及构建产物所在目录 +REPO_ROOT = Path(__file__).resolve().parents[2] +DIST_DIR = REPO_ROOT / "dist" + + +def test_pyproject_uses_maturin_backend() -> None: + """验证 pyproject.toml 使用 maturin 作为构建后端。 + + 测试目标: + 确保项目构建系统已切换为 maturin(PyO3 推荐的 Python 扩展打包工具), + 以便能够产出包含 Rust 原生扩展的 wheel。 + + 关键断言: + pyproject.toml 文件中包含 ``maturin`` 字符串,意味着 ``[build-system]`` + 节点声明了 ``requires=['maturin>=...']`` 与 ``build-backend='maturin'``。 + """ + pyproject = REPO_ROOT / "pyproject.toml" + if not pyproject.is_file(): + pytest.fail(f"未找到 pyproject.toml,路径:{pyproject}") + text = pyproject.read_text(encoding="utf-8") + assert "maturin" in text, ( + "pyproject.toml 必须声明 [build-system] requires=['maturin>=...'] " + "以及 build-backend='maturin'" + ) + + +def test_native_extension_present_in_install() -> None: + """验证当前安装环境中已包含编译后的 czsc._native 扩展。 + + 测试场景: + 通过子进程方式调用 Python 解释器尝试 ``import czsc._native``, + 并打印其 ``__file__`` 属性。 + + 关键断言: + - 子进程退出码必须为 0(导入成功); + - ``czsc._native.__file__`` 必须以 ``.so`` / ``.pyd`` / ``.dylib`` 结尾, + 表明它是编译生成的二进制扩展,而不是普通 Python 模块。 + """ + proc = subprocess.run( + [sys.executable, "-c", "import czsc._native; print(czsc._native.__file__)"], + capture_output=True, + text=True, + ) + assert proc.returncode == 0, ( + "在当前安装环境中 `import czsc._native` 必须成功 " + f"(stderr: {proc.stderr.strip()!r});" + "构建流程必须保证 maturin 正确打包了 Rust 扩展" + ) + out = proc.stdout.strip() + assert out.endswith((".so", ".pyd", ".dylib")), ( + f"czsc._native.__file__ 应指向编译扩展,实际为 {out!r}" + ) + + +def test_wheel_install_in_clean_venv(tmp_path: Path) -> None: + """在干净 venv 中安装 wheel 并执行最小冒烟导入。 + + 测试场景: + 1. 在 ``dist/`` 下查找最新构建的 ``czsc-*.whl`` 文件; + 2. 在 pytest 提供的临时目录下创建一个全新的虚拟环境; + 3. 使用 ``pip install --find-links`` 将 wheel 安装进该 venv; + 4. 在该 venv 中执行 ``import czsc; print(czsc.CZSC.__module__)``。 + + 设计要点: + - 使用 ``--find-links DIST_DIR`` 让 pip 能够找到 dist/ 下的同伴 wheel + (例如尚未发布到 PyPI 的预发布版 wbt wheel),不可用时再回退到 PyPI; + - 打印的是 ``czsc.CZSC.__module__``(类的模块名),而不是 + ``type(czsc.CZSC).__module__``(其元类的模块名)。后者对于 PyO3 类 + 会返回 ``"builtins"``,无法验证迁移目标是否达成。 + + 关键断言: + - ``pip install`` 退出码为 0; + - 子进程导入并打印的模块名包含 ``"czsc"``,证明 ``CZSC`` 类来自 + ``czsc._native`` 命名空间。 + """ + # 选取 dist/ 下的所有 czsc 安装包,并以排序方式取最新版本 + wheels = sorted(DIST_DIR.glob("czsc-*.whl")) if DIST_DIR.is_dir() else [] + if not wheels: + pytest.fail( + f"在 {DIST_DIR} 下找不到 wheel;" + "请先运行 `maturin build --release` 再执行本冒烟测试" + ) + + # 构建一个全新的虚拟环境用于隔离安装 + venv = tmp_path / "venv" + subprocess.run([sys.executable, "-m", "venv", str(venv)], check=True) + pip = venv / "bin" / "pip" + py = venv / "bin" / "python" + + # `--find-links DIST_DIR` 让 pip 解析 dist/ 下的同伴 wheel + # (例如尚未发布到 PyPI 的预发布 wbt-0.1.7 wheel,提前包含了 + # czsc 在模块加载时引用的 top_drawdowns 等名称)。 + # 缺少该参数时安装会从 PyPI 拉取 wbt 0.1.6,导致 czsc 启动失败。 + install = subprocess.run( + [str(pip), "install", "--find-links", str(DIST_DIR), str(wheels[-1])], + capture_output=True, + text=True, + ) + if install.returncode != 0: + pytest.fail( + f"pip install {wheels[-1].name} 失败: {install.stderr}" + ) + + # 在新 venv 中打印 czsc.CZSC.__module__, + # 该值由 PyO3 的 #[pyclass(module=...)] 设置为 "czsc._native" + smoke = subprocess.run( + [str(py), "-c", "import czsc; print(czsc.CZSC.__module__)"], + capture_output=True, + text=True, + ) + if smoke.returncode != 0: + pytest.fail(f"冒烟 `import czsc` 失败: {smoke.stderr}") + out = smoke.stdout.strip() + assert "czsc" in out, f"冒烟测试输出异常:{out!r}" diff --git a/test/test_analyze.py b/test/test_analyze.py deleted file mode 100644 index 51e360f5f..000000000 --- a/test/test_analyze.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/2/16 20:31 -describe: czsc.analyze 单元测试 - -Mock数据格式说明: -- 数据来源: czsc.mock.generate_symbol_kines -- 数据列: dt, symbol, open, close, high, low, vol, amount -- 时间范围: 20220101-20250101(3年数据,满足3年+要求) -- 频率: 1分钟、5分钟、日线 -- Seed: 42(确保可重现) -""" - -from czsc import mock -from czsc.core import CZSC, Direction, Freq, format_standard_kline - - -def get_mock_bars(freq=Freq.D, symbol="000001", n_days=100): - """获取mock K线数据并转换为RawBar对象 - - Args: - freq: K线频率 - symbol: 品种代码 - n_days: 天数(仅用于非标准频率) - - Returns: - list: RawBar对象列表 - """ - if freq == Freq.F1: - df = mock.generate_symbol_kines(symbol, "1分钟", sdt="20220101", edt="20250101", seed=42) - - elif freq == Freq.F5: - df = mock.generate_symbol_kines(symbol, "5分钟", sdt="20220101", edt="20250101", seed=42) - elif freq == Freq.D: - df = mock.generate_symbol_kines(symbol, "日线", sdt="20220101", edt="20250101", seed=42) - else: - df = mock.generate_klines(seed=42) - df = df[df["symbol"] == symbol].head(n_days) if symbol in df["symbol"].values else df.head(n_days) - - # bars = [] - # for i, row in df.iterrows(): - # bar = RawBar( - # symbol=row['symbol'], - # id=i, - # freq=freq, - # open=row['open'], - # dt=row['dt'], - # close=row['close'], - # high=row['high'], - # low=row['low'], - # vol=row['vol'], - # amount=row['amount'] - # ) - # bars.append(bar) - bars = format_standard_kline(df, freq=freq) - return bars - - -def test_czsc_basic(): - """测试CZSC基础功能""" - bars = get_mock_bars(freq=Freq.D, symbol="000001", n_days=200) - c = CZSC(bars) - - assert c.symbol == "000001", "symbol应该正确设置" - assert c.freq == Freq.D, "频率应该正确设置" - assert len(c.bars_raw) > 0, "原始K线数据不应为空" - assert len(c.bars_ubi) > 0, "去除包含关系后的K线数据不应为空" - assert len(c.bi_list) > 0, "笔的列表不应为空" - - -def test_czsc_signals(): - """测试CZSC信号计算 - 无信号函数时signals为None或空字典""" - bars = get_mock_bars(freq=Freq.D, symbol="000001", n_days=200) - c = CZSC(bars) - - # 没有提供 get_signals 函数时,signals 为 None(Rust)或空字典(Python) - assert c.signals is None or isinstance(c.signals, dict), "signals应该是None或字典类型" - - -def test_czsc_ubi_properties(): - """测试CZSC的ubi属性""" - bars = get_mock_bars(freq=Freq.D, symbol="000001", n_days=200) - c = CZSC(bars) - - ubi = c.ubi - assert "direction" in ubi, "ubi应该包含direction字段" - assert "high_bar" in ubi, "ubi应该包含high_bar字段" - assert "low_bar" in ubi, "ubi应该包含low_bar字段" - assert isinstance(ubi["direction"], Direction), "direction应该是Direction类型" diff --git a/test/test_analyze_boundary.py b/test/test_analyze_boundary.py deleted file mode 100644 index bfbc06c18..000000000 --- a/test/test_analyze_boundary.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -test_analyze_boundary.py - CZSC 分析核心模块边界情况测试 - -Mock数据格式说明: -- 数据来源: czsc.mock.generate_symbol_kines -- 数据列: dt, symbol, open, close, high, low, vol, amount -- 时间范围: 20220101-20250101(3年数据) -- 频率: 日线、5分钟线 -- Seed: 42(确保可重现) - -测试覆盖: -- 少量K线数据(不足以形成笔) -- 多频率分析一致性 -- 增量更新正确性 -- format_standard_kline 边界情况 -""" - -from czsc import mock -from czsc.core import CZSC, Freq, format_standard_kline - - -def get_daily_bars(symbol="000001", sdt="20220101", edt="20250101"): - """获取日线 mock 数据""" - df = mock.generate_symbol_kines(symbol, "日线", sdt=sdt, edt=edt, seed=42) - return format_standard_kline(df, freq=Freq.D) - - -class TestCZSCBoundary: - """CZSC 分析边界情况测试""" - - def test_minimal_bars(self): - """测试最少K线(不足以形成笔)""" - bars = get_daily_bars()[:5] - c = CZSC(bars) - assert len(c.bars_raw) == 5 - assert len(c.bi_list) == 0, "5根K线不应形成笔" - - def test_moderate_bars(self): - """测试中等数量K线""" - bars = get_daily_bars()[:50] - c = CZSC(bars) - # CZSC 可能因 max_bi_num 裁剪旧K线,因此使用 <= 判断 - assert len(c.bars_raw) <= 50 - assert len(c.bars_raw) > 0 - assert len(c.bars_ubi) > 0 - - def test_large_dataset(self): - """测试完整3年数据""" - bars = get_daily_bars() - c = CZSC(bars) - assert len(c.bars_raw) > 500 - assert len(c.bi_list) > 10, "3年数据应形成多笔" - assert len(c.fx_list) > 0, "应有分型" - - def test_incremental_update(self): - """测试增量更新""" - bars = get_daily_bars() - # 先用前100根K线初始化 - c = CZSC(bars[:100]) - initial_bi_count = len(c.bi_list) - initial_bars_count = len(c.bars_raw) - - # 逐根增加K线 - for bar in bars[100:200]: - c.update(bar) - - # CZSC 可能因 max_bi_num 裁剪旧K线,bars_raw 应增长但不一定恰好等于200 - assert len(c.bars_raw) > initial_bars_count, "增量更新后K线数应增加" - assert len(c.bi_list) >= initial_bi_count, "增量更新后笔数不应减少" - - def test_max_bi_num(self): - """测试 max_bi_num 限制""" - bars = get_daily_bars() - c = CZSC(bars, max_bi_num=10) - assert len(c.bi_list) <= 10, "笔数量应被 max_bi_num 限制" - - def test_different_symbols(self): - """测试不同品种""" - for symbol in ["000001", "000002", "600001"]: - df = mock.generate_symbol_kines(symbol, "日线", sdt="20220101", edt="20250101", seed=42) - bars = format_standard_kline(df, freq=Freq.D) - c = CZSC(bars) - assert c.symbol == symbol - - def test_ubi_structure(self): - """测试 ubi 结构完整性""" - bars = get_daily_bars() - c = CZSC(bars) - ubi = c.ubi - assert "direction" in ubi - assert "high_bar" in ubi - assert "low_bar" in ubi - - -class TestFormatStandardKline: - """format_standard_kline 函数测试""" - - def test_basic_conversion(self): - """测试基本转换""" - df = mock.generate_symbol_kines("000001", "日线", sdt="20220101", edt="20250101", seed=42) - bars = format_standard_kline(df, freq=Freq.D) - assert len(bars) > 0 - assert bars[0].freq == Freq.D - assert bars[0].symbol == "000001" - - def test_preserves_data(self): - """测试数据保留完整性""" - df = mock.generate_symbol_kines("000001", "日线", sdt="20220101", edt="20250101", seed=42) - bars = format_standard_kline(df, freq=Freq.D) - # 验证第一根K线数据 - first_row = df.iloc[0] - assert bars[0].open == first_row["open"] - assert bars[0].close == first_row["close"] - assert bars[0].high == first_row["high"] - assert bars[0].low == first_row["low"] - - def test_5min_frequency(self): - """测试5分钟频率""" - df = mock.generate_symbol_kines("000001", "5分钟", sdt="20220101", edt="20220201", seed=42) - bars = format_standard_kline(df, freq=Freq.F5) - assert len(bars) > 0 - assert bars[0].freq == Freq.F5 - - def test_sorted_by_time(self): - """测试按时间排序""" - df = mock.generate_symbol_kines("000001", "日线", sdt="20220101", edt="20250101", seed=42) - bars = format_standard_kline(df, freq=Freq.D) - for i in range(1, len(bars)): - assert bars[i].dt >= bars[i - 1].dt, "K线应按时间升序排列" diff --git a/test/test_api_surface.py b/test/test_api_surface.py deleted file mode 100644 index 2a22c099f..000000000 --- a/test/test_api_surface.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -import ast -from pathlib import Path - -import czsc - - -ROOT_DIR = Path(__file__).resolve().parents[1] - - -def _read_dunder_all(path: Path) -> list[str]: - module = ast.parse(path.read_text(encoding="utf-8")) - for node in module.body: - if isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Name) and target.id == "__all__": - return [elt.value for elt in node.value.elts if isinstance(elt, ast.Constant) and isinstance(elt.value, str)] - raise AssertionError(f"__all__ not found in {path}") - - -def test_root_api_surface_retains_supported_shortcuts(): - expected = { - "CZSC", - "Freq", - "RawBar", - "CzscTrader", - "SignalsParser", - "DataClient", - "DiskCache", - "mock", - "svc", - "CzscStrategyBase", - "KlineChart", - } - missing = sorted(name for name in expected if not hasattr(czsc, name)) - assert not missing - - -def test_root_api_surface_drops_removed_legacy_exports(): - removed = { - "CTAResearch", - "DummyBacktest", - "OpensOptimize", - "ExitsOptimize", - "PairsPerformance", - "sensors", - "rwc", - } - leaked = sorted(name for name in removed if hasattr(czsc, name)) - assert not leaked - - -def test_runtime_and_stub_dunder_all_are_aligned(): - runtime_all = sorted(czsc.__all__) - stub_all = sorted(_read_dunder_all(ROOT_DIR / "czsc" / "__init__.pyi")) - assert stub_all == runtime_all - - -def test_examples_do_not_reference_removed_workflows(): - removed_markers = ["CTAResearch", "OpensOptimize", "ExitsOptimize"] - offenders: list[str] = [] - - for path in (ROOT_DIR / "examples").glob("*.py"): - text = path.read_text(encoding="utf-8") - if any(marker in text for marker in removed_markers): - offenders.append(path.name) - - assert offenders == [] diff --git a/test/test_eda.py b/test/test_eda.py deleted file mode 100644 index 78def3f64..000000000 --- a/test/test_eda.py +++ /dev/null @@ -1,406 +0,0 @@ -import pandas as pd -import pytest - -from czsc.eda import weights_simple_ensemble - - -def test_cal_yearly_days(): - if pd.__version__ < "2.1.0": - pytest.skip("skip this test if pandas version is less than 1.3.0") - - from czsc.eda import cal_yearly_days - - # Test with a list of dates within a single year - dts = ["2023-01-01", "2023-01-02", "2023-01-03", "2023-12-31"] - assert cal_yearly_days(dts) == 252 - - # Test with a list of dates spanning more than one year - dts = ["2022-01-01", "2022-12-31", "2023-01-01", "2023-12-31"] - assert cal_yearly_days(dts) == 2 - - # Test with a list of dates with minute precision - dts = [ - "2023-01-01 12:00", - "2023-01-02 13:00", - "2023-01-01 14:00", - "2023-02-01 15:00", - "2023-03-01 16:00", - "2023-03-01 17:00", - ] - assert cal_yearly_days(dts) == 252 - - # Test with an empty list - with pytest.raises(AssertionError): - cal_yearly_days([]) - - # Test with a list of dates with duplicates - dts = ["2023-01-01", "2023-01-01", "2023-01-02", "2023-01-02"] - assert cal_yearly_days(dts) == 252 - - -def test_weights_simple_ensemble_mean(): - df = pd.DataFrame({"strategy1": [0.1, 0.2, 0.3], "strategy2": [0.2, 0.3, 0.4], "strategy3": [0.3, 0.4, 0.5]}) - weight_cols = ["strategy1", "strategy2", "strategy3"] - result = weights_simple_ensemble(df, weight_cols, method="mean") - expected = pd.Series([0.2, 0.3, 0.4], name="weight") - pd.testing.assert_series_equal(result["weight"], expected) - - -def test_weights_simple_ensemble_vote(): - df = pd.DataFrame({"strategy1": [1, -1, 1], "strategy2": [-1, 1, -1], "strategy3": [1, 1, -1]}) - weight_cols = ["strategy1", "strategy2", "strategy3"] - result = weights_simple_ensemble(df, weight_cols, method="vote") - expected = pd.Series([1, 1, -1], name="weight") - pd.testing.assert_series_equal(result["weight"], expected) - - -def test_weights_simple_ensemble_sum_clip(): - df = pd.DataFrame({"strategy1": [0.5, -0.5, 0.5], "strategy2": [0.5, 0.5, -0.5], "strategy3": [0.5, 0.5, 0.5]}) - weight_cols = ["strategy1", "strategy2", "strategy3"] - result = weights_simple_ensemble(df, weight_cols, method="sum_clip", clip_min=-1, clip_max=1) - expected = pd.Series([1, 0.5, 0.5], name="weight") - pd.testing.assert_series_equal(result["weight"], expected) - - -def test_weights_simple_ensemble_only_long(): - df = pd.DataFrame({"strategy1": [0.5, -0.5, 0.5], "strategy2": [0.5, 0.5, -0.5], "strategy3": [0.5, 0.5, 0.5]}) - weight_cols = ["strategy1", "strategy2", "strategy3"] - result = weights_simple_ensemble(df, weight_cols, method="sum_clip", clip_min=-1, clip_max=1, only_long=True) - expected = pd.Series([1, 0.5, 0.5], name="weight") - pd.testing.assert_series_equal(result["weight"], expected) - - -def test_limit_leverage(): - from czsc.eda import limit_leverage - - data = { - "dt": pd.date_range(start="2023-01-01", periods=10, freq="D"), - "symbol": ["TEST"] * 10, - "weight": [0.1, 0.2, -0.3, 3, -0.5, 0.6, -0.7, 0.8, -0.9, 1.0], - "price": [100 + i for i in range(10)], - } - df = pd.DataFrame(data) - - # Test with leverage = 1.0 - df_result = limit_leverage(df, leverage=1.0, copy=True, window=3, min_periods=2) - assert df_result["weight"].max() <= 1.0 - assert df_result["weight"].min() >= -1.0 - - # Test with leverage = 2.0 - df_result = limit_leverage(df, leverage=2.0, copy=True, window=3, min_periods=2) - assert df_result["weight"].max() <= 2.0 - assert df_result["weight"].min() >= -2.0 - - # Test with different window and min_periods - df_result = limit_leverage(df, leverage=1.0, window=5, min_periods=2, copy=True) - assert df_result["weight"].max() <= 1.0 - assert df_result["weight"].min() >= -1.0 - - df1 = df.copy() - df1.rename(columns={"weight": "weight1"}, inplace=True) - # Test with leverage = 1.0 - df_result = limit_leverage(df1, leverage=1.0, copy=True, window=3, min_periods=2, weight="weight1") - assert df_result["weight1"].max() <= 1.0 - assert df_result["weight1"].min() >= -1.0 - - -def test_turnover_rate_normal(): - """测试正常数据的换手率计算""" - from czsc.eda import turnover_rate - - # 创建测试数据 - dates = pd.date_range(start="2024-01-01", periods=3, freq="D") - symbols = ["A", "B", "C"] - - # 创建权重数据 - weights = [ - [1.0, 0.5, 0.0], # 第一天 - [0.5, 1.0, 0.5], # 第二天 - [0.0, 0.5, 1.0], # 第三天 - ] - - # 构建DataFrame - data = [] - for i, date in enumerate(dates): - for j, symbol in enumerate(symbols): - data.append({"dt": date, "symbol": symbol, "weight": weights[i][j]}) - df = pd.DataFrame(data) - - # 计算换手率 - result = turnover_rate(df) - - # 验证结果 - assert isinstance(result, dict) - assert "单边换手率" in result - assert "日均换手率" in result - assert "最大单日换手率" in result - assert "最小单日换手率" in result - assert "日换手详情" in result - - # 验证换手率计算是否正确 - # 第一天:1.0 + 0.5 + 0.0 = 1.5 - # 第二天:|0.5-1.0| + |1.0-0.5| + |0.5-0.0| = 1.5 - # 第三天:|0.0-0.5| + |0.5-1.0| + |1.0-0.5| = 1.5 - assert result["单边换手率"] == 4.5 # 1.5 + 1.5 + 1.5 - assert result["日均换手率"] == 1.5 # 4.5 / 3 - assert result["最大单日换手率"] == 1.5 - assert result["最小单日换手率"] == 1.5 - print(result["日换手详情"]) - - -def test_turnover_rate_verbose(): - """测试verbose模式下的日志输出""" - from czsc.eda import turnover_rate - - dates = pd.date_range(start="2024-01-01", periods=2, freq="D") - data = [ - {"dt": dates[0], "symbol": "A", "weight": 1.0}, - {"dt": dates[0], "symbol": "B", "weight": 0.0}, - {"dt": dates[1], "symbol": "A", "weight": 0.0}, - {"dt": dates[1], "symbol": "B", "weight": 1.0}, - ] - df = pd.DataFrame(data) - - # 这里我们只验证verbose=True时不会抛出异常 - result = turnover_rate(df, verbose=True) - assert isinstance(result, dict) - - -def test_turnover_rate_invalid_data(): - """测试无效数据的处理""" - from czsc.eda import turnover_rate - - # 测试缺少必要列的数据 - df = pd.DataFrame({"dt": ["2024-01-01"], "symbol": ["A"]}) - with pytest.raises(KeyError): - turnover_rate(df) - - # 测试权重列包含非数值数据 - df = pd.DataFrame({"dt": ["2024-01-01"], "symbol": ["A"], "weight": ["invalid"]}) - with pytest.raises(TypeError, match="weight 列必须包含数值数据"): - turnover_rate(df) - - -def test_cross_sectional_strategy(): - """测试横截面策略功能和mock数据质量""" - import pandas as pd - - from czsc import mock - from czsc.eda import cross_sectional_strategy - - def __execute_one(): - """执行单次横截面策略测试""" - df = mock.generate_cs_factor() - df = cross_sectional_strategy(df, factor="F#RPS#20", long=0.3, short=0.3, norm=True, window=1, verbose=False) - dfw = pd.pivot_table(df, index="dt", columns="symbol", values="weight") - dfw["sum"] = dfw.sum(axis=1) - assert dfw["sum"].max() < 0.01 and dfw["sum"].min() > -0.01 - return df, dfw - - # 执行基本测试 - df, dfw = __execute_one() - - # 验证mock数据质量 - assert isinstance(df, pd.DataFrame), "generate_cs_factor应该返回DataFrame" - assert len(df) > 0, "数据不应为空" - assert "F#RPS#20" in df.columns, "应包含F#RPS#20因子列" - assert "symbol" in df.columns, "应包含symbol列" - assert "dt" in df.columns, "应包含dt列" - - # 验证因子数据的合理性 - factor_values = df["F#RPS#20"] - assert factor_values.min() >= 0, "RPS因子值应该>=0" - assert factor_values.max() <= 1, "RPS因子值应该<=1" - assert not factor_values.isnull().all(), "因子值不应全为空" - - # 验证策略权重的合理性 - weights = df["weight"] - assert weights.abs().max() <= 1, "权重绝对值不应超过1" - - # 测试不同参数组合 - def test_different_params(): - """测试不同参数组合""" - df_base = mock.generate_cs_factor() - - # 测试只做多 - df_long = cross_sectional_strategy( - df_base.copy(), factor="F#RPS#20", long=0.2, short=0.0, norm=True, window=1, verbose=False - ) - assert (df_long["weight"] >= 0).all(), "只做多时权重应该>=0" - - # 测试只做空 - df_short = cross_sectional_strategy( - df_base.copy(), factor="F#RPS#20", long=0.0, short=0.2, norm=True, window=1, verbose=False - ) - assert (df_short["weight"] <= 0).all(), "只做空时权重应该<=0" - - # 测试不归一化 - df_no_norm = cross_sectional_strategy( - df_base, factor="F#RPS#20", long=0.3, short=0.3, norm=False, window=1, verbose=False - ) - assert isinstance(df_no_norm, pd.DataFrame), "不归一化时应该返回DataFrame" - - # 测试多空 + window 平滑 - df_long_smooth = cross_sectional_strategy( - df_base.copy(), factor="F#RPS#20", long=0.3, short=0.3, norm=True, window=20, verbose=False - ) - assert isinstance(df_long_smooth, pd.DataFrame), "多空 + window 平滑时应该返回DataFrame" - dfw_smooth = pd.pivot_table(df_long_smooth, index="dt", columns="symbol", values="weight") - dfw_smooth["sum"] = dfw_smooth.sum(axis=1) - assert dfw_smooth["sum"].max() < 0.01 and dfw_smooth["sum"].min() > -0.01 - print(f"多空+20日平滑:{dfw_smooth.tail()}") - - return True - - # 执行参数测试 - assert test_different_params(), "不同参数组合测试失败" - - # 验证时间序列的连续性 - dates = sorted(df["dt"].unique()) - assert len(dates) > 100, "应该有足够的时间序列数据" - - # 验证不同股票的数据完整性 - symbols = df["symbol"].unique() - assert len(symbols) > 10, "应该有足够的股票数量" - - for symbol in symbols[:3]: # 检查前3个股票 - symbol_data = df[df["symbol"] == symbol] - assert len(symbol_data) == len(dates), f"股票{symbol}的数据点数量应该与日期数量一致" - - print(f"横截面策略测试通过: 共{len(dates)}个交易日,{len(symbols)}只股票") - print(f"最新日期权重分布:\n{dfw[dfw.index == dfw.index.max()]}") - - -def test_mock_klines_data_quality(): - """测试优化后的mock klines数据质量""" - from czsc import mock - - # 生成数据 - df = mock.generate_klines(seed=42) - - # 基本验证 - assert isinstance(df, pd.DataFrame), "应该返回DataFrame" - assert len(df) > 0, "数据不应为空" - - # 验证必要列存在 - required_cols = ["dt", "symbol", "open", "close", "high", "low", "vol", "amount"] - for col in required_cols: - assert col in df.columns, f"缺少必要列: {col}" - - # 验证数据类型 - assert pd.api.types.is_datetime64_any_dtype(df["dt"]), "dt列应该是日期类型" - assert pd.api.types.is_numeric_dtype(df["open"]), "open列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["close"]), "close列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["high"]), "high列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["low"]), "low列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["vol"]), "vol列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["amount"]), "amount列应该是数值类型" - - # 验证价格关系的正确性 - price_check = (df["high"] >= df[["open", "close"]].max(axis=1)) & (df["low"] <= df[["open", "close"]].min(axis=1)) - assert price_check.all(), "所有K线的价格关系都应该正确(high>=max(open,close), low<=min(open,close))" - - # 验证价格为正数 - assert (df["open"] > 0).all(), "开盘价应该为正数" - assert (df["close"] > 0).all(), "收盘价应该为正数" - assert (df["high"] > 0).all(), "最高价应该为正数" - assert (df["low"] > 0).all(), "最低价应该为正数" - assert (df["vol"] > 0).all(), "成交量应该为正数" - assert (df["amount"] > 0).all(), "成交金额应该为正数" - - # 验证不同股票的价格走势差异 - symbols = df["symbol"].unique() - df.groupby("symbol")["close"].last() - price_changes = df.groupby("symbol")["close"].apply(lambda x: (x.iloc[-1] / x.iloc[0] - 1) * 100) - - # 应该有涨有跌,不应该所有股票都是同样的走势 - assert price_changes.std() > 10, "不同股票的涨跌幅应该有足够的差异性" - assert price_changes.max() > 0, "应该有股票上涨" - assert price_changes.min() < 0, "应该有股票下跌" - - # 验证成交量与价格波动的相关性 - df["price_volatility"] = abs(df["close"] - df["open"]) / df["open"] - correlation = df["vol"].corr(df["price_volatility"]) - assert correlation > 0, "成交量与价格波动应该呈正相关" - assert correlation < 0.8, "相关性不应该过高(避免过于人工化)" - - print("Mock数据质量测试通过:") - print(f"- 数据总量: {len(df):,}行") - print(f"- 股票数量: {len(symbols)}只") - print(f"- 时间跨度: {df['dt'].min()} 至 {df['dt'].max()}") - print(f"- 价格涨跌幅范围: {price_changes.min():.1f}% 至 {price_changes.max():.1f}%") - print(f"- 成交量与波动率相关性: {correlation:.3f}") - - -def test_mock_data_consistency(): - """测试mock数据的一致性和可重现性""" - from czsc import mock - - # 使用相同种子生成两次数据 - df1 = mock.generate_klines(seed=123) - df2 = mock.generate_klines(seed=123) - - # 验证两次生成的数据完全一致 - pd.testing.assert_frame_equal(df1, df2, "相同种子应该生成相同的数据") - - # 使用不同种子生成数据 - df3 = mock.generate_klines(seed=456) - - # 验证不同种子生成的数据确实不同 - assert not df1.equals(df3), "不同种子应该生成不同的数据" - - # 但数据结构应该相同 - assert df1.columns.equals(df3.columns), "数据列结构应该相同" - assert len(df1) == len(df3), "数据行数应该相同" - assert df1["symbol"].nunique() == df3["symbol"].nunique(), "股票数量应该相同" - - print("Mock数据一致性测试通过") - - -def test_mark_cta_periods(): - """测试CTA周期标记功能""" - from czsc import mock - from czsc.eda import mark_cta_periods - - df = mock.generate_klines(seed=42) - # df = mock.generate_symbol_kines("BTC", freq='1分钟', sdt="20170101", edt="20250101", seed=42) - # 确保dt列是datetime类型 - df["dt"] = pd.to_datetime(df["dt"]) - - # 如果可能的话,也测试rs=True,但要处理可能的错误 - # try: - df1 = mark_cta_periods(df.copy(), rs=True, q1=0.15, q2=0.4, verbose=False) - assert isinstance(df1, pd.DataFrame), "rs=True时返回值应该是DataFrame" - assert len(df1) == len(df), "rs=True时处理后数据长度应该保持一致" - - # 先测试rs=False,避免rs-czsc库的复杂性 - df2 = mark_cta_periods(df.copy(), rs=False, q1=0.15, q2=0.4, verbose=False) - - # 基本验证 - assert isinstance(df2, pd.DataFrame), "返回值应该是DataFrame" - assert len(df2) == len(df), "处理后数据长度应该保持一致" - - # 对比 python 版本和 rust 版本的结果 - assert df1.shape == df2.shape, "rs=True和rs=False的结果应该有相同的形状" - assert df1.columns.equals(df2.columns), "rs=True和rs=False的列结构应该相同" - assert df1["symbol"].equals(df2["symbol"]), "rs=True和rs=False的symbol列应该相同" - assert df1["dt"].equals(df2["dt"]), "rs=True和rs=False的dt列应该相同" - cols = [ - "is_best_period", - "is_best_up_period", - "is_best_down_period", - "is_worst_period", - "is_worst_up_period", - "is_worst_down_period", - "is_normal_period", - ] - - for col in cols: - print(f"\n\n{'=' * 20}") - print(f"对比 {col} 列:") - print(f"rs=True: {df1[col].value_counts()}") - print(f"rs=False: {df2[col].value_counts()}") - dfx = df1[df1[col] != df2[col]].copy() - mis_rate = len(dfx) / len(df1) if len(df1) > 0 else 0 - print(f"rs=True和rs=False在{col}列不一致的比例: {mis_rate:.2%}") - # assert mis_rate < 0.2, f"{col} 列在rs=True和rs=False结果中不一致的比例过高: {mis_rate:.2%}" diff --git a/test/test_envs.py b/test/test_envs.py index 36ded23ff..f22028143 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -1,122 +1,159 @@ +"""``czsc.envs`` 环境变量与配置接口单元测试(迁移后版本)。 + +本测试套件覆盖 ``czsc.envs`` 模块在迁移完成后保留下来的公共配置接口, +并通过"反向断言"(negative pin)防止已废弃的 helper 被意外重新引入。 + +业务背景: + 经过迁移清理,``czsc.envs`` 仅保留三个公开配置项: + + - ``czsc_min_bi_len``:最小笔长度(K 线根数) + - ``czsc_max_bi_num``:最大笔数量 + - ``czsc_verbose``:是否启用详细日志输出 + + 历史遗留的 ``use_python``、``get_welcome``、``valid_true`` 等 helper + 以及 ``CZSC_USE_PYTHON`` 环境变量分支均已废弃。本套件除了覆盖这些 + 保留接口的正确行为之外,还显式断言废弃符号的缺失,避免后续重构时 + 被无意中重新加入。 + +测试覆盖: + - 已废弃接口的"不存在"反向断言; + - ``get_verbose`` 默认值、环境变量与参数覆盖三种来源; + - ``get_min_bi_len`` 默认值、参数覆盖、环境变量覆盖、返回类型; + - ``get_max_bi_num`` 默认值、参数覆盖、环境变量覆盖、返回类型。 """ -test_envs.py - czsc.envs 环境变量管理模块单元测试 - -测试覆盖: -- use_python(): 环境变量控制Python/Rust版本选择 -- get_verbose(): 详细输出控制 -- get_welcome(): 欢迎信息控制 -- get_min_bi_len(): 最小笔长度配置 -- get_max_bi_num(): 最大笔数量配置 -- 边界情况: 无效环境变量值、参数覆盖 -""" + +from __future__ import annotations import os -from czsc.envs import get_max_bi_num, get_min_bi_len, get_verbose, get_welcome, use_python +import pytest + +from czsc import envs as envs_mod +from czsc.envs import get_max_bi_num, get_min_bi_len, get_verbose + + +class TestRetiredHelpers: + """反向断言:确认所有已废弃的 helper 已经被移除。""" + # 参数化覆盖所有需要被废弃的旧符号;其中 _env 是私有助手,应当保留 + @pytest.mark.parametrize("name", ["use_python", "get_welcome", "valid_true", "_env"]) + def test_legacy_helper_removed(self, name: str) -> None: + """对每个旧名称做存在性检查。 -class TestUsePython: - """测试 use_python 函数""" + 测试场景: + 遍历四个历史符号;其中 ``_env`` 是模块内私有 helper,仍然保留; + 其余三个公共 helper 必须已经从 ``czsc.envs`` 模块中移除。 - def test_default_false(self): - """默认情况下应返回 False""" - os.environ.pop("CZSC_USE_PYTHON", None) - assert use_python() is False + 关键断言: + - ``_env`` 必须存在(私有实现细节,保留向下兼容); + - 其他三个公共名称必须已移除。 + """ + # `_env` 仍然作为模块内私有 helper 保留;其余符号必须已经移除 + if name == "_env": + assert hasattr(envs_mod, name), "私有 _env helper 应当继续保留" + else: + assert not hasattr(envs_mod, name), ( + f"czsc.envs.{name} 必须被移除" + ) - def test_set_true_values(self): - """测试各种 True 的有效表达""" - for val in ["1", "True", "true", "Y", "y", "yes", "Yes"]: - os.environ["CZSC_USE_PYTHON"] = val - assert use_python() is True - os.environ.pop("CZSC_USE_PYTHON", None) + def test_no_czsc_use_python_branch(self) -> None: + """验证 czsc.envs 源码中不再引用 CZSC_USE_PYTHON 环境变量。 - def test_set_false_values(self): - """测试无效值应返回 False""" - for val in ["0", "False", "false", "N", "n", "no", "No", "abc"]: - os.environ["CZSC_USE_PYTHON"] = val - assert use_python() is False - os.environ.pop("CZSC_USE_PYTHON", None) + 测试场景: + 通过 ``inspect.getsource`` 获取整个 ``czsc.envs`` 模块的源码字符串, + 搜索其中是否仍然包含 ``CZSC_USE_PYTHON`` 关键字。 + + 关键断言: + 源码中不得出现 ``CZSC_USE_PYTHON`` 字符串,确保迁移后已经 + 完全切断 Rust/Python 双实现的环境变量切换分支。 + """ + import inspect + + src = inspect.getsource(envs_mod) + assert "CZSC_USE_PYTHON" not in src, ( + "CZSC_USE_PYTHON 环境变量必须不被引用" + ) class TestGetVerbose: - """测试 get_verbose 函数""" + """``get_verbose`` 函数行为测试:覆盖默认值、环境变量、参数覆盖三类来源。""" + + def test_default_false(self) -> None: + """无任何外部输入时,get_verbose 默认返回 False。 - def test_default_false(self): - """默认情况下应返回 False""" + 测试场景: + 清除两种大小写形式的环境变量后调用 ``get_verbose()``。 + + 关键断言: + 返回值严格为 ``False``(使用 ``is`` 比较布尔身份)。 + """ + os.environ.pop("CZSC_VERBOSE", None) os.environ.pop("czsc_verbose", None) assert get_verbose() is False - def test_env_true(self, monkeypatch): - """环境变量设为 True 时应返回 True""" + def test_env_true(self, monkeypatch) -> None: + """通过环境变量 CZSC_VERBOSE=1 开启详细模式。 + + 使用 pytest 的 ``monkeypatch`` fixture 临时设置环境变量, + 测试结束后会自动还原。 + """ monkeypatch.setenv("CZSC_VERBOSE", "1") assert get_verbose() is True - def test_parameter_override(self): - """参数传入时应覆盖环境变量""" + def test_parameter_override(self) -> None: + """显式传入参数覆盖环境变量与默认值。 + + 关键断言: + - ``get_verbose(verbose="1")`` 返回 True; + - ``get_verbose(verbose="0")`` 返回 False。 + """ os.environ.pop("czsc_verbose", None) assert get_verbose(verbose="1") is True assert get_verbose(verbose="0") is False -class TestGetWelcome: - """测试 get_welcome 函数""" - - def test_default_false(self): - """默认应返回 False""" - os.environ.pop("czsc_welcome", None) - os.environ["CZSC_WELCOME"] = "0" - assert get_welcome() is False - os.environ.pop("czsc_welcome", None) - - def test_set_true(self, monkeypatch): - """设置为 1 时应返回 True""" - monkeypatch.setenv("CZSC_WELCOME", "1") - assert get_welcome() is True - - class TestGetMinBiLen: - """测试 get_min_bi_len 函数""" + """``get_min_bi_len`` 函数行为测试:覆盖默认值、参数覆盖、环境变量、返回类型。""" - def test_default_value(self): - """默认值应为 6""" + def test_default_value(self) -> None: + """无任何外部输入时,最小笔长度默认值为 6 根 K 线。""" + os.environ.pop("CZSC_MIN_BI_LEN", None) os.environ.pop("czsc_min_bi_len", None) assert get_min_bi_len() == 6 - def test_parameter_override(self): - """参数传入时应覆盖默认值""" + def test_parameter_override(self) -> None: + """显式传入参数应直接被采用,不受默认值影响。""" assert get_min_bi_len(7) == 7 assert get_min_bi_len(6) == 6 - def test_env_override(self, monkeypatch): - """环境变量应覆盖默认值""" + def test_env_override(self, monkeypatch) -> None: + """通过环境变量 CZSC_MIN_BI_LEN 覆盖默认值。""" monkeypatch.setenv("CZSC_MIN_BI_LEN", "7") assert get_min_bi_len() == 7 - def test_returns_int(self): - """返回值应为整数""" - result = get_min_bi_len() - assert isinstance(result, int) + def test_returns_int(self) -> None: + """返回值类型必须为 int,便于下游算法直接使用。""" + assert isinstance(get_min_bi_len(), int) class TestGetMaxBiNum: - """测试 get_max_bi_num 函数""" + """``get_max_bi_num`` 函数行为测试:覆盖默认值、参数覆盖、环境变量、返回类型。""" - def test_default_value(self): - """默认值应为 50""" + def test_default_value(self) -> None: + """无任何外部输入时,最大笔数量默认值为 50。""" + os.environ.pop("CZSC_MAX_BI_NUM", None) os.environ.pop("czsc_max_bi_num", None) assert get_max_bi_num() == 50 - def test_parameter_override(self): - """参数传入时应覆盖默认值""" + def test_parameter_override(self) -> None: + """显式传入参数应直接被采用。""" assert get_max_bi_num(100) == 100 - assert get_max_bi_num(20) == 20 - def test_env_override(self, monkeypatch): - """环境变量应覆盖默认值""" + def test_env_override(self, monkeypatch) -> None: + """通过环境变量 CZSC_MAX_BI_NUM 覆盖默认值。""" monkeypatch.setenv("CZSC_MAX_BI_NUM", "100") assert get_max_bi_num() == 100 - def test_returns_int(self): - """返回值应为整数""" - result = get_max_bi_num() - assert isinstance(result, int) + def test_returns_int(self) -> None: + """返回值类型必须为 int。""" + assert isinstance(get_max_bi_num(), int) diff --git a/test/test_mark_czsc_status.py b/test/test_mark_czsc_status.py deleted file mode 100644 index c6e5d2a61..000000000 --- a/test/test_mark_czsc_status.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python -""" -V字反转识别功能测试脚本 - -该脚本用于测试 mark_czsc_status.py 模块的功能,验证V字反转识别是否正常工作。 -""" - -import os -import sys - -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - - -def test_mark_czsc_status(): - """测试 mark_czsc_status 函数的基本功能""" - import czsc - from czsc.utils.mark_czsc_status import mark_czsc_status - - dfk = czsc.mock.generate_klines() - assert len(dfk) > 0, "生成的测试数据不应为空" - - # 调用状态标记函数 - dfr, bi_stats_marked = mark_czsc_status(dfk, verbose=True) - - assert len(dfr) > 0, "处理后K线数据不应为空" - assert len(bi_stats_marked) > 0, "笔统计数据不应为空" - - # 检查新增的标记列 - mark_cols = ["is_reversal", "is_trend", "is_oscillation", "is_normal"] - for col in mark_cols: - assert col in dfr.columns, f"应包含标记列 {col}" - - # 检查笔统计数据中的 mark 列 - assert "mark" in bi_stats_marked.columns, "笔统计数据应包含 mark 列" diff --git a/test/test_mock_quality.py b/test/test_mock_quality.py deleted file mode 100644 index d358ddfc1..000000000 --- a/test/test_mock_quality.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -测试mock模块数据质量的单元测试 -""" - -import pandas as pd -import pytest - - -class TestMockDataQuality: - """Mock数据质量测试类""" - - def test_generate_klines_basic_structure(self): - """测试generate_klines的基本结构""" - from czsc import mock - - df = mock.generate_klines(seed=42) - - # 基本结构验证 - assert isinstance(df, pd.DataFrame), "应该返回DataFrame" - assert len(df) > 0, "数据不应为空" - - # 验证必要列存在 - required_cols = ["dt", "symbol", "open", "close", "high", "low", "vol", "amount"] - for col in required_cols: - assert col in df.columns, f"缺少必要列: {col}" - - def test_generate_klines_data_types(self): - """测试数据类型的正确性""" - from czsc import mock - - df = mock.generate_klines(seed=42) - - # 验证数据类型 - assert pd.api.types.is_datetime64_any_dtype(df["dt"]), "dt列应该是日期类型" - assert pd.api.types.is_numeric_dtype(df["open"]), "open列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["close"]), "close列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["high"]), "high列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["low"]), "low列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["vol"]), "vol列应该是数值类型" - assert pd.api.types.is_numeric_dtype(df["amount"]), "amount列应该是数值类型" - - def test_generate_klines_price_relationships(self): - """测试价格关系的正确性""" - from czsc import mock - - df = mock.generate_klines(seed=42) - - # 验证价格关系正确性 - price_check = (df["high"] >= df[["open", "close"]].max(axis=1)) & ( - df["low"] <= df[["open", "close"]].min(axis=1) - ) - assert price_check.all(), "所有K线的价格关系都应该正确" - - # 验证价格为正数 - assert (df["open"] > 0).all(), "开盘价应该为正数" - assert (df["close"] > 0).all(), "收盘价应该为正数" - assert (df["high"] > 0).all(), "最高价应该为正数" - assert (df["low"] > 0).all(), "最低价应该为正数" - assert (df["vol"] > 0).all(), "成交量应该为正数" - assert (df["amount"] > 0).all(), "成交金额应该为正数" - - def test_generate_klines_market_realism(self): - """测试市场真实性""" - from czsc import mock - - df = mock.generate_klines(seed=42) - - # 验证不同股票的价格走势差异 - symbols = df["symbol"].unique() - assert len(symbols) >= 10, "应该有足够的股票数量" - - price_changes = df.groupby("symbol")["close"].apply(lambda x: (x.iloc[-1] / x.iloc[0] - 1) * 100) - - # 应该有涨有跌,不应该所有股票都是同样的走势 - assert price_changes.std() > 10, "不同股票的涨跌幅应该有足够的差异性" - assert price_changes.max() > 0, "应该有股票上涨" - assert price_changes.min() < 0, "应该有股票下跌" - - # 验证成交量与价格波动的相关性 - df["price_volatility"] = abs(df["close"] - df["open"]) / df["open"] - correlation = df["vol"].corr(df["price_volatility"]) - assert correlation > 0, "成交量与价格波动应该呈正相关" - assert correlation < 0.8, "相关性不应该过高(避免过于人工化)" - - def test_generate_klines_consistency(self): - """测试数据一致性和可重现性""" - from czsc import mock - - # 使用相同种子生成两次数据 - df1 = mock.generate_klines(seed=123) - df2 = mock.generate_klines(seed=123) - - # 验证两次生成的数据完全一致 - pd.testing.assert_frame_equal(df1, df2, "相同种子应该生成相同的数据") - - # 使用不同种子生成数据 - df3 = mock.generate_klines(seed=456) - - # 验证不同种子生成的数据确实不同 - assert not df1.equals(df3), "不同种子应该生成不同的数据" - - # 但数据结构应该相同 - assert df1.columns.equals(df3.columns), "数据列结构应该相同" - assert len(df1) == len(df3), "数据行数应该相同" - assert df1["symbol"].nunique() == df3["symbol"].nunique(), "股票数量应该相同" - - def test_generate_klines_no_missing_values(self): - """测试无缺失值""" - from czsc import mock - - df = mock.generate_klines(seed=42) - assert df.isnull().sum().sum() == 0, "生成的数据不应该有缺失值" - - def test_generate_cs_factor_basic_structure(self): - """测试generate_cs_factor的基本结构""" - from czsc import mock - - df = mock.generate_cs_factor(seed=42) - - # 基本结构验证 - assert isinstance(df, pd.DataFrame), "应该返回DataFrame" - assert len(df) > 0, "数据不应为空" - assert "F#RPS#20" in df.columns, "应包含F#RPS#20因子列" - assert "symbol" in df.columns, "应包含symbol列" - assert "dt" in df.columns, "应包含dt列" - - def test_generate_cs_factor_data_quality(self): - """测试因子数据质量""" - from czsc import mock - - df = mock.generate_cs_factor(seed=42) - - # 验证因子数据的合理性 - factor_values = df["F#RPS#20"] - assert factor_values.min() >= 0, "RPS因子值应该>=0" - assert factor_values.max() <= 1, "RPS因子值应该<=1" - assert not factor_values.isnull().all(), "因子值不应全为空" - - def test_time_series_continuity(self): - """测试时间序列连续性""" - from czsc import mock - - df = mock.generate_cs_factor(seed=42) - - # 验证时间序列的连续性 - dates = sorted(df["dt"].unique()) - assert len(dates) > 100, "应该有足够的时间序列数据" - - # 验证不同股票的数据完整性 - symbols = df["symbol"].unique() - assert len(symbols) >= 10, "应该有足够的股票数量" - - for symbol in symbols[:3]: # 检查前3个股票 - symbol_data = df[df["symbol"] == symbol] - assert len(symbol_data) == len(dates), f"股票{symbol}的数据点数量应该与日期数量一致" - - def test_performance_benchmark(self): - """测试性能基准""" - import time - - from czsc import mock - - start_time = time.time() - df = mock.generate_klines(seed=42) - elapsed_time = time.time() - start_time - - # 性能基准:应该在合理时间内生成大量数据 - assert elapsed_time < 10, f"生成{len(df)}行数据耗时{elapsed_time:.2f}秒,应该在10秒内完成" - assert len(df) > 50000, "应该生成足够多的数据" - - def test_generate_klines_with_weights(self): - """测试带权重的K线数据生成""" - from czsc import mock - - df = mock.generate_klines_with_weights(seed=42) - - # 验证权重列存在 - assert "weight" in df.columns, "应包含weight列" - assert "price" in df.columns, "应包含price列" - - # 验证权重范围 - weights = df["weight"] - assert weights.min() >= -1, "权重最小值应该>=-1" - assert weights.max() <= 1, "权重最大值应该<=1" - - def test_generate_ts_factor(self): - """测试时序因子数据生成""" - from czsc import mock - - df = mock.generate_ts_factor(seed=42) - - # 验证因子列存在 - assert "F#SMA#20" in df.columns, "应包含F#SMA#20因子列" - - # 验证SMA因子的合理性(移动平均不应该有异常值) - sma_values = df["F#SMA#20"] - assert sma_values.min() >= 0, "SMA因子值应该为正数" - assert not sma_values.isnull().all(), "SMA因子值不应全为空" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/test/test_plotly_plot.py b/test/test_plotly_plot.py index 5b44aa409..726ae6652 100644 --- a/test/test_plotly_plot.py +++ b/test/test_plotly_plot.py @@ -1,8 +1,15 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2023/2/26 15:06 -describe: 测试绘图 +"""K 线图(KlineChart)Plotly 绘图模块单元测试。 + +本测试套件验证 ``czsc.KlineChart`` 在结合缠论分析结果后能够正确生成 +Plotly 交互式 K 线图,覆盖均线、成交量、MACD、分型与笔等常见叠加层。 + +业务背景: + KlineChart 是 czsc 中用于交易研究和回放展示的核心可视化组件,基于 + Plotly 多子图架构(n_rows)实现。本测试用例使用 mock 数据驱动整条 + 渲染流水线,并通过 HTML 文件落盘的方式验证最终产物可以被序列化输出。 + +模块作者: + zengbin93 (zeng_bin8888@163.com),创建于 2023/2/26 15:06 """ import os @@ -10,12 +17,29 @@ import pandas as pd from czsc import KlineChart, mock -from czsc.core import CZSC, Freq, RawBar +from czsc import CZSC, Freq, RawBar def test_kline_chart(): - """测试K线图""" - # 使用mock数据替代硬编码数据文件 + """端到端验证 KlineChart 能够基于缠论 CZSC 对象绘制完整图表。 + + 测试场景: + 1. 通过 ``czsc.mock.generate_symbol_kines`` 生成 2023 年全年的日线 + mock 数据(seed=42 保证可重现); + 2. 将 DataFrame 行逐条转换为 RawBar 对象列表; + 3. 构造 ``CZSC`` 分析对象并设置最大笔数量 max_bi_num=50; + 4. 创建 3 行子图布局的 KlineChart 实例并依次添加: + - 主图:K 线 + 多组 SMA 均线(5/10/21 与 34/55/89/144); + - 第二行:成交量; + - 第三行:MACD; + - 主图叠加:分型散点 + 笔的连线(含端点文本标注)。 + 5. 将图表导出为本地 HTML 文件并清理。 + + 关键断言: + - 调用 ``write_html`` 后目标文件必须真实存在; + - 删除文件后 ``os.path.exists`` 必须返回 False,确保资源清理彻底。 + """ + # 使用 mock 数据替代硬编码数据文件,固定 seed 保证测试可重现 df = mock.generate_symbol_kines("000001", "日线", sdt="20230101", edt="20240101", seed=42) bars = [] for i, row in df.iterrows(): @@ -35,7 +59,7 @@ def test_kline_chart(): c = CZSC(bars, max_bi_num=50) - # 从 bars_raw 手动构建 DataFrame + # 从 bars_raw 手动构建 DataFrame,作为 KlineChart 各类 add_* 方法的输入 df = pd.DataFrame( [ { diff --git a/test/test_rs.py b/test/test_rs.py deleted file mode 100644 index 612fb464e..000000000 --- a/test/test_rs.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np - - -def test_daily_performance(): - from czsc import daily_performance - - x = daily_performance([0.01, 0.02, -0.01, 0.03, 0.02, -0.02, 0.01, -0.01, 0.02, 0.01]) - assert x["夏普"] is not None - assert x["年化"] is not None - assert x["最大回撤"] >= 0 # 最大回撤为正值,表示回撤幅度 - assert isinstance(x["夏普"], (int, float)) - - -def test_weight_backtest(): - """测试权重回测功能""" - from rs_czsc import WeightBacktest - - from czsc import mock - - dfw = mock.generate_symbol_kines("000001", "日线", sdt="20230101", edt="20240101", seed=42) - dfw["weight"] = np.where(dfw["close"] > dfw["open"], 1.0, -1.0) - dfw["price"] = dfw["close"] - wb = WeightBacktest(dfw[["dt", "weight", "symbol", "price"]]) - - assert "夏普" in wb.stats - assert isinstance(wb.stats["夏普"], (int, float)) - assert wb.stats["夏普"] != 0 - - -def test_czsc(): - from rs_czsc import CZSC, Freq, format_standard_kline - - from czsc.mock import generate_klines - - df = generate_klines(seed=42) - symbol = df["symbol"].iloc[0] - df = df[df["symbol"] == symbol].copy() - bars = format_standard_kline(df, freq=Freq.D) - - c = CZSC(bars) - assert len(c.bars_raw) > 0 - assert c.bars_raw[-1].close > 0 - if len(c.bi_list) > 0: - bi = c.bi_list[-1] - assert bi is not None diff --git a/test/test_rs_analyze.py b/test/test_rs_analyze.py deleted file mode 100644 index 8986d81c8..000000000 --- a/test/test_rs_analyze.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/2/16 20:31 -describe: czsc.analyze 单元测试 -""" - -from rs_czsc import CZSC, Direction, Freq, RawBar - -from czsc import mock - - -def get_mock_bars(freq=Freq.D, symbol="000001", n_days=100): - """获取mock K线数据并转换为RawBar对象""" - if freq == Freq.F1: - df = mock.generate_symbol_kines(symbol, "1分钟", sdt="20240101", edt="20240110", seed=42) - elif freq == Freq.F5: - df = mock.generate_symbol_kines(symbol, "5分钟", sdt="20240101", edt="20240110", seed=42) - elif freq == Freq.D: - df = mock.generate_symbol_kines(symbol, "日线", sdt="20230101", edt="20240101", seed=42) - else: - df = mock.generate_klines(seed=42) - df = df[df["symbol"] == symbol].head(n_days) if symbol in df["symbol"].values else df.head(n_days) - - bars = [] - for i, row in df.iterrows(): - bar = RawBar( - symbol=row["symbol"], - id=i, - freq=freq, - open=row["open"], - dt=row["dt"], - close=row["close"], - high=row["high"], - low=row["low"], - vol=row["vol"], - amount=row["amount"], - ) - bars.append(bar) - return bars - - -def test_czsc_basic(): - """测试CZSC基础功能""" - bars = get_mock_bars(freq=Freq.D, symbol="000001", n_days=200) - c = CZSC(bars) - - assert c.symbol == "000001", "symbol应该正确设置" - assert c.freq == Freq.D, "频率应该正确设置" - assert len(c.bars_raw) > 0, "原始K线数据不应为空" - assert len(c.bars_ubi) > 0, "去除包含关系后的K线数据不应为空" - assert len(c.bi_list) > 0, "笔的列表不应为空" - - -def test_czsc_signals(): - """测试CZSC信号计算 - 无信号函数时signals为None或空字典""" - bars = get_mock_bars(freq=Freq.D, symbol="000001", n_days=200) - c = CZSC(bars) - - # 没有提供 get_signals 函数时,signals 为 None(Rust)或空字典(Python) - assert c.signals is None or isinstance(c.signals, dict), "signals应该是None或字典类型" - - -def test_czsc_ubi_properties(): - """测试CZSC的ubi属性""" - bars = get_mock_bars(freq=Freq.D, symbol="000001", n_days=200) - c = CZSC(bars) - - ubi = c.ubi - assert "direction" in ubi, "ubi应该包含direction字段" - assert "high_bar" in ubi, "ubi应该包含high_bar字段" - assert "low_bar" in ubi, "ubi应该包含low_bar字段" - assert isinstance(ubi["direction"], Direction), "direction应该是Direction类型" diff --git a/test/test_sig.py b/test/test_sig.py deleted file mode 100644 index d2e2102e6..000000000 --- a/test/test_sig.py +++ /dev/null @@ -1,288 +0,0 @@ -""" -test_sig.py - czsc.utils.sig 信号辅助工具函数单元测试 - -Mock数据格式说明: -- check_cross_info/fast_slow_cross: 输入为等长数值列表或 numpy 数组 -- same_dir_counts/count_last_same: 输入为数值列表 -- get_sub_elements: 输入为任意类型列表 -- cross_zero_axis: 输入为等长数值列表 -- down_cross_count: 输入为等长数值列表 -- cal_cross_num: 输入为 fast_slow_cross 返回的交叉信息列表 - -测试覆盖: -- 基本功能验证 -- 边界情况: 空数据、单值、全同值 -- 异常输入处理 -""" - -import numpy as np -import pytest - -from czsc.utils.sig import ( - cal_cross_num, - check_cross_info, - count_last_same, - create_single_signal, - cross_zero_axis, - down_cross_count, - fast_slow_cross, - get_sub_elements, - same_dir_counts, -) - - -class TestCreateSingleSignal: - """测试 create_single_signal 函数""" - - def test_basic(self): - """测试基本信号创建""" - s = create_single_signal(k1="1分钟", k2="倒1", k3="形态", v1="类一买") - assert isinstance(s, dict) - assert len(s) == 1 - - def test_default_values(self): - """测试默认值""" - s = create_single_signal(k1="1分钟") - assert isinstance(s, dict) - assert len(s) == 1 - - def test_with_score(self): - """测试带分数的信号""" - s = create_single_signal(k1="1分钟", k2="倒1", k3="形态", v1="看多", score=5) - assert isinstance(s, dict) - - -class TestCheckCrossInfo: - """测试 check_cross_info 函数""" - - def test_golden_cross(self): - """测试金叉检测""" - fast = [1, 2, 3, 4, 5, 6, 7] - slow = [7, 6, 5, 4, 3, 2, 1] - result = check_cross_info(fast, slow) - assert isinstance(result, list) - # fast从低于slow到高于slow,应该有金叉 - golden_crosses = [x for x in result if x["类型"] == "金叉"] - assert len(golden_crosses) >= 1 - - def test_death_cross(self): - """测试死叉检测""" - fast = [7, 6, 5, 4, 3, 2, 1] - slow = [1, 2, 3, 4, 5, 6, 7] - result = check_cross_info(fast, slow) - death_crosses = [x for x in result if x["类型"] == "死叉"] - assert len(death_crosses) >= 1 - - def test_no_cross(self): - """测试无交叉""" - fast = [10, 11, 12, 13, 14] - slow = [1, 2, 3, 4, 5] - result = check_cross_info(fast, slow) - assert len(result) == 0 - - def test_numpy_input(self): - """测试 numpy 数组输入""" - fast = np.array([1, 2, 3, 4, 5, 6, 7]) - slow = np.array([7, 6, 5, 4, 3, 2, 1]) - result = check_cross_info(fast, slow) - assert isinstance(result, list) - - def test_unequal_length_raises(self): - """测试不等长输入应抛出异常""" - with pytest.raises(AssertionError): - check_cross_info([1, 2, 3], [1, 2]) - - def test_cross_info_fields(self): - """测试返回字段完整性""" - fast = [1, 2, 3, 4, 5, 6, 7] - slow = [7, 6, 5, 4, 3, 2, 1] - result = check_cross_info(fast, slow) - if result: - expected_keys = { - "位置", - "类型", - "快线", - "慢线", - "距离", - "距今", - "面积", - "价差", - "快线高点", - "快线低点", - "慢线高点", - "慢线低点", - } - assert set(result[0].keys()) == expected_keys - - -class TestFastSlowCross: - """测试 fast_slow_cross 函数(与 check_cross_info 功能相同)""" - - def test_basic(self): - """测试基本交叉检测""" - fast = [1, 2, 3, 4, 5, 6, 7] - slow = [7, 6, 5, 4, 3, 2, 1] - result = fast_slow_cross(fast, slow) - assert isinstance(result, list) - assert len(result) >= 1 - - def test_consistent_with_check_cross_info(self): - """测试与 check_cross_info 结果一致""" - fast = [1, 3, 2, 5, 4, 7, 6] - slow = [4, 4, 4, 4, 4, 4, 4] - r1 = check_cross_info(fast, slow) - r2 = fast_slow_cross(fast, slow) - assert len(r1) == len(r2) - - -class TestSameDirCounts: - """测试 same_dir_counts 函数""" - - def test_all_positive(self): - """测试全正数""" - assert same_dir_counts([1, 2, 3]) == 3 - - def test_all_negative(self): - """测试全负数""" - assert same_dir_counts([-1, -2, -3]) == 3 - - def test_mixed(self): - """测试混合序列""" - assert same_dir_counts([-1, -2, 1, 2, 3]) == 3 - - def test_single_element(self): - """测试单元素""" - assert same_dir_counts([5]) == 1 - assert same_dir_counts([-5]) == 1 - - def test_direction_change(self): - """测试方向变化""" - # 最后是正数,往前数到第一个负数 - assert same_dir_counts([-1, -1, -2, -3, 0, 1, 2, 3, -1, -2, 1, 1, 2, 3]) == 4 - - -class TestCountLastSame: - """测试 count_last_same 函数""" - - def test_all_same(self): - """测试全部相同""" - assert count_last_same([1, 1, 1, 1]) == 4 - - def test_last_different(self): - """测试末尾不同""" - assert count_last_same([1, 2, 3, 3, 3]) == 3 - - def test_single_element(self): - """测试单元素""" - assert count_last_same([5]) == 1 - - def test_no_repeat(self): - """测试无重复""" - assert count_last_same([1, 2, 3, 4]) == 1 - - def test_tuple_input(self): - """测试 tuple 输入""" - assert count_last_same((1, 2, 2, 2)) == 3 - - -class TestGetSubElements: - """测试 get_sub_elements 函数""" - - def test_basic(self): - """测试基本功能""" - x = [1, 2, 3, 4, 5, 6, 7, 8, 9] - result = get_sub_elements(x, di=1, n=3) - assert result == [7, 8, 9] - - def test_di_2(self): - """测试 di=2""" - x = [1, 2, 3, 4, 5, 6, 7, 8, 9] - result = get_sub_elements(x, di=2, n=3) - assert result == [6, 7, 8] - - def test_n_larger_than_list(self): - """测试 n 大于列表长度""" - x = [1, 2, 3] - result = get_sub_elements(x, di=1, n=10) - assert result == [1, 2, 3] - - def test_di_0_raises(self): - """测试 di=0 应抛出异常""" - with pytest.raises(AssertionError): - get_sub_elements([1, 2, 3], di=0, n=2) - - -class TestCrossZeroAxis: - """测试 cross_zero_axis 函数""" - - def test_basic(self): - """测试基本功能""" - n1 = [1, 2, 3, -1, -2] - n2 = [1, 1, 1, 1, 1] - result = cross_zero_axis(n1, n2) - assert isinstance(result, int) - assert result >= 0 - - def test_no_cross(self): - """测试无交叉""" - n1 = [1, 2, 3, 4, 5] - n2 = [1, 2, 3, 4, 5] - result = cross_zero_axis(n1, n2) - assert isinstance(result, int) - - def test_unequal_length_raises(self): - """测试不等长应抛出异常""" - with pytest.raises(AssertionError): - cross_zero_axis([1, 2], [1, 2, 3]) - - -class TestCalCrossNum: - """测试 cal_cross_num 函数""" - - def test_empty_cross(self): - """测试空交叉列表""" - jc, sc = cal_cross_num([]) - assert jc == 0 - assert sc == 0 - - def test_single_cross(self): - """测试单个交叉""" - cross = [{"类型": "金叉", "距离": 5}] - jc, sc = cal_cross_num(cross) - assert jc == 1 - assert sc == 0 - - def test_with_distance_filter(self): - """测试距离过滤""" - cross = [ - {"类型": "金叉", "距离": 5}, - {"类型": "死叉", "距离": 1}, - ] - jc, sc = cal_cross_num(cross, distance=3) - assert isinstance(jc, int) - assert isinstance(sc, int) - - -class TestDownCrossCount: - """测试 down_cross_count 函数""" - - def test_basic(self): - """测试基本下穿计数""" - x1 = [5, 4, 3, 2, 1] - x2 = [1, 2, 3, 4, 5] - result = down_cross_count(x1, x2) - assert result >= 1 - - def test_no_cross(self): - """测试无下穿""" - x1 = [10, 11, 12, 13, 14] - x2 = [1, 2, 3, 4, 5] - result = down_cross_count(x1, x2) - assert result == 0 - - def test_numpy_input(self): - """测试 numpy 数组输入""" - x1 = np.array([5, 4, 3, 2, 1]) - x2 = np.array([1, 2, 3, 4, 5]) - result = down_cross_count(x1, x2) - assert isinstance(result, (int, np.integer)) diff --git a/test/test_utils.py b/test/test_utils.py index 661b26292..477f33fde 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,8 +1,10 @@ -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/2/16 20:31 -describe: czsc.utils 单元测试 +"""``czsc.utils`` 通用工具函数单元测试。 + +本测试套件覆盖 ``czsc.utils`` 中若干通用工具函数的基础行为,包括数值四舍五入、 +对称加密往返以及超时装饰器在正常与超时两种场景下的表现。 + +模块作者: + zengbin93 (zeng_bin8888@163.com),创建于 2022/2/16 20:31 """ import time @@ -15,6 +17,15 @@ def test_x_round(): + """验证 ``utils.x_round`` 按指定小数位进行四舍五入的行为。 + + 测试场景: + - 整数输入:保留 3 位小数应保持原值; + - 浮点输入:分别在 3 / 4 / 5 位精度下校验截断与舍入结果。 + + 关键断言: + ``x_round(1.000342, n)`` 在 n=3/4/5 时分别得到 1.0 / 1.0003 / 1.00034。 + """ assert utils.x_round(100, 3) == 100 assert utils.x_round(1.000342, 3) == 1.0 assert utils.x_round(1.000342, 4) == 1.0003 @@ -22,6 +33,17 @@ def test_x_round(): def test_fernet(): + """验证 Fernet 对称加密的 encrypt → decrypt 往返一致性。 + + 测试场景: + 1. 调用 ``generate_fernet_key`` 生成一把随机密钥; + 2. 使用该密钥对一个字典对象做 ``fernet_encrypt`` 加密; + 3. 再用同一密钥执行 ``fernet_decrypt`` 解密(``is_dict=True`` 表示 + 还原为字典而非字符串)。 + + 关键断言: + 解密后得到的字典与原始字典完全相等,证明加密往返不损失信息。 + """ from czsc.utils.crypto.fernet import fernet_decrypt, fernet_encrypt, generate_fernet_key key = generate_fernet_key() @@ -31,42 +53,16 @@ def test_fernet(): assert text == decrypted, f"{text} != {decrypted}" -def test_find_most_similarity(): - """测试相似度查找功能""" - from czsc.utils.features import find_most_similarity - - # 使用固定种子创建确定性的测试数据 - np.random.seed(42) - vector = pd.Series(np.random.rand(10)) - matrix = pd.DataFrame(np.random.rand(10, 100)) - - result = find_most_similarity(vector, matrix, n=5, metric="cosine") - - assert isinstance(result, pd.Series), "结果应该是pandas Series" - assert len(result) == 5, "结果长度应该是5" - assert all(isinstance(index, int) for index in result.index), "索引应该都是整数" - assert all(0 <= value <= 1 for value in result.values), "相似度值应该在0-1之间" - - -def test_overlap(): - """测试重叠检测功能""" - from czsc.utils import overlap - - df = pd.DataFrame( - { - "dt": pd.date_range(start="1/1/2022", periods=5), - "symbol": ["AAPL", "AAPL", "AAPL", "AAPL", "AAPL"], - "col": [1, 1, 2, 2, 1], - } - ) - - result = overlap(df, "col") - - assert result["col_overlap"].tolist() == [1, 2, 1, 2, 1], "重叠检测结果不正确" +def test_timeout_decorator_success(): + """验证超时装饰器在被装饰函数耗时小于阈值时正常返回结果。 + 测试场景: + 定义一个执行约 1 秒的 ``fast_function``,并用 ``timeout_decorator(2)`` + 装饰(超时阈值 2 秒)。 -def test_timeout_decorator_success(): - """测试超时装饰器正常情况""" + 关键断言: + 函数能够在阈值内正常完成并返回 ``"Completed"``。 + """ @timeout_decorator(2) def fast_function(): @@ -77,7 +73,15 @@ def fast_function(): def test_timeout_decorator_timeout(): - """测试超时装饰器超时情况""" + """验证超时装饰器在被装饰函数耗时超过阈值时返回 None。 + + 测试场景: + 定义一个执行约 5 秒的 ``slow_function``,并用 ``timeout_decorator(1)`` + 装饰(超时阈值 1 秒)。 + + 关键断言: + 装饰器在 1 秒后中止函数执行并返回 ``None``。 + """ @timeout_decorator(1) def slow_function(): diff --git a/test/test_utils_features.py b/test/test_utils_features.py deleted file mode 100644 index 3f799a645..000000000 --- a/test/test_utils_features.py +++ /dev/null @@ -1,144 +0,0 @@ -""" -test_utils_features.py - czsc.utils.features 因子处理模块单元测试 - -Mock数据格式说明: -- normalize_feature: 需要 DataFrame,包含 dt 列和因子列 -- normalize_ts_feature: 需要时间序列 DataFrame,包含 dt 和因子列 -- feature_cross_layering: 需要 DataFrame,包含 dt, symbol, 因子列 -- find_most_similarity: 需要 Series 和 DataFrame - -测试覆盖: -- normalize_feature: 各种标准化方法(standard, minmax, robust, norm) -- normalize_ts_feature: 时间序列归一化 -- feature_cross_layering: 截面分层 -- find_most_similarity: 相似度搜索 -- 边界情况: 含 NaN 数据断言、极端值 -""" - -import numpy as np -import pandas as pd -import pytest - -from czsc.utils.features import feature_cross_layering, find_most_similarity, normalize_feature - - -class TestNormalizeFeature: - """测试 normalize_feature 函数""" - - @pytest.fixture - def sample_df(self): - """构造测试数据: 多个时间截面,每个截面有多个品种的因子值""" - np.random.seed(42) - dates = pd.date_range("20220101", periods=100, freq="D") - data = [] - for dt in dates: - for _ in range(20): - data.append({"dt": dt, "factor": np.random.randn()}) - return pd.DataFrame(data) - - def test_standard_method(self, sample_df): - """测试 standard 标准化""" - result = normalize_feature(sample_df, "factor", method="standard") - assert "factor" in result.columns - assert len(result) == len(sample_df) - - def test_minmax_method(self, sample_df): - """测试 minmax 标准化""" - result = normalize_feature(sample_df, "factor", method="minmax") - assert "factor" in result.columns - - def test_robust_method(self, sample_df): - """测试 robust 标准化""" - result = normalize_feature(sample_df, "factor", method="robust") - assert "factor" in result.columns - - def test_invalid_method_raises(self, sample_df): - """测试无效方法应抛出异常""" - with pytest.raises(ValueError): - normalize_feature(sample_df, "factor", method="invalid_method") - - def test_nan_raises(self, sample_df): - """测试含 NaN 数据应抛出异常""" - sample_df.loc[0, "factor"] = np.nan - with pytest.raises(AssertionError): - normalize_feature(sample_df, "factor") - - def test_does_not_modify_original(self, sample_df): - """测试不应修改原始数据""" - original = sample_df["factor"].copy() - normalize_feature(sample_df, "factor") - pd.testing.assert_series_equal(sample_df["factor"], original) - - -class TestFeatureCrossLayering: - """测试 feature_cross_layering 函数""" - - @pytest.fixture - def sample_df(self): - """构造截面数据: 多日期 x 多品种""" - np.random.seed(42) - dates = pd.date_range("20220101", periods=50, freq="D") - symbols = [f"SYM{str(i).zfill(3)}" for i in range(20)] - data = [] - for dt in dates: - for sym in symbols: - data.append({"dt": dt, "symbol": sym, "factor": np.random.randn()}) - return pd.DataFrame(data) - - def test_basic_layering(self, sample_df): - """测试基本分层""" - result = feature_cross_layering(sample_df, "factor", n=5) - assert "factor分层" in result.columns - - def test_layer_format(self, sample_df): - """测试分层格式""" - result = feature_cross_layering(sample_df, "factor", n=5) - # 分层应为 "第XX层" 格式 - layers = result["factor分层"].unique() - for layer in layers: - assert layer.startswith("第"), f"分层格式应以'第'开头: {layer}" - - def test_missing_dt_raises(self, sample_df): - """测试缺少 dt 列应抛出异常""" - df = sample_df.drop(columns=["dt"]) - with pytest.raises(AssertionError): - feature_cross_layering(df, "factor") - - def test_missing_symbol_raises(self, sample_df): - """测试缺少 symbol 列应抛出异常""" - df = sample_df.drop(columns=["symbol"]) - with pytest.raises(AssertionError): - feature_cross_layering(df, "factor") - - def test_too_few_symbols_raises(self): - """测试品种数量不足应抛出异常""" - dates = pd.date_range("20220101", periods=10, freq="D") - data = [{"dt": dt, "symbol": "SYM001", "factor": np.random.randn()} for dt in dates] - df = pd.DataFrame(data) - with pytest.raises(AssertionError): - feature_cross_layering(df, "factor", n=5) - - -class TestFindMostSimilarity: - """测试 find_most_similarity 函数""" - - def test_basic(self): - """测试基本相似度搜索""" - np.random.seed(42) - vector = pd.Series(np.random.randn(10)) - matrix = pd.DataFrame(np.random.randn(10, 20), columns=[f"col_{i}" for i in range(20)]) - result = find_most_similarity(vector, matrix, n=5) - assert len(result) == 5 - assert isinstance(result, pd.Series) - - def test_identical_vector(self): - """测试完全相同的向量""" - vector = pd.Series([1, 2, 3, 4, 5]) - matrix = pd.DataFrame( - { - "exact": [1, 2, 3, 4, 5], - "different": [5, 4, 3, 2, 1], - } - ) - result = find_most_similarity(vector, matrix, n=2) - assert len(result) == 2 diff --git a/test/test_utils_refactored.py b/test/test_utils_refactored.py index 761f9e5bb..b7aad0e54 100644 --- a/test/test_utils_refactored.py +++ b/test/test_utils_refactored.py @@ -1,5 +1,21 @@ -""" -测试新的绘图模块结构 +"""``czsc.utils`` 重构后的子包结构与向后兼容性单元测试。 + +本测试套件验证 ``czsc.utils`` 经过重构、按主题划分到多个子包后, +新旧导入路径都能正确工作,并且各子包暴露的常量、函数与验证器行为符合预期。 + +业务背景: + ``czsc.utils`` 重构为以下子包/模块: + + - ``czsc.utils.plotting``:绘图相关,进一步分为 ``common``(颜色/标签等 + 共享常量与 ``figure_to_html`` 工具函数)、回测绘图、权重绘图等子模块。 + - ``czsc.utils.data``:数据相关,包括 ``validators``(DataFrame 校验)、 + ``converters``(标准化转换)等。 + - ``czsc.utils.crypto``:对称加密工具,从 ``czsc.utils.crypto.fernet`` 升级 + 为子包导出。 + - ``czsc.utils.analysis``:统计与相关性分析工具的统一入口。 + + 重构同时要求保持完整的向后兼容性,旧的导入路径(如 ``from czsc.utils + import DiskCache``)必须继续工作。 """ import pandas as pd @@ -7,7 +23,20 @@ def test_plotting_common_module(): - """测试绘图公共模块的常量和函数""" + """验证 ``czsc.utils.plotting.common`` 中的常量与 figure_to_html 工具函数。 + + 测试场景: + 1. 校验四个核心常量值:颜色、Sigma 等级数、月份标签数; + 2. 用空 ``go.Figure`` 调用 ``figure_to_html``: + - ``to_html=False`` 时返回原 Figure 对象; + - ``to_html=True`` 时返回包含 plotly 标识的 HTML 字符串。 + + 关键断言: + - ``COLOR_DRAWDOWN`` 与 ``COLOR_RETURN`` 为预定义颜色字符串; + - ``SIGMA_LEVELS`` 长度为 6(覆盖 ±1/±2/±3 sigma); + - ``MONTH_LABELS`` 长度为 12(一月到十二月); + - ``figure_to_html`` 在两种模式下分别返回 Figure 与 str。 + """ from czsc.utils.plotting.common import ( COLOR_DRAWDOWN, COLOR_RETURN, @@ -38,7 +67,12 @@ def test_plotting_common_module(): def test_plotting_backtest_imports(): - """测试回测绘图模块的导入""" + """验证回测绘图模块的关键 API 可以从 ``czsc.utils.plotting`` 顶层导入。 + + 关键断言: + ``plot_cumulative_returns``、``plot_colored_table``、``plot_czsc_chart`` + 三个函数均可被导入且为可调用对象。 + """ from czsc.utils.plotting import ( plot_colored_table, plot_cumulative_returns, @@ -51,7 +85,11 @@ def test_plotting_backtest_imports(): def test_plotting_weight_imports(): - """测试权重绘图模块的导入""" + """验证权重绘图相关统计函数可以从 ``czsc.utils.plotting`` 导入。 + + 关键断言: + ``calculate_turnover_stats`` 与 ``calculate_weight_stats`` 都是可调用对象。 + """ from czsc.utils.plotting import ( calculate_turnover_stats, calculate_weight_stats, @@ -62,7 +100,18 @@ def test_plotting_weight_imports(): def test_data_validators(): - """测试数据验证器""" + """验证数据校验器 ``czsc.utils.data.validators`` 的核心断言行为。 + + 测试场景: + - ``validate_dataframe_columns``:列齐全时静默通过,缺列时抛 ValueError; + - ``validate_datetime_index``:DatetimeIndex 静默通过,普通整数索引抛 ValueError; + - ``validate_numeric_column``:列存在时静默通过,列不存在时抛 ValueError。 + + 关键断言: + - 所有正常情况都不抛异常; + - 所有异常分支都通过 ``pytest.raises`` 配合 ``match`` 参数校验异常消息中 + 的关键中文字符串。 + """ from czsc.utils.data.validators import ( validate_dataframe_columns, validate_datetime_index, @@ -96,7 +145,18 @@ def test_data_validators(): def test_data_converters(): - """测试数据转换器""" + """验证数据转换器 ``czsc.utils.data.converters`` 的标准化能力。 + + 测试场景: + - ``to_standard_kline_format``:将带有 ``datetime`` / ``volume`` 列名的 + DataFrame 转换为 czsc 的标准列名(``dt`` / ``vol``); + - ``normalize_symbol``:去除前后空格并转换为大写。 + + 关键断言: + - 转换后必须包含 ``dt``、``open``、``vol`` 等标准列; + - ``dt`` 列的元素类型必须是 ``pd.Timestamp``; + - ``" aapl "`` → ``"AAPL"``、``" tsla "`` → ``"TSLA"``。 + """ from czsc.utils.data.converters import ( normalize_symbol, to_standard_kline_format, @@ -127,7 +187,17 @@ def test_data_converters(): def test_crypto_module(): - """测试加密模块""" + """验证加密子包 ``czsc.utils.crypto`` 的密钥生成与往返加密能力。 + + 测试场景: + 1. 生成 Fernet 密钥; + 2. 用该密钥对字典做加密; + 3. 用同一密钥解密并以字典形式还原。 + + 关键断言: + - 密钥与密文类型为 ``bytes`` 或 ``str``(Fernet 实现允许两种形式); + - 解密后字典与原始字典完全一致。 + """ from czsc.utils.crypto import ( fernet_decrypt, fernet_encrypt, @@ -135,7 +205,7 @@ def test_crypto_module(): ) key = generate_fernet_key() - assert isinstance(key, (bytes, str)) # Key can be bytes or string + assert isinstance(key, (bytes, str)) # 密钥既可能是 bytes 也可能是 str text = {"account": "test", "password": "123"} encrypted = fernet_encrypt(text, key) @@ -146,7 +216,12 @@ def test_crypto_module(): def test_analysis_stats_imports(): - """测试统计分析模块的导入""" + """验证统计分析模块 ``czsc.utils.analysis`` 的关键函数可被导入。 + + 关键断言: + ``daily_performance``、``holds_performance``、``top_drawdowns`` 三个统计 + 函数均为可调用对象。 + """ from czsc.utils.analysis import ( daily_performance, holds_performance, @@ -159,7 +234,12 @@ def test_analysis_stats_imports(): def test_analysis_corr_imports(): - """测试相关性分析模块的导入""" + """验证相关性分析模块 ``czsc.utils.analysis`` 的关键函数可被导入。 + + 关键断言: + ``nmi_matrix``、``single_linear``、``cross_sectional_ic`` 三个相关性分析 + 函数均为可调用对象。 + """ from czsc.utils.analysis import ( cross_sectional_ic, nmi_matrix, @@ -171,15 +251,19 @@ def test_analysis_corr_imports(): assert callable(cross_sectional_ic) -def test_analysis_events_imports(): - """测试事件分析模块的导入""" - from czsc.utils.analysis import overlap - - assert callable(overlap) - - def test_backward_compatibility(): - """测试向后兼容性 - 旧的导入路径仍然可用""" + """验证向后兼容性:重构前的旧导入路径仍然可用。 + + 测试场景: + 在重构后,仍然支持从 ``czsc.utils`` 顶层直接导入历史接口,包括: + ``DataClient``、``DiskCache``、``KlineChart``、``daily_performance``、 + ``generate_fernet_key``、``home_path``、``plot_colored_table`` 等。 + + 关键断言: + - 各个名称都能成功导入; + - 类对象 / 路径对象不为 None; + - 函数对象 ``callable(...)`` 为真。 + """ # 测试从主utils导入 from czsc.utils import ( DataClient, @@ -188,7 +272,6 @@ def test_backward_compatibility(): daily_performance, generate_fernet_key, home_path, - overlap, ) assert home_path is not None @@ -196,7 +279,6 @@ def test_backward_compatibility(): assert DataClient is not None assert callable(generate_fernet_key) assert callable(daily_performance) - assert callable(overlap) assert KlineChart is not None # 测试向后兼容性 - 通过 czsc.utils 的 __init__.py 重新导出 diff --git a/test/test_utils_ta.py b/test/test_utils_ta.py deleted file mode 100644 index c5105ba97..000000000 --- a/test/test_utils_ta.py +++ /dev/null @@ -1,503 +0,0 @@ -""" -test_utils_ta.py - 技术分析指标单元测试 - -Mock数据格式说明: -- 数据来源: czsc.mock.generate_symbol_kines -- 数据列: dt, symbol, open, close, high, low, vol, amount -- 时间范围: 20200101-20250101(5年数据,满足3年+要求) -- 频率: 日线 -- Seed: 42(确保可重现) - -测试覆盖范围: -- SMA (Simple Moving Average) - 简单移动平均线 -- EMA (Exponential Moving Average) - 指数移动平均线 -- MACD (Moving Average Convergence Divergence) - 指标平滑异同移动平均线 -- RSI (Relative Strength Index) - 相对强弱指标 -- BOLL (Bollinger Bands) - 布林带 -- ATR (Average True Range) - 平均真实波幅 -- KDJ (Stochastic Indicator) - 随机指标 -""" - -import numpy as np -import pandas as pd - -from czsc import mock -from czsc.utils.ta import ATR, BOLL, EMA, KDJ, MACD, RSI, SMA - - -def get_test_data(freq="日线", sdt="20200101", edt="20250101", symbol="000001"): - """获取测试数据 - - Args: - freq: K线频率,默认日线 - sdt: 开始日期 - edt: 结束日期 - symbol: 品种代码 - - Returns: - DataFrame: K线数据 - """ - df = mock.generate_symbol_kines(symbol=symbol, freq=freq, sdt=sdt, edt=edt, seed=42) - return df - - -class TestSMA: - """SMA简单移动平均线测试""" - - def test_sma_basic(self): - """测试SMA基础功能""" - df = get_test_data() - sma5 = SMA(df["close"], 5) - sma10 = SMA(df["close"], 10) - sma20 = SMA(df["close"], 20) - sma60 = SMA(df["close"], 60) - - assert len(sma5) == len(df), "SMA返回长度应与输入相同" - assert len(sma10) == len(df), "SMA返回长度应与输入相同" - assert len(sma20) == len(df), "SMA返回长度应与输入相同" - assert len(sma60) == len(df), "SMA返回长度应与输入相同" - - # 验证SMA非NaN值的平滑性 - assert not sma5[60:].isna().any(), "SMA(5)在60周期后不应有NaN" - assert not sma10[60:].isna().any(), "SMA(10)在60周期后不应有NaN" - assert not sma20[60:].isna().any(), "SMA(20)在60周期后不应有NaN" - assert not sma60[60:].isna().any(), "SMA(60)在60周期后不应有NaN" - - def test_sma_empty_array(self): - """测试空数组""" - result = SMA(np.array([]), 5) - assert len(result) == 0, "空数组应返回空结果" - - def test_sma_single_value(self): - """测试单值数组""" - result = SMA(np.array([100], dtype=np.float64), 5) - assert len(result) == 1, "单值数组应返回单值结果" - # SMA implementation returns mean for available data, so won't be NaN for single value - - def test_sma_with_nan(self): - """测试包含NaN的数据""" - data = np.array([1, 2, np.nan, 4, 5, 6, 7, 8, 9, 10], dtype=np.float64) - result = SMA(data, 5) - assert len(result) == len(data), "包含NaN的数据长度应保持不变" - # 验证结果不为None - assert result is not None - - def test_sma_with_inf(self): - """测试包含Inf的数据""" - data = np.array([1, 2, 3, 4, np.inf, 6, 7, 8, 9, 10], dtype=np.float64) - result = SMA(data, 5) - assert len(result) == len(data), "包含Inf的数据长度应保持不变" - - def test_sma_with_zeros(self): - """测试全0数据""" - data = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float64) - result = SMA(data, 5) - assert len(result) == len(data), "全0数据长度应保持不变" - # 验证结果也是0(忽略NaN) - non_nan = result[~pd.isna(result)] - assert all(non_nan == 0), "全0数据的SMA应为0" - - def test_sma_large_period(self): - """测试周期超过数据长度""" - data = np.array([1, 2, 3, 4, 5], dtype=np.float64) - result = SMA(data, 10) - assert len(result) == len(data), "周期超过数据长度时,长度应保持不变" - # SMA implementation returns mean for available data - - def test_sma_period_1(self): - """测试周期为1""" - data = np.array([1, 2, 3, 4, 5], dtype=np.float64) - # TA-Lib requires period >= 2 - result = SMA(data, 2) - assert len(result) == len(data), "周期为2时,长度应保持不变" - - -class TestEMA: - """EMA指数移动平均线测试""" - - def test_ema_basic(self): - """测试EMA基础功能""" - df = get_test_data() - ema5 = EMA(df["close"], 5) - ema12 = EMA(df["close"], 12) - ema26 = EMA(df["close"], 26) - - assert len(ema5) == len(df), "EMA返回长度应与输入相同" - assert len(ema12) == len(df), "EMA返回长度应与输入相同" - assert len(ema26) == len(df), "EMA返回长度应与输入相同" - - def test_ema_empty_array(self): - """测试空数组""" - result = EMA(np.array([]), 5) - assert len(result) == 0, "空数组应返回空结果" - - def test_ema_single_value(self): - """测试单值数组""" - result = EMA(np.array([100], dtype=np.float64), 5) - assert len(result) == 1, "单值数组应返回单值结果" - - def test_ema_with_nan(self): - """测试包含NaN的数据""" - data = np.array([1, 2, np.nan, 4, 5, 6, 7, 8, 9, 10], dtype=np.float64) - result = EMA(data, 5) - assert len(result) == len(data), "包含NaN的数据长度应保持不变" - - def test_ema_with_zeros(self): - """测试全0数据""" - data = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float64) - result = EMA(data, 5) - assert len(result) == len(data), "全0数据长度应保持不变" - - def test_ema_period_1(self): - """测试周期为1""" - data = np.array([1, 2, 3, 4, 5], dtype=np.float64) - # TA-Lib requires period >= 2 - result = EMA(data, 2) - assert len(result) == len(data), "周期为2时,长度应保持不变" - - -class TestMACD: - """MACD指标测试""" - - def test_macd_basic(self): - """测试MACD基础功能""" - df = get_test_data() - diff, dea, macd = MACD(df["close"].values) - - assert len(diff) == len(df), "DIFF返回长度应与输入相同" - assert len(dea) == len(df), "DEA返回长度应与输入相同" - assert len(macd) == len(df), "MACD返回长度应与输入相同" - - # Note: TA-Lib's MACD uses a different formula than simple 2*(DIFF-DEA) - # We just verify the relationships exist - valid_idx = ~pd.isna(diff) & ~pd.isna(dea) & ~pd.isna(macd) - assert valid_idx.sum() > 0, "应有有效的MACD值" - - def test_macd_empty_array(self): - """测试空数组""" - diff, dea, macd = MACD(np.array([])) - assert len(diff) == 0, "空数组应返回空结果" - assert len(dea) == 0, "空数组应返回空结果" - assert len(macd) == 0, "空数组应返回空结果" - - def test_macd_single_value(self): - """测试单值数组""" - diff, dea, macd = MACD(np.array([100], dtype=np.float64)) - assert len(diff) == 1, "单值数组应返回单值结果" - assert len(dea) == 1, "单值数组应返回单值结果" - assert len(macd) == 1, "单值数组应返回单值结果" - - def test_macd_with_nan(self): - """测试包含NaN的数据""" - data = np.array([1, 2, np.nan, 4, 5, 6, 7, 8, 9, 10], dtype=np.float64) - diff, dea, macd = MACD(data) - assert len(diff) == len(data), "包含NaN的数据长度应保持不变" - assert len(dea) == len(data), "包含NaN的数据长度应保持不变" - assert len(macd) == len(data), "包含NaN的数据长度应保持不变" - - def test_macd_with_zeros(self): - """测试全0数据""" - data = np.array([0] * 100, dtype=np.float64) - diff, dea, macd = MACD(data) - assert len(diff) == len(data), "全0数据长度应保持不变" - assert len(dea) == len(data), "全0数据长度应保持不变" - assert len(macd) == len(data), "全0数据长度应保持不变" - - def test_macd_with_inf(self): - """测试包含Inf的数据""" - data = np.array([1, 2, 3, 4, np.inf, 6, 7, 8, 9, 10], dtype=np.float64) - diff, dea, macd = MACD(data) - assert len(diff) == len(data), "包含Inf的数据长度应保持不变" - - def test_macd_constant_values(self): - """测试常量值数据""" - data = np.array([100] * 100, dtype=np.float64) - diff, dea, macd = MACD(data) - # 常量值的MACD应接近0 - assert len(diff) == len(data), "常量数据长度应保持不变" - - -class TestRSI: - """RSI相对强弱指标测试""" - - def test_rsi_basic(self): - """测试RSI基础功能""" - df = get_test_data() - rsi6 = RSI(df["close"], 6) - rsi12 = RSI(df["close"], 12) - rsi24 = RSI(df["close"], 24) - - assert len(rsi6) == len(df), "RSI返回长度应与输入相同" - assert len(rsi12) == len(df), "RSI返回长度应与输入相同" - assert len(rsi24) == len(df), "RSI返回长度应与输入相同" - - # 验证RSI在0-100范围内(忽略NaN) - valid_rsi6 = rsi6[~pd.isna(rsi6)] - assert all((valid_rsi6 >= 0) & (valid_rsi6 <= 100)), "RSI应在0-100范围内" - - def test_rsi_empty_array(self): - """测试空数组""" - result = RSI(np.array([]), 6) - assert len(result) == 0, "空数组应返回空结果" - - def test_rsi_single_value(self): - """测试单值数组""" - result = RSI(np.array([100], dtype=np.float64), 6) - assert len(result) == 1, "单值数组应返回单值结果" - - def test_rsi_with_nan(self): - """测试包含NaN的数据""" - data = np.array([1, 2, np.nan, 4, 5, 6, 7, 8, 9, 10], dtype=np.float64) - result = RSI(data, 6) - assert len(result) == len(data), "包含NaN的数据长度应保持不变" - - def test_rsi_constant_values(self): - """测试常量值数据""" - data = np.array([100] * 100, dtype=np.float64) - result = RSI(data, 6) - assert len(result) == len(data), "常量数据长度应保持不变" - - def test_rsi_all_increasing(self): - """测试持续上涨数据""" - data = np.array(list(range(1, 101)), dtype=np.float64) - result = RSI(data, 6) - # 持续上涨时RSI应接近100 - valid_rsi = result[~pd.isna(result)] - assert len(valid_rsi) > 0, "应有有效的RSI值" - - def test_rsi_all_decreasing(self): - """测试持续下跌数据""" - data = np.array(list(range(100, 0, -1)), dtype=np.float64) - result = RSI(data, 6) - # 持续下跌时RSI应接近0 - valid_rsi = result[~pd.isna(result)] - assert len(valid_rsi) > 0, "应有有效的RSI值" - - -class TestBOLL: - """BOLL布林带测试""" - - def test_boll_basic(self): - """测试BOLL基础功能""" - df = get_test_data() - upper, middle, lower = BOLL(df["close"], 20) - - assert len(upper) == len(df), "上轨返回长度应与输入相同" - assert len(middle) == len(df), "中轨返回长度应与输入相同" - assert len(lower) == len(df), "下轨返回长度应与输入相同" - - # 验证上轨 >= 中轨 >= 下轨(忽略NaN) - valid_idx = ~pd.isna(upper) & ~pd.isna(middle) & ~pd.isna(lower) - assert all(upper[valid_idx] >= middle[valid_idx]), "上轨应大于等于中轨" - assert all(middle[valid_idx] >= lower[valid_idx]), "中轨应大于等于下轨" - - def test_boll_empty_array(self): - """测试空数组""" - upper, middle, lower = BOLL(np.array([]), 20) - assert len(upper) == 0, "空数组应返回空结果" - assert len(middle) == 0, "空数组应返回空结果" - assert len(lower) == 0, "空数组应返回空结果" - - def test_boll_single_value(self): - """测试单值数组""" - upper, middle, lower = BOLL(np.array([100], dtype=np.float64), 20) - assert len(upper) == 1, "单值数组应返回单值结果" - assert len(middle) == 1, "单值数组应返回单值结果" - assert len(lower) == 1, "单值数组应返回单值结果" - - def test_boll_with_nan(self): - """测试包含NaN的数据""" - data = np.array([1, 2, np.nan, 4, 5, 6, 7, 8, 9, 10] * 10, dtype=np.float64) - upper, middle, lower = BOLL(data, 20) - assert len(upper) == len(data), "包含NaN的数据长度应保持不变" - assert len(middle) == len(data), "包含NaN的数据长度应保持不变" - assert len(lower) == len(data), "包含NaN的数据长度应保持不变" - - def test_boll_constant_values(self): - """测试常量值数据""" - data = np.array([100] * 100, dtype=np.float64) - upper, middle, lower = BOLL(data, 20) - # 常量值的BOLL上下轨应接近中轨 - valid_idx = ~pd.isna(upper) & ~pd.isna(middle) & ~pd.isna(lower) - if len(valid_idx) > 0: - # 中轨应该等于常量值 - assert all(middle[valid_idx] == 100), "常量值的中轨应等于该常量" - - def test_boll_with_zeros(self): - """测试全0数据""" - data = np.array([0] * 100, dtype=np.float64) - upper, middle, lower = BOLL(data, 20) - assert len(upper) == len(data), "全0数据长度应保持不变" - assert len(middle) == len(data), "全0数据长度应保持不变" - assert len(lower) == len(data), "全0数据长度应保持不变" - - def test_boll_relationship(self): - """测试布林带上下轨关系""" - data = np.array(list(range(1, 101)), dtype=np.float64) - upper, middle, lower = BOLL(data, 20) - - # 验证上轨 >= 中轨 >= 下轨 - valid_idx = ~pd.isna(upper) & ~pd.isna(middle) & ~pd.isna(lower) - assert all(upper[valid_idx] >= middle[valid_idx]), "上轨应大于等于中轨" - assert all(middle[valid_idx] >= lower[valid_idx]), "中轨应大于等于下轨" - - -class TestATR: - """ATR平均真实波幅测试""" - - def test_atr_basic(self): - """测试ATR基础功能""" - df = get_test_data() - atr = ATR(df["high"], df["low"], df["close"], 14) - - assert len(atr) == len(df), "ATR返回长度应与输入相同" - - # 验证ATR非负性(忽略NaN) - valid_atr = atr[~pd.isna(atr)] - assert all(valid_atr >= 0), "ATR应非负" - - def test_atr_empty_array(self): - """测试空数组""" - df = pd.DataFrame({"high": [], "low": [], "close": []}) - result = ATR(df["high"], df["low"], df["close"], 14) - assert len(result) == 0, "空数组应返回空结果" - - def test_atr_single_value(self): - """测试单值数据""" - df = pd.DataFrame({"high": [100], "low": [90], "close": [95]}) - result = ATR(df["high"], df["low"], df["close"], 14) - assert len(result) == 1, "单值数据应返回单值结果" - - def test_atr_with_nan(self): - """测试包含NaN的数据""" - df = pd.DataFrame( - {"high": [100, 102, np.nan, 106, 108], "low": [90, 92, 94, np.nan, 98], "close": [95, 97, 99, 101, 103]} - ) - result = ATR(df["high"], df["low"], df["close"], 14) - assert len(result) == len(df), "包含NaN的数据长度应保持不变" - - def test_atr_constant_prices(self): - """测试常量价格""" - df = pd.DataFrame({"high": [100] * 50, "low": [100] * 50, "close": [100] * 50}) - result = ATR(df["high"], df["low"], df["close"], 14) - # 常量价格的ATR应为0 - valid_atr = result[~pd.isna(result)] - if len(valid_atr) > 0: - assert all(valid_atr == 0), "常量价格的ATR应为0" - - def test_atr_with_zeros(self): - """测试全0数据""" - df = pd.DataFrame({"high": [0] * 50, "low": [0] * 50, "close": [0] * 50}) - result = ATR(df["high"], df["low"], df["close"], 14) - assert len(result) == len(df), "全0数据长度应保持不变" - - def test_atr_non_negative(self): - """测试ATR非负性""" - df = get_test_data() - atr = ATR(df["high"], df["low"], df["close"], 14) - valid_atr = atr[~pd.isna(atr)] - assert all(valid_atr >= 0), "ATR应始终非负" - - -class TestKDJ: - """KDJ随机指标测试""" - - def test_kdj_basic(self): - """测试KDJ基础功能""" - df = get_test_data() - k, d, j = KDJ(df["close"].values, df["high"].values, df["low"].values) - - assert len(k) == len(df), "K值返回长度应与输入相同" - assert len(d) == len(df), "D值返回长度应与输入相同" - assert len(j) == len(df), "J值返回长度应与输入相同" - - # 验证K/D值在0-100范围内(忽略NaN) - valid_k = k[~pd.isna(k)] - valid_d = d[~pd.isna(d)] - assert all((valid_k >= 0) & (valid_k <= 100)), "K值应在0-100范围内" - assert all((valid_d >= 0) & (valid_d <= 100)), "D值应在0-100范围内" - - def test_kdj_empty_array(self): - """测试空数组""" - df = pd.DataFrame({"high": [], "low": [], "close": []}) - k, d, j = KDJ(df["close"].values, df["high"].values, df["low"].values) - assert len(k) == 0, "空数组应返回空结果" - assert len(d) == 0, "空数组应返回空结果" - assert len(j) == 0, "空数组应返回空结果" - - def test_kdj_single_value(self): - """测试单值数据""" - df = pd.DataFrame({"high": [100], "low": [90], "close": [95]}) - k, d, j = KDJ(df["close"].values, df["high"].values, df["low"].values) - assert len(k) == 1, "单值数据应返回单值结果" - assert len(d) == 1, "单值数据应返回单值结果" - assert len(j) == 1, "单值数据应返回单值结果" - - def test_kdj_with_nan(self): - """测试包含NaN的数据""" - df = pd.DataFrame( - {"high": [100, 102, np.nan, 106, 108], "low": [90, 92, 94, np.nan, 98], "close": [95, 97, 99, 101, 103]} - ) - k, d, j = KDJ(df["close"].values, df["high"].values, df["low"].values) - assert len(k) == len(df), "包含NaN的数据长度应保持不变" - assert len(d) == len(df), "包含NaN的数据长度应保持不变" - assert len(j) == len(df), "包含NaN的数据长度应保持不变" - - def test_kdj_constant_prices(self): - """测试常量价格""" - df = pd.DataFrame({"high": [100] * 50, "low": [100] * 50, "close": [100] * 50}) - k, d, j = KDJ(df["close"].values, df["high"].values, df["low"].values) - # 常量价格的KDJ应在50附近(超买超卖中间值) - assert len(k) == len(df), "常量数据长度应保持不变" - - def test_kdj_range(self): - """测试KDJ取值范围""" - df = get_test_data() - k, d, j = KDJ(df["close"].values, df["high"].values, df["low"].values) - - # 验证K/D在0-100范围 - valid_k = k[~pd.isna(k)] - valid_d = d[~pd.isna(d)] - assert all((valid_k >= 0) & (valid_k <= 100)), "K值应在0-100范围内" - assert all((valid_d >= 0) & (valid_d <= 100)), "D值应在0-100范围内" - - -class TestIndicatorsIntegration: - """技术指标综合测试""" - - def test_indicators_with_real_data_consistency(self): - """测试技术指标在真实数据上的一致性""" - df = get_test_data() - - # 测试多个指标 - sma = SMA(df["close"].values, 20) - ema = EMA(df["close"].values, 20) - diff, dea, macd = MACD(df["close"].values) - rsi = RSI(df["close"].values, 14) - upper, middle, lower = BOLL(df["close"].values, 20) - atr = ATR(df["high"], df["low"], df["close"], 14) - k, d, j = KDJ(df["close"].values, df["high"].values, df["low"].values) - - # 验证所有指标长度一致 - assert len(sma) == len(df), "SMA长度应一致" - assert len(ema) == len(df), "EMA长度应一致" - assert len(diff) == len(df), "MACD DIFF长度应一致" - assert len(rsi) == len(df), "RSI长度应一致" - assert len(upper) == len(df), "BOLL上轨长度应一致" - assert len(atr) == len(df), "ATR长度应一致" - assert len(k) == len(df), "KDJ K值长度应一致" - - def test_indicators_with_mixed_nan_inf(self): - """测试技术指标对混合NaN/Inf数据的处理""" - data = np.array([1, 2, np.nan, 4, np.inf, 6, -np.inf, 8, 9, 10] * 10) - - # 测试不会崩溃 - sma = SMA(data, 5) - ema = EMA(data, 5) - diff, dea, macd = MACD(data) - rsi = RSI(data, 6) - - assert len(sma) == len(data), "SMA应能处理混合数据" - assert len(ema) == len(data), "EMA应能处理混合数据" - assert len(diff) == len(data), "MACD应能处理混合数据" - assert len(rsi) == len(data), "RSI应能处理混合数据" diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 000000000..f4a0bf67f --- /dev/null +++ b/test/unit/__init__.py @@ -0,0 +1,12 @@ +"""单元测试包。 + +本目录用于存放单元测试用例(unit tests),针对单个函数、类或模块的最小 +功能单元进行隔离测试。单元测试应做到: + +- 运行速度快,单个用例通常在毫秒级完成 +- 不依赖外部资源(数据库、网络、文件系统副作用等) +- 通过 Mock / 桩对象隔离外部依赖 +- 失败时能够精确定位到被测代码的最小单元 + +测试文件命名约定:``test_*.py``。 +""" diff --git a/test/unit/snapshots/core_parity_seed42.json b/test/unit/snapshots/core_parity_seed42.json new file mode 100644 index 000000000..ce95b74f5 --- /dev/null +++ b/test/unit/snapshots/core_parity_seed42.json @@ -0,0 +1,277 @@ +{ + "seed": 42, + "symbol": "000001", + "freq": "30分钟", + "sdt": "20240101", + "edt": "20240301", + "bars_count": 610, + "fx_list_count": 175, + "bi_list_count": 43, + "fx_marks": [ + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型", + "顶分型", + "底分型" + ], + "bi_directions": [ + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下", + "向上", + "向下" + ], + "bi_lengths": [ + 8, + 8, + 7, + 15, + 7, + 19, + 7, + 7, + 11, + 8, + 6, + 17, + 6, + 9, + 8, + 21, + 8, + 12, + 13, + 11, + 7, + 9, + 7, + 11, + 7, + 6, + 6, + 8, + 9, + 6, + 10, + 16, + 16, + 9, + 7, + 6, + 23, + 20, + 10, + 7, + 23, + 9, + 14 + ] +} diff --git a/test/unit/test_core_parity.py b/test/unit/test_core_parity.py new file mode 100644 index 000000000..579816c83 --- /dev/null +++ b/test/unit/test_core_parity.py @@ -0,0 +1,141 @@ +"""CZSC 核心算法(FX/BI/ZS)一致性单元测试。 + +本测试套件验证迁移后的 czsc-core(Rust 实现)在固定输入下产生的 +分型(FX)、笔(BI)、中枢(ZS)结果,与基线快照(baseline snapshot) +逐字节一致(byte-for-byte identical)。 + +业务背景: + 缠论的核心识别算法对最终交易信号有决定性影响,任何细微的实现差异都 + 可能导致大量信号漂移。因此在从 ``rs_czsc`` 迁移到 in-repo 的 + ``czsc._native`` 过程中,必须保证算法输出与一个锁定的基线( + rs-czsc commit ``47ef6efa``,seed=42)完全一致。 + +测试覆盖: + - ``czsc.CZSC`` 的来源必须是 ``czsc._native``,而不是外部 ``rs_czsc``; + - 在固定 mock 数据(seed=42)上构造 CZSC 对象后: + * 分型数量与基线一致 + * 笔数量与基线一致 + * 分型方向序列与基线一致 + * 笔方向序列与基线一致 + * 笔长度序列与基线一致 +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +# 基线快照文件路径,存放固定输入下应得的标准输出 +SNAPSHOT_PATH = Path(__file__).parent / "snapshots" / "core_parity_seed42.json" + + +def _load_snapshot() -> dict[str, Any]: + """加载基线快照 JSON 文件并返回字典。""" + return json.loads(SNAPSHOT_PATH.read_text(encoding="utf-8")) + + +def _build_czsc() -> tuple[Any | None, str | None]: + """构造一个用于一致性比对的 CZSC 实例。 + + 使用固定参数(symbol='000001',30 分钟,2024-01-01~2024-03-01,seed=42) + 生成可重现的 mock K 线数据,再转换为 RawBar 列表,最后构造 CZSC 对象。 + + 返回值: + 构造成功时返回 ``(czsc_obj, None)``;任何异常情况下返回 + ``(None, 错误描述)``,避免在 fixture 阶段直接抛出导致测试 ERROR + 而不是 FAIL。 + """ + try: + import czsc + from czsc.mock import generate_symbol_kines + + df = generate_symbol_kines( + "000001", "30分钟", "20240101", "20240301", seed=42 + ) + bars = czsc.format_standard_kline(df, freq=czsc.Freq.F30) + return czsc.CZSC(bars), None + except Exception as exc: # noqa: BLE001 + return None, f"build failed: {type(exc).__name__}: {exc}" + + +def test_czsc_source_is_in_repo_native() -> None: + """验证 czsc.CZSC 类来源于 czsc._native,而不是外部 rs_czsc。 + + 测试目标: + 迁移目标之一是彻底剥离对 ``rs_czsc`` PyPI 包的依赖, + 因此 ``CZSC`` 类的 ``__module__`` 必须以 ``czsc.`` 开头。 + + 关键断言: + ``type(obj).__module__`` 字符串以 ``"czsc."`` 开头。 + """ + obj, err = _build_czsc() + assert obj is not None, err + module = type(obj).__module__ + assert module.startswith("czsc."), ( + f"czsc.CZSC 必须来自 czsc._native(实际:{module!r});" + "迁移目标要求完全移除 rs_czsc PyPI 依赖。" + ) + + +def test_fx_list_count_matches_baseline() -> None: + """验证识别出的分型(FX)数量与基线一致。""" + obj, err = _build_czsc() + assert obj is not None, err + snap = _load_snapshot() + assert len(obj.fx_list) == snap["fx_list_count"], ( + f"FX 数量出现漂移:实际 {len(obj.fx_list)},基线 " + f"{snap['fx_list_count']}" + ) + + +def test_bi_list_count_matches_baseline() -> None: + """验证识别出的笔(BI)数量与基线一致。""" + obj, err = _build_czsc() + assert obj is not None, err + snap = _load_snapshot() + assert len(obj.bi_list) == snap["bi_list_count"], ( + f"BI 数量出现漂移:实际 {len(obj.bi_list)},基线 " + f"{snap['bi_list_count']}" + ) + + +def test_fx_marks_sequence_matches_baseline() -> None: + """验证分型方向序列(顶分型 G / 底分型 D)逐项与基线一致。 + + 关键断言: + 将 ``obj.fx_list`` 中每个 FX 的 ``mark`` 字段转字符串后形成的列表, + 必须与基线快照中的 ``fx_marks`` 列表完全相等。 + """ + obj, err = _build_czsc() + assert obj is not None, err + snap = _load_snapshot() + actual = [str(fx.mark) for fx in obj.fx_list] + assert actual == snap["fx_marks"], ( + f"FX 方向序列出现漂移;首个差异下标 = " + f"{next((i for i, (a, b) in enumerate(zip(actual, snap['fx_marks'])) if a != b), 'len mismatch')}" + ) + + +def test_bi_directions_sequence_matches_baseline() -> None: + """验证笔方向序列(向上笔 Up / 向下笔 Down)逐项与基线一致。""" + obj, err = _build_czsc() + assert obj is not None, err + snap = _load_snapshot() + actual = [str(bi.direction) for bi in obj.bi_list] + assert actual == snap["bi_directions"], ( + f"BI 方向序列出现漂移;首个差异下标 = " + f"{next((i for i, (a, b) in enumerate(zip(actual, snap['bi_directions'])) if a != b), 'len mismatch')}" + ) + + +def test_bi_lengths_sequence_matches_baseline() -> None: + """验证每一笔的长度(包含 K 线根数)序列与基线一致。""" + obj, err = _build_czsc() + assert obj is not None, err + snap = _load_snapshot() + actual = [bi.length for bi in obj.bi_list] + assert actual == snap["bi_lengths"], ( + f"BI 长度序列出现漂移;期望 {snap['bi_lengths'][:5]}…," + f"实际 {actual[:5]}…" + ) diff --git a/test/unit/test_pickle.py b/test/unit/test_pickle.py new file mode 100644 index 000000000..05bc6bfd5 --- /dev/null +++ b/test/unit/test_pickle.py @@ -0,0 +1,137 @@ +"""PyO3 类的 pickle 往返序列化单元测试。 + +本测试套件验证 czsc 暴露的所有 PyO3 类(Rust 实现 + Python 绑定)都支持 +``pickle.dumps`` / ``pickle.loads`` 往返序列化(roundtrip),从而保证它们 +能够安全地穿越多进程边界(multiprocessing boundary)。 + +业务背景: + 在 Streamlit、Joblib、Dask 等多进程场景下,工作进程之间通常通过 pickle + 协议传递对象。如果某个 PyO3 类没有实现 ``__getstate__`` / + ``__setstate__``,则在跨进程传递时会失败,导致并行化能力受损。 + +测试策略: + 采用 "FAIL 而不是 ERROR" 的策略:在 fixture 阶段捕获导入和构造异常, + 并通过 ``pytest.fail`` 转换为断言失败。这样 CI 能直接看到失败原因, + 而不是因为构造阶段抛异常使测试以 ERROR 状态结束、被部分工具忽略。 + +测试覆盖: + - ``CZSC``: 缠论分析核心类; + - ``BarGenerator``: K 线生成器; + - ``Position``: 持仓对象; + - ``CzscSignals`` / ``CzscTrader``: 信号与交易容器类。 +""" + +from __future__ import annotations + +import pickle +from typing import Any + +import pytest + + +def _try_build_bars() -> tuple[Any | None, str | None]: + """通过 czsc.mock + format_standard_kline 构造小批量 RawBar 列表。 + + 使用固定 seed=42 与短时间区间以保证测试快速、结果稳定。 + + 返回值: + - 成功:``(bars, None)``; + - 失败:``(None, error_msg)``。 + + 通过返回元组而非抛异常,调用方可以将失败转化为 AssertionError, + 避免 fixture 阶段直接抛异常造成 pytest 报告为 ERROR 的情况。 + """ + try: + from czsc.mock import generate_symbol_kines # type: ignore[attr-defined] + from czsc import format_standard_kline, Freq # type: ignore[attr-defined] + except Exception as exc: # noqa: BLE001 + return None, f"import failed: {type(exc).__name__}: {exc}" + + try: + df = generate_symbol_kines("000001", "30分钟", "20240101", "20240105", seed=42) + bars = format_standard_kline(df, freq=Freq.F30) + return bars, None + except Exception as exc: # noqa: BLE001 + return None, f"bar generation failed: {type(exc).__name__}: {exc}" + + +def _build_obj(name: str, bars: Any) -> tuple[Any | None, str | None]: + """根据目标类名构造一个待 pickle 的实例(不抛异常版本)。 + + 支持构造的目标类: + - ``CZSC``: 直接使用 RawBar 列表构造; + - ``BarGenerator``: 仅初始化基础频率与目标频率列表; + - ``Position``: 构造一个空的开平仓策略对象; + - ``CzscSignals``/``CzscTrader``: 需要先用 BarGenerator 喂入 bars, + 再分别传入信号配置与持仓策略列表。 + """ + try: + import czsc + + if name == "CZSC": + return czsc.CZSC(bars), None + if name == "BarGenerator": + return czsc.BarGenerator(base_freq="30分钟", freqs=["日线"]), None + if name == "Position": + return ( + czsc.Position(symbol="000001", name="t", opens=[], exits=[]), + None, + ) + if name in ("CzscSignals", "CzscTrader"): + # CzscSignals 与 CzscTrader 都要求传入一个已经接收过 K 线、 + # 处于"已就绪"状态的 BarGenerator,以及信号配置列表; + # 其中 CzscTrader 还额外要求一个持仓策略列表。 + bg = czsc.BarGenerator(base_freq="30分钟", freqs=["日线"]) + for bar in bars: + bg.update(bar) + if name == "CzscSignals": + return czsc.CzscSignals(bg, []), None + return czsc.CzscTrader(bg, [], []), None + except Exception as exc: # noqa: BLE001 + return None, f"construction failed: {type(exc).__name__}: {exc}" + return None, f"unknown target: {name}" + + +# 参数化覆盖所有需要支持 pickle 的核心 PyO3 类 +@pytest.mark.parametrize( + "target", + ["CZSC", "BarGenerator", "Position", "CzscSignals", "CzscTrader"], +) +def test_pickle_roundtrip(target: str) -> None: + """对每个 PyO3 类执行完整的 pickle dump → load 往返。 + + 测试场景: + 1. 构造测试用 RawBar 列表与目标对象; + 2. 调用 ``pickle.dumps`` 序列化对象到字节串; + 3. 调用 ``pickle.loads`` 反序列化恢复对象; + 4. 校验类型与内部状态保持一致。 + + 关键断言: + - 反序列化后对象类型与原对象完全一致; + - 若两者都实现了 ``__getstate__``,反序列化后状态字典相等。 + """ + bars, err = _try_build_bars() + if err is not None: + pytest.fail(f"[{target}] {err}") # 转化为 FAIL,避免 ERROR + + obj, build_err = _build_obj(target, bars) + if obj is None: + pytest.fail(f"[{target}] {build_err}") + + try: + blob = pickle.dumps(obj) + restored = pickle.loads(blob) + except Exception as exc: # noqa: BLE001 + pytest.fail( + f"[{target}] pickle 往返序列化抛出异常 " + f"{type(exc).__name__}: {exc}" + ) + + assert type(restored) is type(obj), ( + f"[{target}] 往返序列化改变了对象类型:{type(obj)} → {type(restored)}" + ) + + if hasattr(obj, "__getstate__") and hasattr(restored, "__getstate__"): + assert restored.__getstate__() == obj.__getstate__(), ( + f"[{target}] __getstate__ 在往返序列化后发生变化" + ) diff --git a/test/unit/test_signals_parity.py b/test/unit/test_signals_parity.py new file mode 100644 index 000000000..b3c92daab --- /dev/null +++ b/test/unit/test_signals_parity.py @@ -0,0 +1,128 @@ +"""信号函数命名空间一致性单元测试。 + +本测试套件验证 ``czsc.signals.*`` 命名空间下的信号函数已完整迁移到 +in-repo 的 Rust 扩展 ``czsc._native.signals``,并满足以下契约: + +业务背景: + czsc 的信号体系按类别组织(bar / cxt / tas / vol / pressure / obv / cvolp), + 每个子包提供若干信号函数。早期版本中部分实现来自外部 ``rs_czsc`` 包, + 迁移过程要求所有信号函数统一来自 in-repo 的 Rust 扩展,并通过薄层 + 重导出(thin re-export)暴露在 Python 命名空间。 + +测试覆盖: + - 每个必需子包都可以被 import; + - 每个子包至少暴露 ``MIN_FUNCS_PER_SUBPACKAGE`` 个可调用对象; + - 每个可调用对象的 ``__module__`` 必须以 ``czsc.`` 开头(来源于 czsc._native); + - ``czsc._native`` 必须存在 ``signals`` 子模块。 + +注意: + 本文件只锁定**命名空间契约**(namespace contract),即"接口存在且来源 + 正确";逐个信号函数的数值一致性(per-function value parity)将在 + 迁移过程中通过其它测试逐步补齐。 +""" + +from __future__ import annotations + +import importlib +from typing import Any + +import pytest + +# 信号体系按类别组织,下列子包均必须存在 +REQUIRED_SUBPACKAGES = ("bar", "cxt", "tas", "vol", "pressure", "obv", "cvolp") +# 每个子包至少应暴露的公开可调用对象数量阈值 +MIN_FUNCS_PER_SUBPACKAGE = 3 + + +def _safe_import(name: str) -> tuple[Any | None, str | None]: + """安全导入指定模块,捕获所有异常并返回 (module, err) 元组。""" + try: + return importlib.import_module(name), None + except Exception as exc: # noqa: BLE001 + return None, f"{type(exc).__name__}: {exc}" + + +# 参数化覆盖所有必需信号子包,每个子包验证其可被导入 +@pytest.mark.parametrize("sub", REQUIRED_SUBPACKAGES) +def test_signal_subpackage_exists(sub: str) -> None: + """验证给定信号子包在迁移完成后可被成功 import。 + + 关键断言: + ``importlib.import_module(f"czsc.signals.{sub}")`` 不抛异常且返回非 None。 + """ + mod, err = _safe_import(f"czsc.signals.{sub}") + assert mod is not None, ( + f"czsc.signals.{sub} 必须可被导入({err})" + ) + + +# 参数化覆盖所有必需信号子包,每个子包验证其暴露足够数量的函数 +@pytest.mark.parametrize("sub", REQUIRED_SUBPACKAGES) +def test_signal_subpackage_has_functions(sub: str) -> None: + """验证给定信号子包至少暴露 ``MIN_FUNCS_PER_SUBPACKAGE`` 个可调用对象。 + + 测试场景: + 通过 ``dir(mod)`` 列出所有非下划线开头的属性,过滤出可调用对象, + 并对其数量做下限校验。 + + 关键断言: + 子包公开的可调用对象数量 ≥ ``MIN_FUNCS_PER_SUBPACKAGE``。 + """ + mod, err = _safe_import(f"czsc.signals.{sub}") + if mod is None: + pytest.fail(f"无法导入 czsc.signals.{sub}: {err}") + funcs = [ + name for name in dir(mod) + if not name.startswith("_") and callable(getattr(mod, name)) + ] + assert len(funcs) >= MIN_FUNCS_PER_SUBPACKAGE, ( + f"czsc.signals.{sub} 必须至少暴露 {MIN_FUNCS_PER_SUBPACKAGE} 个 " + f"信号函数;实际找到 {len(funcs)} 个:{funcs}" + ) + + +# 参数化覆盖所有必需信号子包,每个子包验证函数来源 +@pytest.mark.parametrize("sub", REQUIRED_SUBPACKAGES) +def test_signal_subpackage_sourced_from_native(sub: str) -> None: + """验证子包内每个函数都来源于 czsc.* 命名空间,而不是外部包。 + + 测试目标: + 确保 czsc.signals 是 ``czsc._native.signals`` 的薄重导出层, + 而不是含有来自 ``rs_czsc`` 等外部模块的实现。 + + 关键断言: + 遍历每个公开可调用对象,其 ``__module__`` 必须以 ``czsc.`` 开头。 + """ + mod, err = _safe_import(f"czsc.signals.{sub}") + if mod is None: + pytest.fail(f"无法导入 czsc.signals.{sub}: {err}") + funcs = [ + getattr(mod, n) for n in dir(mod) + if not n.startswith("_") and callable(getattr(mod, n)) + ] + if not funcs: + pytest.fail(f"czsc.signals.{sub} 中没有任何可调用对象") + + foreign = [f for f in funcs if not getattr(f, "__module__", "").startswith("czsc.")] + assert not foreign, ( + f"czsc.signals.{sub} 中包含 {len(foreign)} 个来源于 czsc.* 之外的函数 " + f"(例如 {foreign[0].__module__}.{foreign[0].__name__});" + "必须全部从 czsc._native.signals 重导出" + ) + + +def test_native_signals_module_exists() -> None: + """验证 czsc._native 已注册 signals 子模块。 + + 测试目标: + Rust 扩展 ``czsc._native`` 在编译时通过 PyO3 注册了名为 ``signals`` + 的子模块,作为信号函数的最终来源。 + + 关键断言: + ``czsc._native`` 模块存在且具备 ``signals`` 属性。 + """ + native, err = _safe_import("czsc._native") + assert native is not None, f"czsc._native 必须存在(maturin 构建产物)({err})" + assert hasattr(native, "signals"), ( + "czsc._native.signals 必须是已注册的 PyO3 子模块" + ) diff --git a/test/unit/test_ta_parity.py b/test/unit/test_ta_parity.py new file mode 100644 index 000000000..e517d1537 --- /dev/null +++ b/test/unit/test_ta_parity.py @@ -0,0 +1,166 @@ +"""技术指标算子(TA Operators)与 Python TA-Lib 一致性单元测试。 + +本测试套件验证迁移后的 ``czsc.ta.*`` Rust + PyO3 算子在数值上与 +Python 版本的 ``talib`` 库结果保持高精度一致(相对误差 / 绝对误差 +均小于 1e-6),覆盖核心技术指标算子。 + +业务背景: + 历史上 czsc 在 ``czsc.utils.ta`` 中提供了一层 Python 对 TA-Lib 的薄包装。 + 迁移目标是用纯 Rust 实现替换该层,并通过 PyO3 暴露为 ``czsc._native.ta``, + 再由 ``czsc.ta`` 重导出。在替换过程中,必须保证以下指标的输出与 talib + 在相同输入下数值上完全一致: + + - ``ema``:指数移动平均 + - ``sma``:简单移动平均 + - ``rolling_rank``:滚动百分位排名 + - ``boll_positions``:布林通道位置 + - ``ultimate_smoother``:终极平滑器 + +测试策略: + 采用 "FAIL 而不是 ERROR" 策略:导入和数值操作的异常被捕获并在测试函数 + 内部转换为 ``pytest.fail``,这样 CI 报告会显示具体失败原因而不是 ERROR。 +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + + +def _safe_import(name: str) -> tuple[Any | None, str | None]: + """安全导入指定模块,捕获所有异常并返回 (module, err) 元组。""" + try: + return __import__(name, fromlist=["__init__"]), None + except Exception as exc: # noqa: BLE001 + return None, f"{type(exc).__name__}: {exc}" + + +def _series(seed: int = 42, n: int = 1024) -> np.ndarray: + """生成可重现的标准正态分布随机序列,用作算子的输入数据。 + + 使用 numpy 的 ``default_rng`` 生成 1024 个 float64 样本, + 默认 seed=42 保证不同运行间结果完全一致。 + """ + return np.random.default_rng(seed).standard_normal(n).astype(np.float64) + + +def _native_module() -> tuple[Any | None, str | None]: + """便捷封装:尝试导入 ``czsc.ta`` 模块。""" + return _safe_import("czsc.ta") + + +def test_ta_module_sourced_from_native() -> None: + """验证 czsc.ta 来自 Rust 扩展子模块,而不是 Python TA-Lib 包装。 + + 测试目标: + ``czsc.ta`` 必须由 ``czsc._native.ta`` 提供, + 旧的 Python 版 ``czsc.utils.ta`` 已被移除。 + + 关键断言: + 模块的 ``__file__`` 路径中包含 ``czsc/_native``,或 ``__name__`` 以 + ``czsc._native`` 开头。 + """ + ta, err = _native_module() + assert ta is not None, f"czsc.ta 必须存在({err})" + module_name = getattr(ta, "__name__", "?") + file_path = getattr(ta, "__file__", "") or "" + assert "czsc/_native" in file_path or module_name.startswith("czsc._native"), ( + f"czsc.ta 必须来自 czsc._native.ta(实际 module={module_name!r}, " + f"file={file_path!r});旧的 czsc.utils.ta 包装层必须移除" + ) + + +def test_ema_matches_talib() -> None: + """验证 czsc.ta.ema 输出与 talib.EMA 在 timeperiod=14 时数值一致。 + + 测试场景: + 在同一随机序列上分别调用两端的 EMA,跳过算子预热期(前 20 个点), + 对剩余结果做高精度数值比较。 + + 关键断言: + ``np.testing.assert_allclose(actual[20:], expected[20:], rtol=1e-6, atol=1e-6)``。 + """ + ta, err = _native_module() + if ta is None: + pytest.fail(f"czsc.ta 不可用:{err}") + talib_mod, terr = _safe_import("talib") + if talib_mod is None: + pytest.fail(f"talib 不可用:{terr}") + + series = _series() + expected = talib_mod.EMA(series, timeperiod=14) + if not hasattr(ta, "ema"): + pytest.fail("czsc.ta.ema 尚未暴露") + actual = ta.ema(series, length=14) + np.testing.assert_allclose( + np.asarray(actual)[20:], expected[20:], rtol=1e-6, atol=1e-6 + ) + + +def test_sma_matches_talib() -> None: + """验证 czsc.ta.sma 输出与 talib.SMA 在 timeperiod=20 时数值一致。 + + 测试场景:与 EMA 一致,跳过预热期后做高精度比较。 + """ + ta, err = _native_module() + if ta is None: + pytest.fail(f"czsc.ta 不可用:{err}") + talib_mod, terr = _safe_import("talib") + if talib_mod is None: + pytest.fail(f"talib 不可用:{terr}") + + series = _series() + expected = talib_mod.SMA(series, timeperiod=20) + if not hasattr(ta, "sma"): + pytest.fail("czsc.ta.sma 尚未暴露") + actual = ta.sma(series, length=20) + np.testing.assert_allclose( + np.asarray(actual)[20:], expected[20:], rtol=1e-6, atol=1e-6 + ) + + +def test_rolling_rank_returns_finite() -> None: + """验证 czsc.ta.rolling_rank 在预热期之后输出有限值(不出现 NaN/Inf)。 + + 关键断言: + ``np.isfinite(out[20:]).all()`` 为真,确保算子在窗口建立后能产出稳定数值。 + """ + ta, err = _native_module() + if ta is None: + pytest.fail(f"czsc.ta 不可用:{err}") + if not hasattr(ta, "rolling_rank"): + pytest.fail("czsc.ta.rolling_rank 尚未暴露") + out = np.asarray(ta.rolling_rank(_series(), window=20)) + assert np.isfinite(out[20:]).all(), ( + "rolling_rank 在预热窗口之后必须产出有限值" + ) + + +def test_boll_positions_signature() -> None: + """验证 czsc.ta 暴露了 boll_positions(布林通道位置)算子。 + + 关键断言: + ``hasattr(ta, "boll_positions")`` 为真。 + """ + ta, err = _native_module() + if ta is None: + pytest.fail(f"czsc.ta 不可用:{err}") + assert hasattr(ta, "boll_positions"), ( + "czsc.ta.boll_positions 必须暴露" + ) + + +def test_ultimate_smoother_signature() -> None: + """验证 czsc.ta 暴露了 ultimate_smoother(终极平滑器)算子。 + + 关键断言: + ``hasattr(ta, "ultimate_smoother")`` 为真。 + """ + ta, err = _native_module() + if ta is None: + pytest.fail(f"czsc.ta 不可用:{err}") + assert hasattr(ta, "ultimate_smoother"), ( + "czsc.ta.ultimate_smoother 必须暴露" + ) diff --git a/test/unit/test_trading_time.py b/test/unit/test_trading_time.py new file mode 100644 index 000000000..224891b81 --- /dev/null +++ b/test/unit/test_trading_time.py @@ -0,0 +1,106 @@ +"""``is_trading_time`` 跨市场交易时段判断单元测试。 + +本测试套件验证 ``czsc.is_trading_time`` 函数能够正确识别 A 股、港股和 +数字货币三大市场的可交易时间段。 + +业务背景: + ``is_trading_time`` 是仅在 czsc 中提供(rs-czsc 不包含)的实用函数, + 迁移过程中由 Rust 实现并通过 PyO3 暴露在 ``czsc._native`` 命名空间, + 最终重导出为 ``czsc.is_trading_time``。 + +各市场交易时段(本地时区): + - A 股 (astock):周一至周五 09:30-11:30 + 13:00-15:00(北京时间) + - 港股 (hk):周一至周五 09:30-12:00 + 13:00-16:00(香港时间) + - 数字货币 (crypto):全年 7×24 小时可交易 + +测试覆盖: + - 三个市场在工作日内边界点(开盘/收盘/午休前后)的判定; + - A 股周末(非交易日)的拒绝判定; + - 数字货币的"始终可交易"语义; + - ``is_trading_time`` 函数的来源必须是 czsc._native(Rust 实现)。 +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +import pytest + + +def _safe_import_czsc() -> tuple[Any | None, str | None]: + """安全导入 czsc 顶层包,捕获所有异常并返回 (czsc, err) 元组。""" + try: + import czsc + + return czsc, None + except Exception as exc: # noqa: BLE001 + return None, f"{type(exc).__name__}: {exc}" + + +# 参数化用例覆盖三大市场的关键时间边界点 +@pytest.mark.parametrize( + "market, dt, expected", + [ + # A 股 — 2024-01-08 周一(正常交易日):覆盖上下午开盘、午休、收盘前后 + ("astock", datetime(2024, 1, 8, 9, 30), True), + ("astock", datetime(2024, 1, 8, 10, 0), True), + ("astock", datetime(2024, 1, 8, 11, 30), True), + ("astock", datetime(2024, 1, 8, 12, 30), False), + ("astock", datetime(2024, 1, 8, 13, 0), True), + ("astock", datetime(2024, 1, 8, 15, 0), True), + ("astock", datetime(2024, 1, 8, 15, 30), False), + ("astock", datetime(2024, 1, 6, 10, 0), False), # 周六,非交易日 + # 港股 — 2024-01-08 周一:覆盖开盘、午休、收盘 + ("hk", datetime(2024, 1, 8, 9, 30), True), + ("hk", datetime(2024, 1, 8, 12, 0), False), + ("hk", datetime(2024, 1, 8, 16, 0), True), + # 数字货币 — 任何时间均可交易(包括周末和节假日) + ("crypto", datetime(2024, 1, 6, 3, 0), True), + ("crypto", datetime(2024, 12, 25, 0, 0), True), + ], +) +def test_is_trading_time(market: str, dt: datetime, expected: bool) -> None: + """对每个市场 / 时间点组合验证 is_trading_time 的判定结果。 + + 测试场景: + 参数化执行 13 个用例,覆盖 A 股、港股、数字货币三大市场在 + 不同时间段的边界判定。 + + 关键断言: + ``czsc.is_trading_time(dt, market=market)`` 返回值与预期布尔值完全相等 + (使用 ``is`` 比较,确保返回的是布尔类型而非 truthy 值)。 + """ + czsc, err = _safe_import_czsc() + if czsc is None: + pytest.fail(f"导入 czsc 失败:{err}") + if not hasattr(czsc, "is_trading_time"): + pytest.fail( + "czsc.is_trading_time 尚未暴露 — czsc-utils 必须添加该函数" + ) + actual = czsc.is_trading_time(dt, market=market) + assert actual is expected, ( + f"is_trading_time({market}, {dt.isoformat()}) 返回 {actual}," + f"预期 {expected}" + ) + + +def test_is_trading_time_module_origin() -> None: + """验证 is_trading_time 来自 czsc._native(Rust 实现),而非 Python helper。 + + 测试目标: + 确保该函数已经走 Rust 实现路径,而不是仍由旧的 Python 工具函数提供。 + + 关键断言: + ``is_trading_time.__module__`` 字符串以 ``"czsc."`` 开头。 + """ + czsc, err = _safe_import_czsc() + if czsc is None: + pytest.fail(f"导入 czsc 失败:{err}") + fn = getattr(czsc, "is_trading_time", None) + if fn is None: + pytest.fail("czsc.is_trading_time 缺失") + module = getattr(fn, "__module__", "?") + assert module.startswith("czsc."), ( + f"is_trading_time 必须来自 czsc._native(实际 {module!r})" + ) diff --git a/tests/test_workspace_layout.sh b/tests/test_workspace_layout.sh new file mode 100755 index 000000000..249cd1f37 --- /dev/null +++ b/tests/test_workspace_layout.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# Phase B.1 — RED test: workspace must contain 9 named crates and build cleanly. +# See docs/superpowers/specs/2026-05-03-rust-czsc-migration-design.md §1, §2. +set -euo pipefail + +required=( + czsc-core + czsc-utils + czsc-ta + czsc-signals + czsc-trader + czsc-signal-macros + error-macros + error-support + czsc-python +) + +echo "[1/3] Checking crate file layout..." +for c in "${required[@]}"; do + test -f "crates/$c/Cargo.toml" || { echo " MISSING: crates/$c/Cargo.toml"; exit 1; } + test -f "crates/$c/src/lib.rs" || { echo " MISSING: crates/$c/src/lib.rs"; exit 1; } +done +echo " OK: 9 crate file layouts present" + +echo "[2/3] Checking cargo workspace metadata..." +cargo metadata --format-version 1 --no-deps 2>/dev/null \ + | python3 -c " +import json, sys +data = json.load(sys.stdin) +members = {pkg['name'] for pkg in data['packages']} +required = set('${required[*]}'.split()) +missing = required - members +if missing: + print(f' MISSING from workspace: {sorted(missing)}', file=sys.stderr) + sys.exit(1) +print(f' OK: workspace registers {len(members)} crate(s)') +" + +echo "[3/3] Running cargo build --workspace..." +cargo build --workspace --quiet +echo " OK: cargo build --workspace passed" diff --git a/uv.lock b/uv.lock index 1ca47d7ca..d354bace4 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -926,6 +926,7 @@ wheels = [ [[package]] name = "czsc" +version = "1.0.0" source = { editable = "." } dependencies = [ { name = "clickhouse-connect" }, @@ -952,7 +953,6 @@ dependencies = [ { name = "reportlab" }, { name = "requests" }, { name = "requests-toolbelt" }, - { name = "rs-czsc" }, { name = "scikit-learn" }, { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.python.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.16.2", source = { registry = "https://pypi.python.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -962,6 +962,7 @@ dependencies = [ { name = "ta-lib" }, { name = "tenacity" }, { name = "tqdm" }, + { name = "wbt" }, ] [package.optional-dependencies] @@ -972,6 +973,7 @@ all = [ { name = "jupyter" }, { name = "pytest" }, { name = "pytest-cov" }, + { name = "rs-czsc" }, { name = "twine" }, ] dev = [ @@ -986,6 +988,7 @@ release = [ test = [ { name = "pytest" }, { name = "pytest-cov" }, + { name = "rs-czsc" }, ] [package.dev-dependencies] @@ -994,6 +997,7 @@ dev = [ { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.python.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "ipython", version = "9.5.0", source = { registry = "https://pypi.python.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "jupyter" }, + { name = "maturin" }, { name = "pytest" }, { name = "pyyaml" }, { name = "ruff" }, @@ -1002,16 +1006,14 @@ dev = [ [package.metadata] requires-dist = [ - { name = "build", marker = "extra == 'all'" }, { name = "build", marker = "extra == 'release'" }, { name = "clickhouse-connect" }, { name = "cryptography" }, + { name = "czsc", extras = ["test", "dev", "release"], marker = "extra == 'all'" }, { name = "deprecated" }, { name = "dill" }, { name = "fonttools", specifier = ">=4.61.0" }, - { name = "ipython", marker = "extra == 'all'" }, { name = "ipython", marker = "extra == 'dev'" }, - { name = "jupyter", marker = "extra == 'all'" }, { name = "jupyter", marker = "extra == 'dev'" }, { name = "kaleido" }, { name = "lightweight-charts" }, @@ -1026,16 +1028,14 @@ requires-dist = [ { name = "polars", specifier = ">=0.20.0" }, { name = "pyarrow" }, { name = "pyecharts", specifier = ">=1.9.1" }, - { name = "pytest", marker = "extra == 'all'", specifier = ">=7.0.0" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=7.0.0" }, - { name = "pytest-cov", marker = "extra == 'all'", specifier = ">=4.0.0" }, { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" }, { name = "pytz" }, { name = "redis" }, { name = "reportlab", specifier = ">=4.0" }, { name = "requests", specifier = ">=2.24.0" }, { name = "requests-toolbelt" }, - { name = "rs-czsc", specifier = ">=0.1.26" }, + { name = "rs-czsc", marker = "extra == 'test'" }, { name = "scikit-learn" }, { name = "scipy" }, { name = "seaborn" }, @@ -1044,16 +1044,17 @@ requires-dist = [ { name = "ta-lib", specifier = ">=0.6" }, { name = "tenacity" }, { name = "tqdm", specifier = ">=4.66.4" }, - { name = "twine", marker = "extra == 'all'" }, { name = "twine", marker = "extra == 'release'" }, + { name = "wbt" }, ] -provides-extras = ["all", "dev", "release", "test"] +provides-extras = ["test", "dev", "release", "all"] [package.metadata.requires-dev] dev = [ { name = "basedpyright", specifier = ">=1.28.0" }, { name = "ipython", specifier = ">=8.37.0" }, { name = "jupyter", specifier = ">=1.1.1" }, + { name = "maturin", specifier = ">=1.13.1" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "ruff", specifier = ">=0.9.0" }, @@ -2313,6 +2314,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, ] +[[package]] +name = "maturin" +version = "1.13.1" +source = { registry = "https://pypi.python.org/simple" } +dependencies = [ + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/16/b284a7bc4af3dd87717c784278c1b8cb18606ad1f6f7a671c47bfd9c3df0/maturin-1.13.1.tar.gz", hash = "sha256:9a87ff3b8e4d1c6eac33ebfe8e261e8236516d98d45c0323550621819b5a1a2f", size = 340369, upload-time = "2026-04-09T15:14:07.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/4d/a23fc95be881aa8c7a6ea353410417872e4d7065df03d7f3db8f0dbed4a7/maturin-1.13.1-py3-none-linux_armv6l.whl", hash = "sha256:416e4e01cb88b798e606ee43929df897e42c1647b722ef68283816cca99a8742", size = 10102444, upload-time = "2026-04-09T15:13:48.393Z" }, + { url = "https://files.pythonhosted.org/packages/a6/1e/65c385d65bae95cf04895d52f39dbed8b1453ae55da2903d252ade40a774/maturin-1.13.1-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:72888e87819ce546d0d2df900e4b385e4ef299077d92ee37b48923a5602dae94", size = 19576043, upload-time = "2026-04-09T15:14:08.685Z" }, + { url = "https://files.pythonhosted.org/packages/8f/13/f6bc868d0bfecd9314870b97f530a167e31f7878ac4945c78245c6eef69c/maturin-1.13.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:98b5fcf1a186c217830a8295ecc2989c6b1cf50945417adfc15252107b9475b7", size = 10117339, upload-time = "2026-04-09T15:13:40.559Z" }, + { url = "https://files.pythonhosted.org/packages/51/58/279e081305c11c1c1c4fccacf77df8959646c5d4de7a57ec7e787653e270/maturin-1.13.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:3da18cccf2f683c0977bff9146a0908d6ffce836d600665736ac01679f588cb9", size = 10139689, upload-time = "2026-04-09T15:13:38.291Z" }, + { url = "https://files.pythonhosted.org/packages/00/94/69391af5396c6aab723932240803f49e5f3de3dd7c57d32f02d237a0ce32/maturin-1.13.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:6b1e5916a253243e8f5f9e847b62bbc98420eec48c9ce2e2e8724c6da89d359b", size = 10551141, upload-time = "2026-04-09T15:13:42.887Z" }, + { url = "https://files.pythonhosted.org/packages/9e/bf/4edac2667b49e3733438062ae416413b8fc8d42e1bd499ba15e1fb02fc55/maturin-1.13.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:dc91031e0619c1e28730279ef9ee5f106c9b9ec806b013f888676b242f892eb7", size = 9983094, upload-time = "2026-04-09T15:13:56.868Z" }, + { url = "https://files.pythonhosted.org/packages/79/94/a6d651cfe8fc6bf2e892c90e3cdbb25c06d81c9115140d03ea1a68a97575/maturin-1.13.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:001741c6cff56aa8ea59a0d78ae990c0550d0e3e82b00b683eedb4158a8ef7e6", size = 9949980, upload-time = "2026-04-09T15:13:59.185Z" }, + { url = "https://files.pythonhosted.org/packages/b5/d1/82c067464f848e38af9910bce55eb54302b1c1284a279d515dbfcf5994f5/maturin-1.13.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:01c845825c917c07c1d0b2c9032c59c16a7d383d1e649a46481d3e5693c2750f", size = 13186276, upload-time = "2026-04-09T15:13:45.725Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f4/25367baf1025580f047f9b37598bb3fadc416e24536afd4f28e190335c73/maturin-1.13.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f69093ed4a0e6464e52a7fc26d714f859ce15630ec8070743398c6bf41f38a9e", size = 10891837, upload-time = "2026-04-09T15:13:35.68Z" }, + { url = "https://files.pythonhosted.org/packages/af/be/caafad8ce74974b7deafdf144d12f758993dfea4c66c9905b138f51a7792/maturin-1.13.1-py3-none-manylinux_2_31_riscv64.musllinux_1_1_riscv64.whl", hash = "sha256:c1490584f3c70af45466ee99065b49e6657ebdccac6b10571bb44681309c9396", size = 10351032, upload-time = "2026-04-09T15:14:01.632Z" }, + { url = "https://files.pythonhosted.org/packages/66/0e/970a721d27cfa410e8bfa0a1e32e6ef52cb8169692110a5fdabe1af3f570/maturin-1.13.1-py3-none-win32.whl", hash = "sha256:c6a720b252c99de072922dbe4432ab19662b6f80045b0355fec23bdfccb450da", size = 8855465, upload-time = "2026-04-09T15:13:51.122Z" }, + { url = "https://files.pythonhosted.org/packages/88/70/7c1e0d65fa147d5479055a171541c82b8cdfc1c825d85a82240470f14176/maturin-1.13.1-py3-none-win_amd64.whl", hash = "sha256:a2017d2281203d0c6570240e7d746564d766d756105823b7de68bda6ae722711", size = 10230471, upload-time = "2026-04-09T15:13:53.89Z" }, + { url = "https://files.pythonhosted.org/packages/c5/2a/afe0193b673a79ffd2e01ad999511b7e9e6b49af02bb3759d82a78c3043d/maturin-1.13.1-py3-none-win_arm64.whl", hash = "sha256:2839024dcd65776abb4759e5bca29941971e095574162a4d335191da4be9ff24", size = 8905575, upload-time = "2026-04-09T15:14:03.891Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -4639,6 +4664,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "wbt" +version = "0.1.7" +source = { registry = "https://pypi.python.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.python.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.python.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pandas" }, + { name = "plotly" }, + { name = "polars" }, + { name = "pyarrow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/87/337d04375b78ec37e98f3dec1f4dfc12db76b6076c8ebf19ebc3cfe0d090/wbt-0.1.7.tar.gz", hash = "sha256:b4c400fcde06b1dc14879c49fff4bc67e80d0fd70eef6af5867a5b3ed32ba73f", size = 341218, upload-time = "2026-05-06T02:40:16.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/3a/594d95d260e1ce027b0a5342427639706473e681cc985a548c18298263b3/wbt-0.1.7-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1bbb77d5f27528d139295339c21bd64f73e2bb0bcd9b362a53cb06e4c4e75a96", size = 10409419, upload-time = "2026-05-06T02:40:02.439Z" }, + { url = "https://files.pythonhosted.org/packages/3c/e5/a8f5afc7a178ff93735fad284b0fae81bd8ee1c47f05268d5d629c9283ad/wbt-0.1.7-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:4a2eb66091722f568367b71dd5f689665a764019e050a9dc42422b3cac518379", size = 9503435, upload-time = "2026-05-06T02:40:05.448Z" }, + { url = "https://files.pythonhosted.org/packages/98/25/8316d48311966bee39b564c4a20a052f551a4c5abd6d3493098a69ee1d27/wbt-0.1.7-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11914f4e3ee16f0c788234914df492420d50cecf0d9ed179686cf530c963894c", size = 9849058, upload-time = "2026-05-06T02:40:08.555Z" }, + { url = "https://files.pythonhosted.org/packages/38/37/03d10c4b680688e7871614bdf64630a95296102832af3c5eab42dcb8feaf/wbt-0.1.7-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd9ff5f88133ba861b65a2c4e6c52278c232058699e75a1816770e8be4d84b3b", size = 10848393, upload-time = "2026-05-06T02:40:11.553Z" }, + { url = "https://files.pythonhosted.org/packages/65/7e/4dd81587036fb9e8b156f3f4dd62a8bbc69d8feaa8949257f28e08cc82bc/wbt-0.1.7-cp310-abi3-win_amd64.whl", hash = "sha256:cea88b3ff9ea951e52bd6eb5b20d77f90bc981079ce493940a4cedaed59954e7", size = 11509260, upload-time = "2026-05-06T02:40:14.194Z" }, +] + [[package]] name = "wcwidth" version = "0.2.13" From fb0afce64088ef24441f8aecf7978415ae65b0d3 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 16:36:46 +0800 Subject: [PATCH 02/23] fix(audit): P0 - bump version, delete stale stub, purge stray rs_czsc imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 按飞书 spec wiki 子文档审计的 P0 清单做小补丁: - czsc/__init__.py: __version__ 0.10.12 -> 1.0.0; __date__ 20260308 -> 20260507(与 Cargo.toml/pyproject.toml 一致) - czsc/__init__.pyi: 整文件删除(残留 from rs_czsc / from .core 引用,basedpyright standard 模式报错;py.typed 保留,类型回退到内联注解) - czsc/eda.py: mark_v_reversal 的 rs 双分支折叠为单一 from czsc import …,移除 rs_czsc 与 czsc.utils.bar_generator 死路径 - czsc/utils/sig.py / czsc/utils/analysis/stats.py: from rs_czsc import {Signal,daily_performance} 改走 czsc / wbt(等价符号已验证) - czsc/utils/sig.pyi / czsc/utils/plotting/kline.pyi: 修 5 处 from czsc.core import … as …(已删模块) - docs/MIGRATION_NOTES.md: 新增 §10 Phase Q 章节,登记上述修补清单 --- czsc/__init__.py | 4 +- czsc/__init__.pyi | 226 ---------------------------------- czsc/eda.py | 9 +- czsc/utils/analysis/stats.py | 2 +- czsc/utils/plotting/kline.pyi | 2 +- czsc/utils/sig.py | 4 +- czsc/utils/sig.pyi | 10 +- docs/MIGRATION_NOTES.md | 16 +++ 8 files changed, 28 insertions(+), 245 deletions(-) delete mode 100644 czsc/__init__.pyi diff --git a/czsc/__init__.py b/czsc/__init__.py index 145c0a072..24d9fa242 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -406,10 +406,10 @@ # === 包元信息 === # 这些字段会被 setuptools / pip / sphinx 等工具读取,发布前需同步更新。 # __date__ 采用 ``YYYYMMDD`` 格式,便于排序与追溯。 -__version__ = "0.10.12" +__version__ = "1.0.0" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20260308" +__date__ = "20260507" # === 懒加载子模块映射表 === # 键为公开访问名,值为完整模块路径。 diff --git a/czsc/__init__.pyi b/czsc/__init__.pyi deleted file mode 100644 index 93eae8a9f..000000000 --- a/czsc/__init__.pyi +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -from types import ModuleType -from typing import Any - -# 来自 rs_czsc 的类型,保持宽松 -from rs_czsc import WeightBacktest as WeightBacktest -from rs_czsc import daily_performance as daily_performance -from rs_czsc import top_drawdowns as top_drawdowns - -from . import envs as envs -from . import traders as traders -from . import utils as utils -from .core import ( - CZSC, - ZS, - CzscJsonStrategy, - CzscStrategyBase, - Direction, - Event, - Freq, - NewBar, - Operate, - Position, - RawBar, - Signal, - format_standard_kline, -) -from .eda import ( - cal_symbols_factor, - cal_trade_price, - cal_yearly_days, - cross_sectional_strategy, - dif_long_bear, - limit_leverage, - make_price_features, - mark_cta_periods, - mark_volatility, - min_max_limit, - monotonicity, - remove_beta_effects, - rolling_layers, - sma_long_bear, - tsf_type, - turnover_rate, - twap, - unify_weights, - vwap, - weights_simple_ensemble, -) -from .traders import ( - CzscSignals, - CzscTrader, - SignalsParser, - check_signals_acc, - generate_czsc_signals, - get_signals_config, - get_signals_freqs, - get_unique_signals, -) -from .utils import ( - AliyunOSS, - DataClient, - DiskCache, - clear_cache, - clear_expired_cache, - code_namespace, - create_grid_params, - cross_sectional_ic, - dill_dump, - dill_load, - disk_cache, - empty_cache_path, - fernet_decrypt, - fernet_encrypt, - freqs_sorted, - generate_fernet_key, - get_dir_size, - get_py_namespace, - get_url_token, - holds_performance, - home_path, - import_by_name, - index_composition, - mac_address, - overlap, - print_df_sample, - psi, - read_json, - resample_to_daily, - risk_free_returns, - rolling_daily_performance, - save_json, - set_url_token, - ta, - timeout_decorator, - to_arrow, - update_bbars, - update_nxb, - update_tbars, - x_round, -) - -__version__: str -__author__: str -__email__: str -__date__: str - -# 延迟模块(运行时由 __getattr__ 注入) -svc: ModuleType -fsa: ModuleType -aphorism: ModuleType -mock: ModuleType -cwc: ModuleType - -__all__ = [ - "WeightBacktest", - "daily_performance", - "top_drawdowns", - "envs", - "traders", - "utils", - "CZSC", - "ZS", - "Direction", - "Event", - "Freq", - "NewBar", - "Operate", - "Position", - "RawBar", - "Signal", - "format_standard_kline", - "CzscSignals", - "CzscTrader", - "SignalsParser", - "check_signals_acc", - "generate_czsc_signals", - "get_signals_config", - "get_signals_freqs", - "get_unique_signals", - "AliyunOSS", - "DataClient", - "DiskCache", - "clear_cache", - "clear_expired_cache", - "code_namespace", - "create_grid_params", - "cross_sectional_ic", - "dill_dump", - "dill_load", - "disk_cache", - "empty_cache_path", - "fernet_decrypt", - "fernet_encrypt", - "freqs_sorted", - "generate_fernet_key", - "get_dir_size", - "get_py_namespace", - "get_url_token", - "holds_performance", - "home_path", - "import_by_name", - "index_composition", - "mac_address", - "overlap", - "print_df_sample", - "psi", - "read_json", - "resample_to_daily", - "risk_free_returns", - "rolling_daily_performance", - "save_json", - "set_url_token", - "ta", - "timeout_decorator", - "to_arrow", - "update_bbars", - "update_nxb", - "update_tbars", - "x_round", - "svc", - "fsa", - "aphorism", - "mock", - "cwc", - "CzscStrategyBase", - "CzscJsonStrategy", - "capture_warnings", - "execute_with_warning_capture", - "adjust_holding_weights", - "log_strategy_info", - "calculate_bi_info", - "symbols_bi_infos", - "plot_czsc_chart", - "KlineChart", - "check_kline_quality", - "remove_beta_effects", - "vwap", - "twap", - "cross_sectional_strategy", - "monotonicity", - "min_max_limit", - "rolling_layers", - "cal_symbols_factor", - "weights_simple_ensemble", - "unify_weights", - "sma_long_bear", - "dif_long_bear", - "tsf_type", - "limit_leverage", - "cal_trade_price", - "mark_cta_periods", - "mark_volatility", - "cal_yearly_days", - "turnover_rate", - "make_price_features", - "__version__", - "__author__", - "__email__", - "__date__", - "welcome", -] - -def __getattr__(name: str) -> Any: ... -def welcome() -> None: ... diff --git a/czsc/eda.py b/czsc/eda.py index 9fb59dbc4..93f1f0cc2 100644 --- a/czsc/eda.py +++ b/czsc/eda.py @@ -834,7 +834,6 @@ def mark_v_reversal(df: pd.DataFrame, **kwargs): - copy: 是否复制数据,默认True - verbose: 是否打印日志,默认False - logger: 日志记录器 - - rs: 是否使用rs_czsc,默认True - min_power_percentile: 第一个笔的最小力度百分位数,默认0.7(即前30%) - min_retracement: 最小回撤比例,默认0.5 - min_speed_ratio: 第二个笔相对第一个笔的最小速度比例,默认1.5 @@ -842,13 +841,7 @@ def mark_v_reversal(df: pd.DataFrame, **kwargs): :return: 带有V字反转标记的K线数据,新增列 'is_v_reversal_up', 'is_v_reversal_down', 'is_v_reversal' """ - rs = kwargs.get("rs", True) - - if rs: - from rs_czsc import CZSC, Direction, format_standard_kline - else: - from czsc import CZSC - from czsc.utils.bar_generator import format_standard_kline + from czsc import CZSC, Direction, format_standard_kline # 参数设置 min_power_percentile = kwargs.get("min_power_percentile", 0.7) diff --git a/czsc/utils/analysis/stats.py b/czsc/utils/analysis/stats.py index 601299b9c..ad63ec018 100644 --- a/czsc/utils/analysis/stats.py +++ b/czsc/utils/analysis/stats.py @@ -133,7 +133,7 @@ def rolling_daily_performance(df: pd.DataFrame, ret_col, window=252, min_periods - yearly_days: int, 252, 一年的交易日数 """ - from rs_czsc import daily_performance + from wbt import daily_performance from czsc.eda import cal_yearly_days diff --git a/czsc/utils/plotting/kline.pyi b/czsc/utils/plotting/kline.pyi index d710a1e7a..71e142602 100644 --- a/czsc/utils/plotting/kline.pyi +++ b/czsc/utils/plotting/kline.pyi @@ -2,7 +2,7 @@ import pandas as pd from _typeshed import Incomplete from plotly import graph_objects as go -from czsc.core import CZSC as CZSC +from czsc import CZSC as CZSC class KlineChart: n_rows: Incomplete diff --git a/czsc/utils/sig.py b/czsc/utils/sig.py index 287ba0657..b76234141 100644 --- a/czsc/utils/sig.py +++ b/czsc/utils/sig.py @@ -33,7 +33,7 @@ def create_single_signal(**kwargs) -> OrderedDict: """构造单个标准信号对象 - 通过 ``rs_czsc.Signal`` 把 ``k1/k2/k3/v1/v2/v3/score`` 标准字段拼装成 + 通过 ``czsc.Signal`` 把 ``k1/k2/k3/v1/v2/v3/score`` 标准字段拼装成 ``key="k1_k2_k3"`` / ``value="v1_v2_v3_score"`` 的字符串形式,并以 ``OrderedDict`` 返回,便于和其他信号合并。 @@ -43,7 +43,7 @@ def create_single_signal(**kwargs) -> OrderedDict: - score: int,信号置信度评分,默认 0 :return: OrderedDict,``{Signal.key: Signal.value}`` """ - from rs_czsc import Signal + from czsc import Signal s = OrderedDict() k1, k2, k3 = kwargs.get("k1", "任意"), kwargs.get("k2", "任意"), kwargs.get("k3", "任意") diff --git a/czsc/utils/sig.pyi b/czsc/utils/sig.pyi index 01fb8b9f1..d5ed68f3d 100644 --- a/czsc/utils/sig.pyi +++ b/czsc/utils/sig.pyi @@ -5,11 +5,11 @@ from typing import Any import numpy as np from deprecated import deprecated as deprecated -from czsc.core import BI as BI -from czsc.core import ZS as ZS -from czsc.core import Direction as Direction -from czsc.core import RawBar as RawBar -from czsc.core import Signal as Signal +from czsc import BI as BI +from czsc import ZS as ZS +from czsc import Direction as Direction +from czsc import RawBar as RawBar +from czsc import Signal as Signal def create_single_signal(**kwargs) -> OrderedDict: ... def is_symmetry_zs(bis: list[BI], th: float = 0.3) -> bool: ... diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index acb467694..61094dcac 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -756,3 +756,19 @@ test_all_signals_parity[xlarge] PASSED checks across ~9.7M data points** — every single one passes. The migrated ``czsc._native`` is a drop-in replacement for ``rs_czsc`` at the signal-output level on data ranging from 500 bars to 40k bars. + +--- + +## 10. Phase Q — Audit-driven P0/P1 fixes (2026-05-07) + +> 触发:飞书 spec wiki 子文档 [实现细节审计 — czsc Rust 迁移现状(2026-05-07)](https://www.feishu.cn/wiki/Z7gGweUfqiK1DfkiC36cMl62nLe) 列出的 P0/P1 缺口。 + +### 10.1 已完成的修补 + +| 修复 | 文件 | 修改 | +|-|-|-| +| 版本号对齐 | [czsc/__init__.py](../czsc/__init__.py) | `__version__ = "0.10.12"` → `"1.0.0"`、`__date__ = "20260308"` → `"20260507"`,与 `Cargo.toml` / `pyproject.toml` 的 `1.0.0` 一致 | +| 删除过时 stub | `czsc/__init__.pyi` | 整文件删除。该 stub 仍引用 `from rs_czsc import ...` 与 `from .core import ...`(`core.py` 已删;`rs_czsc` 已退出依赖图),basedpyright `standard` 模式下会报错。`czsc/py.typed` 仍在,类型信息回退到 `czsc/__init__.py` 内联注解;spec §2.4 期望的 `czsc/_native.pyi` 由 `pyo3-stub-gen` 生成,留待 P1 | +| 死分支折叠 | [czsc/eda.py:823-859](../czsc/eda.py) | `mark_v_reversal` 的 `rs` 双分支已折叠为 `from czsc import CZSC, Direction, format_standard_kline` 单一 import;移除 `kwargs["rs"]` 文档项与 `from rs_czsc import ...` / `from czsc.utils.bar_generator import ...` 的死路径(`rs_czsc` 不再依赖、`bar_generator.py` 已删) | +| 散落 `rs_czsc` 导入清理 | [czsc/utils/sig.py:46](../czsc/utils/sig.py)、[czsc/utils/analysis/stats.py:136](../czsc/utils/analysis/stats.py) | `from rs_czsc import Signal` → `from czsc import Signal`;`from rs_czsc import daily_performance` → `from wbt import daily_performance`。两处都是函数体内的 lazy import,`czsc` / `wbt` 等价符号已全量验证 | +| 修复 `.pyi` 中已删除的 `czsc.core` 引用 | [czsc/utils/sig.pyi](../czsc/utils/sig.pyi)、[czsc/utils/plotting/kline.pyi](../czsc/utils/plotting/kline.pyi) | 5 处 `from czsc.core import X as X` → `from czsc import X as X`(`czsc.core` 已在 Phase H 删除,basedpyright 会报错);同步更新 `czsc/utils/sig.py` docstring 中"通过 ``rs_czsc.Signal`` 把…" → "通过 ``czsc.Signal`` 把…" | From 06bdcbc5e1a3d7f2f29d87b957d7dbf3a58674fa Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 16:37:18 +0800 Subject: [PATCH 03/23] refactor(traders): retire _lazy_rs_czsc, route sig_parse through czsc._native MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase F 已经把 derive_signals_config / derive_signals_freqs / list_all_signals 全量迁到 czsc._native(246 个信号模板可拉),sig_parse.py 中的 _lazy_rs_czsc 工厂 + 三个永真分支 (if list_all_signals is not None / if derive_signals_config is not None) 已无意义。删掉工厂 + wrappers,顶层直接 from czsc._native import …, 同步消去 docstring/注释中"rs_czsc 不可用时"的陈述;.pyi 同步补 derive_*/list_all_signals 顶层签名与 sig_k3_map 字段。Spec §3.3 "待评估 Rust 是否已等价实现"标记可摘除。 LoC: czsc/traders/sig_parse.py 387 -> 326(-61 行)。 --- czsc/traders/sig_parse.py | 116 +++++++++---------------------------- czsc/traders/sig_parse.pyi | 6 +- docs/MIGRATION_NOTES.md | 1 + 3 files changed, 34 insertions(+), 89 deletions(-) diff --git a/czsc/traders/sig_parse.py b/czsc/traders/sig_parse.py index 819ffb369..376fa96ae 100644 --- a/czsc/traders/sig_parse.py +++ b/czsc/traders/sig_parse.py @@ -5,16 +5,12 @@ 被策略加载器、信号注册表与信号编辑器等上层组件使用,以便用户既能用紧凑 的字符串描述信号,也能在程序内部以结构化字典进行编辑、序列化与回放。 -实现层面的关键点: - -* 由于 Rust 端的 ``derive_signals_config`` / ``derive_signals_freqs`` / - ``list_all_signals`` 在当前阶段尚未完全迁移到 ``czsc._native`` 命名空间, - 本模块通过 :func:`_lazy_rs_czsc` 在调用点惰性导入 ``rs_czsc``,从而保证 - 即便相关函数缺失,模块本身仍可被正常 import;只有真正调用到对应函数时 - 才会抛出明确的错误。 -* :class:`SignalsParser` 在初始化时尽力调用 ``list_all_signals`` 拉取全量 - 信号模板,缺失时退化为空注册表;上层调用方可以在没有任何注册信息时 - 仍然顺利构造解析器,只是部分功能会返回空结果。 +底层 ``derive_signals_config`` / ``derive_signals_freqs`` / ``list_all_signals`` +均由 Rust 端 ``czsc._native`` 提供(已在 Phase F 完成迁移);本模块只做 +Python 侧的解析、模板格式化与配置展平等编排工作。 + +:class:`SignalsParser` 在初始化时调用 ``list_all_signals`` 拉取全量信号模板; +失败时退化为空注册表,上层调用方仍可顺利构造解析器,只是部分功能会返回空结果。 """ from __future__ import annotations @@ -25,64 +21,11 @@ from loguru import logger from parse import parse - -def _lazy_rs_czsc(): - """惰性导入 ``rs_czsc`` 中尚未迁移到 czsc._native 的若干函数。 - - 由于这些函数当前阶段尚未在 czsc._native 中暴露,模块加载时直接 import - 会带来强依赖;本函数把 import 推迟到调用点,让模块自身可以在 rs_czsc - 缺失或部分不可用时仍可被正常导入。 - - Returns: - 三元组 ``(derive_signals_config, derive_signals_freqs, list_all_signals)``, - 分别对应 rs_czsc 中三个尚未迁移的函数。 - - Raises: - NotImplementedError: 当 rs_czsc 不可用或缺失上述函数时抛出, - 提示调用方重新安装 rs_czsc 以恢复回退能力。 - """ - try: - from rs_czsc import ( # type: ignore[import-not-found] - derive_signals_config as _dsc, - derive_signals_freqs as _dsf, - list_all_signals as _las, - ) - except ImportError as exc: # pragma: no cover - # 故意把 ImportError 转成 NotImplementedError,让"功能未提供"的语义 - # 比"缺包"更明显,并附上修复建议。 - raise NotImplementedError( - "rs_czsc.{derive_signals_config, derive_signals_freqs, " - "list_all_signals} have not been migrated to czsc._native yet " - "(see MIGRATION_NOTES §2.8). Re-install rs_czsc to fall back." - ) from exc - return _dsc, _dsf, _las - - -def derive_signals_config(*args, **kwargs): - """``rs_czsc.derive_signals_config`` 的惰性转发包装。 - - 所有参数会原样转发给底层实现;当底层不可用时抛出 - :class:`NotImplementedError`,错误信息见 :func:`_lazy_rs_czsc`。 - """ - return _lazy_rs_czsc()[0](*args, **kwargs) - - -def derive_signals_freqs(*args, **kwargs): - """``rs_czsc.derive_signals_freqs`` 的惰性转发包装。 - - 所有参数会原样转发给底层实现;当底层不可用时抛出 - :class:`NotImplementedError`,错误信息见 :func:`_lazy_rs_czsc`。 - """ - return _lazy_rs_czsc()[1](*args, **kwargs) - - -def list_all_signals(*args, **kwargs): - """``rs_czsc.list_all_signals`` 的惰性转发包装。 - - 所有参数会原样转发给底层实现;当底层不可用时抛出 - :class:`NotImplementedError`,错误信息见 :func:`_lazy_rs_czsc`。 - """ - return _lazy_rs_czsc()[2](*args, **kwargs) +from czsc._native import ( + derive_signals_config, + derive_signals_freqs, + list_all_signals, +) def _normalize_template(template: str) -> str: @@ -147,7 +90,7 @@ class SignalsParser: 建立"函数名 → 模板"的注册表;后续调用 ``parse_params`` / ``get_function_name`` / ``config_to_keys`` / ``parse`` 均依赖该注册表。 - 当 rs_czsc 不可用导致 ``list_all_signals`` 失败时,注册表会被置空, + 当 ``list_all_signals`` 失败(如 Rust 扩展异常)时,注册表会被置空, 此时各方法会按"找不到匹配模板"的语义返回空值,而不会抛出异常。 """ @@ -164,14 +107,13 @@ def __init__(self, signals_module: str = "czsc.signals"): # "k3 → 函数名列表"的反向索引,供后续解析与匹配使用。 sig_pats_map: dict[str, str] = {} sig_k3_map: dict[str, str] = {} - signal_defs = [] + signal_defs: list[dict[str, Any]] = [] - if list_all_signals is not None: - try: - # 尽力拉取全量信号定义;失败时降级为空注册表,不影响模块导入。 - signal_defs = list_all_signals(include_kline=True, include_trader=True) - except Exception as exc: - logger.warning(f"list_all_signals unavailable, using empty parser registry: {exc}") + try: + # 尽力拉取全量信号定义;失败时降级为空注册表,不影响模块导入。 + signal_defs = list(list_all_signals(include_kline=True, include_trader=True)) + except Exception as exc: + logger.warning(f"list_all_signals unavailable, using empty parser registry: {exc}") for item in signal_defs: name = str(item.get("name", "")).strip() @@ -263,15 +205,14 @@ def get_function_name(self, signal: str): logger.error(f"signal {signal} matched multiple functions: {matches}") return None - if derive_signals_config is not None: - try: - # 本地匹配失败时,回退到 Rust 端的权威解析作为兜底。 - conf = derive_signals_config([signal]) - if conf: - return str(conf[0]["name"]).split(".")[-1] - except Exception: - # 兜底失败保持静默(log 已在更底层打印),避免噪声。 - pass + try: + # 本地匹配失败时,回退到 Rust 端的权威解析作为兜底。 + conf = derive_signals_config([signal]) + if conf: + return str(conf[0]["name"]).split(".")[-1] + except Exception: + # 兜底失败保持静默(log 已在更底层打印),避免噪声。 + pass return None def config_to_keys(self, config: list[dict]): @@ -313,7 +254,7 @@ def parse(self, signal_seq: list[str]): 扁平化后的信号配置字典列表;输入为空或底层解析不可用时返回 空列表。 """ - if not signal_seq or derive_signals_config is None: + if not signal_seq: return [] try: @@ -372,8 +313,7 @@ def get_signals_freqs(signals_seq: list) -> list[str]: 去重的 K 线周期字符串列表;输入为空时返回空列表。 Notes: - 本函数依赖 ``rs_czsc`` 进行语义解析,``rs-czsc`` 为必选依赖; - 缺失时会通过 :func:`_lazy_rs_czsc` 抛出 :class:`NotImplementedError`。 + 语义解析由 Rust 扩展 ``czsc._native.derive_signals_*`` 完成。 """ if not signals_seq: return [] diff --git a/czsc/traders/sig_parse.pyi b/czsc/traders/sig_parse.pyi index 9502b0fa3..7664f6f5e 100644 --- a/czsc/traders/sig_parse.pyi +++ b/czsc/traders/sig_parse.pyi @@ -1,10 +1,11 @@ from _typeshed import Incomplete -from czsc.core import Signal as Signal +from czsc import Signal as Signal class SignalsParser: signals_module: Incomplete sig_name_map: Incomplete + sig_k3_map: Incomplete sig_pats_map: Incomplete def __init__(self, signals_module: str = "czsc.signals") -> None: ... def parse_params(self, name, signal): ... @@ -12,5 +13,8 @@ class SignalsParser: def config_to_keys(self, config: list[dict]): ... def parse(self, signal_seq: list[str]): ... +def derive_signals_config(*args, **kwargs): ... +def derive_signals_freqs(*args, **kwargs): ... +def list_all_signals(*args, **kwargs): ... def get_signals_config(signals_seq: list[str], signals_module: str = "czsc.signals") -> list[dict]: ... def get_signals_freqs(signals_seq: list) -> list[str]: ... diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 61094dcac..5bba5604c 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -772,3 +772,4 @@ the signal-output level on data ranging from 500 bars to 40k bars. | 死分支折叠 | [czsc/eda.py:823-859](../czsc/eda.py) | `mark_v_reversal` 的 `rs` 双分支已折叠为 `from czsc import CZSC, Direction, format_standard_kline` 单一 import;移除 `kwargs["rs"]` 文档项与 `from rs_czsc import ...` / `from czsc.utils.bar_generator import ...` 的死路径(`rs_czsc` 不再依赖、`bar_generator.py` 已删) | | 散落 `rs_czsc` 导入清理 | [czsc/utils/sig.py:46](../czsc/utils/sig.py)、[czsc/utils/analysis/stats.py:136](../czsc/utils/analysis/stats.py) | `from rs_czsc import Signal` → `from czsc import Signal`;`from rs_czsc import daily_performance` → `from wbt import daily_performance`。两处都是函数体内的 lazy import,`czsc` / `wbt` 等价符号已全量验证 | | 修复 `.pyi` 中已删除的 `czsc.core` 引用 | [czsc/utils/sig.pyi](../czsc/utils/sig.pyi)、[czsc/utils/plotting/kline.pyi](../czsc/utils/plotting/kline.pyi) | 5 处 `from czsc.core import X as X` → `from czsc import X as X`(`czsc.core` 已在 Phase H 删除,basedpyright 会报错);同步更新 `czsc/utils/sig.py` docstring 中"通过 ``rs_czsc.Signal`` 把…" → "通过 ``czsc.Signal`` 把…" | +| `czsc/traders/sig_parse.py` 退役 `_lazy_rs_czsc` | [czsc/traders/sig_parse.py](../czsc/traders/sig_parse.py) | 验证 `czsc._native.{derive_signals_config, derive_signals_freqs, list_all_signals}` 三函数已在 Phase F 全量上线(246 个信号模板可拉),将模块顶部的 `_lazy_rs_czsc` 工厂 + 三个 wrapper(`derive_signals_config` / `derive_signals_freqs` / `list_all_signals`)一次性删掉,改为顶层 `from czsc._native import ...`;同步移除 `if list_all_signals is not None` / `if derive_signals_config is not None` 等永真分支以及对应注释中的 `rs_czsc` 提法。`SignalsParser` 注册表初始化大小 = 246,与原 lazy 路径一致;spec §3.3 中"待评估 Rust 是否已等价实现"的临时性脚注随之失效。`czsc/traders/sig_parse.pyi` 同步更新(顶层补 `derive_signals_config / freqs / list_all_signals` 与 `sig_k3_map` 属性声明) | From e37e1bab1e4fe4d93dab96f7342e899f1b733b7e Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 16:38:24 +0800 Subject: [PATCH 04/23] refactor(api): retire __init__.py lazy loading and shrink LoC by 54% MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 按 spec §3.1,所有公共 API 在导入期一次性 import;不再使用 PEP 562 lazy loading: - 删除 _LAZY_MODULES / _LAZY_ATTRS / __getattr__ 三件套 - svc/fsa/aphorism/mock 改为顶层 from . import … - 7 个 lazy 属性(capture_warnings / plot_czsc_chart / KlineChart / …)改为 from czsc.utils.* import … - 删除 if TYPE_CHECKING 守卫;welcome() 内的 lazy import 提到顶层 同时把注释/文档密度大幅收缩: - 17 处冗长的"逐符号说明"注释块全删(信息已在符号自身 docstring) - __all__ 字面表改为按主题分组的紧凑横排(仍保留全部 129 个公共名称) - welcome() docstring 折成单行;模块 docstring 22 行 -> 11 行 LoC: 507 -> 235(-54%)。 循环 import 防坑:svc/fsa/aphorism/mock 中含 from czsc import top_drawdowns 等 回环 import,必须放到所有顶层符号绑定后再加载(第二批 from . import ...), 第一次误把它们提到顶部触发 ImportError partially-initialized module,调整顺序后 通过;文件中以"第一批 / 第二批"分组注释固化此约束。 测试更新:test/test_import_performance.py 中两个基于"streamlit 不应被 import czsc 拉起"的旧测试与新方向冲突,已删除;保留 import-time 兜底与 svc 可访问性测试。 --- czsc/__init__.py | 444 +++++++------------------------- docs/MIGRATION_NOTES.md | 1 + test/test_import_performance.py | 55 +--- 3 files changed, 88 insertions(+), 412 deletions(-) diff --git a/czsc/__init__.py b/czsc/__init__.py index 24d9fa242..aa6d248fa 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -1,60 +1,26 @@ -""" -CZSC(缠中说禅)量化分析框架的顶层包入口模块 - -职责: - 1. 统一对外暴露公共 API(缠论核心类、信号、交易器、策略、回测工具等) - 2. 完成 Rust 后端(czsc._native,由 PyO3 编译而来)与 Python 适配层之间的桥接 - 3. 维护 ``__all__`` 公共契约,保证 ``from czsc import *`` 行为可控 - 4. 通过模块级 ``__getattr__`` 提供按需加载(懒加载)的子模块与符号 - —— 仅在首次访问时才执行 ``importlib.import_module``,可显著降低冷启动耗时 - 同时避免循环依赖 +"""CZSC(缠中说禅)量化分析框架——顶层包入口。 -约定: - - 所有以单下划线开头的对象(如 ``_sys``、``_LAZY_MODULES``)均为模块内部使用, - 不属于公共 API,禁止外部直接依赖 - - ``czsc.ta``、``czsc.CZSC`` 等高频符号优先来自 Rust 实现,性能更佳 - - 升级 Rust 版本时需同步检查 ``__all__`` 与本模块导入区,确保契约一致 +按 spec §3.1,所有公共 API 在导入期一次性 import;不再使用 PEP 562 lazy loading。 +- ``czsc._native``:Rust 扩展(PyO3),提供缠论核心类型、信号、交易器、TA 算子。 +- ``czsc.{connectors,sensors,signals,traders,utils,svc,fsa,aphorism,mock,envs}``:Python 子包。 +- ``czsc.{ema,sma,...,ultimate_smoother,...}``:Rust TA 算子的顶层别名。 +- ``czsc.{WeightBacktest,daily_performance,top_drawdowns}``:来自硬依赖 ``wbt``。 -作者: zengbin93 -邮箱: zeng_bin8888@163.com -创建时间: 2019/10/29 15:01 +作者: zengbin93 ,创建于 2019/10/29。 """ import sys as _sys -# 子包按"始终需要 / 启动期可加载"为标准统一引入: -# - _native: PyO3 编译的 Rust 扩展,缠论核心实现位于此 -# - connectors/envs/sensors/signals/traders/utils: 业务层模块,多数函数会被立即用到 -# 这些子包必须立即加载,因下方的 from ... import 语句直接依赖它们 +# 第一批:纯薄壳子包(不会回头 import czsc 顶层符号)。 +# svc/fsa/aphorism/mock 中含 ``from czsc import top_drawdowns`` 等回环 import, +# 必须放到 wbt / .traders / .utils 之后再加载,避免循环 import。 from . import _native, connectors, envs, sensors, signals, traders, utils -# === Rust 扩展的 ta 命名空间桥接 === -# 让 ``czsc.ta.*`` 直接来自 Rust 扩展(czsc._native.ta),不再使用 Python 包装层。 -# 通过同时设置模块属性与 ``sys.modules``,保证以下两种导入方式都能命中 Rust 实现: -# import czsc.ta # 解析为 czsc._native.ta -# from czsc.ta import ema # 解析为 czsc._native.ta.ema -# 注意:在文件下方 ``from .utils import ...`` 之后还会重新赋值一次,避免 -# 旧版包装模块在导入链上覆盖此处别名(见后文"重新应用 czsc.ta 别名"段)。 +# czsc.ta -> czsc._native.ta(Rust 实现);同时设置 sys.modules 以兼容 import czsc.ta ta = _native.ta _sys.modules["czsc.ta"] = _native.ta -# === 缠论核心数据类型与算法(来自 Rust 扩展) === -# 这些符号是 CZSC 公共 API 的"硬契约",在大量业务代码与下游项目中被直接 import。 -# 命名与原 Python 实现保持一致,便于无缝迁移。各类型的语义: -# BI - 笔(缠论中由分型连接形成的最小走势单元) -# CZSC - 缠论分析器主类,承载分型/笔/线段的识别管线 -# FX - 分型(顶分型 / 底分型) -# ZS - 中枢(多笔重叠形成的盘整区间) -# BarGenerator - K 线合成器,用于多周期联立 -# Direction/Mark/Operate - 方向/标记/操作枚举 -# Event/Signal/Position - 事件、信号、持仓三件套 -# ParsedSignalDoc - 信号文档解析结果 -# FakeBI/NewBar/RawBar - K 线及其衍生抽象(原始/合成/虚拟笔) -# 工具函数: -# boll_positions/ema/sma/ultimate_smoother/rolling_rank - 技术指标 -# check_bi/check_fx/check_fxs/remove_include - 缠论结构校验 -# freq_end_time/is_trading_time - 周期与交易时段判定 -# parse_signal_doc - 信号声明字符串解析 +# === 缠论核心数据类型与算法(来自 Rust 扩展 czsc._native)=== from ._native import ( BI, CZSC, @@ -86,40 +52,13 @@ ultimate_smoother, ) -# === format_standard_kline 的 Python 包装 === -# Rust 扩展仅提供一个直通桩(``Vec -> Vec``),无法直接接受 -# pandas DataFrame。但下游用户期望的签名是 ``DataFrame + Freq -> List[RawBar]``。 -# 因此 ``czsc/_format_standard_kline.py`` 是一个 Python 适配层,逐行通过 -# PyO3 构造器创建 RawBar 实例,签名与 rs-czsc 提供的 Python 端 API 完全一致。 -# 这样既保留了"用户传 DataFrame 即可"的便利性,又复用了 Rust 端的内存布局与类型校验。 +# format_standard_kline: Python 适配层,把 DataFrame -> List[RawBar](详见模块 docstring) from czsc._format_standard_kline import format_standard_kline -# === 回测引擎(来自第三方包 wbt) === -# WeightBacktest - 基于权重序列的向量化回测器 -# daily_performance - 日度绩效统计(年化、夏普、最大回撤等) -# top_drawdowns - 提取前 N 大回撤区间,用于风险归因分析 +# === wbt(硬依赖,提供回测/绩效组件)=== from wbt import WeightBacktest, daily_performance, top_drawdowns -from typing import TYPE_CHECKING - -# === 探索性数据分析(EDA)相关函数 === -# 这些工具函数用于因子研究、特征工程与轻量级策略评估, -# 在多数交易研究流程中是高频使用项,因此选择在导入期就一并暴露: -# cal_symbols_factor / cal_trade_price - 多品种因子计算与交易价归一化 -# cal_yearly_days - 计算年化基准日数(区分 A 股/期货/数字货币) -# cross_sectional_strategy - 横截面排序策略 -# dif_long_bear / sma_long_bear - 多/空趋势判定 -# limit_leverage - 杠杆约束 -# make_price_features - 价格类特征工厂 -# mark_cta_periods / mark_volatility - CTA 区间与波动率分段 -# min_max_limit - 数值裁剪 -# monotonicity - 单调性检验 -# remove_beta_effects - 去除 Beta 系统性影响 -# rolling_layers - 分层回看 -# tsf_type - 时序特征类别标注 -# turnover_rate - 换手率统计 -# twap / vwap - 时间/成交量加权均价 -# unify_weights / weights_simple_ensemble - 权重归一与简单集成 +# === EDA 工具(来自 czsc.eda)=== from .eda import ( cal_symbols_factor, cal_trade_price, @@ -143,27 +82,10 @@ weights_simple_ensemble, ) -# 仅在类型检查阶段(如 mypy / pyright / IDE 静态分析)暴露这些懒加载子模块, -# 运行期它们会经由下方的 ``__getattr__`` 按需导入。这种写法既能让静态工具 -# 正确解析 ``czsc.svc`` 等用法,又不会在导入 czsc 时就把这些重量级子包 -# (例如 svc 依赖 plotly/streamlit)拉起来。 -if TYPE_CHECKING: - from . import aphorism, fsa, mock, svc - -# === 策略门面(Facade) === -# strategies.py 是 Python 端的薄封装:在 Rust 实现的 Trader 基础上,提供 -# CzscStrategyBase(策略开发抽象基类)与 CzscJsonStrategy(JSON 配置式策略), -# 隔离用户层 API 与底层 Rust 类型,便于策略快速搭建与序列化。 +# === 策略门面(czsc.strategies;Python 层对 Rust Trader 的薄封装)=== from .strategies import CzscJsonStrategy, CzscStrategyBase -# === 交易器与信号管理 API === -# traders 子包对外暴露的统一入口,由 Rust 后端驱动: -# CzscSignals/CzscTrader - 多周期信号合成与交易调度核心 -# SignalsParser - 信号声明字符串解析器 -# derive_signals_* - 从持仓/事件反推所需信号配置或周期集合 -# generate_czsc_signals - 标准化信号生成入口 -# get_signals_* - 信号配置/周期获取辅助函数 -# get_unique_signals - 信号去重工具,回测前置预处理常用 +# === 交易器与信号管理 API(czsc.traders)=== from .traders import ( CzscSignals, CzscTrader, @@ -176,14 +98,7 @@ get_unique_signals, ) -# === 研究/优化入口(research.py,Rust 后端) === -# 这些函数是策略研究流程的主入口,封装了"参数批量回测""开仓/平仓优化""复盘" -# 等高层操作,对应 czsc/research.py 中的统一研究 API: -# build_open_optim_positions - 构造开仓参数优化所需的 Position 列表 -# build_exit_optim_positions - 构造平仓参数优化所需的 Position 列表 -# run_optimize_batch - 多参数组合批量优化 -# run_replay - 单标的回放 -# run_research - 顶层一键式研究流水线 +# === 研究/优化入口(czsc.research,Rust 后端)=== from .research import ( build_exit_optim_positions, build_open_optim_positions, @@ -192,24 +107,7 @@ run_research, ) -# === 通用工具函数集合(czsc.utils) === -# 这些工具按使用频率从 czsc.utils 中提升至顶级命名空间,便于直接 ``czsc.xxx`` 调用。 -# 主要分组: -# - 缓存类:DiskCache / disk_cache / clear_cache / clear_expired_cache / -# empty_cache_path / get_dir_size / home_path -# - 数据源/IO:DataClient / AliyunOSS / read_json / save_json / to_arrow -# - 加解密:fernet_encrypt / fernet_decrypt / generate_fernet_key / -# get_url_token / set_url_token -# - 序列化:dill_dump / dill_load -# - 时间/周期:freqs_sorted / resample_to_daily -# - 命名空间/反射:code_namespace / get_py_namespace / import_by_name -# - 绩效统计:cross_sectional_ic / holds_performance / -# rolling_daily_performance / risk_free_returns -# - 调试:print_df_sample / mac_address / x_round -# - PSI:psi(群体稳定性指数) -# - 网格/装饰器:create_grid_params / timeout_decorator -# - 指标增量更新:update_bbars / update_nxb / update_tbars -# - 指数成分:index_composition +# === 通用工具函数(czsc.utils)=== from .utils import ( AliyunOSS, DataClient, @@ -251,257 +149,87 @@ x_round, ) -# === 重新应用 czsc.ta 别名(关键步骤,勿删) === -# 必须在 ``from .utils import ...`` 之后再次执行一次 czsc.ta 的别名绑定, -# 原因:旧版 utils 模块在导入链上可能间接触发 ``czsc.utils.ta`` 子模块的 -# 副作用导入,从而将 sys.modules['czsc.ta'] 指向 Python 包装版本,覆盖 -# 文件顶部设置的 Rust 版本。这里再绑定一次确保 Rust 实现胜出。 +# === 之前的 lazy 属性,改为静态 import(spec §3.1 移除 lazy loading)=== +from czsc.utils.kline_quality import check_kline_quality +from czsc.utils.log import log_strategy_info +from czsc.utils.plotting.kline import KlineChart, plot_czsc_chart +from czsc.utils.trade import adjust_holding_weights +from czsc.utils.warning_capture import capture_warnings, execute_with_warning_capture + +# 第二批:会回头 import czsc 顶层符号(如 ``from czsc import top_drawdowns``)的重型子包。 +# 必须放在所有顶层符号都已经绑定之后,否则会触发 partially-initialized module 循环 import。 +from . import aphorism, fsa, mock, svc + +# czsc.ta 别名再保险:from .utils import ... 链路上若有副作用 import czsc.utils.ta, +# 可能把 sys.modules["czsc.ta"] 覆盖回 Python 包装版本,这里再绑一次确保 Rust 子模块胜出。 ta = _native.ta _sys.modules["czsc.ta"] = _native.ta -# === 公共 API 契约 === -# ``__all__`` 显式声明 ``from czsc import *`` 行为暴露的符号集合。 -# 维护规则: -# 1. 任何顶级 import 中出现的公共符号都需登记于此(按主题分组排列) -# 2. 私有符号(单下划线开头)禁止登记 -# 3. 修改本列表等价于修改公共契约,必须在 CHANGELOG / 迁移指南中说明 -__all__ = [ - "BI", - "CZSC", - "FX", - "ZS", - "BarGenerator", - "Direction", - "Event", - "FakeBI", - "Freq", - "Mark", - "NewBar", - "Operate", - "ParsedSignalDoc", - "Position", - "RawBar", - "Signal", - "boll_positions", - "check_bi", - "check_fx", - "check_fxs", - "ema", - "format_standard_kline", - "freq_end_time", - "is_trading_time", - "parse_signal_doc", - "remove_include", - "rolling_rank", - "sma", - "ultimate_smoother", - # —— 来自 wbt 的回测组件 —— - "WeightBacktest", - "daily_performance", - "top_drawdowns", - # —— 始终预先加载的子包 —— - "connectors", - "envs", - "sensors", - "signals", - "traders", - "utils", - # —— 交易器 API(czsc/traders/__init__.py,Rust 后端实现) —— - "CzscSignals", - "CzscTrader", - "SignalsParser", - "derive_signals_config", - "derive_signals_freqs", - "generate_czsc_signals", - "get_signals_config", - "get_signals_freqs", - "get_unique_signals", - # —— 策略门面(czsc/strategies.py,Python 层对 Rust Trader 的封装) —— - "CzscStrategyBase", - "CzscJsonStrategy", - # —— 研究/优化入口(czsc/research.py,Rust 后端) —— - "build_exit_optim_positions", - "build_open_optim_positions", - "run_optimize_batch", - "run_replay", - "run_research", - # —— 通用工具(来自 czsc/utils) —— - "AliyunOSS", - "DataClient", - "DiskCache", - "clear_cache", - "clear_expired_cache", - "code_namespace", - "create_grid_params", - "cross_sectional_ic", - "dill_dump", - "dill_load", - "disk_cache", - "empty_cache_path", - "fernet_decrypt", - "fernet_encrypt", - "freqs_sorted", - "generate_fernet_key", - "get_dir_size", - "get_py_namespace", - "get_url_token", - "holds_performance", - "home_path", - "import_by_name", - "index_composition", - "mac_address", - "print_df_sample", - "psi", - "read_json", - "resample_to_daily", - "risk_free_returns", - "rolling_daily_performance", - "save_json", - "set_url_token", - "ta", - "timeout_decorator", - "to_arrow", - "update_bbars", - "update_nxb", - "update_tbars", - "x_round", - "svc", - "fsa", - "aphorism", - "mock", - "capture_warnings", - "execute_with_warning_capture", - "adjust_holding_weights", - "log_strategy_info", - "plot_czsc_chart", - "KlineChart", - "check_kline_quality", - "remove_beta_effects", - "vwap", - "twap", - "cross_sectional_strategy", - "monotonicity", - "min_max_limit", - "rolling_layers", - "cal_symbols_factor", - "weights_simple_ensemble", - "unify_weights", - "sma_long_bear", - "dif_long_bear", - "tsf_type", - "limit_leverage", - "cal_trade_price", - "mark_cta_periods", - "mark_volatility", - "cal_yearly_days", - "turnover_rate", - "make_price_features", - "__version__", - "__author__", - "__email__", - "__date__", - "welcome", -] - # === 包元信息 === -# 这些字段会被 setuptools / pip / sphinx 等工具读取,发布前需同步更新。 -# __date__ 采用 ``YYYYMMDD`` 格式,便于排序与追溯。 __version__ = "1.0.0" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" __date__ = "20260507" -# === 懒加载子模块映射表 === -# 键为公开访问名,值为完整模块路径。 -# 这些子包通常依赖较重(如 svc 依赖 plotly/streamlit、fsa 依赖飞书 SDK、 -# mock 涉及大量随机数生成器),若在 import czsc 时一次性全部加载,会显著 -# 拖慢 CLI 工具与服务启动速度。延迟到首次访问时再加载,可保持冷启动轻量。 -_LAZY_MODULES = { - "svc": "czsc.svc", - "fsa": "czsc.fsa", - "aphorism": "czsc.aphorism", - "mock": "czsc.mock", -} - -# === 懒加载属性映射表 === -# 键为公开访问名,值为 (模块路径, 模块内符号名) 二元组。 -# 用于把分散在工具子模块中的少量高频函数/类提升到顶级命名空间, -# 同时保留按需加载、避免在导入期触发不必要的副作用。 -_LAZY_ATTRS = { - "capture_warnings": ("czsc.utils.warning_capture", "capture_warnings"), - "execute_with_warning_capture": ("czsc.utils.warning_capture", "execute_with_warning_capture"), - "adjust_holding_weights": ("czsc.utils.trade", "adjust_holding_weights"), - "log_strategy_info": ("czsc.utils.log", "log_strategy_info"), - "plot_czsc_chart": ("czsc.utils.plotting.kline", "plot_czsc_chart"), - "KlineChart": ("czsc.utils.plotting.kline", "KlineChart"), - "check_kline_quality": ("czsc.utils.kline_quality", "check_kline_quality"), -} - - -def __getattr__(name): - """ - 模块级懒加载钩子(PEP 562) - - Python 在常规属性查找失败时会回退调用本函数,因此可借助它实现 - 延迟导入,既不破坏 ``czsc.svc.xxx``、``czsc.capture_warnings`` 等 - 用户期望的访问形式,又能避免导入开销。 - - 实现细节: - 1. 命中 ``_LAZY_MODULES`` —— 调用 ``importlib.import_module`` 加载, - 并把模块对象写入 ``globals()``,后续访问直接走常规路径,无再次开销 - 2. 命中 ``_LAZY_ATTRS`` —— 加载子模块后取出指定属性,同样缓存到全局命名空间 - 3. 全部未命中 —— 抛 ``AttributeError``(必须保留,否则 ``hasattr`` 会出错) - - 参数: - name: 用户尝试访问的属性名,例如 ``"svc"``、``"plot_czsc_chart"`` - - 返回: - 加载完成的模块对象或目标属性 - - 异常: - AttributeError: 当 ``name`` 既不在 ``_LAZY_MODULES`` 也不在 ``_LAZY_ATTRS`` 中 - """ - import importlib - - # 路径 1:懒加载子模块(按需 import,再缓存到全局) - if name in _LAZY_MODULES: - module = importlib.import_module(_LAZY_MODULES[name]) - globals()[name] = module - return module - - # 路径 2:懒加载子模块中的某个属性(先 import 再 getattr,最后缓存) - if name in _LAZY_ATTRS: - mod_path, attr_name = _LAZY_ATTRS[name] - module = importlib.import_module(mod_path) - attr = getattr(module, attr_name) - globals()[name] = attr - return attr - - # 路径 3:未注册的属性 —— 严格抛错以维持标准 Python 语义(hasattr 等场景依赖此行为) - raise AttributeError(f"module 'czsc' has no attribute {name!r}") +# === 公共 API 契约 === +# 修改本列表等价于修改公共契约;新增/移除符号必须在 release notes 与 MIGRATION_NOTES 中说明。 +__all__ = [ + # 缠论核心 + "BI", "CZSC", "FX", "ZS", "BarGenerator", "Direction", "Event", "FakeBI", + "Freq", "Mark", "NewBar", "Operate", "ParsedSignalDoc", "Position", "RawBar", "Signal", + "boll_positions", "check_bi", "check_fx", "check_fxs", + "ema", "format_standard_kline", "freq_end_time", "is_trading_time", + "parse_signal_doc", "remove_include", "rolling_rank", "sma", "ultimate_smoother", + # 来自 wbt + "WeightBacktest", "daily_performance", "top_drawdowns", + # 始终预加载的子包 + "connectors", "envs", "sensors", "signals", "traders", "utils", + "svc", "fsa", "aphorism", "mock", + # 交易器 / 信号 API + "CzscSignals", "CzscTrader", "SignalsParser", + "derive_signals_config", "derive_signals_freqs", + "generate_czsc_signals", "get_signals_config", "get_signals_freqs", "get_unique_signals", + # 策略门面 + "CzscStrategyBase", "CzscJsonStrategy", + # 研究/优化入口 + "build_exit_optim_positions", "build_open_optim_positions", + "run_optimize_batch", "run_replay", "run_research", + # 通用工具 + "AliyunOSS", "DataClient", "DiskCache", + "clear_cache", "clear_expired_cache", + "code_namespace", "create_grid_params", "cross_sectional_ic", + "dill_dump", "dill_load", "disk_cache", + "empty_cache_path", "fernet_decrypt", "fernet_encrypt", + "freqs_sorted", "generate_fernet_key", + "get_dir_size", "get_py_namespace", "get_url_token", + "holds_performance", "home_path", + "import_by_name", "index_composition", + "mac_address", "print_df_sample", "psi", + "read_json", "resample_to_daily", "risk_free_returns", + "rolling_daily_performance", "save_json", "set_url_token", + "ta", "timeout_decorator", "to_arrow", + "update_bbars", "update_nxb", "update_tbars", "x_round", + # 静态 import 的高频符号(曾经走 _LAZY_ATTRS) + "capture_warnings", "execute_with_warning_capture", + "adjust_holding_weights", "log_strategy_info", + "plot_czsc_chart", "KlineChart", "check_kline_quality", + # EDA + "remove_beta_effects", "vwap", "twap", "cross_sectional_strategy", + "monotonicity", "min_max_limit", "rolling_layers", "cal_symbols_factor", + "weights_simple_ensemble", "unify_weights", + "sma_long_bear", "dif_long_bear", "tsf_type", "limit_leverage", + "cal_trade_price", "mark_cta_periods", "mark_volatility", "cal_yearly_days", + "turnover_rate", "make_price_features", + # 元信息 + "__version__", "__author__", "__email__", "__date__", "welcome", +] def welcome(): - """ - CLI/交互式环境下的欢迎信息打印函数 - - 用途: - - 打印当前 CZSC 版本号、日期与一段随机的"缠论格言"(aphorism 子模块) - - 打印关键环境变量当前值,便于排查"实际生效配置"问题 - - 当本地缓存目录体积超过 1 GB 时,给出清理提示,避免长期堆积 - - 设计动机: - 把 aphorism 的导入推迟到函数体内(而非模块顶部),是为了避免 - ``import czsc`` 时强制依赖该子包;若用户没有调用 ``welcome()``, - aphorism 就不会被加载。 - """ - from czsc import aphorism - + """打印 CZSC 版本号、随机格言与缓存目录提示,用于 CLI/交互式环境。""" print(f"欢迎使用CZSC!当前版本标识为 {__version__}@{__date__}\n") aphorism.print_one() - print(f"CZSC环境变量:czsc_min_bi_len = {envs.get_min_bi_len()}; czsc_max_bi_num = {envs.get_max_bi_num()}; ") - # 1 GB 阈值:超出即提示用户主动清理,避免缓存目录无限膨胀; - # 用 ``pow(1024, 3)`` 而非 ``10**9`` 是为了得到精确的二进制 GB(GiB)。 + # 1 GiB 阈值:超出即提示用户主动清理,避免缓存目录无限膨胀。 if get_dir_size(home_path) > pow(1024, 3): print(f"{home_path} 目录缓存超过1GB,请适当清理。调用 czsc.empty_cache_path() 可以直接清空缓存") diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 5bba5604c..df7ae92d8 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -773,3 +773,4 @@ the signal-output level on data ranging from 500 bars to 40k bars. | 散落 `rs_czsc` 导入清理 | [czsc/utils/sig.py:46](../czsc/utils/sig.py)、[czsc/utils/analysis/stats.py:136](../czsc/utils/analysis/stats.py) | `from rs_czsc import Signal` → `from czsc import Signal`;`from rs_czsc import daily_performance` → `from wbt import daily_performance`。两处都是函数体内的 lazy import,`czsc` / `wbt` 等价符号已全量验证 | | 修复 `.pyi` 中已删除的 `czsc.core` 引用 | [czsc/utils/sig.pyi](../czsc/utils/sig.pyi)、[czsc/utils/plotting/kline.pyi](../czsc/utils/plotting/kline.pyi) | 5 处 `from czsc.core import X as X` → `from czsc import X as X`(`czsc.core` 已在 Phase H 删除,basedpyright 会报错);同步更新 `czsc/utils/sig.py` docstring 中"通过 ``rs_czsc.Signal`` 把…" → "通过 ``czsc.Signal`` 把…" | | `czsc/traders/sig_parse.py` 退役 `_lazy_rs_czsc` | [czsc/traders/sig_parse.py](../czsc/traders/sig_parse.py) | 验证 `czsc._native.{derive_signals_config, derive_signals_freqs, list_all_signals}` 三函数已在 Phase F 全量上线(246 个信号模板可拉),将模块顶部的 `_lazy_rs_czsc` 工厂 + 三个 wrapper(`derive_signals_config` / `derive_signals_freqs` / `list_all_signals`)一次性删掉,改为顶层 `from czsc._native import ...`;同步移除 `if list_all_signals is not None` / `if derive_signals_config is not None` 等永真分支以及对应注释中的 `rs_czsc` 提法。`SignalsParser` 注册表初始化大小 = 246,与原 lazy 路径一致;spec §3.3 中"待评估 Rust 是否已等价实现"的临时性脚注随之失效。`czsc/traders/sig_parse.pyi` 同步更新(顶层补 `derive_signals_config / freqs / list_all_signals` 与 `sig_k3_map` 属性声明) | +| 移除 `czsc/__init__.py` lazy loading + 注释/文档密度收缩 | [czsc/__init__.py](../czsc/__init__.py) | 按 spec §3.1 删除 `_LAZY_MODULES` / `_LAZY_ATTRS` / `__getattr__` 三件套;`svc / fsa / aphorism / mock` 改为顶层 `from . import ...`,7 个 lazy 属性(`capture_warnings` / `execute_with_warning_capture` / `adjust_holding_weights` / `log_strategy_info` / `plot_czsc_chart` / `KlineChart` / `check_kline_quality`)改为 `from czsc.utils.* import ...` 直接导入;删除 `if TYPE_CHECKING` 守卫;`welcome()` 函数体内的 `from czsc import aphorism` 提到顶层。同时压缩区段注释(17 处冗长的"逐符号说明"注释块全删)、`__all__` 字面表改为按主题分组的紧凑横排(仍保留全部 129 个公共名称、按主题用单行注释分隔)、`welcome()` docstring 折成单行、模块 docstring 22 行 → 11 行。`czsc/__init__.py` LoC 从 507 → 235(-54%)。**循环 import 防坑**:`svc / fsa / aphorism / mock` 中含 `from czsc import top_drawdowns` 等回环 import,必须放到所有顶层符号绑定后再加载(即"第二批 `from . import aphorism, fsa, mock, svc`"),第一次重排误把它们提到顶部触发 `cannot import name 'top_drawdowns' from partially initialized module 'czsc'`,调整顺序后通过;文件中以"第一批 / 第二批"分组注释固化此约束。**测试更新**:`test/test_import_performance.py::test_heavy_dependencies_not_loaded_on_import` 与 `test_svc_lazy_loaded` 是基于"streamlit 不应在 import czsc 时被加载"的旧设计断言,与新方向冲突,已删除;保留 `test_czsc_import_time`(< 10s 兜底)与 `test_czsc_svc_accessible`(顶层属性可用)。冷启动 importtime cumtime ≈ 320ms(spec §6 P3 目标 ≤ 300ms,超 ~7%;spec §3.1 注释中预期 < 50ms 仅指 Rust 扩展加载,不含整包 import) | diff --git a/test/test_import_performance.py b/test/test_import_performance.py index 126545c64..15ce7ac4d 100644 --- a/test/test_import_performance.py +++ b/test/test_import_performance.py @@ -38,61 +38,8 @@ def test_czsc_import_time(): ) -def test_heavy_dependencies_not_loaded_on_import(): - """导入 czsc 后,不应自动加载 streamlit / scipy 等重型可选依赖""" - code = """ -import sys -# 记录导入前已加载的模块(排除测试框架等) -import czsc - -loaded = set(sys.modules.keys()) -heavy = ["streamlit", "scipy", "clickhouse_connect", "redis", "IPython", "lightweight_charts"] -violations = [m for m in heavy if any(k == m or k.startswith(m + ".") for k in loaded)] -if violations: - print("FAIL:" + ",".join(violations)) -else: - print("OK") -""" - result = subprocess.run( - [sys.executable, "-c", code], - capture_output=True, - text=True, - timeout=60, - ) - assert result.returncode == 0, f"子进程异常:\n{result.stderr}" - output = result.stdout.strip() - assert output == "OK", ( - f"import czsc 后意外加载了重型依赖: {output.replace('FAIL:', '')}。这些依赖应该使用延迟导入(lazy import)。" - ) - - -def test_svc_lazy_loaded(): - """czsc.svc 应延迟加载(访问 czsc.svc 时才触发 streamlit 导入)""" - code = """ -import sys -import czsc - -# 仅导入 czsc 本身,不访问 czsc.svc -# streamlit 不应被加载 -loaded = set(sys.modules.keys()) -if any(k == "streamlit" or k.startswith("streamlit.") for k in loaded): - print("FAIL: streamlit loaded before accessing czsc.svc") -else: - print("OK") -""" - result = subprocess.run( - [sys.executable, "-c", code], - capture_output=True, - text=True, - timeout=60, - ) - assert result.returncode == 0, f"子进程异常:\n{result.stderr}" - output = result.stdout.strip() - assert output == "OK", output - - def test_czsc_svc_accessible(): - """czsc.svc 通过延迟加载仍可正常访问""" + """czsc.svc 子包在导入期就已可用(spec §3.1 移除 lazy loading 后)""" import czsc svc = czsc.svc From b984dcd79862d846580924952a24922d2d579db3 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 16:40:22 +0800 Subject: [PATCH 05/23] feat(utils): implement stoploss_by_direction (czsc-only) via TDD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 调研发现 stoploss_by_direction 既不在当前安装的 rs_czsc 也不在 wbt 也不在 /Users/jun/Documents/vscodePro/rs_czsc git 历史 —— czsc/svc/backtest.py:261 的 from rs_czsc import stoploss_by_direction 一直是死调用,运行 Streamlit dashboard 时会 ImportError。 按 superpowers TDD 范式: - RED: 新增 test/test_stoploss_by_direction.py 的 6 个失败测试(多/空头止损、 order_id 切分、列契约、入参不可变性) - GREEN: 在 czsc/utils/trade.py 用纯 Python 写最小实现(按方向连续段切 order_id;向量化 hold_returns / min_hold_returns / returns / is_stop; 浮点边界 1e-9 容差处理 92/100 - 1 = -0.07999... 这类问题) - 切 czsc/svc/backtest.py:261 -> from czsc.utils.trade import stoploss_by_direction 效果:grep -r 'from rs_czsc\|import rs_czsc' czsc/ --include='*.py' 现在零结果, spec C1 全量达成,czsc 内部彻底无 rs_czsc 依赖。 该函数与 is_trading_time 并列为 czsc-only 新增能力,登记于 MIGRATION_NOTES §2.2。 --- czsc/svc/backtest.py | 9 ++- czsc/utils/trade.py | 63 +++++++++++++++ docs/MIGRATION_NOTES.md | 2 + test/test_stoploss_by_direction.py | 120 +++++++++++++++++++++++++++++ 4 files changed, 190 insertions(+), 4 deletions(-) create mode 100644 test/test_stoploss_by_direction.py diff --git a/czsc/svc/backtest.py b/czsc/svc/backtest.py index 9737b8e1a..e6704ef12 100644 --- a/czsc/svc/backtest.py +++ b/czsc/svc/backtest.py @@ -246,9 +246,10 @@ def show_holds_backtest(df, **kwargs): def show_stoploss_by_direction(dfw, **kwargs): """按方向止损分析的展示 - 在执行权重回测之前,先调用 ``rs_czsc.stoploss_by_direction`` 对权重数据按交易方向 - 进行止损改写:当一笔交易(同方向连续持仓)的浮亏达到 ``stoploss`` 时,将后续权重 - 强制平仓。改写后再调用 :func:`show_weight_backtest` 进行回测和展示。 + 在执行权重回测之前,先调用 :func:`czsc.utils.trade.stoploss_by_direction` 对 + 权重数据按交易方向进行止损改写:当一笔交易(同方向连续持仓)的浮亏达到 + ``stoploss`` 时,将后续权重强制平仓。改写后再调用 :func:`show_weight_backtest` + 进行回测和展示。 :param dfw: pd.DataFrame,包含 ``symbol``、``dt``、``weight``、``price`` 等权重数据 :param kwargs: 其他参数 @@ -258,7 +259,7 @@ def show_stoploss_by_direction(dfw, **kwargs): - fee_rate: float,手续费率,默认 0.0002 :return: None """ - from rs_czsc import stoploss_by_direction + from czsc.utils.trade import stoploss_by_direction dfw = dfw.copy() stoploss = kwargs.pop("stoploss", 0.08) diff --git a/czsc/utils/trade.py b/czsc/utils/trade.py index 214aeca41..f0f01920f 100644 --- a/czsc/utils/trade.py +++ b/czsc/utils/trade.py @@ -189,3 +189,66 @@ def adjust_holding_weights(df, hold_periods=1, **kwargs): dfw1 = pd.melt(dfs, id_vars="dt", value_vars=dfs.columns.to_list(), var_name="symbol", value_name="weight") dfw1 = pd.merge(df[["dt", "symbol", "n1b"]], dfw1, on=["dt", "symbol"], how="left") return dfw1 + + +def stoploss_by_direction(dfw: pd.DataFrame, stoploss: float = 0.08) -> pd.DataFrame: + """按方向止损改写权重序列。 + + 把权重序列按"方向连续段"切成若干笔交易(``order_id``),对每一笔逐 bar + 计算从入场价开始的方向化累计收益 ``hold_returns``;当累计回撤达到 + ``-stoploss`` 时,从该 bar 起把后续 ``weight`` 强制归零、并把 ``is_stop`` + 置 ``True``,等价于"按方向止损平仓"。原始权重保留在 ``raw_weight`` 列。 + + 历史上该函数由 ``rs_czsc.stoploss_by_direction`` 提供;rs_czsc 后续未再 + 维护,czsc 在迁移过程中以 czsc-only 改动新增此 Python 实现,作为 + ``czsc.svc.backtest.show_stoploss_by_direction`` 的底层依赖。 + + :param dfw: 包含 ``dt`` / ``symbol`` / ``weight`` / ``price`` 列的权重数据 + :param stoploss: 触发止损的浮亏阈值(取正值;例如 0.08 代表 -8%) + :return: 复制自 ``dfw``,额外带 ``raw_weight`` / ``order_id`` / + ``hold_returns`` / ``min_hold_returns`` / ``returns`` / ``is_stop`` 列 + """ + out = dfw.copy() + out = out.sort_values(["symbol", "dt"]).reset_index(drop=True) + out["raw_weight"] = out["weight"].astype(float) + + sign = out["weight"].apply(lambda w: 0 if w == 0 else (1 if w > 0 else -1)) + # order_id 在方向变化时自增;在每个 symbol 内独立编号 + direction_change = sign.ne(sign.groupby(out["symbol"]).shift(1)).astype(int) + out["order_id"] = direction_change.groupby(out["symbol"]).cumsum() + + out["hold_returns"] = 0.0 + out["min_hold_returns"] = 0.0 + out["returns"] = 0.0 + out["is_stop"] = False + + for (_symbol, _oid), grp in out.groupby(["symbol", "order_id"], sort=False): + weight_sign = sign.loc[grp.index].iloc[0] + prices = grp["price"].astype(float).to_numpy() + idx = grp.index + + if weight_sign == 0 or len(prices) == 0: + continue + + entry_price = prices[0] + # 方向化累计收益:long = price/entry - 1; short = 1 - price/entry + hold = (prices / entry_price - 1.0) * weight_sign + # 单 bar 收益(方向化):长仓 = price[t]/price[t-1] - 1, 短仓取负 + bar_ret = pd.Series(prices, index=idx).pct_change().fillna(0.0).to_numpy() * weight_sign + running_min = pd.Series(hold).cummin().to_numpy() + + out.loc[idx, "hold_returns"] = hold + out.loc[idx, "min_hold_returns"] = running_min + out.loc[idx, "returns"] = bar_ret + + # 止损触发条件:min_hold_returns ≤ -stoploss + # 加一个极小容差,避免浮点除法(如 92/100 - 1 = -0.07999999999999996)让 + # 边界值落在止损线另一侧。 + triggered = running_min <= -float(stoploss) + 1e-9 + if triggered.any(): + first_stop = int(triggered.argmax()) + stop_idx = idx[first_stop:] + out.loc[stop_idx, "is_stop"] = True + out.loc[stop_idx, "weight"] = 0.0 + + return out diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index df7ae92d8..8fd3e3be8 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -38,6 +38,7 @@ | 能力 | 位置 | 状态 | 说明 | |-|-|-|-| | `is_trading_time` | [crates/czsc-utils/src/trading_time.rs](../crates/czsc-utils/src/trading_time.rs) | **已实现** (commit `Phase C.3`) | rs-czsc 中尚未实现,czsc 内部新增。Rust 端 6 个测试 PASS;PyO3 binding 通过 `czsc-utils` 的 `python` feature 暴露,已在 `czsc-python` 注册槽连接。Python 端 A6 转 GREEN 待 Phase H 完成 maturin 构建后达成。支持 `astock` / `hk` / `crypto` 三个市场;naive datetime 输入按市场本地时间解读 | +| `stoploss_by_direction` | [czsc/utils/trade.py](../czsc/utils/trade.py)、[test/test_stoploss_by_direction.py](../test/test_stoploss_by_direction.py) | **已实现** (Phase Q, 2026-05-07) | rs_czsc 与 wbt 中均无;历史 `czsc/svc/backtest.py` 的 `from rs_czsc import stoploss_by_direction` 是死调用。按 superpowers TDD 范式新增 6 个 RED 测试 + 纯 Python 实现:按方向连续段切 `order_id`、向量化输出 `raw_weight / weight / hold_returns / min_hold_returns / returns / is_stop` 列;浮点边界以 1e-9 容差处理。`czsc/svc/backtest.py:261` 已切到 `from czsc.utils.trade import stoploss_by_direction`,czsc 内 `rs_czsc` 引用清零 | ### 2.3 czsc-ta 算子裁剪清单 @@ -774,3 +775,4 @@ the signal-output level on data ranging from 500 bars to 40k bars. | 修复 `.pyi` 中已删除的 `czsc.core` 引用 | [czsc/utils/sig.pyi](../czsc/utils/sig.pyi)、[czsc/utils/plotting/kline.pyi](../czsc/utils/plotting/kline.pyi) | 5 处 `from czsc.core import X as X` → `from czsc import X as X`(`czsc.core` 已在 Phase H 删除,basedpyright 会报错);同步更新 `czsc/utils/sig.py` docstring 中"通过 ``rs_czsc.Signal`` 把…" → "通过 ``czsc.Signal`` 把…" | | `czsc/traders/sig_parse.py` 退役 `_lazy_rs_czsc` | [czsc/traders/sig_parse.py](../czsc/traders/sig_parse.py) | 验证 `czsc._native.{derive_signals_config, derive_signals_freqs, list_all_signals}` 三函数已在 Phase F 全量上线(246 个信号模板可拉),将模块顶部的 `_lazy_rs_czsc` 工厂 + 三个 wrapper(`derive_signals_config` / `derive_signals_freqs` / `list_all_signals`)一次性删掉,改为顶层 `from czsc._native import ...`;同步移除 `if list_all_signals is not None` / `if derive_signals_config is not None` 等永真分支以及对应注释中的 `rs_czsc` 提法。`SignalsParser` 注册表初始化大小 = 246,与原 lazy 路径一致;spec §3.3 中"待评估 Rust 是否已等价实现"的临时性脚注随之失效。`czsc/traders/sig_parse.pyi` 同步更新(顶层补 `derive_signals_config / freqs / list_all_signals` 与 `sig_k3_map` 属性声明) | | 移除 `czsc/__init__.py` lazy loading + 注释/文档密度收缩 | [czsc/__init__.py](../czsc/__init__.py) | 按 spec §3.1 删除 `_LAZY_MODULES` / `_LAZY_ATTRS` / `__getattr__` 三件套;`svc / fsa / aphorism / mock` 改为顶层 `from . import ...`,7 个 lazy 属性(`capture_warnings` / `execute_with_warning_capture` / `adjust_holding_weights` / `log_strategy_info` / `plot_czsc_chart` / `KlineChart` / `check_kline_quality`)改为 `from czsc.utils.* import ...` 直接导入;删除 `if TYPE_CHECKING` 守卫;`welcome()` 函数体内的 `from czsc import aphorism` 提到顶层。同时压缩区段注释(17 处冗长的"逐符号说明"注释块全删)、`__all__` 字面表改为按主题分组的紧凑横排(仍保留全部 129 个公共名称、按主题用单行注释分隔)、`welcome()` docstring 折成单行、模块 docstring 22 行 → 11 行。`czsc/__init__.py` LoC 从 507 → 235(-54%)。**循环 import 防坑**:`svc / fsa / aphorism / mock` 中含 `from czsc import top_drawdowns` 等回环 import,必须放到所有顶层符号绑定后再加载(即"第二批 `from . import aphorism, fsa, mock, svc`"),第一次重排误把它们提到顶部触发 `cannot import name 'top_drawdowns' from partially initialized module 'czsc'`,调整顺序后通过;文件中以"第一批 / 第二批"分组注释固化此约束。**测试更新**:`test/test_import_performance.py::test_heavy_dependencies_not_loaded_on_import` 与 `test_svc_lazy_loaded` 是基于"streamlit 不应在 import czsc 时被加载"的旧设计断言,与新方向冲突,已删除;保留 `test_czsc_import_time`(< 10s 兜底)与 `test_czsc_svc_accessible`(顶层属性可用)。冷启动 importtime cumtime ≈ 320ms(spec §6 P3 目标 ≤ 300ms,超 ~7%;spec §3.1 注释中预期 < 50ms 仅指 Rust 扩展加载,不含整包 import) | +| 实现 `czsc.utils.trade.stoploss_by_direction` 并切换调用方 | [czsc/utils/trade.py](../czsc/utils/trade.py)、[czsc/svc/backtest.py:261](../czsc/svc/backtest.py)、[test/test_stoploss_by_direction.py](../test/test_stoploss_by_direction.py) | 调研发现 `stoploss_by_direction` 既不在当前安装的 `rs_czsc`,也不在 `wbt`,更不在 `/Users/jun/Documents/vscodePro/rs_czsc` git 历史中——`from rs_czsc import stoploss_by_direction` 是死调用,运行 Streamlit dashboard 时会 `ImportError`。按 spec C1(`grep -r rs_czsc czsc/` 应无结果)的目标,按 superpowers TDD 范式新增 6 个 RED 测试(多/空头止损、order_id 切分、列契约、入参不可变性等),用纯 Python 在 `czsc/utils/trade.py` 写最小实现(按方向连续段切 order_id、向量化 hold_returns / min_hold_returns / returns / is_stop,浮点容差 1e-9 处理 `92/100 - 1 = -0.07999…` 这类边界),把 `czsc/svc/backtest.py:261` 的导入切到 `czsc.utils.trade.stoploss_by_direction`。**`grep -r 'from rs_czsc\\|import rs_czsc' czsc/ --include='*.py'` 现在零结果**——spec C1 全量达成,czsc 内部彻底无 `rs_czsc` 依赖。该函数标记为 czsc-only 改动,归入 §2.2 "新增能力"小节 | diff --git a/test/test_stoploss_by_direction.py b/test/test_stoploss_by_direction.py new file mode 100644 index 000000000..73bd81fea --- /dev/null +++ b/test/test_stoploss_by_direction.py @@ -0,0 +1,120 @@ +""" +author: claude-code +create_dt: 2026/05/07 +describe: 测试 czsc.utils.trade.stoploss_by_direction —— 按方向止损改写权重序列。 + +本文件按照 spec §10 P1 中"用 czsc 自实现替代 rs_czsc.stoploss_by_direction"的目标 +设计 RED 测试基线。函数语义参考 czsc/svc/backtest.py:show_stoploss_by_direction +对该工具的列契约(raw_weight / weight / order_id / hold_returns / min_hold_returns +/ returns / is_stop)。 +""" + +from __future__ import annotations + +import pandas as pd +import pytest + + +def _make_dfw(weights: list[float], prices: list[float], symbol: str = "X") -> pd.DataFrame: + assert len(weights) == len(prices) + return pd.DataFrame({ + "dt": pd.date_range("2024-01-01", periods=len(prices), freq="D"), + "symbol": symbol, + "weight": weights, + "price": prices, + }) + + +def test_long_position_no_stop(): + """多头持仓未触发止损:weight 不被改写,is_stop 全 False。""" + from czsc.utils.trade import stoploss_by_direction + + dfw = _make_dfw( + weights=[1.0, 1.0, 1.0, 1.0], + prices=[100.0, 101.0, 102.0, 103.0], + ) + out = stoploss_by_direction(dfw, stoploss=0.10) + + assert (out["weight"] == 1.0).all() + assert (out["raw_weight"] == 1.0).all() + assert out["is_stop"].any() is False or out["is_stop"].sum() == 0 + # 持仓累计收益对长仓 = price[t]/entry - 1 + assert out["hold_returns"].iloc[-1] == pytest.approx(0.03, rel=1e-9) + + +def test_long_position_triggers_stop(): + """多头持仓在浮亏达到 stoploss 时触发:从触发 bar 起 weight 归零、is_stop=True。""" + from czsc.utils.trade import stoploss_by_direction + + # 100 -> 92 = -8% 触达 stoploss=0.08 + dfw = _make_dfw( + weights=[1.0, 1.0, 1.0, 1.0, 1.0], + prices=[100.0, 95.0, 92.0, 91.0, 90.0], + ) + out = stoploss_by_direction(dfw, stoploss=0.08) + + # raw_weight 保持原值 + assert (out["raw_weight"] == 1.0).all() + # 触发点为第三个 bar(index 2,price=92, ret=-0.08) + assert out["is_stop"].iloc[0] is False or not out["is_stop"].iloc[0] + assert bool(out["is_stop"].iloc[2]) is True + # 触发后所有 bar 都标 is_stop 且权重归零 + assert (out.loc[out.index >= 2, "is_stop"]).all() + assert (out.loc[out.index >= 2, "weight"] == 0.0).all() + + +def test_short_position_triggers_stop(): + """空头持仓在浮亏达到 stoploss 时触发:方向相反,价格上涨触发止损。""" + from czsc.utils.trade import stoploss_by_direction + + # 空头 entry=100, price 涨到 110 = -10% 浮亏(空头亏损) + dfw = _make_dfw( + weights=[-1.0, -1.0, -1.0, -1.0], + prices=[100.0, 105.0, 110.0, 115.0], + ) + out = stoploss_by_direction(dfw, stoploss=0.10) + + # 第三个 bar(index 2, price=110, hold_ret=-0.10)触发 + assert bool(out["is_stop"].iloc[2]) is True + assert (out.loc[out.index >= 2, "weight"] == 0.0).all() + + +def test_order_id_increments_on_direction_change(): + """方向切换(含切到 0)时 order_id 自增。""" + from czsc.utils.trade import stoploss_by_direction + + dfw = _make_dfw( + weights=[1.0, 1.0, 0.0, -1.0, -1.0], + prices=[100.0, 101.0, 102.0, 103.0, 104.0], + ) + out = stoploss_by_direction(dfw, stoploss=0.20) + + # order_id 应该是 3 个不同的值(多 → 空仓 → 空) + order_ids = out["order_id"].tolist() + assert order_ids[0] == order_ids[1] + assert order_ids[2] != order_ids[1] + assert order_ids[3] == order_ids[4] + assert order_ids[3] != order_ids[2] + + +def test_required_output_columns(): + """函数必须返回 svc/backtest.py:show_stoploss_by_direction 所需的全部列。""" + from czsc.utils.trade import stoploss_by_direction + + dfw = _make_dfw(weights=[1.0, 1.0], prices=[100.0, 101.0]) + out = stoploss_by_direction(dfw, stoploss=0.05) + + required = {"dt", "symbol", "raw_weight", "weight", "price", + "hold_returns", "min_hold_returns", "returns", + "order_id", "is_stop"} + assert required.issubset(set(out.columns)) + + +def test_input_dfw_not_mutated(): + """函数不能就地修改入参(show_stoploss_by_direction 的 .copy() 已经先做了一层防御)。""" + from czsc.utils.trade import stoploss_by_direction + + dfw = _make_dfw(weights=[1.0, 1.0, 1.0], prices=[100.0, 92.0, 90.0]) + cols_before = set(dfw.columns) + _ = stoploss_by_direction(dfw, stoploss=0.05) + assert set(dfw.columns) == cols_before # 未给原 df 加列 From 80af45f705fb7d8a5af462d7a0a72de90288c0e2 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 16:42:17 +0800 Subject: [PATCH 06/23] docs(migration): expand Phase Q with deferred-items table and verification log MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 补充 §10.2 "故意保留 / 暂缓的项" 表格(czsc/sensors/ 部分恢复、 czsc/traders/optimize.py 保留为薄外观层、czsc/utils/ta.py MACD ×2 约定保留、 _native.pyi 与 envs 精简归 P1)以及 §10.3 一段命令行验证日志,作为本轮 4 个代码 commit 的合并尾单。 --- docs/MIGRATION_NOTES.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 8fd3e3be8..b9a5c8905 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -776,3 +776,35 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `czsc/traders/sig_parse.py` 退役 `_lazy_rs_czsc` | [czsc/traders/sig_parse.py](../czsc/traders/sig_parse.py) | 验证 `czsc._native.{derive_signals_config, derive_signals_freqs, list_all_signals}` 三函数已在 Phase F 全量上线(246 个信号模板可拉),将模块顶部的 `_lazy_rs_czsc` 工厂 + 三个 wrapper(`derive_signals_config` / `derive_signals_freqs` / `list_all_signals`)一次性删掉,改为顶层 `from czsc._native import ...`;同步移除 `if list_all_signals is not None` / `if derive_signals_config is not None` 等永真分支以及对应注释中的 `rs_czsc` 提法。`SignalsParser` 注册表初始化大小 = 246,与原 lazy 路径一致;spec §3.3 中"待评估 Rust 是否已等价实现"的临时性脚注随之失效。`czsc/traders/sig_parse.pyi` 同步更新(顶层补 `derive_signals_config / freqs / list_all_signals` 与 `sig_k3_map` 属性声明) | | 移除 `czsc/__init__.py` lazy loading + 注释/文档密度收缩 | [czsc/__init__.py](../czsc/__init__.py) | 按 spec §3.1 删除 `_LAZY_MODULES` / `_LAZY_ATTRS` / `__getattr__` 三件套;`svc / fsa / aphorism / mock` 改为顶层 `from . import ...`,7 个 lazy 属性(`capture_warnings` / `execute_with_warning_capture` / `adjust_holding_weights` / `log_strategy_info` / `plot_czsc_chart` / `KlineChart` / `check_kline_quality`)改为 `from czsc.utils.* import ...` 直接导入;删除 `if TYPE_CHECKING` 守卫;`welcome()` 函数体内的 `from czsc import aphorism` 提到顶层。同时压缩区段注释(17 处冗长的"逐符号说明"注释块全删)、`__all__` 字面表改为按主题分组的紧凑横排(仍保留全部 129 个公共名称、按主题用单行注释分隔)、`welcome()` docstring 折成单行、模块 docstring 22 行 → 11 行。`czsc/__init__.py` LoC 从 507 → 235(-54%)。**循环 import 防坑**:`svc / fsa / aphorism / mock` 中含 `from czsc import top_drawdowns` 等回环 import,必须放到所有顶层符号绑定后再加载(即"第二批 `from . import aphorism, fsa, mock, svc`"),第一次重排误把它们提到顶部触发 `cannot import name 'top_drawdowns' from partially initialized module 'czsc'`,调整顺序后通过;文件中以"第一批 / 第二批"分组注释固化此约束。**测试更新**:`test/test_import_performance.py::test_heavy_dependencies_not_loaded_on_import` 与 `test_svc_lazy_loaded` 是基于"streamlit 不应在 import czsc 时被加载"的旧设计断言,与新方向冲突,已删除;保留 `test_czsc_import_time`(< 10s 兜底)与 `test_czsc_svc_accessible`(顶层属性可用)。冷启动 importtime cumtime ≈ 320ms(spec §6 P3 目标 ≤ 300ms,超 ~7%;spec §3.1 注释中预期 < 50ms 仅指 Rust 扩展加载,不含整包 import) | | 实现 `czsc.utils.trade.stoploss_by_direction` 并切换调用方 | [czsc/utils/trade.py](../czsc/utils/trade.py)、[czsc/svc/backtest.py:261](../czsc/svc/backtest.py)、[test/test_stoploss_by_direction.py](../test/test_stoploss_by_direction.py) | 调研发现 `stoploss_by_direction` 既不在当前安装的 `rs_czsc`,也不在 `wbt`,更不在 `/Users/jun/Documents/vscodePro/rs_czsc` git 历史中——`from rs_czsc import stoploss_by_direction` 是死调用,运行 Streamlit dashboard 时会 `ImportError`。按 spec C1(`grep -r rs_czsc czsc/` 应无结果)的目标,按 superpowers TDD 范式新增 6 个 RED 测试(多/空头止损、order_id 切分、列契约、入参不可变性等),用纯 Python 在 `czsc/utils/trade.py` 写最小实现(按方向连续段切 order_id、向量化 hold_returns / min_hold_returns / returns / is_stop,浮点容差 1e-9 处理 `92/100 - 1 = -0.07999…` 这类边界),把 `czsc/svc/backtest.py:261` 的导入切到 `czsc.utils.trade.stoploss_by_direction`。**`grep -r 'from rs_czsc\\|import rs_czsc' czsc/ --include='*.py'` 现在零结果**——spec C1 全量达成,czsc 内部彻底无 `rs_czsc` 依赖。该函数标记为 czsc-only 改动,归入 §2.2 "新增能力"小节 | + +### 10.2 故意保留 / 暂缓的项 + +| 项 | 原计划(spec) | 实际处理 | 原因 | +|-|-|-|-| +| `czsc/sensors/` 完整恢复 | spec §9 "完整保留 3 文件 301 行(含 `CTAResearch`)" | 保留占位 `__init__.py`(15 行),未恢复 `cta.py` / `utils.py` | 历史 `czsc.sensors.cta.CTAResearch` 依赖 `from czsc.traders.dummy import DummyBacktest`,而 `dummy.py` 已按 spec §3.3 删除并由 `czsc.run_replay` / wbt 替代;1:1 恢复会引入坏 import。需先在 Rust 端 `czsc-trader` 提供等价 dummy/replay 后再恢复,归为后续 Phase G 收尾 | +| `czsc/traders/optimize.py` | spec §3.3 / §9 列入"完全删除" | 保留 | 现已是 Rust 端 `run_optimize_batch` 的 Python 薄外观层(配置归一化 + 物化数据 + 任务哈希 + 结果转发),与 spec 旧版"完全删除"假设不符;行为正确,无回归。以"过渡薄层"身份保留,不在 P0 范围内删除 | +| `czsc/utils/ta.py` | spec §3.2 删除(由 Rust `czsc.ta.*` 替代) | 保留 75 行 | 仅保留 czsc 仪表盘场景使用的 MACD 特殊约定("柱状图额外乘以 2"),Rust 端 `czsc-ta` 暂未迁移该约定。**不通过 `czsc.ta` 重新导出**(`czsc.ta` 已指向 Rust 子模块),调用方需显式 `from czsc.utils.ta import MACD`。后续把柱状图 ×2 约定纳入 `czsc-ta::pure` 后再删 | +| `czsc/_native.pyi` 自动生成 | spec §2.4 / Q4 | 未生成 | 需在 `crates/czsc-python/build.rs` 接 `pyo3-stub-gen`;属 P1 范畴,本次会话未做。当前类型检查靠 `czsc/__init__.py` 内联注解 + `py.typed` | +| `czsc.envs` 精简至 ~20 行 + `set_envs` Rust 入口 | spec §3.4 | 保留 116 行 | 需 Rust 端先暴露 `set_envs(min_bi_len=..., max_bi_num=..., verbose=...)` 入口;归 P1 | + +### 10.3 验证 + +``` +$ uv run python -c "import czsc; print(czsc.__version__, czsc.__date__)" +1.0.0 20260507 + +$ grep -rn "from rs_czsc\|import rs_czsc" czsc/ --include='*.py' +(无结果——spec C1 全量达成) + +$ uv run pytest test/compat/ test/unit/ test/test_envs.py test/test_io.py \ + test/test_warning_capture.py test/test_utils.py test/test_kline_quality.py \ + test/test_import_performance.py test/test_plotly_plot.py \ + test/test_trade_utils.py test/test_stoploss_by_direction.py -q +124 passed in 2.93s + +$ wc -l czsc/__init__.py czsc/traders/sig_parse.py +235 czsc/__init__.py # 507 -> 235,spec §3.1 lazy loading 已退役 +326 czsc/traders/sig_parse.py # 387 -> 326,_lazy_rs_czsc 工厂已退役 +``` + +公共 API 快照(`test/compat/snapshots/api_v1.json`,129 个公共名称)与 pickle roundtrip(5 个 PyO3 类)回归全部 GREEN,证明本轮 P0/P1 改动未破坏 §6 验收基线。 From 1b8a1c65def88cddad56c8214e756f6b1419c40e Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 16:48:05 +0800 Subject: [PATCH 07/23] refactor(envs): trim czsc.envs from 117 -> 49 lines (-58%) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 按 spec §3.4 收缩 czsc/envs.py 到目标行数级别(Rust 端 set_envs 入口仍归 P1, 本轮先把 Python 侧 docstring 密度降到位): - 模块级 20 行 docstring -> 11 行简洁说明 - 5 个函数(get_verbose / get_min_bi_len / get_max_bi_num + 内部 _env / _to_bool) 各 8-15 行 verbose docstring -> 1-2 行单行说明 - 行为完全保留:环境变量大小写降级、参数显式优先、bool 宽松解析、int(float) 容错都不变 同步清理 czsc/envs.pyi:删除已废弃的 valid_true / use_python / get_welcome 三个公共符号声明(test/test_envs.py::TestRetiredHelpers 已经在断言它们在 .py 中不存在,但 .pyi 之前未跟进,basedpyright 仍把它们误识为存在)。 测试:test/test_envs.py 16 项全部通过;全量 sweep 124 项 GREEN。 --- czsc/envs.py | 99 +++++++---------------------------------- czsc/envs.pyi | 12 ++--- docs/MIGRATION_NOTES.md | 3 +- 3 files changed, 21 insertions(+), 93 deletions(-) diff --git a/czsc/envs.py b/czsc/envs.py index 094b27a8a..9541e1296 100644 --- a/czsc/envs.py +++ b/czsc/envs.py @@ -1,59 +1,30 @@ -""" -czsc.envs —— 极简环境变量适配层(迁移到 Rust 后保留版本) - -背景与定位: - 迁移到 Rust 后端之后,原本的"Python 回退开关"已经下线(Phase H 之后 - Python 端不再保留任何缠论核心算法的回退实现,所有调用都走 Rust)。 - 因此运行时参数被裁剪到仅剩三项,全部用于配置 czsc-core 的分析行为或 - 日志详尽程度。 - -环境变量命名约定: - 1. 推荐使用全大写形式(如 ``CZSC_MIN_BI_LEN``) - 2. 出于历史兼容性,也接受全小写形式(如 ``czsc_min_bi_len``) - 3. 当大小写两种形式都设置时,**大写形式优先** - 4. 函数参数 v 显式传入时,优先级最高(覆盖环境变量) - -可读取的环境变量: - - CZSC_VERBOSE —— 是否打印详细日志(True/False) - - CZSC_MIN_BI_LEN —— 笔的最小长度(含包含处理后 K 线根数) - - CZSC_MAX_BI_NUM —— 单个 CZSC 实例保留的最大笔数 +"""czsc.envs —— 极简环境变量适配层(spec §3.4)。 + +迁移到 Rust 后端后仅保留三项运行时参数: + +- ``CZSC_VERBOSE`` —— 是否打印详细日志(True/False) +- ``CZSC_MIN_BI_LEN`` —— 笔的最小长度(去包含后的 K 线根数;默认 6) +- ``CZSC_MAX_BI_NUM`` —— 单个 CZSC 实例保留的最大笔数(默认 50) + +约定:环境变量名同时接受全大写与全小写写法(大写优先),函数参数显式传值时 +优先级最高(覆盖环境变量)。 """ from __future__ import annotations import os -# 被视为"真值"的字符串集合(大小写不敏感,比较前会先 lower 化) -# 列举常见写法,避免用户因大小写或缩写导致开关不生效 +# 被视为"真值"的字符串(小写化后比对,覆盖常见写法) _VALID_TRUE = {"1", "true", "y", "yes"} def _env(name: str, default: str | None = None) -> str | None: - """ - 带大小写兜底的环境变量读取 - - 优先级: - 1. 全大写形式(如 ``NAME``) - 2. 全小写形式(如 ``name``) - 3. 调用方提供的 default - - 保留小写形式纯属向后兼容;新代码请只使用大写命名。 - """ + """读取环境变量,依次按 UPPER / lower 大小写降级。""" return os.environ.get(name.upper(), os.environ.get(name.lower(), default)) def _to_bool(v) -> bool: - """ - 宽松版"任意值 -> bool" 转换 - - 规则: - - bool 直接原样返回(避免 ``True/False`` 被字符串化后重判一次) - - None 视为 False - - 其他对象先 ``str()`` 再 strip + lower,比对 ``_VALID_TRUE`` 集合 - - 用于解析"用户填写的环境变量值是否启用某开关"的场景, - 比 ``bool(s)``(任意非空字符串都为 True)更符合直觉。 - """ + """把任意值宽松地转成 bool(None → False;字符串按 ``_VALID_TRUE`` 集合判定)。""" if isinstance(v, bool): return v if v is None: @@ -62,55 +33,17 @@ def _to_bool(v) -> bool: def get_verbose(verbose=None) -> bool: - """ - 判断是否启用详细日志输出 - - 参数: - verbose: 显式传入的开关值;若为 None 则读取环境变量 ``CZSC_VERBOSE`` - - 返回: - bool;任何被 :func:`_to_bool` 视为真值的输入都会返回 True - """ + """返回是否启用详细日志(``CZSC_VERBOSE`` 环境变量;显式参数优先)。""" return _to_bool(verbose if verbose is not None else _env("czsc_verbose")) def get_min_bi_len(v: int | None = None) -> int: - """ - 获取笔的最小长度(去包含后的 K 线根数) - - 取值含义: - 6 —— 新规范,要求笔至少跨越 6 根去包含后的 K 线 - 7 —— 旧规范,部分历史策略仍依赖此设置以维持原始走势识别口径 - - 参数: - v: 显式传入的最小笔长度;若为 None 则读取 ``CZSC_MIN_BI_LEN``, - 都未提供时使用默认值 6 - - 返回: - int 形式的最小笔长度 - - 备注: - ``int(float(raw))`` 是为兼容 ``"6"`` / ``"6.0"`` / ``6.5`` 等多种 - 字符串/数值书写形式,统一在小数点截断后再转 int。 - """ + """返回笔最小长度(``CZSC_MIN_BI_LEN``;默认 6)。``int(float(...))`` 兼容 "6"/"6.0"/6.5 等输入。""" raw = v if v is not None else _env("czsc_min_bi_len", 6) return int(float(raw)) def get_max_bi_num(v: int | None = None) -> int: - """ - 获取单个 CZSC 实例保留的最大笔数 - - 参数: - v: 显式传入的上限值;若为 None 则读取 ``CZSC_MAX_BI_NUM``, - 都未提供时默认 50 - - 返回: - int 形式的最大笔数 - - 备注: - - 用途:回放/实时计算时只保留最近 N 笔,避免内存与计算量随时间无界增长 - - 取值过小会丢掉中长周期信号;过大会增加每根 K 线的更新成本 - """ + """返回单个 CZSC 实例的最大笔数(``CZSC_MAX_BI_NUM``;默认 50)。""" raw = v if v is not None else _env("czsc_max_bi_num", 50) return int(float(raw)) diff --git a/czsc/envs.pyi b/czsc/envs.pyi index 801fb0ea7..ab26dc538 100644 --- a/czsc/envs.pyi +++ b/czsc/envs.pyi @@ -1,9 +1,3 @@ -from _typeshed import Incomplete - -valid_true: Incomplete - -def use_python(): ... -def get_verbose(verbose=None): ... -def get_welcome(): ... -def get_min_bi_len(v: int = None) -> int: ... -def get_max_bi_num(v: int = None) -> int: ... +def get_verbose(verbose: bool | None = ...) -> bool: ... +def get_min_bi_len(v: int | None = ...) -> int: ... +def get_max_bi_num(v: int | None = ...) -> int: ... diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index b9a5c8905..9987f38a9 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -776,6 +776,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `czsc/traders/sig_parse.py` 退役 `_lazy_rs_czsc` | [czsc/traders/sig_parse.py](../czsc/traders/sig_parse.py) | 验证 `czsc._native.{derive_signals_config, derive_signals_freqs, list_all_signals}` 三函数已在 Phase F 全量上线(246 个信号模板可拉),将模块顶部的 `_lazy_rs_czsc` 工厂 + 三个 wrapper(`derive_signals_config` / `derive_signals_freqs` / `list_all_signals`)一次性删掉,改为顶层 `from czsc._native import ...`;同步移除 `if list_all_signals is not None` / `if derive_signals_config is not None` 等永真分支以及对应注释中的 `rs_czsc` 提法。`SignalsParser` 注册表初始化大小 = 246,与原 lazy 路径一致;spec §3.3 中"待评估 Rust 是否已等价实现"的临时性脚注随之失效。`czsc/traders/sig_parse.pyi` 同步更新(顶层补 `derive_signals_config / freqs / list_all_signals` 与 `sig_k3_map` 属性声明) | | 移除 `czsc/__init__.py` lazy loading + 注释/文档密度收缩 | [czsc/__init__.py](../czsc/__init__.py) | 按 spec §3.1 删除 `_LAZY_MODULES` / `_LAZY_ATTRS` / `__getattr__` 三件套;`svc / fsa / aphorism / mock` 改为顶层 `from . import ...`,7 个 lazy 属性(`capture_warnings` / `execute_with_warning_capture` / `adjust_holding_weights` / `log_strategy_info` / `plot_czsc_chart` / `KlineChart` / `check_kline_quality`)改为 `from czsc.utils.* import ...` 直接导入;删除 `if TYPE_CHECKING` 守卫;`welcome()` 函数体内的 `from czsc import aphorism` 提到顶层。同时压缩区段注释(17 处冗长的"逐符号说明"注释块全删)、`__all__` 字面表改为按主题分组的紧凑横排(仍保留全部 129 个公共名称、按主题用单行注释分隔)、`welcome()` docstring 折成单行、模块 docstring 22 行 → 11 行。`czsc/__init__.py` LoC 从 507 → 235(-54%)。**循环 import 防坑**:`svc / fsa / aphorism / mock` 中含 `from czsc import top_drawdowns` 等回环 import,必须放到所有顶层符号绑定后再加载(即"第二批 `from . import aphorism, fsa, mock, svc`"),第一次重排误把它们提到顶部触发 `cannot import name 'top_drawdowns' from partially initialized module 'czsc'`,调整顺序后通过;文件中以"第一批 / 第二批"分组注释固化此约束。**测试更新**:`test/test_import_performance.py::test_heavy_dependencies_not_loaded_on_import` 与 `test_svc_lazy_loaded` 是基于"streamlit 不应在 import czsc 时被加载"的旧设计断言,与新方向冲突,已删除;保留 `test_czsc_import_time`(< 10s 兜底)与 `test_czsc_svc_accessible`(顶层属性可用)。冷启动 importtime cumtime ≈ 320ms(spec §6 P3 目标 ≤ 300ms,超 ~7%;spec §3.1 注释中预期 < 50ms 仅指 Rust 扩展加载,不含整包 import) | | 实现 `czsc.utils.trade.stoploss_by_direction` 并切换调用方 | [czsc/utils/trade.py](../czsc/utils/trade.py)、[czsc/svc/backtest.py:261](../czsc/svc/backtest.py)、[test/test_stoploss_by_direction.py](../test/test_stoploss_by_direction.py) | 调研发现 `stoploss_by_direction` 既不在当前安装的 `rs_czsc`,也不在 `wbt`,更不在 `/Users/jun/Documents/vscodePro/rs_czsc` git 历史中——`from rs_czsc import stoploss_by_direction` 是死调用,运行 Streamlit dashboard 时会 `ImportError`。按 spec C1(`grep -r rs_czsc czsc/` 应无结果)的目标,按 superpowers TDD 范式新增 6 个 RED 测试(多/空头止损、order_id 切分、列契约、入参不可变性等),用纯 Python 在 `czsc/utils/trade.py` 写最小实现(按方向连续段切 order_id、向量化 hold_returns / min_hold_returns / returns / is_stop,浮点容差 1e-9 处理 `92/100 - 1 = -0.07999…` 这类边界),把 `czsc/svc/backtest.py:261` 的导入切到 `czsc.utils.trade.stoploss_by_direction`。**`grep -r 'from rs_czsc\\|import rs_czsc' czsc/ --include='*.py'` 现在零结果**——spec C1 全量达成,czsc 内部彻底无 `rs_czsc` 依赖。该函数标记为 czsc-only 改动,归入 §2.2 "新增能力"小节 | +| `czsc.envs` 精简(Python 侧 docstring 收缩) | [czsc/envs.py](../czsc/envs.py)、[czsc/envs.pyi](../czsc/envs.pyi) | spec §3.4 目标"~20 行"是含 Rust `set_envs(...)` 入口后的最终形态;本轮先做 Python 侧最大化压缩:117 → 49 行(-58%)。3 个 getter(`get_verbose` / `get_min_bi_len` / `get_max_bi_num`) + 2 个内部 helper(`_env` / `_to_bool`)逻辑完全保留;裁剪掉模块级 ~20 行说明性 docstring 与每个函数 8-15 行的 verbose docstring,改为单行说明。`envs.pyi` 同步更新:删除 `valid_true: Incomplete` / `def use_python(): ...` / `def get_welcome(): ...` 三个旧公共符号(`test/test_envs.py::TestRetiredHelpers` 已经断言它们在 .py 中不存在,但 .pyi 之前未跟进,会让 basedpyright 把它们误识为存在)。`test_envs.py` 全 16 项通过 | ### 10.2 故意保留 / 暂缓的项 @@ -785,7 +786,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `czsc/traders/optimize.py` | spec §3.3 / §9 列入"完全删除" | 保留 | 现已是 Rust 端 `run_optimize_batch` 的 Python 薄外观层(配置归一化 + 物化数据 + 任务哈希 + 结果转发),与 spec 旧版"完全删除"假设不符;行为正确,无回归。以"过渡薄层"身份保留,不在 P0 范围内删除 | | `czsc/utils/ta.py` | spec §3.2 删除(由 Rust `czsc.ta.*` 替代) | 保留 75 行 | 仅保留 czsc 仪表盘场景使用的 MACD 特殊约定("柱状图额外乘以 2"),Rust 端 `czsc-ta` 暂未迁移该约定。**不通过 `czsc.ta` 重新导出**(`czsc.ta` 已指向 Rust 子模块),调用方需显式 `from czsc.utils.ta import MACD`。后续把柱状图 ×2 约定纳入 `czsc-ta::pure` 后再删 | | `czsc/_native.pyi` 自动生成 | spec §2.4 / Q4 | 未生成 | 需在 `crates/czsc-python/build.rs` 接 `pyo3-stub-gen`;属 P1 范畴,本次会话未做。当前类型检查靠 `czsc/__init__.py` 内联注解 + `py.typed` | -| `czsc.envs` 精简至 ~20 行 + `set_envs` Rust 入口 | spec §3.4 | 保留 116 行 | 需 Rust 端先暴露 `set_envs(min_bi_len=..., max_bi_num=..., verbose=...)` 入口;归 P1 | +| `czsc.envs` 精简(Python 侧) | spec §3.4 | **2026-05-07 已完成**:117 → 49 行(-58%) | 详见上表"`czsc.envs` 精简(Python 侧 docstring 收缩)"。Rust 端 `set_envs(min_bi_len=..., max_bi_num=..., verbose=...)` 入口仍归 P1(需 `czsc-utils` 暴露后再做),现状 `CZSC` 与 `BarGenerator` 通过构造器参数接收 envs,不依赖 Rust 全局 | ### 10.3 验证 From 7ef3ff1ab7b3e83416e590f7652ca431943d8210 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 16:54:03 +0800 Subject: [PATCH 08/23] docs(claude): sync project memory with Phase H/J/Q current state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CLAUDE.md 仍记着已退役的 czsc/core.py / czsc/py/ / CZSC_USE_PYTHON 与已删除的 czsc/utils/{st_components,bar_generator,...}.py 路径。本次按当前仓库现状刷新: - 核心组件 (§):1) czsc/core.py -> czsc._native(PyO3 扩展模块);2) czsc/py/ -> crates/(9 个 Rust crate);明确"不存在 Python 回退、CZSC_USE_PYTHON 已退役" - 数据格式转换示例:from czsc.core import CZSC ... -> from czsc import CZSC ... - 关键环境变量:CZSC_USE_PYTHON 标为废弃;CZSC_VERBOSE / MIN_BI_LEN / MAX_BI_NUM 补充大小写约定与构造器参数优先级 - 缓存管理:czsc.utils.cache.home_path -> czsc.home_path(顶层别名)/ czsc.utils.data.cache.home_path(实际定义处) - Streamlit 集成:把 czsc/utils/st_components.py 改为 czsc/svc/, 并标注 lazy loading 已按 spec §3.1 移除 - Rust/Python 混合架构 (§):删掉"版本控制 / 回退机制"两条,改为说明 Rust 是唯一实现 + maturin 构建 + czsc._native 扩展模块名 + pyo3-stub-gen 自动 stub(计划中)+ 一次性 fork rs-czsc 后不再同步 - czsc/utils/ 列表:refresh 到 Phase J 后的目录结构(data/ / analysis/ / crypto/ / plotting/ 子目录 + trade.py 含 stoploss_by_direction) --- CLAUDE.md | 71 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index f73a69fb1..7d698ec68 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -45,15 +45,18 @@ uv run flake8 czsc/ test/ ### 核心组件 -1. **`czsc/core.py`** - 混合架构核心模块,智能选择 Rust/Python 实现: - - Rust 版本优先(rs-czsc),性能优化 - - Python 版本作为回退方案 - - 导入核心类:`CZSC`、`RawBar`、`NewBar`、`Signal`、`Event`、`Position` 等 - -2. **`czsc/py/`** - Python 实现的核心算法: - - `analyze.py`: 缠论分析核心类,实现分型、笔的自动识别 - - `objects.py`: 核心数据结构定义 - - `bar_generator.py`: K线数据生成和重采样 +1. **`czsc._native`** - PyO3 编译产生的 Rust 扩展模块(缠论核心): + - 由 `crates/czsc-python` 通过 `maturin` 打包,扩展模块名 `czsc._native` + - 暴露 `CZSC / FX / BI / ZS / RawBar / NewBar / Freq / Mark / Direction / Operate / Signal / Event / Position / BarGenerator` 等核心类型 + - 暴露 `check_bi / check_fx / check_fxs / remove_include / freq_end_time / is_trading_time` 等工具函数 + - 暴露 30+ 信号函数(按 `czsc._native.signals.{bar,cxt,tas,vol,pressure,obv,cvolp}` 分组) + - 暴露 `czsc._native.ta.*`(Rust TA 算子,对应 `czsc.ta.*`) + - **不存在 Python 回退**:`czsc/py/` 与 `czsc/core.py` 已在 Phase H 删除;`CZSC_USE_PYTHON` 环境变量已退役(spec §3.4) + +2. **`crates/`** - Rust workspace(9 个 crate,详见 `docs/MIGRATION_NOTES.md` §1): + - `czsc-core` / `czsc-signals` / `czsc-trader` / `czsc-utils` / `czsc-ta` + - `czsc-signal-macros` / `error-macros` / `error-support`(proc-macro / 错误支持) + - `czsc-python`(PyO3 binding 总入口,唯一启用 `pyo3/extension-module` 的 crate) 3. **`czsc/traders/`** - 交易执行框架: - `base.py`: CzscSignals 和 CzscTrader 核心类 @@ -75,12 +78,15 @@ uv run flake8 czsc/ test/ - `feature.py`: 特征选择器和分析器 - `event.py`: 事件匹配和检测 -6. **`czsc/utils/`** - 工具模块: - - `bar_generator.py`: K线数据生成和重采样 - - `cache.py`: 磁盘缓存工具 - - `st_components.py`: Streamlit仪表板组件 - - `ta.py`: 技术分析指标 - - `data_client.py`: 统一数据客户端接口 +6. **`czsc/utils/`** - 工具模块(Phase J 精简后): + - `data/cache.py` / `io.py` / `log.py` / `kline_quality.py`:缓存、IO、日志、K 线质量校验 + - `analysis/`(`stats.py` 业绩统计 / `corr.py` 相关性分析) + - `crypto/fernet.py`:URL token 加解密 + - `data/client.py`:统一数据客户端接口 + - `ta.py`:仅保留 czsc 仪表盘场景使用的 MACD 特殊约定(其余 TA 算子由 Rust `czsc.ta.*` 提供) + - `trade.py`:交易工具(含本轮 TDD 新增的 `stoploss_by_direction`) + - `plotting/{kline,backtest,weight,common}.py`:Plotly 图表绘制 + - 已删除:`bar_generator.py` / `bi_info.py`(Rust 已实现)、`st_components.py`(迁至 `czsc/svc/`)、`echarts_*` / `pdf_report` / `html_report_builder` / `word_writer` / `signal_analyzer`(spec §9 完全删除) - `plot_backtest.py`: **回测可视化工具(已优化)** - 提供 Plotly 交互式图表绘制函数 - 支持累计收益曲线、回撤分析、收益分布、月度热力图等 @@ -146,8 +152,8 @@ CZSC 支持使用 `CzscTrader` 类进行多级别联立分析,可同时分析 ### 数据格式转换 ```python -# 从mock数据生成CZSC对象的正确模式 -from czsc.core import CZSC, format_standard_kline, Freq +# 从mock数据生成CZSC对象的正确模式(czsc.core 已删除,全部走顶层 czsc 命名空间) +from czsc import CZSC, Freq, format_standard_kline from czsc.mock import generate_symbol_kines # 生成K线数据 @@ -222,30 +228,35 @@ from czsc.utils.plot_backtest import ( ## 关键环境变量和设置 -- `CZSC_USE_PYTHON`: 强制使用 Python 版本实现(默认优先使用 Rust 版本) -- `czsc_min_bi_len`: 最小笔长度(来自 `czsc.envs`) -- `czsc_max_bi_num`: 最大笔数量(来自 `czsc.envs`) +- `CZSC_VERBOSE` / `czsc_verbose`:是否打印详细日志(来自 `czsc.envs`) +- `CZSC_MIN_BI_LEN` / `czsc_min_bi_len`:最小笔长度,默认 6(来自 `czsc.envs`) +- `CZSC_MAX_BI_NUM` / `czsc_max_bi_num`:最大笔数量,默认 50(来自 `czsc.envs`) +- 大小写两种写法都接受,大写优先;构造器显式参数优先级最高 +- `CZSC_USE_PYTHON` 已**废弃**(spec §3.4,Python 回退路径已删,所有调用统一走 Rust) - 缓存目录自动管理,具备大小监控功能 ## 缓存管理 项目大量使用磁盘缓存: -- 缓存位置:`czsc.utils.cache.home_path` +- 缓存位置:`czsc.home_path`(顶层)或 `czsc.utils.data.cache.home_path`(实际定义处) - 清除缓存:`czsc.empty_cache_path()` -- 监控大小:`czsc.get_dir_size(home_path)` -- 当缓存超过1GB时会显示清理提示 +- 监控大小:`czsc.get_dir_size(czsc.home_path)` +- 当缓存超过1GB时 `czsc.welcome()` 会显示清理提示 + +## Streamlit 集成 -## Streamlit集成 +项目在 `czsc/svc/`(Streamlit Visualize Components)中提供完整的可视化组件库,覆盖回测、相关性、因子、统计、策略、权重等场景。`czsc.svc` 在 `czsc/__init__.py` 中按 spec §3.1 改为静态 import,**不再走 lazy loading**——访问 `czsc.svc` 即可,无需任何延迟初始化。 -项目在 `czsc/utils/st_components.py` 中包含丰富的 Streamlit 分析组件,提供回测结果、相关性分析、因子分析等可视化工具。 +> 注:原 `czsc/utils/st_components.py` 已在 Phase J 删除,所有 Streamlit 组件统一收敛到 `czsc/svc/`。 ## Rust/Python 混合架构 -项目核心功能使用 Rust 重构以提升性能: -- **版本控制**: 通过环境变量 `CZSC_USE_PYTHON` 控制,默认优先使用 Rust 版本 -- **回退机制**: Rust 版本不可用时自动回退到 Python 版本(见 `czsc/core.py`) -- **核心模块**: 已迁移的模块包括 `CZSC` 分析器、K线生成器、枚举类型等 -- **版本检测**: 运行时自动检测 `rs-czsc` 库可用性和版本信息 +项目核心算法用 Rust 实现,通过 PyO3 暴露给 Python: +- **构建方式**:`maturin + Rust workspace`,扩展模块名 `czsc._native` +- **唯一架构**:Rust 是缠论核心算法的唯一实现;Python 端不再保留任何回退(spec §3.1 / §3.4) +- **API 暴露**:所有面向用户的 API 都通过 `czsc.xxx` 顶层命名空间暴露,禁止用户感知 `czsc._native` +- **类型 stub**:`czsc/py.typed` 启用 inline 类型注解;`czsc/_native.pyi` 由 `pyo3-stub-gen` 自动生成(计划中,详见 `docs/MIGRATION_NOTES.md` §10.2) +- **rs-czsc 关系**:czsc 一次性 fork rs-czsc 的 Rust 实现进本仓库,**不再做季度同步**;基线 commit 与 czsc-only 改动记录在 `docs/MIGRATION_NOTES.md` ## 数据连接器支持 From 12cf082489b8565cd5fdfae25a785acb4af94ea9 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 17:08:52 +0800 Subject: [PATCH 09/23] =?UTF-8?q?feat(stubs):=20auto-generate=20czsc/=5Fna?= =?UTF-8?q?tive.pyi=20via=20pyo3-stub-gen=20(spec=20=C2=A72.4=20/=20Q4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 各业务 crate 的 #[gen_stub_pyclass] / #[gen_stub_pyfunction] / #[gen_stub_pymethods] 装饰器早就布到位了,但缺收集器 + 生成 binary。本次补齐: - crates/czsc-python/Cargo.toml - 把 pyo3 的 extension-module feature 拆成本地可选 default feature (default = ["extension-module"]),方便 binary 用 --no-default-features 跳过它,让 pyo3 自动链接 libpython(否则 macOS 链接器找不到 _PyExc_* 等符号) - 新增 [[bin]] stub_gen,path = src/bin/stub_gen.rs - crates/czsc-python/src/lib.rs - 写自定义 stub_info() 函数(替代 define_stub_info_gatherer! 宏): 把 from_pyproject_toml 路径显式指向 workspace 根 (默认宏假设 pyproject.toml 与 Cargo.toml 同目录,但 czsc 仓库的 pyproject 在 workspace 根、Cargo.toml 在 crates/czsc-python/) - crates/czsc-python/src/bin/stub_gen.rs - 最小入口:调用 stub_info()?.generate()? 触发命令: PYO3_PYTHON=$(uv run python -c 'import sys; print(sys.executable)') \ cargo run --bin stub_gen -p czsc-python --no-default-features 产物:czsc/_native.pyi(1 235 行),覆盖 BI / CZSC / FX / ZS / BarGenerator / Position / Signal / Event / RawBar / NewBar / FakeBI / Direction / Mark / Operate / Freq / ParsedSignalDoc 等核心类,以及 30+ TA 算子、信号函数与 顶层 chip_distribution_triangle / parse_signal_doc 等。 pyproject.toml::tool.maturin.include = ["czsc/**/*.pyi", ...] 已经覆盖此文件, wheel 打包时自动携带,不需要追加配置。 残留:basedpyright 在 _native.pyi 上报 8 个 upstream pyo3-stub-gen 已知问题 (__eq__ 参数类型不兼容父类、__dict__ 与 dict[str, Any] 不兼容),属于工具 false-positive,不影响 IDE 提示与运行时行为。 测试:124 项全量 sweep 仍 GREEN;ast.parse 验证 _native.pyi 语法合法。 --- crates/czsc-python/Cargo.toml | 13 +- crates/czsc-python/src/bin/stub_gen.rs | 21 + crates/czsc-python/src/lib.rs | 17 + czsc/_native.pyi | 1235 ++++++++++++++++++++++++ docs/MIGRATION_NOTES.md | 3 +- 5 files changed, 1287 insertions(+), 2 deletions(-) create mode 100644 crates/czsc-python/src/bin/stub_gen.rs create mode 100644 czsc/_native.pyi diff --git a/crates/czsc-python/Cargo.toml b/crates/czsc-python/Cargo.toml index ae8492201..8a734b7c3 100644 --- a/crates/czsc-python/Cargo.toml +++ b/crates/czsc-python/Cargo.toml @@ -11,6 +11,17 @@ name = "czsc_python" path = "src/lib.rs" crate-type = ["cdylib", "rlib"] +# 默认启用 extension-module(cdylib 路径); +# 编译 stub_gen 这种独立 binary 时需要 --no-default-features, +# 让 pyo3 自动链接 libpython(详见 src/bin/stub_gen.rs)。 +[features] +default = ["extension-module"] +extension-module = ["pyo3/extension-module"] + +[[bin]] +name = "stub_gen" +path = "src/bin/stub_gen.rs" + [dependencies] czsc-core = { path = "../czsc-core", features = ["python"] } czsc-ta = { path = "../czsc-ta", features = ["rust-numpy"] } @@ -29,7 +40,7 @@ polars = { workspace = true } # abi3-py310. Business crates pull in the bare workspace pyo3 so # `cargo test --workspace` can link against a real libpython resolved # via PYO3_PYTHON. -pyo3 = { workspace = true, features = ["extension-module", "abi3-py310", "chrono"] } +pyo3 = { workspace = true, features = ["abi3-py310", "chrono"] } pyo3-stub-gen = "0.12" rust_xlsxwriter = "0.79" serde = { workspace = true } diff --git a/crates/czsc-python/src/bin/stub_gen.rs b/crates/czsc-python/src/bin/stub_gen.rs new file mode 100644 index 000000000..c98027b16 --- /dev/null +++ b/crates/czsc-python/src/bin/stub_gen.rs @@ -0,0 +1,21 @@ +//! `czsc-python` 的 type stub 生成器入口(spec §2.4 / Q4)。 +//! +//! 通过 `pyo3-stub-gen` 收集 Rust 端所有 `#[gen_stub_pyclass]` / +//! `#[gen_stub_pyfunction]` / `#[gen_stub_pymethods]` 装饰器注册的信息, +//! 写出 `czsc/_native.pyi` 供 basedpyright / IDE / type checker 消费。 +//! +//! 触发方式: +//! PYO3_PYTHON=$(uv run python -c 'import sys; print(sys.executable)') \ +//! cargo run --bin stub_gen -p czsc-python +//! +//! 输出路径由 `pyproject.toml` 里 `[tool.maturin].module-name = "czsc._native"` +//! 推导:写到 `czsc/_native.pyi`。 + +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + let stub = czsc_python::stub_info()?; + stub.generate()?; + println!("czsc/_native.pyi 生成完成"); + Ok(()) +} diff --git a/crates/czsc-python/src/lib.rs b/crates/czsc-python/src/lib.rs index ed61eed78..6a73c52de 100644 --- a/crates/czsc-python/src/lib.rs +++ b/crates/czsc-python/src/lib.rs @@ -63,3 +63,20 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } + +// === pyo3-stub-gen 收集器 === +// 收集所有 #[gen_stub_pyclass] / #[gen_stub_pyfunction] / #[gen_stub_pymethods] +// 装饰器注册的 Rust 端类型 / 函数信息。配套 binary `cargo run --bin stub_gen` +// 调用 stub_info() 后写出 czsc/_native.pyi(spec §2.4 / Q4)。 +// +// 注:由于 czsc 的 pyproject.toml 在 workspace 根(不是 crate 根), +// 这里**手写**一个等价于 `define_stub_info_gatherer!` 的入口,把 pyproject 路径 +// 显式指向上两级目录。 +pub fn stub_info() -> pyo3_stub_gen::Result { + let manifest_dir: &std::path::Path = env!("CARGO_MANIFEST_DIR").as_ref(); + let workspace_root = manifest_dir + .parent() // crates/ + .and_then(|p| p.parent()) // workspace 根 + .ok_or_else(|| anyhow::anyhow!("无法定位 workspace 根目录"))?; + pyo3_stub_gen::StubInfo::from_pyproject_toml(workspace_root.join("pyproject.toml")) +} diff --git a/czsc/_native.pyi b/czsc/_native.pyi new file mode 100644 index 000000000..ff800f954 --- /dev/null +++ b/czsc/_native.pyi @@ -0,0 +1,1235 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401 + +import builtins +import numpy +import numpy.typing +import typing +from enum import Enum + +class BI: + r""" + 笔 + """ + @property + def symbol(self) -> builtins.str: ... + @property + def direction(self) -> Direction: ... + @property + def high(self) -> builtins.float: ... + @property + def low(self) -> builtins.float: ... + @property + def cache(self) -> dict: ... + @property + def __dict__(self) -> typing.Any: + r""" + 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 + """ + @property + def sdt(self) -> typing.Any: ... + @property + def edt(self) -> typing.Any: ... + @property + def fx_a(self) -> FX: ... + @property + def fx_b(self) -> FX: ... + @property + def fxs(self) -> builtins.list[FX]: ... + @property + def bars(self) -> builtins.list[NewBar]: + r""" + 获取构成笔的NewBar列表 + """ + @property + def power(self) -> builtins.float: + r""" + 价差力度 + """ + @property + def power_price(self) -> builtins.float: + r""" + 价差力度(别名) + """ + @property + def power_volume(self) -> builtins.float: + r""" + 成交量力度 + """ + @property + def power_snr(self) -> builtins.float: + r""" + SNR 度量力度 + """ + @property + def change(self) -> builtins.float: + r""" + 笔的涨跌幅 + """ + @property + def SNR(self) -> builtins.float: + r""" + 笔内部的信噪比 + """ + @property + def slope(self) -> builtins.float: + r""" + 笔内部高低点之间的斜率 + """ + @property + def acceleration(self) -> builtins.float: + r""" + 笔内部价格的加速度 + """ + @property + def length(self) -> builtins.int: + r""" + 笔的无包含关系K线数量 + """ + @property + def rsq(self) -> builtins.float: + r""" + 笔的原始K线close单变量线性回归拟合优度 + """ + @property + def hypotenuse(self) -> builtins.float: + r""" + 笔的斜边长度 + """ + @property + def angle(self) -> builtins.float: + r""" + 笔的斜边与竖直方向的夹角,角度越大,力度越大 + """ + @property + def raw_bars(self) -> builtins.list[RawBar]: + r""" + 构成笔的原始K线序列,不包含首尾分型的首根K线 + """ + @property + def fake_bis(self) -> builtins.list[FakeBI]: + r""" + 笔的内部分型连接得到近似次级别笔列表 + """ + @property + def cache(self) -> typing.Any: + r""" + 缓存字典(与 czsc 库兼容) + """ + def __new__(cls, symbol:builtins.str, direction:Direction, fx_a:FX, fx_b:FX, fxs:typing.Sequence[FX], bars:typing.Sequence[NewBar]) -> BI: ... + def get_cache_with_default(self, _key:builtins.str, default_value:builtins.float) -> builtins.float: + r""" + 获取缓存值,如果不存在则返回默认值(与 czsc 库兼容) + """ + def get_price_linear(self, n:builtins.int) -> builtins.float: + r""" + 获取线性价格(与 czsc 库兼容) + """ + def __repr__(self) -> builtins.str: ... + def __richcmp__(self, other:BI, op:int) -> builtins.bool: ... + +class BarGenerator: + @property + def symbol_py(self) -> typing.Optional[builtins.str]: + r""" + 获取所属品种 - Python 属性 + """ + @property + def base_freq(self) -> builtins.str: + r""" + 获取基准频率 + """ + @property + def end_dt(self) -> typing.Optional[typing.Any]: + r""" + 获取end_dt属性(Python兼容) + """ + @property + def bars(self) -> typing.Any: + r""" + 获取各周期K线数据 - 返回字典,键为频率字符串,值为K线列表 + """ + def __new__(cls, base_freq:typing.Any, freqs:typing.Any, max_count:builtins.int=2000, market:typing.Optional[typing.Any]=None) -> BarGenerator: ... + def init_freq_bars(self, freq:typing.Any, bars:typing.Sequence[RawBar]) -> None: + r""" + 初始化某个周期的K线序列 + + # 函数计算逻辑 + + 1. 检查输入的`freq`是否存在于`self.freq_bars`的键中。如果不存在,返回错误。 + 2. 检查`self.freq_bars[freq]`是否为空。如果不为空,返回错误,表示不允许重复初始化。 + 3. 如果以上检查都通过,将输入的`bars`存储到`self.freq_bars[freq]`中。 + 4. 从`bars`中获取最后一根K线的交易标的代码,更新`self.symbol`。 + + # Arguments + + * `freq` - 周期名称 (支持字符串或Freq枚举) + * `bars` - K线序列 + """ + def get_latest_date(self) -> typing.Optional[builtins.str]: + r""" + 获取最新K线日期 + """ + def update(self, bar:RawBar) -> None: + r""" + 从Python RawBar对象更新K线数据 - 支持直接自动转换 + """ + def __reduce__(self) -> typing.Any: + r""" + 支持 pickle 序列化 - 使用 __reduce__ 方法 + """ + def __setstate__(self, state:typing.Any) -> None: + r""" + 支持 pickle 反序列化 + """ + +class CZSC: + @property + def symbol(self) -> builtins.str: ... + @property + def freq(self) -> Freq: ... + @property + def max_bi_num(self) -> builtins.int: ... + @property + def bi_list(self) -> builtins.list[BI]: ... + @property + def bars_raw(self) -> builtins.list[RawBar]: + r""" + 获取原始K线序列 - 返回PyRawBar对象列表 + """ + @property + def bars_raw_df(self) -> typing.Any: + r""" + 获取原始K线序列的DataFrame格式,便于绘图和分析 + """ + @property + def bars_ubi(self) -> builtins.list[NewBar]: + r""" + 获取无包含关系K线序列 + """ + @property + def finished_bis(self) -> builtins.list[BI]: + r""" + 获取已完成的笔列表(与 bi_list 相同,为兼容 czsc 库) + """ + @property + def fx_list(self) -> builtins.list[FX]: + r""" + 获取分型列表(属性,与 czsc 库兼容) + """ + @property + def cache(self) -> typing.Any: + r""" + 缓存字典(与 czsc 库兼容) + """ + @property + def signals(self) -> typing.Any: + r""" + 信号字典(与 czsc 库兼容) + """ + @property + def ubi_fxs(self) -> builtins.list[FX]: + r""" + 无包含关系K线分型列表(与 czsc 库兼容) + """ + @property + def ubi(self) -> typing.Any: + r""" + 无包含关系K线(与 czsc 库兼容) + 返回未完成的笔信息,格式与 Python 版本保持一致 + """ + @property + def verbose(self) -> builtins.bool: + r""" + 是否显示详细信息(与 czsc 库兼容) + """ + @property + def last_bi_extend(self) -> builtins.bool: + r""" + 最后一笔延伸情况(与 czsc 库兼容) + 判断最后一笔是否在延伸中,True 表示延伸中 + """ + @property + def cache(self) -> dict: + r""" + 缓存字典(与 czsc 库兼容) + """ + def __new__(cls, bars_raw:typing.Sequence[RawBar], max_bi_num:builtins.int=50) -> CZSC: ... + @staticmethod + def from_dataframe(df_bytes:bytes, freq:Freq, max_bi_num:builtins.int=50) -> CZSC: + r""" + 直接从Arrow格式的DataFrame创建CZSC对象,避免中间转换 + 这是高性能的批量创建接口,适用于大量数据的初始化 + + :param df_bytes: Arrow IPC格式的DataFrame字节数据 + :param freq: K线频率 + :param max_bi_num: 最大笔数量限制 + :return: CZSC对象 + """ + def open_in_browser(self, _renderer:typing.Optional[builtins.str]=None) -> builtins.str: + r""" + 在浏览器中打开(与 czsc 库兼容) + """ + def to_echarts(self) -> builtins.str: + r""" + 转换为 ECharts 格式(与 czsc 库兼容) + """ + def to_plotly(self) -> builtins.str: + r""" + 转换为 Plotly 格式(与 czsc 库兼容) + """ + def update(self, bar:RawBar) -> None: + r""" + 更新K线数据 + """ + def __repr__(self) -> builtins.str: ... + def __reduce__(self) -> typing.Any: + r""" + Pickle support — `__reduce__` returns ``(CZSC, (fixed_point_bars, max_bi_num))``. + + `update_bar` drains older bars whose dt is below the current + first-BI's start (see `bars_raw.drain` block above), so a + freshly-constructed CZSC's `bars_raw` may still differ from the + fixed point reached after a single re-analysis. We run one extra + `CZSC::new` here to converge before serializing — guarantees that + `pickle.dumps(restored) == pickle.dumps(obj)` byte-for-byte even + when CzscSignals nests CZSC inside `kas[freq]` (Phase A's + `restored.__getstate__() == obj.__getstate__()` assertion relies + on this). + """ + +class CzscSignals: + r""" + CzscSignals 的 PyO3 包装 + """ + @property + def name(self) -> builtins.str: + r""" + 返回类名 + """ + @property + def symbol(self) -> builtins.str: + r""" + 返回标的代码 + """ + @property + def s(self) -> typing.Any: + r""" + 返回信号字典 s + """ + @property + def kas(self) -> builtins.dict[builtins.str, CZSC]: + r""" + 返回各周期 CZSC 分析引擎 + """ + @property + def freqs(self) -> builtins.list[builtins.str]: + r""" + 返回所有周期字符串列表 + """ + @property + def base_freq(self) -> builtins.str: + r""" + 返回基准周期字符串 + """ + @property + def end_dt(self) -> typing.Optional[typing.Any]: + r""" + 返回最新时间,作为 pandas Timestamp + """ + @property + def bid(self) -> typing.Optional[builtins.int]: + r""" + 返回当前 bar id + """ + @property + def latest_price(self) -> typing.Optional[builtins.float]: + r""" + 返回最新价格 + """ + @property + def signals_config(self) -> typing.Any: + r""" + 返回原始信号配置 + """ + def __new__(cls, bg:BarGenerator, signals_config:list) -> CzscSignals: ... + def update_signals(self, bar:RawBar) -> None: + r""" + 更新信号 + """ + def get_signals_by_conf(self) -> typing.Any: + r""" + 获取当前信号字典(同 s 属性) + """ + def __reduce__(self) -> typing.Any: + r""" + Pickle 支持:返回 ``(cls, (bg_clone, signals_config_list))``。 + 反序列化时 PyCzscSignals 由原 ``__new__`` 重新构造;缓存的信号 + 状态不持久化(与 design doc §2.4 的 multiprocessing 用例一致: + 子进程拿到的是构造参数 fresh trader)。 + """ + +class CzscTrader: + r""" + CzscTrader 的 PyO3 包装 + """ + @property + def name(self) -> builtins.str: + r""" + 返回类名 + """ + @property + def symbol(self) -> builtins.str: + r""" + 返回标的代码 + """ + @property + def s(self) -> typing.Any: + r""" + 返回信号字典 s + """ + @property + def kas(self) -> builtins.dict[builtins.str, CZSC]: + r""" + 返回各周期 CZSC 分析引擎 + """ + @property + def freqs(self) -> builtins.list[builtins.str]: + r""" + 返回所有周期字符串列表 + """ + @property + def base_freq(self) -> builtins.str: + r""" + 返回基准周期字符串 + """ + @property + def end_dt(self) -> typing.Optional[typing.Any]: + r""" + 返回最新时间,作为 pandas Timestamp + """ + @property + def bid(self) -> typing.Optional[builtins.int]: + r""" + 返回当前 bar id + """ + @property + def latest_price(self) -> typing.Optional[builtins.float]: + r""" + 返回最新价格 + """ + @property + def signals_config(self) -> typing.Any: + r""" + 返回原始信号配置 + """ + @property + def positions(self) -> builtins.list[Position]: + r""" + 返回仓位列表(PyPosition 包装) + """ + @property + def pos_changed(self) -> builtins.bool: + r""" + 返回是否有仓位发生变化 + """ + def __new__(cls, bg:BarGenerator, positions:list, signals_config:list, ensemble_method:builtins.str='mean') -> CzscTrader: ... + def update(self, bar:RawBar) -> None: + r""" + 更新信号和仓位 + """ + def on_bar(self, bar:RawBar) -> None: + r""" + 更新信号和仓位(同 update) + """ + def on_sig(self, sig:dict) -> None: + r""" + 基于信号字典更新仓位 + """ + def get_ensemble_pos(self, method:typing.Optional[builtins.str]=None) -> builtins.float: + r""" + 获取集成后的仓位值 + """ + def get_position(self, name:builtins.str) -> typing.Optional[Position]: + r""" + 根据名称获取仓位 + """ + def get_signals_by_conf(self) -> typing.Any: + r""" + 获取当前信号字典 + """ + def update_signals(self, bar:RawBar) -> None: + r""" + 仅更新信号(不更新仓位) + """ + def __reduce__(self) -> typing.Any: + r""" + Pickle 支持:返回构造参数 (bg, positions, signals_config, ensemble_method)。 + 反序列化时由 ``__new__`` 重新构造一个 fresh trader;缓存的运行 + 状态不持久化(与 design doc §2.4 multiprocessing 用例一致)。 + """ + +class Event: + r""" + Python可见的Event包装器 + """ + @property + def operate(self) -> Operate: ... + @property + def signals_all(self) -> builtins.list[Signal]: ... + @property + def signals_any(self) -> builtins.list[Signal]: ... + @property + def signals_not(self) -> builtins.list[Signal]: ... + @property + def name(self) -> builtins.str: ... + @property + def unique_signals(self) -> builtins.list[builtins.str]: + r""" + 获取所有唯一信号(字符串格式,兼容原Python API) + """ + @property + def sha256(self) -> builtins.str: + r""" + 获取SHA256哈希 + """ + def __new__(cls, operate:Operate, signals_all:typing.Sequence[Signal]=[], signals_any:typing.Sequence[Signal]=[], signals_not:typing.Sequence[Signal]=[], name:builtins.str='') -> Event: ... + @classmethod + def from_dict(cls, _cls:type, dict:dict) -> Event: ... + @classmethod + def from_json(cls, _cls:type, json_str:builtins.str) -> Event: ... + def compute_sha8(self) -> builtins.str: + r""" + 计算SHA8哈希值 + """ + def is_match(self, signals:typing.Any) -> builtins.bool: + r""" + 判断事件是否匹配信号集合,返回是否匹配 + 支持多种参数类型:Dict[str, str] 或 Dict[str, Signal] 或 Vec + """ + def to_json(self) -> builtins.str: + r""" + 转换为JSON字符串 + """ + def __repr__(self) -> builtins.str: ... + def __str__(self) -> builtins.str: ... + def dump(self) -> typing.Any: + r""" + 导出为字典 + """ + @classmethod + def load(cls, _cls:type, data:dict) -> Event: + r""" + 从字典加载 + """ + def get_signals_config(self) -> builtins.list[builtins.str]: + r""" + 获取信号配置 + """ + def __reduce__(self) -> typing.Any: + r""" + 支持 pickle 序列化 - 使用 __reduce__ 方法 + """ + +class FX: + r""" + 分型 + """ + @property + def symbol(self) -> builtins.str: ... + @property + def dt(self) -> typing.Any: ... + @property + def mark(self) -> Mark: ... + @property + def high(self) -> builtins.float: ... + @property + def low(self) -> builtins.float: ... + @property + def fx(self) -> builtins.float: ... + @property + def new_bars(self) -> builtins.list[NewBar]: + r""" + 获取构成分型的NewBar列表 + """ + @property + def raw_bars(self) -> builtins.list[RawBar]: + r""" + 获取原始K线列表(从NewBar的elements中提取) + """ + @property + def power_str(self) -> builtins.str: + r""" + 获取分型强度字符串 + """ + @property + def power_volume(self) -> builtins.float: + r""" + 获取成交量力度 + """ + @property + def has_zs(self) -> builtins.bool: + r""" + 判断是否有重叠中枢 + """ + @property + def elements(self) -> builtins.list[NewBar]: + r""" + 获取构成分型的NewBar列表(与new_bars相同,为兼容czsc库) + """ + @property + def cache(self) -> dict: + r""" + 缓存字典(与 czsc 库兼容) + """ + @property + def __dict__(self) -> typing.Any: + r""" + 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 + """ + def __new__(cls, symbol:builtins.str, dt:typing.Any, mark:Mark, high:builtins.float, low:builtins.float, fx:builtins.float, elements:typing.Sequence[NewBar]) -> FX: ... + def __repr__(self) -> builtins.str: ... + def __richcmp__(self, other:FX, op:int) -> builtins.bool: ... + +class FakeBI: + r""" + 虚拟笔 + 主要为笔的内部分析提供便利 + """ + @property + def symbol(self) -> builtins.str: ... + @property + def sdt(self) -> typing.Any: ... + @property + def edt(self) -> typing.Any: ... + @property + def direction(self) -> Direction: ... + @property + def high(self) -> builtins.float: ... + @property + def low(self) -> builtins.float: ... + @property + def power(self) -> builtins.float: ... + @property + def cache(self) -> dict: ... + def __repr__(self) -> builtins.str: ... + +class LiteBar: + r""" + Python可见的LiteBar包装器 + """ + @property + def id(self) -> builtins.int: ... + @property + def dt(self) -> builtins.float: ... + @property + def price(self) -> builtins.float: ... + def __new__(cls, id:builtins.int, dt:builtins.float, price:builtins.float) -> LiteBar: ... + def __repr__(self) -> builtins.str: ... + +class NewBar: + r""" + 去除包含关系后的K线元素 + """ + @property + def symbol(self) -> builtins.str: ... + @property + def dt(self) -> typing.Any: ... + @property + def freq(self) -> Freq: ... + @property + def id(self) -> builtins.int: ... + @property + def open(self) -> builtins.float: ... + @property + def close(self) -> builtins.float: ... + @property + def high(self) -> builtins.float: ... + @property + def low(self) -> builtins.float: ... + @property + def vol(self) -> builtins.float: ... + @property + def amount(self) -> builtins.float: ... + @property + def cache(self) -> dict: ... + @property + def elements(self) -> builtins.list[RawBar]: ... + @property + def raw_bars(self) -> builtins.list[RawBar]: + r""" + 获取构成NewBar的原始K线列表(与elements相同,为兼容czsc库) + """ + def __new__(cls, symbol:builtins.str, dt:typing.Any, freq:Freq, open:builtins.float, close:builtins.float, high:builtins.float, low:builtins.float, vol:builtins.float, amount:builtins.float, id:builtins.int=0, elements:typing.Optional[typing.Sequence[RawBar]]=None) -> NewBar: ... + def __repr__(self) -> builtins.str: ... + def __richcmp__(self, other:NewBar, op:int) -> builtins.bool: ... + +class Operate: + r""" + Python可见的Operate包装器 + """ + HL: Operate + HS: Operate + HO: Operate + LO: Operate + LE: Operate + SO: Operate + SE: Operate + @property + def value(self) -> builtins.str: + r""" + 兼容性属性:返回操作类型的中文字符串值 + """ + @classmethod + def hl(cls, _cls:type) -> Operate: ... + @classmethod + def hs(cls, _cls:type) -> Operate: ... + @classmethod + def ho(cls, _cls:type) -> Operate: ... + @classmethod + def lo(cls, _cls:type) -> Operate: ... + @classmethod + def le(cls, _cls:type) -> Operate: ... + @classmethod + def so(cls, _cls:type) -> Operate: ... + @classmethod + def se(cls, _cls:type) -> Operate: ... + @classmethod + def from_str_py(cls, _cls:type, s:builtins.str) -> Operate: ... + @classmethod + def from_str(cls, _cls:type, s:builtins.str) -> Operate: ... + def __str__(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + def __eq__(self, other:Operate) -> builtins.bool: ... + def __hash__(self) -> builtins.int: ... + def __reduce__(self) -> typing.Any: + r""" + 支持pickle序列化 + """ + +class ParsedSignalDoc: + r""" + Python可见的ParsedSignalDoc包装器 + """ + @property + def param_template(self) -> typing.Optional[builtins.str]: ... + @property + def signals(self) -> builtins.list[Signal]: ... + def __repr__(self) -> builtins.str: ... + +class Pos: + r""" + Python可见的Pos枚举包装器 + """ + @classmethod + def short(cls, _cls:type) -> Pos: ... + @classmethod + def flat(cls, _cls:type) -> Pos: ... + @classmethod + def long(cls, _cls:type) -> Pos: ... + def __str__(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + def __eq__(self, other:Pos) -> builtins.bool: ... + def __add__(self, other:Pos) -> builtins.float: + r""" + 加法运算,用于numpy.mean等数学操作 + """ + def __radd__(self, other:builtins.float) -> builtins.float: + r""" + 右加法运算 + """ + def __float__(self) -> builtins.float: + r""" + 转换为浮点数,用于数学运算 + """ + def __int__(self) -> builtins.int: + r""" + 整数转换 + """ + def __lt__(self, other:Pos) -> builtins.bool: + r""" + 比较运算符 - 小于 + """ + def __le__(self, other:Pos) -> builtins.bool: + r""" + 比较运算符 - 小于等于 + """ + def __gt__(self, other:Pos) -> builtins.bool: + r""" + 比较运算符 - 大于 + """ + def __ge__(self, other:Pos) -> builtins.bool: + r""" + 比较运算符 - 大于等于 + """ + +class Position: + r""" + Python可见的Position包装器 + """ + @property + def opens(self) -> builtins.list[Event]: ... + @property + def exits(self) -> builtins.list[Event]: ... + @property + def interval(self) -> builtins.int: ... + @property + def timeout(self) -> builtins.int: ... + @property + def stop_loss(self) -> builtins.float: ... + @property + def t0(self) -> builtins.bool: ... + @property + def name(self) -> builtins.str: ... + @property + def symbol(self) -> builtins.str: ... + @property + def pos(self) -> builtins.float: ... + @property + def pos_changed(self) -> builtins.bool: ... + @property + def end_dt(self) -> typing.Optional[builtins.float]: + r""" + 获取最新信号时间 + """ + @property + def operates(self) -> builtins.list[typing.Any]: + r""" + 获取操作记录列表 + """ + @property + def pairs(self) -> list: + r""" + 获取交易对数据(返回记录列表,兼容pandas.DataFrame构造) + """ + @property + def holds(self) -> list: + r""" + 获取持仓历史数据(返回记录列表,兼容历史版本) + """ + @property + def unique_signals(self) -> builtins.list[builtins.str]: ... + @property + def events(self) -> builtins.list[Event]: ... + def __new__(cls, symbol:builtins.str, opens:typing.Sequence[Event], exits:typing.Sequence[Event]=[], interval:builtins.int=0, timeout:builtins.int=1000, stop_loss:builtins.float=1000.0, t0:builtins.bool=False, name:typing.Optional[builtins.str]=None) -> Position: ... + @classmethod + def load_from_file(cls, _cls:type, path:builtins.str) -> Position: ... + @classmethod + def from_json(cls, _cls:type, json_str:builtins.str) -> Position: ... + def save(self, path:builtins.str) -> None: + r""" + 保存到文件 + """ + def to_json(self) -> builtins.str: + r""" + 转换为JSON字符串 + """ + def all_events(self) -> builtins.list[Event]: + r""" + 获取所有相关事件 + """ + def update(self, arg1:typing.Any, arg2:typing.Optional[typing.Any]=None) -> None: + r""" + 更新仓位状态(兼容单参数调用) + """ + def __reduce__(self) -> typing.Any: + r""" + 支持 pickle 序列化 - 使用 __reduce__ 方法 + """ + def dump(self, with_data:builtins.bool=True) -> typing.Any: + r""" + 导出Position数据为Python字典 + """ + @classmethod + def load(cls, _cls:type, data:typing.Any) -> Position: + r""" + 从字典数据加载Position + """ + def __repr__(self) -> builtins.str: ... + +class RawBar: + r""" + 原始K线元素 + """ + @property + def symbol(self) -> builtins.str: ... + @property + def dt(self) -> typing.Any: ... + @property + def freq(self) -> Freq: ... + @property + def id(self) -> builtins.int: ... + @property + def open(self) -> builtins.float: ... + @property + def close(self) -> builtins.float: ... + @property + def high(self) -> builtins.float: ... + @property + def low(self) -> builtins.float: ... + @property + def vol(self) -> builtins.float: ... + @property + def amount(self) -> builtins.float: ... + @property + def solid(self) -> builtins.float: + r""" + 实体部分(与原版CZSC兼容) + """ + @property + def upper(self) -> builtins.float: + r""" + 上影线长度(与原版CZSC兼容) + """ + @property + def lower(self) -> builtins.float: + r""" + 下影线长度(与原版CZSC兼容) + """ + @property + def cache(self) -> dict: ... + @property + def __dict__(self) -> typing.Any: + r""" + 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 + """ + def __new__(cls, symbol:builtins.str, dt:typing.Any, freq:Freq, open:builtins.float, close:builtins.float, high:builtins.float, low:builtins.float, vol:builtins.float, amount:builtins.float, id:builtins.int=0) -> RawBar: ... + def _asdict(self) -> typing.Any: + r""" + 让对象表现得像记录,pandas DataFrame构造器会调用这个 + """ + def to_dict(self) -> typing.Any: + r""" + 转换为字典,便于创建 pandas DataFrame + """ + def __reduce__(self) -> typing.Any: + r""" + 支持pickle序列化 + """ + def __deepcopy__(self, _memo:typing.Any) -> RawBar: + r""" + 支持深拷贝 + """ + def __repr__(self) -> builtins.str: ... + def __richcmp__(self, other:RawBar, op:int) -> builtins.bool: ... + +class Signal: + r""" + Python可见的Signal包装器 + """ + @property + def key(self) -> builtins.str: ... + @property + def value(self) -> builtins.str: ... + @property + def k3(self) -> builtins.str: ... + @property + def v1(self) -> builtins.str: ... + @property + def v2(self) -> builtins.str: ... + @property + def v3(self) -> builtins.str: ... + @property + def score(self) -> builtins.int: ... + @property + def k1(self) -> builtins.str: + r""" + 新增k1和k2属性getter,匹配Python版本 + """ + @property + def k2(self) -> builtins.str: ... + def __new__(cls, *args, signal:typing.Optional[builtins.str]=None, key:typing.Optional[builtins.str]=None, value:typing.Optional[builtins.str]=None, k1:typing.Optional[builtins.str]=None, k2:typing.Optional[builtins.str]=None, k3:typing.Optional[builtins.str]=None, v1:typing.Optional[builtins.str]=None, v2:typing.Optional[builtins.str]=None, v3:typing.Optional[builtins.str]=None, score:typing.Optional[builtins.int]=None) -> Signal: ... + @classmethod + def from_string(cls, _cls:type, s:builtins.str) -> Signal: ... + def to_json(self) -> builtins.str: + r""" + 添加to_json方法以匹配Python版本 + """ + def __str__(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + def __eq__(self, other:Signal) -> builtins.bool: ... + def __hash__(self) -> builtins.int: ... + def matches(self, other:Signal) -> builtins.bool: + r""" + 检查Signal是否匹配另一个Signal + """ + def is_match(self, signals_dict:typing.Mapping[builtins.str, builtins.str]) -> builtins.bool: + r""" + 判断信号是否与信号字典中的值匹配(Python版本is_match逻辑) + """ + def to_string(self) -> builtins.str: + r""" + 获取Signal的完整字符串表示 + """ + +class ZS: + @property + def bis(self) -> builtins.list[BI]: + r""" + 获取构成中枢的笔列表 + """ + @property + def sdt(self) -> typing.Any: + r""" + 中枢开始时间 + """ + @property + def edt(self) -> typing.Any: + r""" + 中枢结束时间 + """ + @property + def sdir(self) -> Direction: + r""" + 中枢第一笔方向 + """ + @property + def edir(self) -> Direction: + r""" + 中枢倒一笔方向 + """ + @property + def zg(self) -> builtins.float: + r""" + 中枢上沿 + """ + @property + def zd(self) -> builtins.float: + r""" + 中枢下沿 + """ + @property + def zz(self) -> builtins.float: + r""" + 中枢中轴 + """ + @property + def gg(self) -> builtins.float: + r""" + 中枢最高点 + """ + @property + def dd(self) -> builtins.float: + r""" + 中枢最低点 + """ + @property + def cache(self) -> dict: ... + def __new__(cls, bis:typing.Sequence[BI]) -> ZS: ... + def is_valid(self) -> builtins.bool: + r""" + 中枢是否有效 + """ + +class Direction(Enum): + r""" + 方向 + """ + Up = ... + r""" + 向上 + """ + Down = ... + r""" + 向下 + """ + + @property + def value(self) -> builtins.str: + r""" + 获取方向的字符串值(与 czsc 库兼容) + """ + def __deepcopy__(self, _memo:typing.Any) -> Direction: + r""" + 支持深拷贝 + """ + def __reduce__(self) -> tuple[typing.Any, typing.Any]: + r""" + 支持pickle序列化 + """ + def __new__(cls, value:builtins.str) -> Direction: ... + def __str__(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... + +class Freq(Enum): + r""" + 时间周期 + """ + Tick = ... + r""" + 逐笔 + """ + F1 = ... + r""" + 1分钟 + """ + F2 = ... + r""" + 2分钟 + """ + F3 = ... + r""" + 3分钟 + """ + F4 = ... + r""" + 4分钟 + """ + F5 = ... + r""" + 5分钟 + """ + F6 = ... + r""" + 6分钟 + """ + F10 = ... + r""" + 10分钟 + """ + F12 = ... + r""" + 12分钟 + """ + F15 = ... + r""" + 15分钟 + """ + F20 = ... + r""" + 20分钟 + """ + F30 = ... + r""" + 30分钟 + """ + F60 = ... + r""" + 60分钟 + """ + F120 = ... + r""" + 120分钟 + """ + F240 = ... + r""" + 240分钟 + """ + F360 = ... + r""" + 360分钟 + """ + D = ... + r""" + 日线 + """ + W = ... + r""" + 周线 + """ + M = ... + r""" + 月线 + """ + S = ... + r""" + 季线 + """ + Y = ... + r""" + 年线 + """ + + __members__: typing.Any + @property + def value(self) -> builtins.str: ... + def __deepcopy__(self, _memo:typing.Any) -> Freq: + r""" + 支持深拷贝 + """ + def __reduce__(self) -> tuple[typing.Any, typing.Any]: + r""" + 支持pickle序列化 + """ + def __new__(cls, value:builtins.str) -> Freq: ... + def __str__(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... + +class Mark(Enum): + r""" + 分型类型 + """ + D = ... + r""" + 底分型 + """ + G = ... + r""" + 顶分型 + """ + + @property + def value(self) -> builtins.str: + r""" + 获取标记的字符串值(与 czsc 库兼容) + """ + def __str__(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... + +class Market(Enum): + AShare = ... + r""" + A股 + """ + Futures = ... + r""" + 期货 + """ + Default = ... + r""" + 默认 + """ + + def __new__(cls, ob:typing.Any) -> Market: ... + +def chip_distribution_triangle(data:numpy.typing.NDArray[numpy.float64], price_step:builtins.float, decay_factor:builtins.float) -> tuple[numpy.typing.NDArray[numpy.float64], numpy.typing.NDArray[numpy.float64]]: + r""" + 计算筹码分布(三角形分布 + 筹码沉淀机制) + + 此函数用于估算基于历史K线的筹码分布情况,结合三角形分布模型和筹码沉淀(衰减)机制。 + + # Python 接口说明 + + 输入一个二维 numpy 数组,形状为 (N, 3),每一行对应一根K线,列顺序为: + `[high, low, vol]`,类型必须为 `float64`。 + + 示例: + ```python + columns = ['high', 'low', 'vol'] + arr2 = df[columns].to_numpy(dtype=np.float64) + price_centers, chip_dist = chip_distribution_triangle(arr2, 0.01, 0.9) + ``` + + # 参数 + + - `data`: 二维数组,形状为 (N, 3),分别是每根K线的最高价、最低价和成交量。 + - `price_step`: 分档间隔(如0.01表示以0.01为单位划分价格区间)。 + - `decay_factor`: 筹码衰减因子,表示前一根K线上的筹码有多少比例沉淀保留到下一根K线上,范围为(0, 1),例如0.98表示保留98%。 + + # 返回值 + + 返回一个元组 `(price_centers, chip_distribution)`: + - `price_centers`: 一维数组,表示价格分布区间的中心价位。 + - `chip_distribution`: 一维数组,对应每个价格中心的筹码强度(权重/密度)。 + + 返回的两个数组长度相同,可用于绘制筹码分布图或进一步分析。 + """ + +def parse_signal_doc(doc:builtins.str) -> ParsedSignalDoc: + r""" + 解析文档中的Signal信息 + """ + diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 9987f38a9..0c37f97a6 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -777,6 +777,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | 移除 `czsc/__init__.py` lazy loading + 注释/文档密度收缩 | [czsc/__init__.py](../czsc/__init__.py) | 按 spec §3.1 删除 `_LAZY_MODULES` / `_LAZY_ATTRS` / `__getattr__` 三件套;`svc / fsa / aphorism / mock` 改为顶层 `from . import ...`,7 个 lazy 属性(`capture_warnings` / `execute_with_warning_capture` / `adjust_holding_weights` / `log_strategy_info` / `plot_czsc_chart` / `KlineChart` / `check_kline_quality`)改为 `from czsc.utils.* import ...` 直接导入;删除 `if TYPE_CHECKING` 守卫;`welcome()` 函数体内的 `from czsc import aphorism` 提到顶层。同时压缩区段注释(17 处冗长的"逐符号说明"注释块全删)、`__all__` 字面表改为按主题分组的紧凑横排(仍保留全部 129 个公共名称、按主题用单行注释分隔)、`welcome()` docstring 折成单行、模块 docstring 22 行 → 11 行。`czsc/__init__.py` LoC 从 507 → 235(-54%)。**循环 import 防坑**:`svc / fsa / aphorism / mock` 中含 `from czsc import top_drawdowns` 等回环 import,必须放到所有顶层符号绑定后再加载(即"第二批 `from . import aphorism, fsa, mock, svc`"),第一次重排误把它们提到顶部触发 `cannot import name 'top_drawdowns' from partially initialized module 'czsc'`,调整顺序后通过;文件中以"第一批 / 第二批"分组注释固化此约束。**测试更新**:`test/test_import_performance.py::test_heavy_dependencies_not_loaded_on_import` 与 `test_svc_lazy_loaded` 是基于"streamlit 不应在 import czsc 时被加载"的旧设计断言,与新方向冲突,已删除;保留 `test_czsc_import_time`(< 10s 兜底)与 `test_czsc_svc_accessible`(顶层属性可用)。冷启动 importtime cumtime ≈ 320ms(spec §6 P3 目标 ≤ 300ms,超 ~7%;spec §3.1 注释中预期 < 50ms 仅指 Rust 扩展加载,不含整包 import) | | 实现 `czsc.utils.trade.stoploss_by_direction` 并切换调用方 | [czsc/utils/trade.py](../czsc/utils/trade.py)、[czsc/svc/backtest.py:261](../czsc/svc/backtest.py)、[test/test_stoploss_by_direction.py](../test/test_stoploss_by_direction.py) | 调研发现 `stoploss_by_direction` 既不在当前安装的 `rs_czsc`,也不在 `wbt`,更不在 `/Users/jun/Documents/vscodePro/rs_czsc` git 历史中——`from rs_czsc import stoploss_by_direction` 是死调用,运行 Streamlit dashboard 时会 `ImportError`。按 spec C1(`grep -r rs_czsc czsc/` 应无结果)的目标,按 superpowers TDD 范式新增 6 个 RED 测试(多/空头止损、order_id 切分、列契约、入参不可变性等),用纯 Python 在 `czsc/utils/trade.py` 写最小实现(按方向连续段切 order_id、向量化 hold_returns / min_hold_returns / returns / is_stop,浮点容差 1e-9 处理 `92/100 - 1 = -0.07999…` 这类边界),把 `czsc/svc/backtest.py:261` 的导入切到 `czsc.utils.trade.stoploss_by_direction`。**`grep -r 'from rs_czsc\\|import rs_czsc' czsc/ --include='*.py'` 现在零结果**——spec C1 全量达成,czsc 内部彻底无 `rs_czsc` 依赖。该函数标记为 czsc-only 改动,归入 §2.2 "新增能力"小节 | | `czsc.envs` 精简(Python 侧 docstring 收缩) | [czsc/envs.py](../czsc/envs.py)、[czsc/envs.pyi](../czsc/envs.pyi) | spec §3.4 目标"~20 行"是含 Rust `set_envs(...)` 入口后的最终形态;本轮先做 Python 侧最大化压缩:117 → 49 行(-58%)。3 个 getter(`get_verbose` / `get_min_bi_len` / `get_max_bi_num`) + 2 个内部 helper(`_env` / `_to_bool`)逻辑完全保留;裁剪掉模块级 ~20 行说明性 docstring 与每个函数 8-15 行的 verbose docstring,改为单行说明。`envs.pyi` 同步更新:删除 `valid_true: Incomplete` / `def use_python(): ...` / `def get_welcome(): ...` 三个旧公共符号(`test/test_envs.py::TestRetiredHelpers` 已经断言它们在 .py 中不存在,但 .pyi 之前未跟进,会让 basedpyright 把它们误识为存在)。`test_envs.py` 全 16 项通过 | +| `czsc/_native.pyi` 自动生成(spec §2.4 / Q4) | [crates/czsc-python/Cargo.toml](../crates/czsc-python/Cargo.toml)、[crates/czsc-python/src/lib.rs](../crates/czsc-python/src/lib.rs)、[crates/czsc-python/src/bin/stub_gen.rs](../crates/czsc-python/src/bin/stub_gen.rs)、[czsc/_native.pyi](../czsc/_native.pyi) | 各 PyO3 业务 crate 上的 `gen_stub_pyclass` / `gen_stub_pyfunction` / `gen_stub_pymethods` 装饰器早已布到位,但缺 stub 收集器 + 生成 binary。本次:① `czsc-python` 拆出 `extension-module` 为可选 default feature(cdylib 走默认;binary 用 `--no-default-features` 让 pyo3 自动链接 libpython,否则 macOS 链接器找不到 `_PyExc_*` 等符号);② lib.rs 写了一个自定义 `stub_info()` 函数,把 `from_pyproject_toml` 路径显式指向 workspace 根(默认宏 `define_stub_info_gatherer!` 假设 `pyproject.toml` 与 `Cargo.toml` 同目录,但本仓库 pyproject 在 workspace 根、Cargo 在 `crates/czsc-python/`);③ `src/bin/stub_gen.rs` 是最小入口,调用 `stub_info()?.generate()?`;④ 触发:`PYO3_PYTHON=$(uv run python -c 'import sys; print(sys.executable)') cargo run --bin stub_gen -p czsc-python --no-default-features`;⑤ 产物:`czsc/_native.pyi`,1 235 行,覆盖 BI / CZSC / FX / ZS / BarGenerator / Position / Signal / Event / RawBar / NewBar / FakeBI / Direction / Mark / Operate / Freq / ParsedSignalDoc 等核心类与 30+ TA 算子 / 信号函数 / `chip_distribution_triangle` / `parse_signal_doc` 等顶层函数。`pyproject.toml::tool.maturin.include = ["czsc/**/*.pyi", ...]` 已经覆盖此文件,wheel 打包自动带上。**残留**:basedpyright 在 `_native.pyi` 上报 8 个 upstream pyo3-stub-gen 已知问题(`__eq__` 参数类型不兼容父类、`__dict__` 与 `dict[str, Any]` 不兼容),属于工具层 false-positive,不影响功能 | ### 10.2 故意保留 / 暂缓的项 @@ -785,7 +786,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `czsc/sensors/` 完整恢复 | spec §9 "完整保留 3 文件 301 行(含 `CTAResearch`)" | 保留占位 `__init__.py`(15 行),未恢复 `cta.py` / `utils.py` | 历史 `czsc.sensors.cta.CTAResearch` 依赖 `from czsc.traders.dummy import DummyBacktest`,而 `dummy.py` 已按 spec §3.3 删除并由 `czsc.run_replay` / wbt 替代;1:1 恢复会引入坏 import。需先在 Rust 端 `czsc-trader` 提供等价 dummy/replay 后再恢复,归为后续 Phase G 收尾 | | `czsc/traders/optimize.py` | spec §3.3 / §9 列入"完全删除" | 保留 | 现已是 Rust 端 `run_optimize_batch` 的 Python 薄外观层(配置归一化 + 物化数据 + 任务哈希 + 结果转发),与 spec 旧版"完全删除"假设不符;行为正确,无回归。以"过渡薄层"身份保留,不在 P0 范围内删除 | | `czsc/utils/ta.py` | spec §3.2 删除(由 Rust `czsc.ta.*` 替代) | 保留 75 行 | 仅保留 czsc 仪表盘场景使用的 MACD 特殊约定("柱状图额外乘以 2"),Rust 端 `czsc-ta` 暂未迁移该约定。**不通过 `czsc.ta` 重新导出**(`czsc.ta` 已指向 Rust 子模块),调用方需显式 `from czsc.utils.ta import MACD`。后续把柱状图 ×2 约定纳入 `czsc-ta::pure` 后再删 | -| `czsc/_native.pyi` 自动生成 | spec §2.4 / Q4 | 未生成 | 需在 `crates/czsc-python/build.rs` 接 `pyo3-stub-gen`;属 P1 范畴,本次会话未做。当前类型检查靠 `czsc/__init__.py` 内联注解 + `py.typed` | +| ~~`czsc/_native.pyi` 自动生成~~ | ~~spec §2.4 / Q4~~ | **2026-05-07 已完成** | 详见上表"`czsc/_native.pyi` 自动生成"。1 235 行 stub 已生成并被 maturin include 覆盖;basedpyright 上 8 个 upstream pyo3-stub-gen false-positive 已记录,不影响功能 | | `czsc.envs` 精简(Python 侧) | spec §3.4 | **2026-05-07 已完成**:117 → 49 行(-58%) | 详见上表"`czsc.envs` 精简(Python 侧 docstring 收缩)"。Rust 端 `set_envs(min_bi_len=..., max_bi_num=..., verbose=...)` 入口仍归 P1(需 `czsc-utils` 暴露后再做),现状 `CZSC` 与 `BarGenerator` 通过构造器参数接收 envs,不依赖 Rust 全局 | ### 10.3 验证 From f8a9c5bb94f70b1911f2146c904c8f50be41431c Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 17:15:55 +0800 Subject: [PATCH 10/23] ci(stubs): add czsc/_native.pyi drift check to code-quality workflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 防回归:12cf082 落地 stub_gen 后,Rust 装饰器改动若忘记重跑 stub_gen, czsc/_native.pyi 会与 Rust 源码漂移。CI 缺这个保护点。 新增 stub-drift job(needs: rust-tests): - checkout + Rust toolchain + Python 3.11 - PYO3_PYTHON=$(which python3) cargo run --bin stub_gen -p czsc-python --no-default-features - git diff --exit-code czsc/_native.pyi 断言无漂移, 否则失败并把本地重新生成命令打到日志里 本地验证:连续跑两次 stub_gen 产物一致(idempotent)。 覆盖两个回归方向: (a) 改了 #[gen_stub_pyclass] / pyfunction / pymethods 装饰器但忘重跑 -> CI 红灯阻拦 (b) 手改了 stub 但 Rust 没改 -> 下次 CI 复跑时把手改盖掉、提示提交方处理 --- .github/workflows/code-quality.yml | 50 ++++++++++++++++++++++++++++++ docs/MIGRATION_NOTES.md | 2 ++ 2 files changed, 52 insertions(+) diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 0478f8d41..8f7f97c8b 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -42,6 +42,56 @@ jobs: echo "::endgroup::" done + # ------------------------------------------------------------------ + # 1.1) `czsc/_native.pyi` 漂移检查 + # + # `crates/czsc-python/src/bin/stub_gen.rs` 通过 pyo3-stub-gen 0.12 + # 把 Rust 端 `#[gen_stub_pyclass]` / `#[gen_stub_pyfunction]` / + # `#[gen_stub_pymethods]` 的元信息渲染成 `czsc/_native.pyi`。 + # 这个 stub 文件与 Rust 源码会双向漂移: + # + # - 改了 `gen_stub_*` 装饰器但忘了重跑 stub_gen → stub 反映旧 API + # - 手改了 stub 但 Rust 没改 → 下次重跑会把手改盖掉 + # + # 本 job 的策略:在 CI 重新跑一次 stub_gen,再 `git diff --exit-code` + # 检查 `czsc/_native.pyi` 是否被改写,若变更则 PR 必须把新生成的 stub + # 也提交进来才能合入。 + # ------------------------------------------------------------------ + stub-drift: + name: czsc/_native.pyi drift check + runs-on: ubuntu-latest + needs: rust-tests + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + + - name: Set up Python (any 3.10+ works for binding) + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Run stub_gen + run: | + export PYO3_PYTHON=$(which python3) + cargo run --bin stub_gen -p czsc-python --no-default-features + + - name: Assert czsc/_native.pyi is in sync with Rust source + run: | + if ! git diff --exit-code czsc/_native.pyi; then + echo "::error::czsc/_native.pyi 与 Rust 装饰器不一致 ——" + echo "请在本地运行:" + echo " PYO3_PYTHON=\$(uv run python -c 'import sys; print(sys.executable)') \\" + echo " cargo run --bin stub_gen -p czsc-python --no-default-features" + echo "并把更新后的 czsc/_native.pyi 一并 commit。" + exit 1 + fi + # ------------------------------------------------------------------ # 2) Python test matrix — single abi3 wheel covers 3.10/3.11/3.12/3.13. # Each Python version installs the project (which triggers maturin diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 0c37f97a6..4c918a482 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -778,6 +778,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | 实现 `czsc.utils.trade.stoploss_by_direction` 并切换调用方 | [czsc/utils/trade.py](../czsc/utils/trade.py)、[czsc/svc/backtest.py:261](../czsc/svc/backtest.py)、[test/test_stoploss_by_direction.py](../test/test_stoploss_by_direction.py) | 调研发现 `stoploss_by_direction` 既不在当前安装的 `rs_czsc`,也不在 `wbt`,更不在 `/Users/jun/Documents/vscodePro/rs_czsc` git 历史中——`from rs_czsc import stoploss_by_direction` 是死调用,运行 Streamlit dashboard 时会 `ImportError`。按 spec C1(`grep -r rs_czsc czsc/` 应无结果)的目标,按 superpowers TDD 范式新增 6 个 RED 测试(多/空头止损、order_id 切分、列契约、入参不可变性等),用纯 Python 在 `czsc/utils/trade.py` 写最小实现(按方向连续段切 order_id、向量化 hold_returns / min_hold_returns / returns / is_stop,浮点容差 1e-9 处理 `92/100 - 1 = -0.07999…` 这类边界),把 `czsc/svc/backtest.py:261` 的导入切到 `czsc.utils.trade.stoploss_by_direction`。**`grep -r 'from rs_czsc\\|import rs_czsc' czsc/ --include='*.py'` 现在零结果**——spec C1 全量达成,czsc 内部彻底无 `rs_czsc` 依赖。该函数标记为 czsc-only 改动,归入 §2.2 "新增能力"小节 | | `czsc.envs` 精简(Python 侧 docstring 收缩) | [czsc/envs.py](../czsc/envs.py)、[czsc/envs.pyi](../czsc/envs.pyi) | spec §3.4 目标"~20 行"是含 Rust `set_envs(...)` 入口后的最终形态;本轮先做 Python 侧最大化压缩:117 → 49 行(-58%)。3 个 getter(`get_verbose` / `get_min_bi_len` / `get_max_bi_num`) + 2 个内部 helper(`_env` / `_to_bool`)逻辑完全保留;裁剪掉模块级 ~20 行说明性 docstring 与每个函数 8-15 行的 verbose docstring,改为单行说明。`envs.pyi` 同步更新:删除 `valid_true: Incomplete` / `def use_python(): ...` / `def get_welcome(): ...` 三个旧公共符号(`test/test_envs.py::TestRetiredHelpers` 已经断言它们在 .py 中不存在,但 .pyi 之前未跟进,会让 basedpyright 把它们误识为存在)。`test_envs.py` 全 16 项通过 | | `czsc/_native.pyi` 自动生成(spec §2.4 / Q4) | [crates/czsc-python/Cargo.toml](../crates/czsc-python/Cargo.toml)、[crates/czsc-python/src/lib.rs](../crates/czsc-python/src/lib.rs)、[crates/czsc-python/src/bin/stub_gen.rs](../crates/czsc-python/src/bin/stub_gen.rs)、[czsc/_native.pyi](../czsc/_native.pyi) | 各 PyO3 业务 crate 上的 `gen_stub_pyclass` / `gen_stub_pyfunction` / `gen_stub_pymethods` 装饰器早已布到位,但缺 stub 收集器 + 生成 binary。本次:① `czsc-python` 拆出 `extension-module` 为可选 default feature(cdylib 走默认;binary 用 `--no-default-features` 让 pyo3 自动链接 libpython,否则 macOS 链接器找不到 `_PyExc_*` 等符号);② lib.rs 写了一个自定义 `stub_info()` 函数,把 `from_pyproject_toml` 路径显式指向 workspace 根(默认宏 `define_stub_info_gatherer!` 假设 `pyproject.toml` 与 `Cargo.toml` 同目录,但本仓库 pyproject 在 workspace 根、Cargo 在 `crates/czsc-python/`);③ `src/bin/stub_gen.rs` 是最小入口,调用 `stub_info()?.generate()?`;④ 触发:`PYO3_PYTHON=$(uv run python -c 'import sys; print(sys.executable)') cargo run --bin stub_gen -p czsc-python --no-default-features`;⑤ 产物:`czsc/_native.pyi`,1 235 行,覆盖 BI / CZSC / FX / ZS / BarGenerator / Position / Signal / Event / RawBar / NewBar / FakeBI / Direction / Mark / Operate / Freq / ParsedSignalDoc 等核心类与 30+ TA 算子 / 信号函数 / `chip_distribution_triangle` / `parse_signal_doc` 等顶层函数。`pyproject.toml::tool.maturin.include = ["czsc/**/*.pyi", ...]` 已经覆盖此文件,wheel 打包自动带上。**残留**:basedpyright 在 `_native.pyi` 上报 8 个 upstream pyo3-stub-gen 已知问题(`__eq__` 参数类型不兼容父类、`__dict__` 与 `dict[str, Any]` 不兼容),属于工具层 false-positive,不影响功能 | +| `_native.pyi` 漂移检查 CI job | [.github/workflows/code-quality.yml](../.github/workflows/code-quality.yml) | `code-quality.yml` 新增 `stub-drift` job(依赖 `rust-tests`),在 CI 中:① checkout + 装 Rust + Python 3.11;② 跑 `PYO3_PYTHON=$(which python3) cargo run --bin stub_gen -p czsc-python --no-default-features`;③ `git diff --exit-code czsc/_native.pyi` 断言无漂移,否则失败并把本地重新生成命令打到日志里。本地已验证 stub_gen 重跑两次产物一致(idempotent)。这一步覆盖两个回归方向:(a)改了 `gen_stub_*` 装饰器但忘重跑 → CI 红灯阻拦;(b)手改了 stub 但 Rust 没改 → 下次 CI 复跑时把手改盖掉、提示提交方处理 | ### 10.2 故意保留 / 暂缓的项 @@ -787,6 +788,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `czsc/traders/optimize.py` | spec §3.3 / §9 列入"完全删除" | 保留 | 现已是 Rust 端 `run_optimize_batch` 的 Python 薄外观层(配置归一化 + 物化数据 + 任务哈希 + 结果转发),与 spec 旧版"完全删除"假设不符;行为正确,无回归。以"过渡薄层"身份保留,不在 P0 范围内删除 | | `czsc/utils/ta.py` | spec §3.2 删除(由 Rust `czsc.ta.*` 替代) | 保留 75 行 | 仅保留 czsc 仪表盘场景使用的 MACD 特殊约定("柱状图额外乘以 2"),Rust 端 `czsc-ta` 暂未迁移该约定。**不通过 `czsc.ta` 重新导出**(`czsc.ta` 已指向 Rust 子模块),调用方需显式 `from czsc.utils.ta import MACD`。后续把柱状图 ×2 约定纳入 `czsc-ta::pure` 后再删 | | ~~`czsc/_native.pyi` 自动生成~~ | ~~spec §2.4 / Q4~~ | **2026-05-07 已完成** | 详见上表"`czsc/_native.pyi` 自动生成"。1 235 行 stub 已生成并被 maturin include 覆盖;basedpyright 上 8 个 upstream pyo3-stub-gen false-positive 已记录,不影响功能 | +| ~~`_native.pyi` CI 漂移检查~~ | ~~P1 待办(本地有 stub_gen,但 CI 不验证~~ | **2026-05-07 已完成** | 详见上表"`_native.pyi` 漂移检查 CI job"。`code-quality.yml` 新增 `stub-drift` job,在 CI 重跑 stub_gen 后 `git diff --exit-code czsc/_native.pyi`,发现漂移即失败并提示重新生成命令 | | `czsc.envs` 精简(Python 侧) | spec §3.4 | **2026-05-07 已完成**:117 → 49 行(-58%) | 详见上表"`czsc.envs` 精简(Python 侧 docstring 收缩)"。Rust 端 `set_envs(min_bi_len=..., max_bi_num=..., verbose=...)` 入口仍归 P1(需 `czsc-utils` 暴露后再做),现状 `CZSC` 与 `BarGenerator` 通过构造器参数接收 envs,不依赖 Rust 全局 | ### 10.3 验证 From ad5a6784986b1fbddcb0e60e9831cfadcdb6f675 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 17:19:35 +0800 Subject: [PATCH 11/23] =?UTF-8?q?feat(sensors):=20partial=20restore=20of?= =?UTF-8?q?=20czsc/sensors/=20(spec=20=C2=A79)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 之前 czsc/sensors/ 仅有 15 行占位 __init__.py,与 spec §9 "完整保留 3 文件 301 行(含 CTAResearch)"差距明显。本次: - 从 git 历史 79bdf5e:czsc/sensors/utils.py 恢复 utils.py(121 行,含 holds_concepts_effect / turn_over_rate / max_draw_down 三个纯 numpy / pandas 工具;无内部 czsc 依赖,可直接复用) - 同步恢复 utils.pyi - 重写 __init__.py: * 暴露 3 个 utility 函数到 czsc.sensors.* 顶层 * 添加 CTAResearch 占位类,__init__ 直接抛 NotImplementedError, 明确告知历史实现依赖已删的 czsc.traders.dummy.DummyBacktest (spec §3.3 删除),引导用户改用 czsc.run_replay / wbt.WeightBacktest 组合,并指向 MIGRATION_NOTES §10.2 中的恢复计划 公共 API 影响:from czsc.sensors import {CTAResearch, holds_concepts_effect, turn_over_rate, max_draw_down} 全部可用;CTAResearch() 实例化即 fail-fast,不会让代码运行半截才报 ImportError。 spec §9 "完整保留" 项剩余的就只是 cta.py 真实迁移(需先在 Rust 端 czsc-trader 提供等价 dummy/replay)。 测试:124 项全量 sweep 仍 GREEN。 --- czsc/sensors/__init__.py | 60 +++++++++++++++---- czsc/sensors/utils.py | 121 +++++++++++++++++++++++++++++++++++++++ czsc/sensors/utils.pyi | 5 ++ docs/MIGRATION_NOTES.md | 3 +- 4 files changed, 177 insertions(+), 12 deletions(-) create mode 100644 czsc/sensors/utils.py create mode 100644 czsc/sensors/utils.pyi diff --git a/czsc/sensors/__init__.py b/czsc/sensors/__init__.py index 433fe67ba..b74076eb1 100644 --- a/czsc/sensors/__init__.py +++ b/czsc/sensors/__init__.py @@ -1,15 +1,53 @@ -"""czsc.sensors —— 事件检测与特征分析命名空间(占位实现)。 +"""czsc.sensors —— 事件检测与特征分析命名空间。 -本包对应 CZSC 项目中的“传感器(sensors)”层,目标是承载下列高层能力: +按 spec §9,本包目标承载三类传感器能力: -- ``CTA`` 研究框架:策略回放、参数优化、并行回测; -- 特征选择器:基于滚动窗口的因子选择与分析; -- 事件匹配器:把信号组合成事件并在历史数据上扫描匹配。 +1. **CTA 研究框架**:策略回放、参数优化、并行回测——以 ``CTAResearch`` 类为核心 +2. **特征选择器**:基于滚动窗口的因子选择与分析 +3. **事件匹配器**:把信号组合成事件并在历史数据上扫描匹配 -当前阶段仅恢复命名空间的可导入性,使得 ``import czsc.sensors`` 不会 -报错;具体的传感器类(``CTAResearch``、``FeatureSelector``、``EventMatcher`` 等) -将在后续 Python 端清理工作中一并迁移过来。 - -在迁移完成之前,本模块保持为一个空的命名空间包,公共 API 冒烟测试只验证模块 -能干净地完成导入这一最小契约,不假设包内已有任何具体类型。 +当前阶段已恢复 ``utils`` 子模块(``holds_concepts_effect`` / +``turn_over_rate`` / ``max_draw_down`` 三个纯 numpy/pandas 工具,无内部 czsc +依赖)。``CTAResearch`` 历史实现依赖已删除的 ``czsc.traders.dummy.DummyBacktest`` +(spec §3.3 已经把 dummy 替换为 ``czsc.run_replay`` / wbt),因此暂时保留为 +``NotImplementedError`` 占位,等 Phase G 在 Rust 端 ``czsc-trader`` 提供 +等价的 dummy/replay 后再恢复实现。 """ + +from __future__ import annotations + +from czsc.sensors.utils import ( + holds_concepts_effect, + max_draw_down, + turn_over_rate, +) + + +class CTAResearch: + """spec §9 中的 CTA 研究框架占位类。 + + 历史实现依赖 ``czsc.traders.dummy.DummyBacktest``;该模块已按 spec §3.3 + 在 Phase J 删除(被 ``czsc.run_replay`` / ``wbt.WeightBacktest`` 取代)。 + 完整恢复需要先在 Rust 端 ``czsc-trader`` 提供等价能力,详见 + ``docs/MIGRATION_NOTES.md`` §10.2。 + + 在那之前,实例化此类会立即抛 :class:`NotImplementedError`,明确告知用户 + 迁移路径,而不是让代码在运行半截后才报"找不到 DummyBacktest"。 + """ + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "czsc.sensors.CTAResearch 暂未恢复——历史实现依赖已删除的 " + "czsc.traders.dummy.DummyBacktest(spec §3.3 / Phase J)。" + "请改用 czsc.run_replay / czsc.run_research / czsc.WeightBacktest " + "组合达成等价工作流;或关注 docs/MIGRATION_NOTES.md §10.2 中" + "对该项的恢复计划。" + ) + + +__all__ = [ + "CTAResearch", + "holds_concepts_effect", + "max_draw_down", + "turn_over_rate", +] diff --git a/czsc/sensors/utils.py b/czsc/sensors/utils.py new file mode 100644 index 000000000..0afa941a5 --- /dev/null +++ b/czsc/sensors/utils.py @@ -0,0 +1,121 @@ +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2021/11/17 18:50 +""" + +from collections import Counter + +import numpy as np +import pandas as pd +from tqdm import tqdm + + +def max_draw_down(n1b: list): + """最大回撤 + + 参考:https://blog.csdn.net/weixin_38997425/article/details/82915386 + + :param n1b: 逐个结算周期的收益列表,单位:BP,换算关系是 10000BP = 100% + 如,n1b = [100.1, -90.5, 212.6],表示第一个结算周期收益为100.1BP,也就是1.001%,以此类推。 + :return: 最大回撤起止位置和最大回撤 + """ + curve = np.cumsum(n1b) + curve += 10000 + # 获取结束位置 + i = np.argmax((np.maximum.accumulate(curve) - curve) / np.maximum.accumulate(curve)) + if i == 0: + return 0, 0, 0 + + # 获取开始位置 + j = np.argmax(curve[:i]) + mdd = int((curve[j] - curve[i]) / curve[j] * 10000) / 10000 + return j, i, mdd + + +def turn_over_rate(df_holds: pd.DataFrame) -> tuple[pd.DataFrame, float]: + """计算持仓明细对应的组合换手率 + + :param df_holds: 每个交易日的持仓明细,数据样例如下 + 证券代码 成分日期 持仓权重 + 0 000576.SZ 2020-01-02 0.0099 + 1 000639.SZ 2020-01-02 0.0099 + 2 000803.SZ 2020-01-02 0.0099 + 3 000811.SZ 2020-01-02 0.0099 + 4 000829.SZ 2020-01-02 0.0099 + :return: 组合换手率 + """ + dft = pd.pivot_table(df_holds, index="成分日期", columns="证券代码", values="持仓权重", aggfunc="sum") + dft = dft.fillna(0) + df_turns = dft.diff().abs().sum(axis=1).reset_index() + df_turns.columns = ["date", "change"] + + # 由于是 diff 计算,第一个时刻的仓位变化被忽视了,修改一下 + sdt = df_holds["成分日期"].min() + df_turns.loc[(df_turns["date"] == sdt), "change"] = df_holds[df_holds["成分日期"] == sdt]["持仓权重"].sum() + return df_turns, round(df_turns.change.sum() / 2, 4) + + +def holds_concepts_effect(holds: pd.DataFrame, concepts: dict, top_n=20, min_n=3, **kwargs): + """股票持仓列表的板块效应 + + 原理概述:在选股时,如果股票的概念板块与组合中的其他股票的概念板块有重合,那么这个股票的表现会更好。 + + 函数计算逻辑: + + 1. 如果kwargs中存在'copy'键且对应值为True,则将holds进行复制。 + 2. 为holds添加'概念板块'列,该列的值是holds中'symbol'列对应的股票的概念板块列表,如果没有对应的概念板块则填充为空。 + 3. 添加'概念数量'列,该列的值是每个股票的概念板块数量。 + 4. 从holds中筛选出概念数量大于0的行,赋值给holds。 + 5. 创建空列表new_holds和空字典dt_key_concepts。 + 6. 对holds按照'dt'进行分组,遍历每个分组,计算板块效应。 + a. 计算密集出现的概念,选取出现次数最多的前top_n个概念,赋值给key_concepts列表。 + b. 将日期dt和对应的key_concepts存入dt_key_concepts字典。 + c. 计算在密集概念中出现次数超过min_n的股票,将符合条件的股票添加到new_holds列表中。 + 7. 使用pd.concat将new_holds中的DataFrame进行合并,忽略索引,赋值给dfh。 + 8. 创建DataFrame dfk,其中包含日期(dt)和对应的强势概念(key_concepts)。 + 9. 返回dfh和dfk。 + + :param holds: 组合股票池数据,样例: + + =================== ========= ========== + dt symbol weight + =================== ========= ========== + 2023-05-09 00:00:00 601858.SH 0.00333333 + 2023-05-09 00:00:00 300502.SZ 0.00333333 + 2023-05-09 00:00:00 603258.SH 0.00333333 + 2023-05-09 00:00:00 300499.SZ 0.00333333 + 2023-05-09 00:00:00 300624.SZ 0.00333333 + =================== ========= ========== + + :param concepts: 股票的概念板块,样例: + { + '002507.SZ': ['电子商务', '超级品牌', '国企改革'], + '002508.SZ': ['家用电器', '杭州亚运会', '恒大概念'] + } + :param top_n: 选取前 n 个密集概念 + :param min_n: 单股票至少要有 n 个概念在 top_n 中 + :return: 过滤后的选股结果,每个时间点的 top_n 概念 + """ + if kwargs.get("copy", True): + holds = holds.copy() + + holds["概念板块"] = holds["symbol"].map(concepts).fillna("") + holds["概念数量"] = holds["概念板块"].apply(len) + holds = holds[holds["概念数量"] > 0] + + new_holds = [] + dt_key_concepts = {} + for dt, dfg in tqdm(holds.groupby("dt"), desc="计算板块效应"): + # 计算密集出现的概念 + key_concepts = [k for k, v in Counter([x for y in dfg["概念板块"] for x in y]).most_common(top_n)] + dt_key_concepts[dt] = key_concepts + + # 计算在密集概念中出现次数超过min_n的股票 + dfg["强势概念"] = dfg["概念板块"].apply(lambda x: ",".join(set(x) & set(key_concepts))) + sel = dfg[dfg["强势概念"].apply(lambda x: len(x.split(",")) >= min_n)] + new_holds.append(sel) + + dfh = pd.concat(new_holds, ignore_index=True) + dfk = pd.DataFrame([{"dt": k, "强势概念": v} for k, v in dt_key_concepts.items()]) + return dfh, dfk diff --git a/czsc/sensors/utils.pyi b/czsc/sensors/utils.pyi new file mode 100644 index 000000000..0f078a4e0 --- /dev/null +++ b/czsc/sensors/utils.pyi @@ -0,0 +1,5 @@ +import pandas as pd + +def max_draw_down(n1b: list): ... +def turn_over_rate(df_holds: pd.DataFrame) -> tuple[pd.DataFrame, float]: ... +def holds_concepts_effect(holds: pd.DataFrame, concepts: dict, top_n: int = 20, min_n: int = 3, **kwargs): ... diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 4c918a482..0dfed3806 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -779,12 +779,13 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `czsc.envs` 精简(Python 侧 docstring 收缩) | [czsc/envs.py](../czsc/envs.py)、[czsc/envs.pyi](../czsc/envs.pyi) | spec §3.4 目标"~20 行"是含 Rust `set_envs(...)` 入口后的最终形态;本轮先做 Python 侧最大化压缩:117 → 49 行(-58%)。3 个 getter(`get_verbose` / `get_min_bi_len` / `get_max_bi_num`) + 2 个内部 helper(`_env` / `_to_bool`)逻辑完全保留;裁剪掉模块级 ~20 行说明性 docstring 与每个函数 8-15 行的 verbose docstring,改为单行说明。`envs.pyi` 同步更新:删除 `valid_true: Incomplete` / `def use_python(): ...` / `def get_welcome(): ...` 三个旧公共符号(`test/test_envs.py::TestRetiredHelpers` 已经断言它们在 .py 中不存在,但 .pyi 之前未跟进,会让 basedpyright 把它们误识为存在)。`test_envs.py` 全 16 项通过 | | `czsc/_native.pyi` 自动生成(spec §2.4 / Q4) | [crates/czsc-python/Cargo.toml](../crates/czsc-python/Cargo.toml)、[crates/czsc-python/src/lib.rs](../crates/czsc-python/src/lib.rs)、[crates/czsc-python/src/bin/stub_gen.rs](../crates/czsc-python/src/bin/stub_gen.rs)、[czsc/_native.pyi](../czsc/_native.pyi) | 各 PyO3 业务 crate 上的 `gen_stub_pyclass` / `gen_stub_pyfunction` / `gen_stub_pymethods` 装饰器早已布到位,但缺 stub 收集器 + 生成 binary。本次:① `czsc-python` 拆出 `extension-module` 为可选 default feature(cdylib 走默认;binary 用 `--no-default-features` 让 pyo3 自动链接 libpython,否则 macOS 链接器找不到 `_PyExc_*` 等符号);② lib.rs 写了一个自定义 `stub_info()` 函数,把 `from_pyproject_toml` 路径显式指向 workspace 根(默认宏 `define_stub_info_gatherer!` 假设 `pyproject.toml` 与 `Cargo.toml` 同目录,但本仓库 pyproject 在 workspace 根、Cargo 在 `crates/czsc-python/`);③ `src/bin/stub_gen.rs` 是最小入口,调用 `stub_info()?.generate()?`;④ 触发:`PYO3_PYTHON=$(uv run python -c 'import sys; print(sys.executable)') cargo run --bin stub_gen -p czsc-python --no-default-features`;⑤ 产物:`czsc/_native.pyi`,1 235 行,覆盖 BI / CZSC / FX / ZS / BarGenerator / Position / Signal / Event / RawBar / NewBar / FakeBI / Direction / Mark / Operate / Freq / ParsedSignalDoc 等核心类与 30+ TA 算子 / 信号函数 / `chip_distribution_triangle` / `parse_signal_doc` 等顶层函数。`pyproject.toml::tool.maturin.include = ["czsc/**/*.pyi", ...]` 已经覆盖此文件,wheel 打包自动带上。**残留**:basedpyright 在 `_native.pyi` 上报 8 个 upstream pyo3-stub-gen 已知问题(`__eq__` 参数类型不兼容父类、`__dict__` 与 `dict[str, Any]` 不兼容),属于工具层 false-positive,不影响功能 | | `_native.pyi` 漂移检查 CI job | [.github/workflows/code-quality.yml](../.github/workflows/code-quality.yml) | `code-quality.yml` 新增 `stub-drift` job(依赖 `rust-tests`),在 CI 中:① checkout + 装 Rust + Python 3.11;② 跑 `PYO3_PYTHON=$(which python3) cargo run --bin stub_gen -p czsc-python --no-default-features`;③ `git diff --exit-code czsc/_native.pyi` 断言无漂移,否则失败并把本地重新生成命令打到日志里。本地已验证 stub_gen 重跑两次产物一致(idempotent)。这一步覆盖两个回归方向:(a)改了 `gen_stub_*` 装饰器但忘重跑 → CI 红灯阻拦;(b)手改了 stub 但 Rust 没改 → 下次 CI 复跑时把手改盖掉、提示提交方处理 | +| `czsc/sensors/` 部分恢复(spec §9) | [czsc/sensors/utils.py](../czsc/sensors/utils.py)、[czsc/sensors/utils.pyi](../czsc/sensors/utils.pyi)、[czsc/sensors/__init__.py](../czsc/sensors/__init__.py) | 之前 sensors 仅有 15 行占位 `__init__.py`,与 spec §9 "完整保留 3 文件 301 行"差距明显。本次:① 从 git 历史 `79bdf5e:czsc/sensors/utils.py` 恢复 `utils.py`(121 行,含 `holds_concepts_effect` / `turn_over_rate` / `max_draw_down`,纯 numpy / pandas 实现,无内部 czsc 依赖);② 同步恢复 `utils.pyi`;③ 重写 `__init__.py`:暴露 3 个 utility 函数 + 添加 `CTAResearch` **占位类**(`__init__` 直接抛 `NotImplementedError`,明确指出历史实现依赖已删 `czsc.traders.dummy.DummyBacktest`、引导用户改用 `czsc.run_replay` / `wbt.WeightBacktest` 组合,并指向本文档)。`from czsc.sensors import CTAResearch, holds_concepts_effect, ...` 全部正常;`CTAResearch()` 调用即抛 `NotImplementedError`(fail-fast,避免在调用半截才报"找不到 DummyBacktest")。spec §9 "完整保留" 项剩余的就只是 `cta.py` 真实迁移(需先在 Rust 端 `czsc-trader` 落地等价 dummy/replay)。| ### 10.2 故意保留 / 暂缓的项 | 项 | 原计划(spec) | 实际处理 | 原因 | |-|-|-|-| -| `czsc/sensors/` 完整恢复 | spec §9 "完整保留 3 文件 301 行(含 `CTAResearch`)" | 保留占位 `__init__.py`(15 行),未恢复 `cta.py` / `utils.py` | 历史 `czsc.sensors.cta.CTAResearch` 依赖 `from czsc.traders.dummy import DummyBacktest`,而 `dummy.py` 已按 spec §3.3 删除并由 `czsc.run_replay` / wbt 替代;1:1 恢复会引入坏 import。需先在 Rust 端 `czsc-trader` 提供等价 dummy/replay 后再恢复,归为后续 Phase G 收尾 | +| `czsc/sensors/` 部分恢复 | spec §9 "完整保留 3 文件 301 行(含 `CTAResearch`)" | **2026-05-07 已部分恢复**:`utils.py`(121 行,3 个纯 numpy/pandas 工具)+ `utils.pyi` + `__init__.py` 重写(添加 `CTAResearch` `NotImplementedError` 占位)。详见上表"`czsc/sensors/` 部分恢复" | 完整恢复仍依赖 Phase G 在 Rust 端 `czsc-trader` 提供 dummy/replay 等价物,之后再把 `cta.py` 真实迁移。当前状态:`from czsc.sensors import holds_concepts_effect, turn_over_rate, max_draw_down` 可用;`CTAResearch()` 调用即 fail-fast | | `czsc/traders/optimize.py` | spec §3.3 / §9 列入"完全删除" | 保留 | 现已是 Rust 端 `run_optimize_batch` 的 Python 薄外观层(配置归一化 + 物化数据 + 任务哈希 + 结果转发),与 spec 旧版"完全删除"假设不符;行为正确,无回归。以"过渡薄层"身份保留,不在 P0 范围内删除 | | `czsc/utils/ta.py` | spec §3.2 删除(由 Rust `czsc.ta.*` 替代) | 保留 75 行 | 仅保留 czsc 仪表盘场景使用的 MACD 特殊约定("柱状图额外乘以 2"),Rust 端 `czsc-ta` 暂未迁移该约定。**不通过 `czsc.ta` 重新导出**(`czsc.ta` 已指向 Rust 子模块),调用方需显式 `from czsc.utils.ta import MACD`。后续把柱状图 ×2 约定纳入 `czsc-ta::pure` 后再删 | | ~~`czsc/_native.pyi` 自动生成~~ | ~~spec §2.4 / Q4~~ | **2026-05-07 已完成** | 详见上表"`czsc/_native.pyi` 自动生成"。1 235 行 stub 已生成并被 maturin include 覆盖;basedpyright 上 8 个 upstream pyo3-stub-gen false-positive 已记录,不影响功能 | From 0bd308236169dd07a605fa9d4b668e939ecd46e1 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 17:26:37 +0800 Subject: [PATCH 12/23] =?UTF-8?q?perf(bench):=20add=20CZSC::new=20criterio?= =?UTF-8?q?n=20benchmark=20(spec=20=C2=A76=20P1=20=3D=2096.585ms=20/=20200?= =?UTF-8?q?ms)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 闭合 spec §6 P1 验收:10 万根 K 线做完整 CZSC 分析(分型/笔/中枢识别)≤ 200 ms。 新增 crates/czsc-core/benches/czsc_analyze_bench.rs: - 用 criterion 0.7 做基准(lto=true / opt-level=3 / codegen-units=1 配置由 workspace [profile.release] 提供) - 慢周期正弦 + 快周期抖动 + 渐进漂移生成 10 万根 30 分钟模拟 K 线 (避免单调推高/降低让缠论分析路径退化) - iter_batched(BatchSize::LargeInput) 测 CZSC::new(bars, max_bi_num=50) - 20 样本 实测(M2 Mac,release): czsc_analyze/CZSC::new(bars=100000, max_bi_num=50) time: [96.276 ms 96.585 ms 96.971 ms] 96.585 ms vs spec §6 P1 目标 200 ms,余量 52%,P1 达标 ✅。 Cargo.toml 加 criterion = "0.7" 到 [dev-dependencies] + [[bench]] name = "czsc_analyze_bench" harness = false。 触发命令:cargo bench -p czsc-core --- Cargo.lock | 203 ++++++++++++++++++ crates/czsc-core/Cargo.toml | 6 + .../czsc-core/benches/czsc_analyze_bench.rs | 84 ++++++++ docs/MIGRATION_NOTES.md | 9 +- 4 files changed, 301 insertions(+), 1 deletion(-) create mode 100644 crates/czsc-core/benches/czsc_analyze_bench.rs diff --git a/Cargo.lock b/Cargo.lock index 682e80212..1783f615c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -60,6 +60,18 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + [[package]] name = "anyhow" version = "1.0.102" @@ -236,6 +248,12 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.61" @@ -290,6 +308,58 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" + [[package]] name = "comfy-table" version = "7.2.2" @@ -325,6 +395,39 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "itertools", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" +dependencies = [ + "cast", + "itertools", +] + [[package]] name = "crossbeam-channel" version = "0.5.15" @@ -391,6 +494,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.7" @@ -428,6 +537,7 @@ version = "1.0.0" dependencies = [ "anyhow", "chrono", + "criterion", "derive_builder", "error-macros", "error-support", @@ -917,6 +1027,17 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "halfbrown" version = "0.2.5" @@ -1386,6 +1507,12 @@ version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "ordered-float" version = "5.3.0" @@ -1502,6 +1629,34 @@ dependencies = [ "array-init-cursor", ] +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polars" version = "0.42.0" @@ -2308,6 +2463,15 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2626,6 +2790,16 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tokio" version = "1.52.2" @@ -2799,6 +2973,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -2902,6 +3086,16 @@ dependencies = [ "semver", ] +[[package]] +name = "web-sys" +version = "0.3.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2918,6 +3112,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/crates/czsc-core/Cargo.toml b/crates/czsc-core/Cargo.toml index 3866342cb..5d3e63839 100644 --- a/crates/czsc-core/Cargo.toml +++ b/crates/czsc-core/Cargo.toml @@ -37,3 +37,9 @@ python = ["pyo3", "pyo3-stub-gen", "parking_lot"] anyhow = "1" serde_json = "1" thiserror = "2" +criterion = "0.7" + +# spec §6 P1 性能基准:10 万根 K 线 CZSC 完整分析 ≤ 200 ms(M2 Mac,单进程) +[[bench]] +name = "czsc_analyze_bench" +harness = false diff --git a/crates/czsc-core/benches/czsc_analyze_bench.rs b/crates/czsc-core/benches/czsc_analyze_bench.rs new file mode 100644 index 000000000..d0ceb4070 --- /dev/null +++ b/crates/czsc-core/benches/czsc_analyze_bench.rs @@ -0,0 +1,84 @@ +//! CZSC 核心分析性能基准(spec §6 P1)。 +//! +//! 验收目标:10 万根 K 线做完整 CZSC 分析(分型 / 笔 / 中枢识别)≤ 200 ms。 +//! +//! 触发: +//! cargo bench -p czsc-core +//! +//! 在 M2 Mac、release 构建(lto=true / opt-level=3 / codegen-units=1)下, +//! criterion 默认 100 个样本输出 mean / median / std-dev,把 mean 与 200 ms +//! 比对即可判定 P1 是否达标。 + +use std::hint::black_box; +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::{RawBar, RawBarBuilder}; +use czsc_core::objects::freq::Freq; + +/// 生成 `count` 根模拟 30 分钟 K 线,价格用正弦+游走,振幅与噪声足以触发 +/// 分型与笔的形成(不会出现一直单调推高/降低导致缠论分析路径退化)。 +fn generate_bars(count: usize) -> Vec { + let symbol: Arc = Arc::from("000001.SH"); + let base_ts = 1704067200i64; // 2024-01-01 00:00 UTC + + (0..count) + .map(|i| { + // 基础价格围绕 100,慢周期正弦 + 快周期高频抖动 + 渐进漂移 + let slow = (i as f64 * 0.001).sin() * 30.0; + let fast = (i as f64 * 0.07).sin() * 4.0; + let drift = i as f64 * 0.0002; + let close = 100.0 + slow + fast + drift; + let open = close - fast * 0.5; + let high = close.max(open) + 0.6; + let low = close.min(open) - 0.6; + // 30 分钟一根:每根递增 1800 秒 + let dt = Utc.timestamp_opt(base_ts + i as i64 * 1800, 0).unwrap(); + + RawBarBuilder::default() + .symbol(symbol.clone()) + .id(i as i32) + .dt(dt) + .freq(Freq::F30) + .open(open) + .close(close) + .high(high) + .low(low) + .vol(1_000_000.0) + .amount(close * 1_000_000.0) + .build() + .expect("RawBar 构造失败") + }) + .collect() +} + +fn bench_czsc_analyze(c: &mut Criterion) { + // spec §6 P1 目标:10 万根 K 线 ≤ 200 ms + const N: usize = 100_000; + let bars = generate_bars(N); + + let mut group = c.benchmark_group("czsc_analyze"); + group.sample_size(20); // 大样本下 20 已经足够稳定 + + group.bench_function(format!("CZSC::new(bars={N}, max_bi_num=50)"), |b| { + b.iter_batched( + || bars.clone(), + |input| { + let c = CZSC::new(black_box(input), 50); + black_box(c) + }, + BatchSize::LargeInput, + ); + }); + + group.finish(); +} + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench_czsc_analyze +); +criterion_main!(benches); diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 0dfed3806..20a6c2467 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -780,6 +780,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `czsc/_native.pyi` 自动生成(spec §2.4 / Q4) | [crates/czsc-python/Cargo.toml](../crates/czsc-python/Cargo.toml)、[crates/czsc-python/src/lib.rs](../crates/czsc-python/src/lib.rs)、[crates/czsc-python/src/bin/stub_gen.rs](../crates/czsc-python/src/bin/stub_gen.rs)、[czsc/_native.pyi](../czsc/_native.pyi) | 各 PyO3 业务 crate 上的 `gen_stub_pyclass` / `gen_stub_pyfunction` / `gen_stub_pymethods` 装饰器早已布到位,但缺 stub 收集器 + 生成 binary。本次:① `czsc-python` 拆出 `extension-module` 为可选 default feature(cdylib 走默认;binary 用 `--no-default-features` 让 pyo3 自动链接 libpython,否则 macOS 链接器找不到 `_PyExc_*` 等符号);② lib.rs 写了一个自定义 `stub_info()` 函数,把 `from_pyproject_toml` 路径显式指向 workspace 根(默认宏 `define_stub_info_gatherer!` 假设 `pyproject.toml` 与 `Cargo.toml` 同目录,但本仓库 pyproject 在 workspace 根、Cargo 在 `crates/czsc-python/`);③ `src/bin/stub_gen.rs` 是最小入口,调用 `stub_info()?.generate()?`;④ 触发:`PYO3_PYTHON=$(uv run python -c 'import sys; print(sys.executable)') cargo run --bin stub_gen -p czsc-python --no-default-features`;⑤ 产物:`czsc/_native.pyi`,1 235 行,覆盖 BI / CZSC / FX / ZS / BarGenerator / Position / Signal / Event / RawBar / NewBar / FakeBI / Direction / Mark / Operate / Freq / ParsedSignalDoc 等核心类与 30+ TA 算子 / 信号函数 / `chip_distribution_triangle` / `parse_signal_doc` 等顶层函数。`pyproject.toml::tool.maturin.include = ["czsc/**/*.pyi", ...]` 已经覆盖此文件,wheel 打包自动带上。**残留**:basedpyright 在 `_native.pyi` 上报 8 个 upstream pyo3-stub-gen 已知问题(`__eq__` 参数类型不兼容父类、`__dict__` 与 `dict[str, Any]` 不兼容),属于工具层 false-positive,不影响功能 | | `_native.pyi` 漂移检查 CI job | [.github/workflows/code-quality.yml](../.github/workflows/code-quality.yml) | `code-quality.yml` 新增 `stub-drift` job(依赖 `rust-tests`),在 CI 中:① checkout + 装 Rust + Python 3.11;② 跑 `PYO3_PYTHON=$(which python3) cargo run --bin stub_gen -p czsc-python --no-default-features`;③ `git diff --exit-code czsc/_native.pyi` 断言无漂移,否则失败并把本地重新生成命令打到日志里。本地已验证 stub_gen 重跑两次产物一致(idempotent)。这一步覆盖两个回归方向:(a)改了 `gen_stub_*` 装饰器但忘重跑 → CI 红灯阻拦;(b)手改了 stub 但 Rust 没改 → 下次 CI 复跑时把手改盖掉、提示提交方处理 | | `czsc/sensors/` 部分恢复(spec §9) | [czsc/sensors/utils.py](../czsc/sensors/utils.py)、[czsc/sensors/utils.pyi](../czsc/sensors/utils.pyi)、[czsc/sensors/__init__.py](../czsc/sensors/__init__.py) | 之前 sensors 仅有 15 行占位 `__init__.py`,与 spec §9 "完整保留 3 文件 301 行"差距明显。本次:① 从 git 历史 `79bdf5e:czsc/sensors/utils.py` 恢复 `utils.py`(121 行,含 `holds_concepts_effect` / `turn_over_rate` / `max_draw_down`,纯 numpy / pandas 实现,无内部 czsc 依赖);② 同步恢复 `utils.pyi`;③ 重写 `__init__.py`:暴露 3 个 utility 函数 + 添加 `CTAResearch` **占位类**(`__init__` 直接抛 `NotImplementedError`,明确指出历史实现依赖已删 `czsc.traders.dummy.DummyBacktest`、引导用户改用 `czsc.run_replay` / `wbt.WeightBacktest` 组合,并指向本文档)。`from czsc.sensors import CTAResearch, holds_concepts_effect, ...` 全部正常;`CTAResearch()` 调用即抛 `NotImplementedError`(fail-fast,避免在调用半截才报"找不到 DummyBacktest")。spec §9 "完整保留" 项剩余的就只是 `cta.py` 真实迁移(需先在 Rust 端 `czsc-trader` 落地等价 dummy/replay)。| +| `czsc-core` criterion 性能基准(spec §6 P1) | [crates/czsc-core/Cargo.toml](../crates/czsc-core/Cargo.toml)、[crates/czsc-core/benches/czsc_analyze_bench.rs](../crates/czsc-core/benches/czsc_analyze_bench.rs) | 添加 `criterion 0.7` 到 `[dev-dependencies]` + `[[bench]] name = "czsc_analyze_bench" harness = false` 配置;新建 `benches/czsc_analyze_bench.rs` 用慢周期正弦+快周期抖动+渐进漂移生成 10 万根 30 分钟模拟 K 线(保证不会单调推高/降低让缠论分析路径退化),用 `criterion::iter_batched(BatchSize::LargeInput)` 测 `CZSC::new(bars, max_bi_num=50)`。**M2 Mac、release 构建(lto=true / opt-level=3 / codegen-units=1)下,mean = 96.585 ms(CI: 96.276–96.971 ms,20 样本)**——spec §6 P1 目标 ≤ 200 ms,余量 52%,**P1 达标 ✅**。触发:`cargo bench -p czsc-core` | ### 10.2 故意保留 / 暂缓的项 @@ -807,9 +808,15 @@ $ uv run pytest test/compat/ test/unit/ test/test_envs.py test/test_io.py \ test/test_trade_utils.py test/test_stoploss_by_direction.py -q 124 passed in 2.93s -$ wc -l czsc/__init__.py czsc/traders/sig_parse.py +$ wc -l czsc/__init__.py czsc/traders/sig_parse.py czsc/envs.py 235 czsc/__init__.py # 507 -> 235,spec §3.1 lazy loading 已退役 326 czsc/traders/sig_parse.py # 387 -> 326,_lazy_rs_czsc 工厂已退役 + 49 czsc/envs.py # 117 -> 49,spec §3.4 Python 侧精简 + +$ cargo bench -p czsc-core +czsc_analyze/CZSC::new(bars=100000, max_bi_num=50) + time: [96.276 ms 96.585 ms 96.971 ms] +# spec §6 P1 目标 ≤ 200 ms,达标 ✅ ``` 公共 API 快照(`test/compat/snapshots/api_v1.json`,129 个公共名称)与 pickle roundtrip(5 个 PyO3 类)回归全部 GREEN,证明本轮 P0/P1 改动未破坏 §6 验收基线。 From b55153081f3bc09170652d262110a46be90e2772 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 17:39:05 +0800 Subject: [PATCH 13/23] =?UTF-8?q?perf(bench):=20add=20222-signal=20dispatc?= =?UTF-8?q?h=20benchmark=20(spec=20=C2=A76=20P2=20=3D=201.1=C2=B5s/signal,?= =?UTF-8?q?=204.7ms/10K)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 闭合 spec §6 P2 验收: - 30+ 信号函数批量执行单根 K 线 P50 ≤ 50 µs / 信号 - 批量 1 万根 ≤ 80 ms 新增 crates/czsc-signals/benches/signals_bench.rs: - 复用 P1 同款慢周期正弦 + 快周期抖动 + 渐进漂移 K 线生成器(独立 copy 避免 跨 crate dev-dep) - 构造 100 / 10 000 两个 size 的 CZSC,循环调用 SIGNAL_REGISTRY 中全部 222 个 K 线信号各一次(spec 说"30+"是当时的下限估计,实际仓库内已注册 222 个) - 用 black_box 阻止死代码消除 实测(M2 Mac,release): signals_dispatch/dispatch_all(222 signals, bars=100) time: [242.86 µs 244.20 µs 245.52 µs] signals_dispatch/dispatch_all(222 signals, bars=10000) time: [4.6427 ms 4.6820 ms 4.7254 ms] - 244.2 µs / 222 signals ≈ 1.1 µs/signal vs spec ≤ 50 µs ⇒ 余 45×, P2 (单根) ✅ - 4.682 ms vs spec ≤ 80 ms ⇒ 余 17×, P2 (批量 1 万根) ✅ Cargo.toml 加 criterion = "0.7" 到 [dev-dependencies] + [[bench]] name = "signals_bench"。 触发命令:cargo bench -p czsc-signals --- Cargo.lock | 1 + crates/czsc-signals/Cargo.toml | 9 ++ crates/czsc-signals/benches/signals_bench.rs | 114 +++++++++++++++++++ docs/MIGRATION_NOTES.md | 7 ++ 4 files changed, 131 insertions(+) create mode 100644 crates/czsc-signals/benches/signals_bench.rs diff --git a/Cargo.lock b/Cargo.lock index 1783f615c..02d1ac7ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -596,6 +596,7 @@ version = "1.0.0" dependencies = [ "anyhow", "chrono", + "criterion", "czsc-core", "czsc-signal-macros", "czsc-ta", diff --git a/crates/czsc-signals/Cargo.toml b/crates/czsc-signals/Cargo.toml index 9ede996e9..b506c5944 100644 --- a/crates/czsc-signals/Cargo.toml +++ b/crates/czsc-signals/Cargo.toml @@ -20,3 +20,12 @@ anyhow = "1.0" tracing = "0.1" inventory = "0.3" chrono = { version = "0.4", default-features = false, features = ["clock"] } + +[dev-dependencies] +criterion = "0.7" + +# spec §6 P2 性能基准:30+ 信号函数批量执行 +# 单根 K 线 P50 ≤ 50 µs;批量 1 万根 ≤ 80 ms +[[bench]] +name = "signals_bench" +harness = false diff --git a/crates/czsc-signals/benches/signals_bench.rs b/crates/czsc-signals/benches/signals_bench.rs new file mode 100644 index 000000000..bd3c0498f --- /dev/null +++ b/crates/czsc-signals/benches/signals_bench.rs @@ -0,0 +1,114 @@ +//! 信号函数性能基准(spec §6 P2)。 +//! +//! 验收目标: +//! - 单根 K 线分派全部 K 线信号(30+ 信号),P50 ≤ 50 µs(per-signal 单次调用) +//! - 批量 1 万根 K 线下分派全部 K 线信号一次,总耗时 ≤ 80 ms +//! +//! 这里把"全部 K 线信号在同一 CZSC 实例上各调用一次"的总耗时记录下来; +//! 可由总耗时除以 SIGNAL_REGISTRY 大小得到 per-signal P50 估计值。 +//! +//! 触发:cargo bench -p czsc-signals + +use std::collections::HashMap; +use std::hint::black_box; +use std::sync::Arc; + +use chrono::{TimeZone, Utc}; +use criterion::{Criterion, criterion_group, criterion_main}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::bar::{RawBar, RawBarBuilder}; +use czsc_core::objects::freq::Freq; +use czsc_signals::registry::SIGNAL_REGISTRY; +use czsc_signals::types::TaCache; +use serde_json::Value; + +/// 与 czsc-core/benches/czsc_analyze_bench.rs 同款 K 线生成器(独立 copy 避免 +/// 跨 crate dev-dep 依赖;spec §4.2 测试隔离原则)。 +fn generate_bars(count: usize) -> Vec { + let symbol: Arc = Arc::from("000001.SH"); + let base_ts = 1704067200i64; + + (0..count) + .map(|i| { + let slow = (i as f64 * 0.001).sin() * 30.0; + let fast = (i as f64 * 0.07).sin() * 4.0; + let drift = i as f64 * 0.0002; + let close = 100.0 + slow + fast + drift; + let open = close - fast * 0.5; + let high = close.max(open) + 0.6; + let low = close.min(open) - 0.6; + let dt = Utc.timestamp_opt(base_ts + i as i64 * 1800, 0).unwrap(); + + RawBarBuilder::default() + .symbol(symbol.clone()) + .id(i as i32) + .dt(dt) + .freq(Freq::F30) + .open(open) + .close(close) + .high(high) + .low(low) + .vol(1_000_000.0) + .amount(close * 1_000_000.0) + .build() + .expect("RawBar 构造失败") + }) + .collect() +} + +/// 在给定 CZSC 实例上把 ``SIGNAL_REGISTRY`` 中所有 kline 信号各调用一次, +/// 返回成功调用次数(用于 black_box 防止整体被优化掉)。 +fn dispatch_all_signals(czsc: &CZSC) -> usize { + let empty_params: HashMap = HashMap::new(); + let mut cache = TaCache::default(); + let mut count = 0usize; + for (_name, meta) in SIGNAL_REGISTRY.iter() { + let signals = (meta.func)(czsc, &empty_params, &mut cache); + // black_box 让 LLVM 不能把 signals 作为死代码消除 + let _ = black_box(signals); + count += 1; + } + count +} + +fn bench_signals_dispatch(c: &mut Criterion) { + // ① 单根代表 K 线下分派全部信号——把 100 根 bar 的 CZSC 视为典型回放截面 + let small_bars = generate_bars(100); + let small_czsc = CZSC::new(small_bars, 50); + let signals_count = SIGNAL_REGISTRY.len(); + + let mut group = c.benchmark_group("signals_dispatch"); + group.sample_size(20); + + group.bench_function( + format!("dispatch_all({signals_count} signals, bars=100)"), + |b| { + b.iter(|| { + let n = dispatch_all_signals(black_box(&small_czsc)); + black_box(n) + }); + }, + ); + + // ② 1 万根 K 线场景——目标整体 ≤ 80 ms + let large_bars = generate_bars(10_000); + let large_czsc = CZSC::new(large_bars, 50); + group.bench_function( + format!("dispatch_all({signals_count} signals, bars=10000)"), + |b| { + b.iter(|| { + let n = dispatch_all_signals(black_box(&large_czsc)); + black_box(n) + }); + }, + ); + + group.finish(); +} + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench_signals_dispatch +); +criterion_main!(benches); diff --git a/docs/MIGRATION_NOTES.md b/docs/MIGRATION_NOTES.md index 20a6c2467..95a7d8add 100644 --- a/docs/MIGRATION_NOTES.md +++ b/docs/MIGRATION_NOTES.md @@ -781,6 +781,7 @@ the signal-output level on data ranging from 500 bars to 40k bars. | `_native.pyi` 漂移检查 CI job | [.github/workflows/code-quality.yml](../.github/workflows/code-quality.yml) | `code-quality.yml` 新增 `stub-drift` job(依赖 `rust-tests`),在 CI 中:① checkout + 装 Rust + Python 3.11;② 跑 `PYO3_PYTHON=$(which python3) cargo run --bin stub_gen -p czsc-python --no-default-features`;③ `git diff --exit-code czsc/_native.pyi` 断言无漂移,否则失败并把本地重新生成命令打到日志里。本地已验证 stub_gen 重跑两次产物一致(idempotent)。这一步覆盖两个回归方向:(a)改了 `gen_stub_*` 装饰器但忘重跑 → CI 红灯阻拦;(b)手改了 stub 但 Rust 没改 → 下次 CI 复跑时把手改盖掉、提示提交方处理 | | `czsc/sensors/` 部分恢复(spec §9) | [czsc/sensors/utils.py](../czsc/sensors/utils.py)、[czsc/sensors/utils.pyi](../czsc/sensors/utils.pyi)、[czsc/sensors/__init__.py](../czsc/sensors/__init__.py) | 之前 sensors 仅有 15 行占位 `__init__.py`,与 spec §9 "完整保留 3 文件 301 行"差距明显。本次:① 从 git 历史 `79bdf5e:czsc/sensors/utils.py` 恢复 `utils.py`(121 行,含 `holds_concepts_effect` / `turn_over_rate` / `max_draw_down`,纯 numpy / pandas 实现,无内部 czsc 依赖);② 同步恢复 `utils.pyi`;③ 重写 `__init__.py`:暴露 3 个 utility 函数 + 添加 `CTAResearch` **占位类**(`__init__` 直接抛 `NotImplementedError`,明确指出历史实现依赖已删 `czsc.traders.dummy.DummyBacktest`、引导用户改用 `czsc.run_replay` / `wbt.WeightBacktest` 组合,并指向本文档)。`from czsc.sensors import CTAResearch, holds_concepts_effect, ...` 全部正常;`CTAResearch()` 调用即抛 `NotImplementedError`(fail-fast,避免在调用半截才报"找不到 DummyBacktest")。spec §9 "完整保留" 项剩余的就只是 `cta.py` 真实迁移(需先在 Rust 端 `czsc-trader` 落地等价 dummy/replay)。| | `czsc-core` criterion 性能基准(spec §6 P1) | [crates/czsc-core/Cargo.toml](../crates/czsc-core/Cargo.toml)、[crates/czsc-core/benches/czsc_analyze_bench.rs](../crates/czsc-core/benches/czsc_analyze_bench.rs) | 添加 `criterion 0.7` 到 `[dev-dependencies]` + `[[bench]] name = "czsc_analyze_bench" harness = false` 配置;新建 `benches/czsc_analyze_bench.rs` 用慢周期正弦+快周期抖动+渐进漂移生成 10 万根 30 分钟模拟 K 线(保证不会单调推高/降低让缠论分析路径退化),用 `criterion::iter_batched(BatchSize::LargeInput)` 测 `CZSC::new(bars, max_bi_num=50)`。**M2 Mac、release 构建(lto=true / opt-level=3 / codegen-units=1)下,mean = 96.585 ms(CI: 96.276–96.971 ms,20 样本)**——spec §6 P1 目标 ≤ 200 ms,余量 52%,**P1 达标 ✅**。触发:`cargo bench -p czsc-core` | +| `czsc-signals` criterion 性能基准(spec §6 P2) | [crates/czsc-signals/Cargo.toml](../crates/czsc-signals/Cargo.toml)、[crates/czsc-signals/benches/signals_bench.rs](../crates/czsc-signals/benches/signals_bench.rs) | 添加 `criterion 0.7` 到 `[dev-dependencies]` + `[[bench]] name = "signals_bench" harness = false`;新建 `benches/signals_bench.rs`,复用 P1 同款 K 线生成器(独立 copy 避免跨 crate dev-dep);构造 100 / 10 000 两个 size 的 CZSC,循环调用 `SIGNAL_REGISTRY` 中**全部 222 个 K 线信号**各一次(注:spec 说"30+"是当时的下限估计,实际仓库内已注册 222 个),用 `black_box` 阻止死代码消除。**M2 Mac、release 下:dispatch_all(222 signals, bars=100) = 244.2 µs(CI 242.86–245.52, 20 样本,约每信号 1.1 µs);dispatch_all(222 signals, bars=10000) = 4.682 ms(CI 4.6427–4.7254)**。spec §6 P2 目标:单根 K 线 P50 ≤ 50 µs / 信号(实测 1.1 µs,余 45×)、批量 1 万根 ≤ 80 ms(实测 4.7 ms,余 17×),**P2 全维度达标 ✅**。触发:`cargo bench -p czsc-signals` | ### 10.2 故意保留 / 暂缓的项 @@ -817,6 +818,12 @@ $ cargo bench -p czsc-core czsc_analyze/CZSC::new(bars=100000, max_bi_num=50) time: [96.276 ms 96.585 ms 96.971 ms] # spec §6 P1 目标 ≤ 200 ms,达标 ✅ + +$ cargo bench -p czsc-signals +signals_dispatch/dispatch_all(222 signals, bars=100) + time: [242.86 µs 244.20 µs 245.52 µs] # 每信号 ~1.1 µs (spec ≤ 50 µs,余 45×) ✅ +signals_dispatch/dispatch_all(222 signals, bars=10000) + time: [4.6427 ms 4.6820 ms 4.7254 ms] # spec ≤ 80 ms,余 17× ✅ ``` 公共 API 快照(`test/compat/snapshots/api_v1.json`,129 个公共名称)与 pickle roundtrip(5 个 PyO3 类)回归全部 GREEN,证明本轮 P0/P1 改动未破坏 §6 验收基线。 From 1262dac8d83ae6a3fd2b178c7818e2f3419c15ac Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 22:04:22 +0800 Subject: [PATCH 14/23] fix(tests): exclude additional fields from meta comparison in run_research tests --- test/parity/test_run_research.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/parity/test_run_research.py b/test/parity/test_run_research.py index 1fd12054b..d18c04c6d 100644 --- a/test/parity/test_run_research.py +++ b/test/parity/test_run_research.py @@ -169,8 +169,10 @@ def test_run_research_meta_match(rs_czsc_module, czsc_module, parity_inputs): rs_payload = rs_czsc_module._rs_czsc.run_research(arrow_bytes, strategy_json, None, None) czsc_payload = czsc_module._native.run_research(arrow_bytes, strategy_json, None, None) - # 这些字段在不同构建间合理变化,比较时需要排除 - drop_keys = {"build_ts", "git_hash", "engine_version"} + # 这些字段在不同构建/运行间合理变化,比较时需要排除: + # build_ts / git_hash / engine_version —— 构建环境差异 + # elapsed_ms —— 同一进程两次调用的毫秒级耗时抖动,与算法/数据无关 + drop_keys = {"build_ts", "git_hash", "engine_version", "elapsed_ms"} rs_meta = {k: v for k, v in rs_payload["meta"].items() if k not in drop_keys} czsc_meta = {k: v for k, v in czsc_payload["meta"].items() if k not in drop_keys} From 89075696f8dd8b77a2a9d499d05bc020e6428766 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 23:58:59 +0800 Subject: [PATCH 15/23] fix(ci): isolate stub_gen bin via required-features to fix linker error --- .github/workflows/code-quality.yml | 6 ++-- crates/czsc-python/Cargo.toml | 39 ++++++++++++++++++-------- crates/czsc-python/src/bin/stub_gen.rs | 7 ++++- 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 8f7f97c8b..114b567b6 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -79,7 +79,8 @@ jobs: - name: Run stub_gen run: | export PYO3_PYTHON=$(which python3) - cargo run --bin stub_gen -p czsc-python --no-default-features + cargo run --bin stub_gen -p czsc-python \ + --no-default-features --features stub-gen - name: Assert czsc/_native.pyi is in sync with Rust source run: | @@ -87,7 +88,8 @@ jobs: echo "::error::czsc/_native.pyi 与 Rust 装饰器不一致 ——" echo "请在本地运行:" echo " PYO3_PYTHON=\$(uv run python -c 'import sys; print(sys.executable)') \\" - echo " cargo run --bin stub_gen -p czsc-python --no-default-features" + echo " cargo run --bin stub_gen -p czsc-python \\" + echo " --no-default-features --features stub-gen" echo "并把更新后的 czsc/_native.pyi 一并 commit。" exit 1 fi diff --git a/crates/czsc-python/Cargo.toml b/crates/czsc-python/Cargo.toml index 8a734b7c3..daeb93b7f 100644 --- a/crates/czsc-python/Cargo.toml +++ b/crates/czsc-python/Cargo.toml @@ -11,16 +11,32 @@ name = "czsc_python" path = "src/lib.rs" crate-type = ["cdylib", "rlib"] -# 默认启用 extension-module(cdylib 路径); -# 编译 stub_gen 这种独立 binary 时需要 --no-default-features, -# 让 pyo3 自动链接 libpython(详见 src/bin/stub_gen.rs)。 +# 默认启用 extension-module(cdylib 路径)。 +# +# 同一个 package 里既有 cdylib(wheel)又有 bin(stub_gen),feature 集会被 +# 统一,单一 feature 集无法同时满足两者:wheel 要 extension-module + +# abi3-py310(不链接 libpython),bin 要裸 pyo3(链接具体版本 libpython)。 +# 所以把 stub_gen 用 required-features 隔离到独立 feature `stub-gen`: +# - `cargo build --workspace --release` → 默认只构 wheel,跳过 bin +# - `cargo run --bin stub_gen -p czsc-python \ +# --no-default-features --features stub-gen` +# → 关 extension-module/abi3, +# 走非 abi3 链接路径, +# PYO3_PYTHON 解析具体 +# libpython3.X.so +# +# abi3-py310 也挂在 extension-module 后面:CI 上 setup-python 提供的是 +# libpython3.X.so,没有 abi3 模式期望的 libpython3.so 符号链接,bin 路径 +# 必须避开 abi3 才能让链接器找到 PyGILState_Release 这类符号。 [features] default = ["extension-module"] -extension-module = ["pyo3/extension-module"] +extension-module = ["pyo3/extension-module", "pyo3/abi3-py310"] +stub-gen = [] [[bin]] -name = "stub_gen" -path = "src/bin/stub_gen.rs" +name = "stub_gen" +path = "src/bin/stub_gen.rs" +required-features = ["stub-gen"] [dependencies] czsc-core = { path = "../czsc-core", features = ["python"] } @@ -36,11 +52,12 @@ inventory = "0.3" md5 = "0.8" numpy = { workspace = true } polars = { workspace = true } -# czsc-python is the only crate that opts into pyo3/extension-module + -# abi3-py310. Business crates pull in the bare workspace pyo3 so -# `cargo test --workspace` can link against a real libpython resolved -# via PYO3_PYTHON. -pyo3 = { workspace = true, features = ["abi3-py310", "chrono"] } +# czsc-python 是唯一启用 pyo3/extension-module 的 crate。 +# abi3-py310 通过上面的 [features] 段挂在 extension-module 后面,仅 wheel +# 构建路径启用;stub_gen --no-default-features 时走具体版本 libpython 链接。 +# 业务 crate 都引用 workspace 裸 pyo3,`cargo test --workspace` 通过 +# PYO3_PYTHON 解析真实 libpython。 +pyo3 = { workspace = true, features = ["chrono"] } pyo3-stub-gen = "0.12" rust_xlsxwriter = "0.79" serde = { workspace = true } diff --git a/crates/czsc-python/src/bin/stub_gen.rs b/crates/czsc-python/src/bin/stub_gen.rs index c98027b16..72be8287f 100644 --- a/crates/czsc-python/src/bin/stub_gen.rs +++ b/crates/czsc-python/src/bin/stub_gen.rs @@ -6,7 +6,12 @@ //! //! 触发方式: //! PYO3_PYTHON=$(uv run python -c 'import sys; print(sys.executable)') \ -//! cargo run --bin stub_gen -p czsc-python +//! cargo run --bin stub_gen -p czsc-python \ +//! --no-default-features --features stub-gen +//! +//! `--features stub-gen` 通过 [[bin]] required-features 启用本二进制, +//! `--no-default-features` 关闭 extension-module 和 abi3-py310,让 pyo3 +//! 走非 abi3 链接路径(`-lpython3.X`)。两者必须配对出现。 //! //! 输出路径由 `pyproject.toml` 里 `[tool.maturin].module-name = "czsc._native"` //! 推导:写到 `czsc/_native.pyi`。 From 1a27fa202bdfed5a6e62339e26f32148c93f8531 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Thu, 7 May 2026 23:58:59 +0800 Subject: [PATCH 16/23] chore(claude): copy skills/ from rs_czsc --- .../czsc-write-signal-function/SKILL.md | 104 ++++++ .../agents/openai.yaml | 4 + .../event-driven-signal-pipeline.md | 82 +++++ .../references/signal-function-patterns.md | 141 ++++++++ .../scripts/new_signal_stub.py | 142 ++++++++ .claude/skills/signal-functions/SKILL.md | 117 +++++++ .../references/signals-ang.md | 17 + .../references/signals-bar.md | 53 +++ .../references/signals-byi.md | 12 + .../references/signals-cat.md | 9 + .../references/signals-clv.md | 8 + .../references/signals-coo.md | 12 + .../references/signals-cvolp.md | 8 + .../references/signals-cxt.md | 48 +++ .../references/signals-cxt_trader.md | 9 + .../references/signals-jcc.md | 26 ++ .../references/signals-kcatr.md | 8 + .../references/signals-ntmdk.md | 8 + .../references/signals-obv.md | 9 + .../references/signals-pos.md | 23 ++ .../references/signals-pressure.md | 11 + .../references/signals-tas.md | 66 ++++ .../references/signals-vol.md | 13 + .../signal-functions/references/signals-xl.md | 14 + .../references/signals-zdy_trader.md | 11 + .../signals/adtm_up_dw_line_V230603.md | 30 ++ .../signals/amv_up_dw_line_V230603.md | 28 ++ .../signals/asi_up_dw_line_V230603.md | 28 ++ .../signals/bar_accelerate_V221110.md | 27 ++ .../signals/bar_accelerate_V221118.md | 29 ++ .../signals/bar_accelerate_V240428.md | 28 ++ .../signals/bar_amount_acc_V230214.md | 28 ++ .../signals/bar_big_solid_V230215.md | 28 ++ .../references/signals/bar_bpm_V230227.md | 28 ++ .../references/signals/bar_break_V240428.md | 27 ++ .../references/signals/bar_channel_V230508.md | 28 ++ .../signals/bar_classify_V240606.md | 26 ++ .../signals/bar_classify_V240607.md | 26 ++ .../signals/bar_decision_V240608.md | 28 ++ .../signals/bar_decision_V240616.md | 27 ++ .../signals/bar_dual_thrust_V230403.md | 29 ++ .../references/signals/bar_eight_V230702.md | 27 ++ .../references/signals/bar_end_V221211.md | 26 ++ .../signals/bar_fake_break_V230204.md | 28 ++ .../signals/bar_fang_liang_break_V221216.md | 29 ++ .../signals/bar_limit_down_V230525.md | 27 ++ .../signals/bar_mean_amount_V221112.md | 29 ++ .../signals/bar_operate_span_V221111.md | 26 ++ .../references/signals/bar_plr_V240427.md | 29 ++ .../references/signals/bar_polyfit_V240428.md | 27 ++ .../signals/bar_r_breaker_V230326.md | 26 ++ .../signals/bar_reversal_V230227.md | 27 ++ .../signals/bar_section_momentum_V221112.md | 29 ++ .../signals/bar_shuang_fei_V230507.md | 26 ++ .../references/signals/bar_single_V230214.md | 28 ++ .../references/signals/bar_single_V230506.md | 24 ++ .../references/signals/bar_td9_V240616.md | 26 ++ .../references/signals/bar_time_V230327.md | 26 ++ .../references/signals/bar_tnr_V230629.md | 27 ++ .../references/signals/bar_tnr_V230630.md | 28 ++ .../references/signals/bar_trend_V240209.md | 27 ++ .../references/signals/bar_triple_V230506.md | 25 ++ .../references/signals/bar_vol_bs1_V230224.md | 27 ++ .../signals/bar_vol_grow_V221112.md | 27 ++ .../signals/bar_volatility_V241013.md | 27 ++ .../references/signals/bar_weekday_V230328.md | 25 ++ .../signals/bar_window_ps_V230731.md | 29 ++ .../signals/bar_window_ps_V230801.md | 27 ++ .../signals/bar_window_std_V230731.md | 29 ++ .../references/signals/bar_zdf_V221203.md | 28 ++ .../references/signals/bar_zdt_V230331.md | 24 ++ .../references/signals/bar_zfzd_V241013.md | 26 ++ .../references/signals/bar_zfzd_V241014.md | 26 ++ .../signals/bar_zt_count_V230504.md | 28 ++ .../signals/bias_up_dw_line_V230618.md | 28 ++ .../references/signals/byi_bi_end_V230106.md | 26 ++ .../references/signals/byi_bi_end_V230107.md | 26 ++ .../references/signals/byi_fx_num_V230628.md | 27 ++ .../signals/byi_second_bs_V230324.md | 27 ++ .../signals/byi_symmetry_zs_V221107.md | 27 ++ .../references/signals/cat_macd_V230518.md | 27 ++ .../references/signals/cat_macd_V230520.md | 27 ++ .../signals/cci_decision_V240620.md | 26 ++ .../signals/clv_up_dw_line_V230605.md | 27 ++ .../signals/cmo_up_dw_line_V230605.md | 28 ++ .../references/signals/coo_cci_V230323.md | 29 ++ .../references/signals/coo_kdj_V230322.md | 29 ++ .../references/signals/coo_sar_V230325.md | 27 ++ .../references/signals/coo_td_V221110.md | 26 ++ .../references/signals/coo_td_V221111.md | 26 ++ .../signals/cvolp_up_dw_line_V230612.md | 31 ++ .../references/signals/cxt_bi_base_V230228.md | 28 ++ .../references/signals/cxt_bi_end_V230104.md | 28 ++ .../references/signals/cxt_bi_end_V230105.md | 28 ++ .../references/signals/cxt_bi_end_V230222.md | 27 ++ .../references/signals/cxt_bi_end_V230224.md | 27 ++ .../references/signals/cxt_bi_end_V230312.md | 28 ++ .../references/signals/cxt_bi_end_V230320.md | 27 ++ .../references/signals/cxt_bi_end_V230322.md | 27 ++ .../references/signals/cxt_bi_end_V230324.md | 27 ++ .../references/signals/cxt_bi_end_V230618.md | 27 ++ .../references/signals/cxt_bi_end_V230815.md | 27 ++ .../signals/cxt_bi_status_V230101.md | 27 ++ .../signals/cxt_bi_status_V230102.md | 28 ++ .../references/signals/cxt_bi_stop_V230815.md | 27 ++ .../signals/cxt_bi_trend_V230824.md | 28 ++ .../signals/cxt_bi_trend_V230913.md | 27 ++ .../references/signals/cxt_bi_zdf_V230601.md | 27 ++ .../references/signals/cxt_bs_V240526.md | 27 ++ .../references/signals/cxt_bs_V240527.md | 27 ++ .../signals/cxt_decision_V240526.md | 27 ++ .../signals/cxt_decision_V240612.md | 27 ++ .../signals/cxt_decision_V240613.md | 27 ++ .../signals/cxt_decision_V240614.md | 27 ++ .../signals/cxt_double_zs_V230311.md | 27 ++ .../signals/cxt_eleven_bi_V230622.md | 27 ++ .../signals/cxt_first_buy_V221126.md | 27 ++ .../signals/cxt_first_sell_V221126.md | 27 ++ .../references/signals/cxt_five_bi_V230619.md | 27 ++ .../signals/cxt_fx_power_V221107.md | 27 ++ .../signals/cxt_intraday_V230701.md | 13 + .../references/signals/cxt_nine_bi_V230621.md | 27 ++ .../references/signals/cxt_overlap_V240526.md | 27 ++ .../references/signals/cxt_overlap_V240612.md | 27 ++ .../signals/cxt_range_oscillation_V230620.md | 27 ++ .../signals/cxt_second_bs_V230320.md | 28 ++ .../signals/cxt_second_bs_V240524.md | 28 ++ .../signals/cxt_seven_bi_V230620.md | 27 ++ .../signals/cxt_third_bs_V230318.md | 28 ++ .../signals/cxt_third_bs_V230319.md | 28 ++ .../signals/cxt_third_buy_V230228.md | 27 ++ .../signals/cxt_three_bi_V230618.md | 27 ++ .../references/signals/cxt_ubi_end_V230816.md | 27 ++ .../cxt_zhong_shu_gong_zhen_V221221.md | 13 + .../signals/dema_up_dw_line_V230605.md | 26 ++ .../signals/demakder_up_dw_line_V230605.md | 28 ++ .../signals/emv_up_dw_line_V230605.md | 26 ++ .../signals/er_up_dw_line_V230604.md | 29 ++ .../references/signals/jcc_ci_tou_V221101.md | 28 ++ .../signals/jcc_fan_ji_xian_V221121.md | 26 ++ .../signals/jcc_fen_shou_xian_V20221113.md | 27 ++ .../signals/jcc_gap_yin_yang_V221121.md | 26 ++ .../signals/jcc_ping_tou_V221113.md | 27 ++ .../signals/jcc_san_fa_V20221115.md | 27 ++ .../signals/jcc_san_fa_V20221118.md | 26 ++ .../references/signals/jcc_san_szx_V221122.md | 27 ++ .../signals/jcc_san_xing_xian_V221023.md | 27 ++ .../signals/jcc_shan_chun_V221121.md | 26 ++ .../references/signals/jcc_szx_V221111.md | 27 ++ .../references/signals/jcc_ta_xing_V221124.md | 26 ++ .../references/signals/jcc_ten_mo_V221028.md | 26 ++ .../signals/jcc_three_crow_V221108.md | 26 ++ .../signals/jcc_two_crow_V221108.md | 26 ++ .../signals/jcc_wu_yun_gai_ding_V221101.md | 28 ++ .../signals/jcc_xing_xian_V221118.md | 27 ++ .../signals/jcc_yun_xian_V221118.md | 26 ++ .../signals/jcc_zhu_huo_xian_V221027.md | 28 ++ .../signals/kcatr_up_dw_line_V230823.md | 30 ++ .../references/signals/ntmdk_V230824.md | 27 ++ .../signals/obv_up_dw_line_V230719.md | 30 ++ .../references/signals/obvm_line_V230610.md | 28 ++ .../signals/pos_bar_stop_V230524.md | 23 ++ .../signals/pos_fix_exit_V230624.md | 22 ++ .../references/signals/pos_fx_stop_V230414.md | 24 ++ .../references/signals/pos_holds_V230414.md | 24 ++ .../references/signals/pos_holds_V230807.md | 25 ++ .../references/signals/pos_holds_V240428.md | 25 ++ .../references/signals/pos_holds_V240608.md | 24 ++ .../references/signals/pos_ma_V230414.md | 25 ++ .../signals/pos_profit_loss_V230624.md | 25 ++ .../references/signals/pos_status_V230808.md | 22 ++ .../references/signals/pos_stop_V240331.md | 23 ++ .../references/signals/pos_stop_V240428.md | 25 ++ .../references/signals/pos_stop_V240608.md | 24 ++ .../references/signals/pos_stop_V240614.md | 23 ++ .../references/signals/pos_stop_V240717.md | 25 ++ .../references/signals/pos_take_V240428.md | 24 ++ .../signals/pressure_support_V240222.md | 28 ++ .../signals/pressure_support_V240402.md | 28 ++ .../signals/pressure_support_V240406.md | 28 ++ .../signals/pressure_support_V240530.md | 29 ++ .../signals/skdj_up_dw_line_V230611.md | 30 ++ .../signals/tas_accelerate_V230531.md | 28 ++ .../references/signals/tas_angle_V230802.md | 28 ++ .../references/signals/tas_atr_V230630.md | 27 ++ .../signals/tas_atr_break_V230424.md | 28 ++ .../references/signals/tas_boll_bc_V221118.md | 30 ++ .../references/signals/tas_boll_cc_V230312.md | 29 ++ .../signals/tas_boll_power_V221112.md | 27 ++ .../references/signals/tas_boll_vt_V230212.md | 29 ++ .../signals/tas_cci_base_V230402.md | 29 ++ .../signals/tas_cross_status_V230619.md | 27 ++ .../signals/tas_cross_status_V230624.md | 28 ++ .../signals/tas_cross_status_V230625.md | 29 ++ .../signals/tas_dif_layer_V241010.md | 27 ++ .../signals/tas_dif_zero_V240612.md | 28 ++ .../signals/tas_dif_zero_V240614.md | 29 ++ .../references/signals/tas_dma_bs_V240608.md | 28 ++ .../signals/tas_double_ma_V221203.md | 29 ++ .../signals/tas_double_ma_V230511.md | 31 ++ .../signals/tas_double_ma_V240208.md | 28 ++ .../signals/tas_first_bs_V230217.md | 34 ++ .../references/signals/tas_hlma_V230301.md | 29 ++ .../signals/tas_kdj_base_V221101.md | 28 ++ .../references/signals/tas_kdj_evc_V221201.md | 29 ++ .../references/signals/tas_kdj_evc_V230401.md | 29 ++ .../signals/tas_low_trend_V230627.md | 28 ++ .../references/signals/tas_ma_base_V221101.md | 28 ++ .../references/signals/tas_ma_base_V221203.md | 30 ++ .../references/signals/tas_ma_base_V230313.md | 32 ++ .../signals/tas_ma_cohere_V230512.md | 28 ++ .../signals/tas_ma_round_V221206.md | 31 ++ .../signals/tas_ma_system_V230513.md | 27 ++ .../signals/tas_macd_base_V221028.md | 29 ++ .../signals/tas_macd_base_V230320.md | 31 ++ .../references/signals/tas_macd_bc_V221201.md | 30 ++ .../references/signals/tas_macd_bc_V230803.md | 26 ++ .../references/signals/tas_macd_bc_V230804.md | 26 ++ .../references/signals/tas_macd_bc_V240307.md | 27 ++ .../signals/tas_macd_bc_ubi_V230804.md | 26 ++ .../signals/tas_macd_bs1_V230312.md | 26 ++ .../signals/tas_macd_bs1_V230313.md | 27 ++ .../signals/tas_macd_bs1_V230411.md | 29 ++ .../signals/tas_macd_bs1_V230412.md | 28 ++ .../signals/tas_macd_change_V221105.md | 29 ++ .../signals/tas_macd_direct_V221106.md | 28 ++ .../signals/tas_macd_dist_V230408.md | 28 ++ .../signals/tas_macd_dist_V230409.md | 29 ++ .../signals/tas_macd_dist_V230410.md | 28 ++ .../signals/tas_macd_first_bs_V221201.md | 28 ++ .../signals/tas_macd_first_bs_V221216.md | 27 ++ .../signals/tas_macd_power_V221108.md | 29 ++ .../signals/tas_macd_second_bs_V221201.md | 27 ++ .../references/signals/tas_macd_xt_V221208.md | 27 ++ .../signals/tas_rsi_base_V230227.md | 30 ++ .../references/signals/tas_rumi_V230704.md | 29 ++ .../signals/tas_sar_base_V230425.md | 28 ++ .../signals/tas_second_bs_V230228.md | 33 ++ .../signals/tas_second_bs_V230303.md | 33 ++ .../references/signals/tas_slope_V231019.md | 28 ++ .../signals/vol_double_ma_V230214.md | 28 ++ .../references/signals/vol_gao_di_V221218.md | 28 ++ .../signals/vol_single_ma_V230214.md | 28 ++ .../references/signals/vol_ti_suo_V221216.md | 26 ++ .../references/signals/vol_window_V230731.md | 29 ++ .../references/signals/vol_window_V230801.md | 27 ++ .../signals/xl_bar_basis_V240411.md | 26 ++ .../signals/xl_bar_basis_V240412.md | 27 ++ .../signals/xl_bar_position_V240328.md | 26 ++ .../signals/xl_bar_trend_V240329.md | 27 ++ .../signals/xl_bar_trend_V240330.md | 27 ++ .../signals/xl_bar_trend_V240331.md | 25 ++ .../signals/xl_bar_trend_V240623.md | 26 ++ .../signals/zdy_stop_loss_V230406.md | 3 + .../signals/zdy_take_profit_V230406.md | 3 + .../signals/zdy_take_profit_V230407.md | 3 + .../references/signals/zdy_vibrate_V230406.md | 3 + .../scripts/extract_signal_docs.py | 316 ++++++++++++++++++ 258 files changed, 7490 insertions(+) create mode 100644 .claude/skills/czsc-write-signal-function/SKILL.md create mode 100644 .claude/skills/czsc-write-signal-function/agents/openai.yaml create mode 100644 .claude/skills/czsc-write-signal-function/references/event-driven-signal-pipeline.md create mode 100644 .claude/skills/czsc-write-signal-function/references/signal-function-patterns.md create mode 100644 .claude/skills/czsc-write-signal-function/scripts/new_signal_stub.py create mode 100644 .claude/skills/signal-functions/SKILL.md create mode 100644 .claude/skills/signal-functions/references/signals-ang.md create mode 100644 .claude/skills/signal-functions/references/signals-bar.md create mode 100644 .claude/skills/signal-functions/references/signals-byi.md create mode 100644 .claude/skills/signal-functions/references/signals-cat.md create mode 100644 .claude/skills/signal-functions/references/signals-clv.md create mode 100644 .claude/skills/signal-functions/references/signals-coo.md create mode 100644 .claude/skills/signal-functions/references/signals-cvolp.md create mode 100644 .claude/skills/signal-functions/references/signals-cxt.md create mode 100644 .claude/skills/signal-functions/references/signals-cxt_trader.md create mode 100644 .claude/skills/signal-functions/references/signals-jcc.md create mode 100644 .claude/skills/signal-functions/references/signals-kcatr.md create mode 100644 .claude/skills/signal-functions/references/signals-ntmdk.md create mode 100644 .claude/skills/signal-functions/references/signals-obv.md create mode 100644 .claude/skills/signal-functions/references/signals-pos.md create mode 100644 .claude/skills/signal-functions/references/signals-pressure.md create mode 100644 .claude/skills/signal-functions/references/signals-tas.md create mode 100644 .claude/skills/signal-functions/references/signals-vol.md create mode 100644 .claude/skills/signal-functions/references/signals-xl.md create mode 100644 .claude/skills/signal-functions/references/signals-zdy_trader.md create mode 100644 .claude/skills/signal-functions/references/signals/adtm_up_dw_line_V230603.md create mode 100644 .claude/skills/signal-functions/references/signals/amv_up_dw_line_V230603.md create mode 100644 .claude/skills/signal-functions/references/signals/asi_up_dw_line_V230603.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_accelerate_V221110.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_accelerate_V221118.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_accelerate_V240428.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_amount_acc_V230214.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_big_solid_V230215.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_bpm_V230227.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_break_V240428.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_channel_V230508.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_classify_V240606.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_classify_V240607.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_decision_V240608.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_decision_V240616.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_dual_thrust_V230403.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_eight_V230702.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_end_V221211.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_fake_break_V230204.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_fang_liang_break_V221216.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_limit_down_V230525.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_mean_amount_V221112.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_operate_span_V221111.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_plr_V240427.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_polyfit_V240428.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_r_breaker_V230326.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_reversal_V230227.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_section_momentum_V221112.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_shuang_fei_V230507.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_single_V230214.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_single_V230506.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_td9_V240616.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_time_V230327.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_tnr_V230629.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_tnr_V230630.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_trend_V240209.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_triple_V230506.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_vol_bs1_V230224.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_vol_grow_V221112.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_volatility_V241013.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_weekday_V230328.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_window_ps_V230731.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_window_ps_V230801.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_window_std_V230731.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_zdf_V221203.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_zdt_V230331.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_zfzd_V241013.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_zfzd_V241014.md create mode 100644 .claude/skills/signal-functions/references/signals/bar_zt_count_V230504.md create mode 100644 .claude/skills/signal-functions/references/signals/bias_up_dw_line_V230618.md create mode 100644 .claude/skills/signal-functions/references/signals/byi_bi_end_V230106.md create mode 100644 .claude/skills/signal-functions/references/signals/byi_bi_end_V230107.md create mode 100644 .claude/skills/signal-functions/references/signals/byi_fx_num_V230628.md create mode 100644 .claude/skills/signal-functions/references/signals/byi_second_bs_V230324.md create mode 100644 .claude/skills/signal-functions/references/signals/byi_symmetry_zs_V221107.md create mode 100644 .claude/skills/signal-functions/references/signals/cat_macd_V230518.md create mode 100644 .claude/skills/signal-functions/references/signals/cat_macd_V230520.md create mode 100644 .claude/skills/signal-functions/references/signals/cci_decision_V240620.md create mode 100644 .claude/skills/signal-functions/references/signals/clv_up_dw_line_V230605.md create mode 100644 .claude/skills/signal-functions/references/signals/cmo_up_dw_line_V230605.md create mode 100644 .claude/skills/signal-functions/references/signals/coo_cci_V230323.md create mode 100644 .claude/skills/signal-functions/references/signals/coo_kdj_V230322.md create mode 100644 .claude/skills/signal-functions/references/signals/coo_sar_V230325.md create mode 100644 .claude/skills/signal-functions/references/signals/coo_td_V221110.md create mode 100644 .claude/skills/signal-functions/references/signals/coo_td_V221111.md create mode 100644 .claude/skills/signal-functions/references/signals/cvolp_up_dw_line_V230612.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_base_V230228.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230104.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230105.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230222.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230224.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230312.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230320.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230322.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230324.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230618.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_end_V230815.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_status_V230101.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_status_V230102.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_stop_V230815.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_trend_V230824.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_trend_V230913.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bi_zdf_V230601.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bs_V240526.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_bs_V240527.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_decision_V240526.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_decision_V240612.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_decision_V240613.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_decision_V240614.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_double_zs_V230311.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_eleven_bi_V230622.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_first_buy_V221126.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_first_sell_V221126.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_five_bi_V230619.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_fx_power_V221107.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_intraday_V230701.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_nine_bi_V230621.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_overlap_V240526.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_overlap_V240612.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_range_oscillation_V230620.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_second_bs_V230320.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_second_bs_V240524.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_seven_bi_V230620.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_third_bs_V230318.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_third_bs_V230319.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_third_buy_V230228.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_three_bi_V230618.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_ubi_end_V230816.md create mode 100644 .claude/skills/signal-functions/references/signals/cxt_zhong_shu_gong_zhen_V221221.md create mode 100644 .claude/skills/signal-functions/references/signals/dema_up_dw_line_V230605.md create mode 100644 .claude/skills/signal-functions/references/signals/demakder_up_dw_line_V230605.md create mode 100644 .claude/skills/signal-functions/references/signals/emv_up_dw_line_V230605.md create mode 100644 .claude/skills/signal-functions/references/signals/er_up_dw_line_V230604.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_ci_tou_V221101.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_fan_ji_xian_V221121.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_fen_shou_xian_V20221113.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_gap_yin_yang_V221121.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_ping_tou_V221113.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_san_fa_V20221115.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_san_fa_V20221118.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_san_szx_V221122.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_san_xing_xian_V221023.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_shan_chun_V221121.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_szx_V221111.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_ta_xing_V221124.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_ten_mo_V221028.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_three_crow_V221108.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_two_crow_V221108.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_wu_yun_gai_ding_V221101.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_xing_xian_V221118.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_yun_xian_V221118.md create mode 100644 .claude/skills/signal-functions/references/signals/jcc_zhu_huo_xian_V221027.md create mode 100644 .claude/skills/signal-functions/references/signals/kcatr_up_dw_line_V230823.md create mode 100644 .claude/skills/signal-functions/references/signals/ntmdk_V230824.md create mode 100644 .claude/skills/signal-functions/references/signals/obv_up_dw_line_V230719.md create mode 100644 .claude/skills/signal-functions/references/signals/obvm_line_V230610.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_bar_stop_V230524.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_fix_exit_V230624.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_fx_stop_V230414.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_holds_V230414.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_holds_V230807.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_holds_V240428.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_holds_V240608.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_ma_V230414.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_profit_loss_V230624.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_status_V230808.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_stop_V240331.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_stop_V240428.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_stop_V240608.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_stop_V240614.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_stop_V240717.md create mode 100644 .claude/skills/signal-functions/references/signals/pos_take_V240428.md create mode 100644 .claude/skills/signal-functions/references/signals/pressure_support_V240222.md create mode 100644 .claude/skills/signal-functions/references/signals/pressure_support_V240402.md create mode 100644 .claude/skills/signal-functions/references/signals/pressure_support_V240406.md create mode 100644 .claude/skills/signal-functions/references/signals/pressure_support_V240530.md create mode 100644 .claude/skills/signal-functions/references/signals/skdj_up_dw_line_V230611.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_accelerate_V230531.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_angle_V230802.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_atr_V230630.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_atr_break_V230424.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_boll_bc_V221118.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_boll_cc_V230312.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_boll_power_V221112.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_boll_vt_V230212.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_cci_base_V230402.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_cross_status_V230619.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_cross_status_V230624.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_cross_status_V230625.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_dif_layer_V241010.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_dif_zero_V240612.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_dif_zero_V240614.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_dma_bs_V240608.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_double_ma_V221203.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_double_ma_V230511.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_double_ma_V240208.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_first_bs_V230217.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_hlma_V230301.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_kdj_base_V221101.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_kdj_evc_V221201.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_kdj_evc_V230401.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_low_trend_V230627.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_ma_base_V221101.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_ma_base_V221203.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_ma_base_V230313.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_ma_cohere_V230512.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_ma_round_V221206.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_ma_system_V230513.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_base_V221028.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_base_V230320.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bc_V221201.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bc_V230803.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bc_V230804.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bc_V240307.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bc_ubi_V230804.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bs1_V230312.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bs1_V230313.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bs1_V230411.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_bs1_V230412.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_change_V221105.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_direct_V221106.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_dist_V230408.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_dist_V230409.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_dist_V230410.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221201.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221216.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_power_V221108.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_second_bs_V221201.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_macd_xt_V221208.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_rsi_base_V230227.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_rumi_V230704.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_sar_base_V230425.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_second_bs_V230228.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_second_bs_V230303.md create mode 100644 .claude/skills/signal-functions/references/signals/tas_slope_V231019.md create mode 100644 .claude/skills/signal-functions/references/signals/vol_double_ma_V230214.md create mode 100644 .claude/skills/signal-functions/references/signals/vol_gao_di_V221218.md create mode 100644 .claude/skills/signal-functions/references/signals/vol_single_ma_V230214.md create mode 100644 .claude/skills/signal-functions/references/signals/vol_ti_suo_V221216.md create mode 100644 .claude/skills/signal-functions/references/signals/vol_window_V230731.md create mode 100644 .claude/skills/signal-functions/references/signals/vol_window_V230801.md create mode 100644 .claude/skills/signal-functions/references/signals/xl_bar_basis_V240411.md create mode 100644 .claude/skills/signal-functions/references/signals/xl_bar_basis_V240412.md create mode 100644 .claude/skills/signal-functions/references/signals/xl_bar_position_V240328.md create mode 100644 .claude/skills/signal-functions/references/signals/xl_bar_trend_V240329.md create mode 100644 .claude/skills/signal-functions/references/signals/xl_bar_trend_V240330.md create mode 100644 .claude/skills/signal-functions/references/signals/xl_bar_trend_V240331.md create mode 100644 .claude/skills/signal-functions/references/signals/xl_bar_trend_V240623.md create mode 100644 .claude/skills/signal-functions/references/signals/zdy_stop_loss_V230406.md create mode 100644 .claude/skills/signal-functions/references/signals/zdy_take_profit_V230406.md create mode 100644 .claude/skills/signal-functions/references/signals/zdy_take_profit_V230407.md create mode 100644 .claude/skills/signal-functions/references/signals/zdy_vibrate_V230406.md create mode 100644 .claude/skills/signal-functions/scripts/extract_signal_docs.py diff --git a/.claude/skills/czsc-write-signal-function/SKILL.md b/.claude/skills/czsc-write-signal-function/SKILL.md new file mode 100644 index 000000000..c306bcd81 --- /dev/null +++ b/.claude/skills/czsc-write-signal-function/SKILL.md @@ -0,0 +1,104 @@ +--- +name: czsc-write-signal-function +description: 为 rs_czsc 编写或重构 Rust 信号函数(bar/tas/cxt/pos)并接入事件驱动链路(`#[signal]` 自动注册、SignalConfig、Event/Position 匹配)的工作流。遇到“新增信号函数”“修改信号模板”“注册信号函数”“排查信号不触发”时使用。 +--- + +# CZSC Signal Authoring + +## Overview + +用这个技能在 `rs_czsc` 中稳定完成信号函数开发: +1. 判断信号类型(K线级 / Trader级) +2. 编写函数并保持 `Signal` 7段格式 +3. 注册到正确注册表 +4. 用 `SignalConfig` 和测试验证是否能触发 `Event -> Position` + +先读: +- `references/event-driven-signal-pipeline.md` +- `references/signal-function-patterns.md` + +需要快速起草函数时,先运行: +- `scripts/new_signal_stub.py` + +## Decision Tree + +1. 需要访问仓位(`Position`)或策略状态吗? +- 是:写 Trader 级信号(`pos.rs` 风格),用 `#[signal(category = "trader", ...)]` 自动注册。 +- 否:写 K线级信号(`bar/tas/cxt` 风格),用 `#[signal(category = "kline", ...)]` 自动注册。 + +2. 需要技术指标缓存吗? +- 是:使用 `TaCache`,优先走 `update_*_cache`。 +- 否:保留 `_cache` 参数但不使用。 + +3. 需要多个输出信号吗? +- 是:返回 `Vec`,每个信号都保持 7 段格式。 +- 否:返回单元素 `Vec`。 + +## Workflow + +1. 明确输入输出 +- 明确函数输入参数(`params`)和默认值。 +- 明确 `k1/k2/k3` 模板和 `v1/v2/v3` 语义。 +- 明确“数据不足时”返回策略:通常返回 `vec![]` 或 `其他` 信号。 + +2. 实现函数 +- K线级函数签名:`fn(&CZSC, &ParamView, &mut TaCache) -> Vec`。 +- Trader级函数签名:`fn(&dyn TraderState, &ParamView) -> Vec`。 +- 参数解析优先复用工具:`get_usize_param` / `get_str_param`。 +- 必须写详细注释(硬性格式,不可省略小节): + - 标题行:`/// :<一句话功能>` + - `///` + - `/// 参数模板:\`"..."\`` + - `///` + - `/// 信号逻辑:` + - `/// 1. ...` + - `/// 2. ...` + - `/// 3. ...` + - `///` + - `/// 信号列表示例:` + - `/// - Signal('...')` + - `/// - Signal('...')` + - `///` + - `/// 参数说明:` + - `/// - : <含义与默认值>` + - `/// - : <含义与默认值>` + - 与 Python 或历史版本对齐时,必须补一行:`/// 对齐说明:...` +- 关键分支注释:说明为什么这么判定,而不是只写“做了什么”。 +- 用 `format!` 拼出完整字符串:`k1_k2_k3_v1_v2_v3_score`。 +- 用 `Signal::from_str(...)` 或 `parse::()` 做最终构造。 + +3. 接入注册表 +- 在函数上添加 `#[signal(...)]`,由 `inventory` 自动收集。 +- 注册名与函数名保持一一对应(注意 `_V` 与 `_v` 大小写风格)。 +- 在 `#[signal(...)]` 中给出 `template`,确保 `SignalConfig` 反解析和人类阅读一致。 + +4. 接入调用侧 +- K线级:通过 `SignalConfig { name, freq: Some(...), params }` 在 `CzscSignals::update_signals` 中执行。 +- Trader级:通过 `SignalConfig { name, freq: None, params }` 在 `CzscTrader::update` 的 Trader 信号分支执行。 + +5. 验证 +- 至少验证: + - 函数在样本上可运行,不 panic。 + - 输出信号可被 `Signal` 正常解析(7段、score 0~100)。 + - 注册后能被 `SignalConfig.name` 命中。 + - 若用于交易事件,`Event.matches_signals*` 能按预期触发。 + - 注释完整:他人仅看注释可理解参数、边界、输出语义和关键判定。 +- 优先补充/运行相关测试:`signal_compare_tests`、`trader_tests` 或新增针对性测试。 +- 注释格式校验(提交前必须人工检查): + - 每个新/改信号函数都包含 `参数模板/信号逻辑/信号列表示例/参数说明` 四段。 + - `信号列表示例` 至少 2 条,且与当前函数输出字段一致(`k1_k2_k3_v1_v2_v3_score`)。 + +## Guardrails + +- 不要输出非 7 段信号;否则会在 `Signal::from_str` 失败。 +- 不要忽略大小写差异;注册表键是精确匹配。 +- 不要在数据长度不足时硬算;先做边界检查。 +- 不要绕过 `TaCache` 重复全量计算重指标。 +- 不要在函数里偷偷改全局状态;保持纯计算风格(Trader级仅读取状态)。 +- 不要写空洞注释(如“计算信号”);注释要解释判定依据和业务语义。 + +## Resources + +- 架构与触发链路:`references/event-driven-signal-pipeline.md` +- 函数模板与常见坑:`references/signal-function-patterns.md` +- 骨架生成器:`scripts/new_signal_stub.py` diff --git a/.claude/skills/czsc-write-signal-function/agents/openai.yaml b/.claude/skills/czsc-write-signal-function/agents/openai.yaml new file mode 100644 index 000000000..5a321f983 --- /dev/null +++ b/.claude/skills/czsc-write-signal-function/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "CZSC Signal Writer" + short_description: "Write Rust CZSC signal functions and wire registries" + default_prompt: "Use $czsc-write-signal-function to implement and register a new CZSC signal function." diff --git a/.claude/skills/czsc-write-signal-function/references/event-driven-signal-pipeline.md b/.claude/skills/czsc-write-signal-function/references/event-driven-signal-pipeline.md new file mode 100644 index 000000000..ef4db39a8 --- /dev/null +++ b/.claude/skills/czsc-write-signal-function/references/event-driven-signal-pipeline.md @@ -0,0 +1,82 @@ +# Event-Driven Signal Pipeline (rs_czsc) + +## 1. 总体调用链 + +1. `CzscTrader::update` 接收一根 `RawBar` +2. `CzscSignals::update_signals` 计算本根所有 K线级信号 +3. `CzscTrader::update` 额外执行 Trader 级信号(pos 系列) +4. 合并信号后调用每个 `Position::update` +5. `Position` 按 `opens + exits` 顺序匹配 `Event` +6. `Event` 内部通过 `signals_not -> signals_all -> signals_any` 判定 +7. 匹配后产生命令(`LO/LE/SO/SE/...`)并更新持仓记录 + +## 2. 关键文件 + +- `crates/czsc-trader/src/trader.rs` + - 交易主入口;组织信号计算与仓位更新 +- `crates/czsc-trader/src/signals/czsc_signals.rs` + - K线级信号执行器;维护 `s` 与 `sigs` +- `crates/czsc-signals/src/registry.rs` + - 汇总 `inventory` 自动收集的 `SignalDescriptor` 到运行时注册表 +- `crates/czsc-signals/src/types.rs` + - 信号函数签名类型定义 +- `crates/czsc-core/src/objects/signal.rs` + - `Signal` 7段格式、`is_match` +- `crates/czsc-core/src/objects/event.rs` + - `Event` 匹配逻辑 +- `crates/czsc-core/src/objects/position.rs` + - `Position::update` 执行开平仓规则 +- `crates/czsc-core/src/objects/state.rs` + - `TraderState` trait(给 Trader 级信号读状态) + +## 3. 两类信号函数 + +### K线级信号(bar / tas / cxt) + +- 签名: +`fn(&CZSC, &ParamView, &mut TaCache) -> Vec` +- 注册方式:函数上 `#[signal(category = "kline", ...)]` 自动收集 +- 调用路径:`CzscSignals::update_signals` +- 适用场景:仅依赖某个周期的 K 线结构和指标 + +### Trader级信号(pos) + +- 签名: +`fn(&dyn TraderState, &ParamView) -> Vec` +- 注册方式:函数上 `#[signal(category = "trader", ...)]` 自动收集 +- 调用路径:`CzscTrader::update` 中 `freq.is_none()` 分支 +- 适用场景:需要 `Position` + `CZSC` 联合状态 + +## 4. 信号字符串规则 + +`Signal` 必须满足: +- 格式:`k1_k2_k3_v1_v2_v3_score` +- 总段数:7 段 +- `score`:0~100 + +`CzscSignals` 会把它拆成: +- key:`k1_k2_k3` +- value:`v1_v2_v3_score` + +`Event` 匹配时用 `Signal::is_match`: +- `score >= 事件要求score` +- `v1/v2/v3` 精确匹配或事件侧为 `任意` + +## 5. 开发时最常见断点 + +- 注册表 name 与 `SignalConfig.name` 不一致(大小写、`_V`/`_v`) +- 输出不是 7 段,导致 `Signal::from_str` 失败 +- `freq` 设置错误: + - K线级应为 `Some(freq)` + - Trader级应为 `None` +- 数据不足未做边界检查,触发索引错误或无意义信号 + +## 6. 最小联调路径 + +1. 在 `czsc-signals` 实现函数并添加 `#[signal(...)]` 标注 +2. 构造 `SignalConfig` +3. 用 `CzscTrader::update` 喂历史 bars +4. 检查: + - `trader.signals.s` 中是否有目标 key + - `trader.signals.sigs` 中是否出现目标 `Signal` + - `position.operates` 是否按预期变化 diff --git a/.claude/skills/czsc-write-signal-function/references/signal-function-patterns.md b/.claude/skills/czsc-write-signal-function/references/signal-function-patterns.md new file mode 100644 index 000000000..12c628f96 --- /dev/null +++ b/.claude/skills/czsc-write-signal-function/references/signal-function-patterns.md @@ -0,0 +1,141 @@ +# Signal Function Patterns + +## 0. 注释规范(必须) + +每个新信号函数都必须包含以下注释层级: + +- 函数文档注释(`///`)至少写清 4 件事: + - 信号在策略中的业务含义 + - 参数模板字符串(`param_template`) + - 核心判定逻辑(触发条件) + - 边界行为(数据不足返回什么) +- 关键分支前写“原因注释”,解释判定依据和设计意图。 +- 与 Python/旧版行为对齐时,标注“对齐对象”和“为何这样对齐”。 + +## 1. K线级函数模板(bar/tas/cxt) + +```rust +use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::{get_str_param, get_usize_param}; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use czsc_signal_macros::signal; + +/// xxx_v240101: 示例信号 +/// +/// 参数模板:"{freq}_D{di}{key}_示例V240101" +/// 信号语义:当满足示例条件时输出目标状态,否则输出“其他”。 +/// 边界行为:当 bars 不足以支持 di 回看时,返回空信号。 +#[signal( + category = "kline", + name = "xxx_V240101", + template = "{freq}_D{di}{key}_示例V240101", + opcode = "XxxV240101", + param_kind = "XxxV240101" +)] +pub fn xxx_v240101( + czsc: &CZSC, + params: &ParamView, + cache: &mut TaCache, +) -> Vec { + let di = get_usize_param(params, "di", 1); + let key = get_str_param(params, "key", "DEFAULT"); + + // 1) 边界检查:数据不足时不输出错误信号,直接返回空结果。 + if czsc.bars_raw.len() < di + 2 { + return vec![]; + } + + // 2) 计算 k1/k2/k3 与 v1/v2/v3: + // k1/k2/k3 用于事件匹配 key,v1/v2/v3 承载状态值语义。 + let k1 = czsc.freq.to_string(); + let k2 = format!("D{}{}", di, key); + let k3 = "示例V240101"; + let v1 = "其他"; + let v2 = "任意"; + let v3 = "任意"; + + // 3) 严格 7 段 + let sig_str = format!("{}_{}_{}_{}_{}_{}_0", k1, k2, k3, v1, v2, v3); + Signal::from_str(&sig_str).map_or_else(|_| vec![], |s| vec![s]) +} +``` + +## 2. Trader级函数模板(pos) + +```rust +use crate::params::ParamView; +use crate::utils::sig::get_str_param; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::state::TraderState; +use czsc_signal_macros::signal; +use std::str::FromStr; + +/// pos_xxx_v240101: 示例 Trader 级信号 +/// +/// 参数模板:"{pos_name}_示例V240101" +/// 信号语义:根据仓位是否存在输出“有效/其他”。 +/// 边界行为:缺少 pos_name 或查询不到仓位时输出“其他”。 +#[signal( + category = "trader", + name = "pos_xxx_V240101", + template = "{pos_name}_示例V240101", + opcode = "PosXxxV240101", + param_kind = "PosXxxV240101" +)] +pub fn pos_xxx_v240101(cat: &dyn TraderState, params: &ParamView) -> Vec { + let pos_name = get_str_param(params, "pos_name", "").to_string(); + + let k1 = format!("{}_状态", pos_name); + let k2 = "其他"; + let k3 = "示例V240101"; + + // Trader 级信号必须解释“为何读取 trader 状态”,避免后续误改为纯 K线逻辑。 + let v1 = if cat.get_position(&pos_name).is_some() { + "有效" + } else { + "其他" + }; + + let sig_str = format!("{}_{}_{}_{}_任意_任意_0", k1, k2, k3, v1); + Signal::from_str(&sig_str).map_or_else(|_| vec![], |s| vec![s]) +} +``` + +## 3. 注册模板 + +通过 `#[signal(...)]` 自动注册,无需手写 `registry.rs` 列表。 +关键字段: + +```rust +#[signal( + category = "kline|trader", + name = "..._Vxxxxxx", + template = "...", + opcode = "...", + param_kind = "..." +)] +``` + +## 4. 参数与模板约束 + +- `params` key 名与模板占位符保持一致。 +- 模板中 `freq` 表示周期;Trader 级常常不用 `freq` 字段做路由,依赖 `freq1` 之类业务参数。 +- 维护 `k2/k3` 语义稳定,避免同名逻辑漂移。 + +## 5. 事件触发兼容性检查 + +新增信号用于事件时,先写出事件侧信号字符串,再倒推函数输出: + +1. 事件侧 `Signal` 是否 7 段且 score 合法 +2. 事件侧 `k1_k2_k3` 是否与函数输出完全一致 +3. 事件侧 `v1/v2/v3` 是否允许 `任意` +4. 事件侧 score 是否不高于函数输出 score + +## 6. 常见坑 + +- 在 `SignalConfig` 里用错注册名(尤其 `_V` 大写) +- `freq: None` / `Some` 配置反了,导致函数完全不被调度 +- 直接 `unwrap` 导致边界数据 panic;优先 `map_or_else` 或早返回 +- 函数输出多信号时 key 冲突未预期,后写会覆盖 `s` 字典同 key 值 diff --git a/.claude/skills/czsc-write-signal-function/scripts/new_signal_stub.py b/.claude/skills/czsc-write-signal-function/scripts/new_signal_stub.py new file mode 100644 index 000000000..1f68d0ffc --- /dev/null +++ b/.claude/skills/czsc-write-signal-function/scripts/new_signal_stub.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +"""Generate Rust stubs for CZSC signal functions (auto-registered by #[signal]).""" + +from __future__ import annotations + +import argparse +import re +import sys + + +def to_registry_name(func_name: str) -> str: + m = re.match(r"^(.*)_v(\d{6})$", func_name) + if not m: + return func_name + return f"{m.group(1)}_V{m.group(2)}" + + +def to_pascal_case(name: str) -> str: + return "".join(part.capitalize() for part in name.split("_") if part) + + +def derive_opcode(func_name: str) -> str: + m = re.match(r"^(.*)_v(\d{6})$", func_name) + if not m: + return to_pascal_case(func_name) + return f"{to_pascal_case(m.group(1))}V{m.group(2)}" + + +def kline_stub(func_name: str) -> str: + registry_name = to_registry_name(func_name) + opcode = derive_opcode(func_name) + return f'''use crate::params::ParamView; +use crate::types::TaCache; +use crate::utils::sig::get_usize_param; +use czsc_signal_macros::signal; +use czsc_core::analyze::CZSC; +use czsc_core::objects::signal::Signal; +use std::str::FromStr; + +/// {func_name}: TODO 用一句话描述信号业务含义 +/// +/// 参数模板:\"{{freq}}_D{{di}}_{registry_name}\" +/// 判定逻辑:TODO 说明触发条件与 v1/v2/v3 的语义映射 +/// 边界行为:当数据不足时返回空信号,避免输出误导状态 +#[signal( + category = "kline", + name = "{registry_name}", + template = "{{freq}}_D{{di}}_{registry_name}", + opcode = "{opcode}", + param_kind = "{opcode}" +)] +pub fn {func_name}( + czsc: &CZSC, + params: &ParamView, + _cache: &mut TaCache, +) -> Vec {{ + // 参数读取:统一从 params 提取,保证默认值行为可预期 + let di = get_usize_param(params, "di", 1); + + // 边界检查:避免 di 回看越界导致 panic 或错误信号 + if czsc.bars_raw.len() < di + 2 {{ + return vec![]; + }} + + // 组装信号 7 段:k1/k2/k3 决定匹配键,v1/v2/v3 表达状态值 + let k1 = czsc.freq.to_string(); + let k2 = format!("D{{}}", di); + let k3 = "示例"; + let v1 = "其他"; + + let sig_str = format!("{{}}_{{}}_{{}}_{{}}_任意_任意_0", k1, k2, k3, v1); + Signal::from_str(&sig_str).map_or_else(|_| vec![], |s| vec![s]) +}} +''' + + +def trader_stub(func_name: str) -> str: + registry_name = to_registry_name(func_name) + opcode = derive_opcode(func_name) + return f'''use crate::params::ParamView; +use crate::utils::sig::get_str_param; +use czsc_signal_macros::signal; +use czsc_core::objects::signal::Signal; +use czsc_core::objects::state::TraderState; +use std::str::FromStr; + +/// {func_name}: TODO 用一句话描述 Trader 级信号业务含义 +/// +/// 参数模板:\"{{pos_name}}_{registry_name}\" +/// 判定逻辑:TODO 说明如何读取 TraderState 并映射为信号值 +/// 边界行为:查询不到仓位或关键参数缺失时输出“其他” +#[signal( + category = "trader", + name = "{registry_name}", + template = "{{pos_name}}_{registry_name}", + opcode = "{opcode}", + param_kind = "{opcode}" +)] +pub fn {func_name}( + cat: &dyn TraderState, + params: &ParamView, +) -> Vec {{ + // 参数读取:pos_name 是 Trader 级信号常见路由参数 + let pos_name = get_str_param(params, "pos_name", "").to_string(); + + // 读取 TraderState:说明该信号依赖仓位状态而非单周期 K线 + let k1 = format!("{{}}_状态", pos_name); + let v1 = if cat.get_position(&pos_name).is_some() {{ + "有效" + }} else {{ + "其他" + }}; + + let sig_str = format!("{{}}_其他_示例_{{}}_任意_任意_0", k1, v1); + Signal::from_str(&sig_str).map_or_else(|_| vec![], |s| vec![s]) +}} +''' + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--kind", choices=["kline", "trader"], required=True) + parser.add_argument("--func", required=True, help="Rust function name, e.g. tas_xxx_v240101") + args = parser.parse_args() + + func_name = args.func.strip() + + if args.kind == "kline": + func_body = kline_stub(func_name) + else: + func_body = trader_stub(func_name) + + sys.stdout.write("=== Function Stub ===\n") + sys.stdout.write(func_body) + sys.stdout.write( + "\n=== Notes ===\n使用 #[signal(...)] 自动注册,无需手写 registry.rs 注册项。\n" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/.claude/skills/signal-functions/SKILL.md b/.claude/skills/signal-functions/SKILL.md new file mode 100644 index 000000000..acf7e485f --- /dev/null +++ b/.claude/skills/signal-functions/SKILL.md @@ -0,0 +1,117 @@ +--- +name: signal-functions +description: rs_czsc 信号函数完整参考手册。查询信号函数用法、参数模板、模块分类时使用。触发场景:(1)查找某个信号函数的参数模板或含义(2)按模块浏览所有可用信号(3)编写 signals_config 时需要知道信号名称和参数(4)理解信号字符串格式(5)配置 CzscTrader/CzscSignals 的信号列表。不触发:编写新信号函数的 Rust 实现代码。 +--- + +# rs_czsc 信号函数参考 + +## 概述 + +rs_czsc 通过 Rust 实现、PyO3 暴露的信号函数体系,提供 232 个信号函数,分两大类: + +- **K线级信号 (kline)**: 208 个,基于 CZSC 分析结果 + K线数据计算 +- **交易级信号 (trader)**: 24 个,基于持仓/策略状态计算 + +## Python API + +```python +from rs_czsc import list_all_signals, derive_signals_config, derive_signals_freqs + +# 列出所有注册信号 +signals = list_all_signals(include_kline=True, include_trader=True) +# 返回 list[dict],每项: {name, param_template, category, namespace} + +# 从信号字符串反推 signals_config +config = derive_signals_config(["60分钟_D1SMA#5_分类V221101_多头_向上_任意_0"]) + +# 从 signals_config 提取所需周期 +freqs = derive_signals_freqs(config) +``` + +## 信号字符串格式 + +7 段式:`k1_k2_k3_v1_v2_v3_score` + +``` +60分钟_D1SMA#5_分类V221101_多头_向上_任意_0 + ├─k1──┤ ├─k2───┤ ├k3──────┤ ├v1┤ ├v2┤ ├v3┤ ├s┤ +``` + +- `k1,k2,k3`:键字段(通常为:周期、参数描述、版本标签) +- `v1,v2,v3`:值字段(信号状态,如"看多"/"看空"/"其他") +- `score`:整数 0-100 + +## 通用模板参数 + +| 参数 | 含义 | 示例 | +|------|------|------| +| `{freq}` | K线周期 | `60分钟`, `日线`, `周线` | +| `{di}` | 倒数第几根K线/笔,0=当前 | `0`, `1`, `2` | +| `{n}` / `{m}` | 窗口长度 / 辅助窗口 | `5`, `10`, `20` | +| `{th}` | 阈值 | `100`, `500` | +| `{ma_type}` | 均线类型 | `SMA`, `EMA`, `WMA` | +| `{timeperiod}` | 均线/指标周期 | `5`, `10`, `20` | +| `{fastperiod}` / `{slowperiod}` / `{signalperiod}` | MACD 参数 | `12`, `26`, `9` | +| `{max_overlap}` | 最大重叠次数 | `1`, `3` | +| `{pos_name}` | 持仓名称(trader信号) | `多头持仓` | +| `{freq1}` / `{freq2}` | 双周期信号的两个周期 | `60分钟`, `日线` | + +## 信号模块索引 + +### K线级信号 + +| 模块 | 数量 | 说明 | 模块索引 | +|------|------|------|----------| +| **bar** | 46 | K线形态、动量、突破、统计 | [signals-bar.md](references/signals-bar.md) | +| **cxt** | 41 | 缠论笔、分型、买卖点、形态 | [signals-cxt.md](references/signals-cxt.md) | +| **tas** | 59 | MACD/均线/布林/KDJ/RSI/ATR等 | [signals-tas.md](references/signals-tas.md) | +| **jcc** | 19 | 日本蜡烛图经典形态 | [signals-jcc.md](references/signals-jcc.md) | +| **ang** | 10 | ADTM/AMV/ASI/CMO/SKDJ等辅助指标 | [signals-ang.md](references/signals-ang.md) | +| **xl** | 7 | XL系列(位置/趋势/突破/通道) | [signals-xl.md](references/signals-xl.md) | +| **vol** | 6 | 成交量均线/缩量/高低/窗口能量 | [signals-vol.md](references/signals-vol.md) | +| **coo** | 5 | TD序列/CCI/KDJ/SAR组合 | [signals-coo.md](references/signals-coo.md) | +| **byi** | 5 | 对称中枢/停顿分型/验证分型 | [signals-byi.md](references/signals-byi.md) | +| **pressure** | 4 | 支撑压力位 | [signals-pressure.md](references/signals-pressure.md) | +| **obv** | 2 | OBV能量 | [signals-obv.md](references/signals-obv.md) | +| **cvolp** | 1 | CVOLP动量变化率 | [signals-cvolp.md](references/signals-cvolp.md) | +| **ntmdk** | 1 | NTMDK多空 | [signals-ntmdk.md](references/signals-ntmdk.md) | +| **kcatr** | 1 | KCATR多空 | [signals-kcatr.md](references/signals-kcatr.md) | +| **clv** | 1 | CLV多空 | [signals-clv.md](references/signals-clv.md) | + +### 交易级信号 + +| 模块 | 数量 | 说明 | 模块索引 | +|------|------|------|----------| +| **pos** | 16 | 持仓管理(止损/止盈/保本/状态) | [signals-pos.md](references/signals-pos.md) | +| **zdy_trader** | 4 | 自定义交易(震荡/止损/止盈) | [signals-zdy_trader.md](references/signals-zdy_trader.md) | +| **cat** | 2 | MACD联立信号 | [signals-cat.md](references/signals-cat.md) | +| **cxt_trader** | 2 | 缠论交易(中枢共振/日内走势) | [signals-cxt_trader.md](references/signals-cxt_trader.md) | + +### 单信号详细文档 + +每个信号函数都有独立的详细文档,包含**信号逻辑**、**信号列表示例**、**参数说明**。 + +文件位于 `references/signals/{signal_name}.md`,如 [er_up_dw_line_V230604.md](references/signals/er_up_dw_line_V230604.md)。 + +## 使用示例 + +### 在 CzscSignals/CzscTrader 中配置信号 + +```python +signals_config = [ + {"name": "tas_ma_base_V221101", "freq": "60分钟", "di": 1, "ma_type": "SMA", "timeperiod": 5}, + {"name": "tas_macd_base_V221028", "freq": "日线", "di": 1, "fastperiod": 12, "slowperiod": 26, "signalperiod": 9}, + {"name": "cxt_bi_end_V230224", "freq": "30分钟", "di": 1}, + {"name": "pos_stop_V240428", "freq1": "60分钟", "pos_name": "多头", "t": 200, "n": 3}, +] +``` + +### 从信号字符串反推配置 + +```python +from rs_czsc import derive_signals_config + +signal_str = "60分钟_D1SMA#5_分类V221101_多头_向上_任意_0" +config = derive_signals_config([signal_str]) +# 返回: [{"name": "tas_ma_base_V221101", "freq": "60分钟", "di": 1, "ma_type": "SMA", "timeperiod": 5}] +``` diff --git a/.claude/skills/signal-functions/references/signals-ang.md b/.claude/skills/signal-functions/references/signals-ang.md new file mode 100644 index 000000000..3ce29d19f --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-ang.md @@ -0,0 +1,17 @@ +# ang 模块信号索引 + +> 源码: `crates/czsc-signals/src/ang.rs` +> 共 10 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `adtm_up_dw_line_V230603` | `"{freq}_D{di}N{n}M{m}TH{th}_ADTMV230603` | ADTM 能量异动多空信号 | [详细文档](signals/adtm_up_dw_line_V230603.md) | +| `amv_up_dw_line_V230603` | `"{freq}_D{di}N{n}M{m}_AMV能量V230603` | AMV 能量多空信号 | [详细文档](signals/amv_up_dw_line_V230603.md) | +| `asi_up_dw_line_V230603` | `"{freq}_D{di}N{n}P{p}_ASI多空V230603` | ASI 多空信号 | [详细文档](signals/asi_up_dw_line_V230603.md) | +| `bias_up_dw_line_V230618` | `"{freq}_D{di}N{n}M{m}P{p}TH1{th1}TH2{th2}TH3{th3}_BIAS乖离率V230618` | BIAS 三周期共振信号 | [详细文档](signals/bias_up_dw_line_V230618.md) | +| `cmo_up_dw_line_V230605` | `"{freq}_D{di}N{n}M{m}_CMO能量V230605` | CMO 能量阈值信号 | [详细文档](signals/cmo_up_dw_line_V230605.md) | +| `dema_up_dw_line_V230605` | `"{freq}_D{di}N{n}_DEMA短线趋势V230605` | DEMA 短线趋势信号 | [详细文档](signals/dema_up_dw_line_V230605.md) | +| `demakder_up_dw_line_V230605` | `"{freq}_D{di}N{n}TH{th}TL{tl}_DEMAKER价格趋势V230605` | DEMAKER 价格趋势信号 | [详细文档](signals/demakder_up_dw_line_V230605.md) | +| `emv_up_dw_line_V230605` | `"{freq}_D{di}_EMV简易波动V230605` | EMV 简易波动多空信号 | [详细文档](signals/emv_up_dw_line_V230605.md) | +| `er_up_dw_line_V230604` | `"{freq}_D{di}W{w}N{n}_ER价格动量V230604` | ER 价格动量分层信号 | [详细文档](signals/er_up_dw_line_V230604.md) | +| `skdj_up_dw_line_V230611` | `"{freq}_D{di}N{n}M{m}UP{up}DW{dw}_SKDJ随机波动V230611` | SKDJ 随机波动信号 | [详细文档](signals/skdj_up_dw_line_V230611.md) | diff --git a/.claude/skills/signal-functions/references/signals-bar.md b/.claude/skills/signal-functions/references/signals-bar.md new file mode 100644 index 000000000..8a979227c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-bar.md @@ -0,0 +1,53 @@ +# bar 模块信号索引 + +> 源码: `crates/czsc-signals/src/bar.rs` +> 共 46 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `bar_accelerate_V221110` | `"{freq}_D{di}W{window}_加速V221110` | 区间加速走势判定 | [详细文档](signals/bar_accelerate_V221110.md) | +| `bar_accelerate_V221118` | `"{freq}_D{di}W{window}#{ma_type}#{timeperiod}_加速V221118` | 均线偏离加速判定 | [详细文档](signals/bar_accelerate_V221118.md) | +| `bar_accelerate_V240428` | `"{freq}_D{di}W{w}T{t}_加速V240428` | 滚动差分加速判定 | [详细文档](signals/bar_accelerate_V240428.md) | +| `bar_amount_acc_V230214` | `"{freq}_D{di}N{n}_累计超{t}千万` | 区间累计成交额信号 | [详细文档](signals/bar_amount_acc_V230214.md) | +| `bar_big_solid_V230215` | `"{freq}_D{di}N{n}_MIDV230215` | 窗口最大实体中位多空信号 | [详细文档](signals/bar_big_solid_V230215.md) | +| `bar_bpm_V230227` | `"{freq}_D{di}N{n}T{th}_绝对动量V230227` | 绝对动量分层 | [详细文档](signals/bar_bpm_V230227.md) | +| `bar_break_V240428` | `"{freq}_D{di}W{w}_事件V240428` | 收盘极值突破 | [详细文档](signals/bar_break_V240428.md) | +| `bar_channel_V230508` | `"{freq}_D{di}M{m}_通道V230507` | 窄幅通道方向判定 | [详细文档](signals/bar_channel_V230508.md) | +| `bar_classify_V240606` | `"{freq}_D{di}收盘位置_分类V240606` | 单根K线收盘位置分类 | [详细文档](signals/bar_classify_V240606.md) | +| `bar_classify_V240607` | `"{freq}_D{di}K2收盘位置_分类V240607` | 两根K线收盘位置分类 | [详细文档](signals/bar_classify_V240607.md) | +| `bar_decision_V240608` | `"{freq}_W{w}N{n}Q{q}放量_决策区域V240608` | 放量反向决策区 | [详细文档](signals/bar_decision_V240608.md) | +| `bar_decision_V240616` | `"{freq}_W{w}N{n}强弱_决策区域V240616` | 新高新低后的强弱决策 | [详细文档](signals/bar_decision_V240616.md) | +| `bar_dual_thrust_V230403` | `"{freq}_D{di}通道突破#{N}#{K1}#{K2}_BS辅助V230403` | Dual Thrust 通道突破 | [详细文档](signals/bar_dual_thrust_V230403.md) | +| `bar_eight_V230702` | `"{freq}_D{di}#8K_走势分类V230702` | 8K 走势分类 | [详细文档](signals/bar_eight_V230702.md) | +| `bar_end_V221211` | `"{freq}_{freq1}结束_BS辅助221211` | 判断大周期K线是否闭合 | [详细文档](signals/bar_end_V221211.md) | +| `bar_fake_break_V230204` | `"{freq}_D{di}N{n}M{m}_假突破V230204` | 区间假突破判定 | [详细文档](signals/bar_fake_break_V230204.md) | +| `bar_fang_liang_break_V221216` | `"{freq}_D{di}TH{th}#{ma_type}#{timeperiod}_突破V221216` | 放量突破与缩量回踩 | [详细文档](signals/bar_fang_liang_break_V221216.md) | +| `bar_limit_down_V230525` | `"{freq}_跌停后无下影线长实体阳线_短线V230525` | 跌停后反包阳线 | [详细文档](signals/bar_limit_down_V230525.md) | +| `bar_mean_amount_V221112` | `"{freq}_D{di}K{n}B均额_{th1}至{th2}千万` | 区间均额分类信号 | [详细文档](signals/bar_mean_amount_V221112.md) | +| `bar_operate_span_V221111` | `"{freq}_T{t1}#{t2}_时间区间V221111` | 日内时间区间过滤 | [详细文档](signals/bar_operate_span_V221111.md) | +| `bar_plr_V240427` | `"{freq}_D{di}W{w}T{t}M{m}_盈亏比V240427` | 盈亏比约束 | [详细文档](signals/bar_plr_V240427.md) | +| `bar_polyfit_V240428` | `"{freq}_D{di}W{w}_分类V240428` | 一阶二阶拟合分类 | [详细文档](signals/bar_polyfit_V240428.md) | +| `bar_r_breaker_V230326` | `"{freq}_RBreaker_BS辅助V230326` | RBreaker 价格位判定 | [详细文档](signals/bar_r_breaker_V230326.md) | +| `bar_reversal_V230227` | `"{freq}_D{di}A{avg_bp}_反转V230227` | 末根反转迹象判定 | [详细文档](signals/bar_reversal_V230227.md) | +| `bar_section_momentum_V221112` | `"{freq}_D{di}K{n}B_阈值{th}BPV221112` | 区间动量强弱与波动 | [详细文档](signals/bar_section_momentum_V221112.md) | +| `bar_shuang_fei_V230507` | `"{freq}_D{di}双飞_短线V230507` | 双飞涨停形态 | [详细文档](signals/bar_shuang_fei_V230507.md) | +| `bar_single_V230214` | `"{freq}_D{di}T{t}_状态V230214` | 单K状态信号 | [详细文档](signals/bar_single_V230214.md) | +| `bar_single_V230506` | `"{freq}_D{di}单K趋势N{n}_BS辅助V230506` | 单K趋势分层信号 | [详细文档](signals/bar_single_V230506.md) | +| `bar_td9_V240616` | `"{freq}_神奇九转N{n}_BS辅助V240616` | 神奇九转计数 | [详细文档](signals/bar_td9_V240616.md) | +| `bar_time_V230327` | `"{freq}_日内时间_分段V230327` | 日内时间分段信号 | [详细文档](signals/bar_time_V230327.md) | +| `bar_tnr_V230629` | `"{freq}_D{di}TNR{timeperiod}_趋势V230629` | TNR 分层信号 | [详细文档](signals/bar_tnr_V230629.md) | +| `bar_tnr_V230630` | `"{freq}_D{di}TNR{timeperiod}K{k}_趋势V230630` | TNR 噪音变化判定 | [详细文档](signals/bar_tnr_V230630.md) | +| `bar_trend_V240209` | `"{freq}_D{di}N{N}趋势跟踪_BS辅助V240209` | 趋势跟踪结构判定 | [详细文档](signals/bar_trend_V240209.md) | +| `bar_triple_V230506` | `"{freq}_D{di}三K加速_裸K形态V230506` | 三K加速形态信号 | [详细文档](signals/bar_triple_V230506.md) | +| `bar_vol_bs1_V230224` | `"{freq}_D{di}N{n}量价_BS1辅助V230224` | 量价高低点辅助 | [详细文档](signals/bar_vol_bs1_V230224.md) | +| `bar_vol_grow_V221112` | `"{freq}_D{di}K{n}B_放量V221112` | 成交量放大信号 | [详细文档](signals/bar_vol_grow_V221112.md) | +| `bar_volatility_V241013` | `"{freq}_波动率分层W{w}N{n}_完全分类V241013` | 波动率三层分类 | [详细文档](signals/bar_volatility_V241013.md) | +| `bar_weekday_V230328` | `"{freq}_周内时间_分段V230328` | 周内时间分段信号 | [详细文档](signals/bar_weekday_V230328.md) | +| `bar_window_ps_V230731` | `"{freq}_W{w}M{m}N{n}L{l}_支撑压力位V230731` | 支撑压力位分位特征 | [详细文档](signals/bar_window_ps_V230731.md) | +| `bar_window_ps_V230801` | `"{freq}_N{n}W{w}_支撑压力位V230801` | 支撑压力位窗口极值 | [详细文档](signals/bar_window_ps_V230801.md) | +| `bar_window_std_V230731` | `"{freq}_D{di}W{w}M{m}N{n}_窗口波动V230731` | 窗口波动分层特征 | [详细文档](signals/bar_window_std_V230731.md) | +| `bar_zdf_V221203` | `"{freq}_D{di}{mode}_{t1}至{t2}` | 单根涨跌幅区间信号 | [详细文档](signals/bar_zdf_V221203.md) | +| `bar_zdt_V230331` | `"{freq}_D{di}_涨跌停V230331` | 涨跌停识别信号 | [详细文档](signals/bar_zdt_V230331.md) | +| `bar_zfzd_V241013` | `"{freq}_窄幅震荡N{n}_形态V241013` | 窄幅震荡(全重叠) | [详细文档](signals/bar_zfzd_V241013.md) | +| `bar_zfzd_V241014` | `"{freq}_窄幅震荡N{n}_形态V241014` | 窄幅震荡(最大实体重叠) | [详细文档](signals/bar_zfzd_V241014.md) | +| `bar_zt_count_V230504` | `"{freq}_D{di}W{window}涨停计数_裸K形态V230504` | 窗口涨停计数 | [详细文档](signals/bar_zt_count_V230504.md) | diff --git a/.claude/skills/signal-functions/references/signals-byi.md b/.claude/skills/signal-functions/references/signals-byi.md new file mode 100644 index 000000000..1fb231098 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-byi.md @@ -0,0 +1,12 @@ +# byi 模块信号索引 + +> 源码: `crates/czsc-signals/src/byi.rs` +> 共 5 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `byi_bi_end_V230106` | `"{freq}_D0停顿分型_BE辅助V230106` | 分型停顿辅助笔结束信号 | [详细文档](signals/byi_bi_end_V230106.md) | +| `byi_bi_end_V230107` | `"{freq}_D0验证分型_BE辅助V230107` | 验证分型辅助笔结束信号 | [详细文档](signals/byi_bi_end_V230107.md) | +| `byi_fx_num_V230628` | `"{freq}_D{di}笔分型数大于{num}_BE辅助V230628` | 前笔分型数量约束信号 | [详细文档](signals/byi_fx_num_V230628.md) | +| `byi_second_bs_V230324` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}回抽零轴_BS2辅助V230324` | 二类买卖点辅助信号 | [详细文档](signals/byi_second_bs_V230324.md) | +| `byi_symmetry_zs_V221107` | `"{freq}_D{di}B_对称中枢` | 对称中枢识别信号 | [详细文档](signals/byi_symmetry_zs_V221107.md) | diff --git a/.claude/skills/signal-functions/references/signals-cat.md b/.claude/skills/signal-functions/references/signals-cat.md new file mode 100644 index 000000000..b22f036ce --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-cat.md @@ -0,0 +1,9 @@ +# cat 模块信号索引 + +> 源码: `crates/czsc-signals/src/cat.rs` +> 共 2 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `cat_macd_V230518` | `"{freq1}#{freq2}_MACD交叉_联立V230518` | 高低级别 MACD 交叉联立信号 | [详细文档](signals/cat_macd_V230518.md) | +| `cat_macd_V230520` | `"{freq1}#{freq2}_MACD交叉_联立V230520` | 高低级别 MACD 缩柱联立信号 | [详细文档](signals/cat_macd_V230520.md) | diff --git a/.claude/skills/signal-functions/references/signals-clv.md b/.claude/skills/signal-functions/references/signals-clv.md new file mode 100644 index 000000000..85da56f58 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-clv.md @@ -0,0 +1,8 @@ +# clv 模块信号索引 + +> 源码: `crates/czsc-signals/src/clv.rs` +> 共 1 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `clv_up_dw_line_V230605` | `"{freq}_D{di}N{n}_CLV多空V230605` | CLV 多空信号 | [详细文档](signals/clv_up_dw_line_V230605.md) | diff --git a/.claude/skills/signal-functions/references/signals-coo.md b/.claude/skills/signal-functions/references/signals-coo.md new file mode 100644 index 000000000..18bfcb198 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-coo.md @@ -0,0 +1,12 @@ +# coo 模块信号索引 + +> 源码: `crates/czsc-signals/src/coo.rs` +> 共 5 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `coo_cci_V230323` | `"{freq}_D{di}CCI{n}#{ma_type}#{m}_BS辅助V230323` | CCI 结合均线的多空与方向信号 | [详细文档](signals/coo_cci_V230323.md) | +| `coo_kdj_V230322` | `"{freq}_D{di}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{ma_type}#{n}_BS辅助V230322` | 均线与 KDJ 配合多空信号 | [详细文档](signals/coo_kdj_V230322.md) | +| `coo_sar_V230325` | `"{freq}_D{di}N{n}SAR_BS辅助V230325` | SAR 与区间极值配合信号 | [详细文档](signals/coo_sar_V230325.md) | +| `coo_td_V221110` | `"{freq}_D{di}K_TD` | TD 神奇九转信号(旧版模板) | [详细文档](signals/coo_td_V221110.md) | +| `coo_td_V221111` | `"{freq}_D{di}TD_BS辅助V221111` | TD 神奇九转信号 | [详细文档](signals/coo_td_V221111.md) | diff --git a/.claude/skills/signal-functions/references/signals-cvolp.md b/.claude/skills/signal-functions/references/signals-cvolp.md new file mode 100644 index 000000000..7ede3d9a3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-cvolp.md @@ -0,0 +1,8 @@ +# cvolp 模块信号索引 + +> 源码: `crates/czsc-signals/src/cvolp.rs` +> 共 1 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `cvolp_up_dw_line_V230612` | `"{freq}_D{di}N{n}M{m}UP{up}DW{dw}_CVOLP动量变化率V230612` | CVOLP 动量变化率信号 | [详细文档](signals/cvolp_up_dw_line_V230612.md) | diff --git a/.claude/skills/signal-functions/references/signals-cxt.md b/.claude/skills/signal-functions/references/signals-cxt.md new file mode 100644 index 000000000..10535df93 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-cxt.md @@ -0,0 +1,48 @@ +# cxt 模块信号索引 + +> 源码: `crates/czsc-signals/src/cxt.rs` +> 共 41 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `cxt_bi_base_V230228` | `"{freq}_D0BL{bi_init_length}_V230228` | 笔基础状态信号 | [详细文档](signals/cxt_bi_base_V230228.md) | +| `cxt_bi_end_V230104` | `"{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230104` | 单均线辅助判断笔结束 | [详细文档](signals/cxt_bi_end_V230104.md) | +| `cxt_bi_end_V230105` | `"{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230105` | K线形态+均线辅助判断笔结束 | [详细文档](signals/cxt_bi_end_V230105.md) | +| `cxt_bi_end_V230222` | `"{freq}_D1MO{max_overlap}_BE辅助V230222` | 未完成笔分型新高新低次数 | [详细文档](signals/cxt_bi_end_V230222.md) | +| `cxt_bi_end_V230224` | `"{freq}_D1_BE辅助V230224` | 量价配合笔结束辅助 | [详细文档](signals/cxt_bi_end_V230224.md) | +| `cxt_bi_end_V230312` | `"{freq}_D0MACD{fastperiod}#{slowperiod}#{signalperiod}_BE辅助V230312` | MACD辅助判断笔结束 | [详细文档](signals/cxt_bi_end_V230312.md) | +| `cxt_bi_end_V230320` | `"{freq}_D0质数窗口MO{max_overlap}_BE辅助V230320` | 质数窗口笔结束辅助 | [详细文档](signals/cxt_bi_end_V230320.md) | +| `cxt_bi_end_V230322` | `"{freq}_D0分型配合{ma_type}#{timeperiod}_BE辅助V230322` | 分型配合均线的笔结束辅助 | [详细文档](signals/cxt_bi_end_V230322.md) | +| `cxt_bi_end_V230324` | `"{freq}_D0{ma_type}#{timeperiod}均线突破_BE辅助V230324` | 笔结束分型均线突破 | [详细文档](signals/cxt_bi_end_V230324.md) | +| `cxt_bi_end_V230618` | `"{freq}_D{di}MO{max_overlap}_BE辅助V230618` | 笔结束小中枢辅助 | [详细文档](signals/cxt_bi_end_V230618.md) | +| `cxt_bi_end_V230815` | `"{freq}_快速突破_BE辅助V230815` | 快速突破反向笔 | [详细文档](signals/cxt_bi_end_V230815.md) | +| `cxt_bi_status_V230101` | `"{freq}_D1_表里关系V230101` | 笔表里关系信号 | [详细文档](signals/cxt_bi_status_V230101.md) | +| `cxt_bi_status_V230102` | `"{freq}_D1_表里关系V230102` | 笔表里关系信号 | [详细文档](signals/cxt_bi_status_V230102.md) | +| `cxt_bi_stop_V230815` | `"{freq}_距离{th}BP_止损V230815` | 笔止损距离状态 | [详细文档](signals/cxt_bi_stop_V230815.md) | +| `cxt_bi_trend_V230824` | `"{freq}_D{di}N{n}TH{th}_形态V230824` | N笔形态判断 | [详细文档](signals/cxt_bi_trend_V230824.md) | +| `cxt_bi_trend_V230913` | `"{freq}_D{di}N{n}笔趋势_高低点辅助判断V230913` | 笔趋势高低点回归信号 | [详细文档](signals/cxt_bi_trend_V230913.md) | +| `cxt_bi_zdf_V230601` | `"{freq}_D{di}N{n}_分层V230601` | BI涨跌幅分层 | [详细文档](signals/cxt_bi_zdf_V230601.md) | +| `cxt_bs_V240526` | `"{freq}_趋势跟随_BS辅助V240526` | 趋势跟随 BS 辅助 | [详细文档](signals/cxt_bs_V240526.md) | +| `cxt_bs_V240527` | `"{freq}_趋势跟随_BS辅助V240527` | 未完成笔上的趋势跟随 BS 辅助 | [详细文档](signals/cxt_bs_V240527.md) | +| `cxt_decision_V240526` | `"{freq}_分型区域N{n}_决策区域V240526` | 分型区域决策 | [详细文档](signals/cxt_decision_V240526.md) | +| `cxt_decision_V240612` | `"{freq}_W{w}N{n}高低点_决策区域V240612` | 高低点N档决策区间 | [详细文档](signals/cxt_decision_V240612.md) | +| `cxt_decision_V240613` | `"{freq}_放量笔N{n}BS2_决策区域V240613` | 放量笔N4BS2决策区 | [详细文档](signals/cxt_decision_V240613.md) | +| `cxt_decision_V240614` | `"{freq}_放量笔N{n}_决策区域V240614` | 放量新高/新低决策区 | [详细文档](signals/cxt_decision_V240614.md) | +| `cxt_double_zs_V230311` | `"{freq}_D{di}双中枢_BS1辅助V230311` | 双中枢 BS1 辅助 | [详细文档](signals/cxt_double_zs_V230311.md) | +| `cxt_eleven_bi_V230622` | `"{freq}_D{di}十一笔_形态V230622` | 十一笔形态分类信号 | [详细文档](signals/cxt_eleven_bi_V230622.md) | +| `cxt_first_buy_V221126` | `"{freq}_D{di}B_BUY1V221126` | 一买信号 | [详细文档](signals/cxt_first_buy_V221126.md) | +| `cxt_first_sell_V221126` | `"{freq}_D{di}B_SELL1V221126` | 一卖信号 | [详细文档](signals/cxt_first_sell_V221126.md) | +| `cxt_five_bi_V230619` | `"{freq}_D{di}五笔_形态V230619` | 五笔形态分类信号 | [详细文档](signals/cxt_five_bi_V230619.md) | +| `cxt_fx_power_V221107` | `"{freq}_D{di}F_分型强弱V221107` | 倒数分型强弱 | [详细文档](signals/cxt_fx_power_V221107.md) | +| `cxt_nine_bi_V230621` | `"{freq}_D{di}九笔_形态V230621` | 九笔形态分类信号 | [详细文档](signals/cxt_nine_bi_V230621.md) | +| `cxt_overlap_V240526` | `"{freq}_顶底重合_支撑压力V240526` | 收盘价与最近分型区间重合次数 | [详细文档](signals/cxt_overlap_V240526.md) | +| `cxt_overlap_V240612` | `"{freq}_SNR顺畅N{n}_支撑压力V240612` | 顺畅笔分型支撑压力信号 | [详细文档](signals/cxt_overlap_V240612.md) | +| `cxt_range_oscillation_V230620` | `"{freq}_D{di}TH{th}_区间震荡V230620` | 区间震荡笔数统计 | [详细文档](signals/cxt_range_oscillation_V230620.md) | +| `cxt_second_bs_V230320` | `"{freq}_D{di}#{ma_type}#{timeperiod}_BS2辅助V230320` | 均线辅助识别第二类买卖点 | [详细文档](signals/cxt_second_bs_V230320.md) | +| `cxt_second_bs_V240524` | `"{freq}_D{di}W{w}T{t}_第二买卖点V240524` | 第二买卖点重叠计数信号 | [详细文档](signals/cxt_second_bs_V240524.md) | +| `cxt_seven_bi_V230620` | `"{freq}_D{di}七笔_形态V230620` | 七笔形态分类信号 | [详细文档](signals/cxt_seven_bi_V230620.md) | +| `cxt_third_bs_V230318` | `"{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230318` | 均线辅助识别第三类买卖点 | [详细文档](signals/cxt_third_bs_V230318.md) | +| `cxt_third_bs_V230319` | `"{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230319` | 带均线形态的第三类买卖点辅助 | [详细文档](signals/cxt_third_bs_V230319.md) | +| `cxt_third_buy_V230228` | `"{freq}_D{di}_三买辅助V230228` | 笔三买辅助 | [详细文档](signals/cxt_third_buy_V230228.md) | +| `cxt_three_bi_V230618` | `"{freq}_D{di}三笔_形态V230618` | 三笔形态分类信号 | [详细文档](signals/cxt_three_bi_V230618.md) | +| `cxt_ubi_end_V230816` | `"{freq}_UBI_BE辅助V230816` | UBI 新高新低次数信号 | [详细文档](signals/cxt_ubi_end_V230816.md) | diff --git a/.claude/skills/signal-functions/references/signals-cxt_trader.md b/.claude/skills/signal-functions/references/signals-cxt_trader.md new file mode 100644 index 000000000..67a15dd4d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-cxt_trader.md @@ -0,0 +1,9 @@ +# cxt_trader 模块信号索引 + +> 源码: `crates/czsc-signals/src/cxt_trader.rs` +> 共 2 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `cxt_intraday_V230701` | `"{freq1}#{freq2}_D{di}日_走势分类V230701` | 30分钟日内走势分类 | [详细文档](signals/cxt_intraday_V230701.md) | +| `cxt_zhong_shu_gong_zhen_V221221` | `"{freq1}_{freq2}_中枢共振V221221` | 大小级别中枢共振 | [详细文档](signals/cxt_zhong_shu_gong_zhen_V221221.md) | diff --git a/.claude/skills/signal-functions/references/signals-jcc.md b/.claude/skills/signal-functions/references/signals-jcc.md new file mode 100644 index 000000000..99d7581ab --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-jcc.md @@ -0,0 +1,26 @@ +# jcc 模块信号索引 + +> 源码: `crates/czsc-signals/src/jcc.rs` +> 共 19 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `jcc_ci_tou_V221101` | `"{freq}_D{di}Z{z}TH{th}_刺透形态` | 刺透形态 | [详细文档](signals/jcc_ci_tou_V221101.md) | +| `jcc_fan_ji_xian_V221121` | `"{freq}_D{di}_反击线` | 反击线 | [详细文档](signals/jcc_fan_ji_xian_V221121.md) | +| `jcc_fen_shou_xian_V20221113` | `"{freq}_D{di}K_分手线` | 分手线 | [详细文档](signals/jcc_fen_shou_xian_V20221113.md) | +| `jcc_gap_yin_yang_V221121` | `"{freq}_D{di}K_并列阴阳` | 跳空并列阴阳 | [详细文档](signals/jcc_gap_yin_yang_V221121.md) | +| `jcc_ping_tou_V221113` | `"{freq}_D{di}TH{th}_平头形态` | 平头形态 | [详细文档](signals/jcc_ping_tou_V221113.md) | +| `jcc_san_fa_V20221115` | `"{freq}_D{di}K_三法` | 三法形态 | [详细文档](signals/jcc_san_fa_V20221115.md) | +| `jcc_san_fa_V20221118` | `"{freq}_D{di}K_三法A` | 三法形态A | [详细文档](signals/jcc_san_fa_V20221118.md) | +| `jcc_san_szx_V221122` | `"{freq}_D{di}T{th}_三星` | 三星形态 | [详细文档](signals/jcc_san_szx_V221122.md) | +| `jcc_san_xing_xian_V221023` | `"{freq}_D{di}TH{th}_伞形线` | 伞形线形态信号 | [详细文档](signals/jcc_san_xing_xian_V221023.md) | +| `jcc_shan_chun_V221121` | `"{freq}_D{di}B_山川形态` | 山川形态 | [详细文档](signals/jcc_shan_chun_V221121.md) | +| `jcc_szx_V221111` | `"{freq}_D{di}TH{th}_十字线` | 十字线 | [详细文档](signals/jcc_szx_V221111.md) | +| `jcc_ta_xing_V221124` | `"{freq}_D{di}K_塔形` | 塔形顶底 | [详细文档](signals/jcc_ta_xing_V221124.md) | +| `jcc_ten_mo_V221028` | `"{freq}_D{di}_吞没形态` | 吞没形态 | [详细文档](signals/jcc_ten_mo_V221028.md) | +| `jcc_three_crow_V221108` | `"{freq}_D{di}_三只乌鸦` | 三只乌鸦 | [详细文档](signals/jcc_three_crow_V221108.md) | +| `jcc_two_crow_V221108` | `"{freq}_D{di}K_两只乌鸦` | 两只乌鸦 | [详细文档](signals/jcc_two_crow_V221108.md) | +| `jcc_wu_yun_gai_ding_V221101` | `"{freq}_D{di}Z{z}TH{th}_乌云盖顶` | 乌云盖顶 | [详细文档](signals/jcc_wu_yun_gai_ding_V221101.md) | +| `jcc_xing_xian_V221118` | `"{freq}_D{di}TH{th}_星形线` | 星形线 | [详细文档](signals/jcc_xing_xian_V221118.md) | +| `jcc_yun_xian_V221118` | `"{freq}_D{di}_孕线` | 孕线形态 | [详细文档](signals/jcc_yun_xian_V221118.md) | +| `jcc_zhu_huo_xian_V221027` | `"{freq}_D{di}T{th}F{zf}_烛火线` | 烛火线 | [详细文档](signals/jcc_zhu_huo_xian_V221027.md) | diff --git a/.claude/skills/signal-functions/references/signals-kcatr.md b/.claude/skills/signal-functions/references/signals-kcatr.md new file mode 100644 index 000000000..a21d64457 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-kcatr.md @@ -0,0 +1,8 @@ +# kcatr 模块信号索引 + +> 源码: `crates/czsc-signals/src/kcatr.rs` +> 共 1 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `kcatr_up_dw_line_V230823` | `"{freq}_D{di}N{n}M{m}T{th}_KCATR多空V230823` | ATR 通道突破多空 | [详细文档](signals/kcatr_up_dw_line_V230823.md) | diff --git a/.claude/skills/signal-functions/references/signals-ntmdk.md b/.claude/skills/signal-functions/references/signals-ntmdk.md new file mode 100644 index 000000000..debfe39fa --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-ntmdk.md @@ -0,0 +1,8 @@ +# ntmdk 模块信号索引 + +> 源码: `crates/czsc-signals/src/ntmdk.rs` +> 共 1 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `ntmdk_V230824` | `"{freq}_D{di}M{m}_NTMDK多空V230824` | M 日前收盘价对比多空 | [详细文档](signals/ntmdk_V230824.md) | diff --git a/.claude/skills/signal-functions/references/signals-obv.md b/.claude/skills/signal-functions/references/signals-obv.md new file mode 100644 index 000000000..6a9769df8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-obv.md @@ -0,0 +1,9 @@ +# obv 模块信号索引 + +> 源码: `crates/czsc-signals/src/obv.rs` +> 共 2 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `obv_up_dw_line_V230719` | `"{freq}_D{di}N{n}M{m}MO{max_overlap}_OBV能量V230719` | OBV 交叉信号 | [详细文档](signals/obv_up_dw_line_V230719.md) | +| `obvm_line_V230610` | `"{freq}_D{di}N{n}M{m}_OBV能量V230610` | OBV 双 EMA 能量信号 | [详细文档](signals/obvm_line_V230610.md) | diff --git a/.claude/skills/signal-functions/references/signals-pos.md b/.claude/skills/signal-functions/references/signals-pos.md new file mode 100644 index 000000000..a98246c54 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-pos.md @@ -0,0 +1,23 @@ +# pos 模块信号索引 + +> 源码: `crates/czsc-signals/src/pos.rs` +> 共 16 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `pos_bar_stop_V230524` | `"{pos_name}_{freq1}N{n}K_止损V230524` | 按开仓点附近N根K线极值止损 | [详细文档](signals/pos_bar_stop_V230524.md) | +| `pos_fix_exit_V230624` | `"{pos_name}_固定{th}BP止盈止损_出场V230624` | 固定 BP 止盈止损 | [详细文档](signals/pos_fix_exit_V230624.md) | +| `pos_fx_stop_V230414` | `"{freq1}_{pos_name}N{n}_止损V230414` | 按开仓点附近分型止损 | [详细文档](signals/pos_fx_stop_V230414.md) | +| `pos_holds_V230414` | `"{pos_name}_{freq1}N{n}M{m}_趋势判断V230414` | 开仓后 N 根K线收益与阈值比较 | [详细文档](signals/pos_holds_V230414.md) | +| `pos_holds_V230807` | `"{pos_name}_{freq1}N{n}M{m}T{t}_BS辅助V230807` | 开仓后收益在 (t, m) 之间触发保本 | [详细文档](signals/pos_holds_V230807.md) | +| `pos_holds_V240428` | `"{pos_name}_{freq1}H{h}T{t}N{n}_保本V240428` | 最大盈利回撤比例保本 | [详细文档](signals/pos_holds_V240428.md) | +| `pos_holds_V240608` | `"{pos_name}_{freq1}W{w}N{n}_保本V240608` | 跌破/升破开仓前窗口极值后,回到成本价指定档位保本 | [详细文档](signals/pos_holds_V240608.md) | +| `pos_ma_V230414` | `"{pos_name}_{freq1}#{ma_type}#{timeperiod}_持有状态V230414` | 判断开仓后是否升破/跌破均线 | [详细文档](signals/pos_ma_V230414.md) | +| `pos_profit_loss_V230624` | `"{pos_name}_{freq1}YKB{ykb}N{n}_盈亏比判断V230624` | 盈亏比阈值判断 | [详细文档](signals/pos_profit_loss_V230624.md) | +| `pos_status_V230808` | `"{pos_name}_持仓状态_BS辅助V230808` | 持仓状态 | [详细文档](signals/pos_status_V230808.md) | +| `pos_stop_V240331` | `"{pos_name}_{freq1}#{n}_止损V240331` | 最近 N 根K线追踪止损 | [详细文档](signals/pos_stop_V240331.md) | +| `pos_stop_V240428` | `"{pos_name}_{freq1}T{t}N{n}_止损V240428` | 按开仓前离散价位跳数止损 | [详细文档](signals/pos_stop_V240428.md) | +| `pos_stop_V240608` | `"{pos_name}_{freq1}W{w}N{n}_止损V240608` | 开仓后突破开仓前窗口极值 N 档止损 | [详细文档](signals/pos_stop_V240608.md) | +| `pos_stop_V240614` | `"{pos_name}_{freq1}N{n}_止损V240614` | 开仓后低于/高于成本价的 K线数量计数止损 | [详细文档](signals/pos_stop_V240614.md) | +| `pos_stop_V240717` | `"{pos_name}_{freq1}N{n}T{timeperiod}_止损V240717` | 基于开仓时 ATR 的计数止损 | [详细文档](signals/pos_stop_V240717.md) | +| `pos_take_V240428` | `"{pos_name}_{freq1}T{t}N{n}_止盈V240428` | 倍量阳/阴线计数止盈 | [详细文档](signals/pos_take_V240428.md) | diff --git a/.claude/skills/signal-functions/references/signals-pressure.md b/.claude/skills/signal-functions/references/signals-pressure.md new file mode 100644 index 000000000..09f678da1 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-pressure.md @@ -0,0 +1,11 @@ +# pressure 模块信号索引 + +> 源码: `crates/czsc-signals/src/pressure.rs` +> 共 4 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `pressure_support_V240222` | `"{freq}_D{di}W{w}高低点验证_支撑压力V240222` | 高低点验证支撑压力位 | [详细文档](signals/pressure_support_V240222.md) | +| `pressure_support_V240402` | `"{freq}_D{di}W{w}_支撑压力V240402` | 分型区间支撑压力位 | [详细文档](signals/pressure_support_V240402.md) | +| `pressure_support_V240406` | `"{freq}_D{di}W{w}_支撑压力V240406` | 分型密集支撑压力位 | [详细文档](signals/pressure_support_V240406.md) | +| `pressure_support_V240530` | `"{freq}_D{di}W{w}N{n}_支撑压力V240530` | 关键重叠K线支撑压力位 | [详细文档](signals/pressure_support_V240530.md) | diff --git a/.claude/skills/signal-functions/references/signals-tas.md b/.claude/skills/signal-functions/references/signals-tas.md new file mode 100644 index 000000000..8b63ebe04 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-tas.md @@ -0,0 +1,66 @@ +# tas 模块信号索引 + +> 源码: `crates/czsc-signals/src/tas.rs` +> 共 59 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `cci_decision_V240620` | `"{freq}_N{n}CCI_决策区域V240620` | CCI 逆势决策区域 | [详细文档](signals/cci_decision_V240620.md) | +| `tas_accelerate_V230531` | `"{freq}_D{di}N{n}T{t}_BOLL加速V230531` | BOLL 通道加速信号 | [详细文档](signals/tas_accelerate_V230531.md) | +| `tas_angle_V230802` | `"{freq}_D{di}N{n}T{th}_笔角度V230802` | 笔角度偏离信号 | [详细文档](signals/tas_angle_V230802.md) | +| `tas_atr_V230630` | `"{freq}_D{di}ATR{timeperiod}_波动V230630` | ATR 波动分层信号 | [详细文档](signals/tas_atr_V230630.md) | +| `tas_atr_break_V230424` | `"{freq}_D{di}ATR{timeperiod}T{th}突破_BS辅助V230424` | ATR 通道突破 | [详细文档](signals/tas_atr_break_V230424.md) | +| `tas_boll_bc_V221118` | `"{freq}_D{di}N{n}M{m}L{line}#BOLL{timeperiod}_背驰V221118` | BOLL背驰辅助信号 | [详细文档](signals/tas_boll_bc_V221118.md) | +| `tas_boll_cc_V230312` | `"{freq}_D{di}BOLL{timeperiod}S{nbdev}SP{sp}_BS辅助V230312` | 布林进出场信号 | [详细文档](signals/tas_boll_cc_V230312.md) | +| `tas_boll_power_V221112` | `"{freq}_D{di}BOLL{timeperiod}_强弱V221112` | BOLL强弱分层信号 | [详细文档](signals/tas_boll_power_V221112.md) | +| `tas_boll_vt_V230212` | `"{freq}_D{di}BOLL{timeperiod}S{nbdev}MO{max_overlap}_BS辅助V230212` | BOLL 通道突破进出场信号 | [详细文档](signals/tas_boll_vt_V230212.md) | +| `tas_cci_base_V230402` | `"{freq}_D{di}CCI{timeperiod}#{min_count}#{max_count}_BS辅助V230402` | CCI 极值连续计数信号 | [详细文档](signals/tas_cci_base_V230402.md) | +| `tas_cross_status_V230619` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_金死叉V230619` | 0轴上下金死叉次数 | [详细文档](signals/tas_cross_status_V230619.md) | +| `tas_cross_status_V230624` | `"{freq}_D{di}N{n}MD{md}_MACD交叉数量V230624` | 指定金死叉数值 | [详细文档](signals/tas_cross_status_V230624.md) | +| `tas_cross_status_V230625` | `"{freq}_D{di}N{n}MD{md}J{j}S{s}_MACD交叉数量V230625` | 指定金叉/死叉次数后状态 | [详细文档](signals/tas_cross_status_V230625.md) | +| `tas_dif_layer_V241010` | `"{freq}_DIF分层W{w}T{t}_完全分类V241010` | DIF 三层分类 | [详细文档](signals/tas_dif_layer_V241010.md) | +| `tas_dif_zero_V240612` | `"{freq}_DIF靠近零轴T{t}_BS辅助V240612` | DIF靠近零轴买卖点信号(基于最近一笔) | [详细文档](signals/tas_dif_zero_V240612.md) | +| `tas_dif_zero_V240614` | `"{freq}_DIF靠近零轴W{w}T{t}_BS辅助V240614` | DIF靠近零轴买卖点信号 | [详细文档](signals/tas_dif_zero_V240614.md) | +| `tas_dma_bs_V240608` | `"{freq}_N{n}双均线{t1}#{t2}顺势_BS辅助V240608` | 双均线顺势回调买卖点 | [详细文档](signals/tas_dma_bs_V240608.md) | +| `tas_double_ma_V221203` | `"{freq}_D{di}T{th}#{ma_type}#{timeperiod1}#{timeperiod2}_JX辅助V221203` | 双均线多空强弱信号 | [详细文档](signals/tas_double_ma_V221203.md) | +| `tas_double_ma_V230511` | `"{freq}_D{di}#{ma_type}#{t1}#{t2}_BS辅助V230511` | 双均线反向信号 | [详细文档](signals/tas_double_ma_V230511.md) | +| `tas_double_ma_V240208` | `"{freq}_D{di}N{N}M{M}双均线_BS辅助V240208` | 双均线交叉结构信号 | [详细文档](signals/tas_double_ma_V240208.md) | +| `tas_first_bs_V230217` | `"{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS1辅助V230217` | 均线结合K线形态的一买一卖辅助 | [详细文档](signals/tas_first_bs_V230217.md) | +| `tas_hlma_V230301` | `"{freq}_D{di}#{ma_type}#{timeperiod}HLMA_BS辅助V230301` | HMA/LMA 多空信号 | [详细文档](signals/tas_hlma_V230301.md) | +| `tas_kdj_base_V221101` | `"{freq}_D{di}K#KDJ{fastk_period}#{slowk_period}#{slowd_period}_KDJ辅助V221101` | KDJ基础辅助信号 | [详细文档](signals/tas_kdj_base_V221101.md) | +| `tas_kdj_evc_V221201` | `"{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{c1}#{c2}_KDJ极值V221201` | KDJ 极值计数信号 | [详细文档](signals/tas_kdj_evc_V221201.md) | +| `tas_kdj_evc_V230401` | `"{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{min_count}#{max_count}_BS辅助V230401` | KDJ 极值计数信号 | [详细文档](signals/tas_kdj_evc_V230401.md) | +| `tas_low_trend_V230627` | `"{freq}_D{di}N{n}TH{th}_趋势230627` | 阴跌/小阳趋势信号 | [详细文档](signals/tas_low_trend_V230627.md) | +| `tas_ma_base_V221101` | `"{freq}_D{di}{ma_type}#{timeperiod}_分类V221101` | 单均线多空与方向信号 | [详细文档](signals/tas_ma_base_V221101.md) | +| `tas_ma_base_V221203` | `"{freq}_D{di}{ma_type}#{timeperiod}T{th}_分类V221203` | 单均线多空与距离分层信号 | [详细文档](signals/tas_ma_base_V221203.md) | +| `tas_ma_base_V230313` | `"{freq}_D{di}#{ma_type}#{timeperiod}MO{max_overlap}_BS辅助V230313` | 单均线开平仓辅助信号(带重叠约束) | [详细文档](signals/tas_ma_base_V230313.md) | +| `tas_ma_cohere_V230512` | `"{freq}_D{di}SMA{ma_seq}_均线系统V230512` | 均线系统粘合/扩散状态 | [详细文档](signals/tas_ma_cohere_V230512.md) | +| `tas_ma_round_V221206` | `"{freq}_D{di}TH{th}#碰{ma_type}#{timeperiod}_BE辅助V221206` | 笔端点触碰均线信号 | [详细文档](signals/tas_ma_round_V221206.md) | +| `tas_ma_system_V230513` | `"{freq}_D{di}SMA{ma_seq}_均线系统V230513` | 均线系统多空排列 | [详细文档](signals/tas_ma_system_V230513.md) | +| `tas_macd_base_V221028` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}#{key}_BS辅助V221028` | MACD/DIF/DEA 多空与方向信号 | [详细文档](signals/tas_macd_base_V221028.md) | +| `tas_macd_base_V230320` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}MO{max_overlap}#{key}_BS辅助V230320` | MACD/DIF/DEA 多空与方向信号(含重叠约束) | [详细文档](signals/tas_macd_base_V230320.md) | +| `tas_macd_bc_V221201` | `"{freq}_D{di}N{n}M{m}#MACD{fastperiod}#{slowperiod}#{signalperiod}_BCV221201` | MACD背驰辅助信号 | [详细文档](signals/tas_macd_bc_V221201.md) | +| `tas_macd_bc_V230803` | `"{freq}_MACD双分型背驰_BS辅助V230803` | 双分型 MACD 背驰信号 | [详细文档](signals/tas_macd_bc_V230803.md) | +| `tas_macd_bc_V230804` | `"{freq}_D{di}MACD背驰_BS辅助V230804` | MACD 黄白线背驰信号 | [详细文档](signals/tas_macd_bc_V230804.md) | +| `tas_macd_bc_V240307` | `"{freq}_D{di}N{n}柱子背驰_BS辅助V240307` | MACD 柱背驰计次信号 | [详细文档](signals/tas_macd_bc_V240307.md) | +| `tas_macd_bc_ubi_V230804` | `"{freq}_MACD背驰_UBI观察V230804` | 未完成笔 MACD 背驰观察 | [详细文档](signals/tas_macd_bc_ubi_V230804.md) | +| `tas_macd_bs1_V230312` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230312` | MACD 辅助一买一卖(笔结构) | [详细文档](signals/tas_macd_bs1_V230312.md) | +| `tas_macd_bs1_V230313` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230313` | MACD 红绿柱第一买卖点 | [详细文档](signals/tas_macd_bs1_V230313.md) | +| `tas_macd_bs1_V230411` | `"{freq}_D{di}T{tha}#{thb}#{thc}_BS1辅助V230411` | MACD DIF 五笔背驰信号 | [详细文档](signals/tas_macd_bs1_V230411.md) | +| `tas_macd_bs1_V230412` | `"{freq}_D{di}T{tha}#{thb}_BS1辅助V230412` | MACD DIF 五笔背驰简化信号 | [详细文档](signals/tas_macd_bs1_V230412.md) | +| `tas_macd_change_V221105` | `"{freq}_D{di}K{n}#MACD{fastperiod}#{slowperiod}#{signalperiod}变色次数_BS辅助V221105` | MACD变色次数信号 | [详细文档](signals/tas_macd_change_V221105.md) | +| `tas_macd_direct_V221106` | `"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}方向_BS辅助V221106` | MACD柱方向信号 | [详细文档](signals/tas_macd_direct_V221106.md) | +| `tas_macd_dist_V230408` | `"{freq}_{key}分层W{w}N{n}_BS辅助V230408` | DIF/DEA/MACD等宽分层信号 | [详细文档](signals/tas_macd_dist_V230408.md) | +| `tas_macd_dist_V230409` | `"{freq}_{key}远离W{w}N{n}T{t}_BS辅助V230409` | DIF/DEA/MACD远离零轴信号 | [详细文档](signals/tas_macd_dist_V230409.md) | +| `tas_macd_dist_V230410` | `"{freq}_{key}多空分层W{w}N{n}_BS辅助V230410` | DIF/DEA/MACD多空分层信号 | [详细文档](signals/tas_macd_dist_V230410.md) | +| `tas_macd_first_bs_V221201` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221201` | MACD一买一卖辅助信号 | [详细文档](signals/tas_macd_first_bs_V221201.md) | +| `tas_macd_first_bs_V221216` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221216` | MACD 第一买卖点(扩展版) | [详细文档](signals/tas_macd_first_bs_V221216.md) | +| `tas_macd_power_V221108` | `"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}强弱_BS辅助V221108` | MACD强弱分层信号 | [详细文档](signals/tas_macd_power_V221108.md) | +| `tas_macd_second_bs_V221201` | `"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS2辅助V221201` | MACD 第二买卖点 | [详细文档](signals/tas_macd_second_bs_V221201.md) | +| `tas_macd_xt_V221208` | `"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}形态_BS辅助V221208` | MACD 柱形态信号 | [详细文档](signals/tas_macd_xt_V221208.md) | +| `tas_rsi_base_V230227` | `"{freq}_D{di}T{th}RSI{timeperiod}_RSI辅助V230227` | RSI超买超卖与方向信号 | [详细文档](signals/tas_rsi_base_V230227.md) | +| `tas_rumi_V230704` | `"{freq}_D{di}F{timeperiod1}S{timeperiod2}R{rumi_window}_BS辅助V230704` | RUMI 零轴切换信号 | [详细文档](signals/tas_rumi_V230704.md) | +| `tas_sar_base_V230425` | `"{freq}_D{di}MO{max_overlap}SAR_BS辅助V230425` | SAR 基础多空信号 | [详细文档](signals/tas_sar_base_V230425.md) | +| `tas_second_bs_V230228` | `"{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS2辅助V230228` | 均线结合K线形态的二买二卖辅助 | [详细文档](signals/tas_second_bs_V230228.md) | +| `tas_second_bs_V230303` | `"{freq}_D{di}{ma_type}#{timeperiod}_BS2辅助V230303` | 利用笔和均线辅助二买二卖 | [详细文档](signals/tas_second_bs_V230303.md) | +| `tas_slope_V231019` | `"{freq}_D{di}DIF{n}斜率T{th}_BS辅助V231019` | DIF 斜率分位多空 | [详细文档](signals/tas_slope_V231019.md) | diff --git a/.claude/skills/signal-functions/references/signals-vol.md b/.claude/skills/signal-functions/references/signals-vol.md new file mode 100644 index 000000000..6089afa5e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-vol.md @@ -0,0 +1,13 @@ +# vol 模块信号索引 + +> 源码: `crates/czsc-signals/src/vol.rs` +> 共 6 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `vol_double_ma_V230214` | `"{freq}_D{di}VOL双均线{ma_type}#{t1}#{t2}_BS辅助V230214` | 成交量双均线多空信号 | [详细文档](signals/vol_double_ma_V230214.md) | +| `vol_gao_di_V221218` | `"{freq}_D{di}K_量柱V221218` | 高量柱与低量柱信号 | [详细文档](signals/vol_gao_di_V221218.md) | +| `vol_single_ma_V230214` | `"{freq}_D{di}VOL#{ma_type}#{timeperiod}_分类V230214` | 单成交量均线多空与方向信号 | [详细文档](signals/vol_single_ma_V230214.md) | +| `vol_ti_suo_V221216` | `"{freq}_D{di}K_量柱V221216` | 梯量与缩量柱信号 | [详细文档](signals/vol_ti_suo_V221216.md) | +| `vol_window_V230731` | `"{freq}_D{di}W{w}M{m}N{n}_窗口能量V230731` | 窗口成交量分层特征 | [详细文档](signals/vol_window_V230731.md) | +| `vol_window_V230801` | `"{freq}_D{di}W{w}_窗口能量V230801` | 窗口成交量先后顺序特征 | [详细文档](signals/vol_window_V230801.md) | diff --git a/.claude/skills/signal-functions/references/signals-xl.md b/.claude/skills/signal-functions/references/signals-xl.md new file mode 100644 index 000000000..3bcf5fe3e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-xl.md @@ -0,0 +1,14 @@ +# xl 模块信号索引 + +> 源码: `crates/czsc-signals/src/xl.rs` +> 共 7 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `xl_bar_basis_V240411` | `"{freq}_N{n}_形态V240411` | 吞没形态信号 | [详细文档](signals/xl_bar_basis_V240411.md) | +| `xl_bar_basis_V240412` | `"{freq}_N{n}#TH{th}_形态V240412` | 长蜡烛形态信号 | [详细文档](signals/xl_bar_basis_V240412.md) | +| `xl_bar_position_V240328` | `"{freq}_N{n}_BS辅助V240328` | 相对高低位置识别信号 | [详细文档](signals/xl_bar_position_V240328.md) | +| `xl_bar_trend_V240329` | `"{freq}_N{n}M{m}_十字线反转V240329` | 十字孕线反转信号 | [详细文档](signals/xl_bar_trend_V240329.md) | +| `xl_bar_trend_V240330` | `"{freq}_N{n}M{m}#{ma_type}_双均线过滤V240330` | 双均线过滤信号 | [详细文档](signals/xl_bar_trend_V240330.md) | +| `xl_bar_trend_V240331` | `"{freq}_N{n}_突破信号V240331` | 通道突破信号 | [详细文档](signals/xl_bar_trend_V240331.md) | +| `xl_bar_trend_V240623` | `"{freq}_N{n}通道_突破信号V240623` | 通道突破连续信号 | [详细文档](signals/xl_bar_trend_V240623.md) | diff --git a/.claude/skills/signal-functions/references/signals-zdy_trader.md b/.claude/skills/signal-functions/references/signals-zdy_trader.md new file mode 100644 index 000000000..a234cdd95 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals-zdy_trader.md @@ -0,0 +1,11 @@ +# zdy_trader 模块信号索引 + +> 源码: `crates/czsc-signals/src/zdy_trader.rs` +> 共 4 个信号 + +| 信号名 | 参数模板 | 说明 | 详细文档 | +|--------|----------|------|----------| +| `zdy_stop_loss_V230406` | `` | 笔操作止损逻辑 | [详细文档](signals/zdy_stop_loss_V230406.md) | +| `zdy_take_profit_V230406` | `` | 笔操作止盈逻辑 | [详细文档](signals/zdy_take_profit_V230406.md) | +| `zdy_take_profit_V230407` | `` | 按力度提前止盈 | [详细文档](signals/zdy_take_profit_V230407.md) | +| `zdy_vibrate_V230406` | `` | 中枢震荡短差辅助 | [详细文档](signals/zdy_vibrate_V230406.md) | diff --git a/.claude/skills/signal-functions/references/signals/adtm_up_dw_line_V230603.md b/.claude/skills/signal-functions/references/signals/adtm_up_dw_line_V230603.md new file mode 100644 index 000000000..f51de8f01 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/adtm_up_dw_line_V230603.md @@ -0,0 +1,30 @@ +# adtm_up_dw_line_V230603:ADTM 能量异动多空信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}TH{th}_ADTMV230603` + +## 信号逻辑 + +1. 计算 `N` 窗口 `up_sum` 与 `M` 窗口 `dw_sum`; +2. 计算 `adtm = (up_sum - dw_sum) / max(up_sum, dw_sum)`; +3. `up_sum > dw_sum` 或 `adtm > th/10` 判 `看多`; +4. `up_sum < dw_sum` 或 `adtm < th/10` 判 `看空`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N30M20TH5_ADTMV230603_看多_任意_任意_0')` +- `Signal('60分钟_D1N30M20TH5_ADTMV230603_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:`up_sum` 窗口,默认 `30`; +- `m`:`dw_sum` 窗口,默认 `20`; +- `th`:阈值(除以 10 使用),默认 `5`。 + +## 对齐说明 + +与 Python `adtm_up_dw_line_V230603` 的条件优先级与阈值口径一致。 diff --git a/.claude/skills/signal-functions/references/signals/amv_up_dw_line_V230603.md b/.claude/skills/signal-functions/references/signals/amv_up_dw_line_V230603.md new file mode 100644 index 000000000..4b17d84b0 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/amv_up_dw_line_V230603.md @@ -0,0 +1,28 @@ +# amv_up_dw_line_V230603:AMV 能量多空信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}_AMV能量V230603` + +## 信号逻辑 + +1. 计算 `N` 与 `M` 窗口成交额加权均价; +2. 形成 `amv1` 与 `amv2`; +3. `amv1 > amv2` 判 `看多`,否则 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N30M120_AMV能量V230603_看多_任意_任意_0')` +- `Signal('60分钟_D1N30M120_AMV能量V230603_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:短窗口,默认 `30`; +- `m`:长窗口,默认 `120`。 + +## 对齐说明 + +与 Python `amv_up_dw_line_V230603` 的加权均价公式一致。 diff --git a/.claude/skills/signal-functions/references/signals/asi_up_dw_line_V230603.md b/.claude/skills/signal-functions/references/signals/asi_up_dw_line_V230603.md new file mode 100644 index 000000000..935120b03 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/asi_up_dw_line_V230603.md @@ -0,0 +1,28 @@ +# asi_up_dw_line_V230603:ASI 多空信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}P{p}_ASI多空V230603` + +## 信号逻辑 + +1. 基于最近 `p` 根K线计算 SI 序列并累加得 ASI; +2. 将最新 ASI 与 `p` 窗口 ASI 均值比较; +3. `asi_last > asi_mean` 判 `看多`,否则 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N30P120_ASI多空V230603_看多_任意_任意_0')` +- `Signal('60分钟_D1N30P120_ASI多空V230603_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:SI 公式中的常数项,默认 `30`; +- `p`:窗口长度,默认 `120`。 + +## 对齐说明 + +按 Python `asi_up_dw_line_V230603` 的原始向量公式逐项对齐实现。 diff --git a/.claude/skills/signal-functions/references/signals/bar_accelerate_V221110.md b/.claude/skills/signal-functions/references/signals/bar_accelerate_V221110.md new file mode 100644 index 000000000..ca1e19d18 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_accelerate_V221110.md @@ -0,0 +1,27 @@ +# bar_accelerate_V221110:区间加速走势判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{window}_加速V221110` + +## 信号逻辑 + +1. 取倒数第 `di` 根截止的最近 `window` 根K线,计算区间最高/最低; +2. 若末根收盘位于区间上20%且阳线占比>=80%,判 `上涨`; +3. 若末根收盘位于区间下20%且阴线占比>=80%,判 `下跌`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W13_加速V221110_上涨_任意_任意_0')` +- `Signal('60分钟_D1W13_加速V221110_下跌_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `window`:观察窗口长度,默认 `10`。 + +## 对齐说明 + +收盘位置与阳阴占比阈值对齐 Python `bar_accelerate_V221110`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_accelerate_V221118.md b/.claude/skills/signal-functions/references/signals/bar_accelerate_V221118.md new file mode 100644 index 000000000..e455d9a58 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_accelerate_V221118.md @@ -0,0 +1,29 @@ +# bar_accelerate_V221118:均线偏离加速判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{window}#{ma_type}#{timeperiod}_加速V221118` + +## 信号逻辑 + +1. 计算窗口内每根 `close - ma` 偏离值; +2. 全部偏离为正,且最后三根偏离值递增,判 `上涨`; +3. 全部偏离为负,且最后三根偏离值递减,判 `下跌`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('日线_D1W13#SMA#10_加速V221118_上涨_任意_任意_0')` +- `Signal('日线_D1W13#SMA#10_加速V221118_下跌_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `window`:观察窗口,默认 `13`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `10`。 + +## 对齐说明 + +偏离序列与三根单调条件对齐 Python `bar_accelerate_V221118`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_accelerate_V240428.md b/.claude/skills/signal-functions/references/signals/bar_accelerate_V240428.md new file mode 100644 index 000000000..f5913e18c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_accelerate_V240428.md @@ -0,0 +1,28 @@ +# bar_accelerate_V240428:滚动差分加速判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}T{t}_加速V240428` + +## 信号逻辑 + +1. 计算 `diff = close - close[w]`,取最近300根 `|diff|` 的75分位阈值; +2. 若最新 `|diff|` 超阈且 `diff>0`,窗口内倍量阳线数>=`t` 判 `上涨`; +3. 若最新 `|diff|` 超阈且 `diff<0`,窗口内倍量阴线数>=`t` 判 `下跌`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('日线_D1W21T2_加速V240428_上涨_任意_任意_0')` +- `Signal('日线_D1W21T2_加速V240428_下跌_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `w`:差分窗口,默认 `21`; +- `t`:倍量同向K线最小数量,默认 `1`。 + +## 对齐说明 + +阈值分位与倍量计数口径对齐 Python `bar_accelerate_V240428`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_amount_acc_V230214.md b/.claude/skills/signal-functions/references/signals/bar_amount_acc_V230214.md new file mode 100644 index 000000000..afd3424ac --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_amount_acc_V230214.md @@ -0,0 +1,28 @@ +# bar_amount_acc_V230214:区间累计成交额信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}_累计超{t}千万` + +## 信号逻辑 + +1. 取倒数第 `di` 根截止的最近 `n` 根K线; +2. 计算累计成交额 `sum(amount)`; +3. 若大于 `t * 1e7` 判 `是`,否则 `否`。 + +## 信号列表示例 + +- `Signal('日线_D2N5_累计超10千万_是_任意_任意_0')` +- `Signal('日线_D2N5_累计超10千万_否_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `2`; +- `n`:回看K线数,默认 `5`; +- `t`:阈值(千万),默认 `10`。 + +## 对齐说明 + +累计金额阈值判断与 Python `bar_amount_acc_V230214` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_big_solid_V230215.md b/.claude/skills/signal-functions/references/signals/bar_big_solid_V230215.md new file mode 100644 index 000000000..c6aa097f7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_big_solid_V230215.md @@ -0,0 +1,28 @@ +# bar_big_solid_V230215:窗口最大实体中位多空信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}_MIDV230215` + +## 信号逻辑 + +1. 在窗口内找到实体最大K线; +2. 取该K线实体中位价 `mid`; +3. 最新收盘价高于 `mid` 判 `看多`,否则 `看空`; +4. 最大实体K线按方向标注 `大阳/大阴`。 + +## 信号列表示例 + +- `Signal('日线_D1N20_MIDV230215_看多_大阳_任意_0')` +- `Signal('日线_D1N20_MIDV230215_看空_大阴_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:窗口长度,默认 `20`。 + +## 对齐说明 + +最大实体与中位价定义对齐 Python `bar_big_solid_V230215`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_bpm_V230227.md b/.claude/skills/signal-functions/references/signals/bar_bpm_V230227.md new file mode 100644 index 000000000..ee6bb45c8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_bpm_V230227.md @@ -0,0 +1,28 @@ +# bar_bpm_V230227:绝对动量分层 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}T{th}_绝对动量V230227` + +## 信号逻辑 + +1. 取最近 `n` 根,计算区间 BP:`(last_close/first_open-1)*10000`; +2. `bp>0` 时,`bp>th` 判 `超强` 否则 `强势`; +3. `bp<=0` 时,`|bp|>th` 判 `超弱` 否则 `弱势`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N20T1000_绝对动量V230227_强势_任意_任意_0')` +- `Signal('60分钟_D1N20T1000_绝对动量V230227_超弱_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:窗口长度,默认 `20`; +- `th`:强弱阈值(BP),默认 `1000`。 + +## 对齐说明 + +分层规则与 Python `bar_bpm_V230227` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_break_V240428.md b/.claude/skills/signal-functions/references/signals/bar_break_V240428.md new file mode 100644 index 000000000..19b92deba --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_break_V240428.md @@ -0,0 +1,27 @@ +# bar_break_V240428:收盘极值突破 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}_事件V240428` + +## 信号逻辑 + +1. 在窗口内比较末根收盘与前序最高/最低; +2. 收盘高于前序最高判 `收盘新高`; +3. 收盘低于前序最低判 `收盘新低`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W20_事件V240428_收盘新高_任意_任意_0')` +- `Signal('60分钟_D1W20_事件V240428_收盘新低_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `w`:窗口长度,默认 `20`。 + +## 对齐说明 + +极值比较区间与 Python `bar_break_V240428` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_channel_V230508.md b/.claude/skills/signal-functions/references/signals/bar_channel_V230508.md new file mode 100644 index 000000000..58d6da3d0 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_channel_V230508.md @@ -0,0 +1,28 @@ +# bar_channel_V230508:窄幅通道方向判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}M{m}_通道V230507` + +## 信号逻辑 + +1. 窗口内每根K线涨跌幅需不超过 `m` BP; +2. 对高点和低点分别做一元线性拟合,要求 `r2 > 0.8`; +3. 双斜率同向且右侧极值确认,判 `看多/看空`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('日线_D1M600_通道V230507_看多_任意_任意_0')` +- `Signal('日线_D1M600_通道V230507_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:窗口长度,默认 `20`; +- `m`:单根波动阈值(BP),默认 `600`。 + +## 对齐说明 + +拟合阈值和右侧极值规则对齐 Python `bar_channel_V230508`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_classify_V240606.md b/.claude/skills/signal-functions/references/signals/bar_classify_V240606.md new file mode 100644 index 000000000..8ca29af0b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_classify_V240606.md @@ -0,0 +1,26 @@ +# bar_classify_V240606:单根K线收盘位置分类 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}收盘位置_分类V240606` + +## 信号逻辑 + +1. 将K线高低区间三等分; +2. 收盘落在上三分之一判 `高位`; +3. 收盘落在下三分之一判 `低位`,否则 `中间`。 + +## 信号列表示例 + +- `Signal('60分钟_D1收盘位置_分类V240606_高位_任意_任意_0')` +- `Signal('60分钟_D1收盘位置_分类V240606_中间_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +三分位阈值与 Python `bar_classify_V240606` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_classify_V240607.md b/.claude/skills/signal-functions/references/signals/bar_classify_V240607.md new file mode 100644 index 000000000..e10a1ebe4 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_classify_V240607.md @@ -0,0 +1,26 @@ +# bar_classify_V240607:两根K线收盘位置分类 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K2收盘位置_分类V240607` + +## 信号逻辑 + +1. 取最近两根K线(截至 `di`); +2. 第二根收盘高于第一根最高判 `看多`; +3. 第二根收盘低于第一根最低判 `看空`,否则 `中性`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K2收盘位置_分类V240607_看多_任意_任意_0')` +- `Signal('60分钟_D1K2收盘位置_分类V240607_中性_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +两根K线比较规则与 Python `bar_classify_V240607` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_decision_V240608.md b/.claude/skills/signal-functions/references/signals/bar_decision_V240608.md new file mode 100644 index 000000000..efabce551 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_decision_V240608.md @@ -0,0 +1,28 @@ +# bar_decision_V240608:放量反向决策区 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_W{w}N{n}Q{q}放量_决策区域V240608` + +## 信号逻辑 + +1. 在最近 `n` 根中取成交量最大的3根; +2. 若三者成交量都大于最近 `w` 根的 `q` 分位; +3. 且 `n` 窗口净涨则 `看空`,净跌则 `看多`。 + +## 信号列表示例 + +- `Signal('60分钟_W300N10Q80放量_决策区域V240608_看空_任意_任意_0')` +- `Signal('60分钟_W300N10Q80放量_决策区域V240608_看多_任意_任意_0')` + +## 参数说明 + +- `w`:长窗口,默认 `300`; +- `n`:短窗口,默认 `10`; +- `q`:成交量分位阈值(0-100),默认 `80`。 + +## 对齐说明 + +分位阈值与反向判定对齐 Python `bar_decision_V240608`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_decision_V240616.md b/.claude/skills/signal-functions/references/signals/bar_decision_V240616.md new file mode 100644 index 000000000..7e47397a7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_decision_V240616.md @@ -0,0 +1,27 @@ +# bar_decision_V240616:新高新低后的强弱决策 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_W{w}N{n}强弱_决策区域V240616` + +## 信号逻辑 + +1. 用 `di=n` 的 `w` 窗口给出历史新高/新低参考; +2. 在最近 `n` 根中过滤出大实体K线并按顺序检查其右侧K线; +3. 新高后转弱判 `看空`,新低后转强判 `看多`。 + +## 信号列表示例 + +- `Signal('60分钟_W100N5强弱_决策区域V240616_看空_任意_任意_0')` +- `Signal('60分钟_W100N5强弱_决策区域V240616_看多_任意_任意_0')` + +## 参数说明 + +- `w`:参考窗口,默认 `100`; +- `n`:决策窗口,默认 `5`。 + +## 对齐说明 + +候选筛选与右侧确认流程对齐 Python `bar_decision_V240616`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_dual_thrust_V230403.md b/.claude/skills/signal-functions/references/signals/bar_dual_thrust_V230403.md new file mode 100644 index 000000000..77b9af7af --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_dual_thrust_V230403.md @@ -0,0 +1,29 @@ +# bar_dual_thrust_V230403:Dual Thrust 通道突破 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}通道突破#{N}#{K1}#{K2}_BS辅助V230403` + +## 信号逻辑 + +1. 用前 `N+1` 根计算 `HH/HC/LC/LL` 与 `Range=max(HH-LC, HC-LL)`; +2. 构造当根上/下轨:`open + Range*K1%`、`open - Range*K2%`; +3. 收盘上破判 `看多`,下破判 `看空`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('日线_D1通道突破#5#20#20_BS辅助V230403_看多_任意_任意_0')` +- `Signal('日线_D1通道突破#5#20#20_BS辅助V230403_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `N`:回看天数,默认 `5`; +- `K1`:上轨系数(百分比),默认 `20`; +- `K2`:下轨系数(百分比),默认 `20`。 + +## 对齐说明 + +通道计算与突破判断对齐 Python `bar_dual_thrust_V230403`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_eight_V230702.md b/.claude/skills/signal-functions/references/signals/bar_eight_V230702.md new file mode 100644 index 000000000..2ad338d64 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_eight_V230702.md @@ -0,0 +1,27 @@ +# bar_eight_V230702:8K 走势分类 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}#8K_走势分类V230702` + +## 信号逻辑 + +1. 统计8K中的三连K重叠中枢; +2. 无中枢时输出 `无中枢上涨/无中枢下跌`; +3. 双中枢满足不重叠时输出 `双中枢上涨/双中枢下跌`; +4. 其余按前三根是否出现极值分为 `强平衡/弱平衡/转折平衡`。 + +## 信号列表示例 + +- `Signal('30分钟_D1#8K_走势分类V230702_双中枢上涨_任意_任意_0')` +- `Signal('30分钟_D1#8K_走势分类V230702_转折平衡市_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +中枢判定与分类分支顺序对齐 Python `bar_eight_V230702`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_end_V221211.md b/.claude/skills/signal-functions/references/signals/bar_end_V221211.md new file mode 100644 index 000000000..2e7b0b270 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_end_V221211.md @@ -0,0 +1,26 @@ +# bar_end_V221211:判断大周期K线是否闭合 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_{freq1}结束_BS辅助221211` + +## 信号逻辑 + +1. 以当前基础周期 `freq` 与目标分钟周期 `freq1` 计算当前K线对应结束时间; +2. 从最新K线向前统计同属该结束时间的连续数量 `i`; +3. 若 `end_time == last_dt` 判定 `闭合`,否则判定 `未闭{i}`。 + +## 信号列表示例 + +- `Signal('15分钟_60分钟结束_BS辅助221211_闭合_任意_任意_0')` +- `Signal('15分钟_60分钟结束_BS辅助221211_未闭2_任意_任意_0')` + +## 参数说明 + +- `freq1`:目标分钟周期,默认 `60分钟`。 + +## 对齐说明 + +闭合/未闭计数语义与 Python `bar_end_V221211` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_fake_break_V230204.md b/.claude/skills/signal-functions/references/signals/bar_fake_break_V230204.md new file mode 100644 index 000000000..d1c42d05e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_fake_break_V230204.md @@ -0,0 +1,28 @@ +# bar_fake_break_V230204:区间假突破判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}_假突破V230204` + +## 信号逻辑 + +1. 在最近 `N` 根内寻找滑动 `M` 窗口重叠中枢; +2. 阳线末根创新高且“跌破DD后拉回”,判 `看多`; +3. 阴线末根创新低且“突破GG后回落”,判 `看空`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('15分钟_D1N20M5_假突破_看多_任意_任意_0')` +- `Signal('15分钟_D1N20M5_假突破_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:观察窗口,默认 `20`; +- `m`:中枢滑窗,默认 `5`。 + +## 对齐说明 + +中枢重叠与真假突破条件对齐 Python `bar_fake_break_V230204`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_fang_liang_break_V221216.md b/.claude/skills/signal-functions/references/signals/bar_fang_liang_break_V221216.md new file mode 100644 index 000000000..922bbe5f4 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_fang_liang_break_V221216.md @@ -0,0 +1,29 @@ +# bar_fang_liang_break_V221216:放量突破与缩量回踩 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TH{th}#{ma_type}#{timeperiod}_突破V221216` + +## 信号逻辑 + +1. 计算指定均线,检查末根是否放量且站上均线,判 `放量突破`; +2. 检查末根是否缩量且收盘不破均线,且前序收盘与均线距离在阈值内,判 `缩量回踩`; +3. 在窗口长度 `5~9` 中依次尝试,首次出现突破即返回。 + +## 信号列表示例 + +- `Signal('15分钟_D1TH300#SMA#233_突破V221216_放量突破_缩量回踩_任意_0')` +- `Signal('15分钟_D1TH300#SMA#233_突破V221216_其他_其他_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `th`:回踩距离阈值(BP),默认 `300`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `233`。 + +## 对齐说明 + +窗口扫描与两阶段条件对齐 Python `bar_fang_liang_break_V221216`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_limit_down_V230525.md b/.claude/skills/signal-functions/references/signals/bar_limit_down_V230525.md new file mode 100644 index 000000000..dbba05766 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_limit_down_V230525.md @@ -0,0 +1,27 @@ +# bar_limit_down_V230525:跌停后反包阳线 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_跌停后无下影线长实体阳线_短线V230525` + +## 信号逻辑 + +1. 仅日线级别; +2. 前一日近似跌停:`low==closeopen && solid>2*upper && close/open>1.07`; +4. 且当日最低低于前日最低,判 `满足`。 + +## 信号列表示例 + +- `Signal('日线_跌停后无下影线长实体阳线_短线V230525_满足_任意_任意_0')` +- `Signal('日线_跌停后无下影线长实体阳线_短线V230525_其他_任意_任意_0')` + +## 参数说明 + +- 无额外参数。 + +## 对齐说明 + +条件组合与 Python `bar_limit_down_V230525` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_mean_amount_V221112.md b/.claude/skills/signal-functions/references/signals/bar_mean_amount_V221112.md new file mode 100644 index 000000000..9bf2b4b00 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_mean_amount_V221112.md @@ -0,0 +1,29 @@ +# bar_mean_amount_V221112:区间均额分类信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K{n}B均额_{th1}至{th2}千万` + +## 信号逻辑 + +1. 取倒数第 `di` 根截止的最近 `n` 根K线; +2. 计算平均成交额 `m`; +3. 若 `m/1e7` 在 `[th1, th2]` 判 `是`,否则判 `否`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K10B均额_1至4千万_是_任意_任意_0')` +- `Signal('60分钟_D1K10B均额_1至4千万_否_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:样本长度,默认 `10`; +- `th1`:下限(千万),默认 `1`; +- `th2`:上限(千万),默认 `4`。 + +## 对齐说明 + +均额口径与 Python `bar_mean_amount_V221112` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_operate_span_V221111.md b/.claude/skills/signal-functions/references/signals/bar_operate_span_V221111.md new file mode 100644 index 000000000..44afc6d9a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_operate_span_V221111.md @@ -0,0 +1,26 @@ +# bar_operate_span_V221111:日内时间区间过滤 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_T{t1}#{t2}_时间区间V221111` + +## 信号逻辑 + +1. 读取最新K线时间 `HHMM`; +2. 若 `t1 <= HHMM <= t2` 判定 `是`,否则判定 `否`。 + +## 信号列表示例 + +- `Signal('60分钟_T0935#1450_时间区间_是_任意_任意_0')` +- `Signal('60分钟_T0935#1450_时间区间_否_任意_任意_0')` + +## 参数说明 + +- `t1`:起始时间(`HHMM`),默认 `0935`; +- `t2`:结束时间(`HHMM`),默认 `1450`。 + +## 对齐说明 + +边界包含比较与 Python `bar_operate_span_V221111` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_plr_V240427.md b/.claude/skills/signal-functions/references/signals/bar_plr_V240427.md new file mode 100644 index 000000000..14bbe26c8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_plr_V240427.md @@ -0,0 +1,29 @@ +# bar_plr_V240427:盈亏比约束 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}T{t}M{m}_盈亏比V240427` + +## 信号逻辑 + +1. `多头`:以窗口最低点前的最高点与当前收盘计算盈亏比; +2. `空头`:以窗口最高点前的最低点与当前收盘计算盈亏比; +3. `plr > t/10` 判 `满足`,否则 `不满足`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W60T20M多头_盈亏比V240427_满足_任意_任意_0')` +- `Signal('60分钟_D1W60T20M空头_盈亏比V240427_不满足_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `w`:窗口长度,默认 `60`; +- `t`:阈值(`t/10`),默认 `20`; +- `m`:方向,`多头/空头`,默认 `多头`。 + +## 对齐说明 + +盈亏比定义与阈值比较对齐 Python `bar_plr_V240427`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_polyfit_V240428.md b/.claude/skills/signal-functions/references/signals/bar_polyfit_V240428.md new file mode 100644 index 000000000..92549b743 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_polyfit_V240428.md @@ -0,0 +1,27 @@ +# bar_polyfit_V240428:一阶二阶拟合分类 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}_分类V240428` + +## 信号逻辑 + +1. 对窗口收盘价做一阶拟合取斜率 `p1`; +2. 做二阶拟合取二次项系数 `p2`; +3. 按 `p1/p2` 符号组合输出 `加速上涨/减速上涨/加速下跌/减速下跌`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W20_分类V240428_加速上涨_任意_任意_0')` +- `Signal('60分钟_D1W20_分类V240428_减速下跌_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `w`:窗口长度,默认 `20`。 + +## 对齐说明 + +一二阶系数组合分类对齐 Python `bar_polyfit_V240428`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_r_breaker_V230326.md b/.claude/skills/signal-functions/references/signals/bar_r_breaker_V230326.md new file mode 100644 index 000000000..e6beeabb2 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_r_breaker_V230326.md @@ -0,0 +1,26 @@ +# bar_r_breaker_V230326:RBreaker 价格位判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_RBreaker_BS辅助V230326` + +## 信号逻辑 + +1. 用前一根K线 `H/C/L` 计算突破位、观察位、反转位; +2. 当前收盘突破上/下轨判 `趋势做多/趋势做空`; +3. 满足观察后反转条件判 `反转做多/反转做空`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('日线_RBreaker_BS辅助V230326_做多_趋势_任意_0')` +- `Signal('日线_RBreaker_BS辅助V230326_做空_反转_任意_0')` + +## 参数说明 + +- 无额外参数。 + +## 对齐说明 + +六价位与判定顺序对齐 Python `bar_r_breaker_V230326`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_reversal_V230227.md b/.claude/skills/signal-functions/references/signals/bar_reversal_V230227.md new file mode 100644 index 000000000..dae343606 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_reversal_V230227.md @@ -0,0 +1,27 @@ +# bar_reversal_V230227:末根反转迹象判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}A{avg_bp}_反转V230227` + +## 信号逻辑 + +1. 以末根K线形态(阴线/长上影、阳线/长下影)确定反转方向候选; +2. 左侧13根满足 3/5/8 平均涨跌幅阈值,或 13 连阳/连阴,触发反向信号; +3. 输出 `看多/看空/其他`。 + +## 信号列表示例 + +- `Signal('15分钟_D1A300_反转V230227_看多_任意_任意_0')` +- `Signal('15分钟_D1A300_反转V230227_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `avg_bp`:平均单根涨跌幅阈值(BP),默认 `300`。 + +## 对齐说明 + +触发条件与优先级对齐 Python `bar_reversal_V230227`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_section_momentum_V221112.md b/.claude/skills/signal-functions/references/signals/bar_section_momentum_V221112.md new file mode 100644 index 000000000..ce8eb98ee --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_section_momentum_V221112.md @@ -0,0 +1,29 @@ +# bar_section_momentum_V221112:区间动量强弱与波动 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K{n}B_阈值{th}BPV221112` + +## 信号逻辑 + +1. 区间 BP:`(last_close/first_open-1)*10000`; +2. 区间波动:`(max_high/min_low-1)*10000`; +3. `v1`:`上涨/下跌`;`v2`:`强势/弱势`(`|bp|>=th`); +4. `v3`:`高波动/低波动`(`|wave|/|bp| >= 3`)。 + +## 信号列表示例 + +- `Signal('60分钟_D1K10B_阈值100BPV221112_上涨_强势_高波动_0')` +- `Signal('60分钟_D1K10B_阈值100BPV221112_下跌_弱势_低波动_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:窗口长度,默认 `10`; +- `th`:强弱阈值(BP),默认 `100`。 + +## 对齐说明 + +三段分类与 Python `bar_section_momentum_V221112` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_shuang_fei_V230507.md b/.claude/skills/signal-functions/references/signals/bar_shuang_fei_V230507.md new file mode 100644 index 000000000..c02912399 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_shuang_fei_V230507.md @@ -0,0 +1,26 @@ +# bar_shuang_fei_V230507:双飞涨停形态 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}双飞_短线V230507` + +## 信号逻辑 + +1. 前天近似涨停、昨天大阴回撤、今天再度强势上涨; +2. 且今天收盘突破昨天高点,判 `看多`; +3. 不满足返回 `其他`。 + +## 信号列表示例 + +- `Signal('日线_D1双飞_短线V230507_看多_任意_任意_0')` +- `Signal('日线_D1双飞_短线V230507_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +三日组合条件对齐 Python `bar_shuang_fei_V230507`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_single_V230214.md b/.claude/skills/signal-functions/references/signals/bar_single_V230214.md new file mode 100644 index 000000000..1c6e03a6b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_single_V230214.md @@ -0,0 +1,28 @@ +# bar_single_V230214:单K状态信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{t}_状态V230214` + +## 信号逻辑 + +1. 倒数第 `di` 根K线,按 `close/open` 判 `阳线/阴线`; +2. 若 `solid > (upper+lower)*t/10` 判 `长实体`; +3. 若 `upper > (solid+lower)*t/10` 判 `长上影`; +4. 若 `lower > (solid+upper)*t/10` 判 `长下影`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('日线_D1T10_状态V230214_阳线_长实体_任意_0')` +- `Signal('日线_D1T10_状态V230214_阴线_长上影_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `t`:长实体/长影阈值(/10),默认 `10`。 + +## 对齐说明 + +分类阈值与 Python `bar_single_V230214` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_single_V230506.md b/.claude/skills/signal-functions/references/signals/bar_single_V230506.md new file mode 100644 index 000000000..5161f4a27 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_single_V230506.md @@ -0,0 +1,24 @@ +# bar_single_V230506:单K趋势分层信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}单K趋势N{n}_BS辅助V230506` + +## 信号逻辑 + +1. 取截止到倒数第 `di` 根的最近 100 根K线; +2. 计算每根K线因子 `(close-open)/(open*vol)`; +3. 参考 Python `pd.cut(..., n)` 将末根因子分层,输出 `第1层 ~ 第n层`; +4. 若样本不足或存在 `open=0/vol=0`,返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1单K趋势N5_BS辅助V230506_第3层_任意_任意_0')` +- `Signal('60分钟_D1单K趋势N5_BS辅助V230506_其他_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:分层数量,默认 `5`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_td9_V240616.md b/.claude/skills/signal-functions/references/signals/bar_td9_V240616.md new file mode 100644 index 000000000..b1f807bca --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_td9_V240616.md @@ -0,0 +1,26 @@ +# bar_td9_V240616:神奇九转计数 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_神奇九转N{n}_BS辅助V240616` + +## 信号逻辑 + +1. 当前收盘与4根前收盘比较,得到 `1/-1/0`; +2. 统计末端连续同号个数; +3. 连续 `>=n` 个 `1` 输出 `卖点`,连续 `>=n` 个 `-1` 输出 `买点`。 + +## 信号列表示例 + +- `Signal('60分钟_神奇九转N9_BS辅助V240616_卖点_9转_任意_0')` +- `Signal('60分钟_神奇九转N9_BS辅助V240616_买点_9转_任意_0')` + +## 参数说明 + +- `n`:连续计数阈值,默认 `9`。 + +## 对齐说明 + +计数窗口与买卖点定义对齐 Python `bar_td9_V240616`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_time_V230327.md b/.claude/skills/signal-functions/references/signals/bar_time_V230327.md new file mode 100644 index 000000000..f6b6470d9 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_time_V230327.md @@ -0,0 +1,26 @@ +# bar_time_V230327:日内时间分段信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_日内时间_分段V230327` + +## 信号逻辑 + +1. 仅支持 `30分钟/60分钟` 周期; +2. 取最近 100 根K线的 `HH:MM` 去重并排序; +3. 输出当前K线时间在分段序列中的位置:`第{n}段`。 + +## 信号列表示例 + +- `Signal('60分钟_日内时间_分段V230327_第1段_任意_任意_0')` +- `Signal('60分钟_日内时间_分段V230327_第4段_任意_任意_0')` + +## 参数说明 + +- 无额外参数。 + +## 对齐说明 + +分段生成与 Python `bar_time_V230327` 的排序与编号口径一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_tnr_V230629.md b/.claude/skills/signal-functions/references/signals/bar_tnr_V230629.md new file mode 100644 index 000000000..2bdd06a56 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_tnr_V230629.md @@ -0,0 +1,27 @@ +# bar_tnr_V230629:TNR 分层信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TNR{timeperiod}_趋势V230629` + +## 信号逻辑 + +1. 计算每根K线 TNR 值; +2. 取最近100个 TNR 做 `qcut(10)`; +3. 输出末根所在层:`第{n}层`。 + +## 信号列表示例 + +- `Signal('15分钟_D1TNR14_趋势V230629_第7层_任意_任意_0')` +- `Signal('15分钟_D1TNR14_趋势V230629_第2层_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:TNR周期,默认 `14`。 + +## 对齐说明 + +分层逻辑与 `duplicates='drop'` 行为对齐 Python `bar_tnr_V230629`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_tnr_V230630.md b/.claude/skills/signal-functions/references/signals/bar_tnr_V230630.md new file mode 100644 index 000000000..bbb5f8360 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_tnr_V230630.md @@ -0,0 +1,28 @@ +# bar_tnr_V230630:TNR 噪音变化判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TNR{timeperiod}K{k}_趋势V230630` + +## 信号逻辑 + +1. 计算 TNR:`|close_t-close_{t-n}| / sum(|diff(close)|)`; +2. 取最近 `k` 根 TNR 均值,与当前 TNR 比较; +3. 当前值大于均值判 `噪音减少`,否则判 `噪音增加`。 + +## 信号列表示例 + +- `Signal('15分钟_D1TNR14K3_趋势V230630_噪音减少_任意_任意_0')` +- `Signal('15分钟_D1TNR14K3_趋势V230630_噪音增加_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:TNR周期,默认 `14`; +- `k`:均值窗口,默认 `3`。 + +## 对齐说明 + +TNR与噪音方向定义对齐 Python `bar_tnr_V230630`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_trend_V240209.md b/.claude/skills/signal-functions/references/signals/bar_trend_V240209.md new file mode 100644 index 000000000..d6f152d86 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_trend_V240209.md @@ -0,0 +1,27 @@ +# bar_trend_V240209:趋势跟踪结构判定 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{N}趋势跟踪_BS辅助V240209` + +## 信号逻辑 + +1. 在窗口内定位最高点和最低点,结合其先后顺序选择多头或空头分支; +2. 右侧结构满足 `5 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}三K加速_裸K形态V230506` + +## 信号逻辑 + +1. 取倒数第 `di` 根开始的最近3根K线; +2. 三根连续阳线判定 `三连涨`,若高低点依次抬升判定 `新高涨`; +3. 三根连续阴线判定 `三连跌`,若高低点依次下降判定 `新低跌`; +4. 若已形成形态,再按成交量关系细分为 `依次放量/依次缩量/量柱无序`; +5. 数据不足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1三K加速_裸K形态V230506_新高涨_依次放量_任意_0')` +- `Signal('60分钟_D1三K加速_裸K形态V230506_三连跌_量柱无序_任意_0')` +- `Signal('60分钟_D1三K加速_裸K形态V230506_其他_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_vol_bs1_V230224.md b/.claude/skills/signal-functions/references/signals/bar_vol_bs1_V230224.md new file mode 100644 index 000000000..df91d4ca5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_vol_bs1_V230224.md @@ -0,0 +1,27 @@ +# bar_vol_bs1_V230224:量价高低点辅助 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}量价_BS1辅助V230224` + +## 信号逻辑 + +1. 窗口末根创新高且上影显著、成交额远高于均值,判 `看空`; +2. 窗口末根创新低且下影显著、成交额远低于均值,判 `看多`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N20量价_BS1辅助V230224_看空_任意_任意_0')` +- `Signal('60分钟_D1N20量价_BS1辅助V230224_看多_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:窗口长度,默认 `20`。 + +## 对齐说明 + +量价条件阈值与 Python `bar_vol_bs1_V230224` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_vol_grow_V221112.md b/.claude/skills/signal-functions/references/signals/bar_vol_grow_V221112.md new file mode 100644 index 000000000..42d715443 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_vol_grow_V221112.md @@ -0,0 +1,27 @@ +# bar_vol_grow_V221112:成交量放大信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K{n}B_放量V221112` + +## 信号逻辑 + +1. 取倒数第 `di` 根及其前 `n` 根,共 `n+1` 根K线; +2. 计算前 `n` 根平均成交量 `mean_vol`; +3. 若当前量在 `[2*mean_vol, 4*mean_vol]`,判 `是`,否则 `否`。 + +## 信号列表示例 + +- `Signal('60分钟_D2K5B_放量V221112_是_任意_任意_0')` +- `Signal('60分钟_D2K5B_放量V221112_否_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `2`; +- `n`:回看K线数量,默认 `5`。 + +## 对齐说明 + +判定区间与 Python `bar_vol_grow_V221112` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_volatility_V241013.md b/.claude/skills/signal-functions/references/signals/bar_volatility_V241013.md new file mode 100644 index 000000000..4ef1c9071 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_volatility_V241013.md @@ -0,0 +1,27 @@ +# bar_volatility_V241013:波动率三层分类 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_波动率分层W{w}N{n}_完全分类V241013` + +## 信号逻辑 + +1. 定义 `volatility_n = 最近n根收盘最大值-最小值`; +2. 对最近 `w` 根缓存值做三分位分层; +3. 末根分层输出 `低波动/中波动/高波动`。 + +## 信号列表示例 + +- `Signal('60分钟_波动率分层W200N10_完全分类V241013_低波动_任意_任意_0')` +- `Signal('60分钟_波动率分层W200N10_完全分类V241013_高波动_任意_任意_0')` + +## 参数说明 + +- `w`:分层窗口,默认 `200`; +- `n`:波动率窗口,默认 `10`。 + +## 对齐说明 + +缓存写入与 `qcut` 退化行为对齐 Python `bar_volatility_V241013`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_weekday_V230328.md b/.claude/skills/signal-functions/references/signals/bar_weekday_V230328.md new file mode 100644 index 000000000..22335f5a8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_weekday_V230328.md @@ -0,0 +1,25 @@ +# bar_weekday_V230328:周内时间分段信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_周内时间_分段V230328` + +## 信号逻辑 + +1. 当样本数量不足 20 根时返回 `其他`; +2. 否则将最新K线日期按 `weekday` 映射到 `周一~周日`。 + +## 信号列表示例 + +- `Signal('60分钟_周内时间_分段V230328_周一_任意_任意_0')` +- `Signal('60分钟_周内时间_分段V230328_周五_任意_任意_0')` + +## 参数说明 + +- 无额外参数。 + +## 对齐说明 + +weekday 映射表与 Python `bar_weekday_V230328` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_window_ps_V230731.md b/.claude/skills/signal-functions/references/signals/bar_window_ps_V230731.md new file mode 100644 index 000000000..4689720e2 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_window_ps_V230731.md @@ -0,0 +1,29 @@ +# bar_window_ps_V230731:支撑压力位分位特征 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_W{w}M{m}N{n}L{l}_支撑压力位V230731` + +## 信号逻辑 + +1. 用最近 `n` 笔高低点构造压力线与支撑线; +2. 计算收盘在区间中的位置 `pct=(close-L)/(H-L)`; +3. 对最近 `m` 个 `pct` 做 `qcut(l)`,输出最近 `w` 根的压力/支撑层与当前层。 + +## 信号列表示例 + +- `Signal('15分钟_W5M40N8L5_支撑压力位V230731_压力N5_支撑N3_当前N4_0')` +- `Signal('15分钟_W5M40N8L5_支撑压力位V230731_压力N4_支撑N1_当前N2_0')` + +## 参数说明 + +- `w`:观察窗口,默认 `5`; +- `m`:分位样本长度,默认 `40`; +- `n`:笔窗口长度,默认 `8`; +- `l`:分层数量,默认 `5`。 + +## 对齐说明 + +参数约束与分位定义对齐 Python `bar_window_ps_V230731`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_window_ps_V230801.md b/.claude/skills/signal-functions/references/signals/bar_window_ps_V230801.md new file mode 100644 index 000000000..2849c0c4c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_window_ps_V230801.md @@ -0,0 +1,27 @@ +# bar_window_ps_V230801:支撑压力位窗口极值 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}W{w}_支撑压力位V230801` + +## 信号逻辑 + +1. 基于最近 `n` 笔和当前未完成笔计算压力/支撑区间; +2. 将最近 `w` 根收盘映射到 `0~9` 分位整数; +3. 输出窗口最大/最小/当前分位。 + +## 信号列表示例 + +- `Signal('60分钟_N8W5_支撑压力位V230801_最大N7_最小N3_当前N5_0')` +- `Signal('60分钟_N8W5_支撑压力位V230801_最大N4_最小N0_当前N2_0')` + +## 参数说明 + +- `w`:观察窗口,默认 `5`; +- `n`:笔窗口长度,默认 `8`。 + +## 对齐说明 + +`ubi` 口径与整数分位映射对齐 Python `bar_window_ps_V230801`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_window_std_V230731.md b/.claude/skills/signal-functions/references/signals/bar_window_std_V230731.md new file mode 100644 index 000000000..678dd1c59 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_window_std_V230731.md @@ -0,0 +1,29 @@ +# bar_window_std_V230731:窗口波动分层特征 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}M{m}N{n}_窗口波动V230731` + +## 信号逻辑 + +1. 计算每根K线的 `STD20`(前20收盘标准差); +2. 取最近 `m` 个 `STD20` 做 `qcut(n)` 分层; +3. 输出最近 `w` 根中的最大层和最小层。 + +## 信号列表示例 + +- `Signal('60分钟_D1W5M100N10_窗口波动V230731_高波N8_低波N6_任意_0')` +- `Signal('60分钟_D1W5M100N10_窗口波动V230731_高波N4_低波N3_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `w`:观察窗口,默认 `5`; +- `m`:分层样本长度,默认 `100`; +- `n`:分层数量,默认 `10`。 + +## 对齐说明 + +STD20口径与 `qcut(..., duplicates='drop')` 对齐 Python `bar_window_std_V230731`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_zdf_V221203.md b/.claude/skills/signal-functions/references/signals/bar_zdf_V221203.md new file mode 100644 index 000000000..827e007fc --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_zdf_V221203.md @@ -0,0 +1,28 @@ +# bar_zdf_V221203:单根涨跌幅区间信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}{mode}_{t1}至{t2}` + +## 信号逻辑 + +1. 读取倒数第 `di` 根及其前一根K线; +2. `mode=ZF` 使用涨幅 `close/prev_close-1`,`mode=DF` 使用跌幅 `1-close/prev_close`; +3. 换算为 BP 后在 `[t1, t2]` 判 `满足`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('日线_D1ZF_300至600_满足_任意_任意_0')` +- `Signal('日线_D1DF_300至600_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `mode`:`ZF` 或 `DF`,默认 `ZF`; +- `span`:区间下上界(`t1,t2`),默认 `300,600`。 + +## 对齐说明 + +BP 计算与 Python `bar_zdf_V221203` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_zdt_V230331.md b/.claude/skills/signal-functions/references/signals/bar_zdt_V230331.md new file mode 100644 index 000000000..85a3e4774 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_zdt_V230331.md @@ -0,0 +1,24 @@ +# bar_zdt_V230331:涨跌停识别信号 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}_涨跌停V230331` + +## 信号逻辑 + +1. 取倒数第 `di` 根与其前一根K线; +2. 若当前K线收盘等于最高且不低于前收,记为 `涨停`; +3. 若当前K线收盘等于最低且不高于前收,记为 `跌停`; +4. 否则记为 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1_涨跌停V230331_涨停_任意_任意_0')` +- `Signal('60分钟_D1_涨跌停V230331_跌停_任意_任意_0')` +- `Signal('60分钟_D1_涨跌停V230331_其他_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_zfzd_V241013.md b/.claude/skills/signal-functions/references/signals/bar_zfzd_V241013.md new file mode 100644 index 000000000..608538e25 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_zfzd_V241013.md @@ -0,0 +1,26 @@ +# bar_zfzd_V241013:窄幅震荡(全重叠) + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_窄幅震荡N{n}_形态V241013` + +## 信号逻辑 + +1. 取最近 `n` 根K线; +2. 若 `min(high) >= max(low)`,判为窗口内全重叠; +3. 输出 `满足`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_窄幅震荡N5_形态V241013_满足_任意_任意_0')` +- `Signal('60分钟_窄幅震荡N5_形态V241013_其他_任意_任意_0')` + +## 参数说明 + +- `n`:窗口长度,默认 `5`。 + +## 对齐说明 + +重叠判定公式与 Python `bar_zfzd_V241013` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bar_zfzd_V241014.md b/.claude/skills/signal-functions/references/signals/bar_zfzd_V241014.md new file mode 100644 index 000000000..c0e454ef0 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_zfzd_V241014.md @@ -0,0 +1,26 @@ +# bar_zfzd_V241014:窄幅震荡(最大实体重叠) + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_窄幅震荡N{n}_形态V241014` + +## 信号逻辑 + +1. 找到窗口内最大实体K线; +2. 若其实体明显过大(超过窗口实体均值2倍)直接排除; +3. 若该K线与窗口内所有K线区间均重叠,判 `满足`。 + +## 信号列表示例 + +- `Signal('60分钟_窄幅震荡N10_形态V241014_满足_任意_任意_0')` +- `Signal('60分钟_窄幅震荡N10_形态V241014_其他_任意_任意_0')` + +## 参数说明 + +- `n`:窗口长度,默认 `5`。 + +## 对齐说明 + +最大实体筛选与重叠判断对齐 Python `bar_zfzd_V241014`。 diff --git a/.claude/skills/signal-functions/references/signals/bar_zt_count_V230504.md b/.claude/skills/signal-functions/references/signals/bar_zt_count_V230504.md new file mode 100644 index 000000000..76c0655d1 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bar_zt_count_V230504.md @@ -0,0 +1,28 @@ +# bar_zt_count_V230504:窗口涨停计数 + +> 模块: `bar.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{window}涨停计数_裸K形态V230504` + +## 信号逻辑 + +1. 在窗口内按相邻K线判断 `涨停`:`b2.close > b1.close*1.07 && b2.close==b2.high`; +2. 统计总次数 `sum(c1)`; +3. 统计连续双涨停次数 `cc`(相邻两个都为1); +4. 若总次数为0返回 `其他`,否则输出 `"{sum}次" + "连续{cc}次"`。 + +## 信号列表示例 + +- `Signal('日线_D1W5涨停计数_裸K形态V230504_1次_连续0次_任意_0')` +- `Signal('日线_D1W5涨停计数_裸K形态V230504_3次_连续2次_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `window`:统计窗口,默认 `5`。 + +## 对齐说明 + +涨停阈值与连续计次与 Python `bar_zt_count_V230504` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/bias_up_dw_line_V230618.md b/.claude/skills/signal-functions/references/signals/bias_up_dw_line_V230618.md new file mode 100644 index 000000000..9a84ae66c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/bias_up_dw_line_V230618.md @@ -0,0 +1,28 @@ +# bias_up_dw_line_V230618:BIAS 三周期共振信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}P{p}TH1{th1}TH2{th2}TH3{th3}_BIAS乖离率V230618` + +## 信号逻辑 + +1. 分别计算 `n/m/p` 三个窗口的均线乖离率; +2. 三个乖离率同时超过正阈值判 `看多`; +3. 三个乖离率同时低于负阈值判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N6M12P24TH11TH23TH35_BIAS乖离率V230618_看多_任意_任意_0')` +- `Signal('60分钟_D1N6M12P24TH11TH23TH35_BIAS乖离率V230618_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n/m/p`:三组均线窗口,默认 `6/12/24`; +- `th1/th2/th3`:对应窗口阈值,默认 `1/3/5`。 + +## 对齐说明 + +与 Python `bias_up_dw_line_V230618` 的三阈值共振条件一致。 diff --git a/.claude/skills/signal-functions/references/signals/byi_bi_end_V230106.md b/.claude/skills/signal-functions/references/signals/byi_bi_end_V230106.md new file mode 100644 index 000000000..78f494a19 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/byi_bi_end_V230106.md @@ -0,0 +1,26 @@ +# byi_bi_end_V230106:分型停顿辅助笔结束信号 + +> 模块: `byi.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0停顿分型_BE辅助V230106` + +## 信号逻辑 + +1. 基于最后一笔方向与末端分型,判断停顿分型是否成立; +2. 满足底分型停顿给出 `看多`,满足顶分型停顿给出 `看空`; +3. 再按最后一根K线实体强弱输出 `强/弱`。 + +## 信号列表示例 + +- `Signal('60分钟_D0停顿分型_BE辅助V230106_看多_强_任意_0')` +- `Signal('60分钟_D0停顿分型_BE辅助V230106_看空_弱_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空。 + +## 对齐说明 + +与 Python `byi_bi_end_V230106` 的停顿判定条件一致。 diff --git a/.claude/skills/signal-functions/references/signals/byi_bi_end_V230107.md b/.claude/skills/signal-functions/references/signals/byi_bi_end_V230107.md new file mode 100644 index 000000000..0cfb0d9a6 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/byi_bi_end_V230107.md @@ -0,0 +1,26 @@ +# byi_bi_end_V230107:验证分型辅助笔结束信号 + +> 模块: `byi.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0验证分型_BE辅助V230107` + +## 信号逻辑 + +1. 校验最后一笔末端分型与末三分型结构关系; +2. 满足验证底分型给出 `看多`,验证顶分型给出 `看空`; +3. 依据最后一根K线实体强弱输出 `强/弱`。 + +## 信号列表示例 + +- `Signal('60分钟_D0验证分型_BE辅助V230107_看多_强_任意_0')` +- `Signal('60分钟_D0验证分型_BE辅助V230107_看空_弱_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空。 + +## 对齐说明 + +与 Python `byi_bi_end_V230107` 的结构校验和强弱规则一致。 diff --git a/.claude/skills/signal-functions/references/signals/byi_fx_num_V230628.md b/.claude/skills/signal-functions/references/signals/byi_fx_num_V230628.md new file mode 100644 index 000000000..a1777fc04 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/byi_fx_num_V230628.md @@ -0,0 +1,27 @@ +# byi_fx_num_V230628:前笔分型数量约束信号 + +> 模块: `byi.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}笔分型数大于{num}_BE辅助V230628` + +## 信号逻辑 + +1. 取倒数第 `di` 笔; +2. 输出该笔方向(`向上/向下`); +3. 若该笔内部分型数量 `>= num` 记 `满足`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1笔分型数大于4_BE辅助V230628_向下_满足_任意_0')` +- `Signal('60分钟_D1笔分型数大于4_BE辅助V230628_向上_其他_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始检查,默认 `1`; +- `num`:分型数量阈值,默认 `4`。 + +## 对齐说明 + +与 Python `byi_fx_num_V230628` 的数量判断一致。 diff --git a/.claude/skills/signal-functions/references/signals/byi_second_bs_V230324.md b/.claude/skills/signal-functions/references/signals/byi_second_bs_V230324.md new file mode 100644 index 000000000..fd39a1d6b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/byi_second_bs_V230324.md @@ -0,0 +1,27 @@ +# byi_second_bs_V230324:二类买卖点辅助信号 + +> 模块: `byi.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}回抽零轴_BS2辅助V230324` + +## 信号逻辑 + +1. 基于最近 9 笔关键分型的 DIF 值和标准差构造条件; +2. 满足向下笔回抽零轴条件判 `看多`; +3. 满足向上笔回抽零轴条件判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9回抽零轴_BS2辅助V230324_看多_任意_任意_0')` +- `Signal('60分钟_D1MACD12#26#9回抽零轴_BS2辅助V230324_看空_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始检查,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 + +## 对齐说明 + +按 Python `byi_second_bs_V230324` 的 DIF 取样点和不等式链实现。 diff --git a/.claude/skills/signal-functions/references/signals/byi_symmetry_zs_V221107.md b/.claude/skills/signal-functions/references/signals/byi_symmetry_zs_V221107.md new file mode 100644 index 000000000..ec1decc4e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/byi_symmetry_zs_V221107.md @@ -0,0 +1,27 @@ +# byi_symmetry_zs_V221107:对称中枢识别信号 + +> 模块: `byi.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}B_对称中枢` + +## 信号逻辑 + +1. 取倒数 `di` 截止最近 10 笔; +2. 依次检查最近 `7/5/3` 笔是否构成对称中枢; +3. 命中则输出 `是 + {i}笔`,否则 `否 + 任意`; +4. 方向位按最后一笔反向映射(最后笔向下 -> `向上`)。 + +## 信号列表示例 + +- `Signal('60分钟_D1B_对称中枢_是_向上_7笔_0')` +- `Signal('60分钟_D1B_对称中枢_否_向下_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始,默认 `1`。 + +## 对齐说明 + +与 Python `byi_symmetry_zs_V221107` 的 7/5/3 笔判定序一致。 diff --git a/.claude/skills/signal-functions/references/signals/cat_macd_V230518.md b/.claude/skills/signal-functions/references/signals/cat_macd_V230518.md new file mode 100644 index 000000000..6a7577231 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cat_macd_V230518.md @@ -0,0 +1,27 @@ +# cat_macd_V230518:高低级别 MACD 交叉联立信号 + +> 模块: `cat.rs` | 类别: `trader` + +## 参数模板 + +`"{freq1}#{freq2}_MACD交叉_联立V230518` + +## 信号逻辑 + +1. 当 `freq1` 最近一次由负翻正(MACD 金叉)后,检查 `freq2` 是否仅出现 1 次金叉,满足判 `看多`; +2. 当 `freq1` 最近一次由正翻负(MACD 死叉)后,检查 `freq2` 是否仅出现 1 次死叉,满足判 `看空`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('日线#60分钟_MACD交叉_联立V230518_看多_任意_任意_0')` +- `Signal('日线#60分钟_MACD交叉_联立V230518_看空_任意_任意_0')` + +## 参数说明 + +- `freq1`:高一级别周期,默认 `5分钟`; +- `freq2`:低一级别周期,默认 `1分钟`。 + +## 对齐说明 + +触发窗口、首次交叉判定与 Python `cat_macd_V230518` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cat_macd_V230520.md b/.claude/skills/signal-functions/references/signals/cat_macd_V230520.md new file mode 100644 index 000000000..8a4125e18 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cat_macd_V230520.md @@ -0,0 +1,27 @@ +# cat_macd_V230520:高低级别 MACD 缩柱联立信号 + +> 模块: `cat.rs` | 类别: `trader` + +## 参数模板 + +`"{freq1}#{freq2}_MACD交叉_联立V230520` + +## 信号逻辑 + +1. `freq1` 最近三根 MACD 连续抬升且历史出现负值时,检查 `freq2` 的金死叉结构,满足判 `看多`; +2. `freq1` 最近三根 MACD 连续下压且历史出现正值时,检查 `freq2` 的死金叉结构,满足判 `看空`; +3. 同时给出触发时 `DEA` 在零轴上下的位置 `v2`。 + +## 信号列表示例 + +- `Signal('日线#60分钟_MACD交叉_联立V230520_看多_零轴上方_任意_0')` +- `Signal('日线#60分钟_MACD交叉_联立V230520_看空_零轴下方_任意_0')` + +## 参数说明 + +- `freq1`:高一级别周期,默认 `5分钟`; +- `freq2`:低一级别周期,默认 `1分钟`。 + +## 对齐说明 + +交叉次数、顺序和阈值条件与 Python `cat_macd_V230520` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/cci_decision_V240620.md b/.claude/skills/signal-functions/references/signals/cci_decision_V240620.md new file mode 100644 index 000000000..a6c215f1d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cci_decision_V240620.md @@ -0,0 +1,26 @@ +# cci_decision_V240620:CCI 逆势决策区域 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}CCI_决策区域V240620` + +## 信号逻辑 + +1. 固定计算 `CCI(14)`; +2. 取最近 `n` 根 CCI:若最小值 `< -100` 判 `开多`,`v2` 为 `< -100` 的出现次数; +3. 若最大值 `> 100` 判 `开空`,`v2` 为 `> 100` 的出现次数(覆盖开多分支)。 + +## 信号列表示例 + +- `Signal('15分钟_N4CCI_决策区域V240620_开多_2次_任意_0')` +- `Signal('15分钟_N4CCI_决策区域V240620_开空_1次_任意_0')` + +## 参数说明 + +- `n`:统计窗口长度,默认 `2`。 + +## 对齐说明 + +分支顺序与 Python `cci_decision_V240620` 保持一致(后判空头覆盖前判多头)。 diff --git a/.claude/skills/signal-functions/references/signals/clv_up_dw_line_V230605.md b/.claude/skills/signal-functions/references/signals/clv_up_dw_line_V230605.md new file mode 100644 index 000000000..41d3fa43c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/clv_up_dw_line_V230605.md @@ -0,0 +1,27 @@ +# clv_up_dw_line_V230605:CLV 多空信号 + +> 模块: `clv.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}_CLV多空V230605` + +## 信号逻辑 + +1. 取最近 `n` 根K线,计算每根 `(2*close-low-high)/(high-low)`; +2. 计算该序列均值 `clv_ma`; +3. `clv_ma > 0` 判 `看多`,否则判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N70_CLV多空V230605_看多_任意_任意_0')` +- `Signal('60分钟_D1N70_CLV多空V230605_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:统计窗口大小,默认 `70`。 + +## 对齐说明 + +CLV 公式与阈值判断对齐 Python `clv_up_dw_line_V230605`。 diff --git a/.claude/skills/signal-functions/references/signals/cmo_up_dw_line_V230605.md b/.claude/skills/signal-functions/references/signals/cmo_up_dw_line_V230605.md new file mode 100644 index 000000000..c495c500f --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cmo_up_dw_line_V230605.md @@ -0,0 +1,28 @@ +# cmo_up_dw_line_V230605:CMO 能量阈值信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}_CMO能量V230605` + +## 信号逻辑 + +1. 统计窗口内上涨/下跌收盘差值总和; +2. 计算 `cmo = (up-dw)/(up+dw)*100`; +3. `cmo > m` 判 `看多`;`cmo < -m` 判 `看空`;否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N70M30_CMO能量V230605_看多_任意_任意_0')` +- `Signal('60分钟_D1N70M30_CMO能量V230605_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:统计窗口,默认 `70`; +- `m`:阈值,默认 `30`。 + +## 对齐说明 + +与 Python `cmo_up_dw_line_V230605` 保持同一阈值与分支顺序。 diff --git a/.claude/skills/signal-functions/references/signals/coo_cci_V230323.md b/.claude/skills/signal-functions/references/signals/coo_cci_V230323.md new file mode 100644 index 000000000..e68bc0b11 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/coo_cci_V230323.md @@ -0,0 +1,29 @@ +# coo_cci_V230323:CCI 结合均线的多空与方向信号 + +> 模块: `coo.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}CCI{n}#{ma_type}#{m}_BS辅助V230323` + +## 信号逻辑 + +1. 计算 `CCI(n)` 与 `MA(n*m)`; +2. `CCI>100` 且 `close>MA` 判 `多头`,`CCI<-100` 且 `close 模块: `coo.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{ma_type}#{n}_BS辅助V230322` + +## 信号逻辑 + +1. 计算 `KDJ` 与 `MA(n)`; +2. `close > MA` 且 `K < D` 判 `多头`; +3. `close < MA` 且 `K > D` 判 `空头`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1KDJ9#3#3#EMA#3_BS辅助V230322_多头_任意_任意_0')` +- `Signal('60分钟_D1KDJ9#3#3#EMA#3_BS辅助V230322_空头_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:均线周期,默认 `3`; +- `ma_type`:均线类型,默认 `EMA`; +- `fastk_period/slowk_period/slowd_period`:KDJ 参数,默认 `9/3/3`。 + +## 对齐说明 + +与 Python `coo_kdj_V230322` 的组合条件一致。 diff --git a/.claude/skills/signal-functions/references/signals/coo_sar_V230325.md b/.claude/skills/signal-functions/references/signals/coo_sar_V230325.md new file mode 100644 index 000000000..cb6293909 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/coo_sar_V230325.md @@ -0,0 +1,27 @@ +# coo_sar_V230325:SAR 与区间极值配合信号 + +> 模块: `coo.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}SAR_BS辅助V230325` + +## 信号逻辑 + +1. 计算最近 `n` 根收盘价区间高低点; +2. 若 `close > SAR` 且 `high >= 区间最高收盘` 判 `多头`; +3. 若 `close < SAR` 且 `low <= 区间最低收盘` 判 `空头`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N60SAR_BS辅助V230325_多头_任意_任意_0')` +- `Signal('60分钟_D1N60SAR_BS辅助V230325_空头_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:区间窗口,默认 `60`。 + +## 对齐说明 + +与 Python `coo_sar_V230325` 的 SAR 与区间条件一致。 diff --git a/.claude/skills/signal-functions/references/signals/coo_td_V221110.md b/.claude/skills/signal-functions/references/signals/coo_td_V221110.md new file mode 100644 index 000000000..285a90bc8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/coo_td_V221110.md @@ -0,0 +1,26 @@ +# coo_td_V221110:TD 神奇九转信号(旧版模板) + +> 模块: `coo.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_TD` + +## 信号逻辑 + +1. 取倒数 `di` 截止的最近 50 根收盘价; +2. 按 `close[i]` 与 `close[i-4]` 比较累计 TD 计数; +3. 根据最新 TD 值及前一值输出 `看多/看空/延续` 与 `TD顶/TD底/非顶/非底`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K_TD_延续_非顶_任意_0')` +- `Signal('60分钟_D1K_TD_看空_TD底_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `coo_td_V221110` 的 TD 计数递推一致。 diff --git a/.claude/skills/signal-functions/references/signals/coo_td_V221111.md b/.claude/skills/signal-functions/references/signals/coo_td_V221111.md new file mode 100644 index 000000000..50be8b0b5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/coo_td_V221111.md @@ -0,0 +1,26 @@ +# coo_td_V221111:TD 神奇九转信号 + +> 模块: `coo.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TD_BS辅助V221111` + +## 信号逻辑 + +1. 取倒数 `di` 截止的最近 50 根收盘价; +2. 计算 TD 计数序列; +3. 输出 `看多/看空/延续` 与 `TD顶/TD底/非顶/非底` 组合。 + +## 信号列表示例 + +- `Signal('60分钟_D1TD_BS辅助V221111_延续_非顶_任意_0')` +- `Signal('60分钟_D1TD_BS辅助V221111_看多_TD顶_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `coo_td_V221111` 的窗口和判定分支一致。 diff --git a/.claude/skills/signal-functions/references/signals/cvolp_up_dw_line_V230612.md b/.claude/skills/signal-functions/references/signals/cvolp_up_dw_line_V230612.md new file mode 100644 index 000000000..dc179a00c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cvolp_up_dw_line_V230612.md @@ -0,0 +1,31 @@ +# cvolp_up_dw_line_V230612:CVOLP 动量变化率信号 + +> 模块: `cvolp.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}UP{up}DW{dw}_CVOLP动量变化率V230612` + +## 信号逻辑 + +1. 取最近 `n+m` 根成交量,构造长度为 `n` 的指数权重; +2. 计算卷积平滑序列 `emap`,并将前 `n` 项置为 `emap[n]`; +3. 计算 `sroc = (emap - roll(emap, m))[-1] / roll(emap, m)[-1]`; +4. `sroc > up/100` 判 `看多`,`sroc < -dw/100` 判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N34M55UP5DW5_CVOLP动量变化率V230612_看多_任意_任意_0')` +- `Signal('60分钟_D1N34M55UP5DW5_CVOLP动量变化率V230612_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:卷积平滑窗口,默认 `34`; +- `m`:滚动比较窗口,默认 `55`; +- `up`:看多阈值(百分比整数),默认 `5`; +- `dw`:看空阈值(百分比整数),默认 `5`。 + +## 对齐说明 + +卷积平滑与 `roll` 口径对齐 Python `cvolp_up_dw_line_V230612`。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_base_V230228.md b/.claude/skills/signal-functions/references/signals/cxt_bi_base_V230228.md new file mode 100644 index 000000000..b05213e04 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_base_V230228.md @@ -0,0 +1,28 @@ +# cxt_bi_base_V230228:笔基础状态信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0BL{bi_init_length}_V230228` + +## 信号逻辑 + +1. 读取最新一笔方向; +2. 若最新笔为向下笔,当前状态记为 `向上`,反之记为 `向下`; +3. 若未完成笔长度 `bars_ubi` 大于等于 `bi_init_length`,记为 `中继`,否则记为 `转折`; +4. 笔数据不足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D0BL9_V230228_向上_中继_任意_0')` +- `Signal('60分钟_D0BL9_V230228_向下_转折_任意_0')` +- `Signal('60分钟_D0BL9_V230228_其他_任意_任意_0')` + +## 参数说明 + +- `bi_init_length`:未完成笔长度阈值,默认 `9`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_base_V230228` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230104.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230104.md new file mode 100644 index 000000000..a38080c69 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230104.md @@ -0,0 +1,28 @@ +# cxt_bi_end_V230104:单均线辅助判断笔结束 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230104` + +## 信号逻辑 + +1. 计算指定均线,并取最近 3 根原始 K 线; +2. 若向下笔尾部出现三连阳且收盘强于均线阈值,判定 `看多`;向上笔尾部三连阴且收盘弱于均线阈值,判定 `看空`; +3. 不满足边界、均线或形态条件时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D0SMA#5T50_BE辅助V230104_看多_任意_任意_0')` +- `Signal('60分钟_D0EMA#8T30_BE辅助V230104_看空_任意_任意_0')` + +## 参数说明 + +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`; +- `th`:收盘价相对均线的 BP 阈值,默认 `50`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230104` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230105.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230105.md new file mode 100644 index 000000000..6f3bdebbf --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230105.md @@ -0,0 +1,28 @@ +# cxt_bi_end_V230105:K线形态+均线辅助判断笔结束 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0{ma_type}#{timeperiod}T{th}_BE辅助V230105` + +## 信号逻辑 + +1. 提取最后一笔终点分型的两根原始 K 线,并计算指定均线; +2. 向下笔若先阴后强阳上穿均线阈值,判定 `看多`;向上笔若先阳后强阴下破均线阈值,判定 `看空`; +3. 未完成笔过长、分型样本不足或均线不可用时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D0SMA#5T50_BE辅助V230105_看多_任意_任意_0')` +- `Signal('60分钟_D0EMA#8T30_BE辅助V230105_看空_任意_任意_0')` + +## 参数说明 + +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`; +- `th`:第二根 K 线相对均线的突破阈值,默认 `50` BP。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230105` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230222.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230222.md new file mode 100644 index 000000000..c8db20f48 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230222.md @@ -0,0 +1,27 @@ +# cxt_bi_end_V230222:未完成笔分型新高新低次数 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D1MO{max_overlap}_BE辅助V230222` + +## 信号逻辑 + +1. 拼接最后一笔内部已确认分型与当前 UBI 分型序列; +2. 仅当最新分型刚确认,或距最新原始 K 线不超过 `max_overlap` 根时继续判断; +3. 若最新顶分型创序列新高,输出 `新高_第X次`;若底分型创新低,输出 `新低_第X次`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MO3_BE辅助V230222_新高_第2次_任意_0')` +- `Signal('60分钟_D1MO3_BE辅助V230222_新低_第1次_任意_0')` + +## 参数说明 + +- `max_overlap`:允许最新分型与当前原始 K 线的最大重叠根数,默认 `3`; +- 超出确认时机或分型不足时返回 `其他_其他`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230222` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230224.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230224.md new file mode 100644 index 000000000..d5266be4f --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230224.md @@ -0,0 +1,27 @@ +# cxt_bi_end_V230224:量价配合笔结束辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D1_BE辅助V230224` + +## 信号逻辑 + +1. 统计最后一笔整体均量与终点分型均量; +2. 长上影且分型显著放量时判定 `看空`,长下影且分型显著缩量时判定 `看多`; +3. 若笔或分型样本不足,或未完成笔过长,则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1_BE辅助V230224_看多_任意_任意_0')` +- `Signal('60分钟_D1_BE辅助V230224_看空_任意_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空; +- 仅在 UBI 较短时使用量价关系辅助判断笔结束。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230224` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230312.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230312.md new file mode 100644 index 000000000..cefb938d2 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230312.md @@ -0,0 +1,28 @@ +# cxt_bi_end_V230312:MACD辅助判断笔结束 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0MACD{fastperiod}#{slowperiod}#{signalperiod}_BE辅助V230312` + +## 信号逻辑 + +1. 计算指定参数的 MACD,并读取最后一笔终点分型对应的首末原始 K 线; +2. 向下笔若分型尾部 MACD 柱值高于分型起点,判定 `看多`;向上笔反向判定 `看空`; +3. MACD 缓存、分型样本或边界条件不满足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D0MACD12#26#9_BE辅助V230312_看多_任意_任意_0')` +- `Signal('60分钟_D0MACD12#26#9_BE辅助V230312_看空_任意_任意_0')` + +## 参数说明 + +- `fastperiod`:MACD 快线周期,默认 `12`; +- `slowperiod`:MACD 慢线周期,默认 `26`; +- `signalperiod`:信号线周期,默认 `9`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230312` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230320.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230320.md new file mode 100644 index 000000000..7d6c49fcc --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230320.md @@ -0,0 +1,27 @@ +# cxt_bi_end_V230320:质数窗口笔结束辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0质数窗口MO{max_overlap}_BE辅助V230320` + +## 信号逻辑 + +1. 展开当前 UBI 原始 K 线,统计其长度是否落在预设质数集合中; +2. 若向上笔后的 UBI 在最近 `max_overlap` 根内创新低,判定 `看多`;向下笔反向判定 `看空`; +3. 输出时补充 `XXK` 表示当前 UBI 长度。 + +## 信号列表示例 + +- `Signal('60分钟_D0质数窗口MO3_BE辅助V230320_看多_13K_任意_0')` +- `Signal('60分钟_D0质数窗口MO3_BE辅助V230320_看空_17K_任意_0')` + +## 参数说明 + +- `max_overlap`:允许用末尾 `max_overlap` 根 K 线判断极值,默认 `3`; +- 质数窗口集合固定为 `11~97` 内常用质数。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230320` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230322.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230322.md new file mode 100644 index 000000000..031aab7eb --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230322.md @@ -0,0 +1,27 @@ +# cxt_bi_end_V230322:分型配合均线的笔结束辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0分型配合{ma_type}#{timeperiod}_BE辅助V230322` + +## 信号逻辑 + +1. 读取最新 UBI 分型对应的原始 K 线,并提取分型区间内的均线序列; +2. 向上笔若最新分型与均线位置形成顶部配合,判定 `看空`;向下笔反向判定 `看多`; +3. 再用 `同向分型/反向分型` 说明分型方向与笔方向关系。 + +## 信号列表示例 + +- `Signal('60分钟_D0分型配合SMA#5_BE辅助V230322_看多_反向分型_任意_0')` +- `Signal('60分钟_D0分型配合EMA#8_BE辅助V230322_看空_同向分型_任意_0')` + +## 参数说明 + +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230322` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230324.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230324.md new file mode 100644 index 000000000..3cb0e85be --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230324.md @@ -0,0 +1,27 @@ +# cxt_bi_end_V230324:笔结束分型均线突破 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D0{ma_type}#{timeperiod}均线突破_BE辅助V230324` + +## 信号逻辑 + +1. 计算指定均线,并提取最后一笔终点分型除最后一根之外的均线序列; +2. 向上笔若上一根收盘跌破分型内最低均线,判定 `看空`;向下笔若上一根收盘突破分型内最高均线,判定 `看多`; +3. 数据不足、UBI 过长或均线不可用时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D0SMA#5均线突破_BE辅助V230324_看多_任意_任意_0')` +- `Signal('60分钟_D0EMA#13均线突破_BE辅助V230324_看空_任意_任意_0')` + +## 参数说明 + +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230324` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230618.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230618.md new file mode 100644 index 000000000..16d57dc80 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230618.md @@ -0,0 +1,27 @@ +# cxt_bi_end_V230618:笔结束小中枢辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MO{max_overlap}_BE辅助V230618` + +## 信号逻辑 + +1. 读取倒数第 `di` 笔的原始 K 线并做价格覆盖计数; +2. 统计覆盖次数形成的峰值数量,近似识别笔内小中枢; +3. 输出 `看多/看空` 和 `X小中枢/其他`,用于辅助笔结束判断。 + +## 信号列表示例 + +- `Signal('60分钟_D1MO3_BE辅助V230618_看多_1小中枢_任意_0')` +- `Signal('60分钟_D1MO3_BE辅助V230618_看空_其他_任意_0')` + +## 参数说明 + +- `di`:取倒数第 `di` 笔,默认 `1`; +- `max_overlap`:控制 UBI 最大允许延伸长度,默认 `3`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230618` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230815.md b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230815.md new file mode 100644 index 000000000..8e13b8ee9 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_end_V230815.md @@ -0,0 +1,27 @@ +# cxt_bi_end_V230815:快速突破反向笔 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_快速突破_BE辅助V230815` + +## 信号逻辑 + +1. 读取最后一笔和当前未完成笔最后一根 K 线; +2. 向上笔若被最新低点快速跌破,输出 `向下`;向下笔若被最新高点快速突破,输出 `向上`; +3. 笔数不足或 UBI 已延伸过长时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_快速突破_BE辅助V230815_向上_任意_任意_0')` +- `Signal('60分钟_快速突破_BE辅助V230815_向下_任意_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空; +- 仅用于很短的 UBI 场景,强调“快速突破”。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_end_V230815` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_status_V230101.md b/.claude/skills/signal-functions/references/signals/cxt_bi_status_V230101.md new file mode 100644 index 000000000..830613586 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_status_V230101.md @@ -0,0 +1,27 @@ +# cxt_bi_status_V230101:笔表里关系信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D1_表里关系V230101` + +## 信号逻辑 + +1. 依据最后一笔方向和 `bars_ubi` 长度判定外部方向(`向上/向下`); +2. 结合未完成笔最后一个分型(顶分/底分)判定内部状态(`顶分/底分/延伸`); +3. 笔或分型数据不足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1_表里关系V230101_向上_顶分_任意_0')` +- `Signal('60分钟_D1_表里关系V230101_向下_底分_任意_0')` +- `Signal('60分钟_D1_表里关系V230101_向上_延伸_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_status_V230101` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_status_V230102.md b/.claude/skills/signal-functions/references/signals/cxt_bi_status_V230102.md new file mode 100644 index 000000000..6c4544455 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_status_V230102.md @@ -0,0 +1,28 @@ +# cxt_bi_status_V230102:笔表里关系信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D1_表里关系V230102` + +## 信号逻辑 + +1. 沿用 `cxt_bi_status_V230101` 的表里方向和分型判定规则; +2. 仅当最后一根原始K线时间等于最新 UBI 分型确认结束时间时触发; +3. 不满足触发时机或数据不足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1_表里关系V230102_向下_底分_任意_0')` +- `Signal('60分钟_D1_表里关系V230102_向下_延伸_任意_0')` +- `Signal('60分钟_D1_表里关系V230102_向上_顶分_任意_0')` +- `Signal('60分钟_D1_表里关系V230102_向上_延伸_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_status_V230102` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_stop_V230815.md b/.claude/skills/signal-functions/references/signals/cxt_bi_stop_V230815.md new file mode 100644 index 000000000..308213269 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_stop_V230815.md @@ -0,0 +1,27 @@ +# cxt_bi_stop_V230815:笔止损距离状态 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_距离{th}BP_止损V230815` + +## 信号逻辑 + +1. 读取最后一笔方向,并把其高低点作为止损基准; +2. 向上场景比较最新收盘距笔高的回撤,向下场景比较最新收盘距笔低的反弹; +3. 若落在 `th` BP 阈值内则标记 `阈值内`,否则标记 `阈值外`。 + +## 信号列表示例 + +- `Signal('60分钟_距离50BP_止损V230815_向上_阈值内_任意_0')` +- `Signal('60分钟_距离50BP_止损V230815_向下_阈值外_任意_0')` + +## 参数说明 + +- `th`:距离阈值,单位 BP,默认 `50`; +- 信号只读取最后一笔和当前 UBI,不做更长历史统计。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_stop_V230815` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_trend_V230824.md b/.claude/skills/signal-functions/references/signals/cxt_bi_trend_V230824.md new file mode 100644 index 000000000..49d8055c3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_trend_V230824.md @@ -0,0 +1,28 @@ +# cxt_bi_trend_V230824:N笔形态判断 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}TH{th}_形态V230824` + +## 信号逻辑 + +1. 取最近 `n` 笔的中位价格均值; +2. 用首笔中位价格相对均值的偏离程度判断 `向上/向下/横盘`; +3. 偏离阈值由 `th` 控制,数据不足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N4TH2_形态V230824_向上_任意_任意_0')` +- `Signal('60分钟_D1N4TH2_形态V230824_横盘_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- `n`:参与比较的笔数,默认 `4`; +- `th`:相对均值的偏离阈值,默认 `2`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_trend_V230824` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_trend_V230913.md b/.claude/skills/signal-functions/references/signals/cxt_bi_trend_V230913.md new file mode 100644 index 000000000..0e77de922 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_trend_V230913.md @@ -0,0 +1,27 @@ +# cxt_bi_trend_V230913:笔趋势高低点回归信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}笔趋势_高低点辅助判断V230913` + +## 信号逻辑 + +1. 分别取最近 `di` 个向上笔高点和向下笔低点,做线性回归预测; +2. 用当前 UBI 指定位置的时间点预测上沿、下沿及中轴; +3. 将最新收盘相对预测区间的位置映射为 `上升趋势/下降趋势 + 强弱`。 + +## 信号列表示例 + +- `Signal('60分钟_D4N1笔趋势_高低点辅助判断V230913_上升趋势_强_任意_0')` +- `Signal('60分钟_D4N1笔趋势_高低点辅助判断V230913_下降趋势_超强_任意_0')` + +## 参数说明 + +- `di`:参与回归的同向笔数量,默认 `4`; +- `n`:使用 UBI 中倒数第 `n` 根 K 线做比较,默认 `1`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_trend_V230913` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bi_zdf_V230601.md b/.claude/skills/signal-functions/references/signals/cxt_bi_zdf_V230601.md new file mode 100644 index 000000000..f07bdd65a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bi_zdf_V230601.md @@ -0,0 +1,27 @@ +# cxt_bi_zdf_V230601:BI涨跌幅分层 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}_分层V230601` + +## 信号逻辑 + +1. 取最近最多 50 笔的力度序列; +2. 读取最新笔方向作为 `v1`; +3. 用 `qcut_last_label` 将最新力度分到 `n` 层中的某一层,输出 `第X层`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N5_分层V230601_向上_第3层_任意_0')` +- `Signal('60分钟_D1N5_分层V230601_向下_第1层_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始统计,默认 `1`; +- `n`:分层数量,默认 `5`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bi_zdf_V230601` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bs_V240526.md b/.claude/skills/signal-functions/references/signals/cxt_bs_V240526.md new file mode 100644 index 000000000..669851d9d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bs_V240526.md @@ -0,0 +1,27 @@ +# cxt_bs_V240526:趋势跟随 BS 辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_趋势跟随_BS辅助V240526` + +## 信号逻辑 + +1. 读取最近 7 笔,要求倒数第二笔具备高 SNR、强价格力度、强成交量或斜率特征; +2. 再比较最后一笔相对前一强势笔的价格力度区间; +3. 满足小回撤条件时输出 `买点/卖点`,否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_趋势跟随_BS辅助V240526_买点_任意_任意_0')` +- `Signal('60分钟_趋势跟随_BS辅助V240526_卖点_任意_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空; +- 重点观察倒数第二笔是否是“顺畅强趋势笔”。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bs_V240526` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_bs_V240527.md b/.claude/skills/signal-functions/references/signals/cxt_bs_V240527.md new file mode 100644 index 000000000..4d333496a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_bs_V240527.md @@ -0,0 +1,27 @@ +# cxt_bs_V240527:未完成笔上的趋势跟随 BS 辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_趋势跟随_BS辅助V240527` + +## 信号逻辑 + +1. 读取最近 7 笔,要求最后一笔本身是高 SNR 的强趋势笔; +2. 再读取当前 UBI 原始 K 线,比较其价格力度相对最后一笔的回撤比例; +3. 满足小回撤条件时输出 `买点/卖点`,否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_趋势跟随_BS辅助V240527_买点_任意_任意_0')` +- `Signal('60分钟_趋势跟随_BS辅助V240527_卖点_任意_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空; +- 与 `V240526` 的区别在于这里评估的是未完成笔上的回撤。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_bs_V240527` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_decision_V240526.md b/.claude/skills/signal-functions/references/signals/cxt_decision_V240526.md new file mode 100644 index 000000000..4a5dd4a2a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_decision_V240526.md @@ -0,0 +1,27 @@ +# cxt_decision_V240526:分型区域决策 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_分型区域N{n}_决策区域V240526` + +## 信号逻辑 + +1. 在最近 100 根 K 线中提取离散价格层; +2. 若最后一笔向上,统计最新收盘到顶分型上沿之间的价位层数,层数不多于 `n` 时判定 `开空`;向下笔反向判定 `开多`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_分型区域N9_决策区域V240526_开多_任意_任意_0')` +- `Signal('60分钟_分型区域N9_决策区域V240526_开空_任意_任意_0')` + +## 参数说明 + +- `n`:允许的价位层数量阈值,默认 `9`; +- 至少要求 120 根原始 K 线和一笔已完成笔。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_decision_V240526` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_decision_V240612.md b/.claude/skills/signal-functions/references/signals/cxt_decision_V240612.md new file mode 100644 index 000000000..20c6fb54d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_decision_V240612.md @@ -0,0 +1,27 @@ +# cxt_decision_V240612:高低点N档决策区间 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_W{w}N{n}高低点_决策区域V240612` + +## 信号逻辑 + +1. 用最近 100 根 K 线生成离散价格层,再用最近 `w` 根 K 线确定高低点; +2. 在低点上方和高点下方各取第 `n` 档价格,形成低区和高区阈值; +3. 最新收盘落入低区判定 `开多`,落入高区判定 `开空`。 + +## 信号列表示例 + +- `Signal('60分钟_W10N9高低点_决策区域V240612_开多_任意_任意_0')` +- `Signal('60分钟_W10N9高低点_决策区域V240612_开空_任意_任意_0')` + +## 参数说明 + +- `w`:最近高低点统计窗口,默认 `10`; +- `n`:从高低点向内取第 `n` 档价格,默认 `9`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_decision_V240612` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_decision_V240613.md b/.claude/skills/signal-functions/references/signals/cxt_decision_V240613.md new file mode 100644 index 000000000..9ef5254d8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_decision_V240613.md @@ -0,0 +1,27 @@ +# cxt_decision_V240613:放量笔N4BS2决策区 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_放量笔N{n}BS2_决策区域V240613` + +## 信号逻辑 + +1. 取最近 `n` 笔并定位成交量最大的最后一笔; +2. 若该笔向下但未创新低,判定 `开多`;若向上但未创新高,判定 `开空`; +3. 只有最后一笔同时满足“放量且不是极值笔”时才触发。 + +## 信号列表示例 + +- `Signal('60分钟_放量笔N4BS2_决策区域V240613_开多_任意_任意_0')` +- `Signal('60分钟_放量笔N4BS2_决策区域V240613_开空_任意_任意_0')` + +## 参数说明 + +- `n`:比较最近 `n` 笔的放量程度,默认 `4`; +- 仅在 UBI 不超过 7 时使用该决策信号。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_decision_V240613` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_decision_V240614.md b/.claude/skills/signal-functions/references/signals/cxt_decision_V240614.md new file mode 100644 index 000000000..a1431806f --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_decision_V240614.md @@ -0,0 +1,27 @@ +# cxt_decision_V240614:放量新高/新低决策区 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_放量笔N{n}_决策区域V240614` + +## 信号逻辑 + +1. 取最近 `n` 笔并定位成交量最大的最后一笔; +2. 若该笔向下且同时创新低,判定 `开多`;若向上且同时创新高,判定 `开空`; +3. 用于识别放量突破后的反向决策区域。 + +## 信号列表示例 + +- `Signal('60分钟_放量笔N4_决策区域V240614_开多_任意_任意_0')` +- `Signal('60分钟_放量笔N4_决策区域V240614_开空_任意_任意_0')` + +## 参数说明 + +- `n`:比较最近 `n` 笔的放量程度,默认 `4`; +- 需要最后一笔既是放量笔,又是最近 `n` 笔的新高或新低笔。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_decision_V240614` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_double_zs_V230311.md b/.claude/skills/signal-functions/references/signals/cxt_double_zs_V230311.md new file mode 100644 index 000000000..30a794b7c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_double_zs_V230311.md @@ -0,0 +1,27 @@ +# cxt_double_zs_V230311:双中枢 BS1 辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}双中枢_BS1辅助V230311` + +## 信号逻辑 + +1. 提取最近 20 笔并重建中枢序列; +2. 若最近两个中枢都有效,比较后一中枢内部两笔的时长与前后中枢极值关系; +3. 向下笔满足衰竭条件判定 `看多`,向上笔满足衰竭条件判定 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1双中枢_BS1辅助V230311_看多_任意_任意_0')` +- `Signal('60分钟_D1双中枢_BS1辅助V230311_看空_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 需要至少形成两个有效中枢,否则返回 `其他`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_double_zs_V230311` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_eleven_bi_V230622.md b/.claude/skills/signal-functions/references/signals/cxt_eleven_bi_V230622.md new file mode 100644 index 000000000..d724ce7ec --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_eleven_bi_V230622.md @@ -0,0 +1,27 @@ +# cxt_eleven_bi_V230622:十一笔形态分类信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}十一笔_形态V230622` + +## 信号逻辑 + +1. 读取最近 11 笔并统计首末极值与中间结构关系; +2. 识别 A5B3C3、A3B3C5、A3B5C3、类二买卖、类三买等十一笔结构; +3. 若不满足任何预定义结构,则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1十一笔_形态V230622_A5B3C3式类一买_任意_任意_0')` +- `Signal('60分钟_D1十一笔_形态V230622_类二卖_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 该信号面向更长结构分类,要求笔数和确认度更高。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_eleven_bi_V230622` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_first_buy_V221126.md b/.claude/skills/signal-functions/references/signals/cxt_first_buy_V221126.md new file mode 100644 index 000000000..07371a8d6 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_first_buy_V221126.md @@ -0,0 +1,27 @@ +# cxt_first_buy_V221126:一买信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}B_BUY1V221126` + +## 信号逻辑 + +1. 依次尝试最近 `21/19/17/15/13/11/9/7/5` 笔; +2. 调用统一的 `check_first_buy` 结构判定函数识别一买; +3. 命中后输出对应笔数,否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1B_BUY1_一买_5笔_任意_0')` +- `Signal('60分钟_D1B_BUY1_一买_13笔_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 一买结构判定复用 Python 同名逻辑。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_first_buy_V221126` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_first_sell_V221126.md b/.claude/skills/signal-functions/references/signals/cxt_first_sell_V221126.md new file mode 100644 index 000000000..ef75a6da7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_first_sell_V221126.md @@ -0,0 +1,27 @@ +# cxt_first_sell_V221126:一卖信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}B_SELL1V221126` + +## 信号逻辑 + +1. 依次尝试最近 `21/19/17/15/13/11/9/7/5` 笔; +2. 调用统一的 `check_first_sell` 结构判定函数识别一卖; +3. 命中后输出对应笔数,否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1B_SELL1_一卖_5笔_任意_0')` +- `Signal('60分钟_D1B_SELL1_一卖_13笔_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 一卖结构判定复用 Python 同名逻辑。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_first_sell_V221126` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_five_bi_V230619.md b/.claude/skills/signal-functions/references/signals/cxt_five_bi_V230619.md new file mode 100644 index 000000000..731effc25 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_five_bi_V230619.md @@ -0,0 +1,27 @@ +# cxt_five_bi_V230619:五笔形态分类信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}五笔_形态V230619` + +## 信号逻辑 + +1. 读取最近 5 笔并计算整体最高点、最低点; +2. 依据中枢重合、首末笔力度与突破位置识别底背驰、顶背驰、颈线突破、类三买卖等形态; +3. 未命中任何预定义结构时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1五笔_形态V230619_aAb式底背驰_任意_任意_0')` +- `Signal('60分钟_D1五笔_形态V230619_类三卖_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 该信号直接输出形态标签,不再附加次级分类。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_five_bi_V230619` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_fx_power_V221107.md b/.claude/skills/signal-functions/references/signals/cxt_fx_power_V221107.md new file mode 100644 index 000000000..0f33721bb --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_fx_power_V221107.md @@ -0,0 +1,27 @@ +# cxt_fx_power_V221107:倒数分型强弱 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}F_分型强弱V221107` + +## 信号逻辑 + +1. 读取倒数第 `di` 个分型; +2. `v1 = 分型强弱(power_str) + 顶/底`; +3. `v2 = 有中枢/无中枢`。 + +## 信号列表示例 + +- `Signal('60分钟_D1F_分型强弱_强顶_有中枢_任意_0')` +- `Signal('60分钟_D2F_分型强弱_弱底_无中枢_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 个分型,默认 `1`; +- 仅当分型列表长度满足要求时输出具体强弱,否则返回 `其他`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_fx_power_V221107` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_intraday_V230701.md b/.claude/skills/signal-functions/references/signals/cxt_intraday_V230701.md new file mode 100644 index 000000000..97447b7b7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_intraday_V230701.md @@ -0,0 +1,13 @@ +# cxt_intraday_V230701:30分钟日内走势分类 + +> 模块: `cxt_trader.rs` | 类别: `trader` + +## 参数模板 + +`"{freq1}#{freq2}_D{di}日_走势分类V230701` + +## 信号逻辑 + +1. 取指定日的 30 分钟 bars; +2. 识别无中枢、双中枢、单中枢平衡市; +3. 返回对应日内结构标签。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_nine_bi_V230621.md b/.claude/skills/signal-functions/references/signals/cxt_nine_bi_V230621.md new file mode 100644 index 000000000..09ec78a7c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_nine_bi_V230621.md @@ -0,0 +1,27 @@ +# cxt_nine_bi_V230621:九笔形态分类信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}九笔_形态V230621` + +## 信号逻辑 + +1. 读取最近 9 笔,依据首末极值和中间中枢关系构造结构; +2. 识别 aAb、aAbcd、ABC、类趋势一买卖、类三买卖、类二买卖等九笔形态; +3. 未命中时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1九笔_形态V230621_aAb式类一买_任意_任意_0')` +- `Signal('60分钟_D1九笔_形态V230621_ZG三买_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 该分类信号直接返回形态名,不再附加辅助标签。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_nine_bi_V230621` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_overlap_V240526.md b/.claude/skills/signal-functions/references/signals/cxt_overlap_V240526.md new file mode 100644 index 000000000..5f88f5258 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_overlap_V240526.md @@ -0,0 +1,27 @@ +# cxt_overlap_V240526:收盘价与最近分型区间重合次数 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_顶底重合_支撑压力V240526` + +## 信号逻辑 + +1. 取最近 9 笔,读取最新收盘价; +2. 分别统计收盘价落在向上笔顶分型区间和向下笔底分型区间中的次数; +3. 输出 `顶重合X次` 与 `底重合Y次`,用于支撑压力观察。 + +## 信号列表示例 + +- `Signal('60分钟_顶底重合_支撑压力V240526_顶重合2次_底重合1次_任意_0')` +- `Signal('60分钟_顶底重合_支撑压力V240526_顶重合0次_底重合3次_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空; +- 至少要求 11 笔与非空原始 K 线序列。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_overlap_V240526` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_overlap_V240612.md b/.claude/skills/signal-functions/references/signals/cxt_overlap_V240612.md new file mode 100644 index 000000000..d8e4710f1 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_overlap_V240612.md @@ -0,0 +1,27 @@ +# cxt_overlap_V240612:顺畅笔分型支撑压力信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_SNR顺畅N{n}_支撑压力V240612` + +## 信号逻辑 + +1. 在最近 `n` 笔中筛选原始 K 线数量足够的笔,并按 SNR 排序; +2. 选择 SNR 最高且大于阈值的“顺畅笔”,提取其顶分型与底分型区间; +3. 若最新笔终点分型与这些区间重叠,则输出 `支撑/压力 + 顶分型/底分型`。 + +## 信号列表示例 + +- `Signal('60分钟_SNR顺畅N7_支撑压力V240612_支撑_顺畅笔顶分型_任意_0')` +- `Signal('60分钟_SNR顺畅N7_支撑压力V240612_压力_顺畅笔底分型_任意_0')` + +## 参数说明 + +- `n`:候选顺畅笔数量窗口,默认 `7`; +- 仅当最大 SNR 不低于 `0.7` 时才输出具体支撑压力。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_overlap_V240612` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_range_oscillation_V230620.md b/.claude/skills/signal-functions/references/signals/cxt_range_oscillation_V230620.md new file mode 100644 index 000000000..e579f4f94 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_range_oscillation_V230620.md @@ -0,0 +1,27 @@ +# cxt_range_oscillation_V230620:区间震荡笔数统计 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TH{th}_区间震荡V230620` + +## 信号逻辑 + +1. 读取最近 12 笔的中位价格中心; +2. 从最新笔向前逐笔比较中心振幅,只要最大振幅百分比小于 `th` 就继续累加; +3. 若累计笔数超过 1,则输出 `X笔震荡 + 向上/向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1TH2_区间震荡V230620_4笔震荡_向上_任意_0')` +- `Signal('60分钟_D1TH2_区间震荡V230620_6笔震荡_向下_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- `th`:中心振幅百分比阈值,默认 `2`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_range_oscillation_V230620` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_second_bs_V230320.md b/.claude/skills/signal-functions/references/signals/cxt_second_bs_V230320.md new file mode 100644 index 000000000..0906fe969 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_second_bs_V230320.md @@ -0,0 +1,28 @@ +# cxt_second_bs_V230320:均线辅助识别第二类买卖点 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}#{ma_type}#{timeperiod}_BS2辅助V230320` + +## 信号逻辑 + +1. 取最近 5 笔,并计算关键分型右侧原始 K 线的均线值; +2. 若前两次同向回撤/反弹已偏离均线,而第 5 笔重新回到均线同向,判定 `二买/二卖`; +3. 均线、分型样本或笔数量不足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1#SMA#21_BS2辅助V230320_二买_任意_任意_0')` +- `Signal('60分钟_D1#EMA#34_BS2辅助V230320_二卖_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `21`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_second_bs_V230320` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_second_bs_V240524.md b/.claude/skills/signal-functions/references/signals/cxt_second_bs_V240524.md new file mode 100644 index 000000000..033e26c8f --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_second_bs_V240524.md @@ -0,0 +1,28 @@ +# cxt_second_bs_V240524:第二买卖点重叠计数信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}T{t}_第二买卖点V240524` + +## 信号逻辑 + +1. 读取最近 `w` 笔,统计最后一笔终点分型与前面笔终点分型的重叠次数; +2. 最后一笔为向下且长度足够、重叠次数不少于 `t` 时判定 `二买`; +3. 最后一笔为向上且满足同样条件时判定 `二卖`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W9T2_第二买卖点V240524_二买_任意_任意_0')` +- `Signal('60分钟_D1W9T2_第二买卖点V240524_二卖_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- `w`:统计窗口笔数,默认 `9`; +- `t`:最少重叠次数,默认 `2`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_second_bs_V240524` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_seven_bi_V230620.md b/.claude/skills/signal-functions/references/signals/cxt_seven_bi_V230620.md new file mode 100644 index 000000000..af0f5e1df --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_seven_bi_V230620.md @@ -0,0 +1,27 @@ +# cxt_seven_bi_V230620:七笔形态分类信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}七笔_形态V230620` + +## 信号逻辑 + +1. 读取最近 7 笔并统计极值与关键中枢关系; +2. 识别 aAbcd、abcAd、类趋势、向上/向下中枢完成、类三买卖等七笔结构; +3. 未命中预定义结构时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1七笔_形态V230620_aAbcd式底背驰_任意_任意_0')` +- `Signal('60分钟_D1七笔_形态V230620_向上中枢完成_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 仅在最近结构已基本完成且 UBI 不长时评估。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_seven_bi_V230620` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_third_bs_V230318.md b/.claude/skills/signal-functions/references/signals/cxt_third_bs_V230318.md new file mode 100644 index 000000000..c24374ac5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_third_bs_V230318.md @@ -0,0 +1,28 @@ +# cxt_third_bs_V230318:均线辅助识别第三类买卖点 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230318` + +## 信号逻辑 + +1. 取最近 5 笔构造中枢,并计算第 1、3、5 笔终点分型的均线; +2. 若第 5 笔离开中枢,且三次均线值同向抬升或下降,则判定 `三买/三卖`; +3. 中枢无效、均线缺失或笔数不足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1#SMA#34_BS3辅助V230318_三买_任意_任意_0')` +- `Signal('60分钟_D1#EMA#34_BS3辅助V230318_三卖_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `34`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_third_bs_V230318` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_third_bs_V230319.md b/.claude/skills/signal-functions/references/signals/cxt_third_bs_V230319.md new file mode 100644 index 000000000..018c846a1 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_third_bs_V230319.md @@ -0,0 +1,28 @@ +# cxt_third_bs_V230319:带均线形态的第三类买卖点辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}#{ma_type}#{timeperiod}_BS3辅助V230319` + +## 信号逻辑 + +1. 取最近 5 笔构造中枢,并读取第 1、3、5 笔终点分型的均线值; +2. 先根据第 5 笔是否离开中枢,判定 `三买/三卖`; +3. 再根据三次均线相对位置补充 `均线新高/新低/顶分/底分/否定`。 + +## 信号列表示例 + +- `Signal('60分钟_D1#SMA#34_BS3辅助V230319_三买_均线新高_任意_0')` +- `Signal('60分钟_D1#EMA#34_BS3辅助V230319_三卖_均线顶分_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `34`。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_third_bs_V230319` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_third_buy_V230228.md b/.claude/skills/signal-functions/references/signals/cxt_third_buy_V230228.md new file mode 100644 index 000000000..5098155fa --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_third_buy_V230228.md @@ -0,0 +1,27 @@ +# cxt_third_buy_V230228:笔三买辅助 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}_三买辅助V230228` + +## 信号逻辑 + +1. 依次尝试最近 `13/11/9/7/5` 笔加末笔,共 `n + 1` 笔; +2. 从奇数位上升关键笔中提取突破结构,要求末笔低点在关键高点上方并满足价格约束; +3. 满足条件时输出 `三买_XX笔`,否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1_三买辅助V230228_三买_6笔_任意_0')` +- `Signal('60分钟_D1_三买辅助V230228_三买_10笔_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 该函数仅输出三买,不输出三卖。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_third_buy_V230228` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_three_bi_V230618.md b/.claude/skills/signal-functions/references/signals/cxt_three_bi_V230618.md new file mode 100644 index 000000000..a2ba7b26b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_three_bi_V230618.md @@ -0,0 +1,27 @@ +# cxt_three_bi_V230618:三笔形态分类信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}三笔_形态V230618` + +## 信号逻辑 + +1. 读取最近 3 笔,依据第 1 笔和第 3 笔的高低点关系划分形态; +2. 识别不重合、奔走、收敛、扩张、盘背、无背等典型三笔结构; +3. 若不满足任何预定义结构,则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1三笔_形态V230618_向下盘背_任意_任意_0')` +- `Signal('60分钟_D1三笔_形态V230618_向上扩张_任意_任意_0')` + +## 参数说明 + +- `di`:从倒数第 `di` 笔开始取样,默认 `1`; +- 仅在未完成笔较短时评估三笔形态。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_three_bi_V230618` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_ubi_end_V230816.md b/.claude/skills/signal-functions/references/signals/cxt_ubi_end_V230816.md new file mode 100644 index 000000000..92b754954 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_ubi_end_V230816.md @@ -0,0 +1,27 @@ +# cxt_ubi_end_V230816:UBI 新高新低次数信号 + +> 模块: `cxt.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_UBI_BE辅助V230816` + +## 信号逻辑 + +1. 重建当前未完成笔 UBI 结构; +2. 若 UBI 向上,则统计内部顶分型的逐次新高次数;若最后一根再次上破,则输出 `新高_第X次`; +3. UBI 向下时对称统计新低次数,否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_UBI_BE辅助V230816_新高_第3次_任意_0')` +- `Signal('60分钟_UBI_BE辅助V230816_新低_第2次_任意_0')` + +## 参数说明 + +- 本信号无额外参数,`params` 可为空; +- 需要 UBI 已形成足够分型和原始 K 线长度。 + +## 对齐说明 + +与 Python `czsc.signals.cxt_ubi_end_V230816` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/cxt_zhong_shu_gong_zhen_V221221.md b/.claude/skills/signal-functions/references/signals/cxt_zhong_shu_gong_zhen_V221221.md new file mode 100644 index 000000000..c731d32a5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/cxt_zhong_shu_gong_zhen_V221221.md @@ -0,0 +1,13 @@ +# cxt_zhong_shu_gong_zhen_V221221:大小级别中枢共振 + +> 模块: `cxt_trader.rs` | 类别: `trader` + +## 参数模板 + +`"{freq1}_{freq2}_中枢共振V221221` + +## 信号逻辑 + +1. 大小级别最近 3 笔均构成有效中枢; +2. 小级别中枢位置相对大级别中轴偏上且末笔向下,判 `看多`; +3. 小级别中枢位置相对大级别中轴偏下且末笔向上,判 `看空`。 diff --git a/.claude/skills/signal-functions/references/signals/dema_up_dw_line_V230605.md b/.claude/skills/signal-functions/references/signals/dema_up_dw_line_V230605.md new file mode 100644 index 000000000..200276e18 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/dema_up_dw_line_V230605.md @@ -0,0 +1,26 @@ +# dema_up_dw_line_V230605:DEMA 短线趋势信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}_DEMA短线趋势V230605` + +## 信号逻辑 + +1. 用 `n` 与 `2n` 窗口均值构造 `dema = 2*MA(n)-MA(2n)`; +2. 最新收盘价高于 dema 判 `看多`,否则判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N5_DEMA短线趋势V230605_看多_任意_任意_0')` +- `Signal('60分钟_D1N5_DEMA短线趋势V230605_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:短窗口,默认 `5`。 + +## 对齐说明 + +按 Python `dema_up_dw_line_V230605` 的近似 DEMA 口径实现。 diff --git a/.claude/skills/signal-functions/references/signals/demakder_up_dw_line_V230605.md b/.claude/skills/signal-functions/references/signals/demakder_up_dw_line_V230605.md new file mode 100644 index 000000000..5225f0eb3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/demakder_up_dw_line_V230605.md @@ -0,0 +1,28 @@ +# demakder_up_dw_line_V230605:DEMAKER 价格趋势信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}TH{th}TL{tl}_DEMAKER价格趋势V230605` + +## 信号逻辑 + +1. 统计窗口内上涨高点均值 `demax` 与下跌低点均值 `demin`; +2. 计算 `demaker = demax / (demax + demin)`; +3. `demaker > th/10` 判 `看多`,`demaker < tl/10` 判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N105TH5TL5_DEMAKER价格趋势V230605_看多_任意_任意_0')` +- `Signal('60分钟_D1N105TH5TL5_DEMAKER价格趋势V230605_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:统计窗口,默认 `105`; +- `th/tl`:上下阈值(除以 10 使用),默认 `5/5`。 + +## 对齐说明 + +保持 Python `demakder_up_dw_line_V230605` 对空样本返回 NaN 的行为。 diff --git a/.claude/skills/signal-functions/references/signals/emv_up_dw_line_V230605.md b/.claude/skills/signal-functions/references/signals/emv_up_dw_line_V230605.md new file mode 100644 index 000000000..fc53393b3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/emv_up_dw_line_V230605.md @@ -0,0 +1,26 @@ +# emv_up_dw_line_V230605:EMV 简易波动多空信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}_EMV简易波动V230605` + +## 信号逻辑 + +1. 取最近两根K线计算中点位移; +2. 以成交量/振幅形成箱体比率; +3. `emv > 0` 判 `看多`,否则判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1_EMV简易波动V230605_看多_任意_任意_0')` +- `Signal('60分钟_D1_EMV简易波动V230605_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `emv_up_dw_line_V230605` 的两根K线近似 EMV 计算一致。 diff --git a/.claude/skills/signal-functions/references/signals/er_up_dw_line_V230604.md b/.claude/skills/signal-functions/references/signals/er_up_dw_line_V230604.md new file mode 100644 index 000000000..d5cf112a8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/er_up_dw_line_V230604.md @@ -0,0 +1,29 @@ +# er_up_dw_line_V230604:ER 价格动量分层信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}N{n}_ER价格动量V230604` + +## 信号逻辑 + +1. 以 `W` 窗口均价构造 bull/bear power 因子; +2. 仅保留与末值同号的因子子序列; +3. 末值正负给出 `均线上方/均线下方`; +4. 对同号子序列做 `N` 分箱输出 `第x层`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W60N10_ER价格动量V230604_均线上方_第3层_任意_0')` +- `Signal('60分钟_D1W60N10_ER价格动量V230604_均线下方_第8层_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `w`:均价窗口,默认 `60`; +- `n`:分层数量,默认 `10`。 + +## 对齐说明 + +与 Python `er_up_dw_line_V230604` 的同号过滤与分层规则一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_ci_tou_V221101.md b/.claude/skills/signal-functions/references/signals/jcc_ci_tou_V221101.md new file mode 100644 index 000000000..e92cc6fa5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_ci_tou_V221101.md @@ -0,0 +1,28 @@ +# jcc_ci_tou_V221101:刺透形态 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}Z{z}TH{th}_刺透形态` + +## 信号逻辑 + +1. 前一根为大阴线且跌幅超过 `z`; +2. 当前低开并收盘刺入前一根实体 `th` 比例以上; +3. 满足则返回 `满足`。 + +## 信号列表示例 + +- `Signal('15分钟_D1Z100TH50_刺透形态_满足_任意_任意_0')` +- `Signal('15分钟_D1Z100TH50_刺透形态_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `z`:前一根最小跌幅(BP),默认 `100`; +- `th`:刺入比例阈值,默认 `50`。 + +## 对齐说明 + +与 Python `jcc_ci_tou_V221101` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_fan_ji_xian_V221121.md b/.claude/skills/signal-functions/references/signals/jcc_fan_ji_xian_V221121.md new file mode 100644 index 000000000..43252da15 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_fan_ji_xian_V221121.md @@ -0,0 +1,26 @@ +# jcc_fan_ji_xian_V221121:反击线 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}_反击线` + +## 信号逻辑 + +1. 最近20根内检测收盘接近、跳空幅度和实体强度; +2. 满足基础条件后,按区间位置和方向细分 `看涨反击线/看跌反击线`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('15分钟_D1_反击线_满足_看涨反击线_任意_0')` +- `Signal('15分钟_D1_反击线_满足_看跌反击线_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `jcc_fan_ji_xian_V221121` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_fen_shou_xian_V20221113.md b/.claude/skills/signal-functions/references/signals/jcc_fen_shou_xian_V20221113.md new file mode 100644 index 000000000..9bb2efbc0 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_fen_shou_xian_V20221113.md @@ -0,0 +1,27 @@ +# jcc_fen_shou_xian_V20221113:分手线 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_分手线` + +## 信号逻辑 + +1. 两根K线同开盘,且第二根收盘突破第一根高低点,判 `满足`; +2. 结合区间位置与实体方向,细分 `上升分手/下跌分手`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K_分手线_满足_上升分手_任意_0')` +- `Signal('60分钟_D1K_分手线_满足_下跌分手_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `zdf`:分手强度阈值(BP),默认 `300`。 + +## 对齐说明 + +与 Python `jcc_fen_shou_xian_V20221113` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_gap_yin_yang_V221121.md b/.claude/skills/signal-functions/references/signals/jcc_gap_yin_yang_V221121.md new file mode 100644 index 000000000..e507b170a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_gap_yin_yang_V221121.md @@ -0,0 +1,26 @@ +# jcc_gap_yin_yang_V221121:跳空并列阴阳 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_并列阴阳` + +## 信号逻辑 + +1. 最近三根满足跳空窗口; +2. 两根并列阴阳实体方差小于阈值; +3. 判定 `向上跳空/向下跳空`。 + +## 信号列表示例 + +- `Signal('15分钟_D1K_并列阴阳_向上跳空_任意_任意_0')` +- `Signal('15分钟_D1K_并列阴阳_向下跳空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `jcc_gap_yin_yang_V221121` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_ping_tou_V221113.md b/.claude/skills/signal-functions/references/signals/jcc_ping_tou_V221113.md new file mode 100644 index 000000000..25560b282 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_ping_tou_V221113.md @@ -0,0 +1,27 @@ +# jcc_ping_tou_V221113:平头形态 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TH{th}_平头形态` + +## 信号逻辑 + +1. 对比两根K线高点或低点差值比例,识别 `顶部/底部`; +2. 实体条件满足时给出 `实体标准` 标签; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('15分钟_D2TH100_平头形态_顶部_实体标准_任意_0')` +- `Signal('15分钟_D2TH100_平头形态_底部_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `2`; +- `th`:高低点容差阈值(BP),默认 `100`。 + +## 对齐说明 + +与 Python `jcc_ping_tou_V221113` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_san_fa_V20221115.md b/.claude/skills/signal-functions/references/signals/jcc_san_fa_V20221115.md new file mode 100644 index 000000000..5d8ccd7fe --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_san_fa_V20221115.md @@ -0,0 +1,27 @@ +# jcc_san_fa_V20221115:三法形态 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_三法` + +## 信号逻辑 + +1. 固定观察 6 根K线,比较首尾与中间三根位置关系; +2. 满足基础强度阈值时,判定 `上升三法/下降三法`; +3. 不满足返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K_三法_满足_上升三法_任意_0')` +- `Signal('60分钟_D1K_三法_满足_下降三法_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `zdf`:首尾涨跌幅阈值(BP),默认 `500`。 + +## 对齐说明 + +与 Python `jcc_san_fa_V20221115` 判定一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_san_fa_V20221118.md b/.claude/skills/signal-functions/references/signals/jcc_san_fa_V20221118.md new file mode 100644 index 000000000..a39c81e53 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_san_fa_V20221118.md @@ -0,0 +1,26 @@ +# jcc_san_fa_V20221118:三法形态A + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_三法A` + +## 信号逻辑 + +1. 在 5~8 根窗口内扫描三法形态; +2. 满足上升三法或下降三法时输出方向; +3. `v2` 记录触发窗口长度 `nK`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K_三法A_上升三法_6K_任意_0')` +- `Signal('60分钟_D1K_三法A_下降三法_8K_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +窗口遍历与条件组合对齐 Python `jcc_san_fa_V20221118`。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_san_szx_V221122.md b/.claude/skills/signal-functions/references/signals/jcc_san_szx_V221122.md new file mode 100644 index 000000000..dc2de96c2 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_san_szx_V221122.md @@ -0,0 +1,27 @@ +# jcc_san_szx_V221122:三星形态 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{th}_三星` + +## 信号逻辑 + +1. 取最近5根K线; +2. 统计十字线数量; +3. 不少于3根判 `满足`。 + +## 信号列表示例 + +- `Signal('15分钟_D1T10_三星_满足_任意_任意_0')` +- `Signal('15分钟_D1T10_三星_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `th`:十字线阈值,默认 `10`。 + +## 对齐说明 + +与 Python `jcc_san_szx_V221122` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_san_xing_xian_V221023.md b/.claude/skills/signal-functions/references/signals/jcc_san_xing_xian_V221023.md new file mode 100644 index 000000000..6cca1f131 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_san_xing_xian_V221023.md @@ -0,0 +1,27 @@ +# jcc_san_xing_xian_V221023:伞形线形态信号 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TH{th}_伞形线` + +## 信号逻辑 + +1. 判断当前K线是否满足长下影短上影(下影 > 实体 * th,且上影 < 0.2 * 实体); +2. 若满足,再结合左侧20根区间位置,判定 `锤子/上吊`; +3. 不满足时返回 `其他`。 + +## 信号列表示例 + +- `Signal('15分钟_D1TH200_伞形线_满足_锤子_任意_0')` +- `Signal('15分钟_D1TH200_伞形线_满足_上吊_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `th`:下影线与实体倍数阈值,默认 `2`(内部按 100 倍整数编码)。 + +## 对齐说明 + +与 Python `jcc_san_xing_xian_V221023` 判定顺序一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_shan_chun_V221121.md b/.claude/skills/signal-functions/references/signals/jcc_shan_chun_V221121.md new file mode 100644 index 000000000..151261487 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_shan_chun_V221121.md @@ -0,0 +1,26 @@ +# jcc_shan_chun_V221121:山川形态 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}B_山川形态` + +## 信号逻辑 + +1. 取最近5笔; +2. 末笔向上且 `5/3/1` 笔高点方差小,判 `三山`; +3. 末笔向下且 `5/3/1` 笔低点方差小,判 `三川`。 + +## 信号列表示例 + +- `Signal('15分钟_D1B_山川形态_三山_任意_任意_0')` +- `Signal('15分钟_D1B_山川形态_三川_任意_任意_0')` + +## 参数说明 + +- `di`:截止倒数第 `di` 笔,默认 `1`。 + +## 对齐说明 + +方差阈值与 Python `jcc_shan_chun_V221121` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_szx_V221111.md b/.claude/skills/signal-functions/references/signals/jcc_szx_V221111.md new file mode 100644 index 000000000..a1cdd43e5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_szx_V221111.md @@ -0,0 +1,27 @@ +# jcc_szx_V221111:十字线 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TH{th}_十字线` + +## 信号逻辑 + +1. `(high-low)/|close-open| > th` 或 `close==open` 判十字线; +2. 按上下影长度细分 `蜻蜓/墓碑/长腿/十字线`; +3. 前一根强阳时追加 `北方`。 + +## 信号列表示例 + +- `Signal('60分钟_D1TH10_十字线_蜻蜓十字线_北方_任意_0')` +- `Signal('60分钟_D1TH10_十字线_墓碑十字线_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `th`:十字线阈值,默认 `10`。 + +## 对齐说明 + +与 Python `jcc_szx_V221111` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_ta_xing_V221124.md b/.claude/skills/signal-functions/references/signals/jcc_ta_xing_V221124.md new file mode 100644 index 000000000..d6549af4d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_ta_xing_V221124.md @@ -0,0 +1,26 @@ +# jcc_ta_xing_V221124:塔形顶底 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_塔形` + +## 信号逻辑 + +1. 在 5~9 根窗口内扫描; +2. 首尾实体最大且中间高低点聚集; +3. 判定 `顶部/底部`,并返回窗口长度。 + +## 信号列表示例 + +- `Signal('15分钟_D1K_塔形_顶部_6K_任意_0')` +- `Signal('15分钟_D1K_塔形_底部_8K_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `jcc_ta_xing_V221124` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_ten_mo_V221028.md b/.claude/skills/signal-functions/references/signals/jcc_ten_mo_V221028.md new file mode 100644 index 000000000..9ef0cdaa7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_ten_mo_V221028.md @@ -0,0 +1,26 @@ +# jcc_ten_mo_V221028:吞没形态 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}_吞没形态` + +## 信号逻辑 + +1. 当前K线高低点完全包住前一根K线,记 `满足`; +2. 结合左侧20根位置与实体方向,区分 `看涨吞没/看跌吞没`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('15分钟_D1_吞没形态_满足_看涨吞没_任意_0')` +- `Signal('15分钟_D1_吞没形态_满足_看跌吞没_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `jcc_ten_mo_V221028` 判定条件一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_three_crow_V221108.md b/.claude/skills/signal-functions/references/signals/jcc_three_crow_V221108.md new file mode 100644 index 000000000..f082edb1d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_three_crow_V221108.md @@ -0,0 +1,26 @@ +# jcc_three_crow_V221108:三只乌鸦 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}_三只乌鸦` + +## 信号逻辑 + +1. 三根连续阴线且高低收盘递降; +2. 影线和实体关系满足强空形态; +3. 根据开盘关系细分 `常规/加强/半加强`。 + +## 信号列表示例 + +- `Signal('30分钟_D1_三只乌鸦_满足_常规_任意_0')` +- `Signal('30分钟_D1_三只乌鸦_满足_加强_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `jcc_three_crow_V221108` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_two_crow_V221108.md b/.claude/skills/signal-functions/references/signals/jcc_two_crow_V221108.md new file mode 100644 index 000000000..e82e06472 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_two_crow_V221108.md @@ -0,0 +1,26 @@ +# jcc_two_crow_V221108:两只乌鸦 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_两只乌鸦` + +## 信号逻辑 + +1. 第一根长阳; +2. 第二根高开低走且收在第一根高点之上; +3. 第三根阴线继续下压,判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K_两只乌鸦_看空_任意_任意_0')` +- `Signal('60分钟_D1K_两只乌鸦_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `jcc_two_crow_V221108` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_wu_yun_gai_ding_V221101.md b/.claude/skills/signal-functions/references/signals/jcc_wu_yun_gai_ding_V221101.md new file mode 100644 index 000000000..50957a91d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_wu_yun_gai_ding_V221101.md @@ -0,0 +1,28 @@ +# jcc_wu_yun_gai_ding_V221101:乌云盖顶 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}Z{z}TH{th}_乌云盖顶` + +## 信号逻辑 + +1. 前一根阳线实体涨幅需大于 `z`; +2. 当前K线跳空高开,且收盘回落到前一根实体内部; +3. 前一根收盘位于左侧10根收盘高位,判定 `满足`。 + +## 信号列表示例 + +- `Signal('日线_D1Z500TH50_乌云盖顶_满足_任意_任意_0')` +- `Signal('日线_D1Z500TH50_乌云盖顶_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `z`:前一根阳线最小涨幅(BP),默认 `500`; +- `th`:当前收盘扎入前一根实体比例,默认 `50`。 + +## 对齐说明 + +与 Python `jcc_wu_yun_gai_ding_V221101` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_xing_xian_V221118.md b/.claude/skills/signal-functions/references/signals/jcc_xing_xian_V221118.md new file mode 100644 index 000000000..4ace9b6b3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_xing_xian_V221118.md @@ -0,0 +1,27 @@ +# jcc_xing_xian_V221118:星形线 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TH{th}_星形线` + +## 信号逻辑 + +1. 取三根K线,按高低点结构区分启明星/黄昏星候选; +2. 再结合实体强弱关系做最终确认; +3. 中间K线开收相等时输出 `中间十字`。 + +## 信号列表示例 + +- `Signal('60分钟_D2TH2_星形线_启明星_中间十字_任意_0')` +- `Signal('60分钟_D2TH2_星形线_黄昏星_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `2`; +- `th`:左侧实体与中间实体的倍率阈值,默认 `2`。 + +## 对齐说明 + +与 Python `jcc_xing_xian_V221118` 条件一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_yun_xian_V221118.md b/.claude/skills/signal-functions/references/signals/jcc_yun_xian_V221118.md new file mode 100644 index 000000000..bca350325 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_yun_xian_V221118.md @@ -0,0 +1,26 @@ +# jcc_yun_xian_V221118:孕线形态 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}_孕线` + +## 信号逻辑 + +1. 前一根为长实体,当前为小实体; +2. 当前开收位于前一根实体区间内; +3. 方向反转判 `看多/看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1_孕线_看多_任意_任意_0')` +- `Signal('60分钟_D1_孕线_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +与 Python `jcc_yun_xian_V221118` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/jcc_zhu_huo_xian_V221027.md b/.claude/skills/signal-functions/references/signals/jcc_zhu_huo_xian_V221027.md new file mode 100644 index 000000000..7d23fc23a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/jcc_zhu_huo_xian_V221027.md @@ -0,0 +1,28 @@ +# jcc_zhu_huo_xian_V221027:烛火线 + +> 模块: `jcc.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{th}F{zf}_烛火线` + +## 信号逻辑 + +1. 以影线、实体和振幅阈值判断是否 `满足`; +2. 结合左侧20根区间位置判定 `箭在弦/风中烛`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1T200F500_烛火线_满足_箭在弦_任意_0')` +- `Signal('60分钟_D1T200F500_烛火线_满足_风中烛_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `th`:上影线与实体倍数阈值,默认 `2`; +- `zf`:最小振幅阈值(BP),默认 `500`。 + +## 对齐说明 + +按 Python `jcc_zhu_huo_xian_V221027` 原始公式实现。 diff --git a/.claude/skills/signal-functions/references/signals/kcatr_up_dw_line_V230823.md b/.claude/skills/signal-functions/references/signals/kcatr_up_dw_line_V230823.md new file mode 100644 index 000000000..33d6c03d2 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/kcatr_up_dw_line_V230823.md @@ -0,0 +1,30 @@ +# kcatr_up_dw_line_V230823:ATR 通道突破多空 + +> 模块: `kcatr.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}T{th}_KCATR多空V230823` + +## 信号逻辑 + +1. 在最近 `n` 根上计算平均真实波幅 `ATR`; +2. 在最近 `m` 根上计算收盘均值 `middle`; +3. 最新收盘价大于 `middle + ATR * th` 判 `看多`; +4. 最新收盘价小于 `middle - ATR * th` 判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N30M16T2_KCATR多空V230823_看多_任意_任意_0')` +- `Signal('60分钟_D1N30M16T2_KCATR多空V230823_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:ATR 计算窗口,默认 `30`; +- `m`:中轨均值窗口,默认 `16`; +- `th`:ATR 倍数阈值,默认 `2`。 + +## 对齐说明 + +ATR 取样与突破阈值口径对齐 Python `kcatr_up_dw_line_V230823`。 diff --git a/.claude/skills/signal-functions/references/signals/ntmdk_V230824.md b/.claude/skills/signal-functions/references/signals/ntmdk_V230824.md new file mode 100644 index 000000000..659fba20e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/ntmdk_V230824.md @@ -0,0 +1,27 @@ +# ntmdk_V230824:M 日前收盘价对比多空 + +> 模块: `ntmdk.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}M{m}_NTMDK多空V230824` + +## 信号逻辑 + +1. 取截止倒数第 `di` 根的最近 `m` 根K线; +2. 若末根收盘价大于首根收盘价,判 `看多`; +3. 否则判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1M10_NTMDK多空V230824_看多_任意_任意_0')` +- `Signal('60分钟_D1M10_NTMDK多空V230824_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `m`:回看比较窗口,默认 `10`。 + +## 对齐说明 + +比较口径对齐 Python `ntmdk_V230824`。 diff --git a/.claude/skills/signal-functions/references/signals/obv_up_dw_line_V230719.md b/.claude/skills/signal-functions/references/signals/obv_up_dw_line_V230719.md new file mode 100644 index 000000000..fe79c5194 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/obv_up_dw_line_V230719.md @@ -0,0 +1,30 @@ +# obv_up_dw_line_V230719:OBV 交叉信号 + +> 模块: `obv.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}MO{max_overlap}_OBV能量V230719` + +## 信号逻辑 + +1. 先计算 OBV 累计量序列; +2. 计算 `obvm = EMA(OBV, n)`,再计算 `sig = EMA(obvm, m)`; +3. 若当前 `obvm > sig` 且 `max_overlap` 根前 `obvm < sig`,判 `看多`; +4. 若当前 `obvm < sig` 且 `max_overlap` 根前 `obvm > sig`,判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N7M10MO3_OBV能量V230719_看多_任意_任意_0')` +- `Signal('60分钟_D1N7M10MO3_OBV能量V230719_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:OBVM EMA 周期,默认 `7`; +- `m`:信号线 EMA 周期,默认 `10`; +- `max_overlap`:交叉回看根数,默认 `3`。 + +## 对齐说明 + +交叉判定时点与 Python `obv_up_dw_line_V230719` 完全一致(使用 `-max_overlap`)。 diff --git a/.claude/skills/signal-functions/references/signals/obvm_line_V230610.md b/.claude/skills/signal-functions/references/signals/obvm_line_V230610.md new file mode 100644 index 000000000..5a22a82ee --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/obvm_line_V230610.md @@ -0,0 +1,28 @@ +# obvm_line_V230610:OBV 双 EMA 能量信号 + +> 模块: `obv.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}_OBV能量V230610` + +## 信号逻辑 + +1. 计算 OBV 累计量序列(阳线加量、阴线减量); +2. 分别计算 OBV 的短期 EMA(`n`) 与长期 EMA(`m`); +3. 短期 EMA 大于长期 EMA 判 `看多`,否则判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N10M30_OBV能量V230610_看多_任意_任意_0')` +- `Signal('60分钟_D1N10M30_OBV能量V230610_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:短期 EMA 周期,默认 `10`; +- `m`:长期 EMA 周期,默认 `30`。 + +## 对齐说明 + +OBV 构造方式与 Python `obvm_line_V230610` 一致(按 K 线涨跌符号加减成交量)。 diff --git a/.claude/skills/signal-functions/references/signals/pos_bar_stop_V230524.md b/.claude/skills/signal-functions/references/signals/pos_bar_stop_V230524.md new file mode 100644 index 000000000..c0e98b3e3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_bar_stop_V230524.md @@ -0,0 +1,23 @@ +# pos_bar_stop_V230524:按开仓点附近N根K线极值止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}N{n}K_止损V230524` + +## 信号逻辑 + +- 多头:开仓前最近 `n` 根K线最低价被最新价跌破,记 `多头止损`; +- 空头:开仓前最近 `n` 根K线最高价被最新价突破,记 `空头止损`。 + +## 信号列表示例 + +- `Signal('日线三买多头_日线N3K_止损V230524_多头止损_任意_任意_0')` +- `Signal('日线三买多头_日线N3K_止损V230524_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `n`:向前取K线数量,默认 `3`,有效范围 `[1, 20]`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_fix_exit_V230624.md b/.claude/skills/signal-functions/references/signals/pos_fix_exit_V230624.md new file mode 100644 index 000000000..d85c88aac --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_fix_exit_V230624.md @@ -0,0 +1,22 @@ +# pos_fix_exit_V230624:固定 BP 止盈止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_固定{th}BP止盈止损_出场V230624` + +## 信号逻辑 + +- 多头:现价低于 `开仓价*(1-th/10000)` 为 `多头止损`,高于 `开仓价*(1+th/10000)` 为 `多头止盈`; +- 空头:规则镜像为 `空头止损/空头止盈`。 + +## 信号列表示例 + +- `Signal('日线三买多头_固定100BP止盈止损_出场V230624_多头止损_任意_任意_0')` +- `Signal('日线三买多头_固定100BP止盈止损_出场V230624_空头止盈_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `th`:止盈止损阈值(BP),默认 `300`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_fx_stop_V230414.md b/.claude/skills/signal-functions/references/signals/pos_fx_stop_V230414.md new file mode 100644 index 000000000..5e470d432 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_fx_stop_V230414.md @@ -0,0 +1,24 @@ +# pos_fx_stop_V230414:按开仓点附近分型止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{freq1}_{pos_name}N{n}_止损V230414` + +## 信号逻辑 + +- 多头:取开仓前最近 `n` 个底分型,最新价跌破最低分型低点,记 `多头止损`; +- 空头:取开仓前最近 `n` 个顶分型,最新价突破最高分型高点,记 `空头止损`; +- 其余场景记 `其他`。 + +## 信号列表示例 + +- `Signal('日线_日线三买多头N1_止损V230414_多头止损_任意_任意_0')` +- `Signal('日线_日线三买多头N1_止损V230414_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `n`:向前取分型数量,默认 `3`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_holds_V230414.md b/.claude/skills/signal-functions/references/signals/pos_holds_V230414.md new file mode 100644 index 000000000..aadcb66a8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_holds_V230414.md @@ -0,0 +1,24 @@ +# pos_holds_V230414:开仓后 N 根K线收益与阈值比较 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}N{n}M{m}_趋势判断V230414` + +## 信号逻辑 + +- 多头:`zdf=(最新收盘-开仓价)/开仓价*10000`,`zdf 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}N{n}M{m}T{t}_BS辅助V230807` + +## 信号逻辑 + +- 当开仓后收益落在 `(t, m)` 区间时触发 `多头保本/空头保本`; +- 含义是“达到最低保本收益但未达到趋势确认阈值”,优先保本离场。 + +## 信号列表示例 + +- `Signal('日线三买多头N1_60分钟N5M50T10_BS辅助V230807_多头保本_任意_任意_0')` +- `Signal('日线三买多头N1_60分钟N5M50T10_BS辅助V230807_空头保本_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `n`:最少持有K线数量,默认 `5`; +- `m`:收益上限阈值(BP),默认 `50`; +- `t`:保本收益阈值(BP),默认 `10`,且要求 `m > t > 0`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_holds_V240428.md b/.claude/skills/signal-functions/references/signals/pos_holds_V240428.md new file mode 100644 index 000000000..1063d3530 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_holds_V240428.md @@ -0,0 +1,25 @@ +# pos_holds_V240428:最大盈利回撤比例保本 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}H{h}T{t}N{n}_保本V240428` + +## 信号逻辑 + +- 多头:最大盈利 `y1` 超过 `h` 且当前盈利 `y2 < y1*t/100`,记 `多头保本`; +- 空头:按镜像规则记 `空头保本`。 + +## 信号列表示例 + +- `Signal('日线三买多头N1_60分钟H100T20N5_保本V240428_多头保本_任意_任意_0')` +- `Signal('日线三买多头N1_60分钟H100T20N5_保本V240428_空头保本_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `h`:最大盈利阈值(BP),默认 `100`; +- `t`:回撤比例阈值(%),默认 `20`; +- `n`:最少持有K线数量,默认 `5`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_holds_V240608.md b/.claude/skills/signal-functions/references/signals/pos_holds_V240608.md new file mode 100644 index 000000000..44b661949 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_holds_V240608.md @@ -0,0 +1,24 @@ +# pos_holds_V240608:跌破/升破开仓前窗口极值后,回到成本价指定档位保本 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}W{w}N{n}_保本V240608` + +## 信号逻辑 + +- 多头:若开仓后最低价跌破开仓前 `w` 根最低价,且现价回到成本价上方第 `n` 档,记 `多头保本`; +- 空头:若开仓后最高价突破开仓前 `w` 根最高价,且现价回到成本价下方第 `n` 档,记 `空头保本`。 + +## 信号列表示例 + +- `Signal('日线三买多头N1_60分钟W20N2_保本V240608_多头保本_任意_任意_0')` +- `Signal('日线三买多头N1_60分钟W20N2_保本V240608_空头保本_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `w`:开仓前观察窗口,默认 `20`; +- `n`:成本价上下档位偏移,默认 `2`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_ma_V230414.md b/.claude/skills/signal-functions/references/signals/pos_ma_V230414.md new file mode 100644 index 000000000..4599b5ab8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_ma_V230414.md @@ -0,0 +1,25 @@ +# pos_ma_V230414:判断开仓后是否升破/跌破均线 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}#{ma_type}#{timeperiod}_持有状态V230414` + +## 信号逻辑 + +- 多头持仓:开仓后任一 bar 出现 `close > MA`,记 `多头_升破均线`; +- 空头持仓:开仓后任一 bar 出现 `close < MA`,记 `空头_跌破均线`; +- 其余场景返回 `其他_其他`。 + +## 信号列表示例 + +- `Signal('日线三买多头N1_60分钟#SMA#5_持有状态V230414_多头_升破均线_任意_0')` +- `Signal('日线三买多头N1_60分钟#SMA#5_持有状态V230414_空头_跌破均线_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期参数,默认 `5`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_profit_loss_V230624.md b/.claude/skills/signal-functions/references/signals/pos_profit_loss_V230624.md new file mode 100644 index 000000000..c0d883f3e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_profit_loss_V230624.md @@ -0,0 +1,25 @@ +# pos_profit_loss_V230624:盈亏比阈值判断 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}YKB{ykb}N{n}_盈亏比判断V230624` + +## 信号逻辑 + +- 基于开仓前 `n` 个分型确定止损价,计算 `ykb = (现价-开仓价)/(开仓价-止损价)*10`; +- `ykb > 阈值` 记 `多头达标/空头达标`; +- 未达标时若击穿止损价,记 `多头止损/空头止损`。 + +## 信号列表示例 + +- `Signal('日线通道突破_60分钟YKB20N3_盈亏比判断V230624_多头达标_任意_任意_0')` +- `Signal('日线通道突破_60分钟YKB20N3_盈亏比判断V230624_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `ykb`:盈亏比阈值(×10),默认 `20`; +- `n`:止损分型窗口,默认 `3`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_status_V230808.md b/.claude/skills/signal-functions/references/signals/pos_status_V230808.md new file mode 100644 index 000000000..797e75415 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_status_V230808.md @@ -0,0 +1,22 @@ +# pos_status_V230808:持仓状态 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_持仓状态_BS辅助V230808` + +## 信号逻辑 + +- 最近操作为 `LO` 输出 `持多`; +- 最近操作为 `SO` 输出 `持空`; +- 其余输出 `持币`。 + +## 信号列表示例 + +- `Signal('日线三买多头N1_持仓状态_BS辅助V230808_持多_任意_任意_0')` +- `Signal('日线三买多头N1_持仓状态_BS辅助V230808_持币_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称。 diff --git a/.claude/skills/signal-functions/references/signals/pos_stop_V240331.md b/.claude/skills/signal-functions/references/signals/pos_stop_V240331.md new file mode 100644 index 000000000..1e6a7f348 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_stop_V240331.md @@ -0,0 +1,23 @@ +# pos_stop_V240331:最近 N 根K线追踪止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}#{n}_止损V240331` + +## 信号逻辑 + +- 多头:最新K线低点跌破前 `n` 根最低价且 bar_id 晚于开仓 bar,触发 `多头止损`; +- 空头:最新K线高点突破前 `n` 根最高价且 bar_id 晚于开仓 bar,触发 `空头止损`。 + +## 信号列表示例 + +- `Signal('SMA5多头_15分钟#10_止损V240331_多头止损_任意_任意_0')` +- `Signal('SMA5空头_15分钟#10_止损V240331_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `n`:追踪窗口,默认 `10`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_stop_V240428.md b/.claude/skills/signal-functions/references/signals/pos_stop_V240428.md new file mode 100644 index 000000000..3972d9247 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_stop_V240428.md @@ -0,0 +1,25 @@ +# pos_stop_V240428:按开仓前离散价位跳数止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}T{t}N{n}_止损V240428` + +## 信号逻辑 + +- 使用开仓前历史K线提取离散价位; +- 多头取低于开仓价的第 `t` 档止损位,空头取高于开仓价的第 `t` 档止损位; +- 开仓后至少持有 `n` 根后,收盘穿越止损位触发 `多头止损/空头止损`。 + +## 信号列表示例 + +- `Signal('日线三买多头N1_60分钟T20N5_止损V240428_多头止损_任意_任意_0')` +- `Signal('日线三买多头N1_60分钟T20N5_止损V240428_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `t`:离散价位档位,默认 `20`; +- `n`:最少持有K线数量,默认 `5`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_stop_V240608.md b/.claude/skills/signal-functions/references/signals/pos_stop_V240608.md new file mode 100644 index 000000000..c340a38f7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_stop_V240608.md @@ -0,0 +1,24 @@ +# pos_stop_V240608:开仓后突破开仓前窗口极值 N 档止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}W{w}N{n}_止损V240608` + +## 信号逻辑 + +- 多头:开仓后最低价低于“开仓前 `w` 根最低价下方第 `n` 档”触发 `多头止损`; +- 空头:开仓后最高价高于“开仓前 `w` 根最高价上方第 `n` 档”触发 `空头止损`。 + +## 信号列表示例 + +- `Signal('SMA5多头_15分钟W20N10_止损V240608_多头止损_任意_任意_0')` +- `Signal('SMA5空头_15分钟W20N10_止损V240608_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `w`:开仓前观察窗口,默认 `20`; +- `n`:上下档位偏移,默认 `10`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_stop_V240614.md b/.claude/skills/signal-functions/references/signals/pos_stop_V240614.md new file mode 100644 index 000000000..fc8414340 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_stop_V240614.md @@ -0,0 +1,23 @@ +# pos_stop_V240614:开仓后低于/高于成本价的 K线数量计数止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}N{n}_止损V240614` + +## 信号逻辑 + +- 多头:开仓后 `low < 开仓价` 的K线数量达到 `n`,触发 `多头止损`; +- 空头:开仓后 `high > 开仓价` 的K线数量达到 `n`,触发 `空头止损`。 + +## 信号列表示例 + +- `Signal('SMA5多头_15分钟N10_止损V240614_多头止损_任意_任意_0')` +- `Signal('SMA5空头_15分钟N10_止损V240614_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `n`:计数阈值,默认 `10`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_stop_V240717.md b/.claude/skills/signal-functions/references/signals/pos_stop_V240717.md new file mode 100644 index 000000000..2315812fc --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_stop_V240717.md @@ -0,0 +1,25 @@ +# pos_stop_V240717:基于开仓时 ATR 的计数止损 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}N{n}T{timeperiod}_止损V240717` + +## 信号逻辑 + +- 先取开仓时刻 ATR(`timeperiod`); +- 多头阈值为 `开仓价 - ATR*0.67`,空头阈值为 `开仓价 + ATR*0.67`; +- 开仓后超过阈值的K线数量达到 `n` 时触发 `多头止损/空头止损`。 + +## 信号列表示例 + +- `Signal('SMA5多头_15分钟N3T20_止损V240717_多头止损_任意_任意_0')` +- `Signal('SMA5空头_15分钟N3T20_止损V240717_空头止损_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `n`:计数阈值,默认 `10`; +- `timeperiod`:ATR 周期,默认 `20`。 diff --git a/.claude/skills/signal-functions/references/signals/pos_take_V240428.md b/.claude/skills/signal-functions/references/signals/pos_take_V240428.md new file mode 100644 index 000000000..eb1a8bea0 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pos_take_V240428.md @@ -0,0 +1,24 @@ +# pos_take_V240428:倍量阳/阴线计数止盈 + +> 模块: `pos.rs` | 类别: `trader` + +## 参数模板 + +`"{pos_name}_{freq1}T{t}N{n}_止盈V240428` + +## 信号逻辑 + +- 多头统计开仓后“阳线且成交量 > 前一根 2 倍”的次数,达到 `t` 触发 `多头止盈`; +- 空头统计对应倍量阴线次数,达到 `t` 触发 `空头止盈`。 + +## 信号列表示例 + +- `Signal('日线三买多头N1_60分钟T3N5_止盈V240428_多头止盈_任意_任意_0')` +- `Signal('日线三买多头N1_60分钟T3N5_止盈V240428_空头止盈_任意_任意_0')` + +## 参数说明 + +- `pos_name`:仓位名称; +- `freq1`:K线周期; +- `t`:倍量K线数量阈值,默认 `3`; +- `n`:最少持有K线数量,默认 `5`。 diff --git a/.claude/skills/signal-functions/references/signals/pressure_support_V240222.md b/.claude/skills/signal-functions/references/signals/pressure_support_V240222.md new file mode 100644 index 000000000..3445c9099 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pressure_support_V240222.md @@ -0,0 +1,28 @@ +# pressure_support_V240222:高低点验证支撑压力位 + +> 模块: `pressure.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}高低点验证_支撑压力V240222` + +## 信号逻辑 + +1. 取最近 `w` 根K线,计算区间最高/最低与振幅标准差 `gap`; +2. 若区间波动不足(`max_high-min_low < gap*0.3*w`)则返回 `其他`; +3. 若窗口两端高点贴近全局高点,判 `压力位`; +4. 若窗口两端低点贴近全局低点,判 `支撑位`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W20高低点验证_支撑压力V240222_压力位_任意_任意_0')` +- `Signal('60分钟_D1W20高低点验证_支撑压力V240222_支撑位_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `w`:观察窗口大小,默认 `20`,且必须大于 `10`。 + +## 对齐说明 + +窗口切分与高低点验证条件对齐 Python `pressure_support_V240222`。 diff --git a/.claude/skills/signal-functions/references/signals/pressure_support_V240402.md b/.claude/skills/signal-functions/references/signals/pressure_support_V240402.md new file mode 100644 index 000000000..d24c14e79 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pressure_support_V240402.md @@ -0,0 +1,28 @@ +# pressure_support_V240402:分型区间支撑压力位 + +> 模块: `pressure.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}_支撑压力V240402` + +## 信号逻辑 + +1. 统计最近 `50` 个分型中,包含当前收盘价的分型数量; +2. 若命中分型少于 `5` 或窗口波动不足(`max_high-min_low < gap*3`)返回 `其他`; +3. 当前收盘靠近窗口上沿(前20%)判 `压力位`; +4. 当前收盘靠近窗口下沿(前30%)判 `支撑位`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W60_支撑压力V240402_压力位_任意_任意_0')` +- `Signal('60分钟_D1W60_支撑压力V240402_支撑位_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `w`:观察窗口大小,默认 `60`,且必须大于 `10`。 + +## 对齐说明 + +分型筛选与收盘位置阈值对齐 Python `pressure_support_V240402`。 diff --git a/.claude/skills/signal-functions/references/signals/pressure_support_V240406.md b/.claude/skills/signal-functions/references/signals/pressure_support_V240406.md new file mode 100644 index 000000000..d538bb367 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pressure_support_V240406.md @@ -0,0 +1,28 @@ +# pressure_support_V240406:分型密集支撑压力位 + +> 模块: `pressure.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}_支撑压力V240406` + +## 信号逻辑 + +1. 统计窗口最高/最低附近的分型数量(严格落在分型区间内); +2. 若窗口波动不足(`max_high-min_low < gap*3`)返回 `其他`; +3. 若高点附近分型 `>=3` 且收盘靠近上沿,判 `压力位`; +4. 若低点附近分型 `>=3` 且收盘靠近下沿,判 `支撑位`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W60_支撑压力V240406_压力位_任意_任意_0')` +- `Signal('60分钟_D1W60_支撑压力V240406_支撑位_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `w`:观察窗口大小,默认 `60`,且必须大于 `10`。 + +## 对齐说明 + +分型密集阈值和价格区间判断对齐 Python `pressure_support_V240406`。 diff --git a/.claude/skills/signal-functions/references/signals/pressure_support_V240530.md b/.claude/skills/signal-functions/references/signals/pressure_support_V240530.md new file mode 100644 index 000000000..496a4dba1 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/pressure_support_V240530.md @@ -0,0 +1,29 @@ +# pressure_support_V240530:关键重叠K线支撑压力位 + +> 模块: `pressure.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}N{n}_支撑压力V240530` + +## 信号逻辑 + +1. 在最近 `w` 根K线中寻找与其他K线重叠次数最多的关键K线; +2. 若最大重叠次数小于 `0.5*w`,返回 `其他`; +3. 以关键K线高低价在全局 `unique price` 列表上的 `±n` 档形成压力/支撑区间; +4. 收盘落入高位区间判 `压力位`,落入低位区间判 `支撑位`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W20N5_支撑压力V240530_压力位_任意_任意_0')` +- `Signal('60分钟_D1W20N5_支撑压力V240530_支撑位_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `w`:观察窗口大小,默认 `20`,且必须大于 `10`; +- `n`:价格档位偏移,默认 `5`。 + +## 对齐说明 + +关键K线重叠计数和 `unique price ±n` 区间判定对齐 Python `pressure_support_V240530`。 diff --git a/.claude/skills/signal-functions/references/signals/skdj_up_dw_line_V230611.md b/.claude/skills/signal-functions/references/signals/skdj_up_dw_line_V230611.md new file mode 100644 index 000000000..3df914ff4 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/skdj_up_dw_line_V230611.md @@ -0,0 +1,30 @@ +# skdj_up_dw_line_V230611:SKDJ 随机波动信号 + +> 模块: `ang.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}UP{up}DW{dw}_SKDJ随机波动V230611` + +## 信号逻辑 + +1. 先计算 `RSV(n)` 序列; +2. 对 RSV 做两次 `m` 周期均值平滑; +3. `dw < D < K_last` 判 `看多`;`K_last < D 且 D > up` 判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N233M89UP60DW40_SKDJ随机波动V230611_看多_任意_任意_0')` +- `Signal('60分钟_D1N233M89UP60DW40_SKDJ随机波动V230611_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:RSV 窗口,默认 `233`; +- `m`:平滑窗口,默认 `89`; +- `up`:超买阈值,默认 `60`; +- `dw`:超卖阈值,默认 `40`。 + +## 对齐说明 + +与 Python `skdj_up_dw_line_V230611` 的双平滑与阈值判定一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_accelerate_V230531.md b/.claude/skills/signal-functions/references/signals/tas_accelerate_V230531.md new file mode 100644 index 000000000..0d1d1f1e4 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_accelerate_V230531.md @@ -0,0 +1,28 @@ +# tas_accelerate_V230531:BOLL 通道加速信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}T{t}_BOLL加速V230531` + +## 信号逻辑 + +1. 取最近 `n` 根,计算中线/上轨/下轨涨跌幅; +2. 全部在中线上方且 `上轨涨幅 > t/10 * 中线涨幅 > 0`,判 `多头加速`; +3. 全部在中线下方且 `下轨涨幅 < t/10 * 中线涨幅 < 0`,判 `空头加速`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N20T20_BOLL加速V230531_多头加速_升破上轨_任意_0')` +- `Signal('60分钟_D1N20T20_BOLL加速V230531_空头加速_跌破下轨_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:观察窗口,默认 `20`; +- `t`:轨道/中线倍率阈值(除以10),默认 `20`。 + +## 对齐说明 + +按 Python `tas_accelerate_V230531` 双分支覆盖语义实现。 diff --git a/.claude/skills/signal-functions/references/signals/tas_angle_V230802.md b/.claude/skills/signal-functions/references/signals/tas_angle_V230802.md new file mode 100644 index 000000000..bea0abfc0 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_angle_V230802.md @@ -0,0 +1,28 @@ +# tas_angle_V230802:笔角度偏离信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}T{th}_笔角度V230802` + +## 信号逻辑 + +1. 定义角度为 `power_price / length`; +2. 取同向历史 `n` 笔角度均值作为基线; +3. 当前角度低于 `th%` 时输出反向信号。 + +## 信号列表示例 + +- `Signal('60分钟_D1N9T50_笔角度V230802_空头_任意_任意_0')` +- `Signal('60分钟_D1N9T50_笔角度V230802_多头_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 笔,默认 `1`; +- `n`:同向样本数,默认 `9`; +- `th`:角度阈值百分比,默认 `50`。 + +## 对齐说明 + +`length` 口径使用 `BI.length`(无包含K数量)与 Python 对齐。 diff --git a/.claude/skills/signal-functions/references/signals/tas_atr_V230630.md b/.claude/skills/signal-functions/references/signals/tas_atr_V230630.md new file mode 100644 index 000000000..f608d654b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_atr_V230630.md @@ -0,0 +1,27 @@ +# tas_atr_V230630:ATR 波动分层信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}ATR{timeperiod}_波动V230630` + +## 信号逻辑 + +1. 计算 `ATR / close` 波动率; +2. 对最近 100 根波动率做 `qcut(10)` 分层; +3. 输出末值所在层级 `第{n}层`。 + +## 信号列表示例 + +- `Signal('60分钟_D1ATR14_波动V230630_第3层_任意_任意_0')` +- `Signal('60分钟_D1ATR14_波动V230630_第9层_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:ATR 周期,默认 `14`。 + +## 对齐说明 + +ATR 预热与分层边界对齐 Python `tas_atr_V230630`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_atr_break_V230424.md b/.claude/skills/signal-functions/references/signals/tas_atr_break_V230424.md new file mode 100644 index 000000000..ab447ed97 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_atr_break_V230424.md @@ -0,0 +1,28 @@ +# tas_atr_break_V230424:ATR 通道突破 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}ATR{timeperiod}T{th}突破_BS辅助V230424` + +## 信号逻辑 + +1. 取窗口 `HH/LL` 和当前 ATR; +2. 若 `close` 落在 `HH-th*ATR` 与 `LL+th*ATR` 之间,输出 `其他`; +3. 向上突破输出 `看多`,向下突破输出 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1ATR5T30突破_BS辅助V230424_看多_任意_任意_0')` +- `Signal('60分钟_D1ATR5T30突破_BS辅助V230424_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:ATR 周期,默认 `5`; +- `th`:ATR 倍数(除以10),默认 `30`。 + +## 对齐说明 + +区间内返回 `其他` 的优先级与 Python `tas_atr_break_V230424` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_boll_bc_V221118.md b/.claude/skills/signal-functions/references/signals/tas_boll_bc_V221118.md new file mode 100644 index 000000000..79abc0776 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_boll_bc_V221118.md @@ -0,0 +1,30 @@ +# tas_boll_bc_V221118:BOLL背驰辅助信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}L{line}#BOLL{timeperiod}_背驰V221118` + +## 信号逻辑 + +1. 对比近端 `n` 根与参考段 `m` 根的价格极值; +2. 结合 BOLL 指定轨道 `line` 的上下突破次数; +3. 满足低点背驰给 `一买`,满足高点背驰给 `一卖`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N3M10L3#BOLL20_背驰V221118_一买_任意_任意_0')` +- `Signal('60分钟_D1N3M10L3#BOLL20_背驰V221118_一卖_任意_任意_0')` +- `Signal('60分钟_D1N3M10L3#BOLL20_背驰V221118_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n/m`:近端与参考窗口长度,默认 `3/10`; +- `line`:轨道层级,默认 `3`; +- `timeperiod`:BOLL周期,默认 `20`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_boll_cc_V230312.md b/.claude/skills/signal-functions/references/signals/tas_boll_cc_V230312.md new file mode 100644 index 000000000..42fde8d32 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_boll_cc_V230312.md @@ -0,0 +1,29 @@ +# tas_boll_cc_V230312:布林进出场信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}BOLL{timeperiod}S{nbdev}SP{sp}_BS辅助V230312` + +## 信号逻辑 + +1. 计算 BOLL 中轨与上下轨; +2. 计算 `bias = (close / mid - 1) * 10000`; +3. `close < upper 且 bias < -sp` 判 `看空`,`close > lower 且 bias > sp` 判 `看多`。 + +## 信号列表示例 + +- `Signal('60分钟_D1BOLL20S20SP400_BS辅助V230312_看空_任意_任意_0')` +- `Signal('60分钟_D1BOLL20S20SP400_BS辅助V230312_看多_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:BOLL 周期,默认 `20`; +- `nbdev`:标准差倍数 *10,默认 `20`; +- `sp`:偏离阈值(BP),默认 `400`。 + +## 对齐说明 + +与 Python `tas_boll_cc_V230312` 的 bias 判定和阈值方向一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_boll_power_V221112.md b/.claude/skills/signal-functions/references/signals/tas_boll_power_V221112.md new file mode 100644 index 000000000..be23aa159 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_boll_power_V221112.md @@ -0,0 +1,27 @@ +# tas_boll_power_V221112:BOLL强弱分层信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}BOLL{timeperiod}_强弱V221112` + +## 信号逻辑 + +1. 计算 BOLL 中线与标准差; +2. 先以 `close` 相对中线判断 `多头/空头`; +3. 再按偏离程度分层 `弱势/强势/超强/极强`。 + +## 信号列表示例 + +- `Signal('60分钟_D1BOLL20_强弱V221112_多头_强势_任意_0')` +- `Signal('60分钟_D1BOLL20_强弱V221112_空头_超强_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:BOLL周期,默认 `20`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_boll_vt_V230212.md b/.claude/skills/signal-functions/references/signals/tas_boll_vt_V230212.md new file mode 100644 index 000000000..716179b5b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_boll_vt_V230212.md @@ -0,0 +1,29 @@ +# tas_boll_vt_V230212:BOLL 通道突破进出场信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}BOLL{timeperiod}S{nbdev}MO{max_overlap}_BS辅助V230212` + +## 信号逻辑 + +1. 计算指定参数的 BOLL 上下轨(`nbdev / 10` 为标准差倍数); +2. 最新收盘价在上轨上方,且窗口内曾有收盘价在上轨下方,判 `看多`; +3. 最新收盘价在下轨下方,且窗口内曾有收盘价在下轨上方,判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1BOLL20S20MO5_BS辅助V230212_看多_任意_任意_0')` +- `Signal('60分钟_D1BOLL20S20MO5_BS辅助V230212_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:BOLL 周期,默认 `20`; +- `nbdev`:标准差倍数 *10,默认 `20`; +- `max_overlap`:窗口重叠长度,默认 `5`。 + +## 对齐说明 + +严格按 Python `tas_boll_vt_V230212` 判定分支实现。 diff --git a/.claude/skills/signal-functions/references/signals/tas_cci_base_V230402.md b/.claude/skills/signal-functions/references/signals/tas_cci_base_V230402.md new file mode 100644 index 000000000..554aa5de6 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_cci_base_V230402.md @@ -0,0 +1,29 @@ +# tas_cci_base_V230402:CCI 极值连续计数信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}CCI{timeperiod}#{min_count}#{max_count}_BS辅助V230402` + +## 信号逻辑 + +1. 计算 CCI 序列; +2. 若末尾连续 `CCI > 100` 次数落在 `[min_count, max_count)`,判 `多头`; +3. 若末尾连续 `CCI < -100` 次数落在 `[min_count, max_count)`,判 `空头`。 + +## 信号列表示例 + +- `Signal('60分钟_D1CCI14#3#6_BS辅助V230402_多头_任意_任意_0')` +- `Signal('60分钟_D1CCI14#3#6_BS辅助V230402_空头_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod`:CCI 周期,默认 `14`; +- `min_count`:最小连续次数,默认 `3`; +- `max_count`:最大连续次数上界(开区间),默认 `min_count + 3`。 + +## 对齐说明 + +连续计数和覆盖顺序与 Python `tas_cci_base_V230402` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_cross_status_V230619.md b/.claude/skills/signal-functions/references/signals/tas_cross_status_V230619.md new file mode 100644 index 000000000..1fd1b96fd --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_cross_status_V230619.md @@ -0,0 +1,27 @@ +# tas_cross_status_V230619:0轴上下金死叉次数 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_金死叉V230619` + +## 信号逻辑 + +1. 取近 100 根 DIF/DEA 并截取最近过零后的有效段; +2. 在 0 轴上下分别统计金叉/死叉次数; +3. 若当根形成有效交叉,输出 `0轴上/下金叉(死叉)第N次`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9_金死叉V230619_0轴下金叉第1次_任意_任意_0')` +- `Signal('60分钟_D1MACD12#26#9_金死叉V230619_0轴上死叉第2次_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 + +## 对齐说明 + +过零截取与交叉计次逻辑对齐 Python `tas_cross_status_V230619`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_cross_status_V230624.md b/.claude/skills/signal-functions/references/signals/tas_cross_status_V230624.md new file mode 100644 index 000000000..cad472264 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_cross_status_V230624.md @@ -0,0 +1,28 @@ +# tas_cross_status_V230624:指定金死叉数值 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}MD{md}_MACD交叉数量V230624` + +## 信号逻辑 + +1. 取最近 `n` 根 DIF/DEA 并过零截断; +2. 按最小间隔 `md` 过滤交叉并统计 `jc/sc`; +3. 根据当前所在零轴区域输出上下轴金叉/死叉次数。 + +## 信号列表示例 + +- `Signal('60分钟_D1N100MD1_MACD交叉数量V230624_0轴下金叉第2次_0轴下死叉第1次_任意_0')` +- `Signal('60分钟_D1N100MD1_MACD交叉数量V230624_0轴上金叉第1次_0轴上死叉第2次_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:统计窗口长度,默认 `100`; +- `md`:最小交叉间隔,默认 `1`。 + +## 对齐说明 + +交叉过滤及计数口径与 Python `tas_cross_status_V230624` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_cross_status_V230625.md b/.claude/skills/signal-functions/references/signals/tas_cross_status_V230625.md new file mode 100644 index 000000000..546208b0a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_cross_status_V230625.md @@ -0,0 +1,29 @@ +# tas_cross_status_V230625:指定金叉/死叉次数后状态 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}MD{md}J{j}S{s}_MACD交叉数量V230625` + +## 信号逻辑 + +1. 在近 `n` 根内统计过滤后的金叉/死叉数量; +2. 仅允许 `j` 或 `s` 之一生效; +3. 达到目标次数后输出 `0轴上/下第N次金叉(死叉)以后`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N100MD1J2S0_MACD交叉数量V230625_0轴下第2次金叉以后_任意_任意_0')` +- `Signal('60分钟_D1N100MD1J0S2_MACD交叉数量V230625_0轴上第2次死叉以后_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:统计窗口长度,默认 `100`; +- `md`:交叉间隔阈值,默认 `1`; +- `j/s`:目标金叉或死叉次数,默认 `0/0`(二者不能同时非零)。 + +## 对齐说明 + +参数约束与触发语义对齐 Python `tas_cross_status_V230625`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_dif_layer_V241010.md b/.claude/skills/signal-functions/references/signals/tas_dif_layer_V241010.md new file mode 100644 index 000000000..7e9cd17b8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_dif_layer_V241010.md @@ -0,0 +1,27 @@ +# tas_dif_layer_V241010:DIF 三层分类 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_DIF分层W{w}T{t}_完全分类V241010` + +## 信号逻辑 + +1. 取最近 `w` 根 DIF,计算绝对值最大幅度基准 `r`; +2. `|dif_last| > r * t` 且符号为负,判 `空头远离`; +3. `|dif_last| > r * t` 且符号为正,判 `多头远离`,否则 `零轴附近`。 + +## 信号列表示例 + +- `Signal('60分钟_DIF分层W100T30_完全分类V241010_空头远离_任意_任意_0')` +- `Signal('60分钟_DIF分层W100T30_完全分类V241010_零轴附近_任意_任意_0')` + +## 参数说明 + +- `w`:观察窗口长度,默认 `100`; +- `t`:远离阈值倍率,默认 `30`。 + +## 对齐说明 + +分层阈值口径与 Python `tas_dif_layer_V241010` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_dif_zero_V240612.md b/.claude/skills/signal-functions/references/signals/tas_dif_zero_V240612.md new file mode 100644 index 000000000..f0759204f --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_dif_zero_V240612.md @@ -0,0 +1,28 @@ +# tas_dif_zero_V240612:DIF靠近零轴买卖点信号(基于最近一笔) + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_DIF靠近零轴T{t}_BS辅助V240612` + +## 信号逻辑 + +1. 取最近一笔内部原始K线的 `DIF` 序列; +2. 计算 `delta = std(diffs) * t / 100`; +3. 若最后一笔为向下笔,且末端 `DIF` 靠近零轴,同时 `max(diffs)` 显著高于均值+标准差,判 `买点`; +4. 若最后一笔为向上笔,且末端 `DIF` 靠近零轴,同时 `min(diffs)` 显著低于-(均值+标准差),判 `卖点`。 + +## 信号列表示例 + +- `Signal('60分钟_DIF靠近零轴T50_BS辅助V240612_买点_任意_任意_0')` +- `Signal('60分钟_DIF靠近零轴T50_BS辅助V240612_卖点_任意_任意_0')` +- `Signal('60分钟_DIF靠近零轴T50_BS辅助V240612_其他_任意_任意_0')` + +## 参数说明 + +- `t`:波动率倍数(除以100),默认 `50`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_dif_zero_V240614.md b/.claude/skills/signal-functions/references/signals/tas_dif_zero_V240614.md new file mode 100644 index 000000000..8cd43ceed --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_dif_zero_V240614.md @@ -0,0 +1,29 @@ +# tas_dif_zero_V240614:DIF靠近零轴买卖点信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_DIF靠近零轴W{w}T{t}_BS辅助V240614` + +## 信号逻辑 + +1. 取最近 `w` 根K线的 `DIF` 序列; +2. 计算 `delta = std(diffs) * t / 100`; +3. 若 `diffs` 全部大于0,且 `diffs[-1]` 靠近零轴,同时 `max(diffs)` 显著高于均值+标准差,判 `买点`; +4. 若 `diffs` 全部小于0,且 `diffs[-1]` 靠近零轴,同时 `min(diffs)` 显著低于-(均值+标准差),判 `卖点`。 + +## 信号列表示例 + +- `Signal('60分钟_DIF靠近零轴W20T50_BS辅助V240614_买点_任意_任意_0')` +- `Signal('60分钟_DIF靠近零轴W20T50_BS辅助V240614_卖点_任意_任意_0')` +- `Signal('60分钟_DIF靠近零轴W20T50_BS辅助V240614_其他_任意_任意_0')` + +## 参数说明 + +- `w`:K线窗口长度,默认 `20`; +- `t`:波动率倍数(除以100),默认 `50`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_dma_bs_V240608.md b/.claude/skills/signal-functions/references/signals/tas_dma_bs_V240608.md new file mode 100644 index 000000000..03f50c232 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_dma_bs_V240608.md @@ -0,0 +1,28 @@ +# tas_dma_bs_V240608:双均线顺势回调买卖点 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}双均线{t1}#{t2}顺势_BS辅助V240608` + +## 信号逻辑 + +1. 以 `t1/t2` 均线顺势方向做过滤; +2. 在 `ma2` 附近按离散价格序号选取回踩/反抽位; +3. 满足穿越与收盘条件时给出 `买点/卖点`。 + +## 信号列表示例 + +- `Signal('60分钟_N5双均线5#10顺势_BS辅助V240608_买点_任意_任意_0')` +- `Signal('60分钟_N5双均线5#10顺势_BS辅助V240608_卖点_任意_任意_0')` + +## 参数说明 + +- `n`:价格序号偏移,默认 `5`; +- `t1`:快线周期,默认 `5`; +- `t2`:慢线周期,默认 `10`。 + +## 对齐说明 + +价格位选取(含负索引语义)与 Python `tas_dma_bs_V240608` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_double_ma_V221203.md b/.claude/skills/signal-functions/references/signals/tas_double_ma_V221203.md new file mode 100644 index 000000000..a18a520d8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_double_ma_V221203.md @@ -0,0 +1,29 @@ +# tas_double_ma_V221203:双均线多空强弱信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{th}#{ma_type}#{timeperiod1}#{timeperiod2}_JX辅助V221203` + +## 信号逻辑 + +1. 计算两条均线 `ma1/ma2`; +2. `ma1 >= ma2` 判定 `多头`,否则 `空头`; +3. 两线相对距离(BP)超过 `th` 判 `强势`,否则 `弱势`。 + +## 信号列表示例 + +- `Signal('60分钟_D1T100#SMA#5#10_JX辅助V221203_多头_强势_任意_0')` +- `Signal('60分钟_D1T80#EMA#12#26_JX辅助V221203_空头_弱势_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `th`:强弱阈值(BP),默认 `100`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod1/timeperiod2`:两条均线周期,默认 `5/10`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_double_ma_V230511.md b/.claude/skills/signal-functions/references/signals/tas_double_ma_V230511.md new file mode 100644 index 000000000..e52a51d88 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_double_ma_V230511.md @@ -0,0 +1,31 @@ +# tas_double_ma_V230511:双均线反向信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}#{ma_type}#{t1}#{t2}_BS辅助V230511` + +## 信号逻辑 + +1. 计算 `t1/t2` 双均线,`t1 < t2`; +2. 当前K线需为大实体(`solid >= max(upper, lower, mean_solid)`); +3. `ma1 > ma2` 且当前大实体阴线,判 `看多`; +4. `ma1 < ma2` 且当前大实体阳线,判 `看空`; +5. 在同侧连续区间内若仅出现一次对应大实体且区间长度 `< t2/2`,则 `v2=第一个`。 + +## 信号列表示例 + +- `Signal('60分钟_D1#SMA#5#20_BS辅助V230511_看多_第一个_任意_0')` +- `Signal('60分钟_D1#SMA#5#20_BS辅助V230511_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `t1`:快线周期,默认 `5`; +- `t2`:慢线周期,默认 `20`; +- `ma_type`:均线类型,默认 `SMA`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_double_ma_V240208.md b/.claude/skills/signal-functions/references/signals/tas_double_ma_V240208.md new file mode 100644 index 000000000..3e4229011 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_double_ma_V240208.md @@ -0,0 +1,28 @@ +# tas_double_ma_V240208:双均线交叉结构信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{N}M{M}双均线_BS辅助V240208` + +## 信号逻辑 + +1. 计算 `N/M` 双均线并识别交叉序列; +2. 最近三次交叉记作 `X1/X2/X3`; +3. `X3` 金叉且 `X2` 快线最高判 `多头`,死叉镜像判 `空头`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N20M60双均线_BS辅助V240208_多头_任意_任意_0')` +- `Signal('60分钟_D1N20M60双均线_BS辅助V240208_空头_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `N`:快线周期,默认 `20`; +- `M`:慢线周期,默认 `60`。 + +## 对齐说明 + +交叉类型与快线比较逻辑对齐 Python `tas_double_ma_V240208`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_first_bs_V230217.md b/.claude/skills/signal-functions/references/signals/tas_first_bs_V230217.md new file mode 100644 index 000000000..4ed809cb1 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_first_bs_V230217.md @@ -0,0 +1,34 @@ +# tas_first_bs_V230217:均线结合K线形态的一买一卖辅助 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS1辅助V230217` + +## 信号逻辑 + +1. 在最近 `n` 根K线上计算均线并构造 `sma/low/high/open/close` 序列; +2. 一买条件: +- `sma > low` 全满足; +- 阴线占比 `> 60%`; +- 最近3根出现新低; +- 最后一根收盘在均线上方; +3. 一卖条件与上面对称; +4. 满足则输出 `一买/一卖`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N10#SMA#5_BS1辅助V230217_一买_任意_任意_0')` +- `Signal('60分钟_D1N10#SMA#5_BS1辅助V230217_一卖_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:窗口大小,默认 `10`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_hlma_V230301.md b/.claude/skills/signal-functions/references/signals/tas_hlma_V230301.md new file mode 100644 index 000000000..b30952ec5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_hlma_V230301.md @@ -0,0 +1,29 @@ +# tas_hlma_V230301:HMA/LMA 多空信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}#{ma_type}#{timeperiod}HLMA_BS辅助V230301` + +## 信号逻辑 + +1. 取最近 `timeperiod` 根K线,计算 `hma=high均值`、`lma=low均值`; +2. 若 `close_now > hma` 且 `close_prev <= ma_prev`,判 `看多`; +3. 若 `close_now < lma` 且 `close_prev >= ma_prev`,判 `看空`; +4. 否则判 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1#SMA#3HLMA_BS辅助V230301_看多_任意_任意_0')` +- `Signal('60分钟_D1#SMA#3HLMA_BS辅助V230301_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:窗口周期,默认 `3`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_kdj_base_V221101.md b/.claude/skills/signal-functions/references/signals/tas_kdj_base_V221101.md new file mode 100644 index 000000000..ca7738c33 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_kdj_base_V221101.md @@ -0,0 +1,28 @@ +# tas_kdj_base_V221101:KDJ基础辅助信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K#KDJ{fastk_period}#{slowk_period}#{slowd_period}_KDJ辅助V221101` + +## 信号逻辑 + +1. 计算 K、D、J 三序列; +2. `J > K > D` 判定 `多头`,`J < K < D` 判定 `空头`,否则 `其他`; +3. `J_now >= J_prev` 判定 `向上`,否则 `向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K#KDJ9#3#3_KDJ辅助V221101_多头_向上_任意_0')` +- `Signal('60分钟_D1K#KDJ9#3#3_KDJ辅助V221101_空头_向下_任意_0')` +- `Signal('60分钟_D1K#KDJ9#3#3_KDJ辅助V221101_其他_向下_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastk_period/slowk_period/slowd_period`:KDJ参数,默认 `9/3/3`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_kdj_evc_V221201.md b/.claude/skills/signal-functions/references/signals/tas_kdj_evc_V221201.md new file mode 100644 index 000000000..8e702c80b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_kdj_evc_V221201.md @@ -0,0 +1,29 @@ +# tas_kdj_evc_V221201:KDJ 极值计数信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{c1}#{c2}_KDJ极值V221201` + +## 信号逻辑 + +1. 计算 `K/D/J` 序列并提取 `key`; +2. 统计末端连续低于 `th` 或高于 `100-th` 的次数; +3. 连续次数落入 `[c1, c2)` 时分别输出 `多头/空头`,并在 `v2` 标注计数。 + +## 信号列表示例 + +- `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_KDJ极值V221201_多头_C5_任意_0')` +- `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_KDJ极值V221201_空头_C6_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `key`:取值 `K/D/J`,默认 `K`; +- `th`:极值阈值,默认 `10`; +- `count_range`:连续计数区间,默认 `[5, 8]`。 + +## 对齐说明 + +连续计数、`v2=Cx` 标注方式与 Python `tas_kdj_evc_V221201` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_kdj_evc_V230401.md b/.claude/skills/signal-functions/references/signals/tas_kdj_evc_V230401.md new file mode 100644 index 000000000..1e5ad334d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_kdj_evc_V230401.md @@ -0,0 +1,29 @@ +# tas_kdj_evc_V230401:KDJ 极值计数信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{min_count}#{max_count}_BS辅助V230401` + +## 信号逻辑 + +1. 计算 `K/D/J` 指标并提取目标序列; +2. 末端连续低于阈值记多头计数,连续高于阈值记空头计数; +3. 连续次数在 `[min_count, max_count)` 时输出 `多头/空头`。 + +## 信号列表示例 + +- `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_BS辅助V230401_多头_任意_任意_0')` +- `Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_BS辅助V230401_空头_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `key`:`K/D/J`,默认 `K`; +- `th`:极值阈值,默认 `10`; +- `min_count/max_count`:连续计数区间,默认 `5/8`。 + +## 对齐说明 + +参数校验与计数边界严格对齐 Python `tas_kdj_evc_V230401`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_low_trend_V230627.md b/.claude/skills/signal-functions/references/signals/tas_low_trend_V230627.md new file mode 100644 index 000000000..1a3a7dfbb --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_low_trend_V230627.md @@ -0,0 +1,28 @@ +# tas_low_trend_V230627:阴跌/小阳趋势信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}TH{th}_趋势230627` + +## 信号逻辑 + +1. 对窗口内实体振幅做阈值过滤,剔除波动过大的场景; +2. 统计 `low <= 历史收盘最小值` 次数,超过 `0.8*n` 判 `阴跌趋势`; +3. 统计 `high >= 历史收盘最大值` 次数,超过 `0.8*n` 判 `小阳趋势`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N13TH300_趋势230627_阴跌趋势_任意_任意_0')` +- `Signal('60分钟_D1N13TH300_趋势230627_小阳趋势_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:统计窗口,默认 `13`; +- `th`:实体振幅阈值(BP),默认 `300`。 + +## 对齐说明 + +循环窗口与阈值比较口径对齐 Python `tas_low_trend_V230627`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_ma_base_V221101.md b/.claude/skills/signal-functions/references/signals/tas_ma_base_V221101.md new file mode 100644 index 000000000..8ce06d8f2 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_ma_base_V221101.md @@ -0,0 +1,28 @@ +# tas_ma_base_V221101:单均线多空与方向信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}{ma_type}#{timeperiod}_分类V221101` + +## 信号逻辑 + +1. 计算指定均线(`SMA/EMA`); +2. `close >= ma` 判定 `多头`,否则 `空头`; +3. `ma_now >= ma_prev` 判定 `向上`,否则 `向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1SMA#5_分类V221101_多头_向上_任意_0')` +- `Signal('60分钟_D1EMA#12_分类V221101_空头_向下_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_ma_base_V221203.md b/.claude/skills/signal-functions/references/signals/tas_ma_base_V221203.md new file mode 100644 index 000000000..6f8d9a33a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_ma_base_V221203.md @@ -0,0 +1,30 @@ +# tas_ma_base_V221203:单均线多空与距离分层信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}{ma_type}#{timeperiod}T{th}_分类V221203` + +## 信号逻辑 + +1. 计算指定均线(`SMA/EMA`); +2. `close >= ma` 判定 `多头`,否则 `空头`; +3. `ma_now >= ma_prev` 判定 `向上`,否则 `向下`; +4. `abs(close-ma)/ma * 10000 > th` 判定 `远离`,否则 `靠近`。 + +## 信号列表示例 + +- `Signal('60分钟_D1SMA#5T100_分类V221203_多头_向上_靠近_0')` +- `Signal('60分钟_D1EMA#12T80_分类V221203_空头_向下_远离_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`; +- `th`:距离阈值(BP),默认 `100`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_ma_base_V230313.md b/.claude/skills/signal-functions/references/signals/tas_ma_base_V230313.md new file mode 100644 index 000000000..ca0e640d4 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_ma_base_V230313.md @@ -0,0 +1,32 @@ +# tas_ma_base_V230313:单均线开平仓辅助信号(带重叠约束) + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}#{ma_type}#{timeperiod}MO{max_overlap}_BS辅助V230313` + +## 信号逻辑 + +1. 计算指定均线(`SMA/EMA`); +2. 取倒数 `di` 截止的 `max_overlap+1` 根K线; +3. 若最新 `close >= ma` 且窗口内并非全部 `close > ma`,判 `看多`; +4. 若最新 `close < ma` 且窗口内并非全部 `close < ma`,判 `看空`; +5. 否则判 `其他`;并用 `ma_now >= ma_prev` 判方向 `向上/向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1#SMA#5MO5_BS辅助V230313_看多_向上_任意_0')` +- `Signal('60分钟_D1#EMA#12MO5_BS辅助V230313_看空_向下_任意_0')` +- `Signal('60分钟_D1#SMA#5MO5_BS辅助V230313_其他_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`; +- `max_overlap`:相同方向最大重叠窗口,默认 `5`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_ma_cohere_V230512.md b/.claude/skills/signal-functions/references/signals/tas_ma_cohere_V230512.md new file mode 100644 index 000000000..e505087b5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_ma_cohere_V230512.md @@ -0,0 +1,28 @@ +# tas_ma_cohere_V230512:均线系统粘合/扩散状态 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}SMA{ma_seq}_均线系统V230512` + +## 信号逻辑 + +1. 计算 `ma_seq` 各条 SMA,并构造最近 100 根“均线最大值/最小值 - 1”序列; +2. 用前 80 根计算标准差 `ret_std`; +3. 最近 20 根中,`ret < 0.5 * ret_std` 达到 16 次判 `粘合`; +4. 最近 20 根中,`ret > 1.0 * ret_std` 达到 16 次判 `扩散`(覆盖前者)。 + +## 信号列表示例 + +- `Signal('60分钟_D1SMA5#13#21#34#55_均线系统V230512_粘合_任意_任意_0')` +- `Signal('60分钟_D1SMA5#13#21#34#55_均线系统V230512_扩散_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `ma_seq`:均线周期序列(`#` 分隔),默认 `5#13#21#34#55`。 + +## 对齐说明 + +阈值与覆盖顺序对齐 Python `tas_ma_cohere_V230512`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_ma_round_V221206.md b/.claude/skills/signal-functions/references/signals/tas_ma_round_V221206.md new file mode 100644 index 000000000..322b8d0df --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_ma_round_V221206.md @@ -0,0 +1,31 @@ +# tas_ma_round_V221206:笔端点触碰均线信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}TH{th}#碰{ma_type}#{timeperiod}_BE辅助V221206` + +## 信号逻辑 + +1. 计算指定均线(`SMA/EMA`); +2. 取倒数第 `di` 笔,提取其结束分型中间 NewBar 的原始K线; +3. 计算该批原始K线对应均线均值 `last_ma`; +4. 若上笔且 `abs(high-last_ma)/power_price < th/100`,判 `上碰`; +5. 若下笔且 `abs(low-last_ma)/power_price < th/100`,判 `下碰`;否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1TH10#碰SMA#60_BE辅助V221206_上碰_任意_任意_0')` +- `Signal('60分钟_D1TH10#碰SMA#60_BE辅助V221206_下碰_任意_任意_0')` + +## 参数说明 + +- `di`:指定倒数第 `di` 笔,默认 `1`; +- `th`:端点触碰阈值(百分比),默认 `10`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_ma_system_V230513.md b/.claude/skills/signal-functions/references/signals/tas_ma_system_V230513.md new file mode 100644 index 000000000..2fd419c11 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_ma_system_V230513.md @@ -0,0 +1,27 @@ +# tas_ma_system_V230513:均线系统多空排列 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}SMA{ma_seq}_均线系统V230513` + +## 信号逻辑 + +1. 计算 `ma_seq` 中各周期 SMA; +2. 当前值严格递减判 `多头排列`; +3. 当前值严格递增判 `空头排列`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1SMA5#10#20_均线系统V230513_多头排列_任意_任意_0')` +- `Signal('60分钟_D1SMA5#10#20_均线系统V230513_空头排列_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `ma_seq`:均线周期串,默认 `5#10#20`。 + +## 对齐说明 + +排列判定方向与 Python `tas_ma_system_V230513` 完全一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_base_V221028.md b/.claude/skills/signal-functions/references/signals/tas_macd_base_V221028.md new file mode 100644 index 000000000..513c30567 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_base_V221028.md @@ -0,0 +1,29 @@ +# tas_macd_base_V221028:MACD/DIF/DEA 多空与方向信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}#{key}_BS辅助V221028` + +## 信号逻辑 + +1. 计算 MACD 三序列; +2. 依据 `key` 选择 `MACD/DIF/DEA`; +3. 当前值 `>=0` 判定 `多头`,否则 `空头`; +4. 当前值 `>=` 前值判定 `向上`,否则 `向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9#MACD_BS辅助V221028_多头_向上_任意_0')` +- `Signal('60分钟_D1MACD12#26#9#DIF_BS辅助V221028_空头_向下_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`; +- `key`:`MACD`、`DIF` 或 `DEA`,默认 `MACD`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_base_V230320.md b/.claude/skills/signal-functions/references/signals/tas_macd_base_V230320.md new file mode 100644 index 000000000..45c5c6369 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_base_V230320.md @@ -0,0 +1,31 @@ +# tas_macd_base_V230320:MACD/DIF/DEA 多空与方向信号(含重叠约束) + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}MO{max_overlap}#{key}_BS辅助V230320` + +## 信号逻辑 + +1. 计算 `MACD/DIF/DEA` 序列; +2. 取倒数 `di` 截止的最近 `max_overlap+1` 根值; +3. 若 `last > 0` 且前序存在 `< 0` 判 `多头`; +4. 若 `last < 0` 且前序存在 `> 0` 判 `空头`; +5. 否则判 `其他`;方向由 `last >= prev` 判 `向上/向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9MO3#MACD_BS辅助V230320_多头_向上_任意_0')` +- `Signal('60分钟_D1MACD12#26#9MO3#DIF_BS辅助V230320_空头_向下_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `key`:指标键,`MACD/DIF/DEA`,默认 `MACD`; +- `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`; +- `max_overlap`:最大重叠窗口,默认 `3`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bc_V221201.md b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V221201.md new file mode 100644 index 000000000..651660aac --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V221201.md @@ -0,0 +1,30 @@ +# tas_macd_bc_V221201:MACD背驰辅助信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}M{m}#MACD{fastperiod}#{slowperiod}#{signalperiod}_BCV221201` + +## 信号逻辑 + +1. 取最近 `m+n` 根K线,前 `m` 为对照窗口,后 `n` 为近端窗口; +2. 若近端价格创新低且MACD低点抬高,判 `底部` 背驰; +3. 若近端价格创新高且MACD高点走低,判 `顶部` 背驰; +4. 并给出当前柱体颜色 `红柱/绿柱`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N3M50#MACD12#26#9_BCV221201_底部_绿柱_任意_0')` +- `Signal('60分钟_D1N3M50#MACD12#26#9_BCV221201_顶部_红柱_任意_0')` +- `Signal('60分钟_D1N3M50#MACD12#26#9_BCV221201_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n/m`:近端窗口与对照窗口长度,默认 `3/50`; +- `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bc_V230803.md b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V230803.md new file mode 100644 index 000000000..ff9fb3163 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V230803.md @@ -0,0 +1,26 @@ +# tas_macd_bc_V230803:双分型 MACD 背驰信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_MACD双分型背驰_BS辅助V230803` + +## 信号逻辑 + +1. 提取最近分型列表中的同类顶/底分型; +2. 对比两个分型中间K线的 MACD 柱值; +3. 向上笔出现 `macd1 > macd2 > 0` 判 `空头`,向下笔镜像判 `多头`。 + +## 信号列表示例 + +- `Signal('60分钟_MACD双分型背驰_BS辅助V230803_空头_任意_任意_0')` +- `Signal('60分钟_MACD双分型背驰_BS辅助V230803_多头_任意_任意_0')` + +## 参数说明 + +- 无额外参数。 + +## 对齐说明 + +分型来源改为 `get_fx_list()`,与 Python `c.fx_list` 语义对齐。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bc_V230804.md b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V230804.md new file mode 100644 index 000000000..784b1d9af --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V230804.md @@ -0,0 +1,26 @@ +# tas_macd_bc_V230804:MACD 黄白线背驰信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD背驰_BS辅助V230804` + +## 信号逻辑 + +1. 取最近 7 笔,并在末 5 笔构建中枢; +2. 上笔场景:末笔位于高位区且 DIF 峰值弱于前两上笔,判 `空头`; +3. 下笔场景镜像:末笔位于低位区且 DIF 谷值抬升,判 `多头`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD背驰_BS辅助V230804_空头_任意_任意_0')` +- `Signal('60分钟_D1MACD背驰_BS辅助V230804_多头_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 笔,默认 `1`。 + +## 对齐说明 + +中枢有效性与 DIF 对比口径与 Python `tas_macd_bc_V230804` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bc_V240307.md b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V240307.md new file mode 100644 index 000000000..d1a513b3a --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bc_V240307.md @@ -0,0 +1,27 @@ +# tas_macd_bc_V240307:MACD 柱背驰计次信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}柱子背驰_BS辅助V240307` + +## 信号逻辑 + +1. 在窗口内识别 MACD 柱局部顶/底; +2. 顶部减弱并满足间隔条件判 `顶背驰`,底部抬高镜像判 `底背驰`; +3. 输出距离最近顶/底的计次数 `第k次`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N20柱子背驰_BS辅助V240307_顶背驰_第2次_任意_0')` +- `Signal('60分钟_D1N20柱子背驰_BS辅助V240307_底背驰_第1次_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:观察窗口,默认 `20`。 + +## 对齐说明 + +峰谷识别、间隔阈值与统计口径对齐 Python `tas_macd_bc_V240307`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bc_ubi_V230804.md b/.claude/skills/signal-functions/references/signals/tas_macd_bc_ubi_V230804.md new file mode 100644 index 000000000..9542a742b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bc_ubi_V230804.md @@ -0,0 +1,26 @@ +# tas_macd_bc_ubi_V230804:未完成笔 MACD 背驰观察 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_MACD背驰_UBI观察V230804` + +## 信号逻辑 + +1. 使用未完成笔(UBI)方向与极值位置; +2. 在最近 6 笔中构造中枢并比较 UBI 末段 DIF 与历史对应笔 DIF; +3. 上行 UBI DIF 走弱判 `空头`,下行 UBI DIF 抬升判 `多头`。 + +## 信号列表示例 + +- `Signal('60分钟_MACD背驰_UBI观察V230804_空头_任意_任意_0')` +- `Signal('60分钟_MACD背驰_UBI观察V230804_多头_任意_任意_0')` + +## 参数说明 + +- 无额外参数。 + +## 对齐说明 + +UBI 原始K线口径与 Python `tas_macd_bc_ubi_V230804` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230312.md b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230312.md new file mode 100644 index 000000000..192b57576 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230312.md @@ -0,0 +1,26 @@ +# tas_macd_bs1_V230312:MACD 辅助一买一卖(笔结构) + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230312` + +## 信号逻辑 + +1. 最近 7 笔内,末笔创新低并满足三卖结构且末分型 MACD 抬升,判 `看多`; +2. 镜像条件(创新高 + 三买结构 + MACD 走弱)判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V230312_看多_任意_任意_0')` +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V230312_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 笔,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 + +## 对齐说明 + +笔结构约束与末分型 MACD 比较逻辑对齐 Python `tas_macd_bs1_V230312`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230313.md b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230313.md new file mode 100644 index 000000000..fb4e3493e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230313.md @@ -0,0 +1,27 @@ +# tas_macd_bs1_V230313:MACD 红绿柱第一买卖点 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V230313` + +## 信号逻辑 + +1. 近 10 与前 90 根对比新高新低; +2. 用交叉面积递减/递增与 MACD 方向判 `一买/一卖`; +3. `v2` 返回最后交叉类型。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V230313_一买_死叉_任意_0')` +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V230313_一卖_金叉_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 + +## 对齐说明 + +面积比较与条件优先级(`and/or`)按 Python `tas_macd_bs1_V230313` 对齐。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230411.md b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230411.md new file mode 100644 index 000000000..1f8bbd0c3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230411.md @@ -0,0 +1,29 @@ +# tas_macd_bs1_V230411:MACD DIF 五笔背驰信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{tha}#{thb}#{thc}_BS1辅助V230411` + +## 信号逻辑 + +1. 取最近 5 笔并要求当前未完成笔长度约束; +2. 上笔场景:涨幅、DIF 结构、末笔涨幅与 DIF 衰减同时满足,判定 `顶背驰`; +3. 下笔场景镜像:跌幅、DIF 结构与 DIF 回升满足,判定 `底背驰`。 + +## 信号列表示例 + +- `Signal('60分钟_D1T30#5#30_BS1辅助V230411_顶背驰_任意_任意_0')` +- `Signal('60分钟_D1T30#5#30_BS1辅助V230411_底背驰_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 笔,默认 `1`; +- `tha`:前三笔累计涨跌阈值(BP),默认 `30`; +- `thb`:第5笔相对第3笔价格阈值(BP),默认 `5`; +- `thc`:第5笔相对第3笔 DIF 变化阈值(BP),默认 `30`。 + +## 对齐说明 + +五笔条件组合与 Python `tas_macd_bs1_V230411` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230412.md b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230412.md new file mode 100644 index 000000000..821deefad --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_bs1_V230412.md @@ -0,0 +1,28 @@ +# tas_macd_bs1_V230412:MACD DIF 五笔背驰简化信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{tha}#{thb}_BS1辅助V230412` + +## 信号逻辑 + +1. 取最近 5 笔并校验未完成笔长度; +2. 上笔场景:前三笔涨幅过阈值,且 `DIF(3)` 为局部最大,末笔价格不弱,判 `顶背驰`; +3. 下笔场景镜像:`DIF(3)` 为局部最小,判 `底背驰`。 + +## 信号列表示例 + +- `Signal('60分钟_D1T100#10_BS1辅助V230412_顶背驰_任意_任意_0')` +- `Signal('60分钟_D1T100#10_BS1辅助V230412_底背驰_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 笔,默认 `1`; +- `tha`:前三笔累计涨跌阈值(BP),默认 `100`; +- `thb`:第5笔相对第3笔价格阈值(BP),默认 `10`。 + +## 对齐说明 + +条件组合与 Python `tas_macd_bs1_V230412` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_change_V221105.md b/.claude/skills/signal-functions/references/signals/tas_macd_change_V221105.md new file mode 100644 index 000000000..7a6e5bc46 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_change_V221105.md @@ -0,0 +1,29 @@ +# tas_macd_change_V221105:MACD变色次数信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K{n}#MACD{fastperiod}#{slowperiod}#{signalperiod}变色次数_BS辅助V221105` + +## 信号逻辑 + +1. 在最近 `n` 根上计算 DIF/DEA 金叉死叉序列; +2. 过滤 `距离<2` 的抖动交叉; +3. 同类型连续交叉按 Python 语义合并; +4. 输出合并后次数 `"{num}次"`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K55#MACD12#26#9变色次数_BS辅助V221105_0次_任意_任意_0')` +- `Signal('60分钟_D1K55#MACD12#26#9变色次数_BS辅助V221105_3次_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `n`:统计窗口长度,默认 `55`; +- `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_direct_V221106.md b/.claude/skills/signal-functions/references/signals/tas_macd_direct_V221106.md new file mode 100644 index 000000000..8a233f8fa --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_direct_V221106.md @@ -0,0 +1,28 @@ +# tas_macd_direct_V221106:MACD柱方向信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}方向_BS辅助V221106` + +## 信号逻辑 + +1. 计算 MACD 柱序列; +2. 取倒数 `di` 对齐的最近 3 根柱值; +3. 严格递增判定 `向上`,严格递减判定 `向下`,否则 `模糊`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K#MACD12#26#9方向_BS辅助V221106_向上_任意_任意_0')` +- `Signal('60分钟_D1K#MACD12#26#9方向_BS辅助V221106_向下_任意_任意_0')` +- `Signal('60分钟_D1K#MACD12#26#9方向_BS辅助V221106_模糊_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230408.md b/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230408.md new file mode 100644 index 000000000..8da2850a3 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230408.md @@ -0,0 +1,28 @@ +# tas_macd_dist_V230408:DIF/DEA/MACD等宽分层信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_{key}分层W{w}N{n}_BS辅助V230408` + +## 信号逻辑 + +1. 获取最近 `w` 根K线的 `DIF/DEA/MACD` 序列; +2. 按等宽区间切分为 `n` 层; +3. 返回最后一个值所在层级 `第{q}层`。 + +## 信号列表示例 + +- `Signal('60分钟_DIF分层W100N10_BS辅助V230408_第3层_任意_任意_0')` +- `Signal('60分钟_MACD分层W100N10_BS辅助V230408_第8层_任意_任意_0')` + +## 参数说明 + +- `key`:`DIF/DEA/MACD`,默认 `DIF`; +- `w`:窗口长度,默认 `100`; +- `n`:分层数量,默认 `10`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230409.md b/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230409.md new file mode 100644 index 000000000..6aebd4ea6 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230409.md @@ -0,0 +1,29 @@ +# tas_macd_dist_V230409:DIF/DEA/MACD远离零轴信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_{key}远离W{w}N{n}T{t}_BS辅助V230409` + +## 信号逻辑 + +1. 获取最近 `w` 根K线指标值并计算绝对值均值; +2. 若最近 `n` 根中绝对值最大者超过 `mean * t/10`,判定远离零轴; +3. 按最后一个值符号输出 `多头远离/空头远离`。 + +## 信号列表示例 + +- `Signal('60分钟_DIF远离W100N10T20_BS辅助V230409_多头远离_任意_任意_0')` +- `Signal('60分钟_DIF远离W100N10T20_BS辅助V230409_空头远离_任意_任意_0')` + +## 参数说明 + +- `key`:`DIF/DEA/MACD`,默认 `DIF`; +- `w`:窗口长度,默认 `100`; +- `n`:最近判定窗口,默认 `10`; +- `t`:远离阈值倍率(除以10),默认 `20`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230410.md b/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230410.md new file mode 100644 index 000000000..fefcf25fa --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_dist_V230410.md @@ -0,0 +1,28 @@ +# tas_macd_dist_V230410:DIF/DEA/MACD多空分层信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_{key}多空分层W{w}N{n}_BS辅助V230410` + +## 信号逻辑 + +1. 取最近 `w` 根指标序列,按最后一值符号判 `多头/空头`; +2. 仅保留同符号样本并等宽分层为 `n` 层; +3. 输出 `多头/空头` 与 `第{q}层`。 + +## 信号列表示例 + +- `Signal('60分钟_DIF多空分层W200N5_BS辅助V230410_多头_第2层_任意_0')` +- `Signal('60分钟_DIF多空分层W200N5_BS辅助V230410_空头_第4层_任意_0')` + +## 参数说明 + +- `key`:`DIF/DEA/MACD`,默认 `DIF`; +- `w`:窗口长度,默认 `200`; +- `n`:分层数量,默认 `5`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221201.md b/.claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221201.md new file mode 100644 index 000000000..2eda4dc46 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221201.md @@ -0,0 +1,28 @@ +# tas_macd_first_bs_V221201:MACD一买一卖辅助信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221201` + +## 信号逻辑 + +1. 在近 300 根内统计 DIF/DEA 金叉死叉序列; +2. 满足特定零轴位置与节奏条件时,给出 `一买` 或 `一卖`; +3. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V221201_一买_任意_任意_0')` +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V221201_一卖_任意_任意_0')` +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V221201_其他_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221216.md b/.claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221216.md new file mode 100644 index 000000000..e510dcac5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_first_bs_V221216.md @@ -0,0 +1,27 @@ +# tas_macd_first_bs_V221216:MACD 第一买卖点(扩展版) + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS1辅助V221216` + +## 信号逻辑 + +1. 以最近 10 根与前 90 根做高低点对比(新高/新低); +2. 结合最近交叉类型、零轴位置与 MACD 方向判断 `一买/一卖`; +3. `v2` 输出最后一次交叉类型(`金叉/死叉`)。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V221216_一买_死叉_任意_0')` +- `Signal('60分钟_D1MACD12#26#9_BS1辅助V221216_一卖_金叉_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 + +## 对齐说明 + +分支条件、`or` 组合与 `v2` 输出语义对齐 Python `tas_macd_first_bs_V221216`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_power_V221108.md b/.claude/skills/signal-functions/references/signals/tas_macd_power_V221108.md new file mode 100644 index 000000000..63f371a1e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_power_V221108.md @@ -0,0 +1,29 @@ +# tas_macd_power_V221108:MACD强弱分层信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}强弱_BS辅助V221108` + +## 信号逻辑 + +1. 计算当前 `DIF/DEA`; +2. `dif >= dea >= 0` 判定 `超强`; +3. `dif - dea > 0` 判定 `强势`; +4. `dif <= dea <= 0` 判定 `超弱`; +5. `dif - dea < 0` 判定 `弱势`,其余为 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K#MACD12#26#9强弱_BS辅助V221108_超强_任意_任意_0')` +- `Signal('60分钟_D1K#MACD12#26#9强弱_BS辅助V221108_弱势_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD参数,默认 `12/26/9`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_second_bs_V221201.md b/.claude/skills/signal-functions/references/signals/tas_macd_second_bs_V221201.md new file mode 100644 index 000000000..4d25f5298 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_second_bs_V221201.md @@ -0,0 +1,27 @@ +# tas_macd_second_bs_V221201:MACD 第二买卖点 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}_BS2辅助V221201` + +## 信号逻辑 + +1. 在近 350 根(去掉最早 50 根)统计交叉序列; +2. 结合最近交叉距今、零轴位置与 MACD 方向判 `二买/二卖`; +3. `v2` 返回最后交叉类型。 + +## 信号列表示例 + +- `Signal('60分钟_D1MACD12#26#9_BS2辅助V221201_二买_死叉_任意_0')` +- `Signal('60分钟_D1MACD12#26#9_BS2辅助V221201_二卖_金叉_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 + +## 对齐说明 + +`距今` 条件与零轴判定对齐 Python `tas_macd_second_bs_V221201`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_macd_xt_V221208.md b/.claude/skills/signal-functions/references/signals/tas_macd_xt_V221208.md new file mode 100644 index 000000000..6eafdb0a2 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_macd_xt_V221208.md @@ -0,0 +1,27 @@ +# tas_macd_xt_V221208:MACD 柱形态信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K#MACD{fastperiod}#{slowperiod}#{signalperiod}形态_BS辅助V221208` + +## 信号逻辑 + +1. 读取最近 5 根 MACD 柱; +2. 按柱子相对大小关系判定 `逼空棒/杀多棒/绿抽脚/红缩头`; +3. 按跨零关系判定 `空翻多/多翻空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K#MACD12#26#9形态_BS辅助V221208_逼空棒_任意_任意_0')` +- `Signal('60分钟_D1K#MACD12#26#9形态_BS辅助V221208_多翻空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `fastperiod/slowperiod/signalperiod`:MACD 参数,默认 `12/26/9`。 + +## 对齐说明 + +形态分支顺序与 Python `tas_macd_xt_V221208` 保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_rsi_base_V230227.md b/.claude/skills/signal-functions/references/signals/tas_rsi_base_V230227.md new file mode 100644 index 000000000..e7b863a22 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_rsi_base_V230227.md @@ -0,0 +1,30 @@ +# tas_rsi_base_V230227:RSI超买超卖与方向信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}T{th}RSI{timeperiod}_RSI辅助V230227` + +## 信号逻辑 + +1. 使用 `n` 计算 RSI(与 Python 保持一致); +2. `rsi <= th` 判 `超卖`,`rsi >= 100-th` 判 `超买`,否则 `其他`; +3. `rsi_now >= rsi_prev` 判 `向上`,否则 `向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1T20RSI6_RSI辅助V230227_超卖_向上_任意_0')` +- `Signal('60分钟_D1T20RSI6_RSI辅助V230227_超买_向下_任意_0')` +- `Signal('60分钟_D1T20RSI6_RSI辅助V230227_其他_向上_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:RSI 实际计算周期,默认 `6`; +- `timeperiod`:仅用于信号键展示,默认 `6`; +- `th`:超买超卖阈值,默认 `20`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_rumi_V230704.md b/.claude/skills/signal-functions/references/signals/tas_rumi_V230704.md new file mode 100644 index 000000000..0fc26e8a6 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_rumi_V230704.md @@ -0,0 +1,29 @@ +# tas_rumi_V230704:RUMI 零轴切换信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}F{timeperiod1}S{timeperiod2}R{rumi_window}_BS辅助V230704` + +## 信号逻辑 + +1. 计算 `SMA(timeperiod1)` 与 `WMA(timeperiod2)`,得到 `diff = fast - slow`; +2. 对 `diff` 做 `SMA(rumi_window)` 平滑,得到 `rumi`; +3. `rumi` 上穿 0 轴判 `多头`,下穿 0 轴判 `空头`。 + +## 信号列表示例 + +- `Signal('60分钟_D1F3S50R30_BS辅助V230704_多头_任意_任意_0')` +- `Signal('60分钟_D1F3S50R30_BS辅助V230704_空头_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `timeperiod1`:快线均线周期,默认 `3`; +- `timeperiod2`:慢线均线周期,默认 `50`; +- `rumi_window`:RUMI 平滑周期,默认 `30`。 + +## 对齐说明 + +快慢线选型与零轴交叉判定对齐 Python `tas_rumi_V230704`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_sar_base_V230425.md b/.claude/skills/signal-functions/references/signals/tas_sar_base_V230425.md new file mode 100644 index 000000000..7ae87c2d5 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_sar_base_V230425.md @@ -0,0 +1,28 @@ +# tas_sar_base_V230425:SAR 基础多空信号 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}MO{max_overlap}SAR_BS辅助V230425` + +## 信号逻辑 + +1. 计算 SAR 序列; +2. 若当前 `close > sar` 且窗口内存在任意 `close < sar`,判定 `看多`; +3. 若当前 `close < sar` 且窗口内存在任意 `close > sar`,判定 `看空`; +4. 否则返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1MO5SAR_BS辅助V230425_看多_任意_任意_0')` +- `Signal('60分钟_D1MO5SAR_BS辅助V230425_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `max_overlap`:重叠窗口,默认 `5`。 + +## 对齐说明 + +突破与重叠窗口判定逻辑对齐 Python `tas_sar_base_V230425`。 diff --git a/.claude/skills/signal-functions/references/signals/tas_second_bs_V230228.md b/.claude/skills/signal-functions/references/signals/tas_second_bs_V230228.md new file mode 100644 index 000000000..8f205dfa0 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_second_bs_V230228.md @@ -0,0 +1,33 @@ +# tas_second_bs_V230228:均线结合K线形态的二买二卖辅助 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}N{n}#{ma_type}#{timeperiod}_BS2辅助V230228` + +## 信号逻辑 + +1. 在最近 `n` 根K线上计算均线; +2. 二买条件: +- `sma[-1]` 为窗口新高且 `sma[-1] > sma[-2]`; +- 最新收盘 `close[-1] > sma[-1]`; +- 最近3根存在 `low < sma`; +3. 二卖条件与上面对称; +4. 满足则输出 `二买/二卖`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1N21#SMA#20_BS2辅助V230228_二买_任意_任意_0')` +- `Signal('60分钟_D1N21#SMA#20_BS2辅助V230228_二卖_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:窗口大小,默认 `21`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `20`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_second_bs_V230303.md b/.claude/skills/signal-functions/references/signals/tas_second_bs_V230303.md new file mode 100644 index 000000000..eac3db385 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_second_bs_V230303.md @@ -0,0 +1,33 @@ +# tas_second_bs_V230303:利用笔和均线辅助二买二卖 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}{ma_type}#{timeperiod}_BS2辅助V230303` + +## 信号逻辑 + +1. 取倒数 `di` 截止最近13笔,取最后一笔与其首尾原始K线; +2. 二买条件: +- 最后一笔为向下; +- 最后一笔末K最低点跌破均线; +- 最近5笔最低点为13笔全局最低; +- 该笔首K均线值 < 末K均线值(均线向上); +3. 二卖条件与上面对称; +4. 满足则输出 `二买/二卖`,否则 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1SMA#30_BS2辅助V230303_二买_任意_任意_0')` +- `Signal('60分钟_D1SMA#30_BS2辅助V230303_二卖_任意_任意_0')` + +## 参数说明 + +- `di`:指定倒数第 `di` 笔,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `30`。 + +## 对齐说明 + +与 Python 同名函数逻辑与边界条件保持一致。 diff --git a/.claude/skills/signal-functions/references/signals/tas_slope_V231019.md b/.claude/skills/signal-functions/references/signals/tas_slope_V231019.md new file mode 100644 index 000000000..05385fe2b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/tas_slope_V231019.md @@ -0,0 +1,28 @@ +# tas_slope_V231019:DIF 斜率分位多空 + +> 模块: `tas.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}DIF{n}斜率T{th}_BS辅助V231019` + +## 信号逻辑 + +1. 计算最近区间内 DIF 线性回归斜率序列; +2. 计算当前斜率在历史区间中的归一化分位; +3. 分位 `> th/100` 判 `看多`,`< 1-th/100` 判 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1DIF10斜率T80_BS辅助V231019_看多_任意_任意_0')` +- `Signal('60分钟_D1DIF10斜率T80_BS辅助V231019_看空_任意_任意_0')` + +## 参数说明 + +- `di`:倒数第 `di` 根K线,默认 `1`; +- `n`:斜率回看长度,默认 `10`; +- `th`:分位阈值(50-100),默认 `80`。 + +## 对齐说明 + +分位判定区间和阈值方向与 Python `tas_slope_V231019` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/vol_double_ma_V230214.md b/.claude/skills/signal-functions/references/signals/vol_double_ma_V230214.md new file mode 100644 index 000000000..faa214fa7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/vol_double_ma_V230214.md @@ -0,0 +1,28 @@ +# vol_double_ma_V230214:成交量双均线多空信号 + +> 模块: `vol.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}VOL双均线{ma_type}#{t1}#{t2}_BS辅助V230214` + +## 信号逻辑 + +1. 分别计算成交量短均线 `t1` 与长均线 `t2`; +2. `vol_ma_short >= vol_ma_long` 判定 `看多`,否则 `看空`。 + +## 信号列表示例 + +- `Signal('60分钟_D1VOL双均线SMA#5#20_BS辅助V230214_看多_任意_任意_0')` +- `Signal('60分钟_D1VOL双均线EMA#5#20_BS辅助V230214_看空_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `t1`:短均线周期,默认 `5`; +- `t2`:长均线周期,默认 `20`; +- `ma_type`:均线类型,默认 `SMA`。 + +## 对齐说明 + +短长成交量均线关系判定与 Python `vol_double_ma_V230214` 一致。 diff --git a/.claude/skills/signal-functions/references/signals/vol_gao_di_V221218.md b/.claude/skills/signal-functions/references/signals/vol_gao_di_V221218.md new file mode 100644 index 000000000..82c77da5c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/vol_gao_di_V221218.md @@ -0,0 +1,28 @@ +# vol_gao_di_V221218:高量柱与低量柱信号 + +> 模块: `vol.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_量柱V221218` + +## 信号逻辑 + +1. 依次检查 `10/9/8/7/6` 根窗口; +2. 若末根成交量为窗口最大值,判 `高量柱`; +3. 若次末根为窗口最大且末根不足其 50%,判 `高量黄金柱`; +4. 若末根成交量为窗口最小值,判 `低量柱`; +5. 命中后输出对应窗口长度(如 `10K`)。 + +## 信号列表示例 + +- `Signal('60分钟_D1K_量柱V221218_高量柱_10K_任意_0')` +- `Signal('60分钟_D1K_量柱V221218_低量柱_7K_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +窗口递减检查顺序与高/低量柱定义对齐 Python `vol_gao_di_V221218`。 diff --git a/.claude/skills/signal-functions/references/signals/vol_single_ma_V230214.md b/.claude/skills/signal-functions/references/signals/vol_single_ma_V230214.md new file mode 100644 index 000000000..965a803c7 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/vol_single_ma_V230214.md @@ -0,0 +1,28 @@ +# vol_single_ma_V230214:单成交量均线多空与方向信号 + +> 模块: `vol.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}VOL#{ma_type}#{timeperiod}_分类V230214` + +## 信号逻辑 + +1. 计算指定成交量均线(`SMA/EMA/WMA`); +2. `vol_now >= vol_ma_now` 判定 `多头`,否则 `空头`; +3. `vol_ma_now >= vol_ma_prev` 判定 `向上`,否则 `向下`。 + +## 信号列表示例 + +- `Signal('60分钟_D1VOL#SMA#5_分类V230214_多头_向上_任意_0')` +- `Signal('60分钟_D1VOL#EMA#12_分类V230214_空头_向下_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `ma_type`:均线类型,默认 `SMA`; +- `timeperiod`:均线周期,默认 `5`。 + +## 对齐说明 + +成交量均线缓存与判定口径对齐 Python `vol_single_ma_V230214`。 diff --git a/.claude/skills/signal-functions/references/signals/vol_ti_suo_V221216.md b/.claude/skills/signal-functions/references/signals/vol_ti_suo_V221216.md new file mode 100644 index 000000000..54acd20ce --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/vol_ti_suo_V221216.md @@ -0,0 +1,26 @@ +# vol_ti_suo_V221216:梯量与缩量柱信号 + +> 模块: `vol.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}K_量柱V221216` + +## 信号逻辑 + +1. 连续三根成交量递增判定 `梯量`,递减判定 `缩量`; +2. 在 `梯量/缩量` 前提下,以当前收盘与前两根收盘区间比较得到 `价升/价跌/价平`; +3. 不满足量柱条件时返回 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_D1K_量柱V221216_梯量_价升_任意_0')` +- `Signal('60分钟_D1K_量柱V221216_缩量_价跌_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`。 + +## 对齐说明 + +量柱与价位分类规则对齐 Python `vol_ti_suo_V221216`。 diff --git a/.claude/skills/signal-functions/references/signals/vol_window_V230731.md b/.claude/skills/signal-functions/references/signals/vol_window_V230731.md new file mode 100644 index 000000000..3e3a36e15 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/vol_window_V230731.md @@ -0,0 +1,29 @@ +# vol_window_V230731:窗口成交量分层特征 + +> 模块: `vol.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}M{m}N{n}_窗口能量V230731` + +## 信号逻辑 + +1. 取最近 `m` 根成交量并按 `qcut` 分成 `n` 层; +2. 统计最近 `w` 根中的最高层与最低层; +3. 输出 `高量N{max}` 与 `低量N{min}`。 + +## 信号列表示例 + +- `Signal('60分钟_D2W5M100N10_窗口能量V230731_高量N9_低量N4_任意_0')` +- `Signal('60分钟_D1W5M30N10_窗口能量V230731_高量N10_低量N3_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `w`:观察窗口大小,默认 `5`; +- `m`:分层样本长度,默认 `30`; +- `n`:分层数量,默认 `10`。 + +## 对齐说明 + +分层采用与 Python `pd.qcut(..., duplicates='drop')` 等价的去重分位边界。 diff --git a/.claude/skills/signal-functions/references/signals/vol_window_V230801.md b/.claude/skills/signal-functions/references/signals/vol_window_V230801.md new file mode 100644 index 000000000..bd1ca910e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/vol_window_V230801.md @@ -0,0 +1,27 @@ +# vol_window_V230801:窗口成交量先后顺序特征 + +> 模块: `vol.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_D{di}W{w}_窗口能量V230801` + +## 信号逻辑 + +1. 取最近 `w` 根成交量; +2. 若最小量索引在最大量索引之后,判 `先放后缩`; +3. 否则判 `先缩后放`。 + +## 信号列表示例 + +- `Signal('60分钟_D1W5_窗口能量V230801_先放后缩_任意_任意_0')` +- `Signal('60分钟_D1W5_窗口能量V230801_先缩后放_任意_任意_0')` + +## 参数说明 + +- `di`:信号计算截止在倒数第 `di` 根K线,默认 `1`; +- `w`:观察窗口大小,默认 `5`。 + +## 对齐说明 + +最大/最小成交量首次出现位置比较逻辑对齐 Python `vol_window_V230801`。 diff --git a/.claude/skills/signal-functions/references/signals/xl_bar_basis_V240411.md b/.claude/skills/signal-functions/references/signals/xl_bar_basis_V240411.md new file mode 100644 index 000000000..a903f0c2e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/xl_bar_basis_V240411.md @@ -0,0 +1,26 @@ +# xl_bar_basis_V240411:吞没形态信号 + +> 模块: `xl.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}_形态V240411` + +## 信号逻辑 + +1. 最近两根K线构成看涨吞没时输出 `看涨吞没`; +2. 构成看跌吞没时输出 `看跌吞没`; +3. 否则输出 `其他`。 + +## 信号列表示例 + +- `Signal('60分钟_N2_形态V240411_看涨吞没_任意_任意_0')` +- `Signal('60分钟_N2_形态V240411_看跌吞没_任意_任意_0')` + +## 参数说明 + +- `n`:最小窗口参数,默认 `2`。 + +## 对齐说明 + +与 Python `xl_bar_basis_V240411` 的吞没条件一致。 diff --git a/.claude/skills/signal-functions/references/signals/xl_bar_basis_V240412.md b/.claude/skills/signal-functions/references/signals/xl_bar_basis_V240412.md new file mode 100644 index 000000000..59503a666 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/xl_bar_basis_V240412.md @@ -0,0 +1,27 @@ +# xl_bar_basis_V240412:长蜡烛形态信号 + +> 模块: `xl.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}#TH{th}_形态V240412` + +## 信号逻辑 + +1. 统计前 `n` 根实体长度均值与标准差; +2. 当前实体超过 `mean + th*std` 时识别长蜡烛; +3. 按实体方向输出 `看涨长蜡烛/看跌长蜡烛`。 + +## 信号列表示例 + +- `Signal('60分钟_N10#TH3_形态V240412_看涨长蜡烛_任意_任意_0')` +- `Signal('60分钟_N10#TH3_形态V240412_看跌长蜡烛_任意_任意_0')` + +## 参数说明 + +- `n`:统计窗口,默认 `10`; +- `th`:标准差倍数,默认 `3`。 + +## 对齐说明 + +与 Python `xl_bar_basis_V240412` 的阈值公式一致。 diff --git a/.claude/skills/signal-functions/references/signals/xl_bar_position_V240328.md b/.claude/skills/signal-functions/references/signals/xl_bar_position_V240328.md new file mode 100644 index 000000000..5a44749a8 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/xl_bar_position_V240328.md @@ -0,0 +1,26 @@ +# xl_bar_position_V240328:相对高低位置识别信号 + +> 模块: `xl.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}_BS辅助V240328` + +## 信号逻辑 + +1. 计算最近 `3n` 根 `(close-EMA(n))/EMA(n)` 偏离度; +2. 若最新值低于 30% 分位数判 `相对低点`; +3. 若最新值高于 70% 分位数判 `相对高点`。 + +## 信号列表示例 + +- `Signal('60分钟_N10_BS辅助V240328_相对低点_任意_任意_0')` +- `Signal('60分钟_N10_BS辅助V240328_相对高点_任意_任意_0')` + +## 参数说明 + +- `n`:EMA 周期,默认 `10`。 + +## 对齐说明 + +与 Python `xl_bar_position_V240328` 的分位口径一致(midpoint)。 diff --git a/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240329.md b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240329.md new file mode 100644 index 000000000..1713b9d3c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240329.md @@ -0,0 +1,27 @@ +# xl_bar_trend_V240329:十字孕线反转信号 + +> 模块: `xl.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}M{m}_十字线反转V240329` + +## 信号逻辑 + +1. 判断最新K线是否十字线(`check_szx`); +2. 前一根为长阴线且满足阈值判 `底部十字孕线`; +3. 前一根为长阳线且满足阈值判 `顶部十字孕线`。 + +## 信号列表示例 + +- `Signal('60分钟_N10M5_十字线反转V240329_底部十字孕线_其他_任意_0')` +- `Signal('60分钟_N10M5_十字线反转V240329_顶部十字孕线_其他_任意_0')` + +## 参数说明 + +- `n`:十字线阈值参数,默认 `10`; +- `m`:实体比例阈值,默认 `5`。 + +## 对齐说明 + +与 Python `xl_bar_trend_V240329` 的 `check_szx` 口径一致。 diff --git a/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240330.md b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240330.md new file mode 100644 index 000000000..2654337fb --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240330.md @@ -0,0 +1,27 @@ +# xl_bar_trend_V240330:双均线过滤信号 + +> 模块: `xl.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}M{m}#{ma_type}_双均线过滤V240330` + +## 信号逻辑 + +1. 计算 `MA(n)` 与 `MA(m)`; +2. 根据两线相对位置输出 `看多/看空`; +3. 统计连续状态次数并输出 `第xx次`(最大 10)。 + +## 信号列表示例 + +- `Signal('60分钟_N5M21#SMA_双均线过滤V240330_看多_第03次_任意_0')` +- `Signal('60分钟_N5M21#SMA_双均线过滤V240330_看空_第06次_任意_0')` + +## 参数说明 + +- `n/m`:短长均线周期,默认 `5/21`; +- `ma_type`:均线类型,默认 `SMA`。 + +## 对齐说明 + +与 Python `xl_bar_trend_V240330` 的次数计数逻辑一致。 diff --git a/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240331.md b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240331.md new file mode 100644 index 000000000..032c8994d --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240331.md @@ -0,0 +1,25 @@ +# xl_bar_trend_V240331:通道突破信号 + +> 模块: `xl.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}_突破信号V240331` + +## 信号逻辑 + +1. 若最新高点突破前 `n` 根最高价,判 `做多`; +2. 若最新低点跌破前 `n` 根最低价,判 `做空`。 + +## 信号列表示例 + +- `Signal('60分钟_N20_突破信号V240331_做多_任意_任意_0')` +- `Signal('60分钟_N20_突破信号V240331_做空_任意_任意_0')` + +## 参数说明 + +- `n`:通道窗口,默认 `20`。 + +## 对齐说明 + +与 Python `xl_bar_trend_V240331` 的突破判定一致。 diff --git a/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240623.md b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240623.md new file mode 100644 index 000000000..fef1dc29b --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/xl_bar_trend_V240623.md @@ -0,0 +1,26 @@ +# xl_bar_trend_V240623:通道突破连续信号 + +> 模块: `xl.rs` | 类别: `kline` + +## 参数模板 + +`"{freq}_N{n}通道_突破信号V240623` + +## 信号逻辑 + +1. 使用倒数第二根K线判断是否突破前 `n` 通道; +2. 突破上轨给 `做多`,且最新再创新高给 `连续2次上涨`; +3. 跌破下轨给 `做空`,且最新再创新低给 `连续2次下跌`。 + +## 信号列表示例 + +- `Signal('60分钟_N20通道_突破信号V240623_做多_连续2次上涨_任意_0')` +- `Signal('60分钟_N20通道_突破信号V240623_做空_连续2次下跌_任意_0')` + +## 参数说明 + +- `n`:通道窗口,默认 `20`。 + +## 对齐说明 + +与 Python `xl_bar_trend_V240623` 的“倒二突破 + 最新确认”口径一致。 diff --git a/.claude/skills/signal-functions/references/signals/zdy_stop_loss_V230406.md b/.claude/skills/signal-functions/references/signals/zdy_stop_loss_V230406.md new file mode 100644 index 000000000..8a73d4d87 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/zdy_stop_loss_V230406.md @@ -0,0 +1,3 @@ +# zdy_stop_loss_V230406:笔操作止损逻辑 + +> 模块: `zdy_trader.rs` | 类别: `trader` diff --git a/.claude/skills/signal-functions/references/signals/zdy_take_profit_V230406.md b/.claude/skills/signal-functions/references/signals/zdy_take_profit_V230406.md new file mode 100644 index 000000000..047774002 --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/zdy_take_profit_V230406.md @@ -0,0 +1,3 @@ +# zdy_take_profit_V230406:笔操作止盈逻辑 + +> 模块: `zdy_trader.rs` | 类别: `trader` diff --git a/.claude/skills/signal-functions/references/signals/zdy_take_profit_V230407.md b/.claude/skills/signal-functions/references/signals/zdy_take_profit_V230407.md new file mode 100644 index 000000000..6136c680e --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/zdy_take_profit_V230407.md @@ -0,0 +1,3 @@ +# zdy_take_profit_V230407:按力度提前止盈 + +> 模块: `zdy_trader.rs` | 类别: `trader` diff --git a/.claude/skills/signal-functions/references/signals/zdy_vibrate_V230406.md b/.claude/skills/signal-functions/references/signals/zdy_vibrate_V230406.md new file mode 100644 index 000000000..d5bc51b6c --- /dev/null +++ b/.claude/skills/signal-functions/references/signals/zdy_vibrate_V230406.md @@ -0,0 +1,3 @@ +# zdy_vibrate_V230406:中枢震荡短差辅助 + +> 模块: `zdy_trader.rs` | 类别: `trader` diff --git a/.claude/skills/signal-functions/scripts/extract_signal_docs.py b/.claude/skills/signal-functions/scripts/extract_signal_docs.py new file mode 100644 index 000000000..daec6954a --- /dev/null +++ b/.claude/skills/signal-functions/scripts/extract_signal_docs.py @@ -0,0 +1,316 @@ +""" +从 rs_czsc Rust 源码提取信号函数文档,为每个信号生成独立 markdown 文件。 + +用法: python extract_signal_docs.py + +示例: + python extract_signal_docs.py crates/czsc-signals/src .claude/skills/signal-functions/references +""" + +import re +import sys +from pathlib import Path + + +# ── 文档块解析 ────────────────────────────────────────────────────── + +SIGNAL_ATTR_RE = re.compile( + r'#\[signal\(\s*\n(.*?)\)\]', re.DOTALL +) +SIGNAL_FIELD_RE = re.compile( + r'(\w+)\s*=\s*"([^"]*)"' +) +DOC_LINE_RE = re.compile(r'^\s*///\s?(.*)') + + +def extract_signals_from_file(filepath: Path) -> list[dict]: + """从单个 Rust 源文件提取所有信号函数的文档。""" + text = filepath.read_text(encoding='utf-8') + lines = text.split('\n') + + signals: list[dict] = [] + i = 0 + + while i < len(lines): + line = lines[i] + + # 查找 #[signal(...)] 属性 + if line.strip().startswith('#[signal('): + attr_start = i + attr_text = line + depth = line.count('(') - line.count(')') + j = i + 1 + while depth > 0 and j < len(lines): + attr_text += '\n' + lines[j] + depth += lines[j].count('(') - lines[j].count(')') + j += 1 + i = j + + # 解析属性字段 + attr_match = SIGNAL_ATTR_RE.search(attr_text) + if not attr_match: + continue + fields = dict(SIGNAL_FIELD_RE.findall(attr_match.group(1))) + + # 向上查找文档注释块(紧贴 #[signal] 之前的 /// 行) + doc_lines: list[str] = [] + k = attr_start - 1 + while k >= 0 and lines[k].strip().startswith('///'): + doc_lines.insert(0, lines[k]) + k -= 1 + + # 解析文档注释 + doc_parts = parse_doc_comment(doc_lines) + + # 向下查找函数签名 + func_line = '' + while i < len(lines): + if lines[i].strip().startswith('pub fn ') or lines[i].strip().startswith('pub async fn '): + func_line = lines[i].strip() + break + i += 1 + + sig_info = { + 'name': fields.get('name', ''), + 'template': fields.get('template', ''), + 'category': fields.get('category', ''), + 'opcode': fields.get('opcode', ''), + 'param_kind': fields.get('param_kind', ''), + 'module': filepath.stem, + 'title': doc_parts.get('title', fields.get('name', '')), + 'param_template': doc_parts.get('param_template', ''), + 'logic': doc_parts.get('logic', ''), + 'examples': doc_parts.get('examples', ''), + 'params': doc_parts.get('params', ''), + 'alignment': doc_parts.get('alignment', ''), + } + signals.append(sig_info) + else: + i += 1 + + return signals + + +def parse_doc_comment(doc_lines: list[str]) -> dict[str, str]: + """解析 /// 文档注释块,提取各段落。""" + # 去掉 /// 前缀 + cleaned: list[str] = [] + for raw in doc_lines: + m = DOC_LINE_RE.match(raw) + cleaned.append(m.group(1) if m else '') + + # 合并为一整段文本 + full_text = '\n'.join(cleaned) + + parts: dict[str, str] = {} + + # 标题行(第一行非空) + title_match = re.match(r'([^\n]+)', full_text) + if title_match: + parts['title'] = title_match.group(1).strip() + + # 参数模板 + m = re.search(r'参数模板:[`"](.+?)[`"]', full_text) + if m: + parts['param_template'] = m.group(1) + + # 信号逻辑 + logic_match = re.search(r'信号逻辑:(.*?)(?=信号列表示例:|参数说明:|$)', full_text, re.DOTALL) + if logic_match: + parts['logic'] = logic_match.group(1).strip() + + # 信号列表示例 + examples_match = re.search(r'信号列表示例:(.*?)(?=参数说明:|对齐说明:|$)', full_text, re.DOTALL) + if examples_match: + parts['examples'] = examples_match.group(1).strip() + + # 参数说明 + params_match = re.search(r'参数说明:(.*?)(?=对齐说明:|$)', full_text, re.DOTALL) + if params_match: + parts['params'] = params_match.group(1).strip() + + # 对齐说明 + align_match = re.search(r'对齐说明:(.+)', full_text, re.DOTALL) + if align_match: + parts['alignment'] = align_match.group(1).strip() + + return parts + + +# ── Markdown 生成 ────────────────────────────────────────────────── + +def generate_signal_md(sig: dict) -> str: + """为单个信号生成 markdown 内容。""" + name = sig['name'] + lines: list[str] = [] + + lines.append(f'# {sig.get("title", name)}') + lines.append('') + lines.append(f'> 模块: `{sig["module"]}.rs` | 类别: `{sig["category"]}`') + lines.append('') + + # 参数模板 + if sig.get('param_template'): + lines.append('## 参数模板') + lines.append('') + lines.append(f'`{sig["param_template"]}`') + lines.append('') + + # 信号逻辑 + if sig.get('logic'): + lines.append('## 信号逻辑') + lines.append('') + for logic_line in sig['logic'].split('\n'): + stripped = logic_line.strip() + if stripped: + lines.append(stripped) + lines.append('') + + # 信号列表示例 + if sig.get('examples'): + lines.append('## 信号列表示例') + lines.append('') + for ex_line in sig['examples'].split('\n'): + stripped = ex_line.strip() + if stripped: + lines.append(stripped) + lines.append('') + + # 参数说明 + if sig.get('params'): + lines.append('## 参数说明') + lines.append('') + for p_line in sig['params'].split('\n'): + stripped = p_line.strip() + if stripped: + lines.append(stripped) + lines.append('') + + # 对齐说明 + if sig.get('alignment'): + lines.append('## 对齐说明') + lines.append('') + lines.append(sig['alignment']) + lines.append('') + + return '\n'.join(lines) + + +def generate_module_index(module_name: str, signals: list[dict], all_signal_files: dict[str, str]) -> str: + """为模块生成索引 markdown。""" + lines: list[str] = [] + lines.append(f'# {module_name} 模块信号索引') + lines.append('') + lines.append(f'> 源码: `crates/czsc-signals/src/{module_name}.rs`') + lines.append(f'> 共 {len(signals)} 个信号') + lines.append('') + lines.append('| 信号名 | 参数模板 | 说明 | 详细文档 |') + lines.append('|--------|----------|------|----------|') + + for sig in sorted(signals, key=lambda s: s['name']): + name = sig['name'] + template = sig.get('param_template', sig.get('template', '')) + # 从标题中提取简短说明 + title = sig.get('title', name) + desc = title.split(':', 1)[-1] if ':' in title else title + # 文件链接 + filename = all_signal_files.get(name, '') + link = f'[详细文档](signals/{filename})' if filename else '' + lines.append(f'| `{name}` | `{template}` | {desc} | {link} |') + + lines.append('') + return '\n'.join(lines) + + +# ── 主流程 ───────────────────────────────────────────────────────── + +# 模块分类映射 +MODULE_GROUPS = { + # K线级信号 + 'bar': 'bar', + 'cxt': 'cxt', + 'tas': 'tas', + 'jcc': 'jcc', + 'zdy': 'zdy', + 'ang': 'misc', + 'xl': 'misc', + 'vol': 'misc', + 'coo': 'misc', + 'byi': 'misc', + 'pressure': 'misc', + 'obv': 'misc', + 'cvolp': 'misc', + 'ntmdk': 'misc', + 'kcatr': 'misc', + 'clv': 'misc', + # 交易级信号 + 'pos': 'trader', + 'cat': 'trader', + 'cxt_trader': 'trader', + 'zdy_trader': 'trader', +} + + +def main(): + if len(sys.argv) < 3: + print(f'用法: python {sys.argv[0]} ') + sys.exit(1) + + src_dir = Path(sys.argv[1]) + out_dir = Path(sys.argv[2]) + signals_dir = out_dir / 'signals' + + signals_dir.mkdir(parents=True, exist_ok=True) + + # 收集所有信号 + all_signals: list[dict] = [] + rs_files = sorted(src_dir.glob('*.rs')) + + for rs_file in rs_files: + if rs_file.name in ('lib.rs', 'registry.rs', 'types.rs', 'params.rs', 'utils.rs'): + continue + sigs = extract_signals_from_file(rs_file) + all_signals.extend(sigs) + print(f' {rs_file.name}: 提取 {len(sigs)} 个信号') + + print(f'\n共提取 {len(all_signals)} 个信号函数') + + # 为每个信号生成独立 markdown 文件 + signal_file_map: dict[str, str] = {} # name -> filename + for sig in all_signals: + name = sig['name'] + filename = f'{name}.md' + md_content = generate_signal_md(sig) + (signals_dir / filename).write_text(md_content, encoding='utf-8') + signal_file_map[name] = filename + + print(f'已生成 {len(signal_file_map)} 个信号文档文件到 {signals_dir}/') + + # 按模块分组 + module_signals: dict[str, list[dict]] = {} + for sig in all_signals: + mod = sig['module'] + module_signals.setdefault(mod, []).append(sig) + + # 生成模块级索引文件 + for mod_name, sigs in sorted(module_signals.items()): + index_content = generate_module_index(mod_name, sigs, signal_file_map) + index_path = out_dir / f'signals-{mod_name}.md' + index_path.write_text(index_content, encoding='utf-8') + print(f' 模块索引: {index_path.name} ({len(sigs)} 个信号)') + + # 生成总索引(按分组归类) + group_order = ['bar', 'cxt', 'tas', 'jcc', 'zdy', 'ang', 'xl', 'vol', 'coo', 'byi', + 'pressure', 'obv', 'cvolp', 'ntmdk', 'kcatr', 'clv', + 'pos', 'cat', 'cxt_trader', 'zdy_trader'] + + print('\n=== 信号统计 ===') + kline_count = sum(1 for s in all_signals if s['category'] == 'kline') + trader_count = sum(1 for s in all_signals if s['category'] == 'trader') + print(f'K线级信号: {kline_count}') + print(f'交易级信号: {trader_count}') + print(f'总计: {len(all_signals)}') + + +if __name__ == '__main__': + main() From 6b53268a724ea32b2aaafaa64b17b885a5843524 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Fri, 8 May 2026 10:31:29 +0800 Subject: [PATCH 17/23] fix(smoke): skip wheel install test when dist/ is empty --- test/smoke/test_install.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/smoke/test_install.py b/test/smoke/test_install.py index 8acd87ff8..5ee43d04e 100644 --- a/test/smoke/test_install.py +++ b/test/smoke/test_install.py @@ -106,9 +106,12 @@ def test_wheel_install_in_clean_venv(tmp_path: Path) -> None: # 选取 dist/ 下的所有 czsc 安装包,并以排序方式取最新版本 wheels = sorted(DIST_DIR.glob("czsc-*.whl")) if DIST_DIR.is_dir() else [] if not wheels: - pytest.fail( + # 常规 CI 的 test job 只跑 `maturin develop`,不产 wheel; + # wheel 安装验证由 python-publish.yml 的 smoke-test job 在发布流程 + # 里专门跑。所以本地/常规 CI 没 wheel 时跳过,不算失败。 + pytest.skip( f"在 {DIST_DIR} 下找不到 wheel;" - "请先运行 `maturin build --release` 再执行本冒烟测试" + "如需运行本测试,请先 `maturin build --release`" ) # 构建一个全新的虚拟环境用于隔离安装 From 88077edab697ac7f5c873c1573f009862b8a9e0a Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Fri, 8 May 2026 11:12:58 +0800 Subject: [PATCH 18/23] Refactor test code for improved readability and consistency - Simplified list and dictionary assertions in various test files for clarity. - Removed unnecessary line breaks and whitespace to enhance code cleanliness. - Consolidated multi-line assertions into single lines where appropriate. - Ensured consistent formatting across test cases, including parameterized tests. - Updated comments and docstrings for better understanding of test purposes. --- czsc/__init__.py | 218 ++++++++++++----- czsc/_compat.py | 12 +- czsc/_native.pyi | 298 +++++++++++++++-------- czsc/_utils/_df_convert.py | 5 +- czsc/connectors/cooperation.py | 33 +-- czsc/connectors/jq_connector.py | 42 ++-- czsc/models.py | 12 +- czsc/research.py | 18 +- czsc/signals/_helpers.py | 3 +- czsc/signals/bar.py | 6 + czsc/signals/cvolp.py | 6 + czsc/signals/cxt.py | 6 + czsc/signals/obv.py | 6 + czsc/signals/pressure.py | 6 + czsc/signals/tas.py | 6 + czsc/signals/vol.py | 6 + czsc/strategies.py | 17 +- czsc/svc/backtest.py | 1 - czsc/svc/factor.py | 2 +- czsc/svc/price_analysis.py | 2 +- czsc/svc/returns.py | 5 +- czsc/svc/statistics.py | 2 +- czsc/svc/strategy.py | 7 +- czsc/traders/__init__.py | 3 +- czsc/traders/__init__.pyi | 3 +- czsc/traders/base.py | 1 - czsc/traders/optimize.py | 45 ++-- czsc/utils/__init__.py | 1 - test/compat/test_public_api.py | 24 +- test/integration/test_weight_backtest.py | 10 +- test/parity/_signal_defaults.py | 39 ++- test/parity/bench_optimize.py | 53 ++-- test/parity/compare_optimize_full.py | 85 ++++--- test/parity/conftest.py | 8 +- test/parity/test_all_signals.py | 26 +- test/parity/test_examples.py | 51 ++-- test/parity/test_optimize.py | 14 +- test/parity/test_performance.py | 35 +-- test/parity/test_run_research.py | 4 +- test/parity/test_signals_registry.py | 10 +- test/smoke/test_install.py | 17 +- test/test_envs.py | 8 +- test/test_plotly_plot.py | 3 +- test/test_stoploss_by_direction.py | 29 ++- test/test_utils.py | 3 - test/unit/test_core_parity.py | 18 +- test/unit/test_pickle.py | 15 +- test/unit/test_signals_parity.py | 21 +- test/unit/test_ta_parity.py | 20 +- test/unit/test_trading_time.py | 13 +- 50 files changed, 710 insertions(+), 568 deletions(-) diff --git a/czsc/__init__.py b/czsc/__init__.py index aa6d248fa..1caf420be 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -21,6 +21,22 @@ _sys.modules["czsc.ta"] = _native.ta # === 缠论核心数据类型与算法(来自 Rust 扩展 czsc._native)=== +# === wbt(硬依赖,提供回测/绩效组件)=== +from wbt import WeightBacktest, daily_performance, top_drawdowns + +# format_standard_kline: Python 适配层,把 DataFrame -> List[RawBar](详见模块 docstring) +from czsc._format_standard_kline import format_standard_kline + +# === 之前的 lazy 属性,改为静态 import(spec §3.1 移除 lazy loading)=== +from czsc.utils.kline_quality import check_kline_quality +from czsc.utils.log import log_strategy_info +from czsc.utils.plotting.kline import KlineChart, plot_czsc_chart +from czsc.utils.trade import adjust_holding_weights +from czsc.utils.warning_capture import capture_warnings, execute_with_warning_capture + +# 第二批:会回头 import czsc 顶层符号(如 ``from czsc import top_drawdowns``)的重型子包。 +# 必须放在所有顶层符号都已经绑定之后,否则会触发 partially-initialized module 循环 import。 +from . import aphorism, fsa, mock, svc from ._native import ( BI, CZSC, @@ -52,12 +68,6 @@ ultimate_smoother, ) -# format_standard_kline: Python 适配层,把 DataFrame -> List[RawBar](详见模块 docstring) -from czsc._format_standard_kline import format_standard_kline - -# === wbt(硬依赖,提供回测/绩效组件)=== -from wbt import WeightBacktest, daily_performance, top_drawdowns - # === EDA 工具(来自 czsc.eda)=== from .eda import ( cal_symbols_factor, @@ -82,6 +92,15 @@ weights_simple_ensemble, ) +# === 研究/优化入口(czsc.research,Rust 后端)=== +from .research import ( + build_exit_optim_positions, + build_open_optim_positions, + run_optimize_batch, + run_replay, + run_research, +) + # === 策略门面(czsc.strategies;Python 层对 Rust Trader 的薄封装)=== from .strategies import CzscJsonStrategy, CzscStrategyBase @@ -98,15 +117,6 @@ get_unique_signals, ) -# === 研究/优化入口(czsc.research,Rust 后端)=== -from .research import ( - build_exit_optim_positions, - build_open_optim_positions, - run_optimize_batch, - run_replay, - run_research, -) - # === 通用工具函数(czsc.utils)=== from .utils import ( AliyunOSS, @@ -149,17 +159,6 @@ x_round, ) -# === 之前的 lazy 属性,改为静态 import(spec §3.1 移除 lazy loading)=== -from czsc.utils.kline_quality import check_kline_quality -from czsc.utils.log import log_strategy_info -from czsc.utils.plotting.kline import KlineChart, plot_czsc_chart -from czsc.utils.trade import adjust_holding_weights -from czsc.utils.warning_capture import capture_warnings, execute_with_warning_capture - -# 第二批:会回头 import czsc 顶层符号(如 ``from czsc import top_drawdowns``)的重型子包。 -# 必须放在所有顶层符号都已经绑定之后,否则会触发 partially-initialized module 循环 import。 -from . import aphorism, fsa, mock, svc - # czsc.ta 别名再保险:from .utils import ... 链路上若有副作用 import czsc.utils.ta, # 可能把 sys.modules["czsc.ta"] 覆盖回 Python 包装版本,这里再绑一次确保 Rust 子模块胜出。 ta = _native.ta @@ -175,53 +174,144 @@ # 修改本列表等价于修改公共契约;新增/移除符号必须在 release notes 与 MIGRATION_NOTES 中说明。 __all__ = [ # 缠论核心 - "BI", "CZSC", "FX", "ZS", "BarGenerator", "Direction", "Event", "FakeBI", - "Freq", "Mark", "NewBar", "Operate", "ParsedSignalDoc", "Position", "RawBar", "Signal", - "boll_positions", "check_bi", "check_fx", "check_fxs", - "ema", "format_standard_kline", "freq_end_time", "is_trading_time", - "parse_signal_doc", "remove_include", "rolling_rank", "sma", "ultimate_smoother", + "BI", + "CZSC", + "FX", + "ZS", + "BarGenerator", + "Direction", + "Event", + "FakeBI", + "Freq", + "Mark", + "NewBar", + "Operate", + "ParsedSignalDoc", + "Position", + "RawBar", + "Signal", + "boll_positions", + "check_bi", + "check_fx", + "check_fxs", + "ema", + "format_standard_kline", + "freq_end_time", + "is_trading_time", + "parse_signal_doc", + "remove_include", + "rolling_rank", + "sma", + "ultimate_smoother", # 来自 wbt - "WeightBacktest", "daily_performance", "top_drawdowns", + "WeightBacktest", + "daily_performance", + "top_drawdowns", # 始终预加载的子包 - "connectors", "envs", "sensors", "signals", "traders", "utils", - "svc", "fsa", "aphorism", "mock", + "connectors", + "envs", + "sensors", + "signals", + "traders", + "utils", + "svc", + "fsa", + "aphorism", + "mock", # 交易器 / 信号 API - "CzscSignals", "CzscTrader", "SignalsParser", - "derive_signals_config", "derive_signals_freqs", - "generate_czsc_signals", "get_signals_config", "get_signals_freqs", "get_unique_signals", + "CzscSignals", + "CzscTrader", + "SignalsParser", + "derive_signals_config", + "derive_signals_freqs", + "generate_czsc_signals", + "get_signals_config", + "get_signals_freqs", + "get_unique_signals", # 策略门面 - "CzscStrategyBase", "CzscJsonStrategy", + "CzscStrategyBase", + "CzscJsonStrategy", # 研究/优化入口 - "build_exit_optim_positions", "build_open_optim_positions", - "run_optimize_batch", "run_replay", "run_research", + "build_exit_optim_positions", + "build_open_optim_positions", + "run_optimize_batch", + "run_replay", + "run_research", # 通用工具 - "AliyunOSS", "DataClient", "DiskCache", - "clear_cache", "clear_expired_cache", - "code_namespace", "create_grid_params", "cross_sectional_ic", - "dill_dump", "dill_load", "disk_cache", - "empty_cache_path", "fernet_decrypt", "fernet_encrypt", - "freqs_sorted", "generate_fernet_key", - "get_dir_size", "get_py_namespace", "get_url_token", - "holds_performance", "home_path", - "import_by_name", "index_composition", - "mac_address", "print_df_sample", "psi", - "read_json", "resample_to_daily", "risk_free_returns", - "rolling_daily_performance", "save_json", "set_url_token", - "ta", "timeout_decorator", "to_arrow", - "update_bbars", "update_nxb", "update_tbars", "x_round", + "AliyunOSS", + "DataClient", + "DiskCache", + "clear_cache", + "clear_expired_cache", + "code_namespace", + "create_grid_params", + "cross_sectional_ic", + "dill_dump", + "dill_load", + "disk_cache", + "empty_cache_path", + "fernet_decrypt", + "fernet_encrypt", + "freqs_sorted", + "generate_fernet_key", + "get_dir_size", + "get_py_namespace", + "get_url_token", + "holds_performance", + "home_path", + "import_by_name", + "index_composition", + "mac_address", + "print_df_sample", + "psi", + "read_json", + "resample_to_daily", + "risk_free_returns", + "rolling_daily_performance", + "save_json", + "set_url_token", + "ta", + "timeout_decorator", + "to_arrow", + "update_bbars", + "update_nxb", + "update_tbars", + "x_round", # 静态 import 的高频符号(曾经走 _LAZY_ATTRS) - "capture_warnings", "execute_with_warning_capture", - "adjust_holding_weights", "log_strategy_info", - "plot_czsc_chart", "KlineChart", "check_kline_quality", + "capture_warnings", + "execute_with_warning_capture", + "adjust_holding_weights", + "log_strategy_info", + "plot_czsc_chart", + "KlineChart", + "check_kline_quality", # EDA - "remove_beta_effects", "vwap", "twap", "cross_sectional_strategy", - "monotonicity", "min_max_limit", "rolling_layers", "cal_symbols_factor", - "weights_simple_ensemble", "unify_weights", - "sma_long_bear", "dif_long_bear", "tsf_type", "limit_leverage", - "cal_trade_price", "mark_cta_periods", "mark_volatility", "cal_yearly_days", - "turnover_rate", "make_price_features", + "remove_beta_effects", + "vwap", + "twap", + "cross_sectional_strategy", + "monotonicity", + "min_max_limit", + "rolling_layers", + "cal_symbols_factor", + "weights_simple_ensemble", + "unify_weights", + "sma_long_bear", + "dif_long_bear", + "tsf_type", + "limit_leverage", + "cal_trade_price", + "mark_cta_periods", + "mark_volatility", + "cal_yearly_days", + "turnover_rate", + "make_price_features", # 元信息 - "__version__", "__author__", "__email__", "__date__", "welcome", + "__version__", + "__author__", + "__email__", + "__date__", + "welcome", ] diff --git a/czsc/_compat.py b/czsc/_compat.py index 53d89cd92..ee6ad916f 100644 --- a/czsc/_compat.py +++ b/czsc/_compat.py @@ -23,12 +23,12 @@ import hashlib import json +from collections.abc import Iterable from pathlib import Path -from typing import Any, Iterable +from typing import Any import pandas as pd - # 周期字符串到排序权重的映射表 # 数字越小代表越小级别(越高频),按此权重做稳定排序后, # 同一策略中"小周期 -> 大周期"的展示顺序与缠论惯例一致。 @@ -171,9 +171,7 @@ def position_dump_to_runtime(payload: dict[str, Any]) -> dict[str, Any]: event_copy = dict(event) # 三种信号关系字段:必须存在但允许为空列表 for sig_key in ("signals_all", "signals_any", "signals_not"): - event_copy[sig_key] = [ - signal_kv_to_string(sig) for sig in list(event_copy.get(sig_key) or []) - ] + event_copy[sig_key] = [signal_kv_to_string(sig) for sig in list(event_copy.get(sig_key) or [])] events.append(event_copy) out[event_key] = events return out @@ -371,9 +369,7 @@ def py_repr_json(value: Any) -> str: if isinstance(value, list): return "[" + ", ".join(py_repr_json(item) for item in value) + "]" if isinstance(value, dict): - return "{" + ", ".join( - f"'{py_escape_str(str(key))}': {py_repr_json(val)}" for key, val in value.items() - ) + "}" + return "{" + ", ".join(f"'{py_escape_str(str(key))}': {py_repr_json(val)}" for key, val in value.items()) + "}" # 兜底:把非典型类型按字符串处理,再走一次递归 return py_repr_json(str(value)) diff --git a/czsc/_native.pyi b/czsc/_native.pyi index ff800f954..2a9aee46f 100644 --- a/czsc/_native.pyi +++ b/czsc/_native.pyi @@ -2,11 +2,12 @@ # ruff: noqa: E501, F401 import builtins -import numpy -import numpy.typing import typing from enum import Enum +import numpy +import numpy.typing + class BI: r""" 笔 @@ -116,21 +117,29 @@ class BI: r""" 缓存字典(与 czsc 库兼容) """ - def __new__(cls, symbol:builtins.str, direction:Direction, fx_a:FX, fx_b:FX, fxs:typing.Sequence[FX], bars:typing.Sequence[NewBar]) -> BI: ... - def get_cache_with_default(self, _key:builtins.str, default_value:builtins.float) -> builtins.float: + def __new__( + cls, + symbol: builtins.str, + direction: Direction, + fx_a: FX, + fx_b: FX, + fxs: typing.Sequence[FX], + bars: typing.Sequence[NewBar], + ) -> BI: ... + def get_cache_with_default(self, _key: builtins.str, default_value: builtins.float) -> builtins.float: r""" 获取缓存值,如果不存在则返回默认值(与 czsc 库兼容) """ - def get_price_linear(self, n:builtins.int) -> builtins.float: + def get_price_linear(self, n: builtins.int) -> builtins.float: r""" 获取线性价格(与 czsc 库兼容) """ def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other:BI, op:int) -> builtins.bool: ... + def __richcmp__(self, other: BI, op: int) -> builtins.bool: ... class BarGenerator: @property - def symbol_py(self) -> typing.Optional[builtins.str]: + def symbol_py(self) -> builtins.str | None: r""" 获取所属品种 - Python 属性 """ @@ -140,7 +149,7 @@ class BarGenerator: 获取基准频率 """ @property - def end_dt(self) -> typing.Optional[typing.Any]: + def end_dt(self) -> typing.Any | None: r""" 获取end_dt属性(Python兼容) """ @@ -149,28 +158,34 @@ class BarGenerator: r""" 获取各周期K线数据 - 返回字典,键为频率字符串,值为K线列表 """ - def __new__(cls, base_freq:typing.Any, freqs:typing.Any, max_count:builtins.int=2000, market:typing.Optional[typing.Any]=None) -> BarGenerator: ... - def init_freq_bars(self, freq:typing.Any, bars:typing.Sequence[RawBar]) -> None: + def __new__( + cls, + base_freq: typing.Any, + freqs: typing.Any, + max_count: builtins.int = 2000, + market: typing.Any | None = None, + ) -> BarGenerator: ... + def init_freq_bars(self, freq: typing.Any, bars: typing.Sequence[RawBar]) -> None: r""" 初始化某个周期的K线序列 - + # 函数计算逻辑 - + 1. 检查输入的`freq`是否存在于`self.freq_bars`的键中。如果不存在,返回错误。 2. 检查`self.freq_bars[freq]`是否为空。如果不为空,返回错误,表示不允许重复初始化。 3. 如果以上检查都通过,将输入的`bars`存储到`self.freq_bars[freq]`中。 4. 从`bars`中获取最后一根K线的交易标的代码,更新`self.symbol`。 - + # Arguments - + * `freq` - 周期名称 (支持字符串或Freq枚举) * `bars` - K线序列 """ - def get_latest_date(self) -> typing.Optional[builtins.str]: + def get_latest_date(self) -> builtins.str | None: r""" 获取最新K线日期 """ - def update(self, bar:RawBar) -> None: + def update(self, bar: RawBar) -> None: r""" 从Python RawBar对象更新K线数据 - 支持直接自动转换 """ @@ -178,7 +193,7 @@ class BarGenerator: r""" 支持 pickle 序列化 - 使用 __reduce__ 方法 """ - def __setstate__(self, state:typing.Any) -> None: + def __setstate__(self, state: typing.Any) -> None: r""" 支持 pickle 反序列化 """ @@ -254,19 +269,19 @@ class CZSC: r""" 缓存字典(与 czsc 库兼容) """ - def __new__(cls, bars_raw:typing.Sequence[RawBar], max_bi_num:builtins.int=50) -> CZSC: ... + def __new__(cls, bars_raw: typing.Sequence[RawBar], max_bi_num: builtins.int = 50) -> CZSC: ... @staticmethod - def from_dataframe(df_bytes:bytes, freq:Freq, max_bi_num:builtins.int=50) -> CZSC: + def from_dataframe(df_bytes: bytes, freq: Freq, max_bi_num: builtins.int = 50) -> CZSC: r""" 直接从Arrow格式的DataFrame创建CZSC对象,避免中间转换 这是高性能的批量创建接口,适用于大量数据的初始化 - + :param df_bytes: Arrow IPC格式的DataFrame字节数据 :param freq: K线频率 :param max_bi_num: 最大笔数量限制 :return: CZSC对象 """ - def open_in_browser(self, _renderer:typing.Optional[builtins.str]=None) -> builtins.str: + def open_in_browser(self, _renderer: builtins.str | None = None) -> builtins.str: r""" 在浏览器中打开(与 czsc 库兼容) """ @@ -278,7 +293,7 @@ class CZSC: r""" 转换为 Plotly 格式(与 czsc 库兼容) """ - def update(self, bar:RawBar) -> None: + def update(self, bar: RawBar) -> None: r""" 更新K线数据 """ @@ -286,7 +301,7 @@ class CZSC: def __reduce__(self) -> typing.Any: r""" Pickle support — `__reduce__` returns ``(CZSC, (fixed_point_bars, max_bi_num))``. - + `update_bar` drains older bars whose dt is below the current first-BI's start (see `bars_raw.drain` block above), so a freshly-constructed CZSC's `bars_raw` may still differ from the @@ -333,17 +348,17 @@ class CzscSignals: 返回基准周期字符串 """ @property - def end_dt(self) -> typing.Optional[typing.Any]: + def end_dt(self) -> typing.Any | None: r""" 返回最新时间,作为 pandas Timestamp """ @property - def bid(self) -> typing.Optional[builtins.int]: + def bid(self) -> builtins.int | None: r""" 返回当前 bar id """ @property - def latest_price(self) -> typing.Optional[builtins.float]: + def latest_price(self) -> builtins.float | None: r""" 返回最新价格 """ @@ -352,8 +367,8 @@ class CzscSignals: r""" 返回原始信号配置 """ - def __new__(cls, bg:BarGenerator, signals_config:list) -> CzscSignals: ... - def update_signals(self, bar:RawBar) -> None: + def __new__(cls, bg: BarGenerator, signals_config: list) -> CzscSignals: ... + def update_signals(self, bar: RawBar) -> None: r""" 更新信号 """ @@ -404,17 +419,17 @@ class CzscTrader: 返回基准周期字符串 """ @property - def end_dt(self) -> typing.Optional[typing.Any]: + def end_dt(self) -> typing.Any | None: r""" 返回最新时间,作为 pandas Timestamp """ @property - def bid(self) -> typing.Optional[builtins.int]: + def bid(self) -> builtins.int | None: r""" 返回当前 bar id """ @property - def latest_price(self) -> typing.Optional[builtins.float]: + def latest_price(self) -> builtins.float | None: r""" 返回最新价格 """ @@ -433,24 +448,26 @@ class CzscTrader: r""" 返回是否有仓位发生变化 """ - def __new__(cls, bg:BarGenerator, positions:list, signals_config:list, ensemble_method:builtins.str='mean') -> CzscTrader: ... - def update(self, bar:RawBar) -> None: + def __new__( + cls, bg: BarGenerator, positions: list, signals_config: list, ensemble_method: builtins.str = "mean" + ) -> CzscTrader: ... + def update(self, bar: RawBar) -> None: r""" 更新信号和仓位 """ - def on_bar(self, bar:RawBar) -> None: + def on_bar(self, bar: RawBar) -> None: r""" 更新信号和仓位(同 update) """ - def on_sig(self, sig:dict) -> None: + def on_sig(self, sig: dict) -> None: r""" 基于信号字典更新仓位 """ - def get_ensemble_pos(self, method:typing.Optional[builtins.str]=None) -> builtins.float: + def get_ensemble_pos(self, method: builtins.str | None = None) -> builtins.float: r""" 获取集成后的仓位值 """ - def get_position(self, name:builtins.str) -> typing.Optional[Position]: + def get_position(self, name: builtins.str) -> Position | None: r""" 根据名称获取仓位 """ @@ -458,7 +475,7 @@ class CzscTrader: r""" 获取当前信号字典 """ - def update_signals(self, bar:RawBar) -> None: + def update_signals(self, bar: RawBar) -> None: r""" 仅更新信号(不更新仓位) """ @@ -493,16 +510,23 @@ class Event: r""" 获取SHA256哈希 """ - def __new__(cls, operate:Operate, signals_all:typing.Sequence[Signal]=[], signals_any:typing.Sequence[Signal]=[], signals_not:typing.Sequence[Signal]=[], name:builtins.str='') -> Event: ... + def __new__( + cls, + operate: Operate, + signals_all: typing.Sequence[Signal] = [], + signals_any: typing.Sequence[Signal] = [], + signals_not: typing.Sequence[Signal] = [], + name: builtins.str = "", + ) -> Event: ... @classmethod - def from_dict(cls, _cls:type, dict:dict) -> Event: ... + def from_dict(cls, _cls: type, dict: dict) -> Event: ... @classmethod - def from_json(cls, _cls:type, json_str:builtins.str) -> Event: ... + def from_json(cls, _cls: type, json_str: builtins.str) -> Event: ... def compute_sha8(self) -> builtins.str: r""" 计算SHA8哈希值 """ - def is_match(self, signals:typing.Any) -> builtins.bool: + def is_match(self, signals: typing.Any) -> builtins.bool: r""" 判断事件是否匹配信号集合,返回是否匹配 支持多种参数类型:Dict[str, str] 或 Dict[str, Signal] 或 Vec @@ -518,7 +542,7 @@ class Event: 导出为字典 """ @classmethod - def load(cls, _cls:type, data:dict) -> Event: + def load(cls, _cls: type, data: dict) -> Event: r""" 从字典加载 """ @@ -587,9 +611,18 @@ class FX: r""" 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 """ - def __new__(cls, symbol:builtins.str, dt:typing.Any, mark:Mark, high:builtins.float, low:builtins.float, fx:builtins.float, elements:typing.Sequence[NewBar]) -> FX: ... + def __new__( + cls, + symbol: builtins.str, + dt: typing.Any, + mark: Mark, + high: builtins.float, + low: builtins.float, + fx: builtins.float, + elements: typing.Sequence[NewBar], + ) -> FX: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other:FX, op:int) -> builtins.bool: ... + def __richcmp__(self, other: FX, op: int) -> builtins.bool: ... class FakeBI: r""" @@ -624,7 +657,7 @@ class LiteBar: def dt(self) -> builtins.float: ... @property def price(self) -> builtins.float: ... - def __new__(cls, id:builtins.int, dt:builtins.float, price:builtins.float) -> LiteBar: ... + def __new__(cls, id: builtins.int, dt: builtins.float, price: builtins.float) -> LiteBar: ... def __repr__(self) -> builtins.str: ... class NewBar: @@ -660,14 +693,28 @@ class NewBar: r""" 获取构成NewBar的原始K线列表(与elements相同,为兼容czsc库) """ - def __new__(cls, symbol:builtins.str, dt:typing.Any, freq:Freq, open:builtins.float, close:builtins.float, high:builtins.float, low:builtins.float, vol:builtins.float, amount:builtins.float, id:builtins.int=0, elements:typing.Optional[typing.Sequence[RawBar]]=None) -> NewBar: ... + def __new__( + cls, + symbol: builtins.str, + dt: typing.Any, + freq: Freq, + open: builtins.float, + close: builtins.float, + high: builtins.float, + low: builtins.float, + vol: builtins.float, + amount: builtins.float, + id: builtins.int = 0, + elements: typing.Sequence[RawBar] | None = None, + ) -> NewBar: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other:NewBar, op:int) -> builtins.bool: ... + def __richcmp__(self, other: NewBar, op: int) -> builtins.bool: ... class Operate: r""" Python可见的Operate包装器 """ + HL: Operate HS: Operate HO: Operate @@ -681,26 +728,26 @@ class Operate: 兼容性属性:返回操作类型的中文字符串值 """ @classmethod - def hl(cls, _cls:type) -> Operate: ... + def hl(cls, _cls: type) -> Operate: ... @classmethod - def hs(cls, _cls:type) -> Operate: ... + def hs(cls, _cls: type) -> Operate: ... @classmethod - def ho(cls, _cls:type) -> Operate: ... + def ho(cls, _cls: type) -> Operate: ... @classmethod - def lo(cls, _cls:type) -> Operate: ... + def lo(cls, _cls: type) -> Operate: ... @classmethod - def le(cls, _cls:type) -> Operate: ... + def le(cls, _cls: type) -> Operate: ... @classmethod - def so(cls, _cls:type) -> Operate: ... + def so(cls, _cls: type) -> Operate: ... @classmethod - def se(cls, _cls:type) -> Operate: ... + def se(cls, _cls: type) -> Operate: ... @classmethod - def from_str_py(cls, _cls:type, s:builtins.str) -> Operate: ... + def from_str_py(cls, _cls: type, s: builtins.str) -> Operate: ... @classmethod - def from_str(cls, _cls:type, s:builtins.str) -> Operate: ... + def from_str(cls, _cls: type, s: builtins.str) -> Operate: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __eq__(self, other:Operate) -> builtins.bool: ... + def __eq__(self, other: Operate) -> builtins.bool: ... def __hash__(self) -> builtins.int: ... def __reduce__(self) -> typing.Any: r""" @@ -712,7 +759,7 @@ class ParsedSignalDoc: Python可见的ParsedSignalDoc包装器 """ @property - def param_template(self) -> typing.Optional[builtins.str]: ... + def param_template(self) -> builtins.str | None: ... @property def signals(self) -> builtins.list[Signal]: ... def __repr__(self) -> builtins.str: ... @@ -722,19 +769,19 @@ class Pos: Python可见的Pos枚举包装器 """ @classmethod - def short(cls, _cls:type) -> Pos: ... + def short(cls, _cls: type) -> Pos: ... @classmethod - def flat(cls, _cls:type) -> Pos: ... + def flat(cls, _cls: type) -> Pos: ... @classmethod - def long(cls, _cls:type) -> Pos: ... + def long(cls, _cls: type) -> Pos: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __eq__(self, other:Pos) -> builtins.bool: ... - def __add__(self, other:Pos) -> builtins.float: + def __eq__(self, other: Pos) -> builtins.bool: ... + def __add__(self, other: Pos) -> builtins.float: r""" 加法运算,用于numpy.mean等数学操作 """ - def __radd__(self, other:builtins.float) -> builtins.float: + def __radd__(self, other: builtins.float) -> builtins.float: r""" 右加法运算 """ @@ -746,19 +793,19 @@ class Pos: r""" 整数转换 """ - def __lt__(self, other:Pos) -> builtins.bool: + def __lt__(self, other: Pos) -> builtins.bool: r""" 比较运算符 - 小于 """ - def __le__(self, other:Pos) -> builtins.bool: + def __le__(self, other: Pos) -> builtins.bool: r""" 比较运算符 - 小于等于 """ - def __gt__(self, other:Pos) -> builtins.bool: + def __gt__(self, other: Pos) -> builtins.bool: r""" 比较运算符 - 大于 """ - def __ge__(self, other:Pos) -> builtins.bool: + def __ge__(self, other: Pos) -> builtins.bool: r""" 比较运算符 - 大于等于 """ @@ -788,7 +835,7 @@ class Position: @property def pos_changed(self) -> builtins.bool: ... @property - def end_dt(self) -> typing.Optional[builtins.float]: + def end_dt(self) -> builtins.float | None: r""" 获取最新信号时间 """ @@ -811,12 +858,22 @@ class Position: def unique_signals(self) -> builtins.list[builtins.str]: ... @property def events(self) -> builtins.list[Event]: ... - def __new__(cls, symbol:builtins.str, opens:typing.Sequence[Event], exits:typing.Sequence[Event]=[], interval:builtins.int=0, timeout:builtins.int=1000, stop_loss:builtins.float=1000.0, t0:builtins.bool=False, name:typing.Optional[builtins.str]=None) -> Position: ... + def __new__( + cls, + symbol: builtins.str, + opens: typing.Sequence[Event], + exits: typing.Sequence[Event] = [], + interval: builtins.int = 0, + timeout: builtins.int = 1000, + stop_loss: builtins.float = 1000.0, + t0: builtins.bool = False, + name: builtins.str | None = None, + ) -> Position: ... @classmethod - def load_from_file(cls, _cls:type, path:builtins.str) -> Position: ... + def load_from_file(cls, _cls: type, path: builtins.str) -> Position: ... @classmethod - def from_json(cls, _cls:type, json_str:builtins.str) -> Position: ... - def save(self, path:builtins.str) -> None: + def from_json(cls, _cls: type, json_str: builtins.str) -> Position: ... + def save(self, path: builtins.str) -> None: r""" 保存到文件 """ @@ -828,7 +885,7 @@ class Position: r""" 获取所有相关事件 """ - def update(self, arg1:typing.Any, arg2:typing.Optional[typing.Any]=None) -> None: + def update(self, arg1: typing.Any, arg2: typing.Any | None = None) -> None: r""" 更新仓位状态(兼容单参数调用) """ @@ -836,12 +893,12 @@ class Position: r""" 支持 pickle 序列化 - 使用 __reduce__ 方法 """ - def dump(self, with_data:builtins.bool=True) -> typing.Any: + def dump(self, with_data: builtins.bool = True) -> typing.Any: r""" 导出Position数据为Python字典 """ @classmethod - def load(cls, _cls:type, data:typing.Any) -> Position: + def load(cls, _cls: type, data: typing.Any) -> Position: r""" 从字典数据加载Position """ @@ -893,7 +950,19 @@ class RawBar: r""" 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 """ - def __new__(cls, symbol:builtins.str, dt:typing.Any, freq:Freq, open:builtins.float, close:builtins.float, high:builtins.float, low:builtins.float, vol:builtins.float, amount:builtins.float, id:builtins.int=0) -> RawBar: ... + def __new__( + cls, + symbol: builtins.str, + dt: typing.Any, + freq: Freq, + open: builtins.float, + close: builtins.float, + high: builtins.float, + low: builtins.float, + vol: builtins.float, + amount: builtins.float, + id: builtins.int = 0, + ) -> RawBar: ... def _asdict(self) -> typing.Any: r""" 让对象表现得像记录,pandas DataFrame构造器会调用这个 @@ -906,12 +975,12 @@ class RawBar: r""" 支持pickle序列化 """ - def __deepcopy__(self, _memo:typing.Any) -> RawBar: + def __deepcopy__(self, _memo: typing.Any) -> RawBar: r""" 支持深拷贝 """ def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other:RawBar, op:int) -> builtins.bool: ... + def __richcmp__(self, other: RawBar, op: int) -> builtins.bool: ... class Signal: r""" @@ -938,22 +1007,35 @@ class Signal: """ @property def k2(self) -> builtins.str: ... - def __new__(cls, *args, signal:typing.Optional[builtins.str]=None, key:typing.Optional[builtins.str]=None, value:typing.Optional[builtins.str]=None, k1:typing.Optional[builtins.str]=None, k2:typing.Optional[builtins.str]=None, k3:typing.Optional[builtins.str]=None, v1:typing.Optional[builtins.str]=None, v2:typing.Optional[builtins.str]=None, v3:typing.Optional[builtins.str]=None, score:typing.Optional[builtins.int]=None) -> Signal: ... + def __new__( + cls, + *args, + signal: builtins.str | None = None, + key: builtins.str | None = None, + value: builtins.str | None = None, + k1: builtins.str | None = None, + k2: builtins.str | None = None, + k3: builtins.str | None = None, + v1: builtins.str | None = None, + v2: builtins.str | None = None, + v3: builtins.str | None = None, + score: builtins.int | None = None, + ) -> Signal: ... @classmethod - def from_string(cls, _cls:type, s:builtins.str) -> Signal: ... + def from_string(cls, _cls: type, s: builtins.str) -> Signal: ... def to_json(self) -> builtins.str: r""" 添加to_json方法以匹配Python版本 """ def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __eq__(self, other:Signal) -> builtins.bool: ... + def __eq__(self, other: Signal) -> builtins.bool: ... def __hash__(self) -> builtins.int: ... - def matches(self, other:Signal) -> builtins.bool: + def matches(self, other: Signal) -> builtins.bool: r""" 检查Signal是否匹配另一个Signal """ - def is_match(self, signals_dict:typing.Mapping[builtins.str, builtins.str]) -> builtins.bool: + def is_match(self, signals_dict: typing.Mapping[builtins.str, builtins.str]) -> builtins.bool: r""" 判断信号是否与信号字典中的值匹配(Python版本is_match逻辑) """ @@ -1015,7 +1097,7 @@ class ZS: """ @property def cache(self) -> dict: ... - def __new__(cls, bis:typing.Sequence[BI]) -> ZS: ... + def __new__(cls, bis: typing.Sequence[BI]) -> ZS: ... def is_valid(self) -> builtins.bool: r""" 中枢是否有效 @@ -1025,6 +1107,7 @@ class Direction(Enum): r""" 方向 """ + Up = ... r""" 向上 @@ -1039,7 +1122,7 @@ class Direction(Enum): r""" 获取方向的字符串值(与 czsc 库兼容) """ - def __deepcopy__(self, _memo:typing.Any) -> Direction: + def __deepcopy__(self, _memo: typing.Any) -> Direction: r""" 支持深拷贝 """ @@ -1047,15 +1130,16 @@ class Direction(Enum): r""" 支持pickle序列化 """ - def __new__(cls, value:builtins.str) -> Direction: ... + def __new__(cls, value: builtins.str) -> Direction: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... + def __richcmp__(self, other: typing.Any, op: int) -> builtins.bool: ... class Freq(Enum): r""" 时间周期 """ + Tick = ... r""" 逐笔 @@ -1144,7 +1228,7 @@ class Freq(Enum): __members__: typing.Any @property def value(self) -> builtins.str: ... - def __deepcopy__(self, _memo:typing.Any) -> Freq: + def __deepcopy__(self, _memo: typing.Any) -> Freq: r""" 支持深拷贝 """ @@ -1152,15 +1236,16 @@ class Freq(Enum): r""" 支持pickle序列化 """ - def __new__(cls, value:builtins.str) -> Freq: ... + def __new__(cls, value: builtins.str) -> Freq: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... + def __richcmp__(self, other: typing.Any, op: int) -> builtins.bool: ... class Mark(Enum): r""" 分型类型 """ + D = ... r""" 底分型 @@ -1177,7 +1262,7 @@ class Mark(Enum): """ def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... + def __richcmp__(self, other: typing.Any, op: int) -> builtins.bool: ... class Market(Enum): AShare = ... @@ -1193,43 +1278,44 @@ class Market(Enum): 默认 """ - def __new__(cls, ob:typing.Any) -> Market: ... + def __new__(cls, ob: typing.Any) -> Market: ... -def chip_distribution_triangle(data:numpy.typing.NDArray[numpy.float64], price_step:builtins.float, decay_factor:builtins.float) -> tuple[numpy.typing.NDArray[numpy.float64], numpy.typing.NDArray[numpy.float64]]: +def chip_distribution_triangle( + data: numpy.typing.NDArray[numpy.float64], price_step: builtins.float, decay_factor: builtins.float +) -> tuple[numpy.typing.NDArray[numpy.float64], numpy.typing.NDArray[numpy.float64]]: r""" 计算筹码分布(三角形分布 + 筹码沉淀机制) - + 此函数用于估算基于历史K线的筹码分布情况,结合三角形分布模型和筹码沉淀(衰减)机制。 - + # Python 接口说明 - + 输入一个二维 numpy 数组,形状为 (N, 3),每一行对应一根K线,列顺序为: `[high, low, vol]`,类型必须为 `float64`。 - + 示例: ```python columns = ['high', 'low', 'vol'] arr2 = df[columns].to_numpy(dtype=np.float64) price_centers, chip_dist = chip_distribution_triangle(arr2, 0.01, 0.9) ``` - + # 参数 - + - `data`: 二维数组,形状为 (N, 3),分别是每根K线的最高价、最低价和成交量。 - `price_step`: 分档间隔(如0.01表示以0.01为单位划分价格区间)。 - `decay_factor`: 筹码衰减因子,表示前一根K线上的筹码有多少比例沉淀保留到下一根K线上,范围为(0, 1),例如0.98表示保留98%。 - + # 返回值 - + 返回一个元组 `(price_centers, chip_distribution)`: - `price_centers`: 一维数组,表示价格分布区间的中心价位。 - `chip_distribution`: 一维数组,对应每个价格中心的筹码强度(权重/密度)。 - + 返回的两个数组长度相同,可用于绘制筹码分布图或进一步分析。 """ -def parse_signal_doc(doc:builtins.str) -> ParsedSignalDoc: +def parse_signal_doc(doc: builtins.str) -> ParsedSignalDoc: r""" 解析文档中的Signal信息 """ - diff --git a/czsc/_utils/_df_convert.py b/czsc/_utils/_df_convert.py index 1687a9a85..3c885ef2a 100644 --- a/czsc/_utils/_df_convert.py +++ b/czsc/_utils/_df_convert.py @@ -12,12 +12,13 @@ - PyArrow 与 Pandas 的版本组合需保持一致,否则可能在 Schema 推断时报错 """ + import pandas as pd import pyarrow as pa import pyarrow.ipc as ipc -from typing import Union -def pandas_to_arrow_bytes(df: Union[pd.DataFrame, pd.Series]) -> bytes: + +def pandas_to_arrow_bytes(df: pd.DataFrame | pd.Series) -> bytes: """ 将 Pandas DataFrame/Series 序列化为 Arrow IPC 文件格式的字节流 diff --git a/czsc/connectors/cooperation.py b/czsc/connectors/cooperation.py index 6d969440c..3e1163017 100644 --- a/czsc/connectors/cooperation.py +++ b/czsc/connectors/cooperation.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 作者: zengbin93 邮箱: zeng_bin8888@163.com @@ -32,17 +31,20 @@ - 数据接口由内部服务 ``http://zbczsc.com:9106`` 提供,外网用户无权访问; - 缓存目录默认位于 ``~/.quant_data_cache``,可以通过环境变量 ``CZSC_CACHE_PATH`` 覆盖。 """ + import os import time -import czsc -import requests +from datetime import datetime +from pathlib import Path +from typing import Any + import loguru import pandas as pd +import requests from tqdm import tqdm -from pathlib import Path -from datetime import datetime -from czsc import RawBar, Freq -from typing import Dict, List, Any + +import czsc +from czsc import Freq # 首次使用需要打开一个 Python 终端按如下方式设置 token,或者直接在环境变量中设置 CZSC_TOKEN # 示例:czsc.set_url_token(token='your token', url='http://zbczsc.com:9106') @@ -283,7 +285,7 @@ def get_raw_bars(symbol, freq, sdt, edt, fq="前复权", **kwargs): if "SH" in symbol or "SZ" in symbol: # 复权类型映射:本地中文枚举到底层接口缩写 fq_map = {"前复权": "qfq", "后复权": "hfq", "不复权": None} - adj = fq_map.get(fq, None) + adj = fq_map.get(fq) # 标的代码格式为 "code#asset",asset 取首字母即可(s/e/i 等) code, asset = symbol.split("#") @@ -717,7 +719,7 @@ def set_token(self, token: str): self._setup_headers() self.logger.info("访问令牌已更新") - def _make_request(self, method: str, endpoint: str, data: Dict = None) -> Dict: + def _make_request(self, method: str, endpoint: str, data: dict = None) -> dict: """统一的 HTTP 请求底层方法(内部方法)。 负责拼接 URL、根据 method 选择 GET/POST 调用、统一异常处理与日志输出。 @@ -750,7 +752,7 @@ def _make_request(self, method: str, endpoint: str, data: Dict = None) -> Dict: self.logger.error(f"响应解析失败: {e}") raise - def get_all_strategy_metadata(self) -> List[Dict]: + def get_all_strategy_metadata(self) -> list[dict]: """获取所有策略元数据。 :return: list[dict], 策略元数据列表;接口失败时返回空列表 @@ -872,7 +874,7 @@ def delete_strategy_meta(self, strategy_name: str) -> bool: self.logger.error(f"删除策略元数据失败: {result.get('msg', '未知错误')}") return False - def get_all_strategy_latest_weights(self) -> List[Dict]: + def get_all_strategy_latest_weights(self) -> list[dict]: """获取所有策略的最新持仓权重快照。 :return: list[dict], 策略权重数据列表;接口失败时返回空列表 @@ -887,7 +889,7 @@ def get_all_strategy_latest_weights(self) -> List[Dict]: self.logger.error(f"获取最新权重数据失败: {result.get('msg', '未知错误')}") return [] - def query_strategy_weight(self, strategy: str, sdt: str = "", edt: str = "", symbols: List[str] = None) -> Dict: + def query_strategy_weight(self, strategy: str, sdt: str = "", edt: str = "", symbols: list[str] = None) -> dict: """查询单个策略的持仓权重。 :param strategy: str, 策略名称 @@ -926,7 +928,7 @@ def delete_strategy(self, strategy: str) -> bool: self.logger.error(f"删除策略失败: {result.get('msg', '未知错误')}") return False - def clear_cache(self, tokens: List[str] = None, roles: List[int] = None) -> bool: + def clear_cache(self, tokens: list[str] = None, roles: list[int] = None) -> bool: """清除服务端的接口缓存。 :param tokens: list[str], 可选, 需要清除的 token 列表 @@ -954,7 +956,7 @@ def upload_strategy_weights( author: str, outsample_sdt: str, upload_token: str = None, - ) -> Dict: + ) -> dict: """上传策略权重数据。 与模块级 ``upload_strategy`` 函数功能一致,区别在于参数以独立形参的形式暴露, @@ -970,9 +972,10 @@ def upload_strategy_weights( :return: dict, 上传接口返回的结果 :raises requests.exceptions.RequestException: 网络异常时抛出 """ - import pandas as pd import os + import pandas as pd + # 数据预处理:拷贝以避免污染外部数据 df_copy = df.copy() df_copy["dt"] = pd.to_datetime(df_copy["dt"]) diff --git a/czsc/connectors/jq_connector.py b/czsc/connectors/jq_connector.py index 86425a462..521541a6d 100644 --- a/czsc/connectors/jq_connector.py +++ b/czsc/connectors/jq_connector.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ 模块说明: 聚宽(JQData)HTTP 数据接口的轻量封装。 @@ -27,18 +26,19 @@ - 单次 K 线请求最大 5000 条,超过会触发 warning 并被服务端截断; - 区间查询超过 1000 个交易日时同样可能失败,需要自行分段。 """ + +import json import os import pickle -import json -import requests import warnings from collections import OrderedDict -import pandas as pd from datetime import datetime, timedelta -from typing import List from urllib.parse import quote -from czsc import RawBar, Freq, BarGenerator, freq_end_time +import pandas as pd +import requests + +from czsc import BarGenerator, Freq, RawBar, freq_end_time # CZSC 内部周期字符串到聚宽 unit 字符串的映射 freq_cn2jq = { @@ -288,7 +288,7 @@ def get_all_securities(code, date=None) -> pd.DataFrame: def get_kline( symbol: str, end_date: [datetime, str], freq: str, start_date: [datetime, str] = None, count=None, fq: bool = True -) -> List[RawBar]: +) -> list[RawBar]: """获取 K 线数据并转换为 ``RawBar`` 列表。 支持两种调用模式: @@ -391,7 +391,7 @@ def get_kline( def get_kline_period( symbol: str, start_date: [datetime, str], end_date: [datetime, str], freq: str, fq=True -) -> List[RawBar]: +) -> list[RawBar]: """获取指定时间段的行情数据(仅区间模式,固定使用 ``get_price_period``)。 与 ``get_kline`` 的区别:本函数强制使用 start_date + end_date,对超长区间会发出告警。 @@ -461,7 +461,7 @@ def get_kline_period( return bars -def get_init_bg(symbol: str, end_dt: [str, datetime], base_freq: str, freqs: List[str], max_count=1000, fq=True): +def get_init_bg(symbol: str, end_dt: [str, datetime], base_freq: str, freqs: list[str], max_count=1000, fq=True): """获取指定标的的初始化 BarGenerator 以及待重放数据。 用于实时回放/回测的启动阶段: @@ -577,7 +577,7 @@ def get_share_basic(symbol): :return: collections.OrderedDict, 基础面汇总信息 """ # 公司基础信息:股票名称、所属行业、地域、主营业务等 - basic_info = run_query(table="finance.STK_COMPANY_INFO", conditions="code#=#{}".format(symbol), count=1) + basic_info = run_query(table="finance.STK_COMPANY_INFO", conditions=f"code#=#{symbol}", count=1) basic_info = basic_info.iloc[0].to_dict() f10 = OrderedDict() @@ -605,27 +605,27 @@ def get_share_basic(symbol): for year in ["2017", "2018", "2019", "2020"]: indicator = get_fundamental(table="indicator", symbol=symbol, date=year) # indicator.get(key, 0) 可能返回 None 或空字符串,因此再做一次真值判断后再转 float - f10["{}EPS".format(year)] = float(indicator.get("eps", 0)) if indicator.get("eps", 0) else 0 - f10["{}ROA".format(year)] = float(indicator.get("roa", 0)) if indicator.get("roa", 0) else 0 - f10["{}ROE".format(year)] = float(indicator.get("roe", 0)) if indicator.get("roe", 0) else 0 - f10["{}销售净利率(%)".format(year)] = ( + f10[f"{year}EPS"] = float(indicator.get("eps", 0)) if indicator.get("eps", 0) else 0 + f10[f"{year}ROA"] = float(indicator.get("roa", 0)) if indicator.get("roa", 0) else 0 + f10[f"{year}ROE"] = float(indicator.get("roe", 0)) if indicator.get("roe", 0) else 0 + f10[f"{year}销售净利率(%)"] = ( float(indicator.get("net_profit_margin", 0)) if indicator.get("net_profit_margin", 0) else 0 ) - f10["{}销售毛利率(%)".format(year)] = ( + f10[f"{year}销售毛利率(%)"] = ( float(indicator.get("gross_profit_margin", 0)) if indicator.get("gross_profit_margin", 0) else 0 ) - f10["{}营业收入同比增长率(%)".format(year)] = ( + f10[f"{year}营业收入同比增长率(%)"] = ( float(indicator.get("inc_revenue_year_on_year", 0)) if indicator.get("inc_revenue_year_on_year", 0) else 0 ) - f10["{}营业收入环比增长率(%)".format(year)] = ( + f10[f"{year}营业收入环比增长率(%)"] = ( float(indicator.get("inc_revenue_annual", 0)) if indicator.get("inc_revenue_annual", 0) else 0 ) - f10["{}营业利润同比增长率(%)".format(year)] = ( + f10[f"{year}营业利润同比增长率(%)"] = ( float(indicator.get("inc_operation_profit_year_on_year", 0)) if indicator.get("inc_operation_profit_year_on_year", 0) else 0 ) - f10["{}经营活动产生的现金流量净额/营业收入(%)".format(year)] = ( + f10[f"{year}经营活动产生的现金流量净额/营业收入(%)"] = ( float(indicator.get("ocf_to_revenue", 0)) if indicator.get("ocf_to_revenue", 0) else 0 ) @@ -633,7 +633,7 @@ def get_share_basic(symbol): msg = "{}({})@{}\n".format(f10["股票代码"], f10["股票名称"], f10["地域"]) msg += "\n{}\n".format("*" * 30) for k in ["行业", "主营", "PE_TTM", "PE", "PB", "总市值(亿)", "流通市值(亿)", "流通比(%)", "同花顺F10"]: - msg += "{}:{}\n".format(k, f10[k]) + msg += f"{k}:{f10[k]}\n" msg += "\n{}\n".format("*" * 30) cols = [ @@ -650,7 +650,7 @@ def get_share_basic(symbol): for k in cols: # 把 4 年同一指标横向拼接,便于一眼看出趋势 msg += k + ":{} | {} | {} | {}\n".format( - *[f10["{}{}".format(year, k)] for year in ["2017", "2018", "2019", "2020"]] + *[f10[f"{year}{k}"] for year in ["2017", "2018", "2019", "2020"]] ) f10["msg"] = msg diff --git a/czsc/models.py b/czsc/models.py index 82668637c..bb1edb8f6 100644 --- a/czsc/models.py +++ b/czsc/models.py @@ -14,7 +14,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional, TypedDict +from typing import Any, TypedDict from czsc._utils._df_convert import arrow_bytes_to_pd_df @@ -74,13 +74,13 @@ class ResearchResult: holds_path - 持仓表对应的本地路径(可选) """ - meta: Dict[str, Any] + meta: dict[str, Any] signals_arrow: bytes pairs_arrow: bytes holds_arrow: bytes - signals_path: Optional[str] = None - pairs_path: Optional[str] = None - holds_path: Optional[str] = None + signals_path: str | None = None + pairs_path: str | None = None + holds_path: str | None = None def signals_df(self): """将 ``signals_arrow`` 反序列化为 Pandas DataFrame(按需调用,避免无谓开销)""" @@ -104,6 +104,7 @@ class ReplayResult(ResearchResult): 代码与日志中体现语义差异("复现某次回放"对应 ReplayResult,"批量研究 多组参数"对应 ResearchResult)。 """ + pass @@ -116,4 +117,5 @@ class OptimizeResult: (成功概要 / 警告 / 错误信息)。后续若需要扩展更结构化的优化指标 (如最优参数、得分排序等),可在此追加字段,保持向后兼容。 """ + message: str diff --git a/czsc/research.py b/czsc/research.py index 7d50bd159..214866604 100644 --- a/czsc/research.py +++ b/czsc/research.py @@ -28,19 +28,29 @@ position_dump_to_runtime, signal_config_to_runtime, ) + # 直接调用 PyO3 暴露的 Rust 实现(带下划线别名表示"不要在调用方代码中再展开") from czsc._native import ( build_exit_optim_positions as _build_exit_optim_positions, +) +from czsc._native import ( build_open_optim_positions as _build_open_optim_positions, +) +from czsc._native import ( run_optimize, +) +from czsc._native import ( run_optimize_batch as _run_optimize_batch, +) +from czsc._native import ( run_replay as _run_replay, +) +from czsc._native import ( run_research as _run_research, ) from czsc._utils._df_convert import pandas_to_arrow_bytes from czsc.models import OptimizeResult, ReplayResult, ResearchResult - # 类型别名:bars 入参允许传 DataFrame 或已就绪的 Arrow IPC 字节 # 这种"两可"形式可以让上层在已经持有字节流的场景下省一次序列化 BarsLike = pd.DataFrame | bytes @@ -129,8 +139,7 @@ def run_research( # positions / signals_config 都是用户层格式,需要先归一化为 Rust 运行时期望的紧凑布局 if "positions" in strategy_payload: strategy_payload["positions"] = [ - position_dump_to_runtime(pos) if isinstance(pos, dict) else pos - for pos in strategy_payload["positions"] + position_dump_to_runtime(pos) if isinstance(pos, dict) else pos for pos in strategy_payload["positions"] ] if "signals_config" in strategy_payload: strategy_payload["signals_config"] = [ @@ -180,8 +189,7 @@ def run_replay( strategy_payload = dict(strategy) if "positions" in strategy_payload: strategy_payload["positions"] = [ - position_dump_to_runtime(pos) if isinstance(pos, dict) else pos - for pos in strategy_payload["positions"] + position_dump_to_runtime(pos) if isinstance(pos, dict) else pos for pos in strategy_payload["positions"] ] if "signals_config" in strategy_payload: strategy_payload["signals_config"] = [ diff --git a/czsc/signals/_helpers.py b/czsc/signals/_helpers.py index 966d4921c..984666dc4 100644 --- a/czsc/signals/_helpers.py +++ b/czsc/signals/_helpers.py @@ -36,7 +36,8 @@ from __future__ import annotations -from typing import Any, Callable +from collections.abc import Callable +from typing import Any # 从 Rust 扩展模块按需导入底层接口;使用 ``as _xxx`` 显式标注为内部依赖,避免被外部直接引用 from czsc._native import call_signal as _call_signal diff --git a/czsc/signals/bar.py b/czsc/signals/bar.py index 6d0696d9c..26d8bdfea 100644 --- a/czsc/signals/bar.py +++ b/czsc/signals/bar.py @@ -25,8 +25,14 @@ from czsc.signals._helpers import ( get_signal_template as _get_signal_template, +) +from czsc.signals._helpers import ( list_signals as _list_signals, +) +from czsc.signals._helpers import ( make_signal_callable as _make_signal_callable, +) +from czsc.signals._helpers import ( parse_signal_value as _parse_signal_value, ) diff --git a/czsc/signals/cvolp.py b/czsc/signals/cvolp.py index a9eee4e6c..0b97880a4 100644 --- a/czsc/signals/cvolp.py +++ b/czsc/signals/cvolp.py @@ -25,8 +25,14 @@ from czsc.signals._helpers import ( get_signal_template as _get_signal_template, +) +from czsc.signals._helpers import ( list_signals as _list_signals, +) +from czsc.signals._helpers import ( make_signal_callable as _make_signal_callable, +) +from czsc.signals._helpers import ( parse_signal_value as _parse_signal_value, ) diff --git a/czsc/signals/cxt.py b/czsc/signals/cxt.py index c9e8a327a..5d8aaf87c 100644 --- a/czsc/signals/cxt.py +++ b/czsc/signals/cxt.py @@ -25,8 +25,14 @@ from czsc.signals._helpers import ( get_signal_template as _get_signal_template, +) +from czsc.signals._helpers import ( list_signals as _list_signals, +) +from czsc.signals._helpers import ( make_signal_callable as _make_signal_callable, +) +from czsc.signals._helpers import ( parse_signal_value as _parse_signal_value, ) diff --git a/czsc/signals/obv.py b/czsc/signals/obv.py index c834bc8bb..2fd541608 100644 --- a/czsc/signals/obv.py +++ b/czsc/signals/obv.py @@ -25,8 +25,14 @@ from czsc.signals._helpers import ( get_signal_template as _get_signal_template, +) +from czsc.signals._helpers import ( list_signals as _list_signals, +) +from czsc.signals._helpers import ( make_signal_callable as _make_signal_callable, +) +from czsc.signals._helpers import ( parse_signal_value as _parse_signal_value, ) diff --git a/czsc/signals/pressure.py b/czsc/signals/pressure.py index b4c4a81b4..6880ff069 100644 --- a/czsc/signals/pressure.py +++ b/czsc/signals/pressure.py @@ -27,8 +27,14 @@ from czsc.signals._helpers import ( get_signal_template as _get_signal_template, +) +from czsc.signals._helpers import ( list_signals as _list_signals, +) +from czsc.signals._helpers import ( make_signal_callable as _make_signal_callable, +) +from czsc.signals._helpers import ( parse_signal_value as _parse_signal_value, ) diff --git a/czsc/signals/tas.py b/czsc/signals/tas.py index 5be9efa7a..cdc9bbb54 100644 --- a/czsc/signals/tas.py +++ b/czsc/signals/tas.py @@ -27,8 +27,14 @@ from czsc.signals._helpers import ( get_signal_template as _get_signal_template, +) +from czsc.signals._helpers import ( list_signals as _list_signals, +) +from czsc.signals._helpers import ( make_signal_callable as _make_signal_callable, +) +from czsc.signals._helpers import ( parse_signal_value as _parse_signal_value, ) diff --git a/czsc/signals/vol.py b/czsc/signals/vol.py index ec3f5f656..51f6dfd9d 100644 --- a/czsc/signals/vol.py +++ b/czsc/signals/vol.py @@ -26,8 +26,14 @@ from czsc.signals._helpers import ( get_signal_template as _get_signal_template, +) +from czsc.signals._helpers import ( list_signals as _list_signals, +) +from czsc.signals._helpers import ( make_signal_callable as _make_signal_callable, +) +from czsc.signals._helpers import ( parse_signal_value as _parse_signal_value, ) diff --git a/czsc/strategies.py b/czsc/strategies.py index ece0353d7..f54d4bc39 100644 --- a/czsc/strategies.py +++ b/czsc/strategies.py @@ -25,8 +25,6 @@ from pathlib import Path from typing import Any -import pandas as pd - from czsc._compat import ( bars_to_dataframe, position_dump_to_runtime, @@ -35,9 +33,12 @@ sort_freqs, ) from czsc._native import Position + # 直接调用 Rust 端的派生器(用下划线后缀别名,避免与同名公开 API 混淆) from czsc._native import ( derive_signals_config as _derive_signals_config_impl, +) +from czsc._native import ( derive_signals_freqs as _derive_signals_freqs_impl, ) from czsc.research import run_replay, run_research @@ -216,9 +217,7 @@ def save_positions(self, path): payload.pop("symbol", None) # md5 校验码:基于序列化字符串生成,加载时可校验配置完整性 payload["md5"] = hashlib.md5(str(payload).encode("utf-8")).hexdigest() - (out_dir / f"{payload['name']}.json").write_text( - json.dumps(payload, ensure_ascii=False), encoding="utf-8" - ) + (out_dir / f"{payload['name']}.json").write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") def load_positions(self, files, check=True): """ @@ -262,9 +261,7 @@ def _build_runtime_strategy(self, overrides: dict[str, Any]) -> dict[str, Any]: "base_freq": self.base_freq, "signals_module": self.signals_module_name, "signals_config": [signal_config_to_runtime(cfg) for cfg in self.signals_config], - "positions": [ - position_dump_to_runtime(pos.dump(with_data=False)) for pos in self.positions - ], + "positions": [position_dump_to_runtime(pos.dump(with_data=False)) for pos in self.positions], "market": self.kwargs.get("market", "默认"), "bg_max_count": int(self.kwargs.get("bg_max_count", 5000)), # 仅当 sdt 存在时才注入字段,避免显式写 None 触发 Rust 端 schema 错误 @@ -326,6 +323,4 @@ class CzscJsonStrategy(CzscStrategyBase): @property def positions(self): """从 ``files_position`` 加载并返回 Position 列表,受 ``check_position`` 控制是否校验""" - return self.load_positions( - self.kwargs["files_position"], self.kwargs.get("check_position", True) - ) + return self.load_positions(self.kwargs["files_position"], self.kwargs.get("check_position", True)) diff --git a/czsc/svc/backtest.py b/czsc/svc/backtest.py index e6704ef12..801ddaf80 100644 --- a/czsc/svc/backtest.py +++ b/czsc/svc/backtest.py @@ -23,7 +23,6 @@ import pandas as pd import streamlit as st from loguru import logger - from wbt import WeightBacktest diff --git a/czsc/svc/factor.py b/czsc/svc/factor.py index a5af96575..0610ee83c 100644 --- a/czsc/svc/factor.py +++ b/czsc/svc/factor.py @@ -20,9 +20,9 @@ import plotly.express as px import plotly.graph_objects as go import streamlit as st +from wbt import daily_performance from .base import apply_stats_style, generate_component_key -from wbt import daily_performance def show_feature_returns(df, features, ret_col="returns", key=None, **kwargs): diff --git a/czsc/svc/price_analysis.py b/czsc/svc/price_analysis.py index f4cfb914d..7a195efeb 100644 --- a/czsc/svc/price_analysis.py +++ b/czsc/svc/price_analysis.py @@ -20,9 +20,9 @@ import pandas as pd import streamlit as st from loguru import logger +from wbt import WeightBacktest from .base import apply_stats_style -from wbt import WeightBacktest from .returns import show_cumulative_returns diff --git a/czsc/svc/returns.py b/czsc/svc/returns.py index 72d24b97a..59fda6c61 100644 --- a/czsc/svc/returns.py +++ b/czsc/svc/returns.py @@ -20,11 +20,12 @@ import plotly.express as px import plotly.graph_objects as go import streamlit as st - -from .base import apply_stats_style, ensure_datetime_index, generate_component_key from wbt import daily_performance + from czsc import top_drawdowns +from .base import apply_stats_style, ensure_datetime_index, generate_component_key + def show_daily_return(df: pd.DataFrame, key=None, **kwargs): """用 streamlit 展示日收益 diff --git a/czsc/svc/statistics.py b/czsc/svc/statistics.py index ba9083971..2ba8f836f 100644 --- a/czsc/svc/statistics.py +++ b/czsc/svc/statistics.py @@ -21,9 +21,9 @@ import plotly.express as px import streamlit as st from deprecated import deprecated +from wbt import daily_performance from .base import apply_stats_style, ensure_datetime_index, generate_component_key -from wbt import daily_performance def show_splited_daily(df, ret_col, **kwargs): diff --git a/czsc/svc/strategy.py b/czsc/svc/strategy.py index 25d5df36d..551de51c9 100644 --- a/czsc/svc/strategy.py +++ b/czsc/svc/strategy.py @@ -26,9 +26,9 @@ import pandas as pd import plotly.express as px import streamlit as st +from wbt import WeightBacktest from .base import apply_stats_style, generate_component_key -from wbt import WeightBacktest def show_optuna_study(study, key=None, **kwargs): @@ -334,6 +334,7 @@ def show_symbols_bench(df: pd.DataFrame, **kwargs): :return: None """ from wbt import daily_performance + from czsc.eda import cal_yearly_days df = df[["symbol", "dt", "price"]].copy() @@ -386,11 +387,10 @@ def show_quarterly_effect(returns: pd.Series, key=None): :return: None """ import plotly.express as px + from wbt import daily_performance from czsc.eda import cal_yearly_days - from wbt import daily_performance - returns.index = pd.to_datetime(returns.index) yearly_days = cal_yearly_days(returns.index.to_list()) @@ -761,6 +761,7 @@ def show_portfolio(df: pd.DataFrame, portfolio: str, benchmark: str | None = Non :return: None """ from wbt import daily_performance + from czsc.eda import cal_yearly_days if benchmark is not None: diff --git a/czsc/traders/__init__.py b/czsc/traders/__init__.py index b8a586b2a..2e48f0202 100644 --- a/czsc/traders/__init__.py +++ b/czsc/traders/__init__.py @@ -23,6 +23,8 @@ # 直接从 Rust 原生扩展中导入交易体系核心类与帮助函数, # 确保 Python 侧只承担"调用-转发"职责,业务逻辑落在 Rust 端。 +from wbt import WeightBacktest + from czsc._native import ( CzscSignals, CzscTrader, @@ -30,7 +32,6 @@ derive_signals_freqs, generate_czsc_signals, ) -from wbt import WeightBacktest # 兼容老的导入路径:保留 Python 侧的薄封装函数 get_unique_signals。 from czsc.traders.base import get_unique_signals diff --git a/czsc/traders/__init__.pyi b/czsc/traders/__init__.pyi index e82fe848f..ee7fbda18 100644 --- a/czsc/traders/__init__.pyi +++ b/czsc/traders/__init__.pyi @@ -1,3 +1,5 @@ +from wbt import WeightBacktest as WeightBacktest + from czsc._native import ( CzscSignals as CzscSignals, ) @@ -22,6 +24,5 @@ from czsc.traders.sig_parse import ( from czsc.traders.sig_parse import ( get_signals_freqs as get_signals_freqs, ) -from wbt import WeightBacktest as WeightBacktest def __getattr__(name): ... diff --git a/czsc/traders/base.py b/czsc/traders/base.py index 953f9af47..b5171f20f 100644 --- a/czsc/traders/base.py +++ b/czsc/traders/base.py @@ -31,7 +31,6 @@ derive_signals_freqs, generate_czsc_signals, ) - from czsc.traders.sig_parse import get_signals_freqs diff --git a/czsc/traders/optimize.py b/czsc/traders/optimize.py index 73d703318..3611aee12 100644 --- a/czsc/traders/optimize.py +++ b/czsc/traders/optimize.py @@ -26,8 +26,9 @@ import hashlib import json +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable +from typing import Any # 兼容层提供的辅助函数,统一处理 Python <-> Rust 之间的数据转换、 # 序列化、哈希计算和事件归一化等跨语言桥接逻辑。 @@ -278,12 +279,15 @@ def __init__(self, read_bars: Callable, **kwargs): self.signals_module_name = kwargs.get("signals_module_name", "czsc.signals") # base_freq 优先取用户显式配置;否则借助策略类自动推导, # 保证后续读取 K 线和写入 Rust 配置时频率信息一致。 - self.base_freq = kwargs.get("base_freq") or CzscOpenOptimStrategy( - symbol="symbol", - files_position=self.files_position, - candidate_signals=self.candidate_signals, - signals_module_name=self.signals_module_name, - ).base_freq + self.base_freq = ( + kwargs.get("base_freq") + or CzscOpenOptimStrategy( + symbol="symbol", + files_position=self.files_position, + candidate_signals=self.candidate_signals, + signals_module_name=self.signals_module_name, + ).base_freq + ) self.results_root = Path(kwargs["results_path"]) # 用候选信号集合 + 标的列表的字符串拼接做 MD5,截前 8 位作任务哈希; # 相同输入会得到相同的输出目录,便于结果复用与覆盖。 @@ -340,9 +344,7 @@ def _materialize_bars_dir(self): for symbol in self.symbols: bars = _read_bars(self.read_bars, symbol, self.base_freq, bar_sdt, bar_edt) # parquet 不写 index,Rust 端按列名读取,避免歧义。 - bars_to_dataframe(bars, symbol=symbol).to_parquet( - bars_dir / f"{symbol}.parquet", index=False - ) + bars_to_dataframe(bars, symbol=symbol).to_parquet(bars_dir / f"{symbol}.parquet", index=False) return bars_dir def _materialize_position_files(self): @@ -417,17 +419,18 @@ def __init__(self, read_bars: Callable, **kwargs): self.candidate_events = normalize_candidate_events(kwargs["candidate_events"]) self.signals_module_name = kwargs.get("signals_module_name", "czsc.signals") # 与 OpensOptimize 对称:未显式指定 base_freq 时通过策略类反推。 - self.base_freq = kwargs.get("base_freq") or CzscExitOptimStrategy( - symbol="symbol", - files_position=self.files_position, - candidate_events=self.candidate_events, - signals_module_name=self.signals_module_name, - ).base_freq + self.base_freq = ( + kwargs.get("base_freq") + or CzscExitOptimStrategy( + symbol="symbol", + files_position=self.files_position, + candidate_events=self.candidate_events, + signals_module_name=self.signals_module_name, + ).base_freq + ) self.results_root = Path(kwargs["results_path"]) # 候选事件用 JSON 化字符串再做 MD5,避免 dict 不同顺序导致哈希漂移。 - self.task_hash = md5_upper8( - f"{py_repr_json(self.candidate_events)}_{py_repr_list_str(self.symbols)}" - ) + self.task_hash = md5_upper8(f"{py_repr_json(self.candidate_events)}_{py_repr_list_str(self.symbols)}") self.results_path = str(self.results_root / f"{self.task_name}_{self.task_hash}") self.poss_path = str(Path(self.results_path) / "poss") @@ -472,9 +475,7 @@ def _materialize_bars_dir(self): bar_edt = self.kwargs.get("bar_edt", "20220101") for symbol in self.symbols: bars = _read_bars(self.read_bars, symbol, self.base_freq, bar_sdt, bar_edt) - bars_to_dataframe(bars, symbol=symbol).to_parquet( - bars_dir / f"{symbol}.parquet", index=False - ) + bars_to_dataframe(bars, symbol=symbol).to_parquet(bars_dir / f"{symbol}.parquet", index=False) return bars_dir def _materialize_position_files(self): diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index fcb63ff2c..87827ad97 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -67,7 +67,6 @@ # 注意:``sig`` 模块依赖 ``czsc`` 顶层包,存在循环导入风险,故此处不预先 re-export, # 调用方需要按需通过 ``from czsc.utils.sig import ...`` 方式直接引用。 - # 交易/重采样相关工具 from .trade import resample_to_daily, risk_free_returns, update_bbars, update_nxb, update_tbars diff --git a/test/compat/test_public_api.py b/test/compat/test_public_api.py index dcb4ac623..911b45669 100644 --- a/test/compat/test_public_api.py +++ b/test/compat/test_public_api.py @@ -57,9 +57,7 @@ def test_top_level_names_importable() -> None: czsc, err = _safe_import("czsc") assert czsc is not None, f"failed to import czsc: {err}" missing = [name for name in snap["top_level"] if not hasattr(czsc, name)] - assert not missing, ( - f"czsc.* missing {len(missing)} required public names: {missing}" - ) + assert not missing, f"czsc.* missing {len(missing)} required public names: {missing}" def test_signal_subpackages_present() -> None: @@ -74,9 +72,7 @@ def test_signal_subpackages_present() -> None: mod, err = _safe_import(f"czsc.signals.{sub}") if mod is None: failures.append(f"czsc.signals.{sub} ({err})") - assert not failures, ( - f"czsc.signals.* missing {len(failures)} required subpackages: {failures}" - ) + assert not failures, f"czsc.signals.* missing {len(failures)} required subpackages: {failures}" def test_traders_namespace_complete() -> None: @@ -85,9 +81,7 @@ def test_traders_namespace_complete() -> None: traders, err = _safe_import("czsc.traders") assert traders is not None, f"failed to import czsc.traders: {err}" missing = [name for name in snap["traders"] if not hasattr(traders, name)] - assert not missing, ( - f"czsc.traders.* missing {len(missing)} required public names: {missing}" - ) + assert not missing, f"czsc.traders.* missing {len(missing)} required public names: {missing}" def test_ta_namespace_complete() -> None: @@ -96,9 +90,7 @@ def test_ta_namespace_complete() -> None: ta, err = _safe_import("czsc.ta") assert ta is not None, f"failed to import czsc.ta: {err}" missing = [name for name in snap["ta"] if not hasattr(ta, name)] - assert not missing, ( - f"czsc.ta.* missing {len(missing)} required public names: {missing}" - ) + assert not missing, f"czsc.ta.* missing {len(missing)} required public names: {missing}" def test_no_legacy_dummy_backtest() -> None: @@ -111,9 +103,7 @@ def test_no_legacy_dummy_backtest() -> None: czsc, err = _safe_import("czsc") assert czsc is not None, f"failed to import czsc: {err}" leftover = [name for name in snap["removed"] if hasattr(czsc, name)] - assert not leftover, ( - f"czsc.* still exposes legacy names that must be removed: {leftover}" - ) + assert not leftover, f"czsc.* still exposes legacy names that must be removed: {leftover}" def test_no_czsc_use_python_branch() -> None: @@ -125,9 +115,7 @@ def test_no_czsc_use_python_branch() -> None: envs, err = _safe_import("czsc.envs") assert envs is not None, f"failed to import czsc.envs: {err}" leftover = [name for name in snap["removed_envs"] if hasattr(envs, name)] - assert not leftover, ( - f"czsc.envs still exposes removed env vars: {leftover}" - ) + assert not leftover, f"czsc.envs still exposes removed env vars: {leftover}" def test_weight_backtest_comes_from_wbt() -> None: diff --git a/test/integration/test_weight_backtest.py b/test/integration/test_weight_backtest.py index 44a4d8d26..5379d2a57 100644 --- a/test/integration/test_weight_backtest.py +++ b/test/integration/test_weight_backtest.py @@ -56,10 +56,7 @@ def test_czsc_attr_is_wbt_attr(attr_name: str) -> None: assert czsc is not None, f"failed to import czsc: {czsc_err}" wbt, wbt_err = _safe_import("wbt") if wbt is None: - pytest.fail( - f"wbt 必须作为硬依赖存在 ({wbt_err});" - f"czsc.{attr_name} 必须从 wbt 重导出" - ) + pytest.fail(f"wbt 必须作为硬依赖存在 ({wbt_err});czsc.{attr_name} 必须从 wbt 重导出") czsc_attr = getattr(czsc, attr_name, None) wbt_attr = getattr(wbt, attr_name, None) @@ -94,7 +91,4 @@ def test_no_residual_rs_czsc_dependency() -> None: if wb is None: pytest.fail("czsc.WeightBacktest 缺失") module = getattr(wb, "__module__", "?") - assert "rs_czsc" not in module, ( - f"czsc.WeightBacktest 仍然通过 {module!r} 路由;" - f"必须替换为 wbt.WeightBacktest" - ) + assert "rs_czsc" not in module, f"czsc.WeightBacktest 仍然通过 {module!r} 路由;必须替换为 wbt.WeightBacktest" diff --git a/test/parity/_signal_defaults.py b/test/parity/_signal_defaults.py index 55ac67805..ba1ba6978 100644 --- a/test/parity/_signal_defaults.py +++ b/test/parity/_signal_defaults.py @@ -26,8 +26,10 @@ "freq1": "日线", # 计数器 / 回看窗口(lookback) "di": "1", - "n": "5", "N": "5", - "m": "20", "M": "20", + "n": "5", + "N": "5", + "m": "20", + "M": "20", "p": "5", "q": "5", "k": "5", @@ -36,23 +38,32 @@ "s": "5", "z": "5", "t": "5", - "w": "15", # >10 才能满足"压力位/支撑位"信号的 assert + "w": "15", # >10 才能满足"压力位/支撑位"信号的 assert "window": "15", "rumi_window": "15", # 成对参数(必须满足 a < b 的约束) - "t1": "5", "t2": "20", - "th1": "5", "th2": "20", "th3": "30", - "tha": "5", "thb": "50", "thc": "500", - "timeperiod1": "5", "timeperiod2": "20", - "min_count": "3", "max_count": "10", + "t1": "5", + "t2": "20", + "th1": "5", + "th2": "20", + "th3": "30", + "tha": "5", + "thb": "50", + "thc": "500", + "timeperiod1": "5", + "timeperiod2": "20", + "min_count": "3", + "max_count": "10", # 单一阈值(RSI / 动量类信号常用) - "th": "50", # 必须落在 30 < th < 300 区间 - "ndev": "2", "nbdev": "2", + "th": "50", # 必须落在 30 < th < 300 区间 + "ndev": "2", + "nbdev": "2", "avg_bp": "5", "bi_init_length": "20", "max_overlap": "3", "num": "5", - "up": "1", "dw": "1", + "up": "1", + "dw": "1", "zf": "5", "tl": "5", "key": "close", @@ -73,8 +84,10 @@ "lp": "20", # 常量类占位符 "mode": "CO", - "K1": "K1", "K2": "K2", - "c1": "K1", "c2": "K2", + "K1": "K1", + "K2": "K2", + "c1": "K1", + "c2": "K2", } diff --git a/test/parity/bench_optimize.py b/test/parity/bench_optimize.py index 315072962..665b67fa9 100644 --- a/test/parity/bench_optimize.py +++ b/test/parity/bench_optimize.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """性能基准测试脚本:对比 ``rs_czsc`` 与迁移后的 ``czsc`` 在 ``OpensOptimize`` / ``ExitsOptimize`` 工作流上的耗时表现。 @@ -22,25 +21,18 @@ from __future__ import annotations import argparse -import hashlib -import importlib -import json -import shutil import statistics import sys import tempfile import time from pathlib import Path -import pandas as pd - # 把 parity 测试目录加入 sys.path,方便复用其中的辅助函数 ROOT = Path(__file__).resolve().parents[2] PARITY_DIR = ROOT / "test" / "parity" sys.path.insert(0, str(PARITY_DIR)) sys.path.insert(0, str(PARITY_DIR / "_compare_optimize")) -from _signal_defaults import render # noqa: E402 # 复用 parity 脚本中的工具函数(数据准备、模块导入、仓位文件落盘等) from compare_optimize_full import ( # noqa: E402 @@ -74,11 +66,16 @@ def time_open(module_name: str, results_root: Path, candidates: list[str]) -> fl files_position = write_beta_positions(czsc_mod, results_root / "base_positions", "000001") oop = OpensOptimize( - symbols=["000001"], files_position=files_position, - task_name="BenchOpen", candidate_signals=candidates, - read_bars=get_raw_bars, results_path=results_root, + symbols=["000001"], + files_position=files_position, + task_name="BenchOpen", + candidate_signals=candidates, + read_bars=get_raw_bars, + results_path=results_root, signals_module_name="czsc.signals", - bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + bar_sdt=bar_sdt, + bar_edt=bar_edt, + sdt=sdt, ) t0 = time.perf_counter() oop.execute(n_jobs=1) @@ -106,13 +103,18 @@ def time_exit(module_name: str, results_root: Path, candidate_events: list[dict] files_position = write_beta_positions(czsc_mod, results_root / "base_positions", "000001") eop = ExitsOptimize( - symbols=["000001"], files_position=files_position, - task_name="BenchExit", candidate_events=candidate_events, - read_bars=get_raw_bars, results_path=results_root, + symbols=["000001"], + files_position=files_position, + task_name="BenchExit", + candidate_events=candidate_events, + read_bars=get_raw_bars, + results_path=results_root, signals_module_name="czsc.signals", # 显式指定 base_freq 是为了绕过 czsc 在自动推导时对 strategy.positions 的处理 bug base_freq="5分钟", - bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + bar_sdt=bar_sdt, + bar_edt=bar_edt, + sdt=sdt, ) t0 = time.perf_counter() eop.execute(n_jobs=1) @@ -122,9 +124,11 @@ def time_exit(module_name: str, results_root: Path, candidate_events: list[dict] def fmt(times: list[float]) -> str: """把多次试验的耗时列表格式化为 ``mean±stdev (min/max)`` 字符串。""" if len(times) < 2: - return f"{times[0]*1000:.0f}ms" - return (f"{statistics.mean(times)*1000:.0f}±{statistics.stdev(times)*1000:.0f}ms " - f"(min {min(times)*1000:.0f}ms / max {max(times)*1000:.0f}ms)") + return f"{times[0] * 1000:.0f}ms" + return ( + f"{statistics.mean(times) * 1000:.0f}±{statistics.stdev(times) * 1000:.0f}ms " + f"(min {min(times) * 1000:.0f}ms / max {max(times) * 1000:.0f}ms)" + ) def main(): @@ -136,15 +140,13 @@ def main(): # 候选信号 / 候选事件预计算一次即可:parity 测试已经证明两套实现 # 在这部分输入上完全一致,因此可以放心共享。 import czsc as _cz + candidate_signals = all_kline_candidate_signals(_cz) candidate_events = all_kline_candidate_events(_cz) # 把可能存在的 dict 形式的信号统一回退成字符串形式,确保两套实现 # 在最终消费时拿到完全相同的输入(双保险)。 for e in candidate_events: - e["signals_all"] = [ - (s if isinstance(s, str) else f"{s['key']}_{s['value']}") - for s in e["signals_all"] - ] + e["signals_all"] = [(s if isinstance(s, str) else f"{s['key']}_{s['value']}") for s in e["signals_all"]] print(f"trials={args.trials}, candidate_signals={len(candidate_signals)}, candidate_events={len(candidate_events)}") print() @@ -153,8 +155,7 @@ def main(): # 然后再跑 args.trials 次正式试验。 results: dict = {"open": {}, "exit": {}} - for kind, module in [("open", "rs_czsc"), ("open", "czsc"), - ("exit", "rs_czsc"), ("exit", "czsc")]: + for kind, module in [("open", "rs_czsc"), ("open", "czsc"), ("exit", "rs_czsc"), ("exit", "czsc")]: # warm-up 阶段:失败时记录但不阻塞后续测量 with tempfile.TemporaryDirectory() as tmp: try: @@ -174,7 +175,7 @@ def main(): else: t = time_exit(module, Path(tmp), candidate_events) times.append(t) - print(f" [{module}/{kind}] trial {i+1}: {t*1000:.0f}ms") + print(f" [{module}/{kind}] trial {i + 1}: {t * 1000:.0f}ms") results[kind][module] = times # 汇总输出 diff --git a/test/parity/compare_optimize_full.py b/test/parity/compare_optimize_full.py index f99db4f15..724c4568a 100644 --- a/test/parity/compare_optimize_full.py +++ b/test/parity/compare_optimize_full.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """完整 K 线信号集合下的 Open/Exit 优化等价性对比脚本。 本脚本对比 ``rs_czsc`` 与迁移后的 ``czsc`` 在 ``use_optimize.py`` 工作流 @@ -43,12 +42,14 @@ # K 线数据准备 —— 用 mock 替换原示例缺失的 k_line.feather # # --------------------------------------------------------------------- # + def make_bars_df(freq: str, sdt: str, edt: str) -> pd.DataFrame: """生成单品种 mock K 线 DataFrame,列与示例脚本对齐。 使用固定 ``seed=42`` 以保证两次运行(rs_czsc 与 czsc)的输入完全一致。 """ from wbt.mock import mock_symbol_kline + df = mock_symbol_kline("000001", freq, sdt, edt, seed=42) df["dt"] = pd.to_datetime(df["dt"]) cols = ["dt", "symbol", "open", "high", "low", "close", "vol", "amount"] @@ -59,6 +60,7 @@ def make_bars_df(freq: str, sdt: str, edt: str) -> pd.DataFrame: # Beta 仓位构造与 read_bars 回调 # # --------------------------------------------------------------------- # + def _sig_str_to_kv(sig: str) -> dict: """把七段式信号字符串拆分成 ``{key, value}`` 字典形式。 @@ -88,16 +90,18 @@ def event_dict(name_, op, sig): # czsc.Position 不接受 T0 关键字参数;rs_czsc 接受。 # 通过 .load 走 dict 入口可以让两套实现保持完全一致的入参形态。 - return Position.load({ - "symbol": symbol, - "name": name, - "opens": [event_dict(f"{name}_open", open_operate, open_signal)], - "exits": [event_dict(f"{name}_exit", exit_operate, exit_signal)], - "interval": 0, - "timeout": 120, - "stop_loss": 800.0, - "T0": False, - }) + return Position.load( + { + "symbol": symbol, + "name": name, + "opens": [event_dict(f"{name}_open", open_operate, open_signal)], + "exits": [event_dict(f"{name}_exit", exit_operate, exit_signal)], + "interval": 0, + "timeout": 120, + "stop_loss": 800.0, + "T0": False, + } + ) def write_beta_positions(czsc_module, path: Path, symbol: str) -> list[str]: @@ -109,12 +113,18 @@ def write_beta_positions(czsc_module, path: Path, symbol: str) -> list[str]: path.mkdir(parents=True, exist_ok=True) positions = [ build_position( - czsc_module, symbol, "long_beta", - "5分钟_D1单K趋势N5_BS辅助V230506_第1层_任意_任意_0", "开多", + czsc_module, + symbol, + "long_beta", + "5分钟_D1单K趋势N5_BS辅助V230506_第1层_任意_任意_0", + "开多", ), build_position( - czsc_module, symbol, "short_beta", - "5分钟_D1单K趋势N5_BS辅助V230506_第18层_任意_任意_0", "开空", + czsc_module, + symbol, + "short_beta", + "5分钟_D1单K趋势N5_BS辅助V230506_第18层_任意_任意_0", + "开空", ), ] files = [] @@ -133,6 +143,7 @@ def write_beta_positions(czsc_module, path: Path, symbol: str) -> list[str]: # 候选信号集合 —— 全量 K 线信号(按默认参数渲染) # # --------------------------------------------------------------------- # + def all_kline_candidate_signals(czsc_module) -> list[str]: """渲染信号注册表里所有 K 线类信号为完整七段式字符串。 @@ -160,6 +171,7 @@ def all_kline_candidate_signals(czsc_module) -> list[str]: # 驱动逻辑 —— 对每个模块各跑一次 OpensOptimize / ExitsOptimize # # --------------------------------------------------------------------- # + def _import_module(module_name: str): """根据名字导入对应的 czsc 模块和 traders.optimize 子模块。 @@ -170,12 +182,14 @@ def _import_module(module_name: str): if module_name == "rs_czsc": import rs_czsc as czsc_mod from rs_czsc import Event as _Event + if not getattr(_Event, "_rs_tuple_contract_patch", False): origin = _Event.is_match def _wrapped(self, sig): out = origin(self, sig) return out if isinstance(out, tuple) else (out, "is_match" if out else "") + _Event.is_match = _wrapped _Event._rs_tuple_contract_patch = True elif module_name == "czsc": @@ -183,16 +197,19 @@ def _wrapped(self, sig): else: raise ValueError(module_name) import importlib + optim_mod = importlib.import_module(f"{module_name}.traders.optimize") return czsc_mod, optim_mod def _make_read_bars(czsc_mod, bars_5min, bars_daily): """构造一个 ``read_bars`` 回调,按 freq 选择对应频率的 mock 数据。""" + def get_raw_bars(symbol_in, freq_in, sdt_in, edt_in, **_): df = bars_daily if freq_in == "日线" else bars_5min df = df[df["symbol"] == symbol_in] return czsc_mod.format_standard_kline(df, freq=freq_in) + return get_raw_bars @@ -210,13 +227,15 @@ def all_kline_candidate_events(czsc_mod) -> list[dict]: out = [] for i, sig in enumerate(all_kline_candidate_signals(czsc_mod)): operate = "平多" if i % 2 == 0 else "平空" - out.append({ - "name": f"exit_{i:03d}", - "operate": operate, - "signals_all": [sig], - "signals_any": [], - "signals_not": [], - }) + out.append( + { + "name": f"exit_{i:03d}", + "operate": operate, + "signals_all": [sig], + "signals_any": [], + "signals_not": [], + } + ) return out @@ -245,7 +264,9 @@ def run_open(module_name: str, results_root: Path) -> Path: read_bars=get_raw_bars, results_path=results_root, signals_module_name="czsc.signals", - bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + bar_sdt=bar_sdt, + bar_edt=bar_edt, + sdt=sdt, ) t0 = time.perf_counter() oop.execute(n_jobs=1) @@ -283,7 +304,9 @@ def run_exit(module_name: str, results_root: Path) -> Path: # 会对字符串形式的 signals_all 调用 Position.load(其更严格的校验 # 器会拒绝该形态)。两套实现的 Rust optimizer 调用本身都接受字符串。 base_freq="5分钟", - bar_sdt=bar_sdt, bar_edt=bar_edt, sdt=sdt, + bar_sdt=bar_sdt, + bar_edt=bar_edt, + sdt=sdt, ) t0 = time.perf_counter() eop.execute(n_jobs=1) @@ -296,6 +319,7 @@ def run_exit(module_name: str, results_root: Path) -> Path: # 输出树对比 # # --------------------------------------------------------------------- # + def inventory(root: Path) -> dict[str, int]: """递归扫描 ``root`` 下的所有文件,返回 {相对路径: 字节数}。""" out: dict[str, int] = {} @@ -349,11 +373,14 @@ def compare_trees(rs_path: Path, czsc_path: Path) -> dict: continue cols = sorted(set(a.columns) & set(b.columns)) if set(a.columns) != set(b.columns): - summary["parquet_diffs"].append({ - "rel": rel, "kind": "columns", - "rs_only": sorted(set(a.columns) - set(b.columns)), - "czsc_only": sorted(set(b.columns) - set(a.columns)), - }) + summary["parquet_diffs"].append( + { + "rel": rel, + "kind": "columns", + "rs_only": sorted(set(a.columns) - set(b.columns)), + "czsc_only": sorted(set(b.columns) - set(a.columns)), + } + ) a = a[cols].reset_index(drop=True) b = b[cols].reset_index(drop=True) try: diff --git a/test/parity/conftest.py b/test/parity/conftest.py index 3d251c062..962479787 100644 --- a/test/parity/conftest.py +++ b/test/parity/conftest.py @@ -82,9 +82,7 @@ def sample_position_dict(): { "name": "open_long", "operate": "开多", - "signals_all": [ - {"key": "日线_D1N5M5TH10_ADTMV230603", "value": "看多_任意_任意_0"} - ], + "signals_all": [{"key": "日线_D1N5M5TH10_ADTMV230603", "value": "看多_任意_任意_0"}], "signals_any": [], "signals_not": [], }, @@ -93,9 +91,7 @@ def sample_position_dict(): { "name": "exit_long", "operate": "平多", - "signals_all": [ - {"key": "日线_D1N5M5TH10_ADTMV230603", "value": "看空_任意_任意_0"} - ], + "signals_all": [{"key": "日线_D1N5M5TH10_ADTMV230603", "value": "看空_任意_任意_0"}], "signals_any": [], "signals_not": [], }, diff --git a/test/parity/test_all_signals.py b/test/parity/test_all_signals.py index 6473798cd..60eb829a5 100644 --- a/test/parity/test_all_signals.py +++ b/test/parity/test_all_signals.py @@ -38,7 +38,6 @@ from ._signal_defaults import render - # --------------------------------------------------------------------- # # K 线 fixture # # --------------------------------------------------------------------- # @@ -67,6 +66,7 @@ def _make_bars(freq: str, sdt: str, edt: str) -> pd.DataFrame: # 策略合成 # # --------------------------------------------------------------------- # + def _build_all_signals_strategy(czsc_module, base_freq: str): """构造覆盖全部 K 线信号的运行时策略。 @@ -151,14 +151,13 @@ def _signal_columns(df: pd.DataFrame) -> list[str]: # 参数化等价性测试 # # --------------------------------------------------------------------- # + @pytest.mark.parametrize( "label,base_freq,sdt,edt", DATASETS, ids=[d[0] for d in DATASETS], ) -def test_all_signals_parity( - rs_czsc_module, czsc_module, label, base_freq, sdt, edt, capsys -): +def test_all_signals_parity(rs_czsc_module, czsc_module, label, base_freq, sdt, edt, capsys): """对每种规模的数据集,``czsc.run_research`` 输出的每一列信号都 必须与 ``rs_czsc.run_research`` 完全相等。 @@ -193,17 +192,13 @@ def test_all_signals_parity( czsc_df = arrow_bytes_to_pd_df(bytes(czsc_payload["signals_arrow"])) # shape 必须完全一致 - assert rs_df.shape == czsc_df.shape, ( - f"[{label}] shape mismatch: rs={rs_df.shape} czsc={czsc_df.shape}" - ) + assert rs_df.shape == czsc_df.shape, f"[{label}] shape mismatch: rs={rs_df.shape} czsc={czsc_df.shape}" # 列集合严格相等:任何一边出现额外列都视为失败 rs_cols = set(rs_df.columns) czsc_cols = set(czsc_df.columns) assert rs_cols == czsc_cols, ( - f"[{label}] column set differs.\n" - f" rs only: {rs_cols - czsc_cols}\n" - f" czsc only: {czsc_cols - rs_cols}" + f"[{label}] column set differs.\n rs only: {rs_cols - czsc_cols}\n czsc only: {czsc_cols - rs_cols}" ) # 逐列、逐 cell 对比信号值,记录所有不一致列 @@ -217,8 +212,12 @@ def test_all_signals_parity( mask = rs_series.ne(czsc_series) | (rs_series.isna() ^ czsc_series.isna()) first_idx = mask.idxmax() if mask.any() else None diverging.append( - (col, first_idx, rs_series.iloc[first_idx] if first_idx is not None else None, - czsc_series.iloc[first_idx] if first_idx is not None else None) + ( + col, + first_idx, + rs_series.iloc[first_idx] if first_idx is not None else None, + czsc_series.iloc[first_idx] if first_idx is not None else None, + ) ) ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") @@ -234,6 +233,5 @@ def test_all_signals_parity( ) assert not diverging, ( - f"[{label}] {len(diverging)} signal columns diverge.\n" - f" first 5 (col, row, rs, czsc): {diverging[:5]}" + f"[{label}] {len(diverging)} signal columns diverge.\n first 5 (col, row, rs, czsc): {diverging[:5]}" ) diff --git a/test/parity/test_examples.py b/test/parity/test_examples.py index b244193da..7c6287cf3 100644 --- a/test/parity/test_examples.py +++ b/test/parity/test_examples.py @@ -24,13 +24,12 @@ from pathlib import Path import pandas as pd -import pytest - # --------------------------------------------------------------------- # # 共享辅助函数 # # --------------------------------------------------------------------- # + def _patch_event_is_match_tuple_contract(module): """把 ``module.Event.is_match`` 包装成 (matched, reason) 元组返回。 @@ -111,16 +110,12 @@ def _compare_parquet_trees(rs_root: Path, czsc_root: Path, label: str): for rel in sorted(rs_files): rs_df = _normalise_parquet(rs_root / rel) czsc_df = _normalise_parquet(czsc_root / rel) - assert rs_df.shape == czsc_df.shape, ( - f"[{label}/{rel}] shape mismatch: rs={rs_df.shape} czsc={czsc_df.shape}" - ) + assert rs_df.shape == czsc_df.shape, f"[{label}/{rel}] shape mismatch: rs={rs_df.shape} czsc={czsc_df.shape}" # 列集合严格一致:"完全一致"意味着任何一边都不能多出列 rs_cols = set(rs_df.columns) czsc_cols = set(czsc_df.columns) assert rs_cols == czsc_cols, ( - f"[{label}/{rel}] column set differs.\n" - f" rs only: {rs_cols - czsc_cols}\n" - f" czsc only: {czsc_cols - rs_cols}" + f"[{label}/{rel}] column set differs.\n rs only: {rs_cols - czsc_cols}\n czsc only: {czsc_cols - rs_cols}" ) cols = sorted(rs_cols) pd.testing.assert_frame_equal( @@ -135,6 +130,7 @@ def _compare_parquet_trees(rs_root: Path, czsc_root: Path, label: str): # 示例 1 —— 30分钟笔非多即空.py # # --------------------------------------------------------------------- # + def _build_long_short_position(module, symbol: str, base_freq: str): """构造与 ``30分钟笔非多即空.py::create_long_short_V230909`` 一致的 Position。 @@ -202,9 +198,7 @@ def _run_30m_example(module, bars_df, results_root: Path): return elapsed -def test_example_30min_long_short_parity( - rs_czsc_module, czsc_module, tmp_path, capsys -): +def test_example_30min_long_short_parity(rs_czsc_module, czsc_module, tmp_path, capsys): """30分钟笔非多即空 示例脚本的端到端等价性。 测试场景:rs_czsc 与 czsc 各跑一次 backtest+replay,输出落到不同目录。 @@ -226,10 +220,7 @@ def test_example_30min_long_short_parity( with capsys.disabled(): ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") # 顺便打印耗时比,便于跟踪性能 - print( - f"\n[30分钟笔非多即空] rs_czsc={rs_elapsed:.3f}s " - f"czsc={czsc_elapsed:.3f}s ratio={ratio:.2f}x" - ) + print(f"\n[30分钟笔非多即空] rs_czsc={rs_elapsed:.3f}s czsc={czsc_elapsed:.3f}s ratio={ratio:.2f}x") # --------------------------------------------------------------------- # @@ -355,9 +346,7 @@ def read_bars(symbol, freq, sdt, edt, **_): open_root = results_root / "open_demo" open_root.mkdir(parents=True, exist_ok=True) - files_position = _materialize_beta_positions( - module, "000001", open_root / "base_positions" - ) + files_position = _materialize_beta_positions(module, "000001", open_root / "base_positions") start = time.perf_counter() oop = OpensOptimize( @@ -377,9 +366,7 @@ def read_bars(symbol, freq, sdt, edt, **_): exit_root = results_root / "exit_demo" exit_root.mkdir(parents=True, exist_ok=True) - files_position_exit = _materialize_beta_positions( - module, "000001", exit_root / "base_positions" - ) + files_position_exit = _materialize_beta_positions(module, "000001", exit_root / "base_positions") eop = ExitsOptimize( symbols=["000001"], files_position=files_position_exit, @@ -397,9 +384,7 @@ def read_bars(symbol, freq, sdt, edt, **_): return time.perf_counter() - start -def test_example_use_optimize_parity( - rs_czsc_module, czsc_module, tmp_path, capsys -): +def test_example_use_optimize_parity(rs_czsc_module, czsc_module, tmp_path, capsys): """use_optimize 示例脚本的端到端等价性。 测试场景:rs_czsc 与 czsc 分别跑一遍开仓优化 + 出场优化,输出落到 @@ -421,16 +406,14 @@ def test_example_use_optimize_parity( with capsys.disabled(): ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") - print( - f"\n[use_optimize] rs_czsc={rs_elapsed:.3f}s " - f"czsc={czsc_elapsed:.3f}s ratio={ratio:.2f}x" - ) + print(f"\n[use_optimize] rs_czsc={rs_elapsed:.3f}s czsc={czsc_elapsed:.3f}s ratio={ratio:.2f}x") # --------------------------------------------------------------------- # # 示例 3 —— weight_backtest.py # # --------------------------------------------------------------------- # + def _build_weight_df(): """构造 ``weight_backtest.py`` 所需的权重 DataFrame。 @@ -443,9 +426,7 @@ def _build_weight_df(): return df[["dt", "symbol", "weight", "price"]].copy() -def test_example_weight_backtest_parity( - rs_czsc_module, czsc_module, capsys -): +def test_example_weight_backtest_parity(rs_czsc_module, czsc_module, capsys): """``WeightBacktest`` 示例的等价性测试(弱化版本)。 关键约定: @@ -468,9 +449,7 @@ def test_example_weight_backtest_parity( df = _build_weight_df() start = time.perf_counter() - rs_wb = rs_czsc_module.WeightBacktest( - df, digits=2, n_jobs=1, weight_type="ts" - ) + rs_wb = rs_czsc_module.WeightBacktest(df, digits=2, n_jobs=1, weight_type="ts") rs_stats = dict(rs_wb.stats) rs_elapsed = time.perf_counter() - start @@ -495,9 +474,7 @@ def test_example_weight_backtest_parity( if isinstance(rv, (int, float)) and isinstance(cv, (int, float)): if abs(rv - cv) > 0.005: # 0.5 个百分点的容差 diffs[k] = (rv, cv) - assert not diffs, ( - f"core stats divergence beyond 0.005 tolerance: {diffs}" - ) + assert not diffs, f"core stats divergence beyond 0.005 tolerance: {diffs}" with capsys.disabled(): ratio = czsc_elapsed / rs_elapsed if rs_elapsed > 0 else float("inf") diff --git a/test/parity/test_optimize.py b/test/parity/test_optimize.py index e85d5570d..4356a8136 100644 --- a/test/parity/test_optimize.py +++ b/test/parity/test_optimize.py @@ -17,8 +17,6 @@ from __future__ import annotations import json -import shutil -import tempfile from pathlib import Path import pandas as pd @@ -67,9 +65,7 @@ def test_build_open_optim_positions_matches(rs_czsc_module, czsc_module, positio rs_data = json.loads(rs_raw) if isinstance(rs_raw, str) else rs_raw czsc_data = json.loads(czsc_raw) if isinstance(czsc_raw, str) else czsc_raw - assert len(rs_data) == len(czsc_data), ( - f"position count mismatch: rs={len(rs_data)} czsc={len(czsc_data)}" - ) + assert len(rs_data) == len(czsc_data), f"position count mismatch: rs={len(rs_data)} czsc={len(czsc_data)}" # 仓位列表的顺序在两次独立运行间不一定一致;统一用 (name, opens hash, # exits hash) 做规范化后再比较集合相等。 @@ -83,9 +79,7 @@ def _canon(positions): for p in positions ) - assert _canon(rs_data) == _canon(czsc_data), ( - f"open-optim variants diverge" - ) + assert _canon(rs_data) == _canon(czsc_data), "open-optim variants diverge" def test_build_exit_optim_positions_matches(rs_czsc_module, czsc_module, position_files): @@ -145,9 +139,7 @@ def bars_dir(tmp_path, mock_kline_df, czsc_module): return bars_path -def test_run_optimize_batch_matches( - rs_czsc_module, czsc_module, bars_dir, position_files, tmp_path -): +def test_run_optimize_batch_matches(rs_czsc_module, czsc_module, bars_dir, position_files, tmp_path): """``run_optimize_batch`` 端到端等价性测试。 测试场景:构造一个最小化的开仓优化任务配置,rs_czsc 与 czsc 各跑 diff --git a/test/parity/test_performance.py b/test/parity/test_performance.py index ae20d0f8c..2927868f9 100644 --- a/test/parity/test_performance.py +++ b/test/parity/test_performance.py @@ -14,11 +14,9 @@ import json import statistics import time -from pathlib import Path import pandas as pd - # 性能回归预算:czsc 不应慢于 rs_czsc 1.5 倍 PERF_RATIO_BUDGET = 1.5 @@ -96,6 +94,7 @@ def positions(self): # 测试 1 —— CZSC 分析器构造性能 # # --------------------------------------------------------------------- # + def test_perf_czsc_analyzer(rs_czsc_module, czsc_module, mock_kline_df, capsys): """约 522 根日线下,``CZSC(bars)`` 分析器的耗时对比。 @@ -111,14 +110,10 @@ def test_perf_czsc_analyzer(rs_czsc_module, czsc_module, mock_kline_df, capsys): ratio = czsc_t / rs_t if rs_t > 0 else float("inf") with capsys.disabled(): - print( - f"\n[CZSC(522 daily bars)] rs_czsc={rs_t * 1000:.2f}ms " - f"czsc={czsc_t * 1000:.2f}ms ratio={ratio:.2f}x" - ) + print(f"\n[CZSC(522 daily bars)] rs_czsc={rs_t * 1000:.2f}ms czsc={czsc_t * 1000:.2f}ms ratio={ratio:.2f}x") # 这是宽松预算 —— 性能差异主要起到信息提示作用,而不是阻塞性的卡口 assert ratio <= PERF_RATIO_BUDGET, ( - f"czsc analyzer is {ratio:.2f}x slower than rs_czsc baseline " - f"({PERF_RATIO_BUDGET}x budget exceeded)" + f"czsc analyzer is {ratio:.2f}x slower than rs_czsc baseline ({PERF_RATIO_BUDGET}x budget exceeded)" ) @@ -126,6 +121,7 @@ def test_perf_czsc_analyzer(rs_czsc_module, czsc_module, mock_kline_df, capsys): # 测试 2 —— 多种 K 线规模下的 backtest 性能 # # --------------------------------------------------------------------- # + def test_perf_backtest_scaling(rs_czsc_module, czsc_module, tmp_path, capsys): """30 分钟策略 backtest 在不同 K 线规模下的耗时对比。 @@ -156,18 +152,14 @@ def test_perf_backtest_scaling(rs_czsc_module, czsc_module, tmp_path, capsys): backtest_sdt = pd.to_datetime(sdt_data) + pd.Timedelta(days=60) backtest_sdt_str = backtest_sdt.strftime("%Y-%m-%d") - rs_t = _backtest_perf( - rs_czsc_module, df, "30分钟", backtest_sdt_str, tmp_path / "rs" - ) - czsc_t = _backtest_perf( - czsc_module, df, "30分钟", backtest_sdt_str, tmp_path / "czsc" - ) + rs_t = _backtest_perf(rs_czsc_module, df, "30分钟", backtest_sdt_str, tmp_path / "rs") + czsc_t = _backtest_perf(czsc_module, df, "30分钟", backtest_sdt_str, tmp_path / "czsc") ratio = czsc_t / rs_t if rs_t > 0 else float("inf") rows.append((len(df), rs_t * 1000, czsc_t * 1000, ratio)) with capsys.disabled(): # 打印一张对照表 - print(f"\n[backtest scaling — 30min strategy]") + print("\n[backtest scaling — 30min strategy]") print(f" {'#bars':>6} | {'rs_czsc':>10} | {'czsc':>10} | {'ratio':>6}") print(f" {'-' * 6} | {'-' * 10} | {'-' * 10} | {'-' * 6}") for n_bars, rs_ms, czsc_ms, r in rows: @@ -175,18 +167,15 @@ def test_perf_backtest_scaling(rs_czsc_module, czsc_module, tmp_path, capsys): # 只要任何一个规模超出预算就视为失败 over_budget = [(n, r) for n, _, _, r in rows if r > PERF_RATIO_BUDGET] - assert not over_budget, ( - f"backtest perf budget {PERF_RATIO_BUDGET}x exceeded at: {over_budget}" - ) + assert not over_budget, f"backtest perf budget {PERF_RATIO_BUDGET}x exceeded at: {over_budget}" # --------------------------------------------------------------------- # # 测试 3 —— derive_signals_config + run_research 端到端性能 # # --------------------------------------------------------------------- # -def test_perf_run_research_endtoend( - rs_czsc_module, czsc_module, mock_kline_df, sample_position_dict, capsys -): + +def test_perf_run_research_endtoend(rs_czsc_module, czsc_module, mock_kline_df, sample_position_dict, capsys): """``run_research(arrow_bytes, json)`` 的端到端性能对比。 这是迁移后最关键的入口:把 Arrow 字节 + JSON 策略喂给 Rust 端, @@ -256,6 +245,4 @@ def _time(module, n=5): f"ratio={ratio:.2f}x" ) - assert ratio <= PERF_RATIO_BUDGET, ( - f"czsc.run_research is {ratio:.2f}x slower than rs_czsc baseline" - ) + assert ratio <= PERF_RATIO_BUDGET, f"czsc.run_research is {ratio:.2f}x slower than rs_czsc baseline" diff --git a/test/parity/test_run_research.py b/test/parity/test_run_research.py index d18c04c6d..b3f2f8022 100644 --- a/test/parity/test_run_research.py +++ b/test/parity/test_run_research.py @@ -106,9 +106,7 @@ def test_run_research_signals_match(rs_czsc_module, czsc_module, parity_inputs): f"czsc only: {set(czsc_df.columns) - set(rs_df.columns)}" ) common_cols = sorted(rs_df.columns) - pd.testing.assert_frame_equal( - rs_df[common_cols], czsc_df[common_cols], check_dtype=False - ) + pd.testing.assert_frame_equal(rs_df[common_cols], czsc_df[common_cols], check_dtype=False) def test_run_research_pairs_match(rs_czsc_module, czsc_module, parity_inputs): diff --git a/test/parity/test_signals_registry.py b/test/parity/test_signals_registry.py index 2aa8ca90f..7e16ba9ad 100644 --- a/test/parity/test_signals_registry.py +++ b/test/parity/test_signals_registry.py @@ -20,9 +20,7 @@ def test_list_all_signals_count_matches(rs_czsc_module, czsc_module): """注册表中信号数量必须一致。""" rs_list = rs_czsc_module.list_all_signals() czsc_list = czsc_module._native.list_all_signals() - assert len(rs_list) == len(czsc_list), ( - f"signal count mismatch: rs_czsc={len(rs_list)} vs czsc={len(czsc_list)}" - ) + assert len(rs_list) == len(czsc_list), f"signal count mismatch: rs_czsc={len(rs_list)} vs czsc={len(czsc_list)}" def test_list_all_signals_names_match(rs_czsc_module, czsc_module): @@ -43,11 +41,7 @@ def test_list_all_signals_templates_match(rs_czsc_module, czsc_module): """ rs_map = {d["name"]: d.get("param_template") for d in rs_czsc_module.list_all_signals()} czsc_map = {d["name"]: d.get("param_template") for d in czsc_module._native.list_all_signals()} - diffs = { - name: (rs_map[name], czsc_map[name]) - for name in rs_map - if rs_map[name] != czsc_map[name] - } + diffs = {name: (rs_map[name], czsc_map[name]) for name in rs_map if rs_map[name] != czsc_map[name]} assert not diffs, f"param_template mismatches: {diffs}" diff --git a/test/smoke/test_install.py b/test/smoke/test_install.py index 5ee43d04e..8aa6a5819 100644 --- a/test/smoke/test_install.py +++ b/test/smoke/test_install.py @@ -27,7 +27,6 @@ import pytest - # 仓库根目录(基于本测试文件位置反推),以及构建产物所在目录 REPO_ROOT = Path(__file__).resolve().parents[2] DIST_DIR = REPO_ROOT / "dist" @@ -49,8 +48,7 @@ def test_pyproject_uses_maturin_backend() -> None: pytest.fail(f"未找到 pyproject.toml,路径:{pyproject}") text = pyproject.read_text(encoding="utf-8") assert "maturin" in text, ( - "pyproject.toml 必须声明 [build-system] requires=['maturin>=...'] " - "以及 build-backend='maturin'" + "pyproject.toml 必须声明 [build-system] requires=['maturin>=...'] 以及 build-backend='maturin'" ) @@ -77,9 +75,7 @@ def test_native_extension_present_in_install() -> None: "构建流程必须保证 maturin 正确打包了 Rust 扩展" ) out = proc.stdout.strip() - assert out.endswith((".so", ".pyd", ".dylib")), ( - f"czsc._native.__file__ 应指向编译扩展,实际为 {out!r}" - ) + assert out.endswith((".so", ".pyd", ".dylib")), f"czsc._native.__file__ 应指向编译扩展,实际为 {out!r}" def test_wheel_install_in_clean_venv(tmp_path: Path) -> None: @@ -109,10 +105,7 @@ def test_wheel_install_in_clean_venv(tmp_path: Path) -> None: # 常规 CI 的 test job 只跑 `maturin develop`,不产 wheel; # wheel 安装验证由 python-publish.yml 的 smoke-test job 在发布流程 # 里专门跑。所以本地/常规 CI 没 wheel 时跳过,不算失败。 - pytest.skip( - f"在 {DIST_DIR} 下找不到 wheel;" - "如需运行本测试,请先 `maturin build --release`" - ) + pytest.skip(f"在 {DIST_DIR} 下找不到 wheel;如需运行本测试,请先 `maturin build --release`") # 构建一个全新的虚拟环境用于隔离安装 venv = tmp_path / "venv" @@ -130,9 +123,7 @@ def test_wheel_install_in_clean_venv(tmp_path: Path) -> None: text=True, ) if install.returncode != 0: - pytest.fail( - f"pip install {wheels[-1].name} 失败: {install.stderr}" - ) + pytest.fail(f"pip install {wheels[-1].name} 失败: {install.stderr}") # 在新 venv 中打印 czsc.CZSC.__module__, # 该值由 PyO3 的 #[pyclass(module=...)] 设置为 "czsc._native" diff --git a/test/test_envs.py b/test/test_envs.py index f22028143..98a923821 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -52,9 +52,7 @@ def test_legacy_helper_removed(self, name: str) -> None: if name == "_env": assert hasattr(envs_mod, name), "私有 _env helper 应当继续保留" else: - assert not hasattr(envs_mod, name), ( - f"czsc.envs.{name} 必须被移除" - ) + assert not hasattr(envs_mod, name), f"czsc.envs.{name} 必须被移除" def test_no_czsc_use_python_branch(self) -> None: """验证 czsc.envs 源码中不再引用 CZSC_USE_PYTHON 环境变量。 @@ -70,9 +68,7 @@ def test_no_czsc_use_python_branch(self) -> None: import inspect src = inspect.getsource(envs_mod) - assert "CZSC_USE_PYTHON" not in src, ( - "CZSC_USE_PYTHON 环境变量必须不被引用" - ) + assert "CZSC_USE_PYTHON" not in src, "CZSC_USE_PYTHON 环境变量必须不被引用" class TestGetVerbose: diff --git a/test/test_plotly_plot.py b/test/test_plotly_plot.py index 726ae6652..be9387d03 100644 --- a/test/test_plotly_plot.py +++ b/test/test_plotly_plot.py @@ -16,8 +16,7 @@ import pandas as pd -from czsc import KlineChart, mock -from czsc import CZSC, Freq, RawBar +from czsc import CZSC, Freq, KlineChart, RawBar, mock def test_kline_chart(): diff --git a/test/test_stoploss_by_direction.py b/test/test_stoploss_by_direction.py index 73bd81fea..98aa35470 100644 --- a/test/test_stoploss_by_direction.py +++ b/test/test_stoploss_by_direction.py @@ -17,12 +17,14 @@ def _make_dfw(weights: list[float], prices: list[float], symbol: str = "X") -> pd.DataFrame: assert len(weights) == len(prices) - return pd.DataFrame({ - "dt": pd.date_range("2024-01-01", periods=len(prices), freq="D"), - "symbol": symbol, - "weight": weights, - "price": prices, - }) + return pd.DataFrame( + { + "dt": pd.date_range("2024-01-01", periods=len(prices), freq="D"), + "symbol": symbol, + "weight": weights, + "price": prices, + } + ) def test_long_position_no_stop(): @@ -104,9 +106,18 @@ def test_required_output_columns(): dfw = _make_dfw(weights=[1.0, 1.0], prices=[100.0, 101.0]) out = stoploss_by_direction(dfw, stoploss=0.05) - required = {"dt", "symbol", "raw_weight", "weight", "price", - "hold_returns", "min_hold_returns", "returns", - "order_id", "is_stop"} + required = { + "dt", + "symbol", + "raw_weight", + "weight", + "price", + "hold_returns", + "min_hold_returns", + "returns", + "order_id", + "is_stop", + } assert required.issubset(set(out.columns)) diff --git a/test/test_utils.py b/test/test_utils.py index 477f33fde..1a06069aa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,9 +9,6 @@ import time -import numpy as np -import pandas as pd - from czsc import utils from czsc.utils import timeout_decorator diff --git a/test/unit/test_core_parity.py b/test/unit/test_core_parity.py index 579816c83..2ff121856 100644 --- a/test/unit/test_core_parity.py +++ b/test/unit/test_core_parity.py @@ -50,9 +50,7 @@ def _build_czsc() -> tuple[Any | None, str | None]: import czsc from czsc.mock import generate_symbol_kines - df = generate_symbol_kines( - "000001", "30分钟", "20240101", "20240301", seed=42 - ) + df = generate_symbol_kines("000001", "30分钟", "20240101", "20240301", seed=42) bars = czsc.format_standard_kline(df, freq=czsc.Freq.F30) return czsc.CZSC(bars), None except Exception as exc: # noqa: BLE001 @@ -73,8 +71,7 @@ def test_czsc_source_is_in_repo_native() -> None: assert obj is not None, err module = type(obj).__module__ assert module.startswith("czsc."), ( - f"czsc.CZSC 必须来自 czsc._native(实际:{module!r});" - "迁移目标要求完全移除 rs_czsc PyPI 依赖。" + f"czsc.CZSC 必须来自 czsc._native(实际:{module!r});迁移目标要求完全移除 rs_czsc PyPI 依赖。" ) @@ -84,8 +81,7 @@ def test_fx_list_count_matches_baseline() -> None: assert obj is not None, err snap = _load_snapshot() assert len(obj.fx_list) == snap["fx_list_count"], ( - f"FX 数量出现漂移:实际 {len(obj.fx_list)},基线 " - f"{snap['fx_list_count']}" + f"FX 数量出现漂移:实际 {len(obj.fx_list)},基线 {snap['fx_list_count']}" ) @@ -95,8 +91,7 @@ def test_bi_list_count_matches_baseline() -> None: assert obj is not None, err snap = _load_snapshot() assert len(obj.bi_list) == snap["bi_list_count"], ( - f"BI 数量出现漂移:实际 {len(obj.bi_list)},基线 " - f"{snap['bi_list_count']}" + f"BI 数量出现漂移:实际 {len(obj.bi_list)},基线 {snap['bi_list_count']}" ) @@ -135,7 +130,4 @@ def test_bi_lengths_sequence_matches_baseline() -> None: assert obj is not None, err snap = _load_snapshot() actual = [bi.length for bi in obj.bi_list] - assert actual == snap["bi_lengths"], ( - f"BI 长度序列出现漂移;期望 {snap['bi_lengths'][:5]}…," - f"实际 {actual[:5]}…" - ) + assert actual == snap["bi_lengths"], f"BI 长度序列出现漂移;期望 {snap['bi_lengths'][:5]}…,实际 {actual[:5]}…" diff --git a/test/unit/test_pickle.py b/test/unit/test_pickle.py index 05bc6bfd5..e6b7c6f67 100644 --- a/test/unit/test_pickle.py +++ b/test/unit/test_pickle.py @@ -42,8 +42,8 @@ def _try_build_bars() -> tuple[Any | None, str | None]: 避免 fixture 阶段直接抛异常造成 pytest 报告为 ERROR 的情况。 """ try: + from czsc import Freq, format_standard_kline # type: ignore[attr-defined] from czsc.mock import generate_symbol_kines # type: ignore[attr-defined] - from czsc import format_standard_kline, Freq # type: ignore[attr-defined] except Exception as exc: # noqa: BLE001 return None, f"import failed: {type(exc).__name__}: {exc}" @@ -122,16 +122,9 @@ def test_pickle_roundtrip(target: str) -> None: blob = pickle.dumps(obj) restored = pickle.loads(blob) except Exception as exc: # noqa: BLE001 - pytest.fail( - f"[{target}] pickle 往返序列化抛出异常 " - f"{type(exc).__name__}: {exc}" - ) + pytest.fail(f"[{target}] pickle 往返序列化抛出异常 {type(exc).__name__}: {exc}") - assert type(restored) is type(obj), ( - f"[{target}] 往返序列化改变了对象类型:{type(obj)} → {type(restored)}" - ) + assert type(restored) is type(obj), f"[{target}] 往返序列化改变了对象类型:{type(obj)} → {type(restored)}" if hasattr(obj, "__getstate__") and hasattr(restored, "__getstate__"): - assert restored.__getstate__() == obj.__getstate__(), ( - f"[{target}] __getstate__ 在往返序列化后发生变化" - ) + assert restored.__getstate__() == obj.__getstate__(), f"[{target}] __getstate__ 在往返序列化后发生变化" diff --git a/test/unit/test_signals_parity.py b/test/unit/test_signals_parity.py index b3c92daab..44756e593 100644 --- a/test/unit/test_signals_parity.py +++ b/test/unit/test_signals_parity.py @@ -51,9 +51,7 @@ def test_signal_subpackage_exists(sub: str) -> None: ``importlib.import_module(f"czsc.signals.{sub}")`` 不抛异常且返回非 None。 """ mod, err = _safe_import(f"czsc.signals.{sub}") - assert mod is not None, ( - f"czsc.signals.{sub} 必须可被导入({err})" - ) + assert mod is not None, f"czsc.signals.{sub} 必须可被导入({err})" # 参数化覆盖所有必需信号子包,每个子包验证其暴露足够数量的函数 @@ -71,13 +69,9 @@ def test_signal_subpackage_has_functions(sub: str) -> None: mod, err = _safe_import(f"czsc.signals.{sub}") if mod is None: pytest.fail(f"无法导入 czsc.signals.{sub}: {err}") - funcs = [ - name for name in dir(mod) - if not name.startswith("_") and callable(getattr(mod, name)) - ] + funcs = [name for name in dir(mod) if not name.startswith("_") and callable(getattr(mod, name))] assert len(funcs) >= MIN_FUNCS_PER_SUBPACKAGE, ( - f"czsc.signals.{sub} 必须至少暴露 {MIN_FUNCS_PER_SUBPACKAGE} 个 " - f"信号函数;实际找到 {len(funcs)} 个:{funcs}" + f"czsc.signals.{sub} 必须至少暴露 {MIN_FUNCS_PER_SUBPACKAGE} 个 信号函数;实际找到 {len(funcs)} 个:{funcs}" ) @@ -96,10 +90,7 @@ def test_signal_subpackage_sourced_from_native(sub: str) -> None: mod, err = _safe_import(f"czsc.signals.{sub}") if mod is None: pytest.fail(f"无法导入 czsc.signals.{sub}: {err}") - funcs = [ - getattr(mod, n) for n in dir(mod) - if not n.startswith("_") and callable(getattr(mod, n)) - ] + funcs = [getattr(mod, n) for n in dir(mod) if not n.startswith("_") and callable(getattr(mod, n))] if not funcs: pytest.fail(f"czsc.signals.{sub} 中没有任何可调用对象") @@ -123,6 +114,4 @@ def test_native_signals_module_exists() -> None: """ native, err = _safe_import("czsc._native") assert native is not None, f"czsc._native 必须存在(maturin 构建产物)({err})" - assert hasattr(native, "signals"), ( - "czsc._native.signals 必须是已注册的 PyO3 子模块" - ) + assert hasattr(native, "signals"), "czsc._native.signals 必须是已注册的 PyO3 子模块" diff --git a/test/unit/test_ta_parity.py b/test/unit/test_ta_parity.py index e517d1537..06e43839d 100644 --- a/test/unit/test_ta_parity.py +++ b/test/unit/test_ta_parity.py @@ -94,9 +94,7 @@ def test_ema_matches_talib() -> None: if not hasattr(ta, "ema"): pytest.fail("czsc.ta.ema 尚未暴露") actual = ta.ema(series, length=14) - np.testing.assert_allclose( - np.asarray(actual)[20:], expected[20:], rtol=1e-6, atol=1e-6 - ) + np.testing.assert_allclose(np.asarray(actual)[20:], expected[20:], rtol=1e-6, atol=1e-6) def test_sma_matches_talib() -> None: @@ -116,9 +114,7 @@ def test_sma_matches_talib() -> None: if not hasattr(ta, "sma"): pytest.fail("czsc.ta.sma 尚未暴露") actual = ta.sma(series, length=20) - np.testing.assert_allclose( - np.asarray(actual)[20:], expected[20:], rtol=1e-6, atol=1e-6 - ) + np.testing.assert_allclose(np.asarray(actual)[20:], expected[20:], rtol=1e-6, atol=1e-6) def test_rolling_rank_returns_finite() -> None: @@ -133,9 +129,7 @@ def test_rolling_rank_returns_finite() -> None: if not hasattr(ta, "rolling_rank"): pytest.fail("czsc.ta.rolling_rank 尚未暴露") out = np.asarray(ta.rolling_rank(_series(), window=20)) - assert np.isfinite(out[20:]).all(), ( - "rolling_rank 在预热窗口之后必须产出有限值" - ) + assert np.isfinite(out[20:]).all(), "rolling_rank 在预热窗口之后必须产出有限值" def test_boll_positions_signature() -> None: @@ -147,9 +141,7 @@ def test_boll_positions_signature() -> None: ta, err = _native_module() if ta is None: pytest.fail(f"czsc.ta 不可用:{err}") - assert hasattr(ta, "boll_positions"), ( - "czsc.ta.boll_positions 必须暴露" - ) + assert hasattr(ta, "boll_positions"), "czsc.ta.boll_positions 必须暴露" def test_ultimate_smoother_signature() -> None: @@ -161,6 +153,4 @@ def test_ultimate_smoother_signature() -> None: ta, err = _native_module() if ta is None: pytest.fail(f"czsc.ta 不可用:{err}") - assert hasattr(ta, "ultimate_smoother"), ( - "czsc.ta.ultimate_smoother 必须暴露" - ) + assert hasattr(ta, "ultimate_smoother"), "czsc.ta.ultimate_smoother 必须暴露" diff --git a/test/unit/test_trading_time.py b/test/unit/test_trading_time.py index 224891b81..02524a41c 100644 --- a/test/unit/test_trading_time.py +++ b/test/unit/test_trading_time.py @@ -75,14 +75,9 @@ def test_is_trading_time(market: str, dt: datetime, expected: bool) -> None: if czsc is None: pytest.fail(f"导入 czsc 失败:{err}") if not hasattr(czsc, "is_trading_time"): - pytest.fail( - "czsc.is_trading_time 尚未暴露 — czsc-utils 必须添加该函数" - ) + pytest.fail("czsc.is_trading_time 尚未暴露 — czsc-utils 必须添加该函数") actual = czsc.is_trading_time(dt, market=market) - assert actual is expected, ( - f"is_trading_time({market}, {dt.isoformat()}) 返回 {actual}," - f"预期 {expected}" - ) + assert actual is expected, f"is_trading_time({market}, {dt.isoformat()}) 返回 {actual},预期 {expected}" def test_is_trading_time_module_origin() -> None: @@ -101,6 +96,4 @@ def test_is_trading_time_module_origin() -> None: if fn is None: pytest.fail("czsc.is_trading_time 缺失") module = getattr(fn, "__module__", "?") - assert module.startswith("czsc."), ( - f"is_trading_time 必须来自 czsc._native(实际 {module!r})" - ) + assert module.startswith("czsc."), f"is_trading_time 必须来自 czsc._native(实际 {module!r})" From 959e88e2ef8207a4952e16517caf8f6b38e3d33b Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Fri, 8 May 2026 11:33:02 +0800 Subject: [PATCH 19/23] refactor: simplify date handling and improve warnings in various functions --- czsc/connectors/cooperation.py | 4 ++-- czsc/connectors/jq_connector.py | 15 ++++++--------- czsc/eda.py | 6 +++--- czsc/svc/__init__.py | 1 - czsc/svc/__init__.pyi | 3 --- czsc/svc/statistics.py | 12 ------------ czsc/svc/strategy.py | 2 -- czsc/traders/optimize.py | 2 +- czsc/utils/trade.py | 2 +- test/parity/test_examples.py | 6 +++--- test/parity/test_performance.py | 7 ++----- test/unit/test_core_parity.py | 4 ++-- 12 files changed, 20 insertions(+), 44 deletions(-) diff --git a/czsc/connectors/cooperation.py b/czsc/connectors/cooperation.py index 3e1163017..d338fdf0b 100644 --- a/czsc/connectors/cooperation.py +++ b/czsc/connectors/cooperation.py @@ -152,11 +152,11 @@ def get_min_future_klines(code, sdt, edt, freq="1m", **kwargs): dates = pd.date_range(start="20000101", end="20300101", freq="365D") dates = [d.strftime("%Y%m%d") for d in dates] - dates = sorted(list(set(dates))) + dates = sorted(set(dates)) rows = [] # 遍历每一个分段区间 [sdt_, edt_),按需要逐段拉取 - for sdt_, edt_ in tqdm(zip(dates[:-1], dates[1:]), total=len(dates) - 1): + for sdt_, edt_ in tqdm(zip(dates[:-1], dates[1:], strict=False), total=len(dates) - 1): # 该段结束时间早于查询开始时间,跳过 if edt_ < sdt: continue diff --git a/czsc/connectors/jq_connector.py b/czsc/connectors/jq_connector.py index 521541a6d..7298c8cb2 100644 --- a/czsc/connectors/jq_connector.py +++ b/czsc/connectors/jq_connector.py @@ -192,10 +192,7 @@ def get_concept_stocks(symbol, date=None): >>> symbols1 = get_concept_stocks("GN036", date="2020-07-08") >>> symbols2 = get_concept_stocks("GN036", date=datetime.now()) """ - if not date: - date = str(datetime.now().date()) - else: - date = pd.to_datetime(date) + date = str(datetime.now().date()) if not date else pd.to_datetime(date) if isinstance(date, datetime): date = str(date.date()) @@ -317,7 +314,7 @@ def get_kline( >>> df4 = get_kline(symbol="000001.XSHG", end_date='20200719', freq="1min", count=1000) """ if count and count > 5000: - warnings.warn(f"count={count}, 超过5000的最大值限制,仅返回最后5000条记录") + warnings.warn(f"count={count}, 超过5000的最大值限制,仅返回最后5000条记录", stacklevel=2) end_date = pd.to_datetime(end_date) @@ -411,7 +408,7 @@ def get_kline_period( # 粗略估算:(自然日 * 5/7) 近似得到交易日数;超 1000 个交易日可能触发服务端限制 if (end_date - start_date).days * 5 / 7 > 1000: - warnings.warn(f"{end_date.date()} - {start_date.date()} 超过1000个交易日,K线获取可能失败,返回为0") + warnings.warn(f"{end_date.date()} - {start_date.date()} 超过1000个交易日,K线获取可能失败,返回为0", stacklevel=2) data = { "method": "get_price_period", @@ -487,7 +484,7 @@ def get_init_bg(symbol: str, end_dt: [str, datetime], base_freq: str, freqs: lis bg = BarGenerator(base_freq, freqs, max_count) # 对 BarGenerator 中维护的每一个频率,分别拉取并初始化 - for freq in bg.bars.keys(): + for freq in bg.bars: bars_ = get_kline(symbol=symbol, end_date=last_day, freq=freq_cn2jq[freq], count=max_count, fq=fq) bg.init_freq_bars(freq, bars_) print(f"{symbol} - {freq} - {len(bg.bars[freq])} - last_dt: {bg.bars[freq][-1].dt} - last_day: {last_day}") @@ -538,7 +535,7 @@ def get_fundamental(table: str, symbol: str, date: str, columns: str = "") -> di df = text2df(r.text) try: return df.iloc[0].to_dict() - except: + except Exception: # 兼容数据为空、列缺失等多种异常,统一返回空字典 return {} @@ -701,7 +698,7 @@ def get_raw_bars(symbol, freq, sdt, edt, fq="前复权", **kwargs): kwargs["fq"] = fq freq = str(freq) # 仅 "前复权" 时设为 True,其他统一为 False - fq = True if fq == "前复权" else False + fq = fq == "前复权" # CZSC 中文频率到聚宽 pandas-style 频率字符串的映射 _map = { "1分钟": "1min", diff --git a/czsc/eda.py b/czsc/eda.py index 93f1f0cc2..0bd54b826 100644 --- a/czsc/eda.py +++ b/czsc/eda.py @@ -878,9 +878,9 @@ def __convert_rs_direction(x): bi_stats.append( { "symbol": symbol, - "sdt": bi.sdt if not rs else pd.to_datetime(bi.sdt, unit="s"), - "edt": bi.edt if not rs else pd.to_datetime(bi.edt, unit="s"), - "direction": bi.direction.value if not rs else __convert_rs_direction(bi.direction), + "sdt": pd.to_datetime(bi.sdt, unit="s"), + "edt": pd.to_datetime(bi.edt, unit="s"), + "direction": __convert_rs_direction(bi.direction), "power_price": abs(bi.change), "length": bi.length, "rsq": bi.rsq, diff --git a/czsc/svc/__init__.py b/czsc/svc/__init__.py index d1b9c4ab3..b63e2157a 100644 --- a/czsc/svc/__init__.py +++ b/czsc/svc/__init__.py @@ -125,7 +125,6 @@ "show_backtest_by_symbol", "show_long_short_backtest", "show_comprehensive_weight_backtest", - "run_weight_backtest_app", # 统计分析 "show_splited_daily", "show_yearly_stats", diff --git a/czsc/svc/__init__.pyi b/czsc/svc/__init__.pyi index 00bfb3e39..676efcd8e 100644 --- a/czsc/svc/__init__.pyi +++ b/czsc/svc/__init__.pyi @@ -189,7 +189,6 @@ __all__ = [ "show_backtest_by_symbol", "show_long_short_backtest", "show_comprehensive_weight_backtest", - "run_weight_backtest_app", "show_splited_daily", "show_yearly_stats", "show_out_in_compare", @@ -223,5 +222,3 @@ __all__ = [ "code_editor_form", ] -# Names in __all__ with no definition: -# run_weight_backtest_app diff --git a/czsc/svc/statistics.py b/czsc/svc/statistics.py index 2ba8f836f..8f0b4df66 100644 --- a/czsc/svc/statistics.py +++ b/czsc/svc/statistics.py @@ -89,10 +89,6 @@ def show_yearly_stats(df, ret_col, **kwargs): - sub_title: str,子标题 :return: None """ - daily_performance = safe_import_daily_performance() - if daily_performance is None: - return - df = ensure_datetime_index(df) df = df.copy().fillna(0).sort_index(ascending=True) @@ -128,10 +124,6 @@ def show_out_in_compare(df, ret_col, mid_dt, **kwargs): - sub_title: str,子标题 :return: None """ - daily_performance = safe_import_daily_performance() - if daily_performance is None: - return - assert isinstance(df, pd.DataFrame), "df 必须是 pd.DataFrame 类型" df = ensure_datetime_index(df) df = df[[ret_col]].copy().fillna(0).sort_index(ascending=True) @@ -220,10 +212,6 @@ def show_outsample_by_dailys(df, outsample_sdt1, outsample_sdt2=None): """ from czsc.eda import cal_yearly_days - daily_performance = safe_import_daily_performance() - if daily_performance is None: - return - if not ("dt" in df.columns and "returns" in df.columns): st.error(f"数据格式错误,必须包含列 ['dt', 'returns']; 当前列:{df.columns}") return diff --git a/czsc/svc/strategy.py b/czsc/svc/strategy.py index 551de51c9..cafa42fbc 100644 --- a/czsc/svc/strategy.py +++ b/czsc/svc/strategy.py @@ -944,8 +944,6 @@ def show_symbol_penalty(df: pd.DataFrame, n=3, **kwargs): - yearly_days: int,年化天数,默认 252 :return: None """ - WeightBacktest = safe_import_weight_backtest() - digits = kwargs.get("digits", 2) fee_rate = kwargs.get("fee_rate", 0.0) weight_type = kwargs.get("weight_type", "ts") diff --git a/czsc/traders/optimize.py b/czsc/traders/optimize.py index 3611aee12..5e65460ef 100644 --- a/czsc/traders/optimize.py +++ b/czsc/traders/optimize.py @@ -275,7 +275,7 @@ def __init__(self, read_bars: Callable, **kwargs): self.symbols = sorted(kwargs["symbols"]) self.files_position = [str(x) for x in kwargs["files_position"]] self.task_name = kwargs.get("task_name", "入场优化") - self.candidate_signals = sorted(list(kwargs["candidate_signals"])) + self.candidate_signals = sorted(kwargs["candidate_signals"]) self.signals_module_name = kwargs.get("signals_module_name", "czsc.signals") # base_freq 优先取用户显式配置;否则借助策略类自动推导, # 保证后续读取 K 线和写入 Rust 配置时频率信息一致。 diff --git a/czsc/utils/trade.py b/czsc/utils/trade.py index f0f01920f..64ef272b6 100644 --- a/czsc/utils/trade.py +++ b/czsc/utils/trade.py @@ -144,7 +144,7 @@ def resample_to_daily(df: pd.DataFrame, sdt=None, edt=None, only_trade_date=True trade_dates = pd.merge_asof(trade_dates, vdt, left_on="date", right_on="dt") trade_dates = trade_dates.dropna(subset=["dt"]).reset_index(drop=True) - dt_map = {dt: dfg for dt, dfg in df.groupby("dt")} + dt_map = dict(df.groupby("dt")) results = [] for row in trade_dates.to_dict("records"): # 注意:这里必须进行 copy,否则默认浅拷贝导致数据异常 diff --git a/test/parity/test_examples.py b/test/parity/test_examples.py index 7c6287cf3..a8df115e9 100644 --- a/test/parity/test_examples.py +++ b/test/parity/test_examples.py @@ -471,9 +471,9 @@ def test_example_weight_backtest_parity(rs_czsc_module, czsc_module, capsys): if k not in tight_check: continue rv, cv = rs_stats[k], czsc_stats[k] - if isinstance(rv, (int, float)) and isinstance(cv, (int, float)): - if abs(rv - cv) > 0.005: # 0.5 个百分点的容差 - diffs[k] = (rv, cv) + if isinstance(rv, (int, float)) and isinstance(cv, (int, float)) and abs(rv - cv) > 0.005: + # 0.5 个百分点的容差 + diffs[k] = (rv, cv) assert not diffs, f"core stats divergence beyond 0.005 tolerance: {diffs}" with capsys.disabled(): diff --git a/test/parity/test_performance.py b/test/parity/test_performance.py index 2927868f9..c908a6fbd 100644 --- a/test/parity/test_performance.py +++ b/test/parity/test_performance.py @@ -218,11 +218,8 @@ def _time(module, n=5): samples.append(time.perf_counter() - start) return statistics.median(samples) - if hasattr(rs_czsc_module, "_rs_czsc"): - # rs_czsc 的入口在 ``rs_czsc._rs_czsc`` 子模块上 - rs_native = rs_czsc_module._rs_czsc - else: - rs_native = rs_czsc_module._native + # rs_czsc 的入口在 ``rs_czsc._rs_czsc`` 子模块上 + rs_native = rs_czsc_module._rs_czsc if hasattr(rs_czsc_module, "_rs_czsc") else rs_czsc_module._native samples_rs = [] samples_czsc = [] diff --git a/test/unit/test_core_parity.py b/test/unit/test_core_parity.py index 2ff121856..13c8837e1 100644 --- a/test/unit/test_core_parity.py +++ b/test/unit/test_core_parity.py @@ -108,7 +108,7 @@ def test_fx_marks_sequence_matches_baseline() -> None: actual = [str(fx.mark) for fx in obj.fx_list] assert actual == snap["fx_marks"], ( f"FX 方向序列出现漂移;首个差异下标 = " - f"{next((i for i, (a, b) in enumerate(zip(actual, snap['fx_marks'])) if a != b), 'len mismatch')}" + f"{next((i for i, (a, b) in enumerate(zip(actual, snap['fx_marks'], strict=False)) if a != b), 'len mismatch')}" ) @@ -120,7 +120,7 @@ def test_bi_directions_sequence_matches_baseline() -> None: actual = [str(bi.direction) for bi in obj.bi_list] assert actual == snap["bi_directions"], ( f"BI 方向序列出现漂移;首个差异下标 = " - f"{next((i for i, (a, b) in enumerate(zip(actual, snap['bi_directions'])) if a != b), 'len mismatch')}" + f"{next((i for i, (a, b) in enumerate(zip(actual, snap['bi_directions'], strict=False)) if a != b), 'len mismatch')}" ) From 6f29922d63b23df4d82735444fc81bb71a3e4f69 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Fri, 8 May 2026 12:04:09 +0800 Subject: [PATCH 20/23] refactor: improve file handling and enhance code readability across multiple modules --- czsc/__init__.py | 8 ++++---- czsc/_utils/_df_convert.py | 1 - czsc/connectors/jq_connector.py | 8 ++++---- czsc/connectors/ts_connector.py | 4 ++-- czsc/fsa/base.py | 21 +++++++++++---------- czsc/fsa/im.py | 18 ++++++++++-------- czsc/sensors/utils.py | 4 +++- czsc/svc/__init__.pyi | 1 - czsc/svc/strategy.py | 6 ++++-- czsc/utils/__init__.py | 3 ++- czsc/utils/data/cache.py | 12 ++++++++---- czsc/utils/data/client.py | 3 ++- pyproject.toml | 4 ++++ 13 files changed, 54 insertions(+), 39 deletions(-) diff --git a/czsc/__init__.py b/czsc/__init__.py index 1caf420be..1b938bbcd 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -9,6 +9,10 @@ 作者: zengbin93 ,创建于 2019/10/29。 """ +# isort: skip_file +# 顶层包的 import 顺序经过手工设计以处理子包间的循环依赖, +# 不要让 isort/ruff 重排——会触发 partially-initialized module 错误。 + import sys as _sys # 第一批:纯薄壳子包(不会回头 import czsc 顶层符号)。 @@ -16,10 +20,6 @@ # 必须放到 wbt / .traders / .utils 之后再加载,避免循环 import。 from . import _native, connectors, envs, sensors, signals, traders, utils -# czsc.ta -> czsc._native.ta(Rust 实现);同时设置 sys.modules 以兼容 import czsc.ta -ta = _native.ta -_sys.modules["czsc.ta"] = _native.ta - # === 缠论核心数据类型与算法(来自 Rust 扩展 czsc._native)=== # === wbt(硬依赖,提供回测/绩效组件)=== from wbt import WeightBacktest, daily_performance, top_drawdowns diff --git a/czsc/_utils/_df_convert.py b/czsc/_utils/_df_convert.py index 3c885ef2a..04fc9602b 100644 --- a/czsc/_utils/_df_convert.py +++ b/czsc/_utils/_df_convert.py @@ -12,7 +12,6 @@ - PyArrow 与 Pandas 的版本组合需保持一致,否则可能在 Schema 推断时报错 """ - import pandas as pd import pyarrow as pa import pyarrow.ipc as ipc diff --git a/czsc/connectors/jq_connector.py b/czsc/connectors/jq_connector.py index 7298c8cb2..4107606b8 100644 --- a/czsc/connectors/jq_connector.py +++ b/czsc/connectors/jq_connector.py @@ -408,7 +408,9 @@ def get_kline_period( # 粗略估算:(自然日 * 5/7) 近似得到交易日数;超 1000 个交易日可能触发服务端限制 if (end_date - start_date).days * 5 / 7 > 1000: - warnings.warn(f"{end_date.date()} - {start_date.date()} 超过1000个交易日,K线获取可能失败,返回为0", stacklevel=2) + warnings.warn( + f"{end_date.date()} - {start_date.date()} 超过1000个交易日,K线获取可能失败,返回为0", stacklevel=2 + ) data = { "method": "get_price_period", @@ -646,9 +648,7 @@ def get_share_basic(symbol): msg += "2017~2020 财务变化\n\n" for k in cols: # 把 4 年同一指标横向拼接,便于一眼看出趋势 - msg += k + ":{} | {} | {} | {}\n".format( - *[f10[f"{year}{k}"] for year in ["2017", "2018", "2019", "2020"]] - ) + msg += k + ":{} | {} | {} | {}\n".format(*[f10[f"{year}{k}"] for year in ["2017", "2018", "2019", "2020"]]) f10["msg"] = msg return f10 diff --git a/czsc/connectors/ts_connector.py b/czsc/connectors/ts_connector.py index 11e117471..addb7515f 100644 --- a/czsc/connectors/ts_connector.py +++ b/czsc/connectors/ts_connector.py @@ -226,13 +226,13 @@ def pro_bar_minutes(ts_code, sdt, edt, freq="60min", asset="E", adj=None): latest_factor = factor.iloc[-1]["adj_factor"] adj_map = {row["trade_date"]: row["adj_factor"] for _, row in factor.iterrows()} for col in ["open", "close", "high", "low"]: - kline[col] = kline.apply(lambda x: x[col] * adj_map[x["trade_date"]] / latest_factor, axis=1) + kline[col] = kline.apply(lambda x, col=col: x[col] * adj_map[x["trade_date"]] / latest_factor, axis=1) if len(factor) > 0 and adj and adj == "hfq": # 后复权 = 当日收盘价 × 当日复权因子 adj_map = {row["trade_date"]: row["adj_factor"] for _, row in factor.iterrows()} for col in ["open", "close", "high", "low"]: - kline[col] = kline.apply(lambda x: x[col] * adj_map[x["trade_date"]], axis=1) + kline[col] = kline.apply(lambda x, col=col: x[col] * adj_map[x["trade_date"]], axis=1) if sdt: kline = kline[kline["trade_time"] >= pd.to_datetime(sdt)] diff --git a/czsc/fsa/base.py b/czsc/fsa/base.py index 284745bbc..2df193a5f 100644 --- a/czsc/fsa/base.py +++ b/czsc/fsa/base.py @@ -150,16 +150,17 @@ def upload_file(self, file_path, parent_node): file_size = os.path.getsize(file_path) url = "https://open.feishu.cn/open-apis/drive/v1/files/upload_all" - form = { - "file_name": os.path.basename(file_path), - "parent_type": "explorer", - "parent_node": parent_node, - "size": str(file_size), - "file": (open(file_path, "rb")), - } - multi_form = MultipartEncoder(form) - headers = {"Authorization": f"Bearer {self.get_access_token()}", "Content-Type": multi_form.content_type} - response = requests.request("POST", url, headers=headers, data=multi_form) + with open(file_path, "rb") as _fp: + form = { + "file_name": os.path.basename(file_path), + "parent_type": "explorer", + "parent_node": parent_node, + "size": str(file_size), + "file": _fp, + } + multi_form = MultipartEncoder(form) + headers = {"Authorization": f"Bearer {self.get_access_token()}", "Content-Type": multi_form.content_type} + response = requests.request("POST", url, headers=headers, data=multi_form) return response.json()["data"]["file_token"] def download_file(self, file_token, file_path): diff --git a/czsc/fsa/im.py b/czsc/fsa/im.py index 3e51012a4..afd2be6b7 100644 --- a/czsc/fsa/im.py +++ b/czsc/fsa/im.py @@ -51,10 +51,11 @@ def upload_im_file(self, file_path, file_type="stream"): from requests_toolbelt import MultipartEncoder url = "https://open.feishu.cn/open-apis/im/v1/files" - form = {"file_name": os.path.basename(file_path), "file_type": file_type, "file": (open(file_path, "rb"))} - multi_form = MultipartEncoder(form) - headers = {"Authorization": f"Bearer {self.get_access_token()}", "Content-Type": multi_form.content_type} - response = requests.request("POST", url, headers=headers, data=multi_form) + with open(file_path, "rb") as _fp: + form = {"file_name": os.path.basename(file_path), "file_type": file_type, "file": _fp} + multi_form = MultipartEncoder(form) + headers = {"Authorization": f"Bearer {self.get_access_token()}", "Content-Type": multi_form.content_type} + response = requests.request("POST", url, headers=headers, data=multi_form) return response.json()["data"]["file_key"] def upload_im_image(self, image_path, image_type="message"): @@ -72,10 +73,11 @@ def upload_im_image(self, image_path, image_type="message"): from requests_toolbelt import MultipartEncoder url = "https://open.feishu.cn/open-apis/im/v1/images" - form = {"image_type": image_type, "image": (open(image_path, "rb"))} - multi_form = MultipartEncoder(form) - headers = {"Authorization": f"Bearer {self.get_access_token()}", "Content-Type": multi_form.content_type} - response = requests.request("POST", url, headers=headers, data=multi_form) + with open(image_path, "rb") as _fp: + form = {"image_type": image_type, "image": _fp} + multi_form = MultipartEncoder(form) + headers = {"Authorization": f"Bearer {self.get_access_token()}", "Content-Type": multi_form.content_type} + response = requests.request("POST", url, headers=headers, data=multi_form) return response.json()["data"]["image_key"] def send(self, payload, receive_id_type="open_id"): diff --git a/czsc/sensors/utils.py b/czsc/sensors/utils.py index 0afa941a5..cbc3f10cd 100644 --- a/czsc/sensors/utils.py +++ b/czsc/sensors/utils.py @@ -112,7 +112,9 @@ def holds_concepts_effect(holds: pd.DataFrame, concepts: dict, top_n=20, min_n=3 dt_key_concepts[dt] = key_concepts # 计算在密集概念中出现次数超过min_n的股票 - dfg["强势概念"] = dfg["概念板块"].apply(lambda x: ",".join(set(x) & set(key_concepts))) + dfg["强势概念"] = dfg["概念板块"].apply( + lambda x, key_concepts=key_concepts: ",".join(set(x) & set(key_concepts)) + ) sel = dfg[dfg["强势概念"].apply(lambda x: len(x.split(",")) >= min_n)] new_holds.append(sel) diff --git a/czsc/svc/__init__.pyi b/czsc/svc/__init__.pyi index 676efcd8e..3f10d4864 100644 --- a/czsc/svc/__init__.pyi +++ b/czsc/svc/__init__.pyi @@ -221,4 +221,3 @@ __all__ = [ "weight_backtest_form", "code_editor_form", ] - diff --git a/czsc/svc/strategy.py b/czsc/svc/strategy.py index cafa42fbc..297f1e455 100644 --- a/czsc/svc/strategy.py +++ b/czsc/svc/strategy.py @@ -154,8 +154,10 @@ def show_czsc_trader(trader, max_k_num=300, key=None, **kwargs): open_ops = [czsc.Operate.LO, czsc.Operate.SO] # 开仓用上三角,平仓用下三角;颜色区分多空 - bs_df["tag"] = bs_df["op"].apply(lambda x: "triangle-up" if x in open_ops else "triangle-down") - bs_df["color"] = bs_df["op"].apply(lambda x: "red" if x in open_ops else "white") + bs_df["tag"] = bs_df["op"].apply( + lambda x, open_ops=open_ops: "triangle-up" if x in open_ops else "triangle-down" + ) + bs_df["color"] = bs_df["op"].apply(lambda x, open_ops=open_ops: "red" if x in open_ops else "white") kline.add_scatter_indicator( bs_df["dt"], diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index 87827ad97..a22c5ffb1 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -212,7 +212,8 @@ def get_py_namespace(file_py: str, keys: list = None) -> dict: allowed_prefixes = [os.path.abspath("czsc/strategies"), os.path.abspath("czsc/signals")] if not any(file_py.startswith(p) for p in allowed_prefixes): raise ValueError(f"文件路径 {file_py} 不在白名单目录内") - text = open(file_py, encoding="utf-8").read() + with open(file_py, encoding="utf-8") as _f: + text = _f.read() code = compile(text, file_py, "exec") namespace = {"file_py": file_py, "file_name": os.path.basename(file_py).split(".")[0]} exec(code, namespace) diff --git a/czsc/utils/data/cache.py b/czsc/utils/data/cache.py index c725a5287..ec4fba2ae 100644 --- a/czsc/utils/data/cache.py +++ b/czsc/utils/data/cache.py @@ -94,9 +94,11 @@ def get(self, k: str, suffix: str = "pkl") -> Any: if suffix == "pkl": import dill - res = dill.load(open(file, "rb")) + with open(file, "rb") as _f: + res = dill.load(_f) elif suffix == "json": - res = json.load(open(file, encoding="utf-8")) + with open(file, encoding="utf-8") as _f: + res = json.load(_f) elif suffix == "txt": res = file.read_text(encoding="utf-8") elif suffix == "csv": @@ -125,12 +127,14 @@ def set(self, k: str, v: Any, suffix: str = "pkl"): if suffix == "pkl": import dill - dill.dump(v, open(file, "wb")) + with open(file, "wb") as _f: + dill.dump(v, _f) elif suffix == "json": if not isinstance(v, dict): raise ValueError("suffix json only support dict") - json.dump(v, open(file, "w", encoding="utf-8"), ensure_ascii=False, indent=4) + with open(file, "w", encoding="utf-8") as _f: + json.dump(v, _f, ensure_ascii=False, indent=4) elif suffix == "txt": if not isinstance(v, str): diff --git a/czsc/utils/data/client.py b/czsc/utils/data/client.py index 532dccb78..02a5f12cb 100644 --- a/czsc/utils/data/client.py +++ b/czsc/utils/data/client.py @@ -54,8 +54,9 @@ class DataClient: __version__ = "V250719" _cache_lock = threading.Lock() # 进程内线程锁,防止并发冲突 + @staticmethod @lru_cache(maxsize=128) - def _get_cache_key(self, req_params_str: str) -> str: + def _get_cache_key(req_params_str: str) -> str: """缓存哈希计算,避免重复计算""" return hashlib.md5(req_params_str.encode("utf-8")).hexdigest().upper()[:8] diff --git a/pyproject.toml b/pyproject.toml index 3842be606..620147802 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,10 @@ select = ["E", "F", "I", "UP", "B", "SIM", "C4"] ignore = ["E501", "SIM112"] extend-safe-fixes = ["UP"] +[tool.ruff.lint.per-file-ignores] +# Auto-generated by pyo3-stub-gen; regeneration overwrites manual edits. +"czsc/_native.pyi" = ["F811"] + [tool.ruff.lint.isort] known-first-party = ["czsc"] From bb955a113fd4cc130afe7376162e852baa95c0e7 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Fri, 8 May 2026 12:10:45 +0800 Subject: [PATCH 21/23] Refactor tests and improve code readability - Reordered imports in `test_fx.rs` for consistency. - Simplified cloning in `test_mark_direction.rs` by directly assigning values. - Enhanced readability of assertions in `test_signal.rs` by formatting multi-line assertions. - Improved formatting and readability in `test_zs.rs` by aligning parameters and breaking long lines. - Cleaned up function signatures in `lib.rs` and `signals_dispatcher.rs` for better clarity. - Removed unnecessary `into_iter()` calls in `signals_dispatcher.rs` and `trader/api.rs`. - Streamlined parameter handling in `trader/czsc_signals.rs` for better readability. - Enhanced readability in `trader/czsc_trader.rs` by breaking long lines and improving indentation. - Improved error handling and readability in `czsc-signals/src` files by using more concise patterns. - Cleaned up unnecessary braces and improved formatting in various signal functions. - Enhanced readability in `czsc-ta/src/python.rs` by breaking long function signatures into multiple lines. - Improved assertions in tests for better clarity and consistency. - Refactored `bar_generator.rs` and `mod.rs` for better readability and consistency in function signatures. --- crates/czsc-core/src/objects/event.rs | 8 +-- crates/czsc-core/src/objects/mod.rs | 18 +++--- crates/czsc-core/src/objects/operate.rs | 5 +- crates/czsc-core/src/objects/position.rs | 13 ++-- crates/czsc-core/src/objects/signal.rs | 5 +- crates/czsc-core/tests/test_analyze_utils.rs | 4 +- crates/czsc-core/tests/test_bi.rs | 15 ++++- crates/czsc-core/tests/test_czsc_analyzer.rs | 15 ++--- crates/czsc-core/tests/test_event.rs | 11 ++-- crates/czsc-core/tests/test_fx.rs | 2 +- crates/czsc-core/tests/test_mark_direction.rs | 2 +- crates/czsc-core/tests/test_signal.rs | 5 +- crates/czsc-core/tests/test_zs.rs | 64 +++++++++++++++---- crates/czsc-python/src/lib.rs | 19 ++++-- crates/czsc-python/src/signals_dispatcher.rs | 21 +++--- crates/czsc-python/src/trader/api.rs | 2 +- crates/czsc-python/src/trader/czsc_signals.rs | 31 +++++---- crates/czsc-python/src/trader/czsc_trader.rs | 56 +++++++++------- crates/czsc-python/src/trader/research.rs | 2 +- .../czsc-signal-macros/tests/test_export.rs | 3 +- crates/czsc-signals/src/ang.rs | 5 +- crates/czsc-signals/src/bar.rs | 5 +- crates/czsc-signals/src/cxt.rs | 18 ++---- crates/czsc-signals/src/jcc.rs | 5 +- crates/czsc-signals/src/obv.rs | 10 ++- crates/czsc-signals/src/params.rs | 16 ++--- crates/czsc-signals/src/tas.rs | 16 ++--- crates/czsc-signals/src/utils/sig.rs | 22 +++---- crates/czsc-signals/src/zdy_trader.rs | 10 ++- crates/czsc-ta/src/python.rs | 24 +++++-- crates/czsc-ta/tests/test_pure.rs | 22 +++++-- crates/czsc-utils/src/bar_generator.rs | 2 +- crates/czsc-utils/src/python/mod.rs | 6 +- crates/czsc-utils/tests/test_bar_generator.rs | 8 ++- crates/error-macros/tests/test_derive.rs | 5 +- crates/error-support/tests/test_chain.rs | 7 +- 36 files changed, 279 insertions(+), 203 deletions(-) diff --git a/crates/czsc-core/src/objects/event.rs b/crates/czsc-core/src/objects/event.rs index c7bbf6d69..6f0c90c74 100644 --- a/crates/czsc-core/src/objects/event.rs +++ b/crates/czsc-core/src/objects/event.rs @@ -13,6 +13,10 @@ use std::str::FromStr; use super::operate::Operate; use super::signal::{ANY, Signal}; +#[cfg(feature = "python")] +use super::operate::PyOperate; +#[cfg(feature = "python")] +use super::signal::PySignal; #[cfg(feature = "python")] use pyo3::exceptions::PyValueError; #[cfg(feature = "python")] @@ -20,10 +24,6 @@ use pyo3::prelude::*; #[cfg(feature = "python")] use pyo3::types::{PyDict, PyDictMethods}; #[cfg(feature = "python")] -use super::operate::PyOperate; -#[cfg(feature = "python")] -use super::signal::PySignal; -#[cfg(feature = "python")] use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/crates/czsc-core/src/objects/mod.rs b/crates/czsc-core/src/objects/mod.rs index b81d0d935..64709c181 100644 --- a/crates/czsc-core/src/objects/mod.rs +++ b/crates/czsc-core/src/objects/mod.rs @@ -3,18 +3,18 @@ //! Migrated from rs-czsc 47ef6efa per docs/MIGRATION_NOTES.md §1. Submodules //! are added incrementally as Phase D sub-loops complete. -pub mod errors; -pub mod market; -pub mod freq; pub mod bar; -pub mod mark; +pub mod bi; pub mod direction; -pub mod fx; +pub mod errors; +pub mod event; pub mod fake_bi; -pub mod bi; -pub mod zs; +pub mod freq; +pub mod fx; +pub mod mark; +pub mod market; pub mod operate; -pub mod signal; -pub mod event; pub mod position; +pub mod signal; pub mod state; +pub mod zs; diff --git a/crates/czsc-core/src/objects/operate.rs b/crates/czsc-core/src/objects/operate.rs index 4d9307f5e..edb2a575a 100644 --- a/crates/czsc-core/src/objects/operate.rs +++ b/crates/czsc-core/src/objects/operate.rs @@ -15,10 +15,7 @@ use strum_macros::{AsRefStr, EnumIter, EnumString}; #[cfg(feature = "python")] use pyo3::prelude::*; #[cfg(feature = "python")] -use pyo3::{ - Bound, FromPyObject, PyResult, Python, exceptions::PyValueError, - types::PyAnyMethods, -}; +use pyo3::{Bound, FromPyObject, PyResult, Python, exceptions::PyValueError, types::PyAnyMethods}; #[cfg(feature = "python")] use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; diff --git a/crates/czsc-core/src/objects/position.rs b/crates/czsc-core/src/objects/position.rs index 1a7a5a58b..2a8250552 100644 --- a/crates/czsc-core/src/objects/position.rs +++ b/crates/czsc-core/src/objects/position.rs @@ -20,16 +20,16 @@ use super::event::Event; use super::operate::Operate; use super::signal::{ANY, Signal}; +#[cfg(feature = "python")] +use super::event::PyEvent; +#[cfg(feature = "python")] +use super::signal::PySignal; #[cfg(feature = "python")] use pyo3::exceptions::PyValueError; #[cfg(feature = "python")] use pyo3::prelude::*; #[cfg(feature = "python")] use pyo3::types::PyBytes; -#[cfg(feature = "python")] -use super::event::PyEvent; -#[cfg(feature = "python")] -use super::signal::PySignal; /// 解析 operate 字符串,支持英文缩写和中文名称 fn parse_operate(s: &str) -> Result { @@ -1111,7 +1111,10 @@ impl PyLiteBar { /// Python可见的Position包装器 #[cfg_attr(feature = "python", gen_stub_pyclass)] -#[cfg_attr(feature = "python", pyclass(name = "Position", module = "czsc._native"))] +#[cfg_attr( + feature = "python", + pyclass(name = "Position", module = "czsc._native") +)] #[derive(Debug, Clone)] pub struct PyPosition { pub inner: Position, diff --git a/crates/czsc-core/src/objects/signal.rs b/crates/czsc-core/src/objects/signal.rs index f50bd7852..6eeae9374 100644 --- a/crates/czsc-core/src/objects/signal.rs +++ b/crates/czsc-core/src/objects/signal.rs @@ -524,7 +524,10 @@ impl PySignal { /// Python可见的ParsedSignalDoc包装器 #[cfg_attr(feature = "python", gen_stub_pyclass)] -#[cfg_attr(feature = "python", pyclass(name = "ParsedSignalDoc", module = "czsc._native"))] +#[cfg_attr( + feature = "python", + pyclass(name = "ParsedSignalDoc", module = "czsc._native") +)] #[derive(Debug, Clone)] pub struct PyParsedSignalDoc { pub(crate) inner: ParsedSignalDoc, diff --git a/crates/czsc-core/tests/test_analyze_utils.rs b/crates/czsc-core/tests/test_analyze_utils.rs index c86218f74..441c75af0 100644 --- a/crates/czsc-core/tests/test_analyze_utils.rs +++ b/crates/czsc-core/tests/test_analyze_utils.rs @@ -76,7 +76,9 @@ fn check_fxs_extracts_fx_from_sequence() { #[test] fn check_bi_returns_tuple_with_remainder() { - let bars: Vec = (0..6).map(|i| nb(i + 1, 10.0 + i as f64, 9.0 + i as f64)).collect(); + let bars: Vec = (0..6) + .map(|i| nb(i + 1, 10.0 + i as f64, 9.0 + i as f64)) + .collect(); let (bi, remainder) = check_bi(&bars); // The function signature contract: always returns (Option, &[NewBar]) let _ = bi; diff --git a/crates/czsc-core/tests/test_bi.rs b/crates/czsc-core/tests/test_bi.rs index ab7a4d08b..7607e099c 100644 --- a/crates/czsc-core/tests/test_bi.rs +++ b/crates/czsc-core/tests/test_bi.rs @@ -38,7 +38,11 @@ fn fx(ts: i64, mark: Mark, level: f64) -> FX { .mark(mark) .high(k2.high) .low(k2.low) - .fx(if matches!(level, l if l > 5.0) { k2.high } else { k2.low }) + .fx(if matches!(level, l if l > 5.0) { + k2.high + } else { + k2.low + }) .elements(vec![k1, k2, k3]) .build() .unwrap() @@ -49,7 +53,14 @@ fn sample_bi_up() -> BI { let fx_a = fx(1_700_000_000, Mark::D, 9.0); let fx_b = fx(1_700_007_200, Mark::G, 12.0); let bars: Vec = (0..5) - .map(|i| nb(1_700_000_000 + i * 1800, 11.0 + i as f64 * 0.2, 9.5 + i as f64 * 0.2, 100.0)) + .map(|i| { + nb( + 1_700_000_000 + i * 1800, + 11.0 + i as f64 * 0.2, + 9.5 + i as f64 * 0.2, + 100.0, + ) + }) .collect(); BIBuilder::default() .symbol(Arc::::from("000001")) diff --git a/crates/czsc-core/tests/test_czsc_analyzer.rs b/crates/czsc-core/tests/test_czsc_analyzer.rs index 4867386ce..6c0cbe07a 100644 --- a/crates/czsc-core/tests/test_czsc_analyzer.rs +++ b/crates/czsc-core/tests/test_czsc_analyzer.rs @@ -77,18 +77,15 @@ fn fx_and_bi_lists_are_consistent_with_zigzag() { fn update_bar_appends_incrementally() { let bars = synthetic_zigzag(30); let mut c = CZSC::new(bars, 50); - let extra = rb( - 1_700_000_000 + 30 * 1800, - 102.0, - 103.0, - 104.0, - 101.0, - ); + let extra = rb(1_700_000_000 + 30 * 1800, 102.0, 103.0, 104.0, 101.0); c.update_bar(extra); assert_eq!(c.freq, Freq::F30); // bars_raw monotonically grows (modulo the analyzer's internal pruning) - assert!(c.bars_raw.iter().any(|b| b.dt - == Utc.timestamp_opt(1_700_000_000 + 30 * 1800, 0).unwrap())); + assert!( + c.bars_raw + .iter() + .any(|b| b.dt == Utc.timestamp_opt(1_700_000_000 + 30 * 1800, 0).unwrap()) + ); } #[test] diff --git a/crates/czsc-core/tests/test_event.rs b/crates/czsc-core/tests/test_event.rs index ec9d79f35..064eaf7b5 100644 --- a/crates/czsc-core/tests/test_event.rs +++ b/crates/czsc-core/tests/test_event.rs @@ -11,9 +11,7 @@ use czsc_core::objects::signal::Signal; fn sample_event() -> Event { Event { operate: Operate::LO, - signals_all: vec![ - Signal::from_str("30分钟_D1_前高_看多_强_任意_0").unwrap(), - ], + signals_all: vec![Signal::from_str("30分钟_D1_前高_看多_强_任意_0").unwrap()], signals_any: vec![ Signal::from_str("日线_D1_趋势_看多_中_任意_0").unwrap(), Signal::from_str("日线_D2_趋势_看多_弱_任意_0").unwrap(), @@ -38,8 +36,11 @@ fn compute_sha8_returns_4_hex_chars() { let e = sample_event(); let h = e.compute_sha8(); assert_eq!(h.len(), 4, "sha8 prefix must be 4 chars, got {h:?}"); - assert!(h.chars().all(|c| c.is_ascii_hexdigit() && c.is_ascii_uppercase() || c.is_ascii_digit()), - "expected uppercase hex, got {h:?}"); + assert!( + h.chars() + .all(|c| c.is_ascii_hexdigit() && c.is_ascii_uppercase() || c.is_ascii_digit()), + "expected uppercase hex, got {h:?}" + ); } #[test] diff --git a/crates/czsc-core/tests/test_fx.rs b/crates/czsc-core/tests/test_fx.rs index dc2b2a9ee..57bbe3327 100644 --- a/crates/czsc-core/tests/test_fx.rs +++ b/crates/czsc-core/tests/test_fx.rs @@ -6,8 +6,8 @@ use std::sync::Arc; use chrono::{TimeZone, Utc}; use czsc_core::objects::bar::{NewBar, NewBarBuilder}; -use czsc_core::objects::fx::{FX, FXBuilder}; use czsc_core::objects::freq::Freq; +use czsc_core::objects::fx::{FX, FXBuilder}; use czsc_core::objects::mark::Mark; fn nb(ts: i64, high: f64, low: f64, vol: f64) -> NewBar { diff --git a/crates/czsc-core/tests/test_mark_direction.rs b/crates/czsc-core/tests/test_mark_direction.rs index b5122da0b..8743944af 100644 --- a/crates/czsc-core/tests/test_mark_direction.rs +++ b/crates/czsc-core/tests/test_mark_direction.rs @@ -38,7 +38,7 @@ fn direction_display_round_trips() { #[test] fn direction_equality_and_clone() { let a = Direction::Up; - let b = a.clone(); + let b = a; assert_eq!(a, b); assert_ne!(Direction::Up, Direction::Down); } diff --git a/crates/czsc-core/tests/test_signal.rs b/crates/czsc-core/tests/test_signal.rs index bb0052376..3be8f1361 100644 --- a/crates/czsc-core/tests/test_signal.rs +++ b/crates/czsc-core/tests/test_signal.rs @@ -50,7 +50,10 @@ fn is_match_obeys_score_and_wildcards() { let s = Signal::from_str("30分钟_D1_前高_看多_强_任意_50").unwrap(); let mut dict = HashMap::new(); dict.insert("30分钟_D1_前高".to_string(), "看多_强_中_60".to_string()); - assert!(s.is_match(&dict), "score 60 >= 50 with v3 wildcard should match"); + assert!( + s.is_match(&dict), + "score 60 >= 50 with v3 wildcard should match" + ); let mut low_score = HashMap::new(); low_score.insert("30分钟_D1_前高".to_string(), "看多_强_中_40".to_string()); diff --git a/crates/czsc-core/tests/test_zs.rs b/crates/czsc-core/tests/test_zs.rs index 246b3338c..31b954a73 100644 --- a/crates/czsc-core/tests/test_zs.rs +++ b/crates/czsc-core/tests/test_zs.rs @@ -40,21 +40,35 @@ fn fx(ts: i64, mark: Mark, level: f64) -> FX { .mark(mark_clone) .high(k2.high) .low(k2.low) - .fx(if matches!(mark, Mark::G) { k2.high } else { k2.low }) + .fx(if matches!(mark, Mark::G) { + k2.high + } else { + k2.low + }) .elements(vec![k1, k2, k3]) .build() .unwrap() } -fn make_bi(ts_a: i64, mark_a: Mark, level_a: f64, - ts_b: i64, mark_b: Mark, level_b: f64, - direction: Direction) -> BI { +fn make_bi( + ts_a: i64, + mark_a: Mark, + level_a: f64, + ts_b: i64, + mark_b: Mark, + level_b: f64, + direction: Direction, +) -> BI { let fx_a = fx(ts_a, mark_a, level_a); let fx_b = fx(ts_b, mark_b, level_b); // bars span — endpoints determine high/low let bars = vec![ nb(ts_a, level_a + 0.5, level_a - 0.5), - nb((ts_a + ts_b) / 2, (level_a + level_b) / 2.0 + 0.5, (level_a + level_b) / 2.0 - 0.5), + nb( + (ts_a + ts_b) / 2, + (level_a + level_b) / 2.0 + 0.5, + (level_a + level_b) / 2.0 - 0.5, + ), nb(ts_b, level_b + 0.5, level_b - 0.5), ]; BIBuilder::default() @@ -70,12 +84,33 @@ fn make_bi(ts_a: i64, mark_a: Mark, level_a: f64, fn sample_zs() -> ZS { // 3-bi center: down 12 -> 9, up 9 -> 11, down 11 -> 9.5 - let bi1 = make_bi(1_700_000_000, Mark::G, 12.0, - 1_700_001_800, Mark::D, 9.0, Direction::Down); - let bi2 = make_bi(1_700_001_800, Mark::D, 9.0, - 1_700_003_600, Mark::G, 11.0, Direction::Up); - let bi3 = make_bi(1_700_003_600, Mark::G, 11.0, - 1_700_005_400, Mark::D, 9.5, Direction::Down); + let bi1 = make_bi( + 1_700_000_000, + Mark::G, + 12.0, + 1_700_001_800, + Mark::D, + 9.0, + Direction::Down, + ); + let bi2 = make_bi( + 1_700_001_800, + Mark::D, + 9.0, + 1_700_003_600, + Mark::G, + 11.0, + Direction::Up, + ); + let bi3 = make_bi( + 1_700_003_600, + Mark::G, + 11.0, + 1_700_005_400, + Mark::D, + 9.5, + Direction::Down, + ); ZS::new(vec![bi1, bi2, bi3]) } @@ -98,7 +133,12 @@ fn zg_zd_within_first_three_bis() { fn zz_is_midpoint_of_zg_zd() { let zs = sample_zs(); let mid = (zs.zg + zs.zd) / 2.0; - assert!((zs.zz - mid).abs() < 1e-9, "zz {} should equal mid {}", zs.zz, mid); + assert!( + (zs.zz - mid).abs() < 1e-9, + "zz {} should equal mid {}", + zs.zz, + mid + ); } #[test] diff --git a/crates/czsc-python/src/lib.rs b/crates/czsc-python/src/lib.rs index 6a73c52de..28fe81fc4 100644 --- a/crates/czsc-python/src/lib.rs +++ b/crates/czsc-python/src/lib.rs @@ -21,14 +21,15 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { // czsc-signals contributes `SignalDescriptor` entries via // `inventory::collect!`. The dummy iterator forces the crate // into the final cdylib so the constructors run on import. - let _signals_count = inventory::iter::() - .into_iter() - .count(); + let _signals_count = inventory::iter::().count(); // Trader surface — CzscTrader, CzscSignals, generate_czsc_signals. m.add_class::()?; m.add_class::()?; - m.add_function(wrap_pyfunction!(trader::generate::generate_czsc_signals, m)?)?; + m.add_function(wrap_pyfunction!( + trader::generate::generate_czsc_signals, + m + )?)?; // Research / optimize entrypoints (mirrors rs_czsc/python/src/lib.rs). // These are the heavy-lift functions that strategies.py / @@ -42,8 +43,14 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(trader::research::run_research, m)?)?; m.add_function(wrap_pyfunction!(trader::research::run_replay, m)?)?; m.add_function(wrap_pyfunction!(trader::research::run_optimize_batch, m)?)?; - m.add_function(wrap_pyfunction!(trader::research::build_open_optim_positions, m)?)?; - m.add_function(wrap_pyfunction!(trader::research::build_exit_optim_positions, m)?)?; + m.add_function(wrap_pyfunction!( + trader::research::build_open_optim_positions, + m + )?)?; + m.add_function(wrap_pyfunction!( + trader::research::build_exit_optim_positions, + m + )?)?; // czsc._native.signals namespace + per-category sub-modules // (bar / cxt / tas / vol / pressure / obv / cvolp). The dispatcher diff --git a/crates/czsc-python/src/signals_dispatcher.rs b/crates/czsc-python/src/signals_dispatcher.rs index 476c4a4e6..261b8dec0 100644 --- a/crates/czsc-python/src/signals_dispatcher.rs +++ b/crates/czsc-python/src/signals_dispatcher.rs @@ -68,8 +68,8 @@ pub fn call_signal( czsc: &CZSC, params: Option<&Bound<'_, PyDict>>, ) -> PyResult> { - let descriptor = lookup(name) - .ok_or_else(|| PyKeyError::new_err(format!("unknown signal: {name}")))?; + let descriptor = + lookup(name).ok_or_else(|| PyKeyError::new_err(format!("unknown signal: {name}")))?; let kline_func = match descriptor.func_ref { SignalFnRef::Kline(f) => f, @@ -96,7 +96,6 @@ pub fn call_signal( #[pyo3(signature = (category=None))] pub fn list_signal_names(category: Option<&str>) -> Vec { let mut out: Vec = inventory::iter::() - .into_iter() .filter(|d| matches!(d.func_ref, SignalFnRef::Kline(_))) .filter(|d| match category { Some(c) => name_prefix(d.name).map(|p| p == c).unwrap_or(false), @@ -130,7 +129,11 @@ pub fn get_signal_category(name: &str) -> Option { /// and ``czsc._native.signals`` (submodule). The submodule entries /// give design-doc §3.3 the path ``from czsc._native.signals import /// call_signal``. -pub fn register(py: Python<'_>, m: &Bound<'_, PyModule>, signals_mod: &Bound<'_, PyModule>) -> PyResult<()> { +pub fn register( + py: Python<'_>, + m: &Bound<'_, PyModule>, + signals_mod: &Bound<'_, PyModule>, +) -> PyResult<()> { use pyo3::wrap_pyfunction; m.add_function(wrap_pyfunction!(call_signal, m)?)?; @@ -151,15 +154,7 @@ pub fn register(py: Python<'_>, m: &Bound<'_, PyModule>, signals_mod: &Bound<'_, // // The Python-side `czsc/signals/.py` shim layers __getattr__ // on top of these to expose individual functions. - let categories = [ - "bar", - "cxt", - "tas", - "vol", - "pressure", - "obv", - "cvolp", - ]; + let categories = ["bar", "cxt", "tas", "vol", "pressure", "obv", "cvolp"]; let sys = py.import("sys")?; let py_modules = sys.getattr("modules")?; for cat in categories { diff --git a/crates/czsc-python/src/trader/api.rs b/crates/czsc-python/src/trader/api.rs index d44fc9126..9b73d6c95 100644 --- a/crates/czsc-python/src/trader/api.rs +++ b/crates/czsc-python/src/trader/api.rs @@ -13,12 +13,12 @@ use czsc_core::analyze::utils::format_standard_kline; use czsc_core::objects::bar::RawBar; use czsc_core::objects::freq::Freq; use czsc_core::objects::position::Position; +use czsc_signals::registry::list_all_signals as list_all_registered_signals; use czsc_trader::engine_v2::{ExecutionPlan, ExecutionPlanInput, UnifiedExecEngine}; use czsc_trader::optimize::{ get_exit_optim_positions, get_open_optim_positions, symbols_optim_parallel, }; use czsc_trader::signals::sig_parse::{SignalConfig, get_signals_config, get_signals_freqs}; -use czsc_signals::registry::list_all_signals as list_all_registered_signals; use polars::prelude::*; fn write_df_parquet(path: &Path, mut df: DataFrame) -> PyResult<()> { diff --git a/crates/czsc-python/src/trader/czsc_signals.rs b/crates/czsc-python/src/trader/czsc_signals.rs index a49f4b733..cc7b88ad9 100644 --- a/crates/czsc-python/src/trader/czsc_signals.rs +++ b/crates/czsc-python/src/trader/czsc_signals.rs @@ -34,13 +34,14 @@ pub(crate) fn parse_signals_config(configs: &Bound) -> PyResult() { - for (k, v) in params_dict.iter() { - let key: String = k.extract()?; - let val = py_to_serde_value(&v)?; - params.insert(key, val); - } - } + && let Ok(params_dict) = params_obj.downcast::() + { + for (k, v) in params_dict.iter() { + let key: String = k.extract()?; + let val = py_to_serde_value(&v)?; + params.insert(key, val); + } + } // 也支持 flat params:dict 中除 name/freq/params 以外的 key 直接作为参数 for (k, v) in dict.iter() { @@ -166,11 +167,12 @@ impl PyCzscSignals { #[getter] fn end_dt(&self, py: Python) -> PyResult> { if let Some(dt_str) = self.inner.s.get("dt") - && let Ok(dt) = chrono::DateTime::parse_from_rfc3339(dt_str) { - let utc_dt = dt.with_timezone(&chrono::Utc); - let timestamp = create_naive_pandas_timestamp(py, utc_dt)?; - return Ok(Some(timestamp)); - } + && let Ok(dt) = chrono::DateTime::parse_from_rfc3339(dt_str) + { + let utc_dt = dt.with_timezone(&chrono::Utc); + let timestamp = create_naive_pandas_timestamp(py, utc_dt)?; + return Ok(Some(timestamp)); + } Ok(None) } @@ -247,10 +249,7 @@ impl PyCzscSignals { /// Helper: convert `Vec` back to a Python ``list[dict]`` /// shaped exactly like ``parse_signals_config`` expects, so /// ``__reduce__`` -> ``__new__`` round-trips cleanly. -pub(crate) fn signal_configs_to_pylist( - py: Python, - configs: &[SignalConfig], -) -> PyResult { +pub(crate) fn signal_configs_to_pylist(py: Python, configs: &[SignalConfig]) -> PyResult { let list = PyList::empty(py); for cfg in configs { let dict = PyDict::new(py); diff --git a/crates/czsc-python/src/trader/czsc_trader.rs b/crates/czsc-python/src/trader/czsc_trader.rs index 69340d676..df0377b08 100644 --- a/crates/czsc-python/src/trader/czsc_trader.rs +++ b/crates/czsc-python/src/trader/czsc_trader.rs @@ -30,14 +30,16 @@ fn extract_position(_py: Python, obj: &Bound) -> PyResult { } // 尝试从 _inner 属性提取 if let Ok(inner_attr) = obj.getattr("_inner") - && let Ok(py_pos) = inner_attr.extract::() { - return Ok(py_pos.inner); - } + && let Ok(py_pos) = inner_attr.extract::() + { + return Ok(py_pos.inner); + } // 尝试从 inner 属性提取 if let Ok(inner_attr) = obj.getattr("inner") - && let Ok(py_pos) = inner_attr.extract::() { - return Ok(py_pos.inner); - } + && let Ok(py_pos) = inner_attr.extract::() + { + return Ok(py_pos.inner); + } Err(PyValueError::new_err( "positions 中的元素必须是 Position 或有 _inner/inner 属性的对象", )) @@ -138,11 +140,12 @@ impl PyCzscTrader { #[getter] fn end_dt(&self, py: Python) -> PyResult> { if let Some(dt_str) = self.inner.signals.s.get("dt") - && let Ok(dt) = chrono::DateTime::parse_from_rfc3339(dt_str) { - let utc_dt = dt.with_timezone(&chrono::Utc); - let timestamp = create_naive_pandas_timestamp(py, utc_dt)?; - return Ok(Some(timestamp)); - } + && let Ok(dt) = chrono::DateTime::parse_from_rfc3339(dt_str) + { + let utc_dt = dt.with_timezone(&chrono::Utc); + let timestamp = create_naive_pandas_timestamp(py, utc_dt)?; + return Ok(Some(timestamp)); + } Ok(None) } @@ -351,7 +354,13 @@ impl PyCzscTrader { let configs_list = super::czsc_signals::signal_configs_to_pylist(py, &self.signals_config)?; let constructor = py.get_type::(); - let args = (bg_clone, positions_list, configs_list, self.ensemble_method.clone()).into_pyobject(py)?; + let args = ( + bg_clone, + positions_list, + configs_list, + self.ensemble_method.clone(), + ) + .into_pyobject(py)?; let result = (constructor, args).into_pyobject(py)?; Ok(result.into_any().unbind()) } @@ -385,18 +394,19 @@ fn parse_dt_from_pyobj(obj: &Bound) -> PyResult> { // 尝试 pandas Timestamp: 调用 .isoformat() 或 str() if let Ok(iso) = obj.call_method0("isoformat") - && let Ok(s) = iso.extract::() { - if let Ok(dt) = DateTime::parse_from_rfc3339(&s) { - return Ok(dt); - } - // pandas isoformat 可能不带时区 - if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S") { - return Ok(DateTime::from_naive_utc_and_offset( - naive, - FixedOffset::east_opt(0).unwrap(), - )); - } + && let Ok(s) = iso.extract::() + { + if let Ok(dt) = DateTime::parse_from_rfc3339(&s) { + return Ok(dt); } + // pandas isoformat 可能不带时区 + if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S") { + return Ok(DateTime::from_naive_utc_and_offset( + naive, + FixedOffset::east_opt(0).unwrap(), + )); + } + } // 最后降级:str(obj) let s = obj.str()?.to_string(); diff --git a/crates/czsc-python/src/trader/research.rs b/crates/czsc-python/src/trader/research.rs index 5c1edc6ec..b4e98b67d 100644 --- a/crates/czsc-python/src/trader/research.rs +++ b/crates/czsc-python/src/trader/research.rs @@ -5,10 +5,10 @@ use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; use czsc_core::analyze::utils::format_standard_kline; use czsc_core::objects::freq::Freq; use czsc_core::objects::position::Position; +use czsc_signals::registry::{SIGNAL_REGISTRY, TRADER_SIGNAL_REGISTRY}; use czsc_trader::engine_v2::{ExecutionPlan, ExecutionPlanInput, UnifiedExecEngine}; use czsc_trader::optimize::{get_exit_optim_positions, get_open_optim_positions}; use czsc_trader::signals::sig_parse::SignalConfig; -use czsc_signals::registry::{SIGNAL_REGISTRY, TRADER_SIGNAL_REGISTRY}; use polars::prelude::*; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; diff --git a/crates/czsc-signal-macros/tests/test_export.rs b/crates/czsc-signal-macros/tests/test_export.rs index 88cb1514e..9bbb1fbba 100644 --- a/crates/czsc-signal-macros/tests/test_export.rs +++ b/crates/czsc-signal-macros/tests/test_export.rs @@ -9,7 +9,6 @@ #[test] fn proc_macro_crate_links() { - // Reaching this assertion means the crate compiled with both + // Reaching this function means the crate compiled with both // #[proc_macro_attribute] entrypoints exported. - assert!(true); } diff --git a/crates/czsc-signals/src/ang.rs b/crates/czsc-signals/src/ang.rs index 8d13c18da..0b2755335 100644 --- a/crates/czsc-signals/src/ang.rs +++ b/crates/czsc-signals/src/ang.rs @@ -375,12 +375,11 @@ pub fn skdj_up_dw_line_v230611(c: &CZSC, params: &ParamView, cache: &mut TaCache for (i, bar) in c.bars_raw.iter().enumerate() { rsv_ids.push(bar.id); // 对齐 Python:历史 bar 的 RSV 只计算一次;同 dt 延伸时仅最后一根会重算。 - if i + 1 < c.bars_raw.len() { - if let Some(v) = old_map.get(&bar.id) { + if i + 1 < c.bars_raw.len() + && let Some(v) = old_map.get(&bar.id) { rsv_series.push(*v); continue; } - } let win = if i < n { &c.bars_raw[..=i] } else { diff --git a/crates/czsc-signals/src/bar.rs b/crates/czsc-signals/src/bar.rs index f702d3d32..293a79793 100644 --- a/crates/czsc-signals/src/bar.rs +++ b/crates/czsc-signals/src/bar.rs @@ -63,11 +63,10 @@ pub fn bar_single_v230506(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> factors.push((bar.close - bar.open) / (bar.open * bar.vol)); } - if valid && !factors.is_empty() { - if let Some(q) = pd_cut_last_label(&factors, n) { + if valid && !factors.is_empty() + && let Some(q) = pd_cut_last_label(&factors, n) { v1 = format!("第{}层", q); } - } } make_kline_signal_v1(&k1, &k2, k3, &v1) diff --git a/crates/czsc-signals/src/cxt.rs b/crates/czsc-signals/src/cxt.rs index b7fdfc96c..bf4282938 100644 --- a/crates/czsc-signals/src/cxt.rs +++ b/crates/czsc-signals/src/cxt.rs @@ -592,13 +592,11 @@ pub fn cxt_bi_end_v230324(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> let mut ma_vals: Vec = Vec::new(); for rb in fx_raw.iter().take(fx_raw.len() - 1) { - if let Some(idx) = id_to_idx.get(&rb.id) { - if let Some(x) = ma.get(*idx) { - if x.is_finite() { + if let Some(idx) = id_to_idx.get(&rb.id) + && let Some(x) = ma.get(*idx) + && x.is_finite() { ma_vals.push(*x); } - } - } } if ma_vals.is_empty() { return make_kline_signal_v1(&k1, &k2, k3, v1); @@ -1638,11 +1636,10 @@ pub fn cxt_bi_end_v230222(c: &CZSC, params: &ParamView, _cache: &mut TaCache) -> } let mut fxs: Vec = Vec::new(); - if let Some(last_bi) = c.bi_list.last() { - if last_bi.fxs.len() > 1 { + if let Some(last_bi) = c.bi_list.last() + && last_bi.fxs.len() > 1 { fxs.extend_from_slice(&last_bi.fxs[1..]); } - } for x in ubi_fxs { if fxs.last().map(|y| x.dt > y.dt).unwrap_or(true) { fxs.push(x); @@ -1963,11 +1960,10 @@ pub fn cxt_bi_end_v230322(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> let last_fx_raw = fx_raw_bars(last_fx); let mut ma_vals = Vec::new(); for rb in &last_fx_raw { - if let Some(idx) = id_to_idx.get(&rb.id) { - if let Some(v) = ma.get(*idx) { + if let Some(idx) = id_to_idx.get(&rb.id) + && let Some(v) = ma.get(*idx) { ma_vals.push(*v); } - } } if ma_vals.is_empty() { return make_kline_signal_v1(&k1, &k2, k3, "其他"); diff --git a/crates/czsc-signals/src/jcc.rs b/crates/czsc-signals/src/jcc.rs index 48f99def7..da477252c 100644 --- a/crates/czsc-signals/src/jcc.rs +++ b/crates/czsc-signals/src/jcc.rs @@ -12,11 +12,10 @@ fn get_f64_param(params: &ParamView, key: &str, default: f64) -> f64 { if let Some(x) = v.as_f64() { return x; } - if let Some(s) = v.as_str() { - if let Ok(x) = s.parse::() { + if let Some(s) = v.as_str() + && let Ok(x) = s.parse::() { return x; } - } } default } diff --git a/crates/czsc-signals/src/obv.rs b/crates/czsc-signals/src/obv.rs index 5c27aa55e..ad226ab82 100644 --- a/crates/czsc-signals/src/obv.rs +++ b/crates/czsc-signals/src/obv.rs @@ -162,11 +162,10 @@ pub fn obvm_line_v230610(c: &CZSC, params: &ParamView, cache: &mut TaCache) -> V let bars = get_sub_elements(&c.bars_raw, di, n.max(m) + 10); let mut obv_seq = Vec::with_capacity(bars.len()); for b in bars { - if let Some(i) = id_map.get(&b.id).copied() { - if i < obv_series.len() { + if let Some(i) = id_map.get(&b.id).copied() + && i < obv_series.len() { obv_seq.push(obv_series[i]); } - } } if obv_seq.len() < n.max(m) { return make_kline_signal_v1(&k1, &k2, k3, v1_default); @@ -238,11 +237,10 @@ pub fn obv_up_dw_line_v230719( let bars = get_sub_elements(&c.bars_raw, di, min_k_num); let mut obv_seq = Vec::with_capacity(bars.len()); for b in bars { - if let Some(i) = id_map.get(&b.id).copied() { - if i < obv_series.len() { + if let Some(i) = id_map.get(&b.id).copied() + && i < obv_series.len() { obv_seq.push(obv_series[i]); } - } } if obv_seq.len() < min_k_num { return make_kline_signal_v1(&k1, &k2, k3, v1_default); diff --git a/crates/czsc-signals/src/params.rs b/crates/czsc-signals/src/params.rs index 8a25103fc..5bd8ffdde 100644 --- a/crates/czsc-signals/src/params.rs +++ b/crates/czsc-signals/src/params.rs @@ -19,10 +19,10 @@ impl<'a> ParamView<'a> { if let Some(n) = val.as_u64() { return n as usize; } - if let Some(s) = val.as_str() { - if let Ok(n) = s.parse::() { - return n; - } + if let Some(s) = val.as_str() + && let Ok(n) = s.parse::() + { + return n; } } default @@ -30,10 +30,10 @@ impl<'a> ParamView<'a> { #[inline] pub fn str<'b>(&'b self, key: &str, default: &'b str) -> &'b str { - if let Some(val) = self.inner.get(key) { - if let Some(s) = val.as_str() { - return s; - } + if let Some(val) = self.inner.get(key) + && let Some(s) = val.as_str() + { + return s; } default } diff --git a/crates/czsc-signals/src/tas.rs b/crates/czsc-signals/src/tas.rs index 69cd10386..f6ccd83d6 100644 --- a/crates/czsc-signals/src/tas.rs +++ b/crates/czsc-signals/src/tas.rs @@ -3087,13 +3087,11 @@ pub fn tas_macd_bc_v230803(czsc: &CZSC, _params: &ParamView, cache: &mut TaCache .flat_map(|nb| nb.elements.iter()) .nth(1) .map(|x| x.id); - if let (Some(i1), Some(i2)) = (id1, id2) { - if let (Some(macd1), Some(macd2)) = (get_macd(i1), get_macd(i2)) { - if macd1 > macd2 && macd2 > 0.0 { + if let (Some(i1), Some(i2)) = (id1, id2) + && let (Some(macd1), Some(macd2)) = (get_macd(i1), get_macd(i2)) + && macd1 > macd2 && macd2 > 0.0 { v1 = "空头"; } - } - } } } else { let bottoms: Vec<_> = fx_list @@ -3120,13 +3118,11 @@ pub fn tas_macd_bc_v230803(czsc: &CZSC, _params: &ParamView, cache: &mut TaCache .flat_map(|nb| nb.elements.iter()) .nth(1) .map(|x| x.id); - if let (Some(i1), Some(i2)) = (id1, id2) { - if let (Some(macd1), Some(macd2)) = (get_macd(i1), get_macd(i2)) { - if macd1 < macd2 && macd2 < 0.0 { + if let (Some(i1), Some(i2)) = (id1, id2) + && let (Some(macd1), Some(macd2)) = (get_macd(i1), get_macd(i2)) + && macd1 < macd2 && macd2 < 0.0 { v1 = "多头"; } - } - } } } diff --git a/crates/czsc-signals/src/utils/sig.rs b/crates/czsc-signals/src/utils/sig.rs index e0d06b4c9..40e41d58e 100644 --- a/crates/czsc-signals/src/utils/sig.rs +++ b/crates/czsc-signals/src/utils/sig.rs @@ -41,10 +41,10 @@ pub fn get_usize_param(params: &ParamView, key: &str, default: usize) -> usize { if let Some(n) = val.as_u64() { return n as usize; } - if let Some(s) = val.as_str() { - if let Ok(n) = s.parse::() { - return n; - } + if let Some(s) = val.as_str() + && let Ok(n) = s.parse::() + { + return n; } } default @@ -110,11 +110,7 @@ pub fn parse_minute_freq(freq: &str) -> Option { return None; } let n = freq.trim_end_matches("分钟").parse::().ok()?; - if n > 0 { - Some(n) - } else { - None - } + if n > 0 { Some(n) } else { None } } /// 计算分钟周期对应的结束时间(与 Python freq_end_time 口径一致) @@ -835,10 +831,10 @@ mod tests { /// 解析字符串参数 pub fn get_str_param<'a>(params: &'a ParamView, key: &str, default: &'a str) -> &'a str { - if let Some(val) = params.value(key) { - if let Some(s) = val.as_str() { - return s; - } + if let Some(val) = params.value(key) + && let Some(s) = val.as_str() + { + return s; } default } diff --git a/crates/czsc-signals/src/zdy_trader.rs b/crates/czsc-signals/src/zdy_trader.rs index 5f0b5c064..bdfaf1ab5 100644 --- a/crates/czsc-signals/src/zdy_trader.rs +++ b/crates/czsc-signals/src/zdy_trader.rs @@ -166,12 +166,11 @@ pub fn zdy_stop_loss_v230406(cat: &dyn TraderState, params: &ParamView) -> Vec Vec open_base_fx.high { + if let Some(open_base_fx) = fxs.iter().rfind(|x| x.mark == Mark::G && x.dt < op.dt) + && last_bar.close > open_base_fx.high { v1 = "空头止损"; v2 = "升破分型高点"; } - } if (1.0 - last_bar.close / op.price) * 10000.0 <= -first_stop { v1 = "空头止损"; v2 = "进场点止损"; diff --git a/crates/czsc-ta/src/python.rs b/crates/czsc-ta/src/python.rs index 1253a2318..83e7b1e51 100644 --- a/crates/czsc-ta/src/python.rs +++ b/crates/czsc-ta/src/python.rs @@ -29,7 +29,12 @@ fn rolling_rank(series: Vec, window: usize) -> Vec { #[pyfunction] #[pyo3(signature = (series, n=None, *, period=None, length=None))] -fn sma(series: Vec, n: Option, period: Option, length: Option) -> Vec { +fn sma( + series: Vec, + n: Option, + period: Option, + length: Option, +) -> Vec { // Same kwarg story as `ema` — talib's keyword is `timeperiod` / // pandas-ta's is `length`; rs-czsc historical scripts pass `n` / // `period`. Phase A parity test calls `ta.sma(series, length=20)`. @@ -100,7 +105,12 @@ fn rank_positions(series: Vec, n: usize) -> Vec { #[pyfunction] #[pyo3(signature = (series, n=None, *, period=None, length=None))] -fn ema(series: Vec, n: Option, period: Option, length: Option) -> Vec { +fn ema( + series: Vec, + n: Option, + period: Option, + length: Option, +) -> Vec { // Accept any of: positional `n`, kwargs `period=` (legacy rs-czsc) or // `length=` (talib / pandas-ta convention). The Phase A parity test // in `test/unit/test_ta_parity.py::test_ema_matches_talib` calls @@ -222,8 +232,14 @@ pub fn register(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { ); // numpy-bound entries - ta.add_function(wrap_pyfunction!(mixed::chip_dist::chip_distribution_triangle, &ta)?)?; - parent.add_function(wrap_pyfunction!(mixed::chip_dist::chip_distribution_triangle, parent)?)?; + ta.add_function(wrap_pyfunction!( + mixed::chip_dist::chip_distribution_triangle, + &ta + )?)?; + parent.add_function(wrap_pyfunction!( + mixed::chip_dist::chip_distribution_triangle, + parent + )?)?; // Register the submodule into sys.modules so `from czsc._native.ta // import ema` (and `import czsc._native.ta`) works the same as a diff --git a/crates/czsc-ta/tests/test_pure.rs b/crates/czsc-ta/tests/test_pure.rs index efb9ab5dd..35d913c8c 100644 --- a/crates/czsc-ta/tests/test_pure.rs +++ b/crates/czsc-ta/tests/test_pure.rs @@ -3,8 +3,8 @@ //! aligned with the rs-czsc 47ef6efa baseline. use czsc_ta::pure::{ - boll_positions, double_sma_positions, ema, mid_positions, rolling_rank, - single_ema_positions, single_sma_positions, true_range, ultimate_smoother, + boll_positions, double_sma_positions, ema, mid_positions, rolling_rank, single_ema_positions, + single_sma_positions, true_range, ultimate_smoother, }; fn series(n: usize) -> Vec { @@ -24,7 +24,12 @@ fn ultimate_smoother_first_4_passthrough() { let s = series(20); let out = ultimate_smoother(&s, 10.0); for i in 0..4 { - assert!((out[i] - s[i]).abs() < f64::EPSILON, "i={i}: {} vs {}", out[i], s[i]); + assert!( + (out[i] - s[i]).abs() < f64::EPSILON, + "i={i}: {} vs {}", + out[i], + s[i] + ); } } @@ -73,7 +78,10 @@ fn mid_positions_in_range() { let s = series(30); let out = mid_positions(&s, 5); for v in &out { - assert!(*v >= -1.0 && *v <= 1.0, "expected position in [-1, 1], got {v}"); + assert!( + *v >= -1.0 && *v <= 1.0, + "expected position in [-1, 1], got {v}" + ); } } @@ -96,9 +104,9 @@ fn boll_positions_in_signed_range() { #[test] fn true_range_matches_input_length() { - let high = vec![10.0, 11.0, 12.0, 11.5]; - let low = vec![ 9.0, 9.5, 10.5, 10.0]; - let prev = vec![ 9.5, 10.0, 11.0, 10.5]; + let high = vec![10.0, 11.0, 12.0, 11.5]; + let low = vec![9.0, 9.5, 10.5, 10.0]; + let prev = vec![9.5, 10.0, 11.0, 10.5]; let tr = true_range(&high, &low, &prev); assert_eq!(tr.len(), 4); // tr[i] = max(high-low, |high-prev|, |low-prev|) >= 0 diff --git a/crates/czsc-utils/src/bar_generator.rs b/crates/czsc-utils/src/bar_generator.rs index 753ee2df1..cceced036 100644 --- a/crates/czsc-utils/src/bar_generator.rs +++ b/crates/czsc-utils/src/bar_generator.rs @@ -408,7 +408,7 @@ impl BarGenerator { )); }; - self.init_freq_with_bars(freq, bars.into_iter())?; + self.init_freq_with_bars(freq, bars)?; Ok(()) }) } diff --git a/crates/czsc-utils/src/python/mod.rs b/crates/czsc-utils/src/python/mod.rs index 723f3db41..e8d53b2d8 100644 --- a/crates/czsc-utils/src/python/mod.rs +++ b/crates/czsc-utils/src/python/mod.rs @@ -23,11 +23,7 @@ fn is_trading_time(dt: chrono::NaiveDateTime, market: &str) -> bool { /// `PyValueError` via `UtilsError`'s PyErr conversion. #[pyfunction] #[pyo3(signature = (dt, freq, market=Market::Default))] -fn freq_end_time( - dt: DateTime, - freq: Freq, - market: Market, -) -> PyResult> { +fn freq_end_time(dt: DateTime, freq: Freq, market: Market) -> PyResult> { crate::freq_data::freq_end_time(dt, freq, market) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } diff --git a/crates/czsc-utils/tests/test_bar_generator.rs b/crates/czsc-utils/tests/test_bar_generator.rs index 1622e254e..937a21719 100644 --- a/crates/czsc-utils/tests/test_bar_generator.rs +++ b/crates/czsc-utils/tests/test_bar_generator.rs @@ -63,8 +63,8 @@ fn update_bar_appends_for_new_freq_window() { bg.update_bar(&bar(1_700_000_000, 10.0, 11.0, 12.0, 9.0)) .unwrap(); // Both freq queues received a bar - assert!(bg.freq_bars.get(&Freq::F1).unwrap().read().len() >= 1); - assert!(bg.freq_bars.get(&Freq::F30).unwrap().read().len() >= 1); + assert!(!bg.freq_bars.get(&Freq::F1).unwrap().read().is_empty()); + assert!(!bg.freq_bars.get(&Freq::F30).unwrap().read().is_empty()); } #[test] @@ -72,6 +72,8 @@ fn symbol_returns_seed_symbol_after_update() { let bg = BarGenerator::new(Freq::F1, vec![Freq::F30], 100, Market::Default).unwrap(); bg.update_bar(&bar(1_700_000_000, 10.0, 11.0, 12.0, 9.0)) .unwrap(); - let sym = bg.symbol().expect("symbol should be available after update"); + let sym = bg + .symbol() + .expect("symbol should be available after update"); assert_eq!(&*sym, "000001"); } diff --git a/crates/error-macros/tests/test_derive.rs b/crates/error-macros/tests/test_derive.rs index 65515633d..3b2d1d0b7 100644 --- a/crates/error-macros/tests/test_derive.rs +++ b/crates/error-macros/tests/test_derive.rs @@ -21,5 +21,8 @@ fn from_anyhow_blanket_impl_exists() { fn serialize_emits_string() { let dummy = DummyError::Unexpected(anyhow::anyhow!("boom")); let json = serde_json::to_string(&dummy).unwrap(); - assert!(json.contains("boom"), "expected serialized payload, got {json}"); + assert!( + json.contains("boom"), + "expected serialized payload, got {json}" + ); } diff --git a/crates/error-support/tests/test_chain.rs b/crates/error-support/tests/test_chain.rs index 945bfb498..464916f4a 100644 --- a/crates/error-support/tests/test_chain.rs +++ b/crates/error-support/tests/test_chain.rs @@ -6,12 +6,15 @@ use error_support::{czsc_bail, expand_error_chain}; #[test] fn expand_error_chain_walks_sources() { - let inner = std::io::Error::new(std::io::ErrorKind::Other, "leaf"); + let inner = std::io::Error::other("leaf"); let mid = anyhow::Error::new(inner).context("middle layer"); let outer = mid.context("outermost"); let chain = expand_error_chain(&outer); assert!(chain.contains("outermost"), "chain missing outer: {chain}"); - assert!(chain.contains("Caused by: middle layer"), "missing middle: {chain}"); + assert!( + chain.contains("Caused by: middle layer"), + "missing middle: {chain}" + ); assert!(chain.contains("Caused by: leaf"), "missing leaf: {chain}"); } From 3dacc180c43a68052673f8f3778e7815db85d308 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Fri, 8 May 2026 12:20:39 +0800 Subject: [PATCH 22/23] refactor: improve type hinting and fix groupby mapping in trade utility functions --- czsc/_native.pyi | 298 ++++++++++++++++---------------------------- czsc/utils/trade.py | 4 +- pyproject.toml | 8 +- 3 files changed, 113 insertions(+), 197 deletions(-) diff --git a/czsc/_native.pyi b/czsc/_native.pyi index 2a9aee46f..ff800f954 100644 --- a/czsc/_native.pyi +++ b/czsc/_native.pyi @@ -2,11 +2,10 @@ # ruff: noqa: E501, F401 import builtins -import typing -from enum import Enum - import numpy import numpy.typing +import typing +from enum import Enum class BI: r""" @@ -117,29 +116,21 @@ class BI: r""" 缓存字典(与 czsc 库兼容) """ - def __new__( - cls, - symbol: builtins.str, - direction: Direction, - fx_a: FX, - fx_b: FX, - fxs: typing.Sequence[FX], - bars: typing.Sequence[NewBar], - ) -> BI: ... - def get_cache_with_default(self, _key: builtins.str, default_value: builtins.float) -> builtins.float: + def __new__(cls, symbol:builtins.str, direction:Direction, fx_a:FX, fx_b:FX, fxs:typing.Sequence[FX], bars:typing.Sequence[NewBar]) -> BI: ... + def get_cache_with_default(self, _key:builtins.str, default_value:builtins.float) -> builtins.float: r""" 获取缓存值,如果不存在则返回默认值(与 czsc 库兼容) """ - def get_price_linear(self, n: builtins.int) -> builtins.float: + def get_price_linear(self, n:builtins.int) -> builtins.float: r""" 获取线性价格(与 czsc 库兼容) """ def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other: BI, op: int) -> builtins.bool: ... + def __richcmp__(self, other:BI, op:int) -> builtins.bool: ... class BarGenerator: @property - def symbol_py(self) -> builtins.str | None: + def symbol_py(self) -> typing.Optional[builtins.str]: r""" 获取所属品种 - Python 属性 """ @@ -149,7 +140,7 @@ class BarGenerator: 获取基准频率 """ @property - def end_dt(self) -> typing.Any | None: + def end_dt(self) -> typing.Optional[typing.Any]: r""" 获取end_dt属性(Python兼容) """ @@ -158,34 +149,28 @@ class BarGenerator: r""" 获取各周期K线数据 - 返回字典,键为频率字符串,值为K线列表 """ - def __new__( - cls, - base_freq: typing.Any, - freqs: typing.Any, - max_count: builtins.int = 2000, - market: typing.Any | None = None, - ) -> BarGenerator: ... - def init_freq_bars(self, freq: typing.Any, bars: typing.Sequence[RawBar]) -> None: + def __new__(cls, base_freq:typing.Any, freqs:typing.Any, max_count:builtins.int=2000, market:typing.Optional[typing.Any]=None) -> BarGenerator: ... + def init_freq_bars(self, freq:typing.Any, bars:typing.Sequence[RawBar]) -> None: r""" 初始化某个周期的K线序列 - + # 函数计算逻辑 - + 1. 检查输入的`freq`是否存在于`self.freq_bars`的键中。如果不存在,返回错误。 2. 检查`self.freq_bars[freq]`是否为空。如果不为空,返回错误,表示不允许重复初始化。 3. 如果以上检查都通过,将输入的`bars`存储到`self.freq_bars[freq]`中。 4. 从`bars`中获取最后一根K线的交易标的代码,更新`self.symbol`。 - + # Arguments - + * `freq` - 周期名称 (支持字符串或Freq枚举) * `bars` - K线序列 """ - def get_latest_date(self) -> builtins.str | None: + def get_latest_date(self) -> typing.Optional[builtins.str]: r""" 获取最新K线日期 """ - def update(self, bar: RawBar) -> None: + def update(self, bar:RawBar) -> None: r""" 从Python RawBar对象更新K线数据 - 支持直接自动转换 """ @@ -193,7 +178,7 @@ class BarGenerator: r""" 支持 pickle 序列化 - 使用 __reduce__ 方法 """ - def __setstate__(self, state: typing.Any) -> None: + def __setstate__(self, state:typing.Any) -> None: r""" 支持 pickle 反序列化 """ @@ -269,19 +254,19 @@ class CZSC: r""" 缓存字典(与 czsc 库兼容) """ - def __new__(cls, bars_raw: typing.Sequence[RawBar], max_bi_num: builtins.int = 50) -> CZSC: ... + def __new__(cls, bars_raw:typing.Sequence[RawBar], max_bi_num:builtins.int=50) -> CZSC: ... @staticmethod - def from_dataframe(df_bytes: bytes, freq: Freq, max_bi_num: builtins.int = 50) -> CZSC: + def from_dataframe(df_bytes:bytes, freq:Freq, max_bi_num:builtins.int=50) -> CZSC: r""" 直接从Arrow格式的DataFrame创建CZSC对象,避免中间转换 这是高性能的批量创建接口,适用于大量数据的初始化 - + :param df_bytes: Arrow IPC格式的DataFrame字节数据 :param freq: K线频率 :param max_bi_num: 最大笔数量限制 :return: CZSC对象 """ - def open_in_browser(self, _renderer: builtins.str | None = None) -> builtins.str: + def open_in_browser(self, _renderer:typing.Optional[builtins.str]=None) -> builtins.str: r""" 在浏览器中打开(与 czsc 库兼容) """ @@ -293,7 +278,7 @@ class CZSC: r""" 转换为 Plotly 格式(与 czsc 库兼容) """ - def update(self, bar: RawBar) -> None: + def update(self, bar:RawBar) -> None: r""" 更新K线数据 """ @@ -301,7 +286,7 @@ class CZSC: def __reduce__(self) -> typing.Any: r""" Pickle support — `__reduce__` returns ``(CZSC, (fixed_point_bars, max_bi_num))``. - + `update_bar` drains older bars whose dt is below the current first-BI's start (see `bars_raw.drain` block above), so a freshly-constructed CZSC's `bars_raw` may still differ from the @@ -348,17 +333,17 @@ class CzscSignals: 返回基准周期字符串 """ @property - def end_dt(self) -> typing.Any | None: + def end_dt(self) -> typing.Optional[typing.Any]: r""" 返回最新时间,作为 pandas Timestamp """ @property - def bid(self) -> builtins.int | None: + def bid(self) -> typing.Optional[builtins.int]: r""" 返回当前 bar id """ @property - def latest_price(self) -> builtins.float | None: + def latest_price(self) -> typing.Optional[builtins.float]: r""" 返回最新价格 """ @@ -367,8 +352,8 @@ class CzscSignals: r""" 返回原始信号配置 """ - def __new__(cls, bg: BarGenerator, signals_config: list) -> CzscSignals: ... - def update_signals(self, bar: RawBar) -> None: + def __new__(cls, bg:BarGenerator, signals_config:list) -> CzscSignals: ... + def update_signals(self, bar:RawBar) -> None: r""" 更新信号 """ @@ -419,17 +404,17 @@ class CzscTrader: 返回基准周期字符串 """ @property - def end_dt(self) -> typing.Any | None: + def end_dt(self) -> typing.Optional[typing.Any]: r""" 返回最新时间,作为 pandas Timestamp """ @property - def bid(self) -> builtins.int | None: + def bid(self) -> typing.Optional[builtins.int]: r""" 返回当前 bar id """ @property - def latest_price(self) -> builtins.float | None: + def latest_price(self) -> typing.Optional[builtins.float]: r""" 返回最新价格 """ @@ -448,26 +433,24 @@ class CzscTrader: r""" 返回是否有仓位发生变化 """ - def __new__( - cls, bg: BarGenerator, positions: list, signals_config: list, ensemble_method: builtins.str = "mean" - ) -> CzscTrader: ... - def update(self, bar: RawBar) -> None: + def __new__(cls, bg:BarGenerator, positions:list, signals_config:list, ensemble_method:builtins.str='mean') -> CzscTrader: ... + def update(self, bar:RawBar) -> None: r""" 更新信号和仓位 """ - def on_bar(self, bar: RawBar) -> None: + def on_bar(self, bar:RawBar) -> None: r""" 更新信号和仓位(同 update) """ - def on_sig(self, sig: dict) -> None: + def on_sig(self, sig:dict) -> None: r""" 基于信号字典更新仓位 """ - def get_ensemble_pos(self, method: builtins.str | None = None) -> builtins.float: + def get_ensemble_pos(self, method:typing.Optional[builtins.str]=None) -> builtins.float: r""" 获取集成后的仓位值 """ - def get_position(self, name: builtins.str) -> Position | None: + def get_position(self, name:builtins.str) -> typing.Optional[Position]: r""" 根据名称获取仓位 """ @@ -475,7 +458,7 @@ class CzscTrader: r""" 获取当前信号字典 """ - def update_signals(self, bar: RawBar) -> None: + def update_signals(self, bar:RawBar) -> None: r""" 仅更新信号(不更新仓位) """ @@ -510,23 +493,16 @@ class Event: r""" 获取SHA256哈希 """ - def __new__( - cls, - operate: Operate, - signals_all: typing.Sequence[Signal] = [], - signals_any: typing.Sequence[Signal] = [], - signals_not: typing.Sequence[Signal] = [], - name: builtins.str = "", - ) -> Event: ... + def __new__(cls, operate:Operate, signals_all:typing.Sequence[Signal]=[], signals_any:typing.Sequence[Signal]=[], signals_not:typing.Sequence[Signal]=[], name:builtins.str='') -> Event: ... @classmethod - def from_dict(cls, _cls: type, dict: dict) -> Event: ... + def from_dict(cls, _cls:type, dict:dict) -> Event: ... @classmethod - def from_json(cls, _cls: type, json_str: builtins.str) -> Event: ... + def from_json(cls, _cls:type, json_str:builtins.str) -> Event: ... def compute_sha8(self) -> builtins.str: r""" 计算SHA8哈希值 """ - def is_match(self, signals: typing.Any) -> builtins.bool: + def is_match(self, signals:typing.Any) -> builtins.bool: r""" 判断事件是否匹配信号集合,返回是否匹配 支持多种参数类型:Dict[str, str] 或 Dict[str, Signal] 或 Vec @@ -542,7 +518,7 @@ class Event: 导出为字典 """ @classmethod - def load(cls, _cls: type, data: dict) -> Event: + def load(cls, _cls:type, data:dict) -> Event: r""" 从字典加载 """ @@ -611,18 +587,9 @@ class FX: r""" 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 """ - def __new__( - cls, - symbol: builtins.str, - dt: typing.Any, - mark: Mark, - high: builtins.float, - low: builtins.float, - fx: builtins.float, - elements: typing.Sequence[NewBar], - ) -> FX: ... + def __new__(cls, symbol:builtins.str, dt:typing.Any, mark:Mark, high:builtins.float, low:builtins.float, fx:builtins.float, elements:typing.Sequence[NewBar]) -> FX: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other: FX, op: int) -> builtins.bool: ... + def __richcmp__(self, other:FX, op:int) -> builtins.bool: ... class FakeBI: r""" @@ -657,7 +624,7 @@ class LiteBar: def dt(self) -> builtins.float: ... @property def price(self) -> builtins.float: ... - def __new__(cls, id: builtins.int, dt: builtins.float, price: builtins.float) -> LiteBar: ... + def __new__(cls, id:builtins.int, dt:builtins.float, price:builtins.float) -> LiteBar: ... def __repr__(self) -> builtins.str: ... class NewBar: @@ -693,28 +660,14 @@ class NewBar: r""" 获取构成NewBar的原始K线列表(与elements相同,为兼容czsc库) """ - def __new__( - cls, - symbol: builtins.str, - dt: typing.Any, - freq: Freq, - open: builtins.float, - close: builtins.float, - high: builtins.float, - low: builtins.float, - vol: builtins.float, - amount: builtins.float, - id: builtins.int = 0, - elements: typing.Sequence[RawBar] | None = None, - ) -> NewBar: ... + def __new__(cls, symbol:builtins.str, dt:typing.Any, freq:Freq, open:builtins.float, close:builtins.float, high:builtins.float, low:builtins.float, vol:builtins.float, amount:builtins.float, id:builtins.int=0, elements:typing.Optional[typing.Sequence[RawBar]]=None) -> NewBar: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other: NewBar, op: int) -> builtins.bool: ... + def __richcmp__(self, other:NewBar, op:int) -> builtins.bool: ... class Operate: r""" Python可见的Operate包装器 """ - HL: Operate HS: Operate HO: Operate @@ -728,26 +681,26 @@ class Operate: 兼容性属性:返回操作类型的中文字符串值 """ @classmethod - def hl(cls, _cls: type) -> Operate: ... + def hl(cls, _cls:type) -> Operate: ... @classmethod - def hs(cls, _cls: type) -> Operate: ... + def hs(cls, _cls:type) -> Operate: ... @classmethod - def ho(cls, _cls: type) -> Operate: ... + def ho(cls, _cls:type) -> Operate: ... @classmethod - def lo(cls, _cls: type) -> Operate: ... + def lo(cls, _cls:type) -> Operate: ... @classmethod - def le(cls, _cls: type) -> Operate: ... + def le(cls, _cls:type) -> Operate: ... @classmethod - def so(cls, _cls: type) -> Operate: ... + def so(cls, _cls:type) -> Operate: ... @classmethod - def se(cls, _cls: type) -> Operate: ... + def se(cls, _cls:type) -> Operate: ... @classmethod - def from_str_py(cls, _cls: type, s: builtins.str) -> Operate: ... + def from_str_py(cls, _cls:type, s:builtins.str) -> Operate: ... @classmethod - def from_str(cls, _cls: type, s: builtins.str) -> Operate: ... + def from_str(cls, _cls:type, s:builtins.str) -> Operate: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __eq__(self, other: Operate) -> builtins.bool: ... + def __eq__(self, other:Operate) -> builtins.bool: ... def __hash__(self) -> builtins.int: ... def __reduce__(self) -> typing.Any: r""" @@ -759,7 +712,7 @@ class ParsedSignalDoc: Python可见的ParsedSignalDoc包装器 """ @property - def param_template(self) -> builtins.str | None: ... + def param_template(self) -> typing.Optional[builtins.str]: ... @property def signals(self) -> builtins.list[Signal]: ... def __repr__(self) -> builtins.str: ... @@ -769,19 +722,19 @@ class Pos: Python可见的Pos枚举包装器 """ @classmethod - def short(cls, _cls: type) -> Pos: ... + def short(cls, _cls:type) -> Pos: ... @classmethod - def flat(cls, _cls: type) -> Pos: ... + def flat(cls, _cls:type) -> Pos: ... @classmethod - def long(cls, _cls: type) -> Pos: ... + def long(cls, _cls:type) -> Pos: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __eq__(self, other: Pos) -> builtins.bool: ... - def __add__(self, other: Pos) -> builtins.float: + def __eq__(self, other:Pos) -> builtins.bool: ... + def __add__(self, other:Pos) -> builtins.float: r""" 加法运算,用于numpy.mean等数学操作 """ - def __radd__(self, other: builtins.float) -> builtins.float: + def __radd__(self, other:builtins.float) -> builtins.float: r""" 右加法运算 """ @@ -793,19 +746,19 @@ class Pos: r""" 整数转换 """ - def __lt__(self, other: Pos) -> builtins.bool: + def __lt__(self, other:Pos) -> builtins.bool: r""" 比较运算符 - 小于 """ - def __le__(self, other: Pos) -> builtins.bool: + def __le__(self, other:Pos) -> builtins.bool: r""" 比较运算符 - 小于等于 """ - def __gt__(self, other: Pos) -> builtins.bool: + def __gt__(self, other:Pos) -> builtins.bool: r""" 比较运算符 - 大于 """ - def __ge__(self, other: Pos) -> builtins.bool: + def __ge__(self, other:Pos) -> builtins.bool: r""" 比较运算符 - 大于等于 """ @@ -835,7 +788,7 @@ class Position: @property def pos_changed(self) -> builtins.bool: ... @property - def end_dt(self) -> builtins.float | None: + def end_dt(self) -> typing.Optional[builtins.float]: r""" 获取最新信号时间 """ @@ -858,22 +811,12 @@ class Position: def unique_signals(self) -> builtins.list[builtins.str]: ... @property def events(self) -> builtins.list[Event]: ... - def __new__( - cls, - symbol: builtins.str, - opens: typing.Sequence[Event], - exits: typing.Sequence[Event] = [], - interval: builtins.int = 0, - timeout: builtins.int = 1000, - stop_loss: builtins.float = 1000.0, - t0: builtins.bool = False, - name: builtins.str | None = None, - ) -> Position: ... + def __new__(cls, symbol:builtins.str, opens:typing.Sequence[Event], exits:typing.Sequence[Event]=[], interval:builtins.int=0, timeout:builtins.int=1000, stop_loss:builtins.float=1000.0, t0:builtins.bool=False, name:typing.Optional[builtins.str]=None) -> Position: ... @classmethod - def load_from_file(cls, _cls: type, path: builtins.str) -> Position: ... + def load_from_file(cls, _cls:type, path:builtins.str) -> Position: ... @classmethod - def from_json(cls, _cls: type, json_str: builtins.str) -> Position: ... - def save(self, path: builtins.str) -> None: + def from_json(cls, _cls:type, json_str:builtins.str) -> Position: ... + def save(self, path:builtins.str) -> None: r""" 保存到文件 """ @@ -885,7 +828,7 @@ class Position: r""" 获取所有相关事件 """ - def update(self, arg1: typing.Any, arg2: typing.Any | None = None) -> None: + def update(self, arg1:typing.Any, arg2:typing.Optional[typing.Any]=None) -> None: r""" 更新仓位状态(兼容单参数调用) """ @@ -893,12 +836,12 @@ class Position: r""" 支持 pickle 序列化 - 使用 __reduce__ 方法 """ - def dump(self, with_data: builtins.bool = True) -> typing.Any: + def dump(self, with_data:builtins.bool=True) -> typing.Any: r""" 导出Position数据为Python字典 """ @classmethod - def load(cls, _cls: type, data: typing.Any) -> Position: + def load(cls, _cls:type, data:typing.Any) -> Position: r""" 从字典数据加载Position """ @@ -950,19 +893,7 @@ class RawBar: r""" 直接支持 __dict__ 属性,让 pandas DataFrame() 能正确识别对象 """ - def __new__( - cls, - symbol: builtins.str, - dt: typing.Any, - freq: Freq, - open: builtins.float, - close: builtins.float, - high: builtins.float, - low: builtins.float, - vol: builtins.float, - amount: builtins.float, - id: builtins.int = 0, - ) -> RawBar: ... + def __new__(cls, symbol:builtins.str, dt:typing.Any, freq:Freq, open:builtins.float, close:builtins.float, high:builtins.float, low:builtins.float, vol:builtins.float, amount:builtins.float, id:builtins.int=0) -> RawBar: ... def _asdict(self) -> typing.Any: r""" 让对象表现得像记录,pandas DataFrame构造器会调用这个 @@ -975,12 +906,12 @@ class RawBar: r""" 支持pickle序列化 """ - def __deepcopy__(self, _memo: typing.Any) -> RawBar: + def __deepcopy__(self, _memo:typing.Any) -> RawBar: r""" 支持深拷贝 """ def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other: RawBar, op: int) -> builtins.bool: ... + def __richcmp__(self, other:RawBar, op:int) -> builtins.bool: ... class Signal: r""" @@ -1007,35 +938,22 @@ class Signal: """ @property def k2(self) -> builtins.str: ... - def __new__( - cls, - *args, - signal: builtins.str | None = None, - key: builtins.str | None = None, - value: builtins.str | None = None, - k1: builtins.str | None = None, - k2: builtins.str | None = None, - k3: builtins.str | None = None, - v1: builtins.str | None = None, - v2: builtins.str | None = None, - v3: builtins.str | None = None, - score: builtins.int | None = None, - ) -> Signal: ... + def __new__(cls, *args, signal:typing.Optional[builtins.str]=None, key:typing.Optional[builtins.str]=None, value:typing.Optional[builtins.str]=None, k1:typing.Optional[builtins.str]=None, k2:typing.Optional[builtins.str]=None, k3:typing.Optional[builtins.str]=None, v1:typing.Optional[builtins.str]=None, v2:typing.Optional[builtins.str]=None, v3:typing.Optional[builtins.str]=None, score:typing.Optional[builtins.int]=None) -> Signal: ... @classmethod - def from_string(cls, _cls: type, s: builtins.str) -> Signal: ... + def from_string(cls, _cls:type, s:builtins.str) -> Signal: ... def to_json(self) -> builtins.str: r""" 添加to_json方法以匹配Python版本 """ def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __eq__(self, other: Signal) -> builtins.bool: ... + def __eq__(self, other:Signal) -> builtins.bool: ... def __hash__(self) -> builtins.int: ... - def matches(self, other: Signal) -> builtins.bool: + def matches(self, other:Signal) -> builtins.bool: r""" 检查Signal是否匹配另一个Signal """ - def is_match(self, signals_dict: typing.Mapping[builtins.str, builtins.str]) -> builtins.bool: + def is_match(self, signals_dict:typing.Mapping[builtins.str, builtins.str]) -> builtins.bool: r""" 判断信号是否与信号字典中的值匹配(Python版本is_match逻辑) """ @@ -1097,7 +1015,7 @@ class ZS: """ @property def cache(self) -> dict: ... - def __new__(cls, bis: typing.Sequence[BI]) -> ZS: ... + def __new__(cls, bis:typing.Sequence[BI]) -> ZS: ... def is_valid(self) -> builtins.bool: r""" 中枢是否有效 @@ -1107,7 +1025,6 @@ class Direction(Enum): r""" 方向 """ - Up = ... r""" 向上 @@ -1122,7 +1039,7 @@ class Direction(Enum): r""" 获取方向的字符串值(与 czsc 库兼容) """ - def __deepcopy__(self, _memo: typing.Any) -> Direction: + def __deepcopy__(self, _memo:typing.Any) -> Direction: r""" 支持深拷贝 """ @@ -1130,16 +1047,15 @@ class Direction(Enum): r""" 支持pickle序列化 """ - def __new__(cls, value: builtins.str) -> Direction: ... + def __new__(cls, value:builtins.str) -> Direction: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other: typing.Any, op: int) -> builtins.bool: ... + def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... class Freq(Enum): r""" 时间周期 """ - Tick = ... r""" 逐笔 @@ -1228,7 +1144,7 @@ class Freq(Enum): __members__: typing.Any @property def value(self) -> builtins.str: ... - def __deepcopy__(self, _memo: typing.Any) -> Freq: + def __deepcopy__(self, _memo:typing.Any) -> Freq: r""" 支持深拷贝 """ @@ -1236,16 +1152,15 @@ class Freq(Enum): r""" 支持pickle序列化 """ - def __new__(cls, value: builtins.str) -> Freq: ... + def __new__(cls, value:builtins.str) -> Freq: ... def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other: typing.Any, op: int) -> builtins.bool: ... + def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... class Mark(Enum): r""" 分型类型 """ - D = ... r""" 底分型 @@ -1262,7 +1177,7 @@ class Mark(Enum): """ def __str__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... - def __richcmp__(self, other: typing.Any, op: int) -> builtins.bool: ... + def __richcmp__(self, other:typing.Any, op:int) -> builtins.bool: ... class Market(Enum): AShare = ... @@ -1278,44 +1193,43 @@ class Market(Enum): 默认 """ - def __new__(cls, ob: typing.Any) -> Market: ... + def __new__(cls, ob:typing.Any) -> Market: ... -def chip_distribution_triangle( - data: numpy.typing.NDArray[numpy.float64], price_step: builtins.float, decay_factor: builtins.float -) -> tuple[numpy.typing.NDArray[numpy.float64], numpy.typing.NDArray[numpy.float64]]: +def chip_distribution_triangle(data:numpy.typing.NDArray[numpy.float64], price_step:builtins.float, decay_factor:builtins.float) -> tuple[numpy.typing.NDArray[numpy.float64], numpy.typing.NDArray[numpy.float64]]: r""" 计算筹码分布(三角形分布 + 筹码沉淀机制) - + 此函数用于估算基于历史K线的筹码分布情况,结合三角形分布模型和筹码沉淀(衰减)机制。 - + # Python 接口说明 - + 输入一个二维 numpy 数组,形状为 (N, 3),每一行对应一根K线,列顺序为: `[high, low, vol]`,类型必须为 `float64`。 - + 示例: ```python columns = ['high', 'low', 'vol'] arr2 = df[columns].to_numpy(dtype=np.float64) price_centers, chip_dist = chip_distribution_triangle(arr2, 0.01, 0.9) ``` - + # 参数 - + - `data`: 二维数组,形状为 (N, 3),分别是每根K线的最高价、最低价和成交量。 - `price_step`: 分档间隔(如0.01表示以0.01为单位划分价格区间)。 - `decay_factor`: 筹码衰减因子,表示前一根K线上的筹码有多少比例沉淀保留到下一根K线上,范围为(0, 1),例如0.98表示保留98%。 - + # 返回值 - + 返回一个元组 `(price_centers, chip_distribution)`: - `price_centers`: 一维数组,表示价格分布区间的中心价位。 - `chip_distribution`: 一维数组,对应每个价格中心的筹码强度(权重/密度)。 - + 返回的两个数组长度相同,可用于绘制筹码分布图或进一步分析。 """ -def parse_signal_doc(doc: builtins.str) -> ParsedSignalDoc: +def parse_signal_doc(doc:builtins.str) -> ParsedSignalDoc: r""" 解析文档中的Signal信息 """ + diff --git a/czsc/utils/trade.py b/czsc/utils/trade.py index 64ef272b6..3e1d4cb88 100644 --- a/czsc/utils/trade.py +++ b/czsc/utils/trade.py @@ -144,7 +144,9 @@ def resample_to_daily(df: pd.DataFrame, sdt=None, edt=None, only_trade_date=True trade_dates = pd.merge_asof(trade_dates, vdt, left_on="date", right_on="dt") trade_dates = trade_dates.dropna(subset=["dt"]).reset_index(drop=True) - dt_map = dict(df.groupby("dt")) + # noqa: C416 — DataFrameGroupBy.keys 是字符串属性而非方法,dict(grouper) 会把 + # mapping 路径走崩('str' object is not callable)。必须用显式 dict-comp。 + dt_map = {dt: dfg for dt, dfg in df.groupby("dt")} # noqa: C416 results = [] for row in trade_dates.to_dict("records"): # 注意:这里必须进行 copy,否则默认浅拷贝导致数据异常 diff --git a/pyproject.toml b/pyproject.toml index 620147802..f93f0b101 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,16 +139,16 @@ filterwarnings = [ [tool.ruff] line-length = 120 +# Auto-generated by pyo3-stub-gen — its formatter is authoritative +# (CI's stub-drift check regenerates and asserts no diff). Excluded from +# both lint and format so we don't fight the generator. +extend-exclude = ["czsc/_native.pyi"] [tool.ruff.lint] select = ["E", "F", "I", "UP", "B", "SIM", "C4"] ignore = ["E501", "SIM112"] extend-safe-fixes = ["UP"] -[tool.ruff.lint.per-file-ignores] -# Auto-generated by pyo3-stub-gen; regeneration overwrites manual edits. -"czsc/_native.pyi" = ["F811"] - [tool.ruff.lint.isort] known-first-party = ["czsc"] From 0e3e79f791dec62b490474957a31ed0eae606897 Mon Sep 17 00:00:00 2001 From: jun <793739422@qq.com> Date: Fri, 8 May 2026 13:33:09 +0800 Subject: [PATCH 23/23] =?UTF-8?q?chore(rust):=20=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=E4=B8=AD=E6=96=87=E5=8C=96=20+=20=E6=94=B6=E7=B4=A7=20CI=20cli?= =?UTF-8?q?ppy=20=E4=B8=A5=E6=A0=BC=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 翻译 ~294 行 Rust 英文注释为中文,覆盖 czsc-{core,python,utils,ta,signals, trader,signal-macros} 与 error-support;保留 API/类型名、设计文档锚点、 注释掉的代码块、Python 行号引用、数学公式不变。 - czsc/_native.pyi 同步重生成(pyo3-stub-gen 把 __reduce__ 的中文 doc 嵌入 stub 后产生 18 行漂移,与 Rust 源同步提交以满足 stub-drift CI)。 - .github/workflows/code-quality.yml 去掉 cargo clippy 的 `|| true`, 让 -D warnings 真正阻塞 CI(之前任何 clippy 错误都被静默吞掉)。 回归验证:cargo fmt/clippy/test (per-crate, 122 passed) + ruff check/format + maturin develop + pytest (212 passed, 1 skipped) 全部通过;业务代码零改动。 --- .github/workflows/code-quality.yml | 2 +- crates/czsc-core/src/analyze/mod.rs | 18 ++--- crates/czsc-core/src/lib.rs | 6 +- crates/czsc-core/src/objects/bar.rs | 10 +-- crates/czsc-core/src/objects/event.rs | 10 +-- crates/czsc-core/src/objects/fake_bi.rs | 4 +- crates/czsc-core/src/objects/operate.rs | 27 ++++--- crates/czsc-core/src/objects/position.rs | 8 +- crates/czsc-core/src/objects/signal.rs | 19 +++-- crates/czsc-core/src/python/mod.rs | 51 ++++++------ crates/czsc-core/src/utils/corr.rs | 2 +- crates/czsc-core/tests/test_analyze_utils.rs | 18 ++--- crates/czsc-core/tests/test_bi.rs | 8 +- crates/czsc-core/tests/test_czsc_analyzer.rs | 19 +++-- crates/czsc-core/tests/test_fx.rs | 10 +-- crates/czsc-core/tests/test_signal.rs | 19 ++--- crates/czsc-core/tests/test_zs.rs | 12 +-- crates/czsc-python/build.rs | 12 +-- crates/czsc-python/src/lib.rs | 32 ++++---- crates/czsc-python/src/signals_dispatcher.rs | 78 +++++++++---------- crates/czsc-python/src/trader/api.rs | 2 +- crates/czsc-python/src/trader/czsc_signals.rs | 6 +- crates/czsc-python/src/trader/czsc_trader.rs | 2 +- crates/czsc-python/src/trader/mod.rs | 14 ++-- .../czsc-signal-macros/tests/test_export.rs | 17 ++-- crates/czsc-signals/src/utils/ta.rs | 10 +-- crates/czsc-ta/src/pure.rs | 26 +++---- crates/czsc-ta/src/python.rs | 71 ++++++++--------- crates/czsc-ta/tests/test_pure.rs | 22 +++--- crates/czsc-trader/src/lib.rs | 13 ++-- crates/czsc-trader/src/optimize.rs | 4 +- crates/czsc-utils/src/lib.rs | 12 +-- crates/czsc-utils/src/python/mod.rs | 25 +++--- crates/czsc-utils/src/trading_time.rs | 15 ++-- crates/czsc-utils/tests/test_freq_data.rs | 19 +++-- crates/czsc-utils/tests/test_trading_time.rs | 10 +-- crates/error-support/src/lib.rs | 2 +- czsc/_native.pyi | 18 ++--- 38 files changed, 308 insertions(+), 345 deletions(-) diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 114b567b6..f90900182 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -225,7 +225,7 @@ jobs: - name: Lint Rust with cargo clippy run: | - cargo clippy --workspace --all-targets -- -D warnings || true + cargo clippy --workspace --all-targets -- -D warnings security: name: Security Audit diff --git a/crates/czsc-core/src/analyze/mod.rs b/crates/czsc-core/src/analyze/mod.rs index 5935a535b..827f759fe 100644 --- a/crates/czsc-core/src/analyze/mod.rs +++ b/crates/czsc-core/src/analyze/mod.rs @@ -660,17 +660,15 @@ impl CZSC { ) } - /// Pickle support — `__reduce__` returns ``(CZSC, (fixed_point_bars, max_bi_num))``. + /// Pickle 支持 —— `__reduce__` 返回 ``(CZSC, (fixed_point_bars, max_bi_num))``。 /// - /// `update_bar` drains older bars whose dt is below the current - /// first-BI's start (see `bars_raw.drain` block above), so a - /// freshly-constructed CZSC's `bars_raw` may still differ from the - /// fixed point reached after a single re-analysis. We run one extra - /// `CZSC::new` here to converge before serializing — guarantees that - /// `pickle.dumps(restored) == pickle.dumps(obj)` byte-for-byte even - /// when CzscSignals nests CZSC inside `kas[freq]` (Phase A's - /// `restored.__getstate__() == obj.__getstate__()` assertion relies - /// on this). + /// `update_bar` 会丢弃 dt 小于当前 first-BI 起始时间的旧 bar + /// (参见上面的 `bars_raw.drain` 块),因此刚构造出来的 CZSC 的 + /// `bars_raw` 可能仍然和「再分析一次后到达的不动点」不同。这里多 + /// 跑一次 `CZSC::new`,让其在序列化前收敛 —— 保证即使 CzscSignals + /// 在 `kas[freq]` 里嵌套了 CZSC,`pickle.dumps(restored) == + /// pickle.dumps(obj)` 也是逐字节相等的(Phase A 的 + /// `restored.__getstate__() == obj.__getstate__()` 断言依赖这一点)。 fn __reduce__(&self, py: Python) -> PyResult { use pyo3::IntoPyObject; let trimmed = CZSC::new(self.bars_raw.clone(), self.max_bi_num); diff --git a/crates/czsc-core/src/lib.rs b/crates/czsc-core/src/lib.rs index 95034486f..1ebba67be 100644 --- a/crates/czsc-core/src/lib.rs +++ b/crates/czsc-core/src/lib.rs @@ -1,7 +1,7 @@ -//! czsc-core —缠论 core analyzer (FX / BI / ZS / CZSC). +//! czsc-core —— 缠论核心分析器(FX / BI / ZS / CZSC)。 //! -//! Migrated from rs-czsc 47ef6efa. Submodules are added incrementally as -//! Phase D progresses; see docs/superpowers/plans/2026-05-03-rust-czsc-migration.md. +//! 由 rs-czsc 47ef6efa 迁移而来。子模块随着 Phase D 的推进逐步加入; +//! 参见 docs/superpowers/plans/2026-05-03-rust-czsc-migration.md。 pub mod analyze; pub mod objects; diff --git a/crates/czsc-core/src/objects/bar.rs b/crates/czsc-core/src/objects/bar.rs index 74ac38367..88ff06336 100644 --- a/crates/czsc-core/src/objects/bar.rs +++ b/crates/czsc-core/src/objects/bar.rs @@ -233,11 +233,11 @@ impl RawBar { /// 支持pickle序列化 fn __reduce__(&self, py: Python) -> PyResult { - // RawBar.new takes `freq: Freq` (the PyO3 enum), not a string — - // pass the enum directly through pickle so the unpickle path - // (`RawBar(*args)`) succeeds. Stringifying with - // `freq_to_chinese_string` here would force the constructor to - // accept str|Freq and silently change the public API. + // RawBar.new 接收 `freq: Freq`(PyO3 枚举),而不是字符串 —— + // 通过 pickle 直接传递枚举,这样 unpickle 路径 + // (`RawBar(*args)`)才会成功。如果在这里用 + // `freq_to_chinese_string` 字符串化,会迫使构造函数同时 + // 接受 str|Freq,并悄悄地改动公共 API。 let cls = py.get_type::(); let args = ( self.symbol.as_ref(), diff --git a/crates/czsc-core/src/objects/event.rs b/crates/czsc-core/src/objects/event.rs index 6f0c90c74..c1877cc82 100644 --- a/crates/czsc-core/src/objects/event.rs +++ b/crates/czsc-core/src/objects/event.rs @@ -1,7 +1,7 @@ -// czsc-only: pyo3 imports + super::operate / super::signal Python wrappers -// gated behind the `python` feature for non-python builds. Sha256 is used in -// the (non-python) Event helpers so it stays unconditional. -// See docs/MIGRATION_NOTES.md §2.4. +// czsc-only: pyo3 import 与 super::operate / super::signal 的 Python wrapper +// 通过 `python` feature 进行门控,以便 non-python 构建。Sha256 在 +// (non-python 的)Event 辅助函数里被使用,因此保持无条件 import。 +// 参见 docs/MIGRATION_NOTES.md §2.4。 #![allow(unused)] use anyhow::{Context, anyhow}; use serde::{Deserialize, Serialize}; @@ -270,7 +270,7 @@ impl<'py> FromPyObject<'py> for Event { None => String::new(), }; - // 3) signals + // 3) 信号字段 let signals_all = dict .get_item("signals_all")? .ok_or(PyValueError::new_err("缺少字段: 'signals_all'"))? diff --git a/crates/czsc-core/src/objects/fake_bi.rs b/crates/czsc-core/src/objects/fake_bi.rs index 3fb4b9b6e..22425ac0f 100644 --- a/crates/czsc-core/src/objects/fake_bi.rs +++ b/crates/czsc-core/src/objects/fake_bi.rs @@ -115,11 +115,11 @@ impl FakeBI { /// 创建 fake_bis 列表 /// -/// # Arguments +/// # 参数 /// /// * `fxs` - 分型序列,必须顶底分型交替 /// -/// # Returns +/// # 返回值 /// /// * 返回 FakeBI 的 Vec pub fn create_fake_bis(fxs: &[FX]) -> Vec { diff --git a/crates/czsc-core/src/objects/operate.rs b/crates/czsc-core/src/objects/operate.rs index edb2a575a..9dc223583 100644 --- a/crates/czsc-core/src/objects/operate.rs +++ b/crates/czsc-core/src/objects/operate.rs @@ -1,10 +1,9 @@ -// czsc-only: rs-czsc's `operate.rs` carried a wide collection of unused -// imports (polars / log / rayon / sha2 / and forward references to -// event / position / signal). Since the file only really defines the -// `Operate` enum + its `PyOperate` wrapper, we trim the imports here to -// avoid pulling those heavy crates into czsc-core. The original -// `#![allow(unused)]` is kept for the few remaining unused items the -// upstream file carries. See docs/MIGRATION_NOTES.md §2.4. +// czsc-only: rs-czsc 的 `operate.rs` 里带了一大堆未使用的 import +// (polars / log / rayon / sha2 还有对 event / position / signal 的前向引用)。 +// 由于这个文件实际上只定义了 `Operate` 枚举及其 `PyOperate` 包装器, +// 我们在这里裁剪掉这些 import,避免把那些重量级 crate 拉进 czsc-core。 +// 原本的 `#![allow(unused)]` 保留下来,用于覆盖上游文件中尚存的少量 +// 未使用项。参见 docs/MIGRATION_NOTES.md §2.4。 #![allow(unused)] use serde::{Deserialize, Serialize}; use std::fmt; @@ -25,25 +24,25 @@ pub const ANY: &str = "任意"; Clone, Copy, Debug, PartialEq, Hash, EnumString, EnumIter, AsRefStr, Serialize, Deserialize, )] pub enum Operate { - /// Hold Long 持多 + /// 持多(Hold Long) #[serde(rename = "持多")] HL, - /// Hold Short 持空 + /// 持空(Hold Short) #[serde(rename = "持空")] HS, - /// Hold Other 持币 + /// 持币(Hold Other) #[serde(rename = "持币")] HO, - /// Long Open 开多 + /// 开多(Long Open) #[serde(rename = "开多")] LO, - /// Long Exit 平多 + /// 平多(Long Exit) #[serde(rename = "平多")] LE, - /// Short Open 开空 + /// 开空(Short Open) #[serde(rename = "开空")] SO, - /// Short Exit 平空 + /// 平空(Short Exit) #[serde(rename = "平空")] SE, } diff --git a/crates/czsc-core/src/objects/position.rs b/crates/czsc-core/src/objects/position.rs index 2a8250552..08db58f2c 100644 --- a/crates/czsc-core/src/objects/position.rs +++ b/crates/czsc-core/src/objects/position.rs @@ -1,6 +1,6 @@ -// czsc-only: pyo3 + Python wrapper imports gated behind the `python` feature -// for non-python builds. polars + log are kept unconditional because Position -// uses them outside cfg blocks. See docs/MIGRATION_NOTES.md §2.4. +// czsc-only: pyo3 与 Python wrapper 的 import 通过 `python` feature 进行门控, +// 以便 non-python 构建。polars 与 log 保持无条件 import,因为 Position +// 在 cfg 块外面也会用到它们。参见 docs/MIGRATION_NOTES.md §2.4。 #![allow(unused)] use anyhow::{Context, anyhow}; use chrono::{DateTime, FixedOffset, NaiveDateTime}; @@ -657,7 +657,7 @@ impl Position { return PositionUpdateProfile::default(); } } else { - // init + // 初始化 self.temp_state = Some(TempState { end_dt: last_bar.dt, last_lo_dt: None, diff --git a/crates/czsc-core/src/objects/signal.rs b/crates/czsc-core/src/objects/signal.rs index 6eeae9374..47af93d53 100644 --- a/crates/czsc-core/src/objects/signal.rs +++ b/crates/czsc-core/src/objects/signal.rs @@ -1,7 +1,7 @@ -// czsc-only: pyo3 imports gated behind the `python` feature (rs-czsc 47ef6efa -// relied on `#![allow(unused)]` to mask the bare imports when the feature was -// off; we make the gating explicit so czsc-core builds in non-python mode). -// See docs/MIGRATION_NOTES.md §2.4. +// czsc-only: pyo3 import 通过 `python` feature 进行门控(rs-czsc 47ef6efa +// 依赖 `#![allow(unused)]` 在 feature 关闭时屏蔽这些裸 import;我们把门控 +// 显式化,这样 czsc-core 在 non-python 模式下也能编译)。 +// 参见 docs/MIGRATION_NOTES.md §2.4。 #![allow(unused)] use anyhow::{Context, anyhow, bail}; use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Visitor}; @@ -63,10 +63,9 @@ pub struct PySignal { } impl PySignal { - /// Wrap an inner [`Signal`] for Python exposure. The constructor is - /// public so downstream crates (notably `czsc-python`'s signal - /// dispatcher) can return signal objects without round-tripping - /// through the string parser. + /// 把内部 [`Signal`] 包装成 Python 可见对象。构造函数是 public 的, + /// 这样下游 crate(特别是 `czsc-python` 的信号分发器)可以直接返回 + /// 信号对象,无需经过字符串解析器走一遭。 pub fn from_inner(inner: Signal) -> Self { Self { inner } } @@ -246,7 +245,7 @@ pub(crate) fn parse_signal_doc(doc: &str) -> ParsedSignalDoc { let mut param_template: Option = None; let mut signals: Vec = Vec::new(); - // Helper: 给定起始索引,查找第一个引号并提取匹配的内容(支持中英文引号) + // 辅助函数:给定起始索引,查找第一个引号并提取匹配的内容(支持中英文引号) fn extract_quoted(s: &str, start: usize) -> Option<(String, usize)> { // 支持的开引号及其对应的闭引号 let pairs: &[(char, char)] = &[ @@ -298,7 +297,7 @@ pub(crate) fn parse_signal_doc(doc: &str) -> ParsedSignalDoc { let mut pos = 0usize; let needle = "Signal("; while let Some(found) = doc[pos..].find(needle) { - let abs = pos + found + needle.len(); // index right after "(" + let abs = pos + found + needle.len(); // 紧跟在 "(" 之后的索引 if let Some((content, end_idx)) = extract_quoted(doc, abs) { let s = Signal::from_str(&content).ok(); if let Some(signal) = s { diff --git a/crates/czsc-core/src/python/mod.rs b/crates/czsc-core/src/python/mod.rs index b7c74765e..cfedd9bae 100644 --- a/crates/czsc-core/src/python/mod.rs +++ b/crates/czsc-core/src/python/mod.rs @@ -1,13 +1,11 @@ -//! PyO3 binding registry for czsc-core. +//! czsc-core 的 PyO3 binding 注册表。 //! -//! Phase D's per-type sub-loops add `#[cfg_attr(feature = "python", pyclass)]` -//! to each migrated type. This module collects them into a single -//! `register()` entrypoint that `czsc-python` calls from the -//! `_native` aggregator. +//! Phase D 的逐类型子循环会给每个迁移过来的类型加 `#[cfg_attr(feature = "python", pyclass)]`。 +//! 本模块把它们汇总到一个 `register()` 入口,由 `czsc-python` 在 +//! `_native` aggregator 中调用。 //! -//! Pickle (`__getstate__` / `__setstate__`) per design doc §2.4 will -//! land on a follow-up pass once Phase E/F/G land and the per-class -//! identity tests can fully exercise it. +//! 按 design doc §2.4 的 Pickle(`__getstate__` / `__setstate__`)将会 +//! 在 Phase E/F/G 落地后做一次后续提交,到时各类的 identity 测试可以充分覆盖它。 use pyo3::prelude::*; @@ -27,23 +25,22 @@ use crate::objects::position::{PyLiteBar, PyPos, PyPosition}; use crate::objects::signal::{PyParsedSignalDoc, PySignal, parse_signal_doc_py}; use crate::objects::zs::ZS; -/// Python-friendly thin wrapper around `analyze::utils::check_fx`. +/// 对 `analyze::utils::check_fx` 的 Python 友好的薄 wrapper。 #[pyfunction] #[pyo3(name = "check_fx")] fn check_fx_py(k1: NewBar, k2: NewBar, k3: NewBar) -> Option { analyze_utils::check_fx(&k1, &k2, &k3) } -/// Python-friendly thin wrapper around `analyze::utils::check_fxs`. +/// 对 `analyze::utils::check_fxs` 的 Python 友好的薄 wrapper。 #[pyfunction] #[pyo3(name = "check_fxs")] fn check_fxs_py(bars: Vec) -> Vec { analyze_utils::check_fxs(&bars) } -/// Python-friendly thin wrapper around `analyze::utils::check_bi`. -/// Drops the unused remainder slice; Python callers only ever consume -/// the optional BI value. +/// 对 `analyze::utils::check_bi` 的 Python 友好的薄 wrapper。 +/// 丢弃未使用的剩余切片;Python 调用方只消费可选的 BI 值。 #[pyfunction] #[pyo3(name = "check_bi")] fn check_bi_py(bars: Vec) -> Option { @@ -51,7 +48,7 @@ fn check_bi_py(bars: Vec) -> Option { bi } -/// Python-friendly thin wrapper around `analyze::utils::remove_include`. +/// 对 `analyze::utils::remove_include` 的 Python 友好的薄 wrapper。 #[pyfunction] #[pyo3(name = "remove_include")] fn remove_include_py(k1: NewBar, k2: NewBar, k3: RawBar) -> PyResult<(bool, NewBar)> { @@ -59,22 +56,20 @@ fn remove_include_py(k1: NewBar, k2: NewBar, k3: RawBar) -> PyResult<(bool, NewB .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } -/// Python-friendly thin wrapper around `analyze::utils::format_standard_kline`. -/// Polars DataFrame is bridged via the standard pyo3-polars / arrow path; for -/// now we accept a list of pre-built RawBars to avoid the polars/python coupling -/// during D.A. The full DataFrame entrypoint will be added when Phase E/F wire -/// the polars Python bridge (see design doc §2.3). +/// 对 `analyze::utils::format_standard_kline` 的 Python 友好的薄 wrapper。 +/// Polars DataFrame 通过标准的 pyo3-polars / arrow 路径桥接;目前 +/// 我们接受一个预构建好的 RawBar 列表,以避免在 D.A 阶段引入 polars/python 的耦合。 +/// 完整的 DataFrame 入口会等到 Phase E/F 接入 polars Python 桥时再添加(详见 design doc §2.3)。 #[pyfunction] #[pyo3(name = "format_standard_kline")] fn format_standard_kline_py(bars: Vec) -> Vec { bars } -/// Add the migrated czsc-core types onto the parent module that czsc-python -/// passes in. Lives behind the `python` feature so plain Rust consumers -/// don't pull pyo3 in transitively. +/// 把迁移过来的 czsc-core 类型添加到 czsc-python 传入的父模块上。 +/// 隐藏在 `python` feature 后面,这样普通 Rust 消费者就不会传递性地引入 pyo3。 pub fn register(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { - // Enums + // 枚举 m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -82,12 +77,12 @@ pub fn register(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - // Bar primitives + // Bar 基础类型 m.add_class::()?; m.add_class::()?; m.add_class::()?; - // Chan-theory data structures + // 缠论数据结构 m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -99,11 +94,11 @@ pub fn register(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - // Analyzer (CZSC) + // 分析器(CZSC) m.add_class::()?; - // Free functions: signal-doc parser + analyze helpers (the 4 promotions - // from design doc §2.5) + // 自由函数:signal-doc 解析器 + analyze helpers(来自 design doc §2.5 + // 的 4 个 promotion) m.add_function(wrap_pyfunction!(parse_signal_doc_py, m)?)?; m.add_function(wrap_pyfunction!(check_fx_py, m)?)?; m.add_function(wrap_pyfunction!(check_fxs_py, m)?)?; diff --git a/crates/czsc-core/src/utils/corr.rs b/crates/czsc-core/src/utils/corr.rs index df8e20544..c0452c2af 100644 --- a/crates/czsc-core/src/utils/corr.rs +++ b/crates/czsc-core/src/utils/corr.rs @@ -119,7 +119,7 @@ pub fn pearson_corr(x: &[f64], y: &[f64]) -> Option { /// /// [wiki](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient) /// -/// > Spearman's coefficient is appropriate for both continuous and discrete ordinal variables. +/// > Spearman 系数对于连续变量和离散有序变量都适用。 /// /// - 当数据为空或长度不一致时返回 None。 pub fn spearman_rank_corr(x: &[f64], y: &[f64]) -> Option { diff --git a/crates/czsc-core/tests/test_analyze_utils.rs b/crates/czsc-core/tests/test_analyze_utils.rs index 441c75af0..eb7a2e6c4 100644 --- a/crates/czsc-core/tests/test_analyze_utils.rs +++ b/crates/czsc-core/tests/test_analyze_utils.rs @@ -1,9 +1,9 @@ -//! Phase D.U — RED test: analyze::utils helpers (check_fx / check_fxs / -//! check_bi / remove_include / format_standard_kline) are publicly callable -//! and produce the expected shapes per the rs-czsc 47ef6efa baseline. +//! Phase D.U — RED test:analyze::utils 的 helper(check_fx / check_fxs / +//! check_bi / remove_include / format_standard_kline)对外可调用, +//! 并且产出与 rs-czsc 47ef6efa 基线一致的形状。 //! -//! This test also locks the visibility promotions required by the design -//! doc §2.5: all four `pub(crate)` helpers must now be `pub`. +//! 本测试同时锁定 design doc §2.5 要求的可见性提升: +//! 这 4 个原本 `pub(crate)` 的 helper 现在必须是 `pub`。 use std::sync::Arc; @@ -32,7 +32,7 @@ fn nb(ts: i64, high: f64, low: f64) -> NewBar { #[test] fn check_fx_detects_top_pattern() { - // top fx: middle bar engulfs both neighbours from above + // 顶分型:中间 bar 从上方包住两侧邻居 let k1 = nb(1, 11.0, 9.0); let k2 = nb(2, 12.0, 10.0); let k3 = nb(3, 11.5, 9.5); @@ -53,7 +53,7 @@ fn check_fx_detects_bottom_pattern() { #[test] fn check_fx_returns_none_when_no_pattern() { - // strictly increasing — neither top nor bottom + // 严格递增——既不是顶也不是底 let k1 = nb(1, 10.0, 9.0); let k2 = nb(2, 11.0, 10.0); let k3 = nb(3, 12.0, 11.0); @@ -62,7 +62,7 @@ fn check_fx_returns_none_when_no_pattern() { #[test] fn check_fxs_extracts_fx_from_sequence() { - // 5 bars: ascending, peak, descending → exactly one top fx in the middle + // 5 根 bar:上升、峰、下降 → 中间恰好出现一个顶分型 let bars = vec![ nb(1, 10.0, 9.0), nb(2, 11.0, 10.0), @@ -80,7 +80,7 @@ fn check_bi_returns_tuple_with_remainder() { .map(|i| nb(i + 1, 10.0 + i as f64, 9.0 + i as f64)) .collect(); let (bi, remainder) = check_bi(&bars); - // The function signature contract: always returns (Option, &[NewBar]) + // 函数签名契约:总是返回 (Option, &[NewBar]) let _ = bi; assert!(remainder.len() <= bars.len()); } diff --git a/crates/czsc-core/tests/test_bi.rs b/crates/czsc-core/tests/test_bi.rs index 7607e099c..6f782868a 100644 --- a/crates/czsc-core/tests/test_bi.rs +++ b/crates/czsc-core/tests/test_bi.rs @@ -1,5 +1,5 @@ -//! Phase D.8 — RED test: BI (笔) constructs via BIBuilder, surfaces -//! direction / endpoints, and answers length / SNR / power_price helpers. +//! Phase D.8 —— RED 测试:BI(笔)通过 BIBuilder 构造, +//! 暴露 direction / 端点,并响应 length / SNR / power_price 等辅助方法。 use std::sync::Arc; @@ -49,7 +49,7 @@ fn fx(ts: i64, mark: Mark, level: f64) -> FX { } fn sample_bi_up() -> BI { - // up bi: starts at bottom fx, ends at top fx + // 向上的笔:起点是底分型,终点是顶分型 let fx_a = fx(1_700_000_000, Mark::D, 9.0); let fx_b = fx(1_700_007_200, Mark::G, 12.0); let bars: Vec = (0..5) @@ -89,7 +89,7 @@ fn length_is_bars_count() { #[test] fn high_low_endpoints_match_fxs() { let bi = sample_bi_up(); - assert!(bi.get_low() < bi.get_high(), "low must be < high"); + assert!(bi.get_low() < bi.get_high(), "low 必须小于 high"); } #[test] diff --git a/crates/czsc-core/tests/test_czsc_analyzer.rs b/crates/czsc-core/tests/test_czsc_analyzer.rs index 6c0cbe07a..cff3f6809 100644 --- a/crates/czsc-core/tests/test_czsc_analyzer.rs +++ b/crates/czsc-core/tests/test_czsc_analyzer.rs @@ -1,6 +1,6 @@ -//! Phase D.A — RED test: CZSC analyzer constructs from a RawBar feed, -//! exposes bars_raw / bars_ubi / bi_list / fx_list, and survives an -//! incremental update_bar feed. +//! Phase D.A — RED test:CZSC 分析器从 RawBar 流构造, +//! 暴露 bars_raw / bars_ubi / bi_list / fx_list,并能正确处理 +//! 增量的 update_bar 输入。 use std::sync::Arc; @@ -26,7 +26,7 @@ fn rb(ts: i64, open: f64, close: f64, high: f64, low: f64) -> RawBar { } fn synthetic_zigzag(n: usize) -> Vec { - // Build a sine-like zigzag so that the analyzer can produce fxs/bis. + // 构造一个类正弦波形的 zigzag,让分析器能产出 fxs/bis。 (0..n) .map(|i| { let phase = (i as f64) * 0.7; @@ -56,8 +56,8 @@ fn new_populates_symbol_and_freq() { fn new_consumes_all_bars_and_builds_ubi() { let bars = synthetic_zigzag(40); let c = CZSC::new(bars, 50); - // bars_ubi is the merged-bar (NewBar) sequence; for 40 raw zigzag - // bars we expect non-empty merged sequence + // bars_ubi 是合并后 bar(NewBar)序列;对于 40 根原始 zigzag + // bar,我们期望合并后的序列非空 assert!(!c.bars_ubi.is_empty(), "bars_ubi should not be empty"); } @@ -66,9 +66,8 @@ fn fx_and_bi_lists_are_consistent_with_zigzag() { let bars = synthetic_zigzag(60); let c = CZSC::new(bars, 50); let fxs = c.get_fx_list(); - // A 60-bar zigzag should produce at least 2 fxs (or zero — the - // exact count depends on the synthetic shape; we only assert - // non-negative invariants). + // 60 根 bar 的 zigzag 应该至少产出 2 个 fx(或者 0 个—— + // 具体数量取决于合成出来的波形;这里只断言非负不变量)。 assert!(fxs.len() <= 60); assert!(c.bi_list.len() <= 50); } @@ -80,7 +79,7 @@ fn update_bar_appends_incrementally() { let extra = rb(1_700_000_000 + 30 * 1800, 102.0, 103.0, 104.0, 101.0); c.update_bar(extra); assert_eq!(c.freq, Freq::F30); - // bars_raw monotonically grows (modulo the analyzer's internal pruning) + // bars_raw 单调增长(不计分析器内部的裁剪) assert!( c.bars_raw .iter() diff --git a/crates/czsc-core/tests/test_fx.rs b/crates/czsc-core/tests/test_fx.rs index 57bbe3327..9c79033e7 100644 --- a/crates/czsc-core/tests/test_fx.rs +++ b/crates/czsc-core/tests/test_fx.rs @@ -1,6 +1,6 @@ -//! Phase D.7 — RED test: FX (分型) constructs via FXBuilder, exposes -//! power_str / power_volume / has_zs (non-python build), and compares -//! by structural equality. +//! Phase D.7 —— RED 测试:FX(分型)通过 FXBuilder 构造, +//! 暴露 power_str / power_volume / has_zs(non-python 构建), +//! 并按结构相等进行比较。 use std::sync::Arc; @@ -28,9 +28,9 @@ fn nb(ts: i64, high: f64, low: f64, vol: f64) -> NewBar { } fn sample_top_fx() -> FX { - // top fx (顶分型): middle bar's high is the highest + // 顶分型:中间 bar 的 high 是最高的 let k1 = nb(1_700_000_000, 11.0, 9.0, 100.0); - let k2 = nb(1_700_001_800, 12.0, 10.0, 200.0); // top + let k2 = nb(1_700_001_800, 12.0, 10.0, 200.0); // 顶 let k3 = nb(1_700_003_600, 11.5, 9.5, 100.0); FXBuilder::default() .symbol(Arc::::from("000001")) diff --git a/crates/czsc-core/tests/test_signal.rs b/crates/czsc-core/tests/test_signal.rs index 3be8f1361..656e79ba3 100644 --- a/crates/czsc-core/tests/test_signal.rs +++ b/crates/czsc-core/tests/test_signal.rs @@ -1,6 +1,6 @@ -//! Phase D.10b — RED test: Signal type (`SignalRef<'static>` aka `Signal`) -//! parses from the canonical k1_k2_k3_v1_v2_v3_score string and exposes -//! `key()` / `value()` / Display per the rs-czsc contract. +//! Phase D.10b —— RED 测试:Signal 类型(`SignalRef<'static>` 即 `Signal`) +//! 从规范的 k1_k2_k3_v1_v2_v3_score 字符串解析,并按 rs-czsc 的契约 +//! 暴露 `key()` / `value()` / Display。 use std::str::FromStr; @@ -10,9 +10,9 @@ use czsc_core::objects::signal::Signal; fn parses_canonical_signal_string() { let raw = "30分钟_D1_前高_看多_强_任意_0"; let s = Signal::from_str(raw).unwrap(); - // key drops "任意" parts; here all of k1/k2/k3 are concrete + // key 会去掉 "任意" 段;这里 k1/k2/k3 全部是具体值 assert_eq!(s.key(), "30分钟_D1_前高"); - // value is v1_v2_v3_score + // value 是 v1_v2_v3_score assert_eq!(s.value(), "看多_强_任意_0"); } @@ -40,7 +40,7 @@ fn equality_is_full_signal_string() { #[test] fn key_skips_wildcards() { let s = Signal::from_str("任意_D1_前高_看多_强_任意_0").unwrap(); - // k1 is 任意 → dropped from key + // k1 是「任意」→ 从 key 中剔除 assert_eq!(s.key(), "D1_前高"); } @@ -50,12 +50,9 @@ fn is_match_obeys_score_and_wildcards() { let s = Signal::from_str("30分钟_D1_前高_看多_强_任意_50").unwrap(); let mut dict = HashMap::new(); dict.insert("30分钟_D1_前高".to_string(), "看多_强_中_60".to_string()); - assert!( - s.is_match(&dict), - "score 60 >= 50 with v3 wildcard should match" - ); + assert!(s.is_match(&dict), "score 60 >= 50 且 v3 是通配符,应当匹配"); let mut low_score = HashMap::new(); low_score.insert("30分钟_D1_前高".to_string(), "看多_强_中_40".to_string()); - assert!(!s.is_match(&low_score), "score 40 < 50 must not match"); + assert!(!s.is_match(&low_score), "score 40 < 50 不应匹配"); } diff --git a/crates/czsc-core/tests/test_zs.rs b/crates/czsc-core/tests/test_zs.rs index 31b954a73..37408c4ba 100644 --- a/crates/czsc-core/tests/test_zs.rs +++ b/crates/czsc-core/tests/test_zs.rs @@ -1,5 +1,5 @@ -//! Phase D.9 — RED test: ZS (中枢) constructs from a non-empty BI list, -//! computes zg / zd / zz / gg / dd boundaries, and surfaces is_valid(). +//! Phase D.9 —— RED 测试:ZS(中枢)从非空 BI 列表构造, +//! 计算 zg / zd / zz / gg / dd 边界,并暴露 is_valid()。 use std::sync::Arc; @@ -61,7 +61,7 @@ fn make_bi( ) -> BI { let fx_a = fx(ts_a, mark_a, level_a); let fx_b = fx(ts_b, mark_b, level_b); - // bars span — endpoints determine high/low + // bars 跨度 —— 端点决定 high/low let bars = vec![ nb(ts_a, level_a + 0.5, level_a - 0.5), nb( @@ -83,7 +83,7 @@ fn make_bi( } fn sample_zs() -> ZS { - // 3-bi center: down 12 -> 9, up 9 -> 11, down 11 -> 9.5 + // 由 3 笔构成的中枢:向下 12 -> 9,向上 9 -> 11,向下 11 -> 9.5 let bi1 = make_bi( 1_700_000_000, Mark::G, @@ -125,7 +125,7 @@ fn new_populates_endpoints() { #[test] fn zg_zd_within_first_three_bis() { let zs = sample_zs(); - // zg = min of first 3 bis' highs; zd = max of first 3 bis' lows + // zg = 前 3 笔 high 的最小值;zd = 前 3 笔 low 的最大值 assert!(zs.zg >= zs.zd, "zg={} must be >= zd={}", zs.zg, zs.zd); } @@ -151,5 +151,5 @@ fn gg_dd_envelope_zg_zd() { #[test] fn is_valid_returns_bool() { let zs = sample_zs(); - let _ = zs.is_valid(); // doesn't matter true or false — must not panic + let _ = zs.is_valid(); // 返回 true 还是 false 都无所谓 —— 关键是不能 panic } diff --git a/crates/czsc-python/build.rs b/crates/czsc-python/build.rs index c03853420..c30c53438 100644 --- a/crates/czsc-python/build.rs +++ b/crates/czsc-python/build.rs @@ -1,10 +1,10 @@ -//! Build script for czsc-python. +//! czsc-python 的 build script。 //! -//! On macOS the cdylib needs `-undefined dynamic_lookup` so that Python -//! symbols are resolved at runtime by the host interpreter. PyO3's -//! `extension-module` feature normally emits this, but when building via -//! plain `cargo build --workspace` (without maturin) we make the link arg -//! explicit so the workspace layout test stays GREEN. +//! 在 macOS 上 cdylib 需要 `-undefined dynamic_lookup`,这样 Python 符号 +//! 才能在运行时由宿主解释器解析。PyO3 的 `extension-module` feature 一般 +//! 会自动加上这个 flag,但是当我们直接用 +//! `cargo build --workspace`(不走 maturin)构建时,需要显式声明这个 +//! 链接参数,以便 workspace layout test 保持 GREEN。 fn main() { if std::env::var("CARGO_CFG_TARGET_OS").as_deref() == Ok("macos") { diff --git a/crates/czsc-python/src/lib.rs b/crates/czsc-python/src/lib.rs index 28fe81fc4..470d46ba2 100644 --- a/crates/czsc-python/src/lib.rs +++ b/crates/czsc-python/src/lib.rs @@ -1,8 +1,8 @@ -//! czsc-python — PyO3 aggregator that produces the `czsc._native` extension. +//! czsc-python —— 产生 `czsc._native` 扩展的 PyO3 聚合器。 //! -//! Each business crate's PyO3 surface is registered here. The crate is -//! the only one that links `pyo3 = { features = ["extension-module"] }` -//! and produces the cdylib loaded by Python. +//! 每个业务 crate 的 PyO3 表面都在这里注册。这个 crate 是 workspace +//! 中唯一启用 `pyo3 = { features = ["extension-module"] }` 的,它产出 +//! Python 加载的 cdylib。 use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -18,12 +18,12 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { czsc_utils::python::register(py, m)?; czsc_ta::python::register(py, m)?; - // czsc-signals contributes `SignalDescriptor` entries via - // `inventory::collect!`. The dummy iterator forces the crate - // into the final cdylib so the constructors run on import. + // czsc-signals 通过 `inventory::collect!` 贡献 `SignalDescriptor` + // 条目。这里用一次哑迭代强制把该 crate 链入最终的 cdylib, + // 这样 import 时构造器就会跑起来。 let _signals_count = inventory::iter::().count(); - // Trader surface — CzscTrader, CzscSignals, generate_czsc_signals. + // Trader 表面 —— CzscTrader、CzscSignals、generate_czsc_signals。 m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!( @@ -31,9 +31,9 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m )?)?; - // Research / optimize entrypoints (mirrors rs_czsc/python/src/lib.rs). - // These are the heavy-lift functions that strategies.py / - // research.py / optimize.py wrap thinly on the Python side. + // Research / optimize 入口(对齐 rs_czsc/python/src/lib.rs)。 + // 这些是干重活的函数,Python 侧的 strategies.py / research.py / + // optimize.py 只做薄薄一层包装。 m.add_function(wrap_pyfunction!(trader::api::list_all_signals, m)?)?; m.add_function(wrap_pyfunction!(trader::api::derive_signals_config, m)?)?; m.add_function(wrap_pyfunction!(trader::api::derive_signals_freqs, m)?)?; @@ -52,13 +52,13 @@ fn _native(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m )?)?; - // czsc._native.signals namespace + per-category sub-modules - // (bar / cxt / tas / vol / pressure / obv / cvolp). The dispatcher - // is registered on each so that + // czsc._native.signals 命名空间 + 按类别分的子模块 + // (bar / cxt / tas / vol / pressure / obv / cvolp)。分发器在每个子模块 + // 上都注册一次,使得 // from czsc._native.signals import call_signal - // and + // 和 // from czsc._native.signals.bar import list_signal_names - // both resolve. See `signals_dispatcher.rs` for the design. + // 都能解析。设计细节见 `signals_dispatcher.rs`。 let signals = PyModule::new(py, "signals")?; signals.setattr("__name__", "czsc._native.signals")?; let sys = py.import("sys")?; diff --git a/crates/czsc-python/src/signals_dispatcher.rs b/crates/czsc-python/src/signals_dispatcher.rs index 261b8dec0..aa21c03d7 100644 --- a/crates/czsc-python/src/signals_dispatcher.rs +++ b/crates/czsc-python/src/signals_dispatcher.rs @@ -1,21 +1,19 @@ -//! czsc._native signal dispatcher (design doc §3.3). +//! czsc._native 信号分发器(设计文档 §3.3)。 //! -//! Per-signal PyO3 wrappers would require ~30+ hand-written `#[pyfunction]` -//! definitions; instead we expose a single dispatcher that looks the -//! signal up by name in the inventory table contributed by -//! `czsc-signals`. The Python-side ``czsc/signals/{bar,cxt,...}.py`` -//! shims attach a per-name closure via ``__getattr__`` so user code -//! reads naturally: +//! 若给每个信号写独立的 PyO3 wrapper,会有 30+ 个手写的 `#[pyfunction]`; +//! 这里改为暴露一个统一的分发器,通过名字在 `czsc-signals` 贡献的 inventory +//! 表里查找信号。Python 端的 ``czsc/signals/{bar,cxt,...}.py`` 通过 +//! ``__getattr__`` 给每个名字挂上一个闭包,使用方代码因而读起来很自然: //! //! ```python //! from czsc.signals.bar import bar_amount_acc_V230214 //! result = bar_amount_acc_V230214(czsc_obj, {"di": 1, "n": 5}) //! ``` //! -//! The dispatcher only handles **kline** signals (``fn(&CZSC, ¶ms, -//! &mut TaCache) -> Vec``). Trader-state signals require a -//! ``CzscTrader`` instance and are dispatched via -//! ``CzscTrader.update_signals`` / ``CzscSignals.update_signals``. +//! 本分发器只处理 **K 线** 类信号(签名为 ``fn(&CZSC, ¶ms, +//! &mut TaCache) -> Vec``)。依赖 trader 状态的信号需要 +//! ``CzscTrader`` 实例,走 ``CzscTrader.update_signals`` / +//! ``CzscSignals.update_signals`` 路径分发。 use crate::trader::czsc_signals::py_to_serde_value; use czsc_core::analyze::CZSC; @@ -27,24 +25,22 @@ use pyo3::types::PyDict; use serde_json::Value; use std::collections::HashMap; -/// Find a signal descriptor by name. Returns `None` if no descriptor -/// has that name. Callers should treat this as a missing signal. +/// 按名字查找信号 descriptor。找不到时返回 `None`,调用方应将其视为 +/// 信号未注册。 fn lookup(name: &str) -> Option<&'static SignalDescriptor> { inventory::iter::() .into_iter() .find(|d| d.name == name) } -/// Extract the category prefix from a signal name (e.g. ``bar`` from -/// ``bar_amount_acc_V230214``). Returns ``None`` if the name has no -/// underscore. +/// 从信号名中提取分类前缀(如从 ``bar_amount_acc_V230214`` 中取出 +/// ``bar``)。名字里不含下划线时返回 ``None``。 fn name_prefix(name: &str) -> Option<&str> { name.split_once('_').map(|(p, _)| p) } -/// Convert a Python params dict (or `None`) into the -/// ``HashMap`` shape used by all kline signal -/// functions. Accepts ``None`` as an empty dict. +/// 把 Python 端传入的 params 字典(或 ``None``)转换为所有 K 线类信号函数 +/// 都接受的 ``HashMap``。``None`` 视为空字典。 fn extract_params(params: Option<&Bound<'_, PyDict>>) -> PyResult> { let mut out: HashMap = HashMap::new(); if let Some(d) = params { @@ -57,10 +53,10 @@ fn extract_params(params: Option<&Bound<'_, PyDict>>) -> PyResult) -> Vec { @@ -107,28 +102,25 @@ pub fn list_signal_names(category: Option<&str>) -> Vec { out } -/// Return the parameter template for ``name``, or ``None`` if no signal -/// with that name is registered. The template is the schema string -/// declared in the `#[signal(...)]` macro and matches what the legacy -/// Python helpers parse. +/// 返回 ``name`` 对应信号的参数模板字符串;若未注册则返回 ``None``。 +/// 模板即 `#[signal(...)]` 宏里声明的 schema,与历史 Python 辅助代码 +/// 解析的字符串保持一致。 #[pyfunction] pub fn get_signal_template(name: &str) -> Option { lookup(name).map(|d| d.template.to_string()) } -/// Return the category prefix for ``name`` (``"bar"`` / ``"cxt"`` / -/// ...). ``None`` when the signal isn't registered or its name has no -/// underscore. +/// 返回 ``name`` 的分类前缀(``"bar"`` / ``"cxt"`` / ...)。信号未注册 +/// 或名字里不含下划线时返回 ``None``。 #[pyfunction] pub fn get_signal_category(name: &str) -> Option { let descriptor = lookup(name)?; name_prefix(descriptor.name).map(|p| p.to_string()) } -/// Register the dispatcher symbols on both ``czsc._native`` (top-level) -/// and ``czsc._native.signals`` (submodule). The submodule entries -/// give design-doc §3.3 the path ``from czsc._native.signals import -/// call_signal``. +/// 把分发器相关符号同时挂到 ``czsc._native``(顶层)和 +/// ``czsc._native.signals``(子模块)下。子模块入口对应设计文档 §3.3 +/// 描述的导入路径 ``from czsc._native.signals import call_signal``。 pub fn register( py: Python<'_>, m: &Bound<'_, PyModule>, @@ -146,14 +138,14 @@ pub fn register( signals_mod.add_function(wrap_pyfunction!(get_signal_template, signals_mod)?)?; signals_mod.add_function(wrap_pyfunction!(get_signal_category, signals_mod)?)?; - // Per-category sub-modules: czsc._native.signals.{bar,cxt,...}. - // Each gets the full dispatcher trio so user code can write: + // 按分类创建子模块:czsc._native.signals.{bar,cxt,...}。 + // 每个子模块都挂上完整的分发器三件套,使用方代码可以这样写: // // import czsc._native.signals.bar as bar_mod - // bar_mod.list_signal_names() # only bar_* names + // bar_mod.list_signal_names() # 只列 bar_* 信号 // - // The Python-side `czsc/signals/.py` shim layers __getattr__ - // on top of these to expose individual functions. + // Python 侧的 `czsc/signals/.py` 在此基础上叠加 __getattr__, + // 把单个信号函数暴露成可直接调用的属性。 let categories = ["bar", "cxt", "tas", "vol", "pressure", "obv", "cvolp"]; let sys = py.import("sys")?; let py_modules = sys.getattr("modules")?; diff --git a/crates/czsc-python/src/trader/api.rs b/crates/czsc-python/src/trader/api.rs index 9b73d6c95..018a28485 100644 --- a/crates/czsc-python/src/trader/api.rs +++ b/crates/czsc-python/src/trader/api.rs @@ -1268,7 +1268,7 @@ pub fn generate_signals( vec![] }; - // 2) Arrow bytes -> bars + // 2) Arrow bytes 转换为 bars let raw_data = bars_bytes.as_bytes(); let df = pyarrow_to_df(raw_data) .map_err(|e| PyValueError::new_err(format!("Arrow bytes 转 DataFrame 失败: {e}")))?; diff --git a/crates/czsc-python/src/trader/czsc_signals.rs b/crates/czsc-python/src/trader/czsc_signals.rs index cc7b88ad9..f3fc153aa 100644 --- a/crates/czsc-python/src/trader/czsc_signals.rs +++ b/crates/czsc-python/src/trader/czsc_signals.rs @@ -246,9 +246,9 @@ impl PyCzscSignals { } } -/// Helper: convert `Vec` back to a Python ``list[dict]`` -/// shaped exactly like ``parse_signals_config`` expects, so -/// ``__reduce__`` -> ``__new__`` round-trips cleanly. +/// 辅助函数:把 `Vec` 还原回与 ``parse_signals_config`` +/// 期望形状完全一致的 Python ``list[dict]``,让 +/// ``__reduce__`` -> ``__new__`` 能干净地往返序列化。 pub(crate) fn signal_configs_to_pylist(py: Python, configs: &[SignalConfig]) -> PyResult { let list = PyList::empty(py); for cfg in configs { diff --git a/crates/czsc-python/src/trader/czsc_trader.rs b/crates/czsc-python/src/trader/czsc_trader.rs index df0377b08..f7ca0889c 100644 --- a/crates/czsc-python/src/trader/czsc_trader.rs +++ b/crates/czsc-python/src/trader/czsc_trader.rs @@ -344,7 +344,7 @@ impl PyCzscTrader { fn __reduce__(&self, py: Python) -> PyResult { let bg_clone = self.inner.signals.bg.clone(); - // positions: clone via PyPosition wrappers + // positions:通过 PyPosition wrapper 克隆 let positions_list = PyList::empty(py); for pos in &self.inner.positions { let py_pos = PyPosition { inner: pos.clone() }; diff --git a/crates/czsc-python/src/trader/mod.rs b/crates/czsc-python/src/trader/mod.rs index 6443be181..274b37c5a 100644 --- a/crates/czsc-python/src/trader/mod.rs +++ b/crates/czsc-python/src/trader/mod.rs @@ -1,11 +1,11 @@ -//! PyO3 wrappers for czsc-trader public objects (CzscTrader / CzscSignals), -//! the `generate_czsc_signals` free function, and the research/optimize -//! orchestration entrypoints (`run_research`, `run_replay`, -//! `run_optimize_batch`, `build_*_optim_positions`). +//! czsc-trader 公共对象(CzscTrader / CzscSignals)的 PyO3 包装层、 +//! `generate_czsc_signals` 自由函数,以及 research/optimize 编排入口 +//! (`run_research`、`run_replay`、`run_optimize_batch`、 +//! `build_*_optim_positions`)。 //! -//! Mirrors `rs_czsc/python/src/trader/`. The `weight_backtest` submodule -//! from rs-czsc is intentionally NOT migrated — czsc relies on the -//! external `wbt` package for backtests (design doc §3.1 / §5.10). +//! 对齐 `rs_czsc/python/src/trader/`。rs-czsc 中的 `weight_backtest` +//! 子模块**有意**不迁移过来 —— czsc 依赖外部 `wbt` 包做回测 +//! (design doc §3.1 / §5.10)。 pub mod api; pub mod czsc_signals; diff --git a/crates/czsc-signal-macros/tests/test_export.rs b/crates/czsc-signal-macros/tests/test_export.rs index 9bbb1fbba..f6018727d 100644 --- a/crates/czsc-signal-macros/tests/test_export.rs +++ b/crates/czsc-signal-macros/tests/test_export.rs @@ -1,14 +1,13 @@ -//! Phase E.last — smoke test: czsc-signal-macros compiles and the test -//! binary can link against it. Proc-macros are compile-time constructs, -//! so the real validation is that this test target builds at all. +//! Phase E.last — smoke test:czsc-signal-macros 能编译,且测试二进制 +//! 能成功 link 到它。proc-macro 属于编译期构造,因此真正的验证就是 +//! 这个 test target 能不能编译通过。 //! -//! Full expansion testing requires czsc-signals types and lands in -//! Phase F (every signal module under crates/czsc-signals/src/*.rs -//! exercises `#[signal_module]` and `#[signal]` against the real -//! types). +//! 完整的展开测试需要 czsc-signals 的类型,安排在 Phase F +//! 进行(crates/czsc-signals/src/*.rs 下的每个信号模块都会针对 +//! 真实类型走一遍 `#[signal_module]` 和 `#[signal]`)。 #[test] fn proc_macro_crate_links() { - // Reaching this function means the crate compiled with both - // #[proc_macro_attribute] entrypoints exported. + // 能跑到这个函数就说明 crate 编译通过,并且两个 + // #[proc_macro_attribute] 入口都已正常导出。 } diff --git a/crates/czsc-signals/src/utils/ta.rs b/crates/czsc-signals/src/utils/ta.rs index 31571fd32..12b84f2d9 100644 --- a/crates/czsc-signals/src/utils/ta.rs +++ b/crates/czsc-signals/src/utils/ta.rs @@ -249,9 +249,9 @@ fn calc_wma_cache_style(series: &[f64], n: usize) -> Vec { } /// 计算 MACD {dif, dea, macd},对齐 TA-Lib 原始三元组语义: -/// - `dif` = MACD line -/// - `dea` = signal line -/// - `macd` = histogram +/// - `dif` = MACD 线 +/// - `dea` = 信号线 +/// - `macd` = 柱状图 pub fn calc_macd(series: &[f64], short: usize, long: usize, m: usize) -> MacdSeries { let len = series.len(); let mut dif = vec![f64::NAN; len]; @@ -1378,8 +1378,8 @@ fn calc_stoch( let mut slowk = calc_sma_nan(&fastk, slowk_period); let mut slowd = calc_sma_nan(&slowk, slowd_period); - // Align with TA-Lib STOCH lookback: fastk-1 + slowk-1 + slowd-1. - // TA-Lib returns both slowk/slowd as NaN before this index. + // 与 TA-Lib STOCH lookback 对齐:fastk-1 + slowk-1 + slowd-1。 + // TA-Lib 在该索引之前的 slowk/slowd 都返回 NaN。 let lookback = (fastk_period - 1) + (slowk_period - 1) + (slowd_period - 1); let warmup = lookback.min(len); for i in 0..warmup { diff --git a/crates/czsc-ta/src/pure.rs b/crates/czsc-ta/src/pure.rs index 40b0e2462..6cb7c546a 100644 --- a/crates/czsc-ta/src/pure.rs +++ b/crates/czsc-ta/src/pure.rs @@ -1,13 +1,12 @@ //! 纯 Rust 实现 -/// Plain Simple Moving Average — talib.SMA-compatible. +/// 简单移动平均 —— 与 talib.SMA 兼容。 /// -/// Returns the rolling mean of `series` with window size `n`. Indices -/// 0..n-1 are filled with NaN to match the talib convention (so -/// `np.isfinite(out[n:])` is fully True). This is intentionally -/// distinct from `single_sma_positions`, which computes a double SMA -/// then derives a [-1, 0, 1] position signal — `sma` here is the raw -/// moving average needed by `czsc.ta.sma` per design doc §3.1. +/// 返回 `series` 在窗口大小为 `n` 下的滚动均值。索引 0..n-1 用 NaN +/// 填充以匹配 talib 约定(这样 `np.isfinite(out[n:])` 全为 True)。 +/// 它与 `single_sma_positions` 有意区分:后者先计算两次 SMA 再 +/// 派生 [-1, 0, 1] 持仓信号;这里的 `sma` 是 design doc §3.1 中 +/// `czsc.ta.sma` 所需的原始移动平均。 pub fn sma(series: &[f64], n: usize) -> Vec { let len = series.len(); if len == 0 || n == 0 { @@ -904,13 +903,12 @@ pub fn rank_positions(series: &[f64], n: usize) -> Vec { /// 计算指数移动平均 (EMA) /// 返回每个点的 EMA 值 pub fn ema(series: &[f64], period: usize) -> Vec { - // talib-compatible EMA: warmup [0, period-1) is NaN, position - // (period-1) is seeded with the simple mean of the first `period` - // samples, then the standard recurrence runs forward. This matches - // talib.EMA's output bit-for-bit (verified by Phase A's - // `test_ema_matches_talib`). The previous rs-czsc implementation - // seeded with `series[0]`, which produced visible divergence in - // the first ~30 bars. + // 与 talib 兼容的 EMA:warmup [0, period-1) 部分为 NaN,位置 + // (period-1) 用前 `period` 个样本的简单平均作为种子,随后按 + // 标准递推向前计算。这与 talib.EMA 的输出逐 bit 一致(已由 + // Phase A 的 `test_ema_matches_talib` 验证)。早先 rs-czsc 的 + // 实现以 `series[0]` 作为种子,在前约 30 根 bar 上会出现可见 + // 偏差。 let len = series.len(); if len == 0 || period == 0 { return vec![]; diff --git a/crates/czsc-ta/src/python.rs b/crates/czsc-ta/src/python.rs index 83e7b1e51..e67df3dcc 100644 --- a/crates/czsc-ta/src/python.rs +++ b/crates/czsc-ta/src/python.rs @@ -1,10 +1,10 @@ -//! PyO3 binding registry for czsc-ta. +//! czsc-ta 的 PyO3 绑定注册表。 //! -//! Mirrors the wrapper layer that rs-czsc kept inside its python crate -//! (rs_czsc/python/src/utils/ta.rs); we move the `#[pyfunction]` shells -//! into czsc-ta itself so czsc-python only orchestrates `register()` -//! calls. All wrappers are dormant unless the `python` feature -//! (or `rust-numpy` for the numpy-bound entries) is on. +//! 镜像 rs-czsc 原本放在其 python crate 里的 wrapper 层 +//! (rs_czsc/python/src/utils/ta.rs);我们把 `#[pyfunction]` 外壳 +//! 搬到 czsc-ta 自身,这样 czsc-python 只负责编排 `register()` +//! 调用。除非启用 `python` feature(numpy-bound 条目则需要 +//! `rust-numpy`),否则所有 wrapper 都处于休眠状态。 use pyo3::prelude::*; @@ -17,10 +17,9 @@ fn ultimate_smoother(close: Vec, period: f64) -> Vec { #[pyfunction] fn rolling_rank(series: Vec, window: usize) -> Vec { - // Convert Option -> f64 (None -> NaN) so `np.asarray(...)` lands - // in float64 dtype and `np.isfinite(out[window:])` works as expected. - // Python callers consuming the rank position can `.dropna()` instead of - // filtering Nones. + // 把 Option 转成 f64(None -> NaN),这样 `np.asarray(...)` 落到 + // float64 dtype,`np.isfinite(out[window:])` 也能按预期工作。 + // 消费 rank 位置的 Python 调用方可以用 `.dropna()` 代替对 None 的过滤。 pure::rolling_rank(&series, window) .into_iter() .map(|opt| opt.map(|r| r as f64).unwrap_or(f64::NAN)) @@ -35,9 +34,9 @@ fn sma( period: Option, length: Option, ) -> Vec { - // Same kwarg story as `ema` — talib's keyword is `timeperiod` / - // pandas-ta's is `length`; rs-czsc historical scripts pass `n` / - // `period`. Phase A parity test calls `ta.sma(series, length=20)`. + // 关键字参数情况和 `ema` 一样 —— talib 的关键字是 `timeperiod` / + // pandas-ta 的是 `length`;rs-czsc 历史脚本传 `n` / `period`。 + // Phase A parity test 调用 `ta.sma(series, length=20)`。 let p = n.or(period).or(length).unwrap_or(0); pure::sma(&series, p) } @@ -111,12 +110,11 @@ fn ema( period: Option, length: Option, ) -> Vec { - // Accept any of: positional `n`, kwargs `period=` (legacy rs-czsc) or - // `length=` (talib / pandas-ta convention). The Phase A parity test - // in `test/unit/test_ta_parity.py::test_ema_matches_talib` calls - // `ta.ema(series, length=14)`; rs-czsc historical scripts pass - // `period=14`. Resolution order preserves the positional path first - // so existing positional callers keep working. + // 接受以下任一形式:位置参数 `n`、关键字参数 `period=`(rs-czsc 遗留) + // 或 `length=`(talib / pandas-ta 惯例)。Phase A parity test + // 中 `test/unit/test_ta_parity.py::test_ema_matches_talib` 调用 + // `ta.ema(series, length=14)`;rs-czsc 历史脚本传 `period=14`。 + // 解析顺序优先保留位置参数路径,让既有的位置参数调用方继续工作。 let p = n.or(period).or(length).unwrap_or(0); pure::ema(&series, p) } @@ -185,15 +183,15 @@ fn holt_winters( pure::holt_winters(&series, season_length, alpha, beta, gamma) } -/// Add the migrated czsc-ta functions onto the parent module that -/// czsc-python passes in. Build a `ta` submodule mirroring the design -/// doc §3.1 namespace map (czsc.ta.* + repeated top-level exposure). +/// 把迁移过来的 czsc-ta 函数挂到 czsc-python 传入的父模块上。构建一个 +/// `ta` 子模块,镜像 design doc §3.1 的命名空间映射(czsc.ta.* 以及在 +/// 顶层重复暴露)。 pub fn register(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { let ta = PyModule::new(py, "ta")?; - // Set the fully-qualified __name__ so `czsc.ta` (aliased via - // sys.modules) reports `__name__ == "czsc._native.ta"`. Required - // by the public-API parity test that checks namespace origin and - // by pickle when classes living in this submodule get round-tripped. + // 设置全限定的 __name__,使 `czsc.ta`(通过 sys.modules 别名暴露) + // 报告 `__name__ == "czsc._native.ta"`。检查命名空间来源的 + // public-API parity test 需要这个值;当该子模块里的类被 + // pickle 往返序列化时也需要这个值。 ta.setattr("__name__", "czsc._native.ta")?; macro_rules! add { @@ -231,7 +229,7 @@ pub fn register(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { holt_winters, ); - // numpy-bound entries + // numpy-bound 条目 ta.add_function(wrap_pyfunction!( mixed::chip_dist::chip_distribution_triangle, &ta @@ -241,19 +239,18 @@ pub fn register(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { parent )?)?; - // Register the submodule into sys.modules so `from czsc._native.ta - // import ema` (and `import czsc._native.ta`) works the same as a - // pure-Python package. `parent.add_submodule` only sets it as an - // attribute of the parent — sys.modules is the bit Python's import - // machinery actually consults for nested module resolution. + // 把子模块注册到 sys.modules,这样 `from czsc._native.ta + // import ema`(以及 `import czsc._native.ta`)就能像纯 Python + // 包一样工作。`parent.add_submodule` 只是把它设置为父模块的 + // 一个属性 —— Python 的 import 机制在做嵌套模块解析时实际查询的 + // 是 sys.modules。 let sys = py.import("sys")?; let py_modules = sys.getattr("modules")?; py_modules.set_item("czsc._native.ta", &ta)?; - // Use `parent.add` instead of `add_submodule` so we control the - // attribute key (`parent.ta`) independently of the module's - // qualified __name__ (`czsc._native.ta`). add_submodule uses the - // qualified name as the attribute, which would expose the - // submodule as `parent.czsc._native.ta` instead of `parent.ta`. + // 使用 `parent.add` 而不是 `add_submodule`,这样可以让属性 key + // (`parent.ta`)与模块的全限定 __name__(`czsc._native.ta`)相互 + // 独立地受控。add_submodule 用全限定名作为属性,会把子模块 + // 暴露成 `parent.czsc._native.ta` 而不是 `parent.ta`。 parent.add("ta", &ta)?; Ok(()) } diff --git a/crates/czsc-ta/tests/test_pure.rs b/crates/czsc-ta/tests/test_pure.rs index 35d913c8c..dc0bbf93c 100644 --- a/crates/czsc-ta/tests/test_pure.rs +++ b/crates/czsc-ta/tests/test_pure.rs @@ -1,6 +1,5 @@ -//! Phase E.2 — RED test: czsc-ta pure operators preserve length, behave -//! correctly on degenerate inputs, and produce sensible numeric output -//! aligned with the rs-czsc 47ef6efa baseline. +//! Phase E.2 — RED test:czsc-ta 纯算子保持长度、在退化输入上表现 +//! 正确,并产出与 rs-czsc 47ef6efa baseline 对齐的合理数值输出。 use czsc_ta::pure::{ boll_positions, double_sma_positions, ema, mid_positions, rolling_rank, single_ema_positions, @@ -20,7 +19,7 @@ fn ultimate_smoother_preserves_length() { #[test] fn ultimate_smoother_first_4_passthrough() { - // First 4 values must equal input per the rs-czsc contract. + // 按 rs-czsc 契约,前 4 个值必须等于输入。 let s = series(20); let out = ultimate_smoother(&s, 10.0); for i in 0..4 { @@ -54,11 +53,11 @@ fn ema_preserves_length() { #[test] fn ema_zero_period_returns_empty() { - // rs-czsc's ema short-circuits to an empty Vec when period == 0; - // lock that behaviour so callers can rely on a stable contract. + // rs-czsc 的 ema 在 period == 0 时短路返回空 Vec; + // 锁定这个行为,让调用方可以依赖一个稳定的契约。 let s = vec![1.0, 2.0, 3.0]; let out = ema(&s, 0); - assert!(out.is_empty(), "ema(_, 0) must short-circuit to empty"); + assert!(out.is_empty(), "ema(_, 0) 必须短路返回空 Vec"); } #[test] @@ -78,10 +77,7 @@ fn mid_positions_in_range() { let s = series(30); let out = mid_positions(&s, 5); for v in &out { - assert!( - *v >= -1.0 && *v <= 1.0, - "expected position in [-1, 1], got {v}" - ); + assert!(*v >= -1.0 && *v <= 1.0, "持仓应在 [-1, 1] 区间,实际为 {v}"); } } @@ -98,7 +94,7 @@ fn boll_positions_in_signed_range() { let out = boll_positions(&s, 20, 2.0); assert_eq!(out.len(), s.len()); for v in &out { - assert!(*v >= -1 && *v <= 1, "boll position must be -1/0/1, got {v}"); + assert!(*v >= -1 && *v <= 1, "boll 持仓必须为 -1/0/1,实际为 {v}"); } } @@ -111,6 +107,6 @@ fn true_range_matches_input_length() { assert_eq!(tr.len(), 4); // tr[i] = max(high-low, |high-prev|, |low-prev|) >= 0 for v in &tr { - assert!(*v >= 0.0, "true_range must be non-negative, got {v}"); + assert!(*v >= 0.0, "true_range 必须非负,实际为 {v}"); } } diff --git a/crates/czsc-trader/src/lib.rs b/crates/czsc-trader/src/lib.rs index f0e2e74be..b9b757ed5 100644 --- a/crates/czsc-trader/src/lib.rs +++ b/crates/czsc-trader/src/lib.rs @@ -1,11 +1,10 @@ -//! czsc-trader — multi-strategy trading engine, signal compilation, and -//! optimization. Migrated from rs-czsc 47ef6efa per docs/MIGRATION_NOTES.md §1. +//! czsc-trader — 多策略交易引擎、信号编译以及优化。 +//! 按 docs/MIGRATION_NOTES.md §1 从 rs-czsc 47ef6efa 迁移而来。 //! -//! `weight_backtest` was deliberately not migrated: per design doc §5.8 -//! item 3 and §5.10, the public `WeightBacktest` API is delegated to the -//! external `wbt` package starting in Phase I. The Rust workspace owns -//! signal compilation, the trader state machine, and the v2 execution -//! engine that backs Python's `run_backtest` / `run_optimize` calls. +//! `weight_backtest` 故意没有迁移:按 design doc §5.8 item 3 和 §5.10, +//! 公开的 `WeightBacktest` API 从 Phase I 开始委托给外部 `wbt` 包。 +//! Rust workspace 负责信号编译、trader 状态机,以及支撑 Python +//! `run_backtest` / `run_optimize` 调用的 v2 执行引擎。 pub mod engine_v2; pub mod optimize; diff --git a/crates/czsc-trader/src/optimize.rs b/crates/czsc-trader/src/optimize.rs index e2b8f1b66..aa56903b7 100644 --- a/crates/czsc-trader/src/optimize.rs +++ b/crates/czsc-trader/src/optimize.rs @@ -178,13 +178,13 @@ pub fn get_exit_optim_positions( let event_str = py_repr_event_dump(&event); let hash = hash_str(&event_str); - // mode = append + // mode = append(追加模式) let mut pos_append = beta.clone(); pos_append.exits.push(event.clone()); pos_append.name = format!("{}#追加{}", beta.name, hash); pos_list.push(pos_append); - // mode = replace + // mode = replace(替换模式) let mut pos_replace = beta.clone(); pos_replace.exits = vec![event.clone()]; pos_replace.name = format!("{}#替换{}", beta.name, hash); diff --git a/crates/czsc-utils/src/lib.rs b/crates/czsc-utils/src/lib.rs index 667c2b8fe..7245ab620 100644 --- a/crates/czsc-utils/src/lib.rs +++ b/crates/czsc-utils/src/lib.rs @@ -1,11 +1,11 @@ -//! czsc-utils — utilities crate. +//! czsc-utils — 工具 crate。 //! -//! Phase C scope so far: -//! - [`trading_time`] — `is_trading_time` (czsc-only NEW per design doc §2.5) +//! 目前 Phase C 范围: +//! - [`trading_time`] — `is_trading_time`(czsc-only NEW,按 design doc §2.5) //! -//! Pending (deferred until Phase D unlocks `czsc-core`): -//! - `freq_data` — depends on `czsc_core::objects::{RawBar, Freq, Market}` -//! - `bar_generator` — depends on `czsc-core` +//! 待办(延后到 Phase D 解锁 `czsc-core` 再做): +//! - `freq_data` — 依赖 `czsc_core::objects::{RawBar, Freq, Market}` +//! - `bar_generator` — 依赖 `czsc-core` pub mod bar_generator; pub mod errors; diff --git a/crates/czsc-utils/src/python/mod.rs b/crates/czsc-utils/src/python/mod.rs index e8d53b2d8..e39604789 100644 --- a/crates/czsc-utils/src/python/mod.rs +++ b/crates/czsc-utils/src/python/mod.rs @@ -1,5 +1,5 @@ -//! PyO3 bindings for czsc-utils. Gated by the `python` feature so that -//! downstream Rust consumers don't pull pyo3 in transitively. +//! czsc-utils 的 PyO3 binding。通过 `python` feature 来开关, +//! 这样下游 Rust 消费者就不会传递性地引入 pyo3。 use chrono::{DateTime, Utc}; use czsc_core::objects::{freq::Freq, market::Market}; @@ -7,20 +7,20 @@ use pyo3::prelude::*; use crate::bar_generator::BarGenerator; -/// `czsc.is_trading_time(dt, market="astock")` → bool. +/// `czsc.is_trading_time(dt, market="astock")` → bool。 /// -/// `dt` is taken as a naive Python `datetime` (no tz attached). See -/// design doc §2.5 + §6 F6 for the contract. +/// `dt` 视为 naive 的 Python `datetime`(不附带 tz)。契约详见 +/// design doc §2.5 + §6 F6。 #[pyfunction] #[pyo3(signature = (dt, market="astock"))] fn is_trading_time(dt: chrono::NaiveDateTime, market: &str) -> bool { crate::is_trading_time(dt, market) } -/// `czsc.freq_end_time(dt, freq, market=Market.Default)` → datetime. +/// `czsc.freq_end_time(dt, freq, market=Market.Default)` → datetime。 /// -/// Wraps `czsc_utils::freq_data::freq_end_time`. Errors are mapped to -/// `PyValueError` via `UtilsError`'s PyErr conversion. +/// 包装 `czsc_utils::freq_data::freq_end_time`。错误通过 `UtilsError` +/// 的 PyErr 转换映射到 `PyValueError`。 #[pyfunction] #[pyo3(signature = (dt, freq, market=Market::Default))] fn freq_end_time(dt: DateTime, freq: Freq, market: Market) -> PyResult> { @@ -28,9 +28,8 @@ fn freq_end_time(dt: DateTime, freq: Freq, market: Market) -> PyResult, parent: &Bound<'_, PyModule>) -> PyResult<()> { let utils = PyModule::new(py, "utils")?; utils.add_function(wrap_pyfunction!(is_trading_time, &utils)?)?; @@ -38,8 +37,8 @@ pub fn register(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { utils.add_class::()?; parent.add_submodule(&utils)?; - // Also expose top-level so `from czsc._native import *` makes the - // canonical names directly visible (per design doc §3.1). + // 同时在顶层暴露一份,这样 `from czsc._native import *` 时 + // 规范名称可以直接可见(按 design doc §3.1)。 parent.add_function(wrap_pyfunction!(is_trading_time, parent)?)?; parent.add_function(wrap_pyfunction!(freq_end_time, parent)?)?; parent.add_class::()?; diff --git a/crates/czsc-utils/src/trading_time.rs b/crates/czsc-utils/src/trading_time.rs index 899cbcdd0..ffdc43ae1 100644 --- a/crates/czsc-utils/src/trading_time.rs +++ b/crates/czsc-utils/src/trading_time.rs @@ -1,8 +1,8 @@ -//! Trading-time predicate. czsc-only addition; see docs/MIGRATION_NOTES.md §2.2. +//! 交易时间判定。czsc-only 新增;详见 docs/MIGRATION_NOTES.md §2.2。 //! -//! Inputs are interpreted as the market's local naive datetime (no tz -//! attached): A股 → CST, 港股 → HKT, crypto → any. The Python wrapper at -//! `czsc.is_trading_time` keeps the same contract. +//! 输入按市场本地的 naive datetime 解释(不附带 tz): +//! A股 → CST,港股 → HKT,crypto → 任意。Python wrapper +//! `czsc.is_trading_time` 保持相同的契约。 use chrono::{Datelike, NaiveDateTime, Timelike, Weekday}; @@ -20,9 +20,8 @@ fn is_weekday(dt: &NaiveDateTime) -> bool { !matches!(dt.weekday(), Weekday::Sat | Weekday::Sun) } -/// Return true iff `dt` (local market time) falls inside the regular trading -/// session for `market`. Recognised values: `astock`, `hk`, `crypto`. Any -/// other string returns `false`. +/// 当且仅当 `dt`(市场本地时间)落在 `market` 的常规交易时段内时返回 true。 +/// 识别的取值:`astock`、`hk`、`crypto`。其他字符串一律返回 `false`。 pub fn is_trading_time(dt: NaiveDateTime, market: &str) -> bool { match market { "crypto" => true, @@ -39,7 +38,7 @@ pub fn is_trading_time(dt: NaiveDateTime, market: &str) -> bool { return false; } let m = minute_of_day(&dt); - // HK lunch break: 12:00-13:00 (12:00 is closed) + // 港股午休:12:00-13:00(12:00 已闭市) (hm_minutes(9, 30)..hm_minutes(12, 0)).contains(&m) || (hm_minutes(13, 0)..=hm_minutes(16, 0)).contains(&m) } diff --git a/crates/czsc-utils/tests/test_freq_data.rs b/crates/czsc-utils/tests/test_freq_data.rs index 2956a353c..c02f97073 100644 --- a/crates/czsc-utils/tests/test_freq_data.rs +++ b/crates/czsc-utils/tests/test_freq_data.rs @@ -1,5 +1,5 @@ -//! Phase C.1 — RED test: freq_end_time + infer_market_from_bars match the -//! rs-czsc 47ef6efa baseline behaviour. +//! Phase C.1 — RED test:freq_end_time + infer_market_from_bars 与 +//! rs-czsc 47ef6efa 基线行为一致。 use std::sync::Arc; @@ -11,9 +11,9 @@ use czsc_utils::freq_data::{freq_end_time, infer_market_from_bars}; #[test] fn freq_end_time_returns_some_datetime_for_30min_default_market() { - // For an arbitrary intraday minute we just want to confirm the helper - // succeeds and the result is non-decreasing (the actual minute table - // is encoded in minutes_split.feather and is the rs-czsc baseline). + // 对于任意一个盘中分钟,我们只想确认 helper 调用成功,并且 + // 结果是非递减的(具体的分钟表编码在 minutes_split.feather 中, + // 即 rs-czsc 基线)。 let dt = Utc.with_ymd_and_hms(2024, 1, 8, 9, 30, 0).unwrap(); let edt = freq_end_time(dt, Freq::F30, Market::Default).unwrap(); assert!(edt >= dt, "edt {edt} must be >= input dt {dt}"); @@ -21,8 +21,8 @@ fn freq_end_time_returns_some_datetime_for_30min_default_market() { #[test] fn freq_end_time_idempotent_when_already_at_boundary() { - // If a query already lands on a boundary the function must round-trip - // (i.e. calling it again should return the same instant). + // 如果一次查询已经正好落在边界上,函数必须能 round-trip + // (即再调一次应该返回同一个时刻)。 let dt = Utc.with_ymd_and_hms(2024, 1, 8, 10, 0, 0).unwrap(); let edt1 = freq_end_time(dt, Freq::F30, Market::Default).unwrap(); let edt2 = freq_end_time(edt1, Freq::F30, Market::Default).unwrap(); @@ -31,9 +31,8 @@ fn freq_end_time_idempotent_when_already_at_boundary() { #[test] fn freq_end_time_handles_daily_freq() { - // For higher timeframes the function is still callable and should not - // panic; the exact boundary semantics are encoded in rs-czsc and we - // simply lock that calling it returns Ok. + // 对更高级别的时间周期,函数仍然可以调用且不应 panic; + // 具体的边界语义编码在 rs-czsc 里,我们这里只锁定调用返回 Ok。 let dt = Utc.with_ymd_and_hms(2024, 1, 8, 14, 30, 0).unwrap(); let _ = freq_end_time(dt, Freq::D, Market::Default).unwrap(); } diff --git a/crates/czsc-utils/tests/test_trading_time.rs b/crates/czsc-utils/tests/test_trading_time.rs index a780eb0fa..df96f9220 100644 --- a/crates/czsc-utils/tests/test_trading_time.rs +++ b/crates/czsc-utils/tests/test_trading_time.rs @@ -1,7 +1,7 @@ -//! Phase C.3 — RED test: is_trading_time across A股 / 港股 / crypto. +//! Phase C.3 — RED 测试:is_trading_time 覆盖 A股 / 港股 / crypto。 //! -//! Mirrors the cases locked by test/unit/test_trading_time.py (Phase A.6). -//! The function is czsc-only — see docs/MIGRATION_NOTES.md §2.2. +//! 镜像 test/unit/test_trading_time.py(Phase A.6)锁定的用例。 +//! 该函数为 czsc-only —— 见 docs/MIGRATION_NOTES.md §2.2。 use chrono::NaiveDate; use czsc_utils::is_trading_time; @@ -15,7 +15,7 @@ fn dt(y: i32, mo: u32, d: u32, h: u32, mi: u32) -> chrono::NaiveDateTime { #[test] fn astock_regular_session() { - // 2024-01-08 is Monday + // 2024-01-08 是周一 assert!(is_trading_time(dt(2024, 1, 8, 9, 30), "astock")); assert!(is_trading_time(dt(2024, 1, 8, 10, 0), "astock")); assert!(is_trading_time(dt(2024, 1, 8, 11, 30), "astock")); @@ -31,7 +31,7 @@ fn astock_lunch_break_and_off_hours() { #[test] fn astock_weekend_closed() { - // 2024-01-06 is Saturday + // 2024-01-06 是周六 assert!(!is_trading_time(dt(2024, 1, 6, 10, 0), "astock")); } diff --git a/crates/error-support/src/lib.rs b/crates/error-support/src/lib.rs index d3459ce6e..11f9f183e 100644 --- a/crates/error-support/src/lib.rs +++ b/crates/error-support/src/lib.rs @@ -23,7 +23,7 @@ pub fn expand_error_chain(err: &anyhow::Error) -> String { error_chain } -/// Copy from anyhow::bail! +/// 从 anyhow::bail! 复制而来 #[macro_export] macro_rules! czsc_bail { ($msg:literal $(,)?) => { diff --git a/czsc/_native.pyi b/czsc/_native.pyi index ff800f954..446337413 100644 --- a/czsc/_native.pyi +++ b/czsc/_native.pyi @@ -285,17 +285,15 @@ class CZSC: def __repr__(self) -> builtins.str: ... def __reduce__(self) -> typing.Any: r""" - Pickle support — `__reduce__` returns ``(CZSC, (fixed_point_bars, max_bi_num))``. + Pickle 支持 —— `__reduce__` 返回 ``(CZSC, (fixed_point_bars, max_bi_num))``。 - `update_bar` drains older bars whose dt is below the current - first-BI's start (see `bars_raw.drain` block above), so a - freshly-constructed CZSC's `bars_raw` may still differ from the - fixed point reached after a single re-analysis. We run one extra - `CZSC::new` here to converge before serializing — guarantees that - `pickle.dumps(restored) == pickle.dumps(obj)` byte-for-byte even - when CzscSignals nests CZSC inside `kas[freq]` (Phase A's - `restored.__getstate__() == obj.__getstate__()` assertion relies - on this). + `update_bar` 会丢弃 dt 小于当前 first-BI 起始时间的旧 bar + (参见上面的 `bars_raw.drain` 块),因此刚构造出来的 CZSC 的 + `bars_raw` 可能仍然和「再分析一次后到达的不动点」不同。这里多 + 跑一次 `CZSC::new`,让其在序列化前收敛 —— 保证即使 CzscSignals + 在 `kas[freq]` 里嵌套了 CZSC,`pickle.dumps(restored) == + pickle.dumps(obj)` 也是逐字节相等的(Phase A 的 + `restored.__getstate__() == obj.__getstate__()` 断言依赖这一点)。 """ class CzscSignals: