diff --git a/src/scanner/utils.rs b/src/scanner/utils.rs index e27f397..f16293d 100644 --- a/src/scanner/utils.rs +++ b/src/scanner/utils.rs @@ -470,14 +470,9 @@ impl Requester { None } - /// query the statistics handler for the current number of errors based on the given policy - async fn get_scan_errors_by_policy(&self, trigger: PolicyTrigger) -> Result { - Ok(self.ferox_scan.num_errors(trigger)) - } - /// wrapper for adjust_[up,down] functions, checks error levels to determine adjustment direction async fn adjust_limit(&self, trigger: PolicyTrigger) -> Result<()> { - let scan_errors = self.get_scan_errors_by_policy(trigger).await?; + let scan_errors = self.ferox_scan.num_errors(trigger); let policy_errors = atomic_load!(self.policy_data.errors, Ordering::SeqCst); if let Ok(mut guard) = self.tuning_lock.try_lock() { @@ -691,6 +686,7 @@ mod tests { scan_manager::{ScanOrder, ScanType}, }; use reqwest::StatusCode; + use std::time::Instant; /// helper to setup a realistic requester test async fn setup_requester_test(config: Option>) -> (Arc, Tasks) { @@ -1229,4 +1225,122 @@ mod tests { assert_eq!(pd.heap.write().unwrap().move_up(), 0); assert_eq!(pd.heap.write().unwrap().parent_value(), 400); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + /// cooldown should pause execution and prevent others calling it by setting cooling_down flag + async fn cooldown_pauses_and_sets_flag() { + let (handles, _) = setup_requester_test(None).await; + + let requester = Arc::new(Requester { + handles, + tuning_lock: Mutex::new(0), + ferox_scan: Arc::new(FeroxScan::default()), + target_url: "http://localhost".to_string(), + rate_limiter: RwLock::new(None), + policy_data: PolicyData::new(RequesterPolicy::AutoBail, 7), + }); + + let start = Instant::now(); + let clone = requester.clone(); + let resp = tokio::task::spawn(async move { + sleep(Duration::new(1, 0)).await; + clone.policy_data.cooling_down.load(Ordering::Relaxed) + }); + + requester.cool_down().await; + + assert_eq!(resp.await.unwrap(), true); + println!("{}", start.elapsed().as_millis()); + assert!(start.elapsed().as_millis() >= 3500); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + /// adjust_limit should add one to the streak counter when errors from scan equal policy and + /// increase the scan rate + async fn adjust_limit_increments_streak_counter_on_upward_movement() { + let (handles, _) = setup_requester_test(None).await; + + let requester = Requester { + handles, + tuning_lock: Mutex::new(0), + ferox_scan: Arc::new(FeroxScan::default()), + target_url: "http://localhost".to_string(), + rate_limiter: RwLock::new(None), + policy_data: PolicyData::new(RequesterPolicy::AutoBail, 7), + }; + + requester.policy_data.set_reqs_sec(400); + requester.adjust_limit(PolicyTrigger::Errors).await.unwrap(); + + assert_eq!(*requester.tuning_lock.lock().unwrap(), 1); + assert_eq!(requester.policy_data.get_limit(), 300); + assert_eq!( + requester.rate_limiter.read().await.as_ref().unwrap().max(), + 300 + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + /// adjust_limit should reset the streak counter when errors from scan are > policy and + /// decrease the scan rate + async fn adjust_limit_resets_streak_counter_on_downward_movement() { + let (handles, _) = setup_requester_test(None).await; + + let scan = FeroxScan::default(); + scan.add_error(); + scan.add_error(); + + let requester = Requester { + handles, + tuning_lock: Mutex::new(0), + ferox_scan: Arc::new(scan), + target_url: "http://localhost".to_string(), + rate_limiter: RwLock::new(None), + policy_data: PolicyData::new(RequesterPolicy::AutoBail, 7), + }; + + requester.policy_data.set_reqs_sec(400); + requester.policy_data.set_errors(1); + + let mut guard = requester.tuning_lock.lock().unwrap(); + *guard = 2; + drop(guard); + + requester.adjust_limit(PolicyTrigger::Errors).await.unwrap(); + + assert_eq!(*requester.tuning_lock.lock().unwrap(), 0); + assert_eq!(requester.policy_data.get_limit(), 100); + assert_eq!( + requester.rate_limiter.read().await.as_ref().unwrap().max(), + 100 + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + /// adjust_limit should remove the rate limiter when remove_limit is set + async fn adjust_limit_removes_rate_limiter() { + let (handles, _) = setup_requester_test(None).await; + + let scan = FeroxScan::default(); + scan.add_error(); + scan.add_error(); + + let requester = Requester { + handles, + tuning_lock: Mutex::new(0), + ferox_scan: Arc::new(scan), + target_url: "http://localhost".to_string(), + rate_limiter: RwLock::new(None), + policy_data: PolicyData::new(RequesterPolicy::AutoBail, 7), + }; + + requester.policy_data.set_reqs_sec(400); + requester + .policy_data + .remove_limit + .store(true, Ordering::Relaxed); + + requester.adjust_limit(PolicyTrigger::Errors).await.unwrap(); + assert!(requester.rate_limiter.read().await.is_none()); + } }