import numpy as np
import pandas as pd
import akshare as ak
import matplotlib.pyplot as plt
from talib import abstract
from datetime import datetime, timedelta
import matplotlib.dates as mdates
from collections import defaultdict
# 设置matplotlib支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
class TripleBottomStrategy:
"""
三重W底量化选股策略框架
功能:
1. 动态波动率调整突破阈值
2. MACD/RSI二次验证
3. 均线系统多因子过滤
4. 形态完成后的涨幅统计
5. 模式失败概率分析
6. 可视化形态识别结果
"""
def __init__(self, stock_code, start_date, end_date):
"""初始化策略实例"""
self.stock_code = stock_code
self.data = self._load_data(start_date, end_date)
self._preprocess_data()
self.patterns = None # 存储检测到的形态
self.performance_results = None # 存储形态表现结果
def _load_data(self, start_date, end_date):
"""从akshare加载股票数据"""
df = ak.stock_zh_a_daily(symbol=self.stock_code, start_date=start_date, end_date=end_date)
df.index = pd.to_datetime(df['date'])
return df[['open', 'high', 'low', 'close', 'volume']]
def _preprocess_data(self):
"""数据预处理和技术指标计算"""
# 计算均线
self.data['ma20'] = self.data['close'].rolling(20).mean()
self.data['ma60'] = self.data['close'].rolling(60).mean()
# 计算ATR波动率
self.data['atr'] = abstract.ATR(self.data['high'], self.data['low'],
self.data['close'], timeperiod=14)
# 计算MACD指标
self.data['macd'], self.data['macd_signal'], self.data['macd_hist'] = abstract.MACD(
self.data['close'], fastperiod=12, slowperiod=26, signalperiod=9)
# 计算RSI指标
self.data['rsi'] = abstract.RSI(self.data['close'], timeperiod=14)
# 计算未来N天的涨幅,用于回测和分析
for n in [5, 10, 20, 30]:
self.data[f'future_return_{n}d'] = self.data['close'].pct_change(n).shift(-n)
def _find_extrema(self, window=5):
"""使用滑动窗口寻找局部极值点"""
# 局部最大值
peaks = (self.data['low'] == self.data['low'].rolling(window, center=True).min())
# 局部最小值
troughs = (self.data['low'] == self.data['low'].rolling(window, center=True).min())
# 合并并排序所有极值点
extrema = pd.Series(index=self.data.index)
extrema[peaks] = 1 # 1表示波峰
extrema[troughs] = -1 # -1表示波谷
extrema = extrema.dropna()
return extrema
def _dynamic_breakout_ratio(self, date_index):
"""基于波动率动态调整突破阈值"""
# 确保有足够的数据计算ATR
if date_index < 30:
return 0.03
recent_atr = self.data['atr'].iloc[date_index-20:date_index].mean()
overall_atr = self.data['atr'].iloc[:date_index].mean()
if overall_atr == 0:
return 0.03
base_ratio = 0.03 # 基础突破比例
# 根据近期波动率与历史波动率的比率调整突破阈值
dynamic_ratio = base_ratio * (recent_atr / overall_atr)
# 限制在2%-5%之间,防止极端情况
return np.clip(dynamic_ratio, 0.02, 0.05)
def _valIDAte_with_indicators(self, date_index):
"""使用MACD和RSI指标进行二次验证"""
if date_index < 60:
return False
current_data = self.data.iloc[date_index]
# MACD金叉或在零轴附近且动能增强
macd_ok = (current_data['macd'] > current_data['macd_signal']) and \
(current_data['macd_hist'] > self.data['macd_hist'].iloc[date_index-1])
# RSI不在超买区,最好在50以下
rsi_ok = current_data['rsi'] < 60
# 均线多头排列或即将形成
ma_ok = current_data['ma20'] > current_data['ma60']
return macd_ok and rsi_ok and ma_ok
def detect_pattern(self, min_distance=21, price_tolerance=0.05):
"""
检测三重底形态
参数:
min_distance: 波谷之间的最小交易日距离
price_tolerance: 底部价格的容忍度,允许的最大价格差异比例
"""
extrema = self._find_extrema()
troughs = extrema[extrema == -1]
patterns = []
# 遍历所有可能的三重底组合
for i in range(len(troughs) - 2):
t1_date, t2_date, t3_date = troughs.index[i], troughs.index[i+1], troughs.index[i+2]
# 检查波谷之间的时间距离
days_between_t1_t2 = (t2_date - t1_date).days
days_between_t2_t3 = (t3_date - t2_date).days
if days_between_t1_t2 < min_distance or days_between_t2_t3 < min_distance:
continue
# 获取三个底部的价格
t1_price = self.data.loc[t1_date, 'low']
t2_price = self.data.loc[t2_date, 'low']
t3_price = self.data.loc[t3_date, 'low']
# 检查三个底部价格是否相近
if not (abs(t1_price - t2_price) / t1_price < price_tolerance and
abs(t2_price - t3_price) / t2_price < price_tolerance and
abs(t1_price - t3_price) / t1_price < price_tolerance):
continue
# 找到两个底部之间的波峰
peak1_idx = self.data.index.get_loc(t1_date)
peak2_idx = self.data.index.get_loc(t2_date)
peak_between = self.data['high'].iloc[peak1_idx:peak2_idx].idxmax()
# 确认W底形态(中间高,两边低)
if not (self.data.loc[peak_between, 'high'] > max(t1_price, t2_price)):
continue
# 找到颈线位(两个波峰的连线)
neckline_price = max(self.data.loc[peak_between, 'high'],
self.data['high'].iloc[peak2_idx:self.data.index.get_loc(t3_date)].idxmax())
# 获取形态完成日期的索引
completion_idx = self.data.index.get_loc(t3_date)
# 计算动态突破阈值
breakout_ratio = self._dynamic_breakout_ratio(completion_idx)
breakout_price = neckline_price * (1 + breakout_ratio)
# 检查是否突破
# 我们检查突破后的5个交易日内是否收盘价高于突破价
future_data = self.data.iloc[completion_idx:completion_idx+6]
breakout_date = None
for date, row in future_data.iterrows():
if row['close'] > breakout_price:
breakout_date = date
break
if breakout_date is None:
continue
# 指标二次验证
if not self._validate_with_indicators(self.data.index.get_loc(breakout_date)):
continue
# 记录形态信息
patterns.append({
't1_date': t1_date,
't2_date': t2_date,
't3_date': t3_date,
'neckline_price': neckline_price,
'breakout_price': breakout_price,
'breakout_date': breakout_date,
'breakout_ratio_used': breakout_ratio,
'success': None, # 成功与否将在回测中确定
'future_returns': {}
})
self.patterns = patterns
return patterns
def analyze_pattern_performance(self):
"""分析形态完成后的涨幅和失败概率"""
if not self.patterns:
print("未检测到任何三重底形态,无法进行性能分析。")
return
results = []
for pattern in self.patterns:
breakout_idx = self.data.index.get_loc(pattern['breakout_date'])
# 计算不同周期的未来收益
for n in [5, 10, 20, 30]:
if breakout_idx + n < len(self.data):
pattern['future_returns'][f'{n}d'] = self.data.iloc[breakout_idx + n]['close'] / \
self.data.loc[pattern['breakout_date'], 'close'] - 1
else:
pattern['future_returns'][f'{n}d'] = np.nan
results.append({
'breakout_date': pattern['breakout_date'],
'return_5d': pattern['future_returns'].get('5d'),
'return_10d': pattern['future_returns'].get('10d'),
'return_20d': pattern['future_returns'].get('20d'),
'return_30d': pattern['future_returns'].get('30d')
})
self.performance_results = pd.DataFrame(results).set_index('breakout_date')
# 计算失败概率
success_threshold = 0.02 # 涨幅超过2%视为成功
success_counts = defaultdict(int)
total_counts = defaultdict(int)
for pattern in self.patterns:
for period, ret in pattern['future_returns'].items():
if pd.notna(ret):
total_counts[period] += 1
if ret >= success_threshold:
success_counts[period] += 1
print("\n--- 形态表现分析 ---")
print(f"共检测到 {len(self.patterns)} 个有效三重底形态")
print("\n未来涨幅统计:")
print(self.performance_results.describe())
print("\n模式失败概率分析:")
for period in ['5d', '10d', '20d', '30d']:
if total_counts[period] > 0:
success_rate = success_counts[period] / total_counts[period]
failure_rate = 1 - success_rate
print(f"{period} 失败概率: {failure_rate:.2%} (成功: {success_counts[period]}, 总样本: {total_counts[period]})")
else:
print(f"{period} 没有足够的数据进行分析")
def visualize(self):
"""可视化形态识别结果"""
if not self.patterns:
print("未检测到任何三重底形态,无法进行可视化。")
return
fig, ax = plt.subplots(figsize=(16, 10))
# 绘制价格和均线
ax.plot(self.data.index, self.data['close'], label='收盘价', linewidth=2)
ax.plot(self.data.index, self.data['ma20'], label='20日均线', color='orange', alpha=0.7)
ax.plot(self.data.index, self.data['ma60'], label='60日均线', color='purple', alpha=0.7)
# 标记形态关键点
for pattern in self.patterns:
# 标记三个底部
ax.scatter([pattern['t1_date'], pattern['t2_date'], pattern['t3_date']],
[self.data.loc[pattern['t1_date'], 'low'],
self.data.loc[pattern['t2_date'], 'low'],
self.data.loc[pattern['t3_date'], 'low']],
color='green', s=100, marker='^', label='三重底底部')
# 标记突破点
ax.scatter(pattern['breakout_date'], self.data.loc[pattern['breakout_date'], 'close'],
color='red', s=120, marker='*', label='突破点')
# 绘制颈线
# 找到第一个波峰
peak1_idx = self.data.index.get_loc(pattern['t1_date'])
peak2_idx = self.data.index.get_loc(pattern['t2_date'])
peak_between_date = self.data['high'].iloc[peak1_idx:peak2_idx].idxmax()
# 找到第二个波峰
peak3_idx = self.data.index.get_loc(pattern['t3_date'])
peak_after_idx = self.data.index.get_loc(pattern['breakout_date'])
peak_after_date = self.data['high'].iloc[peak2_idx:peak3_idx].idxmax()
# 绘制颈线
neckline_points = [peak_between_date, peak_after_date, pattern['breakout_date']]
ax.plot(neckline_points, [pattern['neckline_price']]*3, 'r--', label='颈线')
# 添加标注
ax.annotate(f"突破: {pattern['breakout_price']:.2f}",
xy=(pattern['breakout_date'], pattern['breakout_price']),
xytext=(pattern['breakout_date'], pattern['breakout_price'] * 1.05),
arrowprops=dict(facecolor='black', shrink=0.05))
ax.set_title(f"{self.stock_code} 三重底形态识别结果")
ax.set_xlabel("日期")
ax.set_ylabel("价格")
ax.grid(True)
ax.legend()
# 优化x轴日期显示
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
plt.gcf().autofmt_xdate()
plt.show()