use actix_mvc_app::controllers::payment::CompanyRegistrationData; use actix_mvc_app::db::payment as payment_db; use actix_mvc_app::db::registration as registration_db; use actix_mvc_app::utils::stripe_security::StripeWebhookVerifier; use actix_mvc_app::validators::CompanyRegistrationValidator; use heromodels::models::biz::PaymentStatus; use hmac::{Hmac, Mac}; use sha2::Sha256; #[cfg(test)] mod payment_flow_tests { use super::*; fn create_valid_registration_data() -> CompanyRegistrationData { CompanyRegistrationData { company_name: "Test Company Ltd".to_string(), company_type: "Single FZC".to_string(), company_email: Some("test@example.com".to_string()), company_phone: Some("+1234567890".to_string()), company_website: Some("https://example.com".to_string()), company_address: Some("123 Test Street, Test City".to_string()), company_industry: Some("Technology".to_string()), company_purpose: Some("Software development".to_string()), fiscal_year_end: Some("December".to_string()), shareholders: r#"[{"name": "John Doe", "percentage": 100}]"#.to_string(), payment_plan: "monthly".to_string(), } } #[test] fn test_registration_data_validation_success() { let data = create_valid_registration_data(); let result = CompanyRegistrationValidator::validate(&data); assert!( result.is_valid, "Valid registration data should pass validation" ); assert!(result.errors.is_empty(), "Valid data should have no errors"); } #[test] fn test_registration_data_validation_failures() { // Test empty company name let mut data = create_valid_registration_data(); data.company_name = "".to_string(); let result = CompanyRegistrationValidator::validate(&data); assert!(!result.is_valid); assert!(result.errors.iter().any(|e| e.field == "company_name")); // Test invalid email let mut data = create_valid_registration_data(); data.company_email = Some("invalid-email".to_string()); let result = CompanyRegistrationValidator::validate(&data); assert!(!result.is_valid); assert!(result.errors.iter().any(|e| e.field == "company_email")); // Test invalid phone let mut data = create_valid_registration_data(); data.company_phone = Some("123".to_string()); let result = CompanyRegistrationValidator::validate(&data); assert!(!result.is_valid); assert!(result.errors.iter().any(|e| e.field == "company_phone")); // Test invalid website let mut data = create_valid_registration_data(); data.company_website = Some("not-a-url".to_string()); let result = CompanyRegistrationValidator::validate(&data); assert!(!result.is_valid); assert!(result.errors.iter().any(|e| e.field == "company_website")); // Test invalid shareholders JSON let mut data = create_valid_registration_data(); data.shareholders = "invalid json".to_string(); let result = CompanyRegistrationValidator::validate(&data); assert!(!result.is_valid); assert!(result.errors.iter().any(|e| e.field == "shareholders")); // Test invalid payment plan let mut data = create_valid_registration_data(); data.payment_plan = "invalid_plan".to_string(); let result = CompanyRegistrationValidator::validate(&data); assert!(!result.is_valid); assert!(result.errors.iter().any(|e| e.field == "payment_plan")); } #[test] fn test_registration_data_storage_and_retrieval() { let payment_intent_id = "pi_test_123456".to_string(); let data = create_valid_registration_data(); // Store registration data let store_result = registration_db::store_registration_data(payment_intent_id.clone(), data.clone()); assert!( store_result.is_ok(), "Should successfully store registration data" ); // Retrieve registration data let retrieve_result = registration_db::get_registration_data(&payment_intent_id); assert!( retrieve_result.is_ok(), "Should successfully retrieve registration data" ); let retrieved_data = retrieve_result.unwrap(); assert!( retrieved_data.is_some(), "Should find stored registration data" ); let stored_data = retrieved_data.unwrap(); assert_eq!(stored_data.company_name, data.company_name); assert_eq!(stored_data.company_email, data.company_email.unwrap()); assert_eq!(stored_data.payment_plan, data.payment_plan); // Clean up let _ = registration_db::delete_registration_data(&payment_intent_id); } #[test] fn test_payment_creation_and_status_updates() { let payment_intent_id = "pi_test_payment_123".to_string(); // Create a payment let create_result = payment_db::create_new_payment( payment_intent_id.clone(), 0, // Temporary company_id "monthly".to_string(), 20.0, // setup_fee 20.0, // monthly_fee 40.0, // total_amount ); assert!(create_result.is_ok(), "Should successfully create payment"); let (payment_id, payment) = create_result.unwrap(); assert_eq!(payment.payment_intent_id, payment_intent_id); assert_eq!(payment.status, PaymentStatus::Pending); // Update payment status to completed let update_result = payment_db::update_payment_status(&payment_intent_id, PaymentStatus::Completed); assert!( update_result.is_ok(), "Should successfully update payment status" ); let updated_payment = update_result.unwrap(); assert!(updated_payment.is_some(), "Should return updated payment"); assert_eq!(updated_payment.unwrap().status, PaymentStatus::Completed); // Test updating company ID let company_id = 123u32; let link_result = payment_db::update_payment_company_id(&payment_intent_id, company_id); assert!( link_result.is_ok(), "Should successfully link payment to company" ); let linked_payment = link_result.unwrap(); assert!(linked_payment.is_some(), "Should return linked payment"); assert_eq!(linked_payment.unwrap().company_id, company_id); } #[test] fn test_payment_queries() { // Test getting pending payments let pending_result = payment_db::get_pending_payments(); assert!( pending_result.is_ok(), "Should successfully get pending payments" ); // Test getting failed payments let failed_result = payment_db::get_failed_payments(); assert!( failed_result.is_ok(), "Should successfully get failed payments" ); // Test getting payment by intent ID let get_result = payment_db::get_payment_by_intent_id("nonexistent_payment"); assert!( get_result.is_ok(), "Should handle nonexistent payment gracefully" ); assert!( get_result.unwrap().is_none(), "Should return None for nonexistent payment" ); } #[test] fn test_pricing_calculations() { // Test pricing calculation logic fn calculate_total_amount(setup_fee: f64, monthly_fee: f64, payment_plan: &str) -> f64 { match payment_plan { "monthly" => setup_fee + monthly_fee, "yearly" => setup_fee + (monthly_fee * 12.0 * 0.8), // 20% discount "two_year" => setup_fee + (monthly_fee * 24.0 * 0.6), // 40% discount _ => setup_fee + monthly_fee, } } // Test monthly pricing let monthly_total = calculate_total_amount(20.0, 20.0, "monthly"); assert_eq!( monthly_total, 40.0, "Monthly total should be setup + monthly fee" ); // Test yearly pricing (20% discount) let yearly_total = calculate_total_amount(20.0, 20.0, "yearly"); let expected_yearly = 20.0 + (20.0 * 12.0 * 0.8); // Setup + discounted yearly assert_eq!( yearly_total, expected_yearly, "Yearly total should include 20% discount" ); // Test two-year pricing (40% discount) let two_year_total = calculate_total_amount(20.0, 20.0, "two_year"); let expected_two_year = 20.0 + (20.0 * 24.0 * 0.6); // Setup + discounted two-year assert_eq!( two_year_total, expected_two_year, "Two-year total should include 40% discount" ); } #[test] fn test_company_type_mapping() { let test_cases = vec![ ("Single FZC", "Single"), ("Startup FZC", "Starter"), ("Growth FZC", "Global"), ("Global FZC", "Global"), ("Cooperative FZC", "Coop"), ("Twin FZC", "Twin"), ]; for (input, expected) in test_cases { // This would test the business type mapping in create_company_from_form_data // We'll need to expose this logic or test it indirectly assert!(true, "Company type mapping test placeholder for {}", input); } } } #[cfg(test)] mod webhook_security_tests { use super::*; use std::time::{SystemTime, UNIX_EPOCH}; #[test] fn test_webhook_signature_verification_valid() { let payload = b"test payload"; let webhook_secret = "whsec_test_secret"; let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); // Create a valid signature let signed_payload = format!("{}.{}", current_time, std::str::from_utf8(payload).unwrap()); let mut mac = Hmac::::new_from_slice(webhook_secret.as_bytes()).unwrap(); mac.update(signed_payload.as_bytes()); let signature = hex::encode(mac.finalize().into_bytes()); let signature_header = format!("t={},v1={}", current_time, signature); let result = StripeWebhookVerifier::verify_signature( payload, &signature_header, webhook_secret, Some(300), ); assert!(result.is_ok(), "Valid signature should verify successfully"); assert!(result.unwrap(), "Valid signature should return true"); } #[test] fn test_webhook_signature_verification_invalid() { let payload = b"test payload"; let webhook_secret = "whsec_test_secret"; let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); // Create an invalid signature let signature_header = format!("t={},v1=invalid_signature", current_time); let result = StripeWebhookVerifier::verify_signature( payload, &signature_header, webhook_secret, Some(300), ); assert!(result.is_ok(), "Invalid signature should not cause error"); assert!(!result.unwrap(), "Invalid signature should return false"); } #[test] fn test_webhook_signature_verification_expired() { let payload = b"test payload"; let webhook_secret = "whsec_test_secret"; let old_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() - 400; // 400 seconds ago (beyond 300s tolerance) // Create a signature with old timestamp let signed_payload = format!("{}.{}", old_time, std::str::from_utf8(payload).unwrap()); let mut mac = Hmac::::new_from_slice(webhook_secret.as_bytes()).unwrap(); mac.update(signed_payload.as_bytes()); let signature = hex::encode(mac.finalize().into_bytes()); let signature_header = format!("t={},v1={}", old_time, signature); let result = StripeWebhookVerifier::verify_signature( payload, &signature_header, webhook_secret, Some(300), ); assert!(result.is_err(), "Expired signature should return error"); assert!( result.unwrap_err().contains("too old"), "Error should mention timestamp age" ); } #[test] fn test_webhook_signature_verification_malformed_header() { let payload = b"test payload"; let webhook_secret = "whsec_test_secret"; // Test various malformed headers let malformed_headers = vec![ "invalid_header", "t=123", "v1=signature", "t=invalid_timestamp,v1=signature", "", ]; for header in malformed_headers { let result = StripeWebhookVerifier::verify_signature(payload, header, webhook_secret, Some(300)); assert!( result.is_err(), "Malformed header '{}' should return error", header ); } } }