Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Unit Tests for Feature Availability Module | |
| Tests feature categorization, availability masking, and validation. | |
| """ | |
| import pytest | |
| import numpy as np | |
| import polars as pl | |
| from datasets import load_dataset | |
| from src.forecasting.feature_availability import FeatureAvailability | |
| def sample_columns(): | |
| """Load actual dataset columns for testing.""" | |
| # Use HF token for private dataset access | |
| hf_token = "<HF_TOKEN>" | |
| dataset = load_dataset( | |
| "evgueni-p/fbmc-features-24month", | |
| split="train", | |
| token=hf_token | |
| ) | |
| return list(dataset.features.keys()) | |
| def categories(sample_columns): | |
| """Categorize features once for all tests.""" | |
| return FeatureAvailability.categorize_features(sample_columns) | |
| class TestFeatureCategorization: | |
| """Test feature categorization logic.""" | |
| def test_total_feature_count(self, categories): | |
| """Test total feature count matches expected.""" | |
| total = sum(len(v) for v in categories.values()) | |
| # 2,553 columns - 1 timestamp - 38 targets = 2,514 features | |
| assert total == 2514, f"Expected 2,514 features, got {total}" | |
| def test_no_uncategorized_features(self, categories): | |
| """Test all features are categorized.""" | |
| uncategorized = categories['uncategorized'] | |
| assert len(uncategorized) == 0, ( | |
| f"Found {len(uncategorized)} uncategorized features: " | |
| f"{uncategorized[:10]}" | |
| ) | |
| def test_full_horizon_count(self, categories): | |
| """Test full-horizon D+14 feature count.""" | |
| full_d14 = len(categories['full_horizon_d14']) | |
| # Expected: temporal (12) + weather (375) + outages (176) + LTA (40) = 603 | |
| assert 580 <= full_d14 <= 620, ( | |
| f"Expected ~603 full-horizon features, got {full_d14}" | |
| ) | |
| def test_partial_d1_count(self, categories): | |
| """Test partial D+1 feature count.""" | |
| partial = len(categories['partial_d1']) | |
| # Expected: load forecasts (12) | |
| assert partial == 12, f"Expected 12 partial D+1 features, got {partial}" | |
| def test_historical_count(self, categories): | |
| """Test historical feature count.""" | |
| historical = len(categories['historical']) | |
| # Expected: ~1,899 (prices, generation, demand, lags, etc.) | |
| assert 1800 <= historical <= 2000, ( | |
| f"Expected ~1,899 historical features, got {historical}" | |
| ) | |
| def test_temporal_features_in_full_horizon(self, categories): | |
| """Test temporal features are in full_horizon_d14.""" | |
| full_d14 = categories['full_horizon_d14'] | |
| temporal_patterns = [ | |
| 'hour_sin', 'hour_cos', | |
| 'day_sin', 'day_cos', | |
| 'month_sin', 'month_cos', | |
| 'weekday_sin', 'weekday_cos', | |
| 'is_weekend' | |
| ] | |
| for pattern in temporal_patterns: | |
| matching = [f for f in full_d14 if pattern in f] | |
| assert len(matching) > 0, f"No temporal features matching '{pattern}'" | |
| def test_weather_features_in_full_horizon(self, categories): | |
| """Test weather features are in full_horizon_d14.""" | |
| full_d14 = categories['full_horizon_d14'] | |
| weather_prefixes = ['temp_', 'wind_', 'solar_', 'cloud_', 'pressure_'] | |
| for prefix in weather_prefixes: | |
| matching = [f for f in full_d14 if f.startswith(prefix)] | |
| assert len(matching) > 0, f"No weather features starting with '{prefix}'" | |
| def test_outage_features_in_full_horizon(self, categories): | |
| """Test CNEC outage features are in full_horizon_d14.""" | |
| full_d14 = categories['full_horizon_d14'] | |
| outage_features = [f for f in full_d14 if f.startswith('outage_cnec_')] | |
| assert len(outage_features) == 176, ( | |
| f"Expected 176 CNEC outage features, got {len(outage_features)}" | |
| ) | |
| def test_lta_features_in_full_horizon(self, categories): | |
| """Test LTA features are in full_horizon_d14.""" | |
| full_d14 = categories['full_horizon_d14'] | |
| lta_features = [f for f in full_d14 if f.startswith('lta_')] | |
| assert len(lta_features) == 40, ( | |
| f"Expected 40 LTA features, got {len(lta_features)}" | |
| ) | |
| def test_load_forecast_in_partial(self, categories): | |
| """Test load forecast features are in partial_d1.""" | |
| partial = categories['partial_d1'] | |
| load_forecasts = [f for f in partial if f.startswith('load_forecast_')] | |
| assert len(load_forecasts) == 12, ( | |
| f"Expected 12 load forecast features, got {len(load_forecasts)}" | |
| ) | |
| def test_price_features_in_historical(self, categories): | |
| """Test price features are in historical.""" | |
| historical = categories['historical'] | |
| price_features = [f for f in historical if f.startswith('price_')] | |
| assert len(price_features) > 0, "No price features found in historical" | |
| def test_generation_features_in_historical(self, categories): | |
| """Test generation features are in historical.""" | |
| historical = categories['historical'] | |
| gen_features = [f for f in historical if f.startswith('gen_')] | |
| assert len(gen_features) > 0, "No generation features found in historical" | |
| def test_demand_features_in_historical(self, categories): | |
| """Test demand features are in historical.""" | |
| historical = categories['historical'] | |
| demand_features = [f for f in historical if f.startswith('demand_')] | |
| assert len(demand_features) > 0, "No demand features found in historical" | |
| def test_no_duplicates_across_categories(self, categories): | |
| """Test features are not duplicated across categories.""" | |
| full_set = set(categories['full_horizon_d14']) | |
| partial_set = set(categories['partial_d1']) | |
| historical_set = set(categories['historical']) | |
| # Check for overlaps | |
| full_partial = full_set & partial_set | |
| full_historical = full_set & historical_set | |
| partial_historical = partial_set & historical_set | |
| assert len(full_partial) == 0, f"Overlap between full and partial: {full_partial}" | |
| assert len(full_historical) == 0, f"Overlap between full and historical: {full_historical}" | |
| assert len(partial_historical) == 0, f"Overlap between partial and historical: {partial_historical}" | |
| class TestAvailabilityMasking: | |
| """Test availability mask creation.""" | |
| def test_full_horizon_mask(self): | |
| """Test mask for full-horizon features.""" | |
| mask = FeatureAvailability.create_availability_mask('temp_DE_LU', 336) | |
| assert mask.shape == (336,), f"Expected shape (336,), got {mask.shape}" | |
| assert np.all(mask == 1.0), "Full-horizon mask should be all ones" | |
| def test_partial_d1_mask(self): | |
| """Test mask for partial D+1 features.""" | |
| mask = FeatureAvailability.create_availability_mask('load_forecast_DE', 336) | |
| assert mask.shape == (336,), f"Expected shape (336,), got {mask.shape}" | |
| assert np.sum(mask) == 24, f"Expected 24 ones (D+1), got {np.sum(mask)}" | |
| assert np.all(mask[:24] == 1.0), "First 24 hours should be available" | |
| assert np.all(mask[24:] == 0.0), "Hours 25-336 should be masked" | |
| def test_temporal_mask(self): | |
| """Test mask for temporal features (always available).""" | |
| mask = FeatureAvailability.create_availability_mask('hour_sin', 336) | |
| assert mask.shape == (336,), f"Expected shape (336,), got {mask.shape}" | |
| assert np.all(mask == 1.0), "Temporal mask should be all ones" | |
| def test_lta_mask(self): | |
| """Test mask for LTA features (forward-filled).""" | |
| mask = FeatureAvailability.create_availability_mask('lta_AT_CZ', 336) | |
| assert mask.shape == (336,), f"Expected shape (336,), got {mask.shape}" | |
| assert np.all(mask == 1.0), "LTA mask should be all ones (forward-filled)" | |
| def test_historical_mask(self): | |
| """Test mask for historical features.""" | |
| mask = FeatureAvailability.create_availability_mask('price_DE', 336) | |
| assert mask.shape == (336,), f"Expected shape (336,), got {mask.shape}" | |
| assert np.all(mask == 0.0), "Historical mask should be all zeros" | |
| def test_mask_different_horizons(self): | |
| """Test mask with different forecast horizons.""" | |
| # Test 168-hour horizon (7 days) | |
| mask_168 = FeatureAvailability.create_availability_mask('load_forecast_DE', 168) | |
| assert mask_168.shape == (168,) | |
| assert np.sum(mask_168) == 24 | |
| # Test 720-hour horizon (30 days) | |
| mask_720 = FeatureAvailability.create_availability_mask('load_forecast_DE', 720) | |
| assert mask_720.shape == (720,) | |
| assert np.sum(mask_720) == 24 | |
| class TestValidation: | |
| """Test validation functions.""" | |
| def test_validation_passes(self, categories): | |
| """Test validation passes for correct categorization.""" | |
| is_valid, warnings = FeatureAvailability.validate_categorization( | |
| categories, verbose=False | |
| ) | |
| assert is_valid, f"Validation failed with warnings: {warnings}" | |
| assert len(warnings) == 0, f"Unexpected warnings: {warnings}" | |
| def test_category_summary_generation(self, categories): | |
| """Test category summary table generation.""" | |
| summary = FeatureAvailability.get_category_summary(categories) | |
| assert 'Category' in summary.columns | |
| assert 'Count' in summary.columns | |
| assert 'Availability' in summary.columns | |
| assert len(summary) >= 3 # At least 3 categories (full, partial, historical) | |
| class TestPatternMatching: | |
| """Test internal pattern matching logic.""" | |
| def test_temporal_pattern_matching(self): | |
| """Test temporal feature pattern matching.""" | |
| test_cols = ['hour_sin', 'day_cos', 'month', 'weekday', 'is_weekend'] | |
| categories = FeatureAvailability.categorize_features(test_cols) | |
| assert len(categories['full_horizon_d14']) == 5 | |
| assert len(categories['partial_d1']) == 0 | |
| assert len(categories['historical']) == 0 | |
| def test_weather_prefix_matching(self): | |
| """Test weather feature prefix matching.""" | |
| test_cols = ['temp_DE', 'wind_FR', 'solar_AT', 'cloud_NL', 'pressure_BE'] | |
| categories = FeatureAvailability.categorize_features(test_cols) | |
| assert len(categories['full_horizon_d14']) == 5 | |
| def test_load_forecast_matching(self): | |
| """Test load forecast prefix matching.""" | |
| test_cols = ['load_forecast_DE', 'load_forecast_FR', 'load_forecast_AT'] | |
| categories = FeatureAvailability.categorize_features(test_cols) | |
| assert len(categories['partial_d1']) == 3 | |
| def test_price_matching(self): | |
| """Test price feature matching.""" | |
| test_cols = ['price_DE', 'price_FR', 'price_AT'] | |
| categories = FeatureAvailability.categorize_features(test_cols) | |
| assert len(categories['historical']) == 3 | |
| def test_mixed_features(self): | |
| """Test categorization with mixed feature types.""" | |
| test_cols = [ | |
| 'hour_sin', # temporal -> full | |
| 'temp_DE', # weather -> full | |
| 'load_forecast_DE', # load -> partial | |
| 'price_DE', # price -> historical | |
| 'gen_FR_nuclear', # generation -> historical | |
| ] | |
| categories = FeatureAvailability.categorize_features(test_cols) | |
| assert len(categories['full_horizon_d14']) == 2 # hour_sin, temp_DE | |
| assert len(categories['partial_d1']) == 1 # load_forecast_DE | |
| assert len(categories['historical']) == 2 # price_DE, gen_FR_nuclear | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "-s"]) | |