Giter VIP home page Giter VIP logo

Comments (1)

hugo2046 avatar hugo2046 commented on May 18, 2024

哈喽,您好,期待您的回复。请问在时变夏普中有两个库如下: from WaveModel import wave_transform # 自定义小波分析库 from EDC import QueryMacroIndic 我使用chatgpt询问说是wave_transform() 函数可能是作者自己编写的, 请问可以告知代码在你们吗? 谢谢您

后续上传该部分代码

import pandas as pd
import numpy as np
import pywt  # 小波分析

import itertools
import talib
from sklearn import preprocessing
from sklearn import svm


# 信号去噪
class DenoisingThreshold(object):
    '''
    获取小波去噪的阈值
    1. CalSqtwolog 固定阈值准则(sqtwolog)
    2. CalRigrsure 无偏风险估计准则(rigrsure)
    3. CalMinmaxi 极小极大准则( minimaxi)
    4. CalHeursure
    
    参考:https://wenku.baidu.com/view/63d62a818762caaedd33d463.html
    
    对股票价格等数据而言,其信号频率较少地与噪声重叠因此可以选用sqtwolog和heursure准则,使去噪效果更明显。 
    但对收益率这样的高频数据,尽量采用保守的 rigrsure 或 minimaxi 准则来确定阈值,以保留较多的信号。
    '''

    def __init__(self, signal: np.array):

        self.signal = signal

        self.N = len(signal)

    # 固定阈值准则(sqtwolog)
    @property
    def CalSqtwolog(self) -> float:

        return np.sqrt(2 * np.log(self.N))

    # 无偏风险估计准则(rigrsure)
    @property
    def CalRigrsure(self) -> float:

        N = self.N
        signal = np.abs(self.signal)
        signal = np.sort(signal)
        signal = np.power(signal, 2)

        risk_j = np.zeros(N)

        for j in range(N):

            if j == 0:
                risk_j[j] = 1 + signal[N - 1]
            else:
                risk_j[j] = (N - 2 * j + (N - j) *
                             (signal[N - j]) + np.sum(signal[:j])) / N

        k = risk_j.argmin()

        return np.sqrt(signal[k])

    # 极小极大准则( minimaxi)
    @property
    def CalMinmaxi(self) -> float:

        if self.N > 32:
            # N>32 可以使用minmaxi阈值 反之则为0
            return 0.3936 + 0.1829 * (np.log(self.N) / np.log(2))

        else:

            return 0

    @property
    def GetCrit(self) -> float:

        return np.sqrt(np.power(np.log(self.N) / np.log(2), 3) * 1 / self.N)

    @property
    def GetEta(self) -> float:

        return (np.sum(np.abs(self.signal)**2) - self.N) / self.N

    #混合准则(heursure)
    @property
    def CalHeursure(self):

        if self.GetCrit > self.GetEta:

            #print('推荐使用sqtwolog阈值')
            return self.CalSqtwolog

        else:

            #print('推荐使用 Min(sqtwolog阈值,rigrsure阈值)')
            return min(self.CalRigrsure, self.CalSqtwolog)


# 小波处理+svm滚动预测
class wavelet_svm_model(object):
    '''对数据进行建模预测
    --------------------
    输入参数:

        data:必须包含OHLC money及预测字段Y(ovo标记) 其余字段为训练数据
        M:train数据的滚动计算窗口
        window:滚动窗口 即T至T-window日 预测T-1至T-window日数据 预测T日数据
        wavelet\wavelet_mode:同pywt.wavedec的参数
        th_mode:阈值确认准则
        filter_num:需要过滤小波的细节组 比如(3,4)对三至四组进行过滤 为空则是1-4组全过滤 
        whether_wave_process:是否使用小波处理
    --------------------
    方法:
        wave_process:过滤阈值 采用固定阈值准则(sqtwolog)
        preprocess:生成训练用字段
        rolling_svm:使用svm滚动训练
    '''

    def __init__(self,
                 data: pd.DataFrame,
                 M: int,
                 window: int,
                 wavelet: str,
                 wavelet_mode: str,
                 th_mode: str,
                 filter_num=None,
                 whether_wave_process: bool = False):

        self.data = data
        self.__M = M
        self.__window = window
        self.__wavelet = wavelet
        self.__wavelet_mode = wavelet_mode
        self.__th_mode = th_mode
        self.__filter_num = filter_num
        self.__whether_wave_process = whether_wave_process

        self.__train_col = [col for col in self.data.columns if col != 'Y'
                           ]  # 训练的字段

        self.train_df = pd.DataFrame()  # 储存训练数据
        self.predict_df = data[['Y']].copy()  # 储存预测数据及真实Y

    def wave_process(self):
        '''对数据进行小波处理(可选)'''

        if self.__filter_num:

            a = self.__filter_num[0]
            b = self.__filter_num[1]

            #self.__filter_num = range(a,b + 1)

        else:
            a = 1
            b = 5
            #self.__filter_num = range(1,5)

        data = self.data.copy()  # 复制

        for col in self.__train_col:

            #res1 = pywt.wavedec(
            #    data[col].values, wavelet=self.__wavelet, mode=self.__wavelet_mode, level=4)

            #for j in self.__filter_num:

            #    threshold = DenoisingThreshold(res1[j]).CalHeursure
            #    res1[j] = pywt.threshold(res1[j], threshold, 'soft')

            denoised_ser = wave_transform(
                data[col],
                wavelet=self.__wavelet,
                wavelet_mode=self.__wavelet_mode,
                level=4,
                th_mode=self.__th_mode,
                n=a,
                m=b)

            #data[col] = pywt.waverec(res1, self.__wavelet)
            data[col] = denoised_ser

        self.train_df = data

    def preprocess(self):
        '''生成相应的特征'''

        if self.__whether_wave_process:

            self.wave_process()  # 小波处理

            data = self.train_df

        else:

            data = self.data.copy()

        data['近M日最高价'] = data['high'].rolling(self.__M).max()
        data['近M日最低价'] = data['low'].rolling(self.__M).min()
        data['成交额占比'] = data['money'] / data['money'].rolling(self.__M).sum()
        data['近M日涨跌幅'] = data['close'].pct_change(self.__M)
        data['近M日均价'] = data['close'].rolling(self.__M).mean()

        # 上面新增了需要训练用的字段 这里更新字段
        self.__train_col = [
            col for col in data.columns if col not in self.__train_col + ['Y']
        ]
        self.train_df = data[self.__train_col]
        self.train_df = self.train_df.iloc[self.__M:]

    def standardization(self):
        '''对所有特征进行标准化处理'''

        data = preprocessing.scale(self.train_df[self.__train_col])
        data = pd.DataFrame(
            data, index=self.train_df.index, columns=self.__train_col)
        data['Y'] = self.predict_df['Y']
        self.train_df = data

    def rolling_svm(self):
        '''利用SVM模型进行建模预测'''

        predict_ser = rolling_apply(self.train_df, self.model_fit,
                                    self.__window)

        self.predict_df['predict'] = predict_ser

        self.predict_df = self.predict_df.iloc[self.__window + self.__M:]

    def model_fit(self, df: pd.DataFrame) -> pd.Series:

        idx = df.index[-1]

        train_x = df[self.__train_col].iloc[:-1]
        train_y = df['Y'].shift(-1).iloc[:-1]  # 对需要预测的y进行滞后一期处理

        test_x = df[self.__train_col].iloc[-1:]

        model = svm.SVC(gamma=0.001)

        model.fit(train_x, train_y)

        return pd.Series(model.predict(test_x), index=[idx])


# 小波变换
def wave_transform(data_ser: pd.Series, wavelet: str, wavelet_mode: str,
                   level: int, th_mode: str, n: int, m: int) -> pd.Series:
    '''
    参数:
        data_ser:pd.Series
        wavelet\wavelet_mode\level:同pywt.wavedec
        th_mode:选择阈值的准则
        n,m:需要过了的层级范围
    '''
    res1 = pywt.wavedec(
        data_ser.values, wavelet=wavelet, mode=wavelet_mode, level=level)

    denoising_dic = {
        'rigrsure': 'CalRigrsure',
        'sqtwolog': 'CalSqtwolog',
        'heursure': 'CalHeursure',
        'minimaxi': 'CalMinmaxi'
    }

    for j in range(n, m + 1):

        dsth = DenoisingThreshold(res1[j])
        threshold = getattr(dsth, denoising_dic[th_mode])

        res1[j] = pywt.threshold(res1[j], threshold, 'soft')
    
    # 数据重构
    redata = pywt.waverec(res1, wavelet)
    if len(redata) != len(data_ser):
        
        return pd.Series(redata[:len(data_ser)],index=data_ser.index)
    
    else:
    
        return pd.Series(redata,index=data_ser.index)


class AnalysisWaveletModel(object):
    '''通过不同的M及滚动训练窗口 查看模型预测情况'''

    def __init__(self,
                 data: pd.DataFrame,
                 M_list: list,
                 window_list: list,
                 wavelet: str,
                 wavelet_mode: str,
                 th_mode: str,
                 whether_wave_process: bool = False):

        self.data = data
        self.__M_list = M_list
        self.__window_list = window_list
        self.__wavelet = wavelet
        self.__wavelet_mode = wavelet_mode
        self.__th_mode = th_mode
        self.__whether_wave_process = whether_wave_process

        self.Flag_df = pd.DataFrame()  # 持仓标记
        self.res_svm_pred = pd.DataFrame()  # 训练结果展示表

    def iterations_params(self):

        params = list(itertools.product(self.__M_list, self.__window_list))

        res_svm_pred = pd.DataFrame(columns=[
            'M', '训练窗宽', '总预测次数', '成功次数', '成功概率', '上涨预测成功率', '下跌预测成功概率'
        ])

        flag_list = []

        for m, w in tqdm(params, desc='模型训练中'):

            # 初始化模型
            wsm = wavelet_svm_model(self.data, m, w, self.__wavelet,
                                    self.__wavelet_mode, self.__th_mode,
                                    self.__whether_wave_process)
            # 计算训练字段
            wsm.preprocess()
            # 标准化
            wsm.standardization()
            # 滚动训练
            wsm.rolling_svm()

            predict_ = wsm.predict_df
            predict_num = len(predict_)

            predict_['predict'] = predict_['predict'].shift(1)

            # 全部
            right_num = len(predict_[predict_['predict'] == predict_['Y']])
            right_pre = right_num / predict_num

            # 上涨预测成功概率
            up_df = predict_.query('Y==1')
            up_num = len(up_df[up_df['predict'] == up_df['Y']]) / len(
                up_df)  # 上涨预测成功率

            # 下跌预测成功概率
            down_df = predict_.query('Y!=1')
            down_num = len(down_df[down_df['predict'] == down_df['Y']]) / len(
                down_df)  # 上涨预测成功率

            # 储存到容器中
            res_svm_pred.loc[len(res_svm_pred), :] = [
                m, w, predict_num, right_num, right_pre, up_num, down_num
            ]

            predict_['predict'].name = f'{m}_{w}'
            flag_list.append(wsm.predict_df['predict'])  # 储存预测值 0,1标记代表持仓/空仓

        self.Flag_df = pd.concat(flag_list, axis=1)

        self.res_svm_pred = res_svm_pred

    # 计算T值
    def T_Value(self, n: int = 0):

        limit_n = len(self.res_svm_pred)

        if n > limit_n or n == 0:

            n = limit_n

        probability_of_s = self.res_svm_pred['成功概率'].iloc[:n]

        # 《平安证券 水致清则鱼自现——小波分析与支持向量机择时研究》给出的T值计算感觉不对
        # t值应该是标准差吧 但他给出的是要用方差
        #return (probability_of_s.mean() - 0.5) / (
        #    probability_of_s.var() / np.sqrt(n))

        t_statistic, p_value = stats.ttest_1samp(probability_of_s.values, 0.5)

        return f't-statistic:{t_statistic},p_value:{p_value}'


# 定义rolling_apply理论上应该比for循环快
# pandas.rolling.apply不支持多列
def rolling_apply(df, func, win_size) -> pd.Series:

    iidx = np.arange(len(df))

    shape = (iidx.size - win_size + 1, win_size)

    strides = (iidx.strides[0], iidx.strides[0])

    res = np.lib.stride_tricks.as_strided(
        iidx, shape=shape, strides=strides, writeable=True)

    # 这里注意func返回的需要为df或者ser
    return pd.concat((func(df.iloc[r]) for r in res), axis=0)  # concat可能会有点慢

from quantsplaybook.

Related Issues (4)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.