diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000000..ddc91d066b3 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,131 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +Qlib (pyqlib) is an open-source, AI-oriented quantitative investment platform by Microsoft. It provides the full ML pipeline for quant research: data processing, factor engineering, model training, backtesting, and online serving. Supports Python 3.8–3.12 on Linux, macOS, and Windows. + +## Build & Install + +```bash +# Development install (editable + all extras) +make dev + +# Or step-by-step: +make prerequisite # compile Cython extensions (rolling, expanding) +make install # pip install -e . (minimal) +make develop # pip install -e .[dev] (adds pytest, statsmodels) +make lint # install lint tools +make test # install test deps + +# Build wheel package +make build +``` + +The Cython extensions (`qlib/data/_libs/rolling.pyx`, `expanding.pyx`) compile C++ code for rolling/expanding window operations. They require `cython` and `numpy` headers. + +## Test + +```bash +# Run full test suite (excluding slow tests) +cd tests && python -m pytest . -m "not slow" --durations=0 + +# Run a single test file +cd tests && python -m pytest test_workflow.py -m "not slow" + +# Run slow tests too +cd tests && python -m pytest . + +# On macOS, thread limits are needed to prevent OpenMP segfaults: +export OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 NUMEXPR_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 +``` + +Tests are marked with `@pytest.mark.slow` for expensive tests. RL tests are only collected on Linux (`tests/conftest.py`). Tests require data downloaded to `~/.qlib/qlib_data/cn_data`. + +## Lint + +```bash +make black # black -l 120 --check +make pylint # pylint on qlib/ and scripts/ (many checks disabled, see Makefile) +make flake8 # flake8 on qlib/ (E501,F541,E266,E402,W503,E731,E203 ignored) +make mypy # mypy on qlib/ (data, model, contrib, utils etc. excluded) +make nbqa # Run black + pylint on notebooks +make lint # all of the above +``` + +Line length is 120 chars. Pre-commit hooks run black + flake8 on push. + +## Architecture + +### Initialization + +Everything starts with `qlib.init()` (configs in `qlib/config.py`). The global `C` object (`QlibConfig`) holds all configuration — provider URIs, cache settings, region parameters, logging, etc. Two modes exist: `client` (local data) and `server` (Redis-backed cache). Region-specific config (`REG_CN`, `REG_US`, `REG_TW`) controls trade units, limit thresholds, and deal prices. + +Always call `qlib.init(provider_uri="...")` before any operation. The `qlib.auto_init()` helper finds `config.yaml` by walking up the directory tree. + +### Data Layer (`qlib/data/`) + +Providers abstract data access through a plugin architecture. The default is `LocalProvider` backed by `FileStorage` (binary `.bin` files keyed by instrument+field). Key abstractions: + +- **`CalendarProvider`** — trading calendar +- **`InstrumentProvider`** — stock universe +- **`FeatureProvider`** — raw features +- **`ExpressionProvider`** — computed expressions (like a query engine for financial data) +- **`DatasetProvider`** — materialized datasets with caching +- **`PITProvider`** — point-in-time data (prevents lookahead bias) + +The expression system (`qlib/data/ops.py`) defines operators like `Ref()`, `Mean()`, `Std()`, `Rsquare()`, etc. that compile into computations over financial time series. Custom operators can be registered at init time. + +`qlib/data/dataset/` contains the data loading pipeline: +- `DataHandler` / `DataHandlerLP` — processes raw data into features/labels (train/valid/test segments) +- `DatasetH` — the standard PyTorch-style Dataset wrapper + +### Model Layer (`qlib/model/`) + +`Model` inherits from `BaseModel` and must implement `fit(dataset)` and `predict(dataset)`. Models are config-driven: each model class is instantiated from YAML config via `init_instance_by_config()`. The `qlib/model/trainer.py` `task_train()` function orchestrates training from config. + +Model subdirectories: +- `qlib/model/ens/` — ensemble models +- `qlib/model/meta/` — meta-learning (DDG-DA, etc.) +- `qlib/model/interpret/` — interpretability tools +- `qlib/model/riskmodel/` — risk modeling + +### Workflow & Experiment Management (`qlib/workflow/`) + +Experiments track runs with MLflow as the backend. The global `R` (`QlibRecorder`) manages experiment lifecycle: + +```python +with R.start(experiment_name="test", recorder_name="run1"): + model.fit(dataset) + R.log_metrics(mse=0.1, step=0) + R.save_objects(model=model) +``` + +`qrun` CLI (`qlib/cli/run.py`) runs a full workflow from a YAML config file. It supports Jinja2 templating and base config inheritance via `BASE_CONFIG_PATH`. + +### Backtesting (`qlib/backtest/`) + +Components run in a nested decision framework: +- `exchange.py` — simulates trading with costs, limits, delays +- `executor.py` — executes decisions +- `decision.py` — trade decision representation +- `account.py` / `position.py` — portfolio tracking +- `signal.py` — signal generation from model predictions +- `report.py` / `profit_attribution.py` — performance analysis + +### Strategy (`qlib/strategy/`) + +Trading strategies transform model predictions into trade decisions. The base class is in `qlib/strategy/base.py`. + +### Contrib (`qlib/contrib/`) + +Community-contributed and research models, strategies, and workflows. Organized mirrors the core structure: `contrib/model/`, `contrib/strategy/`, `contrib/workflow/`, `contrib/data/`, `contrib/online/`, `contrib/rolling/`, `contrib/tuner/`. Production-grade models like LightGBM, GRU, LSTM, Transformer, HIST, etc. live here. + +### Online Serving (`qlib/workflow/online/`) + +Supports deploying trained models for live trading with automatic model rolling and updating. + +## Config-Driven Design + +Nearly everything is instantiated from YAML/dict config using `qlib.utils.init_instance_by_config()`. A config specifies `class`, `module_path`, and `kwargs`. This pattern is used for models, data handlers, strategies, and execution components. diff --git a/examples/my_workflow_with_visual.py b/examples/my_workflow_with_visual.py new file mode 100644 index 00000000000..fa3b3b1a60f --- /dev/null +++ b/examples/my_workflow_with_visual.py @@ -0,0 +1,269 @@ +""" +完整训练 + 回测 + 可视化 workflow + +基于 workflow_config_lightgbm_Alpha360.yaml 的配置, +增加回测结果的可视化输出。 + +运行方式: + python examples/my_workflow_with_visual.py + +如果不在 notebook 中运行,图表将保存为 HTML 文件到当前目录。 +""" + +import pandas as pd +import qlib +from qlib.constant import REG_CN +from qlib.utils import init_instance_by_config +from qlib.utils.time import Freq +from qlib.workflow import R +from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord +from qlib.data import D +from qlib.backtest import backtest +from qlib.backtest.executor import SimulatorExecutor +from qlib.contrib.evaluate import risk_analysis +from qlib.contrib.strategy import TopkDropoutStrategy + +# ── 可视化模块 ────────────────────────────────────────── +from qlib.contrib.report.analysis_position import ( + report_graph, + score_ic_graph, + risk_analysis_graph, + cumulative_return_graph, +) + +# ── 判断是否在 notebook 环境 ────────────────────────────── +try: + get_ipython() + IN_NOTEBOOK = True +except NameError: + IN_NOTEBOOK = False + + +def show_or_save(figs, filename): + """在 notebook 中内嵌显示,否则保存为 HTML""" + if IN_NOTEBOOK: + from qlib.contrib.report.graph import BaseGraph + + BaseGraph.show_graph_in_notebook(figs) + print(f"[notebook] {filename} 已显示") + else: + html = "\n".join(fig.to_html(include_plotlyjs="cdn") for fig in figs) + path = f"{filename}.html" + with open(path, "w", encoding="utf-8") as f: + f.write(html) + print(f"[file] 图表已保存 -> {path}") + + +def main(): + # ── 1. 初始化 Qlib ──────────────────────────────────── + provider_uri = "~/.qlib/qlib_data/cn_data" + qlib.init(provider_uri=provider_uri, region=REG_CN) + + market = "csi300" + benchmark = "SH000300" + + # ── 2. 数据集配置(与 Alpha360 yaml 一致)───────────── + dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha360", + "module_path": "qlib.contrib.data.handler", + "kwargs": { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, + "infer_processors": [], + "learn_processors": [ + {"class": "DropnaLabel"}, + {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, + ], + "label": ["Ref($close, -2) / Ref($close, -1) - 1"], + }, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + } + + # ── 3. 模型配置(与 Alpha360 yaml 一致)─────────────── + model_config = { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + } + + # ── 4. 实例化 dataset 和 model ──────────────────────── + dataset = init_instance_by_config(dataset_config) + model = init_instance_by_config(model_config) + + # 看一眼训练集数据 + df_train = dataset.prepare("train", col_set=["feature", "label"]) + print(f"训练集: features={df_train['feature'].shape}, labels={df_train['label'].shape}") + + # ── 5. 训练 + 预测 ──────────────────────────────────── + with R.start(experiment_name="my_workflow_visual", recorder_name="run1"): + model.fit(dataset) + recorder = R.get_recorder() + + sr = SignalRecord(model, dataset, recorder) + sr.generate() + + sar = SigAnaRecord(recorder, ana_long_short=False, ann_scaler=252) + sar.generate() + + # 获取预测信号 + pred_df = recorder.load_object("pred.pkl") + print(f"预测信号: {pred_df.shape}, index={pred_df.index.names}") + + # ── 6. 回测 ─────────────────────────────────────────── + STRATEGY_CONFIG = { + "topk": 50, + "n_drop": 5, + "signal": pred_df, + } + EXECUTOR_CONFIG = { + "time_per_step": "day", + "generate_portfolio_metrics": True, + } + backtest_config = { + "start_time": "2017-01-01", + "end_time": "2020-08-01", + "account": 100000000, + "benchmark": benchmark, + "exchange_kwargs": { + "freq": "day", + "limit_threshold": 0.095, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + }, + } + + strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG) + executor_obj = SimulatorExecutor(**EXECUTOR_CONFIG) + portfolio_metric_dict, indicator_dict = backtest( + executor=executor_obj, strategy=strategy_obj, **backtest_config + ) + + analysis_freq = f"{Freq.parse('day')[0]}{Freq.parse('day')[1]}" + report_normal_df, positions = portfolio_metric_dict.get(analysis_freq) + print(f"\n回测报告: {report_normal_df.shape}, columns={list(report_normal_df.columns)}") + + # ── 7. 打印风险指标 ──────────────────────────────────── + print("\n" + "=" * 60) + print("基准收益 (1day)") + print(risk_analysis(report_normal_df["bench"])) + print("\n超额收益(无成本)") + print(risk_analysis(report_normal_df["return"] - report_normal_df["bench"])) + print("\n超额收益(含成本)") + print( + risk_analysis( + report_normal_df["return"] - report_normal_df["bench"] - report_normal_df["cost"] + ) + ) + print("=" * 60) + + # ── 8. 可视化 ───────────────────────────────────────── + + # 8a. 回测全景图 + figs = report_graph(report_normal_df, show_notebook=False) + show_or_save(figs, "report_overview") + + # 8b. IC 分析图 + pred_df_dates = pred_df.index.get_level_values("datetime") + label_data = D.features( + D.instruments(market), + ["Ref($close, -2) / Ref($close, -1) - 1"], + pred_df_dates.min(), + pred_df_dates.max(), + ) + label_data.columns = ["label"] + pred_label = pd.concat([label_data, pred_df], axis=1, sort=True).reindex(label_data.index) + print(f"\nIC 数据: {pred_label.shape}") + + try: + figs = score_ic_graph(pred_label, show_notebook=False) + except ValueError: + print("[warn] score_ic_graph 内部报错(可能是 groupby 返回标量),改为手动计算 IC") + data = pred_label.dropna() + if len(data) > 0: + ic = data.groupby(level="datetime").apply( + lambda x: x["label"].corr(x["score"]) + ) + rank_ic = data.groupby(level="datetime").apply( + lambda x: x["label"].corr(x["score"], method="spearman") + ) + if not hasattr(ic, "index"): + ic = pd.Series([ic], index=[data.index.get_level_values("datetime")[0]]) + rank_ic = pd.Series([rank_ic], index=ic.index) + ic_df = pd.DataFrame({"ic": ic, "rank_ic": rank_ic}) + from qlib.contrib.report.graph import ScatterGraph + + fig = ScatterGraph( + ic_df, + layout=dict(title="Score IC", xaxis=dict(tickangle=45)), + graph_kwargs={"mode": "lines+markers"}, + ).figure + figs = (fig,) + else: + print("[warn] 无有效数据,跳过 IC 图") + figs = () + show_or_save(figs, "score_ic") + + # 8c. 风险分析图 + analysis = {} + analysis["excess_return_without_cost"] = risk_analysis( + report_normal_df["return"] - report_normal_df["bench"] + ) + analysis["excess_return_with_cost"] = risk_analysis( + report_normal_df["return"] - report_normal_df["bench"] - report_normal_df["cost"] + ) + analysis_df = pd.concat(analysis) + print(f"\n风险分析: {analysis_df}") + + figs = risk_analysis_graph(analysis_df, report_normal_df, show_notebook=False) + show_or_save(figs, "risk_analysis") + + # 8d. 买卖持仓分析 + figs = cumulative_return_graph( + positions, report_normal_df, label_data, show_notebook=False, + start_date="2017-01-01", end_date="2020-08-01", + ) + show_or_save(figs, "cumulative_return") + + # ── 9. 打印关键指标摘要 ──────────────────────────────── + print("\n" + "=" * 60) + print("关键指标摘要") + print("=" * 60) + total_cost = report_normal_df["cost"].sum() + total_turnover = report_normal_df["turnover"].sum() + avg_turnover = report_normal_df["turnover"].mean() + # 估算单边换手率 = turnover / 2(因为一买一卖各算一次) + print(f"累计成本: {total_cost:.4f}") + print(f"总换手率: {total_turnover:.2f}") + print(f"日均换手率: {avg_turnover:.4f}") + print(f"成本/换手率: {total_cost / total_turnover * 10000:.1f} bps (每次交易平均成本)") + + print("\n完成!所有图表已输出。") + + +if __name__ == "__main__": + main()