refactor: performed pre-commit checks

This commit is contained in:
yashksaini-coder
2025-09-02 03:06:39 +05:30
parent e8d1a0fc32
commit 05867be37e
3 changed files with 101 additions and 78 deletions

View File

@ -8,40 +8,41 @@ from libp2p.utils.address_validation import (
get_optimal_binding_address, get_optimal_binding_address,
) )
def main(): def main():
print("=== Address Validation Utilities Demo ===\n") print("=== Address Validation Utilities Demo ===\n")
port = 8000 port = 8000
# Test available interfaces # Test available interfaces
print(f"Available interfaces for port {port}:") print(f"Available interfaces for port {port}:")
interfaces = get_available_interfaces(port) interfaces = get_available_interfaces(port)
for i, addr in enumerate(interfaces, 1): for i, addr in enumerate(interfaces, 1):
print(f" {i}. {addr}") print(f" {i}. {addr}")
print() print()
# Test optimal binding address # Test optimal binding address
print(f"Optimal binding address for port {port}:") print(f"Optimal binding address for port {port}:")
optimal = get_optimal_binding_address(port) optimal = get_optimal_binding_address(port)
print(f" -> {optimal}") print(f" -> {optimal}")
print() print()
# Check for wildcard addresses # Check for wildcard addresses
wildcard_found = any("0.0.0.0" in str(addr) for addr in interfaces) wildcard_found = any("0.0.0.0" in str(addr) for addr in interfaces)
print(f"Wildcard addresses found: {wildcard_found}") print(f"Wildcard addresses found: {wildcard_found}")
# Check for loopback addresses # Check for loopback addresses
loopback_found = any("127.0.0.1" in str(addr) for addr in interfaces) loopback_found = any("127.0.0.1" in str(addr) for addr in interfaces)
print(f"Loopback addresses found: {loopback_found}") print(f"Loopback addresses found: {loopback_found}")
# Check if optimal is wildcard # Check if optimal is wildcard
optimal_is_wildcard = "0.0.0.0" in str(optimal) optimal_is_wildcard = "0.0.0.0" in str(optimal)
print(f"Optimal address is wildcard: {optimal_is_wildcard}") print(f"Optimal address is wildcard: {optimal_is_wildcard}")
print() print()
if not wildcard_found and loopback_found and not optimal_is_wildcard: if not wildcard_found and loopback_found and not optimal_is_wildcard:
print("✅ All checks passed! Address validation is working correctly.") print("✅ All checks passed! Address validation is working correctly.")
print(" - No wildcard addresses") print(" - No wildcard addresses")
@ -49,9 +50,9 @@ def main():
print(" - Optimal address is secure") print(" - Optimal address is secure")
else: else:
print("❌ Some checks failed. Address validation needs attention.") print("❌ Some checks failed. Address validation needs attention.")
print() print()
# Test different protocols # Test different protocols
print("Testing different protocols:") print("Testing different protocols:")
for protocol in ["tcp", "udp"]: for protocol in ["tcp", "udp"]:
@ -60,5 +61,6 @@ def main():
if "0.0.0.0" in str(addr): if "0.0.0.0" in str(addr):
print(f" ⚠️ Warning: {protocol} returned wildcard address") print(f" ⚠️ Warning: {protocol} returned wildcard address")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -2,8 +2,6 @@
Tests to verify that all examples use 127.0.0.1 instead of 0.0.0.0 Tests to verify that all examples use 127.0.0.1 instead of 0.0.0.0
""" """
import ast
import os
from pathlib import Path from pathlib import Path
@ -17,42 +15,42 @@ class TestExamplesBindAddress:
def check_file_for_wildcard_binding(self, filepath): def check_file_for_wildcard_binding(self, filepath):
"""Check if a file contains 0.0.0.0 binding""" """Check if a file contains 0.0.0.0 binding"""
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, encoding="utf-8") as f:
content = f.read() content = f.read()
# Check for various forms of wildcard binding # Check for various forms of wildcard binding
wildcard_patterns = [ wildcard_patterns = [
'0.0.0.0', "0.0.0.0",
'/ip4/0.0.0.0/', "/ip4/0.0.0.0/",
] ]
found_wildcards = [] found_wildcards = []
for line_num, line in enumerate(content.splitlines(), 1): for line_num, line in enumerate(content.splitlines(), 1):
for pattern in wildcard_patterns: for pattern in wildcard_patterns:
if pattern in line and not line.strip().startswith('#'): if pattern in line and not line.strip().startswith("#"):
found_wildcards.append((line_num, line.strip())) found_wildcards.append((line_num, line.strip()))
return found_wildcards return found_wildcards
def test_no_wildcard_binding_in_examples(self): def test_no_wildcard_binding_in_examples(self):
"""Test that no example files use 0.0.0.0 for binding""" """Test that no example files use 0.0.0.0 for binding"""
example_files = self.get_example_files() example_files = self.get_example_files()
# Skip certain files that might legitimately discuss wildcards # Skip certain files that might legitimately discuss wildcards
skip_files = [ skip_files = [
'network_discover.py', # This demonstrates wildcard expansion "network_discover.py", # This demonstrates wildcard expansion
] ]
files_with_wildcards = {} files_with_wildcards = {}
for filepath in example_files: for filepath in example_files:
if any(skip in str(filepath) for skip in skip_files): if any(skip in str(filepath) for skip in skip_files):
continue continue
wildcards = self.check_file_for_wildcard_binding(filepath) wildcards = self.check_file_for_wildcard_binding(filepath)
if wildcards: if wildcards:
files_with_wildcards[str(filepath)] = wildcards files_with_wildcards[str(filepath)] = wildcards
# Assert no wildcards found # Assert no wildcards found
if files_with_wildcards: if files_with_wildcards:
error_msg = "Found wildcard bindings in example files:\n" error_msg = "Found wildcard bindings in example files:\n"
@ -60,51 +58,53 @@ class TestExamplesBindAddress:
error_msg += f"\n{filepath}:\n" error_msg += f"\n{filepath}:\n"
for line_num, line in occurrences: for line_num, line in occurrences:
error_msg += f" Line {line_num}: {line}\n" error_msg += f" Line {line_num}: {line}\n"
assert False, error_msg assert False, error_msg
def test_examples_use_loopback_address(self): def test_examples_use_loopback_address(self):
"""Test that examples use 127.0.0.1 for local binding""" """Test that examples use 127.0.0.1 for local binding"""
example_files = self.get_example_files() example_files = self.get_example_files()
# Files that should contain listen addresses # Files that should contain listen addresses
files_with_networking = [ files_with_networking = [
'ping/ping.py', "ping/ping.py",
'chat/chat.py', "chat/chat.py",
'bootstrap/bootstrap.py', "bootstrap/bootstrap.py",
'pubsub/pubsub.py', "pubsub/pubsub.py",
'identify/identify.py', "identify/identify.py",
] ]
for filename in files_with_networking: for filename in files_with_networking:
filepath = None filepath = None
for example_file in example_files: for example_file in example_files:
if filename in str(example_file): if filename in str(example_file):
filepath = example_file filepath = example_file
break break
if filepath is None: if filepath is None:
continue continue
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, encoding="utf-8") as f:
content = f.read() content = f.read()
# Check for proper loopback usage # Check for proper loopback usage
has_loopback = '127.0.0.1' in content or 'localhost' in content has_loopback = "127.0.0.1" in content or "localhost" in content
has_multiaddr_loopback = '/ip4/127.0.0.1/' in content has_multiaddr_loopback = "/ip4/127.0.0.1/" in content
assert has_loopback or has_multiaddr_loopback, \ assert has_loopback or has_multiaddr_loopback, (
f"{filepath} should use loopback address (127.0.0.1)" f"{filepath} should use loopback address (127.0.0.1)"
)
def test_doc_examples_use_loopback(self): def test_doc_examples_use_loopback(self):
"""Test that documentation examples use secure addresses""" """Test that documentation examples use secure addresses"""
doc_examples_dir = Path("examples/doc-examples") doc_examples_dir = Path("examples/doc-examples")
if not doc_examples_dir.exists(): if not doc_examples_dir.exists():
return return
doc_example_files = list(doc_examples_dir.glob("*.py")) doc_example_files = list(doc_examples_dir.glob("*.py"))
for filepath in doc_example_files: for filepath in doc_example_files:
wildcards = self.check_file_for_wildcard_binding(filepath) wildcards = self.check_file_for_wildcard_binding(filepath)
assert not wildcards, \ assert not wildcards, (
f"Documentation example {filepath} contains wildcard binding" f"Documentation example {filepath} contains wildcard binding"
)

View File

@ -13,16 +13,19 @@ from libp2p.utils.address_validation import (
class TestDefaultBindAddress: class TestDefaultBindAddress:
"""Test suite for verifying default bind addresses use secure addresses (not 0.0.0.0)""" """
Test suite for verifying default bind addresses use
secure addresses (not 0.0.0.0)
"""
def test_default_bind_address_is_not_wildcard(self): def test_default_bind_address_is_not_wildcard(self):
"""Test that default bind address is NOT 0.0.0.0 (wildcard)""" """Test that default bind address is NOT 0.0.0.0 (wildcard)"""
port = 8000 port = 8000
addr = get_optimal_binding_address(port) addr = get_optimal_binding_address(port)
# Should NOT return wildcard address # Should NOT return wildcard address
assert "0.0.0.0" not in str(addr) assert "0.0.0.0" not in str(addr)
# Should return a valid IP address (could be loopback or local network) # Should return a valid IP address (could be loopback or local network)
addr_str = str(addr) addr_str = str(addr)
assert "/ip4/" in addr_str assert "/ip4/" in addr_str
@ -32,14 +35,14 @@ class TestDefaultBindAddress:
"""Test that available interfaces always includes loopback address""" """Test that available interfaces always includes loopback address"""
port = 8000 port = 8000
interfaces = get_available_interfaces(port) interfaces = get_available_interfaces(port)
# Should have at least one interface # Should have at least one interface
assert len(interfaces) > 0 assert len(interfaces) > 0
# Should include loopback address # Should include loopback address
loopback_found = any("127.0.0.1" in str(addr) for addr in interfaces) loopback_found = any("127.0.0.1" in str(addr) for addr in interfaces)
assert loopback_found, "Loopback address not found in available interfaces" assert loopback_found, "Loopback address not found in available interfaces"
# Should not have wildcard as the only option # Should not have wildcard as the only option
if len(interfaces) == 1: if len(interfaces) == 1:
assert "0.0.0.0" not in str(interfaces[0]) assert "0.0.0.0" not in str(interfaces[0])
@ -50,7 +53,7 @@ class TestDefaultBindAddress:
port = 8000 port = 8000
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}") listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
host = new_host(listen_addrs=[listen_addr]) host = new_host(listen_addrs=[listen_addr])
# Verify the host configuration # Verify the host configuration
assert host is not None assert host is not None
# Note: We can't test actual binding without running the host, # Note: We can't test actual binding without running the host,
@ -60,12 +63,12 @@ class TestDefaultBindAddress:
"""Test that fallback addresses don't use wildcard binding""" """Test that fallback addresses don't use wildcard binding"""
# When no interfaces are discovered, fallback should be loopback # When no interfaces are discovered, fallback should be loopback
port = 8000 port = 8000
# Even if we can't discover interfaces, we should get loopback # Even if we can't discover interfaces, we should get loopback
addr = get_optimal_binding_address(port) addr = get_optimal_binding_address(port)
# Should NOT be wildcard # Should NOT be wildcard
assert "0.0.0.0" not in str(addr) assert "0.0.0.0" not in str(addr)
# Should be a valid IP address # Should be a valid IP address
addr_str = str(addr) addr_str = str(addr)
assert "/ip4/" in addr_str assert "/ip4/" in addr_str
@ -76,11 +79,11 @@ class TestDefaultBindAddress:
"""Test that different protocols still use secure addresses by default""" """Test that different protocols still use secure addresses by default"""
port = 8000 port = 8000
addr = get_optimal_binding_address(port, protocol=protocol) addr = get_optimal_binding_address(port, protocol=protocol)
# Should NOT be wildcard # Should NOT be wildcard
assert "0.0.0.0" not in str(addr) assert "0.0.0.0" not in str(addr)
assert protocol in str(addr) assert protocol in str(addr)
# Should be a valid IP address # Should be a valid IP address
addr_str = str(addr) addr_str = str(addr)
assert "/ip4/" in addr_str assert "/ip4/" in addr_str
@ -90,15 +93,17 @@ class TestDefaultBindAddress:
"""Test that no public interface binding occurs by default""" """Test that no public interface binding occurs by default"""
port = 8000 port = 8000
interfaces = get_available_interfaces(port) interfaces = get_available_interfaces(port)
# Check that we don't expose on all interfaces by default # Check that we don't expose on all interfaces by default
wildcard_addrs = [addr for addr in interfaces if "0.0.0.0" in str(addr)] wildcard_addrs = [addr for addr in interfaces if "0.0.0.0" in str(addr)]
assert len(wildcard_addrs) == 0, "Found wildcard addresses in default interfaces" assert len(wildcard_addrs) == 0, (
"Found wildcard addresses in default interfaces"
)
# Verify optimal address selection doesn't choose wildcard # Verify optimal address selection doesn't choose wildcard
optimal = get_optimal_binding_address(port) optimal = get_optimal_binding_address(port)
assert "0.0.0.0" not in str(optimal), "Optimal address should not be wildcard" assert "0.0.0.0" not in str(optimal), "Optimal address should not be wildcard"
# Should be a valid IP address (could be loopback or local network) # Should be a valid IP address (could be loopback or local network)
addr_str = str(optimal) addr_str = str(optimal)
assert "/ip4/" in addr_str assert "/ip4/" in addr_str
@ -108,54 +113,70 @@ class TestDefaultBindAddress:
"""Test that loopback address is always available as an option""" """Test that loopback address is always available as an option"""
port = 8000 port = 8000
interfaces = get_available_interfaces(port) interfaces = get_available_interfaces(port)
# Loopback should always be available # Loopback should always be available
loopback_addrs = [addr for addr in interfaces if "127.0.0.1" in str(addr)] loopback_addrs = [addr for addr in interfaces if "127.0.0.1" in str(addr)]
assert len(loopback_addrs) > 0, "Loopback address should always be available" assert len(loopback_addrs) > 0, "Loopback address should always be available"
# At least one loopback address should have the correct port # At least one loopback address should have the correct port
loopback_with_port = [addr for addr in loopback_addrs if f"/tcp/{port}" in str(addr)] loopback_with_port = [
assert len(loopback_with_port) > 0, f"Loopback address with port {port} should be available" addr for addr in loopback_addrs if f"/tcp/{port}" in str(addr)
]
assert len(loopback_with_port) > 0, (
f"Loopback address with port {port} should be available"
)
def test_optimal_address_selection_behavior(self): def test_optimal_address_selection_behavior(self):
"""Test that optimal address selection works correctly""" """Test that optimal address selection works correctly"""
port = 8000 port = 8000
interfaces = get_available_interfaces(port) interfaces = get_available_interfaces(port)
optimal = get_optimal_binding_address(port) optimal = get_optimal_binding_address(port)
# Should never return wildcard # Should never return wildcard
assert "0.0.0.0" not in str(optimal) assert "0.0.0.0" not in str(optimal)
# Should return one of the available interfaces # Should return one of the available interfaces
optimal_str = str(optimal) optimal_str = str(optimal)
interface_strs = [str(addr) for addr in interfaces] interface_strs = [str(addr) for addr in interfaces]
assert optimal_str in interface_strs, f"Optimal address {optimal_str} should be in available interfaces" assert optimal_str in interface_strs, (
f"Optimal address {optimal_str} should be in available interfaces"
)
# If non-loopback interfaces are available, should prefer them # If non-loopback interfaces are available, should prefer them
non_loopback_interfaces = [addr for addr in interfaces if "127.0.0.1" not in str(addr)] non_loopback_interfaces = [
addr for addr in interfaces if "127.0.0.1" not in str(addr)
]
if non_loopback_interfaces: if non_loopback_interfaces:
# Should prefer non-loopback when available # Should prefer non-loopback when available
assert "127.0.0.1" not in str(optimal), "Should prefer non-loopback when available" assert "127.0.0.1" not in str(optimal), (
"Should prefer non-loopback when available"
)
else: else:
# Should use loopback when no other interfaces available # Should use loopback when no other interfaces available
assert "127.0.0.1" in str(optimal), "Should use loopback when no other interfaces available" assert "127.0.0.1" in str(optimal), (
"Should use loopback when no other interfaces available"
)
def test_address_validation_utilities_behavior(self): def test_address_validation_utilities_behavior(self):
"""Test that address validation utilities behave as expected""" """Test that address validation utilities behave as expected"""
port = 8000 port = 8000
# Test that we get multiple interface options # Test that we get multiple interface options
interfaces = get_available_interfaces(port) interfaces = get_available_interfaces(port)
assert len(interfaces) >= 2, "Should have at least loopback + one network interface" assert len(interfaces) >= 2, (
"Should have at least loopback + one network interface"
)
# Test that loopback is always included # Test that loopback is always included
has_loopback = any("127.0.0.1" in str(addr) for addr in interfaces) has_loopback = any("127.0.0.1" in str(addr) for addr in interfaces)
assert has_loopback, "Loopback should always be available" assert has_loopback, "Loopback should always be available"
# Test that no wildcards are included # Test that no wildcards are included
has_wildcard = any("0.0.0.0" in str(addr) for addr in interfaces) has_wildcard = any("0.0.0.0" in str(addr) for addr in interfaces)
assert not has_wildcard, "Wildcard addresses should never be included" assert not has_wildcard, "Wildcard addresses should never be included"
# Test optimal selection # Test optimal selection
optimal = get_optimal_binding_address(port) optimal = get_optimal_binding_address(port)
assert optimal in interfaces, "Optimal address should be from available interfaces" assert optimal in interfaces, (
"Optimal address should be from available interfaces"
)