Giter VIP home page Giter VIP logo

Comments (1)

carstenf avatar carstenf commented on May 26, 2024

I probably found a solution, but not fully tested:

class MultipleTimeSeriesCV:
     """Generates tuples of train_idx, test_idx pairs
     Assumes the MultiIndex contains levels 'symbol' and 'date'
     purges overlapping outcomes"""

    def __init__(self,
                 n_splits=3,
                 train_period_length=126,
                 test_period_length=21,
                 lookahead=None,
                 date_idx='date',
                 shuffle=False):
        self.n_splits = n_splits
        self.lookahead = lookahead
        self.test_length = test_period_length
        self.train_length = train_period_length
        self.shuffle = shuffle
        self.date_idx = date_idx


    def split(self, X, y=None, groups=None):
            unique_dates = X.index.get_level_values(self.date_idx).unique()
            days = sorted(unique_dates)  # Ascending order
            split_idx = []
            for i in range(self.n_splits):
                # Calculate split indices based on ascending order of days 
                train_start_idx = i * self.test_length   
                train_end_idx = train_start_idx + self.train_length
                test_start_idx = train_end_idx + (self.lookahead or 0)
                test_end_idx = test_start_idx + self.test_length
        
                # Ensure we do not exceed the length of days
                if test_end_idx >= len(days):
                    break
                
                split_idx.append((train_start_idx, train_end_idx, test_start_idx, test_end_idx))
        
            dates = X.reset_index()[[self.date_idx]]
        
            for train_start, train_end, test_start, test_end in split_idx:
                # Adjust the condition to select the right slice based on sorted ascending days
                train_idx = dates[(dates[self.date_idx] >= days[train_start]) & 
                                  (dates[self.date_idx] < days[train_end])].index
                test_idx = dates[(dates[self.date_idx] >= days[test_start]) & 
                                 (dates[self.date_idx] < days[test_end])].index
        
                if self.shuffle:
                    train_idx = np.random.permutation(train_idx)
                
                yield train_idx.to_numpy(), test_idx.to_numpy()

    def get_n_splits(self, X, y, groups=None):
        return self.n_splits

the new result:
image

from machine-learning-for-trading.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.