mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Compare commits
214 Commits
333d56dc00
...
chore01
| Author | SHA1 | Date | |
|---|---|---|---|
| c1bbb1a9b4 | |||
| f13876a64f | |||
| 3c43e4682a | |||
| b46dae7c50 | |||
| 262e7e9834 | |||
| 634de8ed02 | |||
| 17c1ced408 | |||
| d64f9e10fd | |||
| 93c2d5002f | |||
| 721da9364e | |||
| 35a4bf2d42 | |||
| 02ff688b5a | |||
| 52625e0f68 | |||
| 066c87515e | |||
| a0826b21bc | |||
| 3e4be44fa3 | |||
| c8c7e55d5c | |||
| 3363f57338 | |||
| 37fd2542c0 | |||
| 5e2840d5b5 | |||
| 17838687fe | |||
| 43bb36338c | |||
| d519f75d69 | |||
| 7c60b81801 | |||
| 009fdd0d8f | |||
| 6a1b955a4e | |||
| 87429eb2e9 | |||
| 321bb86ea4 | |||
| c3c8d8ccb9 | |||
| a01811a435 | |||
| db2f3a64ea | |||
| 77208e95cc | |||
| ae3e2ff943 | |||
| a862ac83cd | |||
| 3f30ed4437 | |||
| 67a3cab2e2 | |||
| bf132cf3dd | |||
| 4dd2454a46 | |||
| b01f2bd105 | |||
| 1a4fe91419 | |||
| 1326592fe2 | |||
| 5f0c5101c7 | |||
| 8d028a046d | |||
| a0cb6e3a30 | |||
| f4a4298c0f | |||
| cdbd80eeba | |||
| 518c1f98b1 | |||
| 284eee78d7 | |||
| 6e8960b383 | |||
| 1250b2ec12 | |||
| 9b0f75014c | |||
| b4a5d9037c | |||
| ecdb770c45 | |||
| 95e7e7b4f6 | |||
| 8428aff20c | |||
| 81cc2f06f0 | |||
| c5a2836829 | |||
| 4fdfdae9fb | |||
| 0271a36316 | |||
| 8793667503 | |||
| 771b837916 | |||
| 93db588b9e | |||
| 4a36d6efeb | |||
| 7d364da950 | |||
| 4e8ebf707a | |||
| 80e22f7c4a | |||
| f4d5a44521 | |||
| afe6da5db2 | |||
| 396812e84a | |||
| 74f4aaf136 | |||
| fe662446dd | |||
| a8a71b077b | |||
| b7f11ba43d | |||
| a69db8a716 | |||
| aa2a650f85 | |||
| 030deb42b4 | |||
| 637bd5d560 | |||
| ce3f3a8e43 | |||
| f3976b7d2f | |||
| 09c9709a3e | |||
| f0b05b8307 | |||
| 9fdb36ed03 | |||
| 31191cbfae | |||
| 2fe5882013 | |||
| bffabd1070 | |||
| 9370101a84 | |||
| 56732a1506 | |||
| 68cb54ee0f | |||
| f80101c4eb | |||
| d268393812 | |||
| 4786b48364 | |||
| 2ee3e0b054 | |||
| 2a249b1792 | |||
| c693cd9bb9 | |||
| 25d7706047 | |||
| c5a8f26490 | |||
| 31c65274c3 | |||
| 5ec1671608 | |||
| 69a0d3da9d | |||
| 431a4807fb | |||
| f54a14b713 | |||
| d0c81301b5 | |||
| d2d4c4b451 | |||
| 4b4214f066 | |||
| 37a4d96f90 | |||
| 33730bdc48 | |||
| 159d2cc322 | |||
| b367ff70c3 | |||
| 9465805c3b | |||
| 37652f7034 | |||
| b8217bb8a8 | |||
| a0ca284e8f | |||
| 809a32a712 | |||
| ade6f5c6ad | |||
| d385cb45cf | |||
| 05867be37e | |||
| e8d1a0fc32 | |||
| 5633d52a63 | |||
| 68af8766e2 | |||
| 87550113a4 | |||
| 14a74fdbd1 | |||
| 145727a9ba | |||
| 84c1a7031a | |||
| 20edc3830a | |||
| 6742dd38f7 | |||
| 69680e9c1f | |||
| fcb35084b3 | |||
| 42c8937a8d | |||
| 64ccce17eb | |||
| 6a24b138dd | |||
| eab8df84df | |||
| 9749be6574 | |||
| 186113968e | |||
| e1141ee376 | |||
| 8e74f944e1 | |||
| 89cb8c0bd9 | |||
| d97b86081b | |||
| 2c03ac46ea | |||
| 58433f9b52 | |||
| 933741b190 | |||
| 760f94bd81 | |||
| 6d1e53a4e2 | |||
| 5ed3707a51 | |||
| f550c19b2c | |||
| 84c9ddc2dd | |||
| a6ff93122b | |||
| 8e6e88140f | |||
| 342ac746f8 | |||
| b3f0a4e8c4 | |||
| 0f64bb49b5 | |||
| 03bf071739 | |||
| c15c317514 | |||
| 6c45862fe9 | |||
| 8f0cdc9ed4 | |||
| bbe632bd85 | |||
| 2689040d48 | |||
| 8263052f88 | |||
| e2fee14bc5 | |||
| 6633eb01d4 | |||
| 123c86c091 | |||
| 369f79306f | |||
| cb6fd27626 | |||
| a1d1a07d4c | |||
| ac01cc5038 | |||
| 94d920f365 | |||
| 45c5f16379 | |||
| ce76641ef5 | |||
| bc2ac47594 | |||
| a3231af714 | |||
| 54b3055eaa | |||
| 446a22b0f0 | |||
| 8f5dd3bd11 | |||
| 997094e5b7 | |||
| 40dad64949 | |||
| d6cf83051e | |||
| 3d1c36419c | |||
| 8100a5cd20 | |||
| c940dac1e6 | |||
| fb544d6db2 | |||
| b40d84fc26 | |||
| 3b27b02a8b | |||
| fe761baa49 | |||
| 3baf886527 | |||
| f0172a0ba1 | |||
| 5a2fca32a0 | |||
| 8d9b7f413d | |||
| 5de09ed8a1 | |||
| a9f184be6a | |||
| a9a6ed6767 | |||
| f3cf06cd72 | |||
| 6931092eea | |||
| 163cc35cb0 | |||
| 388302baa7 | |||
| dc04270c19 | |||
| f1872bba93 | |||
| 9573ab5aea | |||
| a5b0db1d8f | |||
| fe4c17e8d1 | |||
| 167dfdcac1 | |||
| 19c1f5ead0 | |||
| 64107b4648 | |||
| a6f85690bf | |||
| 651bf0fc6e | |||
| 7a1aa548a1 | |||
| 65faa214da | |||
| 53a16d0476 | |||
| 1997777c52 | |||
| 8a21435fc4 | |||
| 7469238bb4 | |||
| b6c36373a9 | |||
| 4fb7132b4e | |||
| fa0b64dca8 | |||
| 227a5c6441 | |||
| 187418378a |
42
.github/workflows/tox.yml
vendored
42
.github/workflows/tox.yml
vendored
@ -36,10 +36,48 @@ jobs:
|
|||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install Nim for interop testing
|
||||||
|
if: matrix.toxenv == 'interop'
|
||||||
|
run: |
|
||||||
|
echo "Installing Nim for nim-libp2p interop testing..."
|
||||||
|
curl -sSf https://nim-lang.org/choosenim/init.sh | sh -s -- -y --firstInstall
|
||||||
|
echo "$HOME/.nimble/bin" >> $GITHUB_PATH
|
||||||
|
echo "$HOME/.choosenim/toolchains/nim-stable/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Cache nimble packages
|
||||||
|
if: matrix.toxenv == 'interop'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.nimble
|
||||||
|
~/.choosenim/toolchains/*/lib
|
||||||
|
key: ${{ runner.os }}-nimble-${{ hashFiles('**/nim_echo_server.nim') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-nimble-
|
||||||
|
|
||||||
|
- name: Build nim interop binaries
|
||||||
|
if: matrix.toxenv == 'interop'
|
||||||
|
run: |
|
||||||
|
export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH"
|
||||||
|
cd tests/interop/nim_libp2p
|
||||||
|
./scripts/setup_nim_echo.sh
|
||||||
|
|
||||||
- run: |
|
- run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install tox
|
python -m pip install tox
|
||||||
- run: |
|
|
||||||
|
- name: Run Tests or Generate Docs
|
||||||
|
run: |
|
||||||
|
if [[ "${{ matrix.toxenv }}" == 'docs' ]]; then
|
||||||
|
export TOXENV=docs
|
||||||
|
else
|
||||||
|
export TOXENV=py${{ matrix.python }}-${{ matrix.toxenv }}
|
||||||
|
fi
|
||||||
|
# Set PATH for nim commands during tox
|
||||||
|
if [[ "${{ matrix.toxenv }}" == 'interop' ]]; then
|
||||||
|
export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH"
|
||||||
|
fi
|
||||||
python -m tox run -r
|
python -m tox run -r
|
||||||
|
|
||||||
windows:
|
windows:
|
||||||
@ -65,5 +103,5 @@ jobs:
|
|||||||
if [[ "${{ matrix.toxenv }}" == "wheel" ]]; then
|
if [[ "${{ matrix.toxenv }}" == "wheel" ]]; then
|
||||||
python -m tox run -e windows-wheel
|
python -m tox run -e windows-wheel
|
||||||
else
|
else
|
||||||
python -m tox run -e py311-${{ matrix.toxenv }}
|
python -m tox run -e py${{ matrix.python-version }}-${{ matrix.toxenv }}
|
||||||
fi
|
fi
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -178,6 +178,10 @@ env.bak/
|
|||||||
#lockfiles
|
#lockfiles
|
||||||
uv.lock
|
uv.lock
|
||||||
poetry.lock
|
poetry.lock
|
||||||
|
tests/interop/js_libp2p/js_node/node_modules/
|
||||||
|
tests/interop/js_libp2p/js_node/package-lock.json
|
||||||
|
tests/interop/js_libp2p/js_node/src/node_modules/
|
||||||
|
tests/interop/js_libp2p/js_node/src/package-lock.json
|
||||||
|
|
||||||
# Sphinx documentation build
|
# Sphinx documentation build
|
||||||
_build/
|
_build/
|
||||||
|
|||||||
12
README.md
12
README.md
@ -61,12 +61,12 @@ ______________________________________________________________________
|
|||||||
|
|
||||||
### Discovery
|
### Discovery
|
||||||
|
|
||||||
| **Discovery** | **Status** | **Source** |
|
| **Discovery** | **Status** | **Source** |
|
||||||
| -------------------- | :--------: | :--------------------------------------------------------------------------------: |
|
| -------------------- | :--------: | :----------------------------------------------------------------------------------: |
|
||||||
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
|
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
|
||||||
| **`random-walk`** | 🌱 | |
|
| **`random-walk`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/random_walk) |
|
||||||
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
|
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
|
||||||
| **`rendezvous`** | 🌱 | |
|
| **`rendezvous`** | 🌱 | |
|
||||||
|
|
||||||
______________________________________________________________________
|
______________________________________________________________________
|
||||||
|
|
||||||
|
|||||||
@ -36,12 +36,14 @@ Create a file named ``relay_node.py`` with the following content:
|
|||||||
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
|
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
|
||||||
from libp2p.relay.circuit_v2.config import RelayConfig
|
from libp2p.relay.circuit_v2.config import RelayConfig
|
||||||
from libp2p.tools.async_service import background_trio_service
|
from libp2p.tools.async_service import background_trio_service
|
||||||
|
from libp2p.utils import get_wildcard_address
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger("relay_node")
|
logger = logging.getLogger("relay_node")
|
||||||
|
|
||||||
async def run_relay():
|
async def run_relay():
|
||||||
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9000")
|
# Use wildcard address to listen on all interfaces
|
||||||
|
listen_addr = get_wildcard_address(9000)
|
||||||
host = new_host()
|
host = new_host()
|
||||||
|
|
||||||
config = RelayConfig(
|
config = RelayConfig(
|
||||||
@ -107,6 +109,7 @@ Create a file named ``destination_node.py`` with the following content:
|
|||||||
from libp2p.relay.circuit_v2.config import RelayConfig
|
from libp2p.relay.circuit_v2.config import RelayConfig
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from libp2p.tools.async_service import background_trio_service
|
from libp2p.tools.async_service import background_trio_service
|
||||||
|
from libp2p.utils import get_wildcard_address
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger("destination_node")
|
logger = logging.getLogger("destination_node")
|
||||||
@ -139,7 +142,8 @@ Create a file named ``destination_node.py`` with the following content:
|
|||||||
Run a simple destination node that accepts connections.
|
Run a simple destination node that accepts connections.
|
||||||
This is a simplified version that doesn't use the relay functionality.
|
This is a simplified version that doesn't use the relay functionality.
|
||||||
"""
|
"""
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/9001")
|
# Create a libp2p host - use wildcard address to listen on all interfaces
|
||||||
|
listen_addr = get_wildcard_address(9001)
|
||||||
host = new_host()
|
host = new_host()
|
||||||
|
|
||||||
# Configure as a relay receiver (stop)
|
# Configure as a relay receiver (stop)
|
||||||
@ -252,14 +256,15 @@ Create a file named ``source_node.py`` with the following content:
|
|||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from libp2p.tools.async_service import background_trio_service
|
from libp2p.tools.async_service import background_trio_service
|
||||||
from libp2p.relay.circuit_v2.discovery import RelayInfo
|
from libp2p.relay.circuit_v2.discovery import RelayInfo
|
||||||
|
from libp2p.utils import get_wildcard_address
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger("source_node")
|
logger = logging.getLogger("source_node")
|
||||||
|
|
||||||
async def run_source(relay_peer_id=None, destination_peer_id=None):
|
async def run_source(relay_peer_id=None, destination_peer_id=None):
|
||||||
# Create a libp2p host
|
# Create a libp2p host - use wildcard address to listen on all interfaces
|
||||||
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9002")
|
listen_addr = get_wildcard_address(9002)
|
||||||
host = new_host()
|
host = new_host()
|
||||||
|
|
||||||
# Configure as a relay client
|
# Configure as a relay client
|
||||||
@ -428,7 +433,7 @@ Running the Example
|
|||||||
Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||||
==================================================
|
==================================================
|
||||||
|
|
||||||
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx>]
|
Listening on: [<Multiaddr /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx>]
|
||||||
Protocol service started
|
Protocol service started
|
||||||
Relay service started successfully
|
Relay service started successfully
|
||||||
Relay limits: RelayLimits(duration=3600, data=10485760, max_circuit_conns=8, max_reservations=4)
|
Relay limits: RelayLimits(duration=3600, data=10485760, max_circuit_conns=8, max_reservations=4)
|
||||||
@ -447,7 +452,7 @@ Running the Example
|
|||||||
Use this ID in the source node: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
|
Use this ID in the source node: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
|
||||||
==================================================
|
==================================================
|
||||||
|
|
||||||
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9001/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s>]
|
Listening on: [<Multiaddr /ip4/127.0.0.1/tcp/9001/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s>]
|
||||||
Registered echo protocol handler
|
Registered echo protocol handler
|
||||||
Protocol service started
|
Protocol service started
|
||||||
Transport created
|
Transport created
|
||||||
@ -469,7 +474,7 @@ Running the Example
|
|||||||
|
|
||||||
$ python source_node.py
|
$ python source_node.py
|
||||||
Source node started with ID: QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3
|
Source node started with ID: QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3
|
||||||
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9002/p2p/QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3>]
|
Listening on: [<Multiaddr /ip4/127.0.0.1/tcp/9002/p2p/QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3>]
|
||||||
Protocol service started
|
Protocol service started
|
||||||
No relay peer ID provided. Please enter the relay\'s peer ID:
|
No relay peer ID provided. Please enter the relay\'s peer ID:
|
||||||
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
|
||||||
|
|||||||
43
docs/examples.echo_quic.rst
Normal file
43
docs/examples.echo_quic.rst
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
QUIC Echo Demo
|
||||||
|
==============
|
||||||
|
|
||||||
|
This example demonstrates a simple ``echo`` protocol using **QUIC transport**.
|
||||||
|
|
||||||
|
QUIC provides built-in TLS security and stream multiplexing over UDP, making it an excellent transport choice for libp2p applications.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python -m pip install libp2p
|
||||||
|
Collecting libp2p
|
||||||
|
...
|
||||||
|
Successfully installed libp2p-x.x.x
|
||||||
|
$ echo-quic-demo
|
||||||
|
Run this from the same folder in another console:
|
||||||
|
|
||||||
|
echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmAsbxRR1HiGJRNVPQLNMeNsBCsXT3rDjoYBQzgzNpM5mJ
|
||||||
|
|
||||||
|
Waiting for incoming connection...
|
||||||
|
|
||||||
|
Copy the line that starts with ``echo-quic-demo -p 8001``, open a new terminal in the same
|
||||||
|
folder and paste it in:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu
|
||||||
|
|
||||||
|
I am 16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu
|
||||||
|
STARTING CLIENT CONNECTION PROCESS
|
||||||
|
CLIENT CONNECTED TO SERVER
|
||||||
|
Sent: hi, there!
|
||||||
|
Got: ECHO: hi, there!
|
||||||
|
|
||||||
|
**Key differences from TCP Echo:**
|
||||||
|
|
||||||
|
- Uses UDP instead of TCP: ``/udp/8000`` instead of ``/tcp/8000``
|
||||||
|
- Includes QUIC protocol identifier: ``/quic-v1`` in the multiaddr
|
||||||
|
- Built-in TLS security (no separate security transport needed)
|
||||||
|
- Native stream multiplexing over a single QUIC connection
|
||||||
|
|
||||||
|
.. literalinclude:: ../examples/echo/echo_quic.py
|
||||||
|
:language: python
|
||||||
|
:linenos:
|
||||||
@ -12,7 +12,7 @@ This example demonstrates how to use the libp2p ``identify`` protocol.
|
|||||||
$ identify-demo
|
$ identify-demo
|
||||||
First host listening. Run this from another console:
|
First host listening. Run this from another console:
|
||||||
|
|
||||||
identify-demo -p 8889 -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
identify-demo -p 8889 -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
|
|
||||||
Waiting for incoming identify request...
|
Waiting for incoming identify request...
|
||||||
|
|
||||||
@ -21,13 +21,13 @@ folder and paste it in:
|
|||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ identify-demo -p 8889 -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
$ identify-demo -p 8889 -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
dialer (host_b) listening on /ip4/0.0.0.0/tcp/8889
|
dialer (host_b) listening on /ip4/127.0.0.1/tcp/8889
|
||||||
Second host connecting to peer: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
Second host connecting to peer: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
Starting identify protocol...
|
Starting identify protocol...
|
||||||
Identify response:
|
Identify response:
|
||||||
Public Key (Base64): CAASpgIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDC6c/oNPP9X13NDQ3Xrlp3zOj+ErXIWb/A4JGwWchiDBwMhMslEX3ct8CqI0BqUYKuwdFjowqqopOJ3cS2MlqtGaiP6Dg9bvGqSDoD37BpNaRVNcebRxtB0nam9SQy3PYLbHAmz0vR4ToSiL9OLRORnGOxCtHBuR8ZZ5vS0JEni8eQMpNa7IuXwyStnuty/QjugOZudBNgYSr8+9gH722KTjput5IRL7BrpIdd4HNXGVRm4b9BjNowvHu404x3a/ifeNblpy/FbYyFJEW0looygKF7hpRHhRbRKIDZt2BqOfT1sFkbqsHE85oY859+VMzP61YELgvGwai2r7KcjkW/AgMBAAE=
|
Public Key (Base64): CAASpgIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDC6c/oNPP9X13NDQ3Xrlp3zOj+ErXIWb/A4JGwWchiDBwMhMslEX3ct8CqI0BqUYKuwdFjowqqopOJ3cS2MlqtGaiP6Dg9bvGqSDoD37BpNaRVNcebRxtB0nam9SQy3PYLbHAmz0vR4ToSiL9OLRORnGOxCtHBuR8ZZ5vS0JEni8eQMpNa7IuXwyStnuty/QjugOZudBNgYSr8+9gH722KTjput5IRL7BrpIdd4HNXGVRm4b9BjNowvHu404x3a/ifeNblpy/FbYyFJEW0looygKF7hpRHhRbRKIDZt2BqOfT1sFkbqsHE85oY859+VMzP61YELgvGwai2r7KcjkW/AgMBAAE=
|
||||||
Listen Addresses: ['/ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM']
|
Listen Addresses: ['/ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM']
|
||||||
Protocols: ['/ipfs/id/1.0.0', '/ipfs/ping/1.0.0']
|
Protocols: ['/ipfs/id/1.0.0', '/ipfs/ping/1.0.0']
|
||||||
Observed Address: ['/ip4/127.0.0.1/tcp/38082']
|
Observed Address: ['/ip4/127.0.0.1/tcp/38082']
|
||||||
Protocol Version: ipfs/0.1.0
|
Protocol Version: ipfs/0.1.0
|
||||||
|
|||||||
@ -34,11 +34,11 @@ There is also a more interactive version of the example which runs as separate l
|
|||||||
==== Starting Identify-Push Listener on port 8888 ====
|
==== Starting Identify-Push Listener on port 8888 ====
|
||||||
|
|
||||||
Listener host ready!
|
Listener host ready!
|
||||||
Listening on: /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
Listening on: /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
Peer ID: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
Peer ID: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
|
|
||||||
Run dialer with command:
|
Run dialer with command:
|
||||||
identify-push-listener-dialer-demo -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
identify-push-listener-dialer-demo -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
|
|
||||||
Waiting for incoming connections... (Ctrl+C to exit)
|
Waiting for incoming connections... (Ctrl+C to exit)
|
||||||
|
|
||||||
@ -47,12 +47,12 @@ folder and paste it in:
|
|||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ identify-push-listener-dialer-demo -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
$ identify-push-listener-dialer-demo -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
|
|
||||||
==== Starting Identify-Push Dialer on port 8889 ====
|
==== Starting Identify-Push Dialer on port 8889 ====
|
||||||
|
|
||||||
Dialer host ready!
|
Dialer host ready!
|
||||||
Listening on: /ip4/0.0.0.0/tcp/8889/p2p/QmZyXwVuTaBcDeRsSkJpOpWrSt
|
Listening on: /ip4/127.0.0.1/tcp/8889/p2p/QmZyXwVuTaBcDeRsSkJpOpWrSt
|
||||||
|
|
||||||
Connecting to peer: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
Connecting to peer: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM
|
||||||
Successfully connected to listener!
|
Successfully connected to listener!
|
||||||
|
|||||||
@ -15,7 +15,7 @@ This example demonstrates how to create a chat application using libp2p's PubSub
|
|||||||
2025-04-06 23:59:17,471 - pubsub-demo - INFO - Your selected topic is: pubsub-chat
|
2025-04-06 23:59:17,471 - pubsub-demo - INFO - Your selected topic is: pubsub-chat
|
||||||
2025-04-06 23:59:17,472 - pubsub-demo - INFO - Using random available port: 33269
|
2025-04-06 23:59:17,472 - pubsub-demo - INFO - Using random available port: 33269
|
||||||
2025-04-06 23:59:17,490 - pubsub-demo - INFO - Node started with peer ID: QmcJnocH1d1tz3Zp4MotVDjNfNFawXHw2dpB9tMYGTXJp7
|
2025-04-06 23:59:17,490 - pubsub-demo - INFO - Node started with peer ID: QmcJnocH1d1tz3Zp4MotVDjNfNFawXHw2dpB9tMYGTXJp7
|
||||||
2025-04-06 23:59:17,490 - pubsub-demo - INFO - Listening on: /ip4/0.0.0.0/tcp/33269
|
2025-04-06 23:59:17,490 - pubsub-demo - INFO - Listening on: /ip4/127.0.0.1/tcp/33269
|
||||||
2025-04-06 23:59:17,490 - pubsub-demo - INFO - Initializing PubSub and GossipSub...
|
2025-04-06 23:59:17,490 - pubsub-demo - INFO - Initializing PubSub and GossipSub...
|
||||||
2025-04-06 23:59:17,491 - pubsub-demo - INFO - Pubsub and GossipSub services started.
|
2025-04-06 23:59:17,491 - pubsub-demo - INFO - Pubsub and GossipSub services started.
|
||||||
2025-04-06 23:59:17,491 - pubsub-demo - INFO - Pubsub ready.
|
2025-04-06 23:59:17,491 - pubsub-demo - INFO - Pubsub ready.
|
||||||
@ -35,7 +35,7 @@ Copy the line that starts with ``pubsub-demo -d``, open a new terminal and paste
|
|||||||
2025-04-07 00:00:59,846 - pubsub-demo - INFO - Your selected topic is: pubsub-chat
|
2025-04-07 00:00:59,846 - pubsub-demo - INFO - Your selected topic is: pubsub-chat
|
||||||
2025-04-07 00:00:59,846 - pubsub-demo - INFO - Using random available port: 51977
|
2025-04-07 00:00:59,846 - pubsub-demo - INFO - Using random available port: 51977
|
||||||
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Node started with peer ID: QmYQKCm95Ut1aXsjHmWVYqdaVbno1eKTYC8KbEVjqUaKaQ
|
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Node started with peer ID: QmYQKCm95Ut1aXsjHmWVYqdaVbno1eKTYC8KbEVjqUaKaQ
|
||||||
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Listening on: /ip4/0.0.0.0/tcp/51977
|
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Listening on: /ip4/127.0.0.1/tcp/51977
|
||||||
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Initializing PubSub and GossipSub...
|
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Initializing PubSub and GossipSub...
|
||||||
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Pubsub and GossipSub services started.
|
2025-04-07 00:00:59,864 - pubsub-demo - INFO - Pubsub and GossipSub services started.
|
||||||
2025-04-07 00:00:59,865 - pubsub-demo - INFO - Pubsub ready.
|
2025-04-07 00:00:59,865 - pubsub-demo - INFO - Pubsub ready.
|
||||||
|
|||||||
@ -23,7 +23,7 @@ The Random Walk implementation performs the following key operations:
|
|||||||
2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s
|
2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s
|
||||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123
|
2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123
|
||||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/127.0.0.1/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0
|
2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0
|
||||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode
|
2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode
|
||||||
2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started
|
2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started
|
||||||
|
|||||||
@ -9,6 +9,7 @@ Examples
|
|||||||
examples.identify_push
|
examples.identify_push
|
||||||
examples.chat
|
examples.chat
|
||||||
examples.echo
|
examples.echo
|
||||||
|
examples.echo_quic
|
||||||
examples.ping
|
examples.ping
|
||||||
examples.pubsub
|
examples.pubsub
|
||||||
examples.circuit_relay
|
examples.circuit_relay
|
||||||
|
|||||||
@ -28,6 +28,11 @@ For Python, the most common transport is TCP. Here's how to set up a basic TCP t
|
|||||||
.. literalinclude:: ../examples/doc-examples/example_transport.py
|
.. literalinclude:: ../examples/doc-examples/example_transport.py
|
||||||
:language: python
|
:language: python
|
||||||
|
|
||||||
|
Also, QUIC is a modern transport protocol that provides built-in TLS security and stream multiplexing over UDP:
|
||||||
|
|
||||||
|
.. literalinclude:: ../examples/doc-examples/example_quic_transport.py
|
||||||
|
:language: python
|
||||||
|
|
||||||
Connection Encryption
|
Connection Encryption
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
77
docs/libp2p.transport.quic.rst
Normal file
77
docs/libp2p.transport.quic.rst
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
libp2p.transport.quic package
|
||||||
|
=============================
|
||||||
|
|
||||||
|
Submodules
|
||||||
|
----------
|
||||||
|
|
||||||
|
libp2p.transport.quic.config module
|
||||||
|
-----------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.config
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
libp2p.transport.quic.connection module
|
||||||
|
---------------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.connection
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
libp2p.transport.quic.exceptions module
|
||||||
|
---------------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.exceptions
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
libp2p.transport.quic.listener module
|
||||||
|
-------------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.listener
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
libp2p.transport.quic.security module
|
||||||
|
-------------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.security
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
libp2p.transport.quic.stream module
|
||||||
|
-----------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.stream
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
libp2p.transport.quic.transport module
|
||||||
|
--------------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.transport
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
libp2p.transport.quic.utils module
|
||||||
|
----------------------------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic.utils
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
Module contents
|
||||||
|
---------------
|
||||||
|
|
||||||
|
.. automodule:: libp2p.transport.quic
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
@ -9,6 +9,11 @@ Subpackages
|
|||||||
|
|
||||||
libp2p.transport.tcp
|
libp2p.transport.tcp
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 4
|
||||||
|
|
||||||
|
libp2p.transport.quic
|
||||||
|
|
||||||
Submodules
|
Submodules
|
||||||
----------
|
----------
|
||||||
|
|
||||||
|
|||||||
@ -14,11 +14,26 @@ try:
|
|||||||
expand_wildcard_address,
|
expand_wildcard_address,
|
||||||
get_available_interfaces,
|
get_available_interfaces,
|
||||||
get_optimal_binding_address,
|
get_optimal_binding_address,
|
||||||
|
get_wildcard_address,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallbacks if utilities are missing
|
# Fallbacks if utilities are missing - use minimal network discovery
|
||||||
|
import socket
|
||||||
|
|
||||||
def get_available_interfaces(port: int, protocol: str = "tcp"):
|
def get_available_interfaces(port: int, protocol: str = "tcp"):
|
||||||
return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")]
|
# Try to get local network interfaces, fallback to loopback
|
||||||
|
addrs = []
|
||||||
|
try:
|
||||||
|
# Get hostname IP (better than hardcoded localhost)
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
local_ip = socket.gethostbyname(hostname)
|
||||||
|
if local_ip != "127.0.0.1":
|
||||||
|
addrs.append(Multiaddr(f"/ip4/{local_ip}/{protocol}/{port}"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Always include loopback as fallback
|
||||||
|
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
|
||||||
|
return addrs
|
||||||
|
|
||||||
def expand_wildcard_address(addr: Multiaddr, port: int | None = None):
|
def expand_wildcard_address(addr: Multiaddr, port: int | None = None):
|
||||||
if port is None:
|
if port is None:
|
||||||
@ -27,6 +42,15 @@ except ImportError:
|
|||||||
return [Multiaddr(addr_str + f"/{port}")]
|
return [Multiaddr(addr_str + f"/{port}")]
|
||||||
|
|
||||||
def get_optimal_binding_address(port: int, protocol: str = "tcp"):
|
def get_optimal_binding_address(port: int, protocol: str = "tcp"):
|
||||||
|
# Try to get a non-loopback address first
|
||||||
|
interfaces = get_available_interfaces(port, protocol)
|
||||||
|
for addr in interfaces:
|
||||||
|
if "127.0.0.1" not in str(addr):
|
||||||
|
return addr
|
||||||
|
# Fallback to loopback if no other interfaces found
|
||||||
|
return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")
|
||||||
|
|
||||||
|
def get_wildcard_address(port: int, protocol: str = "tcp"):
|
||||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +61,10 @@ def main() -> None:
|
|||||||
for a in interfaces:
|
for a in interfaces:
|
||||||
print(f" - {a}")
|
print(f" - {a}")
|
||||||
|
|
||||||
wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
# Demonstrate wildcard address as a feature
|
||||||
|
wildcard_v4 = get_wildcard_address(port)
|
||||||
|
print(f"\nWildcard address (feature): {wildcard_v4}")
|
||||||
|
|
||||||
expanded_v4 = expand_wildcard_address(wildcard_v4)
|
expanded_v4 = expand_wildcard_address(wildcard_v4)
|
||||||
print("\nExpanded IPv4 wildcard:")
|
print("\nExpanded IPv4 wildcard:")
|
||||||
for a in expanded_v4:
|
for a in expanded_v4:
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import new_host
|
from libp2p import new_host
|
||||||
@ -54,18 +53,26 @@ BOOTSTRAP_PEERS = [
|
|||||||
|
|
||||||
async def run(port: int, bootstrap_addrs: list[str]) -> None:
|
async def run(port: int, bootstrap_addrs: list[str]) -> None:
|
||||||
"""Run the bootstrap discovery example."""
|
"""Run the bootstrap discovery example."""
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
find_free_port,
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
if port <= 0:
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
# Generate key pair
|
# Generate key pair
|
||||||
secret = secrets.token_bytes(32)
|
secret = secrets.token_bytes(32)
|
||||||
key_pair = create_new_key_pair(secret)
|
key_pair = create_new_key_pair(secret)
|
||||||
|
|
||||||
# Create listen address
|
# Create listen addresses for all available interfaces
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
|
||||||
# Register peer discovery handler
|
# Register peer discovery handler
|
||||||
peerDiscovery.register_peer_discovered_handler(on_peer_discovery)
|
peerDiscovery.register_peer_discovered_handler(on_peer_discovery)
|
||||||
|
|
||||||
logger.info("🚀 Starting Bootstrap Discovery Example")
|
logger.info("🚀 Starting Bootstrap Discovery Example")
|
||||||
logger.info(f"📍 Listening on: {listen_addr}")
|
|
||||||
logger.info(f"🌐 Bootstrap peers: {len(bootstrap_addrs)}")
|
logger.info(f"🌐 Bootstrap peers: {len(bootstrap_addrs)}")
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
@ -80,7 +87,22 @@ async def run(port: int, bootstrap_addrs: list[str]) -> None:
|
|||||||
host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs)
|
host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
|
||||||
|
logger.info("Listener ready, listening on:")
|
||||||
|
print("Listener ready, listening on:")
|
||||||
|
for addr in all_addrs:
|
||||||
|
logger.info(f"{addr}")
|
||||||
|
print(f"{addr}")
|
||||||
|
|
||||||
|
# Display optimal address for reference
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}"
|
||||||
|
logger.info(f"Optimal address: {optimal_addr_with_peer}")
|
||||||
|
print(f"Optimal address: {optimal_addr_with_peer}")
|
||||||
|
|
||||||
# Keep running and log peer discovery events
|
# Keep running and log peer discovery events
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -98,7 +120,7 @@ def main() -> None:
|
|||||||
Usage:
|
Usage:
|
||||||
python bootstrap.py -p 8000
|
python bootstrap.py -p 8000
|
||||||
python bootstrap.py -p 8001 --custom-bootstrap \\
|
python bootstrap.py -p 8001 --custom-bootstrap \\
|
||||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmYourPeerID"
|
"/ip4/[HOST_IP]/tcp/8000/p2p/QmYourPeerID"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import multiaddr
|
import multiaddr
|
||||||
@ -17,6 +18,11 @@ from libp2p.peer.peerinfo import (
|
|||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Configure minimal logging
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("libp2p").setLevel(logging.WARNING)
|
||||||
|
|
||||||
PROTOCOL_ID = TProtocol("/chat/1.0.0")
|
PROTOCOL_ID = TProtocol("/chat/1.0.0")
|
||||||
MAX_READ_LEN = 2**32 - 1
|
MAX_READ_LEN = 2**32 - 1
|
||||||
|
|
||||||
@ -40,9 +46,18 @@ async def write_data(stream: INetStream) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def run(port: int, destination: str) -> None:
|
async def run(port: int, destination: str) -> None:
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
from libp2p.utils.address_validation import (
|
||||||
|
find_free_port,
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
if port <= 0:
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
|
listen_addrs = get_available_interfaces(port)
|
||||||
host = new_host()
|
host = new_host()
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery:
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
@ -54,10 +69,19 @@ async def run(port: int, destination: str) -> None:
|
|||||||
|
|
||||||
host.set_stream_handler(PROTOCOL_ID, stream_handler)
|
host.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||||
|
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
|
||||||
|
print("Listener ready, listening on:\n")
|
||||||
|
for addr in all_addrs:
|
||||||
|
print(f"{addr}")
|
||||||
|
|
||||||
|
# Use optimal address for the client command
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}"
|
||||||
print(
|
print(
|
||||||
"Run this from the same folder in another console:\n\n"
|
f"\nRun this from the same folder in another console:\n\n"
|
||||||
f"chat-demo "
|
f"chat-demo -d {optimal_addr_with_peer}\n"
|
||||||
f"-d {host.get_addrs()[0]}\n"
|
|
||||||
)
|
)
|
||||||
print("Waiting for incoming connection...")
|
print("Waiting for incoming connection...")
|
||||||
|
|
||||||
@ -86,7 +110,7 @@ def main() -> None:
|
|||||||
where <DESTINATION> is the multiaddress of the previous listener host.
|
where <DESTINATION> is the multiaddress of the previous listener host.
|
||||||
"""
|
"""
|
||||||
example_maddr = (
|
example_maddr = (
|
||||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
"/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||||
)
|
)
|
||||||
parser = argparse.ArgumentParser(description=description)
|
parser = argparse.ArgumentParser(description=description)
|
||||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -9,9 +8,10 @@ from libp2p import (
|
|||||||
from libp2p.crypto.secp256k1 import (
|
from libp2p.crypto.secp256k1 import (
|
||||||
create_new_key_pair,
|
create_new_key_pair,
|
||||||
)
|
)
|
||||||
from libp2p.security.insecure.transport import (
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
PLAINTEXT_PROTOCOL_ID,
|
from libp2p.utils.address_validation import (
|
||||||
InsecureTransport,
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -38,17 +38,19 @@ async def main():
|
|||||||
# Create a host with the key pair and insecure transport
|
# Create a host with the key pair and insecure transport
|
||||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||||
|
|
||||||
# Configure the listening address
|
# Configure the listening address using the new paradigm
|
||||||
port = 8000
|
port = 8000
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Start the host
|
# Start the host
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print(
|
print(
|
||||||
"libp2p has started with insecure transport "
|
"libp2p has started with insecure transport "
|
||||||
"(not recommended for production)"
|
"(not recommended for production)"
|
||||||
)
|
)
|
||||||
print("libp2p is listening on:", host.get_addrs())
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_addr}")
|
||||||
# Keep the host running
|
# Keep the host running
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -13,6 +12,10 @@ from libp2p.security.noise.transport import (
|
|||||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||||
Transport as NoiseTransport,
|
Transport as NoiseTransport,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@ -24,8 +27,13 @@ async def main():
|
|||||||
noise_transport = NoiseTransport(
|
noise_transport = NoiseTransport(
|
||||||
# local_key_pair: The key pair used for libp2p identity and authentication
|
# local_key_pair: The key pair used for libp2p identity and authentication
|
||||||
libp2p_keypair=key_pair,
|
libp2p_keypair=key_pair,
|
||||||
|
# noise_privkey: The private key used for Noise protocol encryption
|
||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# TODO: add early data
|
# early_data: Optional data to send during the handshake
|
||||||
|
# (None means no early data)
|
||||||
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
@ -34,14 +42,16 @@ async def main():
|
|||||||
# Create a host with the key pair and Noise security
|
# Create a host with the key pair and Noise security
|
||||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||||
|
|
||||||
# Configure the listening address
|
# Configure the listening address using the new paradigm
|
||||||
port = 8000
|
port = 8000
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Start the host
|
# Start the host
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print("libp2p has started with Noise encryption")
|
print("libp2p has started with Noise encryption")
|
||||||
print("libp2p is listening on:", host.get_addrs())
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_addr}")
|
||||||
# Keep the host running
|
# Keep the host running
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -13,6 +12,10 @@ from libp2p.security.secio.transport import (
|
|||||||
ID as SECIO_PROTOCOL_ID,
|
ID as SECIO_PROTOCOL_ID,
|
||||||
Transport as SecioTransport,
|
Transport as SecioTransport,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@ -32,14 +35,16 @@ async def main():
|
|||||||
# Create a host with the key pair and SECIO security
|
# Create a host with the key pair and SECIO security
|
||||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||||
|
|
||||||
# Configure the listening address
|
# Configure the listening address using the new paradigm
|
||||||
port = 8000
|
port = 8000
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Start the host
|
# Start the host
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print("libp2p has started with SECIO encryption")
|
print("libp2p has started with SECIO encryption")
|
||||||
print("libp2p is listening on:", host.get_addrs())
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_addr}")
|
||||||
# Keep the host running
|
# Keep the host running
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -13,6 +12,10 @@ from libp2p.security.noise.transport import (
|
|||||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||||
Transport as NoiseTransport,
|
Transport as NoiseTransport,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@ -28,7 +31,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
@ -37,14 +42,16 @@ async def main():
|
|||||||
# Create a host with the key pair, Noise security, and mplex multiplexer
|
# Create a host with the key pair, Noise security, and mplex multiplexer
|
||||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||||
|
|
||||||
# Configure the listening address
|
# Configure the listening address using the new paradigm
|
||||||
port = 8000
|
port = 8000
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Start the host
|
# Start the host
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print("libp2p has started with Noise encryption and mplex multiplexing")
|
print("libp2p has started with Noise encryption and mplex multiplexing")
|
||||||
print("libp2p is listening on:", host.get_addrs())
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_addr}")
|
||||||
# Keep the host running
|
# Keep the host running
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,10 @@ from libp2p.network.stream.net_stream import (
|
|||||||
from libp2p.peer.peerinfo import (
|
from libp2p.peer.peerinfo import (
|
||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||||
|
|
||||||
@ -173,7 +177,9 @@ async def run_enhanced_demo(
|
|||||||
"""
|
"""
|
||||||
Run enhanced echo demo with NetStream state management.
|
Run enhanced echo demo with NetStream state management.
|
||||||
"""
|
"""
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
# Use the new address paradigm
|
||||||
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Generate or use provided key
|
# Generate or use provided key
|
||||||
if seed:
|
if seed:
|
||||||
@ -185,7 +191,7 @@ async def run_enhanced_demo(
|
|||||||
|
|
||||||
host = new_host(key_pair=create_new_key_pair(secret))
|
host = new_host(key_pair=create_new_key_pair(secret))
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print(f"Host ID: {host.get_id().to_string()}")
|
print(f"Host ID: {host.get_id().to_string()}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
@ -196,10 +202,12 @@ async def run_enhanced_demo(
|
|||||||
# type: ignore: Stream is type of NetStream
|
# type: ignore: Stream is type of NetStream
|
||||||
host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler)
|
host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler)
|
||||||
|
|
||||||
|
# Use optimal address for client command
|
||||||
|
optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}"
|
||||||
print(
|
print(
|
||||||
"Run client from another console:\n"
|
"Run client from another console:\n"
|
||||||
f"python3 example_net_stream.py "
|
f"python3 example_net_stream.py "
|
||||||
f"-d {host.get_addrs()[0]}\n"
|
f"-d {optimal_addr_with_peer}\n"
|
||||||
)
|
)
|
||||||
print("Waiting for connections...")
|
print("Waiting for connections...")
|
||||||
print("Press Ctrl+C to stop server")
|
print("Press Ctrl+C to stop server")
|
||||||
@ -226,7 +234,7 @@ async def run_enhanced_demo(
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
example_maddr = (
|
example_maddr = (
|
||||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
"/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
from multiaddr import Multiaddr
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -16,6 +16,10 @@ from libp2p.security.noise.transport import (
|
|||||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||||
Transport as NoiseTransport,
|
Transport as NoiseTransport,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@ -31,7 +35,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
@ -40,14 +46,16 @@ async def main():
|
|||||||
# Create a host with the key pair, Noise security, and mplex multiplexer
|
# Create a host with the key pair, Noise security, and mplex multiplexer
|
||||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||||
|
|
||||||
# Configure the listening address
|
# Configure the listening address using the new paradigm
|
||||||
port = 8000
|
port = 8000
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Start the host
|
# Start the host
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print("libp2p has started")
|
print("libp2p has started")
|
||||||
print("libp2p is listening on:", host.get_addrs())
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_addr}")
|
||||||
|
|
||||||
# Connect to bootstrap peers manually
|
# Connect to bootstrap peers manually
|
||||||
bootstrap_list = [
|
bootstrap_list = [
|
||||||
@ -59,7 +67,7 @@ async def main():
|
|||||||
|
|
||||||
for addr in bootstrap_list:
|
for addr in bootstrap_list:
|
||||||
try:
|
try:
|
||||||
peer_info = info_from_p2p_addr(multiaddr.Multiaddr(addr))
|
peer_info = info_from_p2p_addr(Multiaddr(addr))
|
||||||
await host.connect(peer_info)
|
await host.connect(peer_info)
|
||||||
print(f"Connected to {peer_info.peer_id.to_string()}")
|
print(f"Connected to {peer_info.peer_id.to_string()}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
49
examples/doc-examples/example_quic_transport.py
Normal file
49
examples/doc-examples/example_quic_transport.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import secrets
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p import (
|
||||||
|
new_host,
|
||||||
|
)
|
||||||
|
from libp2p.crypto.secp256k1 import (
|
||||||
|
create_new_key_pair,
|
||||||
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Create a key pair for the host
|
||||||
|
secret = secrets.token_bytes(32)
|
||||||
|
key_pair = create_new_key_pair(secret)
|
||||||
|
|
||||||
|
# Create a host with the key pair
|
||||||
|
host = new_host(key_pair=key_pair, enable_quic=True)
|
||||||
|
|
||||||
|
# Configure the listening address using the new paradigm
|
||||||
|
port = 8000
|
||||||
|
listen_addrs = get_available_interfaces(port, protocol="udp")
|
||||||
|
# Convert TCP addresses to QUIC-v1 addresses
|
||||||
|
quic_addrs = []
|
||||||
|
for addr in listen_addrs:
|
||||||
|
addr_str = str(addr).replace("/tcp/", "/udp/") + "/quic-v1"
|
||||||
|
from multiaddr import Multiaddr
|
||||||
|
|
||||||
|
quic_addrs.append(Multiaddr(addr_str))
|
||||||
|
|
||||||
|
optimal_addr = get_optimal_binding_address(port, protocol="udp")
|
||||||
|
optimal_quic_str = str(optimal_addr).replace("/tcp/", "/udp/") + "/quic-v1"
|
||||||
|
|
||||||
|
# Start the host
|
||||||
|
async with host.run(listen_addrs=quic_addrs):
|
||||||
|
print("libp2p has started with QUIC transport")
|
||||||
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_quic_str}")
|
||||||
|
# Keep the host running
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|
||||||
|
# Run the async function
|
||||||
|
trio.run(main)
|
||||||
@ -1,6 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -13,6 +12,10 @@ from libp2p.security.noise.transport import (
|
|||||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||||
Transport as NoiseTransport,
|
Transport as NoiseTransport,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@ -28,7 +31,9 @@ async def main():
|
|||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=key_pair.private_key,
|
||||||
# early_data: Optional data to send during the handshake
|
# early_data: Optional data to send during the handshake
|
||||||
# (None means no early data)
|
# (None means no early data)
|
||||||
# TODO: add early data
|
early_data=None,
|
||||||
|
# with_noise_pipes: Whether to use Noise pipes for additional security features
|
||||||
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a security options dictionary mapping protocol ID to transport
|
# Create a security options dictionary mapping protocol ID to transport
|
||||||
@ -37,14 +42,16 @@ async def main():
|
|||||||
# Create a host with the key pair, Noise security, and mplex multiplexer
|
# Create a host with the key pair, Noise security, and mplex multiplexer
|
||||||
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
host = new_host(key_pair=key_pair, sec_opt=security_options)
|
||||||
|
|
||||||
# Configure the listening address
|
# Configure the listening address using the new paradigm
|
||||||
port = 8000
|
port = 8000
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Start the host
|
# Start the host
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print("libp2p has started")
|
print("libp2p has started")
|
||||||
print("libp2p is listening on:", host.get_addrs())
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_addr}")
|
||||||
# Keep the host running
|
# Keep the host running
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -9,6 +8,10 @@ from libp2p import (
|
|||||||
from libp2p.crypto.secp256k1 import (
|
from libp2p.crypto.secp256k1 import (
|
||||||
create_new_key_pair,
|
create_new_key_pair,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@ -19,14 +22,16 @@ async def main():
|
|||||||
# Create a host with the key pair
|
# Create a host with the key pair
|
||||||
host = new_host(key_pair=key_pair)
|
host = new_host(key_pair=key_pair)
|
||||||
|
|
||||||
# Configure the listening address
|
# Configure the listening address using the new paradigm
|
||||||
port = 8000
|
port = 8000
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
|
||||||
# Start the host
|
# Start the host
|
||||||
async with host.run(listen_addrs=[listen_addr]):
|
async with host.run(listen_addrs=listen_addrs):
|
||||||
print("libp2p has started with TCP transport")
|
print("libp2p has started with TCP transport")
|
||||||
print("libp2p is listening on:", host.get_addrs())
|
print("libp2p is listening on:", host.get_addrs())
|
||||||
|
print(f"Optimal address: {optimal_addr}")
|
||||||
# Keep the host running
|
# Keep the host running
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ This example shows how to:
|
|||||||
2. Use different load balancing strategies
|
2. Use different load balancing strategies
|
||||||
3. Access multiple connections through the new API
|
3. Access multiple connections through the new API
|
||||||
4. Maintain backward compatibility
|
4. Maintain backward compatibility
|
||||||
|
5. Use the new address paradigm for network configuration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -15,6 +16,7 @@ import trio
|
|||||||
|
|
||||||
from libp2p import new_swarm
|
from libp2p import new_swarm
|
||||||
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
from libp2p.network.swarm import ConnectionConfig, RetryConfig
|
||||||
|
from libp2p.utils import get_available_interfaces, get_optimal_binding_address
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -103,10 +105,45 @@ async def example_backward_compatibility() -> None:
|
|||||||
logger.info("Backward compatibility example completed")
|
logger.info("Backward compatibility example completed")
|
||||||
|
|
||||||
|
|
||||||
|
async def example_network_address_paradigm() -> None:
|
||||||
|
"""Example of using the new address paradigm with multiple connections."""
|
||||||
|
logger.info("Demonstrating network address paradigm...")
|
||||||
|
|
||||||
|
# Get available interfaces using the new paradigm
|
||||||
|
port = 8000 # Example port
|
||||||
|
available_interfaces = get_available_interfaces(port)
|
||||||
|
logger.info(f"Available interfaces: {available_interfaces}")
|
||||||
|
|
||||||
|
# Get optimal binding address
|
||||||
|
optimal_address = get_optimal_binding_address(port)
|
||||||
|
logger.info(f"Optimal binding address: {optimal_address}")
|
||||||
|
|
||||||
|
# Create connection config for multiple connections with network awareness
|
||||||
|
connection_config = ConnectionConfig(
|
||||||
|
max_connections_per_peer=3, load_balancing_strategy="round_robin"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create swarm with address paradigm
|
||||||
|
swarm = new_swarm(connection_config=connection_config)
|
||||||
|
|
||||||
|
logger.info("Network address paradigm features:")
|
||||||
|
logger.info(" - get_available_interfaces() for interface discovery")
|
||||||
|
logger.info(" - get_optimal_binding_address() for smart address selection")
|
||||||
|
logger.info(" - Multiple connections with proper network binding")
|
||||||
|
|
||||||
|
await swarm.close()
|
||||||
|
logger.info("Network address paradigm example completed")
|
||||||
|
|
||||||
|
|
||||||
async def example_production_ready_config() -> None:
|
async def example_production_ready_config() -> None:
|
||||||
"""Example of production-ready configuration."""
|
"""Example of production-ready configuration."""
|
||||||
logger.info("Creating swarm with production-ready configuration...")
|
logger.info("Creating swarm with production-ready configuration...")
|
||||||
|
|
||||||
|
# Get optimal network configuration using the new paradigm
|
||||||
|
port = 8001 # Example port
|
||||||
|
optimal_address = get_optimal_binding_address(port)
|
||||||
|
logger.info(f"Using optimal binding address: {optimal_address}")
|
||||||
|
|
||||||
# Production-ready retry configuration
|
# Production-ready retry configuration
|
||||||
retry_config = RetryConfig(
|
retry_config = RetryConfig(
|
||||||
max_retries=3, # Reasonable retry limit
|
max_retries=3, # Reasonable retry limit
|
||||||
@ -156,6 +193,9 @@ async def main() -> None:
|
|||||||
await example_backward_compatibility()
|
await example_backward_compatibility()
|
||||||
logger.info("-" * 30)
|
logger.info("-" * 30)
|
||||||
|
|
||||||
|
await example_network_address_paradigm()
|
||||||
|
logger.info("-" * 30)
|
||||||
|
|
||||||
await example_production_ready_config()
|
await example_production_ready_config()
|
||||||
logger.info("-" * 30)
|
logger.info("-" * 30)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
@ -26,8 +27,14 @@ from libp2p.peer.peerinfo import (
|
|||||||
from libp2p.utils.address_validation import (
|
from libp2p.utils.address_validation import (
|
||||||
find_free_port,
|
find_free_port,
|
||||||
get_available_interfaces,
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Configure minimal logging
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("libp2p").setLevel(logging.WARNING)
|
||||||
|
|
||||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||||
MAX_READ_LEN = 2**32 - 1
|
MAX_READ_LEN = 2**32 - 1
|
||||||
|
|
||||||
@ -76,9 +83,13 @@ async def run(port: int, destination: str, seed: int | None = None) -> None:
|
|||||||
for addr in listen_addr:
|
for addr in listen_addr:
|
||||||
print(f"{addr}/p2p/{peer_id}")
|
print(f"{addr}/p2p/{peer_id}")
|
||||||
|
|
||||||
|
# Get optimal address for display
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = f"{optimal_addr}/p2p/{peer_id}"
|
||||||
|
|
||||||
print(
|
print(
|
||||||
"\nRun this from the same folder in another console:\n\n"
|
"\nRun this from the same folder in another console:\n\n"
|
||||||
f"echo-demo -d {host.get_addrs()[0]}\n"
|
f"echo-demo -d {optimal_addr_with_peer}\n"
|
||||||
)
|
)
|
||||||
print("Waiting for incoming connections...")
|
print("Waiting for incoming connections...")
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
@ -114,7 +125,7 @@ def main() -> None:
|
|||||||
where <DESTINATION> is the multiaddress of the previous listener host.
|
where <DESTINATION> is the multiaddress of the previous listener host.
|
||||||
"""
|
"""
|
||||||
example_maddr = (
|
example_maddr = (
|
||||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
"/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||||
)
|
)
|
||||||
parser = argparse.ArgumentParser(description=description)
|
parser = argparse.ArgumentParser(description=description)
|
||||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||||
|
|||||||
207
examples/echo/echo_quic.py
Normal file
207
examples/echo/echo_quic.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
QUIC Echo Example - Fixed version with proper client/server separation
|
||||||
|
|
||||||
|
This program demonstrates a simple echo protocol using QUIC transport where a peer
|
||||||
|
listens for connections and copies back any input received on a stream.
|
||||||
|
|
||||||
|
Fixed to properly separate client and server modes - clients don't start listeners.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from multiaddr import Multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p import new_host
|
||||||
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.network.stream.net_stream import INetStream
|
||||||
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
|
|
||||||
|
# Configure minimal logging
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("libp2p").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
async def _echo_stream_handler(stream: INetStream) -> None:
|
||||||
|
try:
|
||||||
|
msg = await stream.read()
|
||||||
|
await stream.write(msg)
|
||||||
|
await stream.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Echo handler error: {e}")
|
||||||
|
try:
|
||||||
|
await stream.close()
|
||||||
|
except: # noqa: E722
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server(port: int, seed: int | None = None) -> None:
|
||||||
|
"""Run echo server with QUIC transport."""
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
find_free_port,
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
if port <= 0:
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
|
# For QUIC, we need UDP addresses - use the new address paradigm
|
||||||
|
tcp_addrs = get_available_interfaces(port)
|
||||||
|
# Convert TCP addresses to QUIC addresses
|
||||||
|
quic_addrs = []
|
||||||
|
for addr in tcp_addrs:
|
||||||
|
addr_str = str(addr).replace("/tcp/", "/udp/") + "/quic"
|
||||||
|
quic_addrs.append(Multiaddr(addr_str))
|
||||||
|
|
||||||
|
if seed:
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
secret_number = random.getrandbits(32 * 8)
|
||||||
|
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||||
|
else:
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
secret = secrets.token_bytes(32)
|
||||||
|
|
||||||
|
# Create host with QUIC transport
|
||||||
|
host = new_host(
|
||||||
|
enable_quic=True,
|
||||||
|
key_pair=create_new_key_pair(secret),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Server mode: start listener
|
||||||
|
async with host.run(listen_addrs=quic_addrs):
|
||||||
|
try:
|
||||||
|
print(f"I am {host.get_id().to_string()}")
|
||||||
|
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
|
||||||
|
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
|
||||||
|
print("Listener ready, listening on:")
|
||||||
|
for addr in all_addrs:
|
||||||
|
print(f"{addr}")
|
||||||
|
|
||||||
|
# Use optimal address for the client command
|
||||||
|
optimal_tcp = get_optimal_binding_address(port)
|
||||||
|
optimal_quic_str = str(optimal_tcp).replace("/tcp/", "/udp/") + "/quic"
|
||||||
|
peer_id = host.get_id().to_string()
|
||||||
|
optimal_quic_with_peer = f"{optimal_quic_str}/p2p/{peer_id}"
|
||||||
|
print(
|
||||||
|
f"\nRun this from the same folder in another console:\n\n"
|
||||||
|
f"python3 ./examples/echo/echo_quic.py -d {optimal_quic_with_peer}\n"
|
||||||
|
)
|
||||||
|
print("Waiting for incoming QUIC connections...")
|
||||||
|
await trio.sleep_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Closing server gracefully...")
|
||||||
|
await host.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def run_client(destination: str, seed: int | None = None) -> None:
|
||||||
|
"""Run echo client with QUIC transport."""
|
||||||
|
if seed:
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
secret_number = random.getrandbits(32 * 8)
|
||||||
|
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||||
|
else:
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
secret = secrets.token_bytes(32)
|
||||||
|
|
||||||
|
# Create host with QUIC transport
|
||||||
|
host = new_host(
|
||||||
|
enable_quic=True,
|
||||||
|
key_pair=create_new_key_pair(secret),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Client mode: NO listener, just connect
|
||||||
|
async with host.run(listen_addrs=[]): # Empty listen_addrs for client
|
||||||
|
print(f"I am {host.get_id().to_string()}")
|
||||||
|
|
||||||
|
maddr = Multiaddr(destination)
|
||||||
|
info = info_from_p2p_addr(maddr)
|
||||||
|
|
||||||
|
# Connect to server
|
||||||
|
print("STARTING CLIENT CONNECTION PROCESS")
|
||||||
|
await host.connect(info)
|
||||||
|
print("CLIENT CONNECTED TO SERVER")
|
||||||
|
|
||||||
|
# Start a stream with the destination
|
||||||
|
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||||
|
|
||||||
|
msg = b"hi, there!\n"
|
||||||
|
|
||||||
|
await stream.write(msg)
|
||||||
|
response = await stream.read()
|
||||||
|
|
||||||
|
print(f"Sent: {msg.decode('utf-8')}")
|
||||||
|
print(f"Got: {response.decode('utf-8')}")
|
||||||
|
await stream.close()
|
||||||
|
await host.disconnect(info.peer_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Run echo server or client with QUIC transport.
|
||||||
|
|
||||||
|
Fixed version that properly separates client and server modes.
|
||||||
|
"""
|
||||||
|
if not destination: # Server mode
|
||||||
|
await run_server(port, seed)
|
||||||
|
else: # Client mode
|
||||||
|
await run_client(destination, seed)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main function - help text updated for QUIC."""
|
||||||
|
description = """
|
||||||
|
This program demonstrates a simple echo protocol using QUIC
|
||||||
|
transport where a peer listens for connections and copies back
|
||||||
|
any input received on a stream.
|
||||||
|
|
||||||
|
QUIC provides built-in TLS security and stream multiplexing over UDP.
|
||||||
|
|
||||||
|
To use it, first run 'echo-quic-demo -p <PORT>', where <PORT> is
|
||||||
|
the UDP port number. Then, run another host with ,
|
||||||
|
'echo-quic-demo -d <DESTINATION>'
|
||||||
|
where <DESTINATION> is the QUIC multiaddress of the previous listener host.
|
||||||
|
"""
|
||||||
|
|
||||||
|
example_maddr = "/ip4/[HOST_IP]/udp/8000/quic/p2p/QmQn4SwGkDZKkUEpBRBv"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=description)
|
||||||
|
parser.add_argument("-p", "--port", default=0, type=int, help="UDP port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--destination",
|
||||||
|
type=str,
|
||||||
|
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
help="provide a seed to the random number generator",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
trio.run(run, args.port, args.destination, args.seed)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -20,6 +20,11 @@ from libp2p.peer.peerinfo import (
|
|||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Configure minimal logging
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("libp2p").setLevel(logging.WARNING)
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.identity.identify-example")
|
logger = logging.getLogger("libp2p.identity.identify-example")
|
||||||
|
|
||||||
|
|
||||||
@ -58,11 +63,19 @@ def print_identify_response(identify_response: Identify):
|
|||||||
|
|
||||||
|
|
||||||
async def run(port: int, destination: str, use_varint_format: bool = True) -> None:
|
async def run(port: int, destination: str, use_varint_format: bool = True) -> None:
|
||||||
localhost_ip = "0.0.0.0"
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
if not destination:
|
if not destination:
|
||||||
# Create first host (listener)
|
# Create first host (listener)
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
if port <= 0:
|
||||||
|
from libp2p.utils.address_validation import find_free_port
|
||||||
|
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
|
listen_addrs = get_available_interfaces(port)
|
||||||
host_a = new_host()
|
host_a = new_host()
|
||||||
|
|
||||||
# Set up identify handler with specified format
|
# Set up identify handler with specified format
|
||||||
@ -73,25 +86,49 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
|||||||
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, identify_handler)
|
host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, identify_handler)
|
||||||
|
|
||||||
async with (
|
async with (
|
||||||
host_a.run(listen_addrs=[listen_addr]),
|
host_a.run(listen_addrs=listen_addrs),
|
||||||
trio.open_nursery() as nursery,
|
trio.open_nursery() as nursery,
|
||||||
):
|
):
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host_a.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host_a.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
# Get all available addresses with peer ID
|
||||||
# connections
|
all_addrs = host_a.get_addrs()
|
||||||
server_addr = str(host_a.get_addrs()[0])
|
|
||||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
|
||||||
|
|
||||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
if use_varint_format:
|
||||||
format_flag = "--raw-format" if not use_varint_format else ""
|
format_name = "length-prefixed"
|
||||||
print(
|
print(f"First host listening (using {format_name} format).")
|
||||||
f"First host listening (using {format_name} format). "
|
print("Listener ready, listening on:\n")
|
||||||
f"Run this from another console:\n\n"
|
for addr in all_addrs:
|
||||||
f"identify-demo {format_flag} -d {client_addr}\n"
|
print(f"{addr}")
|
||||||
)
|
|
||||||
print("Waiting for incoming identify request...")
|
# Use optimal address for the client command
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = (
|
||||||
|
f"{optimal_addr}/p2p/{host_a.get_id().to_string()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"\nRun this from the same folder in another console:\n\n"
|
||||||
|
f"identify-demo -d {optimal_addr_with_peer}\n"
|
||||||
|
)
|
||||||
|
print("Waiting for incoming identify request...")
|
||||||
|
else:
|
||||||
|
format_name = "raw protobuf"
|
||||||
|
print(f"First host listening (using {format_name} format).")
|
||||||
|
print("Listener ready, listening on:\n")
|
||||||
|
for addr in all_addrs:
|
||||||
|
print(f"{addr}")
|
||||||
|
|
||||||
|
# Use optimal address for the client command
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = (
|
||||||
|
f"{optimal_addr}/p2p/{host_a.get_id().to_string()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"\nRun this from the same folder in another console:\n\n"
|
||||||
|
f"identify-demo -d {optimal_addr_with_peer}\n"
|
||||||
|
)
|
||||||
|
print("Waiting for incoming identify request...")
|
||||||
|
|
||||||
# Add a custom handler to show connection events
|
# Add a custom handler to show connection events
|
||||||
async def custom_identify_handler(stream):
|
async def custom_identify_handler(stream):
|
||||||
@ -134,11 +171,20 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Create second host (dialer)
|
# Create second host (dialer)
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
from libp2p.utils.address_validation import (
|
||||||
|
find_free_port,
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
if port <= 0:
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
|
listen_addrs = get_available_interfaces(port)
|
||||||
host_b = new_host()
|
host_b = new_host()
|
||||||
|
|
||||||
async with (
|
async with (
|
||||||
host_b.run(listen_addrs=[listen_addr]),
|
host_b.run(listen_addrs=listen_addrs),
|
||||||
trio.open_nursery() as nursery,
|
trio.open_nursery() as nursery,
|
||||||
):
|
):
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
@ -234,7 +280,7 @@ def main() -> None:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
example_maddr = (
|
example_maddr = (
|
||||||
"/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
"/ip4/[HOST_IP]/tcp/8888/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=description)
|
parser = argparse.ArgumentParser(description=description)
|
||||||
@ -258,7 +304,7 @@ def main() -> None:
|
|||||||
|
|
||||||
# Determine format: use varint (length-prefixed) if --raw-format is specified,
|
# Determine format: use varint (length-prefixed) if --raw-format is specified,
|
||||||
# otherwise use raw protobuf format (old format)
|
# otherwise use raw protobuf format (old format)
|
||||||
use_varint_format = args.raw_format
|
use_varint_format = not args.raw_format
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.destination:
|
if args.destination:
|
||||||
|
|||||||
@ -36,6 +36,9 @@ from libp2p.identity.identify_push import (
|
|||||||
from libp2p.peer.peerinfo import (
|
from libp2p.peer.peerinfo import (
|
||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -207,13 +210,13 @@ async def main() -> None:
|
|||||||
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
|
ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start listening on random ports using the run context manager
|
# Start listening on available interfaces using random ports
|
||||||
listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
listen_addrs_1 = get_available_interfaces(0) # 0 for random port
|
||||||
listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
listen_addrs_2 = get_available_interfaces(0) # 0 for random port
|
||||||
|
|
||||||
async with (
|
async with (
|
||||||
host_1.run([listen_addr_1]),
|
host_1.run(listen_addrs_1),
|
||||||
host_2.run([listen_addr_2]),
|
host_2.run(listen_addrs_2),
|
||||||
trio.open_nursery() as nursery,
|
trio.open_nursery() as nursery,
|
||||||
):
|
):
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
|
|||||||
@ -14,7 +14,7 @@ Usage:
|
|||||||
python identify_push_listener_dialer.py
|
python identify_push_listener_dialer.py
|
||||||
|
|
||||||
# Then in another console, run as a dialer (default port 8889):
|
# Then in another console, run as a dialer (default port 8889):
|
||||||
python identify_push_listener_dialer.py -d /ip4/127.0.0.1/tcp/8888/p2p/PEER_ID
|
python identify_push_listener_dialer.py -d /ip4/[HOST_IP]/tcp/8888/p2p/PEER_ID
|
||||||
(where PEER_ID is the peer ID displayed by the listener)
|
(where PEER_ID is the peer ID displayed by the listener)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -56,6 +56,11 @@ from libp2p.peer.peerinfo import (
|
|||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Configure minimal logging
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("libp2p").setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logger = logging.getLogger("libp2p.identity.identify-push-example")
|
logger = logging.getLogger("libp2p.identity.identify-push-example")
|
||||||
|
|
||||||
@ -194,6 +199,11 @@ async def run_listener(
|
|||||||
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
|
port: int, use_varint_format: bool = True, raw_format_flag: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a host in listener mode."""
|
"""Run a host in listener mode."""
|
||||||
|
from libp2p.utils.address_validation import find_free_port, get_available_interfaces
|
||||||
|
|
||||||
|
if port <= 0:
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
format_name = "length-prefixed" if use_varint_format else "raw protobuf"
|
||||||
print(
|
print(
|
||||||
f"\n==== Starting Identify-Push Listener on port {port} "
|
f"\n==== Starting Identify-Push Listener on port {port} "
|
||||||
@ -215,26 +225,33 @@ async def run_listener(
|
|||||||
custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
|
custom_identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start listening
|
# Start listening on all available interfaces
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with host.run([listen_addr]):
|
async with host.run(listen_addrs):
|
||||||
addr = host.get_addrs()[0]
|
all_addrs = host.get_addrs()
|
||||||
logger.info("Listener host ready!")
|
logger.info("Listener host ready!")
|
||||||
print("Listener host ready!")
|
print("Listener host ready!")
|
||||||
|
|
||||||
logger.info(f"Listening on: {addr}")
|
logger.info("Listener ready, listening on:")
|
||||||
print(f"Listening on: {addr}")
|
print("Listener ready, listening on:")
|
||||||
|
for addr in all_addrs:
|
||||||
|
logger.info(f"{addr}")
|
||||||
|
print(f"{addr}")
|
||||||
|
|
||||||
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
logger.info(f"Peer ID: {host.get_id().pretty()}")
|
||||||
print(f"Peer ID: {host.get_id().pretty()}")
|
print(f"Peer ID: {host.get_id().pretty()}")
|
||||||
|
|
||||||
print("\nRun dialer with command:")
|
# Use the first address as the default for the dialer command
|
||||||
|
default_addr = all_addrs[0]
|
||||||
|
print("\nRun this from the same folder in another console:")
|
||||||
if raw_format_flag:
|
if raw_format_flag:
|
||||||
print(f"identify-push-listener-dialer-demo -d {addr} --raw-format")
|
print(
|
||||||
|
f"identify-push-listener-dialer-demo -d {default_addr} --raw-format"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"identify-push-listener-dialer-demo -d {addr}")
|
print(f"identify-push-listener-dialer-demo -d {default_addr}")
|
||||||
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
|
print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)")
|
||||||
|
|
||||||
# Keep running until interrupted
|
# Keep running until interrupted
|
||||||
@ -274,10 +291,12 @@ async def run_dialer(
|
|||||||
identify_push_handler_for(host, use_varint_format=use_varint_format),
|
identify_push_handler_for(host, use_varint_format=use_varint_format),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start listening on a different port
|
# Start listening on available interfaces
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
from libp2p.utils.address_validation import get_available_interfaces
|
||||||
|
|
||||||
async with host.run([listen_addr]):
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
|
||||||
|
async with host.run(listen_addrs):
|
||||||
logger.info("Dialer host ready!")
|
logger.info("Dialer host ready!")
|
||||||
print("Dialer host ready!")
|
print("Dialer host ready!")
|
||||||
|
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from libp2p.tools.async_service import (
|
|||||||
from libp2p.tools.utils import (
|
from libp2p.tools.utils import (
|
||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
from libp2p.utils.paths import get_script_dir, join_paths
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -53,8 +54,8 @@ logger = logging.getLogger("kademlia-example")
|
|||||||
# Configure DHT module loggers to inherit from the parent logger
|
# Configure DHT module loggers to inherit from the parent logger
|
||||||
# This ensures all kademlia-example.* loggers use the same configuration
|
# This ensures all kademlia-example.* loggers use the same configuration
|
||||||
# Get the directory where this script is located
|
# Get the directory where this script is located
|
||||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
SCRIPT_DIR = get_script_dir(__file__)
|
||||||
SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt")
|
SERVER_ADDR_LOG = join_paths(SCRIPT_DIR, "server_node_addr.txt")
|
||||||
|
|
||||||
# Set the level for all child loggers
|
# Set the level for all child loggers
|
||||||
for module in [
|
for module in [
|
||||||
@ -149,26 +150,43 @@ async def run_node(
|
|||||||
|
|
||||||
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
||||||
host = new_host(key_pair=key_pair)
|
host = new_host(key_pair=key_pair)
|
||||||
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
|
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
from libp2p.utils.address_validation import (
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
|
||||||
|
async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery:
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
peer_id = host.get_id().pretty()
|
peer_id = host.get_id().pretty()
|
||||||
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
|
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
|
||||||
|
logger.info("Listener ready, listening on:")
|
||||||
|
for addr in all_addrs:
|
||||||
|
logger.info(f"{addr}")
|
||||||
|
|
||||||
|
# Use optimal address for the bootstrap command
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}"
|
||||||
|
bootstrap_cmd = f"--bootstrap {optimal_addr_with_peer}"
|
||||||
|
logger.info("To connect to this node, use: %s", bootstrap_cmd)
|
||||||
|
|
||||||
await connect_to_bootstrap_nodes(host, bootstrap_nodes)
|
await connect_to_bootstrap_nodes(host, bootstrap_nodes)
|
||||||
dht = KadDHT(host, dht_mode)
|
dht = KadDHT(host, dht_mode)
|
||||||
# take all peer ids from the host and add them to the dht
|
# take all peer ids from the host and add them to the dht
|
||||||
for peer_id in host.get_peerstore().peer_ids():
|
for peer_id in host.get_peerstore().peer_ids():
|
||||||
await dht.routing_table.add_peer(peer_id)
|
await dht.routing_table.add_peer(peer_id)
|
||||||
logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}")
|
logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}")
|
||||||
bootstrap_cmd = f"--bootstrap {addr_str}"
|
|
||||||
logger.info("To connect to this node, use: %s", bootstrap_cmd)
|
|
||||||
|
|
||||||
# Save server address in server mode
|
# Save server address in server mode
|
||||||
if dht_mode == DHTMode.SERVER:
|
if dht_mode == DHTMode.SERVER:
|
||||||
save_server_addr(addr_str)
|
save_server_addr(str(optimal_addr_with_peer))
|
||||||
|
|
||||||
# Start the DHT service
|
# Start the DHT service
|
||||||
async with background_trio_service(dht):
|
async with background_trio_service(dht):
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
@ -14,6 +13,11 @@ from libp2p.crypto.secp256k1 import (
|
|||||||
)
|
)
|
||||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||||
|
|
||||||
|
# Configure minimal logging
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("libp2p").setLevel(logging.WARNING)
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.discovery.mdns")
|
logger = logging.getLogger("libp2p.discovery.mdns")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
@ -22,34 +26,43 @@ handler.setFormatter(
|
|||||||
)
|
)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
# Set root logger to DEBUG to capture all logs from dependencies
|
|
||||||
logging.getLogger().setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
|
|
||||||
def onPeerDiscovery(peerinfo: PeerInfo):
|
def onPeerDiscovery(peerinfo: PeerInfo):
|
||||||
logger.info(f"Discovered: {peerinfo.peer_id}")
|
logger.info(f"Discovered: {peerinfo.peer_id}")
|
||||||
|
|
||||||
|
|
||||||
async def run(port: int) -> None:
|
async def run(port: int) -> None:
|
||||||
|
from libp2p.utils.address_validation import find_free_port, get_available_interfaces
|
||||||
|
|
||||||
|
if port <= 0:
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
secret = secrets.token_bytes(32)
|
secret = secrets.token_bytes(32)
|
||||||
key_pair = create_new_key_pair(secret)
|
key_pair = create_new_key_pair(secret)
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
|
||||||
peerDiscovery.register_peer_discovered_handler(onPeerDiscovery)
|
peerDiscovery.register_peer_discovered_handler(onPeerDiscovery)
|
||||||
|
|
||||||
print(
|
|
||||||
"Run this from the same folder in another console to "
|
|
||||||
"start another peer on a different port:\n\n"
|
|
||||||
"mdns-demo -p <ANOTHER_PORT>\n"
|
|
||||||
)
|
|
||||||
print("Waiting for mDNS peer discovery events...\n")
|
|
||||||
|
|
||||||
logger.info("Starting peer Discovery")
|
logger.info("Starting peer Discovery")
|
||||||
host = new_host(key_pair=key_pair, enable_mDNS=True)
|
host = new_host(key_pair=key_pair, enable_mDNS=True)
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery:
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
|
||||||
|
print("Listener ready, listening on:")
|
||||||
|
for addr in all_addrs:
|
||||||
|
print(f"{addr}")
|
||||||
|
|
||||||
|
print(
|
||||||
|
"\nRun this from the same folder in another console to "
|
||||||
|
"start another peer on a different port:\n\n"
|
||||||
|
"mdns-demo -p <ANOTHER_PORT>\n"
|
||||||
|
)
|
||||||
|
print("Waiting for mDNS peer discovery events...\n")
|
||||||
|
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
import multiaddr
|
import multiaddr
|
||||||
import trio
|
import trio
|
||||||
@ -16,6 +17,11 @@ from libp2p.peer.peerinfo import (
|
|||||||
info_from_p2p_addr,
|
info_from_p2p_addr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Configure minimal logging
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("libp2p").setLevel(logging.WARNING)
|
||||||
|
|
||||||
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
|
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
|
||||||
PING_LENGTH = 32
|
PING_LENGTH = 32
|
||||||
RESP_TIMEOUT = 60
|
RESP_TIMEOUT = 60
|
||||||
@ -55,20 +61,38 @@ async def send_ping(stream: INetStream) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def run(port: int, destination: str) -> None:
|
async def run(port: int, destination: str) -> None:
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
from libp2p.utils.address_validation import (
|
||||||
host = new_host(listen_addrs=[listen_addr])
|
find_free_port,
|
||||||
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
if port <= 0:
|
||||||
|
port = find_free_port()
|
||||||
|
|
||||||
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
host = new_host(listen_addrs=listen_addrs)
|
||||||
|
|
||||||
|
async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery:
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
if not destination:
|
if not destination:
|
||||||
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||||
|
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
|
||||||
|
print("Listener ready, listening on:\n")
|
||||||
|
for addr in all_addrs:
|
||||||
|
print(f"{addr}")
|
||||||
|
|
||||||
|
# Use optimal address for the client command
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}"
|
||||||
print(
|
print(
|
||||||
"Run this from the same folder in another console:\n\n"
|
f"\nRun this from the same folder in another console:\n\n"
|
||||||
f"ping-demo "
|
f"ping-demo -d {optimal_addr_with_peer}\n"
|
||||||
f"-d {host.get_addrs()[0]}\n"
|
|
||||||
)
|
)
|
||||||
print("Waiting for incoming connection...")
|
print("Waiting for incoming connection...")
|
||||||
|
|
||||||
@ -94,7 +118,7 @@ def main() -> None:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
example_maddr = (
|
example_maddr = (
|
||||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
"/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=description)
|
parser = argparse.ArgumentParser(description=description)
|
||||||
|
|||||||
@ -102,14 +102,16 @@ async def monitor_peer_topics(pubsub, nursery, termination_event):
|
|||||||
|
|
||||||
|
|
||||||
async def run(topic: str, destination: str | None, port: int | None) -> None:
|
async def run(topic: str, destination: str | None, port: int | None) -> None:
|
||||||
# Initialize network settings
|
from libp2p.utils.address_validation import (
|
||||||
localhost_ip = "127.0.0.1"
|
get_available_interfaces,
|
||||||
|
get_optimal_binding_address,
|
||||||
|
)
|
||||||
|
|
||||||
if port is None or port == 0:
|
if port is None or port == 0:
|
||||||
port = find_free_port()
|
port = find_free_port()
|
||||||
logger.info(f"Using random available port: {port}")
|
logger.info(f"Using random available port: {port}")
|
||||||
|
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
|
||||||
# Create a new libp2p host
|
# Create a new libp2p host
|
||||||
host = new_host(
|
host = new_host(
|
||||||
@ -138,12 +140,11 @@ async def run(topic: str, destination: str | None, port: int | None) -> None:
|
|||||||
|
|
||||||
pubsub = Pubsub(host, gossipsub)
|
pubsub = Pubsub(host, gossipsub)
|
||||||
termination_event = trio.Event() # Event to signal termination
|
termination_event = trio.Event() # Event to signal termination
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery:
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
logger.info(f"Node started with peer ID: {host.get_id()}")
|
logger.info(f"Node started with peer ID: {host.get_id()}")
|
||||||
logger.info(f"Listening on: {listen_addr}")
|
|
||||||
logger.info("Initializing PubSub and GossipSub...")
|
logger.info("Initializing PubSub and GossipSub...")
|
||||||
async with background_trio_service(pubsub):
|
async with background_trio_service(pubsub):
|
||||||
async with background_trio_service(gossipsub):
|
async with background_trio_service(gossipsub):
|
||||||
@ -157,10 +158,21 @@ async def run(topic: str, destination: str | None, port: int | None) -> None:
|
|||||||
|
|
||||||
if not destination:
|
if not destination:
|
||||||
# Server mode
|
# Server mode
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
|
||||||
|
logger.info("Listener ready, listening on:")
|
||||||
|
for addr in all_addrs:
|
||||||
|
logger.info(f"{addr}")
|
||||||
|
|
||||||
|
# Use optimal address for the client command
|
||||||
|
optimal_addr = get_optimal_binding_address(port)
|
||||||
|
optimal_addr_with_peer = (
|
||||||
|
f"{optimal_addr}/p2p/{host.get_id().to_string()}"
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Run this script in another console with:\n"
|
f"\nRun this from the same folder in another console:\n\n"
|
||||||
f"pubsub-demo "
|
f"pubsub-demo -d {optimal_addr_with_peer}\n"
|
||||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id()}\n"
|
|
||||||
)
|
)
|
||||||
logger.info("Waiting for peers...")
|
logger.info("Waiting for peers...")
|
||||||
|
|
||||||
@ -182,11 +194,6 @@ async def run(topic: str, destination: str | None, port: int | None) -> None:
|
|||||||
f"Connecting to peer: {info.peer_id} "
|
f"Connecting to peer: {info.peer_id} "
|
||||||
f"using protocols: {protocols_in_maddr}"
|
f"using protocols: {protocols_in_maddr}"
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
"Run this script in another console with:\n"
|
|
||||||
f"pubsub-demo "
|
|
||||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id()}\n"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
await host.connect(info)
|
await host.connect(info)
|
||||||
logger.info(f"Connected to peer: {info.peer_id}")
|
logger.info(f"Connected to peer: {info.peer_id}")
|
||||||
|
|||||||
@ -16,7 +16,6 @@ import random
|
|||||||
import secrets
|
import secrets
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from multiaddr import Multiaddr
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import new_host
|
from libp2p import new_host
|
||||||
@ -130,16 +129,24 @@ async def run_node(port: int, mode: str, demo_interval: int = 30) -> None:
|
|||||||
# Create host and DHT
|
# Create host and DHT
|
||||||
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
||||||
host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES)
|
host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES)
|
||||||
listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
from libp2p.utils.address_validation import get_available_interfaces
|
||||||
|
|
||||||
|
listen_addrs = get_available_interfaces(port)
|
||||||
|
|
||||||
|
async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery:
|
||||||
# Start maintenance tasks
|
# Start maintenance tasks
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
nursery.start_soon(maintain_connections, host)
|
nursery.start_soon(maintain_connections, host)
|
||||||
|
|
||||||
peer_id = host.get_id().pretty()
|
peer_id = host.get_id().pretty()
|
||||||
logger.info(f"Node peer ID: {peer_id}")
|
logger.info(f"Node peer ID: {peer_id}")
|
||||||
logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}")
|
|
||||||
|
# Get all available addresses with peer ID
|
||||||
|
all_addrs = host.get_addrs()
|
||||||
|
logger.info("Listener ready, listening on:")
|
||||||
|
for addr in all_addrs:
|
||||||
|
logger.info(f"{addr}")
|
||||||
|
|
||||||
# Create and start DHT with Random Walk enabled
|
# Create and start DHT with Random Walk enabled
|
||||||
dht = KadDHT(host, dht_mode, enable_random_walk=True)
|
dht = KadDHT(host, dht_mode, enable_random_walk=True)
|
||||||
|
|||||||
446
examples/test_tcp_data_transfer.py
Normal file
446
examples/test_tcp_data_transfer.py
Normal file
@ -0,0 +1,446 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
TCP P2P Data Transfer Test
|
||||||
|
|
||||||
|
This test proves that TCP peer-to-peer data transfer works correctly in libp2p.
|
||||||
|
This serves as a baseline to compare with WebSocket tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from multiaddr import Multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p import create_yamux_muxer_option, new_host
|
||||||
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
|
|
||||||
|
# Test protocol for data exchange
|
||||||
|
TCP_DATA_PROTOCOL = TProtocol("/test/tcp-data-exchange/1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
async def create_tcp_host_pair():
|
||||||
|
"""Create a pair of hosts configured for TCP communication."""
|
||||||
|
# Create key pairs
|
||||||
|
key_pair_a = create_new_key_pair()
|
||||||
|
key_pair_b = create_new_key_pair()
|
||||||
|
|
||||||
|
# Create security options (using plaintext for simplicity)
|
||||||
|
def security_options(kp):
|
||||||
|
return {
|
||||||
|
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
||||||
|
local_key_pair=kp, secure_bytes_provider=None, peerstore=None
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Host A (listener) - TCP transport (default)
|
||||||
|
host_a = new_host(
|
||||||
|
key_pair=key_pair_a,
|
||||||
|
sec_opt=security_options(key_pair_a),
|
||||||
|
muxer_opt=create_yamux_muxer_option(),
|
||||||
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Host B (dialer) - TCP transport (default)
|
||||||
|
host_b = new_host(
|
||||||
|
key_pair=key_pair_b,
|
||||||
|
sec_opt=security_options(key_pair_b),
|
||||||
|
muxer_opt=create_yamux_muxer_option(),
|
||||||
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")],
|
||||||
|
)
|
||||||
|
|
||||||
|
return host_a, host_b
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_tcp_basic_connection():
|
||||||
|
"""Test basic TCP connection establishment."""
|
||||||
|
host_a, host_b = await create_tcp_host_pair()
|
||||||
|
|
||||||
|
connection_established = False
|
||||||
|
|
||||||
|
async def connection_handler(stream):
|
||||||
|
nonlocal connection_established
|
||||||
|
connection_established = True
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
host_a.set_stream_handler(TCP_DATA_PROTOCOL, connection_handler)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||||
|
host_b.run(listen_addrs=[]),
|
||||||
|
):
|
||||||
|
# Get host A's listen address
|
||||||
|
listen_addrs = host_a.get_addrs()
|
||||||
|
assert listen_addrs, "Host A should have listen addresses"
|
||||||
|
|
||||||
|
# Extract TCP address
|
||||||
|
tcp_addr = None
|
||||||
|
for addr in listen_addrs:
|
||||||
|
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||||
|
tcp_addr = addr
|
||||||
|
break
|
||||||
|
|
||||||
|
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||||
|
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||||
|
|
||||||
|
# Create peer info for host A
|
||||||
|
peer_info = info_from_p2p_addr(tcp_addr)
|
||||||
|
|
||||||
|
# Host B connects to host A
|
||||||
|
await host_b.connect(peer_info)
|
||||||
|
print("✅ TCP connection established")
|
||||||
|
|
||||||
|
# Open a stream to test the connection
|
||||||
|
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Wait a bit for the handler to be called
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
assert connection_established, "TCP connection handler should have been called"
|
||||||
|
print("✅ TCP basic connection test successful!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_tcp_data_transfer():
|
||||||
|
"""Test TCP peer-to-peer data transfer."""
|
||||||
|
host_a, host_b = await create_tcp_host_pair()
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
test_data = b"Hello TCP P2P Data Transfer! This is a test message."
|
||||||
|
received_data = None
|
||||||
|
transfer_complete = trio.Event()
|
||||||
|
|
||||||
|
async def data_handler(stream):
|
||||||
|
nonlocal received_data
|
||||||
|
try:
|
||||||
|
# Read the incoming data
|
||||||
|
received_data = await stream.read(len(test_data))
|
||||||
|
# Echo it back to confirm successful transfer
|
||||||
|
await stream.write(received_data)
|
||||||
|
await stream.close()
|
||||||
|
transfer_complete.set()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Handler error: {e}")
|
||||||
|
transfer_complete.set()
|
||||||
|
|
||||||
|
host_a.set_stream_handler(TCP_DATA_PROTOCOL, data_handler)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||||
|
host_b.run(listen_addrs=[]),
|
||||||
|
):
|
||||||
|
# Get host A's listen address
|
||||||
|
listen_addrs = host_a.get_addrs()
|
||||||
|
assert listen_addrs, "Host A should have listen addresses"
|
||||||
|
|
||||||
|
# Extract TCP address
|
||||||
|
tcp_addr = None
|
||||||
|
for addr in listen_addrs:
|
||||||
|
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||||
|
tcp_addr = addr
|
||||||
|
break
|
||||||
|
|
||||||
|
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||||
|
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||||
|
|
||||||
|
# Create peer info for host A
|
||||||
|
peer_info = info_from_p2p_addr(tcp_addr)
|
||||||
|
|
||||||
|
# Host B connects to host A
|
||||||
|
await host_b.connect(peer_info)
|
||||||
|
print("✅ TCP connection established")
|
||||||
|
|
||||||
|
# Open a stream for data transfer
|
||||||
|
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||||
|
print("✅ TCP stream opened")
|
||||||
|
|
||||||
|
# Send test data
|
||||||
|
await stream.write(test_data)
|
||||||
|
print(f"📤 Sent data: {test_data}")
|
||||||
|
|
||||||
|
# Read echoed data back
|
||||||
|
echoed_data = await stream.read(len(test_data))
|
||||||
|
print(f"📥 Received echo: {echoed_data}")
|
||||||
|
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Wait for transfer to complete
|
||||||
|
with trio.fail_after(5.0): # 5 second timeout
|
||||||
|
await transfer_complete.wait()
|
||||||
|
|
||||||
|
# Verify data transfer
|
||||||
|
assert received_data == test_data, (
|
||||||
|
f"Data mismatch: {received_data} != {test_data}"
|
||||||
|
)
|
||||||
|
assert echoed_data == test_data, f"Echo mismatch: {echoed_data} != {test_data}"
|
||||||
|
|
||||||
|
print("✅ TCP P2P data transfer successful!")
|
||||||
|
print(f" Original: {test_data}")
|
||||||
|
print(f" Received: {received_data}")
|
||||||
|
print(f" Echoed: {echoed_data}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_tcp_large_data_transfer():
|
||||||
|
"""Test TCP with larger data payloads."""
|
||||||
|
host_a, host_b = await create_tcp_host_pair()
|
||||||
|
|
||||||
|
# Large test data (10KB)
|
||||||
|
test_data = b"TCP Large Data Test! " * 500 # ~10KB
|
||||||
|
received_data = None
|
||||||
|
transfer_complete = trio.Event()
|
||||||
|
|
||||||
|
async def large_data_handler(stream):
|
||||||
|
nonlocal received_data
|
||||||
|
try:
|
||||||
|
# Read data in chunks
|
||||||
|
chunks = []
|
||||||
|
total_received = 0
|
||||||
|
expected_size = len(test_data)
|
||||||
|
|
||||||
|
while total_received < expected_size:
|
||||||
|
chunk = await stream.read(min(1024, expected_size - total_received))
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
chunks.append(chunk)
|
||||||
|
total_received += len(chunk)
|
||||||
|
|
||||||
|
received_data = b"".join(chunks)
|
||||||
|
|
||||||
|
# Send back confirmation
|
||||||
|
await stream.write(b"RECEIVED_OK")
|
||||||
|
await stream.close()
|
||||||
|
transfer_complete.set()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Large data handler error: {e}")
|
||||||
|
transfer_complete.set()
|
||||||
|
|
||||||
|
host_a.set_stream_handler(TCP_DATA_PROTOCOL, large_data_handler)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||||
|
host_b.run(listen_addrs=[]),
|
||||||
|
):
|
||||||
|
# Get host A's listen address
|
||||||
|
listen_addrs = host_a.get_addrs()
|
||||||
|
assert listen_addrs, "Host A should have listen addresses"
|
||||||
|
|
||||||
|
# Extract TCP address
|
||||||
|
tcp_addr = None
|
||||||
|
for addr in listen_addrs:
|
||||||
|
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||||
|
tcp_addr = addr
|
||||||
|
break
|
||||||
|
|
||||||
|
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||||
|
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||||
|
print(f"📊 Test data size: {len(test_data)} bytes")
|
||||||
|
|
||||||
|
# Create peer info for host A
|
||||||
|
peer_info = info_from_p2p_addr(tcp_addr)
|
||||||
|
|
||||||
|
# Host B connects to host A
|
||||||
|
await host_b.connect(peer_info)
|
||||||
|
print("✅ TCP connection established")
|
||||||
|
|
||||||
|
# Open a stream for data transfer
|
||||||
|
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||||
|
print("✅ TCP stream opened")
|
||||||
|
|
||||||
|
# Send large test data in chunks
|
||||||
|
chunk_size = 1024
|
||||||
|
sent_bytes = 0
|
||||||
|
for i in range(0, len(test_data), chunk_size):
|
||||||
|
chunk = test_data[i : i + chunk_size]
|
||||||
|
await stream.write(chunk)
|
||||||
|
sent_bytes += len(chunk)
|
||||||
|
if sent_bytes % (chunk_size * 4) == 0: # Progress every 4KB
|
||||||
|
print(f"📤 Sent {sent_bytes}/{len(test_data)} bytes")
|
||||||
|
|
||||||
|
print(f"📤 Sent all {len(test_data)} bytes")
|
||||||
|
|
||||||
|
# Read confirmation
|
||||||
|
confirmation = await stream.read(1024)
|
||||||
|
print(f"📥 Received confirmation: {confirmation}")
|
||||||
|
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
# Wait for transfer to complete
|
||||||
|
with trio.fail_after(10.0): # 10 second timeout for large data
|
||||||
|
await transfer_complete.wait()
|
||||||
|
|
||||||
|
# Verify data transfer
|
||||||
|
assert received_data is not None, "No data was received"
|
||||||
|
assert received_data == test_data, (
|
||||||
|
"Large data transfer failed:"
|
||||||
|
+ f" sizes {len(received_data)} != {len(test_data)}"
|
||||||
|
)
|
||||||
|
assert confirmation == b"RECEIVED_OK", f"Confirmation failed: {confirmation}"
|
||||||
|
|
||||||
|
print("✅ TCP large data transfer successful!")
|
||||||
|
print(f" Data size: {len(test_data)} bytes")
|
||||||
|
print(f" Received: {len(received_data)} bytes")
|
||||||
|
print(f" Match: {received_data == test_data}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_tcp_bidirectional_transfer():
|
||||||
|
"""Test bidirectional data transfer over TCP."""
|
||||||
|
host_a, host_b = await create_tcp_host_pair()
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
data_a_to_b = b"Message from Host A to Host B via TCP"
|
||||||
|
data_b_to_a = b"Response from Host B to Host A via TCP"
|
||||||
|
|
||||||
|
received_on_a = None
|
||||||
|
received_on_b = None
|
||||||
|
transfer_complete_a = trio.Event()
|
||||||
|
transfer_complete_b = trio.Event()
|
||||||
|
|
||||||
|
async def handler_a(stream):
|
||||||
|
nonlocal received_on_a
|
||||||
|
try:
|
||||||
|
# Read data from B
|
||||||
|
received_on_a = await stream.read(len(data_b_to_a))
|
||||||
|
print(f"🅰️ Host A received: {received_on_a}")
|
||||||
|
await stream.close()
|
||||||
|
transfer_complete_a.set()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Handler A error: {e}")
|
||||||
|
transfer_complete_a.set()
|
||||||
|
|
||||||
|
async def handler_b(stream):
|
||||||
|
nonlocal received_on_b
|
||||||
|
try:
|
||||||
|
# Read data from A
|
||||||
|
received_on_b = await stream.read(len(data_a_to_b))
|
||||||
|
print(f"🅱️ Host B received: {received_on_b}")
|
||||||
|
await stream.close()
|
||||||
|
transfer_complete_b.set()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Handler B error: {e}")
|
||||||
|
transfer_complete_b.set()
|
||||||
|
|
||||||
|
# Set up handlers on both hosts
|
||||||
|
protocol_a_to_b = TProtocol("/test/tcp-a-to-b/1.0.0")
|
||||||
|
protocol_b_to_a = TProtocol("/test/tcp-b-to-a/1.0.0")
|
||||||
|
|
||||||
|
host_a.set_stream_handler(protocol_b_to_a, handler_a)
|
||||||
|
host_b.set_stream_handler(protocol_a_to_b, handler_b)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||||
|
host_b.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||||
|
):
|
||||||
|
# Get addresses
|
||||||
|
addrs_a = host_a.get_addrs()
|
||||||
|
addrs_b = host_b.get_addrs()
|
||||||
|
|
||||||
|
assert addrs_a and addrs_b, "Both hosts should have addresses"
|
||||||
|
|
||||||
|
# Extract TCP addresses
|
||||||
|
tcp_addr_a = next(
|
||||||
|
(
|
||||||
|
addr
|
||||||
|
for addr in addrs_a
|
||||||
|
if "/tcp/" in str(addr) and "/ws" not in str(addr)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
tcp_addr_b = next(
|
||||||
|
(
|
||||||
|
addr
|
||||||
|
for addr in addrs_b
|
||||||
|
if "/tcp/" in str(addr) and "/ws" not in str(addr)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tcp_addr_a and tcp_addr_b, (
|
||||||
|
f"TCP addresses not found: A={addrs_a}, B={addrs_b}"
|
||||||
|
)
|
||||||
|
print(f"🔗 Host A listening on: {tcp_addr_a}")
|
||||||
|
print(f"🔗 Host B listening on: {tcp_addr_b}")
|
||||||
|
|
||||||
|
# Create peer infos
|
||||||
|
peer_info_a = info_from_p2p_addr(tcp_addr_a)
|
||||||
|
peer_info_b = info_from_p2p_addr(tcp_addr_b)
|
||||||
|
|
||||||
|
# Establish connections
|
||||||
|
await host_b.connect(peer_info_a)
|
||||||
|
await host_a.connect(peer_info_b)
|
||||||
|
print("✅ Bidirectional TCP connections established")
|
||||||
|
|
||||||
|
# Send data A -> B
|
||||||
|
stream_a_to_b = await host_a.new_stream(peer_info_b.peer_id, [protocol_a_to_b])
|
||||||
|
await stream_a_to_b.write(data_a_to_b)
|
||||||
|
print(f"📤 A->B: {data_a_to_b}")
|
||||||
|
await stream_a_to_b.close()
|
||||||
|
|
||||||
|
# Send data B -> A
|
||||||
|
stream_b_to_a = await host_b.new_stream(peer_info_a.peer_id, [protocol_b_to_a])
|
||||||
|
await stream_b_to_a.write(data_b_to_a)
|
||||||
|
print(f"📤 B->A: {data_b_to_a}")
|
||||||
|
await stream_b_to_a.close()
|
||||||
|
|
||||||
|
# Wait for both transfers to complete
|
||||||
|
with trio.fail_after(5.0):
|
||||||
|
await transfer_complete_a.wait()
|
||||||
|
await transfer_complete_b.wait()
|
||||||
|
|
||||||
|
# Verify bidirectional transfer
|
||||||
|
assert received_on_a == data_b_to_a, f"A received wrong data: {received_on_a}"
|
||||||
|
assert received_on_b == data_a_to_b, f"B received wrong data: {received_on_b}"
|
||||||
|
|
||||||
|
print("✅ TCP bidirectional data transfer successful!")
|
||||||
|
print(f" A->B: {data_a_to_b}")
|
||||||
|
print(f" B->A: {data_b_to_a}")
|
||||||
|
print(f" ✓ A got: {received_on_a}")
|
||||||
|
print(f" ✓ B got: {received_on_b}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests directly
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
print("🧪 Running TCP P2P Data Transfer Tests")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
async def run_all_tcp_tests():
|
||||||
|
try:
|
||||||
|
print("\n1. Testing basic TCP connection...")
|
||||||
|
await test_tcp_basic_connection()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Basic TCP connection test failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\n2. Testing TCP data transfer...")
|
||||||
|
await test_tcp_data_transfer()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ TCP data transfer test failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\n3. Testing TCP large data transfer...")
|
||||||
|
await test_tcp_large_data_transfer()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ TCP large data transfer test failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\n4. Testing TCP bidirectional transfer...")
|
||||||
|
await test_tcp_bidirectional_transfer()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ TCP bidirectional transfer test failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("🏁 TCP P2P Tests Complete - All Tests PASSED!")
|
||||||
|
|
||||||
|
trio.run(run_all_tcp_tests)
|
||||||
210
examples/transport_integration_demo.py
Normal file
210
examples/transport_integration_demo.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Demo script showing the new transport integration capabilities in py-libp2p.
|
||||||
|
|
||||||
|
This script demonstrates:
|
||||||
|
1. How to use the transport registry
|
||||||
|
2. How to create transports dynamically based on multiaddrs
|
||||||
|
3. How to register custom transports
|
||||||
|
4. How the new system automatically selects the right transport
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add the libp2p directory to the path so we can import it
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
import multiaddr
|
||||||
|
|
||||||
|
from libp2p.transport import (
|
||||||
|
create_transport,
|
||||||
|
create_transport_for_multiaddr,
|
||||||
|
get_supported_transport_protocols,
|
||||||
|
get_transport_registry,
|
||||||
|
register_transport,
|
||||||
|
)
|
||||||
|
from libp2p.transport.tcp.tcp import TCP
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def demo_transport_registry():
|
||||||
|
"""Demonstrate the transport registry functionality."""
|
||||||
|
print("🔧 Transport Registry Demo")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Get the global registry
|
||||||
|
registry = get_transport_registry()
|
||||||
|
|
||||||
|
# Show supported protocols
|
||||||
|
supported = get_supported_transport_protocols()
|
||||||
|
print(f"Supported transport protocols: {supported}")
|
||||||
|
|
||||||
|
# Show registered transports
|
||||||
|
print("\nRegistered transports:")
|
||||||
|
for protocol in supported:
|
||||||
|
transport_class = registry.get_transport(protocol)
|
||||||
|
class_name = transport_class.__name__ if transport_class else "None"
|
||||||
|
print(f" {protocol}: {class_name}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def demo_transport_factory():
|
||||||
|
"""Demonstrate the transport factory functions."""
|
||||||
|
print("🏭 Transport Factory Demo")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Create a dummy upgrader for WebSocket transport
|
||||||
|
upgrader = TransportUpgrader({}, {})
|
||||||
|
|
||||||
|
# Create transports using the factory function
|
||||||
|
try:
|
||||||
|
tcp_transport = create_transport("tcp")
|
||||||
|
print(f"✅ Created TCP transport: {type(tcp_transport).__name__}")
|
||||||
|
|
||||||
|
ws_transport = create_transport("ws", upgrader)
|
||||||
|
print(f"✅ Created WebSocket transport: {type(ws_transport).__name__}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating transport: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def demo_multiaddr_transport_selection():
|
||||||
|
"""Demonstrate automatic transport selection based on multiaddrs."""
|
||||||
|
print("🎯 Multiaddr Transport Selection Demo")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Create a dummy upgrader
|
||||||
|
upgrader = TransportUpgrader({}, {})
|
||||||
|
|
||||||
|
# Test different multiaddr types
|
||||||
|
test_addrs = [
|
||||||
|
"/ip4/127.0.0.1/tcp/8080",
|
||||||
|
"/ip4/127.0.0.1/tcp/8080/ws",
|
||||||
|
"/ip6/::1/tcp/8080/ws",
|
||||||
|
"/dns4/example.com/tcp/443/ws",
|
||||||
|
]
|
||||||
|
|
||||||
|
for addr_str in test_addrs:
|
||||||
|
try:
|
||||||
|
maddr = multiaddr.Multiaddr(addr_str)
|
||||||
|
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||||
|
|
||||||
|
if transport:
|
||||||
|
print(f"✅ {addr_str} -> {type(transport).__name__}")
|
||||||
|
else:
|
||||||
|
print(f"❌ {addr_str} -> No transport found")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ {addr_str} -> Error: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def demo_custom_transport_registration():
|
||||||
|
"""Demonstrate how to register custom transports."""
|
||||||
|
print("🔧 Custom Transport Registration Demo")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Show current supported protocols
|
||||||
|
print(f"Before registration: {get_supported_transport_protocols()}")
|
||||||
|
|
||||||
|
# Register a custom transport (using TCP as an example)
|
||||||
|
class CustomTCPTransport(TCP):
|
||||||
|
"""Custom TCP transport for demonstration."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.custom_flag = True
|
||||||
|
|
||||||
|
# Register the custom transport
|
||||||
|
register_transport("custom_tcp", CustomTCPTransport)
|
||||||
|
|
||||||
|
# Show updated supported protocols
|
||||||
|
print(f"After registration: {get_supported_transport_protocols()}")
|
||||||
|
|
||||||
|
# Test creating the custom transport
|
||||||
|
try:
|
||||||
|
custom_transport = create_transport("custom_tcp")
|
||||||
|
print(f"✅ Created custom transport: {type(custom_transport).__name__}")
|
||||||
|
# Check if it has the custom flag (type-safe way)
|
||||||
|
if hasattr(custom_transport, "custom_flag"):
|
||||||
|
flag_value = getattr(custom_transport, "custom_flag", "Not found")
|
||||||
|
print(f" Custom flag: {flag_value}")
|
||||||
|
else:
|
||||||
|
print(" Custom flag: Not found")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating custom transport: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def demo_integration_with_libp2p():
|
||||||
|
"""Demonstrate how the new system integrates with libp2p."""
|
||||||
|
print("🚀 Libp2p Integration Demo")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
print("The new transport system integrates seamlessly with libp2p:")
|
||||||
|
print()
|
||||||
|
print("1. ✅ Automatic transport selection based on multiaddr")
|
||||||
|
print("2. ✅ Support for WebSocket (/ws) protocol")
|
||||||
|
print("3. ✅ Fallback to TCP for backward compatibility")
|
||||||
|
print("4. ✅ Easy registration of new transport protocols")
|
||||||
|
print("5. ✅ No changes needed to existing libp2p code")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Example usage in libp2p:")
|
||||||
|
print(" # This will automatically use WebSocket transport")
|
||||||
|
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080/ws'])")
|
||||||
|
print()
|
||||||
|
print(" # This will automatically use TCP transport")
|
||||||
|
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080'])")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all demos."""
|
||||||
|
print("🎉 Py-libp2p Transport Integration Demo")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Run all demos
|
||||||
|
demo_transport_registry()
|
||||||
|
demo_transport_factory()
|
||||||
|
demo_multiaddr_transport_selection()
|
||||||
|
demo_custom_transport_registration()
|
||||||
|
demo_integration_with_libp2p()
|
||||||
|
|
||||||
|
print("🎯 Summary of New Features:")
|
||||||
|
print("=" * 40)
|
||||||
|
print("✅ Transport Registry: Central registry for all transport implementations")
|
||||||
|
print("✅ Dynamic Transport Selection: Automatic selection based on multiaddr")
|
||||||
|
print("✅ WebSocket Support: Full /ws protocol support")
|
||||||
|
print("✅ Extensible Architecture: Easy to add new transport protocols")
|
||||||
|
print("✅ Backward Compatibility: Existing TCP code continues to work")
|
||||||
|
print("✅ Factory Functions: Simple API for creating transports")
|
||||||
|
print()
|
||||||
|
print("🚀 The transport system is now ready for production use!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Demo interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Demo failed with error: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
220
examples/websocket/test_tcp_echo.py
Normal file
220
examples/websocket/test_tcp_echo.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple TCP echo demo to verify basic libp2p functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.host.basic_host import BasicHost
|
||||||
|
from libp2p.network.swarm import Swarm
|
||||||
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
|
from libp2p.peer.peerstore import PeerStore
|
||||||
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
|
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||||
|
from libp2p.transport.tcp.tcp import TCP
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
|
||||||
|
# Enable debug logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.tcp-example")
|
||||||
|
|
||||||
|
# Simple echo protocol
|
||||||
|
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
async def echo_handler(stream):
|
||||||
|
"""Simple echo handler that echoes back any data received."""
|
||||||
|
try:
|
||||||
|
data = await stream.read(1024)
|
||||||
|
if data:
|
||||||
|
message = data.decode("utf-8", errors="replace")
|
||||||
|
print(f"📥 Received: {message}")
|
||||||
|
print(f"📤 Echoing back: {message}")
|
||||||
|
await stream.write(data)
|
||||||
|
await stream.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Echo handler error: {e}")
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_tcp_host():
|
||||||
|
"""Create a host with TCP transport."""
|
||||||
|
# Create key pair and peer store
|
||||||
|
key_pair = create_new_key_pair()
|
||||||
|
peer_id = ID.from_pubkey(key_pair.public_key)
|
||||||
|
peer_store = PeerStore()
|
||||||
|
peer_store.add_key_pair(peer_id, key_pair)
|
||||||
|
|
||||||
|
# Create transport upgrader with plaintext security
|
||||||
|
upgrader = TransportUpgrader(
|
||||||
|
secure_transports_by_protocol={
|
||||||
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||||
|
},
|
||||||
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create TCP transport
|
||||||
|
transport = TCP()
|
||||||
|
|
||||||
|
# Create swarm and host
|
||||||
|
swarm = Swarm(peer_id, peer_store, upgrader, transport)
|
||||||
|
host = BasicHost(swarm)
|
||||||
|
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
async def run(port: int, destination: str) -> None:
|
||||||
|
localhost_ip = "0.0.0.0"
|
||||||
|
|
||||||
|
if not destination:
|
||||||
|
# Create first host (listener) with TCP transport
|
||||||
|
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
host = create_tcp_host()
|
||||||
|
logger.debug("Created TCP host")
|
||||||
|
|
||||||
|
# Set up echo handler
|
||||||
|
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
|
# Start the peer-store cleanup task
|
||||||
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
|
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||||
|
# connections
|
||||||
|
addrs = host.get_addrs()
|
||||||
|
logger.debug(f"Host addresses: {addrs}")
|
||||||
|
if not addrs:
|
||||||
|
print("❌ Error: No addresses found for the host")
|
||||||
|
return
|
||||||
|
|
||||||
|
server_addr = str(addrs[0])
|
||||||
|
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||||
|
|
||||||
|
print("🌐 TCP Server Started Successfully!")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"📍 Server Address: {client_addr}")
|
||||||
|
print("🔧 Protocol: /echo/1.0.0")
|
||||||
|
print("🚀 Transport: TCP")
|
||||||
|
print()
|
||||||
|
print("📋 To test the connection, run this in another terminal:")
|
||||||
|
print(f" python test_tcp_echo.py -d {client_addr}")
|
||||||
|
print()
|
||||||
|
print("⏳ Waiting for incoming TCP connections...")
|
||||||
|
print("─" * 50)
|
||||||
|
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating TCP server: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Create second host (dialer) with TCP transport
|
||||||
|
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a single host for client operations
|
||||||
|
host = create_tcp_host()
|
||||||
|
|
||||||
|
# Start the host for client operations
|
||||||
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
|
# Start the peer-store cleanup task
|
||||||
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
maddr = multiaddr.Multiaddr(destination)
|
||||||
|
info = info_from_p2p_addr(maddr)
|
||||||
|
print("🔌 TCP Client Starting...")
|
||||||
|
print("=" * 40)
|
||||||
|
print(f"🎯 Target Peer: {info.peer_id}")
|
||||||
|
print(f"📍 Target Address: {destination}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("🔗 Connecting to TCP server...")
|
||||||
|
await host.connect(info)
|
||||||
|
print("✅ Successfully connected to TCP server!")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
print("\n❌ Connection Failed!")
|
||||||
|
print(f" Peer ID: {info.peer_id}")
|
||||||
|
print(f" Address: {destination}")
|
||||||
|
print(f" Error: {error_msg}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create a stream and send test data
|
||||||
|
try:
|
||||||
|
stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to create stream: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("🚀 Starting Echo Protocol Test...")
|
||||||
|
print("─" * 40)
|
||||||
|
|
||||||
|
# Send test data
|
||||||
|
test_message = b"Hello TCP Transport!"
|
||||||
|
print(f"📤 Sending message: {test_message.decode('utf-8')}")
|
||||||
|
await stream.write(test_message)
|
||||||
|
|
||||||
|
# Read response
|
||||||
|
print("⏳ Waiting for server response...")
|
||||||
|
response = await stream.read(1024)
|
||||||
|
print(f"📥 Received response: {response.decode('utf-8')}")
|
||||||
|
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
print("─" * 40)
|
||||||
|
if response == test_message:
|
||||||
|
print("🎉 Echo test successful!")
|
||||||
|
print("✅ TCP transport is working perfectly!")
|
||||||
|
else:
|
||||||
|
print("❌ Echo test failed!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Echo protocol error: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
print("✅ TCP demo completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating TCP client: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
description = "Simple TCP echo demo for libp2p"
|
||||||
|
parser = argparse.ArgumentParser(description=description)
|
||||||
|
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"-d", "--destination", type=str, help="destination multiaddr string"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
trio.run(run, args.port, args.destination)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
145
examples/websocket/test_websocket_transport.py
Normal file
145
examples/websocket/test_websocket_transport.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test script to verify WebSocket transport functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add the libp2p directory to the path so we can import it
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
import multiaddr
|
||||||
|
|
||||||
|
from libp2p.transport import create_transport, create_transport_for_multiaddr
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_websocket_transport():
|
||||||
|
"""Test basic WebSocket transport functionality."""
|
||||||
|
print("🧪 Testing WebSocket Transport Functionality")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Create a dummy upgrader
|
||||||
|
upgrader = TransportUpgrader({}, {})
|
||||||
|
|
||||||
|
# Test creating WebSocket transport
|
||||||
|
try:
|
||||||
|
ws_transport = create_transport("ws", upgrader)
|
||||||
|
print(f"✅ WebSocket transport created: {type(ws_transport).__name__}")
|
||||||
|
|
||||||
|
# Test creating transport from multiaddr
|
||||||
|
ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||||
|
ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader)
|
||||||
|
print(
|
||||||
|
f"✅ WebSocket transport from multiaddr: "
|
||||||
|
f"{type(ws_transport_from_maddr).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test creating listener
|
||||||
|
handler_called = False
|
||||||
|
|
||||||
|
async def test_handler(conn):
|
||||||
|
nonlocal handler_called
|
||||||
|
handler_called = True
|
||||||
|
print(f"✅ Connection handler called with: {type(conn).__name__}")
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
listener = ws_transport.create_listener(test_handler)
|
||||||
|
print(f"✅ WebSocket listener created: {type(listener).__name__}")
|
||||||
|
|
||||||
|
# Test that the transport can be used
|
||||||
|
print(
|
||||||
|
f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"✅ WebSocket transport supports listening: "
|
||||||
|
f"{hasattr(ws_transport, 'create_listener')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n🎯 WebSocket Transport Test Results:")
|
||||||
|
print("✅ Transport creation: PASS")
|
||||||
|
print("✅ Multiaddr parsing: PASS")
|
||||||
|
print("✅ Listener creation: PASS")
|
||||||
|
print("✅ Interface compliance: PASS")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ WebSocket transport test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_transport_registry():
|
||||||
|
"""Test the transport registry functionality."""
|
||||||
|
print("\n🔧 Testing Transport Registry")
|
||||||
|
print("=" * 30)
|
||||||
|
|
||||||
|
from libp2p.transport import (
|
||||||
|
get_supported_transport_protocols,
|
||||||
|
get_transport_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = get_transport_registry()
|
||||||
|
supported = get_supported_transport_protocols()
|
||||||
|
|
||||||
|
print(f"Supported protocols: {supported}")
|
||||||
|
|
||||||
|
# Test getting transports
|
||||||
|
for protocol in supported:
|
||||||
|
transport_class = registry.get_transport(protocol)
|
||||||
|
class_name = transport_class.__name__ if transport_class else "None"
|
||||||
|
print(f" {protocol}: {class_name}")
|
||||||
|
|
||||||
|
# Test creating transports through registry
|
||||||
|
upgrader = TransportUpgrader({}, {})
|
||||||
|
|
||||||
|
for protocol in supported:
|
||||||
|
try:
|
||||||
|
transport = registry.create_transport(protocol, upgrader)
|
||||||
|
if transport:
|
||||||
|
print(f"✅ {protocol}: Created successfully")
|
||||||
|
else:
|
||||||
|
print(f"❌ {protocol}: Failed to create")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ {protocol}: Error - {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("🚀 WebSocket Transport Integration Test Suite")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
success = await test_websocket_transport()
|
||||||
|
await test_transport_registry()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
if success:
|
||||||
|
print("🎉 All tests passed! WebSocket transport is working correctly.")
|
||||||
|
else:
|
||||||
|
print("❌ Some tests failed. Check the output above for details.")
|
||||||
|
|
||||||
|
print("\n🚀 WebSocket transport is ready for use in py-libp2p!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n👋 Test interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Test failed with error: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
448
examples/websocket/websocket_demo.py
Normal file
448
examples/websocket/websocket_demo.py
Normal file
@ -0,0 +1,448 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.abc import INotifee
|
||||||
|
from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair
|
||||||
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.host.basic_host import BasicHost
|
||||||
|
from libp2p.network.swarm import Swarm
|
||||||
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
|
from libp2p.peer.peerstore import PeerStore
|
||||||
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
|
from libp2p.security.noise.transport import (
|
||||||
|
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||||
|
Transport as NoiseTransport,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
|
|
||||||
|
# Enable debug logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.websocket-example")
|
||||||
|
|
||||||
|
|
||||||
|
# Suppress KeyboardInterrupt by handling SIGINT directly
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
print("✅ Clean exit completed.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Simple echo protocol
|
||||||
|
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
async def echo_handler(stream):
|
||||||
|
"""Simple echo handler that echoes back any data received."""
|
||||||
|
try:
|
||||||
|
data = await stream.read(1024)
|
||||||
|
if data:
|
||||||
|
message = data.decode("utf-8", errors="replace")
|
||||||
|
print(f"📥 Received: {message}")
|
||||||
|
print(f"📤 Echoing back: {message}")
|
||||||
|
await stream.write(data)
|
||||||
|
await stream.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Echo handler error: {e}")
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_websocket_host(listen_addrs=None, use_plaintext=False):
|
||||||
|
"""Create a host with WebSocket transport."""
|
||||||
|
# Create key pair and peer store
|
||||||
|
key_pair = create_new_key_pair()
|
||||||
|
peer_id = ID.from_pubkey(key_pair.public_key)
|
||||||
|
peer_store = PeerStore()
|
||||||
|
peer_store.add_key_pair(peer_id, key_pair)
|
||||||
|
|
||||||
|
if use_plaintext:
|
||||||
|
# Create transport upgrader with plaintext security
|
||||||
|
upgrader = TransportUpgrader(
|
||||||
|
secure_transports_by_protocol={
|
||||||
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||||
|
},
|
||||||
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create separate Ed25519 key for Noise protocol
|
||||||
|
noise_key_pair = create_ed25519_key_pair()
|
||||||
|
|
||||||
|
# Create Noise transport
|
||||||
|
noise_transport = NoiseTransport(
|
||||||
|
libp2p_keypair=key_pair,
|
||||||
|
noise_privkey=noise_key_pair.private_key,
|
||||||
|
early_data=None,
|
||||||
|
with_noise_pipes=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create transport upgrader with Noise security
|
||||||
|
upgrader = TransportUpgrader(
|
||||||
|
secure_transports_by_protocol={
|
||||||
|
TProtocol(NOISE_PROTOCOL_ID): noise_transport
|
||||||
|
},
|
||||||
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create WebSocket transport
|
||||||
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
|
# Create swarm and host
|
||||||
|
swarm = Swarm(peer_id, peer_store, upgrader, transport)
|
||||||
|
host = BasicHost(swarm)
|
||||||
|
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
async def run(port: int, destination: str, use_plaintext: bool = False) -> None:
|
||||||
|
localhost_ip = "0.0.0.0"
|
||||||
|
|
||||||
|
if not destination:
|
||||||
|
# Create first host (listener) with WebSocket transport
|
||||||
|
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
||||||
|
|
||||||
|
try:
|
||||||
|
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||||
|
logger.debug(f"Created host with use_plaintext={use_plaintext}")
|
||||||
|
|
||||||
|
# Set up echo handler
|
||||||
|
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||||
|
|
||||||
|
# Add connection event handlers for debugging
|
||||||
|
class DebugNotifee(INotifee):
|
||||||
|
async def opened_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def closed_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def connected(self, network, conn):
|
||||||
|
print(
|
||||||
|
f"🔗 New libp2p connection established: "
|
||||||
|
f"{conn.muxed_conn.peer_id}"
|
||||||
|
)
|
||||||
|
if hasattr(conn.muxed_conn, "get_security_protocol"):
|
||||||
|
security = conn.muxed_conn.get_security_protocol()
|
||||||
|
else:
|
||||||
|
security = "Unknown"
|
||||||
|
|
||||||
|
print(f" Security: {security}")
|
||||||
|
|
||||||
|
async def disconnected(self, network, conn):
|
||||||
|
print(f"🔌 libp2p connection closed: {conn.muxed_conn.peer_id}")
|
||||||
|
|
||||||
|
async def listen(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def listen_close(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
host.get_network().register_notifee(DebugNotifee())
|
||||||
|
|
||||||
|
# Create a cancellation token for clean shutdown
|
||||||
|
cancel_scope = trio.CancelScope()
|
||||||
|
|
||||||
|
async def signal_handler():
|
||||||
|
with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as (
|
||||||
|
signal_receiver
|
||||||
|
):
|
||||||
|
async for sig in signal_receiver:
|
||||||
|
print(f"\n🛑 Received signal {sig}")
|
||||||
|
print("✅ Shutting down WebSocket server...")
|
||||||
|
cancel_scope.cancel()
|
||||||
|
return
|
||||||
|
|
||||||
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
|
# Start the peer-store cleanup task
|
||||||
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
|
# Start the signal handler
|
||||||
|
nursery.start_soon(signal_handler)
|
||||||
|
|
||||||
|
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||||
|
# connections
|
||||||
|
addrs = host.get_addrs()
|
||||||
|
logger.debug(f"Host addresses: {addrs}")
|
||||||
|
if not addrs:
|
||||||
|
print("❌ Error: No addresses found for the host")
|
||||||
|
print("Debug: host.get_addrs() returned empty list")
|
||||||
|
return
|
||||||
|
|
||||||
|
server_addr = str(addrs[0])
|
||||||
|
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||||
|
|
||||||
|
print("🌐 WebSocket Server Started Successfully!")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"📍 Server Address: {client_addr}")
|
||||||
|
print("🔧 Protocol: /echo/1.0.0")
|
||||||
|
print("🚀 Transport: WebSocket (/ws)")
|
||||||
|
print()
|
||||||
|
print("📋 To test the connection, run this in another terminal:")
|
||||||
|
plaintext_flag = " --plaintext" if use_plaintext else ""
|
||||||
|
print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}")
|
||||||
|
print()
|
||||||
|
print("⏳ Waiting for incoming WebSocket connections...")
|
||||||
|
print("─" * 50)
|
||||||
|
|
||||||
|
# Add a custom handler to show connection events
|
||||||
|
async def custom_echo_handler(stream):
|
||||||
|
peer_id = stream.muxed_conn.peer_id
|
||||||
|
print("\n🔗 New WebSocket Connection!")
|
||||||
|
print(f" Peer ID: {peer_id}")
|
||||||
|
print(" Protocol: /echo/1.0.0")
|
||||||
|
|
||||||
|
# Show remote address in multiaddr format
|
||||||
|
try:
|
||||||
|
remote_address = stream.get_remote_address()
|
||||||
|
if remote_address:
|
||||||
|
print(f" Remote: {remote_address}")
|
||||||
|
except Exception:
|
||||||
|
print(" Remote: Unknown")
|
||||||
|
|
||||||
|
print(" ─" * 40)
|
||||||
|
|
||||||
|
# Call the original handler
|
||||||
|
await echo_handler(stream)
|
||||||
|
|
||||||
|
print(" ─" * 40)
|
||||||
|
print(f"✅ Echo request completed for peer: {peer_id}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Replace the handler with our custom one
|
||||||
|
host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler)
|
||||||
|
|
||||||
|
# Wait indefinitely or until cancelled
|
||||||
|
with cancel_scope:
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating WebSocket server: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Create second host (dialer) with WebSocket transport
|
||||||
|
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a single host for client operations
|
||||||
|
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||||
|
|
||||||
|
# Start the host for client operations
|
||||||
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
|
# Start the peer-store cleanup task
|
||||||
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
|
# Add connection event handlers for debugging
|
||||||
|
class ClientDebugNotifee(INotifee):
|
||||||
|
async def opened_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def closed_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def connected(self, network, conn):
|
||||||
|
print(
|
||||||
|
f"🔗 Client: libp2p connection established: "
|
||||||
|
f"{conn.muxed_conn.peer_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def disconnected(self, network, conn):
|
||||||
|
print(
|
||||||
|
f"🔌 Client: libp2p connection closed: "
|
||||||
|
f"{conn.muxed_conn.peer_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def listen(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def listen_close(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
host.get_network().register_notifee(ClientDebugNotifee())
|
||||||
|
|
||||||
|
maddr = multiaddr.Multiaddr(destination)
|
||||||
|
info = info_from_p2p_addr(maddr)
|
||||||
|
print("🔌 WebSocket Client Starting...")
|
||||||
|
print("=" * 40)
|
||||||
|
print(f"🎯 Target Peer: {info.peer_id}")
|
||||||
|
print(f"📍 Target Address: {destination}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("🔗 Connecting to WebSocket server...")
|
||||||
|
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||||
|
await host.connect(info)
|
||||||
|
print("✅ Successfully connected to WebSocket server!")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
print("\n❌ Connection Failed!")
|
||||||
|
print(f" Peer ID: {info.peer_id}")
|
||||||
|
print(f" Address: {destination}")
|
||||||
|
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||||
|
print(f" Error: {error_msg}")
|
||||||
|
print(f" Error type: {type(e).__name__}")
|
||||||
|
|
||||||
|
# Add more detailed error information for debugging
|
||||||
|
if hasattr(e, "__cause__") and e.__cause__:
|
||||||
|
print(f" Root cause: {e.__cause__}")
|
||||||
|
print(f" Root cause type: {type(e.__cause__).__name__}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("💡 Troubleshooting:")
|
||||||
|
print(" • Make sure the WebSocket server is running")
|
||||||
|
print(" • Check that the server address is correct")
|
||||||
|
print(" • Verify the server is listening on the right port")
|
||||||
|
print(
|
||||||
|
" • Ensure both client and server use the same sec protocol"
|
||||||
|
)
|
||||||
|
if not use_plaintext:
|
||||||
|
print(" • Noise over WebSocket may have compatibility issues")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create a stream and send test data
|
||||||
|
try:
|
||||||
|
stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to create stream: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("🚀 Starting Echo Protocol Test...")
|
||||||
|
print("─" * 40)
|
||||||
|
|
||||||
|
# Send test data
|
||||||
|
test_message = b"Hello WebSocket Transport!"
|
||||||
|
print(f"📤 Sending message: {test_message.decode('utf-8')}")
|
||||||
|
await stream.write(test_message)
|
||||||
|
|
||||||
|
# Read response
|
||||||
|
print("⏳ Waiting for server response...")
|
||||||
|
response = await stream.read(1024)
|
||||||
|
print(f"📥 Received response: {response.decode('utf-8')}")
|
||||||
|
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
print("─" * 40)
|
||||||
|
if response == test_message:
|
||||||
|
print("🎉 Echo test successful!")
|
||||||
|
print("✅ WebSocket transport is working perfectly!")
|
||||||
|
print("✅ Client completed successfully, exiting.")
|
||||||
|
else:
|
||||||
|
print("❌ Echo test failed!")
|
||||||
|
print(" Response doesn't match sent data.")
|
||||||
|
print(f" Sent: {test_message}")
|
||||||
|
print(f" Received: {response}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
print(f"Echo protocol error: {error_msg}")
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Ensure stream is closed
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
# Check if stream has is_closed method and use it
|
||||||
|
has_is_closed = hasattr(stream, "is_closed") and callable(
|
||||||
|
getattr(stream, "is_closed")
|
||||||
|
)
|
||||||
|
if has_is_closed:
|
||||||
|
# type: ignore[attr-defined]
|
||||||
|
if not await stream.is_closed():
|
||||||
|
await stream.close()
|
||||||
|
else:
|
||||||
|
# Fallback: just try to close the stream
|
||||||
|
await stream.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# host.run() context manager handles cleanup automatically
|
||||||
|
print()
|
||||||
|
print("🎉 WebSocket Demo Completed Successfully!")
|
||||||
|
print("=" * 50)
|
||||||
|
print("✅ WebSocket transport is working perfectly!")
|
||||||
|
print("✅ Echo protocol communication successful!")
|
||||||
|
print("✅ libp2p integration verified!")
|
||||||
|
print()
|
||||||
|
print("🚀 Your WebSocket transport is ready for production use!")
|
||||||
|
|
||||||
|
# Add a small delay to ensure all cleanup is complete
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating WebSocket client: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
description = """
|
||||||
|
This program demonstrates the libp2p WebSocket transport.
|
||||||
|
First run
|
||||||
|
'python websocket_demo.py -p <PORT> [--plaintext]' to start a WebSocket server.
|
||||||
|
Then run
|
||||||
|
'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--plaintext]'
|
||||||
|
where <DESTINATION> is the multiaddress shown by the server.
|
||||||
|
|
||||||
|
By default, this example uses Noise encryption for secure communication.
|
||||||
|
Use --plaintext for testing with unencrypted communication
|
||||||
|
(not recommended for production).
|
||||||
|
"""
|
||||||
|
|
||||||
|
example_maddr = (
|
||||||
|
"/ip4/127.0.0.1/tcp/8888/ws/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=description)
|
||||||
|
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--destination",
|
||||||
|
type=str,
|
||||||
|
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plaintext",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"use plaintext security instead of Noise encryption "
|
||||||
|
"(not recommended for production)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Determine security mode: use Noise by default,
|
||||||
|
# plaintext if --plaintext is specified
|
||||||
|
use_plaintext = args.plaintext
|
||||||
|
|
||||||
|
try:
|
||||||
|
trio.run(run, args.port, args.destination, use_plaintext)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# This is expected when Ctrl+C is pressed
|
||||||
|
# The signal handler already printed the shutdown message
|
||||||
|
print("✅ Clean exit completed.")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Unexpected error: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,5 +1,12 @@
|
|||||||
"""Libp2p Python implementation."""
|
"""Libp2p Python implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||||
|
from typing import Any
|
||||||
|
from libp2p.transport.quic.transport import QUICTransport
|
||||||
|
from libp2p.transport.quic.config import QUICTransportConfig
|
||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
Mapping,
|
Mapping,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -18,6 +25,7 @@ from libp2p.abc import (
|
|||||||
IPeerRouting,
|
IPeerRouting,
|
||||||
IPeerStore,
|
IPeerStore,
|
||||||
ISecureTransport,
|
ISecureTransport,
|
||||||
|
ITransport,
|
||||||
)
|
)
|
||||||
from libp2p.crypto.keys import (
|
from libp2p.crypto.keys import (
|
||||||
KeyPair,
|
KeyPair,
|
||||||
@ -38,10 +46,12 @@ from libp2p.host.routed_host import (
|
|||||||
RoutedHost,
|
RoutedHost,
|
||||||
)
|
)
|
||||||
from libp2p.network.swarm import (
|
from libp2p.network.swarm import (
|
||||||
ConnectionConfig,
|
|
||||||
RetryConfig,
|
|
||||||
Swarm,
|
Swarm,
|
||||||
)
|
)
|
||||||
|
from libp2p.network.config import (
|
||||||
|
ConnectionConfig,
|
||||||
|
RetryConfig
|
||||||
|
)
|
||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
@ -72,6 +82,10 @@ from libp2p.transport.tcp.tcp import (
|
|||||||
from libp2p.transport.upgrader import (
|
from libp2p.transport.upgrader import (
|
||||||
TransportUpgrader,
|
TransportUpgrader,
|
||||||
)
|
)
|
||||||
|
from libp2p.transport.transport_registry import (
|
||||||
|
create_transport_for_multiaddr,
|
||||||
|
get_supported_transport_protocols,
|
||||||
|
)
|
||||||
from libp2p.utils.logging import (
|
from libp2p.utils.logging import (
|
||||||
setup_logging,
|
setup_logging,
|
||||||
)
|
)
|
||||||
@ -87,6 +101,7 @@ MUXER_YAMUX = "YAMUX"
|
|||||||
MUXER_MPLEX = "MPLEX"
|
MUXER_MPLEX = "MPLEX"
|
||||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -162,9 +177,13 @@ def new_swarm(
|
|||||||
peerstore_opt: IPeerStore | None = None,
|
peerstore_opt: IPeerStore | None = None,
|
||||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||||
|
enable_quic: bool = False,
|
||||||
retry_config: Optional["RetryConfig"] = None,
|
retry_config: Optional["RetryConfig"] = None,
|
||||||
connection_config: Optional["ConnectionConfig"] = None,
|
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||||
|
tls_client_config: ssl.SSLContext | None = None,
|
||||||
|
tls_server_config: ssl.SSLContext | None = None,
|
||||||
) -> INetworkService:
|
) -> INetworkService:
|
||||||
|
logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}")
|
||||||
"""
|
"""
|
||||||
Create a swarm instance based on the parameters.
|
Create a swarm instance based on the parameters.
|
||||||
|
|
||||||
@ -174,6 +193,8 @@ def new_swarm(
|
|||||||
:param peerstore_opt: optional peerstore
|
:param peerstore_opt: optional peerstore
|
||||||
:param muxer_preference: optional explicit muxer preference
|
:param muxer_preference: optional explicit muxer preference
|
||||||
:param listen_addrs: optional list of multiaddrs to listen on
|
:param listen_addrs: optional list of multiaddrs to listen on
|
||||||
|
:param enable_quic: enable quic for transport
|
||||||
|
:param quic_transport_opt: options for transport
|
||||||
:return: return a default swarm instance
|
:return: return a default swarm instance
|
||||||
|
|
||||||
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
||||||
@ -186,16 +207,48 @@ def new_swarm(
|
|||||||
|
|
||||||
id_opt = generate_peer_id_from(key_pair)
|
id_opt = generate_peer_id_from(key_pair)
|
||||||
|
|
||||||
|
transport: TCP | QUICTransport | ITransport
|
||||||
|
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
|
||||||
|
|
||||||
if listen_addrs is None:
|
if listen_addrs is None:
|
||||||
transport = TCP()
|
if enable_quic:
|
||||||
else:
|
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||||
addr = listen_addrs[0]
|
|
||||||
if addr.__contains__("tcp"):
|
|
||||||
transport = TCP()
|
|
||||||
elif addr.__contains__("quic"):
|
|
||||||
raise ValueError("QUIC not yet supported")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
|
transport = TCP()
|
||||||
|
else:
|
||||||
|
# Use transport registry to select the appropriate transport
|
||||||
|
from libp2p.transport.transport_registry import create_transport_for_multiaddr
|
||||||
|
|
||||||
|
# Create a temporary upgrader for transport selection
|
||||||
|
# We'll create the real upgrader later with the proper configuration
|
||||||
|
temp_upgrader = TransportUpgrader(
|
||||||
|
secure_transports_by_protocol={},
|
||||||
|
muxer_transports_by_protocol={}
|
||||||
|
)
|
||||||
|
|
||||||
|
addr = listen_addrs[0]
|
||||||
|
logger.debug(f"new_swarm: Creating transport for address: {addr}")
|
||||||
|
transport_maybe = create_transport_for_multiaddr(
|
||||||
|
addr,
|
||||||
|
temp_upgrader,
|
||||||
|
private_key=key_pair.private_key,
|
||||||
|
config=quic_transport_opt,
|
||||||
|
tls_client_config=tls_client_config,
|
||||||
|
tls_server_config=tls_server_config
|
||||||
|
)
|
||||||
|
|
||||||
|
if transport_maybe is None:
|
||||||
|
raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}")
|
||||||
|
|
||||||
|
transport = transport_maybe
|
||||||
|
logger.debug(f"new_swarm: Created transport: {type(transport)}")
|
||||||
|
|
||||||
|
# If enable_quic is True but we didn't get a QUIC transport, force QUIC
|
||||||
|
if enable_quic and not isinstance(transport, QUICTransport):
|
||||||
|
logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})")
|
||||||
|
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||||
|
|
||||||
|
logger.debug(f"new_swarm: Final transport type: {type(transport)}")
|
||||||
|
|
||||||
# Generate X25519 keypair for Noise
|
# Generate X25519 keypair for Noise
|
||||||
noise_key_pair = create_new_x25519_key_pair()
|
noise_key_pair = create_new_x25519_key_pair()
|
||||||
@ -236,6 +289,7 @@ def new_swarm(
|
|||||||
muxer_transports_by_protocol=muxer_transports_by_protocol,
|
muxer_transports_by_protocol=muxer_transports_by_protocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
peerstore = peerstore_opt or PeerStore()
|
peerstore = peerstore_opt or PeerStore()
|
||||||
# Store our key pair in peerstore
|
# Store our key pair in peerstore
|
||||||
peerstore.add_key_pair(id_opt, key_pair)
|
peerstore.add_key_pair(id_opt, key_pair)
|
||||||
@ -261,6 +315,10 @@ def new_host(
|
|||||||
enable_mDNS: bool = False,
|
enable_mDNS: bool = False,
|
||||||
bootstrap: list[str] | None = None,
|
bootstrap: list[str] | None = None,
|
||||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
enable_quic: bool = False,
|
||||||
|
quic_transport_opt: QUICTransportConfig | None = None,
|
||||||
|
tls_client_config: ssl.SSLContext | None = None,
|
||||||
|
tls_server_config: ssl.SSLContext | None = None,
|
||||||
) -> IHost:
|
) -> IHost:
|
||||||
"""
|
"""
|
||||||
Create a new libp2p host based on the given parameters.
|
Create a new libp2p host based on the given parameters.
|
||||||
@ -274,15 +332,27 @@ def new_host(
|
|||||||
:param listen_addrs: optional list of multiaddrs to listen on
|
:param listen_addrs: optional list of multiaddrs to listen on
|
||||||
:param enable_mDNS: whether to enable mDNS discovery
|
:param enable_mDNS: whether to enable mDNS discovery
|
||||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||||
|
:param enable_quic: optinal choice to use QUIC for transport
|
||||||
|
:param quic_transport_opt: optional configuration for quic transport
|
||||||
|
:param tls_client_config: optional TLS client configuration for WebSocket transport
|
||||||
|
:param tls_server_config: optional TLS server configuration for WebSocket transport
|
||||||
:return: return a host instance
|
:return: return a host instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not enable_quic and quic_transport_opt is not None:
|
||||||
|
logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config")
|
||||||
|
|
||||||
swarm = new_swarm(
|
swarm = new_swarm(
|
||||||
|
enable_quic=enable_quic,
|
||||||
key_pair=key_pair,
|
key_pair=key_pair,
|
||||||
muxer_opt=muxer_opt,
|
muxer_opt=muxer_opt,
|
||||||
sec_opt=sec_opt,
|
sec_opt=sec_opt,
|
||||||
peerstore_opt=peerstore_opt,
|
peerstore_opt=peerstore_opt,
|
||||||
muxer_preference=muxer_preference,
|
muxer_preference=muxer_preference,
|
||||||
listen_addrs=listen_addrs,
|
listen_addrs=listen_addrs,
|
||||||
|
connection_config=quic_transport_opt if enable_quic else None,
|
||||||
|
tls_client_config=tls_client_config,
|
||||||
|
tls_server_config=tls_server_config
|
||||||
)
|
)
|
||||||
|
|
||||||
if disc_opt is not None:
|
if disc_opt is not None:
|
||||||
|
|||||||
@ -5,17 +5,17 @@ from collections.abc import (
|
|||||||
)
|
)
|
||||||
from typing import TYPE_CHECKING, NewType, Union, cast
|
from typing import TYPE_CHECKING, NewType, Union, cast
|
||||||
|
|
||||||
|
from libp2p.transport.quic.stream import QUICStream
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from libp2p.abc import (
|
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
|
||||||
IMuxedConn,
|
from libp2p.transport.quic.connection import QUICConnection
|
||||||
INetStream,
|
|
||||||
ISecureTransport,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
IMuxedConn = cast(type, object)
|
IMuxedConn = cast(type, object)
|
||||||
INetStream = cast(type, object)
|
INetStream = cast(type, object)
|
||||||
ISecureTransport = cast(type, object)
|
ISecureTransport = cast(type, object)
|
||||||
|
IMuxedStream = cast(type, object)
|
||||||
|
QUICConnection = cast(type, object)
|
||||||
|
|
||||||
from libp2p.io.abc import (
|
from libp2p.io.abc import (
|
||||||
ReadWriteCloser,
|
ReadWriteCloser,
|
||||||
@ -37,3 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
|||||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||||
|
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
|
||||||
|
TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]]
|
||||||
|
MessageID = NewType("MessageID", str)
|
||||||
|
|||||||
@ -2,15 +2,20 @@ import logging
|
|||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
from multiaddr.resolvers import DNSResolver
|
from multiaddr.resolvers import DNSResolver
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.abc import ID, INetworkService, PeerInfo
|
from libp2p.abc import ID, INetworkService, PeerInfo
|
||||||
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
|
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
|
||||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||||
|
from libp2p.network.exceptions import SwarmException
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
|
from libp2p.peer.peerstore import PERMANENT_ADDR_TTL
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.discovery.bootstrap")
|
logger = logging.getLogger("libp2p.discovery.bootstrap")
|
||||||
resolver = DNSResolver()
|
resolver = DNSResolver()
|
||||||
|
|
||||||
|
DEFAULT_CONNECTION_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
class BootstrapDiscovery:
|
class BootstrapDiscovery:
|
||||||
"""
|
"""
|
||||||
@ -19,68 +24,147 @@ class BootstrapDiscovery:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
|
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
|
||||||
|
"""
|
||||||
|
Initialize BootstrapDiscovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
swarm: The network service (swarm) instance
|
||||||
|
bootstrap_addrs: List of bootstrap peer multiaddresses
|
||||||
|
|
||||||
|
"""
|
||||||
self.swarm = swarm
|
self.swarm = swarm
|
||||||
self.peerstore = swarm.peerstore
|
self.peerstore = swarm.peerstore
|
||||||
self.bootstrap_addrs = bootstrap_addrs or []
|
self.bootstrap_addrs = bootstrap_addrs or []
|
||||||
self.discovered_peers: set[str] = set()
|
self.discovered_peers: set[str] = set()
|
||||||
|
self.connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Process bootstrap addresses and emit peer discovery events."""
|
"""Process bootstrap addresses and emit peer discovery events in parallel."""
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"Starting bootstrap discovery with "
|
f"Starting bootstrap discovery with "
|
||||||
f"{len(self.bootstrap_addrs)} bootstrap addresses"
|
f"{len(self.bootstrap_addrs)} bootstrap addresses"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Show all bootstrap addresses being processed
|
||||||
|
for i, addr in enumerate(self.bootstrap_addrs):
|
||||||
|
logger.debug(f"{i + 1}. {addr}")
|
||||||
|
|
||||||
# Validate and filter bootstrap addresses
|
# Validate and filter bootstrap addresses
|
||||||
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
|
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
|
||||||
|
logger.info(f"Valid addresses after validation: {len(self.bootstrap_addrs)}")
|
||||||
|
|
||||||
for addr_str in self.bootstrap_addrs:
|
# Use Trio nursery for PARALLEL address processing
|
||||||
try:
|
try:
|
||||||
await self._process_bootstrap_addr(addr_str)
|
async with trio.open_nursery() as nursery:
|
||||||
except Exception as e:
|
logger.debug(
|
||||||
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
|
f"Starting {len(self.bootstrap_addrs)} parallel address "
|
||||||
|
f"processing tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start all bootstrap address processing tasks in parallel
|
||||||
|
for addr_str in self.bootstrap_addrs:
|
||||||
|
logger.debug(f"Starting parallel task for: {addr_str}")
|
||||||
|
nursery.start_soon(self._process_bootstrap_addr, addr_str)
|
||||||
|
|
||||||
|
# The nursery will wait for all address processing tasks to complete
|
||||||
|
logger.debug(
|
||||||
|
"Nursery active - waiting for address processing tasks to complete"
|
||||||
|
)
|
||||||
|
|
||||||
|
except trio.Cancelled:
|
||||||
|
logger.debug("Bootstrap address processing cancelled - cleaning up tasks")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Bootstrap address processing failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info("Bootstrap discovery startup complete - all tasks finished")
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Clean up bootstrap discovery resources."""
|
"""Clean up bootstrap discovery resources."""
|
||||||
logger.debug("Stopping bootstrap discovery")
|
logger.info("Stopping bootstrap discovery and cleaning up tasks")
|
||||||
|
|
||||||
|
# Clear discovered peers
|
||||||
self.discovered_peers.clear()
|
self.discovered_peers.clear()
|
||||||
|
|
||||||
|
logger.debug("Bootstrap discovery cleanup completed")
|
||||||
|
|
||||||
async def _process_bootstrap_addr(self, addr_str: str) -> None:
|
async def _process_bootstrap_addr(self, addr_str: str) -> None:
|
||||||
"""Convert string address to PeerInfo and add to peerstore."""
|
"""Convert string address to PeerInfo and add to peerstore."""
|
||||||
try:
|
try:
|
||||||
multiaddr = Multiaddr(addr_str)
|
try:
|
||||||
|
multiaddr = Multiaddr(addr_str)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.is_dns_addr(multiaddr):
|
||||||
|
resolved_addrs = await resolver.resolve(multiaddr)
|
||||||
|
if resolved_addrs is None:
|
||||||
|
logger.warning(f"DNS resolution returned None for: {addr_str}")
|
||||||
|
return
|
||||||
|
|
||||||
|
peer_id_str = multiaddr.get_peer_id()
|
||||||
|
if peer_id_str is None:
|
||||||
|
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
|
||||||
|
return
|
||||||
|
peer_id = ID.from_base58(peer_id_str)
|
||||||
|
addrs = [addr for addr in resolved_addrs]
|
||||||
|
if not addrs:
|
||||||
|
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
|
||||||
|
return
|
||||||
|
peer_info = PeerInfo(peer_id, addrs)
|
||||||
|
await self.add_addr(peer_info)
|
||||||
|
else:
|
||||||
|
peer_info = info_from_p2p_addr(multiaddr)
|
||||||
|
await self.add_addr(peer_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
|
logger.warning(f"Failed to process bootstrap address {addr_str}: {e}")
|
||||||
return
|
|
||||||
if self.is_dns_addr(multiaddr):
|
|
||||||
resolved_addrs = await resolver.resolve(multiaddr)
|
|
||||||
peer_id_str = multiaddr.get_peer_id()
|
|
||||||
if peer_id_str is None:
|
|
||||||
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
|
|
||||||
return
|
|
||||||
peer_id = ID.from_base58(peer_id_str)
|
|
||||||
addrs = [addr for addr in resolved_addrs]
|
|
||||||
if not addrs:
|
|
||||||
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
|
|
||||||
return
|
|
||||||
peer_info = PeerInfo(peer_id, addrs)
|
|
||||||
self.add_addr(peer_info)
|
|
||||||
else:
|
|
||||||
self.add_addr(info_from_p2p_addr(multiaddr))
|
|
||||||
|
|
||||||
def is_dns_addr(self, addr: Multiaddr) -> bool:
|
def is_dns_addr(self, addr: Multiaddr) -> bool:
|
||||||
"""Check if the address is a DNS address."""
|
"""Check if the address is a DNS address."""
|
||||||
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
|
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
|
||||||
|
|
||||||
def add_addr(self, peer_info: PeerInfo) -> None:
|
async def add_addr(self, peer_info: PeerInfo) -> None:
|
||||||
"""Add a peer to the peerstore and emit discovery event."""
|
"""
|
||||||
|
Add a peer to the peerstore, emit discovery event,
|
||||||
|
and attempt connection in parallel.
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"Adding peer {peer_info.peer_id} with {len(peer_info.addrs)} addresses"
|
||||||
|
)
|
||||||
|
|
||||||
# Skip if it's our own peer
|
# Skip if it's our own peer
|
||||||
if peer_info.peer_id == self.swarm.get_peer_id():
|
if peer_info.peer_id == self.swarm.get_peer_id():
|
||||||
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
|
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Always add addresses to peerstore (allows multiple addresses for same peer)
|
# Filter addresses to only include IPv4+TCP (only supported protocol)
|
||||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
ipv4_tcp_addrs = []
|
||||||
|
filtered_out_addrs = []
|
||||||
|
|
||||||
|
for addr in peer_info.addrs:
|
||||||
|
if self._is_ipv4_tcp_addr(addr):
|
||||||
|
ipv4_tcp_addrs.append(addr)
|
||||||
|
else:
|
||||||
|
filtered_out_addrs.append(addr)
|
||||||
|
|
||||||
|
# Log filtering results
|
||||||
|
logger.debug(
|
||||||
|
f"Address filtering for {peer_info.peer_id}: "
|
||||||
|
f"{len(ipv4_tcp_addrs)} IPv4+TCP, {len(filtered_out_addrs)} filtered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip peer if no IPv4+TCP addresses available
|
||||||
|
if not ipv4_tcp_addrs:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ No IPv4+TCP addresses for {peer_info.peer_id} - "
|
||||||
|
f"skipping connection attempts"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add only IPv4+TCP addresses to peerstore
|
||||||
|
self.peerstore.add_addrs(peer_info.peer_id, ipv4_tcp_addrs, PERMANENT_ADDR_TTL)
|
||||||
|
|
||||||
# Only emit discovery event if this is the first time we see this peer
|
# Only emit discovery event if this is the first time we see this peer
|
||||||
peer_id_str = str(peer_info.peer_id)
|
peer_id_str = str(peer_info.peer_id)
|
||||||
@ -89,6 +173,140 @@ class BootstrapDiscovery:
|
|||||||
self.discovered_peers.add(peer_id_str)
|
self.discovered_peers.add(peer_id_str)
|
||||||
# Emit peer discovery event
|
# Emit peer discovery event
|
||||||
peerDiscovery.emit_peer_discovered(peer_info)
|
peerDiscovery.emit_peer_discovered(peer_info)
|
||||||
logger.debug(f"Peer discovered: {peer_info.peer_id}")
|
logger.info(f"Peer discovered: {peer_info.peer_id}")
|
||||||
|
|
||||||
|
# Connect to peer (parallel across different bootstrap addresses)
|
||||||
|
logger.debug("Connecting to discovered peer...")
|
||||||
|
await self._connect_to_peer(peer_info.peer_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")
|
logger.debug(
|
||||||
|
f"Additional addresses added for existing peer: {peer_info.peer_id}"
|
||||||
|
)
|
||||||
|
# Even for existing peers, try to connect if not already connected
|
||||||
|
if peer_info.peer_id not in self.swarm.connections:
|
||||||
|
logger.debug("Connecting to existing peer...")
|
||||||
|
await self._connect_to_peer(peer_info.peer_id)
|
||||||
|
|
||||||
|
async def _connect_to_peer(self, peer_id: ID) -> None:
|
||||||
|
"""
|
||||||
|
Attempt to establish a connection to a peer with timeout.
|
||||||
|
|
||||||
|
Uses swarm.dial_peer to connect using addresses stored in peerstore.
|
||||||
|
Times out after self.connection_timeout seconds to prevent hanging.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Connection attempt for peer: {peer_id}")
|
||||||
|
|
||||||
|
# Pre-connection validation: Check if already connected
|
||||||
|
if peer_id in self.swarm.connections:
|
||||||
|
logger.debug(
|
||||||
|
f"Already connected to {peer_id} - skipping connection attempt"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check available addresses before attempting connection
|
||||||
|
available_addrs = self.peerstore.addrs(peer_id)
|
||||||
|
logger.debug(f"Connecting to {peer_id} ({len(available_addrs)} addresses)")
|
||||||
|
|
||||||
|
if not available_addrs:
|
||||||
|
logger.error(f"❌ No addresses available for {peer_id} - cannot connect")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Record start time for connection attempt monitoring
|
||||||
|
connection_start_time = trio.current_time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with trio.move_on_after(self.connection_timeout):
|
||||||
|
# Log connection attempt
|
||||||
|
logger.debug(
|
||||||
|
f"Attempting connection to {peer_id} using "
|
||||||
|
f"{len(available_addrs)} addresses"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use swarm.dial_peer to connect using stored addresses
|
||||||
|
await self.swarm.dial_peer(peer_id)
|
||||||
|
|
||||||
|
# Calculate connection time
|
||||||
|
connection_time = trio.current_time() - connection_start_time
|
||||||
|
|
||||||
|
# Post-connection validation: Verify connection was actually established
|
||||||
|
if peer_id in self.swarm.connections:
|
||||||
|
logger.info(
|
||||||
|
f"✅ Connected to {peer_id} (took {connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Dial succeeded but connection not found for {peer_id}"
|
||||||
|
)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ Connection to {peer_id} timed out after {self.connection_timeout}s"
|
||||||
|
)
|
||||||
|
except SwarmException as e:
|
||||||
|
# Calculate failed connection time
|
||||||
|
failed_connection_time = trio.current_time() - connection_start_time
|
||||||
|
|
||||||
|
# Enhanced error logging
|
||||||
|
error_msg = str(e)
|
||||||
|
if "no addresses established a successful connection" in error_msg:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ Failed to connect to {peer_id} after trying all "
|
||||||
|
f"{len(available_addrs)} addresses "
|
||||||
|
f"(took {failed_connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
# Log individual address failures if this is a MultiError
|
||||||
|
if (
|
||||||
|
e.__cause__ is not None
|
||||||
|
and hasattr(e.__cause__, "exceptions")
|
||||||
|
and getattr(e.__cause__, "exceptions", None) is not None
|
||||||
|
):
|
||||||
|
exceptions_list = getattr(e.__cause__, "exceptions")
|
||||||
|
logger.debug("📋 Individual address failure details:")
|
||||||
|
for i, addr_exception in enumerate(exceptions_list, 1):
|
||||||
|
logger.debug(f"Address {i}: {addr_exception}")
|
||||||
|
# Also log the actual address that failed
|
||||||
|
if i <= len(available_addrs):
|
||||||
|
logger.debug(f"Failed address: {available_addrs[i - 1]}")
|
||||||
|
else:
|
||||||
|
logger.warning("No detailed exception information available")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"❌ Failed to connect to {peer_id}: {e} "
|
||||||
|
f"(took {failed_connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle unexpected errors that aren't swarm-specific
|
||||||
|
failed_connection_time = trio.current_time() - connection_start_time
|
||||||
|
logger.error(
|
||||||
|
f"❌ Unexpected error connecting to {peer_id}: "
|
||||||
|
f"{e} (took {failed_connection_time:.2f}s)"
|
||||||
|
)
|
||||||
|
# Don't re-raise to prevent killing the nursery and other parallel tasks
|
||||||
|
|
||||||
|
def _is_ipv4_tcp_addr(self, addr: Multiaddr) -> bool:
|
||||||
|
"""
|
||||||
|
Check if address is IPv4 with TCP protocol only.
|
||||||
|
|
||||||
|
Filters out IPv6, UDP, QUIC, WebSocket, and other unsupported protocols.
|
||||||
|
Only IPv4+TCP addresses are supported by the current transport.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
protocols = addr.protocols()
|
||||||
|
|
||||||
|
# Must have IPv4 protocol
|
||||||
|
has_ipv4 = any(p.name == "ip4" for p in protocols)
|
||||||
|
if not has_ipv4:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Must have TCP protocol
|
||||||
|
has_tcp = any(p.name == "tcp" for p in protocols)
|
||||||
|
if not has_tcp:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If we can't parse the address, don't use it
|
||||||
|
return False
|
||||||
|
|||||||
@ -213,7 +213,6 @@ class BasicHost(IHost):
|
|||||||
self,
|
self,
|
||||||
peer_id: ID,
|
peer_id: ID,
|
||||||
protocol_ids: Sequence[TProtocol],
|
protocol_ids: Sequence[TProtocol],
|
||||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
|
||||||
) -> INetStream:
|
) -> INetStream:
|
||||||
"""
|
"""
|
||||||
:param peer_id: peer_id that host is connecting
|
:param peer_id: peer_id that host is connecting
|
||||||
@ -227,7 +226,7 @@ class BasicHost(IHost):
|
|||||||
selected_protocol = await self.multiselect_client.select_one_of(
|
selected_protocol = await self.multiselect_client.select_one_of(
|
||||||
list(protocol_ids),
|
list(protocol_ids),
|
||||||
MultiselectCommunicator(net_stream),
|
MultiselectCommunicator(net_stream),
|
||||||
negotitate_timeout,
|
self.negotiate_timeout,
|
||||||
)
|
)
|
||||||
except MultiselectClientError as error:
|
except MultiselectClientError as error:
|
||||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||||
|
|||||||
70
libp2p/network/config.py
Normal file
70
libp2p/network/config.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetryConfig:
|
||||||
|
"""
|
||||||
|
Configuration for retry logic with exponential backoff.
|
||||||
|
|
||||||
|
This configuration controls how connection attempts are retried when they fail.
|
||||||
|
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||||
|
herd problems in distributed systems.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
max_retries: Maximum number of retry attempts before giving up.
|
||||||
|
Default: 3 attempts
|
||||||
|
initial_delay: Initial delay in seconds before the first retry.
|
||||||
|
Default: 0.1 seconds (100ms)
|
||||||
|
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||||
|
Default: 30.0 seconds
|
||||||
|
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||||
|
the delay by this factor). Default: 2.0 (doubles each time)
|
||||||
|
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||||
|
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_retries: int = 3
|
||||||
|
initial_delay: float = 0.1
|
||||||
|
max_delay: float = 30.0
|
||||||
|
backoff_multiplier: float = 2.0
|
||||||
|
jitter_factor: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConnectionConfig:
|
||||||
|
"""
|
||||||
|
Configuration for multi-connection support.
|
||||||
|
|
||||||
|
This configuration controls how multiple connections per peer are managed,
|
||||||
|
including connection limits, timeouts, and load balancing strategies.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||||
|
peer. Default: 3 connections
|
||||||
|
connection_timeout: Timeout in seconds for establishing new connections.
|
||||||
|
Default: 30.0 seconds
|
||||||
|
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||||
|
Options: "round_robin" (default) or "least_loaded"
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_connections_per_peer: int = 3
|
||||||
|
connection_timeout: float = 30.0
|
||||||
|
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate configuration after initialization."""
|
||||||
|
if not (
|
||||||
|
self.load_balancing_strategy == "round_robin"
|
||||||
|
or self.load_balancing_strategy == "least_loaded"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Load balancing strategy can only be 'round_robin' or 'least_loaded'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.max_connections_per_peer < 1:
|
||||||
|
raise ValueError("Max connection per peer should be atleast 1")
|
||||||
|
|
||||||
|
if self.connection_timeout < 0:
|
||||||
|
raise ValueError("Connection timeout should be positive")
|
||||||
@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import (
|
|||||||
MuxedStreamError,
|
MuxedStreamError,
|
||||||
MuxedStreamReset,
|
MuxedStreamReset,
|
||||||
)
|
)
|
||||||
|
from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
StreamClosed,
|
StreamClosed,
|
||||||
@ -170,7 +171,7 @@ class NetStream(INetStream):
|
|||||||
elif self.__stream_state == StreamState.OPEN:
|
elif self.__stream_state == StreamState.OPEN:
|
||||||
self.__stream_state = StreamState.CLOSE_READ
|
self.__stream_state = StreamState.CLOSE_READ
|
||||||
raise StreamEOF() from error
|
raise StreamEOF() from error
|
||||||
except MuxedStreamReset as error:
|
except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error:
|
||||||
async with self._state_lock:
|
async with self._state_lock:
|
||||||
if self.__stream_state in [
|
if self.__stream_state in [
|
||||||
StreamState.OPEN,
|
StreamState.OPEN,
|
||||||
@ -199,7 +200,12 @@ class NetStream(INetStream):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self.muxed_stream.write(data)
|
await self.muxed_stream.write(data)
|
||||||
except (MuxedStreamClosed, MuxedStreamError) as error:
|
except (
|
||||||
|
MuxedStreamClosed,
|
||||||
|
MuxedStreamError,
|
||||||
|
QUICStreamClosedError,
|
||||||
|
QUICStreamResetError,
|
||||||
|
) as error:
|
||||||
async with self._state_lock:
|
async with self._state_lock:
|
||||||
if self.__stream_state == StreamState.OPEN:
|
if self.__stream_state == StreamState.OPEN:
|
||||||
self.__stream_state = StreamState.CLOSE_WRITE
|
self.__stream_state = StreamState.CLOSE_WRITE
|
||||||
|
|||||||
@ -2,9 +2,9 @@ from collections.abc import (
|
|||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
)
|
)
|
||||||
from dataclasses import dataclass
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from multiaddr import (
|
from multiaddr import (
|
||||||
Multiaddr,
|
Multiaddr,
|
||||||
@ -27,6 +27,7 @@ from libp2p.custom_types import (
|
|||||||
from libp2p.io.abc import (
|
from libp2p.io.abc import (
|
||||||
ReadWriteCloser,
|
ReadWriteCloser,
|
||||||
)
|
)
|
||||||
|
from libp2p.network.config import ConnectionConfig, RetryConfig
|
||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
@ -41,6 +42,9 @@ from libp2p.transport.exceptions import (
|
|||||||
OpenConnectionError,
|
OpenConnectionError,
|
||||||
SecurityUpgradeFailure,
|
SecurityUpgradeFailure,
|
||||||
)
|
)
|
||||||
|
from libp2p.transport.quic.config import QUICTransportConfig
|
||||||
|
from libp2p.transport.quic.connection import QUICConnection
|
||||||
|
from libp2p.transport.quic.transport import QUICTransport
|
||||||
from libp2p.transport.upgrader import (
|
from libp2p.transport.upgrader import (
|
||||||
TransportUpgrader,
|
TransportUpgrader,
|
||||||
)
|
)
|
||||||
@ -61,59 +65,6 @@ from .exceptions import (
|
|||||||
logger = logging.getLogger("libp2p.network.swarm")
|
logger = logging.getLogger("libp2p.network.swarm")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RetryConfig:
|
|
||||||
"""
|
|
||||||
Configuration for retry logic with exponential backoff.
|
|
||||||
|
|
||||||
This configuration controls how connection attempts are retried when they fail.
|
|
||||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
|
||||||
herd problems in distributed systems.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
max_retries: Maximum number of retry attempts before giving up.
|
|
||||||
Default: 3 attempts
|
|
||||||
initial_delay: Initial delay in seconds before the first retry.
|
|
||||||
Default: 0.1 seconds (100ms)
|
|
||||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
|
||||||
Default: 30.0 seconds
|
|
||||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
|
||||||
the delay by this factor). Default: 2.0 (doubles each time)
|
|
||||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
|
||||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_retries: int = 3
|
|
||||||
initial_delay: float = 0.1
|
|
||||||
max_delay: float = 30.0
|
|
||||||
backoff_multiplier: float = 2.0
|
|
||||||
jitter_factor: float = 0.1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConnectionConfig:
|
|
||||||
"""
|
|
||||||
Configuration for multi-connection support.
|
|
||||||
|
|
||||||
This configuration controls how multiple connections per peer are managed,
|
|
||||||
including connection limits, timeouts, and load balancing strategies.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
|
||||||
peer. Default: 3 connections
|
|
||||||
connection_timeout: Timeout in seconds for establishing new connections.
|
|
||||||
Default: 30.0 seconds
|
|
||||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
|
||||||
Options: "round_robin" (default) or "least_loaded"
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_connections_per_peer: int = 3
|
|
||||||
connection_timeout: float = 30.0
|
|
||||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||||
async def stream_handler(stream: INetStream) -> None:
|
async def stream_handler(stream: INetStream) -> None:
|
||||||
await network.get_manager().wait_finished()
|
await network.get_manager().wait_finished()
|
||||||
@ -126,8 +77,7 @@ class Swarm(Service, INetworkService):
|
|||||||
peerstore: IPeerStore
|
peerstore: IPeerStore
|
||||||
upgrader: TransportUpgrader
|
upgrader: TransportUpgrader
|
||||||
transport: ITransport
|
transport: ITransport
|
||||||
# Enhanced: Support for multiple connections per peer
|
connections: dict[ID, list[INetConn]]
|
||||||
connections: dict[ID, list[INetConn]] # Multiple connections per peer
|
|
||||||
listeners: dict[str, IListener]
|
listeners: dict[str, IListener]
|
||||||
common_stream_handler: StreamHandlerFn
|
common_stream_handler: StreamHandlerFn
|
||||||
listener_nursery: trio.Nursery | None
|
listener_nursery: trio.Nursery | None
|
||||||
@ -137,7 +87,7 @@ class Swarm(Service, INetworkService):
|
|||||||
|
|
||||||
# Enhanced: New configuration
|
# Enhanced: New configuration
|
||||||
retry_config: RetryConfig
|
retry_config: RetryConfig
|
||||||
connection_config: ConnectionConfig
|
connection_config: ConnectionConfig | QUICTransportConfig
|
||||||
_round_robin_index: dict[ID, int]
|
_round_robin_index: dict[ID, int]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -147,7 +97,7 @@ class Swarm(Service, INetworkService):
|
|||||||
upgrader: TransportUpgrader,
|
upgrader: TransportUpgrader,
|
||||||
transport: ITransport,
|
transport: ITransport,
|
||||||
retry_config: RetryConfig | None = None,
|
retry_config: RetryConfig | None = None,
|
||||||
connection_config: ConnectionConfig | None = None,
|
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||||
):
|
):
|
||||||
self.self_id = peer_id
|
self.self_id = peer_id
|
||||||
self.peerstore = peerstore
|
self.peerstore = peerstore
|
||||||
@ -178,6 +128,11 @@ class Swarm(Service, INetworkService):
|
|||||||
# Create a nursery for listener tasks.
|
# Create a nursery for listener tasks.
|
||||||
self.listener_nursery = nursery
|
self.listener_nursery = nursery
|
||||||
self.event_listener_nursery_created.set()
|
self.event_listener_nursery_created.set()
|
||||||
|
|
||||||
|
if isinstance(self.transport, QUICTransport):
|
||||||
|
self.transport.set_background_nursery(nursery)
|
||||||
|
self.transport.set_swarm(self)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.manager.wait_finished()
|
await self.manager.wait_finished()
|
||||||
finally:
|
finally:
|
||||||
@ -370,6 +325,7 @@ class Swarm(Service, INetworkService):
|
|||||||
# Dial peer (connection to peer does not yet exist)
|
# Dial peer (connection to peer does not yet exist)
|
||||||
# Transport dials peer (gets back a raw conn)
|
# Transport dials peer (gets back a raw conn)
|
||||||
try:
|
try:
|
||||||
|
addr = Multiaddr(f"{addr}/p2p/{peer_id}")
|
||||||
raw_conn = await self.transport.dial(addr)
|
raw_conn = await self.transport.dial(addr)
|
||||||
except OpenConnectionError as error:
|
except OpenConnectionError as error:
|
||||||
logger.debug("fail to dial peer %s over base transport", peer_id)
|
logger.debug("fail to dial peer %s over base transport", peer_id)
|
||||||
@ -377,6 +333,15 @@ class Swarm(Service, INetworkService):
|
|||||||
f"fail to open connection to peer {peer_id}"
|
f"fail to open connection to peer {peer_id}"
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
|
if isinstance(self.transport, QUICTransport) and isinstance(
|
||||||
|
raw_conn, IMuxedConn
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Skipping upgrade for QUIC, QUIC connections are already multiplexed"
|
||||||
|
)
|
||||||
|
swarm_conn = await self.add_conn(raw_conn)
|
||||||
|
return swarm_conn
|
||||||
|
|
||||||
logger.debug("dialed peer %s over base transport", peer_id)
|
logger.debug("dialed peer %s over base transport", peer_id)
|
||||||
|
|
||||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||||
@ -402,9 +367,7 @@ class Swarm(Service, INetworkService):
|
|||||||
logger.debug("upgraded mux for peer %s", peer_id)
|
logger.debug("upgraded mux for peer %s", peer_id)
|
||||||
|
|
||||||
swarm_conn = await self.add_conn(muxed_conn)
|
swarm_conn = await self.add_conn(muxed_conn)
|
||||||
|
|
||||||
logger.debug("successfully dialed peer %s", peer_id)
|
logger.debug("successfully dialed peer %s", peer_id)
|
||||||
|
|
||||||
return swarm_conn
|
return swarm_conn
|
||||||
|
|
||||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||||
@ -427,7 +390,6 @@ class Swarm(Service, INetworkService):
|
|||||||
:return: net stream instance
|
:return: net stream instance
|
||||||
"""
|
"""
|
||||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||||
|
|
||||||
# Get existing connections or dial new ones
|
# Get existing connections or dial new ones
|
||||||
connections = self.get_connections(peer_id)
|
connections = self.get_connections(peer_id)
|
||||||
if not connections:
|
if not connections:
|
||||||
@ -436,6 +398,10 @@ class Swarm(Service, INetworkService):
|
|||||||
# Load balancing strategy at interface level
|
# Load balancing strategy at interface level
|
||||||
connection = self._select_connection(connections, peer_id)
|
connection = self._select_connection(connections, peer_id)
|
||||||
|
|
||||||
|
if isinstance(self.transport, QUICTransport) and connection is not None:
|
||||||
|
conn = cast(SwarmConn, connection)
|
||||||
|
return await conn.new_stream()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
net_stream = await connection.new_stream()
|
net_stream = await connection.new_stream()
|
||||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||||
@ -515,18 +481,38 @@ class Swarm(Service, INetworkService):
|
|||||||
- Call listener listen with the multiaddr
|
- Call listener listen with the multiaddr
|
||||||
- Map multiaddr to listener
|
- Map multiaddr to listener
|
||||||
"""
|
"""
|
||||||
|
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
|
||||||
# We need to wait until `self.listener_nursery` is created.
|
# We need to wait until `self.listener_nursery` is created.
|
||||||
|
logger.debug("Starting to listen")
|
||||||
await self.event_listener_nursery_created.wait()
|
await self.event_listener_nursery_created.wait()
|
||||||
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
for maddr in multiaddrs:
|
for maddr in multiaddrs:
|
||||||
|
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
|
||||||
if str(maddr) in self.listeners:
|
if str(maddr) in self.listeners:
|
||||||
|
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
|
||||||
success_count += 1
|
success_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def conn_handler(
|
async def conn_handler(
|
||||||
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# No need to upgrade QUIC Connection
|
||||||
|
if isinstance(self.transport, QUICTransport):
|
||||||
|
try:
|
||||||
|
quic_conn = cast(QUICConnection, read_write_closer)
|
||||||
|
await self.add_conn(quic_conn)
|
||||||
|
peer_id = quic_conn.peer_id
|
||||||
|
logger.debug(
|
||||||
|
f"successfully opened quic connection to peer {peer_id}"
|
||||||
|
)
|
||||||
|
# NOTE: This is a intentional barrier to prevent from the
|
||||||
|
# handler exiting and closing the connection.
|
||||||
|
await self.manager.wait_finished()
|
||||||
|
except Exception:
|
||||||
|
await read_write_closer.close()
|
||||||
|
return
|
||||||
|
|
||||||
raw_conn = RawConnection(read_write_closer, False)
|
raw_conn = RawConnection(read_write_closer, False)
|
||||||
|
|
||||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
|
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
|
||||||
@ -562,13 +548,18 @@ class Swarm(Service, INetworkService):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Success
|
# Success
|
||||||
|
logger.debug(f"Swarm.listen: creating listener for {maddr}")
|
||||||
listener = self.transport.create_listener(conn_handler)
|
listener = self.transport.create_listener(conn_handler)
|
||||||
|
logger.debug(f"Swarm.listen: listener created for {maddr}")
|
||||||
self.listeners[str(maddr)] = listener
|
self.listeners[str(maddr)] = listener
|
||||||
# TODO: `listener.listen` is not bounded with nursery. If we want to be
|
# TODO: `listener.listen` is not bounded with nursery. If we want to be
|
||||||
# I/O agnostic, we should change the API.
|
# I/O agnostic, we should change the API.
|
||||||
if self.listener_nursery is None:
|
if self.listener_nursery is None:
|
||||||
raise SwarmException("swarm instance hasn't been run")
|
raise SwarmException("swarm instance hasn't been run")
|
||||||
|
assert self.listener_nursery is not None # For type checker
|
||||||
|
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
|
||||||
await listener.listen(maddr, self.listener_nursery)
|
await listener.listen(maddr, self.listener_nursery)
|
||||||
|
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
|
||||||
|
|
||||||
# Call notifiers since event occurred
|
# Call notifiers since event occurred
|
||||||
await self.notify_listen(maddr)
|
await self.notify_listen(maddr)
|
||||||
@ -660,9 +651,10 @@ class Swarm(Service, INetworkService):
|
|||||||
muxed_conn,
|
muxed_conn,
|
||||||
self,
|
self,
|
||||||
)
|
)
|
||||||
|
logger.debug("Swarm::add_conn | starting muxed connection")
|
||||||
self.manager.run_task(muxed_conn.start)
|
self.manager.run_task(muxed_conn.start)
|
||||||
await muxed_conn.event_started.wait()
|
await muxed_conn.event_started.wait()
|
||||||
|
logger.debug("Swarm::add_conn | starting swarm connection")
|
||||||
self.manager.run_task(swarm_conn.start)
|
self.manager.run_task(swarm_conn.start)
|
||||||
await swarm_conn.event_started.wait()
|
await swarm_conn.event_started.wait()
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from builtins import AssertionError
|
||||||
|
|
||||||
from libp2p.abc import (
|
from libp2p.abc import (
|
||||||
IMultiselectCommunicator,
|
IMultiselectCommunicator,
|
||||||
)
|
)
|
||||||
@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
|||||||
msg_bytes = encode_delim(msg_str.encode())
|
msg_bytes = encode_delim(msg_str.encode())
|
||||||
try:
|
try:
|
||||||
await self.read_writer.write(msg_bytes)
|
await self.read_writer.write(msg_bytes)
|
||||||
except IOException as error:
|
# Handle for connection close during ongoing negotiation in QUIC
|
||||||
|
except (IOException, AssertionError, ValueError) as error:
|
||||||
raise MultiselectCommunicatorError(
|
raise MultiselectCommunicatorError(
|
||||||
"fail to write to multiselect communicator"
|
"fail to write to multiselect communicator"
|
||||||
) from error
|
) from error
|
||||||
|
|||||||
@ -1,6 +1,3 @@
|
|||||||
from ast import (
|
|
||||||
literal_eval,
|
|
||||||
)
|
|
||||||
from collections import (
|
from collections import (
|
||||||
defaultdict,
|
defaultdict,
|
||||||
)
|
)
|
||||||
@ -22,6 +19,7 @@ from libp2p.abc import (
|
|||||||
IPubsubRouter,
|
IPubsubRouter,
|
||||||
)
|
)
|
||||||
from libp2p.custom_types import (
|
from libp2p.custom_types import (
|
||||||
|
MessageID,
|
||||||
TProtocol,
|
TProtocol,
|
||||||
)
|
)
|
||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
@ -56,6 +54,10 @@ from .pb import (
|
|||||||
from .pubsub import (
|
from .pubsub import (
|
||||||
Pubsub,
|
Pubsub,
|
||||||
)
|
)
|
||||||
|
from .utils import (
|
||||||
|
parse_message_id_safe,
|
||||||
|
safe_parse_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||||
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
||||||
@ -306,7 +308,8 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
floodsub_peers: set[ID] = {
|
floodsub_peers: set[ID] = {
|
||||||
peer_id
|
peer_id
|
||||||
for peer_id in self.pubsub.peer_topics[topic]
|
for peer_id in self.pubsub.peer_topics[topic]
|
||||||
if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
|
if peer_id in self.peer_protocol
|
||||||
|
and self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
|
||||||
}
|
}
|
||||||
send_to.update(floodsub_peers)
|
send_to.update(floodsub_peers)
|
||||||
|
|
||||||
@ -794,8 +797,8 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
|
|
||||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||||
# seen_seqnos) to list of messages we want to request
|
# seen_seqnos) to list of messages we want to request
|
||||||
msg_ids_wanted: list[str] = [
|
msg_ids_wanted: list[MessageID] = [
|
||||||
msg_id
|
parse_message_id_safe(msg_id)
|
||||||
for msg_id in ihave_msg.messageIDs
|
for msg_id in ihave_msg.messageIDs
|
||||||
if msg_id not in seen_seqnos_and_peers
|
if msg_id not in seen_seqnos_and_peers
|
||||||
]
|
]
|
||||||
@ -811,9 +814,9 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
Forwards all request messages that are present in mcache to the
|
Forwards all request messages that are present in mcache to the
|
||||||
requesting peer.
|
requesting peer.
|
||||||
"""
|
"""
|
||||||
# FIXME: Update type of message ID
|
msg_ids: list[tuple[bytes, bytes]] = [
|
||||||
# FIXME: Find a better way to parse the msg ids
|
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
|
||||||
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
|
]
|
||||||
msgs_to_forward: list[rpc_pb2.Message] = []
|
msgs_to_forward: list[rpc_pb2.Message] = []
|
||||||
for msg_id_iwant in msg_ids:
|
for msg_id_iwant in msg_ids:
|
||||||
# Check if the wanted message ID is present in mcache
|
# Check if the wanted message ID is present in mcache
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
|
import ast
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from libp2p.abc import IHost
|
from libp2p.abc import IHost
|
||||||
|
from libp2p.custom_types import (
|
||||||
|
MessageID,
|
||||||
|
)
|
||||||
from libp2p.peer.envelope import consume_envelope
|
from libp2p.peer.envelope import consume_envelope
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
||||||
@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
|
|||||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def parse_message_id_safe(msg_id_str: str) -> MessageID:
|
||||||
|
"""Safely handle message ID as string."""
|
||||||
|
return MessageID(msg_id_str)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
|
||||||
|
"""
|
||||||
|
Safely parse message ID using ast.literal_eval with validation.
|
||||||
|
:param msg_id_str: String representation of message ID
|
||||||
|
:return: Tuple of (seqno, from_id) as bytes
|
||||||
|
:raises ValueError: If parsing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = ast.literal_eval(msg_id_str)
|
||||||
|
if not isinstance(parsed, tuple) or len(parsed) != 2:
|
||||||
|
raise ValueError("Invalid message ID format")
|
||||||
|
|
||||||
|
seqno, from_id = parsed
|
||||||
|
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
|
||||||
|
raise ValueError("Message ID components must be bytes")
|
||||||
|
|
||||||
|
return (seqno, from_id)
|
||||||
|
except (ValueError, SyntaxError) as e:
|
||||||
|
raise ValueError(f"Invalid message ID format: {e}")
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from dataclasses import (
|
|||||||
dataclass,
|
dataclass,
|
||||||
field,
|
field,
|
||||||
)
|
)
|
||||||
|
from enum import Flag, auto
|
||||||
|
|
||||||
from libp2p.peer.peerinfo import (
|
from libp2p.peer.peerinfo import (
|
||||||
PeerInfo,
|
PeerInfo,
|
||||||
@ -18,29 +19,118 @@ from .resources import (
|
|||||||
RelayLimits,
|
RelayLimits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_MIN_RELAYS = 3
|
||||||
|
DEFAULT_MAX_RELAYS = 20
|
||||||
|
DEFAULT_DISCOVERY_INTERVAL = 300 # seconds
|
||||||
|
DEFAULT_RESERVATION_TTL = 3600 # seconds
|
||||||
|
DEFAULT_MAX_CIRCUIT_DURATION = 3600 # seconds
|
||||||
|
DEFAULT_MAX_CIRCUIT_BYTES = 1024 * 1024 * 1024 # 1GB
|
||||||
|
|
||||||
|
DEFAULT_MAX_CIRCUIT_CONNS = 8
|
||||||
|
DEFAULT_MAX_RESERVATIONS = 4
|
||||||
|
|
||||||
|
MAX_RESERVATIONS_PER_IP = 8
|
||||||
|
MAX_CIRCUITS_PER_IP = 16
|
||||||
|
RESERVATION_RATE_PER_IP = 4 # per minute
|
||||||
|
CIRCUIT_RATE_PER_IP = 8 # per minute
|
||||||
|
MAX_CIRCUITS_TOTAL = 64
|
||||||
|
MAX_RESERVATIONS_TOTAL = 32
|
||||||
|
MAX_BANDWIDTH_PER_CIRCUIT = 1024 * 1024 # 1MB/s
|
||||||
|
MAX_BANDWIDTH_TOTAL = 10 * 1024 * 1024 # 10MB/s
|
||||||
|
|
||||||
|
MIN_RELAY_SCORE = 0.5
|
||||||
|
MAX_RELAY_LATENCY = 1.0 # seconds
|
||||||
|
ENABLE_AUTO_RELAY = True
|
||||||
|
AUTO_RELAY_TIMEOUT = 30 # seconds
|
||||||
|
MAX_AUTO_RELAY_ATTEMPTS = 3
|
||||||
|
RESERVATION_REFRESH_THRESHOLD = 0.8 # Refresh at 80% of TTL
|
||||||
|
MAX_CONCURRENT_RESERVATIONS = 2
|
||||||
|
|
||||||
|
# Timeout constants for different components
|
||||||
|
DEFAULT_DISCOVERY_STREAM_TIMEOUT = 10 # seconds
|
||||||
|
DEFAULT_PEER_PROTOCOL_TIMEOUT = 5 # seconds
|
||||||
|
DEFAULT_PROTOCOL_READ_TIMEOUT = 15 # seconds
|
||||||
|
DEFAULT_PROTOCOL_WRITE_TIMEOUT = 15 # seconds
|
||||||
|
DEFAULT_PROTOCOL_CLOSE_TIMEOUT = 10 # seconds
|
||||||
|
DEFAULT_DCUTR_READ_TIMEOUT = 30 # seconds
|
||||||
|
DEFAULT_DCUTR_WRITE_TIMEOUT = 30 # seconds
|
||||||
|
DEFAULT_DIAL_TIMEOUT = 10 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TimeoutConfig:
|
||||||
|
"""Timeout configuration for different Circuit Relay v2 components."""
|
||||||
|
|
||||||
|
# Discovery timeouts
|
||||||
|
discovery_stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT
|
||||||
|
peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT
|
||||||
|
|
||||||
|
# Core protocol timeouts
|
||||||
|
protocol_read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT
|
||||||
|
protocol_write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT
|
||||||
|
protocol_close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT
|
||||||
|
|
||||||
|
# DCUtR timeouts
|
||||||
|
dcutr_read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT
|
||||||
|
dcutr_write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT
|
||||||
|
dial_timeout: int = DEFAULT_DIAL_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
# Relay roles enum
|
||||||
|
class RelayRole(Flag):
|
||||||
|
"""
|
||||||
|
Bit-flag enum that captures the three possible relay capabilities.
|
||||||
|
|
||||||
|
A node can combine multiple roles using bit-wise OR, for example::
|
||||||
|
|
||||||
|
RelayRole.HOP | RelayRole.STOP
|
||||||
|
"""
|
||||||
|
|
||||||
|
HOP = auto() # Act as a relay for others ("hop")
|
||||||
|
STOP = auto() # Accept relayed connections ("stop")
|
||||||
|
CLIENT = auto() # Dial through existing relays ("client")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RelayConfig:
|
class RelayConfig:
|
||||||
"""Configuration for Circuit Relay v2."""
|
"""Configuration for Circuit Relay v2."""
|
||||||
|
|
||||||
# Role configuration
|
# Role configuration (bit-flags)
|
||||||
enable_hop: bool = False # Whether to act as a relay (hop)
|
roles: RelayRole = RelayRole.STOP | RelayRole.CLIENT
|
||||||
enable_stop: bool = True # Whether to accept relayed connections (stop)
|
|
||||||
enable_client: bool = True # Whether to use relays for dialing
|
|
||||||
|
|
||||||
# Resource limits
|
# Resource limits
|
||||||
limits: RelayLimits | None = None
|
limits: RelayLimits | None = None
|
||||||
|
|
||||||
# Discovery configuration
|
# Discovery configuration
|
||||||
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
|
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
|
||||||
min_relays: int = 3
|
min_relays: int = DEFAULT_MIN_RELAYS
|
||||||
max_relays: int = 20
|
max_relays: int = DEFAULT_MAX_RELAYS
|
||||||
discovery_interval: int = 300 # seconds
|
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL
|
||||||
|
|
||||||
# Connection configuration
|
# Connection configuration
|
||||||
reservation_ttl: int = 3600 # seconds
|
reservation_ttl: int = DEFAULT_RESERVATION_TTL
|
||||||
max_circuit_duration: int = 3600 # seconds
|
max_circuit_duration: int = DEFAULT_MAX_CIRCUIT_DURATION
|
||||||
max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB
|
max_circuit_bytes: int = DEFAULT_MAX_CIRCUIT_BYTES
|
||||||
|
|
||||||
|
# Timeout configuration
|
||||||
|
timeouts: TimeoutConfig = field(default_factory=TimeoutConfig)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Backwards-compat boolean helpers. Existing code that still accesses
|
||||||
|
# ``cfg.enable_hop, cfg.enable_stop, cfg.enable_client`` will continue to work.
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_hop(self) -> bool: # pragma: no cover – helper
|
||||||
|
return bool(self.roles & RelayRole.HOP)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_stop(self) -> bool: # pragma: no cover – helper
|
||||||
|
return bool(self.roles & RelayRole.STOP)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_client(self) -> bool: # pragma: no cover – helper
|
||||||
|
return bool(self.roles & RelayRole.CLIENT)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Initialize default values."""
|
"""Initialize default values."""
|
||||||
@ -48,8 +138,8 @@ class RelayConfig:
|
|||||||
self.limits = RelayLimits(
|
self.limits = RelayLimits(
|
||||||
duration=self.max_circuit_duration,
|
duration=self.max_circuit_duration,
|
||||||
data=self.max_circuit_bytes,
|
data=self.max_circuit_bytes,
|
||||||
max_circuit_conns=8,
|
max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS,
|
||||||
max_reservations=4,
|
max_reservations=DEFAULT_MAX_RESERVATIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -58,20 +148,20 @@ class HopConfig:
|
|||||||
"""Configuration specific to relay (hop) nodes."""
|
"""Configuration specific to relay (hop) nodes."""
|
||||||
|
|
||||||
# Resource limits per IP
|
# Resource limits per IP
|
||||||
max_reservations_per_ip: int = 8
|
max_reservations_per_ip: int = MAX_RESERVATIONS_PER_IP
|
||||||
max_circuits_per_ip: int = 16
|
max_circuits_per_ip: int = MAX_CIRCUITS_PER_IP
|
||||||
|
|
||||||
# Rate limiting
|
# Rate limiting
|
||||||
reservation_rate_per_ip: int = 4 # per minute
|
reservation_rate_per_ip: int = RESERVATION_RATE_PER_IP
|
||||||
circuit_rate_per_ip: int = 8 # per minute
|
circuit_rate_per_ip: int = CIRCUIT_RATE_PER_IP
|
||||||
|
|
||||||
# Resource quotas
|
# Resource quotas
|
||||||
max_circuits_total: int = 64
|
max_circuits_total: int = MAX_CIRCUITS_TOTAL
|
||||||
max_reservations_total: int = 32
|
max_reservations_total: int = MAX_RESERVATIONS_TOTAL
|
||||||
|
|
||||||
# Bandwidth limits
|
# Bandwidth limits
|
||||||
max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s
|
max_bandwidth_per_circuit: int = MAX_BANDWIDTH_PER_CIRCUIT
|
||||||
max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s
|
max_bandwidth_total: int = MAX_BANDWIDTH_TOTAL
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -79,14 +169,14 @@ class ClientConfig:
|
|||||||
"""Configuration specific to relay clients."""
|
"""Configuration specific to relay clients."""
|
||||||
|
|
||||||
# Relay selection
|
# Relay selection
|
||||||
min_relay_score: float = 0.5
|
min_relay_score: float = MIN_RELAY_SCORE
|
||||||
max_relay_latency: float = 1.0 # seconds
|
max_relay_latency: float = MAX_RELAY_LATENCY
|
||||||
|
|
||||||
# Auto-relay settings
|
# Auto-relay settings
|
||||||
enable_auto_relay: bool = True
|
enable_auto_relay: bool = ENABLE_AUTO_RELAY
|
||||||
auto_relay_timeout: int = 30 # seconds
|
auto_relay_timeout: int = AUTO_RELAY_TIMEOUT
|
||||||
max_auto_relay_attempts: int = 3
|
max_auto_relay_attempts: int = MAX_AUTO_RELAY_ATTEMPTS
|
||||||
|
|
||||||
# Reservation management
|
# Reservation management
|
||||||
reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL
|
reservation_refresh_threshold: float = RESERVATION_REFRESH_THRESHOLD
|
||||||
max_concurrent_reservations: int = 2
|
max_concurrent_reservations: int = MAX_CONCURRENT_RESERVATIONS
|
||||||
|
|||||||
@ -29,6 +29,11 @@ from libp2p.peer.id import (
|
|||||||
from libp2p.peer.peerinfo import (
|
from libp2p.peer.peerinfo import (
|
||||||
PeerInfo,
|
PeerInfo,
|
||||||
)
|
)
|
||||||
|
from libp2p.relay.circuit_v2.config import (
|
||||||
|
DEFAULT_DCUTR_READ_TIMEOUT,
|
||||||
|
DEFAULT_DCUTR_WRITE_TIMEOUT,
|
||||||
|
DEFAULT_DIAL_TIMEOUT,
|
||||||
|
)
|
||||||
from libp2p.relay.circuit_v2.nat import (
|
from libp2p.relay.circuit_v2.nat import (
|
||||||
ReachabilityChecker,
|
ReachabilityChecker,
|
||||||
)
|
)
|
||||||
@ -47,11 +52,7 @@ PROTOCOL_ID = TProtocol("/libp2p/dcutr")
|
|||||||
# Maximum message size for DCUtR (4KiB as per spec)
|
# Maximum message size for DCUtR (4KiB as per spec)
|
||||||
MAX_MESSAGE_SIZE = 4 * 1024
|
MAX_MESSAGE_SIZE = 4 * 1024
|
||||||
|
|
||||||
# Timeouts
|
# DCUtR protocol constants
|
||||||
STREAM_READ_TIMEOUT = 30 # seconds
|
|
||||||
STREAM_WRITE_TIMEOUT = 30 # seconds
|
|
||||||
DIAL_TIMEOUT = 10 # seconds
|
|
||||||
|
|
||||||
# Maximum number of hole punch attempts per peer
|
# Maximum number of hole punch attempts per peer
|
||||||
MAX_HOLE_PUNCH_ATTEMPTS = 5
|
MAX_HOLE_PUNCH_ATTEMPTS = 5
|
||||||
|
|
||||||
@ -70,7 +71,13 @@ class DCUtRProtocol(Service):
|
|||||||
hole punching, after they have established an initial connection through a relay.
|
hole punching, after they have established an initial connection through a relay.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, host: IHost):
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: IHost,
|
||||||
|
read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT,
|
||||||
|
write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT,
|
||||||
|
dial_timeout: int = DEFAULT_DIAL_TIMEOUT,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the DCUtR protocol.
|
Initialize the DCUtR protocol.
|
||||||
|
|
||||||
@ -78,10 +85,19 @@ class DCUtRProtocol(Service):
|
|||||||
----------
|
----------
|
||||||
host : IHost
|
host : IHost
|
||||||
The libp2p host this protocol is running on
|
The libp2p host this protocol is running on
|
||||||
|
read_timeout : int
|
||||||
|
Timeout for stream read operations, in seconds
|
||||||
|
write_timeout : int
|
||||||
|
Timeout for stream write operations, in seconds
|
||||||
|
dial_timeout : int
|
||||||
|
Timeout for dial operations, in seconds
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.host = host
|
self.host = host
|
||||||
|
self.read_timeout = read_timeout
|
||||||
|
self.write_timeout = write_timeout
|
||||||
|
self.dial_timeout = dial_timeout
|
||||||
self.event_started = trio.Event()
|
self.event_started = trio.Event()
|
||||||
self._hole_punch_attempts: dict[ID, int] = {}
|
self._hole_punch_attempts: dict[ID, int] = {}
|
||||||
self._direct_connections: set[ID] = set()
|
self._direct_connections: set[ID] = set()
|
||||||
@ -161,7 +177,7 @@ class DCUtRProtocol(Service):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Read the CONNECT message
|
# Read the CONNECT message
|
||||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
with trio.fail_after(self.read_timeout):
|
||||||
msg_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
msg_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||||
|
|
||||||
# Parse the message
|
# Parse the message
|
||||||
@ -196,7 +212,7 @@ class DCUtRProtocol(Service):
|
|||||||
response.type = HolePunch.CONNECT
|
response.type = HolePunch.CONNECT
|
||||||
response.ObsAddrs.extend(our_addrs)
|
response.ObsAddrs.extend(our_addrs)
|
||||||
|
|
||||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
with trio.fail_after(self.write_timeout):
|
||||||
await stream.write(response.SerializeToString())
|
await stream.write(response.SerializeToString())
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -206,7 +222,7 @@ class DCUtRProtocol(Service):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Wait for SYNC message
|
# Wait for SYNC message
|
||||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
with trio.fail_after(self.read_timeout):
|
||||||
sync_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
sync_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||||
|
|
||||||
# Parse the SYNC message
|
# Parse the SYNC message
|
||||||
@ -300,7 +316,7 @@ class DCUtRProtocol(Service):
|
|||||||
connect_msg.ObsAddrs.extend(our_addrs)
|
connect_msg.ObsAddrs.extend(our_addrs)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
with trio.fail_after(self.write_timeout):
|
||||||
await stream.write(connect_msg.SerializeToString())
|
await stream.write(connect_msg.SerializeToString())
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -310,7 +326,7 @@ class DCUtRProtocol(Service):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Receive the peer's CONNECT message
|
# Receive the peer's CONNECT message
|
||||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
with trio.fail_after(self.read_timeout):
|
||||||
resp_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
resp_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||||
|
|
||||||
# Calculate RTT
|
# Calculate RTT
|
||||||
@ -349,7 +365,7 @@ class DCUtRProtocol(Service):
|
|||||||
sync_msg = HolePunch()
|
sync_msg = HolePunch()
|
||||||
sync_msg.type = HolePunch.SYNC
|
sync_msg.type = HolePunch.SYNC
|
||||||
|
|
||||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
with trio.fail_after(self.write_timeout):
|
||||||
await stream.write(sync_msg.SerializeToString())
|
await stream.write(sync_msg.SerializeToString())
|
||||||
|
|
||||||
logger.debug("Sent SYNC message to %s", peer_id)
|
logger.debug("Sent SYNC message to %s", peer_id)
|
||||||
@ -468,7 +484,7 @@ class DCUtRProtocol(Service):
|
|||||||
peer_info = PeerInfo(peer_id, [addr])
|
peer_info = PeerInfo(peer_id, [addr])
|
||||||
|
|
||||||
# Try to connect with timeout
|
# Try to connect with timeout
|
||||||
with trio.fail_after(DIAL_TIMEOUT):
|
with trio.fail_after(self.dial_timeout):
|
||||||
await self.host.connect(peer_info)
|
await self.host.connect(peer_info)
|
||||||
|
|
||||||
logger.info("Successfully connected to %s at %s", peer_id, addr)
|
logger.info("Successfully connected to %s at %s", peer_id, addr)
|
||||||
@ -508,7 +524,9 @@ class DCUtRProtocol(Service):
|
|||||||
|
|
||||||
# Handle both single connection and list of connections
|
# Handle both single connection and list of connections
|
||||||
connections: list[INetConn] = (
|
connections: list[INetConn] = (
|
||||||
[conn_or_conns] if not isinstance(conn_or_conns, list) else conn_or_conns
|
list(conn_or_conns)
|
||||||
|
if not isinstance(conn_or_conns, list)
|
||||||
|
else conn_or_conns
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if any connection is direct (not relayed)
|
# Check if any connection is direct (not relayed)
|
||||||
|
|||||||
@ -31,6 +31,11 @@ from libp2p.tools.async_service import (
|
|||||||
Service,
|
Service,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
DEFAULT_DISCOVERY_INTERVAL,
|
||||||
|
DEFAULT_DISCOVERY_STREAM_TIMEOUT,
|
||||||
|
DEFAULT_PEER_PROTOCOL_TIMEOUT,
|
||||||
|
)
|
||||||
from .pb.circuit_pb2 import (
|
from .pb.circuit_pb2 import (
|
||||||
HopMessage,
|
HopMessage,
|
||||||
)
|
)
|
||||||
@ -43,10 +48,8 @@ from .protocol_buffer import (
|
|||||||
|
|
||||||
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
|
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
|
||||||
|
|
||||||
# Constants
|
# Discovery constants
|
||||||
MAX_RELAYS_TO_TRACK = 10
|
MAX_RELAYS_TO_TRACK = 10
|
||||||
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
|
|
||||||
STREAM_TIMEOUT = 10 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
# Extended interfaces for type checking
|
# Extended interfaces for type checking
|
||||||
@ -86,6 +89,8 @@ class RelayDiscovery(Service):
|
|||||||
auto_reserve: bool = False,
|
auto_reserve: bool = False,
|
||||||
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
|
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
|
||||||
max_relays: int = MAX_RELAYS_TO_TRACK,
|
max_relays: int = MAX_RELAYS_TO_TRACK,
|
||||||
|
stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT,
|
||||||
|
peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the discovery service.
|
Initialize the discovery service.
|
||||||
@ -100,6 +105,10 @@ class RelayDiscovery(Service):
|
|||||||
How often to run discovery, in seconds
|
How often to run discovery, in seconds
|
||||||
max_relays : int
|
max_relays : int
|
||||||
Maximum number of relays to track
|
Maximum number of relays to track
|
||||||
|
stream_timeout : int
|
||||||
|
Timeout for stream operations during discovery, in seconds
|
||||||
|
peer_protocol_timeout : int
|
||||||
|
Timeout for checking peer protocol support, in seconds
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -107,6 +116,8 @@ class RelayDiscovery(Service):
|
|||||||
self.auto_reserve = auto_reserve
|
self.auto_reserve = auto_reserve
|
||||||
self.discovery_interval = discovery_interval
|
self.discovery_interval = discovery_interval
|
||||||
self.max_relays = max_relays
|
self.max_relays = max_relays
|
||||||
|
self.stream_timeout = stream_timeout
|
||||||
|
self.peer_protocol_timeout = peer_protocol_timeout
|
||||||
self._discovered_relays: dict[ID, RelayInfo] = {}
|
self._discovered_relays: dict[ID, RelayInfo] = {}
|
||||||
self._protocol_cache: dict[
|
self._protocol_cache: dict[
|
||||||
ID, set[str]
|
ID, set[str]
|
||||||
@ -165,8 +176,8 @@ class RelayDiscovery(Service):
|
|||||||
self._discovered_relays[peer_id].last_seen = time.time()
|
self._discovered_relays[peer_id].last_seen = time.time()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if peer supports the relay protocol
|
# Don't wait too long for protocol info
|
||||||
with trio.move_on_after(5): # Don't wait too long for protocol info
|
with trio.move_on_after(self.peer_protocol_timeout):
|
||||||
if await self._supports_relay_protocol(peer_id):
|
if await self._supports_relay_protocol(peer_id):
|
||||||
await self._add_relay(peer_id)
|
await self._add_relay(peer_id)
|
||||||
|
|
||||||
@ -264,7 +275,7 @@ class RelayDiscovery(Service):
|
|||||||
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
|
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
|
||||||
"""Check protocol support via direct connection."""
|
"""Check protocol support via direct connection."""
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(STREAM_TIMEOUT):
|
with trio.fail_after(self.stream_timeout):
|
||||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||||
if stream:
|
if stream:
|
||||||
await stream.close()
|
await stream.close()
|
||||||
@ -370,7 +381,7 @@ class RelayDiscovery(Service):
|
|||||||
|
|
||||||
# Open a stream to the relay with timeout
|
# Open a stream to the relay with timeout
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(STREAM_TIMEOUT):
|
with trio.fail_after(self.stream_timeout):
|
||||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||||
if not stream:
|
if not stream:
|
||||||
logger.error("Failed to open stream to relay %s", peer_id)
|
logger.error("Failed to open stream to relay %s", peer_id)
|
||||||
@ -386,7 +397,7 @@ class RelayDiscovery(Service):
|
|||||||
peer=self.host.get_id().to_bytes(),
|
peer=self.host.get_id().to_bytes(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with trio.fail_after(STREAM_TIMEOUT):
|
with trio.fail_after(self.stream_timeout):
|
||||||
await stream.write(request.SerializeToString())
|
await stream.write(request.SerializeToString())
|
||||||
|
|
||||||
# Wait for response
|
# Wait for response
|
||||||
|
|||||||
@ -5,6 +5,7 @@ This module implements the Circuit Relay v2 protocol as specified in:
|
|||||||
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from enum import Enum, auto
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -37,6 +38,15 @@ from libp2p.tools.async_service import (
|
|||||||
Service,
|
Service,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
DEFAULT_MAX_CIRCUIT_BYTES,
|
||||||
|
DEFAULT_MAX_CIRCUIT_CONNS,
|
||||||
|
DEFAULT_MAX_CIRCUIT_DURATION,
|
||||||
|
DEFAULT_MAX_RESERVATIONS,
|
||||||
|
DEFAULT_PROTOCOL_CLOSE_TIMEOUT,
|
||||||
|
DEFAULT_PROTOCOL_READ_TIMEOUT,
|
||||||
|
DEFAULT_PROTOCOL_WRITE_TIMEOUT,
|
||||||
|
)
|
||||||
from .pb.circuit_pb2 import (
|
from .pb.circuit_pb2 import (
|
||||||
HopMessage,
|
HopMessage,
|
||||||
Limit,
|
Limit,
|
||||||
@ -58,18 +68,22 @@ logger = logging.getLogger("libp2p.relay.circuit_v2")
|
|||||||
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
|
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
|
||||||
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
|
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
|
||||||
|
|
||||||
|
|
||||||
|
# Direction enum for data piping
|
||||||
|
class Pipe(Enum):
|
||||||
|
SRC_TO_DST = auto()
|
||||||
|
DST_TO_SRC = auto()
|
||||||
|
|
||||||
|
|
||||||
# Default limits for relay resources
|
# Default limits for relay resources
|
||||||
DEFAULT_RELAY_LIMITS = RelayLimits(
|
DEFAULT_RELAY_LIMITS = RelayLimits(
|
||||||
duration=60 * 60, # 1 hour
|
duration=DEFAULT_MAX_CIRCUIT_DURATION,
|
||||||
data=1024 * 1024 * 1024, # 1GB
|
data=DEFAULT_MAX_CIRCUIT_BYTES,
|
||||||
max_circuit_conns=8,
|
max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS,
|
||||||
max_reservations=4,
|
max_reservations=DEFAULT_MAX_RESERVATIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream operation timeouts
|
# Stream operation constants
|
||||||
STREAM_READ_TIMEOUT = 15 # seconds
|
|
||||||
STREAM_WRITE_TIMEOUT = 15 # seconds
|
|
||||||
STREAM_CLOSE_TIMEOUT = 10 # seconds
|
|
||||||
MAX_READ_RETRIES = 5 # Maximum number of read retries
|
MAX_READ_RETRIES = 5 # Maximum number of read retries
|
||||||
|
|
||||||
|
|
||||||
@ -113,6 +127,9 @@ class CircuitV2Protocol(Service):
|
|||||||
host: IHost,
|
host: IHost,
|
||||||
limits: RelayLimits | None = None,
|
limits: RelayLimits | None = None,
|
||||||
allow_hop: bool = False,
|
allow_hop: bool = False,
|
||||||
|
read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT,
|
||||||
|
write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT,
|
||||||
|
close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize a Circuit Relay v2 protocol instance.
|
Initialize a Circuit Relay v2 protocol instance.
|
||||||
@ -125,11 +142,20 @@ class CircuitV2Protocol(Service):
|
|||||||
Resource limits for the relay
|
Resource limits for the relay
|
||||||
allow_hop : bool
|
allow_hop : bool
|
||||||
Whether to allow this node to act as a relay
|
Whether to allow this node to act as a relay
|
||||||
|
read_timeout : int
|
||||||
|
Timeout for stream read operations, in seconds
|
||||||
|
write_timeout : int
|
||||||
|
Timeout for stream write operations, in seconds
|
||||||
|
close_timeout : int
|
||||||
|
Timeout for stream close operations, in seconds
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.host = host
|
self.host = host
|
||||||
self.limits = limits or DEFAULT_RELAY_LIMITS
|
self.limits = limits or DEFAULT_RELAY_LIMITS
|
||||||
self.allow_hop = allow_hop
|
self.allow_hop = allow_hop
|
||||||
|
self.read_timeout = read_timeout
|
||||||
|
self.write_timeout = write_timeout
|
||||||
|
self.close_timeout = close_timeout
|
||||||
self.resource_manager = RelayResourceManager(self.limits)
|
self.resource_manager = RelayResourceManager(self.limits)
|
||||||
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
|
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
|
||||||
self.event_started = trio.Event()
|
self.event_started = trio.Event()
|
||||||
@ -174,7 +200,7 @@ class CircuitV2Protocol(Service):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
with trio.fail_after(self.close_timeout):
|
||||||
await stream.close()
|
await stream.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
@ -216,7 +242,7 @@ class CircuitV2Protocol(Service):
|
|||||||
|
|
||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
with trio.fail_after(self.read_timeout):
|
||||||
# Try reading with timeout
|
# Try reading with timeout
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Attempting to read from stream (attempt %d/%d)",
|
"Attempting to read from stream (attempt %d/%d)",
|
||||||
@ -293,7 +319,7 @@ class CircuitV2Protocol(Service):
|
|||||||
# First, handle the read timeout gracefully
|
# First, handle the read timeout gracefully
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(
|
with trio.fail_after(
|
||||||
STREAM_READ_TIMEOUT * 2
|
self.read_timeout * 2
|
||||||
): # Double the timeout for reading
|
): # Double the timeout for reading
|
||||||
msg_bytes = await stream.read()
|
msg_bytes = await stream.read()
|
||||||
if not msg_bytes:
|
if not msg_bytes:
|
||||||
@ -414,7 +440,7 @@ class CircuitV2Protocol(Service):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Read the incoming message with timeout
|
# Read the incoming message with timeout
|
||||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
with trio.fail_after(self.read_timeout):
|
||||||
msg_bytes = await stream.read()
|
msg_bytes = await stream.read()
|
||||||
stop_msg = StopMessage()
|
stop_msg = StopMessage()
|
||||||
stop_msg.ParseFromString(msg_bytes)
|
stop_msg.ParseFromString(msg_bytes)
|
||||||
@ -458,8 +484,20 @@ class CircuitV2Protocol(Service):
|
|||||||
|
|
||||||
# Start relaying data
|
# Start relaying data
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
nursery.start_soon(self._relay_data, src_stream, stream, peer_id)
|
nursery.start_soon(
|
||||||
nursery.start_soon(self._relay_data, stream, src_stream, peer_id)
|
self._relay_data,
|
||||||
|
src_stream,
|
||||||
|
stream,
|
||||||
|
peer_id,
|
||||||
|
Pipe.SRC_TO_DST,
|
||||||
|
)
|
||||||
|
nursery.start_soon(
|
||||||
|
self._relay_data,
|
||||||
|
stream,
|
||||||
|
src_stream,
|
||||||
|
peer_id,
|
||||||
|
Pipe.DST_TO_SRC,
|
||||||
|
)
|
||||||
|
|
||||||
except trio.TooSlowError:
|
except trio.TooSlowError:
|
||||||
logger.error("Timeout reading from stop stream")
|
logger.error("Timeout reading from stop stream")
|
||||||
@ -509,7 +547,7 @@ class CircuitV2Protocol(Service):
|
|||||||
ttl = self.resource_manager.reserve(peer_id)
|
ttl = self.resource_manager.reserve(peer_id)
|
||||||
|
|
||||||
# Send reservation success response
|
# Send reservation success response
|
||||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
with trio.fail_after(self.write_timeout):
|
||||||
status = create_status(
|
status = create_status(
|
||||||
code=StatusCode.OK, message="Reservation accepted"
|
code=StatusCode.OK, message="Reservation accepted"
|
||||||
)
|
)
|
||||||
@ -560,7 +598,7 @@ class CircuitV2Protocol(Service):
|
|||||||
# Always close the stream when done with reservation
|
# Always close the stream when done with reservation
|
||||||
if cast(INetStreamWithExtras, stream).is_open():
|
if cast(INetStreamWithExtras, stream).is_open():
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
with trio.fail_after(self.close_timeout):
|
||||||
await stream.close()
|
await stream.close()
|
||||||
except Exception as close_err:
|
except Exception as close_err:
|
||||||
logger.error("Error closing stream: %s", str(close_err))
|
logger.error("Error closing stream: %s", str(close_err))
|
||||||
@ -596,7 +634,7 @@ class CircuitV2Protocol(Service):
|
|||||||
self._active_relays[peer_id] = (stream, None)
|
self._active_relays[peer_id] = (stream, None)
|
||||||
|
|
||||||
# Try to connect to the destination with timeout
|
# Try to connect to the destination with timeout
|
||||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
with trio.fail_after(self.read_timeout):
|
||||||
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
|
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
|
||||||
if not dst_stream:
|
if not dst_stream:
|
||||||
raise ConnectionError("Could not connect to destination")
|
raise ConnectionError("Could not connect to destination")
|
||||||
@ -648,8 +686,20 @@ class CircuitV2Protocol(Service):
|
|||||||
|
|
||||||
# Start relaying data
|
# Start relaying data
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
nursery.start_soon(self._relay_data, stream, dst_stream, peer_id)
|
nursery.start_soon(
|
||||||
nursery.start_soon(self._relay_data, dst_stream, stream, peer_id)
|
self._relay_data,
|
||||||
|
stream,
|
||||||
|
dst_stream,
|
||||||
|
peer_id,
|
||||||
|
Pipe.SRC_TO_DST,
|
||||||
|
)
|
||||||
|
nursery.start_soon(
|
||||||
|
self._relay_data,
|
||||||
|
dst_stream,
|
||||||
|
stream,
|
||||||
|
peer_id,
|
||||||
|
Pipe.DST_TO_SRC,
|
||||||
|
)
|
||||||
|
|
||||||
except (trio.TooSlowError, ConnectionError) as e:
|
except (trio.TooSlowError, ConnectionError) as e:
|
||||||
logger.error("Error establishing relay connection: %s", str(e))
|
logger.error("Error establishing relay connection: %s", str(e))
|
||||||
@ -685,6 +735,7 @@ class CircuitV2Protocol(Service):
|
|||||||
src_stream: INetStream,
|
src_stream: INetStream,
|
||||||
dst_stream: INetStream,
|
dst_stream: INetStream,
|
||||||
peer_id: ID,
|
peer_id: ID,
|
||||||
|
direction: Pipe,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Relay data between two streams.
|
Relay data between two streams.
|
||||||
@ -698,24 +749,27 @@ class CircuitV2Protocol(Service):
|
|||||||
peer_id : ID
|
peer_id : ID
|
||||||
ID of the peer being relayed
|
ID of the peer being relayed
|
||||||
|
|
||||||
|
direction : Pipe
|
||||||
|
Direction of data flow (``Pipe.SRC_TO_DST`` or ``Pipe.DST_TO_SRC``)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# Read data with retries
|
# Read data with retries
|
||||||
data = await self._read_stream_with_retry(src_stream)
|
data = await self._read_stream_with_retry(src_stream)
|
||||||
if not data:
|
if not data:
|
||||||
logger.info("Source stream closed/reset")
|
logger.info("%s closed/reset", direction.name)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Write data with timeout
|
# Write data with timeout
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
with trio.fail_after(self.write_timeout):
|
||||||
await dst_stream.write(data)
|
await dst_stream.write(data)
|
||||||
except trio.TooSlowError:
|
except trio.TooSlowError:
|
||||||
logger.error("Timeout writing to destination stream")
|
logger.error("Timeout writing in %s", direction.name)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error writing to destination stream: %s", str(e))
|
logger.error("Error writing in %s: %s", direction.name, str(e))
|
||||||
break
|
break
|
||||||
|
|
||||||
# Update resource usage
|
# Update resource usage
|
||||||
@ -744,7 +798,7 @@ class CircuitV2Protocol(Service):
|
|||||||
"""Send a status message."""
|
"""Send a status message."""
|
||||||
try:
|
try:
|
||||||
logger.debug("Sending status message with code %s: %s", code, message)
|
logger.debug("Sending status message with code %s: %s", code, message)
|
||||||
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
with trio.fail_after(self.write_timeout * 2): # Double the timeout
|
||||||
# Create a proto Status directly
|
# Create a proto Status directly
|
||||||
pb_status = PbStatus()
|
pb_status = PbStatus()
|
||||||
pb_status.code = cast(
|
pb_status.code = cast(
|
||||||
@ -782,7 +836,7 @@ class CircuitV2Protocol(Service):
|
|||||||
"""Send a status message on a STOP stream."""
|
"""Send a status message on a STOP stream."""
|
||||||
try:
|
try:
|
||||||
logger.debug("Sending stop status message with code %s: %s", code, message)
|
logger.debug("Sending stop status message with code %s: %s", code, message)
|
||||||
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
with trio.fail_after(self.write_timeout * 2): # Double the timeout
|
||||||
# Create a proto Status directly
|
# Create a proto Status directly
|
||||||
pb_status = PbStatus()
|
pb_status = PbStatus()
|
||||||
pb_status.code = cast(
|
pb_status.code = cast(
|
||||||
|
|||||||
@ -8,6 +8,7 @@ including reservations and connection limits.
|
|||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
dataclass,
|
dataclass,
|
||||||
)
|
)
|
||||||
|
from enum import Enum, auto
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -19,6 +20,18 @@ from libp2p.peer.id import (
|
|||||||
# Import the protobuf definitions
|
# Import the protobuf definitions
|
||||||
from .pb.circuit_pb2 import Reservation as PbReservation
|
from .pb.circuit_pb2 import Reservation as PbReservation
|
||||||
|
|
||||||
|
RANDOM_BYTES_LENGTH = 16 # 128 bits of randomness
|
||||||
|
TIMESTAMP_MULTIPLIER = 1000000 # To convert seconds to microseconds
|
||||||
|
|
||||||
|
|
||||||
|
# Reservation status enum
|
||||||
|
class ReservationStatus(Enum):
|
||||||
|
"""Lifecycle status of a relay reservation."""
|
||||||
|
|
||||||
|
ACTIVE = auto()
|
||||||
|
EXPIRED = auto()
|
||||||
|
REJECTED = auto()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RelayLimits:
|
class RelayLimits:
|
||||||
@ -68,8 +81,8 @@ class Reservation:
|
|||||||
# - Peer ID to bind it to the specific peer
|
# - Peer ID to bind it to the specific peer
|
||||||
# - Timestamp for uniqueness
|
# - Timestamp for uniqueness
|
||||||
# - Hash everything for a fixed size output
|
# - Hash everything for a fixed size output
|
||||||
random_bytes = os.urandom(16) # 128 bits of randomness
|
random_bytes = os.urandom(RANDOM_BYTES_LENGTH)
|
||||||
timestamp = str(int(self.created_at * 1000000)).encode()
|
timestamp = str(int(self.created_at * TIMESTAMP_MULTIPLIER)).encode()
|
||||||
peer_bytes = self.peer_id.to_bytes()
|
peer_bytes = self.peer_id.to_bytes()
|
||||||
|
|
||||||
# Combine all elements and hash them
|
# Combine all elements and hash them
|
||||||
@ -84,6 +97,15 @@ class Reservation:
|
|||||||
"""Check if the reservation has expired."""
|
"""Check if the reservation has expired."""
|
||||||
return time.time() > self.expires_at
|
return time.time() > self.expires_at
|
||||||
|
|
||||||
|
# Expose a friendly status enum
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self) -> ReservationStatus:
|
||||||
|
"""Return the current status as a ``ReservationStatus`` enum."""
|
||||||
|
return (
|
||||||
|
ReservationStatus.EXPIRED if self.is_expired() else ReservationStatus.ACTIVE
|
||||||
|
)
|
||||||
|
|
||||||
def can_accept_connection(self) -> bool:
|
def can_accept_connection(self) -> bool:
|
||||||
"""Check if a new connection can be accepted."""
|
"""Check if a new connection can be accepted."""
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -89,7 +89,10 @@ class CircuitV2Transport(ITransport):
|
|||||||
auto_reserve=config.enable_client,
|
auto_reserve=config.enable_client,
|
||||||
discovery_interval=config.discovery_interval,
|
discovery_interval=config.discovery_interval,
|
||||||
max_relays=config.max_relays,
|
max_relays=config.max_relays,
|
||||||
|
stream_timeout=config.timeouts.discovery_stream_timeout,
|
||||||
|
peer_protocol_timeout=config.timeouts.peer_protocol_timeout,
|
||||||
)
|
)
|
||||||
|
self.relay_counter = 0 # for round robin load balancing
|
||||||
|
|
||||||
async def dial(
|
async def dial(
|
||||||
self,
|
self,
|
||||||
@ -219,11 +222,25 @@ class CircuitV2Transport(ITransport):
|
|||||||
# Get a relay from the list of discovered relays
|
# Get a relay from the list of discovered relays
|
||||||
relays = self.discovery.get_relays()
|
relays = self.discovery.get_relays()
|
||||||
if relays:
|
if relays:
|
||||||
# TODO: Implement more sophisticated relay selection
|
# Prioritize relays with active reservations
|
||||||
# For now, just return the first available relay
|
relays_with_reservations = []
|
||||||
return relays[0]
|
other_relays = []
|
||||||
|
|
||||||
# Wait and try discovery
|
for relay_id in relays:
|
||||||
|
relay_info = self.discovery.get_relay_info(relay_id)
|
||||||
|
if relay_info and relay_info.has_reservation:
|
||||||
|
relays_with_reservations.append(relay_id)
|
||||||
|
else:
|
||||||
|
other_relays.append(relay_id)
|
||||||
|
|
||||||
|
# Return first available relay with reservation, or fallback to others
|
||||||
|
self.relay_counter += 1
|
||||||
|
if relays_with_reservations:
|
||||||
|
return relays_with_reservations[
|
||||||
|
(self.relay_counter - 1) % len(relays_with_reservations)
|
||||||
|
]
|
||||||
|
elif other_relays:
|
||||||
|
return other_relays[(self.relay_counter - 1) % len(other_relays)]
|
||||||
await trio.sleep(1)
|
await trio.sleep(1)
|
||||||
attempts += 1
|
attempts += 1
|
||||||
|
|
||||||
|
|||||||
@ -1,68 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from libp2p.abc import IRawConnection
|
|
||||||
from libp2p.custom_types import TProtocol
|
|
||||||
from libp2p.peer.id import ID
|
|
||||||
|
|
||||||
from .pb import noise_pb2 as noise_pb
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyDataHandler(ABC):
|
|
||||||
"""Interface for handling early data during Noise handshake"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def send(
|
|
||||||
self, conn: IRawConnection, peer_id: ID
|
|
||||||
) -> noise_pb.NoiseExtensions | None:
|
|
||||||
"""Called to generate early data to send during handshake"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def received(
|
|
||||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
|
||||||
) -> None:
|
|
||||||
"""Called when early data is received during handshake"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TransportEarlyDataHandler(EarlyDataHandler):
|
|
||||||
"""Default early data handler for muxer negotiation"""
|
|
||||||
|
|
||||||
def __init__(self, supported_muxers: list[TProtocol]):
|
|
||||||
self.supported_muxers = supported_muxers
|
|
||||||
self.received_muxers: list[TProtocol] = []
|
|
||||||
|
|
||||||
async def send(
|
|
||||||
self, conn: IRawConnection, peer_id: ID
|
|
||||||
) -> noise_pb.NoiseExtensions | None:
|
|
||||||
"""Send our supported muxers list"""
|
|
||||||
if not self.supported_muxers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
extensions = noise_pb.NoiseExtensions()
|
|
||||||
# Convert TProtocol to string for serialization
|
|
||||||
extensions.stream_muxers[:] = [str(muxer) for muxer in self.supported_muxers]
|
|
||||||
return extensions
|
|
||||||
|
|
||||||
async def received(
|
|
||||||
self, conn: IRawConnection, extensions: noise_pb.NoiseExtensions | None
|
|
||||||
) -> None:
|
|
||||||
"""Store received muxers list"""
|
|
||||||
if extensions and extensions.stream_muxers:
|
|
||||||
self.received_muxers = [
|
|
||||||
TProtocol(muxer) for muxer in extensions.stream_muxers
|
|
||||||
]
|
|
||||||
|
|
||||||
def match_muxers(self, is_initiator: bool) -> TProtocol | None:
|
|
||||||
"""Find first common muxer between local and remote"""
|
|
||||||
if is_initiator:
|
|
||||||
# Initiator: find first local muxer that remote supports
|
|
||||||
for local_muxer in self.supported_muxers:
|
|
||||||
if local_muxer in self.received_muxers:
|
|
||||||
return local_muxer
|
|
||||||
else:
|
|
||||||
# Responder: find first remote muxer that we support
|
|
||||||
for remote_muxer in self.received_muxers:
|
|
||||||
if remote_muxer in self.supported_muxers:
|
|
||||||
return remote_muxer
|
|
||||||
return None
|
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
|
|||||||
FixedSizeLenMsgReadWriter,
|
FixedSizeLenMsgReadWriter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SIZE_NOISE_MESSAGE_LEN = 2
|
SIZE_NOISE_MESSAGE_LEN = 2
|
||||||
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
||||||
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
||||||
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
|||||||
self.noise_state = noise_state
|
self.noise_state = noise_state
|
||||||
|
|
||||||
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
||||||
|
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
|
||||||
data_encrypted = self.encrypt(msg)
|
data_encrypted = self.encrypt(msg)
|
||||||
if prefix_encoded:
|
if prefix_encoded:
|
||||||
# Manually add the prefix if needed
|
# Manually add the prefix if needed
|
||||||
data_encrypted = self.prefix + data_encrypted
|
data_encrypted = self.prefix + data_encrypted
|
||||||
|
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
|
||||||
await self.read_writer.write_msg(data_encrypted)
|
await self.read_writer.write_msg(data_encrypted)
|
||||||
|
logger.debug("Noise write_msg: write completed successfully")
|
||||||
|
|
||||||
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
||||||
|
logger.debug("Noise read_msg: reading encrypted message")
|
||||||
noise_msg_encrypted = await self.read_writer.read_msg()
|
noise_msg_encrypted = await self.read_writer.read_msg()
|
||||||
|
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
|
||||||
if prefix_encoded:
|
if prefix_encoded:
|
||||||
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||||
else:
|
else:
|
||||||
return self.decrypt(noise_msg_encrypted)
|
result = self.decrypt(noise_msg_encrypted)
|
||||||
|
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
|
||||||
|
return result
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self.read_writer.close()
|
await self.read_writer.close()
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
dataclass,
|
dataclass,
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
|
||||||
from libp2p.crypto.keys import (
|
from libp2p.crypto.keys import (
|
||||||
PrivateKey,
|
PrivateKey,
|
||||||
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
|
|||||||
|
|
||||||
from .pb import noise_pb2 as noise_pb
|
from .pb import noise_pb2 as noise_pb
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
||||||
|
|
||||||
|
|
||||||
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
|
|||||||
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
data = make_data_to_be_signed(noise_static_pubkey)
|
data = make_data_to_be_signed(noise_static_pubkey)
|
||||||
|
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
|
||||||
|
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
|
||||||
return id_privkey.sign(data)
|
return id_privkey.sign(data)
|
||||||
|
|
||||||
|
|
||||||
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
|
|||||||
2. signed by the private key corresponding to `id_pubkey`
|
2. signed by the private key corresponding to `id_pubkey`
|
||||||
"""
|
"""
|
||||||
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
||||||
return payload.id_pubkey.verify(expected_data, payload.id_sig)
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: payload.id_pubkey type: "
|
||||||
|
f"{type(payload.id_pubkey)}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: noise_static_pubkey type: "
|
||||||
|
f"{type(noise_static_pubkey)}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||||
|
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from abc import (
|
|||||||
ABC,
|
ABC,
|
||||||
abstractmethod,
|
abstractmethod,
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import (
|
from cryptography.hazmat.primitives import (
|
||||||
serialization,
|
serialization,
|
||||||
@ -30,9 +31,6 @@ from libp2p.security.secure_session import (
|
|||||||
SecureSession,
|
SecureSession,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .early_data import (
|
|
||||||
EarlyDataHandler,
|
|
||||||
)
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
HandshakeHasNotFinished,
|
HandshakeHasNotFinished,
|
||||||
InvalidSignature,
|
InvalidSignature,
|
||||||
@ -48,7 +46,8 @@ from .messages import (
|
|||||||
make_handshake_payload_sig,
|
make_handshake_payload_sig,
|
||||||
verify_handshake_payload_sig,
|
verify_handshake_payload_sig,
|
||||||
)
|
)
|
||||||
from .pb import noise_pb2 as noise_pb
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class IPattern(ABC):
|
class IPattern(ABC):
|
||||||
@ -66,8 +65,7 @@ class BasePattern(IPattern):
|
|||||||
noise_static_key: PrivateKey
|
noise_static_key: PrivateKey
|
||||||
local_peer: ID
|
local_peer: ID
|
||||||
libp2p_privkey: PrivateKey
|
libp2p_privkey: PrivateKey
|
||||||
initiator_early_data_handler: EarlyDataHandler | None
|
early_data: bytes | None
|
||||||
responder_early_data_handler: EarlyDataHandler | None
|
|
||||||
|
|
||||||
def create_noise_state(self) -> NoiseState:
|
def create_noise_state(self) -> NoiseState:
|
||||||
noise_state = NoiseState.from_name(self.protocol_name)
|
noise_state = NoiseState.from_name(self.protocol_name)
|
||||||
@ -78,50 +76,11 @@ class BasePattern(IPattern):
|
|||||||
raise NoiseStateError("noise_protocol is not initialized")
|
raise NoiseStateError("noise_protocol is not initialized")
|
||||||
return noise_state
|
return noise_state
|
||||||
|
|
||||||
async def make_handshake_payload(
|
def make_handshake_payload(self) -> NoiseHandshakePayload:
|
||||||
self, conn: IRawConnection, peer_id: ID, is_initiator: bool
|
|
||||||
) -> NoiseHandshakePayload:
|
|
||||||
signature = make_handshake_payload_sig(
|
signature = make_handshake_payload_sig(
|
||||||
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
self.libp2p_privkey, self.noise_static_key.get_public_key()
|
||||||
)
|
)
|
||||||
|
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
|
||||||
# NEW: Get early data from appropriate handler
|
|
||||||
extensions = None
|
|
||||||
if is_initiator and self.initiator_early_data_handler:
|
|
||||||
extensions = await self.initiator_early_data_handler.send(conn, peer_id)
|
|
||||||
elif not is_initiator and self.responder_early_data_handler:
|
|
||||||
extensions = await self.responder_early_data_handler.send(conn, peer_id)
|
|
||||||
|
|
||||||
# NEW: Serialize extensions into early_data field
|
|
||||||
early_data = None
|
|
||||||
if extensions:
|
|
||||||
early_data = extensions.SerializeToString()
|
|
||||||
|
|
||||||
return NoiseHandshakePayload(
|
|
||||||
self.libp2p_privkey.get_public_key(),
|
|
||||||
signature,
|
|
||||||
early_data, # ← This is the key addition
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_received_payload(
|
|
||||||
self, conn: IRawConnection, payload: NoiseHandshakePayload, is_initiator: bool
|
|
||||||
) -> None:
|
|
||||||
"""Process early data from received payload"""
|
|
||||||
if not payload.early_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Deserialize the NoiseExtensions from early_data field
|
|
||||||
try:
|
|
||||||
extensions = noise_pb.NoiseExtensions.FromString(payload.early_data)
|
|
||||||
except Exception:
|
|
||||||
# Invalid extensions, ignore silently
|
|
||||||
return
|
|
||||||
|
|
||||||
# Pass to appropriate handler
|
|
||||||
if is_initiator and self.initiator_early_data_handler:
|
|
||||||
await self.initiator_early_data_handler.received(conn, extensions)
|
|
||||||
elif not is_initiator and self.responder_early_data_handler:
|
|
||||||
await self.responder_early_data_handler.received(conn, extensions)
|
|
||||||
|
|
||||||
|
|
||||||
class PatternXX(BasePattern):
|
class PatternXX(BasePattern):
|
||||||
@ -130,17 +89,16 @@ class PatternXX(BasePattern):
|
|||||||
local_peer: ID,
|
local_peer: ID,
|
||||||
libp2p_privkey: PrivateKey,
|
libp2p_privkey: PrivateKey,
|
||||||
noise_static_key: PrivateKey,
|
noise_static_key: PrivateKey,
|
||||||
initiator_early_data_handler: EarlyDataHandler | None,
|
early_data: bytes | None = None,
|
||||||
responder_early_data_handler: EarlyDataHandler | None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
self.protocol_name = b"Noise_XX_25519_ChaChaPoly_SHA256"
|
||||||
self.local_peer = local_peer
|
self.local_peer = local_peer
|
||||||
self.libp2p_privkey = libp2p_privkey
|
self.libp2p_privkey = libp2p_privkey
|
||||||
self.noise_static_key = noise_static_key
|
self.noise_static_key = noise_static_key
|
||||||
self.initiator_early_data_handler = initiator_early_data_handler
|
self.early_data = early_data
|
||||||
self.responder_early_data_handler = responder_early_data_handler
|
|
||||||
|
|
||||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
|
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
noise_state.set_as_responder()
|
noise_state.set_as_responder()
|
||||||
noise_state.start_handshake()
|
noise_state.start_handshake()
|
||||||
@ -152,23 +110,25 @@ class PatternXX(BasePattern):
|
|||||||
|
|
||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
|
|
||||||
# 1. Consume msg#1 (just empty bytes)
|
# Consume msg#1.
|
||||||
|
logger.debug("Noise XX handshake_inbound: reading msg#1")
|
||||||
await read_writer.read_msg()
|
await read_writer.read_msg()
|
||||||
|
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
|
||||||
|
|
||||||
# 2. Send msg#2 with our payload INCLUDING EARLY DATA
|
# Send msg#2, which should include our handshake payload.
|
||||||
our_payload = await self.make_handshake_payload(
|
logger.debug("Noise XX handshake_inbound: preparing msg#2")
|
||||||
conn,
|
our_payload = self.make_handshake_payload()
|
||||||
self.local_peer, # We send our own peer ID in responder role
|
|
||||||
is_initiator=False,
|
|
||||||
)
|
|
||||||
msg_2 = our_payload.serialize()
|
msg_2 = our_payload.serialize()
|
||||||
|
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
|
||||||
await read_writer.write_msg(msg_2)
|
await read_writer.write_msg(msg_2)
|
||||||
|
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
|
||||||
|
|
||||||
# 3. Receive msg#3
|
# Receive and consume msg#3.
|
||||||
|
logger.debug("Noise XX handshake_inbound: reading msg#3")
|
||||||
msg_3 = await read_writer.read_msg()
|
msg_3 = await read_writer.read_msg()
|
||||||
|
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||||
|
|
||||||
# Extract remote pubkey from noise handshake state
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
raise NoiseStateError(
|
raise NoiseStateError(
|
||||||
"something is wrong in the underlying noise `handshake_state`: "
|
"something is wrong in the underlying noise `handshake_state`: "
|
||||||
@ -177,31 +137,14 @@ class PatternXX(BasePattern):
|
|||||||
)
|
)
|
||||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||||
|
|
||||||
# 4. Verify signature (unchanged)
|
|
||||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
raise InvalidSignature
|
raise InvalidSignature
|
||||||
|
|
||||||
# NEW: Process early data from msg#3 AFTER signature verification
|
|
||||||
await self.handle_received_payload(
|
|
||||||
conn, peer_handshake_payload, is_initiator=False
|
|
||||||
)
|
|
||||||
|
|
||||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
|
|
||||||
if not noise_state.handshake_finished:
|
if not noise_state.handshake_finished:
|
||||||
raise HandshakeHasNotFinished(
|
raise HandshakeHasNotFinished(
|
||||||
"handshake is done but it is not marked as finished in `noise_state`"
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Get negotiated muxer for connection state
|
|
||||||
# negotiated_muxer = None
|
|
||||||
if self.responder_early_data_handler and hasattr(
|
|
||||||
self.responder_early_data_handler, "match_muxers"
|
|
||||||
):
|
|
||||||
# negotiated_muxer =
|
|
||||||
# self.responder_early_data_handler.match_muxers(is_initiator=False)
|
|
||||||
pass
|
|
||||||
|
|
||||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||||
return SecureSession(
|
return SecureSession(
|
||||||
local_peer=self.local_peer,
|
local_peer=self.local_peer,
|
||||||
@ -210,13 +153,12 @@ class PatternXX(BasePattern):
|
|||||||
remote_permanent_pubkey=remote_pubkey,
|
remote_permanent_pubkey=remote_pubkey,
|
||||||
is_initiator=False,
|
is_initiator=False,
|
||||||
conn=transport_read_writer,
|
conn=transport_read_writer,
|
||||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
|
||||||
# For now, store it in connection metadata or similar
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handshake_outbound(
|
async def handshake_outbound(
|
||||||
self, conn: IRawConnection, remote_peer: ID
|
self, conn: IRawConnection, remote_peer: ID
|
||||||
) -> ISecureConn:
|
) -> ISecureConn:
|
||||||
|
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
|
|
||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
@ -228,27 +170,47 @@ class PatternXX(BasePattern):
|
|||||||
if handshake_state is None:
|
if handshake_state is None:
|
||||||
raise NoiseStateError("Handshake state is not initialized")
|
raise NoiseStateError("Handshake state is not initialized")
|
||||||
|
|
||||||
# 1. Send msg#1 (empty) - no early data possible in XX pattern
|
# Send msg#1, which is *not* encrypted.
|
||||||
|
logger.debug("Noise XX handshake_outbound: sending msg#1")
|
||||||
msg_1 = b""
|
msg_1 = b""
|
||||||
await read_writer.write_msg(msg_1)
|
await read_writer.write_msg(msg_1)
|
||||||
|
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
|
||||||
|
|
||||||
# 2. Read msg#2 from responder
|
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||||
|
logger.debug("Noise XX handshake_outbound: reading msg#2")
|
||||||
msg_2 = await read_writer.read_msg()
|
msg_2 = await read_writer.read_msg()
|
||||||
|
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||||
|
|
||||||
# Extract remote pubkey from noise handshake state
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
raise NoiseStateError(
|
raise NoiseStateError(
|
||||||
"something is wrong in the underlying noise `handshake_state`: "
|
"something is wrong in the underlying noise `handshake_state`: "
|
||||||
"we received and consumed msg#2, which should have included the "
|
"we received and consumed msg#3, which should have included the "
|
||||||
"remote static public key, but it is not present in the handshake_state"
|
"remote static public key, but it is not present in the handshake_state"
|
||||||
)
|
)
|
||||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||||
|
|
||||||
# Verify signature BEFORE processing early data (security)
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
|
||||||
|
)
|
||||||
|
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
|
||||||
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
|
||||||
|
f"{id_pubkey_repr}"
|
||||||
|
)
|
||||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
|
logger.error(
|
||||||
|
f"Noise XX handshake_outbound: signature verification failed for peer "
|
||||||
|
f"{remote_peer}"
|
||||||
|
)
|
||||||
raise InvalidSignature
|
raise InvalidSignature
|
||||||
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: signature verification successful for peer "
|
||||||
|
f"{remote_peer}"
|
||||||
|
)
|
||||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
if remote_peer_id_from_pubkey != remote_peer:
|
if remote_peer_id_from_pubkey != remote_peer:
|
||||||
raise PeerIDMismatchesPubkey(
|
raise PeerIDMismatchesPubkey(
|
||||||
@ -257,15 +219,8 @@ class PatternXX(BasePattern):
|
|||||||
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
f"remote_peer_id_from_pubkey={remote_peer_id_from_pubkey}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Process early data from msg#2 AFTER verification
|
# Send msg#3, which includes our encrypted payload and our noise static key.
|
||||||
await self.handle_received_payload(
|
our_payload = self.make_handshake_payload()
|
||||||
conn, peer_handshake_payload, is_initiator=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Send msg#3 with our payload INCLUDING EARLY DATA
|
|
||||||
our_payload = await self.make_handshake_payload(
|
|
||||||
conn, remote_peer, is_initiator=True
|
|
||||||
)
|
|
||||||
msg_3 = our_payload.serialize()
|
msg_3 = our_payload.serialize()
|
||||||
await read_writer.write_msg(msg_3)
|
await read_writer.write_msg(msg_3)
|
||||||
|
|
||||||
@ -273,16 +228,6 @@ class PatternXX(BasePattern):
|
|||||||
raise HandshakeHasNotFinished(
|
raise HandshakeHasNotFinished(
|
||||||
"handshake is done but it is not marked as finished in `noise_state`"
|
"handshake is done but it is not marked as finished in `noise_state`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Get negotiated muxer
|
|
||||||
# negotiated_muxer = None
|
|
||||||
if self.initiator_early_data_handler and hasattr(
|
|
||||||
self.initiator_early_data_handler, "match_muxers"
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
# negotiated_muxer =
|
|
||||||
# self.initiator_early_data_handler.match_muxers(is_initiator=True)
|
|
||||||
|
|
||||||
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
|
||||||
return SecureSession(
|
return SecureSession(
|
||||||
local_peer=self.local_peer,
|
local_peer=self.local_peer,
|
||||||
@ -291,8 +236,6 @@ class PatternXX(BasePattern):
|
|||||||
remote_permanent_pubkey=remote_pubkey,
|
remote_permanent_pubkey=remote_pubkey,
|
||||||
is_initiator=True,
|
is_initiator=True,
|
||||||
conn=transport_read_writer,
|
conn=transport_read_writer,
|
||||||
# NOTE: negotiated_muxer would need to be added to SecureSession constructor
|
|
||||||
# For now, store it in connection metadata or similar
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,13 +1,8 @@
|
|||||||
syntax = "proto2";
|
syntax = "proto3";
|
||||||
package pb;
|
package pb;
|
||||||
|
|
||||||
message NoiseExtensions {
|
|
||||||
repeated bytes webtransport_certhashes = 1;
|
|
||||||
repeated string stream_muxers = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message NoiseHandshakePayload {
|
message NoiseHandshakePayload {
|
||||||
optional bytes identity_key = 1;
|
bytes identity_key = 1;
|
||||||
optional bytes identity_sig = 2;
|
bytes identity_sig = 2;
|
||||||
optional bytes data = 3;
|
bytes data = 3;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,15 +13,13 @@ _sym_db = _symbol_database.Default()
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"I\n\x0fNoiseExtensions\x12\x1f\n\x17webtransport_certhashes\x18\x01 \x03(\x0c\x12\x15\n\rstream_muxers\x18\x02 \x03(\t\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$libp2p/security/noise/pb/noise.proto\x12\x02pb\"Q\n\x15NoiseHandshakePayload\x12\x14\n\x0cidentity_key\x18\x01 \x01(\x0c\x12\x14\n\x0cidentity_sig\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x62\x06proto3')
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.security.noise.pb.noise_pb2', globals())
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
|
|
||||||
DESCRIPTOR._options = None
|
DESCRIPTOR._options = None
|
||||||
_NOISEEXTENSIONS._serialized_start=44
|
_NOISEHANDSHAKEPAYLOAD._serialized_start=44
|
||||||
_NOISEEXTENSIONS._serialized_end=117
|
_NOISEHANDSHAKEPAYLOAD._serialized_end=125
|
||||||
_NOISEHANDSHAKEPAYLOAD._serialized_start=119
|
|
||||||
_NOISEHANDSHAKEPAYLOAD._serialized_end=200
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@ -4,34 +4,12 @@ isort:skip_file
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import collections.abc
|
|
||||||
import google.protobuf.descriptor
|
import google.protobuf.descriptor
|
||||||
import google.protobuf.internal.containers
|
|
||||||
import google.protobuf.message
|
import google.protobuf.message
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class NoiseExtensions(google.protobuf.message.Message):
|
|
||||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
|
||||||
|
|
||||||
WEBTRANSPORT_CERTHASHES_FIELD_NUMBER: builtins.int
|
|
||||||
STREAM_MUXERS_FIELD_NUMBER: builtins.int
|
|
||||||
@property
|
|
||||||
def webtransport_certhashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
|
||||||
@property
|
|
||||||
def stream_muxers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
webtransport_certhashes: collections.abc.Iterable[builtins.bytes] | None = ...,
|
|
||||||
stream_muxers: collections.abc.Iterable[builtins.str] | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def ClearField(self, field_name: typing.Literal["stream_muxers", b"stream_muxers", "webtransport_certhashes", b"webtransport_certhashes"]) -> None: ...
|
|
||||||
|
|
||||||
global___NoiseExtensions = NoiseExtensions
|
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class NoiseHandshakePayload(google.protobuf.message.Message):
|
class NoiseHandshakePayload(google.protobuf.message.Message):
|
||||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
@ -45,11 +23,10 @@ class NoiseHandshakePayload(google.protobuf.message.Message):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
identity_key: builtins.bytes | None = ...,
|
identity_key: builtins.bytes = ...,
|
||||||
identity_sig: builtins.bytes | None = ...,
|
identity_sig: builtins.bytes = ...,
|
||||||
data: builtins.bytes | None = ...,
|
data: builtins.bytes = ...,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def HasField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> builtins.bool: ...
|
|
||||||
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
|
def ClearField(self, field_name: typing.Literal["data", b"data", "identity_key", b"identity_key", "identity_sig", b"identity_sig"]) -> None: ...
|
||||||
|
|
||||||
global___NoiseHandshakePayload = NoiseHandshakePayload
|
global___NoiseHandshakePayload = NoiseHandshakePayload
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from libp2p.peer.id import (
|
|||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .early_data import EarlyDataHandler, TransportEarlyDataHandler
|
|
||||||
from .patterns import (
|
from .patterns import (
|
||||||
IPattern,
|
IPattern,
|
||||||
PatternXX,
|
PatternXX,
|
||||||
@ -27,40 +26,35 @@ class Transport(ISecureTransport):
|
|||||||
libp2p_privkey: PrivateKey
|
libp2p_privkey: PrivateKey
|
||||||
noise_privkey: PrivateKey
|
noise_privkey: PrivateKey
|
||||||
local_peer: ID
|
local_peer: ID
|
||||||
supported_muxers: list[TProtocol]
|
early_data: bytes | None
|
||||||
initiator_early_data_handler: EarlyDataHandler | None
|
with_noise_pipes: bool
|
||||||
responder_early_data_handler: EarlyDataHandler | None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
libp2p_keypair: KeyPair,
|
libp2p_keypair: KeyPair,
|
||||||
noise_privkey: PrivateKey,
|
noise_privkey: PrivateKey,
|
||||||
supported_muxers: list[TProtocol] | None = None,
|
early_data: bytes | None = None,
|
||||||
initiator_handler: EarlyDataHandler | None = None,
|
with_noise_pipes: bool = False,
|
||||||
responder_handler: EarlyDataHandler | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.libp2p_privkey = libp2p_keypair.private_key
|
self.libp2p_privkey = libp2p_keypair.private_key
|
||||||
self.noise_privkey = noise_privkey
|
self.noise_privkey = noise_privkey
|
||||||
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
|
self.local_peer = ID.from_pubkey(libp2p_keypair.public_key)
|
||||||
self.supported_muxers = supported_muxers or []
|
self.early_data = early_data
|
||||||
|
self.with_noise_pipes = with_noise_pipes
|
||||||
|
|
||||||
# Create default handlers for muxer negotiation if none provided
|
if self.with_noise_pipes:
|
||||||
if initiator_handler is None and self.supported_muxers:
|
raise NotImplementedError
|
||||||
initiator_handler = TransportEarlyDataHandler(self.supported_muxers)
|
|
||||||
if responder_handler is None and self.supported_muxers:
|
|
||||||
responder_handler = TransportEarlyDataHandler(self.supported_muxers)
|
|
||||||
|
|
||||||
self.initiator_early_data_handler = initiator_handler
|
|
||||||
self.responder_early_data_handler = responder_handler
|
|
||||||
|
|
||||||
def get_pattern(self) -> IPattern:
|
def get_pattern(self) -> IPattern:
|
||||||
return PatternXX(
|
if self.with_noise_pipes:
|
||||||
self.local_peer,
|
raise NotImplementedError
|
||||||
self.libp2p_privkey,
|
else:
|
||||||
self.noise_privkey,
|
return PatternXX(
|
||||||
self.initiator_early_data_handler,
|
self.local_peer,
|
||||||
self.responder_early_data_handler,
|
self.libp2p_privkey,
|
||||||
)
|
self.noise_privkey,
|
||||||
|
self.early_data,
|
||||||
|
)
|
||||||
|
|
||||||
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def secure_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
pattern = self.get_pattern()
|
pattern = self.get_pattern()
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from types import (
|
from types import (
|
||||||
TracebackType,
|
TracebackType,
|
||||||
)
|
)
|
||||||
@ -15,6 +13,7 @@ from libp2p.abc import (
|
|||||||
from libp2p.stream_muxer.exceptions import (
|
from libp2p.stream_muxer.exceptions import (
|
||||||
MuxedConnUnavailable,
|
MuxedConnUnavailable,
|
||||||
)
|
)
|
||||||
|
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||||
|
|
||||||
from .constants import (
|
from .constants import (
|
||||||
HeaderTags,
|
HeaderTags,
|
||||||
@ -34,72 +33,6 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteLock:
|
|
||||||
"""
|
|
||||||
A read-write lock that allows multiple concurrent readers
|
|
||||||
or one exclusive writer, implemented using Trio primitives.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._readers = 0
|
|
||||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
|
||||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
|
||||||
|
|
||||||
async def acquire_read(self) -> None:
|
|
||||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
|
||||||
try:
|
|
||||||
async with self._readers_lock:
|
|
||||||
if self._readers == 0:
|
|
||||||
await self._writer_lock.acquire()
|
|
||||||
self._readers += 1
|
|
||||||
except trio.Cancelled:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def release_read(self) -> None:
|
|
||||||
"""Release a read lock."""
|
|
||||||
async with self._readers_lock:
|
|
||||||
if self._readers == 1:
|
|
||||||
self._writer_lock.release()
|
|
||||||
self._readers -= 1
|
|
||||||
|
|
||||||
async def acquire_write(self) -> None:
|
|
||||||
"""Acquire an exclusive write lock."""
|
|
||||||
try:
|
|
||||||
await self._writer_lock.acquire()
|
|
||||||
except trio.Cancelled:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def release_write(self) -> None:
|
|
||||||
"""Release the exclusive write lock."""
|
|
||||||
self._writer_lock.release()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
|
||||||
"""Context manager for acquiring and releasing a read lock safely."""
|
|
||||||
acquire = False
|
|
||||||
try:
|
|
||||||
await self.acquire_read()
|
|
||||||
acquire = True
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if acquire:
|
|
||||||
with trio.CancelScope() as scope:
|
|
||||||
scope.shield = True
|
|
||||||
await self.release_read()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
|
||||||
"""Context manager for acquiring and releasing a write lock safely."""
|
|
||||||
acquire = False
|
|
||||||
try:
|
|
||||||
await self.acquire_write()
|
|
||||||
acquire = True
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if acquire:
|
|
||||||
self.release_write()
|
|
||||||
|
|
||||||
|
|
||||||
class MplexStream(IMuxedStream):
|
class MplexStream(IMuxedStream):
|
||||||
"""
|
"""
|
||||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
|
|||||||
MultiselectError,
|
MultiselectError,
|
||||||
)
|
)
|
||||||
from libp2p.protocol_muxer.multiselect import (
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
Multiselect,
|
Multiselect,
|
||||||
)
|
)
|
||||||
from libp2p.protocol_muxer.multiselect_client import (
|
from libp2p.protocol_muxer.multiselect_client import (
|
||||||
@ -46,11 +47,17 @@ class MuxerMultistream:
|
|||||||
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
||||||
multiselect: Multiselect
|
multiselect: Multiselect
|
||||||
multiselect_client: MultiselectClient
|
multiselect_client: MultiselectClient
|
||||||
|
negotiate_timeout: int
|
||||||
|
|
||||||
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
muxer_transports_by_protocol: TMuxerOptions,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
) -> None:
|
||||||
self.transports = OrderedDict()
|
self.transports = OrderedDict()
|
||||||
self.multiselect = Multiselect()
|
self.multiselect = Multiselect()
|
||||||
self.multistream_client = MultiselectClient()
|
self.multistream_client = MultiselectClient()
|
||||||
|
self.negotiate_timeout = negotiate_timeout
|
||||||
for protocol, transport in muxer_transports_by_protocol.items():
|
for protocol, transport in muxer_transports_by_protocol.items():
|
||||||
self.add_transport(protocol, transport)
|
self.add_transport(protocol, transport)
|
||||||
|
|
||||||
@ -80,10 +87,12 @@ class MuxerMultistream:
|
|||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
if conn.is_initiator:
|
if conn.is_initiator:
|
||||||
protocol = await self.multiselect_client.select_one_of(
|
protocol = await self.multiselect_client.select_one_of(
|
||||||
tuple(self.transports.keys()), communicator
|
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
protocol, _ = await self.multiselect.negotiate(
|
||||||
|
communicator, self.negotiate_timeout
|
||||||
|
)
|
||||||
if protocol is None:
|
if protocol is None:
|
||||||
raise MultiselectError(
|
raise MultiselectError(
|
||||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||||
@ -93,7 +102,7 @@ class MuxerMultistream:
|
|||||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
protocol = await self.multistream_client.select_one_of(
|
protocol = await self.multistream_client.select_one_of(
|
||||||
tuple(self.transports.keys()), communicator
|
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||||
)
|
)
|
||||||
transport_class = self.transports[protocol]
|
transport_class = self.transports[protocol]
|
||||||
if protocol == PROTOCOL_ID:
|
if protocol == PROTOCOL_ID:
|
||||||
|
|||||||
70
libp2p/stream_muxer/rw_lock.py
Normal file
70
libp2p/stream_muxer/rw_lock.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWriteLock:
|
||||||
|
"""
|
||||||
|
A read-write lock that allows multiple concurrent readers
|
||||||
|
or one exclusive writer, implemented using Trio primitives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._readers = 0
|
||||||
|
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||||
|
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||||
|
|
||||||
|
async def acquire_read(self) -> None:
|
||||||
|
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||||
|
try:
|
||||||
|
async with self._readers_lock:
|
||||||
|
if self._readers == 0:
|
||||||
|
await self._writer_lock.acquire()
|
||||||
|
self._readers += 1
|
||||||
|
except trio.Cancelled:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def release_read(self) -> None:
|
||||||
|
"""Release a read lock."""
|
||||||
|
async with self._readers_lock:
|
||||||
|
if self._readers == 1:
|
||||||
|
self._writer_lock.release()
|
||||||
|
self._readers -= 1
|
||||||
|
|
||||||
|
async def acquire_write(self) -> None:
|
||||||
|
"""Acquire an exclusive write lock."""
|
||||||
|
try:
|
||||||
|
await self._writer_lock.acquire()
|
||||||
|
except trio.Cancelled:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def release_write(self) -> None:
|
||||||
|
"""Release the exclusive write lock."""
|
||||||
|
self._writer_lock.release()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Context manager for acquiring and releasing a read lock safely."""
|
||||||
|
acquire = False
|
||||||
|
try:
|
||||||
|
await self.acquire_read()
|
||||||
|
acquire = True
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if acquire:
|
||||||
|
with trio.CancelScope() as scope:
|
||||||
|
scope.shield = True
|
||||||
|
await self.release_read()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||||
|
"""Context manager for acquiring and releasing a write lock safely."""
|
||||||
|
acquire = False
|
||||||
|
try:
|
||||||
|
await self.acquire_write()
|
||||||
|
acquire = True
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if acquire:
|
||||||
|
self.release_write()
|
||||||
@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
|
|||||||
MuxedStreamError,
|
MuxedStreamError,
|
||||||
MuxedStreamReset,
|
MuxedStreamReset,
|
||||||
)
|
)
|
||||||
|
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||||
|
|
||||||
# Configure logger for this module
|
# Configure logger for this module
|
||||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||||
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
|
|||||||
self.send_window = DEFAULT_WINDOW_SIZE
|
self.send_window = DEFAULT_WINDOW_SIZE
|
||||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||||
self.window_lock = trio.Lock()
|
self.window_lock = trio.Lock()
|
||||||
|
self.rw_lock = ReadWriteLock()
|
||||||
|
self.close_lock = trio.Lock()
|
||||||
|
|
||||||
async def __aenter__(self) -> "YamuxStream":
|
async def __aenter__(self) -> "YamuxStream":
|
||||||
"""Enter the async context manager."""
|
"""Enter the async context manager."""
|
||||||
@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream):
|
|||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
if self.send_closed:
|
async with self.rw_lock.write_lock():
|
||||||
raise MuxedStreamError("Stream is closed for sending")
|
if self.send_closed:
|
||||||
|
raise MuxedStreamError("Stream is closed for sending")
|
||||||
|
|
||||||
# Flow control: Check if we have enough send window
|
# Flow control: Check if we have enough send window
|
||||||
total_len = len(data)
|
total_len = len(data)
|
||||||
sent = 0
|
sent = 0
|
||||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||||
while sent < total_len:
|
while sent < total_len:
|
||||||
# Wait for available window with timeout
|
# Wait for available window with timeout
|
||||||
timeout = False
|
timeout = False
|
||||||
async with self.window_lock:
|
async with self.window_lock:
|
||||||
if self.send_window == 0:
|
if self.send_window == 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
f"Stream {self.stream_id}: "
|
||||||
|
"Window is zero, waiting for update"
|
||||||
|
)
|
||||||
|
# Release lock and wait with timeout
|
||||||
|
self.window_lock.release()
|
||||||
|
# To avoid re-acquiring the lock immediately,
|
||||||
|
with trio.move_on_after(5.0) as cancel_scope:
|
||||||
|
while self.send_window == 0 and not self.closed:
|
||||||
|
await trio.sleep(0.01)
|
||||||
|
# If we timed out, cancel the scope
|
||||||
|
timeout = cancel_scope.cancelled_caught
|
||||||
|
# Re-acquire lock
|
||||||
|
await self.window_lock.acquire()
|
||||||
|
|
||||||
|
# If we timed out waiting for window update, raise an error
|
||||||
|
if timeout:
|
||||||
|
raise MuxedStreamError(
|
||||||
|
"Timed out waiting for window update after 5 seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.closed:
|
||||||
|
raise MuxedStreamError("Stream is closed")
|
||||||
|
|
||||||
|
# Calculate how much we can send now
|
||||||
|
to_send = min(self.send_window, total_len - sent)
|
||||||
|
chunk = data[sent : sent + to_send]
|
||||||
|
self.send_window -= to_send
|
||||||
|
|
||||||
|
# Send the data
|
||||||
|
header = struct.pack(
|
||||||
|
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||||
)
|
)
|
||||||
# Release lock and wait with timeout
|
await self.conn.secured_conn.write(header + chunk)
|
||||||
self.window_lock.release()
|
sent += to_send
|
||||||
# To avoid re-acquiring the lock immediately,
|
|
||||||
with trio.move_on_after(5.0) as cancel_scope:
|
|
||||||
while self.send_window == 0 and not self.closed:
|
|
||||||
await trio.sleep(0.01)
|
|
||||||
# If we timed out, cancel the scope
|
|
||||||
timeout = cancel_scope.cancelled_caught
|
|
||||||
# Re-acquire lock
|
|
||||||
await self.window_lock.acquire()
|
|
||||||
|
|
||||||
# If we timed out waiting for window update, raise an error
|
|
||||||
if timeout:
|
|
||||||
raise MuxedStreamError(
|
|
||||||
"Timed out waiting for window update after 5 seconds."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.closed:
|
|
||||||
raise MuxedStreamError("Stream is closed")
|
|
||||||
|
|
||||||
# Calculate how much we can send now
|
|
||||||
to_send = min(self.send_window, total_len - sent)
|
|
||||||
chunk = data[sent : sent + to_send]
|
|
||||||
self.send_window -= to_send
|
|
||||||
|
|
||||||
# Send the data
|
|
||||||
header = struct.pack(
|
|
||||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
|
||||||
)
|
|
||||||
await self.conn.secured_conn.write(header + chunk)
|
|
||||||
sent += to_send
|
|
||||||
|
|
||||||
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if not self.send_closed:
|
async with self.close_lock:
|
||||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
if not self.send_closed:
|
||||||
header = struct.pack(
|
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
header = struct.pack(
|
||||||
)
|
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||||
await self.conn.secured_conn.write(header)
|
)
|
||||||
self.send_closed = True
|
await self.conn.secured_conn.write(header)
|
||||||
|
self.send_closed = True
|
||||||
|
|
||||||
# Only set fully closed if both directions are closed
|
# Only set fully closed if both directions are closed
|
||||||
if self.send_closed and self.recv_closed:
|
if self.send_closed and self.recv_closed:
|
||||||
self.closed = True
|
self.closed = True
|
||||||
else:
|
else:
|
||||||
# Stream is half-closed but not fully closed
|
# Stream is half-closed but not fully closed
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
if not self.closed:
|
if not self.closed:
|
||||||
logger.debug(f"Resetting stream {self.stream_id}")
|
async with self.close_lock:
|
||||||
header = struct.pack(
|
logger.debug(f"Resetting stream {self.stream_id}")
|
||||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
header = struct.pack(
|
||||||
)
|
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||||
await self.conn.secured_conn.write(header)
|
)
|
||||||
self.closed = True
|
await self.conn.secured_conn.write(header)
|
||||||
self.reset_received = True # Mark as reset
|
self.closed = True
|
||||||
|
self.reset_received = True # Mark as reset
|
||||||
|
|
||||||
def set_deadline(self, ttl: int) -> bool:
|
def set_deadline(self, ttl: int) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -0,0 +1,57 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .tcp.tcp import TCP
|
||||||
|
from .websocket.transport import WebsocketTransport
|
||||||
|
from .transport_registry import (
|
||||||
|
TransportRegistry,
|
||||||
|
create_transport_for_multiaddr,
|
||||||
|
get_transport_registry,
|
||||||
|
register_transport,
|
||||||
|
get_supported_transport_protocols,
|
||||||
|
)
|
||||||
|
from .upgrader import TransportUpgrader
|
||||||
|
from libp2p.abc import ITransport
|
||||||
|
|
||||||
|
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport:
|
||||||
|
"""
|
||||||
|
Convenience function to create a transport instance.
|
||||||
|
|
||||||
|
:param protocol: The transport protocol ("tcp", "ws", "wss", or custom)
|
||||||
|
:param upgrader: Optional transport upgrader (required for WebSocket)
|
||||||
|
:param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config)
|
||||||
|
:return: Transport instance
|
||||||
|
"""
|
||||||
|
# First check if it's a built-in protocol
|
||||||
|
if protocol in ["ws", "wss"]:
|
||||||
|
if upgrader is None:
|
||||||
|
raise ValueError(f"WebSocket transport requires an upgrader")
|
||||||
|
return WebsocketTransport(
|
||||||
|
upgrader,
|
||||||
|
tls_client_config=kwargs.get("tls_client_config"),
|
||||||
|
tls_server_config=kwargs.get("tls_server_config"),
|
||||||
|
handshake_timeout=kwargs.get("handshake_timeout", 15.0)
|
||||||
|
)
|
||||||
|
elif protocol == "tcp":
|
||||||
|
return TCP()
|
||||||
|
else:
|
||||||
|
# Check if it's a custom registered transport
|
||||||
|
registry = get_transport_registry()
|
||||||
|
transport_class = registry.get_transport(protocol)
|
||||||
|
if transport_class:
|
||||||
|
transport = registry.create_transport(protocol, upgrader, **kwargs)
|
||||||
|
if transport is None:
|
||||||
|
raise ValueError(f"Failed to create transport for protocol: {protocol}")
|
||||||
|
return transport
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TCP",
|
||||||
|
"WebsocketTransport",
|
||||||
|
"TransportRegistry",
|
||||||
|
"create_transport_for_multiaddr",
|
||||||
|
"create_transport",
|
||||||
|
"get_transport_registry",
|
||||||
|
"register_transport",
|
||||||
|
"get_supported_transport_protocols",
|
||||||
|
]
|
||||||
|
|||||||
0
libp2p/transport/quic/__init__.py
Normal file
0
libp2p/transport/quic/__init__.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
"""
|
||||||
|
Configuration classes for QUIC transport.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import (
|
||||||
|
dataclass,
|
||||||
|
field,
|
||||||
|
)
|
||||||
|
import ssl
|
||||||
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.network.config import ConnectionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class QUICTransportKwargs(TypedDict, total=False):
|
||||||
|
"""Type definition for kwargs accepted by new_transport function."""
|
||||||
|
|
||||||
|
# Connection settings
|
||||||
|
idle_timeout: float
|
||||||
|
max_datagram_size: int
|
||||||
|
local_port: int | None
|
||||||
|
|
||||||
|
# Protocol version support
|
||||||
|
enable_draft29: bool
|
||||||
|
enable_v1: bool
|
||||||
|
|
||||||
|
# TLS settings
|
||||||
|
verify_mode: ssl.VerifyMode
|
||||||
|
alpn_protocols: list[str]
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
max_concurrent_streams: int
|
||||||
|
connection_window: int
|
||||||
|
stream_window: int
|
||||||
|
|
||||||
|
# Logging and debugging
|
||||||
|
enable_qlog: bool
|
||||||
|
qlog_dir: str | None
|
||||||
|
|
||||||
|
# Connection management
|
||||||
|
max_connections: int
|
||||||
|
connection_timeout: float
|
||||||
|
|
||||||
|
# Protocol identifiers
|
||||||
|
PROTOCOL_QUIC_V1: TProtocol
|
||||||
|
PROTOCOL_QUIC_DRAFT29: TProtocol
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QUICTransportConfig(ConnectionConfig):
|
||||||
|
"""Configuration for QUIC transport."""
|
||||||
|
|
||||||
|
# Connection settings
|
||||||
|
idle_timeout: float = 30.0 # Seconds before an idle connection is closed.
|
||||||
|
max_datagram_size: int = (
|
||||||
|
1200 # Maximum size of UDP datagrams to avoid IP fragmentation.
|
||||||
|
)
|
||||||
|
local_port: int | None = (
|
||||||
|
None # Local port to bind to. If None, a random port is chosen.
|
||||||
|
)
|
||||||
|
|
||||||
|
# Protocol version support
|
||||||
|
enable_draft29: bool = True # Enable QUIC draft-29 for compatibility
|
||||||
|
enable_v1: bool = True # Enable QUIC v1 (RFC 9000)
|
||||||
|
|
||||||
|
# TLS settings
|
||||||
|
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
|
||||||
|
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
max_concurrent_streams: int = 100 # Maximum concurrent streams per connection
|
||||||
|
connection_window: int = 1024 * 1024 # Connection flow control window
|
||||||
|
stream_window: int = 64 * 1024 # Stream flow control window
|
||||||
|
|
||||||
|
# Logging and debugging
|
||||||
|
enable_qlog: bool = False # Enable QUIC logging
|
||||||
|
qlog_dir: str | None = None # Directory for QUIC logs
|
||||||
|
|
||||||
|
# Connection management
|
||||||
|
max_connections: int = 1000 # Maximum number of connections
|
||||||
|
connection_timeout: float = 10.0 # Connection establishment timeout
|
||||||
|
|
||||||
|
MAX_CONCURRENT_STREAMS: int = 1000
|
||||||
|
"""Maximum number of concurrent streams per connection."""
|
||||||
|
|
||||||
|
MAX_INCOMING_STREAMS: int = 1000
|
||||||
|
"""Maximum number of incoming streams per connection."""
|
||||||
|
|
||||||
|
CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0
|
||||||
|
"""Timeout for connection handshake (seconds)."""
|
||||||
|
|
||||||
|
MAX_OUTGOING_STREAMS: int = 1000
|
||||||
|
"""Maximum number of outgoing streams per connection."""
|
||||||
|
|
||||||
|
CONNECTION_CLOSE_TIMEOUT: int = 10
|
||||||
|
"""Timeout for opening new connection (seconds)."""
|
||||||
|
|
||||||
|
# Stream timeouts
|
||||||
|
STREAM_OPEN_TIMEOUT: float = 5.0
|
||||||
|
"""Timeout for opening new streams (seconds)."""
|
||||||
|
|
||||||
|
STREAM_ACCEPT_TIMEOUT: float = 30.0
|
||||||
|
"""Timeout for accepting incoming streams (seconds)."""
|
||||||
|
|
||||||
|
STREAM_READ_TIMEOUT: float = 30.0
|
||||||
|
"""Default timeout for stream read operations (seconds)."""
|
||||||
|
|
||||||
|
STREAM_WRITE_TIMEOUT: float = 30.0
|
||||||
|
"""Default timeout for stream write operations (seconds)."""
|
||||||
|
|
||||||
|
STREAM_CLOSE_TIMEOUT: float = 10.0
|
||||||
|
"""Timeout for graceful stream close (seconds)."""
|
||||||
|
|
||||||
|
# Flow control configuration
|
||||||
|
STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB
|
||||||
|
"""Per-stream flow control window size."""
|
||||||
|
|
||||||
|
CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB
|
||||||
|
"""Connection-wide flow control window size."""
|
||||||
|
|
||||||
|
# Buffer management
|
||||||
|
MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB
|
||||||
|
"""Maximum receive buffer size per stream."""
|
||||||
|
|
||||||
|
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
|
||||||
|
"""Low watermark for stream receive buffer."""
|
||||||
|
|
||||||
|
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
|
||||||
|
"""High watermark for stream receive buffer."""
|
||||||
|
|
||||||
|
# Stream lifecycle configuration
|
||||||
|
ENABLE_STREAM_RESET_ON_ERROR: bool = True
|
||||||
|
"""Whether to automatically reset streams on errors."""
|
||||||
|
|
||||||
|
STREAM_RESET_ERROR_CODE: int = 1
|
||||||
|
"""Default error code for stream resets."""
|
||||||
|
|
||||||
|
ENABLE_STREAM_KEEP_ALIVE: bool = False
|
||||||
|
"""Whether to enable stream keep-alive mechanisms."""
|
||||||
|
|
||||||
|
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
|
||||||
|
"""Interval for stream keep-alive pings (seconds)."""
|
||||||
|
|
||||||
|
# Resource management
|
||||||
|
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
|
||||||
|
"""Whether to track stream resource usage."""
|
||||||
|
|
||||||
|
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
|
||||||
|
"""Memory limit per individual stream."""
|
||||||
|
|
||||||
|
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
|
||||||
|
"""Total memory limit for all streams per connection."""
|
||||||
|
|
||||||
|
# Concurrency and performance
|
||||||
|
ENABLE_STREAM_BATCHING: bool = True
|
||||||
|
"""Whether to batch multiple stream operations."""
|
||||||
|
|
||||||
|
STREAM_BATCH_SIZE: int = 10
|
||||||
|
"""Number of streams to process in a batch."""
|
||||||
|
|
||||||
|
STREAM_PROCESSING_CONCURRENCY: int = 100
|
||||||
|
"""Maximum concurrent stream processing tasks."""
|
||||||
|
|
||||||
|
# Debugging and monitoring
|
||||||
|
ENABLE_STREAM_METRICS: bool = True
|
||||||
|
"""Whether to collect stream metrics."""
|
||||||
|
|
||||||
|
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
|
||||||
|
"""Whether to track stream lifecycle timelines."""
|
||||||
|
|
||||||
|
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
|
||||||
|
"""Interval for collecting stream metrics (seconds)."""
|
||||||
|
|
||||||
|
# Error handling configuration
|
||||||
|
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
|
||||||
|
"""Number of retry attempts for recoverable stream errors."""
|
||||||
|
|
||||||
|
STREAM_ERROR_RETRY_DELAY: float = 1.0
|
||||||
|
"""Initial delay between stream error retries (seconds)."""
|
||||||
|
|
||||||
|
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
|
||||||
|
"""Backoff factor for stream error retries."""
|
||||||
|
|
||||||
|
# Protocol identifiers matching go-libp2p
|
||||||
|
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000
|
||||||
|
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate configuration after initialization."""
|
||||||
|
if not (self.enable_draft29 or self.enable_v1):
|
||||||
|
raise ValueError("At least one QUIC version must be enabled")
|
||||||
|
|
||||||
|
if self.idle_timeout <= 0:
|
||||||
|
raise ValueError("Idle timeout must be positive")
|
||||||
|
|
||||||
|
if self.max_datagram_size < 1200:
|
||||||
|
raise ValueError("Max datagram size must be at least 1200 bytes")
|
||||||
|
|
||||||
|
# Validate timeouts
|
||||||
|
timeout_fields = [
|
||||||
|
"STREAM_OPEN_TIMEOUT",
|
||||||
|
"STREAM_ACCEPT_TIMEOUT",
|
||||||
|
"STREAM_READ_TIMEOUT",
|
||||||
|
"STREAM_WRITE_TIMEOUT",
|
||||||
|
"STREAM_CLOSE_TIMEOUT",
|
||||||
|
]
|
||||||
|
for timeout_field in timeout_fields:
|
||||||
|
if getattr(self, timeout_field) <= 0:
|
||||||
|
raise ValueError(f"{timeout_field} must be positive")
|
||||||
|
|
||||||
|
# Validate flow control windows
|
||||||
|
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
|
||||||
|
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
|
||||||
|
|
||||||
|
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
|
||||||
|
raise ValueError(
|
||||||
|
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate buffer sizes
|
||||||
|
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
|
||||||
|
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
|
||||||
|
|
||||||
|
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
|
||||||
|
raise ValueError(
|
||||||
|
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
|
||||||
|
"exceed MAX_STREAM_RECEIVE_BUFFER"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
|
||||||
|
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate memory limits
|
||||||
|
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
|
||||||
|
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
|
||||||
|
|
||||||
|
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
|
||||||
|
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
|
||||||
|
|
||||||
|
expected_stream_memory = (
|
||||||
|
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
|
||||||
|
)
|
||||||
|
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
|
||||||
|
# Allow some headroom, but warn if configuration seems inconsistent
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(
|
||||||
|
"Stream memory configuration may be inconsistent: "
|
||||||
|
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
|
||||||
|
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
|
||||||
|
"could exceed connection limit of"
|
||||||
|
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_stream_config_dict(self) -> dict[str, Any]:
|
||||||
|
"""Get stream-specific configuration as dictionary."""
|
||||||
|
stream_config = {}
|
||||||
|
for attr_name in dir(self):
|
||||||
|
if attr_name.startswith(
|
||||||
|
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
|
||||||
|
):
|
||||||
|
stream_config[attr_name.lower()] = getattr(self, attr_name)
|
||||||
|
return stream_config
|
||||||
|
|
||||||
|
|
||||||
|
# Additional configuration classes for specific stream features
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamFlowControlConfig:
|
||||||
|
"""Configuration for QUIC stream flow control."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_window_size: int = 512 * 1024,
|
||||||
|
max_window_size: int = 2 * 1024 * 1024,
|
||||||
|
window_update_threshold: float = 0.5,
|
||||||
|
enable_auto_tuning: bool = True,
|
||||||
|
):
|
||||||
|
self.initial_window_size = initial_window_size
|
||||||
|
self.max_window_size = max_window_size
|
||||||
|
self.window_update_threshold = window_update_threshold
|
||||||
|
self.enable_auto_tuning = enable_auto_tuning
|
||||||
|
|
||||||
|
|
||||||
|
def create_stream_config_for_use_case(
|
||||||
|
use_case: Literal[
|
||||||
|
"high_throughput", "low_latency", "many_streams", "memory_constrained"
|
||||||
|
],
|
||||||
|
) -> QUICTransportConfig:
|
||||||
|
"""
|
||||||
|
Create optimized stream configuration for specific use cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_case: One of "high_throughput", "low_latency", "many_streams","
|
||||||
|
"memory_constrained"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimized QUICTransportConfig
|
||||||
|
|
||||||
|
"""
|
||||||
|
base_config = QUICTransportConfig()
|
||||||
|
|
||||||
|
if use_case == "high_throughput":
|
||||||
|
# Optimize for high throughput
|
||||||
|
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
|
||||||
|
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
|
||||||
|
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
|
||||||
|
base_config.STREAM_PROCESSING_CONCURRENCY = 200
|
||||||
|
|
||||||
|
elif use_case == "low_latency":
|
||||||
|
# Optimize for low latency
|
||||||
|
base_config.STREAM_OPEN_TIMEOUT = 1.0
|
||||||
|
base_config.STREAM_READ_TIMEOUT = 5.0
|
||||||
|
base_config.STREAM_WRITE_TIMEOUT = 5.0
|
||||||
|
base_config.ENABLE_STREAM_BATCHING = False
|
||||||
|
base_config.STREAM_BATCH_SIZE = 1
|
||||||
|
|
||||||
|
elif use_case == "many_streams":
|
||||||
|
# Optimize for many concurrent streams
|
||||||
|
base_config.MAX_CONCURRENT_STREAMS = 5000
|
||||||
|
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
|
||||||
|
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
|
||||||
|
base_config.STREAM_PROCESSING_CONCURRENCY = 500
|
||||||
|
|
||||||
|
elif use_case == "memory_constrained":
|
||||||
|
# Optimize for low memory usage
|
||||||
|
base_config.MAX_CONCURRENT_STREAMS = 100
|
||||||
|
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
|
||||||
|
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
|
||||||
|
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
|
||||||
|
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
|
||||||
|
base_config.STREAM_PROCESSING_CONCURRENCY = 50
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown use case: {use_case}")
|
||||||
|
|
||||||
|
return base_config
|
||||||
1489
libp2p/transport/quic/connection.py
Normal file
1489
libp2p/transport/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
391
libp2p/transport/quic/exceptions.py
Normal file
391
libp2p/transport/quic/exceptions.py
Normal file
@ -0,0 +1,391 @@
|
|||||||
|
"""
|
||||||
|
QUIC Transport exceptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
|
class QUICError(Exception):
|
||||||
|
"""Base exception for all QUIC transport errors."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, error_code: int | None = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
|
||||||
|
# Transport-level exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICTransportError(QUICError):
|
||||||
|
"""Base exception for QUIC transport operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICDialError(QUICTransportError):
|
||||||
|
"""Error occurred during QUIC connection establishment."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICListenError(QUICTransportError):
|
||||||
|
"""Error occurred during QUIC listener operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICSecurityError(QUICTransportError):
|
||||||
|
"""Error related to QUIC security/TLS operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Connection-level exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConnectionError(QUICError):
|
||||||
|
"""Base exception for QUIC connection operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConnectionClosedError(QUICConnectionError):
|
||||||
|
"""QUIC connection has been closed."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConnectionTimeoutError(QUICConnectionError):
|
||||||
|
"""QUIC connection operation timed out."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICHandshakeError(QUICConnectionError):
|
||||||
|
"""Error during QUIC handshake process."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICPeerVerificationError(QUICConnectionError):
|
||||||
|
"""Error verifying peer identity during handshake."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Stream-level exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamError(QUICError):
|
||||||
|
"""Base exception for QUIC stream operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
error_code: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(message, error_code)
|
||||||
|
self.stream_id = stream_id
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamClosedError(QUICStreamError):
|
||||||
|
"""Stream is closed and cannot be used for I/O operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamResetError(QUICStreamError):
|
||||||
|
"""Stream was reset by local or remote peer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
error_code: int | None = None,
|
||||||
|
reset_by_peer: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(message, stream_id, error_code)
|
||||||
|
self.reset_by_peer = reset_by_peer
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamTimeoutError(QUICStreamError):
|
||||||
|
"""Stream operation timed out."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamBackpressureError(QUICStreamError):
|
||||||
|
"""Stream write blocked due to flow control."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamLimitError(QUICStreamError):
|
||||||
|
"""Stream limit reached (too many concurrent streams)."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamStateError(QUICStreamError):
|
||||||
|
"""Invalid operation for current stream state."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
current_state: str | None = None,
|
||||||
|
attempted_operation: str | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(message, stream_id)
|
||||||
|
self.current_state = current_state
|
||||||
|
self.attempted_operation = attempted_operation
|
||||||
|
|
||||||
|
|
||||||
|
# Flow control exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICFlowControlError(QUICError):
|
||||||
|
"""Base exception for flow control related errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICFlowControlViolationError(QUICFlowControlError):
|
||||||
|
"""Flow control limits were violated."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICFlowControlDeadlockError(QUICFlowControlError):
|
||||||
|
"""Flow control deadlock detected."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Resource management exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICResourceError(QUICError):
|
||||||
|
"""Base exception for resource management errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICMemoryLimitError(QUICResourceError):
|
||||||
|
"""Memory limit exceeded."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConnectionLimitError(QUICResourceError):
|
||||||
|
"""Connection limit exceeded."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Multiaddr and addressing exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICAddressError(QUICError):
|
||||||
|
"""Base exception for QUIC addressing errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICInvalidMultiaddrError(QUICAddressError):
|
||||||
|
"""Invalid multiaddr format for QUIC transport."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICAddressResolutionError(QUICAddressError):
|
||||||
|
"""Failed to resolve QUIC address."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICProtocolError(QUICError):
|
||||||
|
"""Base exception for QUIC protocol errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICVersionNegotiationError(QUICProtocolError):
|
||||||
|
"""QUIC version negotiation failed."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICUnsupportedVersionError(QUICProtocolError):
|
||||||
|
"""Unsupported QUIC version."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConfigurationError(QUICError):
|
||||||
|
"""Base exception for QUIC configuration errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICInvalidConfigError(QUICConfigurationError):
|
||||||
|
"""Invalid QUIC configuration parameters."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICCertificateError(QUICConfigurationError):
|
||||||
|
"""Error with TLS certificate configuration."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def map_quic_error_code(error_code: int) -> str:
|
||||||
|
"""
|
||||||
|
Map QUIC error codes to human-readable descriptions.
|
||||||
|
Based on RFC 9000 Transport Error Codes.
|
||||||
|
"""
|
||||||
|
error_codes = {
|
||||||
|
0x00: "NO_ERROR",
|
||||||
|
0x01: "INTERNAL_ERROR",
|
||||||
|
0x02: "CONNECTION_REFUSED",
|
||||||
|
0x03: "FLOW_CONTROL_ERROR",
|
||||||
|
0x04: "STREAM_LIMIT_ERROR",
|
||||||
|
0x05: "STREAM_STATE_ERROR",
|
||||||
|
0x06: "FINAL_SIZE_ERROR",
|
||||||
|
0x07: "FRAME_ENCODING_ERROR",
|
||||||
|
0x08: "TRANSPORT_PARAMETER_ERROR",
|
||||||
|
0x09: "CONNECTION_ID_LIMIT_ERROR",
|
||||||
|
0x0A: "PROTOCOL_VIOLATION",
|
||||||
|
0x0B: "INVALID_TOKEN",
|
||||||
|
0x0C: "APPLICATION_ERROR",
|
||||||
|
0x0D: "CRYPTO_BUFFER_EXCEEDED",
|
||||||
|
0x0E: "KEY_UPDATE_ERROR",
|
||||||
|
0x0F: "AEAD_LIMIT_REACHED",
|
||||||
|
0x10: "NO_VIABLE_PATH",
|
||||||
|
}
|
||||||
|
|
||||||
|
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_stream_error(
|
||||||
|
error_type: str,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
error_code: int | None = None,
|
||||||
|
) -> QUICStreamError:
|
||||||
|
"""
|
||||||
|
Factory function to create appropriate stream error based on type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
|
||||||
|
message: Error message
|
||||||
|
stream_id: Stream identifier
|
||||||
|
error_code: QUIC error code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate QUICStreamError subclass
|
||||||
|
|
||||||
|
"""
|
||||||
|
error_type = error_type.lower()
|
||||||
|
|
||||||
|
if error_type in ("closed", "close"):
|
||||||
|
return QUICStreamClosedError(message, stream_id, error_code)
|
||||||
|
elif error_type == "reset":
|
||||||
|
return QUICStreamResetError(message, stream_id, error_code)
|
||||||
|
elif error_type == "timeout":
|
||||||
|
return QUICStreamTimeoutError(message, stream_id, error_code)
|
||||||
|
elif error_type in ("backpressure", "flow_control"):
|
||||||
|
return QUICStreamBackpressureError(message, stream_id, error_code)
|
||||||
|
elif error_type in ("limit", "stream_limit"):
|
||||||
|
return QUICStreamLimitError(message, stream_id, error_code)
|
||||||
|
elif error_type == "state":
|
||||||
|
return QUICStreamStateError(message, stream_id)
|
||||||
|
else:
|
||||||
|
return QUICStreamError(message, stream_id, error_code)
|
||||||
|
|
||||||
|
|
||||||
|
def create_connection_error(
|
||||||
|
error_type: str, message: str, error_code: int | None = None
|
||||||
|
) -> QUICConnectionError:
|
||||||
|
"""
|
||||||
|
Factory function to create appropriate connection error based on type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error ("closed", "timeout", "handshake", etc.)
|
||||||
|
message: Error message
|
||||||
|
error_code: QUIC error code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate QUICConnectionError subclass
|
||||||
|
|
||||||
|
"""
|
||||||
|
error_type = error_type.lower()
|
||||||
|
|
||||||
|
if error_type in ("closed", "close"):
|
||||||
|
return QUICConnectionClosedError(message, error_code)
|
||||||
|
elif error_type == "timeout":
|
||||||
|
return QUICConnectionTimeoutError(message, error_code)
|
||||||
|
elif error_type == "handshake":
|
||||||
|
return QUICHandshakeError(message, error_code)
|
||||||
|
elif error_type in ("peer_verification", "verification"):
|
||||||
|
return QUICPeerVerificationError(message, error_code)
|
||||||
|
else:
|
||||||
|
return QUICConnectionError(message, error_code)
|
||||||
|
|
||||||
|
|
||||||
|
class QUICErrorContext:
|
||||||
|
"""
|
||||||
|
Context manager for handling QUIC errors with automatic error mapping.
|
||||||
|
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, operation: str, component: str = "quic") -> None:
|
||||||
|
self.operation = operation
|
||||||
|
self.component = component
|
||||||
|
|
||||||
|
def __enter__(self) -> "QUICErrorContext":
|
||||||
|
return self
|
||||||
|
|
||||||
|
# TODO: Fix types for exc_type
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> Literal[False]:
|
||||||
|
if exc_type is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if exc_val is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Map common aioquic exceptions to our exceptions
|
||||||
|
if "ConnectionClosed" in str(exc_type):
|
||||||
|
raise QUICConnectionClosedError(
|
||||||
|
f"Connection closed during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
elif "StreamReset" in str(exc_type):
|
||||||
|
raise QUICStreamResetError(
|
||||||
|
f"Stream reset during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
elif "timeout" in str(exc_val).lower():
|
||||||
|
if "stream" in self.component.lower():
|
||||||
|
raise QUICStreamTimeoutError(
|
||||||
|
f"Timeout during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
else:
|
||||||
|
raise QUICConnectionTimeoutError(
|
||||||
|
f"Timeout during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
elif "flow control" in str(exc_val).lower():
|
||||||
|
raise QUICStreamBackpressureError(
|
||||||
|
f"Flow control error during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
|
||||||
|
# Let other exceptions propagate
|
||||||
|
return False
|
||||||
1041
libp2p/transport/quic/listener.py
Normal file
1041
libp2p/transport/quic/listener.py
Normal file
File diff suppressed because it is too large
Load Diff
1165
libp2p/transport/quic/security.py
Normal file
1165
libp2p/transport/quic/security.py
Normal file
File diff suppressed because it is too large
Load Diff
656
libp2p/transport/quic/stream.py
Normal file
656
libp2p/transport/quic/stream.py
Normal file
@ -0,0 +1,656 @@
|
|||||||
|
"""
|
||||||
|
QUIC Stream implementation
|
||||||
|
Provides stream interface over QUIC's native multiplexing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
QUICStreamBackpressureError,
|
||||||
|
QUICStreamClosedError,
|
||||||
|
QUICStreamResetError,
|
||||||
|
QUICStreamTimeoutError,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from libp2p.abc import IMuxedStream
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
|
||||||
|
from .connection import QUICConnection
|
||||||
|
else:
|
||||||
|
IMuxedStream = cast(type, object)
|
||||||
|
TProtocol = cast(type, object)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamState(Enum):
|
||||||
|
"""Stream lifecycle states following libp2p patterns."""
|
||||||
|
|
||||||
|
OPEN = "open"
|
||||||
|
WRITE_CLOSED = "write_closed"
|
||||||
|
READ_CLOSED = "read_closed"
|
||||||
|
CLOSED = "closed"
|
||||||
|
RESET = "reset"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamDirection(Enum):
|
||||||
|
"""Stream direction for tracking initiator."""
|
||||||
|
|
||||||
|
INBOUND = "inbound"
|
||||||
|
OUTBOUND = "outbound"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamTimeline:
|
||||||
|
"""Track stream lifecycle events for debugging and monitoring."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.created_at = time.time()
|
||||||
|
self.opened_at: float | None = None
|
||||||
|
self.first_data_at: float | None = None
|
||||||
|
self.closed_at: float | None = None
|
||||||
|
self.reset_at: float | None = None
|
||||||
|
self.error_code: int | None = None
|
||||||
|
|
||||||
|
def record_open(self) -> None:
|
||||||
|
self.opened_at = time.time()
|
||||||
|
|
||||||
|
def record_first_data(self) -> None:
|
||||||
|
if self.first_data_at is None:
|
||||||
|
self.first_data_at = time.time()
|
||||||
|
|
||||||
|
def record_close(self) -> None:
|
||||||
|
self.closed_at = time.time()
|
||||||
|
|
||||||
|
def record_reset(self, error_code: int) -> None:
|
||||||
|
self.reset_at = time.time()
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStream(IMuxedStream):
|
||||||
|
"""
|
||||||
|
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||||
|
|
||||||
|
Based on patterns from go-libp2p and js-libp2p, this implementation:
|
||||||
|
- Leverages QUIC's native multiplexing and flow control
|
||||||
|
- Integrates with libp2p resource management
|
||||||
|
- Provides comprehensive error handling with QUIC-specific codes
|
||||||
|
- Supports bidirectional communication with independent close semantics
|
||||||
|
- Implements proper stream lifecycle management
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection: "QUICConnection",
|
||||||
|
stream_id: int,
|
||||||
|
direction: StreamDirection,
|
||||||
|
remote_addr: tuple[str, int],
|
||||||
|
resource_scope: Any | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize QUIC stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: Parent QUIC connection
|
||||||
|
stream_id: QUIC stream identifier
|
||||||
|
direction: Stream direction (inbound/outbound)
|
||||||
|
resource_scope: Resource manager scope for memory accounting
|
||||||
|
remote_addr: Remote addr stream is connected to
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._connection = connection
|
||||||
|
self._stream_id = stream_id
|
||||||
|
self._direction = direction
|
||||||
|
self._resource_scope = resource_scope
|
||||||
|
|
||||||
|
# libp2p interface compliance
|
||||||
|
self._protocol: TProtocol | None = None
|
||||||
|
self._metadata: dict[str, Any] = {}
|
||||||
|
self._remote_addr = remote_addr
|
||||||
|
|
||||||
|
# Stream state management
|
||||||
|
self._state = StreamState.OPEN
|
||||||
|
self._state_lock = trio.Lock()
|
||||||
|
|
||||||
|
# Flow control and buffering
|
||||||
|
self._receive_buffer = bytearray()
|
||||||
|
self._receive_buffer_lock = trio.Lock()
|
||||||
|
self._receive_event = trio.Event()
|
||||||
|
self._backpressure_event = trio.Event()
|
||||||
|
self._backpressure_event.set() # Initially no backpressure
|
||||||
|
|
||||||
|
# Close/reset state
|
||||||
|
self._write_closed = False
|
||||||
|
self._read_closed = False
|
||||||
|
self._close_event = trio.Event()
|
||||||
|
self._reset_error_code: int | None = None
|
||||||
|
|
||||||
|
# Lifecycle tracking
|
||||||
|
self._timeline = StreamTimeline()
|
||||||
|
self._timeline.record_open()
|
||||||
|
|
||||||
|
# Resource accounting
|
||||||
|
self._memory_reserved = 0
|
||||||
|
|
||||||
|
# Stream constant configurations
|
||||||
|
self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT
|
||||||
|
self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT
|
||||||
|
self.FLOW_CONTROL_WINDOW_SIZE = (
|
||||||
|
connection._transport._config.STREAM_FLOW_CONTROL_WINDOW
|
||||||
|
)
|
||||||
|
self.MAX_RECEIVE_BUFFER_SIZE = (
|
||||||
|
connection._transport._config.MAX_STREAM_RECEIVE_BUFFER
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._resource_scope:
|
||||||
|
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Created QUIC stream {stream_id} "
|
||||||
|
f"({direction.value}, connection: {connection.remote_peer_id()})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Properties for libp2p interface compliance
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol(self) -> TProtocol | None:
|
||||||
|
"""Get the protocol identifier for this stream."""
|
||||||
|
return self._protocol
|
||||||
|
|
||||||
|
@protocol.setter
|
||||||
|
def protocol(self, protocol_id: TProtocol) -> None:
|
||||||
|
"""Set the protocol identifier for this stream."""
|
||||||
|
self._protocol = protocol_id
|
||||||
|
self._metadata["protocol"] = protocol_id
|
||||||
|
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stream_id(self) -> str:
|
||||||
|
"""Get stream ID as string for libp2p compatibility."""
|
||||||
|
return str(self._stream_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def muxed_conn(self) -> "QUICConnection": # type: ignore
|
||||||
|
"""Get the parent muxed connection."""
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> StreamState:
|
||||||
|
"""Get current stream state."""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def direction(self) -> StreamDirection:
|
||||||
|
"""Get stream direction."""
|
||||||
|
return self._direction
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initiator(self) -> bool:
|
||||||
|
"""Check if this stream was locally initiated."""
|
||||||
|
return self._direction == StreamDirection.OUTBOUND
|
||||||
|
|
||||||
|
# Core stream operations
|
||||||
|
|
||||||
|
async def read(self, n: int | None = None) -> bytes:
|
||||||
|
"""
|
||||||
|
Read data from the stream with QUIC flow control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Maximum number of bytes to read. If None or -1, read all available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Data read from stream
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICStreamClosedError: Stream is closed
|
||||||
|
QUICStreamResetError: Stream was reset
|
||||||
|
QUICStreamTimeoutError: Read timeout exceeded
|
||||||
|
|
||||||
|
"""
|
||||||
|
if n is None:
|
||||||
|
n = -1
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||||
|
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||||
|
|
||||||
|
if self._read_closed:
|
||||||
|
# Return any remaining buffered data, then EOF
|
||||||
|
async with self._receive_buffer_lock:
|
||||||
|
if self._receive_buffer:
|
||||||
|
data = self._extract_data_from_buffer(n)
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
return data
|
||||||
|
return b""
|
||||||
|
|
||||||
|
# Wait for data with timeout
|
||||||
|
timeout = self.READ_TIMEOUT
|
||||||
|
try:
|
||||||
|
with trio.move_on_after(timeout) as cancel_scope:
|
||||||
|
while True:
|
||||||
|
async with self._receive_buffer_lock:
|
||||||
|
if self._receive_buffer:
|
||||||
|
data = self._extract_data_from_buffer(n)
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check if stream was closed while waiting
|
||||||
|
if self._read_closed:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
# Wait for more data
|
||||||
|
await self._receive_event.wait()
|
||||||
|
self._receive_event = trio.Event() # Reset for next wait
|
||||||
|
|
||||||
|
if cancel_scope.cancelled_caught:
|
||||||
|
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
|
||||||
|
|
||||||
|
return b""
|
||||||
|
except QUICStreamResetError:
|
||||||
|
# Stream was reset while reading
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading from stream {self.stream_id}: {e}")
|
||||||
|
await self._handle_stream_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Write data to the stream with QUIC flow control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Data to write
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICStreamClosedError: Stream is closed for writing
|
||||||
|
QUICStreamBackpressureError: Flow control window exhausted
|
||||||
|
QUICStreamResetError: Stream was reset
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||||
|
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||||
|
|
||||||
|
if self._write_closed:
|
||||||
|
raise QUICStreamClosedError(
|
||||||
|
f"Stream {self.stream_id} write side is closed"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle flow control backpressure
|
||||||
|
await self._backpressure_event.wait()
|
||||||
|
|
||||||
|
# Send data through QUIC connection
|
||||||
|
self._connection._quic.send_stream_data(self._stream_id, data)
|
||||||
|
await self._connection._transmit()
|
||||||
|
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing to stream {self.stream_id}: {e}")
|
||||||
|
# Convert QUIC-specific errors
|
||||||
|
if "flow control" in str(e).lower():
|
||||||
|
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
|
||||||
|
await self._handle_stream_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close the stream gracefully (both read and write sides).
|
||||||
|
|
||||||
|
This implements proper close semantics where both sides
|
||||||
|
are closed and resources are cleaned up.
|
||||||
|
"""
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Closing stream {self.stream_id}")
|
||||||
|
|
||||||
|
# Close both sides
|
||||||
|
if not self._write_closed:
|
||||||
|
await self.close_write()
|
||||||
|
if not self._read_closed:
|
||||||
|
await self.close_read()
|
||||||
|
|
||||||
|
# Update state and cleanup
|
||||||
|
async with self._state_lock:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
|
||||||
|
await self._cleanup_resources()
|
||||||
|
self._timeline.record_close()
|
||||||
|
self._close_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} closed")
|
||||||
|
|
||||||
|
async def close_write(self) -> None:
|
||||||
|
"""Close the write side of the stream."""
|
||||||
|
if self._write_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send FIN to close write side
|
||||||
|
self._connection._quic.send_stream_data(
|
||||||
|
self._stream_id, b"", end_stream=True
|
||||||
|
)
|
||||||
|
await self._connection._transmit()
|
||||||
|
|
||||||
|
self._write_closed = True
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._read_closed:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
else:
|
||||||
|
self._state = StreamState.WRITE_CLOSED
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} write side closed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
|
||||||
|
|
||||||
|
async def close_read(self) -> None:
|
||||||
|
"""Close the read side of the stream."""
|
||||||
|
if self._read_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._read_closed = True
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._write_closed:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
else:
|
||||||
|
self._state = StreamState.READ_CLOSED
|
||||||
|
|
||||||
|
# Wake up any pending reads
|
||||||
|
self._receive_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} read side closed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
|
||||||
|
|
||||||
|
async def reset(self, error_code: int = 0) -> None:
|
||||||
|
"""
|
||||||
|
Reset the stream with the given error code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_code: QUIC error code for the reset
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state == StreamState.RESET:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Resetting stream {self.stream_id} with error code {error_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._state = StreamState.RESET
|
||||||
|
self._reset_error_code = error_code
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send QUIC reset frame
|
||||||
|
self._connection._quic.reset_stream(self._stream_id, error_code)
|
||||||
|
await self._connection._transmit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
|
||||||
|
finally:
|
||||||
|
# Always cleanup resources
|
||||||
|
await self._cleanup_resources()
|
||||||
|
self._timeline.record_reset(error_code)
|
||||||
|
self._close_event.set()
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if stream is completely closed."""
|
||||||
|
return self._state in (StreamState.CLOSED, StreamState.RESET)
|
||||||
|
|
||||||
|
def is_reset(self) -> bool:
|
||||||
|
"""Check if stream was reset."""
|
||||||
|
return self._state == StreamState.RESET
|
||||||
|
|
||||||
|
def can_read(self) -> bool:
|
||||||
|
"""Check if stream can be read from."""
|
||||||
|
return not self._read_closed and self._state not in (
|
||||||
|
StreamState.CLOSED,
|
||||||
|
StreamState.RESET,
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_write(self) -> bool:
|
||||||
|
"""Check if stream can be written to."""
|
||||||
|
return not self._write_closed and self._state not in (
|
||||||
|
StreamState.CLOSED,
|
||||||
|
StreamState.RESET,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
||||||
|
"""
|
||||||
|
Handle data received from the QUIC connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Received data
|
||||||
|
end_stream: Whether this is the last data (FIN received)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._state == StreamState.RESET:
|
||||||
|
return
|
||||||
|
|
||||||
|
if data:
|
||||||
|
async with self._receive_buffer_lock:
|
||||||
|
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
|
||||||
|
logger.warning(
|
||||||
|
f"Stream {self.stream_id} receive buffer overflow, "
|
||||||
|
f"dropping {len(data)} bytes"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._receive_buffer.extend(data)
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
|
||||||
|
# Notify waiting readers
|
||||||
|
self._receive_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
|
||||||
|
|
||||||
|
if end_stream:
|
||||||
|
self._read_closed = True
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._write_closed:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
else:
|
||||||
|
self._state = StreamState.READ_CLOSED
|
||||||
|
|
||||||
|
# Wake up readers to process remaining data and EOF
|
||||||
|
self._receive_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} received FIN")
|
||||||
|
|
||||||
|
async def handle_stop_sending(self, error_code: int) -> None:
|
||||||
|
"""
|
||||||
|
Handle STOP_SENDING frame from remote peer.
|
||||||
|
|
||||||
|
When a STOP_SENDING frame is received, the peer is requesting that we
|
||||||
|
stop sending data on this stream. We respond by resetting the stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_code: Error code from the STOP_SENDING frame
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._write_closed = True
|
||||||
|
|
||||||
|
# Wake up any pending write operations
|
||||||
|
self._backpressure_event.set()
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.direction == StreamDirection.OUTBOUND:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
elif self._read_closed:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
else:
|
||||||
|
# Only write side closed - add WRITE_CLOSED state if needed
|
||||||
|
self._state = StreamState.WRITE_CLOSED
|
||||||
|
|
||||||
|
# Send RESET_STREAM in response (QUIC protocol requirement)
|
||||||
|
try:
|
||||||
|
self._connection._quic.reset_stream(int(self.stream_id), error_code)
|
||||||
|
await self._connection._transmit()
|
||||||
|
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_reset(self, error_code: int) -> None:
|
||||||
|
"""
|
||||||
|
Handle stream reset from remote peer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_code: QUIC error code from reset frame
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"Stream {self.stream_id} reset by peer with error code {error_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
self._state = StreamState.RESET
|
||||||
|
self._reset_error_code = error_code
|
||||||
|
|
||||||
|
await self._cleanup_resources()
|
||||||
|
self._timeline.record_reset(error_code)
|
||||||
|
self._close_event.set()
|
||||||
|
|
||||||
|
# Wake up any pending operations
|
||||||
|
self._receive_event.set()
|
||||||
|
self._backpressure_event.set()
|
||||||
|
|
||||||
|
async def handle_flow_control_update(self, available_window: int) -> None:
|
||||||
|
"""
|
||||||
|
Handle flow control window updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_window: Available flow control window size
|
||||||
|
|
||||||
|
"""
|
||||||
|
if available_window > 0:
|
||||||
|
self._backpressure_event.set()
|
||||||
|
logger.debug(
|
||||||
|
f"Stream {self.stream_id} flow control".__add__(
|
||||||
|
f"window updated: {available_window}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._backpressure_event = trio.Event() # Reset to blocking state
|
||||||
|
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
|
||||||
|
|
||||||
|
def _extract_data_from_buffer(self, n: int) -> bytes:
|
||||||
|
"""Extract data from receive buffer with specified limit."""
|
||||||
|
if n == -1:
|
||||||
|
# Read all available data
|
||||||
|
data = bytes(self._receive_buffer)
|
||||||
|
self._receive_buffer.clear()
|
||||||
|
else:
|
||||||
|
# Read up to n bytes
|
||||||
|
data = bytes(self._receive_buffer[:n])
|
||||||
|
self._receive_buffer = self._receive_buffer[n:]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _handle_stream_error(self, error: Exception) -> None:
|
||||||
|
"""Handle errors by resetting the stream."""
|
||||||
|
logger.error(f"Stream {self.stream_id} error: {error}")
|
||||||
|
await self.reset(error_code=1) # Generic error code
|
||||||
|
|
||||||
|
def _reserve_memory(self, size: int) -> None:
|
||||||
|
"""Reserve memory with resource manager."""
|
||||||
|
if self._resource_scope:
|
||||||
|
try:
|
||||||
|
self._resource_scope.reserve_memory(size)
|
||||||
|
self._memory_reserved += size
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to reserve memory for stream {self.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _release_memory(self, size: int) -> None:
|
||||||
|
"""Release memory with resource manager."""
|
||||||
|
if self._resource_scope and size > 0:
|
||||||
|
try:
|
||||||
|
self._resource_scope.release_memory(size)
|
||||||
|
self._memory_reserved = max(0, self._memory_reserved - size)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to release memory for stream {self.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cleanup_resources(self) -> None:
|
||||||
|
"""Clean up stream resources."""
|
||||||
|
# Release all reserved memory
|
||||||
|
if self._memory_reserved > 0:
|
||||||
|
self._release_memory(self._memory_reserved)
|
||||||
|
|
||||||
|
# Clear receive buffer
|
||||||
|
async with self._receive_buffer_lock:
|
||||||
|
self._receive_buffer.clear()
|
||||||
|
|
||||||
|
# Remove from connection's stream registry
|
||||||
|
self._connection._remove_stream(self._stream_id)
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} resources cleaned up")
|
||||||
|
|
||||||
|
# Abstact implementations
|
||||||
|
|
||||||
|
def get_remote_address(self) -> tuple[str, int]:
|
||||||
|
return self._remote_addr
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "QUICStream":
|
||||||
|
"""Enter the async context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
"""Exit the async context manager and close the stream."""
|
||||||
|
logger.debug("Exiting the context and closing the stream")
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
def set_deadline(self, ttl: int) -> bool:
|
||||||
|
"""
|
||||||
|
Set a deadline for the stream. QUIC does not support deadlines natively,
|
||||||
|
so this method always returns False to indicate the operation is unsupported.
|
||||||
|
|
||||||
|
:param ttl: Time-to-live in seconds (ignored).
|
||||||
|
:return: False, as deadlines are not supported.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("QUIC does not support setting read deadlines")
|
||||||
|
|
||||||
|
# String representation for debugging
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"QUICStream(id={self.stream_id}, "
|
||||||
|
f"state={self._state.value}, "
|
||||||
|
f"direction={self._direction.value}, "
|
||||||
|
f"protocol={self._protocol})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"QUICStream({self.stream_id})"
|
||||||
491
libp2p/transport/quic/transport.py
Normal file
491
libp2p/transport/quic/transport.py
Normal file
@ -0,0 +1,491 @@
|
|||||||
|
"""
|
||||||
|
QUIC Transport implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
|
from aioquic.quic.configuration import (
|
||||||
|
QuicConfiguration,
|
||||||
|
)
|
||||||
|
from aioquic.quic.connection import (
|
||||||
|
QuicConnection as NativeQUICConnection,
|
||||||
|
)
|
||||||
|
from aioquic.quic.logger import QuicLogger
|
||||||
|
import multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.abc import (
|
||||||
|
ITransport,
|
||||||
|
)
|
||||||
|
from libp2p.crypto.keys import (
|
||||||
|
PrivateKey,
|
||||||
|
)
|
||||||
|
from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
|
||||||
|
from libp2p.peer.id import (
|
||||||
|
ID,
|
||||||
|
)
|
||||||
|
from libp2p.transport.quic.security import QUICTLSSecurityConfig
|
||||||
|
from libp2p.transport.quic.utils import (
|
||||||
|
create_client_config_from_base,
|
||||||
|
create_server_config_from_base,
|
||||||
|
get_alpn_protocols,
|
||||||
|
is_quic_multiaddr,
|
||||||
|
multiaddr_to_quic_version,
|
||||||
|
quic_multiaddr_to_endpoint,
|
||||||
|
quic_version_to_wire_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from libp2p.network.swarm import Swarm
|
||||||
|
else:
|
||||||
|
Swarm = cast(type, object)
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
QUICTransportConfig,
|
||||||
|
)
|
||||||
|
from .connection import (
|
||||||
|
QUICConnection,
|
||||||
|
)
|
||||||
|
from .exceptions import (
|
||||||
|
QUICDialError,
|
||||||
|
QUICListenError,
|
||||||
|
QUICSecurityError,
|
||||||
|
)
|
||||||
|
from .listener import (
|
||||||
|
QUICListener,
|
||||||
|
)
|
||||||
|
from .security import (
|
||||||
|
QUICTLSConfigManager,
|
||||||
|
create_quic_security_transport,
|
||||||
|
)
|
||||||
|
|
||||||
|
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||||
|
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QUICTransport(ITransport):
|
||||||
|
"""
|
||||||
|
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize QUIC transport with security integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
private_key: libp2p private key for identity and TLS cert generation
|
||||||
|
config: QUIC transport configuration options
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._private_key = private_key
|
||||||
|
self._peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||||
|
self._config = config or QUICTransportConfig()
|
||||||
|
|
||||||
|
# Connection management
|
||||||
|
self._connections: dict[str, QUICConnection] = {}
|
||||||
|
self._listeners: list[QUICListener] = []
|
||||||
|
|
||||||
|
# Security manager for TLS integration
|
||||||
|
self._security_manager = create_quic_security_transport(
|
||||||
|
self._private_key, self._peer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# QUIC configurations for different versions
|
||||||
|
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
|
||||||
|
self._setup_quic_configurations()
|
||||||
|
|
||||||
|
# Resource management
|
||||||
|
self._closed = False
|
||||||
|
self._nursery_manager = trio.CapacityLimiter(1)
|
||||||
|
self._background_nursery: trio.Nursery | None = None
|
||||||
|
|
||||||
|
self._swarm: Swarm | None = None
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Initialized QUIC transport with security for peer {self._peer_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_background_nursery(self, nursery: trio.Nursery) -> None:
|
||||||
|
"""Set the nursery to use for background tasks (called by swarm)."""
|
||||||
|
self._background_nursery = nursery
|
||||||
|
logger.debug("Transport background nursery set")
|
||||||
|
|
||||||
|
def set_swarm(self, swarm: Swarm) -> None:
|
||||||
|
"""Set the swarm for adding incoming connections."""
|
||||||
|
self._swarm = swarm
|
||||||
|
|
||||||
|
def _setup_quic_configurations(self) -> None:
|
||||||
|
"""Setup QUIC configurations."""
|
||||||
|
try:
|
||||||
|
# Get TLS configuration from security manager
|
||||||
|
server_tls_config = self._security_manager.create_server_config()
|
||||||
|
client_tls_config = self._security_manager.create_client_config()
|
||||||
|
|
||||||
|
# Base server configuration
|
||||||
|
base_server_config = QuicConfiguration(
|
||||||
|
is_client=False,
|
||||||
|
alpn_protocols=get_alpn_protocols(),
|
||||||
|
verify_mode=self._config.verify_mode,
|
||||||
|
max_datagram_frame_size=self._config.max_datagram_size,
|
||||||
|
idle_timeout=self._config.idle_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base client configuration
|
||||||
|
base_client_config = QuicConfiguration(
|
||||||
|
is_client=True,
|
||||||
|
alpn_protocols=get_alpn_protocols(),
|
||||||
|
verify_mode=self._config.verify_mode,
|
||||||
|
max_datagram_frame_size=self._config.max_datagram_size,
|
||||||
|
idle_timeout=self._config.idle_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply TLS configuration
|
||||||
|
self._apply_tls_configuration(base_server_config, server_tls_config)
|
||||||
|
self._apply_tls_configuration(base_client_config, client_tls_config)
|
||||||
|
|
||||||
|
# QUIC v1 (RFC 9000) configurations
|
||||||
|
if self._config.enable_v1:
|
||||||
|
quic_v1_server_config = create_server_config_from_base(
|
||||||
|
base_server_config, self._security_manager, self._config
|
||||||
|
)
|
||||||
|
quic_v1_server_config.supported_versions = [
|
||||||
|
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||||
|
]
|
||||||
|
|
||||||
|
quic_v1_client_config = create_client_config_from_base(
|
||||||
|
base_client_config, self._security_manager, self._config
|
||||||
|
)
|
||||||
|
quic_v1_client_config.supported_versions = [
|
||||||
|
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Store both server and client configs for v1
|
||||||
|
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = (
|
||||||
|
quic_v1_server_config
|
||||||
|
)
|
||||||
|
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = (
|
||||||
|
quic_v1_client_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# QUIC draft-29 configurations for compatibility
|
||||||
|
if self._config.enable_draft29:
|
||||||
|
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
|
||||||
|
draft29_server_config.supported_versions = [
|
||||||
|
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||||
|
]
|
||||||
|
|
||||||
|
draft29_client_config = copy.copy(base_client_config)
|
||||||
|
draft29_client_config.supported_versions = [
|
||||||
|
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||||
|
]
|
||||||
|
|
||||||
|
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = (
|
||||||
|
draft29_server_config
|
||||||
|
)
|
||||||
|
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = (
|
||||||
|
draft29_client_config
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("QUIC configurations initialized with libp2p TLS security")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise QUICSecurityError(
|
||||||
|
f"Failed to setup QUIC TLS configurations: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _apply_tls_configuration(
|
||||||
|
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: QuicConfiguration to update
|
||||||
|
tls_config: TLS configuration dictionary from security manager
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config.certificate = tls_config.certificate
|
||||||
|
config.private_key = tls_config.private_key
|
||||||
|
config.certificate_chain = tls_config.certificate_chain
|
||||||
|
config.alpn_protocols = tls_config.alpn_protocols
|
||||||
|
config.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
logger.debug("Successfully applied TLS configuration to QUIC config")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
|
||||||
|
|
||||||
|
async def dial(
|
||||||
|
self,
|
||||||
|
maddr: multiaddr.Multiaddr,
|
||||||
|
) -> QUICConnection:
|
||||||
|
"""
|
||||||
|
Dial a remote peer using QUIC transport with security verification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1)
|
||||||
|
peer_id: Expected peer ID for verification
|
||||||
|
nursery: Nursery to execute the background tasks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw connection interface to the remote peer
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICDialError: If dialing fails
|
||||||
|
QUICSecurityError: If security verification fails
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise QUICDialError("Transport is closed")
|
||||||
|
|
||||||
|
if not is_quic_multiaddr(maddr):
|
||||||
|
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract connection details from multiaddr
|
||||||
|
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||||
|
remote_peer_id = maddr.get_peer_id()
|
||||||
|
if remote_peer_id is not None:
|
||||||
|
remote_peer_id = ID.from_base58(remote_peer_id)
|
||||||
|
|
||||||
|
if remote_peer_id is None:
|
||||||
|
logger.error("Unable to derive peer id from multiaddr")
|
||||||
|
raise QUICDialError("Unable to derive peer id from multiaddr")
|
||||||
|
quic_version = multiaddr_to_quic_version(maddr)
|
||||||
|
|
||||||
|
# Get appropriate QUIC client configuration
|
||||||
|
config_key = TProtocol(f"{quic_version}_client")
|
||||||
|
logger.debug("config_key", config_key, self._quic_configs.keys())
|
||||||
|
config = self._quic_configs.get(config_key)
|
||||||
|
if not config:
|
||||||
|
raise QUICDialError(f"Unsupported QUIC version: {quic_version}")
|
||||||
|
|
||||||
|
config.is_client = True
|
||||||
|
config.quic_logger = QuicLogger()
|
||||||
|
|
||||||
|
# Ensure client certificate is properly set for mutual authentication
|
||||||
|
if not config.certificate or not config.private_key:
|
||||||
|
logger.warning(
|
||||||
|
"Client config missing certificate - applying TLS config"
|
||||||
|
)
|
||||||
|
client_tls_config = self._security_manager.create_client_config()
|
||||||
|
self._apply_tls_configuration(config, client_tls_config)
|
||||||
|
|
||||||
|
# Debug log to verify certificate is present
|
||||||
|
logger.info(
|
||||||
|
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Starting QUIC Connection")
|
||||||
|
# Create QUIC connection using aioquic's sans-IO core
|
||||||
|
native_quic_connection = NativeQUICConnection(configuration=config)
|
||||||
|
|
||||||
|
# Create trio-based QUIC connection wrapper with security
|
||||||
|
connection = QUICConnection(
|
||||||
|
quic_connection=native_quic_connection,
|
||||||
|
remote_addr=(host, port),
|
||||||
|
remote_peer_id=remote_peer_id,
|
||||||
|
local_peer_id=self._peer_id,
|
||||||
|
is_initiator=True,
|
||||||
|
maddr=maddr,
|
||||||
|
transport=self,
|
||||||
|
security_manager=self._security_manager,
|
||||||
|
)
|
||||||
|
logger.debug("QUIC Connection Created")
|
||||||
|
|
||||||
|
if self._background_nursery is None:
|
||||||
|
logger.error("No nursery set to execute background tasks")
|
||||||
|
raise QUICDialError("No nursery found to execute tasks")
|
||||||
|
|
||||||
|
await connection.connect(self._background_nursery)
|
||||||
|
|
||||||
|
# Store connection for management
|
||||||
|
conn_id = f"{host}:{port}"
|
||||||
|
self._connections[conn_id] = connection
|
||||||
|
|
||||||
|
return connection
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
|
||||||
|
raise QUICDialError(f"Dial failed: {e}") from e
|
||||||
|
|
||||||
|
async def _verify_peer_identity(
|
||||||
|
self, connection: QUICConnection, expected_peer_id: ID
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Verify remote peer identity after TLS handshake.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: The established QUIC connection
|
||||||
|
expected_peer_id: Expected peer ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICSecurityError: If peer verification fails
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get peer certificate from the connection
|
||||||
|
peer_certificate = await connection.get_peer_certificate()
|
||||||
|
|
||||||
|
if not peer_certificate:
|
||||||
|
raise QUICSecurityError("No peer certificate available")
|
||||||
|
|
||||||
|
# Verify peer identity using security manager
|
||||||
|
verified_peer_id = self._security_manager.verify_peer_identity(
|
||||||
|
peer_certificate, expected_peer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if verified_peer_id != expected_peer_id:
|
||||||
|
raise QUICSecurityError(
|
||||||
|
"Peer ID verification failed: expected "
|
||||||
|
f"{expected_peer_id}, got {verified_peer_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||||
|
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise QUICSecurityError(f"Peer identity verification failed: {e}") from e
|
||||||
|
|
||||||
|
def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener:
|
||||||
|
"""
|
||||||
|
Create a QUIC listener with integrated security.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler_function: Function to handle new connections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QUIC listener instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICListenError: If transport is closed
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise QUICListenError("Transport is closed")
|
||||||
|
|
||||||
|
# Get server configurations for the listener
|
||||||
|
server_configs = {
|
||||||
|
version: config
|
||||||
|
for version, config in self._quic_configs.items()
|
||||||
|
if version.endswith("_server")
|
||||||
|
}
|
||||||
|
|
||||||
|
listener = QUICListener(
|
||||||
|
transport=self,
|
||||||
|
handler_function=handler_function,
|
||||||
|
quic_configs=server_configs,
|
||||||
|
config=self._config,
|
||||||
|
security_manager=self._security_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._listeners.append(listener)
|
||||||
|
logger.debug("Created QUIC listener with security")
|
||||||
|
return listener
|
||||||
|
|
||||||
|
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this transport can dial the given multiaddr.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maddr: Multiaddr to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this transport can dial the address
|
||||||
|
|
||||||
|
"""
|
||||||
|
return is_quic_multiaddr(maddr)
|
||||||
|
|
||||||
|
def protocols(self) -> list[TProtocol]:
|
||||||
|
"""
|
||||||
|
Get supported protocol identifiers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of supported protocol strings
|
||||||
|
|
||||||
|
"""
|
||||||
|
protocols = [QUIC_V1_PROTOCOL]
|
||||||
|
if self._config.enable_draft29:
|
||||||
|
protocols.append(QUIC_DRAFT29_PROTOCOL)
|
||||||
|
return protocols
|
||||||
|
|
||||||
|
def listen_order(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the listen order priority for this transport.
|
||||||
|
Matches go-libp2p's ListenOrder = 1 for QUIC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Priority order for listening (lower = higher priority)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the transport and cleanup resources."""
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._closed = True
|
||||||
|
logger.debug("Closing QUIC transport")
|
||||||
|
|
||||||
|
# Close all active connections and listeners concurrently using trio nursery
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
# Close all connections
|
||||||
|
for connection in self._connections.values():
|
||||||
|
nursery.start_soon(connection.close)
|
||||||
|
|
||||||
|
# Close all listeners
|
||||||
|
for listener in self._listeners:
|
||||||
|
nursery.start_soon(listener.close)
|
||||||
|
|
||||||
|
self._connections.clear()
|
||||||
|
self._listeners.clear()
|
||||||
|
|
||||||
|
logger.debug("QUIC transport closed")
|
||||||
|
|
||||||
|
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
|
||||||
|
"""Clean up a terminated connection from all listeners."""
|
||||||
|
try:
|
||||||
|
for listener in self._listeners:
|
||||||
|
await listener._remove_connection_by_object(connection)
|
||||||
|
logger.debug(
|
||||||
|
"✅ TRANSPORT: Cleaned up terminated connection from all listeners"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ TRANSPORT: Error cleaning up terminated connection: {e}")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, int | list[str] | object]:
|
||||||
|
"""Get transport statistics including security info."""
|
||||||
|
return {
|
||||||
|
"active_connections": len(self._connections),
|
||||||
|
"active_listeners": len(self._listeners),
|
||||||
|
"supported_protocols": self.protocols(),
|
||||||
|
"local_peer_id": str(self._peer_id),
|
||||||
|
"security_enabled": True,
|
||||||
|
"tls_configured": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_security_manager(self) -> QUICTLSConfigManager:
|
||||||
|
"""
|
||||||
|
Get the security manager for this transport.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The QUIC TLS configuration manager
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._security_manager
|
||||||
|
|
||||||
|
def get_listener_socket(self) -> trio.socket.SocketType | None:
|
||||||
|
"""Get the socket from the first active listener."""
|
||||||
|
for listener in self._listeners:
|
||||||
|
if listener.is_listening() and listener._socket:
|
||||||
|
return listener._socket
|
||||||
|
return None
|
||||||
466
libp2p/transport/quic/utils.py
Normal file
466
libp2p/transport/quic/utils.py
Normal file
@ -0,0 +1,466 @@
|
|||||||
|
"""
|
||||||
|
Multiaddr utilities for QUIC transport - Module 4.
|
||||||
|
Essential utilities required for QUIC transport implementation.
|
||||||
|
Based on go-libp2p and js-libp2p QUIC implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
from aioquic.quic.configuration import QuicConfiguration
|
||||||
|
import multiaddr
|
||||||
|
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||||
|
|
||||||
|
from .config import QUICTransportConfig
|
||||||
|
from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Protocol constants
|
||||||
|
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||||
|
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||||
|
UDP_PROTOCOL = "udp"
|
||||||
|
IP4_PROTOCOL = "ip4"
|
||||||
|
IP6_PROTOCOL = "ip6"
|
||||||
|
|
||||||
|
SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server"
|
||||||
|
CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client"
|
||||||
|
|
||||||
|
SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server"
|
||||||
|
CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client"
|
||||||
|
|
||||||
|
CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = {
|
||||||
|
SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000
|
||||||
|
CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000
|
||||||
|
SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||||
|
CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||||
|
}
|
||||||
|
|
||||||
|
# QUIC version to wire format mappings (required for aioquic)
|
||||||
|
QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = {
|
||||||
|
QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000
|
||||||
|
QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29
|
||||||
|
}
|
||||||
|
|
||||||
|
# ALPN protocols for libp2p over QUIC
|
||||||
|
LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a multiaddr represents a QUIC address.
|
||||||
|
|
||||||
|
Valid QUIC multiaddrs:
|
||||||
|
- /ip4/127.0.0.1/udp/4001/quic-v1
|
||||||
|
- /ip4/127.0.0.1/udp/4001/quic
|
||||||
|
- /ip6/::1/udp/4001/quic-v1
|
||||||
|
- /ip6/::1/udp/4001/quic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maddr: Multiaddr to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the multiaddr represents a QUIC address
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
addr_str = str(maddr)
|
||||||
|
|
||||||
|
# Check for required components
|
||||||
|
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
|
||||||
|
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
|
||||||
|
has_quic = (
|
||||||
|
f"/{QUIC_V1_PROTOCOL}" in addr_str
|
||||||
|
or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str
|
||||||
|
or "/quic" in addr_str
|
||||||
|
)
|
||||||
|
|
||||||
|
return has_ip and has_udp and has_quic
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Extract host and port from a QUIC multiaddr.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maddr: QUIC multiaddr
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (host, port)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not is_quic_multiaddr(maddr):
|
||||||
|
raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
host = None
|
||||||
|
port = None
|
||||||
|
|
||||||
|
# Try to get IPv4 address
|
||||||
|
try:
|
||||||
|
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try to get IPv6 address if IPv4 not found
|
||||||
|
if host is None:
|
||||||
|
try:
|
||||||
|
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Get UDP port
|
||||||
|
try:
|
||||||
|
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
|
||||||
|
port = int(port_str)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if host is None or port is None:
|
||||||
|
raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}")
|
||||||
|
|
||||||
|
return host, port
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise QUICInvalidMultiaddrError(
|
||||||
|
f"Failed to parse QUIC multiaddr {maddr}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
|
||||||
|
"""
|
||||||
|
Determine QUIC version from multiaddr.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maddr: QUIC multiaddr
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QUIC version identifier ("quic-v1" or "quic")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
addr_str = str(maddr)
|
||||||
|
|
||||||
|
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
|
||||||
|
return QUIC_V1_PROTOCOL # RFC 9000
|
||||||
|
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
|
||||||
|
return QUIC_DRAFT29_PROTOCOL # draft-29
|
||||||
|
else:
|
||||||
|
raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise QUICInvalidMultiaddrError(
|
||||||
|
f"Failed to determine QUIC version from {maddr}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def create_quic_multiaddr(
|
||||||
|
host: str, port: int, version: str = "quic-v1"
|
||||||
|
) -> multiaddr.Multiaddr:
|
||||||
|
"""
|
||||||
|
Create a QUIC multiaddr from host, port, and version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: IP address (IPv4 or IPv6)
|
||||||
|
port: UDP port number
|
||||||
|
version: QUIC version ("quic-v1" or "quic")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QUIC multiaddr
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICInvalidMultiaddrError: If invalid parameters provided
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Determine IP version
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(host)
|
||||||
|
if isinstance(ip, ipaddress.IPv4Address):
|
||||||
|
ip_proto = IP4_PROTOCOL
|
||||||
|
else:
|
||||||
|
ip_proto = IP6_PROTOCOL
|
||||||
|
except ValueError:
|
||||||
|
raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}")
|
||||||
|
|
||||||
|
# Validate port
|
||||||
|
if not (0 <= port <= 65535):
|
||||||
|
raise QUICInvalidMultiaddrError(f"Invalid port: {port}")
|
||||||
|
|
||||||
|
# Validate and normalize QUIC version
|
||||||
|
if version == "quic-v1" or version == "/quic-v1":
|
||||||
|
quic_proto = QUIC_V1_PROTOCOL
|
||||||
|
elif version == "quic" or version == "/quic":
|
||||||
|
quic_proto = QUIC_DRAFT29_PROTOCOL
|
||||||
|
else:
|
||||||
|
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
|
||||||
|
|
||||||
|
# Construct multiaddr
|
||||||
|
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
|
||||||
|
return multiaddr.Multiaddr(addr_str)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def quic_version_to_wire_format(version: TProtocol) -> int:
|
||||||
|
"""
|
||||||
|
Convert QUIC version string to wire format integer for aioquic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: QUIC version string ("quic-v1" or "quic")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wire format version number
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICUnsupportedVersionError: If version is not supported
|
||||||
|
|
||||||
|
"""
|
||||||
|
wire_version = QUIC_VERSION_MAPPINGS.get(version)
|
||||||
|
if wire_version is None:
|
||||||
|
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||||
|
|
||||||
|
return wire_version
|
||||||
|
|
||||||
|
|
||||||
|
def custom_quic_version_to_wire_format(version: TProtocol) -> int:
|
||||||
|
"""
|
||||||
|
Convert QUIC version string to wire format integer for aioquic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: QUIC version string ("quic-v1" or "quic")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wire format version number
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICUnsupportedVersionError: If version is not supported
|
||||||
|
|
||||||
|
"""
|
||||||
|
wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version)
|
||||||
|
if wire_version is None:
|
||||||
|
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||||
|
|
||||||
|
return wire_version
|
||||||
|
|
||||||
|
|
||||||
|
def get_alpn_protocols() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get ALPN protocols for libp2p over QUIC.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ALPN protocol identifiers
|
||||||
|
|
||||||
|
"""
|
||||||
|
return LIBP2P_ALPN_PROTOCOLS.copy()
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
|
||||||
|
"""
|
||||||
|
Normalize a QUIC multiaddr to canonical form.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maddr: Input QUIC multiaddr
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized multiaddr
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICInvalidMultiaddrError: If not a valid QUIC multiaddr
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not is_quic_multiaddr(maddr):
|
||||||
|
raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}")
|
||||||
|
|
||||||
|
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||||
|
version = multiaddr_to_quic_version(maddr)
|
||||||
|
|
||||||
|
return create_quic_multiaddr(host, port, version)
|
||||||
|
|
||||||
|
|
||||||
|
def create_server_config_from_base(
|
||||||
|
base_config: QuicConfiguration,
|
||||||
|
security_manager: QUICTLSConfigManager | None = None,
|
||||||
|
transport_config: QUICTransportConfig | None = None,
|
||||||
|
) -> QuicConfiguration:
|
||||||
|
"""
|
||||||
|
Create a server configuration without using deepcopy.
|
||||||
|
Manually copies attributes while handling cryptography objects properly.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create new server configuration from scratch
|
||||||
|
server_config = QuicConfiguration(is_client=False)
|
||||||
|
server_config.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Copy basic configuration attributes (these are safe to copy)
|
||||||
|
copyable_attrs = [
|
||||||
|
"alpn_protocols",
|
||||||
|
"verify_mode",
|
||||||
|
"max_datagram_frame_size",
|
||||||
|
"idle_timeout",
|
||||||
|
"max_concurrent_streams",
|
||||||
|
"supported_versions",
|
||||||
|
"max_data",
|
||||||
|
"max_stream_data",
|
||||||
|
"stateless_retry",
|
||||||
|
"quantum_readiness_test",
|
||||||
|
]
|
||||||
|
|
||||||
|
for attr in copyable_attrs:
|
||||||
|
if hasattr(base_config, attr):
|
||||||
|
value = getattr(base_config, attr)
|
||||||
|
if value is not None:
|
||||||
|
setattr(server_config, attr, value)
|
||||||
|
|
||||||
|
# Handle cryptography objects - these need direct reference, not copying
|
||||||
|
crypto_attrs = [
|
||||||
|
"certificate",
|
||||||
|
"private_key",
|
||||||
|
"certificate_chain",
|
||||||
|
"ca_certs",
|
||||||
|
]
|
||||||
|
|
||||||
|
for attr in crypto_attrs:
|
||||||
|
if hasattr(base_config, attr):
|
||||||
|
value = getattr(base_config, attr)
|
||||||
|
if value is not None:
|
||||||
|
setattr(server_config, attr, value)
|
||||||
|
|
||||||
|
# Apply security manager configuration if available
|
||||||
|
if security_manager:
|
||||||
|
try:
|
||||||
|
server_tls_config = security_manager.create_server_config()
|
||||||
|
|
||||||
|
# Override with security manager's TLS configuration
|
||||||
|
if server_tls_config.certificate:
|
||||||
|
server_config.certificate = server_tls_config.certificate
|
||||||
|
if server_tls_config.private_key:
|
||||||
|
server_config.private_key = server_tls_config.private_key
|
||||||
|
if server_tls_config.certificate_chain:
|
||||||
|
server_config.certificate_chain = (
|
||||||
|
server_tls_config.certificate_chain
|
||||||
|
)
|
||||||
|
if server_tls_config.alpn_protocols:
|
||||||
|
server_config.alpn_protocols = server_tls_config.alpn_protocols
|
||||||
|
server_tls_config.request_client_certificate = True
|
||||||
|
if getattr(server_tls_config, "request_client_certificate", False):
|
||||||
|
server_config._libp2p_request_client_cert = True # type: ignore
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"🔧 Failed to set request_client_certificate in server config"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to apply security manager config: {e}")
|
||||||
|
|
||||||
|
# Set transport-specific defaults if provided
|
||||||
|
if transport_config:
|
||||||
|
if server_config.idle_timeout == 0:
|
||||||
|
server_config.idle_timeout = getattr(
|
||||||
|
transport_config, "idle_timeout", 30.0
|
||||||
|
)
|
||||||
|
if server_config.max_datagram_frame_size is None:
|
||||||
|
server_config.max_datagram_frame_size = getattr(
|
||||||
|
transport_config, "max_datagram_size", 1200
|
||||||
|
)
|
||||||
|
# Ensure we have ALPN protocols
|
||||||
|
if not server_config.alpn_protocols:
|
||||||
|
server_config.alpn_protocols = ["libp2p"]
|
||||||
|
|
||||||
|
logger.debug("Successfully created server config without deepcopy")
|
||||||
|
return server_config
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create server config: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def create_client_config_from_base(
|
||||||
|
base_config: QuicConfiguration,
|
||||||
|
security_manager: QUICTLSConfigManager | None = None,
|
||||||
|
transport_config: QUICTransportConfig | None = None,
|
||||||
|
) -> QuicConfiguration:
|
||||||
|
"""
|
||||||
|
Create a client configuration without using deepcopy.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create new client configuration from scratch
|
||||||
|
client_config = QuicConfiguration(is_client=True)
|
||||||
|
client_config.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Copy basic configuration attributes
|
||||||
|
copyable_attrs = [
|
||||||
|
"alpn_protocols",
|
||||||
|
"verify_mode",
|
||||||
|
"max_datagram_frame_size",
|
||||||
|
"idle_timeout",
|
||||||
|
"max_concurrent_streams",
|
||||||
|
"supported_versions",
|
||||||
|
"max_data",
|
||||||
|
"max_stream_data",
|
||||||
|
"quantum_readiness_test",
|
||||||
|
]
|
||||||
|
|
||||||
|
for attr in copyable_attrs:
|
||||||
|
if hasattr(base_config, attr):
|
||||||
|
value = getattr(base_config, attr)
|
||||||
|
if value is not None:
|
||||||
|
setattr(client_config, attr, value)
|
||||||
|
|
||||||
|
# Handle cryptography objects - these need direct reference, not copying
|
||||||
|
crypto_attrs = [
|
||||||
|
"certificate",
|
||||||
|
"private_key",
|
||||||
|
"certificate_chain",
|
||||||
|
"ca_certs",
|
||||||
|
]
|
||||||
|
|
||||||
|
for attr in crypto_attrs:
|
||||||
|
if hasattr(base_config, attr):
|
||||||
|
value = getattr(base_config, attr)
|
||||||
|
if value is not None:
|
||||||
|
setattr(client_config, attr, value)
|
||||||
|
|
||||||
|
# Apply security manager configuration if available
|
||||||
|
if security_manager:
|
||||||
|
try:
|
||||||
|
client_tls_config = security_manager.create_client_config()
|
||||||
|
|
||||||
|
# Override with security manager's TLS configuration
|
||||||
|
if client_tls_config.certificate:
|
||||||
|
client_config.certificate = client_tls_config.certificate
|
||||||
|
if client_tls_config.private_key:
|
||||||
|
client_config.private_key = client_tls_config.private_key
|
||||||
|
if client_tls_config.certificate_chain:
|
||||||
|
client_config.certificate_chain = (
|
||||||
|
client_tls_config.certificate_chain
|
||||||
|
)
|
||||||
|
if client_tls_config.alpn_protocols:
|
||||||
|
client_config.alpn_protocols = client_tls_config.alpn_protocols
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to apply security manager config: {e}")
|
||||||
|
|
||||||
|
# Ensure we have ALPN protocols
|
||||||
|
if not client_config.alpn_protocols:
|
||||||
|
client_config.alpn_protocols = ["libp2p"]
|
||||||
|
|
||||||
|
logger.debug("Successfully created client config without deepcopy")
|
||||||
|
return client_config
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create client config: {e}")
|
||||||
|
raise
|
||||||
267
libp2p/transport/transport_registry.py
Normal file
267
libp2p/transport/transport_registry.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
Transport registry for dynamic transport selection based on multiaddr protocols.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from multiaddr import Multiaddr
|
||||||
|
from multiaddr.protocols import Protocol
|
||||||
|
|
||||||
|
from libp2p.abc import ITransport
|
||||||
|
from libp2p.transport.tcp.tcp import TCP
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
from libp2p.transport.websocket.multiaddr_utils import (
|
||||||
|
is_valid_websocket_multiaddr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Import QUIC utilities here to avoid circular imports
|
||||||
|
def _get_quic_transport() -> Any:
|
||||||
|
from libp2p.transport.quic.transport import QUICTransport
|
||||||
|
|
||||||
|
return QUICTransport
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
|
||||||
|
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||||
|
|
||||||
|
return is_quic_multiaddr
|
||||||
|
|
||||||
|
|
||||||
|
# Import WebsocketTransport here to avoid circular imports
|
||||||
|
def _get_websocket_transport() -> Any:
|
||||||
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
|
|
||||||
|
return WebsocketTransport
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.transport.registry")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that a multiaddr has a valid TCP structure.
|
||||||
|
|
||||||
|
:param maddr: The multiaddr to validate
|
||||||
|
:return: True if valid TCP structure, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
||||||
|
# or /ip6/::1/tcp/8080
|
||||||
|
protocols: list[Protocol] = list(maddr.protocols())
|
||||||
|
|
||||||
|
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
||||||
|
if len(protocols) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
||||||
|
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Second protocol should be tcp
|
||||||
|
if protocols[1].name != "tcp":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Should not have any protocols after tcp (unless it's a valid
|
||||||
|
# continuation like p2p)
|
||||||
|
# For now, we'll be strict and only allow network + tcp
|
||||||
|
if len(protocols) > 2:
|
||||||
|
# Check if the additional protocols are valid continuations
|
||||||
|
valid_continuations = ["p2p"] # Add more as needed
|
||||||
|
for i in range(2, len(protocols)):
|
||||||
|
if protocols[i].name not in valid_continuations:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TransportRegistry:
|
||||||
|
"""
|
||||||
|
Registry for mapping multiaddr protocols to transport implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._transports: dict[str, type[ITransport]] = {}
|
||||||
|
self._register_default_transports()
|
||||||
|
|
||||||
|
def _register_default_transports(self) -> None:
|
||||||
|
"""Register the default transport implementations."""
|
||||||
|
# Register TCP transport for /tcp protocol
|
||||||
|
self.register_transport("tcp", TCP)
|
||||||
|
|
||||||
|
# Register WebSocket transport for /ws and /wss protocols
|
||||||
|
WebsocketTransport = _get_websocket_transport()
|
||||||
|
self.register_transport("ws", WebsocketTransport)
|
||||||
|
self.register_transport("wss", WebsocketTransport)
|
||||||
|
|
||||||
|
# Register QUIC transport for /quic and /quic-v1 protocols
|
||||||
|
QUICTransport = _get_quic_transport()
|
||||||
|
self.register_transport("quic", QUICTransport)
|
||||||
|
self.register_transport("quic-v1", QUICTransport)
|
||||||
|
|
||||||
|
def register_transport(
|
||||||
|
self, protocol: str, transport_class: type[ITransport]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a transport class for a specific protocol.
|
||||||
|
|
||||||
|
:param protocol: The protocol identifier (e.g., "tcp", "ws")
|
||||||
|
:param transport_class: The transport class to register
|
||||||
|
"""
|
||||||
|
self._transports[protocol] = transport_class
|
||||||
|
logger.debug(
|
||||||
|
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
||||||
|
"""
|
||||||
|
Get the transport class for a specific protocol.
|
||||||
|
|
||||||
|
:param protocol: The protocol identifier
|
||||||
|
:return: The transport class or None if not found
|
||||||
|
"""
|
||||||
|
return self._transports.get(protocol)
|
||||||
|
|
||||||
|
def get_supported_protocols(self) -> list[str]:
|
||||||
|
"""Get list of supported transport protocols."""
|
||||||
|
return list(self._transports.keys())
|
||||||
|
|
||||||
|
def create_transport(
|
||||||
|
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
||||||
|
) -> ITransport | None:
|
||||||
|
"""
|
||||||
|
Create a transport instance for a specific protocol.
|
||||||
|
|
||||||
|
:param protocol: The protocol identifier
|
||||||
|
:param upgrader: The transport upgrader instance (required for WebSocket)
|
||||||
|
:param kwargs: Additional arguments for transport construction
|
||||||
|
:return: Transport instance or None if protocol not supported or creation fails
|
||||||
|
"""
|
||||||
|
transport_class = self.get_transport(protocol)
|
||||||
|
if transport_class is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if protocol in ["ws", "wss"]:
|
||||||
|
# WebSocket transport requires upgrader
|
||||||
|
if upgrader is None:
|
||||||
|
logger.warning(
|
||||||
|
f"WebSocket transport '{protocol}' requires upgrader"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
# Use explicit WebsocketTransport to avoid type issues
|
||||||
|
WebsocketTransport = _get_websocket_transport()
|
||||||
|
return WebsocketTransport(
|
||||||
|
upgrader,
|
||||||
|
tls_client_config=kwargs.get("tls_client_config"),
|
||||||
|
tls_server_config=kwargs.get("tls_server_config"),
|
||||||
|
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
|
||||||
|
)
|
||||||
|
elif protocol in ["quic", "quic-v1"]:
|
||||||
|
# QUIC transport requires private_key
|
||||||
|
private_key = kwargs.get("private_key")
|
||||||
|
if private_key is None:
|
||||||
|
logger.warning(f"QUIC transport '{protocol}' requires private_key")
|
||||||
|
return None
|
||||||
|
# Use explicit QUICTransport to avoid type issues
|
||||||
|
QUICTransport = _get_quic_transport()
|
||||||
|
config = kwargs.get("config")
|
||||||
|
return QUICTransport(private_key, config)
|
||||||
|
else:
|
||||||
|
# TCP transport doesn't require upgrader
|
||||||
|
return transport_class()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Global transport registry instance (lazy initialization)
|
||||||
|
_global_registry: TransportRegistry | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_transport_registry() -> TransportRegistry:
|
||||||
|
"""Get the global transport registry instance."""
|
||||||
|
global _global_registry
|
||||||
|
if _global_registry is None:
|
||||||
|
_global_registry = TransportRegistry()
|
||||||
|
return _global_registry
|
||||||
|
|
||||||
|
|
||||||
|
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
||||||
|
"""Register a transport class in the global registry."""
|
||||||
|
registry = get_transport_registry()
|
||||||
|
registry.register_transport(protocol, transport_class)
|
||||||
|
|
||||||
|
|
||||||
|
def create_transport_for_multiaddr(
|
||||||
|
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
|
||||||
|
) -> ITransport | None:
|
||||||
|
"""
|
||||||
|
Create the appropriate transport for a given multiaddr.
|
||||||
|
|
||||||
|
:param maddr: The multiaddr to create transport for
|
||||||
|
:param upgrader: The transport upgrader instance
|
||||||
|
:param kwargs: Additional arguments for transport construction
|
||||||
|
(e.g., private_key for QUIC)
|
||||||
|
:return: Transport instance or None if no suitable transport found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get all protocols in the multiaddr
|
||||||
|
protocols = [proto.name for proto in maddr.protocols()]
|
||||||
|
|
||||||
|
# Check for supported transport protocols in order of preference
|
||||||
|
# We need to validate that the multiaddr structure is valid for our transports
|
||||||
|
if "quic" in protocols or "quic-v1" in protocols:
|
||||||
|
# For QUIC, we need a valid structure like:
|
||||||
|
# /ip4/127.0.0.1/udp/4001/quic
|
||||||
|
# /ip4/127.0.0.1/udp/4001/quic-v1
|
||||||
|
is_quic_multiaddr = _get_quic_validation()
|
||||||
|
if is_quic_multiaddr(maddr):
|
||||||
|
# Determine QUIC version
|
||||||
|
registry = get_transport_registry()
|
||||||
|
if "quic-v1" in protocols:
|
||||||
|
return registry.create_transport("quic-v1", upgrader, **kwargs)
|
||||||
|
else:
|
||||||
|
return registry.create_transport("quic", upgrader, **kwargs)
|
||||||
|
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
|
||||||
|
# For WebSocket, we need a valid structure like:
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||||
|
if is_valid_websocket_multiaddr(maddr):
|
||||||
|
# Determine if this is a secure WebSocket connection
|
||||||
|
registry = get_transport_registry()
|
||||||
|
if "wss" in protocols or "tls" in protocols:
|
||||||
|
return registry.create_transport("wss", upgrader, **kwargs)
|
||||||
|
else:
|
||||||
|
return registry.create_transport("ws", upgrader, **kwargs)
|
||||||
|
elif "tcp" in protocols:
|
||||||
|
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
|
||||||
|
# Check if the multiaddr has proper TCP structure
|
||||||
|
if _is_valid_tcp_multiaddr(maddr):
|
||||||
|
registry = get_transport_registry()
|
||||||
|
return registry.create_transport("tcp", upgrader)
|
||||||
|
|
||||||
|
# If no supported transport protocol found or structure is invalid, return None
|
||||||
|
logger.warning(
|
||||||
|
f"No supported transport protocol found or invalid structure in "
|
||||||
|
f"multiaddr: {maddr}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle any errors gracefully (e.g., invalid multiaddr)
|
||||||
|
logger.warning(f"Error processing multiaddr {maddr}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_transport_protocols() -> list[str]:
|
||||||
|
"""Get list of supported transport protocols from the global registry."""
|
||||||
|
registry = get_transport_registry()
|
||||||
|
return registry.get_supported_protocols()
|
||||||
@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
|
|||||||
MultiselectClientError,
|
MultiselectClientError,
|
||||||
MultiselectError,
|
MultiselectError,
|
||||||
)
|
)
|
||||||
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
)
|
||||||
from libp2p.security.exceptions import (
|
from libp2p.security.exceptions import (
|
||||||
HandshakeFailure,
|
HandshakeFailure,
|
||||||
)
|
)
|
||||||
@ -37,9 +40,12 @@ class TransportUpgrader:
|
|||||||
self,
|
self,
|
||||||
secure_transports_by_protocol: TSecurityOptions,
|
secure_transports_by_protocol: TSecurityOptions,
|
||||||
muxer_transports_by_protocol: TMuxerOptions,
|
muxer_transports_by_protocol: TMuxerOptions,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
):
|
):
|
||||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
self.muxer_multistream = MuxerMultistream(
|
||||||
|
muxer_transports_by_protocol, negotiate_timeout
|
||||||
|
)
|
||||||
|
|
||||||
async def upgrade_security(
|
async def upgrade_security(
|
||||||
self,
|
self,
|
||||||
|
|||||||
198
libp2p/transport/websocket/connection.py
Normal file
198
libp2p/transport/websocket/connection.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.io.abc import ReadWriteCloser
|
||||||
|
from libp2p.io.exceptions import IOException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class P2PWebSocketConnection(ReadWriteCloser):
|
||||||
|
"""
|
||||||
|
Wraps a WebSocketConnection to provide the raw stream interface
|
||||||
|
that libp2p protocols expect.
|
||||||
|
|
||||||
|
Implements production-ready buffer management and flow control
|
||||||
|
as recommended in the libp2p WebSocket specification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ws_connection: Any,
|
||||||
|
ws_context: Any = None,
|
||||||
|
is_secure: bool = False,
|
||||||
|
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||||
|
) -> None:
|
||||||
|
self._ws_connection = ws_connection
|
||||||
|
self._ws_context = ws_context
|
||||||
|
self._is_secure = is_secure
|
||||||
|
self._read_buffer = b""
|
||||||
|
self._read_lock = trio.Lock()
|
||||||
|
self._connection_start_time = time.time()
|
||||||
|
self._bytes_read = 0
|
||||||
|
self._bytes_written = 0
|
||||||
|
self._closed = False
|
||||||
|
self._close_lock = trio.Lock()
|
||||||
|
self._max_buffered_amount = max_buffered_amount
|
||||||
|
self._write_lock = trio.Lock()
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
"""Write data with flow control and buffer management"""
|
||||||
|
if self._closed:
|
||||||
|
raise IOException("Connection is closed")
|
||||||
|
|
||||||
|
async with self._write_lock:
|
||||||
|
try:
|
||||||
|
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||||
|
|
||||||
|
# Check buffer amount for flow control
|
||||||
|
if hasattr(self._ws_connection, "bufferedAmount"):
|
||||||
|
buffered = self._ws_connection.bufferedAmount
|
||||||
|
if buffered > self._max_buffered_amount:
|
||||||
|
logger.warning(f"WebSocket buffer full: {buffered} bytes")
|
||||||
|
# In production, you might want to
|
||||||
|
# wait or implement backpressure
|
||||||
|
# For now, we'll continue but log the warning
|
||||||
|
|
||||||
|
# Send as a binary WebSocket message
|
||||||
|
await self._ws_connection.send_message(data)
|
||||||
|
self._bytes_written += len(data)
|
||||||
|
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket write failed: {e}")
|
||||||
|
self._closed = True
|
||||||
|
raise IOException from e
|
||||||
|
|
||||||
|
async def read(self, n: int | None = None) -> bytes:
|
||||||
|
"""
|
||||||
|
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||||
|
This implementation provides byte-level access to WebSocket messages,
|
||||||
|
which is required for libp2p protocol compatibility.
|
||||||
|
|
||||||
|
For WebSocket compatibility with libp2p protocols, this method:
|
||||||
|
1. Buffers incoming WebSocket messages
|
||||||
|
2. Returns exactly the requested number of bytes when n is specified
|
||||||
|
3. Accumulates multiple WebSocket messages if needed to satisfy the request
|
||||||
|
4. Returns empty bytes (not raises) when connection is closed and no data
|
||||||
|
available
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise IOException("Connection is closed")
|
||||||
|
|
||||||
|
async with self._read_lock:
|
||||||
|
try:
|
||||||
|
# If n is None, read at least one message and return all buffered data
|
||||||
|
if n is None:
|
||||||
|
if not self._read_buffer:
|
||||||
|
try:
|
||||||
|
# Use a short timeout to avoid blocking indefinitely
|
||||||
|
with trio.fail_after(1.0): # 1 second timeout
|
||||||
|
message = await self._ws_connection.get_message()
|
||||||
|
if isinstance(message, str):
|
||||||
|
message = message.encode("utf-8")
|
||||||
|
self._read_buffer = message
|
||||||
|
except trio.TooSlowError:
|
||||||
|
# No message available within timeout
|
||||||
|
return b""
|
||||||
|
except Exception:
|
||||||
|
# Return empty bytes if no data available
|
||||||
|
# (connection closed)
|
||||||
|
return b""
|
||||||
|
|
||||||
|
result = self._read_buffer
|
||||||
|
self._read_buffer = b""
|
||||||
|
self._bytes_read += len(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# For specific byte count requests, return UP TO n bytes (not exactly n)
|
||||||
|
# This matches TCP semantics where read(1024) returns available data
|
||||||
|
# up to 1024 bytes
|
||||||
|
|
||||||
|
# If we don't have any data buffered, try to get at least one message
|
||||||
|
if not self._read_buffer:
|
||||||
|
try:
|
||||||
|
# Use a short timeout to avoid blocking indefinitely
|
||||||
|
with trio.fail_after(1.0): # 1 second timeout
|
||||||
|
message = await self._ws_connection.get_message()
|
||||||
|
if isinstance(message, str):
|
||||||
|
message = message.encode("utf-8")
|
||||||
|
self._read_buffer = message
|
||||||
|
except trio.TooSlowError:
|
||||||
|
return b"" # No data available
|
||||||
|
except Exception:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
# Now return up to n bytes from the buffer (TCP-like semantics)
|
||||||
|
if len(self._read_buffer) == 0:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
# Return up to n bytes (like TCP read())
|
||||||
|
result = self._read_buffer[:n]
|
||||||
|
self._read_buffer = self._read_buffer[len(result) :]
|
||||||
|
self._bytes_read += len(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket read failed: {e}")
|
||||||
|
raise IOException from e
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the WebSocket connection. This method is idempotent."""
|
||||||
|
async with self._close_lock:
|
||||||
|
if self._closed:
|
||||||
|
return # Already closed
|
||||||
|
|
||||||
|
logger.debug("WebSocket connection closing")
|
||||||
|
self._closed = True
|
||||||
|
try:
|
||||||
|
# Always close the connection directly, avoid context manager issues
|
||||||
|
# The context manager may be causing cancel scope corruption
|
||||||
|
logger.debug("WebSocket closing connection directly")
|
||||||
|
await self._ws_connection.aclose()
|
||||||
|
# Exit the context manager if we have one
|
||||||
|
if self._ws_context is not None:
|
||||||
|
await self._ws_context.__aexit__(None, None, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket close error: {e}")
|
||||||
|
# Don't raise here, as close() should be idempotent
|
||||||
|
finally:
|
||||||
|
logger.debug("WebSocket connection closed")
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if the connection is closed"""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def conn_state(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Return connection state information similar to Go's ConnState() method.
|
||||||
|
|
||||||
|
:return: Dictionary containing connection state information
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
return {
|
||||||
|
"transport": "websocket",
|
||||||
|
"secure": self._is_secure,
|
||||||
|
"connection_duration": current_time - self._connection_start_time,
|
||||||
|
"bytes_read": self._bytes_read,
|
||||||
|
"bytes_written": self._bytes_written,
|
||||||
|
"total_bytes": self._bytes_read + self._bytes_written,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_remote_address(self) -> tuple[str, int] | None:
|
||||||
|
# Try to get remote address from the WebSocket connection
|
||||||
|
try:
|
||||||
|
remote = self._ws_connection.remote
|
||||||
|
if hasattr(remote, "address") and hasattr(remote, "port"):
|
||||||
|
return str(remote.address), int(remote.port)
|
||||||
|
elif isinstance(remote, str):
|
||||||
|
# Parse address:port format
|
||||||
|
if ":" in remote:
|
||||||
|
host, port = remote.rsplit(":", 1)
|
||||||
|
return host, int(port)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
225
libp2p/transport/websocket/listener.py
Normal file
225
libp2p/transport/websocket/listener.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from multiaddr import Multiaddr
|
||||||
|
import trio
|
||||||
|
from trio_typing import TaskStatus
|
||||||
|
from trio_websocket import serve_websocket
|
||||||
|
|
||||||
|
from libp2p.abc import IListener
|
||||||
|
from libp2p.custom_types import THandler
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||||
|
|
||||||
|
from .connection import P2PWebSocketConnection
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.transport.websocket.listener")
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketListener(IListener):
|
||||||
|
"""
|
||||||
|
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
handler: THandler,
|
||||||
|
upgrader: TransportUpgrader,
|
||||||
|
tls_config: ssl.SSLContext | None = None,
|
||||||
|
handshake_timeout: float = 15.0,
|
||||||
|
) -> None:
|
||||||
|
self._handler = handler
|
||||||
|
self._upgrader = upgrader
|
||||||
|
self._tls_config = tls_config
|
||||||
|
self._handshake_timeout = handshake_timeout
|
||||||
|
self._server = None
|
||||||
|
self._shutdown_event = trio.Event()
|
||||||
|
self._nursery: trio.Nursery | None = None
|
||||||
|
self._listeners: Any = None
|
||||||
|
self._is_wss = False # Track whether this is a WSS listener
|
||||||
|
|
||||||
|
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||||
|
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
||||||
|
|
||||||
|
# Parse the WebSocket multiaddr to determine if it's secure
|
||||||
|
try:
|
||||||
|
parsed = parse_websocket_multiaddr(maddr)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||||
|
|
||||||
|
# Check if WSS is requested but no TLS config provided
|
||||||
|
if parsed.is_wss and self._tls_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot listen on WSS address {maddr} without TLS configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store whether this is a WSS listener
|
||||||
|
self._is_wss = parsed.is_wss
|
||||||
|
|
||||||
|
# Extract host and port from the base multiaddr
|
||||||
|
host = (
|
||||||
|
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||||
|
or "0.0.0.0"
|
||||||
|
)
|
||||||
|
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||||
|
if port_str is None:
|
||||||
|
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||||
|
port = int(port_str)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def serve_websocket_tcp(
|
||||||
|
handler: Callable[[Any], Awaitable[None]],
|
||||||
|
port: int,
|
||||||
|
host: str,
|
||||||
|
task_status: TaskStatus[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Start TCP server and handle WebSocket connections manually"""
|
||||||
|
logger.debug(
|
||||||
|
"serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss
|
||||||
|
)
|
||||||
|
|
||||||
|
async def websocket_handler(request: Any) -> None:
|
||||||
|
"""Handle WebSocket requests"""
|
||||||
|
logger.debug("WebSocket request received")
|
||||||
|
try:
|
||||||
|
# Apply handshake timeout
|
||||||
|
with trio.fail_after(self._handshake_timeout):
|
||||||
|
# Accept the WebSocket connection
|
||||||
|
ws_connection = await request.accept()
|
||||||
|
logger.debug("WebSocket handshake successful")
|
||||||
|
|
||||||
|
# Create the WebSocket connection wrapper
|
||||||
|
conn = P2PWebSocketConnection(
|
||||||
|
ws_connection, is_secure=parsed.is_wss
|
||||||
|
) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
|
# Call the handler function that was passed to create_listener
|
||||||
|
# This handler will handle the security and muxing upgrades
|
||||||
|
logger.debug("Calling connection handler")
|
||||||
|
await self._handler(conn)
|
||||||
|
|
||||||
|
# Don't keep the connection alive indefinitely
|
||||||
|
# Let the handler manage the connection lifecycle
|
||||||
|
logger.debug(
|
||||||
|
"Handler completed, connection will be managed by handler"
|
||||||
|
)
|
||||||
|
|
||||||
|
except trio.TooSlowError:
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket handshake timeout after {self._handshake_timeout}s"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await request.reject(408) # Request Timeout
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"WebSocket connection error: {e}")
|
||||||
|
logger.debug(f"Error type: {type(e)}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||||
|
# Reject the connection
|
||||||
|
try:
|
||||||
|
await request.reject(400)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
||||||
|
ssl_context = self._tls_config if parsed.is_wss else None
|
||||||
|
await serve_websocket(
|
||||||
|
websocket_handler, host, port, ssl_context, task_status=task_status
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the nursery for shutdown
|
||||||
|
self._nursery = nursery
|
||||||
|
|
||||||
|
# Start the server using nursery.start() like TCP does
|
||||||
|
logger.debug("Calling nursery.start()...")
|
||||||
|
started_listeners = await nursery.start(
|
||||||
|
serve_websocket_tcp,
|
||||||
|
None, # No handler needed since it's defined inside serve_websocket_tcp
|
||||||
|
port,
|
||||||
|
host,
|
||||||
|
)
|
||||||
|
logger.debug(f"nursery.start() returned: {started_listeners}")
|
||||||
|
|
||||||
|
if started_listeners is None:
|
||||||
|
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Store the listeners for get_addrs() and close() - these are real
|
||||||
|
# SocketListener objects
|
||||||
|
self._listeners = started_listeners
|
||||||
|
logger.debug(
|
||||||
|
"WebsocketListener.listen returning True with WebSocketServer object"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||||
|
if not hasattr(self, "_listeners") or not self._listeners:
|
||||||
|
logger.debug("No listeners available for get_addrs()")
|
||||||
|
return ()
|
||||||
|
|
||||||
|
# Handle WebSocketServer objects
|
||||||
|
if hasattr(self._listeners, "port"):
|
||||||
|
# This is a WebSocketServer object
|
||||||
|
port = self._listeners.port
|
||||||
|
# Create a multiaddr from the port with correct WSS/WS protocol
|
||||||
|
protocol = "wss" if self._is_wss else "ws"
|
||||||
|
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),)
|
||||||
|
else:
|
||||||
|
# This is a list of listeners (like TCP)
|
||||||
|
listeners = self._listeners
|
||||||
|
# Get addresses from listeners like TCP does
|
||||||
|
return tuple(
|
||||||
|
_multiaddr_from_socket(listener.socket, self._is_wss)
|
||||||
|
for listener in listeners
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the WebSocket listener and stop accepting new connections"""
|
||||||
|
logger.debug("WebsocketListener.close called")
|
||||||
|
if hasattr(self, "_listeners") and self._listeners:
|
||||||
|
# Signal shutdown
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
# Close the WebSocket server
|
||||||
|
if hasattr(self._listeners, "aclose"):
|
||||||
|
# This is a WebSocketServer object
|
||||||
|
logger.debug("Closing WebSocket server")
|
||||||
|
await self._listeners.aclose()
|
||||||
|
logger.debug("WebSocket server closed")
|
||||||
|
elif isinstance(self._listeners, (list, tuple)):
|
||||||
|
# This is a list of listeners (like TCP)
|
||||||
|
logger.debug("Closing TCP listeners")
|
||||||
|
for listener in self._listeners:
|
||||||
|
await listener.aclose()
|
||||||
|
logger.debug("TCP listeners closed")
|
||||||
|
else:
|
||||||
|
# Unknown type, try to close it directly
|
||||||
|
logger.debug("Closing unknown listener type")
|
||||||
|
if hasattr(self._listeners, "close"):
|
||||||
|
self._listeners.close()
|
||||||
|
logger.debug("Unknown listener closed")
|
||||||
|
|
||||||
|
# Clear the listeners reference
|
||||||
|
self._listeners = None
|
||||||
|
logger.debug("WebsocketListener.close completed")
|
||||||
|
|
||||||
|
|
||||||
|
def _multiaddr_from_socket(
|
||||||
|
socket: trio.socket.SocketType, is_wss: bool = False
|
||||||
|
) -> Multiaddr:
|
||||||
|
"""Convert socket to multiaddr"""
|
||||||
|
ip, port = socket.getsockname()
|
||||||
|
protocol = "wss" if is_wss else "ws"
|
||||||
|
return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}")
|
||||||
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
"""
|
||||||
|
WebSocket multiaddr parsing utilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from multiaddr import Multiaddr
|
||||||
|
from multiaddr.protocols import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class ParsedWebSocketMultiaddr(NamedTuple):
|
||||||
|
"""Parsed WebSocket multiaddr information."""
|
||||||
|
|
||||||
|
is_wss: bool
|
||||||
|
sni: str | None
|
||||||
|
rest_multiaddr: Multiaddr
|
||||||
|
|
||||||
|
|
||||||
|
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
|
||||||
|
"""
|
||||||
|
Parse a WebSocket multiaddr and extract security information.
|
||||||
|
|
||||||
|
:param maddr: The multiaddr to parse
|
||||||
|
:return: Parsed WebSocket multiaddr information
|
||||||
|
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
|
||||||
|
"""
|
||||||
|
# First validate that this is a valid WebSocket multiaddr
|
||||||
|
if not is_valid_websocket_multiaddr(maddr):
|
||||||
|
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
|
||||||
|
|
||||||
|
protocols = list(maddr.protocols())
|
||||||
|
|
||||||
|
# Find the WebSocket protocol and check for security
|
||||||
|
is_wss = False
|
||||||
|
sni = None
|
||||||
|
ws_index = -1
|
||||||
|
tls_index = -1
|
||||||
|
sni_index = -1
|
||||||
|
|
||||||
|
# Find protocol indices
|
||||||
|
for i, protocol in enumerate(protocols):
|
||||||
|
if protocol.name == "ws":
|
||||||
|
ws_index = i
|
||||||
|
elif protocol.name == "wss":
|
||||||
|
ws_index = i
|
||||||
|
is_wss = True
|
||||||
|
elif protocol.name == "tls":
|
||||||
|
tls_index = i
|
||||||
|
elif protocol.name == "sni":
|
||||||
|
sni_index = i
|
||||||
|
sni = protocol.value
|
||||||
|
|
||||||
|
if ws_index == -1:
|
||||||
|
raise ValueError("Not a WebSocket multiaddr")
|
||||||
|
|
||||||
|
# Handle /wss protocol (convert to /tls/ws internally)
|
||||||
|
if is_wss and tls_index == -1:
|
||||||
|
# Convert /wss to /tls/ws format
|
||||||
|
# Remove /wss to get the base multiaddr
|
||||||
|
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||||
|
return ParsedWebSocketMultiaddr(
|
||||||
|
is_wss=True, sni=None, rest_multiaddr=without_wss
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle /tls/ws and /tls/sni/.../ws formats
|
||||||
|
if tls_index != -1:
|
||||||
|
is_wss = True
|
||||||
|
# Extract the base multiaddr (everything before /tls)
|
||||||
|
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
|
||||||
|
# Use multiaddr methods to properly extract the base
|
||||||
|
rest_multiaddr = maddr
|
||||||
|
# Remove /tls/ws or /tls/sni/.../ws from the end
|
||||||
|
if sni_index != -1:
|
||||||
|
# /tls/sni/example.com/ws format
|
||||||
|
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||||
|
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
|
||||||
|
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||||
|
else:
|
||||||
|
# /tls/ws format
|
||||||
|
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||||
|
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||||
|
return ParsedWebSocketMultiaddr(
|
||||||
|
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
|
||||||
|
)
|
||||||
|
|
||||||
|
# Regular /ws multiaddr - remove /ws and any additional protocols
|
||||||
|
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
|
||||||
|
return ParsedWebSocketMultiaddr(
|
||||||
|
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that a multiaddr has a valid WebSocket structure.
|
||||||
|
|
||||||
|
:param maddr: The multiaddr to validate
|
||||||
|
:return: True if valid WebSocket structure, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# WebSocket multiaddr should have structure like:
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||||
|
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||||
|
protocols: list[Protocol] = list(maddr.protocols())
|
||||||
|
|
||||||
|
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
|
||||||
|
if len(protocols) < 3:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
|
||||||
|
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Second protocol should be tcp
|
||||||
|
if protocols[1].name != "tcp":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for valid WebSocket protocols
|
||||||
|
ws_protocols = ["ws", "wss"]
|
||||||
|
tls_protocols = ["tls"]
|
||||||
|
sni_protocols = ["sni"]
|
||||||
|
|
||||||
|
# Find the WebSocket protocol
|
||||||
|
ws_protocol_found = False
|
||||||
|
tls_found = False
|
||||||
|
# sni_found = False # Not used currently
|
||||||
|
|
||||||
|
for i, protocol in enumerate(protocols[2:], start=2):
|
||||||
|
if protocol.name in ws_protocols:
|
||||||
|
ws_protocol_found = True
|
||||||
|
break
|
||||||
|
elif protocol.name in tls_protocols:
|
||||||
|
tls_found = True
|
||||||
|
elif protocol.name in sni_protocols:
|
||||||
|
pass # sni_found = True # Not used in current implementation
|
||||||
|
|
||||||
|
if not ws_protocol_found:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate protocol sequence
|
||||||
|
# For /ws: network + tcp + ws
|
||||||
|
# For /wss: network + tcp + wss
|
||||||
|
# For /tls/ws: network + tcp + tls + ws
|
||||||
|
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
|
||||||
|
|
||||||
|
# Check if it's a simple /ws or /wss
|
||||||
|
if len(protocols) == 3:
|
||||||
|
return protocols[2].name in ["ws", "wss"]
|
||||||
|
|
||||||
|
# Check for /tls/ws or /tls/sni/.../ws patterns
|
||||||
|
if tls_found:
|
||||||
|
# Must end with /ws (not /wss when using /tls)
|
||||||
|
if protocols[-1].name != "ws":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for valid TLS sequence
|
||||||
|
tls_index = None
|
||||||
|
for i, protocol in enumerate(protocols[2:], start=2):
|
||||||
|
if protocol.name == "tls":
|
||||||
|
tls_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if tls_index is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# After tls, we can have sni, then ws
|
||||||
|
remaining_protocols = protocols[tls_index + 1 :]
|
||||||
|
if len(remaining_protocols) == 1:
|
||||||
|
# /tls/ws
|
||||||
|
return remaining_protocols[0].name == "ws"
|
||||||
|
elif len(remaining_protocols) == 2:
|
||||||
|
# /tls/sni/example.com/ws
|
||||||
|
return (
|
||||||
|
remaining_protocols[0].name == "sni"
|
||||||
|
and remaining_protocols[1].name == "ws"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If we have more than 3 protocols but no TLS, check for valid continuations
|
||||||
|
# Allow additional protocols after the WebSocket protocol (like /p2p)
|
||||||
|
valid_continuations = ["p2p"]
|
||||||
|
|
||||||
|
# Find the WebSocket protocol index
|
||||||
|
ws_index = None
|
||||||
|
for i, protocol in enumerate(protocols):
|
||||||
|
if protocol.name in ["ws", "wss"]:
|
||||||
|
ws_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if ws_index is not None:
|
||||||
|
# Check protocols after the WebSocket protocol
|
||||||
|
for i in range(ws_index + 1, len(protocols)):
|
||||||
|
if protocols[i].name not in valid_continuations:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
229
libp2p/transport/websocket/transport.py
Normal file
229
libp2p/transport/websocket/transport.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
from multiaddr import Multiaddr
|
||||||
|
|
||||||
|
from libp2p.abc import IListener, ITransport
|
||||||
|
from libp2p.custom_types import THandler
|
||||||
|
from libp2p.network.connection.raw_connection import RawConnection
|
||||||
|
from libp2p.transport.exceptions import OpenConnectionError
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||||
|
|
||||||
|
from .connection import P2PWebSocketConnection
|
||||||
|
from .listener import WebsocketListener
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketTransport(ITransport):
|
||||||
|
"""
|
||||||
|
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
||||||
|
|
||||||
|
Implements production-ready WebSocket transport with:
|
||||||
|
- Flow control and buffer management
|
||||||
|
- Connection limits and rate limiting
|
||||||
|
- Proper error handling and cleanup
|
||||||
|
- Support for both WS and WSS protocols
|
||||||
|
- TLS configuration and handshake timeout
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
upgrader: TransportUpgrader,
|
||||||
|
tls_client_config: ssl.SSLContext | None = None,
|
||||||
|
tls_server_config: ssl.SSLContext | None = None,
|
||||||
|
handshake_timeout: float = 15.0,
|
||||||
|
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||||
|
):
|
||||||
|
self._upgrader = upgrader
|
||||||
|
self._tls_client_config = tls_client_config
|
||||||
|
self._tls_server_config = tls_server_config
|
||||||
|
self._handshake_timeout = handshake_timeout
|
||||||
|
self._max_buffered_amount = max_buffered_amount
|
||||||
|
self._connection_count = 0
|
||||||
|
self._max_connections = 1000 # Production limit
|
||||||
|
|
||||||
|
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
||||||
|
"""Dial a WebSocket connection to the given multiaddr."""
|
||||||
|
logger.debug(f"WebsocketTransport.dial called with {maddr}")
|
||||||
|
|
||||||
|
# Parse the WebSocket multiaddr to determine if it's secure
|
||||||
|
try:
|
||||||
|
parsed = parse_websocket_multiaddr(maddr)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||||
|
|
||||||
|
# Extract host and port from the base multiaddr
|
||||||
|
host = (
|
||||||
|
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||||
|
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||||
|
)
|
||||||
|
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||||
|
if port_str is None:
|
||||||
|
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||||
|
port = int(port_str)
|
||||||
|
|
||||||
|
# Build WebSocket URL based on security
|
||||||
|
if parsed.is_wss:
|
||||||
|
ws_url = f"wss://{host}:{port}/"
|
||||||
|
else:
|
||||||
|
ws_url = f"ws://{host}:{port}/"
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check connection limits
|
||||||
|
if self._connection_count >= self._max_connections:
|
||||||
|
raise OpenConnectionError(
|
||||||
|
f"Maximum connections reached: {self._max_connections}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare SSL context for WSS connections
|
||||||
|
ssl_context = None
|
||||||
|
if parsed.is_wss:
|
||||||
|
if self._tls_client_config:
|
||||||
|
ssl_context = self._tls_client_config
|
||||||
|
else:
|
||||||
|
# Create default SSL context for client
|
||||||
|
ssl_context = ssl.create_default_context()
|
||||||
|
# Set SNI if available
|
||||||
|
if parsed.sni:
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
|
||||||
|
|
||||||
|
# Use a different approach: start background nursery that will persist
|
||||||
|
logger.debug("WebsocketTransport.dial establishing connection")
|
||||||
|
|
||||||
|
# Import trio-websocket functions
|
||||||
|
from trio_websocket import connect_websocket
|
||||||
|
from trio_websocket._impl import _url_to_host
|
||||||
|
|
||||||
|
# Parse the WebSocket URL to get host, port, resource
|
||||||
|
# like trio-websocket does
|
||||||
|
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
|
||||||
|
ws_url, ssl_context
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
|
||||||
|
f"port={ws_port}, resource={ws_resource}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a background task manager for this connection
|
||||||
|
import trio
|
||||||
|
|
||||||
|
nursery_manager = trio.lowlevel.current_task().parent_nursery
|
||||||
|
if nursery_manager is None:
|
||||||
|
raise OpenConnectionError(
|
||||||
|
f"No parent nursery available for WebSocket connection to {maddr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply timeout to the connection process
|
||||||
|
with trio.fail_after(self._handshake_timeout):
|
||||||
|
logger.debug("WebsocketTransport.dial connecting WebSocket")
|
||||||
|
ws = await connect_websocket(
|
||||||
|
nursery_manager, # Use the existing nursery from libp2p
|
||||||
|
ws_host,
|
||||||
|
ws_port,
|
||||||
|
ws_resource,
|
||||||
|
use_ssl=ws_ssl_context,
|
||||||
|
message_queue_size=1024, # Reasonable defaults
|
||||||
|
max_message_size=16 * 1024 * 1024, # 16MB max message
|
||||||
|
)
|
||||||
|
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
||||||
|
|
||||||
|
# Create our connection wrapper with both WSS support and flow control
|
||||||
|
conn = P2PWebSocketConnection(
|
||||||
|
ws,
|
||||||
|
None,
|
||||||
|
is_secure=parsed.is_wss,
|
||||||
|
max_buffered_amount=self._max_buffered_amount,
|
||||||
|
)
|
||||||
|
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
||||||
|
|
||||||
|
self._connection_count += 1
|
||||||
|
logger.debug(f"Total connections: {self._connection_count}")
|
||||||
|
|
||||||
|
return RawConnection(conn, initiator=True)
|
||||||
|
except trio.TooSlowError as e:
|
||||||
|
raise OpenConnectionError(
|
||||||
|
f"WebSocket handshake timeout after {self._handshake_timeout}s "
|
||||||
|
f"for {maddr}"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
|
||||||
|
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
||||||
|
|
||||||
|
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
The type checker is incorrectly reporting this as an inconsistent override.
|
||||||
|
"""
|
||||||
|
logger.debug("WebsocketTransport.create_listener called")
|
||||||
|
return WebsocketListener(
|
||||||
|
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
|
||||||
|
"""
|
||||||
|
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
|
||||||
|
Similar to Go's Resolve() method.
|
||||||
|
|
||||||
|
:param maddr: The multiaddr to resolve
|
||||||
|
:return: List of resolved multiaddrs
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = parse_websocket_multiaddr(maddr)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
|
||||||
|
return [maddr] # Return original if not a valid WebSocket multiaddr
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not parsed.is_wss:
|
||||||
|
# No /tls/ws component, this isn't a secure websocket multiaddr
|
||||||
|
return [maddr]
|
||||||
|
|
||||||
|
if parsed.sni is not None:
|
||||||
|
# Already has SNI, return as-is
|
||||||
|
return [maddr]
|
||||||
|
|
||||||
|
# Try to extract DNS name from the base multiaddr
|
||||||
|
dns_name = None
|
||||||
|
for protocol_name in ["dns", "dns4", "dns6"]:
|
||||||
|
try:
|
||||||
|
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if dns_name is None:
|
||||||
|
# No DNS name found, return original
|
||||||
|
return [maddr]
|
||||||
|
|
||||||
|
# Create new multiaddr with SNI
|
||||||
|
# For /dns/example.com/tcp/8080/wss ->
|
||||||
|
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
|
||||||
|
try:
|
||||||
|
# Remove /wss and add /tls/sni/example.com/ws
|
||||||
|
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||||
|
sni_component = Multiaddr(f"/sni/{dns_name}")
|
||||||
|
resolved = (
|
||||||
|
without_wss.encapsulate(Multiaddr("/tls"))
|
||||||
|
.encapsulate(sni_component)
|
||||||
|
.encapsulate(Multiaddr("/ws"))
|
||||||
|
)
|
||||||
|
logger.debug(f"Resolved {maddr} to {resolved}")
|
||||||
|
return [resolved]
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
|
||||||
|
return [maddr]
|
||||||
@ -3,38 +3,24 @@ from __future__ import annotations
|
|||||||
import socket
|
import socket
|
||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
|
from multiaddr.utils import get_network_addrs, get_thin_waist_addresses
|
||||||
try:
|
|
||||||
from multiaddr.utils import ( # type: ignore
|
|
||||||
get_network_addrs,
|
|
||||||
get_thin_waist_addresses,
|
|
||||||
)
|
|
||||||
|
|
||||||
_HAS_THIN_WAIST = True
|
|
||||||
except ImportError: # pragma: no cover - only executed in older environments
|
|
||||||
_HAS_THIN_WAIST = False
|
|
||||||
get_thin_waist_addresses = None # type: ignore
|
|
||||||
get_network_addrs = None # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_get_network_addrs(ip_version: int) -> list[str]:
|
def _safe_get_network_addrs(ip_version: int) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
|
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
|
||||||
Falls back to minimal defaults when Thin Waist helpers are missing.
|
|
||||||
|
|
||||||
:param ip_version: 4 or 6
|
:param ip_version: 4 or 6
|
||||||
"""
|
"""
|
||||||
if _HAS_THIN_WAIST and get_network_addrs:
|
try:
|
||||||
try:
|
return get_network_addrs(ip_version) or []
|
||||||
return get_network_addrs(ip_version) or []
|
except Exception: # pragma: no cover - defensive
|
||||||
except Exception: # pragma: no cover - defensive
|
# Fallback behavior (very conservative)
|
||||||
return []
|
if ip_version == 4:
|
||||||
# Fallback behavior (very conservative)
|
return ["127.0.0.1"]
|
||||||
if ip_version == 4:
|
if ip_version == 6:
|
||||||
return ["127.0.0.1"]
|
return ["::1"]
|
||||||
if ip_version == 6:
|
return []
|
||||||
return ["::1"]
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def find_free_port() -> int:
|
def find_free_port() -> int:
|
||||||
@ -47,16 +33,13 @@ def find_free_port() -> int:
|
|||||||
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
|
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
|
||||||
"""
|
"""
|
||||||
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
|
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
|
||||||
If Thin Waist isn't available, returns [addr] (identity).
|
|
||||||
"""
|
"""
|
||||||
if _HAS_THIN_WAIST and get_thin_waist_addresses:
|
try:
|
||||||
try:
|
if port is not None:
|
||||||
if port is not None:
|
return get_thin_waist_addresses(addr, port=port) or []
|
||||||
return get_thin_waist_addresses(addr, port=port) or []
|
return get_thin_waist_addresses(addr) or []
|
||||||
return get_thin_waist_addresses(addr) or []
|
except Exception: # pragma: no cover - defensive
|
||||||
except Exception: # pragma: no cover - defensive
|
return [addr]
|
||||||
return [addr]
|
|
||||||
return [addr]
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
|
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
|
||||||
@ -73,8 +56,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
|||||||
seen_v4: set[str] = set()
|
seen_v4: set[str] = set()
|
||||||
|
|
||||||
for ip in _safe_get_network_addrs(4):
|
for ip in _safe_get_network_addrs(4):
|
||||||
seen_v4.add(ip)
|
if ip not in seen_v4: # Avoid duplicates
|
||||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
seen_v4.add(ip)
|
||||||
|
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||||
|
|
||||||
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
|
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
|
||||||
if seen_v4 and "127.0.0.1" not in seen_v4:
|
if seen_v4 and "127.0.0.1" not in seen_v4:
|
||||||
@ -89,8 +73,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
|||||||
#
|
#
|
||||||
# seen_v6: set[str] = set()
|
# seen_v6: set[str] = set()
|
||||||
# for ip in _safe_get_network_addrs(6):
|
# for ip in _safe_get_network_addrs(6):
|
||||||
# seen_v6.add(ip)
|
# if ip not in seen_v6: # Avoid duplicates
|
||||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
# seen_v6.add(ip)
|
||||||
|
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||||
#
|
#
|
||||||
# # Always include IPv6 loopback for testing purposes when IPv6 is available
|
# # Always include IPv6 loopback for testing purposes when IPv6 is available
|
||||||
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
|
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
|
||||||
@ -99,7 +84,7 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
|||||||
|
|
||||||
# Fallback if nothing discovered
|
# Fallback if nothing discovered
|
||||||
if not addrs:
|
if not addrs:
|
||||||
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
|
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
|
||||||
|
|
||||||
return addrs
|
return addrs
|
||||||
|
|
||||||
@ -120,6 +105,20 @@ def expand_wildcard_address(
|
|||||||
return expanded
|
return expanded
|
||||||
|
|
||||||
|
|
||||||
|
def get_wildcard_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||||
|
"""
|
||||||
|
Get wildcard address (0.0.0.0) when explicitly needed.
|
||||||
|
|
||||||
|
This function provides access to wildcard binding as a feature when
|
||||||
|
explicitly required, preserving the ability to bind to all interfaces.
|
||||||
|
|
||||||
|
:param port: Port number.
|
||||||
|
:param protocol: Transport protocol.
|
||||||
|
:return: A Multiaddr with wildcard binding (0.0.0.0).
|
||||||
|
"""
|
||||||
|
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||||
"""
|
"""
|
||||||
Choose an optimal address for an example to bind to:
|
Choose an optimal address for an example to bind to:
|
||||||
@ -148,13 +147,14 @@ def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
|||||||
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
|
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
# As a final fallback, produce a wildcard
|
# As a final fallback, produce a loopback address
|
||||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_available_interfaces",
|
"get_available_interfaces",
|
||||||
"get_optimal_binding_address",
|
"get_optimal_binding_address",
|
||||||
|
"get_wildcard_address",
|
||||||
"expand_wildcard_address",
|
"expand_wildcard_address",
|
||||||
"find_free_port",
|
"find_free_port",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,7 +1,4 @@
|
|||||||
import atexit
|
import atexit
|
||||||
from datetime import (
|
|
||||||
datetime,
|
|
||||||
)
|
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
@ -21,6 +18,9 @@ log_queue: "queue.Queue[Any]" = queue.Queue()
|
|||||||
# Store the current listener to stop it on exit
|
# Store the current listener to stop it on exit
|
||||||
_current_listener: logging.handlers.QueueListener | None = None
|
_current_listener: logging.handlers.QueueListener | None = None
|
||||||
|
|
||||||
|
# Store the handlers for proper cleanup
|
||||||
|
_current_handlers: list[logging.Handler] = []
|
||||||
|
|
||||||
# Event to track when the listener is ready
|
# Event to track when the listener is ready
|
||||||
_listener_ready = threading.Event()
|
_listener_ready = threading.Event()
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ def setup_logging() -> None:
|
|||||||
- Child loggers inherit their parent's level unless explicitly set
|
- Child loggers inherit their parent's level unless explicitly set
|
||||||
- The root libp2p logger controls the default level
|
- The root libp2p logger controls the default level
|
||||||
"""
|
"""
|
||||||
global _current_listener, _listener_ready
|
global _current_listener, _listener_ready, _current_handlers
|
||||||
|
|
||||||
# Reset the event
|
# Reset the event
|
||||||
_listener_ready.clear()
|
_listener_ready.clear()
|
||||||
@ -105,6 +105,12 @@ def setup_logging() -> None:
|
|||||||
_current_listener.stop()
|
_current_listener.stop()
|
||||||
_current_listener = None
|
_current_listener = None
|
||||||
|
|
||||||
|
# Close and clear existing handlers
|
||||||
|
for handler in _current_handlers:
|
||||||
|
if isinstance(handler, logging.FileHandler):
|
||||||
|
handler.close()
|
||||||
|
_current_handlers.clear()
|
||||||
|
|
||||||
# Get the log level from environment variable
|
# Get the log level from environment variable
|
||||||
debug_str = os.environ.get("LIBP2P_DEBUG", "")
|
debug_str = os.environ.get("LIBP2P_DEBUG", "")
|
||||||
|
|
||||||
@ -148,13 +154,10 @@ def setup_logging() -> None:
|
|||||||
log_path = Path(log_file)
|
log_path = Path(log_file)
|
||||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
# Default log file with timestamp and unique identifier
|
# Use cross-platform temp file creation
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
from libp2p.utils.paths import create_temp_file
|
||||||
unique_id = os.urandom(4).hex() # Add a unique identifier to prevent collisions
|
|
||||||
if os.name == "nt": # Windows
|
log_file = str(create_temp_file(prefix="py-libp2p_", suffix=".log"))
|
||||||
log_file = f"C:\\Windows\\Temp\\py-libp2p_{timestamp}_{unique_id}.log"
|
|
||||||
else: # Unix-like
|
|
||||||
log_file = f"/tmp/py-libp2p_{timestamp}_{unique_id}.log"
|
|
||||||
|
|
||||||
# Print the log file path so users know where to find it
|
# Print the log file path so users know where to find it
|
||||||
print(f"Logging to: {log_file}", file=sys.stderr)
|
print(f"Logging to: {log_file}", file=sys.stderr)
|
||||||
@ -195,6 +198,9 @@ def setup_logging() -> None:
|
|||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
logger.propagate = False # Prevent message duplication
|
logger.propagate = False # Prevent message duplication
|
||||||
|
|
||||||
|
# Store handlers globally for cleanup
|
||||||
|
_current_handlers.extend(handlers)
|
||||||
|
|
||||||
# Start the listener AFTER configuring all loggers
|
# Start the listener AFTER configuring all loggers
|
||||||
_current_listener = logging.handlers.QueueListener(
|
_current_listener = logging.handlers.QueueListener(
|
||||||
log_queue, *handlers, respect_handler_level=True
|
log_queue, *handlers, respect_handler_level=True
|
||||||
@ -209,7 +215,13 @@ def setup_logging() -> None:
|
|||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup_logging() -> None:
|
def cleanup_logging() -> None:
|
||||||
"""Clean up logging resources on exit."""
|
"""Clean up logging resources on exit."""
|
||||||
global _current_listener
|
global _current_listener, _current_handlers
|
||||||
if _current_listener is not None:
|
if _current_listener is not None:
|
||||||
_current_listener.stop()
|
_current_listener.stop()
|
||||||
_current_listener = None
|
_current_listener = None
|
||||||
|
|
||||||
|
# Close all file handlers to ensure proper cleanup on Windows
|
||||||
|
for handler in _current_handlers:
|
||||||
|
if isinstance(handler, logging.FileHandler):
|
||||||
|
handler.close()
|
||||||
|
_current_handlers.clear()
|
||||||
|
|||||||
267
libp2p/utils/paths.py
Normal file
267
libp2p/utils/paths.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
Cross-platform path utilities for py-libp2p.
|
||||||
|
|
||||||
|
This module provides standardized path operations to ensure consistent
|
||||||
|
behavior across Windows, macOS, and Linux platforms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
PathLike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
|
def get_temp_dir() -> Path:
|
||||||
|
"""
|
||||||
|
Get cross-platform temporary directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Platform-specific temporary directory path
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(tempfile.gettempdir())
|
||||||
|
|
||||||
|
|
||||||
|
def get_project_root() -> Path:
|
||||||
|
"""
|
||||||
|
Get the project root directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the py-libp2p project root
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Navigate from libp2p/utils/paths.py to project root
|
||||||
|
return Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def join_paths(*parts: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Cross-platform path joining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*parts: Path components to join
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Joined path using platform-appropriate separator
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(*parts)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_dir_exists(path: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Ensure directory exists, create if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path to ensure exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path object for the directory
|
||||||
|
|
||||||
|
"""
|
||||||
|
path_obj = Path(path)
|
||||||
|
path_obj.mkdir(parents=True, exist_ok=True)
|
||||||
|
return path_obj
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_dir() -> Path:
|
||||||
|
"""
|
||||||
|
Get user config directory (cross-platform).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Platform-specific config directory
|
||||||
|
|
||||||
|
"""
|
||||||
|
if os.name == "nt": # Windows
|
||||||
|
appdata = os.environ.get("APPDATA", "")
|
||||||
|
if appdata:
|
||||||
|
return Path(appdata) / "py-libp2p"
|
||||||
|
else:
|
||||||
|
# Fallback to user home directory
|
||||||
|
return Path.home() / "AppData" / "Roaming" / "py-libp2p"
|
||||||
|
else: # Unix-like (Linux, macOS)
|
||||||
|
return Path.home() / ".config" / "py-libp2p"
|
||||||
|
|
||||||
|
|
||||||
|
def get_script_dir(script_path: PathLike | None = None) -> Path:
|
||||||
|
"""
|
||||||
|
Get the directory containing a script file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script_path: Path to the script file. If None, uses __file__
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Directory containing the script
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If script path cannot be determined
|
||||||
|
|
||||||
|
"""
|
||||||
|
if script_path is None:
|
||||||
|
# This will be the directory of the calling script
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
if frame and frame.f_back:
|
||||||
|
script_path = frame.f_back.f_globals.get("__file__")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Could not determine script path")
|
||||||
|
|
||||||
|
if script_path is None:
|
||||||
|
raise RuntimeError("Script path is None")
|
||||||
|
|
||||||
|
return Path(script_path).parent.absolute()
|
||||||
|
|
||||||
|
|
||||||
|
def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path:
|
||||||
|
"""
|
||||||
|
Create a temporary file with a unique name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: File name prefix
|
||||||
|
suffix: File name suffix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the created temporary file
|
||||||
|
|
||||||
|
"""
|
||||||
|
temp_dir = get_temp_dir()
|
||||||
|
# Create a unique filename using timestamp and random bytes
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string
|
||||||
|
unique_id = secrets.token_hex(4)
|
||||||
|
filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}"
|
||||||
|
|
||||||
|
temp_file = temp_dir / filename
|
||||||
|
# Create the file by touching it
|
||||||
|
temp_file.touch()
|
||||||
|
return temp_file
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Resolve a relative path from a base path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: Base directory path
|
||||||
|
relative_path: Relative path to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Resolved absolute path
|
||||||
|
|
||||||
|
"""
|
||||||
|
base = Path(base_path).resolve()
|
||||||
|
relative = Path(relative_path)
|
||||||
|
|
||||||
|
if relative.is_absolute():
|
||||||
|
return relative
|
||||||
|
else:
|
||||||
|
return (base / relative).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_path(path: PathLike) -> Path:
|
||||||
|
"""
|
||||||
|
Normalize a path, resolving any symbolic links and relative components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Normalized absolute path
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(path).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def get_venv_path() -> Path | None:
|
||||||
|
"""
|
||||||
|
Get virtual environment path if active.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Virtual environment path if active, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
venv_path = os.environ.get("VIRTUAL_ENV")
|
||||||
|
if venv_path:
|
||||||
|
return Path(venv_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_python_executable() -> Path:
|
||||||
|
"""
|
||||||
|
Get current Python executable path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the current Python executable
|
||||||
|
|
||||||
|
"""
|
||||||
|
return Path(sys.executable)
|
||||||
|
|
||||||
|
|
||||||
|
def find_executable(name: str) -> Path | None:
|
||||||
|
"""
|
||||||
|
Find executable in system PATH.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the executable to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to executable if found, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Check if name already contains path
|
||||||
|
if os.path.dirname(name):
|
||||||
|
path = Path(name)
|
||||||
|
if path.exists() and os.access(path, os.X_OK):
|
||||||
|
return path
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Search in PATH
|
||||||
|
for path_dir in os.environ.get("PATH", "").split(os.pathsep):
|
||||||
|
if not path_dir:
|
||||||
|
continue
|
||||||
|
path = Path(path_dir) / name
|
||||||
|
if path.exists() and os.access(path, os.X_OK):
|
||||||
|
return path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_script_binary_path() -> Path:
|
||||||
|
"""
|
||||||
|
Get path to script's binary directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Directory containing the script's binary
|
||||||
|
|
||||||
|
"""
|
||||||
|
return get_python_executable().parent
|
||||||
|
|
||||||
|
|
||||||
|
def get_binary_path(binary_name: str) -> Path | None:
|
||||||
|
"""
|
||||||
|
Find binary in PATH or virtual environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
binary_name: Name of the binary to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to binary if found, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
# First check in virtual environment if active
|
||||||
|
venv_path = get_venv_path()
|
||||||
|
if venv_path:
|
||||||
|
venv_bin = venv_path / "bin" if os.name != "nt" else venv_path / "Scripts"
|
||||||
|
binary_path = venv_bin / binary_name
|
||||||
|
if binary_path.exists() and os.access(binary_path, os.X_OK):
|
||||||
|
return binary_path
|
||||||
|
|
||||||
|
# Fall back to system PATH
|
||||||
|
return find_executable(binary_name)
|
||||||
12
newsfragments/585.feature.rst
Normal file
12
newsfragments/585.feature.rst
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
Added experimental WebSocket transport support with basic WS and WSS functionality. This includes:
|
||||||
|
|
||||||
|
- WebSocket transport implementation with trio-websocket backend
|
||||||
|
- Support for both WS (WebSocket) and WSS (WebSocket Secure) protocols
|
||||||
|
- Basic connection management and stream handling
|
||||||
|
- TLS configuration support for WSS connections
|
||||||
|
- Multiaddr parsing for WebSocket addresses
|
||||||
|
- Integration with libp2p host and peer discovery
|
||||||
|
|
||||||
|
**Note**: This is experimental functionality. Advanced features like proxy support,
|
||||||
|
interop testing, and production examples are still in development. See
|
||||||
|
https://github.com/libp2p/py-libp2p/discussions/937 for the complete roadmap of missing features.
|
||||||
1
newsfragments/735.internal.rst
Normal file
1
newsfragments/735.internal.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve relay selection by load balancing and reservation priority.
|
||||||
1
newsfragments/763.feature.rst
Normal file
1
newsfragments/763.feature.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add QUIC transport support for faster, more efficient peer-to-peer connections with native stream multiplexing.
|
||||||
1
newsfragments/843.bugfix.rst
Normal file
1
newsfragments/843.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module.
|
||||||
1
newsfragments/849.feature.rst
Normal file
1
newsfragments/849.feature.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add automatic peer dialing in bootstrap module using trio.Nursery.
|
||||||
2
newsfragments/885.feature.rst
Normal file
2
newsfragments/885.feature.rst
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
Updated all example scripts and core modules to use secure loopback addresses instead of wildcard addresses for network binding.
|
||||||
|
The `get_wildcard_address` function and related logic now utilize all available interfaces safely, improving security and consistency across the codebase.
|
||||||
2
newsfragments/886.bugfix.rst
Normal file
2
newsfragments/886.bugfix.rst
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
Fixed cross-platform path handling by replacing hardcoded OS-specific
|
||||||
|
paths with standardized utilities in core modules and examples.
|
||||||
1
newsfragments/896.bugfix.rst
Normal file
1
newsfragments/896.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly
|
||||||
6
newsfragments/897.bugfix.rst
Normal file
6
newsfragments/897.bugfix.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions
|
||||||
|
|
||||||
|
- Implements ReadWriteLock for `YamuxStream` write operations
|
||||||
|
- Prevents data corruption from concurrent write operations
|
||||||
|
- Read operations remain lock-free due to existing `Yamux` architecture
|
||||||
|
- Resolves race conditions identified in Issue #793
|
||||||
11
newsfragments/917.internal.rst
Normal file
11
newsfragments/917.internal.rst
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
Replace magic numbers with named constants and enums for clarity and maintainability
|
||||||
|
|
||||||
|
**Key Changes:**
|
||||||
|
- **Introduced type-safe enums** for better code clarity:
|
||||||
|
- `RelayRole(Flag)` enum with HOP, STOP, CLIENT roles supporting bitwise combinations (e.g., `RelayRole.HOP | RelayRole.STOP`)
|
||||||
|
- `ReservationStatus(Enum)` for reservation lifecycle management (ACTIVE, EXPIRED, REJECTED)
|
||||||
|
- **Replaced magic numbers with named constants** throughout the codebase, improving code maintainability and eliminating hardcoded timeout values (15s, 30s, 10s) with descriptive constant names
|
||||||
|
- **Added comprehensive timeout configuration system** with new `TimeoutConfig` dataclass supporting component-specific timeouts (discovery, protocol, DCUtR)
|
||||||
|
- **Enhanced configurability** of `RelayDiscovery`, `CircuitV2Protocol`, and `DCUtRProtocol` constructors with optional timeout parameters
|
||||||
|
- **Improved architecture consistency** with clean configuration flow across all circuit relay components
|
||||||
|
**Backward Compatibility:** All changes maintain full backward compatibility. Existing code continues to work unchanged while new timeout configuration options are available for users who need them.
|
||||||
1
newsfragments/927.bugfix.rst
Normal file
1
newsfragments/927.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix multiaddr dependency to use the last py-multiaddr commit hash to resolve installation issues
|
||||||
1
newsfragments/934.misc.rst
Normal file
1
newsfragments/934.misc.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Updated multiaddr dependency from git repository to pip package version 0.0.11.
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user