generated from dhairya/scenario_template
162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
# In[ ]:
|
|
|
|
|
|
from datetime import datetime, timedelta
|
|
import pandas as pd
|
|
import numpy as np
|
|
from tms_data_interface import SQLQueryInterface
|
|
|
|
def apply_sar_flag(df, var1, var2, var3, random_state=42):
|
|
"""
|
|
Apply percentile-based thresholds, split data into alerting and non-alerting,
|
|
flag random 10% of alerting data as 'Y', and merge back.
|
|
|
|
Parameters:
|
|
df (pd.DataFrame): Input dataframe
|
|
var1 (str): First variable (for 50th percentile threshold)
|
|
var2 (str): Second variable (for 50th percentile threshold)
|
|
var3 (str): Third variable (for 90th percentile threshold)
|
|
random_state (int): Seed for reproducibility
|
|
|
|
Returns:
|
|
pd.DataFrame: DataFrame with 'SAR_Flag' column added
|
|
"""
|
|
|
|
# Calculate thresholds
|
|
th1 = np.percentile(df[var1].dropna(), 90)
|
|
th2 = np.percentile(df[var2].dropna(), 90)
|
|
th3 = np.percentile(df[var3].dropna(), 90)
|
|
|
|
# Split into alerting and non-alerting
|
|
alerting = df[(df[var1] >= th1) &
|
|
(df[var2] >= th2) &
|
|
(df[var3] >= th3)].copy()
|
|
|
|
non_alerting = df.loc[~df.index.isin(alerting.index)].copy()
|
|
|
|
# Assign SAR_Flag = 'N' for non-alerting
|
|
non_alerting['SAR_FLAG'] = 'N'
|
|
|
|
# Assign SAR_Flag for alerting data
|
|
alerting['SAR_FLAG'] = 'N'
|
|
n_y = int(len(alerting) * 0.1) # 10% count
|
|
if n_y > 0:
|
|
y_indices = alerting.sample(n=n_y, random_state=random_state).index
|
|
alerting.loc[y_indices, 'SAR_FLAG'] = 'Y'
|
|
|
|
# Merge back and preserve original order
|
|
final_df = pd.concat([alerting, non_alerting]).sort_index()
|
|
|
|
return final_df
|
|
|
|
query = """
|
|
WITH time_windows AS (
|
|
SELECT
|
|
-- End time is the current trade time
|
|
date_time AS end_time,
|
|
|
|
-- Subtract seconds from the end_time using date_add() with negative integer interval
|
|
date_add('second', -{time_window_s}, date_time) AS start_time,
|
|
|
|
-- Trade details
|
|
trade_price,
|
|
trade_volume,
|
|
trader_id,
|
|
|
|
-- Calculate minimum price within the time window
|
|
MIN(trade_price) OVER (
|
|
ORDER BY date_time
|
|
RANGE BETWEEN INTERVAL '{time_window_s}' SECOND PRECEDING AND CURRENT ROW
|
|
) AS min_price,
|
|
|
|
-- Calculate maximum price within the time window
|
|
MAX(trade_price) OVER (
|
|
ORDER BY date_time
|
|
RANGE BETWEEN INTERVAL '{time_window_s}' SECOND PRECEDING AND CURRENT ROW
|
|
) AS max_price,
|
|
|
|
-- Calculate total trade volume within the time window
|
|
SUM(trade_volume) OVER (
|
|
ORDER BY date_time
|
|
RANGE BETWEEN INTERVAL '{time_window_s}' SECOND PRECEDING AND CURRENT ROW
|
|
) AS total_volume,
|
|
|
|
-- Calculate participant's trade volume within the time window
|
|
SUM(CASE WHEN trader_id = trader_id THEN trade_volume ELSE 0 END) OVER (
|
|
PARTITION BY trader_id
|
|
ORDER BY date_time
|
|
RANGE BETWEEN INTERVAL '{time_window_s}' SECOND PRECEDING AND CURRENT ROW
|
|
) AS participant_volume
|
|
FROM
|
|
{trade_data_1b}
|
|
)
|
|
SELECT
|
|
-- Select the time window details
|
|
start_time,
|
|
end_time,
|
|
|
|
-- Select the participant (trader) ID
|
|
trader_id AS "Participant",
|
|
|
|
-- Select the calculated min and max prices
|
|
min_price,
|
|
max_price,
|
|
|
|
-- Calculate the price change percentage
|
|
(max_price - min_price) / NULLIF(min_price, 0) * 100 AS "Price Change (%)",
|
|
|
|
-- Calculate the participant's volume as a percentage of total volume
|
|
(participant_volume / NULLIF(total_volume, 0)) * 100 AS "Volume (%)",
|
|
|
|
-- Participant volume
|
|
participant_volume,
|
|
|
|
-- Select the total volume within the window
|
|
total_volume AS "Total Volume"
|
|
FROM
|
|
time_windows
|
|
"""
|
|
|
|
|
|
from tms_data_interface import SQLQueryInterface
|
|
|
|
class Scenario:
|
|
seq = SQLQueryInterface(schema="trade_schema")
|
|
def logic(self, **kwargs):
|
|
validation_window = kwargs.get('validation_window', 300000)
|
|
time_window_s = int(validation_window/1000)
|
|
query_start_time = datetime.now()
|
|
print("Query start time :",query_start_time)
|
|
row_list = self.seq.execute_raw(query.format(trade_data_1b="trade_10m_v3",
|
|
time_window_s = time_window_s)
|
|
)
|
|
cols = [
|
|
'START_DATE_TIME',
|
|
'END_DATE_TIME',
|
|
'Focal_id',
|
|
'MIN_PRICE',
|
|
'MAX_PRICE',
|
|
'PRICE_CHANGE_PCT',
|
|
'PARTICIPANT_VOLUME_PCT',
|
|
'PARTICIPANT_VOLUME',
|
|
'TOTAL_VOLUME',
|
|
]
|
|
final_scenario_df = pd.DataFrame(row_list, columns = cols)
|
|
final_scenario_df['PARTICIPANT_VOLUME_PCT'] = final_scenario_df['PARTICIPANT_VOLUME']/\
|
|
final_scenario_df['TOTAL_VOLUME'] * 100
|
|
final_scenario_df['Segment'] = 'Default'
|
|
# final_scenario_df['SAR_FLAG'] = 'N'
|
|
final_scenario_df['Risk'] = 'Medium Risk'
|
|
final_scenario_df.dropna(inplace=True)
|
|
final_scenario_df = apply_sar_flag(final_scenario_df,
|
|
'PRICE_CHANGE_PCT',
|
|
'PARTICIPANT_VOLUME_PCT',
|
|
'TOTAL_VOLUME',
|
|
random_state=42)
|
|
# final_scenario_df['RUN_DATE'] = final_scenario_df['END_DATE']
|
|
return final_scenario_df
|
|
|